实测支持使用MySQL数据库进行管理 (#69)

This commit is contained in:
KimigaiiWuyi 2025-05-12 04:24:11 +08:00
parent 7e5abc6874
commit ef151d55d3
9 changed files with 136 additions and 157 deletions

View File

@ -117,7 +117,13 @@ class _Bot:
if forward_m.type != 'image_size': if forward_m.type != 'image_size':
message_result.append([forward_m]) message_result.append([forward_m])
elif enable_forward == '合并为一条消息': elif enable_forward == '合并为一条消息':
_temp_mr.extend(_m.data) _add = []
for index, forward_m in enumerate(_m.data):
_add.append(forward_m)
if index < len(_m.data) - 1:
_add.append(MessageSegment.text('\n'))
_temp_mr.extend(_add)
elif enable_forward.isdigit(): elif enable_forward.isdigit():
for forward_m in _m.data[: int(enable_forward)]: for forward_m in _m.data[: int(enable_forward)]:
if forward_m.type != 'image_size': if forward_m.type != 'image_size':

View File

@ -12,52 +12,26 @@ from fastapi import WebSocket, WebSocketDisconnect
sys.path.append(str(Path(__file__).resolve().parent)) sys.path.append(str(Path(__file__).resolve().parent))
sys.path.append(str(Path(__file__).resolve().parents[1])) sys.path.append(str(Path(__file__).resolve().parents[1]))
from gsuid_core.gss import gss # noqa: E402 # from gsuid_core.utils.database.startup import exec_list # noqa: E402
from gsuid_core.bot import _Bot # noqa: E402
from gsuid_core.logger import logger # noqa: E402
from gsuid_core.web_app import app, site # noqa: E402
from gsuid_core.config import core_config # noqa: E402
from gsuid_core.handler import handle_event # noqa: E402
from gsuid_core.models import MessageReceive # noqa: E402
from gsuid_core.utils.database.startup import exec_list # noqa: E402
HOST = core_config.get_config('HOST')
PORT = int(core_config.get_config('PORT'))
ENABLE_HTTP = core_config.get_config('ENABLE_HTTP')
HTTP_SERVER_STATUS = False
exec_list.extend(
[
'ALTER TABLE GsBind ADD COLUMN group_id TEXT',
'ALTER TABLE GsBind ADD COLUMN sr_uid TEXT',
'ALTER TABLE GsUser ADD COLUMN sr_uid TEXT',
'ALTER TABLE GsUser ADD COLUMN sr_region TEXT',
'ALTER TABLE GsUser ADD COLUMN zzz_region TEXT',
'ALTER TABLE GsUser ADD COLUMN bb_region TEXT',
'ALTER TABLE GsUser ADD COLUMN bbb_region TEXT',
'ALTER TABLE GsUser ADD COLUMN wd_region TEXT',
'ALTER TABLE GsUser ADD COLUMN fp TEXT',
'ALTER TABLE GsUser ADD COLUMN device_id TEXT',
'ALTER TABLE GsUser ADD COLUMN bb_uid TEXT',
'ALTER TABLE GsUser ADD COLUMN bbb_uid TEXT',
'ALTER TABLE GsUser ADD COLUMN zzz_uid TEXT',
'ALTER TABLE GsUser ADD COLUMN wd_uid TEXT',
'ALTER TABLE GsBind ADD COLUMN bb_uid TEXT',
'ALTER TABLE GsBind ADD COLUMN bbb_uid TEXT',
'ALTER TABLE GsBind ADD COLUMN zzz_uid TEXT',
'ALTER TABLE GsBind ADD COLUMN wd_uid TEXT',
'ALTER TABLE GsUser ADD COLUMN device_info TEXT',
'ALTER TABLE GsUser ADD COLUMN sr_sign_switch TEXT DEFAULT "off"',
'ALTER TABLE GsUser ADD COLUMN zzz_sign_switch TEXT DEFAULT "off"',
'ALTER TABLE GsUser ADD COLUMN sr_push_switch TEXT DEFAULT "off"',
'ALTER TABLE GsUser ADD COLUMN zzz_push_switch TEXT DEFAULT "off"',
'ALTER TABLE GsUser ADD COLUMN draw_switch TEXT DEFAULT "off"',
'ALTER TABLE GsCache ADD COLUMN sr_uid TEXT',
]
)
def main(): async def main():
from gsuid_core.utils.database.base_models import init_database
await init_database()
from gsuid_core.gss import gss # noqa: E402
from gsuid_core.bot import _Bot # noqa: E402
from gsuid_core.logger import logger # noqa: E402
from gsuid_core.web_app import app, site # noqa: E402
from gsuid_core.config import core_config # noqa: E402
from gsuid_core.handler import handle_event # noqa: E402
from gsuid_core.models import MessageReceive # noqa: E402
HOST = core_config.get_config('HOST')
PORT = int(core_config.get_config('PORT'))
ENABLE_HTTP = core_config.get_config('ENABLE_HTTP')
@app.websocket('/ws/{bot_id}') @app.websocket('/ws/{bot_id}')
async def websocket_endpoint(websocket: WebSocket, bot_id: str): async def websocket_endpoint(websocket: WebSocket, bot_id: str):
try: try:
@ -87,11 +61,6 @@ def main():
@app.post('/api/send_msg') @app.post('/api/send_msg')
async def sendMsg(msg: Dict): async def sendMsg(msg: Dict):
global HTTP_SERVER_STATUS
if not HTTP_SERVER_STATUS:
asyncio.create_task(_bot._process())
HTTP_SERVER_STATUS = True
data = msgjson.encode(msg) data = msgjson.encode(msg)
MR = msgjson.Decoder(MessageReceive).decode(data) MR = msgjson.Decoder(MessageReceive).decode(data)
result = await handle_event(_bot, MR, True) result = await handle_event(_bot, MR, True)
@ -103,7 +72,7 @@ def main():
site.gen_plugin_page() site.gen_plugin_page()
site.mount_app(app) site.mount_app(app)
uvicorn.run( config = uvicorn.Config(
app, app,
host=HOST, host=HOST,
port=PORT, port=PORT,
@ -122,9 +91,11 @@ def main():
'level': 'INFO', 'level': 'INFO',
}, },
}, },
}, }, # 你的日志配置
loop="asyncio",
) )
server = uvicorn.Server(config)
await server.serve()
if __name__ == '__main__': asyncio.run(main())
main()

