diff --git a/server/websocket.go b/server/websocket.go index 5d75b92..525d469 100644 --- a/server/websocket.go +++ b/server/websocket.go @@ -39,15 +39,15 @@ type webSocketServer struct { // websocketClient WebSocket客户端实例 type websocketClient struct { - bot *coolq.CQBot - conf *config.WebsocketReverse - - mu sync.Mutex - + bot *coolq.CQBot + mu sync.Mutex universal *wsConn event *wsConn - token string - filter string + + token string + filter string + reconnectInterval time.Duration + limiter api.Handler } type wsConn struct { @@ -81,10 +81,8 @@ func runWSServer(b *coolq.CQBot, node yaml.Node) { mux.HandleFunc("/event", s.event) mux.HandleFunc("/api", s.api) mux.HandleFunc("/", s.any) - go func() { - log.Infof("CQ WebSocket 服务器已启动: %v", addr) - log.Fatal(http.ListenAndServe(addr, &mux)) - }() + log.Infof("CQ WebSocket 服务器已启动: %v", addr) + log.Fatal(http.ListenAndServe(addr, &mux)) } // runWSClient 运行一个反向向WS client @@ -100,26 +98,33 @@ func runWSClient(b *coolq.CQBot, node yaml.Node) { c := &websocketClient{ bot: b, - conf: &conf, token: conf.AccessToken, filter: conf.Filter, } filter.Add(c.filter) - if c.conf.Universal != "" { - c.connect("Universal", conf.Universal, &c.universal) - } else { - if c.conf.API != "" { - c.connect("API", conf.API, nil) - } - if c.conf.Event != "" { - c.connect("Event", conf.Event, &c.event) - } + if conf.ReconnectInterval != 0 { + c.reconnectInterval = time.Duration(conf.ReconnectInterval) * time.Millisecond + } + if conf.RateLimit.Enabled { + c.limiter = rateLimit(conf.RateLimit.Frequency, conf.RateLimit.Bucket) + } + + if conf.Universal != "" { + c.connect("Universal", conf.Universal, &c.universal) + c.bot.OnEventPush(c.onBotPushEvent("Universal", conf.Universal, &c.universal)) + return // 连接到 Universal 后, 不再连接其他 + } + if conf.API != "" { + c.connect("API", conf.API, nil) + } + if conf.Event != "" { + c.connect("Event", conf.Event, &c.event) + c.bot.OnEventPush(c.onBotPushEvent("Event", conf.Event, &c.event)) } - c.bot.OnEventPush(c.onBotPushEvent) } func (c *websocketClient) connect(typ, url string, conptr **wsConn) { - log.Infof("开始尝试连接到反向WebSocket %s服务器: %v", typ, c.conf.API) + log.Infof("开始尝试连接到反向WebSocket %s服务器: %v", typ, url) header := http.Header{ "X-Client-Role": []string{typ}, "X-Self-ID": []string{strconv.FormatInt(c.bot.Client.Uin, 10)}, @@ -130,9 +135,9 @@ func (c *websocketClient) connect(typ, url string, conptr **wsConn) { } conn, _, err := websocket.Dial(context.Background(), url, &websocket.DialOptions{HTTPHeader: header}) // nolint if err != nil { - log.Warnf("连接到反向WebSocket %s服务器 %v 时出现错误: %v", typ, c.conf.API, err) - if c.conf.ReconnectInterval != 0 { - time.Sleep(time.Millisecond * time.Duration(c.conf.ReconnectInterval)) + log.Warnf("连接到反向WebSocket %s服务器 %v 时出现错误: %v", typ, url, err) + if c.reconnectInterval != 0 { + time.Sleep(c.reconnectInterval) c.connect(typ, url, conptr) } return @@ -147,37 +152,32 @@ func (c *websocketClient) connect(typ, url string, conptr **wsConn) { } } - log.Infof("已连接到反向WebSocket %s服务器 %v", typ, c.conf.API) + log.Infof("已连接到反向WebSocket %s服务器 %v", typ, url) wrappedConn := &wsConn{Conn: conn, apiCaller: api.NewCaller(c.bot)} - if c.conf.RateLimit.Enabled { - wrappedConn.apiCaller.Use(rateLimit(c.conf.RateLimit.Frequency, c.conf.RateLimit.Bucket)) + if c.limiter != nil { + wrappedConn.apiCaller.Use(c.limiter) } if conptr != nil { *conptr = wrappedConn } - switch typ { - case "API": - go c.listenAPI(wrappedConn, false) - case "Universal": - go c.listenAPI(wrappedConn, true) - } + go c.listenAPI(typ, url, wrappedConn) } -func (c *websocketClient) listenAPI(conn *wsConn, u bool) { +func (c *websocketClient) listenAPI(typ, url string, conn *wsConn) { defer func() { _ = conn.Close(websocket.StatusNormalClosure, "") }() conn.Conn.SetReadLimit(1024 * 1024 * 128) for { buffer := global.NewBuffer() t, reader, err := conn.Conn.Reader(context.Background()) if err != nil { - log.Warnf("监听反向WS API时出现错误: %v", err) + log.Warnf("监听反向WS %s时出现错误: %v", typ, err) break } _, err = buffer.ReadFrom(reader) if err != nil { - log.Warnf("监听反向WS API时出现错误: %v", err) + log.Warnf("监听反向WS %s时出现错误: %v", typ, err) break } if t == websocket.MessageText { @@ -189,49 +189,37 @@ func (c *websocketClient) listenAPI(conn *wsConn, u bool) { global.PutBuffer(buffer) } } - if c.conf.ReconnectInterval != 0 { - time.Sleep(time.Millisecond * time.Duration(c.conf.ReconnectInterval)) - if !u { - go c.connect("API", c.conf.API, nil) + if c.reconnectInterval != 0 { + time.Sleep(c.reconnectInterval) + if typ == "API" { // Universal 不重连,避免多次重连 + go c.connect(typ, url, nil) } } } -func (c *websocketClient) onBotPushEvent(e *coolq.Event) { - filter := filter.Find(c.filter) - if filter != nil && !filter.Eval(gjson.Parse(e.JSONString())) { - log.Debugf("上报Event %s 到 WS服务器 时被过滤.", e.JSONBytes()) - return - } - push := func(conn *wsConn, reconnect func()) { - log.Debugf("向WS服务器推送Event: %s", e.JSONBytes()) +func (c *websocketClient) onBotPushEvent(typ, url string, conn **wsConn) func(e *coolq.Event) { + return func(e *coolq.Event) { + c.mu.Lock() + defer c.mu.Unlock() + + flt := filter.Find(c.filter) + if flt != nil && !flt.Eval(gjson.Parse(e.JSONString())) { + log.Debugf("上报Event %s 到 WS服务器 时被过滤.", e.JSONBytes()) + return + } + + log.Debugf("向反向WS %s服务器推送Event: %s", typ, e.JSONBytes()) ctx, cancel := context.WithTimeout(context.Background(), time.Second*15) defer cancel() - if err := conn.Write(ctx, websocket.MessageText, e.JSONBytes()); err != nil { - log.Warnf("向WS服务器推送 Event 时出现错误: %v", err) - _ = conn.Close(websocket.StatusNormalClosure, "") - if c.conf.ReconnectInterval != 0 { - time.Sleep(time.Millisecond * time.Duration(c.conf.ReconnectInterval)) - reconnect() + if err := (*conn).Write(ctx, websocket.MessageText, e.JSONBytes()); err != nil { + log.Warnf("向反向WS %s服务器推送 Event 时出现错误: %v", typ, err) + _ = (*conn).Close(websocket.StatusNormalClosure, "") + if c.reconnectInterval != 0 { + time.Sleep(c.reconnectInterval) + c.connect(typ, url, conn) } } } - - connect := func(typ, url string, conptr **wsConn) func() { - return func() { - c.connect(typ, url, conptr) - } - } - - c.mu.Lock() - defer c.mu.Unlock() - - if c.event != nil { - push(c.event, connect("Event", c.conf.Event, &c.event)) - } - if c.universal != nil { - push(c.universal, connect("Universal", c.conf.Universal, &c.universal)) - } } func (s *webSocketServer) event(w http.ResponseWriter, r *http.Request) {