diff --git a/client/client.go b/client/client.go index f0817d4e..29c1f199 100644 --- a/client/client.go +++ b/client/client.go @@ -4,7 +4,6 @@ import ( "bytes" "crypto/md5" "fmt" - "io" "math" "math/rand" "net" @@ -43,12 +42,12 @@ type QQClient struct { GroupList []*GroupInfo OnlineClients []*OtherClientInfo Online bool - NetLooping bool + // NetLooping bool SequenceId int32 OutGoingPacketSessionId []byte RandomKey []byte - Conn net.Conn + TCP *utils.TCPListener ConnectTime time.Time handlers HandlerMap @@ -176,6 +175,7 @@ func NewClientMd5(uin int64, passwordMd5 [16]byte) *QQClient { AllowSlider: true, RandomKey: make([]byte, 16), OutGoingPacketSessionId: []byte{0x02, 0xB0, 0x5B, 0x8B}, + TCP: &utils.TCPListener{}, sigInfo: &loginSigInfo{}, requestPacketRequestID: 1921334513, groupSeq: int32(rand.Intn(20000)), @@ -236,7 +236,10 @@ func NewClientMd5(uin int64, passwordMd5 [16]byte) *QQClient { if len(cli.servers) > 3 { cli.servers = cli.servers[0 : len(cli.servers)/2] // 保留ping值中位数以上的server } + cli.TCP.PlanedDisconnect(cli.planedDisconnect) + cli.TCP.UnexpectedDisconnect(cli.unexpectedDisconnect) rand.Read(cli.RandomKey) + go cli.netLoop() return cli } @@ -245,12 +248,9 @@ func (c *QQClient) Login() (*LoginResponse, error) { if c.Online { return nil, ErrAlreadyOnline } - if c.Conn == nil { - err := c.connect() - if err != nil { - return nil, err - } - go c.netLoop() + err := c.connect() + if err != nil { + return nil, err } rsp, err := c.sendAndWait(c.buildLoginPacket()) if err != nil { @@ -268,12 +268,9 @@ func (c *QQClient) TokenLogin(token []byte) error { if c.Online { return ErrAlreadyOnline } - if c.Conn == nil { - err := c.connect() - if err != nil { - return err - } - go c.netLoop() + err := c.connect() + if err != nil { + return err } { r := binary.NewReader(token) @@ -288,7 +285,7 @@ func (c *QQClient) TokenLogin(token []byte) error { c.OutGoingPacketSessionId = r.ReadBytesShort() SystemDeviceInfo.TgtgtKey = r.ReadBytesShort() } - _, err := c.sendAndWait(c.buildRequestChangeSigPacket()) + _, err = c.sendAndWait(c.buildRequestChangeSigPacket()) if err != nil { return err } @@ -300,12 +297,9 @@ func (c *QQClient) FetchQRCode() (*QRCodeLoginResponse, error) { if c.Online { return nil, ErrAlreadyOnline } - if c.Conn == nil { - err := c.connect() - if err != nil { - return nil, err - } - go c.netLoop() + err := c.connect() + if err != nil { + return nil, err } i, err := c.sendAndWait(c.buildQRCodeFetchRequestPacket()) if err != nil { @@ -389,7 +383,6 @@ func (c *QQClient) init() { if len(c.g) == 0 { c.Warning("device lock is disable. http api may fail.") } - c.Online = true _ = c.registerClient() c.groupSysMsgCache, _ = c.GetGroupSystemMessages() if !c.heartbeatEnabled { @@ -826,35 +819,6 @@ func (g *GroupInfo) removeMember(uin int64) { }) } -func (c *QQClient) connect() error { - c.Info("connect to server: %v", c.servers[c.currServerIndex].String()) - conn, err := net.DialTCP("tcp", nil, c.servers[c.currServerIndex]) - c.currServerIndex++ - if c.currServerIndex == len(c.servers) { - c.currServerIndex = 0 - } - if err != nil || conn == nil { - c.retryTimes++ - if c.retryTimes > len(c.servers) { - return errors.New("All servers are unreachable") - } - c.Error("connect server error: %v", err) - return err - } - c.retryTimes = 0 - c.ConnectTime = time.Now() - c.Conn = conn - return nil -} - -func (c *QQClient) Disconnect() { - c.NetLooping = false - c.Online = false - if c.Conn != nil { - _ = c.Conn.Close() - } -} - func (c *QQClient) SetCustomServer(servers []*net.TCPAddr) { c.servers = append(servers, c.servers...) } @@ -866,6 +830,9 @@ func (c *QQClient) SendGroupGift(groupCode, uin uint64, gift message.GroupGift) func (c *QQClient) registerClient() error { _, err := c.sendAndWait(c.buildClientRegisterPacket()) + if err == nil { + c.Online = true + } return err } @@ -894,7 +861,7 @@ func (c *QQClient) nextHighwayApplySeq() int32 { } func (c *QQClient) send(pkt []byte) error { - _, err := c.Conn.Write(pkt) + err := c.TCP.Write(pkt) if err != nil { c.stat.PacketLost++ } else { @@ -959,54 +926,89 @@ func (c *QQClient) waitPacket(cmd string, f func(interface{}, error)) func() { } } -func (c *QQClient) netLoop() { - if c.NetLooping { +func (c *QQClient) connect() error { + c.Info("connect to server: %v", c.servers[c.currServerIndex].String()) + err := c.TCP.Connect(c.servers[c.currServerIndex]) + c.currServerIndex++ + if c.currServerIndex == len(c.servers) { + c.currServerIndex = 0 + } + if err != nil { + c.retryTimes++ + if c.retryTimes > len(c.servers) { + return errors.New("All servers are unreachable") + } + c.Error("connect server error: %v", err) + return err + } + c.retryTimes = 0 + c.ConnectTime = time.Now() + return nil +} + +func (c *QQClient) quickReconnect() { + c.Disconnect() + if err := c.connect(); err != nil { + c.Error("connect server error: %v", err) + c.dispatchDisconnectEvent(&ClientDisconnectedEvent{Message: "quick reconnect failed"}) return } - c.NetLooping = true - reader := binary.NewNetworkReader(c.Conn) - retry := 0 + if err := c.registerClient(); err != nil { + c.Error("register client failed: %v", err) + c.Disconnect() + c.dispatchDisconnectEvent(&ClientDisconnectedEvent{Message: "register error"}) + return + } +} + +func (c *QQClient) Disconnect() { + c.TCP.Close() +} + +func (c *QQClient) planedDisconnect(_ *utils.TCPListener) { + c.Debug("planed disconnect.") + c.stat.DisconnectTimes++ + c.Online = false +} + +func (c *QQClient) unexpectedDisconnect(_ *utils.TCPListener, e error) { + c.Error("unexpected disconnect: %v", e) + c.stat.DisconnectTimes++ + if err := c.connect(); err != nil { + c.Error("connect server error: %v", err) + c.dispatchDisconnectEvent(&ClientDisconnectedEvent{Message: "connection dropped by server."}) + return + } + if err := c.registerClient(); err != nil { + c.Error("register client failed: %v", err) + c.Disconnect() + c.dispatchDisconnectEvent(&ClientDisconnectedEvent{Message: "register error"}) + return + } +} + +func (c *QQClient) netLoop() { + // todo: release this errCount := 0 - for c.NetLooping { - l, err := reader.ReadInt32() - if errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) { - c.Error("connection dropped by server: %v", err) - c.stat.DisconnectTimes++ - err = c.connect() - if err != nil { - c.Error("connect server error: %v", err) - break - } - reader = binary.NewNetworkReader(c.Conn) - go func() { - if e := c.registerClient(); e != nil && e.Error() != "Packet timed out" { // 掉线在心跳已经有判断了, 只需要处理返回值 - c.lastLostMsg = "register client failed: " + e.Error() - c.Disconnect() - c.Error("reconnect failed: " + e.Error()) - //break - } - }() - } - if l <= 0 { - retry++ - time.Sleep(time.Second * 3) - if retry > 10 { - break - } + for { + l, err := c.TCP.ReadInt32() + if err != nil { + time.Sleep(time.Millisecond * 500) continue } - data, _ := reader.ReadBytes(int(l) - 4) + data, _ := c.TCP.ReadBytes(int(l) - 4) pkt, err := packets.ParseIncomingPacket(data, c.sigInfo.d2Key) if err != nil { c.Error("parse incoming packet error: %v", err) if errors.Is(err, packets.ErrSessionExpired) || errors.Is(err, packets.ErrPacketDropped) { - break + c.quickReconnect() + continue } errCount++ if errCount > 2 { - break + c.quickReconnect() + continue } - //log.Println("parse incoming packet error: " + err.Error()) continue } payload := pkt.Payload @@ -1018,7 +1020,6 @@ func (c *QQClient) netLoop() { } } errCount = 0 - retry = 0 c.Debug("rev pkt: %v seq: %v", pkt.CommandName, pkt.SequenceId) c.stat.PacketReceived++ go func() { @@ -1057,35 +1058,35 @@ func (c *QQClient) netLoop() { } }() } - c.NetLooping = false - c.Online = false - _ = c.Conn.Close() - if c.lastLostMsg == "" { - c.lastLostMsg = "Connection lost." - } - c.stat.LostTimes++ - c.dispatchDisconnectEvent(&ClientDisconnectedEvent{Message: c.lastLostMsg}) + /* + c.NetLooping = false + c.Online = false + _ = c.TCP.Close() + if c.lastLostMsg == "" { + c.lastLostMsg = "Connection lost." + } + c.stat.LostTimes++ + c.dispatchDisconnectEvent(&ClientDisconnectedEvent{Message: c.lastLostMsg}) + */ } func (c *QQClient) doHeartbeat() { c.heartbeatEnabled = true times := 0 for c.Online { + time.Sleep(time.Second * 30) seq := c.nextSeq() sso := packets.BuildSsoPacket(seq, c.version.AppId, "Heartbeat.Alive", SystemDeviceInfo.IMEI, []byte{}, c.OutGoingPacketSessionId, []byte{}, c.ksid) packet := packets.BuildLoginPacket(c.Uin, 0, []byte{}, sso, []byte{}) _, err := c.sendAndWait(seq, packet) - if err != nil { - c.lastLostMsg = "Heartbeat failed: " + err.Error() - c.Disconnect() - break + if errors.Is(err, utils.ErrConnectionClosed) { + continue } times++ if times >= 7 { _ = c.registerClient() times = 0 } - time.Sleep(time.Second * 30) } c.heartbeatEnabled = false } diff --git a/client/decoders.go b/client/decoders.go index f48fbf09..e83587c2 100644 --- a/client/decoders.go +++ b/client/decoders.go @@ -737,17 +737,16 @@ func decodeForceOfflinePacket(c *QQClient, _ *incomingPacketInfo, payload []byte data.ReadFrom(jce.NewJceReader(request.SBuffer)) r := jce.NewJceReader(data.Map["req_PushForceOffline"]["PushNotifyPack.RequestPushForceOffline"][1:]) tips := r.ReadString(2) - c.lastLostMsg = tips - c.NetLooping = false - c.Online = false + c.dispatchDisconnectEvent(&ClientDisconnectedEvent{Message: tips}) + c.Disconnect() return nil, nil } // StatSvc.ReqMSFOffline func decodeMSFOfflinePacket(c *QQClient, _ *incomingPacketInfo, _ []byte) (interface{}, error) { - c.lastLostMsg = "服务器端强制下线." - c.NetLooping = false - c.Online = false + // c.lastLostMsg = "服务器端强制下线." + c.Disconnect() + c.dispatchDisconnectEvent(&ClientDisconnectedEvent{Message: "服务端强制下线."}) return nil, nil } diff --git a/utils/connection.go b/utils/connection.go new file mode 100644 index 00000000..1f26c81d --- /dev/null +++ b/utils/connection.go @@ -0,0 +1,105 @@ +package utils + +import ( + "github.com/pkg/errors" + "io" + "net" + "time" +) + +type TCPListener struct { + conn net.Conn + planedDisconnect func(*TCPListener) + unexpectedDisconnect func(*TCPListener, error) +} + +var ErrConnectionClosed = errors.New("connection closed") + +// PlanedDisconnect 预料中的断开连接 +// 如调用 Close() Connect() +func (t *TCPListener) PlanedDisconnect(f func(*TCPListener)) { + t.planedDisconnect = f +} + +// UnexpectedDisconnect 未预料钟的断开连接 +func (t *TCPListener) UnexpectedDisconnect(f func(*TCPListener, error)) { + t.unexpectedDisconnect = f +} + +func (t *TCPListener) Connect(addr *net.TCPAddr) error { + t.Close() + conn, err := net.DialTCP("tcp", nil, addr) + if err != nil { + return errors.Wrap(err, "dial tcp error") + } + t.conn = conn + return nil +} + +func (t *TCPListener) Write(buf []byte) error { + if t.conn == nil { + return ErrConnectionClosed + } + _, err := t.conn.Write(buf) + if err != nil { + if t.conn != nil { + t.close() + t.invokeUnexpectedDisconnect(err) + } + return ErrConnectionClosed + } + return nil +} + +func (t *TCPListener) ReadBytes(len int) ([]byte, error) { + if t.conn == nil { + return nil, ErrConnectionClosed + } + buf := make([]byte, len) + _, err := io.ReadFull(t.conn, buf) + if err != nil { + time.Sleep(time.Millisecond * 100) // 服务器会发送offline包后立即断开连接, 此时还没解析 + if t.conn != nil { + t.close() + t.invokeUnexpectedDisconnect(err) + } + return nil, ErrConnectionClosed + } + return buf, nil +} + +func (t *TCPListener) ReadInt32() (int32, error) { + b, err := t.ReadBytes(4) + if err != nil { + return 0, err + } + return (int32(b[0]) << 24) | (int32(b[1]) << 16) | (int32(b[2]) << 8) | int32(b[3]), nil +} + +func (t *TCPListener) Close() { + if t.conn == nil { + return + } + t.close() + t.invokePlanedDisconnect() +} + +func (t *TCPListener) close() { + if t.conn == nil { + return + } + _ = t.conn.Close() + t.conn = nil +} + +func (t *TCPListener) invokePlanedDisconnect() { + if t.planedDisconnect != nil { + t.planedDisconnect(t) + } +} + +func (t *TCPListener) invokeUnexpectedDisconnect(err error) { + if t.unexpectedDisconnect != nil { + t.unexpectedDisconnect(t, err) + } +}