1
0
mirror of https://github.com/Mrs4s/go-cqhttp.git synced 2025-05-05 03:23:49 +08:00

feat: routing api by version & inject version field

This commit is contained in:
Mrs4s 2022-06-15 16:02:52 +08:00
parent 7349fd4b82
commit f900fd62fb
No known key found for this signature in database
GPG Key ID: 3186E98FA19CE3A7
8 changed files with 286 additions and 108 deletions

View File

@ -23,7 +23,10 @@ type Param struct {
type Router struct {
Func string
Version []uint16
Path []string
PathV11 []string // v11 only
PathV12 []string // v12 only
Params []Param
}
@ -31,29 +34,90 @@ type generator struct {
out io.Writer
}
const (
PathAll = 0
PathV11 = 11
PathV12 = 12
)
func (g *generator) WriteString(s string) {
io.WriteString(g.out, s)
}
func (g *generator) generate(routers []Router) {
var actions []string // for onebot v12 get_supported_actions
for _, router := range routers {
if len(router.PathV12) > 0 {
actions = append(actions, router.PathV12...)
}
if len(router.Path) > 0 {
actions = append(actions, router.Path...)
}
}
for i := range actions {
actions[i] = `"` + actions[i] + `"`
}
// TODO: v12 和 all 的 switch-case 由常量改为数组寻址, 以利用 get_supported_actions
g.WriteString("// Code generated by cmd/api-generator. DO NOT EDIT.\n\n")
g.WriteString("package api\n\nimport (\n\n")
g.WriteString("\"github.com/Mrs4s/go-cqhttp/coolq\"\n")
g.WriteString("\"github.com/Mrs4s/go-cqhttp/global\"\n")
g.WriteString(")\n\n")
g.WriteString(`func (c *Caller) call(action string, p Getter) global.MSG {
switch action {
default:
return coolq.Failed(404, "API_NOT_FOUND", "API不存在")` + "\n")
for _, router := range routers {
g.router(router)
g.WriteString(fmt.Sprintf(`func (c *Caller) call(action string, version uint16, p Getter) global.MSG {
if version == 12 {
if action == "get_supported_actions" {
return coolq.OK([]string{%v})
}
io.WriteString(g.out, ` }}`)
switch action {
`, strings.Join(actions, ",")))
for _, router := range routers {
g.router(router, PathV12)
}
io.WriteString(g.out, `}}`)
io.WriteString(g.out, "\n")
g.WriteString(`if version == 11 {
switch action {
`)
for _, router := range routers {
g.router(router, PathV11)
}
io.WriteString(g.out, `}}`)
io.WriteString(g.out, "\n")
io.WriteString(g.out, "switch action {\n")
for _, router := range routers {
g.router(router, PathAll)
}
io.WriteString(g.out, `}`)
io.WriteString(g.out, "\n")
io.WriteString(g.out, "return coolq.Failed(404, \"API_NOT_FOUND\", \"API不存在\")}")
}
func (g *generator) router(router Router) {
func (g *generator) router(router Router, pathVersion int) {
/*
checkVersion := func(v uint16) bool {
for _, ver := range router.Version {
if ver == v {
return true
}
}
return false
}
*/
path := router.Path
if pathVersion == PathV11 {
path = router.PathV11
}
if pathVersion == PathV12 {
path = router.PathV12
}
if len(path) == 0 {
return
}
g.WriteString(`case `)
for i, p := range router.Path {
for i, p := range path {
if i != 0 {
g.WriteString(`, `)
}
@ -61,7 +125,19 @@ func (g *generator) router(router Router) {
}
g.WriteString(":\n")
if len(router.Version) == 1 { // 目前来说只需要判断一个版本的情况
check := make([]string, 0, len(router.Version))
for _, ver := range router.Version {
check = append(check, fmt.Sprintf("version != %v", ver))
}
fmt.Fprintf(g.out, "if %v {\n", strings.Join(check, " && "))
fmt.Fprintf(g.out, "return coolq.Failed(405, \"VERSION_ERROR\", \"API版本不匹配\")}\n")
}
for i, p := range router.Params {
if p.Name == "version" {
continue
}
if p.Default == "" {
v := "p.Get(" + strconv.Quote(p.Name) + ")"
fmt.Fprintf(g.out, "p%d := %s\n", i, conv(v, p.Type))
@ -73,10 +149,14 @@ func (g *generator) router(router Router) {
}
g.WriteString("\t\treturn c.bot." + router.Func + "(")
for i := range router.Params {
for i, p := range router.Params {
if i != 0 {
g.WriteString(", ")
}
if p.Name == "version" {
fmt.Fprintf(g.out, "version")
continue
}
fmt.Fprintf(g.out, "p%d", i)
}
g.WriteString(")\n")
@ -100,6 +180,8 @@ func conv(v, t string) string {
return v + ".Uint()"
case "uint32":
return "uint32(" + v + ".Uint())"
case "uint16":
return "uint16(" + v + ".Uint())"
}
}
@ -108,7 +190,8 @@ func main() {
src := flag.String("path", "", "source file")
flag.Parse()
fset := token.NewFileSet()
file, err := parser.ParseFile(fset, *src, nil, parser.ParseComments)
for _, s := range strings.Split(*src, ",") {
file, err := parser.ParseFile(fset, s, nil, parser.ParseComments)
if err != nil {
panic(err)
}
@ -145,6 +228,14 @@ func main() {
for _, route := range strings.Split(args, ",") {
router.Path = append(router.Path, unquote(route))
}
case "route11":
for _, route := range strings.Split(args, ",") {
router.PathV11 = append(router.PathV11, unquote(route))
}
case "route12":
for _, route := range strings.Split(args, ",") {
router.PathV12 = append(router.PathV12, unquote(route))
}
case "default":
for name, value := range parseMap(args, "=") {
for i, p := range router.Params {
@ -161,21 +252,47 @@ func main() {
}
}
}
case "version":
version := strings.Split(args, ",")
for _, v := range version {
if i, err := strconv.ParseUint(v, 10, 16); err == nil {
router.Version = append(router.Version, uint16(i))
}
}
}
sort.Slice(router.Path, func(i, j int) bool {
return router.Path[i] < router.Path[j]
})
sort.Slice(router.PathV11, func(i, j int) bool {
return router.PathV11[i] < router.PathV11[j]
})
sort.Slice(router.PathV12, func(i, j int) bool {
return router.PathV12[i] < router.PathV12[j]
})
}
if router.Path != nil {
if router.Path != nil || router.PathV11 != nil || router.PathV12 != nil {
routers = append(routers, router)
} else {
println(decl.Name.Name)
}
}
}
}
sort.Slice(routers, func(i, j int) bool {
return routers[i].Path[0] < routers[j].Path[0]
path := func(r Router) string {
if r.Path != nil {
return r.Path[0]
}
if r.PathV11 != nil {
return r.PathV11[0]
}
if r.PathV12 != nil {
return r.PathV12[0]
}
return ""
}
return path(routers[i]) < path(routers[j])
})
out := new(bytes.Buffer)

View File

@ -14,6 +14,7 @@ import (
"strings"
"time"
"github.com/Mrs4s/go-cqhttp/internal/base"
"github.com/segmentio/asm/base64"
"github.com/Mrs4s/MiraiGo/binary"
@ -25,7 +26,6 @@ import (
"github.com/Mrs4s/go-cqhttp/db"
"github.com/Mrs4s/go-cqhttp/global"
"github.com/Mrs4s/go-cqhttp/internal/base"
"github.com/Mrs4s/go-cqhttp/internal/cache"
"github.com/Mrs4s/go-cqhttp/internal/param"
"github.com/Mrs4s/go-cqhttp/modules/filter"
@ -47,7 +47,8 @@ var defaultPageToken = guildMemberPageToken{
// CQGetLoginInfo 获取登录号信息
//
// https://git.io/Jtz1I
// @route(get_login_info)
// @route11(get_login_info)
// @route12(get_self_info)
func (bot *CQBot) CQGetLoginInfo() global.MSG {
return OK(global.MSG{"user_id": bot.Client.Uin, "nickname": bot.Client.Nickname})
}
@ -655,7 +656,7 @@ func (bot *CQBot) CQGroupFileDeleteFile(groupID int64, id string, busID int32) g
//
// https://docs.go-cqhttp.org/api/#%E8%8E%B7%E5%8F%96%E4%B8%AD%E6%96%87%E5%88%86%E8%AF%8D-%E9%9A%90%E8%97%8F-api
// @route(.get_word_slices)
func (bot *CQBot) CQGetWordSlices(content string) global.MSG {
func (bot *CQBot) CQGetWordSlices(content string, version uint16) global.MSG {
slices, err := bot.Client.GetWordSegmentation(content)
if err != nil {
return Failed(100, "WORD_SEGMENTATION_API_ERROR", err.Error())
@ -1342,7 +1343,8 @@ func (bot *CQBot) CQGetGroupHonorInfo(groupID int64, t string) global.MSG {
// CQGetStrangerInfo 获取陌生人信息
//
// https://git.io/Jtz17
// @route(get_stranger_info)
// @route11(get_stranger_info)
// @route12(get_user_info)
func (bot *CQBot) CQGetStrangerInfo(userID int64) global.MSG {
info, err := bot.Client.GetSummaryInfo(userID)
if err != nil {
@ -1755,7 +1757,7 @@ func (bot *CQBot) CQGetOnlineClients(noCache bool) global.MSG {
// CQCanSendImage 检查是否可以发送图片(此处永远返回true)
//
// https://git.io/Jtz1N
// @route(can_send_image)
// @route11(can_send_image)
func (bot *CQBot) CQCanSendImage() global.MSG {
return OK(global.MSG{"yes": true})
}
@ -1763,7 +1765,7 @@ func (bot *CQBot) CQCanSendImage() global.MSG {
// CQCanSendRecord 检查是否可以发送语音(此处永远返回true)
//
// https://git.io/Jtz1x
// @route(can_send_record)
// @route11(can_send_record)
func (bot *CQBot) CQCanSendRecord() global.MSG {
return OK(global.MSG{"yes": true})
}
@ -1831,7 +1833,8 @@ func (bot *CQBot) CQSetGroupAnonymousBan(groupID int64, flag string, duration in
//
// https://git.io/JtzMe
// @route(get_status)
func (bot *CQBot) CQGetStatus() global.MSG {
func (bot *CQBot) CQGetStatus(version uint16) global.MSG {
if version == 11 {
return OK(global.MSG{
"app_initialized": true,
"app_enabled": true,
@ -1841,6 +1844,12 @@ func (bot *CQBot) CQGetStatus() global.MSG {
"good": bot.Client.Online.Load(),
"stat": bot.Client.GetStatistics(),
})
}
return OK(global.MSG{
"online": bot.Client.Online.Load(),
"good": bot.Client.Online.Load(),
"stat": bot.Client.GetStatistics(),
})
}
// CQSetEssenceMessage 扩展API-设置精华消息
@ -1917,7 +1926,7 @@ func (bot *CQBot) CQCheckURLSafely(url string) global.MSG {
// CQGetVersionInfo 获取版本信息
//
// https://git.io/JtwUs
// @route(get_version_info)
// @route11(get_version_info)
func (bot *CQBot) CQGetVersionInfo() global.MSG {
wd, _ := os.Getwd()
return OK(global.MSG{
@ -2044,7 +2053,7 @@ func (bot *CQBot) CQReloadEventFilter(file string) global.MSG {
// OK 生成成功返回值
func OK(data interface{}) global.MSG {
return global.MSG{"data": data, "retcode": 0, "status": "ok"}
return global.MSG{"data": data, "retcode": 0, "status": "ok", "message": ""}
}
// Failed 生成失败返回值
@ -2056,7 +2065,7 @@ func Failed(code int, msg ...string) global.MSG {
if len(msg) > 1 {
w = msg[1]
}
return global.MSG{"data": nil, "retcode": code, "msg": m, "wording": w, "status": "failed"}
return global.MSG{"data": nil, "retcode": code, "msg": m, "wording": w, "message": w, "status": "failed"}
}
func limitedString(str string) string {

23
coolq/api_v12.go Normal file
View File

@ -0,0 +1,23 @@
package coolq
import (
"runtime"
"github.com/Mrs4s/go-cqhttp/global"
"github.com/Mrs4s/go-cqhttp/internal/base"
)
// CQGetVersion 获取版本信息 OneBotV12
//
// https://git.io/JtwUs
// @route12(get_version)
func (bot *CQBot) CQGetVersion() global.MSG {
return OK(global.MSG{
"impl": "go_cqhttp",
"platform": "qq",
"version": base.Version,
"onebot_version": 12,
"runtime_version": runtime.Version(),
"runtime_os": runtime.GOOS,
})
}

View File

@ -111,7 +111,7 @@ func NewQQBot(cli *client.QQClient) *CQBot {
for {
<-t.C
bot.dispatchEvent("meta_event/heartbeat", global.MSG{
"status": bot.CQGetStatus()["data"],
"status": bot.CQGetStatus(11)["data"],
"interval": base.HeartbeatInterval.Milliseconds(),
})
}

View File

@ -7,13 +7,40 @@ import (
"github.com/Mrs4s/go-cqhttp/global"
)
func (c *Caller) call(action string, p Getter) global.MSG {
func (c *Caller) call(action string, version uint16, p Getter) global.MSG {
if version == 12 {
if action == "get_supported_actions" {
return coolq.OK([]string{".get_word_slices", ".handle_quick_operation", ".ocr_image", "ocr_image", "_get_group_notice", "_get_model_show", "_send_group_notice", "_set_model_show", "check_url_safely", "create_group_file_folder", "create_guild_role", "delete_essence_msg", "delete_friend", "delete_group_file", "delete_group_folder", "delete_guild_role", "delete_msg", "delete_unidirectional_friend", "download_file", "get_essence_msg_list", "get_forward_msg", "get_friend_list", "get_group_at_all_remain", "get_group_file_system_info", "get_group_file_url", "get_group_files_by_folder", "get_group_honor_info", "get_group_info", "get_group_list", "get_group_member_info", "get_group_member_list", "get_group_msg_history", "get_group_root_files", "get_group_system_msg", "get_guild_channel_list", "get_guild_list", "get_guild_member_list", "get_guild_member_profile", "get_guild_meta_by_guest", "get_guild_msg", "get_guild_roles", "get_guild_service_profile", "get_image", "get_self_info", "get_msg", "get_online_clients", "get_status", "get_user_info", "get_topic_channel_feeds", "get_unidirectional_friend_list", "get_version", "mark_msg_as_read", "qidian_get_account_info", "reload_event_filter", "send_forward_msg", "send_group_forward_msg", "send_group_msg", "send_group_sign", "send_guild_channel_msg", "send_msg", "send_private_forward_msg", "send_private_msg", "set_essence_msg", "set_friend_add_request", "set_group_add_request", "set_group_admin", "set_group_anonymous_ban", "set_group_ban", "set_group_card", "set_group_kick", "set_group_leave", "set_group_name", "set_group_portrait", "set_group_special_title", "set_group_whole_ban", "set_guild_member_role", "set_qq_profile", "update_guild_role", "upload_group_file"})
}
switch action {
case "get_self_info":
return c.bot.CQGetLoginInfo()
case "get_user_info":
p0 := p.Get("user_id").Int()
return c.bot.CQGetStrangerInfo(p0)
case "get_version":
return c.bot.CQGetVersion()
}
}
if version == 11 {
switch action {
case "can_send_image":
return c.bot.CQCanSendImage()
case "can_send_record":
return c.bot.CQCanSendRecord()
case "get_login_info":
return c.bot.CQGetLoginInfo()
case "get_stranger_info":
p0 := p.Get("user_id").Int()
return c.bot.CQGetStrangerInfo(p0)
case "get_version_info":
return c.bot.CQGetVersionInfo()
}
}
switch action {
default:
return coolq.Failed(404, "API_NOT_FOUND", "API不存在")
case ".get_word_slices":
p0 := p.Get("content").String()
return c.bot.CQGetWordSlices(p0)
return c.bot.CQGetWordSlices(p0, version)
case ".handle_quick_operation":
p0 := p.Get("context")
p1 := p.Get("operation")
@ -36,10 +63,6 @@ func (c *Caller) call(action string, p Getter) global.MSG {
p0 := p.Get("model").String()
p1 := p.Get("model_show").String()
return c.bot.CQSetModelShow(p0, p1)
case "can_send_image":
return c.bot.CQCanSendImage()
case "can_send_record":
return c.bot.CQCanSendRecord()
case "check_url_safely":
p0 := p.Get("url").String()
return c.bot.CQCheckURLSafely(p0)
@ -166,8 +189,6 @@ func (c *Caller) call(action string, p Getter) global.MSG {
case "get_image":
p0 := p.Get("file").String()
return c.bot.CQGetImage(p0)
case "get_login_info":
return c.bot.CQGetLoginInfo()
case "get_msg":
p0 := int32(p.Get("message_id").Int())
return c.bot.CQGetMessage(p0)
@ -175,18 +196,13 @@ func (c *Caller) call(action string, p Getter) global.MSG {
p0 := p.Get("no_cache").Bool()
return c.bot.CQGetOnlineClients(p0)
case "get_status":
return c.bot.CQGetStatus()
case "get_stranger_info":
p0 := p.Get("user_id").Int()
return c.bot.CQGetStrangerInfo(p0)
return c.bot.CQGetStatus(version)
case "get_topic_channel_feeds":
p0 := p.Get("guild_id").Uint()
p1 := p.Get("channel_id").Uint()
return c.bot.CQGetTopicChannelFeeds(p0, p1)
case "get_unidirectional_friend_list":
return c.bot.CQGetUnidirectionalFriendList()
case "get_version_info":
return c.bot.CQGetVersionInfo()
case "mark_msg_as_read":
p0 := int32(p.Get("message_id").Int())
return c.bot.CQMarkMessageAsRead(p0)
@ -338,4 +354,5 @@ func (c *Caller) call(action string, p Getter) global.MSG {
p3 := p.Get("folder").String()
return c.bot.CQUploadGroupFile(p0, p1, p2, p3)
}
return coolq.Failed(404, "API_NOT_FOUND", "API不存在")
}

View File

@ -8,7 +8,7 @@ import (
"github.com/Mrs4s/go-cqhttp/global"
)
//go:generate go run github.com/Mrs4s/go-cqhttp/cmd/api-generator -path=./../../coolq/api.go
//go:generate go run github.com/Mrs4s/go-cqhttp/cmd/api-generator -path=./../../coolq/api.go,./../../coolq/api_v12.go
// Getter 参数获取
type Getter interface {
@ -25,13 +25,13 @@ type Caller struct {
}
// Call specific API
func (c *Caller) Call(action string, p Getter) global.MSG {
func (c *Caller) Call(action string, version uint16, p Getter) global.MSG {
for _, fn := range c.handlers {
if ret := fn(action, p); ret != nil {
return ret
}
}
return c.call(action, p)
return c.call(action, version, p)
}
// Use add handlers to the API caller

View File

@ -34,6 +34,7 @@ import (
// HTTPServer HTTP通信相关配置
type HTTPServer struct {
Disabled bool `yaml:"disabled"`
Version uint16 `yaml:"version"`
Address string `yaml:"address"`
Host string `yaml:"host"`
Port int `yaml:"port"`
@ -57,6 +58,7 @@ type httpServerPost struct {
type httpServer struct {
api *api.Caller
accessToken string
version uint16
}
// HTTPClient 反向HTTP上报客户端
@ -81,6 +83,7 @@ type httpCtx struct {
const httpDefault = `
- http: # HTTP 通信设置
address: 0.0.0.0:5700 # HTTP监听地址
version: 11 # OneBot协议版本, 支持 11/12
timeout: 5 # 反向 HTTP 超时时间, 单位秒<5 时将被忽略
long-polling: # 长轮询拓展
enabled: false # 是否开启
@ -150,6 +153,13 @@ func (s *httpServer) ServeHTTP(writer http.ResponseWriter, request *http.Request
contentType := request.Header.Get("Content-Type")
switch request.Method {
case http.MethodPost:
// todo: msg pack
if s.version == 12 && strings.Contains(contentType, "application/msgpack") {
log.Warnf("请求 %v 数据类型暂不支持: MsgPack", request.RequestURI)
writer.WriteHeader(http.StatusUnsupportedMediaType)
return
}
if strings.Contains(contentType, "application/json") {
body, err := io.ReadAll(request.Body)
if err != nil {
@ -190,12 +200,12 @@ func (s *httpServer) ServeHTTP(writer http.ResponseWriter, request *http.Request
if request.URL.Path == "/" {
action := strings.TrimSuffix(ctx.Get("action").Str, "_async")
log.Debugf("HTTPServer接收到API调用: %v", action)
response = s.api.Call(action, ctx.Get("params"))
response = s.api.Call(action, s.version, ctx.Get("params"))
} else {
action := strings.TrimPrefix(request.URL.Path, "/")
action = strings.TrimSuffix(action, "_async")
log.Debugf("HTTPServer接收到API调用: %v", action)
response = s.api.Call(action, &ctx)
response = s.api.Call(action, s.version, &ctx)
}
writer.Header().Set("Content-Type", "application/json; charset=utf-8")
@ -245,9 +255,11 @@ func runHTTP(bot *coolq.CQBot, node yaml.Node) {
case conf.Disabled:
return
}
if conf.Version != 11 && conf.Version != 12 {
conf.Version = 11
}
network, addr := "tcp", conf.Address
s := &httpServer{accessToken: conf.AccessToken}
s := &httpServer{accessToken: conf.AccessToken, version: conf.Version}
switch {
case conf.Address != "":
uri, err := url.Parse(conf.Address)

View File

@ -470,7 +470,7 @@ func (c *wsConn) handleRequest(_ *coolq.CQBot, payload []byte) {
j := gjson.Parse(utils.B2S(payload))
t := strings.TrimSuffix(j.Get("action").Str, "_async")
log.Debugf("WS接收到API调用: %v 参数: %v", t, j.Get("params").Raw)
ret := c.apiCaller.Call(t, j.Get("params"))
ret := c.apiCaller.Call(t, 11, j.Get("params"))
if j.Get("echo").Exists() {
ret["echo"] = j.Get("echo").Value()
}