1
0
mirror of https://github.com/Mrs4s/go-cqhttp.git synced 2025-05-04 19:17:37 +08:00

server: fix concurrent write in ws

This commit is contained in:
wdvxdr 2021-12-05 22:40:44 +08:00
parent a06a891186
commit dfec28dac6
No known key found for this signature in database
GPG Key ID: 703F8C071DE7A1B6

View File

@ -50,10 +50,21 @@ type websocketClient struct {
} }
type wsConn struct { type wsConn struct {
*websocket.Conn mu sync.Mutex
conn *websocket.Conn
apiCaller *api.Caller apiCaller *api.Caller
} }
func (c *wsConn) WriteText(b []byte) error {
c.mu.Lock()
defer c.mu.Unlock()
return c.conn.WriteMessage(websocket.TextMessage, b)
}
func (c *wsConn) Close() error {
return c.conn.Close()
}
var upgrader = websocket.Upgrader{ var upgrader = websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool { CheckOrigin: func(r *http.Request) bool {
return true return true
@ -158,7 +169,7 @@ func (c *websocketClient) connect(typ, url string, conptr **wsConn) {
} }
log.Infof("已连接到反向WebSocket %s服务器 %v", typ, url) log.Infof("已连接到反向WebSocket %s服务器 %v", typ, url)
wrappedConn := &wsConn{Conn: conn, apiCaller: api.NewCaller(c.bot)} wrappedConn := &wsConn{conn: conn, apiCaller: api.NewCaller(c.bot)}
if c.limiter != nil { if c.limiter != nil {
wrappedConn.apiCaller.Use(c.limiter) wrappedConn.apiCaller.Use(c.limiter)
} }
@ -176,7 +187,7 @@ func (c *websocketClient) listenAPI(typ, url string, conn *wsConn) {
defer func() { _ = conn.Close() }() defer func() { _ = conn.Close() }()
for { for {
buffer := global.NewBuffer() buffer := global.NewBuffer()
t, reader, err := conn.Conn.NextReader() t, reader, err := conn.conn.NextReader()
if err != nil { if err != nil {
log.Warnf("监听反向WS %s时出现错误: %v", typ, err) log.Warnf("监听反向WS %s时出现错误: %v", typ, err)
break break
@ -215,7 +226,7 @@ func (c *websocketClient) onBotPushEvent(typ, url string, conn **wsConn) func(e
} }
log.Debugf("向反向WS %s服务器推送Event: %s", typ, e.JSONBytes()) log.Debugf("向反向WS %s服务器推送Event: %s", typ, e.JSONBytes())
if err := (*conn).WriteMessage(websocket.TextMessage, e.JSONBytes()); err != nil { if err := (*conn).WriteText(e.JSONBytes()); err != nil {
log.Warnf("向反向WS %s服务器推送 Event 时出现错误: %v", typ, err) log.Warnf("向反向WS %s服务器推送 Event 时出现错误: %v", typ, err)
_ = (*conn).Close() _ = (*conn).Close()
if c.reconnectInterval != 0 { if c.reconnectInterval != 0 {
@ -248,7 +259,7 @@ func (s *webSocketServer) event(w http.ResponseWriter, r *http.Request) {
} }
log.Infof("接受 WebSocket 连接: %v (/event)", r.RemoteAddr) log.Infof("接受 WebSocket 连接: %v (/event)", r.RemoteAddr)
conn := &wsConn{Conn: c, apiCaller: api.NewCaller(s.bot)} conn := &wsConn{conn: c, apiCaller: api.NewCaller(s.bot)}
s.mu.Lock() s.mu.Lock()
s.eventConn = append(s.eventConn, conn) s.eventConn = append(s.eventConn, conn)
s.mu.Unlock() s.mu.Unlock()
@ -269,7 +280,7 @@ func (s *webSocketServer) api(w http.ResponseWriter, r *http.Request) {
} }
log.Infof("接受 WebSocket 连接: %v (/api)", r.RemoteAddr) log.Infof("接受 WebSocket 连接: %v (/api)", r.RemoteAddr)
conn := &wsConn{Conn: c, apiCaller: api.NewCaller(s.bot)} conn := &wsConn{conn: c, apiCaller: api.NewCaller(s.bot)}
if s.conf.RateLimit.Enabled { if s.conf.RateLimit.Enabled {
conn.apiCaller.Use(rateLimit(s.conf.RateLimit.Frequency, s.conf.RateLimit.Bucket)) conn.apiCaller.Use(rateLimit(s.conf.RateLimit.Frequency, s.conf.RateLimit.Bucket))
} }
@ -298,7 +309,7 @@ func (s *webSocketServer) any(w http.ResponseWriter, r *http.Request) {
} }
log.Infof("接受 WebSocket 连接: %v (/)", r.RemoteAddr) log.Infof("接受 WebSocket 连接: %v (/)", r.RemoteAddr)
conn := &wsConn{Conn: c, apiCaller: api.NewCaller(s.bot)} conn := &wsConn{conn: c, apiCaller: api.NewCaller(s.bot)}
if s.conf.RateLimit.Enabled { if s.conf.RateLimit.Enabled {
conn.apiCaller.Use(rateLimit(s.conf.RateLimit.Frequency, s.conf.RateLimit.Bucket)) conn.apiCaller.Use(rateLimit(s.conf.RateLimit.Frequency, s.conf.RateLimit.Bucket))
} }
@ -312,7 +323,7 @@ func (s *webSocketServer) listenAPI(c *wsConn) {
defer func() { _ = c.Close() }() defer func() { _ = c.Close() }()
for { for {
buffer := global.NewBuffer() buffer := global.NewBuffer()
t, reader, err := c.NextReader() t, reader, err := c.conn.NextReader()
if err != nil { if err != nil {
break break
} }
@ -346,7 +357,10 @@ func (c *wsConn) handleRequest(_ *coolq.CQBot, payload []byte) {
if j.Get("echo").Exists() { if j.Get("echo").Exists() {
ret["echo"] = j.Get("echo").Value() ret["echo"] = j.Get("echo").Value()
} }
writer, _ := c.NextWriter(websocket.TextMessage)
c.mu.Lock()
defer c.mu.Unlock()
writer, _ := c.conn.NextWriter(websocket.TextMessage)
_ = json.NewEncoder(writer).Encode(ret) _ = json.NewEncoder(writer).Encode(ret)
_ = writer.Close() _ = writer.Close()
} }
@ -365,7 +379,7 @@ func (s *webSocketServer) onBotPushEvent(e *coolq.Event) {
for i := 0; i < len(s.eventConn); i++ { for i := 0; i < len(s.eventConn); i++ {
conn := s.eventConn[i] conn := s.eventConn[i]
log.Debugf("向WS客户端推送Event: %s", e.JSONBytes()) log.Debugf("向WS客户端推送Event: %s", e.JSONBytes())
if err := conn.WriteMessage(websocket.TextMessage, e.JSONBytes()); err != nil { if err := conn.WriteText(e.JSONBytes()); err != nil {
_ = conn.Close() _ = conn.Close()
conn = nil conn = nil
continue continue