diff --git a/client/internal/highway/bdh.go b/client/internal/highway/bdh.go index 36c05dd7..3a045cbd 100644 --- a/client/internal/highway/bdh.go +++ b/client/internal/highway/bdh.go @@ -59,20 +59,15 @@ func (s *Session) UploadBDH(trans Transaction) ([]byte, error) { if err := trans.encrypt(s.SessionKey); err != nil { return nil, err } - return s.retry(uploadBDH, &trans) -} -func uploadBDH(s *Session, addr Addr, trans *Transaction) ([]byte, error) { - conn, err := net.DialTimeout("tcp", addr.String(), time.Second*20) + pc, err := s.selectConn() if err != nil { return nil, err } - defer conn.Close() + defer s.putIdleConn(pc) + conn := pc.conn reader := binary.NewNetworkReader(conn) - if err = s.sendEcho(conn); err != nil { - return nil, err - } const chunkSize = 256 * 1024 var rspExt []byte @@ -113,7 +108,7 @@ func uploadBDH(s *Session, addr Addr, trans *Transaction) ([]byte, error) { buffers := frame(head, chunk) _, err = buffers.WriteTo(conn) if err != nil { - return nil, errors.Wrap(err, "write conn error") + return nil, errors.Wrap(err, "write pc error") } rspHead, err := readResponse(reader) if err != nil { @@ -147,6 +142,7 @@ func (s *Session) UploadBDHMultiThread(trans Transaction) ([]byte, error) { } func uploadBDHMultiThread(s *Session, addr Addr, trans *Transaction) ([]byte, error) { + // TODO: use idle conn const blockSize int64 = 256 * 1024 const threadCount = 4 var ( @@ -170,7 +166,7 @@ func uploadBDHMultiThread(s *Session, addr Addr, trans *Transaction) ([]byte, er } defer conn.Close() reader := binary.NewNetworkReader(conn) - if err = s.sendEcho(conn); err != nil { + if err = s.ping(&persistConn{conn: conn}); err != nil { return err } @@ -223,7 +219,7 @@ func uploadBDHMultiThread(s *Session, addr Addr, trans *Transaction) ([]byte, er buffers := frame(head, chunk) _, err = buffers.WriteTo(conn) if err != nil { - return errors.Wrap(err, "write conn error") + return errors.Wrap(err, "write pc error") } rspHead, err := readResponse(reader) if err != nil { diff --git a/client/internal/highway/highway.go b/client/internal/highway/highway.go index c34fe8fc..311e6a0c 100644 --- a/client/internal/highway/highway.go +++ b/client/internal/highway/highway.go @@ -3,7 +3,10 @@ package highway import ( "fmt" "net" + "runtime" + "sync" "sync/atomic" + "time" "github.com/pkg/errors" @@ -36,23 +39,28 @@ type Session struct { AppID int32 SigSession []byte SessionKey []byte - SsoAddr []Addr - seq int32 - /* - idleMu sync.Mutex - idleCount int - idle *idle - */ + addrMu sync.Mutex + idx int + SsoAddr []Addr + + seq int32 + idleMu sync.Mutex + idleCount int + idle *idle } const highwayMaxResponseSize int32 = 1024 * 100 // 100k func (s *Session) AddrLength() int { + s.addrMu.Lock() + defer s.addrMu.Unlock() return len(s.SsoAddr) } func (s *Session) AppendAddr(ip, port uint32) { + s.addrMu.Lock() + defer s.addrMu.Unlock() addr := Addr{ IP: ip, Port: int(port), @@ -82,14 +90,17 @@ func (s *Session) sendHeartbreak(conn net.Conn) error { return err } -func (s *Session) sendEcho(conn net.Conn) error { - err := s.sendHeartbreak(conn) +func (s *Session) ping(pc *persistConn) error { + start := time.Now() + err := s.sendHeartbreak(pc.conn) if err != nil { return errors.Wrap(err, "echo error") } - if _, err = readResponse(binary.NewNetworkReader(conn)); err != nil { + if _, err = readResponse(binary.NewNetworkReader(pc.conn)); err != nil { return errors.Wrap(err, "echo error") } + // update delay + pc.ping = time.Since(start).Milliseconds() return nil } @@ -113,22 +124,120 @@ func readResponse(r *binary.NetworkReader) (*pb.RspDataHighwayHead, error) { return rsp, nil } -/* +type persistConn struct { + conn net.Conn + ping int64 // echo ping +} + const maxIdleConn = 5 type idle struct { - conn net.Conn - delay int64 - next *idle + pc persistConn + next *idle } -// getConn ... -func (s *Session) getConn() net.Conn { +// getIdleConn ... +func (s *Session) getIdleConn() persistConn { s.idleMu.Lock() defer s.idleMu.Unlock() - conn := s.idle.conn + // no idle + if s.idle == nil { + return persistConn{} + } + + // switch the fastest idle conn + conn := s.idle.pc s.idle = s.idle.next + s.idleCount-- + if s.idleCount < 0 { + panic("idle count underflow") + } + return conn } -*/ + +func (s *Session) putIdleConn(pc persistConn) { + s.idleMu.Lock() + defer s.idleMu.Unlock() + + cur := &idle{pc: pc} + s.idleCount++ + if s.idle == nil { // quick path + s.idle = cur + return + } + + // insert between pre and succ + var pre, succ *idle + succ = s.idle + for succ != nil && succ.pc.ping < pc.ping { // keep idle list sorted by delay incremental + pre = succ + succ = succ.next + } + if pre != nil { + pre.next = cur + } + cur.next = succ + + // remove the slowest idle conn if idle count greater than maxIdleConn + if s.idleCount > maxIdleConn { + for cur.next != nil { + pre = cur + cur = cur.next + } + pre.next = nil + s.idleCount-- + } +} + +func (s *Session) connect(addr Addr) (persistConn, error) { + conn, err := net.DialTimeout("tcp", addr.String(), time.Second*3) + if err != nil { + return persistConn{}, err + } + + // close conn + runtime.SetFinalizer(conn, func(conn net.Conn) { + _ = conn.Close() + }) + + pc := persistConn{conn: conn} + if err = s.ping(&pc); err != nil { + return persistConn{}, err + } + return pc, nil +} + +func (s *Session) selectConn() (pc persistConn, err error) { + for { // select from idle pc + pc = s.getIdleConn() + if pc.conn == nil { + // no idle connection + break + } + err = s.ping(&pc) // ping + if err == nil { + return + } + } + + s.addrMu.Lock() + defer s.addrMu.Unlock() + try := 0 + idx := s.idx + for { + addr := s.SsoAddr[idx] + idx = (idx + 1) % len(s.SsoAddr) // move next + pc, err = s.connect(addr) + if err == nil { + break + } + try++ + if try > 5 { + break + } + } + s.idx = idx + return +}