1
0
mirror of https://github.com/Mrs4s/MiraiGo.git synced 2025-05-04 19:17:38 +08:00
LXY f5b16b19c2
Network refactor (#229)
* 重构:netLoop下移到Listener

* 格式:修正NewClient中注释段缩进

* 格式:更名&将代码移动到对应位置

* 格式:TCPListener去锁化

* 修正:netLoop中的错误调用

* 修正:使其可用

* 修正:使功能一致

* 修正:现在可正常运行

* 优化:更早的释放锁(?

* 修正:未写完的部分

* 修正:潜在的断线时仍然认为在线这件事&删除空重复文件

* 文档:添加部分注释

* 修正:CoverError可能引起死锁

* 修正:永远不会被触发的DisconnectEvent

* 文档:将注释移动至对应位置
2021-12-27 15:36:52 +08:00

173 lines
3.8 KiB
Go

package network
import (
"encoding/binary"
"io"
"net"
"sync"
"sync/atomic"
"unsafe"
"github.com/pkg/errors"
)
type TCPListener struct {
//lock sync.RWMutex
conn *net.TCPConn
//connected bool
// PlannedDisconnect 预料中的断开连接
// 如调用 Close() Connect()
PlannedDisconnect func(*TCPListener)
// UnexpectedDisconnect 未预料的断开连接
UnexpectedDisconnect func(*TCPListener, error)
}
func (t *TCPListener) getConn() *net.TCPConn {
return (*net.TCPConn)(atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&t.conn))))
}
func (t *TCPListener) setConn(conn *net.TCPConn) (swapped bool) {
return atomic.CompareAndSwapPointer((*unsafe.Pointer)(unsafe.Pointer(&t.conn)), unsafe.Pointer(nil), unsafe.Pointer(conn))
}
func (t *TCPListener) closeConn() *net.TCPConn {
return (*net.TCPConn)(atomic.SwapPointer((*unsafe.Pointer)(unsafe.Pointer(&t.conn)), unsafe.Pointer(nil)))
}
func (t *TCPListener) Connected() bool {
// 等同于 t.getConn() != nil (? copilot写的)
return t.getConn() != nil
}
func (t *TCPListener) Connect(addr *net.TCPAddr) error {
t.Close()
conn, err := net.DialTCP("tcp", nil, addr)
if err != nil {
return errors.Wrap(err, "dial tcp error")
}
t.setConn(conn)
//t.lock.Lock()
//defer t.lock.Unlock()
//t.conn = conn
return nil
}
// ConnectFastest 连接到最快的服务器
// TODO 禁用不可用服务器
func (t *TCPListener) ConnectFastest(addr []*net.TCPAddr) (*net.TCPAddr, error) {
ch := make(chan error)
wg := sync.WaitGroup{}
wg.Add(len(addr))
for _, remote := range addr {
go func(remote *net.TCPAddr) {
defer wg.Done()
conn, err := net.DialTCP("tcp", nil, remote)
if err != nil {
return
}
//addrs = append(addrs, remote)
if !t.setConn(conn) {
_ = conn.Close()
return
}
ch <- nil
}(remote)
}
go func() {
wg.Wait()
if t.getConn() == nil {
ch <- errors.New("All addr are unreachable")
}
}()
err := <-ch
if err != nil {
return nil, err
}
conn := t.getConn()
return conn.RemoteAddr().(*net.TCPAddr), nil
}
func (t *TCPListener) Write(buf []byte) error {
if conn := t.getConn(); conn != nil {
_, err := conn.Write(buf)
if err != nil {
t.unexpectedClose(err)
return ErrConnectionBroken
}
return nil
}
return ErrConnectionBroken
}
func (t *TCPListener) ReadBytes(len int) ([]byte, error) {
buf := make([]byte, len)
if conn := t.getConn(); conn != nil {
_, err := io.ReadFull(conn, buf)
if err != nil {
// time.Sleep(time.Millisecond * 100) // 服务器会发送offline包后立即断开连接, 此时还没解析, 可能还是得加锁
t.unexpectedClose(err)
return nil, ErrConnectionBroken
}
return buf, nil
}
return nil, ErrConnectionBroken
}
func (t *TCPListener) ReadInt32() (int32, error) {
b, err := t.ReadBytes(4)
if err != nil {
return 0, err
}
return int32(binary.BigEndian.Uint32(b)), nil
}
func (t *TCPListener) Close() {
t.close()
t.invokePlannedDisconnect()
}
func (t *TCPListener) unexpectedClose(err error) {
t.close()
t.invokeUnexpectedDisconnect(err)
}
func (t *TCPListener) close() {
if conn := t.closeConn(); conn != nil {
_ = conn.Close()
}
}
func (t *TCPListener) invokePlannedDisconnect() {
//if t.Connected() {
t.PlannedDisconnect(t)
//}
//t.lock.RLock()
//defer t.lock.RUnlock()
//if t.plannedDisconnect != nil && t.connected {
// go t.plannedDisconnect(t)
// t.connected = false
//}
}
func (t *TCPListener) invokeUnexpectedDisconnect(err error) {
//if t.Connected() {
t.UnexpectedDisconnect(t, err)
//}
//t.lock.RLock()
//defer t.lock.RUnlock()
//if t.unexpectedDisconnect != nil && t.connected {
// go t.unexpectedDisconnect(t, err)
// t.connected = false
//}
}
//func (t *TCPListener) getConn() net.Conn {
// t.lock.RLock()
// defer t.lock.RUnlock()
// return t.conn
//}