mirror of
https://github.com/Genshin-bots/gsuid_core.git
synced 2025-06-01 13:09:47 +08:00
✨ 实测支持使用MySQL
数据库进行管理 (#69)
This commit is contained in:
parent
7e5abc6874
commit
ef151d55d3
@ -117,7 +117,13 @@ class _Bot:
|
||||
if forward_m.type != 'image_size':
|
||||
message_result.append([forward_m])
|
||||
elif enable_forward == '合并为一条消息':
|
||||
_temp_mr.extend(_m.data)
|
||||
_add = []
|
||||
for index, forward_m in enumerate(_m.data):
|
||||
_add.append(forward_m)
|
||||
if index < len(_m.data) - 1:
|
||||
_add.append(MessageSegment.text('\n'))
|
||||
_temp_mr.extend(_add)
|
||||
|
||||
elif enable_forward.isdigit():
|
||||
for forward_m in _m.data[: int(enable_forward)]:
|
||||
if forward_m.type != 'image_size':
|
||||
|
@ -12,6 +12,14 @@ from fastapi import WebSocket, WebSocketDisconnect
|
||||
sys.path.append(str(Path(__file__).resolve().parent))
|
||||
sys.path.append(str(Path(__file__).resolve().parents[1]))
|
||||
|
||||
# from gsuid_core.utils.database.startup import exec_list # noqa: E402
|
||||
|
||||
|
||||
async def main():
|
||||
from gsuid_core.utils.database.base_models import init_database
|
||||
|
||||
await init_database()
|
||||
|
||||
from gsuid_core.gss import gss # noqa: E402
|
||||
from gsuid_core.bot import _Bot # noqa: E402
|
||||
from gsuid_core.logger import logger # noqa: E402
|
||||
@ -19,45 +27,11 @@ from gsuid_core.web_app import app, site # noqa: E402
|
||||
from gsuid_core.config import core_config # noqa: E402
|
||||
from gsuid_core.handler import handle_event # noqa: E402
|
||||
from gsuid_core.models import MessageReceive # noqa: E402
|
||||
from gsuid_core.utils.database.startup import exec_list # noqa: E402
|
||||
|
||||
HOST = core_config.get_config('HOST')
|
||||
PORT = int(core_config.get_config('PORT'))
|
||||
ENABLE_HTTP = core_config.get_config('ENABLE_HTTP')
|
||||
HTTP_SERVER_STATUS = False
|
||||
|
||||
exec_list.extend(
|
||||
[
|
||||
'ALTER TABLE GsBind ADD COLUMN group_id TEXT',
|
||||
'ALTER TABLE GsBind ADD COLUMN sr_uid TEXT',
|
||||
'ALTER TABLE GsUser ADD COLUMN sr_uid TEXT',
|
||||
'ALTER TABLE GsUser ADD COLUMN sr_region TEXT',
|
||||
'ALTER TABLE GsUser ADD COLUMN zzz_region TEXT',
|
||||
'ALTER TABLE GsUser ADD COLUMN bb_region TEXT',
|
||||
'ALTER TABLE GsUser ADD COLUMN bbb_region TEXT',
|
||||
'ALTER TABLE GsUser ADD COLUMN wd_region TEXT',
|
||||
'ALTER TABLE GsUser ADD COLUMN fp TEXT',
|
||||
'ALTER TABLE GsUser ADD COLUMN device_id TEXT',
|
||||
'ALTER TABLE GsUser ADD COLUMN bb_uid TEXT',
|
||||
'ALTER TABLE GsUser ADD COLUMN bbb_uid TEXT',
|
||||
'ALTER TABLE GsUser ADD COLUMN zzz_uid TEXT',
|
||||
'ALTER TABLE GsUser ADD COLUMN wd_uid TEXT',
|
||||
'ALTER TABLE GsBind ADD COLUMN bb_uid TEXT',
|
||||
'ALTER TABLE GsBind ADD COLUMN bbb_uid TEXT',
|
||||
'ALTER TABLE GsBind ADD COLUMN zzz_uid TEXT',
|
||||
'ALTER TABLE GsBind ADD COLUMN wd_uid TEXT',
|
||||
'ALTER TABLE GsUser ADD COLUMN device_info TEXT',
|
||||
'ALTER TABLE GsUser ADD COLUMN sr_sign_switch TEXT DEFAULT "off"',
|
||||
'ALTER TABLE GsUser ADD COLUMN zzz_sign_switch TEXT DEFAULT "off"',
|
||||
'ALTER TABLE GsUser ADD COLUMN sr_push_switch TEXT DEFAULT "off"',
|
||||
'ALTER TABLE GsUser ADD COLUMN zzz_push_switch TEXT DEFAULT "off"',
|
||||
'ALTER TABLE GsUser ADD COLUMN draw_switch TEXT DEFAULT "off"',
|
||||
'ALTER TABLE GsCache ADD COLUMN sr_uid TEXT',
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
@app.websocket('/ws/{bot_id}')
|
||||
async def websocket_endpoint(websocket: WebSocket, bot_id: str):
|
||||
try:
|
||||
@ -87,11 +61,6 @@ def main():
|
||||
|
||||
@app.post('/api/send_msg')
|
||||
async def sendMsg(msg: Dict):
|
||||
global HTTP_SERVER_STATUS
|
||||
if not HTTP_SERVER_STATUS:
|
||||
asyncio.create_task(_bot._process())
|
||||
HTTP_SERVER_STATUS = True
|
||||
|
||||
data = msgjson.encode(msg)
|
||||
MR = msgjson.Decoder(MessageReceive).decode(data)
|
||||
result = await handle_event(_bot, MR, True)
|
||||
@ -103,7 +72,7 @@ def main():
|
||||
site.gen_plugin_page()
|
||||
site.mount_app(app)
|
||||
|
||||
uvicorn.run(
|
||||
config = uvicorn.Config(
|
||||
app,
|
||||
host=HOST,
|
||||
port=PORT,
|
||||
@ -122,9 +91,11 @@ def main():
|
||||
'level': 'INFO',
|
||||
},
|
||||
},
|
||||
},
|
||||
}, # 你的日志配置
|
||||
loop="asyncio",
|
||||
)
|
||||
server = uvicorn.Server(config)
|
||||
await server.serve()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
asyncio.run(main())
|
||||
|
@ -20,8 +20,8 @@ from gsuid_core.utils.plugins_config.gs_config import core_plugins_config
|
||||
auto_install_dep: bool = core_plugins_config.get_config('AutoInstallDep').data
|
||||
auto_update_dep: bool = core_plugins_config.get_config('AutoUpdateDep').data
|
||||
|
||||
core_start_def = set()
|
||||
core_shutdown_def = set()
|
||||
core_start_def: set[Callable] = set()
|
||||
core_shutdown_def: set[Callable] = set()
|
||||
installed_dependencies: Dict[str, str] = {}
|
||||
ignore_dep = ['python', 'fastapi', 'pydantic']
|
||||
|
||||
|
@ -1,5 +1,4 @@
|
||||
import re
|
||||
import asyncio
|
||||
from typing import Dict, Type, Tuple, Union, Literal, Optional, overload
|
||||
|
||||
from sqlalchemy import event
|
||||
@ -21,7 +20,6 @@ class DBSqla:
|
||||
|
||||
def get_sqla(self, bot_id) -> SQLA:
|
||||
sqla = self._get_sqla(bot_id, self.is_sr)
|
||||
asyncio.create_task(sqla.sr_adapter())
|
||||
return sqla
|
||||
|
||||
def _get_sqla(self, bot_id, is_sr: bool = False) -> SQLA:
|
||||
@ -29,7 +27,6 @@ class DBSqla:
|
||||
if bot_id not in sqla_list:
|
||||
sqla = SQLA(bot_id, is_sr)
|
||||
sqla_list[bot_id] = sqla
|
||||
sqla.create_all()
|
||||
|
||||
@event.listens_for(engine.sync_engine, "connect")
|
||||
def set_sqlite_pragma(dbapi_connection, connection_record):
|
||||
|
@ -11,8 +11,8 @@ from typing import (
|
||||
Awaitable,
|
||||
)
|
||||
|
||||
from sqlalchemy.pool import NullPool
|
||||
from sqlalchemy import text, create_engine
|
||||
# from sqlalchemy.pool import NullPool
|
||||
from sqlalchemy import exc, text, create_engine
|
||||
from sqlalchemy.sql.expression import func, null, true
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker # type: ignore
|
||||
@ -45,28 +45,32 @@ db_pool_recycle: int = database_config.get_config('db_pool_recycle').data
|
||||
db_custom_url = database_config.get_config('db_custom_url').data
|
||||
db_type: str = database_config.get_config('db_type').data
|
||||
|
||||
db_driver: str = database_config.get_config('db_driver').data
|
||||
|
||||
_db_type = db_type.lower()
|
||||
db_config = {
|
||||
'pool_recycle': db_pool_recycle,
|
||||
# 'pool_pre_ping': True,
|
||||
# 'pool_size': db_pool_size,
|
||||
'echo': db_echo,
|
||||
}
|
||||
|
||||
DB_PATH = get_res_path() / 'GsData.db'
|
||||
sync_url = ''
|
||||
|
||||
sync_url, engine, finally_url = '', '', ''
|
||||
async_maker: async_sessionmaker[AsyncSession] = None # type: ignore
|
||||
server_engine = None
|
||||
|
||||
if _db_type == 'sqlite':
|
||||
sync_url = 'sqlite:///'
|
||||
base_url = 'sqlite+aiosqlite:///'
|
||||
db_url = str(DB_PATH)
|
||||
# del db_config['pool_size']
|
||||
elif _db_type == 'mysql':
|
||||
sync_url = 'mysql+pymysql://'
|
||||
base_url = 'mysql+aiomysql://'
|
||||
base_url = f'mysql+{db_driver}://'
|
||||
db_hp = f'{db_host}:{db_port}' if db_port else db_host
|
||||
db_url = f'{db_user}:{db_password}@{db_hp}/'
|
||||
elif _db_type == 'postgresql':
|
||||
sync_url = 'postgresql+psycopg2://'
|
||||
sync_url = 'postgresql+psycopg://'
|
||||
base_url = 'postgresql+asyncpg://'
|
||||
db_hp = f'{db_host}:{db_port}' if db_port else db_host
|
||||
db_url = f'{db_user}:{db_password}@{db_hp}/'
|
||||
@ -77,13 +81,25 @@ else:
|
||||
base_url = db_type
|
||||
db_url = db_custom_url
|
||||
|
||||
|
||||
async def init_database():
|
||||
global engine, finally_url, async_maker
|
||||
|
||||
try:
|
||||
if _db_type == 'sqlite':
|
||||
engine = create_async_engine(f'{base_url}{db_url}', **db_config)
|
||||
finally_url = f'{base_url}{db_url}'
|
||||
else:
|
||||
server_engine = None
|
||||
db_config.update(
|
||||
{
|
||||
'pool_size': db_pool_size,
|
||||
'max_overflow': 10,
|
||||
'pool_timeout': 30,
|
||||
'isolation_level': "AUTOCOMMIT",
|
||||
}
|
||||
)
|
||||
try:
|
||||
server_engine = None
|
||||
if _db_type == 'mysql':
|
||||
server_engine = create_engine(
|
||||
f'{sync_url}{db_url}', **db_config
|
||||
@ -97,31 +113,38 @@ try:
|
||||
f"[MySQL] 数据库 {db_name} 创建成功或已存在!"
|
||||
)
|
||||
elif _db_type == 'postgresql':
|
||||
try:
|
||||
server_engine = create_engine(
|
||||
f'{sync_url}{db_url}', **db_config
|
||||
)
|
||||
with server_engine.connect() as conn:
|
||||
t1 = f"CREATE DATABASE {db_name} WITH ENCODING 'UTF8' "
|
||||
t2 = "LC_COLLATE 'en_US.UTF-8' LC_CTYPE 'en_US.UTF-8'"
|
||||
conn.execute(text(t1 + t2))
|
||||
t = f"CREATE DATABASE {db_name} WITH ENCODING "
|
||||
t2 = "'UTF8' LC_COLLATE 'en_US.UTF-8' LC_CTYPE "
|
||||
t3 = "'en_US.UTF-8' TEMPLATE template0"
|
||||
conn.execute(text(t + t2 + t3))
|
||||
except exc.ProgrammingError as e:
|
||||
if 'already exists' in str(e) or '已经存在' in str(e):
|
||||
pass
|
||||
logger.success(
|
||||
f"[PostgreSQL] 数据库 {db_name} 创建成功或已存在!"
|
||||
)
|
||||
finally:
|
||||
if server_engine is not None:
|
||||
if server_engine:
|
||||
server_engine.dispose()
|
||||
logger.debug("[SQL] 同步数据库引擎已释放")
|
||||
logger.info('[数据库] 临时数据库连接已释放!')
|
||||
|
||||
db_config['poolclass'] = NullPool
|
||||
# db_config['poolclass'] = NullPool
|
||||
finally_url = f'{base_url}{db_url}{db_name}'
|
||||
engine = create_async_engine(finally_url, **db_config)
|
||||
|
||||
async_maker = async_sessionmaker(
|
||||
engine,
|
||||
expire_on_commit=False,
|
||||
close_resets_only=False,
|
||||
class_=AsyncSession,
|
||||
)
|
||||
except: # noqa: E722
|
||||
except Exception as e: # noqa: E722
|
||||
logger.exception(f'[GsCore] [数据库] 连接失败: {e}')
|
||||
raise ValueError(
|
||||
f'[GsCore] [数据库] [{base_url}] 连接失败, 请检查配置文件!'
|
||||
)
|
||||
@ -132,14 +155,16 @@ def with_session(
|
||||
) -> Callable[Concatenate[Any, P], Awaitable[R]]:
|
||||
@wraps(func)
|
||||
async def wrapper(self, *args: P.args, **kwargs: P.kwargs):
|
||||
async with async_maker() as session:
|
||||
max_retries = 3
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
async with async_maker() as session:
|
||||
data = await func(self, session, *args, **kwargs)
|
||||
await session.commit()
|
||||
return data
|
||||
except Exception as e:
|
||||
print(e)
|
||||
raise e
|
||||
logger.exception(f"[数据库] 第 {attempt + 1} 次重试失败: {e}")
|
||||
continue
|
||||
|
||||
return wrapper # type: ignore
|
||||
|
||||
|
@ -1,12 +1,8 @@
|
||||
import re
|
||||
import asyncio
|
||||
from typing import Dict, List, Literal, Optional
|
||||
|
||||
from sqlmodel import SQLModel
|
||||
from sqlalchemy.sql import text
|
||||
|
||||
from .base_models import async_maker
|
||||
from .utils import SERVER, SR_SERVER
|
||||
from .base_models import engine, async_maker
|
||||
from .models import GsBind, GsPush, GsUser, GsCache
|
||||
|
||||
|
||||
@ -15,40 +11,6 @@ class SQLA:
|
||||
self.bot_id = bot_id
|
||||
self.is_sr = is_sr
|
||||
|
||||
def create_all(self):
|
||||
try:
|
||||
asyncio.create_task(self._create_all())
|
||||
except RuntimeError:
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.run_until_complete(self._create_all())
|
||||
loop.close()
|
||||
|
||||
async def _create_all(self):
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(SQLModel.metadata.create_all)
|
||||
await self.sr_adapter()
|
||||
|
||||
async def sr_adapter(self):
|
||||
exec_list = [
|
||||
'ALTER TABLE GsBind ADD COLUMN group_id TEXT',
|
||||
'ALTER TABLE GsBind ADD COLUMN sr_uid TEXT',
|
||||
'ALTER TABLE GsUser ADD COLUMN sr_uid TEXT',
|
||||
'ALTER TABLE GsUser ADD COLUMN sr_region TEXT',
|
||||
'ALTER TABLE GsUser ADD COLUMN fp TEXT',
|
||||
'ALTER TABLE GsUser ADD COLUMN device_id TEXT',
|
||||
'ALTER TABLE GsUser ADD COLUMN sr_sign_switch TEXT DEFAULT "off"',
|
||||
'ALTER TABLE GsUser ADD COLUMN sr_push_switch TEXT DEFAULT "off"',
|
||||
'ALTER TABLE GsUser ADD COLUMN draw_switch TEXT DEFAULT "off"',
|
||||
'ALTER TABLE GsCache ADD COLUMN sr_uid TEXT',
|
||||
]
|
||||
async with async_maker() as session:
|
||||
for _t in exec_list:
|
||||
try:
|
||||
await session.execute(text(_t))
|
||||
await session.commit()
|
||||
except: # noqa: E722
|
||||
pass
|
||||
|
||||
#####################
|
||||
# GsBind 部分 #
|
||||
#####################
|
||||
@ -416,3 +378,4 @@ class SQLA:
|
||||
|
||||
async def insert_new_user(self, **kwargs):
|
||||
await GsUser.full_insert_data(**kwargs)
|
||||
await GsUser.full_insert_data(**kwargs)
|
||||
|
@ -31,7 +31,7 @@ async def move_database():
|
||||
|
||||
|
||||
# @on_core_start
|
||||
async def sr_adapter():
|
||||
async def trans_adapter():
|
||||
async with engine.begin() as conn:
|
||||
metadata = MetaData()
|
||||
try:
|
||||
|
@ -9,6 +9,12 @@ DATABASE_CONIFG: Dict[str, GSC] = {
|
||||
'SQLite',
|
||||
['SQLite', 'MySql', 'PostgreSQL', '自定义'],
|
||||
),
|
||||
'db_driver': GsStrConfig(
|
||||
'MySQL驱动',
|
||||
'设置喜欢的MySQL驱动',
|
||||
'aiomysql',
|
||||
['aiomysql', 'asyncmy'],
|
||||
),
|
||||
'db_custom_url': GsStrConfig(
|
||||
'自定义数据库连接地址 (一般无需填写)',
|
||||
'设置自定义数据库连接',
|
||||
@ -31,7 +37,7 @@ DATABASE_CONIFG: Dict[str, GSC] = {
|
||||
'数据库用户名',
|
||||
'设置数据库用户名',
|
||||
'root',
|
||||
['root', 'admin'],
|
||||
['root', 'admin', 'postgres'],
|
||||
),
|
||||
'db_password': GsStrConfig(
|
||||
'数据库密码',
|
||||
|
@ -54,6 +54,10 @@ pic_expire_time = core_plugins_config.get_config('ScheduledCleanPicSrv').data
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
try:
|
||||
logger.info(
|
||||
'[GsCore] 执行启动Hook函数中!',
|
||||
[_def.__name__ for _def in core_start_def],
|
||||
)
|
||||
_task = [_def() for _def in core_start_def]
|
||||
await asyncio.gather(*_task)
|
||||
except Exception as e:
|
||||
@ -64,9 +68,16 @@ async def lifespan(app: FastAPI):
|
||||
await start_check() # type:ignore
|
||||
await start_scheduler()
|
||||
asyncio.create_task(clean_log())
|
||||
|
||||
yield
|
||||
|
||||
await shutdown_scheduler()
|
||||
|
||||
try:
|
||||
logger.info(
|
||||
'[GsCore] 执行关闭Hook函数中!',
|
||||
[_def.__name__ for _def in core_shutdown_def],
|
||||
)
|
||||
_task = [_def() for _def in core_shutdown_def]
|
||||
await asyncio.gather(*_task)
|
||||
except Exception as e:
|
||||
|
Loading…
x
Reference in New Issue
Block a user