1
0
mirror of https://github.com/Mrs4s/MiraiGo.git synced 2025-06-18 13:35:03 +08:00

client/internal/highway: auto switch multi-thread upload

This commit is contained in:
wdvxdr 2023-02-20 15:05:19 +08:00
parent e7a34fdd90
commit af032dec96
6 changed files with 31 additions and 62 deletions

View File

@ -43,10 +43,10 @@ type imageUploadResponse struct {
IsExists bool IsExists bool
} }
func (c *QQClient) UploadImage(target message.Source, img io.ReadSeeker, thread ...int) (message.IMessageElement, error) { func (c *QQClient) UploadImage(target message.Source, img io.ReadSeeker) (message.IMessageElement, error) {
switch target.SourceType { switch target.SourceType {
case message.SourceGroup, message.SourceGuildChannel, message.SourceGuildDirect: case message.SourceGroup, message.SourceGuildChannel, message.SourceGuildDirect:
return c.uploadGroupOrGuildImage(target, img, thread...) return c.uploadGroupOrGuildImage(target, img)
case message.SourcePrivate: case message.SourcePrivate:
return c.uploadPrivateImage(target.PrimaryID, img, 0) return c.uploadPrivateImage(target.PrimaryID, img, 0)
default: default:
@ -54,7 +54,7 @@ func (c *QQClient) UploadImage(target message.Source, img io.ReadSeeker, thread
} }
} }
func (c *QQClient) uploadGroupOrGuildImage(target message.Source, img io.ReadSeeker, thread ...int) (message.IMessageElement, error) { func (c *QQClient) uploadGroupOrGuildImage(target message.Source, img io.ReadSeeker) (message.IMessageElement, error) {
_, _ = img.Seek(0, io.SeekStart) // safe _, _ = img.Seek(0, io.SeekStart) // safe
fh, length := utils.ComputeMd5AndLength(img) fh, length := utils.ComputeMd5AndLength(img)
_, _ = img.Seek(0, io.SeekStart) _, _ = img.Seek(0, io.SeekStart)
@ -63,10 +63,6 @@ func (c *QQClient) uploadGroupOrGuildImage(target message.Source, img io.ReadSee
imgWaiter.Wait(key) imgWaiter.Wait(key)
defer imgWaiter.Done(key) defer imgWaiter.Done(key)
tc := 1
if len(thread) > 0 {
tc = thread[0]
}
cmd := int32(2) cmd := int32(2)
ext := EmptyBytes ext := EmptyBytes
if target.SourceType != message.SourceGroup { // guild if target.SourceType != message.SourceGroup { // guild
@ -112,11 +108,7 @@ func (c *QQClient) uploadGroupOrGuildImage(target message.Source, img io.ReadSee
Ticket: rsp.UploadKey, Ticket: rsp.UploadKey,
Ext: ext, Ext: ext,
} }
if tc > 1 && length > 3*1024*1024 { _, err = c.highwaySession.Upload(input)
_, err = c.highwaySession.UploadBDHMultiThread(input)
} else {
_, err = c.highwaySession.UploadBDH(input)
}
if err != nil { if err != nil {
return nil, errors.Wrap(err, "upload failed") return nil, errors.Wrap(err, "upload failed")
} }
@ -306,7 +298,7 @@ func (c *QQClient) uploadOcrImage(img io.Reader, size int32, sum []byte) (string
Uuid: binary.GenUUID(r), Uuid: binary.GenUUID(r),
}) })
rsp, err := c.highwaySession.UploadBDH(highway.Transaction{ rsp, err := c.highwaySession.Upload(highway.Transaction{
CommandID: 76, CommandID: 76,
Body: img, Body: img,
Size: int64(size), Size: int64(size),

View File

@ -3,7 +3,6 @@ package highway
import ( import (
"crypto/md5" "crypto/md5"
"io" "io"
"net"
"sync" "sync"
"sync/atomic" "sync/atomic"
@ -36,29 +35,7 @@ func (bdh *Transaction) encrypt(key []byte) error {
return nil return nil
} }
func (s *Session) retry(upload func(s *Session, addr Addr, trans *Transaction) ([]byte, error), trans *Transaction) ([]byte, error) { func (s *Session) uploadSingle(trans Transaction) ([]byte, error) {
// try to find a available server
for _, addr := range s.SsoAddr {
r, err := upload(s, addr, trans)
if err == nil {
return r, nil
}
if _, ok := err.(net.Error); ok {
// try another server
// TODO: delete broken servers?
continue
}
return nil, err
}
return nil, errors.New("cannot found available server")
}
func (s *Session) UploadBDH(trans Transaction) ([]byte, error) {
// encrypt ext data
if err := trans.encrypt(s.SessionKey); err != nil {
return nil, err
}
pc, err := s.selectConn() pc, err := s.selectConn()
if err != nil { if err != nil {
return nil, err return nil, err
@ -66,7 +43,7 @@ func (s *Session) UploadBDH(trans Transaction) ([]byte, error) {
defer s.putIdleConn(pc) defer s.putIdleConn(pc)
reader := binary.NewNetworkReader(pc.conn) reader := binary.NewNetworkReader(pc.conn)
const chunkSize = 256 * 1024 const chunkSize = 128 * 1024
var rspExt []byte var rspExt []byte
offset := 0 offset := 0
chunk := make([]byte, chunkSize) chunk := make([]byte, chunkSize)
@ -105,7 +82,7 @@ func (s *Session) UploadBDH(trans Transaction) ([]byte, error) {
buffers := frame(head, chunk) buffers := frame(head, chunk)
_, err = buffers.WriteTo(pc.conn) _, err = buffers.WriteTo(pc.conn)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "write pc error") return nil, errors.Wrap(err, "write conn error")
} }
rspHead, err := readResponse(reader) rspHead, err := readResponse(reader)
if err != nil { if err != nil {
@ -124,18 +101,22 @@ func (s *Session) UploadBDH(trans Transaction) ([]byte, error) {
return rspExt, nil return rspExt, nil
} }
func (s *Session) UploadBDHMultiThread(trans Transaction) ([]byte, error) { func (s *Session) Upload(trans Transaction) ([]byte, error) {
// for small file and small thread count,
// use UploadBDH instead of UploadBDHMultiThread
if trans.Size < 1024*1024*3 {
return s.UploadBDH(trans)
}
// encrypt ext data // encrypt ext data
if err := trans.encrypt(s.SessionKey); err != nil { if err := trans.encrypt(s.SessionKey); err != nil {
return nil, err return nil, err
} }
const maxThreadCount = 4
threadCount := int(trans.Size) / (3 * 512 * 1024) // 1 thread upload 1.5 MB
if threadCount > maxThreadCount {
threadCount = maxThreadCount
}
if threadCount < 2 {
// single thread upload
return s.uploadSingle(trans)
}
// pick a address // pick a address
// TODO: pick smarter // TODO: pick smarter
pc, err := s.selectConn() pc, err := s.selectConn()
@ -145,9 +126,7 @@ func (s *Session) UploadBDHMultiThread(trans Transaction) ([]byte, error) {
addr := pc.addr addr := pc.addr
s.putIdleConn(pc) s.putIdleConn(pc)
// TODO: use idle conn
const blockSize int64 = 256 * 1024 const blockSize int64 = 256 * 1024
const threadCount = 4
var ( var (
rspExt []byte rspExt []byte
completedThread uint32 completedThread uint32
@ -163,11 +142,12 @@ func (s *Session) UploadBDHMultiThread(trans Transaction) ([]byte, error) {
cond.Signal() cond.Signal()
}() }()
// todo: get from pool?
pc, err := s.connect(addr) pc, err := s.connect(addr)
if err != nil { if err != nil {
return err return err
} }
// defer s.putIdleConn(pc) // TODO: should we put back? defer s.putIdleConn(pc)
reader := binary.NewNetworkReader(pc.conn) reader := binary.NewNetworkReader(pc.conn)
chunk := make([]byte, blockSize) chunk := make([]byte, blockSize)
@ -176,7 +156,8 @@ func (s *Session) UploadBDHMultiThread(trans Transaction) ([]byte, error) {
off := offset off := offset
offset += blockSize offset += blockSize
id++ id++
if int64(id) == count { // last last := int64(id) == count
if last { // last
for atomic.LoadUint32(&completedThread) != uint32(threadCount-1) { for atomic.LoadUint32(&completedThread) != uint32(threadCount-1) {
cond.Wait() cond.Wait()
} }
@ -228,7 +209,7 @@ func (s *Session) UploadBDHMultiThread(trans Transaction) ([]byte, error) {
if rspHead.ErrorCode != 0 { if rspHead.ErrorCode != 0 {
return errors.Errorf("upload failed: %d", rspHead.ErrorCode) return errors.Errorf("upload failed: %d", rspHead.ErrorCode)
} }
if rspHead.RspExtendinfo != nil { if last && rspHead.RspExtendinfo != nil {
rspExt = rspHead.RspExtendinfo rspExt = rspHead.RspExtendinfo
} }
} }

View File

@ -135,7 +135,7 @@ type persistConn struct {
ping int64 // echo ping ping int64 // echo ping
} }
const maxIdleConn = 5 const maxIdleConn = 7
type idle struct { type idle struct {
pc persistConn pc persistConn
@ -207,6 +207,7 @@ func (s *Session) connect(addr Addr) (persistConn, error) {
if err != nil { if err != nil {
return persistConn{}, err return persistConn{}, err
} }
_ = conn.(*net.TCPConn).SetKeepAlive(true)
// close conn // close conn
runtime.SetFinalizer(conn, func(conn net.Conn) { runtime.SetFinalizer(conn, func(conn net.Conn) {

View File

@ -313,7 +313,7 @@ func (builder *ForwardMessageBuilder) Main(m *message.ForwardMessage) *message.F
Sum: bodyHash[:], Sum: bodyHash[:],
Size: int64(len(body)), Size: int64(len(body)),
} }
_, err = c.highwaySession.UploadBDH(input) _, err = c.highwaySession.Upload(input)
if err != nil { if err != nil {
return nil return nil
} }

View File

@ -73,7 +73,7 @@ func (c *QQClient) UploadVoice(target message.Source, voice io.ReadSeeker) (*mes
ext = c.buildGroupPttStoreBDHExt(target.PrimaryID, fh, int32(length), 0, int32(length)) ext = c.buildGroupPttStoreBDHExt(target.PrimaryID, fh, int32(length), 0, int32(length))
} }
// multi-thread upload is no need // multi-thread upload is no need
rsp, err := c.highwaySession.UploadBDH(highway.Transaction{ rsp, err := c.highwaySession.Upload(highway.Transaction{
CommandID: cmd, CommandID: cmd,
Body: voice, Body: voice,
Sum: fh, Sum: fh,
@ -121,8 +121,7 @@ func (c *QQClient) UploadVoice(target message.Source, voice io.ReadSeeker) (*mes
} }
// UploadShortVideo 将视频和封面上传到服务器, 返回 message.ShortVideoElement 可直接发送 // UploadShortVideo 将视频和封面上传到服务器, 返回 message.ShortVideoElement 可直接发送
// thread 上传线程数 func (c *QQClient) UploadShortVideo(target message.Source, video, thumb io.ReadSeeker) (*message.ShortVideoElement, error) {
func (c *QQClient) UploadShortVideo(target message.Source, video, thumb io.ReadSeeker, thread int) (*message.ShortVideoElement, error) {
thumbHash := md5.New() thumbHash := md5.New()
thumbLen, _ := io.Copy(thumbHash, thumb) thumbLen, _ := io.Copy(thumbHash, thumb)
thumbSum := thumbHash.Sum(nil) thumbSum := thumbHash.Sum(nil)
@ -168,11 +167,7 @@ func (c *QQClient) UploadShortVideo(target message.Source, video, thumb io.ReadS
Ext: ext, Ext: ext,
Encrypt: true, Encrypt: true,
} }
if thread > 1 { hwRsp, err = c.highwaySession.Upload(input)
hwRsp, err = c.highwaySession.UploadBDHMultiThread(input)
} else {
hwRsp, err = c.highwaySession.UploadBDH(input)
}
if err != nil { if err != nil {
return nil, errors.Wrap(err, "upload video file error") return nil, errors.Wrap(err, "upload video file error")
} }

View File

@ -147,7 +147,7 @@ func (c *QQClient) UploadFile(target message.Source, file *LocalFile) error {
if target.SourceType == message.SourcePrivate { if target.SourceType == message.SourcePrivate {
input.CommandID = 69 input.CommandID = 69
} }
if _, err := c.highwaySession.UploadBDHMultiThread(input); err != nil { if _, err := c.highwaySession.Upload(input); err != nil {
return errors.Wrap(err, "upload failed") return errors.Wrap(err, "upload failed")
} }
} }