diff --git a/client/image.go b/client/image.go index 3806b493..7e1a4f1b 100644 --- a/client/image.go +++ b/client/image.go @@ -118,7 +118,7 @@ func (c *QQClient) uploadGroupOrGuildImage(target message.Source, img io.ReadSee if tc > 1 && length > 3*1024*1024 { _, err = c.highwaySession.UploadBDHMultiThread(highway.BdhMultiThreadInput{ CommandID: cmd, - Body: utils.ReaderAtFrom2ReadSeeker(img, nil), + Body: img, Size: length, Sum: fh, Ticket: rsp.UploadKey, diff --git a/client/internal/highway/bdh.go b/client/internal/highway/bdh.go index 2921d37e..e9605f3c 100644 --- a/client/internal/highway/bdh.go +++ b/client/internal/highway/bdh.go @@ -28,7 +28,7 @@ type BdhInput struct { type BdhMultiThreadInput struct { CommandID int32 - Body io.ReaderAt + Body io.ReadSeeker Sum []byte Size int64 Ticket []byte @@ -130,10 +130,11 @@ func (s *Session) UploadBDH(input BdhInput) ([]byte, error) { func (s *Session) UploadBDHMultiThread(input BdhMultiThreadInput, threadCount int) ([]byte, error) { // for small file and small thread count, // use UploadBDH instead of UploadBDHMultiThread - if input.Size < 1024*1024*3 || threadCount < 2 { + // FIXME: enable multi-thread, now receive error code 81 + if true || input.Size < 1024*1024*3 || threadCount < 2 { return s.UploadBDH(BdhInput{ CommandID: input.CommandID, - Body: io.NewSectionReader(input.Body, 0, input.Size), + Body: input.Body, Ticket: input.Ticket, Ext: input.Ext, Encrypt: input.Encrypt, @@ -207,7 +208,8 @@ func (s *Session) UploadBDHMultiThread(input BdhMultiThreadInput, threadCount in chunk = chunk[:blockSize] cond.L.Lock() // lock protect reading - n, err := input.Body.ReadAt(chunk, block.Offset) + _, _ = input.Body.Seek(block.Offset, io.SeekStart) + n, err := input.Body.Read(chunk) cond.L.Unlock() if err != nil { diff --git a/client/ptt.go b/client/ptt.go index 0d3e7fdf..9ad2caeb 100644 --- a/client/ptt.go +++ b/client/ptt.go @@ -150,11 +150,13 @@ func (c *QQClient) UploadShortVideo(target message.Source, video, thumb io.ReadS cmd = 89 } ext, _ := proto.Marshal(c.buildPttShortVideoProto(target, videoHash, thumbHash, videoLen, thumbLen).PttShortVideoUploadReq) + combined := utils.DoubleReadSeeker(video, thumb) if thread > 1 { - sum, _ := utils.ComputeMd5AndLength(utils.MultiReadSeeker(thumb, video)) + sum, _ := utils.ComputeMd5AndLength(combined) + _, _ = combined.Seek(0, io.SeekStart) input := highway.BdhMultiThreadInput{ CommandID: cmd, - Body: utils.ReaderAtFrom2ReadSeeker(thumb, video), + Body: combined, Size: videoLen + thumbLen, Sum: sum, Ticket: c.highwaySession.SigSession, @@ -163,10 +165,9 @@ func (c *QQClient) UploadShortVideo(target message.Source, video, thumb io.ReadS } hwRsp, err = c.highwaySession.UploadBDHMultiThread(input, thread) } else { - multi := utils.MultiReadSeeker(thumb, video) input := highway.BdhInput{ CommandID: cmd, - Body: multi, + Body: combined, Ticket: c.highwaySession.SigSession, Ext: ext, Encrypt: true, diff --git a/utils/sys.go b/utils/sys.go index 7ef22b45..75b54e7d 100644 --- a/utils/sys.go +++ b/utils/sys.go @@ -6,9 +6,57 @@ import ( "io" ) -type multiReadSeeker struct { - readers []io.ReadSeeker - multiReader io.Reader +type doubleReadSeeker struct { + rs1, rs2 io.ReadSeeker + rs1len, rs2len int64 + pos int64 +} + +func (r *doubleReadSeeker) Seek(offset int64, whence int) (int64, error) { + var err error + switch whence { + case io.SeekStart: + if offset < r.rs1len { + r.pos, err = r.rs1.Seek(offset, io.SeekStart) + return r.pos, err + } else { + r.pos, err = r.rs2.Seek(offset-r.rs1len, io.SeekStart) + r.pos += r.rs1len + return r.pos, err + } + case io.SeekEnd: // negative offset + return r.Seek(r.rs1len+r.rs2len+offset-1, io.SeekStart) + default: // io.SeekCurrent + return r.Seek(r.pos+offset, io.SeekStart) + } +} + +func (r *doubleReadSeeker) Read(p []byte) (n int, err error) { + switch { + case r.pos >= r.rs1len: // read only from the second reader + n, err := r.rs2.Read(p) + r.pos += int64(n) + return n, err + case r.pos+int64(len(p)) <= r.rs1len: // read only from the first reader + n, err := r.rs1.Read(p) + r.pos += int64(n) + return n, err + default: // read on the border - end of first reader and start of second reader + n1, err := r.rs1.Read(p) + r.pos += int64(n1) + if r.pos != r.rs1len || (err != nil && errors.Is(err, io.EOF)) { + // Read() might not read all, return + // If error (but not EOF), return + return n1, err + } + _, err = r.rs2.Seek(0, io.SeekStart) + if err != nil { + return n1, err + } + n2, err := r.rs2.Read(p[n1:]) + r.pos += int64(n2) + return n1 + n2, err + } } func ComputeMd5AndLength(r io.Reader) ([]byte, int64) { @@ -18,82 +66,19 @@ func ComputeMd5AndLength(r io.Reader) ([]byte, int64) { return fh, length } -func (r *multiReadSeeker) Read(p []byte) (int, error) { - if r.multiReader == nil { - var readers []io.Reader - for i := range r.readers { - _, _ = r.readers[i].Seek(0, io.SeekStart) - readers = append(readers, r.readers[i]) - } - r.multiReader = io.MultiReader(readers...) - } - return r.multiReader.Read(p) -} - -func (r *multiReadSeeker) Seek(offset int64, whence int) (int64, error) { - if whence != 0 || offset != 0 { - return -1, errors.New("unsupported offset") - } - r.multiReader = nil - return 0, nil -} - -func MultiReadSeeker(r ...io.ReadSeeker) io.ReadSeeker { - return &multiReadSeeker{ - readers: r, - } -} - -type multiReadAt struct { - first io.ReadSeeker - second io.ReadSeeker - firstSize int64 - secondSize int64 -} - -func (m *multiReadAt) ReadAt(p []byte, off int64) (n int, err error) { - if m.second == nil { // quick path - _, _ = m.first.Seek(off, io.SeekStart) - return m.first.Read(p) - } - if off < m.firstSize && off+int64(len(p)) < m.firstSize { - _, err = m.first.Seek(off, io.SeekStart) - if err != nil { - return - } - return m.first.Read(p) - } else if off < m.firstSize && off+int64(len(p)) >= m.firstSize { - _, _ = m.first.Seek(off, io.SeekStart) - _, _ = m.second.Seek(0, io.SeekStart) - n, err = m.first.Read(p[:m.firstSize-off]) - if err != nil { - return - } - n2, err := m.second.Read(p[m.firstSize-off:]) - return n + n2, err - } - _, err = m.second.Seek(off-m.firstSize, io.SeekStart) - if err != nil { - return - } - return m.second.Read(p) -} - -func ReaderAtFrom2ReadSeeker(first, second io.ReadSeeker) io.ReaderAt { - firstSize, _ := first.Seek(0, io.SeekEnd) - if second == nil { - return &multiReadAt{ - first: first, - firstSize: firstSize, - secondSize: 0, - } - } - secondSize, _ := second.Seek(0, io.SeekEnd) - return &multiReadAt{ - first: first, - second: second, - firstSize: firstSize, - secondSize: secondSize, +// DoubleReadSeeker combines two io.ReadSeeker into one. +// input two io.ReadSeeker must be at the start. +func DoubleReadSeeker(first, second io.ReadSeeker) io.ReadSeeker { + rs1Len, _ := first.Seek(0, io.SeekEnd) + _, _ = first.Seek(0, io.SeekStart) // reset to start + rs2Len, _ := second.Seek(0, io.SeekEnd) + _, _ = second.Seek(0, io.SeekStart) // reset to start + return &doubleReadSeeker{ + rs1: first, + rs2: second, + rs1len: rs1Len, + rs2len: rs2Len, + pos: 0, } }