diff --git a/src/galaxy/api/plugin.py b/src/galaxy/api/plugin.py index dbe746b..23d3f33 100644 --- a/src/galaxy/api/plugin.py +++ b/src/galaxy/api/plugin.py @@ -1,21 +1,18 @@ import asyncio +import dataclasses import json import logging import logging.handlers -import dataclasses -from enum import Enum -from collections import OrderedDict -from itertools import count import sys +from collections import OrderedDict +from enum import Enum +from itertools import count +from typing import Any, Dict, List, Optional, Set, Union -from typing import Any, List, Dict, Optional, Union - -from galaxy.api.types import Achievement, Game, LocalGame, FriendInfo, GameTime - -from galaxy.api.jsonrpc import Server, NotificationClient, ApplicationError from galaxy.api.consts import Feature -from galaxy.api.errors import UnknownError, ImportInProgress -from galaxy.api.types import Authentication, NextStep +from galaxy.api.errors import ImportInProgress, UnknownError +from galaxy.api.jsonrpc import ApplicationError, NotificationClient, Server +from galaxy.api.types import Achievement, Authentication, FriendInfo, Game, GameTime, LocalGame, NextStep class JSONEncoder(json.JSONEncoder): @@ -24,6 +21,7 @@ class JSONEncoder(json.JSONEncoder): # filter None values def dict_factory(elements): return {k: v for k, v in elements if v is not None} + return dataclasses.asdict(o, dict_factory=dict_factory) if isinstance(o, Enum): return o.value @@ -32,12 +30,13 @@ class JSONEncoder(json.JSONEncoder): class Plugin: """Use and override methods of this class to create a new platform integration.""" + def __init__(self, platform, version, reader, writer, handshake_token): logging.info("Creating plugin for platform %s, version %s", platform.value, version) self._platform = platform self._version = version - self._feature_methods = OrderedDict() + self._features: Set[Feature] = set() self._active = True self._pass_control_task = None @@ -50,6 +49,7 @@ class Plugin: def eof_handler(): self._shutdown() + self._server.register_eof(eof_handler) self._achievements_import_in_progress = False @@ -85,63 +85,47 @@ class Plugin: self._register_method( "import_owned_games", self.get_owned_games, - result_name="owned_games", - feature=Feature.ImportOwnedGames + result_name="owned_games" ) + self._detect_feature(Feature.ImportOwnedGames, ["get_owned_games"]) + self._register_method( "import_unlocked_achievements", self.get_unlocked_achievements, - result_name="unlocked_achievements", - feature=Feature.ImportAchievements - ) - self._register_method( - "start_achievements_import", - self.start_achievements_import, - ) - self._register_method( - "import_local_games", - self.get_local_games, - result_name="local_games", - feature=Feature.ImportInstalledGames - ) - self._register_notification("launch_game", self.launch_game, feature=Feature.LaunchGame) - self._register_notification("install_game", self.install_game, feature=Feature.InstallGame) - self._register_notification( - "uninstall_game", - self.uninstall_game, - feature=Feature.UninstallGame - ) - self._register_notification( - "shutdown_platform_client", - self.shutdown_platform_client, - feature=Feature.ShutdownPlatformClient - ) - self._register_method( - "import_friends", - self.get_friends, - result_name="friend_info_list", - feature=Feature.ImportFriends - ) - self._register_method( - "import_game_times", - self.get_game_times, - result_name="game_times", - feature=Feature.ImportGameTime - ) - self._register_method( - "start_game_times_import", - self.start_game_times_import, + result_name="unlocked_achievements" ) + self._detect_feature(Feature.ImportAchievements, ["get_unlocked_achievements"]) + + self._register_method("start_achievements_import", self.start_achievements_import) + self._detect_feature(Feature.ImportAchievements, ["import_games_achievements"]) + + self._register_method("import_local_games", self.get_local_games, result_name="local_games") + self._detect_feature(Feature.ImportInstalledGames, ["get_local_games"]) + + self._register_notification("launch_game", self.launch_game) + self._detect_feature(Feature.LaunchGame, ["launch_game"]) + + self._register_notification("install_game", self.install_game) + self._detect_feature(Feature.InstallGame, ["install_game"]) + + self._register_notification("uninstall_game", self.uninstall_game) + self._detect_feature(Feature.UninstallGame, ["uninstall_game"]) + + self._register_notification("shutdown_platform_client", self.shutdown_platform_client) + self._detect_feature(Feature.ShutdownPlatformClient, ["shutdown_platform_client"]) + + self._register_method("import_friends", self.get_friends, result_name="friend_info_list") + self._detect_feature(Feature.ImportFriends, ["get_friends"]) + + self._register_method("import_game_times", self.get_game_times, result_name="game_times") + self._detect_feature(Feature.ImportGameTime, ["get_game_times"]) + + self._register_method("start_game_times_import", self.start_game_times_import) + self._detect_feature(Feature.ImportGameTime, ["import_game_times"]) @property - def features(self): - features = [] - if self.__class__ != Plugin: - for feature, handlers in self._feature_methods.items(): - if self._implements(handlers): - features.append(feature) - - return features + def features(self) -> List[Feature]: + return list(self._features) @property def persistent_cache(self) -> Dict: @@ -149,13 +133,17 @@ class Plugin: """ return self._persistent_cache - def _implements(self, handlers): - for handler in handlers: - if handler.__name__ not in self.__class__.__dict__: + def _implements(self, methods: List[str]) -> bool: + for method in methods: + if method not in self.__class__.__dict__: return False return True - def _register_method(self, name, handler, result_name=None, internal=False, sensitive_params=False, feature=None): + def _detect_feature(self, feature: Feature, methods: List[str]): + if self._implements(methods): + self._features.add(feature) + + def _register_method(self, name, handler, result_name=None, internal=False, sensitive_params=False): if internal: def method(*args, **kwargs): result = handler(*args, **kwargs) @@ -164,6 +152,7 @@ class Plugin: result_name: result } return result + self._server.register_method(name, method, True, sensitive_params) else: async def method(*args, **kwargs): @@ -173,17 +162,12 @@ class Plugin: result_name: result } return result + self._server.register_method(name, method, False, sensitive_params) - if feature is not None: - self._feature_methods.setdefault(feature, []).append(handler) - - def _register_notification(self, name, handler, internal=False, sensitive_params=False, feature=None): + def _register_notification(self, name, handler, internal=False, sensitive_params=False): self._server.register_notification(name, handler, internal, sensitive_params) - if feature is not None: - self._feature_methods.setdefault(feature, []).append(handler) - async def run(self): """Plugin's main coroutine.""" await self._server.run() @@ -192,6 +176,7 @@ class Plugin: def create_task(self, coro, description): """Wrapper around asyncio.create_task - takes care of canceling tasks on shutdown""" + async def task_wrapper(task_id): try: return await coro @@ -524,7 +509,7 @@ class Plugin: raise NotImplementedError() async def pass_login_credentials(self, step: str, credentials: Dict[str, str], cookies: List[Dict[str, str]]) \ - -> Union[NextStep, Authentication]: + -> Union[NextStep, Authentication]: """This method is called if we return galaxy.api.types.NextStep from authenticate or from pass_login_credentials. This method's parameters provide the data extracted from the web page navigation that previous NextStep finished on. This method should either return galaxy.api.types.Authentication if the authentication is finished @@ -607,6 +592,7 @@ class Plugin: :param game_ids: ids of the games for which to import unlocked achievements """ + async def import_game_achievements(game_id): try: achievements = await self.get_unlocked_achievements(game_id) diff --git a/tests/conftest.py b/tests/conftest.py index 4c74e42..9cd4b9a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -58,6 +58,7 @@ def plugin(reader, writer): stack.enter_context(patch.object(Plugin, method)) yield Plugin(Platform.Generic, "0.1", reader, writer, "token") + @pytest.fixture(autouse=True) def my_caplog(caplog): caplog.set_level(logging.DEBUG) diff --git a/tests/test_features.py b/tests/test_features.py index dce5fb0..1b518db 100644 --- a/tests/test_features.py +++ b/tests/test_features.py @@ -1,21 +1,49 @@ +from galaxy.api.consts import Feature, Platform from galaxy.api.plugin import Plugin -from galaxy.api.consts import Platform, Feature + def test_base_class(): plugin = Plugin(Platform.Generic, "0.1", None, None, None) - assert plugin.features == [] + assert set(plugin.features) == { + Feature.ImportInstalledGames, + Feature.ImportOwnedGames, + Feature.LaunchGame, + Feature.InstallGame, + Feature.UninstallGame, + Feature.ImportAchievements, + Feature.ImportGameTime, + Feature.ImportFriends, + Feature.ShutdownPlatformClient + } + def test_no_overloads(): - class PluginImpl(Plugin): #pylint: disable=abstract-method + class PluginImpl(Plugin): # pylint: disable=abstract-method pass plugin = PluginImpl(Platform.Generic, "0.1", None, None, None) assert plugin.features == [] + def test_one_method_feature(): - class PluginImpl(Plugin): #pylint: disable=abstract-method + class PluginImpl(Plugin): # pylint: disable=abstract-method async def get_owned_games(self): pass plugin = PluginImpl(Platform.Generic, "0.1", None, None, None) - assert plugin.features == [Feature.ImportOwnedGames] \ No newline at end of file + assert plugin.features == [Feature.ImportOwnedGames] + + +def test_multi_features(): + class PluginImpl(Plugin): # pylint: disable=abstract-method + async def get_owned_games(self): + pass + + async def import_games_achievements(self, game_ids) -> None: + pass + + async def start_game_times_import(self, game_ids) -> None: + pass + + plugin = PluginImpl(Platform.Generic, "0.1", None, None, None) + assert set(plugin.features) == {Feature.ImportAchievements, Feature.ImportOwnedGames}