diff --git a/global/config.go b/global/config.go index ac35020..bf9191b 100644 --- a/global/config.go +++ b/global/config.go @@ -1,17 +1,26 @@ package global import ( + "errors" "os" + "os/exec" + "path" + "path/filepath" + "runtime" "strconv" + "strings" "time" "github.com/hjson/hjson-go" - jsoniter "github.com/json-iterator/go" + "github.com/json-iterator/go" log "github.com/sirupsen/logrus" ) var json = jsoniter.ConfigCompatibleWithStandardLibrary +var currentPath=getCurrentPath() +var DefaultConfFile =path.Join(currentPath,"config.hjson") + // DefaultConfigWithComments 为go-cqhttp的默认配置文件 var DefaultConfigWithComments = ` /* @@ -312,3 +321,34 @@ func (c *JSONConfig) Save(path string) error { } return WriteAllText(path, string(data)) } +// getCurrentPath 获取当前文件的路径,直接返回string +func getCurrentPath() string { + cwd, e := GetCurrentPath() + if e != nil { + panic(e) + } + return cwd +} + +// GetCurrentPath 预留,获取当前目录地址 +func GetCurrentPath() (string, error) { + file, err := exec.LookPath(os.Args[0]) + if err != nil { + return "", err + } + fpath, err := filepath.Abs(file) + if err != nil { + return "", err + } + //fmt.Println("path111:", path) + if runtime.GOOS == "windows" { + fpath = strings.Replace(fpath, "\\", "/", -1) + } + //fmt.Println("path222:", path) + i := strings.LastIndex(fpath, "/") + if i < 0 { + return "", errors.New("system/path_error,Can't find '/' or '\\'"); + } + //fmt.Println("path333:", path) + return string(fpath[0 : i+1]), nil +} \ No newline at end of file diff --git a/main.go b/main.go index f5abd3d..010cc55 100644 --- a/main.go +++ b/main.go @@ -10,8 +10,8 @@ import ( "flag" "fmt" "github.com/Mrs4s/go-cqhttp/global/terminal" - rotatelogs "github.com/lestrrat-go/file-rotatelogs" - easy "github.com/t-tomalak/logrus-easy-formatter" + "github.com/lestrrat-go/file-rotatelogs" + "github.com/t-tomalak/logrus-easy-formatter" "io" "io/ioutil" "net/http" @@ -36,22 +36,28 @@ import ( "github.com/Mrs4s/MiraiGo/client" "github.com/Mrs4s/go-cqhttp/coolq" "github.com/Mrs4s/go-cqhttp/global" - jsoniter "github.com/json-iterator/go" + "github.com/json-iterator/go" log "github.com/sirupsen/logrus" ) var json = jsoniter.ConfigCompatibleWithStandardLibrary var conf *global.JSONConfig var isFastStart = false +var c string var d bool var h bool func init() { + var debug bool + flag.StringVar(&c, "c", global.DefaultConfFile, "configuration filename default is config.hjson") flag.BoolVar(&d, "d", false, "running as a daemon") + flag.BoolVar(&debug, "D", false, "debug mode") flag.BoolVar(&h, "h", false, "this help") flag.Parse() + //通过-c 参数替换 配置文件路径 + global.DefaultConfFile = c logFormatter := &easy.Formatter{ TimestampFormat: "2006-01-02 15:04:05", LogFormat: "[%time%] [%lvl%]: %msg% \n", @@ -67,8 +73,11 @@ func init() { os.Exit(1) } + if debug { + conf.Debug=true + } // 在debug模式下,将在标准输出中打印当前执行行数 - if conf.Debug { + if conf.Debug{ log.SetReportCaller(true) } @@ -98,8 +107,8 @@ func init() { goConf.ReverseServers[0].ReverseEventURL = conf.WSReverseEventURL goConf.ReverseServers[0].ReverseReconnectInterval = conf.WSReverseReconnectInterval } - if err := goConf.Save("config.hjson"); err != nil { - log.Fatalf("保存 config.hjson 时出现错误: %v", err) + if err := goConf.Save(global.DefaultConfFile); err != nil { + log.Fatalf("保存 %s 时出现错误: %v", global.DefaultConfFile, err) } _ = os.Remove("cqhttp.json") } @@ -160,7 +169,7 @@ func main() { time.Sleep(time.Second * 10) } if conf.Uin == 0 || (conf.Password == "" && conf.PasswordEncrypted == "") { - log.Warnf("请修改 config.hjson 以添加账号密码.") + log.Warnf("请修改 %s 以添加账号密码.", global.DefaultConfFile) if !isFastStart { time.Sleep(time.Second * 5) } @@ -195,7 +204,7 @@ func main() { global.PasswordHash = md5.Sum([]byte(conf.Password)) conf.Password = "" conf.PasswordEncrypted = "AES:" + PasswordHashEncrypt(global.PasswordHash[:], byteKey) - _ = conf.Save("config.hjson") + _ = conf.Save(global.DefaultConfFile) } if conf.PasswordEncrypted != "" { if len(byteKey) == 0 { @@ -224,7 +233,7 @@ func main() { passwordHash := md5.Sum([]byte(password)) newPasswordHash := PasswordHashEncrypt(passwordHash[:], byteKey) conf.PasswordEncrypted = "AES:" + newPasswordHash - _ = conf.Save("config.hjson") + _ = conf.Save(global.DefaultConfFile) log.Debug("密码加密方案升级完成") } @@ -518,15 +527,15 @@ func getConfig() *global.JSONConfig { conf.HTTPConfig.PostUrls[post] = os.Getenv("HTTP_SECRET") } } else { - conf = global.LoadConfig("config.hjson") + conf = global.LoadConfig(global.DefaultConfFile) } if conf == nil { - err := global.WriteAllText("config.hjson", global.DefaultConfigWithComments) + err := global.WriteAllText(global.DefaultConfFile, global.DefaultConfigWithComments) if err != nil { log.Fatalf("创建默认配置文件时出现错误: %v", err) return nil } - log.Infof("默认配置文件已生成, 请编辑 config.hjson 后重启程序.") + log.Infof("默认配置文件已生成, 请编辑 %s 后重启程序.", global.DefaultConfFile) if !isFastStart { time.Sleep(time.Second * 5) } diff --git a/server/apiAdmin.go b/server/apiAdmin.go index be62817..a5656c0 100644 --- a/server/apiAdmin.go +++ b/server/apiAdmin.go @@ -22,10 +22,10 @@ import ( "github.com/Mrs4s/go-cqhttp/coolq" "github.com/Mrs4s/go-cqhttp/global" "github.com/gin-gonic/gin" - jsoniter "github.com/json-iterator/go" + "github.com/json-iterator/go" log "github.com/sirupsen/logrus" "github.com/tidwall/gjson" - asciiart "github.com/yinghau76/go-ascii-art" + "github.com/yinghau76/go-ascii-art" ) var json = jsoniter.ConfigCompatibleWithStandardLibrary @@ -364,7 +364,7 @@ func GetConf() *global.JSONConfig { if JSONConfig != nil { return JSONConfig } - conf := global.LoadConfig("config.hjson") + conf := global.LoadConfig(global.DefaultConfFile) return conf } @@ -559,9 +559,9 @@ func AdminDoConfigBase(s *webServer, c *gin.Context) { conf.EnableDB = false } conf.AccessToken = c.PostForm("access_token") - if err := conf.Save("config.hjson"); err != nil { - log.Fatalf("保存 config.hjson 时出现错误: %v", err) - c.JSON(200, Failed(502, "保存 config.hjson 时出现错误:"+fmt.Sprintf("%v", err))) + if err := conf.Save(global.DefaultConfFile); err != nil { + log.Fatalf("保存 %s 时出现错误: %v", global.DefaultConfFile, err) + c.JSON(200, Failed(502, "保存 "+global.DefaultConfFile+" 时出现错误:"+fmt.Sprintf("%v", err))) } else { JSONConfig = nil c.JSON(200, coolq.OK(coolq.MSG{})) @@ -584,9 +584,9 @@ func AdminDoConfigHTTP(s *webServer, c *gin.Context) { if c.PostForm("post_url") != "" { conf.HTTPConfig.PostUrls[c.PostForm("post_url")] = c.PostForm("post_secret") } - if err := conf.Save("config.hjson"); err != nil { - log.Fatalf("保存 config.hjson 时出现错误: %v", err) - c.JSON(200, Failed(502, "保存 config.hjson 时出现错误:"+fmt.Sprintf("%v", err))) + if err := conf.Save(global.DefaultConfFile); err != nil { + log.Fatalf("保存 %s 时出现错误: %v", global.DefaultConfFile, err) + c.JSON(200, Failed(502, "保存 "+global.DefaultConfFile+" 时出现错误:"+fmt.Sprintf("%v", err))) } else { JSONConfig = nil c.JSON(200, coolq.OK(coolq.MSG{})) @@ -604,9 +604,9 @@ func AdminDoConfigWS(s *webServer, c *gin.Context) { } else { conf.WSConfig.Enabled = false } - if err := conf.Save("config.hjson"); err != nil { - log.Fatalf("保存 config.hjson 时出现错误: %v", err) - c.JSON(200, Failed(502, "保存 config.hjson 时出现错误:"+fmt.Sprintf("%v", err))) + if err := conf.Save(global.DefaultConfFile); err != nil { + log.Fatalf("保存 %s 时出现错误: %v", global.DefaultConfFile, err) + c.JSON(200, Failed(502, "保存 "+global.DefaultConfFile+" 时出现错误:"+fmt.Sprintf("%v", err))) } else { JSONConfig = nil c.JSON(200, coolq.OK(coolq.MSG{})) @@ -626,9 +626,9 @@ func AdminDoConfigReverseWS(s *webServer, c *gin.Context) { } else { conf.ReverseServers[0].Enabled = false } - if err := conf.Save("config.hjson"); err != nil { - log.Fatalf("保存 config.hjson 时出现错误: %v", err) - c.JSON(200, Failed(502, "保存 config.hjson 时出现错误:"+fmt.Sprintf("%v", err))) + if err := conf.Save(global.DefaultConfFile); err != nil { + log.Fatalf("保存 %s 时出现错误: %v", global.DefaultConfFile, err) + c.JSON(200, Failed(502, "保存 "+global.DefaultConfFile+" 时出现错误:"+fmt.Sprintf("%v", err))) } else { JSONConfig = nil c.JSON(200, coolq.OK(coolq.MSG{})) @@ -641,13 +641,13 @@ func AdminDoConfigJSON(s *webServer, c *gin.Context) { JSON := c.PostForm("json") err := json.Unmarshal([]byte(JSON), &conf) if err != nil { - log.Warnf("尝试加载配置文件 %v 时出现错误: %v", "config.hjson", err) - c.JSON(200, Failed(502, "保存 config.hjson 时出现错误:"+fmt.Sprintf("%v", err))) + log.Warnf("尝试加载配置文件 %v 时出现错误: %v", global.DefaultConfFile, err) + c.JSON(200, Failed(502, "保存 "+global.DefaultConfFile+" 时出现错误:"+fmt.Sprintf("%v", err))) return } - if err := conf.Save("config.hjson"); err != nil { - log.Fatalf("保存 config.hjson 时出现错误: %v", err) - c.JSON(200, Failed(502, "保存 config.hjson 时出现错误:"+fmt.Sprintf("%v", err))) + if err := conf.Save(global.DefaultConfFile); err != nil { + log.Fatalf("保存 %s 时出现错误: %v", global.DefaultConfFile, err) + c.JSON(200, Failed(502, "保存 "+global.DefaultConfFile+" 时出现错误:"+fmt.Sprintf("%v", err))) } else { JSONConfig = nil c.JSON(200, coolq.OK(coolq.MSG{})) diff --git a/server/daemon.go b/server/daemon.go index e36b8af..09cef00 100644 --- a/server/daemon.go +++ b/server/daemon.go @@ -2,14 +2,11 @@ package server import ( - "errors" "fmt" "github.com/Mrs4s/go-cqhttp/global" log "github.com/sirupsen/logrus" "os" "os/exec" - "path/filepath" - "runtime" "strings" ) @@ -37,38 +34,15 @@ func Daemon() { log.Info("[PID] ", proc.Process.Pid) //pid写入到pid文件中,方便后续stop的时候kill - pidErr:=savePid("go-cqhttp.pid",fmt.Sprintf("%d",proc.Process.Pid)) - if pidErr != nil{ - log.Errorf("save pid file error: %v",pidErr) + pidErr := savePid("go-cqhttp.pid", fmt.Sprintf("%d", proc.Process.Pid)) + if pidErr != nil { + log.Errorf("save pid file error: %v", pidErr) } os.Exit(0) } // savePid 保存pid到文件中,便于后续restart/stop的时候kill pid用。 -func savePid(path string,data string) error { +func savePid(path string, data string) error { return global.WriteAllText(path, data) } - -// GetCurrentPath 预留,获取当前目录地址 -func GetCurrentPath() (string, error) { - file, err := exec.LookPath(os.Args[0]) - if err != nil { - return "", err - } - path, err := filepath.Abs(file) - if err != nil { - return "", err - } - //fmt.Println("path111:", path) - if runtime.GOOS == "windows" { - path = strings.Replace(path, "\\", "/", -1) - } - //fmt.Println("path222:", path) - i := strings.LastIndex(path, "/") - if i < 0 { - return "",errors.New("system/path_error,Can't find '/' or '\\'"); - } - //fmt.Println("path333:", path) - return string(path[0 : i+1]), nil -}