1
0
mirror of https://github.com/Mrs4s/go-cqhttp.git synced 2025-05-05 03:23:49 +08:00

style: move global/update to internal/selfupdate

also move checkUpdate to internal/selfupdate
This commit is contained in:
wdvxdr 2021-09-25 10:58:54 +08:00
parent 110982651d
commit 883fca089d
No known key found for this signature in database
GPG Key ID: 703F8C071DE7A1B6
8 changed files with 244 additions and 230 deletions

View File

@ -9,6 +9,7 @@ import (
"time" "time"
"github.com/Mrs4s/go-cqhttp/global" "github.com/Mrs4s/go-cqhttp/global"
"github.com/Mrs4s/go-cqhttp/internal/base"
"github.com/Mrs4s/MiraiGo/binary" "github.com/Mrs4s/MiraiGo/binary"
"github.com/Mrs4s/MiraiGo/client" "github.com/Mrs4s/MiraiGo/client"
@ -16,18 +17,16 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
var format = "string"
// SetMessageFormat 设置消息上报格式默认为string // SetMessageFormat 设置消息上报格式默认为string
func SetMessageFormat(f string) { func SetMessageFormat(f string) {
format = f base.PostFormat = f
} }
// ToFormattedMessage 将给定[]message.IMessageElement转换为通过coolq.SetMessageFormat所定义的消息上报格式 // ToFormattedMessage 将给定[]message.IMessageElement转换为通过coolq.SetMessageFormat所定义的消息上报格式
func ToFormattedMessage(e []message.IMessageElement, groupID int64, isRaw ...bool) (r interface{}) { func ToFormattedMessage(e []message.IMessageElement, groupID int64, isRaw ...bool) (r interface{}) {
if format == "string" { if base.PostFormat == "string" {
r = ToStringMessage(e, groupID, isRaw...) r = ToStringMessage(e, groupID, isRaw...)
} else if format == "array" { } else if base.PostFormat == "array" {
r = ToArrayMessage(e, groupID) r = ToArrayMessage(e, groupID)
} }
return return

View File

@ -1,93 +0,0 @@
// Package update 包含go-cqhttp自我更新相关函数
package update
import (
"bufio"
"bytes"
"fmt"
"hash"
"io"
"os"
"path/filepath"
"github.com/dustin/go-humanize"
"github.com/kardianos/osext"
log "github.com/sirupsen/logrus"
)
// WriteSumCounter 写入量计算实例
type WriteSumCounter struct {
Total uint64
Hash hash.Hash
}
// Write 方法将写入的byte长度追加至写入的总长度Total中
func (wc *WriteSumCounter) Write(p []byte) (int, error) {
n := len(p)
wc.Total += uint64(n)
wc.Hash.Write(p)
fmt.Printf("\r ")
fmt.Printf("\rDownloading... %s complete", humanize.Bytes(wc.Total))
return n, nil
}
// FromStream copy form getlantern/go-update
func FromStream(updateWith io.Reader) (err error, errRecover error) {
updatePath, err := osext.Executable()
if err != nil {
return
}
var newBytes []byte
// no patch to apply, go on through
bufBytes := bufio.NewReader(updateWith)
updateWith = io.Reader(bufBytes)
newBytes, err = io.ReadAll(updateWith)
if err != nil {
return
}
// get the directory the executable exists in
updateDir := filepath.Dir(updatePath)
filename := filepath.Base(updatePath)
// Copy the contents of of newbinary to a the new executable file
newPath := filepath.Join(updateDir, fmt.Sprintf(".%s.new", filename))
fp, err := os.OpenFile(newPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0o755)
if err != nil {
return
}
// We won't log this error, because it's always going to happen.
defer func() { _ = fp.Close() }()
if _, err = io.Copy(fp, bytes.NewReader(newBytes)); err != nil {
log.Errorf("Unable to copy data: %v\n", err)
}
// if we don't call fp.Close(), windows won't let us move the new executable
// because the file will still be "in use"
if err := fp.Close(); err != nil {
log.Errorf("Unable to close file: %v\n", err)
}
// this is where we'll move the executable to so that we can swap in the updated replacement
oldPath := filepath.Join(updateDir, fmt.Sprintf(".%s.old", filename))
// delete any existing old exec file - this is necessary on Windows for two reasons:
// 1. after a successful update, Windows can't remove the .old file because the process is still running
// 2. windows rename operations fail if the destination file already exists
_ = os.Remove(oldPath)
// move the existing executable to a new file in the same directory
err = os.Rename(updatePath, oldPath)
if err != nil {
return
}
// move the new executable in to become the new program
err = os.Rename(newPath, updatePath)
if err != nil {
// copy unsuccessful
errRecover = os.Rename(oldPath, updatePath)
} else {
// copy successful, remove the old binary
_ = os.Remove(oldPath)
}
return
}

View File

@ -1,7 +1,11 @@
// Package base provides base config for go-cqhttp // Package base provides base config for go-cqhttp
package base package base
import "github.com/Mrs4s/go-cqhttp/global/config" import (
log "github.com/sirupsen/logrus"
"github.com/Mrs4s/go-cqhttp/global/config"
)
// flags // flags
var ( var (
@ -13,9 +17,10 @@ var (
ForceFragmented bool // 是否启用强制分片 ForceFragmented bool // 是否启用强制分片
SkipMimeScan bool // 是否跳过Mime扫描 SkipMimeScan bool // 是否跳过Mime扫描
Proxy string // 存储 proxy_rewrite,用于设置代理 PostFormat = "string" // 上报格式 string or array
PasswordHash [16]byte // 存储QQ密码哈希供登录使用 Proxy string // 存储 proxy_rewrite,用于设置代理
AccountToken []byte // 存储AccountToken供登录使用 PasswordHash [16]byte // 存储QQ密码哈希供登录使用
AccountToken []byte // 存储AccountToken供登录使用
) )
// Parse parses flags from config file // Parse parses flags from config file
@ -32,5 +37,11 @@ func Parse() {
} }
{ // string { // string
Proxy = conf.Message.ProxyRewrite Proxy = conf.Message.ProxyRewrite
if conf.Message.PostFormat != "string" && conf.Message.PostFormat != "array" {
log.Warnf("post-format 配置错误, 将自动使用 string")
PostFormat = "string"
} else {
PostFormat = conf.Message.PostFormat
}
} }
} }

