chore: update Ruff linting rules formatting (#2570)

This commit is contained in:
Nico Miguelino
2025-11-12 10:21:49 -08:00
committed by GitHub
parent b9c206b10a
commit 6705477e5b
58 changed files with 838 additions and 860 deletions

View File

@@ -41,6 +41,10 @@ jobs:
run: |
poetry install --only=dev-host
- name: Analyzing the code with ruff
- name: Run Ruff linting checks
run: |
poetry run ruff check .
- name: Run Ruff formatting checks
run: |
poetry run ruff format --check .

View File

@@ -19,5 +19,5 @@ class AssetAdmin(admin.ModelAdmin):
'is_active',
'nocache',
'play_order',
'skip_asset_check'
'skip_asset_check',
)

View File

@@ -20,9 +20,9 @@ def template(request, template_name, context):
context['date_format'] = settings['date_format']
context['default_duration'] = settings['default_duration']
context['default_streaming_duration'] = (
settings['default_streaming_duration']
)
context['default_streaming_duration'] = settings[
'default_streaming_duration'
]
context['template_settings'] = {
'imports': ['from lib.utils import template_handle_unicode'],
'default_filters': ['template_handle_unicode'],
@@ -40,7 +40,7 @@ def prepare_default_asset(**kwargs):
asset_id = 'default_{}'.format(uuid.uuid4().hex)
duration = (
int(get_video_duration(kwargs['uri']).total_seconds())
if "video" == kwargs['mimetype']
if 'video' == kwargs['mimetype']
else kwargs['duration']
)
@@ -56,7 +56,7 @@ def prepare_default_asset(**kwargs):
'play_order': 0,
'skip_asset_check': 0,
'start_date': kwargs['start_date'],
'uri': kwargs['uri']
'uri': kwargs['uri'],
}
@@ -67,7 +67,7 @@ def add_default_assets():
default_asset_settings = {
'start_date': datetime_now,
'end_date': datetime_now.replace(year=datetime_now.year + 6),
'duration': settings['default_duration']
'duration': settings['default_duration'],
}
default_assets_yaml = path.join(
@@ -79,11 +79,13 @@ def add_default_assets():
default_assets = yaml.safe_load(yaml_file).get('assets')
for default_asset in default_assets:
default_asset_settings.update({
'name': default_asset.get('name'),
'uri': default_asset.get('uri'),
'mimetype': default_asset.get('mimetype')
})
default_asset_settings.update(
{
'name': default_asset.get('name'),
'uri': default_asset.get('uri'),
'mimetype': default_asset.get('mimetype'),
}
)
asset = prepare_default_asset(**default_asset_settings)
if asset:

View File

@@ -10,7 +10,8 @@ def generate_asset_id():
class Asset(models.Model):
asset_id = models.TextField(
primary_key=True, default=generate_asset_id, editable=False)
primary_key=True, default=generate_asset_id, editable=False
)
name = models.TextField(blank=True, null=True)
uri = models.TextField(blank=True, null=True)
md5 = models.TextField(blank=True, null=True)
@@ -33,8 +34,6 @@ class Asset(models.Model):
def is_active(self):
if self.is_enabled and self.start_date and self.end_date:
current_time = timezone.now()
return (
self.start_date < current_time < self.end_date
)
return self.start_date < current_time < self.end_date
return False

View File

@@ -1,3 +1,3 @@
from django.test import TestCase # noqa F401
from django.test import TestCase # noqa F401
# Create your tests here.

View File

@@ -24,9 +24,9 @@ def react(request):
return template(request, 'react.html', {})
@require_http_methods(["GET", "POST"])
@require_http_methods(['GET', 'POST'])
def login(request):
if request.method == "POST":
if request.method == 'POST':
username = request.POST.get('username')
password = request.POST.get('password')
@@ -38,16 +38,16 @@ def login(request):
return redirect(reverse('anthias_app:react'))
else:
messages.error(request, 'Invalid username or password')
return template(request, 'login.html', {
'next': request.GET.get('next', '/')
})
return template(
request, 'login.html', {'next': request.GET.get('next', '/')}
)
return template(request, 'login.html', {
'next': request.GET.get('next', '/')
})
return template(
request, 'login.html', {'next': request.GET.get('next', '/')}
)
@require_http_methods(["GET"])
@require_http_methods(['GET'])
def splash_page(request):
ip_addresses = []
@@ -59,6 +59,6 @@ def splash_page(request):
else:
ip_addresses.append(f'http://{ip_address}')
return template(request, 'splash-page.html', {
'ip_addresses': ip_addresses
})
return template(
request, 'splash-page.html', {'ip_addresses': ip_addresses}
)

View File

