1
0
mirror of https://github.com/Mrs4s/go-cqhttp.git synced 2025-05-04 19:17:37 +08:00

fix: fix ws-reverse connect info

This commit is contained in:
wdvxdr 2021-11-17 11:41:28 +08:00
parent ea2bda523f
commit dba2bf2881
No known key found for this signature in database
GPG Key ID: 703F8C071DE7A1B6

View File

@ -40,14 +40,14 @@ type webSocketServer struct {
// websocketClient WebSocket客户端实例 // websocketClient WebSocket客户端实例
type websocketClient struct { type websocketClient struct {
bot *coolq.CQBot bot *coolq.CQBot
conf *config.WebsocketReverse
mu sync.Mutex mu sync.Mutex
universal *wsConn universal *wsConn
event *wsConn event *wsConn
token string token string
filter string filter string
reconnectInterval time.Duration
limiter api.Handler
} }
type wsConn struct { type wsConn struct {
@ -81,10 +81,8 @@ func runWSServer(b *coolq.CQBot, node yaml.Node) {
mux.HandleFunc("/event", s.event) mux.HandleFunc("/event", s.event)
mux.HandleFunc("/api", s.api) mux.HandleFunc("/api", s.api)
mux.HandleFunc("/", s.any) mux.HandleFunc("/", s.any)
go func() {
log.Infof("CQ WebSocket 服务器已启动: %v", addr) log.Infof("CQ WebSocket 服务器已启动: %v", addr)
log.Fatal(http.ListenAndServe(addr, &mux)) log.Fatal(http.ListenAndServe(addr, &mux))
}()
} }
// runWSClient 运行一个反向向WS client // runWSClient 运行一个反向向WS client
@ -100,26 +98,33 @@ func runWSClient(b *coolq.CQBot, node yaml.Node) {
c := &websocketClient{ c := &websocketClient{
bot: b, bot: b,
conf: &conf,
token: conf.AccessToken, token: conf.AccessToken,
filter: conf.Filter, filter: conf.Filter,
} }
filter.Add(c.filter) filter.Add(c.filter)
if c.conf.Universal != "" { if conf.ReconnectInterval != 0 {
c.reconnectInterval = time.Duration(conf.ReconnectInterval) * time.Millisecond
}
if conf.RateLimit.Enabled {
c.limiter = rateLimit(conf.RateLimit.Frequency, conf.RateLimit.Bucket)
}
if conf.Universal != "" {
c.connect("Universal", conf.Universal, &c.universal) c.connect("Universal", conf.Universal, &c.universal)
} else { c.bot.OnEventPush(c.onBotPushEvent("Universal", conf.Universal, &c.universal))
if c.conf.API != "" { return // 连接到 Universal 后, 不再连接其他
}
if conf.API != "" {
c.connect("API", conf.API, nil) c.connect("API", conf.API, nil)
} }
if c.conf.Event != "" { if conf.Event != "" {
c.connect("Event", conf.Event, &c.event) c.connect("Event", conf.Event, &c.event)
c.bot.OnEventPush(c.onBotPushEvent("Event", conf.Event, &c.event))
} }
} }
c.bot.OnEventPush(c.onBotPushEvent)
}
func (c *websocketClient) connect(typ, url string, conptr **wsConn) { func (c *websocketClient) connect(typ, url string, conptr **wsConn) {
log.Infof("开始尝试连接到反向WebSocket %s服务器: %v", typ, c.conf.API) log.Infof("开始尝试连接到反向WebSocket %s服务器: %v", typ, url)
header := http.Header{ header := http.Header{
"X-Client-Role": []string{typ}, "X-Client-Role": []string{typ},
"X-Self-ID": []string{strconv.FormatInt(c.bot.Client.Uin, 10)}, "X-Self-ID": []string{strconv.FormatInt(c.bot.Client.Uin, 10)},
@ -130,9 +135,9 @@ func (c *websocketClient) connect(typ, url string, conptr **wsConn) {
} }
conn, _, err := websocket.Dial(context.Background(), url, &websocket.DialOptions{HTTPHeader: header}) // nolint conn, _, err := websocket.Dial(context.Background(), url, &websocket.DialOptions{HTTPHeader: header}) // nolint
if err != nil { if err != nil {
log.Warnf("连接到反向WebSocket %s服务器 %v 时出现错误: %v", typ, c.conf.API, err) log.Warnf("连接到反向WebSocket %s服务器 %v 时出现错误: %v", typ, url, err)
if c.conf.ReconnectInterval != 0 { if c.reconnectInterval != 0 {
time.Sleep(time.Millisecond * time.Duration(c.conf.ReconnectInterval)) time.Sleep(c.reconnectInterval)
c.connect(typ, url, conptr) c.connect(typ, url, conptr)
} }
return return
@ -147,37 +152,32 @@ func (c *websocketClient) connect(typ, url string, conptr **wsConn) {
} }
} }
log.Infof("已连接到反向WebSocket %s服务器 %v", typ, c.conf.API) log.Infof("已连接到反向WebSocket %s服务器 %v", typ, url)
wrappedConn := &wsConn{Conn: conn, apiCaller: api.NewCaller(c.bot)} wrappedConn := &wsConn{Conn: conn, apiCaller: api.NewCaller(c.bot)}
if c.conf.RateLimit.Enabled { if c.limiter != nil {
wrappedConn.apiCaller.Use(rateLimit(c.conf.RateLimit.Frequency, c.conf.RateLimit.Bucket)) wrappedConn.apiCaller.Use(c.limiter)
} }
if conptr != nil { if conptr != nil {
*conptr = wrappedConn *conptr = wrappedConn
} }
switch typ { go c.listenAPI(typ, url, wrappedConn)
case "API":
go c.listenAPI(wrappedConn, false)
case "Universal":
go c.listenAPI(wrappedConn, true)
}
} }
func (c *websocketClient) listenAPI(conn *wsConn, u bool) { func (c *websocketClient) listenAPI(typ, url string, conn *wsConn) {
defer func() { _ = conn.Close(websocket.StatusNormalClosure, "") }() defer func() { _ = conn.Close(websocket.StatusNormalClosure, "") }()
conn.Conn.SetReadLimit(1024 * 1024 * 128) conn.Conn.SetReadLimit(1024 * 1024 * 128)
for { for {
buffer := global.NewBuffer() buffer := global.NewBuffer()
t, reader, err := conn.Conn.Reader(context.Background()) t, reader, err := conn.Conn.Reader(context.Background())
if err != nil { if err != nil {
log.Warnf("监听反向WS API时出现错误: %v", err) log.Warnf("监听反向WS %s时出现错误: %v", typ, err)
break break
} }
_, err = buffer.ReadFrom(reader) _, err = buffer.ReadFrom(reader)
if err != nil { if err != nil {
log.Warnf("监听反向WS API时出现错误: %v", err) log.Warnf("监听反向WS %s时出现错误: %v", typ, err)
break break
} }
if t == websocket.MessageText { if t == websocket.MessageText {
@ -189,48 +189,36 @@ func (c *websocketClient) listenAPI(conn *wsConn, u bool) {
global.PutBuffer(buffer) global.PutBuffer(buffer)
} }
} }
if c.conf.ReconnectInterval != 0 { if c.reconnectInterval != 0 {
time.Sleep(time.Millisecond * time.Duration(c.conf.ReconnectInterval)) time.Sleep(c.reconnectInterval)
if !u { if typ == "API" { // Universal 不重连,避免多次重连
go c.connect("API", c.conf.API, nil) go c.connect(typ, url, nil)
} }
} }
} }
func (c *websocketClient) onBotPushEvent(e *coolq.Event) { func (c *websocketClient) onBotPushEvent(typ, url string, conn **wsConn) func(e *coolq.Event) {
filter := filter.Find(c.filter) return func(e *coolq.Event) {
if filter != nil && !filter.Eval(gjson.Parse(e.JSONString())) {
log.Debugf("上报Event %s 到 WS服务器 时被过滤.", e.JSONBytes())
return
}
push := func(conn *wsConn, reconnect func()) {
log.Debugf("向WS服务器推送Event: %s", e.JSONBytes())
ctx, cancel := context.WithTimeout(context.Background(), time.Second*15)
defer cancel()
if err := conn.Write(ctx, websocket.MessageText, e.JSONBytes()); err != nil {
log.Warnf("向WS服务器推送 Event 时出现错误: %v", err)
_ = conn.Close(websocket.StatusNormalClosure, "")
if c.conf.ReconnectInterval != 0 {
time.Sleep(time.Millisecond * time.Duration(c.conf.ReconnectInterval))
reconnect()
}
}
}
connect := func(typ, url string, conptr **wsConn) func() {
return func() {
c.connect(typ, url, conptr)
}
}
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
if c.event != nil { flt := filter.Find(c.filter)
push(c.event, connect("Event", c.conf.Event, &c.event)) if flt != nil && !flt.Eval(gjson.Parse(e.JSONString())) {
log.Debugf("上报Event %s 到 WS服务器 时被过滤.", e.JSONBytes())
return
}
log.Debugf("向反向WS %s服务器推送Event: %s", typ, e.JSONBytes())
ctx, cancel := context.WithTimeout(context.Background(), time.Second*15)
defer cancel()
if err := (*conn).Write(ctx, websocket.MessageText, e.JSONBytes()); err != nil {
log.Warnf("向反向WS %s服务器推送 Event 时出现错误: %v", typ, err)
_ = (*conn).Close(websocket.StatusNormalClosure, "")
if c.reconnectInterval != 0 {
time.Sleep(c.reconnectInterval)
c.connect(typ, url, conn)
}
} }
if c.universal != nil {
push(c.universal, connect("Universal", c.conf.Universal, &c.universal))
} }
} }