mirror of
https://github.com/maxdorninger/MediaManager.git
synced 2026-02-19 23:49:11 -05:00
This PR enables the ruff rule for return type annotations (ANN), and adds the ty package for type checking.
104 lines
2.5 KiB
Python
104 lines
2.5 KiB
Python
import logging
|
|
import os
|
|
from contextvars import ContextVar
|
|
from typing import Annotated, Any, Generator, Optional
|
|
|
|
from fastapi import Depends
|
|
from sqlalchemy import create_engine
|
|
from sqlalchemy.engine import Engine
|
|
from sqlalchemy.engine.url import URL
|
|
from sqlalchemy.orm import Session, declarative_base, sessionmaker
|
|
|
|
from media_manager.database.config import DbConfig
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
Base = declarative_base()
|
|
|
|
engine: Optional[Engine] = None
|
|
SessionLocal: Optional[sessionmaker] = None
|
|
|
|
|
|
def build_db_url(
|
|
user: str,
|
|
password: str,
|
|
host: str,
|
|
port: int | str,
|
|
dbname: str,
|
|
) -> URL:
|
|
return URL.create(
|
|
"postgresql+psycopg",
|
|
user,
|
|
password,
|
|
host,
|
|
int(port),
|
|
dbname,
|
|
)
|
|
|
|
|
|
def init_engine(
|
|
db_config: DbConfig | None = None,
|
|
url: str | URL | None = None,
|
|
) -> Engine:
|
|
"""
|
|
Initialize the global SQLAlchemy engine and session factory.
|
|
Pass either a DbConfig-like object or a full URL. Only initializes once.
|
|
"""
|
|
global engine, SessionLocal
|
|
if engine is not None:
|
|
return engine
|
|
|
|
if url is None:
|
|
if db_config is None:
|
|
url = os.getenv("DATABASE_URL")
|
|
if not url:
|
|
msg = "DB config or `DATABASE_URL` must be provided"
|
|
raise RuntimeError(msg)
|
|
else:
|
|
url = build_db_url(
|
|
db_config.user,
|
|
db_config.password,
|
|
db_config.host,
|
|
db_config.port,
|
|
db_config.dbname,
|
|
)
|
|
|
|
engine = create_engine(
|
|
url,
|
|
echo=False,
|
|
pool_size=10,
|
|
max_overflow=10,
|
|
pool_timeout=30,
|
|
pool_recycle=1800,
|
|
)
|
|
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
|
log.debug("SQLAlchemy engine initialized")
|
|
return engine
|
|
|
|
|
|
def get_engine() -> Engine:
|
|
if engine is None:
|
|
msg = "Engine not initialized. Call init_engine(...) first."
|
|
raise RuntimeError(msg)
|
|
return engine
|
|
|
|
|
|
def get_session() -> Generator[Session, Any, None]:
|
|
if SessionLocal is None:
|
|
msg = "Session factory not initialized. Call init_engine(...) first."
|
|
raise RuntimeError(msg)
|
|
db = SessionLocal()
|
|
try:
|
|
yield db
|
|
db.commit()
|
|
except Exception as e:
|
|
db.rollback()
|
|
log.critical(f"error occurred: {e}")
|
|
raise
|
|
finally:
|
|
db.close()
|
|
|
|
|
|
db_session: ContextVar[Session] = ContextVar("db_session")
|
|
DbSessionDependency = Annotated[Session, Depends(get_session)]
|