Files
MediaManager/media_manager/common/repository.py
Maximilian Dorninger ebb6cb790a add C90 linting rule
2026-05-07 16:16:50 +02:00

176 lines
5.6 KiB
Python

import logging
from typing import Any, TypeVar
from uuid import UUID
from pydantic import BaseModel
from sqlalchemy import delete, select
from sqlalchemy.exc import IntegrityError, SQLAlchemyError
from sqlalchemy.orm import Session
from media_manager.exceptions import ConflictError, NotFoundError
log = logging.getLogger(__name__)
T = TypeVar("T")
EntityId = UUID | int | str
class BaseRepository[T, S: BaseModel]:
"""
Base repository providing common CRUD operations for media models.
"""
def __init__(self, db: Session, model: type[T], schema: type[S]) -> None:
self.db = db
self.model = model
self.schema = schema
def get_by_id(self, entity_id: EntityId) -> S:
result = self.db.get(self.model, entity_id)
if not result:
msg = f"{self.model.__name__} with id {entity_id} not found."
raise NotFoundError(msg)
return self.schema.model_validate(result)
def get_by_external_id(self, external_id: int, metadata_provider: str) -> S:
stmt = select(self.model).where(
self.model.external_id == external_id,
self.model.metadata_provider == metadata_provider,
)
result = self.db.execute(stmt).scalar_one_or_none()
if not result:
msg = f"{self.model.__name__} with external_id {external_id} and provider {metadata_provider} not found."
raise NotFoundError(msg)
return self.schema.model_validate(result)
def get_all(self) -> list[S]:
stmt = select(self.model)
results = self.db.execute(stmt).scalars().unique().all()
return [self.schema.model_validate(r) for r in results]
def delete(self, entity_id: EntityId) -> None:
obj = self.db.get(self.model, entity_id)
if not obj:
msg = f"{self.model.__name__} with id {entity_id} not found."
raise NotFoundError(msg)
self.db.delete(obj)
self.db.commit()
def set_library(self, entity_id: EntityId, library: str) -> None:
obj = self.db.get(self.model, entity_id)
if not obj:
msg = f"{self.model.__name__} with id {entity_id} not found."
raise NotFoundError(msg)
obj.library = library
self.db.commit()
def save_media_base(
self,
media_schema: S,
model_class: type[T],
exclude: set[str] | None = None,
) -> S:
"""
Generic save method for media models.
"""
if exclude is None:
exclude = set()
db_obj = self.db.get(model_class, media_schema.id) if media_schema.id else None
if db_obj:
# Update existing
# Always exclude "id" from updates
update_exclude = exclude | {"id"}
for key, value in media_schema.model_dump(exclude=update_exclude).items():
if hasattr(db_obj, key):
setattr(db_obj, key, value)
else:
# Insert new
db_obj = model_class(**media_schema.model_dump(exclude=exclude))
self.db.add(db_obj)
try:
self.db.commit()
self.db.refresh(db_obj)
except IntegrityError as e:
self.db.rollback()
msg = f"Integrity error while saving {model_class.__name__}: {e.orig}"
raise ConflictError(msg) from e
except SQLAlchemyError:
self.db.rollback()
log.exception(f"Database error while saving {model_class.__name__}")
raise
else:
return self.schema.model_validate(db_obj)
def update_media_attributes_base(
self,
media_id: EntityId,
model_class: type[T],
**attributes: Any, # noqa: ANN401
) -> S:
"""
Generic update method for media attributes.
"""
db_obj = self.db.get(model_class, media_id)
if not db_obj:
msg = f"{model_class.__name__} with id {media_id} not found."
raise NotFoundError(msg)
updated = False
for key, value in attributes.items():
if (
value is not None
and hasattr(db_obj, key)
and getattr(db_obj, key) != value
):
setattr(db_obj, key, value)
updated = True
if updated:
try:
self.db.commit()
self.db.refresh(db_obj)
except SQLAlchemyError:
self.db.rollback()
raise
return self.schema.model_validate(db_obj)
def add_media_file_base(
self, file_schema: S, model_class: type[T], schema_class: type[S]
) -> S:
"""
Generic method to add a media file record.
"""
db_model = model_class(**file_schema.model_dump())
try:
self.db.add(db_model)
self.db.commit()
self.db.refresh(db_model)
except IntegrityError:
self.db.rollback()
raise
except SQLAlchemyError:
self.db.rollback()
raise
else:
return schema_class.model_validate(db_model)
def remove_files_by_torrent_id_base(
self, torrent_id: EntityId, model_class: type[T]
) -> int:
"""
Generic method to remove media files by torrent ID.
"""
try:
stmt = delete(model_class).where(model_class.torrent_id == torrent_id)
result = self.db.execute(stmt)
self.db.commit()
except SQLAlchemyError:
self.db.rollback()
raise
else:
return result.rowcount