mirror of
https://github.com/wizarrrr/wizarr.git
synced 2025-12-23 23:59:23 -05:00
fix: update API key hashing to use UTF-8 encoding and improve bundle name handling in export
fix: add requests.exceptions import in Audiobookrequest and Ombi clients fix: ignore type errors in invitation unit tests for library and server associations fix: use timezone-aware datetime for export dates in WizardExportImportService
This commit is contained in:
@@ -60,8 +60,11 @@ def require_api_key(f):
|
||||
logger.warning("API request without API key from %s", request.remote_addr)
|
||||
abort(401, error="Unauthorized")
|
||||
|
||||
# Type assertion since we've already checked that auth_key exists
|
||||
assert isinstance(auth_key, str)
|
||||
|
||||
# Hash the provided key to compare with stored hash
|
||||
key_hash = hashlib.sha256(auth_key.encode()).hexdigest()
|
||||
key_hash = hashlib.sha256(auth_key.encode("utf-8")).hexdigest()
|
||||
api_key = ApiKey.query.filter_by(key_hash=key_hash, is_active=True).first()
|
||||
|
||||
if not api_key:
|
||||
|
||||
@@ -601,7 +601,11 @@ def export_bundle(bundle_id: int):
|
||||
temp_file_path = temp_file.name
|
||||
|
||||
# Generate filename with bundle name and current date
|
||||
bundle_name = export_data.bundle.name.replace(" ", "_").lower()
|
||||
bundle_name = (
|
||||
export_data.bundle.name.replace(" ", "_").lower()
|
||||
if export_data.bundle
|
||||
else "unknown_bundle"
|
||||
)
|
||||
filename = f"wizard_bundle_{bundle_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
|
||||
|
||||
return send_file(
|
||||
|
||||
@@ -5,6 +5,7 @@ Audiobookrequest companion client implementation.
|
||||
import logging
|
||||
|
||||
import requests
|
||||
import requests.exceptions
|
||||
|
||||
from app.models import Connection
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ Ombi companion client implementation.
|
||||
import logging
|
||||
|
||||
import requests
|
||||
import requests.exceptions
|
||||
|
||||
from app.models import Connection
|
||||
|
||||
|
||||
@@ -778,7 +778,7 @@ def handle_oauth_token(app, token: str, code: str) -> None:
|
||||
api_token = post_setup_client.token
|
||||
threading.Thread(
|
||||
target=_post_join_setup,
|
||||
args=(current_app._get_current_object(), server_url, api_token, token),
|
||||
args=(current_app._get_current_object(), server_url, api_token, token), # type: ignore
|
||||
daemon=True,
|
||||
).start()
|
||||
|
||||
|
||||
@@ -3,8 +3,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import and_
|
||||
@@ -71,7 +71,7 @@ class WizardExportDTO:
|
||||
bundle: WizardBundleDTO | None = None
|
||||
export_date: str = ""
|
||||
total_count: int = 0
|
||||
server_types: list[str] = None
|
||||
server_types: list[str] = field(default_factory=list)
|
||||
export_type: str = "steps" # "steps" or "bundle"
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
@@ -132,7 +132,7 @@ class WizardExportImportService:
|
||||
|
||||
return WizardExportDTO(
|
||||
steps=step_dtos,
|
||||
export_date=datetime.utcnow().isoformat(),
|
||||
export_date=datetime.now(UTC).isoformat(),
|
||||
total_count=len(step_dtos),
|
||||
server_types=[server_type] if step_dtos else [],
|
||||
export_type="steps",
|
||||
@@ -174,7 +174,7 @@ class WizardExportImportService:
|
||||
|
||||
return WizardExportDTO(
|
||||
bundle=bundle_dto,
|
||||
export_date=datetime.utcnow().isoformat(),
|
||||
export_date=datetime.now(UTC).isoformat(),
|
||||
export_type="bundle",
|
||||
)
|
||||
|
||||
@@ -269,9 +269,11 @@ class WizardExportImportService:
|
||||
|
||||
# Check required fields
|
||||
required_fields = ["server_type", "position", "markdown"]
|
||||
for field in required_fields:
|
||||
if field not in step:
|
||||
errors.append(f"Step {index}: missing required field '{field}'")
|
||||
for required_field in required_fields:
|
||||
if required_field not in step:
|
||||
errors.append(
|
||||
f"Step {index}: missing required field '{required_field}'"
|
||||
)
|
||||
|
||||
# Validate field types
|
||||
if "server_type" in step and not isinstance(step["server_type"], str):
|
||||
|
||||
@@ -22,7 +22,7 @@ def check_expiring(app=None):
|
||||
from flask import current_app
|
||||
|
||||
try:
|
||||
app = current_app._get_current_object()
|
||||
app = current_app._get_current_object() # type: ignore
|
||||
except RuntimeError:
|
||||
# If we're outside application context, we need the app to be passed
|
||||
logging.error(
|
||||
|
||||
@@ -62,7 +62,7 @@ mock_state = MockMediaServerState()
|
||||
class MockJellyfinClient:
|
||||
"""Mock Jellyfin client that simulates API responses."""
|
||||
|
||||
def __init__(self, url: str = None, token: str = None, **kwargs):
|
||||
def __init__(self, url: str | None = None, token: str | None = None, **kwargs):
|
||||
self.url = url or "http://localhost:8096"
|
||||
self.token = token or "mock-api-key"
|
||||
self.server_id = kwargs.get("server_id")
|
||||
@@ -264,7 +264,7 @@ class MockJellyfinClient:
|
||||
class MockPlexClient:
|
||||
"""Mock Plex client that simulates PlexAPI responses."""
|
||||
|
||||
def __init__(self, url: str = None, token: str = None, **kwargs):
|
||||
def __init__(self, url: str | None = None, token: str | None = None, **kwargs):
|
||||
self.url = url or "http://localhost:32400"
|
||||
self.token = token or "mock-plex-token"
|
||||
self.server_id = kwargs.get("server_id")
|
||||
@@ -363,7 +363,7 @@ class MockPlexClient:
|
||||
class MockAudiobookshelfClient:
|
||||
"""Mock Audiobookshelf client."""
|
||||
|
||||
def __init__(self, url: str = None, token: str = None, **kwargs):
|
||||
def __init__(self, url: str | None = None, token: str | None = None, **kwargs):
|
||||
self.url = url or "http://localhost:13378"
|
||||
self.token = token or "mock-abs-token"
|
||||
self.server_id = kwargs.get("server_id")
|
||||
|
||||
@@ -644,7 +644,7 @@ class TestInvitationMarkingUsed:
|
||||
assert invite.used_by == user
|
||||
|
||||
# Verify user is in invitation's users collection
|
||||
assert user in invite.users
|
||||
assert user in invite.users # type: ignore
|
||||
|
||||
def test_mark_server_used_multi_server_partial(self, app):
|
||||
"""Test marking one server as used in multi-server invitation."""
|
||||
@@ -691,7 +691,7 @@ class TestInvitationMarkingUsed:
|
||||
assert invite.used is False # Not all servers used yet
|
||||
|
||||
# But user should be tracked
|
||||
assert user1 in invite.users
|
||||
assert user1 in invite.users # type: ignore
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -161,8 +161,8 @@ class TestInvitationCreation:
|
||||
assert invite.unlimited is True
|
||||
assert invite.duration == "14"
|
||||
assert invite.expires is not None
|
||||
assert len(invite.servers) == 1
|
||||
assert invite.servers[0] == server
|
||||
assert len(invite.servers) == 1 # type: ignore
|
||||
assert invite.servers[0] == server # type: ignore
|
||||
|
||||
def test_create_invitation_with_libraries(self, app):
|
||||
"""Test creating invitation with specific libraries."""
|
||||
@@ -199,8 +199,8 @@ class TestInvitationCreation:
|
||||
invite = create_invite(form_data)
|
||||
|
||||
# Verify library associations
|
||||
assert len(invite.libraries) == 2
|
||||
library_ids = {lib.id for lib in invite.libraries}
|
||||
assert len(invite.libraries) == 2 # type: ignore
|
||||
library_ids = {lib.id for lib in invite.libraries} # type: ignore
|
||||
assert library_ids == {lib1.id, lib2.id}
|
||||
|
||||
def test_create_multi_server_invitation(self, app):
|
||||
@@ -233,8 +233,8 @@ class TestInvitationCreation:
|
||||
invite = create_invite(form_data)
|
||||
|
||||
# Verify server associations
|
||||
assert len(invite.servers) == 2
|
||||
server_ids = {server.id for server in invite.servers}
|
||||
assert len(invite.servers) == 2 # type: ignore
|
||||
server_ids = {server.id for server in invite.servers} # type: ignore
|
||||
assert server_ids == {server1.id, server2.id}
|
||||
|
||||
def test_create_invitation_validation_errors(self, app):
|
||||
@@ -456,13 +456,13 @@ class TestInvitationRelationships:
|
||||
invite = create_invite(form_data)
|
||||
|
||||
# Verify server relationships
|
||||
assert len(invite.servers) == 2
|
||||
assert server1 in invite.servers
|
||||
assert server2 in invite.servers
|
||||
assert len(invite.servers) == 2 # type: ignore
|
||||
assert server1 in invite.servers # type: ignore
|
||||
assert server2 in invite.servers # type: ignore
|
||||
|
||||
# Verify reverse relationship
|
||||
assert invite in server1.invites
|
||||
assert invite in server2.invites
|
||||
assert invite in server1.invites # type: ignore
|
||||
assert invite in server2.invites # type: ignore
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user