From 4e1ea8056d66bb58ac5ca0e83cc6de7ce30aaf36 Mon Sep 17 00:00:00 2001 From: Romuald Juchnowicz-Bierbasz Date: Fri, 28 Jun 2019 14:00:44 +0200 Subject: [PATCH] Add StreamLineReader with unit tests --- src/galaxy/api/jsonrpc.py | 28 +++-------------- src/galaxy/reader.py | 28 +++++++++++++++++ tests/test_stream_line_reader.py | 52 ++++++++++++++++++++++++++++++++ 3 files changed, 84 insertions(+), 24 deletions(-) create mode 100644 src/galaxy/reader.py create mode 100644 tests/test_stream_line_reader.py diff --git a/src/galaxy/api/jsonrpc.py b/src/galaxy/api/jsonrpc.py index eb40ab3..37d0e17 100644 --- a/src/galaxy/api/jsonrpc.py +++ b/src/galaxy/api/jsonrpc.py @@ -5,6 +5,8 @@ import logging import inspect import json +from galaxy.reader import StreamLineReader + class JsonRpcError(Exception): def __init__(self, code, message, data=None): self.code = code @@ -67,14 +69,12 @@ def anonymise_sensitive_params(params, sensitive_params): class Server(): def __init__(self, reader, writer, encoder=json.JSONEncoder()): self._active = True - self._reader = reader + self._reader = StreamLineReader(reader) self._writer = writer self._encoder = encoder self._methods = {} self._notifications = {} self._eof_listeners = [] - self._input_buffer = bytes() - self._processed_input_buffer_it = 0 def register_method(self, name, callback, internal, sensitive_params=False): """ @@ -106,7 +106,7 @@ class Server(): async def run(self): while self._active: try: - data = await self._readline() + data = await self._reader.readline() if not data: self._eof() continue @@ -117,26 +117,6 @@ class Server(): logging.debug("Received %d bytes of data", len(data)) self._handle_input(data) - async def _readline(self): - """Like StreamReader.readline but without limit""" - while True: - # check if there is no unprocessed data in the buffer - if not self._input_buffer or self._processed_input_buffer_it != 0: - chunk = await self._reader.read(1024) - if not chunk: - return bytes() # EOF - self._input_buffer += chunk - - it = self._input_buffer.find(b"\n", self._processed_input_buffer_it) - if it < 0: - self._processed_input_buffer_it = len(self._input_buffer) - continue - - line = self._input_buffer[:it] - self._input_buffer = self._input_buffer[it+1:] - self._processed_input_buffer_it = 0 - return line - def stop(self): self._active = False diff --git a/src/galaxy/reader.py b/src/galaxy/reader.py new file mode 100644 index 0000000..551f803 --- /dev/null +++ b/src/galaxy/reader.py @@ -0,0 +1,28 @@ +from asyncio import StreamReader + + +class StreamLineReader: + """Handles StreamReader readline without buffer limit""" + def __init__(self, reader: StreamReader): + self._reader = reader + self._buffer = bytes() + self._processed_buffer_it = 0 + + async def readline(self): + while True: + # check if there is no unprocessed data in the buffer + if not self._buffer or self._processed_buffer_it != 0: + chunk = await self._reader.read(1024) + if not chunk: + return bytes() # EOF + self._buffer += chunk + + it = self._buffer.find(b"\n", self._processed_buffer_it) + if it < 0: + self._processed_buffer_it = len(self._buffer) + continue + + line = self._buffer[:it] + self._buffer = self._buffer[it+1:] + self._processed_buffer_it = 0 + return line diff --git a/tests/test_stream_line_reader.py b/tests/test_stream_line_reader.py new file mode 100644 index 0000000..2f81e6c --- /dev/null +++ b/tests/test_stream_line_reader.py @@ -0,0 +1,52 @@ +from unittest.mock import MagicMock + +import pytest + +from galaxy.reader import StreamLineReader +from galaxy.unittest.mock import AsyncMock + +@pytest.fixture() +def stream_reader(): + reader = MagicMock() + reader.read = AsyncMock() + return reader + +@pytest.fixture() +def read(stream_reader): + return stream_reader.read + +@pytest.fixture() +def reader(stream_reader): + return StreamLineReader(stream_reader) + +@pytest.mark.asyncio +async def test_message(reader, read): + read.return_value = b"a\n" + assert await reader.readline() == b"a" + read.assert_called_once() + +@pytest.mark.asyncio +async def test_separate_messages(reader, read): + read.side_effect = [b"a\n", b"b\n"] + assert await reader.readline() == b"a" + assert await reader.readline() == b"b" + assert read.call_count == 2 + +@pytest.mark.asyncio +async def test_connected_messages(reader, read): + read.return_value = b"a\nb\n" + assert await reader.readline() == b"a" + assert await reader.readline() == b"b" + read.assert_called_once() + +@pytest.mark.asyncio +async def test_cut_message(reader, read): + read.side_effect = [b"a", b"b\n"] + assert await reader.readline() == b"ab" + assert read.call_count == 2 + +@pytest.mark.asyncio +async def test_half_message(reader, read): + read.side_effect = [b"a", b""] + assert await reader.readline() == b"" + assert read.call_count == 2