View File

@ -20,8 +20,8 @@ from gsuid_core.utils.plugins_config.gs_config import core_plugins_config
auto_install_dep: bool = core_plugins_config.get_config('AutoInstallDep').data auto_install_dep: bool = core_plugins_config.get_config('AutoInstallDep').data
auto_update_dep: bool = core_plugins_config.get_config('AutoUpdateDep').data auto_update_dep: bool = core_plugins_config.get_config('AutoUpdateDep').data
core_start_def = set() core_start_def: set[Callable] = set()
core_shutdown_def = set() core_shutdown_def: set[Callable] = set()
installed_dependencies: Dict[str, str] = {} installed_dependencies: Dict[str, str] = {}
ignore_dep = ['python', 'fastapi', 'pydantic'] ignore_dep = ['python', 'fastapi', 'pydantic']

View File

@ -1,5 +1,4 @@
import re import re
import asyncio
from typing import Dict, Type, Tuple, Union, Literal, Optional, overload from typing import Dict, Type, Tuple, Union, Literal, Optional, overload
from sqlalchemy import event from sqlalchemy import event
@ -21,7 +20,6 @@ class DBSqla:
def get_sqla(self, bot_id) -> SQLA: def get_sqla(self, bot_id) -> SQLA:
sqla = self._get_sqla(bot_id, self.is_sr) sqla = self._get_sqla(bot_id, self.is_sr)
asyncio.create_task(sqla.sr_adapter())
return sqla return sqla
def _get_sqla(self, bot_id, is_sr: bool = False) -> SQLA: def _get_sqla(self, bot_id, is_sr: bool = False) -> SQLA:
@ -29,7 +27,6 @@ class DBSqla:
if bot_id not in sqla_list: if bot_id not in sqla_list:
sqla = SQLA(bot_id, is_sr) sqla = SQLA(bot_id, is_sr)
sqla_list[bot_id] = sqla sqla_list[bot_id] = sqla
sqla.create_all()
@event.listens_for(engine.sync_engine, "connect") @event.listens_for(engine.sync_engine, "connect")
def set_sqlite_pragma(dbapi_connection, connection_record): def set_sqlite_pragma(dbapi_connection, connection_record):

