1
0
mirror of https://github.com/Mrs4s/MiraiGo.git synced 2025-05-04 11:07:40 +08:00

highway: refactor multi thread uploading

This commit is contained in:
wdvxdr 2022-02-27 15:42:02 +08:00
parent e6ad62569d
commit ca4580dad5
No known key found for this signature in database
GPG Key ID: 703F8C071DE7A1B6
3 changed files with 66 additions and 112 deletions

View File

@ -2,6 +2,7 @@ package client
import ( import (
"bytes" "bytes"
"crypto/md5"
"encoding/hex" "encoding/hex"
"io" "io"
"math/rand" "math/rand"
@ -93,6 +94,7 @@ func (c *QQClient) uploadGroupOrGuildImage(target message.Source, img io.ReadSee
var r interface{} var r interface{}
var err error var err error
var input highway.BdhInput
switch target.SourceType { switch target.SourceType {
case message.SourceGroup: case message.SourceGroup:
r, err = c.sendAndWait(c.buildGroupImageStorePacket(target.PrimaryID, fh, int32(length))) r, err = c.sendAndWait(c.buildGroupImageStorePacket(target.PrimaryID, fh, int32(length)))
@ -115,22 +117,18 @@ func (c *QQClient) uploadGroupOrGuildImage(target message.Source, img io.ReadSee
} }
} }
input = highway.BdhInput{
CommandID: cmd,
Body: img,
Size: length,
Sum: fh,
Ticket: rsp.UploadKey,
Ext: ext,
}
if tc > 1 && length > 3*1024*1024 { if tc > 1 && length > 3*1024*1024 {
_, err = c.highwaySession.UploadBDHMultiThread(highway.BdhMultiThreadInput{ _, err = c.highwaySession.UploadBDHMultiThread(input, tc)
CommandID: cmd,
Body: utils.ReaderAtFrom2ReadSeeker(img, nil),
Size: length,
Sum: fh,
Ticket: rsp.UploadKey,
Ext: ext,
}, 4)
} else { } else {
_, err = c.highwaySession.UploadBDH(highway.BdhInput{ _, err = c.highwaySession.UploadBDH(input)
CommandID: cmd,
Body: img,
Ticket: rsp.UploadKey,
Ext: ext,
})
} }
if err != nil { if err != nil {
return nil, errors.Wrap(err, "upload failed") return nil, errors.Wrap(err, "upload failed")
@ -326,10 +324,13 @@ func (c *QQClient) uploadOcrImage(img io.Reader) (string, error) {
Uuid: binary.GenUUID(r), Uuid: binary.GenUUID(r),
}) })
buf, _ := io.ReadAll(img) h := md5.New()
buf, _ := io.ReadAll(io.TeeReader(img, h))
rsp, err := c.highwaySession.UploadBDH(highway.BdhInput{ rsp, err := c.highwaySession.UploadBDH(highway.BdhInput{
CommandID: 76, CommandID: 76,
Body: bytes.NewReader(buf), Body: bytes.NewReader(buf),
Size: int64(len(buf)),
Sum: h.Sum(nil),
Ticket: c.highwaySession.SigSession, Ticket: c.highwaySession.SigSession,
Ext: ext, Ext: ext,
Encrypt: false, Encrypt: false,

View File

@ -15,22 +15,13 @@ import (
"github.com/Mrs4s/MiraiGo/client/internal/network" "github.com/Mrs4s/MiraiGo/client/internal/network"
"github.com/Mrs4s/MiraiGo/client/pb" "github.com/Mrs4s/MiraiGo/client/pb"
"github.com/Mrs4s/MiraiGo/internal/proto" "github.com/Mrs4s/MiraiGo/internal/proto"
"github.com/Mrs4s/MiraiGo/utils"
) )
type BdhInput struct { type BdhInput struct {
CommandID int32 CommandID int32
Body io.ReadSeeker Body io.Reader
Ticket []byte Sum []byte // md5 sum of body
Ext []byte Size int64 // body size
Encrypt bool
}
type BdhMultiThreadInput struct {
CommandID int32
Body io.ReaderAt
Sum []byte
Size int64
Ticket []byte Ticket []byte
Ext []byte Ext []byte
Encrypt bool Encrypt bool
@ -46,24 +37,12 @@ func (bdh *BdhInput) encrypt(key []byte) error {
return nil return nil
} }
func (bdh *BdhMultiThreadInput) encrypt(key []byte) error {
if bdh.Encrypt {
if len(key) == 0 {
return errors.New("session key not found. maybe miss some packet?")
}
bdh.Ext = binary.NewTeaCipher(key).Encrypt(bdh.Ext)
}
return nil
}
func (s *Session) UploadBDH(input BdhInput) ([]byte, error) { func (s *Session) UploadBDH(input BdhInput) ([]byte, error) {
if len(s.SsoAddr) == 0 { if len(s.SsoAddr) == 0 {
return nil, errors.New("srv addrs not found. maybe miss some packet?") return nil, errors.New("srv addrs not found. maybe miss some packet?")
} }
addr := s.SsoAddr[0].String() addr := s.SsoAddr[0].String()
sum, length := utils.ComputeMd5AndLength(input.Body)
_, _ = input.Body.Seek(0, io.SeekStart)
if err := input.encrypt(s.SessionKey); err != nil { if err := input.encrypt(s.SessionKey); err != nil {
return nil, err return nil, err
} }
@ -95,12 +74,12 @@ func (s *Session) UploadBDH(input BdhInput) ([]byte, error) {
head, _ := proto.Marshal(&pb.ReqDataHighwayHead{ head, _ := proto.Marshal(&pb.ReqDataHighwayHead{
MsgBasehead: s.dataHighwayHead(4096, input.CommandID, 2052), MsgBasehead: s.dataHighwayHead(4096, input.CommandID, 2052),
MsgSeghead: &pb.SegHead{ MsgSeghead: &pb.SegHead{
Filesize: length, Filesize: input.Size,
Dataoffset: int64(offset), Dataoffset: int64(offset),
Datalength: int32(rl), Datalength: int32(rl),
Serviceticket: input.Ticket, Serviceticket: input.Ticket,
Md5: ch[:], Md5: ch[:],
FileMd5: sum, FileMd5: input.Sum,
}, },
ReqExtendinfo: input.Ext, ReqExtendinfo: input.Ext,
}) })
@ -127,17 +106,11 @@ func (s *Session) UploadBDH(input BdhInput) ([]byte, error) {
return rspExt, nil return rspExt, nil
} }
func (s *Session) UploadBDHMultiThread(input BdhMultiThreadInput, threadCount int) ([]byte, error) { func (s *Session) UploadBDHMultiThread(input BdhInput, threadCount int) ([]byte, error) {
// for small file and small thread count, // for small file and small thread count,
// use UploadBDH instead of UploadBDHMultiThread // use UploadBDH instead of UploadBDHMultiThread
if input.Size < 1024*1024*3 || threadCount < 2 { if input.Size < 1024*1024*3 || threadCount < 2 {
return s.UploadBDH(BdhInput{ return s.UploadBDH(input)
CommandID: input.CommandID,
Body: io.NewSectionReader(input.Body, 0, input.Size),
Ticket: input.Ticket,
Ext: input.Ext,
Encrypt: input.Encrypt,
})
} }
if len(s.SsoAddr) == 0 { if len(s.SsoAddr) == 0 {
@ -149,36 +122,21 @@ func (s *Session) UploadBDHMultiThread(input BdhMultiThreadInput, threadCount in
return nil, err return nil, err
} }
type BlockMetaData struct {
Id int
Offset int64
}
const blockSize int64 = 1024 * 512 const blockSize int64 = 1024 * 512
var ( var (
blocks []BlockMetaData rspExt []byte
rspExt []byte completedThread uint32
BlockId = ^uint32(0) // -1 cond = sync.NewCond(&sync.Mutex{})
uploadedCount uint32 offset = int64(0)
cond = sync.NewCond(&sync.Mutex{}) count = (input.Size + blockSize - 1) / blockSize
id = 0
) )
// Init Blocks
{
var temp int64 = 0
for temp+blockSize < input.Size {
blocks = append(blocks, BlockMetaData{
Id: len(blocks),
Offset: temp,
})
temp += blockSize
}
blocks = append(blocks, BlockMetaData{
Id: len(blocks),
Offset: temp,
})
}
doUpload := func() error { doUpload := func() error {
// send signal complete uploading // send signal complete uploading
defer cond.Signal() defer func() {
atomic.AddUint32(&completedThread, 1)
cond.Signal()
}()
conn, err := net.DialTimeout("tcp", addr, time.Second*20) conn, err := net.DialTimeout("tcp", addr, time.Second*20)
if err != nil { if err != nil {
@ -192,22 +150,20 @@ func (s *Session) UploadBDHMultiThread(input BdhMultiThreadInput, threadCount in
chunk := make([]byte, blockSize) chunk := make([]byte, blockSize)
for { for {
nextId := atomic.AddUint32(&BlockId, 1) cond.L.Lock() // lock protect reading
if nextId >= uint32(len(blocks)) { off := offset
break offset += blockSize
} id++
block := blocks[nextId] if int64(id) == count { // last
if block.Id == len(blocks)-1 { for atomic.LoadUint32(&completedThread) != uint32(threadCount-1) {
cond.L.Lock()
for atomic.LoadUint32(&uploadedCount) != uint32(len(blocks))-1 {
cond.Wait() cond.Wait()
} }
} else if int64(id) > count {
cond.L.Unlock() cond.L.Unlock()
break
} }
chunk = chunk[:blockSize] chunk = chunk[:blockSize]
n, err := io.ReadFull(input.Body, chunk)
cond.L.Lock() // lock protect reading
n, err := input.Body.ReadAt(chunk, block.Offset)
cond.L.Unlock() cond.L.Unlock()
if err != nil { if err != nil {
@ -225,7 +181,7 @@ func (s *Session) UploadBDHMultiThread(input BdhMultiThreadInput, threadCount in
MsgBasehead: s.dataHighwayHead(4096, input.CommandID, 2052), MsgBasehead: s.dataHighwayHead(4096, input.CommandID, 2052),
MsgSeghead: &pb.SegHead{ MsgSeghead: &pb.SegHead{
Filesize: input.Size, Filesize: input.Size,
Dataoffset: block.Offset, Dataoffset: off,
Datalength: int32(n), Datalength: int32(n),
Serviceticket: input.Ticket, Serviceticket: input.Ticket,
Md5: ch[:], Md5: ch[:],
@ -248,7 +204,6 @@ func (s *Session) UploadBDHMultiThread(input BdhMultiThreadInput, threadCount in
if rspHead.RspExtendinfo != nil { if rspHead.RspExtendinfo != nil {
rspExt = rspHead.RspExtendinfo rspExt = rspHead.RspExtendinfo
} }
atomic.AddUint32(&uploadedCount, 1)
} }
return nil return nil
} }

View File

@ -1,6 +1,7 @@
package client package client
import ( import (
"crypto/md5"
"encoding/hex" "encoding/hex"
"io" "io"
@ -74,6 +75,8 @@ func (c *QQClient) UploadVoice(target message.Source, voice io.ReadSeeker) (*mes
rsp, err := c.highwaySession.UploadBDH(highway.BdhInput{ rsp, err := c.highwaySession.UploadBDH(highway.BdhInput{
CommandID: cmd, CommandID: cmd,
Body: voice, Body: voice,
Sum: fh,
Size: length,
Ticket: c.highwaySession.SigSession, Ticket: c.highwaySession.SigSession,
Ext: ext, Ext: ext,
Encrypt: false, Encrypt: false,
@ -120,14 +123,17 @@ func (c *QQClient) UploadVoice(target message.Source, voice io.ReadSeeker) (*mes
// UploadShortVideo 将视频和封面上传到服务器, 返回 message.ShortVideoElement 可直接发送 // UploadShortVideo 将视频和封面上传到服务器, 返回 message.ShortVideoElement 可直接发送
// thread 上传线程数 // thread 上传线程数
func (c *QQClient) UploadShortVideo(target message.Source, video, thumb io.ReadSeeker, thread int) (*message.ShortVideoElement, error) { func (c *QQClient) UploadShortVideo(target message.Source, video, thumb io.ReadSeeker, thread int) (*message.ShortVideoElement, error) {
videoHash, videoLen := utils.ComputeMd5AndLength(video) thumbHash := md5.New()
thumbHash, thumbLen := utils.ComputeMd5AndLength(thumb) thumbLen, _ := io.Copy(thumbHash, thumb)
thumbSum := thumbHash.Sum(nil)
videoSum, videoLen := utils.ComputeMd5AndLength(io.TeeReader(video, thumbHash))
sum := thumbHash.Sum(nil)
key := string(videoHash) + string(thumbHash) key := string(sum)
pttWaiter.Wait(key) pttWaiter.Wait(key)
defer pttWaiter.Done(key) defer pttWaiter.Done(key)
i, err := c.sendAndWait(c.buildPttGroupShortVideoUploadReqPacket(target, videoHash, thumbHash, videoLen, thumbLen)) i, err := c.sendAndWait(c.buildPttGroupShortVideoUploadReqPacket(target, videoSum, thumbSum, videoLen, thumbLen))
if err != nil { if err != nil {
return nil, errors.Wrap(err, "upload req error") return nil, errors.Wrap(err, "upload req error")
} }
@ -135,8 +141,8 @@ func (c *QQClient) UploadShortVideo(target message.Source, video, thumb io.ReadS
videoElement := &message.ShortVideoElement{ videoElement := &message.ShortVideoElement{
Size: int32(videoLen), Size: int32(videoLen),
ThumbSize: int32(thumbLen), ThumbSize: int32(thumbLen),
Md5: videoHash, Md5: videoSum,
ThumbMd5: thumbHash, ThumbMd5: thumbSum,
Guild: target.SourceType == message.SourceGuildChannel, Guild: target.SourceType == message.SourceGuildChannel,
} }
if rsp.FileExists == 1 { if rsp.FileExists == 1 {
@ -149,28 +155,20 @@ func (c *QQClient) UploadShortVideo(target message.Source, video, thumb io.ReadS
if target.SourceType == message.SourceGuildChannel { if target.SourceType == message.SourceGuildChannel {
cmd = 89 cmd = 89
} }
ext, _ := proto.Marshal(c.buildPttShortVideoProto(target, videoHash, thumbHash, videoLen, thumbLen).PttShortVideoUploadReq) ext, _ := proto.Marshal(c.buildPttShortVideoProto(target, videoSum, thumbSum, videoLen, thumbLen).PttShortVideoUploadReq)
combined := utils.MultiReadSeeker(thumb, video)
input := highway.BdhInput{
CommandID: cmd,
Body: combined,
Size: videoLen + thumbLen,
Sum: sum,
Ticket: c.highwaySession.SigSession,
Ext: ext,
Encrypt: true,
}
if thread > 1 { if thread > 1 {
sum, _ := utils.ComputeMd5AndLength(utils.MultiReadSeeker(thumb, video))
input := highway.BdhMultiThreadInput{
CommandID: cmd,
Body: utils.ReaderAtFrom2ReadSeeker(thumb, video),
Size: videoLen + thumbLen,
Sum: sum,
Ticket: c.highwaySession.SigSession,
Ext: ext,
Encrypt: true,
}
hwRsp, err = c.highwaySession.UploadBDHMultiThread(input, thread) hwRsp, err = c.highwaySession.UploadBDHMultiThread(input, thread)
} else { } else {
multi := utils.MultiReadSeeker(thumb, video)
input := highway.BdhInput{
CommandID: cmd,
Body: multi,
Ticket: c.highwaySession.SigSession,
Ext: ext,
Encrypt: true,
}
hwRsp, err = c.highwaySession.UploadBDH(input) hwRsp, err = c.highwaySession.UploadBDH(input)
} }
if err != nil { if err != nil {