diff --git a/gsuid_core/utils/api/mys/request.py b/gsuid_core/utils/api/mys/request.py index a3c9077..6802cad 100644 --- a/gsuid_core/utils/api/mys/request.py +++ b/gsuid_core/utils/api/mys/request.py @@ -13,6 +13,8 @@ from typing import Any, Dict, List, Union, Literal, Optional, cast from aiohttp import ClientSession, ContentTypeError +from gsuid_core.logger import logger + from .api import _API from .tools import ( random_hex, @@ -61,7 +63,7 @@ RECOGNIZE_SERVER = { } -class MysApi: +class BaseMysApi: proxy_url: Optional[str] = None mysVersion = '2.44.1' _HEADER = { @@ -98,6 +100,120 @@ class MysApi: async def get_stoken(self, uid: str) -> Optional[str]: ... + async def simple_mys_req( + self, + URL: str, + uid: Union[str, bool], + params: Dict = {}, + header: Dict = {}, + cookie: Optional[str] = None, + ) -> Union[Dict, int]: + if isinstance(uid, bool): + is_os = uid + server_id = 'cn_qd01' if is_os else 'cn_gf01' + else: + server_id = RECOGNIZE_SERVER.get(uid[0]) + is_os = False if int(uid[0]) < 6 else True + ex_params = '&'.join([f'{k}={v}' for k, v in params.items()]) + if is_os: + _URL = _API[f'{URL}_OS'] + HEADER = copy.deepcopy(self._HEADER_OS) + HEADER['DS'] = generate_os_ds() + else: + _URL = _API[URL] + HEADER = copy.deepcopy(self._HEADER) + HEADER['DS'] = get_ds_token( + ex_params if ex_params else f'role_id={uid}&server={server_id}' + ) + HEADER.update(header) + if cookie is not None: + HEADER['Cookie'] = cookie + elif 'Cookie' not in HEADER and isinstance(uid, str): + ck = await self.get_ck(uid) + if ck is None: + return -51 + HEADER['Cookie'] = ck + data = await self._mys_request( + url=_URL, + method='GET', + header=HEADER, + params=params if params else {'server': server_id, 'role_id': uid}, + use_proxy=True if is_os else False, + ) + return data + + async def _mys_req_get( + self, + url: str, + is_os: bool, + params: Dict, + header: Optional[Dict] = None, + ) -> Union[Dict, int]: + if is_os: + _URL = _API[f'{url}_OS'] + HEADER = copy.deepcopy(self._HEADER_OS) + use_proxy = True + else: + _URL = _API[url] + HEADER = copy.deepcopy(self._HEADER) + use_proxy = False + if header: + HEADER.update(header) + + if 'Cookie' not in HEADER and 'uid' in params: + ck = await self.get_ck(params['uid']) + if ck is None: + return -51 + HEADER['Cookie'] = ck + data = await self._mys_request( + url=_URL, + method='GET', + header=HEADER, + params=params, + use_proxy=use_proxy, + ) + return data + + async def _mys_request( + self, + url: str, + method: Literal['GET', 'POST'] = 'GET', + header: Dict[str, Any] = _HEADER, + params: Optional[Dict[str, Any]] = None, + data: Optional[Dict[str, Any]] = None, + use_proxy: Optional[bool] = False, + ) -> Union[Dict, int]: + async with ClientSession() as client: + async with client.request( + method, + url=url, + headers=header, + params=params, + json=data, + proxy=self.proxy_url if use_proxy else None, + timeout=300, + ) as resp: + try: + raw_data = await resp.json() + except ContentTypeError: + _raw_data = await resp.text() + raw_data = {'retcode': -999, 'data': _raw_data} + logger.debug(raw_data) + if 'retcode' in raw_data: + retcode: int = raw_data['retcode'] + elif 'code' in raw_data: + retcode: int = raw_data['code'] + else: + retcode = 0 + if retcode == 1034: + await self._upass(header) + return retcode + elif retcode != 0: + return retcode + return raw_data + + +class MysApi(BaseMysApi): async def get_upass_link(self, header: Dict) -> Union[int, Dict]: header['DS'] = get_ds_token('is_high=false') return await self._mys_request( @@ -828,115 +944,3 @@ class MysApi: if isinstance(resp, int): return resp return cast(MysOrderCheck, resp['data']) - - async def simple_mys_req( - self, - URL: str, - uid: Union[str, bool], - params: Dict = {}, - header: Dict = {}, - cookie: Optional[str] = None, - ) -> Union[Dict, int]: - if isinstance(uid, bool): - is_os = uid - server_id = 'cn_qd01' if is_os else 'cn_gf01' - else: - server_id = RECOGNIZE_SERVER.get(uid[0]) - is_os = False if int(uid[0]) < 6 else True - ex_params = '&'.join([f'{k}={v}' for k, v in params.items()]) - if is_os: - _URL = _API[f'{URL}_OS'] - HEADER = copy.deepcopy(self._HEADER_OS) - HEADER['DS'] = generate_os_ds() - else: - _URL = _API[URL] - HEADER = copy.deepcopy(self._HEADER) - HEADER['DS'] = get_ds_token( - ex_params if ex_params else f'role_id={uid}&server={server_id}' - ) - HEADER.update(header) - if cookie is not None: - HEADER['Cookie'] = cookie - elif 'Cookie' not in HEADER and isinstance(uid, str): - ck = await self.get_ck(uid) - if ck is None: - return -51 - HEADER['Cookie'] = ck - data = await self._mys_request( - url=_URL, - method='GET', - header=HEADER, - params=params if params else {'server': server_id, 'role_id': uid}, - use_proxy=True if is_os else False, - ) - return data - - async def _mys_req_get( - self, - url: str, - is_os: bool, - params: Dict, - header: Optional[Dict] = None, - ) -> Union[Dict, int]: - if is_os: - _URL = _API[f'{url}_OS'] - HEADER = copy.deepcopy(self._HEADER_OS) - use_proxy = True - else: - _URL = _API[url] - HEADER = copy.deepcopy(self._HEADER) - use_proxy = False - if header: - HEADER.update(header) - - if 'Cookie' not in HEADER and 'uid' in params: - ck = await self.get_ck(params['uid']) - if ck is None: - return -51 - HEADER['Cookie'] = ck - data = await self._mys_request( - url=_URL, - method='GET', - header=HEADER, - params=params, - use_proxy=use_proxy, - ) - return data - - async def _mys_request( - self, - url: str, - method: Literal['GET', 'POST'] = 'GET', - header: Dict[str, Any] = _HEADER, - params: Optional[Dict[str, Any]] = None, - data: Optional[Dict[str, Any]] = None, - use_proxy: Optional[bool] = False, - ) -> Union[Dict, int]: - async with ClientSession() as client: - async with client.request( - method, - url=url, - headers=header, - params=params, - json=data, - proxy=self.proxy_url if use_proxy else None, - timeout=300, - ) as resp: - try: - raw_data = await resp.json() - except ContentTypeError: - _raw_data = await resp.text() - raw_data = {'retcode': -999, 'data': _raw_data} - print(raw_data) - if 'retcode' in raw_data: - retcode: int = raw_data['retcode'] - elif 'code' in raw_data: - retcode: int = raw_data['code'] - else: - retcode = 0 - if retcode == 1034: - await self._upass(header) - return retcode - elif retcode != 0: - return retcode - return raw_data diff --git a/gsuid_core/utils/database/api.py b/gsuid_core/utils/database/api.py new file mode 100644 index 0000000..17611de --- /dev/null +++ b/gsuid_core/utils/database/api.py @@ -0,0 +1,27 @@ +from typing import Dict + +from sqlalchemy import event + +from gsuid_core.data_store import get_res_path +from gsuid_core.utils.database.dal import SQLA + +is_wal = False + +active_sqla: Dict[str, SQLA] = {} +db_url = str(get_res_path().parent / 'GsData.db') + + +def get_sqla(bot_id) -> SQLA: + if bot_id not in active_sqla: + sqla = SQLA(db_url, bot_id) + active_sqla[bot_id] = sqla + sqla.create_all() + + @event.listens_for(sqla.engine.sync_engine, 'connect') + def engine_connect(conn, branch): + if is_wal: + cursor = conn.cursor() + cursor.execute('PRAGMA journal_mode=WAL') + cursor.close() + + return active_sqla[bot_id] diff --git a/gsuid_core/utils/database/dal.py b/gsuid_core/utils/database/dal.py index 350120e..70d863a 100644 --- a/gsuid_core/utils/database/dal.py +++ b/gsuid_core/utils/database/dal.py @@ -11,13 +11,14 @@ from sqlalchemy.orm import sessionmaker from sqlalchemy.sql.expression import func from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine -from .utils import SERVER +from .utils import SERVER, SR_SERVER from .models import GsBind, GsPush, GsUser, GsCache class SQLA: - def __init__(self, url: str, bot_id: str): + def __init__(self, url: str, bot_id: str, is_sr: bool = False): self.bot_id = bot_id + self.is_sr = is_sr self.url = f'sqlite+aiosqlite:///{url}' self.engine = create_async_engine(self.url, pool_recycle=1500) self.async_session = sessionmaker( @@ -160,7 +161,12 @@ class SQLA: async def switch_uid( self, user_id: str, uid: Optional[str] = None ) -> Optional[List]: - uid_list = await self.get_bind_uid_list(user_id) + uid_list = ( + await self.get_bind_sruid_list(user_id) + if self.is_sr + else await self.get_bind_uid_list(user_id) + ) + id_type = 'sr_uid' if self.is_sr else 'uid' if uid_list and len(uid_list) >= 1: if uid and uid not in uid_list: return None @@ -170,7 +176,7 @@ class SQLA: uid = uid_list[1] uid_list.remove(uid) uid_list.insert(0, uid) - await self.update_bind_data(user_id, {'uid': '_'.join(uid_list)}) + await self.update_bind_data(user_id, {id_type: '_'.join(uid_list)}) return uid_list else: return None @@ -182,21 +188,22 @@ class SQLA: async def select_user_data(self, uid: str) -> Optional[GsUser]: async with self.async_session() as session: async with session.begin(): - sql = select(GsUser).where(GsUser.uid == uid) - result = await session.execute(sql) - return data[0] if (data := result.scalars().all()) else None - - async def select_sr_user_data(self, sr_uid: str) -> Optional[GsUser]: - async with self.async_session() as session: - async with session.begin(): - sql = select(GsUser).where(GsUser.sr_uid == sr_uid) + sql = ( + select(GsUser).where(GsUser.sr_uid == uid) + if self.is_sr + else select(GsUser).where(GsUser.uid == uid) + ) result = await session.execute(sql) return data[0] if (data := result.scalars().all()) else None async def select_cache_cookie(self, uid: str) -> Optional[str]: async with self.async_session() as session: async with session.begin(): - sql = select(GsCache).where(GsCache.uid == uid) + sql = ( + select(GsCache).where(GsCache.sr_uid == uid) + if self.is_sr + else select(GsCache).where(GsCache.uid == uid) + ) result = await session.execute(sql) data: List[GsCache] = result.scalars().all() return data[0].cookie if len(data) >= 1 else None @@ -229,9 +236,9 @@ class SQLA: async def insert_user_data( self, user_id: str, - uid: Optional[str], - sr_uid: Optional[str], - cookie: str, + uid: Optional[str] = None, + sr_uid: Optional[str] = None, + cookie: Optional[str] = None, stoken: Optional[str] = None, ) -> bool: async with self.async_session() as session: @@ -246,10 +253,11 @@ class SQLA: stoken=stoken, bot_id=self.bot_id, user_id=user_id, + sr_uid=sr_uid, ) ) await session.execute(sql) - elif sr_uid and await self.sr_user_exists(sr_uid): + elif sr_uid and await self.user_exists(sr_uid): sql = ( update(GsUser) .where(GsUser.sr_uid == sr_uid) @@ -259,10 +267,14 @@ class SQLA: stoken=stoken, bot_id=self.bot_id, user_id=user_id, + uid=uid, ) ) await session.execute(sql) else: + if cookie is None: + return False + account_id = re.search(r'account_id=(\d*)', cookie) assert account_id is not None account_id = str(account_id.group(1)) @@ -279,7 +291,9 @@ class SQLA: push_switch='off', bbs_switch='off', region=SERVER.get(uid[0], 'cn_gf01') if uid else None, - sr_region=None, + sr_region=SR_SERVER.get(sr_uid[0], None) + if sr_uid + else None, ) session.add(user_data) await session.commit() @@ -288,8 +302,14 @@ class SQLA: async def update_user_data(self, uid: str, data: Optional[Dict]): async with self.async_session() as session: async with session.begin(): - sql = update(GsUser).where( - GsUser.uid == uid, GsUser.bot_id == self.bot_id + sql = ( + update(GsUser).where( + GsUser.sr_uid == uid, GsUser.bot_id == self.bot_id + ) + if self.is_sr + else update(GsUser).where( + GsUser.uid == uid, GsUser.bot_id == self.bot_id + ) ) if data is not None: query = sql.values(**data) @@ -301,7 +321,11 @@ class SQLA: async with self.async_session() as session: async with session.begin(): if await self.user_exists(uid): - sql = delete(GsUser).where(GsUser.uid == uid) + sql = ( + delete(GsUser).where(GsUser.sr_uid == uid) + if self.is_sr + else delete(GsUser).where(GsUser.uid == uid) + ) await session.execute(sql) await session.commit() return True @@ -335,10 +359,6 @@ class SQLA: data = await self.select_user_data(uid) return True if data else False - async def sr_user_exists(self, sr_uid: str) -> bool: - data = await self.select_sr_user_data(sr_uid) - return True if data else False - async def update_user_stoken( self, uid: str, stoken: Optional[str] ) -> bool: @@ -346,9 +366,17 @@ class SQLA: async with session.begin(): if await self.user_exists(uid): sql = ( - update(GsUser) - .where(GsUser.uid == uid) - .values(stoken=stoken) + ( + update(GsUser) + .where(GsUser.sr_uid == uid) + .values(stoken=stoken) + ) + if self.is_sr + else ( + update(GsUser) + .where(GsUser.uid == uid) + .values(stoken=stoken) + ) ) await session.execute(sql) await session.commit() @@ -362,9 +390,17 @@ class SQLA: async with session.begin(): if await self.user_exists(uid): sql = ( - update(GsUser) - .where(GsUser.uid == uid) - .values(cookie=cookie) + ( + update(GsUser) + .where(GsUser.sr_uid == uid) + .values(cookie=cookie) + ) + if self.is_sr + else ( + update(GsUser) + .where(GsUser.uid == uid) + .values(cookie=cookie) + ) ) await session.execute(sql) await session.commit() @@ -376,7 +412,17 @@ class SQLA: async with session.begin(): if await self.user_exists(uid): sql = ( - update(GsUser).where(GsUser.uid == uid).values(**data) + ( + update(GsUser) + .where(GsUser.sr_uid == uid) + .values(**data) + ) + if self.is_sr + else ( + update(GsUser) + .where(GsUser.uid == uid) + .values(**data) + ) ) await session.execute(sql) await session.commit() @@ -458,7 +504,13 @@ class SQLA: user_list: List[GsUser] = data.scalars().all() for user in user_list: if not user.status and user.cookie: - await self.insert_cache_data(user.cookie, uid) # 进入缓存 + # 进入缓存 + if self.is_sr: + await self.insert_cache_data( + user.cookie, sr_uid=uid + ) + else: + await self.insert_cache_data(user.cookie, uid) return user.cookie continue else: @@ -551,7 +603,11 @@ class SQLA: async def refresh_cache(self, uid: str): async with self.async_session() as session: async with session.begin(): - sql = delete(GsCache).where(GsCache.uid == uid) + sql = ( + delete(GsCache).where(GsCache.sr_uid == uid) + if self.is_sr + else delete(GsCache).where(GsCache.uid == uid) + ) await session.execute(sql) return True diff --git a/gsuid_core/utils/database/utils.py b/gsuid_core/utils/database/utils.py index 2a7315b..c46a6d7 100644 --- a/gsuid_core/utils/database/utils.py +++ b/gsuid_core/utils/database/utils.py @@ -7,3 +7,8 @@ SERVER = { '8': 'os_asia', '9': 'os_cht', } + +SR_SERVER = { + '1': 'prod_gf_cn', + '2': 'prod_gf_cn', +}