🎨 修改get_sqla的方式

This commit is contained in:
Wuyi无疑 2023-03-05 16:35:41 +08:00
parent 8462ae66a9
commit e6db065d5d
18 changed files with 62 additions and 65 deletions

View File

@ -9,7 +9,7 @@ from gsuid_core.aps import scheduler
from ..utils.mys_api import mys_api
from .backup_data import data_backup
from ..utils.database import active_sqla
from ..utils.database import get_sqla
from ..gsuid_utils.database.models import GsUser
@ -26,7 +26,7 @@ async def send_backup_msg(bot: Bot):
@SV('数据管理', pm=2).on_fullmatch(('校验全部Cookies'))
async def send_check_cookie(bot: Bot, ev: Event):
user_list = await active_sqla[bot.bot_id].get_all_user()
user_list = await get_sqla(bot.bot_id).get_all_user()
invalid_user: List[GsUser] = []
for user in user_list:
if user.cookie and user.mys_id:
@ -36,7 +36,7 @@ async def send_check_cookie(bot: Bot, ev: Event):
True if int(user.uid[0]) > 5 else False,
)
if isinstance(mys_data, int):
await active_sqla[bot.bot_id].delete_user_data(user.uid)
await get_sqla(bot.bot_id).delete_user_data(user.uid)
invalid_user.append(user)
continue
for i in mys_data:
@ -76,7 +76,7 @@ async def send_check_cookie(bot: Bot, ev: Event):
@SV('数据管理', pm=2).on_fullmatch(('校验全部Stoken'))
async def send_check_stoken(bot: Bot, ev: Event):
user_list = await active_sqla[bot.bot_id].get_all_user()
user_list = await get_sqla(bot.bot_id).get_all_user()
invalid_user: List[GsUser] = []
for user in user_list:
if user.stoken and user.mys_id:
@ -85,9 +85,7 @@ async def send_check_stoken(bot: Bot, ev: Event):
user.mys_id,
)
if isinstance(mys_data, int):
await active_sqla[bot.bot_id].update_user_stoken(
user.uid, None
)
await get_sqla(bot.bot_id).update_user_stoken(user.uid, None)
invalid_user.append(user)
continue
if len(user_list) > 4:

View File

@ -4,7 +4,7 @@ from shutil import copyfile
from nonebot.log import logger
from ..utils.database import active_sqla
from ..utils.database import get_sqla
from ..utils.resource.RESOURCE_PATH import TEMP_PATH
@ -24,8 +24,10 @@ async def data_backup():
f.unlink()
except OSError as e:
print("Error: %s : %s" % (f, e.strerror))
for bot_id in active_sqla:
await active_sqla[bot_id].delete_cache()
sqla = get_sqla('TEMP')
await sqla.delete_cache()
await sqla.close()
del sqla
logger.info('————缓存成功清除————')
except Exception:
logger.info('————数据库备份失败————')

View File

@ -4,7 +4,7 @@ from gsuid_core.sv import SV
from gsuid_core.bot import Bot
from gsuid_core.models import Event
from ..utils.database import active_sqla
from ..utils.database import get_sqla
from ..utils.error_reply import UID_HINT
from .draw_config_card import draw_config_img
from .set_config import set_push_value, set_config_func
@ -21,7 +21,7 @@ async def send_config_card(bot: Bot, ev: Event):
async def send_config_ev(bot: Bot, ev: Event):
await bot.logger.info('开始执行[设置阈值信息]')
sqla = active_sqla[ev.bot_id]
sqla = get_sqla(ev.bot_id)
uid = await sqla.get_bind_uid(ev.user_id)
if uid is None:
return await bot.send(UID_HINT)
@ -41,7 +41,7 @@ async def send_config_ev(bot: Bot, ev: Event):
# 开启 自动签到 和 推送树脂提醒 功能
@SV('原神配置').on_prefix(('gs开启', 'gs关闭'))
async def open_switch_func(bot: Bot, ev: Event):
sqla = active_sqla[ev.bot_id]
sqla = get_sqla(ev.bot_id)
user_id = ev.user_id
config_name = ev.text

View File

