diff --git a/src/galaxy/api/jsonrpc.py b/src/galaxy/api/jsonrpc.py index 87bff71..8b14ca7 100644 --- a/src/galaxy/api/jsonrpc.py +++ b/src/galaxy/api/jsonrpc.py @@ -6,6 +6,7 @@ import inspect import json from galaxy.reader import StreamLineReader +from galaxy.task_manager import TaskManager class JsonRpcError(Exception): def __init__(self, code, message, data=None): @@ -52,7 +53,8 @@ class UnknownError(ApplicationError): super().__init__(0, "Unknown error", data) Request = namedtuple("Request", ["method", "params", "id"], defaults=[{}, None]) -Method = namedtuple("Method", ["callback", "signature", "internal", "sensitive_params"]) +Method = namedtuple("Method", ["callback", "signature", "immediate", "sensitive_params"]) + def anonymise_sensitive_params(params, sensitive_params): anomized_data = "****" @@ -74,9 +76,9 @@ class Server(): self._encoder = encoder self._methods = {} self._notifications = {} - self._eof_listeners = [] + self._task_manager = TaskManager("jsonrpc server") - def register_method(self, name, callback, internal, sensitive_params=False): + def register_method(self, name, callback, immediate, sensitive_params=False): """ Register method @@ -86,9 +88,9 @@ class Server(): :param sensitive_params: list of parameters that are anonymized before logging; \ if False - no params are considered sensitive, if True - all params are considered sensitive """ - self._methods[name] = Method(callback, inspect.signature(callback), internal, sensitive_params) + self._methods[name] = Method(callback, inspect.signature(callback), immediate, sensitive_params) - def register_notification(self, name, callback, internal, sensitive_params=False): + def register_notification(self, name, callback, immediate, sensitive_params=False): """ Register notification @@ -98,10 +100,7 @@ class Server(): :param sensitive_params: list of parameters that are anonymized before logging; \ if False - no params are considered sensitive, if True - all params are considered sensitive """ - self._notifications[name] = Method(callback, inspect.signature(callback), internal, sensitive_params) - - def register_eof(self, callback): - self._eof_listeners.append(callback) + self._notifications[name] = Method(callback, inspect.signature(callback), immediate, sensitive_params) async def run(self): while self._active: @@ -118,14 +117,16 @@ class Server(): self._handle_input(data) await asyncio.sleep(0) # To not starve task queue - def stop(self): + def close(self): + logging.info("Closing JSON-RPC server - not more messages will be read") self._active = False + async def wait_closed(self): + await self._task_manager.wait() + def _eof(self): logging.info("Received EOF") - self.stop() - for listener in self._eof_listeners: - listener() + self.close() def _handle_input(self, data): try: @@ -145,7 +146,7 @@ class Server(): logging.error("Received unknown notification: %s", request.method) return - callback, signature, internal, sensitive_params = method + callback, signature, immediate, sensitive_params = method self._log_request(request, sensitive_params) try: @@ -153,12 +154,11 @@ class Server(): except TypeError: self._send_error(request.id, InvalidParams()) - if internal: - # internal requests are handled immediately + if immediate: callback(*bound_args.args, **bound_args.kwargs) else: try: - asyncio.create_task(callback(*bound_args.args, **bound_args.kwargs)) + self._task_manager.create_task(callback(*bound_args.args, **bound_args.kwargs), request.method) except Exception: logging.exception("Unexpected exception raised in notification handler") @@ -169,7 +169,7 @@ class Server(): self._send_error(request.id, MethodNotFound()) return - callback, signature, internal, sensitive_params = method + callback, signature, immediate, sensitive_params = method self._log_request(request, sensitive_params) try: @@ -177,8 +177,7 @@ class Server(): except TypeError: self._send_error(request.id, InvalidParams()) - if internal: - # internal requests are handled immediately + if immediate: response = callback(*bound_args.args, **bound_args.kwargs) self._send_response(request.id, response) else: @@ -190,11 +189,13 @@ class Server(): self._send_error(request.id, MethodNotFound()) except JsonRpcError as error: self._send_error(request.id, error) + except asyncio.CancelledError: + self._send_error(request.id, Aborted()) except Exception as e: #pylint: disable=broad-except logging.exception("Unexpected exception raised in plugin handler") self._send_error(request.id, UnknownError(str(e))) - asyncio.create_task(handle()) + self._task_manager.create_task(handle(), request.method) @staticmethod def _parse_request(data): @@ -215,7 +216,7 @@ class Server(): logging.debug("Sending data: %s", line) data = (line + "\n").encode("utf-8") self._writer.write(data) - asyncio.create_task(self._writer.drain()) + self._task_manager.create_task(self._writer.drain(), "drain") except TypeError as error: logging.error(str(error)) @@ -255,6 +256,7 @@ class NotificationClient(): self._writer = writer self._encoder = encoder self._methods = {} + self._task_manager = TaskManager("notification client") def notify(self, method, params, sensitive_params=False): """ @@ -273,13 +275,16 @@ class NotificationClient(): self._log(method, params, sensitive_params) self._send(notification) + async def close(self): + await self._task_manager.wait() + def _send(self, data): try: line = self._encoder.encode(data) data = (line + "\n").encode("utf-8") logging.debug("Sending %d byte of data", len(data)) self._writer.write(data) - asyncio.create_task(self._writer.drain()) + self._task_manager.create_task(self._writer.drain(), "drain") except TypeError as error: logging.error("Failed to parse outgoing message: %s", str(error)) diff --git a/src/galaxy/api/plugin.py b/src/galaxy/api/plugin.py index 8add108..f573ebb 100644 --- a/src/galaxy/api/plugin.py +++ b/src/galaxy/api/plugin.py @@ -4,16 +4,14 @@ import json import logging import logging.handlers import sys -from collections import OrderedDict from enum import Enum -from itertools import count from typing import Any, Dict, List, Optional, Set, Union from galaxy.api.consts import Feature from galaxy.api.errors import ImportInProgress, UnknownError from galaxy.api.jsonrpc import ApplicationError, NotificationClient, Server from galaxy.api.types import Achievement, Authentication, FriendInfo, Game, GameTime, LocalGame, NextStep - +from galaxy.task_manager import TaskManager class JSONEncoder(json.JSONEncoder): def default(self, o): # pylint: disable=method-hidden @@ -38,7 +36,6 @@ class Plugin: self._features: Set[Feature] = set() self._active = True - self._pass_control_task = None self._reader, self._writer = reader, writer self._handshake_token = handshake_token @@ -47,29 +44,25 @@ class Plugin: self._server = Server(self._reader, self._writer, encoder) self._notification_client = NotificationClient(self._writer, encoder) - def eof_handler(): - self._shutdown() - - self._server.register_eof(eof_handler) - self._achievements_import_in_progress = False self._game_times_import_in_progress = False self._persistent_cache = dict() - self._tasks = OrderedDict() - self._task_counter = count() + self._internal_task_manager = TaskManager("plugin internal") + self._external_task_manager = TaskManager("plugin external") # internal self._register_method("shutdown", self._shutdown, internal=True) - self._register_method("get_capabilities", self._get_capabilities, internal=True) + self._register_method("get_capabilities", self._get_capabilities, internal=True, immediate=True) self._register_method( "initialize_cache", self._initialize_cache, internal=True, + immediate=True, sensitive_params="data" ) - self._register_method("ping", self._ping, internal=True) + self._register_method("ping", self._ping, internal=True, immediate=True) # implemented by developer self._register_method( @@ -116,6 +109,13 @@ class Plugin: self._register_method("start_game_times_import", self._start_game_times_import) self._detect_feature(Feature.ImportGameTime, ["get_game_time"]) + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + self.close() + await self.wait_closed() + @property def features(self) -> List[Feature]: return list(self._features) @@ -136,55 +136,65 @@ class Plugin: if self._implements(methods): self._features.add(feature) - def _register_method(self, name, handler, result_name=None, internal=False, sensitive_params=False): - if internal: + def _register_method(self, name, handler, result_name=None, internal=False, immediate=False, sensitive_params=False): + def wrap_result(result): + if result_name: + result = { + result_name: result + } + return result + + if immediate: def method(*args, **kwargs): result = handler(*args, **kwargs) - if result_name: - result = { - result_name: result - } - return result + return wrap_result(result) self._server.register_method(name, method, True, sensitive_params) else: async def method(*args, **kwargs): - result = await handler(*args, **kwargs) - if result_name: - result = { - result_name: result - } - return result + if not internal: + handler_ = self._wrap_external_method(handler, name) + else: + handler_ = handler + result = await handler_(*args, **kwargs) + return wrap_result(result) self._server.register_method(name, method, False, sensitive_params) - def _register_notification(self, name, handler, internal=False, sensitive_params=False): - self._server.register_notification(name, handler, internal, sensitive_params) + def _register_notification(self, name, handler, internal=False, immediate=False, sensitive_params=False): + if not internal and not immediate: + handler = self._wrap_external_method(handler, name) + self._server.register_notification(name, handler, immediate, sensitive_params) + + def _wrap_external_method(self, handler, name: str): + async def wrapper(*args, **kwargs): + return await self._external_task_manager.create_task(handler(*args, **kwargs), name, False) + return wrapper async def run(self): """Plugin's main coroutine.""" await self._server.run() - if self._pass_control_task is not None: - await self._pass_control_task + await self._external_task_manager.wait() + + def close(self) -> None: + if not self._active: + return + + logging.info("Closing plugin") + self._server.close() + self._external_task_manager.cancel() + self._internal_task_manager.create_task(self.shutdown(), "shutdown") + self._active = False + + async def wait_closed(self) -> None: + await self._external_task_manager.wait() + await self._internal_task_manager.wait() + await self._server.wait_closed() + await self._notification_client.close() def create_task(self, coro, description): """Wrapper around asyncio.create_task - takes care of canceling tasks on shutdown""" - - async def task_wrapper(task_id): - try: - return await coro - except asyncio.CancelledError: - logging.debug("Canceled task %d (%s)", task_id, description) - except Exception: - logging.exception("Exception raised in task %d (%s)", task_id, description) - finally: - del self._tasks[task_id] - - task_id = next(self._task_counter) - logging.debug("Creating task %d (%s)", task_id, description) - task = asyncio.create_task(task_wrapper(task_id)) - self._tasks[task_id] = task - return task + return self._external_task_manager.create_task(coro, description) async def _pass_control(self): while self._active: @@ -194,13 +204,11 @@ class Plugin: logging.exception("Unexpected exception raised in plugin tick") await asyncio.sleep(1) - def _shutdown(self): + async def _shutdown(self): logging.info("Shutting down") - self._server.stop() - self._active = False - self.shutdown() - for task in self._tasks.values(): - task.cancel() + self.close() + await self._external_task_manager.wait() + await self._internal_task_manager.wait() def _get_capabilities(self): return { @@ -215,7 +223,7 @@ class Plugin: self.handshake_complete() except Exception: logging.exception("Unhandled exception during `handshake_complete` step") - self._pass_control_task = asyncio.create_task(self._pass_control()) + self._internal_task_manager.create_task(self._pass_control(), "tick") @staticmethod def _ping(): @@ -444,7 +452,7 @@ class Plugin: """ - def shutdown(self) -> None: + async def shutdown(self) -> None: """This method is called on integration shutdown. Override it to implement tear down. This method is called by the GOG Galaxy Client.""" @@ -552,7 +560,11 @@ class Plugin: self._achievements_import_in_progress = False self.achievements_import_complete() - self.create_task(import_games_achievements(game_ids, context), "Games unlocked achievements import") + self._external_task_manager.create_task( + import_games_achievements(game_ids, context), + "unlocked achievements import", + handle_exceptions=False + ) self._achievements_import_in_progress = True async def prepare_achievements_context(self, game_ids: List[str]) -> Any: @@ -712,7 +724,11 @@ class Plugin: self._game_times_import_in_progress = False self.game_times_import_complete() - self.create_task(import_game_times(game_ids, context), "Game times import") + self._external_task_manager.create_task( + import_game_times(game_ids, context), + "game times import", + handle_exceptions=False + ) self._game_times_import_in_progress = True async def prepare_game_times_context(self, game_ids: List[str]) -> Any: @@ -783,8 +799,8 @@ def create_and_run_plugin(plugin_class, argv): reader, writer = await asyncio.open_connection("127.0.0.1", port) extra_info = writer.get_extra_info("sockname") logging.info("Using local address: %s:%u", *extra_info) - plugin = plugin_class(reader, writer, token) - await plugin.run() + async with plugin_class(reader, writer, token) as plugin: + await plugin.run() try: if sys.platform == "win32": diff --git a/src/galaxy/task_manager.py b/src/galaxy/task_manager.py new file mode 100644 index 0000000..1f6d457 --- /dev/null +++ b/src/galaxy/task_manager.py @@ -0,0 +1,49 @@ +import asyncio +import logging +from collections import OrderedDict +from itertools import count + +class TaskManager: + def __init__(self, name): + self._name = name + self._tasks = OrderedDict() + self._task_counter = count() + + def create_task(self, coro, description, handle_exceptions=True): + """Wrapper around asyncio.create_task - takes care of canceling tasks on shutdown""" + + async def task_wrapper(task_id): + try: + result = await coro + logging.debug("Task manager %s: finished task %d (%s)", self._name, task_id, description) + return result + except asyncio.CancelledError: + if handle_exceptions: + logging.debug("Task manager %s: canceled task %d (%s)", self._name, task_id, description) + else: + raise + except Exception: + if handle_exceptions: + logging.exception("Task manager %s: exception raised in task %d (%s)", self._name, task_id, description) + else: + raise + finally: + del self._tasks[task_id] + + task_id = next(self._task_counter) + logging.debug("Task manager %s: creating task %d (%s)", self._name, task_id, description) + task = asyncio.create_task(task_wrapper(task_id)) + self._tasks[task_id] = task + return task + + def cancel(self): + for task in self._tasks.values(): + task.cancel() + + async def wait(self): + # Tasks can spawn other tasks + while True: + tasks = self._tasks.values() + if not tasks: + return + await asyncio.gather(*tasks, return_exceptions=True) diff --git a/tests/__init__.py b/tests/__init__.py index 140adbd..357b407 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -7,11 +7,8 @@ def create_message(request): def get_messages(write_mock): messages = [] - print("call_args_list", write_mock.call_args_list) for call_args in write_mock.call_args_list: - print("call_args", call_args) data = call_args[0][0] - print("data", data) for line in data.splitlines(): message = json.loads(line) messages.append(message) diff --git a/tests/conftest.py b/tests/conftest.py index 23bdadc..a8fce46 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,6 +6,7 @@ import pytest from galaxy.api.plugin import Plugin from galaxy.api.consts import Platform +from galaxy.unittest.mock import async_return_value @pytest.fixture() def reader(): @@ -16,8 +17,7 @@ def reader(): @pytest.fixture() async def writer(): stream = MagicMock(name="stream_writer") - stream.write = MagicMock() - stream.drain = MagicMock() + stream.drain.side_effect = lambda: async_return_value(None) yield stream @pytest.fixture() @@ -29,7 +29,7 @@ def write(writer): yield writer.write @pytest.fixture() -def plugin(reader, writer): +async def plugin(reader, writer): """Return plugin instance with all feature methods mocked""" methods = ( "handshake_complete", @@ -55,7 +55,10 @@ def plugin(reader, writer): with ExitStack() as stack: for method in methods: stack.enter_context(patch.object(Plugin, method)) - yield Plugin(Platform.Generic, "0.1", reader, writer, "token") + + async with Plugin(Platform.Generic, "0.1", reader, writer, "token") as plugin: + plugin.shutdown.return_value = async_return_value(None) + yield plugin @pytest.fixture(autouse=True) diff --git a/tests/test_achievements.py b/tests/test_achievements.py index 1d7abda..593fb22 100644 --- a/tests/test_achievements.py +++ b/tests/test_achievements.py @@ -136,6 +136,7 @@ async def test_prepare_get_unlocked_achievements_context_error(plugin, read, wri } } read.side_effect = [async_return_value(create_message(request)), async_return_value(b"")] + await plugin.run() assert get_messages(write) == [ @@ -153,6 +154,7 @@ async def test_prepare_get_unlocked_achievements_context_error(plugin, read, wri @pytest.mark.asyncio async def test_import_in_progress(plugin, read, write): plugin.prepare_achievements_context.return_value = async_return_value(None) + plugin.get_unlocked_achievements.return_value = async_return_value([]) requests = [ { "jsonrpc": "2.0", @@ -179,21 +181,20 @@ async def test_import_in_progress(plugin, read, write): await plugin.run() - assert get_messages(write) == [ - { - "jsonrpc": "2.0", - "id": "3", - "result": None - }, - { - "jsonrpc": "2.0", - "id": "4", - "error": { - "code": 600, - "message": "Import already in progress" - } + messages = get_messages(write) + assert { + "jsonrpc": "2.0", + "id": "3", + "result": None + } in messages + assert { + "jsonrpc": "2.0", + "id": "4", + "error": { + "code": 600, + "message": "Import already in progress" } - ] + } in messages @pytest.mark.asyncio diff --git a/tests/test_game_times.py b/tests/test_game_times.py index 877bcc1..12cf496 100644 --- a/tests/test_game_times.py +++ b/tests/test_game_times.py @@ -179,21 +179,20 @@ async def test_import_in_progress(plugin, read, write): await plugin.run() - assert get_messages(write) == [ - { - "jsonrpc": "2.0", - "id": "3", - "result": None - }, - { - "jsonrpc": "2.0", - "id": "4", - "error": { - "code": 600, - "message": "Import already in progress" - } + messages = get_messages(write) + assert { + "jsonrpc": "2.0", + "id": "3", + "result": None + } in messages + assert { + "jsonrpc": "2.0", + "id": "4", + "error": { + "code": 600, + "message": "Import already in progress" } - ] + } in messages @pytest.mark.asyncio diff --git a/tests/test_internal.py b/tests/test_internal.py index bfd42a0..b98b77a 100644 --- a/tests/test_internal.py +++ b/tests/test_internal.py @@ -46,6 +46,7 @@ async def test_shutdown(plugin, read, write): } read.side_effect = [async_return_value(create_message(request))] await plugin.run() + await plugin.wait_closed() plugin.shutdown.assert_called_with() assert get_messages(write) == [ {