diff --git a/coolq/bot.go b/coolq/bot.go index 2ba3785..c21262e 100644 --- a/coolq/bot.go +++ b/coolq/bot.go @@ -107,16 +107,14 @@ func NewQQBot(cli *client.QQClient, conf *config.Config) *CQBot { _ = node.Decode(lconf) enableLevelDB = lconf.Enable } + multiDB := db.NewMultiDatabase() if enableLevelDB { - level := db.UseLevelDB() - if err := level.Open(); err != nil { - log.Fatalf("打开数据库失败: %v", err) - } - bot.db = level - log.Info("信息数据库初始化完成.") - } else { - log.Warn("警告: 信息数据库已关闭,将无法使用 [回复/撤回] 等功能。") + multiDB.UseDB(db.UseLevelDB()) } + if err := multiDB.Open(); err != nil { + log.Fatalf("打开数据库失败: %v", err) + } + bot.db = multiDB bot.Client.OnPrivateMessage(bot.privateMessageEvent) bot.Client.OnGroupMessage(bot.groupMessageEvent) if conf.Message.ReportSelfMessage { diff --git a/db/database.go b/db/database.go index 0801609..23a69e3 100644 --- a/db/database.go +++ b/db/database.go @@ -72,10 +72,6 @@ type ( PrevGlobalID int32 `bson:"prevGlobalId"` QuotedContent []global.MSG `bson:"quotedContent"` } - - // MultiDatabase todo - MultiDatabase struct { - } ) // ToGlobalID 构建`code`-`msgID`的字符串并返回其CRC32 Checksum的值 diff --git a/db/multidb.go b/db/multidb.go new file mode 100644 index 0000000..b4fc5cd --- /dev/null +++ b/db/multidb.go @@ -0,0 +1,68 @@ +package db + +import "github.com/pkg/errors" + +// MultiDatabase 多数据库支持 +// 写入会对所有 Backend 进行写入 +// 读取只会读取第一个库 +type MultiDatabase struct { + backends []IDatabase +} + +func NewMultiDatabase(backends ...IDatabase) *MultiDatabase { + return &MultiDatabase{ + backends: backends, + } +} + +func (db *MultiDatabase) UseDB(backend IDatabase) { + db.backends = append(db.backends, backend) +} + +func (db *MultiDatabase) Open() error { + for _, b := range db.backends { + if err := b.Open(); err != nil { + return errors.Wrap(err, "open backend error") + } + } + return nil +} + +func (db *MultiDatabase) GetMessageByGlobalID(id int32) (IStoredMessage, error) { + if len(db.backends) == 0 { + return nil, errors.New("database disabled") + } + return db.backends[0].GetMessageByGlobalID(id) +} + +func (db *MultiDatabase) GetGroupMessageByGlobalID(id int32) (*StoredGroupMessage, error) { + if len(db.backends) == 0 { + return nil, errors.New("database disabled") + } + return db.backends[0].GetGroupMessageByGlobalID(id) +} + +func (db *MultiDatabase) GetPrivateMessageByGlobalID(id int32) (*StoredPrivateMessage, error) { + if len(db.backends) == 0 { + return nil, errors.New("database disabled") + } + return db.backends[0].GetPrivateMessageByGlobalID(id) +} + +func (db *MultiDatabase) InsertGroupMessage(m *StoredGroupMessage) error { + for _, b := range db.backends { + if err := b.InsertGroupMessage(m); err != nil { + return errors.Wrap(err, "insert message to backend error") + } + } + return nil +} + +func (db *MultiDatabase) InsertPrivateMessage(m *StoredPrivateMessage) error { + for _, b := range db.backends { + if err := b.InsertPrivateMessage(m); err != nil { + return errors.Wrap(err, "insert message to backend error") + } + } + return nil +}