chore(style): add one blank line

to conform with Google style

Signed-off-by: aarnphm-ec2-dev <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
aarnphm-ec2-dev
2023-08-26 11:36:57 +00:00
parent 938fd362bb
commit 806a663e4a
111 changed files with 601 additions and 87 deletions

View File

@@ -29,6 +29,7 @@ if t.TYPE_CHECKING:
from openllm._configuration import GenerationConfig
from openllm.client import BaseAsyncClient
class ResponseComparator(JSONSnapshotExtension):
def serialize(self, data: SerializableData, *, exclude: PropertyFilter | None = None, matcher: PropertyMatcher | None = None,) -> SerializedData:
if openllm.utils.LazyType(ListAny).isinstance(data):
@@ -66,9 +67,11 @@ class ResponseComparator(JSONSnapshotExtension):
return (len(s.responses) == len(t.responses) and all([_s == _t for _s, _t in zip(s.responses, t.responses)]) and eq_config(s.marshaled_config, t.marshaled_config))
return len(serialized_data) == len(snapshot_data) and all([eq_output(s, t) for s, t in zip(serialized_data, snapshot_data)])
@pytest.fixture()
def response_snapshot(snapshot: SnapshotAssertion):
return snapshot.use_extension(ResponseComparator)
@attr.define(init=False)
class _Handle(ABC):
port: int
@@ -100,6 +103,7 @@ class _Handle(ABC):
except Exception:
time.sleep(1)
raise RuntimeError(f'Handle failed to initialise within {timeout} seconds.')
@attr.define(init=False)
class LocalHandle(_Handle):
process: subprocess.Popen[bytes]
@@ -109,10 +113,12 @@ class LocalHandle(_Handle):
def status(self) -> bool:
return self.process.poll() is None
class HandleProtocol(t.Protocol):
@contextlib.contextmanager
def __call__(*, model: str, model_id: str, image_tag: str, quantize: t.AnyStr | None = None,) -> t.Generator[_Handle, None, None]:
...
@attr.define(init=False)
class DockerHandle(_Handle):
container_name: str
@@ -124,6 +130,7 @@ class DockerHandle(_Handle):
def status(self) -> bool:
container = self.docker_client.containers.get(self.container_name)
return container.status in ['running', 'created']
@contextlib.contextmanager
def _local_handle(
model: str, model_id: str, image_tag: str, deployment_mode: t.Literal['container', 'local'], quantize: t.Literal['int8', 'int4', 'gptq'] | None = None, *, _serve_grpc: bool = False,
@@ -146,6 +153,7 @@ def _local_handle(
proc.stdout.close()
if proc.stderr:
proc.stderr.close()
@contextlib.contextmanager
def _container_handle(
model: str, model_id: str, image_tag: str, deployment_mode: t.Literal['container', 'local'], quantize: t.Literal['int8', 'int4', 'gptq'] | None = None, *, _serve_grpc: bool = False,
@@ -192,19 +200,23 @@ def _container_handle(
print(container_output, file=sys.stderr)
container.remove()
@pytest.fixture(scope='session', autouse=True)
def clean_context() -> t.Generator[contextlib.ExitStack, None, None]:
stack = contextlib.ExitStack()
yield stack
stack.close()
@pytest.fixture(scope='module')
def el() -> t.Generator[asyncio.AbstractEventLoop, None, None]:
loop = asyncio.get_event_loop()
yield loop
loop.close()
@pytest.fixture(params=['container', 'local'], scope='session')
def deployment_mode(request: pytest.FixtureRequest) -> str:
return request.param
@pytest.fixture(scope='module')
def handler(el: asyncio.AbstractEventLoop, deployment_mode: t.Literal['container', 'local']):
if deployment_mode == 'container':

View File

@@ -10,15 +10,18 @@ if t.TYPE_CHECKING:
from .conftest import HandleProtocol, ResponseComparator, _Handle
model = 'flan_t5'
model_id = 'google/flan-t5-small'
@pytest.fixture(scope='module')
def flan_t5_handle(handler: HandleProtocol, deployment_mode: t.Literal['container', 'local'], clean_context: contextlib.ExitStack,):
with openllm.testing.prepare(model, model_id=model_id, deployment_mode=deployment_mode, clean_context=clean_context) as image_tag:
with handler(model=model, model_id=model_id, image_tag=image_tag) as handle:
yield handle
@pytest.fixture(scope='module')
async def flan_t5(flan_t5_handle: _Handle):
await flan_t5_handle.health(240)
return flan_t5_handle.client
@pytest.mark.asyncio()
async def test_flan_t5(flan_t5: t.Awaitable[openllm.client.AsyncHTTPClient], response_snapshot: ResponseComparator):
client = await flan_t5

View File

@@ -10,15 +10,18 @@ if t.TYPE_CHECKING:
from .conftest import HandleProtocol, ResponseComparator, _Handle
model = 'opt'
model_id = 'facebook/opt-125m'
@pytest.fixture(scope='module')
def opt_125m_handle(handler: HandleProtocol, deployment_mode: t.Literal['container', 'local'], clean_context: contextlib.ExitStack,):
with openllm.testing.prepare(model, model_id=model_id, deployment_mode=deployment_mode, clean_context=clean_context) as image_tag:
with handler(model=model, model_id=model_id, image_tag=image_tag) as handle:
yield handle
@pytest.fixture(scope='module')
async def opt_125m(opt_125m_handle: _Handle):
await opt_125m_handle.health(240)
return opt_125m_handle.client
@pytest.mark.asyncio()
async def test_opt_125m(opt_125m: t.Awaitable[openllm.client.AsyncHTTPClient], response_snapshot: ResponseComparator):
client = await opt_125m