Merge pull request '[database] fix invalid transaction being left open' (#187) from beerpsi/artemis:fix/invalid-transaction into develop
Reviewed-on: https://gitea.tendokyu.moe/Hay1tsme/artemis/pulls/187pull/194/head
commit
6a305d2514
|
@ -1,12 +1,11 @@
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import secrets
|
import secrets
|
||||||
import ssl
|
|
||||||
import string
|
import string
|
||||||
import warnings
|
import warnings
|
||||||
from hashlib import sha256
|
from hashlib import sha256
|
||||||
from logging.handlers import TimedRotatingFileHandler
|
from logging.handlers import TimedRotatingFileHandler
|
||||||
from typing import Any, ClassVar, Optional
|
from typing import ClassVar, Optional
|
||||||
|
|
||||||
import alembic.config
|
import alembic.config
|
||||||
import bcrypt
|
import bcrypt
|
||||||
|
@ -17,6 +16,7 @@ from sqlalchemy.ext.asyncio import (
|
||||||
AsyncSession,
|
AsyncSession,
|
||||||
create_async_engine,
|
create_async_engine,
|
||||||
)
|
)
|
||||||
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
from core.config import CoreConfig
|
from core.config import CoreConfig
|
||||||
from core.data.schema import ArcadeData, BaseData, CardData, UserData, metadata
|
from core.data.schema import ArcadeData, BaseData, CardData, UserData, metadata
|
||||||
|
@ -25,7 +25,7 @@ from core.utils import MISSING, Utils
|
||||||
|
|
||||||
class Data:
|
class Data:
|
||||||
engine: ClassVar[AsyncEngine] = MISSING
|
engine: ClassVar[AsyncEngine] = MISSING
|
||||||
session: ClassVar[AsyncSession] = MISSING
|
session: ClassVar["sessionmaker[AsyncSession]"] = MISSING
|
||||||
user: ClassVar[UserData] = MISSING
|
user: ClassVar[UserData] = MISSING
|
||||||
arcade: ClassVar[ArcadeData] = MISSING
|
arcade: ClassVar[ArcadeData] = MISSING
|
||||||
card: ClassVar[CardData] = MISSING
|
card: ClassVar[CardData] = MISSING
|
||||||
|
@ -53,7 +53,7 @@ class Data:
|
||||||
self.__engine = Data.engine
|
self.__engine = Data.engine
|
||||||
|
|
||||||
if Data.session is MISSING:
|
if Data.session is MISSING:
|
||||||
Data.session = AsyncSession(Data.engine, expire_on_commit=False)
|
Data.session = sessionmaker(Data.engine, expire_on_commit=False, class_=AsyncSession)
|
||||||
|
|
||||||
if Data.user is MISSING:
|
if Data.user is MISSING:
|
||||||
Data.user = UserData(self.config, self.session)
|
Data.user = UserData(self.config, self.session)
|
||||||
|
|
|
@ -9,6 +9,7 @@ from sqlalchemy.engine import Row
|
||||||
from sqlalchemy.engine.cursor import CursorResult
|
from sqlalchemy.engine.cursor import CursorResult
|
||||||
from sqlalchemy.exc import SQLAlchemyError
|
from sqlalchemy.exc import SQLAlchemyError
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy.orm import sessionmaker
|
||||||
from sqlalchemy.schema import ForeignKey
|
from sqlalchemy.schema import ForeignKey
|
||||||
from sqlalchemy.sql import func, text
|
from sqlalchemy.sql import func, text
|
||||||
from sqlalchemy.types import INTEGER, JSON, TEXT, TIMESTAMP, Integer, String
|
from sqlalchemy.types import INTEGER, JSON, TEXT, TIMESTAMP, Integer, String
|
||||||
|
@ -38,7 +39,7 @@ event_log: Table = Table(
|
||||||
|
|
||||||
|
|
||||||
class BaseData:
|
class BaseData:
|
||||||
def __init__(self, cfg: CoreConfig, conn: AsyncSession) -> None:
|
def __init__(self, cfg: CoreConfig, conn: "sessionmaker[AsyncSession]") -> None:
|
||||||
self.config = cfg
|
self.config = cfg
|
||||||
self.conn = conn
|
self.conn = conn
|
||||||
self.logger = logging.getLogger("database")
|
self.logger = logging.getLogger("database")
|
||||||
|
@ -46,9 +47,10 @@ class BaseData:
|
||||||
async def execute(self, sql: str, opts: Dict[str, Any] = {}) -> Optional[CursorResult]:
|
async def execute(self, sql: str, opts: Dict[str, Any] = {}) -> Optional[CursorResult]:
|
||||||
res = None
|
res = None
|
||||||
|
|
||||||
|
async with self.conn() as session:
|
||||||
try:
|
try:
|
||||||
self.logger.debug(f"SQL Execute: {''.join(str(sql).splitlines())}")
|
self.logger.debug(f"SQL Execute: {''.join(str(sql).splitlines())}")
|
||||||
res = await self.conn.execute(text(sql), opts)
|
res = await session.execute(text(sql), opts)
|
||||||
|
|
||||||
except SQLAlchemyError as e:
|
except SQLAlchemyError as e:
|
||||||
self.logger.error(f"SQLAlchemy error {e}")
|
self.logger.error(f"SQLAlchemy error {e}")
|
||||||
|
@ -60,7 +62,7 @@ class BaseData:
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
try:
|
try:
|
||||||
res = await self.conn.execute(sql, opts)
|
res = await session.execute(sql, opts)
|
||||||
|
|
||||||
except SQLAlchemyError as e:
|
except SQLAlchemyError as e:
|
||||||
self.logger.error(f"SQLAlchemy error {e}")
|
self.logger.error(f"SQLAlchemy error {e}")
|
||||||
|
|
Loading…
Reference in New Issue