mirror of
https://github.com/Mrs4s/MiraiGo.git
synced 2025-05-04 19:17:38 +08:00
Network refactor (#229)
* 重构:netLoop下移到Listener * 格式:修正NewClient中注释段缩进 * 格式:更名&将代码移动到对应位置 * 格式:TCPListener去锁化 * 修正:netLoop中的错误调用 * 修正:使其可用 * 修正:使功能一致 * 修正:现在可正常运行 * 优化:更早的释放锁(? * 修正:未写完的部分 * 修正:潜在的断线时仍然认为在线这件事&删除空重复文件 * 文档:添加部分注释 * 修正:CoverError可能引起死锁 * 修正:永远不会被触发的DisconnectEvent * 文档:将注释移动至对应位置
This commit is contained in:
parent
9bd6b38f90
commit
f5b16b19c2
171
client/client.go
171
client/client.go
@ -48,12 +48,12 @@ type QQClient struct {
|
|||||||
SequenceId atomic.Int32
|
SequenceId atomic.Int32
|
||||||
SessionId []byte
|
SessionId []byte
|
||||||
RandomKey []byte
|
RandomKey []byte
|
||||||
TCP *network.TCPListener
|
|
||||||
ConnectTime time.Time
|
ConnectTime time.Time
|
||||||
|
|
||||||
// todo: combine net conn, transport, pending into one struct
|
// todo: combine net conn, transport, pending into one struct
|
||||||
pendingMu sync.Mutex
|
pendingMu sync.Mutex
|
||||||
pending map[uint16]*network.Call
|
pending map[uint16]*network.Call
|
||||||
|
//TCP *network.TCPListener
|
||||||
transport *network.Transport
|
transport *network.Transport
|
||||||
oicq *oicq.Codec
|
oicq *oicq.Codec
|
||||||
|
|
||||||
@ -131,7 +131,7 @@ func NewClientMd5(uin int64, passwordMd5 [16]byte) *QQClient {
|
|||||||
Uin: uin,
|
Uin: uin,
|
||||||
PasswordMd5: passwordMd5,
|
PasswordMd5: passwordMd5,
|
||||||
AllowSlider: true,
|
AllowSlider: true,
|
||||||
TCP: &network.TCPListener{},
|
//TCP: &network.TCPListener{},
|
||||||
sig: &auth.SigInfo{
|
sig: &auth.SigInfo{
|
||||||
OutPacketSessionID: []byte{0x02, 0xB0, 0x5B, 0x8B},
|
OutPacketSessionID: []byte{0x02, 0xB0, 0x5B, 0x8B},
|
||||||
},
|
},
|
||||||
@ -181,15 +181,15 @@ func NewClientMd5(uin int64, passwordMd5 [16]byte) *QQClient {
|
|||||||
}
|
}
|
||||||
if len(cli.servers) == 0 {
|
if len(cli.servers) == 0 {
|
||||||
cli.servers = []*net.TCPAddr{ // default servers
|
cli.servers = []*net.TCPAddr{ // default servers
|
||||||
{IP: net.IP{42, 81, 172, 81}, Port: 80},
|
|
||||||
{IP: net.IP{114, 221, 148, 59}, Port: 14000},
|
|
||||||
{IP: net.IP{42, 81, 172, 147}, Port: 443},
|
|
||||||
{IP: net.IP{125, 94, 60, 146}, Port: 80},
|
|
||||||
{IP: net.IP{114, 221, 144, 215}, Port: 80},
|
|
||||||
{IP: net.IP{42, 81, 172, 22}, Port: 80},
|
{IP: net.IP{42, 81, 172, 22}, Port: 80},
|
||||||
|
{IP: net.IP{42, 81, 172, 81}, Port: 80},
|
||||||
|
{IP: net.IP{42, 81, 172, 147}, Port: 443},
|
||||||
|
{IP: net.IP{114, 221, 144, 215}, Port: 80},
|
||||||
|
{IP: net.IP{114, 221, 148, 59}, Port: 14000},
|
||||||
|
{IP: net.IP{125, 94, 60, 146}, Port: 80},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
pings := make([]int64, len(cli.servers))
|
/*pings := make([]int64, len(cli.servers))
|
||||||
wg := sync.WaitGroup{}
|
wg := sync.WaitGroup{}
|
||||||
wg.Add(len(cli.servers))
|
wg.Add(len(cli.servers))
|
||||||
for i := range cli.servers {
|
for i := range cli.servers {
|
||||||
@ -209,9 +209,9 @@ func NewClientMd5(uin int64, passwordMd5 [16]byte) *QQClient {
|
|||||||
})
|
})
|
||||||
if len(cli.servers) > 3 {
|
if len(cli.servers) > 3 {
|
||||||
cli.servers = cli.servers[0 : len(cli.servers)/2] // 保留ping值中位数以上的server
|
cli.servers = cli.servers[0 : len(cli.servers)/2] // 保留ping值中位数以上的server
|
||||||
}
|
}*/
|
||||||
cli.TCP.PlannedDisconnect(cli.plannedDisconnect)
|
cli.transport.PlannedDisconnect(cli.plannedDisconnect)
|
||||||
cli.TCP.UnexpectedDisconnect(cli.unexpectedDisconnect)
|
cli.transport.UnexpectedDisconnect(cli.unexpectedDisconnect)
|
||||||
rand.Read(cli.RandomKey)
|
rand.Read(cli.RandomKey)
|
||||||
return cli
|
return cli
|
||||||
}
|
}
|
||||||
@ -255,11 +255,50 @@ func (c *QQClient) TokenLogin(token []byte) error {
|
|||||||
if c.Online.Load() {
|
if c.Online.Load() {
|
||||||
return ErrAlreadyOnline
|
return ErrAlreadyOnline
|
||||||
}
|
}
|
||||||
err := c.connect()
|
err := c.LoadToken(token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
{
|
return c.ReLogin()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *QQClient) ReLogin() error {
|
||||||
|
if c.Online.Load() {
|
||||||
|
return ErrAlreadyOnline
|
||||||
|
}
|
||||||
|
err := c.connectFastest()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_, err = c.callAndDecode(c.buildRequestChangeSigRequest(c.version.MainSigMap), decodeExchangeEmpResponse)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err = c.init(true)
|
||||||
|
// 登录失败
|
||||||
|
if err != nil {
|
||||||
|
c.Disconnect()
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *QQClient) DumpToken() []byte {
|
||||||
|
return binary.NewWriterF(func(w *binary.Writer) {
|
||||||
|
w.WriteUInt64(uint64(c.Uin))
|
||||||
|
w.WriteBytesShort(c.sig.D2)
|
||||||
|
w.WriteBytesShort(c.sig.D2Key)
|
||||||
|
w.WriteBytesShort(c.sig.TGT)
|
||||||
|
w.WriteBytesShort(c.sig.SrmToken)
|
||||||
|
w.WriteBytesShort(c.sig.T133)
|
||||||
|
w.WriteBytesShort(c.sig.EncryptedA1)
|
||||||
|
w.WriteBytesShort(c.oicq.WtSessionTicketKey)
|
||||||
|
w.WriteBytesShort(c.sig.OutPacketSessionID)
|
||||||
|
w.WriteBytesShort(c.deviceInfo.TgtgtKey)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *QQClient) LoadToken(token []byte) error {
|
||||||
|
return utils.CoverError(func() {
|
||||||
r := binary.NewReader(token)
|
r := binary.NewReader(token)
|
||||||
c.Uin = r.ReadInt64()
|
c.Uin = r.ReadInt64()
|
||||||
c.sig.D2 = r.ReadBytesShort()
|
c.sig.D2 = r.ReadBytesShort()
|
||||||
@ -272,18 +311,83 @@ func (c *QQClient) TokenLogin(token []byte) error {
|
|||||||
c.sig.OutPacketSessionID = r.ReadBytesShort()
|
c.sig.OutPacketSessionID = r.ReadBytesShort()
|
||||||
// SystemDeviceInfo.TgtgtKey = r.ReadBytesShort()
|
// SystemDeviceInfo.TgtgtKey = r.ReadBytesShort()
|
||||||
c.deviceInfo.TgtgtKey = r.ReadBytesShort()
|
c.deviceInfo.TgtgtKey = r.ReadBytesShort()
|
||||||
}
|
copy(SystemDeviceInfo.TgtgtKey, c.deviceInfo.TgtgtKey)
|
||||||
_, err = c.callAndDecode(c.buildRequestChangeSigRequest(c.version.MainSigMap), decodeExchangeEmpResponse)
|
})
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return c.init(true)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *QQClient) DumpDevice() []byte {
|
||||||
|
return binary.NewWriterF(func(w *binary.Writer) {
|
||||||
|
w.WriteBytesShort(c.deviceInfo.Display)
|
||||||
|
w.WriteBytesShort(c.deviceInfo.Product)
|
||||||
|
w.WriteBytesShort(c.deviceInfo.Device)
|
||||||
|
w.WriteBytesShort(c.deviceInfo.Board)
|
||||||
|
w.WriteBytesShort(c.deviceInfo.Brand)
|
||||||
|
w.WriteBytesShort(c.deviceInfo.Model)
|
||||||
|
w.WriteBytesShort(c.deviceInfo.Bootloader)
|
||||||
|
w.WriteBytesShort(c.deviceInfo.FingerPrint)
|
||||||
|
w.WriteBytesShort(c.deviceInfo.BootId)
|
||||||
|
w.WriteBytesShort(c.deviceInfo.ProcVersion)
|
||||||
|
w.WriteBytesShort(c.deviceInfo.BaseBand)
|
||||||
|
w.WriteBytesShort(c.deviceInfo.SimInfo)
|
||||||
|
w.WriteBytesShort(c.deviceInfo.OSType)
|
||||||
|
w.WriteBytesShort(c.deviceInfo.MacAddress)
|
||||||
|
w.WriteBytesShort(c.deviceInfo.IpAddress)
|
||||||
|
w.WriteBytesShort(c.deviceInfo.WifiBSSID)
|
||||||
|
w.WriteBytesShort(c.deviceInfo.WifiSSID)
|
||||||
|
w.WriteBytesShort(c.deviceInfo.IMSIMd5)
|
||||||
|
w.WriteStringShort(c.deviceInfo.IMEI)
|
||||||
|
w.WriteBytesShort(c.deviceInfo.APN)
|
||||||
|
w.WriteBytesShort(c.deviceInfo.VendorName)
|
||||||
|
w.WriteBytesShort(c.deviceInfo.VendorOSName)
|
||||||
|
w.WriteBytesShort(c.deviceInfo.AndroidId)
|
||||||
|
|
||||||
|
w.Write(c.PasswordMd5[:])
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *QQClient) LoadDevice(device []byte) error {
|
||||||
|
return utils.CoverError(func() {
|
||||||
|
r := binary.NewReader(device)
|
||||||
|
c.deviceInfo.Display = r.ReadBytesShort()
|
||||||
|
c.deviceInfo.Product = r.ReadBytesShort()
|
||||||
|
c.deviceInfo.Device = r.ReadBytesShort()
|
||||||
|
c.deviceInfo.Board = r.ReadBytesShort()
|
||||||
|
c.deviceInfo.Brand = r.ReadBytesShort()
|
||||||
|
c.deviceInfo.Model = r.ReadBytesShort()
|
||||||
|
c.deviceInfo.Bootloader = r.ReadBytesShort()
|
||||||
|
c.deviceInfo.FingerPrint = r.ReadBytesShort()
|
||||||
|
c.deviceInfo.BootId = r.ReadBytesShort()
|
||||||
|
c.deviceInfo.ProcVersion = r.ReadBytesShort()
|
||||||
|
c.deviceInfo.BaseBand = r.ReadBytesShort()
|
||||||
|
c.deviceInfo.SimInfo = r.ReadBytesShort()
|
||||||
|
c.deviceInfo.OSType = r.ReadBytesShort()
|
||||||
|
c.deviceInfo.MacAddress = r.ReadBytesShort()
|
||||||
|
c.deviceInfo.IpAddress = r.ReadBytesShort()
|
||||||
|
c.deviceInfo.WifiBSSID = r.ReadBytesShort()
|
||||||
|
c.deviceInfo.WifiSSID = r.ReadBytesShort()
|
||||||
|
c.deviceInfo.IMSIMd5 = r.ReadBytesShort()
|
||||||
|
c.deviceInfo.IMEI = r.ReadStringShort()
|
||||||
|
c.deviceInfo.APN = r.ReadBytesShort()
|
||||||
|
c.deviceInfo.VendorName = r.ReadBytesShort()
|
||||||
|
c.deviceInfo.VendorOSName = r.ReadBytesShort()
|
||||||
|
c.deviceInfo.AndroidId = r.ReadBytesShort()
|
||||||
|
|
||||||
|
copy(c.PasswordMd5[:], r.ReadBytes(md5.Size))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// FetchQRCode 以默认值获取登录二维码
|
||||||
|
// 函数已被弃用 请使用FetchQRCodeCustomSize获得更可控结果
|
||||||
|
// 但该兼容函数不会被删除
|
||||||
|
// Deprecated use FetchQRCodeCustomSize(3, 4, 2) instead
|
||||||
func (c *QQClient) FetchQRCode() (*QRCodeLoginResponse, error) {
|
func (c *QQClient) FetchQRCode() (*QRCodeLoginResponse, error) {
|
||||||
return c.FetchQRCodeCustomSize(3, 4, 2)
|
return c.FetchQRCodeCustomSize(3, 4, 2)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// FetchQRCodeCustomSize 以特定参数获取登录二维码
|
||||||
|
// size: 块尺寸 默认值3 即单个黑/白块大小为3x3像素
|
||||||
|
// margin: 与图片边界的距离 默认值4 即二维码主体至图片边界有4像素白色填充
|
||||||
|
// ecLevel: 纠错等级 可用值:1,2,3 默认值2
|
||||||
func (c *QQClient) FetchQRCodeCustomSize(size, margin, ecLevel uint32) (*QRCodeLoginResponse, error) {
|
func (c *QQClient) FetchQRCodeCustomSize(size, margin, ecLevel uint32) (*QRCodeLoginResponse, error) {
|
||||||
if c.Online.Load() {
|
if c.Online.Load() {
|
||||||
return nil, ErrAlreadyOnline
|
return nil, ErrAlreadyOnline
|
||||||
@ -400,15 +504,13 @@ func (c *QQClient) init(tokenLogin bool) error {
|
|||||||
d2()
|
d2()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
c.groupSysMsgCache, _ = c.GetGroupSystemMessages()
|
go c.doHeartbeat()
|
||||||
if !c.heartbeatEnabled {
|
|
||||||
go c.doHeartbeat()
|
|
||||||
}
|
|
||||||
_ = c.RefreshStatus()
|
_ = c.RefreshStatus()
|
||||||
if c.version.Protocol == auth.QiDian {
|
if c.version.Protocol == auth.QiDian {
|
||||||
_, _ = c.callAndDecode(c.buildLoginExtraPacket(), decodeLoginExtraResponse) // 小登录
|
_, _ = c.callAndDecode(c.buildLoginExtraPacket(), decodeLoginExtraResponse) // 小登录
|
||||||
_, _ = c.callAndDecode(c.buildConnKeyRequestPacket(), decodeConnKeyResponse) // big data key 如果等待 config push 的话时间来不及
|
_, _ = c.callAndDecode(c.buildConnKeyRequestPacket(), decodeConnKeyResponse) // big data key 如果等待 config push 的话时间来不及
|
||||||
}
|
}
|
||||||
|
c.groupSysMsgCache, _ = c.GetGroupSystemMessages()
|
||||||
req := c.buildGetMessageRequest(msg.SyncFlag_START, time.Now().Unix())
|
req := c.buildGetMessageRequest(msg.SyncFlag_START, time.Now().Unix())
|
||||||
req.Params = network.Params{"used_reg_proxy": true, "init": true}
|
req.Params = network.Params{"used_reg_proxy": true, "init": true}
|
||||||
_, _ = c.callAndDecode(req, decodeMessageSvcPacket)
|
_, _ = c.callAndDecode(req, decodeMessageSvcPacket)
|
||||||
@ -765,10 +867,21 @@ func (c *QQClient) nextHighwayApplySeq() int32 {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *QQClient) doHeartbeat() {
|
func (c *QQClient) doHeartbeat() {
|
||||||
|
// 不需要atomic/锁
|
||||||
|
if c.heartbeatEnabled {
|
||||||
|
return
|
||||||
|
}
|
||||||
c.heartbeatEnabled = true
|
c.heartbeatEnabled = true
|
||||||
|
defer func() {
|
||||||
|
c.heartbeatEnabled = false
|
||||||
|
}()
|
||||||
times := 0
|
times := 0
|
||||||
for c.Online.Load() {
|
ticker := time.NewTicker(time.Second * 30)
|
||||||
time.Sleep(time.Second * 30)
|
for range ticker.C {
|
||||||
|
if !c.Online.Load() {
|
||||||
|
ticker.Stop()
|
||||||
|
return // 下线停止goroutine,for gc
|
||||||
|
}
|
||||||
seq := c.nextSeq()
|
seq := c.nextSeq()
|
||||||
req := network.Request{
|
req := network.Request{
|
||||||
Type: network.RequestTypeLogin,
|
Type: network.RequestTypeLogin,
|
||||||
@ -779,8 +892,11 @@ func (c *QQClient) doHeartbeat() {
|
|||||||
Body: EmptyBytes,
|
Body: EmptyBytes,
|
||||||
}
|
}
|
||||||
_, err := c.call(&req)
|
_, err := c.call(&req)
|
||||||
if errors.Is(err, network.ErrConnectionClosed) {
|
if err != nil {
|
||||||
continue
|
if errors.Is(err, network.ErrConnectionBroken) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
continue // skip time++
|
||||||
}
|
}
|
||||||
times++
|
times++
|
||||||
if times >= 7 {
|
if times >= 7 {
|
||||||
@ -788,5 +904,4 @@ func (c *QQClient) doHeartbeat() {
|
|||||||
times = 0
|
times = 0
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
c.heartbeatEnabled = false
|
|
||||||
}
|
}
|
||||||
|
@ -5,33 +5,40 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
type TCPListener struct {
|
type TCPListener struct {
|
||||||
lock sync.RWMutex
|
//lock sync.RWMutex
|
||||||
conn net.Conn
|
conn *net.TCPConn
|
||||||
connected bool
|
//connected bool
|
||||||
plannedDisconnect func(*TCPListener)
|
|
||||||
unexpectedDisconnect func(*TCPListener, error)
|
// PlannedDisconnect 预料中的断开连接
|
||||||
|
// 如调用 Close() Connect()
|
||||||
|
PlannedDisconnect func(*TCPListener)
|
||||||
|
|
||||||
|
// UnexpectedDisconnect 未预料的断开连接
|
||||||
|
UnexpectedDisconnect func(*TCPListener, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
var ErrConnectionClosed = errors.New("connection closed")
|
func (t *TCPListener) getConn() *net.TCPConn {
|
||||||
|
return (*net.TCPConn)(atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&t.conn))))
|
||||||
// PlannedDisconnect 预料中的断开连接
|
|
||||||
// 如调用 Close() Connect()
|
|
||||||
func (t *TCPListener) PlannedDisconnect(f func(*TCPListener)) {
|
|
||||||
t.lock.Lock()
|
|
||||||
defer t.lock.Unlock()
|
|
||||||
t.plannedDisconnect = f
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// UnexpectedDisconnect 未预料的断开连接
|
func (t *TCPListener) setConn(conn *net.TCPConn) (swapped bool) {
|
||||||
func (t *TCPListener) UnexpectedDisconnect(f func(*TCPListener, error)) {
|
return atomic.CompareAndSwapPointer((*unsafe.Pointer)(unsafe.Pointer(&t.conn)), unsafe.Pointer(nil), unsafe.Pointer(conn))
|
||||||
t.lock.Lock()
|
}
|
||||||
defer t.lock.Unlock()
|
|
||||||
t.unexpectedDisconnect = f
|
func (t *TCPListener) closeConn() *net.TCPConn {
|
||||||
|
return (*net.TCPConn)(atomic.SwapPointer((*unsafe.Pointer)(unsafe.Pointer(&t.conn)), unsafe.Pointer(nil)))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *TCPListener) Connected() bool {
|
||||||
|
// 等同于 t.getConn() != nil (? copilot写的)
|
||||||
|
return t.getConn() != nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *TCPListener) Connect(addr *net.TCPAddr) error {
|
func (t *TCPListener) Connect(addr *net.TCPAddr) error {
|
||||||
@ -40,24 +47,59 @@ func (t *TCPListener) Connect(addr *net.TCPAddr) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "dial tcp error")
|
return errors.Wrap(err, "dial tcp error")
|
||||||
}
|
}
|
||||||
t.lock.Lock()
|
t.setConn(conn)
|
||||||
defer t.lock.Unlock()
|
//t.lock.Lock()
|
||||||
t.conn = conn
|
//defer t.lock.Unlock()
|
||||||
t.connected = true
|
//t.conn = conn
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ConnectFastest 连接到最快的服务器
|
||||||
|
// TODO 禁用不可用服务器
|
||||||
|
func (t *TCPListener) ConnectFastest(addr []*net.TCPAddr) (*net.TCPAddr, error) {
|
||||||
|
ch := make(chan error)
|
||||||
|
wg := sync.WaitGroup{}
|
||||||
|
wg.Add(len(addr))
|
||||||
|
for _, remote := range addr {
|
||||||
|
go func(remote *net.TCPAddr) {
|
||||||
|
defer wg.Done()
|
||||||
|
conn, err := net.DialTCP("tcp", nil, remote)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
//addrs = append(addrs, remote)
|
||||||
|
if !t.setConn(conn) {
|
||||||
|
_ = conn.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ch <- nil
|
||||||
|
}(remote)
|
||||||
|
}
|
||||||
|
go func() {
|
||||||
|
wg.Wait()
|
||||||
|
if t.getConn() == nil {
|
||||||
|
ch <- errors.New("All addr are unreachable")
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
err := <-ch
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
conn := t.getConn()
|
||||||
|
return conn.RemoteAddr().(*net.TCPAddr), nil
|
||||||
|
}
|
||||||
|
|
||||||
func (t *TCPListener) Write(buf []byte) error {
|
func (t *TCPListener) Write(buf []byte) error {
|
||||||
if conn := t.getConn(); conn != nil {
|
if conn := t.getConn(); conn != nil {
|
||||||
_, err := conn.Write(buf)
|
_, err := conn.Write(buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.unexpectedClose(err)
|
t.unexpectedClose(err)
|
||||||
return ErrConnectionClosed
|
return ErrConnectionBroken
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return ErrConnectionClosed
|
return ErrConnectionBroken
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *TCPListener) ReadBytes(len int) ([]byte, error) {
|
func (t *TCPListener) ReadBytes(len int) ([]byte, error) {
|
||||||
@ -67,12 +109,12 @@ func (t *TCPListener) ReadBytes(len int) ([]byte, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
// time.Sleep(time.Millisecond * 100) // 服务器会发送offline包后立即断开连接, 此时还没解析, 可能还是得加锁
|
// time.Sleep(time.Millisecond * 100) // 服务器会发送offline包后立即断开连接, 此时还没解析, 可能还是得加锁
|
||||||
t.unexpectedClose(err)
|
t.unexpectedClose(err)
|
||||||
return nil, ErrConnectionClosed
|
return nil, ErrConnectionBroken
|
||||||
}
|
}
|
||||||
return buf, nil
|
return buf, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, ErrConnectionClosed
|
return nil, ErrConnectionBroken
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *TCPListener) ReadInt32() (int32, error) {
|
func (t *TCPListener) ReadInt32() (int32, error) {
|
||||||
@ -94,34 +136,37 @@ func (t *TCPListener) unexpectedClose(err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (t *TCPListener) close() {
|
func (t *TCPListener) close() {
|
||||||
t.lock.Lock()
|
if conn := t.closeConn(); conn != nil {
|
||||||
defer t.lock.Unlock()
|
_ = conn.Close()
|
||||||
if t.conn != nil {
|
|
||||||
_ = t.conn.Close()
|
|
||||||
t.conn = nil
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *TCPListener) invokePlannedDisconnect() {
|
func (t *TCPListener) invokePlannedDisconnect() {
|
||||||
t.lock.RLock()
|
//if t.Connected() {
|
||||||
defer t.lock.RUnlock()
|
t.PlannedDisconnect(t)
|
||||||
if t.plannedDisconnect != nil && t.connected {
|
//}
|
||||||
go t.plannedDisconnect(t)
|
//t.lock.RLock()
|
||||||
t.connected = false
|
//defer t.lock.RUnlock()
|
||||||
}
|
//if t.plannedDisconnect != nil && t.connected {
|
||||||
|
// go t.plannedDisconnect(t)
|
||||||
|
// t.connected = false
|
||||||
|
//}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *TCPListener) invokeUnexpectedDisconnect(err error) {
|
func (t *TCPListener) invokeUnexpectedDisconnect(err error) {
|
||||||
t.lock.RLock()
|
//if t.Connected() {
|
||||||
defer t.lock.RUnlock()
|
t.UnexpectedDisconnect(t, err)
|
||||||
if t.unexpectedDisconnect != nil && t.connected {
|
//}
|
||||||
go t.unexpectedDisconnect(t, err)
|
//t.lock.RLock()
|
||||||
t.connected = false
|
//defer t.lock.RUnlock()
|
||||||
}
|
//if t.unexpectedDisconnect != nil && t.connected {
|
||||||
|
// go t.unexpectedDisconnect(t, err)
|
||||||
|
// t.connected = false
|
||||||
|
//}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *TCPListener) getConn() net.Conn {
|
//func (t *TCPListener) getConn() net.Conn {
|
||||||
t.lock.RLock()
|
// t.lock.RLock()
|
||||||
defer t.lock.RUnlock()
|
// defer t.lock.RUnlock()
|
||||||
return t.conn
|
// return t.conn
|
||||||
}
|
//}
|
||||||
|
@ -22,6 +22,7 @@ var (
|
|||||||
ErrSessionExpired = errors.New("session expired")
|
ErrSessionExpired = errors.New("session expired")
|
||||||
ErrPacketDropped = errors.New("packet dropped")
|
ErrPacketDropped = errors.New("packet dropped")
|
||||||
ErrInvalidPacketType = errors.New("invalid packet type")
|
ErrInvalidPacketType = errors.New("invalid packet type")
|
||||||
|
ErrConnectionBroken = errors.New("connection broken")
|
||||||
)
|
)
|
||||||
|
|
||||||
func (t *Transport) ReadRequest(head []byte) (*Request, error) {
|
func (t *Transport) ReadRequest(head []byte) (*Request, error) {
|
||||||
|
@ -1,6 +1,11 @@
|
|||||||
package network
|
package network
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
goBinary "encoding/binary"
|
||||||
|
"fmt"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
"github.com/Mrs4s/MiraiGo/binary"
|
"github.com/Mrs4s/MiraiGo/binary"
|
||||||
@ -15,7 +20,27 @@ type Transport struct {
|
|||||||
Device *auth.Device
|
Device *auth.Device
|
||||||
|
|
||||||
// connection
|
// connection
|
||||||
// conn *TCPListener
|
conn TCPListener
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *Transport) PlannedDisconnect(fun func(*TCPListener)) {
|
||||||
|
t.conn.PlannedDisconnect = fun
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *Transport) UnexpectedDisconnect(fun func(*TCPListener, error)) {
|
||||||
|
t.conn.UnexpectedDisconnect = fun
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *Transport) ConnectFastest(servers []*net.TCPAddr) (chosen *net.TCPAddr, err error) {
|
||||||
|
return t.conn.ConnectFastest(servers)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *Transport) Close() {
|
||||||
|
t.conn.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *Transport) Write(data []byte) error {
|
||||||
|
return t.conn.Write(data)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *Transport) packBody(req *Request, w *binary.Writer) {
|
func (t *Transport) packBody(req *Request, w *binary.Writer) {
|
||||||
@ -96,3 +121,59 @@ func (t *Transport) PackPacket(req *Request) []byte {
|
|||||||
w.WriteUInt32At(pos, uint32(w.Len()))
|
w.WriteUInt32At(pos, uint32(w.Len()))
|
||||||
return append([]byte(nil), w.Bytes()...)
|
return append([]byte(nil), w.Bytes()...)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type PktHandler func(pkt *Request, netErr error)
|
||||||
|
type RequestHandler func(head []byte) (*Request, error)
|
||||||
|
|
||||||
|
func (t *Transport) NetLoop(pktHandler PktHandler, respHandler RequestHandler) {
|
||||||
|
go t.netLoop(pktHandler, respHandler)
|
||||||
|
}
|
||||||
|
|
||||||
|
// readPacket 帮助函数(Helper function)
|
||||||
|
func readPacket(conn *net.TCPConn, minSize, maxSize uint32) ([]byte, error) {
|
||||||
|
lBuf := make([]byte, 4)
|
||||||
|
_, err := io.ReadFull(conn, lBuf)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
l := goBinary.BigEndian.Uint32(lBuf)
|
||||||
|
if l < minSize || l > maxSize {
|
||||||
|
return nil, fmt.Errorf("parse incoming packet error: invalid packet length %v", l)
|
||||||
|
}
|
||||||
|
data := make([]byte, l-4)
|
||||||
|
_, err = io.ReadFull(conn, data)
|
||||||
|
return data, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// netLoop 整个函数周期使用同一个连接,确保不会发生串线这种奇怪的事情
|
||||||
|
func (t *Transport) netLoop(pktHandler PktHandler, respHandler RequestHandler) {
|
||||||
|
conn := t.conn.getConn()
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
pktHandler(nil, fmt.Errorf("panic: %v", r))
|
||||||
|
}
|
||||||
|
t.conn.Close()
|
||||||
|
}()
|
||||||
|
errCount := 0
|
||||||
|
for {
|
||||||
|
data, err := readPacket(conn, 4, 10<<20) // max 10MB
|
||||||
|
if err != nil {
|
||||||
|
// 在且仅在没有新连接建立时断线才被认为是意外的
|
||||||
|
if t.conn.getConn() == conn {
|
||||||
|
pktHandler(nil, errors.Wrap(ErrConnectionBroken, err.Error()))
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
req, err := respHandler(data)
|
||||||
|
if err == nil {
|
||||||
|
errCount = 0
|
||||||
|
goto ok
|
||||||
|
}
|
||||||
|
errCount++
|
||||||
|
if errCount > 2 {
|
||||||
|
err = errors.Wrap(ErrConnectionBroken, err.Error())
|
||||||
|
}
|
||||||
|
ok:
|
||||||
|
go pktHandler(req, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package client
|
package client
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"github.com/Mrs4s/MiraiGo/message"
|
||||||
"net"
|
"net"
|
||||||
"runtime/debug"
|
"runtime/debug"
|
||||||
"sync"
|
"sync"
|
||||||
@ -10,7 +11,6 @@ import (
|
|||||||
|
|
||||||
"github.com/Mrs4s/MiraiGo/client/internal/network"
|
"github.com/Mrs4s/MiraiGo/client/internal/network"
|
||||||
"github.com/Mrs4s/MiraiGo/client/internal/oicq"
|
"github.com/Mrs4s/MiraiGo/client/internal/oicq"
|
||||||
"github.com/Mrs4s/MiraiGo/message"
|
|
||||||
"github.com/Mrs4s/MiraiGo/utils"
|
"github.com/Mrs4s/MiraiGo/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -84,22 +84,22 @@ func (c *QQClient) ConnectionQualityTest() *ConnectionQualityInfo {
|
|||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
|
|
||||||
// connect 连接到 QQClient.servers 中的服务器
|
func (c *QQClient) connectFastest() error {
|
||||||
func (c *QQClient) connect() error {
|
c.Disconnect()
|
||||||
c.Info("connect to server: %v", c.servers[c.currServerIndex].String())
|
addr, err := c.transport.ConnectFastest(c.servers)
|
||||||
err := c.TCP.Connect(c.servers[c.currServerIndex])
|
|
||||||
c.currServerIndex++
|
|
||||||
if c.currServerIndex == len(c.servers) {
|
|
||||||
c.currServerIndex = 0
|
|
||||||
}
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.retryTimes++
|
c.Disconnect()
|
||||||
if c.retryTimes > len(c.servers) {
|
|
||||||
return errors.New("All servers are unreachable")
|
|
||||||
}
|
|
||||||
c.Error("connect server error: %v", err)
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
c.Debug("connected to server: %v [fastest]", addr.String())
|
||||||
|
c.transport.NetLoop(c.pktProc, c.transport.ReadRequest)
|
||||||
|
c.retryTimes = 0
|
||||||
|
c.ConnectTime = time.Now()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// connect 连接到 QQClient.servers 中的服务器
|
||||||
|
func (c *QQClient) connect() error {
|
||||||
c.once.Do(func() {
|
c.once.Do(func() {
|
||||||
c.OnGroupMessage(func(_ *QQClient, _ *message.GroupMessage) {
|
c.OnGroupMessage(func(_ *QQClient, _ *message.GroupMessage) {
|
||||||
c.stat.MessageReceived.Add(1)
|
c.stat.MessageReceived.Add(1)
|
||||||
@ -116,11 +116,30 @@ func (c *QQClient) connect() error {
|
|||||||
c.onGroupMessageReceipt("internal", func(_ *QQClient, _ *groupMessageReceiptEvent) {
|
c.onGroupMessageReceipt("internal", func(_ *QQClient, _ *groupMessageReceiptEvent) {
|
||||||
c.stat.MessageSent.Add(1)
|
c.stat.MessageSent.Add(1)
|
||||||
})
|
})
|
||||||
go c.netLoop()
|
//go c.netLoop()
|
||||||
})
|
})
|
||||||
|
return c.connectFastest() // 暂时?
|
||||||
|
/*c.Info("connect to server: %v", c.servers[c.currServerIndex].String())
|
||||||
|
err := c.TCP.Connect(c.servers[c.currServerIndex])
|
||||||
|
c.currServerIndex++
|
||||||
|
if c.currServerIndex == len(c.servers) {
|
||||||
|
c.currServerIndex = 0
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
c.retryTimes++
|
||||||
|
if c.retryTimes > len(c.servers) {
|
||||||
|
return errors.New("All servers are unreachable")
|
||||||
|
}
|
||||||
|
c.Error("connect server error: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
c.retryTimes = 0
|
c.retryTimes = 0
|
||||||
c.ConnectTime = time.Now()
|
c.ConnectTime = time.Now()
|
||||||
return nil
|
return nil*/
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *QQClient) QuickReconnect() {
|
||||||
|
c.quickReconnect() // TODO "用户请求快速重连"
|
||||||
}
|
}
|
||||||
|
|
||||||
// quickReconnect 快速重连
|
// quickReconnect 快速重连
|
||||||
@ -129,7 +148,7 @@ func (c *QQClient) quickReconnect() {
|
|||||||
time.Sleep(time.Millisecond * 200)
|
time.Sleep(time.Millisecond * 200)
|
||||||
if err := c.connect(); err != nil {
|
if err := c.connect(); err != nil {
|
||||||
c.Error("connect server error: %v", err)
|
c.Error("connect server error: %v", err)
|
||||||
c.dispatchDisconnectEvent(&ClientDisconnectedEvent{Message: "quick reconnect failed"})
|
c.dispatchDisconnectEvent(&ClientDisconnectedEvent{Message: "快速重连失败"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if err := c.registerClient(); err != nil {
|
if err := c.registerClient(); err != nil {
|
||||||
@ -143,7 +162,7 @@ func (c *QQClient) quickReconnect() {
|
|||||||
// Disconnect 中断连接, 不释放资源
|
// Disconnect 中断连接, 不释放资源
|
||||||
func (c *QQClient) Disconnect() {
|
func (c *QQClient) Disconnect() {
|
||||||
c.Online.Store(false)
|
c.Online.Store(false)
|
||||||
c.TCP.Close()
|
c.transport.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *QQClient) send(call *network.Call) {
|
func (c *QQClient) send(call *network.Call) {
|
||||||
@ -160,9 +179,9 @@ func (c *QQClient) send(call *network.Call) {
|
|||||||
c.pendingMu.Lock()
|
c.pendingMu.Lock()
|
||||||
call = c.pending[seq]
|
call = c.pending[seq]
|
||||||
delete(c.pending, seq)
|
delete(c.pending, seq)
|
||||||
|
c.pendingMu.Unlock()
|
||||||
call.Err = err
|
call.Err = err
|
||||||
call.Done <- call
|
call.Done <- call
|
||||||
c.pendingMu.Unlock()
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -194,7 +213,7 @@ func (c *QQClient) callAndDecode(req *network.Request, decoder func(*QQClient, *
|
|||||||
|
|
||||||
// sendPacket 向服务器发送一个数据包
|
// sendPacket 向服务器发送一个数据包
|
||||||
func (c *QQClient) sendPacket(pkt []byte) error {
|
func (c *QQClient) sendPacket(pkt []byte) error {
|
||||||
err := c.TCP.Write(pkt)
|
err := c.transport.Write(pkt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.stat.PacketLost.Add(1)
|
c.stat.PacketLost.Add(1)
|
||||||
} else {
|
} else {
|
||||||
@ -257,100 +276,79 @@ func (c *QQClient) unexpectedDisconnect(_ *network.TCPListener, e error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// netLoop 通过循环来不停接收数据包
|
func (c *QQClient) pktProc(req *network.Request, netErr error) {
|
||||||
func (c *QQClient) netLoop() {
|
if netErr != nil {
|
||||||
errCount := 0
|
switch true {
|
||||||
for c.alive {
|
case errors.Is(netErr, network.ErrConnectionBroken):
|
||||||
l, err := c.TCP.ReadInt32()
|
go c.dispatchDisconnectEvent(&ClientDisconnectedEvent{Message: netErr.Error()})
|
||||||
if err != nil {
|
c.QuickReconnect()
|
||||||
time.Sleep(time.Millisecond * 500)
|
case errors.Is(netErr, network.ErrSessionExpired) || errors.Is(netErr, network.ErrPacketDropped):
|
||||||
continue
|
c.Disconnect()
|
||||||
|
go c.dispatchDisconnectEvent(&ClientDisconnectedEvent{Message: "session expired"})
|
||||||
}
|
}
|
||||||
if l < 4 || l > 1024*1024*10 { // max 10MB
|
c.Error("parse incoming packet error: %v", netErr)
|
||||||
c.Error("parse incoming packet error: invalid packet length %v", l)
|
return
|
||||||
errCount++
|
}
|
||||||
if errCount > 2 {
|
|
||||||
go c.quickReconnect()
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
data, _ := c.TCP.ReadBytes(int(l) - 4)
|
|
||||||
req, err := c.transport.ReadRequest(data)
|
|
||||||
if err != nil {
|
|
||||||
c.Error("parse incoming packet error: %v", err)
|
|
||||||
if errors.Is(err, network.ErrSessionExpired) || errors.Is(err, network.ErrPacketDropped) {
|
|
||||||
c.Disconnect()
|
|
||||||
go c.dispatchDisconnectEvent(&ClientDisconnectedEvent{Message: "session expired"})
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
errCount++
|
|
||||||
if errCount > 2 {
|
|
||||||
go c.quickReconnect()
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if req.EncryptType == network.EncryptTypeEmptyKey {
|
|
||||||
m, err := c.oicq.Unmarshal(req.Body)
|
|
||||||
if err != nil {
|
|
||||||
c.Error("decrypt payload error: %v", err)
|
|
||||||
if errors.Is(err, oicq.ErrUnknownFlag) {
|
|
||||||
go c.quickReconnect()
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
req.Body = m.Body
|
|
||||||
}
|
|
||||||
errCount = 0
|
|
||||||
c.Debug("rev pkt: %v seq: %v", req.CommandName, req.SequenceID)
|
|
||||||
c.stat.PacketReceived.Add(1)
|
|
||||||
go func(req *network.Request) {
|
|
||||||
defer func() {
|
|
||||||
if pan := recover(); pan != nil {
|
|
||||||
c.Error("panic on decoder %v : %v\n%s", req.CommandName, pan, debug.Stack())
|
|
||||||
c.Dump("packet decode error: %v - %v", req.Body, req.CommandName, pan)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
// snapshot of read call
|
if req.EncryptType == network.EncryptTypeEmptyKey {
|
||||||
c.pendingMu.Lock()
|
m, err := c.oicq.Unmarshal(req.Body)
|
||||||
call := c.pending[uint16(req.SequenceID)]
|
if err != nil {
|
||||||
if call != nil {
|
c.Error("decrypt payload error: %v", err)
|
||||||
call.Response = &network.Response{
|
if errors.Is(err, oicq.ErrUnknownFlag) {
|
||||||
SequenceID: req.SequenceID,
|
go c.quickReconnect() // TODO "服务器发送未知响应"
|
||||||
CommandName: req.CommandName,
|
|
||||||
Body: req.Body,
|
|
||||||
Params: call.Request.Params,
|
|
||||||
// Request: nil,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
c.pendingMu.Unlock()
|
|
||||||
if call != nil && call.Request.CommandName == req.CommandName {
|
|
||||||
select {
|
|
||||||
case call.Done <- call:
|
|
||||||
default:
|
|
||||||
// we don't want blocking
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
req.Body = m.Body
|
||||||
|
}
|
||||||
|
|
||||||
if decoder, ok := decoders[req.CommandName]; ok {
|
defer func() {
|
||||||
// found predefined decoder
|
if pan := recover(); pan != nil {
|
||||||
resp := network.Response{
|
c.Error("panic on decoder %v : %v\n%s", req.CommandName, pan, debug.Stack())
|
||||||
SequenceID: req.SequenceID,
|
c.Dump("packet decode error: %v - %v", req.Body, req.CommandName, pan)
|
||||||
CommandName: req.CommandName,
|
}
|
||||||
Body: req.Body,
|
}()
|
||||||
// Request: nil,
|
|
||||||
}
|
c.Debug("rev resp: %v seq: %v", req.CommandName, req.SequenceID)
|
||||||
decoded, err := decoder(c, &resp)
|
c.stat.PacketReceived.Add(1)
|
||||||
if err != nil {
|
|
||||||
c.Debug("decode req %v error: %+v", req.CommandName, err)
|
// snapshot of read call
|
||||||
}
|
c.pendingMu.Lock()
|
||||||
if f, ok := c.waiters.Load(req.CommandName); ok { // 在不存在handler的情况下触发wait
|
call := c.pending[uint16(req.SequenceID)]
|
||||||
f.(func(interface{}, error))(decoded, err)
|
if call != nil {
|
||||||
}
|
call.Response = &network.Response{
|
||||||
} else {
|
SequenceID: req.SequenceID,
|
||||||
c.Debug("Unhandled Command: %s\nSeq: %d\nThis message can be ignored.", req.CommandName, req.SequenceID)
|
CommandName: req.CommandName,
|
||||||
}
|
Body: req.Body,
|
||||||
}(req)
|
Params: call.Request.Params,
|
||||||
|
// Request: nil,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
c.pendingMu.Unlock()
|
||||||
|
if call != nil && call.Request.CommandName == req.CommandName {
|
||||||
|
select {
|
||||||
|
case call.Done <- call:
|
||||||
|
default:
|
||||||
|
// we don't want blocking
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if decoder, ok := decoders[req.CommandName]; ok {
|
||||||
|
// found predefined decoder
|
||||||
|
resp := network.Response{
|
||||||
|
SequenceID: req.SequenceID,
|
||||||
|
CommandName: req.CommandName,
|
||||||
|
Body: req.Body,
|
||||||
|
// Request: nil,
|
||||||
|
}
|
||||||
|
decoded, err := decoder(c, &resp)
|
||||||
|
if err != nil {
|
||||||
|
c.Debug("decode req %v error: %+v", req.CommandName, err)
|
||||||
|
}
|
||||||
|
if f, ok := c.waiters.Load(req.CommandName); ok { // 在不存在handler的情况下触发wait
|
||||||
|
f.(func(interface{}, error))(decoded, err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
c.Debug("Unhandled Command: %s\nSeq: %d\nThis message can be ignored.", req.CommandName, req.SequenceID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -236,6 +236,7 @@ func ToProtoElems(elems []IMessageElement, generalFlags bool) (r []*msg.Elem) {
|
|||||||
if len(elems) == 0 {
|
if len(elems) == 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
r = make([]*msg.Elem, 0, len(elems))
|
||||||
for _, elem := range elems {
|
for _, elem := range elems {
|
||||||
if reply, ok := elem.(*ReplyElement); ok {
|
if reply, ok := elem.(*ReplyElement); ok {
|
||||||
r = append(r, &msg.Elem{
|
r = append(r, &msg.Elem{
|
||||||
|
26
utils/func.go
Normal file
26
utils/func.go
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
package utils
|
||||||
|
|
||||||
|
import "fmt"
|
||||||
|
|
||||||
|
// CoverError == catch{}
|
||||||
|
func CoverError(fun func()) error {
|
||||||
|
if fun == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
errCh := make(chan error, 1)
|
||||||
|
func() {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
if err, ok := r.(error); ok {
|
||||||
|
errCh <- err
|
||||||
|
} else {
|
||||||
|
errCh <- fmt.Errorf("%v", r)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
errCh <- nil
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
fun()
|
||||||
|
}()
|
||||||
|
return <-errCh
|
||||||
|
}
|
29
utils/func_test.go
Normal file
29
utils/func_test.go
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
package utils
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
var errTest = errors.New("test error")
|
||||||
|
|
||||||
|
func TestCoverError(t *testing.T) {
|
||||||
|
err := CoverError(nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf(`CoverError(nil) = %v, want nil`, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = CoverError(func() {
|
||||||
|
panic("test")
|
||||||
|
})
|
||||||
|
if err.Error() != "test" {
|
||||||
|
t.Errorf(`CoverError(func() { panic("test") }) = %v, want "test"`, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = CoverError(func() {
|
||||||
|
panic(errTest)
|
||||||
|
})
|
||||||
|
if err != errTest {
|
||||||
|
t.Errorf(`CoverError(func() { panic(errTest) }) = %v, want errTest`, err)
|
||||||
|
}
|
||||||
|
}
|
@ -10,10 +10,7 @@ import (
|
|||||||
|
|
||||||
var client = &http.Client{
|
var client = &http.Client{
|
||||||
Transport: &http.Transport{
|
Transport: &http.Transport{
|
||||||
ForceAttemptHTTP2: true,
|
ForceAttemptHTTP2: true,
|
||||||
MaxConnsPerHost: 0,
|
|
||||||
MaxIdleConns: 0,
|
|
||||||
MaxIdleConnsPerHost: 999,
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user