mirror of
https://github.com/Mrs4s/MiraiGo.git
synced 2025-05-04 11:07:40 +08:00
207 lines
4.8 KiB
Go
207 lines
4.8 KiB
Go
package network
|
|
|
|
import (
|
|
goBinary "encoding/binary"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"strconv"
|
|
"sync"
|
|
|
|
"github.com/pkg/errors"
|
|
|
|
"github.com/Mrs4s/MiraiGo/binary"
|
|
"github.com/Mrs4s/MiraiGo/client/internal/auth"
|
|
)
|
|
|
|
// Transport is a network transport.
|
|
type Transport struct {
|
|
// sessionMu sync.Mutex
|
|
Sig *auth.SigInfo
|
|
Version *auth.AppVersion
|
|
Device *auth.Device
|
|
|
|
// connection
|
|
connMu sync.Mutex
|
|
servers []*net.TCPAddr
|
|
curServerAddr *net.TCPAddr
|
|
conn TCPListener
|
|
}
|
|
|
|
func (t *Transport) AddServerAddr(addr *net.TCPAddr) {
|
|
t.connMu.Lock()
|
|
defer t.connMu.Unlock()
|
|
t.servers = append(t.servers, addr)
|
|
}
|
|
|
|
func (t *Transport) GetServerAddr() *net.TCPAddr {
|
|
t.connMu.Lock()
|
|
defer t.connMu.Unlock()
|
|
return t.curServerAddr
|
|
}
|
|
|
|
func (t *Transport) ServerCount() int {
|
|
t.connMu.Lock()
|
|
defer t.connMu.Unlock()
|
|
return len(t.servers)
|
|
}
|
|
|
|
func (t *Transport) PlannedDisconnect(fun func(*TCPListener)) {
|
|
t.conn.PlannedDisconnect = fun
|
|
}
|
|
|
|
func (t *Transport) UnexpectedDisconnect(fun func(*TCPListener, error)) {
|
|
t.conn.UnexpectedDisconnect = fun
|
|
}
|
|
|
|
func (t *Transport) ConnectFastest() (chosen *net.TCPAddr, err error) {
|
|
t.connMu.Lock()
|
|
defer t.connMu.Unlock()
|
|
chosen, err = t.conn.ConnectFastest(t.servers)
|
|
t.curServerAddr = chosen
|
|
return
|
|
}
|
|
|
|
func (t *Transport) Close() {
|
|
t.conn.Close()
|
|
}
|
|
|
|
func (t *Transport) Write(data []byte) error {
|
|
return t.conn.Write(data)
|
|
}
|
|
|
|
func (t *Transport) packBody(req *Request, w *binary.Writer) {
|
|
pos := w.FillUInt32()
|
|
if req.Type == RequestTypeLogin {
|
|
w.WriteUInt32(uint32(req.SequenceID))
|
|
w.WriteUInt32(t.Version.AppId)
|
|
w.WriteUInt32(t.Version.SubAppId)
|
|
w.Write([]byte{0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00})
|
|
tgt := t.Sig.TGT
|
|
if len(tgt) == 0 || len(tgt) == 4 {
|
|
w.WriteUInt32(0x04)
|
|
} else {
|
|
w.WriteUInt32(uint32(len(tgt) + 4))
|
|
w.Write(tgt)
|
|
}
|
|
}
|
|
w.WriteString(req.CommandName)
|
|
w.WriteUInt32(uint32(len(t.Sig.OutPacketSessionID) + 4))
|
|
w.Write(t.Sig.OutPacketSessionID)
|
|
if req.Type == RequestTypeLogin {
|
|
w.WriteString(t.Device.IMEI)
|
|
w.WriteUInt32(0x04)
|
|
|
|
w.WriteUInt16(uint16(len(t.Sig.Ksid)) + 2)
|
|
w.Write(t.Sig.Ksid)
|
|
}
|
|
w.WriteUInt32(0x04)
|
|
w.WriteUInt32At(pos, uint32(w.Len()-pos))
|
|
|
|
w.WriteUInt32(uint32(len(req.Body) + 4))
|
|
w.Write(req.Body)
|
|
}
|
|
|
|
// PackPacket packs a packet.
|
|
func (t *Transport) PackPacket(req *Request) []byte {
|
|
// todo(wdvxdr): combine pack packet, send packet and return the response
|
|
if len(t.Sig.D2) == 0 {
|
|
req.EncryptType = EncryptTypeEmptyKey
|
|
}
|
|
|
|
w := binary.SelectWriter()
|
|
defer binary.PutWriter(w)
|
|
|
|
pos := w.FillUInt32()
|
|
// vvv w.Write(head) vvv
|
|
w.WriteUInt32(uint32(req.Type))
|
|
w.WriteByte(byte(req.EncryptType))
|
|
switch req.Type {
|
|
case RequestTypeLogin:
|
|
switch req.EncryptType {
|
|
case EncryptTypeD2Key:
|
|
w.WriteUInt32(uint32(len(t.Sig.D2) + 4))
|
|
w.Write(t.Sig.D2)
|
|
default:
|
|
w.WriteUInt32(4)
|
|
}
|
|
case RequestTypeSimple:
|
|
w.WriteUInt32(uint32(req.SequenceID))
|
|
}
|
|
w.WriteByte(0x00)
|
|
w.WriteString(strconv.FormatInt(req.Uin, 10))
|
|
// ^^^ w.Write(head) ^^^
|
|
|
|
w2 := binary.SelectWriter()
|
|
t.packBody(req, w2)
|
|
body := w2.Bytes()
|
|
// encrypt body
|
|
switch req.EncryptType {
|
|
case EncryptTypeD2Key:
|
|
body = binary.NewTeaCipher(t.Sig.D2Key).Encrypt(body)
|
|
case EncryptTypeEmptyKey:
|
|
body = binary.NewTeaCipher(emptyKey).Encrypt(body)
|
|
}
|
|
w.Write(body)
|
|
binary.PutWriter(w2)
|
|
|
|
w.WriteUInt32At(pos, uint32(w.Len()))
|
|
return append([]byte(nil), w.Bytes()...)
|
|
}
|
|
|
|
type PktHandler func(pkt *Request, netErr error)
|
|
type RequestHandler func(head []byte) (*Request, error)
|
|
|
|
func (t *Transport) NetLoop(pktHandler PktHandler, respHandler RequestHandler) {
|
|
go t.netLoop(pktHandler, respHandler)
|
|
}
|
|
|
|
// readPacket 帮助函数(Helper function)
|
|
func readPacket(conn *net.TCPConn, minSize, maxSize uint32) ([]byte, error) {
|
|
lBuf := make([]byte, 4)
|
|
_, err := io.ReadFull(conn, lBuf)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
l := goBinary.BigEndian.Uint32(lBuf)
|
|
if l < minSize || l > maxSize {
|
|
return nil, fmt.Errorf("parse incoming packet error: invalid packet length %v", l)
|
|
}
|
|
data := make([]byte, l-4)
|
|
_, err = io.ReadFull(conn, data)
|
|
return data, err
|
|
}
|
|
|
|
// netLoop 整个函数周期使用同一个连接,确保不会发生串线这种奇怪的事情
|
|
func (t *Transport) netLoop(pktHandler PktHandler, respHandler RequestHandler) {
|
|
conn := t.conn.getConn()
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
pktHandler(nil, fmt.Errorf("panic: %v", r))
|
|
}
|
|
t.conn.Close()
|
|
}()
|
|
errCount := 0
|
|
for {
|
|
data, err := readPacket(conn, 4, 10<<20) // max 10MB
|
|
if err != nil {
|
|
// 在且仅在没有新连接建立时断线才被认为是意外的
|
|
if t.conn.getConn() == conn {
|
|
pktHandler(nil, errors.Wrap(ErrConnectionBroken, err.Error()))
|
|
}
|
|
return
|
|
}
|
|
req, err := respHandler(data)
|
|
if err == nil {
|
|
errCount = 0
|
|
goto ok
|
|
}
|
|
errCount++
|
|
if errCount > 2 {
|
|
err = errors.Wrap(ErrConnectionBroken, err.Error())
|
|
}
|
|
ok:
|
|
go pktHandler(req, err)
|
|
}
|
|
}
|