diff --git a/server/websocket.go b/server/websocket.go index 144170e..38fb96a 100644 --- a/server/websocket.go +++ b/server/websocket.go @@ -50,10 +50,21 @@ type websocketClient struct { } type wsConn struct { - *websocket.Conn + mu sync.Mutex + conn *websocket.Conn apiCaller *api.Caller } +func (c *wsConn) WriteText(b []byte) error { + c.mu.Lock() + defer c.mu.Unlock() + return c.conn.WriteMessage(websocket.TextMessage, b) +} + +func (c *wsConn) Close() error { + return c.conn.Close() +} + var upgrader = websocket.Upgrader{ CheckOrigin: func(r *http.Request) bool { return true @@ -158,7 +169,7 @@ func (c *websocketClient) connect(typ, url string, conptr **wsConn) { } log.Infof("已连接到反向WebSocket %s服务器 %v", typ, url) - wrappedConn := &wsConn{Conn: conn, apiCaller: api.NewCaller(c.bot)} + wrappedConn := &wsConn{conn: conn, apiCaller: api.NewCaller(c.bot)} if c.limiter != nil { wrappedConn.apiCaller.Use(c.limiter) } @@ -176,7 +187,7 @@ func (c *websocketClient) listenAPI(typ, url string, conn *wsConn) { defer func() { _ = conn.Close() }() for { buffer := global.NewBuffer() - t, reader, err := conn.Conn.NextReader() + t, reader, err := conn.conn.NextReader() if err != nil { log.Warnf("监听反向WS %s时出现错误: %v", typ, err) break @@ -215,7 +226,7 @@ func (c *websocketClient) onBotPushEvent(typ, url string, conn **wsConn) func(e } log.Debugf("向反向WS %s服务器推送Event: %s", typ, e.JSONBytes()) - if err := (*conn).WriteMessage(websocket.TextMessage, e.JSONBytes()); err != nil { + if err := (*conn).WriteText(e.JSONBytes()); err != nil { log.Warnf("向反向WS %s服务器推送 Event 时出现错误: %v", typ, err) _ = (*conn).Close() if c.reconnectInterval != 0 { @@ -248,7 +259,7 @@ func (s *webSocketServer) event(w http.ResponseWriter, r *http.Request) { } log.Infof("接受 WebSocket 连接: %v (/event)", r.RemoteAddr) - conn := &wsConn{Conn: c, apiCaller: api.NewCaller(s.bot)} + conn := &wsConn{conn: c, apiCaller: api.NewCaller(s.bot)} s.mu.Lock() s.eventConn = append(s.eventConn, conn) s.mu.Unlock() @@ -269,7 +280,7 @@ func (s *webSocketServer) api(w http.ResponseWriter, r *http.Request) { } log.Infof("接受 WebSocket 连接: %v (/api)", r.RemoteAddr) - conn := &wsConn{Conn: c, apiCaller: api.NewCaller(s.bot)} + conn := &wsConn{conn: c, apiCaller: api.NewCaller(s.bot)} if s.conf.RateLimit.Enabled { conn.apiCaller.Use(rateLimit(s.conf.RateLimit.Frequency, s.conf.RateLimit.Bucket)) } @@ -298,7 +309,7 @@ func (s *webSocketServer) any(w http.ResponseWriter, r *http.Request) { } log.Infof("接受 WebSocket 连接: %v (/)", r.RemoteAddr) - conn := &wsConn{Conn: c, apiCaller: api.NewCaller(s.bot)} + conn := &wsConn{conn: c, apiCaller: api.NewCaller(s.bot)} if s.conf.RateLimit.Enabled { conn.apiCaller.Use(rateLimit(s.conf.RateLimit.Frequency, s.conf.RateLimit.Bucket)) } @@ -312,7 +323,7 @@ func (s *webSocketServer) listenAPI(c *wsConn) { defer func() { _ = c.Close() }() for { buffer := global.NewBuffer() - t, reader, err := c.NextReader() + t, reader, err := c.conn.NextReader() if err != nil { break } @@ -346,7 +357,10 @@ func (c *wsConn) handleRequest(_ *coolq.CQBot, payload []byte) { if j.Get("echo").Exists() { ret["echo"] = j.Get("echo").Value() } - writer, _ := c.NextWriter(websocket.TextMessage) + + c.mu.Lock() + defer c.mu.Unlock() + writer, _ := c.conn.NextWriter(websocket.TextMessage) _ = json.NewEncoder(writer).Encode(ret) _ = writer.Close() } @@ -365,7 +379,7 @@ func (s *webSocketServer) onBotPushEvent(e *coolq.Event) { for i := 0; i < len(s.eventConn); i++ { conn := s.eventConn[i] log.Debugf("向WS客户端推送Event: %s", e.JSONBytes()) - if err := conn.WriteMessage(websocket.TextMessage, e.JSONBytes()); err != nil { + if err := conn.WriteText(e.JSONBytes()); err != nil { _ = conn.Close() conn = nil continue