mirror of
https://github.com/Mrs4s/MiraiGo.git
synced 2025-05-04 19:17:38 +08:00
* 重构:netLoop下移到Listener * 格式:修正NewClient中注释段缩进 * 格式:更名&将代码移动到对应位置 * 格式:TCPListener去锁化 * 修正:netLoop中的错误调用 * 修正:使其可用 * 修正:使功能一致 * 修正:现在可正常运行 * 优化:更早的释放锁(? * 修正:未写完的部分 * 修正:潜在的断线时仍然认为在线这件事&删除空重复文件 * 文档:添加部分注释 * 修正:CoverError可能引起死锁 * 修正:永远不会被触发的DisconnectEvent * 文档:将注释移动至对应位置
173 lines
3.8 KiB
Go
173 lines
3.8 KiB
Go
package network
|
|
|
|
import (
|
|
"encoding/binary"
|
|
"io"
|
|
"net"
|
|
"sync"
|
|
"sync/atomic"
|
|
"unsafe"
|
|
|
|
"github.com/pkg/errors"
|
|
)
|
|
|
|
type TCPListener struct {
|
|
//lock sync.RWMutex
|
|
conn *net.TCPConn
|
|
//connected bool
|
|
|
|
// PlannedDisconnect 预料中的断开连接
|
|
// 如调用 Close() Connect()
|
|
PlannedDisconnect func(*TCPListener)
|
|
|
|
// UnexpectedDisconnect 未预料的断开连接
|
|
UnexpectedDisconnect func(*TCPListener, error)
|
|
}
|
|
|
|
func (t *TCPListener) getConn() *net.TCPConn {
|
|
return (*net.TCPConn)(atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&t.conn))))
|
|
}
|
|
|
|
func (t *TCPListener) setConn(conn *net.TCPConn) (swapped bool) {
|
|
return atomic.CompareAndSwapPointer((*unsafe.Pointer)(unsafe.Pointer(&t.conn)), unsafe.Pointer(nil), unsafe.Pointer(conn))
|
|
}
|
|
|
|
func (t *TCPListener) closeConn() *net.TCPConn {
|
|
return (*net.TCPConn)(atomic.SwapPointer((*unsafe.Pointer)(unsafe.Pointer(&t.conn)), unsafe.Pointer(nil)))
|
|
}
|
|
|
|
func (t *TCPListener) Connected() bool {
|
|
// 等同于 t.getConn() != nil (? copilot写的)
|
|
return t.getConn() != nil
|
|
}
|
|
|
|
func (t *TCPListener) Connect(addr *net.TCPAddr) error {
|
|
t.Close()
|
|
conn, err := net.DialTCP("tcp", nil, addr)
|
|
if err != nil {
|
|
return errors.Wrap(err, "dial tcp error")
|
|
}
|
|
t.setConn(conn)
|
|
//t.lock.Lock()
|
|
//defer t.lock.Unlock()
|
|
//t.conn = conn
|
|
return nil
|
|
}
|
|
|
|
// ConnectFastest 连接到最快的服务器
|
|
// TODO 禁用不可用服务器
|
|
func (t *TCPListener) ConnectFastest(addr []*net.TCPAddr) (*net.TCPAddr, error) {
|
|
ch := make(chan error)
|
|
wg := sync.WaitGroup{}
|
|
wg.Add(len(addr))
|
|
for _, remote := range addr {
|
|
go func(remote *net.TCPAddr) {
|
|
defer wg.Done()
|
|
conn, err := net.DialTCP("tcp", nil, remote)
|
|
if err != nil {
|
|
return
|
|
}
|
|
//addrs = append(addrs, remote)
|
|
if !t.setConn(conn) {
|
|
_ = conn.Close()
|
|
return
|
|
}
|
|
ch <- nil
|
|
}(remote)
|
|
}
|
|
go func() {
|
|
wg.Wait()
|
|
if t.getConn() == nil {
|
|
ch <- errors.New("All addr are unreachable")
|
|
}
|
|
}()
|
|
err := <-ch
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
conn := t.getConn()
|
|
return conn.RemoteAddr().(*net.TCPAddr), nil
|
|
}
|
|
|
|
func (t *TCPListener) Write(buf []byte) error {
|
|
if conn := t.getConn(); conn != nil {
|
|
_, err := conn.Write(buf)
|
|
if err != nil {
|
|
t.unexpectedClose(err)
|
|
return ErrConnectionBroken
|
|
}
|
|
return nil
|
|
}
|
|
|
|
return ErrConnectionBroken
|
|
}
|
|
|
|
func (t *TCPListener) ReadBytes(len int) ([]byte, error) {
|
|
buf := make([]byte, len)
|
|
if conn := t.getConn(); conn != nil {
|
|
_, err := io.ReadFull(conn, buf)
|
|
if err != nil {
|
|
// time.Sleep(time.Millisecond * 100) // 服务器会发送offline包后立即断开连接, 此时还没解析, 可能还是得加锁
|
|
t.unexpectedClose(err)
|
|
return nil, ErrConnectionBroken
|
|
}
|
|
return buf, nil
|
|
}
|
|
|
|
return nil, ErrConnectionBroken
|
|
}
|
|
|
|
func (t *TCPListener) ReadInt32() (int32, error) {
|
|
b, err := t.ReadBytes(4)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
return int32(binary.BigEndian.Uint32(b)), nil
|
|
}
|
|
|
|
func (t *TCPListener) Close() {
|
|
t.close()
|
|
t.invokePlannedDisconnect()
|
|
}
|
|
|
|
func (t *TCPListener) unexpectedClose(err error) {
|
|
t.close()
|
|
t.invokeUnexpectedDisconnect(err)
|
|
}
|
|
|
|
func (t *TCPListener) close() {
|
|
if conn := t.closeConn(); conn != nil {
|
|
_ = conn.Close()
|
|
}
|
|
}
|
|
|
|
func (t *TCPListener) invokePlannedDisconnect() {
|
|
//if t.Connected() {
|
|
t.PlannedDisconnect(t)
|
|
//}
|
|
//t.lock.RLock()
|
|
//defer t.lock.RUnlock()
|
|
//if t.plannedDisconnect != nil && t.connected {
|
|
// go t.plannedDisconnect(t)
|
|
// t.connected = false
|
|
//}
|
|
}
|
|
|
|
func (t *TCPListener) invokeUnexpectedDisconnect(err error) {
|
|
//if t.Connected() {
|
|
t.UnexpectedDisconnect(t, err)
|
|
//}
|
|
//t.lock.RLock()
|
|
//defer t.lock.RUnlock()
|
|
//if t.unexpectedDisconnect != nil && t.connected {
|
|
// go t.unexpectedDisconnect(t, err)
|
|
// t.connected = false
|
|
//}
|
|
}
|
|
|
|
//func (t *TCPListener) getConn() net.Conn {
|
|
// t.lock.RLock()
|
|
// defer t.lock.RUnlock()
|
|
// return t.conn
|
|
//}
|