1
0
mirror of https://github.com/Mrs4s/MiraiGo.git synced 2025-05-05 03:23:50 +08:00
2022-02-27 16:46:11 +08:00

217 lines
5.1 KiB
Go

package highway
import (
"crypto/md5"
"io"
"net"
"sync"
"sync/atomic"
"time"
"github.com/pkg/errors"
"golang.org/x/sync/errgroup"
"github.com/Mrs4s/MiraiGo/binary"
"github.com/Mrs4s/MiraiGo/client/internal/network"
"github.com/Mrs4s/MiraiGo/client/pb"
"github.com/Mrs4s/MiraiGo/internal/proto"
)
type BdhInput struct {
CommandID int32
Body io.Reader
Sum []byte // md5 sum of body
Size int64 // body size
Ticket []byte
Ext []byte
Encrypt bool
}
func (bdh *BdhInput) encrypt(key []byte) error {
if bdh.Encrypt {
if len(key) == 0 {
return errors.New("session key not found. maybe miss some packet?")
}
bdh.Ext = binary.NewTeaCipher(key).Encrypt(bdh.Ext)
}
return nil
}
func (s *Session) UploadBDH(input BdhInput) ([]byte, error) {
if len(s.SsoAddr) == 0 {
return nil, errors.New("srv addrs not found. maybe miss some packet?")
}
addr := s.SsoAddr[0].String()
if err := input.encrypt(s.SessionKey); err != nil {
return nil, err
}
conn, err := net.DialTimeout("tcp", addr, time.Second*20)
if err != nil {
return nil, errors.Wrap(err, "connect error")
}
defer conn.Close()
reader := binary.NewNetworkReader(conn)
if err = s.sendEcho(conn); err != nil {
return nil, err
}
const chunkSize = 256 * 1024
var rspExt []byte
offset := 0
chunk := make([]byte, chunkSize)
for {
chunk = chunk[:chunkSize]
rl, err := io.ReadFull(input.Body, chunk)
if errors.Is(err, io.EOF) {
break
}
if errors.Is(err, io.ErrUnexpectedEOF) {
chunk = chunk[:rl]
}
ch := md5.Sum(chunk)
head, _ := proto.Marshal(&pb.ReqDataHighwayHead{
MsgBasehead: s.dataHighwayHead(4096, input.CommandID, 2052),
MsgSeghead: &pb.SegHead{
Filesize: input.Size,
Dataoffset: int64(offset),
Datalength: int32(rl),
Serviceticket: input.Ticket,
Md5: ch[:],
FileMd5: input.Sum,
},
ReqExtendinfo: input.Ext,
})
offset += rl
frame := network.HeadBodyFrame(head, chunk)
_, err = frame.WriteTo(conn)
if err != nil {
return nil, errors.Wrap(err, "write conn error")
}
rspHead, _, err := readResponse(reader)
if err != nil {
return nil, errors.Wrap(err, "highway upload error")
}
if rspHead.ErrorCode != 0 {
return nil, errors.Errorf("upload failed: %d", rspHead.ErrorCode)
}
if rspHead.RspExtendinfo != nil {
rspExt = rspHead.RspExtendinfo
}
if rspHead.MsgSeghead != nil && rspHead.MsgSeghead.Serviceticket != nil {
input.Ticket = rspHead.MsgSeghead.Serviceticket
}
}
return rspExt, nil
}
func (s *Session) UploadBDHMultiThread(input BdhInput, threadCount int) ([]byte, error) {
// for small file and small thread count,
// use UploadBDH instead of UploadBDHMultiThread
if input.Size < 1024*1024*3 || threadCount < 2 {
return s.UploadBDH(input)
}
if len(s.SsoAddr) == 0 {
return nil, errors.New("srv addrs not found. maybe miss some packet?")
}
addr := s.SsoAddr[0].String()
if err := input.encrypt(s.SessionKey); err != nil {
return nil, err
}
const blockSize int64 = 1024 * 512
var (
rspExt []byte
completedThread uint32
cond = sync.NewCond(&sync.Mutex{})
offset = int64(0)
count = (input.Size + blockSize - 1) / blockSize
id = 0
)
doUpload := func() error {
// send signal complete uploading
defer func() {
atomic.AddUint32(&completedThread, 1)
cond.Signal()
}()
conn, err := net.DialTimeout("tcp", addr, time.Second*20)
if err != nil {
return errors.Wrap(err, "connect error")
}
defer conn.Close()
reader := binary.NewNetworkReader(conn)
if err = s.sendEcho(conn); err != nil {
return err
}
chunk := make([]byte, blockSize)
for {
cond.L.Lock() // lock protect reading
off := offset
offset += blockSize
id++
if int64(id) == count { // last
for atomic.LoadUint32(&completedThread) != uint32(threadCount-1) {
cond.Wait()
}
} else if int64(id) > count {
cond.L.Unlock()
break
}
chunk = chunk[:blockSize]
n, err := io.ReadFull(input.Body, chunk)
cond.L.Unlock()
if err != nil {
if err == io.EOF {
break
}
if err == io.ErrUnexpectedEOF {
chunk = chunk[:n]
} else {
return err
}
}
ch := md5.Sum(chunk)
head, _ := proto.Marshal(&pb.ReqDataHighwayHead{
MsgBasehead: s.dataHighwayHead(4096, input.CommandID, 2052),
MsgSeghead: &pb.SegHead{
Filesize: input.Size,
Dataoffset: off,
Datalength: int32(n),
Serviceticket: input.Ticket,
Md5: ch[:],
FileMd5: input.Sum,
},
ReqExtendinfo: input.Ext,
})
frame := network.HeadBodyFrame(head, chunk)
_, err = frame.WriteTo(conn)
if err != nil {
return errors.Wrap(err, "write conn error")
}
rspHead, _, err := readResponse(reader)
if err != nil {
return errors.Wrap(err, "highway upload error")
}
if rspHead.ErrorCode != 0 {
return errors.Errorf("upload failed: %d", rspHead.ErrorCode)
}
if rspHead.RspExtendinfo != nil {
rspExt = rspHead.RspExtendinfo
}
}
return nil
}
group := errgroup.Group{}
for i := 0; i < threadCount; i++ {
group.Go(doUpload)
}
return rspExt, group.Wait()
}