@ -5,7 +5,7 @@ from typing import Union
from nonebot.log import logger
from PIL import Image, ImageDraw
from ..utils.database import active_sqla
from ..utils.database import get_sqla
from .config_default import CONIFG_DEFAULT
from ..utils.image.convert import convert_img
from ..utils.image.image_tools import CustomizeImage
@ -22,7 +22,7 @@ second_color = (57, 57, 57)
async def draw_config_img(bot_id: str) -> Union[bytes, str]:
sqla = active_sqla[bot_id]
sqla = get_sqla(bot_id)
# 获取背景图片各项参数
based_w = 850
based_h = 850 + 155 * (len(CONIFG_DEFAULT) - 5)

View File

@ -3,7 +3,7 @@ from typing import Optional
from nonebot.log import logger
from .gs_config import gsconfig
from ..utils.database import active_sqla
from ..utils.database import get_sqla
from .config_default import CONIFG_DEFAULT
PUSH_MAP = {
@ -20,7 +20,7 @@ PRIV_MAP = {
async def set_push_value(bot_id: str, func: str, uid: str, value: int):
sqla = active_sqla[bot_id]
sqla = get_sqla(bot_id)
if func in PUSH_MAP:
status = PUSH_MAP[func]
else:
@ -41,7 +41,7 @@ async def set_config_func(
query: Optional[bool] = None,
is_admin: bool = False,
):
sqla = active_sqla[bot_id]
sqla = get_sqla(bot_id)
# 这里将传入的中文config_name转换为英文status
for _name in CONIFG_DEFAULT:
config = CONIFG_DEFAULT[_name]

View File

@ -4,7 +4,7 @@ from gsuid_core.models import Event
from .note_text import award
from ..utils.convert import get_uid
from ..utils.database import active_sqla
from ..utils.database import get_sqla
from ..utils.error_reply import UID_HINT
from .draw_note_card import draw_note_img
@ -12,7 +12,7 @@ from .draw_note_card import draw_note_img
# 群聊内 每月统计 功能
@SV('查询札记').on_fullmatch(('每月统计'))
async def send_monthly_data(bot: Bot, ev: Event):
sqla = active_sqla[bot.bot_id]
sqla = get_sqla(ev.bot_id)
uid = await sqla.get_bind_uid(ev.user_id)
if uid is None:
return UID_HINT

View File

@ -7,7 +7,7 @@ from nonebot.log import logger
from PIL import Image, ImageDraw
from ..utils.mys_api import mys_api
from ..utils.database import active_sqla
from ..utils.database import get_sqla
from ..utils.image.convert import convert_img
from ..genshinuid_enka.to_data import get_enka_info
from ..gsuid_utils.api.mys.models import Expedition
@ -70,7 +70,7 @@ async def _draw_task_img(
async def get_resin_img(bot_id: str, user_id: str):
try:
sqla = active_sqla[bot_id]
sqla = get_sqla(bot_id)
uid_list: List = await sqla.get_bind_uid_list(user_id)
logger.info('[每日信息]UID: {}'.format(uid_list))
# 进行校验UID是否绑定CK

View File

@ -4,7 +4,7 @@ from gsuid_core.gss import gss
from nonebot.log import logger
from ..utils.mys_api import mys_api
from ..utils.database import active_sqla
from ..utils.database import get_sqla
from ..genshinuid_config.gs_config import gsconfig
from ..gsuid_utils.api.mys.models import DailyNoteData
@ -21,7 +21,7 @@ NOTICE = {
async def get_notice_list() -> Dict[str, Dict[str, Dict]]:
msg_dict: Dict[str, Dict[str, Dict]] = {}
for bot_id in gss.active_bot:
sqla = active_sqla[bot_id]
sqla = get_sqla(bot_id)
user_list = await sqla.get_all_push_user_list()
for user in user_list:
raw_data = await mys_api.get_daily_data(user.uid)
@ -48,7 +48,7 @@ async def all_check(
user_id: str,
uid: str,
) -> Dict[str, Dict[str, Dict]]:
sqla = active_sqla[bot_id]
sqla = get_sqla(bot_id)
for mode in NOTICE.keys():
# 检查条件
if push_data[f'{mode}_is_push'] == 'on':

View File

@ -8,8 +8,8 @@ from gsuid_core.models import Event
from gsuid_core.aps import scheduler
from gsuid_core.logger import logger
from ..utils.database import get_sqla
from .sign import sign_in, daily_sign
from ..utils.database import active_sqla
from ..utils.error_reply import UID_HINT
from ..genshinuid_config.gs_config import gsconfig
@ -25,7 +25,7 @@ async def sign_at_night():
@SV('原神签到').on_fullmatch('签到')
async def get_sign_func(bot: Bot, ev: Event):
await bot.logger.info('[签到]QQ号: {}'.format(ev.user_id))
sqla = active_sqla[ev.bot_id]
sqla = get_sqla(ev.bot_id)
uid = await sqla.get_bind_uid(ev.user_id)
if uid is None:
return await bot.send(UID_HINT)

View File

@ -2,10 +2,11 @@ import random
import asyncio
from copy import deepcopy
from gsuid_core.gss import gss
from nonebot.log import logger
from ..utils.mys_api import mys_api
from ..utils.database import active_sqla
from ..utils.database import get_sqla
from ..genshinuid_config.gs_config import gsconfig
private_msg_list = {}
@ -158,8 +159,8 @@ async def daily_sign():
"""
global already
tasks = []
for bot_id in active_sqla:
sqla = active_sqla[bot_id]
for bot_id in gss.active_bot:
sqla = get_sqla(bot_id)
user_list = await sqla.get_all_user()
for user in user_list:
if user.sign_switch != 'off':

View File

@ -6,7 +6,7 @@ from gsuid_core.models import Event
from gsuid_core.segment import MessageSegment
from .qrlogin import qrcode_login
from ..utils.database import active_sqla
from ..utils.database import get_sqla
from .get_ck_help_msg import get_ck_help
from ..utils.message import send_diff_msg
from .draw_user_card import get_user_card
@ -30,7 +30,7 @@ async def send_refresh_ck_msg(bot: Bot, ev: Event):
@SV('扫码登陆').on_fullmatch(('扫码登陆', '扫码登录'))
async def send_qrcode_login(bot: Bot, ev: Event):
await bot.logger.info('开始执行[扫码登陆]')
im = await qrcode_login(bot, ev.user_id)
im = await qrcode_login(bot, ev, ev.user_id)
if not im:
return
im = await deal_ck(ev.bot_id, im, ev.user_id)
@ -57,7 +57,7 @@ async def send_link_uid_msg(bot: Bot, ev: Event):
qid = ev.user_id
await bot.logger.info('[绑定/解绑]UserID: {}'.format(qid))
sqla = active_sqla[ev.bot_id]
sqla = get_sqla(ev.bot_id)
uid = ev.text
if ev.command.startswith('绑定'):

View File

@ -3,7 +3,7 @@ from typing import Dict, List
from http.cookies import SimpleCookie
from ..utils.mys_api import mys_api
from ..utils.database import active_sqla
from ..utils.database import get_sqla
from ..utils.error_reply import UID_HINT
pic_path = Path(__file__).parent / 'pic'
@ -28,7 +28,7 @@ lt_list = ['login_ticket', 'login_ticket_v2']
async def get_ck_by_all_stoken(bot_id: str):
sqla = active_sqla[bot_id]
sqla = get_sqla(bot_id)
uid_list: List = await sqla.get_all_uid_list()
uid_dict = {}
for uid in uid_list:
@ -40,7 +40,7 @@ async def get_ck_by_all_stoken(bot_id: str):
async def get_ck_by_stoken(bot_id: str, user_id: str):
sqla = active_sqla[bot_id]
sqla = get_sqla(bot_id)
uid_list: List = await sqla.get_bind_uid_list(user_id)
uid_dict = {uid: user_id for uid in uid_list}
im = await refresh_ck_by_uid_list(bot_id, uid_dict)
@ -48,7 +48,7 @@ async def get_ck_by_stoken(bot_id: str, user_id: str):
async def refresh_ck_by_uid_list(bot_id: str, uid_dict: Dict):
sqla = active_sqla[bot_id]
sqla = get_sqla(bot_id)
uid_num = len(uid_dict)
if uid_num == 0:
return '请先绑定一个UID噢~'
@ -113,7 +113,7 @@ async def get_account_id(simp_dict: SimpleCookie) -> str:
async def _deal_ck(bot_id: str, mes: str, user_id: str) -> str:
sqla = active_sqla[bot_id]
sqla = get_sqla(bot_id)
simp_dict = SimpleCookie(mes)
uid = await sqla.get_bind_uid(user_id)
if uid is None:

View File

@ -3,7 +3,7 @@ from pathlib import Path
from PIL import Image, ImageDraw
from ..utils.database import active_sqla
from ..utils.database import get_sqla
from ..utils.image.convert import convert_img
from ..gsuid_utils.database.models import GsUser
from ..utils.image.image_tools import get_simple_bg
@ -26,7 +26,7 @@ gs_font_26 = genshin_font_origin(26)
async def get_user_card(bot_id: str, user_id: str) -> bytes:
sqla = active_sqla[bot_id]
sqla = get_sqla(bot_id)
uid_list: List = await sqla.get_bind_uid_list(user_id)
w, h = 500, len(uid_list) * 210 + 330
img = await get_simple_bg(w, h)

View File

@ -7,12 +7,13 @@ from typing import Any, List, Tuple, Union, Literal
import qrcode
from gsuid_core.bot import Bot
from gsuid_core.models import Event
from gsuid_core.logger import logger
from qrcode.constants import ERROR_CORRECT_L
from gsuid_core.segment import MessageSegment
from ..utils.mys_api import mys_api
from ..utils.database import active_sqla
from ..utils.database import get_sqla
disnote = '''免责声明:您将通过扫码完成获取米游社sk以及ck。
本Bot将不会保存您的登录状态
@ -61,8 +62,8 @@ async def refresh(
return True, json.loads(status_data['payload']['raw'])
async def qrcode_login(bot: Bot, user_id: str) -> str:
sqla = active_sqla[bot.bot_id]
async def qrcode_login(bot: Bot, ev: Event, user_id: str) -> str:
sqla = get_sqla(ev.bot_id)
async def send_msg(msg: str):
await bot.send(msg)

View File

@ -380,3 +380,6 @@ class SQLA:
sql = delete(GsCache).where(GsCache.uid == uid)
await self.session.execute(sql)
return True
async def close(self):
await self.session.close()

View File

@ -5,7 +5,7 @@ from gsuid_core.bot import Bot
from gsuid_core.models import Event
from .mys_api import mys_api
from .database import active_sqla
from .database import get_sqla
from .error_reply import VERIFY_HINT
@ -15,7 +15,7 @@ async def get_uid(bot: Bot, ev: Event):
if uid:
uid = uid[0]
else:
sqla = active_sqla[bot.bot_id]
sqla = get_sqla(ev.bot_id)
uid = await sqla.get_bind_uid(user_id)
return uid
@ -25,9 +25,7 @@ class GsCookie:
self.cookie: Optional[str] = None
self.uid: Optional[str] = None
self.raw_data = None
for bot_id in active_sqla:
self.sqla = active_sqla[bot_id]
break
self.sqla = get_sqla('TEMP')
async def get_cookie(self, uid: str) -> str:
self.uid = uid

View File

@ -1,7 +1,6 @@
from typing import Dict
from sqlalchemy import event
from gsuid_core.gss import gss
from ..gsuid_utils.database.dal import SQLA
@ -11,9 +10,8 @@ active_sqla: Dict[str, SQLA] = {}
db_url = 'GsData.db'
@gss.on_bot_connect
async def refresh_sqla():
for bot_id in gss.active_bot:
def get_sqla(bot_id) -> SQLA:
if bot_id not in active_sqla:
sqla = SQLA(db_url, bot_id)
active_sqla[bot_id] = sqla
sqla.create_all()
@ -24,3 +22,5 @@ async def refresh_sqla():
cursor = conn.cursor()
cursor.execute('PRAGMA journal_mode=WAL')
cursor.close()
return active_sqla[bot_id]

View File

@ -1,6 +1,6 @@
from typing import Dict, Literal, Optional
from .database import active_sqla
from .database import get_sqla
from ..gsuid_utils.api.mys import MysApi
from ..genshinuid_config.gs_config import gsconfig
@ -47,21 +47,15 @@ class _MysApi(MysApi):
async def get_ck(
self, uid: str, mode: Literal['OWNER', 'RANDOM'] = 'RANDOM'
) -> Optional[str]:
for bot_id in active_sqla:
sqla = active_sqla[bot_id]
if mode == 'RANDOM':
return await sqla.get_random_cookie(uid)
else:
return await sqla.get_user_cookie(uid)
sqla = get_sqla('TEMP')
if mode == 'RANDOM':
return await sqla.get_random_cookie(uid)
else:
return None
return await sqla.get_user_cookie(uid)
async def get_stoken(self, uid: str) -> Optional[str]:
for bot_id in active_sqla:
sqla = active_sqla[bot_id]
return await sqla.get_user_stoken(uid)
else:
return None
sqla = get_sqla('TEMP')
return await sqla.get_user_stoken(uid)
mys_api = _MysApi()