From 971a9575ff722fbfcfedb140d4ffd68d36254471 Mon Sep 17 00:00:00 2001 From: Sam Date: Sun, 11 Apr 2021 01:53:51 +0800 Subject: [PATCH] =?UTF-8?q?=E5=AE=9E=E7=8E=B0=20OneBot=20HTTP/WS=20?= =?UTF-8?q?=E7=8A=B6=E6=80=81=E7=A0=81=E6=A0=87=E5=87=86,=20Closes=20#812?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- server/http.go | 26 ++++++++++++++++------- server/websocket.go | 52 +++++++++++++++++++++++++-------------------- 2 files changed, 47 insertions(+), 31 deletions(-) diff --git a/server/http.go b/server/http.go index 4b783c4..c6f0283 100644 --- a/server/http.go +++ b/server/http.go @@ -86,16 +86,26 @@ func RunHTTPServerAndClients(bot *coolq.CQBot, conf *config.HTTPServer) { if authToken != "" { s.engine.Use(func(c *gin.Context) { auth := c.Request.Header.Get("Authorization") - switch { - case auth != "": - if strings.SplitN(auth, " ", 2)[1] != authToken { + if auth == "" { + headAuth := c.Query("access_token") + switch { + case headAuth == "": c.AbortWithStatus(401) + return + case headAuth != authToken: + c.AbortWithStatus(403) + return + } + } else { + auth := strings.SplitN(auth, " ", 2) + switch { + case len(auth) != 2 || auth[1] == "": + c.AbortWithStatus(401) + return + case auth[1] != authToken: + c.AbortWithStatus(403) + return } - case c.Query("access_token") != authToken: - c.AbortWithStatus(401) - return - default: - c.Next() } }) } diff --git a/server/websocket.go b/server/websocket.go index 5d4988c..3a019e6 100644 --- a/server/websocket.go +++ b/server/websocket.go @@ -264,15 +264,29 @@ func (c *WebSocketClient) onBotPushEvent(m *bytes.Buffer) { } } -func (s *webSocketServer) event(w http.ResponseWriter, r *http.Request) { - if s.conf.AccessToken != "" { - if auth := r.URL.Query().Get("access_token"); auth != s.token { - if auth := strings.SplitN(r.Header.Get("Authorization"), " ", 2); len(auth) != 2 || auth[1] != s.token { - log.Warnf("已拒绝 %v 的 WebSocket 请求: Token鉴权失败", r.RemoteAddr) - w.WriteHeader(401) - return +func (s *webSocketServer) auth(r *http.Request) (bool, int) { + if s.token != "" { // s.token == s.conf.AccessToken + var auth string + if auth = r.URL.Query().Get("access_token"); auth == "" { + headAuth := strings.SplitN(r.Header.Get("Authorization"), " ", 2) + if len(headAuth) != 2 || headAuth[1] == "" { + return false, 401 } + auth = headAuth[1] } + if auth != s.token { + log.Warnf("已拒绝 %v 的 WebSocket 请求: Token鉴权失败", r.RemoteAddr) + return false, 403 + } + } + return true, 0 +} + +func (s *webSocketServer) event(w http.ResponseWriter, r *http.Request) { + isAuth, errReason := s.auth(r) + if !isAuth { + w.WriteHeader(errReason) + return } c, err := upgrader.Upgrade(w, r, nil) if err != nil { @@ -296,14 +310,10 @@ func (s *webSocketServer) event(w http.ResponseWriter, r *http.Request) { } func (s *webSocketServer) api(w http.ResponseWriter, r *http.Request) { - if s.token != "" { - if auth := r.URL.Query().Get("access_token"); auth != s.token { - if auth := strings.SplitN(r.Header.Get("Authorization"), " ", 2); len(auth) != 2 || auth[1] != s.token { - log.Warnf("已拒绝 %v 的 WebSocket 请求: Token鉴权失败", r.RemoteAddr) - w.WriteHeader(401) - return - } - } + isAuth, errReason := s.auth(r) + if !isAuth { + w.WriteHeader(errReason) + return } c, err := upgrader.Upgrade(w, r, nil) if err != nil { @@ -319,14 +329,10 @@ func (s *webSocketServer) api(w http.ResponseWriter, r *http.Request) { } func (s *webSocketServer) any(w http.ResponseWriter, r *http.Request) { - if s.token != "" { - if auth := r.URL.Query().Get("access_token"); auth != s.token { - if auth := strings.SplitN(r.Header.Get("Authorization"), " ", 2); len(auth) != 2 || auth[1] != s.token { - log.Warnf("已拒绝 %v 的 WebSocket 请求: Token鉴权失败", r.RemoteAddr) - w.WriteHeader(401) - return - } - } + isAuth, errReason := s.auth(r) + if !isAuth { + w.WriteHeader(errReason) + return } c, err := upgrader.Upgrade(w, r, nil) if err != nil {