1
0
mirror of https://github.com/Mrs4s/MiraiGo.git synced 2025-05-04 11:07:40 +08:00

225 lines
5.0 KiB
Go

package highway
import (
"crypto/md5"
"io"
"sync"
"sync/atomic"
"github.com/pkg/errors"
"golang.org/x/sync/errgroup"
"github.com/Mrs4s/MiraiGo/binary"
"github.com/Mrs4s/MiraiGo/client/pb"
"github.com/Mrs4s/MiraiGo/internal/proto"
)
type Transaction struct {
CommandID int32
Body io.Reader
Sum []byte // md5 sum of body
Size int64 // body size
Ticket []byte
Ext []byte
Encrypt bool
}
func (bdh *Transaction) encrypt(key []byte) error {
if !bdh.Encrypt {
return nil
}
if len(key) == 0 {
return errors.New("session key not found. maybe miss some packet?")
}
bdh.Ext = binary.NewTeaCipher(key).Encrypt(bdh.Ext)
return nil
}
func (s *Session) uploadSingle(trans Transaction) ([]byte, error) {
pc, err := s.selectConn()
if err != nil {
return nil, err
}
defer s.putIdleConn(pc)
reader := binary.NewNetworkReader(pc.conn)
const chunkSize = 128 * 1024
var rspExt []byte
offset := 0
chunk := make([]byte, chunkSize)
for {
chunk = chunk[:cap(chunk)]
rl, err := io.ReadFull(trans.Body, chunk)
if rl == 0 {
break
}
if errors.Is(err, io.ErrUnexpectedEOF) {
chunk = chunk[:rl]
}
ch := md5.Sum(chunk)
head, _ := proto.Marshal(&pb.ReqDataHighwayHead{
MsgBasehead: &pb.DataHighwayHead{
Version: 1,
Uin: s.Uin,
Command: _REQ_CMD_DATA,
Seq: s.nextSeq(),
Appid: s.AppID,
Dataflag: 4096,
CommandId: trans.CommandID,
LocaleId: 2052,
},
MsgSeghead: &pb.SegHead{
Filesize: trans.Size,
Dataoffset: int64(offset),
Datalength: int32(rl),
Serviceticket: trans.Ticket,
Md5: ch[:],
FileMd5: trans.Sum,
},
ReqExtendinfo: trans.Ext,
})
offset += rl
buffers := frame(head, chunk)
_, err = buffers.WriteTo(pc.conn)
if err != nil {
return nil, errors.Wrap(err, "write conn error")
}
rspHead, err := readResponse(reader)
if err != nil {
return nil, errors.Wrap(err, "highway upload error")
}
if rspHead.ErrorCode != 0 {
return nil, errors.Errorf("upload failed: %d", rspHead.ErrorCode)
}
if rspHead.RspExtendinfo != nil {
rspExt = rspHead.RspExtendinfo
}
if rspHead.MsgSeghead != nil && rspHead.MsgSeghead.Serviceticket != nil {
trans.Ticket = rspHead.MsgSeghead.Serviceticket
}
}
return rspExt, nil
}
func (s *Session) Upload(trans Transaction) ([]byte, error) {
// encrypt ext data
if err := trans.encrypt(s.SessionKey); err != nil {
return nil, err
}
const maxThreadCount = 4
threadCount := int(trans.Size) / (3 * 512 * 1024) // 1 thread upload 1.5 MB
if threadCount > maxThreadCount {
threadCount = maxThreadCount
}
if threadCount < 2 {
// single thread upload
return s.uploadSingle(trans)
}
// pick a address
// TODO: pick smarter
pc, err := s.selectConn()
if err != nil {
return nil, err
}
addr := pc.addr
s.putIdleConn(pc)
const blockSize int64 = 256 * 1024
var (
rspExt []byte
completedThread uint32
cond = sync.NewCond(&sync.Mutex{})
offset = int64(0)
count = (trans.Size + blockSize - 1) / blockSize
id = 0
)
doUpload := func() error {
// send signal complete uploading
defer func() {
atomic.AddUint32(&completedThread, 1)
cond.Signal()
}()
// todo: get from pool?
pc, err := s.connect(addr)
if err != nil {
return err
}
defer s.putIdleConn(pc)
reader := binary.NewNetworkReader(pc.conn)
chunk := make([]byte, blockSize)
for {
cond.L.Lock() // lock protect reading
off := offset
offset += blockSize
id++
last := int64(id) == count
if last { // last
for atomic.LoadUint32(&completedThread) != uint32(threadCount-1) {
cond.Wait()
}
} else if int64(id) > count {
cond.L.Unlock()
break
}
chunk = chunk[:blockSize]
n, err := io.ReadFull(trans.Body, chunk)
cond.L.Unlock()
if n == 0 {
break
}
if errors.Is(err, io.ErrUnexpectedEOF) {
chunk = chunk[:n]
}
ch := md5.Sum(chunk)
head, _ := proto.Marshal(&pb.ReqDataHighwayHead{
MsgBasehead: &pb.DataHighwayHead{
Version: 1,
Uin: s.Uin,
Command: _REQ_CMD_DATA,
Seq: s.nextSeq(),
Appid: s.AppID,
Dataflag: 4096,
CommandId: trans.CommandID,
LocaleId: 2052,
},
MsgSeghead: &pb.SegHead{
Filesize: trans.Size,
Dataoffset: off,
Datalength: int32(n),
Serviceticket: trans.Ticket,
Md5: ch[:],
FileMd5: trans.Sum,
},
ReqExtendinfo: trans.Ext,
})
buffers := frame(head, chunk)
_, err = buffers.WriteTo(pc.conn)
if err != nil {
return errors.Wrap(err, "write conn error")
}
rspHead, err := readResponse(reader)
if err != nil {
return errors.Wrap(err, "highway upload error")
}
if rspHead.ErrorCode != 0 {
return errors.Errorf("upload failed: %d", rspHead.ErrorCode)
}
if last && rspHead.RspExtendinfo != nil {
rspExt = rspHead.RspExtendinfo
}
}
return nil
}
group := errgroup.Group{}
for i := 0; i < threadCount; i++ {
group.Go(doUpload)
}
return rspExt, group.Wait()
}