diff --git a/global/config/config.go b/global/config/config.go index 58ea8e5..cca800b 100644 --- a/global/config/config.go +++ b/global/config/config.go @@ -125,6 +125,13 @@ type WebsocketReverse struct { MiddleWares `yaml:"middlewares"` } +type LambdaServer struct { + Disabled bool `yaml:"disabled"` + Type string `yaml:"type"` + + MiddleWares `yaml:"middlewares"` +} + // LevelDBConfig leveldb 相关配置 type LevelDBConfig struct { Enable bool `yaml:"enable"` @@ -288,6 +295,13 @@ const httpDefault = ` # HTTP 通信设置 # secret: '' # 密钥 ` +const lambdaDefault = ` # LambdaServer 配置 + - lambda: + type: scf # scf: 腾讯云函数 aws: aws Lambda + middlewares: + <<: *default # 引用默认中间件 +` + const wsDefault = ` # 正向WS设置 - ws: # 正向WS服务器监听地址 diff --git a/main.go b/main.go index 2cb9880..f0fa72d 100644 --- a/main.go +++ b/main.go @@ -12,6 +12,7 @@ import ( "os" "os/exec" "path" + "path/filepath" "runtime" "strings" "sync" @@ -70,13 +71,33 @@ func init() { // 通过-c 参数替换 配置文件路径 config.DefaultConfigFile = c + + conf = config.Get() + if debug { + conf.Output.Debug = true + } + // 在debug模式下,将在标准输出中打印当前执行行数 + if conf.Output.Debug { + log.SetReportCaller(true) + } +} + +func main() { + if h { + help() + } + if d { + server.Daemon() + } + if wd != "" { + resetWorkDir() + } + logFormatter := &easy.Formatter{ TimestampFormat: "2006-01-02 15:04:05", LogFormat: "[%time%] [%lvl%]: %msg% \n", } - conf = config.Get() - rotateOptions := []rotatelogs.Option{ rotatelogs.WithRotationTime(time.Hour * 24), } @@ -94,13 +115,6 @@ func init() { panic(err) } - if debug { - conf.Output.Debug = true - } - // 在debug模式下,将在标准输出中打印当前执行行数 - if conf.Output.Debug { - log.SetReportCaller(true) - } log.AddHook(global.NewLocalHook(w, logFormatter, global.GetLogLevel(conf.Output.LogLevel)...)) if !global.PathExists(global.ImagePath) { @@ -123,18 +137,7 @@ func init() { log.Fatalf("创建发送图片缓存文件夹失败: %v", err) } } -} -func main() { - if h { - help() - } - if d { - server.Daemon() - } - if wd != "" { - resetWorkDir() - } var byteKey []byte arg := os.Args if len(arg) > 1 { @@ -418,6 +421,14 @@ func main() { go server.RunPprofServer(pc) } } + if p, ok := m["lambda"]; ok { + lc := new(config.LambdaServer) + if err := p.Decode(lc); err != nil { + log.Warn("读取pprof配置失败 :", err) + } else { + go server.RunLambdaClient(bot, lc) + } + } } log.Info("资源初始化完成, 开始处理信息.") log.Info("アトリは、高性能ですから!") @@ -613,7 +624,8 @@ func resetWorkDir() { args = append(args, os.Args[i]) } } - proc := exec.Command(os.Args[0], args...) + p, _ := filepath.Abs(os.Args[0]) + proc := exec.Command(p, args...) proc.Stdin = os.Stdin proc.Stdout = os.Stdout proc.Stderr = os.Stderr diff --git a/scripts/bootstrap b/scripts/bootstrap new file mode 100644 index 0000000..bb0d671 --- /dev/null +++ b/scripts/bootstrap @@ -0,0 +1,10 @@ +#!/bin/sh +echo "Start GOCQHTTP~~~" +unset http_proxy +unset https_proxy + +mkdir /mnt/ +rm -rf /mnt/data/leveldb +cp -f config.yml /mnt/config.yml +cp -f device.json /mnt/device.json +./go-cqhttp -w="/mnt/" faststart diff --git a/server/scf.go b/server/scf.go new file mode 100644 index 0000000..ef167f8 --- /dev/null +++ b/server/scf.go @@ -0,0 +1,179 @@ +package server + +import ( + "bytes" + "fmt" + "io" + "net/http" + "net/url" + "os" + "runtime/debug" + "strings" + + "github.com/Mrs4s/MiraiGo/utils" + log "github.com/sirupsen/logrus" + + "github.com/Mrs4s/go-cqhttp/coolq" + "github.com/Mrs4s/go-cqhttp/global" + "github.com/Mrs4s/go-cqhttp/global/config" +) + +type lambdaClient struct { + nextUrl string + responseURL string + lambdaType string + + client http.Client + http *httpServer +} + +type lambdaResponse struct { + IsBase64Encoded bool `json:"isBase64Encoded"` + StatusCode int `json:"statusCode"` + Headers map[string]string `json:"headers"` + Body string `json:"body"` +} + +type lambdaResponseWriter struct { + statusCode int + header http.Header + buf *bytes.Buffer +} + +func (l *lambdaResponseWriter) Header() http.Header { + return l.header +} + +func (l *lambdaResponseWriter) Write(data []byte) (int, error) { + buffer := global.NewBuffer() + defer global.PutBuffer(buffer) + body := "" + if l.buf != nil { + body = utils.B2S(data) + } + header := make(map[string]string) + for k, v := range l.header { + header[k] = v[0] + } + json.NewEncoder(buffer).Encode(&lambdaResponse{ + IsBase64Encoded: false, + StatusCode: l.statusCode, + Headers: header, + Body: body, + }) + + r, _ := http.NewRequest("POST", cli.responseURL, buffer) + do, err := cli.client.Do(r) + if err != nil { + return 0, err + } + _ = do.Body.Close() + return len(data), nil +} + +func (l *lambdaResponseWriter) WriteHeader(statusCode int) { + l.statusCode = statusCode +} + +var cli *lambdaClient + +func init() { + cli = &lambdaClient{ + lambdaType: "scf", + client: http.Client{Timeout: 0}, + } + switch cli.lambdaType { // todo: aws + case "scf": // tencent serverless function + base := fmt.Sprintf("http://%s:%s/runtime/", + os.Getenv("SCF_RUNTIME_API"), + os.Getenv("SCF_RUNTIME_API_PORT"), + ) + cli.nextUrl = base + "invocation/next" + cli.responseURL = base + "invocation/response" + post, err := http.Post(base+"init/ready", "", nil) + if err != nil { + log.Warnf("lambda 初始化失败: %v", err) + return + } + _ = post.Body.Close() + case "aws": // aws lambda + const apiVersion = "2018-06-01" + base := fmt.Sprintf("http://%s/%s/runtime/", os.Getenv("AWS_LAMBDA_RUNTIME_API"), apiVersion) + cli.nextUrl = base + "invocation/next" + cli.responseURL = base + "invocation/response" + } +} + +// RunLambdaClient type: [scf,aws] +func RunLambdaClient(bot *coolq.CQBot, conf *config.LambdaServer) { + api := newAPICaller(bot) + if conf.RateLimit.Enabled { + api.use(rateLimit(conf.RateLimit.Frequency, conf.RateLimit.Bucket)) + } + server := &httpServer{ + api: api, + accessToken: conf.AccessToken, + } + + for { + req := cli.next() + func() { + defer func() { + if e := recover(); e != nil { + log.Warnf("Lambda 出现不可恢复错误: %v\n%s", e, debug.Stack()) + } + }() + buffer := global.NewBuffer() + defer global.PutBuffer(buffer) + server.ServeHTTP(&lambdaResponseWriter{ + header: make(http.Header), + buf: buffer, + }, req) + }() + } +} + +type lambdaInvoke struct { + Headers map[string]string + HttpMethod string `json:"httpMethod"` + Body string `json:"body"` + Path string `json:"path"` + QueryString map[string]string + RequestContext struct { + Path string `json:"path"` + } `json:"requestContext"` +} + +func (c *lambdaClient) next() *http.Request { + r, err := http.NewRequest(http.MethodGet, c.nextUrl, nil) + if err != nil { + return nil + } + resp, err := c.client.Do(r) + if err != nil { + return nil + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return nil + } + var req = new(http.Request) + var invoke = new(lambdaInvoke) + _ = json.NewDecoder(resp.Body).Decode(invoke) + + req.Method = invoke.HttpMethod + req.Body = io.NopCloser(strings.NewReader(invoke.Body)) + req.Header = make(map[string][]string) + for k, v := range invoke.Headers { + req.Header.Set(k, v) + } + req.URL = new(url.URL) + req.URL.Path = strings.TrimPrefix(invoke.Path, invoke.RequestContext.Path) + // todo: avoid encoding + query := make(url.Values) + for k, v := range invoke.QueryString { + query[k] = []string{v} + } + req.URL.RawQuery = query.Encode() + return req +}