@@ -40,18 +40,13 @@ if not DEBUG:
SECRET_KEY = device_settings.get('django_secret_key')
else:
# SECURITY WARNING: keep the secret key used in production secret!
SECRET_KEY = 'django-insecure-7rz*$)g6dk&=h-3imq2xw*iu!zuhfb&w6v482_vs!w@4_gha=j' # noqa: E501
SECRET_KEY = (
'django-insecure-7rz*$)g6dk&=h-3imq2xw*iu!zuhfb&w6v482_vs!w@4_gha=j' # noqa: E501
)
ALLOWED_HOSTS = [
'127.0.0.1',
'localhost',
'anthias',
'anthias-server'
]
ALLOWED_HOSTS = ['127.0.0.1', 'localhost', 'anthias', 'anthias-server']
CSRF_TRUSTED_ORIGINS = [
'http://anthias'
]
CSRF_TRUSTED_ORIGINS = ['http://anthias']
# Application definition
@@ -109,7 +104,8 @@ DATABASES = {
'default': {
'ENGINE': 'django.db.backends.sqlite3',
'NAME': (
'/data/.screenly/test.db' if getenv('ENVIRONMENT') == 'test'
'/data/.screenly/test.db'
if getenv('ENVIRONMENT') == 'test'
else '/data/.screenly/screenly.db'
),
},
@@ -173,7 +169,7 @@ REST_FRAMEWORK = {
'EXCEPTION_HANDLER': 'api.helpers.custom_exception_handler',
# The project uses custom authentication classes,
# so we need to disable the default ones.
'DEFAULT_AUTHENTICATION_CLASSES': []
'DEFAULT_AUTHENTICATION_CLASSES': [],
}
SPECTACULAR_SETTINGS = {

View File

@@ -13,6 +13,7 @@ Including another URLconf
1. Import the include() function: from django.urls import include, path
2. Add a URL to urlpatterns: path('blog/', include('blog.urls'))
"""
from django.contrib import admin
from django.urls import include, path
from drf_spectacular.views import SpectacularAPIView, SpectacularRedocView
@@ -31,11 +32,7 @@ urlpatterns = [
path('', include('anthias_app.urls')),
path('api/', include('api.urls')),
path('api/schema/', SpectacularAPIView.as_view(), name='schema'),
path(
'api/docs/',
APIDocView.as_view(url_name='schema'),
name='redoc'
),
path('api/docs/', APIDocView.as_view(url_name='schema'), name='redoc'),
]
# @TODO: Write custom 403 and 404 pages.

View File

@@ -1,3 +1,3 @@
from django.contrib import admin # noqa F401
from django.contrib import admin # noqa F401
# Register your models here.

View File

@@ -1,6 +1,6 @@
def preprocessing_filter_spec(endpoints):
filtered = []
for (path, path_regex, method, callback) in endpoints:
if path.startswith("/api/v2"):
for path, path_regex, method, callback in endpoints:
if path.startswith('/api/v2'):
filtered.append((path, path_regex, method, callback))
return filtered

View File

@@ -15,7 +15,6 @@ class AssetCreationError(Exception):
def update_asset(asset, data):
for key, value in list(data.items()):
if (
key in ['asset_id', 'is_processing', 'mimetype', 'uri']
or key not in asset
@@ -25,19 +24,17 @@ def update_asset(asset, data):
if key in ['start_date', 'end_date']:
value = date_parser.parse(value).replace(tzinfo=None)
if (
key in [
'play_order',
'skip_asset_check',
'is_enabled',
'is_active',
'nocache',
]
):
if key in [
'play_order',
'skip_asset_check',
'is_enabled',
'is_active',
'nocache',
]:
value = int(value)
if key == 'duration':
if "video" not in asset['mimetype']:
if 'video' not in asset['mimetype']:
continue
value = int(value)
@@ -48,8 +45,7 @@ def custom_exception_handler(exc, context):
exception_handler(exc, context)
return Response(
{'error': str(exc)},
status=status.HTTP_500_INTERNAL_SERVER_ERROR
{'error': str(exc)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR
)
@@ -59,11 +55,7 @@ def get_active_asset_ids():
start_date__isnull=False,
end_date__isnull=False,
)
return [
asset.asset_id
for asset in enabled_assets
if asset.is_active()
]
return [asset.asset_id for asset in enabled_assets if asset.is_active()]
def save_active_assets_ordering(active_asset_ids):

View File

@@ -30,10 +30,10 @@ def get_unique_name(name):
def validate_uri(uri):
if uri.startswith('/'):
if not path.isfile(uri):
raise Exception("Invalid file path. Failed to add asset.")
raise Exception('Invalid file path. Failed to add asset.')
else:
if not validate_url(uri):
raise Exception("Invalid URL. Failed to add asset.")
raise Exception('Invalid URL. Failed to add asset.')
class AssetSerializer(ModelSerializer):
@@ -77,21 +77,27 @@ class UpdateAssetSerializer(Serializer):
def update(self, instance, validated_data):
instance.name = validated_data.get('name', instance.name)
instance.start_date = validated_data.get(
'start_date', instance.start_date)
'start_date', instance.start_date
)
instance.end_date = validated_data.get('end_date', instance.end_date)
instance.is_enabled = validated_data.get(
'is_enabled', instance.is_enabled)
'is_enabled', instance.is_enabled
)
instance.is_processing = validated_data.get(
'is_processing', instance.is_processing)
'is_processing', instance.is_processing
)
instance.nocache = validated_data.get('nocache', instance.nocache)
instance.play_order = validated_data.get(
'play_order', instance.play_order)
'play_order', instance.play_order
)
instance.skip_asset_check = validated_data.get(
'skip_asset_check', instance.skip_asset_check)
'skip_asset_check', instance.skip_asset_check
)
if 'video' not in instance.mimetype:
instance.duration = validated_data.get(
'duration', instance.duration)
'duration', instance.duration
)
instance.save()

View File

@@ -30,13 +30,9 @@ class CreateAssetSerializerMixin:
'name': name,
'mimetype': data.get('mimetype'),
'is_enabled': data.get(
'is_enabled',
False if version == 'v2' else 0
),
'nocache': data.get(
'nocache',
False if version == 'v2' else 0
'is_enabled', False if version == 'v2' else 0
),
'nocache': data.get('nocache', False if version == 'v2' else 0),
}
uri = (
@@ -44,8 +40,8 @@ class CreateAssetSerializerMixin:
.replace(ampersand_fix, '&')
.replace('<', '&lt;')
.replace('>', '&gt;')
.replace('\'', '&apos;')
.replace('\"', '&quot;')
.replace("'", '&apos;')
.replace('"', '&quot;')
)
validate_uri(uri)
@@ -61,15 +57,15 @@ class CreateAssetSerializerMixin:
uri = new_uri
if 'youtube_asset' in asset['mimetype']:
(
uri, asset['name'], asset['duration']
) = download_video_from_youtube(uri, asset['asset_id'])
(uri, asset['name'], asset['duration']) = (
download_video_from_youtube(uri, asset['asset_id'])
)
asset['mimetype'] = 'video'
asset['is_processing'] = True if version == 'v2' else 1
asset['uri'] = uri
if "video" in asset['mimetype']:
if 'video' in asset['mimetype']:
if int(data.get('duration')) == 0:
original_mimetype = data.get('mimetype')
@@ -84,9 +80,7 @@ class CreateAssetSerializerMixin:
)
else:
# Crashes if it's not an int. We want that.
duration = data.get(
'duration', settings['default_duration']
)
duration = data.get('duration', settings['default_duration'])
if version == 'v2':
asset['duration'] = duration
@@ -131,8 +125,10 @@ class PlaylistOrderSerializerMixin(Serializer):
class BackupViewSerializerMixin(Serializer):
pass
class RebootViewSerializerMixin(Serializer):
pass
class ShutdownViewSerializerMixin(Serializer):
pass

View File

@@ -66,18 +66,19 @@ class CreateAssetSerializerV1_1(Serializer):
uri = path.join(settings['assetdir'], asset['asset_id'])
if 'youtube_asset' in asset['mimetype']:
(
uri, asset['name'], asset['duration']
) = download_video_from_youtube(uri, asset['asset_id'])
(uri, asset['name'], asset['duration']) = (
download_video_from_youtube(uri, asset['asset_id'])
)
asset['mimetype'] = 'video'
asset['is_processing'] = 1
asset['uri'] = uri
if "video" in asset['mimetype']:
if 'video' in asset['mimetype']:
if int(data.get('duration')) == 0:
asset['duration'] = int(
get_video_duration(uri).total_seconds())
get_video_duration(uri).total_seconds()
)
else:
# Crashes if it's not an int. We want that.
asset['duration'] = int(data.get('duration'))
@@ -87,15 +88,15 @@ class CreateAssetSerializerV1_1(Serializer):
if data.get('start_date'):
asset['start_date'] = data.get('start_date').replace(tzinfo=None)
else:
asset['start_date'] = ""
asset['start_date'] = ''
if data.get('end_date'):
asset['end_date'] = data.get('end_date').replace(tzinfo=None)
else:
asset['end_date'] = ""
asset['end_date'] = ''
if not asset['skip_asset_check'] and url_fails(asset['uri']):
raise Exception("Could not retrieve file. Check the asset URL.")
raise Exception('Could not retrieve file. Check the asset URL.')
return asset

View File

@@ -1,6 +1,7 @@
"""
Tests for asset-related API endpoints.
"""
import mock
from django.test import TestCase
from django.urls import reverse
@@ -36,14 +37,13 @@ class CRUDAssetEndpointsTest(TestCase, ParametrizedTestCase):
def create_asset(self, data, version):
asset_list_url = reverse(f'api:asset_list_{version}')
return self.client.post(
asset_list_url,
data=get_request_data(data, version)
asset_list_url, data=get_request_data(data, version)
).data
def update_asset(self, asset_id, data, version):
return self.client.put(
reverse(f'api:asset_detail_{version}', args=[asset_id]),
data=get_request_data(data, version)
data=get_request_data(data, version),
).data
def get_asset(self, asset_id, version):
@@ -55,7 +55,9 @@ class CRUDAssetEndpointsTest(TestCase, ParametrizedTestCase):
return self.client.delete(url)
@parametrize_version
def test_get_assets_when_first_time_setup_should_initially_return_empty(self, version): # noqa: E501
def test_get_assets_when_first_time_setup_should_initially_return_empty(
self, version
): # noqa: E501
asset_list_url = reverse(f'api:asset_list_{version}')
response = self.client.get(asset_list_url)
assets = response.data
@@ -67,8 +69,7 @@ class CRUDAssetEndpointsTest(TestCase, ParametrizedTestCase):
def test_create_asset_should_return_201(self, version):
asset_list_url = reverse(f'api:asset_list_{version}')
response = self.client.post(
asset_list_url,
data=get_request_data(ASSET_CREATION_DATA, version)
asset_list_url, data=get_request_data(ASSET_CREATION_DATA, version)
)
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
@@ -83,9 +84,7 @@ class CRUDAssetEndpointsTest(TestCase, ParametrizedTestCase):
@mock.patch('api.serializers.mixins.rename')
@mock.patch('api.serializers.mixins.validate_uri')
def test_create_video_asset_v2_with_non_zero_duration_should_fail(
self,
mock_validate_uri,
mock_rename
self, mock_validate_uri, mock_rename
):
"""Test that v2 rejects video assets with non-zero duration."""
mock_validate_uri.return_value = True
@@ -101,22 +100,18 @@ class CRUDAssetEndpointsTest(TestCase, ParametrizedTestCase):
'is_enabled': True,
'nocache': False,
'play_order': 0,
'skip_asset_check': False
'skip_asset_check': False,
}
response = self.client.post(
asset_list_url,
data=test_data,
format='json'
asset_list_url, data=test_data, format='json'
)
self.assertEqual(
response.status_code,
status.HTTP_500_INTERNAL_SERVER_ERROR
response.status_code, status.HTTP_500_INTERNAL_SERVER_ERROR
)
self.assertIn(
'Duration must be zero for video assets',
str(response.data)
'Duration must be zero for video assets', str(response.data)
)
self.assertEqual(mock_rename.call_count, 1)

View File

@@ -1,6 +1,7 @@
"""
Common test utilities and constants for the Anthias API tests.
"""
import json
from django.urls import reverse
@@ -16,7 +17,7 @@ ASSET_CREATION_DATA = {
'is_enabled': 0,
'nocache': 0,
'play_order': 0,
'skip_asset_check': 0
'skip_asset_check': 0,
}
ASSET_UPDATE_DATA_V1_2 = {
'name': 'Anthias',
@@ -28,7 +29,7 @@ ASSET_UPDATE_DATA_V1_2 = {
'is_enabled': 1,
'nocache': 0,
'play_order': 0,
'skip_asset_check': 0
'skip_asset_check': 0,
}
ASSET_UPDATE_DATA_V2 = {
**ASSET_UPDATE_DATA_V1_2,
@@ -38,11 +39,10 @@ ASSET_UPDATE_DATA_V2 = {
'skip_asset_check': False,
}
def get_request_data(data, version):
"""Helper function to format request data based on API version."""
if version in ['v1', 'v1_1']:
return {
'model': json.dumps(data)
}
return {'model': json.dumps(data)}
else:
return data

View File

@@ -1,6 +1,7 @@
"""
Tests for Info API endpoints (v1 and v2).
"""
from unittest import mock
from django.test import TestCase
@@ -25,23 +26,13 @@ class InfoEndpointsTest(TestCase):
for key, expected_value in expected_data.items():
self.assertEqual(data[key], expected_value)
@mock.patch(
'api.views.mixins.is_up_to_date',
return_value=False
)
@mock.patch(
'lib.diagnostics.get_load_avg',
return_value={'15 min': 0.11}
)
@mock.patch('api.views.mixins.is_up_to_date', return_value=False)
@mock.patch('lib.diagnostics.get_load_avg', return_value={'15 min': 0.11})
@mock.patch('api.views.mixins.size', return_value='15G')
@mock.patch('api.views.mixins.statvfs', mock.MagicMock())
@mock.patch('api.views.mixins.r.get', return_value='off')
def test_info_v1_endpoint(
self,
redis_get_mock,
size_mock,
get_load_avg_mock,
is_up_to_date_mock
self, redis_get_mock, size_mock, get_load_avg_mock, is_up_to_date_mock
):
response = self.client.get(self.info_url_v1)
data = response.data
@@ -50,12 +41,9 @@ class InfoEndpointsTest(TestCase):
self.assertEqual(response.status_code, status.HTTP_200_OK)
# Assert mock calls
self._assert_mock_calls([
redis_get_mock,
size_mock,
get_load_avg_mock,
is_up_to_date_mock
])
self._assert_mock_calls(
[redis_get_mock, size_mock, get_load_avg_mock, is_up_to_date_mock]
)
# Assert response data
expected_data = {
@@ -63,57 +51,42 @@ class InfoEndpointsTest(TestCase):
'loadavg': 0.11,
'free_space': '15G',
'display_power': 'off',
'up_to_date': False
'up_to_date': False,
}
self._assert_response_data(data, expected_data)
@mock.patch(
'api.views.v2.is_up_to_date',
return_value=True
)
@mock.patch(
'lib.diagnostics.get_load_avg',
return_value={'15 min': 0.25}
)
@mock.patch('api.views.v2.is_up_to_date', return_value=True)
@mock.patch('lib.diagnostics.get_load_avg', return_value={'15 min': 0.25})
@mock.patch('api.views.v2.size', return_value='20G')
@mock.patch('api.views.v2.statvfs', mock.MagicMock())
@mock.patch('api.views.v2.r.get', return_value='on')
@mock.patch('api.views.v2.diagnostics.get_git_branch', return_value='main')
@mock.patch(
'api.views.v2.diagnostics.get_git_short_hash',
return_value='a1b2c3d'
'api.views.v2.diagnostics.get_git_short_hash', return_value='a1b2c3d'
)
@mock.patch(
'api.views.v2.device_helper.parse_cpu_info',
return_value={'model': 'Raspberry Pi 4'}
)
@mock.patch(
'api.views.v2.diagnostics.get_uptime',
return_value=86400
return_value={'model': 'Raspberry Pi 4'},
)
@mock.patch('api.views.v2.diagnostics.get_uptime', return_value=86400)
@mock.patch(
'api.views.v2.psutil.virtual_memory',
return_value=mock.MagicMock(
total=8192 << 20, # 8GB
used=4096 << 20, # 4GB
free=4096 << 20, # 4GB
used=4096 << 20, # 4GB
free=4096 << 20, # 4GB
shared=0,
buffers=1024 << 20, # 1GB
available=7168 << 20 # 7GB
)
buffers=1024 << 20, # 1GB
available=7168 << 20, # 7GB
),
)
@mock.patch(
'api.views.v2.get_node_mac_address',
return_value='00:11:22:33:44:55'
'api.views.v2.get_node_mac_address', return_value='00:11:22:33:44:55'
)
@mock.patch(
'api.views.v2.get_node_ip',
return_value='192.168.1.100 10.0.0.50'
)
@mock.patch(
'api.views.v2.getenv',
return_value='testuser'
'api.views.v2.get_node_ip', return_value='192.168.1.100 10.0.0.50'
)
@mock.patch('api.views.v2.getenv', return_value='testuser')
def test_info_v2_endpoint(
self,
getenv_mock,
@@ -127,7 +100,7 @@ class InfoEndpointsTest(TestCase):
redis_get_mock,
size_mock,
get_load_avg_mock,
is_up_to_date_mock
is_up_to_date_mock,
):
response = self.client.get(self.info_url_v2)
data = response.data
@@ -136,20 +109,22 @@ class InfoEndpointsTest(TestCase):
self.assertEqual(response.status_code, status.HTTP_200_OK)
# Assert mock calls
self._assert_mock_calls([
redis_get_mock,
size_mock,
get_load_avg_mock,
is_up_to_date_mock,
get_git_branch_mock,
get_git_short_hash_mock,
parse_cpu_info_mock,
get_uptime_mock,
virtual_memory_mock,
mac_address_mock,
get_node_ip_mock,
getenv_mock
])
self._assert_mock_calls(
[
redis_get_mock,
size_mock,
get_load_avg_mock,
is_up_to_date_mock,
get_git_branch_mock,
get_git_short_hash_mock,
parse_cpu_info_mock,
get_uptime_mock,
virtual_memory_mock,
mac_address_mock,
get_node_ip_mock,
getenv_mock,
]
)
# Assert response data
expected_data = {
@@ -160,20 +135,17 @@ class InfoEndpointsTest(TestCase):
'up_to_date': True,
'anthias_version': 'main@a1b2c3d',
'device_model': 'Raspberry Pi 4',
'uptime': {
'days': 1,
'hours': 0.0
},
'uptime': {'days': 1, 'hours': 0.0},
'memory': {
'total': 8192,
'used': 4096,
'free': 4096,
'shared': 0,
'buff': 1024,
'available': 7168
'available': 7168,
},
'ip_addresses': ['http://192.168.1.100', 'http://10.0.0.50'],
'mac_address': '00:11:22:33:44:55',
'host_user': 'testuser'
'host_user': 'testuser',
}
self._assert_response_data(data, expected_data)

View File

@@ -1,6 +1,7 @@
"""
Tests for V1 API endpoints.
"""
import os
from pathlib import Path
from unittest import mock
@@ -66,24 +67,22 @@ class V1EndpointsTest(TestCase, ParametrizedTestCase):
playlist_order_url = reverse('api:playlist_order_v1')
for asset_name in ['Asset #1', 'Asset #2', 'Asset #3']:
Asset.objects.create(**{
**ASSET_CREATION_DATA,
'name': asset_name,
})
Asset.objects.create(
**{
**ASSET_CREATION_DATA,
'name': asset_name,
}
)
self.assertTrue(
all([
asset.play_order == 0
for asset in Asset.objects.all()
])
all([asset.play_order == 0 for asset in Asset.objects.all()])
)
asset_1, asset_2, asset_3 = Asset.objects.all()
asset_ids = [asset_1.asset_id, asset_2.asset_id, asset_3.asset_id]
response = self.client.post(
playlist_order_url,
data={'ids': ','.join(asset_ids)}
playlist_order_url, data={'ids': ','.join(asset_ids)}
)
self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT)
@@ -114,7 +113,7 @@ class V1EndpointsTest(TestCase, ParametrizedTestCase):
@mock.patch(
'api.views.mixins.reboot_anthias.apply_async',
side_effect=(lambda: None)
side_effect=(lambda: None),
)
def test_reboot(self, reboot_anthias_mock):
reboot_url = reverse('api:reboot_v1')
@@ -125,7 +124,7 @@ class V1EndpointsTest(TestCase, ParametrizedTestCase):
@mock.patch(
'api.views.mixins.shutdown_anthias.apply_async',
side_effect=(lambda: None)
side_effect=(lambda: None),
)
def test_shutdown(self, shutdown_anthias_mock):
shutdown_url = reverse('api:shutdown_v1')
@@ -136,19 +135,17 @@ class V1EndpointsTest(TestCase, ParametrizedTestCase):
@mock.patch('api.views.v1.ZmqPublisher.send_to_viewer', return_value=None)
def test_viewer_current_asset(self, send_to_viewer_mock):
asset = Asset.objects.create(**{
**ASSET_CREATION_DATA,
'is_enabled': 1,
})
asset = Asset.objects.create(
**{
**ASSET_CREATION_DATA,
'is_enabled': 1,
}
)
asset_id = asset.asset_id
with (
mock.patch(
'api.views.v1.ZmqCollector.recv_json',
side_effect=(lambda _: {
'current_asset_id': asset_id
})
)
with mock.patch(
'api.views.v1.ZmqCollector.recv_json',
side_effect=(lambda _: {'current_asset_id': asset_id}),
):
viewer_current_asset_url = reverse('api:viewer_current_asset_v1')
response = self.client.get(viewer_current_asset_url)

View File

@@ -1,6 +1,7 @@
"""
Tests for V2 API endpoints.
"""
import hashlib
from unittest import mock
from unittest.mock import patch
@@ -78,15 +79,14 @@ class DeviceSettingsViewV2Test(TestCase):
}
response = self.client.patch(
self.device_settings_url,
data=data,
format='json'
self.device_settings_url, data=data, format='json'
)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertIn('auth_backend', response.data)
self.assertIn(
'is not a valid choice', str(response.data['auth_backend']))
'is not a valid choice', str(response.data['auth_backend'])
)
settings_mock.load.assert_not_called()
settings_mock.save.assert_not_called()
@@ -94,9 +94,7 @@ class DeviceSettingsViewV2Test(TestCase):
@mock.patch('api.views.v2.settings')
@mock.patch('api.views.v2.ZmqPublisher')
def test_patch_device_settings_success(
self,
publisher_mock,
settings_mock
self, publisher_mock, settings_mock
):
settings_mock.load = mock.MagicMock()
settings_mock.save = mock.MagicMock()
@@ -128,22 +126,17 @@ class DeviceSettingsViewV2Test(TestCase):
}
response = self.client.patch(
self.device_settings_url,
data=data,
format='json'
self.device_settings_url, data=data, format='json'
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(
response.data['message'],
'Settings were successfully saved.'
response.data['message'], 'Settings were successfully saved.'
)
settings_mock.load.assert_called_once()
settings_mock.save.assert_called_once()
self.assertEqual(
settings_mock.__setitem__.call_count, 5
)
self.assertEqual(settings_mock.__setitem__.call_count, 5)
publisher_instance.send_to_viewer.assert_called_once_with('reload')
@@ -155,9 +148,7 @@ class DeviceSettingsViewV2Test(TestCase):
}
response = self.client.patch(
self.device_settings_url,
data=data,
format='json'
self.device_settings_url, data=data, format='json'
)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
@@ -214,18 +205,16 @@ class DeviceSettingsViewV2Test(TestCase):
}
expected_hashed_password = hashlib.sha256(
'testpass'.encode('utf-8')).hexdigest()
'testpass'.encode('utf-8')
).hexdigest()
response = self.client.patch(
self.device_settings_url,
data=data,
format='json'
self.device_settings_url, data=data, format='json'
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(
response.data['message'],
'Settings were successfully saved.'
response.data['message'], 'Settings were successfully saved.'
)
settings_mock.load.assert_called_once()
@@ -233,7 +222,8 @@ class DeviceSettingsViewV2Test(TestCase):
settings_mock.__setitem__.assert_any_call('auth_backend', 'auth_basic')
settings_mock.__setitem__.assert_any_call('user', 'testuser')
settings_mock.__setitem__.assert_any_call(
'password', expected_hashed_password)
'password', expected_hashed_password
)
publisher_instance.send_to_viewer.assert_called_once_with('reload')
@@ -274,15 +264,12 @@ class DeviceSettingsViewV2Test(TestCase):
}
response = self.client.patch(
self.device_settings_url,
data=data,
format='json'
self.device_settings_url, data=data, format='json'
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(
response.data['message'],
'Settings were successfully saved.'
response.data['message'], 'Settings were successfully saved.'
)
settings_mock.load.assert_called_once()
@@ -303,7 +290,7 @@ class DeviceSettingsViewV2Test(TestCase):
remove_default_assets_mock,
add_default_assets_mock,
publisher_mock,
settings_mock
settings_mock,
):
settings_mock.load = mock.MagicMock()
settings_mock.save = mock.MagicMock()
@@ -331,15 +318,12 @@ class DeviceSettingsViewV2Test(TestCase):
}
response = self.client.patch(
self.device_settings_url,
data=data,
format='json'
self.device_settings_url, data=data, format='json'
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(
response.data['message'],
'Settings were successfully saved.'
response.data['message'], 'Settings were successfully saved.'
)
settings_mock.load.assert_called_once()
@@ -378,15 +362,12 @@ class DeviceSettingsViewV2Test(TestCase):
}
response = self.client.patch(
self.device_settings_url,
data=data,
format='json'
self.device_settings_url, data=data, format='json'
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(
response.data['message'],
'Settings were successfully saved.'
response.data['message'], 'Settings were successfully saved.'
)
settings_mock.load.assert_called_once()
@@ -405,9 +386,7 @@ class TestIntegrationsViewV2(TestCase):
@patch('api.views.v2.is_balena_app')
@patch('api.views.v2.getenv')
def test_integrations_balena_environment(
self,
mock_getenv,
mock_is_balena
self, mock_getenv, mock_is_balena
):
# Mock Balena environment
mock_is_balena.side_effect = lambda: True
@@ -422,15 +401,18 @@ class TestIntegrationsViewV2(TestCase):
response = self.client.get(self.integrations_url)
self.assertEqual(response.status_code, 200)
self.assertEqual(response.json(), {
'is_balena': True,
'balena_device_id': 'test-device-uuid',
'balena_app_id': 'test-app-id',
'balena_app_name': 'test-app-name',
'balena_supervisor_version': 'test-supervisor-version',
'balena_host_os_version': 'test-host-os-version',
'balena_device_name_at_init': 'test-device-name',
})
self.assertEqual(
response.json(),
{
'is_balena': True,
'balena_device_id': 'test-device-uuid',
'balena_app_id': 'test-app-id',
'balena_app_name': 'test-app-name',
'balena_supervisor_version': 'test-supervisor-version',
'balena_host_os_version': 'test-host-os-version',
'balena_device_name_at_init': 'test-device-name',
},
)
@patch('api.views.v2.is_balena_app')
def test_integrations_non_balena_environment(self, mock_is_balena):
@@ -439,12 +421,15 @@ class TestIntegrationsViewV2(TestCase):
response = self.client.get(self.integrations_url)
self.assertEqual(response.status_code, 200)
self.assertEqual(response.json(), {
'is_balena': False,
'balena_device_id': None,
'balena_app_id': None,
'balena_app_name': None,
'balena_supervisor_version': None,
'balena_host_os_version': None,
'balena_device_name_at_init': None,
})
self.assertEqual(
response.json(),
{
'is_balena': False,
'balena_device_id': None,
'balena_app_id': None,
'balena_app_name': None,
'balena_supervisor_version': None,
'balena_host_os_version': None,
'balena_device_name_at_init': None,
},
)

View File

@@ -6,9 +6,7 @@ from api.views.v1_1 import AssetListViewV1_1, AssetViewV1_1
def get_url_patterns():
return [
path(
'v1.1/assets',
AssetListViewV1_1.as_view(),
name='asset_list_v1_1'
'v1.1/assets', AssetListViewV1_1.as_view(), name='asset_list_v1_1'
),
path(
'v1.1/assets/<str:asset_id>',

View File

@@ -6,9 +6,7 @@ from api.views.v1_2 import AssetListViewV1_2, AssetViewV1_2
def get_url_patterns():
return [
path(
'v1.2/assets',
AssetListViewV1_2.as_view(),
name='asset_list_v1_2'
'v1.2/assets', AssetListViewV1_2.as_view(), name='asset_list_v1_2'
),
path(
'v1.2/assets/<str:asset_id>',

View File

@@ -33,7 +33,7 @@ def get_url_patterns():
path(
'v2/assets/<str:asset_id>',
AssetViewV2.as_view(),
name='asset_detail_v2'
name='asset_detail_v2',
),
path('v2/backup', BackupViewV2.as_view(), name='backup_v2'),
path('v2/recover', RecoverViewV2.as_view(), name='recover_v2'),

View File

@@ -61,9 +61,9 @@ class BackupViewMixin(APIView):
201: {
'type': 'string',
'example': 'anthias-backup-2021-09-16T15-00-00.tar.gz',
'description': 'Backup file name'
'description': 'Backup file name',
}
}
},
)
@authorized
def post(self, request):
@@ -82,11 +82,8 @@ class RecoverViewMixin(APIView):
'multipart/form-data': {
'type': 'object',
'properties': {
'backup_upload': {
'type': 'string',
'format': 'binary'
}
}
'backup_upload': {'type': 'string', 'format': 'binary'}
},
}
},
responses={
@@ -99,21 +96,21 @@ class RecoverViewMixin(APIView):
@authorized
def post(self, request):
publisher = ZmqPublisher.get_instance()
file_upload = (request.data.get('backup_upload'))
file_upload = request.data.get('backup_upload')
filename = file_upload.name
if guess_type(filename)[0] != 'application/x-tar':
raise Exception("Incorrect file extension.")
raise Exception('Incorrect file extension.')
try:
publisher.send_to_viewer('stop')
location = path.join("static", filename)
location = path.join('static', filename)
with open(location, 'wb') as f:
f.write(file_upload.read())
backup_helper.recover(location)
return Response("Recovery successful.")
return Response('Recovery successful.')
finally:
publisher.send_to_viewer('play')
@@ -145,11 +142,8 @@ class FileAssetViewMixin(APIView):
'multipart/form-data': {
'type': 'object',
'properties': {
'file_upload': {
'type': 'string',
'format': 'binary'
}
}
'file_upload': {'type': 'string', 'format': 'binary'}
},
}
},
responses={
@@ -157,10 +151,10 @@ class FileAssetViewMixin(APIView):
'type': 'object',
'properties': {
'uri': {'type': 'string'},
'ext': {'type': 'string'}
}
'ext': {'type': 'string'},
},
}
}
},
)
@authorized
def post(self, request):
@@ -169,15 +163,18 @@ class FileAssetViewMixin(APIView):
file_type = guess_type(filename)[0]
if not file_type:
raise Exception("Invalid file type.")
raise Exception('Invalid file type.')
if file_type.split('/')[0] not in ['image', 'video']:
raise Exception("Invalid file type.")
raise Exception('Invalid file type.')
file_path = path.join(
settings['assetdir'],
uuid.uuid5(uuid.NAMESPACE_URL, filename).hex,
) + ".tmp"
file_path = (
path.join(
settings['assetdir'],
uuid.uuid5(uuid.NAMESPACE_URL, filename).hex,
)
+ '.tmp'
)
if 'Content-Range' in request.headers:
range_str = request.headers['Content-Range']
@@ -211,9 +208,9 @@ class AssetContentViewMixin(APIView):
'filename': {'type': 'string'},
'mimetype': {'type': 'string'},
'content': {'type': 'string'},
}
},
}
}
},
)
@authorized
def get(self, request, asset_id, format=None):
@@ -233,13 +230,10 @@ class AssetContentViewMixin(APIView):
'type': 'file',
'filename': filename,
'content': b64encode(content).decode(),
'mimetype': mimetype
'mimetype': mimetype,
}
else:
result = {
'type': 'url',
'url': asset.uri
}
result = {'type': 'url', 'url': asset.uri}
return Response(result)
@@ -248,7 +242,7 @@ class PlaylistOrderViewMixin(APIView):
@extend_schema(
summary='Update playlist order',
request=PlaylistOrderSerializerMixin,
responses={204: None}
responses={204: None},
)
@authorized
def post(self, request):
@@ -280,13 +274,13 @@ class AssetsControlViewMixin(APIView):
type=OpenApiTypes.STR,
enum=['next', 'previous', 'asset&{asset_id}'],
)
]
],
)
@authorized
def get(self, request, command):
publisher = ZmqPublisher.get_instance()
publisher.send_to_viewer(command)
return Response("Asset switched")
return Response('Asset switched')
class InfoViewMixin(APIView):
@@ -300,31 +294,33 @@ class InfoViewMixin(APIView):
'loadavg': {'type': 'number'},
'free_space': {'type': 'string'},
'display_power': {'type': 'string'},
'up_to_date': {'type': 'boolean'}
'up_to_date': {'type': 'boolean'},
},
'example': {
'viewlog': 'Not yet implemented',
'loadavg': 0.1,
'free_space': '10G',
'display_power': 'on',
'up_to_date': True
}
'up_to_date': True,
},
}
}
},
)
@authorized
def get(self, request):
viewlog = "Not yet implemented"
viewlog = 'Not yet implemented'
# Calculate disk space
slash = statvfs("/")
slash = statvfs('/')
free_space = size(slash.f_bavail * slash.f_frsize)
display_power = r.get('display_power')
return Response({
'viewlog': viewlog,
'loadavg': diagnostics.get_load_avg()['15 min'],
'free_space': free_space,
'display_power': display_power,
'up_to_date': is_up_to_date()
})
return Response(
{
'viewlog': viewlog,
'loadavg': diagnostics.get_load_avg()['15 min'],
'free_space': free_space,
'display_power': display_power,
'up_to_date': is_up_to_date(),
}
)

View File

@@ -1,4 +1,3 @@
from drf_spectacular.utils import (
OpenApiExample,
OpenApiRequest,
@@ -68,8 +67,7 @@ V1_ASSET_REQUEST = OpenApiRequest(
),
examples=[
OpenApiExample(
name='Example 1',
value={'model': MODEL_STRING_EXAMPLE}
name='Example 1', value={'model': MODEL_STRING_EXAMPLE}
),
],
)
@@ -87,9 +85,7 @@ class AssetViewV1(APIView, DeleteAssetViewMixin):
@extend_schema(
summary='Update asset',
request=V1_ASSET_REQUEST,
responses={
201: AssetSerializer
}
responses={201: AssetSerializer},
)
@authorized
def put(self, request, asset_id, format=None):
@@ -102,7 +98,8 @@ class AssetViewV1(APIView, DeleteAssetViewMixin):
serializer.save()
else:
return Response(
serializer.errors, status=status.HTTP_400_BAD_REQUEST)
serializer.errors, status=status.HTTP_400_BAD_REQUEST
)
asset.refresh_from_db()
return Response(AssetSerializer(asset).data)
@@ -116,10 +113,7 @@ class AssetListViewV1(APIView):
serializer_class = AssetSerializer
@extend_schema(
summary='List assets',
responses={
200: AssetSerializer(many=True)
}
summary='List assets', responses={200: AssetSerializer(many=True)}
)
@authorized
def get(self, request, format=None):
@@ -130,9 +124,7 @@ class AssetListViewV1(APIView):
@extend_schema(
summary='Create asset',
request=V1_ASSET_REQUEST,
responses={
201: AssetSerializer
}
responses={201: AssetSerializer},
)
@authorized
def post(self, request, format=None):
@@ -148,7 +140,8 @@ class AssetListViewV1(APIView):
asset = Asset.objects.create(**serializer.data)
return Response(
AssetSerializer(asset).data, status=status.HTTP_201_CREATED)
AssetSerializer(asset).data, status=status.HTTP_201_CREATED
)
class FileAssetViewV1(FileAssetViewMixin):
@@ -187,7 +180,7 @@ class ViewerCurrentAssetViewV1(APIView):
@extend_schema(
summary='Get current asset',
description='Get the current asset being displayed on the screen',
responses={200: AssetSerializer}
responses={200: AssetSerializer},
)
@authorized
def get(self, request):

View File

@@ -17,10 +17,7 @@ from lib.auth import authorized
class AssetListViewV1_1(APIView):
@extend_schema(
summary='List assets',
responses={
200: AssetSerializer(many=True)
}
summary='List assets', responses={200: AssetSerializer(many=True)}
)
@authorized
def get(self, request):
@@ -31,9 +28,7 @@ class AssetListViewV1_1(APIView):
@extend_schema(
summary='Create asset',
request=V1_ASSET_REQUEST,
responses={
201: AssetSerializer
}
responses={201: AssetSerializer},
)
@authorized
def post(self, request):
@@ -49,7 +44,8 @@ class AssetListViewV1_1(APIView):
asset = Asset.objects.create(**serializer.data)
return Response(
AssetSerializer(asset).data, status=status.HTTP_201_CREATED)
AssetSerializer(asset).data, status=status.HTTP_201_CREATED
)
class AssetViewV1_1(APIView, DeleteAssetViewMixin):
@@ -57,7 +53,7 @@ class AssetViewV1_1(APIView, DeleteAssetViewMixin):
summary='Get asset',
responses={
200: AssetSerializer,
}
},
)
@authorized
def get(self, request, asset_id):
@@ -67,9 +63,7 @@ class AssetViewV1_1(APIView, DeleteAssetViewMixin):
@extend_schema(
summary='Update asset',
request=V1_ASSET_REQUEST,
responses={
200: AssetSerializer
}
responses={200: AssetSerializer},
)
@authorized
def put(self, request, asset_id):
@@ -82,7 +76,8 @@ class AssetViewV1_1(APIView, DeleteAssetViewMixin):
serializer.save()
else:
return Response(
serializer.errors, status=status.HTTP_400_BAD_REQUEST)
serializer.errors, status=status.HTTP_400_BAD_REQUEST
)
asset.refresh_from_db()
return Response(AssetSerializer(asset).data)

View File

@@ -22,10 +22,7 @@ class AssetListViewV1_2(APIView):
serializer_class = AssetSerializer
@extend_schema(
summary='List assets',
responses={
200: AssetSerializer(many=True)
}
summary='List assets', responses={200: AssetSerializer(many=True)}
)
@authorized
def get(self, request):
@@ -36,15 +33,14 @@ class AssetListViewV1_2(APIView):
@extend_schema(
summary='Create asset',
request=CreateAssetSerializerV1_2,
responses={
201: AssetSerializer
}
responses={201: AssetSerializer},
)
@authorized
def post(self, request):
try:
serializer = CreateAssetSerializerV1_2(
data=request.data, unique_name=True)
data=request.data, unique_name=True
)
if not serializer.is_valid():
raise AssetCreationError(serializer.errors)
@@ -79,13 +75,15 @@ class AssetViewV1_2(APIView, DeleteAssetViewMixin):
def update(self, request, asset_id, partial=False):
asset = Asset.objects.get(asset_id=asset_id)
serializer = UpdateAssetSerializer(
asset, data=request.data, partial=partial)
asset, data=request.data, partial=partial
)
if serializer.is_valid():
serializer.save()
else:
return Response(
serializer.errors, status=status.HTTP_400_BAD_REQUEST)
serializer.errors, status=status.HTTP_400_BAD_REQUEST
)
active_asset_ids = get_active_asset_ids()
@@ -107,9 +105,7 @@ class AssetViewV1_2(APIView, DeleteAssetViewMixin):
@extend_schema(
summary='Update asset',
request=UpdateAssetSerializer,
responses={
200: AssetSerializer
}
responses={200: AssetSerializer},
)
@authorized
def patch(self, request, asset_id):
@@ -118,9 +114,7 @@ class AssetViewV1_2(APIView, DeleteAssetViewMixin):
@extend_schema(
summary='Update asset',
request=UpdateAssetSerializer,
responses={
200: AssetSerializer
}
responses={200: AssetSerializer},
)
@authorized
def put(self, request, asset_id):

View File

@@ -57,10 +57,7 @@ class AssetListViewV2(APIView):
serializer_class = AssetSerializerV2
@extend_schema(
summary='List assets',
responses={
200: AssetSerializerV2(many=True)
}
summary='List assets', responses={200: AssetSerializerV2(many=True)}
)
@authorized
def get(self, request):
@@ -71,15 +68,14 @@ class AssetListViewV2(APIView):
@extend_schema(
summary='Create asset',
request=CreateAssetSerializerV2,
responses={
201: AssetSerializerV2
}
responses={201: AssetSerializerV2},
)
@authorized
def post(self, request):
try:
serializer = CreateAssetSerializerV2(
data=request.data, unique_name=True)
data=request.data, unique_name=True
)
if not serializer.is_valid():
raise AssetCreationError(serializer.errors)
@@ -115,13 +111,15 @@ class AssetViewV2(APIView, DeleteAssetViewMixin):
def update(self, request, asset_id, partial=False):
asset = Asset.objects.get(asset_id=asset_id)
serializer = UpdateAssetSerializerV2(
asset, data=request.data, partial=partial)
asset, data=request.data, partial=partial
)
if serializer.is_valid():
serializer.save()
else:
return Response(
serializer.errors, status=status.HTTP_400_BAD_REQUEST)
serializer.errors, status=status.HTTP_400_BAD_REQUEST
)
active_asset_ids = get_active_asset_ids()
@@ -143,9 +141,7 @@ class AssetViewV2(APIView, DeleteAssetViewMixin):
@extend_schema(
summary='Update asset',
request=UpdateAssetSerializerV2,
responses={
200: AssetSerializerV2
}
responses={200: AssetSerializerV2},
)
@authorized
def patch(self, request, asset_id):
@@ -154,9 +150,7 @@ class AssetViewV2(APIView, DeleteAssetViewMixin):
@extend_schema(
summary='Update asset',
request=UpdateAssetSerializerV2,
responses={
200: AssetSerializerV2
}
responses={200: AssetSerializerV2},
)
@authorized
def put(self, request, asset_id):
@@ -198,9 +192,7 @@ class AssetsControlViewV2(AssetsControlViewMixin):
class DeviceSettingsViewV2(APIView):
@extend_schema(
summary='Get device settings',
responses={
200: DeviceSettingsSerializerV2
}
responses={200: DeviceSettingsSerializerV2},
)
@authorized
def get(self, request):
@@ -211,25 +203,28 @@ class DeviceSettingsViewV2(APIView):
logging.error(f'Failed to reload settings: {str(e)}')
# Continue with existing settings if reload fails
return Response({
'player_name': settings['player_name'],
'audio_output': settings['audio_output'],
'default_duration': int(settings['default_duration']),
'default_streaming_duration': int(
settings['default_streaming_duration']
),
'date_format': settings['date_format'],
'auth_backend': settings['auth_backend'],
'show_splash': settings['show_splash'],
'default_assets': settings['default_assets'],
'shuffle_playlist': settings['shuffle_playlist'],
'use_24_hour_clock': settings['use_24_hour_clock'],
'debug_logging': settings['debug_logging'],
'username': (
settings['user'] if settings['auth_backend'] == 'auth_basic'
else ''
),
})
return Response(
{
'player_name': settings['player_name'],
'audio_output': settings['audio_output'],
'default_duration': int(settings['default_duration']),
'default_streaming_duration': int(
settings['default_streaming_duration']
),
'date_format': settings['date_format'],
'auth_backend': settings['auth_backend'],
'show_splash': settings['show_splash'],
'default_assets': settings['default_assets'],
'shuffle_playlist': settings['shuffle_playlist'],
'use_24_hour_clock': settings['use_24_hour_clock'],
'debug_logging': settings['debug_logging'],
'username': (
settings['user']
if settings['auth_backend'] == 'auth_basic'
else ''
),
}
)
def update_auth_settings(self, data, auth_backend, current_pass_correct):
if auth_backend == '':
@@ -248,36 +243,36 @@ class DeviceSettingsViewV2(APIView):
if new_user != settings['user']:
if current_pass_correct is None:
raise ValueError(
"Must supply current password to change username"
'Must supply current password to change username'
)
if not current_pass_correct:
raise ValueError("Incorrect current password.")
raise ValueError('Incorrect current password.')
settings['user'] = new_user
if new_pass:
if current_pass_correct is None:
raise ValueError(
"Must supply current password to change password"
'Must supply current password to change password'
)
if not current_pass_correct:
raise ValueError("Incorrect current password.")
raise ValueError('Incorrect current password.')
if new_pass2 != new_pass:
raise ValueError("New passwords do not match!")
raise ValueError('New passwords do not match!')
settings['password'] = new_pass
else:
if new_user:
if new_pass and new_pass != new_pass2:
raise ValueError("New passwords do not match!")
raise ValueError('New passwords do not match!')
if not new_pass:
raise ValueError("Must provide password")
raise ValueError('Must provide password')
settings['user'] = new_user
settings['password'] = new_pass
else:
raise ValueError("Must provide username")
raise ValueError('Must provide username')
@extend_schema(
summary='Update device settings',
@@ -285,15 +280,13 @@ class DeviceSettingsViewV2(APIView):
responses={
200: {
'type': 'object',
'properties': {
'message': {'type': 'string'}
}
'properties': {'message': {'type': 'string'}},
},
400: {
'type': 'object',
'properties': {'error': {'type': 'string'}}
}
}
'properties': {'error': {'type': 'string'}},
},
},
)
@authorized
def patch(self, request):
@@ -314,25 +307,24 @@ class DeviceSettingsViewV2(APIView):
):
if not current_password:
raise ValueError(
"Must supply current password to change "
"authentication method"
'Must supply current password to change '
'authentication method'
)
if not settings.auth.check_password(current_password):
raise ValueError("Incorrect current password.")
raise ValueError('Incorrect current password.')
prev_auth_backend = settings['auth_backend']
if not current_password and prev_auth_backend:
current_pass_correct = None
else:
current_pass_correct = (
settings
.auth_backends[prev_auth_backend]
.check_password(current_password)
)
current_pass_correct = settings.auth_backends[
prev_auth_backend
].check_password(current_password)
next_auth_backend = settings.auth_backends[auth_backend]
self.update_auth_settings(
data, next_auth_backend.name, current_pass_correct)
data, next_auth_backend.name, current_pass_correct
)
settings['auth_backend'] = auth_backend
# Update settings
@@ -341,9 +333,9 @@ class DeviceSettingsViewV2(APIView):
if 'default_duration' in data:
settings['default_duration'] = data['default_duration']
if 'default_streaming_duration' in data:
settings['default_streaming_duration'] = (
data['default_streaming_duration']
)
settings['default_streaming_duration'] = data[
'default_streaming_duration'
]
if 'audio_output' in data:
settings['audio_output'] = data['audio_output']
if 'date_format' in data:
@@ -371,7 +363,7 @@ class DeviceSettingsViewV2(APIView):
except Exception as e:
return Response(
{'error': f'An error occurred while saving settings: {e}'},
status=400
status=400,
)
@@ -408,7 +400,7 @@ class InfoViewV2(InfoViewMixin):
'free': virtual_memory.free >> 20,
'shared': virtual_memory.shared >> 20,
'buff': virtual_memory.buffers >> 20,
'available': virtual_memory.available >> 20
'available': virtual_memory.available >> 20,
}
def get_ip_addresses(self):
@@ -445,8 +437,8 @@ class InfoViewV2(InfoViewMixin):
'type': 'object',
'properties': {
'days': {'type': 'integer'},
'hours': {'type': 'number'}
}
'hours': {'type': 'number'},
},
},
'memory': {
'type': 'object',
@@ -456,41 +448,44 @@ class InfoViewV2(InfoViewMixin):
'free': {'type': 'integer'},
'shared': {'type': 'integer'},
'buff': {'type': 'integer'},
'available': {'type': 'integer'}
}
'available': {'type': 'integer'},
},
},
'ip_addresses': {
'type': 'array', 'items': {'type': 'string'}
'type': 'array',
'items': {'type': 'string'},
},
'mac_address': {'type': 'string'},
'host_user': {'type': 'string'}
}
'host_user': {'type': 'string'},
},
}
}
},
)
@authorized
def get(self, request):
viewlog = "Not yet implemented"
viewlog = 'Not yet implemented'
# Calculate disk space
slash = statvfs("/")
slash = statvfs('/')
free_space = size(slash.f_bavail * slash.f_frsize)
display_power = r.get('display_power')
return Response({
'viewlog': viewlog,
'loadavg': diagnostics.get_load_avg()['15 min'],
'free_space': free_space,
'display_power': display_power,
'up_to_date': is_up_to_date(),
'anthias_version': self.get_anthias_version(),
'device_model': self.get_device_model(),
'uptime': self.get_uptime(),
'memory': self.get_memory(),
'ip_addresses': self.get_ip_addresses(),
'mac_address': get_node_mac_address(),
'host_user': getenv('HOST_USER'),
})
return Response(
{
'viewlog': viewlog,
'loadavg': diagnostics.get_load_avg()['15 min'],
'free_space': free_space,
'display_power': display_power,
'up_to_date': is_up_to_date(),
'anthias_version': self.get_anthias_version(),
'device_model': self.get_device_model(),
'uptime': self.get_uptime(),
'memory': self.get_memory(),
'ip_addresses': self.get_ip_addresses(),
'mac_address': get_node_mac_address(),
'host_user': getenv('HOST_USER'),
}
)
class IntegrationsViewV2(APIView):
@@ -498,9 +493,7 @@ class IntegrationsViewV2(APIView):
@extend_schema(
summary='Get integrations information',
responses={
200: IntegrationsSerializerV2
}
responses={200: IntegrationsSerializerV2},
)
@authorized
def get(self, request):
@@ -509,20 +502,22 @@ class IntegrationsViewV2(APIView):
}
if data['is_balena']:
data.update({
'balena_device_id': getenv('BALENA_DEVICE_UUID'),
'balena_app_id': getenv('BALENA_APP_ID'),
'balena_app_name': getenv('BALENA_APP_NAME'),
'balena_supervisor_version': (
getenv('BALENA_SUPERVISOR_VERSION')
),
'balena_host_os_version': (
getenv('BALENA_HOST_OS_VERSION')
),
'balena_device_name_at_init': (
getenv('BALENA_DEVICE_NAME_AT_INIT')
),
})
data.update(
{
'balena_device_id': getenv('BALENA_DEVICE_UUID'),
'balena_app_id': getenv('BALENA_APP_ID'),
'balena_app_name': getenv('BALENA_APP_NAME'),
'balena_supervisor_version': (
getenv('BALENA_SUPERVISOR_VERSION')
),
'balena_host_os_version': (
getenv('BALENA_HOST_OS_VERSION')
),
'balena_device_name_at_init': (
getenv('BALENA_DEVICE_NAME_AT_INIT')
),
}
)
serializer = self.serializer_class(data=data)
serializer.is_valid(raise_exception=True)

View File

@@ -22,13 +22,14 @@ except Exception:
pass
__author__ = "Screenly, Inc"
__copyright__ = "Copyright 2012-2024, Screenly, Inc"
__license__ = "Dual License: GPLv2 and Commercial License"
__author__ = 'Screenly, Inc'
__copyright__ = 'Copyright 2012-2024, Screenly, Inc'
__license__ = 'Dual License: GPLv2 and Commercial License'
CELERY_RESULT_BACKEND = getenv(
'CELERY_RESULT_BACKEND', 'redis://localhost:6379/0')
'CELERY_RESULT_BACKEND', 'redis://localhost:6379/0'
)
CELERY_BROKER_URL = getenv('CELERY_BROKER_URL', 'redis://localhost:6379/0')
CELERY_TASK_RESULT_EXPIRES = timedelta(hours=6)
@@ -37,7 +38,7 @@ celery = Celery(
'Anthias Celery Worker',
backend=CELERY_RESULT_BACKEND,
broker=CELERY_BROKER_URL,
result_expires=CELERY_TASK_RESULT_EXPIRES
result_expires=CELERY_TASK_RESULT_EXPIRES,
)
@@ -45,7 +46,9 @@ celery = Celery(
def setup_periodic_tasks(sender, **kwargs):
# Calls cleanup() every hour.
sender.add_periodic_task(3600, cleanup.s(), name='cleanup')
sender.add_periodic_task(60*5, get_display_power.s(), name='display_power')
sender.add_periodic_task(
60 * 5, get_display_power.s(), name='display_power'
)
@celery.task(time_limit=30)
@@ -58,7 +61,10 @@ def get_display_power():
def cleanup():
sh.find(
path.join(getenv('HOME'), 'screenly_assets'),
'-name', '*.tmp', '-delete')
'-name',
'*.tmp',
'-delete',
)
@celery.task

View File

@@ -3,8 +3,8 @@
from __future__ import unicode_literals
__author__ = "Nash Kaminski"
__license__ = "Dual License: GPLv2 and Commercial License"
__author__ = 'Nash Kaminski'
__license__ = 'Dual License: GPLv2 and Commercial License'
import ipaddress
import json
@@ -22,7 +22,7 @@ from tenacity import (
wait_fixed,
)
REDIS_ARGS = dict(host="127.0.0.1", port=6379, db=0)
REDIS_ARGS = dict(host='127.0.0.1', port=6379, db=0)
# Name of redis channel to listen to
CHANNEL_NAME = b'hostcmd'
SUPPORTED_INTERFACES = (
@@ -39,8 +39,8 @@ def get_ip_addresses():
for interface in netifaces.interfaces()
if interface.startswith(SUPPORTED_INTERFACES)
for ip in (
netifaces.ifaddresses(interface).get(netifaces.AF_INET, []) +
netifaces.ifaddresses(interface).get(netifaces.AF_INET6, [])
netifaces.ifaddresses(interface).get(netifaces.AF_INET, [])
+ netifaces.ifaddresses(interface).get(netifaces.AF_INET6, [])
)
if not ipaddress.ip_address(ip['addr']).is_link_local
]
@@ -75,7 +75,7 @@ def set_ip_addresses():
CMD_TO_ARGV = {
b'reboot': ['/usr/bin/sudo', '-n', '/usr/bin/systemctl', 'reboot'],
b'shutdown': ['/usr/bin/sudo', '-n', '/usr/bin/systemctl', 'poweroff'],
b'set_ip_addresses': set_ip_addresses
b'set_ip_addresses': set_ip_addresses,
}
@@ -83,17 +83,18 @@ def execute_host_command(cmd_name):
cmd = CMD_TO_ARGV.get(cmd_name, None)
if cmd is None:
logging.warning(
"Unable to perform host command %s: no such command!", cmd_name)
'Unable to perform host command %s: no such command!', cmd_name
)
elif os.getenv('TESTING'):
logging.warning(
"Would have executed %s but not doing so as TESTING is defined",
'Would have executed %s but not doing so as TESTING is defined',
cmd,
)
elif cmd_name in [b'reboot', b'shutdown']:
logging.info("Executing host command %s", cmd_name)
logging.info('Executing host command %s', cmd_name)
phandle = subprocess.run(cmd)
logging.info(
"Host command %s (%s) returned %s",
'Host command %s (%s) returned %s',
cmd_name,
cmd,
phandle.returncode,
@@ -110,18 +111,19 @@ def process_message(message):
):
execute_host_command(message.get('data', b''))
else:
logging.info("Received unsolicited message: %s", message)
logging.info('Received unsolicited message: %s', message)
def subscriber_loop():
# Connect to redis on localhost and wait for messages
logging.info("Connecting to redis...")
logging.info('Connecting to redis...')
rdb = redis.Redis(**REDIS_ARGS)
pubsub = rdb.pubsub(ignore_subscribe_messages=True)
pubsub.subscribe(CHANNEL_NAME)
rdb.set('host_agent_ready', 'true')
logging.info(
"Subscribed to channel %s, ready to process messages", CHANNEL_NAME)
'Subscribed to channel %s, ready to process messages', CHANNEL_NAME
)
for message in pubsub.listen():
process_message(message)

View File

@@ -46,7 +46,8 @@ class Auth(with_metaclass(ABCMeta, object)):
return self.authenticate()
except ValueError as e:
return HttpResponse(
"Authorization backend is unavailable: " + str(e), status=503)
'Authorization backend is unavailable: ' + str(e), status=503
)
def update_settings(self, request, current_pass_correct):
"""
@@ -95,12 +96,7 @@ class NoAuth(Auth):
class BasicAuth(Auth):
display_name = 'Basic'
name = 'auth_basic'
config = {
'auth_basic': {
'user': '',
'password': ''
}
}
config = {'auth_basic': {'user': '', 'password': ''}}
def __init__(self, settings):
self.settings = settings
@@ -112,8 +108,8 @@ class BasicAuth(Auth):
:param password: str
:return: True if the check passes.
"""
return (
self.settings['user'] == username and self.check_password(password)
return self.settings['user'] == username and self.check_password(
password
)
def check_password(self, password):
@@ -151,6 +147,7 @@ class BasicAuth(Auth):
def authenticate(self):
from django.shortcuts import redirect
from django.urls import reverse
return redirect(reverse('anthias_app:login'))
def update_settings(self, request, current_pass_correct):
@@ -166,34 +163,36 @@ class BasicAuth(Auth):
# Optionally may change password.
if current_pass_correct is None:
raise ValueError(
"Must supply current password to change username")
'Must supply current password to change username'
)
if not current_pass_correct:
raise ValueError("Incorrect current password.")
raise ValueError('Incorrect current password.')
self.settings['user'] = new_user
if new_pass:
if current_pass_correct is None:
raise ValueError(
"Must supply current password to change password")
'Must supply current password to change password'
)
if not current_pass_correct:
raise ValueError("Incorrect current password.")
raise ValueError('Incorrect current password.')
if new_pass2 != new_pass: # changing password
raise ValueError("New passwords do not match!")
raise ValueError('New passwords do not match!')
self.settings['password'] = new_pass
else: # no current password
if new_user: # setting username and password
if new_pass and new_pass != new_pass2:
raise ValueError("New passwords do not match!")
raise ValueError('New passwords do not match!')
if not new_pass:
raise ValueError("Must provide password")
raise ValueError('Must provide password')
self.settings['user'] = new_user
self.settings['password'] = new_pass
else:
raise ValueError("Must provide username")
raise ValueError('Must provide username')
def authorized(orig):
@@ -214,11 +213,11 @@ def authorized(orig):
if not isinstance(request, (HttpRequest, Request)):
raise ValueError(
'Request object is not of type HttpRequest or Request')
'Request object is not of type HttpRequest or Request'
)
return (
settings.auth.authenticate_if_needed(request) or
orig(*args, **kwargs)
return settings.auth.authenticate_if_needed(request) or orig(
*args, **kwargs
)
return decorated

View File

@@ -7,15 +7,15 @@ from datetime import datetime
from os import getenv, makedirs, path, remove
directories = ['.screenly', 'screenly_assets']
default_archive_name = "anthias-backup"
static_dir = "screenly/staticfiles"
default_archive_name = 'anthias-backup'
static_dir = 'screenly/staticfiles'
def create_backup(name=default_archive_name):
home = getenv('HOME')
archive_name = "{}-{}.tar.gz".format(
archive_name = '{}-{}.tar.gz'.format(
name if name else default_archive_name,
datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
datetime.now().strftime('%Y-%m-%dT%H-%M-%S'),
)
file_path = path.join(home, static_dir, archive_name)
@@ -26,8 +26,7 @@ def create_backup(name=default_archive_name):
remove(file_path)
try:
with tarfile.open(file_path, "w:gz") as tar:
with tarfile.open(file_path, 'w:gz') as tar:
for directory in directories:
path_to_dir = path.join(home, directory)
tar.add(path_to_dir, arcname=directory)
@@ -46,10 +45,10 @@ def recover(file_path):
# or we can create a new class that extends Exception.
sys.exit(1)
with tarfile.open(file_path, "r:gz") as tar:
with tarfile.open(file_path, 'r:gz') as tar:
for directory in directories:
if directory not in tar.getnames():
raise Exception("Archive is wrong.")
raise Exception('Archive is wrong.')
tar.extractall(path=getenv('HOME'))

View File

@@ -6,9 +6,7 @@ def parse_cpu_info():
Extracts the various Raspberry Pi related data
from the CPU.
"""
cpu_info = {
'cpu_count': 0
}
cpu_info = {'cpu_count': 0}
with open('/proc/cpuinfo', 'r') as cpuinfo:
for line in cpuinfo:

View File

@@ -98,8 +98,8 @@ def get_debian_version():
def get_raspberry_code():
return device_helper.parse_cpu_info().get('hardware', "Unknown")
return device_helper.parse_cpu_info().get('hardware', 'Unknown')
def get_raspberry_model():
return device_helper.parse_cpu_info().get('model', "Unknown")
return device_helper.parse_cpu_info().get('model', 'Unknown')

View File

@@ -20,14 +20,14 @@ r = connect_to_redis()
# Availability and HEAD commit of the remote branch to be checked
# every 24 hours.
REMOTE_BRANCH_STATUS_TTL = (60 * 60 * 24)
REMOTE_BRANCH_STATUS_TTL = 60 * 60 * 24
# Suspend all external requests if we enconter an error other than
# a ConnectionError for 5 minutes.
ERROR_BACKOFF_TTL = (60 * 5)
ERROR_BACKOFF_TTL = 60 * 5
# Availability of the cached Docker Hub hash
DOCKER_HUB_HASH_TTL = (10 * 60)
DOCKER_HUB_HASH_TTL = 10 * 60
# Google Analytics data
ANALYTICS_MEASURE_ID = 'G-S3VX8HTPK7'
@@ -59,13 +59,13 @@ def remote_branch_available(branch):
# Make sure we havent recently failed before allowing fetch
if r.get('github-api-error') is not None:
logging.warning("GitHub requests suspended due to prior error")
logging.warning('GitHub requests suspended due to prior error')
return None
# Check for cached remote branch status
remote_branch_cache = r.get('remote-branch-available')
if remote_branch_cache is not None:
return remote_branch_cache == "1"
return remote_branch_cache == '1'
try:
resp = requests_get(
@@ -73,7 +73,7 @@ def remote_branch_available(branch):
headers={
'Accept': 'application/vnd.github.loki-preview+json',
},
timeout=DEFAULT_REQUESTS_TIMEOUT
timeout=DEFAULT_REQUESTS_TIMEOUT,
)
resp.raise_for_status()
except exceptions.RequestException as exc:
@@ -116,7 +116,7 @@ def fetch_remote_hash():
try:
resp = requests_get(
f'https://api.github.com/repos/screenly/anthias/git/refs/heads/{branch}', # noqa: E501
timeout=DEFAULT_REQUESTS_TIMEOUT
timeout=DEFAULT_REQUESTS_TIMEOUT,
)
resp.raise_for_status()
except exceptions.RequestException as exc:
@@ -163,7 +163,8 @@ def get_latest_docker_hub_hash(device_type):
if len(reduced) == 0:
logging.warning(
'No commit hash found for device type: %s', device_type)
'No commit hash found for device type: %s', device_type
)
return None
docker_hub_hash = reduced[0]
@@ -209,25 +210,27 @@ def is_up_to_date():
ga_url = f'{ga_base_url}?{ga_query_params}'
payload = {
'client_id': device_id,
'events': [{
'name': 'version',
'params': {
'Branch': str(git_branch),
'Hash': str(git_short_hash),
'NOOBS': os.path.isfile('/boot/os_config.json'),
'Balena': is_balena_app(),
'Docker': is_docker(),
'Pi_Version': parse_cpu_info().get('model', "Unknown")
}
}]
'events': [
{
'name': 'version',
'params': {
'Branch': str(git_branch),
'Hash': str(git_short_hash),
'NOOBS': os.path.isfile('/boot/os_config.json'),
'Balena': is_balena_app(),
'Docker': is_docker(),
'Pi_Version': parse_cpu_info().get(
'model', 'Unknown'
),
},
}
],
}
headers = {'content-type': 'application/json'}
try:
requests_post(
ga_url,
data=json.dumps(payload),
headers=headers
ga_url, data=json.dumps(payload), headers=headers
)
except exceptions.ConnectionError:
pass
@@ -235,7 +238,6 @@ def is_up_to_date():
device_type = os.getenv('DEVICE_TYPE')
latest_docker_hub_hash = get_latest_docker_hub_hash(device_type)
return (
(latest_sha == git_hash) or
(latest_docker_hub_hash == git_short_hash)
return (latest_sha == git_hash) or (
latest_docker_hub_hash == git_short_hash
)

View File

@@ -84,12 +84,15 @@ def validate_url(string):
def get_balena_supervisor_api_response(method, action, **kwargs):
version = kwargs.get('version', 'v1')
return getattr(requests, method)('{}/{}/{}?apikey={}'.format(
os.getenv('BALENA_SUPERVISOR_ADDRESS'),
version,
action,
os.getenv('BALENA_SUPERVISOR_API_KEY'),
), headers={'Content-Type': 'application/json'})
return getattr(requests, method)(
'{}/{}/{}?apikey={}'.format(
os.getenv('BALENA_SUPERVISOR_ADDRESS'),
version,
action,
os.getenv('BALENA_SUPERVISOR_API_KEY'),
),
headers={'Content-Type': 'application/json'},
)
def get_balena_device_info():
@@ -106,7 +109,8 @@ def reboot_via_balena_supervisor():
def get_balena_supervisor_version():
response = get_balena_supervisor_api_response(
method='get', action='version', version='v2')
method='get', action='version', version='v2'
)
if response.ok:
return response.json()['version']
else:
@@ -169,11 +173,10 @@ def get_node_ip():
break
else:
raise Exception(
'Internet connection is not available.')
'Internet connection is not available.'
)
except RetryError:
logging.warning(
'Internet connection is not available. '
)
logging.warning('Internet connection is not available. ')
ip_addresses = r.get('ip_addresses')
@@ -194,10 +197,12 @@ def get_node_mac_address():
balena_supervisor_api_key = os.getenv('BALENA_SUPERVISOR_API_KEY')
headers = {'Content-Type': 'application/json'}
r = requests.get('{}/v1/device?apikey={}'.format(
balena_supervisor_address,
balena_supervisor_api_key
), headers=headers)
r = requests.get(
'{}/v1/device?apikey={}'.format(
balena_supervisor_address, balena_supervisor_api_key
),
headers=headers,
)
if r.ok:
return r.json()['mac_address']
@@ -220,35 +225,45 @@ def get_active_connections(bus, fields=None):
try:
nm_proxy = bus.get(
"org.freedesktop.NetworkManager",
"/org/freedesktop/NetworkManager",
'org.freedesktop.NetworkManager',
'/org/freedesktop/NetworkManager',
)
except Exception:
return None
nm_properties = nm_proxy["org.freedesktop.DBus.Properties"]
nm_properties = nm_proxy['org.freedesktop.DBus.Properties']
active_connections = nm_properties.Get(
"org.freedesktop.NetworkManager", "ActiveConnections")
'org.freedesktop.NetworkManager', 'ActiveConnections'
)
for active_connection in active_connections:
active_connection_proxy = bus.get(
"org.freedesktop.NetworkManager", active_connection)
active_connection_properties = (
active_connection_proxy["org.freedesktop.DBus.Properties"])
'org.freedesktop.NetworkManager', active_connection
)
active_connection_properties = active_connection_proxy[
'org.freedesktop.DBus.Properties'
]
connection = dict()
for field in fields:
field_value = active_connection_properties.Get(
"org.freedesktop.NetworkManager.Connection.Active", field)
'org.freedesktop.NetworkManager.Connection.Active', field
)
if field == 'Devices':
devices = list()
for device_path in field_value:
device_proxy = bus.get(
"org.freedesktop.NetworkManager", device_path)
device_properties = (
device_proxy["org.freedesktop.DBus.Properties"])
devices.append(device_properties.Get(
"org.freedesktop.NetworkManager.Device", "Interface"))
'org.freedesktop.NetworkManager', device_path
)
device_properties = device_proxy[
'org.freedesktop.DBus.Properties'
]
devices.append(
device_properties.Get(
'org.freedesktop.NetworkManager.Device',
'Interface',
)
)
field_value = devices
connection.update({field: field_value})
@@ -266,19 +281,21 @@ def remove_connection(bus, uuid):
"""
try:
nm_proxy = bus.get(
"org.freedesktop.NetworkManager",
"/org/freedesktop/NetworkManager/Settings",
'org.freedesktop.NetworkManager',
'/org/freedesktop/NetworkManager/Settings',
)
except Exception:
return False
nm_settings = nm_proxy["org.freedesktop.NetworkManager.Settings"]
nm_settings = nm_proxy['org.freedesktop.NetworkManager.Settings']
connection_path = nm_settings.GetConnectionByUuid(uuid)
connection_proxy = bus.get(
"org.freedesktop.NetworkManager", connection_path)
connection = (
connection_proxy["org.freedesktop.NetworkManager.Settings.Connection"])
'org.freedesktop.NetworkManager', connection_path
)
connection = connection_proxy[
'org.freedesktop.NetworkManager.Settings.Connection'
]
connection.Delete()
return True
@@ -332,7 +349,8 @@ def url_fails(url):
"""
if urlparse(url).scheme in ('rtsp', 'rtmp'):
run_mplayer = mplayer( # noqa: F821
'-identify', '-frames', '0', '-nosound', url)
'-identify', '-frames', '0', '-nosound', url
)
for line in run_mplayer.split('\n'):
if 'Clip info:' in line:
return False
@@ -361,7 +379,7 @@ def url_fails(url):
allow_redirects=True,
headers=headers,
timeout=10,
verify=verify
verify=verify,
).ok:
return False
@@ -370,7 +388,7 @@ def url_fails(url):
allow_redirects=True,
headers=headers,
timeout=10,
verify=verify
verify=verify,
).ok:
return False
@@ -386,11 +404,7 @@ def download_video_from_youtube(uri, asset_id):
info = json.loads(check_output(['yt-dlp', '-j', uri]))
duration = info['duration']
location = path.join(
home,
'screenly_assets',
f'{asset_id}.mp4'
)
location = path.join(home, 'screenly_assets', f'{asset_id}.mp4')
thread = YoutubeDownloadThread(location, uri, asset_id)
thread.daemon = True
thread.start()
@@ -407,14 +421,16 @@ class YoutubeDownloadThread(Thread):
def run(self):
publisher = ZmqPublisher.get_instance()
call([
'yt-dlp',
'-S',
'vcodec:h264,fps,res:1080,acodec:m4a',
'-o',
self.location,
self.uri,
])
call(
[
'yt-dlp',
'-S',
'vcodec:h264,fps,res:1080,acodec:m4a',
'-o',
self.location,
self.uri,
]
)
try:
asset = Asset.objects.get(asset_id=self.asset_id)
@@ -448,11 +464,14 @@ def generate_perfect_paper_password(pw_length=10, has_symbols=True):
:param has_symbols: bool
:return: string
"""
ppp_letters = '!#%+23456789:=?@ABCDEFGHJKLMNPRSTUVWXYZabcdefghjkmnopqrstuvwxyz' # noqa: E501
ppp_letters = (
'!#%+23456789:=?@ABCDEFGHJKLMNPRSTUVWXYZabcdefghjkmnopqrstuvwxyz' # noqa: E501
)
if not has_symbols:
ppp_letters = ''.join(set(ppp_letters) - set(string.punctuation))
return "".join(
random.SystemRandom().choice(ppp_letters) for _ in range(pw_length))
return ''.join(
random.SystemRandom().choice(ppp_letters) for _ in range(pw_length)
)
def connect_to_redis():

View File

@@ -1,5 +1,6 @@
#!/usr/bin/env python3
"""Django's command-line utility for administrative tasks."""
import os
import sys
@@ -12,8 +13,8 @@ def main():
except ImportError as exc:
raise ImportError(
"Couldn't import Django. Are you sure it's installed and "
"available on your PYTHONPATH environment variable? Did you "
"forget to activate a virtual environment?"
'available on your PYTHONPATH environment variable? Did you '
'forget to activate a virtual environment?'
) from exc
execute_from_command_line(sys.argv)

