From 5e70c8115edebb1b0a8940eaa44ecf56f40b0f65 Mon Sep 17 00:00:00 2001 From: wdvxdr Date: Sun, 17 Oct 2021 17:12:24 +0800 Subject: [PATCH] feat: modular database build with leveldb on default --- .golangci.yml | 14 ------ coolq/api.go | 17 +++----- coolq/bot.go | 30 +------------ coolq/cqcode.go | 12 +++--- coolq/event.go | 18 +++----- db/database.go | 8 ++-- db/{ => leveldb}/leveldb.go | 59 ++++++++++++++----------- db/mongodb.go | 71 ------------------------------- db/mongodb/mongodb.go | 85 +++++++++++++++++++++++++++++++++++++ db/multidb.go | 69 ++++++++++++++++++------------ main.go | 7 +++ 11 files changed, 195 insertions(+), 195 deletions(-) rename db/{ => leveldb}/leveldb.go (55%) delete mode 100644 db/mongodb.go create mode 100644 db/mongodb/mongodb.go diff --git a/.golangci.yml b/.golangci.yml index 4950c00..6c466b4 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -58,20 +58,6 @@ linters: - revive - forbidigo - makezero - #- interfacer - - # don't enable: - # - scopelint - # - gochecknoglobals - # - gocognit - # - godot - # - godox - # - goerr113 - # - interfacer - # - maligned - # - nestif - # - testpackage - # - wsl run: # default concurrency is a available CPU number. diff --git a/coolq/api.go b/coolq/api.go index 8e8778f..55c7f95 100644 --- a/coolq/api.go +++ b/coolq/api.go @@ -465,7 +465,7 @@ func (bot *CQBot) CQSendGroupForwardMessage(groupID int64, m gjson.Result) globa ts.Add(time.Second) if e.Get("data.id").Exists() { i := e.Get("data.id").Int() - m, _ := bot.db.GetGroupMessageByGlobalID(int32(i)) + m, _ := db.GetGroupMessageByGlobalID(int32(i)) if m != nil { return &message.ForwardNode{ SenderId: m.Attribute.SenderUin, @@ -781,7 +781,7 @@ func (bot *CQBot) CQProcessGroupRequest(flag, subType, reason string, approve bo // // https:// git.io/Jtz1y func (bot *CQBot) CQDeleteMessage(messageID int32) global.MSG { - msg, err := bot.db.GetMessageByGlobalID(messageID) + msg, err := db.GetMessageByGlobalID(messageID) if err != nil { log.Warnf("撤回消息时出现错误: %v", err) return Failed(100, "MESSAGE_NOT_FOUND", "消息不存在") @@ -1115,7 +1115,7 @@ func (bot *CQBot) CQGetForwardMessage(resID string) global.MSG { // // https://git.io/Jtz1b func (bot *CQBot) CQGetMessage(messageID int32) global.MSG { - msg, err := bot.db.GetMessageByGlobalID(messageID) + msg, err := db.GetMessageByGlobalID(messageID) if err != nil { log.Warnf("获取消息时出现错误: %v", err) return Failed(100, "MSG_NOT_FOUND", "消息不存在") @@ -1176,11 +1176,8 @@ func (bot *CQBot) CQGetGroupMessageHistory(groupID int64, seq int64) global.MSG } ms := make([]global.MSG, 0, len(msg)) for _, m := range msg { - id := m.Id bot.checkMedia(m.Elements) - if bot.db != nil { - id = bot.InsertGroupMessage(m) - } + id := bot.InsertGroupMessage(m) t := bot.formatGroupMessage(m) t["message_id"] = id ms = append(ms, t) @@ -1302,7 +1299,7 @@ func (bot *CQBot) CQGetStatus() global.MSG { // // https://docs.go-cqhttp.org/api/#%E8%AE%BE%E7%BD%AE%E7%B2%BE%E5%8D%8E%E6%B6%88%E6%81%AF func (bot *CQBot) CQSetEssenceMessage(messageID int32) global.MSG { - msg, err := bot.db.GetGroupMessageByGlobalID(messageID) + msg, err := db.GetGroupMessageByGlobalID(messageID) if err != nil { return Failed(100, "MESSAGE_NOT_FOUND", "消息不存在") } @@ -1317,7 +1314,7 @@ func (bot *CQBot) CQSetEssenceMessage(messageID int32) global.MSG { // // https://docs.go-cqhttp.org/api/#%E7%A7%BB%E5%87%BA%E7%B2%BE%E5%8D%8E%E6%B6%88%E6%81%AF func (bot *CQBot) CQDeleteEssenceMessage(messageID int32) global.MSG { - msg, err := bot.db.GetGroupMessageByGlobalID(messageID) + msg, err := db.GetGroupMessageByGlobalID(messageID) if err != nil { return Failed(100, "MESSAGE_NOT_FOUND", "消息不存在") } @@ -1436,7 +1433,7 @@ func (bot *CQBot) CQSetModelShow(modelName string, modelShow string) global.MSG // CQMarkMessageAsRead 标记消息已读 func (bot *CQBot) CQMarkMessageAsRead(msgID int32) global.MSG { - m, err := bot.db.GetMessageByGlobalID(msgID) + m, err := db.GetMessageByGlobalID(msgID) if err != nil { return Failed(100, "MSG_NOT_FOUND", "消息不存在") } diff --git a/coolq/bot.go b/coolq/bot.go index f87a5cf..ed5ecbd 100644 --- a/coolq/bot.go +++ b/coolq/bot.go @@ -23,7 +23,6 @@ import ( "github.com/Mrs4s/go-cqhttp/global" "github.com/Mrs4s/go-cqhttp/internal/base" - "github.com/Mrs4s/go-cqhttp/modules/config" ) // CQBot CQBot结构体,存储Bot实例相关配置 @@ -33,7 +32,6 @@ type CQBot struct { lock sync.RWMutex events []func(*Event) - db db.IDatabase friendReqCache sync.Map tempSessionCache sync.Map } @@ -71,30 +69,6 @@ func NewQQBot(cli *client.QQClient) *CQBot { bot := &CQBot{ Client: cli, } - levelNode, levelDB := base.Database["leveldb"] - mongoNode, mongoDB := base.Database["mongodb"] - multiDB := db.NewMultiDatabase() - if levelDB { - lconf := new(config.LevelDBConfig) - _ = levelNode.Decode(lconf) - if lconf.Enable { - multiDB.UseDB(db.UseLevelDB()) - } - } - if mongoDB { - lconf := new(config.MongoDBConfig) - _ = mongoNode.Decode(lconf) - if lconf.Database == "" { - lconf.Database = "gocq-database" - } - if lconf.Enable { - multiDB.UseDB(db.UseMongoDB(lconf.URI, lconf.Database)) - } - } - if err := multiDB.Open(); err != nil { - log.Fatalf("打开数据库失败: %v", err) - } - bot.db = multiDB bot.Client.OnPrivateMessage(bot.privateMessageEvent) bot.Client.OnGroupMessage(bot.groupMessageEvent) if base.ReportSelfMessage { @@ -384,7 +358,7 @@ func (bot *CQBot) InsertGroupMessage(m *message.GroupMessage) int32 { QuotedContent: ToMessageContent(reply.Elements), } } - if err := bot.db.InsertGroupMessage(msg); err != nil { + if err := db.InsertGroupMessage(msg); err != nil { log.Warnf("记录聊天数据时出现错误: %v", err) return -1 } @@ -427,7 +401,7 @@ func (bot *CQBot) InsertPrivateMessage(m *message.PrivateMessage) int32 { QuotedContent: ToMessageContent(m.Elements), } } - if err := bot.db.InsertPrivateMessage(msg); err != nil { + if err := db.InsertPrivateMessage(msg); err != nil { log.Warnf("记录聊天数据时出现错误: %v", err) return -1 } diff --git a/coolq/cqcode.go b/coolq/cqcode.go index bc9ca09..f567463 100644 --- a/coolq/cqcode.go +++ b/coolq/cqcode.go @@ -479,7 +479,7 @@ func (bot *CQBot) ConvertStringMessage(raw string, isGroup bool) (r []message.IM switch { case customText != "": var elem *message.ReplyElement - var org db.IStoredMessage + var org db.StoredMessage sender, senderErr := strconv.ParseInt(d["qq"], 10, 64) if senderErr != nil && err != nil { log.Warnf("警告: 自定义 Reply 元素中必须包含 Uin 或 id") @@ -491,7 +491,7 @@ func (bot *CQBot) ConvertStringMessage(raw string, isGroup bool) (r []message.IM } messageSeq, seqErr := strconv.ParseInt(d["seq"], 10, 64) if err == nil { - org, _ = bot.db.GetMessageByGlobalID(int32(mid)) + org, _ = db.GetMessageByGlobalID(int32(mid)) } if org != nil { elem = &message.ReplyElement{ @@ -519,7 +519,7 @@ func (bot *CQBot) ConvertStringMessage(raw string, isGroup bool) (r []message.IM } r = append([]message.IMessageElement{elem}, r...) case err == nil: - org, err := bot.db.GetMessageByGlobalID(int32(mid)) + org, err := db.GetMessageByGlobalID(int32(mid)) if err == nil { r = append([]message.IMessageElement{ &message.ReplyElement{ @@ -651,7 +651,7 @@ func (bot *CQBot) ConvertObjectMessage(m gjson.Result, isGroup bool) (r []messag switch { case customText != "": var elem *message.ReplyElement - var org db.IStoredMessage + var org db.StoredMessage sender, senderErr := strconv.ParseInt(e.Get("data.[user_id,qq]").String(), 10, 64) if senderErr != nil && err != nil { log.Warnf("警告: 自定义 Reply 元素中必须包含 user_id 或 id") @@ -663,7 +663,7 @@ func (bot *CQBot) ConvertObjectMessage(m gjson.Result, isGroup bool) (r []messag } messageSeq, seqErr := strconv.ParseInt(e.Get("data.seq").String(), 10, 64) if err == nil { - org, _ = bot.db.GetMessageByGlobalID(int32(mid)) + org, _ = db.GetMessageByGlobalID(int32(mid)) } if org != nil { elem = &message.ReplyElement{ @@ -691,7 +691,7 @@ func (bot *CQBot) ConvertObjectMessage(m gjson.Result, isGroup bool) (r []messag } r = append([]message.IMessageElement{elem}, r...) case err == nil: - org, err := bot.db.GetMessageByGlobalID(int32(mid)) + org, err := db.GetMessageByGlobalID(int32(mid)) if err == nil { r = append([]message.IMessageElement{ &message.ReplyElement{ diff --git a/coolq/event.go b/coolq/event.go index fe9e425..feac543 100644 --- a/coolq/event.go +++ b/coolq/event.go @@ -33,10 +33,7 @@ func ToFormattedMessage(e []message.IMessageElement, groupID int64, isRaw ...boo func (bot *CQBot) privateMessageEvent(c *client.QQClient, m *message.PrivateMessage) { bot.checkMedia(m.Elements) cqm := ToStringMessage(m.Elements, 0, true) - id := m.Id - if bot.db != nil { - id = bot.InsertPrivateMessage(m) - } + id := bot.InsertPrivateMessage(m) log.Infof("收到好友 %v(%v) 的消息: %v (%v)", m.Sender.DisplayName(), m.Sender.Uin, cqm, id) fm := global.MSG{ "post_type": func() string { @@ -89,10 +86,7 @@ func (bot *CQBot) groupMessageEvent(c *client.QQClient, m *message.GroupMessage) } } cqm := ToStringMessage(m.Elements, m.GroupCode, true) - id := m.Id - if bot.db != nil { - id = bot.InsertGroupMessage(m) - } + id := bot.InsertGroupMessage(m) log.Infof("收到群 %v(%v) 内 %v(%v) 的消息: %v (%v)", m.GroupName, m.GroupCode, m.Sender.DisplayName(), m.Sender.Uin, cqm, id) gm := bot.formatGroupMessage(m) if gm == nil { @@ -108,10 +102,10 @@ func (bot *CQBot) tempMessageEvent(c *client.QQClient, e *client.TempMessageEven cqm := ToStringMessage(m.Elements, 0, true) bot.tempSessionCache.Store(m.Sender.Uin, e.Session) id := m.Id - if bot.db != nil { // nolint - // todo(Mrs4s) - // id = bot.InsertTempMessage(m.Sender.Uin, m) - } + // todo(Mrs4s) + // if bot.db != nil { // nolint + // id = bot.InsertTempMessage(m.Sender.Uin, m) + // } log.Infof("收到来自群 %v(%v) 内 %v(%v) 的临时会话消息: %v", m.GroupName, m.GroupCode, m.Sender.DisplayName(), m.Sender.Uin, cqm) tm := global.MSG{ "post_type": "message", diff --git a/db/database.go b/db/database.go index 08a852b..21445e2 100644 --- a/db/database.go +++ b/db/database.go @@ -8,13 +8,13 @@ import ( ) type ( - // IDatabase 数据库操作接口定义 - IDatabase interface { + // Database 数据库操作接口定义 + Database interface { // Open 初始化数据库 Open() error // GetMessageByGlobalID 通过 GlobalID 来获取消息 - GetMessageByGlobalID(int32) (IStoredMessage, error) + GetMessageByGlobalID(int32) (StoredMessage, error) // GetGroupMessageByGlobalID 通过 GlobalID 来获取群消息 GetGroupMessageByGlobalID(int32) (*StoredGroupMessage, error) // GetPrivateMessageByGlobalID 通过 GlobalID 来获取私聊消息 @@ -26,7 +26,7 @@ type ( InsertPrivateMessage(*StoredPrivateMessage) error } - IStoredMessage interface { + StoredMessage interface { GetID() string GetType() string GetGlobalID() int32 diff --git a/db/leveldb.go b/db/leveldb/leveldb.go similarity index 55% rename from db/leveldb.go rename to db/leveldb/leveldb.go index c38d52e..ec76d43 100644 --- a/db/leveldb.go +++ b/db/leveldb/leveldb.go @@ -1,4 +1,4 @@ -package db +package leveldb import ( "bytes" @@ -9,8 +9,11 @@ import ( "github.com/pkg/errors" "github.com/syndtr/goleveldb/leveldb" "github.com/syndtr/goleveldb/leveldb/opt" + "gopkg.in/yaml.v3" + "github.com/Mrs4s/go-cqhttp/db" "github.com/Mrs4s/go-cqhttp/global" + "github.com/Mrs4s/go-cqhttp/modules/config" ) type LevelDBImpl struct { @@ -22,42 +25,50 @@ const ( private byte = 0x1 ) -func UseLevelDB() *LevelDBImpl { - gob.Register(StoredMessageAttribute{}) - gob.Register(QuotedInfo{}) +func init() { + gob.Register(db.StoredMessageAttribute{}) + gob.Register(db.QuotedInfo{}) gob.Register(global.MSG{}) - gob.Register(StoredGroupMessage{}) - gob.Register(StoredPrivateMessage{}) - return &LevelDBImpl{} + gob.Register(db.StoredGroupMessage{}) + gob.Register(db.StoredPrivateMessage{}) + + db.Register("leveldb", func(node yaml.Node) db.Database { + conf := new(config.LevelDBConfig) + _ = node.Decode(conf) + if !conf.Enable { + return nil + } + return &LevelDBImpl{} + }) } -func (db *LevelDBImpl) Open() error { +func (ldb *LevelDBImpl) Open() error { p := path.Join("data", "leveldb-v2") d, err := leveldb.OpenFile(p, &opt.Options{ WriteBuffer: 128 * opt.KiB, }) if err != nil { - return errors.Wrap(err, "open level db error") + return errors.Wrap(err, "open level ldb error") } - db.db = d + ldb.db = d return nil } -func (db *LevelDBImpl) GetMessageByGlobalID(id int32) (IStoredMessage, error) { - v, err := db.db.Get(binary.ToBytes(id), nil) +func (ldb *LevelDBImpl) GetMessageByGlobalID(id int32) (db.StoredMessage, error) { + v, err := ldb.db.Get(binary.ToBytes(id), nil) if err != nil { return nil, errors.Wrap(err, "get value error") } r := binary.NewReader(v) switch r.ReadByte() { case group: - g := &StoredGroupMessage{} + g := &db.StoredGroupMessage{} if err = gob.NewDecoder(bytes.NewReader(r.ReadAvailable())).Decode(g); err != nil { return nil, errors.Wrap(err, "decode message error") } return g, nil case private: - p := &StoredPrivateMessage{} + p := &db.StoredPrivateMessage{} if err = gob.NewDecoder(bytes.NewReader(r.ReadAvailable())).Decode(p); err != nil { return nil, errors.Wrap(err, "decode message error") } @@ -67,50 +78,50 @@ func (db *LevelDBImpl) GetMessageByGlobalID(id int32) (IStoredMessage, error) { } } -func (db *LevelDBImpl) GetGroupMessageByGlobalID(id int32) (*StoredGroupMessage, error) { - i, err := db.GetMessageByGlobalID(id) +func (ldb *LevelDBImpl) GetGroupMessageByGlobalID(id int32) (*db.StoredGroupMessage, error) { + i, err := ldb.GetMessageByGlobalID(id) if err != nil { return nil, err } - g, ok := i.(*StoredGroupMessage) + g, ok := i.(*db.StoredGroupMessage) if !ok { return nil, errors.New("message type error") } return g, nil } -func (db *LevelDBImpl) GetPrivateMessageByGlobalID(id int32) (*StoredPrivateMessage, error) { - i, err := db.GetMessageByGlobalID(id) +func (ldb *LevelDBImpl) GetPrivateMessageByGlobalID(id int32) (*db.StoredPrivateMessage, error) { + i, err := ldb.GetMessageByGlobalID(id) if err != nil { return nil, err } - p, ok := i.(*StoredPrivateMessage) + p, ok := i.(*db.StoredPrivateMessage) if !ok { return nil, errors.New("message type error") } return p, nil } -func (db *LevelDBImpl) InsertGroupMessage(msg *StoredGroupMessage) error { +func (ldb *LevelDBImpl) InsertGroupMessage(msg *db.StoredGroupMessage) error { buf := global.NewBuffer() defer global.PutBuffer(buf) if err := gob.NewEncoder(buf).Encode(msg); err != nil { return errors.Wrap(err, "encode message error") } - err := db.db.Put(binary.ToBytes(msg.GlobalID), binary.NewWriterF(func(w *binary.Writer) { + err := ldb.db.Put(binary.ToBytes(msg.GlobalID), binary.NewWriterF(func(w *binary.Writer) { w.WriteByte(group) w.Write(buf.Bytes()) }), nil) return errors.Wrap(err, "put data error") } -func (db *LevelDBImpl) InsertPrivateMessage(msg *StoredPrivateMessage) error { +func (ldb *LevelDBImpl) InsertPrivateMessage(msg *db.StoredPrivateMessage) error { buf := global.NewBuffer() defer global.PutBuffer(buf) if err := gob.NewEncoder(buf).Encode(msg); err != nil { return errors.Wrap(err, "encode message error") } - err := db.db.Put(binary.ToBytes(msg.GlobalID), binary.NewWriterF(func(w *binary.Writer) { + err := ldb.db.Put(binary.ToBytes(msg.GlobalID), binary.NewWriterF(func(w *binary.Writer) { w.WriteByte(private) w.Write(buf.Bytes()) }), nil) diff --git a/db/mongodb.go b/db/mongodb.go deleted file mode 100644 index 00677b1..0000000 --- a/db/mongodb.go +++ /dev/null @@ -1,71 +0,0 @@ -package db - -import ( - "context" - - "github.com/pkg/errors" - "go.mongodb.org/mongo-driver/bson" - "go.mongodb.org/mongo-driver/mongo" - "go.mongodb.org/mongo-driver/mongo/options" -) - -type MongoDBImpl struct { - uri string - db string - mongo *mongo.Database -} - -const ( - MongoGroupMessageCollection = "group-messages" - MongoPrivateMessageCollection = "private-messages" -) - -func UseMongoDB(uri, db string) *MongoDBImpl { - return &MongoDBImpl{uri: uri, db: db} -} - -func (db *MongoDBImpl) Open() error { - cli, err := mongo.Connect(context.Background(), options.Client().ApplyURI(db.uri)) - if err != nil { - return errors.Wrap(err, "open mongo connection error") - } - db.mongo = cli.Database(db.db) - return nil -} - -func (db *MongoDBImpl) GetMessageByGlobalID(id int32) (IStoredMessage, error) { - if r, err := db.GetGroupMessageByGlobalID(id); err == nil { - return r, nil - } - return db.GetPrivateMessageByGlobalID(id) -} - -func (db *MongoDBImpl) GetGroupMessageByGlobalID(id int32) (*StoredGroupMessage, error) { - coll := db.mongo.Collection(MongoGroupMessageCollection) - var ret StoredGroupMessage - if err := coll.FindOne(context.Background(), bson.D{{"globalId", id}}).Decode(&ret); err != nil { - return nil, errors.Wrap(err, "query error") - } - return &ret, nil -} - -func (db *MongoDBImpl) GetPrivateMessageByGlobalID(id int32) (*StoredPrivateMessage, error) { - coll := db.mongo.Collection(MongoPrivateMessageCollection) - var ret StoredPrivateMessage - if err := coll.FindOne(context.Background(), bson.D{{"globalId", id}}).Decode(&ret); err != nil { - return nil, errors.Wrap(err, "query error") - } - return &ret, nil -} - -func (db *MongoDBImpl) InsertGroupMessage(msg *StoredGroupMessage) error { - coll := db.mongo.Collection(MongoGroupMessageCollection) - _, err := coll.InsertOne(context.Background(), msg) - return errors.Wrap(err, "insert error") -} - -func (db *MongoDBImpl) InsertPrivateMessage(msg *StoredPrivateMessage) error { - coll := db.mongo.Collection(MongoPrivateMessageCollection) - _, err := coll.InsertOne(context.Background(), msg) - return errors.Wrap(err, "insert error") -} diff --git a/db/mongodb/mongodb.go b/db/mongodb/mongodb.go new file mode 100644 index 0000000..6387d5a --- /dev/null +++ b/db/mongodb/mongodb.go @@ -0,0 +1,85 @@ +package mongodb + +import ( + "context" + + "github.com/pkg/errors" + "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/mongo" + "go.mongodb.org/mongo-driver/mongo/options" + "gopkg.in/yaml.v3" + + "github.com/Mrs4s/go-cqhttp/db" + "github.com/Mrs4s/go-cqhttp/modules/config" +) + +type MongoDBImpl struct { + uri string + db string + mongo *mongo.Database +} + +const ( + MongoGroupMessageCollection = "group-messages" + MongoPrivateMessageCollection = "private-messages" +) + +func init() { + db.Register("mongodb", func(node yaml.Node) db.Database { + conf := new(config.MongoDBConfig) + _ = node.Decode(conf) + if conf.Database == "" { + conf.Database = "gocq-database" + } + if !conf.Enable { + return nil + } + return &MongoDBImpl{uri: conf.URI, db: conf.Database} + }) +} + +func (m *MongoDBImpl) Open() error { + cli, err := mongo.Connect(context.Background(), options.Client().ApplyURI(m.uri)) + if err != nil { + return errors.Wrap(err, "open mongo connection error") + } + m.mongo = cli.Database(m.db) + return nil +} + +func (m *MongoDBImpl) GetMessageByGlobalID(id int32) (db.StoredMessage, error) { + if r, err := m.GetGroupMessageByGlobalID(id); err == nil { + return r, nil + } + return m.GetPrivateMessageByGlobalID(id) +} + +func (m *MongoDBImpl) GetGroupMessageByGlobalID(id int32) (*db.StoredGroupMessage, error) { + coll := m.mongo.Collection(MongoGroupMessageCollection) + var ret db.StoredGroupMessage + if err := coll.FindOne(context.Background(), bson.D{{"globalId", id}}).Decode(&ret); err != nil { + return nil, errors.Wrap(err, "query error") + } + return &ret, nil +} + +func (m *MongoDBImpl) GetPrivateMessageByGlobalID(id int32) (*db.StoredPrivateMessage, error) { + coll := m.mongo.Collection(MongoPrivateMessageCollection) + var ret db.StoredPrivateMessage + if err := coll.FindOne(context.Background(), bson.D{{"globalId", id}}).Decode(&ret); err != nil { + return nil, errors.Wrap(err, "query error") + } + return &ret, nil +} + +func (m *MongoDBImpl) InsertGroupMessage(msg *db.StoredGroupMessage) error { + coll := m.mongo.Collection(MongoGroupMessageCollection) + _, err := coll.InsertOne(context.Background(), msg) + return errors.Wrap(err, "insert error") +} + +func (m *MongoDBImpl) InsertPrivateMessage(msg *db.StoredPrivateMessage) error { + coll := m.mongo.Collection(MongoPrivateMessageCollection) + _, err := coll.InsertOne(context.Background(), msg) + return errors.Wrap(err, "insert error") +} diff --git a/db/multidb.go b/db/multidb.go index 5d7a78b..f878c9c 100644 --- a/db/multidb.go +++ b/db/multidb.go @@ -1,26 +1,43 @@ package db -import "github.com/pkg/errors" +import ( + "github.com/pkg/errors" + "gopkg.in/yaml.v3" -// MultiDBImpl 多数据库支持 + "github.com/Mrs4s/go-cqhttp/internal/base" +) + +// backends 多数据库支持, 后端支持 // 写入会对所有 Backend 进行写入 // 读取只会读取第一个库 -type MultiDBImpl struct { - backends []IDatabase +var backends []Database + +// drivers 多数据库启动 +var drivers = make(map[string]func(node yaml.Node) Database) + +// Register 添加数据库后端 +func Register(name string, init func(yaml.Node) Database) { + if _, ok := drivers[name]; ok { + panic("database driver conflict: " + name) + } + drivers[name] = init } -func NewMultiDatabase(backends ...IDatabase) *MultiDBImpl { - return &MultiDBImpl{ - backends: backends, +// Init 加载所有后端配置文件 +func Init() { + backends = make([]Database, 0, len(drivers)) + for name, init := range drivers { + if n, ok := base.Database[name]; ok { + db := init(n) + if db != nil { + backends = append(backends, db) + } + } } } -func (db *MultiDBImpl) UseDB(backend IDatabase) { - db.backends = append(db.backends, backend) -} - -func (db *MultiDBImpl) Open() error { - for _, b := range db.backends { +func Open() error { + for _, b := range backends { if err := b.Open(); err != nil { return errors.Wrap(err, "open backend error") } @@ -28,29 +45,29 @@ func (db *MultiDBImpl) Open() error { return nil } -func (db *MultiDBImpl) GetMessageByGlobalID(id int32) (IStoredMessage, error) { - if len(db.backends) == 0 { +func GetMessageByGlobalID(id int32) (StoredMessage, error) { + if len(backends) == 0 { return nil, errors.New("database disabled") } - return db.backends[0].GetMessageByGlobalID(id) + return backends[0].GetMessageByGlobalID(id) } -func (db *MultiDBImpl) GetGroupMessageByGlobalID(id int32) (*StoredGroupMessage, error) { - if len(db.backends) == 0 { +func GetGroupMessageByGlobalID(id int32) (*StoredGroupMessage, error) { + if len(backends) == 0 { return nil, errors.New("database disabled") } - return db.backends[0].GetGroupMessageByGlobalID(id) + return backends[0].GetGroupMessageByGlobalID(id) } -func (db *MultiDBImpl) GetPrivateMessageByGlobalID(id int32) (*StoredPrivateMessage, error) { - if len(db.backends) == 0 { +func GetPrivateMessageByGlobalID(id int32) (*StoredPrivateMessage, error) { + if len(backends) == 0 { return nil, errors.New("database disabled") } - return db.backends[0].GetPrivateMessageByGlobalID(id) + return backends[0].GetPrivateMessageByGlobalID(id) } -func (db *MultiDBImpl) InsertGroupMessage(m *StoredGroupMessage) error { - for _, b := range db.backends { +func InsertGroupMessage(m *StoredGroupMessage) error { + for _, b := range backends { if err := b.InsertGroupMessage(m); err != nil { return errors.Wrap(err, "insert message to backend error") } @@ -58,8 +75,8 @@ func (db *MultiDBImpl) InsertGroupMessage(m *StoredGroupMessage) error { return nil } -func (db *MultiDBImpl) InsertPrivateMessage(m *StoredPrivateMessage) error { - for _, b := range db.backends { +func InsertPrivateMessage(m *StoredPrivateMessage) error { + for _, b := range backends { if err := b.InsertPrivateMessage(m); err != nil { return errors.Wrap(err, "insert message to backend error") } diff --git a/main.go b/main.go index c4e236d..2e66c63 100644 --- a/main.go +++ b/main.go @@ -19,6 +19,7 @@ import ( "golang.org/x/term" "github.com/Mrs4s/go-cqhttp/coolq" + "github.com/Mrs4s/go-cqhttp/db" "github.com/Mrs4s/go-cqhttp/global" "github.com/Mrs4s/go-cqhttp/global/terminal" "github.com/Mrs4s/go-cqhttp/internal/base" @@ -28,6 +29,7 @@ import ( "github.com/Mrs4s/go-cqhttp/modules/servers" "github.com/Mrs4s/go-cqhttp/server" + _ "github.com/Mrs4s/go-cqhttp/db/leveldb" // leveldb _ "github.com/Mrs4s/go-cqhttp/modules/mime" // mime检查模块 _ "github.com/Mrs4s/go-cqhttp/modules/pprof" // pprof 性能分析 _ "github.com/Mrs4s/go-cqhttp/modules/silk" // silk编码模块 @@ -84,6 +86,11 @@ func main() { mkCacheDir(global.CachePath, "发送图片") cache.Init() + db.Init() + if err := db.Open(); err != nil { + log.Fatalf("打开数据库失败: %v", err) + } + var byteKey []byte arg := os.Args if len(arg) > 1 {