1
0
mirror of https://github.com/Mrs4s/go-cqhttp.git synced 2025-05-05 03:23:49 +08:00

fix: use gorilla/websocket

This commit is contained in:
wdvxdr 2021-11-18 14:23:17 +08:00
parent c2c1fb00e5
commit a2d1e88ed5
No known key found for this signature in database
GPG Key ID: 703F8C071DE7A1B6

View File

@ -2,7 +2,6 @@ package server
import ( import (
"bytes" "bytes"
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/http" "net/http"
@ -13,10 +12,10 @@ import (
"time" "time"
"github.com/Mrs4s/MiraiGo/utils" "github.com/Mrs4s/MiraiGo/utils"
"github.com/gorilla/websocket"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
"gopkg.in/yaml.v3" "gopkg.in/yaml.v3"
"nhooyr.io/websocket"
"github.com/Mrs4s/go-cqhttp/coolq" "github.com/Mrs4s/go-cqhttp/coolq"
"github.com/Mrs4s/go-cqhttp/global" "github.com/Mrs4s/go-cqhttp/global"
@ -55,6 +54,12 @@ type wsConn struct {
apiCaller *api.Caller apiCaller *api.Caller
} }
var upgrader = websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
return true
},
}
// runWSServer 运行一个正向WS server // runWSServer 运行一个正向WS server
func runWSServer(b *coolq.CQBot, node yaml.Node) { func runWSServer(b *coolq.CQBot, node yaml.Node) {
var conf config.WebsocketServer var conf config.WebsocketServer
@ -133,7 +138,7 @@ func (c *websocketClient) connect(typ, url string, conptr **wsConn) {
if c.token != "" { if c.token != "" {
header["Authorization"] = []string{"Token " + c.token} header["Authorization"] = []string{"Token " + c.token}
} }
conn, _, err := websocket.Dial(context.Background(), url, &websocket.DialOptions{HTTPHeader: header}) // nolint conn, _, err := websocket.DefaultDialer.Dial(url, header) // nolint
if err != nil { if err != nil {
log.Warnf("连接到反向WebSocket %s服务器 %v 时出现错误: %v", typ, url, err) log.Warnf("连接到反向WebSocket %s服务器 %v 时出现错误: %v", typ, url, err)
if c.reconnectInterval != 0 { if c.reconnectInterval != 0 {
@ -146,7 +151,7 @@ func (c *websocketClient) connect(typ, url string, conptr **wsConn) {
switch typ { switch typ {
case "Event", "Universal": case "Event", "Universal":
handshake := fmt.Sprintf(`{"meta_event_type":"lifecycle","post_type":"meta_event","self_id":%d,"sub_type":"connect","time":%d}`, c.bot.Client.Uin, time.Now().Unix()) handshake := fmt.Sprintf(`{"meta_event_type":"lifecycle","post_type":"meta_event","self_id":%d,"sub_type":"connect","time":%d}`, c.bot.Client.Uin, time.Now().Unix())
err = conn.Write(context.Background(), websocket.MessageText, []byte(handshake)) err = conn.WriteMessage(websocket.TextMessage, []byte(handshake))
if err != nil { if err != nil {
log.Warnf("反向WebSocket 握手时出现错误: %v", err) log.Warnf("反向WebSocket 握手时出现错误: %v", err)
} }
@ -168,11 +173,10 @@ func (c *websocketClient) connect(typ, url string, conptr **wsConn) {
} }
func (c *websocketClient) listenAPI(typ, url string, conn *wsConn) { func (c *websocketClient) listenAPI(typ, url string, conn *wsConn) {
defer func() { _ = conn.Close(websocket.StatusNormalClosure, "") }() defer func() { _ = conn.Close() }()
conn.Conn.SetReadLimit(1024 * 1024 * 128)
for { for {
buffer := global.NewBuffer() buffer := global.NewBuffer()
t, reader, err := conn.Conn.Reader(context.Background()) t, reader, err := conn.Conn.NextReader()
if err != nil { if err != nil {
log.Warnf("监听反向WS %s时出现错误: %v", typ, err) log.Warnf("监听反向WS %s时出现错误: %v", typ, err)
break break
@ -182,7 +186,7 @@ func (c *websocketClient) listenAPI(typ, url string, conn *wsConn) {
log.Warnf("监听反向WS %s时出现错误: %v", typ, err) log.Warnf("监听反向WS %s时出现错误: %v", typ, err)
break break
} }
if t == websocket.MessageText { if t == websocket.TextMessage {
go func(buffer *bytes.Buffer) { go func(buffer *bytes.Buffer) {
defer global.PutBuffer(buffer) defer global.PutBuffer(buffer)
conn.handleRequest(c.bot, buffer.Bytes()) conn.handleRequest(c.bot, buffer.Bytes())
@ -211,11 +215,9 @@ func (c *websocketClient) onBotPushEvent(typ, url string, conn **wsConn) func(e
} }
log.Debugf("向反向WS %s服务器推送Event: %s", typ, e.JSONBytes()) log.Debugf("向反向WS %s服务器推送Event: %s", typ, e.JSONBytes())
ctx, cancel := context.WithTimeout(context.Background(), time.Second*15) if err := (*conn).WriteMessage(websocket.TextMessage, e.JSONBytes()); err != nil {
defer cancel()
if err := (*conn).Write(ctx, websocket.MessageText, e.JSONBytes()); err != nil {
log.Warnf("向反向WS %s服务器推送 Event 时出现错误: %v", typ, err) log.Warnf("向反向WS %s服务器推送 Event 时出现错误: %v", typ, err)
_ = (*conn).Close(websocket.StatusNormalClosure, "") _ = (*conn).Close()
if c.reconnectInterval != 0 { if c.reconnectInterval != 0 {
time.Sleep(c.reconnectInterval) time.Sleep(c.reconnectInterval)
c.connect(typ, url, conn) c.connect(typ, url, conn)
@ -231,23 +233,22 @@ func (s *webSocketServer) event(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(status) w.WriteHeader(status)
return return
} }
opts := &websocket.AcceptOptions{InsecureSkipVerify: true}
c, err := websocket.Accept(w, r, opts) c, err := upgrader.Upgrade(w, r, nil)
if err != nil { if err != nil {
log.Warnf("处理 WebSocket 请求时出现错误: %v", err) log.Warnf("处理 WebSocket 请求时出现错误: %v", err)
return return
} }
err = c.Write(context.Background(), websocket.MessageText, []byte(s.handshake))
err = c.WriteMessage(websocket.TextMessage, []byte(s.handshake))
if err != nil { if err != nil {
log.Warnf("WebSocket 握手时出现错误: %v", err) log.Warnf("WebSocket 握手时出现错误: %v", err)
_ = c.Close(websocket.StatusNormalClosure, "") _ = c.Close()
return return
} }
log.Infof("接受 WebSocket 连接: %v (/event)", r.RemoteAddr) log.Infof("接受 WebSocket 连接: %v (/event)", r.RemoteAddr)
conn := &wsConn{Conn: c, apiCaller: api.NewCaller(s.bot)} conn := &wsConn{Conn: c, apiCaller: api.NewCaller(s.bot)}
s.mu.Lock() s.mu.Lock()
s.eventConn = append(s.eventConn, conn) s.eventConn = append(s.eventConn, conn)
s.mu.Unlock() s.mu.Unlock()
@ -260,18 +261,19 @@ func (s *webSocketServer) api(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(status) w.WriteHeader(status)
return return
} }
opts := &websocket.AcceptOptions{InsecureSkipVerify: true}
c, err := websocket.Accept(w, r, opts) c, err := upgrader.Upgrade(w, r, nil)
if err != nil { if err != nil {
log.Warnf("处理 WebSocket 请求时出现错误: %v", err) log.Warnf("处理 WebSocket 请求时出现错误: %v", err)
return return
} }
log.Infof("接受 WebSocket 连接: %v (/api)", r.RemoteAddr) log.Infof("接受 WebSocket 连接: %v (/api)", r.RemoteAddr)
conn := &wsConn{Conn: c, apiCaller: api.NewCaller(s.bot)} conn := &wsConn{Conn: c, apiCaller: api.NewCaller(s.bot)}
if s.conf.RateLimit.Enabled { if s.conf.RateLimit.Enabled {
conn.apiCaller.Use(rateLimit(s.conf.RateLimit.Frequency, s.conf.RateLimit.Bucket)) conn.apiCaller.Use(rateLimit(s.conf.RateLimit.Frequency, s.conf.RateLimit.Bucket))
} }
go s.listenAPI(conn) s.listenAPI(conn)
} }
func (s *webSocketServer) any(w http.ResponseWriter, r *http.Request) { func (s *webSocketServer) any(w http.ResponseWriter, r *http.Request) {
@ -281,18 +283,20 @@ func (s *webSocketServer) any(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(status) w.WriteHeader(status)
return return
} }
opts := &websocket.AcceptOptions{InsecureSkipVerify: true}
c, err := websocket.Accept(w, r, opts) c, err := upgrader.Upgrade(w, r, nil)
if err != nil { if err != nil {
log.Warnf("处理 WebSocket 请求时出现错误: %v", err) log.Warnf("处理 WebSocket 请求时出现错误: %v", err)
return return
} }
err = c.Write(context.Background(), websocket.MessageText, []byte(s.handshake))
err = c.WriteMessage(websocket.TextMessage, []byte(s.handshake))
if err != nil { if err != nil {
log.Warnf("WebSocket 握手时出现错误: %v", err) log.Warnf("WebSocket 握手时出现错误: %v", err)
_ = c.Close(websocket.StatusNormalClosure, "") _ = c.Close()
return return
} }
log.Infof("接受 WebSocket 连接: %v (/)", r.RemoteAddr) log.Infof("接受 WebSocket 连接: %v (/)", r.RemoteAddr)
conn := &wsConn{Conn: c, apiCaller: api.NewCaller(s.bot)} conn := &wsConn{Conn: c, apiCaller: api.NewCaller(s.bot)}
if s.conf.RateLimit.Enabled { if s.conf.RateLimit.Enabled {
@ -305,11 +309,10 @@ func (s *webSocketServer) any(w http.ResponseWriter, r *http.Request) {
} }
func (s *webSocketServer) listenAPI(c *wsConn) { func (s *webSocketServer) listenAPI(c *wsConn) {
defer func() { _ = c.Close(websocket.StatusNormalClosure, "") }() defer func() { _ = c.Close() }()
c.Conn.SetReadLimit(1024 * 1024 * 128)
for { for {
buffer := global.NewBuffer() buffer := global.NewBuffer()
t, reader, err := c.Reader(context.Background()) t, reader, err := c.NextReader()
if err != nil { if err != nil {
break break
} }
@ -318,7 +321,7 @@ func (s *webSocketServer) listenAPI(c *wsConn) {
break break
} }
if t == websocket.MessageText { if t == websocket.TextMessage {
go func(buffer *bytes.Buffer) { go func(buffer *bytes.Buffer) {
defer global.PutBuffer(buffer) defer global.PutBuffer(buffer)
c.handleRequest(s.bot, buffer.Bytes()) c.handleRequest(s.bot, buffer.Bytes())
@ -333,7 +336,7 @@ func (c *wsConn) handleRequest(_ *coolq.CQBot, payload []byte) {
defer func() { defer func() {
if err := recover(); err != nil { if err := recover(); err != nil {
log.Printf("处置WS命令时发生无法恢复的异常%v\n%s", err, debug.Stack()) log.Printf("处置WS命令时发生无法恢复的异常%v\n%s", err, debug.Stack())
_ = c.Close(websocket.StatusInternalError, fmt.Sprint(err)) _ = c.Close()
} }
}() }()
j := gjson.Parse(utils.B2S(payload)) j := gjson.Parse(utils.B2S(payload))
@ -343,7 +346,7 @@ func (c *wsConn) handleRequest(_ *coolq.CQBot, payload []byte) {
if j.Get("echo").Exists() { if j.Get("echo").Exists() {
ret["echo"] = j.Get("echo").Value() ret["echo"] = j.Get("echo").Value()
} }
writer, _ := c.Writer(context.Background(), websocket.MessageText) writer, _ := c.NextWriter(websocket.TextMessage)
_ = json.NewEncoder(writer).Encode(ret) _ = json.NewEncoder(writer).Encode(ret)
_ = writer.Close() _ = writer.Close()
} }
@ -362,8 +365,8 @@ func (s *webSocketServer) onBotPushEvent(e *coolq.Event) {
for i := 0; i < len(s.eventConn); i++ { for i := 0; i < len(s.eventConn); i++ {
conn := s.eventConn[i] conn := s.eventConn[i]
log.Debugf("向WS客户端推送Event: %s", e.JSONBytes()) log.Debugf("向WS客户端推送Event: %s", e.JSONBytes())
if err := conn.Write(context.Background(), websocket.MessageText, e.JSONBytes()); err != nil { if err := conn.WriteMessage(websocket.TextMessage, e.JSONBytes()); err != nil {
_ = conn.Close(websocket.StatusNormalClosure, "") _ = conn.Close()
conn = nil conn = nil
continue continue
} }