From bc7524c8fcfaf435ff214de6cbbd6033f3f6a332 Mon Sep 17 00:00:00 2001 From: beerpsi Date: Thu, 14 Nov 2024 12:36:22 +0700 Subject: [PATCH] fix: make database async --- core/data/alembic/env.py | 68 +++++++++++----- core/data/database.py | 156 ++++++++++++++++++++++++------------- core/data/schema/arcade.py | 30 +++---- core/data/schema/base.py | 23 +++--- core/data/schema/card.py | 11 +-- core/data/schema/user.py | 24 +++--- core/utils.py | 104 ++++++++++++++++++++----- dbutils.py | 11 +-- read.py | 24 +++--- 9 files changed, 297 insertions(+), 154 deletions(-) diff --git a/core/data/alembic/env.py b/core/data/alembic/env.py index d532093..f2a8182 100644 --- a/core/data/alembic/env.py +++ b/core/data/alembic/env.py @@ -1,8 +1,14 @@ from __future__ import with_statement -from alembic import context -from sqlalchemy import engine_from_config, pool + +import asyncio +import threading from logging.config import fileConfig +from alembic import context +from sqlalchemy import pool +from sqlalchemy.engine import Connection +from sqlalchemy.ext.asyncio import async_engine_from_config + from core.data.schema.base import metadata # this is the Alembic Config object, which provides @@ -37,20 +43,29 @@ def run_migrations_offline(): script output. """ - raise Exception('Not implemented or configured!') + raise Exception("Not implemented or configured!") url = config.get_main_option("sqlalchemy.url") - context.configure( - url=url, target_metadata=target_metadata, literal_binds=True) + context.configure(url=url, target_metadata=target_metadata, literal_binds=True) with context.begin_transaction(): context.run_migrations() -def run_migrations_online(): - """Run migrations in 'online' mode. +def do_run_migrations(connection: Connection) -> None: + context.configure( + connection=connection, + target_metadata=target_metadata, + compare_type=True, + compare_server_default=True, + ) - In this scenario we need to create an Engine + with context.begin_transaction(): + context.run_migrations() + + +async def run_async_migrations() -> None: + """In this scenario we need to create an Engine and associate a connection with the context. """ @@ -59,21 +74,32 @@ def run_migrations_online(): for override in overrides: ini_section[override] = overrides[override] - connectable = engine_from_config( - ini_section, - prefix='sqlalchemy.', - poolclass=pool.NullPool) + connectable = async_engine_from_config( + ini_section, prefix="sqlalchemy.", poolclass=pool.NullPool + ) - with connectable.connect() as connection: - context.configure( - connection=connection, - target_metadata=target_metadata, - compare_type=True, - compare_server_default=True, - ) + async with connectable.connect() as connection: + await connection.run_sync(do_run_migrations) + + await connectable.dispose() + + +def run_migrations_online(): + try: + loop = asyncio.get_running_loop() + except RuntimeError: + # there's no event loop + asyncio.run(run_async_migrations()) + else: + # there's currently an event loop and trying to wait for a coroutine + # to finish without using `await` is pretty wormy. nested event loops + # are explicitly forbidden by asyncio. + # + # take the easy way out, spawn it in another thread. + thread = threading.Thread(target=asyncio.run, args=(run_async_migrations(),)) + thread.start() + thread.join() - with context.begin_transaction(): - context.run_migrations() if context.is_offline_mode(): run_migrations_offline() diff --git a/core/data/database.py b/core/data/database.py index b4f3cc0..16bd67b 100644 --- a/core/data/database.py +++ b/core/data/database.py @@ -1,54 +1,65 @@ -import logging, coloredlogs -from typing import Optional -from sqlalchemy.orm import scoped_session, sessionmaker -from sqlalchemy import create_engine -from logging.handlers import TimedRotatingFileHandler +import asyncio +import logging import os -import secrets, string -import bcrypt +import secrets +import string +import warnings from hashlib import sha256 +from logging.handlers import TimedRotatingFileHandler +from typing import ClassVar, Optional + import alembic.config -import glob +import bcrypt +import coloredlogs +import pymysql.err +from sqlalchemy.ext.asyncio import ( + AsyncEngine, + AsyncSession, + async_scoped_session, + create_async_engine, +) +from sqlalchemy.orm import sessionmaker from core.config import CoreConfig -from core.data.schema import * -from core.utils import Utils +from core.data.schema import ArcadeData, BaseData, CardData, UserData, metadata +from core.utils import MISSING, Utils class Data: - engine = None - session = None - user = None - arcade = None - card = None - base = None + engine: ClassVar[AsyncEngine] = MISSING + session: ClassVar[AsyncSession] = MISSING + user: ClassVar[UserData] = MISSING + arcade: ClassVar[ArcadeData] = MISSING + card: ClassVar[CardData] = MISSING + base: ClassVar[BaseData] = MISSING + def __init__(self, cfg: CoreConfig) -> None: self.config = cfg if self.config.database.sha2_password: passwd = sha256(self.config.database.password.encode()).digest() - self.__url = f"{self.config.database.protocol}://{self.config.database.username}:{passwd.hex()}@{self.config.database.host}:{self.config.database.port}/{self.config.database.name}?charset=utf8mb4&ssl={str(self.config.database.ssl_enabled).lower()}" + self.__url = f"{self.config.database.protocol}+aiomysql://{self.config.database.username}:{passwd.hex()}@{self.config.database.host}:{self.config.database.port}/{self.config.database.name}?charset=utf8mb4&ssl={str(self.config.database.ssl_enabled).lower()}" else: - self.__url = f"{self.config.database.protocol}://{self.config.database.username}:{self.config.database.password}@{self.config.database.host}:{self.config.database.port}/{self.config.database.name}?charset=utf8mb4&ssl={str(self.config.database.ssl_enabled).lower()}" + self.__url = f"{self.config.database.protocol}+aiomysql://{self.config.database.username}:{self.config.database.password}@{self.config.database.host}:{self.config.database.port}/{self.config.database.name}?charset=utf8mb4&ssl={str(self.config.database.ssl_enabled).lower()}" - if Data.engine is None: - Data.engine = create_engine(self.__url, pool_recycle=3600) + if Data.engine is MISSING: + Data.engine = create_async_engine(self.__url, pool_recycle=3600, isolation_level="AUTOCOMMIT") self.__engine = Data.engine - if Data.session is None: - s = sessionmaker(bind=Data.engine, autoflush=True, autocommit=True) - Data.session = scoped_session(s) + if Data.session is MISSING: + s = sessionmaker(Data.engine, expire_on_commit=False, class_=AsyncSession) + Data.session = async_scoped_session(s, asyncio.current_task) - if Data.user is None: + if Data.user is MISSING: Data.user = UserData(self.config, self.session) - if Data.arcade is None: + if Data.arcade is MISSING: Data.arcade = ArcadeData(self.config, self.session) - if Data.card is None: + if Data.card is MISSING: Data.card = CardData(self.config, self.session) - if Data.base is None: + if Data.base is MISSING: Data.base = BaseData(self.config, self.session) self.logger = logging.getLogger("database") @@ -94,40 +105,73 @@ class Data: alembic.config.main(argv=alembicArgs) os.chdir(old_dir) - def create_database(self): + async def create_database(self): self.logger.info("Creating databases...") - metadata.create_all( - self.engine, - checkfirst=True, - ) - for _, mod in Utils.get_all_titles().items(): - if hasattr(mod, "database"): - mod.database(self.config) - metadata.create_all( - self.engine, - checkfirst=True, - ) + with warnings.catch_warnings(): + # SQLAlchemy will generate a nice primary key constraint name, but in + # MySQL/MariaDB the constraint name is always PRIMARY. Every time a + # custom primary key name is generated, a warning is emitted from pymysql, + # which we don't care about. Other warnings may be helpful though, don't + # suppress everything. + warnings.filterwarnings( + action="ignore", + message=r"Name '(.+)' ignored for PRIMARY key\.", + category=pymysql.err.Warning, + ) - # Stamp the end revision as if alembic had created it, so it can take off after this. - self.__alembic_cmd( - "stamp", - "head", - ) + async with self.engine.begin() as conn: + await conn.run_sync(metadata.create_all, checkfirst=True) - def schema_upgrade(self, ver: str = None): - self.__alembic_cmd( - "upgrade", - "head" if not ver else ver, - ) + for _, mod in Utils.get_all_titles().items(): + if hasattr(mod, "database"): + mod.database(self.config) + + await conn.run_sync(metadata.create_all, checkfirst=True) + + # Stamp the end revision as if alembic had created it, so it can take off after this. + self.__alembic_cmd( + "stamp", + "head", + ) + + def schema_upgrade(self, ver: Optional[str] = None): + with warnings.catch_warnings(): + # SQLAlchemy will generate a nice primary key constraint name, but in + # MySQL/MariaDB the constraint name is always PRIMARY. Every time a + # custom primary key name is generated, a warning is emitted from pymysql, + # which we don't care about. Other warnings may be helpful though, don't + # suppress everything. + warnings.filterwarnings( + action="ignore", + message=r"Name '(.+)' ignored for PRIMARY key\.", + category=pymysql.err.Warning, + ) + + self.__alembic_cmd( + "upgrade", + "head" if not ver else ver, + ) def schema_downgrade(self, ver: str): - self.__alembic_cmd( - "downgrade", - ver, - ) + with warnings.catch_warnings(): + # SQLAlchemy will generate a nice primary key constraint name, but in + # MySQL/MariaDB the constraint name is always PRIMARY. Every time a + # custom primary key name is generated, a warning is emitted from pymysql, + # which we don't care about. Other warnings may be helpful though, don't + # suppress everything. + warnings.filterwarnings( + action="ignore", + message=r"Name '(.+)' ignored for PRIMARY key\.", + category=pymysql.err.Warning, + ) - async def create_owner(self, email: Optional[str] = None, code: Optional[str] = "00000000000000000000") -> None: + self.__alembic_cmd( + "downgrade", + ver, + ) + + async def create_owner(self, email: Optional[str] = None, code: str = "00000000000000000000") -> None: pw = "".join( secrets.choice(string.ascii_letters + string.digits) for i in range(20) ) @@ -150,12 +194,12 @@ class Data: async def migrate(self) -> None: exist = await self.base.execute("SELECT * FROM alembic_version") if exist is not None: - self.logger.warn("No need to migrate as you have already migrated to alembic. If you are trying to upgrade the schema, use `upgrade` instead!") + self.logger.warning("No need to migrate as you have already migrated to alembic. If you are trying to upgrade the schema, use `upgrade` instead!") return self.logger.info("Upgrading to latest with legacy system") if not await self.legacy_upgrade(): - self.logger.warn("No need to migrate as you have already deleted the old schema_versions system. If you are trying to upgrade the schema, use `upgrade` instead!") + self.logger.warning("No need to migrate as you have already deleted the old schema_versions system. If you are trying to upgrade the schema, use `upgrade` instead!") return self.logger.info("Done") diff --git a/core/data/schema/arcade.py b/core/data/schema/arcade.py index 5b570a1..653fe7c 100644 --- a/core/data/schema/arcade.py +++ b/core/data/schema/arcade.py @@ -1,16 +1,16 @@ -from typing import Optional, Dict, List -from sqlalchemy import Table, Column, and_, or_ -from sqlalchemy.sql.schema import ForeignKey, PrimaryKeyConstraint -from sqlalchemy.types import Integer, String, Boolean, JSON -from sqlalchemy.sql import func, select +import re +from typing import List, Optional + +from sqlalchemy import Column, Table, and_, or_ from sqlalchemy.dialects.mysql import insert from sqlalchemy.engine import Row -import re +from sqlalchemy.sql import func, select +from sqlalchemy.sql.schema import ForeignKey, PrimaryKeyConstraint +from sqlalchemy.types import JSON, Boolean, Integer, String from core.data.schema.base import BaseData, metadata -from core.const import * -arcade = Table( +arcade: Table = Table( "arcade", metadata, Column("id", Integer, primary_key=True, nullable=False), @@ -26,7 +26,7 @@ arcade = Table( mysql_charset="utf8mb4", ) -machine = Table( +machine: Table = Table( "machine", metadata, Column("id", Integer, primary_key=True, nullable=False), @@ -47,7 +47,7 @@ machine = Table( mysql_charset="utf8mb4", ) -arcade_owner = Table( +arcade_owner: Table = Table( "arcade_owner", metadata, Column( @@ -69,7 +69,7 @@ arcade_owner = Table( class ArcadeData(BaseData): - async def get_machine(self, serial: str = None, id: int = None) -> Optional[Row]: + async def get_machine(self, serial: Optional[str] = None, id: Optional[int] = None) -> Optional[Row]: if serial is not None: serial = serial.replace("-", "") if len(serial) == 11: @@ -98,8 +98,8 @@ class ArcadeData(BaseData): self, arcade_id: int, serial: str = "", - board: str = None, - game: str = None, + board: Optional[str] = None, + game: Optional[str] = None, is_cab: bool = False, ) -> Optional[int]: if not arcade_id: @@ -150,8 +150,8 @@ class ArcadeData(BaseData): async def create_arcade( self, - name: str = None, - nickname: str = None, + name: Optional[str] = None, + nickname: Optional[str] = None, country: str = "JPN", country_id: int = 1, state: str = "", diff --git a/core/data/schema/base.py b/core/data/schema/base.py index d74198b..cb44272 100644 --- a/core/data/schema/base.py +++ b/core/data/schema/base.py @@ -1,22 +1,23 @@ +import asyncio import json import logging from random import randrange -from typing import Any, Optional, Dict, List +from typing import Any, Dict, List, Optional + +from sqlalchemy import Column, MetaData, Table from sqlalchemy.engine import Row from sqlalchemy.engine.cursor import CursorResult -from sqlalchemy.engine.base import Connection -from sqlalchemy.sql import text, func, select from sqlalchemy.exc import SQLAlchemyError -from sqlalchemy import MetaData, Table, Column -from sqlalchemy.types import Integer, String, TIMESTAMP, JSON, INTEGER, TEXT +from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.schema import ForeignKey -from sqlalchemy.dialects.mysql import insert +from sqlalchemy.sql import func, text +from sqlalchemy.types import INTEGER, JSON, TEXT, TIMESTAMP, Integer, String from core.config import CoreConfig metadata = MetaData() -event_log = Table( +event_log: Table = Table( "event_log", metadata, Column("id", Integer, primary_key=True, nullable=False), @@ -37,7 +38,7 @@ event_log = Table( class BaseData: - def __init__(self, cfg: CoreConfig, conn: Connection) -> None: + def __init__(self, cfg: CoreConfig, conn: AsyncSession) -> None: self.config = cfg self.conn = conn self.logger = logging.getLogger("database") @@ -47,7 +48,7 @@ class BaseData: try: self.logger.debug(f"SQL Execute: {''.join(str(sql).splitlines())}") - res = self.conn.execute(text(sql), opts) + res = await self.conn.execute(text(sql), opts) except SQLAlchemyError as e: self.logger.error(f"SQLAlchemy error {e}") @@ -59,7 +60,7 @@ class BaseData: except Exception: try: - res = self.conn.execute(sql, opts) + res = await self.conn.execute(sql, opts) except SQLAlchemyError as e: self.logger.error(f"SQLAlchemy error {e}") @@ -83,7 +84,7 @@ class BaseData: async def log_event( self, system: str, type: str, severity: int, message: str, details: Dict = {}, user: int = None, - arcade: int = None, machine: int = None, ip: str = None, game: str = None, version: str = None + arcade: int = None, machine: int = None, ip: Optional[str] = None, game: Optional[str] = None, version: Optional[str] = None ) -> Optional[int]: sql = event_log.insert().values( system=system, diff --git a/core/data/schema/card.py b/core/data/schema/card.py index 1865539..254b19e 100644 --- a/core/data/schema/card.py +++ b/core/data/schema/card.py @@ -1,13 +1,14 @@ from typing import Dict, List, Optional -from sqlalchemy import Table, Column, UniqueConstraint -from sqlalchemy.types import Integer, String, Boolean, TIMESTAMP, BIGINT, VARCHAR -from sqlalchemy.sql.schema import ForeignKey -from sqlalchemy.sql import func + +from sqlalchemy import Column, Table, UniqueConstraint from sqlalchemy.engine import Row +from sqlalchemy.sql import func +from sqlalchemy.sql.schema import ForeignKey +from sqlalchemy.types import BIGINT, TIMESTAMP, VARCHAR, Boolean, Integer, String from core.data.schema.base import BaseData, metadata -aime_card = Table( +aime_card: Table = Table( "aime_card", metadata, Column("id", Integer, primary_key=True, nullable=False), diff --git a/core/data/schema/user.py b/core/data/schema/user.py index 8c3695c..8686f08 100644 --- a/core/data/schema/user.py +++ b/core/data/schema/user.py @@ -1,15 +1,15 @@ -from typing import Optional, List -from sqlalchemy import Table, Column -from sqlalchemy.types import Integer, String, TIMESTAMP -from sqlalchemy.sql import func -from sqlalchemy.dialects.mysql import insert -from sqlalchemy.sql import func, select -from sqlalchemy.engine import Row +from typing import List, Optional + import bcrypt +from sqlalchemy import Column, Table +from sqlalchemy.dialects.mysql import insert +from sqlalchemy.engine import Row +from sqlalchemy.sql import func, select +from sqlalchemy.types import TIMESTAMP, Integer, String from core.data.schema.base import BaseData, metadata -aime_user = Table( +aime_user: Table = Table( "aime_user", metadata, Column("id", Integer, nullable=False, primary_key=True, autoincrement=True), @@ -26,10 +26,10 @@ aime_user = Table( class UserData(BaseData): async def create_user( self, - id: int = None, - username: str = None, - email: str = None, - password: str = None, + id: Optional[int] = None, + username: Optional[str] = None, + email: Optional[str] = None, + password: Optional[str] = None, permission: int = 1, ) -> Optional[int]: if id is None: diff --git a/core/utils.py b/core/utils.py index 24c174c..af96451 100644 --- a/core/utils.py +++ b/core/utils.py @@ -1,18 +1,47 @@ -from typing import Dict, Any, Optional -from types import ModuleType -from starlette.requests import Request -import logging import importlib -from os import walk -import jwt +import logging from base64 import b64decode from datetime import datetime, timezone +from os import walk +from types import ModuleType +from typing import Any, Dict, Optional + +import jwt +from starlette.requests import Request from .config import CoreConfig + +class _MissingSentinel: + __slots__: tuple[str, ...] = () + + def __eq__(self, other) -> bool: + return False + + def __bool__(self) -> bool: + return False + + def __hash__(self) -> int: + return 0 + + def __repr__(self): + return "..." + + +MISSING: Any = _MissingSentinel() +"""This is different from `None` in that its type is `Any`, and so it can be used +as a placeholder for values that are *definitely* going to be initialized, +so they don't have to be typed as `T | None`, which makes type checkers +angry when an attribute is accessed. + +This can also be used for when `None` has actual meaning as a value, and so a +separate value is needed to mean "unset".""" + + class Utils: real_title_port = None real_title_port_ssl = None + @classmethod def get_all_titles(cls) -> Dict[str, ModuleType]: ret: Dict[str, Any] = {} @@ -36,27 +65,56 @@ class Utils: def get_ip_addr(cls, req: Request) -> str: ip = req.headers.get("x-forwarded-for", req.client.host) return ip.split(", ")[0] - + @classmethod def get_title_port(cls, cfg: CoreConfig): - if cls.real_title_port is not None: return cls.real_title_port + if cls.real_title_port is not None: + return cls.real_title_port + + cls.real_title_port = ( + cfg.server.proxy_port + if cfg.server.is_using_proxy and cfg.server.proxy_port + else cfg.server.port + ) - cls.real_title_port = cfg.server.proxy_port if cfg.server.is_using_proxy and cfg.server.proxy_port else cfg.server.port - return cls.real_title_port - + @classmethod def get_title_port_ssl(cls, cfg: CoreConfig): - if cls.real_title_port_ssl is not None: return cls.real_title_port_ssl + if cls.real_title_port_ssl is not None: + return cls.real_title_port_ssl + + cls.real_title_port_ssl = ( + cfg.server.proxy_port_ssl + if cfg.server.is_using_proxy and cfg.server.proxy_port_ssl + else 443 + ) - cls.real_title_port_ssl = cfg.server.proxy_port_ssl if cfg.server.is_using_proxy and cfg.server.proxy_port_ssl else 443 - return cls.real_title_port_ssl -def create_sega_auth_key(aime_id: int, game: str, place_id: int, keychip_id: str, b64_secret: str, exp_seconds: int = 86400, err_logger: str = 'aimedb') -> Optional[str]: + +def create_sega_auth_key( + aime_id: int, + game: str, + place_id: int, + keychip_id: str, + b64_secret: str, + exp_seconds: int = 86400, + err_logger: str = "aimedb", +) -> Optional[str]: logger = logging.getLogger(err_logger) try: - return jwt.encode({ "aime_id": aime_id, "game": game, "place_id": place_id, "keychip_id": keychip_id, "exp": int(datetime.now(tz=timezone.utc).timestamp()) + exp_seconds }, b64decode(b64_secret), algorithm="HS256") + return jwt.encode( + { + "aime_id": aime_id, + "game": game, + "place_id": place_id, + "keychip_id": keychip_id, + "exp": int(datetime.now(tz=timezone.utc).timestamp()) + exp_seconds, + }, + b64decode(b64_secret), + algorithm="HS256", + ) except jwt.InvalidKeyError: logger.error("Failed to encode Sega Auth Key because the secret is invalid!") return None @@ -64,10 +122,19 @@ def create_sega_auth_key(aime_id: int, game: str, place_id: int, keychip_id: str logger.error(f"Unknown exception occoured when encoding Sega Auth Key! {e}") return None -def decode_sega_auth_key(token: str, b64_secret: str, err_logger: str = 'aimedb') -> Optional[Dict]: + +def decode_sega_auth_key( + token: str, b64_secret: str, err_logger: str = "aimedb" +) -> Optional[Dict]: logger = logging.getLogger(err_logger) try: - return jwt.decode(token, "secret", b64decode(b64_secret), algorithms=["HS256"], options={"verify_signature": True}) + return jwt.decode( + token, + "secret", + b64decode(b64_secret), + algorithms=["HS256"], + options={"verify_signature": True}, + ) except jwt.ExpiredSignatureError: logger.error("Sega Auth Key failed to validate due to an expired signature!") return None @@ -83,4 +150,3 @@ def decode_sega_auth_key(token: str, b64_secret: str, err_logger: str = 'aimedb' except Exception as e: logger.error(f"Unknown exception occoured when decoding Sega Auth Key! {e}") return None - \ No newline at end of file diff --git a/dbutils.py b/dbutils.py index 9314f8e..9080afc 100644 --- a/dbutils.py +++ b/dbutils.py @@ -1,12 +1,13 @@ #!/usr/bin/env python3 import argparse -import logging -from os import mkdir, path, access, W_OK, environ -import yaml import asyncio +import logging +from os import W_OK, access, environ, mkdir, path + +import yaml -from core.data import Data from core.config import CoreConfig +from core.data import Data if __name__ == "__main__": parser = argparse.ArgumentParser(description="Database utilities") @@ -46,7 +47,7 @@ if __name__ == "__main__": loop = asyncio.get_event_loop() if args.action == "create": - data.create_database() + loop.run_until_complete(data.create_database()) elif args.action == "upgrade": data.schema_upgrade(args.version) diff --git a/read.py b/read.py index 8a0ae72..c6950a2 100644 --- a/read.py +++ b/read.py @@ -1,16 +1,16 @@ #!/usr/bin/env python3 import argparse -import re -import os -import yaml -from os import path -import logging -import coloredlogs import asyncio - +import logging +import os +import re from logging.handlers import TimedRotatingFileHandler +from os import path from typing import List, Optional +import coloredlogs +import yaml + from core import CoreConfig, Utils @@ -44,7 +44,7 @@ class BaseReader: pass -if __name__ == "__main__": +async def main(): parser = argparse.ArgumentParser(description="Import Game Information") parser.add_argument( "--game", @@ -140,8 +140,12 @@ if __name__ == "__main__": for dir, mod in titles.items(): if args.game in mod.game_codes: handler = mod.reader(config, args.version, bin_arg, opt_arg, args.extra) - loop = asyncio.get_event_loop() - loop.run_until_complete(handler.read()) + + await handler.read() logger.info("Done") + + +if __name__ == "__main__": + asyncio.run(main())