mirror of
https://github.com/Mrs4s/MiraiGo.git
synced 2025-05-04 11:07:40 +08:00
225 lines
5.0 KiB
Go
225 lines
5.0 KiB
Go
package highway
|
|
|
|
import (
|
|
"crypto/md5"
|
|
"io"
|
|
"sync"
|
|
"sync/atomic"
|
|
|
|
"github.com/pkg/errors"
|
|
"golang.org/x/sync/errgroup"
|
|
|
|
"github.com/Mrs4s/MiraiGo/binary"
|
|
"github.com/Mrs4s/MiraiGo/client/pb"
|
|
"github.com/Mrs4s/MiraiGo/internal/proto"
|
|
)
|
|
|
|
type Transaction struct {
|
|
CommandID int32
|
|
Body io.Reader
|
|
Sum []byte // md5 sum of body
|
|
Size int64 // body size
|
|
Ticket []byte
|
|
Ext []byte
|
|
Encrypt bool
|
|
}
|
|
|
|
func (bdh *Transaction) encrypt(key []byte) error {
|
|
if !bdh.Encrypt {
|
|
return nil
|
|
}
|
|
if len(key) == 0 {
|
|
return errors.New("session key not found. maybe miss some packet?")
|
|
}
|
|
bdh.Ext = binary.NewTeaCipher(key).Encrypt(bdh.Ext)
|
|
return nil
|
|
}
|
|
|
|
func (s *Session) uploadSingle(trans Transaction) ([]byte, error) {
|
|
pc, err := s.selectConn()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer s.putIdleConn(pc)
|
|
|
|
reader := binary.NewNetworkReader(pc.conn)
|
|
const chunkSize = 128 * 1024
|
|
var rspExt []byte
|
|
offset := 0
|
|
chunk := make([]byte, chunkSize)
|
|
for {
|
|
chunk = chunk[:cap(chunk)]
|
|
rl, err := io.ReadFull(trans.Body, chunk)
|
|
if rl == 0 {
|
|
break
|
|
}
|
|
if errors.Is(err, io.ErrUnexpectedEOF) {
|
|
chunk = chunk[:rl]
|
|
}
|
|
ch := md5.Sum(chunk)
|
|
head, _ := proto.Marshal(&pb.ReqDataHighwayHead{
|
|
MsgBasehead: &pb.DataHighwayHead{
|
|
Version: 1,
|
|
Uin: s.Uin,
|
|
Command: _REQ_CMD_DATA,
|
|
Seq: s.nextSeq(),
|
|
Appid: s.AppID,
|
|
Dataflag: 4096,
|
|
CommandId: trans.CommandID,
|
|
LocaleId: 2052,
|
|
},
|
|
MsgSeghead: &pb.SegHead{
|
|
Filesize: trans.Size,
|
|
Dataoffset: int64(offset),
|
|
Datalength: int32(rl),
|
|
Serviceticket: trans.Ticket,
|
|
Md5: ch[:],
|
|
FileMd5: trans.Sum,
|
|
},
|
|
ReqExtendinfo: trans.Ext,
|
|
})
|
|
offset += rl
|
|
buffers := frame(head, chunk)
|
|
_, err = buffers.WriteTo(pc.conn)
|
|
if err != nil {
|
|
return nil, errors.Wrap(err, "write conn error")
|
|
}
|
|
rspHead, err := readResponse(reader)
|
|
if err != nil {
|
|
return nil, errors.Wrap(err, "highway upload error")
|
|
}
|
|
if rspHead.ErrorCode != 0 {
|
|
return nil, errors.Errorf("upload failed: %d", rspHead.ErrorCode)
|
|
}
|
|
if rspHead.RspExtendinfo != nil {
|
|
rspExt = rspHead.RspExtendinfo
|
|
}
|
|
if rspHead.MsgSeghead != nil && rspHead.MsgSeghead.Serviceticket != nil {
|
|
trans.Ticket = rspHead.MsgSeghead.Serviceticket
|
|
}
|
|
}
|
|
return rspExt, nil
|
|
}
|
|
|
|
func (s *Session) Upload(trans Transaction) ([]byte, error) {
|
|
// encrypt ext data
|
|
if err := trans.encrypt(s.SessionKey); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
const maxThreadCount = 4
|
|
threadCount := int(trans.Size) / (3 * 512 * 1024) // 1 thread upload 1.5 MB
|
|
if threadCount > maxThreadCount {
|
|
threadCount = maxThreadCount
|
|
}
|
|
if threadCount < 2 {
|
|
// single thread upload
|
|
return s.uploadSingle(trans)
|
|
}
|
|
|
|
// pick a address
|
|
// TODO: pick smarter
|
|
pc, err := s.selectConn()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
addr := pc.addr
|
|
s.putIdleConn(pc)
|
|
|
|
const blockSize int64 = 256 * 1024
|
|
var (
|
|
rspExt []byte
|
|
completedThread uint32
|
|
cond = sync.NewCond(&sync.Mutex{})
|
|
offset = int64(0)
|
|
count = (trans.Size + blockSize - 1) / blockSize
|
|
id = 0
|
|
)
|
|
doUpload := func() error {
|
|
// send signal complete uploading
|
|
defer func() {
|
|
atomic.AddUint32(&completedThread, 1)
|
|
cond.Signal()
|
|
}()
|
|
|
|
// todo: get from pool?
|
|
pc, err := s.connect(addr)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer s.putIdleConn(pc)
|
|
|
|
reader := binary.NewNetworkReader(pc.conn)
|
|
chunk := make([]byte, blockSize)
|
|
for {
|
|
cond.L.Lock() // lock protect reading
|
|
off := offset
|
|
offset += blockSize
|
|
id++
|
|
last := int64(id) == count
|
|
if last { // last
|
|
for atomic.LoadUint32(&completedThread) != uint32(threadCount-1) {
|
|
cond.Wait()
|
|
}
|
|
} else if int64(id) > count {
|
|
cond.L.Unlock()
|
|
break
|
|
}
|
|
chunk = chunk[:blockSize]
|
|
n, err := io.ReadFull(trans.Body, chunk)
|
|
cond.L.Unlock()
|
|
|
|
if n == 0 {
|
|
break
|
|
}
|
|
if errors.Is(err, io.ErrUnexpectedEOF) {
|
|
chunk = chunk[:n]
|
|
}
|
|
ch := md5.Sum(chunk)
|
|
head, _ := proto.Marshal(&pb.ReqDataHighwayHead{
|
|
MsgBasehead: &pb.DataHighwayHead{
|
|
Version: 1,
|
|
Uin: s.Uin,
|
|
Command: _REQ_CMD_DATA,
|
|
Seq: s.nextSeq(),
|
|
Appid: s.AppID,
|
|
Dataflag: 4096,
|
|
CommandId: trans.CommandID,
|
|
LocaleId: 2052,
|
|
},
|
|
MsgSeghead: &pb.SegHead{
|
|
Filesize: trans.Size,
|
|
Dataoffset: off,
|
|
Datalength: int32(n),
|
|
Serviceticket: trans.Ticket,
|
|
Md5: ch[:],
|
|
FileMd5: trans.Sum,
|
|
},
|
|
ReqExtendinfo: trans.Ext,
|
|
})
|
|
buffers := frame(head, chunk)
|
|
_, err = buffers.WriteTo(pc.conn)
|
|
if err != nil {
|
|
return errors.Wrap(err, "write conn error")
|
|
}
|
|
rspHead, err := readResponse(reader)
|
|
if err != nil {
|
|
return errors.Wrap(err, "highway upload error")
|
|
}
|
|
if rspHead.ErrorCode != 0 {
|
|
return errors.Errorf("upload failed: %d", rspHead.ErrorCode)
|
|
}
|
|
if last && rspHead.RspExtendinfo != nil {
|
|
rspExt = rspHead.RspExtendinfo
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
group := errgroup.Group{}
|
|
for i := 0; i < threadCount; i++ {
|
|
group.Go(doUpload)
|
|
}
|
|
return rspExt, group.Wait()
|
|
}
|