diff --git a/.gitignore b/.gitignore index 43a051c..9e66494 100644 --- a/.gitignore +++ b/.gitignore @@ -8,4 +8,7 @@ device.json data/ logs/ internal/btree/*.lock -internal/btree/*.db \ No newline at end of file +internal/btree/*.db + +# binary builds +go-cqhttp diff --git a/cmd/gocq/login.go b/cmd/gocq/login.go index d963d25..3fc3813 100644 --- a/cmd/gocq/login.go +++ b/cmd/gocq/login.go @@ -15,9 +15,9 @@ import ( "github.com/mattn/go-colorable" "github.com/pkg/errors" log "github.com/sirupsen/logrus" - "github.com/tidwall/gjson" "github.com/Mrs4s/go-cqhttp/global" + "github.com/Mrs4s/go-cqhttp/internal/download" ) var console = bufio.NewReader(os.Stdin) @@ -243,12 +243,11 @@ func getTicket(u string) (str string) { } func fetchCaptcha(id string) string { - data, err := global.GetBytes("https://captcha.go-cqhttp.org/captcha/ticket?id=" + id) + g, err := download.Request{URL: "https://captcha.go-cqhttp.org/captcha/ticket?id=" + id}.JSON() if err != nil { log.Warnf("获取 Ticket 时出现错误: %v", err) return "" } - g := gjson.ParseBytes(data) if g.Get("ticket").Exists() { return g.Get("ticket").String() } diff --git a/cmd/gocq/main.go b/cmd/gocq/main.go index 442f73e..50699a3 100644 --- a/cmd/gocq/main.go +++ b/cmd/gocq/main.go @@ -55,11 +55,15 @@ func Main() { switch { case base.LittleH: base.Help() - case base.LittleWD != "": - base.ResetWorkingDir() case base.LittleD: server.Daemon() } + if base.LittleWD != "" { + err := os.Chdir(base.LittleWD) + if err != nil { + log.Fatalf("重置工作目录时出现错误: %v", err) + } + } base.Init() rotateOptions := []rotatelogs.Option{ diff --git a/coolq/api.go b/coolq/api.go index 526389c..d7275ce 100644 --- a/coolq/api.go +++ b/coolq/api.go @@ -27,6 +27,7 @@ import ( "github.com/Mrs4s/go-cqhttp/global" "github.com/Mrs4s/go-cqhttp/internal/base" "github.com/Mrs4s/go-cqhttp/internal/cache" + "github.com/Mrs4s/go-cqhttp/internal/download" "github.com/Mrs4s/go-cqhttp/internal/param" "github.com/Mrs4s/go-cqhttp/modules/filter" ) @@ -1545,12 +1546,8 @@ func (bot *CQBot) CQGetImage(file string) global.MSG { } local := path.Join(global.CachePath, file+path.Ext(msg["filename"].(string))) if !global.PathExists(local) { - if body, err := global.HTTPGetReadCloser(msg["url"].(string)); err == nil { - f, _ := os.OpenFile(local, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o0644) - _, _ = f.ReadFrom(body) - _ = body.Close() - _ = f.Close() - } else { + r := download.Request{URL: msg["url"].(string)} + if err := r.WriteToFile(local); err != nil { log.Warnf("下载图片 %v 时出现错误: %v", msg["url"], err) return Failed(100, "DOWNLOAD_IMAGE_ERROR", err.Error()) } @@ -1593,7 +1590,8 @@ func (bot *CQBot) CQDownloadFile(url string, headers gjson.Result, threadCount i return Failed(100, "DELETE_FILE_ERROR", err.Error()) } } - if err := global.DownloadFileMultiThreading(url, file, 0, threadCount, h); err != nil { + r := download.Request{URL: url, Header: h} + if err := r.WriteToFileMultiThreading(file, threadCount); err != nil { log.Warnf("下载链接 %v 时出现错误: %v", url, err) return Failed(100, "DOWNLOAD_FILE_ERROR", err.Error()) } diff --git a/coolq/cqcode.go b/coolq/cqcode.go index 571b9b1..c241e34 100644 --- a/coolq/cqcode.go +++ b/coolq/cqcode.go @@ -29,6 +29,7 @@ import ( "github.com/Mrs4s/go-cqhttp/global" "github.com/Mrs4s/go-cqhttp/internal/base" "github.com/Mrs4s/go-cqhttp/internal/cache" + "github.com/Mrs4s/go-cqhttp/internal/download" "github.com/Mrs4s/go-cqhttp/internal/mime" "github.com/Mrs4s/go-cqhttp/internal/param" ) @@ -87,7 +88,9 @@ func replyID(r *message.ReplyElement, source message.Source) int32 { if r.GroupID != 0 { id = r.GroupID } - if source.SourceType == message.SourcePrivate && r.Sender == source.PrimaryID { + // 私聊时,部分(不确定)的账号会在 ReplyElement 中带有 GroupID 字段。 + // 这里需要判断是由于 “直接回复” 功能,GroupID 为触发直接回复的来源那个群。 + if source.SourceType == message.SourcePrivate && (r.Sender == source.PrimaryID || r.GroupID == source.PrimaryID) { // 私聊似乎腾讯服务器有bug? seq = int32(uint16(seq)) id = r.Sender @@ -893,9 +896,9 @@ func (bot *CQBot) ToElement(t string, d map[string]string, sourceType message.So name := info.Get("track_info.name").Str mid := info.Get("track_info.mid").Str albumMid := info.Get("track_info.album.mid").Str - pinfo, _ := global.GetBytes("http://u.y.qq.com/cgi-bin/musicu.fcg?g_tk=2034008533&uin=0&format=json&data={\"comm\":{\"ct\":23,\"cv\":0},\"url_mid\":{\"module\":\"vkey.GetVkeyServer\",\"method\":\"CgiGetVkey\",\"param\":{\"guid\":\"4311206557\",\"songmid\":[\"" + mid + "\"],\"songtype\":[0],\"uin\":\"0\",\"loginflag\":1,\"platform\":\"23\"}}}&_=1599039471576") + pinfo, _ := download.Request{URL: "http://u.y.qq.com/cgi-bin/musicu.fcg?g_tk=2034008533&uin=0&format=json&data={\"comm\":{\"ct\":23,\"cv\":0},\"url_mid\":{\"module\":\"vkey.GetVkeyServer\",\"method\":\"CgiGetVkey\",\"param\":{\"guid\":\"4311206557\",\"songmid\":[\"" + mid + "\"],\"songtype\":[0],\"uin\":\"0\",\"loginflag\":1,\"platform\":\"23\"}}}&_=1599039471576"}.JSON() jumpURL := "https://i.y.qq.com/v8/playsong.html?platform=11&appshare=android_qq&appversion=10030010&hosteuin=oKnlNenz7i-s7c**&songmid=" + mid + "&type=0&appsongtype=1&_wv=1&source=qq&ADTAG=qfshare" - purl := gjson.ParseBytes(pinfo).Get("url_mid.data.midurlinfo.0.purl").Str + purl := pinfo.Get("url_mid.data.midurlinfo.0.purl").Str preview := "http://y.gtimg.cn/music/photo_new/T002R180x180M000" + albumMid + ".jpg" content := info.Get("track_info.singer.0.name").Str if d["content"] != "" { @@ -1087,8 +1090,11 @@ func (bot *CQBot) makeImageOrVideoElem(d map[string]string, video bool, sourceTy if exist { _ = os.Remove(cacheFile) } - if err := global.DownloadFileMultiThreading(f, cacheFile, maxSize, thread, nil); err != nil { - return nil, err + { + r := download.Request{URL: f, Limit: maxSize} + if err := r.WriteToFileMultiThreading(cacheFile, thread); err != nil { + return nil, err + } } useCacheFile: if video { diff --git a/coolq/event.go b/coolq/event.go index ebdc960..34500e2 100644 --- a/coolq/event.go +++ b/coolq/event.go @@ -4,7 +4,6 @@ import ( "encoding/hex" "encoding/json" "fmt" - "os" "path" "strconv" "strings" @@ -18,6 +17,7 @@ import ( "github.com/Mrs4s/go-cqhttp/global" "github.com/Mrs4s/go-cqhttp/internal/base" "github.com/Mrs4s/go-cqhttp/internal/cache" + "github.com/Mrs4s/go-cqhttp/internal/download" ) // ToFormattedMessage 将给定[]message.IMessageElement转换为通过coolq.SetMessageFormat所定义的消息上报格式 @@ -666,7 +666,8 @@ func (bot *CQBot) checkMedia(e []message.IMessageElement, sourceID int64) { filename := hex.EncodeToString(i.Md5) + ".image" cache.Image.Insert(i.Md5, data) if i.Url != "" && !global.PathExists(path.Join(global.ImagePath, "guild-images", filename)) { - if err := global.DownloadFile(i.Url, path.Join(global.ImagePath, "guild-images", filename), -1, nil); err != nil { + r := download.Request{URL: i.Url} + if err := r.WriteToFile(path.Join(global.ImagePath, "guild-images", filename)); err != nil { log.Warnf("下载频道图片时出现错误: %v", err) } } @@ -684,12 +685,11 @@ func (bot *CQBot) checkMedia(e []message.IMessageElement, sourceID int64) { i.Name = strings.ReplaceAll(i.Name, "{", "") i.Name = strings.ReplaceAll(i.Name, "}", "") if !global.PathExists(path.Join(global.VoicePath, i.Name)) { - b, err := global.GetBytes(i.Url) + err := download.Request{URL: i.Url}.WriteToFile(path.Join(global.VoicePath, i.Name)) if err != nil { log.Warnf("语音文件 %v 下载失败: %v", i.Name, err) continue } - _ = os.WriteFile(path.Join(global.VoicePath, i.Name), b, 0o644) } case *message.ShortVideoElement: data := binary.NewWriterF(func(w *binary.Writer) { diff --git a/global/fs.go b/global/fs.go index f985dd5..046ac2c 100644 --- a/global/fs.go +++ b/global/fs.go @@ -16,6 +16,8 @@ import ( b14 "github.com/fumiama/go-base16384" "github.com/segmentio/asm/base64" log "github.com/sirupsen/logrus" + + "github.com/Mrs4s/go-cqhttp/internal/download" ) const ( @@ -82,8 +84,7 @@ func FindFile(file, cache, p string) (data []byte, err error) { if (cache == "" || cache == "1") && PathExists(cacheFile) { return os.ReadFile(cacheFile) } - data, err = GetBytes(file) - _ = os.WriteFile(cacheFile, data, 0o644) + err = download.Request{URL: file}.WriteToFile(cacheFile) if err != nil { return nil, err } diff --git a/global/net.go b/global/net.go index 0327a7e..030cbea 100644 --- a/global/net.go +++ b/global/net.go @@ -1,304 +1,27 @@ package global import ( - "bufio" - "compress/gzip" "fmt" - "io" - "net/http" - "net/url" - "os" - "strconv" - "strings" - "sync" - - "github.com/pkg/errors" "github.com/tidwall/gjson" - "github.com/Mrs4s/go-cqhttp/internal/base" + "github.com/Mrs4s/go-cqhttp/internal/download" ) -var ( - client = &http.Client{ - Transport: &http.Transport{ - Proxy: func(request *http.Request) (u *url.URL, e error) { - if base.Proxy == "" { - return http.ProxyFromEnvironment(request) - } - return url.Parse(base.Proxy) - }, - ForceAttemptHTTP2: true, - MaxConnsPerHost: 0, - MaxIdleConns: 0, - MaxIdleConnsPerHost: 999, - }, - } - - // ErrOverSize 响应主体过大时返回此错误 - ErrOverSize = errors.New("oversize") - - // UserAgent HTTP请求时使用的UA - UserAgent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/87.0.4280.88 Safari/537.36 Edg/87.0.664.66" -) - -// GetBytes 对给定URL发送Get请求,返回响应主体 -func GetBytes(url string) ([]byte, error) { - reader, err := HTTPGetReadCloser(url) - if err != nil { - return nil, err - } - defer func() { - _ = reader.Close() - }() - return io.ReadAll(reader) -} - -// DownloadFile 将给定URL对应的文件下载至给定Path -func DownloadFile(url, path string, limit int64, headers map[string]string) error { - file, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE, 0o666) - if err != nil { - return err - } - defer file.Close() - req, err := http.NewRequest(http.MethodGet, url, nil) - if err != nil { - return err - } - - for k, v := range headers { - req.Header.Set(k, v) - } - - resp, err := client.Do(req) - if err != nil { - return err - } - defer resp.Body.Close() - if limit > 0 && resp.ContentLength > limit { - return ErrOverSize - } - _, err = file.ReadFrom(resp.Body) - if err != nil { - return err - } - return nil -} - -// DownloadFileMultiThreading 使用threadCount个线程将给定URL对应的文件下载至给定Path -func DownloadFileMultiThreading(url, path string, limit int64, threadCount int, headers map[string]string) error { - if threadCount < 2 { - return DownloadFile(url, path, limit, headers) - } - type BlockMetaData struct { - BeginOffset int64 - EndOffset int64 - DownloadedSize int64 - } - var blocks []*BlockMetaData - var contentLength int64 - errUnsupportedMultiThreading := errors.New("unsupported multi-threading") - // 初始化分块或直接下载 - initOrDownload := func() error { - copyStream := func(s io.ReadCloser) error { - file, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE, 0o666) - if err != nil { - return err - } - defer file.Close() - if _, err = file.ReadFrom(s); err != nil { - return err - } - return errUnsupportedMultiThreading - } - req, err := http.NewRequest(http.MethodGet, url, nil) - if err != nil { - return err - } - - for k, v := range headers { - req.Header.Set(k, v) - } - if _, ok := headers["User-Agent"]; !ok { - req.Header["User-Agent"] = []string{UserAgent} - } - req.Header.Set("range", "bytes=0-") - resp, err := client.Do(req) - if err != nil { - return err - } - defer resp.Body.Close() - if resp.StatusCode < 200 || resp.StatusCode >= 300 { - return errors.New("response status unsuccessful: " + strconv.FormatInt(int64(resp.StatusCode), 10)) - } - if resp.StatusCode == http.StatusOK { - if limit > 0 && resp.ContentLength > limit { - return ErrOverSize - } - return copyStream(resp.Body) - } - if resp.StatusCode == http.StatusPartialContent { - contentLength = resp.ContentLength - if limit > 0 && resp.ContentLength > limit { - return ErrOverSize - } - blockSize := contentLength - if contentLength > 1024*1024 { - blockSize = (contentLength / int64(threadCount)) - 10 - } - if blockSize == contentLength { - return copyStream(resp.Body) - } - var tmp int64 - for tmp+blockSize < contentLength { - blocks = append(blocks, &BlockMetaData{ - BeginOffset: tmp, - EndOffset: tmp + blockSize - 1, - }) - tmp += blockSize - } - blocks = append(blocks, &BlockMetaData{ - BeginOffset: tmp, - EndOffset: contentLength - 1, - }) - return nil - } - return errors.New("unknown status code") - } - // 下载分块 - downloadBlock := func(block *BlockMetaData) error { - req, _ := http.NewRequest(http.MethodGet, url, nil) - file, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE, 0o666) - if err != nil { - return err - } - defer file.Close() - _, _ = file.Seek(block.BeginOffset, io.SeekStart) - writer := bufio.NewWriter(file) - defer writer.Flush() - - for k, v := range headers { - req.Header.Set(k, v) - } - - if _, ok := headers["User-Agent"]; !ok { - req.Header["User-Agent"] = []string{UserAgent} - } - req.Header.Set("range", "bytes="+strconv.FormatInt(block.BeginOffset, 10)+"-"+strconv.FormatInt(block.EndOffset, 10)) - resp, err := client.Do(req) - if err != nil { - return err - } - defer resp.Body.Close() - if resp.StatusCode < 200 || resp.StatusCode >= 300 { - return errors.New("response status unsuccessful: " + strconv.FormatInt(int64(resp.StatusCode), 10)) - } - buffer := make([]byte, 1024) - i, err := resp.Body.Read(buffer) - for { - if err != nil && err != io.EOF { - return err - } - i64 := int64(len(buffer[:i])) - needSize := block.EndOffset + 1 - block.BeginOffset - if i64 > needSize { - i64 = needSize - err = io.EOF - } - _, e := writer.Write(buffer[:i64]) - if e != nil { - return e - } - block.BeginOffset += i64 - block.DownloadedSize += i64 - if err == io.EOF || block.BeginOffset > block.EndOffset { - break - } - i, err = resp.Body.Read(buffer) - } - return nil - } - - if err := initOrDownload(); err != nil { - if err == errUnsupportedMultiThreading { - return nil - } - return err - } - wg := sync.WaitGroup{} - wg.Add(len(blocks)) - var lastErr error - for i := range blocks { - go func(b *BlockMetaData) { - defer wg.Done() - if err := downloadBlock(b); err != nil { - lastErr = err - } - }(blocks[i]) - } - wg.Wait() - return lastErr -} - // QQMusicSongInfo 通过给定id在QQ音乐上查找曲目信息 func QQMusicSongInfo(id string) (gjson.Result, error) { - d, err := GetBytes(`https://u.y.qq.com/cgi-bin/musicu.fcg?format=json&inCharset=utf8&outCharset=utf-8¬ice=0&platform=yqq.json&needNewCode=0&data={%22comm%22:{%22ct%22:24,%22cv%22:0},%22songinfo%22:{%22method%22:%22get_song_detail_yqq%22,%22param%22:{%22song_type%22:0,%22song_mid%22:%22%22,%22song_id%22:` + id + `},%22module%22:%22music.pf_song_detail_svr%22}}`) + d, err := download.Request{URL: `https://u.y.qq.com/cgi-bin/musicu.fcg?format=json&inCharset=utf8&outCharset=utf-8¬ice=0&platform=yqq.json&needNewCode=0&data={%22comm%22:{%22ct%22:24,%22cv%22:0},%22songinfo%22:{%22method%22:%22get_song_detail_yqq%22,%22param%22:{%22song_type%22:0,%22song_mid%22:%22%22,%22song_id%22:` + id + `},%22module%22:%22music.pf_song_detail_svr%22}}`}.JSON() if err != nil { return gjson.Result{}, err } - return gjson.ParseBytes(d).Get("songinfo.data"), nil + return d.Get("songinfo.data"), nil } // NeteaseMusicSongInfo 通过给定id在wdd音乐上查找曲目信息 func NeteaseMusicSongInfo(id string) (gjson.Result, error) { - d, err := GetBytes(fmt.Sprintf("http://music.163.com/api/song/detail/?id=%s&ids=%%5B%s%%5D", id, id)) + d, err := download.Request{URL: fmt.Sprintf("http://music.163.com/api/song/detail/?id=%s&ids=%%5B%s%%5D", id, id)}.JSON() if err != nil { return gjson.Result{}, err } - return gjson.ParseBytes(d).Get("songs.0"), nil -} - -type gzipCloser struct { - f io.Closer - r *gzip.Reader -} - -// NewGzipReadCloser 从 io.ReadCloser 创建 gunzip io.ReadCloser -func NewGzipReadCloser(reader io.ReadCloser) (io.ReadCloser, error) { - gzipReader, err := gzip.NewReader(reader) - if err != nil { - return nil, err - } - return &gzipCloser{ - f: reader, - r: gzipReader, - }, nil -} - -// Read impls io.Reader -func (g *gzipCloser) Read(p []byte) (n int, err error) { - return g.r.Read(p) -} - -// Close impls io.Closer -func (g *gzipCloser) Close() error { - _ = g.f.Close() - return g.r.Close() -} - -// HTTPGetReadCloser 从 Http url 获取 io.ReadCloser -func HTTPGetReadCloser(url string) (io.ReadCloser, error) { - req, err := http.NewRequest(http.MethodGet, url, nil) - if err != nil { - return nil, err - } - req.Header["User-Agent"] = []string{UserAgent} - resp, err := client.Do(req) - if err != nil { - return nil, err - } - if strings.Contains(resp.Header.Get("Content-Encoding"), "gzip") { - return NewGzipReadCloser(resp.Body) - } - return resp.Body, err + return d.Get("songs.0"), nil } diff --git a/global/terminal/double_click_windows.go b/global/terminal/double_click_windows.go index 1fa746c..4c3f4e8 100644 --- a/global/terminal/double_click_windows.go +++ b/global/terminal/double_click_windows.go @@ -6,15 +6,16 @@ package terminal import ( "os" "path/filepath" - "syscall" "unsafe" + "golang.org/x/sys/windows" + "github.com/pkg/errors" ) // RunningByDoubleClick 检查是否通过双击直接运行 func RunningByDoubleClick() bool { - kernel32 := syscall.NewLazyDLL("kernel32.dll") + kernel32 := windows.NewLazySystemDLL("kernel32.dll") lp := kernel32.NewProc("GetConsoleProcessList") if lp != nil { var ids [2]uint32 @@ -29,7 +30,8 @@ func RunningByDoubleClick() bool { // NoMoreDoubleClick 提示用户不要双击运行,并生成安全启动脚本 func NoMoreDoubleClick() error { - r := boxW(0, "请勿通过双击直接运行本程序, 这将导致一些非预料的后果.\n请在shell中运行./go-cqhttp.exe\n点击确认将释出安全启动脚本,点击取消则关闭程序", "警告", 0x00000030|0x00000001) + toHighDPI() + r := boxW(getConsoleWindows(), "请勿通过双击直接运行本程序, 这将导致一些非预料的后果.\n请在shell中运行./go-cqhttp.exe\n点击确认将释出安全启动脚本,点击取消则关闭程序", "警告", 0x00000030|0x00000001) if r == 2 { return nil } @@ -59,9 +61,10 @@ func NoMoreDoubleClick() error { // BoxW of Win32 API. Check https://docs.microsoft.com/en-us/windows/win32/api/winuser/nf-winuser-messageboxw for more detail. func boxW(hwnd uintptr, caption, title string, flags uint) int { - captionPtr, _ := syscall.UTF16PtrFromString(caption) - titlePtr, _ := syscall.UTF16PtrFromString(title) - ret, _, _ := syscall.NewLazyDLL("user32.dll").NewProc("MessageBoxW").Call( + captionPtr, _ := windows.UTF16PtrFromString(caption) + titlePtr, _ := windows.UTF16PtrFromString(title) + u32 := windows.NewLazySystemDLL("user32.dll") + ret, _, _ := u32.NewProc("MessageBoxW").Call( hwnd, uintptr(unsafe.Pointer(captionPtr)), uintptr(unsafe.Pointer(titlePtr)), @@ -69,3 +72,23 @@ func boxW(hwnd uintptr, caption, title string, flags uint) int { return int(ret) } + +// GetConsoleWindows retrieves the window handle used by the console associated with the calling process. +func getConsoleWindows() (hWnd uintptr) { + hWnd, _, _ = windows.NewLazySystemDLL("kernel32.dll").NewProc("GetConsoleWindow").Call() + return +} + +// toHighDPI tries to raise DPI awareness context to DPI_AWARENESS_CONTEXT_UNAWARE_GDISCALED +func toHighDPI() { + systemAware := ^uintptr(2) + 1 + unawareGDIScaled := ^uintptr(5) + 1 + u32 := windows.NewLazySystemDLL("user32.dll") + proc := u32.NewProc("SetThreadDpiAwarenessContext") + if proc.Find() != nil { + return + } + for i := unawareGDIScaled; i <= systemAware; i++ { + _, _, _ = u32.NewProc("SetThreadDpiAwarenessContext").Call(i) + } +} diff --git a/internal/base/flag.go b/internal/base/flag.go index 24ad361..38790ee 100644 --- a/internal/base/flag.go +++ b/internal/base/flag.go @@ -5,13 +5,8 @@ import ( "flag" "fmt" "os" - "os/exec" - "path" - "path/filepath" - "strings" "time" - "github.com/pkg/errors" log "github.com/sirupsen/logrus" "gopkg.in/yaml.v3" @@ -58,9 +53,7 @@ var ( // Parse parse flags func Parse() { - wd, _ := os.Getwd() - dc := path.Join(wd, "config.yml") - flag.StringVar(&LittleC, "c", dc, "configuration filename") + flag.StringVar(&LittleC, "c", "config.yml", "configuration filename") flag.BoolVar(&LittleD, "d", false, "running as a daemon") flag.BoolVar(&LittleH, "h", false, "this Help") flag.StringVar(&LittleWD, "w", "", "cover the working directory") @@ -128,32 +121,3 @@ Options: flag.PrintDefaults() os.Exit(0) } - -// ResetWorkingDir 重设工作路径 -func ResetWorkingDir() { - wd := LittleWD - args := make([]string, 0, len(os.Args)) - for i := 1; i < len(os.Args); i++ { - if os.Args[i] == "-w" { - i++ // skip value field - } else if !strings.HasPrefix(os.Args[i], "-w") { - args = append(args, os.Args[i]) - } - } - ex, _ := os.Executable() - p, _ := filepath.Abs(ex) - _, err := os.Stat(p) - if !(err == nil || errors.Is(err, os.ErrExist)) { - log.Fatalf("重置工作目录时出现错误: 无法找到路径 %v", p) - } - proc := exec.Command(p, args...) - proc.Stdin = os.Stdin - proc.Stdout = os.Stdout - proc.Stderr = os.Stderr - proc.Dir = wd - err = proc.Run() - if err != nil { - panic(err) - } - os.Exit(0) -} diff --git a/internal/download/download.go b/internal/download/download.go new file mode 100644 index 0000000..1387168 --- /dev/null +++ b/internal/download/download.go @@ -0,0 +1,300 @@ +// Package download provide download utility functions +package download + +import ( + "bufio" + "compress/gzip" + "fmt" + "io" + "net/http" + "net/url" + "os" + "strconv" + "strings" + "sync" + + "github.com/pkg/errors" + "github.com/tidwall/gjson" + + "github.com/Mrs4s/go-cqhttp/internal/base" +) + +var client = &http.Client{ + Transport: &http.Transport{ + Proxy: func(request *http.Request) (u *url.URL, e error) { + if base.Proxy == "" { + return http.ProxyFromEnvironment(request) + } + return url.Parse(base.Proxy) + }, + ForceAttemptHTTP2: false, + MaxConnsPerHost: 0, + MaxIdleConns: 0, + MaxIdleConnsPerHost: 999, + }, +} + +// ErrOverSize 响应主体过大时返回此错误 +var ErrOverSize = errors.New("oversize") + +// UserAgent HTTP请求时使用的UA +const UserAgent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/87.0.4280.88 Safari/537.36 Edg/87.0.664.66" + +// Request is a file download request +type Request struct { + URL string + Header map[string]string + Limit int64 +} + +func (r Request) do() (*http.Response, error) { + req, err := http.NewRequest(http.MethodGet, r.URL, nil) + if err != nil { + return nil, err + } + + req.Header["User-Agent"] = []string{UserAgent} + for k, v := range r.Header { + req.Header.Set(k, v) + } + + return client.Do(req) +} + +func (r Request) body() (io.ReadCloser, error) { + resp, err := r.do() + if err != nil { + return nil, err + } + + limit := r.Limit // check file size limit + if limit > 0 && resp.ContentLength > limit { + _ = resp.Body.Close() + return nil, ErrOverSize + } + + if strings.Contains(resp.Header.Get("Content-Encoding"), "gzip") { + return gzipReadCloser(resp.Body) + } + return resp.Body, err +} + +// Bytes 对给定URL发送Get请求,返回响应主体 +func (r Request) Bytes() ([]byte, error) { + rd, err := r.body() + if err != nil { + return nil, err + } + defer rd.Close() + return io.ReadAll(rd) +} + +// JSON 发送GET请求, 并转换响应为JSON +func (r Request) JSON() (gjson.Result, error) { + rd, err := r.body() + if err != nil { + return gjson.Result{}, err + } + defer rd.Close() + + var sb strings.Builder + _, err = io.Copy(&sb, rd) + if err != nil { + return gjson.Result{}, err + } + + return gjson.Parse(sb.String()), nil +} + +func writeToFile(reader io.ReadCloser, path string) error { + file, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE, 0o644) + if err != nil { + return err + } + _, err = file.ReadFrom(reader) + return err +} + +// WriteToFile 下载到制定目录 +func (r Request) WriteToFile(path string) error { + rd, err := r.body() + if err != nil { + return err + } + defer rd.Close() + return writeToFile(rd, path) +} + +// WriteToFileMultiThreading 多线程下载到制定目录 +func (r Request) WriteToFileMultiThreading(path string, thread int) error { + if thread < 2 { + return r.WriteToFile(path) + } + + limit := r.Limit + type BlockMetaData struct { + BeginOffset int64 + EndOffset int64 + DownloadedSize int64 + } + var blocks []*BlockMetaData + var contentLength int64 + errUnsupportedMultiThreading := errors.New("unsupported multi-threading") + // 初始化分块或直接下载 + initOrDownload := func() error { + header := make(map[string]string, len(r.Header)) + for k, v := range r.Header { // copy headers + header[k] = v + } + header["range"] = "bytes=0-" + req := Request{ + URL: r.URL, + Header: header, + } + resp, err := req.do() + if err != nil { + return err + } + defer resp.Body.Close() + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return errors.New("response status unsuccessful: " + strconv.FormatInt(int64(resp.StatusCode), 10)) + } + if resp.StatusCode == http.StatusOK { + if limit > 0 && resp.ContentLength > limit { + return ErrOverSize + } + if err = writeToFile(resp.Body, path); err != nil { + return err + } + return errUnsupportedMultiThreading + } + if resp.StatusCode == http.StatusPartialContent { + contentLength = resp.ContentLength + if limit > 0 && resp.ContentLength > limit { + return ErrOverSize + } + blockSize := contentLength + if contentLength > 1024*1024 { + blockSize = (contentLength / int64(thread)) - 10 + } + if blockSize == contentLength { + return writeToFile(resp.Body, path) + } + var tmp int64 + for tmp+blockSize < contentLength { + blocks = append(blocks, &BlockMetaData{ + BeginOffset: tmp, + EndOffset: tmp + blockSize - 1, + }) + tmp += blockSize + } + blocks = append(blocks, &BlockMetaData{ + BeginOffset: tmp, + EndOffset: contentLength - 1, + }) + return nil + } + return errors.New("unknown status code") + } + // 下载分块 + downloadBlock := func(block *BlockMetaData) error { + file, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE, 0o666) + if err != nil { + return err + } + defer file.Close() + _, _ = file.Seek(block.BeginOffset, io.SeekStart) + writer := bufio.NewWriter(file) + defer writer.Flush() + + header := make(map[string]string, len(r.Header)) + for k, v := range r.Header { // copy headers + header[k] = v + } + header["range"] = fmt.Sprintf("bytes=%d-%d", block.BeginOffset, block.EndOffset) + req := Request{ + URL: r.URL, + Header: header, + } + resp, err := req.do() + if err != nil { + return err + } + defer resp.Body.Close() + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return errors.New("response status unsuccessful: " + strconv.FormatInt(int64(resp.StatusCode), 10)) + } + buffer := make([]byte, 1024) + i, err := resp.Body.Read(buffer) + for { + if err != nil && err != io.EOF { + return err + } + i64 := int64(len(buffer[:i])) + needSize := block.EndOffset + 1 - block.BeginOffset + if i64 > needSize { + i64 = needSize + err = io.EOF + } + _, e := writer.Write(buffer[:i64]) + if e != nil { + return e + } + block.BeginOffset += i64 + block.DownloadedSize += i64 + if err == io.EOF || block.BeginOffset > block.EndOffset { + break + } + i, err = resp.Body.Read(buffer) + } + return nil + } + + if err := initOrDownload(); err != nil { + if err == errUnsupportedMultiThreading { + return nil + } + return err + } + wg := sync.WaitGroup{} + wg.Add(len(blocks)) + var lastErr error + for i := range blocks { + go func(b *BlockMetaData) { + defer wg.Done() + if err := downloadBlock(b); err != nil { + lastErr = err + } + }(blocks[i]) + } + wg.Wait() + return lastErr +} + +type gzipCloser struct { + f io.Closer + r *gzip.Reader +} + +// gzipReadCloser 从 io.ReadCloser 创建 gunzip io.ReadCloser +func gzipReadCloser(reader io.ReadCloser) (io.ReadCloser, error) { + gzipReader, err := gzip.NewReader(reader) + if err != nil { + return nil, err + } + return &gzipCloser{ + f: reader, + r: gzipReader, + }, nil +} + +// Read impls io.Reader +func (g *gzipCloser) Read(p []byte) (n int, err error) { + return g.r.Read(p) +} + +// Close impls io.Closer +func (g *gzipCloser) Close() error { + _ = g.f.Close() + return g.r.Close() +} diff --git a/internal/selfupdate/update.go b/internal/selfupdate/update.go index 8771502..108e5fc 100644 --- a/internal/selfupdate/update.go +++ b/internal/selfupdate/update.go @@ -3,6 +3,7 @@ package selfupdate import ( "bufio" + "bytes" "encoding/hex" "fmt" "hash" @@ -14,10 +15,10 @@ import ( "strings" "github.com/sirupsen/logrus" - "github.com/tidwall/gjson" "github.com/Mrs4s/go-cqhttp/global" "github.com/Mrs4s/go-cqhttp/internal/base" + "github.com/Mrs4s/go-cqhttp/internal/download" ) func readLine() (str string) { @@ -28,11 +29,11 @@ func readLine() (str string) { } func lastVersion() (string, error) { - r, err := global.GetBytes("https://api.github.com/repos/Mrs4s/go-cqhttp/releases/latest") + r, err := download.Request{URL: "https://api.github.com/repos/Mrs4s/go-cqhttp/releases/latest"}.JSON() if err != nil { return "", err } - return gjson.GetBytes(r, "tag_name").Str, nil + return r.Get("tag_name").Str, nil } // CheckUpdate 检查更新 @@ -69,12 +70,12 @@ func binaryName() string { func checksum(github, version string) []byte { sumURL := fmt.Sprintf("%v/Mrs4s/go-cqhttp/releases/download/%v/go-cqhttp_checksums.txt", github, version) - closer, err := global.HTTPGetReadCloser(sumURL) + sum, err := download.Request{URL: sumURL}.Bytes() if err != nil { return nil } - rd := bufio.NewReader(closer) + rd := bufio.NewReader(bytes.NewReader(sum)) for { str, err := rd.ReadString('\n') if err != nil {