mirror of
https://github.com/Mrs4s/go-cqhttp.git
synced 2025-05-04 19:17:37 +08:00
265 lines
5.6 KiB
Go
265 lines
5.6 KiB
Go
package main
|
|
|
|
import (
|
|
"bytes"
|
|
"flag"
|
|
"fmt"
|
|
"go/ast"
|
|
"go/format"
|
|
"go/parser"
|
|
"go/token"
|
|
"io"
|
|
"os"
|
|
"sort"
|
|
"strconv"
|
|
"strings"
|
|
)
|
|
|
|
type Param struct {
|
|
Name string
|
|
Type string
|
|
Default string
|
|
}
|
|
|
|
type Router struct {
|
|
Func string
|
|
Path string
|
|
Aliases []string
|
|
Params []Param
|
|
}
|
|
|
|
type generator struct {
|
|
out io.Writer
|
|
}
|
|
|
|
func (g *generator) WriteString(s string) {
|
|
io.WriteString(g.out, s)
|
|
}
|
|
|
|
func (g *generator) generate(routers []Router) {
|
|
g.WriteString("// Code generated by cmd/api-generator. DO NOT EDIT.\n\n")
|
|
g.WriteString("package api\n\nimport (\n\n")
|
|
g.WriteString("\"github.com/Mrs4s/go-cqhttp/coolq\"\n")
|
|
g.WriteString("\"github.com/Mrs4s/go-cqhttp/global\"\n")
|
|
g.WriteString(")\n\n")
|
|
g.WriteString(`func (c *Caller) call(action string, p Getter) global.MSG {
|
|
switch action {
|
|
default:
|
|
return coolq.Failed(404, "API_NOT_FOUND", "API不存在")` + "\n")
|
|
for _, router := range routers {
|
|
g.router(router)
|
|
}
|
|
io.WriteString(g.out, ` }}`)
|
|
}
|
|
|
|
func (g *generator) router(router Router) {
|
|
g.WriteString(`case ` + strconv.Quote(router.Path))
|
|
for _, alias := range router.Aliases {
|
|
g.WriteString(`, ` + strconv.Quote(alias))
|
|
}
|
|
g.WriteString(":\n")
|
|
|
|
for i, p := range router.Params {
|
|
if p.Default == "" {
|
|
v := "p.Get(" + strconv.Quote(p.Name) + ")"
|
|
fmt.Fprintf(g.out, "p%d := %s\n", i, conv(v, p.Type))
|
|
} else {
|
|
fmt.Fprintf(g.out, "p%d := %s\n", i, p.Default)
|
|
fmt.Fprintf(g.out, "if pt := p.Get(%s); pt.Exists() {\n", strconv.Quote(p.Name))
|
|
fmt.Fprintf(g.out, "p%d = %s\n}\n", i, conv("pt", p.Type))
|
|
}
|
|
}
|
|
|
|
g.WriteString("\t\treturn c.bot." + router.Func + "(")
|
|
for i := range router.Params {
|
|
if i != 0 {
|
|
g.WriteString(", ")
|
|
}
|
|
fmt.Fprintf(g.out, "p%d", i)
|
|
}
|
|
g.WriteString(")\n")
|
|
}
|
|
|
|
func conv(v, t string) string {
|
|
switch t {
|
|
default:
|
|
panic("unknown type: " + t)
|
|
case "gjson.Result":
|
|
return v
|
|
case "int64":
|
|
return v + ".Int()"
|
|
case "bool":
|
|
return v + ".Bool()"
|
|
case "string":
|
|
return v + ".String()"
|
|
case "int32", "uint32", "int":
|
|
return t + "(" + v + ".Int())"
|
|
case "uint64":
|
|
return v + ".Uint()"
|
|
}
|
|
}
|
|
|
|
func main() {
|
|
var routers []Router
|
|
src := flag.String("path", "", "source file")
|
|
flag.Parse()
|
|
fset := token.NewFileSet()
|
|
file, err := parser.ParseFile(fset, *src, nil, parser.ParseComments)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
for _, decl := range file.Decls {
|
|
switch decl := decl.(type) {
|
|
case *ast.FuncDecl:
|
|
if !decl.Name.IsExported() || decl.Recv == nil {
|
|
continue
|
|
}
|
|
if st, ok := decl.Recv.List[0].Type.(*ast.StarExpr); !ok || st.X.(*ast.Ident).Name != "CQBot" {
|
|
continue
|
|
}
|
|
router := Router{Func: decl.Name.Name}
|
|
|
|
// compute params
|
|
for _, p := range decl.Type.Params.List {
|
|
var typ string
|
|
switch t := p.Type.(type) {
|
|
case *ast.Ident:
|
|
typ = t.Name
|
|
case *ast.SelectorExpr:
|
|
typ = t.X.(*ast.Ident).Name + "." + t.Sel.Name
|
|
}
|
|
for _, name := range p.Names {
|
|
router.Params = append(router.Params, Param{Name: snakecase(name.Name), Type: typ})
|
|
}
|
|
}
|
|
|
|
for _, comment := range decl.Doc.List {
|
|
annotation, args := match(comment.Text)
|
|
switch annotation {
|
|
case "route":
|
|
router.Path = args
|
|
case "alias":
|
|
router.Aliases = append(router.Aliases, args)
|
|
case "default":
|
|
for name, value := range parseMap(args, "=") {
|
|
for i, p := range router.Params {
|
|
if p.Name == name {
|
|
router.Params[i].Default = convDefault(value, p.Type)
|
|
}
|
|
}
|
|
}
|
|
case "rename":
|
|
for name, value := range parseMap(args, "->") {
|
|
for i, p := range router.Params {
|
|
if p.Name == name {
|
|
router.Params[i].Name = value
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
if router.Path != "" {
|
|
routers = append(routers, router)
|
|
} else {
|
|
println(decl.Name.Name)
|
|
}
|
|
}
|
|
}
|
|
|
|
sort.Slice(routers, func(i, j int) bool {
|
|
return routers[i].Path < routers[j].Path
|
|
})
|
|
|
|
out := new(bytes.Buffer)
|
|
g := &generator{out: out}
|
|
g.generate(routers)
|
|
source, err := format.Source(out.Bytes())
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
err = os.WriteFile("api.go", source, 0o644)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
}
|
|
|
|
func unquote(s string) string {
|
|
switch s[0] {
|
|
case '"':
|
|
s, _ = strconv.Unquote(s)
|
|
case '`':
|
|
s = strings.Trim(s, "`")
|
|
}
|
|
return s
|
|
}
|
|
|
|
func parseMap(input string, sep string) map[string]string {
|
|
out := make(map[string]string)
|
|
for _, arg := range strings.Split(input, ",") {
|
|
k, v, ok := cut(arg, sep)
|
|
if !ok {
|
|
out[k] = "true"
|
|
}
|
|
k = strings.TrimSpace(k)
|
|
v = unquote(strings.TrimSpace(v))
|
|
out[k] = v
|
|
}
|
|
return out
|
|
}
|
|
|
|
func match(text string) (string, string) {
|
|
text = strings.TrimPrefix(text, "//")
|
|
text = strings.TrimSpace(text)
|
|
if !strings.HasPrefix(text, "@") || !strings.HasSuffix(text, ")") {
|
|
return "", ""
|
|
}
|
|
text = strings.Trim(text, "@)")
|
|
cmd, args, ok := cut(text, "(")
|
|
if !ok {
|
|
return "", ""
|
|
}
|
|
return cmd, unquote(args)
|
|
}
|
|
|
|
func cut(s, sep string) (before, after string, found bool) {
|
|
if i := strings.Index(s, sep); i >= 0 {
|
|
return s[:i], s[i+len(sep):], true
|
|
}
|
|
return s, "", false
|
|
}
|
|
|
|
// some abbreviations need translation before transforming ro snake case
|
|
var replacer = strings.NewReplacer("ID", "Id")
|
|
|
|
func snakecase(s string) string {
|
|
s = replacer.Replace(s)
|
|
t := make([]byte, 0, 32)
|
|
for i := 0; i < len(s); i++ {
|
|
c := s[i]
|
|
if ('a' <= c && c <= 'z') || ('0' <= c && c <= '9') {
|
|
t = append(t, c)
|
|
} else {
|
|
t = append(t, '_')
|
|
t = append(t, c^0x20)
|
|
}
|
|
}
|
|
return string(t)
|
|
}
|
|
|
|
func convDefault(s string, t string) string {
|
|
switch t {
|
|
case "bool":
|
|
if s == "true" {
|
|
return s
|
|
}
|
|
case "uint32":
|
|
if s != "0" {
|
|
return t + "(" + s + ")"
|
|
}
|
|
default:
|
|
panic("unhandled default value type:" + t)
|
|
}
|
|
return ""
|
|
}
|