diff --git a/server/websocket.go b/server/websocket.go index 6ffb837..144170e 100644 --- a/server/websocket.go +++ b/server/websocket.go @@ -2,7 +2,6 @@ package server import ( "bytes" - "context" "encoding/json" "fmt" "net/http" @@ -13,10 +12,10 @@ import ( "time" "github.com/Mrs4s/MiraiGo/utils" + "github.com/gorilla/websocket" log "github.com/sirupsen/logrus" "github.com/tidwall/gjson" "gopkg.in/yaml.v3" - "nhooyr.io/websocket" "github.com/Mrs4s/go-cqhttp/coolq" "github.com/Mrs4s/go-cqhttp/global" @@ -55,6 +54,12 @@ type wsConn struct { apiCaller *api.Caller } +var upgrader = websocket.Upgrader{ + CheckOrigin: func(r *http.Request) bool { + return true + }, +} + // runWSServer 运行一个正向WS server func runWSServer(b *coolq.CQBot, node yaml.Node) { var conf config.WebsocketServer @@ -133,7 +138,7 @@ func (c *websocketClient) connect(typ, url string, conptr **wsConn) { if c.token != "" { header["Authorization"] = []string{"Token " + c.token} } - conn, _, err := websocket.Dial(context.Background(), url, &websocket.DialOptions{HTTPHeader: header}) // nolint + conn, _, err := websocket.DefaultDialer.Dial(url, header) // nolint if err != nil { log.Warnf("连接到反向WebSocket %s服务器 %v 时出现错误: %v", typ, url, err) if c.reconnectInterval != 0 { @@ -146,7 +151,7 @@ func (c *websocketClient) connect(typ, url string, conptr **wsConn) { switch typ { case "Event", "Universal": handshake := fmt.Sprintf(`{"meta_event_type":"lifecycle","post_type":"meta_event","self_id":%d,"sub_type":"connect","time":%d}`, c.bot.Client.Uin, time.Now().Unix()) - err = conn.Write(context.Background(), websocket.MessageText, []byte(handshake)) + err = conn.WriteMessage(websocket.TextMessage, []byte(handshake)) if err != nil { log.Warnf("反向WebSocket 握手时出现错误: %v", err) } @@ -168,11 +173,10 @@ func (c *websocketClient) connect(typ, url string, conptr **wsConn) { } func (c *websocketClient) listenAPI(typ, url string, conn *wsConn) { - defer func() { _ = conn.Close(websocket.StatusNormalClosure, "") }() - conn.Conn.SetReadLimit(1024 * 1024 * 128) + defer func() { _ = conn.Close() }() for { buffer := global.NewBuffer() - t, reader, err := conn.Conn.Reader(context.Background()) + t, reader, err := conn.Conn.NextReader() if err != nil { log.Warnf("监听反向WS %s时出现错误: %v", typ, err) break @@ -182,7 +186,7 @@ func (c *websocketClient) listenAPI(typ, url string, conn *wsConn) { log.Warnf("监听反向WS %s时出现错误: %v", typ, err) break } - if t == websocket.MessageText { + if t == websocket.TextMessage { go func(buffer *bytes.Buffer) { defer global.PutBuffer(buffer) conn.handleRequest(c.bot, buffer.Bytes()) @@ -211,11 +215,9 @@ func (c *websocketClient) onBotPushEvent(typ, url string, conn **wsConn) func(e } log.Debugf("向反向WS %s服务器推送Event: %s", typ, e.JSONBytes()) - ctx, cancel := context.WithTimeout(context.Background(), time.Second*15) - defer cancel() - if err := (*conn).Write(ctx, websocket.MessageText, e.JSONBytes()); err != nil { + if err := (*conn).WriteMessage(websocket.TextMessage, e.JSONBytes()); err != nil { log.Warnf("向反向WS %s服务器推送 Event 时出现错误: %v", typ, err) - _ = (*conn).Close(websocket.StatusNormalClosure, "") + _ = (*conn).Close() if c.reconnectInterval != 0 { time.Sleep(c.reconnectInterval) c.connect(typ, url, conn) @@ -231,23 +233,22 @@ func (s *webSocketServer) event(w http.ResponseWriter, r *http.Request) { w.WriteHeader(status) return } - opts := &websocket.AcceptOptions{InsecureSkipVerify: true} - c, err := websocket.Accept(w, r, opts) + + c, err := upgrader.Upgrade(w, r, nil) if err != nil { log.Warnf("处理 WebSocket 请求时出现错误: %v", err) return } - err = c.Write(context.Background(), websocket.MessageText, []byte(s.handshake)) + + err = c.WriteMessage(websocket.TextMessage, []byte(s.handshake)) if err != nil { log.Warnf("WebSocket 握手时出现错误: %v", err) - _ = c.Close(websocket.StatusNormalClosure, "") + _ = c.Close() return } log.Infof("接受 WebSocket 连接: %v (/event)", r.RemoteAddr) - conn := &wsConn{Conn: c, apiCaller: api.NewCaller(s.bot)} - s.mu.Lock() s.eventConn = append(s.eventConn, conn) s.mu.Unlock() @@ -260,18 +261,19 @@ func (s *webSocketServer) api(w http.ResponseWriter, r *http.Request) { w.WriteHeader(status) return } - opts := &websocket.AcceptOptions{InsecureSkipVerify: true} - c, err := websocket.Accept(w, r, opts) + + c, err := upgrader.Upgrade(w, r, nil) if err != nil { log.Warnf("处理 WebSocket 请求时出现错误: %v", err) return } + log.Infof("接受 WebSocket 连接: %v (/api)", r.RemoteAddr) conn := &wsConn{Conn: c, apiCaller: api.NewCaller(s.bot)} if s.conf.RateLimit.Enabled { conn.apiCaller.Use(rateLimit(s.conf.RateLimit.Frequency, s.conf.RateLimit.Bucket)) } - go s.listenAPI(conn) + s.listenAPI(conn) } func (s *webSocketServer) any(w http.ResponseWriter, r *http.Request) { @@ -281,18 +283,20 @@ func (s *webSocketServer) any(w http.ResponseWriter, r *http.Request) { w.WriteHeader(status) return } - opts := &websocket.AcceptOptions{InsecureSkipVerify: true} - c, err := websocket.Accept(w, r, opts) + + c, err := upgrader.Upgrade(w, r, nil) if err != nil { log.Warnf("处理 WebSocket 请求时出现错误: %v", err) return } - err = c.Write(context.Background(), websocket.MessageText, []byte(s.handshake)) + + err = c.WriteMessage(websocket.TextMessage, []byte(s.handshake)) if err != nil { log.Warnf("WebSocket 握手时出现错误: %v", err) - _ = c.Close(websocket.StatusNormalClosure, "") + _ = c.Close() return } + log.Infof("接受 WebSocket 连接: %v (/)", r.RemoteAddr) conn := &wsConn{Conn: c, apiCaller: api.NewCaller(s.bot)} if s.conf.RateLimit.Enabled { @@ -305,11 +309,10 @@ func (s *webSocketServer) any(w http.ResponseWriter, r *http.Request) { } func (s *webSocketServer) listenAPI(c *wsConn) { - defer func() { _ = c.Close(websocket.StatusNormalClosure, "") }() - c.Conn.SetReadLimit(1024 * 1024 * 128) + defer func() { _ = c.Close() }() for { buffer := global.NewBuffer() - t, reader, err := c.Reader(context.Background()) + t, reader, err := c.NextReader() if err != nil { break } @@ -318,7 +321,7 @@ func (s *webSocketServer) listenAPI(c *wsConn) { break } - if t == websocket.MessageText { + if t == websocket.TextMessage { go func(buffer *bytes.Buffer) { defer global.PutBuffer(buffer) c.handleRequest(s.bot, buffer.Bytes()) @@ -333,7 +336,7 @@ func (c *wsConn) handleRequest(_ *coolq.CQBot, payload []byte) { defer func() { if err := recover(); err != nil { log.Printf("处置WS命令时发生无法恢复的异常:%v\n%s", err, debug.Stack()) - _ = c.Close(websocket.StatusInternalError, fmt.Sprint(err)) + _ = c.Close() } }() j := gjson.Parse(utils.B2S(payload)) @@ -343,7 +346,7 @@ func (c *wsConn) handleRequest(_ *coolq.CQBot, payload []byte) { if j.Get("echo").Exists() { ret["echo"] = j.Get("echo").Value() } - writer, _ := c.Writer(context.Background(), websocket.MessageText) + writer, _ := c.NextWriter(websocket.TextMessage) _ = json.NewEncoder(writer).Encode(ret) _ = writer.Close() } @@ -362,8 +365,8 @@ func (s *webSocketServer) onBotPushEvent(e *coolq.Event) { for i := 0; i < len(s.eventConn); i++ { conn := s.eventConn[i] log.Debugf("向WS客户端推送Event: %s", e.JSONBytes()) - if err := conn.Write(context.Background(), websocket.MessageText, e.JSONBytes()); err != nil { - _ = conn.Close(websocket.StatusNormalClosure, "") + if err := conn.WriteMessage(websocket.TextMessage, e.JSONBytes()); err != nil { + _ = conn.Close() conn = nil continue }