diff --git a/cmd/api-generator/main.go b/cmd/api-generator/main.go index cf96c06..f97d7dd 100644 --- a/cmd/api-generator/main.go +++ b/cmd/api-generator/main.go @@ -10,6 +10,7 @@ import ( "go/token" "io" "os" + "reflect" "sort" "strconv" "strings" @@ -28,7 +29,6 @@ type Param struct { type Router struct { Func string - Version []uint16 Path []string PathV11 []string // v11 only PathV12 []string // v12 only @@ -59,62 +59,33 @@ func (g *generator) header() { } func (g *generator) genRouter(routers []Router) { - var actions []string // for onebot v12 get_supported_actions - for _, router := range routers { - if len(router.PathV12) > 0 { - actions = append(actions, router.PathV12...) - } - if len(router.Path) > 0 { - actions = append(actions, router.Path...) - } - } - for i := range actions { - actions[i] = `"` + actions[i] + `"` - } - g.WriteString("import (\n\n") g.WriteString("\"github.com/Mrs4s/go-cqhttp/coolq\"\n") g.WriteString("\"github.com/Mrs4s/go-cqhttp/global\"\n") g.WriteString("\"github.com/Mrs4s/go-cqhttp/pkg/onebot\"\n") g.WriteString(")\n\n") - g.WriteString(`func (c *Caller) call(action string, spec *onebot.Spec, p Getter) global.MSG { - if spec.Version == 12 { + g.WriteString(`func (c *Caller) call(action string, spec *onebot.Spec, p Getter) global.MSG {`) + genVer := func(path int) { + g.writef(`if spec.Version == %d { switch action { -`) - for _, router := range routers { - g.router(router, PathV12) + `, path) + for _, router := range routers { + g.router(router, path) + } + g.WriteString("}}\n") } - io.WriteString(g.out, `}}`) - io.WriteString(g.out, "\n") - g.WriteString(`if spec.Version == 11 { - switch action { - `) - for _, router := range routers { - g.router(router, PathV11) - } - io.WriteString(g.out, `}}`) - io.WriteString(g.out, "\n") - io.WriteString(g.out, "switch action {\n") + genVer(PathV11) + genVer(PathV12) + // generic path + g.WriteString("switch action {\n") for _, router := range routers { g.router(router, PathAll) } - io.WriteString(g.out, `}`) - io.WriteString(g.out, "\n") - io.WriteString(g.out, "return coolq.Failed(404, \"API_NOT_FOUND\", \"API不存在\")}") + g.WriteString("}\n") + g.WriteString("return coolq.Failed(404, \"API_NOT_FOUND\", \"API不存在\")}") } func (g *generator) router(router Router, pathVersion int) { - /* - checkVersion := func(v uint16) bool { - for _, ver := range router.Version { - if ver == v { - return true - } - } - return false - } - */ - path := router.Path if pathVersion == PathV11 { path = router.PathV11 @@ -135,26 +106,17 @@ func (g *generator) router(router Router, pathVersion int) { } g.WriteString(":\n") - if len(router.Version) == 1 { // 目前来说只需要判断一个版本的情况 - check := make([]string, 0, len(router.Version)) - for _, ver := range router.Version { - check = append(check, fmt.Sprintf("spec.Version != %v", ver)) - } - fmt.Fprintf(g.out, "if %v {\n", strings.Join(check, " && ")) - fmt.Fprintf(g.out, "return coolq.Failed(405, \"VERSION_ERROR\", \"API版本不匹配\")}\n") - } - for i, p := range router.Params { - if p.Name == "spec" { + if p.Type == "*onebot.Spec" { continue } if p.Default == "" { v := "p.Get(" + strconv.Quote(p.Name) + ")" - fmt.Fprintf(g.out, "p%d := %s\n", i, conv(v, p.Type)) + g.writef("p%d := %s\n", i, conv(v, p.Type)) } else { - fmt.Fprintf(g.out, "p%d := %s\n", i, p.Default) - fmt.Fprintf(g.out, "if pt := p.Get(%s); pt.Exists() {\n", strconv.Quote(p.Name)) - fmt.Fprintf(g.out, "p%d = %s\n}\n", i, conv("pt", p.Type)) + g.writef("p%d := %s\n", i, p.Default) + g.writef("if pt := p.Get(%s); pt.Exists() {\n", strconv.Quote(p.Name)) + g.writef("p%d = %s\n}\n", i, conv("pt", p.Type)) } } @@ -163,11 +125,11 @@ func (g *generator) router(router Router, pathVersion int) { if i != 0 { g.WriteString(", ") } - if p.Name == "spec" { - fmt.Fprintf(g.out, "spec") + if p.Type == "*onebot.Spec" { + g.WriteString("spec") continue } - fmt.Fprintf(g.out, "p%d", i) + g.writef("p%d", i) } g.WriteString(")\n") } @@ -175,7 +137,7 @@ func (g *generator) router(router Router, pathVersion int) { func conv(v, t string) string { switch t { default: - panic("unknown type: " + t) + panic("unsupported type: " + t) case "gjson.Result", "*onebot.Spec": return v case "int64": @@ -208,23 +170,15 @@ func main() { for _, decl := range file.Decls { switch decl := decl.(type) { case *ast.FuncDecl: - if !decl.Name.IsExported() || decl.Recv == nil { - continue - } - if st, ok := decl.Recv.List[0].Type.(*ast.StarExpr); !ok || st.X.(*ast.Ident).Name != "CQBot" { + if !decl.Name.IsExported() || decl.Recv == nil || + typeName(decl.Recv.List[0].Type) != "*CQBot" { continue } router := Router{Func: decl.Name.Name} // compute params for _, p := range decl.Type.Params.List { - var typ string - switch t := p.Type.(type) { - case *ast.Ident: - typ = t.Name - case *ast.SelectorExpr: - typ = t.X.(*ast.Ident).Name + "." + t.Sel.Name - } + typ := typeName(p.Type) for _, name := range p.Names { router.Params = append(router.Params, Param{Name: snakecase(name.Name), Type: typ}) } @@ -261,13 +215,6 @@ func main() { } } } - case "version": - version := strings.Split(args, ",") - for _, v := range version { - if i, err := strconv.ParseUint(v, 10, 16); err == nil { - router.Version = append(router.Version, uint16(i)) - } - } } sort.Slice(router.Path, func(i, j int) bool { return router.Path[i] < router.Path[j] @@ -335,7 +282,7 @@ func unquote(s string) string { func parseMap(input string, sep string) map[string]string { out := make(map[string]string) for _, arg := range strings.Split(input, ",") { - k, v, ok := cut(arg, sep) + k, v, ok := strings.Cut(arg, sep) if !ok { out[k] = "true" } @@ -353,20 +300,13 @@ func match(text string) (string, string) { return "", "" } text = strings.Trim(text, "@)") - cmd, args, ok := cut(text, "(") + cmd, args, ok := strings.Cut(text, "(") if !ok { return "", "" } return cmd, unquote(args) } -func cut(s, sep string) (before, after string, found bool) { - if i := strings.Index(s, sep); i >= 0 { - return s[:i], s[i+len(sep):], true - } - return s, "", false -} - // some abbreviations need translation before transforming ro snake case var replacer = strings.NewReplacer("ID", "Id") @@ -400,3 +340,16 @@ func convDefault(s string, t string) string { } return "" } + +func typeName(x ast.Node) string { + switch x := x.(type) { + case *ast.Ident: + return x.Name + case *ast.SelectorExpr: + return typeName(x.X) + "." + x.Sel.Name + case *ast.StarExpr: + return "*" + typeName(x.X) + default: + panic("unhandled type: " + reflect.TypeOf(x).String()) + } +} diff --git a/modules/api/api.go b/modules/api/api.go index e260e80..8573a08 100644 --- a/modules/api/api.go +++ b/modules/api/api.go @@ -9,23 +9,6 @@ import ( ) func (c *Caller) call(action string, spec *onebot.Spec, p Getter) global.MSG { - if spec.Version == 12 { - switch action { - case "get_self_info": - return c.bot.CQGetLoginInfo() - case "get_user_info": - p0 := p.Get("user_id").Int() - return c.bot.CQGetStrangerInfo(p0) - case "get_version": - return c.bot.CQGetVersion() - case "send_message": - p0 := p.Get("group_id").String() - p1 := p.Get("user_id").String() - p2 := p.Get("detail_type").String() - p3 := p.Get("message") - return c.bot.CQSendMessageV12(p0, p1, p2, p3) - } - } if spec.Version == 11 { switch action { case ".handle_quick_operation": @@ -77,6 +60,23 @@ func (c *Caller) call(action string, spec *onebot.Spec, p Getter) global.MSG { return c.bot.CQSendPrivateMessage(p0, p1, p2, p3) } } + if spec.Version == 12 { + switch action { + case "get_self_info": + return c.bot.CQGetLoginInfo() + case "get_user_info": + p0 := p.Get("user_id").Int() + return c.bot.CQGetStrangerInfo(p0) + case "get_version": + return c.bot.CQGetVersion() + case "send_message": + p0 := p.Get("group_id").String() + p1 := p.Get("user_id").String() + p2 := p.Get("detail_type").String() + p3 := p.Get("message") + return c.bot.CQSendMessageV12(p0, p1, p2, p3) + } + } switch action { case ".get_word_slices": p0 := p.Get("content").String()