1
0
mirror of https://github.com/Mrs4s/MiraiGo.git synced 2025-05-04 11:07:40 +08:00

207 lines
4.8 KiB
Go

package network
import (
goBinary "encoding/binary"
"fmt"
"io"
"net"
"strconv"
"sync"
"github.com/pkg/errors"
"github.com/Mrs4s/MiraiGo/binary"
"github.com/Mrs4s/MiraiGo/client/internal/auth"
)
// Transport is a network transport.
type Transport struct {
// sessionMu sync.Mutex
Sig *auth.SigInfo
Version *auth.AppVersion
Device *auth.Device
// connection
connMu sync.Mutex
servers []*net.TCPAddr
curServerAddr *net.TCPAddr
conn TCPListener
}
func (t *Transport) AddServerAddr(addr *net.TCPAddr) {
t.connMu.Lock()
defer t.connMu.Unlock()
t.servers = append(t.servers, addr)
}
func (t *Transport) GetServerAddr() *net.TCPAddr {
t.connMu.Lock()
defer t.connMu.Unlock()
return t.curServerAddr
}
func (t *Transport) ServerCount() int {
t.connMu.Lock()
defer t.connMu.Unlock()
return len(t.servers)
}
func (t *Transport) PlannedDisconnect(fun func(*TCPListener)) {
t.conn.PlannedDisconnect = fun
}
func (t *Transport) UnexpectedDisconnect(fun func(*TCPListener, error)) {
t.conn.UnexpectedDisconnect = fun
}
func (t *Transport) ConnectFastest() (chosen *net.TCPAddr, err error) {
t.connMu.Lock()
defer t.connMu.Unlock()
chosen, err = t.conn.ConnectFastest(t.servers)
t.curServerAddr = chosen
return
}
func (t *Transport) Close() {
t.conn.Close()
}
func (t *Transport) Write(data []byte) error {
return t.conn.Write(data)
}
func (t *Transport) packBody(req *Request, w *binary.Writer) {
pos := w.FillUInt32()
if req.Type == RequestTypeLogin {
w.WriteUInt32(uint32(req.SequenceID))
w.WriteUInt32(t.Version.AppId)
w.WriteUInt32(t.Version.SubAppId)
w.Write([]byte{0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00})
tgt := t.Sig.TGT
if len(tgt) == 0 || len(tgt) == 4 {
w.WriteUInt32(0x04)
} else {
w.WriteUInt32(uint32(len(tgt) + 4))
w.Write(tgt)
}
}
w.WriteString(req.CommandName)
w.WriteUInt32(uint32(len(t.Sig.OutPacketSessionID) + 4))
w.Write(t.Sig.OutPacketSessionID)
if req.Type == RequestTypeLogin {
w.WriteString(t.Device.IMEI)
w.WriteUInt32(0x04)
w.WriteUInt16(uint16(len(t.Sig.Ksid)) + 2)
w.Write(t.Sig.Ksid)
}
w.WriteUInt32(0x04)
w.WriteUInt32At(pos, uint32(w.Len()-pos))
w.WriteUInt32(uint32(len(req.Body) + 4))
w.Write(req.Body)
}
// PackPacket packs a packet.
func (t *Transport) PackPacket(req *Request) []byte {
// todo(wdvxdr): combine pack packet, send packet and return the response
if len(t.Sig.D2) == 0 {
req.EncryptType = EncryptTypeEmptyKey
}
w := binary.SelectWriter()
defer binary.PutWriter(w)
pos := w.FillUInt32()
// vvv w.Write(head) vvv
w.WriteUInt32(uint32(req.Type))
w.WriteByte(byte(req.EncryptType))
switch req.Type {
case RequestTypeLogin:
switch req.EncryptType {
case EncryptTypeD2Key:
w.WriteUInt32(uint32(len(t.Sig.D2) + 4))
w.Write(t.Sig.D2)
default:
w.WriteUInt32(4)
}
case RequestTypeSimple:
w.WriteUInt32(uint32(req.SequenceID))
}
w.WriteByte(0x00)
w.WriteString(strconv.FormatInt(req.Uin, 10))
// ^^^ w.Write(head) ^^^
w2 := binary.SelectWriter()
t.packBody(req, w2)
body := w2.Bytes()
// encrypt body
switch req.EncryptType {
case EncryptTypeD2Key:
body = binary.NewTeaCipher(t.Sig.D2Key).Encrypt(body)
case EncryptTypeEmptyKey:
body = binary.NewTeaCipher(emptyKey).Encrypt(body)
}
w.Write(body)
binary.PutWriter(w2)
w.WriteUInt32At(pos, uint32(w.Len()))
return append([]byte(nil), w.Bytes()...)
}
type PktHandler func(pkt *Request, netErr error)
type RequestHandler func(head []byte) (*Request, error)
func (t *Transport) NetLoop(pktHandler PktHandler, respHandler RequestHandler) {
go t.netLoop(pktHandler, respHandler)
}
// readPacket 帮助函数(Helper function)
func readPacket(conn *net.TCPConn, minSize, maxSize uint32) ([]byte, error) {
lBuf := make([]byte, 4)
_, err := io.ReadFull(conn, lBuf)
if err != nil {
return nil, err
}
l := goBinary.BigEndian.Uint32(lBuf)
if l < minSize || l > maxSize {
return nil, fmt.Errorf("parse incoming packet error: invalid packet length %v", l)
}
data := make([]byte, l-4)
_, err = io.ReadFull(conn, data)
return data, err
}
// netLoop 整个函数周期使用同一个连接,确保不会发生串线这种奇怪的事情
func (t *Transport) netLoop(pktHandler PktHandler, respHandler RequestHandler) {
conn := t.conn.getConn()
defer func() {
if r := recover(); r != nil {
pktHandler(nil, fmt.Errorf("panic: %v", r))
}
t.conn.Close()
}()
errCount := 0
for {
data, err := readPacket(conn, 4, 10<<20) // max 10MB
if err != nil {
// 在且仅在没有新连接建立时断线才被认为是意外的
if t.conn.getConn() == conn {
pktHandler(nil, errors.Wrap(ErrConnectionBroken, err.Error()))
}
return
}
req, err := respHandler(data)
if err == nil {
errCount = 0
goto ok
}
errCount++
if errCount > 2 {
err = errors.Wrap(ErrConnectionBroken, err.Error())
}
ok:
go pktHandler(req, err)
}
}