1
0
mirror of https://github.com/Mrs4s/go-cqhttp.git synced 2025-05-04 19:17:37 +08:00

internal/download: move download logic to this package

also disable http2 when downloading, may fix some issue

For #1563
This commit is contained in:
wdvxdr 2022-10-13 21:27:36 +08:00
parent 1dd12df26a
commit f4117bfb70
8 changed files with 328 additions and 307 deletions

View File

@ -15,9 +15,9 @@ import (
"github.com/mattn/go-colorable" "github.com/mattn/go-colorable"
"github.com/pkg/errors" "github.com/pkg/errors"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"github.com/Mrs4s/go-cqhttp/global" "github.com/Mrs4s/go-cqhttp/global"
"github.com/Mrs4s/go-cqhttp/internal/download"
) )
var console = bufio.NewReader(os.Stdin) var console = bufio.NewReader(os.Stdin)
@ -243,12 +243,11 @@ func getTicket(u string) (str string) {
} }
func fetchCaptcha(id string) string { func fetchCaptcha(id string) string {
data, err := global.GetBytes("https://captcha.go-cqhttp.org/captcha/ticket?id=" + id) g, err := download.Request{URL: "https://captcha.go-cqhttp.org/captcha/ticket?id=" + id}.JSON()
if err != nil { if err != nil {
log.Warnf("获取 Ticket 时出现错误: %v", err) log.Warnf("获取 Ticket 时出现错误: %v", err)
return "" return ""
} }
g := gjson.ParseBytes(data)
if g.Get("ticket").Exists() { if g.Get("ticket").Exists() {
return g.Get("ticket").String() return g.Get("ticket").String()
} }

View File

@ -27,6 +27,7 @@ import (
"github.com/Mrs4s/go-cqhttp/global" "github.com/Mrs4s/go-cqhttp/global"
"github.com/Mrs4s/go-cqhttp/internal/base" "github.com/Mrs4s/go-cqhttp/internal/base"
"github.com/Mrs4s/go-cqhttp/internal/cache" "github.com/Mrs4s/go-cqhttp/internal/cache"
"github.com/Mrs4s/go-cqhttp/internal/download"
"github.com/Mrs4s/go-cqhttp/internal/param" "github.com/Mrs4s/go-cqhttp/internal/param"
"github.com/Mrs4s/go-cqhttp/modules/filter" "github.com/Mrs4s/go-cqhttp/modules/filter"
) )
@ -1545,12 +1546,8 @@ func (bot *CQBot) CQGetImage(file string) global.MSG {
} }
local := path.Join(global.CachePath, file+path.Ext(msg["filename"].(string))) local := path.Join(global.CachePath, file+path.Ext(msg["filename"].(string)))
if !global.PathExists(local) { if !global.PathExists(local) {
if body, err := global.HTTPGetReadCloser(msg["url"].(string)); err == nil { r := download.Request{URL: msg["url"].(string)}
f, _ := os.OpenFile(local, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o0644) if err := r.WriteToFile(local); err != nil {
_, _ = f.ReadFrom(body)
_ = body.Close()
_ = f.Close()
} else {
log.Warnf("下载图片 %v 时出现错误: %v", msg["url"], err) log.Warnf("下载图片 %v 时出现错误: %v", msg["url"], err)
return Failed(100, "DOWNLOAD_IMAGE_ERROR", err.Error()) return Failed(100, "DOWNLOAD_IMAGE_ERROR", err.Error())
} }
@ -1593,7 +1590,8 @@ func (bot *CQBot) CQDownloadFile(url string, headers gjson.Result, threadCount i
return Failed(100, "DELETE_FILE_ERROR", err.Error()) return Failed(100, "DELETE_FILE_ERROR", err.Error())
} }
} }
if err := global.DownloadFileMultiThreading(url, file, 0, threadCount, h); err != nil { r := download.Request{URL: url, Header: h}
if err := r.WriteToFileMultiThreading(file, threadCount); err != nil {
log.Warnf("下载链接 %v 时出现错误: %v", url, err) log.Warnf("下载链接 %v 时出现错误: %v", url, err)
return Failed(100, "DOWNLOAD_FILE_ERROR", err.Error()) return Failed(100, "DOWNLOAD_FILE_ERROR", err.Error())
} }

View File

