1
0
mirror of https://github.com/Mrs4s/go-cqhttp.git synced 2025-05-05 03:23:49 +08:00

server: support unix socket

Fixes #1415
This commit is contained in:
wdvxdr 2022-03-23 21:06:32 +08:00
parent d42d8dd395
commit 40a765b117
No known key found for this signature in database
GPG Key ID: 703F8C071DE7A1B6
2 changed files with 72 additions and 15 deletions

View File

@ -2,8 +2,10 @@ package server
import ( import (
"bytes" "bytes"
"context"
"crypto/hmac" "crypto/hmac"
"crypto/sha1" "crypto/sha1"
"encoding/base64"
"encoding/hex" "encoding/hex"
"encoding/json" "encoding/json"
"fmt" "fmt"
@ -65,6 +67,7 @@ type HTTPClient struct {
filter string filter string
apiPort int apiPort int
timeout int32 timeout int32
client *http.Client
MaxRetries uint64 MaxRetries uint64
RetriesInterval uint64 RetriesInterval uint64
} }
@ -77,8 +80,7 @@ type httpCtx struct {
const httpDefault = ` const httpDefault = `
- http: # HTTP 通信设置 - http: # HTTP 通信设置
host: 127.0.0.1 # 服务端监听地址 address: 0.0.0.0:5700 # HTTP监听地址
port: 5700 # 服务端监听端口
timeout: 5 # 反向 HTTP 超时时间, 单位秒<5 时将被忽略 timeout: 5 # 反向 HTTP 超时时间, 单位秒<5 时将被忽略
long-polling: # 长轮询拓展 long-polling: # 长轮询拓展
enabled: false # 是否开启 enabled: false # 是否开启
@ -300,8 +302,30 @@ func (c HTTPClient) Run() {
if c.timeout < 5 { if c.timeout < 5 {
c.timeout = 5 c.timeout = 5
} }
rawAddress := c.addr
network, address := resolveURI(c.addr)
client := &http.Client{
Timeout: time.Second * time.Duration(c.timeout),
Transport: &http.Transport{
DialContext: func(_ context.Context, _, addr string) (net.Conn, error) {
if network == "unix" {
host, _, err := net.SplitHostPort(addr)
if err != nil {
host = addr
}
filepath, err := base64.RawURLEncoding.DecodeString(host)
if err == nil {
addr = string(filepath)
}
}
return net.Dial(network, addr)
},
},
}
c.addr = address // clean path
c.client = client
log.Infof("HTTP POST上报器已启动: %v", rawAddress)
c.bot.OnEventPush(c.onBotPushEvent) c.bot.OnEventPush(c.onBotPushEvent)
log.Infof("HTTP POST上报器已启动: %v", c.addr)
} }
func (c *HTTPClient) onBotPushEvent(e *coolq.Event) { func (c *HTTPClient) onBotPushEvent(e *coolq.Event) {
@ -313,7 +337,6 @@ func (c *HTTPClient) onBotPushEvent(e *coolq.Event) {
} }
} }
client := http.Client{Timeout: time.Second * time.Duration(c.timeout)}
header := make(http.Header) header := make(http.Header)
header.Set("X-Self-ID", strconv.FormatInt(c.bot.Client.Uin, 10)) header.Set("X-Self-ID", strconv.FormatInt(c.bot.Client.Uin, 10))
header.Set("User-Agent", "CQHttp/4.15.0") header.Set("User-Agent", "CQHttp/4.15.0")
@ -338,7 +361,7 @@ func (c *HTTPClient) onBotPushEvent(e *coolq.Event) {
} }
req.Header = header req.Header = header
res, err = client.Do(req) res, err = c.client.Do(req)
if res != nil { if res != nil {
//goland:noinspection GoDeferInLoop //goland:noinspection GoDeferInLoop
defer res.Body.Close() defer res.Body.Close()

View File

@ -2,6 +2,7 @@ package server
import ( import (
"bytes" "bytes"
"encoding/base64"
"encoding/json" "encoding/json"
"fmt" "fmt"
"net" "net"
@ -77,9 +78,7 @@ var upgrader = websocket.Upgrader{
const wsDefault = ` # 正向WS设置 const wsDefault = ` # 正向WS设置
- ws: - ws:
# 正向WS服务器监听地址 # 正向WS服务器监听地址
host: 127.0.0.1 address: 0.0.0.0:8080
# 正向WS服务器监听端口
port: 6700
middlewares: middlewares:
<<: *default # 引用默认中间件 <<: *default # 引用默认中间件
` `
@ -213,8 +212,25 @@ func runWSClient(b *coolq.CQBot, node yaml.Node) {
} }
} }
func (c *websocketClient) connect(typ, url string, conptr **wsConn) { func resolveURI(addr string) (network, address string) {
log.Infof("开始尝试连接到反向WebSocket %s服务器: %v", typ, url) network, address = "tcp", addr
uri, err := url.Parse(addr)
if err == nil && uri.Scheme != "" {
scheme, ext, _ := strings.Cut(uri.Scheme, "+")
if ext != "" {
network = ext
uri.Scheme = scheme // remove `+unix`/`+tcp4`
if ext == "unix" {
uri.Host = base64.StdEncoding.EncodeToString([]byte(uri.Host + uri.Path))
}
address = uri.String()
}
}
return
}
func (c *websocketClient) connect(typ, addr string, conptr **wsConn) {
log.Infof("开始尝试连接到反向WebSocket %s服务器: %v", typ, addr)
header := http.Header{ header := http.Header{
"X-Client-Role": []string{typ}, "X-Client-Role": []string{typ},
"X-Self-ID": []string{strconv.FormatInt(c.bot.Client.Uin, 10)}, "X-Self-ID": []string{strconv.FormatInt(c.bot.Client.Uin, 10)},
@ -223,12 +239,30 @@ func (c *websocketClient) connect(typ, url string, conptr **wsConn) {
if c.token != "" { if c.token != "" {
header["Authorization"] = []string{"Token " + c.token} header["Authorization"] = []string{"Token " + c.token}
} }
conn, _, err := websocket.DefaultDialer.Dial(url, header) // nolint
network, address := resolveURI(addr)
dialer := websocket.Dialer{
NetDial: func(_, addr string) (net.Conn, error) {
if network == "unix" {
host, _, err := net.SplitHostPort(addr)
if err != nil {
host = addr
}
filepath, err := base64.RawURLEncoding.DecodeString(host)
if err == nil {
addr = string(filepath)
}
}
return net.Dial(network, addr) // support unix socket transport
},
}
conn, _, err := dialer.Dial(address, header) // nolint
if err != nil { if err != nil {
log.Warnf("连接到反向WebSocket %s服务器 %v 时出现错误: %v", typ, url, err) log.Warnf("连接到反向WebSocket %s服务器 %v 时出现错误: %v", typ, addr, err)
if c.reconnectInterval != 0 { if c.reconnectInterval != 0 {
time.Sleep(c.reconnectInterval) time.Sleep(c.reconnectInterval)
c.connect(typ, url, conptr) c.connect(typ, addr, conptr)
} }
return return
} }
@ -242,7 +276,7 @@ func (c *websocketClient) connect(typ, url string, conptr **wsConn) {
} }
} }
log.Infof("已连接到反向WebSocket %s服务器 %v", typ, url) log.Infof("已连接到反向WebSocket %s服务器 %v", typ, addr)
var wrappedConn *wsConn var wrappedConn *wsConn
if conptr != nil && *conptr != nil { if conptr != nil && *conptr != nil {
@ -261,7 +295,7 @@ func (c *websocketClient) connect(typ, url string, conptr **wsConn) {
} }
if typ != "Event" { if typ != "Event" {
go c.listenAPI(typ, url, wrappedConn) go c.listenAPI(typ, addr, wrappedConn)
} }
} }