1
0
mirror of https://github.com/Mrs4s/MiraiGo.git synced 2025-05-04 11:07:40 +08:00

264 lines
6.2 KiB
Go

package highway
import (
"crypto/md5"
"io"
"net"
"sync"
"sync/atomic"
"time"
"github.com/pkg/errors"
"golang.org/x/sync/errgroup"
"github.com/Mrs4s/MiraiGo/binary"
"github.com/Mrs4s/MiraiGo/client/internal/network"
"github.com/Mrs4s/MiraiGo/client/pb"
"github.com/Mrs4s/MiraiGo/internal/proto"
"github.com/Mrs4s/MiraiGo/utils"
)
type BdhInput struct {
CommandID int32
Body io.ReadSeeker
Ticket []byte
Ext []byte
Encrypt bool
}
type BdhMultiThreadInput struct {
CommandID int32
Body io.ReadSeeker
Sum []byte
Size int64
Ticket []byte
Ext []byte
Encrypt bool
}
func (bdh *BdhInput) encrypt(key []byte) error {
if bdh.Encrypt {
if len(key) == 0 {
return errors.New("session key not found. maybe miss some packet?")
}
bdh.Ext = binary.NewTeaCipher(key).Encrypt(bdh.Ext)
}
return nil
}
func (bdh *BdhMultiThreadInput) encrypt(key []byte) error {
if bdh.Encrypt {
if len(key) == 0 {
return errors.New("session key not found. maybe miss some packet?")
}
bdh.Ext = binary.NewTeaCipher(key).Encrypt(bdh.Ext)
}
return nil
}
func (s *Session) UploadBDH(input BdhInput) ([]byte, error) {
if len(s.SsoAddr) == 0 {
return nil, errors.New("srv addrs not found. maybe miss some packet?")
}
addr := s.SsoAddr[0].String()
sum, length := utils.ComputeMd5AndLength(input.Body)
_, _ = input.Body.Seek(0, io.SeekStart)
if err := input.encrypt(s.SessionKey); err != nil {
return nil, err
}
conn, err := net.DialTimeout("tcp", addr, time.Second*20)
if err != nil {
return nil, errors.Wrap(err, "connect error")
}
defer conn.Close()
reader := binary.NewNetworkReader(conn)
if err = s.sendEcho(conn); err != nil {
return nil, err
}
const chunkSize = 256 * 1024
var rspExt []byte
offset := 0
chunk := make([]byte, chunkSize)
for {
chunk = chunk[:chunkSize]
rl, err := io.ReadFull(input.Body, chunk)
if errors.Is(err, io.EOF) {
break
}
if errors.Is(err, io.ErrUnexpectedEOF) {
chunk = chunk[:rl]
}
ch := md5.Sum(chunk)
head, _ := proto.Marshal(&pb.ReqDataHighwayHead{
MsgBasehead: s.dataHighwayHead(4096, input.CommandID, 2052),
MsgSeghead: &pb.SegHead{
Filesize: length,
Dataoffset: int64(offset),
Datalength: int32(rl),
Serviceticket: input.Ticket,
Md5: ch[:],
FileMd5: sum,
},
ReqExtendinfo: input.Ext,
})
offset += rl
frame := network.HeadBodyFrame(head, chunk)
_, err = frame.WriteTo(conn)
if err != nil {
return nil, errors.Wrap(err, "write conn error")
}
rspHead, _, err := readResponse(reader)
if err != nil {
return nil, errors.Wrap(err, "highway upload error")
}
if rspHead.ErrorCode != 0 {
return nil, errors.Errorf("upload failed: %d", rspHead.ErrorCode)
}
if rspHead.RspExtendinfo != nil {
rspExt = rspHead.RspExtendinfo
}
if rspHead.MsgSeghead != nil && rspHead.MsgSeghead.Serviceticket != nil {
input.Ticket = rspHead.MsgSeghead.Serviceticket
}
}
return rspExt, nil
}
func (s *Session) UploadBDHMultiThread(input BdhMultiThreadInput, threadCount int) ([]byte, error) {
// for small file and small thread count,
// use UploadBDH instead of UploadBDHMultiThread
// FIXME: enable multi-thread, now receive error code 81
if true || input.Size < 1024*1024*3 || threadCount < 2 {
return s.UploadBDH(BdhInput{
CommandID: input.CommandID,
Body: input.Body,
Ticket: input.Ticket,
Ext: input.Ext,
Encrypt: input.Encrypt,
})
}
if len(s.SsoAddr) == 0 {
return nil, errors.New("srv addrs not found. maybe miss some packet?")
}
addr := s.SsoAddr[0].String()
if err := input.encrypt(s.SessionKey); err != nil {
return nil, err
}
type BlockMetaData struct {
Id int
Offset int64
}
const blockSize int64 = 1024 * 512
var (
blocks []BlockMetaData
rspExt []byte
BlockId = ^uint32(0) // -1
uploadedCount uint32
cond = sync.NewCond(&sync.Mutex{})
)
// Init Blocks
{
var temp int64 = 0
for temp+blockSize < input.Size {
blocks = append(blocks, BlockMetaData{
Id: len(blocks),
Offset: temp,
})
temp += blockSize
}
blocks = append(blocks, BlockMetaData{
Id: len(blocks),
Offset: temp,
})
}
doUpload := func() error {
// send signal complete uploading
defer cond.Signal()
conn, err := net.DialTimeout("tcp", addr, time.Second*20)
if err != nil {
return errors.Wrap(err, "connect error")
}
defer conn.Close()
reader := binary.NewNetworkReader(conn)
if err = s.sendEcho(conn); err != nil {
return err
}
chunk := make([]byte, blockSize)
for {
nextId := atomic.AddUint32(&BlockId, 1)
if nextId >= uint32(len(blocks)) {
break
}
block := blocks[nextId]
if block.Id == len(blocks)-1 {
cond.L.Lock()
for atomic.LoadUint32(&uploadedCount) != uint32(len(blocks))-1 {
cond.Wait()
}
cond.L.Unlock()
}
chunk = chunk[:blockSize]
cond.L.Lock() // lock protect reading
_, _ = input.Body.Seek(block.Offset, io.SeekStart)
n, err := input.Body.Read(chunk)
cond.L.Unlock()
if err != nil {
if err == io.EOF {
break
}
if err == io.ErrUnexpectedEOF {
chunk = chunk[:n]
} else {
return err
}
}
ch := md5.Sum(chunk)
head, _ := proto.Marshal(&pb.ReqDataHighwayHead{
MsgBasehead: s.dataHighwayHead(4096, input.CommandID, 2052),
MsgSeghead: &pb.SegHead{
Filesize: input.Size,
Dataoffset: block.Offset,
Datalength: int32(n),
Serviceticket: input.Ticket,
Md5: ch[:],
FileMd5: input.Sum,
},
ReqExtendinfo: input.Ext,
})
frame := network.HeadBodyFrame(head, chunk)
_, err = frame.WriteTo(conn)
if err != nil {
return errors.Wrap(err, "write conn error")
}
rspHead, _, err := readResponse(reader)
if err != nil {
return errors.Wrap(err, "highway upload error")
}
if rspHead.ErrorCode != 0 {
return errors.Errorf("upload failed: %d", rspHead.ErrorCode)
}
if rspHead.RspExtendinfo != nil {
rspExt = rspHead.RspExtendinfo
}
atomic.AddUint32(&uploadedCount, 1)
}
return nil
}
group := errgroup.Group{}
for i := 0; i < threadCount; i++ {
group.Go(doUpload)
}
return rspExt, group.Wait()
}