View File

@@ -21,34 +21,3 @@ python-on-whales = '^0.79.0'
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
[tool.ruff]
# Exclude files/directories
exclude = ["anthias_app/migrations/*.py"]
# Line length configuration
line-length = 79
# Python target version
target-version = "py39"
# Enable all rules by default
lint.select = [
"E", # pycodestyle
"F", # pyflakes
"W", # pycodestyle warnings
"I", # isort
"N", # pep8-naming
"B", # flake8-bugbear
"A", # flake8-builtins
]
# Ignore specific rules
lint.ignore = [
"N801", # Ignore class naming convention for API versioning throughout the codebase
"A002",
"A004",
]
[tool.ruff.lint.per-file-ignores]
"bin/migrate.py" = ["E501"]

View File

@@ -6,17 +6,16 @@ import json
import requests
BASE_URL = "https://api.github.com/repos/Screenly/Anthias"
BASE_URL = 'https://api.github.com/repos/Screenly/Anthias'
GITHUB_HEADERS = {
"Accept": "application/vnd.github+json",
"X-GitHub-Api-Version": "2022-11-28"
'Accept': 'application/vnd.github+json',
'X-GitHub-Api-Version': '2022-11-28',
}
def get_latest_tag():
response = requests.get(
"{}/releases/latest".format(BASE_URL),
headers=GITHUB_HEADERS
'{}/releases/latest'.format(BASE_URL), headers=GITHUB_HEADERS
)
return response.json()['tag_name']
@@ -25,8 +24,8 @@ def get_latest_tag():
def get_asset_list(release_tag):
asset_urls = []
response = requests.get(
"{}/releases/tags/{}".format(BASE_URL, release_tag),
headers=GITHUB_HEADERS
'{}/releases/tags/{}'.format(BASE_URL, release_tag),
headers=GITHUB_HEADERS,
)
for url in response.json()['assets']:
@@ -39,8 +38,7 @@ def get_asset_list(release_tag):
def retrieve_and_patch_json(url):
image_json = requests.get(
url.replace('.img.zst', '.json'),
headers=GITHUB_HEADERS
url.replace('.img.zst', '.json'), headers=GITHUB_HEADERS
).json()
image_json['url'] = url
@@ -52,7 +50,7 @@ def retrieve_and_patch_json(url):
def main():
latest_release = get_latest_tag()
release_assets = get_asset_list(latest_release)
pi_imager_json = {"os_list": []}
pi_imager_json = {'os_list': []}
for url in release_assets:
pi_imager_json['os_list'].append(retrieve_and_patch_json(url))
@@ -60,5 +58,5 @@ def main():
print(json.dumps(pi_imager_json))
if __name__ == "__main__":
if __name__ == '__main__':
main()

