diff --git a/coolq/bot.go b/coolq/bot.go index 92d3191..6cd9620 100644 --- a/coolq/bot.go +++ b/coolq/bot.go @@ -33,7 +33,7 @@ var json = jsoniter.ConfigCompatibleWithStandardLibrary type CQBot struct { Client *client.QQClient - events []func(*bytes.Buffer) + events []func(*Event) mu sync.Mutex db *leveldb.DB @@ -45,6 +45,33 @@ type CQBot struct { // MSG 消息Map type MSG map[string]interface{} +type Event struct { + RawMsg MSG + + once sync.Once + buffer *bytes.Buffer +} + +func (e *Event) marshal() { + if e.buffer == nil { + e.buffer = global.NewBuffer() + } + _ = json.NewEncoder(e.buffer).Encode(e.RawMsg) +} + +// JsonBytes return byes of json by lazy marshalling. +func (e *Event) JsonBytes() []byte { + e.once.Do(e.marshal) + return e.buffer.Bytes() +} + +// JsonString return string of json without extra allocation +// by lazy marshalling. +func (e *Event) JsonString() string { + e.once.Do(e.marshal) + return utils.B2S(e.buffer.Bytes()) +} + // ForceFragmented 是否启用强制分片 var ForceFragmented = false @@ -125,7 +152,7 @@ func NewQQBot(cli *client.QQClient, conf *config.Config) *CQBot { } // OnEventPush 注册事件上报函数 -func (bot *CQBot) OnEventPush(f func(buf *bytes.Buffer)) { +func (bot *CQBot) OnEventPush(f func(e *Event)) { bot.mu.Lock() defer bot.mu.Unlock() bot.events = append(bot.events, f) @@ -404,12 +431,11 @@ func (bot *CQBot) Release() { } func (bot *CQBot) dispatchEventMessage(m MSG) { - buf := global.NewBuffer() + event := &Event{RawMsg: m} wg := sync.WaitGroup{} wg.Add(len(bot.events)) - _ = json.NewEncoder(buf).Encode(m) for _, f := range bot.events { - go func(fn func(*bytes.Buffer)) { + go func(fn func(*Event)) { defer func() { wg.Done() if pan := recover(); pan != nil { @@ -418,7 +444,7 @@ func (bot *CQBot) dispatchEventMessage(m MSG) { }() start := time.Now() - fn(buf) + fn(event) end := time.Now() if end.Sub(start) > time.Second*5 { log.Debugf("警告: 事件处理耗时超过 5 秒 (%v), 请检查应用是否有堵塞.", end.Sub(start)) @@ -426,7 +452,7 @@ func (bot *CQBot) dispatchEventMessage(m MSG) { }(f) } wg.Wait() - global.PutBuffer(buf) + global.PutBuffer(event.buffer) } func (bot *CQBot) formatGroupMessage(m *message.GroupMessage) MSG { diff --git a/global/buffer.go b/global/buffer.go index b2b8578..73079f3 100644 --- a/global/buffer.go +++ b/global/buffer.go @@ -20,7 +20,7 @@ func NewBuffer() *bytes.Buffer { func PutBuffer(buf *bytes.Buffer) { // See https://golang.org/issue/23199 const maxSize = 1 << 16 - if buf.Cap() < maxSize { // 对于大Buffer直接丢弃 + if buf != nil && buf.Cap() < maxSize { // 对于大Buffer直接丢弃 buf.Reset() bufferPool.Put(buf) } diff --git a/server/http.go b/server/http.go index 7057db4..e663632 100644 --- a/server/http.go +++ b/server/http.go @@ -1,7 +1,6 @@ package server import ( - "bytes" "context" "crypto/hmac" "crypto/sha1" @@ -194,24 +193,24 @@ func (c HTTPClient) Run() { log.Infof("HTTP POST上报器已启动: %v", c.addr) } -func (c *HTTPClient) onBotPushEvent(m *bytes.Buffer) { +func (c *HTTPClient) onBotPushEvent(e *coolq.Event) { var res string if c.filter != "" { filter := findFilter(c.filter) - if filter != nil && !filter.Eval(gjson.Parse(utils.B2S(m.Bytes()))) { - log.Debugf("上报Event %v 到 HTTP 服务器 %v 时被过滤.", c.addr, utils.B2S(m.Bytes())) + if filter != nil && !filter.Eval(gjson.Parse(e.JsonString())) { + log.Debugf("上报Event %v 到 HTTP 服务器 %s 时被过滤.", c.addr, e.JsonBytes()) return } } - err := gout.POST(c.addr).SetJSON(m.Bytes()).BindBody(&res).SetHeader(func() gout.H { + err := gout.POST(c.addr).SetJSON(e.JsonBytes()).BindBody(&res).SetHeader(func() gout.H { h := gout.H{ "X-Self-ID": c.bot.Client.Uin, "User-Agent": "CQHttp/4.15.0", } if c.secret != "" { mac := hmac.New(sha1.New, []byte(c.secret)) - _, err := mac.Write(m.Bytes()) + _, err := mac.Write(e.JsonBytes()) if err != nil { log.Error(err) return nil @@ -229,12 +228,12 @@ func (c *HTTPClient) onBotPushEvent(m *bytes.Buffer) { return nil }).Do() if err != nil { - log.Warnf("上报Event数据 %v 到 %v 失败: %v", utils.B2S(m.Bytes()), c.addr, err) + log.Warnf("上报Event数据 %s 到 %v 失败: %v", e.JsonBytes(), c.addr, err) return } - log.Debugf("上报Event数据 %v 到 %v", utils.B2S(m.Bytes()), c.addr) + log.Debugf("上报Event数据 %s 到 %v", e.JsonBytes(), c.addr) if gjson.Valid(res) { - c.bot.CQHandleQuickOperation(gjson.Parse(utils.B2S(m.Bytes())), gjson.Parse(res)) + c.bot.CQHandleQuickOperation(gjson.Parse(e.JsonString()), gjson.Parse(res)) } } diff --git a/server/websocket.go b/server/websocket.go index 03d7076..ff7c8c3 100644 --- a/server/websocket.go +++ b/server/websocket.go @@ -228,19 +228,19 @@ func (c *websocketClient) listenAPI(conn *webSocketConn, u bool) { } } -func (c *websocketClient) onBotPushEvent(m *bytes.Buffer) { +func (c *websocketClient) onBotPushEvent(e *coolq.Event) { filter := findFilter(c.filter) - if filter != nil && !filter.Eval(gjson.Parse(utils.B2S(m.Bytes()))) { - log.Debugf("上报Event %v 到 WS客户端 时被过滤.", utils.B2S(m.Bytes())) + if filter != nil && !filter.Eval(gjson.Parse(e.JsonString())) { + log.Debugf("上报Event %s 到 WS客户端 时被过滤.", e.JsonBytes()) return } if c.eventConn != nil { - log.Debugf("向WS服务器 %v 推送Event: %v", c.eventConn.RemoteAddr().String(), utils.B2S(m.Bytes())) + log.Debugf("向WS服务器 %v 推送Event: %s", c.eventConn.RemoteAddr().String(), e.JsonBytes()) conn := c.eventConn conn.Lock() defer conn.Unlock() _ = c.eventConn.SetWriteDeadline(time.Now().Add(time.Second * 15)) - if err := c.eventConn.WriteMessage(websocket.TextMessage, m.Bytes()); err != nil { + if err := c.eventConn.WriteMessage(websocket.TextMessage, e.JsonBytes()); err != nil { log.Warnf("向WS服务器 %v 推送Event时出现错误: %v", c.eventConn.RemoteAddr().String(), err) _ = c.eventConn.Close() if c.conf.ReconnectInterval != 0 { @@ -250,12 +250,12 @@ func (c *websocketClient) onBotPushEvent(m *bytes.Buffer) { } } if c.universalConn != nil { - log.Debugf("向WS服务器 %v 推送Event: %v", c.universalConn.RemoteAddr().String(), utils.B2S(m.Bytes())) + log.Debugf("向WS服务器 %v 推送Event: %s", c.universalConn.RemoteAddr().String(), e.JsonBytes()) conn := c.universalConn conn.Lock() defer conn.Unlock() _ = c.universalConn.SetWriteDeadline(time.Now().Add(time.Second * 15)) - if err := c.universalConn.WriteMessage(websocket.TextMessage, m.Bytes()); err != nil { + if err := c.universalConn.WriteMessage(websocket.TextMessage, e.JsonBytes()); err != nil { log.Warnf("向WS服务器 %v 推送Event时出现错误: %v", c.universalConn.RemoteAddr().String(), err) _ = c.universalConn.Close() if c.conf.ReconnectInterval != 0 { @@ -386,21 +386,21 @@ func (c *webSocketConn) handleRequest(_ *coolq.CQBot, payload []byte) { _ = c.WriteJSON(ret) } -func (s *webSocketServer) onBotPushEvent(m *bytes.Buffer) { +func (s *webSocketServer) onBotPushEvent(e *coolq.Event) { s.eventConnMutex.Lock() defer s.eventConnMutex.Unlock() filter := findFilter(s.filter) - if filter != nil && !filter.Eval(gjson.Parse(utils.B2S(m.Bytes()))) { - log.Debugf("上报Event %v 到 WS客户端 时被过滤.", utils.B2S(m.Bytes())) + if filter != nil && !filter.Eval(gjson.Parse(e.JsonString())) { + log.Debugf("上报Event %s 到 WS客户端 时被过滤.", e.JsonBytes()) return } for i, l := 0, len(s.eventConn); i < l; i++ { conn := s.eventConn[i] - log.Debugf("向WS客户端 %v 推送Event: %v", conn.RemoteAddr().String(), utils.B2S(m.Bytes())) + log.Debugf("向WS客户端 %v 推送Event: %s", conn.RemoteAddr().String(), e.JsonBytes()) conn.Lock() - if err := conn.WriteMessage(websocket.TextMessage, m.Bytes()); err != nil { + if err := conn.WriteMessage(websocket.TextMessage, e.JsonBytes()); err != nil { _ = conn.Close() next := i + 1 if next >= l {