mirror of
https://github.com/Mrs4s/go-cqhttp.git
synced 2025-05-05 03:23:49 +08:00
389 lines
8.8 KiB
Go
389 lines
8.8 KiB
Go
package main
|
|
|
|
import (
|
|
"bytes"
|
|
"flag"
|
|
"fmt"
|
|
"go/ast"
|
|
"go/format"
|
|
"go/parser"
|
|
"go/token"
|
|
"io"
|
|
"os"
|
|
"sort"
|
|
"strconv"
|
|
"strings"
|
|
)
|
|
|
|
type Param struct {
|
|
Name string
|
|
Type string
|
|
Default string
|
|
}
|
|
|
|
type Router struct {
|
|
Func string
|
|
Version []uint16
|
|
Path []string
|
|
PathV11 []string // v11 only
|
|
PathV12 []string // v12 only
|
|
Params []Param
|
|
}
|
|
|
|
type generator struct {
|
|
out io.Writer
|
|
}
|
|
|
|
const (
|
|
PathAll = 0
|
|
PathV11 = 11
|
|
PathV12 = 12
|
|
)
|
|
|
|
func (g *generator) WriteString(s string) {
|
|
io.WriteString(g.out, s)
|
|
}
|
|
|
|
func (g *generator) generate(routers []Router) {
|
|
var actions []string // for onebot v12 get_supported_actions
|
|
for _, router := range routers {
|
|
if len(router.PathV12) > 0 {
|
|
actions = append(actions, router.PathV12...)
|
|
}
|
|
if len(router.Path) > 0 {
|
|
actions = append(actions, router.Path...)
|
|
}
|
|
}
|
|
for i := range actions {
|
|
actions[i] = `"` + actions[i] + `"`
|
|
}
|
|
|
|
// TODO: v12 和 all 的 switch-case 由常量改为数组寻址, 以利用 get_supported_actions
|
|
g.WriteString("// Code generated by cmd/api-generator. DO NOT EDIT.\n\n")
|
|
g.WriteString("package api\n\nimport (\n\n")
|
|
g.WriteString("\"github.com/Mrs4s/go-cqhttp/coolq\"\n")
|
|
g.WriteString("\"github.com/Mrs4s/go-cqhttp/global\"\n")
|
|
g.WriteString(")\n\n")
|
|
g.WriteString(fmt.Sprintf(`func (c *Caller) call(action string, version uint16, p Getter) global.MSG {
|
|
if version == 12 {
|
|
if action == "get_supported_actions" {
|
|
return coolq.OK([]string{%v})
|
|
}
|
|
switch action {
|
|
`, strings.Join(actions, ",")))
|
|
for _, router := range routers {
|
|
g.router(router, PathV12)
|
|
}
|
|
io.WriteString(g.out, `}}`)
|
|
io.WriteString(g.out, "\n")
|
|
g.WriteString(`if version == 11 {
|
|
switch action {
|
|
`)
|
|
for _, router := range routers {
|
|
g.router(router, PathV11)
|
|
}
|
|
io.WriteString(g.out, `}}`)
|
|
io.WriteString(g.out, "\n")
|
|
io.WriteString(g.out, "switch action {\n")
|
|
for _, router := range routers {
|
|
g.router(router, PathAll)
|
|
}
|
|
io.WriteString(g.out, `}`)
|
|
io.WriteString(g.out, "\n")
|
|
io.WriteString(g.out, "return coolq.Failed(404, \"API_NOT_FOUND\", \"API不存在\")}")
|
|
}
|
|
|
|
func (g *generator) router(router Router, pathVersion int) {
|
|
/*
|
|
checkVersion := func(v uint16) bool {
|
|
for _, ver := range router.Version {
|
|
if ver == v {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
*/
|
|
|
|
path := router.Path
|
|
if pathVersion == PathV11 {
|
|
path = router.PathV11
|
|
}
|
|
if pathVersion == PathV12 {
|
|
path = router.PathV12
|
|
}
|
|
if len(path) == 0 {
|
|
return
|
|
}
|
|
|
|
g.WriteString(`case `)
|
|
for i, p := range path {
|
|
if i != 0 {
|
|
g.WriteString(`, `)
|
|
}
|
|
g.WriteString(strconv.Quote(p))
|
|
}
|
|
g.WriteString(":\n")
|
|
|
|
if len(router.Version) == 1 { // 目前来说只需要判断一个版本的情况
|
|
check := make([]string, 0, len(router.Version))
|
|
for _, ver := range router.Version {
|
|
check = append(check, fmt.Sprintf("version != %v", ver))
|
|
}
|
|
fmt.Fprintf(g.out, "if %v {\n", strings.Join(check, " && "))
|
|
fmt.Fprintf(g.out, "return coolq.Failed(405, \"VERSION_ERROR\", \"API版本不匹配\")}\n")
|
|
}
|
|
|
|
for i, p := range router.Params {
|
|
if p.Name == "version" {
|
|
continue
|
|
}
|
|
if p.Default == "" {
|
|
v := "p.Get(" + strconv.Quote(p.Name) + ")"
|
|
fmt.Fprintf(g.out, "p%d := %s\n", i, conv(v, p.Type))
|
|
} else {
|
|
fmt.Fprintf(g.out, "p%d := %s\n", i, p.Default)
|
|
fmt.Fprintf(g.out, "if pt := p.Get(%s); pt.Exists() {\n", strconv.Quote(p.Name))
|
|
fmt.Fprintf(g.out, "p%d = %s\n}\n", i, conv("pt", p.Type))
|
|
}
|
|
}
|
|
|
|
g.WriteString("\t\treturn c.bot." + router.Func + "(")
|
|
for i, p := range router.Params {
|
|
if i != 0 {
|
|
g.WriteString(", ")
|
|
}
|
|
if p.Name == "version" {
|
|
fmt.Fprintf(g.out, "version")
|
|
continue
|
|
}
|
|
fmt.Fprintf(g.out, "p%d", i)
|
|
}
|
|
g.WriteString(")\n")
|
|
}
|
|
|
|
func conv(v, t string) string {
|
|
switch t {
|
|
default:
|
|
panic("unknown type: " + t)
|
|
case "gjson.Result":
|
|
return v
|
|
case "int64":
|
|
return v + ".Int()"
|
|
case "bool":
|
|
return v + ".Bool()"
|
|
case "string":
|
|
return v + ".String()"
|
|
case "int32", "int":
|
|
return t + "(" + v + ".Int())"
|
|
case "uint64":
|
|
return v + ".Uint()"
|
|
case "uint32":
|
|
return "uint32(" + v + ".Uint())"
|
|
case "uint16":
|
|
return "uint16(" + v + ".Uint())"
|
|
}
|
|
}
|
|
|
|
func main() {
|
|
var routers []Router
|
|
src := flag.String("path", "", "source file")
|
|
flag.Parse()
|
|
fset := token.NewFileSet()
|
|
for _, s := range strings.Split(*src, ",") {
|
|
file, err := parser.ParseFile(fset, s, nil, parser.ParseComments)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
for _, decl := range file.Decls {
|
|
switch decl := decl.(type) {
|
|
case *ast.FuncDecl:
|
|
if !decl.Name.IsExported() || decl.Recv == nil {
|
|
continue
|
|
}
|
|
if st, ok := decl.Recv.List[0].Type.(*ast.StarExpr); !ok || st.X.(*ast.Ident).Name != "CQBot" {
|
|
continue
|
|
}
|
|
router := Router{Func: decl.Name.Name}
|
|
|
|
// compute params
|
|
for _, p := range decl.Type.Params.List {
|
|
var typ string
|
|
switch t := p.Type.(type) {
|
|
case *ast.Ident:
|
|
typ = t.Name
|
|
case *ast.SelectorExpr:
|
|
typ = t.X.(*ast.Ident).Name + "." + t.Sel.Name
|
|
}
|
|
for _, name := range p.Names {
|
|
router.Params = append(router.Params, Param{Name: snakecase(name.Name), Type: typ})
|
|
}
|
|
}
|
|
|
|
for _, comment := range decl.Doc.List {
|
|
annotation, args := match(comment.Text)
|
|
switch annotation {
|
|
case "route":
|
|
for _, route := range strings.Split(args, ",") {
|
|
router.Path = append(router.Path, unquote(route))
|
|
}
|
|
case "route11":
|
|
for _, route := range strings.Split(args, ",") {
|
|
router.PathV11 = append(router.PathV11, unquote(route))
|
|
}
|
|
case "route12":
|
|
for _, route := range strings.Split(args, ",") {
|
|
router.PathV12 = append(router.PathV12, unquote(route))
|
|
}
|
|
case "default":
|
|
for name, value := range parseMap(args, "=") {
|
|
for i, p := range router.Params {
|
|
if p.Name == name {
|
|
router.Params[i].Default = convDefault(value, p.Type)
|
|
}
|
|
}
|
|
}
|
|
case "rename":
|
|
for name, value := range parseMap(args, "->") {
|
|
for i, p := range router.Params {
|
|
if p.Name == name {
|
|
router.Params[i].Name = value
|
|
}
|
|
}
|
|
}
|
|
case "version":
|
|
version := strings.Split(args, ",")
|
|
for _, v := range version {
|
|
if i, err := strconv.ParseUint(v, 10, 16); err == nil {
|
|
router.Version = append(router.Version, uint16(i))
|
|
}
|
|
}
|
|
}
|
|
sort.Slice(router.Path, func(i, j int) bool {
|
|
return router.Path[i] < router.Path[j]
|
|
})
|
|
sort.Slice(router.PathV11, func(i, j int) bool {
|
|
return router.PathV11[i] < router.PathV11[j]
|
|
})
|
|
sort.Slice(router.PathV12, func(i, j int) bool {
|
|
return router.PathV12[i] < router.PathV12[j]
|
|
})
|
|
}
|
|
if router.Path != nil || router.PathV11 != nil || router.PathV12 != nil {
|
|
routers = append(routers, router)
|
|
} else {
|
|
println(decl.Name.Name)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
sort.Slice(routers, func(i, j int) bool {
|
|
path := func(r Router) string {
|
|
if r.Path != nil {
|
|
return r.Path[0]
|
|
}
|
|
if r.PathV11 != nil {
|
|
return r.PathV11[0]
|
|
}
|
|
if r.PathV12 != nil {
|
|
return r.PathV12[0]
|
|
}
|
|
return ""
|
|
}
|
|
return path(routers[i]) < path(routers[j])
|
|
})
|
|
|
|
out := new(bytes.Buffer)
|
|
g := &generator{out: out}
|
|
g.generate(routers)
|
|
source, err := format.Source(out.Bytes())
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
err = os.WriteFile("api.go", source, 0o644)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
}
|
|
|
|
func unquote(s string) string {
|
|
switch s[0] {
|
|
case '"':
|
|
s, _ = strconv.Unquote(s)
|
|
case '`':
|
|
s = strings.Trim(s, "`")
|
|
}
|
|
return s
|
|
}
|
|
|
|
func parseMap(input string, sep string) map[string]string {
|
|
out := make(map[string]string)
|
|
for _, arg := range strings.Split(input, ",") {
|
|
k, v, ok := cut(arg, sep)
|
|
if !ok {
|
|
out[k] = "true"
|
|
}
|
|
k = strings.TrimSpace(k)
|
|
v = unquote(strings.TrimSpace(v))
|
|
out[k] = v
|
|
}
|
|
return out
|
|
}
|
|
|
|
func match(text string) (string, string) {
|
|
text = strings.TrimPrefix(text, "//")
|
|
text = strings.TrimSpace(text)
|
|
if !strings.HasPrefix(text, "@") || !strings.HasSuffix(text, ")") {
|
|
return "", ""
|
|
}
|
|
text = strings.Trim(text, "@)")
|
|
cmd, args, ok := cut(text, "(")
|
|
if !ok {
|
|
return "", ""
|
|
}
|
|
return cmd, unquote(args)
|
|
}
|
|
|
|
func cut(s, sep string) (before, after string, found bool) {
|
|
if i := strings.Index(s, sep); i >= 0 {
|
|
return s[:i], s[i+len(sep):], true
|
|
}
|
|
return s, "", false
|
|
}
|
|
|
|
// some abbreviations need translation before transforming ro snake case
|
|
var replacer = strings.NewReplacer("ID", "Id")
|
|
|
|
func snakecase(s string) string {
|
|
s = replacer.Replace(s)
|
|
t := make([]byte, 0, 32)
|
|
for i := 0; i < len(s); i++ {
|
|
c := s[i]
|
|
if ('a' <= c && c <= 'z') || ('0' <= c && c <= '9') {
|
|
t = append(t, c)
|
|
} else {
|
|
t = append(t, '_')
|
|
t = append(t, c^0x20)
|
|
}
|
|
}
|
|
return string(t)
|
|
}
|
|
|
|
func convDefault(s string, t string) string {
|
|
switch t {
|
|
case "bool":
|
|
if s == "true" {
|
|
return s
|
|
}
|
|
case "uint32":
|
|
if s != "0" {
|
|
return t + "(" + s + ")"
|
|
}
|
|
default:
|
|
panic("unhandled default value type:" + t)
|
|
}
|
|
return ""
|
|
}
|