1
0
mirror of https://github.com/Mrs4s/go-cqhttp.git synced 2025-05-05 19:43:49 +08:00

添加cli命令行 -c 指定加载配置文件

This commit is contained in:
scjtqs 2021-03-12 17:36:15 +08:00
parent 2e5f33fb08
commit 678e5023cc
4 changed files with 86 additions and 63 deletions

View File

@ -1,17 +1,26 @@
package global package global
import ( import (
"errors"
"os" "os"
"os/exec"
"path"
"path/filepath"
"runtime"
"strconv" "strconv"
"strings"
"time" "time"
"github.com/hjson/hjson-go" "github.com/hjson/hjson-go"
jsoniter "github.com/json-iterator/go" "github.com/json-iterator/go"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
var json = jsoniter.ConfigCompatibleWithStandardLibrary var json = jsoniter.ConfigCompatibleWithStandardLibrary
var currentPath=getCurrentPath()
var DefaultConfFile =path.Join(currentPath,"config.hjson")
// DefaultConfigWithComments 为go-cqhttp的默认配置文件 // DefaultConfigWithComments 为go-cqhttp的默认配置文件
var DefaultConfigWithComments = ` var DefaultConfigWithComments = `
/* /*
@ -312,3 +321,34 @@ func (c *JSONConfig) Save(path string) error {
} }
return WriteAllText(path, string(data)) return WriteAllText(path, string(data))
} }
// getCurrentPath 获取当前文件的路径直接返回string
func getCurrentPath() string {
cwd, e := GetCurrentPath()
if e != nil {
panic(e)
}
return cwd
}
// GetCurrentPath 预留,获取当前目录地址
func GetCurrentPath() (string, error) {
file, err := exec.LookPath(os.Args[0])
if err != nil {
return "", err
}
fpath, err := filepath.Abs(file)
if err != nil {
return "", err
}
//fmt.Println("path111:", path)
if runtime.GOOS == "windows" {
fpath = strings.Replace(fpath, "\\", "/", -1)
}
//fmt.Println("path222:", path)
i := strings.LastIndex(fpath, "/")
if i < 0 {
return "", errors.New("system/path_error,Can't find '/' or '\\'");
}
//fmt.Println("path333:", path)
return string(fpath[0 : i+1]), nil
}

33
main.go
View File

@ -10,8 +10,8 @@ import (
"flag" "flag"
"fmt" "fmt"
"github.com/Mrs4s/go-cqhttp/global/terminal" "github.com/Mrs4s/go-cqhttp/global/terminal"
rotatelogs "github.com/lestrrat-go/file-rotatelogs" "github.com/lestrrat-go/file-rotatelogs"
easy "github.com/t-tomalak/logrus-easy-formatter" "github.com/t-tomalak/logrus-easy-formatter"
"io" "io"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
@ -36,22 +36,28 @@ import (
"github.com/Mrs4s/MiraiGo/client" "github.com/Mrs4s/MiraiGo/client"
"github.com/Mrs4s/go-cqhttp/coolq" "github.com/Mrs4s/go-cqhttp/coolq"
"github.com/Mrs4s/go-cqhttp/global" "github.com/Mrs4s/go-cqhttp/global"
jsoniter "github.com/json-iterator/go" "github.com/json-iterator/go"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
var json = jsoniter.ConfigCompatibleWithStandardLibrary var json = jsoniter.ConfigCompatibleWithStandardLibrary
var conf *global.JSONConfig var conf *global.JSONConfig
var isFastStart = false var isFastStart = false
var c string
var d bool var d bool
var h bool var h bool
func init() { func init() {
var debug bool
flag.StringVar(&c, "c", global.DefaultConfFile, "configuration filename default is config.hjson")
flag.BoolVar(&d, "d", false, "running as a daemon") flag.BoolVar(&d, "d", false, "running as a daemon")
flag.BoolVar(&debug, "D", false, "debug mode")
flag.BoolVar(&h, "h", false, "this help") flag.BoolVar(&h, "h", false, "this help")
flag.Parse() flag.Parse()
//通过-c 参数替换 配置文件路径
global.DefaultConfFile = c
logFormatter := &easy.Formatter{ logFormatter := &easy.Formatter{
TimestampFormat: "2006-01-02 15:04:05", TimestampFormat: "2006-01-02 15:04:05",
LogFormat: "[%time%] [%lvl%]: %msg% \n", LogFormat: "[%time%] [%lvl%]: %msg% \n",
@ -67,8 +73,11 @@ func init() {
os.Exit(1) os.Exit(1)
} }
if debug {
conf.Debug=true
}
// 在debug模式下,将在标准输出中打印当前执行行数 // 在debug模式下,将在标准输出中打印当前执行行数
if conf.Debug { if conf.Debug{
log.SetReportCaller(true) log.SetReportCaller(true)
} }
@ -98,8 +107,8 @@ func init() {
goConf.ReverseServers[0].ReverseEventURL = conf.WSReverseEventURL goConf.ReverseServers[0].ReverseEventURL = conf.WSReverseEventURL
goConf.ReverseServers[0].ReverseReconnectInterval = conf.WSReverseReconnectInterval goConf.ReverseServers[0].ReverseReconnectInterval = conf.WSReverseReconnectInterval
} }
if err := goConf.Save("config.hjson"); err != nil { if err := goConf.Save(global.DefaultConfFile); err != nil {
log.Fatalf("保存 config.hjson 时出现错误: %v", err) log.Fatalf("保存 %s 时出现错误: %v", global.DefaultConfFile, err)
} }
_ = os.Remove("cqhttp.json") _ = os.Remove("cqhttp.json")
} }
@ -160,7 +169,7 @@ func main() {
time.Sleep(time.Second * 10) time.Sleep(time.Second * 10)
} }
if conf.Uin == 0 || (conf.Password == "" && conf.PasswordEncrypted == "") { if conf.Uin == 0 || (conf.Password == "" && conf.PasswordEncrypted == "") {
log.Warnf("请修改 config.hjson 以添加账号密码.") log.Warnf("请修改 %s 以添加账号密码.", global.DefaultConfFile)
if !isFastStart { if !isFastStart {
time.Sleep(time.Second * 5) time.Sleep(time.Second * 5)
} }
@ -195,7 +204,7 @@ func main() {
global.PasswordHash = md5.Sum([]byte(conf.Password)) global.PasswordHash = md5.Sum([]byte(conf.Password))
conf.Password = "" conf.Password = ""
conf.PasswordEncrypted = "AES:" + PasswordHashEncrypt(global.PasswordHash[:], byteKey) conf.PasswordEncrypted = "AES:" + PasswordHashEncrypt(global.PasswordHash[:], byteKey)
_ = conf.Save("config.hjson") _ = conf.Save(global.DefaultConfFile)
} }
if conf.PasswordEncrypted != "" { if conf.PasswordEncrypted != "" {
if len(byteKey) == 0 { if len(byteKey) == 0 {
@ -224,7 +233,7 @@ func main() {
passwordHash := md5.Sum([]byte(password)) passwordHash := md5.Sum([]byte(password))
newPasswordHash := PasswordHashEncrypt(passwordHash[:], byteKey) newPasswordHash := PasswordHashEncrypt(passwordHash[:], byteKey)
conf.PasswordEncrypted = "AES:" + newPasswordHash conf.PasswordEncrypted = "AES:" + newPasswordHash
_ = conf.Save("config.hjson") _ = conf.Save(global.DefaultConfFile)
log.Debug("密码加密方案升级完成") log.Debug("密码加密方案升级完成")
} }
@ -518,15 +527,15 @@ func getConfig() *global.JSONConfig {
conf.HTTPConfig.PostUrls[post] = os.Getenv("HTTP_SECRET") conf.HTTPConfig.PostUrls[post] = os.Getenv("HTTP_SECRET")
} }
} else { } else {
conf = global.LoadConfig("config.hjson") conf = global.LoadConfig(global.DefaultConfFile)
} }
if conf == nil { if conf == nil {
err := global.WriteAllText("config.hjson", global.DefaultConfigWithComments) err := global.WriteAllText(global.DefaultConfFile, global.DefaultConfigWithComments)
if err != nil { if err != nil {
log.Fatalf("创建默认配置文件时出现错误: %v", err) log.Fatalf("创建默认配置文件时出现错误: %v", err)
return nil return nil
} }
log.Infof("默认配置文件已生成, 请编辑 config.hjson 后重启程序.") log.Infof("默认配置文件已生成, 请编辑 %s 后重启程序.", global.DefaultConfFile)
if !isFastStart { if !isFastStart {
time.Sleep(time.Second * 5) time.Sleep(time.Second * 5)
} }

View File

@ -22,10 +22,10 @@ import (
"github.com/Mrs4s/go-cqhttp/coolq" "github.com/Mrs4s/go-cqhttp/coolq"
"github.com/Mrs4s/go-cqhttp/global" "github.com/Mrs4s/go-cqhttp/global"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
jsoniter "github.com/json-iterator/go" "github.com/json-iterator/go"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
asciiart "github.com/yinghau76/go-ascii-art" "github.com/yinghau76/go-ascii-art"
) )
var json = jsoniter.ConfigCompatibleWithStandardLibrary var json = jsoniter.ConfigCompatibleWithStandardLibrary
@ -364,7 +364,7 @@ func GetConf() *global.JSONConfig {
if JSONConfig != nil { if JSONConfig != nil {
return JSONConfig return JSONConfig
} }
conf := global.LoadConfig("config.hjson") conf := global.LoadConfig(global.DefaultConfFile)
return conf return conf
} }
@ -559,9 +559,9 @@ func AdminDoConfigBase(s *webServer, c *gin.Context) {
conf.EnableDB = false conf.EnableDB = false
} }
conf.AccessToken = c.PostForm("access_token") conf.AccessToken = c.PostForm("access_token")
if err := conf.Save("config.hjson"); err != nil { if err := conf.Save(global.DefaultConfFile); err != nil {
log.Fatalf("保存 config.hjson 时出现错误: %v", err) log.Fatalf("保存 %s 时出现错误: %v", global.DefaultConfFile, err)
c.JSON(200, Failed(502, "保存 config.hjson 时出现错误:"+fmt.Sprintf("%v", err))) c.JSON(200, Failed(502, "保存 "+global.DefaultConfFile+" 时出现错误:"+fmt.Sprintf("%v", err)))
} else { } else {
JSONConfig = nil JSONConfig = nil
c.JSON(200, coolq.OK(coolq.MSG{})) c.JSON(200, coolq.OK(coolq.MSG{}))
@ -584,9 +584,9 @@ func AdminDoConfigHTTP(s *webServer, c *gin.Context) {
if c.PostForm("post_url") != "" { if c.PostForm("post_url") != "" {
conf.HTTPConfig.PostUrls[c.PostForm("post_url")] = c.PostForm("post_secret") conf.HTTPConfig.PostUrls[c.PostForm("post_url")] = c.PostForm("post_secret")
} }
if err := conf.Save("config.hjson"); err != nil { if err := conf.Save(global.DefaultConfFile); err != nil {
log.Fatalf("保存 config.hjson 时出现错误: %v", err) log.Fatalf("保存 %s 时出现错误: %v", global.DefaultConfFile, err)
c.JSON(200, Failed(502, "保存 config.hjson 时出现错误:"+fmt.Sprintf("%v", err))) c.JSON(200, Failed(502, "保存 "+global.DefaultConfFile+" 时出现错误:"+fmt.Sprintf("%v", err)))
} else { } else {
JSONConfig = nil JSONConfig = nil
c.JSON(200, coolq.OK(coolq.MSG{})) c.JSON(200, coolq.OK(coolq.MSG{}))
@ -604,9 +604,9 @@ func AdminDoConfigWS(s *webServer, c *gin.Context) {
} else { } else {
conf.WSConfig.Enabled = false conf.WSConfig.Enabled = false
} }
if err := conf.Save("config.hjson"); err != nil { if err := conf.Save(global.DefaultConfFile); err != nil {
log.Fatalf("保存 config.hjson 时出现错误: %v", err) log.Fatalf("保存 %s 时出现错误: %v", global.DefaultConfFile, err)
c.JSON(200, Failed(502, "保存 config.hjson 时出现错误:"+fmt.Sprintf("%v", err))) c.JSON(200, Failed(502, "保存 "+global.DefaultConfFile+" 时出现错误:"+fmt.Sprintf("%v", err)))
} else { } else {
JSONConfig = nil JSONConfig = nil
c.JSON(200, coolq.OK(coolq.MSG{})) c.JSON(200, coolq.OK(coolq.MSG{}))
@ -626,9 +626,9 @@ func AdminDoConfigReverseWS(s *webServer, c *gin.Context) {
} else { } else {
conf.ReverseServers[0].Enabled = false conf.ReverseServers[0].Enabled = false
} }
if err := conf.Save("config.hjson"); err != nil { if err := conf.Save(global.DefaultConfFile); err != nil {
log.Fatalf("保存 config.hjson 时出现错误: %v", err) log.Fatalf("保存 %s 时出现错误: %v", global.DefaultConfFile, err)
c.JSON(200, Failed(502, "保存 config.hjson 时出现错误:"+fmt.Sprintf("%v", err))) c.JSON(200, Failed(502, "保存 "+global.DefaultConfFile+" 时出现错误:"+fmt.Sprintf("%v", err)))
} else { } else {
JSONConfig = nil JSONConfig = nil
c.JSON(200, coolq.OK(coolq.MSG{})) c.JSON(200, coolq.OK(coolq.MSG{}))
@ -641,13 +641,13 @@ func AdminDoConfigJSON(s *webServer, c *gin.Context) {
JSON := c.PostForm("json") JSON := c.PostForm("json")
err := json.Unmarshal([]byte(JSON), &conf) err := json.Unmarshal([]byte(JSON), &conf)
if err != nil { if err != nil {
log.Warnf("尝试加载配置文件 %v 时出现错误: %v", "config.hjson", err) log.Warnf("尝试加载配置文件 %v 时出现错误: %v", global.DefaultConfFile, err)
c.JSON(200, Failed(502, "保存 config.hjson 时出现错误:"+fmt.Sprintf("%v", err))) c.JSON(200, Failed(502, "保存 "+global.DefaultConfFile+" 时出现错误:"+fmt.Sprintf("%v", err)))
return return
} }
if err := conf.Save("config.hjson"); err != nil { if err := conf.Save(global.DefaultConfFile); err != nil {
log.Fatalf("保存 config.hjson 时出现错误: %v", err) log.Fatalf("保存 %s 时出现错误: %v", global.DefaultConfFile, err)
c.JSON(200, Failed(502, "保存 config.hjson 时出现错误:"+fmt.Sprintf("%v", err))) c.JSON(200, Failed(502, "保存 "+global.DefaultConfFile+" 时出现错误:"+fmt.Sprintf("%v", err)))
} else { } else {
JSONConfig = nil JSONConfig = nil
c.JSON(200, coolq.OK(coolq.MSG{})) c.JSON(200, coolq.OK(coolq.MSG{}))

View File

@ -2,14 +2,11 @@
package server package server
import ( import (
"errors"
"fmt" "fmt"
"github.com/Mrs4s/go-cqhttp/global" "github.com/Mrs4s/go-cqhttp/global"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"os" "os"
"os/exec" "os/exec"
"path/filepath"
"runtime"
"strings" "strings"
) )
@ -37,38 +34,15 @@ func Daemon() {
log.Info("[PID] ", proc.Process.Pid) log.Info("[PID] ", proc.Process.Pid)
//pid写入到pid文件中方便后续stop的时候kill //pid写入到pid文件中方便后续stop的时候kill
pidErr:=savePid("go-cqhttp.pid",fmt.Sprintf("%d",proc.Process.Pid)) pidErr := savePid("go-cqhttp.pid", fmt.Sprintf("%d", proc.Process.Pid))
if pidErr != nil{ if pidErr != nil {
log.Errorf("save pid file error: %v",pidErr) log.Errorf("save pid file error: %v", pidErr)
} }
os.Exit(0) os.Exit(0)
} }
// savePid 保存pid到文件中便于后续restart/stop的时候kill pid用。 // savePid 保存pid到文件中便于后续restart/stop的时候kill pid用。
func savePid(path string,data string) error { func savePid(path string, data string) error {
return global.WriteAllText(path, data) return global.WriteAllText(path, data)
} }
// GetCurrentPath 预留,获取当前目录地址
func GetCurrentPath() (string, error) {
file, err := exec.LookPath(os.Args[0])
if err != nil {
return "", err
}
path, err := filepath.Abs(file)
if err != nil {
return "", err
}
//fmt.Println("path111:", path)
if runtime.GOOS == "windows" {
path = strings.Replace(path, "\\", "/", -1)
}
//fmt.Println("path222:", path)
i := strings.LastIndex(path, "/")
if i < 0 {
return "",errors.New("system/path_error,Can't find '/' or '\\'");
}
//fmt.Println("path333:", path)
return string(path[0 : i+1]), nil
}