package main import ( "bytes" "flag" "fmt" "go/ast" "go/format" "go/parser" "go/token" "io" "os" "sort" "strconv" "strings" ) type Param struct { Name string Type string Default string } type Router struct { Func string Version []uint16 Path []string PathV11 []string // v11 only PathV12 []string // v12 only Params []Param } type generator struct { out io.Writer } const ( PathAll = 0 PathV11 = 11 PathV12 = 12 ) func (g *generator) WriteString(s string) { io.WriteString(g.out, s) } func (g *generator) generate(routers []Router) { var actions []string // for onebot v12 get_supported_actions for _, router := range routers { if len(router.PathV12) > 0 { actions = append(actions, router.PathV12...) } if len(router.Path) > 0 { actions = append(actions, router.Path...) } } for i := range actions { actions[i] = `"` + actions[i] + `"` } // TODO: v12 和 all 的 switch-case 由常量改为数组寻址, 以利用 get_supported_actions g.WriteString("// Code generated by cmd/api-generator. DO NOT EDIT.\n\n") g.WriteString("package api\n\nimport (\n\n") g.WriteString("\"github.com/Mrs4s/go-cqhttp/coolq\"\n") g.WriteString("\"github.com/Mrs4s/go-cqhttp/global\"\n") g.WriteString(")\n\n") g.WriteString(fmt.Sprintf(`func (c *Caller) call(action string, version uint16, p Getter) global.MSG { if version == 12 { if action == "get_supported_actions" { return coolq.OK([]string{%v}) } switch action { `, strings.Join(actions, ","))) for _, router := range routers { g.router(router, PathV12) } io.WriteString(g.out, `}}`) io.WriteString(g.out, "\n") g.WriteString(`if version == 11 { switch action { `) for _, router := range routers { g.router(router, PathV11) } io.WriteString(g.out, `}}`) io.WriteString(g.out, "\n") io.WriteString(g.out, "switch action {\n") for _, router := range routers { g.router(router, PathAll) } io.WriteString(g.out, `}`) io.WriteString(g.out, "\n") io.WriteString(g.out, "return coolq.Failed(404, \"API_NOT_FOUND\", \"API不存在\")}") } func (g *generator) router(router Router, pathVersion int) { /* checkVersion := func(v uint16) bool { for _, ver := range router.Version { if ver == v { return true } } return false } */ path := router.Path if pathVersion == PathV11 { path = router.PathV11 } if pathVersion == PathV12 { path = router.PathV12 } if len(path) == 0 { return } g.WriteString(`case `) for i, p := range path { if i != 0 { g.WriteString(`, `) } g.WriteString(strconv.Quote(p)) } g.WriteString(":\n") if len(router.Version) == 1 { // 目前来说只需要判断一个版本的情况 check := make([]string, 0, len(router.Version)) for _, ver := range router.Version { check = append(check, fmt.Sprintf("version != %v", ver)) } fmt.Fprintf(g.out, "if %v {\n", strings.Join(check, " && ")) fmt.Fprintf(g.out, "return coolq.Failed(405, \"VERSION_ERROR\", \"API版本不匹配\")}\n") } for i, p := range router.Params { if p.Name == "version" { continue } if p.Default == "" { v := "p.Get(" + strconv.Quote(p.Name) + ")" fmt.Fprintf(g.out, "p%d := %s\n", i, conv(v, p.Type)) } else { fmt.Fprintf(g.out, "p%d := %s\n", i, p.Default) fmt.Fprintf(g.out, "if pt := p.Get(%s); pt.Exists() {\n", strconv.Quote(p.Name)) fmt.Fprintf(g.out, "p%d = %s\n}\n", i, conv("pt", p.Type)) } } g.WriteString("\t\treturn c.bot." + router.Func + "(") for i, p := range router.Params { if i != 0 { g.WriteString(", ") } if p.Name == "version" { fmt.Fprintf(g.out, "version") continue } fmt.Fprintf(g.out, "p%d", i) } g.WriteString(")\n") } func conv(v, t string) string { switch t { default: panic("unknown type: " + t) case "gjson.Result": return v case "int64": return v + ".Int()" case "bool": return v + ".Bool()" case "string": return v + ".String()" case "int32", "int": return t + "(" + v + ".Int())" case "uint64": return v + ".Uint()" case "uint32": return "uint32(" + v + ".Uint())" case "uint16": return "uint16(" + v + ".Uint())" } } func main() { var routers []Router src := flag.String("path", "", "source file") flag.Parse() fset := token.NewFileSet() for _, s := range strings.Split(*src, ",") { file, err := parser.ParseFile(fset, s, nil, parser.ParseComments) if err != nil { panic(err) } for _, decl := range file.Decls { switch decl := decl.(type) { case *ast.FuncDecl: if !decl.Name.IsExported() || decl.Recv == nil { continue } if st, ok := decl.Recv.List[0].Type.(*ast.StarExpr); !ok || st.X.(*ast.Ident).Name != "CQBot" { continue } router := Router{Func: decl.Name.Name} // compute params for _, p := range decl.Type.Params.List { var typ string switch t := p.Type.(type) { case *ast.Ident: typ = t.Name case *ast.SelectorExpr: typ = t.X.(*ast.Ident).Name + "." + t.Sel.Name } for _, name := range p.Names { router.Params = append(router.Params, Param{Name: snakecase(name.Name), Type: typ}) } } for _, comment := range decl.Doc.List { annotation, args := match(comment.Text) switch annotation { case "route": for _, route := range strings.Split(args, ",") { router.Path = append(router.Path, unquote(route)) } case "route11": for _, route := range strings.Split(args, ",") { router.PathV11 = append(router.PathV11, unquote(route)) } case "route12": for _, route := range strings.Split(args, ",") { router.PathV12 = append(router.PathV12, unquote(route)) } case "default": for name, value := range parseMap(args, "=") { for i, p := range router.Params { if p.Name == name { router.Params[i].Default = convDefault(value, p.Type) } } } case "rename": for name, value := range parseMap(args, "->") { for i, p := range router.Params { if p.Name == name { router.Params[i].Name = value } } } case "version": version := strings.Split(args, ",") for _, v := range version { if i, err := strconv.ParseUint(v, 10, 16); err == nil { router.Version = append(router.Version, uint16(i)) } } } sort.Slice(router.Path, func(i, j int) bool { return router.Path[i] < router.Path[j] }) sort.Slice(router.PathV11, func(i, j int) bool { return router.PathV11[i] < router.PathV11[j] }) sort.Slice(router.PathV12, func(i, j int) bool { return router.PathV12[i] < router.PathV12[j] }) } if router.Path != nil || router.PathV11 != nil || router.PathV12 != nil { routers = append(routers, router) } else { println(decl.Name.Name) } } } } sort.Slice(routers, func(i, j int) bool { path := func(r Router) string { if r.Path != nil { return r.Path[0] } if r.PathV11 != nil { return r.PathV11[0] } if r.PathV12 != nil { return r.PathV12[0] } return "" } return path(routers[i]) < path(routers[j]) }) out := new(bytes.Buffer) g := &generator{out: out} g.generate(routers) source, err := format.Source(out.Bytes()) if err != nil { panic(err) } err = os.WriteFile("api.go", source, 0o644) if err != nil { panic(err) } } func unquote(s string) string { switch s[0] { case '"': s, _ = strconv.Unquote(s) case '`': s = strings.Trim(s, "`") } return s } func parseMap(input string, sep string) map[string]string { out := make(map[string]string) for _, arg := range strings.Split(input, ",") { k, v, ok := cut(arg, sep) if !ok { out[k] = "true" } k = strings.TrimSpace(k) v = unquote(strings.TrimSpace(v)) out[k] = v } return out } func match(text string) (string, string) { text = strings.TrimPrefix(text, "//") text = strings.TrimSpace(text) if !strings.HasPrefix(text, "@") || !strings.HasSuffix(text, ")") { return "", "" } text = strings.Trim(text, "@)") cmd, args, ok := cut(text, "(") if !ok { return "", "" } return cmd, unquote(args) } func cut(s, sep string) (before, after string, found bool) { if i := strings.Index(s, sep); i >= 0 { return s[:i], s[i+len(sep):], true } return s, "", false } // some abbreviations need translation before transforming ro snake case var replacer = strings.NewReplacer("ID", "Id") func snakecase(s string) string { s = replacer.Replace(s) t := make([]byte, 0, 32) for i := 0; i < len(s); i++ { c := s[i] if ('a' <= c && c <= 'z') || ('0' <= c && c <= '9') { t = append(t, c) } else { t = append(t, '_') t = append(t, c^0x20) } } return string(t) } func convDefault(s string, t string) string { switch t { case "bool": if s == "true" { return s } case "uint32": if s != "0" { return t + "(" + s + ")" } default: panic("unhandled default value type:" + t) } return "" }