Files
MediaManager/media_manager/database/__init__.py
Marcel Hellwig 96b84d45db Adding some more new lints (#393)
Enable `UP` and `TRY` lint
2026-02-01 18:04:15 +01:00

105 lines
2.5 KiB
Python

import logging
import os
from collections.abc import Generator
from contextvars import ContextVar
from typing import Annotated
from fastapi import Depends
from sqlalchemy import create_engine
from sqlalchemy.engine import Engine
from sqlalchemy.engine.url import URL
from sqlalchemy.orm import Session, declarative_base, sessionmaker
from media_manager.database.config import DbConfig
log = logging.getLogger(__name__)
Base = declarative_base()
engine: Engine | None = None
SessionLocal: sessionmaker | None = None
def build_db_url(
user: str,
password: str,
host: str,
port: int | str,
dbname: str,
) -> URL:
return URL.create(
"postgresql+psycopg",
user,
password,
host,
int(port),
dbname,
)
def init_engine(
db_config: DbConfig | None = None,
url: str | URL | None = None,
) -> Engine:
"""
Initialize the global SQLAlchemy engine and session factory.
Pass either a DbConfig-like object or a full URL. Only initializes once.
"""
global engine, SessionLocal
if engine is not None:
return engine
if url is None:
if db_config is None:
url = os.getenv("DATABASE_URL")
if not url:
msg = "DB config or `DATABASE_URL` must be provided"
raise RuntimeError(msg)
else:
url = build_db_url(
db_config.user,
db_config.password,
db_config.host,
db_config.port,
db_config.dbname,
)
engine = create_engine(
url,
echo=False,
pool_size=10,
max_overflow=10,
pool_timeout=30,
pool_recycle=1800,
)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
log.debug("SQLAlchemy engine initialized")
return engine
def get_engine() -> Engine:
if engine is None:
msg = "Engine not initialized. Call init_engine(...) first."
raise RuntimeError(msg)
return engine
def get_session() -> Generator[Session]:
if SessionLocal is None:
msg = "Session factory not initialized. Call init_engine(...) first."
raise RuntimeError(msg)
db = SessionLocal()
try:
yield db
db.commit()
except Exception:
db.rollback()
log.critical("", exc_info=True)
raise
finally:
db.close()
db_session: ContextVar[Session] = ContextVar("db_session")
DbSessionDependency = Annotated[Session, Depends(get_session)]