From 1771cda11c42cec8cb6fb1b8326926c35f16ed79 Mon Sep 17 00:00:00 2001 From: wdvxdr Date: Thu, 7 Oct 2021 22:24:45 +0800 Subject: [PATCH] feat(server): unify websocket client connect --- server/websocket.go | 180 +++++++++++++++++--------------------------- 1 file changed, 70 insertions(+), 110 deletions(-) diff --git a/server/websocket.go b/server/websocket.go index c9ee06e..64dd856 100644 --- a/server/websocket.go +++ b/server/websocket.go @@ -28,11 +28,12 @@ type webSocketServer struct { bot *coolq.CQBot conf *config.WebsocketServer - eventConn []*webSocketConn - eventConnMutex sync.Mutex - token string - handshake string - filter string + mu sync.Mutex + eventConn []*wsConn + + token string + handshake string + filter string } // websocketClient WebSocket客户端实例 @@ -40,13 +41,15 @@ type websocketClient struct { bot *coolq.CQBot conf *config.WebsocketReverse - universalConn *webSocketConn - eventConn *webSocketConn - token string - filter string + mu sync.Mutex + + universal *wsConn + event *wsConn + token string + filter string } -type webSocketConn struct { +type wsConn struct { *websocket.Conn apiCaller *apiCaller } @@ -102,20 +105,20 @@ func runWSClient(b *coolq.CQBot, node yaml.Node) { } addFilter(c.filter) if c.conf.Universal != "" { - c.connectUniversal() + c.connect("Universal", conf.Universal, &c.universal) } else { if c.conf.API != "" { - c.connectAPI() + c.connect("API", conf.API, nil) } if c.conf.Event != "" { - c.connectEvent() + c.connect("Event", conf.Event, &c.event) } } c.bot.OnEventPush(c.onBotPushEvent) } -func (c *websocketClient) connectAPI() { - log.Infof("开始尝试连接到反向WebSocket API服务器: %v", c.conf.API) +func (c *websocketClient) connect(typ, url string, conptr **wsConn) { + log.Infof("开始尝试连接到反向WebSocket %s服务器: %v", typ, c.conf.API) header := http.Header{ "X-Client-Role": []string{"API"}, "X-Self-ID": []string{strconv.FormatInt(c.bot.Client.Uin, 10)}, @@ -124,98 +127,44 @@ func (c *websocketClient) connectAPI() { if c.token != "" { header["Authorization"] = []string{"Token " + c.token} } - conn, _, err := websocket.Dial(context.Background(), c.conf.Universal, &websocket.DialOptions{HTTPHeader: header}) // nolint + conn, _, err := websocket.Dial(context.Background(), url, &websocket.DialOptions{HTTPHeader: header}) // nolint if err != nil { - log.Warnf("连接到反向WebSocket API服务器 %v 时出现错误: %v", c.conf.API, err) + log.Warnf("连接到反向WebSocket %s服务器 %v 时出现错误: %v", typ, c.conf.API, err) if c.conf.ReconnectInterval != 0 { time.Sleep(time.Millisecond * time.Duration(c.conf.ReconnectInterval)) - c.connectAPI() + c.connect(typ, url, conptr) } return } - log.Infof("已连接到反向WebSocket API服务器 %v", c.conf.API) - wrappedConn := &webSocketConn{Conn: conn, apiCaller: newAPICaller(c.bot)} + + switch typ { + case "Event", "Universal": + handshake := fmt.Sprintf(`{"meta_event_type":"lifecycle","post_type":"meta_event","self_id":%d,"sub_type":"connect","time":%d}`, c.bot.Client.Uin, time.Now().Unix()) + err = conn.Write(context.Background(), websocket.MessageText, []byte(handshake)) + if err != nil { + log.Warnf("反向WebSocket 握手时出现错误: %v", err) + } + } + + log.Infof("已连接到反向WebSocket %s服务器 %v", typ, c.conf.API) + wrappedConn := &wsConn{Conn: conn, apiCaller: newAPICaller(c.bot)} if c.conf.RateLimit.Enabled { wrappedConn.apiCaller.use(rateLimit(c.conf.RateLimit.Frequency, c.conf.RateLimit.Bucket)) } - go c.listenAPI(wrappedConn, false) -} -func (c *websocketClient) connectEvent() { - log.Infof("开始尝试连接到反向WebSocket Event服务器: %v", c.conf.Event) - header := http.Header{ - "X-Client-Role": []string{"Event"}, - "X-Self-ID": []string{strconv.FormatInt(c.bot.Client.Uin, 10)}, - "User-Agent": []string{"CQHttp/4.15.0"}, - } - if c.token != "" { - header["Authorization"] = []string{"Token " + c.token} - } - conn, _, err := websocket.Dial(context.Background(), c.conf.Universal, &websocket.DialOptions{HTTPHeader: header}) // nolint - if err != nil { - log.Warnf("连接到反向WebSocket Event服务器 %v 时出现错误: %v", c.conf.Event, err) - if c.conf.ReconnectInterval != 0 { - time.Sleep(time.Millisecond * time.Duration(c.conf.ReconnectInterval)) - c.connectEvent() - } - return + if conptr != nil { + *conptr = wrappedConn } - handshake := fmt.Sprintf(`{"meta_event_type":"lifecycle","post_type":"meta_event","self_id":%d,"sub_type":"connect","time":%d}`, - c.bot.Client.Uin, time.Now().Unix()) - err = conn.Write(context.Background(), websocket.MessageText, []byte(handshake)) - if err != nil { - log.Warnf("反向WebSocket 握手时出现错误: %v", err) - } - - log.Infof("已连接到反向WebSocket Event服务器 %v", c.conf.Event) - if c.eventConn == nil { - wrappedConn := &webSocketConn{Conn: conn, apiCaller: newAPICaller(c.bot)} - c.eventConn = wrappedConn - } else { - c.eventConn.Conn = conn + switch typ { + case "API": + go c.listenAPI(wrappedConn, false) + case "Universal": + go c.listenAPI(wrappedConn, true) } } -func (c *websocketClient) connectUniversal() { - log.Infof("开始尝试连接到反向WebSocket Universal服务器: %v", c.conf.Universal) - header := http.Header{ - "X-Client-Role": []string{"Universal"}, - "X-Self-ID": []string{strconv.FormatInt(c.bot.Client.Uin, 10)}, - "User-Agent": []string{"CQHttp/4.15.0"}, - } - if c.token != "" { - header["Authorization"] = []string{"Token " + c.token} - } - conn, _, err := websocket.Dial(context.Background(), c.conf.Universal, &websocket.DialOptions{HTTPHeader: header}) // nolint - if err != nil { - log.Warnf("连接到反向WebSocket Universal服务器 %v 时出现错误: %v", c.conf.Universal, err) - if c.conf.ReconnectInterval != 0 { - time.Sleep(time.Millisecond * time.Duration(c.conf.ReconnectInterval)) - c.connectUniversal() - } - return - } - handshake := fmt.Sprintf(`{"meta_event_type":"lifecycle","post_type":"meta_event","self_id":%d,"sub_type":"connect","time":%d}`, - c.bot.Client.Uin, time.Now().Unix()) - err = conn.Write(context.Background(), websocket.MessageText, []byte(handshake)) - if err != nil { - log.Warnf("反向WebSocket 握手时出现错误: %v", err) - } - - if c.universalConn == nil { - wrappedConn := &webSocketConn{Conn: conn, apiCaller: newAPICaller(c.bot)} - if c.conf.RateLimit.Enabled { - wrappedConn.apiCaller.use(rateLimit(c.conf.RateLimit.Frequency, c.conf.RateLimit.Bucket)) - } - c.universalConn = wrappedConn - } else { - c.universalConn.Conn = conn - } - go c.listenAPI(c.universalConn, true) -} - -func (c *websocketClient) listenAPI(conn *webSocketConn, u bool) { +func (c *websocketClient) listenAPI(conn *wsConn, u bool) { defer func() { _ = conn.Close(websocket.StatusNormalClosure, "") }() for { buffer := global.NewBuffer() @@ -241,7 +190,7 @@ func (c *websocketClient) listenAPI(conn *webSocketConn, u bool) { if c.conf.ReconnectInterval != 0 { time.Sleep(time.Millisecond * time.Duration(c.conf.ReconnectInterval)) if !u { - go c.connectAPI() + go c.connect("API", c.conf.API, nil) } } } @@ -252,7 +201,7 @@ func (c *websocketClient) onBotPushEvent(e *coolq.Event) { log.Debugf("上报Event %s 到 WS服务器 时被过滤.", e.JSONBytes()) return } - push := func(conn *webSocketConn, reconnect func()) { + push := func(conn *wsConn, reconnect func()) { log.Debugf("向WS服务器推送Event: %s", e.JSONBytes()) ctx, cancel := context.WithTimeout(context.Background(), time.Second*15) defer cancel() @@ -265,11 +214,21 @@ func (c *websocketClient) onBotPushEvent(e *coolq.Event) { } } } - if c.eventConn != nil { - push(c.eventConn, c.connectEvent) + + connect := func(typ, url string, conptr **wsConn) func() { + return func() { + c.connect(typ, url, conptr) + } } - if c.universalConn != nil { - push(c.universalConn, c.connectUniversal) + + c.mu.Lock() + defer c.mu.Unlock() + + if c.event != nil { + push(c.event, connect("Event", c.conf.Event, &c.event)) + } + if c.universal != nil { + push(c.universal, connect("Universal", c.conf.Universal, &c.universal)) } } @@ -294,11 +253,11 @@ func (s *webSocketServer) event(w http.ResponseWriter, r *http.Request) { log.Infof("接受 WebSocket 连接: %v (/event)", r.RemoteAddr) - conn := &webSocketConn{Conn: c, apiCaller: newAPICaller(s.bot)} + conn := &wsConn{Conn: c, apiCaller: newAPICaller(s.bot)} - s.eventConnMutex.Lock() + s.mu.Lock() s.eventConn = append(s.eventConn, conn) - s.eventConnMutex.Unlock() + s.mu.Unlock() } func (s *webSocketServer) api(w http.ResponseWriter, r *http.Request) { @@ -314,7 +273,7 @@ func (s *webSocketServer) api(w http.ResponseWriter, r *http.Request) { return } log.Infof("接受 WebSocket 连接: %v (/api)", r.RemoteAddr) - conn := &webSocketConn{Conn: c, apiCaller: newAPICaller(s.bot)} + conn := &wsConn{Conn: c, apiCaller: newAPICaller(s.bot)} if s.conf.RateLimit.Enabled { conn.apiCaller.use(rateLimit(s.conf.RateLimit.Frequency, s.conf.RateLimit.Bucket)) } @@ -340,17 +299,17 @@ func (s *webSocketServer) any(w http.ResponseWriter, r *http.Request) { return } log.Infof("接受 WebSocket 连接: %v (/)", r.RemoteAddr) - conn := &webSocketConn{Conn: c, apiCaller: newAPICaller(s.bot)} + conn := &wsConn{Conn: c, apiCaller: newAPICaller(s.bot)} if s.conf.RateLimit.Enabled { conn.apiCaller.use(rateLimit(s.conf.RateLimit.Frequency, s.conf.RateLimit.Bucket)) } - s.eventConnMutex.Lock() + s.mu.Lock() s.eventConn = append(s.eventConn, conn) - s.eventConnMutex.Unlock() + s.mu.Unlock() s.listenAPI(conn) } -func (s *webSocketServer) listenAPI(c *webSocketConn) { +func (s *webSocketServer) listenAPI(c *wsConn) { defer func() { _ = c.Close(websocket.StatusNormalClosure, "") }() for { buffer := global.NewBuffer() @@ -374,7 +333,7 @@ func (s *webSocketServer) listenAPI(c *webSocketConn) { } } -func (c *webSocketConn) handleRequest(_ *coolq.CQBot, payload []byte) { +func (c *wsConn) handleRequest(_ *coolq.CQBot, payload []byte) { defer func() { if err := recover(); err != nil { log.Printf("处置WS命令时发生无法恢复的异常:%v\n%s", err, debug.Stack()) @@ -394,14 +353,15 @@ func (c *webSocketConn) handleRequest(_ *coolq.CQBot, payload []byte) { } func (s *webSocketServer) onBotPushEvent(e *coolq.Event) { - s.eventConnMutex.Lock() - defer s.eventConnMutex.Unlock() - filter := findFilter(s.filter) if filter != nil && !filter.Eval(gjson.Parse(e.JSONString())) { log.Debugf("上报Event %s 到 WS客户端 时被过滤.", e.JSONBytes()) return } + + s.mu.Lock() + defer s.mu.Unlock() + j := 0 for i := 0; i < len(s.eventConn); i++ { conn := s.eventConn[i]