diff --git a/src/galaxy/api/plugin.py b/src/galaxy/api/plugin.py index 5fd778e..b2704e8 100644 --- a/src/galaxy/api/plugin.py +++ b/src/galaxy/api/plugin.py @@ -4,7 +4,7 @@ import json import logging import sys from enum import Enum -from typing import Any, Dict, List, Optional, Set, Union +from typing import Any, Dict, List, Optional, Set, Union, AsyncGenerator from galaxy.api.consts import Feature, OSCompatibility from galaxy.api.errors import ImportInProgress, UnknownError @@ -42,7 +42,8 @@ class Importer: notification_success, notification_failure, notification_finished, - complete + complete, + yielding=False ): self._task_manager = task_manger self._name = name @@ -54,39 +55,44 @@ class Importer: self._complete = complete self._import_in_progress = False + self._yielding = yielding + + async def _import_element(self, id_, context_): + try: + if self._yielding: + async for element in self._get(id_, context_): + self._notification_success(id_, element) + else: + element = await self._get(id_, context_) + self._notification_success(id_, element) + except ApplicationError as error: + self._notification_failure(id_, error) + except asyncio.CancelledError: + pass + except Exception: + logger.exception("Unexpected exception raised in %s importer", self._name) + self._notification_failure(id_, UnknownError()) + + async def _import_elements(self, ids_, context_): + try: + imports = [self._import_element(id_, context_) for id_ in ids_] + await asyncio.gather(*imports) + self._notification_finished() + self._complete() + except asyncio.CancelledError: + logger.debug("Importing %s cancelled", self._name) + finally: + self._import_in_progress = False async def start(self, ids): if self._import_in_progress: raise ImportInProgress() - async def import_element(id_, context_): - try: - element = await self._get(id_, context_) - self._notification_success(id_, element) - except ApplicationError as error: - self._notification_failure(id_, error) - except asyncio.CancelledError: - pass - except Exception: - logger.exception("Unexpected exception raised in %s importer", self._name) - self._notification_failure(id_, UnknownError()) - - async def import_elements(ids_, context_): - try: - imports = [import_element(id_, context_) for id_ in ids_] - await asyncio.gather(*imports) - self._notification_finished() - self._complete() - except asyncio.CancelledError: - logger.debug("Importing %s cancelled", self._name) - finally: - self._import_in_progress = False - self._import_in_progress = True try: context = await self._prepare_context(ids) self._task_manager.create_task( - import_elements(ids, context), + self._import_elements(ids, context), "{} import".format(self._name), handle_exceptions=False ) @@ -185,7 +191,8 @@ class Plugin: self._subscription_games_import_success, self._subscription_games_import_failure, self._subscription_games_import_finished, - self.subscription_games_import_complete + self.subscription_games_import_complete, + yielding=True ) # internal @@ -1122,11 +1129,25 @@ class Plugin: """ return None - async def get_subscription_games(self, subscription_name: str, context: Any) -> Optional[List[SubscriptionGame]]: - """Override this method to return list of subscription games in a given subscription. + async def get_subscription_games(self, subscription_name: str, context: Any) -> AsyncGenerator[List[SubscriptionGame],None]: + """Override this method to return SubscriptionGames for a given subscription. + This method should `yield` results from a list of SubscriptionGames :param context: the value returned from :meth:`prepare_subscription_games_context` - :return: List of subscription games or `None` if list cannot be determined. + :return yield List of subscription games. + + .. code-block:: python + :linenos: + + async def get_sub_games(sub_name: str, context: Any): + for i in range(10): + try: + games_page = await _get_subs_from_backend(sub_name, i) + except KeyError: + print('no more chunk pages for', sub_name) + return + yield [SubGame(game['game_id'], game['game_title']) for game in games_page] + """ raise NotImplementedError() diff --git a/src/galaxy/unittest/mock.py b/src/galaxy/unittest/mock.py index da2e033..f3b9855 100644 --- a/src/galaxy/unittest/mock.py +++ b/src/galaxy/unittest/mock.py @@ -1,7 +1,6 @@ import asyncio from unittest.mock import MagicMock - class AsyncMock(MagicMock): """ .. deprecated:: 0.45