1
0
mirror of https://github.com/Mrs4s/MiraiGo.git synced 2025-05-04 11:07:40 +08:00

client/internal/highway: use pool for multithread upload

This commit is contained in:
wdvxdr 2023-02-19 23:06:15 +08:00
parent 3b97ce341b
commit e7a34fdd90
3 changed files with 28 additions and 17 deletions

View File

@ -6,7 +6,6 @@ import (
"net" "net"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time"
"github.com/pkg/errors" "github.com/pkg/errors"
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
@ -66,9 +65,7 @@ func (s *Session) UploadBDH(trans Transaction) ([]byte, error) {
} }
defer s.putIdleConn(pc) defer s.putIdleConn(pc)
conn := pc.conn reader := binary.NewNetworkReader(pc.conn)
reader := binary.NewNetworkReader(conn)
const chunkSize = 256 * 1024 const chunkSize = 256 * 1024
var rspExt []byte var rspExt []byte
offset := 0 offset := 0
@ -106,7 +103,7 @@ func (s *Session) UploadBDH(trans Transaction) ([]byte, error) {
}) })
offset += rl offset += rl
buffers := frame(head, chunk) buffers := frame(head, chunk)
_, err = buffers.WriteTo(conn) _, err = buffers.WriteTo(pc.conn)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "write pc error") return nil, errors.Wrap(err, "write pc error")
} }
@ -138,10 +135,16 @@ func (s *Session) UploadBDHMultiThread(trans Transaction) ([]byte, error) {
if err := trans.encrypt(s.SessionKey); err != nil { if err := trans.encrypt(s.SessionKey); err != nil {
return nil, err return nil, err
} }
return s.retry(uploadBDHMultiThread, &trans)
}
func uploadBDHMultiThread(s *Session, addr Addr, trans *Transaction) ([]byte, error) { // pick a address
// TODO: pick smarter
pc, err := s.selectConn()
if err != nil {
return nil, err
}
addr := pc.addr
s.putIdleConn(pc)
// TODO: use idle conn // TODO: use idle conn
const blockSize int64 = 256 * 1024 const blockSize int64 = 256 * 1024
const threadCount = 4 const threadCount = 4
@ -160,16 +163,13 @@ func uploadBDHMultiThread(s *Session, addr Addr, trans *Transaction) ([]byte, er
cond.Signal() cond.Signal()
}() }()
conn, err := net.DialTimeout("tcp", addr.String(), time.Second*20) pc, err := s.connect(addr)
if err != nil { if err != nil {
return err return err
} }
defer conn.Close() // defer s.putIdleConn(pc) // TODO: should we put back?
reader := binary.NewNetworkReader(conn)
if err = s.ping(&persistConn{conn: conn}); err != nil {
return err
}
reader := binary.NewNetworkReader(pc.conn)
chunk := make([]byte, blockSize) chunk := make([]byte, blockSize)
for { for {
cond.L.Lock() // lock protect reading cond.L.Lock() // lock protect reading
@ -217,7 +217,7 @@ func uploadBDHMultiThread(s *Session, addr Addr, trans *Transaction) ([]byte, er
ReqExtendinfo: trans.Ext, ReqExtendinfo: trans.Ext,
}) })
buffers := frame(head, chunk) buffers := frame(head, chunk)
_, err = buffers.WriteTo(conn) _, err = buffers.WriteTo(pc.conn)
if err != nil { if err != nil {
return errors.Wrap(err, "write conn error") return errors.Wrap(err, "write conn error")
} }

View File

@ -34,6 +34,10 @@ func (a Addr) String() string {
return fmt.Sprintf("%v:%v", binary.UInt32ToIPV4Address(a.IP), a.Port) return fmt.Sprintf("%v:%v", binary.UInt32ToIPV4Address(a.IP), a.Port)
} }
func (a Addr) empty() bool {
return a.IP == 0 || a.Port == 0
}
type Session struct { type Session struct {
Uin string Uin string
AppID int32 AppID int32
@ -127,6 +131,7 @@ func readResponse(r *binary.NetworkReader) (*pb.RspDataHighwayHead, error) {
type persistConn struct { type persistConn struct {
conn net.Conn conn net.Conn
addr Addr
ping int64 // echo ping ping int64 // echo ping
} }
@ -162,6 +167,11 @@ func (s *Session) putIdleConn(pc persistConn) {
s.idleMu.Lock() s.idleMu.Lock()
defer s.idleMu.Unlock() defer s.idleMu.Unlock()
// check persistConn
if pc.conn == nil || pc.addr.empty() {
panic("put bad idle conn")
}
cur := &idle{pc: pc} cur := &idle{pc: pc}
s.idleCount++ s.idleCount++
if s.idle == nil { // quick path if s.idle == nil { // quick path
@ -203,7 +213,7 @@ func (s *Session) connect(addr Addr) (persistConn, error) {
_ = conn.Close() _ = conn.Close()
}) })
pc := persistConn{conn: conn} pc := persistConn{conn: conn, addr: addr}
if err = s.ping(&pc); err != nil { if err = s.ping(&pc); err != nil {
return persistConn{}, err return persistConn{}, err
} }
@ -225,6 +235,7 @@ func (s *Session) selectConn() (pc persistConn, err error) {
// no idle connection // no idle connection
break break
} }
err = s.ping(&pc) // ping err = s.ping(&pc) // ping
if err == nil { if err == nil {
return return

View File

@ -147,7 +147,7 @@ func (c *QQClient) UploadFile(target message.Source, file *LocalFile) error {
if target.SourceType == message.SourcePrivate { if target.SourceType == message.SourcePrivate {
input.CommandID = 69 input.CommandID = 69
} }
if _, err := c.highwaySession.UploadBDH(input); err != nil { if _, err := c.highwaySession.UploadBDHMultiThread(input); err != nil {
return errors.Wrap(err, "upload failed") return errors.Wrap(err, "upload failed")
} }
} }