diff --git a/coolq/bot.go b/coolq/bot.go index 04ba515..098a3e2 100644 --- a/coolq/bot.go +++ b/coolq/bot.go @@ -32,7 +32,9 @@ var json = jsoniter.ConfigCompatibleWithStandardLibrary type CQBot struct { Client *client.QQClient - events []func(*bytes.Buffer) + events []func(*bytes.Buffer) + mu sync.Mutex + db *leveldb.DB friendReqCache sync.Map tempSessionCache sync.Map @@ -123,6 +125,8 @@ func NewQQBot(cli *client.QQClient, conf *config.Config) *CQBot { // OnEventPush 注册事件上报函数 func (bot *CQBot) OnEventPush(f func(buf *bytes.Buffer)) { + bot.mu.Lock() + defer bot.mu.Unlock() bot.events = append(bot.events, f) } diff --git a/server/http.go b/server/http.go index d37096e..dc1f8ae 100644 --- a/server/http.go +++ b/server/http.go @@ -46,8 +46,6 @@ type httpCtx struct { json gjson.Result query url.Values postForm url.Values - - headerAuth string } func (h *httpCtx) Get(s string) gjson.Result { @@ -71,7 +69,6 @@ func (h *httpCtx) Get(s string) gjson.Result { func (s *httpServer) ServeHTTP(writer http.ResponseWriter, request *http.Request) { var ctx httpCtx contentType := request.Header.Get("Content-Type") - ctx.headerAuth = request.Header.Get("Authorization") switch request.Method { case http.MethodPost: if strings.Contains(contentType, "application/json") { @@ -102,7 +99,7 @@ func (s *httpServer) ServeHTTP(writer http.ResponseWriter, request *http.Request } if s.accessToken != "" { - if status := checkAuth(ctx, s.accessToken); status != http.StatusOK { + if status := checkAuth(request, s.accessToken); status != http.StatusOK { writer.WriteHeader(status) return } @@ -118,10 +115,10 @@ func (s *httpServer) ServeHTTP(writer http.ResponseWriter, request *http.Request _ = json.NewEncoder(writer).Encode(ret) } -func checkAuth(ctx httpCtx, token string) int { - auth := ctx.headerAuth +func checkAuth(req *http.Request, token string) int { + auth := req.Header.Get("Authorization") if auth == "" { - auth = ctx.query.Get("access_token") + auth = req.URL.Query().Get("access_token") } else { authN := strings.SplitN(auth, " ", 2) if len(authN) == 2 { @@ -175,29 +172,25 @@ func RunHTTPServerAndClients(bot *coolq.CQBot, conf *config.HTTPServer) { client: for _, c := range conf.Post { if c.URL != "" { - go newHTTPClient().Run(c.URL, c.Secret, conf.Filter, conf.Timeout, bot) + go HTTPClient{ + bot: bot, + secret: c.Secret, + addr: c.URL, + filter: conf.Filter, + timeout: conf.Timeout, + }.Run() } } } -// newHTTPClient 返回反向HTTP客户端 -func newHTTPClient() *HTTPClient { - return &HTTPClient{} -} - // Run 运行反向HTTP服务 -func (c *HTTPClient) Run(addr, secret, filter string, timeout int32, bot *coolq.CQBot) { - c.bot = bot - c.secret = secret - c.addr = addr - c.timeout = timeout - c.filter = filter - addFilter(filter) +func (c HTTPClient) Run() { + addFilter(c.filter) if c.timeout < 5 { c.timeout = 5 } - bot.OnEventPush(c.onBotPushEvent) - log.Infof("HTTP POST上报器已启动: %v", addr) + c.bot.OnEventPush(c.onBotPushEvent) + log.Infof("HTTP POST上报器已启动: %v", c.addr) } func (c *HTTPClient) onBotPushEvent(m *bytes.Buffer) { diff --git a/server/websocket.go b/server/websocket.go index f4266e5..1421b88 100644 --- a/server/websocket.go +++ b/server/websocket.go @@ -31,8 +31,8 @@ type webSocketServer struct { filter string } -// WebSocketClient WebSocket客户端实例 -type WebSocketClient struct { +// websocketClient WebSocket客户端实例 +type websocketClient struct { bot *coolq.CQBot conf *config.WebsocketReverse @@ -59,11 +59,12 @@ func RunWebSocketServer(b *coolq.CQBot, conf *config.WebsocketServer) { if conf.Disabled { return } - s := new(webSocketServer) - s.conf = conf - s.bot = b - s.token = conf.AccessToken - s.filter = conf.Filter + s := &webSocketServer{ + bot: b, + conf: conf, + token: conf.AccessToken, + filter: conf.Filter, + } addFilter(s.filter) addr := fmt.Sprintf("%s:%d", conf.Host, conf.Port) s.handshake = fmt.Sprintf(`{"_post_method":2,"meta_event_type":"lifecycle","post_type":"meta_event","self_id":%d,"sub_type":"connect","time":%d}`, @@ -83,11 +84,12 @@ func RunWebSocketClient(b *coolq.CQBot, conf *config.WebsocketReverse) { if conf.Disabled { return } - c := new(WebSocketClient) - c.bot = b - c.conf = conf - c.token = conf.AccessToken - c.filter = conf.Filter + c := &websocketClient{ + bot: b, + conf: conf, + token: conf.AccessToken, + filter: conf.Filter, + } addFilter(c.filter) if c.conf.Universal != "" { c.connectUniversal() @@ -102,7 +104,7 @@ func RunWebSocketClient(b *coolq.CQBot, conf *config.WebsocketReverse) { c.bot.OnEventPush(c.onBotPushEvent) } -func (c *WebSocketClient) connectAPI() { +func (c *websocketClient) connectAPI() { log.Infof("开始尝试连接到反向WebSocket API服务器: %v", c.conf.API) header := http.Header{ "X-Client-Role": []string{"API"}, @@ -129,7 +131,7 @@ func (c *WebSocketClient) connectAPI() { go c.listenAPI(wrappedConn, false) } -func (c *WebSocketClient) connectEvent() { +func (c *websocketClient) connectEvent() { log.Infof("开始尝试连接到反向WebSocket Event服务器: %v", c.conf.Event) header := http.Header{ "X-Client-Role": []string{"Event"}, @@ -160,7 +162,7 @@ func (c *WebSocketClient) connectEvent() { c.eventConn = &webSocketConn{Conn: conn, apiCaller: newAPICaller(c.bot)} } -func (c *WebSocketClient) connectUniversal() { +func (c *websocketClient) connectUniversal() { log.Infof("开始尝试连接到反向WebSocket Universal服务器: %v", c.conf.Universal) header := http.Header{ "X-Client-Role": []string{"Universal"}, @@ -194,7 +196,7 @@ func (c *WebSocketClient) connectUniversal() { c.universalConn = wrappedConn } -func (c *WebSocketClient) listenAPI(conn *webSocketConn, u bool) { +func (c *websocketClient) listenAPI(conn *webSocketConn, u bool) { defer func() { _ = conn.Close() }() for { buffer := global.NewBuffer() @@ -225,7 +227,7 @@ func (c *WebSocketClient) listenAPI(conn *webSocketConn, u bool) { } } -func (c *WebSocketClient) onBotPushEvent(m *bytes.Buffer) { +func (c *websocketClient) onBotPushEvent(m *bytes.Buffer) { filter := findFilter(c.filter) if filter != nil && !filter.Eval(gjson.Parse(utils.B2S(m.Bytes()))) { log.Debugf("上报Event %v 到 WS客户端 时被过滤.", utils.B2S(m.Bytes())) @@ -263,28 +265,11 @@ func (c *WebSocketClient) onBotPushEvent(m *bytes.Buffer) { } } -func (s *webSocketServer) auth(r *http.Request) (bool, int) { - if s.token != "" { // s.token == s.conf.AccessToken - var auth string - if auth = r.URL.Query().Get("access_token"); auth == "" { - headAuth := strings.SplitN(r.Header.Get("Authorization"), " ", 2) - if len(headAuth) != 2 || headAuth[1] == "" { - return false, 401 - } - auth = headAuth[1] - } - if auth != s.token { - log.Warnf("已拒绝 %v 的 WebSocket 请求: Token鉴权失败", r.RemoteAddr) - return false, 403 - } - } - return true, 0 -} - func (s *webSocketServer) event(w http.ResponseWriter, r *http.Request) { - isAuth, errReason := s.auth(r) - if !isAuth { - w.WriteHeader(errReason) + status := checkAuth(r, s.token) + if status != http.StatusOK { + log.Warnf("已拒绝 %v 的 WebSocket 请求: Token鉴权失败(code:%d)", r.RemoteAddr, status) + w.WriteHeader(status) return } c, err := upgrader.Upgrade(w, r, nil) @@ -309,9 +294,10 @@ func (s *webSocketServer) event(w http.ResponseWriter, r *http.Request) { } func (s *webSocketServer) api(w http.ResponseWriter, r *http.Request) { - isAuth, errReason := s.auth(r) - if !isAuth { - w.WriteHeader(errReason) + status := checkAuth(r, s.token) + if status != http.StatusOK { + log.Warnf("已拒绝 %v 的 WebSocket 请求: Token鉴权失败(code:%d)", r.RemoteAddr, status) + w.WriteHeader(status) return } c, err := upgrader.Upgrade(w, r, nil) @@ -328,9 +314,10 @@ func (s *webSocketServer) api(w http.ResponseWriter, r *http.Request) { } func (s *webSocketServer) any(w http.ResponseWriter, r *http.Request) { - isAuth, errReason := s.auth(r) - if !isAuth { - w.WriteHeader(errReason) + status := checkAuth(r, s.token) + if status != http.StatusOK { + log.Warnf("已拒绝 %v 的 WebSocket 请求: Token鉴权失败(code:%d)", r.RemoteAddr, status) + w.WriteHeader(status) return } c, err := upgrader.Upgrade(w, r, nil)