diff --git a/client/image.go b/client/image.go index 3806b493..a0cf0b0f 100644 --- a/client/image.go +++ b/client/image.go @@ -2,6 +2,7 @@ package client import ( "bytes" + "crypto/md5" "encoding/hex" "io" "math/rand" @@ -93,6 +94,7 @@ func (c *QQClient) uploadGroupOrGuildImage(target message.Source, img io.ReadSee var r interface{} var err error + var input highway.BdhInput switch target.SourceType { case message.SourceGroup: r, err = c.sendAndWait(c.buildGroupImageStorePacket(target.PrimaryID, fh, int32(length))) @@ -115,22 +117,18 @@ func (c *QQClient) uploadGroupOrGuildImage(target message.Source, img io.ReadSee } } + input = highway.BdhInput{ + CommandID: cmd, + Body: img, + Size: length, + Sum: fh, + Ticket: rsp.UploadKey, + Ext: ext, + } if tc > 1 && length > 3*1024*1024 { - _, err = c.highwaySession.UploadBDHMultiThread(highway.BdhMultiThreadInput{ - CommandID: cmd, - Body: utils.ReaderAtFrom2ReadSeeker(img, nil), - Size: length, - Sum: fh, - Ticket: rsp.UploadKey, - Ext: ext, - }, 4) + _, err = c.highwaySession.UploadBDHMultiThread(input, tc) } else { - _, err = c.highwaySession.UploadBDH(highway.BdhInput{ - CommandID: cmd, - Body: img, - Ticket: rsp.UploadKey, - Ext: ext, - }) + _, err = c.highwaySession.UploadBDH(input) } if err != nil { return nil, errors.Wrap(err, "upload failed") @@ -326,10 +324,13 @@ func (c *QQClient) uploadOcrImage(img io.Reader) (string, error) { Uuid: binary.GenUUID(r), }) - buf, _ := io.ReadAll(img) + h := md5.New() + buf, _ := io.ReadAll(io.TeeReader(img, h)) rsp, err := c.highwaySession.UploadBDH(highway.BdhInput{ CommandID: 76, Body: bytes.NewReader(buf), + Size: int64(len(buf)), + Sum: h.Sum(nil), Ticket: c.highwaySession.SigSession, Ext: ext, Encrypt: false, diff --git a/client/internal/highway/bdh.go b/client/internal/highway/bdh.go index 2921d37e..bfb70c61 100644 --- a/client/internal/highway/bdh.go +++ b/client/internal/highway/bdh.go @@ -15,22 +15,13 @@ import ( "github.com/Mrs4s/MiraiGo/client/internal/network" "github.com/Mrs4s/MiraiGo/client/pb" "github.com/Mrs4s/MiraiGo/internal/proto" - "github.com/Mrs4s/MiraiGo/utils" ) type BdhInput struct { CommandID int32 - Body io.ReadSeeker - Ticket []byte - Ext []byte - Encrypt bool -} - -type BdhMultiThreadInput struct { - CommandID int32 - Body io.ReaderAt - Sum []byte - Size int64 + Body io.Reader + Sum []byte // md5 sum of body + Size int64 // body size Ticket []byte Ext []byte Encrypt bool @@ -46,24 +37,12 @@ func (bdh *BdhInput) encrypt(key []byte) error { return nil } -func (bdh *BdhMultiThreadInput) encrypt(key []byte) error { - if bdh.Encrypt { - if len(key) == 0 { - return errors.New("session key not found. maybe miss some packet?") - } - bdh.Ext = binary.NewTeaCipher(key).Encrypt(bdh.Ext) - } - return nil -} - func (s *Session) UploadBDH(input BdhInput) ([]byte, error) { if len(s.SsoAddr) == 0 { return nil, errors.New("srv addrs not found. maybe miss some packet?") } addr := s.SsoAddr[0].String() - sum, length := utils.ComputeMd5AndLength(input.Body) - _, _ = input.Body.Seek(0, io.SeekStart) if err := input.encrypt(s.SessionKey); err != nil { return nil, err } @@ -95,12 +74,12 @@ func (s *Session) UploadBDH(input BdhInput) ([]byte, error) { head, _ := proto.Marshal(&pb.ReqDataHighwayHead{ MsgBasehead: s.dataHighwayHead(4096, input.CommandID, 2052), MsgSeghead: &pb.SegHead{ - Filesize: length, + Filesize: input.Size, Dataoffset: int64(offset), Datalength: int32(rl), Serviceticket: input.Ticket, Md5: ch[:], - FileMd5: sum, + FileMd5: input.Sum, }, ReqExtendinfo: input.Ext, }) @@ -127,17 +106,11 @@ func (s *Session) UploadBDH(input BdhInput) ([]byte, error) { return rspExt, nil } -func (s *Session) UploadBDHMultiThread(input BdhMultiThreadInput, threadCount int) ([]byte, error) { +func (s *Session) UploadBDHMultiThread(input BdhInput, threadCount int) ([]byte, error) { // for small file and small thread count, // use UploadBDH instead of UploadBDHMultiThread if input.Size < 1024*1024*3 || threadCount < 2 { - return s.UploadBDH(BdhInput{ - CommandID: input.CommandID, - Body: io.NewSectionReader(input.Body, 0, input.Size), - Ticket: input.Ticket, - Ext: input.Ext, - Encrypt: input.Encrypt, - }) + return s.UploadBDH(input) } if len(s.SsoAddr) == 0 { @@ -149,36 +122,21 @@ func (s *Session) UploadBDHMultiThread(input BdhMultiThreadInput, threadCount in return nil, err } - type BlockMetaData struct { - Id int - Offset int64 - } const blockSize int64 = 1024 * 512 var ( - blocks []BlockMetaData - rspExt []byte - BlockId = ^uint32(0) // -1 - uploadedCount uint32 - cond = sync.NewCond(&sync.Mutex{}) + rspExt []byte + completedThread uint32 + cond = sync.NewCond(&sync.Mutex{}) + offset = int64(0) + count = (input.Size + blockSize - 1) / blockSize + id = 0 ) - // Init Blocks - { - var temp int64 = 0 - for temp+blockSize < input.Size { - blocks = append(blocks, BlockMetaData{ - Id: len(blocks), - Offset: temp, - }) - temp += blockSize - } - blocks = append(blocks, BlockMetaData{ - Id: len(blocks), - Offset: temp, - }) - } doUpload := func() error { // send signal complete uploading - defer cond.Signal() + defer func() { + atomic.AddUint32(&completedThread, 1) + cond.Signal() + }() conn, err := net.DialTimeout("tcp", addr, time.Second*20) if err != nil { @@ -192,22 +150,20 @@ func (s *Session) UploadBDHMultiThread(input BdhMultiThreadInput, threadCount in chunk := make([]byte, blockSize) for { - nextId := atomic.AddUint32(&BlockId, 1) - if nextId >= uint32(len(blocks)) { - break - } - block := blocks[nextId] - if block.Id == len(blocks)-1 { - cond.L.Lock() - for atomic.LoadUint32(&uploadedCount) != uint32(len(blocks))-1 { + cond.L.Lock() // lock protect reading + off := offset + offset += blockSize + id++ + if int64(id) == count { // last + for atomic.LoadUint32(&completedThread) != uint32(threadCount-1) { cond.Wait() } + } else if int64(id) > count { cond.L.Unlock() + break } chunk = chunk[:blockSize] - - cond.L.Lock() // lock protect reading - n, err := input.Body.ReadAt(chunk, block.Offset) + n, err := io.ReadFull(input.Body, chunk) cond.L.Unlock() if err != nil { @@ -225,7 +181,7 @@ func (s *Session) UploadBDHMultiThread(input BdhMultiThreadInput, threadCount in MsgBasehead: s.dataHighwayHead(4096, input.CommandID, 2052), MsgSeghead: &pb.SegHead{ Filesize: input.Size, - Dataoffset: block.Offset, + Dataoffset: off, Datalength: int32(n), Serviceticket: input.Ticket, Md5: ch[:], @@ -248,7 +204,6 @@ func (s *Session) UploadBDHMultiThread(input BdhMultiThreadInput, threadCount in if rspHead.RspExtendinfo != nil { rspExt = rspHead.RspExtendinfo } - atomic.AddUint32(&uploadedCount, 1) } return nil } diff --git a/client/ptt.go b/client/ptt.go index 0d3e7fdf..17c4185a 100644 --- a/client/ptt.go +++ b/client/ptt.go @@ -1,6 +1,7 @@ package client import ( + "crypto/md5" "encoding/hex" "io" @@ -74,6 +75,8 @@ func (c *QQClient) UploadVoice(target message.Source, voice io.ReadSeeker) (*mes rsp, err := c.highwaySession.UploadBDH(highway.BdhInput{ CommandID: cmd, Body: voice, + Sum: fh, + Size: length, Ticket: c.highwaySession.SigSession, Ext: ext, Encrypt: false, @@ -120,14 +123,17 @@ func (c *QQClient) UploadVoice(target message.Source, voice io.ReadSeeker) (*mes // UploadShortVideo 将视频和封面上传到服务器, 返回 message.ShortVideoElement 可直接发送 // thread 上传线程数 func (c *QQClient) UploadShortVideo(target message.Source, video, thumb io.ReadSeeker, thread int) (*message.ShortVideoElement, error) { - videoHash, videoLen := utils.ComputeMd5AndLength(video) - thumbHash, thumbLen := utils.ComputeMd5AndLength(thumb) + thumbHash := md5.New() + thumbLen, _ := io.Copy(thumbHash, thumb) + thumbSum := thumbHash.Sum(nil) + videoSum, videoLen := utils.ComputeMd5AndLength(io.TeeReader(video, thumbHash)) + sum := thumbHash.Sum(nil) - key := string(videoHash) + string(thumbHash) + key := string(sum) pttWaiter.Wait(key) defer pttWaiter.Done(key) - i, err := c.sendAndWait(c.buildPttGroupShortVideoUploadReqPacket(target, videoHash, thumbHash, videoLen, thumbLen)) + i, err := c.sendAndWait(c.buildPttGroupShortVideoUploadReqPacket(target, videoSum, thumbSum, videoLen, thumbLen)) if err != nil { return nil, errors.Wrap(err, "upload req error") } @@ -135,8 +141,8 @@ func (c *QQClient) UploadShortVideo(target message.Source, video, thumb io.ReadS videoElement := &message.ShortVideoElement{ Size: int32(videoLen), ThumbSize: int32(thumbLen), - Md5: videoHash, - ThumbMd5: thumbHash, + Md5: videoSum, + ThumbMd5: thumbSum, Guild: target.SourceType == message.SourceGuildChannel, } if rsp.FileExists == 1 { @@ -149,28 +155,20 @@ func (c *QQClient) UploadShortVideo(target message.Source, video, thumb io.ReadS if target.SourceType == message.SourceGuildChannel { cmd = 89 } - ext, _ := proto.Marshal(c.buildPttShortVideoProto(target, videoHash, thumbHash, videoLen, thumbLen).PttShortVideoUploadReq) + ext, _ := proto.Marshal(c.buildPttShortVideoProto(target, videoSum, thumbSum, videoLen, thumbLen).PttShortVideoUploadReq) + combined := utils.MultiReadSeeker(thumb, video) + input := highway.BdhInput{ + CommandID: cmd, + Body: combined, + Size: videoLen + thumbLen, + Sum: sum, + Ticket: c.highwaySession.SigSession, + Ext: ext, + Encrypt: true, + } if thread > 1 { - sum, _ := utils.ComputeMd5AndLength(utils.MultiReadSeeker(thumb, video)) - input := highway.BdhMultiThreadInput{ - CommandID: cmd, - Body: utils.ReaderAtFrom2ReadSeeker(thumb, video), - Size: videoLen + thumbLen, - Sum: sum, - Ticket: c.highwaySession.SigSession, - Ext: ext, - Encrypt: true, - } hwRsp, err = c.highwaySession.UploadBDHMultiThread(input, thread) } else { - multi := utils.MultiReadSeeker(thumb, video) - input := highway.BdhInput{ - CommandID: cmd, - Body: multi, - Ticket: c.highwaySession.SigSession, - Ext: ext, - Encrypt: true, - } hwRsp, err = c.highwaySession.UploadBDH(input) } if err != nil {