diff --git a/.golangci.yml b/.golangci.yml index 3ac1f01..4950c00 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -78,7 +78,9 @@ run: # concurrency: 4 # explicitly omit this value to fully utilize available resources. deadline: 5m issues-exit-code: 1 - tests: false + skip-dirs: + - db + tests: true # output configuration options output: @@ -92,4 +94,4 @@ issues: fix: true exclude-use-default: false exclude: - - "Error return value of .((os.)?std(out|err)..*|.*Close|.*Flush|os.Remove(All)?|.*print(f|ln)?|os.(Un)?Setenv). is not check" + - "Error return value of .((os.)?std(out|err)..*|.*Close|.*Seek|.*Flush|os.Remove(All)?|.*print(f|ln)?|os.(Un)?Setenv). is not check" diff --git a/coolq/api.go b/coolq/api.go index af68a3e..8e8778f 100644 --- a/coolq/api.go +++ b/coolq/api.go @@ -14,6 +14,9 @@ import ( "strings" "time" + "github.com/Mrs4s/go-cqhttp/db" + "github.com/Mrs4s/go-cqhttp/internal/cache" + "github.com/Mrs4s/MiraiGo/binary" "github.com/Mrs4s/MiraiGo/client" "github.com/Mrs4s/MiraiGo/message" @@ -462,20 +465,19 @@ func (bot *CQBot) CQSendGroupForwardMessage(groupID int64, m gjson.Result) globa ts.Add(time.Second) if e.Get("data.id").Exists() { i := e.Get("data.id").Int() - m := bot.GetMessage(int32(i)) + m, _ := bot.db.GetGroupMessageByGlobalID(int32(i)) if m != nil { - sender := m["sender"].(message.Sender) return &message.ForwardNode{ - SenderId: sender.Uin, - SenderName: (&sender).DisplayName(), + SenderId: m.Attribute.SenderUin, + SenderName: m.Attribute.SenderName, Time: func() int32 { - msgTime := m["time"].(int32) + msgTime := m.Attribute.Timestamp if hasCustom && msgTime == 0 { return int32(ts.Unix()) } - return msgTime + return int32(msgTime) }(), - Message: resolveElement(bot.ConvertStringMessage(m["message"].(string), true)), + Message: resolveElement(bot.ConvertContentMessage(m.Content, true)), } } log.Warnf("警告: 引用消息 %v 错误或数据库未开启.", e.Get("data.id").Str) @@ -779,29 +781,28 @@ func (bot *CQBot) CQProcessGroupRequest(flag, subType, reason string, approve bo // // https:// git.io/Jtz1y func (bot *CQBot) CQDeleteMessage(messageID int32) global.MSG { - msg := bot.GetMessage(messageID) - if msg == nil { + msg, err := bot.db.GetMessageByGlobalID(messageID) + if err != nil { + log.Warnf("撤回消息时出现错误: %v", err) return Failed(100, "MESSAGE_NOT_FOUND", "消息不存在") } - if _, ok := msg["group"]; ok { - if msg["internal-id"] == nil { - // TODO 撤回临时对话消息 - log.Warnf("撤回 %v 失败: 无法撤回临时对话消息", messageID) - return Failed(100, "CANNOT_RECALL_TEMP_MSG", "无法撤回临时对话消息") - } - if err := bot.Client.RecallGroupMessage(msg["group"].(int64), msg["message-id"].(int32), msg["internal-id"].(int32)); err != nil { + switch o := msg.(type) { + case *db.StoredGroupMessage: + if err = bot.Client.RecallGroupMessage(o.GroupCode, o.Attribute.MessageSeq, o.Attribute.InternalID); err != nil { log.Warnf("撤回 %v 失败: %v", messageID, err) return Failed(100, "RECALL_API_ERROR", err.Error()) } - } else { - if msg["sender"].(message.Sender).Uin != bot.Client.Uin { + case *db.StoredPrivateMessage: + if o.Attribute.SenderUin != bot.Client.Uin { log.Warnf("撤回 %v 失败: 好友会话无法撤回对方消息.", messageID) return Failed(100, "CANNOT_RECALL_FRIEND_MSG", "无法撤回对方消息") } - if err := bot.Client.RecallPrivateMessage(msg["target"].(int64), int64(msg["time"].(int32)), msg["message-id"].(int32), msg["internal-id"].(int32)); err != nil { + if err = bot.Client.RecallPrivateMessage(o.TargetUin, o.Attribute.Timestamp, o.Attribute.MessageSeq, o.Attribute.InternalID); err != nil { log.Warnf("撤回 %v 失败: %v", messageID, err) return Failed(100, "RECALL_API_ERROR", err.Error()) } + default: + return Failed(100, "UNKNOWN_ERROR") } return OK(nil) } @@ -1025,10 +1026,21 @@ func (bot *CQBot) CQHandleQuickOperation(context, operation gjson.Result) global // // https://docs.go-cqhttp.org/api/#%E8%8E%B7%E5%8F%96%E5%9B%BE%E7%89%87%E4%BF%A1%E6%81%AF func (bot *CQBot) CQGetImage(file string) global.MSG { - if !global.PathExists(path.Join(global.ImagePath, file)) { - return Failed(100) + var b []byte + var err error + if cache.EnableCacheDB && strings.HasSuffix(file, ".image") { + var f []byte + f, err = hex.DecodeString(strings.TrimSuffix(file, ".image")) + b = cache.Image.Get(f) } - b, err := os.ReadFile(path.Join(global.ImagePath, file)) + + if b == nil { + if !global.PathExists(path.Join(global.ImagePath, file)) { + return Failed(100) + } + b, err = os.ReadFile(path.Join(global.ImagePath, file)) + } + if err == nil { r := binary.NewReader(b) r.ReadBytes(16) @@ -1103,38 +1115,32 @@ func (bot *CQBot) CQGetForwardMessage(resID string) global.MSG { // // https://git.io/Jtz1b func (bot *CQBot) CQGetMessage(messageID int32) global.MSG { - msg := bot.GetMessage(messageID) - if msg == nil { + msg, err := bot.db.GetMessageByGlobalID(messageID) + if err != nil { + log.Warnf("获取消息时出现错误: %v", err) return Failed(100, "MSG_NOT_FOUND", "消息不存在") } - sender := msg["sender"].(message.Sender) - gid, isGroup := msg["group"] - raw := msg["message"].(string) - return OK(global.MSG{ - "message_id": messageID, - "real_id": msg["message-id"], - "message_seq": msg["message-id"], - "group": isGroup, - "group_id": gid, - "message_type": func() string { - if isGroup { - return "group" - } - return "private" - }(), + m := global.MSG{ + "message_id": msg.GetGlobalID(), + "message_id_v2": msg.GetID(), + "message_type": msg.GetType(), + "real_id": msg.GetAttribute().MessageSeq, + "message_seq": msg.GetAttribute().MessageSeq, + "group": msg.GetType() == "group", "sender": global.MSG{ - "user_id": sender.Uin, - "nickname": sender.Nickname, + "user_id": msg.GetAttribute().SenderUin, + "nickname": msg.GetAttribute().SenderName, }, - "time": msg["time"], - "raw_message": raw, - "message": ToFormattedMessage(bot.ConvertStringMessage(raw, isGroup), func() int64 { - if isGroup { - return gid.(int64) - } - return 0 - }(), false), - }) + "time": msg.GetAttribute().Timestamp, + } + switch o := msg.(type) { + case *db.StoredGroupMessage: + m["group_id"] = o.GroupCode + m["message"] = ToFormattedMessage(bot.ConvertContentMessage(o.Content, true), o.GroupCode, false) + case *db.StoredPrivateMessage: + m["message"] = ToFormattedMessage(bot.ConvertContentMessage(o.Content, false), 0, false) + } + return OK(m) } // CQGetGroupSystemMessages 扩展API-获取群文件系统消息 @@ -1296,18 +1302,13 @@ func (bot *CQBot) CQGetStatus() global.MSG { // // https://docs.go-cqhttp.org/api/#%E8%AE%BE%E7%BD%AE%E7%B2%BE%E5%8D%8E%E6%B6%88%E6%81%AF func (bot *CQBot) CQSetEssenceMessage(messageID int32) global.MSG { - msg := bot.GetMessage(messageID) - if msg == nil { + msg, err := bot.db.GetGroupMessageByGlobalID(messageID) + if err != nil { return Failed(100, "MESSAGE_NOT_FOUND", "消息不存在") } - if _, ok := msg["group"]; ok { - if err := bot.Client.SetEssenceMessage(msg["group"].(int64), msg["message-id"].(int32), msg["internal-id"].(int32)); err != nil { - log.Warnf("设置精华消息 %v 失败: %v", messageID, err) - return Failed(100, "SET_ESSENCE_MSG_ERROR", err.Error()) - } - } else { - log.Warnf("设置精华消息 %v 失败: 非群聊", messageID) - return Failed(100, "SET_ESSENCE_MSG_ERROR", "非群聊") + if err := bot.Client.SetEssenceMessage(msg.GroupCode, msg.Attribute.MessageSeq, msg.Attribute.InternalID); err != nil { + log.Warnf("设置精华消息 %v 失败: %v", messageID, err) + return Failed(100, "SET_ESSENCE_MSG_ERROR", err.Error()) } return OK(nil) } @@ -1316,18 +1317,13 @@ func (bot *CQBot) CQSetEssenceMessage(messageID int32) global.MSG { // // https://docs.go-cqhttp.org/api/#%E7%A7%BB%E5%87%BA%E7%B2%BE%E5%8D%8E%E6%B6%88%E6%81%AF func (bot *CQBot) CQDeleteEssenceMessage(messageID int32) global.MSG { - msg := bot.GetMessage(messageID) - if msg == nil { + msg, err := bot.db.GetGroupMessageByGlobalID(messageID) + if err != nil { return Failed(100, "MESSAGE_NOT_FOUND", "消息不存在") } - if _, ok := msg["group"]; ok { - if err := bot.Client.DeleteEssenceMessage(msg["group"].(int64), msg["message-id"].(int32), msg["internal-id"].(int32)); err != nil { - log.Warnf("移出精华消息 %v 失败: %v", messageID, err) - return Failed(100, "DEL_ESSENCE_MSG_ERROR", err.Error()) - } - } else { - log.Warnf("移出精华消息 %v 失败: 非群聊", messageID) - return Failed(100, "DEL_ESSENCE_MSG_ERROR", "非群聊") + if err := bot.Client.DeleteEssenceMessage(msg.GroupCode, msg.Attribute.MessageSeq, msg.Attribute.InternalID); err != nil { + log.Warnf("删除精华消息 %v 失败: %v", messageID, err) + return Failed(100, "SET_ESSENCE_MSG_ERROR", err.Error()) } return OK(nil) } @@ -1354,7 +1350,7 @@ func (bot *CQBot) CQGetEssenceMessageList(groupCode int64) global.MSG { "sender_id": m.SenderUin, "operator_id": m.AddDigestUin, } - msg["message_id"] = toGlobalID(groupCode, int32(m.MessageID)) + msg["message_id"] = db.ToGlobalID(groupCode, int32(m.MessageID)) list = append(list, msg) } return OK(list) @@ -1440,18 +1436,17 @@ func (bot *CQBot) CQSetModelShow(modelName string, modelShow string) global.MSG // CQMarkMessageAsRead 标记消息已读 func (bot *CQBot) CQMarkMessageAsRead(msgID int32) global.MSG { - m := bot.GetMessage(msgID) - if m == nil { + m, err := bot.db.GetMessageByGlobalID(msgID) + if err != nil { return Failed(100, "MSG_NOT_FOUND", "消息不存在") } - if _, ok := m["group"]; ok { - bot.Client.MarkGroupMessageReaded(m["group"].(int64), int64(m["message-id"].(int32))) + switch o := m.(type) { + case *db.StoredGroupMessage: + bot.Client.MarkGroupMessageReaded(o.GroupCode, int64(o.Attribute.MessageSeq)) return OK(nil) + case *db.StoredPrivateMessage: + bot.Client.MarkPrivateMessageReaded(o.SessionUin, o.Attribute.Timestamp) } - if _, ok := m["from-group"]; ok { - return Failed(100, "MSG_TYPE_ERROR", "不支持标记临时会话") - } - bot.Client.MarkPrivateMessageReaded(m["sender"].(message.Sender).Uin, int64(m["time"].(int32))) return OK(nil) } diff --git a/coolq/bot.go b/coolq/bot.go index f5507a4..f87a5cf 100644 --- a/coolq/bot.go +++ b/coolq/bot.go @@ -2,11 +2,9 @@ package coolq import ( "bytes" - "encoding/gob" "encoding/hex" "encoding/json" "fmt" - "hash/crc32" "io" "os" "path" @@ -14,14 +12,14 @@ import ( "sync" "time" + "github.com/Mrs4s/go-cqhttp/db" + "github.com/Mrs4s/MiraiGo/binary" "github.com/Mrs4s/MiraiGo/client" "github.com/Mrs4s/MiraiGo/message" "github.com/Mrs4s/MiraiGo/utils" "github.com/pkg/errors" log "github.com/sirupsen/logrus" - "github.com/syndtr/goleveldb/leveldb" - "github.com/syndtr/goleveldb/leveldb/opt" "github.com/Mrs4s/go-cqhttp/global" "github.com/Mrs4s/go-cqhttp/internal/base" @@ -35,7 +33,7 @@ type CQBot struct { lock sync.RWMutex events []func(*Event) - db *leveldb.DB + db db.IDatabase friendReqCache sync.Map tempSessionCache sync.Map } @@ -73,27 +71,30 @@ func NewQQBot(cli *client.QQClient) *CQBot { bot := &CQBot{ Client: cli, } - enableLevelDB := false - node, ok := base.Database["leveldb"] - if ok { + levelNode, levelDB := base.Database["leveldb"] + mongoNode, mongoDB := base.Database["mongodb"] + multiDB := db.NewMultiDatabase() + if levelDB { lconf := new(config.LevelDBConfig) - _ = node.Decode(lconf) - enableLevelDB = lconf.Enable - } - if enableLevelDB { - p := path.Join("data", "leveldb") - db, err := leveldb.OpenFile(p, &opt.Options{ - WriteBuffer: 128 * opt.KiB, - }) - if err != nil { - log.Fatalf("打开数据库失败, 如果频繁遇到此问题请清理 data/leveldb 文件夹或关闭数据库功能。") + _ = levelNode.Decode(lconf) + if lconf.Enable { + multiDB.UseDB(db.UseLevelDB()) } - bot.db = db - gob.Register(message.Sender{}) - log.Info("信息数据库初始化完成.") - } else { - log.Warn("警告: 信息数据库已关闭,将无法使用 [回复/撤回] 等功能。") } + if mongoDB { + lconf := new(config.MongoDBConfig) + _ = mongoNode.Decode(lconf) + if lconf.Database == "" { + lconf.Database = "gocq-database" + } + if lconf.Enable { + multiDB.UseDB(db.UseMongoDB(lconf.URI, lconf.Database)) + } + } + if err := multiDB.Open(); err != nil { + log.Fatalf("打开数据库失败: %v", err) + } + bot.db = multiDB bot.Client.OnPrivateMessage(bot.privateMessageEvent) bot.Client.OnGroupMessage(bot.groupMessageEvent) if base.ReportSelfMessage { @@ -148,22 +149,6 @@ func (bot *CQBot) OnEventPush(f func(e *Event)) { bot.lock.Unlock() } -// GetMessage 获取给定消息id对应的消息 -func (bot *CQBot) GetMessage(mid int32) global.MSG { - if bot.db != nil { - m := global.MSG{} - data, err := bot.db.Get(binary.ToBytes(mid), nil) - if err == nil { - err = gob.NewDecoder(bytes.NewReader(data)).Decode(&m) - if err == nil { - return m - } - } - log.Warnf("获取信息时出现错误: %v id: %v", err, mid) - } - return nil -} - // UploadLocalImageAsGroup 上传本地图片至群聊 func (bot *CQBot) UploadLocalImageAsGroup(groupCode int64, img *LocalImageElement) (i *message.GroupImageElement, err error) { if img.File != "" { @@ -332,8 +317,9 @@ func (bot *CQBot) SendPrivateMessage(target int64, groupID int64, m *message.Sen default: if session == nil && groupID != 0 { msg := bot.Client.SendGroupTempMessage(groupID, target, m) - if msg != nil { - id = bot.InsertTempMessage(target, msg) + if msg != nil { // nolint + // todo(Mrs4s) + // id = bot.InsertTempMessage(target, msg) } break } @@ -342,8 +328,9 @@ func (bot *CQBot) SendPrivateMessage(target int64, groupID int64, m *message.Sen log.Errorf("发送临时会话消息失败: %v", err) break } - if msg != nil { - id = bot.InsertTempMessage(target, msg) + if msg != nil { // nolint + // todo(Mrs4s) + // id = bot.InsertTempMessage(target, msg) } } case unidirectionalFriendExists(): // 单向好友 @@ -363,57 +350,91 @@ func (bot *CQBot) SendPrivateMessage(target int64, groupID int64, m *message.Sen // InsertGroupMessage 群聊消息入数据库 func (bot *CQBot) InsertGroupMessage(m *message.GroupMessage) int32 { - val := global.MSG{ - "message-id": m.Id, - "internal-id": m.InternalId, - "group": m.GroupCode, - "group-name": m.GroupName, - "sender": m.Sender, - "time": m.Time, - "message": ToStringMessage(m.Elements, m.GroupCode, true), + t := &message.SendingMessage{Elements: m.Elements} + replyElem := t.FirstOrNil(func(e message.IMessageElement) bool { + _, ok := e.(*message.ReplyElement) + return ok + }) + msg := &db.StoredGroupMessage{ + ID: encodeMessageID(m.GroupCode, m.Id), + GlobalID: db.ToGlobalID(m.GroupCode, m.Id), + SubType: "normal", + Attribute: &db.StoredMessageAttribute{ + MessageSeq: m.Id, + InternalID: m.InternalId, + SenderUin: m.Sender.Uin, + SenderName: m.Sender.DisplayName(), + Timestamp: int64(m.Time), + }, + GroupCode: m.GroupCode, + AnonymousID: func() string { + if m.Sender.IsAnonymous() { + return m.Sender.AnonymousInfo.AnonymousId + } + return "" + }(), + Content: ToMessageContent(m.Elements), } - id := toGlobalID(m.GroupCode, m.Id) - if bot.db != nil { - buf := global.NewBuffer() - defer global.PutBuffer(buf) - if err := gob.NewEncoder(buf).Encode(val); err != nil { - log.Warnf("记录聊天数据时出现错误: %v", err) - return -1 - } - if err := bot.db.Put(binary.ToBytes(id), buf.Bytes(), nil); err != nil { - log.Warnf("记录聊天数据时出现错误: %v", err) - return -1 + if replyElem != nil { + reply := replyElem.(*message.ReplyElement) + msg.SubType = "quote" + msg.QuotedInfo = &db.QuotedInfo{ + PrevID: encodeMessageID(m.GroupCode, reply.ReplySeq), + PrevGlobalID: db.ToGlobalID(m.GroupCode, reply.ReplySeq), + QuotedContent: ToMessageContent(reply.Elements), } } - return id + if err := bot.db.InsertGroupMessage(msg); err != nil { + log.Warnf("记录聊天数据时出现错误: %v", err) + return -1 + } + return msg.GlobalID } // InsertPrivateMessage 私聊消息入数据库 func (bot *CQBot) InsertPrivateMessage(m *message.PrivateMessage) int32 { - val := global.MSG{ - "message-id": m.Id, - "internal-id": m.InternalId, - "target": m.Target, - "sender": m.Sender, - "time": m.Time, - "message": ToStringMessage(m.Elements, 0, true), + t := &message.SendingMessage{Elements: m.Elements} + replyElem := t.FirstOrNil(func(e message.IMessageElement) bool { + _, ok := e.(*message.ReplyElement) + return ok + }) + msg := &db.StoredPrivateMessage{ + ID: encodeMessageID(m.Sender.Uin, m.Id), + GlobalID: db.ToGlobalID(m.Sender.Uin, m.Id), + SubType: "normal", + Attribute: &db.StoredMessageAttribute{ + MessageSeq: m.Id, + InternalID: m.InternalId, + SenderUin: m.Sender.Uin, + SenderName: m.Sender.DisplayName(), + Timestamp: int64(m.Time), + }, + SessionUin: func() int64 { + if m.Sender.Uin == m.Self { + return m.Target + } + return m.Sender.Uin + }(), + TargetUin: m.Target, + Content: ToMessageContent(m.Elements), } - id := toGlobalID(m.Sender.Uin, m.Id) - if bot.db != nil { - buf := global.NewBuffer() - defer global.PutBuffer(buf) - if err := gob.NewEncoder(buf).Encode(val); err != nil { - log.Warnf("记录聊天数据时出现错误: %v", err) - return -1 - } - if err := bot.db.Put(binary.ToBytes(id), buf.Bytes(), nil); err != nil { - log.Warnf("记录聊天数据时出现错误: %v", err) - return -1 + if replyElem != nil { + reply := replyElem.(*message.ReplyElement) + msg.SubType = "quote" + msg.QuotedInfo = &db.QuotedInfo{ + PrevID: encodeMessageID(reply.Sender, reply.ReplySeq), + PrevGlobalID: db.ToGlobalID(reply.Sender, reply.ReplySeq), + QuotedContent: ToMessageContent(m.Elements), } } - return id + if err := bot.db.InsertPrivateMessage(msg); err != nil { + log.Warnf("记录聊天数据时出现错误: %v", err) + return -1 + } + return msg.GlobalID } +/* // InsertTempMessage 临时消息入数据库 func (bot *CQBot) InsertTempMessage(target int64, m *message.TempMessage) int32 { val := global.MSG{ @@ -426,7 +447,7 @@ func (bot *CQBot) InsertTempMessage(target int64, m *message.TempMessage) int32 "time": int32(time.Now().Unix()), "message": ToStringMessage(m.Elements, 0, true), } - id := toGlobalID(m.Sender.Uin, m.Id) + id := db.ToGlobalID(m.Sender.Uin, m.Id) if bot.db != nil { buf := global.NewBuffer() defer global.PutBuffer(buf) @@ -441,17 +462,11 @@ func (bot *CQBot) InsertTempMessage(target int64, m *message.TempMessage) int32 } return id } - -// toGlobalID 构建`code`-`msgID`的字符串并返回其CRC32 Checksum的值 -func toGlobalID(code int64, msgID int32) int32 { - return int32(crc32.ChecksumIEEE([]byte(fmt.Sprintf("%d-%d", code, msgID)))) -} +*/ // Release 释放Bot实例 func (bot *CQBot) Release() { - if bot.db != nil { - _ = bot.db.Close() - } + } func (bot *CQBot) dispatchEventMessage(m global.MSG) { @@ -580,3 +595,11 @@ func (bot *CQBot) uploadMedia(raw message.IMessageElement, target int64, group b } return nil, errors.New("unsupported message element type") } + +// encodeMessageID 临时先这样, 暂时用不上 +func encodeMessageID(target int64, seq int32) string { + return hex.EncodeToString(binary.NewWriterF(func(w *binary.Writer) { + w.WriteUInt64(uint64(target)) + w.WriteUInt32(uint32(seq)) + })) +} diff --git a/coolq/cqcode.go b/coolq/cqcode.go index ef4aa78..bc9ca09 100644 --- a/coolq/cqcode.go +++ b/coolq/cqcode.go @@ -23,8 +23,10 @@ import ( log "github.com/sirupsen/logrus" "github.com/tidwall/gjson" + "github.com/Mrs4s/go-cqhttp/db" "github.com/Mrs4s/go-cqhttp/global" "github.com/Mrs4s/go-cqhttp/internal/base" + "github.com/Mrs4s/go-cqhttp/internal/cache" "github.com/Mrs4s/go-cqhttp/internal/param" ) @@ -132,7 +134,7 @@ func ToArrayMessage(e []message.IMessageElement, groupID int64) (r []global.MSG) r = append(r, global.MSG{ "type": "reply", "data": map[string]string{ - "id": strconv.FormatInt(int64(toGlobalID(rid, replyElem.ReplySeq)), 10), + "id": strconv.FormatInt(int64(db.ToGlobalID(rid, replyElem.ReplySeq)), 10), "seq": strconv.FormatInt(int64(replyElem.ReplySeq), 10), "qq": strconv.FormatInt(replyElem.Sender, 10), "time": strconv.FormatInt(int64(replyElem.Time), 10), @@ -142,7 +144,7 @@ func ToArrayMessage(e []message.IMessageElement, groupID int64) (r []global.MSG) } else { r = append(r, global.MSG{ "type": "reply", - "data": map[string]string{"id": strconv.FormatInt(int64(toGlobalID(rid, replyElem.ReplySeq)), 10)}, + "data": map[string]string{"id": strconv.FormatInt(int64(db.ToGlobalID(rid, replyElem.ReplySeq)), 10)}, }) } } @@ -275,11 +277,11 @@ func ToStringMessage(e []message.IMessageElement, groupID int64, isRaw ...bool) } if base.ExtraReplyData { write("[CQ:reply,id=%d,seq=%d,qq=%d,time=%d,text=%s]", - toGlobalID(rid, replyElem.ReplySeq), + db.ToGlobalID(rid, replyElem.ReplySeq), replyElem.ReplySeq, replyElem.Sender, replyElem.Time, CQCodeEscapeValue(ToStringMessage(replyElem.Elements, groupID))) } else { - write("[CQ:reply,id=%d]", toGlobalID(rid, replyElem.ReplySeq)) + write("[CQ:reply,id=%d]", db.ToGlobalID(rid, replyElem.ReplySeq)) } } for i, elem := range e { @@ -355,6 +357,110 @@ func ToStringMessage(e []message.IMessageElement, groupID int64, isRaw ...bool) return } +// ToMessageContent 将消息转换成 Content. 忽略 Reply +// 不同于 onebot 的 Array Message, 此函数转换出来的 Content 的 data 段为实际类型 +// 方便数据库查询 +func ToMessageContent(e []message.IMessageElement) (r []global.MSG) { + for _, elem := range e { + var m global.MSG + switch o := elem.(type) { + case *message.TextElement: + m = global.MSG{ + "type": "text", + "data": global.MSG{"text": o.Content}, + } + case *message.LightAppElement: + m = global.MSG{ + "type": "json", + "data": global.MSG{"data": o.Content}, + } + case *message.AtElement: + if o.Target == 0 { + m = global.MSG{ + "type": "at", + "data": global.MSG{ + "subType": "all", + }, + } + } else { + m = global.MSG{ + "type": "at", + "data": global.MSG{ + "subType": "user", + "target": o.Target, + "display": o.Display, + }, + } + } + case *message.RedBagElement: + m = global.MSG{ + "type": "redbag", + "data": global.MSG{"title": o.Title, "type": o.MsgType}, + } + case *message.ForwardElement: + m = global.MSG{ + "type": "forward", + "data": global.MSG{"id": o.ResId}, + } + case *message.FaceElement: + m = global.MSG{ + "type": "face", + "data": global.MSG{"id": o.Index}, + } + case *message.VoiceElement: + m = global.MSG{ + "type": "record", + "data": global.MSG{"file": o.Name, "url": o.Url}, + } + case *message.ShortVideoElement: + m = global.MSG{ + "type": "video", + "data": global.MSG{"file": o.Name, "url": o.Url}, + } + case *message.GroupImageElement: + data := global.MSG{"file": hex.EncodeToString(o.Md5) + ".image", "url": o.Url, "subType": uint32(o.ImageBizType)} + switch { + case o.Flash: + data["type"] = "flash" + case o.EffectID != 0: + data["type"] = "show" + data["id"] = o.EffectID + } + m = global.MSG{ + "type": "image", + "data": data, + } + case *message.FriendImageElement: + data := global.MSG{"file": hex.EncodeToString(o.Md5) + ".image", "url": o.Url} + if o.Flash { + data["type"] = "flash" + } + m = global.MSG{ + "type": "image", + "data": data, + } + case *message.ServiceElement: + if isOk := strings.Contains(o.Content, "= 40006 { + id = 40000 + } + } + } + switch img := e.(type) { + case *LocalImageElement: + img.Flash = flash + img.EffectID = id + case *message.GroupImageElement: + img.Flash = flash + img.EffectID = id + img.ImageBizType = message.ImageBizType(data["subType"].(uint32)) + case *message.FriendImageElement: + img.Flash = flash + } + r = append(r, e) + case "at": + switch data["subType"].(string) { + case "all": + r = append(r, message.NewAt(0)) + case "user": + r = append(r, message.NewAt(data["target"].(int64), data["display"].(string))) + default: + continue + } + case "redbag": + r = append(r, &message.RedBagElement{ + MsgType: message.RedBagMessageType(data["type"].(int)), + Title: data["title"].(string), + }) + case "forward": + r = append(r, &message.ForwardElement{ + ResId: data["id"].(string), + }) + case "face": + r = append(r, message.NewFace(data["id"].(int32))) + case "video": + e, err := bot.makeImageOrVideoElem(map[string]string{"file": data["file"].(string)}, true, group) + if err != nil { + log.Warnf("make image elem error: %v", err) + continue + } + r = append(r, e) + } + } + return +} + // ToElement 将解码后的CQCode转换为Element. // // 返回 interface{} 存在三种类型 @@ -1120,23 +1295,31 @@ func (bot *CQBot) makeImageOrVideoElem(d map[string]string, video, group bool) ( } rawPath := path.Join(global.ImagePath, f) if video { + if strings.HasSuffix(f, ".video") && cache.EnableCacheDB { + hash, err := hex.DecodeString(strings.TrimSuffix(f, ".video")) + if err == nil { + if b := cache.Video.Get(hash); b != nil { + return bot.readVideoCache(b), nil + } + } + } rawPath = path.Join(global.VideoPath, f) if !global.PathExists(rawPath) { return nil, errors.New("invalid video") } - if path.Ext(rawPath) == ".video" { - b, _ := os.ReadFile(rawPath) - r := binary.NewReader(b) - return &message.ShortVideoElement{ // todo 检查缓存是否有效 - Md5: r.ReadBytes(16), - ThumbMd5: r.ReadBytes(16), - Size: r.ReadInt32(), - ThumbSize: r.ReadInt32(), - Name: r.ReadString(), - Uuid: r.ReadAvailable(), - }, nil + if path.Ext(rawPath) != ".video" { + return &LocalVideoElement{File: rawPath}, nil + } + b, _ := os.ReadFile(rawPath) + return bot.readVideoCache(b), nil + } + if strings.HasSuffix(f, ".image") && cache.EnableCacheDB { + hash, err := hex.DecodeString(strings.TrimSuffix(f, ".image")) + if err == nil { + if b := cache.Image.Get(hash); b != nil { + return bot.readImageCache(b, group) + } } - return &LocalVideoElement{File: rawPath}, nil } exist := global.PathExists(rawPath) if !exist && global.PathExists(path.Join(global.ImagePathOld, f)) { @@ -1156,8 +1339,13 @@ func (bot *CQBot) makeImageOrVideoElem(d map[string]string, video, group bool) ( if err != nil { return nil, err } + return bot.readImageCache(b, group) +} + +func (bot *CQBot) readImageCache(b []byte, group bool) (message.IMessageElement, error) { + var err error if len(b) < 20 { - return nil, errors.New("invalid local file") + return nil, errors.New("invalid cache") } r := binary.NewReader(b) hash := r.ReadBytes(16) @@ -1189,6 +1377,18 @@ ok: return rsp, nil } +func (bot *CQBot) readVideoCache(b []byte) message.IMessageElement { + r := binary.NewReader(b) + return &message.ShortVideoElement{ // todo 检查缓存是否有效 + Md5: r.ReadBytes(16), + ThumbMd5: r.ReadBytes(16), + Size: r.ReadInt32(), + ThumbSize: r.ReadInt32(), + Name: r.ReadString(), + Uuid: r.ReadAvailable(), + } +} + // makeShowPic 一种xml 方式发送的群消息图片 func (bot *CQBot) makeShowPic(elem message.IMessageElement, source string, brief string, icon string, minWidth int64, minHeight int64, maxWidth int64, maxHeight int64, group bool) ([]message.IMessageElement, error) { xml := "" diff --git a/coolq/cqcode_test.go b/coolq/cqcode_test.go index 5bde41a..08658aa 100644 --- a/coolq/cqcode_test.go +++ b/coolq/cqcode_test.go @@ -41,7 +41,7 @@ const bText = `123456789[]&987654321[]&987654321[]&987654321[]&987654321[]&98765 func BenchmarkCQCodeEscapeText(b *testing.B) { for i := 0; i < b.N; i++ { ret := bText - ret = CQCodeEscapeText(ret) + CQCodeEscapeText(ret) } } @@ -50,7 +50,7 @@ func BenchmarkCQCodeEscapeTextBefore(b *testing.B) { ret := bText ret = strings.ReplaceAll(ret, "&", "&") ret = strings.ReplaceAll(ret, "[", "[") - ret = strings.ReplaceAll(ret, "]", "]") + strings.ReplaceAll(ret, "]", "]") } } diff --git a/coolq/event.go b/coolq/event.go index f00567b..fe9e425 100644 --- a/coolq/event.go +++ b/coolq/event.go @@ -8,6 +8,9 @@ import ( "strings" "time" + "github.com/Mrs4s/go-cqhttp/db" + "github.com/Mrs4s/go-cqhttp/internal/cache" + "github.com/Mrs4s/go-cqhttp/global" "github.com/Mrs4s/go-cqhttp/internal/base" @@ -105,8 +108,9 @@ func (bot *CQBot) tempMessageEvent(c *client.QQClient, e *client.TempMessageEven cqm := ToStringMessage(m.Elements, 0, true) bot.tempSessionCache.Store(m.Sender.Uin, e.Session) id := m.Id - if bot.db != nil { - id = bot.InsertTempMessage(m.Sender.Uin, m) + if bot.db != nil { // nolint + // todo(Mrs4s) + // id = bot.InsertTempMessage(m.Sender.Uin, m) } log.Infof("收到来自群 %v(%v) 内 %v(%v) 的临时会话消息: %v", m.GroupName, m.GroupCode, m.Sender.DisplayName(), m.Sender.Uin, cqm) tm := global.MSG{ @@ -172,7 +176,7 @@ func (bot *CQBot) groupMutedEvent(c *client.QQClient, e *client.GroupMuteEvent) func (bot *CQBot) groupRecallEvent(c *client.QQClient, e *client.GroupMessageRecalledEvent) { g := c.FindGroup(e.GroupCode) - gid := toGlobalID(e.GroupCode, e.MessageId) + gid := db.ToGlobalID(e.GroupCode, e.MessageId) log.Infof("群 %v 内 %v 撤回了 %v 的消息: %v.", formatGroupName(g), formatMemberName(g.FindMember(e.OperatorUin)), formatMemberName(g.FindMember(e.AuthorUin)), gid) bot.dispatchEventMessage(global.MSG{ @@ -289,7 +293,7 @@ func (bot *CQBot) memberTitleUpdatedEvent(c *client.QQClient, e *client.MemberSp func (bot *CQBot) friendRecallEvent(c *client.QQClient, e *client.FriendMessageRecalledEvent) { f := c.FindFriend(e.FriendUin) - gid := toGlobalID(e.FriendUin, e.MessageId) + gid := db.ToGlobalID(e.FriendUin, e.MessageId) if f != nil { log.Infof("好友 %v(%v) 撤回了消息: %v", f.Nickname, f.Uin, gid) } else { @@ -466,7 +470,7 @@ func (bot *CQBot) otherClientStatusChangedEvent(c *client.QQClient, e *client.Ot func (bot *CQBot) groupEssenceMsg(c *client.QQClient, e *client.GroupDigestEvent) { g := c.FindGroup(e.GroupCode) - gid := toGlobalID(e.GroupCode, e.MessageID) + gid := db.ToGlobalID(e.GroupCode, e.MessageID) if e.OperationType == 1 { log.Infof( "群 %v 内 %v 将 %v 的消息(%v)设为了精华消息.", @@ -545,29 +549,37 @@ func (bot *CQBot) groupDecrease(groupCode, userUin int64, operator *client.Group } func (bot *CQBot) checkMedia(e []message.IMessageElement) { + // TODO(wdvxdr): remove these old cache file in v1.0.0 for _, elem := range e { switch i := elem.(type) { case *message.GroupImageElement: + data := binary.NewWriterF(func(w *binary.Writer) { + w.Write(i.Md5) + w.WriteUInt32(uint32(i.Size)) + w.WriteString(i.ImageId) + w.WriteString(i.Url) + }) filename := hex.EncodeToString(i.Md5) + ".image" - if !global.PathExists(path.Join(global.ImagePath, filename)) { - _ = os.WriteFile(path.Join(global.ImagePath, filename), binary.NewWriterF(func(w *binary.Writer) { - w.Write(i.Md5) - w.WriteUInt32(uint32(i.Size)) - w.WriteString(i.ImageId) - w.WriteString(i.Url) - }), 0o644) + if cache.EnableCacheDB { + cache.Image.Insert(i.Md5, data) + } else if !global.PathExists(path.Join(global.ImagePath, filename)) { + _ = os.WriteFile(path.Join(global.ImagePath, filename), data, 0o644) } case *message.FriendImageElement: + data := binary.NewWriterF(func(w *binary.Writer) { + w.Write(i.Md5) + w.WriteUInt32(uint32(i.Size)) + w.WriteString(i.ImageId) + w.WriteString(i.Url) + }) filename := hex.EncodeToString(i.Md5) + ".image" - if !global.PathExists(path.Join(global.ImagePath, filename)) { - _ = os.WriteFile(path.Join(global.ImagePath, filename), binary.NewWriterF(func(w *binary.Writer) { - w.Write(i.Md5) - w.WriteUInt32(uint32(i.Size)) - w.WriteString(i.ImageId) - w.WriteString(i.Url) - }), 0o644) + if cache.EnableCacheDB { + cache.Image.Insert(i.Md5, data) + } else if !global.PathExists(path.Join(global.ImagePath, filename)) { + _ = os.WriteFile(path.Join(global.ImagePath, filename), data, 0o644) } case *message.VoiceElement: + // todo: don't download original file? i.Name = strings.ReplaceAll(i.Name, "{", "") i.Name = strings.ReplaceAll(i.Name, "}", "") if !global.PathExists(path.Join(global.VoicePath, i.Name)) { @@ -579,16 +591,19 @@ func (bot *CQBot) checkMedia(e []message.IMessageElement) { _ = os.WriteFile(path.Join(global.VoicePath, i.Name), b, 0o644) } case *message.ShortVideoElement: + data := binary.NewWriterF(func(w *binary.Writer) { + w.Write(i.Md5) + w.Write(i.ThumbMd5) + w.WriteUInt32(uint32(i.Size)) + w.WriteUInt32(uint32(i.ThumbSize)) + w.WriteString(i.Name) + w.Write(i.Uuid) + }) filename := hex.EncodeToString(i.Md5) + ".video" - if !global.PathExists(path.Join(global.VideoPath, filename)) { - _ = os.WriteFile(path.Join(global.VideoPath, filename), binary.NewWriterF(func(w *binary.Writer) { - w.Write(i.Md5) - w.Write(i.ThumbMd5) - w.WriteUInt32(uint32(i.Size)) - w.WriteUInt32(uint32(i.ThumbSize)) - w.WriteString(i.Name) - w.Write(i.Uuid) - }), 0o644) + if cache.EnableCacheDB { + cache.Video.Insert(i.Md5, data) + } else if !global.PathExists(path.Join(global.VideoPath, filename)) { + _ = os.WriteFile(path.Join(global.VideoPath, filename), data, 0o644) } i.Name = filename i.Url = bot.Client.GetShortVideoUrl(i.Uuid, i.Md5) diff --git a/db/database.go b/db/database.go new file mode 100644 index 0000000..08a852b --- /dev/null +++ b/db/database.go @@ -0,0 +1,101 @@ +package db + +import ( + "fmt" + "hash/crc32" + + "github.com/Mrs4s/go-cqhttp/global" +) + +type ( + // IDatabase 数据库操作接口定义 + IDatabase interface { + // Open 初始化数据库 + Open() error + + // GetMessageByGlobalID 通过 GlobalID 来获取消息 + GetMessageByGlobalID(int32) (IStoredMessage, error) + // GetGroupMessageByGlobalID 通过 GlobalID 来获取群消息 + GetGroupMessageByGlobalID(int32) (*StoredGroupMessage, error) + // GetPrivateMessageByGlobalID 通过 GlobalID 来获取私聊消息 + GetPrivateMessageByGlobalID(int32) (*StoredPrivateMessage, error) + + // InsertGroupMessage 向数据库写入新的群消息 + InsertGroupMessage(*StoredGroupMessage) error + // InsertPrivateMessage 向数据库写入新的私聊消息 + InsertPrivateMessage(*StoredPrivateMessage) error + } + + IStoredMessage interface { + GetID() string + GetType() string + GetGlobalID() int32 + GetAttribute() *StoredMessageAttribute + GetContent() []global.MSG + } + + // StoredGroupMessage 持久化群消息 + StoredGroupMessage struct { + ID string `bson:"_id"` + GlobalID int32 `bson:"globalId"` + Attribute *StoredMessageAttribute `bson:"attribute"` + SubType string `bson:"subType"` + QuotedInfo *QuotedInfo `bson:"quotedInfo"` + GroupCode int64 `bson:"groupCode"` + AnonymousID string `bson:"anonymousId"` + Content []global.MSG `bson:"content"` + } + + // StoredPrivateMessage 持久化私聊消息 + StoredPrivateMessage struct { + ID string `bson:"_id"` + GlobalID int32 `bson:"globalId"` + Attribute *StoredMessageAttribute `bson:"attribute"` + SubType string `bson:"subType"` + QuotedInfo *QuotedInfo `bson:"quotedInfo"` + SessionUin int64 `bson:"sessionUin"` + TargetUin int64 `bson:"targetUin"` + Content []global.MSG `bson:"content"` + } + + // StoredMessageAttribute 持久化消息属性 + StoredMessageAttribute struct { + MessageSeq int32 `bson:"messageSeq"` + InternalID int32 `bson:"internalId"` + SenderUin int64 `bson:"senderUin"` + SenderName string `bson:"senderName"` + Timestamp int64 `bson:"timestamp"` + } + + // QuotedInfo 引用回复 + QuotedInfo struct { + PrevID string `bson:"prevId"` + PrevGlobalID int32 `bson:"prevGlobalId"` + QuotedContent []global.MSG `bson:"quotedContent"` + } +) + +// ToGlobalID 构建`code`-`msgID`的字符串并返回其CRC32 Checksum的值 +func ToGlobalID(code int64, msgID int32) int32 { + return int32(crc32.ChecksumIEEE([]byte(fmt.Sprintf("%d-%d", code, msgID)))) +} + +func (m *StoredGroupMessage) GetID() string { return m.ID } + +func (m *StoredGroupMessage) GetType() string { return "group" } + +func (m *StoredGroupMessage) GetGlobalID() int32 { return m.GlobalID } + +func (m *StoredGroupMessage) GetAttribute() *StoredMessageAttribute { return m.Attribute } + +func (m *StoredGroupMessage) GetContent() []global.MSG { return m.Content } + +func (m *StoredPrivateMessage) GetID() string { return m.ID } + +func (m *StoredPrivateMessage) GetType() string { return "private" } + +func (m *StoredPrivateMessage) GetGlobalID() int32 { return m.GlobalID } + +func (m *StoredPrivateMessage) GetAttribute() *StoredMessageAttribute { return m.Attribute } + +func (m *StoredPrivateMessage) GetContent() []global.MSG { return m.Content } diff --git a/db/leveldb.go b/db/leveldb.go new file mode 100644 index 0000000..c38d52e --- /dev/null +++ b/db/leveldb.go @@ -0,0 +1,118 @@ +package db + +import ( + "bytes" + "encoding/gob" + "path" + + "github.com/Mrs4s/MiraiGo/binary" + "github.com/pkg/errors" + "github.com/syndtr/goleveldb/leveldb" + "github.com/syndtr/goleveldb/leveldb/opt" + + "github.com/Mrs4s/go-cqhttp/global" +) + +type LevelDBImpl struct { + db *leveldb.DB +} + +const ( + group byte = 0x0 + private byte = 0x1 +) + +func UseLevelDB() *LevelDBImpl { + gob.Register(StoredMessageAttribute{}) + gob.Register(QuotedInfo{}) + gob.Register(global.MSG{}) + gob.Register(StoredGroupMessage{}) + gob.Register(StoredPrivateMessage{}) + return &LevelDBImpl{} +} + +func (db *LevelDBImpl) Open() error { + p := path.Join("data", "leveldb-v2") + d, err := leveldb.OpenFile(p, &opt.Options{ + WriteBuffer: 128 * opt.KiB, + }) + if err != nil { + return errors.Wrap(err, "open level db error") + } + db.db = d + return nil +} + +func (db *LevelDBImpl) GetMessageByGlobalID(id int32) (IStoredMessage, error) { + v, err := db.db.Get(binary.ToBytes(id), nil) + if err != nil { + return nil, errors.Wrap(err, "get value error") + } + r := binary.NewReader(v) + switch r.ReadByte() { + case group: + g := &StoredGroupMessage{} + if err = gob.NewDecoder(bytes.NewReader(r.ReadAvailable())).Decode(g); err != nil { + return nil, errors.Wrap(err, "decode message error") + } + return g, nil + case private: + p := &StoredPrivateMessage{} + if err = gob.NewDecoder(bytes.NewReader(r.ReadAvailable())).Decode(p); err != nil { + return nil, errors.Wrap(err, "decode message error") + } + return p, nil + default: + return nil, errors.New("unknown message flag") + } +} + +func (db *LevelDBImpl) GetGroupMessageByGlobalID(id int32) (*StoredGroupMessage, error) { + i, err := db.GetMessageByGlobalID(id) + if err != nil { + return nil, err + } + g, ok := i.(*StoredGroupMessage) + if !ok { + return nil, errors.New("message type error") + } + return g, nil +} + +func (db *LevelDBImpl) GetPrivateMessageByGlobalID(id int32) (*StoredPrivateMessage, error) { + i, err := db.GetMessageByGlobalID(id) + if err != nil { + return nil, err + } + p, ok := i.(*StoredPrivateMessage) + if !ok { + return nil, errors.New("message type error") + } + return p, nil +} + +func (db *LevelDBImpl) InsertGroupMessage(msg *StoredGroupMessage) error { + buf := global.NewBuffer() + defer global.PutBuffer(buf) + if err := gob.NewEncoder(buf).Encode(msg); err != nil { + return errors.Wrap(err, "encode message error") + } + err := db.db.Put(binary.ToBytes(msg.GlobalID), binary.NewWriterF(func(w *binary.Writer) { + w.WriteByte(group) + w.Write(buf.Bytes()) + }), nil) + return errors.Wrap(err, "put data error") +} + +func (db *LevelDBImpl) InsertPrivateMessage(msg *StoredPrivateMessage) error { + buf := global.NewBuffer() + defer global.PutBuffer(buf) + if err := gob.NewEncoder(buf).Encode(msg); err != nil { + return errors.Wrap(err, "encode message error") + } + err := db.db.Put(binary.ToBytes(msg.GlobalID), binary.NewWriterF(func(w *binary.Writer) { + w.WriteByte(private) + w.Write(buf.Bytes()) + }), nil) + return errors.Wrap(err, "put data error") +} diff --git a/db/mongodb.go b/db/mongodb.go new file mode 100644 index 0000000..00677b1 --- /dev/null +++ b/db/mongodb.go @@ -0,0 +1,71 @@ +package db + +import ( + "context" + + "github.com/pkg/errors" + "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/mongo" + "go.mongodb.org/mongo-driver/mongo/options" +) + +type MongoDBImpl struct { + uri string + db string + mongo *mongo.Database +} + +const ( + MongoGroupMessageCollection = "group-messages" + MongoPrivateMessageCollection = "private-messages" +) + +func UseMongoDB(uri, db string) *MongoDBImpl { + return &MongoDBImpl{uri: uri, db: db} +} + +func (db *MongoDBImpl) Open() error { + cli, err := mongo.Connect(context.Background(), options.Client().ApplyURI(db.uri)) + if err != nil { + return errors.Wrap(err, "open mongo connection error") + } + db.mongo = cli.Database(db.db) + return nil +} + +func (db *MongoDBImpl) GetMessageByGlobalID(id int32) (IStoredMessage, error) { + if r, err := db.GetGroupMessageByGlobalID(id); err == nil { + return r, nil + } + return db.GetPrivateMessageByGlobalID(id) +} + +func (db *MongoDBImpl) GetGroupMessageByGlobalID(id int32) (*StoredGroupMessage, error) { + coll := db.mongo.Collection(MongoGroupMessageCollection) + var ret StoredGroupMessage + if err := coll.FindOne(context.Background(), bson.D{{"globalId", id}}).Decode(&ret); err != nil { + return nil, errors.Wrap(err, "query error") + } + return &ret, nil +} + +func (db *MongoDBImpl) GetPrivateMessageByGlobalID(id int32) (*StoredPrivateMessage, error) { + coll := db.mongo.Collection(MongoPrivateMessageCollection) + var ret StoredPrivateMessage + if err := coll.FindOne(context.Background(), bson.D{{"globalId", id}}).Decode(&ret); err != nil { + return nil, errors.Wrap(err, "query error") + } + return &ret, nil +} + +func (db *MongoDBImpl) InsertGroupMessage(msg *StoredGroupMessage) error { + coll := db.mongo.Collection(MongoGroupMessageCollection) + _, err := coll.InsertOne(context.Background(), msg) + return errors.Wrap(err, "insert error") +} + +func (db *MongoDBImpl) InsertPrivateMessage(msg *StoredPrivateMessage) error { + coll := db.mongo.Collection(MongoPrivateMessageCollection) + _, err := coll.InsertOne(context.Background(), msg) + return errors.Wrap(err, "insert error") +} diff --git a/db/multidb.go b/db/multidb.go new file mode 100644 index 0000000..5d7a78b --- /dev/null +++ b/db/multidb.go @@ -0,0 +1,68 @@ +package db + +import "github.com/pkg/errors" + +// MultiDBImpl 多数据库支持 +// 写入会对所有 Backend 进行写入 +// 读取只会读取第一个库 +type MultiDBImpl struct { + backends []IDatabase +} + +func NewMultiDatabase(backends ...IDatabase) *MultiDBImpl { + return &MultiDBImpl{ + backends: backends, + } +} + +func (db *MultiDBImpl) UseDB(backend IDatabase) { + db.backends = append(db.backends, backend) +} + +func (db *MultiDBImpl) Open() error { + for _, b := range db.backends { + if err := b.Open(); err != nil { + return errors.Wrap(err, "open backend error") + } + } + return nil +} + +func (db *MultiDBImpl) GetMessageByGlobalID(id int32) (IStoredMessage, error) { + if len(db.backends) == 0 { + return nil, errors.New("database disabled") + } + return db.backends[0].GetMessageByGlobalID(id) +} + +func (db *MultiDBImpl) GetGroupMessageByGlobalID(id int32) (*StoredGroupMessage, error) { + if len(db.backends) == 0 { + return nil, errors.New("database disabled") + } + return db.backends[0].GetGroupMessageByGlobalID(id) +} + +func (db *MultiDBImpl) GetPrivateMessageByGlobalID(id int32) (*StoredPrivateMessage, error) { + if len(db.backends) == 0 { + return nil, errors.New("database disabled") + } + return db.backends[0].GetPrivateMessageByGlobalID(id) +} + +func (db *MultiDBImpl) InsertGroupMessage(m *StoredGroupMessage) error { + for _, b := range db.backends { + if err := b.InsertGroupMessage(m); err != nil { + return errors.Wrap(err, "insert message to backend error") + } + } + return nil +} + +func (db *MultiDBImpl) InsertPrivateMessage(m *StoredPrivateMessage) error { + for _, b := range db.backends { + if err := b.InsertPrivateMessage(m); err != nil { + return errors.Wrap(err, "insert message to backend error") + } + } + return nil +} diff --git a/go.mod b/go.mod index 3fc6602..8ca13f9 100644 --- a/go.mod +++ b/go.mod @@ -23,6 +23,7 @@ require ( github.com/tidwall/gjson v1.8.1 github.com/tuotoo/qrcode v0.0.0-20190222102259-ac9c44189bf2 github.com/wdvxdr1123/go-silk v0.0.0-20210316130616-d47b553def60 + go.mongodb.org/mongo-driver v1.7.2 golang.org/x/crypto v0.0.0-20210817164053-32db794688a5 golang.org/x/term v0.0.0-20210615171337-6886f2dfbf5b golang.org/x/time v0.0.0-20210723032227-1f47c861a9ac @@ -32,7 +33,8 @@ require ( require ( github.com/davecgh/go-spew v1.1.1 // indirect - github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db // indirect + github.com/go-stack/stack v1.8.0 // indirect + github.com/golang/snappy v0.0.1 // indirect github.com/google/uuid v1.1.1 // indirect github.com/gorilla/websocket v1.4.2 // indirect github.com/jonboulle/clockwork v0.2.2 // indirect @@ -46,9 +48,14 @@ require ( github.com/tidwall/match v1.0.3 // indirect github.com/tidwall/pretty v1.1.0 // indirect github.com/willf/bitset v1.2.0 // indirect + github.com/xdg-go/pbkdf2 v1.0.0 // indirect + github.com/xdg-go/scram v1.0.2 // indirect + github.com/xdg-go/stringprep v1.0.2 // indirect + github.com/youmark/pkcs8 v0.0.0-20181117223130-1be2e3e5546d // indirect golang.org/x/net v0.0.0-20211015210444-4f30a5c0130f // indirect golang.org/x/sync v0.0.0-20210220032951-036812b2e83c // indirect golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1 // indirect + golang.org/x/text v0.3.6 // indirect google.golang.org/protobuf v1.27.1 // indirect modernc.org/libc v1.8.1 // indirect modernc.org/mathutil v1.2.2 // indirect diff --git a/go.sum b/go.sum index 6e94c31..ee13811 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,6 @@ github.com/Baozisoftware/qrcode-terminal-go v0.0.0-20170407111555-c0650d8dff0f h1:2dk3eOnYllh+wUOuDhOoC2vUVoJF/5z478ryJ+wzEII= github.com/Baozisoftware/qrcode-terminal-go v0.0.0-20170407111555-c0650d8dff0f/go.mod h1:4a58ifQTEe2uwwsaqbh3i2un5/CBPg+At/qHpt18Tmk= +github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/Microsoft/go-winio v0.5.1 h1:aPJp2QD7OOrhO5tQXqQoGSJc+DjDtWTGLOmNyAm6FgY= github.com/Microsoft/go-winio v0.5.1/go.mod h1:JPGBdM1cNvN/6ISo+n8V5iA4v8pBzdOpzfwIujj1a84= github.com/Mrs4s/MiraiGo v0.0.0-20211016141121-3b85d69fd3e1 h1:KED5vry8RkK1oVmXjaL2mKikuZDxYGQKZ9uqQI7cHoc= @@ -27,6 +28,32 @@ github.com/go-playground/universal-translator v0.17.0 h1:icxd5fm+REJzpZx7ZfpaD87 github.com/go-playground/universal-translator v0.17.0/go.mod h1:UkSxE5sNxxRwHyU+Scu5vgOQjsIJAF8j9muTVoKLVtA= github.com/go-playground/validator/v10 v10.2.0 h1:KgJ0snyC2R9VXYN2rneOtQcw5aHQB1Vv0sFl1UcHBOY= github.com/go-playground/validator/v10 v10.2.0/go.mod h1:uOYAAleCW8F/7oMFd6aG0GOhaH6EGOAJShg8Id5JGkI= +github.com/go-stack/stack v1.8.0 h1:5SgMzNM5HxrEjV0ww2lTmX6E2Izsfxas4+YHWRs3Lsk= +github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= +github.com/gobuffalo/attrs v0.0.0-20190224210810-a9411de4debd/go.mod h1:4duuawTqi2wkkpB4ePgWMaai6/Kc6WEz83bhFwpHzj0= +github.com/gobuffalo/depgen v0.0.0-20190329151759-d478694a28d3/go.mod h1:3STtPUQYuzV0gBVOY3vy6CfMm/ljR4pABfrTeHNLHUY= +github.com/gobuffalo/depgen v0.1.0/go.mod h1:+ifsuy7fhi15RWncXQQKjWS9JPkdah5sZvtHc2RXGlg= +github.com/gobuffalo/envy v1.6.15/go.mod h1:n7DRkBerg/aorDM8kbduw5dN3oXGswK5liaSCx4T5NI= +github.com/gobuffalo/envy v1.7.0/go.mod h1:n7DRkBerg/aorDM8kbduw5dN3oXGswK5liaSCx4T5NI= +github.com/gobuffalo/flect v0.1.0/go.mod h1:d2ehjJqGOH/Kjqcoz+F7jHTBbmDb38yXA598Hb50EGs= +github.com/gobuffalo/flect v0.1.1/go.mod h1:8JCgGVbRjJhVgD6399mQr4fx5rRfGKVzFjbj6RE/9UI= +github.com/gobuffalo/flect v0.1.3/go.mod h1:8JCgGVbRjJhVgD6399mQr4fx5rRfGKVzFjbj6RE/9UI= +github.com/gobuffalo/genny v0.0.0-20190329151137-27723ad26ef9/go.mod h1:rWs4Z12d1Zbf19rlsn0nurr75KqhYp52EAGGxTbBhNk= +github.com/gobuffalo/genny v0.0.0-20190403191548-3ca520ef0d9e/go.mod h1:80lIj3kVJWwOrXWWMRzzdhW3DsrdjILVil/SFKBzF28= +github.com/gobuffalo/genny v0.1.0/go.mod h1:XidbUqzak3lHdS//TPu2OgiFB+51Ur5f7CSnXZ/JDvo= +github.com/gobuffalo/genny v0.1.1/go.mod h1:5TExbEyY48pfunL4QSXxlDOmdsD44RRq4mVZ0Ex28Xk= +github.com/gobuffalo/gitgen v0.0.0-20190315122116-cc086187d211/go.mod h1:vEHJk/E9DmhejeLeNt7UVvlSGv3ziL+djtTr3yyzcOw= +github.com/gobuffalo/gogen v0.0.0-20190315121717-8f38393713f5/go.mod h1:V9QVDIxsgKNZs6L2IYiGR8datgMhB577vzTDqypH360= +github.com/gobuffalo/gogen v0.1.0/go.mod h1:8NTelM5qd8RZ15VjQTFkAW6qOMx5wBbW4dSCS3BY8gg= +github.com/gobuffalo/gogen v0.1.1/go.mod h1:y8iBtmHmGc4qa3urIyo1shvOD8JftTtfcKi+71xfDNE= +github.com/gobuffalo/logger v0.0.0-20190315122211-86e12af44bc2/go.mod h1:QdxcLw541hSGtBnhUc4gaNIXRjiDppFGaDqzbrBd3v8= +github.com/gobuffalo/mapi v1.0.1/go.mod h1:4VAGh89y6rVOvm5A8fKFxYG+wIW6LO1FMTG9hnKStFc= +github.com/gobuffalo/mapi v1.0.2/go.mod h1:4VAGh89y6rVOvm5A8fKFxYG+wIW6LO1FMTG9hnKStFc= +github.com/gobuffalo/packd v0.0.0-20190315124812-a385830c7fc0/go.mod h1:M2Juc+hhDXf/PnmBANFCqx4DM3wRbgDvnVWeG2RIxq4= +github.com/gobuffalo/packd v0.1.0/go.mod h1:M2Juc+hhDXf/PnmBANFCqx4DM3wRbgDvnVWeG2RIxq4= +github.com/gobuffalo/packr/v2 v2.0.9/go.mod h1:emmyGweYTm6Kdper+iywB6YK5YzuKchGtJQZ0Odn4pQ= +github.com/gobuffalo/packr/v2 v2.2.0/go.mod h1:CaAwI0GPIAv+5wKLtv8Afwl+Cm78K/I/VCm/3ptBN+0= +github.com/gobuffalo/syncx v0.0.0-20190224160051-33c29581e754/go.mod h1:HhnNqWY95UYwwW3uSASeV7vtgYkT2t16hJgV3AEPUpw= github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee h1:s+21KNqlpePfkah2I+gwHF8xmJWRjooY+5248k6m4A0= github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee/go.mod h1:L0fX3K22YWvt/FAX9NnzrNzcI4wNYi9Yku4O0LKYflo= github.com/gobwas/pool v0.2.0 h1:QEmUOlnSjWtnpRGHF3SauEiOsy82Cup83Vf2LcMlnc8= @@ -38,9 +65,11 @@ github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaW github.com/golang/protobuf v1.3.5/go.mod h1:6O5/vntMXwX2lRkT1hjjk0nAC1IDOTvTlVgjlRvqsdk= github.com/golang/protobuf v1.5.0 h1:LUVKkCeviFUMKqHa4tXIIij/lbhnMbP7Fn5wKdKkRh4= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= -github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db h1:woRePGFeVFfLKN/pOkfl+p/TAqKOfFu+7KPlMVpok/w= github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= +github.com/golang/snappy v0.0.1 h1:Qgr9rKW7uDUkrbSmQeiDsGa8SjGyCOGtuasMWwvp2P4= +github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= @@ -52,17 +81,29 @@ github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0U github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/hpcloud/tail v1.0.0 h1:nfCOvKYfkgYP8hkirhJocXT2+zOD8yUNjXaWfTlyFKI= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= +github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= +github.com/joho/godotenv v1.3.0/go.mod h1:7hK45KPybAkOC6peb+G5yklZfMxEjkZhHbwpqxOKXbg= github.com/jonboulle/clockwork v0.2.2 h1:UOGuzwb1PwsrDAObMuhUnj0p5ULPj8V/xJ7Kx9qUBdQ= github.com/jonboulle/clockwork v0.2.2/go.mod h1:Pkfl5aHPm1nk2H9h0bjmnJD/BcgbGXUBGnn1kMkgxc8= github.com/json-iterator/go v1.1.9 h1:9yzud/Ht36ygwatGx56VwCZtlI/2AD15T1X2sjSuGns= github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= github.com/kardianos/osext v0.0.0-20190222173326-2bc1f35cddc0 h1:iQTw/8FWTuc7uiaSepXwyf3o52HaUYcV+Tu66S3F5GA= github.com/kardianos/osext v0.0.0-20190222173326-2bc1f35cddc0/go.mod h1:1NbS8ALrpOvjt0rHPNLyCIeMtbizbir8U//inJ+zuB8= +github.com/karrick/godirwalk v1.8.0/go.mod h1:H5KPZjojv4lE+QYImBI8xVtrBRgYrIVsaRPx4tDPEn4= +github.com/karrick/godirwalk v1.10.3/go.mod h1:RoGL9dQei4vP9ilrpETWE8CLOZ1kiN0LhBygSwrAsHA= +github.com/klauspost/compress v1.9.5/go.mod h1:RyIbtBH6LamlWaDj8nUwkbUhJ87Yi3uG0guNDohfE1A= github.com/klauspost/compress v1.10.3/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs= github.com/klauspost/compress v1.13.6 h1:P76CopJELS0TiO2mebmnzgWaajssP/EszplttgQxcgc= github.com/klauspost/compress v1.13.6/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk= github.com/klauspost/cpuid/v2 v2.0.6 h1:dQ5ueTiftKxp0gyjKSx5+8BtPWkyQbd95m8Gys/RarI= github.com/klauspost/cpuid/v2 v2.0.6/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= +github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/leodido/go-urn v1.2.0 h1:hpXL4XnriNwQ/ABnpepYM/1vCLWNDfUNts8dX3xTG6Y= github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgxpxOKII= github.com/lestrrat-go/envload v0.0.0-20180220234015-a3eb8ddeffcc h1:RKf14vYWi2ttpEmkA4aQ3j4u9dStX2t4M8UM6qqNsG8= @@ -71,6 +112,8 @@ github.com/lestrrat-go/file-rotatelogs v2.4.0+incompatible h1:Y6sqxHMyB1D2YSzWkL github.com/lestrrat-go/file-rotatelogs v2.4.0+incompatible/go.mod h1:ZQnN8lSECaebrkQytbHj4xNgtg8CR7RYXnPok8e0EHA= github.com/lestrrat-go/strftime v1.0.5 h1:A7H3tT8DhTz8u65w+JRpiBxM4dINQhUXAZnhBa2xeOE= github.com/lestrrat-go/strftime v1.0.5/go.mod h1:E1nN3pCbtMSu1yjSVeyuRFVm/U0xoR76fd03sz+Qz4g= +github.com/markbates/oncer v0.0.0-20181203154359-bf2de49a0be2/go.mod h1:Ld9puTsIW75CHf65OeIOkyKbteujpZVXDpWK6YGZbxE= +github.com/markbates/safe v1.0.1/go.mod h1:nAqgmRi7cY2nqMc92/bSEeQA+R4OheNU2T1kNSCBdG0= github.com/maruel/rs v0.0.0-20150922171536-2c81c4312fe4 h1:u9jwvcKbQpghIXgNl/EOL8hzhAFXh4ePrEP493W3tNA= github.com/maruel/rs v0.0.0-20150922171536-2c81c4312fe4/go.mod h1:kcRFpEzolcEklV6rD7W95mG49/sbdX/PlFmd7ni3RvA= github.com/mattn/go-colorable v0.1.8 h1:c1ghPdyEDarC70ftn0y+A/Ee++9zz8ljHG1b13eJ0s8= @@ -81,11 +124,14 @@ github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 h1:ZqeYNhU3OH github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742 h1:Esafd1046DLDQ0W1YjYsBW+p8U2u7vzgW2SQVmlNazg= github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= +github.com/montanaflynn/stats v0.0.0-20171201202039-1bf9dbcd8cbe/go.mod h1:wL8QJuTMNUDYhXwkmfOly8iTdp5TEcJFWZD2D7SIkUc= github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= github.com/onsi/ginkgo v1.7.0 h1:WSHQ+IS43OoUrWtD1/bbclrwK8TTH5hzp+umCiuxHgs= github.com/onsi/ginkgo v1.7.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= github.com/onsi/gomega v1.4.3 h1:RE1xgDvH7imwFD45h+u2SgIfERHlS2yNG4DObb5BSKU= github.com/onsi/gomega v1.4.3/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY= +github.com/pelletier/go-toml v1.7.0/go.mod h1:vwGMzjaWMwyfHwgIBhI2YUM4fB6nL6lVAvS1LBMMhTE= +github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= @@ -93,17 +139,27 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0 h1:OdAsTTz6OkFY5QxjkYwrChwuRruF69c169dPK26NUlk= github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= +github.com/rogpeppe/go-internal v1.1.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= +github.com/rogpeppe/go-internal v1.2.2/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= +github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/segmentio/asm v1.1.0 h1:fkVr8k5J4sKoFjTGVD6r1yKvDKqmvrEh3K7iyVxgBs8= github.com/segmentio/asm v1.1.0/go.mod h1:4EUJGaKsB8ImLUwOGORVsNd9vTRDeh44JGsY4aKp5I4= +github.com/sirupsen/logrus v1.4.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= +github.com/sirupsen/logrus v1.4.1/go.mod h1:ni0Sbl8bgC9z8RoU9G6nDWqqs/fq4eDPysMBDgk/93Q= +github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= github.com/sirupsen/logrus v1.7.0/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0= github.com/sirupsen/logrus v1.8.1 h1:dJKuHgqk1NNQlqoA6BTlM1Wf9DOH3NBjQyu0h9+AZZE= github.com/sirupsen/logrus v1.8.1/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0= github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0= github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDqSt+GTGFMVlhk3ULuV0y9ZmzeVGR4mloJI3M= +github.com/spf13/cobra v0.0.3/go.mod h1:1l0Ry5zgKvJasoi3XT1TypsSe7PqH0Sj9dhYf7v3XqQ= +github.com/spf13/pflag v1.0.3/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/syndtr/goleveldb v1.0.0 h1:fBdIW9lB4Iz0n9khmH8w27SJ3QEJ7+IgjPEwGSZiFdE= @@ -112,6 +168,7 @@ github.com/tidwall/gjson v1.8.1 h1:8j5EE9Hrh3l9Od1OIEDAb7IpezNA20UdRngNAj5N0WU= github.com/tidwall/gjson v1.8.1/go.mod h1:5/xDoumyyDNerp2U36lyolv46b3uF/9Bu6OfyQ9GImk= github.com/tidwall/match v1.0.3 h1:FQUVvBImDutD8wJLN6c5eMzWtjgONK9MwIBCOrUJKeE= github.com/tidwall/match v1.0.3/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= +github.com/tidwall/pretty v1.0.0/go.mod h1:XNkn88O1ChpSDQmQeStsy+sBenx6DDtFZJxhVysOjyk= github.com/tidwall/pretty v1.1.0 h1:K3hMW5epkdAVwibsQEfR/7Zj0Qgt4DxtNumTq/VloO8= github.com/tidwall/pretty v1.1.0/go.mod h1:XNkn88O1ChpSDQmQeStsy+sBenx6DDtFZJxhVysOjyk= github.com/tuotoo/qrcode v0.0.0-20190222102259-ac9c44189bf2 h1:BWVtt2VBY+lmVDu9MGKqLGKl04B+iRHcrW1Ptyi/8tg= @@ -122,17 +179,44 @@ github.com/ugorji/go/codec v1.1.7 h1:2SvQaVZ1ouYrrKKwoSk2pzd4A9evlKJb9oTL+OaLUSs github.com/ugorji/go/codec v1.1.7/go.mod h1:Ax+UKWsSmolVDwsd+7N3ZtXu+yMGCf907BLYF3GoBXY= github.com/wdvxdr1123/go-silk v0.0.0-20210316130616-d47b553def60 h1:lRKf10iIOW0VsH5WDF621ihzR+R2wEBZVtNRHuLLCb4= github.com/wdvxdr1123/go-silk v0.0.0-20210316130616-d47b553def60/go.mod h1:ecFKZPX81BaB70I6ruUgEwYcDOtuNgJGnjdK+MIl5ko= +github.com/xdg-go/pbkdf2 v1.0.0 h1:Su7DPu48wXMwC3bs7MCNG+z4FhcyEuz5dlvchbq0B0c= +github.com/xdg-go/pbkdf2 v1.0.0/go.mod h1:jrpuAogTd400dnrH08LKmI/xc1MbPOebTwRqcT5RDeI= +github.com/xdg-go/scram v1.0.2 h1:akYIkZ28e6A96dkWNJQu3nmCzH3YfwMPQExUYDaRv7w= +github.com/xdg-go/scram v1.0.2/go.mod h1:1WAq6h33pAW+iRreB34OORO2Nf7qel3VV3fjBj+hCSs= +github.com/xdg-go/stringprep v1.0.2 h1:6iq84/ryjjeRmMJwxutI51F2GIPlP5BfTvXHeYjyhBc= +github.com/xdg-go/stringprep v1.0.2/go.mod h1:8F9zXuvzgwmyT5DUm4GUfZGDdT3W+LCvS6+da4O5kxM= +github.com/youmark/pkcs8 v0.0.0-20181117223130-1be2e3e5546d h1:splanxYIlg+5LfHAM6xpdFEAYOk8iySO56hMFq6uLyA= +github.com/youmark/pkcs8 v0.0.0-20181117223130-1be2e3e5546d/go.mod h1:rHwXgn7JulP+udvsHwJoVG1YGAP6VLg4y9I5dyZdqmA= +go.mongodb.org/mongo-driver v1.7.2 h1:pFttQyIiJUHEn50YfZgC9ECjITMT44oiN36uArf/OFg= +go.mongodb.org/mongo-driver v1.7.2/go.mod h1:Q4oFMbo1+MSNqICAdYMlC/zSTrwCogR4R8NzkI+yfU8= +golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20190422162423-af44ce270edf/go.mod h1:WFFai1msRO1wXaEeE5yQxYXgSfI8pQAWXbQop6sCtWE= +golang.org/x/crypto v0.0.0-20200302210943-78000ba7a073/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20210817164053-32db794688a5 h1:HWj/xjIHfjYU5nVXpTM0s39J9CbLn7Cc5a7IC5rwsMQ= golang.org/x/crypto v0.0.0-20210817164053-32db794688a5/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210505024714-0287a6fb4125/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20211015210444-4f30a5c0130f h1:OfiFi4JbukWwe3lzw+xunroH1mnC1e2Gy5cxNJApiSY= golang.org/x/net v0.0.0-20211015210444-4f30a5c0130f/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190412183630-56d357773e84/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c h1:5KslGYwFpkhGh+Q16bwMP3cOontH8FOep7tGV86Y7SQ= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190419153524-e8e3143a4f4a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190531175056-4c3a928424d2/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -148,19 +232,26 @@ golang.org/x/term v0.0.0-20210615171337-6886f2dfbf5b/go.mod h1:jbD1KX2456YbFQfuX golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6 h1:aRYxNxv6iGQlyVaZmk6ZgYEDa+Jg18DxebPSrd6bg1M= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20210723032227-1f47c861a9ac h1:7zkz7BUtwNFFqcowJ+RIgu2MaV/MapERkDIy+mwPyjs= golang.org/x/time v0.0.0-20210723032227-1f47c861a9ac/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190329151228-23e29df326fe/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20190416151739-9c9e1878f421/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20190420181800-aa740d480789/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20190531172133-b3315ee88b7d/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.27.1 h1:SnqbnDw1V7RiZcXPx5MEeqPv2s79L9i7BJUlG/+RurQ= google.golang.org/protobuf v1.27.1/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= gopkg.in/fsnotify.v1 v1.4.7 h1:xOHLXZwVvI9hhs+cLKq5+I5onOuwQLhQwiu63xxlHs4= gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= diff --git a/internal/btree/btree.go b/internal/btree/btree.go new file mode 100644 index 0000000..ccbdb00 --- /dev/null +++ b/internal/btree/btree.go @@ -0,0 +1,524 @@ +// Package btree provide a disk-based btree +package btree + +import ( + "io" + "math/rand" + "os" + "unsafe" + + "github.com/pkg/errors" +) + +const ( + hashSize = 16 // md5 hash + tableSize = (1024 - 1) / int(unsafe.Sizeof(item{})) + cacheSlots = 11 // prime + superSize = int(unsafe.Sizeof(super{})) + tableStructSize = int(unsafe.Sizeof(table{})) +) + +type item struct { + hash [hashSize]byte + offset int64 + child int64 +} + +type table struct { + items [tableSize]item + size int +} + +type cache struct { + table *table + offset int64 +} + +type super struct { + top int64 + freeTop int64 + alloc int64 +} + +// DB ... +type DB struct { + fd *os.File + top int64 + freeTop int64 + alloc int64 + cache [cacheSlots]cache + + inAllocator bool + deleteLarger bool + fqueue [freeQueueLen]chunk + fqueueLen int +} + +func (d *DB) get(offset int64) *table { + assert(offset != 0) + + // take from cache + slot := &d.cache[offset%cacheSlots] + if slot.offset == offset { + return slot.table + } + + table := new(table) + + d.fd.Seek(offset, io.SeekStart) + err := readTable(d.fd, table) + if err != nil { + panic(errors.Wrap(err, "btree I/O error")) + } + return table +} + +func (d *DB) put(t *table, offset int64) { + assert(offset != 0) + + // overwrite cache + slot := &d.cache[offset%cacheSlots] + slot.table = t + slot.offset = offset +} + +func (d *DB) flush(t *table, offset int64) { + assert(offset != 0) + + d.fd.Seek(offset, io.SeekStart) + err := writeTable(d.fd, t) + if err != nil { + panic(errors.Wrap(err, "btree I/O error")) + } + d.put(t, offset) +} + +func (d *DB) flushSuper() { + d.fd.Seek(0, io.SeekStart) + super := super{ + top: d.top, + freeTop: d.freeTop, + alloc: d.alloc, + } + err := writeSuper(d.fd, &super) + if err != nil { + panic(errors.Wrap(err, "btree I/O error")) + } +} + +// Open opens an existed btree file +func Open(name string) (*DB, error) { + btree := new(DB) + fd, err := os.OpenFile(name, os.O_RDWR, 0o644) + if err != nil { + return nil, errors.Wrap(err, "btree open file failed") + } + btree.fd = fd + + super := super{} + err = readSuper(fd, &super) + btree.top = super.top + btree.freeTop = super.freeTop + btree.alloc = super.alloc + return btree, errors.Wrap(err, "btree read meta info failed") +} + +// Create creates a database +func Create(name string) (*DB, error) { + btree := new(DB) + fd, err := os.OpenFile(name, os.O_RDWR|os.O_TRUNC|os.O_CREATE, 0o644) + if err != nil { + return nil, errors.Wrap(err, "btree open file failed") + } + + btree.fd = fd + btree.alloc = int64(superSize) + btree.flushSuper() + return btree, nil +} + +// Close closes the database +func (d *DB) Close() error { + _ = d.fd.Sync() + err := d.fd.Close() + for i := 0; i < cacheSlots; i++ { + d.cache[i] = cache{} + } + return errors.Wrap(err, "btree close failed") +} + +func collapse(bt *DB, offset int64) int64 { + table := bt.get(offset) + if table.size != 0 { + /* unable to collapse */ + bt.put(table, offset) + return offset + } + ret := table.items[0].child + bt.put(table, offset) + + /* + * WARNING: this is dangerous as the chunk is added to allocation tree + * before the references to it are removed! + */ + bt.freeChunk(offset, int(unsafe.Sizeof(table))) + return ret +} + +// split a table. The pivot item is stored to 'hash' and 'offset'. +// Returns offset to the new table. +func (d *DB) split(t *table, hash *byte, offset *int64) int64 { + copyhash(hash, &t.items[tableSize/2].hash[0]) + *offset = t.items[tableSize/2].offset + + ntable := new(table) + ntable.size = t.size - tableSize/2 - 1 + + t.size = tableSize / 2 + + copy(ntable.items[:ntable.size+1], t.items[tableSize/2+1:]) + + noff := d.allocChunk(tableStructSize) + d.flush(ntable, noff) + + // make sure data is written before a reference is added to it + _ = d.fd.Sync() + return noff +} + +// takeSmallest find and remove the smallest item from the given table. The key of the item +// is stored to 'hash'. Returns offset to the item +func (d *DB) takeSmallest(toff int64, sha1 *byte) int64 { + table := d.get(toff) + assert(table.size > 0) + + var off int64 + child := table.items[0].child + if child == 0 { + off = d.remove(table, 0, sha1) + } else { + /* recursion */ + off = d.takeSmallest(child, sha1) + table.items[0].child = collapse(d, child) + } + d.flush(table, toff) + + // make sure data is written before a reference is added to it + _ = d.fd.Sync() + return off +} + +// takeLargest find and remove the largest item from the given table. The key of the item +// is stored to 'hash'. Returns offset to the item +func (d *DB) takeLargest(toff int64, hash *byte) int64 { + table := d.get(toff) + assert(table.size > 0) + + var off int64 + child := table.items[table.size].child + if child == 0 { + off = d.remove(table, table.size-1, hash) + } else { + /* recursion */ + off = d.takeLargest(child, hash) + table.items[table.size].child = collapse(d, child) + } + d.flush(table, toff) + + // make sure data is written before a reference is added to it + _ = d.fd.Sync() + return off +} + +// remove an item in position 'i' from the given table. The key of the +// removed item is stored to 'hash'. Returns offset to the item. +func (d *DB) remove(t *table, i int, hash *byte) int64 { + assert(i < t.size) + + if hash != nil { + copyhash(hash, &t.items[i].hash[0]) + } + + offset := t.items[i].offset + lc := t.items[i].child + rc := t.items[i+1].child + + if lc != 0 && rc != 0 { + /* replace the removed item by taking an item from one of the + child tables */ + var noff int64 + if rand.Int()&1 != 0 { + noff = d.takeLargest(lc, &t.items[i].hash[0]) + t.items[i].child = collapse(d, lc) + } else { + noff = d.takeSmallest(rc, &t.items[i].hash[0]) + t.items[i+1].child = collapse(d, rc) + } + t.items[i].child = noff + } else { + // memmove(&table->items[i], &table->items[i + 1], + // (table->size - i) * sizeof(struct btree_item)); + copy(t.items[i:], t.items[i+1:]) + t.size-- + + if lc != 0 { + t.items[i].child = lc + } else { + t.items[i].child = rc + } + } + return offset +} + +func (d *DB) insert(toff int64, hash *byte, data []byte, size int) int64 { + table := d.get(toff) + assert(table.size < tableSize-1) + + left, right := 0, table.size + for left < right { + mid := (right-left)>>1 + left + switch cmp := cmp(hash, &table.items[mid].hash[0]); { + case cmp == 0: + // already in the table + ret := table.items[mid].offset + d.put(table, toff) + return ret + case cmp < 0: + right = mid + default: + left = mid + 1 + } + } + i := left + + var off, rc, ret int64 + lc := table.items[i].child + if lc != 0 { + /* recursion */ + ret = d.insert(lc, hash, data, size) + + /* check if we need to split */ + child := d.get(lc) + if child.size < tableSize-1 { + /* nothing to do */ + d.put(table, toff) + d.put(child, lc) + return ret + } + /* overwrites SHA-1 */ + rc = d.split(child, hash, &off) + /* flush just in case changes happened */ + d.flush(child, lc) + + // make sure data is written before a reference is added to it + _ = d.fd.Sync() + } else { + off = d.insertData(data, size) + ret = off + } + + table.size++ + // memmove(&table->items[i + 1], &table->items[i], + // (table->size - i) * sizeof(struct btree_item)); + copy(table.items[i+1:], table.items[i:]) + copyhash(&table.items[i].hash[0], hash) + table.items[i].offset = off + table.items[i].child = lc + table.items[i+1].child = rc + + d.flush(table, toff) + return ret +} + +func (d *DB) insertData(data []byte, size int) int64 { + if data == nil { + return int64(size) + } + assert(len(data) == size) + + offset := d.allocChunk(4 + len(data)) + + d.fd.Seek(offset, io.SeekStart) + err := write32(d.fd, int32(len(data))) + if err != nil { + panic(errors.Wrap(err, "btree I/O error")) + } + _, err = d.fd.Write(data) + if err != nil { + panic(errors.Wrap(err, "btree I/O error")) + } + + // make sure data is written before a reference is added to it + _ = d.fd.Sync() + return offset +} + +// delete remove an item with key 'hash' from the given table. The offset to the +// removed item is returned. +// Please note that 'hash' is overwritten when called inside the allocator. +func (d *DB) delete(offset int64, hash *byte) int64 { + if offset == 0 { + return 0 + } + table := d.get(offset) + + left, right := 0, table.size + for left < right { + i := (right-left)>>1 + left + switch cmp := cmp(hash, &table.items[i].hash[0]); { + case cmp == 0: + // found + ret := d.remove(table, i, hash) + d.flush(table, offset) + return ret + case cmp < 0: + right = i + default: + left = i + 1 + } + } + + // not found - recursion + i := left + child := table.items[i].child + ret := d.delete(child, hash) + if ret != 0 { + table.items[i].child = collapse(d, child) + } + + if ret == 0 && d.deleteLarger && i < table.size { + ret = d.remove(table, i, hash) + } + if ret != 0 { + /* flush just in case changes happened */ + d.flush(table, offset) + } else { + d.put(table, offset) + } + return ret +} + +func (d *DB) insertTopLevel(toff *int64, hash *byte, data []byte, size int) int64 { // nolint:unparam + var off, ret, rc int64 + if *toff != 0 { + ret = d.insert(*toff, hash, data, size) + + /* check if we need to split */ + table := d.get(*toff) + if table.size < tableSize-1 { + /* nothing to do */ + d.put(table, *toff) + return ret + } + rc = d.split(table, hash, &off) + d.flush(table, *toff) + } else { + off = d.insertData(data, size) + ret = off + } + + /* create new top level table */ + t := new(table) + t.size = 1 + copyhash(&t.items[0].hash[0], hash) + t.items[0].offset = off + t.items[0].child = *toff + t.items[1].child = rc + + ntoff := d.allocChunk(tableStructSize) + d.flush(t, ntoff) + + *toff = ntoff + + // make sure data is written before a reference is added to it + _ = d.fd.Sync() + return ret +} + +func (d *DB) lookup(toff int64, hash *byte) int64 { + if toff == 0 { + return 0 + } + table := d.get(toff) + + left, right := 0, table.size + for left < right { + mid := (right-left)>>1 + left + switch cmp := cmp(hash, &table.items[mid].hash[0]); { + case cmp == 0: + // found + ret := table.items[mid].offset + d.put(table, toff) + return ret + case cmp < 0: + right = mid + default: + left = mid + 1 + } + } + + i := left + child := table.items[i].child + d.put(table, toff) + return d.lookup(child, hash) +} + +// Insert a new item with key 'hash' with the contents in 'data' to the +// database file. +func (d *DB) Insert(chash *byte, data []byte) { + /* SHA-1 must be in writable memory */ + var hash [hashSize]byte + copyhash(&hash[0], chash) + + _ = d.insertTopLevel(&d.top, &hash[0], data, len(data)) + freeQueued(d) + d.flushSuper() +} + +// Get look up item with the given key 'hash' in the database file. Length of the +// item is stored in 'len'. Returns a pointer to the contents of the item. +// The returned pointer should be released with free() after use. +func (d *DB) Get(hash *byte) []byte { + off := d.lookup(d.top, hash) + if off == 0 { + return nil + } + + d.fd.Seek(off, io.SeekStart) + length, err := read32(d.fd) + if err != nil { + return nil + } + data := make([]byte, length) + n, err := io.ReadFull(d.fd, data) + if err != nil { + return nil + } + return data[:n] +} + +// Delete remove item with the given key 'hash' from the database file. +func (d *DB) Delete(hash *byte) error { + var h [hashSize]byte + copyhash(&h[0], hash) + + off := d.delete(d.top, &h[0]) + if off == 0 { + return nil // not found key + } + + d.top = collapse(d, d.top) + freeQueued(d) + d.flushSuper() + + d.fd.Seek(off, io.SeekStart) + length, err := read32(d.fd) // len: 0 + if err != nil { + return errors.Wrap(err, "btree I/O error") + } + + d.freeChunk(off, int(length+4)) + freeQueued(d) + d.flushSuper() + return nil +} diff --git a/internal/btree/btree_test.go b/internal/btree/btree_test.go new file mode 100644 index 0000000..5afc4ae --- /dev/null +++ b/internal/btree/btree_test.go @@ -0,0 +1,61 @@ +package btree + +import ( + "crypto/sha1" + "os" + "testing" + + assert2 "github.com/stretchr/testify/assert" +) + +func tempfile(t *testing.T) string { + temp, err := os.CreateTemp(".", "temp.*.db") + assert2.NoError(t, temp.Close()) + assert2.NoError(t, err) + return temp.Name() +} + +func TestCreate(t *testing.T) { + f := tempfile(t) + _, err := Create(f) + assert2.NoError(t, err) + defer os.Remove(f) +} + +func TestBtree(t *testing.T) { + f := tempfile(t) + defer os.Remove(f) + bt, err := Create(f) + assert2.NoError(t, err) + + var tests = []string{ + "hello world", + "123", + "We are met on a great battle-field of that war.", + "Abraham Lincoln, November 19, 1863, Gettysburg, Pennsylvania", + } + var sha = make([]*byte, len(tests)) + for i, tt := range tests { + var hash = sha1.New() + hash.Write([]byte(tt)) + sha[i] = &hash.Sum(nil)[0] + bt.Insert(sha[i], []byte(tt)) + } + assert2.NoError(t, bt.Close()) + + bt, err = Open(f) + assert2.NoError(t, err) + for i, tt := range tests { + assert2.Equal(t, []byte(tt), bt.Get(sha[i])) + } + + for i := range tests { + assert2.NoError(t, bt.Delete(sha[i])) + } + + for i := range tests { + assert2.Equal(t, []byte(nil), bt.Get(sha[i])) + } + + assert2.NoError(t, bt.Close()) +} diff --git a/internal/btree/chunk.go b/internal/btree/chunk.go new file mode 100644 index 0000000..4b38c14 --- /dev/null +++ b/internal/btree/chunk.go @@ -0,0 +1,122 @@ +package btree + +import ( + "math/rand" + "unsafe" +) + +type chunk struct { + offset int64 + len int +} + +const freeQueueLen = 64 + +func freeQueued(bt *DB) { + for i := 0; i < bt.fqueueLen; i++ { + chunk := &bt.fqueue[i] + bt.freeChunk(chunk.offset, chunk.len) + } + bt.fqueueLen = 0 +} + +func (d *DB) allocChunk(size int) int64 { + assert(size > 0) + + size = power2(size) + + var offset int64 + if d.inAllocator { + const i32s = unsafe.Sizeof(int32(0)) + + /* create fake size SHA-1 */ + var sha1 [hashSize]byte + p := unsafe.Pointer(&sha1[0]) + *(*int32)(p) = -1 // *(uint32_t *) hash = -1; + *(*uint32)(unsafe.Add(p, i32s)) = uint32(size) // ((__be32 *) hash)[1] = to_be32(size); + + /* find free chunk with the larger or the same size/SHA-1 */ + d.inAllocator = true + d.deleteLarger = true + offset = d.delete(d.freeTop, &sha1[0]) + d.deleteLarger = false + if offset != 0 { + assert(*(*int32)(p) == -1) // assert(*(uint32_t *) hash == (uint32_t) -1) + flen := int(*(*uint32)(unsafe.Add(p, i32s))) // size_t free_len = from_be32(((__be32 *) hash)[1]) + assert(power2(flen) == flen) + assert(flen >= size) + + /* delete buddy information */ + resethash(&sha1[0]) + *(*int64)(p) = offset + buddyLen := d.delete(d.freeTop, &sha1[0]) + assert(buddyLen == int64(size)) + + d.freeTop = collapse(d, d.freeTop) + + d.inAllocator = false + + /* free extra space at the end of the chunk */ + for flen > size { + flen >>= 1 + d.freeChunk(offset+int64(flen), flen) + } + } else { + d.inAllocator = false + } + } + if offset == 0 { + /* not found, allocate from the end of the file */ + offset = d.alloc + /* TODO: this wastes memory.. */ + if offset&int64(size-1) != 0 { + offset += int64(size) - (offset & (int64(size) - 1)) + } + d.alloc = offset + int64(size) + } + d.flushSuper() + + // make sure the allocation tree is up-to-date before using the chunk + _ = d.fd.Sync() + return offset +} + +/* Mark a chunk as unused in the database file */ +func (d *DB) freeChunk(offset int64, size int) { + assert(size > 0) + assert(offset != 0) + size = power2(size) + assert(offset&int64(size-1) == 0) + + if d.inAllocator { + chunk := &d.fqueue[d.fqueueLen] + d.fqueueLen++ + chunk.offset = offset + chunk.len = size + return + } + + /* create fake offset SHA-1 for buddy allocation */ + var sha1 [hashSize]byte + p := unsafe.Pointer(&sha1[0]) + d.inAllocator = true + + const i32s = unsafe.Sizeof(int32(0)) + + /* add buddy information */ + resethash(&sha1[0]) + *(*int32)(p) = -1 // *(uint32_t *) hash = -1; + *(*uint32)(unsafe.Add(p, i32s)) = uint32(size) // ((__be32 *) hash)[1] = to_be32(size); + *(*uint32)(unsafe.Add(p, i32s*2)) = rand.Uint32() /* to make SHA-1 unique */ + *(*uint32)(unsafe.Add(p, i32s*3)) = rand.Uint32() + + // insert_toplevel(btree, &btree->free_top, hash, NULL, offset); + _ = d.insertTopLevel(&d.freeTop, &sha1[0], nil, int(offset)) + d.inAllocator = false + + d.flushSuper() + + // make sure the allocation tree is up-to-date before removing + // references to the chunk + _ = d.fd.Sync() +} diff --git a/internal/btree/helper.go b/internal/btree/helper.go new file mode 100644 index 0000000..4a97610 --- /dev/null +++ b/internal/btree/helper.go @@ -0,0 +1,106 @@ +package btree + +import ( + "io" + "reflect" + "unsafe" +) + +func assert(cond bool) { + if !cond { + panic("assert failed!") + } +} + +// power2 returns a value that is greater or equal to 'val' and is power-of-two. +func power2(val int) int { + i := 1 + for i < val { + i <<= 1 + } + return i +} + +// helpers for hash + +func cmp(a, b *byte) int64 { + pa, pb := unsafe.Pointer(a), unsafe.Pointer(b) + if *(*uint64)(pa) != *(*uint64)(pb) { + return int64(*(*uint64)(pa) - *(*uint64)(pb)) + } + pa, pb = unsafe.Add(pa, 8), unsafe.Add(pb, 8) + return int64(*(*uint64)(pa) - *(*uint64)(pb)) +} + +func copyhash(dst *byte, src *byte) { + pa, pb := unsafe.Pointer(dst), unsafe.Pointer(src) + *(*[hashSize]byte)(pa) = *(*[hashSize]byte)(pb) +} + +func resethash(sha1 *byte) { + p := unsafe.Pointer(sha1) + *(*[hashSize]byte)(p) = [hashSize]byte{} +} + +// reading table + +func read32(r io.Reader) (int32, error) { + var b = make([]byte, 4) + _, err := r.Read(b) + if err != nil { + return 0, err + } + return *(*int32)(unsafe.Pointer(&b[0])), nil +} + +func readTable(r io.Reader, t *table) error { + buf := make([]byte, tableStructSize) + _, err := r.Read(buf) + if err != nil { + return err + } + *t = *(*table)(unsafe.Pointer(&buf[0])) + return nil +} + +func readSuper(r io.Reader, s *super) error { + buf := make([]byte, superSize) + _, err := r.Read(buf) + if err != nil { + return err + } + *s = *(*super)(unsafe.Pointer(&buf[0])) + return nil +} + +// write table + +func write32(w io.Writer, t int32) error { + var p []byte + ph := (*reflect.SliceHeader)(unsafe.Pointer(&p)) + ph.Data = uintptr(unsafe.Pointer(&t)) + ph.Len = 4 + ph.Cap = 4 + _, err := w.Write(p) + return err +} + +func writeTable(w io.Writer, t *table) error { + var p []byte + ph := (*reflect.SliceHeader)(unsafe.Pointer(&p)) + ph.Data = uintptr(unsafe.Pointer(t)) + ph.Len = tableStructSize + ph.Cap = tableStructSize + _, err := w.Write(p) + return err +} + +func writeSuper(w io.Writer, s *super) error { + var p []byte + ph := (*reflect.SliceHeader)(unsafe.Pointer(&p)) + ph.Data = uintptr(unsafe.Pointer(s)) + ph.Len = superSize + ph.Cap = superSize + _, err := w.Write(p) + return err +} diff --git a/internal/cache/cache.go b/internal/cache/cache.go new file mode 100644 index 0000000..d4f4e7f --- /dev/null +++ b/internal/cache/cache.go @@ -0,0 +1,97 @@ +// Package cache impl the cache for gocq +package cache + +import ( + "fmt" + "sync" + + log "github.com/sirupsen/logrus" + + "github.com/Mrs4s/go-cqhttp/global" + "github.com/Mrs4s/go-cqhttp/internal/base" + "github.com/Mrs4s/go-cqhttp/internal/btree" +) + +// todo(wdvxdr): always enable db-cache in v1.0.0 + +// EnableCacheDB 是否启用 btree db缓存图片等 +var EnableCacheDB bool + +// Media Cache DBs +var ( + Image Cache + Video Cache + // todo: Voice? +) + +// Cache wraps the btree.DB for concurrent safe +type Cache struct { + lock sync.RWMutex + db *btree.DB +} + +// Insert 添加媒体缓存 +func (c *Cache) Insert(md5, data []byte) { + c.lock.Lock() + defer c.lock.Unlock() + + var hash [16]byte + copy(hash[:], md5) + c.db.Insert(&hash[0], data) +} + +// Get 获取缓存信息 +func (c *Cache) Get(md5 []byte) []byte { + c.lock.RLock() + defer c.lock.RUnlock() + + var hash [16]byte + copy(hash[:], md5) + return c.db.Get(&hash[0]) +} + +// Delete 删除指定缓存 +func (c *Cache) Delete(md5 []byte) { + c.lock.Lock() + defer c.lock.Unlock() + + var hash [16]byte + copy(hash[:], md5) + _ = c.db.Delete(&hash[0]) +} + +// Init 初始化 Cache +func Init() { + node, ok := base.Database["cache"] + if !ok { + return + } + EnableCacheDB = true + var conf map[string]string + err := node.Decode(&conf) + if err != nil { + log.Fatalf("failed to read cache config: %v", err) + } + + var open = func(typ string, cache *Cache) { + file := conf[typ] + if file == "" { + file = fmt.Sprintf("data/%s.db", typ) + } + if global.PathExists(file) { + db, err := btree.Open(file) + if err != nil { + log.Fatalf("open %s cache failed: %v", typ, err) + } + cache.db = db + } else { + db, err := btree.Create(file) + if err != nil { + log.Fatalf("create %s cache failed: %v", typ, err) + } + cache.db = db + } + } + open("image", &Image) + open("video", &Video) +} diff --git a/main.go b/main.go index d60d78d..c4e236d 100644 --- a/main.go +++ b/main.go @@ -10,8 +10,6 @@ import ( "sync" "time" - "github.com/Mrs4s/go-cqhttp/internal/selfdiagnosis" - "github.com/Mrs4s/MiraiGo/binary" "github.com/Mrs4s/MiraiGo/client" para "github.com/fumiama/go-hide-param" @@ -24,6 +22,8 @@ import ( "github.com/Mrs4s/go-cqhttp/global" "github.com/Mrs4s/go-cqhttp/global/terminal" "github.com/Mrs4s/go-cqhttp/internal/base" + "github.com/Mrs4s/go-cqhttp/internal/cache" + "github.com/Mrs4s/go-cqhttp/internal/selfdiagnosis" "github.com/Mrs4s/go-cqhttp/internal/selfupdate" "github.com/Mrs4s/go-cqhttp/modules/servers" "github.com/Mrs4s/go-cqhttp/server" @@ -82,6 +82,7 @@ func main() { mkCacheDir(global.VoicePath, "语音") mkCacheDir(global.VideoPath, "视频") mkCacheDir(global.CachePath, "发送图片") + cache.Init() var byteKey []byte arg := os.Args diff --git a/modules/config/config.go b/modules/config/config.go index 31c4462..72a39b0 100644 --- a/modules/config/config.go +++ b/modules/config/config.go @@ -137,6 +137,13 @@ type LevelDBConfig struct { Enable bool `yaml:"enable"` } +// MongoDBConfig mongodb 相关配置 +type MongoDBConfig struct { + Enable bool `yaml:"enable"` + URI string `yaml:"uri"` + Database string `yaml:"database"` +} + // Parse 从默认配置文件路径中获取 func Parse(path string) *Config { fromEnv := os.Getenv("GCQ_UIN") != "" diff --git a/modules/config/default_config.yml b/modules/config/default_config.yml index ca7affd..abd2382 100644 --- a/modules/config/default_config.yml +++ b/modules/config/default_config.yml @@ -75,6 +75,11 @@ database: # 数据库相关设置 # 关闭将无法使用 撤回 回复 get_msg 等上下文相关功能 enable: true + # 媒体文件缓存, 删除此项则使用缓存文件(旧版行为) + cache: + image: data/image.db + video: data/video.db + # 连接服务列表 servers: # 添加方式,同一连接方式可添加多个,具体配置说明请查看文档 diff --git a/modules/mime/mime.go b/modules/mime/mime.go index 4563099..50d762d 100644 --- a/modules/mime/mime.go +++ b/modules/mime/mime.go @@ -43,7 +43,7 @@ func check(r io.ReadSeeker, list []string) (bool, string) { return true, "" } _, _ = r.Seek(0, io.SeekStart) - defer r.Seek(0, io.SeekStart) // nolint + defer r.Seek(0, io.SeekStart) t, err := mimetype.DetectReader(r) if err != nil { logrus.Debugf("扫描 Mime 时出现问题: %v", err)