6
ruff.toml Normal file
View File

@@ -0,0 +1,6 @@
line-length = 79
exclude = ["anthias_app/migrations/*.py"]
target-version = "py311"
[format]
quote-style = "single"

View File

@@ -38,7 +38,8 @@ def get_ip_addresses():
i['addr']
for interface_name in interfaces()
for i in ifaddresses(interface_name).setdefault(
AF_INET, [{'addr': None}])
AF_INET, [{'addr': None}]
)
if interface_name in ['eth0', 'wlan0']
if i['addr'] is not None
]

View File

@@ -28,7 +28,7 @@ DEFAULTS = {
'use_ssl': False,
'auth_backend': '',
'websocket_port': '9999',
'django_secret_key': ''
'django_secret_key': '',
},
'viewer': {
'audio_output': 'hdmi',
@@ -40,25 +40,26 @@ DEFAULTS = {
'show_splash': True,
'shuffle_playlist': False,
'verify_ssl': True,
'default_assets': False
}
'default_assets': False,
},
}
CONFIGURABLE_SETTINGS = DEFAULTS['viewer'].copy()
CONFIGURABLE_SETTINGS['use_24_hour_clock'] = (
DEFAULTS['main']['use_24_hour_clock'])
CONFIGURABLE_SETTINGS['use_24_hour_clock'] = DEFAULTS['main'][
'use_24_hour_clock'
]
CONFIGURABLE_SETTINGS['date_format'] = DEFAULTS['main']['date_format']
PORT = int(getenv('PORT', 8080))
LISTEN = getenv('LISTEN', '127.0.0.1')
# Initiate logging
logging.basicConfig(level=logging.INFO,
format='%(message)s',
datefmt='%a, %d %b %Y %H:%M:%S')
logging.basicConfig(
level=logging.INFO, format='%(message)s', datefmt='%a, %d %b %Y %H:%M:%S'
)
# Silence urllib info messages ('Starting new HTTP connection')
# that are triggered by the remote url availability check in view_web
requests_log = logging.getLogger("requests")
requests_log = logging.getLogger('requests')
requests_log.setLevel(logging.WARNING)
logging.debug('Starting viewer')
@@ -79,7 +80,8 @@ class AnthiasSettings(UserDict):
if not path.isfile(self.conf_file):
logging.error(
'Config-file %s missing. Using defaults.', self.conf_file)
'Config-file %s missing. Using defaults.', self.conf_file
)
self.use_defaults()
self.save()
else:
@@ -95,9 +97,9 @@ class AnthiasSettings(UserDict):
self[field] = config.get(section, field)
# Likely not a hashed password
if (
field == 'password' and
self[field] != '' and
len(self[field]) != 64
field == 'password'
and self[field] != ''
and len(self[field]) != 64
):
# Hash the original password.
self[field] = hashlib.sha256(self[field]).hexdigest()
@@ -105,7 +107,10 @@ class AnthiasSettings(UserDict):
logging.debug(
"Could not parse setting '%s.%s': %s. "
"Using default value: '%s'.",
section, field, str(e), default
section,
field,
str(e),
default,
)
self[field] = default
if field in ['database', 'assetdir']:
@@ -114,7 +119,8 @@ class AnthiasSettings(UserDict):
def _set(self, config, section, field, default):
if isinstance(default, bool):
config.set(
section, field, self.get(field, default) and 'on' or 'off')
section, field, self.get(field, default) and 'on' or 'off'
)
else:
config.set(section, field, str(self.get(field, default)))
@@ -140,7 +146,7 @@ class AnthiasSettings(UserDict):
config.add_section(section)
for field, default in list(defaults.items()):
self._set(config, section, field, default)
with open(self.conf_file, "w") as f:
with open(self.conf_file, 'w') as f:
config.write(f)
self.load()
@@ -165,7 +171,7 @@ class ZmqPublisher(object):
def __init__(self):
if self.INSTANCE is not None:
raise ValueError("An instance already exists!")
raise ValueError('An instance already exists!')
self.context = zmq.Context()
@@ -180,10 +186,10 @@ class ZmqPublisher(object):
return cls.INSTANCE
def send_to_ws_server(self, msg):
self.socket.send("ws_server {}".format(msg).encode('utf-8'))
self.socket.send('ws_server {}'.format(msg).encode('utf-8'))
def send_to_viewer(self, msg):
self.socket.send_string("viewer {}".format(msg))
self.socket.send_string('viewer {}'.format(msg))
class ZmqConsumer(object):
@@ -205,7 +211,7 @@ class ZmqCollector(object):
def __init__(self):
if self.INSTANCE is not None:
raise ValueError("An instance already exists!")
raise ValueError('An instance already exists!')
self.context = zmq.Context()

