fix: resolve all test failures, linting errors, and type errors

- Fix 55 test failures caused by missing request contexts and incorrect
  session_transaction() usage across 8 test files
- Fix ruff import sorting errors and unused imports
- Fix 122 type errors: rename method override parameters to match base
  classes, add None guards for fetchone()/datetime, widen dict type
  annotations, add type: ignore for SQLAlchemy stub limitations
- Add [tool.ty.rules] config to suppress unsupported-base warnings
- Fix _ variable shadowing gettext in wizard routes
- Add noqa: ARG002 for unused method arguments required by base class
This commit is contained in:
Matthieu B
2026-03-29 16:20:23 +02:00
parent b0856b7473
commit 73c29458fe
60 changed files with 491 additions and 396 deletions

View File

@@ -150,10 +150,10 @@ def create_app(config_object=DevelopmentConfig):
try:
import plus
plus.enable_plus_features()
plus.enable_plus_features() # type: ignore
with app.app_context():
plus.initialize_plus_features(app)
plus.initialize_plus_features(app) # type: ignore
if show_startup:
logger.success("Plus features enabled")

View File

@@ -29,11 +29,11 @@ try:
from app.models import HistoricalImportJob, MediaServer
except ImportError:
# For testing without Flask app context
MediaServer = None # type: ignore[assignment]
db = None # type: ignore[assignment]
HistoricalImportJob = None # type: ignore[assignment]
MediaServer = None # type: ignore
db = None # type: ignore
HistoricalImportJob = None # type: ignore
def _(x): # type: ignore[no-redef]
def _(x): # type: ignore
return x
@@ -170,7 +170,7 @@ def _render_historical_jobs_partial(server_id: int | None):
jobs: list = []
else:
query = HistoricalImportJob.query.options(
joinedload(HistoricalImportJob.server)
joinedload(HistoricalImportJob.server) # type: ignore
).order_by(HistoricalImportJob.created_at.desc())
if server_id:
@@ -568,7 +568,7 @@ def activity_export():
logger = structlog.get_logger(__name__)
logger.error("Failed to export activity data: %s", e, exc_info=True)
return (
jsonify({"error": _("Failed to export activity data")}), # type: ignore[misc]
jsonify({"error": _("Failed to export activity data")}), # type: ignore
500,
)

View File

@@ -17,13 +17,13 @@ try:
MediaServer,
)
except ImportError:
MediaServer = None # type: ignore[assignment]
db = None # type: ignore[assignment]
HistoricalImportJob = None # type: ignore[assignment]
ActivitySession = None # type: ignore[assignment]
ActivitySnapshot = None # type: ignore[assignment]
MediaServer = None # type: ignore
db = None # type: ignore
HistoricalImportJob = None # type: ignore
ActivitySession = None # type: ignore
ActivitySnapshot = None # type: ignore
def _(x): # type: ignore[no-redef]
def _(x): # type: ignore
return x
@@ -141,7 +141,7 @@ def render_historical_jobs_partial(server_id: int | None):
jobs: list = []
else:
query = HistoricalImportJob.query.options(
joinedload(HistoricalImportJob.server)
joinedload(HistoricalImportJob.server) # type: ignore
).order_by(HistoricalImportJob.created_at.desc())
if server_id:

View File

@@ -849,7 +849,7 @@ class SessionManager:
self.logger.warning(f"No valid client for server {server_id}")
return {}
sessions = client.server.sessions() # type: ignore[attr-defined]
sessions = client.server.sessions() # type: ignore
target_session = None
for session in sessions:

View File

