From 9c65ff4dcd86e1d67dd3f8d3b7a36761c67d7bd1 Mon Sep 17 00:00:00 2001 From: Ink-33 Date: Sat, 29 Aug 2020 17:24:09 +0800 Subject: [PATCH] =?UTF-8?q?=E5=B0=9D=E8=AF=95=E4=BF=AE=E5=A4=8D=E6=AD=A3?= =?UTF-8?q?=E5=90=91WebSocket=E4=B8=8B=E7=9A=84=E9=89=B4=E6=9D=83panic?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- server/websocket.go | 24 +++++++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/server/websocket.go b/server/websocket.go index 7f0cae6..5887d8a 100644 --- a/server/websocket.go +++ b/server/websocket.go @@ -206,10 +206,16 @@ func (c *websocketClient) onBotPushEvent(m coolq.MSG) { func (s *websocketServer) event(w http.ResponseWriter, r *http.Request) { if s.token != "" { - if r.URL.Query().Get("access_token") != s.token && strings.SplitN(r.Header.Get("Authorization"), " ", 2)[1] != s.token { + if r.URL.Query().Get("access_token") != s.token { log.Warnf("已拒绝 %v 的 Websocket 请求: Token错误", r.RemoteAddr) w.WriteHeader(401) return + } else if auth := r.Header.Get("Authorization"); auth != "" { + if strings.SplitN(auth, " ", 2)[1] != s.token { + log.Warnf("已拒绝 %v 的 Websocket 请求: Token错误", r.RemoteAddr) + w.WriteHeader(401) + return + } } } c, err := upgrader.Upgrade(w, r, nil) @@ -235,10 +241,16 @@ func (s *websocketServer) event(w http.ResponseWriter, r *http.Request) { func (s *websocketServer) api(w http.ResponseWriter, r *http.Request) { if s.token != "" { - if r.URL.Query().Get("access_token") != s.token && strings.SplitN(r.Header.Get("Authorization"), " ", 2)[1] != s.token { + if r.URL.Query().Get("access_token") != s.token { log.Warnf("已拒绝 %v 的 Websocket 请求: Token错误", r.RemoteAddr) w.WriteHeader(401) return + } else if auth := r.Header.Get("Authorization"); auth != "" { + if strings.SplitN(auth, " ", 2)[1] != s.token { + log.Warnf("已拒绝 %v 的 Websocket 请求: Token错误", r.RemoteAddr) + w.WriteHeader(401) + return + } } } c, err := upgrader.Upgrade(w, r, nil) @@ -253,10 +265,16 @@ func (s *websocketServer) api(w http.ResponseWriter, r *http.Request) { func (s *websocketServer) any(w http.ResponseWriter, r *http.Request) { if s.token != "" { - if r.URL.Query().Get("access_token") != s.token && strings.SplitN(r.Header.Get("Authorization"), " ", 2)[1] != s.token { + if r.URL.Query().Get("access_token") != s.token { log.Warnf("已拒绝 %v 的 Websocket 请求: Token错误", r.RemoteAddr) w.WriteHeader(401) return + } else if auth := r.Header.Get("Authorization"); auth != "" { + if strings.SplitN(auth, " ", 2)[1] != s.token { + log.Warnf("已拒绝 %v 的 Websocket 请求: Token错误", r.RemoteAddr) + w.WriteHeader(401) + return + } } } c, err := upgrader.Upgrade(w, r, nil)