From 1585bab203cfb798792c2472e782ba58b7f28e01 Mon Sep 17 00:00:00 2001 From: Romuald Bierbasz Date: Tue, 22 Oct 2019 11:30:01 +0200 Subject: [PATCH] Wait for drain before writing --- src/galaxy/api/jsonrpc.py | 18 ++++++++++++++---- tests/test_achievements.py | 3 ++- tests/test_authenticate.py | 4 +++- tests/test_friends.py | 4 +++- tests/test_game_times.py | 3 ++- tests/test_local_games.py | 3 ++- tests/test_owned_games.py | 5 ++++- tests/test_persistent_cache.py | 4 +++- 8 files changed, 33 insertions(+), 11 deletions(-) diff --git a/src/galaxy/api/jsonrpc.py b/src/galaxy/api/jsonrpc.py index 4087f58..be5a9c3 100644 --- a/src/galaxy/api/jsonrpc.py +++ b/src/galaxy/api/jsonrpc.py @@ -88,6 +88,7 @@ class Server(): self._methods = {} self._notifications = {} self._task_manager = TaskManager("jsonrpc server") + self._write_lock = asyncio.Lock() def register_method(self, name, callback, immediate, sensitive_params=False): """ @@ -223,12 +224,16 @@ class Server(): raise InvalidRequest() def _send(self, data): + async def send_task(data_): + async with self._write_lock: + self._writer.write(data_) + await self._writer.drain() + try: line = self._encoder.encode(data) logging.debug("Sending data: %s", line) data = (line + "\n").encode("utf-8") - self._writer.write(data) - self._task_manager.create_task(self._writer.drain(), "drain") + self._task_manager.create_task(send_task(data), "send") except TypeError as error: logging.error(str(error)) @@ -263,6 +268,7 @@ class NotificationClient(): self._encoder = encoder self._methods = {} self._task_manager = TaskManager("notification client") + self._write_lock = asyncio.Lock() def notify(self, method, params, sensitive_params=False): """ @@ -286,12 +292,16 @@ class NotificationClient(): await self._task_manager.wait() def _send(self, data): + async def send_task(data_): + async with self._write_lock: + self._writer.write(data_) + await self._writer.drain() + try: line = self._encoder.encode(data) data = (line + "\n").encode("utf-8") logging.debug("Sending %d byte of data", len(data)) - self._writer.write(data) - self._task_manager.create_task(self._writer.drain(), "drain") + self._task_manager.create_task(send_task(data), "send") except TypeError as error: logging.error("Failed to parse outgoing message: %s", str(error)) diff --git a/tests/test_achievements.py b/tests/test_achievements.py index 598a51a..db21ca5 100644 --- a/tests/test_achievements.py +++ b/tests/test_achievements.py @@ -5,7 +5,7 @@ from pytest import raises from galaxy.api.types import Achievement from galaxy.api.errors import BackendError -from galaxy.unittest.mock import async_return_value +from galaxy.unittest.mock import async_return_value, skip_loop from tests import create_message, get_messages @@ -201,6 +201,7 @@ async def test_import_in_progress(plugin, read, write): async def test_unlock_achievement(plugin, write): achievement = Achievement(achievement_id="lvl20", unlock_time=1548422395) plugin.unlock_achievement("14", achievement) + await skip_loop() response = json.loads(write.call_args[0][0]) assert response == { diff --git a/tests/test_authenticate.py b/tests/test_authenticate.py index 1c29ad2..9bbb6e8 100644 --- a/tests/test_authenticate.py +++ b/tests/test_authenticate.py @@ -5,7 +5,7 @@ from galaxy.api.errors import ( UnknownError, InvalidCredentials, NetworkError, LoggedInElsewhere, ProtocolError, BackendNotAvailable, BackendTimeout, BackendError, TemporaryBlocked, Banned, AccessDenied ) -from galaxy.unittest.mock import async_return_value +from galaxy.unittest.mock import async_return_value, skip_loop from tests import create_message, get_messages @@ -97,6 +97,7 @@ async def test_store_credentials(plugin, write): "token": "ABC" } plugin.store_credentials(credentials) + await skip_loop() assert get_messages(write) == [ { @@ -110,6 +111,7 @@ async def test_store_credentials(plugin, write): @pytest.mark.asyncio async def test_lost_authentication(plugin, write): plugin.lost_authentication() + await skip_loop() assert get_messages(write) == [ { diff --git a/tests/test_friends.py b/tests/test_friends.py index 820f824..8b124e9 100644 --- a/tests/test_friends.py +++ b/tests/test_friends.py @@ -1,6 +1,6 @@ from galaxy.api.types import FriendInfo from galaxy.api.errors import UnknownError -from galaxy.unittest.mock import async_return_value +from galaxy.unittest.mock import async_return_value, skip_loop import pytest @@ -67,6 +67,7 @@ async def test_add_friend(plugin, write): friend = FriendInfo("7", "Kuba") plugin.add_friend(friend) + await skip_loop() assert get_messages(write) == [ { @@ -82,6 +83,7 @@ async def test_add_friend(plugin, write): @pytest.mark.asyncio async def test_remove_friend(plugin, write): plugin.remove_friend("5") + await skip_loop() assert get_messages(write) == [ { diff --git a/tests/test_game_times.py b/tests/test_game_times.py index 7f34e58..fc46a66 100644 --- a/tests/test_game_times.py +++ b/tests/test_game_times.py @@ -3,7 +3,7 @@ from unittest.mock import call import pytest from galaxy.api.types import GameTime from galaxy.api.errors import BackendError -from galaxy.unittest.mock import async_return_value +from galaxy.unittest.mock import async_return_value, skip_loop from tests import create_message, get_messages @@ -199,6 +199,7 @@ async def test_import_in_progress(plugin, read, write): async def test_update_game(plugin, write): game_time = GameTime("3", 60, 1549550504) plugin.update_game_time(game_time) + await skip_loop() assert get_messages(write) == [ { diff --git a/tests/test_local_games.py b/tests/test_local_games.py index 9899ccb..326057f 100644 --- a/tests/test_local_games.py +++ b/tests/test_local_games.py @@ -3,7 +3,7 @@ import pytest from galaxy.api.types import LocalGame from galaxy.api.consts import LocalGameState from galaxy.api.errors import UnknownError, FailedParsingManifest -from galaxy.unittest.mock import async_return_value +from galaxy.unittest.mock import async_return_value, skip_loop from tests import create_message, get_messages @@ -83,6 +83,7 @@ async def test_failure(plugin, read, write, error, code, message): async def test_local_game_state_update(plugin, write): game = LocalGame("1", LocalGameState.Running) plugin.update_local_game_status(game) + await skip_loop() assert get_messages(write) == [ { diff --git a/tests/test_owned_games.py b/tests/test_owned_games.py index 0f2752d..73f3308 100644 --- a/tests/test_owned_games.py +++ b/tests/test_owned_games.py @@ -3,7 +3,7 @@ import pytest from galaxy.api.types import Game, Dlc, LicenseInfo from galaxy.api.consts import LicenseType from galaxy.api.errors import UnknownError -from galaxy.unittest.mock import async_return_value +from galaxy.unittest.mock import async_return_value, skip_loop from tests import create_message, get_messages @@ -100,6 +100,7 @@ async def test_failure(plugin, read, write): async def test_add_game(plugin, write): game = Game("3", "Doom", None, LicenseInfo(LicenseType.SinglePurchase, None)) plugin.add_game(game) + await skip_loop() assert get_messages(write) == [ { "jsonrpc": "2.0", @@ -120,6 +121,7 @@ async def test_add_game(plugin, write): @pytest.mark.asyncio async def test_remove_game(plugin, write): plugin.remove_game("5") + await skip_loop() assert get_messages(write) == [ { "jsonrpc": "2.0", @@ -135,6 +137,7 @@ async def test_remove_game(plugin, write): async def test_update_game(plugin, write): game = Game("3", "Doom", None, LicenseInfo(LicenseType.SinglePurchase, None)) plugin.update_game(game) + await skip_loop() assert get_messages(write) == [ { "jsonrpc": "2.0", diff --git a/tests/test_persistent_cache.py b/tests/test_persistent_cache.py index 9aaa0f7..8e9834f 100644 --- a/tests/test_persistent_cache.py +++ b/tests/test_persistent_cache.py @@ -1,6 +1,6 @@ import pytest -from galaxy.unittest.mock import async_return_value +from galaxy.unittest.mock import async_return_value, skip_loop from tests import create_message, get_messages @@ -57,6 +57,7 @@ async def test_set_cache(plugin, write, cache_data): plugin.persistent_cache.update(cache_data) plugin.push_cache() + await skip_loop() assert_rpc_request(write, "push_cache", cache_data) assert cache_data == plugin.persistent_cache @@ -68,6 +69,7 @@ async def test_clear_cache(plugin, write, cache_data): plugin.persistent_cache.clear() plugin.push_cache() + await skip_loop() assert_rpc_request(write, "push_cache", {}) assert {} == plugin.persistent_cache