1
0
mirror of https://github.com/Mrs4s/go-cqhttp.git synced 2025-05-05 03:23:49 +08:00

feat: modular database

build with leveldb on default
This commit is contained in:
wdvxdr 2021-10-17 17:12:24 +08:00
parent 4e69ef4c2f
commit 5e70c8115e
No known key found for this signature in database
GPG Key ID: 703F8C071DE7A1B6
11 changed files with 195 additions and 195 deletions

View File

@ -58,20 +58,6 @@ linters:
- revive - revive
- forbidigo - forbidigo
- makezero - makezero
#- interfacer
# don't enable:
# - scopelint
# - gochecknoglobals
# - gocognit
# - godot
# - godox
# - goerr113
# - interfacer
# - maligned
# - nestif
# - testpackage
# - wsl
run: run:
# default concurrency is a available CPU number. # default concurrency is a available CPU number.

View File

@ -465,7 +465,7 @@ func (bot *CQBot) CQSendGroupForwardMessage(groupID int64, m gjson.Result) globa
ts.Add(time.Second) ts.Add(time.Second)
if e.Get("data.id").Exists() { if e.Get("data.id").Exists() {
i := e.Get("data.id").Int() i := e.Get("data.id").Int()
m, _ := bot.db.GetGroupMessageByGlobalID(int32(i)) m, _ := db.GetGroupMessageByGlobalID(int32(i))
if m != nil { if m != nil {
return &message.ForwardNode{ return &message.ForwardNode{
SenderId: m.Attribute.SenderUin, SenderId: m.Attribute.SenderUin,
@ -781,7 +781,7 @@ func (bot *CQBot) CQProcessGroupRequest(flag, subType, reason string, approve bo
// //
// https:// git.io/Jtz1y // https:// git.io/Jtz1y
func (bot *CQBot) CQDeleteMessage(messageID int32) global.MSG { func (bot *CQBot) CQDeleteMessage(messageID int32) global.MSG {
msg, err := bot.db.GetMessageByGlobalID(messageID) msg, err := db.GetMessageByGlobalID(messageID)
if err != nil { if err != nil {
log.Warnf("撤回消息时出现错误: %v", err) log.Warnf("撤回消息时出现错误: %v", err)
return Failed(100, "MESSAGE_NOT_FOUND", "消息不存在") return Failed(100, "MESSAGE_NOT_FOUND", "消息不存在")
@ -1115,7 +1115,7 @@ func (bot *CQBot) CQGetForwardMessage(resID string) global.MSG {
// //
// https://git.io/Jtz1b // https://git.io/Jtz1b
func (bot *CQBot) CQGetMessage(messageID int32) global.MSG { func (bot *CQBot) CQGetMessage(messageID int32) global.MSG {
msg, err := bot.db.GetMessageByGlobalID(messageID) msg, err := db.GetMessageByGlobalID(messageID)
if err != nil { if err != nil {
log.Warnf("获取消息时出现错误: %v", err) log.Warnf("获取消息时出现错误: %v", err)
return Failed(100, "MSG_NOT_FOUND", "消息不存在") return Failed(100, "MSG_NOT_FOUND", "消息不存在")
@ -1176,11 +1176,8 @@ func (bot *CQBot) CQGetGroupMessageHistory(groupID int64, seq int64) global.MSG
} }
ms := make([]global.MSG, 0, len(msg)) ms := make([]global.MSG, 0, len(msg))
for _, m := range msg { for _, m := range msg {
id := m.Id
bot.checkMedia(m.Elements) bot.checkMedia(m.Elements)
if bot.db != nil { id := bot.InsertGroupMessage(m)
id = bot.InsertGroupMessage(m)
}
t := bot.formatGroupMessage(m) t := bot.formatGroupMessage(m)
t["message_id"] = id t["message_id"] = id
ms = append(ms, t) ms = append(ms, t)
@ -1302,7 +1299,7 @@ func (bot *CQBot) CQGetStatus() global.MSG {
// //
// https://docs.go-cqhttp.org/api/#%E8%AE%BE%E7%BD%AE%E7%B2%BE%E5%8D%8E%E6%B6%88%E6%81%AF // https://docs.go-cqhttp.org/api/#%E8%AE%BE%E7%BD%AE%E7%B2%BE%E5%8D%8E%E6%B6%88%E6%81%AF
func (bot *CQBot) CQSetEssenceMessage(messageID int32) global.MSG { func (bot *CQBot) CQSetEssenceMessage(messageID int32) global.MSG {
msg, err := bot.db.GetGroupMessageByGlobalID(messageID) msg, err := db.GetGroupMessageByGlobalID(messageID)
if err != nil { if err != nil {
return Failed(100, "MESSAGE_NOT_FOUND", "消息不存在") return Failed(100, "MESSAGE_NOT_FOUND", "消息不存在")
} }
@ -1317,7 +1314,7 @@ func (bot *CQBot) CQSetEssenceMessage(messageID int32) global.MSG {
// //
// https://docs.go-cqhttp.org/api/#%E7%A7%BB%E5%87%BA%E7%B2%BE%E5%8D%8E%E6%B6%88%E6%81%AF // https://docs.go-cqhttp.org/api/#%E7%A7%BB%E5%87%BA%E7%B2%BE%E5%8D%8E%E6%B6%88%E6%81%AF
func (bot *CQBot) CQDeleteEssenceMessage(messageID int32) global.MSG { func (bot *CQBot) CQDeleteEssenceMessage(messageID int32) global.MSG {
msg, err := bot.db.GetGroupMessageByGlobalID(messageID) msg, err := db.GetGroupMessageByGlobalID(messageID)
if err != nil { if err != nil {
return Failed(100, "MESSAGE_NOT_FOUND", "消息不存在") return Failed(100, "MESSAGE_NOT_FOUND", "消息不存在")
} }
@ -1436,7 +1433,7 @@ func (bot *CQBot) CQSetModelShow(modelName string, modelShow string) global.MSG
// CQMarkMessageAsRead 标记消息已读 // CQMarkMessageAsRead 标记消息已读
func (bot *CQBot) CQMarkMessageAsRead(msgID int32) global.MSG { func (bot *CQBot) CQMarkMessageAsRead(msgID int32) global.MSG {
m, err := bot.db.GetMessageByGlobalID(msgID) m, err := db.GetMessageByGlobalID(msgID)
if err != nil { if err != nil {
return Failed(100, "MSG_NOT_FOUND", "消息不存在") return Failed(100, "MSG_NOT_FOUND", "消息不存在")
} }

View File

@ -23,7 +23,6 @@ import (
"github.com/Mrs4s/go-cqhttp/global" "github.com/Mrs4s/go-cqhttp/global"
"github.com/Mrs4s/go-cqhttp/internal/base" "github.com/Mrs4s/go-cqhttp/internal/base"
"github.com/Mrs4s/go-cqhttp/modules/config"
) )
// CQBot CQBot结构体,存储Bot实例相关配置 // CQBot CQBot结构体,存储Bot实例相关配置
@ -33,7 +32,6 @@ type CQBot struct {
lock sync.RWMutex lock sync.RWMutex
events []func(*Event) events []func(*Event)
db db.IDatabase
friendReqCache sync.Map friendReqCache sync.Map
tempSessionCache sync.Map tempSessionCache sync.Map
} }
@ -71,30 +69,6 @@ func NewQQBot(cli *client.QQClient) *CQBot {
bot := &CQBot{ bot := &CQBot{
Client: cli, Client: cli,
} }
levelNode, levelDB := base.Database["leveldb"]
mongoNode, mongoDB := base.Database["mongodb"]
multiDB := db.NewMultiDatabase()
if levelDB {
lconf := new(config.LevelDBConfig)
_ = levelNode.Decode(lconf)
if lconf.Enable {
multiDB.UseDB(db.UseLevelDB())
}
}
if mongoDB {
lconf := new(config.MongoDBConfig)
_ = mongoNode.Decode(lconf)
if lconf.Database == "" {
lconf.Database = "gocq-database"
}
if lconf.Enable {
multiDB.UseDB(db.UseMongoDB(lconf.URI, lconf.Database))
}
}
if err := multiDB.Open(); err != nil {
log.Fatalf("打开数据库失败: %v", err)
}
bot.db = multiDB
bot.Client.OnPrivateMessage(bot.privateMessageEvent) bot.Client.OnPrivateMessage(bot.privateMessageEvent)
bot.Client.OnGroupMessage(bot.groupMessageEvent) bot.Client.OnGroupMessage(bot.groupMessageEvent)
if base.ReportSelfMessage { if base.ReportSelfMessage {
@ -384,7 +358,7 @@ func (bot *CQBot) InsertGroupMessage(m *message.GroupMessage) int32 {
QuotedContent: ToMessageContent(reply.Elements), QuotedContent: ToMessageContent(reply.Elements),
} }
} }
if err := bot.db.InsertGroupMessage(msg); err != nil { if err := db.InsertGroupMessage(msg); err != nil {
log.Warnf("记录聊天数据时出现错误: %v", err) log.Warnf("记录聊天数据时出现错误: %v", err)
return -1 return -1
} }
@ -427,7 +401,7 @@ func (bot *CQBot) InsertPrivateMessage(m *message.PrivateMessage) int32 {
QuotedContent: ToMessageContent(m.Elements), QuotedContent: ToMessageContent(m.Elements),
} }
} }
if err := bot.db.InsertPrivateMessage(msg); err != nil { if err := db.InsertPrivateMessage(msg); err != nil {
log.Warnf("记录聊天数据时出现错误: %v", err) log.Warnf("记录聊天数据时出现错误: %v", err)
return -1 return -1
} }

View File

@ -479,7 +479,7 @@ func (bot *CQBot) ConvertStringMessage(raw string, isGroup bool) (r []message.IM
switch { switch {
case customText != "": case customText != "":
var elem *message.ReplyElement var elem *message.ReplyElement
var org db.IStoredMessage var org db.StoredMessage
sender, senderErr := strconv.ParseInt(d["qq"], 10, 64) sender, senderErr := strconv.ParseInt(d["qq"], 10, 64)
if senderErr != nil && err != nil { if senderErr != nil && err != nil {
log.Warnf("警告: 自定义 Reply 元素中必须包含 Uin 或 id") log.Warnf("警告: 自定义 Reply 元素中必须包含 Uin 或 id")
@ -491,7 +491,7 @@ func (bot *CQBot) ConvertStringMessage(raw string, isGroup bool) (r []message.IM
} }
messageSeq, seqErr := strconv.ParseInt(d["seq"], 10, 64) messageSeq, seqErr := strconv.ParseInt(d["seq"], 10, 64)
if err == nil { if err == nil {
org, _ = bot.db.GetMessageByGlobalID(int32(mid)) org, _ = db.GetMessageByGlobalID(int32(mid))
} }
if org != nil { if org != nil {
elem = &message.ReplyElement{ elem = &message.ReplyElement{
@ -519,7 +519,7 @@ func (bot *CQBot) ConvertStringMessage(raw string, isGroup bool) (r []message.IM
} }
r = append([]message.IMessageElement{elem}, r...) r = append([]message.IMessageElement{elem}, r...)
case err == nil: case err == nil:
org, err := bot.db.GetMessageByGlobalID(int32(mid)) org, err := db.GetMessageByGlobalID(int32(mid))
if err == nil { if err == nil {
r = append([]message.IMessageElement{ r = append([]message.IMessageElement{
&message.ReplyElement{ &message.ReplyElement{
@ -651,7 +651,7 @@ func (bot *CQBot) ConvertObjectMessage(m gjson.Result, isGroup bool) (r []messag
switch { switch {
case customText != "": case customText != "":
var elem *message.ReplyElement var elem *message.ReplyElement
var org db.IStoredMessage var org db.StoredMessage
sender, senderErr := strconv.ParseInt(e.Get("data.[user_id,qq]").String(), 10, 64) sender, senderErr := strconv.ParseInt(e.Get("data.[user_id,qq]").String(), 10, 64)
if senderErr != nil && err != nil { if senderErr != nil && err != nil {
log.Warnf("警告: 自定义 Reply 元素中必须包含 user_id 或 id") log.Warnf("警告: 自定义 Reply 元素中必须包含 user_id 或 id")
@ -663,7 +663,7 @@ func (bot *CQBot) ConvertObjectMessage(m gjson.Result, isGroup bool) (r []messag
} }
messageSeq, seqErr := strconv.ParseInt(e.Get("data.seq").String(), 10, 64) messageSeq, seqErr := strconv.ParseInt(e.Get("data.seq").String(), 10, 64)
if err == nil { if err == nil {
org, _ = bot.db.GetMessageByGlobalID(int32(mid)) org, _ = db.GetMessageByGlobalID(int32(mid))
} }
if org != nil { if org != nil {
elem = &message.ReplyElement{ elem = &message.ReplyElement{
@ -691,7 +691,7 @@ func (bot *CQBot) ConvertObjectMessage(m gjson.Result, isGroup bool) (r []messag
} }
r = append([]message.IMessageElement{elem}, r...) r = append([]message.IMessageElement{elem}, r...)
case err == nil: case err == nil:
org, err := bot.db.GetMessageByGlobalID(int32(mid)) org, err := db.GetMessageByGlobalID(int32(mid))
if err == nil { if err == nil {
r = append([]message.IMessageElement{ r = append([]message.IMessageElement{
&message.ReplyElement{ &message.ReplyElement{

View File

@ -33,10 +33,7 @@ func ToFormattedMessage(e []message.IMessageElement, groupID int64, isRaw ...boo
func (bot *CQBot) privateMessageEvent(c *client.QQClient, m *message.PrivateMessage) { func (bot *CQBot) privateMessageEvent(c *client.QQClient, m *message.PrivateMessage) {
bot.checkMedia(m.Elements) bot.checkMedia(m.Elements)
cqm := ToStringMessage(m.Elements, 0, true) cqm := ToStringMessage(m.Elements, 0, true)
id := m.Id id := bot.InsertPrivateMessage(m)
if bot.db != nil {
id = bot.InsertPrivateMessage(m)
}
log.Infof("收到好友 %v(%v) 的消息: %v (%v)", m.Sender.DisplayName(), m.Sender.Uin, cqm, id) log.Infof("收到好友 %v(%v) 的消息: %v (%v)", m.Sender.DisplayName(), m.Sender.Uin, cqm, id)
fm := global.MSG{ fm := global.MSG{
"post_type": func() string { "post_type": func() string {
@ -89,10 +86,7 @@ func (bot *CQBot) groupMessageEvent(c *client.QQClient, m *message.GroupMessage)
} }
} }
cqm := ToStringMessage(m.Elements, m.GroupCode, true) cqm := ToStringMessage(m.Elements, m.GroupCode, true)
id := m.Id id := bot.InsertGroupMessage(m)
if bot.db != nil {
id = bot.InsertGroupMessage(m)
}
log.Infof("收到群 %v(%v) 内 %v(%v) 的消息: %v (%v)", m.GroupName, m.GroupCode, m.Sender.DisplayName(), m.Sender.Uin, cqm, id) log.Infof("收到群 %v(%v) 内 %v(%v) 的消息: %v (%v)", m.GroupName, m.GroupCode, m.Sender.DisplayName(), m.Sender.Uin, cqm, id)
gm := bot.formatGroupMessage(m) gm := bot.formatGroupMessage(m)
if gm == nil { if gm == nil {
@ -108,10 +102,10 @@ func (bot *CQBot) tempMessageEvent(c *client.QQClient, e *client.TempMessageEven
cqm := ToStringMessage(m.Elements, 0, true) cqm := ToStringMessage(m.Elements, 0, true)
bot.tempSessionCache.Store(m.Sender.Uin, e.Session) bot.tempSessionCache.Store(m.Sender.Uin, e.Session)
id := m.Id id := m.Id
if bot.db != nil { // nolint // todo(Mrs4s)
// todo(Mrs4s) // if bot.db != nil { // nolint
// id = bot.InsertTempMessage(m.Sender.Uin, m) // id = bot.InsertTempMessage(m.Sender.Uin, m)
} // }
log.Infof("收到来自群 %v(%v) 内 %v(%v) 的临时会话消息: %v", m.GroupName, m.GroupCode, m.Sender.DisplayName(), m.Sender.Uin, cqm) log.Infof("收到来自群 %v(%v) 内 %v(%v) 的临时会话消息: %v", m.GroupName, m.GroupCode, m.Sender.DisplayName(), m.Sender.Uin, cqm)
tm := global.MSG{ tm := global.MSG{
"post_type": "message", "post_type": "message",

View File

@ -8,13 +8,13 @@ import (
) )
type ( type (
// IDatabase 数据库操作接口定义 // Database 数据库操作接口定义
IDatabase interface { Database interface {
// Open 初始化数据库 // Open 初始化数据库
Open() error Open() error
// GetMessageByGlobalID 通过 GlobalID 来获取消息 // GetMessageByGlobalID 通过 GlobalID 来获取消息
GetMessageByGlobalID(int32) (IStoredMessage, error) GetMessageByGlobalID(int32) (StoredMessage, error)
// GetGroupMessageByGlobalID 通过 GlobalID 来获取群消息 // GetGroupMessageByGlobalID 通过 GlobalID 来获取群消息
GetGroupMessageByGlobalID(int32) (*StoredGroupMessage, error) GetGroupMessageByGlobalID(int32) (*StoredGroupMessage, error)
// GetPrivateMessageByGlobalID 通过 GlobalID 来获取私聊消息 // GetPrivateMessageByGlobalID 通过 GlobalID 来获取私聊消息
@ -26,7 +26,7 @@ type (
InsertPrivateMessage(*StoredPrivateMessage) error InsertPrivateMessage(*StoredPrivateMessage) error
} }
IStoredMessage interface { StoredMessage interface {
GetID() string GetID() string
GetType() string GetType() string
GetGlobalID() int32 GetGlobalID() int32

View File

@ -1,4 +1,4 @@
package db package leveldb
import ( import (
"bytes" "bytes"
@ -9,8 +9,11 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/syndtr/goleveldb/leveldb" "github.com/syndtr/goleveldb/leveldb"
"github.com/syndtr/goleveldb/leveldb/opt" "github.com/syndtr/goleveldb/leveldb/opt"
"gopkg.in/yaml.v3"
"github.com/Mrs4s/go-cqhttp/db"
"github.com/Mrs4s/go-cqhttp/global" "github.com/Mrs4s/go-cqhttp/global"
"github.com/Mrs4s/go-cqhttp/modules/config"
) )
type LevelDBImpl struct { type LevelDBImpl struct {
@ -22,42 +25,50 @@ const (
private byte = 0x1 private byte = 0x1
) )
func UseLevelDB() *LevelDBImpl { func init() {
gob.Register(StoredMessageAttribute{}) gob.Register(db.StoredMessageAttribute{})
gob.Register(QuotedInfo{}) gob.Register(db.QuotedInfo{})
gob.Register(global.MSG{}) gob.Register(global.MSG{})
gob.Register(StoredGroupMessage{}) gob.Register(db.StoredGroupMessage{})
gob.Register(StoredPrivateMessage{}) gob.Register(db.StoredPrivateMessage{})
return &LevelDBImpl{}
db.Register("leveldb", func(node yaml.Node) db.Database {
conf := new(config.LevelDBConfig)
_ = node.Decode(conf)
if !conf.Enable {
return nil
}
return &LevelDBImpl{}
})
} }
func (db *LevelDBImpl) Open() error { func (ldb *LevelDBImpl) Open() error {
p := path.Join("data", "leveldb-v2") p := path.Join("data", "leveldb-v2")
d, err := leveldb.OpenFile(p, &opt.Options{ d, err := leveldb.OpenFile(p, &opt.Options{
WriteBuffer: 128 * opt.KiB, WriteBuffer: 128 * opt.KiB,
}) })
if err != nil { if err != nil {
return errors.Wrap(err, "open level db error") return errors.Wrap(err, "open level ldb error")
} }
db.db = d ldb.db = d
return nil return nil
} }
func (db *LevelDBImpl) GetMessageByGlobalID(id int32) (IStoredMessage, error) { func (ldb *LevelDBImpl) GetMessageByGlobalID(id int32) (db.StoredMessage, error) {
v, err := db.db.Get(binary.ToBytes(id), nil) v, err := ldb.db.Get(binary.ToBytes(id), nil)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "get value error") return nil, errors.Wrap(err, "get value error")
} }
r := binary.NewReader(v) r := binary.NewReader(v)
switch r.ReadByte() { switch r.ReadByte() {
case group: case group:
g := &StoredGroupMessage{} g := &db.StoredGroupMessage{}
if err = gob.NewDecoder(bytes.NewReader(r.ReadAvailable())).Decode(g); err != nil { if err = gob.NewDecoder(bytes.NewReader(r.ReadAvailable())).Decode(g); err != nil {
return nil, errors.Wrap(err, "decode message error") return nil, errors.Wrap(err, "decode message error")
} }
return g, nil return g, nil
case private: case private:
p := &StoredPrivateMessage{} p := &db.StoredPrivateMessage{}
if err = gob.NewDecoder(bytes.NewReader(r.ReadAvailable())).Decode(p); err != nil { if err = gob.NewDecoder(bytes.NewReader(r.ReadAvailable())).Decode(p); err != nil {
return nil, errors.Wrap(err, "decode message error") return nil, errors.Wrap(err, "decode message error")
} }
@ -67,50 +78,50 @@ func (db *LevelDBImpl) GetMessageByGlobalID(id int32) (IStoredMessage, error) {
} }
} }
func (db *LevelDBImpl) GetGroupMessageByGlobalID(id int32) (*StoredGroupMessage, error) { func (ldb *LevelDBImpl) GetGroupMessageByGlobalID(id int32) (*db.StoredGroupMessage, error) {
i, err := db.GetMessageByGlobalID(id) i, err := ldb.GetMessageByGlobalID(id)
if err != nil { if err != nil {
return nil, err return nil, err
} }
g, ok := i.(*StoredGroupMessage) g, ok := i.(*db.StoredGroupMessage)
if !ok { if !ok {
return nil, errors.New("message type error") return nil, errors.New("message type error")
} }
return g, nil return g, nil
} }
func (db *LevelDBImpl) GetPrivateMessageByGlobalID(id int32) (*StoredPrivateMessage, error) { func (ldb *LevelDBImpl) GetPrivateMessageByGlobalID(id int32) (*db.StoredPrivateMessage, error) {
i, err := db.GetMessageByGlobalID(id) i, err := ldb.GetMessageByGlobalID(id)
if err != nil { if err != nil {
return nil, err return nil, err
} }
p, ok := i.(*StoredPrivateMessage) p, ok := i.(*db.StoredPrivateMessage)
if !ok { if !ok {
return nil, errors.New("message type error") return nil, errors.New("message type error")
} }
return p, nil return p, nil
} }
func (db *LevelDBImpl) InsertGroupMessage(msg *StoredGroupMessage) error { func (ldb *LevelDBImpl) InsertGroupMessage(msg *db.StoredGroupMessage) error {
buf := global.NewBuffer() buf := global.NewBuffer()
defer global.PutBuffer(buf) defer global.PutBuffer(buf)
if err := gob.NewEncoder(buf).Encode(msg); err != nil { if err := gob.NewEncoder(buf).Encode(msg); err != nil {
return errors.Wrap(err, "encode message error") return errors.Wrap(err, "encode message error")
} }
err := db.db.Put(binary.ToBytes(msg.GlobalID), binary.NewWriterF(func(w *binary.Writer) { err := ldb.db.Put(binary.ToBytes(msg.GlobalID), binary.NewWriterF(func(w *binary.Writer) {
w.WriteByte(group) w.WriteByte(group)
w.Write(buf.Bytes()) w.Write(buf.Bytes())
}), nil) }), nil)
return errors.Wrap(err, "put data error") return errors.Wrap(err, "put data error")
} }
func (db *LevelDBImpl) InsertPrivateMessage(msg *StoredPrivateMessage) error { func (ldb *LevelDBImpl) InsertPrivateMessage(msg *db.StoredPrivateMessage) error {
buf := global.NewBuffer() buf := global.NewBuffer()
defer global.PutBuffer(buf) defer global.PutBuffer(buf)
if err := gob.NewEncoder(buf).Encode(msg); err != nil { if err := gob.NewEncoder(buf).Encode(msg); err != nil {
return errors.Wrap(err, "encode message error") return errors.Wrap(err, "encode message error")
} }
err := db.db.Put(binary.ToBytes(msg.GlobalID), binary.NewWriterF(func(w *binary.Writer) { err := ldb.db.Put(binary.ToBytes(msg.GlobalID), binary.NewWriterF(func(w *binary.Writer) {
w.WriteByte(private) w.WriteByte(private)
w.Write(buf.Bytes()) w.Write(buf.Bytes())
}), nil) }), nil)

View File

@ -1,71 +0,0 @@
package db
import (
"context"
"github.com/pkg/errors"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
)
type MongoDBImpl struct {
uri string
db string
mongo *mongo.Database
}
const (
MongoGroupMessageCollection = "group-messages"
MongoPrivateMessageCollection = "private-messages"
)
func UseMongoDB(uri, db string) *MongoDBImpl {
return &MongoDBImpl{uri: uri, db: db}
}
func (db *MongoDBImpl) Open() error {
cli, err := mongo.Connect(context.Background(), options.Client().ApplyURI(db.uri))
if err != nil {
return errors.Wrap(err, "open mongo connection error")
}
db.mongo = cli.Database(db.db)
return nil
}
func (db *MongoDBImpl) GetMessageByGlobalID(id int32) (IStoredMessage, error) {
if r, err := db.GetGroupMessageByGlobalID(id); err == nil {
return r, nil
}
return db.GetPrivateMessageByGlobalID(id)
}
func (db *MongoDBImpl) GetGroupMessageByGlobalID(id int32) (*StoredGroupMessage, error) {
coll := db.mongo.Collection(MongoGroupMessageCollection)
var ret StoredGroupMessage
if err := coll.FindOne(context.Background(), bson.D{{"globalId", id}}).Decode(&ret); err != nil {
return nil, errors.Wrap(err, "query error")
}
return &ret, nil
}
func (db *MongoDBImpl) GetPrivateMessageByGlobalID(id int32) (*StoredPrivateMessage, error) {
coll := db.mongo.Collection(MongoPrivateMessageCollection)
var ret StoredPrivateMessage
if err := coll.FindOne(context.Background(), bson.D{{"globalId", id}}).Decode(&ret); err != nil {
return nil, errors.Wrap(err, "query error")
}
return &ret, nil
}
func (db *MongoDBImpl) InsertGroupMessage(msg *StoredGroupMessage) error {
coll := db.mongo.Collection(MongoGroupMessageCollection)
_, err := coll.InsertOne(context.Background(), msg)
return errors.Wrap(err, "insert error")
}
func (db *MongoDBImpl) InsertPrivateMessage(msg *StoredPrivateMessage) error {
coll := db.mongo.Collection(MongoPrivateMessageCollection)
_, err := coll.InsertOne(context.Background(), msg)
return errors.Wrap(err, "insert error")
}

85
db/mongodb/mongodb.go Normal file
View File

@ -0,0 +1,85 @@
package mongodb
import (
"context"
"github.com/pkg/errors"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
"gopkg.in/yaml.v3"
"github.com/Mrs4s/go-cqhttp/db"
"github.com/Mrs4s/go-cqhttp/modules/config"
)
type MongoDBImpl struct {
uri string
db string
mongo *mongo.Database
}
const (
MongoGroupMessageCollection = "group-messages"
MongoPrivateMessageCollection = "private-messages"
)
func init() {
db.Register("mongodb", func(node yaml.Node) db.Database {
conf := new(config.MongoDBConfig)
_ = node.Decode(conf)
if conf.Database == "" {
conf.Database = "gocq-database"
}
if !conf.Enable {
return nil
}
return &MongoDBImpl{uri: conf.URI, db: conf.Database}
})
}
func (m *MongoDBImpl) Open() error {
cli, err := mongo.Connect(context.Background(), options.Client().ApplyURI(m.uri))
if err != nil {
return errors.Wrap(err, "open mongo connection error")
}
m.mongo = cli.Database(m.db)
return nil
}
func (m *MongoDBImpl) GetMessageByGlobalID(id int32) (db.StoredMessage, error) {
if r, err := m.GetGroupMessageByGlobalID(id); err == nil {
return r, nil
}
return m.GetPrivateMessageByGlobalID(id)
}
func (m *MongoDBImpl) GetGroupMessageByGlobalID(id int32) (*db.StoredGroupMessage, error) {
coll := m.mongo.Collection(MongoGroupMessageCollection)
var ret db.StoredGroupMessage
if err := coll.FindOne(context.Background(), bson.D{{"globalId", id}}).Decode(&ret); err != nil {
return nil, errors.Wrap(err, "query error")
}
return &ret, nil
}
func (m *MongoDBImpl) GetPrivateMessageByGlobalID(id int32) (*db.StoredPrivateMessage, error) {
coll := m.mongo.Collection(MongoPrivateMessageCollection)
var ret db.StoredPrivateMessage
if err := coll.FindOne(context.Background(), bson.D{{"globalId", id}}).Decode(&ret); err != nil {
return nil, errors.Wrap(err, "query error")
}
return &ret, nil
}
func (m *MongoDBImpl) InsertGroupMessage(msg *db.StoredGroupMessage) error {
coll := m.mongo.Collection(MongoGroupMessageCollection)
_, err := coll.InsertOne(context.Background(), msg)
return errors.Wrap(err, "insert error")
}
func (m *MongoDBImpl) InsertPrivateMessage(msg *db.StoredPrivateMessage) error {
coll := m.mongo.Collection(MongoPrivateMessageCollection)
_, err := coll.InsertOne(context.Background(), msg)
return errors.Wrap(err, "insert error")
}

View File

@ -1,26 +1,43 @@
package db package db
import "github.com/pkg/errors" import (
"github.com/pkg/errors"
"gopkg.in/yaml.v3"
// MultiDBImpl 多数据库支持 "github.com/Mrs4s/go-cqhttp/internal/base"
)
// backends 多数据库支持, 后端支持
// 写入会对所有 Backend 进行写入 // 写入会对所有 Backend 进行写入
// 读取只会读取第一个库 // 读取只会读取第一个库
type MultiDBImpl struct { var backends []Database
backends []IDatabase
// drivers 多数据库启动
var drivers = make(map[string]func(node yaml.Node) Database)
// Register 添加数据库后端
func Register(name string, init func(yaml.Node) Database) {
if _, ok := drivers[name]; ok {
panic("database driver conflict: " + name)
}
drivers[name] = init
} }
func NewMultiDatabase(backends ...IDatabase) *MultiDBImpl { // Init 加载所有后端配置文件
return &MultiDBImpl{ func Init() {
backends: backends, backends = make([]Database, 0, len(drivers))
for name, init := range drivers {
if n, ok := base.Database[name]; ok {
db := init(n)
if db != nil {
backends = append(backends, db)
}
}
} }
} }
func (db *MultiDBImpl) UseDB(backend IDatabase) { func Open() error {
db.backends = append(db.backends, backend) for _, b := range backends {
}
func (db *MultiDBImpl) Open() error {
for _, b := range db.backends {
if err := b.Open(); err != nil { if err := b.Open(); err != nil {
return errors.Wrap(err, "open backend error") return errors.Wrap(err, "open backend error")
} }
@ -28,29 +45,29 @@ func (db *MultiDBImpl) Open() error {
return nil return nil
} }
func (db *MultiDBImpl) GetMessageByGlobalID(id int32) (IStoredMessage, error) { func GetMessageByGlobalID(id int32) (StoredMessage, error) {
if len(db.backends) == 0 { if len(backends) == 0 {
return nil, errors.New("database disabled") return nil, errors.New("database disabled")
} }
return db.backends[0].GetMessageByGlobalID(id) return backends[0].GetMessageByGlobalID(id)
} }
func (db *MultiDBImpl) GetGroupMessageByGlobalID(id int32) (*StoredGroupMessage, error) { func GetGroupMessageByGlobalID(id int32) (*StoredGroupMessage, error) {
if len(db.backends) == 0 { if len(backends) == 0 {
return nil, errors.New("database disabled") return nil, errors.New("database disabled")
} }
return db.backends[0].GetGroupMessageByGlobalID(id) return backends[0].GetGroupMessageByGlobalID(id)
} }
func (db *MultiDBImpl) GetPrivateMessageByGlobalID(id int32) (*StoredPrivateMessage, error) { func GetPrivateMessageByGlobalID(id int32) (*StoredPrivateMessage, error) {
if len(db.backends) == 0 { if len(backends) == 0 {
return nil, errors.New("database disabled") return nil, errors.New("database disabled")
} }
return db.backends[0].GetPrivateMessageByGlobalID(id) return backends[0].GetPrivateMessageByGlobalID(id)
} }
func (db *MultiDBImpl) InsertGroupMessage(m *StoredGroupMessage) error { func InsertGroupMessage(m *StoredGroupMessage) error {
for _, b := range db.backends { for _, b := range backends {
if err := b.InsertGroupMessage(m); err != nil { if err := b.InsertGroupMessage(m); err != nil {
return errors.Wrap(err, "insert message to backend error") return errors.Wrap(err, "insert message to backend error")
} }
@ -58,8 +75,8 @@ func (db *MultiDBImpl) InsertGroupMessage(m *StoredGroupMessage) error {
return nil return nil
} }
func (db *MultiDBImpl) InsertPrivateMessage(m *StoredPrivateMessage) error { func InsertPrivateMessage(m *StoredPrivateMessage) error {
for _, b := range db.backends { for _, b := range backends {
if err := b.InsertPrivateMessage(m); err != nil { if err := b.InsertPrivateMessage(m); err != nil {
return errors.Wrap(err, "insert message to backend error") return errors.Wrap(err, "insert message to backend error")
} }

View File

@ -19,6 +19,7 @@ import (
"golang.org/x/term" "golang.org/x/term"
"github.com/Mrs4s/go-cqhttp/coolq" "github.com/Mrs4s/go-cqhttp/coolq"
"github.com/Mrs4s/go-cqhttp/db"
"github.com/Mrs4s/go-cqhttp/global" "github.com/Mrs4s/go-cqhttp/global"
"github.com/Mrs4s/go-cqhttp/global/terminal" "github.com/Mrs4s/go-cqhttp/global/terminal"
"github.com/Mrs4s/go-cqhttp/internal/base" "github.com/Mrs4s/go-cqhttp/internal/base"
@ -28,6 +29,7 @@ import (
"github.com/Mrs4s/go-cqhttp/modules/servers" "github.com/Mrs4s/go-cqhttp/modules/servers"
"github.com/Mrs4s/go-cqhttp/server" "github.com/Mrs4s/go-cqhttp/server"
_ "github.com/Mrs4s/go-cqhttp/db/leveldb" // leveldb
_ "github.com/Mrs4s/go-cqhttp/modules/mime" // mime检查模块 _ "github.com/Mrs4s/go-cqhttp/modules/mime" // mime检查模块
_ "github.com/Mrs4s/go-cqhttp/modules/pprof" // pprof 性能分析 _ "github.com/Mrs4s/go-cqhttp/modules/pprof" // pprof 性能分析
_ "github.com/Mrs4s/go-cqhttp/modules/silk" // silk编码模块 _ "github.com/Mrs4s/go-cqhttp/modules/silk" // silk编码模块
@ -84,6 +86,11 @@ func main() {
mkCacheDir(global.CachePath, "发送图片") mkCacheDir(global.CachePath, "发送图片")
cache.Init() cache.Init()
db.Init()
if err := db.Open(); err != nil {
log.Fatalf("打开数据库失败: %v", err)
}
var byteKey []byte var byteKey []byte
arg := os.Args arg := os.Args
if len(arg) > 1 { if len(arg) > 1 {