1
0
mirror of https://github.com/Mrs4s/MiraiGo.git synced 2025-05-04 11:07:40 +08:00

fix(conn): fix deadlock.

Fixes: #179
This commit is contained in:
wdvxdr 2021-08-15 20:38:19 +08:00
parent b8399fa706
commit bebbc7f33b
No known key found for this signature in database
GPG Key ID: 703F8C071DE7A1B6

View File

@ -21,11 +21,15 @@ var ErrConnectionClosed = errors.New("connection closed")
// PlannedDisconnect 预料中的断开连接
// 如调用 Close() Connect()
func (t *TCPListener) PlannedDisconnect(f func(*TCPListener)) {
t.lock.Lock()
defer t.lock.Unlock()
t.plannedDisconnect = f
}
// UnexpectedDisconnect 未预料钟的断开连接
func (t *TCPListener) UnexpectedDisconnect(f func(*TCPListener, error)) {
t.lock.Lock()
defer t.lock.Unlock()
t.unexpectedDisconnect = f
}
@ -42,30 +46,30 @@ func (t *TCPListener) Connect(addr *net.TCPAddr) error {
}
func (t *TCPListener) Write(buf []byte) error {
err := t.rLockDo(func() error {
_, e := t.conn.Write(buf)
return e
})
if err == nil {
if conn := t.getConn(); conn != nil {
_, err := conn.Write(buf)
if err != nil {
t.unexpectedClose(err)
return ErrConnectionClosed
}
return nil
}
t.unexpectedClose(err)
return ErrConnectionClosed
}
func (t *TCPListener) ReadBytes(len int) ([]byte, error) {
buf := make([]byte, len)
err := t.rLockDo(func() error {
_, e := io.ReadFull(t.conn, buf)
return e
})
if err == nil {
if conn := t.getConn(); conn != nil {
_, err := io.ReadFull(conn, buf)
if err != nil {
// time.Sleep(time.Millisecond * 100) // 服务器会发送offline包后立即断开连接, 此时还没解析, 可能还是得加锁
t.unexpectedClose(err)
return nil, ErrConnectionClosed
}
return buf, nil
}
//time.Sleep(time.Millisecond * 100) // 服务器会发送offline包后立即断开连接, 此时还没解析, 可能还是得加锁
t.unexpectedClose(err)
return nil, ErrConnectionClosed
}
@ -78,18 +82,14 @@ func (t *TCPListener) ReadInt32() (int32, error) {
}
func (t *TCPListener) Close() {
if !t.connIsNil() {
t.close()
t.invokePlannedDisconnect()
}
}
func (t *TCPListener) unexpectedClose(err error) {
if !t.connIsNil() {
t.close()
t.invokeUnexpectedDisconnect(err)
}
}
func (t *TCPListener) close() {
t.lock.Lock()
@ -101,28 +101,23 @@ func (t *TCPListener) close() {
}
func (t *TCPListener) invokePlannedDisconnect() {
t.lock.RLock()
defer t.lock.RUnlock()
if t.plannedDisconnect != nil {
go t.plannedDisconnect(t)
}
}
func (t *TCPListener) invokeUnexpectedDisconnect(err error) {
t.lock.RLock()
defer t.lock.RUnlock()
if t.unexpectedDisconnect != nil {
go t.unexpectedDisconnect(t, err)
}
}
func (t *TCPListener) rLockDo(fn func() error) error {
func (t *TCPListener) getConn() net.Conn {
t.lock.RLock()
defer t.lock.RUnlock()
if t.conn == nil {
return ErrConnectionClosed
}
return fn()
}
func (t *TCPListener) connIsNil() bool {
t.lock.RLock()
defer t.lock.RUnlock()
return t.conn == nil
return t.conn
}