diff --git a/.idea/.gitignore b/.idea/.gitignore deleted file mode 100644 index 73f69e09..00000000 --- a/.idea/.gitignore +++ /dev/null @@ -1,8 +0,0 @@ -# Default ignored files -/shelf/ -/workspace.xml -# Datasource local storage ignored files -/dataSources/ -/dataSources.local.xml -# Editor-based HTTP Client requests -/httpRequests/ diff --git a/utils/connection.go b/utils/connection.go index 5d758622..92485000 100644 --- a/utils/connection.go +++ b/utils/connection.go @@ -4,12 +4,13 @@ import ( "encoding/binary" "io" "net" - "time" + "sync" "github.com/pkg/errors" ) type TCPListener struct { + lock sync.RWMutex conn net.Conn plannedDisconnect func(*TCPListener) unexpectedDisconnect func(*TCPListener, error) @@ -34,40 +35,38 @@ func (t *TCPListener) Connect(addr *net.TCPAddr) error { if err != nil { return errors.Wrap(err, "dial tcp error") } + t.lock.Lock() + defer t.lock.Unlock() t.conn = conn return nil } func (t *TCPListener) Write(buf []byte) error { - if t.conn == nil { - return ErrConnectionClosed + err := t.rLockDo(func() error { + _, e := t.conn.Write(buf) + return e + }) + if err == nil { + return nil } - _, err := t.conn.Write(buf) - if err != nil { - if t.conn != nil { - t.close() - t.invokeUnexpectedDisconnect(err) - } - return ErrConnectionClosed - } - return nil + + t.unexpectedClose(err) + return ErrConnectionClosed } func (t *TCPListener) ReadBytes(len int) ([]byte, error) { - if t.conn == nil { - return nil, ErrConnectionClosed - } buf := make([]byte, len) - _, err := io.ReadFull(t.conn, buf) - if err != nil { - time.Sleep(time.Millisecond * 100) // 服务器会发送offline包后立即断开连接, 此时还没解析, 可能还是得加锁 - if t.conn != nil { - t.close() - t.invokeUnexpectedDisconnect(err) - } - return nil, ErrConnectionClosed + err := t.rLockDo(func() error { + _, e := io.ReadFull(t.conn, buf) + return e + }) + if err == nil { + return buf, nil } - return buf, nil + + //time.Sleep(time.Millisecond * 100) // 服务器会发送offline包后立即断开连接, 此时还没解析, 可能还是得加锁 + t.unexpectedClose(err) + return nil, ErrConnectionClosed } func (t *TCPListener) ReadInt32() (int32, error) { @@ -79,19 +78,26 @@ func (t *TCPListener) ReadInt32() (int32, error) { } func (t *TCPListener) Close() { - if t.conn == nil { - return + if !t.connIsNil() { + t.close() + t.invokePlannedDisconnect() + } +} + +func (t *TCPListener) unexpectedClose(err error) { + if !t.connIsNil() { + t.close() + t.invokeUnexpectedDisconnect(err) } - t.close() - t.invokePlannedDisconnect() } func (t *TCPListener) close() { - if t.conn == nil { - return + t.lock.Lock() + defer t.lock.Unlock() + if t.conn != nil { + _ = t.conn.Close() + t.conn = nil } - _ = t.conn.Close() - t.conn = nil } func (t *TCPListener) invokePlannedDisconnect() { @@ -105,3 +111,18 @@ func (t *TCPListener) invokeUnexpectedDisconnect(err error) { go t.unexpectedDisconnect(t, err) } } + +func (t *TCPListener) rLockDo(fn func() error) error { + t.lock.RLock() + defer t.lock.RUnlock() + if t.conn == nil { + return ErrConnectionClosed + } + return fn() +} + +func (t *TCPListener) connIsNil() bool { + t.lock.RLock() + defer t.lock.RUnlock() + return t.conn == nil +}