From 5ddb246ab26537823c608e3c43730ab771e9fbaa Mon Sep 17 00:00:00 2001 From: wdvxdr Date: Wed, 9 Sep 2020 10:47:14 +0800 Subject: [PATCH] feature: _rate_limit MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit api限速 --- docs/config.md | 8 ++ global/config.go | 14 ++++ global/ratelimit.go | 20 +++++ go.mod | 1 + go.sum | 2 + main.go | 3 + server/http.go | 196 ++++++++++++++++++++++++-------------------- server/websocket.go | 3 +- 8 files changed, 159 insertions(+), 88 deletions(-) create mode 100644 global/ratelimit.go diff --git a/docs/config.md b/docs/config.md index 76968a5..ccdff7c 100644 --- a/docs/config.md +++ b/docs/config.md @@ -27,6 +27,11 @@ go-cqhttp 支持导入CQHTTP的配置文件, 具体步骤为: "relogin_delay": 3, "max_relogin_times": 0 }, + "_rate_limit": { + "enabled": false, + "frequency": 1, + "bucket_size": 1 + }, "post_message_format": "string", "ignore_invalid_cqcode": false, "force_fragmented": true, @@ -66,6 +71,9 @@ go-cqhttp 支持导入CQHTTP的配置文件, 具体步骤为: | relogin | bool | 是否自动重新登录 | | relogin_delay | int | 重登录延时(秒) | | max_relogin_times | uint | 最大重登录次数,若0则不设置上限 | +| _rate_limit | bool | 是否启用API调用限速 | +| frequency | float64 | 1s内能调用API的次数 | +| bucket_size | int | 令牌桶的大小,默认为1,修改此值可允许一定程度内连续调用api | | post_message_format | string | 上报信息类型 | | ignore_invalid_cqcode| bool | 是否忽略错误的CQ码 | | force_fragmented | bool | 是否强制分片发送群长消息 | diff --git a/global/config.go b/global/config.go index 07ba209..31ed776 100644 --- a/global/config.go +++ b/global/config.go @@ -21,6 +21,11 @@ type JsonConfig struct { ReLoginDelay int `json:"relogin_delay"` MaxReloginTimes uint `json:"max_relogin_times"` } `json:"relogin"` + RateLimit struct { + Enabled bool `json:"enabled"` + Frequency float64 `json:"frequency"` + BucketSize int `json:"bucket_size"` + } `json:"_rate_limit"` IgnoreInvalidCQCode bool `json:"ignore_invalid_cqcode"` ForceFragmented bool `json:"force_fragmented"` HeartbeatInterval time.Duration `json:"heartbeat_interval"` @@ -85,6 +90,15 @@ func DefaultConfig() *JsonConfig { ReLoginDelay: 3, MaxReloginTimes: 0, }, + RateLimit: struct { + Enabled bool `json:"enabled"` + Frequency float64 `json:"frequency"` + BucketSize int `json:"bucket_size"` + }{ + Enabled: false, + Frequency: 1, + BucketSize: 1, + }, PostMessageFormat: "string", ForceFragmented: true, HttpConfig: &GoCQHttpConfig{ diff --git a/global/ratelimit.go b/global/ratelimit.go new file mode 100644 index 0000000..79594be --- /dev/null +++ b/global/ratelimit.go @@ -0,0 +1,20 @@ +package global + +import ( + "context" + "golang.org/x/time/rate" +) + +var limiter *rate.Limiter +var limitEnable = false + +func RateLimit(ctx context.Context) { + if limitEnable { + _ = limiter.Wait(ctx) + } +} + +func InitLimiter(r float64, b int) { + limitEnable = true + limiter = rate.NewLimiter(rate.Limit(r), b) +} diff --git a/go.mod b/go.mod index 85418c5..319ffc2 100644 --- a/go.mod +++ b/go.mod @@ -26,5 +26,6 @@ require ( github.com/yinghau76/go-ascii-art v0.0.0-20190517192627-e7f465a30189 golang.org/x/net v0.0.0-20200904194848-62affa334b73 // indirect golang.org/x/sys v0.0.0-20200905004654-be1d3432aa8f // indirect + golang.org/x/time v0.0.0-20200630173020-3af7569d3a1e gopkg.in/yaml.v2 v2.3.0 // indirect ) diff --git a/go.sum b/go.sum index 7473e10..4a4dada 100644 --- a/go.sum +++ b/go.sum @@ -159,6 +159,8 @@ golang.org/x/sys v0.0.0-20200905004654-be1d3432aa8f h1:Fqb3ao1hUmOR3GkUOg/Y+BadL golang.org/x/sys v0.0.0-20200905004654-be1d3432aa8f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= +golang.org/x/time v0.0.0-20200630173020-3af7569d3a1e h1:EHBhcS0mlXEAVwNyO2dLfjToGsyY4j24pTs2ScHnX7s= +golang.org/x/time v0.0.0-20200630173020-3af7569d3a1e/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= diff --git a/main.go b/main.go index f5eae47..3aeaba6 100644 --- a/main.go +++ b/main.go @@ -263,6 +263,9 @@ func main() { } else { coolq.SetMessageFormat(conf.PostMessageFormat) } + if conf.RateLimit.Enabled { + global.InitLimiter(conf.RateLimit.Frequency, conf.RateLimit.BucketSize) + } coolq.IgnoreInvalidCQCode = conf.IgnoreInvalidCQCode coolq.ForceFragmented = conf.ForceFragmented if conf.HttpConfig != nil && conf.HttpConfig.Enabled { diff --git a/server/http.go b/server/http.go index 1db9789..aec1134 100644 --- a/server/http.go +++ b/server/http.go @@ -1,6 +1,7 @@ package server import ( + "context" "crypto/hmac" "crypto/sha1" "encoding/hex" @@ -74,93 +75,7 @@ func (s *httpServer) Run(addr, authToken string, bot *coolq.CQBot) { }) } - s.engine.Any("/get_login_info", s.GetLoginInfo) - s.engine.Any("/get_login_info_async", s.GetLoginInfo) - - s.engine.Any("/get_friend_list", s.GetFriendList) - s.engine.Any("/get_friend_list_async", s.GetFriendList) - - s.engine.Any("/get_group_list", s.GetGroupList) - s.engine.Any("/get_group_list_async", s.GetGroupList) - - s.engine.Any("/get_group_info", s.GetGroupInfo) - s.engine.Any("/get_group_info_async", s.GetGroupInfo) - - s.engine.Any("/get_group_member_list", s.GetGroupMemberList) - s.engine.Any("/get_group_member_list_async", s.GetGroupMemberList) - - s.engine.Any("/get_group_member_info", s.GetGroupMemberInfo) - s.engine.Any("/get_group_member_info_async", s.GetGroupMemberInfo) - - s.engine.Any("/send_msg", s.SendMessage) - s.engine.Any("/send_msg_async", s.SendMessage) - - s.engine.Any("/send_private_msg", s.SendPrivateMessage) - s.engine.Any("/send_private_msg_async", s.SendPrivateMessage) - - s.engine.Any("/send_group_msg", s.SendGroupMessage) - s.engine.Any("/send_group_msg_async", s.SendGroupMessage) - - s.engine.Any("/send_group_forward_msg", s.SendGroupForwardMessage) - s.engine.Any("/send_group_forward_msg_async", s.SendGroupForwardMessage) - - s.engine.Any("/delete_msg", s.DeleteMessage) - s.engine.Any("/delete_msg_async", s.DeleteMessage) - - s.engine.Any("/set_friend_add_request", s.ProcessFriendRequest) - s.engine.Any("/set_friend_add_request_async", s.ProcessFriendRequest) - - s.engine.Any("/set_group_add_request", s.ProcessGroupRequest) - s.engine.Any("/set_group_add_request_async", s.ProcessGroupRequest) - - s.engine.Any("/set_group_card", s.SetGroupCard) - s.engine.Any("/set_group_card_async", s.SetGroupCard) - - s.engine.Any("/set_group_special_title", s.SetSpecialTitle) - s.engine.Any("/set_group_special_title_async", s.SetSpecialTitle) - - s.engine.Any("/set_group_kick", s.SetGroupKick) - s.engine.Any("/set_group_kick_async", s.SetGroupKick) - - s.engine.Any("/set_group_ban", s.SetGroupBan) - s.engine.Any("/set_group_ban_async", s.SetGroupBan) - - s.engine.Any("/set_group_whole_ban", s.SetWholeBan) - s.engine.Any("/set_group_whole_ban_async", s.SetWholeBan) - - s.engine.Any("/set_group_name", s.SetGroupName) - s.engine.Any("/set_group_name_async", s.SetGroupName) - - s.engine.Any("/_send_group_notice", s.SendGroupNotice) - s.engine.Any("/_send_group_notice_async", s.SendGroupNotice) - - s.engine.Any("/set_group_leave", s.SetGroupLeave) - s.engine.Any("/set_group_leave_async", s.SetGroupLeave) - - s.engine.Any("/get_image", s.GetImage) - - s.engine.Any("/get_forward_msg", s.GetForwardMessage) - - s.engine.Any("/get_group_msg", s.GetGroupMessage) - - s.engine.Any("/get_group_honor_info", s.GetGroupHonorInfo) - - s.engine.Any("/can_send_image", s.CanSendImage) - s.engine.Any("/can_send_image_async", s.CanSendImage) - - s.engine.Any("/can_send_record", s.CanSendRecord) - s.engine.Any("/can_send_record_async", s.CanSendRecord) - - s.engine.Any("/get_status", s.GetStatus) - s.engine.Any("/get_status_async", s.GetStatus) - - s.engine.Any("/get_version_info", s.GetVersionInfo) - s.engine.Any("/get_version_info_async", s.GetVersionInfo) - - s.engine.Any("/_get_vip_info", s.GetVipInfo) - s.engine.Any("/_get_vip_info_async", s.GetVipInfo) - - s.engine.Any("/.handle_quick_operation", s.HandleQuickOperation) + s.engine.Any("/:action", s.HandleActions) go func() { log.Infof("CQ HTTP 服务器已启动: %v", addr) @@ -213,6 +128,17 @@ func (c *httpClient) onBotPushEvent(m coolq.MSG) { } } +func (s *httpServer) HandleActions(c *gin.Context) { + global.RateLimit(context.Background()) + action := strings.ReplaceAll(c.Param("action"), "_async", "") + log.Debugf("HTTPServer接收到API调用: %v", action) + if f, ok := httpApi[action]; ok { + f(s, c) + } else { + c.JSON(200, coolq.Failed(404)) + } +} + func (s *httpServer) GetLoginInfo(c *gin.Context) { c.JSON(200, s.bot.CQGetLoginInfo()) } @@ -455,3 +381,99 @@ func getParamWithType(c *gin.Context, k string) (string, gjson.Type) { } return "", gjson.Null } + +var httpApi = map[string]func(s *httpServer, c *gin.Context){ + "get_login_info": func(s *httpServer, c *gin.Context) { + s.GetLoginInfo(c) + }, + "get_friend_list": func(s *httpServer, c *gin.Context) { + s.GetFriendList(c) + }, + "get_group_list": func(s *httpServer, c *gin.Context) { + s.GetGroupList(c) + }, + "get_group_info": func(s *httpServer, c *gin.Context) { + s.GetGroupInfo(c) + }, + "get_group_member_list": func(s *httpServer, c *gin.Context) { + s.GetGroupMemberList(c) + }, + "get_group_member_info": func(s *httpServer, c *gin.Context) { + s.GetGroupMemberInfo(c) + }, + "send_msg": func(s *httpServer, c *gin.Context) { + s.SendMessage(c) + }, + "send_group_msg": func(s *httpServer, c *gin.Context) { + s.SendGroupMessage(c) + }, + "send_group_forward_msg": func(s *httpServer, c *gin.Context) { + s.SendGroupForwardMessage(c) + }, + "send_private_msg": func(s *httpServer, c *gin.Context) { + s.SendPrivateMessage(c) + }, + "delete_msg": func(s *httpServer, c *gin.Context) { + s.DeleteMessage(c) + }, + "set_friend_add_request": func(s *httpServer, c *gin.Context) { + s.ProcessFriendRequest(c) + }, + "set_group_add_request": func(s *httpServer, c *gin.Context) { + s.ProcessGroupRequest(c) + }, + "set_group_card": func(s *httpServer, c *gin.Context) { + s.SetGroupCard(c) + }, + "set_group_special_title": func(s *httpServer, c *gin.Context) { + s.SetSpecialTitle(c) + }, + "set_group_kick": func(s *httpServer, c *gin.Context) { + s.SetGroupKick(c) + }, + "set_group_ban": func(s *httpServer, c *gin.Context) { + s.SetGroupBan(c) + }, + "set_group_whole_ban": func(s *httpServer, c *gin.Context) { + s.SetWholeBan(c) + }, + "set_group_name": func(s *httpServer, c *gin.Context) { + s.SetGroupName(c) + }, + "_send_group_notice": func(s *httpServer, c *gin.Context) { + s.SendGroupNotice(c) + }, + "set_group_leave": func(s *httpServer, c *gin.Context) { + s.SetGroupLeave(c) + }, + "get_image": func(s *httpServer, c *gin.Context) { + s.GetImage(c) + }, + "get_forward_msg": func(s *httpServer, c *gin.Context) { + s.GetForwardMessage(c) + }, + "get_group_msg": func(s *httpServer, c *gin.Context) { + s.GetGroupMessage(c) + }, + "get_group_honor_info": func(s *httpServer, c *gin.Context) { + s.GetGroupHonorInfo(c) + }, + "can_send_image": func(s *httpServer, c *gin.Context) { + s.CanSendImage(c) + }, + "can_send_record": func(s *httpServer, c *gin.Context) { + s.CanSendRecord(c) + }, + "get_status": func(s *httpServer, c *gin.Context) { + s.GetStatus(c) + }, + "get_version_info": func(s *httpServer, c *gin.Context) { + s.GetVersionInfo(c) + }, + "_get_vip_info": func(s *httpServer, c *gin.Context) { + s.GetVipInfo(c) + }, + ".handle_quick_operation": func(s *httpServer, c *gin.Context) { + s.HandleQuickOperation(c) + }, +} diff --git a/server/websocket.go b/server/websocket.go index 1ac2401..54fb92a 100644 --- a/server/websocket.go +++ b/server/websocket.go @@ -1,6 +1,7 @@ package server import ( + "context" "fmt" "net/http" "strconv" @@ -319,7 +320,7 @@ func (c *websocketConn) handleRequest(bot *coolq.CQBot, payload []byte) { c.Close() } }() - + global.RateLimit(context.Background()) j := gjson.ParseBytes(payload) t := strings.ReplaceAll(j.Get("action").Str, "_async", "") log.Debugf("WS接收到API调用: %v 参数: %v", t, j.Get("params").Raw)