Files
MediaManager/media_manager/database/__init__.py
Maximilian Dorninger a39e0d204a Ruff enable type annotations rule (#362)
This PR enables the ruff rule for return type annotations (ANN), and
adds the ty package for type checking.
2026-01-06 17:07:19 +01:00

104 lines
2.5 KiB
Python

import logging
import os
from contextvars import ContextVar
from typing import Annotated, Any, Generator, Optional
from fastapi import Depends
from sqlalchemy import create_engine
from sqlalchemy.engine import Engine
from sqlalchemy.engine.url import URL
from sqlalchemy.orm import Session, declarative_base, sessionmaker
from media_manager.database.config import DbConfig
log = logging.getLogger(__name__)
Base = declarative_base()
engine: Optional[Engine] = None
SessionLocal: Optional[sessionmaker] = None
def build_db_url(
user: str,
password: str,
host: str,
port: int | str,
dbname: str,
) -> URL:
return URL.create(
"postgresql+psycopg",
user,
password,
host,
int(port),
dbname,
)
def init_engine(
db_config: DbConfig | None = None,
url: str | URL | None = None,
) -> Engine:
"""
Initialize the global SQLAlchemy engine and session factory.
Pass either a DbConfig-like object or a full URL. Only initializes once.
"""
global engine, SessionLocal
if engine is not None:
return engine
if url is None:
if db_config is None:
url = os.getenv("DATABASE_URL")
if not url:
msg = "DB config or `DATABASE_URL` must be provided"
raise RuntimeError(msg)
else:
url = build_db_url(
db_config.user,
db_config.password,
db_config.host,
db_config.port,
db_config.dbname,
)
engine = create_engine(
url,
echo=False,
pool_size=10,
max_overflow=10,
pool_timeout=30,
pool_recycle=1800,
)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
log.debug("SQLAlchemy engine initialized")
return engine
def get_engine() -> Engine:
if engine is None:
msg = "Engine not initialized. Call init_engine(...) first."
raise RuntimeError(msg)
return engine
def get_session() -> Generator[Session, Any, None]:
if SessionLocal is None:
msg = "Session factory not initialized. Call init_engine(...) first."
raise RuntimeError(msg)
db = SessionLocal()
try:
yield db
db.commit()
except Exception as e:
db.rollback()
log.critical(f"error occurred: {e}")
raise
finally:
db.close()
db_session: ContextVar[Session] = ContextVar("db_session")
DbSessionDependency = Annotated[Session, Depends(get_session)]