diff --git a/tests/test_subscriptions.py b/tests/test_subscriptions.py index dd5d5ab..e5d8cf0 100644 --- a/tests/test_subscriptions.py +++ b/tests/test_subscriptions.py @@ -6,8 +6,13 @@ from galaxy.unittest.mock import async_return_value from tests import create_message, get_messages +class AsyncIter: + def __init__(self, items): + self.items = items + + async def __aiter__(self): + yield self.items -@pytest.mark.asyncio @pytest.mark.asyncio async def test_get_subscriptions_success(plugin, read, write): request = { @@ -97,16 +102,14 @@ async def test_get_subscription_games_success(plugin, read, write): } read.side_effect = [async_return_value(create_message(request)), async_return_value(b"", 10)] - async def sub_games(): - games = [ - SubscriptionGame(game_title="game A", game_id="game_A"), - SubscriptionGame(game_title="game B", game_id="game_B", start_time=1548495632), - SubscriptionGame(game_title="game C", game_id="game_C", end_time=1548495633), - SubscriptionGame(game_title="game D", game_id="game_D", start_time=1548495632, end_time=1548495633), + games = [ + SubscriptionGame(game_title="game A", game_id="game_A"), + SubscriptionGame(game_title="game B", game_id="game_B", start_time=1548495632), + SubscriptionGame(game_title="game C", game_id="game_C", end_time=1548495633), + SubscriptionGame(game_title="game D", game_id="game_D", start_time=1548495632, end_time=1548495633), ] - yield [game for game in games] - plugin.get_subscription_games.return_value = sub_games() + plugin.get_subscription_games.return_value = AsyncIter(games) await plugin.run() plugin.prepare_subscription_games_context.assert_called_with(["sub_a"]) plugin.get_subscription_games.assert_called_with("sub_a", 5) @@ -167,10 +170,7 @@ async def test_get_subscription_games_success_none_yield(plugin, read, write): } read.side_effect = [async_return_value(create_message(request)), async_return_value(b"", 10)] - async def sub_games(): - yield None - - plugin.get_subscription_games.return_value = sub_games() + plugin.get_subscription_games.return_value = AsyncIter(None) await plugin.run() plugin.prepare_subscription_games_context.assert_called_with(["sub_a"]) plugin.get_subscription_games.assert_called_with("sub_a", 5)