1
0
mirror of https://github.com/Mrs4s/go-cqhttp.git synced 2025-05-04 19:17:37 +08:00
2023-02-27 15:22:19 +08:00

403 lines
9.0 KiB
Go

package main
import (
"bytes"
"flag"
"fmt"
"go/ast"
"go/format"
"go/parser"
"go/token"
"io"
"os"
"sort"
"strconv"
"strings"
)
var supported = flag.Bool("supported", false, "genRouter supported.go")
var output = flag.String("o", "", "output file")
var pkg = flag.String("pkg", "", "package name")
var src = flag.String("path", "", "source file")
type Param struct {
Name string
Type string
Default string
}
type Router struct {
Func string
Version []uint16
Path []string
PathV11 []string // v11 only
PathV12 []string // v12 only
Params []Param
}
type generator struct {
out io.Writer
}
const (
PathAll = 0
PathV11 = 11
PathV12 = 12
)
func (g *generator) WriteString(s string) {
io.WriteString(g.out, s)
}
func (g *generator) writef(format string, a ...any) {
fmt.Fprintf(g.out, format, a...)
}
func (g *generator) header() {
g.WriteString("// Code generated by cmd/api-generator. DO NOT EDIT.\n\n")
g.writef("package %s\n\n", *pkg)
}
func (g *generator) genRouter(routers []Router) {
var actions []string // for onebot v12 get_supported_actions
for _, router := range routers {
if len(router.PathV12) > 0 {
actions = append(actions, router.PathV12...)
}
if len(router.Path) > 0 {
actions = append(actions, router.Path...)
}
}
for i := range actions {
actions[i] = `"` + actions[i] + `"`
}
g.WriteString("import (\n\n")
g.WriteString("\"github.com/Mrs4s/go-cqhttp/coolq\"\n")
g.WriteString("\"github.com/Mrs4s/go-cqhttp/global\"\n")
g.WriteString("\"github.com/Mrs4s/go-cqhttp/pkg/onebot\"\n")
g.WriteString(")\n\n")
g.WriteString(`func (c *Caller) call(action string, spec *onebot.Spec, p Getter) global.MSG {
if spec.Version == 12 {
switch action {
`)
for _, router := range routers {
g.router(router, PathV12)
}
io.WriteString(g.out, `}}`)
io.WriteString(g.out, "\n")
g.WriteString(`if spec.Version == 11 {
switch action {
`)
for _, router := range routers {
g.router(router, PathV11)
}
io.WriteString(g.out, `}}`)
io.WriteString(g.out, "\n")
io.WriteString(g.out, "switch action {\n")
for _, router := range routers {
g.router(router, PathAll)
}
io.WriteString(g.out, `}`)
io.WriteString(g.out, "\n")
io.WriteString(g.out, "return coolq.Failed(404, \"API_NOT_FOUND\", \"API不存在\")}")
}
func (g *generator) router(router Router, pathVersion int) {
/*
checkVersion := func(v uint16) bool {
for _, ver := range router.Version {
if ver == v {
return true
}
}
return false
}
*/
path := router.Path
if pathVersion == PathV11 {
path = router.PathV11
}
if pathVersion == PathV12 {
path = router.PathV12
}
if len(path) == 0 {
return
}
g.WriteString(`case `)
for i, p := range path {
if i != 0 {
g.WriteString(`, `)
}
g.WriteString(strconv.Quote(p))
}
g.WriteString(":\n")
if len(router.Version) == 1 { // 目前来说只需要判断一个版本的情况
check := make([]string, 0, len(router.Version))
for _, ver := range router.Version {
check = append(check, fmt.Sprintf("spec.Version != %v", ver))
}
fmt.Fprintf(g.out, "if %v {\n", strings.Join(check, " && "))
fmt.Fprintf(g.out, "return coolq.Failed(405, \"VERSION_ERROR\", \"API版本不匹配\")}\n")
}
for i, p := range router.Params {
if p.Name == "spec" {
continue
}
if p.Default == "" {
v := "p.Get(" + strconv.Quote(p.Name) + ")"
fmt.Fprintf(g.out, "p%d := %s\n", i, conv(v, p.Type))
} else {
fmt.Fprintf(g.out, "p%d := %s\n", i, p.Default)
fmt.Fprintf(g.out, "if pt := p.Get(%s); pt.Exists() {\n", strconv.Quote(p.Name))
fmt.Fprintf(g.out, "p%d = %s\n}\n", i, conv("pt", p.Type))
}
}
g.WriteString("\t\treturn c.bot." + router.Func + "(")
for i, p := range router.Params {
if i != 0 {
g.WriteString(", ")
}
if p.Name == "spec" {
fmt.Fprintf(g.out, "spec")
continue
}
fmt.Fprintf(g.out, "p%d", i)
}
g.WriteString(")\n")
}
func conv(v, t string) string {
switch t {
default:
panic("unknown type: " + t)
case "gjson.Result", "*onebot.Spec":
return v
case "int64":
return v + ".Int()"
case "bool":
return v + ".Bool()"
case "string":
return v + ".String()"
case "int32", "int":
return t + "(" + v + ".Int())"
case "uint64":
return v + ".Uint()"
case "uint32":
return "uint32(" + v + ".Uint())"
case "uint16":
return "uint16(" + v + ".Uint())"
}
}
func main() {
var routers []Router
flag.Parse()
fset := token.NewFileSet()
for _, s := range strings.Split(*src, ",") {
file, err := parser.ParseFile(fset, s, nil, parser.ParseComments)
if err != nil {
panic(err)
}
for _, decl := range file.Decls {
switch decl := decl.(type) {
case *ast.FuncDecl:
if !decl.Name.IsExported() || decl.Recv == nil {
continue
}
if st, ok := decl.Recv.List[0].Type.(*ast.StarExpr); !ok || st.X.(*ast.Ident).Name != "CQBot" {
continue
}
router := Router{Func: decl.Name.Name}
// compute params
for _, p := range decl.Type.Params.List {
var typ string
switch t := p.Type.(type) {
case *ast.Ident:
typ = t.Name
case *ast.SelectorExpr:
typ = t.X.(*ast.Ident).Name + "." + t.Sel.Name
}
for _, name := range p.Names {
router.Params = append(router.Params, Param{Name: snakecase(name.Name), Type: typ})
}
}
for _, comment := range decl.Doc.List {
annotation, args := match(comment.Text)
switch annotation {
case "route":
for _, route := range strings.Split(args, ",") {
router.Path = append(router.Path, unquote(route))
}
case "route11":
for _, route := range strings.Split(args, ",") {
router.PathV11 = append(router.PathV11, unquote(route))
}
case "route12":
for _, route := range strings.Split(args, ",") {
router.PathV12 = append(router.PathV12, unquote(route))
}
case "default":
for name, value := range parseMap(args, "=") {
for i, p := range router.Params {
if p.Name == name {
router.Params[i].Default = convDefault(value, p.Type)
}
}
}
case "rename":
for name, value := range parseMap(args, "->") {
for i, p := range router.Params {
if p.Name == name {
router.Params[i].Name = value
}
}
}
case "version":
version := strings.Split(args, ",")
for _, v := range version {
if i, err := strconv.ParseUint(v, 10, 16); err == nil {
router.Version = append(router.Version, uint16(i))
}
}
}
sort.Slice(router.Path, func(i, j int) bool {
return router.Path[i] < router.Path[j]
})
sort.Slice(router.PathV11, func(i, j int) bool {
return router.PathV11[i] < router.PathV11[j]
})
sort.Slice(router.PathV12, func(i, j int) bool {
return router.PathV12[i] < router.PathV12[j]
})
}
if router.Path != nil || router.PathV11 != nil || router.PathV12 != nil {
routers = append(routers, router)
} else {
println(decl.Name.Name)
}
}
}
}
sort.Slice(routers, func(i, j int) bool {
path := func(r Router) string {
if r.Path != nil {
return r.Path[0]
}
if r.PathV11 != nil {
return r.PathV11[0]
}
if r.PathV12 != nil {
return r.PathV12[0]
}
return ""
}
return path(routers[i]) < path(routers[j])
})
out := new(bytes.Buffer)
g := &generator{out: out}
g.header()
if *supported {
g.genSupported(routers)
} else {
g.genRouter(routers)
}
source, err := format.Source(out.Bytes())
if err != nil {
panic(err)
}
err = os.WriteFile(*output, source, 0o644)
if err != nil {
panic(err)
}
}
func unquote(s string) string {
switch s[0] {
case '"':
s, _ = strconv.Unquote(s)
case '`':
s = strings.Trim(s, "`")
}
return s
}
func parseMap(input string, sep string) map[string]string {
out := make(map[string]string)
for _, arg := range strings.Split(input, ",") {
k, v, ok := cut(arg, sep)
if !ok {
out[k] = "true"
}
k = strings.TrimSpace(k)
v = unquote(strings.TrimSpace(v))
out[k] = v
}
return out
}
func match(text string) (string, string) {
text = strings.TrimPrefix(text, "//")
text = strings.TrimSpace(text)
if !strings.HasPrefix(text, "@") || !strings.HasSuffix(text, ")") {
return "", ""
}
text = strings.Trim(text, "@)")
cmd, args, ok := cut(text, "(")
if !ok {
return "", ""
}
return cmd, unquote(args)
}
func cut(s, sep string) (before, after string, found bool) {
if i := strings.Index(s, sep); i >= 0 {
return s[:i], s[i+len(sep):], true
}
return s, "", false
}
// some abbreviations need translation before transforming ro snake case
var replacer = strings.NewReplacer("ID", "Id")
func snakecase(s string) string {
s = replacer.Replace(s)
t := make([]byte, 0, 32)
for i := 0; i < len(s); i++ {
c := s[i]
if ('a' <= c && c <= 'z') || ('0' <= c && c <= '9') {
t = append(t, c)
} else {
t = append(t, '_')
t = append(t, c^0x20)
}
}
return string(t)
}
func convDefault(s string, t string) string {
switch t {
case "bool":
if s == "true" {
return s
}
case "uint32":
if s != "0" {
return t + "(" + s + ")"
}
default:
panic("unhandled default value type:" + t)
}
return ""
}