From 208563d4c990a69bd86868680a157511b3ddcc25 Mon Sep 17 00:00:00 2001 From: wdvxdr Date: Wed, 27 Oct 2021 11:40:01 +0800 Subject: [PATCH] refactor: move api.go to modules/api, filter to modules/filter maybe do code generate for api route --- {server => modules/api}/api.go | 162 ++++++++++++++------------- {global => modules/filter}/filter.go | 9 +- modules/filter/middlewares.go | 45 ++++++++ server/http.go | 18 +-- server/middlewares.go | 46 +------- server/scf.go | 5 +- server/websocket.go | 37 +++--- 7 files changed, 172 insertions(+), 150 deletions(-) rename {server => modules/api}/api.go (69%) rename {global => modules/filter}/filter.go (96%) create mode 100644 modules/filter/middlewares.go diff --git a/server/api.go b/modules/api/api.go similarity index 69% rename from server/api.go rename to modules/api/api.go index 126a76e..98f022d 100644 --- a/server/api.go +++ b/modules/api/api.go @@ -1,69 +1,74 @@ -package server +// Package api implements the API route for servers. +package api import ( "strings" + "github.com/tidwall/gjson" + "github.com/Mrs4s/go-cqhttp/coolq" "github.com/Mrs4s/go-cqhttp/global" "github.com/Mrs4s/go-cqhttp/internal/param" - - "github.com/tidwall/gjson" + "github.com/Mrs4s/go-cqhttp/modules/filter" ) -type resultGetter interface { +// Getter 参数获取 +type Getter interface { Get(string) gjson.Result } -type handler func(action string, p resultGetter) global.MSG +// Handler 中间件 +type Handler func(action string, p Getter) global.MSG -type apiCaller struct { +// Caller api route caller +type Caller struct { bot *coolq.CQBot - handlers []handler + handlers []Handler } -func getLoginInfo(bot *coolq.CQBot, _ resultGetter) global.MSG { +func getLoginInfo(bot *coolq.CQBot, _ Getter) global.MSG { return bot.CQGetLoginInfo() } -func getQiDianAccountInfo(bot *coolq.CQBot, _ resultGetter) global.MSG { +func getQiDianAccountInfo(bot *coolq.CQBot, _ Getter) global.MSG { return bot.CQGetQiDianAccountInfo() } -func getFriendList(bot *coolq.CQBot, _ resultGetter) global.MSG { +func getFriendList(bot *coolq.CQBot, _ Getter) global.MSG { return bot.CQGetFriendList() } -func getUnidirectionalFriendList(bot *coolq.CQBot, _ resultGetter) global.MSG { +func getUnidirectionalFriendList(bot *coolq.CQBot, _ Getter) global.MSG { return bot.CQGetUnidirectionalFriendList() } -func deleteFriend(bot *coolq.CQBot, p resultGetter) global.MSG { +func deleteFriend(bot *coolq.CQBot, p Getter) global.MSG { return bot.CQDeleteFriend(p.Get("[user_id,id].0").Int()) } -func deleteUnidirectionalFriend(bot *coolq.CQBot, p resultGetter) global.MSG { +func deleteUnidirectionalFriend(bot *coolq.CQBot, p Getter) global.MSG { return bot.CQDeleteUnidirectionalFriend(p.Get("user_id").Int()) } -func getGroupList(bot *coolq.CQBot, p resultGetter) global.MSG { +func getGroupList(bot *coolq.CQBot, p Getter) global.MSG { return bot.CQGetGroupList(p.Get("no_cache").Bool()) } -func getGroupInfo(bot *coolq.CQBot, p resultGetter) global.MSG { +func getGroupInfo(bot *coolq.CQBot, p Getter) global.MSG { return bot.CQGetGroupInfo(p.Get("group_id").Int(), p.Get("no_cache").Bool()) } -func getGroupMemberList(bot *coolq.CQBot, p resultGetter) global.MSG { +func getGroupMemberList(bot *coolq.CQBot, p Getter) global.MSG { return bot.CQGetGroupMemberList(p.Get("group_id").Int(), p.Get("no_cache").Bool()) } -func getGroupMemberInfo(bot *coolq.CQBot, p resultGetter) global.MSG { +func getGroupMemberInfo(bot *coolq.CQBot, p Getter) global.MSG { return bot.CQGetGroupMemberInfo( p.Get("group_id").Int(), p.Get("user_id").Int(), p.Get("no_cache").Bool(), ) } -func sendMSG(bot *coolq.CQBot, p resultGetter) global.MSG { +func sendMSG(bot *coolq.CQBot, p Getter) global.MSG { autoEscape := param.EnsureBool(p.Get("auto_escape"), false) if p.Get("message_type").Str == "private" { return bot.CQSendPrivateMessage(p.Get("user_id").Int(), p.Get("group_id").Int(), p.Get("message"), autoEscape) @@ -80,25 +85,25 @@ func sendMSG(bot *coolq.CQBot, p resultGetter) global.MSG { return global.MSG{} } -func sendGroupMSG(bot *coolq.CQBot, p resultGetter) global.MSG { +func sendGroupMSG(bot *coolq.CQBot, p Getter) global.MSG { return bot.CQSendGroupMessage(p.Get("group_id").Int(), p.Get("message"), param.EnsureBool(p.Get("auto_escape"), false)) } -func sendGroupForwardMSG(bot *coolq.CQBot, p resultGetter) global.MSG { +func sendGroupForwardMSG(bot *coolq.CQBot, p Getter) global.MSG { return bot.CQSendGroupForwardMessage(p.Get("group_id").Int(), p.Get("messages")) } -func sendPrivateMSG(bot *coolq.CQBot, p resultGetter) global.MSG { +func sendPrivateMSG(bot *coolq.CQBot, p Getter) global.MSG { return bot.CQSendPrivateMessage(p.Get("user_id").Int(), p.Get("group_id").Int(), p.Get("message"), param.EnsureBool(p.Get("auto_escape"), false)) } -func deleteMSG(bot *coolq.CQBot, p resultGetter) global.MSG { +func deleteMSG(bot *coolq.CQBot, p Getter) global.MSG { return bot.CQDeleteMessage(int32(p.Get("message_id").Int())) } -func setFriendAddRequest(bot *coolq.CQBot, p resultGetter) global.MSG { +func setFriendAddRequest(bot *coolq.CQBot, p Getter) global.MSG { apr := true if p.Get("approve").Exists() { apr = p.Get("approve").Bool() @@ -106,7 +111,7 @@ func setFriendAddRequest(bot *coolq.CQBot, p resultGetter) global.MSG { return bot.CQProcessFriendRequest(p.Get("flag").Str, apr) } -func setGroupAddRequest(bot *coolq.CQBot, p resultGetter) global.MSG { +func setGroupAddRequest(bot *coolq.CQBot, p Getter) global.MSG { subType := p.Get("sub_type").Str apr := true if subType == "" { @@ -118,19 +123,19 @@ func setGroupAddRequest(bot *coolq.CQBot, p resultGetter) global.MSG { return bot.CQProcessGroupRequest(p.Get("flag").Str, subType, p.Get("reason").Str, apr) } -func setGroupCard(bot *coolq.CQBot, p resultGetter) global.MSG { +func setGroupCard(bot *coolq.CQBot, p Getter) global.MSG { return bot.CQSetGroupCard(p.Get("group_id").Int(), p.Get("user_id").Int(), p.Get("card").Str) } -func setGroupSpecialTitle(bot *coolq.CQBot, p resultGetter) global.MSG { +func setGroupSpecialTitle(bot *coolq.CQBot, p Getter) global.MSG { return bot.CQSetGroupSpecialTitle(p.Get("group_id").Int(), p.Get("user_id").Int(), p.Get("special_title").Str) } -func setGroupKick(bot *coolq.CQBot, p resultGetter) global.MSG { +func setGroupKick(bot *coolq.CQBot, p Getter) global.MSG { return bot.CQSetGroupKick(p.Get("group_id").Int(), p.Get("user_id").Int(), p.Get("message").Str, p.Get("reject_add_request").Bool()) } -func setGroupBan(bot *coolq.CQBot, p resultGetter) global.MSG { +func setGroupBan(bot *coolq.CQBot, p Getter) global.MSG { return bot.CQSetGroupBan(p.Get("group_id").Int(), p.Get("user_id").Int(), func() uint32 { if p.Get("duration").Exists() { return uint32(p.Get("duration").Int()) @@ -139,7 +144,7 @@ func setGroupBan(bot *coolq.CQBot, p resultGetter) global.MSG { }()) } -func setGroupWholeBan(bot *coolq.CQBot, p resultGetter) global.MSG { +func setGroupWholeBan(bot *coolq.CQBot, p Getter) global.MSG { return bot.CQSetGroupWholeBan(p.Get("group_id").Int(), func() bool { if p.Get("enable").Exists() { return p.Get("enable").Bool() @@ -148,11 +153,11 @@ func setGroupWholeBan(bot *coolq.CQBot, p resultGetter) global.MSG { }()) } -func setGroupName(bot *coolq.CQBot, p resultGetter) global.MSG { +func setGroupName(bot *coolq.CQBot, p Getter) global.MSG { return bot.CQSetGroupName(p.Get("group_id").Int(), p.Get("group_name").Str) } -func setGroupAdmin(bot *coolq.CQBot, p resultGetter) global.MSG { +func setGroupAdmin(bot *coolq.CQBot, p Getter) global.MSG { return bot.CQSetGroupAdmin(p.Get("group_id").Int(), p.Get("user_id").Int(), func() bool { if p.Get("enable").Exists() { return p.Get("enable").Bool() @@ -161,19 +166,19 @@ func setGroupAdmin(bot *coolq.CQBot, p resultGetter) global.MSG { }()) } -func sendGroupNotice(bot *coolq.CQBot, p resultGetter) global.MSG { +func sendGroupNotice(bot *coolq.CQBot, p Getter) global.MSG { return bot.CQSetGroupMemo(p.Get("group_id").Int(), p.Get("content").Str, p.Get("image").String()) } -func setGroupLeave(bot *coolq.CQBot, p resultGetter) global.MSG { +func setGroupLeave(bot *coolq.CQBot, p Getter) global.MSG { return bot.CQSetGroupLeave(p.Get("group_id").Int()) } -func getImage(bot *coolq.CQBot, p resultGetter) global.MSG { +func getImage(bot *coolq.CQBot, p Getter) global.MSG { return bot.CQGetImage(p.Get("file").Str) } -func getForwardMSG(bot *coolq.CQBot, p resultGetter) global.MSG { +func getForwardMSG(bot *coolq.CQBot, p Getter) global.MSG { id := p.Get("message_id").Str if id == "" { id = p.Get("id").Str @@ -181,11 +186,11 @@ func getForwardMSG(bot *coolq.CQBot, p resultGetter) global.MSG { return bot.CQGetForwardMessage(id) } -func getMSG(bot *coolq.CQBot, p resultGetter) global.MSG { +func getMSG(bot *coolq.CQBot, p Getter) global.MSG { return bot.CQGetMessage(int32(p.Get("message_id").Int())) } -func downloadFile(bot *coolq.CQBot, p resultGetter) global.MSG { +func downloadFile(bot *coolq.CQBot, p Getter) global.MSG { headers := map[string]string{} headersToken := p.Get("headers") if headersToken.IsArray() { @@ -208,11 +213,11 @@ func downloadFile(bot *coolq.CQBot, p resultGetter) global.MSG { return bot.CQDownloadFile(p.Get("url").Str, headers, int(p.Get("thread_count").Int())) } -func getGroupHonorInfo(bot *coolq.CQBot, p resultGetter) global.MSG { +func getGroupHonorInfo(bot *coolq.CQBot, p Getter) global.MSG { return bot.CQGetGroupHonorInfo(p.Get("group_id").Int(), p.Get("type").Str) } -func setRestart(_ *coolq.CQBot, _ resultGetter) global.MSG { +func setRestart(_ *coolq.CQBot, _ Getter) global.MSG { /* var delay int64 delay = p.Get("delay").Int() @@ -227,112 +232,112 @@ func setRestart(_ *coolq.CQBot, _ resultGetter) global.MSG { return global.MSG{"data": nil, "retcode": 99, "msg": "restart un-supported now", "wording": "restart函数暂不兼容", "status": "failed"} } -func canSendImage(bot *coolq.CQBot, _ resultGetter) global.MSG { +func canSendImage(bot *coolq.CQBot, _ Getter) global.MSG { return bot.CQCanSendImage() } -func canSendRecord(bot *coolq.CQBot, _ resultGetter) global.MSG { +func canSendRecord(bot *coolq.CQBot, _ Getter) global.MSG { return bot.CQCanSendRecord() } -func getStrangerInfo(bot *coolq.CQBot, p resultGetter) global.MSG { +func getStrangerInfo(bot *coolq.CQBot, p Getter) global.MSG { return bot.CQGetStrangerInfo(p.Get("user_id").Int()) } -func getStatus(bot *coolq.CQBot, _ resultGetter) global.MSG { +func getStatus(bot *coolq.CQBot, _ Getter) global.MSG { return bot.CQGetStatus() } -func getVersionInfo(bot *coolq.CQBot, _ resultGetter) global.MSG { +func getVersionInfo(bot *coolq.CQBot, _ Getter) global.MSG { return bot.CQGetVersionInfo() } -func getGroupSystemMSG(bot *coolq.CQBot, _ resultGetter) global.MSG { +func getGroupSystemMSG(bot *coolq.CQBot, _ Getter) global.MSG { return bot.CQGetGroupSystemMessages() } -func getGroupFileSystemInfo(bot *coolq.CQBot, p resultGetter) global.MSG { +func getGroupFileSystemInfo(bot *coolq.CQBot, p Getter) global.MSG { return bot.CQGetGroupFileSystemInfo(p.Get("group_id").Int()) } -func getGroupRootFiles(bot *coolq.CQBot, p resultGetter) global.MSG { +func getGroupRootFiles(bot *coolq.CQBot, p Getter) global.MSG { return bot.CQGetGroupRootFiles(p.Get("group_id").Int()) } -func getGroupFilesByFolder(bot *coolq.CQBot, p resultGetter) global.MSG { +func getGroupFilesByFolder(bot *coolq.CQBot, p Getter) global.MSG { return bot.CQGetGroupFilesByFolderID(p.Get("group_id").Int(), p.Get("folder_id").Str) } -func getGroupFileURL(bot *coolq.CQBot, p resultGetter) global.MSG { +func getGroupFileURL(bot *coolq.CQBot, p Getter) global.MSG { return bot.CQGetGroupFileURL(p.Get("group_id").Int(), p.Get("file_id").Str, int32(p.Get("busid").Int())) } -func uploadGroupFile(bot *coolq.CQBot, p resultGetter) global.MSG { +func uploadGroupFile(bot *coolq.CQBot, p Getter) global.MSG { return bot.CQUploadGroupFile(p.Get("group_id").Int(), p.Get("file").Str, p.Get("name").Str, p.Get("folder").Str) } -func groupFileCreateFolder(bot *coolq.CQBot, p resultGetter) global.MSG { +func groupFileCreateFolder(bot *coolq.CQBot, p Getter) global.MSG { return bot.CQGroupFileCreateFolder(p.Get("group_id").Int(), p.Get("folder_id").Str, p.Get("name").Str) } -func deleteGroupFolder(bot *coolq.CQBot, p resultGetter) global.MSG { +func deleteGroupFolder(bot *coolq.CQBot, p Getter) global.MSG { return bot.CQGroupFileDeleteFolder(p.Get("group_id").Int(), p.Get("folder_id").Str) } -func deleteGroupFile(bot *coolq.CQBot, p resultGetter) global.MSG { +func deleteGroupFile(bot *coolq.CQBot, p Getter) global.MSG { return bot.CQGroupFileDeleteFile(p.Get("group_id").Int(), p.Get("file_id").Str, int32(p.Get("bus_id").Int())) } -func getGroupMsgHistory(bot *coolq.CQBot, p resultGetter) global.MSG { +func getGroupMsgHistory(bot *coolq.CQBot, p Getter) global.MSG { return bot.CQGetGroupMessageHistory(p.Get("group_id").Int(), p.Get("message_seq").Int()) } -func getVipInfo(bot *coolq.CQBot, p resultGetter) global.MSG { +func getVipInfo(bot *coolq.CQBot, p Getter) global.MSG { return bot.CQGetVipInfo(p.Get("user_id").Int()) } -func reloadEventFilter(_ *coolq.CQBot, p resultGetter) global.MSG { - addFilter(p.Get("file").String()) +func reloadEventFilter(_ *coolq.CQBot, p Getter) global.MSG { + filter.Add(p.Get("file").String()) return coolq.OK(nil) } -func getGroupAtAllRemain(bot *coolq.CQBot, p resultGetter) global.MSG { +func getGroupAtAllRemain(bot *coolq.CQBot, p Getter) global.MSG { return bot.CQGetAtAllRemain(p.Get("group_id").Int()) } -func ocrImage(bot *coolq.CQBot, p resultGetter) global.MSG { +func ocrImage(bot *coolq.CQBot, p Getter) global.MSG { return bot.CQOcrImage(p.Get("image").Str) } -func getOnlineClients(bot *coolq.CQBot, p resultGetter) global.MSG { +func getOnlineClients(bot *coolq.CQBot, p Getter) global.MSG { return bot.CQGetOnlineClients(p.Get("no_cache").Bool()) } -func getWordSlices(bot *coolq.CQBot, p resultGetter) global.MSG { +func getWordSlices(bot *coolq.CQBot, p Getter) global.MSG { return bot.CQGetWordSlices(p.Get("content").Str) } -func setGroupPortrait(bot *coolq.CQBot, p resultGetter) global.MSG { +func setGroupPortrait(bot *coolq.CQBot, p Getter) global.MSG { return bot.CQSetGroupPortrait(p.Get("group_id").Int(), p.Get("file").String(), p.Get("cache").String()) } -func setEssenceMSG(bot *coolq.CQBot, p resultGetter) global.MSG { +func setEssenceMSG(bot *coolq.CQBot, p Getter) global.MSG { return bot.CQSetEssenceMessage(int32(p.Get("message_id").Int())) } -func deleteEssenceMSG(bot *coolq.CQBot, p resultGetter) global.MSG { +func deleteEssenceMSG(bot *coolq.CQBot, p Getter) global.MSG { return bot.CQDeleteEssenceMessage(int32(p.Get("message_id").Int())) } -func getEssenceMsgList(bot *coolq.CQBot, p resultGetter) global.MSG { +func getEssenceMsgList(bot *coolq.CQBot, p Getter) global.MSG { return bot.CQGetEssenceMessageList(p.Get("group_id").Int()) } -func checkURLSafely(bot *coolq.CQBot, p resultGetter) global.MSG { +func checkURLSafely(bot *coolq.CQBot, p Getter) global.MSG { return bot.CQCheckURLSafely(p.Get("url").String()) } -func setGroupAnonymousBan(bot *coolq.CQBot, p resultGetter) global.MSG { +func setGroupAnonymousBan(bot *coolq.CQBot, p Getter) global.MSG { obj := p.Get("anonymous") flag := p.Get("anonymous_flag") if !flag.Exists() { @@ -347,24 +352,24 @@ func setGroupAnonymousBan(bot *coolq.CQBot, p resultGetter) global.MSG { return bot.CQSetGroupAnonymousBan(p.Get("group_id").Int(), flag.String(), int32(p.Get("duration").Int())) } -func handleQuickOperation(bot *coolq.CQBot, p resultGetter) global.MSG { +func handleQuickOperation(bot *coolq.CQBot, p Getter) global.MSG { return bot.CQHandleQuickOperation(p.Get("context"), p.Get("operation")) } -func getModelShow(bot *coolq.CQBot, p resultGetter) global.MSG { +func getModelShow(bot *coolq.CQBot, p Getter) global.MSG { return bot.CQGetModelShow(p.Get("model").String()) } -func setModelShow(bot *coolq.CQBot, p resultGetter) global.MSG { +func setModelShow(bot *coolq.CQBot, p Getter) global.MSG { return bot.CQSetModelShow(p.Get("model").String(), p.Get("model_show").String()) } -func markMSGAsRead(bot *coolq.CQBot, p resultGetter) global.MSG { +func markMSGAsRead(bot *coolq.CQBot, p Getter) global.MSG { return bot.CQMarkMessageAsRead(int32(p.Get("message_id").Int())) } // API 是go-cqhttp当前支持的所有api的映射表 -var API = map[string]func(*coolq.CQBot, resultGetter) global.MSG{ +var API = map[string]func(*coolq.CQBot, Getter) global.MSG{ "get_login_info": getLoginInfo, "get_friend_list": getFriendList, "get_unidirectional_friend_list": getUnidirectionalFriendList, @@ -431,7 +436,8 @@ var API = map[string]func(*coolq.CQBot, resultGetter) global.MSG{ "mark_msg_as_read": markMSGAsRead, } -func (api *apiCaller) callAPI(action string, p resultGetter) global.MSG { +// Call specific API +func (api *Caller) Call(action string, p Getter) global.MSG { for _, fn := range api.handlers { if ret := fn(action, p); ret != nil { return ret @@ -443,13 +449,15 @@ func (api *apiCaller) callAPI(action string, p resultGetter) global.MSG { return coolq.Failed(404, "API_NOT_FOUND", "API不存在") } -func (api *apiCaller) use(middlewares ...handler) { +// Use add handlers to the API caller +func (api *Caller) Use(middlewares ...Handler) { api.handlers = append(api.handlers, middlewares...) } -func newAPICaller(bot *coolq.CQBot) *apiCaller { - return &apiCaller{ +// NewCaller create a new API caller +func NewCaller(bot *coolq.CQBot) *Caller { + return &Caller{ bot: bot, - handlers: []handler{}, + handlers: make([]Handler, 0), } } diff --git a/global/filter.go b/modules/filter/filter.go similarity index 96% rename from global/filter.go rename to modules/filter/filter.go index f308757..1ffa165 100644 --- a/global/filter.go +++ b/modules/filter/filter.go @@ -1,4 +1,5 @@ -package global +// Package filter implements an event filter for go-cqhttp +package filter import ( "regexp" @@ -54,14 +55,14 @@ func newAndOp(argument gjson.Result) Filter { opKey := key.Str[1:] op.operands = append(op.operands, operationNode{"", Generate(opKey, value)}) case value.IsObject(): - // is an normal key with an object as the value + // is a normal key with an object as the value // "foo": { // ".bar": "baz" // } opKey := key.String() op.operands = append(op.operands, operationNode{opKey, Generate("and", value)}) default: - // is an normal key with a non-object as the value + // is a normal key with a non-object as the value // "foo": "bar" opKey := key.String() op.operands = append(op.operands, operationNode{opKey, Generate("eq", value)}) @@ -79,7 +80,7 @@ func (op *andOperator) Eval(payload gjson.Result) bool { // is an operator res = res && operand.filter.Eval(payload) } else { - // is an normal key + // is a normal key val := payload.Get(operand.key) res = res && operand.filter.Eval(val) } diff --git a/modules/filter/middlewares.go b/modules/filter/middlewares.go new file mode 100644 index 0000000..4f7a73e --- /dev/null +++ b/modules/filter/middlewares.go @@ -0,0 +1,45 @@ +package filter + +import ( + "os" + "sync" + + "github.com/sirupsen/logrus" + "github.com/tidwall/gjson" +) + +var ( + filters = make(map[string]Filter) + filterMutex sync.RWMutex +) + +// Add adds a filter to the list of filters +func Add(file string) { + if file == "" { + return + } + bs, err := os.ReadFile(file) + if err != nil { + logrus.Error("init filter error: ", err) + return + } + defer func() { + if err := recover(); err != nil { + logrus.Error("init filter error: ", err) + } + }() + filter := Generate("and", gjson.ParseBytes(bs)) + filterMutex.Lock() + filters[file] = filter + filterMutex.Unlock() +} + +// Find returns the filter for the given file +func Find(file string) Filter { + if file == "" { + return nil + } + filterMutex.RLock() + defer filterMutex.RUnlock() + return filters[file] +} diff --git a/server/http.go b/server/http.go index a8dba28..58dc1d4 100644 --- a/server/http.go +++ b/server/http.go @@ -24,12 +24,14 @@ import ( "github.com/Mrs4s/go-cqhttp/coolq" "github.com/Mrs4s/go-cqhttp/global" "github.com/Mrs4s/go-cqhttp/internal/base" + "github.com/Mrs4s/go-cqhttp/modules/api" "github.com/Mrs4s/go-cqhttp/modules/config" + "github.com/Mrs4s/go-cqhttp/modules/filter" ) type httpServer struct { HTTP *http.Server - api *apiCaller + api *api.Caller accessToken string } @@ -121,12 +123,12 @@ func (s *httpServer) ServeHTTP(writer http.ResponseWriter, request *http.Request if base.AcceptOneBot12HTTPEndPoint && request.URL.Path == "/" { action := strings.TrimSuffix(ctx.Get("action").Str, "_async") log.Debugf("HTTPServer接收到API调用: %v", action) - response = s.api.callAPI(action, ctx.Get("params")) + response = s.api.Call(action, ctx.Get("params")) } else { action := strings.TrimPrefix(request.URL.Path, "/") action = strings.TrimSuffix(action, "_async") log.Debugf("HTTPServer接收到API调用: %v", action) - response = s.api.callAPI(action, &ctx) + response = s.api.Call(action, &ctx) } writer.Header().Set("Content-Type", "application/json; charset=utf-8") @@ -176,12 +178,12 @@ func runHTTP(bot *coolq.CQBot, node yaml.Node) { goto client } addr = fmt.Sprintf("%s:%d", conf.Host, conf.Port) - s.api = newAPICaller(bot) + s.api = api.NewCaller(bot) if conf.RateLimit.Enabled { - s.api.use(rateLimit(conf.RateLimit.Frequency, conf.RateLimit.Bucket)) + s.api.Use(rateLimit(conf.RateLimit.Frequency, conf.RateLimit.Bucket)) } if conf.LongPolling.Enabled { - s.api.use(longPolling(bot, conf.LongPolling.MaxQueueSize)) + s.api.Use(longPolling(bot, conf.LongPolling.MaxQueueSize)) } go func() { @@ -215,7 +217,7 @@ client: // Run 运行反向HTTP服务 func (c HTTPClient) Run() { - addFilter(c.filter) + filter.Add(c.filter) if c.timeout < 5 { c.timeout = 5 } @@ -225,7 +227,7 @@ func (c HTTPClient) Run() { func (c *HTTPClient) onBotPushEvent(e *coolq.Event) { if c.filter != "" { - filter := findFilter(c.filter) + filter := filter.Find(c.filter) if filter != nil && !filter.Eval(gjson.Parse(e.JSONString())) { log.Debugf("上报Event %v 到 HTTP 服务器 %s 时被过滤.", c.addr, e.JSONBytes()) return diff --git a/server/middlewares.go b/server/middlewares.go index 6241ca2..aff1a8d 100644 --- a/server/middlewares.go +++ b/server/middlewares.go @@ -3,61 +3,25 @@ package server import ( "container/list" "context" - "os" "sync" "time" "github.com/Mrs4s/go-cqhttp/coolq" "github.com/Mrs4s/go-cqhttp/global" + "github.com/Mrs4s/go-cqhttp/modules/api" - log "github.com/sirupsen/logrus" - "github.com/tidwall/gjson" "golang.org/x/time/rate" ) -var ( - filters = make(map[string]global.Filter) - filterMutex sync.RWMutex -) - -func rateLimit(frequency float64, bucketSize int) handler { +func rateLimit(frequency float64, bucketSize int) api.Handler { limiter := rate.NewLimiter(rate.Limit(frequency), bucketSize) - return func(_ string, _ resultGetter) global.MSG { + return func(_ string, _ api.Getter) global.MSG { _ = limiter.Wait(context.Background()) return nil } } -func addFilter(file string) { - if file == "" { - return - } - bs, err := os.ReadFile(file) - if err != nil { - log.Error("init filter error: ", err) - return - } - defer func() { - if err := recover(); err != nil { - log.Error("init filter error: ", err) - } - }() - filter := global.Generate("and", gjson.ParseBytes(bs)) - filterMutex.Lock() - filters[file] = filter - filterMutex.Unlock() -} - -func findFilter(file string) global.Filter { - if file == "" { - return nil - } - filterMutex.RLock() - defer filterMutex.RUnlock() - return filters[file] -} - -func longPolling(bot *coolq.CQBot, maxSize int) handler { +func longPolling(bot *coolq.CQBot, maxSize int) api.Handler { var mutex sync.Mutex cond := sync.NewCond(&mutex) queue := list.New() @@ -70,7 +34,7 @@ func longPolling(bot *coolq.CQBot, maxSize int) handler { } cond.Signal() }) - return func(action string, p resultGetter) global.MSG { + return func(action string, p api.Getter) global.MSG { if action != "get_updates" { return nil } diff --git a/server/scf.go b/server/scf.go index 6a0ef8e..d863419 100644 --- a/server/scf.go +++ b/server/scf.go @@ -17,6 +17,7 @@ import ( "github.com/Mrs4s/go-cqhttp/coolq" "github.com/Mrs4s/go-cqhttp/global" + api2 "github.com/Mrs4s/go-cqhttp/modules/api" "github.com/Mrs4s/go-cqhttp/modules/config" ) @@ -115,9 +116,9 @@ func runLambda(bot *coolq.CQBot, node yaml.Node) { log.Fatal("unknown lambda type:", conf.Type) } - api := newAPICaller(bot) + api := api2.NewCaller(bot) if conf.RateLimit.Enabled { - api.use(rateLimit(conf.RateLimit.Frequency, conf.RateLimit.Bucket)) + api.Use(rateLimit(conf.RateLimit.Frequency, conf.RateLimit.Bucket)) } server := &httpServer{ api: api, diff --git a/server/websocket.go b/server/websocket.go index fcbc49b..9b1a547 100644 --- a/server/websocket.go +++ b/server/websocket.go @@ -12,16 +12,17 @@ import ( "sync" "time" + "github.com/Mrs4s/MiraiGo/utils" + log "github.com/sirupsen/logrus" + "github.com/tidwall/gjson" "gopkg.in/yaml.v3" "nhooyr.io/websocket" "github.com/Mrs4s/go-cqhttp/coolq" "github.com/Mrs4s/go-cqhttp/global" + "github.com/Mrs4s/go-cqhttp/modules/api" "github.com/Mrs4s/go-cqhttp/modules/config" - - "github.com/Mrs4s/MiraiGo/utils" - log "github.com/sirupsen/logrus" - "github.com/tidwall/gjson" + "github.com/Mrs4s/go-cqhttp/modules/filter" ) type webSocketServer struct { @@ -51,7 +52,7 @@ type websocketClient struct { type wsConn struct { *websocket.Conn - apiCaller *apiCaller + apiCaller *api.Caller } // runWSServer 运行一个正向WS server @@ -71,7 +72,7 @@ func runWSServer(b *coolq.CQBot, node yaml.Node) { token: conf.AccessToken, filter: conf.Filter, } - addFilter(s.filter) + filter.Add(s.filter) addr := fmt.Sprintf("%s:%d", conf.Host, conf.Port) s.handshake = fmt.Sprintf(`{"_post_method":2,"meta_event_type":"lifecycle","post_type":"meta_event","self_id":%d,"sub_type":"connect","time":%d}`, b.Client.Uin, time.Now().Unix()) @@ -103,7 +104,7 @@ func runWSClient(b *coolq.CQBot, node yaml.Node) { token: conf.AccessToken, filter: conf.Filter, } - addFilter(c.filter) + filter.Add(c.filter) if c.conf.Universal != "" { c.connect("Universal", conf.Universal, &c.universal) } else { @@ -147,9 +148,9 @@ func (c *websocketClient) connect(typ, url string, conptr **wsConn) { } log.Infof("已连接到反向WebSocket %s服务器 %v", typ, c.conf.API) - wrappedConn := &wsConn{Conn: conn, apiCaller: newAPICaller(c.bot)} + wrappedConn := &wsConn{Conn: conn, apiCaller: api.NewCaller(c.bot)} if c.conf.RateLimit.Enabled { - wrappedConn.apiCaller.use(rateLimit(c.conf.RateLimit.Frequency, c.conf.RateLimit.Bucket)) + wrappedConn.apiCaller.Use(rateLimit(c.conf.RateLimit.Frequency, c.conf.RateLimit.Bucket)) } if conptr != nil { @@ -196,7 +197,7 @@ func (c *websocketClient) listenAPI(conn *wsConn, u bool) { } func (c *websocketClient) onBotPushEvent(e *coolq.Event) { - filter := findFilter(c.filter) + filter := filter.Find(c.filter) if filter != nil && !filter.Eval(gjson.Parse(e.JSONString())) { log.Debugf("上报Event %s 到 WS服务器 时被过滤.", e.JSONBytes()) return @@ -253,7 +254,7 @@ func (s *webSocketServer) event(w http.ResponseWriter, r *http.Request) { log.Infof("接受 WebSocket 连接: %v (/event)", r.RemoteAddr) - conn := &wsConn{Conn: c, apiCaller: newAPICaller(s.bot)} + conn := &wsConn{Conn: c, apiCaller: api.NewCaller(s.bot)} s.mu.Lock() s.eventConn = append(s.eventConn, conn) @@ -273,9 +274,9 @@ func (s *webSocketServer) api(w http.ResponseWriter, r *http.Request) { return } log.Infof("接受 WebSocket 连接: %v (/api)", r.RemoteAddr) - conn := &wsConn{Conn: c, apiCaller: newAPICaller(s.bot)} + conn := &wsConn{Conn: c, apiCaller: api.NewCaller(s.bot)} if s.conf.RateLimit.Enabled { - conn.apiCaller.use(rateLimit(s.conf.RateLimit.Frequency, s.conf.RateLimit.Bucket)) + conn.apiCaller.Use(rateLimit(s.conf.RateLimit.Frequency, s.conf.RateLimit.Bucket)) } go s.listenAPI(conn) } @@ -299,9 +300,9 @@ func (s *webSocketServer) any(w http.ResponseWriter, r *http.Request) { return } log.Infof("接受 WebSocket 连接: %v (/)", r.RemoteAddr) - conn := &wsConn{Conn: c, apiCaller: newAPICaller(s.bot)} + conn := &wsConn{Conn: c, apiCaller: api.NewCaller(s.bot)} if s.conf.RateLimit.Enabled { - conn.apiCaller.use(rateLimit(s.conf.RateLimit.Frequency, s.conf.RateLimit.Bucket)) + conn.apiCaller.Use(rateLimit(s.conf.RateLimit.Frequency, s.conf.RateLimit.Bucket)) } s.mu.Lock() s.eventConn = append(s.eventConn, conn) @@ -343,7 +344,7 @@ func (c *wsConn) handleRequest(_ *coolq.CQBot, payload []byte) { j := gjson.Parse(utils.B2S(payload)) t := strings.TrimSuffix(j.Get("action").Str, "_async") log.Debugf("WS接收到API调用: %v 参数: %v", t, j.Get("params").Raw) - ret := c.apiCaller.callAPI(t, j.Get("params")) + ret := c.apiCaller.Call(t, j.Get("params")) if j.Get("echo").Exists() { ret["echo"] = j.Get("echo").Value() } @@ -353,7 +354,7 @@ func (c *wsConn) handleRequest(_ *coolq.CQBot, payload []byte) { } func (s *webSocketServer) onBotPushEvent(e *coolq.Event) { - filter := findFilter(s.filter) + filter := filter.Find(s.filter) if filter != nil && !filter.Eval(gjson.Parse(e.JSONString())) { log.Debugf("上报Event %s 到 WS客户端 时被过滤.", e.JSONBytes()) return @@ -373,7 +374,7 @@ func (s *webSocketServer) onBotPushEvent(e *coolq.Event) { } if i != j { // i != j means that some connection has been closed. - // use a in-place removal to avoid copying. + // use an in-place removal to avoid copying. s.eventConn[j] = conn } j++