From f5b16b19c23c30fe8e25982a4922e0abb058ccf7 Mon Sep 17 00:00:00 2001 From: LXY <767763591@qq.com> Date: Mon, 27 Dec 2021 12:46:56 +0800 Subject: [PATCH] Network refactor (#229) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 重构:netLoop下移到Listener * 格式:修正NewClient中注释段缩进 * 格式:更名&将代码移动到对应位置 * 格式:TCPListener去锁化 * 修正:netLoop中的错误调用 * 修正:使其可用 * 修正:使功能一致 * 修正:现在可正常运行 * 优化:更早的释放锁(? * 修正:未写完的部分 * 修正:潜在的断线时仍然认为在线这件事&删除空重复文件 * 文档:添加部分注释 * 修正:CoverError可能引起死锁 * 修正:永远不会被触发的DisconnectEvent * 文档:将注释移动至对应位置 --- client/client.go | 171 +++++++++++++++++---- client/internal/network/conn.go | 141 +++++++++++------ client/internal/network/response.go | 1 + client/internal/network/transport.go | 83 +++++++++- client/network.go | 220 +++++++++++++-------------- message/message.go | 1 + utils/func.go | 26 ++++ utils/func_test.go | 29 ++++ utils/http.go | 5 +- 9 files changed, 485 insertions(+), 192 deletions(-) create mode 100644 utils/func.go create mode 100644 utils/func_test.go diff --git a/client/client.go b/client/client.go index 6ce36024..7e79f761 100644 --- a/client/client.go +++ b/client/client.go @@ -48,12 +48,12 @@ type QQClient struct { SequenceId atomic.Int32 SessionId []byte RandomKey []byte - TCP *network.TCPListener ConnectTime time.Time // todo: combine net conn, transport, pending into one struct pendingMu sync.Mutex pending map[uint16]*network.Call + //TCP *network.TCPListener transport *network.Transport oicq *oicq.Codec @@ -131,7 +131,7 @@ func NewClientMd5(uin int64, passwordMd5 [16]byte) *QQClient { Uin: uin, PasswordMd5: passwordMd5, AllowSlider: true, - TCP: &network.TCPListener{}, + //TCP: &network.TCPListener{}, sig: &auth.SigInfo{ OutPacketSessionID: []byte{0x02, 0xB0, 0x5B, 0x8B}, }, @@ -181,15 +181,15 @@ func NewClientMd5(uin int64, passwordMd5 [16]byte) *QQClient { } if len(cli.servers) == 0 { cli.servers = []*net.TCPAddr{ // default servers - {IP: net.IP{42, 81, 172, 81}, Port: 80}, - {IP: net.IP{114, 221, 148, 59}, Port: 14000}, - {IP: net.IP{42, 81, 172, 147}, Port: 443}, - {IP: net.IP{125, 94, 60, 146}, Port: 80}, - {IP: net.IP{114, 221, 144, 215}, Port: 80}, {IP: net.IP{42, 81, 172, 22}, Port: 80}, + {IP: net.IP{42, 81, 172, 81}, Port: 80}, + {IP: net.IP{42, 81, 172, 147}, Port: 443}, + {IP: net.IP{114, 221, 144, 215}, Port: 80}, + {IP: net.IP{114, 221, 148, 59}, Port: 14000}, + {IP: net.IP{125, 94, 60, 146}, Port: 80}, } } - pings := make([]int64, len(cli.servers)) + /*pings := make([]int64, len(cli.servers)) wg := sync.WaitGroup{} wg.Add(len(cli.servers)) for i := range cli.servers { @@ -209,9 +209,9 @@ func NewClientMd5(uin int64, passwordMd5 [16]byte) *QQClient { }) if len(cli.servers) > 3 { cli.servers = cli.servers[0 : len(cli.servers)/2] // 保留ping值中位数以上的server - } - cli.TCP.PlannedDisconnect(cli.plannedDisconnect) - cli.TCP.UnexpectedDisconnect(cli.unexpectedDisconnect) + }*/ + cli.transport.PlannedDisconnect(cli.plannedDisconnect) + cli.transport.UnexpectedDisconnect(cli.unexpectedDisconnect) rand.Read(cli.RandomKey) return cli } @@ -255,11 +255,50 @@ func (c *QQClient) TokenLogin(token []byte) error { if c.Online.Load() { return ErrAlreadyOnline } - err := c.connect() + err := c.LoadToken(token) if err != nil { return err } - { + return c.ReLogin() +} + +func (c *QQClient) ReLogin() error { + if c.Online.Load() { + return ErrAlreadyOnline + } + err := c.connectFastest() + if err != nil { + return err + } + _, err = c.callAndDecode(c.buildRequestChangeSigRequest(c.version.MainSigMap), decodeExchangeEmpResponse) + if err != nil { + return err + } + err = c.init(true) + // 登录失败 + if err != nil { + c.Disconnect() + } + return err +} + +func (c *QQClient) DumpToken() []byte { + return binary.NewWriterF(func(w *binary.Writer) { + w.WriteUInt64(uint64(c.Uin)) + w.WriteBytesShort(c.sig.D2) + w.WriteBytesShort(c.sig.D2Key) + w.WriteBytesShort(c.sig.TGT) + w.WriteBytesShort(c.sig.SrmToken) + w.WriteBytesShort(c.sig.T133) + w.WriteBytesShort(c.sig.EncryptedA1) + w.WriteBytesShort(c.oicq.WtSessionTicketKey) + w.WriteBytesShort(c.sig.OutPacketSessionID) + w.WriteBytesShort(c.deviceInfo.TgtgtKey) + }) +} + +func (c *QQClient) LoadToken(token []byte) error { + return utils.CoverError(func() { r := binary.NewReader(token) c.Uin = r.ReadInt64() c.sig.D2 = r.ReadBytesShort() @@ -272,18 +311,83 @@ func (c *QQClient) TokenLogin(token []byte) error { c.sig.OutPacketSessionID = r.ReadBytesShort() // SystemDeviceInfo.TgtgtKey = r.ReadBytesShort() c.deviceInfo.TgtgtKey = r.ReadBytesShort() - } - _, err = c.callAndDecode(c.buildRequestChangeSigRequest(c.version.MainSigMap), decodeExchangeEmpResponse) - if err != nil { - return err - } - return c.init(true) + copy(SystemDeviceInfo.TgtgtKey, c.deviceInfo.TgtgtKey) + }) } +func (c *QQClient) DumpDevice() []byte { + return binary.NewWriterF(func(w *binary.Writer) { + w.WriteBytesShort(c.deviceInfo.Display) + w.WriteBytesShort(c.deviceInfo.Product) + w.WriteBytesShort(c.deviceInfo.Device) + w.WriteBytesShort(c.deviceInfo.Board) + w.WriteBytesShort(c.deviceInfo.Brand) + w.WriteBytesShort(c.deviceInfo.Model) + w.WriteBytesShort(c.deviceInfo.Bootloader) + w.WriteBytesShort(c.deviceInfo.FingerPrint) + w.WriteBytesShort(c.deviceInfo.BootId) + w.WriteBytesShort(c.deviceInfo.ProcVersion) + w.WriteBytesShort(c.deviceInfo.BaseBand) + w.WriteBytesShort(c.deviceInfo.SimInfo) + w.WriteBytesShort(c.deviceInfo.OSType) + w.WriteBytesShort(c.deviceInfo.MacAddress) + w.WriteBytesShort(c.deviceInfo.IpAddress) + w.WriteBytesShort(c.deviceInfo.WifiBSSID) + w.WriteBytesShort(c.deviceInfo.WifiSSID) + w.WriteBytesShort(c.deviceInfo.IMSIMd5) + w.WriteStringShort(c.deviceInfo.IMEI) + w.WriteBytesShort(c.deviceInfo.APN) + w.WriteBytesShort(c.deviceInfo.VendorName) + w.WriteBytesShort(c.deviceInfo.VendorOSName) + w.WriteBytesShort(c.deviceInfo.AndroidId) + + w.Write(c.PasswordMd5[:]) + }) +} + +func (c *QQClient) LoadDevice(device []byte) error { + return utils.CoverError(func() { + r := binary.NewReader(device) + c.deviceInfo.Display = r.ReadBytesShort() + c.deviceInfo.Product = r.ReadBytesShort() + c.deviceInfo.Device = r.ReadBytesShort() + c.deviceInfo.Board = r.ReadBytesShort() + c.deviceInfo.Brand = r.ReadBytesShort() + c.deviceInfo.Model = r.ReadBytesShort() + c.deviceInfo.Bootloader = r.ReadBytesShort() + c.deviceInfo.FingerPrint = r.ReadBytesShort() + c.deviceInfo.BootId = r.ReadBytesShort() + c.deviceInfo.ProcVersion = r.ReadBytesShort() + c.deviceInfo.BaseBand = r.ReadBytesShort() + c.deviceInfo.SimInfo = r.ReadBytesShort() + c.deviceInfo.OSType = r.ReadBytesShort() + c.deviceInfo.MacAddress = r.ReadBytesShort() + c.deviceInfo.IpAddress = r.ReadBytesShort() + c.deviceInfo.WifiBSSID = r.ReadBytesShort() + c.deviceInfo.WifiSSID = r.ReadBytesShort() + c.deviceInfo.IMSIMd5 = r.ReadBytesShort() + c.deviceInfo.IMEI = r.ReadStringShort() + c.deviceInfo.APN = r.ReadBytesShort() + c.deviceInfo.VendorName = r.ReadBytesShort() + c.deviceInfo.VendorOSName = r.ReadBytesShort() + c.deviceInfo.AndroidId = r.ReadBytesShort() + + copy(c.PasswordMd5[:], r.ReadBytes(md5.Size)) + }) +} + +// FetchQRCode 以默认值获取登录二维码 +// 函数已被弃用 请使用FetchQRCodeCustomSize获得更可控结果 +// 但该兼容函数不会被删除 +// Deprecated use FetchQRCodeCustomSize(3, 4, 2) instead func (c *QQClient) FetchQRCode() (*QRCodeLoginResponse, error) { return c.FetchQRCodeCustomSize(3, 4, 2) } +// FetchQRCodeCustomSize 以特定参数获取登录二维码 +// size: 块尺寸 默认值3 即单个黑/白块大小为3x3像素 +// margin: 与图片边界的距离 默认值4 即二维码主体至图片边界有4像素白色填充 +// ecLevel: 纠错等级 可用值:1,2,3 默认值2 func (c *QQClient) FetchQRCodeCustomSize(size, margin, ecLevel uint32) (*QRCodeLoginResponse, error) { if c.Online.Load() { return nil, ErrAlreadyOnline @@ -400,15 +504,13 @@ func (c *QQClient) init(tokenLogin bool) error { d2() } } - c.groupSysMsgCache, _ = c.GetGroupSystemMessages() - if !c.heartbeatEnabled { - go c.doHeartbeat() - } + go c.doHeartbeat() _ = c.RefreshStatus() if c.version.Protocol == auth.QiDian { _, _ = c.callAndDecode(c.buildLoginExtraPacket(), decodeLoginExtraResponse) // 小登录 _, _ = c.callAndDecode(c.buildConnKeyRequestPacket(), decodeConnKeyResponse) // big data key 如果等待 config push 的话时间来不及 } + c.groupSysMsgCache, _ = c.GetGroupSystemMessages() req := c.buildGetMessageRequest(msg.SyncFlag_START, time.Now().Unix()) req.Params = network.Params{"used_reg_proxy": true, "init": true} _, _ = c.callAndDecode(req, decodeMessageSvcPacket) @@ -765,10 +867,21 @@ func (c *QQClient) nextHighwayApplySeq() int32 { } func (c *QQClient) doHeartbeat() { + // 不需要atomic/锁 + if c.heartbeatEnabled { + return + } c.heartbeatEnabled = true + defer func() { + c.heartbeatEnabled = false + }() times := 0 - for c.Online.Load() { - time.Sleep(time.Second * 30) + ticker := time.NewTicker(time.Second * 30) + for range ticker.C { + if !c.Online.Load() { + ticker.Stop() + return // 下线停止goroutine,for gc + } seq := c.nextSeq() req := network.Request{ Type: network.RequestTypeLogin, @@ -779,8 +892,11 @@ func (c *QQClient) doHeartbeat() { Body: EmptyBytes, } _, err := c.call(&req) - if errors.Is(err, network.ErrConnectionClosed) { - continue + if err != nil { + if errors.Is(err, network.ErrConnectionBroken) { + break + } + continue // skip time++ } times++ if times >= 7 { @@ -788,5 +904,4 @@ func (c *QQClient) doHeartbeat() { times = 0 } } - c.heartbeatEnabled = false } diff --git a/client/internal/network/conn.go b/client/internal/network/conn.go index a43614c1..5052b320 100644 --- a/client/internal/network/conn.go +++ b/client/internal/network/conn.go @@ -5,33 +5,40 @@ import ( "io" "net" "sync" + "sync/atomic" + "unsafe" "github.com/pkg/errors" ) type TCPListener struct { - lock sync.RWMutex - conn net.Conn - connected bool - plannedDisconnect func(*TCPListener) - unexpectedDisconnect func(*TCPListener, error) + //lock sync.RWMutex + conn *net.TCPConn + //connected bool + + // PlannedDisconnect 预料中的断开连接 + // 如调用 Close() Connect() + PlannedDisconnect func(*TCPListener) + + // UnexpectedDisconnect 未预料的断开连接 + UnexpectedDisconnect func(*TCPListener, error) } -var ErrConnectionClosed = errors.New("connection closed") - -// PlannedDisconnect 预料中的断开连接 -// 如调用 Close() Connect() -func (t *TCPListener) PlannedDisconnect(f func(*TCPListener)) { - t.lock.Lock() - defer t.lock.Unlock() - t.plannedDisconnect = f +func (t *TCPListener) getConn() *net.TCPConn { + return (*net.TCPConn)(atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&t.conn)))) } -// UnexpectedDisconnect 未预料的断开连接 -func (t *TCPListener) UnexpectedDisconnect(f func(*TCPListener, error)) { - t.lock.Lock() - defer t.lock.Unlock() - t.unexpectedDisconnect = f +func (t *TCPListener) setConn(conn *net.TCPConn) (swapped bool) { + return atomic.CompareAndSwapPointer((*unsafe.Pointer)(unsafe.Pointer(&t.conn)), unsafe.Pointer(nil), unsafe.Pointer(conn)) +} + +func (t *TCPListener) closeConn() *net.TCPConn { + return (*net.TCPConn)(atomic.SwapPointer((*unsafe.Pointer)(unsafe.Pointer(&t.conn)), unsafe.Pointer(nil))) +} + +func (t *TCPListener) Connected() bool { + // 等同于 t.getConn() != nil (? copilot写的) + return t.getConn() != nil } func (t *TCPListener) Connect(addr *net.TCPAddr) error { @@ -40,24 +47,59 @@ func (t *TCPListener) Connect(addr *net.TCPAddr) error { if err != nil { return errors.Wrap(err, "dial tcp error") } - t.lock.Lock() - defer t.lock.Unlock() - t.conn = conn - t.connected = true + t.setConn(conn) + //t.lock.Lock() + //defer t.lock.Unlock() + //t.conn = conn return nil } +// ConnectFastest 连接到最快的服务器 +// TODO 禁用不可用服务器 +func (t *TCPListener) ConnectFastest(addr []*net.TCPAddr) (*net.TCPAddr, error) { + ch := make(chan error) + wg := sync.WaitGroup{} + wg.Add(len(addr)) + for _, remote := range addr { + go func(remote *net.TCPAddr) { + defer wg.Done() + conn, err := net.DialTCP("tcp", nil, remote) + if err != nil { + return + } + //addrs = append(addrs, remote) + if !t.setConn(conn) { + _ = conn.Close() + return + } + ch <- nil + }(remote) + } + go func() { + wg.Wait() + if t.getConn() == nil { + ch <- errors.New("All addr are unreachable") + } + }() + err := <-ch + if err != nil { + return nil, err + } + conn := t.getConn() + return conn.RemoteAddr().(*net.TCPAddr), nil +} + func (t *TCPListener) Write(buf []byte) error { if conn := t.getConn(); conn != nil { _, err := conn.Write(buf) if err != nil { t.unexpectedClose(err) - return ErrConnectionClosed + return ErrConnectionBroken } return nil } - return ErrConnectionClosed + return ErrConnectionBroken } func (t *TCPListener) ReadBytes(len int) ([]byte, error) { @@ -67,12 +109,12 @@ func (t *TCPListener) ReadBytes(len int) ([]byte, error) { if err != nil { // time.Sleep(time.Millisecond * 100) // 服务器会发送offline包后立即断开连接, 此时还没解析, 可能还是得加锁 t.unexpectedClose(err) - return nil, ErrConnectionClosed + return nil, ErrConnectionBroken } return buf, nil } - return nil, ErrConnectionClosed + return nil, ErrConnectionBroken } func (t *TCPListener) ReadInt32() (int32, error) { @@ -94,34 +136,37 @@ func (t *TCPListener) unexpectedClose(err error) { } func (t *TCPListener) close() { - t.lock.Lock() - defer t.lock.Unlock() - if t.conn != nil { - _ = t.conn.Close() - t.conn = nil + if conn := t.closeConn(); conn != nil { + _ = conn.Close() } } func (t *TCPListener) invokePlannedDisconnect() { - t.lock.RLock() - defer t.lock.RUnlock() - if t.plannedDisconnect != nil && t.connected { - go t.plannedDisconnect(t) - t.connected = false - } + //if t.Connected() { + t.PlannedDisconnect(t) + //} + //t.lock.RLock() + //defer t.lock.RUnlock() + //if t.plannedDisconnect != nil && t.connected { + // go t.plannedDisconnect(t) + // t.connected = false + //} } func (t *TCPListener) invokeUnexpectedDisconnect(err error) { - t.lock.RLock() - defer t.lock.RUnlock() - if t.unexpectedDisconnect != nil && t.connected { - go t.unexpectedDisconnect(t, err) - t.connected = false - } + //if t.Connected() { + t.UnexpectedDisconnect(t, err) + //} + //t.lock.RLock() + //defer t.lock.RUnlock() + //if t.unexpectedDisconnect != nil && t.connected { + // go t.unexpectedDisconnect(t, err) + // t.connected = false + //} } -func (t *TCPListener) getConn() net.Conn { - t.lock.RLock() - defer t.lock.RUnlock() - return t.conn -} +//func (t *TCPListener) getConn() net.Conn { +// t.lock.RLock() +// defer t.lock.RUnlock() +// return t.conn +//} diff --git a/client/internal/network/response.go b/client/internal/network/response.go index c51c0cdc..556e8519 100644 --- a/client/internal/network/response.go +++ b/client/internal/network/response.go @@ -22,6 +22,7 @@ var ( ErrSessionExpired = errors.New("session expired") ErrPacketDropped = errors.New("packet dropped") ErrInvalidPacketType = errors.New("invalid packet type") + ErrConnectionBroken = errors.New("connection broken") ) func (t *Transport) ReadRequest(head []byte) (*Request, error) { diff --git a/client/internal/network/transport.go b/client/internal/network/transport.go index 06c06e45..2b700a20 100644 --- a/client/internal/network/transport.go +++ b/client/internal/network/transport.go @@ -1,6 +1,11 @@ package network import ( + goBinary "encoding/binary" + "fmt" + "github.com/pkg/errors" + "io" + "net" "strconv" "github.com/Mrs4s/MiraiGo/binary" @@ -15,7 +20,27 @@ type Transport struct { Device *auth.Device // connection - // conn *TCPListener + conn TCPListener +} + +func (t *Transport) PlannedDisconnect(fun func(*TCPListener)) { + t.conn.PlannedDisconnect = fun +} + +func (t *Transport) UnexpectedDisconnect(fun func(*TCPListener, error)) { + t.conn.UnexpectedDisconnect = fun +} + +func (t *Transport) ConnectFastest(servers []*net.TCPAddr) (chosen *net.TCPAddr, err error) { + return t.conn.ConnectFastest(servers) +} + +func (t *Transport) Close() { + t.conn.Close() +} + +func (t *Transport) Write(data []byte) error { + return t.conn.Write(data) } func (t *Transport) packBody(req *Request, w *binary.Writer) { @@ -96,3 +121,59 @@ func (t *Transport) PackPacket(req *Request) []byte { w.WriteUInt32At(pos, uint32(w.Len())) return append([]byte(nil), w.Bytes()...) } + +type PktHandler func(pkt *Request, netErr error) +type RequestHandler func(head []byte) (*Request, error) + +func (t *Transport) NetLoop(pktHandler PktHandler, respHandler RequestHandler) { + go t.netLoop(pktHandler, respHandler) +} + +// readPacket 帮助函数(Helper function) +func readPacket(conn *net.TCPConn, minSize, maxSize uint32) ([]byte, error) { + lBuf := make([]byte, 4) + _, err := io.ReadFull(conn, lBuf) + if err != nil { + return nil, err + } + l := goBinary.BigEndian.Uint32(lBuf) + if l < minSize || l > maxSize { + return nil, fmt.Errorf("parse incoming packet error: invalid packet length %v", l) + } + data := make([]byte, l-4) + _, err = io.ReadFull(conn, data) + return data, err +} + +// netLoop 整个函数周期使用同一个连接,确保不会发生串线这种奇怪的事情 +func (t *Transport) netLoop(pktHandler PktHandler, respHandler RequestHandler) { + conn := t.conn.getConn() + defer func() { + if r := recover(); r != nil { + pktHandler(nil, fmt.Errorf("panic: %v", r)) + } + t.conn.Close() + }() + errCount := 0 + for { + data, err := readPacket(conn, 4, 10<<20) // max 10MB + if err != nil { + // 在且仅在没有新连接建立时断线才被认为是意外的 + if t.conn.getConn() == conn { + pktHandler(nil, errors.Wrap(ErrConnectionBroken, err.Error())) + } + return + } + req, err := respHandler(data) + if err == nil { + errCount = 0 + goto ok + } + errCount++ + if errCount > 2 { + err = errors.Wrap(ErrConnectionBroken, err.Error()) + } + ok: + go pktHandler(req, err) + } +} diff --git a/client/network.go b/client/network.go index ed01ad95..5588f711 100644 --- a/client/network.go +++ b/client/network.go @@ -1,6 +1,7 @@ package client import ( + "github.com/Mrs4s/MiraiGo/message" "net" "runtime/debug" "sync" @@ -10,7 +11,6 @@ import ( "github.com/Mrs4s/MiraiGo/client/internal/network" "github.com/Mrs4s/MiraiGo/client/internal/oicq" - "github.com/Mrs4s/MiraiGo/message" "github.com/Mrs4s/MiraiGo/utils" ) @@ -84,22 +84,22 @@ func (c *QQClient) ConnectionQualityTest() *ConnectionQualityInfo { return r } -// connect 连接到 QQClient.servers 中的服务器 -func (c *QQClient) connect() error { - c.Info("connect to server: %v", c.servers[c.currServerIndex].String()) - err := c.TCP.Connect(c.servers[c.currServerIndex]) - c.currServerIndex++ - if c.currServerIndex == len(c.servers) { - c.currServerIndex = 0 - } +func (c *QQClient) connectFastest() error { + c.Disconnect() + addr, err := c.transport.ConnectFastest(c.servers) if err != nil { - c.retryTimes++ - if c.retryTimes > len(c.servers) { - return errors.New("All servers are unreachable") - } - c.Error("connect server error: %v", err) + c.Disconnect() return err } + c.Debug("connected to server: %v [fastest]", addr.String()) + c.transport.NetLoop(c.pktProc, c.transport.ReadRequest) + c.retryTimes = 0 + c.ConnectTime = time.Now() + return nil +} + +// connect 连接到 QQClient.servers 中的服务器 +func (c *QQClient) connect() error { c.once.Do(func() { c.OnGroupMessage(func(_ *QQClient, _ *message.GroupMessage) { c.stat.MessageReceived.Add(1) @@ -116,11 +116,30 @@ func (c *QQClient) connect() error { c.onGroupMessageReceipt("internal", func(_ *QQClient, _ *groupMessageReceiptEvent) { c.stat.MessageSent.Add(1) }) - go c.netLoop() + //go c.netLoop() }) + return c.connectFastest() // 暂时? + /*c.Info("connect to server: %v", c.servers[c.currServerIndex].String()) + err := c.TCP.Connect(c.servers[c.currServerIndex]) + c.currServerIndex++ + if c.currServerIndex == len(c.servers) { + c.currServerIndex = 0 + } + if err != nil { + c.retryTimes++ + if c.retryTimes > len(c.servers) { + return errors.New("All servers are unreachable") + } + c.Error("connect server error: %v", err) + return err + } c.retryTimes = 0 c.ConnectTime = time.Now() - return nil + return nil*/ +} + +func (c *QQClient) QuickReconnect() { + c.quickReconnect() // TODO "用户请求快速重连" } // quickReconnect 快速重连 @@ -129,7 +148,7 @@ func (c *QQClient) quickReconnect() { time.Sleep(time.Millisecond * 200) if err := c.connect(); err != nil { c.Error("connect server error: %v", err) - c.dispatchDisconnectEvent(&ClientDisconnectedEvent{Message: "quick reconnect failed"}) + c.dispatchDisconnectEvent(&ClientDisconnectedEvent{Message: "快速重连失败"}) return } if err := c.registerClient(); err != nil { @@ -143,7 +162,7 @@ func (c *QQClient) quickReconnect() { // Disconnect 中断连接, 不释放资源 func (c *QQClient) Disconnect() { c.Online.Store(false) - c.TCP.Close() + c.transport.Close() } func (c *QQClient) send(call *network.Call) { @@ -160,9 +179,9 @@ func (c *QQClient) send(call *network.Call) { c.pendingMu.Lock() call = c.pending[seq] delete(c.pending, seq) + c.pendingMu.Unlock() call.Err = err call.Done <- call - c.pendingMu.Unlock() } } @@ -194,7 +213,7 @@ func (c *QQClient) callAndDecode(req *network.Request, decoder func(*QQClient, * // sendPacket 向服务器发送一个数据包 func (c *QQClient) sendPacket(pkt []byte) error { - err := c.TCP.Write(pkt) + err := c.transport.Write(pkt) if err != nil { c.stat.PacketLost.Add(1) } else { @@ -257,100 +276,79 @@ func (c *QQClient) unexpectedDisconnect(_ *network.TCPListener, e error) { } } -// netLoop 通过循环来不停接收数据包 -func (c *QQClient) netLoop() { - errCount := 0 - for c.alive { - l, err := c.TCP.ReadInt32() - if err != nil { - time.Sleep(time.Millisecond * 500) - continue +func (c *QQClient) pktProc(req *network.Request, netErr error) { + if netErr != nil { + switch true { + case errors.Is(netErr, network.ErrConnectionBroken): + go c.dispatchDisconnectEvent(&ClientDisconnectedEvent{Message: netErr.Error()}) + c.QuickReconnect() + case errors.Is(netErr, network.ErrSessionExpired) || errors.Is(netErr, network.ErrPacketDropped): + c.Disconnect() + go c.dispatchDisconnectEvent(&ClientDisconnectedEvent{Message: "session expired"}) } - if l < 4 || l > 1024*1024*10 { // max 10MB - c.Error("parse incoming packet error: invalid packet length %v", l) - errCount++ - if errCount > 2 { - go c.quickReconnect() - } - continue - } - data, _ := c.TCP.ReadBytes(int(l) - 4) - req, err := c.transport.ReadRequest(data) - if err != nil { - c.Error("parse incoming packet error: %v", err) - if errors.Is(err, network.ErrSessionExpired) || errors.Is(err, network.ErrPacketDropped) { - c.Disconnect() - go c.dispatchDisconnectEvent(&ClientDisconnectedEvent{Message: "session expired"}) - continue - } - errCount++ - if errCount > 2 { - go c.quickReconnect() - } - continue - } - if req.EncryptType == network.EncryptTypeEmptyKey { - m, err := c.oicq.Unmarshal(req.Body) - if err != nil { - c.Error("decrypt payload error: %v", err) - if errors.Is(err, oicq.ErrUnknownFlag) { - go c.quickReconnect() - } - continue - } - req.Body = m.Body - } - errCount = 0 - c.Debug("rev pkt: %v seq: %v", req.CommandName, req.SequenceID) - c.stat.PacketReceived.Add(1) - go func(req *network.Request) { - defer func() { - if pan := recover(); pan != nil { - c.Error("panic on decoder %v : %v\n%s", req.CommandName, pan, debug.Stack()) - c.Dump("packet decode error: %v - %v", req.Body, req.CommandName, pan) - } - }() + c.Error("parse incoming packet error: %v", netErr) + return + } - // snapshot of read call - c.pendingMu.Lock() - call := c.pending[uint16(req.SequenceID)] - if call != nil { - call.Response = &network.Response{ - SequenceID: req.SequenceID, - CommandName: req.CommandName, - Body: req.Body, - Params: call.Request.Params, - // Request: nil, - } - } - c.pendingMu.Unlock() - if call != nil && call.Request.CommandName == req.CommandName { - select { - case call.Done <- call: - default: - // we don't want blocking - } - return + if req.EncryptType == network.EncryptTypeEmptyKey { + m, err := c.oicq.Unmarshal(req.Body) + if err != nil { + c.Error("decrypt payload error: %v", err) + if errors.Is(err, oicq.ErrUnknownFlag) { + go c.quickReconnect() // TODO "服务器发送未知响应" } + } + req.Body = m.Body + } - if decoder, ok := decoders[req.CommandName]; ok { - // found predefined decoder - resp := network.Response{ - SequenceID: req.SequenceID, - CommandName: req.CommandName, - Body: req.Body, - // Request: nil, - } - decoded, err := decoder(c, &resp) - if err != nil { - c.Debug("decode req %v error: %+v", req.CommandName, err) - } - if f, ok := c.waiters.Load(req.CommandName); ok { // 在不存在handler的情况下触发wait - f.(func(interface{}, error))(decoded, err) - } - } else { - c.Debug("Unhandled Command: %s\nSeq: %d\nThis message can be ignored.", req.CommandName, req.SequenceID) - } - }(req) + defer func() { + if pan := recover(); pan != nil { + c.Error("panic on decoder %v : %v\n%s", req.CommandName, pan, debug.Stack()) + c.Dump("packet decode error: %v - %v", req.Body, req.CommandName, pan) + } + }() + + c.Debug("rev resp: %v seq: %v", req.CommandName, req.SequenceID) + c.stat.PacketReceived.Add(1) + + // snapshot of read call + c.pendingMu.Lock() + call := c.pending[uint16(req.SequenceID)] + if call != nil { + call.Response = &network.Response{ + SequenceID: req.SequenceID, + CommandName: req.CommandName, + Body: req.Body, + Params: call.Request.Params, + // Request: nil, + } + } + c.pendingMu.Unlock() + if call != nil && call.Request.CommandName == req.CommandName { + select { + case call.Done <- call: + default: + // we don't want blocking + } + return + } + + if decoder, ok := decoders[req.CommandName]; ok { + // found predefined decoder + resp := network.Response{ + SequenceID: req.SequenceID, + CommandName: req.CommandName, + Body: req.Body, + // Request: nil, + } + decoded, err := decoder(c, &resp) + if err != nil { + c.Debug("decode req %v error: %+v", req.CommandName, err) + } + if f, ok := c.waiters.Load(req.CommandName); ok { // 在不存在handler的情况下触发wait + f.(func(interface{}, error))(decoded, err) + } + } else { + c.Debug("Unhandled Command: %s\nSeq: %d\nThis message can be ignored.", req.CommandName, req.SequenceID) } } diff --git a/message/message.go b/message/message.go index 5fc996d9..7d1ef49c 100644 --- a/message/message.go +++ b/message/message.go @@ -236,6 +236,7 @@ func ToProtoElems(elems []IMessageElement, generalFlags bool) (r []*msg.Elem) { if len(elems) == 0 { return nil } + r = make([]*msg.Elem, 0, len(elems)) for _, elem := range elems { if reply, ok := elem.(*ReplyElement); ok { r = append(r, &msg.Elem{ diff --git a/utils/func.go b/utils/func.go new file mode 100644 index 00000000..794de3da --- /dev/null +++ b/utils/func.go @@ -0,0 +1,26 @@ +package utils + +import "fmt" + +// CoverError == catch{} +func CoverError(fun func()) error { + if fun == nil { + return nil + } + errCh := make(chan error, 1) + func() { + defer func() { + if r := recover(); r != nil { + if err, ok := r.(error); ok { + errCh <- err + } else { + errCh <- fmt.Errorf("%v", r) + } + } else { + errCh <- nil + } + }() + fun() + }() + return <-errCh +} diff --git a/utils/func_test.go b/utils/func_test.go new file mode 100644 index 00000000..37c50857 --- /dev/null +++ b/utils/func_test.go @@ -0,0 +1,29 @@ +package utils + +import ( + "errors" + "testing" +) + +var errTest = errors.New("test error") + +func TestCoverError(t *testing.T) { + err := CoverError(nil) + if err != nil { + t.Errorf(`CoverError(nil) = %v, want nil`, err) + } + + err = CoverError(func() { + panic("test") + }) + if err.Error() != "test" { + t.Errorf(`CoverError(func() { panic("test") }) = %v, want "test"`, err) + } + + err = CoverError(func() { + panic(errTest) + }) + if err != errTest { + t.Errorf(`CoverError(func() { panic(errTest) }) = %v, want errTest`, err) + } +} diff --git a/utils/http.go b/utils/http.go index ea594e84..32ca5ce1 100644 --- a/utils/http.go +++ b/utils/http.go @@ -10,10 +10,7 @@ import ( var client = &http.Client{ Transport: &http.Transport{ - ForceAttemptHTTP2: true, - MaxConnsPerHost: 0, - MaxIdleConns: 0, - MaxIdleConnsPerHost: 999, + ForceAttemptHTTP2: true, }, }