From 78f1d5a4ccbc60a3670351934cefaea8aaf03b8d Mon Sep 17 00:00:00 2001 From: Rafal Makagon Date: Thu, 31 Oct 2019 15:15:14 +0100 Subject: [PATCH] Add refresh_credentials method to plugin --- src/galaxy/api/jsonrpc.py | 166 ++++++++++++++++++++---------- src/galaxy/api/plugin.py | 73 ++++++------- tests/test_refresh_credentials.py | 72 +++++++++++++ 3 files changed, 219 insertions(+), 92 deletions(-) create mode 100644 tests/test_refresh_credentials.py diff --git a/src/galaxy/api/jsonrpc.py b/src/galaxy/api/jsonrpc.py index be5a9c3..7403c70 100644 --- a/src/galaxy/api/jsonrpc.py +++ b/src/galaxy/api/jsonrpc.py @@ -64,6 +64,7 @@ class UnknownError(ApplicationError): super().__init__(0, "Unknown error", data) Request = namedtuple("Request", ["method", "params", "id"], defaults=[{}, None]) +Response = namedtuple("Response", ["id", "result", "error"], defaults=[None, {}, {}]) Method = namedtuple("Method", ["callback", "signature", "immediate", "sensitive_params"]) @@ -79,7 +80,7 @@ def anonymise_sensitive_params(params, sensitive_params): return params -class Server(): +class Connection(): def __init__(self, reader, writer, encoder=json.JSONEncoder()): self._active = True self._reader = StreamLineReader(reader) @@ -89,6 +90,8 @@ class Server(): self._notifications = {} self._task_manager = TaskManager("jsonrpc server") self._write_lock = asyncio.Lock() + self._last_request_id = 0 + self._requests_futures = {} def register_method(self, name, callback, immediate, sensitive_params=False): """ @@ -114,6 +117,47 @@ class Server(): """ self._notifications[name] = Method(callback, inspect.signature(callback), immediate, sensitive_params) + async def send_request(self, method, params, sensitive_params): + """ + Send request + + :param method: + :param params: + :param sensitive_params: list of parameters that are anonymized before logging; \ + if False - no params are considered sensitive, if True - all params are considered sensitive + """ + self._last_request_id += 1 + request_id = str(self._last_request_id) + + loop = asyncio.get_running_loop() + future = loop.create_future() + self._requests_futures[self._last_request_id] = (future, sensitive_params) + + logging.info( + "Sending request: id=%s, method=%s, params=%s", + request_id, method, anonymise_sensitive_params(params, sensitive_params) + ) + + self._send_request(request_id, method, params) + return await future + + def send_notification(self, method, params, sensitive_params=False): + """ + Send notification + + :param method: + :param params: + :param sensitive_params: list of parameters that are anonymized before logging; \ + if False - no params are considered sensitive, if True - all params are considered sensitive + """ + + logging.info( + "Sending notification: method=%s, params=%s", + method, anonymise_sensitive_params(params, sensitive_params) + ) + + self._send_notification(method, params) + async def run(self): while self._active: try: @@ -143,15 +187,40 @@ class Server(): def _handle_input(self, data): try: - request = self._parse_request(data) + message = self._parse_message(data) except JsonRpcError as error: self._send_error(None, error) return - if request.id is not None: - self._handle_request(request) - else: - self._handle_notification(request) + if isinstance(message, Request): + if message.id is not None: + self._handle_request(message) + else: + self._handle_notification(message) + elif isinstance(message, Response): + self._handle_response(message) + + def _handle_response(self, response): + request_future = self._requests_futures.get(int(response.id)) + if request_future is None: + response_type = "response" if response.result is not None else "error" + logging.warning("Received %s for unknown request: %s", response_type, response.id) + return + + future, sensitive_params = request_future + + if response.error: + error = JsonRpcError( + response.error.setdefault("code", 0), + response.error.setdefault("message", ""), + response.error.setdefault("data", None) + ) + self._log_error(response, error, sensitive_params) + future.set_exception(error) + return + + self._log_response(response, sensitive_params) + future.set_result(response.result) def _handle_notification(self, request): method = self._notifications.get(request.method) @@ -211,13 +280,17 @@ class Server(): self._task_manager.create_task(handle(), request.method) @staticmethod - def _parse_request(data): + def _parse_message(data): try: - jsonrpc_request = json.loads(data, encoding="utf-8") - if jsonrpc_request.get("jsonrpc") != "2.0": + jsonrpc_message = json.loads(data, encoding="utf-8") + if jsonrpc_message.get("jsonrpc") != "2.0": raise InvalidRequest() - del jsonrpc_request["jsonrpc"] - return Request(**jsonrpc_request) + del jsonrpc_message["jsonrpc"] + if "result" in jsonrpc_message.keys() or "error" in jsonrpc_message.keys(): + return Response(**jsonrpc_message) + else: + return Request(**jsonrpc_message) + except json.JSONDecodeError: raise ParseError() except TypeError: @@ -254,6 +327,23 @@ class Server(): self._send(response) + def _send_request(self, request_id, method, params): + request = { + "jsonrpc": "2.0", + "method": method, + "id": request_id, + "params": params + } + self._send(request) + + def _send_notification(self, method, params): + notification = { + "jsonrpc": "2.0", + "method": method, + "params": params + } + self._send(notification) + @staticmethod def _log_request(request, sensitive_params): params = anonymise_sensitive_params(request.params, sensitive_params) @@ -262,50 +352,14 @@ class Server(): else: logging.info("Handling notification: method=%s, params=%s", request.method, params) -class NotificationClient(): - def __init__(self, writer, encoder=json.JSONEncoder()): - self._writer = writer - self._encoder = encoder - self._methods = {} - self._task_manager = TaskManager("notification client") - self._write_lock = asyncio.Lock() - - def notify(self, method, params, sensitive_params=False): - """ - Send notification - - :param method: - :param params: - :param sensitive_params: list of parameters that are anonymized before logging; \ - if False - no params are considered sensitive, if True - all params are considered sensitive - """ - notification = { - "jsonrpc": "2.0", - "method": method, - "params": params - } - self._log(method, params, sensitive_params) - self._send(notification) - - async def close(self): - self._task_manager.cancel() - await self._task_manager.wait() - - def _send(self, data): - async def send_task(data_): - async with self._write_lock: - self._writer.write(data_) - await self._writer.drain() - - try: - line = self._encoder.encode(data) - data = (line + "\n").encode("utf-8") - logging.debug("Sending %d byte of data", len(data)) - self._task_manager.create_task(send_task(data), "send") - except TypeError as error: - logging.error("Failed to parse outgoing message: %s", str(error)) + @staticmethod + def _log_response(response, sensitive_params): + result = anonymise_sensitive_params(response.result, sensitive_params) + logging.info("Handling response: id=%s, result=%s", response.id, result) @staticmethod - def _log(method, params, sensitive_params): - params = anonymise_sensitive_params(params, sensitive_params) - logging.info("Sending notification: method=%s, params=%s", method, params) + def _log_error(response, error, sensitive_params): + data = anonymise_sensitive_params(error.data, sensitive_params) + logging.info("Handling error: id=%s, code=%s, description=%s, data=%s", + response.id, error.code, error.message, data + ) diff --git a/src/galaxy/api/plugin.py b/src/galaxy/api/plugin.py index 55cfad8..65d408b 100644 --- a/src/galaxy/api/plugin.py +++ b/src/galaxy/api/plugin.py @@ -9,7 +9,7 @@ from typing import Any, Dict, List, Optional, Set, Union from galaxy.api.consts import Feature, OSCompatibility from galaxy.api.errors import ImportInProgress, UnknownError -from galaxy.api.jsonrpc import ApplicationError, NotificationClient, Server +from galaxy.api.jsonrpc import ApplicationError, Connection from galaxy.api.types import ( Achievement, Authentication, Game, GameLibrarySettings, GameTime, LocalGame, NextStep, UserInfo, UserPresence ) @@ -44,8 +44,7 @@ class Plugin: self._handshake_token = handshake_token encoder = JSONEncoder() - self._server = Server(self._reader, self._writer, encoder) - self._notification_client = NotificationClient(self._writer, encoder) + self._connection = Connection(self._reader, self._writer, encoder) self._achievements_import_in_progress = False self._game_times_import_in_progress = False @@ -164,7 +163,7 @@ class Plugin: result = handler(*args, **kwargs) return wrap_result(result) - self._server.register_method(name, method, True, sensitive_params) + self._connection.register_method(name, method, True, sensitive_params) else: async def method(*args, **kwargs): if not internal: @@ -174,12 +173,12 @@ class Plugin: result = await handler_(*args, **kwargs) return wrap_result(result) - self._server.register_method(name, method, False, sensitive_params) + self._connection.register_method(name, method, False, sensitive_params) def _register_notification(self, name, handler, internal=False, immediate=False, sensitive_params=False): if not internal and not immediate: handler = self._wrap_external_method(handler, name) - self._server.register_notification(name, handler, immediate, sensitive_params) + self._connection.register_notification(name, handler, immediate, sensitive_params) def _wrap_external_method(self, handler, name: str): async def wrapper(*args, **kwargs): @@ -189,7 +188,7 @@ class Plugin: async def run(self): """Plugin's main coroutine.""" - await self._server.run() + await self._connection.run() logging.debug("Plugin run loop finished") def close(self) -> None: @@ -197,7 +196,7 @@ class Plugin: return logging.info("Closing plugin") - self._server.close() + self._connection.close() self._external_task_manager.cancel() self._internal_task_manager.create_task(self.shutdown(), "shutdown") self._active = False @@ -206,8 +205,7 @@ class Plugin: logging.debug("Waiting for plugin to close") await self._external_task_manager.wait() await self._internal_task_manager.wait() - await self._server.wait_closed() - await self._notification_client.close() + await self._connection.wait_closed() logging.debug("Plugin closed") def create_task(self, coro, description): @@ -273,7 +271,7 @@ class Plugin: # temporary solution for persistent_cache vs credentials issue self.persistent_cache["credentials"] = credentials # type: ignore - self._notification_client.notify("store_credentials", credentials, sensitive_params=True) + self._connection.send_notification("store_credentials", credentials, sensitive_params=True) def add_game(self, game: Game) -> None: """Notify the client to add game to the list of owned games @@ -295,7 +293,7 @@ class Plugin: """ params = {"owned_game": game} - self._notification_client.notify("owned_game_added", params) + self._connection.send_notification("owned_game_added", params) def remove_game(self, game_id: str) -> None: """Notify the client to remove game from the list of owned games @@ -317,7 +315,7 @@ class Plugin: """ params = {"game_id": game_id} - self._notification_client.notify("owned_game_removed", params) + self._connection.send_notification("owned_game_removed", params) def update_game(self, game: Game) -> None: """Notify the client to update the status of a game @@ -326,7 +324,7 @@ class Plugin: :param game: Game to update """ params = {"owned_game": game} - self._notification_client.notify("owned_game_updated", params) + self._connection.send_notification("owned_game_updated", params) def unlock_achievement(self, game_id: str, achievement: Achievement) -> None: """Notify the client to unlock an achievement for a specific game. @@ -338,24 +336,24 @@ class Plugin: "game_id": game_id, "achievement": achievement } - self._notification_client.notify("achievement_unlocked", params) + self._connection.send_notification("achievement_unlocked", params) def _game_achievements_import_success(self, game_id: str, achievements: List[Achievement]) -> None: params = { "game_id": game_id, "unlocked_achievements": achievements } - self._notification_client.notify("game_achievements_import_success", params) + self._connection.send_notification("game_achievements_import_success", params) def _game_achievements_import_failure(self, game_id: str, error: ApplicationError) -> None: params = { "game_id": game_id, "error": error.json() } - self._notification_client.notify("game_achievements_import_failure", params) + self._connection.send_notification("game_achievements_import_failure", params) def _achievements_import_finished(self) -> None: - self._notification_client.notify("achievements_import_finished", None) + self._connection.send_notification("achievements_import_finished", None) def update_local_game_status(self, local_game: LocalGame) -> None: """Notify the client to update the status of a local game. @@ -381,7 +379,7 @@ class Plugin: self._check_statuses_task = asyncio.create_task(self._check_statuses()) """ params = {"local_game": local_game} - self._notification_client.notify("local_game_status_changed", params) + self._connection.send_notification("local_game_status_changed", params) def add_friend(self, user: UserInfo) -> None: """Notify the client to add a user to friends list of the currently authenticated user. @@ -389,7 +387,7 @@ class Plugin: :param user: UserInfo of a user that the client will add to friends list """ params = {"friend_info": user} - self._notification_client.notify("friend_added", params) + self._connection.send_notification("friend_added", params) def remove_friend(self, user_id: str) -> None: """Notify the client to remove a user from friends list of the currently authenticated user. @@ -397,7 +395,7 @@ class Plugin: :param user_id: id of the user to remove from friends list """ params = {"user_id": user_id} - self._notification_client.notify("friend_removed", params) + self._connection.send_notification("friend_removed", params) def update_game_time(self, game_time: GameTime) -> None: """Notify the client to update game time for a game. @@ -405,38 +403,38 @@ class Plugin: :param game_time: game time to update """ params = {"game_time": game_time} - self._notification_client.notify("game_time_updated", params) + self._connection.send_notification("game_time_updated", params) def _game_time_import_success(self, game_time: GameTime) -> None: params = {"game_time": game_time} - self._notification_client.notify("game_time_import_success", params) + self._connection.send_notification("game_time_import_success", params) def _game_time_import_failure(self, game_id: str, error: ApplicationError) -> None: params = { "game_id": game_id, "error": error.json() } - self._notification_client.notify("game_time_import_failure", params) + self._connection.send_notification("game_time_import_failure", params) def _game_times_import_finished(self) -> None: - self._notification_client.notify("game_times_import_finished", None) + self._connection.send_notification("game_times_import_finished", None) def _game_library_settings_import_success(self, game_library_settings: GameLibrarySettings) -> None: params = {"game_library_settings": game_library_settings} - self._notification_client.notify("game_library_settings_import_success", params) + self._connection.send_notification("game_library_settings_import_success", params) def _game_library_settings_import_failure(self, game_id: str, error: ApplicationError) -> None: params = { "game_id": game_id, "error": error.json() } - self._notification_client.notify("game_library_settings_import_failure", params) + self._connection.send_notification("game_library_settings_import_failure", params) def _game_library_settings_import_finished(self) -> None: - self._notification_client.notify("game_library_settings_import_finished", None) + self._connection.send_notification("game_library_settings_import_finished", None) def _os_compatibility_import_success(self, game_id: str, os_compatibility: Optional[OSCompatibility]) -> None: - self._notification_client.notify( + self._connection.send_notification( "os_compatibility_import_success", { "game_id": game_id, @@ -445,7 +443,7 @@ class Plugin: ) def _os_compatibility_import_failure(self, game_id: str, error: ApplicationError) -> None: - self._notification_client.notify( + self._connection.send_notification( "os_compatibility_import_failure", { "game_id": game_id, @@ -454,10 +452,10 @@ class Plugin: ) def _os_compatibility_import_finished(self) -> None: - self._notification_client.notify("os_compatibility_import_finished", None) + self._connection.send_notification("os_compatibility_import_finished", None) def _user_presence_import_success(self, user_id: str, user_presence: UserPresence) -> None: - self._notification_client.notify( + self._connection.send_notification( "user_presence_import_success", { "user_id": user_id, @@ -466,7 +464,7 @@ class Plugin: ) def _user_presence_import_failure(self, user_id: str, error: ApplicationError) -> None: - self._notification_client.notify( + self._connection.send_notification( "user_presence_import_failure", { "user_id": user_id, @@ -475,23 +473,26 @@ class Plugin: ) def _user_presence_import_finished(self) -> None: - self._notification_client.notify("user_presence_import_finished", None) + self._connection.send_notification("user_presence_import_finished", None) def lost_authentication(self) -> None: """Notify the client that integration has lost authentication for the current user and is unable to perform actions which would require it. """ - self._notification_client.notify("authentication_lost", None) + self._connection.send_notification("authentication_lost", None) def push_cache(self) -> None: """Push local copy of the persistent cache to the GOG Galaxy Client replacing existing one. """ - self._notification_client.notify( + self._connection.send_notification( "push_cache", params={"data": self._persistent_cache}, sensitive_params="data" ) + async def refresh_credentials(self, params: Dict[str, Any], sensitive_params) -> Dict[str, Any]: + return await self._connection.send_request("refresh_credentials", params, sensitive_params) + # handlers def handshake_complete(self) -> None: """This method is called right after the handshake with the GOG Galaxy Client is complete and diff --git a/tests/test_refresh_credentials.py b/tests/test_refresh_credentials.py new file mode 100644 index 0000000..4b5f7c1 --- /dev/null +++ b/tests/test_refresh_credentials.py @@ -0,0 +1,72 @@ +import pytest +import asyncio + +from galaxy.unittest.mock import async_return_value +from tests import create_message, get_messages +from galaxy.api.errors import ( + BackendNotAvailable, BackendTimeout, BackendError, InvalidCredentials, NetworkError, AccessDenied +) +from galaxy.api.jsonrpc import JsonRpcError +@pytest.mark.asyncio +async def test_refresh_credentials_success(plugin, read, write): + + run_task = asyncio.create_task(plugin.run()) + + refreshed_credentials = { + "access_token": "new_access_token" + } + + response = { + "jsonrpc": "2.0", + "id": "1", + "result": refreshed_credentials + } + # 2 loop iterations delay is to force sending response after request has been sent + read.side_effect = [async_return_value(create_message(response), loop_iterations_delay=2)] + + result = await plugin.refresh_credentials({}, False) + assert get_messages(write) == [ + { + "jsonrpc": "2.0", + "method": "refresh_credentials", + "params": { + }, + "id": "1" + } + ] + + assert result == refreshed_credentials + await run_task + +@pytest.mark.asyncio +@pytest.mark.parametrize("exception", [ + BackendNotAvailable, BackendTimeout, BackendError, InvalidCredentials, NetworkError, AccessDenied +]) +async def test_refresh_credentials_failure(exception, plugin, read, write): + + run_task = asyncio.create_task(plugin.run()) + error = exception() + response = { + "jsonrpc": "2.0", + "id": "1", + "error": error.json() + } + + # 2 loop iterations delay is to force sending response after request has been sent + read.side_effect = [async_return_value(create_message(response), loop_iterations_delay=2)] + + with pytest.raises(JsonRpcError) as e: + await plugin.refresh_credentials({}, False) + + assert error == e.value + assert get_messages(write) == [ + { + "jsonrpc": "2.0", + "method": "refresh_credentials", + "params": { + }, + "id": "1" + } + ] + + await run_task