View File

@ -0,0 +1,206 @@
// Package selfupdate 版本升级检查和自更新
package selfupdate
import (
"bufio"
"encoding/hex"
"fmt"
"hash"
"io"
"os"
"path/filepath"
"runtime"
"strings"
"github.com/dustin/go-humanize"
"github.com/kardianos/osext"
"github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"github.com/Mrs4s/go-cqhttp/global"
"github.com/Mrs4s/go-cqhttp/internal/base"
)
func readLine() (str string) {
console := bufio.NewReader(os.Stdin)
str, _ = console.ReadString('\n')
str = strings.TrimSpace(str)
return
}
func lastVersion() (string, error) {
r, err := global.GetBytes("https://api.github.com/repos/Mrs4s/go-cqhttp/releases/latest")
if err != nil {
return "", err
}
return gjson.GetBytes(r, "tag_name").Str, nil
}
// CheckUpdate 检查更新
func CheckUpdate() {
logrus.Infof("正在检查更新.")
if base.Version == "(devel)" {
logrus.Warnf("检查更新失败: 使用的 Actions 测试版或自编译版本.")
return
}
latest, err := lastVersion()
if err != nil {
logrus.Warnf("检查更新失败: %v", err)
return
}
if global.VersionNameCompare(base.Version, latest) {
logrus.Infof("当前有更新的 go-cqhttp 可供更新, 请前往 https://github.com/Mrs4s/go-cqhttp/releases 下载.")
logrus.Infof("当前版本: %v 最新版本: %v", base.Version, latest)
return
}
logrus.Infof("检查更新完成. 当前已运行最新版本.")
}
func binaryName() string {
goarch := runtime.GOARCH
if goarch == "arm" {
goarch += "v7"
}
ext := "tar.gz"
if runtime.GOOS == "windows" {
ext = "zip"
}
return fmt.Sprintf("go-cqhttp_%v_%v.%v", runtime.GOOS, goarch, ext)
}
func checksum(github, version string) []byte {
sumURL := fmt.Sprintf("%v/Mrs4s/go-cqhttp/releases/download/%v/go-cqhttp_checksums.txt", github, version)
closer, err := global.HTTPGetReadCloser(sumURL)
if err != nil {
return nil
}
rd := bufio.NewReader(closer)
for {
str, err := rd.ReadString('\n')
if err != nil {
break
}
str = strings.TrimSpace(str)
if strings.HasSuffix(str, binaryName()) {
sum, _ := hex.DecodeString(strings.TrimSuffix(str, " "+binaryName()))
return sum
}
}
return nil
}
func wait() {
logrus.Info("按 Enter 继续....")
readLine()
os.Exit(0)
}
// SelfUpdate 自更新
func SelfUpdate(github string) {
if github == "" {
github = "https://github.com"
}
logrus.Infof("正在检查更新.")
latest, err := lastVersion()
if err != nil {
logrus.Warnf("获取最新版本失败: %v", err)
wait()
}
url := fmt.Sprintf("%v/Mrs4s/go-cqhttp/releases/download/%v/%v", github, latest, binaryName())
if base.Version == latest {
logrus.Info("当前版本已经是最新版本!")
wait()
}
logrus.Info("当前最新版本为 ", latest)
logrus.Warn("是否更新(y/N): ")
r := strings.TrimSpace(readLine())
if r != "y" && r != "Y" {
logrus.Warn("已取消更新!")
wait()
}
logrus.Info("正在更新,请稍等...")
sum := checksum(github, latest)
if sum != nil {
err = update(url, sum)
if err != nil {
logrus.Error("更新失败: ", err)
} else {
logrus.Info("更新成功!")
}
} else {
logrus.Error("checksum 失败!")
}
wait()
}
// writeSumCounter 写入量计算实例
type writeSumCounter struct {
total uint64
hash hash.Hash
}
// Write 方法将写入的byte长度追加至写入的总长度Total中
func (wc *writeSumCounter) Write(p []byte) (int, error) {
n := len(p)
wc.total += uint64(n)
wc.hash.Write(p)
fmt.Printf("\r ")
fmt.Printf("\rDownloading... %s complete", humanize.Bytes(wc.total))
return n, nil
}
// FromStream copy form getlantern/go-update
func fromStream(updateWith io.Reader) (err error, errRecover error) {
updatePath, err := osext.Executable()
if err != nil {
return
}
// get the directory the executable exists in
updateDir := filepath.Dir(updatePath)
filename := filepath.Base(updatePath)
// Copy the contents of of newbinary to a the new executable file
newPath := filepath.Join(updateDir, fmt.Sprintf(".%s.new", filename))
fp, err := os.OpenFile(newPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0o755)
if err != nil {
return
}
// We won't log this error, because it's always going to happen.
defer func() { _ = fp.Close() }()
if _, err = io.Copy(fp, bufio.NewReader(updateWith)); err != nil {
logrus.Errorf("Unable to copy data: %v\n", err)
}
// if we don't call fp.Close(), windows won't let us move the new executable
// because the file will still be "in use"
if err := fp.Close(); err != nil {
logrus.Errorf("Unable to close file: %v\n", err)
}
// this is where we'll move the executable to so that we can swap in the updated replacement
oldPath := filepath.Join(updateDir, fmt.Sprintf(".%s.old", filename))
// delete any existing old exec file - this is necessary on Windows for two reasons:
// 1. after a successful update, Windows can't remove the .old file because the process is still running
// 2. windows rename operations fail if the destination file already exists
_ = os.Remove(oldPath)
// move the existing executable to a new file in the same directory
err = os.Rename(updatePath, oldPath)
if err != nil {
return
}
// move the new executable in to become the new program
err = os.Rename(newPath, updatePath)
if err != nil {
// copy unsuccessful
errRecover = os.Rename(oldPath, updatePath)
} else {
// copy successful, remove the old binary
_ = os.Remove(oldPath)
}
return
}

