1
0
mirror of https://github.com/Mrs4s/go-cqhttp.git synced 2025-05-04 19:17:37 +08:00

refactor(server): unify http & ws-server Authorization

This commit is contained in:
wdvxdr 2021-06-24 20:18:38 +08:00
parent 2bbc75672f
commit 499108cb1b
No known key found for this signature in database
GPG Key ID: 55FF1414A69CEBA6
3 changed files with 51 additions and 67 deletions

View File

@ -32,7 +32,9 @@ var json = jsoniter.ConfigCompatibleWithStandardLibrary
type CQBot struct { type CQBot struct {
Client *client.QQClient Client *client.QQClient
events []func(*bytes.Buffer) events []func(*bytes.Buffer)
mu sync.Mutex
db *leveldb.DB db *leveldb.DB
friendReqCache sync.Map friendReqCache sync.Map
tempSessionCache sync.Map tempSessionCache sync.Map
@ -123,6 +125,8 @@ func NewQQBot(cli *client.QQClient, conf *config.Config) *CQBot {
// OnEventPush 注册事件上报函数 // OnEventPush 注册事件上报函数
func (bot *CQBot) OnEventPush(f func(buf *bytes.Buffer)) { func (bot *CQBot) OnEventPush(f func(buf *bytes.Buffer)) {
bot.mu.Lock()
defer bot.mu.Unlock()
bot.events = append(bot.events, f) bot.events = append(bot.events, f)
} }

View File

@ -46,8 +46,6 @@ type httpCtx struct {
json gjson.Result json gjson.Result
query url.Values query url.Values
postForm url.Values postForm url.Values
headerAuth string
} }
func (h *httpCtx) Get(s string) gjson.Result { func (h *httpCtx) Get(s string) gjson.Result {
@ -71,7 +69,6 @@ func (h *httpCtx) Get(s string) gjson.Result {
func (s *httpServer) ServeHTTP(writer http.ResponseWriter, request *http.Request) { func (s *httpServer) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
var ctx httpCtx var ctx httpCtx
contentType := request.Header.Get("Content-Type") contentType := request.Header.Get("Content-Type")
ctx.headerAuth = request.Header.Get("Authorization")
switch request.Method { switch request.Method {
case http.MethodPost: case http.MethodPost:
if strings.Contains(contentType, "application/json") { if strings.Contains(contentType, "application/json") {
@ -102,7 +99,7 @@ func (s *httpServer) ServeHTTP(writer http.ResponseWriter, request *http.Request
} }
if s.accessToken != "" { if s.accessToken != "" {
if status := checkAuth(ctx, s.accessToken); status != http.StatusOK { if status := checkAuth(request, s.accessToken); status != http.StatusOK {
writer.WriteHeader(status) writer.WriteHeader(status)
return return
} }
@ -118,10 +115,10 @@ func (s *httpServer) ServeHTTP(writer http.ResponseWriter, request *http.Request
_ = json.NewEncoder(writer).Encode(ret) _ = json.NewEncoder(writer).Encode(ret)
} }
func checkAuth(ctx httpCtx, token string) int { func checkAuth(req *http.Request, token string) int {
auth := ctx.headerAuth auth := req.Header.Get("Authorization")
if auth == "" { if auth == "" {
auth = ctx.query.Get("access_token") auth = req.URL.Query().Get("access_token")
} else { } else {
authN := strings.SplitN(auth, " ", 2) authN := strings.SplitN(auth, " ", 2)
if len(authN) == 2 { if len(authN) == 2 {
@ -175,29 +172,25 @@ func RunHTTPServerAndClients(bot *coolq.CQBot, conf *config.HTTPServer) {
client: client:
for _, c := range conf.Post { for _, c := range conf.Post {
if c.URL != "" { if c.URL != "" {
go newHTTPClient().Run(c.URL, c.Secret, conf.Filter, conf.Timeout, bot) go HTTPClient{
bot: bot,
secret: c.Secret,
addr: c.URL,
filter: conf.Filter,
timeout: conf.Timeout,
}.Run()
} }
} }
} }
// newHTTPClient 返回反向HTTP客户端
func newHTTPClient() *HTTPClient {
return &HTTPClient{}
}
// Run 运行反向HTTP服务 // Run 运行反向HTTP服务
func (c *HTTPClient) Run(addr, secret, filter string, timeout int32, bot *coolq.CQBot) { func (c HTTPClient) Run() {
c.bot = bot addFilter(c.filter)
c.secret = secret
c.addr = addr
c.timeout = timeout
c.filter = filter
addFilter(filter)
if c.timeout < 5 { if c.timeout < 5 {
c.timeout = 5 c.timeout = 5
} }
bot.OnEventPush(c.onBotPushEvent) c.bot.OnEventPush(c.onBotPushEvent)
log.Infof("HTTP POST上报器已启动: %v", addr) log.Infof("HTTP POST上报器已启动: %v", c.addr)
} }
func (c *HTTPClient) onBotPushEvent(m *bytes.Buffer) { func (c *HTTPClient) onBotPushEvent(m *bytes.Buffer) {

View File

@ -31,8 +31,8 @@ type webSocketServer struct {
filter string filter string
} }
// WebSocketClient WebSocket客户端实例 // websocketClient WebSocket客户端实例
type WebSocketClient struct { type websocketClient struct {
bot *coolq.CQBot bot *coolq.CQBot
conf *config.WebsocketReverse conf *config.WebsocketReverse
@ -59,11 +59,12 @@ func RunWebSocketServer(b *coolq.CQBot, conf *config.WebsocketServer) {
if conf.Disabled { if conf.Disabled {
return return
} }
s := new(webSocketServer) s := &webSocketServer{
s.conf = conf bot: b,
s.bot = b conf: conf,
s.token = conf.AccessToken token: conf.AccessToken,
s.filter = conf.Filter filter: conf.Filter,
}
addFilter(s.filter) addFilter(s.filter)
addr := fmt.Sprintf("%s:%d", conf.Host, conf.Port) addr := fmt.Sprintf("%s:%d", conf.Host, conf.Port)
s.handshake = fmt.Sprintf(`{"_post_method":2,"meta_event_type":"lifecycle","post_type":"meta_event","self_id":%d,"sub_type":"connect","time":%d}`, s.handshake = fmt.Sprintf(`{"_post_method":2,"meta_event_type":"lifecycle","post_type":"meta_event","self_id":%d,"sub_type":"connect","time":%d}`,
@ -83,11 +84,12 @@ func RunWebSocketClient(b *coolq.CQBot, conf *config.WebsocketReverse) {
if conf.Disabled { if conf.Disabled {
return return
} }
c := new(WebSocketClient) c := &websocketClient{
c.bot = b bot: b,
c.conf = conf conf: conf,
c.token = conf.AccessToken token: conf.AccessToken,
c.filter = conf.Filter filter: conf.Filter,
}
addFilter(c.filter) addFilter(c.filter)
if c.conf.Universal != "" { if c.conf.Universal != "" {
c.connectUniversal() c.connectUniversal()
@ -102,7 +104,7 @@ func RunWebSocketClient(b *coolq.CQBot, conf *config.WebsocketReverse) {
c.bot.OnEventPush(c.onBotPushEvent) c.bot.OnEventPush(c.onBotPushEvent)
} }
func (c *WebSocketClient) connectAPI() { func (c *websocketClient) connectAPI() {
log.Infof("开始尝试连接到反向WebSocket API服务器: %v", c.conf.API) log.Infof("开始尝试连接到反向WebSocket API服务器: %v", c.conf.API)
header := http.Header{ header := http.Header{
"X-Client-Role": []string{"API"}, "X-Client-Role": []string{"API"},
@ -129,7 +131,7 @@ func (c *WebSocketClient) connectAPI() {
go c.listenAPI(wrappedConn, false) go c.listenAPI(wrappedConn, false)
} }
func (c *WebSocketClient) connectEvent() { func (c *websocketClient) connectEvent() {
log.Infof("开始尝试连接到反向WebSocket Event服务器: %v", c.conf.Event) log.Infof("开始尝试连接到反向WebSocket Event服务器: %v", c.conf.Event)
header := http.Header{ header := http.Header{
"X-Client-Role": []string{"Event"}, "X-Client-Role": []string{"Event"},
@ -160,7 +162,7 @@ func (c *WebSocketClient) connectEvent() {
c.eventConn = &webSocketConn{Conn: conn, apiCaller: newAPICaller(c.bot)} c.eventConn = &webSocketConn{Conn: conn, apiCaller: newAPICaller(c.bot)}
} }
func (c *WebSocketClient) connectUniversal() { func (c *websocketClient) connectUniversal() {
log.Infof("开始尝试连接到反向WebSocket Universal服务器: %v", c.conf.Universal) log.Infof("开始尝试连接到反向WebSocket Universal服务器: %v", c.conf.Universal)
header := http.Header{ header := http.Header{
"X-Client-Role": []string{"Universal"}, "X-Client-Role": []string{"Universal"},
@ -194,7 +196,7 @@ func (c *WebSocketClient) connectUniversal() {
c.universalConn = wrappedConn c.universalConn = wrappedConn
} }
func (c *WebSocketClient) listenAPI(conn *webSocketConn, u bool) { func (c *websocketClient) listenAPI(conn *webSocketConn, u bool) {
defer func() { _ = conn.Close() }() defer func() { _ = conn.Close() }()
for { for {
buffer := global.NewBuffer() buffer := global.NewBuffer()
@ -225,7 +227,7 @@ func (c *WebSocketClient) listenAPI(conn *webSocketConn, u bool) {
} }
} }
func (c *WebSocketClient) onBotPushEvent(m *bytes.Buffer) { func (c *websocketClient) onBotPushEvent(m *bytes.Buffer) {
filter := findFilter(c.filter) filter := findFilter(c.filter)
if filter != nil && !filter.Eval(gjson.Parse(utils.B2S(m.Bytes()))) { if filter != nil && !filter.Eval(gjson.Parse(utils.B2S(m.Bytes()))) {
log.Debugf("上报Event %v 到 WS客户端 时被过滤.", utils.B2S(m.Bytes())) log.Debugf("上报Event %v 到 WS客户端 时被过滤.", utils.B2S(m.Bytes()))
@ -263,28 +265,11 @@ func (c *WebSocketClient) onBotPushEvent(m *bytes.Buffer) {
} }
} }
func (s *webSocketServer) auth(r *http.Request) (bool, int) {
if s.token != "" { // s.token == s.conf.AccessToken
var auth string
if auth = r.URL.Query().Get("access_token"); auth == "" {
headAuth := strings.SplitN(r.Header.Get("Authorization"), " ", 2)
if len(headAuth) != 2 || headAuth[1] == "" {
return false, 401
}
auth = headAuth[1]
}
if auth != s.token {
log.Warnf("已拒绝 %v 的 WebSocket 请求: Token鉴权失败", r.RemoteAddr)
return false, 403
}
}
return true, 0
}
func (s *webSocketServer) event(w http.ResponseWriter, r *http.Request) { func (s *webSocketServer) event(w http.ResponseWriter, r *http.Request) {
isAuth, errReason := s.auth(r) status := checkAuth(r, s.token)
if !isAuth { if status != http.StatusOK {
w.WriteHeader(errReason) log.Warnf("已拒绝 %v 的 WebSocket 请求: Token鉴权失败(code:%d)", r.RemoteAddr, status)
w.WriteHeader(status)
return return
} }
c, err := upgrader.Upgrade(w, r, nil) c, err := upgrader.Upgrade(w, r, nil)
@ -309,9 +294,10 @@ func (s *webSocketServer) event(w http.ResponseWriter, r *http.Request) {
} }
func (s *webSocketServer) api(w http.ResponseWriter, r *http.Request) { func (s *webSocketServer) api(w http.ResponseWriter, r *http.Request) {
isAuth, errReason := s.auth(r) status := checkAuth(r, s.token)
if !isAuth { if status != http.StatusOK {
w.WriteHeader(errReason) log.Warnf("已拒绝 %v 的 WebSocket 请求: Token鉴权失败(code:%d)", r.RemoteAddr, status)
w.WriteHeader(status)
return return
} }
c, err := upgrader.Upgrade(w, r, nil) c, err := upgrader.Upgrade(w, r, nil)
@ -328,9 +314,10 @@ func (s *webSocketServer) api(w http.ResponseWriter, r *http.Request) {
} }
func (s *webSocketServer) any(w http.ResponseWriter, r *http.Request) { func (s *webSocketServer) any(w http.ResponseWriter, r *http.Request) {
isAuth, errReason := s.auth(r) status := checkAuth(r, s.token)
if !isAuth { if status != http.StatusOK {
w.WriteHeader(errReason) log.Warnf("已拒绝 %v 的 WebSocket 请求: Token鉴权失败(code:%d)", r.RemoteAddr, status)
w.WriteHeader(status)
return return
} }
c, err := upgrader.Upgrade(w, r, nil) c, err := upgrader.Upgrade(w, r, nil)