Files
MediaManager/media_manager/tv/repository.py

742 lines
28 KiB
Python

from sqlalchemy import select, delete, func
from sqlalchemy.exc import (
IntegrityError,
SQLAlchemyError,
) # Keep SQLAlchemyError for broader exception handling
from sqlalchemy.orm import Session, joinedload
from media_manager.torrent.models import Torrent
from media_manager.torrent.schemas import TorrentId, Torrent as TorrentSchema
from media_manager.tv import log
from media_manager.tv.models import Season, Show, Episode, SeasonRequest, SeasonFile
from media_manager.exceptions import NotFoundError, MediaAlreadyExists
from media_manager.tv.schemas import (
Season as SeasonSchema,
SeasonId,
Show as ShowSchema,
ShowId,
Episode as EpisodeSchema, # Added EpisodeSchema import
SeasonRequest as SeasonRequestSchema,
SeasonFile as SeasonFileSchema,
SeasonNumber,
SeasonRequestId,
RichSeasonRequest as RichSeasonRequestSchema,
EpisodeId,
)
class TvRepository:
"""
Repository for managing TV shows, seasons, and episodes in the database.
Provides methods to retrieve, save, and delete shows and seasons.
"""
def __init__(self, db: Session):
self.db = db
def get_show_by_id(self, show_id: ShowId) -> ShowSchema:
"""
Retrieve a show by its ID, including seasons and episodes.
:param show_id: The ID of the show to retrieve.
:return: A Show object if found.
:raises NotFoundError: If the show with the given ID is not found.
:raises SQLAlchemyError: If a database error occurs.
"""
try:
stmt = (
select(Show)
.where(Show.id == show_id)
.options(joinedload(Show.seasons).joinedload(Season.episodes))
)
result = self.db.execute(stmt).unique().scalar_one_or_none()
if not result:
raise NotFoundError(f"Show with id {show_id} not found.")
return ShowSchema.model_validate(result)
except SQLAlchemyError as e:
log.error(f"Database error while retrieving show {show_id}: {e}")
raise
def get_show_by_external_id(
self, external_id: int, metadata_provider: str
) -> ShowSchema:
"""
Retrieve a show by its external ID, including nested seasons and episodes.
:param external_id: The ID of the show to retrieve.
:param metadata_provider: The metadata provider associated with the ID.
:return: A Show object if found.
:raises NotFoundError: If the show with the given external ID and provider is not found.
:raises SQLAlchemyError: If a database error occurs.
"""
try:
stmt = (
select(Show)
.where(Show.external_id == external_id)
.where(Show.metadata_provider == metadata_provider)
.options(joinedload(Show.seasons).joinedload(Season.episodes))
)
result = self.db.execute(stmt).unique().scalar_one_or_none()
if not result:
raise NotFoundError(
f"Show with external_id {external_id} and provider {metadata_provider} not found."
)
return ShowSchema.model_validate(result)
except SQLAlchemyError as e:
log.error(
f"Database error while retrieving show by external_id {external_id}: {e}"
)
raise
def get_shows(self) -> list[ShowSchema]:
"""
Retrieve all shows from the database.
:return: A list of Show objects.
:raises SQLAlchemyError: If a database error occurs.
"""
try:
stmt = select(Show).options(
joinedload(Show.seasons).joinedload(Season.episodes)
) # Eager load seasons and episodes
results = self.db.execute(stmt).scalars().unique().all()
return [ShowSchema.model_validate(show) for show in results]
except SQLAlchemyError as e:
log.error(f"Database error while retrieving all shows: {e}")
raise
def get_total_downloaded_episodes_count(self) -> int:
try:
stmt = (
select(func.count()).select_from(Episode).join(Season).join(SeasonFile)
)
total_count = self.db.execute(stmt).scalar_one_or_none()
return total_count
except SQLAlchemyError as e:
log.error(
f"Database error while calculating downloaded episodes count: {e}"
)
raise e
def save_show(self, show: ShowSchema) -> ShowSchema:
"""
Save a new show or update an existing one in the database.
:param show: The Show object to save.
:return: The saved Show object.
:raises ValueError: If a show with the same primary key already exists (on insert).
:raises SQLAlchemyError: If a database error occurs.
"""
db_show = self.db.get(Show, show.id) if show.id else None
if db_show: # Update existing show
db_show.external_id = show.external_id
db_show.metadata_provider = show.metadata_provider
db_show.name = show.name
db_show.overview = show.overview
db_show.year = show.year
db_show.original_language = show.original_language
else: # Insert new show
db_show = Show(
id=show.id,
external_id=show.external_id,
metadata_provider=show.metadata_provider,
name=show.name,
overview=show.overview,
year=show.year,
ended=show.ended,
original_language=show.original_language,
seasons=[
Season(
id=season.id,
show_id=show.id,
number=season.number,
external_id=season.external_id,
name=season.name,
overview=season.overview,
episodes=[
Episode(
id=episode.id,
season_id=season.id,
number=episode.number,
external_id=episode.external_id,
title=episode.title,
)
for episode in season.episodes
],
)
for season in show.seasons
],
)
self.db.add(db_show)
try:
self.db.commit()
self.db.refresh(db_show)
return ShowSchema.model_validate(db_show)
except IntegrityError as e:
self.db.rollback()
raise MediaAlreadyExists(
f"Show with this primary key or unique constraint violation: {e.orig}"
) from e
except SQLAlchemyError as e:
self.db.rollback()
log.error(f"Database error while saving show {show.name}: {e}")
raise
def delete_show(self, show_id: ShowId) -> None:
"""
Delete a show by its ID.
:param show_id: The ID of the show to delete.
:raises NotFoundError: If the show with the given ID is not found.
:raises SQLAlchemyError: If a database error occurs.
"""
try:
show = self.db.get(Show, show_id)
if not show:
raise NotFoundError(f"Show with id {show_id} not found.")
self.db.delete(show)
self.db.commit()
except SQLAlchemyError as e:
self.db.rollback()
log.error(f"Database error while deleting show {show_id}: {e}")
raise
def get_season(self, season_id: SeasonId) -> SeasonSchema:
"""
Retrieve a season by its ID.
:param season_id: The ID of the season to get.
:return: A Season object.
:raises NotFoundError: If the season with the given ID is not found.
:raises SQLAlchemyError: If a database error occurs.
"""
try:
season = self.db.get(Season, season_id)
if not season:
raise NotFoundError(f"Season with id {season_id} not found.")
return SeasonSchema.model_validate(season)
except SQLAlchemyError as e:
log.error(f"Database error while retrieving season {season_id}: {e}")
raise
def add_season_request(
self, season_request: SeasonRequestSchema
) -> SeasonRequestSchema:
"""
Adds a Season to the SeasonRequest table, which marks it as requested.
:param season_request: The SeasonRequest object to add.
:return: The added SeasonRequest object.
:raises IntegrityError: If a similar request already exists or violates constraints.
:raises SQLAlchemyError: If a database error occurs.
"""
db_model = SeasonRequest(
id=season_request.id,
season_id=season_request.season_id,
wanted_quality=season_request.wanted_quality,
min_quality=season_request.min_quality,
requested_by_id=season_request.requested_by.id
if season_request.requested_by
else None,
authorized=season_request.authorized,
authorized_by_id=season_request.authorized_by.id
if season_request.authorized_by
else None,
)
try:
self.db.add(db_model)
self.db.commit()
self.db.refresh(db_model)
return SeasonRequestSchema.model_validate(db_model)
except IntegrityError as e:
self.db.rollback()
log.error(f"Integrity error while adding season request: {e}")
raise
except SQLAlchemyError as e:
self.db.rollback()
log.error(f"Database error while adding season request: {e}")
raise
def delete_season_request(self, season_request_id: SeasonRequestId) -> None:
"""
Removes a SeasonRequest by its ID.
:param season_request_id: The ID of the season request to delete.
:raises NotFoundError: If the season request is not found.
:raises SQLAlchemyError: If a database error occurs.
"""
try:
stmt = delete(SeasonRequest).where(SeasonRequest.id == season_request_id)
result = self.db.execute(stmt)
if result.rowcount == 0:
self.db.rollback()
raise NotFoundError(
f"SeasonRequest with id {season_request_id} not found."
)
self.db.commit()
except SQLAlchemyError as e:
self.db.rollback()
log.error(
f"Database error while deleting season request {season_request_id}: {e}"
)
raise
def get_season_by_number(self, season_number: int, show_id: ShowId) -> SeasonSchema:
"""
Retrieve a season by its number and show ID.
:param season_number: The number of the season.
:param show_id: The ID of the show.
:return: A Season object.
:raises NotFoundError: If the season is not found.
:raises SQLAlchemyError: If a database error occurs.
"""
try:
stmt = (
select(Season)
.where(Season.show_id == show_id)
.where(Season.number == season_number)
.options(joinedload(Season.episodes), joinedload(Season.show))
)
result = self.db.execute(stmt).unique().scalar_one_or_none()
if not result:
raise NotFoundError(
f"Season number {season_number} for show_id {show_id} not found."
)
return SeasonSchema.model_validate(result)
except SQLAlchemyError as e:
log.error(
f"Database error retrieving season {season_number} for show {show_id}: {e}"
)
raise
def get_season_requests(self) -> list[RichSeasonRequestSchema]:
"""
Retrieve all season requests.
:return: A list of RichSeasonRequest objects.
:raises SQLAlchemyError: If a database error occurs.
"""
try:
stmt = select(SeasonRequest).options(
joinedload(SeasonRequest.requested_by),
joinedload(SeasonRequest.authorized_by),
joinedload(SeasonRequest.season).joinedload(Season.show),
)
results = self.db.execute(stmt).scalars().unique().all()
return [
RichSeasonRequestSchema(
id=x.id,
min_quality=x.min_quality,
wanted_quality=x.wanted_quality,
season_id=x.season_id,
show=x.season.show,
season=x.season,
requested_by=x.requested_by,
authorized_by=x.authorized_by,
authorized=x.authorized,
)
for x in results
]
except SQLAlchemyError as e:
log.error(f"Database error while retrieving season requests: {e}")
raise
def add_season_file(self, season_file: SeasonFileSchema) -> SeasonFileSchema:
"""
Adds a season file record to the database.
:param season_file: The SeasonFile object to add.
:return: The added SeasonFile object.
:raises IntegrityError: If the record violates constraints.
:raises SQLAlchemyError: If a database error occurs.
"""
db_model = SeasonFile(**season_file.model_dump())
try:
self.db.add(db_model)
self.db.commit()
self.db.refresh(db_model)
return SeasonFileSchema.model_validate(db_model)
except IntegrityError as e:
self.db.rollback()
log.error(f"Integrity error while adding season file: {e}")
raise
except SQLAlchemyError as e:
self.db.rollback()
log.error(f"Database error while adding season file: {e}")
raise
def remove_season_files_by_torrent_id(self, torrent_id: TorrentId) -> int:
"""
Removes season file records associated with a given torrent ID.
:param torrent_id: The ID of the torrent whose season files are to be removed.
:return: The number of season files removed.
:raises SQLAlchemyError: If a database error occurs.
"""
try:
stmt = delete(SeasonFile).where(SeasonFile.torrent_id == torrent_id)
result = self.db.execute(stmt)
self.db.commit()
deleted_count = result.rowcount # rowcount is an int, not a callable
return deleted_count
except SQLAlchemyError as e:
self.db.rollback()
log.error(
f"Database error removing season files for torrent_id {torrent_id}: {e}"
)
raise
def set_show_library(self, show_id: ShowId, library: str) -> None:
"""
Sets the library for a show.
:param show_id: The ID of the show to update.
:param library: The library path to set for the show.
:raises NotFoundError: If the show with the given ID is not found.
:raises SQLAlchemyError: If a database error occurs.
"""
try:
show = self.db.get(Show, show_id)
if not show:
raise NotFoundError(f"Show with id {show_id} not found.")
show.library = library
self.db.commit()
except SQLAlchemyError as e:
self.db.rollback()
log.error(f"Database error setting library for show {show_id}: {e}")
raise
def get_season_files_by_season_id(
self, season_id: SeasonId
) -> list[SeasonFileSchema]:
"""
Retrieve all season files for a given season ID.
:param season_id: The ID of the season.
:return: A list of SeasonFile objects.
:raises SQLAlchemyError: If a database error occurs.
"""
try:
stmt = select(SeasonFile).where(SeasonFile.season_id == season_id)
results = self.db.execute(stmt).scalars().all()
return [SeasonFileSchema.model_validate(sf) for sf in results]
except SQLAlchemyError as e:
log.error(
f"Database error retrieving season files for season_id {season_id}: {e}"
)
raise
def get_torrents_by_show_id(self, show_id: ShowId) -> list[TorrentSchema]:
"""
Retrieve all torrents associated with a given show ID.
:param show_id: The ID of the show.
:return: A list of Torrent objects.
:raises SQLAlchemyError: If a database error occurs.
"""
try:
stmt = (
select(Torrent)
.distinct()
.join(SeasonFile, SeasonFile.torrent_id == Torrent.id)
.join(Season, Season.id == SeasonFile.season_id)
.where(Season.show_id == show_id)
)
results = self.db.execute(stmt).scalars().unique().all()
return [TorrentSchema.model_validate(torrent) for torrent in results]
except SQLAlchemyError as e:
log.error(f"Database error retrieving torrents for show_id {show_id}: {e}")
raise
def get_all_shows_with_torrents(self) -> list[ShowSchema]:
"""
Retrieve all shows that are associated with a torrent, ordered alphabetically by show name.
:return: A list of Show objects.
:raises SQLAlchemyError: If a database error occurs.
"""
try:
stmt = (
select(Show)
.distinct()
.join(Season, Show.id == Season.show_id)
.join(SeasonFile, Season.id == SeasonFile.season_id)
.join(Torrent, SeasonFile.torrent_id == Torrent.id)
.options(joinedload(Show.seasons).joinedload(Season.episodes))
.order_by(Show.name)
)
results = self.db.execute(stmt).scalars().unique().all()
return [ShowSchema.model_validate(show) for show in results]
except SQLAlchemyError as e:
log.error(f"Database error retrieving all shows with torrents: {e}")
raise
def get_seasons_by_torrent_id(self, torrent_id: TorrentId) -> list[SeasonNumber]:
"""
Retrieve season numbers associated with a given torrent ID.
:param torrent_id: The ID of the torrent.
:return: A list of SeasonNumber objects.
:raises SQLAlchemyError: If a database error occurs.
"""
try:
stmt = (
select(Season.number)
.distinct()
.join(SeasonFile, Season.id == SeasonFile.season_id)
.where(SeasonFile.torrent_id == torrent_id)
)
results = self.db.execute(stmt).scalars().unique().all()
return [SeasonNumber(x) for x in results]
except SQLAlchemyError as e:
log.error(
f"Database error retrieving season numbers for torrent_id {torrent_id}: {e}"
)
raise
def get_season_request(
self, season_request_id: SeasonRequestId
) -> SeasonRequestSchema:
"""
Retrieve a season request by its ID.
:param season_request_id: The ID of the season request.
:return: A SeasonRequest object.
:raises NotFoundError: If the season request is not found.
:raises SQLAlchemyError: If a database error occurs.
"""
try:
request = self.db.get(SeasonRequest, season_request_id)
if not request:
log.warning(f"Season request with id {season_request_id} not found.")
raise NotFoundError(
f"Season request with id {season_request_id} not found."
)
return SeasonRequestSchema.model_validate(request)
except SQLAlchemyError as e:
log.error(
f"Database error retrieving season request {season_request_id}: {e}"
)
raise
def get_show_by_season_id(self, season_id: SeasonId) -> ShowSchema:
"""
Retrieve a show by one of its season's ID.
:param season_id: The ID of the season to retrieve the show for.
:return: A Show object.
:raises NotFoundError: If the show for the given season ID is not found.
:raises SQLAlchemyError: If a database error occurs.
"""
try:
stmt = (
select(Show)
.join(Season, Show.id == Season.show_id)
.where(Season.id == season_id)
.options(joinedload(Show.seasons).joinedload(Season.episodes))
)
result = self.db.execute(stmt).unique().scalar_one_or_none()
if not result:
raise NotFoundError(f"Show for season_id {season_id} not found.")
return ShowSchema.model_validate(result)
except SQLAlchemyError as e:
log.error(f"Database error retrieving show by season_id {season_id}: {e}")
raise
def add_season_to_show(
self, show_id: ShowId, season_data: SeasonSchema
) -> SeasonSchema:
"""
Adds a new season and its episodes to a show.
If the season number already exists for the show, it returns the existing season.
:param show_id: The ID of the show to add the season to.
:param season_data: The SeasonSchema object for the new season.
:return: The added or existing SeasonSchema object.
:raises NotFoundError: If the show is not found.
:raises SQLAlchemyError: If a database error occurs.
"""
db_show = self.db.get(Show, show_id)
if not db_show:
raise NotFoundError(f"Show with id {show_id} not found.")
stmt = (
select(Season)
.where(Season.show_id == show_id)
.where(Season.number == season_data.number)
)
existing_db_season = self.db.execute(stmt).scalar_one_or_none()
if existing_db_season:
return SeasonSchema.model_validate(existing_db_season)
db_season = Season(
id=season_data.id,
show_id=show_id,
number=season_data.number,
external_id=season_data.external_id,
name=season_data.name,
overview=season_data.overview,
episodes=[
Episode(
id=ep_schema.id,
# season_id will be implicitly set by SQLAlchemy relationship
number=ep_schema.number,
external_id=ep_schema.external_id,
title=ep_schema.title,
)
for ep_schema in season_data.episodes
],
)
self.db.add(db_season)
self.db.commit()
self.db.refresh(db_season)
return SeasonSchema.model_validate(db_season)
def add_episode_to_season(
self, season_id: SeasonId, episode_data: EpisodeSchema
) -> EpisodeSchema:
"""
Adds a new episode to a season.
If the episode number already exists for the season, it returns the existing episode.
:param season_id: The ID of the season to add the episode to.
:param episode_data: The EpisodeSchema object for the new episode.
:return: The added or existing EpisodeSchema object.
:raises NotFoundError: If the season is not found.
:raises SQLAlchemyError: If a database error occurs.
"""
db_season = self.db.get(Season, season_id)
if not db_season:
raise NotFoundError(f"Season with id {season_id} not found.")
stmt = (
select(Episode)
.where(Episode.season_id == season_id)
.where(Episode.number == episode_data.number)
)
existing_db_episode = self.db.execute(stmt).scalar_one_or_none()
if existing_db_episode:
return EpisodeSchema.model_validate(existing_db_episode)
db_episode = Episode(
id=episode_data.id,
season_id=season_id,
number=episode_data.number,
external_id=episode_data.external_id,
title=episode_data.title,
)
self.db.add(db_episode)
self.db.commit()
self.db.refresh(db_episode)
return EpisodeSchema.model_validate(db_episode)
def update_show_attributes(
self,
show_id: ShowId,
name: str | None = None,
overview: str | None = None,
year: int | None = None,
ended: bool | None = None,
continuous_download: bool | None = None,
) -> ShowSchema: # Removed poster_url from params
"""
Update attributes of an existing show.
:param show_id: The ID of the show to update.
:param name: The new name for the show.
:param overview: The new overview for the show.
:param year: The new year for the show.
:param ended: The new ended status for the show.
:return: The updated ShowSchema object.
"""
db_show = self.db.get(Show, show_id)
if not db_show:
raise NotFoundError(f"Show with id {show_id} not found.")
updated = False
if name is not None and db_show.name != name:
db_show.name = name
updated = True
if overview is not None and db_show.overview != overview:
db_show.overview = overview
updated = True
if year is not None and db_show.year != year:
db_show.year = year
updated = True
if ended is not None and db_show.ended != ended:
db_show.ended = ended
updated = True
if (
continuous_download is not None
and db_show.continuous_download != continuous_download
):
db_show.continuous_download = continuous_download
updated = True
if updated:
self.db.commit()
self.db.refresh(db_show)
return ShowSchema.model_validate(db_show)
def update_season_attributes(
self, season_id: SeasonId, name: str | None = None, overview: str | None = None
) -> SeasonSchema:
"""
Update attributes of an existing season.
:param season_id: The ID of the season to update.
:param name: The new name for the season.
:param overview: The new overview for the season.
:param external_id: The new external ID for the season.
:return: The updated SeasonSchema object.
:raises NotFoundError: If the season is not found.
:raises SQLAlchemyError: If a database error occurs.
"""
db_season = self.db.get(Season, season_id)
if not db_season:
raise NotFoundError(f"Season with id {season_id} not found.")
updated = False
if name is not None and db_season.name != name:
db_season.name = name
updated = True
if overview is not None and db_season.overview != overview:
db_season.overview = overview
updated = True
if updated:
self.db.commit()
self.db.refresh(db_season)
return SeasonSchema.model_validate(db_season)
def update_episode_attributes(
self, episode_id: EpisodeId, title: str | None = None
) -> EpisodeSchema:
"""
Update attributes of an existing episode.
:param episode_id: The ID of the episode to update.
:param title: The new title for the episode.
:param external_id: The new external ID for the episode.
:return: The updated EpisodeSchema object.
:raises NotFoundError: If the episode is not found.
:raises SQLAlchemyError: If a database error occurs.
"""
db_episode = self.db.get(Episode, episode_id)
if not db_episode:
raise NotFoundError(f"Episode with id {episode_id} not found.")
updated = False
if title is not None and db_episode.title != title:
db_episode.title = title
updated = True
if updated:
self.db.commit()
self.db.refresh(db_episode)
return EpisodeSchema.model_validate(db_episode)