From e7a34fdd90623cfb095b7b7bfbd3d7ac6c1f5e83 Mon Sep 17 00:00:00 2001 From: wdvxdr Date: Sun, 19 Feb 2023 23:06:15 +0800 Subject: [PATCH] client/internal/highway: use pool for multithread upload --- client/internal/highway/bdh.go | 30 +++++++++++++++--------------- client/internal/highway/highway.go | 13 ++++++++++++- client/upload_file.go | 2 +- 3 files changed, 28 insertions(+), 17 deletions(-) diff --git a/client/internal/highway/bdh.go b/client/internal/highway/bdh.go index 909c97a2..cc9a21d2 100644 --- a/client/internal/highway/bdh.go +++ b/client/internal/highway/bdh.go @@ -6,7 +6,6 @@ import ( "net" "sync" "sync/atomic" - "time" "github.com/pkg/errors" "golang.org/x/sync/errgroup" @@ -66,9 +65,7 @@ func (s *Session) UploadBDH(trans Transaction) ([]byte, error) { } defer s.putIdleConn(pc) - conn := pc.conn - reader := binary.NewNetworkReader(conn) - + reader := binary.NewNetworkReader(pc.conn) const chunkSize = 256 * 1024 var rspExt []byte offset := 0 @@ -106,7 +103,7 @@ func (s *Session) UploadBDH(trans Transaction) ([]byte, error) { }) offset += rl buffers := frame(head, chunk) - _, err = buffers.WriteTo(conn) + _, err = buffers.WriteTo(pc.conn) if err != nil { return nil, errors.Wrap(err, "write pc error") } @@ -138,10 +135,16 @@ func (s *Session) UploadBDHMultiThread(trans Transaction) ([]byte, error) { if err := trans.encrypt(s.SessionKey); err != nil { return nil, err } - return s.retry(uploadBDHMultiThread, &trans) -} -func uploadBDHMultiThread(s *Session, addr Addr, trans *Transaction) ([]byte, error) { + // pick a address + // TODO: pick smarter + pc, err := s.selectConn() + if err != nil { + return nil, err + } + addr := pc.addr + s.putIdleConn(pc) + // TODO: use idle conn const blockSize int64 = 256 * 1024 const threadCount = 4 @@ -160,16 +163,13 @@ func uploadBDHMultiThread(s *Session, addr Addr, trans *Transaction) ([]byte, er cond.Signal() }() - conn, err := net.DialTimeout("tcp", addr.String(), time.Second*20) + pc, err := s.connect(addr) if err != nil { return err } - defer conn.Close() - reader := binary.NewNetworkReader(conn) - if err = s.ping(&persistConn{conn: conn}); err != nil { - return err - } + // defer s.putIdleConn(pc) // TODO: should we put back? + reader := binary.NewNetworkReader(pc.conn) chunk := make([]byte, blockSize) for { cond.L.Lock() // lock protect reading @@ -217,7 +217,7 @@ func uploadBDHMultiThread(s *Session, addr Addr, trans *Transaction) ([]byte, er ReqExtendinfo: trans.Ext, }) buffers := frame(head, chunk) - _, err = buffers.WriteTo(conn) + _, err = buffers.WriteTo(pc.conn) if err != nil { return errors.Wrap(err, "write conn error") } diff --git a/client/internal/highway/highway.go b/client/internal/highway/highway.go index f3928b51..8028e15c 100644 --- a/client/internal/highway/highway.go +++ b/client/internal/highway/highway.go @@ -34,6 +34,10 @@ func (a Addr) String() string { return fmt.Sprintf("%v:%v", binary.UInt32ToIPV4Address(a.IP), a.Port) } +func (a Addr) empty() bool { + return a.IP == 0 || a.Port == 0 +} + type Session struct { Uin string AppID int32 @@ -127,6 +131,7 @@ func readResponse(r *binary.NetworkReader) (*pb.RspDataHighwayHead, error) { type persistConn struct { conn net.Conn + addr Addr ping int64 // echo ping } @@ -162,6 +167,11 @@ func (s *Session) putIdleConn(pc persistConn) { s.idleMu.Lock() defer s.idleMu.Unlock() + // check persistConn + if pc.conn == nil || pc.addr.empty() { + panic("put bad idle conn") + } + cur := &idle{pc: pc} s.idleCount++ if s.idle == nil { // quick path @@ -203,7 +213,7 @@ func (s *Session) connect(addr Addr) (persistConn, error) { _ = conn.Close() }) - pc := persistConn{conn: conn} + pc := persistConn{conn: conn, addr: addr} if err = s.ping(&pc); err != nil { return persistConn{}, err } @@ -225,6 +235,7 @@ func (s *Session) selectConn() (pc persistConn, err error) { // no idle connection break } + err = s.ping(&pc) // ping if err == nil { return diff --git a/client/upload_file.go b/client/upload_file.go index febac67a..4f7fbacf 100644 --- a/client/upload_file.go +++ b/client/upload_file.go @@ -147,7 +147,7 @@ func (c *QQClient) UploadFile(target message.Source, file *LocalFile) error { if target.SourceType == message.SourcePrivate { input.CommandID = 69 } - if _, err := c.highwaySession.UploadBDH(input); err != nil { + if _, err := c.highwaySession.UploadBDHMultiThread(input); err != nil { return errors.Wrap(err, "upload failed") } }