View File

@@ -29,7 +29,7 @@ asset_x = {
'is_enabled': 0,
'nocache': 0,
'play_order': 1,
'skip_asset_check': 0
'skip_asset_check': 0,
}
asset_y = {
@@ -43,7 +43,7 @@ asset_y = {
'is_enabled': 1,
'nocache': 0,
'play_order': 0,
'skip_asset_check': 0
'skip_asset_check': 0,
}
@@ -95,12 +95,15 @@ class WebTest(TestCase):
browser.visit(main_page_url)
wait_for_and_do(
browser, '#add-asset-button', lambda btn: btn.click())
browser, '#add-asset-button', lambda btn: btn.click()
)
sleep(1)
wait_for_and_do(
browser, 'input[name="uri"]',
lambda field: field.fill('https://example.com'))
browser,
'input[name="uri"]',
lambda field: field.fill('https://example.com'),
)
sleep(1)
wait_for_and_do(browser, '#tab-uri', lambda form: form.click())
@@ -125,12 +128,15 @@ class WebTest(TestCase):
with get_browser() as browser:
browser.visit(main_page_url)
wait_for_and_do(
browser, '.edit-asset-button', lambda btn: btn.click())
browser, '.edit-asset-button', lambda btn: btn.click()
)
sleep(1)
wait_for_and_do(
browser, 'input[name="duration"]',
lambda field: field.fill('333'))
browser,
'input[name="duration"]',
lambda field: field.fill('333'),
)
sleep(1)
wait_for_and_do(browser, '#edit-form', lambda form: form.click())
@@ -139,7 +145,7 @@ class WebTest(TestCase):
wait_for_and_do(
browser,
'.edit-asset-modal #save-asset',
lambda btn: btn.click()
lambda btn: btn.click(),
)
sleep(3)
@@ -159,10 +165,13 @@ class WebTest(TestCase):
sleep(1)
wait_for_and_do(
browser, '.nav-link.upload-asset-tab', lambda tab: tab.click())
browser, '.nav-link.upload-asset-tab', lambda tab: tab.click()
)
wait_for_and_do(
browser, 'input[name="file_upload"]',
lambda file_input: file_input.fill(image_file))
browser,
'input[name="file_upload"]',
lambda file_input: file_input.fill(image_file),
)
sleep(1)
sleep(3)
@@ -176,9 +185,9 @@ class WebTest(TestCase):
self.assertEqual(asset.duration, settings['default_duration'])
def test_add_asset_video_upload(self):
with (
TemporaryCopy('tests/assets/asset.mov', 'video.mov') as video_file
):
with TemporaryCopy(
'tests/assets/asset.mov', 'video.mov'
) as video_file:
with get_browser() as browser:
browser.visit(main_page_url)
@@ -186,11 +195,15 @@ class WebTest(TestCase):
sleep(1)
wait_for_and_do(
browser, '.nav-link.upload-asset-tab',
lambda tab: tab.click())
browser,
'.nav-link.upload-asset-tab',
lambda tab: tab.click(),
)
wait_for_and_do(
browser, 'input[name="file_upload"]',
lambda file_input: file_input.fill(video_file))
browser,
'input[name="file_upload"]',
lambda file_input: file_input.fill(video_file),
)
sleep(1) # Wait for the new-asset panel animation.
sleep(3) # The backend needs time to process the request.
@@ -207,7 +220,8 @@ class WebTest(TestCase):
with (
TemporaryCopy('tests/assets/asset.mov', 'video.mov') as video_file,
TemporaryCopy(
'static/img/standby.png', 'standby.png') as image_file,
'static/img/standby.png', 'standby.png'
) as image_file,
):
with get_browser() as browser:
browser.visit(main_page_url)
@@ -216,14 +230,20 @@ class WebTest(TestCase):
sleep(1)
wait_for_and_do(
browser, '.nav-link.upload-asset-tab',
lambda tab: tab.click())
browser,
'.nav-link.upload-asset-tab',
lambda tab: tab.click(),
)
wait_for_and_do(
browser, 'input[name="file_upload"]',
lambda file_input: file_input.fill(image_file))
browser,
'input[name="file_upload"]',
lambda file_input: file_input.fill(image_file),
)
wait_for_and_do(
browser, 'input[name="file_upload"]',
lambda file_input: file_input.fill(video_file))
browser,
'input[name="file_upload"]',
lambda file_input: file_input.fill(video_file),
)
sleep(3)
@@ -233,8 +253,7 @@ class WebTest(TestCase):
self.assertEqual(assets[0].name, 'standby.png')
self.assertEqual(assets[0].mimetype, 'image')
self.assertEqual(
assets[0].duration, settings['default_duration'])
self.assertEqual(assets[0].duration, settings['default_duration'])
self.assertEqual(assets[1].name, 'video.mov')
self.assertEqual(assets[1].mimetype, 'video')
@@ -246,12 +265,15 @@ class WebTest(TestCase):
browser.visit(main_page_url)
wait_for_and_do(
browser, '#add-asset-button', lambda btn: btn.click())
browser, '#add-asset-button', lambda btn: btn.click()
)
sleep(1)
wait_for_and_do(
browser, 'input[name="uri"]',
lambda field: field.fill('rtsp://localhost:8091/asset.mov'))
browser,
'input[name="uri"]',
lambda field: field.fill('rtsp://localhost:8091/asset.mov'),
)
sleep(1)
wait_for_and_do(browser, '#add-form', lambda form: form.click())
@@ -268,7 +290,8 @@ class WebTest(TestCase):
self.assertEqual(asset.uri, 'rtsp://localhost:8091/asset.mov')
self.assertEqual(asset.mimetype, 'streaming')
self.assertEqual(
asset.duration, settings['default_streaming_duration'])
asset.duration, settings['default_streaming_duration']
)
@skip('migrate to React-based tests')
def test_remove_asset(self):
@@ -278,9 +301,11 @@ class WebTest(TestCase):
browser.visit(main_page_url)
wait_for_and_do(
browser, '.delete-asset-button', lambda btn: btn.click())
browser, '.delete-asset-button', lambda btn: btn.click()
)
wait_for_and_do(
browser, '.confirm-delete', lambda btn: btn.click())
browser, '.confirm-delete', lambda btn: btn.click()
)
sleep(3)
self.assertEqual(Asset.objects.count(), 0)
@@ -297,21 +322,23 @@ class WebTest(TestCase):
'.form-switch input[type="checkbox"]'
).first
browser.execute_script(
"arguments[0].scrollIntoView(true);",
toggle_element._element
'arguments[0].scrollIntoView(true);', toggle_element._element
)
sleep(1)
# Click the input to trigger the toggle
browser.execute_script(
"arguments[0].click();", toggle_element._element)
'arguments[0].click();', toggle_element._element
)
sleep(2)
# Re-find the element after React re-renders it
toggle_element_after = browser.find_by_css(
'.form-switch input[type="checkbox"]').first
'.form-switch input[type="checkbox"]'
).first
browser.execute_script(
"return arguments[0].checked;", toggle_element_after._element)
'return arguments[0].checked;', toggle_element_after._element
)
# Wait longer for API call to complete
sleep(5)
@@ -326,10 +353,7 @@ class WebTest(TestCase):
# Clear any existing assets first
Asset.objects.all().delete()
Asset.objects.create(**{
**asset_x,
'is_enabled': 1
})
Asset.objects.create(**{**asset_x, 'is_enabled': 1})
with get_browser() as browser:
browser.visit(main_page_url)
@@ -337,22 +361,25 @@ class WebTest(TestCase):
# Find the toggle element and scroll it into view
toggle_element = browser.find_by_css(
'.form-switch input[type="checkbox"]').first
'.form-switch input[type="checkbox"]'
).first
browser.execute_script(
"arguments[0].scrollIntoView(true);", toggle_element._element)
'arguments[0].scrollIntoView(true);', toggle_element._element
)
sleep(1)
# Click the input to trigger the toggle
browser.execute_script(
"arguments[0].click();", toggle_element._element)
'arguments[0].click();', toggle_element._element
)
sleep(2)
# Re-find the element after React re-renders it
toggle_element_after = browser.find_by_css(
'.form-switch input[type="checkbox"]').first
'.form-switch input[type="checkbox"]'
).first
browser.execute_script(
"return arguments[0].checked;",
toggle_element_after._element
'return arguments[0].checked;', toggle_element_after._element
)
# Wait longer for API call to complete
@@ -366,10 +393,7 @@ class WebTest(TestCase):
@skip('migrate to React-based tests')
def test_reorder_asset(self):
Asset.objects.create(**{
**asset_x,
'is_enabled': 1
})
Asset.objects.create(**{**asset_x, 'is_enabled': 1})
Asset.objects.create(**asset_y)
with get_browser() as browser:
@@ -394,12 +418,12 @@ class WebTest(TestCase):
self.assertEqual(
(
'Error: 500 Internal Server Error' in browser.html or
'Error: 504 Gateway Time-out' in browser.html or
'Error: 504 Gateway Timeout' in browser.html
'Error: 500 Internal Server Error' in browser.html
or 'Error: 504 Gateway Time-out' in browser.html
or 'Error: 504 Gateway Timeout' in browser.html
),
False,
'5xx: not expected'
'5xx: not expected',
)
def test_system_info_page_should_work(self):

