diff --git a/binary/jce/gen/structs_parser.go b/binary/jce/gen/structs_parser.go new file mode 100644 index 00000000..b09c7cb2 --- /dev/null +++ b/binary/jce/gen/structs_parser.go @@ -0,0 +1,156 @@ +package main + +import ( + "flag" + "fmt" + "io" + "os" + "reflect" + "strconv" + "strings" + "sync" + + "github.com/Mrs4s/MiraiGo/binary/jce" +) + +const head = `// Code generated by structs_parser; DO NOT EDIT. + +package jce + +` + +func main() { + f := flag.String("f", "structs_tobytes.go", "output file.") + flag.Parse() + fmt.Println("gen runs on arg", *f) + w, err := os.Create(*f) + if err != nil { + panic(err) + } + w.WriteString(head) + WriteJceStruct(w, &jce.RequestPacket{}) + WriteJceStruct(w, &jce.RequestDataVersion3{}) + WriteJceStruct(w, &jce.RequestDataVersion2{}) + WriteJceStruct(w, &jce.SvcReqRegister{}) + WriteJceStruct(w, &jce.FriendListRequest{}) + WriteJceStruct(w, &jce.SummaryCardReq{}) + WriteJceStruct(w, &jce.SummaryCardReqSearch{}) + WriteJceStruct(w, &jce.TroopListRequest{}) + WriteJceStruct(w, &jce.TroopMemberListRequest{}) + WriteJceStruct(w, &jce.SvcRespPushMsg{}) + WriteJceStruct(w, &jce.ModifyGroupCardRequest{}) + WriteJceStruct(w, &jce.SvcReqGetDevLoginInfo{}) + WriteJceStruct(w, &jce.SvcReqRegisterNew{}) + WriteJceStruct(w, &jce.DelFriendReq{}) + err = w.Close() + if err != nil { + panic(err) + } +} + +func writeObject(w io.Writer, v reflect.Value, tag byte, name string) { + k := v.Kind() + if k == reflect.Map { + switch v.Interface().(type) { + case map[string]string: + w.Write([]byte(fmt.Sprintf("\tw.writeMapStrStr(pkt.%s, %d)\n", name, tag))) + case map[string][]byte: + w.Write([]byte(fmt.Sprintf("\tw.writeMapStrBytes(pkt.%s, %d)\n", name, tag))) + case map[string]map[string][]byte: + w.Write([]byte(fmt.Sprintf("\tw.writeMapStrMapStrBytes(pkt.%s, %d)\n", name, tag))) + default: + w.Write([]byte(fmt.Sprintf("\tw.writeMap(pkt.%s, %d)\n", name, tag))) + } + return + } + if k == reflect.Slice { + switch v.Interface().(type) { + case []byte: + w.Write([]byte(fmt.Sprintf("\tw.WriteBytes(pkt.%s, %d)\n", name, tag))) + case []int64: + w.Write([]byte(fmt.Sprintf("\tw.WriteInt64Slice(pkt.%s, %d)\n", name, tag))) + case [][]byte: + w.Write([]byte(fmt.Sprintf("\tw.WriteBytesSlice(pkt.%s, %d)\n", name, tag))) + case []jce.IJceStruct: + w.Write([]byte(fmt.Sprintf("\tw.WriteJceStructSlice(pkt.%s, %d)\n", name, tag))) + default: + w.Write([]byte(fmt.Sprintf("\tw.writeSlice(pkt.%s, %d)\n", name, tag))) + } + return + } + switch k { + case reflect.Uint8, reflect.Int8: + w.Write([]byte(fmt.Sprintf("\tw.WriteByte(byte(pkt.%s), %d)\n", name, tag))) + case reflect.Uint16, reflect.Int16: + w.Write([]byte(fmt.Sprintf("\tw.WriteInt16(int16(pkt.%s), %d)\n", name, tag))) + case reflect.Uint32, reflect.Int32: + w.Write([]byte(fmt.Sprintf("\tw.WriteInt32(int32(pkt.%s), %d)\n", name, tag))) + case reflect.Uint64, reflect.Int64: + w.Write([]byte(fmt.Sprintf("\tw.WriteInt64(int64(pkt.%s), %d)\n", name, tag))) + case reflect.String: + w.Write([]byte(fmt.Sprintf("\tw.WriteString(pkt.%s, %d)\n", name, tag))) + default: + switch v.Interface().(type) { + case jce.IJceStruct: + w.Write([]byte(fmt.Sprintf("\tw.WriteJceStruct(pkt.%s, %d)\n", name, tag))) + case float32: + w.Write([]byte(fmt.Sprintf("\tw.WriteFloat32(pkt.%s, %d)\n", name, tag))) + case float64: + w.Write([]byte(fmt.Sprintf("\tw.WriteFloat64(pkt.%s, %d)\n", name, tag))) + } + } +} + +type decoder struct { + index int + id int + name string +} + +var decoderCache = sync.Map{} + +// writeJceStructRaw 写入 Jce 结构体 +func writeJceStructRaw(w io.Writer, s interface{}) { + t := reflect.TypeOf(s) + if t.Kind() != reflect.Ptr { + return + } + t = t.Elem() + v := reflect.ValueOf(s).Elem() + var jceDec []decoder + dec, ok := decoderCache.Load(t) + if ok { // 从缓存中加载 + jceDec = dec.([]decoder) + } else { // 初次反射 + jceDec = make([]decoder, 0, t.NumField()) + for i := 0; i < t.NumField(); i++ { + field := t.Field(i) + strId := field.Tag.Get("jceId") + if strId == "" { + continue + } + id, err := strconv.Atoi(strId) + if err != nil { + continue + } + jceDec = append(jceDec, decoder{ + index: i, + id: id, + name: field.Name, + }) + } + decoderCache.Store(t, jceDec) // 存入缓存 + } + for _, dec := range jceDec { + obj := v.Field(dec.index) + writeObject(w, obj, byte(dec.id), dec.name) + } +} + +func WriteJceStruct(w io.Writer, s jce.IJceStruct) { + w.Write([]byte(fmt.Sprintf("func (pkt %s) ToBytes() []byte {\n", strings.ReplaceAll(reflect.TypeOf(s).String(), "jce.", "")))) + w.Write([]byte("\tw := NewJceWriter()\n")) + writeJceStructRaw(w, s) + w.Write([]byte("\treturn w.Bytes()\n")) + w.Write([]byte("}\n\n")) +} diff --git a/binary/jce/reader_test.go b/binary/jce/reader_test.go index 7caec125..4bd4c7ea 100644 --- a/binary/jce/reader_test.go +++ b/binary/jce/reader_test.go @@ -2,6 +2,7 @@ package jce import ( "math/rand" + "reflect" "testing" "github.com/stretchr/testify/assert" @@ -15,7 +16,7 @@ func TestJceReader_ReadSlice(t *testing.T) { s[i] = b } w := NewJceWriter() - w.WriteObject(s, 1) + w.WriteBytesSlice(s, 1) r := NewJceReader(w.Bytes()) result := r.ReadByteArrArr(1) assert.Equal(t, s, result) @@ -64,7 +65,7 @@ var req = RequestDataVersion2{ func TestRequestDataVersion2_ReadFrom(t *testing.T) { // todo(wdv): fuzz test w := NewJceWriter() - w.WriteObject(req.Map, 0) + w.writeMapStrMapStrBytes(req.Map, 0) src := w.Bytes() result := RequestDataVersion2{} result.ReadFrom(NewJceReader(src)) @@ -73,7 +74,7 @@ func TestRequestDataVersion2_ReadFrom(t *testing.T) { func BenchmarkRequestDataVersion2_ReadFrom(b *testing.B) { w := NewJceWriter() - w.WriteObject(req.Map, 0) + w.writeMapStrMapStrBytes(req.Map, 0) src := w.Bytes() b.SetBytes(int64(len(src))) result := &RequestDataVersion2{} @@ -93,3 +94,39 @@ func TestJceReader_ReadBytes(t *testing.T) { assert.Equal(t, b, rb) } + +func (w *JceWriter) WriteObject(i interface{}, tag byte) { + t := reflect.TypeOf(i) + if t.Kind() == reflect.Map { + w.WriteMap(i, tag) + return + } + if t.Kind() == reflect.Slice { + if b, ok := i.([]byte); ok { + w.WriteBytes(b, tag) + return + } + w.WriteSlice(i, tag) + return + } + switch o := i.(type) { + case byte: + w.WriteByte(o, tag) + case bool: + w.WriteBool(o, tag) + case int16: + w.WriteInt16(o, tag) + case int32: + w.WriteInt32(o, tag) + case int64: + w.WriteInt64(o, tag) + case float32: + w.WriteFloat32(o, tag) + case float64: + w.WriteFloat64(o, tag) + case string: + w.WriteString(o, tag) + case IJceStruct: + w.WriteJceStruct(o, tag) + } +} diff --git a/binary/jce/structs.go b/binary/jce/structs.go index 1cb644e2..b686a5f4 100644 --- a/binary/jce/structs.go +++ b/binary/jce/structs.go @@ -5,6 +5,7 @@ type IJceStruct interface { ReadFrom(*JceReader) } +//go:generate go run gen/structs_parser.go -f structs_tobytes.go type ( RequestPacket struct { IVersion int16 `jceId:"1"` @@ -539,12 +540,6 @@ type ( } ) -func (pkt *RequestPacket) ToBytes() []byte { - w := NewJceWriter() - w.WriteJceStructRaw(pkt) - return w.Bytes() -} - func (pkt *RequestPacket) ReadFrom(r *JceReader) { pkt.IVersion = r.ReadInt16(1) pkt.CPacketType = r.ReadByte(2) @@ -560,22 +555,10 @@ func (pkt *RequestPacket) ReadFrom(r *JceReader) { pkt.Status = r.ReadMapStrStr(10) } -func (pkt *RequestDataVersion3) ToBytes() []byte { - w := NewJceWriter() - w.WriteJceStructRaw(pkt) - return w.Bytes() -} - func (pkt *RequestDataVersion3) ReadFrom(r *JceReader) { pkt.Map = r.ReadMapStrByte(0) } -func (pkt *RequestDataVersion2) ToBytes() []byte { - w := NewJceWriter() - w.WriteJceStructRaw(pkt) - return w.Bytes() -} - func (pkt *RequestDataVersion2) ReadFrom(r *JceReader) { pkt.Map = r.ReadMapStrMapStrByte(0) } @@ -625,12 +608,6 @@ func (pkt *BigDataIPInfo) ReadFrom(r *JceReader) { pkt.Port = r.ReadInt64(2) } -func (pkt *SvcReqRegister) ToBytes() []byte { - w := NewJceWriter() - w.WriteJceStructRaw(pkt) - return w.Bytes() -} - func (pkt *SvcRespRegister) ReadFrom(r *JceReader) { pkt.Uin = r.ReadInt64(0) pkt.Bid = r.ReadInt64(1) @@ -652,24 +629,6 @@ func (pkt *SvcRespRegister) ReadFrom(r *JceReader) { pkt.ExtOnlineStatus = r.ReadInt64(17) } -func (pkt *FriendListRequest) ToBytes() []byte { - w := NewJceWriter() - w.WriteJceStructRaw(pkt) - return w.Bytes() -} - -func (pkt *SummaryCardReq) ToBytes() []byte { - w := NewJceWriter() - w.WriteJceStructRaw(pkt) - return w.Bytes() -} - -func (pkt *SummaryCardReqSearch) ToBytes() []byte { - w := NewJceWriter() - w.WriteJceStructRaw(pkt) - return w.Bytes() -} - func (pkt *FriendInfo) ReadFrom(r *JceReader) { pkt.FriendUin = r.ReadInt64(0) pkt.GroupId = r.ReadByte(1) @@ -683,12 +642,6 @@ func (pkt *FriendInfo) ReadFrom(r *JceReader) { pkt.CardID = r.ReadBytes(41) } -func (pkt *TroopListRequest) ToBytes() []byte { - w := NewJceWriter() - w.WriteJceStructRaw(pkt) - return w.Bytes() -} - func (pkt *TroopNumber) ReadFrom(r *JceReader) { pkt.GroupUin = r.ReadInt64(0) pkt.GroupCode = r.ReadInt64(1) @@ -699,12 +652,6 @@ func (pkt *TroopNumber) ReadFrom(r *JceReader) { pkt.MaxGroupMemberNum = r.ReadInt64(29) } -func (pkt *TroopMemberListRequest) ToBytes() []byte { - w := NewJceWriter() - w.WriteJceStructRaw(pkt) - return w.Bytes() -} - func (pkt *TroopMemberInfo) ReadFrom(r *JceReader) { pkt.MemberUin = r.ReadInt64(0) pkt.FaceId = r.ReadInt16(1) @@ -800,33 +747,3 @@ func (pkt *InstanceInfo) ReadFrom(r *JceReader) { pkt.ProductType = r.ReadInt64(3) pkt.ClientType = r.ReadInt64(4) } - -func (pkt *SvcRespPushMsg) ToBytes() []byte { - w := NewJceWriter() - w.WriteJceStructRaw(pkt) - return w.Bytes() -} - -func (pkt *ModifyGroupCardRequest) ToBytes() []byte { - w := NewJceWriter() - w.WriteJceStructRaw(pkt) - return w.Bytes() -} - -func (pkt *SvcReqGetDevLoginInfo) ToBytes() []byte { - w := NewJceWriter() - w.WriteJceStructRaw(pkt) - return w.Bytes() -} - -func (pkt *SvcReqRegisterNew) ToBytes() []byte { - w := NewJceWriter() - w.WriteJceStructRaw(pkt) - return w.Bytes() -} - -func (pkt *DelFriendReq) ToBytes() []byte { - w := NewJceWriter() - w.WriteJceStructRaw(pkt) - return w.Bytes() -} diff --git a/binary/jce/structs_tobytes.go b/binary/jce/structs_tobytes.go new file mode 100644 index 00000000..4c172eac --- /dev/null +++ b/binary/jce/structs_tobytes.go @@ -0,0 +1,205 @@ +// Code generated by structs_parser; DO NOT EDIT. + +package jce + +func (pkt *RequestPacket) ToBytes() []byte { + w := NewJceWriter() + w.WriteInt16(int16(pkt.IVersion), 1) + w.WriteByte(byte(pkt.CPacketType), 2) + w.WriteInt32(int32(pkt.IMessageType), 3) + w.WriteInt32(int32(pkt.IRequestId), 4) + w.WriteString(pkt.SServantName, 5) + w.WriteString(pkt.SFuncName, 6) + w.WriteBytes(pkt.SBuffer, 7) + w.WriteInt32(int32(pkt.ITimeout), 8) + w.writeMapStrStr(pkt.Context, 9) + w.writeMapStrStr(pkt.Status, 10) + return w.Bytes() +} + +func (pkt *RequestDataVersion3) ToBytes() []byte { + w := NewJceWriter() + w.writeMapStrBytes(pkt.Map, 0) + return w.Bytes() +} + +func (pkt *RequestDataVersion2) ToBytes() []byte { + w := NewJceWriter() + w.writeMapStrMapStrBytes(pkt.Map, 0) + return w.Bytes() +} + +func (pkt *SvcReqRegister) ToBytes() []byte { + w := NewJceWriter() + w.WriteInt64(int64(pkt.Uin), 0) + w.WriteInt64(int64(pkt.Bid), 1) + w.WriteByte(byte(pkt.ConnType), 2) + w.WriteString(pkt.Other, 3) + w.WriteInt32(int32(pkt.Status), 4) + w.WriteByte(byte(pkt.OnlinePush), 5) + w.WriteByte(byte(pkt.IsOnline), 6) + w.WriteByte(byte(pkt.IsShowOnline), 7) + w.WriteByte(byte(pkt.KickPC), 8) + w.WriteByte(byte(pkt.KickWeak), 9) + w.WriteInt64(int64(pkt.Timestamp), 10) + w.WriteInt64(int64(pkt.IOSVersion), 11) + w.WriteByte(byte(pkt.NetType), 12) + w.WriteString(pkt.BuildVer, 13) + w.WriteByte(byte(pkt.RegType), 14) + w.WriteBytes(pkt.DevParam, 15) + w.WriteBytes(pkt.Guid, 16) + w.WriteInt32(int32(pkt.LocaleId), 17) + w.WriteByte(byte(pkt.SilentPush), 18) + w.WriteString(pkt.DevName, 19) + w.WriteString(pkt.DevType, 20) + w.WriteString(pkt.OSVer, 21) + w.WriteByte(byte(pkt.OpenPush), 22) + w.WriteInt64(int64(pkt.LargeSeq), 23) + w.WriteInt64(int64(pkt.LastWatchStartTime), 24) + w.WriteInt64(int64(pkt.OldSSOIp), 26) + w.WriteInt64(int64(pkt.NewSSOIp), 27) + w.WriteString(pkt.ChannelNo, 28) + w.WriteInt64(int64(pkt.CPID), 29) + w.WriteString(pkt.VendorName, 30) + w.WriteString(pkt.VendorOSName, 31) + w.WriteString(pkt.IOSIdfa, 32) + w.WriteBytes(pkt.B769, 33) + w.WriteByte(byte(pkt.IsSetStatus), 34) + w.WriteBytes(pkt.ServerBuf, 35) + w.WriteByte(byte(pkt.SetMute), 36) + w.WriteInt64(int64(pkt.ExtOnlineStatus), 38) + w.WriteInt32(int32(pkt.BatteryStatus), 39) + return w.Bytes() +} + +func (pkt *FriendListRequest) ToBytes() []byte { + w := NewJceWriter() + w.WriteInt32(int32(pkt.Reqtype), 0) + w.WriteByte(byte(pkt.IfReflush), 1) + w.WriteInt64(int64(pkt.Uin), 2) + w.WriteInt16(int16(pkt.StartIndex), 3) + w.WriteInt16(int16(pkt.FriendCount), 4) + w.WriteByte(byte(pkt.GroupId), 5) + w.WriteByte(byte(pkt.IfGetGroupInfo), 6) + w.WriteByte(byte(pkt.GroupStartIndex), 7) + w.WriteByte(byte(pkt.GroupCount), 8) + w.WriteByte(byte(pkt.IfGetMSFGroup), 9) + w.WriteByte(byte(pkt.IfShowTermType), 10) + w.WriteInt64(int64(pkt.Version), 11) + w.WriteInt64Slice(pkt.UinList, 12) + w.WriteInt32(int32(pkt.AppType), 13) + w.WriteByte(byte(pkt.IfGetDOVId), 14) + w.WriteByte(byte(pkt.IfGetBothFlag), 15) + w.WriteBytes(pkt.D50, 16) + w.WriteBytes(pkt.D6B, 17) + w.WriteInt64Slice(pkt.SnsTypeList, 18) + return w.Bytes() +} + +func (pkt *SummaryCardReq) ToBytes() []byte { + w := NewJceWriter() + w.WriteInt64(int64(pkt.Uin), 0) + w.WriteInt32(int32(pkt.ComeFrom), 1) + w.WriteInt64(int64(pkt.QzoneFeedTimestamp), 2) + w.WriteByte(byte(pkt.IsFriend), 3) + w.WriteInt64(int64(pkt.GroupCode), 4) + w.WriteInt64(int64(pkt.GroupUin), 5) + w.WriteInt64(int64(pkt.GetControl), 8) + w.WriteInt32(int32(pkt.AddFriendSource), 9) + w.WriteBytes(pkt.SecureSig, 10) + w.WriteBytesSlice(pkt.ReqServices, 14) + w.WriteInt64(int64(pkt.TinyId), 15) + w.WriteInt64(int64(pkt.LikeSource), 16) + w.WriteByte(byte(pkt.ReqMedalWallInfo), 18) + w.WriteInt64Slice(pkt.Req0x5ebFieldId, 19) + w.WriteByte(byte(pkt.ReqNearbyGodInfo), 20) + w.WriteByte(byte(pkt.ReqExtendCard), 22) + return w.Bytes() +} + +func (pkt *SummaryCardReqSearch) ToBytes() []byte { + w := NewJceWriter() + w.WriteString(pkt.Keyword, 0) + w.WriteString(pkt.CountryCode, 1) + w.WriteInt32(int32(pkt.Version), 2) + w.WriteBytesSlice(pkt.ReqServices, 3) + return w.Bytes() +} + +func (pkt *TroopListRequest) ToBytes() []byte { + w := NewJceWriter() + w.WriteInt64(int64(pkt.Uin), 0) + w.WriteByte(byte(pkt.GetMSFMsgFlag), 1) + w.WriteBytes(pkt.Cookies, 2) + w.WriteInt64Slice(pkt.GroupInfo, 3) + w.WriteByte(byte(pkt.GroupFlagExt), 4) + w.WriteInt32(int32(pkt.Version), 5) + w.WriteInt64(int64(pkt.CompanyId), 6) + w.WriteInt64(int64(pkt.VersionNum), 7) + w.WriteByte(byte(pkt.GetLongGroupName), 8) + return w.Bytes() +} + +func (pkt *TroopMemberListRequest) ToBytes() []byte { + w := NewJceWriter() + w.WriteInt64(int64(pkt.Uin), 0) + w.WriteInt64(int64(pkt.GroupCode), 1) + w.WriteInt64(int64(pkt.NextUin), 2) + w.WriteInt64(int64(pkt.GroupUin), 3) + w.WriteInt64(int64(pkt.Version), 4) + w.WriteInt64(int64(pkt.ReqType), 5) + w.WriteInt64(int64(pkt.GetListAppointTime), 6) + w.WriteByte(byte(pkt.RichCardNameVer), 7) + return w.Bytes() +} + +func (pkt *SvcRespPushMsg) ToBytes() []byte { + w := NewJceWriter() + w.WriteInt64(int64(pkt.Uin), 0) + w.WriteJceStructSlice(pkt.DelInfos, 1) + w.WriteInt32(int32(pkt.Svrip), 2) + w.WriteBytes(pkt.PushToken, 3) + w.WriteInt32(int32(pkt.ServiceType), 4) + return w.Bytes() +} + +func (pkt *ModifyGroupCardRequest) ToBytes() []byte { + w := NewJceWriter() + w.WriteInt64(int64(pkt.Zero), 0) + w.WriteInt64(int64(pkt.GroupCode), 1) + w.WriteInt64(int64(pkt.NewSeq), 2) + w.WriteJceStructSlice(pkt.UinInfo, 3) + return w.Bytes() +} + +func (pkt *SvcReqGetDevLoginInfo) ToBytes() []byte { + w := NewJceWriter() + w.WriteBytes(pkt.Guid, 0) + w.WriteString(pkt.AppName, 1) + w.WriteInt64(int64(pkt.LoginType), 2) + w.WriteInt64(int64(pkt.Timestamp), 3) + w.WriteInt64(int64(pkt.NextItemIndex), 4) + w.WriteInt64(int64(pkt.RequireMax), 5) + w.WriteInt64(int64(pkt.GetDevListType), 6) + return w.Bytes() +} + +func (pkt *SvcReqRegisterNew) ToBytes() []byte { + w := NewJceWriter() + w.WriteInt64(int64(pkt.RequestOptional), 0) + w.WriteByte(byte(pkt.DisGroupMsgFilter), 14) + w.WriteByte(byte(pkt.GroupMask), 15) + w.WriteInt64(int64(pkt.EndSeq), 16) + w.WriteBytes(pkt.O769Body, 20) + return w.Bytes() +} + +func (pkt *DelFriendReq) ToBytes() []byte { + w := NewJceWriter() + w.WriteInt64(int64(pkt.Uin), 0) + w.WriteInt64(int64(pkt.DelUin), 1) + w.WriteByte(byte(pkt.DelType), 2) + w.WriteInt32(int32(pkt.Version), 3) + return w.Bytes() +} + diff --git a/binary/jce/writer.go b/binary/jce/writer.go index 6159d1be..d5251666 100644 --- a/binary/jce/writer.go +++ b/binary/jce/writer.go @@ -3,6 +3,7 @@ package jce import ( "bytes" goBinary "encoding/binary" + "math" "reflect" "strconv" "sync" @@ -16,18 +17,16 @@ func NewJceWriter() *JceWriter { return &JceWriter{buf: new(bytes.Buffer)} } -func (w *JceWriter) writeHead(t byte, tag int) { - if tag < 15 { - b := byte(tag<<4) | t - w.buf.WriteByte(b) - } else if tag < 256 { - b := 0xF0 | t - w.buf.WriteByte(b) - w.buf.WriteByte(byte(tag)) +func (w *JceWriter) writeHead(t, tag byte) { + if tag < 0xF { + w.buf.WriteByte(byte(tag<<4) | t) + } else { + w.buf.WriteByte(0xF0 | t) + w.buf.WriteByte(tag) } } -func (w *JceWriter) WriteByte(b byte, tag int) *JceWriter { +func (w *JceWriter) WriteByte(b, tag byte) *JceWriter { if b == 0 { w.writeHead(12, tag) } else { @@ -37,7 +36,7 @@ func (w *JceWriter) WriteByte(b byte, tag int) *JceWriter { return w } -func (w *JceWriter) WriteBool(b bool, tag int) { +func (w *JceWriter) WriteBool(b bool, tag byte) { var by byte = 0 if b { by = 1 @@ -45,7 +44,7 @@ func (w *JceWriter) WriteBool(b bool, tag int) { w.WriteByte(by, tag) } -func (w *JceWriter) WriteInt16(n int16, tag int) { +func (w *JceWriter) WriteInt16(n int16, tag byte) { switch { case n >= -128 && n <= 127: w.WriteByte(byte(n), tag) @@ -55,14 +54,14 @@ func (w *JceWriter) WriteInt16(n int16, tag int) { } //go:nosplit -func (w *JceWriter) putInt16(n int16, tag int) { +func (w *JceWriter) putInt16(n int16, tag byte) { w.writeHead(1, tag) var buf [2]byte goBinary.BigEndian.PutUint16(buf[:], uint16(n)) w.buf.Write(buf[:]) } -func (w *JceWriter) WriteInt32(n int32, tag int) *JceWriter { +func (w *JceWriter) WriteInt32(n int32, tag byte) *JceWriter { switch { case n >= -128 && n <= 127: w.WriteByte(byte(n), tag) @@ -75,14 +74,14 @@ func (w *JceWriter) WriteInt32(n int32, tag int) *JceWriter { } //go:nosplit -func (w *JceWriter) putInt32(n int32, tag int) { +func (w *JceWriter) putInt32(n int32, tag byte) { w.writeHead(2, tag) var buf [4]byte goBinary.BigEndian.PutUint32(buf[:], uint32(n)) w.buf.Write(buf[:]) } -func (w *JceWriter) WriteInt64(n int64, tag int) *JceWriter { +func (w *JceWriter) WriteInt64(n int64, tag byte) *JceWriter { switch { case n >= -128 && n <= 127: w.WriteByte(byte(n), tag) @@ -97,49 +96,56 @@ func (w *JceWriter) WriteInt64(n int64, tag int) *JceWriter { } //go:nosplit -func (w *JceWriter) putInt64(n int64, tag int) { +func (w *JceWriter) putInt64(n int64, tag byte) { w.writeHead(3, tag) var buf [8]byte goBinary.BigEndian.PutUint64(buf[:], uint64(n)) w.buf.Write(buf[:]) } -func (w *JceWriter) WriteFloat32(n float32, tag int) { +//go:nosplit +func (w *JceWriter) WriteFloat32(n float32, tag byte) { w.writeHead(4, tag) - _ = goBinary.Write(w.buf, goBinary.BigEndian, n) + var buf [4]byte + goBinary.BigEndian.PutUint32(buf[:], math.Float32bits(n)) + w.buf.Write(buf[:]) } -func (w *JceWriter) WriteFloat64(n float64, tag int) { +//go:nosplit +func (w *JceWriter) WriteFloat64(n float64, tag byte) { w.writeHead(5, tag) - _ = goBinary.Write(w.buf, goBinary.BigEndian, n) + var buf [8]byte + goBinary.BigEndian.PutUint64(buf[:], math.Float64bits(n)) + w.buf.Write(buf[:]) } -func (w *JceWriter) WriteString(s string, tag int) *JceWriter { - by := []byte(s) - if len(by) > 255 { +func (w *JceWriter) WriteString(s string, tag byte) *JceWriter { + if len(s) > 255 { w.writeHead(7, tag) - _ = goBinary.Write(w.buf, goBinary.BigEndian, int32(len(by))) - w.buf.Write(by) + var buf [4]byte + goBinary.BigEndian.PutUint32(buf[:], uint32(len(s))) + w.buf.Write(buf[:]) + w.buf.WriteString(s) return w } w.writeHead(6, tag) - w.buf.WriteByte(byte(len(by))) - w.buf.Write(by) + w.buf.WriteByte(byte(len(s))) + w.buf.WriteString(s) return w } -func (w *JceWriter) WriteBytes(l []byte, tag int) *JceWriter { +func (w *JceWriter) WriteBytes(l []byte, tag byte) *JceWriter { w.writeHead(13, tag) - w.writeHead(0, 0) + w.buf.WriteByte(0) // w.writeHead(0, 0) w.WriteInt32(int32(len(l)), 0) w.buf.Write(l) return w } -func (w *JceWriter) WriteInt64Slice(l []int64, tag int) { +func (w *JceWriter) WriteInt64Slice(l []int64, tag byte) { w.writeHead(9, tag) if len(l) == 0 { - w.WriteInt32(0, 0) + w.buf.WriteByte(0) // w.WriteInt32(0, 0) return } w.WriteInt32(int32(len(l)), 0) @@ -148,7 +154,19 @@ func (w *JceWriter) WriteInt64Slice(l []int64, tag int) { } } -func (w *JceWriter) WriteSlice(i interface{}, tag int) { +func (w *JceWriter) WriteBytesSlice(l [][]byte, tag byte) { + w.writeHead(9, tag) + if len(l) == 0 { + w.buf.WriteByte(0) // w.WriteInt32(0, 0) + return + } + w.WriteInt32(int32(len(l)), 0) + for _, v := range l { + w.WriteBytes(v, 0) + } +} + +func (w *JceWriter) WriteSlice(i interface{}, tag byte) { va := reflect.ValueOf(i) if va.Kind() != reflect.Slice { panic("JceWriter.WriteSlice: not a slice") @@ -156,13 +174,13 @@ func (w *JceWriter) WriteSlice(i interface{}, tag int) { w.writeSlice(va, tag) } -func (w *JceWriter) writeSlice(slice reflect.Value, tag int) { +func (w *JceWriter) writeSlice(slice reflect.Value, tag byte) { if slice.Kind() != reflect.Slice { return } w.writeHead(9, tag) if slice.Len() == 0 { - w.WriteInt32(0, 0) + w.buf.WriteByte(0) // w.WriteInt32(0, 0) return } w.WriteInt32(int32(slice.Len()), 0) @@ -172,10 +190,10 @@ func (w *JceWriter) writeSlice(slice reflect.Value, tag int) { } } -func (w *JceWriter) WriteJceStructSlice(l []IJceStruct, tag int) { +func (w *JceWriter) WriteJceStructSlice(l []IJceStruct, tag byte) { w.writeHead(9, tag) if len(l) == 0 { - w.WriteInt32(0, 0) + w.buf.WriteByte(0) // w.WriteInt32(0, 0) return } w.WriteInt32(int32(len(l)), 0) @@ -184,7 +202,7 @@ func (w *JceWriter) WriteJceStructSlice(l []IJceStruct, tag int) { } } -func (w *JceWriter) WriteMap(m interface{}, tag int) { +func (w *JceWriter) WriteMap(m interface{}, tag byte) { va := reflect.ValueOf(m) if va.Kind() != reflect.Map { panic("JceWriter.WriteMap: not a map") @@ -192,10 +210,10 @@ func (w *JceWriter) WriteMap(m interface{}, tag int) { w.writeMap(va, tag) } -func (w *JceWriter) writeMap(m reflect.Value, tag int) { +func (w *JceWriter) writeMap(m reflect.Value, tag byte) { if m.IsNil() { w.writeHead(8, tag) - w.WriteInt32(0, 0) + w.buf.WriteByte(0) // w.WriteInt32(0, 0) return } if m.Kind() != reflect.Map { @@ -210,54 +228,72 @@ func (w *JceWriter) writeMap(m reflect.Value, tag int) { } } -func (w *JceWriter) WriteObject(i interface{}, tag int) { - t := reflect.TypeOf(i) - if t.Kind() == reflect.Map { - w.WriteMap(i, tag) +func (w *JceWriter) writeMapStrStr(m map[string]string, tag byte) { + if m == nil { + w.writeHead(8, tag) + w.buf.WriteByte(0) // w.WriteInt32(0, 0) return } - if t.Kind() == reflect.Slice { - if b, ok := i.([]byte); ok { - w.WriteBytes(b, tag) - return - } - w.WriteSlice(i, tag) - return - } - switch o := i.(type) { - case byte: - w.WriteByte(o, tag) - case bool: - w.WriteBool(o, tag) - case int16: - w.WriteInt16(o, tag) - case int32: - w.WriteInt32(o, tag) - case int64: - w.WriteInt64(o, tag) - case float32: - w.WriteFloat32(o, tag) - case float64: - w.WriteFloat64(o, tag) - case string: - w.WriteString(o, tag) - case IJceStruct: - w.WriteJceStruct(o, tag) + w.writeHead(8, tag) + w.WriteInt32(int32(len(m)), 0) + for k, v := range m { + w.WriteString(k, 0) + w.WriteString(v, 1) } } -func (w *JceWriter) writeObject(v reflect.Value, tag int) { +func (w *JceWriter) writeMapStrBytes(m map[string][]byte, tag byte) { + if m == nil { + w.writeHead(8, tag) + w.buf.WriteByte(0) // w.WriteInt32(0, 0) + return + } + w.writeHead(8, tag) + w.WriteInt32(int32(len(m)), 0) + for k, v := range m { + w.WriteString(k, 0) + w.WriteBytes(v, 1) + } +} + +func (w *JceWriter) writeMapStrMapStrBytes(m map[string]map[string][]byte, tag byte) { + if m == nil { + w.writeHead(8, tag) + w.buf.WriteByte(0) // w.WriteInt32(0, 0) + return + } + w.writeHead(8, tag) + w.WriteInt32(int32(len(m)), 0) + for k, v := range m { + w.WriteString(k, 0) + w.writeMapStrBytes(v, 1) + } +} + +func (w *JceWriter) writeObject(v reflect.Value, tag byte) { k := v.Kind() if k == reflect.Map { - w.writeMap(v, tag) + switch o := v.Interface().(type) { + case map[string]string: + w.writeMapStrStr(o, tag) + case map[string][]byte: + w.writeMapStrBytes(o, tag) + case map[string]map[string][]byte: + w.writeMapStrMapStrBytes(o, tag) + default: + w.writeMap(v, tag) + } return } if k == reflect.Slice { - if v.Type().Elem().Kind() == reflect.Uint8 { - w.WriteBytes(v.Bytes(), tag) - return + switch o := v.Interface().(type) { + case []byte: + w.WriteBytes(o, tag) + case []IJceStruct: + w.WriteJceStructSlice(o, tag) + default: + w.writeSlice(v, tag) } - w.writeSlice(v, tag) return } switch k { @@ -293,7 +329,6 @@ var decoderCache = sync.Map{} // WriteJceStructRaw 写入 Jce 结构体 func (w *JceWriter) WriteJceStructRaw(s interface{}) { t := reflect.TypeOf(s) - reflect.ValueOf(s).Interface() if t.Kind() != reflect.Ptr { return } @@ -324,11 +359,11 @@ func (w *JceWriter) WriteJceStructRaw(s interface{}) { } for _, dec := range jceDec { obj := v.Field(dec.index) - w.writeObject(obj, dec.id) + w.writeObject(obj, byte(dec.id)) } } -func (w *JceWriter) WriteJceStruct(s IJceStruct, tag int) { +func (w *JceWriter) WriteJceStruct(s IJceStruct, tag byte) { w.writeHead(10, tag) w.WriteJceStructRaw(s) w.writeHead(11, 0) diff --git a/binary/jce/writer_test.go b/binary/jce/writer_test.go index 1eafdd00..2a2d85bb 100644 --- a/binary/jce/writer_test.go +++ b/binary/jce/writer_test.go @@ -12,7 +12,7 @@ func BenchmarkJceWriter_WriteMap(b *testing.B) { var x = globalBytes for i := 0; i < b.N; i++ { w := NewJceWriter() - w.WriteMap(req.Map, 0) + w.writeMapStrMapStrBytes(req.Map, 0) x = w.Bytes() } globalBytes = x @@ -41,18 +41,14 @@ var reqPacket1 = &RequestPacket{ func BenchmarkJceWriter_WriteJceStructRaw(b *testing.B) { var x = globalBytes for i := 0; i < b.N; i++ { - w := NewJceWriter() - w.WriteJceStructRaw(reqPacket1) - x = w.Bytes() + _ = reqPacket1.ToBytes() } globalBytes = x b.SetBytes(int64(len(globalBytes))) } func TestJceWriter_WriteJceStructRaw(t *testing.T) { - w := NewJceWriter() - w.WriteJceStructRaw(reqPacket1) - r := NewJceReader(w.Bytes()) + r := NewJceReader(reqPacket1.ToBytes()) var reqPacket2 RequestPacket reqPacket2.ReadFrom(r) assert.Equal(t, reqPacket1, &reqPacket2)