Compare commits

...

18 Commits
0.58 ... 0.63

Author SHA1 Message Date
Aleksej Pawlowskij
26102dd832 Increment version 2019-12-17 15:56:37 +01:00
Aleksej Pawlowskij
cdcebda529 SDK-3136: Relax install requirements 2019-12-17 15:43:47 +01:00
Romuald Bierbasz
a83f348d7d Increment version 2019-12-10 16:02:40 +01:00
Romuald Bierbasz
1c196d60d5 SDK-3199: Log response json 2019-12-10 16:00:46 +01:00
Aleksej Pawlowskij
deb125ec48 Add missing psutil setup requirement 2019-12-05 16:22:26 +01:00
Rafal Makagon
4cc0055119 Increment version 2019-12-05 13:58:04 +01:00
Romuald Bierbasz
00164fab67 Correctly set _import_in_progress 2019-12-05 11:39:09 +01:00
Romuald Juchnowicz-Bierbasz
453cd1cc70 Do not send notificaitons when import is cancelled 2019-12-03 14:06:55 +01:00
Romuald Juchnowicz-Bierbasz
1f55253fd7 Wait until writer is closed 2019-12-03 14:04:19 +01:00
Romuald Juchnowicz-Bierbasz
7aa3b01abd Add Importer class (reuse code for importers) 2019-12-03 14:03:53 +01:00
Rafal Makagon
bd14d58bad Increment version 2019-11-28 14:37:46 +01:00
Romuald Juchnowicz-Bierbasz
274b9a2c18 Do not wait for drain 2019-11-28 13:10:58 +01:00
Rafal Makagon
75e5a66fbe Increment version 2019-11-27 13:14:11 +01:00
Mieszko Banczerowski
2a9ec3067d Fix sending Exceptions with custom data 2019-11-27 13:12:20 +01:00
Rafal Makagon
69532a5ba9 fix richpresence parameter name 2019-11-27 13:10:43 +01:00
Romuald Juchnowicz-Bierbasz
f5d47b0167 Add timeout to shutdown 2019-11-22 13:11:08 +01:00
Romuald Juchnowicz-Bierbasz
02f4faa432 Do not use root logger 2019-11-22 13:07:33 +01:00
Romuald Juchnowicz-Bierbasz
3d3922c965 Add async_raise 2019-11-20 17:57:17 +01:00
8 changed files with 219 additions and 229 deletions

View File

@@ -2,14 +2,15 @@ from setuptools import setup, find_packages
setup(
name="galaxy.plugin.api",
version="0.58",
version="0.63",
description="GOG Galaxy Integrations Python API",
author='Galaxy team',
author_email='galaxy@gog.com',
packages=find_packages("src"),
package_dir={'': 'src'},
install_requires=[
"aiohttp==3.5.4",
"certifi==2019.3.9"
"aiohttp>=3.5.4",
"certifi>=2019.3.9",
"psutil>=5.6.3; sys_platform == 'darwin'"
]
)

View File