View File

@@ -14,7 +14,8 @@ class BackupHelperTest(unittest.TestCase):
def setUp(self):
self.dt = datetime(2016, 7, 19, 12, 42, 12)
self.expected_archive_name = (
'anthias-backup-2016-07-19T12-42-12.tar.gz')
'anthias-backup-2016-07-19T12-42-12.tar.gz'
)
self.assertFalse(path.isdir(path.join(home, static_dir)))
def tearDown(self):

View File

@@ -15,7 +15,8 @@ class CeleryTasksTestCase(unittest.TestCase):
celeryapp.conf.update(
CELERY_ALWAYS_EAGER=True,
CELERY_RESULT_BACKEND='',
CELERY_BROKER_URL='')
CELERY_BROKER_URL='',
)
def download_image(self, image_url, image_path):
system('curl {} > {}'.format(image_url, image_path))
@@ -30,7 +31,8 @@ class TestCleanup(CeleryTasksTestCase):
def test_cleanup(self):
cleanup.apply()
tmp_files = [
x for x in listdir(self.assets_path) if x.endswith('.tmp')]
x for x in listdir(self.assets_path) if x.endswith('.tmp')
]
self.assertEqual(len(tmp_files), 0)
def tearDown(self):

View File

@@ -25,12 +25,10 @@ ASSET_X = {
'nocache': 0,
'is_processing': 0,
'play_order': 1,
'skip_asset_check': 0
'skip_asset_check': 0,
}
ASSET_X_DIFF = {
'duration': 10
}
ASSET_X_DIFF = {'duration': 10}
ASSET_Y = {
'mimetype': 'image',
@@ -44,7 +42,7 @@ ASSET_Y = {
'nocache': 0,
'is_processing': 0,
'play_order': 0,
'skip_asset_check': 0
'skip_asset_check': 0,
}
ASSET_Z = {
@@ -59,7 +57,7 @@ ASSET_Z = {
'nocache': 0,
'is_processing': 0,
'play_order': 2,
'skip_asset_check': 0
'skip_asset_check': 0,
}
ASSET_TOMORROW = {
@@ -74,7 +72,7 @@ ASSET_TOMORROW = {
'nocache': 0,
'is_processing': 0,
'play_order': 2,
'skip_asset_check': 0
'skip_asset_check': 0,
}
FAKE_DB_PATH = '/tmp/fakedb'
@@ -88,7 +86,9 @@ class SchedulerTest(TestCase):
for asset in assets:
Asset.objects.create(**asset)
def test_generate_asset_list_assets_should_return_list_sorted_by_play_order(self): # noqa: E501
def test_generate_asset_list_assets_should_return_list_sorted_by_play_order(
self,
): # noqa: E501
self.create_assets([ASSET_X, ASSET_Y])
assets, _ = generate_asset_list()
self.assertEqual(assets, [ASSET_Y, ASSET_X])

View File

@@ -52,6 +52,7 @@ def fake_settings(raw):
try:
import settings
yield (settings, settings.settings)
del sys.modules['settings']
finally:
@@ -89,19 +90,24 @@ class SettingsTest(TestCase):
with fake_settings(empty_settings) as (mod_settings, settings):
self.assertEqual(
settings['player_name'],
mod_settings.DEFAULTS['viewer']['player_name'])
mod_settings.DEFAULTS['viewer']['player_name'],
)
self.assertEqual(
settings['show_splash'],
mod_settings.DEFAULTS['viewer']['show_splash'])
mod_settings.DEFAULTS['viewer']['show_splash'],
)
self.assertEqual(
settings['shuffle_playlist'],
mod_settings.DEFAULTS['viewer']['shuffle_playlist'])
mod_settings.DEFAULTS['viewer']['shuffle_playlist'],
)
self.assertEqual(
settings['debug_logging'],
mod_settings.DEFAULTS['viewer']['debug_logging'])
mod_settings.DEFAULTS['viewer']['debug_logging'],
)
self.assertEqual(
settings['default_duration'],
mod_settings.DEFAULTS['viewer']['default_duration'])
mod_settings.DEFAULTS['viewer']['default_duration'],
)
def broken_settings_should_raise_value_error(self):
with self.assertRaises(ValueError):

