diff --git a/binary/jce/reader.go b/binary/jce/reader.go index 7d2f0a7e..e1f88fb3 100644 --- a/binary/jce/reader.go +++ b/binary/jce/reader.go @@ -1,15 +1,16 @@ package jce import ( - "bytes" goBinary "encoding/binary" - "io" "math" "reflect" + + "github.com/Mrs4s/MiraiGo/utils" ) type JceReader struct { - buf *bytes.Reader + buf []byte + off int } type HeadData struct { @@ -18,25 +19,26 @@ type HeadData struct { } func NewJceReader(data []byte) *JceReader { - buf := bytes.NewReader(data) - return &JceReader{buf: buf} + return &JceReader{buf: data} } func (r *JceReader) readHead() (hd HeadData, l int32) { - b, _ := r.buf.ReadByte() + b := r.buf[r.off] hd.Type = b & 0xF hd.Tag = (int(b) & 0xF0) >> 4 + l = 1 if hd.Tag == 15 { - b, _ = r.buf.ReadByte() + b = r.buf[r.off+1] hd.Tag = int(b) & 0xFF - return hd, 2 + l = 2 } - return hd, 1 + r.off += int(l) + return } func (r *JceReader) peakHead() (h HeadData, l int32) { h, l = r.readHead() - _, _ = r.buf.Seek(int64(-l), io.SeekCurrent) + r.off -= int(l) return } @@ -55,7 +57,7 @@ func (r *JceReader) skipField(t byte) { case 3, 5: r.skip(8) case 6: - b, _ := r.buf.ReadByte() + b := r.readByte() r.skip(int(b)) case 7: r.skip(int(r.readInt32())) @@ -89,33 +91,28 @@ func (r *JceReader) SkipField(c int) { } } -func (r *JceReader) readBytes(len int) []byte { - if len == 0 { - return nil - } - b := make([]byte, len) - _, err := r.buf.Read(b) - if err != nil { - panic(err) +func (r *JceReader) readBytes(n int) []byte { + if r.off+n > len(r.buf) { + panic("readBytes: EOF") } + b := make([]byte, n) + n = copy(b, r.buf[r.off:]) + r.off += n return b } func (r *JceReader) readByte() byte { - b, err := r.buf.ReadByte() - if err != nil { - panic(err) + if r.off >= len(r.buf) { + panic("readByte: EOF") } + b := r.buf[r.off] + r.off++ return b } func (r *JceReader) readUInt16() uint16 { - f, _ := r.buf.ReadByte() - s, err := r.buf.ReadByte() - if err != nil { - panic(err) - } - return uint16((int32(f) << 8) + int32(s)) + b := r.readBytes(2) + return uint16((int32(b[0]) << 8) + int32(b[1])) } func (r *JceReader) readInt32() int32 { @@ -274,9 +271,9 @@ func (r *JceReader) ReadString(tag int) string { hd, _ := r.readHead() switch hd.Type { case 6: - return string(r.readBytes(int(r.readByte()))) + return utils.B2S(r.readBytes(int(r.readByte()))) case 7: - return string(r.readBytes(int(r.readInt32()))) + return utils.B2S(r.readBytes(int(r.readInt32()))) default: return "" } @@ -364,35 +361,60 @@ func (r *JceReader) ReadJceStruct(obj IJceStruct, tag int) { func (r *JceReader) ReadMap(i interface{}, tag int) { v := reflect.ValueOf(i) - if v.Kind() != reflect.Map { - return - } - if !r.skipToTag(tag) { + r.readMap(v, tag) +} + +func (r *JceReader) readMap(v reflect.Value, tag int) { + if v.Kind() != reflect.Map || !r.skipToTag(tag) { return } t := v.Type() kt := t.Key() - kv := reflect.New(kt) - vt := t.Elem() - vv := reflect.New(vt) - r.readHead() s := r.ReadInt32(0) + + // map with string key or string value is very common. + // specialize for string + if kt.Kind() == reflect.String && vt.Kind() == reflect.String { + for i := 0; i < int(s); i++ { + kv := reflect.ValueOf(r.ReadString(0)) + vv := reflect.ValueOf(r.ReadString(1)) + v.SetMapIndex(kv, vv) + } + return + } + + if kt.Kind() == reflect.String { + vv := reflect.New(vt) + for i := 0; i < int(s); i++ { + kv := reflect.ValueOf(r.ReadString(0)) + r.readObject(vv, 1) + v.SetMapIndex(kv, vv.Elem()) + } + return + } + + kv := reflect.New(kt) + vv := reflect.New(vt) for i := 0; i < int(s); i++ { - r.ReadObject(kv.Interface(), 0) - r.ReadObject(vv.Interface(), 1) + r.readObject(kv, 0) + r.readObject(vv, 1) v.SetMapIndex(kv.Elem(), vv.Elem()) } } func (r *JceReader) ReadSlice(i interface{}, tag int) { - t := reflect.TypeOf(i) - v := reflect.ValueOf(i).Elem() + r.readSlice(reflect.ValueOf(i), tag) +} + +func (r *JceReader) readSlice(v reflect.Value, tag int) { + t := v.Type() if t.Kind() != reflect.Ptr || t.Elem().Kind() != reflect.Slice { return } + v = v.Elem() t = t.Elem() if !r.skipToTag(tag) { return @@ -404,7 +426,7 @@ func (r *JceReader) ReadSlice(i interface{}, tag int) { t = t.Elem() val := reflect.New(t) for i := 0; i < int(s); i++ { - r.ReadObject(val.Interface(), 0) + r.readObject(val, 0) sv.Index(i).Set(val.Elem()) } v.Set(sv) @@ -417,43 +439,52 @@ func (r *JceReader) ReadSlice(i interface{}, tag int) { } func (r *JceReader) ReadObject(i interface{}, tag int) { - va := reflect.ValueOf(i) - if va.Kind() != reflect.Ptr || va.IsNil() { - return - } - if ve := va.Elem(); ve.Kind() == reflect.Map { - ve.Set(reflect.MakeMap(ve.Type())) - r.ReadMap(ve.Interface(), tag) - return - } - switch o := i.(type) { - case *byte: - *o = r.ReadByte(tag) - case *bool: - *o = r.ReadBool(tag) - case *int16: - *o = r.ReadInt16(tag) - case *int: - *o = int(r.ReadInt32(tag)) - case *int32: - *o = r.ReadInt32(tag) - case *int64: - *o = r.ReadInt64(tag) - case *float32: - *o = r.ReadFloat32(tag) - case *float64: - *o = r.ReadFloat64(tag) - case *string: - *o = r.ReadString(tag) - case *[]byte: - r.ReadSlice(o, tag) - case IJceStruct: - r.readHead() - o.ReadFrom(r) - r.skipToStructEnd() - } + v := reflect.ValueOf(i) + r.readObject(v, tag) } -func (r *JceReader) ReadAvailable() []byte { - return r.readBytes(r.buf.Len()) +func (r *JceReader) readObject(v reflect.Value, tag int) { + if v.Kind() != reflect.Ptr || v.IsNil() { + return + } + elemType := v.Type().Elem() + if elemType.Kind() == reflect.Map { + elem := v.Elem() + elem.Set(reflect.MakeMap(elem.Type())) + r.readMap(elem, tag) + return + } else if elemType.Kind() == reflect.Slice && // *[]byte + elemType.Elem().Kind() == reflect.Uint8 { + elem := v.Elem() + elem.SetBytes(r.ReadBytes(tag)) + return + } + + switch elemType.Kind() { + case reflect.Uint8, reflect.Int8: + *(*uint8)(pointerOf(v)) = r.ReadByte(tag) + case reflect.Bool: + *(*bool)(pointerOf(v)) = r.ReadBool(tag) + case reflect.Uint16, reflect.Int16: + *(*int16)(pointerOf(v)) = r.ReadInt16(tag) + case reflect.Uint32, reflect.Int32: + *(*int32)(pointerOf(v)) = r.ReadInt32(tag) + case reflect.Uint64, reflect.Int64: + *(*int64)(pointerOf(v)) = r.ReadInt64(tag) + case reflect.String: + *(*string)(pointerOf(v)) = r.ReadString(tag) + + default: + // other cases + switch o := v.Interface().(type) { + case IJceStruct: + r.readHead() + o.ReadFrom(r) + r.skipToStructEnd() + case *float32: + *o = r.ReadFloat32(tag) + case *float64: + *o = r.ReadFloat64(tag) + } + } }