1
0
mirror of https://github.com/Mrs4s/go-cqhttp.git synced 2025-05-07 04:23:49 +08:00

feature download_file.

This commit is contained in:
Mrs4s 2021-01-06 21:27:07 +08:00
parent 416a3460ab
commit 2bbf969cfa
5 changed files with 84 additions and 2 deletions

View File

@ -1,9 +1,12 @@
package coolq
import (
"crypto/md5"
"encoding/hex"
"io/ioutil"
"os"
"path"
"path/filepath"
"runtime"
"strconv"
"strings"
@ -741,6 +744,25 @@ func (bot *CQBot) CQGetImage(file string) MSG {
}
}
func (bot *CQBot) CQDownloadFile(url string, headers map[string]string, threadCount int) MSG {
hash := md5.Sum([]byte(url))
file := path.Join(global.CACHE_PATH, hex.EncodeToString(hash[:])+".cache")
if global.PathExists(file) {
if err := os.Remove(file); err != nil {
log.Warnf("删除缓存文件 %v 时出现错误: %v", file, err)
return Failed(100, "DELETE_FILE_ERROR", err.Error())
}
}
if err := global.DownloadFileMultiThreading(url, file, 0, threadCount, headers); err != nil {
log.Warnf("下载链接 %v 时出现错误: %v", url, err)
return Failed(100, "DOWNLOAD_FILE_ERROR", err.Error())
}
abs, _ := filepath.Abs(file)
return OK(MSG{
"file": abs,
})
}
func (bot *CQBot) CQGetForwardMessage(resId string) MSG {
m := bot.Client.GetForwardMessage(resId)
if m == nil {

View File

@ -833,7 +833,7 @@ func (bot *CQBot) makeImageElem(d map[string]string, group bool) (message.IMessa
_ = os.Remove(cacheFile)
}
thread, _ := strconv.Atoi(c)
if err := global.DownloadFileMultiThreading(f, cacheFile, maxImageSize, thread); err != nil {
if err := global.DownloadFileMultiThreading(f, cacheFile, maxImageSize, thread, nil); err != nil {
return nil, err
}
return &LocalImageElement{File: cacheFile}, nil

View File

@ -99,7 +99,7 @@ func DownloadFile(url, path string, limit int64) error {
return nil
}
func DownloadFileMultiThreading(url, path string, limit int64, threadCount int) error {
func DownloadFileMultiThreading(url, path string, limit int64, threadCount int, headers map[string]string) error {
if threadCount < 2 {
return DownloadFile(url, path, limit)
}
@ -128,6 +128,11 @@ func DownloadFileMultiThreading(url, path string, limit int64, threadCount int)
if err != nil {
return err
}
if headers != nil {
for k, v := range headers {
req.Header.Set(k, v)
}
}
req.Header.Set("range", "bytes=0-")
resp, err := client.Do(req)
if err != nil {
@ -184,6 +189,11 @@ func DownloadFileMultiThreading(url, path string, limit int64, threadCount int)
_, _ = file.Seek(block.BeginOffset, io.SeekStart)
writer := bufio.NewWriter(file)
defer writer.Flush()
if headers != nil {
for k, v := range headers {
req.Header.Set(k, v)
}
}
req.Header.Set("range", "bytes="+strconv.FormatInt(block.BeginOffset, 10)+"-"+strconv.FormatInt(block.EndOffset, 10))
resp, err := client.Do(req)
if err != nil {

View File

@ -428,6 +428,33 @@ func HandleQuickOperation(s *httpServer, c *gin.Context) {
}
}
func DownloadFile(s *httpServer, c *gin.Context) {
url := getParam(c, "url")
tc, _ := strconv.Atoi(getParam(c, "thread_count"))
h, t := getParamWithType(c, "headers")
headers := map[string]string{}
if t == gjson.Null || t == gjson.String {
lines := strings.Split(h, "\r\n")
for _, sub := range lines {
str := strings.SplitN(sub, "=", 2)
if len(str) == 2 {
headers[str[0]] = str[1]
}
}
}
if t == gjson.JSON {
arr := gjson.Parse(h)
for _, sub := range arr.Array() {
str := strings.SplitN(sub.String(), "=", 2)
if len(str) == 2 {
headers[str[0]] = str[1]
}
}
}
println(url, tc, h, t)
c.JSON(200, s.bot.CQDownloadFile(url, headers, tc))
}
func OcrImage(s *httpServer, c *gin.Context) {
img := getParam(c, "image")
c.JSON(200, s.bot.CQOcrImage(img))
@ -535,6 +562,7 @@ var httpApi = map[string]func(s *httpServer, c *gin.Context){
"reload_event_filter": ReloadEventFilter,
"set_group_portrait": SetGroupPortrait,
"set_group_anonymous_ban": SetGroupAnonymousBan,
"download_file": DownloadFile,
".handle_quick_operation": HandleQuickOperation,
".ocr_image": OcrImage,
"ocr_image": OcrImage,

View File

@ -490,6 +490,28 @@ var wsApi = map[string]func(*coolq.CQBot, gjson.Result) coolq.MSG{
"get_msg": func(bot *coolq.CQBot, p gjson.Result) coolq.MSG {
return bot.CQGetMessage(int32(p.Get("message_id").Int()))
},
"download_file": func(bot *coolq.CQBot, p gjson.Result) coolq.MSG {
headers := map[string]string{}
headersToken := p.Get("headers")
if headersToken.IsArray() {
for _, sub := range headersToken.Array() {
str := strings.SplitN(sub.String(), "=", 2)
if len(str) == 2 {
headers[str[0]] = str[1]
}
}
}
if headersToken.Type == gjson.String {
lines := strings.Split(headersToken.String(), "\r\n")
for _, sub := range lines {
str := strings.SplitN(sub, "=", 2)
if len(str) == 2 {
headers[str[0]] = str[1]
}
}
}
return bot.CQDownloadFile(p.Get("url").Str, headers, int(p.Get("thread_count").Int()))
},
"get_group_honor_info": func(bot *coolq.CQBot, p gjson.Result) coolq.MSG {
return bot.CQGetGroupHonorInfo(p.Get("group_id").Int(), p.Get("type").Str)
},