diff --git a/binary/pool.go b/binary/pool.go index d0d7a244..645d10a4 100644 --- a/binary/pool.go +++ b/binary/pool.go @@ -32,30 +32,25 @@ func PutBuffer(w *Writer) { } } -type gzipWriter struct { - w *gzip.Writer - buf *bytes.Buffer -} - var gzipPool = sync.Pool{ New: func() interface{} { buf := new(bytes.Buffer) w := gzip.NewWriter(buf) - return &gzipWriter{ + return &GzipWriter{ w: w, buf: buf, } }, } -func acquireGzipWriter() *gzipWriter { - ret := gzipPool.Get().(*gzipWriter) +func AcquireGzipWriter() *GzipWriter { + ret := gzipPool.Get().(*GzipWriter) ret.buf.Reset() ret.w.Reset(ret.buf) return ret } -func releaseGzipWriter(w *gzipWriter) { +func ReleaseGzipWriter(w *GzipWriter) { // See https://golang.org/issue/23199 const maxSize = 1 << 16 if w.buf.Cap() < maxSize { diff --git a/binary/reader.go b/binary/reader.go index c64d1958..0e76a8fb 100644 --- a/binary/reader.go +++ b/binary/reader.go @@ -117,10 +117,8 @@ func (r *Reader) Len() int { } func (tlv TlvMap) Exists(key uint16) bool { - if _, ok := tlv[key]; ok { - return true - } - return false + _, ok := tlv[key] + return ok } // --- Network reader --- @@ -144,13 +142,13 @@ func (r *NetworkReader) ReadByte() (byte, error) { func (r *NetworkReader) ReadBytes(len int) ([]byte, error) { buf := make([]byte, len) _, err := io.ReadFull(r.conn, buf) - //for i := 0; i < len; i++ { - // b, err := r.ReadByte() - // if err != nil { + // for i := 0; i < len; i++ { + // b, err := r.ReadByte() + // if err != nil { // return nil, err - // } - // buf[i] = b - //} + // } + // buf[i] = b + // } return buf, err } diff --git a/binary/utils.go b/binary/utils.go index 29cc14ef..999aa19b 100644 --- a/binary/utils.go +++ b/binary/utils.go @@ -12,6 +12,23 @@ import ( "github.com/Mrs4s/MiraiGo/utils" ) +type GzipWriter struct { + w *gzip.Writer + buf *bytes.Buffer +} + +func (w *GzipWriter) Write(p []byte) (int, error) { + return w.w.Write(p) +} + +func (w *GzipWriter) Close() error { + return w.w.Close() +} + +func (w *GzipWriter) Bytes() []byte { + return w.buf.Bytes() +} + func ZlibUncompress(src []byte) []byte { b := bytes.NewReader(src) var out bytes.Buffer @@ -31,11 +48,11 @@ func ZlibCompress(data []byte) []byte { } func GZipCompress(data []byte) []byte { - gw := acquireGzipWriter() - _, _ = gw.w.Write(data) - _ = gw.w.Close() + gw := AcquireGzipWriter() + _, _ = gw.Write(data) + _ = gw.Close() ret := append([]byte(nil), gw.buf.Bytes()...) - releaseGzipWriter(gw) + ReleaseGzipWriter(gw) return ret } diff --git a/binary/writer.go b/binary/writer.go index acb54b58..dcd53900 100644 --- a/binary/writer.go +++ b/binary/writer.go @@ -79,20 +79,24 @@ func (w *Writer) WriteIntLvPacket(offset int, f func(writer *Writer)) { } func (w *Writer) WriteUniPacket(commandName string, sessionId, extraData, body []byte) { - w.WriteIntLvPacket(4, func(w *Writer) { - w.WriteString(commandName) - w.WriteUInt32(8) - w.Write(sessionId) + w1 := NewWriter() + { // WriteIntLvPacket + w1.WriteString(commandName) + w1.WriteUInt32(8) + w1.Write(sessionId) if len(extraData) == 0 { - w.WriteUInt32(0x04) + w1.WriteUInt32(0x04) } else { - w.WriteUInt32(uint32(len(extraData) + 4)) - w.Write(extraData) + w1.WriteUInt32(uint32(len(extraData) + 4)) + w1.Write(extraData) } - }) - w.WriteIntLvPacket(4, func(w *Writer) { - w.Write(body) - }) + } + data := w1.Bytes() + w.WriteUInt32(uint32(len(data) + 4)) + w.Write(data) + PutBuffer(w1) + w.WriteUInt32(uint32(len(body) + 4)) // WriteIntLvPacket + w.Write(body) } func (w *Writer) WriteBytesShort(data []byte) { diff --git a/protocol/packets/builders.go b/protocol/packets/builders.go index 1b4789e9..9d8325cf 100644 --- a/protocol/packets/builders.go +++ b/protocol/packets/builders.go @@ -27,15 +27,23 @@ func BuildLoginPacket(uin int64, bodyType byte, key, body, extraData []byte) []b func BuildUniPacket(uin int64, seq uint16, commandName string, encryptType byte, sessionID, extraData, key, body []byte) []byte { return binary.NewWriterF(func(w *binary.Writer) { - w.WriteIntLvPacket(4, func(w *binary.Writer) { - w.WriteUInt32(0x0B) - w.WriteByte(encryptType) - w.WriteUInt32(uint32(seq)) - w.WriteByte(0) - w.WriteString(strconv.FormatInt(uin, 10)) - w.EncryptAndWrite(key, binary.NewWriterF(func(w *binary.Writer) { - w.WriteUniPacket(commandName, sessionID, extraData, body) - })) - }) + w2 := binary.NewWriter() + { // w.WriteIntLvPacket + w2.WriteUInt32(0x0B) + w2.WriteByte(encryptType) + w2.WriteUInt32(uint32(seq)) + w2.WriteByte(0) + w2.WriteString(strconv.FormatInt(uin, 10)) + + // inline NewWriterF + w3 := binary.NewWriter() + w3.WriteUniPacket(commandName, sessionID, extraData, body) + w2.EncryptAndWrite(key, w3.Bytes()) + binary.PutBuffer(w3) + } + data := w2.Bytes() + w.WriteUInt32(uint32(len(data) + 4)) + w.Write(data) + binary.PutBuffer(w2) }) } diff --git a/protocol/packets/global.go b/protocol/packets/global.go index 48377f5b..861be889 100644 --- a/protocol/packets/global.go +++ b/protocol/packets/global.go @@ -3,9 +3,10 @@ package packets import ( "strconv" + "github.com/pkg/errors" + "github.com/Mrs4s/MiraiGo/binary" "github.com/Mrs4s/MiraiGo/protocol/crypto" - "github.com/pkg/errors" ) var (