diff --git a/utils/connection.go b/utils/connection.go index 92485000..f7e3b1de 100644 --- a/utils/connection.go +++ b/utils/connection.go @@ -21,11 +21,15 @@ var ErrConnectionClosed = errors.New("connection closed") // PlannedDisconnect 预料中的断开连接 // 如调用 Close() Connect() func (t *TCPListener) PlannedDisconnect(f func(*TCPListener)) { + t.lock.Lock() + defer t.lock.Unlock() t.plannedDisconnect = f } // UnexpectedDisconnect 未预料钟的断开连接 func (t *TCPListener) UnexpectedDisconnect(f func(*TCPListener, error)) { + t.lock.Lock() + defer t.lock.Unlock() t.unexpectedDisconnect = f } @@ -42,30 +46,30 @@ func (t *TCPListener) Connect(addr *net.TCPAddr) error { } func (t *TCPListener) Write(buf []byte) error { - err := t.rLockDo(func() error { - _, e := t.conn.Write(buf) - return e - }) - if err == nil { + if conn := t.getConn(); conn != nil { + _, err := conn.Write(buf) + if err != nil { + t.unexpectedClose(err) + return ErrConnectionClosed + } return nil } - t.unexpectedClose(err) return ErrConnectionClosed } func (t *TCPListener) ReadBytes(len int) ([]byte, error) { buf := make([]byte, len) - err := t.rLockDo(func() error { - _, e := io.ReadFull(t.conn, buf) - return e - }) - if err == nil { + if conn := t.getConn(); conn != nil { + _, err := io.ReadFull(conn, buf) + if err != nil { + // time.Sleep(time.Millisecond * 100) // 服务器会发送offline包后立即断开连接, 此时还没解析, 可能还是得加锁 + t.unexpectedClose(err) + return nil, ErrConnectionClosed + } return buf, nil } - //time.Sleep(time.Millisecond * 100) // 服务器会发送offline包后立即断开连接, 此时还没解析, 可能还是得加锁 - t.unexpectedClose(err) return nil, ErrConnectionClosed } @@ -78,17 +82,13 @@ func (t *TCPListener) ReadInt32() (int32, error) { } func (t *TCPListener) Close() { - if !t.connIsNil() { - t.close() - t.invokePlannedDisconnect() - } + t.close() + t.invokePlannedDisconnect() } func (t *TCPListener) unexpectedClose(err error) { - if !t.connIsNil() { - t.close() - t.invokeUnexpectedDisconnect(err) - } + t.close() + t.invokeUnexpectedDisconnect(err) } func (t *TCPListener) close() { @@ -101,28 +101,23 @@ func (t *TCPListener) close() { } func (t *TCPListener) invokePlannedDisconnect() { + t.lock.RLock() + defer t.lock.RUnlock() if t.plannedDisconnect != nil { go t.plannedDisconnect(t) } } func (t *TCPListener) invokeUnexpectedDisconnect(err error) { + t.lock.RLock() + defer t.lock.RUnlock() if t.unexpectedDisconnect != nil { go t.unexpectedDisconnect(t, err) } } -func (t *TCPListener) rLockDo(fn func() error) error { +func (t *TCPListener) getConn() net.Conn { t.lock.RLock() defer t.lock.RUnlock() - if t.conn == nil { - return ErrConnectionClosed - } - return fn() -} - -func (t *TCPListener) connIsNil() bool { - t.lock.RLock() - defer t.lock.RUnlock() - return t.conn == nil + return t.conn }