1
0
mirror of https://github.com/Mrs4s/go-cqhttp.git synced 2025-05-04 19:17:37 +08:00
go-cqhttp/db/mongodb/mongodb.go

102 lines
3.0 KiB
Go

package mongodb
import (
"context"
"github.com/pkg/errors"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
"gopkg.in/yaml.v3"
"github.com/Mrs4s/go-cqhttp/db"
"github.com/Mrs4s/go-cqhttp/modules/config"
)
type MongoDBImpl struct {
uri string
db string
mongo *mongo.Database
}
const (
MongoGroupMessageCollection = "group-messages"
MongoPrivateMessageCollection = "private-messages"
MongoGuildChannelMessageCollection = "guild-channel-messages"
)
func init() {
db.Register("mongodb", func(node yaml.Node) db.Database {
conf := new(config.MongoDBConfig)
_ = node.Decode(conf)
if conf.Database == "" {
conf.Database = "gocq-database"
}
if !conf.Enable {
return nil
}
return &MongoDBImpl{uri: conf.URI, db: conf.Database}
})
}
func (m *MongoDBImpl) Open() error {
cli, err := mongo.Connect(context.Background(), options.Client().ApplyURI(m.uri))
if err != nil {
return errors.Wrap(err, "open mongo connection error")
}
m.mongo = cli.Database(m.db)
return nil
}
func (m *MongoDBImpl) GetMessageByGlobalID(id int32) (db.StoredMessage, error) {
if r, err := m.GetGroupMessageByGlobalID(id); err == nil {
return r, nil
}
return m.GetPrivateMessageByGlobalID(id)
}
func (m *MongoDBImpl) GetGroupMessageByGlobalID(id int32) (*db.StoredGroupMessage, error) {
coll := m.mongo.Collection(MongoGroupMessageCollection)
var ret db.StoredGroupMessage
if err := coll.FindOne(context.Background(), bson.D{{"globalId", id}}).Decode(&ret); err != nil {
return nil, errors.Wrap(err, "query error")
}
return &ret, nil
}
func (m *MongoDBImpl) GetPrivateMessageByGlobalID(id int32) (*db.StoredPrivateMessage, error) {
coll := m.mongo.Collection(MongoPrivateMessageCollection)
var ret db.StoredPrivateMessage
if err := coll.FindOne(context.Background(), bson.D{{"globalId", id}}).Decode(&ret); err != nil {
return nil, errors.Wrap(err, "query error")
}
return &ret, nil
}
func (m *MongoDBImpl) GetGuildChannelMessageByID(id string) (*db.StoredGuildChannelMessage, error) {
coll := m.mongo.Collection(MongoGuildChannelMessageCollection)
var ret db.StoredGuildChannelMessage
if err := coll.FindOne(context.Background(), bson.D{{"_id", id}}).Decode(&ret); err != nil {
return nil, errors.Wrap(err, "query error")
}
return &ret, nil
}
func (m *MongoDBImpl) InsertGroupMessage(msg *db.StoredGroupMessage) error {
coll := m.mongo.Collection(MongoGroupMessageCollection)
_, err := coll.InsertOne(context.Background(), msg)
return errors.Wrap(err, "insert error")
}
func (m *MongoDBImpl) InsertPrivateMessage(msg *db.StoredPrivateMessage) error {
coll := m.mongo.Collection(MongoPrivateMessageCollection)
_, err := coll.InsertOne(context.Background(), msg)
return errors.Wrap(err, "insert error")
}
func (m *MongoDBImpl) InsertGuildChannelMessage(msg *db.StoredGuildChannelMessage) error {
coll := m.mongo.Collection(MongoGuildChannelMessageCollection)
_, err := coll.InsertOne(context.Background(), msg)
return errors.Wrap(err, "insert error")
}