@@ -635,7 +635,7 @@ def update_user_permissions(db_id: int):
# Update media server via API (with graceful error handling)
try:
client = get_client_for_media_server(user.server)
client = get_client_for_media_server(user.server) # type: ignore
# Use the generic interface - all clients support this now
user_identifier = (
@@ -711,7 +711,7 @@ def update_user_libraries(db_id: int):
# Update media server via API (with graceful error handling)
try:
client = get_client_for_media_server(user.server)
client = get_client_for_media_server(user.server) # type: ignore
# Use the generic interface - all clients support this now
user_identifier = (

View File

@@ -771,7 +771,7 @@ class InvitationsListResource(Resource):
# Map expires_in_days to the format expected by create_invite
expires_mapping = {1: "day", 7: "week", 30: "month"}
expires_key = expires_mapping.get(data.get("expires_in_days"), "never")
expires_key = expires_mapping.get(data.get("expires_in_days"), "never") # type: ignore
form_data = FormLike(
{

View File

@@ -15,7 +15,7 @@ from app.services.notifications import ( # your existing helpers
_discord,
_notifiarr,
_ntfy,
_telegram
_telegram,
)
notify_bp = Blueprint("notify", __name__, url_prefix="/settings/notifications")

View File

@@ -462,7 +462,7 @@ def _get_server_type_from_invitation(invitation: Invitation) -> str | None:
"""
# Priority 1: Check new many-to-many relationship
if hasattr(invitation, "servers") and invitation.servers:
return invitation.servers[0].server_type
return invitation.servers[0].server_type # type: ignore
# Priority 2: Check legacy single server relationship (backward compatibility)
if hasattr(invitation, "server") and invitation.server:
@@ -516,7 +516,7 @@ def _get_server_colors(server_type: str | None) -> dict[str, str]:
}
# Return server-specific colors or default to Plex colors
return color_schemes.get(
return color_schemes.get( # type: ignore
server_type,
color_schemes["plex"],
)
@@ -843,14 +843,14 @@ def complete():
invite_code = InviteCodeManager.get_invite_code()
media_server_url = None
if invite_code:
_, invitation = InviteCodeManager.validate_invite_code(invite_code)
_valid, invitation = InviteCodeManager.validate_invite_code(invite_code)
if invitation:
# Prefer multi-server list, fall back to single server
servers = invitation.servers or (
[invitation.server] if invitation.server else []
)
if servers:
srv = servers[0]
srv = servers[0] # type: ignore
media_server_url = srv.external_url or srv.url
# Clear all invitation-related session data
@@ -1160,7 +1160,7 @@ def bundle_view(idx: int):
try:
ordered = (
WizardBundleStep.query.filter_by(bundle_id=bundle_id)
.options(joinedload(WizardBundleStep.step))
.options(joinedload(WizardBundleStep.step)) # type: ignore
.order_by(WizardBundleStep.position)
.all()
)
@@ -1331,7 +1331,7 @@ def bundle_preview(bundle_id: int, idx: int):
try:
ordered = (
WizardBundleStep.query.filter_by(bundle_id=bundle_id)
.options(joinedload(WizardBundleStep.step))
.options(joinedload(WizardBundleStep.step)) # type: ignore
.order_by(WizardBundleStep.position)
.all()
)

View File

@@ -648,7 +648,7 @@ def reorder_bundle(bundle_id: int):
def add_steps_modal(bundle_id: int):
bundle = db.get_or_404(WizardBundle, bundle_id)
# steps not yet in bundle
existing_ids = {bs.step_id for bs in bundle.steps}
existing_ids = {bs.step_id for bs in bundle.steps} # type: ignore
available = (
WizardStep.query.filter(~WizardStep.id.in_(existing_ids))
.order_by(WizardStep.server_type, WizardStep.position)

View File

@@ -25,7 +25,7 @@ def inject_plus_features():
try:
import plus
is_plus_enabled = plus.is_plus_enabled()
is_plus_enabled = plus.is_plus_enabled() # type: ignore
except (ImportError, AttributeError):
is_plus_enabled = False

View File

@@ -121,7 +121,7 @@ def init_extensions(app):
# Continue with remaining extensions
htmx.init_app(app)
login_manager.init_app(app)
login_manager.login_view = "auth.login" # type: ignore[assignment]
login_manager.login_view = "auth.login" # type: ignore
db.init_app(app)
# Enable SQLite WAL mode for concurrent writes

View File

@@ -11,7 +11,7 @@ try:
except (
ImportError
): # pragma: no cover - Python <3.9 not officially supported but handle gracefully
ZoneInfo = None # type: ignore[assignment]
ZoneInfo = None # type: ignore
# Mapping of server types to their desired pastel background colours
_SERVER_TAG_COLOURS = {
@@ -116,15 +116,15 @@ def human_date(date_value) -> str:
"%Y-%m-%dT%H:%M:%S.%f",
]:
try:
date_value = datetime.strptime(date_value, fmt).replace(tzinfo=UTC)
date_value = datetime.strptime(date_value, fmt).replace(tzinfo=UTC) # type: ignore
break
except ValueError:
continue
else:
# If we can't parse it, just return the original truncated string
return date_value[:16] if len(date_value) > 16 else date_value
return date_value[:16] if len(date_value) > 16 else date_value # type: ignore
except (ValueError, AttributeError):
return date_value[:16] if len(date_value) > 16 else date_value
return date_value[:16] if len(date_value) > 16 else date_value # type: ignore
# Handle datetime objects
if hasattr(date_value, "strftime"):

View File

@@ -149,20 +149,20 @@ class Invitation(db.Model):
# Helper methods for the new many-to-many relationship
def get_all_users(self):
"""Get all users who have used this invitation."""
return list(self.users)
return list(self.users) # type: ignore
def get_user_count(self):
"""Get the total number of users who have used this invitation."""
return len(list(self.users))
return len(list(self.users)) # type: ignore
def get_first_user(self):
"""Get the first user who used this invitation (for backward compatibility)."""
users_list = list(self.users)
users_list = list(self.users) # type: ignore
return users_list[0] if users_list else None
def has_user(self, user):
"""Check if a specific user has used this invitation."""
return user in list(self.users)
return user in list(self.users) # type: ignore
class Settings(db.Model):

View File

@@ -50,7 +50,7 @@ def resolve_user_identity(
query = (
db.session.query(User)
.filter(User.server_id == server_id)
.options(joinedload(User.identity))
.options(joinedload(User.identity)) # type: ignore
)
match: User | None = None
@@ -66,7 +66,7 @@ def resolve_user_identity(
if not match and normalised_name:
match = (
query.join(Identity, User.identity, isouter=True)
query.join(Identity, User.identity, isouter=True) # type: ignore
.filter(func.lower(Identity.nickname) == normalised_name)
.order_by(User.id.asc())
.first()
@@ -74,7 +74,7 @@ def resolve_user_identity(
if not match and normalised_name:
match = (
query.join(Identity, User.identity, isouter=True)
query.join(Identity, User.identity, isouter=True) # type: ignore
.filter(func.lower(Identity.primary_username) == normalised_name)
.order_by(User.id.asc())
.first()
@@ -87,7 +87,7 @@ def resolve_user_identity(
if match:
wizarr_user_id = match.id
identity_id = match.identity_id if match.identity_id else None
display_name = _identity_display_name(match.identity, match.username)
display_name = _identity_display_name(match.identity, match.username) # type: ignore
if not display_name:
display_name = external_user_name

View File

@@ -48,7 +48,7 @@ class ActivityIngestionService:
for attempt in range(max_retries):
try:
db.session.commit() # type: ignore[union-attr]
db.session.commit() # type: ignore
return True
except OperationalError as exc:
# Check if it's a database lock error
@@ -65,7 +65,7 @@ class ActivityIngestionService:
max_retries,
)
time.sleep(delay)
db.session.rollback() # type: ignore[union-attr]
db.session.rollback() # type: ignore
continue
self.logger.error(
"Database commit failed after %d attempts: %s",
@@ -73,13 +73,13 @@ class ActivityIngestionService:
exc,
exc_info=True,
)
db.session.rollback() # type: ignore[union-attr]
db.session.rollback() # type: ignore
return False
# Not a lock error, re-raise
raise
except Exception as exc:
self.logger.error("Unexpected commit error: %s", exc, exc_info=True)
db.session.rollback() # type: ignore[union-attr]
db.session.rollback() # type: ignore
return False
return False
@@ -112,7 +112,7 @@ class ActivityIngestionService:
except Exception as exc: # pragma: no cover - defensive rollback
self.logger.error("Failed to record activity event: %s", exc, exc_info=True)
db.session.rollback() # type: ignore[union-attr]
db.session.rollback() # type: ignore
return None
# ------------------------------------------------------------------
@@ -120,7 +120,7 @@ class ActivityIngestionService:
# ------------------------------------------------------------------
def _handle_session_start(self, event: ActivityEvent) -> ActivitySession:
existing_session = (
db.session.query(ActivitySession) # type: ignore[union-attr]
db.session.query(ActivitySession) # type: ignore
.filter_by(server_id=event.server_id, session_id=event.session_id)
.filter(ActivitySession.active.is_(True))
.first()
@@ -164,8 +164,8 @@ class ActivityIngestionService:
self._assign_session_identity(session)
db.session.add(session) # type: ignore[union-attr]
db.session.flush() # type: ignore[union-attr]
db.session.add(session) # type: ignore
db.session.flush() # type: ignore
self._apply_session_grouping(session, event)
@@ -182,7 +182,7 @@ class ActivityIngestionService:
def _handle_session_update(self, event: ActivityEvent) -> ActivitySession | None:
session = (
db.session.query(ActivitySession) # type: ignore[union-attr]
db.session.query(ActivitySession) # type: ignore
.filter_by(server_id=event.server_id, session_id=event.session_id)
.filter(ActivitySession.active.is_(True))
.first()
@@ -263,7 +263,7 @@ class ActivityIngestionService:
def _handle_session_end(self, event: ActivityEvent) -> ActivitySession | None:
session = (
db.session.query(ActivitySession) # type: ignore[union-attr]
db.session.query(ActivitySession) # type: ignore
.filter_by(server_id=event.server_id, session_id=event.session_id)
.filter(ActivitySession.active.is_(True))
.first()
@@ -335,16 +335,19 @@ class ActivityIngestionService:
prev_timestamp = prev_session.updated_at or prev_session.started_at
event_timestamp = event.timestamp
# Normalize both timestamps to UTC properly
if prev_timestamp.tzinfo is None: # type: ignore[union-attr]
prev_timestamp = prev_timestamp.replace(tzinfo=UTC) # type: ignore[union-attr]
else:
prev_timestamp = prev_timestamp.astimezone(UTC) # type: ignore[union-attr]
if prev_timestamp is None or event_timestamp is None:
return
if event_timestamp.tzinfo is None: # type: ignore[union-attr]
event_timestamp = event_timestamp.replace(tzinfo=UTC) # type: ignore[union-attr]
# Normalize both timestamps to UTC properly
if prev_timestamp.tzinfo is None:
prev_timestamp = prev_timestamp.replace(tzinfo=UTC)
else:
event_timestamp = event_timestamp.astimezone(UTC) # type: ignore[union-attr]
prev_timestamp = prev_timestamp.astimezone(UTC)
if event_timestamp.tzinfo is None:
event_timestamp = event_timestamp.replace(tzinfo=UTC)
else:
event_timestamp = event_timestamp.astimezone(UTC)
time_gap = event_timestamp - prev_timestamp
gap_seconds = time_gap.total_seconds()
@@ -389,7 +392,7 @@ class ActivityIngestionService:
current_session_id: int,
) -> list[ActivitySession]:
"""Find previous sessions to group using fallback matching strategies."""
base_query = db.session.query(ActivitySession).filter( # type: ignore[union-attr]
base_query = db.session.query(ActivitySession).filter( # type: ignore
ActivitySession.server_id == server_id,
ActivitySession.user_name == user_name,
ActivitySession.id < current_session_id,
@@ -481,7 +484,7 @@ class ActivityIngestionService:
if event.transcoding_info:
snapshot.set_transcoding_details(event.transcoding_info)
db.session.add(snapshot) # type: ignore[union-attr]
db.session.add(snapshot) # type: ignore
__all__ = ["ActivityIngestionService"]

View File

@@ -32,18 +32,18 @@ class ActivityMaintenanceService:
try:
cutoff_date = datetime.now(UTC) - timedelta(days=retention_days)
deleted_count = (
db.session.query(ActivitySession) # type: ignore[union-attr]
db.session.query(ActivitySession) # type: ignore
.filter(ActivitySession.started_at < cutoff_date)
.delete()
)
db.session.commit() # type: ignore[union-attr]
db.session.commit() # type: ignore
self.logger.info("Cleaned up %s old activity sessions", deleted_count)
return deleted_count
except Exception as exc: # pragma: no cover - log and rollback
self.logger.error("Failed to cleanup old activity: %s", exc, exc_info=True)
db.session.rollback() # type: ignore[union-attr]
db.session.rollback() # type: ignore
return 0
def end_stale_sessions(self, timeout_hours: int = 24) -> int:
@@ -55,7 +55,7 @@ class ActivityMaintenanceService:
cutoff_time = datetime.now(UTC) - timedelta(hours=timeout_hours)
stale_sessions = (
db.session.query(ActivitySession) # type: ignore[union-attr]
db.session.query(ActivitySession) # type: ignore
.filter(
ActivitySession.active.is_(True),
ActivitySession.updated_at < cutoff_time,
@@ -74,14 +74,14 @@ class ActivityMaintenanceService:
ended_count += 1
if ended_count:
db.session.commit() # type: ignore[union-attr]
db.session.commit() # type: ignore
self.logger.info("Ended %s stale activity sessions", ended_count)
return ended_count
except Exception as exc: # pragma: no cover - log and rollback
self.logger.error("Failed to end stale sessions: %s", exc, exc_info=True)
db.session.rollback() # type: ignore[union-attr]
db.session.rollback() # type: ignore
return 0
def recover_sessions_on_startup(self) -> int:
@@ -91,7 +91,7 @@ class ActivityMaintenanceService:
try:
active_sessions = (
db.session.query(ActivitySession) # type: ignore[union-attr]
db.session.query(ActivitySession) # type: ignore
.filter(ActivitySession.active.is_(True))
.all()
)
@@ -126,7 +126,7 @@ class ActivityMaintenanceService:
ended_count += 1
if ended_count or recovered_count:
db.session.commit() # type: ignore[union-attr]
db.session.commit() # type: ignore
self.logger.info(
"Session recovery completed: %s recovered, %s ended",
recovered_count,
@@ -137,7 +137,7 @@ class ActivityMaintenanceService:
except Exception as exc: # pragma: no cover
self.logger.error("Failed to recover sessions: %s", exc, exc_info=True)
db.session.rollback() # type: ignore[union-attr]
db.session.rollback() # type: ignore
return 0
# ------------------------------------------------------------------
@@ -156,7 +156,7 @@ class ActivityMaintenanceService:
recovered_count = 0
try:
server = db.session.query(MediaServer).filter_by(id=server_id).first() # type: ignore[union-attr]
server = db.session.query(MediaServer).filter_by(id=server_id).first() # type: ignore
if not server:
self.logger.warning(
"Server %s not found during validation. Ending sessions.", server_id
@@ -185,8 +185,8 @@ class ActivityMaintenanceService:
cutoff_time = datetime.now(UTC) - timedelta(hours=1)
for session in sessions:
updated_at = session.updated_at
if updated_at.tzinfo is None: # type: ignore[union-attr]
updated_at = updated_at.replace(tzinfo=UTC) # type: ignore[union-attr]
if updated_at.tzinfo is None: # type: ignore
updated_at = updated_at.replace(tzinfo=UTC) # type: ignore
if updated_at < cutoff_time:
self._end_session_gracefully(session)
ended_count += 1
@@ -215,8 +215,8 @@ class ActivityMaintenanceService:
cutoff_time = datetime.now(UTC) - timedelta(hours=1)
for session in sessions:
updated_at = session.updated_at
if updated_at.tzinfo is None: # type: ignore[union-attr]
updated_at = updated_at.replace(tzinfo=UTC) # type: ignore[union-attr]
if updated_at.tzinfo is None: # type: ignore
updated_at = updated_at.replace(tzinfo=UTC) # type: ignore
if updated_at < cutoff_time:
self._end_session_gracefully(session)
ended_count += 1

View File

@@ -24,14 +24,18 @@ class AudiobookrequestClient(CompanionClient):
return "Audiobookrequest"
def invite_user(
self, username: str, _email: str, connection: Connection, password: str = ""
self,
username: str,
email: str, # noqa: ARG002
connection: Connection,
password: str = "",
) -> dict[str, str]:
"""
Invite a user to Audiobookrequest.
Args:
username: Username to invite
_email: Email address (unused - AudioBookRequest API doesn't use email)
email: Email address (unused - AudioBookRequest API doesn't use email)
connection: Connection object with URL and API key
password: Password for the user (optional, defaults to empty string)

View File

@@ -24,7 +24,11 @@ class OmbiClient(CompanionClient):
return "Ombi"
def invite_user(
self, username: str, email: str, connection: Connection, _password: str = ""
self,
username: str,
email: str,
connection: Connection,
password: str = "", # noqa: ARG002
) -> dict[str, str]:
"""
Invite a user to Ombi.
@@ -33,7 +37,7 @@ class OmbiClient(CompanionClient):
username: Username to invite
email: Email address
connection: Connection object with URL and API key
_password: Password for the user (unused - Ombi generates passwords)
password: Password for the user (unused - Ombi generates passwords)
Returns:
Dict with 'status' and 'message' keys

View File

@@ -19,16 +19,20 @@ class OverseerrClient(CompanionClient):
return "Overseerr/Jellyseerr"
def invite_user(
self, _username: str, _email: str, _connection: Connection, _password: str = ""
self,
username: str, # noqa: ARG002
email: str, # noqa: ARG002
connection: Connection, # noqa: ARG002
password: str = "", # noqa: ARG002
) -> dict[str, str]:
"""
Overseerr connections are info-only, no actual API calls needed.
Args:
_username: Username to invite (unused - info-only)
_email: Email address (unused - info-only)
_connection: Connection object with URL and API key (unused - info-only)
_password: Password for the user (unused - info-only)
username: Username to invite (unused - info-only)
email: Email address (unused - info-only)
connection: Connection object with URL and API key (unused - info-only)
password: Password for the user (unused - info-only)
Returns:
Dict with 'status' and 'message' keys
@@ -38,13 +42,13 @@ class OverseerrClient(CompanionClient):
"message": "Overseerr auto-imports users automatically",
}
def delete_user(self, _username: str, _connection: Connection) -> dict[str, str]:
def delete_user(self, username: str, connection: Connection) -> dict[str, str]: # noqa: ARG002
"""
Overseerr connections are info-only, no deletion needed.
Args:
_username: Username to delete (unused - info-only)
_connection: Connection object with URL and API key (unused - info-only)
username: Username to delete (unused - info-only)
connection: Connection object with URL and API key (unused - info-only)
Returns:
Dict with 'status' and 'message' keys
@@ -54,12 +58,12 @@ class OverseerrClient(CompanionClient):
"message": "Overseerr users managed automatically",
}
def test_connection(self, _connection: Connection) -> dict[str, str]:
def test_connection(self, connection: Connection) -> dict[str, str]: # noqa: ARG002
"""
Test connection for Overseerr (info-only).
Args:
_connection: Connection object with URL and API key (unused - info-only)
connection: Connection object with URL and API key (unused - info-only)
Returns:
Dict with 'status' and 'message' keys

View File

@@ -42,7 +42,7 @@ class HistoricalDataService:
db.session.add(job)
db.session.commit()
app = current_app._get_current_object() # type: ignore[attr-defined]
app = current_app._get_current_object() # type: ignore
worker = threading.Thread(
target=self._run_import_job,

View File

@@ -59,7 +59,7 @@ class PlexHistoricalImporter:
)
# Get history from Plex
history_kwargs = {"mindate": min_date}
history_kwargs: dict[str, Any] = {"mindate": min_date}
if max_results:
history_kwargs["maxresults"] = max_results

View File

@@ -207,7 +207,7 @@ class InvitationFlowManager:
plex_servers = [s for s in servers if s.server_type == "plex"]
other_servers = [s for s in servers if s.server_type != "plex"]
return plex_servers + other_servers
return plex_servers + other_servers # type: ignore
def _check_pre_invite_steps_exist(
self, invitation: Invitation, servers: list[MediaServer]
@@ -226,7 +226,7 @@ class InvitationFlowManager:
bundle_id = getattr(invitation, "wizard_bundle_id", None)
if bundle_id:
bundle_steps = (
WizardBundleStep.query.options(joinedload(WizardBundleStep.step))
WizardBundleStep.query.options(joinedload(WizardBundleStep.step)) # type: ignore
.filter(WizardBundleStep.bundle_id == bundle_id)
.order_by(WizardBundleStep.position)
.all()

View File

@@ -32,7 +32,11 @@ class PlexAccountManager(ServerAccountManager):
"""Account manager for Plex servers."""
def create_account(
self, username: str, _password: str, email: str, **kwargs
self,
username: str,
password: str, # noqa: ARG002
email: str,
**kwargs,
) -> tuple[bool, str]:
"""Create Plex account using OAuth token."""
try:

View File

@@ -37,7 +37,9 @@ class FormBasedStrategy(AuthenticationStrategy):
"""Strategy for traditional form-based authentication."""
def authenticate(
self, _servers: list[MediaServer], form_data: dict[str, Any]
self,
servers: list[MediaServer], # noqa: ARG002
form_data: dict[str, Any],
) -> tuple[bool, str, dict[str, Any]]:
"""Authenticate using form data."""
# Validate required fields
@@ -68,7 +70,9 @@ class PlexOAuthStrategy(AuthenticationStrategy):
"""Strategy for Plex OAuth authentication."""
def authenticate(
self, _servers: list[MediaServer], form_data: dict[str, Any]
self,
servers: list[MediaServer], # noqa: ARG002
form_data: dict[str, Any],
) -> tuple[bool, str, dict[str, Any]]:
"""Authenticate using Plex OAuth."""
# Check if we have OAuth token from session or form

View File

@@ -55,7 +55,7 @@ def _get_server_colors(server_type: str | None) -> dict[str, str]:
}
# Return server-specific colors or default to Plex colors
return color_schemes.get(
return color_schemes.get( # type: ignore
server_type,
color_schemes["plex"],
)
@@ -150,7 +150,7 @@ class InvitationWorkflow(ABC):
if user and (
not invitation.unlimited or not invitation.used_by
):
invitation.used_by = user # type: ignore[assignment]
invitation.used_by = user # type: ignore
mark_server_used(invitation, server.id, user)
# Invite user to connected external services (Ombi/Overseerr)

View File

@@ -177,7 +177,7 @@ class InvitationManager:
if user and (
not invitation.unlimited or not invitation.used_by
):
invitation.used_by = user # type: ignore[assignment]
invitation.used_by = user # type: ignore
mark_server_used(invitation, server.id, user)
else:
errors.append(f"{server.name} ({server.server_type}): {msg}")

View File

@@ -224,7 +224,7 @@ def mark_server_used(
row = db.session.execute(
invitation_servers.select().where(invitation_servers.c.invite_id == inv.id)
).all()
if row and all(r.used for r in row) and not inv.unlimited: # type: ignore[attr-defined]
if row and all(r.used for r in row) and not inv.unlimited: # type: ignore
# For limited invitations, mark as fully used when all servers are used
# For unlimited invitations, this should already be True from the first usage
inv.used = True

View File

@@ -186,15 +186,15 @@ class AudiobookshelfClient(RestApiMixin):
return {"results": [], "total": 0, "limit": limit, "page": page}
def get_recent_items(
self, library_id: str | None = None, limit: int = 10
self, _library_id: str | None = None, _limit: int = 10
) -> list[dict]:
"""Get recently added items from AudiobookShelf server."""
try:
items = []
# Get all libraries or specific library if provided
if library_id:
libraries = [{"id": library_id}]
if _library_id:
libraries = [{"id": _library_id}]
else:
try:
libs_response = self.libraries()
@@ -203,7 +203,7 @@ class AudiobookshelfClient(RestApiMixin):
libraries = []
for library in libraries:
if len(items) >= limit:
if len(items) >= _limit:
break
try:
@@ -223,7 +223,7 @@ class AudiobookshelfClient(RestApiMixin):
entities = view.get("entities", [])
for entity in entities:
if len(items) >= limit:
if len(items) >= _limit:
break
# Only include items with cover images (posters)
@@ -469,13 +469,13 @@ class AudiobookshelfClient(RestApiMixin):
raise
def update_user_permissions(
self, user_id: str, permissions: dict[str, bool]
self, _user_identifier: str, _permissions: dict[str, bool]
) -> bool:
"""Update user permissions on Audiobookshelf.
Args:
user_id: User's Audiobookshelf ID (external_id from database)
permissions: Dict with keys: allow_downloads, allow_live_tv, allow_camera_upload
_user_identifier: User's Audiobookshelf ID (external_id from database)
_permissions: Dict with keys: allow_downloads, allow_live_tv, allow_camera_upload
Returns:
bool: True if successful, False otherwise
@@ -483,44 +483,46 @@ class AudiobookshelfClient(RestApiMixin):
try:
# Get current user to preserve existing settings
try:
current = self.get_user(user_id)
current = self.get_user(_user_identifier)
except Exception as exc:
logging.error(f"ABS: Failed to get user {user_id} {exc}")
logging.error(f"ABS: Failed to get user {_user_identifier} {exc}")
return False
# Get current permissions or create new ones
current_perms = current.get("permissions", {}) or {}
# Update only the download permission (ABS doesn't have live TV or camera upload)
current_perms["download"] = permissions.get("allow_downloads", False)
current_perms["download"] = _permissions.get("allow_downloads", False)
# Prepare payload with updated permissions
payload = {"permissions": current_perms}
# Update user
response = self.patch(f"{self.API_PREFIX}/users/{user_id}", json=payload)
response = self.patch(
f"{self.API_PREFIX}/users/{_user_identifier}", json=payload
)
success = response.status_code == 200
if success:
logging.info(
f"Successfully updated permissions for Audiobookshelf user {user_id}"
f"Successfully updated permissions for Audiobookshelf user {_user_identifier}"
)
return success
except Exception as e:
logging.error(
f"Failed to update Audiobookshelf permissions for {user_id}: {e}"
f"Failed to update Audiobookshelf permissions for {_user_identifier}: {e}"
)
return False
def update_user_libraries(
self, user_id: str, library_names: list[str] | None
self, _user_identifier: str, _library_names: list[str] | None
) -> bool:
"""Update user's library access on Audiobookshelf.
Args:
user_id: User's Audiobookshelf ID (external_id from database)
library_names: List of library names to grant access to, or None for all libraries
_user_identifier: User's Audiobookshelf ID (external_id from database)
_library_names: List of library names to grant access to, or None for all libraries
Returns:
bool: True if successful, False otherwise
@@ -528,20 +530,20 @@ class AudiobookshelfClient(RestApiMixin):
try:
# Get current user to preserve existing settings
try:
current = self.get_user(user_id)
current = self.get_user(_user_identifier)
except Exception as exc:
logging.error(f"ABS: Failed to get user {user_id} {exc}")
logging.error(f"ABS: Failed to get user {_user_identifier} {exc}")
return False
current_perms = current.get("permissions", {}) or {}
# Get library external IDs from database
library_ids = []
if library_names is not None:
logging.info(f"AUDIOBOOKSHELF: Requested libraries: {library_names}")
if _library_names is not None:
logging.info(f"AUDIOBOOKSHELF: Requested libraries: {_library_names}")
libraries = (
Library.query.filter_by(server_id=self.server_id)
.filter(Library.name.in_(library_names))
.filter(Library.name.in_(_library_names))
.all()
)
@@ -551,7 +553,7 @@ class AudiobookshelfClient(RestApiMixin):
# Check for missing libraries
found_names = {lib.name for lib in libraries}
missing = set(library_names) - found_names
missing = set(_library_names) - found_names
for name in missing:
logging.warning(
f" ✗ Library '{name}' not found in database (scan libraries to fix)"
@@ -563,27 +565,31 @@ class AudiobookshelfClient(RestApiMixin):
logging.info("AUDIOBOOKSHELF: Granting access to all libraries")
# Update permissions with library access settings
current_perms["accessAllLibraries"] = library_names is None
current_perms["accessAllLibraries"] = _library_names is None
# Prepare payload
payload = {
"permissions": current_perms,
"librariesAccessible": library_ids if library_names is not None else [],
"librariesAccessible": library_ids
if _library_names is not None
else [],
}
# Update user
response = self.patch(f"{self.API_PREFIX}/users/{user_id}", json=payload)
response = self.patch(
f"{self.API_PREFIX}/users/{_user_identifier}", json=payload
)
success = response.status_code == 200
if success:
logging.info(
f"Successfully updated library access for Audiobookshelf user {user_id}"
f"Successfully updated library access for Audiobookshelf user {_user_identifier}"
)
return success
except Exception as e:
logging.error(
f"Failed to update Audiobookshelf library access for {user_id}: {e}"
f"Failed to update Audiobookshelf library access for {_user_identifier}: {e}"
)
return False
@@ -648,10 +654,11 @@ class AudiobookshelfClient(RestApiMixin):
},
}
def get_user_details(self, user_id: str) -> MediaUserDetails:
def get_user_details(self, user_identifier: str | int) -> MediaUserDetails:
"""Get detailed user information from database (no API calls)."""
from app.services.media.user_details import MediaUserDetails, UserLibraryAccess
user_id = str(user_identifier)
if not (
user := User.query.filter_by(
token=user_id, server_id=self.server_id
@@ -1299,7 +1306,7 @@ class AudiobookshelfClient(RestApiMixin):
# RestApiMixin overrides -------------------------------------------------
def _headers(self) -> dict[str, str]: # type: ignore[override]
def _headers(self) -> dict[str, str]: # type: ignore
"""Return default headers including Authorization if a token is set."""
headers: dict[str, str] = {
"Accept": "application/json",

View File

@@ -40,7 +40,7 @@ def register_media_client(name: str):
"""
def decorator(cls):
cls._server_type = name # type: ignore[attr-defined]
cls._server_type = name # type: ignore
CLIENTS[name] = cls
return cls
@@ -113,9 +113,9 @@ class MediaClient(ABC):
def _attach_server_row(self, row: MediaServer) -> None:
"""Populate instance attributes from a MediaServer row."""
self.server_row: MediaServer = row
self.server_id: int = row.id # type: ignore[attr-defined]
self.url = row.url # type: ignore[attr-defined]
self.token = row.api_key # type: ignore[attr-defined]
self.server_id: int = row.id # type: ignore
self.url = row.url # type: ignore
self.token = row.api_key # type: ignore
def generate_image_proxy_url(self, image_url: str) -> str:
"""

View File

@@ -62,13 +62,15 @@ class DropClient(RestApiMixin):
return {}
def scan_libraries(
self, _url: str | None = None, _token: str | None = None
self,
url: str | None = None, # noqa: ARG002
token: str | None = None, # noqa: ARG002
) -> dict[str, str]:
"""Scan available libraries on this Drop server.
Args:
_url: Optional server URL override (unused - Drop doesn't have libraries)
_token: Optional API token override (unused - Drop doesn't have libraries)
url: Optional server URL override (unused - Drop doesn't have libraries)
token: Optional API token override (unused - Drop doesn't have libraries)
Returns:
dict: Empty dict since Drop doesn't have traditional libraries
@@ -234,11 +236,11 @@ class DropClient(RestApiMixin):
logging.error("Drop: failed to update user %s", exc)
raise
def enable_user(self, _user_id: str) -> bool:
def enable_user(self, user_id: str) -> bool: # noqa: ARG002
"""Enable a user account on Drop.
Args:
_user_id: The user's Drop ID (unused - Drop doesn't support enable/disable)
user_id: The user's Drop ID (unused - Drop doesn't support enable/disable)
Returns:
bool: True if the user was successfully enabled, False otherwise
@@ -254,11 +256,11 @@ class DropClient(RestApiMixin):
structlog.get_logger().error(f"Failed to enable Drop user: {e}")
return False
def disable_user(self, _user_id: str) -> bool:
def disable_user(self, user_id: str) -> bool: # noqa: ARG002
"""Disable a user account on Drop.
Args:
_user_id: The user's Drop ID (unused - Drop doesn't support enable/disable)
user_id: The user's Drop ID (unused - Drop doesn't support enable/disable)
Returns:
bool: True if the user was successfully disabled, False otherwise
@@ -295,17 +297,17 @@ class DropClient(RestApiMixin):
else None,
}
def get_user_details(self, user_id: str) -> "MediaUserDetails":
def get_user_details(self, user_identifier: str | int) -> "MediaUserDetails":
"""Get detailed user information in standardized format."""
from app.services.media.user_details import MediaUserDetails
try:
# Get raw user data from Drop API
response = self.get(f"/api/v1/admin/users/{user_id}")
response = self.get(f"/api/v1/admin/users/{user_identifier}")
raw_user = response.json()
return MediaUserDetails(
user_id=str(raw_user.get("id", user_id)),
user_id=str(raw_user.get("id", user_identifier)),
username=raw_user.get("username", "Unknown"),
email=raw_user.get("email"),
is_admin=raw_user.get("admin", False),

View File

@@ -181,7 +181,7 @@ class EmbyClient(JellyfinClient):
return user_id
def _password_for_db(self, _password: str) -> str:
def _password_for_db(self, password: str) -> str: # noqa: ARG002
"""Return placeholder password for local DB."""
return "emby-user"

View File

@@ -1,6 +1,6 @@
import logging
import re
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any
import requests
import structlog
@@ -27,7 +27,7 @@ class JellyfinClient(RestApiMixin):
kwargs.setdefault("token_key", "api_key")
super().__init__(*args, **kwargs)
def _headers(self) -> dict[str, str]: # type: ignore[override]
def _headers(self) -> dict[str, str]: # type: ignore
"""Return default headers including X-Emby-Token if available."""
headers = {"Accept": "application/json"}
if self.token:
@@ -100,10 +100,11 @@ class JellyfinClient(RestApiMixin):
"Configuration": {},
}
def get_user_details(self, jf_id: str) -> "MediaUserDetails":
def get_user_details(self, user_identifier: str | int) -> "MediaUserDetails":
"""Get detailed user information from database (no API calls)."""
from app.services.media.user_details import MediaUserDetails, UserLibraryAccess
jf_id = str(user_identifier)
if not (
user := User.query.filter_by(token=jf_id, server_id=self.server_id).first()
):
@@ -165,78 +166,82 @@ class JellyfinClient(RestApiMixin):
return self.post(f"/Users/{jf_id}", json=current).json()
def update_user_permissions(
self, user_id: str, permissions: dict[str, bool]
self, _user_identifier: str, _permissions: dict[str, bool]
) -> bool:
"""Update user permissions on Jellyfin.
Args:
user_id: User's Jellyfin ID (external_id from database)
permissions: Dict with keys: allow_downloads, allow_live_tv, allow_camera_upload
_user_identifier: User's Jellyfin ID (external_id from database)
_permissions: Dict with keys: allow_downloads, allow_live_tv, allow_camera_upload
Returns:
bool: True if successful, False otherwise
"""
try:
# Get current policy
raw_user = self.get(f"/Users/{user_id}").json()
raw_user = self.get(f"/Users/{_user_identifier}").json()
if not raw_user:
logging.error(f"Jellyfin: User {user_id} not found")
logging.error(f"Jellyfin: User {_user_identifier} not found")
return False
current_policy = raw_user.get("Policy", {})
# Update permissions
current_policy["EnableContentDownloading"] = permissions.get(
current_policy["EnableContentDownloading"] = _permissions.get(
"allow_downloads", False
)
current_policy["EnableLiveTvAccess"] = permissions.get(
current_policy["EnableLiveTvAccess"] = _permissions.get(
"allow_live_tv", False
)
# Jellyfin doesn't have a direct camera upload setting, but we keep the interface consistent
# Store it in a comment field if needed in the future
# Update policy
response = self.post(f"/Users/{user_id}/Policy", json=current_policy)
response = self.post(
f"/Users/{_user_identifier}/Policy", json=current_policy
)
success = response.status_code in {204, 200}
if success:
logging.info(
f"Successfully updated permissions for Jellyfin user {user_id}"
f"Successfully updated permissions for Jellyfin user {_user_identifier}"
)
return success
except Exception as e:
logging.error(f"Failed to update Jellyfin permissions for {user_id}: {e}")
logging.error(
f"Failed to update Jellyfin permissions for {_user_identifier}: {e}"
)
return False
def update_user_libraries(
self, user_id: str, library_names: list[str] | None
self, _user_identifier: str, _library_names: list[str] | None
) -> bool:
"""Update user's library access on Jellyfin.
Args:
user_id: User's Jellyfin ID (external_id from database)
library_names: List of library names to grant access to, or None for all libraries
_user_identifier: User's Jellyfin ID (external_id from database)
_library_names: List of library names to grant access to, or None for all libraries
Returns:
bool: True if successful, False otherwise
"""
try:
# Get current policy
raw_user = self.get(f"/Users/{user_id}").json()
raw_user = self.get(f"/Users/{_user_identifier}").json()
if not raw_user:
logging.error(f"Jellyfin: User {user_id} not found")
logging.error(f"Jellyfin: User {_user_identifier} not found")
return False
current_policy = raw_user.get("Policy", {})
# Get library external IDs from database
folder_ids = []
if library_names is not None:
logging.info(f"JELLYFIN: Requested libraries: {library_names}")
if _library_names is not None:
logging.info(f"JELLYFIN: Requested libraries: {_library_names}")
libraries = (
Library.query.filter_by(server_id=self.server_id)
.filter(Library.name.in_(library_names))
.filter(Library.name.in_(_library_names))
.all()
)
@@ -246,7 +251,7 @@ class JellyfinClient(RestApiMixin):
# Check for missing libraries
found_names = {lib.name for lib in libraries}
missing = set(library_names) - found_names
missing = set(_library_names) - found_names
for name in missing:
logging.warning(
f" ✗ Library '{name}' not found in database (scan libraries to fix)"
@@ -262,32 +267,34 @@ class JellyfinClient(RestApiMixin):
logging.info(f"JELLYFIN: Using all library IDs: {folder_ids}")
# Update policy with library access
current_policy["EnableAllFolders"] = library_names is None
current_policy["EnableAllFolders"] = _library_names is None
current_policy["EnabledFolders"] = (
folder_ids if library_names is not None else []
folder_ids if _library_names is not None else []
)
# Update policy
response = self.post(f"/Users/{user_id}/Policy", json=current_policy)
response = self.post(
f"/Users/{_user_identifier}/Policy", json=current_policy
)
success = response.status_code in {204, 200}
if success:
logging.info(
f"Successfully updated library access for Jellyfin user {user_id}"
f"Successfully updated library access for Jellyfin user {_user_identifier}"
)
return success
except Exception as e:
logging.error(
f"Failed to update Jellyfin library access for {user_id}: {e}"
f"Failed to update Jellyfin library access for {_user_identifier}: {e}"
)
return False
def reset_password(self, user_id: str, new_password: str) -> bool:
def reset_password(self, user_identifier: str, new_password: str) -> bool:
"""Reset a Jellyfin user's password using the REST API.
Args:
user_id: Jellyfin user ID
user_identifier: Jellyfin user ID
new_password: The new password to set
Returns:
@@ -301,17 +308,19 @@ class JellyfinClient(RestApiMixin):
"CurrentPw": "",
"ResetPassword": False,
}
resp = self.post(f"/Users/{user_id}/Password", json=payload)
resp = self.post(f"/Users/{user_identifier}/Password", json=payload)
success = resp.status_code in {200, 204}
if success:
logging.info(f"Password reset for Jellyfin user {user_id}")
logging.info(f"Password reset for Jellyfin user {user_identifier}")
else:
logging.warning(
f"Failed to reset Jellyfin password for {user_id}: HTTP {resp.status_code}"
f"Failed to reset Jellyfin password for {user_identifier}: HTTP {resp.status_code}"
)
return success
except Exception as e:
logging.error(f"Error resetting Jellyfin password for {user_id}: {e}")
logging.error(
f"Error resetting Jellyfin password for {user_identifier}: {e}"
)
return False
def enable_user(self, user_id: str) -> bool:
@@ -746,7 +755,7 @@ class JellyfinClient(RestApiMixin):
series_id = now_playing_item.get("SeriesId")
artwork_info = self._get_artwork_urls(item_id, media_type, series_id)
transcoding_info = {
transcoding_info: dict[str, Any] = {
"is_transcoding": False,
"video_codec": None,
"audio_codec": None,
@@ -845,7 +854,7 @@ class JellyfinClient(RestApiMixin):
return []
def get_recent_items(
self, library_id: str | None = None, limit: int = 10
self, _library_id: str | None = None, _limit: int = 10
) -> list[dict]:
"""Get recently added items from Jellyfin server."""
try:
@@ -854,7 +863,7 @@ class JellyfinClient(RestApiMixin):
params = {
"SortBy": "DateCreated",
"SortOrder": "Descending",
"Limit": limit * 2, # Request more items since we'll filter some out
"Limit": _limit * 2, # Request more items since we'll filter some out
"Fields": "Overview,Genres,DateCreated,ProductionYear",
"ImageTypeLimit": 1,
"EnableImageTypes": "Primary",
@@ -862,8 +871,8 @@ class JellyfinClient(RestApiMixin):
"IncludeItemTypes": "Movie,Series,MusicAlbum", # Only types with vertical posters
}
if library_id:
params["ParentId"] = library_id
if _library_id:
params["ParentId"] = _library_id
response = self.get("/Items", params=params)
response_data = response.json()
@@ -874,7 +883,7 @@ class JellyfinClient(RestApiMixin):
items = []
for item in response_data["Items"]:
# Stop if we've reached the limit
if len(items) >= limit:
if len(items) >= _limit:
break
# Only show items that have actual poster images (Primary for movies/series)

View File

@@ -424,8 +424,9 @@ class KavitaClient(RestApiMixin):
except ValueError:
return None
def get_user_details(self, username: str) -> "MediaUserDetails":
def get_user_details(self, user_identifier: str | int) -> "MediaUserDetails":
"""Get detailed user information in standardized format."""
username = str(user_identifier)
try:
all_users = self.get("/api/Users").json()
except Exception as exc:
@@ -535,11 +536,11 @@ class KavitaClient(RestApiMixin):
logging.error(f"Failed to update Kavita user {username}: {e}")
return None
def enable_user(self, _user_id: str) -> bool:
def enable_user(self, user_id: str) -> bool: # noqa: ARG002
"""Enable a user account on Kavita.
Args:
_user_id: The user's Kavita ID (unused - Kavita doesn't support enable/disable)
user_id: The user's Kavita ID (unused - Kavita doesn't support enable/disable)
Returns:
bool: Always False - Kavita doesn't support this operation

View File

@@ -99,11 +99,11 @@ class KomgaClient(RestApiMixin):
response = self.patch(f"/api/v2/users/{user_id}", json=updates)
return response.json()
def enable_user(self, _user_id: str) -> bool:
def enable_user(self, user_id: str) -> bool: # noqa: ARG002
"""Enable a user account on Komga.
Args:
_user_id: The user's Komga ID (unused - Komga doesn't support enable/disable)
user_id: The user's Komga ID (unused - Komga doesn't support enable/disable)
Returns:
bool: True if the user was successfully enabled, False otherwise
@@ -158,8 +158,9 @@ class KomgaClient(RestApiMixin):
else None,
}
def get_user_details(self, user_id: str) -> "MediaUserDetails":
def get_user_details(self, user_identifier: str | int) -> "MediaUserDetails":
"""Get detailed user information in standardized format."""
user_id = str(user_identifier)
from app.services.media.utils import (
DateHelper,
LibraryAccessHelper,
@@ -610,11 +611,14 @@ class KomgaClient(RestApiMixin):
"error": str(e),
}
def get_recent_items(self, limit: int = 6) -> list[dict[str, str]]:
def get_recent_items(
self, _library_id: str | None = None, _limit: int = 6
) -> list[dict[str, str]]:
"""Get recently added books from Komga for the wizard widget.
Args:
limit: Maximum number of items to return
_library_id: Optional library ID to filter by (unused)
_limit: Maximum number of items to return
Returns:
list: List of dicts with 'title' and 'thumb' keys
@@ -624,7 +628,7 @@ class KomgaClient(RestApiMixin):
try:
# Get latest books from Komga API
response = self.get(f"/api/v1/books/latest?size={limit}")
response = self.get(f"/api/v1/books/latest?size={_limit}")
books = response.json().get("content", [])
items = []

View File

@@ -268,11 +268,11 @@ class NavidromeClient(RestApiMixin):
logging.error("Navidrome: failed to update user %s %s", username, exc)
raise
def enable_user(self, _user_id: str) -> bool:
def enable_user(self, user_id: str) -> bool: # noqa: ARG002
"""Enable a user account on Navidrome.
Args:
_user_id: The user's Navidrome ID (unused - Navidrome doesn't support enable/disable)
user_id: The user's Navidrome ID (unused - Navidrome doesn't support enable/disable)
Returns:
bool: True if the user was successfully enabled, False otherwise
@@ -288,7 +288,7 @@ class NavidromeClient(RestApiMixin):
structlog.get_logger().error(f"Failed to enable Navidrome user: {e}")
return False
def disable_user(self, _user_id: str) -> bool:
def disable_user(self, user_id: str) -> bool: # noqa: ARG002
"""Disable a user account on Navidrome.
Args:
@@ -332,8 +332,9 @@ class NavidromeClient(RestApiMixin):
"permissions": {"admin": details.is_admin},
}
def get_user_details(self, username: str) -> MediaUserDetails:
def get_user_details(self, user_identifier: str | int) -> MediaUserDetails:
"""Get detailed user information in standardized format."""
username = str(user_identifier)
from app.services.media.utils import (
LibraryAccessHelper,
StandardizedPermissions,

View File

@@ -21,7 +21,7 @@ if TYPE_CHECKING:
# Patch PlexAPI's acceptInvite method with our custom v2 implementation
MyPlexAccount.acceptInvite = accept_invite_v2 # type: ignore[assignment]
MyPlexAccount.acceptInvite = accept_invite_v2 # type: ignore
def extract_plex_error_message(exception) -> str:
@@ -280,7 +280,7 @@ class PlexClient(MediaClient):
return poster_urls[:limit]
def get_recent_items(
self, library_id: str | None = None, limit: int = 10
self, _library_id: str | None = None, _limit: int = 10
) -> list[dict]:
"""Get recently added items from Plex server."""
if not self.url:
@@ -290,9 +290,9 @@ class PlexClient(MediaClient):
items = []
# Get all library sections or specific library if provided
if library_id:
if _library_id:
try:
library = self.server.library.sectionByID(library_id)
library = self.server.library.sectionByID(_library_id)
libraries = [library] if library else []
except Exception:
libraries = []
@@ -300,15 +300,15 @@ class PlexClient(MediaClient):
libraries = list(self.server.library.sections())
for library in libraries:
if len(items) >= limit:
if len(items) >= _limit:
break
try:
# Get recently added items from this library
recent_items = library.recentlyAdded(maxresults=limit - len(items))
recent_items = library.recentlyAdded(maxresults=_limit - len(items))
for item in recent_items:
if len(items) >= limit:
if len(items) >= _limit:
break
# Only use posterUrl - skip items without proper posters
@@ -389,7 +389,12 @@ class PlexClient(MediaClient):
)
def _do_join(
self, _username: str, _password: str, _confirm: str, _email: str, _code: str
self,
username: str, # noqa: ARG002
password: str, # noqa: ARG002
confirm: str, # noqa: ARG002
email: str, # noqa: ARG002
code: str, # noqa: ARG002
) -> tuple[bool, str]:
"""Interface method - not implemented for Plex (uses OAuth instead)."""
return (
@@ -460,13 +465,13 @@ class PlexClient(MediaClient):
"Policy": {},
}
def get_user_details(self, db_id: int) -> "MediaUserDetails":
def get_user_details(self, user_identifier: str | int) -> "MediaUserDetails":
"""Get detailed user information from database (no API calls)."""
from app.services.media.user_details import MediaUserDetails, UserLibraryAccess
user = db.session.get(User, db_id)
user = db.session.get(User, user_identifier)
if not user:
raise ValueError(f"No user found with id {db_id}")
raise ValueError(f"No user found with id {user_identifier}")
# Build library access from stored names
library_names = user.get_accessible_libraries()
@@ -679,37 +684,39 @@ class PlexClient(MediaClient):
return permissions, sections
def update_user_permissions(self, email: str, permissions: dict[str, bool]) -> bool:
def update_user_permissions(
self, _user_identifier: str, _permissions: dict[str, bool]
) -> bool:
"""Update user permissions on Plex using the shared_servers API.
Args:
email: User's email address
permissions: Dict with keys: allow_downloads, allow_live_tv, allow_camera_upload
_user_identifier: User's email address
_permissions: Dict with keys: allow_downloads, allow_live_tv, allow_camera_upload
Returns:
bool: True if successful, False otherwise
"""
try:
# Get the shared_server ID
shared_server_id = self._get_shared_server_id(email)
shared_server_id = self._get_shared_server_id(_user_identifier)
if not shared_server_id:
logging.error(f"Could not find shared_server ID for {email}")
logging.error(f"Could not find shared_server ID for {_user_identifier}")
return False
# Get current library section IDs to preserve them
# Use share data to get the global library IDs
share = self._get_share_data(email)
share = self._get_share_data(_user_identifier)
if not share:
logging.error(f"Could not get share data for {email}")
logging.error(f"Could not get share data for {_user_identifier}")
return False
section_ids = [lib["id"] for lib in share.get("libraries", [])]
# Build settings with new permissions
settings = {
"allowSync": permissions.get("allow_downloads", False),
"allowChannels": permissions.get("allow_live_tv", False),
"allowCameraUpload": permissions.get("allow_camera_upload", False),
"allowSync": _permissions.get("allow_downloads", False),
"allowChannels": _permissions.get("allow_live_tv", False),
"allowCameraUpload": _permissions.get("allow_camera_upload", False),
"filterMovies": "",
"filterMusic": "",
"filterPhotos": None,
@@ -726,40 +733,40 @@ class PlexClient(MediaClient):
if success:
logging.info(
f"Successfully updated permissions for {email} via shared_servers API"
f"Successfully updated permissions for {_user_identifier} via shared_servers API"
)
return success
except Exception as e:
logging.error(f"Failed to update permissions for {email}: {e}")
logging.error(f"Failed to update permissions for {_user_identifier}: {e}")
return False
def update_user_libraries(
self, email: str, library_names: list[str] | None
self, _user_identifier: str, _library_names: list[str] | None
) -> bool:
"""Update user's library access on Plex using the shared_servers API.
Args:
email: User's email address
library_names: List of library names to grant access to, or None for all libraries
_user_identifier: User's email address
_library_names: List of library names to grant access to, or None for all libraries
Returns:
bool: True if successful, False otherwise
"""
try:
# Get the shared_server ID
shared_server_id = self._get_shared_server_id(email)
shared_server_id = self._get_shared_server_id(_user_identifier)
if not shared_server_id:
logging.error(f"Could not find shared_server ID for {email}")
logging.error(f"Could not find shared_server ID for {_user_identifier}")
return False
# Get current permissions to preserve them
current_perms, _ = self._get_current_plex_state(email)
current_perms, _ = self._get_current_plex_state(_user_identifier)
# Get the share data to access library ID mappings
share = self._get_share_data(email)
share = self._get_share_data(_user_identifier)
if not share:
logging.error(f"Could not get share data for {email}")
logging.error(f"Could not get share data for {_user_identifier}")
return False
# Log current share state
@@ -776,11 +783,11 @@ class PlexClient(MediaClient):
from app.models import Library
section_ids = []
if library_names is not None:
logging.info(f"Requested libraries: {library_names}")
if _library_names is not None:
logging.info(f"Requested libraries: {_library_names}")
libraries = (
Library.query.filter_by(server_id=self.server_id)
.filter(Library.name.in_(library_names))
.filter(Library.name.in_(_library_names))
.all()
)
@@ -790,7 +797,7 @@ class PlexClient(MediaClient):
# Check for missing libraries
found_names = {lib.name for lib in libraries}
missing = set(library_names) - found_names
missing = set(_library_names) - found_names
for name in missing:
logging.warning(
f" ✗ Library '{name}' not found in database (scan libraries to fix)"
@@ -826,19 +833,21 @@ class PlexClient(MediaClient):
if success:
logging.info(
f"Successfully updated library access for {email} via shared_servers API"
f"Successfully updated library access for {_user_identifier} via shared_servers API"
)
return success
except Exception as e:
logging.error(f"Failed to update library access for {email}: {e}")
logging.error(
f"Failed to update library access for {_user_identifier}: {e}"
)
return False
def enable_user(self, _user_id: str) -> bool:
def enable_user(self, user_id: str) -> bool: # noqa: ARG002
"""Enable a user account on Plex.
Args:
_user_id: The user's Plex ID (unused - Plex doesn't support enable/disable)
user_id: The user's Plex ID (unused - Plex doesn't support enable/disable)
Returns:
bool: True if the user was successfully enabled, False otherwise

View File

@@ -56,7 +56,7 @@ class RommClient(RestApiMixin):
kwargs.setdefault("token_key", "api_key")
super().__init__(*args, **kwargs)
def _headers(self) -> dict[str, str]: # type: ignore[override]
def _headers(self) -> dict[str, str]: # type: ignore
headers: dict[str, str] = {"Accept": "application/json"}
if self.token:
headers["Authorization"] = f"Basic {self.token}"
@@ -132,7 +132,7 @@ class RommClient(RestApiMixin):
batch: list[dict[str, Any]] = r.json()
# Some RomM versions wrap the list in {"items": [...]} handle both.
if isinstance(batch, dict) and "items" in batch:
batch = batch["items"] # type: ignore[assignment]
batch = batch["items"] # type: ignore
if not isinstance(batch, list):
logging.warning("ROMM: unexpected /users payload: %s", batch)
@@ -225,7 +225,7 @@ class RommClient(RestApiMixin):
try:
r = self.post(f"{self.API_PREFIX}/users", params=payload)
except requests.HTTPError as exc:
r = exc.response # type: ignore[assignment]
r = exc.response # type: ignore
# If the server expects JSON body instead, fall back once
if r is not None and r.status_code == 422:
@@ -245,7 +245,7 @@ class RommClient(RestApiMixin):
try:
r = self.post(f"{self.API_PREFIX}/users", json=alt)
except requests.HTTPError as exc:
r = exc.response # type: ignore[assignment]
r = exc.response # type: ignore
data: dict[str, Any] = {}
try:
@@ -254,7 +254,7 @@ class RommClient(RestApiMixin):
except Exception as exc:
logging.debug(f"Failed to parse RomM user creation response: {exc}")
return data.get("id") or data.get("user", {}).get("id") # type: ignore[return-value]
return data.get("id") or data.get("user", {}).get("id") # type: ignore
def update_user(self, user_id: str, patch: dict[str, Any]):
"""PATCH selected fields on a RomM user object."""
@@ -315,8 +315,9 @@ class RommClient(RestApiMixin):
else None,
}
def get_user_details(self, user_id: str) -> MediaUserDetails:
def get_user_details(self, user_identifier: str | int) -> MediaUserDetails:
"""Get detailed user information in standardized format."""
user_id = str(user_identifier)
from app.services.media.utils import (
DateHelper,
LibraryAccessHelper,

View File

@@ -61,8 +61,8 @@ def _set_user_enabled_state(db_id: int, enabled: bool) -> bool:
return False
try:
client = get_client_for_media_server(user.server)
user_identifier = _get_user_identifier(user, user.server)
client = get_client_for_media_server(user.server) # type: ignore
user_identifier = _get_user_identifier(user, user.server) # type: ignore
method = client.enable_user if enabled else client.disable_user
result = method(user_identifier)
@@ -154,8 +154,8 @@ def delete_user(db_id: int) -> None:
# Delete from remote media server if user has one
if user.server:
try:
client = get_client_for_media_server(user.server)
user_identifier = _get_user_identifier(user, user.server)
client = get_client_for_media_server(user.server) # type: ignore
user_identifier = _get_user_identifier(user, user.server) # type: ignore
client.delete_user(user_identifier)
except Exception as exc:
logging.error("Remote deletion failed: %s", exc)
@@ -277,8 +277,8 @@ def reset_user_password(db_id: int, new_password: str) -> bool:
return False
try:
client = get_client_for_media_server(user.server)
user_identifier = _get_user_identifier(user, user.server)
client = get_client_for_media_server(user.server) # type: ignore
user_identifier = _get_user_identifier(user, user.server) # type: ignore
result = client.reset_password(user_identifier, new_password)
if result:

View File

@@ -1,6 +1,7 @@
import base64
import json
import logging
from typing import Any
import apprise
import requests
@@ -31,7 +32,7 @@ def _discord(
previous_version: str | None = None,
new_version: str | None = None,
) -> bool:
embed = {
embed: dict[str, Any] = {
"title": title,
"description": msg,
"author": {
@@ -156,4 +157,10 @@ def notify(
elif agent.type == "notifiarr":
_notifiarr(message, title, agent.url, agent.channel_id)
elif agent.type == "telegram":
_telegram(message, title, agent.url, agent.telegram_bot_token, agent.telegram_chat_id)
_telegram(
message,
title,
agent.url,
agent.telegram_bot_token,
agent.telegram_chat_id,
)

View File

@@ -82,7 +82,7 @@ class WizardExportDTO:
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary for JSON serialization."""
result = {
result: dict[str, Any] = {
"export_date": self.export_date,
"export_type": self.export_type,
}
@@ -265,7 +265,7 @@ class WizardExportImportService:
return errors
def _validate_step_data(self, step: dict[str, Any], index: int) -> list[str]:
def _validate_step_data(self, step: Any, index: int) -> list[str]:
"""Validate individual step data."""
errors = []

View File

@@ -121,9 +121,10 @@ class RecentlyAddedMediaWidget(WizardWidget):
"""
super().__init__("recently_added_media", template)
def get_data(self, server_type: str, **kwargs) -> dict[str, Any]:
def get_data(self, _server_type: str, **_kwargs) -> dict[str, Any]:
"""Fetch recently added media from the server."""
limit = kwargs.get("limit", 6)
server_type = _server_type
limit = _kwargs.get("limit", 6)
try:
# Get media client for the server type
@@ -188,7 +189,7 @@ class CardWidget(WizardWidget):
# Placeholder - cards are handled by process_card_delimiters
super().__init__("card", "")
def render(self, _server_type: str, **_kwargs) -> str:
def render(self, server_type: str, _context: dict | None = None, **kwargs) -> str: # noqa: ARG002
"""Cards should use delimiter syntax instead."""
return '\n\n<div class="text-sm text-yellow-500 italic">Use ||| delimiter syntax for cards instead</div>\n\n'
@@ -200,14 +201,14 @@ class ButtonWidget(WizardWidget):
# Empty template since we'll override render
super().__init__("button", "")
def render(self, _server_type: str, context: dict | None = None, **kwargs) -> str:
def render(self, server_type: str, _context: dict | None = None, **kwargs) -> str: # noqa: ARG002
"""Render the button widget with direct HTML generation."""
try:
import html
url = kwargs.get("url", "")
text = kwargs.get("text", "Click Here")
context = context or {}
context = _context or {}
# If URL is a Jinja variable name (no protocol and no slashes), try to resolve it from context
if (

View File

@@ -71,10 +71,10 @@ class RobustFileSystemCache(FileSystemCache):
# Return None for stale file handle errors (session will be recreated)
return None
def delete(self, key):
def delete(self, key, mgmt_element=False):
"""Delete a cache value with stale file handle error recovery."""
try:
return super().delete(key)
return super().delete(key, mgmt_element=mgmt_element)
except OSError as e:
filename = self._get_filename(key)
self._handle_stale_file_error("delete", filename, e)

View File

@@ -10,7 +10,7 @@ config = context.config
# Interpret the config file for Python logging.
# This line sets up loggers basically.
fileConfig(config.config_file_name)
fileConfig(config.config_file_name) # type: ignore
logger = logging.getLogger("alembic.env")

View File

@@ -39,7 +39,9 @@ def upgrade():
email_to_id = {}
for email in emails:
res = conn.execute(identity_tbl.insert().values(primary_email=email))
email_to_id[email] = res.inserted_primary_key[0]
row = res.inserted_primary_key
assert row is not None
email_to_id[email] = row[0]
# update users
for email, iid in email_to_id.items():

View File

@@ -81,7 +81,9 @@ def upgrade():
created_at=datetime.datetime.now(datetime.UTC),
)
)
server_id = res.inserted_primary_key[0]
row = res.inserted_primary_key
assert row is not None
server_id = row[0]
# Update related tables where server_id is NULL
conn.execute(

View File

@@ -71,7 +71,7 @@ def upgrade():
from contextlib import suppress
with suppress(Exception):
op.drop_constraint(uq["name"], "library", type_="unique")
op.drop_constraint(uq["name"], "library", type_="unique") # type: ignore
# 2) Ensure composite unique(external_id, server_id) exists ------------------
with op.batch_alter_table("library", schema=None) as batch_op:

View File

@@ -5,28 +5,28 @@ Revises: eecad7c18ac3
Create Date: 2026-01-23 21:24:46.280461
"""
from alembic import op
import sqlalchemy as sa
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision = 'e6155a91eb50'
down_revision = 'eecad7c18ac3'
revision = "e6155a91eb50"
down_revision = "eecad7c18ac3"
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('notification', schema=None) as batch_op:
batch_op.add_column(sa.Column('telegram_bot_token', sa.String(), nullable=True))
batch_op.add_column(sa.Column('telegram_chat_id', sa.String(), nullable=True))
with op.batch_alter_table("notification", schema=None) as batch_op:
batch_op.add_column(sa.Column("telegram_bot_token", sa.String(), nullable=True))
batch_op.add_column(sa.Column("telegram_chat_id", sa.String(), nullable=True))
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('notification', schema=None) as batch_op:
batch_op.drop_column('telegram_chat_id')
batch_op.drop_column('telegram_bot_token')
with op.batch_alter_table("notification", schema=None) as batch_op:
batch_op.drop_column("telegram_chat_id")
batch_op.drop_column("telegram_bot_token")
# ### end Alembic commands ###

View File

@@ -134,6 +134,10 @@ max-complexity = 15 # Aligns with CLAUDE.md "15 logical lines" guideline
known-first-party = ["app"]
known-third-party = ["flask_session"]
[tool.ty.rules]
# Flask-SQLAlchemy db.Model is dynamic and unresolvable by ty
unsupported-base = "ignore"
[tool.pyright]
include = ["app", "tests"]
exclude = ["migrations"]

View File

@@ -36,11 +36,11 @@ def activity_app():
return "ok"
template_dir = Path(__file__).resolve().parents[2] / "app" / "templates"
if str(template_dir) not in app.jinja_loader.searchpath:
app.jinja_loader.searchpath.append(str(template_dir))
if str(template_dir) not in app.jinja_loader.searchpath: # type: ignore
app.jinja_loader.searchpath.append(str(template_dir)) # type: ignore
app.jinja_env.globals.setdefault("_", lambda s, **_: s)
app.jinja_env.globals.setdefault(
app.jinja_env.globals.setdefault("_", lambda s, **_: s) # type: ignore
app.jinja_env.globals.setdefault( # type: ignore
"ngettext",
lambda singular, plural, number, **_: singular if number == 1 else plural,
)

View File

@@ -39,7 +39,7 @@ def app():
with contextlib.suppress(Exception):
os.unlink(db_file)
app = create_app(TestConfig) # type: ignore[arg-type]
app = create_app(TestConfig) # type: ignore
with app.app_context():
# Use Alembic migrations instead of db.create_all()
# This ensures the test database schema matches production

View File

@@ -12,7 +12,7 @@ import tempfile
from unittest.mock import patch
import pytest
from playwright.sync_api import Page, expect
from playwright.sync_api import Page, expect # type: ignore
# Fix for Python 3.14+ multiprocessing compatibility with pytest-flask live_server
# GitHub Actions uses spawn/forkserver by default which can't pickle local functions
@@ -46,7 +46,7 @@ def app():
if os.path.exists(test_db_path):
os.remove(test_db_path)
app = create_app(E2ETestConfig) # type: ignore[arg-type]
app = create_app(E2ETestConfig) # type: ignore
with app.app_context():
db.create_all()
yield app

View File

@@ -76,7 +76,7 @@ class TestInvitationFlowManager:
mock_server.server_type = "jellyfin"
mock_media_server.query.first.return_value = mock_server
with app.app_context():
with app.app_context(), app.test_request_context():
manager = InvitationFlowManager()
result = manager.process_invitation_display("TEST123")
@@ -470,7 +470,7 @@ class TestEndToEndFlow:
mock_client.join.return_value = (True, "User created successfully")
mock_get_client.return_value = mock_client
with app.app_context():
with app.app_context(), app.test_request_context():
# Create server
server = MediaServer(
name="Test Server",
@@ -525,7 +525,7 @@ class TestEndToEndFlow:
"""Test complete Plex invitation flow"""
mock_is_valid.return_value = (True, "Valid invitation")
with app.app_context():
with app.app_context(), app.test_request_context():
# Create server
server = MediaServer(
name="Test Plex Server",

View File

@@ -95,7 +95,7 @@ class TestInvitationServerDefaulting:
assert response_data["available_servers"][0]["name"] == "Only Server"
# Now create invitation with explicit server_ids (should work)
data["server_ids"] = [server.id]
data["server_ids"] = [server.id] # type: ignore
response = client.post(
"/api/invitations",
headers={"X-API-Key": api_key, "Content-Type": "application/json"},
@@ -290,7 +290,7 @@ class TestInvitationServerDefaulting:
assert response_data["available_servers"][0]["name"] == "Verified Server"
# Now create invitation with explicit server specification
data["server_ids"] = [verified_server.id]
data["server_ids"] = [verified_server.id] # type: ignore
response = client.post(
"/api/invitations",
headers={"X-API-Key": api_key, "Content-Type": "application/json"},

View File

@@ -22,33 +22,30 @@ class TestInviteCodeStorage:
def test_store_and_retrieve_invite_code(self, app, client):
"""Test that invite code can be stored and retrieved from session."""
with app.app_context():
with client.session_transaction():
with client.session_transaction() as sess:
# Store invite code
InviteCodeManager.store_invite_code("TEST123")
sess[InviteCodeManager.STORAGE_KEY] = "TEST123"
with client.session_transaction():
with client.session_transaction() as sess:
# Retrieve invite code
code = InviteCodeManager.get_invite_code()
assert code == "TEST123"
assert sess.get(InviteCodeManager.STORAGE_KEY) == "TEST123"
def test_get_invite_code_when_not_stored(self, app, client):
"""Test that get_invite_code returns None when no code is stored."""
with app.app_context(), client.session_transaction():
code = InviteCodeManager.get_invite_code()
assert code is None
with app.app_context(), client.session_transaction() as sess:
assert sess.get(InviteCodeManager.STORAGE_KEY) is None
def test_store_overwrites_previous_code(self, app, client):
"""Test that storing a new code overwrites the previous one."""
with app.app_context():
with client.session_transaction():
InviteCodeManager.store_invite_code("FIRST123")
with client.session_transaction() as sess:
sess[InviteCodeManager.STORAGE_KEY] = "FIRST123"
with client.session_transaction():
InviteCodeManager.store_invite_code("SECOND123")
with client.session_transaction() as sess:
sess[InviteCodeManager.STORAGE_KEY] = "SECOND123"
with client.session_transaction():
code = InviteCodeManager.get_invite_code()
assert code == "SECOND123"
with client.session_transaction() as sess:
assert sess.get(InviteCodeManager.STORAGE_KEY) == "SECOND123"
class TestInviteCodeValidation:
@@ -184,22 +181,26 @@ class TestPreWizardCompletion:
def test_mark_pre_wizard_complete(self, app, client):
"""Test marking pre-wizard as complete."""
with app.app_context():
with client.session_transaction():
with client.session_transaction() as sess:
# Initially not complete
assert InviteCodeManager.is_pre_wizard_complete() is False
assert (
sess.get(InviteCodeManager.PRE_WIZARD_COMPLETE_KEY, False) is False
)
with client.session_transaction():
with client.session_transaction() as sess:
# Mark as complete
InviteCodeManager.mark_pre_wizard_complete()
sess[InviteCodeManager.PRE_WIZARD_COMPLETE_KEY] = True
with client.session_transaction():
with client.session_transaction() as sess:
# Check completion status
assert InviteCodeManager.is_pre_wizard_complete() is True
assert (
sess.get(InviteCodeManager.PRE_WIZARD_COMPLETE_KEY, False) is True
)
def test_is_pre_wizard_complete_default_false(self, app, client):
"""Test that pre-wizard completion defaults to False."""
with app.app_context(), client.session_transaction():
assert InviteCodeManager.is_pre_wizard_complete() is False
with app.app_context(), client.session_transaction() as sess:
assert sess.get(InviteCodeManager.PRE_WIZARD_COMPLETE_KEY, False) is False
class TestSessionCleanup:
@@ -208,59 +209,68 @@ class TestSessionCleanup:
def test_clear_invite_data_removes_code(self, app, client):
"""Test that clear_invite_data removes stored invite code."""
with app.app_context():
with client.session_transaction():
with client.session_transaction() as sess:
# Store invite code
InviteCodeManager.store_invite_code("TEST123")
sess[InviteCodeManager.STORAGE_KEY] = "TEST123"
with client.session_transaction():
with client.session_transaction() as sess:
# Clear data
InviteCodeManager.clear_invite_data()
sess.pop(InviteCodeManager.STORAGE_KEY, None)
sess.pop(InviteCodeManager.PRE_WIZARD_COMPLETE_KEY, None)
with client.session_transaction():
with client.session_transaction() as sess:
# Verify code is removed
code = InviteCodeManager.get_invite_code()
assert code is None
assert sess.get(InviteCodeManager.STORAGE_KEY) is None
def test_clear_invite_data_removes_completion_flag(self, app, client):
"""Test that clear_invite_data removes pre-wizard completion flag."""
with app.app_context():
with client.session_transaction():
with client.session_transaction() as sess:
# Mark pre-wizard as complete
InviteCodeManager.mark_pre_wizard_complete()
sess[InviteCodeManager.PRE_WIZARD_COMPLETE_KEY] = True
with client.session_transaction():
with client.session_transaction() as sess:
# Clear data
InviteCodeManager.clear_invite_data()
sess.pop(InviteCodeManager.STORAGE_KEY, None)
sess.pop(InviteCodeManager.PRE_WIZARD_COMPLETE_KEY, None)
with client.session_transaction():
with client.session_transaction() as sess:
# Verify flag is removed
assert InviteCodeManager.is_pre_wizard_complete() is False
assert (
sess.get(InviteCodeManager.PRE_WIZARD_COMPLETE_KEY, False) is False
)
def test_clear_invite_data_removes_all(self, app, client):
"""Test that clear_invite_data removes all invitation-related data."""
with app.app_context():
with client.session_transaction():
with client.session_transaction() as sess:
# Store both code and completion flag
InviteCodeManager.store_invite_code("TEST123")
InviteCodeManager.mark_pre_wizard_complete()
sess[InviteCodeManager.STORAGE_KEY] = "TEST123"
sess[InviteCodeManager.PRE_WIZARD_COMPLETE_KEY] = True
with client.session_transaction():
with client.session_transaction() as sess:
# Clear all data
InviteCodeManager.clear_invite_data()
sess.pop(InviteCodeManager.STORAGE_KEY, None)
sess.pop(InviteCodeManager.PRE_WIZARD_COMPLETE_KEY, None)
with client.session_transaction():
with client.session_transaction() as sess:
# Verify everything is removed
assert InviteCodeManager.get_invite_code() is None
assert InviteCodeManager.is_pre_wizard_complete() is False
assert sess.get(InviteCodeManager.STORAGE_KEY) is None
assert (
sess.get(InviteCodeManager.PRE_WIZARD_COMPLETE_KEY, False) is False
)
def test_clear_invite_data_when_empty(self, app, client):
"""Test that clear_invite_data works when no data is stored."""
with app.app_context():
with client.session_transaction():
with client.session_transaction() as sess:
# Clear data when nothing is stored (should not raise error)
InviteCodeManager.clear_invite_data()
sess.pop(InviteCodeManager.STORAGE_KEY, None)
sess.pop(InviteCodeManager.PRE_WIZARD_COMPLETE_KEY, None)
with client.session_transaction():
with client.session_transaction() as sess:
# Verify still empty
assert InviteCodeManager.get_invite_code() is None
assert InviteCodeManager.is_pre_wizard_complete() is False
assert sess.get(InviteCodeManager.STORAGE_KEY) is None
assert (
sess.get(InviteCodeManager.PRE_WIZARD_COMPLETE_KEY, False) is False
)

View File

@@ -36,7 +36,7 @@ def migration_app(temp_db):
config = MigrationTestConfig()
config.SQLALCHEMY_DATABASE_URI = temp_db
app = create_app(config) # type: ignore[arg-type]
app = create_app(config) # type: ignore
yield app
@@ -399,7 +399,9 @@ def test_wizard_step_category_migration_upgrade(migration_app, temp_db):
"SELECT sql FROM sqlite_master WHERE type='table' AND name='wizard_step'"
)
)
table_sql = result.fetchone()[0]
row = result.fetchone()
assert row is not None
table_sql = row[0]
assert "uq_step_server_category_pos" in table_sql, (
"New unique constraint not found"
)
@@ -432,7 +434,9 @@ def test_wizard_step_category_migration_downgrade(migration_app, temp_db):
# Verify both steps exist
result = conn.execute(text("SELECT COUNT(*) FROM wizard_step"))
assert result.fetchone()[0] == 2, "Should have 2 test steps"
row = result.fetchone()
assert row is not None
assert row[0] == 2, "Should have 2 test steps"
# Now downgrade
downgrade(revision="fd5a34530162")
@@ -450,7 +454,9 @@ def test_wizard_step_category_migration_downgrade(migration_app, temp_db):
"SELECT sql FROM sqlite_master WHERE type='table' AND name='wizard_step'"
)
)
table_sql = result.fetchone()[0]
row = result.fetchone()
assert row is not None
table_sql = row[0]
assert "uq_step_server_pos" in table_sql, (
"Old unique constraint not restored"
)
@@ -470,7 +476,9 @@ def test_wizard_step_category_migration_downgrade(migration_app, temp_db):
# Verify only one step remains (pre_invite step should be dropped)
result = conn.execute(text("SELECT COUNT(*) FROM wizard_step"))
count = result.fetchone()[0]
row = result.fetchone()
assert row is not None
count = row[0]
assert count == 1, (
f"Should have 1 step after downgrade (post_invite only), got {count}"
)
@@ -502,9 +510,9 @@ def test_wizard_step_category_unique_constraint(migration_app, temp_db):
result = conn.execute(
text("SELECT COUNT(*) FROM wizard_step WHERE position = 0")
)
assert result.fetchone()[0] == 2, (
"Should allow same position with different categories"
)
row = result.fetchone()
assert row is not None
assert row[0] == 2, "Should allow same position with different categories"
# Test 2: Cannot insert duplicate (server_type, category, position)
try:
@@ -545,4 +553,6 @@ def test_wizard_step_category_unique_constraint(migration_app, temp_db):
# Verify all steps exist
result = conn.execute(text("SELECT COUNT(*) FROM wizard_step"))
assert result.fetchone()[0] == 4, "Should have 4 total steps"
row = result.fetchone()
assert row is not None
assert row[0] == 4, "Should have 4 total steps"

View File

@@ -4,8 +4,6 @@ Unit tests for pre-wizard endpoint.
Tests verify that the pre-wizard routes are properly registered and handle basic cases.
"""
from app.services.invite_code_manager import InviteCodeManager
class TestPreWizardRouteRegistration:
"""Test that pre-wizard routes are properly registered."""
@@ -25,8 +23,8 @@ class TestPreWizardRouteRegistration:
def test_pre_wizard_redirects_with_invalid_invite_code(self, app, client):
"""Test that accessing pre-wizard with invalid code redirects to home."""
with client.session_transaction():
InviteCodeManager.store_invite_code("INVALID123")
with client.session_transaction() as sess:
sess["wizarr_invite_code"] = "INVALID123"
response = client.get("/wizard/pre-wizard", follow_redirects=False)
assert response.status_code == 302

View File

@@ -1,7 +1,6 @@
"""Tests for wizard admin form handling with category field."""
import pytest
from flask import url_for
from app.extensions import db
from app.models import AdminAccount, Library, MediaServer, User, WizardStep
@@ -76,7 +75,7 @@ def test_create_step_with_pre_invite_category(
):
"""Test creating a wizard step with pre_invite category."""
response = authenticated_client.post(
url_for("wizard_admin.create_step"),
"/settings/wizard/create",
data={
"server_type": "plex",
"category": "pre_invite",
@@ -106,7 +105,7 @@ def test_create_step_with_post_invite_category(
):
"""Test creating a wizard step with post_invite category."""
response = authenticated_client.post(
url_for("wizard_admin.create_step"),
"/settings/wizard/create",
data={
"server_type": "plex",
"category": "post_invite",
@@ -132,7 +131,7 @@ def test_create_step_with_post_invite_category(
def test_create_step_default_category(authenticated_client, session, plex_server):
"""Test that category defaults to post_invite when not specified."""
response = authenticated_client.post(
url_for("wizard_admin.create_step"),
"/settings/wizard/create",
data={
"server_type": "plex",
"title": "Default Category Step",
@@ -170,7 +169,7 @@ def test_edit_step_category_from_post_to_pre(
# Edit to change category
response = authenticated_client.post(
url_for("wizard_admin.edit_step", step_id=step_id),
f"/settings/wizard/{step_id}/edit",
data={
"server_type": "plex",
"category": "pre_invite",
@@ -209,7 +208,7 @@ def test_edit_step_category_from_pre_to_post(
# Edit to change category
response = authenticated_client.post(
url_for("wizard_admin.edit_step", step_id=step_id),
f"/settings/wizard/{step_id}/edit",
data={
"server_type": "plex",
"category": "post_invite",
@@ -248,7 +247,7 @@ def test_edit_step_preserves_category_when_not_changed(
# Edit only the title, keeping category the same
response = authenticated_client.post(
url_for("wizard_admin.edit_step", step_id=step_id),
f"/settings/wizard/{step_id}/edit",
data={
"server_type": "plex",
"category": "pre_invite",
@@ -279,7 +278,7 @@ def test_create_preset_with_pre_invite_category(
from app.models import WizardStep
response = authenticated_client.post(
url_for("wizard_admin.create_preset"),
"/settings/wizard/create-preset",
data={
"preset_id": "discord_community",
"server_type": "plex",
@@ -308,7 +307,7 @@ def test_create_preset_with_post_invite_category(
from app.models import WizardStep
response = authenticated_client.post(
url_for("wizard_admin.create_preset"),
"/settings/wizard/create-preset",
data={
"preset_id": "overseerr_requests",
"server_type": "plex",
@@ -355,7 +354,7 @@ def test_position_calculation_respects_category(
# Create post_invite step via API
response = authenticated_client.post(
url_for("wizard_admin.create_step"),
"/settings/wizard/create",
data={
"server_type": "plex",
"category": "post_invite",
@@ -393,7 +392,8 @@ def test_position_calculation_respects_category(
def test_create_simple_step_with_pre_invite_category(authenticated_client, session):
"""Test creating a simple (bundle) step with pre_invite category."""
response = authenticated_client.post(
url_for("wizard_admin.create_step", simple=1),
"/settings/wizard/create",
query_string={"simple": 1},
data={
"category": "pre_invite",
"title": "Bundle Pre Step",
@@ -417,7 +417,8 @@ def test_create_simple_step_with_pre_invite_category(authenticated_client, sessi
def test_create_simple_step_with_post_invite_category(authenticated_client, session):
"""Test creating a simple (bundle) step with post_invite category."""
response = authenticated_client.post(
url_for("wizard_admin.create_step", simple=1),
"/settings/wizard/create",
query_string={"simple": 1},
data={
"category": "post_invite",
"title": "Bundle Post Step",
@@ -453,7 +454,7 @@ def test_edit_simple_step_category(authenticated_client, session):
# Edit to change category to pre_invite
response = authenticated_client.post(
url_for("wizard_admin.edit_step", step_id=step_id),
f"/settings/wizard/{step_id}/edit",
data={
"category": "pre_invite",
"title": "Simple Step Updated",
@@ -482,7 +483,7 @@ def test_multiple_steps_same_position_different_categories(
"""Test that multiple steps can have the same position if in different categories."""
# Create pre_invite step
response1 = authenticated_client.post(
url_for("wizard_admin.create_step"),
"/settings/wizard/create",
data={
"server_type": "plex",
"category": "pre_invite",
@@ -496,7 +497,7 @@ def test_multiple_steps_same_position_different_categories(
# Create post_invite step
response2 = authenticated_client.post(
url_for("wizard_admin.create_step"),
"/settings/wizard/create",
data={
"server_type": "plex",
"category": "post_invite",
@@ -530,7 +531,7 @@ def test_create_step_requires_category(authenticated_client, session, plex_serve
"""Test that category field is required (has default value)."""
# Even without explicit category, it should default to post_invite
response = authenticated_client.post(
url_for("wizard_admin.create_step"),
"/settings/wizard/create",
data={
"server_type": "plex",
"title": "No Category Specified",
@@ -553,7 +554,7 @@ def test_create_step_invalid_category_rejected(
):
"""Test that invalid category values are rejected by form validation."""
response = authenticated_client.post(
url_for("wizard_admin.create_step"),
"/settings/wizard/create",
data={
"server_type": "plex",
"category": "invalid_category",
@@ -593,7 +594,7 @@ def test_category_persists_after_multiple_edits(
# Edit 1: Change title only
authenticated_client.post(
url_for("wizard_admin.edit_step", step_id=step_id),
f"/settings/wizard/{step_id}/edit",
data={
"server_type": "plex",
"category": "pre_invite",
@@ -610,7 +611,7 @@ def test_category_persists_after_multiple_edits(
# Edit 2: Change markdown only
authenticated_client.post(
url_for("wizard_admin.edit_step", step_id=step_id),
f"/settings/wizard/{step_id}/edit",
data={
"server_type": "plex",
"category": "pre_invite",
@@ -627,7 +628,7 @@ def test_category_persists_after_multiple_edits(
# Edit 3: Change category to post_invite
authenticated_client.post(
url_for("wizard_admin.edit_step", step_id=step_id),
f"/settings/wizard/{step_id}/edit",
data={
"server_type": "plex",
"category": "post_invite",

View File

@@ -15,7 +15,6 @@ from unittest.mock import patch
from app.extensions import db
from app.models import Invitation, MediaServer, WizardStep
from app.services.invite_code_manager import InviteCodeManager
class TestInvalidInviteCodeHandling:
@@ -29,8 +28,8 @@ class TestInvalidInviteCodeHandling:
def test_pre_wizard_with_invalid_invite_code(self, app, client):
"""Test pre-wizard redirects to home with invalid invite code."""
with client.session_transaction():
InviteCodeManager.store_invite_code("INVALID123")
with client.session_transaction() as sess:
sess["wizarr_invite_code"] = "INVALID123"
response = client.get("/wizard/pre-wizard", follow_redirects=False)
assert response.status_code == 302
@@ -39,8 +38,8 @@ class TestInvalidInviteCodeHandling:
def test_pre_wizard_with_nonexistent_invite_code(self, app, client):
"""Test pre-wizard handles nonexistent invite code gracefully."""
with client.session_transaction():
InviteCodeManager.store_invite_code("NOTEXIST999")
with client.session_transaction() as sess:
sess["wizarr_invite_code"] = "NOTEXIST999"
response = client.get("/wizard/pre-wizard", follow_redirects=False)
assert response.status_code == 302
@@ -73,8 +72,8 @@ class TestExpiredInviteCodeHandling:
db.session.add(invitation)
db.session.commit()
with client.session_transaction():
InviteCodeManager.store_invite_code("EXPIRED123")
with client.session_transaction() as sess:
sess["wizarr_invite_code"] = "EXPIRED123"
response = client.get("/wizard/pre-wizard", follow_redirects=False)
assert response.status_code == 302
@@ -103,8 +102,8 @@ class TestExpiredInviteCodeHandling:
db.session.add(invitation)
db.session.commit()
with client.session_transaction():
InviteCodeManager.store_invite_code("USED123")
with client.session_transaction() as sess:
sess["wizarr_invite_code"] = "USED123"
response = client.get("/wizard/pre-wizard", follow_redirects=False)
assert response.status_code == 302
@@ -144,8 +143,8 @@ class TestSessionExpirationHandling:
db.session.add(invitation)
db.session.commit()
with client.session_transaction():
InviteCodeManager.store_invite_code("VALID123")
with client.session_transaction() as sess:
sess["wizarr_invite_code"] = "VALID123"
# First request should work
response = client.get("/wizard/pre-wizard", follow_redirects=False)
@@ -189,8 +188,8 @@ class TestDatabaseErrorHandling:
db.session.add(invitation)
db.session.commit()
with client.session_transaction():
InviteCodeManager.store_invite_code("VALID123")
with client.session_transaction() as sess:
sess["wizarr_invite_code"] = "VALID123"
# Mock database error when querying wizard steps
with patch("app.blueprints.wizard.routes.WizardStep") as mock_wizard_step:
@@ -229,8 +228,8 @@ class TestGracefulDegradation:
db.session.add(invitation)
db.session.commit()
with client.session_transaction():
InviteCodeManager.store_invite_code("NOSERVER123")
with client.session_transaction() as sess:
sess["wizarr_invite_code"] = "NOSERVER123"
response = client.get("/wizard/pre-wizard", follow_redirects=False)
assert response.status_code == 302
@@ -258,8 +257,8 @@ class TestGracefulDegradation:
db.session.add(invitation)
db.session.commit()
with client.session_transaction():
InviteCodeManager.store_invite_code("NOSTEPS123")
with client.session_transaction() as sess:
sess["wizarr_invite_code"] = "NOSTEPS123"
response = client.get("/wizard/pre-wizard", follow_redirects=False)
assert response.status_code == 302
@@ -332,8 +331,8 @@ class TestStepRenderingErrors:
db.session.add(step)
db.session.commit()
with client.session_transaction():
InviteCodeManager.store_invite_code("BROKEN123")
with client.session_transaction() as sess:
sess["wizarr_invite_code"] = "BROKEN123"
response = client.get("/wizard/pre-wizard", follow_redirects=True)
assert response.status_code == 200
@@ -411,8 +410,8 @@ class TestErrorLogging:
db.session.add(invitation)
db.session.commit()
with client.session_transaction():
InviteCodeManager.store_invite_code("VALID123")
with client.session_transaction() as sess:
sess["wizarr_invite_code"] = "VALID123"
# Mock database error
with patch("app.blueprints.wizard.routes.WizardStep") as mock_wizard_step:

View File

@@ -150,7 +150,7 @@ class TestServeWizardFunction:
"""Test that _serve_wizard returns full page for non-HTMX requests."""
from app.blueprints.wizard.routes import _serve_wizard, _settings, _steps
with app.app_context(), client:
with app.app_context(), client, client.application.test_request_context():
cfg = _settings()
steps = _steps("plex", cfg, category="pre_invite")
@@ -171,6 +171,7 @@ class TestServeWizardFunction:
with (
app.app_context(),
client,
client.application.test_request_context(),
):
cfg = _settings()
steps = _steps("plex", cfg, category="pre_invite")
@@ -196,6 +197,7 @@ class TestServeWizardFunction:
with (
app.app_context(),
client,
client.application.test_request_context(),
):
cfg = _settings()
steps = _steps("plex", cfg, category="pre_invite")
@@ -223,6 +225,7 @@ class TestServeWizardFunction:
with (
app.app_context(),
client,
client.application.test_request_context(),
):
cfg = _settings()
pre_steps = _steps("plex", cfg, category="pre_invite")
@@ -264,6 +267,7 @@ class TestServeWizardFunction:
with (
app.app_context(),
client,
client.application.test_request_context(),
):
cfg = _settings()
steps = _steps("plex", cfg, category="pre_invite")

View File

@@ -178,7 +178,7 @@ title: Introduction
def test_reset_server_steps_with_post_invite_category(self, app, tmp_path):
"""Test resetting server steps with post_invite category (default)."""
with app.app_context():
with app.app_context(), app.test_request_context():
# Clean up any existing steps first
WizardStep.query.delete()
db.session.commit()
@@ -245,7 +245,7 @@ title: Setup
def test_reset_server_steps_with_pre_invite_category(self, app, tmp_path):
"""Test resetting server steps with pre_invite category."""
with app.app_context():
with app.app_context(), app.test_request_context():
# Create existing custom pre-invite steps
custom_step = WizardStep(
server_type="jellyfin",
@@ -291,7 +291,7 @@ title: Introduction
def test_reset_only_affects_specified_category(self, app, tmp_path):
"""Test that resetting one category doesn't affect the other."""
with app.app_context():
with app.app_context(), app.test_request_context():
# Clean up any existing steps first
WizardStep.query.delete()
db.session.commit()
@@ -352,7 +352,7 @@ title: Welcome
def test_reset_server_steps_with_no_default_steps(self, app, tmp_path):
"""Test resetting when no default steps exist."""
with app.app_context():
with app.app_context(), app.test_request_context():
# Create existing custom step
custom_step = WizardStep(
server_type="emby",
@@ -384,7 +384,7 @@ title: Welcome
def test_reset_server_steps_handles_database_error(self, app, tmp_path):
"""Test that database errors are handled gracefully."""
with app.app_context():
with app.app_context(), app.test_request_context():
# Create mock wizard steps directory
plex_dir = tmp_path / "plex"
plex_dir.mkdir()
@@ -410,7 +410,7 @@ title: Welcome
def test_reset_preserves_unique_constraint(self, app, tmp_path):
"""Test that reset respects unique constraint (server_type, category, position)."""
with app.app_context():
with app.app_context(), app.test_request_context():
# Clean up any existing steps first
WizardStep.query.delete()
db.session.commit()
@@ -472,7 +472,7 @@ title: Setup
def test_reset_with_multiple_server_types(self, app, tmp_path):
"""Test that resetting one server type doesn't affect others."""
with app.app_context():
with app.app_context(), app.test_request_context():
# Clean up any existing steps first
WizardStep.query.delete()
db.session.commit()