1
0
mirror of https://github.com/Mrs4s/MiraiGo.git synced 2025-05-04 19:17:38 +08:00

fix: many of the network issue.

This commit is contained in:
Mrs4s 2021-04-10 19:56:45 +08:00
parent 32a3570b09
commit 8edebe66ad
3 changed files with 213 additions and 108 deletions

View File

@ -4,7 +4,6 @@ import (
"bytes" "bytes"
"crypto/md5" "crypto/md5"
"fmt" "fmt"
"io"
"math" "math"
"math/rand" "math/rand"
"net" "net"
@ -43,12 +42,12 @@ type QQClient struct {
GroupList []*GroupInfo GroupList []*GroupInfo
OnlineClients []*OtherClientInfo OnlineClients []*OtherClientInfo
Online bool Online bool
NetLooping bool // NetLooping bool
SequenceId int32 SequenceId int32
OutGoingPacketSessionId []byte OutGoingPacketSessionId []byte
RandomKey []byte RandomKey []byte
Conn net.Conn TCP *utils.TCPListener
ConnectTime time.Time ConnectTime time.Time
handlers HandlerMap handlers HandlerMap
@ -176,6 +175,7 @@ func NewClientMd5(uin int64, passwordMd5 [16]byte) *QQClient {
AllowSlider: true, AllowSlider: true,
RandomKey: make([]byte, 16), RandomKey: make([]byte, 16),
OutGoingPacketSessionId: []byte{0x02, 0xB0, 0x5B, 0x8B}, OutGoingPacketSessionId: []byte{0x02, 0xB0, 0x5B, 0x8B},
TCP: &utils.TCPListener{},
sigInfo: &loginSigInfo{}, sigInfo: &loginSigInfo{},
requestPacketRequestID: 1921334513, requestPacketRequestID: 1921334513,
groupSeq: int32(rand.Intn(20000)), groupSeq: int32(rand.Intn(20000)),
@ -236,7 +236,10 @@ func NewClientMd5(uin int64, passwordMd5 [16]byte) *QQClient {
if len(cli.servers) > 3 { if len(cli.servers) > 3 {
cli.servers = cli.servers[0 : len(cli.servers)/2] // 保留ping值中位数以上的server cli.servers = cli.servers[0 : len(cli.servers)/2] // 保留ping值中位数以上的server
} }
cli.TCP.PlanedDisconnect(cli.planedDisconnect)
cli.TCP.UnexpectedDisconnect(cli.unexpectedDisconnect)
rand.Read(cli.RandomKey) rand.Read(cli.RandomKey)
go cli.netLoop()
return cli return cli
} }
@ -245,13 +248,10 @@ func (c *QQClient) Login() (*LoginResponse, error) {
if c.Online { if c.Online {
return nil, ErrAlreadyOnline return nil, ErrAlreadyOnline
} }
if c.Conn == nil {
err := c.connect() err := c.connect()
if err != nil { if err != nil {
return nil, err return nil, err
} }
go c.netLoop()
}
rsp, err := c.sendAndWait(c.buildLoginPacket()) rsp, err := c.sendAndWait(c.buildLoginPacket())
if err != nil { if err != nil {
c.Disconnect() c.Disconnect()
@ -268,13 +268,10 @@ func (c *QQClient) TokenLogin(token []byte) error {
if c.Online { if c.Online {
return ErrAlreadyOnline return ErrAlreadyOnline
} }
if c.Conn == nil {
err := c.connect() err := c.connect()
if err != nil { if err != nil {
return err return err
} }
go c.netLoop()
}
{ {
r := binary.NewReader(token) r := binary.NewReader(token)
c.Uin = r.ReadInt64() c.Uin = r.ReadInt64()
@ -288,7 +285,7 @@ func (c *QQClient) TokenLogin(token []byte) error {
c.OutGoingPacketSessionId = r.ReadBytesShort() c.OutGoingPacketSessionId = r.ReadBytesShort()
SystemDeviceInfo.TgtgtKey = r.ReadBytesShort() SystemDeviceInfo.TgtgtKey = r.ReadBytesShort()
} }
_, err := c.sendAndWait(c.buildRequestChangeSigPacket()) _, err = c.sendAndWait(c.buildRequestChangeSigPacket())
if err != nil { if err != nil {
return err return err
} }
@ -300,13 +297,10 @@ func (c *QQClient) FetchQRCode() (*QRCodeLoginResponse, error) {
if c.Online { if c.Online {
return nil, ErrAlreadyOnline return nil, ErrAlreadyOnline
} }
if c.Conn == nil {
err := c.connect() err := c.connect()
if err != nil { if err != nil {
return nil, err return nil, err
} }
go c.netLoop()
}
i, err := c.sendAndWait(c.buildQRCodeFetchRequestPacket()) i, err := c.sendAndWait(c.buildQRCodeFetchRequestPacket())
if err != nil { if err != nil {
return nil, errors.Wrap(err, "fetch qrcode error") return nil, errors.Wrap(err, "fetch qrcode error")
@ -389,7 +383,6 @@ func (c *QQClient) init() {
if len(c.g) == 0 { if len(c.g) == 0 {
c.Warning("device lock is disable. http api may fail.") c.Warning("device lock is disable. http api may fail.")
} }
c.Online = true
_ = c.registerClient() _ = c.registerClient()
c.groupSysMsgCache, _ = c.GetGroupSystemMessages() c.groupSysMsgCache, _ = c.GetGroupSystemMessages()
if !c.heartbeatEnabled { if !c.heartbeatEnabled {
@ -826,35 +819,6 @@ func (g *GroupInfo) removeMember(uin int64) {
}) })
} }
func (c *QQClient) connect() error {
c.Info("connect to server: %v", c.servers[c.currServerIndex].String())
conn, err := net.DialTCP("tcp", nil, c.servers[c.currServerIndex])
c.currServerIndex++
if c.currServerIndex == len(c.servers) {
c.currServerIndex = 0
}
if err != nil || conn == nil {
c.retryTimes++
if c.retryTimes > len(c.servers) {
return errors.New("All servers are unreachable")
}
c.Error("connect server error: %v", err)
return err
}
c.retryTimes = 0
c.ConnectTime = time.Now()
c.Conn = conn
return nil
}
func (c *QQClient) Disconnect() {
c.NetLooping = false
c.Online = false
if c.Conn != nil {
_ = c.Conn.Close()
}
}
func (c *QQClient) SetCustomServer(servers []*net.TCPAddr) { func (c *QQClient) SetCustomServer(servers []*net.TCPAddr) {
c.servers = append(servers, c.servers...) c.servers = append(servers, c.servers...)
} }
@ -866,6 +830,9 @@ func (c *QQClient) SendGroupGift(groupCode, uin uint64, gift message.GroupGift)
func (c *QQClient) registerClient() error { func (c *QQClient) registerClient() error {
_, err := c.sendAndWait(c.buildClientRegisterPacket()) _, err := c.sendAndWait(c.buildClientRegisterPacket())
if err == nil {
c.Online = true
}
return err return err
} }
@ -894,7 +861,7 @@ func (c *QQClient) nextHighwayApplySeq() int32 {
} }
func (c *QQClient) send(pkt []byte) error { func (c *QQClient) send(pkt []byte) error {
_, err := c.Conn.Write(pkt) err := c.TCP.Write(pkt)
if err != nil { if err != nil {
c.stat.PacketLost++ c.stat.PacketLost++
} else { } else {
@ -959,54 +926,89 @@ func (c *QQClient) waitPacket(cmd string, f func(interface{}, error)) func() {
} }
} }
func (c *QQClient) netLoop() { func (c *QQClient) connect() error {
if c.NetLooping { c.Info("connect to server: %v", c.servers[c.currServerIndex].String())
err := c.TCP.Connect(c.servers[c.currServerIndex])
c.currServerIndex++
if c.currServerIndex == len(c.servers) {
c.currServerIndex = 0
}
if err != nil {
c.retryTimes++
if c.retryTimes > len(c.servers) {
return errors.New("All servers are unreachable")
}
c.Error("connect server error: %v", err)
return err
}
c.retryTimes = 0
c.ConnectTime = time.Now()
return nil
}
func (c *QQClient) quickReconnect() {
c.Disconnect()
if err := c.connect(); err != nil {
c.Error("connect server error: %v", err)
c.dispatchDisconnectEvent(&ClientDisconnectedEvent{Message: "quick reconnect failed"})
return return
} }
c.NetLooping = true if err := c.registerClient(); err != nil {
reader := binary.NewNetworkReader(c.Conn) c.Error("register client failed: %v", err)
retry := 0
errCount := 0
for c.NetLooping {
l, err := reader.ReadInt32()
if errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) {
c.Error("connection dropped by server: %v", err)
c.stat.DisconnectTimes++
err = c.connect()
if err != nil {
c.Error("connect server error: %v", err)
break
}
reader = binary.NewNetworkReader(c.Conn)
go func() {
if e := c.registerClient(); e != nil && e.Error() != "Packet timed out" { // 掉线在心跳已经有判断了, 只需要处理返回值
c.lastLostMsg = "register client failed: " + e.Error()
c.Disconnect() c.Disconnect()
c.Error("reconnect failed: " + e.Error()) c.dispatchDisconnectEvent(&ClientDisconnectedEvent{Message: "register error"})
//break return
} }
}() }
func (c *QQClient) Disconnect() {
c.TCP.Close()
}
func (c *QQClient) planedDisconnect(_ *utils.TCPListener) {
c.Debug("planed disconnect.")
c.stat.DisconnectTimes++
c.Online = false
}
func (c *QQClient) unexpectedDisconnect(_ *utils.TCPListener, e error) {
c.Error("unexpected disconnect: %v", e)
c.stat.DisconnectTimes++
if err := c.connect(); err != nil {
c.Error("connect server error: %v", err)
c.dispatchDisconnectEvent(&ClientDisconnectedEvent{Message: "connection dropped by server."})
return
} }
if l <= 0 { if err := c.registerClient(); err != nil {
retry++ c.Error("register client failed: %v", err)
time.Sleep(time.Second * 3) c.Disconnect()
if retry > 10 { c.dispatchDisconnectEvent(&ClientDisconnectedEvent{Message: "register error"})
break return
} }
}
func (c *QQClient) netLoop() {
// todo: release this
errCount := 0
for {
l, err := c.TCP.ReadInt32()
if err != nil {
time.Sleep(time.Millisecond * 500)
continue continue
} }
data, _ := reader.ReadBytes(int(l) - 4) data, _ := c.TCP.ReadBytes(int(l) - 4)
pkt, err := packets.ParseIncomingPacket(data, c.sigInfo.d2Key) pkt, err := packets.ParseIncomingPacket(data, c.sigInfo.d2Key)
if err != nil { if err != nil {
c.Error("parse incoming packet error: %v", err) c.Error("parse incoming packet error: %v", err)
if errors.Is(err, packets.ErrSessionExpired) || errors.Is(err, packets.ErrPacketDropped) { if errors.Is(err, packets.ErrSessionExpired) || errors.Is(err, packets.ErrPacketDropped) {
break c.quickReconnect()
continue
} }
errCount++ errCount++
if errCount > 2 { if errCount > 2 {
break c.quickReconnect()
continue
} }
//log.Println("parse incoming packet error: " + err.Error())
continue continue
} }
payload := pkt.Payload payload := pkt.Payload
@ -1018,7 +1020,6 @@ func (c *QQClient) netLoop() {
} }
} }
errCount = 0 errCount = 0
retry = 0
c.Debug("rev pkt: %v seq: %v", pkt.CommandName, pkt.SequenceId) c.Debug("rev pkt: %v seq: %v", pkt.CommandName, pkt.SequenceId)
c.stat.PacketReceived++ c.stat.PacketReceived++
go func() { go func() {
@ -1057,35 +1058,35 @@ func (c *QQClient) netLoop() {
} }
}() }()
} }
/*
c.NetLooping = false c.NetLooping = false
c.Online = false c.Online = false
_ = c.Conn.Close() _ = c.TCP.Close()
if c.lastLostMsg == "" { if c.lastLostMsg == "" {
c.lastLostMsg = "Connection lost." c.lastLostMsg = "Connection lost."
} }
c.stat.LostTimes++ c.stat.LostTimes++
c.dispatchDisconnectEvent(&ClientDisconnectedEvent{Message: c.lastLostMsg}) c.dispatchDisconnectEvent(&ClientDisconnectedEvent{Message: c.lastLostMsg})
*/
} }
func (c *QQClient) doHeartbeat() { func (c *QQClient) doHeartbeat() {
c.heartbeatEnabled = true c.heartbeatEnabled = true
times := 0 times := 0
for c.Online { for c.Online {
time.Sleep(time.Second * 30)
seq := c.nextSeq() seq := c.nextSeq()
sso := packets.BuildSsoPacket(seq, c.version.AppId, "Heartbeat.Alive", SystemDeviceInfo.IMEI, []byte{}, c.OutGoingPacketSessionId, []byte{}, c.ksid) sso := packets.BuildSsoPacket(seq, c.version.AppId, "Heartbeat.Alive", SystemDeviceInfo.IMEI, []byte{}, c.OutGoingPacketSessionId, []byte{}, c.ksid)
packet := packets.BuildLoginPacket(c.Uin, 0, []byte{}, sso, []byte{}) packet := packets.BuildLoginPacket(c.Uin, 0, []byte{}, sso, []byte{})
_, err := c.sendAndWait(seq, packet) _, err := c.sendAndWait(seq, packet)
if err != nil { if errors.Is(err, utils.ErrConnectionClosed) {
c.lastLostMsg = "Heartbeat failed: " + err.Error() continue
c.Disconnect()
break
} }
times++ times++
if times >= 7 { if times >= 7 {
_ = c.registerClient() _ = c.registerClient()
times = 0 times = 0
} }
time.Sleep(time.Second * 30)
} }
c.heartbeatEnabled = false c.heartbeatEnabled = false
} }

View File

@ -737,17 +737,16 @@ func decodeForceOfflinePacket(c *QQClient, _ *incomingPacketInfo, payload []byte
data.ReadFrom(jce.NewJceReader(request.SBuffer)) data.ReadFrom(jce.NewJceReader(request.SBuffer))
r := jce.NewJceReader(data.Map["req_PushForceOffline"]["PushNotifyPack.RequestPushForceOffline"][1:]) r := jce.NewJceReader(data.Map["req_PushForceOffline"]["PushNotifyPack.RequestPushForceOffline"][1:])
tips := r.ReadString(2) tips := r.ReadString(2)
c.lastLostMsg = tips c.dispatchDisconnectEvent(&ClientDisconnectedEvent{Message: tips})
c.NetLooping = false c.Disconnect()
c.Online = false
return nil, nil return nil, nil
} }
// StatSvc.ReqMSFOffline // StatSvc.ReqMSFOffline
func decodeMSFOfflinePacket(c *QQClient, _ *incomingPacketInfo, _ []byte) (interface{}, error) { func decodeMSFOfflinePacket(c *QQClient, _ *incomingPacketInfo, _ []byte) (interface{}, error) {
c.lastLostMsg = "服务器端强制下线." // c.lastLostMsg = "服务器端强制下线."
c.NetLooping = false c.Disconnect()
c.Online = false c.dispatchDisconnectEvent(&ClientDisconnectedEvent{Message: "服务端强制下线."})
return nil, nil return nil, nil
} }

105
utils/connection.go Normal file
View File

@ -0,0 +1,105 @@
package utils
import (
"github.com/pkg/errors"
"io"
"net"
"time"
)
type TCPListener struct {
conn net.Conn
planedDisconnect func(*TCPListener)
unexpectedDisconnect func(*TCPListener, error)
}
var ErrConnectionClosed = errors.New("connection closed")
// PlanedDisconnect 预料中的断开连接
// 如调用 Close() Connect()
func (t *TCPListener) PlanedDisconnect(f func(*TCPListener)) {
t.planedDisconnect = f
}
// UnexpectedDisconnect 未预料钟的断开连接
func (t *TCPListener) UnexpectedDisconnect(f func(*TCPListener, error)) {
t.unexpectedDisconnect = f
}
func (t *TCPListener) Connect(addr *net.TCPAddr) error {
t.Close()
conn, err := net.DialTCP("tcp", nil, addr)
if err != nil {
return errors.Wrap(err, "dial tcp error")
}
t.conn = conn
return nil
}
func (t *TCPListener) Write(buf []byte) error {
if t.conn == nil {
return ErrConnectionClosed
}
_, err := t.conn.Write(buf)
if err != nil {
if t.conn != nil {
t.close()
t.invokeUnexpectedDisconnect(err)
}
return ErrConnectionClosed
}
return nil
}
func (t *TCPListener) ReadBytes(len int) ([]byte, error) {
if t.conn == nil {
return nil, ErrConnectionClosed
}
buf := make([]byte, len)
_, err := io.ReadFull(t.conn, buf)
if err != nil {
time.Sleep(time.Millisecond * 100) // 服务器会发送offline包后立即断开连接, 此时还没解析
if t.conn != nil {
t.close()
t.invokeUnexpectedDisconnect(err)
}
return nil, ErrConnectionClosed
}
return buf, nil
}
func (t *TCPListener) ReadInt32() (int32, error) {
b, err := t.ReadBytes(4)
if err != nil {
return 0, err
}
return (int32(b[0]) << 24) | (int32(b[1]) << 16) | (int32(b[2]) << 8) | int32(b[3]), nil
}
func (t *TCPListener) Close() {
if t.conn == nil {
return
}
t.close()
t.invokePlanedDisconnect()
}
func (t *TCPListener) close() {
if t.conn == nil {
return
}
_ = t.conn.Close()
t.conn = nil
}
func (t *TCPListener) invokePlanedDisconnect() {
if t.planedDisconnect != nil {
t.planedDisconnect(t)
}
}
func (t *TCPListener) invokeUnexpectedDisconnect(err error) {
if t.unexpectedDisconnect != nil {
t.unexpectedDisconnect(t, err)
}
}