diff --git a/GenshinUID/genshinuid_check/__init__.py b/GenshinUID/genshinuid_check/__init__.py index 36d0f40f..b0d2e154 100644 --- a/GenshinUID/genshinuid_check/__init__.py +++ b/GenshinUID/genshinuid_check/__init__.py @@ -9,7 +9,7 @@ from gsuid_core.aps import scheduler from ..utils.mys_api import mys_api from .backup_data import data_backup -from ..utils.database import active_sqla +from ..utils.database import get_sqla from ..gsuid_utils.database.models import GsUser @@ -26,7 +26,7 @@ async def send_backup_msg(bot: Bot): @SV('数据管理', pm=2).on_fullmatch(('校验全部Cookies')) async def send_check_cookie(bot: Bot, ev: Event): - user_list = await active_sqla[bot.bot_id].get_all_user() + user_list = await get_sqla(bot.bot_id).get_all_user() invalid_user: List[GsUser] = [] for user in user_list: if user.cookie and user.mys_id: @@ -36,7 +36,7 @@ async def send_check_cookie(bot: Bot, ev: Event): True if int(user.uid[0]) > 5 else False, ) if isinstance(mys_data, int): - await active_sqla[bot.bot_id].delete_user_data(user.uid) + await get_sqla(bot.bot_id).delete_user_data(user.uid) invalid_user.append(user) continue for i in mys_data: @@ -76,7 +76,7 @@ async def send_check_cookie(bot: Bot, ev: Event): @SV('数据管理', pm=2).on_fullmatch(('校验全部Stoken')) async def send_check_stoken(bot: Bot, ev: Event): - user_list = await active_sqla[bot.bot_id].get_all_user() + user_list = await get_sqla(bot.bot_id).get_all_user() invalid_user: List[GsUser] = [] for user in user_list: if user.stoken and user.mys_id: @@ -85,9 +85,7 @@ async def send_check_stoken(bot: Bot, ev: Event): user.mys_id, ) if isinstance(mys_data, int): - await active_sqla[bot.bot_id].update_user_stoken( - user.uid, None - ) + await get_sqla(bot.bot_id).update_user_stoken(user.uid, None) invalid_user.append(user) continue if len(user_list) > 4: diff --git a/GenshinUID/genshinuid_check/backup_data.py b/GenshinUID/genshinuid_check/backup_data.py index ff70a06f..97e9a3cd 100644 --- a/GenshinUID/genshinuid_check/backup_data.py +++ b/GenshinUID/genshinuid_check/backup_data.py @@ -4,7 +4,7 @@ from shutil import copyfile from nonebot.log import logger -from ..utils.database import active_sqla +from ..utils.database import get_sqla from ..utils.resource.RESOURCE_PATH import TEMP_PATH @@ -24,8 +24,10 @@ async def data_backup(): f.unlink() except OSError as e: print("Error: %s : %s" % (f, e.strerror)) - for bot_id in active_sqla: - await active_sqla[bot_id].delete_cache() + sqla = get_sqla('TEMP') + await sqla.delete_cache() + await sqla.close() + del sqla logger.info('————缓存成功清除————') except Exception: logger.info('————数据库备份失败————') diff --git a/GenshinUID/genshinuid_config/__init__.py b/GenshinUID/genshinuid_config/__init__.py index 15784e8e..5786387a 100644 --- a/GenshinUID/genshinuid_config/__init__.py +++ b/GenshinUID/genshinuid_config/__init__.py @@ -4,7 +4,7 @@ from gsuid_core.sv import SV from gsuid_core.bot import Bot from gsuid_core.models import Event -from ..utils.database import active_sqla +from ..utils.database import get_sqla from ..utils.error_reply import UID_HINT from .draw_config_card import draw_config_img from .set_config import set_push_value, set_config_func @@ -21,7 +21,7 @@ async def send_config_card(bot: Bot, ev: Event): async def send_config_ev(bot: Bot, ev: Event): await bot.logger.info('开始执行[设置阈值信息]') - sqla = active_sqla[ev.bot_id] + sqla = get_sqla(ev.bot_id) uid = await sqla.get_bind_uid(ev.user_id) if uid is None: return await bot.send(UID_HINT) @@ -41,7 +41,7 @@ async def send_config_ev(bot: Bot, ev: Event): # 开启 自动签到 和 推送树脂提醒 功能 @SV('原神配置').on_prefix(('gs开启', 'gs关闭')) async def open_switch_func(bot: Bot, ev: Event): - sqla = active_sqla[ev.bot_id] + sqla = get_sqla(ev.bot_id) user_id = ev.user_id config_name = ev.text diff --git a/GenshinUID/genshinuid_config/draw_config_card.py b/GenshinUID/genshinuid_config/draw_config_card.py index f5933714..4560a5d9 100644 --- a/GenshinUID/genshinuid_config/draw_config_card.py +++ b/GenshinUID/genshinuid_config/draw_config_card.py @@ -5,7 +5,7 @@ from typing import Union from nonebot.log import logger from PIL import Image, ImageDraw -from ..utils.database import active_sqla +from ..utils.database import get_sqla from .config_default import CONIFG_DEFAULT from ..utils.image.convert import convert_img from ..utils.image.image_tools import CustomizeImage @@ -22,7 +22,7 @@ second_color = (57, 57, 57) async def draw_config_img(bot_id: str) -> Union[bytes, str]: - sqla = active_sqla[bot_id] + sqla = get_sqla(bot_id) # 获取背景图片各项参数 based_w = 850 based_h = 850 + 155 * (len(CONIFG_DEFAULT) - 5) diff --git a/GenshinUID/genshinuid_config/set_config.py b/GenshinUID/genshinuid_config/set_config.py index 405fc970..0f60720b 100644 --- a/GenshinUID/genshinuid_config/set_config.py +++ b/GenshinUID/genshinuid_config/set_config.py @@ -3,7 +3,7 @@ from typing import Optional from nonebot.log import logger from .gs_config import gsconfig -from ..utils.database import active_sqla +from ..utils.database import get_sqla from .config_default import CONIFG_DEFAULT PUSH_MAP = { @@ -20,7 +20,7 @@ PRIV_MAP = { async def set_push_value(bot_id: str, func: str, uid: str, value: int): - sqla = active_sqla[bot_id] + sqla = get_sqla(bot_id) if func in PUSH_MAP: status = PUSH_MAP[func] else: @@ -41,7 +41,7 @@ async def set_config_func( query: Optional[bool] = None, is_admin: bool = False, ): - sqla = active_sqla[bot_id] + sqla = get_sqla(bot_id) # 这里将传入的中文config_name转换为英文status for _name in CONIFG_DEFAULT: config = CONIFG_DEFAULT[_name] diff --git a/GenshinUID/genshinuid_note/__init__.py b/GenshinUID/genshinuid_note/__init__.py index 14517c7e..820c71c6 100644 --- a/GenshinUID/genshinuid_note/__init__.py +++ b/GenshinUID/genshinuid_note/__init__.py @@ -4,7 +4,7 @@ from gsuid_core.models import Event from .note_text import award from ..utils.convert import get_uid -from ..utils.database import active_sqla +from ..utils.database import get_sqla from ..utils.error_reply import UID_HINT from .draw_note_card import draw_note_img @@ -12,7 +12,7 @@ from .draw_note_card import draw_note_img # 群聊内 每月统计 功能 @SV('查询札记').on_fullmatch(('每月统计')) async def send_monthly_data(bot: Bot, ev: Event): - sqla = active_sqla[bot.bot_id] + sqla = get_sqla(ev.bot_id) uid = await sqla.get_bind_uid(ev.user_id) if uid is None: return UID_HINT diff --git a/GenshinUID/genshinuid_resin/draw_resin_card.py b/GenshinUID/genshinuid_resin/draw_resin_card.py index aee53a82..152da557 100644 --- a/GenshinUID/genshinuid_resin/draw_resin_card.py +++ b/GenshinUID/genshinuid_resin/draw_resin_card.py @@ -7,7 +7,7 @@ from nonebot.log import logger from PIL import Image, ImageDraw from ..utils.mys_api import mys_api -from ..utils.database import active_sqla +from ..utils.database import get_sqla from ..utils.image.convert import convert_img from ..genshinuid_enka.to_data import get_enka_info from ..gsuid_utils.api.mys.models import Expedition @@ -70,7 +70,7 @@ async def _draw_task_img( async def get_resin_img(bot_id: str, user_id: str): try: - sqla = active_sqla[bot_id] + sqla = get_sqla(bot_id) uid_list: List = await sqla.get_bind_uid_list(user_id) logger.info('[每日信息]UID: {}'.format(uid_list)) # 进行校验UID是否绑定CK diff --git a/GenshinUID/genshinuid_resin/notice.py b/GenshinUID/genshinuid_resin/notice.py index d1d0fe12..7093d48a 100644 --- a/GenshinUID/genshinuid_resin/notice.py +++ b/GenshinUID/genshinuid_resin/notice.py @@ -4,7 +4,7 @@ from gsuid_core.gss import gss from nonebot.log import logger from ..utils.mys_api import mys_api -from ..utils.database import active_sqla +from ..utils.database import get_sqla from ..genshinuid_config.gs_config import gsconfig from ..gsuid_utils.api.mys.models import DailyNoteData @@ -21,7 +21,7 @@ NOTICE = { async def get_notice_list() -> Dict[str, Dict[str, Dict]]: msg_dict: Dict[str, Dict[str, Dict]] = {} for bot_id in gss.active_bot: - sqla = active_sqla[bot_id] + sqla = get_sqla(bot_id) user_list = await sqla.get_all_push_user_list() for user in user_list: raw_data = await mys_api.get_daily_data(user.uid) @@ -48,7 +48,7 @@ async def all_check( user_id: str, uid: str, ) -> Dict[str, Dict[str, Dict]]: - sqla = active_sqla[bot_id] + sqla = get_sqla(bot_id) for mode in NOTICE.keys(): # 检查条件 if push_data[f'{mode}_is_push'] == 'on': diff --git a/GenshinUID/genshinuid_signin/__init__.py b/GenshinUID/genshinuid_signin/__init__.py index 9c442951..e63712e3 100644 --- a/GenshinUID/genshinuid_signin/__init__.py +++ b/GenshinUID/genshinuid_signin/__init__.py @@ -8,8 +8,8 @@ from gsuid_core.models import Event from gsuid_core.aps import scheduler from gsuid_core.logger import logger +from ..utils.database import get_sqla from .sign import sign_in, daily_sign -from ..utils.database import active_sqla from ..utils.error_reply import UID_HINT from ..genshinuid_config.gs_config import gsconfig @@ -25,7 +25,7 @@ async def sign_at_night(): @SV('原神签到').on_fullmatch('签到') async def get_sign_func(bot: Bot, ev: Event): await bot.logger.info('[签到]QQ号: {}'.format(ev.user_id)) - sqla = active_sqla[ev.bot_id] + sqla = get_sqla(ev.bot_id) uid = await sqla.get_bind_uid(ev.user_id) if uid is None: return await bot.send(UID_HINT) diff --git a/GenshinUID/genshinuid_signin/sign.py b/GenshinUID/genshinuid_signin/sign.py index a0da4041..cc2d53e7 100644 --- a/GenshinUID/genshinuid_signin/sign.py +++ b/GenshinUID/genshinuid_signin/sign.py @@ -2,10 +2,11 @@ import random import asyncio from copy import deepcopy +from gsuid_core.gss import gss from nonebot.log import logger from ..utils.mys_api import mys_api -from ..utils.database import active_sqla +from ..utils.database import get_sqla from ..genshinuid_config.gs_config import gsconfig private_msg_list = {} @@ -158,8 +159,8 @@ async def daily_sign(): """ global already tasks = [] - for bot_id in active_sqla: - sqla = active_sqla[bot_id] + for bot_id in gss.active_bot: + sqla = get_sqla(bot_id) user_list = await sqla.get_all_user() for user in user_list: if user.sign_switch != 'off': diff --git a/GenshinUID/genshinuid_user/__init__.py b/GenshinUID/genshinuid_user/__init__.py index 456d12bc..3d301bb4 100644 --- a/GenshinUID/genshinuid_user/__init__.py +++ b/GenshinUID/genshinuid_user/__init__.py @@ -6,7 +6,7 @@ from gsuid_core.models import Event from gsuid_core.segment import MessageSegment from .qrlogin import qrcode_login -from ..utils.database import active_sqla +from ..utils.database import get_sqla from .get_ck_help_msg import get_ck_help from ..utils.message import send_diff_msg from .draw_user_card import get_user_card @@ -30,7 +30,7 @@ async def send_refresh_ck_msg(bot: Bot, ev: Event): @SV('扫码登陆').on_fullmatch(('扫码登陆', '扫码登录')) async def send_qrcode_login(bot: Bot, ev: Event): await bot.logger.info('开始执行[扫码登陆]') - im = await qrcode_login(bot, ev.user_id) + im = await qrcode_login(bot, ev, ev.user_id) if not im: return im = await deal_ck(ev.bot_id, im, ev.user_id) @@ -57,7 +57,7 @@ async def send_link_uid_msg(bot: Bot, ev: Event): qid = ev.user_id await bot.logger.info('[绑定/解绑]UserID: {}'.format(qid)) - sqla = active_sqla[ev.bot_id] + sqla = get_sqla(ev.bot_id) uid = ev.text if ev.command.startswith('绑定'): diff --git a/GenshinUID/genshinuid_user/add_ck.py b/GenshinUID/genshinuid_user/add_ck.py index dfae703c..cd637b93 100644 --- a/GenshinUID/genshinuid_user/add_ck.py +++ b/GenshinUID/genshinuid_user/add_ck.py @@ -3,7 +3,7 @@ from typing import Dict, List from http.cookies import SimpleCookie from ..utils.mys_api import mys_api -from ..utils.database import active_sqla +from ..utils.database import get_sqla from ..utils.error_reply import UID_HINT pic_path = Path(__file__).parent / 'pic' @@ -28,7 +28,7 @@ lt_list = ['login_ticket', 'login_ticket_v2'] async def get_ck_by_all_stoken(bot_id: str): - sqla = active_sqla[bot_id] + sqla = get_sqla(bot_id) uid_list: List = await sqla.get_all_uid_list() uid_dict = {} for uid in uid_list: @@ -40,7 +40,7 @@ async def get_ck_by_all_stoken(bot_id: str): async def get_ck_by_stoken(bot_id: str, user_id: str): - sqla = active_sqla[bot_id] + sqla = get_sqla(bot_id) uid_list: List = await sqla.get_bind_uid_list(user_id) uid_dict = {uid: user_id for uid in uid_list} im = await refresh_ck_by_uid_list(bot_id, uid_dict) @@ -48,7 +48,7 @@ async def get_ck_by_stoken(bot_id: str, user_id: str): async def refresh_ck_by_uid_list(bot_id: str, uid_dict: Dict): - sqla = active_sqla[bot_id] + sqla = get_sqla(bot_id) uid_num = len(uid_dict) if uid_num == 0: return '请先绑定一个UID噢~' @@ -113,7 +113,7 @@ async def get_account_id(simp_dict: SimpleCookie) -> str: async def _deal_ck(bot_id: str, mes: str, user_id: str) -> str: - sqla = active_sqla[bot_id] + sqla = get_sqla(bot_id) simp_dict = SimpleCookie(mes) uid = await sqla.get_bind_uid(user_id) if uid is None: diff --git a/GenshinUID/genshinuid_user/draw_user_card.py b/GenshinUID/genshinuid_user/draw_user_card.py index 2c28f84a..a810a32a 100644 --- a/GenshinUID/genshinuid_user/draw_user_card.py +++ b/GenshinUID/genshinuid_user/draw_user_card.py @@ -3,7 +3,7 @@ from pathlib import Path from PIL import Image, ImageDraw -from ..utils.database import active_sqla +from ..utils.database import get_sqla from ..utils.image.convert import convert_img from ..gsuid_utils.database.models import GsUser from ..utils.image.image_tools import get_simple_bg @@ -26,7 +26,7 @@ gs_font_26 = genshin_font_origin(26) async def get_user_card(bot_id: str, user_id: str) -> bytes: - sqla = active_sqla[bot_id] + sqla = get_sqla(bot_id) uid_list: List = await sqla.get_bind_uid_list(user_id) w, h = 500, len(uid_list) * 210 + 330 img = await get_simple_bg(w, h) diff --git a/GenshinUID/genshinuid_user/qrlogin.py b/GenshinUID/genshinuid_user/qrlogin.py index eb44f1d3..d054c30e 100644 --- a/GenshinUID/genshinuid_user/qrlogin.py +++ b/GenshinUID/genshinuid_user/qrlogin.py @@ -7,12 +7,13 @@ from typing import Any, List, Tuple, Union, Literal import qrcode from gsuid_core.bot import Bot +from gsuid_core.models import Event from gsuid_core.logger import logger from qrcode.constants import ERROR_CORRECT_L from gsuid_core.segment import MessageSegment from ..utils.mys_api import mys_api -from ..utils.database import active_sqla +from ..utils.database import get_sqla disnote = '''免责声明:您将通过扫码完成获取米游社sk以及ck。 本Bot将不会保存您的登录状态。 @@ -61,8 +62,8 @@ async def refresh( return True, json.loads(status_data['payload']['raw']) -async def qrcode_login(bot: Bot, user_id: str) -> str: - sqla = active_sqla[bot.bot_id] +async def qrcode_login(bot: Bot, ev: Event, user_id: str) -> str: + sqla = get_sqla(ev.bot_id) async def send_msg(msg: str): await bot.send(msg) diff --git a/GenshinUID/gsuid_utils/database/dal.py b/GenshinUID/gsuid_utils/database/dal.py index abee4867..8ab2756b 100644 --- a/GenshinUID/gsuid_utils/database/dal.py +++ b/GenshinUID/gsuid_utils/database/dal.py @@ -380,3 +380,6 @@ class SQLA: sql = delete(GsCache).where(GsCache.uid == uid) await self.session.execute(sql) return True + + async def close(self): + await self.session.close() diff --git a/GenshinUID/utils/convert.py b/GenshinUID/utils/convert.py index 9836ef4c..68df2167 100644 --- a/GenshinUID/utils/convert.py +++ b/GenshinUID/utils/convert.py @@ -5,7 +5,7 @@ from gsuid_core.bot import Bot from gsuid_core.models import Event from .mys_api import mys_api -from .database import active_sqla +from .database import get_sqla from .error_reply import VERIFY_HINT @@ -15,7 +15,7 @@ async def get_uid(bot: Bot, ev: Event): if uid: uid = uid[0] else: - sqla = active_sqla[bot.bot_id] + sqla = get_sqla(ev.bot_id) uid = await sqla.get_bind_uid(user_id) return uid @@ -25,9 +25,7 @@ class GsCookie: self.cookie: Optional[str] = None self.uid: Optional[str] = None self.raw_data = None - for bot_id in active_sqla: - self.sqla = active_sqla[bot_id] - break + self.sqla = get_sqla('TEMP') async def get_cookie(self, uid: str) -> str: self.uid = uid diff --git a/GenshinUID/utils/database.py b/GenshinUID/utils/database.py index c54fddda..9afb662c 100644 --- a/GenshinUID/utils/database.py +++ b/GenshinUID/utils/database.py @@ -1,7 +1,6 @@ from typing import Dict from sqlalchemy import event -from gsuid_core.gss import gss from ..gsuid_utils.database.dal import SQLA @@ -11,9 +10,8 @@ active_sqla: Dict[str, SQLA] = {} db_url = 'GsData.db' -@gss.on_bot_connect -async def refresh_sqla(): - for bot_id in gss.active_bot: +def get_sqla(bot_id) -> SQLA: + if bot_id not in active_sqla: sqla = SQLA(db_url, bot_id) active_sqla[bot_id] = sqla sqla.create_all() @@ -24,3 +22,5 @@ async def refresh_sqla(): cursor = conn.cursor() cursor.execute('PRAGMA journal_mode=WAL') cursor.close() + + return active_sqla[bot_id] diff --git a/GenshinUID/utils/mys_api.py b/GenshinUID/utils/mys_api.py index 8ffdd916..a10a69a6 100644 --- a/GenshinUID/utils/mys_api.py +++ b/GenshinUID/utils/mys_api.py @@ -1,6 +1,6 @@ from typing import Dict, Literal, Optional -from .database import active_sqla +from .database import get_sqla from ..gsuid_utils.api.mys import MysApi from ..genshinuid_config.gs_config import gsconfig @@ -47,21 +47,15 @@ class _MysApi(MysApi): async def get_ck( self, uid: str, mode: Literal['OWNER', 'RANDOM'] = 'RANDOM' ) -> Optional[str]: - for bot_id in active_sqla: - sqla = active_sqla[bot_id] - if mode == 'RANDOM': - return await sqla.get_random_cookie(uid) - else: - return await sqla.get_user_cookie(uid) + sqla = get_sqla('TEMP') + if mode == 'RANDOM': + return await sqla.get_random_cookie(uid) else: - return None + return await sqla.get_user_cookie(uid) async def get_stoken(self, uid: str) -> Optional[str]: - for bot_id in active_sqla: - sqla = active_sqla[bot_id] - return await sqla.get_user_stoken(uid) - else: - return None + sqla = get_sqla('TEMP') + return await sqla.get_user_stoken(uid) mys_api = _MysApi()