mirror of
https://github.com/Mrs4s/go-cqhttp.git
synced 2025-05-04 19:17:37 +08:00
356 lines
7.6 KiB
Go
356 lines
7.6 KiB
Go
package main
|
|
|
|
import (
|
|
"bytes"
|
|
"flag"
|
|
"fmt"
|
|
"go/ast"
|
|
"go/format"
|
|
"go/parser"
|
|
"go/token"
|
|
"io"
|
|
"os"
|
|
"reflect"
|
|
"sort"
|
|
"strconv"
|
|
"strings"
|
|
)
|
|
|
|
var supported = flag.Bool("supported", false, "genRouter supported.go")
|
|
var output = flag.String("o", "", "output file")
|
|
var pkg = flag.String("pkg", "", "package name")
|
|
var src = flag.String("path", "", "source file")
|
|
|
|
type Param struct {
|
|
Name string
|
|
Type string
|
|
Default string
|
|
}
|
|
|
|
type Router struct {
|
|
Func string
|
|
Path []string
|
|
PathV11 []string // v11 only
|
|
PathV12 []string // v12 only
|
|
Params []Param
|
|
}
|
|
|
|
type generator struct {
|
|
out io.Writer
|
|
}
|
|
|
|
const (
|
|
PathAll = 0
|
|
PathV11 = 11
|
|
PathV12 = 12
|
|
)
|
|
|
|
func (g *generator) WriteString(s string) {
|
|
io.WriteString(g.out, s)
|
|
}
|
|
|
|
func (g *generator) writef(format string, a ...any) {
|
|
fmt.Fprintf(g.out, format, a...)
|
|
}
|
|
|
|
func (g *generator) header() {
|
|
g.WriteString("// Code generated by cmd/api-generator. DO NOT EDIT.\n\n")
|
|
g.writef("package %s\n\n", *pkg)
|
|
}
|
|
|
|
func (g *generator) genRouter(routers []Router) {
|
|
g.WriteString("import (\n\n")
|
|
g.WriteString("\"github.com/Mrs4s/go-cqhttp/coolq\"\n")
|
|
g.WriteString("\"github.com/Mrs4s/go-cqhttp/global\"\n")
|
|
g.WriteString("\"github.com/Mrs4s/go-cqhttp/pkg/onebot\"\n")
|
|
g.WriteString(")\n\n")
|
|
g.WriteString(`func (c *Caller) call(action string, spec *onebot.Spec, p Getter) global.MSG {`)
|
|
genVer := func(path int) {
|
|
g.writef(`if spec.Version == %d {
|
|
switch action {
|
|
`, path)
|
|
for _, router := range routers {
|
|
g.router(router, path)
|
|
}
|
|
g.WriteString("}}\n")
|
|
}
|
|
genVer(PathV11)
|
|
genVer(PathV12)
|
|
// generic path
|
|
g.WriteString("switch action {\n")
|
|
for _, router := range routers {
|
|
g.router(router, PathAll)
|
|
}
|
|
g.WriteString("}\n")
|
|
g.WriteString("return coolq.Failed(404, \"API_NOT_FOUND\", \"API不存在\")}")
|
|
}
|
|
|
|
func (g *generator) router(router Router, pathVersion int) {
|
|
path := router.Path
|
|
if pathVersion == PathV11 {
|
|
path = router.PathV11
|
|
}
|
|
if pathVersion == PathV12 {
|
|
path = router.PathV12
|
|
}
|
|
if len(path) == 0 {
|
|
return
|
|
}
|
|
|
|
g.WriteString(`case `)
|
|
for i, p := range path {
|
|
if i != 0 {
|
|
g.WriteString(`, `)
|
|
}
|
|
g.WriteString(strconv.Quote(p))
|
|
}
|
|
g.WriteString(":\n")
|
|
|
|
for i, p := range router.Params {
|
|
if p.Type == "*onebot.Spec" {
|
|
continue
|
|
}
|
|
if p.Default == "" {
|
|
v := "p.Get(" + strconv.Quote(p.Name) + ")"
|
|
g.writef("p%d := %s\n", i, conv(v, p.Type))
|
|
} else {
|
|
g.writef("p%d := %s\n", i, p.Default)
|
|
g.writef("if pt := p.Get(%s); pt.Exists() {\n", strconv.Quote(p.Name))
|
|
g.writef("p%d = %s\n}\n", i, conv("pt", p.Type))
|
|
}
|
|
}
|
|
|
|
g.WriteString("\t\treturn c.bot." + router.Func + "(")
|
|
for i, p := range router.Params {
|
|
if i != 0 {
|
|
g.WriteString(", ")
|
|
}
|
|
if p.Type == "*onebot.Spec" {
|
|
g.WriteString("spec")
|
|
continue
|
|
}
|
|
g.writef("p%d", i)
|
|
}
|
|
g.WriteString(")\n")
|
|
}
|
|
|
|
func conv(v, t string) string {
|
|
switch t {
|
|
default:
|
|
panic("unsupported type: " + t)
|
|
case "gjson.Result", "*onebot.Spec":
|
|
return v
|
|
case "int64":
|
|
return v + ".Int()"
|
|
case "bool":
|
|
return v + ".Bool()"
|
|
case "string":
|
|
return v + ".String()"
|
|
case "int32", "int":
|
|
return t + "(" + v + ".Int())"
|
|
case "uint64":
|
|
return v + ".Uint()"
|
|
case "uint32":
|
|
return "uint32(" + v + ".Uint())"
|
|
case "uint16":
|
|
return "uint16(" + v + ".Uint())"
|
|
}
|
|
}
|
|
|
|
func main() {
|
|
var routers []Router
|
|
flag.Parse()
|
|
fset := token.NewFileSet()
|
|
for _, s := range strings.Split(*src, ",") {
|
|
file, err := parser.ParseFile(fset, s, nil, parser.ParseComments)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
for _, decl := range file.Decls {
|
|
switch decl := decl.(type) {
|
|
case *ast.FuncDecl:
|
|
if !decl.Name.IsExported() || decl.Recv == nil ||
|
|
typeName(decl.Recv.List[0].Type) != "*CQBot" {
|
|
continue
|
|
}
|
|
router := Router{Func: decl.Name.Name}
|
|
|
|
// compute params
|
|
for _, p := range decl.Type.Params.List {
|
|
typ := typeName(p.Type)
|
|
for _, name := range p.Names {
|
|
router.Params = append(router.Params, Param{Name: snakecase(name.Name), Type: typ})
|
|
}
|
|
}
|
|
|
|
for _, comment := range decl.Doc.List {
|
|
annotation, args := match(comment.Text)
|
|
switch annotation {
|
|
case "route":
|
|
for _, route := range strings.Split(args, ",") {
|
|
router.Path = append(router.Path, unquote(route))
|
|
}
|
|
case "route11":
|
|
for _, route := range strings.Split(args, ",") {
|
|
router.PathV11 = append(router.PathV11, unquote(route))
|
|
}
|
|
case "route12":
|
|
for _, route := range strings.Split(args, ",") {
|
|
router.PathV12 = append(router.PathV12, unquote(route))
|
|
}
|
|
case "default":
|
|
for name, value := range parseMap(args, "=") {
|
|
for i, p := range router.Params {
|
|
if p.Name == name {
|
|
router.Params[i].Default = convDefault(value, p.Type)
|
|
}
|
|
}
|
|
}
|
|
case "rename":
|
|
for name, value := range parseMap(args, "->") {
|
|
for i, p := range router.Params {
|
|
if p.Name == name {
|
|
router.Params[i].Name = value
|
|
}
|
|
}
|
|
}
|
|
}
|
|
sort.Slice(router.Path, func(i, j int) bool {
|
|
return router.Path[i] < router.Path[j]
|
|
})
|
|
sort.Slice(router.PathV11, func(i, j int) bool {
|
|
return router.PathV11[i] < router.PathV11[j]
|
|
})
|
|
sort.Slice(router.PathV12, func(i, j int) bool {
|
|
return router.PathV12[i] < router.PathV12[j]
|
|
})
|
|
}
|
|
if router.Path != nil || router.PathV11 != nil || router.PathV12 != nil {
|
|
routers = append(routers, router)
|
|
} else {
|
|
println(decl.Name.Name)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
sort.Slice(routers, func(i, j int) bool {
|
|
path := func(r Router) string {
|
|
if r.Path != nil {
|
|
return r.Path[0]
|
|
}
|
|
if r.PathV11 != nil {
|
|
return r.PathV11[0]
|
|
}
|
|
if r.PathV12 != nil {
|
|
return r.PathV12[0]
|
|
}
|
|
return ""
|
|
}
|
|
return path(routers[i]) < path(routers[j])
|
|
})
|
|
|
|
out := new(bytes.Buffer)
|
|
g := &generator{out: out}
|
|
g.header()
|
|
if *supported {
|
|
g.genSupported(routers)
|
|
} else {
|
|
g.genRouter(routers)
|
|
}
|
|
source, err := format.Source(out.Bytes())
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
err = os.WriteFile(*output, source, 0o644)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
}
|
|
|
|
func unquote(s string) string {
|
|
switch s[0] {
|
|
case '"':
|
|
s, _ = strconv.Unquote(s)
|
|
case '`':
|
|
s = strings.Trim(s, "`")
|
|
}
|
|
return s
|
|
}
|
|
|
|
func parseMap(input string, sep string) map[string]string {
|
|
out := make(map[string]string)
|
|
for _, arg := range strings.Split(input, ",") {
|
|
k, v, ok := strings.Cut(arg, sep)
|
|
if !ok {
|
|
out[k] = "true"
|
|
}
|
|
k = strings.TrimSpace(k)
|
|
v = unquote(strings.TrimSpace(v))
|
|
out[k] = v
|
|
}
|
|
return out
|
|
}
|
|
|
|
func match(text string) (string, string) {
|
|
text = strings.TrimPrefix(text, "//")
|
|
text = strings.TrimSpace(text)
|
|
if !strings.HasPrefix(text, "@") || !strings.HasSuffix(text, ")") {
|
|
return "", ""
|
|
}
|
|
text = strings.Trim(text, "@)")
|
|
cmd, args, ok := strings.Cut(text, "(")
|
|
if !ok {
|
|
return "", ""
|
|
}
|
|
return cmd, unquote(args)
|
|
}
|
|
|
|
// some abbreviations need translation before transforming ro snake case
|
|
var replacer = strings.NewReplacer("ID", "Id")
|
|
|
|
func snakecase(s string) string {
|
|
s = replacer.Replace(s)
|
|
t := make([]byte, 0, 32)
|
|
for i := 0; i < len(s); i++ {
|
|
c := s[i]
|
|
if ('a' <= c && c <= 'z') || ('0' <= c && c <= '9') {
|
|
t = append(t, c)
|
|
} else {
|
|
t = append(t, '_')
|
|
t = append(t, c^0x20)
|
|
}
|
|
}
|
|
return string(t)
|
|
}
|
|
|
|
func convDefault(s string, t string) string {
|
|
switch t {
|
|
case "bool":
|
|
if s == "true" {
|
|
return s
|
|
}
|
|
case "uint32":
|
|
if s != "0" {
|
|
return t + "(" + s + ")"
|
|
}
|
|
default:
|
|
panic("unhandled default value type:" + t)
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func typeName(x ast.Node) string {
|
|
switch x := x.(type) {
|
|
case *ast.Ident:
|
|
return x.Name
|
|
case *ast.SelectorExpr:
|
|
return typeName(x.X) + "." + x.Sel.Name
|
|
case *ast.StarExpr:
|
|
return "*" + typeName(x.X)
|
|
default:
|
|
panic("unhandled type: " + reflect.TypeOf(x).String())
|
|
}
|
|
}
|