diff --git a/binary/jce/reader.go b/binary/jce/reader.go index d1609e06..9ff7f626 100644 --- a/binary/jce/reader.go +++ b/binary/jce/reader.go @@ -2,6 +2,7 @@ package jce import ( "bytes" + goBinary "encoding/binary" "math" "reflect" ) @@ -113,12 +114,12 @@ func (r *JceReader) readUInt16() uint16 { func (r *JceReader) readInt32() int32 { b := r.readBytes(4) - return (int32(b[0]) << 24) | (int32(b[1]) << 16) | (int32(b[2]) << 8) | int32(b[3]) + return int32(goBinary.BigEndian.Uint32(b)) } func (r *JceReader) readInt64() int64 { b := r.readBytes(8) - return (int64(b[0]) << 56) | (int64(b[1]) << 48) | (int64(b[2]) << 40) | (int64(b[3]) << 32) | (int64(b[4]) << 24) | (int64(b[5]) << 16) | (int64(b[6]) << 8) | int64(b[7]) + return int64(goBinary.BigEndian.Uint64(b)) } func (r *JceReader) readFloat32() float32 { @@ -334,7 +335,32 @@ func (r *JceReader) ReadJceStruct(obj IJceStruct, tag int) { r.skipToStructEnd() } -func (r *JceReader) ReadMapF(tag int, f func(interface{}, interface{})) { +func (r *JceReader) ReadMap(i interface{}, tag int) { + v := reflect.ValueOf(i) + if v.Kind() != reflect.Map { + return + } + if !r.skipToTag(tag) { + return + } + t := v.Type() + + kt := t.Key() + kv := reflect.New(kt) + + vt := t.Elem() + vv := reflect.New(vt) + + r.readHead() + s := r.ReadInt32(0) + for i := 0; i < int(s); i++ { + r.ReadObject(kv.Interface(), 0) + r.ReadObject(vv.Interface(), 1) + v.SetMapIndex(kv.Elem(), vv.Elem()) + } +} + +func (r *JceReader) _ReadMapF(tag int, f func(interface{}, interface{})) { if !r.skipToTag(tag) { return } @@ -408,6 +434,11 @@ func (r *JceReader) ReadObject(i interface{}, tag int) { if va.Kind() != reflect.Ptr || va.IsNil() { return } + if ve := va.Elem(); ve.Kind() == reflect.Map { + ve.Set(reflect.MakeMap(ve.Type())) + r.ReadMap(ve.Interface(), tag) + return + } switch o := i.(type) { case *byte: *o = r.ReadByte(tag) diff --git a/binary/jce/reader_test.go b/binary/jce/reader_test.go index f3d2846a..66881f4e 100644 --- a/binary/jce/reader_test.go +++ b/binary/jce/reader_test.go @@ -41,3 +41,43 @@ func BenchmarkJceReader_ReadSlice(b *testing.B) { r.ReadSlice(&result, 1) } } + +var req = RequestDataVersion2{ + Map: map[string]map[string][]byte{ + "1": { + "123": []byte(`123`), + }, + "2": { + "123": []byte(`123`), + }, + "3": { + "123": []byte(`123`), + }, + "4": { + "123": []byte(`123`), + }, + "5": { + "123": []byte(`123`), + }, + }} + +func TestRequestDataVersion2_ReadFrom(t *testing.T) { + // todo(wdv): fuzz test + w := NewJceWriter() + w.WriteObject(req.Map, 0) + src := w.Bytes() + result := RequestDataVersion2{} + result.ReadFrom(NewJceReader(src)) + assert.Equal(t, req, result) +} + +func BenchmarkRequestDataVersion2_ReadFrom(b *testing.B) { + w := NewJceWriter() + w.WriteObject(req.Map, 0) + src := w.Bytes() + b.SetBytes(int64(len(src))) + result := &RequestDataVersion2{} + for i := 0; i < b.N; i++ { + result.ReadFrom(NewJceReader(src)) + } +} diff --git a/binary/jce/structs.go b/binary/jce/structs.go index 9583912d..ed0669d5 100644 --- a/binary/jce/structs.go +++ b/binary/jce/structs.go @@ -557,8 +557,8 @@ func (pkt *RequestPacket) ReadFrom(r *JceReader) { pkt.SFuncName = r.ReadString(6) r.ReadSlice(&pkt.SBuffer, 7) pkt.ITimeout = r.ReadInt32(8) - r.ReadMapF(9, func(k interface{}, v interface{}) { pkt.Context[k.(string)] = v.(string) }) - r.ReadMapF(10, func(k interface{}, v interface{}) { pkt.Status[k.(string)] = v.(string) }) + r.ReadMap(pkt.Context, 9) + r.ReadMap(pkt.Status, 10) } func (pkt *RequestDataVersion3) ToBytes() []byte { @@ -569,9 +569,7 @@ func (pkt *RequestDataVersion3) ToBytes() []byte { func (pkt *RequestDataVersion3) ReadFrom(r *JceReader) { pkt.Map = make(map[string][]byte) - r.ReadMapF(0, func(k interface{}, v interface{}) { - pkt.Map[k.(string)] = v.([]byte) - }) + r.ReadMap(pkt.Map, 0) } func (pkt *RequestDataVersion2) ToBytes() []byte { @@ -582,7 +580,8 @@ func (pkt *RequestDataVersion2) ToBytes() []byte { func (pkt *RequestDataVersion2) ReadFrom(r *JceReader) { pkt.Map = make(map[string]map[string][]byte) - r.ReadMapF(0, func(k interface{}, v interface{}) { + // r.ReadMap(pkt.Map, 0) + r._ReadMapF(0, func(k interface{}, v interface{}) { pkt.Map[k.(string)] = make(map[string][]byte) for k2, v := range v.(map[interface{}]interface{}) { pkt.Map[k.(string)][k2.(string)] = v.([]byte)