1
0
mirror of https://github.com/Mrs4s/go-cqhttp.git synced 2025-05-07 04:23:49 +08:00

Merge pull request #713 from scjtqs/origin/dev

添加cli命令行 -c 指定加载配置文件
This commit is contained in:
Mrs4s 2021-03-17 10:33:59 +08:00 committed by GitHub
commit 965c5377eb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 74 additions and 49 deletions

View File

@ -1,8 +1,14 @@
package global package global
import ( import (
"errors"
"os" "os"
"os/exec"
"path"
"path/filepath"
"runtime"
"strconv" "strconv"
"strings"
"time" "time"
"github.com/hjson/hjson-go" "github.com/hjson/hjson-go"
@ -12,6 +18,9 @@ import (
var json = jsoniter.ConfigCompatibleWithStandardLibrary var json = jsoniter.ConfigCompatibleWithStandardLibrary
var currentPath = getCurrentPath()
var DefaultConfFile = path.Join(currentPath, "config.hjson")
// DefaultConfigWithComments 为go-cqhttp的默认配置文件 // DefaultConfigWithComments 为go-cqhttp的默认配置文件
var DefaultConfigWithComments = ` var DefaultConfigWithComments = `
/* /*
@ -312,3 +321,33 @@ func (c *JSONConfig) Save(path string) error {
} }
return WriteAllText(path, string(data)) return WriteAllText(path, string(data))
} }
// getCurrentPath 获取当前文件的路径直接返回string
func getCurrentPath() string {
cwd, e := GetCurrentPath()
if e != nil {
panic(e)
}
return cwd
}
// GetCurrentPath 预留,获取当前目录地址
func GetCurrentPath() (string, error) {
file, err := exec.LookPath(os.Args[0])
if err != nil {
return "", err
}
fpath, err := filepath.Abs(file)
if err != nil {
return "", err
}
if runtime.GOOS == "windows" {
// fpath = strings.Replace(fpath, "\\", "/", -1)
fpath = strings.ReplaceAll(fpath, "\\", "/")
}
i := strings.LastIndex(fpath, "/")
if i < 0 {
return "", errors.New("system/path_error,Can't find '/' or '\\'")
}
return fpath[0 : i+1], nil
}

25
main.go
View File

@ -43,14 +43,20 @@ import (
var json = jsoniter.ConfigCompatibleWithStandardLibrary var json = jsoniter.ConfigCompatibleWithStandardLibrary
var conf *global.JSONConfig var conf *global.JSONConfig
var isFastStart = false var isFastStart = false
var c string
var d bool var d bool
var h bool var h bool
func init() { func init() {
var debug bool
flag.StringVar(&c, "c", global.DefaultConfFile, "configuration filename default is config.hjson")
flag.BoolVar(&d, "d", false, "running as a daemon") flag.BoolVar(&d, "d", false, "running as a daemon")
flag.BoolVar(&debug, "D", false, "debug mode")
flag.BoolVar(&h, "h", false, "this help") flag.BoolVar(&h, "h", false, "this help")
flag.Parse() flag.Parse()
// 通过-c 参数替换 配置文件路径
global.DefaultConfFile = c
logFormatter := &easy.Formatter{ logFormatter := &easy.Formatter{
TimestampFormat: "2006-01-02 15:04:05", TimestampFormat: "2006-01-02 15:04:05",
LogFormat: "[%time%] [%lvl%]: %msg% \n", LogFormat: "[%time%] [%lvl%]: %msg% \n",
@ -66,6 +72,9 @@ func init() {
os.Exit(1) os.Exit(1)
} }
if debug {
conf.Debug = true
}
// 在debug模式下,将在标准输出中打印当前执行行数 // 在debug模式下,将在标准输出中打印当前执行行数
if conf.Debug { if conf.Debug {
log.SetReportCaller(true) log.SetReportCaller(true)
@ -97,8 +106,8 @@ func init() {
goConf.ReverseServers[0].ReverseEventURL = conf.WSReverseEventURL goConf.ReverseServers[0].ReverseEventURL = conf.WSReverseEventURL
goConf.ReverseServers[0].ReverseReconnectInterval = conf.WSReverseReconnectInterval goConf.ReverseServers[0].ReverseReconnectInterval = conf.WSReverseReconnectInterval
} }
if err := goConf.Save("config.hjson"); err != nil { if err := goConf.Save(global.DefaultConfFile); err != nil {
log.Fatalf("保存 config.hjson 时出现错误: %v", err) log.Fatalf("保存 %s 时出现错误: %v", global.DefaultConfFile, err)
} }
_ = os.Remove("cqhttp.json") _ = os.Remove("cqhttp.json")
} }
@ -159,7 +168,7 @@ func main() {
time.Sleep(time.Second * 10) time.Sleep(time.Second * 10)
} }
if conf.Uin == 0 || (conf.Password == "" && conf.PasswordEncrypted == "") { if conf.Uin == 0 || (conf.Password == "" && conf.PasswordEncrypted == "") {
log.Warnf("请修改 config.hjson 以添加账号密码.") log.Warnf("请修改 %s 以添加账号密码.", global.DefaultConfFile)
if !isFastStart { if !isFastStart {
time.Sleep(time.Second * 5) time.Sleep(time.Second * 5)
} }
@ -194,7 +203,7 @@ func main() {
global.PasswordHash = md5.Sum([]byte(conf.Password)) global.PasswordHash = md5.Sum([]byte(conf.Password))
conf.Password = "" conf.Password = ""
conf.PasswordEncrypted = "AES:" + PasswordHashEncrypt(global.PasswordHash[:], byteKey) conf.PasswordEncrypted = "AES:" + PasswordHashEncrypt(global.PasswordHash[:], byteKey)
_ = conf.Save("config.hjson") _ = conf.Save(global.DefaultConfFile)
} }
if conf.PasswordEncrypted != "" { if conf.PasswordEncrypted != "" {
if len(byteKey) == 0 { if len(byteKey) == 0 {
@ -223,7 +232,7 @@ func main() {
passwordHash := md5.Sum([]byte(password)) passwordHash := md5.Sum([]byte(password))
newPasswordHash := PasswordHashEncrypt(passwordHash[:], byteKey) newPasswordHash := PasswordHashEncrypt(passwordHash[:], byteKey)
conf.PasswordEncrypted = "AES:" + newPasswordHash conf.PasswordEncrypted = "AES:" + newPasswordHash
_ = conf.Save("config.hjson") _ = conf.Save(global.DefaultConfFile)
log.Debug("密码加密方案升级完成") log.Debug("密码加密方案升级完成")
} }
@ -518,15 +527,15 @@ func getConfig() *global.JSONConfig {
conf.HTTPConfig.PostUrls[post] = os.Getenv("HTTP_SECRET") conf.HTTPConfig.PostUrls[post] = os.Getenv("HTTP_SECRET")
} }
default: default:
conf = global.LoadConfig("config.hjson") conf = global.LoadConfig(global.DefaultConfFile)
} }
if conf == nil { if conf == nil {
err := global.WriteAllText("config.hjson", global.DefaultConfigWithComments) err := global.WriteAllText(global.DefaultConfFile, global.DefaultConfigWithComments)
if err != nil { if err != nil {
log.Fatalf("创建默认配置文件时出现错误: %v", err) log.Fatalf("创建默认配置文件时出现错误: %v", err)
return nil return nil
} }
log.Infof("默认配置文件已生成, 请编辑 config.hjson 后重启程序.") log.Infof("默认配置文件已生成, 请编辑 %s 后重启程序.", global.DefaultConfFile)
if !isFastStart { if !isFastStart {
time.Sleep(time.Second * 5) time.Sleep(time.Second * 5)
} }

View File

@ -361,7 +361,7 @@ func GetConf() *global.JSONConfig {
if JSONConfig != nil { if JSONConfig != nil {
return JSONConfig return JSONConfig
} }
conf := global.LoadConfig("config.hjson") conf := global.LoadConfig(global.DefaultConfFile)
return conf return conf
} }
@ -559,9 +559,9 @@ func AdminDoConfigBase(s *webServer, c *gin.Context) {
conf.EnableDB = false conf.EnableDB = false
} }
conf.AccessToken = c.PostForm("access_token") conf.AccessToken = c.PostForm("access_token")
if err := conf.Save("config.hjson"); err != nil { if err := conf.Save(global.DefaultConfFile); err != nil {
log.Fatalf("保存 config.hjson 时出现错误: %v", err) log.Fatalf("保存 %s 时出现错误: %v", global.DefaultConfFile, err)
c.JSON(200, Failed(502, "保存 config.hjson 时出现错误:"+fmt.Sprintf("%v", err))) c.JSON(200, Failed(502, "保存 "+global.DefaultConfFile+" 时出现错误:"+fmt.Sprintf("%v", err)))
} else { } else {
JSONConfig = nil JSONConfig = nil
c.JSON(200, coolq.OK(coolq.MSG{})) c.JSON(200, coolq.OK(coolq.MSG{}))
@ -584,9 +584,9 @@ func AdminDoConfigHTTP(s *webServer, c *gin.Context) {
if c.PostForm("post_url") != "" { if c.PostForm("post_url") != "" {
conf.HTTPConfig.PostUrls[c.PostForm("post_url")] = c.PostForm("post_secret") conf.HTTPConfig.PostUrls[c.PostForm("post_url")] = c.PostForm("post_secret")
} }
if err := conf.Save("config.hjson"); err != nil { if err := conf.Save(global.DefaultConfFile); err != nil {
log.Fatalf("保存 config.hjson 时出现错误: %v", err) log.Fatalf("保存 %s 时出现错误: %v", global.DefaultConfFile, err)
c.JSON(200, Failed(502, "保存 config.hjson 时出现错误:"+fmt.Sprintf("%v", err))) c.JSON(200, Failed(502, "保存 "+global.DefaultConfFile+" 时出现错误:"+fmt.Sprintf("%v", err)))
} else { } else {
JSONConfig = nil JSONConfig = nil
c.JSON(200, coolq.OK(coolq.MSG{})) c.JSON(200, coolq.OK(coolq.MSG{}))
@ -604,9 +604,9 @@ func AdminDoConfigWS(s *webServer, c *gin.Context) {
} else { } else {
conf.WSConfig.Enabled = false conf.WSConfig.Enabled = false
} }
if err := conf.Save("config.hjson"); err != nil { if err := conf.Save(global.DefaultConfFile); err != nil {
log.Fatalf("保存 config.hjson 时出现错误: %v", err) log.Fatalf("保存 %s 时出现错误: %v", global.DefaultConfFile, err)
c.JSON(200, Failed(502, "保存 config.hjson 时出现错误:"+fmt.Sprintf("%v", err))) c.JSON(200, Failed(502, "保存 "+global.DefaultConfFile+" 时出现错误:"+fmt.Sprintf("%v", err)))
} else { } else {
JSONConfig = nil JSONConfig = nil
c.JSON(200, coolq.OK(coolq.MSG{})) c.JSON(200, coolq.OK(coolq.MSG{}))
@ -626,9 +626,9 @@ func AdminDoConfigReverseWS(s *webServer, c *gin.Context) {
} else { } else {
conf.ReverseServers[0].Enabled = false conf.ReverseServers[0].Enabled = false
} }
if err := conf.Save("config.hjson"); err != nil { if err := conf.Save(global.DefaultConfFile); err != nil {
log.Fatalf("保存 config.hjson 时出现错误: %v", err) log.Fatalf("保存 %s 时出现错误: %v", global.DefaultConfFile, err)
c.JSON(200, Failed(502, "保存 config.hjson 时出现错误:"+fmt.Sprintf("%v", err))) c.JSON(200, Failed(502, "保存 "+global.DefaultConfFile+" 时出现错误:"+fmt.Sprintf("%v", err)))
} else { } else {
JSONConfig = nil JSONConfig = nil
c.JSON(200, coolq.OK(coolq.MSG{})) c.JSON(200, coolq.OK(coolq.MSG{}))
@ -641,13 +641,13 @@ func AdminDoConfigJSON(s *webServer, c *gin.Context) {
JSON := c.PostForm("json") JSON := c.PostForm("json")
err := json.Unmarshal([]byte(JSON), &conf) err := json.Unmarshal([]byte(JSON), &conf)
if err != nil { if err != nil {
log.Warnf("尝试加载配置文件 %v 时出现错误: %v", "config.hjson", err) log.Warnf("尝试加载配置文件 %v 时出现错误: %v", global.DefaultConfFile, err)
c.JSON(200, Failed(502, "保存 config.hjson 时出现错误:"+fmt.Sprintf("%v", err))) c.JSON(200, Failed(502, "保存 "+global.DefaultConfFile+" 时出现错误:"+fmt.Sprintf("%v", err)))
return return
} }
if err := conf.Save("config.hjson"); err != nil { if err := conf.Save(global.DefaultConfFile); err != nil {
log.Fatalf("保存 config.hjson 时出现错误: %v", err) log.Fatalf("保存 %s 时出现错误: %v", global.DefaultConfFile, err)
c.JSON(200, Failed(502, "保存 config.hjson 时出现错误:"+fmt.Sprintf("%v", err))) c.JSON(200, Failed(502, "保存 "+global.DefaultConfFile+" 时出现错误:"+fmt.Sprintf("%v", err)))
} else { } else {
JSONConfig = nil JSONConfig = nil
c.JSON(200, coolq.OK(coolq.MSG{})) c.JSON(200, coolq.OK(coolq.MSG{}))

View File

@ -2,12 +2,9 @@
package server package server
import ( import (
"errors"
"fmt" "fmt"
"os" "os"
"os/exec" "os/exec"
"path/filepath"
"runtime"
"strings" "strings"
"github.com/Mrs4s/go-cqhttp/global" "github.com/Mrs4s/go-cqhttp/global"
@ -51,23 +48,3 @@ func Daemon() {
func savePid(path string, data string) error { func savePid(path string, data string) error {
return global.WriteAllText(path, data) return global.WriteAllText(path, data)
} }
// GetCurrentPath 预留,获取当前目录地址
func GetCurrentPath() (string, error) {
file, err := exec.LookPath(os.Args[0])
if err != nil {
return "", err
}
path, err := filepath.Abs(file)
if err != nil {
return "", err
}
if runtime.GOOS == "windows" {
path = strings.ReplaceAll(path, "\\", "/")
}
i := strings.LastIndex(path, "/")
if i < 0 {
return "", errors.New("system/path_error,Can't find '/' or '\\'")
}
return path[0 : i+1], nil
}