diff --git a/cmd/gocq/login.go b/cmd/gocq/login.go index d963d25..3fc3813 100644 --- a/cmd/gocq/login.go +++ b/cmd/gocq/login.go @@ -15,9 +15,9 @@ import ( "github.com/mattn/go-colorable" "github.com/pkg/errors" log "github.com/sirupsen/logrus" - "github.com/tidwall/gjson" "github.com/Mrs4s/go-cqhttp/global" + "github.com/Mrs4s/go-cqhttp/internal/download" ) var console = bufio.NewReader(os.Stdin) @@ -243,12 +243,11 @@ func getTicket(u string) (str string) { } func fetchCaptcha(id string) string { - data, err := global.GetBytes("https://captcha.go-cqhttp.org/captcha/ticket?id=" + id) + g, err := download.Request{URL: "https://captcha.go-cqhttp.org/captcha/ticket?id=" + id}.JSON() if err != nil { log.Warnf("获取 Ticket 时出现错误: %v", err) return "" } - g := gjson.ParseBytes(data) if g.Get("ticket").Exists() { return g.Get("ticket").String() } diff --git a/coolq/api.go b/coolq/api.go index 526389c..d7275ce 100644 --- a/coolq/api.go +++ b/coolq/api.go @@ -27,6 +27,7 @@ import ( "github.com/Mrs4s/go-cqhttp/global" "github.com/Mrs4s/go-cqhttp/internal/base" "github.com/Mrs4s/go-cqhttp/internal/cache" + "github.com/Mrs4s/go-cqhttp/internal/download" "github.com/Mrs4s/go-cqhttp/internal/param" "github.com/Mrs4s/go-cqhttp/modules/filter" ) @@ -1545,12 +1546,8 @@ func (bot *CQBot) CQGetImage(file string) global.MSG { } local := path.Join(global.CachePath, file+path.Ext(msg["filename"].(string))) if !global.PathExists(local) { - if body, err := global.HTTPGetReadCloser(msg["url"].(string)); err == nil { - f, _ := os.OpenFile(local, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o0644) - _, _ = f.ReadFrom(body) - _ = body.Close() - _ = f.Close() - } else { + r := download.Request{URL: msg["url"].(string)} + if err := r.WriteToFile(local); err != nil { log.Warnf("下载图片 %v 时出现错误: %v", msg["url"], err) return Failed(100, "DOWNLOAD_IMAGE_ERROR", err.Error()) } @@ -1593,7 +1590,8 @@ func (bot *CQBot) CQDownloadFile(url string, headers gjson.Result, threadCount i return Failed(100, "DELETE_FILE_ERROR", err.Error()) } } - if err := global.DownloadFileMultiThreading(url, file, 0, threadCount, h); err != nil { + r := download.Request{URL: url, Header: h} + if err := r.WriteToFileMultiThreading(file, threadCount); err != nil { log.Warnf("下载链接 %v 时出现错误: %v", url, err) return Failed(100, "DOWNLOAD_FILE_ERROR", err.Error()) } diff --git a/coolq/cqcode.go b/coolq/cqcode.go index 4a3a4f3..c241e34 100644 --- a/coolq/cqcode.go +++ b/coolq/cqcode.go @@ -29,6 +29,7 @@ import ( "github.com/Mrs4s/go-cqhttp/global" "github.com/Mrs4s/go-cqhttp/internal/base" "github.com/Mrs4s/go-cqhttp/internal/cache" + "github.com/Mrs4s/go-cqhttp/internal/download" "github.com/Mrs4s/go-cqhttp/internal/mime" "github.com/Mrs4s/go-cqhttp/internal/param" ) @@ -895,9 +896,9 @@ func (bot *CQBot) ToElement(t string, d map[string]string, sourceType message.So name := info.Get("track_info.name").Str mid := info.Get("track_info.mid").Str albumMid := info.Get("track_info.album.mid").Str - pinfo, _ := global.GetBytes("http://u.y.qq.com/cgi-bin/musicu.fcg?g_tk=2034008533&uin=0&format=json&data={\"comm\":{\"ct\":23,\"cv\":0},\"url_mid\":{\"module\":\"vkey.GetVkeyServer\",\"method\":\"CgiGetVkey\",\"param\":{\"guid\":\"4311206557\",\"songmid\":[\"" + mid + "\"],\"songtype\":[0],\"uin\":\"0\",\"loginflag\":1,\"platform\":\"23\"}}}&_=1599039471576") + pinfo, _ := download.Request{URL: "http://u.y.qq.com/cgi-bin/musicu.fcg?g_tk=2034008533&uin=0&format=json&data={\"comm\":{\"ct\":23,\"cv\":0},\"url_mid\":{\"module\":\"vkey.GetVkeyServer\",\"method\":\"CgiGetVkey\",\"param\":{\"guid\":\"4311206557\",\"songmid\":[\"" + mid + "\"],\"songtype\":[0],\"uin\":\"0\",\"loginflag\":1,\"platform\":\"23\"}}}&_=1599039471576"}.JSON() jumpURL := "https://i.y.qq.com/v8/playsong.html?platform=11&appshare=android_qq&appversion=10030010&hosteuin=oKnlNenz7i-s7c**&songmid=" + mid + "&type=0&appsongtype=1&_wv=1&source=qq&ADTAG=qfshare" - purl := gjson.ParseBytes(pinfo).Get("url_mid.data.midurlinfo.0.purl").Str + purl := pinfo.Get("url_mid.data.midurlinfo.0.purl").Str preview := "http://y.gtimg.cn/music/photo_new/T002R180x180M000" + albumMid + ".jpg" content := info.Get("track_info.singer.0.name").Str if d["content"] != "" { @@ -1089,8 +1090,11 @@ func (bot *CQBot) makeImageOrVideoElem(d map[string]string, video bool, sourceTy if exist { _ = os.Remove(cacheFile) } - if err := global.DownloadFileMultiThreading(f, cacheFile, maxSize, thread, nil); err != nil { - return nil, err + { + r := download.Request{URL: f, Limit: maxSize} + if err := r.WriteToFileMultiThreading(cacheFile, thread); err != nil { + return nil, err + } } useCacheFile: if video { diff --git a/coolq/event.go b/coolq/event.go index ebdc960..34500e2 100644 --- a/coolq/event.go +++ b/coolq/event.go @@ -4,7 +4,6 @@ import ( "encoding/hex" "encoding/json" "fmt" - "os" "path" "strconv" "strings" @@ -18,6 +17,7 @@ import ( "github.com/Mrs4s/go-cqhttp/global" "github.com/Mrs4s/go-cqhttp/internal/base" "github.com/Mrs4s/go-cqhttp/internal/cache" + "github.com/Mrs4s/go-cqhttp/internal/download" ) // ToFormattedMessage 将给定[]message.IMessageElement转换为通过coolq.SetMessageFormat所定义的消息上报格式 @@ -666,7 +666,8 @@ func (bot *CQBot) checkMedia(e []message.IMessageElement, sourceID int64) { filename := hex.EncodeToString(i.Md5) + ".image" cache.Image.Insert(i.Md5, data) if i.Url != "" && !global.PathExists(path.Join(global.ImagePath, "guild-images", filename)) { - if err := global.DownloadFile(i.Url, path.Join(global.ImagePath, "guild-images", filename), -1, nil); err != nil { + r := download.Request{URL: i.Url} + if err := r.WriteToFile(path.Join(global.ImagePath, "guild-images", filename)); err != nil { log.Warnf("下载频道图片时出现错误: %v", err) } } @@ -684,12 +685,11 @@ func (bot *CQBot) checkMedia(e []message.IMessageElement, sourceID int64) { i.Name = strings.ReplaceAll(i.Name, "{", "") i.Name = strings.ReplaceAll(i.Name, "}", "") if !global.PathExists(path.Join(global.VoicePath, i.Name)) { - b, err := global.GetBytes(i.Url) + err := download.Request{URL: i.Url}.WriteToFile(path.Join(global.VoicePath, i.Name)) if err != nil { log.Warnf("语音文件 %v 下载失败: %v", i.Name, err) continue } - _ = os.WriteFile(path.Join(global.VoicePath, i.Name), b, 0o644) } case *message.ShortVideoElement: data := binary.NewWriterF(func(w *binary.Writer) { diff --git a/global/fs.go b/global/fs.go index f985dd5..046ac2c 100644 --- a/global/fs.go +++ b/global/fs.go @@ -16,6 +16,8 @@ import ( b14 "github.com/fumiama/go-base16384" "github.com/segmentio/asm/base64" log "github.com/sirupsen/logrus" + + "github.com/Mrs4s/go-cqhttp/internal/download" ) const ( @@ -82,8 +84,7 @@ func FindFile(file, cache, p string) (data []byte, err error) { if (cache == "" || cache == "1") && PathExists(cacheFile) { return os.ReadFile(cacheFile) } - data, err = GetBytes(file) - _ = os.WriteFile(cacheFile, data, 0o644) + err = download.Request{URL: file}.WriteToFile(cacheFile) if err != nil { return nil, err } diff --git a/global/net.go b/global/net.go index 0327a7e..030cbea 100644 --- a/global/net.go +++ b/global/net.go @@ -1,304 +1,27 @@ package global import ( - "bufio" - "compress/gzip" "fmt" - "io" - "net/http" - "net/url" - "os" - "strconv" - "strings" - "sync" - - "github.com/pkg/errors" "github.com/tidwall/gjson" - "github.com/Mrs4s/go-cqhttp/internal/base" + "github.com/Mrs4s/go-cqhttp/internal/download" ) -var ( - client = &http.Client{ - Transport: &http.Transport{ - Proxy: func(request *http.Request) (u *url.URL, e error) { - if base.Proxy == "" { - return http.ProxyFromEnvironment(request) - } - return url.Parse(base.Proxy) - }, - ForceAttemptHTTP2: true, - MaxConnsPerHost: 0, - MaxIdleConns: 0, - MaxIdleConnsPerHost: 999, - }, - } - - // ErrOverSize 响应主体过大时返回此错误 - ErrOverSize = errors.New("oversize") - - // UserAgent HTTP请求时使用的UA - UserAgent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/87.0.4280.88 Safari/537.36 Edg/87.0.664.66" -) - -// GetBytes 对给定URL发送Get请求,返回响应主体 -func GetBytes(url string) ([]byte, error) { - reader, err := HTTPGetReadCloser(url) - if err != nil { - return nil, err - } - defer func() { - _ = reader.Close() - }() - return io.ReadAll(reader) -} - -// DownloadFile 将给定URL对应的文件下载至给定Path -func DownloadFile(url, path string, limit int64, headers map[string]string) error { - file, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE, 0o666) - if err != nil { - return err - } - defer file.Close() - req, err := http.NewRequest(http.MethodGet, url, nil) - if err != nil { - return err - } - - for k, v := range headers { - req.Header.Set(k, v) - } - - resp, err := client.Do(req) - if err != nil { - return err - } - defer resp.Body.Close() - if limit > 0 && resp.ContentLength > limit { - return ErrOverSize - } - _, err = file.ReadFrom(resp.Body) - if err != nil { - return err - } - return nil -} - -// DownloadFileMultiThreading 使用threadCount个线程将给定URL对应的文件下载至给定Path -func DownloadFileMultiThreading(url, path string, limit int64, threadCount int, headers map[string]string) error { - if threadCount < 2 { - return DownloadFile(url, path, limit, headers) - } - type BlockMetaData struct { - BeginOffset int64 - EndOffset int64 - DownloadedSize int64 - } - var blocks []*BlockMetaData - var contentLength int64 - errUnsupportedMultiThreading := errors.New("unsupported multi-threading") - // 初始化分块或直接下载 - initOrDownload := func() error { - copyStream := func(s io.ReadCloser) error { - file, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE, 0o666) - if err != nil { - return err - } - defer file.Close() - if _, err = file.ReadFrom(s); err != nil { - return err - } - return errUnsupportedMultiThreading - } - req, err := http.NewRequest(http.MethodGet, url, nil) - if err != nil { - return err - } - - for k, v := range headers { - req.Header.Set(k, v) - } - if _, ok := headers["User-Agent"]; !ok { - req.Header["User-Agent"] = []string{UserAgent} - } - req.Header.Set("range", "bytes=0-") - resp, err := client.Do(req) - if err != nil { - return err - } - defer resp.Body.Close() - if resp.StatusCode < 200 || resp.StatusCode >= 300 { - return errors.New("response status unsuccessful: " + strconv.FormatInt(int64(resp.StatusCode), 10)) - } - if resp.StatusCode == http.StatusOK { - if limit > 0 && resp.ContentLength > limit { - return ErrOverSize - } - return copyStream(resp.Body) - } - if resp.StatusCode == http.StatusPartialContent { - contentLength = resp.ContentLength - if limit > 0 && resp.ContentLength > limit { - return ErrOverSize - } - blockSize := contentLength - if contentLength > 1024*1024 { - blockSize = (contentLength / int64(threadCount)) - 10 - } - if blockSize == contentLength { - return copyStream(resp.Body) - } - var tmp int64 - for tmp+blockSize < contentLength { - blocks = append(blocks, &BlockMetaData{ - BeginOffset: tmp, - EndOffset: tmp + blockSize - 1, - }) - tmp += blockSize - } - blocks = append(blocks, &BlockMetaData{ - BeginOffset: tmp, - EndOffset: contentLength - 1, - }) - return nil - } - return errors.New("unknown status code") - } - // 下载分块 - downloadBlock := func(block *BlockMetaData) error { - req, _ := http.NewRequest(http.MethodGet, url, nil) - file, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE, 0o666) - if err != nil { - return err - } - defer file.Close() - _, _ = file.Seek(block.BeginOffset, io.SeekStart) - writer := bufio.NewWriter(file) - defer writer.Flush() - - for k, v := range headers { - req.Header.Set(k, v) - } - - if _, ok := headers["User-Agent"]; !ok { - req.Header["User-Agent"] = []string{UserAgent} - } - req.Header.Set("range", "bytes="+strconv.FormatInt(block.BeginOffset, 10)+"-"+strconv.FormatInt(block.EndOffset, 10)) - resp, err := client.Do(req) - if err != nil { - return err - } - defer resp.Body.Close() - if resp.StatusCode < 200 || resp.StatusCode >= 300 { - return errors.New("response status unsuccessful: " + strconv.FormatInt(int64(resp.StatusCode), 10)) - } - buffer := make([]byte, 1024) - i, err := resp.Body.Read(buffer) - for { - if err != nil && err != io.EOF { - return err - } - i64 := int64(len(buffer[:i])) - needSize := block.EndOffset + 1 - block.BeginOffset - if i64 > needSize { - i64 = needSize - err = io.EOF - } - _, e := writer.Write(buffer[:i64]) - if e != nil { - return e - } - block.BeginOffset += i64 - block.DownloadedSize += i64 - if err == io.EOF || block.BeginOffset > block.EndOffset { - break - } - i, err = resp.Body.Read(buffer) - } - return nil - } - - if err := initOrDownload(); err != nil { - if err == errUnsupportedMultiThreading { - return nil - } - return err - } - wg := sync.WaitGroup{} - wg.Add(len(blocks)) - var lastErr error - for i := range blocks { - go func(b *BlockMetaData) { - defer wg.Done() - if err := downloadBlock(b); err != nil { - lastErr = err - } - }(blocks[i]) - } - wg.Wait() - return lastErr -} - // QQMusicSongInfo 通过给定id在QQ音乐上查找曲目信息 func QQMusicSongInfo(id string) (gjson.Result, error) { - d, err := GetBytes(`https://u.y.qq.com/cgi-bin/musicu.fcg?format=json&inCharset=utf8&outCharset=utf-8¬ice=0&platform=yqq.json&needNewCode=0&data={%22comm%22:{%22ct%22:24,%22cv%22:0},%22songinfo%22:{%22method%22:%22get_song_detail_yqq%22,%22param%22:{%22song_type%22:0,%22song_mid%22:%22%22,%22song_id%22:` + id + `},%22module%22:%22music.pf_song_detail_svr%22}}`) + d, err := download.Request{URL: `https://u.y.qq.com/cgi-bin/musicu.fcg?format=json&inCharset=utf8&outCharset=utf-8¬ice=0&platform=yqq.json&needNewCode=0&data={%22comm%22:{%22ct%22:24,%22cv%22:0},%22songinfo%22:{%22method%22:%22get_song_detail_yqq%22,%22param%22:{%22song_type%22:0,%22song_mid%22:%22%22,%22song_id%22:` + id + `},%22module%22:%22music.pf_song_detail_svr%22}}`}.JSON() if err != nil { return gjson.Result{}, err } - return gjson.ParseBytes(d).Get("songinfo.data"), nil + return d.Get("songinfo.data"), nil } // NeteaseMusicSongInfo 通过给定id在wdd音乐上查找曲目信息 func NeteaseMusicSongInfo(id string) (gjson.Result, error) { - d, err := GetBytes(fmt.Sprintf("http://music.163.com/api/song/detail/?id=%s&ids=%%5B%s%%5D", id, id)) + d, err := download.Request{URL: fmt.Sprintf("http://music.163.com/api/song/detail/?id=%s&ids=%%5B%s%%5D", id, id)}.JSON() if err != nil { return gjson.Result{}, err } - return gjson.ParseBytes(d).Get("songs.0"), nil -} - -type gzipCloser struct { - f io.Closer - r *gzip.Reader -} - -// NewGzipReadCloser 从 io.ReadCloser 创建 gunzip io.ReadCloser -func NewGzipReadCloser(reader io.ReadCloser) (io.ReadCloser, error) { - gzipReader, err := gzip.NewReader(reader) - if err != nil { - return nil, err - } - return &gzipCloser{ - f: reader, - r: gzipReader, - }, nil -} - -// Read impls io.Reader -func (g *gzipCloser) Read(p []byte) (n int, err error) { - return g.r.Read(p) -} - -// Close impls io.Closer -func (g *gzipCloser) Close() error { - _ = g.f.Close() - return g.r.Close() -} - -// HTTPGetReadCloser 从 Http url 获取 io.ReadCloser -func HTTPGetReadCloser(url string) (io.ReadCloser, error) { - req, err := http.NewRequest(http.MethodGet, url, nil) - if err != nil { - return nil, err - } - req.Header["User-Agent"] = []string{UserAgent} - resp, err := client.Do(req) - if err != nil { - return nil, err - } - if strings.Contains(resp.Header.Get("Content-Encoding"), "gzip") { - return NewGzipReadCloser(resp.Body) - } - return resp.Body, err + return d.Get("songs.0"), nil } diff --git a/internal/download/download.go b/internal/download/download.go new file mode 100644 index 0000000..b44bb6d --- /dev/null +++ b/internal/download/download.go @@ -0,0 +1,295 @@ +package download + +import ( + "bufio" + "compress/gzip" + "fmt" + "io" + "net/http" + "net/url" + "os" + "strconv" + "strings" + "sync" + + "github.com/pkg/errors" + "github.com/tidwall/gjson" + + "github.com/Mrs4s/go-cqhttp/internal/base" +) + +var client = &http.Client{ + Transport: &http.Transport{ + Proxy: func(request *http.Request) (u *url.URL, e error) { + if base.Proxy == "" { + return http.ProxyFromEnvironment(request) + } + return url.Parse(base.Proxy) + }, + ForceAttemptHTTP2: false, + MaxConnsPerHost: 0, + MaxIdleConns: 0, + MaxIdleConnsPerHost: 999, + }, +} + +// ErrOverSize 响应主体过大时返回此错误 +var ErrOverSize = errors.New("oversize") + +// UserAgent HTTP请求时使用的UA +const UserAgent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/87.0.4280.88 Safari/537.36 Edg/87.0.664.66" + +type Request struct { + URL string + Header map[string]string + Limit int64 +} + +func (r Request) do() (*http.Response, error) { + req, err := http.NewRequest(http.MethodGet, r.URL, nil) + if err != nil { + return nil, err + } + + req.Header["User-Agent"] = []string{UserAgent} + for k, v := range r.Header { + req.Header.Set(k, v) + } + + return client.Do(req) +} + +func (r Request) body() (io.ReadCloser, error) { + resp, err := r.do() + if err != nil { + return nil, err + } + + limit := r.Limit // check file size limit + if limit > 0 && resp.ContentLength > limit { + _ = resp.Body.Close() + return nil, ErrOverSize + } + + if strings.Contains(resp.Header.Get("Content-Encoding"), "gzip") { + return gzipReadCloser(resp.Body) + } + return resp.Body, err +} + +// Bytes 对给定URL发送Get请求,返回响应主体 +func (r Request) Bytes() ([]byte, error) { + rd, err := r.body() + if err != nil { + return nil, err + } + defer rd.Close() + return io.ReadAll(rd) +} + +func (r Request) JSON() (gjson.Result, error) { + rd, err := r.body() + if err != nil { + return gjson.Result{}, err + } + defer rd.Close() + + var sb strings.Builder + _, err = io.Copy(&sb, rd) + if err != nil { + return gjson.Result{}, err + } + + return gjson.Parse(sb.String()), nil +} + +func writeToFile(reader io.ReadCloser, path string) error { + file, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE, 0o644) + if err != nil { + return err + } + _, err = file.ReadFrom(reader) + return err +} + +func (r Request) WriteToFile(path string) error { + rd, err := r.body() + if err != nil { + return err + } + defer rd.Close() + return writeToFile(rd, path) +} + +func (r Request) WriteToFileMultiThreading(path string, thread int) error { + if thread < 2 { + return r.WriteToFile(path) + } + + limit := r.Limit + type BlockMetaData struct { + BeginOffset int64 + EndOffset int64 + DownloadedSize int64 + } + var blocks []*BlockMetaData + var contentLength int64 + errUnsupportedMultiThreading := errors.New("unsupported multi-threading") + // 初始化分块或直接下载 + initOrDownload := func() error { + header := make(map[string]string, len(r.Header)) + for k, v := range r.Header { // copy headers + header[k] = v + } + header["range"] = "bytes=0-" + req := Request{ + URL: r.URL, + Header: header, + } + resp, err := req.do() + if err != nil { + return err + } + defer resp.Body.Close() + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return errors.New("response status unsuccessful: " + strconv.FormatInt(int64(resp.StatusCode), 10)) + } + if resp.StatusCode == http.StatusOK { + if limit > 0 && resp.ContentLength > limit { + return ErrOverSize + } + if err = writeToFile(resp.Body, path); err != nil { + return err + } + return errUnsupportedMultiThreading + } + if resp.StatusCode == http.StatusPartialContent { + contentLength = resp.ContentLength + if limit > 0 && resp.ContentLength > limit { + return ErrOverSize + } + blockSize := contentLength + if contentLength > 1024*1024 { + blockSize = (contentLength / int64(thread)) - 10 + } + if blockSize == contentLength { + return writeToFile(resp.Body, path) + } + var tmp int64 + for tmp+blockSize < contentLength { + blocks = append(blocks, &BlockMetaData{ + BeginOffset: tmp, + EndOffset: tmp + blockSize - 1, + }) + tmp += blockSize + } + blocks = append(blocks, &BlockMetaData{ + BeginOffset: tmp, + EndOffset: contentLength - 1, + }) + return nil + } + return errors.New("unknown status code") + } + // 下载分块 + downloadBlock := func(block *BlockMetaData) error { + file, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE, 0o666) + if err != nil { + return err + } + defer file.Close() + _, _ = file.Seek(block.BeginOffset, io.SeekStart) + writer := bufio.NewWriter(file) + defer writer.Flush() + + header := make(map[string]string, len(r.Header)) + for k, v := range r.Header { // copy headers + header[k] = v + } + header["range"] = fmt.Sprintf("bytes=%d-%d", block.BeginOffset, block.EndOffset) + req := Request{ + URL: r.URL, + Header: header, + } + resp, err := req.do() + if err != nil { + return err + } + defer resp.Body.Close() + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return errors.New("response status unsuccessful: " + strconv.FormatInt(int64(resp.StatusCode), 10)) + } + buffer := make([]byte, 1024) + i, err := resp.Body.Read(buffer) + for { + if err != nil && err != io.EOF { + return err + } + i64 := int64(len(buffer[:i])) + needSize := block.EndOffset + 1 - block.BeginOffset + if i64 > needSize { + i64 = needSize + err = io.EOF + } + _, e := writer.Write(buffer[:i64]) + if e != nil { + return e + } + block.BeginOffset += i64 + block.DownloadedSize += i64 + if err == io.EOF || block.BeginOffset > block.EndOffset { + break + } + i, err = resp.Body.Read(buffer) + } + return nil + } + + if err := initOrDownload(); err != nil { + if err == errUnsupportedMultiThreading { + return nil + } + return err + } + wg := sync.WaitGroup{} + wg.Add(len(blocks)) + var lastErr error + for i := range blocks { + go func(b *BlockMetaData) { + defer wg.Done() + if err := downloadBlock(b); err != nil { + lastErr = err + } + }(blocks[i]) + } + wg.Wait() + return lastErr +} + +type gzipCloser struct { + f io.Closer + r *gzip.Reader +} + +// gzipReadCloser 从 io.ReadCloser 创建 gunzip io.ReadCloser +func gzipReadCloser(reader io.ReadCloser) (io.ReadCloser, error) { + gzipReader, err := gzip.NewReader(reader) + if err != nil { + return nil, err + } + return &gzipCloser{ + f: reader, + r: gzipReader, + }, nil +} + +// Read impls io.Reader +func (g *gzipCloser) Read(p []byte) (n int, err error) { + return g.r.Read(p) +} + +// Close impls io.Closer +func (g *gzipCloser) Close() error { + _ = g.f.Close() + return g.r.Close() +} diff --git a/internal/selfupdate/update.go b/internal/selfupdate/update.go index 8771502..108e5fc 100644 --- a/internal/selfupdate/update.go +++ b/internal/selfupdate/update.go @@ -3,6 +3,7 @@ package selfupdate import ( "bufio" + "bytes" "encoding/hex" "fmt" "hash" @@ -14,10 +15,10 @@ import ( "strings" "github.com/sirupsen/logrus" - "github.com/tidwall/gjson" "github.com/Mrs4s/go-cqhttp/global" "github.com/Mrs4s/go-cqhttp/internal/base" + "github.com/Mrs4s/go-cqhttp/internal/download" ) func readLine() (str string) { @@ -28,11 +29,11 @@ func readLine() (str string) { } func lastVersion() (string, error) { - r, err := global.GetBytes("https://api.github.com/repos/Mrs4s/go-cqhttp/releases/latest") + r, err := download.Request{URL: "https://api.github.com/repos/Mrs4s/go-cqhttp/releases/latest"}.JSON() if err != nil { return "", err } - return gjson.GetBytes(r, "tag_name").Str, nil + return r.Get("tag_name").Str, nil } // CheckUpdate 检查更新 @@ -69,12 +70,12 @@ func binaryName() string { func checksum(github, version string) []byte { sumURL := fmt.Sprintf("%v/Mrs4s/go-cqhttp/releases/download/%v/go-cqhttp_checksums.txt", github, version) - closer, err := global.HTTPGetReadCloser(sumURL) + sum, err := download.Request{URL: sumURL}.Bytes() if err != nil { return nil } - rd := bufio.NewReader(closer) + rd := bufio.NewReader(bytes.NewReader(sum)) for { str, err := rd.ReadString('\n') if err != nil {