View File

@ -1,7 +1,7 @@
//go:build !windows //go:build !windows
// +build !windows // +build !windows
package update package selfupdate
import ( import (
"archive/tar" "archive/tar"
@ -14,21 +14,21 @@ import (
"net/http" "net/http"
) )
// Update go-cqhttp自我更新 // update go-cqhttp自我更新
func Update(url string, sum []byte) error { func update(url string, sum []byte) error {
resp, err := http.Get(url) resp, err := http.Get(url)
if err != nil { if err != nil {
return err return err
} }
defer resp.Body.Close() defer resp.Body.Close()
wc := WriteSumCounter{ wc := writeSumCounter{
Hash: sha256.New(), hash: sha256.New(),
} }
rsp, err := io.ReadAll(io.TeeReader(resp.Body, &wc)) rsp, err := io.ReadAll(io.TeeReader(resp.Body, &wc))
if err != nil { if err != nil {
return err return err
} }
if !bytes.Equal(wc.Hash.Sum(nil), sum) { if !bytes.Equal(wc.hash.Sum(nil), sum) {
return errors.New("文件已损坏") return errors.New("文件已损坏")
} }
gr, err := gzip.NewReader(bytes.NewReader(rsp)) gr, err := gzip.NewReader(bytes.NewReader(rsp))
@ -42,7 +42,7 @@ func Update(url string, sum []byte) error {
return err return err
} }
if header.Name == "go-cqhttp" { if header.Name == "go-cqhttp" {
err, _ := FromStream(tr) err, _ := fromStream(tr)
fmt.Println() fmt.Println()
if err != nil { if err != nil {
return err return err

View File

@ -1,4 +1,4 @@
package update package selfupdate
import ( import (
"archive/zip" "archive/zip"
@ -10,21 +10,21 @@ import (
"net/http" "net/http"
) )
// Update go-cqhttp自我更新 // update go-cqhttp自我更新
func Update(url string, sum []byte) error { func update(url string, sum []byte) error {
resp, err := http.Get(url) resp, err := http.Get(url)
if err != nil { if err != nil {
return err return err
} }
defer resp.Body.Close() defer resp.Body.Close()
wc := WriteSumCounter{ wc := writeSumCounter{
Hash: sha256.New(), hash: sha256.New(),
} }
rsp, err := io.ReadAll(io.TeeReader(resp.Body, &wc)) rsp, err := io.ReadAll(io.TeeReader(resp.Body, &wc))
if err != nil { if err != nil {
return err return err
} }
if !bytes.Equal(wc.Hash.Sum(nil), sum) { if !bytes.Equal(wc.hash.Sum(nil), sum) {
return errors.New("文件已损坏") return errors.New("文件已损坏")
} }
reader, _ := zip.NewReader(bytes.NewReader(rsp), resp.ContentLength) reader, _ := zip.NewReader(bytes.NewReader(rsp), resp.ContentLength)
@ -32,7 +32,7 @@ func Update(url string, sum []byte) error {
if err != nil { if err != nil {
return err return err
} }
err, _ = FromStream(file) err, _ = fromStream(file)
fmt.Println() fmt.Println()
if err != nil { if err != nil {
return err return err

View File

@ -24,7 +24,7 @@ func readLine() (str string) {
return return
} }
var readLineTimeout = func(t time.Duration, de string) (str string) { func readLineTimeout(t time.Duration, de string) (str string) {
r := make(chan string) r := make(chan string)
go func() { go func() {
select { select {

117
main.go
View File

@ -1,7 +1,6 @@
package main package main
import ( import (
"bufio"
"crypto/aes" "crypto/aes"
"crypto/md5" "crypto/md5"
"crypto/sha1" "crypto/sha1"
@ -12,18 +11,15 @@ import (
"os/exec" "os/exec"
"path" "path"
"path/filepath" "path/filepath"
"runtime"
"strings" "strings"
"sync" "sync"
"time" "time"
"github.com/Mrs4s/MiraiGo/binary" "github.com/Mrs4s/MiraiGo/binary"
"github.com/Mrs4s/MiraiGo/client" "github.com/Mrs4s/MiraiGo/client"
"github.com/Mrs4s/MiraiGo/utils"
para "github.com/fumiama/go-hide-param" para "github.com/fumiama/go-hide-param"
rotatelogs "github.com/lestrrat-go/file-rotatelogs" rotatelogs "github.com/lestrrat-go/file-rotatelogs"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"golang.org/x/crypto/pbkdf2" "golang.org/x/crypto/pbkdf2"
"golang.org/x/term" "golang.org/x/term"
@ -31,8 +27,8 @@ import (
"github.com/Mrs4s/go-cqhttp/global" "github.com/Mrs4s/go-cqhttp/global"
"github.com/Mrs4s/go-cqhttp/global/config" "github.com/Mrs4s/go-cqhttp/global/config"
"github.com/Mrs4s/go-cqhttp/global/terminal" "github.com/Mrs4s/go-cqhttp/global/terminal"
"github.com/Mrs4s/go-cqhttp/global/update"
"github.com/Mrs4s/go-cqhttp/internal/base" "github.com/Mrs4s/go-cqhttp/internal/base"
"github.com/Mrs4s/go-cqhttp/internal/selfupdate"
"github.com/Mrs4s/go-cqhttp/server" "github.com/Mrs4s/go-cqhttp/server"
) )
@ -119,9 +115,9 @@ func main() {
switch arg[i] { switch arg[i] {
case "update": case "update":
if len(arg) > i+1 { if len(arg) > i+1 {
selfUpdate(arg[i+1]) selfupdate.SelfUpdate(arg[i+1])
} else { } else {
selfUpdate("") selfupdate.SelfUpdate("")
} }
case "key": case "key":
p := i + 1 p := i + 1
@ -359,13 +355,6 @@ func main() {
} }
cli.SetOnlineStatus(allowStatus[int(conf.Account.Status)]) cli.SetOnlineStatus(allowStatus[int(conf.Account.Status)])
bot := coolq.NewQQBot(cli, conf) bot := coolq.NewQQBot(cli, conf)
_ = bot.Client
if conf.Message.PostFormat != "string" && conf.Message.PostFormat != "array" {
log.Warnf("post-format 配置错误, 将自动使用 string")
coolq.SetMessageFormat("string")
} else {
coolq.SetMessageFormat(conf.Message.PostFormat)
}
for _, m := range conf.Servers { for _, m := range conf.Servers {
if h, ok := m["http"]; ok { if h, ok := m["http"]; ok {
hc := new(config.HTTPServer) hc := new(config.HTTPServer)
@ -411,7 +400,7 @@ func main() {
log.Info("资源初始化完成, 开始处理信息.") log.Info("资源初始化完成, 开始处理信息.")
log.Info("アトリは、高性能ですから!") log.Info("アトリは、高性能ですから!")
go checkUpdate() go selfupdate.CheckUpdate()
<-global.SetupMainSignalHandler() <-global.SetupMainSignalHandler()
} }
@ -447,104 +436,6 @@ func PasswordHashDecrypt(encryptedPasswordHash string, key []byte) ([]byte, erro
return result, nil return result, nil
} }
func checkUpdate() {
log.Infof("正在检查更新.")
if base.Version == "(devel)" {
log.Warnf("检查更新失败: 使用的 Actions 测试版或自编译版本.")
return
}
r, err := global.GetBytes("https://api.github.com/repos/Mrs4s/go-cqhttp/releases/latest")
if err != nil {
log.Warnf("检查更新失败: %v", err)
return
}
info := gjson.Parse(utils.B2S(r))
if global.VersionNameCompare(base.Version, info.Get("tag_name").Str) {
log.Infof("当前有更新的 go-cqhttp 可供更新, 请前往 https://github.com/Mrs4s/go-cqhttp/releases 下载.")
log.Infof("当前版本: %v 最新版本: %v", base.Version, info.Get("tag_name").Str)
return
}
log.Infof("检查更新完成. 当前已运行最新版本.")
}
func selfUpdate(imageURL string) {
log.Infof("正在检查更新.")
var r string
res, err := global.GetBytes("https://api.github.com/repos/Mrs4s/go-cqhttp/releases/latest")
if err != nil {
log.Warnf("检查更新失败: %v", err)
return
}
info := gjson.Parse(utils.B2S(res))
version := info.Get("tag_name").Str
if base.Version == version {
log.Info("当前版本已经是最新版本!")
goto wait
}
log.Info("当前最新版本为 ", version)
log.Warn("是否更新(y/N): ")
r = strings.TrimSpace(readLine())
if r != "y" && r != "Y" {
log.Warn("已取消更新!")
} else {
log.Info("正在更新,请稍等...")
sumURL := fmt.Sprintf("%v/Mrs4s/go-cqhttp/releases/download/%v/go-cqhttp_checksums.txt",
func() string {
if imageURL != "" {
return imageURL
}
return "https://github.com"
}(), version)
closer, err := global.HTTPGetReadCloser(sumURL)
if err != nil {
log.Error("更新失败: ", err)
goto wait
}
rd := bufio.NewReader(closer)
binaryName := fmt.Sprintf("go-cqhttp_%v_%v.%v", runtime.GOOS, func() string {
if runtime.GOARCH == "arm" {
return "armv7"
}
return runtime.GOARCH
}(), func() string {
if runtime.GOOS == "windows" {
return "zip"
}
return "tar.gz"
}())
var sum []byte
for {
str, err := rd.ReadString('\n')
if err != nil {
break
}
str = strings.TrimSpace(str)
if strings.HasSuffix(str, binaryName) {
sum, _ = hex.DecodeString(strings.TrimSuffix(str, " "+binaryName))
break
}
}
url := fmt.Sprintf("%v/Mrs4s/go-cqhttp/releases/download/%v/%v",
func() string {
if imageURL != "" {
return imageURL
}
return "https://github.com"
}(), version, binaryName)
err = update.Update(url, sum)
if err != nil {
log.Error("更新失败: ", err)
} else {
log.Info("更新成功!")
}
}
wait:
log.Info("按 Enter 继续....")
readLine()
os.Exit(0)
}
// help cli命令行-h的帮助提示 // help cli命令行-h的帮助提示
func help() { func help() {
fmt.Printf(`go-cqhttp service fmt.Printf(`go-cqhttp service