1
0
mirror of https://github.com/Mrs4s/MiraiGo.git synced 2025-05-04 19:17:38 +08:00

fix(conn): fix deadlock.

Fixes: #179
This commit is contained in:
wdvxdr 2021-08-15 20:38:19 +08:00
parent b8399fa706
commit bebbc7f33b
No known key found for this signature in database
GPG Key ID: 703F8C071DE7A1B6

View File

@ -21,11 +21,15 @@ var ErrConnectionClosed = errors.New("connection closed")
// PlannedDisconnect 预料中的断开连接 // PlannedDisconnect 预料中的断开连接
// 如调用 Close() Connect() // 如调用 Close() Connect()
func (t *TCPListener) PlannedDisconnect(f func(*TCPListener)) { func (t *TCPListener) PlannedDisconnect(f func(*TCPListener)) {
t.lock.Lock()
defer t.lock.Unlock()
t.plannedDisconnect = f t.plannedDisconnect = f
} }
// UnexpectedDisconnect 未预料钟的断开连接 // UnexpectedDisconnect 未预料钟的断开连接
func (t *TCPListener) UnexpectedDisconnect(f func(*TCPListener, error)) { func (t *TCPListener) UnexpectedDisconnect(f func(*TCPListener, error)) {
t.lock.Lock()
defer t.lock.Unlock()
t.unexpectedDisconnect = f t.unexpectedDisconnect = f
} }
@ -42,30 +46,30 @@ func (t *TCPListener) Connect(addr *net.TCPAddr) error {
} }
func (t *TCPListener) Write(buf []byte) error { func (t *TCPListener) Write(buf []byte) error {
err := t.rLockDo(func() error { if conn := t.getConn(); conn != nil {
_, e := t.conn.Write(buf) _, err := conn.Write(buf)
return e if err != nil {
}) t.unexpectedClose(err)
if err == nil { return ErrConnectionClosed
}
return nil return nil
} }
t.unexpectedClose(err)
return ErrConnectionClosed return ErrConnectionClosed
} }
func (t *TCPListener) ReadBytes(len int) ([]byte, error) { func (t *TCPListener) ReadBytes(len int) ([]byte, error) {
buf := make([]byte, len) buf := make([]byte, len)
err := t.rLockDo(func() error { if conn := t.getConn(); conn != nil {
_, e := io.ReadFull(t.conn, buf) _, err := io.ReadFull(conn, buf)
return e if err != nil {
}) // time.Sleep(time.Millisecond * 100) // 服务器会发送offline包后立即断开连接, 此时还没解析, 可能还是得加锁
if err == nil { t.unexpectedClose(err)
return nil, ErrConnectionClosed
}
return buf, nil return buf, nil
} }
//time.Sleep(time.Millisecond * 100) // 服务器会发送offline包后立即断开连接, 此时还没解析, 可能还是得加锁
t.unexpectedClose(err)
return nil, ErrConnectionClosed return nil, ErrConnectionClosed
} }
@ -78,18 +82,14 @@ func (t *TCPListener) ReadInt32() (int32, error) {
} }
func (t *TCPListener) Close() { func (t *TCPListener) Close() {
if !t.connIsNil() {
t.close() t.close()
t.invokePlannedDisconnect() t.invokePlannedDisconnect()
} }
}
func (t *TCPListener) unexpectedClose(err error) { func (t *TCPListener) unexpectedClose(err error) {
if !t.connIsNil() {
t.close() t.close()
t.invokeUnexpectedDisconnect(err) t.invokeUnexpectedDisconnect(err)
} }
}
func (t *TCPListener) close() { func (t *TCPListener) close() {
t.lock.Lock() t.lock.Lock()
@ -101,28 +101,23 @@ func (t *TCPListener) close() {
} }
func (t *TCPListener) invokePlannedDisconnect() { func (t *TCPListener) invokePlannedDisconnect() {
t.lock.RLock()
defer t.lock.RUnlock()
if t.plannedDisconnect != nil { if t.plannedDisconnect != nil {
go t.plannedDisconnect(t) go t.plannedDisconnect(t)
} }
} }
func (t *TCPListener) invokeUnexpectedDisconnect(err error) { func (t *TCPListener) invokeUnexpectedDisconnect(err error) {
t.lock.RLock()
defer t.lock.RUnlock()
if t.unexpectedDisconnect != nil { if t.unexpectedDisconnect != nil {
go t.unexpectedDisconnect(t, err) go t.unexpectedDisconnect(t, err)
} }
} }
func (t *TCPListener) rLockDo(fn func() error) error { func (t *TCPListener) getConn() net.Conn {
t.lock.RLock() t.lock.RLock()
defer t.lock.RUnlock() defer t.lock.RUnlock()
if t.conn == nil { return t.conn
return ErrConnectionClosed
}
return fn()
}
func (t *TCPListener) connIsNil() bool {
t.lock.RLock()
defer t.lock.RUnlock()
return t.conn == nil
} }