1
0
mirror of https://github.com/Mrs4s/MiraiGo.git synced 2025-05-04 11:07:40 +08:00

highway: refactor multi thread uploading

This commit is contained in:
wdvxdr 2022-02-27 15:42:02 +08:00
parent e6ad62569d
commit ca4580dad5
No known key found for this signature in database
GPG Key ID: 703F8C071DE7A1B6
3 changed files with 66 additions and 112 deletions

View File

@ -2,6 +2,7 @@ package client
import (
"bytes"
"crypto/md5"
"encoding/hex"
"io"
"math/rand"
@ -93,6 +94,7 @@ func (c *QQClient) uploadGroupOrGuildImage(target message.Source, img io.ReadSee
var r interface{}
var err error
var input highway.BdhInput
switch target.SourceType {
case message.SourceGroup:
r, err = c.sendAndWait(c.buildGroupImageStorePacket(target.PrimaryID, fh, int32(length)))
@ -115,22 +117,18 @@ func (c *QQClient) uploadGroupOrGuildImage(target message.Source, img io.ReadSee
}
}
input = highway.BdhInput{
CommandID: cmd,
Body: img,
Size: length,
Sum: fh,
Ticket: rsp.UploadKey,
Ext: ext,
}
if tc > 1 && length > 3*1024*1024 {
_, err = c.highwaySession.UploadBDHMultiThread(highway.BdhMultiThreadInput{
CommandID: cmd,
Body: utils.ReaderAtFrom2ReadSeeker(img, nil),
Size: length,
Sum: fh,
Ticket: rsp.UploadKey,
Ext: ext,
}, 4)
_, err = c.highwaySession.UploadBDHMultiThread(input, tc)
} else {
_, err = c.highwaySession.UploadBDH(highway.BdhInput{
CommandID: cmd,
Body: img,
Ticket: rsp.UploadKey,
Ext: ext,
})
_, err = c.highwaySession.UploadBDH(input)
}
if err != nil {
return nil, errors.Wrap(err, "upload failed")
@ -326,10 +324,13 @@ func (c *QQClient) uploadOcrImage(img io.Reader) (string, error) {
Uuid: binary.GenUUID(r),
})
buf, _ := io.ReadAll(img)
h := md5.New()
buf, _ := io.ReadAll(io.TeeReader(img, h))
rsp, err := c.highwaySession.UploadBDH(highway.BdhInput{
CommandID: 76,
Body: bytes.NewReader(buf),
Size: int64(len(buf)),
Sum: h.Sum(nil),
Ticket: c.highwaySession.SigSession,
Ext: ext,
Encrypt: false,

View File

@ -15,22 +15,13 @@ import (
"github.com/Mrs4s/MiraiGo/client/internal/network"
"github.com/Mrs4s/MiraiGo/client/pb"
"github.com/Mrs4s/MiraiGo/internal/proto"
"github.com/Mrs4s/MiraiGo/utils"
)
type BdhInput struct {
CommandID int32
Body io.ReadSeeker
Ticket []byte
Ext []byte
Encrypt bool
}
type BdhMultiThreadInput struct {
CommandID int32
Body io.ReaderAt
Sum []byte
Size int64
Body io.Reader
Sum []byte // md5 sum of body
Size int64 // body size
Ticket []byte
Ext []byte
Encrypt bool
@ -46,24 +37,12 @@ func (bdh *BdhInput) encrypt(key []byte) error {
return nil
}
func (bdh *BdhMultiThreadInput) encrypt(key []byte) error {
if bdh.Encrypt {
if len(key) == 0 {
return errors.New("session key not found. maybe miss some packet?")
}
bdh.Ext = binary.NewTeaCipher(key).Encrypt(bdh.Ext)
}
return nil
}
func (s *Session) UploadBDH(input BdhInput) ([]byte, error) {
if len(s.SsoAddr) == 0 {
return nil, errors.New("srv addrs not found. maybe miss some packet?")
}
addr := s.SsoAddr[0].String()
sum, length := utils.ComputeMd5AndLength(input.Body)
_, _ = input.Body.Seek(0, io.SeekStart)
if err := input.encrypt(s.SessionKey); err != nil {
return nil, err
}
@ -95,12 +74,12 @@ func (s *Session) UploadBDH(input BdhInput) ([]byte, error) {
head, _ := proto.Marshal(&pb.ReqDataHighwayHead{
MsgBasehead: s.dataHighwayHead(4096, input.CommandID, 2052),
MsgSeghead: &pb.SegHead{
Filesize: length,
Filesize: input.Size,
Dataoffset: int64(offset),
Datalength: int32(rl),
Serviceticket: input.Ticket,
Md5: ch[:],
FileMd5: sum,
FileMd5: input.Sum,
},
ReqExtendinfo: input.Ext,
})
@ -127,17 +106,11 @@ func (s *Session) UploadBDH(input BdhInput) ([]byte, error) {
return rspExt, nil
}
func (s *Session) UploadBDHMultiThread(input BdhMultiThreadInput, threadCount int) ([]byte, error) {
func (s *Session) UploadBDHMultiThread(input BdhInput, threadCount int) ([]byte, error) {
// for small file and small thread count,
// use UploadBDH instead of UploadBDHMultiThread
if input.Size < 1024*1024*3 || threadCount < 2 {
return s.UploadBDH(BdhInput{
CommandID: input.CommandID,
Body: io.NewSectionReader(input.Body, 0, input.Size),
Ticket: input.Ticket,
Ext: input.Ext,
Encrypt: input.Encrypt,
})
return s.UploadBDH(input)
}
if len(s.SsoAddr) == 0 {
@ -149,36 +122,21 @@ func (s *Session) UploadBDHMultiThread(input BdhMultiThreadInput, threadCount in
return nil, err
}
type BlockMetaData struct {
Id int
Offset int64
}
const blockSize int64 = 1024 * 512
var (
blocks []BlockMetaData
rspExt []byte
BlockId = ^uint32(0) // -1
uploadedCount uint32
cond = sync.NewCond(&sync.Mutex{})
rspExt []byte
completedThread uint32
cond = sync.NewCond(&sync.Mutex{})
offset = int64(0)
count = (input.Size + blockSize - 1) / blockSize
id = 0
)
// Init Blocks
{
var temp int64 = 0
for temp+blockSize < input.Size {
blocks = append(blocks, BlockMetaData{
Id: len(blocks),
Offset: temp,
})
temp += blockSize
}
blocks = append(blocks, BlockMetaData{
Id: len(blocks),
Offset: temp,
})
}
doUpload := func() error {
// send signal complete uploading
defer cond.Signal()
defer func() {
atomic.AddUint32(&completedThread, 1)
cond.Signal()
}()
conn, err := net.DialTimeout("tcp", addr, time.Second*20)
if err != nil {
@ -192,22 +150,20 @@ func (s *Session) UploadBDHMultiThread(input BdhMultiThreadInput, threadCount in
chunk := make([]byte, blockSize)
for {
nextId := atomic.AddUint32(&BlockId, 1)
if nextId >= uint32(len(blocks)) {
break
}
block := blocks[nextId]
if block.Id == len(blocks)-1 {
cond.L.Lock()
for atomic.LoadUint32(&uploadedCount) != uint32(len(blocks))-1 {
cond.L.Lock() // lock protect reading
off := offset
offset += blockSize
id++
if int64(id) == count { // last
for atomic.LoadUint32(&completedThread) != uint32(threadCount-1) {
cond.Wait()
}
} else if int64(id) > count {
cond.L.Unlock()
break
}
chunk = chunk[:blockSize]
cond.L.Lock() // lock protect reading
n, err := input.Body.ReadAt(chunk, block.Offset)
n, err := io.ReadFull(input.Body, chunk)
cond.L.Unlock()
if err != nil {
@ -225,7 +181,7 @@ func (s *Session) UploadBDHMultiThread(input BdhMultiThreadInput, threadCount in
MsgBasehead: s.dataHighwayHead(4096, input.CommandID, 2052),
MsgSeghead: &pb.SegHead{
Filesize: input.Size,
Dataoffset: block.Offset,
Dataoffset: off,
Datalength: int32(n),
Serviceticket: input.Ticket,
Md5: ch[:],
@ -248,7 +204,6 @@ func (s *Session) UploadBDHMultiThread(input BdhMultiThreadInput, threadCount in
if rspHead.RspExtendinfo != nil {
rspExt = rspHead.RspExtendinfo
}
atomic.AddUint32(&uploadedCount, 1)
}
return nil
}

View File

@ -1,6 +1,7 @@
package client
import (
"crypto/md5"
"encoding/hex"
"io"
@ -74,6 +75,8 @@ func (c *QQClient) UploadVoice(target message.Source, voice io.ReadSeeker) (*mes
rsp, err := c.highwaySession.UploadBDH(highway.BdhInput{
CommandID: cmd,
Body: voice,
Sum: fh,
Size: length,
Ticket: c.highwaySession.SigSession,
Ext: ext,
Encrypt: false,
@ -120,14 +123,17 @@ func (c *QQClient) UploadVoice(target message.Source, voice io.ReadSeeker) (*mes
// UploadShortVideo 将视频和封面上传到服务器, 返回 message.ShortVideoElement 可直接发送
// thread 上传线程数
func (c *QQClient) UploadShortVideo(target message.Source, video, thumb io.ReadSeeker, thread int) (*message.ShortVideoElement, error) {
videoHash, videoLen := utils.ComputeMd5AndLength(video)
thumbHash, thumbLen := utils.ComputeMd5AndLength(thumb)
thumbHash := md5.New()
thumbLen, _ := io.Copy(thumbHash, thumb)
thumbSum := thumbHash.Sum(nil)
videoSum, videoLen := utils.ComputeMd5AndLength(io.TeeReader(video, thumbHash))
sum := thumbHash.Sum(nil)
key := string(videoHash) + string(thumbHash)
key := string(sum)
pttWaiter.Wait(key)
defer pttWaiter.Done(key)
i, err := c.sendAndWait(c.buildPttGroupShortVideoUploadReqPacket(target, videoHash, thumbHash, videoLen, thumbLen))
i, err := c.sendAndWait(c.buildPttGroupShortVideoUploadReqPacket(target, videoSum, thumbSum, videoLen, thumbLen))
if err != nil {
return nil, errors.Wrap(err, "upload req error")
}
@ -135,8 +141,8 @@ func (c *QQClient) UploadShortVideo(target message.Source, video, thumb io.ReadS
videoElement := &message.ShortVideoElement{
Size: int32(videoLen),
ThumbSize: int32(thumbLen),
Md5: videoHash,
ThumbMd5: thumbHash,
Md5: videoSum,
ThumbMd5: thumbSum,
Guild: target.SourceType == message.SourceGuildChannel,
}
if rsp.FileExists == 1 {
@ -149,28 +155,20 @@ func (c *QQClient) UploadShortVideo(target message.Source, video, thumb io.ReadS
if target.SourceType == message.SourceGuildChannel {
cmd = 89
}
ext, _ := proto.Marshal(c.buildPttShortVideoProto(target, videoHash, thumbHash, videoLen, thumbLen).PttShortVideoUploadReq)
ext, _ := proto.Marshal(c.buildPttShortVideoProto(target, videoSum, thumbSum, videoLen, thumbLen).PttShortVideoUploadReq)
combined := utils.MultiReadSeeker(thumb, video)
input := highway.BdhInput{
CommandID: cmd,
Body: combined,
Size: videoLen + thumbLen,
Sum: sum,
Ticket: c.highwaySession.SigSession,
Ext: ext,
Encrypt: true,
}
if thread > 1 {
sum, _ := utils.ComputeMd5AndLength(utils.MultiReadSeeker(thumb, video))
input := highway.BdhMultiThreadInput{
CommandID: cmd,
Body: utils.ReaderAtFrom2ReadSeeker(thumb, video),
Size: videoLen + thumbLen,
Sum: sum,
Ticket: c.highwaySession.SigSession,
Ext: ext,
Encrypt: true,
}
hwRsp, err = c.highwaySession.UploadBDHMultiThread(input, thread)
} else {
multi := utils.MultiReadSeeker(thumb, video)
input := highway.BdhInput{
CommandID: cmd,
Body: multi,
Ticket: c.highwaySession.SigSession,
Ext: ext,
Encrypt: true,
}
hwRsp, err = c.highwaySession.UploadBDH(input)
}
if err != nil {