1
0
mirror of https://github.com/Mrs4s/go-cqhttp.git synced 2025-05-04 11:07:39 +08:00

refactor: move api.go to modules/api, filter to modules/filter

maybe do code generate for api route
This commit is contained in:
wdvxdr 2021-10-27 11:40:01 +08:00
parent 4837b9677e
commit 208563d4c9
No known key found for this signature in database
GPG Key ID: 703F8C071DE7A1B6
7 changed files with 172 additions and 150 deletions

View File

@ -1,69 +1,74 @@
package server
// Package api implements the API route for servers.
package api
import (
"strings"
"github.com/tidwall/gjson"
"github.com/Mrs4s/go-cqhttp/coolq"
"github.com/Mrs4s/go-cqhttp/global"
"github.com/Mrs4s/go-cqhttp/internal/param"
"github.com/tidwall/gjson"
"github.com/Mrs4s/go-cqhttp/modules/filter"
)
type resultGetter interface {
// Getter 参数获取
type Getter interface {
Get(string) gjson.Result
}
type handler func(action string, p resultGetter) global.MSG
// Handler 中间件
type Handler func(action string, p Getter) global.MSG
type apiCaller struct {
// Caller api route caller
type Caller struct {
bot *coolq.CQBot
handlers []handler
handlers []Handler
}
func getLoginInfo(bot *coolq.CQBot, _ resultGetter) global.MSG {
func getLoginInfo(bot *coolq.CQBot, _ Getter) global.MSG {
return bot.CQGetLoginInfo()
}
func getQiDianAccountInfo(bot *coolq.CQBot, _ resultGetter) global.MSG {
func getQiDianAccountInfo(bot *coolq.CQBot, _ Getter) global.MSG {
return bot.CQGetQiDianAccountInfo()
}
func getFriendList(bot *coolq.CQBot, _ resultGetter) global.MSG {
func getFriendList(bot *coolq.CQBot, _ Getter) global.MSG {
return bot.CQGetFriendList()
}
func getUnidirectionalFriendList(bot *coolq.CQBot, _ resultGetter) global.MSG {
func getUnidirectionalFriendList(bot *coolq.CQBot, _ Getter) global.MSG {
return bot.CQGetUnidirectionalFriendList()
}
func deleteFriend(bot *coolq.CQBot, p resultGetter) global.MSG {
func deleteFriend(bot *coolq.CQBot, p Getter) global.MSG {
return bot.CQDeleteFriend(p.Get("[user_id,id].0").Int())
}
func deleteUnidirectionalFriend(bot *coolq.CQBot, p resultGetter) global.MSG {
func deleteUnidirectionalFriend(bot *coolq.CQBot, p Getter) global.MSG {
return bot.CQDeleteUnidirectionalFriend(p.Get("user_id").Int())
}
func getGroupList(bot *coolq.CQBot, p resultGetter) global.MSG {
func getGroupList(bot *coolq.CQBot, p Getter) global.MSG {
return bot.CQGetGroupList(p.Get("no_cache").Bool())
}
func getGroupInfo(bot *coolq.CQBot, p resultGetter) global.MSG {
func getGroupInfo(bot *coolq.CQBot, p Getter) global.MSG {
return bot.CQGetGroupInfo(p.Get("group_id").Int(), p.Get("no_cache").Bool())
}
func getGroupMemberList(bot *coolq.CQBot, p resultGetter) global.MSG {
func getGroupMemberList(bot *coolq.CQBot, p Getter) global.MSG {
return bot.CQGetGroupMemberList(p.Get("group_id").Int(), p.Get("no_cache").Bool())
}
func getGroupMemberInfo(bot *coolq.CQBot, p resultGetter) global.MSG {
func getGroupMemberInfo(bot *coolq.CQBot, p Getter) global.MSG {
return bot.CQGetGroupMemberInfo(
p.Get("group_id").Int(), p.Get("user_id").Int(), p.Get("no_cache").Bool(),
)
}
func sendMSG(bot *coolq.CQBot, p resultGetter) global.MSG {
func sendMSG(bot *coolq.CQBot, p Getter) global.MSG {
autoEscape := param.EnsureBool(p.Get("auto_escape"), false)
if p.Get("message_type").Str == "private" {
return bot.CQSendPrivateMessage(p.Get("user_id").Int(), p.Get("group_id").Int(), p.Get("message"), autoEscape)
@ -80,25 +85,25 @@ func sendMSG(bot *coolq.CQBot, p resultGetter) global.MSG {
return global.MSG{}
}
func sendGroupMSG(bot *coolq.CQBot, p resultGetter) global.MSG {
func sendGroupMSG(bot *coolq.CQBot, p Getter) global.MSG {
return bot.CQSendGroupMessage(p.Get("group_id").Int(), p.Get("message"),
param.EnsureBool(p.Get("auto_escape"), false))
}
func sendGroupForwardMSG(bot *coolq.CQBot, p resultGetter) global.MSG {
func sendGroupForwardMSG(bot *coolq.CQBot, p Getter) global.MSG {
return bot.CQSendGroupForwardMessage(p.Get("group_id").Int(), p.Get("messages"))
}
func sendPrivateMSG(bot *coolq.CQBot, p resultGetter) global.MSG {
func sendPrivateMSG(bot *coolq.CQBot, p Getter) global.MSG {
return bot.CQSendPrivateMessage(p.Get("user_id").Int(), p.Get("group_id").Int(), p.Get("message"),
param.EnsureBool(p.Get("auto_escape"), false))
}
func deleteMSG(bot *coolq.CQBot, p resultGetter) global.MSG {
func deleteMSG(bot *coolq.CQBot, p Getter) global.MSG {
return bot.CQDeleteMessage(int32(p.Get("message_id").Int()))
}
func setFriendAddRequest(bot *coolq.CQBot, p resultGetter) global.MSG {
func setFriendAddRequest(bot *coolq.CQBot, p Getter) global.MSG {
apr := true
if p.Get("approve").Exists() {
apr = p.Get("approve").Bool()
@ -106,7 +111,7 @@ func setFriendAddRequest(bot *coolq.CQBot, p resultGetter) global.MSG {
return bot.CQProcessFriendRequest(p.Get("flag").Str, apr)
}
func setGroupAddRequest(bot *coolq.CQBot, p resultGetter) global.MSG {
func setGroupAddRequest(bot *coolq.CQBot, p Getter) global.MSG {
subType := p.Get("sub_type").Str
apr := true
if subType == "" {
@ -118,19 +123,19 @@ func setGroupAddRequest(bot *coolq.CQBot, p resultGetter) global.MSG {
return bot.CQProcessGroupRequest(p.Get("flag").Str, subType, p.Get("reason").Str, apr)
}
func setGroupCard(bot *coolq.CQBot, p resultGetter) global.MSG {
func setGroupCard(bot *coolq.CQBot, p Getter) global.MSG {
return bot.CQSetGroupCard(p.Get("group_id").Int(), p.Get("user_id").Int(), p.Get("card").Str)
}
func setGroupSpecialTitle(bot *coolq.CQBot, p resultGetter) global.MSG {
func setGroupSpecialTitle(bot *coolq.CQBot, p Getter) global.MSG {
return bot.CQSetGroupSpecialTitle(p.Get("group_id").Int(), p.Get("user_id").Int(), p.Get("special_title").Str)
}
func setGroupKick(bot *coolq.CQBot, p resultGetter) global.MSG {
func setGroupKick(bot *coolq.CQBot, p Getter) global.MSG {
return bot.CQSetGroupKick(p.Get("group_id").Int(), p.Get("user_id").Int(), p.Get("message").Str, p.Get("reject_add_request").Bool())
}
func setGroupBan(bot *coolq.CQBot, p resultGetter) global.MSG {
func setGroupBan(bot *coolq.CQBot, p Getter) global.MSG {
return bot.CQSetGroupBan(p.Get("group_id").Int(), p.Get("user_id").Int(), func() uint32 {
if p.Get("duration").Exists() {
return uint32(p.Get("duration").Int())
@ -139,7 +144,7 @@ func setGroupBan(bot *coolq.CQBot, p resultGetter) global.MSG {
}())
}
func setGroupWholeBan(bot *coolq.CQBot, p resultGetter) global.MSG {
func setGroupWholeBan(bot *coolq.CQBot, p Getter) global.MSG {
return bot.CQSetGroupWholeBan(p.Get("group_id").Int(), func() bool {
if p.Get("enable").Exists() {
return p.Get("enable").Bool()
@ -148,11 +153,11 @@ func setGroupWholeBan(bot *coolq.CQBot, p resultGetter) global.MSG {
}())
}
func setGroupName(bot *coolq.CQBot, p resultGetter) global.MSG {
func setGroupName(bot *coolq.CQBot, p Getter) global.MSG {
return bot.CQSetGroupName(p.Get("group_id").Int(), p.Get("group_name").Str)
}
func setGroupAdmin(bot *coolq.CQBot, p resultGetter) global.MSG {
func setGroupAdmin(bot *coolq.CQBot, p Getter) global.MSG {
return bot.CQSetGroupAdmin(p.Get("group_id").Int(), p.Get("user_id").Int(), func() bool {
if p.Get("enable").Exists() {
return p.Get("enable").Bool()
@ -161,19 +166,19 @@ func setGroupAdmin(bot *coolq.CQBot, p resultGetter) global.MSG {
}())
}
func sendGroupNotice(bot *coolq.CQBot, p resultGetter) global.MSG {
func sendGroupNotice(bot *coolq.CQBot, p Getter) global.MSG {
return bot.CQSetGroupMemo(p.Get("group_id").Int(), p.Get("content").Str, p.Get("image").String())
}
func setGroupLeave(bot *coolq.CQBot, p resultGetter) global.MSG {
func setGroupLeave(bot *coolq.CQBot, p Getter) global.MSG {
return bot.CQSetGroupLeave(p.Get("group_id").Int())
}
func getImage(bot *coolq.CQBot, p resultGetter) global.MSG {
func getImage(bot *coolq.CQBot, p Getter) global.MSG {
return bot.CQGetImage(p.Get("file").Str)
}
func getForwardMSG(bot *coolq.CQBot, p resultGetter) global.MSG {
func getForwardMSG(bot *coolq.CQBot, p Getter) global.MSG {
id := p.Get("message_id").Str
if id == "" {
id = p.Get("id").Str
@ -181,11 +186,11 @@ func getForwardMSG(bot *coolq.CQBot, p resultGetter) global.MSG {
return bot.CQGetForwardMessage(id)
}
func getMSG(bot *coolq.CQBot, p resultGetter) global.MSG {
func getMSG(bot *coolq.CQBot, p Getter) global.MSG {
return bot.CQGetMessage(int32(p.Get("message_id").Int()))
}
func downloadFile(bot *coolq.CQBot, p resultGetter) global.MSG {
func downloadFile(bot *coolq.CQBot, p Getter) global.MSG {
headers := map[string]string{}
headersToken := p.Get("headers")
if headersToken.IsArray() {
@ -208,11 +213,11 @@ func downloadFile(bot *coolq.CQBot, p resultGetter) global.MSG {
return bot.CQDownloadFile(p.Get("url").Str, headers, int(p.Get("thread_count").Int()))
}
func getGroupHonorInfo(bot *coolq.CQBot, p resultGetter) global.MSG {
func getGroupHonorInfo(bot *coolq.CQBot, p Getter) global.MSG {
return bot.CQGetGroupHonorInfo(p.Get("group_id").Int(), p.Get("type").Str)
}
func setRestart(_ *coolq.CQBot, _ resultGetter) global.MSG {
func setRestart(_ *coolq.CQBot, _ Getter) global.MSG {
/*
var delay int64
delay = p.Get("delay").Int()
@ -227,112 +232,112 @@ func setRestart(_ *coolq.CQBot, _ resultGetter) global.MSG {
return global.MSG{"data": nil, "retcode": 99, "msg": "restart un-supported now", "wording": "restart函数暂不兼容", "status": "failed"}
}
func canSendImage(bot *coolq.CQBot, _ resultGetter) global.MSG {
func canSendImage(bot *coolq.CQBot, _ Getter) global.MSG {
return bot.CQCanSendImage()
}
func canSendRecord(bot *coolq.CQBot, _ resultGetter) global.MSG {
func canSendRecord(bot *coolq.CQBot, _ Getter) global.MSG {
return bot.CQCanSendRecord()
}
func getStrangerInfo(bot *coolq.CQBot, p resultGetter) global.MSG {
func getStrangerInfo(bot *coolq.CQBot, p Getter) global.MSG {
return bot.CQGetStrangerInfo(p.Get("user_id").Int())
}
func getStatus(bot *coolq.CQBot, _ resultGetter) global.MSG {
func getStatus(bot *coolq.CQBot, _ Getter) global.MSG {
return bot.CQGetStatus()
}
func getVersionInfo(bot *coolq.CQBot, _ resultGetter) global.MSG {
func getVersionInfo(bot *coolq.CQBot, _ Getter) global.MSG {
return bot.CQGetVersionInfo()
}
func getGroupSystemMSG(bot *coolq.CQBot, _ resultGetter) global.MSG {
func getGroupSystemMSG(bot *coolq.CQBot, _ Getter) global.MSG {
return bot.CQGetGroupSystemMessages()
}
func getGroupFileSystemInfo(bot *coolq.CQBot, p resultGetter) global.MSG {
func getGroupFileSystemInfo(bot *coolq.CQBot, p Getter) global.MSG {
return bot.CQGetGroupFileSystemInfo(p.Get("group_id").Int())
}
func getGroupRootFiles(bot *coolq.CQBot, p resultGetter) global.MSG {
func getGroupRootFiles(bot *coolq.CQBot, p Getter) global.MSG {
return bot.CQGetGroupRootFiles(p.Get("group_id").Int())
}
func getGroupFilesByFolder(bot *coolq.CQBot, p resultGetter) global.MSG {
func getGroupFilesByFolder(bot *coolq.CQBot, p Getter) global.MSG {
return bot.CQGetGroupFilesByFolderID(p.Get("group_id").Int(), p.Get("folder_id").Str)
}
func getGroupFileURL(bot *coolq.CQBot, p resultGetter) global.MSG {
func getGroupFileURL(bot *coolq.CQBot, p Getter) global.MSG {
return bot.CQGetGroupFileURL(p.Get("group_id").Int(), p.Get("file_id").Str, int32(p.Get("busid").Int()))
}
func uploadGroupFile(bot *coolq.CQBot, p resultGetter) global.MSG {
func uploadGroupFile(bot *coolq.CQBot, p Getter) global.MSG {
return bot.CQUploadGroupFile(p.Get("group_id").Int(), p.Get("file").Str, p.Get("name").Str, p.Get("folder").Str)
}
func groupFileCreateFolder(bot *coolq.CQBot, p resultGetter) global.MSG {
func groupFileCreateFolder(bot *coolq.CQBot, p Getter) global.MSG {
return bot.CQGroupFileCreateFolder(p.Get("group_id").Int(), p.Get("folder_id").Str, p.Get("name").Str)
}
func deleteGroupFolder(bot *coolq.CQBot, p resultGetter) global.MSG {
func deleteGroupFolder(bot *coolq.CQBot, p Getter) global.MSG {
return bot.CQGroupFileDeleteFolder(p.Get("group_id").Int(), p.Get("folder_id").Str)
}
func deleteGroupFile(bot *coolq.CQBot, p resultGetter) global.MSG {
func deleteGroupFile(bot *coolq.CQBot, p Getter) global.MSG {
return bot.CQGroupFileDeleteFile(p.Get("group_id").Int(), p.Get("file_id").Str, int32(p.Get("bus_id").Int()))
}
func getGroupMsgHistory(bot *coolq.CQBot, p resultGetter) global.MSG {
func getGroupMsgHistory(bot *coolq.CQBot, p Getter) global.MSG {
return bot.CQGetGroupMessageHistory(p.Get("group_id").Int(), p.Get("message_seq").Int())
}
func getVipInfo(bot *coolq.CQBot, p resultGetter) global.MSG {
func getVipInfo(bot *coolq.CQBot, p Getter) global.MSG {
return bot.CQGetVipInfo(p.Get("user_id").Int())
}
func reloadEventFilter(_ *coolq.CQBot, p resultGetter) global.MSG {
addFilter(p.Get("file").String())
func reloadEventFilter(_ *coolq.CQBot, p Getter) global.MSG {
filter.Add(p.Get("file").String())
return coolq.OK(nil)
}
func getGroupAtAllRemain(bot *coolq.CQBot, p resultGetter) global.MSG {
func getGroupAtAllRemain(bot *coolq.CQBot, p Getter) global.MSG {
return bot.CQGetAtAllRemain(p.Get("group_id").Int())
}
func ocrImage(bot *coolq.CQBot, p resultGetter) global.MSG {
func ocrImage(bot *coolq.CQBot, p Getter) global.MSG {
return bot.CQOcrImage(p.Get("image").Str)
}
func getOnlineClients(bot *coolq.CQBot, p resultGetter) global.MSG {
func getOnlineClients(bot *coolq.CQBot, p Getter) global.MSG {
return bot.CQGetOnlineClients(p.Get("no_cache").Bool())
}
func getWordSlices(bot *coolq.CQBot, p resultGetter) global.MSG {
func getWordSlices(bot *coolq.CQBot, p Getter) global.MSG {
return bot.CQGetWordSlices(p.Get("content").Str)
}
func setGroupPortrait(bot *coolq.CQBot, p resultGetter) global.MSG {
func setGroupPortrait(bot *coolq.CQBot, p Getter) global.MSG {
return bot.CQSetGroupPortrait(p.Get("group_id").Int(), p.Get("file").String(), p.Get("cache").String())
}
func setEssenceMSG(bot *coolq.CQBot, p resultGetter) global.MSG {
func setEssenceMSG(bot *coolq.CQBot, p Getter) global.MSG {
return bot.CQSetEssenceMessage(int32(p.Get("message_id").Int()))
}
func deleteEssenceMSG(bot *coolq.CQBot, p resultGetter) global.MSG {
func deleteEssenceMSG(bot *coolq.CQBot, p Getter) global.MSG {
return bot.CQDeleteEssenceMessage(int32(p.Get("message_id").Int()))
}
func getEssenceMsgList(bot *coolq.CQBot, p resultGetter) global.MSG {
func getEssenceMsgList(bot *coolq.CQBot, p Getter) global.MSG {
return bot.CQGetEssenceMessageList(p.Get("group_id").Int())
}
func checkURLSafely(bot *coolq.CQBot, p resultGetter) global.MSG {
func checkURLSafely(bot *coolq.CQBot, p Getter) global.MSG {
return bot.CQCheckURLSafely(p.Get("url").String())
}
func setGroupAnonymousBan(bot *coolq.CQBot, p resultGetter) global.MSG {
func setGroupAnonymousBan(bot *coolq.CQBot, p Getter) global.MSG {
obj := p.Get("anonymous")
flag := p.Get("anonymous_flag")
if !flag.Exists() {
@ -347,24 +352,24 @@ func setGroupAnonymousBan(bot *coolq.CQBot, p resultGetter) global.MSG {
return bot.CQSetGroupAnonymousBan(p.Get("group_id").Int(), flag.String(), int32(p.Get("duration").Int()))
}
func handleQuickOperation(bot *coolq.CQBot, p resultGetter) global.MSG {
func handleQuickOperation(bot *coolq.CQBot, p Getter) global.MSG {
return bot.CQHandleQuickOperation(p.Get("context"), p.Get("operation"))
}
func getModelShow(bot *coolq.CQBot, p resultGetter) global.MSG {
func getModelShow(bot *coolq.CQBot, p Getter) global.MSG {
return bot.CQGetModelShow(p.Get("model").String())
}
func setModelShow(bot *coolq.CQBot, p resultGetter) global.MSG {
func setModelShow(bot *coolq.CQBot, p Getter) global.MSG {
return bot.CQSetModelShow(p.Get("model").String(), p.Get("model_show").String())
}
func markMSGAsRead(bot *coolq.CQBot, p resultGetter) global.MSG {
func markMSGAsRead(bot *coolq.CQBot, p Getter) global.MSG {
return bot.CQMarkMessageAsRead(int32(p.Get("message_id").Int()))
}
// API 是go-cqhttp当前支持的所有api的映射表
var API = map[string]func(*coolq.CQBot, resultGetter) global.MSG{
var API = map[string]func(*coolq.CQBot, Getter) global.MSG{
"get_login_info": getLoginInfo,
"get_friend_list": getFriendList,
"get_unidirectional_friend_list": getUnidirectionalFriendList,
@ -431,7 +436,8 @@ var API = map[string]func(*coolq.CQBot, resultGetter) global.MSG{
"mark_msg_as_read": markMSGAsRead,
}
func (api *apiCaller) callAPI(action string, p resultGetter) global.MSG {
// Call specific API
func (api *Caller) Call(action string, p Getter) global.MSG {
for _, fn := range api.handlers {
if ret := fn(action, p); ret != nil {
return ret
@ -443,13 +449,15 @@ func (api *apiCaller) callAPI(action string, p resultGetter) global.MSG {
return coolq.Failed(404, "API_NOT_FOUND", "API不存在")
}
func (api *apiCaller) use(middlewares ...handler) {
// Use add handlers to the API caller
func (api *Caller) Use(middlewares ...Handler) {
api.handlers = append(api.handlers, middlewares...)
}
func newAPICaller(bot *coolq.CQBot) *apiCaller {
return &apiCaller{
// NewCaller create a new API caller
func NewCaller(bot *coolq.CQBot) *Caller {
return &Caller{
bot: bot,
handlers: []handler{},
handlers: make([]Handler, 0),
}
}

View File

@ -1,4 +1,5 @@
package global
// Package filter implements an event filter for go-cqhttp
package filter
import (
"regexp"
@ -54,14 +55,14 @@ func newAndOp(argument gjson.Result) Filter {
opKey := key.Str[1:]
op.operands = append(op.operands, operationNode{"", Generate(opKey, value)})
case value.IsObject():
// is an normal key with an object as the value
// is a normal key with an object as the value
// "foo": {
// ".bar": "baz"
// }
opKey := key.String()
op.operands = append(op.operands, operationNode{opKey, Generate("and", value)})
default:
// is an normal key with a non-object as the value
// is a normal key with a non-object as the value
// "foo": "bar"
opKey := key.String()
op.operands = append(op.operands, operationNode{opKey, Generate("eq", value)})
@ -79,7 +80,7 @@ func (op *andOperator) Eval(payload gjson.Result) bool {
// is an operator
res = res && operand.filter.Eval(payload)
} else {
// is an normal key
// is a normal key
val := payload.Get(operand.key)
res = res && operand.filter.Eval(val)
}

View File

@ -0,0 +1,45 @@
package filter
import (
"os"
"sync"
"github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
)
var (
filters = make(map[string]Filter)
filterMutex sync.RWMutex
)
// Add adds a filter to the list of filters
func Add(file string) {
if file == "" {
return
}
bs, err := os.ReadFile(file)
if err != nil {
logrus.Error("init filter error: ", err)
return
}
defer func() {
if err := recover(); err != nil {
logrus.Error("init filter error: ", err)
}
}()
filter := Generate("and", gjson.ParseBytes(bs))
filterMutex.Lock()
filters[file] = filter
filterMutex.Unlock()
}
// Find returns the filter for the given file
func Find(file string) Filter {
if file == "" {
return nil
}
filterMutex.RLock()
defer filterMutex.RUnlock()
return filters[file]
}

View File

@ -24,12 +24,14 @@ import (
"github.com/Mrs4s/go-cqhttp/coolq"
"github.com/Mrs4s/go-cqhttp/global"
"github.com/Mrs4s/go-cqhttp/internal/base"
"github.com/Mrs4s/go-cqhttp/modules/api"
"github.com/Mrs4s/go-cqhttp/modules/config"
"github.com/Mrs4s/go-cqhttp/modules/filter"
)
type httpServer struct {
HTTP *http.Server
api *apiCaller
api *api.Caller
accessToken string
}
@ -121,12 +123,12 @@ func (s *httpServer) ServeHTTP(writer http.ResponseWriter, request *http.Request
if base.AcceptOneBot12HTTPEndPoint && request.URL.Path == "/" {
action := strings.TrimSuffix(ctx.Get("action").Str, "_async")
log.Debugf("HTTPServer接收到API调用: %v", action)
response = s.api.callAPI(action, ctx.Get("params"))
response = s.api.Call(action, ctx.Get("params"))
} else {
action := strings.TrimPrefix(request.URL.Path, "/")
action = strings.TrimSuffix(action, "_async")
log.Debugf("HTTPServer接收到API调用: %v", action)
response = s.api.callAPI(action, &ctx)
response = s.api.Call(action, &ctx)
}
writer.Header().Set("Content-Type", "application/json; charset=utf-8")
@ -176,12 +178,12 @@ func runHTTP(bot *coolq.CQBot, node yaml.Node) {
goto client
}
addr = fmt.Sprintf("%s:%d", conf.Host, conf.Port)
s.api = newAPICaller(bot)
s.api = api.NewCaller(bot)
if conf.RateLimit.Enabled {
s.api.use(rateLimit(conf.RateLimit.Frequency, conf.RateLimit.Bucket))
s.api.Use(rateLimit(conf.RateLimit.Frequency, conf.RateLimit.Bucket))
}
if conf.LongPolling.Enabled {
s.api.use(longPolling(bot, conf.LongPolling.MaxQueueSize))
s.api.Use(longPolling(bot, conf.LongPolling.MaxQueueSize))
}
go func() {
@ -215,7 +217,7 @@ client:
// Run 运行反向HTTP服务
func (c HTTPClient) Run() {
addFilter(c.filter)
filter.Add(c.filter)
if c.timeout < 5 {
c.timeout = 5
}
@ -225,7 +227,7 @@ func (c HTTPClient) Run() {
func (c *HTTPClient) onBotPushEvent(e *coolq.Event) {
if c.filter != "" {
filter := findFilter(c.filter)
filter := filter.Find(c.filter)
if filter != nil && !filter.Eval(gjson.Parse(e.JSONString())) {
log.Debugf("上报Event %v 到 HTTP 服务器 %s 时被过滤.", c.addr, e.JSONBytes())
return

View File

@ -3,61 +3,25 @@ package server
import (
"container/list"
"context"
"os"
"sync"
"time"
"github.com/Mrs4s/go-cqhttp/coolq"
"github.com/Mrs4s/go-cqhttp/global"
"github.com/Mrs4s/go-cqhttp/modules/api"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"golang.org/x/time/rate"
)
var (
filters = make(map[string]global.Filter)
filterMutex sync.RWMutex
)
func rateLimit(frequency float64, bucketSize int) handler {
func rateLimit(frequency float64, bucketSize int) api.Handler {
limiter := rate.NewLimiter(rate.Limit(frequency), bucketSize)
return func(_ string, _ resultGetter) global.MSG {
return func(_ string, _ api.Getter) global.MSG {
_ = limiter.Wait(context.Background())
return nil
}
}
func addFilter(file string) {
if file == "" {
return
}
bs, err := os.ReadFile(file)
if err != nil {
log.Error("init filter error: ", err)
return
}
defer func() {
if err := recover(); err != nil {
log.Error("init filter error: ", err)
}
}()
filter := global.Generate("and", gjson.ParseBytes(bs))
filterMutex.Lock()
filters[file] = filter
filterMutex.Unlock()
}
func findFilter(file string) global.Filter {
if file == "" {
return nil
}
filterMutex.RLock()
defer filterMutex.RUnlock()
return filters[file]
}
func longPolling(bot *coolq.CQBot, maxSize int) handler {
func longPolling(bot *coolq.CQBot, maxSize int) api.Handler {
var mutex sync.Mutex
cond := sync.NewCond(&mutex)
queue := list.New()
@ -70,7 +34,7 @@ func longPolling(bot *coolq.CQBot, maxSize int) handler {
}
cond.Signal()
})
return func(action string, p resultGetter) global.MSG {
return func(action string, p api.Getter) global.MSG {
if action != "get_updates" {
return nil
}

View File

@ -17,6 +17,7 @@ import (
"github.com/Mrs4s/go-cqhttp/coolq"
"github.com/Mrs4s/go-cqhttp/global"
api2 "github.com/Mrs4s/go-cqhttp/modules/api"
"github.com/Mrs4s/go-cqhttp/modules/config"
)
@ -115,9 +116,9 @@ func runLambda(bot *coolq.CQBot, node yaml.Node) {
log.Fatal("unknown lambda type:", conf.Type)
}
api := newAPICaller(bot)
api := api2.NewCaller(bot)
if conf.RateLimit.Enabled {
api.use(rateLimit(conf.RateLimit.Frequency, conf.RateLimit.Bucket))
api.Use(rateLimit(conf.RateLimit.Frequency, conf.RateLimit.Bucket))
}
server := &httpServer{
api: api,

View File

@ -12,16 +12,17 @@ import (
"sync"
"time"
"github.com/Mrs4s/MiraiGo/utils"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"gopkg.in/yaml.v3"
"nhooyr.io/websocket"
"github.com/Mrs4s/go-cqhttp/coolq"
"github.com/Mrs4s/go-cqhttp/global"
"github.com/Mrs4s/go-cqhttp/modules/api"
"github.com/Mrs4s/go-cqhttp/modules/config"
"github.com/Mrs4s/MiraiGo/utils"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"github.com/Mrs4s/go-cqhttp/modules/filter"
)
type webSocketServer struct {
@ -51,7 +52,7 @@ type websocketClient struct {
type wsConn struct {
*websocket.Conn
apiCaller *apiCaller
apiCaller *api.Caller
}
// runWSServer 运行一个正向WS server
@ -71,7 +72,7 @@ func runWSServer(b *coolq.CQBot, node yaml.Node) {
token: conf.AccessToken,
filter: conf.Filter,
}
addFilter(s.filter)
filter.Add(s.filter)
addr := fmt.Sprintf("%s:%d", conf.Host, conf.Port)
s.handshake = fmt.Sprintf(`{"_post_method":2,"meta_event_type":"lifecycle","post_type":"meta_event","self_id":%d,"sub_type":"connect","time":%d}`,
b.Client.Uin, time.Now().Unix())
@ -103,7 +104,7 @@ func runWSClient(b *coolq.CQBot, node yaml.Node) {
token: conf.AccessToken,
filter: conf.Filter,
}
addFilter(c.filter)
filter.Add(c.filter)
if c.conf.Universal != "" {
c.connect("Universal", conf.Universal, &c.universal)
} else {
@ -147,9 +148,9 @@ func (c *websocketClient) connect(typ, url string, conptr **wsConn) {
}
log.Infof("已连接到反向WebSocket %s服务器 %v", typ, c.conf.API)
wrappedConn := &wsConn{Conn: conn, apiCaller: newAPICaller(c.bot)}
wrappedConn := &wsConn{Conn: conn, apiCaller: api.NewCaller(c.bot)}
if c.conf.RateLimit.Enabled {
wrappedConn.apiCaller.use(rateLimit(c.conf.RateLimit.Frequency, c.conf.RateLimit.Bucket))
wrappedConn.apiCaller.Use(rateLimit(c.conf.RateLimit.Frequency, c.conf.RateLimit.Bucket))
}
if conptr != nil {
@ -196,7 +197,7 @@ func (c *websocketClient) listenAPI(conn *wsConn, u bool) {
}
func (c *websocketClient) onBotPushEvent(e *coolq.Event) {
filter := findFilter(c.filter)
filter := filter.Find(c.filter)
if filter != nil && !filter.Eval(gjson.Parse(e.JSONString())) {
log.Debugf("上报Event %s 到 WS服务器 时被过滤.", e.JSONBytes())
return
@ -253,7 +254,7 @@ func (s *webSocketServer) event(w http.ResponseWriter, r *http.Request) {
log.Infof("接受 WebSocket 连接: %v (/event)", r.RemoteAddr)
conn := &wsConn{Conn: c, apiCaller: newAPICaller(s.bot)}
conn := &wsConn{Conn: c, apiCaller: api.NewCaller(s.bot)}
s.mu.Lock()
s.eventConn = append(s.eventConn, conn)
@ -273,9 +274,9 @@ func (s *webSocketServer) api(w http.ResponseWriter, r *http.Request) {
return
}
log.Infof("接受 WebSocket 连接: %v (/api)", r.RemoteAddr)
conn := &wsConn{Conn: c, apiCaller: newAPICaller(s.bot)}
conn := &wsConn{Conn: c, apiCaller: api.NewCaller(s.bot)}
if s.conf.RateLimit.Enabled {
conn.apiCaller.use(rateLimit(s.conf.RateLimit.Frequency, s.conf.RateLimit.Bucket))
conn.apiCaller.Use(rateLimit(s.conf.RateLimit.Frequency, s.conf.RateLimit.Bucket))
}
go s.listenAPI(conn)
}
@ -299,9 +300,9 @@ func (s *webSocketServer) any(w http.ResponseWriter, r *http.Request) {
return
}
log.Infof("接受 WebSocket 连接: %v (/)", r.RemoteAddr)
conn := &wsConn{Conn: c, apiCaller: newAPICaller(s.bot)}
conn := &wsConn{Conn: c, apiCaller: api.NewCaller(s.bot)}
if s.conf.RateLimit.Enabled {
conn.apiCaller.use(rateLimit(s.conf.RateLimit.Frequency, s.conf.RateLimit.Bucket))
conn.apiCaller.Use(rateLimit(s.conf.RateLimit.Frequency, s.conf.RateLimit.Bucket))
}
s.mu.Lock()
s.eventConn = append(s.eventConn, conn)
@ -343,7 +344,7 @@ func (c *wsConn) handleRequest(_ *coolq.CQBot, payload []byte) {
j := gjson.Parse(utils.B2S(payload))
t := strings.TrimSuffix(j.Get("action").Str, "_async")
log.Debugf("WS接收到API调用: %v 参数: %v", t, j.Get("params").Raw)
ret := c.apiCaller.callAPI(t, j.Get("params"))
ret := c.apiCaller.Call(t, j.Get("params"))
if j.Get("echo").Exists() {
ret["echo"] = j.Get("echo").Value()
}
@ -353,7 +354,7 @@ func (c *wsConn) handleRequest(_ *coolq.CQBot, payload []byte) {
}
func (s *webSocketServer) onBotPushEvent(e *coolq.Event) {
filter := findFilter(s.filter)
filter := filter.Find(s.filter)
if filter != nil && !filter.Eval(gjson.Parse(e.JSONString())) {
log.Debugf("上报Event %s 到 WS客户端 时被过滤.", e.JSONBytes())
return
@ -373,7 +374,7 @@ func (s *webSocketServer) onBotPushEvent(e *coolq.Event) {
}
if i != j {
// i != j means that some connection has been closed.
// use a in-place removal to avoid copying.
// use an in-place removal to avoid copying.
s.eventConn[j] = conn
}
j++