View File

@@ -22,7 +22,9 @@ class UpdateTest(ParametrizedTestCase):
'lib.github.fetch_remote_hash',
mock.MagicMock(return_value=(None, False)),
)
def test__if_git_branch_env_does_not_exist__is_up_to_date_should_return_true(self): # noqa: E501
def test__if_git_branch_env_does_not_exist__is_up_to_date_should_return_true(
self,
): # noqa: E501
self.assertEqual(is_up_to_date(), True)
@parametrize(
@@ -71,7 +73,8 @@ class UpdateTest(ParametrizedTestCase):
mock.MagicMock(return_value='master'),
)
def test_is_up_to_date_should_return_value_depending_on_git_hashes(
self, hashes, expected):
self, hashes, expected
):
os.environ['GIT_BRANCH'] = 'master'
os.environ['DEVICE_TYPE'] = 'pi4'

View File

@@ -14,10 +14,10 @@ uri_ = '/home/user/file'
class UtilsTest(unittest.TestCase):
def test_unicode_correctness_in_bottle_templates(self):
self.assertEqual(template_handle_unicode('hello'), u'hello')
self.assertEqual(template_handle_unicode('hello'), 'hello')
self.assertEqual(
template_handle_unicode('Привет'),
u'\u041f\u0440\u0438\u0432\u0435\u0442',
'\u041f\u0440\u0438\u0432\u0435\u0442',
)
def test_json_tz(self):

View File

@@ -23,18 +23,21 @@ class ViewerTestCase(unittest.TestCase):
self.m_scheduler = mock.Mock(name='m_scheduler')
self.p_scheduler = mock.patch.object(
self.u, 'Scheduler', self.m_scheduler)
self.u, 'Scheduler', self.m_scheduler
)
self.m_cmd = mock.Mock(name='m_cmd')
self.p_cmd = mock.patch.object(self.u.sh, 'Command', self.m_cmd)
self.m_killall = mock.Mock(name='killall')
self.p_killall = mock.patch.object(
self.u.sh, 'killall', self.m_killall)
self.u.sh, 'killall', self.m_killall
)
self.m_reload = mock.Mock(name='reload')
self.p_reload = mock.patch.object(
self.u, 'load_settings', self.m_reload)
self.u, 'load_settings', self.m_reload
)
self.m_sleep = mock.Mock(name='sleep')
self.p_sleep = mock.patch.object(self.u, 'sleep', self.m_sleep)

View File

