mirror of
https://github.com/Mrs4s/go-cqhttp.git
synced 2025-05-05 03:23:49 +08:00
refactor(server): unify http & ws-server Authorization
This commit is contained in:
parent
2bbc75672f
commit
499108cb1b
@ -33,6 +33,8 @@ type CQBot struct {
|
|||||||
Client *client.QQClient
|
Client *client.QQClient
|
||||||
|
|
||||||
events []func(*bytes.Buffer)
|
events []func(*bytes.Buffer)
|
||||||
|
mu sync.Mutex
|
||||||
|
|
||||||
db *leveldb.DB
|
db *leveldb.DB
|
||||||
friendReqCache sync.Map
|
friendReqCache sync.Map
|
||||||
tempSessionCache sync.Map
|
tempSessionCache sync.Map
|
||||||
@ -123,6 +125,8 @@ func NewQQBot(cli *client.QQClient, conf *config.Config) *CQBot {
|
|||||||
|
|
||||||
// OnEventPush 注册事件上报函数
|
// OnEventPush 注册事件上报函数
|
||||||
func (bot *CQBot) OnEventPush(f func(buf *bytes.Buffer)) {
|
func (bot *CQBot) OnEventPush(f func(buf *bytes.Buffer)) {
|
||||||
|
bot.mu.Lock()
|
||||||
|
defer bot.mu.Unlock()
|
||||||
bot.events = append(bot.events, f)
|
bot.events = append(bot.events, f)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -46,8 +46,6 @@ type httpCtx struct {
|
|||||||
json gjson.Result
|
json gjson.Result
|
||||||
query url.Values
|
query url.Values
|
||||||
postForm url.Values
|
postForm url.Values
|
||||||
|
|
||||||
headerAuth string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *httpCtx) Get(s string) gjson.Result {
|
func (h *httpCtx) Get(s string) gjson.Result {
|
||||||
@ -71,7 +69,6 @@ func (h *httpCtx) Get(s string) gjson.Result {
|
|||||||
func (s *httpServer) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
|
func (s *httpServer) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
|
||||||
var ctx httpCtx
|
var ctx httpCtx
|
||||||
contentType := request.Header.Get("Content-Type")
|
contentType := request.Header.Get("Content-Type")
|
||||||
ctx.headerAuth = request.Header.Get("Authorization")
|
|
||||||
switch request.Method {
|
switch request.Method {
|
||||||
case http.MethodPost:
|
case http.MethodPost:
|
||||||
if strings.Contains(contentType, "application/json") {
|
if strings.Contains(contentType, "application/json") {
|
||||||
@ -102,7 +99,7 @@ func (s *httpServer) ServeHTTP(writer http.ResponseWriter, request *http.Request
|
|||||||
}
|
}
|
||||||
|
|
||||||
if s.accessToken != "" {
|
if s.accessToken != "" {
|
||||||
if status := checkAuth(ctx, s.accessToken); status != http.StatusOK {
|
if status := checkAuth(request, s.accessToken); status != http.StatusOK {
|
||||||
writer.WriteHeader(status)
|
writer.WriteHeader(status)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -118,10 +115,10 @@ func (s *httpServer) ServeHTTP(writer http.ResponseWriter, request *http.Request
|
|||||||
_ = json.NewEncoder(writer).Encode(ret)
|
_ = json.NewEncoder(writer).Encode(ret)
|
||||||
}
|
}
|
||||||
|
|
||||||
func checkAuth(ctx httpCtx, token string) int {
|
func checkAuth(req *http.Request, token string) int {
|
||||||
auth := ctx.headerAuth
|
auth := req.Header.Get("Authorization")
|
||||||
if auth == "" {
|
if auth == "" {
|
||||||
auth = ctx.query.Get("access_token")
|
auth = req.URL.Query().Get("access_token")
|
||||||
} else {
|
} else {
|
||||||
authN := strings.SplitN(auth, " ", 2)
|
authN := strings.SplitN(auth, " ", 2)
|
||||||
if len(authN) == 2 {
|
if len(authN) == 2 {
|
||||||
@ -175,29 +172,25 @@ func RunHTTPServerAndClients(bot *coolq.CQBot, conf *config.HTTPServer) {
|
|||||||
client:
|
client:
|
||||||
for _, c := range conf.Post {
|
for _, c := range conf.Post {
|
||||||
if c.URL != "" {
|
if c.URL != "" {
|
||||||
go newHTTPClient().Run(c.URL, c.Secret, conf.Filter, conf.Timeout, bot)
|
go HTTPClient{
|
||||||
|
bot: bot,
|
||||||
|
secret: c.Secret,
|
||||||
|
addr: c.URL,
|
||||||
|
filter: conf.Filter,
|
||||||
|
timeout: conf.Timeout,
|
||||||
|
}.Run()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// newHTTPClient 返回反向HTTP客户端
|
|
||||||
func newHTTPClient() *HTTPClient {
|
|
||||||
return &HTTPClient{}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Run 运行反向HTTP服务
|
// Run 运行反向HTTP服务
|
||||||
func (c *HTTPClient) Run(addr, secret, filter string, timeout int32, bot *coolq.CQBot) {
|
func (c HTTPClient) Run() {
|
||||||
c.bot = bot
|
addFilter(c.filter)
|
||||||
c.secret = secret
|
|
||||||
c.addr = addr
|
|
||||||
c.timeout = timeout
|
|
||||||
c.filter = filter
|
|
||||||
addFilter(filter)
|
|
||||||
if c.timeout < 5 {
|
if c.timeout < 5 {
|
||||||
c.timeout = 5
|
c.timeout = 5
|
||||||
}
|
}
|
||||||
bot.OnEventPush(c.onBotPushEvent)
|
c.bot.OnEventPush(c.onBotPushEvent)
|
||||||
log.Infof("HTTP POST上报器已启动: %v", addr)
|
log.Infof("HTTP POST上报器已启动: %v", c.addr)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *HTTPClient) onBotPushEvent(m *bytes.Buffer) {
|
func (c *HTTPClient) onBotPushEvent(m *bytes.Buffer) {
|
||||||
|
@ -31,8 +31,8 @@ type webSocketServer struct {
|
|||||||
filter string
|
filter string
|
||||||
}
|
}
|
||||||
|
|
||||||
// WebSocketClient WebSocket客户端实例
|
// websocketClient WebSocket客户端实例
|
||||||
type WebSocketClient struct {
|
type websocketClient struct {
|
||||||
bot *coolq.CQBot
|
bot *coolq.CQBot
|
||||||
conf *config.WebsocketReverse
|
conf *config.WebsocketReverse
|
||||||
|
|
||||||
@ -59,11 +59,12 @@ func RunWebSocketServer(b *coolq.CQBot, conf *config.WebsocketServer) {
|
|||||||
if conf.Disabled {
|
if conf.Disabled {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
s := new(webSocketServer)
|
s := &webSocketServer{
|
||||||
s.conf = conf
|
bot: b,
|
||||||
s.bot = b
|
conf: conf,
|
||||||
s.token = conf.AccessToken
|
token: conf.AccessToken,
|
||||||
s.filter = conf.Filter
|
filter: conf.Filter,
|
||||||
|
}
|
||||||
addFilter(s.filter)
|
addFilter(s.filter)
|
||||||
addr := fmt.Sprintf("%s:%d", conf.Host, conf.Port)
|
addr := fmt.Sprintf("%s:%d", conf.Host, conf.Port)
|
||||||
s.handshake = fmt.Sprintf(`{"_post_method":2,"meta_event_type":"lifecycle","post_type":"meta_event","self_id":%d,"sub_type":"connect","time":%d}`,
|
s.handshake = fmt.Sprintf(`{"_post_method":2,"meta_event_type":"lifecycle","post_type":"meta_event","self_id":%d,"sub_type":"connect","time":%d}`,
|
||||||
@ -83,11 +84,12 @@ func RunWebSocketClient(b *coolq.CQBot, conf *config.WebsocketReverse) {
|
|||||||
if conf.Disabled {
|
if conf.Disabled {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c := new(WebSocketClient)
|
c := &websocketClient{
|
||||||
c.bot = b
|
bot: b,
|
||||||
c.conf = conf
|
conf: conf,
|
||||||
c.token = conf.AccessToken
|
token: conf.AccessToken,
|
||||||
c.filter = conf.Filter
|
filter: conf.Filter,
|
||||||
|
}
|
||||||
addFilter(c.filter)
|
addFilter(c.filter)
|
||||||
if c.conf.Universal != "" {
|
if c.conf.Universal != "" {
|
||||||
c.connectUniversal()
|
c.connectUniversal()
|
||||||
@ -102,7 +104,7 @@ func RunWebSocketClient(b *coolq.CQBot, conf *config.WebsocketReverse) {
|
|||||||
c.bot.OnEventPush(c.onBotPushEvent)
|
c.bot.OnEventPush(c.onBotPushEvent)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *WebSocketClient) connectAPI() {
|
func (c *websocketClient) connectAPI() {
|
||||||
log.Infof("开始尝试连接到反向WebSocket API服务器: %v", c.conf.API)
|
log.Infof("开始尝试连接到反向WebSocket API服务器: %v", c.conf.API)
|
||||||
header := http.Header{
|
header := http.Header{
|
||||||
"X-Client-Role": []string{"API"},
|
"X-Client-Role": []string{"API"},
|
||||||
@ -129,7 +131,7 @@ func (c *WebSocketClient) connectAPI() {
|
|||||||
go c.listenAPI(wrappedConn, false)
|
go c.listenAPI(wrappedConn, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *WebSocketClient) connectEvent() {
|
func (c *websocketClient) connectEvent() {
|
||||||
log.Infof("开始尝试连接到反向WebSocket Event服务器: %v", c.conf.Event)
|
log.Infof("开始尝试连接到反向WebSocket Event服务器: %v", c.conf.Event)
|
||||||
header := http.Header{
|
header := http.Header{
|
||||||
"X-Client-Role": []string{"Event"},
|
"X-Client-Role": []string{"Event"},
|
||||||
@ -160,7 +162,7 @@ func (c *WebSocketClient) connectEvent() {
|
|||||||
c.eventConn = &webSocketConn{Conn: conn, apiCaller: newAPICaller(c.bot)}
|
c.eventConn = &webSocketConn{Conn: conn, apiCaller: newAPICaller(c.bot)}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *WebSocketClient) connectUniversal() {
|
func (c *websocketClient) connectUniversal() {
|
||||||
log.Infof("开始尝试连接到反向WebSocket Universal服务器: %v", c.conf.Universal)
|
log.Infof("开始尝试连接到反向WebSocket Universal服务器: %v", c.conf.Universal)
|
||||||
header := http.Header{
|
header := http.Header{
|
||||||
"X-Client-Role": []string{"Universal"},
|
"X-Client-Role": []string{"Universal"},
|
||||||
@ -194,7 +196,7 @@ func (c *WebSocketClient) connectUniversal() {
|
|||||||
c.universalConn = wrappedConn
|
c.universalConn = wrappedConn
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *WebSocketClient) listenAPI(conn *webSocketConn, u bool) {
|
func (c *websocketClient) listenAPI(conn *webSocketConn, u bool) {
|
||||||
defer func() { _ = conn.Close() }()
|
defer func() { _ = conn.Close() }()
|
||||||
for {
|
for {
|
||||||
buffer := global.NewBuffer()
|
buffer := global.NewBuffer()
|
||||||
@ -225,7 +227,7 @@ func (c *WebSocketClient) listenAPI(conn *webSocketConn, u bool) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *WebSocketClient) onBotPushEvent(m *bytes.Buffer) {
|
func (c *websocketClient) onBotPushEvent(m *bytes.Buffer) {
|
||||||
filter := findFilter(c.filter)
|
filter := findFilter(c.filter)
|
||||||
if filter != nil && !filter.Eval(gjson.Parse(utils.B2S(m.Bytes()))) {
|
if filter != nil && !filter.Eval(gjson.Parse(utils.B2S(m.Bytes()))) {
|
||||||
log.Debugf("上报Event %v 到 WS客户端 时被过滤.", utils.B2S(m.Bytes()))
|
log.Debugf("上报Event %v 到 WS客户端 时被过滤.", utils.B2S(m.Bytes()))
|
||||||
@ -263,28 +265,11 @@ func (c *WebSocketClient) onBotPushEvent(m *bytes.Buffer) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *webSocketServer) auth(r *http.Request) (bool, int) {
|
|
||||||
if s.token != "" { // s.token == s.conf.AccessToken
|
|
||||||
var auth string
|
|
||||||
if auth = r.URL.Query().Get("access_token"); auth == "" {
|
|
||||||
headAuth := strings.SplitN(r.Header.Get("Authorization"), " ", 2)
|
|
||||||
if len(headAuth) != 2 || headAuth[1] == "" {
|
|
||||||
return false, 401
|
|
||||||
}
|
|
||||||
auth = headAuth[1]
|
|
||||||
}
|
|
||||||
if auth != s.token {
|
|
||||||
log.Warnf("已拒绝 %v 的 WebSocket 请求: Token鉴权失败", r.RemoteAddr)
|
|
||||||
return false, 403
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true, 0
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *webSocketServer) event(w http.ResponseWriter, r *http.Request) {
|
func (s *webSocketServer) event(w http.ResponseWriter, r *http.Request) {
|
||||||
isAuth, errReason := s.auth(r)
|
status := checkAuth(r, s.token)
|
||||||
if !isAuth {
|
if status != http.StatusOK {
|
||||||
w.WriteHeader(errReason)
|
log.Warnf("已拒绝 %v 的 WebSocket 请求: Token鉴权失败(code:%d)", r.RemoteAddr, status)
|
||||||
|
w.WriteHeader(status)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c, err := upgrader.Upgrade(w, r, nil)
|
c, err := upgrader.Upgrade(w, r, nil)
|
||||||
@ -309,9 +294,10 @@ func (s *webSocketServer) event(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *webSocketServer) api(w http.ResponseWriter, r *http.Request) {
|
func (s *webSocketServer) api(w http.ResponseWriter, r *http.Request) {
|
||||||
isAuth, errReason := s.auth(r)
|
status := checkAuth(r, s.token)
|
||||||
if !isAuth {
|
if status != http.StatusOK {
|
||||||
w.WriteHeader(errReason)
|
log.Warnf("已拒绝 %v 的 WebSocket 请求: Token鉴权失败(code:%d)", r.RemoteAddr, status)
|
||||||
|
w.WriteHeader(status)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c, err := upgrader.Upgrade(w, r, nil)
|
c, err := upgrader.Upgrade(w, r, nil)
|
||||||
@ -328,9 +314,10 @@ func (s *webSocketServer) api(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *webSocketServer) any(w http.ResponseWriter, r *http.Request) {
|
func (s *webSocketServer) any(w http.ResponseWriter, r *http.Request) {
|
||||||
isAuth, errReason := s.auth(r)
|
status := checkAuth(r, s.token)
|
||||||
if !isAuth {
|
if status != http.StatusOK {
|
||||||
w.WriteHeader(errReason)
|
log.Warnf("已拒绝 %v 的 WebSocket 请求: Token鉴权失败(code:%d)", r.RemoteAddr, status)
|
||||||
|
w.WriteHeader(status)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c, err := upgrader.Upgrade(w, r, nil)
|
c, err := upgrader.Upgrade(w, r, nil)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user