@@ -8,6 +8,10 @@ import json
from galaxy.reader import StreamLineReader
from galaxy.task_manager import TaskManager
logger = logging.getLogger(__name__)
class JsonRpcError(Exception):
def __init__(self, code, message, data=None):
self.code = code
@@ -25,7 +29,7 @@ class JsonRpcError(Exception):
}
if self.data is not None:
obj["error"]["data"] = self.data
obj["data"] = self.data
return obj
@@ -89,7 +93,6 @@ class Connection():
self._methods = {}
self._notifications = {}
self._task_manager = TaskManager("jsonrpc server")
self._write_lock = asyncio.Lock()
self._last_request_id = 0
self._requests_futures = {}
@@ -133,7 +136,7 @@ class Connection():
future = loop.create_future()
self._requests_futures[self._last_request_id] = (future, sensitive_params)
logging.info(
logger.info(
"Sending request: id=%s, method=%s, params=%s",
request_id, method, anonymise_sensitive_params(params, sensitive_params)
)
@@ -151,7 +154,7 @@ class Connection():
if False - no params are considered sensitive, if True - all params are considered sensitive
"""
logging.info(
logger.info(
"Sending notification: method=%s, params=%s",
method, anonymise_sensitive_params(params, sensitive_params)
)
@@ -169,20 +172,20 @@ class Connection():
self._eof()
continue
data = data.strip()
logging.debug("Received %d bytes of data", len(data))
logger.debug("Received %d bytes of data", len(data))
self._handle_input(data)
await asyncio.sleep(0) # To not starve task queue
def close(self):
if self._active:
logging.info("Closing JSON-RPC server - not more messages will be read")
logger.info("Closing JSON-RPC server - not more messages will be read")
self._active = False
async def wait_closed(self):
await self._task_manager.wait()
def _eof(self):
logging.info("Received EOF")
logger.info("Received EOF")
self.close()
def _handle_input(self, data):
@@ -204,7 +207,7 @@ class Connection():
request_future = self._requests_futures.get(int(response.id))
if request_future is None:
response_type = "response" if response.result is not None else "error"
logging.warning("Received %s for unknown request: %s", response_type, response.id)
logger.warning("Received %s for unknown request: %s", response_type, response.id)
return
future, sensitive_params = request_future
@@ -225,7 +228,7 @@ class Connection():
def _handle_notification(self, request):
method = self._notifications.get(request.method)
if not method:
logging.error("Received unknown notification: %s", request.method)
logger.error("Received unknown notification: %s", request.method)
return
callback, signature, immediate, sensitive_params = method
@@ -242,12 +245,12 @@ class Connection():
try:
self._task_manager.create_task(callback(*bound_args.args, **bound_args.kwargs), request.method)
except Exception:
logging.exception("Unexpected exception raised in notification handler")
logger.exception("Unexpected exception raised in notification handler")
def _handle_request(self, request):
method = self._methods.get(request.method)
if not method:
logging.error("Received unknown request: %s", request.method)
logger.error("Received unknown request: %s", request.method)
self._send_error(request.id, MethodNotFound())
return
@@ -274,7 +277,7 @@ class Connection():
except asyncio.CancelledError:
self._send_error(request.id, Aborted())
except Exception as e: #pylint: disable=broad-except
logging.exception("Unexpected exception raised in plugin handler")
logger.exception("Unexpected exception raised in plugin handler")
self._send_error(request.id, UnknownError(str(e)))
self._task_manager.create_task(handle(), request.method)
@@ -296,19 +299,17 @@ class Connection():
except TypeError:
raise InvalidRequest()
def _send(self, data):
async def send_task(data_):
async with self._write_lock:
self._writer.write(data_)
await self._writer.drain()
def _send(self, data, sensitive=True):
try:
line = self._encoder.encode(data)
data = (line + "\n").encode("utf-8")
logging.debug("Sending %d byte of data", len(data))
self._task_manager.create_task(send_task(data), "send")
if sensitive:
logger.debug("Sending %d bytes of data", len(data))
else:
logging.debug("Sending data: %s", line)
self._writer.write(data)
except TypeError as error:
logging.error(str(error))
logger.error(str(error))
def _send_response(self, request_id, result):
response = {
@@ -316,7 +317,7 @@ class Connection():
"id": request_id,
"result": result
}
self._send(response)
self._send(response, sensitive=False)
def _send_error(self, request_id, error):
response = {
@@ -325,7 +326,7 @@ class Connection():
"error": error.json()
}
self._send(response)
self._send(response, sensitive=False)
def _send_request(self, request_id, method, params):
request = {
@@ -334,7 +335,7 @@ class Connection():
"id": request_id,
"params": params
}
self._send(request)
self._send(request, sensitive=True)
def _send_notification(self, method, params):
notification = {
@@ -342,24 +343,24 @@ class Connection():
"method": method,
"params": params
}
self._send(notification)
self._send(notification, sensitive=True)
@staticmethod
def _log_request(request, sensitive_params):
params = anonymise_sensitive_params(request.params, sensitive_params)
if request.id is not None:
logging.info("Handling request: id=%s, method=%s, params=%s", request.id, request.method, params)
logger.info("Handling request: id=%s, method=%s, params=%s", request.id, request.method, params)
else:
logging.info("Handling notification: method=%s, params=%s", request.method, params)
logger.info("Handling notification: method=%s, params=%s", request.method, params)
@staticmethod
def _log_response(response, sensitive_params):
result = anonymise_sensitive_params(response.result, sensitive_params)
logging.info("Handling response: id=%s, result=%s", response.id, result)
logger.info("Handling response: id=%s, result=%s", response.id, result)
@staticmethod
def _log_error(response, error, sensitive_params):
data = anonymise_sensitive_params(error.data, sensitive_params)
logging.info("Handling error: id=%s, code=%s, description=%s, data=%s",
logger.info("Handling error: id=%s, code=%s, description=%s, data=%s",
response.id, error.code, error.message, data
)

View File

@@ -2,7 +2,6 @@ import asyncio
import dataclasses
import json
import logging
import logging.handlers
import sys
from enum import Enum
from typing import Any, Dict, List, Optional, Set, Union
@@ -16,6 +15,9 @@ from galaxy.api.types import (
from galaxy.task_manager import TaskManager
logger = logging.getLogger(__name__)
class JSONEncoder(json.JSONEncoder):
def default(self, o): # pylint: disable=method-hidden
if dataclasses.is_dataclass(o):
@@ -29,11 +31,74 @@ class JSONEncoder(json.JSONEncoder):
return super().default(o)
class Importer:
def __init__(
self,
task_manger,
name,
get,
prepare_context,
notification_success,
notification_failure,
notification_finished,
complete
):
self._task_manager = task_manger
self._name = name
self._get = get
self._prepare_context = prepare_context
self._notification_success = notification_success
self._notification_failure = notification_failure
self._notification_finished = notification_finished
self._complete = complete
self._import_in_progress = False
async def start(self, ids):
if self._import_in_progress:
raise ImportInProgress()
async def import_element(id_, context_):
try:
element = await self._get(id_, context_)
self._notification_success(id_, element)
except ApplicationError as error:
self._notification_failure(id_, error)
except asyncio.CancelledError:
pass
except Exception:
logger.exception("Unexpected exception raised in %s importer", self._name)
self._notification_failure(id_, UnknownError())
async def import_elements(ids_, context_):
try:
imports = [import_element(id_, context_) for id_ in ids_]
await asyncio.gather(*imports)
self._notification_finished()
self._complete()
except asyncio.CancelledError:
logger.debug("Importing %s cancelled", self._name)
finally:
self._import_in_progress = False
self._import_in_progress = True
try:
context = await self._prepare_context(ids)
self._task_manager.create_task(
import_elements(ids, context),
"{} import".format(self._name),
handle_exceptions=False
)
except:
self._import_in_progress = False
raise
class Plugin:
"""Use and override methods of this class to create a new platform integration."""
def __init__(self, platform, version, reader, writer, handshake_token):
logging.info("Creating plugin for platform %s, version %s", platform.value, version)
logger.info("Creating plugin for platform %s, version %s", platform.value, version)
self._platform = platform
self._version = version
@@ -46,17 +111,62 @@ class Plugin:
encoder = JSONEncoder()
self._connection = Connection(self._reader, self._writer, encoder)
self._achievements_import_in_progress = False
self._game_times_import_in_progress = False
self._game_library_settings_import_in_progress = False
self._os_compatibility_import_in_progress = False
self._user_presence_import_in_progress = False
self._persistent_cache = dict()
self._internal_task_manager = TaskManager("plugin internal")
self._external_task_manager = TaskManager("plugin external")
self._achievements_importer = Importer(
self._external_task_manager,
"achievements",
self.get_unlocked_achievements,
self.prepare_achievements_context,
self._game_achievements_import_success,
self._game_achievements_import_failure,
self._achievements_import_finished,
self.achievements_import_complete
)
self._game_time_importer = Importer(
self._external_task_manager,
"game times",
self.get_game_time,
self.prepare_game_times_context,
self._game_time_import_success,
self._game_time_import_failure,
self._game_times_import_finished,
self.game_times_import_complete
)
self._game_library_settings_importer = Importer(
self._external_task_manager,
"game library settings",
self.get_game_library_settings,
self.prepare_game_library_settings_context,
self._game_library_settings_import_success,
self._game_library_settings_import_failure,
self._game_library_settings_import_finished,
self.game_library_settings_import_complete
)
self._os_compatibility_importer = Importer(
self._external_task_manager,
"os compatibility",
self.get_os_compatibility,
self.prepare_os_compatibility_context,
self._os_compatibility_import_success,
self._os_compatibility_import_failure,
self._os_compatibility_import_finished,
self.os_compatibility_import_complete
)
self._user_presence_importer = Importer(
self._external_task_manager,
"users presence",
self.get_user_presence,
self.prepare_user_presence_context,
self._user_presence_import_success,
self._user_presence_import_failure,
self._user_presence_import_finished,
self.user_presence_import_complete
)
# internal
self._register_method("shutdown", self._shutdown, internal=True)
self._register_method("get_capabilities", self._get_capabilities, internal=True, immediate=True)
@@ -189,24 +299,31 @@ class Plugin:
async def run(self):
"""Plugin's main coroutine."""
await self._connection.run()
logging.debug("Plugin run loop finished")
logger.debug("Plugin run loop finished")
def close(self) -> None:
if not self._active:
return
logging.info("Closing plugin")
logger.info("Closing plugin")
self._connection.close()
self._external_task_manager.cancel()
self._internal_task_manager.create_task(self.shutdown(), "shutdown")
async def shutdown():
try:
await asyncio.wait_for(self.shutdown(), 30)
except asyncio.TimeoutError:
logging.warning("Plugin shutdown timed out")
self._internal_task_manager.create_task(shutdown(), "shutdown")
self._active = False
async def wait_closed(self) -> None:
logging.debug("Waiting for plugin to close")
logger.debug("Waiting for plugin to close")
await self._external_task_manager.wait()
await self._internal_task_manager.wait()
await self._connection.wait_closed()
logging.debug("Plugin closed")
logger.debug("Plugin closed")
def create_task(self, coro, description):
"""Wrapper around asyncio.create_task - takes care of canceling tasks on shutdown"""
@@ -217,11 +334,11 @@ class Plugin:
try:
self.tick()
except Exception:
logging.exception("Unexpected exception raised in plugin tick")
logger.exception("Unexpected exception raised in plugin tick")
await asyncio.sleep(1)
async def _shutdown(self):
logging.info("Shutting down")
logger.info("Shutting down")
self.close()
await self._external_task_manager.wait()
await self._internal_task_manager.wait()
@@ -238,7 +355,7 @@ class Plugin:
try:
self.handshake_complete()
except Exception:
logging.exception("Unhandled exception during `handshake_complete` step")
logger.exception("Unhandled exception during `handshake_complete` step")
self._internal_task_manager.create_task(self._pass_control(), "tick")
@staticmethod
@@ -426,7 +543,7 @@ class Plugin:
}
)
def _game_time_import_success(self, game_time: GameTime) -> None:
def _game_time_import_success(self, game_id: str, game_time: GameTime) -> None:
params = {"game_time": game_time}
self._connection.send_notification("game_time_import_success", params)
@@ -440,7 +557,7 @@ class Plugin:
def _game_times_import_finished(self) -> None:
self._connection.send_notification("game_times_import_finished", None)
def _game_library_settings_import_success(self, game_library_settings: GameLibrarySettings) -> None:
def _game_library_settings_import_success(self, game_id: str, game_library_settings: GameLibrarySettings) -> None:
params = {"game_library_settings": game_library_settings}
self._connection.send_notification("game_library_settings_import_success", params)
@@ -627,36 +744,7 @@ class Plugin:
raise NotImplementedError()
async def _start_achievements_import(self, game_ids: List[str]) -> None:
if self._achievements_import_in_progress:
raise ImportInProgress()
context = await self.prepare_achievements_context(game_ids)
async def import_game_achievements(game_id, context_):
try:
achievements = await self.get_unlocked_achievements(game_id, context_)
self._game_achievements_import_success(game_id, achievements)
except ApplicationError as error:
self._game_achievements_import_failure(game_id, error)
except Exception:
logging.exception("Unexpected exception raised in import_game_achievements")
self._game_achievements_import_failure(game_id, UnknownError())
async def import_games_achievements(game_ids_, context_):
try:
imports = [import_game_achievements(game_id, context_) for game_id in game_ids_]
await asyncio.gather(*imports)
finally:
self._achievements_import_finished()
self._achievements_import_in_progress = False
self.achievements_import_complete()
self._external_task_manager.create_task(
import_games_achievements(game_ids, context),
"unlocked achievements import",
handle_exceptions=False
)
self._achievements_import_in_progress = True
await self._achievements_importer.start(game_ids)
async def prepare_achievements_context(self, game_ids: List[str]) -> Any:
"""Override this method to prepare context for get_unlocked_achievements.
@@ -791,36 +879,7 @@ class Plugin:
raise NotImplementedError()
async def _start_game_times_import(self, game_ids: List[str]) -> None:
if self._game_times_import_in_progress:
raise ImportInProgress()
context = await self.prepare_game_times_context(game_ids)
async def import_game_time(game_id, context_):
try:
game_time = await self.get_game_time(game_id, context_)
self._game_time_import_success(game_time)
except ApplicationError as error:
self._game_time_import_failure(game_id, error)
except Exception:
logging.exception("Unexpected exception raised in import_game_time")
self._game_time_import_failure(game_id, UnknownError())
async def import_game_times(game_ids_, context_):
try:
imports = [import_game_time(game_id, context_) for game_id in game_ids_]
await asyncio.gather(*imports)
finally:
self._game_times_import_finished()
self._game_times_import_in_progress = False
self.game_times_import_complete()
self._external_task_manager.create_task(
import_game_times(game_ids, context),
"game times import",
handle_exceptions=False
)
self._game_times_import_in_progress = True
await self._game_time_importer.start(game_ids)
async def prepare_game_times_context(self, game_ids: List[str]) -> Any:
"""Override this method to prepare context for get_game_time.
@@ -849,36 +908,7 @@ class Plugin:
"""
async def _start_game_library_settings_import(self, game_ids: List[str]) -> None:
if self._game_library_settings_import_in_progress:
raise ImportInProgress()
context = await self.prepare_game_library_settings_context(game_ids)
async def import_game_library_settings(game_id, context_):
try:
game_library_settings = await self.get_game_library_settings(game_id, context_)
self._game_library_settings_import_success(game_library_settings)
except ApplicationError as error:
self._game_library_settings_import_failure(game_id, error)
except Exception:
logging.exception("Unexpected exception raised in import_game_library_settings")
self._game_library_settings_import_failure(game_id, UnknownError())
async def import_game_library_settings_set(game_ids_, context_):
try:
imports = [import_game_library_settings(game_id, context_) for game_id in game_ids_]
await asyncio.gather(*imports)
finally:
self._game_library_settings_import_finished()
self._game_library_settings_import_in_progress = False
self.game_library_settings_import_complete()
self._external_task_manager.create_task(
import_game_library_settings_set(game_ids, context),
"game library settings import",
handle_exceptions=False
)
self._game_library_settings_import_in_progress = True
await self._game_library_settings_importer.start(game_ids)
async def prepare_game_library_settings_context(self, game_ids: List[str]) -> Any:
"""Override this method to prepare context for get_game_library_settings.
@@ -907,37 +937,7 @@ class Plugin:
"""
async def _start_os_compatibility_import(self, game_ids: List[str]) -> None:
if self._os_compatibility_import_in_progress:
raise ImportInProgress()
context = await self.prepare_os_compatibility_context(game_ids)
async def import_os_compatibility(game_id, context_):
try:
os_compatibility = await self.get_os_compatibility(game_id, context_)
self._os_compatibility_import_success(game_id, os_compatibility)
except ApplicationError as error:
self._os_compatibility_import_failure(game_id, error)
except Exception:
logging.exception("Unexpected exception raised in import_os_compatibility")
self._os_compatibility_import_failure(game_id, UnknownError())
async def import_os_compatibility_set(game_ids_, context_):
try:
await asyncio.gather(*[
import_os_compatibility(game_id, context_) for game_id in game_ids_
])
finally:
self._os_compatibility_import_finished()
self._os_compatibility_import_in_progress = False
self.os_compatibility_import_complete()
self._external_task_manager.create_task(
import_os_compatibility_set(game_ids, context),
"game OS compatibility import",
handle_exceptions=False
)
self._os_compatibility_import_in_progress = True
await self._os_compatibility_importer.start(game_ids)
async def prepare_os_compatibility_context(self, game_ids: List[str]) -> Any:
"""Override this method to prepare context for get_os_compatibility.
@@ -962,45 +962,15 @@ class Plugin:
def os_compatibility_import_complete(self) -> None:
"""Override this method to handle operations after OS compatibility import is finished (like updating cache)."""
async def _start_user_presence_import(self, user_ids: List[str]) -> None:
if self._user_presence_import_in_progress:
raise ImportInProgress()
async def _start_user_presence_import(self, user_id_list: List[str]) -> None:
await self._user_presence_importer.start(user_id_list)
context = await self.prepare_user_presence_context(user_ids)
async def import_user_presence(user_id, context_) -> None:
try:
self._user_presence_import_success(user_id, await self.get_user_presence(user_id, context_))
except ApplicationError as error:
self._user_presence_import_failure(user_id, error)
except Exception:
logging.exception("Unexpected exception raised in import_user_presence")
self._user_presence_import_failure(user_id, UnknownError())
async def import_user_presence_set(user_ids_, context_) -> None:
try:
await asyncio.gather(*[
import_user_presence(user_id, context_)
for user_id in user_ids_
])
finally:
self._user_presence_import_finished()
self._user_presence_import_in_progress = False
self.user_presence_import_complete()
self._external_task_manager.create_task(
import_user_presence_set(user_ids, context),
"user presence import",
handle_exceptions=False
)
self._user_presence_import_in_progress = True
async def prepare_user_presence_context(self, user_ids: List[str]) -> Any:
async def prepare_user_presence_context(self, user_id_list: List[str]) -> Any:
"""Override this method to prepare context for get_user_presence.
This allows for optimizations like batch requests to platform API.
Default implementation returns None.
:param user_ids: the ids of the users for whom presence information is imported
:param user_id_list: the ids of the users for whom presence information is imported
:return: context
"""
return None
@@ -1037,7 +1007,7 @@ def create_and_run_plugin(plugin_class, argv):
main()
"""
if len(argv) < 3:
logging.critical("Not enough parameters, required: token, port")
logger.critical("Not enough parameters, required: token, port")
sys.exit(1)
token = argv[1]
@@ -1045,23 +1015,28 @@ def create_and_run_plugin(plugin_class, argv):
try:
port = int(argv[2])
except ValueError:
logging.critical("Failed to parse port value: %s", argv[2])
logger.critical("Failed to parse port value: %s", argv[2])
sys.exit(2)
if not (1 <= port <= 65535):
logging.critical("Port value out of range (1, 65535)")
logger.critical("Port value out of range (1, 65535)")
sys.exit(3)
if not issubclass(plugin_class, Plugin):
logging.critical("plugin_class must be subclass of Plugin")
logger.critical("plugin_class must be subclass of Plugin")
sys.exit(4)
async def coroutine():
reader, writer = await asyncio.open_connection("127.0.0.1", port)
extra_info = writer.get_extra_info("sockname")
logging.info("Using local address: %s:%u", *extra_info)
async with plugin_class(reader, writer, token) as plugin:
await plugin.run()
try:
extra_info = writer.get_extra_info("sockname")
logger.info("Using local address: %s:%u", *extra_info)
async with plugin_class(reader, writer, token) as plugin:
await plugin.run()
finally:
writer.close()
await writer.wait_closed()
try:
if sys.platform == "win32":
@@ -1069,5 +1044,5 @@ def create_and_run_plugin(plugin_class, argv):
asyncio.run(coroutine())
except Exception:
logging.exception("Error while running plugin")
logger.exception("Error while running plugin")
sys.exit(5)

View File

@@ -44,6 +44,8 @@ from galaxy.api.errors import (
)
logger = logging.getLogger(__name__)
#: Default limit of the simultaneous connections for ssl connector.
DEFAULT_LIMIT = 20
#: Default timeout in seconds used for client session.
@@ -136,11 +138,11 @@ def handle_exception():
if error.status >= 500:
raise BackendError()
if error.status >= 400:
logging.warning(
logger.warning(
"Got status %d while performing %s request for %s",
error.status, error.request_info.method, str(error.request_info.url)
)
raise UnknownError()
except aiohttp.ClientError:
logging.exception("Caught exception while performing request")
logger.exception("Caught exception while performing request")
raise UnknownError()

View File

@@ -3,7 +3,6 @@ from dataclasses import dataclass
from typing import Iterable, NewType, Optional, List, cast
ProcessId = NewType("ProcessId", int)

View File

@@ -3,6 +3,10 @@ import logging
from collections import OrderedDict
from itertools import count
logger = logging.getLogger(__name__)
class TaskManager:
def __init__(self, name):
self._name = name
@@ -15,23 +19,23 @@ class TaskManager:
async def task_wrapper(task_id):
try:
result = await coro
logging.debug("Task manager %s: finished task %d (%s)", self._name, task_id, description)
logger.debug("Task manager %s: finished task %d (%s)", self._name, task_id, description)
return result
except asyncio.CancelledError:
if handle_exceptions:
logging.debug("Task manager %s: canceled task %d (%s)", self._name, task_id, description)
logger.debug("Task manager %s: canceled task %d (%s)", self._name, task_id, description)
else:
raise
except Exception:
if handle_exceptions:
logging.exception("Task manager %s: exception raised in task %d (%s)", self._name, task_id, description)
logger.exception("Task manager %s: exception raised in task %d (%s)", self._name, task_id, description)
else:
raise
finally:
del self._tasks[task_id]
task_id = next(self._task_counter)
logging.debug("Task manager %s: creating task %d (%s)", self._name, task_id, description)
logger.debug("Task manager %s: creating task %d (%s)", self._name, task_id, description)
task = asyncio.create_task(task_wrapper(task_id))
self._tasks[task_id] = task
return task

View File

@@ -21,11 +21,19 @@ def coroutine_mock():
corofunc.coro = coro
return corofunc
async def skip_loop(iterations=1):
for _ in range(iterations):
await asyncio.sleep(0)
async def async_return_value(return_value, loop_iterations_delay=0):
await skip_loop(loop_iterations_delay)
if loop_iterations_delay > 0:
await skip_loop(loop_iterations_delay)
return return_value
async def async_raise(error, loop_iterations_delay=0):
if loop_iterations_delay > 0:
await skip_loop(loop_iterations_delay)
raise error

View File

@@ -12,13 +12,13 @@ from tests import create_message, get_messages
@pytest.mark.asyncio
async def test_get_user_presence_success(plugin, read, write):
context = "abc"
user_ids = ["666", "13", "42", "69", "22"]
user_id_list = ["666", "13", "42", "69", "22"]
plugin.prepare_user_presence_context.return_value = async_return_value(context)
request = {
"jsonrpc": "2.0",
"id": "11",
"method": "start_user_presence_import",
"params": {"user_ids": user_ids}
"params": {"user_id_list": user_id_list}
}
read.side_effect = [async_return_value(create_message(request)), async_return_value(b"", 10)]
plugin.get_user_presence.side_effect = [
@@ -60,7 +60,7 @@ async def test_get_user_presence_success(plugin, read, write):
]
await plugin.run()
plugin.get_user_presence.assert_has_calls([
call(user_id, context) for user_id in user_ids
call(user_id, context) for user_id in user_id_list
])
plugin.user_presence_import_complete.assert_called_once_with()
@@ -151,7 +151,7 @@ async def test_get_user_presence_error(exception, code, message, plugin, read, w
"jsonrpc": "2.0",
"id": request_id,
"method": "start_user_presence_import",
"params": {"user_ids": [user_id]}
"params": {"user_id_list": [user_id]}
}
read.side_effect = [async_return_value(create_message(request)), async_return_value(b"", 10)]
plugin.get_user_presence.side_effect = exception
@@ -192,7 +192,7 @@ async def test_prepare_get_user_presence_context_error(plugin, read, write):
"jsonrpc": "2.0",
"id": request_id,
"method": "start_user_presence_import",
"params": {"user_ids": ["6"]}
"params": {"user_id_list": ["6"]}
}
read.side_effect = [async_return_value(create_message(request)), async_return_value(b"", 10)]
await plugin.run()
@@ -218,7 +218,7 @@ async def test_import_already_in_progress_error(plugin, read, write):
"id": "3",
"method": "start_user_presence_import",
"params": {
"user_ids": ["42"]
"user_id_list": ["42"]
}
},
{
@@ -226,7 +226,7 @@ async def test_import_already_in_progress_error(plugin, read, write):
"id": "4",
"method": "start_user_presence_import",
"params": {
"user_ids": ["666"]
"user_id_list": ["666"]
}
}
]