SDK-3023: Introduce task managers

This commit is contained in:
Romuald Juchnowicz-Bierbasz
2019-08-21 12:50:08 +02:00
committed by Romuald Bierbasz
parent 0ab00e4119
commit 0294e2a1f1
8 changed files with 187 additions and 116 deletions

View File

@@ -6,6 +6,7 @@ import inspect
import json
from galaxy.reader import StreamLineReader
from galaxy.task_manager import TaskManager
class JsonRpcError(Exception):
def __init__(self, code, message, data=None):
@@ -52,7 +53,8 @@ class UnknownError(ApplicationError):
super().__init__(0, "Unknown error", data)
Request = namedtuple("Request", ["method", "params", "id"], defaults=[{}, None])
Method = namedtuple("Method", ["callback", "signature", "internal", "sensitive_params"])
Method = namedtuple("Method", ["callback", "signature", "immediate", "sensitive_params"])
def anonymise_sensitive_params(params, sensitive_params):
anomized_data = "****"
@@ -74,9 +76,9 @@ class Server():
self._encoder = encoder
self._methods = {}
self._notifications = {}
self._eof_listeners = []
self._task_manager = TaskManager("jsonrpc server")
def register_method(self, name, callback, internal, sensitive_params=False):
def register_method(self, name, callback, immediate, sensitive_params=False):
"""
Register method
@@ -86,9 +88,9 @@ class Server():
:param sensitive_params: list of parameters that are anonymized before logging; \
if False - no params are considered sensitive, if True - all params are considered sensitive
"""
self._methods[name] = Method(callback, inspect.signature(callback), internal, sensitive_params)
self._methods[name] = Method(callback, inspect.signature(callback), immediate, sensitive_params)
def register_notification(self, name, callback, internal, sensitive_params=False):
def register_notification(self, name, callback, immediate, sensitive_params=False):
"""
Register notification
@@ -98,10 +100,7 @@ class Server():
:param sensitive_params: list of parameters that are anonymized before logging; \
if False - no params are considered sensitive, if True - all params are considered sensitive
"""
self._notifications[name] = Method(callback, inspect.signature(callback), internal, sensitive_params)
def register_eof(self, callback):
self._eof_listeners.append(callback)
self._notifications[name] = Method(callback, inspect.signature(callback), immediate, sensitive_params)
async def run(self):
while self._active:
@@ -118,14 +117,16 @@ class Server():
self._handle_input(data)
await asyncio.sleep(0) # To not starve task queue
def stop(self):
def close(self):
logging.info("Closing JSON-RPC server - not more messages will be read")
self._active = False
async def wait_closed(self):
await self._task_manager.wait()
def _eof(self):
logging.info("Received EOF")
self.stop()
for listener in self._eof_listeners:
listener()
self.close()
def _handle_input(self, data):
try:
@@ -145,7 +146,7 @@ class Server():
logging.error("Received unknown notification: %s", request.method)
return
callback, signature, internal, sensitive_params = method
callback, signature, immediate, sensitive_params = method
self._log_request(request, sensitive_params)
try:
@@ -153,12 +154,11 @@ class Server():
except TypeError:
self._send_error(request.id, InvalidParams())
if internal:
# internal requests are handled immediately
if immediate:
callback(*bound_args.args, **bound_args.kwargs)
else:
try:
asyncio.create_task(callback(*bound_args.args, **bound_args.kwargs))
self._task_manager.create_task(callback(*bound_args.args, **bound_args.kwargs), request.method)
except Exception:
logging.exception("Unexpected exception raised in notification handler")
@@ -169,7 +169,7 @@ class Server():
self._send_error(request.id, MethodNotFound())
return
callback, signature, internal, sensitive_params = method
callback, signature, immediate, sensitive_params = method
self._log_request(request, sensitive_params)
try:
@@ -177,8 +177,7 @@ class Server():
except TypeError:
self._send_error(request.id, InvalidParams())
if internal:
# internal requests are handled immediately
if immediate:
response = callback(*bound_args.args, **bound_args.kwargs)
self._send_response(request.id, response)
else:
@@ -190,11 +189,13 @@ class Server():
self._send_error(request.id, MethodNotFound())
except JsonRpcError as error:
self._send_error(request.id, error)
except asyncio.CancelledError:
self._send_error(request.id, Aborted())
except Exception as e: #pylint: disable=broad-except
logging.exception("Unexpected exception raised in plugin handler")
self._send_error(request.id, UnknownError(str(e)))
asyncio.create_task(handle())
self._task_manager.create_task(handle(), request.method)
@staticmethod
def _parse_request(data):
@@ -215,7 +216,7 @@ class Server():
logging.debug("Sending data: %s", line)
data = (line + "\n").encode("utf-8")
self._writer.write(data)
asyncio.create_task(self._writer.drain())
self._task_manager.create_task(self._writer.drain(), "drain")
except TypeError as error:
logging.error(str(error))
@@ -255,6 +256,7 @@ class NotificationClient():
self._writer = writer
self._encoder = encoder
self._methods = {}
self._task_manager = TaskManager("notification client")
def notify(self, method, params, sensitive_params=False):
"""
@@ -273,13 +275,16 @@ class NotificationClient():
self._log(method, params, sensitive_params)
self._send(notification)
async def close(self):
await self._task_manager.wait()
def _send(self, data):
try:
line = self._encoder.encode(data)
data = (line + "\n").encode("utf-8")
logging.debug("Sending %d byte of data", len(data))
self._writer.write(data)
asyncio.create_task(self._writer.drain())
self._task_manager.create_task(self._writer.drain(), "drain")
except TypeError as error:
logging.error("Failed to parse outgoing message: %s", str(error))

View File

@@ -4,16 +4,14 @@ import json
import logging
import logging.handlers
import sys
from collections import OrderedDict
from enum import Enum
from itertools import count
from typing import Any, Dict, List, Optional, Set, Union
from galaxy.api.consts import Feature
from galaxy.api.errors import ImportInProgress, UnknownError
from galaxy.api.jsonrpc import ApplicationError, NotificationClient, Server
from galaxy.api.types import Achievement, Authentication, FriendInfo, Game, GameTime, LocalGame, NextStep
from galaxy.task_manager import TaskManager
class JSONEncoder(json.JSONEncoder):
def default(self, o): # pylint: disable=method-hidden
@@ -38,7 +36,6 @@ class Plugin:
self._features: Set[Feature] = set()
self._active = True
self._pass_control_task = None
self._reader, self._writer = reader, writer
self._handshake_token = handshake_token
@@ -47,29 +44,25 @@ class Plugin:
self._server = Server(self._reader, self._writer, encoder)
self._notification_client = NotificationClient(self._writer, encoder)
def eof_handler():
self._shutdown()
self._server.register_eof(eof_handler)
self._achievements_import_in_progress = False
self._game_times_import_in_progress = False
self._persistent_cache = dict()
self._tasks = OrderedDict()
self._task_counter = count()
self._internal_task_manager = TaskManager("plugin internal")
self._external_task_manager = TaskManager("plugin external")
# internal
self._register_method("shutdown", self._shutdown, internal=True)
self._register_method("get_capabilities", self._get_capabilities, internal=True)
self._register_method("get_capabilities", self._get_capabilities, internal=True, immediate=True)
self._register_method(
"initialize_cache",
self._initialize_cache,
internal=True,
immediate=True,
sensitive_params="data"
)
self._register_method("ping", self._ping, internal=True)
self._register_method("ping", self._ping, internal=True, immediate=True)
# implemented by developer
self._register_method(
@@ -116,6 +109,13 @@ class Plugin:
self._register_method("start_game_times_import", self._start_game_times_import)
self._detect_feature(Feature.ImportGameTime, ["get_game_time"])
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
self.close()
await self.wait_closed()
@property
def features(self) -> List[Feature]:
return list(self._features)
@@ -136,55 +136,65 @@ class Plugin:
if self._implements(methods):
self._features.add(feature)
def _register_method(self, name, handler, result_name=None, internal=False, sensitive_params=False):
if internal:
def _register_method(self, name, handler, result_name=None, internal=False, immediate=False, sensitive_params=False):
def wrap_result(result):
if result_name:
result = {
result_name: result
}
return result
if immediate:
def method(*args, **kwargs):
result = handler(*args, **kwargs)
if result_name:
result = {
result_name: result
}
return result
return wrap_result(result)
self._server.register_method(name, method, True, sensitive_params)
else:
async def method(*args, **kwargs):
result = await handler(*args, **kwargs)
if result_name:
result = {
result_name: result
}
return result
if not internal:
handler_ = self._wrap_external_method(handler, name)
else:
handler_ = handler
result = await handler_(*args, **kwargs)
return wrap_result(result)
self._server.register_method(name, method, False, sensitive_params)
def _register_notification(self, name, handler, internal=False, sensitive_params=False):
self._server.register_notification(name, handler, internal, sensitive_params)
def _register_notification(self, name, handler, internal=False, immediate=False, sensitive_params=False):
if not internal and not immediate:
handler = self._wrap_external_method(handler, name)
self._server.register_notification(name, handler, immediate, sensitive_params)
def _wrap_external_method(self, handler, name: str):
async def wrapper(*args, **kwargs):
return await self._external_task_manager.create_task(handler(*args, **kwargs), name, False)
return wrapper
async def run(self):
"""Plugin's main coroutine."""
await self._server.run()
if self._pass_control_task is not None:
await self._pass_control_task
await self._external_task_manager.wait()
def close(self) -> None:
if not self._active:
return
logging.info("Closing plugin")
self._server.close()
self._external_task_manager.cancel()
self._internal_task_manager.create_task(self.shutdown(), "shutdown")
self._active = False
async def wait_closed(self) -> None:
await self._external_task_manager.wait()
await self._internal_task_manager.wait()
await self._server.wait_closed()
await self._notification_client.close()
def create_task(self, coro, description):
"""Wrapper around asyncio.create_task - takes care of canceling tasks on shutdown"""
async def task_wrapper(task_id):
try:
return await coro
except asyncio.CancelledError:
logging.debug("Canceled task %d (%s)", task_id, description)
except Exception:
logging.exception("Exception raised in task %d (%s)", task_id, description)
finally:
del self._tasks[task_id]
task_id = next(self._task_counter)
logging.debug("Creating task %d (%s)", task_id, description)
task = asyncio.create_task(task_wrapper(task_id))
self._tasks[task_id] = task
return task
return self._external_task_manager.create_task(coro, description)
async def _pass_control(self):
while self._active:
@@ -194,13 +204,11 @@ class Plugin:
logging.exception("Unexpected exception raised in plugin tick")
await asyncio.sleep(1)
def _shutdown(self):
async def _shutdown(self):
logging.info("Shutting down")
self._server.stop()
self._active = False
self.shutdown()
for task in self._tasks.values():
task.cancel()
self.close()
await self._external_task_manager.wait()
await self._internal_task_manager.wait()
def _get_capabilities(self):
return {
@@ -215,7 +223,7 @@ class Plugin:
self.handshake_complete()
except Exception:
logging.exception("Unhandled exception during `handshake_complete` step")
self._pass_control_task = asyncio.create_task(self._pass_control())
self._internal_task_manager.create_task(self._pass_control(), "tick")
@staticmethod
def _ping():
@@ -444,7 +452,7 @@ class Plugin:
"""
def shutdown(self) -> None:
async def shutdown(self) -> None:
"""This method is called on integration shutdown.
Override it to implement tear down.
This method is called by the GOG Galaxy Client."""
@@ -552,7 +560,11 @@ class Plugin:
self._achievements_import_in_progress = False
self.achievements_import_complete()
self.create_task(import_games_achievements(game_ids, context), "Games unlocked achievements import")
self._external_task_manager.create_task(
import_games_achievements(game_ids, context),
"unlocked achievements import",
handle_exceptions=False
)
self._achievements_import_in_progress = True
async def prepare_achievements_context(self, game_ids: List[str]) -> Any:
@@ -712,7 +724,11 @@ class Plugin:
self._game_times_import_in_progress = False
self.game_times_import_complete()
self.create_task(import_game_times(game_ids, context), "Game times import")
self._external_task_manager.create_task(
import_game_times(game_ids, context),
"game times import",
handle_exceptions=False
)
self._game_times_import_in_progress = True
async def prepare_game_times_context(self, game_ids: List[str]) -> Any:
@@ -783,8 +799,8 @@ def create_and_run_plugin(plugin_class, argv):
reader, writer = await asyncio.open_connection("127.0.0.1", port)
extra_info = writer.get_extra_info("sockname")
logging.info("Using local address: %s:%u", *extra_info)
plugin = plugin_class(reader, writer, token)
await plugin.run()
async with plugin_class(reader, writer, token) as plugin:
await plugin.run()
try:
if sys.platform == "win32":

View File

@@ -0,0 +1,49 @@
import asyncio
import logging
from collections import OrderedDict
from itertools import count
class TaskManager:
def __init__(self, name):
self._name = name
self._tasks = OrderedDict()
self._task_counter = count()
def create_task(self, coro, description, handle_exceptions=True):
"""Wrapper around asyncio.create_task - takes care of canceling tasks on shutdown"""
async def task_wrapper(task_id):
try:
result = await coro
logging.debug("Task manager %s: finished task %d (%s)", self._name, task_id, description)
return result
except asyncio.CancelledError:
if handle_exceptions:
logging.debug("Task manager %s: canceled task %d (%s)", self._name, task_id, description)
else:
raise
except Exception:
if handle_exceptions:
logging.exception("Task manager %s: exception raised in task %d (%s)", self._name, task_id, description)
else:
raise
finally:
del self._tasks[task_id]
task_id = next(self._task_counter)
logging.debug("Task manager %s: creating task %d (%s)", self._name, task_id, description)
task = asyncio.create_task(task_wrapper(task_id))
self._tasks[task_id] = task
return task
def cancel(self):
for task in self._tasks.values():
task.cancel()
async def wait(self):
# Tasks can spawn other tasks
while True:
tasks = self._tasks.values()
if not tasks:
return
await asyncio.gather(*tasks, return_exceptions=True)

View File

@@ -7,11 +7,8 @@ def create_message(request):
def get_messages(write_mock):
messages = []
print("call_args_list", write_mock.call_args_list)
for call_args in write_mock.call_args_list:
print("call_args", call_args)
data = call_args[0][0]
print("data", data)
for line in data.splitlines():
message = json.loads(line)
messages.append(message)

View File

@@ -6,6 +6,7 @@ import pytest
from galaxy.api.plugin import Plugin
from galaxy.api.consts import Platform
from galaxy.unittest.mock import async_return_value
@pytest.fixture()
def reader():
@@ -16,8 +17,7 @@ def reader():
@pytest.fixture()
async def writer():
stream = MagicMock(name="stream_writer")
stream.write = MagicMock()
stream.drain = MagicMock()
stream.drain.side_effect = lambda: async_return_value(None)
yield stream
@pytest.fixture()
@@ -29,7 +29,7 @@ def write(writer):
yield writer.write
@pytest.fixture()
def plugin(reader, writer):
async def plugin(reader, writer):
"""Return plugin instance with all feature methods mocked"""
methods = (
"handshake_complete",
@@ -55,7 +55,10 @@ def plugin(reader, writer):
with ExitStack() as stack:
for method in methods:
stack.enter_context(patch.object(Plugin, method))
yield Plugin(Platform.Generic, "0.1", reader, writer, "token")
async with Plugin(Platform.Generic, "0.1", reader, writer, "token") as plugin:
plugin.shutdown.return_value = async_return_value(None)
yield plugin
@pytest.fixture(autouse=True)

View File

@@ -136,6 +136,7 @@ async def test_prepare_get_unlocked_achievements_context_error(plugin, read, wri
}
}
read.side_effect = [async_return_value(create_message(request)), async_return_value(b"")]
await plugin.run()
assert get_messages(write) == [
@@ -153,6 +154,7 @@ async def test_prepare_get_unlocked_achievements_context_error(plugin, read, wri
@pytest.mark.asyncio
async def test_import_in_progress(plugin, read, write):
plugin.prepare_achievements_context.return_value = async_return_value(None)
plugin.get_unlocked_achievements.return_value = async_return_value([])
requests = [
{
"jsonrpc": "2.0",
@@ -179,21 +181,20 @@ async def test_import_in_progress(plugin, read, write):
await plugin.run()
assert get_messages(write) == [
{
"jsonrpc": "2.0",
"id": "3",
"result": None
},
{
"jsonrpc": "2.0",
"id": "4",
"error": {
"code": 600,
"message": "Import already in progress"
}
messages = get_messages(write)
assert {
"jsonrpc": "2.0",
"id": "3",
"result": None
} in messages
assert {
"jsonrpc": "2.0",
"id": "4",
"error": {
"code": 600,
"message": "Import already in progress"
}
]
} in messages
@pytest.mark.asyncio

View File

@@ -179,21 +179,20 @@ async def test_import_in_progress(plugin, read, write):
await plugin.run()
assert get_messages(write) == [
{
"jsonrpc": "2.0",
"id": "3",
"result": None
},
{
"jsonrpc": "2.0",
"id": "4",
"error": {
"code": 600,
"message": "Import already in progress"
}
messages = get_messages(write)
assert {
"jsonrpc": "2.0",
"id": "3",
"result": None
} in messages
assert {
"jsonrpc": "2.0",
"id": "4",
"error": {
"code": 600,
"message": "Import already in progress"
}
]
} in messages
@pytest.mark.asyncio

View File

@@ -46,6 +46,7 @@ async def test_shutdown(plugin, read, write):
}
read.side_effect = [async_return_value(create_message(request))]
await plugin.run()
await plugin.wait_closed()
plugin.shutdown.assert_called_with()
assert get_messages(write) == [
{