diff --git a/coolq/bot.go b/coolq/bot.go index 1bdd3ae..c36db69 100644 --- a/coolq/bot.go +++ b/coolq/bot.go @@ -274,7 +274,7 @@ func (bot *CQBot) SendPrivateMessage(target int64, groupID int64, m *message.Sen return 0 } if i, ok := elem.(*message.VoiceElement); ok { - fv, err := bot.Client.UploadPrivatePtt(target, i.Data) + fv, err := bot.Client.UploadPrivatePtt(target, bytes.NewReader(i.Data)) // todo: io.ReadSeeker if err != nil { log.Warnf("警告: 私聊 %v 消息语音上传失败: %v", target, err) continue diff --git a/global/fs.go b/global/fs.go index 63e22c7..6107261 100644 --- a/global/fs.go +++ b/global/fs.go @@ -15,6 +15,7 @@ import ( "strconv" "strings" + "github.com/Mrs4s/MiraiGo/utils" log "github.com/sirupsen/logrus" ) @@ -33,21 +34,6 @@ const ( CachePath = "data/cache" ) -var ( - // ImageFs go-cqhttp使用的图片缓存目录 - ImageFs = os.DirFS("data/images") - // ImageOldFs 兼容旧版go-cqhttp使用的图片缓存目录 - ImageOldFs = os.DirFS("data/image") - // VoiceFs go-cqhttp使用的语音缓存目录 - VoiceFs = os.DirFS("data/voices") - // VoiceFsOld 兼容旧版go-cqhttp使用的语音缓存目录 - VoiceFsOld = os.DirFS("data/record") - // VideoFs go-cqhttp使用的视频缓存目录 - VideoFs = os.DirFS("data/videos") - // CacheFs go-cqhttp使用的缓存目录 - CacheFs = os.DirFS("data/cache") -) - var ( // ErrSyntax Path语法错误时返回的错误 ErrSyntax = errors.New("syntax error") @@ -75,7 +61,7 @@ func ReadAllText(path string) string { // WriteAllText 将给定text写入给定path func WriteAllText(path, text string) error { - return ioutil.WriteFile(path, []byte(text), 0644) + return ioutil.WriteFile(path, utils.S2B(text), 0644) } // Check 检测err是否为nil @@ -98,7 +84,7 @@ func IsAMRorSILK(b []byte) bool { func FindFile(file, cache, p string) (data []byte, err error) { data, err = nil, ErrSyntax switch { - case strings.HasPrefix(file, "http") || strings.HasPrefix(file, "https"): + case strings.HasPrefix(file, "http"): // https also has prefix http if cache == "" { cache = "1" } @@ -113,7 +99,7 @@ func FindFile(file, cache, p string) (data []byte, err error) { return nil, err } case strings.HasPrefix(file, "base64"): - data, err = base64.StdEncoding.DecodeString(strings.ReplaceAll(file, "base64://", "")) + data, err = base64.StdEncoding.DecodeString(strings.TrimPrefix(file, "base64://")) if err != nil { return nil, err } diff --git a/global/update/update.go b/global/update/update.go index f7ca0c3..30ab241 100644 --- a/global/update/update.go +++ b/global/update/update.go @@ -5,34 +5,31 @@ import ( "bufio" "bytes" "fmt" + "hash" "io" "io/ioutil" "os" "path/filepath" - "strings" "github.com/dustin/go-humanize" "github.com/kardianos/osext" log "github.com/sirupsen/logrus" ) -// WriteCounter 写入量计算实例 -type WriteCounter struct { +// WriteSumCounter 写入量计算实例 +type WriteSumCounter struct { Total uint64 + Hash hash.Hash } // Write 方法将写入的byte长度追加至写入的总长度Total中 -func (wc *WriteCounter) Write(p []byte) (int, error) { +func (wc *WriteSumCounter) Write(p []byte) (int, error) { n := len(p) wc.Total += uint64(n) - wc.PrintProgress() - return n, nil -} - -// PrintProgress 方法将打印当前的总写入量 -func (wc *WriteCounter) PrintProgress() { - fmt.Printf("\r%s", strings.Repeat(" ", 35)) + wc.Hash.Write(p) + fmt.Printf("\r ") fmt.Printf("\rDownloading... %s complete", humanize.Bytes(wc.Total)) + return n, nil } // FromStream copy form getlantern/go-update diff --git a/global/update/update_others.go b/global/update/update_others.go index 7326497..d8020bf 100644 --- a/global/update/update_others.go +++ b/global/update/update_others.go @@ -6,46 +6,47 @@ import ( "archive/tar" "bytes" "compress/gzip" + "crypto/sha256" + "errors" "fmt" "io" "net/http" - - log "github.com/sirupsen/logrus" ) // Update go-cqhttp自我更新 -func Update(url string) { +func Update(url string, sum []byte) error { resp, err := http.Get(url) if err != nil { - log.Error("更新失败: ", err) - return + return err } defer resp.Body.Close() - wc := WriteCounter{} - data, err := io.ReadAll(io.TeeReader(resp.Body, &wc)) - if err != nil { - log.Error("更新失败: ", err) - return + wc := WriteSumCounter{ + Hash: sha256.New(), } - gr, err := gzip.NewReader(bytes.NewReader(data)) + rsp, err := io.ReadAll(io.TeeReader(resp.Body, &wc)) if err != nil { - log.Error("更新失败: ", err) - return + return err + } + if !bytes.Equal(wc.Hash.Sum(nil), sum) { + return errors.New("文件已损坏") + } + gr, err := gzip.NewReader(bytes.NewReader(rsp)) + if err != nil { + return err } tr := tar.NewReader(gr) for { header, err := tr.Next() - if err == io.EOF { - return + if err != nil { + return err } if header.Name == "go-cqhttp" { err, _ := FromStream(tr) fmt.Println() if err != nil { - log.Error("更新失败!", err) - return + return err } - log.Info("更新完成!") + return nil } } } diff --git a/global/update/update_windows.go b/global/update/update_windows.go index 0570522..7fce75b 100644 --- a/global/update/update_windows.go +++ b/global/update/update_windows.go @@ -3,34 +3,39 @@ package update import ( "archive/zip" "bytes" + "crypto/sha256" + "errors" "fmt" "io" "net/http" - - log "github.com/sirupsen/logrus" ) // Update go-cqhttp自我更新 -func Update(url string) { +func Update(url string, sum []byte) error { resp, err := http.Get(url) if err != nil { - log.Error("更新失败: ", err) - return + return err } defer resp.Body.Close() - wc := WriteCounter{} - rsp, _ := io.ReadAll(io.TeeReader(resp.Body, &wc)) + wc := WriteSumCounter{ + Hash: sha256.New(), + } + rsp, err := io.ReadAll(io.TeeReader(resp.Body, &wc)) + if err != nil { + return err + } + if !bytes.Equal(wc.Hash.Sum(nil), sum) { + return errors.New("文件已损坏") + } reader, _ := zip.NewReader(bytes.NewReader(rsp), resp.ContentLength) file, err := reader.Open("go-cqhttp.exe") if err != nil { - log.Error("更新失败!", err) - return + return err } err, _ = FromStream(file) fmt.Println() if err != nil { - log.Error("更新失败!", err) - return + return err } - log.Info("更新完成!") + return nil } diff --git a/go.mod b/go.mod index 3a6f72f..e10301b 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,7 @@ go 1.16 require ( github.com/Baozisoftware/qrcode-terminal-go v0.0.0-20170407111555-c0650d8dff0f - github.com/Mrs4s/MiraiGo v0.0.0-20210413103959-192b8c562ffd + github.com/Mrs4s/MiraiGo v0.0.0-20210413121026-86e0f3327a76 github.com/dustin/go-humanize v1.0.0 github.com/gin-contrib/pprof v1.3.0 github.com/gin-gonic/gin v1.7.1 diff --git a/go.sum b/go.sum index 6b6dbc3..8ef6d6b 100644 --- a/go.sum +++ b/go.sum @@ -2,8 +2,8 @@ cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMT github.com/Baozisoftware/qrcode-terminal-go v0.0.0-20170407111555-c0650d8dff0f h1:2dk3eOnYllh+wUOuDhOoC2vUVoJF/5z478ryJ+wzEII= github.com/Baozisoftware/qrcode-terminal-go v0.0.0-20170407111555-c0650d8dff0f/go.mod h1:4a58ifQTEe2uwwsaqbh3i2un5/CBPg+At/qHpt18Tmk= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= -github.com/Mrs4s/MiraiGo v0.0.0-20210413103959-192b8c562ffd h1:62vdLRaK6D8HDFAuHx5C001PvcPtMke8JPzcfoF2Pg4= -github.com/Mrs4s/MiraiGo v0.0.0-20210413103959-192b8c562ffd/go.mod h1:NjiWhlvGxwv1ftOWIoiFa/OzklnAYI4YqNexFOKSZKw= +github.com/Mrs4s/MiraiGo v0.0.0-20210413121026-86e0f3327a76 h1:ZcT48WS0sl8FTO4xbITwCsj1bLVO+AxHcxaO+8ovmPY= +github.com/Mrs4s/MiraiGo v0.0.0-20210413121026-86e0f3327a76/go.mod h1:NjiWhlvGxwv1ftOWIoiFa/OzklnAYI4YqNexFOKSZKw= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= diff --git a/main.go b/main.go index 3df3d2f..d158dc8 100644 --- a/main.go +++ b/main.go @@ -1,6 +1,7 @@ package main import ( + "bufio" "crypto/aes" "crypto/md5" "crypto/sha1" @@ -401,7 +402,7 @@ func main() { if s, ok := m["ws"]; ok { sc := new(config.WebsocketServer) if err := s.Decode(sc); err != nil { - log.Warn("读取http配置失败 :", err) + log.Warn("读取正向Websocket配置失败 :", err) } else { go server.RunWebSocketServer(bot, sc) } @@ -409,7 +410,7 @@ func main() { if c, ok := m["ws-reverse"]; ok { rc := new(config.WebsocketReverse) if err := c.Decode(rc); err != nil { - log.Warn("读取正向Websocket配置失败 :", err) + log.Warn("读取反向Websocket配置失败 :", err) } else { go server.RunWebSocketClient(bot, rc) } @@ -417,7 +418,7 @@ func main() { if p, ok := m["pprof"]; ok { pc := new(config.PprofServer) if err := p.Decode(pc); err != nil { - log.Warn("读取反向Websocket配置失败 :", err) + log.Warn("读取pprof配置失败 :", err) } else { go server.RunPprofServer(pc) } @@ -484,46 +485,77 @@ func checkUpdate() { func selfUpdate(imageURL string) { log.Infof("正在检查更新.") - var res string + var res, r string if err := gout.GET("https://api.github.com/repos/Mrs4s/go-cqhttp/releases/latest").BindBody(&res).Do(); err != nil { log.Warnf("检查更新失败: %v", err) return } info := gjson.Parse(res) version := info.Get("tag_name").Str - if coolq.Version != version { - log.Info("当前最新版本为 ", version) - log.Warn("是否更新(y/N): ") - r := strings.TrimSpace(readLine()) - if r != "y" && r != "Y" { - log.Warn("已取消更新!") - } else { - log.Info("正在更新,请稍等...") - url := fmt.Sprintf( - "%v/Mrs4s/go-cqhttp/releases/download/%v/go-cqhttp_%v_%v", - func() string { - if imageURL != "" { - return imageURL - } - return "https://github.com" - }(), - version, runtime.GOOS, func() string { - if runtime.GOARCH == "arm" { - return "armv7" - } - return runtime.GOARCH - }(), - ) - if runtime.GOOS == "windows" { - url += ".zip" - } else { - url += ".tar.gz" - } - update.Update(url) - } - } else { + if coolq.Version == version { log.Info("当前版本已经是最新版本!") + goto wait } + log.Info("当前最新版本为 ", version) + log.Warn("是否更新(y/N): ") + r = strings.TrimSpace(readLine()) + if r != "y" && r != "Y" { + log.Warn("已取消更新!") + } else { + log.Info("正在更新,请稍等...") + sumURL := fmt.Sprintf("%v/Mrs4s/go-cqhttp/releases/download/%v/go-cqhttp_checksums.txt", + func() string { + if imageURL != "" { + return imageURL + } + return "https://github.com" + }(), version) + closer, err := global.HTTPGetReadCloser(sumURL) + if err != nil { + log.Error("更新失败: ", err) + goto wait + } + rd := bufio.NewReader(closer) + binaryName := fmt.Sprintf("go-cqhttp_%v_%v.%v", runtime.GOOS, func() string { + if runtime.GOARCH == "arm" { + return "armv7" + } + return runtime.GOARCH + }(), func() string { + if runtime.GOOS == "windows" { + return "zip" + } else { + return "tar.gz" + } + }()) + var sum []byte + for { + str, err := rd.ReadString('\n') + if err != nil { + break + } + str = strings.TrimSpace(str) + if strings.HasSuffix(str, binaryName) { + sum, _ = hex.DecodeString(strings.TrimSuffix(str, " "+binaryName)) + break + } + } + url := fmt.Sprintf("%v/Mrs4s/go-cqhttp/releases/download/%v/%v", + func() string { + if imageURL != "" { + return imageURL + } + return "https://github.com" + }(), version, binaryName) + + err = update.Update(url, sum) + if err != nil { + log.Error("更新失败: ", err) + } else { + log.Info("更新成功!") + } + } +wait: log.Info("按 Enter 继续....") readLine() os.Exit(0)