From 3ea9c2547c0b0a12f455fb0518dcbcd485c83d45 Mon Sep 17 00:00:00 2001 From: wdvxdr Date: Sun, 28 Mar 2021 19:07:20 +0800 Subject: [PATCH] switch to yaml config(#770) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit return on filter file does not exist refactor: event filter 将event filter独立出来 refact: rate limit 将限速改成连接独立控制 switch to yaml config --- coolq/api.go | 8 - coolq/bot.go | 26 ++- global/config.go | 322 +------------------------------ global/config/config.go | 129 +++++++++++++ global/config/default_config.yml | 113 +++++++++++ global/filter.go | 75 +------ global/ratelimit.go | 25 --- go.mod | 2 +- go.sum | 2 - main.go | 288 ++++++++++----------------- server/api.go | 26 ++- server/http.go | 48 +++-- server/middlewares.go | 56 ++++++ server/websocket.go | 143 ++++++++------ 14 files changed, 571 insertions(+), 692 deletions(-) create mode 100644 global/config/config.go create mode 100644 global/config/default_config.yml delete mode 100644 global/ratelimit.go create mode 100644 server/middlewares.go diff --git a/coolq/api.go b/coolq/api.go index 72d90f1..849611d 100644 --- a/coolq/api.go +++ b/coolq/api.go @@ -1148,14 +1148,6 @@ func (bot *CQBot) CQOcrImage(imageID string) MSG { return OK(rsp) } -// CQReloadEventFilter 扩展API-重载事件过滤器 -// -// https://docs.go-cqhttp.org/api/#%E9%87%8D%E8%BD%BD%E4%BA%8B%E4%BB%B6%E8%BF%87%E6%BB%A4%E5%99%A8 -func (bot *CQBot) CQReloadEventFilter() MSG { - global.BootFilter() - return OK(nil) -} - // CQSetGroupPortrait 扩展API-设置群头像 // // https://docs.go-cqhttp.org/api/#%E8%AE%BE%E7%BD%AE%E7%BE%A4%E5%A4%B4%E5%83%8F diff --git a/coolq/bot.go b/coolq/bot.go index 9a830a0..504fb9f 100644 --- a/coolq/bot.go +++ b/coolq/bot.go @@ -14,6 +14,7 @@ import ( "time" "github.com/Mrs4s/go-cqhttp/global" + "github.com/Mrs4s/go-cqhttp/global/config" "github.com/Mrs4s/MiraiGo/binary" "github.com/Mrs4s/MiraiGo/client" @@ -44,11 +45,18 @@ type MSG map[string]interface{} var ForceFragmented = false // NewQQBot 初始化一个QQBot实例 -func NewQQBot(cli *client.QQClient, conf *global.JSONConfig) *CQBot { +func NewQQBot(cli *client.QQClient, conf *config.Config) *CQBot { bot := &CQBot{ Client: cli, } - if conf.EnableDB { + enableLevelDB := false + node, ok := conf.Database["leveldb"] + if ok { + lconf := new(config.LevelDBConfig) + _ = node.Decode(lconf) + enableLevelDB = lconf.Enable + } + if enableLevelDB { p := path.Join("data", "leveldb") db, err := leveldb.OpenFile(p, nil) if err != nil { @@ -62,7 +70,7 @@ func NewQQBot(cli *client.QQClient, conf *global.JSONConfig) *CQBot { } bot.Client.OnPrivateMessage(bot.privateMessageEvent) bot.Client.OnGroupMessage(bot.groupMessageEvent) - if conf.EnableSelfMessage { + if conf.Message.ReportSelfMessage { bot.Client.OnSelfPrivateMessage(bot.privateMessageEvent) bot.Client.OnSelfGroupMessage(bot.groupMessageEvent) } @@ -86,16 +94,17 @@ func NewQQBot(cli *client.QQClient, conf *global.JSONConfig) *CQBot { bot.Client.OnOtherClientStatusChanged(bot.otherClientStatusChangedEvent) bot.Client.OnGroupDigest(bot.groupEssenceMsg) go func() { - i := conf.HeartbeatInterval - if i < 0 { + i := conf.Heartbeat.Interval + if i < 0 || conf.Heartbeat.Disabled { log.Warn("警告: 心跳功能已关闭,若非预期,请检查配置文件。") return } if i == 0 { i = 5 } + t := time.NewTicker(time.Second * time.Duration(i)) for { - time.Sleep(time.Second * i) + <-t.C bot.dispatchEventMessage(MSG{ "time": time.Now().Unix(), "self_id": bot.Client.Uin, @@ -429,10 +438,6 @@ func (bot *CQBot) Release() { } func (bot *CQBot) dispatchEventMessage(m MSG) { - if global.EventFilter != nil && !global.EventFilter.Eval(global.MSG(m)) { - log.Debug("Event filtered!") - return - } buf := global.NewBuffer() wg := sync.WaitGroup{} wg.Add(len(bot.events)) @@ -445,6 +450,7 @@ func (bot *CQBot) dispatchEventMessage(m MSG) { log.Warnf("处理事件 %v 时出现错误: %v \n%s", m, pan, debug.Stack()) } }() + start := time.Now() fn(buf) end := time.Now() diff --git a/global/config.go b/global/config.go index b15a914..f75ed0b 100644 --- a/global/config.go +++ b/global/config.go @@ -1,330 +1,10 @@ package global -import ( - "os" - "path" - "strconv" - "time" - - "github.com/hjson/hjson-go" - jsoniter "github.com/json-iterator/go" - log "github.com/sirupsen/logrus" -) - -var json = jsoniter.ConfigCompatibleWithStandardLibrary - -var ( - currentPath = getCurrentPath() - DefaultConfFile = path.Join(currentPath, "config.hjson") - AccountToken []byte -) - -// DefaultConfigWithComments 为go-cqhttp的默认配置文件 -var DefaultConfigWithComments = ` -/* - go-cqhttp 默认配置文件 -*/ - -{ - // QQ号 - uin: 0 - // QQ密码 - password: "" - // 是否启用密码加密 - encrypt_password: false - // 加密后的密码, 如未启用密码加密将为空, 请勿随意修改. - password_encrypted: "" - // 是否启用内置数据库 - // 启用将会增加10-20MB的内存占用和一定的磁盘空间 - // 关闭将无法使用 撤回 回复 get_msg 等上下文相关功能 - enable_db: true - // 访问密钥, 强烈推荐在公网的服务器设置 - access_token: "" - // 重连设置 - relogin: { - // 是否启用自动重连 - // 如不启用掉线后将不会自动重连 - enabled: true - // 重连延迟, 单位秒 - relogin_delay: 3 - // 最大重连次数, 0为无限制 - max_relogin_times: 0 - } - // API限速设置 - // 该设置为全局生效 - // 原 cqhttp 虽然启用了 rate_limit 后缀, 但是基本没插件适配 - // 目前该限速设置为令牌桶算法, 请参考: - // https://baike.baidu.com/item/%E4%BB%A4%E7%89%8C%E6%A1%B6%E7%AE%97%E6%B3%95/6597000?fr=aladdin - _rate_limit: { - // 是否启用限速 - enabled: false - // 令牌回复频率, 单位秒 - frequency: 1 - // 令牌桶大小 - bucket_size: 1 - } - // 是否忽略无效的CQ码 - // 如果为假将原样发送 - ignore_invalid_cqcode: false - // 是否强制分片发送消息 - // 分片发送将会带来更快的速度 - // 但是兼容性会有些问题 - force_fragmented: false - // 心跳频率, 单位秒 - // -1 为关闭心跳 - heartbeat_interval: 0 - // HTTP设置 - http_config: { - // 是否启用正向HTTP服务器 - enabled: true - // 服务端监听地址 - host: 0.0.0.0 - // 服务端监听端口 - port: 5700 - // 反向HTTP超时时间, 单位秒 - // 最小值为5,小于5将会忽略本项设置 - timeout: 0 - // 反向HTTP POST地址列表 - // 格式: - // { - // 地址: secret - // } - post_urls: {} - } - // 正向WS设置 - ws_config: { - // 是否启用正向WS服务器 - enabled: true - // 正向WS服务器监听地址 - host: 0.0.0.0 - // 正向WS服务器监听端口 - port: 6700 - } - // 反向WS设置 - ws_reverse_servers: [ - // 可以添加多个反向WS推送 - { - // 是否启用该推送 - enabled: false - // 反向WS Universal 地址 - // 注意 设置了此项地址后下面两项将会被忽略 - // 留空请使用 "" - reverse_url: ws://you_websocket_universal.server - // 反向WS API 地址 - reverse_api_url: ws://you_websocket_api.server - // 反向WS Event 地址 - reverse_event_url: ws://you_websocket_event.server - // 重连间隔 单位毫秒 - reverse_reconnect_interval: 3000 - } - ] - // 上报数据类型 - // 可选: string array - post_message_format: string - // 是否使用服务器下发的新地址进行重连 - // 注意, 此设置可能导致在海外服务器上连接情况更差 - use_sso_address: false - // 是否启用 DEBUG - debug: false - // 日志等级 trace,debug,info,warn,error - log_level: "info" - // WebUi 设置 - web_ui: { - // 是否启用 WebUi - enabled: true - // 监听地址 - host: 127.0.0.1 - // 监听端口 - web_ui_port: 9999 - // 是否接收来自web的输入 - web_input: false - } -} -` +var AccountToken []byte // PasswordHash 存储QQ密码哈希供登录使用 var PasswordHash [16]byte -// JSONConfig Config对应的结构体 -type JSONConfig struct { - Uin int64 `json:"uin"` - Password string `json:"password"` - EncryptPassword bool `json:"encrypt_password"` - PasswordEncrypted string `json:"password_encrypted"` - EnableDB bool `json:"enable_db"` - EnableSelfMessage bool `json:"enable_self_message"` - AccessToken string `json:"access_token"` - ReLogin struct { - Enabled bool `json:"enabled"` - ReLoginDelay int `json:"relogin_delay"` - MaxReloginTimes uint `json:"max_relogin_times"` - } `json:"relogin"` - RateLimit struct { - Enabled bool `json:"enabled"` - Frequency float64 `json:"frequency"` - BucketSize int `json:"bucket_size"` - } `json:"_rate_limit"` - IgnoreInvalidCQCode bool `json:"ignore_invalid_cqcode"` - ForceFragmented bool `json:"force_fragmented"` - FixURL bool `json:"fix_url"` - ProxyRewrite string `json:"proxy_rewrite"` - HeartbeatInterval time.Duration `json:"heartbeat_interval"` - HTTPConfig *GoCQHTTPConfig `json:"http_config"` - WSConfig *GoCQWebSocketConfig `json:"ws_config"` - ReverseServers []*GoCQReverseWebSocketConfig `json:"ws_reverse_servers"` - PostMessageFormat string `json:"post_message_format"` - UseSSOAddress bool `json:"use_sso_address"` - Debug bool `json:"debug"` - LogLevel string `json:"log_level"` - // WebUI *GoCQWebUI `json:"web_ui"` -} - -// CQHTTPAPIConfig HTTPAPI对应的Config结构体 -type CQHTTPAPIConfig struct { - Host string `json:"host"` - Port uint16 `json:"port"` - UseHTTP bool `json:"use_http"` - WSHost string `json:"ws_host"` - WSPort uint16 `json:"ws_port"` - UseWS bool `json:"use_ws"` - WSReverseURL string `json:"ws_reverse_url"` - WSReverseAPIURL string `json:"ws_reverse_api_url"` - WSReverseEventURL string `json:"ws_reverse_event_url"` - WSReverseReconnectInterval uint16 `json:"ws_reverse_reconnect_interval"` - WSReverseReconnectOnCode1000 bool `json:"ws_reverse_reconnect_on_code_1000"` - UseWsReverse bool `json:"use_ws_reverse"` - PostURL string `json:"post_url"` - AccessToken string `json:"access_token"` - Secret string `json:"secret"` - PostMessageFormat string `json:"post_message_format"` -} - -// GoCQHTTPConfig 正向HTTP对应config结构体 -type GoCQHTTPConfig struct { - Enabled bool `json:"enabled"` - Host string `json:"host"` - Port uint16 `json:"port"` - Timeout int32 `json:"timeout"` - PostUrls map[string]string `json:"post_urls"` -} - -// GoCQWebSocketConfig 正向WebSocket对应Config结构体 -type GoCQWebSocketConfig struct { - Enabled bool `json:"enabled"` - Host string `json:"host"` - Port uint16 `json:"port"` -} - -// GoCQReverseWebSocketConfig 反向WebSocket对应Config结构体 -type GoCQReverseWebSocketConfig struct { - Enabled bool `json:"enabled"` - ReverseURL string `json:"reverse_url"` - ReverseAPIURL string `json:"reverse_api_url"` - ReverseEventURL string `json:"reverse_event_url"` - ReverseReconnectInterval uint16 `json:"reverse_reconnect_interval"` -} - -/* -// GoCQWebUI WebUI对应Config结构体 -type GoCQWebUI struct { - Enabled bool `json:"enabled"` - Host string `json:"host"` - WebUIPort uint64 `json:"web_ui_port"` - WebInput bool `json:"web_input"` -} -*/ - -// DefaultConfig 返回一份默认配置对应结构体 -func DefaultConfig() *JSONConfig { - return &JSONConfig{ - EnableDB: true, - ReLogin: struct { - Enabled bool `json:"enabled"` - ReLoginDelay int `json:"relogin_delay"` - MaxReloginTimes uint `json:"max_relogin_times"` - }{ - Enabled: true, - ReLoginDelay: 3, - MaxReloginTimes: 0, - }, - RateLimit: struct { - Enabled bool `json:"enabled"` - Frequency float64 `json:"frequency"` - BucketSize int `json:"bucket_size"` - }{ - Enabled: false, - Frequency: 1, - BucketSize: 1, - }, - PostMessageFormat: "string", - ForceFragmented: false, - HTTPConfig: &GoCQHTTPConfig{ - Enabled: true, - Host: "0.0.0.0", - Port: 5700, - PostUrls: map[string]string{}, - }, - WSConfig: &GoCQWebSocketConfig{ - Enabled: true, - Host: "0.0.0.0", - Port: 6700, - }, - ReverseServers: []*GoCQReverseWebSocketConfig{ - { - Enabled: false, - ReverseURL: "ws://you_websocket_universal.server", - ReverseAPIURL: "ws://you_websocket_api.server", - ReverseEventURL: "ws://you_websocket_event.server", - ReverseReconnectInterval: 3000, - }, - }, - } -} - -// LoadConfig 加载配置文件 -func LoadConfig(p string) *JSONConfig { - if !PathExists(p) { - log.Warnf("尝试加载配置文件 %v 失败: 文件不存在", p) - return nil - } - var dat map[string]interface{} - c := JSONConfig{} - err := hjson.Unmarshal([]byte(ReadAllText(p)), &dat) - if err == nil { - b, _ := json.Marshal(dat) - err = json.Unmarshal(b, &c) - } - if err != nil { - log.Warnf("尝试加载配置文件 %v 时出现错误: %v", p, err) - log.Infoln("原文件已备份") - _ = os.Rename(p, p+".backup"+strconv.FormatInt(time.Now().Unix(), 10)) - return nil - } - return &c -} - -// Save 写入配置文件至path -func (c *JSONConfig) Save(path string) error { - data, err := hjson.MarshalWithOptions(c, hjson.EncoderOptions{ - Eol: "\n", - BracesSameLine: true, - IndentBy: " ", - }) - if err != nil { - return err - } - return WriteAllText(path, string(data)) -} - -// getCurrentPath 获取当前文件的路径,直接返回string -func getCurrentPath() string { - cwd, e := os.Getwd() - if e != nil { - panic(e) - } - return cwd -} - /* // GetCurrentPath 预留,获取当前目录地址 func GetCurrentPath() (string, error) { diff --git a/global/config/config.go b/global/config/config.go new file mode 100644 index 0000000..cb94e11 --- /dev/null +++ b/global/config/config.go @@ -0,0 +1,129 @@ +package config + +import ( + _ "embed" // embed the default config file + "os" + "path" + + log "github.com/sirupsen/logrus" + "gopkg.in/yaml.v3" +) + +// DefaultConfig 默认配置文件 +//go:embed default_config.yml +var DefaultConfig string + +var currentPath = getCurrentPath() + +// DefaultConfigFile 默认配置文件路径 +var DefaultConfigFile = path.Join(currentPath, "config.yml") + +// Config 总配置文件 +type Config struct { + Account struct { + Uin int64 `yaml:"uin"` + Password string `yaml:"password"` + Encrypt bool `yaml:"encrypt"` + ReLogin struct { + Disabled bool `yaml:"disabled"` + Delay int `yaml:"delay"` + MaxTimes uint `yaml:"max-times"` + Interval int `yaml:"interval"` + } + UseSSOAddress bool `yaml:"use-sso-address"` + } `yaml:"account"` + + Heartbeat struct { + Disabled bool `yaml:"disabled"` + Interval int `yaml:"interval"` + } `yaml:"heartbeat"` + + Message struct { + PostFormat string `yaml:"post-format"` + IgnoreInvalidCQCode bool `yaml:"ignore-invalid-cqcode"` + ForceFragment bool `yaml:"force-fragment"` + FixURL bool `yaml:"fix-url"` + ProxyRewrite string `yaml:"proxy-rewrite"` + ReportSelfMessage bool `yaml:"report-self-message"` + } `yaml:"message"` + + Output struct { + LogLevel string `yaml:"log-level"` + Debug bool `yaml:"debug"` + } `yaml:"output"` + + Servers []map[string]yaml.Node `yaml:"servers"` + Database map[string]yaml.Node `yaml:"database"` +} + +type MiddleWares struct { + AccessToken string `yaml:"access-token"` + Filter string `yaml:"filter"` + RateLimit struct { + Enabled bool `yaml:"enabled"` + Frequency float64 `yaml:"frequency"` + Bucket int `yaml:"bucket"` + } `yaml:"rate-limit"` +} + +// HTTPServer HTTP通信相关配置 +type HTTPServer struct { + Disabled bool `yaml:"disabled"` + Host string `yaml:"host"` + Port int `yaml:"port"` + Timeout int32 `yaml:"timeout"` + Post []struct { + URL string `yaml:"url"` + Secret string `yaml:"secret"` + } + + MiddleWares `yaml:"middlewares"` +} + +// WebsocketServer 正向WS相关配置 +type WebsocketServer struct { + Disabled bool `yaml:"disabled"` + Host string `yaml:"host"` + Port int `yaml:"port"` + + MiddleWares `yaml:"middlewares"` +} + +// WebsocketReverse 反向WS相关配置 +type WebsocketReverse struct { + Disabled bool `yaml:"disabled"` + Universal string `yaml:"universal"` + API string `yaml:"api"` + Event string `yaml:"event"` + ReconnectInterval int `yaml:"reconnect-interval"` + + MiddleWares `yaml:"middlewares"` +} + +// LevelDBConfig leveldb 相关配置 +type LevelDBConfig struct { + Enable bool `yaml:"enable"` +} + +// Get 从默认配置文件路径中获取 +func Get() *Config { + file, err := os.Open(DefaultConfigFile) + if err != nil { + log.Error("获取配置文件失败: ", err) + return nil + } + config := &Config{} + if yaml.NewDecoder(file).Decode(config) != nil { + log.Fatal("配置文件不合法!", err) + } + return config +} + +// getCurrentPath 获取当前文件的路径,直接返回string +func getCurrentPath() string { + cwd, e := os.Getwd() + if e != nil { + panic(e) + } + return cwd +} diff --git a/global/config/default_config.yml b/global/config/default_config.yml new file mode 100644 index 0000000..262f118 --- /dev/null +++ b/global/config/default_config.yml @@ -0,0 +1,113 @@ +# go-cqhttp 默认配置文件 + +account: # 账号相关 + uin: 1233456 # QQ账号 + password: '' # 密码为空时使用扫码登录 + encrypt: true # 是否开启密码加密 + relogin: # 重连设置 + disable: false + delay: 3 # 重连延迟, 单位秒 + interval: 0 # 重连间隔 + max-times: 0 # 最大重连次数, 0为无限制 + + # 是否使用服务器下发的新地址进行重连 + # 注意, 此设置可能导致在海外服务器上连接情况更差 + use-sso-address: true + +heartbeat: + disabled: false # 是否开启心跳事件上报 + # 心跳频率, 单位秒 + # -1 为关闭心跳 + interval: 5 + +message: + # 上报数据类型 + # 可选: string,array + post-format: string + # 是否忽略无效的CQ码, 如果为假将原样发送 + ignore-invalid-cqcode: false + # 是否强制分片发送消息 + # 分片发送将会带来更快的速度 + # 但是兼容性会有些问题 + force-fragment: false + # 是否将url分片发送 + fix-url: false + # 下载图片等请求网络代理 + proxy-rewrite: '' + # 是否上报自身消息 + report-self-message: false + +output: + # 日志等级 trace,debug,info,warn,error日志等级 trace,debug,info,warn,error + log-level: warn + # 是否启用 DEBUG + debug: false # 开启调试模式 + +default-middlewares: &default + # 访问密钥, 强烈推荐在公网的服务器设置 + access-token: '' + # 事件过滤器文件目录 + filter: '' + # API限速设置 + # 该设置为全局生效 + # 原 cqhttp 虽然启用了 rate_limit 后缀, 但是基本没插件适配 + # 目前该限速设置为令牌桶算法, 请参考: + # https://baike.baidu.com/item/%E4%BB%A4%E7%89%8C%E6%A1%B6%E7%AE%97%E6%B3%95/6597000?fr=aladdin + rate-limit: + enabled: false # 是否启用限速 + frequency: 1 # 令牌回复频率, 单位秒 + bucket: 1 # 令牌桶大小 + +servers: + # HTTP 通信设置 + - http: + # 是否关闭正向HTTP服务器 + disabled: true + # 服务端监听地址 + host: 127.0.0.1 + # 服务端监听端口 + port: 5700 + # 反向HTTP超时时间, 单位秒 + # 最小值为5,小于5将会忽略本项设置 + timeout: 5 + middlewares: + <<: *default + # 反向HTTP POST地址列表 + post: + - url: 127.0.0.1:5701 # 地址 + secret: '' # 密钥 + #- url: 127.0.0.1:5701 # 地址 + # secret: '' # 密钥 + + # 正向WS设置 + - ws: + # 是否禁用当前服务 + disable: true + # 正向WS服务器监听地址 + host: 127.0.0.1 + # 正向WS服务器监听端口 + port: 6700 + middlewares: + <<: *default + + - ws-reverse: + disable: true + # 反向WS Universal 地址 + # 注意 设置了此项地址后下面两项将会被忽略 + universal: ws://your_websocket_universal.server + # 反向WS API 地址 + api: ws://your_websocket_api.server + # 反向WS Event 地址 + event: ws://your_websocket_event.server + # 重连间隔 单位毫秒 + reconnect-interval: 3000 + middlewares: + <<: *default + #- ws-reverse: # 可添加多个 + +database: # 数据库相关设置 + leveldb: + # 是否启用内置leveldb数据库 + # 启用将会增加10-20MB的内存占用和一定的磁盘空间 + # 关闭将无法使用 撤回 回复 get_msg 等上下文相关功能 + enable: true diff --git a/global/filter.go b/global/filter.go index 11296b4..362f230 100644 --- a/global/filter.go +++ b/global/filter.go @@ -1,51 +1,15 @@ package global import ( - "fmt" - "io/ioutil" "regexp" "strings" - log "github.com/sirupsen/logrus" "github.com/tidwall/gjson" ) -// MSG 消息Map -type MSG map[string]interface{} - -// Get 尝试从消息Map中取出key为s的值,若不存在则返回MSG{} -// -// 若所给key对应的值的类型是global.MSG,则返回此值 -// -// 若所给key对应值的类型不是global.MSG,则返回MSG{"__str__": Val} -func (m MSG) Get(s string) MSG { - if v, ok := m[s]; ok { - if msg, ok := v.(MSG); ok { - return msg - } - return MSG{"__str__": v} // 用这个名字应该没问题吧 - } - return nil // 不存在为空 -} - -// String 将消息Map转化为String。若Map存在key "__str__",则返回此key对应的值,否则将输出整张消息Map对应的JSON字符串 -func (m MSG) String() string { - if m == nil { - return "" // 空 JSON - } - if str, ok := m["__str__"]; ok { - if str == nil { - return "" // 空 JSON - } - return fmt.Sprint(str) - } - str, _ := json.MarshalToString(m) - return str -} - // Filter 定义了一个消息上报过滤接口 type Filter interface { - Eval(payload MSG) bool + Eval(payload gjson.Result) bool } type operationNode struct { @@ -68,7 +32,7 @@ func notOperatorConstruct(argument gjson.Result) *NotOperator { } // Eval 对payload执行Not过滤 -func (op *NotOperator) Eval(payload MSG) bool { +func (op *NotOperator) Eval(payload gjson.Result) bool { return !op.operand.Eval(payload) } @@ -110,7 +74,7 @@ func andOperatorConstruct(argument gjson.Result) *AndOperator { } // Eval 对payload执行And过滤 -func (op *AndOperator) Eval(payload MSG) bool { +func (op *AndOperator) Eval(payload gjson.Result) bool { res := true for _, operand := range op.operands { if len(operand.key) == 0 { @@ -147,7 +111,7 @@ func orOperatorConstruct(argument gjson.Result) *OrOperator { } // Eval 对payload执行Or过滤 -func (op *OrOperator) Eval(payload MSG) bool { +func (op *OrOperator) Eval(payload gjson.Result) bool { res := false for _, operand := range op.operands { res = res || operand.Eval(payload) @@ -170,7 +134,7 @@ func equalOperatorConstruct(argument gjson.Result) *EqualOperator { } // Eval 对payload执行Equal过滤 -func (op *EqualOperator) Eval(payload MSG) bool { +func (op *EqualOperator) Eval(payload gjson.Result) bool { return payload.String() == op.operand } @@ -186,7 +150,7 @@ func notEqualOperatorConstruct(argument gjson.Result) *NotEqualOperator { } // Eval 对payload执行NotEqual过滤 -func (op *NotEqualOperator) Eval(payload MSG) bool { +func (op *NotEqualOperator) Eval(payload gjson.Result) bool { return !(payload.String() == op.operand) } @@ -214,7 +178,7 @@ func inOperatorConstruct(argument gjson.Result) *InOperator { } // Eval 对payload执行In过滤 -func (op *InOperator) Eval(payload MSG) bool { +func (op *InOperator) Eval(payload gjson.Result) bool { payloadStr := payload.String() if op.operandArray != nil { for _, value := range op.operandArray { @@ -242,7 +206,7 @@ func containsOperatorConstruct(argument gjson.Result) *ContainsOperator { } // Eval 对payload执行Contains过滤 -func (op *ContainsOperator) Eval(payload MSG) bool { +func (op *ContainsOperator) Eval(payload gjson.Result) bool { return strings.Contains(payload.String(), op.operand) } @@ -261,7 +225,7 @@ func regexOperatorConstruct(argument gjson.Result) *RegexOperator { } // Eval 对payload执行RegexO过滤 -func (op *RegexOperator) Eval(payload MSG) bool { +func (op *RegexOperator) Eval(payload gjson.Result) bool { matched := op.regex.MatchString(payload.String()) return matched } @@ -289,24 +253,3 @@ func Generate(opName string, argument gjson.Result) Filter { panic("the operator " + opName + " is not supported") } } - -// EventFilter 初始化一个nil过滤器 -var EventFilter Filter - -// BootFilter 启动事件过滤器 -func BootFilter() { - defer func() { - if e := recover(); e != nil { - log.Warnf("事件过滤器启动失败: %v", e) - EventFilter = nil - } else { - log.Info("事件过滤器启动成功.") - } - }() - f, err := ioutil.ReadFile("filter.json") - if err != nil { - panic(err) - } else { - EventFilter = Generate("and", gjson.ParseBytes(f)) - } -} diff --git a/global/ratelimit.go b/global/ratelimit.go deleted file mode 100644 index c1ad4e1..0000000 --- a/global/ratelimit.go +++ /dev/null @@ -1,25 +0,0 @@ -package global - -import ( - "context" - - "golang.org/x/time/rate" -) - -var ( - limiter *rate.Limiter - limitEnable = false -) - -// RateLimit 执行API调用速率限制 -func RateLimit(ctx context.Context) { - if limitEnable { - _ = limiter.Wait(ctx) - } -} - -// InitLimiter 初始化速率限制器 -func InitLimiter(frequency float64, bucketSize int) { - limitEnable = true - limiter = rate.NewLimiter(rate.Limit(frequency), bucketSize) -} diff --git a/go.mod b/go.mod index f0b10d0..2690707 100644 --- a/go.mod +++ b/go.mod @@ -9,7 +9,6 @@ require ( github.com/gin-gonic/gin v1.6.3 github.com/gorilla/websocket v1.4.2 github.com/guonaihong/gout v0.1.6 - github.com/hjson/hjson-go v3.1.0+incompatible github.com/jonboulle/clockwork v0.2.2 // indirect github.com/json-iterator/go v1.1.10 github.com/kardianos/osext v0.0.0-20190222173326-2bc1f35cddc0 @@ -28,4 +27,5 @@ require ( golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2 golang.org/x/term v0.0.0-20210317153231-de623e64d2a6 golang.org/x/time v0.0.0-20210220033141-f8bda1e9f3ba + gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c ) diff --git a/go.sum b/go.sum index 64da59a..1439c9f 100644 --- a/go.sum +++ b/go.sum @@ -56,8 +56,6 @@ github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0U github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/guonaihong/gout v0.1.6 h1:Txej4NYvVJLZkW0Xgw1HuWfSWow5BgLF6vqlM2kRdno= github.com/guonaihong/gout v0.1.6/go.mod h1:P6P8+0+toYgmhFqzLxVde+9vQbCDHrxn56V9TglC5io= -github.com/hjson/hjson-go v3.1.0+incompatible h1:DY/9yE8ey8Zv22bY+mHV1uk2yRy0h8tKhZ77hEdi0Aw= -github.com/hjson/hjson-go v3.1.0+incompatible/go.mod h1:qsetwF8NlsTsOTwZTApNlTCerV+b2GjYRRcIk4JMFio= github.com/hpcloud/tail v1.0.0 h1:nfCOvKYfkgYP8hkirhJocXT2+zOD8yUNjXaWfTlyFKI= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/jonboulle/clockwork v0.2.2 h1:UOGuzwb1PwsrDAObMuhUnj0p5ULPj8V/xJ7Kx9qUBdQ= diff --git a/main.go b/main.go index f2d990e..534540d 100644 --- a/main.go +++ b/main.go @@ -5,7 +5,6 @@ import ( "crypto/aes" "crypto/md5" "crypto/sha1" - "encoding/base64" "encoding/hex" "flag" "fmt" @@ -14,7 +13,6 @@ import ( "os/signal" "path" "runtime" - "strconv" "strings" "sync" "syscall" @@ -22,14 +20,13 @@ import ( "github.com/Mrs4s/go-cqhttp/coolq" "github.com/Mrs4s/go-cqhttp/global" + "github.com/Mrs4s/go-cqhttp/global/config" "github.com/Mrs4s/go-cqhttp/global/terminal" "github.com/Mrs4s/go-cqhttp/global/update" "github.com/Mrs4s/go-cqhttp/server" - "github.com/Mrs4s/MiraiGo/binary" "github.com/Mrs4s/MiraiGo/client" "github.com/guonaihong/gout" - jsoniter "github.com/json-iterator/go" rotatelogs "github.com/lestrrat-go/file-rotatelogs" log "github.com/sirupsen/logrus" easy "github.com/t-tomalak/logrus-easy-formatter" @@ -39,8 +36,7 @@ import ( ) var ( - json = jsoniter.ConfigCompatibleWithStandardLibrary - conf *global.JSONConfig + conf *config.Config isFastStart = false c string d bool @@ -49,14 +45,14 @@ var ( func init() { var debug bool - flag.StringVar(&c, "c", global.DefaultConfFile, "configuration filename default is config.hjson") + flag.StringVar(&c, "c", config.DefaultConfigFile, "configuration filename default is config.hjson") flag.BoolVar(&d, "d", false, "running as a daemon") flag.BoolVar(&debug, "D", false, "debug mode") flag.BoolVar(&h, "h", false, "this help") flag.Parse() // 通过-c 参数替换 配置文件路径 - global.DefaultConfFile = c + config.DefaultConfigFile = c logFormatter := &easy.Formatter{ TimestampFormat: "2006-01-02 15:04:05", LogFormat: "[%time%] [%lvl%]: %msg% \n", @@ -67,50 +63,23 @@ func init() { panic(err) } - conf = getConfig() + conf = config.Get() if conf == nil { - os.Exit(1) + _ = os.WriteFile("config.yml", []byte(config.DefaultConfig), 0644) + log.Error("未找到配置文件,默认配置文件已生成!") + readLine() + os.Exit(0) } if debug { - conf.Debug = true + conf.Output.Debug = true } // 在debug模式下,将在标准输出中打印当前执行行数 - if conf.Debug { + if conf.Output.Debug { log.SetReportCaller(true) } - log.AddHook(global.NewLocalHook(w, logFormatter, global.GetLogLevel(conf.LogLevel)...)) - - if global.PathExists("cqhttp.json") { - log.Info("发现 cqhttp.json 将在五秒后尝试导入配置,按 Ctrl+C 取消.") - log.Warn("警告: 该操作会删除 cqhttp.json 并覆盖 config.hjson 文件.") - time.Sleep(time.Second * 5) - conf := global.CQHTTPAPIConfig{} - if err := json.Unmarshal([]byte(global.ReadAllText("cqhttp.json")), &conf); err != nil { - log.Fatalf("读取文件 cqhttp.json 失败: %v", err) - } - goConf := global.DefaultConfig() - goConf.AccessToken = conf.AccessToken - goConf.HTTPConfig.Host = conf.Host - goConf.HTTPConfig.Port = conf.Port - goConf.WSConfig.Host = conf.WSHost - goConf.WSConfig.Port = conf.WSPort - if conf.PostURL != "" { - goConf.HTTPConfig.PostUrls[conf.PostURL] = conf.Secret - } - if conf.UseWsReverse { - goConf.ReverseServers[0].Enabled = true - goConf.ReverseServers[0].ReverseURL = conf.WSReverseURL - goConf.ReverseServers[0].ReverseAPIURL = conf.WSReverseAPIURL - goConf.ReverseServers[0].ReverseEventURL = conf.WSReverseEventURL - goConf.ReverseServers[0].ReverseReconnectInterval = conf.WSReverseReconnectInterval - } - if err := goConf.Save(global.DefaultConfFile); err != nil { - log.Fatalf("保存 %s 时出现错误: %v", global.DefaultConfFile, err) - } - _ = os.Remove("cqhttp.json") - } + log.AddHook(global.NewLocalHook(w, logFormatter, global.GetLogLevel(conf.Output.LogLevel)...)) if !global.PathExists(global.ImagePath) { if err := os.MkdirAll(global.ImagePath, 0755); err != nil { @@ -154,8 +123,7 @@ func main() { } case "key": if len(arg) > i+1 { - b := []byte(arg[i+1]) - byteKey = b + byteKey = []byte(arg[i+1]) } case "faststart": isFastStart = true @@ -167,7 +135,8 @@ func main() { log.Warning("将等待10s后启动") time.Sleep(time.Second * 10) } - if (conf.Uin == 0 || (conf.Password == "" && conf.PasswordEncrypted == "")) && !global.PathExists("session.token") { + + if (conf.Account.Uin == 0 || (conf.Account.Password == "" && !conf.Account.Encrypt)) && !global.PathExists("session.token") { log.Warn("账号密码未配置, 将使用二维码登录.") if !isFastStart { log.Warn("将在 5秒 后继续.") @@ -176,11 +145,10 @@ func main() { } log.Info("当前版本:", coolq.Version) - if conf.Debug { + if conf.Output.Debug { log.SetLevel(log.DebugLevel) log.Warnf("已开启Debug模式.") log.Debugf("开发交流群: 192548878") - server.Debug = true /* if conf.WebUI == nil || !conf.WebUI.Enabled { log.Warnf("警告: 在Debug模式下未启用WebUi服务, 将无法进行性能分析.") @@ -199,52 +167,56 @@ func main() { log.Fatalf("加载设备信息失败: %v", err) } } - if conf.EncryptPassword && conf.PasswordEncrypted == "" { - log.Infof("密码加密已启用, 请输入Key对密码进行加密: (Enter 提交)") - byteKey, _ = term.ReadPassword(int(os.Stdin.Fd())) - global.PasswordHash = md5.Sum([]byte(conf.Password)) - conf.Password = "" - conf.PasswordEncrypted = "AES:" + PasswordHashEncrypt(global.PasswordHash[:], byteKey) - _ = conf.Save(global.DefaultConfFile) - } - if conf.PasswordEncrypted != "" { - if len(byteKey) == 0 { - log.Infof("密码加密已启用, 请输入Key对密码进行解密以继续: (Enter 提交)") - cancel := make(chan struct{}, 1) - go func() { - select { - case <-cancel: - return - case <-time.After(time.Second * 45): - log.Infof("解密key输入超时") - time.Sleep(3 * time.Second) - os.Exit(0) - } - }() + + if conf.Account.Encrypt { + if !global.PathExists("password.encrypt") { + if conf.Account.Password == "" { + log.Error("无法进行加密,请在配置文件中的添加密码后重新启动.") + readLine() + os.Exit(0) + } + log.Infof("密码加密已启用, 请输入Key对密码进行加密: (Enter 提交)") byteKey, _ = term.ReadPassword(int(os.Stdin.Fd())) - cancel <- struct{}{} + global.PasswordHash = md5.Sum([]byte(conf.Account.Password)) + _ = os.WriteFile("password.encrypt", []byte(PasswordHashEncrypt(global.PasswordHash[:], byteKey)), 0644) + log.Info("密码已加密,为了您的账号安全,请删除配置文件中的密码后重新启动.") + readLine() + os.Exit(0) } else { - log.Infof("密码加密已启用, 使用运行时传递的参数进行解密,按 Ctrl+C 取消.") - } + if conf.Account.Password != "" { + log.Error("密码已加密,为了您的账号安全,请删除配置文件中的密码后重新启动.") + readLine() + os.Exit(0) + } - // 升级客户端密码加密方案,MD5+TEA 加密密码 -> PBKDF2+AES 加密 MD5 - // 升级后的 PasswordEncrypted 字符串以"AES:"开始,其后为 Hex 编码的16字节加密 MD5 - if !strings.HasPrefix(conf.PasswordEncrypted, "AES:") { - password := OldPasswordDecrypt(conf.PasswordEncrypted, byteKey) - passwordHash := md5.Sum([]byte(password)) - newPasswordHash := PasswordHashEncrypt(passwordHash[:], byteKey) - conf.PasswordEncrypted = "AES:" + newPasswordHash - _ = conf.Save(global.DefaultConfFile) - log.Debug("密码加密方案升级完成") - } + if len(byteKey) == 0 { + log.Infof("密码加密已启用, 请输入Key对密码进行解密以继续: (Enter 提交)") + cancel := make(chan struct{}, 1) + go func() { + select { + case <-cancel: + return + case <-time.After(time.Second * 45): + log.Infof("解密key输入超时") + time.Sleep(3 * time.Second) + os.Exit(0) + } + }() + byteKey, _ = term.ReadPassword(int(os.Stdin.Fd())) + cancel <- struct{}{} + } else { + log.Infof("密码加密已启用, 使用运行时传递的参数进行解密,按 Ctrl+C 取消.") + } - ph, err := PasswordHashDecrypt(conf.PasswordEncrypted[4:], byteKey) - if err != nil { - log.Fatalf("加密存储的密码损坏,请尝试重新配置密码") + encrypt, _ := os.ReadFile("password.encrypt") + ph, err := PasswordHashDecrypt(string(encrypt), byteKey) + if err != nil { + log.Fatalf("加密存储的密码损坏,请尝试重新配置密码") + } + copy(global.PasswordHash[:], ph) } - copy(global.PasswordHash[:], ph) - } else if conf.Password != "" { - global.PasswordHash = md5.Sum([]byte(conf.Password)) + } else { + global.PasswordHash = md5.Sum([]byte(conf.Account.Password)) } if !isFastStart { log.Info("Bot将在5秒后登录并开始信息处理, 按 Ctrl+C 取消.") @@ -265,8 +237,8 @@ func main() { return "未知" }()) cli = client.NewClientEmpty() - if conf.Uin != 0 && global.PasswordHash != [16]byte{} { - cli.Uin = conf.Uin + if conf.Account.Uin != 0 && global.PasswordHash != [16]byte{} { + cli.Uin = conf.Account.Uin cli.PasswordMd5 = global.PasswordHash } cli.OnLog(func(c *client.QQClient, e *client.LogEvent) { @@ -288,7 +260,7 @@ func main() { log.Infof("读取到 %v 个自定义地址.", len(addr)) } cli.OnServerUpdated(func(bot *client.QQClient, e *client.ServerUpdatedEvent) bool { - if !conf.UseSSOAddress { + if !conf.Account.UseSSOAddress { log.Infof("收到服务器地址更新通知, 根据配置文件已忽略.") return false } @@ -311,10 +283,10 @@ func main() { conf.WebUI.Host = "127.0.0.1" } */ - global.Proxy = conf.ProxyRewrite + global.Proxy = conf.Message.ProxyRewrite // b := server.WebServer.Run(fmt.Sprintf("%s:%d", conf.WebUI.Host, conf.WebUI.WebUIPort), cli) // c := server.Console - isQRCodeLogin := (conf.Uin == 0 || len(conf.Password) == 0) && len(conf.PasswordEncrypted) == 0 + isQRCodeLogin := (conf.Account.Uin == 0 || len(conf.Account.Password) == 0) && !conf.Account.Encrypt isTokenLogin := false if global.PathExists("session.token") { token, err := ioutil.ReadFile("session.token") @@ -348,14 +320,14 @@ func main() { reLoginLock.Lock() defer reLoginLock.Unlock() log.Warnf("Bot已离线: %v", e.Message) - if !conf.ReLogin.Enabled { + if conf.Account.ReLogin.Disabled { os.Exit(1) } - if times > conf.ReLogin.MaxReloginTimes && conf.ReLogin.MaxReloginTimes != 0 { + if times > conf.Account.ReLogin.MaxTimes && conf.Account.ReLogin.MaxTimes != 0 { log.Fatalf("Bot重连次数超过限制, 停止") } - if conf.ReLogin.ReLoginDelay > 0 { - log.Warnf("将在 %v 秒后尝试重连. 重连次数:%v/%v", conf.ReLogin.ReLoginDelay, times, conf.ReLogin.MaxReloginTimes) + if conf.Account.ReLogin.Interval > 0 { + log.Warnf("将在 %v 秒后尝试重连. 重连次数:%v/%v", conf.Account.ReLogin.Interval, times, conf.Account.ReLogin.MaxTimes) } log.Warnf("尝试重连...") if cli.Online { @@ -381,31 +353,42 @@ func main() { global.Check(cli.ReloadGroupList()) log.Infof("共加载 %v 个群.", len(cli.GroupList)) bot := coolq.NewQQBot(cli, conf) - if conf.PostMessageFormat != "string" && conf.PostMessageFormat != "array" { - log.Warnf("post_message_format 配置错误, 将自动使用 string") + _ = bot.Client + if conf.Message.PostFormat != "string" && conf.Message.PostFormat != "array" { + log.Warnf("post-format 配置错误, 将自动使用 string") coolq.SetMessageFormat("string") } else { - coolq.SetMessageFormat(conf.PostMessageFormat) - } - if conf.RateLimit.Enabled { - global.InitLimiter(conf.RateLimit.Frequency, conf.RateLimit.BucketSize) + coolq.SetMessageFormat(conf.Message.PostFormat) } log.Info("正在加载事件过滤器.") - global.BootFilter() - coolq.IgnoreInvalidCQCode = conf.IgnoreInvalidCQCode - coolq.SplitURL = conf.FixURL - coolq.ForceFragmented = conf.ForceFragmented - if conf.HTTPConfig != nil && conf.HTTPConfig.Enabled { - go server.CQHTTPApiServer.Run(fmt.Sprintf("%s:%d", conf.HTTPConfig.Host, conf.HTTPConfig.Port), conf.AccessToken, bot) - for k, v := range conf.HTTPConfig.PostUrls { - server.NewHTTPClient().Run(k, v, conf.HTTPConfig.Timeout, bot) + coolq.IgnoreInvalidCQCode = conf.Message.IgnoreInvalidCQCode + coolq.SplitURL = conf.Message.FixURL + coolq.ForceFragmented = conf.Message.ForceFragment + for _, m := range conf.Servers { + if h, ok := m["http"]; ok { + hc := new(config.HTTPServer) + if err := h.Decode(hc); err != nil { + log.Warn("读取http配置失败 :", err) + } else { + go server.RunHTTPServerAndClients(bot, hc) + } + } + if s, ok := m["ws"]; ok { + sc := new(config.WebsocketServer) + if err := s.Decode(sc); err != nil { + log.Warn("读取http配置失败 :", err) + } else { + go server.RunWebSocketServer(bot, sc) + } + } + if c, ok := m["ws-reverse"]; ok { + rc := new(config.WebsocketReverse) + if err := c.Decode(rc); err != nil { + log.Warn("读取http配置失败 :", err) + } else { + go server.RunWebSocketClient(bot, rc) + } } - } - if conf.WSConfig != nil && conf.WSConfig.Enabled { - go server.WebSocketServer.Run(fmt.Sprintf("%s:%d", conf.WSConfig.Host, conf.WSConfig.Port), conf.AccessToken, bot) - } - for _, rc := range conf.ReverseServers { - go server.NewWebSocketClient(rc, conf.AccessToken, bot).Run() } log.Info("资源初始化完成, 开始处理信息.") log.Info("アトリは、高性能ですから!") @@ -446,25 +429,6 @@ func PasswordHashDecrypt(encryptedPasswordHash string, key []byte) ([]byte, erro return result, nil } -// OldPasswordDecrypt 使用key解密老password,仅供兼容使用 -func OldPasswordDecrypt(encryptedPassword string, key []byte) string { - defer func() { - if pan := recover(); pan != nil { - log.Fatalf("密码解密失败: %v", pan) - } - }() - encKey := md5.Sum(key) - encrypted, err := base64.StdEncoding.DecodeString(encryptedPassword) - if err != nil { - panic(err) - } - tea := binary.NewTeaCipher(encKey[:]) - if tea == nil { - panic("密钥错误") - } - return string(tea.Decrypt(encrypted)) -} - func checkUpdate() { log.Infof("正在检查更新.") if coolq.Version == "(devel)" { @@ -578,56 +542,6 @@ func restart(args []string) { } */ -func getConfig() *global.JSONConfig { - var conf *global.JSONConfig - switch { - case global.PathExists("config.json"): - conf = global.LoadConfig("config.json") - _ = conf.Save("config.hjson") - _ = os.Remove("config.json") - case os.Getenv("UIN") != "": - log.Infof("将从环境变量加载配置.") - uin, _ := strconv.ParseInt(os.Getenv("UIN"), 10, 64) - pwd := os.Getenv("PASS") - post := os.Getenv("HTTP_POST") - conf = &global.JSONConfig{ - Uin: uin, - Password: pwd, - HTTPConfig: &global.GoCQHTTPConfig{ - Enabled: true, - Host: "0.0.0.0", - Port: 5700, - PostUrls: map[string]string{}, - }, - WSConfig: &global.GoCQWebSocketConfig{ - Enabled: true, - Host: "0.0.0.0", - Port: 6700, - }, - PostMessageFormat: "string", - Debug: os.Getenv("DEBUG") == "true", - } - if post != "" { - conf.HTTPConfig.PostUrls[post] = os.Getenv("HTTP_SECRET") - } - default: - conf = global.LoadConfig(global.DefaultConfFile) - } - if conf == nil { - err := global.WriteAllText(global.DefaultConfFile, global.DefaultConfigWithComments) - if err != nil { - log.Fatalf("创建默认配置文件时出现错误: %v", err) - return nil - } - log.Infof("默认配置文件已生成, 请编辑 %s 后重启程序.", global.DefaultConfFile) - if !isFastStart { - time.Sleep(time.Second * 5) - } - return nil - } - return conf -} - // help cli命令行-h的帮助提示 func help() { fmt.Printf(`go-cqhttp service diff --git a/server/api.go b/server/api.go index cbdf33d..0265520 100644 --- a/server/api.go +++ b/server/api.go @@ -13,8 +13,11 @@ type resultGetter interface { Get(string) gjson.Result } +type handler func(action string, p resultGetter) coolq.MSG + type apiCaller struct { - bot *coolq.CQBot + bot *coolq.CQBot + handlers []handler } func getLoginInfo(bot *coolq.CQBot, _ resultGetter) coolq.MSG { @@ -259,8 +262,9 @@ func getVipInfo(bot *coolq.CQBot, p resultGetter) coolq.MSG { return bot.CQGetVipInfo(p.Get("user_id").Int()) } -func reloadEventFilter(bot *coolq.CQBot, _ resultGetter) coolq.MSG { - return bot.CQReloadEventFilter() +func reloadEventFilter(_ *coolq.CQBot, p resultGetter) coolq.MSG { + addFilter(p.Get("file").String()) + return coolq.OK(nil) } func getGroupAtAllRemain(bot *coolq.CQBot, p resultGetter) coolq.MSG { @@ -376,8 +380,24 @@ var API = map[string]func(*coolq.CQBot, resultGetter) coolq.MSG{ } func (api *apiCaller) callAPI(action string, p resultGetter) coolq.MSG { + for _, fn := range api.handlers { + if ret := fn(action, p); ret != nil { + return ret + } + } if f, ok := API[action]; ok { return f(api.bot, p) } return coolq.Failed(404, "API_NOT_FOUND", "API不存在") } + +func (api *apiCaller) use(middlewares ...handler) { + api.handlers = append(api.handlers, middlewares...) +} + +func newAPICaller(bot *coolq.CQBot) *apiCaller { + return &apiCaller{ + bot: bot, + handlers: []handler{}, + } +} diff --git a/server/http.go b/server/http.go index 33a77b9..be1359d 100644 --- a/server/http.go +++ b/server/http.go @@ -6,13 +6,14 @@ import ( "crypto/hmac" "crypto/sha1" "encoding/hex" + "fmt" "net/http" "os" "strings" "time" "github.com/Mrs4s/go-cqhttp/coolq" - "github.com/Mrs4s/go-cqhttp/global" + "github.com/Mrs4s/go-cqhttp/global/config" "github.com/Mrs4s/MiraiGo/utils" "github.com/gin-gonic/gin" @@ -26,7 +27,7 @@ type httpServer struct { engine *gin.Engine bot *coolq.CQBot HTTP *http.Server - api apiCaller + api *apiCaller } // HTTPClient 反向HTTP上报客户端 @@ -34,6 +35,7 @@ type HTTPClient struct { bot *coolq.CQBot secret string addr string + filter string timeout int32 } @@ -41,17 +43,22 @@ type httpContext struct { ctx *gin.Context } -// CQHTTPApiServer CQHTTPApiServer实例 -var CQHTTPApiServer = &httpServer{} - -// Debug 是否启用Debug模式 -var Debug = false - -func (s *httpServer) Run(addr, authToken string, bot *coolq.CQBot) { +func RunHTTPServerAndClients(bot *coolq.CQBot, conf *config.HTTPServer) { + if conf.Disabled { + return + } + var ( + s = new(httpServer) + authToken = conf.AccessToken + addr = fmt.Sprintf("%s:%d", conf.Host, conf.Port) + ) gin.SetMode(gin.ReleaseMode) s.engine = gin.New() s.bot = bot - s.api = apiCaller{s.bot} + s.api = newAPICaller(s.bot) + if conf.RateLimit.Enabled { + s.api.use(rateLimit(conf.RateLimit.Frequency, conf.RateLimit.Bucket)) + } s.engine.Use(func(c *gin.Context) { if c.Request.Method != "GET" && c.Request.Method != "POST" { log.Warnf("已拒绝客户端 %v 的请求: 方法错误", c.Request.RemoteAddr) @@ -108,19 +115,25 @@ func (s *httpServer) Run(addr, authToken string, bot *coolq.CQBot) { os.Exit(1) } }() + + for _, c := range conf.Post { + go newHTTPClient().Run(c.URL, c.Secret, conf.Filter, conf.Timeout, bot) + } } -// NewHTTPClient 返回反向HTTP客户端 -func NewHTTPClient() *HTTPClient { +// newHTTPClient 返回反向HTTP客户端 +func newHTTPClient() *HTTPClient { return &HTTPClient{} } // Run 运行反向HTTP服务 -func (c *HTTPClient) Run(addr, secret string, timeout int32, bot *coolq.CQBot) { +func (c *HTTPClient) Run(addr, secret, filter string, timeout int32, bot *coolq.CQBot) { c.bot = bot c.secret = secret c.addr = addr c.timeout = timeout + c.filter = filter + addFilter(filter) if c.timeout < 5 { c.timeout = 5 } @@ -130,6 +143,14 @@ func (c *HTTPClient) Run(addr, secret string, timeout int32, bot *coolq.CQBot) { func (c *HTTPClient) onBotPushEvent(m *bytes.Buffer) { var res string + if c.filter != "" { + filter := findFilter(c.filter) + if filter != nil && !filter.Eval(gjson.Parse(utils.B2S(m.Bytes()))) { + log.Debugf("上报Event %v 到 HTTP 服务器 %v 时被过滤.", c.addr, utils.B2S(m.Bytes())) + return + } + } + err := gout.POST(c.addr).SetJSON(m.Bytes()).BindBody(&res).SetHeader(func() gout.H { h := gout.H{ "X-Self-ID": c.bot.Client.Uin, @@ -165,7 +186,6 @@ func (c *HTTPClient) onBotPushEvent(m *bytes.Buffer) { } func (s *httpServer) HandleActions(c *gin.Context) { - global.RateLimit(context.Background()) action := strings.ReplaceAll(c.Param("action"), "_async", "") log.Debugf("HTTPServer接收到API调用: %v", action) c.JSON(200, s.api.callAPI(action, httpContext{ctx: c})) diff --git a/server/middlewares.go b/server/middlewares.go new file mode 100644 index 0000000..b8f093b --- /dev/null +++ b/server/middlewares.go @@ -0,0 +1,56 @@ +package server + +import ( + "context" + "os" + "sync" + + "github.com/Mrs4s/go-cqhttp/coolq" + "github.com/Mrs4s/go-cqhttp/global" + + log "github.com/sirupsen/logrus" + "github.com/tidwall/gjson" + "golang.org/x/time/rate" +) + +var ( + filters = make(map[string]global.Filter) + filterMutex sync.RWMutex +) + +func rateLimit(frequency float64, bucketSize int) handler { + limiter := rate.NewLimiter(rate.Limit(frequency), bucketSize) + return func(_ string, _ resultGetter) coolq.MSG { + _ = limiter.Wait(context.Background()) + return nil + } +} + +func addFilter(file string) { + if file == "" { + return + } + bs, err := os.ReadFile(file) + if err != nil { + log.Error("init filter error: ", err) + return + } + defer func() { + if err := recover(); err != nil { + log.Error("init filter error: ", err) + } + }() + filter := global.Generate("and", gjson.ParseBytes(bs)) + filterMutex.Lock() + filters[file] = filter + filterMutex.Unlock() +} + +func findFilter(file string) global.Filter { + if file == "" { + return nil + } + filterMutex.RLock() + defer filterMutex.RUnlock() + return filters[file] +} diff --git a/server/websocket.go b/server/websocket.go index 728380f..9b10613 100644 --- a/server/websocket.go +++ b/server/websocket.go @@ -2,7 +2,6 @@ package server import ( "bytes" - "context" "fmt" "net/http" "runtime/debug" @@ -13,6 +12,7 @@ import ( "github.com/Mrs4s/go-cqhttp/coolq" "github.com/Mrs4s/go-cqhttp/global" + "github.com/Mrs4s/go-cqhttp/global/config" "github.com/Mrs4s/MiraiGo/utils" "github.com/gorilla/websocket" @@ -21,43 +21,50 @@ import ( ) type webSocketServer struct { - bot *coolq.CQBot - token string + bot *coolq.CQBot + conf *config.WebsocketServer + eventConn []*webSocketConn eventConnMutex sync.Mutex + token string handshake string + filter string } // WebSocketClient WebSocket客户端实例 type WebSocketClient struct { - conf *global.GoCQReverseWebSocketConfig - token string - bot *coolq.CQBot + bot *coolq.CQBot + conf *config.WebsocketReverse universalConn *webSocketConn eventConn *webSocketConn + token string + filter string } type webSocketConn struct { *websocket.Conn sync.Mutex - apiCaller apiCaller + apiCaller *apiCaller } -// WebSocketServer 初始化一个WebSocketServer实例 -var WebSocketServer = &webSocketServer{} - var upgrader = websocket.Upgrader{ CheckOrigin: func(r *http.Request) bool { return true }, } -func (s *webSocketServer) Run(addr, authToken string, b *coolq.CQBot) { - s.token = authToken +// RunWebSocketServer 运行一个正向WS server +func RunWebSocketServer(b *coolq.CQBot, conf *config.WebsocketServer) { + s := new(webSocketServer) + s.conf = conf s.bot = b + s.token = conf.AccessToken + s.filter = conf.Filter + addFilter(s.filter) + addr := fmt.Sprintf("%s:%d", conf.Host, conf.Port) s.handshake = fmt.Sprintf(`{"_post_method":2,"meta_event_type":"lifecycle","post_type":"meta_event","self_id":%d,"sub_type":"connect","time":%d}`, - s.bot.Client.Uin, time.Now().Unix()) + b.Client.Uin, time.Now().Unix()) b.OnEventPush(s.onBotPushEvent) http.HandleFunc("/event", s.event) http.HandleFunc("/api", s.api) @@ -68,23 +75,24 @@ func (s *webSocketServer) Run(addr, authToken string, b *coolq.CQBot) { }() } -// NewWebSocketClient 初始化一个NWebSocket客户端 -func NewWebSocketClient(conf *global.GoCQReverseWebSocketConfig, authToken string, b *coolq.CQBot) *WebSocketClient { - return &WebSocketClient{conf: conf, token: authToken, bot: b} -} - -// Run 运行实例 -func (c *WebSocketClient) Run() { - if !c.conf.Enabled { +// RunWebSocketClient 运行一个正向WS client +func RunWebSocketClient(b *coolq.CQBot, conf *config.WebsocketReverse) { + if conf.Disabled { return } - if c.conf.ReverseURL != "" { + c := new(WebSocketClient) + c.bot = b + c.conf = conf + c.token = conf.AccessToken + c.filter = conf.Filter + addFilter(c.filter) + if c.conf.Universal != "" { c.connectUniversal() } else { - if c.conf.ReverseAPIURL != "" { + if c.conf.API != "" { c.connectAPI() } - if c.conf.ReverseEventURL != "" { + if c.conf.Event != "" { c.connectEvent() } } @@ -92,7 +100,7 @@ func (c *WebSocketClient) Run() { } func (c *WebSocketClient) connectAPI() { - log.Infof("开始尝试连接到反向WebSocket API服务器: %v", c.conf.ReverseAPIURL) + log.Infof("开始尝试连接到反向WebSocket API服务器: %v", c.conf.API) header := http.Header{ "X-Client-Role": []string{"API"}, "X-Self-ID": []string{strconv.FormatInt(c.bot.Client.Uin, 10)}, @@ -101,22 +109,25 @@ func (c *WebSocketClient) connectAPI() { if c.token != "" { header["Authorization"] = []string{"Token " + c.token} } - conn, _, err := websocket.DefaultDialer.Dial(c.conf.ReverseAPIURL, header) // nolint + conn, _, err := websocket.DefaultDialer.Dial(c.conf.API, header) // nolint if err != nil { - log.Warnf("连接到反向WebSocket API服务器 %v 时出现错误: %v", c.conf.ReverseAPIURL, err) - if c.conf.ReverseReconnectInterval != 0 { - time.Sleep(time.Millisecond * time.Duration(c.conf.ReverseReconnectInterval)) + log.Warnf("连接到反向WebSocket API服务器 %v 时出现错误: %v", c.conf.API, err) + if c.conf.ReconnectInterval != 0 { + time.Sleep(time.Millisecond * time.Duration(c.conf.ReconnectInterval)) c.connectAPI() } return } - log.Infof("已连接到反向WebSocket API服务器 %v", c.conf.ReverseAPIURL) - wrappedConn := &webSocketConn{Conn: conn, apiCaller: apiCaller{c.bot}} + log.Infof("已连接到反向WebSocket API服务器 %v", c.conf.API) + wrappedConn := &webSocketConn{Conn: conn, apiCaller: newAPICaller(c.bot)} + if c.conf.RateLimit.Enabled { + wrappedConn.apiCaller.use(rateLimit(c.conf.RateLimit.Frequency, c.conf.RateLimit.Bucket)) + } go c.listenAPI(wrappedConn, false) } func (c *WebSocketClient) connectEvent() { - log.Infof("开始尝试连接到反向WebSocket Event服务器: %v", c.conf.ReverseEventURL) + log.Infof("开始尝试连接到反向WebSocket Event服务器: %v", c.conf.Event) header := http.Header{ "X-Client-Role": []string{"Event"}, "X-Self-ID": []string{strconv.FormatInt(c.bot.Client.Uin, 10)}, @@ -125,11 +136,11 @@ func (c *WebSocketClient) connectEvent() { if c.token != "" { header["Authorization"] = []string{"Token " + c.token} } - conn, _, err := websocket.DefaultDialer.Dial(c.conf.ReverseEventURL, header) // nolint + conn, _, err := websocket.DefaultDialer.Dial(c.conf.Event, header) // nolint if err != nil { - log.Warnf("连接到反向WebSocket Event服务器 %v 时出现错误: %v", c.conf.ReverseEventURL, err) - if c.conf.ReverseReconnectInterval != 0 { - time.Sleep(time.Millisecond * time.Duration(c.conf.ReverseReconnectInterval)) + log.Warnf("连接到反向WebSocket Event服务器 %v 时出现错误: %v", c.conf.Event, err) + if c.conf.ReconnectInterval != 0 { + time.Sleep(time.Millisecond * time.Duration(c.conf.ReconnectInterval)) c.connectEvent() } return @@ -142,12 +153,12 @@ func (c *WebSocketClient) connectEvent() { log.Warnf("反向WebSocket 握手时出现错误: %v", err) } - log.Infof("已连接到反向WebSocket Event服务器 %v", c.conf.ReverseEventURL) - c.eventConn = &webSocketConn{Conn: conn, apiCaller: apiCaller{c.bot}} + log.Infof("已连接到反向WebSocket Event服务器 %v", c.conf.Event) + c.eventConn = &webSocketConn{Conn: conn, apiCaller: newAPICaller(c.bot)} } func (c *WebSocketClient) connectUniversal() { - log.Infof("开始尝试连接到反向WebSocket Universal服务器: %v", c.conf.ReverseURL) + log.Infof("开始尝试连接到反向WebSocket Universal服务器: %v", c.conf.Universal) header := http.Header{ "X-Client-Role": []string{"Universal"}, "X-Self-ID": []string{strconv.FormatInt(c.bot.Client.Uin, 10)}, @@ -156,11 +167,11 @@ func (c *WebSocketClient) connectUniversal() { if c.token != "" { header["Authorization"] = []string{"Token " + c.token} } - conn, _, err := websocket.DefaultDialer.Dial(c.conf.ReverseURL, header) // nolint + conn, _, err := websocket.DefaultDialer.Dial(c.conf.Universal, header) // nolint if err != nil { - log.Warnf("连接到反向WebSocket Universal服务器 %v 时出现错误: %v", c.conf.ReverseURL, err) - if c.conf.ReverseReconnectInterval != 0 { - time.Sleep(time.Millisecond * time.Duration(c.conf.ReverseReconnectInterval)) + log.Warnf("连接到反向WebSocket Universal服务器 %v 时出现错误: %v", c.conf.Universal, err) + if c.conf.ReconnectInterval != 0 { + time.Sleep(time.Millisecond * time.Duration(c.conf.ReconnectInterval)) c.connectUniversal() } return @@ -172,7 +183,10 @@ func (c *WebSocketClient) connectUniversal() { log.Warnf("反向WebSocket 握手时出现错误: %v", err) } - wrappedConn := &webSocketConn{Conn: conn, apiCaller: apiCaller{c.bot}} + wrappedConn := &webSocketConn{Conn: conn, apiCaller: newAPICaller(c.bot)} + if c.conf.RateLimit.Enabled { + wrappedConn.apiCaller.use(rateLimit(c.conf.RateLimit.Frequency, c.conf.RateLimit.Bucket)) + } go c.listenAPI(wrappedConn, true) c.universalConn = wrappedConn } @@ -200,8 +214,8 @@ func (c *WebSocketClient) listenAPI(conn *webSocketConn, u bool) { global.PutBuffer(buffer) } } - if c.conf.ReverseReconnectInterval != 0 { - time.Sleep(time.Millisecond * time.Duration(c.conf.ReverseReconnectInterval)) + if c.conf.ReconnectInterval != 0 { + time.Sleep(time.Millisecond * time.Duration(c.conf.ReconnectInterval)) if !u { go c.connectAPI() } @@ -209,6 +223,12 @@ func (c *WebSocketClient) listenAPI(conn *webSocketConn, u bool) { } func (c *WebSocketClient) onBotPushEvent(m *bytes.Buffer) { + filter := findFilter(c.filter) + if filter != nil && !filter.Eval(gjson.Parse(utils.B2S(m.Bytes()))) { + log.Debugf("上报Event %v 到 WS客户端 时被过滤.", utils.B2S(m.Bytes())) + return + } + if c.eventConn != nil { log.Debugf("向WS服务器 %v 推送Event: %v", c.eventConn.RemoteAddr().String(), utils.B2S(m.Bytes())) conn := c.eventConn @@ -218,8 +238,8 @@ func (c *WebSocketClient) onBotPushEvent(m *bytes.Buffer) { if err := c.eventConn.WriteMessage(websocket.TextMessage, m.Bytes()); err != nil { log.Warnf("向WS服务器 %v 推送Event时出现错误: %v", c.eventConn.RemoteAddr().String(), err) _ = c.eventConn.Close() - if c.conf.ReverseReconnectInterval != 0 { - time.Sleep(time.Millisecond * time.Duration(c.conf.ReverseReconnectInterval)) + if c.conf.ReconnectInterval != 0 { + time.Sleep(time.Millisecond * time.Duration(c.conf.ReconnectInterval)) c.connectEvent() } } @@ -233,8 +253,8 @@ func (c *WebSocketClient) onBotPushEvent(m *bytes.Buffer) { if err := c.universalConn.WriteMessage(websocket.TextMessage, m.Bytes()); err != nil { log.Warnf("向WS服务器 %v 推送Event时出现错误: %v", c.universalConn.RemoteAddr().String(), err) _ = c.universalConn.Close() - if c.conf.ReverseReconnectInterval != 0 { - time.Sleep(time.Millisecond * time.Duration(c.conf.ReverseReconnectInterval)) + if c.conf.ReconnectInterval != 0 { + time.Sleep(time.Millisecond * time.Duration(c.conf.ReconnectInterval)) c.connectUniversal() } } @@ -242,7 +262,7 @@ func (c *WebSocketClient) onBotPushEvent(m *bytes.Buffer) { } func (s *webSocketServer) event(w http.ResponseWriter, r *http.Request) { - if s.token != "" { + if s.conf.AccessToken != "" { if auth := r.URL.Query().Get("access_token"); auth != s.token { if auth := strings.SplitN(r.Header.Get("Authorization"), " ", 2); len(auth) != 2 || auth[1] != s.token { log.Warnf("已拒绝 %v 的 WebSocket 请求: Token鉴权失败", r.RemoteAddr) @@ -265,7 +285,7 @@ func (s *webSocketServer) event(w http.ResponseWriter, r *http.Request) { log.Infof("接受 WebSocket 连接: %v (/event)", r.RemoteAddr) - conn := &webSocketConn{Conn: c, apiCaller: apiCaller{s.bot}} + conn := &webSocketConn{Conn: c, apiCaller: newAPICaller(s.bot)} s.eventConnMutex.Lock() s.eventConn = append(s.eventConn, conn) @@ -288,7 +308,10 @@ func (s *webSocketServer) api(w http.ResponseWriter, r *http.Request) { return } log.Infof("接受 WebSocket 连接: %v (/api)", r.RemoteAddr) - conn := &webSocketConn{Conn: c, apiCaller: apiCaller{s.bot}} + conn := &webSocketConn{Conn: c, apiCaller: newAPICaller(s.bot)} + if s.conf.RateLimit.Enabled { + conn.apiCaller.use(rateLimit(s.conf.RateLimit.Frequency, s.conf.RateLimit.Bucket)) + } go s.listenAPI(conn) } @@ -314,7 +337,10 @@ func (s *webSocketServer) any(w http.ResponseWriter, r *http.Request) { return } log.Infof("接受 WebSocket 连接: %v (/)", r.RemoteAddr) - conn := &webSocketConn{Conn: c, apiCaller: apiCaller{s.bot}} + conn := &webSocketConn{Conn: c, apiCaller: newAPICaller(s.bot)} + if s.conf.RateLimit.Enabled { + conn.apiCaller.use(rateLimit(s.conf.RateLimit.Frequency, s.conf.RateLimit.Bucket)) + } s.eventConnMutex.Lock() s.eventConn = append(s.eventConn, conn) s.eventConnMutex.Unlock() @@ -352,7 +378,6 @@ func (c *webSocketConn) handleRequest(_ *coolq.CQBot, payload []byte) { _ = c.Close() } }() - global.RateLimit(context.Background()) j := gjson.ParseBytes(payload) t := strings.ReplaceAll(j.Get("action").Str, "_async", "") log.Debugf("WS接收到API调用: %v 参数: %v", t, j.Get("params").Raw) @@ -368,6 +393,14 @@ func (c *webSocketConn) handleRequest(_ *coolq.CQBot, payload []byte) { func (s *webSocketServer) onBotPushEvent(m *bytes.Buffer) { s.eventConnMutex.Lock() defer s.eventConnMutex.Unlock() + + filter := findFilter(s.filter) + if filter != nil && !filter.Eval(gjson.Parse(utils.B2S(m.Bytes()))) { + fmt.Printf("1213") + log.Debugf("上报Event %v 到 WS客户端 时被过滤.", utils.B2S(m.Bytes())) + return + } + for i, l := 0, len(s.eventConn); i < l; i++ { conn := s.eventConn[i] log.Debugf("向WS客户端 %v 推送Event: %v", conn.RemoteAddr().String(), utils.B2S(m.Bytes()))