package main import ( "bytes" "flag" "fmt" "go/ast" "go/format" "go/parser" "go/token" "io" "os" "sort" "strconv" "strings" ) type Param struct { Name string Type string Default string } type Router struct { Func string Path []string Params []Param } type generator struct { out io.Writer } func (g *generator) WriteString(s string) { io.WriteString(g.out, s) } func (g *generator) generate(routers []Router) { g.WriteString("// Code generated by cmd/api-generator. DO NOT EDIT.\n\n") g.WriteString("package api\n\nimport (\n\n") g.WriteString("\"github.com/Mrs4s/go-cqhttp/coolq\"\n") g.WriteString("\"github.com/Mrs4s/go-cqhttp/global\"\n") g.WriteString(")\n\n") g.WriteString(`func (c *Caller) call(action string, p Getter) global.MSG { switch action { default: return coolq.Failed(404, "API_NOT_FOUND", "API不存在")` + "\n") for _, router := range routers { g.router(router) } io.WriteString(g.out, ` }}`) } func (g *generator) router(router Router) { g.WriteString(`case `) for i, p := range router.Path { if i != 0 { g.WriteString(`, `) } g.WriteString(strconv.Quote(p)) } g.WriteString(":\n") for i, p := range router.Params { if p.Default == "" { v := "p.Get(" + strconv.Quote(p.Name) + ")" fmt.Fprintf(g.out, "p%d := %s\n", i, conv(v, p.Type)) } else { fmt.Fprintf(g.out, "p%d := %s\n", i, p.Default) fmt.Fprintf(g.out, "if pt := p.Get(%s); pt.Exists() {\n", strconv.Quote(p.Name)) fmt.Fprintf(g.out, "p%d = %s\n}\n", i, conv("pt", p.Type)) } } g.WriteString("\t\treturn c.bot." + router.Func + "(") for i := range router.Params { if i != 0 { g.WriteString(", ") } fmt.Fprintf(g.out, "p%d", i) } g.WriteString(")\n") } func conv(v, t string) string { switch t { default: panic("unknown type: " + t) case "gjson.Result": return v case "int64": return v + ".Int()" case "bool": return v + ".Bool()" case "string": return v + ".String()" case "int32", "int": return t + "(" + v + ".Int())" case "uint64": return v + ".Uint()" case "uint32": return "uint32(" + v + ".Uint())" } } func main() { var routers []Router src := flag.String("path", "", "source file") flag.Parse() fset := token.NewFileSet() file, err := parser.ParseFile(fset, *src, nil, parser.ParseComments) if err != nil { panic(err) } for _, decl := range file.Decls { switch decl := decl.(type) { case *ast.FuncDecl: if !decl.Name.IsExported() || decl.Recv == nil { continue } if st, ok := decl.Recv.List[0].Type.(*ast.StarExpr); !ok || st.X.(*ast.Ident).Name != "CQBot" { continue } router := Router{Func: decl.Name.Name} // compute params for _, p := range decl.Type.Params.List { var typ string switch t := p.Type.(type) { case *ast.Ident: typ = t.Name case *ast.SelectorExpr: typ = t.X.(*ast.Ident).Name + "." + t.Sel.Name } for _, name := range p.Names { router.Params = append(router.Params, Param{Name: snakecase(name.Name), Type: typ}) } } for _, comment := range decl.Doc.List { annotation, args := match(comment.Text) switch annotation { case "route": for _, route := range strings.Split(args, ",") { router.Path = append(router.Path, unquote(route)) } case "default": for name, value := range parseMap(args, "=") { for i, p := range router.Params { if p.Name == name { router.Params[i].Default = convDefault(value, p.Type) } } } case "rename": for name, value := range parseMap(args, "->") { for i, p := range router.Params { if p.Name == name { router.Params[i].Name = value } } } } sort.Slice(router.Path, func(i, j int) bool { return router.Path[i] < router.Path[j] }) } if router.Path != nil { routers = append(routers, router) } else { println(decl.Name.Name) } } } sort.Slice(routers, func(i, j int) bool { return routers[i].Path[0] < routers[j].Path[0] }) out := new(bytes.Buffer) g := &generator{out: out} g.generate(routers) source, err := format.Source(out.Bytes()) if err != nil { panic(err) } err = os.WriteFile("api.go", source, 0o644) if err != nil { panic(err) } } func unquote(s string) string { switch s[0] { case '"': s, _ = strconv.Unquote(s) case '`': s = strings.Trim(s, "`") } return s } func parseMap(input string, sep string) map[string]string { out := make(map[string]string) for _, arg := range strings.Split(input, ",") { k, v, ok := cut(arg, sep) if !ok { out[k] = "true" } k = strings.TrimSpace(k) v = unquote(strings.TrimSpace(v)) out[k] = v } return out } func match(text string) (string, string) { text = strings.TrimPrefix(text, "//") text = strings.TrimSpace(text) if !strings.HasPrefix(text, "@") || !strings.HasSuffix(text, ")") { return "", "" } text = strings.Trim(text, "@)") cmd, args, ok := cut(text, "(") if !ok { return "", "" } return cmd, unquote(args) } func cut(s, sep string) (before, after string, found bool) { if i := strings.Index(s, sep); i >= 0 { return s[:i], s[i+len(sep):], true } return s, "", false } // some abbreviations need translation before transforming ro snake case var replacer = strings.NewReplacer("ID", "Id") func snakecase(s string) string { s = replacer.Replace(s) t := make([]byte, 0, 32) for i := 0; i < len(s); i++ { c := s[i] if ('a' <= c && c <= 'z') || ('0' <= c && c <= '9') { t = append(t, c) } else { t = append(t, '_') t = append(t, c^0x20) } } return string(t) } func convDefault(s string, t string) string { switch t { case "bool": if s == "true" { return s } case "uint32": if s != "0" { return t + "(" + s + ")" } default: panic("unhandled default value type:" + t) } return "" }