1
0
mirror of https://github.com/Mrs4s/go-cqhttp.git synced 2025-05-04 19:17:37 +08:00
2023-02-27 16:00:28 +08:00

356 lines
7.6 KiB
Go

package main
import (
"bytes"
"flag"
"fmt"
"go/ast"
"go/format"
"go/parser"
"go/token"
"io"
"os"
"reflect"
"sort"
"strconv"
"strings"
)
var supported = flag.Bool("supported", false, "genRouter supported.go")
var output = flag.String("o", "", "output file")
var pkg = flag.String("pkg", "", "package name")
var src = flag.String("path", "", "source file")
type Param struct {
Name string
Type string
Default string
}
type Router struct {
Func string
Path []string
PathV11 []string // v11 only
PathV12 []string // v12 only
Params []Param
}
type generator struct {
out io.Writer
}
const (
PathAll = 0
PathV11 = 11
PathV12 = 12
)
func (g *generator) WriteString(s string) {
io.WriteString(g.out, s)
}
func (g *generator) writef(format string, a ...any) {
fmt.Fprintf(g.out, format, a...)
}
func (g *generator) header() {
g.WriteString("// Code generated by cmd/api-generator. DO NOT EDIT.\n\n")
g.writef("package %s\n\n", *pkg)
}
func (g *generator) genRouter(routers []Router) {
g.WriteString("import (\n\n")
g.WriteString("\"github.com/Mrs4s/go-cqhttp/coolq\"\n")
g.WriteString("\"github.com/Mrs4s/go-cqhttp/global\"\n")
g.WriteString("\"github.com/Mrs4s/go-cqhttp/pkg/onebot\"\n")
g.WriteString(")\n\n")
g.WriteString(`func (c *Caller) call(action string, spec *onebot.Spec, p Getter) global.MSG {`)
genVer := func(path int) {
g.writef(`if spec.Version == %d {
switch action {
`, path)
for _, router := range routers {
g.router(router, path)
}
g.WriteString("}}\n")
}
genVer(PathV11)
genVer(PathV12)
// generic path
g.WriteString("switch action {\n")
for _, router := range routers {
g.router(router, PathAll)
}
g.WriteString("}\n")
g.WriteString("return coolq.Failed(404, \"API_NOT_FOUND\", \"API不存在\")}")
}
func (g *generator) router(router Router, pathVersion int) {
path := router.Path
if pathVersion == PathV11 {
path = router.PathV11
}
if pathVersion == PathV12 {
path = router.PathV12
}
if len(path) == 0 {
return
}
g.WriteString(`case `)
for i, p := range path {
if i != 0 {
g.WriteString(`, `)
}
g.WriteString(strconv.Quote(p))
}
g.WriteString(":\n")
for i, p := range router.Params {
if p.Type == "*onebot.Spec" {
continue
}
if p.Default == "" {
v := "p.Get(" + strconv.Quote(p.Name) + ")"
g.writef("p%d := %s\n", i, conv(v, p.Type))
} else {
g.writef("p%d := %s\n", i, p.Default)
g.writef("if pt := p.Get(%s); pt.Exists() {\n", strconv.Quote(p.Name))
g.writef("p%d = %s\n}\n", i, conv("pt", p.Type))
}
}
g.WriteString("\t\treturn c.bot." + router.Func + "(")
for i, p := range router.Params {
if i != 0 {
g.WriteString(", ")
}
if p.Type == "*onebot.Spec" {
g.WriteString("spec")
continue
}
g.writef("p%d", i)
}
g.WriteString(")\n")
}
func conv(v, t string) string {
switch t {
default:
panic("unsupported type: " + t)
case "gjson.Result", "*onebot.Spec":
return v
case "int64":
return v + ".Int()"
case "bool":
return v + ".Bool()"
case "string":
return v + ".String()"
case "int32", "int":
return t + "(" + v + ".Int())"
case "uint64":
return v + ".Uint()"
case "uint32":
return "uint32(" + v + ".Uint())"
case "uint16":
return "uint16(" + v + ".Uint())"
}
}
func main() {
var routers []Router
flag.Parse()
fset := token.NewFileSet()
for _, s := range strings.Split(*src, ",") {
file, err := parser.ParseFile(fset, s, nil, parser.ParseComments)
if err != nil {
panic(err)
}
for _, decl := range file.Decls {
switch decl := decl.(type) {
case *ast.FuncDecl:
if !decl.Name.IsExported() || decl.Recv == nil ||
typeName(decl.Recv.List[0].Type) != "*CQBot" {
continue
}
router := Router{Func: decl.Name.Name}
// compute params
for _, p := range decl.Type.Params.List {
typ := typeName(p.Type)
for _, name := range p.Names {
router.Params = append(router.Params, Param{Name: snakecase(name.Name), Type: typ})
}
}
for _, comment := range decl.Doc.List {
annotation, args := match(comment.Text)
switch annotation {
case "route":
for _, route := range strings.Split(args, ",") {
router.Path = append(router.Path, unquote(route))
}
case "route11":
for _, route := range strings.Split(args, ",") {
router.PathV11 = append(router.PathV11, unquote(route))
}
case "route12":
for _, route := range strings.Split(args, ",") {
router.PathV12 = append(router.PathV12, unquote(route))
}
case "default":
for name, value := range parseMap(args, "=") {
for i, p := range router.Params {
if p.Name == name {
router.Params[i].Default = convDefault(value, p.Type)
}
}
}
case "rename":
for name, value := range parseMap(args, "->") {
for i, p := range router.Params {
if p.Name == name {
router.Params[i].Name = value
}
}
}
}
sort.Slice(router.Path, func(i, j int) bool {
return router.Path[i] < router.Path[j]
})
sort.Slice(router.PathV11, func(i, j int) bool {
return router.PathV11[i] < router.PathV11[j]
})
sort.Slice(router.PathV12, func(i, j int) bool {
return router.PathV12[i] < router.PathV12[j]
})
}
if router.Path != nil || router.PathV11 != nil || router.PathV12 != nil {
routers = append(routers, router)
} else {
println(decl.Name.Name)
}
}
}
}
sort.Slice(routers, func(i, j int) bool {
path := func(r Router) string {
if r.Path != nil {
return r.Path[0]
}
if r.PathV11 != nil {
return r.PathV11[0]
}
if r.PathV12 != nil {
return r.PathV12[0]
}
return ""
}
return path(routers[i]) < path(routers[j])
})
out := new(bytes.Buffer)
g := &generator{out: out}
g.header()
if *supported {
g.genSupported(routers)
} else {
g.genRouter(routers)
}
source, err := format.Source(out.Bytes())
if err != nil {
panic(err)
}
err = os.WriteFile(*output, source, 0o644)
if err != nil {
panic(err)
}
}
func unquote(s string) string {
switch s[0] {
case '"':
s, _ = strconv.Unquote(s)
case '`':
s = strings.Trim(s, "`")
}
return s
}
func parseMap(input string, sep string) map[string]string {
out := make(map[string]string)
for _, arg := range strings.Split(input, ",") {
k, v, ok := strings.Cut(arg, sep)
if !ok {
out[k] = "true"
}
k = strings.TrimSpace(k)
v = unquote(strings.TrimSpace(v))
out[k] = v
}
return out
}
func match(text string) (string, string) {
text = strings.TrimPrefix(text, "//")
text = strings.TrimSpace(text)
if !strings.HasPrefix(text, "@") || !strings.HasSuffix(text, ")") {
return "", ""
}
text = strings.Trim(text, "@)")
cmd, args, ok := strings.Cut(text, "(")
if !ok {
return "", ""
}
return cmd, unquote(args)
}
// some abbreviations need translation before transforming ro snake case
var replacer = strings.NewReplacer("ID", "Id")
func snakecase(s string) string {
s = replacer.Replace(s)
t := make([]byte, 0, 32)
for i := 0; i < len(s); i++ {
c := s[i]
if ('a' <= c && c <= 'z') || ('0' <= c && c <= '9') {
t = append(t, c)
} else {
t = append(t, '_')
t = append(t, c^0x20)
}
}
return string(t)
}
func convDefault(s string, t string) string {
switch t {
case "bool":
if s == "true" {
return s
}
case "uint32":
if s != "0" {
return t + "(" + s + ")"
}
default:
panic("unhandled default value type:" + t)
}
return ""
}
func typeName(x ast.Node) string {
switch x := x.(type) {
case *ast.Ident:
return x.Name
case *ast.SelectorExpr:
return typeName(x.X) + "." + x.Sel.Name
case *ast.StarExpr:
return "*" + typeName(x.X)
default:
panic("unhandled type: " + reflect.TypeOf(x).String())
}
}