View File

@ -11,8 +11,8 @@ from typing import (
Awaitable, Awaitable,
) )
from sqlalchemy.pool import NullPool # from sqlalchemy.pool import NullPool
from sqlalchemy import text, create_engine from sqlalchemy import exc, text, create_engine
from sqlalchemy.sql.expression import func, null, true from sqlalchemy.sql.expression import func, null, true
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.ext.asyncio import async_sessionmaker # type: ignore from sqlalchemy.ext.asyncio import async_sessionmaker # type: ignore
@ -45,28 +45,32 @@ db_pool_recycle: int = database_config.get_config('db_pool_recycle').data
db_custom_url = database_config.get_config('db_custom_url').data db_custom_url = database_config.get_config('db_custom_url').data
db_type: str = database_config.get_config('db_type').data db_type: str = database_config.get_config('db_type').data
db_driver: str = database_config.get_config('db_driver').data
_db_type = db_type.lower() _db_type = db_type.lower()
db_config = { db_config = {
'pool_recycle': db_pool_recycle, 'pool_recycle': db_pool_recycle,
# 'pool_pre_ping': True,
# 'pool_size': db_pool_size,
'echo': db_echo, 'echo': db_echo,
} }
DB_PATH = get_res_path() / 'GsData.db' DB_PATH = get_res_path() / 'GsData.db'
sync_url = ''
sync_url, engine, finally_url = '', '', ''
async_maker: async_sessionmaker[AsyncSession] = None # type: ignore
server_engine = None
if _db_type == 'sqlite': if _db_type == 'sqlite':
sync_url = 'sqlite:///'
base_url = 'sqlite+aiosqlite:///' base_url = 'sqlite+aiosqlite:///'
db_url = str(DB_PATH) db_url = str(DB_PATH)
# del db_config['pool_size'] # del db_config['pool_size']
elif _db_type == 'mysql': elif _db_type == 'mysql':
sync_url = 'mysql+pymysql://' sync_url = 'mysql+pymysql://'
base_url = 'mysql+aiomysql://' base_url = f'mysql+{db_driver}://'
db_hp = f'{db_host}:{db_port}' if db_port else db_host db_hp = f'{db_host}:{db_port}' if db_port else db_host
db_url = f'{db_user}:{db_password}@{db_hp}/' db_url = f'{db_user}:{db_password}@{db_hp}/'
elif _db_type == 'postgresql': elif _db_type == 'postgresql':
sync_url = 'postgresql+psycopg2://' sync_url = 'postgresql+psycopg://'
base_url = 'postgresql+asyncpg://' base_url = 'postgresql+asyncpg://'
db_hp = f'{db_host}:{db_port}' if db_port else db_host db_hp = f'{db_host}:{db_port}' if db_port else db_host
db_url = f'{db_user}:{db_password}@{db_hp}/' db_url = f'{db_user}:{db_password}@{db_hp}/'
@ -77,54 +81,73 @@ else:
base_url = db_type base_url = db_type
db_url = db_custom_url db_url = db_custom_url
try:
if _db_type == 'sqlite':
engine = create_async_engine(f'{base_url}{db_url}', **db_config)
finally_url = f'{base_url}{db_url}'
else:
server_engine = None
try:
if _db_type == 'mysql':
server_engine = create_engine(
f'{sync_url}{db_url}', **db_config
)
with server_engine.connect() as conn: async def init_database():
t1 = f"CREATE DATABASE IF NOT EXISTS {db_name} " global engine, finally_url, async_maker
t2 = "CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci"
conn.execute(text(t1 + t2)) try:
logger.success( if _db_type == 'sqlite':
f"[MySQL] 数据库 {db_name} 创建成功或已存在!" engine = create_async_engine(f'{base_url}{db_url}', **db_config)
finally_url = f'{base_url}{db_url}'
else:
db_config.update(
{
'pool_size': db_pool_size,
'max_overflow': 10,
'pool_timeout': 30,
'isolation_level': "AUTOCOMMIT",
}
)
try:
server_engine = None
if _db_type == 'mysql':
server_engine = create_engine(
f'{sync_url}{db_url}', **db_config
) )
elif _db_type == 'postgresql':
server_engine = create_engine(
f'{sync_url}{db_url}', **db_config
)
with server_engine.connect() as conn:
t1 = f"CREATE DATABASE {db_name} WITH ENCODING 'UTF8' "
t2 = "LC_COLLATE 'en_US.UTF-8' LC_CTYPE 'en_US.UTF-8'"
conn.execute(text(t1 + t2))
logger.success(
f"[PostgreSQL] 数据库 {db_name} 创建成功或已存在!"
)
finally:
if server_engine is not None:
server_engine.dispose()
logger.debug("[SQL] 同步数据库引擎已释放")
db_config['poolclass'] = NullPool with server_engine.connect() as conn:
finally_url = f'{base_url}{db_url}{db_name}' t1 = f"CREATE DATABASE IF NOT EXISTS {db_name} "
engine = create_async_engine(finally_url, **db_config) t2 = "CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci"
conn.execute(text(t1 + t2))
logger.success(
f"[MySQL] 数据库 {db_name} 创建成功或已存在!"
)
elif _db_type == 'postgresql':
try:
server_engine = create_engine(
f'{sync_url}{db_url}', **db_config
)
with server_engine.connect() as conn:
t = f"CREATE DATABASE {db_name} WITH ENCODING "
t2 = "'UTF8' LC_COLLATE 'en_US.UTF-8' LC_CTYPE "
t3 = "'en_US.UTF-8' TEMPLATE template0"
conn.execute(text(t + t2 + t3))
except exc.ProgrammingError as e:
if 'already exists' in str(e) or '已经存在' in str(e):
pass
logger.success(
f"[PostgreSQL] 数据库 {db_name} 创建成功或已存在!"
)
finally:
if server_engine:
server_engine.dispose()
logger.info('[数据库] 临时数据库连接已释放!')
async_maker = async_sessionmaker( # db_config['poolclass'] = NullPool
engine, finally_url = f'{base_url}{db_url}{db_name}'
expire_on_commit=False, engine = create_async_engine(finally_url, **db_config)
class_=AsyncSession,
) async_maker = async_sessionmaker(
except: # noqa: E722 engine,
raise ValueError( expire_on_commit=False,
f'[GsCore] [数据库] [{base_url}] 连接失败, 请检查配置文件!' close_resets_only=False,
) class_=AsyncSession,
)
except Exception as e: # noqa: E722
logger.exception(f'[GsCore] [数据库] 连接失败: {e}')
raise ValueError(
f'[GsCore] [数据库] [{base_url}] 连接失败, 请检查配置文件!'
)
def with_session( def with_session(
@ -132,14 +155,16 @@ def with_session(
) -> Callable[Concatenate[Any, P], Awaitable[R]]: ) -> Callable[Concatenate[Any, P], Awaitable[R]]:
@wraps(func) @wraps(func)
async def wrapper(self, *args: P.args, **kwargs: P.kwargs): async def wrapper(self, *args: P.args, **kwargs: P.kwargs):
async with async_maker() as session: max_retries = 3
for attempt in range(max_retries):
try: try:
data = await func(self, session, *args, **kwargs) async with async_maker() as session:
await session.commit() data = await func(self, session, *args, **kwargs)
return data await session.commit()
return data
except Exception as e: except Exception as e:
print(e) logger.exception(f"[数据库] 第 {attempt + 1} 次重试失败: {e}")
raise e continue
return wrapper # type: ignore return wrapper # type: ignore

View File

@ -1,12 +1,8 @@
import re import re
import asyncio
from typing import Dict, List, Literal, Optional from typing import Dict, List, Literal, Optional
from sqlmodel import SQLModel from .base_models import async_maker
from sqlalchemy.sql import text
from .utils import SERVER, SR_SERVER from .utils import SERVER, SR_SERVER
from .base_models import engine, async_maker
from .models import GsBind, GsPush, GsUser, GsCache from .models import GsBind, GsPush, GsUser, GsCache
@ -15,40 +11,6 @@ class SQLA:
self.bot_id = bot_id self.bot_id = bot_id
self.is_sr = is_sr self.is_sr = is_sr
def create_all(self):
try:
asyncio.create_task(self._create_all())
except RuntimeError:
loop = asyncio.get_event_loop()
loop.run_until_complete(self._create_all())
loop.close()
async def _create_all(self):
async with engine.begin() as conn:
await conn.run_sync(SQLModel.metadata.create_all)
await self.sr_adapter()
async def sr_adapter(self):
exec_list = [
'ALTER TABLE GsBind ADD COLUMN group_id TEXT',
'ALTER TABLE GsBind ADD COLUMN sr_uid TEXT',
'ALTER TABLE GsUser ADD COLUMN sr_uid TEXT',
'ALTER TABLE GsUser ADD COLUMN sr_region TEXT',
'ALTER TABLE GsUser ADD COLUMN fp TEXT',
'ALTER TABLE GsUser ADD COLUMN device_id TEXT',
'ALTER TABLE GsUser ADD COLUMN sr_sign_switch TEXT DEFAULT "off"',
'ALTER TABLE GsUser ADD COLUMN sr_push_switch TEXT DEFAULT "off"',
'ALTER TABLE GsUser ADD COLUMN draw_switch TEXT DEFAULT "off"',
'ALTER TABLE GsCache ADD COLUMN sr_uid TEXT',
]
async with async_maker() as session:
for _t in exec_list:
try:
await session.execute(text(_t))
await session.commit()
except: # noqa: E722
pass
##################### #####################
# GsBind 部分 # # GsBind 部分 #
##################### #####################
@ -416,3 +378,4 @@ class SQLA:
async def insert_new_user(self, **kwargs): async def insert_new_user(self, **kwargs):
await GsUser.full_insert_data(**kwargs) await GsUser.full_insert_data(**kwargs)
await GsUser.full_insert_data(**kwargs)

View File

@ -31,7 +31,7 @@ async def move_database():
# @on_core_start # @on_core_start
async def sr_adapter(): async def trans_adapter():
async with engine.begin() as conn: async with engine.begin() as conn:
metadata = MetaData() metadata = MetaData()
try: try:

View File

@ -9,6 +9,12 @@ DATABASE_CONIFG: Dict[str, GSC] = {
'SQLite', 'SQLite',
['SQLite', 'MySql', 'PostgreSQL', '自定义'], ['SQLite', 'MySql', 'PostgreSQL', '自定义'],
), ),
'db_driver': GsStrConfig(
'MySQL驱动',
'设置喜欢的MySQL驱动',
'aiomysql',
['aiomysql', 'asyncmy'],
),
'db_custom_url': GsStrConfig( 'db_custom_url': GsStrConfig(
'自定义数据库连接地址 (一般无需填写)', '自定义数据库连接地址 (一般无需填写)',
'设置自定义数据库连接', '设置自定义数据库连接',
@ -31,7 +37,7 @@ DATABASE_CONIFG: Dict[str, GSC] = {
'数据库用户名', '数据库用户名',
'设置数据库用户名', '设置数据库用户名',
'root', 'root',
['root', 'admin'], ['root', 'admin', 'postgres'],
), ),
'db_password': GsStrConfig( 'db_password': GsStrConfig(
'数据库密码', '数据库密码',

View File

@ -54,6 +54,10 @@ pic_expire_time = core_plugins_config.get_config('ScheduledCleanPicSrv').data
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
try: try:
logger.info(
'[GsCore] 执行启动Hook函数中',
[_def.__name__ for _def in core_start_def],
)
_task = [_def() for _def in core_start_def] _task = [_def() for _def in core_start_def]
await asyncio.gather(*_task) await asyncio.gather(*_task)
except Exception as e: except Exception as e:
@ -64,9 +68,16 @@ async def lifespan(app: FastAPI):
await start_check() # type:ignore await start_check() # type:ignore
await start_scheduler() await start_scheduler()
asyncio.create_task(clean_log()) asyncio.create_task(clean_log())
yield yield
await shutdown_scheduler() await shutdown_scheduler()
try: try:
logger.info(
'[GsCore] 执行关闭Hook函数中',
[_def.__name__ for _def in core_shutdown_def],
)
_task = [_def() for _def in core_shutdown_def] _task = [_def() for _def in core_shutdown_def]
await asyncio.gather(*_task) await asyncio.gather(*_task)
except Exception as e: except Exception as e: