diff --git a/coolq/event.go b/coolq/event.go index d3351b6..c9302a8 100644 --- a/coolq/event.go +++ b/coolq/event.go @@ -2,6 +2,7 @@ package coolq import ( "encoding/hex" + "fmt" "io/ioutil" "path" "strconv" @@ -40,6 +41,7 @@ func (bot *CQBot) privateMessageEvent(c *client.QQClient, m *message.PrivateMess bot.oneWayMsgCache.Store(m.Sender.Uin, "") } id := m.Id + fmt.Println(m) if bot.db != nil { id = bot.InsertPrivateMessage(m) } diff --git a/global/param.go b/global/param.go index 535ddf1..0eed41b 100644 --- a/global/param.go +++ b/global/param.go @@ -11,18 +11,6 @@ import ( "github.com/tidwall/gjson" ) -var trueSet = map[string]struct{}{ - "true": {}, - "yes": {}, - "1": {}, -} - -var falseSet = map[string]struct{}{ - "false": {}, - "no": {}, - "0": {}, -} - // EnsureBool 判断给定的p是否可表示为合法Bool类型,否则返回defaultVal // // 支持的合法类型有 @@ -55,13 +43,14 @@ func EnsureBool(p interface{}, defaultVal bool) bool { str = s } str = strings.ToLower(str) - if _, ok := trueSet[str]; ok { + switch str { + case "true", "yes", "1": return true - } - if _, ok := falseSet[str]; ok { + case "false", "no", "0": return false + default: + return defaultVal } - return defaultVal } // VersionNameCompare 检查版本名是否需要更新, 仅适用于 go-cqhttp 的版本命名规则 diff --git a/go.mod b/go.mod index 696520f..5b16ac5 100644 --- a/go.mod +++ b/go.mod @@ -6,8 +6,7 @@ require ( github.com/Baozisoftware/qrcode-terminal-go v0.0.0-20170407111555-c0650d8dff0f github.com/Mrs4s/MiraiGo v0.0.0-20210525010101-8f0cd9494d64 github.com/dustin/go-humanize v1.0.0 - github.com/gin-contrib/pprof v1.3.0 - github.com/gin-gonic/gin v1.7.1 + github.com/gin-gonic/gin v1.7.1 // indirect github.com/gorilla/websocket v1.4.2 github.com/guonaihong/gout v0.1.9 github.com/jonboulle/clockwork v0.2.2 // indirect diff --git a/go.sum b/go.sum index e423e29..f106a5e 100644 --- a/go.sum +++ b/go.sum @@ -14,15 +14,11 @@ github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25Kn github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= -github.com/gin-contrib/pprof v1.3.0 h1:G9eK6HnbkSqDZBYbzG4wrjCsA4e+cvYAHUZw6W+W9K0= -github.com/gin-contrib/pprof v1.3.0/go.mod h1:waMjT1H9b179t3CxuG1cV3DHpga6ybizwfBaM5OXaB0= github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE= github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI= github.com/gin-gonic/gin v1.6.0/go.mod h1:75u5sXoLsGZoRN5Sgbi1eraJ4GU3++wFwWzhwvtwp4M= -github.com/gin-gonic/gin v1.6.2/go.mod h1:75u5sXoLsGZoRN5Sgbi1eraJ4GU3++wFwWzhwvtwp4M= github.com/gin-gonic/gin v1.7.1 h1:qC89GU3p8TvKWMAVhEpmpB2CIb1hnqt2UdKZaP93mS8= github.com/gin-gonic/gin v1.7.1/go.mod h1:jD2toBW3GZUr5UMcdrwQA10I7RuaFOl/SGeDjXkfUtY= -github.com/go-playground/assert/v2 v2.0.1 h1:MsBgLAaY856+nPRTKrp3/OZK38U/wa0CcBYNjji3q3A= github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= github.com/go-playground/locales v0.13.0 h1:HyWk6mgj5qFqCT5fjGBuRArbVDfE4hi8+e8ceBS/t7Q= github.com/go-playground/locales v0.13.0/go.mod h1:taPMhCMXrRLJO55olJkUXHZBHCxTMfnGwq/HNwmWNS8= diff --git a/server/http.go b/server/http.go index c77ae13..dca1484 100644 --- a/server/http.go +++ b/server/http.go @@ -7,29 +7,32 @@ import ( "crypto/sha1" "encoding/hex" "fmt" + "io" "net/http" + "net/url" "os" "strings" "time" - "github.com/Mrs4s/go-cqhttp/coolq" - "github.com/Mrs4s/go-cqhttp/global/config" - "github.com/Mrs4s/MiraiGo/utils" - "github.com/gin-gonic/gin" "github.com/guonaihong/gout" "github.com/guonaihong/gout/dataflow" + jsoniter "github.com/json-iterator/go" log "github.com/sirupsen/logrus" "github.com/tidwall/gjson" + + "github.com/Mrs4s/go-cqhttp/coolq" + "github.com/Mrs4s/go-cqhttp/global/config" ) type httpServer struct { - engine *gin.Engine - bot *coolq.CQBot - HTTP *http.Server - api *apiCaller + HTTP *http.Server + api *apiCaller + accessToken string } +var json = jsoniter.ConfigCompatibleWithStandardLibrary + // HTTPClient 反向HTTP上报客户端 type HTTPClient struct { bot *coolq.CQBot @@ -39,8 +42,102 @@ type HTTPClient struct { timeout int32 } -type httpContext struct { - ctx *gin.Context +type httpCtx struct { + json gjson.Result + query url.Values + postForm url.Values + + headerAuth string +} + +func (h *httpCtx) Get(s string) gjson.Result { + j := h.json.Get(s) + if j.Exists() { + return j + } + if h.postForm != nil { + if form := h.postForm.Get(s); form != "" { + return gjson.Result{Type: gjson.String, Str: form} + } + } + if h.query != nil { + if query := h.query.Get(s); query != "" { + return gjson.Result{Type: gjson.String, Str: query} + } + } + return gjson.Result{} +} + +func (s *httpServer) ServeHTTP(writer http.ResponseWriter, request *http.Request) { + var ctx httpCtx + contentType := request.Header.Get("Content-Type") + ctx.headerAuth = request.Header.Get("Authorization") + switch request.Method { + case http.MethodPost: + if strings.Contains(contentType, "application/json") { + body, err := io.ReadAll(request.Body) + if err != nil { + log.Warnf("获取请求 %v 的Body时出现错误: %v", request.RequestURI, err) + writer.WriteHeader(http.StatusBadRequest) + return + } + if !gjson.ValidBytes(body) { + log.Warnf("已拒绝客户端 %v 的请求: 非法Json", request.RemoteAddr) + writer.WriteHeader(http.StatusBadRequest) + return + } + ctx.json = gjson.Parse(utils.B2S(body)) + } + if strings.Contains(contentType, "application/x-www-form-urlencoded") { + ctx.postForm = request.PostForm + } + fallthrough + case http.MethodGet: + ctx.query = request.URL.Query() + + default: + log.Warnf("已拒绝客户端 %v 的请求: 方法错误", request.RemoteAddr) + writer.WriteHeader(http.StatusNotFound) + return + } + + if s.accessToken != "" { + if status := checkAuth(ctx, s.accessToken); status != http.StatusOK { + writer.WriteHeader(status) + return + } + } + + action := strings.TrimPrefix(request.URL.Path, "/") + action = strings.TrimSuffix(action, "_async") + log.Debugf("HTTPServer接收到API调用: %v", action) + ret := s.api.callAPI(action, &ctx) + + writer.Header().Set("Content-Type", "application/json; charset=utf-8") + writer.WriteHeader(http.StatusOK) + _ = json.NewEncoder(writer).Encode(ret) +} + +func checkAuth(ctx httpCtx, token string) int { + auth := ctx.headerAuth + if auth == "" { + auth = ctx.query.Get("access_token") + switch { + case auth == "": + return http.StatusUnauthorized + case auth != token: + return http.StatusForbidden + } + } else { + auth := strings.SplitN(auth, " ", 2) + switch { + case len(auth) != 2 || auth[1] == "": + return http.StatusUnauthorized + case auth[1] != token: + return http.StatusForbidden + } + } + return http.StatusOK } // RunHTTPServerAndClients 启动HTTP服务器与HTTP上报客户端 @@ -49,78 +146,24 @@ func RunHTTPServerAndClients(bot *coolq.CQBot, conf *config.HTTPServer) { return } var ( - s = new(httpServer) - authToken = conf.AccessToken - addr string + s = new(httpServer) + addr string ) + s.accessToken = conf.AccessToken if conf.Host == "" || conf.Port == 0 { goto client } addr = fmt.Sprintf("%s:%d", conf.Host, conf.Port) - gin.SetMode(gin.ReleaseMode) - s.engine = gin.New() - s.bot = bot - s.api = newAPICaller(s.bot) + s.api = newAPICaller(bot) if conf.RateLimit.Enabled { s.api.use(rateLimit(conf.RateLimit.Frequency, conf.RateLimit.Bucket)) } - s.engine.Use(func(c *gin.Context) { - if c.Request.Method != "GET" && c.Request.Method != "POST" { - log.Warnf("已拒绝客户端 %v 的请求: 方法错误", c.Request.RemoteAddr) - c.Status(404) - return - } - if c.Request.Method == "POST" && strings.Contains(c.Request.Header.Get("Content-Type"), "application/json") { - d, err := c.GetRawData() - if err != nil { - log.Warnf("获取请求 %v 的Body时出现错误: %v", c.Request.RequestURI, err) - c.Status(400) - return - } - if !gjson.ValidBytes(d) { - log.Warnf("已拒绝客户端 %v 的请求: 非法Json", c.Request.RemoteAddr) - c.Status(400) - return - } - c.Set("json_body", gjson.ParseBytes(d)) - } - c.Next() - }) - - if authToken != "" { - s.engine.Use(func(c *gin.Context) { - auth := c.Request.Header.Get("Authorization") - if auth == "" { - headAuth := c.Query("access_token") - switch { - case headAuth == "": - c.AbortWithStatus(401) - return - case headAuth != authToken: - c.AbortWithStatus(403) - return - } - } else { - auth := strings.SplitN(auth, " ", 2) - switch { - case len(auth) != 2 || auth[1] == "": - c.AbortWithStatus(401) - return - case auth[1] != authToken: - c.AbortWithStatus(403) - return - } - } - }) - } - - s.engine.Any("/:action", s.HandleActions) go func() { log.Infof("CQ HTTP 服务器已启动: %v", addr) s.HTTP = &http.Server{ Addr: addr, - Handler: s.engine, + Handler: s, } if err := s.HTTP.ListenAndServe(); err != nil && err != http.ErrServerClosed { log.Error(err) @@ -202,34 +245,6 @@ func (c *HTTPClient) onBotPushEvent(m *bytes.Buffer) { } } -func (s *httpServer) HandleActions(c *gin.Context) { - action := strings.TrimSuffix(c.Param("action"), "_async") - log.Debugf("HTTPServer接收到API调用: %v", action) - c.JSON(200, s.api.callAPI(action, httpContext{ctx: c})) -} - -func (h httpContext) Get(k string) gjson.Result { - c := h.ctx - if q := c.Query(k); q != "" { - return gjson.Result{Type: gjson.String, Str: q} - } - if c.Request.Method == "POST" { - if h := c.Request.Header.Get("Content-Type"); h != "" { - if strings.Contains(h, "application/x-www-form-urlencoded") { - if p, ok := c.GetPostForm(k); ok { - return gjson.Result{Type: gjson.String, Str: p} - } - } - if strings.Contains(h, "application/json") { - if obj, ok := c.Get("json_body"); ok { - return obj.(gjson.Result).Get(k) - } - } - } - } - return gjson.Result{Type: gjson.Null, Str: ""} -} - func (s *httpServer) ShutDown() { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() diff --git a/server/pprof.go b/server/pprof.go index 6ccb137..0b6b95a 100644 --- a/server/pprof.go +++ b/server/pprof.go @@ -3,11 +3,10 @@ package server import ( "fmt" "net/http" + "net/http/pprof" "os" "time" - "github.com/gin-contrib/pprof" - "github.com/gin-gonic/gin" log "github.com/sirupsen/logrus" "github.com/Mrs4s/go-cqhttp/global/config" @@ -18,14 +17,18 @@ func RunPprofServer(conf *config.PprofServer) { if conf.Disabled { return } - gin.SetMode(gin.ReleaseMode) - engine := gin.New() addr := fmt.Sprintf("%s:%d", conf.Host, conf.Port) - pprof.Register(engine) + mux := http.NewServeMux() + mux.HandleFunc("/debug/pprof/", pprof.Index) + mux.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline) + mux.HandleFunc("/debug/pprof/profile", pprof.Profile) + mux.HandleFunc("/debug/pprof/symbol", pprof.Symbol) + mux.HandleFunc("/debug/pprof/trace", pprof.Trace) + server := http.Server{Addr: addr, Handler: mux} go func() { log.Infof("pprof debug 服务器已启动: %v/debug/pprof", addr) log.Warnf("警告: pprof 服务不支持鉴权, 请不要运行在公网.") - if err := engine.Run(addr); err != nil && err != http.ErrServerClosed { + if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed { log.Error(err) log.Infof("pprof 服务启动失败, 请检查端口是否被占用.") log.Warnf("将在五秒后退出.")