@@ -50,8 +50,7 @@ def build_image(
cache_dir.mkdir(parents=True, exist_ok=True)
except Exception as e:
click.secho(
f'Warning: Failed to create cache directory: {e}',
fg='yellow'
f'Warning: Failed to create cache directory: {e}', fg='yellow'
)
base_apt_dependencies = [
@@ -94,19 +93,22 @@ def build_image(
elif service == 'wifi-connect':
context.update(get_wifi_connect_context(target_platform))
generate_dockerfile(service, {
'base_image': base_image,
'base_image_tag': 'bookworm',
'base_apt_dependencies': base_apt_dependencies,
'board': board,
'debian_version': 'bookworm',
'disable_cache_mounts': disable_cache_mounts,
'environment': environment,
'git_branch': git_branch,
'git_hash': git_hash,
'git_short_hash': git_short_hash,
**context,
})
generate_dockerfile(
service,
{
'base_image': base_image,
'base_image_tag': 'bookworm',
'base_apt_dependencies': base_apt_dependencies,
'board': board,
'debian_version': 'bookworm',
'disable_cache_mounts': disable_cache_mounts,
'environment': environment,
'git_branch': git_branch,
'git_hash': git_hash,
'git_short_hash': git_short_hash,
**context,
},
)
if service == 'test':
click.secho(f'Skipping test service for {board}...', fg='yellow')
@@ -127,12 +129,16 @@ def build_image(
cache_from={
'type': 'local',
'src': str(cache_dir),
} if not clean_build else None,
}
if not clean_build
else None,
cache_to={
'type': 'local',
'dest': str(cache_dir),
'mode': 'max',
} if not clean_build else None,
}
if not clean_build
else None,
builder='multiarch-builder',
file=f'docker/Dockerfile.{service}',
load=True,
@@ -160,10 +166,12 @@ def build_image(
@click.option(
'--service',
default=['all'],
type=click.Choice((
'all',
*SERVICES,
)),
type=click.Choice(
(
'all',
*SERVICES,
)
),
multiple=True,
)
@click.option(
@@ -213,7 +221,8 @@ def main(
# Define tag components
namespaces = ['screenly/anthias', 'screenly/srly-ose']
version_suffix = (
f'{board}-64' if board == 'pi4' and platform == 'linux/arm64/v8'
f'{board}-64'
if board == 'pi4' and platform == 'linux/arm64/v8'
else f'{board}'
)
@@ -244,5 +253,5 @@ def main(
)
if __name__ == "__main__":
if __name__ == '__main__':
main()

View File

@@ -227,18 +227,22 @@ def get_viewer_context(board: str) -> dict:
]
if board in ['pi5', 'x86']:
apt_dependencies.extend([
'qt6-base-dev',
'qt6-webengine-dev',
])
apt_dependencies.extend(
[
'qt6-base-dev',
'qt6-webengine-dev',
]
)
if board not in ['x86', 'pi5']:
apt_dependencies.extend([
'libraspberrypi0',
'libgst-dev',
'libsqlite0-dev',
'libsrtp0-dev',
])
apt_dependencies.extend(
[
'libraspberrypi0',
'libgst-dev',
'libsqlite0-dev',
'libsrtp0-dev',
]
)
if board != 'pi1':
apt_dependencies.extend(['libssl1.1'])
@@ -269,29 +273,25 @@ def get_wifi_connect_context(target_platform: str) -> dict:
return {}
wc_download_url = (
'https://api.github.com/repos/balena-os/wifi-connect/'
'releases/93025295'
'https://api.github.com/repos/balena-os/wifi-connect/releases/93025295'
)
try:
response = requests.get(wc_download_url)
response.raise_for_status()
data = response.json()
assets = [
asset['browser_download_url'] for asset in data['assets']
]
assets = [asset['browser_download_url'] for asset in data['assets']]
try:
archive_url = next(
asset for asset in assets
if f'linux-{architecture}' in asset
asset for asset in assets if f'linux-{architecture}' in asset
)
except StopIteration:
click.secho(
'No wifi-connect release found for this architecture.',
fg='red',
)
archive_url = ""
archive_url = ''
except requests.exceptions.RequestException as e:
click.secho(f'Failed to get wifi-connect release: {e}', fg='red')

View File

@@ -25,6 +25,7 @@ token = None
# Utilities #
#############
def progress_bar(count, total, asset_name='', previous_asset_name=''):
"""
This simple console progress bar
@@ -35,9 +36,8 @@ def progress_bar(count, total, asset_name='', previous_asset_name=''):
# displayed, if the current asset name is shorter than the previous one.
text = f'{asset_name}'.ljust(len(previous_asset_name))
progress_line = (
'#' * int(round(50 * count / float(total))) +
'-' * (50 - int(round(50 * count / float(total))))
progress_line = '#' * int(round(50 * count / float(total))) + '-' * (
50 - int(round(50 * count / float(total)))
)
percent = round(100.0 * count / float(total), 1)
sys.stdout.write(f'[{progress_line}] {percent}% {text}\r')
@@ -53,6 +53,7 @@ def set_token(value):
# Database #
############
def get_assets_by_anthias_api():
if click.confirm('Do you need authentication to access Anthias API?'):
login = click.prompt('Login')
@@ -70,6 +71,7 @@ def get_assets_by_anthias_api():
# Requests #
############
@retry
def get_post_response(endpoint_url, **kwargs):
return requests.post(endpoint_url, **kwargs)
@@ -80,23 +82,17 @@ def send_asset(asset):
asset_uri = asset['uri']
post_kwargs = {
'data': {'title': asset['name']},
'headers': {
'Authorization': token,
'Prefer': 'return=representation'
}
'headers': {'Authorization': token, 'Prefer': 'return=representation'},
}
try:
if asset['mimetype'] in ['image', 'video']:
if asset_uri.startswith('/data'):
asset_uri = os.path.join(
HOME, 'screenly_assets', os.path.basename(asset_uri))
HOME, 'screenly_assets', os.path.basename(asset_uri)
)
post_kwargs.update({
'files': {
'file': open(asset_uri, 'rb')
}
})
post_kwargs.update({'files': {'file': open(asset_uri, 'rb')}})
else:
post_kwargs['data'].update({'source_url': asset_uri})
except FileNotFoundError as error:
@@ -115,9 +111,7 @@ def send_asset(asset):
def check_validate_token(api_key):
endpoint_url = f'{BASE_API_SCREENLY_URL}/api/v4/assets'
headers = {
'Authorization': f'Token {api_key}'
}
headers = {'Authorization': f'Token {api_key}'}
response = requests.get(endpoint_url, headers=headers)
if response.status_code == 200:
return api_key
@@ -129,6 +123,7 @@ def check_validate_token(api_key):
# Main #
########
def start_migration():
if click.confirm('Do you want to start assets migration?'):
assets_migration()
@@ -154,7 +149,7 @@ def assets_migration():
index + 1,
assets_length,
asset_name=shortened_asset_name,
previous_asset_name=previous_asset_name
previous_asset_name=previous_asset_name,
)
previous_asset_name = shortened_asset_name
@@ -185,7 +180,7 @@ def assets_migration():
Your choice
"""
),
type=click.Choice(['1', '2'])
type=click.Choice(['1', '2']),
)
def main(method):
try:
@@ -208,7 +203,8 @@ def main(method):
if __name__ == '__main__':
click.secho(cleandoc("""
click.secho(
cleandoc("""
d8888 888 888
d88888 888 888 888
d88P888 888 888
@@ -217,7 +213,9 @@ if __name__ == '__main__':
d88P 888 888 888 888 888 888 888 .d888888 'Y8888b.
d8888888888 888 888 Y88b. 888 888 888 888 888 X88
d88P 888 888 888 Y888 888 888 888 'Y888888 88888P'
"""), fg='cyan')
"""),
fg='cyan',
)
click.echo()

View File

@@ -58,9 +58,9 @@ except Exception:
standard_library.install_aliases()
__author__ = "Screenly, Inc"
__copyright__ = "Copyright 2012-2024, Screenly, Inc"
__license__ = "Dual License: GPLv2 and Commercial License"
__author__ = 'Screenly, Inc'
__copyright__ = 'Copyright 2012-2024, Screenly, Inc'
__license__ = 'Dual License: GPLv2 and Commercial License'
current_browser_url = None
@@ -146,7 +146,7 @@ commands = {
'setup_wifi': lambda data: setup_wifi(data),
'show_splash': lambda data: show_splash(data),
'unknown': lambda _: command_not_found(),
'current_asset_id': lambda _: send_current_asset_id_to_server()
'current_asset_id': lambda _: send_current_asset_id_to_server(),
}
@@ -155,8 +155,8 @@ def load_browser():
logging.info('Loading browser...')
browser = sh.Command('ScreenlyWebview')(_bg=True, _err_to_out=True)
while (
'Screenly service start' not in browser.process.stdout.decode('utf-8')
while 'Screenly service start' not in browser.process.stdout.decode(
'utf-8'
):
sleep(1)
@@ -227,21 +227,22 @@ def asset_loop(scheduler):
if asset is None:
logging.info(
'Playlist is empty. Sleeping for %s seconds', EMPTY_PL_DELAY)
'Playlist is empty. Sleeping for %s seconds', EMPTY_PL_DELAY
)
view_image(STANDBY_SCREEN)
skip_event = get_skip_event()
skip_event.clear()
if skip_event.wait(timeout=EMPTY_PL_DELAY):
# Skip was triggered, continue immediately to next iteration
logging.info(
'Skip detected during empty playlist wait, continuing')
'Skip detected during empty playlist wait, continuing'
)
else:
# Duration elapsed normally, continue to next iteration
pass
elif (
path.isfile(asset['uri']) or
(not url_fails(asset['uri']) or asset['skip_asset_check'])
elif path.isfile(asset['uri']) or (
not url_fails(asset['uri']) or asset['skip_asset_check']
):
name, mime, uri = asset['name'], asset['mimetype'], asset['uri']
logging.info('Showing asset %s (%s)', name, mime)
@@ -270,14 +271,18 @@ def asset_loop(scheduler):
pass
else:
logging.info('Asset %s at %s is not available, skipping.',
asset['name'], asset['uri'])
logging.info(
'Asset %s at %s is not available, skipping.',
asset['name'],
asset['uri'],
)
skip_event = get_skip_event()
skip_event.clear()
if skip_event.wait(timeout=0.5):
# Skip was triggered, continue immediately to next iteration
logging.info(
'Skip detected during asset unavailability wait, continuing')
'Skip detected during asset unavailability wait, continuing'
)
else:
# Duration elapsed normally, continue to next iteration
pass

View File

@@ -2,9 +2,9 @@ import logging
from viewer import main
if __name__ == "__main__":
if __name__ == '__main__':
try:
main()
except Exception:
logging.exception("Viewer crashed.")
logging.exception('Viewer crashed.')
raise

View File

@@ -11,7 +11,7 @@ from settings import settings
VIDEO_TIMEOUT = 20 # secs
class MediaPlayer():
class MediaPlayer:
def __init__(self):
pass
@@ -40,7 +40,7 @@ class FFMPEGMediaPlayer(MediaPlayer):
self.process = subprocess.Popen(
['ffplay', '-autoexit', self.uri],
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL
stderr=subprocess.DEVNULL,
)
def stop(self):
@@ -90,7 +90,8 @@ class VLCMediaPlayer(MediaPlayer):
self.player.set_mrl(uri)
settings.load()
self.player.audio_output_device_set(
'alsa', self.get_alsa_audio_device())
'alsa', self.get_alsa_audio_device()
)
def play(self):
self.player.play()
@@ -100,10 +101,13 @@ class VLCMediaPlayer(MediaPlayer):
def is_playing(self):
return self.player.get_state() in [
vlc.State.Playing, vlc.State.Buffering, vlc.State.Opening]
vlc.State.Playing,
vlc.State.Buffering,
vlc.State.Opening,
]
class MediaPlayerProxy():
class MediaPlayerProxy:
INSTANCE = None
@classmethod

View File

@@ -19,16 +19,14 @@ def get_specific_asset(asset_id):
def generate_asset_list():
"""Choose deadline via:
1. Map assets to deadlines with rule: if asset is active then
'end_date' else 'start_date'
2. Get nearest deadline
1. Map assets to deadlines with rule: if asset is active then
'end_date' else 'start_date'
2. Get nearest deadline
"""
logging.info('Generating asset-list...')
assets = Asset.objects.all()
deadlines = [
asset.end_date
if asset.is_active()
else asset.start_date
asset.end_date if asset.is_active() else asset.start_date
for asset in assets
]
@@ -38,10 +36,7 @@ def generate_asset_list():
end_date__isnull=False,
).order_by('play_order')
playlist = [
{
k: v for k, v in asset.__dict__.items()
if k not in ['_state', 'md5']
}
{k: v for k, v in asset.__dict__.items() if k not in ['_state', 'md5']}
for asset in enabled_assets
if asset.is_active()
]
@@ -76,7 +71,7 @@ class Scheduler(object):
self.current_asset_id = self.extra_asset
self.extra_asset = None
return asset
logging.error("Asset not found or processed")
logging.error('Asset not found or processed')
self.extra_asset = None
self.refresh_playlist()
@@ -94,7 +89,9 @@ class Scheduler(object):
logging.debug(
'get_next_asset counter %s returning asset %s of %s',
self.counter, idx + 1, len(self.assets),
self.counter,
idx + 1,
len(self.assets),
)
if settings['shuffle_playlist'] and self.index == 0:
@@ -110,7 +107,9 @@ class Scheduler(object):
logging.debug(
'refresh: counter: (%s) deadline (%s) timecur (%s)',
self.counter, self.deadline, time_cur
self.counter,
self.deadline,
time_cur,
)
if self.get_db_mtime() > self.last_update_db_mtime:
@@ -137,7 +136,10 @@ class Scheduler(object):
self.index = self.index % len(self.assets) if self.assets else 0
logging.debug(
'update_playlist done, count %s, counter %s, index %s, deadline %s', # noqa: E501
len(self.assets), self.counter, self.index, self.deadline
len(self.assets),
self.counter,
self.index,
self.deadline,
)
def get_db_mtime(self):

View File

@@ -14,7 +14,7 @@ def sigalrm(signum, frame):
"""
Signal just throw an SigalrmError
"""
raise SigalrmError("SigalrmError")
raise SigalrmError('SigalrmError')
def get_skip_event():
@@ -22,11 +22,12 @@ def get_skip_event():
Get the global skip event for instant asset switching.
"""
from viewer.playback import skip_event
return skip_event
def command_not_found():
logging.error("Command not found")
logging.error('Command not found')
def watchdog():

View File

@@ -47,12 +47,15 @@ class AnthiasServerListener(Thread):
socket_outgoing.send(msg)
if __name__ == "__main__":
if __name__ == '__main__':
context = zmq.Context()
listener = AnthiasServerListener(context)
listener.start()
port = int(settings['websocket_port'])
server = pywsgi.WSGIServer(("", port), WebSocketTranslator(context),
handler_class=WebSocketHandler)
server = pywsgi.WSGIServer(
('', port),
WebSocketTranslator(context),
handler_class=WebSocketHandler,
)
server.serve_forever()