From 72a7430841845589e68dc7957f33a0267a17e9e3 Mon Sep 17 00:00:00 2001 From: Shigma <1700011071@pku.edu.cn> Date: Mon, 10 Aug 2020 20:43:31 +0800 Subject: [PATCH] add api type check --- server/http.go | 29 +++++++++++++++++------------ 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/server/http.go b/server/http.go index b191a62..15b9e8f 100644 --- a/server/http.go +++ b/server/http.go @@ -235,8 +235,8 @@ func (s *httpServer) SendMessage(c *gin.Context) { func (s *httpServer) SendPrivateMessage(c *gin.Context) { uid, _ := strconv.ParseInt(getParam(c, "user_id"), 10, 64) - msg := getParam(c, "message") - if gjson.Valid(msg) { + msg, t := getParamWithType(c, "message") + if t == gjson.JSON { c.JSON(200, s.bot.CQSendPrivateMessage(uid, gjson.Parse(msg))) return } @@ -245,8 +245,8 @@ func (s *httpServer) SendPrivateMessage(c *gin.Context) { func (s *httpServer) SendGroupMessage(c *gin.Context) { gid, _ := strconv.ParseInt(getParam(c, "group_id"), 10, 64) - msg := getParam(c, "message") - if gjson.Valid(msg) { + msg, t := getParamWithType(c, "message") + if t == gjson.JSON { c.JSON(200, s.bot.CQSendGroupMessage(gid, gjson.Parse(msg))) return } @@ -372,14 +372,19 @@ func getParamOrDefault(c *gin.Context, k, def string) string { } func getParam(c *gin.Context, k string) string { + p, _ := getParamWithType(c, k) + return p +} + +func getParamWithType(c *gin.Context, k string) (string, gjson.Type) { if q := c.Query(k); q != "" { - return q + return q, gjson.Null } if c.Request.Method == "POST" { if h := c.Request.Header.Get("Content-Type"); h != "" { if h == "application/x-www-form-urlencoded" { if p, ok := c.GetPostForm(k); ok { - return p + return p, gjson.Null } } if h == "application/json" { @@ -388,20 +393,20 @@ func getParam(c *gin.Context, k string) string { if res.Exists() { switch res.Type { case gjson.JSON: - return res.Raw + return res.Raw, gjson.JSON case gjson.String: - return res.Str + return res.Str, gjson.String case gjson.Number: - return strconv.FormatInt(res.Int(), 10) // 似乎没有需要接受 float 类型的api + return strconv.FormatInt(res.Int(), 10), gjson.Number // 似乎没有需要接受 float 类型的api case gjson.True: - return "true" + return "true", gjson.True case gjson.False: - return "false" + return "false", gjson.False } } } } } } - return "" + return "", gjson.Null }