diff --git a/docs/config.md b/docs/config.md index d3a2df2..93a299b 100644 --- a/docs/config.md +++ b/docs/config.md @@ -167,6 +167,16 @@ database: # 数据库相关设置 > 注5:关于MIME扫描, 详见[MIME](file.md#MIME) +### 环境变量 + +go-cqhttp 配置文件可以使用占位符来读取**环境变量**的值。 + +```yaml +account: # 账号相关 + uin: ${CQ_UIN} # 读取环境变量 CQ_UIN + password: ${CQ_PASSWORD} # 读取环境变量 CQ_UIN +``` + ## 在线状态 | 状态 | 值 | diff --git a/internal/param/param.go b/internal/param/param.go index 0b70967..c38bf1c 100644 --- a/internal/param/param.go +++ b/internal/param/param.go @@ -3,7 +3,6 @@ package param import ( "math" - "reflect" "regexp" "strings" "sync" @@ -94,37 +93,3 @@ func Base64DecodeString(s string) ([]byte, error) { n, err := e.Decode(dst, utils.S2B(s)) return dst[:n], err } - -// SetAtDefault 在变量 variable 为默认值 defaultValue 的时候修改为 value -func SetAtDefault(variable, value, defaultValue interface{}) { - v := reflect.ValueOf(variable) - v2 := reflect.ValueOf(value) - if v.Kind() != reflect.Ptr || v.IsNil() { - return - } - v = v.Elem() - if v.Interface() != defaultValue { - return - } - if v.Kind() != v2.Kind() { - return - } - v.Set(v2) -} - -// SetExcludeDefault 在目标值 value 不为默认值 defaultValue 时修改 variable 为 value -func SetExcludeDefault(variable, value, defaultValue interface{}) { - v := reflect.ValueOf(variable) - v2 := reflect.ValueOf(value) - if v.Kind() != reflect.Ptr || v.IsNil() { - return - } - v = v.Elem() - if reflect.Indirect(v2).Interface() != defaultValue { - return - } - if v.Kind() != v2.Kind() { - return - } - v.Set(v2) -} diff --git a/modules/config/config.go b/modules/config/config.go index 89ee1c9..e9b049d 100644 --- a/modules/config/config.go +++ b/modules/config/config.go @@ -7,14 +7,11 @@ import ( "fmt" "os" "regexp" - "strconv" "strings" "sync" log "github.com/sirupsen/logrus" "gopkg.in/yaml.v3" - - "github.com/Mrs4s/go-cqhttp/internal/param" ) // defaultConfig 默认配置文件 @@ -74,9 +71,8 @@ type Config struct { // Server 的简介和初始配置 type Server struct { - Brief string - Default string - ParseEnv func() (string, *yaml.Node) + Brief string + Default string } // LevelDBConfig leveldb 相关配置 @@ -93,52 +89,17 @@ type MongoDBConfig struct { // Parse 从默认配置文件路径中获取 func Parse(path string) *Config { - fromEnv := os.Getenv("GCQ_UIN") != "" - file, err := os.ReadFile(path) config := &Config{} if err == nil { err = yaml.NewDecoder(strings.NewReader(expand(string(file), os.Getenv))).Decode(config) - if err != nil && !fromEnv { + if err != nil { log.Fatal("配置文件不合法!", err) } - } else if !fromEnv { + } else { generateConfig() os.Exit(0) } - if fromEnv { - // type convert tools - toInt64 := func(str string) int64 { - i, _ := strconv.ParseInt(str, 10, 64) - return i - } - - // load config from environment variable - param.SetAtDefault(&config.Account.Uin, toInt64(os.Getenv("GCQ_UIN")), int64(0)) - param.SetAtDefault(&config.Account.Password, os.Getenv("GCQ_PWD"), "") - param.SetAtDefault(&config.Account.Status, int32(toInt64(os.Getenv("GCQ_STATUS"))), int32(0)) - param.SetAtDefault(&config.Account.ReLogin.Disabled, !param.EnsureBool(os.Getenv("GCQ_RELOGIN_DISABLED"), true), false) - param.SetAtDefault(&config.Account.ReLogin.Delay, uint(toInt64(os.Getenv("GCQ_RELOGIN_DELAY"))), uint(0)) - param.SetAtDefault(&config.Account.ReLogin.MaxTimes, uint(toInt64(os.Getenv("GCQ_RELOGIN_MAX_TIMES"))), uint(0)) - dbConf := &LevelDBConfig{Enable: param.EnsureBool(os.Getenv("GCQ_LEVELDB"), true)} - if config.Database == nil { - config.Database = make(map[string]yaml.Node) - } - config.Database["leveldb"] = func() yaml.Node { - n := &yaml.Node{} - _ = n.Encode(dbConf) - return *n - }() - - for _, s := range serverconfs { - if s.ParseEnv != nil { - name, node := s.ParseEnv() - if node != nil { - config.Servers = append(config.Servers, map[string]yaml.Node{name: *node}) - } - } - } - } return config } diff --git a/modules/config/config_test.go b/modules/config/config_test.go new file mode 100644 index 0000000..295a450 --- /dev/null +++ b/modules/config/config_test.go @@ -0,0 +1,30 @@ +package config + +import ( + "strings" + "testing" +) + +func Test_expand(t *testing.T) { + tests := []struct { + src string + mapping func(string) string + expected string + }{ + { + src: "foo: ${bar}", + mapping: strings.ToUpper, + expected: "foo: BAR", + }, + { + src: "$123", + mapping: strings.ToUpper, + expected: "$123", + }, + } + for i, tt := range tests { + if got := expand(tt.src, tt.mapping); got != tt.expected { + t.Errorf("testcase %d failed, expected %v but got %v", i, tt.expected, got) + } + } +} diff --git a/server/http.go b/server/http.go index 1df3cb5..dcde4f9 100644 --- a/server/http.go +++ b/server/http.go @@ -24,7 +24,6 @@ import ( "github.com/Mrs4s/go-cqhttp/coolq" "github.com/Mrs4s/go-cqhttp/global" - "github.com/Mrs4s/go-cqhttp/internal/param" "github.com/Mrs4s/go-cqhttp/modules/api" "github.com/Mrs4s/go-cqhttp/modules/config" "github.com/Mrs4s/go-cqhttp/modules/filter" @@ -97,50 +96,8 @@ const httpDefault = ` # retries-interval: 1000 # 重试时间,单位毫秒,0 时立即 ` -func nilParseUint(s string, base int, bitSize int) *uint64 { - pu, err := strconv.ParseUint(s, base, bitSize) - if err != nil { - return nil - } - return &pu -} - -func readEnvConfig() (string, *yaml.Node) { - if s, ok := os.LookupEnv("GCQ_HTTP_PORT"); !ok || s == "" { - return "", nil - } - - // type convert tools - toInt64 := func(str string) int64 { - i, _ := strconv.ParseInt(str, 10, 64) - return i - } - accessTokenEnv := os.Getenv("GCQ_ACCESS_TOKEN") - node := &yaml.Node{} - httpConf := &HTTPServer{ - Host: "0.0.0.0", - Port: 5700, - MiddleWares: MiddleWares{ - AccessToken: accessTokenEnv, - }, - } - param.SetExcludeDefault(&httpConf.Disabled, param.EnsureBool(os.Getenv("GCQ_HTTP_DISABLE"), false), false) - param.SetExcludeDefault(&httpConf.Host, os.Getenv("GCQ_HTTP_HOST"), "") - param.SetExcludeDefault(&httpConf.Port, int(toInt64(os.Getenv("GCQ_HTTP_PORT"))), 0) - if os.Getenv("GCQ_HTTP_POST_URL") != "" { - httpConf.Post = append(httpConf.Post, httpServerPost{ - os.Getenv("GCQ_HTTP_POST_URL"), - os.Getenv("GCQ_HTTP_POST_SECRET"), - nilParseUint(os.Getenv("GCQ_HTTP_POST_MAXRETRIES"), 10, 64), - nilParseUint(os.Getenv("GCQ_HTTP_POST_RETRIESINTERVAL"), 10, 64), - }) - } - _ = node.Encode(httpConf) - return "http", node -} - func init() { - config.AddServer(&config.Server{Brief: "HTTP通信", Default: httpDefault, ParseEnv: readEnvConfig}) + config.AddServer(&config.Server{Brief: "HTTP通信", Default: httpDefault}) } var joinQuery = regexp.MustCompile(`\[(.+?),(.+?)]\.0`) diff --git a/server/websocket.go b/server/websocket.go index 964b926..20d3a73 100644 --- a/server/websocket.go +++ b/server/websocket.go @@ -5,7 +5,6 @@ import ( "encoding/json" "fmt" "net/http" - "os" "runtime/debug" "strconv" "strings" @@ -20,7 +19,6 @@ import ( "github.com/Mrs4s/go-cqhttp/coolq" "github.com/Mrs4s/go-cqhttp/global" - "github.com/Mrs4s/go-cqhttp/internal/param" "github.com/Mrs4s/go-cqhttp/modules/api" "github.com/Mrs4s/go-cqhttp/modules/config" "github.com/Mrs4s/go-cqhttp/modules/filter" @@ -123,52 +121,10 @@ func init() { config.AddServer(&config.Server{ Brief: "正向 Websocket 通信", Default: wsDefault, - ParseEnv: func() (string, *yaml.Node) { - if os.Getenv("GCQ_WS_PORT") != "" { - // type convert tools - toInt64 := func(str string) int64 { - i, _ := strconv.ParseInt(str, 10, 64) - return i - } - accessTokenEnv := os.Getenv("GCQ_ACCESS_TOKEN") - node := &yaml.Node{} - wsServerConf := &WebsocketServer{ - Host: "0.0.0.0", - Port: 6700, - MiddleWares: MiddleWares{ - AccessToken: accessTokenEnv, - }, - } - param.SetExcludeDefault(&wsServerConf.Disabled, param.EnsureBool(os.Getenv("GCQ_WS_DISABLE"), false), false) - param.SetExcludeDefault(&wsServerConf.Host, os.Getenv("GCQ_WS_HOST"), "") - param.SetExcludeDefault(&wsServerConf.Port, int(toInt64(os.Getenv("GCQ_WS_PORT"))), 0) - _ = node.Encode(wsServerConf) - return "ws", node - } - return "", nil - }, }) config.AddServer(&config.Server{ Brief: "反向 Websocket 通信", Default: wsReverseDefault, - ParseEnv: func() (string, *yaml.Node) { - if os.Getenv("GCQ_RWS_API") != "" || os.Getenv("GCQ_RWS_EVENT") != "" || os.Getenv("GCQ_RWS_UNIVERSAL") != "" { - accessTokenEnv := os.Getenv("GCQ_ACCESS_TOKEN") - node := &yaml.Node{} - rwsConf := &WebsocketReverse{ - MiddleWares: MiddleWares{ - AccessToken: accessTokenEnv, - }, - } - param.SetExcludeDefault(&rwsConf.Disabled, param.EnsureBool(os.Getenv("GCQ_RWS_DISABLE"), false), false) - param.SetExcludeDefault(&rwsConf.API, os.Getenv("GCQ_RWS_API"), "") - param.SetExcludeDefault(&rwsConf.Event, os.Getenv("GCQ_RWS_EVENT"), "") - param.SetExcludeDefault(&rwsConf.Universal, os.Getenv("GCQ_RWS_UNIVERSAL"), "") - _ = node.Encode(rwsConf) - return "ws-reverse", node - } - return "", nil - }, }) }