@ -29,6 +29,7 @@ import (
"github.com/Mrs4s/go-cqhttp/global" "github.com/Mrs4s/go-cqhttp/global"
"github.com/Mrs4s/go-cqhttp/internal/base" "github.com/Mrs4s/go-cqhttp/internal/base"
"github.com/Mrs4s/go-cqhttp/internal/cache" "github.com/Mrs4s/go-cqhttp/internal/cache"
"github.com/Mrs4s/go-cqhttp/internal/download"
"github.com/Mrs4s/go-cqhttp/internal/mime" "github.com/Mrs4s/go-cqhttp/internal/mime"
"github.com/Mrs4s/go-cqhttp/internal/param" "github.com/Mrs4s/go-cqhttp/internal/param"
) )
@ -895,9 +896,9 @@ func (bot *CQBot) ToElement(t string, d map[string]string, sourceType message.So
name := info.Get("track_info.name").Str name := info.Get("track_info.name").Str
mid := info.Get("track_info.mid").Str mid := info.Get("track_info.mid").Str
albumMid := info.Get("track_info.album.mid").Str albumMid := info.Get("track_info.album.mid").Str
pinfo, _ := global.GetBytes("http://u.y.qq.com/cgi-bin/musicu.fcg?g_tk=2034008533&uin=0&format=json&data={\"comm\":{\"ct\":23,\"cv\":0},\"url_mid\":{\"module\":\"vkey.GetVkeyServer\",\"method\":\"CgiGetVkey\",\"param\":{\"guid\":\"4311206557\",\"songmid\":[\"" + mid + "\"],\"songtype\":[0],\"uin\":\"0\",\"loginflag\":1,\"platform\":\"23\"}}}&_=1599039471576") pinfo, _ := download.Request{URL: "http://u.y.qq.com/cgi-bin/musicu.fcg?g_tk=2034008533&uin=0&format=json&data={\"comm\":{\"ct\":23,\"cv\":0},\"url_mid\":{\"module\":\"vkey.GetVkeyServer\",\"method\":\"CgiGetVkey\",\"param\":{\"guid\":\"4311206557\",\"songmid\":[\"" + mid + "\"],\"songtype\":[0],\"uin\":\"0\",\"loginflag\":1,\"platform\":\"23\"}}}&_=1599039471576"}.JSON()
jumpURL := "https://i.y.qq.com/v8/playsong.html?platform=11&appshare=android_qq&appversion=10030010&hosteuin=oKnlNenz7i-s7c**&songmid=" + mid + "&type=0&appsongtype=1&_wv=1&source=qq&ADTAG=qfshare" jumpURL := "https://i.y.qq.com/v8/playsong.html?platform=11&appshare=android_qq&appversion=10030010&hosteuin=oKnlNenz7i-s7c**&songmid=" + mid + "&type=0&appsongtype=1&_wv=1&source=qq&ADTAG=qfshare"
purl := gjson.ParseBytes(pinfo).Get("url_mid.data.midurlinfo.0.purl").Str purl := pinfo.Get("url_mid.data.midurlinfo.0.purl").Str
preview := "http://y.gtimg.cn/music/photo_new/T002R180x180M000" + albumMid + ".jpg" preview := "http://y.gtimg.cn/music/photo_new/T002R180x180M000" + albumMid + ".jpg"
content := info.Get("track_info.singer.0.name").Str content := info.Get("track_info.singer.0.name").Str
if d["content"] != "" { if d["content"] != "" {
@ -1089,9 +1090,12 @@ func (bot *CQBot) makeImageOrVideoElem(d map[string]string, video bool, sourceTy
if exist { if exist {
_ = os.Remove(cacheFile) _ = os.Remove(cacheFile)
} }
if err := global.DownloadFileMultiThreading(f, cacheFile, maxSize, thread, nil); err != nil { {
r := download.Request{URL: f, Limit: maxSize}
if err := r.WriteToFileMultiThreading(cacheFile, thread); err != nil {
return nil, err return nil, err
} }
}
useCacheFile: useCacheFile:
if video { if video {
return &LocalVideoElement{File: cacheFile}, nil return &LocalVideoElement{File: cacheFile}, nil

View File

@ -4,7 +4,6 @@ import (
"encoding/hex" "encoding/hex"
"encoding/json" "encoding/json"
"fmt" "fmt"
"os"
"path" "path"
"strconv" "strconv"
"strings" "strings"
@ -18,6 +17,7 @@ import (
"github.com/Mrs4s/go-cqhttp/global" "github.com/Mrs4s/go-cqhttp/global"
"github.com/Mrs4s/go-cqhttp/internal/base" "github.com/Mrs4s/go-cqhttp/internal/base"
"github.com/Mrs4s/go-cqhttp/internal/cache" "github.com/Mrs4s/go-cqhttp/internal/cache"
"github.com/Mrs4s/go-cqhttp/internal/download"
) )
// ToFormattedMessage 将给定[]message.IMessageElement转换为通过coolq.SetMessageFormat所定义的消息上报格式 // ToFormattedMessage 将给定[]message.IMessageElement转换为通过coolq.SetMessageFormat所定义的消息上报格式
@ -666,7 +666,8 @@ func (bot *CQBot) checkMedia(e []message.IMessageElement, sourceID int64) {
filename := hex.EncodeToString(i.Md5) + ".image" filename := hex.EncodeToString(i.Md5) + ".image"
cache.Image.Insert(i.Md5, data) cache.Image.Insert(i.Md5, data)
if i.Url != "" && !global.PathExists(path.Join(global.ImagePath, "guild-images", filename)) { if i.Url != "" && !global.PathExists(path.Join(global.ImagePath, "guild-images", filename)) {
if err := global.DownloadFile(i.Url, path.Join(global.ImagePath, "guild-images", filename), -1, nil); err != nil { r := download.Request{URL: i.Url}
if err := r.WriteToFile(path.Join(global.ImagePath, "guild-images", filename)); err != nil {
log.Warnf("下载频道图片时出现错误: %v", err) log.Warnf("下载频道图片时出现错误: %v", err)
} }
} }
@ -684,12 +685,11 @@ func (bot *CQBot) checkMedia(e []message.IMessageElement, sourceID int64) {
i.Name = strings.ReplaceAll(i.Name, "{", "") i.Name = strings.ReplaceAll(i.Name, "{", "")
i.Name = strings.ReplaceAll(i.Name, "}", "") i.Name = strings.ReplaceAll(i.Name, "}", "")
if !global.PathExists(path.Join(global.VoicePath, i.Name)) { if !global.PathExists(path.Join(global.VoicePath, i.Name)) {
b, err := global.GetBytes(i.Url) err := download.Request{URL: i.Url}.WriteToFile(path.Join(global.VoicePath, i.Name))
if err != nil { if err != nil {
log.Warnf("语音文件 %v 下载失败: %v", i.Name, err) log.Warnf("语音文件 %v 下载失败: %v", i.Name, err)
continue continue
} }
_ = os.WriteFile(path.Join(global.VoicePath, i.Name), b, 0o644)
} }
case *message.ShortVideoElement: case *message.ShortVideoElement:
data := binary.NewWriterF(func(w *binary.Writer) { data := binary.NewWriterF(func(w *binary.Writer) {

View File

@ -16,6 +16,8 @@ import (
b14 "github.com/fumiama/go-base16384" b14 "github.com/fumiama/go-base16384"
"github.com/segmentio/asm/base64" "github.com/segmentio/asm/base64"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/Mrs4s/go-cqhttp/internal/download"
) )
const ( const (
@ -82,8 +84,7 @@ func FindFile(file, cache, p string) (data []byte, err error) {
if (cache == "" || cache == "1") && PathExists(cacheFile) { if (cache == "" || cache == "1") && PathExists(cacheFile) {
return os.ReadFile(cacheFile) return os.ReadFile(cacheFile)
} }
data, err = GetBytes(file) err = download.Request{URL: file}.WriteToFile(cacheFile)
_ = os.WriteFile(cacheFile, data, 0o644)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -1,304 +1,27 @@
package global package global
import ( import (
"bufio"
"compress/gzip"
"fmt" "fmt"
"io"
"net/http"
"net/url"
"os"
"strconv"
"strings"
"sync"
"github.com/pkg/errors"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
"github.com/Mrs4s/go-cqhttp/internal/base" "github.com/Mrs4s/go-cqhttp/internal/download"
) )
var (
client = &http.Client{
Transport: &http.Transport{
Proxy: func(request *http.Request) (u *url.URL, e error) {
if base.Proxy == "" {
return http.ProxyFromEnvironment(request)
}
return url.Parse(base.Proxy)
},
ForceAttemptHTTP2: true,
MaxConnsPerHost: 0,
MaxIdleConns: 0,
MaxIdleConnsPerHost: 999,
},
}
// ErrOverSize 响应主体过大时返回此错误
ErrOverSize = errors.New("oversize")
// UserAgent HTTP请求时使用的UA
UserAgent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/87.0.4280.88 Safari/537.36 Edg/87.0.664.66"
)
// GetBytes 对给定URL发送Get请求返回响应主体
func GetBytes(url string) ([]byte, error) {
reader, err := HTTPGetReadCloser(url)
if err != nil {
return nil, err
}
defer func() {
_ = reader.Close()
}()
return io.ReadAll(reader)
}
// DownloadFile 将给定URL对应的文件下载至给定Path
func DownloadFile(url, path string, limit int64, headers map[string]string) error {
file, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE, 0o666)
if err != nil {
return err
}
defer file.Close()
req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
return err
}
for k, v := range headers {
req.Header.Set(k, v)
}
resp, err := client.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if limit > 0 && resp.ContentLength > limit {
return ErrOverSize
}
_, err = file.ReadFrom(resp.Body)
if err != nil {
return err
}
return nil
}
// DownloadFileMultiThreading 使用threadCount个线程将给定URL对应的文件下载至给定Path
func DownloadFileMultiThreading(url, path string, limit int64, threadCount int, headers map[string]string) error {
if threadCount < 2 {
return DownloadFile(url, path, limit, headers)
}
type BlockMetaData struct {
BeginOffset int64
EndOffset int64
DownloadedSize int64
}
var blocks []*BlockMetaData
var contentLength int64
errUnsupportedMultiThreading := errors.New("unsupported multi-threading")
// 初始化分块或直接下载
initOrDownload := func() error {
copyStream := func(s io.ReadCloser) error {
file, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE, 0o666)
if err != nil {
return err
}
defer file.Close()
if _, err = file.ReadFrom(s); err != nil {
return err
}
return errUnsupportedMultiThreading
}
req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
return err
}
for k, v := range headers {
req.Header.Set(k, v)
}
if _, ok := headers["User-Agent"]; !ok {
req.Header["User-Agent"] = []string{UserAgent}
}
req.Header.Set("range", "bytes=0-")
resp, err := client.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return errors.New("response status unsuccessful: " + strconv.FormatInt(int64(resp.StatusCode), 10))
}
if resp.StatusCode == http.StatusOK {
if limit > 0 && resp.ContentLength > limit {
return ErrOverSize
}
return copyStream(resp.Body)
}
if resp.StatusCode == http.StatusPartialContent {
contentLength = resp.ContentLength
if limit > 0 && resp.ContentLength > limit {
return ErrOverSize
}
blockSize := contentLength
if contentLength > 1024*1024 {
blockSize = (contentLength / int64(threadCount)) - 10
}
if blockSize == contentLength {
return copyStream(resp.Body)
}
var tmp int64
for tmp+blockSize < contentLength {
blocks = append(blocks, &BlockMetaData{
BeginOffset: tmp,
EndOffset: tmp + blockSize - 1,
})
tmp += blockSize
}
blocks = append(blocks, &BlockMetaData{
BeginOffset: tmp,
EndOffset: contentLength - 1,
})
return nil
}
return errors.New("unknown status code")
}
// 下载分块
downloadBlock := func(block *BlockMetaData) error {
req, _ := http.NewRequest(http.MethodGet, url, nil)
file, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE, 0o666)
if err != nil {
return err
}
defer file.Close()
_, _ = file.Seek(block.BeginOffset, io.SeekStart)
writer := bufio.NewWriter(file)
defer writer.Flush()
for k, v := range headers {
req.Header.Set(k, v)
}
if _, ok := headers["User-Agent"]; !ok {
req.Header["User-Agent"] = []string{UserAgent}
}
req.Header.Set("range", "bytes="+strconv.FormatInt(block.BeginOffset, 10)+"-"+strconv.FormatInt(block.EndOffset, 10))
resp, err := client.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return errors.New("response status unsuccessful: " + strconv.FormatInt(int64(resp.StatusCode), 10))
}
buffer := make([]byte, 1024)
i, err := resp.Body.Read(buffer)
for {
if err != nil && err != io.EOF {
return err
}
i64 := int64(len(buffer[:i]))
needSize := block.EndOffset + 1 - block.BeginOffset
if i64 > needSize {
i64 = needSize
err = io.EOF
}
_, e := writer.Write(buffer[:i64])
if e != nil {
return e
}
block.BeginOffset += i64
block.DownloadedSize += i64
if err == io.EOF || block.BeginOffset > block.EndOffset {
break
}
i, err = resp.Body.Read(buffer)
}
return nil
}
if err := initOrDownload(); err != nil {
if err == errUnsupportedMultiThreading {
return nil
}
return err
}
wg := sync.WaitGroup{}
wg.Add(len(blocks))
var lastErr error
for i := range blocks {
go func(b *BlockMetaData) {
defer wg.Done()
if err := downloadBlock(b); err != nil {
lastErr = err
}
}(blocks[i])
}
wg.Wait()
return lastErr
}
// QQMusicSongInfo 通过给定id在QQ音乐上查找曲目信息 // QQMusicSongInfo 通过给定id在QQ音乐上查找曲目信息
func QQMusicSongInfo(id string) (gjson.Result, error) { func QQMusicSongInfo(id string) (gjson.Result, error) {
d, err := GetBytes(`https://u.y.qq.com/cgi-bin/musicu.fcg?format=json&inCharset=utf8&outCharset=utf-8&notice=0&platform=yqq.json&needNewCode=0&data={%22comm%22:{%22ct%22:24,%22cv%22:0},%22songinfo%22:{%22method%22:%22get_song_detail_yqq%22,%22param%22:{%22song_type%22:0,%22song_mid%22:%22%22,%22song_id%22:` + id + `},%22module%22:%22music.pf_song_detail_svr%22}}`) d, err := download.Request{URL: `https://u.y.qq.com/cgi-bin/musicu.fcg?format=json&inCharset=utf8&outCharset=utf-8&notice=0&platform=yqq.json&needNewCode=0&data={%22comm%22:{%22ct%22:24,%22cv%22:0},%22songinfo%22:{%22method%22:%22get_song_detail_yqq%22,%22param%22:{%22song_type%22:0,%22song_mid%22:%22%22,%22song_id%22:` + id + `},%22module%22:%22music.pf_song_detail_svr%22}}`}.JSON()
if err != nil { if err != nil {
return gjson.Result{}, err return gjson.Result{}, err
} }
return gjson.ParseBytes(d).Get("songinfo.data"), nil return d.Get("songinfo.data"), nil
} }
// NeteaseMusicSongInfo 通过给定id在wdd音乐上查找曲目信息 // NeteaseMusicSongInfo 通过给定id在wdd音乐上查找曲目信息
func NeteaseMusicSongInfo(id string) (gjson.Result, error) { func NeteaseMusicSongInfo(id string) (gjson.Result, error) {
d, err := GetBytes(fmt.Sprintf("http://music.163.com/api/song/detail/?id=%s&ids=%%5B%s%%5D", id, id)) d, err := download.Request{URL: fmt.Sprintf("http://music.163.com/api/song/detail/?id=%s&ids=%%5B%s%%5D", id, id)}.JSON()
if err != nil { if err != nil {
return gjson.Result{}, err return gjson.Result{}, err
} }
return gjson.ParseBytes(d).Get("songs.0"), nil return d.Get("songs.0"), nil
}
type gzipCloser struct {
f io.Closer
r *gzip.Reader
}
// NewGzipReadCloser 从 io.ReadCloser 创建 gunzip io.ReadCloser
func NewGzipReadCloser(reader io.ReadCloser) (io.ReadCloser, error) {
gzipReader, err := gzip.NewReader(reader)
if err != nil {
return nil, err
}
return &gzipCloser{
f: reader,
r: gzipReader,
}, nil
}
// Read impls io.Reader
func (g *gzipCloser) Read(p []byte) (n int, err error) {
return g.r.Read(p)
}
// Close impls io.Closer
func (g *gzipCloser) Close() error {
_ = g.f.Close()
return g.r.Close()
}
// HTTPGetReadCloser 从 Http url 获取 io.ReadCloser
func HTTPGetReadCloser(url string) (io.ReadCloser, error) {
req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
return nil, err
}
req.Header["User-Agent"] = []string{UserAgent}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
if strings.Contains(resp.Header.Get("Content-Encoding"), "gzip") {
return NewGzipReadCloser(resp.Body)
}
return resp.Body, err
} }

View File

@ -0,0 +1,295 @@
package download
import (
"bufio"
"compress/gzip"
"fmt"
"io"
"net/http"
"net/url"
"os"
"strconv"
"strings"
"sync"
"github.com/pkg/errors"
"github.com/tidwall/gjson"
"github.com/Mrs4s/go-cqhttp/internal/base"
)
var client = &http.Client{
Transport: &http.Transport{
Proxy: func(request *http.Request) (u *url.URL, e error) {
if base.Proxy == "" {
return http.ProxyFromEnvironment(request)
}
return url.Parse(base.Proxy)
},
ForceAttemptHTTP2: false,
MaxConnsPerHost: 0,
MaxIdleConns: 0,
MaxIdleConnsPerHost: 999,
},
}
// ErrOverSize 响应主体过大时返回此错误
var ErrOverSize = errors.New("oversize")
// UserAgent HTTP请求时使用的UA
const UserAgent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/87.0.4280.88 Safari/537.36 Edg/87.0.664.66"
type Request struct {
URL string
Header map[string]string
Limit int64
}
func (r Request) do() (*http.Response, error) {
req, err := http.NewRequest(http.MethodGet, r.URL, nil)
if err != nil {
return nil, err
}
req.Header["User-Agent"] = []string{UserAgent}
for k, v := range r.Header {
req.Header.Set(k, v)
}
return client.Do(req)
}
func (r Request) body() (io.ReadCloser, error) {
resp, err := r.do()
if err != nil {
return nil, err
}
limit := r.Limit // check file size limit
if limit > 0 && resp.ContentLength > limit {
_ = resp.Body.Close()
return nil, ErrOverSize
}
if strings.Contains(resp.Header.Get("Content-Encoding"), "gzip") {
return gzipReadCloser(resp.Body)
}
return resp.Body, err
}
// Bytes 对给定URL发送Get请求返回响应主体
func (r Request) Bytes() ([]byte, error) {
rd, err := r.body()
if err != nil {
return nil, err
}
defer rd.Close()
return io.ReadAll(rd)
}
func (r Request) JSON() (gjson.Result, error) {
rd, err := r.body()
if err != nil {
return gjson.Result{}, err
}
defer rd.Close()
var sb strings.Builder
_, err = io.Copy(&sb, rd)
if err != nil {
return gjson.Result{}, err
}
return gjson.Parse(sb.String()), nil
}
func writeToFile(reader io.ReadCloser, path string) error {
file, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE, 0o644)
if err != nil {
return err
}
_, err = file.ReadFrom(reader)
return err
}
func (r Request) WriteToFile(path string) error {
rd, err := r.body()
if err != nil {
return err
}
defer rd.Close()
return writeToFile(rd, path)
}
func (r Request) WriteToFileMultiThreading(path string, thread int) error {
if thread < 2 {
return r.WriteToFile(path)
}
limit := r.Limit
type BlockMetaData struct {
BeginOffset int64
EndOffset int64
DownloadedSize int64
}
var blocks []*BlockMetaData
var contentLength int64
errUnsupportedMultiThreading := errors.New("unsupported multi-threading")
// 初始化分块或直接下载
initOrDownload := func() error {
header := make(map[string]string, len(r.Header))
for k, v := range r.Header { // copy headers
header[k] = v
}
header["range"] = "bytes=0-"
req := Request{
URL: r.URL,
Header: header,
}
resp, err := req.do()
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return errors.New("response status unsuccessful: " + strconv.FormatInt(int64(resp.StatusCode), 10))
}
if resp.StatusCode == http.StatusOK {
if limit > 0 && resp.ContentLength > limit {
return ErrOverSize
}
if err = writeToFile(resp.Body, path); err != nil {
return err
}
return errUnsupportedMultiThreading
}
if resp.StatusCode == http.StatusPartialContent {
contentLength = resp.ContentLength
if limit > 0 && resp.ContentLength > limit {
return ErrOverSize
}
blockSize := contentLength
if contentLength > 1024*1024 {
blockSize = (contentLength / int64(thread)) - 10
}
if blockSize == contentLength {
return writeToFile(resp.Body, path)
}
var tmp int64
for tmp+blockSize < contentLength {
blocks = append(blocks, &BlockMetaData{
BeginOffset: tmp,
EndOffset: tmp + blockSize - 1,
})
tmp += blockSize
}
blocks = append(blocks, &BlockMetaData{
BeginOffset: tmp,
EndOffset: contentLength - 1,
})
return nil
}
return errors.New("unknown status code")
}
// 下载分块
downloadBlock := func(block *BlockMetaData) error {
file, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE, 0o666)
if err != nil {
return err
}
defer file.Close()
_, _ = file.Seek(block.BeginOffset, io.SeekStart)
writer := bufio.NewWriter(file)
defer writer.Flush()
header := make(map[string]string, len(r.Header))
for k, v := range r.Header { // copy headers
header[k] = v
}
header["range"] = fmt.Sprintf("bytes=%d-%d", block.BeginOffset, block.EndOffset)
req := Request{
URL: r.URL,
Header: header,
}
resp, err := req.do()
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return errors.New("response status unsuccessful: " + strconv.FormatInt(int64(resp.StatusCode), 10))
}
buffer := make([]byte, 1024)
i, err := resp.Body.Read(buffer)
for {
if err != nil && err != io.EOF {
return err
}
i64 := int64(len(buffer[:i]))
needSize := block.EndOffset + 1 - block.BeginOffset
if i64 > needSize {
i64 = needSize
err = io.EOF
}
_, e := writer.Write(buffer[:i64])
if e != nil {
return e
}
block.BeginOffset += i64
block.DownloadedSize += i64
if err == io.EOF || block.BeginOffset > block.EndOffset {
break
}
i, err = resp.Body.Read(buffer)
}
return nil
}
if err := initOrDownload(); err != nil {
if err == errUnsupportedMultiThreading {
return nil
}
return err
}
wg := sync.WaitGroup{}
wg.Add(len(blocks))
var lastErr error
for i := range blocks {
go func(b *BlockMetaData) {
defer wg.Done()
if err := downloadBlock(b); err != nil {
lastErr = err
}
}(blocks[i])
}
wg.Wait()
return lastErr
}
type gzipCloser struct {
f io.Closer
r *gzip.Reader
}
// gzipReadCloser 从 io.ReadCloser 创建 gunzip io.ReadCloser
func gzipReadCloser(reader io.ReadCloser) (io.ReadCloser, error) {
gzipReader, err := gzip.NewReader(reader)
if err != nil {
return nil, err
}
return &gzipCloser{
f: reader,
r: gzipReader,
}, nil
}
// Read impls io.Reader
func (g *gzipCloser) Read(p []byte) (n int, err error) {
return g.r.Read(p)
}
// Close impls io.Closer
func (g *gzipCloser) Close() error {
_ = g.f.Close()
return g.r.Close()
}

View File

@ -3,6 +3,7 @@ package selfupdate
import ( import (
"bufio" "bufio"
"bytes"
"encoding/hex" "encoding/hex"
"fmt" "fmt"
"hash" "hash"
@ -14,10 +15,10 @@ import (
"strings" "strings"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"github.com/Mrs4s/go-cqhttp/global" "github.com/Mrs4s/go-cqhttp/global"
"github.com/Mrs4s/go-cqhttp/internal/base" "github.com/Mrs4s/go-cqhttp/internal/base"
"github.com/Mrs4s/go-cqhttp/internal/download"
) )
func readLine() (str string) { func readLine() (str string) {
@ -28,11 +29,11 @@ func readLine() (str string) {
} }
func lastVersion() (string, error) { func lastVersion() (string, error) {
r, err := global.GetBytes("https://api.github.com/repos/Mrs4s/go-cqhttp/releases/latest") r, err := download.Request{URL: "https://api.github.com/repos/Mrs4s/go-cqhttp/releases/latest"}.JSON()
if err != nil { if err != nil {
return "", err return "", err
} }
return gjson.GetBytes(r, "tag_name").Str, nil return r.Get("tag_name").Str, nil
} }
// CheckUpdate 检查更新 // CheckUpdate 检查更新
@ -69,12 +70,12 @@ func binaryName() string {
func checksum(github, version string) []byte { func checksum(github, version string) []byte {
sumURL := fmt.Sprintf("%v/Mrs4s/go-cqhttp/releases/download/%v/go-cqhttp_checksums.txt", github, version) sumURL := fmt.Sprintf("%v/Mrs4s/go-cqhttp/releases/download/%v/go-cqhttp_checksums.txt", github, version)
closer, err := global.HTTPGetReadCloser(sumURL) sum, err := download.Request{URL: sumURL}.Bytes()
if err != nil { if err != nil {
return nil return nil
} }
rd := bufio.NewReader(closer) rd := bufio.NewReader(bytes.NewReader(sum))
for { for {
str, err := rd.ReadString('\n') str, err := rd.ReadString('\n')
if err != nil { if err != nil {