mirror of
https://github.com/Mrs4s/go-cqhttp.git
synced 2025-05-04 19:17:37 +08:00
feat: modular database
build with leveldb on default
This commit is contained in:
parent
4e69ef4c2f
commit
5e70c8115e
@ -58,20 +58,6 @@ linters:
|
||||
- revive
|
||||
- forbidigo
|
||||
- makezero
|
||||
#- interfacer
|
||||
|
||||
# don't enable:
|
||||
# - scopelint
|
||||
# - gochecknoglobals
|
||||
# - gocognit
|
||||
# - godot
|
||||
# - godox
|
||||
# - goerr113
|
||||
# - interfacer
|
||||
# - maligned
|
||||
# - nestif
|
||||
# - testpackage
|
||||
# - wsl
|
||||
|
||||
run:
|
||||
# default concurrency is a available CPU number.
|
||||
|
17
coolq/api.go
17
coolq/api.go
@ -465,7 +465,7 @@ func (bot *CQBot) CQSendGroupForwardMessage(groupID int64, m gjson.Result) globa
|
||||
ts.Add(time.Second)
|
||||
if e.Get("data.id").Exists() {
|
||||
i := e.Get("data.id").Int()
|
||||
m, _ := bot.db.GetGroupMessageByGlobalID(int32(i))
|
||||
m, _ := db.GetGroupMessageByGlobalID(int32(i))
|
||||
if m != nil {
|
||||
return &message.ForwardNode{
|
||||
SenderId: m.Attribute.SenderUin,
|
||||
@ -781,7 +781,7 @@ func (bot *CQBot) CQProcessGroupRequest(flag, subType, reason string, approve bo
|
||||
//
|
||||
// https:// git.io/Jtz1y
|
||||
func (bot *CQBot) CQDeleteMessage(messageID int32) global.MSG {
|
||||
msg, err := bot.db.GetMessageByGlobalID(messageID)
|
||||
msg, err := db.GetMessageByGlobalID(messageID)
|
||||
if err != nil {
|
||||
log.Warnf("撤回消息时出现错误: %v", err)
|
||||
return Failed(100, "MESSAGE_NOT_FOUND", "消息不存在")
|
||||
@ -1115,7 +1115,7 @@ func (bot *CQBot) CQGetForwardMessage(resID string) global.MSG {
|
||||
//
|
||||
// https://git.io/Jtz1b
|
||||
func (bot *CQBot) CQGetMessage(messageID int32) global.MSG {
|
||||
msg, err := bot.db.GetMessageByGlobalID(messageID)
|
||||
msg, err := db.GetMessageByGlobalID(messageID)
|
||||
if err != nil {
|
||||
log.Warnf("获取消息时出现错误: %v", err)
|
||||
return Failed(100, "MSG_NOT_FOUND", "消息不存在")
|
||||
@ -1176,11 +1176,8 @@ func (bot *CQBot) CQGetGroupMessageHistory(groupID int64, seq int64) global.MSG
|
||||
}
|
||||
ms := make([]global.MSG, 0, len(msg))
|
||||
for _, m := range msg {
|
||||
id := m.Id
|
||||
bot.checkMedia(m.Elements)
|
||||
if bot.db != nil {
|
||||
id = bot.InsertGroupMessage(m)
|
||||
}
|
||||
id := bot.InsertGroupMessage(m)
|
||||
t := bot.formatGroupMessage(m)
|
||||
t["message_id"] = id
|
||||
ms = append(ms, t)
|
||||
@ -1302,7 +1299,7 @@ func (bot *CQBot) CQGetStatus() global.MSG {
|
||||
//
|
||||
// https://docs.go-cqhttp.org/api/#%E8%AE%BE%E7%BD%AE%E7%B2%BE%E5%8D%8E%E6%B6%88%E6%81%AF
|
||||
func (bot *CQBot) CQSetEssenceMessage(messageID int32) global.MSG {
|
||||
msg, err := bot.db.GetGroupMessageByGlobalID(messageID)
|
||||
msg, err := db.GetGroupMessageByGlobalID(messageID)
|
||||
if err != nil {
|
||||
return Failed(100, "MESSAGE_NOT_FOUND", "消息不存在")
|
||||
}
|
||||
@ -1317,7 +1314,7 @@ func (bot *CQBot) CQSetEssenceMessage(messageID int32) global.MSG {
|
||||
//
|
||||
// https://docs.go-cqhttp.org/api/#%E7%A7%BB%E5%87%BA%E7%B2%BE%E5%8D%8E%E6%B6%88%E6%81%AF
|
||||
func (bot *CQBot) CQDeleteEssenceMessage(messageID int32) global.MSG {
|
||||
msg, err := bot.db.GetGroupMessageByGlobalID(messageID)
|
||||
msg, err := db.GetGroupMessageByGlobalID(messageID)
|
||||
if err != nil {
|
||||
return Failed(100, "MESSAGE_NOT_FOUND", "消息不存在")
|
||||
}
|
||||
@ -1436,7 +1433,7 @@ func (bot *CQBot) CQSetModelShow(modelName string, modelShow string) global.MSG
|
||||
|
||||
// CQMarkMessageAsRead 标记消息已读
|
||||
func (bot *CQBot) CQMarkMessageAsRead(msgID int32) global.MSG {
|
||||
m, err := bot.db.GetMessageByGlobalID(msgID)
|
||||
m, err := db.GetMessageByGlobalID(msgID)
|
||||
if err != nil {
|
||||
return Failed(100, "MSG_NOT_FOUND", "消息不存在")
|
||||
}
|
||||
|
30
coolq/bot.go
30
coolq/bot.go
@ -23,7 +23,6 @@ import (
|
||||
|
||||
"github.com/Mrs4s/go-cqhttp/global"
|
||||
"github.com/Mrs4s/go-cqhttp/internal/base"
|
||||
"github.com/Mrs4s/go-cqhttp/modules/config"
|
||||
)
|
||||
|
||||
// CQBot CQBot结构体,存储Bot实例相关配置
|
||||
@ -33,7 +32,6 @@ type CQBot struct {
|
||||
lock sync.RWMutex
|
||||
events []func(*Event)
|
||||
|
||||
db db.IDatabase
|
||||
friendReqCache sync.Map
|
||||
tempSessionCache sync.Map
|
||||
}
|
||||
@ -71,30 +69,6 @@ func NewQQBot(cli *client.QQClient) *CQBot {
|
||||
bot := &CQBot{
|
||||
Client: cli,
|
||||
}
|
||||
levelNode, levelDB := base.Database["leveldb"]
|
||||
mongoNode, mongoDB := base.Database["mongodb"]
|
||||
multiDB := db.NewMultiDatabase()
|
||||
if levelDB {
|
||||
lconf := new(config.LevelDBConfig)
|
||||
_ = levelNode.Decode(lconf)
|
||||
if lconf.Enable {
|
||||
multiDB.UseDB(db.UseLevelDB())
|
||||
}
|
||||
}
|
||||
if mongoDB {
|
||||
lconf := new(config.MongoDBConfig)
|
||||
_ = mongoNode.Decode(lconf)
|
||||
if lconf.Database == "" {
|
||||
lconf.Database = "gocq-database"
|
||||
}
|
||||
if lconf.Enable {
|
||||
multiDB.UseDB(db.UseMongoDB(lconf.URI, lconf.Database))
|
||||
}
|
||||
}
|
||||
if err := multiDB.Open(); err != nil {
|
||||
log.Fatalf("打开数据库失败: %v", err)
|
||||
}
|
||||
bot.db = multiDB
|
||||
bot.Client.OnPrivateMessage(bot.privateMessageEvent)
|
||||
bot.Client.OnGroupMessage(bot.groupMessageEvent)
|
||||
if base.ReportSelfMessage {
|
||||
@ -384,7 +358,7 @@ func (bot *CQBot) InsertGroupMessage(m *message.GroupMessage) int32 {
|
||||
QuotedContent: ToMessageContent(reply.Elements),
|
||||
}
|
||||
}
|
||||
if err := bot.db.InsertGroupMessage(msg); err != nil {
|
||||
if err := db.InsertGroupMessage(msg); err != nil {
|
||||
log.Warnf("记录聊天数据时出现错误: %v", err)
|
||||
return -1
|
||||
}
|
||||
@ -427,7 +401,7 @@ func (bot *CQBot) InsertPrivateMessage(m *message.PrivateMessage) int32 {
|
||||
QuotedContent: ToMessageContent(m.Elements),
|
||||
}
|
||||
}
|
||||
if err := bot.db.InsertPrivateMessage(msg); err != nil {
|
||||
if err := db.InsertPrivateMessage(msg); err != nil {
|
||||
log.Warnf("记录聊天数据时出现错误: %v", err)
|
||||
return -1
|
||||
}
|
||||
|
@ -479,7 +479,7 @@ func (bot *CQBot) ConvertStringMessage(raw string, isGroup bool) (r []message.IM
|
||||
switch {
|
||||
case customText != "":
|
||||
var elem *message.ReplyElement
|
||||
var org db.IStoredMessage
|
||||
var org db.StoredMessage
|
||||
sender, senderErr := strconv.ParseInt(d["qq"], 10, 64)
|
||||
if senderErr != nil && err != nil {
|
||||
log.Warnf("警告: 自定义 Reply 元素中必须包含 Uin 或 id")
|
||||
@ -491,7 +491,7 @@ func (bot *CQBot) ConvertStringMessage(raw string, isGroup bool) (r []message.IM
|
||||
}
|
||||
messageSeq, seqErr := strconv.ParseInt(d["seq"], 10, 64)
|
||||
if err == nil {
|
||||
org, _ = bot.db.GetMessageByGlobalID(int32(mid))
|
||||
org, _ = db.GetMessageByGlobalID(int32(mid))
|
||||
}
|
||||
if org != nil {
|
||||
elem = &message.ReplyElement{
|
||||
@ -519,7 +519,7 @@ func (bot *CQBot) ConvertStringMessage(raw string, isGroup bool) (r []message.IM
|
||||
}
|
||||
r = append([]message.IMessageElement{elem}, r...)
|
||||
case err == nil:
|
||||
org, err := bot.db.GetMessageByGlobalID(int32(mid))
|
||||
org, err := db.GetMessageByGlobalID(int32(mid))
|
||||
if err == nil {
|
||||
r = append([]message.IMessageElement{
|
||||
&message.ReplyElement{
|
||||
@ -651,7 +651,7 @@ func (bot *CQBot) ConvertObjectMessage(m gjson.Result, isGroup bool) (r []messag
|
||||
switch {
|
||||
case customText != "":
|
||||
var elem *message.ReplyElement
|
||||
var org db.IStoredMessage
|
||||
var org db.StoredMessage
|
||||
sender, senderErr := strconv.ParseInt(e.Get("data.[user_id,qq]").String(), 10, 64)
|
||||
if senderErr != nil && err != nil {
|
||||
log.Warnf("警告: 自定义 Reply 元素中必须包含 user_id 或 id")
|
||||
@ -663,7 +663,7 @@ func (bot *CQBot) ConvertObjectMessage(m gjson.Result, isGroup bool) (r []messag
|
||||
}
|
||||
messageSeq, seqErr := strconv.ParseInt(e.Get("data.seq").String(), 10, 64)
|
||||
if err == nil {
|
||||
org, _ = bot.db.GetMessageByGlobalID(int32(mid))
|
||||
org, _ = db.GetMessageByGlobalID(int32(mid))
|
||||
}
|
||||
if org != nil {
|
||||
elem = &message.ReplyElement{
|
||||
@ -691,7 +691,7 @@ func (bot *CQBot) ConvertObjectMessage(m gjson.Result, isGroup bool) (r []messag
|
||||
}
|
||||
r = append([]message.IMessageElement{elem}, r...)
|
||||
case err == nil:
|
||||
org, err := bot.db.GetMessageByGlobalID(int32(mid))
|
||||
org, err := db.GetMessageByGlobalID(int32(mid))
|
||||
if err == nil {
|
||||
r = append([]message.IMessageElement{
|
||||
&message.ReplyElement{
|
||||
|
@ -33,10 +33,7 @@ func ToFormattedMessage(e []message.IMessageElement, groupID int64, isRaw ...boo
|
||||
func (bot *CQBot) privateMessageEvent(c *client.QQClient, m *message.PrivateMessage) {
|
||||
bot.checkMedia(m.Elements)
|
||||
cqm := ToStringMessage(m.Elements, 0, true)
|
||||
id := m.Id
|
||||
if bot.db != nil {
|
||||
id = bot.InsertPrivateMessage(m)
|
||||
}
|
||||
id := bot.InsertPrivateMessage(m)
|
||||
log.Infof("收到好友 %v(%v) 的消息: %v (%v)", m.Sender.DisplayName(), m.Sender.Uin, cqm, id)
|
||||
fm := global.MSG{
|
||||
"post_type": func() string {
|
||||
@ -89,10 +86,7 @@ func (bot *CQBot) groupMessageEvent(c *client.QQClient, m *message.GroupMessage)
|
||||
}
|
||||
}
|
||||
cqm := ToStringMessage(m.Elements, m.GroupCode, true)
|
||||
id := m.Id
|
||||
if bot.db != nil {
|
||||
id = bot.InsertGroupMessage(m)
|
||||
}
|
||||
id := bot.InsertGroupMessage(m)
|
||||
log.Infof("收到群 %v(%v) 内 %v(%v) 的消息: %v (%v)", m.GroupName, m.GroupCode, m.Sender.DisplayName(), m.Sender.Uin, cqm, id)
|
||||
gm := bot.formatGroupMessage(m)
|
||||
if gm == nil {
|
||||
@ -108,10 +102,10 @@ func (bot *CQBot) tempMessageEvent(c *client.QQClient, e *client.TempMessageEven
|
||||
cqm := ToStringMessage(m.Elements, 0, true)
|
||||
bot.tempSessionCache.Store(m.Sender.Uin, e.Session)
|
||||
id := m.Id
|
||||
if bot.db != nil { // nolint
|
||||
// todo(Mrs4s)
|
||||
// if bot.db != nil { // nolint
|
||||
// id = bot.InsertTempMessage(m.Sender.Uin, m)
|
||||
}
|
||||
// }
|
||||
log.Infof("收到来自群 %v(%v) 内 %v(%v) 的临时会话消息: %v", m.GroupName, m.GroupCode, m.Sender.DisplayName(), m.Sender.Uin, cqm)
|
||||
tm := global.MSG{
|
||||
"post_type": "message",
|
||||
|
@ -8,13 +8,13 @@ import (
|
||||
)
|
||||
|
||||
type (
|
||||
// IDatabase 数据库操作接口定义
|
||||
IDatabase interface {
|
||||
// Database 数据库操作接口定义
|
||||
Database interface {
|
||||
// Open 初始化数据库
|
||||
Open() error
|
||||
|
||||
// GetMessageByGlobalID 通过 GlobalID 来获取消息
|
||||
GetMessageByGlobalID(int32) (IStoredMessage, error)
|
||||
GetMessageByGlobalID(int32) (StoredMessage, error)
|
||||
// GetGroupMessageByGlobalID 通过 GlobalID 来获取群消息
|
||||
GetGroupMessageByGlobalID(int32) (*StoredGroupMessage, error)
|
||||
// GetPrivateMessageByGlobalID 通过 GlobalID 来获取私聊消息
|
||||
@ -26,7 +26,7 @@ type (
|
||||
InsertPrivateMessage(*StoredPrivateMessage) error
|
||||
}
|
||||
|
||||
IStoredMessage interface {
|
||||
StoredMessage interface {
|
||||
GetID() string
|
||||
GetType() string
|
||||
GetGlobalID() int32
|
||||
|
@ -1,4 +1,4 @@
|
||||
package db
|
||||
package leveldb
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
@ -9,8 +9,11 @@ import (
|
||||
"github.com/pkg/errors"
|
||||
"github.com/syndtr/goleveldb/leveldb"
|
||||
"github.com/syndtr/goleveldb/leveldb/opt"
|
||||
"gopkg.in/yaml.v3"
|
||||
|
||||
"github.com/Mrs4s/go-cqhttp/db"
|
||||
"github.com/Mrs4s/go-cqhttp/global"
|
||||
"github.com/Mrs4s/go-cqhttp/modules/config"
|
||||
)
|
||||
|
||||
type LevelDBImpl struct {
|
||||
@ -22,42 +25,50 @@ const (
|
||||
private byte = 0x1
|
||||
)
|
||||
|
||||
func UseLevelDB() *LevelDBImpl {
|
||||
gob.Register(StoredMessageAttribute{})
|
||||
gob.Register(QuotedInfo{})
|
||||
func init() {
|
||||
gob.Register(db.StoredMessageAttribute{})
|
||||
gob.Register(db.QuotedInfo{})
|
||||
gob.Register(global.MSG{})
|
||||
gob.Register(StoredGroupMessage{})
|
||||
gob.Register(StoredPrivateMessage{})
|
||||
gob.Register(db.StoredGroupMessage{})
|
||||
gob.Register(db.StoredPrivateMessage{})
|
||||
|
||||
db.Register("leveldb", func(node yaml.Node) db.Database {
|
||||
conf := new(config.LevelDBConfig)
|
||||
_ = node.Decode(conf)
|
||||
if !conf.Enable {
|
||||
return nil
|
||||
}
|
||||
return &LevelDBImpl{}
|
||||
})
|
||||
}
|
||||
|
||||
func (db *LevelDBImpl) Open() error {
|
||||
func (ldb *LevelDBImpl) Open() error {
|
||||
p := path.Join("data", "leveldb-v2")
|
||||
d, err := leveldb.OpenFile(p, &opt.Options{
|
||||
WriteBuffer: 128 * opt.KiB,
|
||||
})
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "open level db error")
|
||||
return errors.Wrap(err, "open level ldb error")
|
||||
}
|
||||
db.db = d
|
||||
ldb.db = d
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *LevelDBImpl) GetMessageByGlobalID(id int32) (IStoredMessage, error) {
|
||||
v, err := db.db.Get(binary.ToBytes(id), nil)
|
||||
func (ldb *LevelDBImpl) GetMessageByGlobalID(id int32) (db.StoredMessage, error) {
|
||||
v, err := ldb.db.Get(binary.ToBytes(id), nil)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "get value error")
|
||||
}
|
||||
r := binary.NewReader(v)
|
||||
switch r.ReadByte() {
|
||||
case group:
|
||||
g := &StoredGroupMessage{}
|
||||
g := &db.StoredGroupMessage{}
|
||||
if err = gob.NewDecoder(bytes.NewReader(r.ReadAvailable())).Decode(g); err != nil {
|
||||
return nil, errors.Wrap(err, "decode message error")
|
||||
}
|
||||
return g, nil
|
||||
case private:
|
||||
p := &StoredPrivateMessage{}
|
||||
p := &db.StoredPrivateMessage{}
|
||||
if err = gob.NewDecoder(bytes.NewReader(r.ReadAvailable())).Decode(p); err != nil {
|
||||
return nil, errors.Wrap(err, "decode message error")
|
||||
}
|
||||
@ -67,50 +78,50 @@ func (db *LevelDBImpl) GetMessageByGlobalID(id int32) (IStoredMessage, error) {
|
||||
}
|
||||
}
|
||||
|
||||
func (db *LevelDBImpl) GetGroupMessageByGlobalID(id int32) (*StoredGroupMessage, error) {
|
||||
i, err := db.GetMessageByGlobalID(id)
|
||||
func (ldb *LevelDBImpl) GetGroupMessageByGlobalID(id int32) (*db.StoredGroupMessage, error) {
|
||||
i, err := ldb.GetMessageByGlobalID(id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
g, ok := i.(*StoredGroupMessage)
|
||||
g, ok := i.(*db.StoredGroupMessage)
|
||||
if !ok {
|
||||
return nil, errors.New("message type error")
|
||||
}
|
||||
return g, nil
|
||||
}
|
||||
|
||||
func (db *LevelDBImpl) GetPrivateMessageByGlobalID(id int32) (*StoredPrivateMessage, error) {
|
||||
i, err := db.GetMessageByGlobalID(id)
|
||||
func (ldb *LevelDBImpl) GetPrivateMessageByGlobalID(id int32) (*db.StoredPrivateMessage, error) {
|
||||
i, err := ldb.GetMessageByGlobalID(id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
p, ok := i.(*StoredPrivateMessage)
|
||||
p, ok := i.(*db.StoredPrivateMessage)
|
||||
if !ok {
|
||||
return nil, errors.New("message type error")
|
||||
}
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func (db *LevelDBImpl) InsertGroupMessage(msg *StoredGroupMessage) error {
|
||||
func (ldb *LevelDBImpl) InsertGroupMessage(msg *db.StoredGroupMessage) error {
|
||||
buf := global.NewBuffer()
|
||||
defer global.PutBuffer(buf)
|
||||
if err := gob.NewEncoder(buf).Encode(msg); err != nil {
|
||||
return errors.Wrap(err, "encode message error")
|
||||
}
|
||||
err := db.db.Put(binary.ToBytes(msg.GlobalID), binary.NewWriterF(func(w *binary.Writer) {
|
||||
err := ldb.db.Put(binary.ToBytes(msg.GlobalID), binary.NewWriterF(func(w *binary.Writer) {
|
||||
w.WriteByte(group)
|
||||
w.Write(buf.Bytes())
|
||||
}), nil)
|
||||
return errors.Wrap(err, "put data error")
|
||||
}
|
||||
|
||||
func (db *LevelDBImpl) InsertPrivateMessage(msg *StoredPrivateMessage) error {
|
||||
func (ldb *LevelDBImpl) InsertPrivateMessage(msg *db.StoredPrivateMessage) error {
|
||||
buf := global.NewBuffer()
|
||||
defer global.PutBuffer(buf)
|
||||
if err := gob.NewEncoder(buf).Encode(msg); err != nil {
|
||||
return errors.Wrap(err, "encode message error")
|
||||
}
|
||||
err := db.db.Put(binary.ToBytes(msg.GlobalID), binary.NewWriterF(func(w *binary.Writer) {
|
||||
err := ldb.db.Put(binary.ToBytes(msg.GlobalID), binary.NewWriterF(func(w *binary.Writer) {
|
||||
w.WriteByte(private)
|
||||
w.Write(buf.Bytes())
|
||||
}), nil)
|
@ -1,71 +0,0 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"go.mongodb.org/mongo-driver/bson"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
"go.mongodb.org/mongo-driver/mongo/options"
|
||||
)
|
||||
|
||||
type MongoDBImpl struct {
|
||||
uri string
|
||||
db string
|
||||
mongo *mongo.Database
|
||||
}
|
||||
|
||||
const (
|
||||
MongoGroupMessageCollection = "group-messages"
|
||||
MongoPrivateMessageCollection = "private-messages"
|
||||
)
|
||||
|
||||
func UseMongoDB(uri, db string) *MongoDBImpl {
|
||||
return &MongoDBImpl{uri: uri, db: db}
|
||||
}
|
||||
|
||||
func (db *MongoDBImpl) Open() error {
|
||||
cli, err := mongo.Connect(context.Background(), options.Client().ApplyURI(db.uri))
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "open mongo connection error")
|
||||
}
|
||||
db.mongo = cli.Database(db.db)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *MongoDBImpl) GetMessageByGlobalID(id int32) (IStoredMessage, error) {
|
||||
if r, err := db.GetGroupMessageByGlobalID(id); err == nil {
|
||||
return r, nil
|
||||
}
|
||||
return db.GetPrivateMessageByGlobalID(id)
|
||||
}
|
||||
|
||||
func (db *MongoDBImpl) GetGroupMessageByGlobalID(id int32) (*StoredGroupMessage, error) {
|
||||
coll := db.mongo.Collection(MongoGroupMessageCollection)
|
||||
var ret StoredGroupMessage
|
||||
if err := coll.FindOne(context.Background(), bson.D{{"globalId", id}}).Decode(&ret); err != nil {
|
||||
return nil, errors.Wrap(err, "query error")
|
||||
}
|
||||
return &ret, nil
|
||||
}
|
||||
|
||||
func (db *MongoDBImpl) GetPrivateMessageByGlobalID(id int32) (*StoredPrivateMessage, error) {
|
||||
coll := db.mongo.Collection(MongoPrivateMessageCollection)
|
||||
var ret StoredPrivateMessage
|
||||
if err := coll.FindOne(context.Background(), bson.D{{"globalId", id}}).Decode(&ret); err != nil {
|
||||
return nil, errors.Wrap(err, "query error")
|
||||
}
|
||||
return &ret, nil
|
||||
}
|
||||
|
||||
func (db *MongoDBImpl) InsertGroupMessage(msg *StoredGroupMessage) error {
|
||||
coll := db.mongo.Collection(MongoGroupMessageCollection)
|
||||
_, err := coll.InsertOne(context.Background(), msg)
|
||||
return errors.Wrap(err, "insert error")
|
||||
}
|
||||
|
||||
func (db *MongoDBImpl) InsertPrivateMessage(msg *StoredPrivateMessage) error {
|
||||
coll := db.mongo.Collection(MongoPrivateMessageCollection)
|
||||
_, err := coll.InsertOne(context.Background(), msg)
|
||||
return errors.Wrap(err, "insert error")
|
||||
}
|
85
db/mongodb/mongodb.go
Normal file
85
db/mongodb/mongodb.go
Normal file
@ -0,0 +1,85 @@
|
||||
package mongodb
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"go.mongodb.org/mongo-driver/bson"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
"go.mongodb.org/mongo-driver/mongo/options"
|
||||
"gopkg.in/yaml.v3"
|
||||
|
||||
"github.com/Mrs4s/go-cqhttp/db"
|
||||
"github.com/Mrs4s/go-cqhttp/modules/config"
|
||||
)
|
||||
|
||||
type MongoDBImpl struct {
|
||||
uri string
|
||||
db string
|
||||
mongo *mongo.Database
|
||||
}
|
||||
|
||||
const (
|
||||
MongoGroupMessageCollection = "group-messages"
|
||||
MongoPrivateMessageCollection = "private-messages"
|
||||
)
|
||||
|
||||
func init() {
|
||||
db.Register("mongodb", func(node yaml.Node) db.Database {
|
||||
conf := new(config.MongoDBConfig)
|
||||
_ = node.Decode(conf)
|
||||
if conf.Database == "" {
|
||||
conf.Database = "gocq-database"
|
||||
}
|
||||
if !conf.Enable {
|
||||
return nil
|
||||
}
|
||||
return &MongoDBImpl{uri: conf.URI, db: conf.Database}
|
||||
})
|
||||
}
|
||||
|
||||
func (m *MongoDBImpl) Open() error {
|
||||
cli, err := mongo.Connect(context.Background(), options.Client().ApplyURI(m.uri))
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "open mongo connection error")
|
||||
}
|
||||
m.mongo = cli.Database(m.db)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MongoDBImpl) GetMessageByGlobalID(id int32) (db.StoredMessage, error) {
|
||||
if r, err := m.GetGroupMessageByGlobalID(id); err == nil {
|
||||
return r, nil
|
||||
}
|
||||
return m.GetPrivateMessageByGlobalID(id)
|
||||
}
|
||||
|
||||
func (m *MongoDBImpl) GetGroupMessageByGlobalID(id int32) (*db.StoredGroupMessage, error) {
|
||||
coll := m.mongo.Collection(MongoGroupMessageCollection)
|
||||
var ret db.StoredGroupMessage
|
||||
if err := coll.FindOne(context.Background(), bson.D{{"globalId", id}}).Decode(&ret); err != nil {
|
||||
return nil, errors.Wrap(err, "query error")
|
||||
}
|
||||
return &ret, nil
|
||||
}
|
||||
|
||||
func (m *MongoDBImpl) GetPrivateMessageByGlobalID(id int32) (*db.StoredPrivateMessage, error) {
|
||||
coll := m.mongo.Collection(MongoPrivateMessageCollection)
|
||||
var ret db.StoredPrivateMessage
|
||||
if err := coll.FindOne(context.Background(), bson.D{{"globalId", id}}).Decode(&ret); err != nil {
|
||||
return nil, errors.Wrap(err, "query error")
|
||||
}
|
||||
return &ret, nil
|
||||
}
|
||||
|
||||
func (m *MongoDBImpl) InsertGroupMessage(msg *db.StoredGroupMessage) error {
|
||||
coll := m.mongo.Collection(MongoGroupMessageCollection)
|
||||
_, err := coll.InsertOne(context.Background(), msg)
|
||||
return errors.Wrap(err, "insert error")
|
||||
}
|
||||
|
||||
func (m *MongoDBImpl) InsertPrivateMessage(msg *db.StoredPrivateMessage) error {
|
||||
coll := m.mongo.Collection(MongoPrivateMessageCollection)
|
||||
_, err := coll.InsertOne(context.Background(), msg)
|
||||
return errors.Wrap(err, "insert error")
|
||||
}
|
@ -1,26 +1,43 @@
|
||||
package db
|
||||
|
||||
import "github.com/pkg/errors"
|
||||
import (
|
||||
"github.com/pkg/errors"
|
||||
"gopkg.in/yaml.v3"
|
||||
|
||||
// MultiDBImpl 多数据库支持
|
||||
"github.com/Mrs4s/go-cqhttp/internal/base"
|
||||
)
|
||||
|
||||
// backends 多数据库支持, 后端支持
|
||||
// 写入会对所有 Backend 进行写入
|
||||
// 读取只会读取第一个库
|
||||
type MultiDBImpl struct {
|
||||
backends []IDatabase
|
||||
var backends []Database
|
||||
|
||||
// drivers 多数据库启动
|
||||
var drivers = make(map[string]func(node yaml.Node) Database)
|
||||
|
||||
// Register 添加数据库后端
|
||||
func Register(name string, init func(yaml.Node) Database) {
|
||||
if _, ok := drivers[name]; ok {
|
||||
panic("database driver conflict: " + name)
|
||||
}
|
||||
drivers[name] = init
|
||||
}
|
||||
|
||||
func NewMultiDatabase(backends ...IDatabase) *MultiDBImpl {
|
||||
return &MultiDBImpl{
|
||||
backends: backends,
|
||||
// Init 加载所有后端配置文件
|
||||
func Init() {
|
||||
backends = make([]Database, 0, len(drivers))
|
||||
for name, init := range drivers {
|
||||
if n, ok := base.Database[name]; ok {
|
||||
db := init(n)
|
||||
if db != nil {
|
||||
backends = append(backends, db)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (db *MultiDBImpl) UseDB(backend IDatabase) {
|
||||
db.backends = append(db.backends, backend)
|
||||
}
|
||||
|
||||
func (db *MultiDBImpl) Open() error {
|
||||
for _, b := range db.backends {
|
||||
func Open() error {
|
||||
for _, b := range backends {
|
||||
if err := b.Open(); err != nil {
|
||||
return errors.Wrap(err, "open backend error")
|
||||
}
|
||||
@ -28,29 +45,29 @@ func (db *MultiDBImpl) Open() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *MultiDBImpl) GetMessageByGlobalID(id int32) (IStoredMessage, error) {
|
||||
if len(db.backends) == 0 {
|
||||
func GetMessageByGlobalID(id int32) (StoredMessage, error) {
|
||||
if len(backends) == 0 {
|
||||
return nil, errors.New("database disabled")
|
||||
}
|
||||
return db.backends[0].GetMessageByGlobalID(id)
|
||||
return backends[0].GetMessageByGlobalID(id)
|
||||
}
|
||||
|
||||
func (db *MultiDBImpl) GetGroupMessageByGlobalID(id int32) (*StoredGroupMessage, error) {
|
||||
if len(db.backends) == 0 {
|
||||
func GetGroupMessageByGlobalID(id int32) (*StoredGroupMessage, error) {
|
||||
if len(backends) == 0 {
|
||||
return nil, errors.New("database disabled")
|
||||
}
|
||||
return db.backends[0].GetGroupMessageByGlobalID(id)
|
||||
return backends[0].GetGroupMessageByGlobalID(id)
|
||||
}
|
||||
|
||||
func (db *MultiDBImpl) GetPrivateMessageByGlobalID(id int32) (*StoredPrivateMessage, error) {
|
||||
if len(db.backends) == 0 {
|
||||
func GetPrivateMessageByGlobalID(id int32) (*StoredPrivateMessage, error) {
|
||||
if len(backends) == 0 {
|
||||
return nil, errors.New("database disabled")
|
||||
}
|
||||
return db.backends[0].GetPrivateMessageByGlobalID(id)
|
||||
return backends[0].GetPrivateMessageByGlobalID(id)
|
||||
}
|
||||
|
||||
func (db *MultiDBImpl) InsertGroupMessage(m *StoredGroupMessage) error {
|
||||
for _, b := range db.backends {
|
||||
func InsertGroupMessage(m *StoredGroupMessage) error {
|
||||
for _, b := range backends {
|
||||
if err := b.InsertGroupMessage(m); err != nil {
|
||||
return errors.Wrap(err, "insert message to backend error")
|
||||
}
|
||||
@ -58,8 +75,8 @@ func (db *MultiDBImpl) InsertGroupMessage(m *StoredGroupMessage) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *MultiDBImpl) InsertPrivateMessage(m *StoredPrivateMessage) error {
|
||||
for _, b := range db.backends {
|
||||
func InsertPrivateMessage(m *StoredPrivateMessage) error {
|
||||
for _, b := range backends {
|
||||
if err := b.InsertPrivateMessage(m); err != nil {
|
||||
return errors.Wrap(err, "insert message to backend error")
|
||||
}
|
||||
|
7
main.go
7
main.go
@ -19,6 +19,7 @@ import (
|
||||
"golang.org/x/term"
|
||||
|
||||
"github.com/Mrs4s/go-cqhttp/coolq"
|
||||
"github.com/Mrs4s/go-cqhttp/db"
|
||||
"github.com/Mrs4s/go-cqhttp/global"
|
||||
"github.com/Mrs4s/go-cqhttp/global/terminal"
|
||||
"github.com/Mrs4s/go-cqhttp/internal/base"
|
||||
@ -28,6 +29,7 @@ import (
|
||||
"github.com/Mrs4s/go-cqhttp/modules/servers"
|
||||
"github.com/Mrs4s/go-cqhttp/server"
|
||||
|
||||
_ "github.com/Mrs4s/go-cqhttp/db/leveldb" // leveldb
|
||||
_ "github.com/Mrs4s/go-cqhttp/modules/mime" // mime检查模块
|
||||
_ "github.com/Mrs4s/go-cqhttp/modules/pprof" // pprof 性能分析
|
||||
_ "github.com/Mrs4s/go-cqhttp/modules/silk" // silk编码模块
|
||||
@ -84,6 +86,11 @@ func main() {
|
||||
mkCacheDir(global.CachePath, "发送图片")
|
||||
cache.Init()
|
||||
|
||||
db.Init()
|
||||
if err := db.Open(); err != nil {
|
||||
log.Fatalf("打开数据库失败: %v", err)
|
||||
}
|
||||
|
||||
var byteKey []byte
|
||||
arg := os.Args
|
||||
if len(arg) > 1 {
|
||||
|
Loading…
x
Reference in New Issue
Block a user