1
0
mirror of https://github.com/Mrs4s/go-cqhttp.git synced 2025-05-04 19:17:37 +08:00

server: support unix socket

Fixes #1415
This commit is contained in:
wdvxdr 2022-03-23 21:06:32 +08:00
parent d42d8dd395
commit 40a765b117
No known key found for this signature in database
GPG Key ID: 703F8C071DE7A1B6
2 changed files with 72 additions and 15 deletions

View File

@ -2,8 +2,10 @@ package server
import (
"bytes"
"context"
"crypto/hmac"
"crypto/sha1"
"encoding/base64"
"encoding/hex"
"encoding/json"
"fmt"
@ -65,6 +67,7 @@ type HTTPClient struct {
filter string
apiPort int
timeout int32
client *http.Client
MaxRetries uint64
RetriesInterval uint64
}
@ -77,8 +80,7 @@ type httpCtx struct {
const httpDefault = `
- http: # HTTP 通信设置
host: 127.0.0.1 # 服务端监听地址
port: 5700 # 服务端监听端口
address: 0.0.0.0:5700 # HTTP监听地址
timeout: 5 # 反向 HTTP 超时时间, 单位秒<5 时将被忽略
long-polling: # 长轮询拓展
enabled: false # 是否开启
@ -300,8 +302,30 @@ func (c HTTPClient) Run() {
if c.timeout < 5 {
c.timeout = 5
}
rawAddress := c.addr
network, address := resolveURI(c.addr)
client := &http.Client{
Timeout: time.Second * time.Duration(c.timeout),
Transport: &http.Transport{
DialContext: func(_ context.Context, _, addr string) (net.Conn, error) {
if network == "unix" {
host, _, err := net.SplitHostPort(addr)
if err != nil {
host = addr
}
filepath, err := base64.RawURLEncoding.DecodeString(host)
if err == nil {
addr = string(filepath)
}
}
return net.Dial(network, addr)
},
},
}
c.addr = address // clean path
c.client = client
log.Infof("HTTP POST上报器已启动: %v", rawAddress)
c.bot.OnEventPush(c.onBotPushEvent)
log.Infof("HTTP POST上报器已启动: %v", c.addr)
}
func (c *HTTPClient) onBotPushEvent(e *coolq.Event) {
@ -313,7 +337,6 @@ func (c *HTTPClient) onBotPushEvent(e *coolq.Event) {
}
}
client := http.Client{Timeout: time.Second * time.Duration(c.timeout)}
header := make(http.Header)
header.Set("X-Self-ID", strconv.FormatInt(c.bot.Client.Uin, 10))
header.Set("User-Agent", "CQHttp/4.15.0")
@ -338,7 +361,7 @@ func (c *HTTPClient) onBotPushEvent(e *coolq.Event) {
}
req.Header = header
res, err = client.Do(req)
res, err = c.client.Do(req)
if res != nil {
//goland:noinspection GoDeferInLoop
defer res.Body.Close()

View File

@ -2,6 +2,7 @@ package server
import (
"bytes"
"encoding/base64"
"encoding/json"
"fmt"
"net"
@ -77,9 +78,7 @@ var upgrader = websocket.Upgrader{
const wsDefault = ` # 正向WS设置
- ws:
# 正向WS服务器监听地址
host: 127.0.0.1
# 正向WS服务器监听端口
port: 6700
address: 0.0.0.0:8080
middlewares:
<<: *default # 引用默认中间件
`
@ -213,8 +212,25 @@ func runWSClient(b *coolq.CQBot, node yaml.Node) {
}
}
func (c *websocketClient) connect(typ, url string, conptr **wsConn) {
log.Infof("开始尝试连接到反向WebSocket %s服务器: %v", typ, url)
func resolveURI(addr string) (network, address string) {
network, address = "tcp", addr
uri, err := url.Parse(addr)
if err == nil && uri.Scheme != "" {
scheme, ext, _ := strings.Cut(uri.Scheme, "+")
if ext != "" {
network = ext
uri.Scheme = scheme // remove `+unix`/`+tcp4`
if ext == "unix" {
uri.Host = base64.StdEncoding.EncodeToString([]byte(uri.Host + uri.Path))
}
address = uri.String()
}
}
return
}
func (c *websocketClient) connect(typ, addr string, conptr **wsConn) {
log.Infof("开始尝试连接到反向WebSocket %s服务器: %v", typ, addr)
header := http.Header{
"X-Client-Role": []string{typ},
"X-Self-ID": []string{strconv.FormatInt(c.bot.Client.Uin, 10)},
@ -223,12 +239,30 @@ func (c *websocketClient) connect(typ, url string, conptr **wsConn) {
if c.token != "" {
header["Authorization"] = []string{"Token " + c.token}
}
conn, _, err := websocket.DefaultDialer.Dial(url, header) // nolint
network, address := resolveURI(addr)
dialer := websocket.Dialer{
NetDial: func(_, addr string) (net.Conn, error) {
if network == "unix" {
host, _, err := net.SplitHostPort(addr)
if err != nil {
log.Warnf("连接到反向WebSocket %s服务器 %v 时出现错误: %v", typ, url, err)
host = addr
}
filepath, err := base64.RawURLEncoding.DecodeString(host)
if err == nil {
addr = string(filepath)
}
}
return net.Dial(network, addr) // support unix socket transport
},
}
conn, _, err := dialer.Dial(address, header) // nolint
if err != nil {
log.Warnf("连接到反向WebSocket %s服务器 %v 时出现错误: %v", typ, addr, err)
if c.reconnectInterval != 0 {
time.Sleep(c.reconnectInterval)
c.connect(typ, url, conptr)
c.connect(typ, addr, conptr)
}
return
}
@ -242,7 +276,7 @@ func (c *websocketClient) connect(typ, url string, conptr **wsConn) {
}
}
log.Infof("已连接到反向WebSocket %s服务器 %v", typ, url)
log.Infof("已连接到反向WebSocket %s服务器 %v", typ, addr)
var wrappedConn *wsConn
if conptr != nil && *conptr != nil {
@ -261,7 +295,7 @@ func (c *websocketClient) connect(typ, url string, conptr **wsConn) {
}
if typ != "Event" {
go c.listenAPI(typ, url, wrappedConn)
go c.listenAPI(typ, addr, wrappedConn)
}
}