diff --git a/gsuid_core/bot.py b/gsuid_core/bot.py index c066fc8..9845056 100644 --- a/gsuid_core/bot.py +++ b/gsuid_core/bot.py @@ -117,7 +117,13 @@ class _Bot: if forward_m.type != 'image_size': message_result.append([forward_m]) elif enable_forward == '合并为一条消息': - _temp_mr.extend(_m.data) + _add = [] + for index, forward_m in enumerate(_m.data): + _add.append(forward_m) + if index < len(_m.data) - 1: + _add.append(MessageSegment.text('\n')) + _temp_mr.extend(_add) + elif enable_forward.isdigit(): for forward_m in _m.data[: int(enable_forward)]: if forward_m.type != 'image_size': diff --git a/gsuid_core/core.py b/gsuid_core/core.py index 55012f5..52ed1ce 100644 --- a/gsuid_core/core.py +++ b/gsuid_core/core.py @@ -12,52 +12,26 @@ from fastapi import WebSocket, WebSocketDisconnect sys.path.append(str(Path(__file__).resolve().parent)) sys.path.append(str(Path(__file__).resolve().parents[1])) -from gsuid_core.gss import gss # noqa: E402 -from gsuid_core.bot import _Bot # noqa: E402 -from gsuid_core.logger import logger # noqa: E402 -from gsuid_core.web_app import app, site # noqa: E402 -from gsuid_core.config import core_config # noqa: E402 -from gsuid_core.handler import handle_event # noqa: E402 -from gsuid_core.models import MessageReceive # noqa: E402 -from gsuid_core.utils.database.startup import exec_list # noqa: E402 - -HOST = core_config.get_config('HOST') -PORT = int(core_config.get_config('PORT')) -ENABLE_HTTP = core_config.get_config('ENABLE_HTTP') -HTTP_SERVER_STATUS = False - -exec_list.extend( - [ - 'ALTER TABLE GsBind ADD COLUMN group_id TEXT', - 'ALTER TABLE GsBind ADD COLUMN sr_uid TEXT', - 'ALTER TABLE GsUser ADD COLUMN sr_uid TEXT', - 'ALTER TABLE GsUser ADD COLUMN sr_region TEXT', - 'ALTER TABLE GsUser ADD COLUMN zzz_region TEXT', - 'ALTER TABLE GsUser ADD COLUMN bb_region TEXT', - 'ALTER TABLE GsUser ADD COLUMN bbb_region TEXT', - 'ALTER TABLE GsUser ADD COLUMN wd_region TEXT', - 'ALTER TABLE GsUser ADD COLUMN fp TEXT', - 'ALTER TABLE GsUser ADD COLUMN device_id TEXT', - 'ALTER TABLE GsUser ADD COLUMN bb_uid TEXT', - 'ALTER TABLE GsUser ADD COLUMN bbb_uid TEXT', - 'ALTER TABLE GsUser ADD COLUMN zzz_uid TEXT', - 'ALTER TABLE GsUser ADD COLUMN wd_uid TEXT', - 'ALTER TABLE GsBind ADD COLUMN bb_uid TEXT', - 'ALTER TABLE GsBind ADD COLUMN bbb_uid TEXT', - 'ALTER TABLE GsBind ADD COLUMN zzz_uid TEXT', - 'ALTER TABLE GsBind ADD COLUMN wd_uid TEXT', - 'ALTER TABLE GsUser ADD COLUMN device_info TEXT', - 'ALTER TABLE GsUser ADD COLUMN sr_sign_switch TEXT DEFAULT "off"', - 'ALTER TABLE GsUser ADD COLUMN zzz_sign_switch TEXT DEFAULT "off"', - 'ALTER TABLE GsUser ADD COLUMN sr_push_switch TEXT DEFAULT "off"', - 'ALTER TABLE GsUser ADD COLUMN zzz_push_switch TEXT DEFAULT "off"', - 'ALTER TABLE GsUser ADD COLUMN draw_switch TEXT DEFAULT "off"', - 'ALTER TABLE GsCache ADD COLUMN sr_uid TEXT', - ] -) +# from gsuid_core.utils.database.startup import exec_list # noqa: E402 -def main(): +async def main(): + from gsuid_core.utils.database.base_models import init_database + + await init_database() + + from gsuid_core.gss import gss # noqa: E402 + from gsuid_core.bot import _Bot # noqa: E402 + from gsuid_core.logger import logger # noqa: E402 + from gsuid_core.web_app import app, site # noqa: E402 + from gsuid_core.config import core_config # noqa: E402 + from gsuid_core.handler import handle_event # noqa: E402 + from gsuid_core.models import MessageReceive # noqa: E402 + + HOST = core_config.get_config('HOST') + PORT = int(core_config.get_config('PORT')) + ENABLE_HTTP = core_config.get_config('ENABLE_HTTP') + @app.websocket('/ws/{bot_id}') async def websocket_endpoint(websocket: WebSocket, bot_id: str): try: @@ -87,11 +61,6 @@ def main(): @app.post('/api/send_msg') async def sendMsg(msg: Dict): - global HTTP_SERVER_STATUS - if not HTTP_SERVER_STATUS: - asyncio.create_task(_bot._process()) - HTTP_SERVER_STATUS = True - data = msgjson.encode(msg) MR = msgjson.Decoder(MessageReceive).decode(data) result = await handle_event(_bot, MR, True) @@ -103,7 +72,7 @@ def main(): site.gen_plugin_page() site.mount_app(app) - uvicorn.run( + config = uvicorn.Config( app, host=HOST, port=PORT, @@ -122,9 +91,11 @@ def main(): 'level': 'INFO', }, }, - }, + }, # 你的日志配置 + loop="asyncio", ) + server = uvicorn.Server(config) + await server.serve() -if __name__ == '__main__': - main() +asyncio.run(main()) diff --git a/gsuid_core/server.py b/gsuid_core/server.py index 26eaf02..992931f 100644 --- a/gsuid_core/server.py +++ b/gsuid_core/server.py @@ -20,8 +20,8 @@ from gsuid_core.utils.plugins_config.gs_config import core_plugins_config auto_install_dep: bool = core_plugins_config.get_config('AutoInstallDep').data auto_update_dep: bool = core_plugins_config.get_config('AutoUpdateDep').data -core_start_def = set() -core_shutdown_def = set() +core_start_def: set[Callable] = set() +core_shutdown_def: set[Callable] = set() installed_dependencies: Dict[str, str] = {} ignore_dep = ['python', 'fastapi', 'pydantic'] diff --git a/gsuid_core/utils/database/api.py b/gsuid_core/utils/database/api.py index 3a30a95..ca1891a 100644 --- a/gsuid_core/utils/database/api.py +++ b/gsuid_core/utils/database/api.py @@ -1,5 +1,4 @@ import re -import asyncio from typing import Dict, Type, Tuple, Union, Literal, Optional, overload from sqlalchemy import event @@ -21,7 +20,6 @@ class DBSqla: def get_sqla(self, bot_id) -> SQLA: sqla = self._get_sqla(bot_id, self.is_sr) - asyncio.create_task(sqla.sr_adapter()) return sqla def _get_sqla(self, bot_id, is_sr: bool = False) -> SQLA: @@ -29,7 +27,6 @@ class DBSqla: if bot_id not in sqla_list: sqla = SQLA(bot_id, is_sr) sqla_list[bot_id] = sqla - sqla.create_all() @event.listens_for(engine.sync_engine, "connect") def set_sqlite_pragma(dbapi_connection, connection_record): diff --git a/gsuid_core/utils/database/base_models.py b/gsuid_core/utils/database/base_models.py index 7cbe4d3..1c03f1f 100644 --- a/gsuid_core/utils/database/base_models.py +++ b/gsuid_core/utils/database/base_models.py @@ -11,8 +11,8 @@ from typing import ( Awaitable, ) -from sqlalchemy.pool import NullPool -from sqlalchemy import text, create_engine +# from sqlalchemy.pool import NullPool +from sqlalchemy import exc, text, create_engine from sqlalchemy.sql.expression import func, null, true from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine from sqlalchemy.ext.asyncio import async_sessionmaker # type: ignore @@ -45,28 +45,32 @@ db_pool_recycle: int = database_config.get_config('db_pool_recycle').data db_custom_url = database_config.get_config('db_custom_url').data db_type: str = database_config.get_config('db_type').data +db_driver: str = database_config.get_config('db_driver').data + _db_type = db_type.lower() db_config = { 'pool_recycle': db_pool_recycle, - # 'pool_pre_ping': True, - # 'pool_size': db_pool_size, 'echo': db_echo, } DB_PATH = get_res_path() / 'GsData.db' -sync_url = '' + +sync_url, engine, finally_url = '', '', '' +async_maker: async_sessionmaker[AsyncSession] = None # type: ignore +server_engine = None if _db_type == 'sqlite': + sync_url = 'sqlite:///' base_url = 'sqlite+aiosqlite:///' db_url = str(DB_PATH) # del db_config['pool_size'] elif _db_type == 'mysql': sync_url = 'mysql+pymysql://' - base_url = 'mysql+aiomysql://' + base_url = f'mysql+{db_driver}://' db_hp = f'{db_host}:{db_port}' if db_port else db_host db_url = f'{db_user}:{db_password}@{db_hp}/' elif _db_type == 'postgresql': - sync_url = 'postgresql+psycopg2://' + sync_url = 'postgresql+psycopg://' base_url = 'postgresql+asyncpg://' db_hp = f'{db_host}:{db_port}' if db_port else db_host db_url = f'{db_user}:{db_password}@{db_hp}/' @@ -77,54 +81,73 @@ else: base_url = db_type db_url = db_custom_url -try: - if _db_type == 'sqlite': - engine = create_async_engine(f'{base_url}{db_url}', **db_config) - finally_url = f'{base_url}{db_url}' - else: - server_engine = None - try: - if _db_type == 'mysql': - server_engine = create_engine( - f'{sync_url}{db_url}', **db_config - ) - with server_engine.connect() as conn: - t1 = f"CREATE DATABASE IF NOT EXISTS {db_name} " - t2 = "CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci" - conn.execute(text(t1 + t2)) - logger.success( - f"[MySQL] 数据库 {db_name} 创建成功或已存在!" +async def init_database(): + global engine, finally_url, async_maker + + try: + if _db_type == 'sqlite': + engine = create_async_engine(f'{base_url}{db_url}', **db_config) + finally_url = f'{base_url}{db_url}' + else: + db_config.update( + { + 'pool_size': db_pool_size, + 'max_overflow': 10, + 'pool_timeout': 30, + 'isolation_level': "AUTOCOMMIT", + } + ) + try: + server_engine = None + if _db_type == 'mysql': + server_engine = create_engine( + f'{sync_url}{db_url}', **db_config ) - elif _db_type == 'postgresql': - server_engine = create_engine( - f'{sync_url}{db_url}', **db_config - ) - with server_engine.connect() as conn: - t1 = f"CREATE DATABASE {db_name} WITH ENCODING 'UTF8' " - t2 = "LC_COLLATE 'en_US.UTF-8' LC_CTYPE 'en_US.UTF-8'" - conn.execute(text(t1 + t2)) - logger.success( - f"[PostgreSQL] 数据库 {db_name} 创建成功或已存在!" - ) - finally: - if server_engine is not None: - server_engine.dispose() - logger.debug("[SQL] 同步数据库引擎已释放") - db_config['poolclass'] = NullPool - finally_url = f'{base_url}{db_url}{db_name}' - engine = create_async_engine(finally_url, **db_config) + with server_engine.connect() as conn: + t1 = f"CREATE DATABASE IF NOT EXISTS {db_name} " + t2 = "CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci" + conn.execute(text(t1 + t2)) + logger.success( + f"[MySQL] 数据库 {db_name} 创建成功或已存在!" + ) + elif _db_type == 'postgresql': + try: + server_engine = create_engine( + f'{sync_url}{db_url}', **db_config + ) + with server_engine.connect() as conn: + t = f"CREATE DATABASE {db_name} WITH ENCODING " + t2 = "'UTF8' LC_COLLATE 'en_US.UTF-8' LC_CTYPE " + t3 = "'en_US.UTF-8' TEMPLATE template0" + conn.execute(text(t + t2 + t3)) + except exc.ProgrammingError as e: + if 'already exists' in str(e) or '已经存在' in str(e): + pass + logger.success( + f"[PostgreSQL] 数据库 {db_name} 创建成功或已存在!" + ) + finally: + if server_engine: + server_engine.dispose() + logger.info('[数据库] 临时数据库连接已释放!') - async_maker = async_sessionmaker( - engine, - expire_on_commit=False, - class_=AsyncSession, - ) -except: # noqa: E722 - raise ValueError( - f'[GsCore] [数据库] [{base_url}] 连接失败, 请检查配置文件!' - ) + # db_config['poolclass'] = NullPool + finally_url = f'{base_url}{db_url}{db_name}' + engine = create_async_engine(finally_url, **db_config) + + async_maker = async_sessionmaker( + engine, + expire_on_commit=False, + close_resets_only=False, + class_=AsyncSession, + ) + except Exception as e: # noqa: E722 + logger.exception(f'[GsCore] [数据库] 连接失败: {e}') + raise ValueError( + f'[GsCore] [数据库] [{base_url}] 连接失败, 请检查配置文件!' + ) def with_session( @@ -132,14 +155,16 @@ def with_session( ) -> Callable[Concatenate[Any, P], Awaitable[R]]: @wraps(func) async def wrapper(self, *args: P.args, **kwargs: P.kwargs): - async with async_maker() as session: + max_retries = 3 + for attempt in range(max_retries): try: - data = await func(self, session, *args, **kwargs) - await session.commit() - return data + async with async_maker() as session: + data = await func(self, session, *args, **kwargs) + await session.commit() + return data except Exception as e: - print(e) - raise e + logger.exception(f"[数据库] 第 {attempt + 1} 次重试失败: {e}") + continue return wrapper # type: ignore diff --git a/gsuid_core/utils/database/dal.py b/gsuid_core/utils/database/dal.py index f25b0ec..6edcbd6 100644 --- a/gsuid_core/utils/database/dal.py +++ b/gsuid_core/utils/database/dal.py @@ -1,12 +1,8 @@ import re -import asyncio from typing import Dict, List, Literal, Optional -from sqlmodel import SQLModel -from sqlalchemy.sql import text - +from .base_models import async_maker from .utils import SERVER, SR_SERVER -from .base_models import engine, async_maker from .models import GsBind, GsPush, GsUser, GsCache @@ -15,40 +11,6 @@ class SQLA: self.bot_id = bot_id self.is_sr = is_sr - def create_all(self): - try: - asyncio.create_task(self._create_all()) - except RuntimeError: - loop = asyncio.get_event_loop() - loop.run_until_complete(self._create_all()) - loop.close() - - async def _create_all(self): - async with engine.begin() as conn: - await conn.run_sync(SQLModel.metadata.create_all) - await self.sr_adapter() - - async def sr_adapter(self): - exec_list = [ - 'ALTER TABLE GsBind ADD COLUMN group_id TEXT', - 'ALTER TABLE GsBind ADD COLUMN sr_uid TEXT', - 'ALTER TABLE GsUser ADD COLUMN sr_uid TEXT', - 'ALTER TABLE GsUser ADD COLUMN sr_region TEXT', - 'ALTER TABLE GsUser ADD COLUMN fp TEXT', - 'ALTER TABLE GsUser ADD COLUMN device_id TEXT', - 'ALTER TABLE GsUser ADD COLUMN sr_sign_switch TEXT DEFAULT "off"', - 'ALTER TABLE GsUser ADD COLUMN sr_push_switch TEXT DEFAULT "off"', - 'ALTER TABLE GsUser ADD COLUMN draw_switch TEXT DEFAULT "off"', - 'ALTER TABLE GsCache ADD COLUMN sr_uid TEXT', - ] - async with async_maker() as session: - for _t in exec_list: - try: - await session.execute(text(_t)) - await session.commit() - except: # noqa: E722 - pass - ##################### # GsBind 部分 # ##################### @@ -416,3 +378,4 @@ class SQLA: async def insert_new_user(self, **kwargs): await GsUser.full_insert_data(**kwargs) + await GsUser.full_insert_data(**kwargs) diff --git a/gsuid_core/utils/database/startup.py b/gsuid_core/utils/database/startup.py index 32598c9..17beaab 100644 --- a/gsuid_core/utils/database/startup.py +++ b/gsuid_core/utils/database/startup.py @@ -31,7 +31,7 @@ async def move_database(): # @on_core_start -async def sr_adapter(): +async def trans_adapter(): async with engine.begin() as conn: metadata = MetaData() try: diff --git a/gsuid_core/utils/plugins_config/database_config.py b/gsuid_core/utils/plugins_config/database_config.py index 95a43c2..7a85e71 100644 --- a/gsuid_core/utils/plugins_config/database_config.py +++ b/gsuid_core/utils/plugins_config/database_config.py @@ -9,6 +9,12 @@ DATABASE_CONIFG: Dict[str, GSC] = { 'SQLite', ['SQLite', 'MySql', 'PostgreSQL', '自定义'], ), + 'db_driver': GsStrConfig( + 'MySQL驱动', + '设置喜欢的MySQL驱动', + 'aiomysql', + ['aiomysql', 'asyncmy'], + ), 'db_custom_url': GsStrConfig( '自定义数据库连接地址 (一般无需填写)', '设置自定义数据库连接', @@ -31,7 +37,7 @@ DATABASE_CONIFG: Dict[str, GSC] = { '数据库用户名', '设置数据库用户名', 'root', - ['root', 'admin'], + ['root', 'admin', 'postgres'], ), 'db_password': GsStrConfig( '数据库密码', diff --git a/gsuid_core/web_app.py b/gsuid_core/web_app.py index 96cf250..61fa076 100644 --- a/gsuid_core/web_app.py +++ b/gsuid_core/web_app.py @@ -54,6 +54,10 @@ pic_expire_time = core_plugins_config.get_config('ScheduledCleanPicSrv').data @asynccontextmanager async def lifespan(app: FastAPI): try: + logger.info( + '[GsCore] 执行启动Hook函数中!', + [_def.__name__ for _def in core_start_def], + ) _task = [_def() for _def in core_start_def] await asyncio.gather(*_task) except Exception as e: @@ -64,9 +68,16 @@ async def lifespan(app: FastAPI): await start_check() # type:ignore await start_scheduler() asyncio.create_task(clean_log()) + yield + await shutdown_scheduler() + try: + logger.info( + '[GsCore] 执行关闭Hook函数中!', + [_def.__name__ for _def in core_shutdown_def], + ) _task = [_def() for _def in core_shutdown_def] await asyncio.gather(*_task) except Exception as e: