From 10b666d6919f8904b88fa2582e43cc1de8d8d3a8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98KimigaiiWuyi=E2=80=99?= <444835641@qq.com> Date: Thu, 27 Apr 2023 17:42:37 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=8E=A8=20=E6=8F=90=E4=BE=9B`DBSqla`?= =?UTF-8?q?=E7=B1=BB=20(KimigaiiWuyi/GenshinUID#526)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- gsuid_core/utils/database/api.py | 39 ++++++++++++++++++++++---------- 1 file changed, 27 insertions(+), 12 deletions(-) diff --git a/gsuid_core/utils/database/api.py b/gsuid_core/utils/database/api.py index 17611de..2e63e7c 100644 --- a/gsuid_core/utils/database/api.py +++ b/gsuid_core/utils/database/api.py @@ -8,20 +8,35 @@ from gsuid_core.utils.database.dal import SQLA is_wal = False active_sqla: Dict[str, SQLA] = {} +active_sr_sqla: Dict[str, SQLA] = {} db_url = str(get_res_path().parent / 'GsData.db') -def get_sqla(bot_id) -> SQLA: - if bot_id not in active_sqla: - sqla = SQLA(db_url, bot_id) - active_sqla[bot_id] = sqla - sqla.create_all() +class DBSqla: + def __init__(self, is_sr: bool = False) -> None: + self.is_sr = is_sr - @event.listens_for(sqla.engine.sync_engine, 'connect') - def engine_connect(conn, branch): - if is_wal: - cursor = conn.cursor() - cursor.execute('PRAGMA journal_mode=WAL') - cursor.close() + def get_sqla(self, bot_id) -> SQLA: + return self._get_sqla(bot_id, self.is_sr) - return active_sqla[bot_id] + def _get_sqla(self, bot_id, is_sr: bool = False) -> SQLA: + sqla_list = active_sr_sqla if is_sr else active_sqla + if bot_id not in sqla_list: + sqla = SQLA(db_url, bot_id, is_sr) + sqla_list[bot_id] = sqla + sqla.create_all() + + @event.listens_for(sqla.engine.sync_engine, 'connect') + def engine_connect(conn, branch): + if is_wal: + cursor = conn.cursor() + cursor.execute('PRAGMA journal_mode=WAL') + cursor.close() + + return sqla_list[bot_id] + + def get_gs_sqla(self, bot_id): + return self._get_sqla(bot_id, False) + + def get_sr_sqla(self, bot_id): + return self._get_sqla(bot_id, True)