style: google

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
Aaron
2023-08-30 13:52:00 -04:00
parent e2ba6a92a6
commit b545ad2ad1
98 changed files with 3514 additions and 2094 deletions

View File

@@ -40,7 +40,13 @@ if t.TYPE_CHECKING:
from openllm.client import BaseAsyncClient
class ResponseComparator(JSONSnapshotExtension):
def serialize(self, data: SerializableData, *, exclude: PropertyFilter | None = None, matcher: PropertyMatcher | None = None,) -> SerializedData:
def serialize(self,
data: SerializableData,
*,
exclude: PropertyFilter | None = None,
matcher: PropertyMatcher | None = None,
) -> SerializedData:
if openllm.utils.LazyType(ListAny).isinstance(data):
data = [d.unmarshaled for d in data]
else:
@@ -49,6 +55,7 @@ class ResponseComparator(JSONSnapshotExtension):
return orjson.dumps(data, option=orjson.OPT_INDENT_2 | orjson.OPT_SORT_KEYS).decode()
def matches(self, *, serialized_data: SerializableData, snapshot_data: SerializableData) -> bool:
def convert_data(data: SerializableData) -> openllm.GenerationOutput | t.Sequence[openllm.GenerationOutput]:
try:
data = orjson.loads(data)
@@ -73,9 +80,11 @@ class ResponseComparator(JSONSnapshotExtension):
return s == t
def eq_output(s: openllm.GenerationOutput, t: openllm.GenerationOutput) -> bool:
return (len(s.responses) == len(t.responses) and all([_s == _t for _s, _t in zip(s.responses, t.responses)]) and eq_config(s.marshaled_config, t.marshaled_config))
return (len(s.responses) == len(t.responses) and all([_s == _t for _s, _t in zip(s.responses, t.responses)]) and
eq_config(s.marshaled_config, t.marshaled_config))
return len(serialized_data) == len(snapshot_data) and all([eq_output(s, t) for s, t in zip(serialized_data, snapshot_data)])
return len(serialized_data) == len(snapshot_data) and all(
[eq_output(s, t) for s, t in zip(serialized_data, snapshot_data)])
@pytest.fixture()
def response_snapshot(snapshot: SnapshotAssertion):
@@ -124,8 +133,14 @@ class LocalHandle(_Handle):
return self.process.poll() is None
class HandleProtocol(t.Protocol):
@contextlib.contextmanager
def __call__(*, model: str, model_id: str, image_tag: str, quantize: t.AnyStr | None = None,) -> t.Generator[_Handle, None, None]:
def __call__(*,
model: str,
model_id: str,
image_tag: str,
quantize: t.AnyStr | None = None,
) -> t.Generator[_Handle, None, None]:
...
@attr.define(init=False)
@@ -133,7 +148,9 @@ class DockerHandle(_Handle):
container_name: str
docker_client: docker.DockerClient
def __init__(self, docker_client: docker.DockerClient, container_name: str, port: int, deployment_mode: t.Literal['container', 'local'],):
def __init__(self, docker_client: docker.DockerClient, container_name: str, port: int,
deployment_mode: t.Literal['container', 'local'],
):
self.__attrs_init__(port, deployment_mode, container_name, docker_client)
def status(self) -> bool:
@@ -141,16 +158,29 @@ class DockerHandle(_Handle):
return container.status in ['running', 'created']
@contextlib.contextmanager
def _local_handle(
model: str, model_id: str, image_tag: str, deployment_mode: t.Literal['container', 'local'], quantize: t.Literal['int8', 'int4', 'gptq'] | None = None, *, _serve_grpc: bool = False,
):
def _local_handle(model: str,
model_id: str,
image_tag: str,
deployment_mode: t.Literal['container', 'local'],
quantize: t.Literal['int8', 'int4', 'gptq'] | None = None,
*,
_serve_grpc: bool = False,
):
with openllm.utils.reserve_free_port() as port:
pass
if not _serve_grpc:
proc = openllm.start(model, model_id=model_id, quantize=quantize, additional_args=['--port', str(port)], __test__=True)
proc = openllm.start(model,
model_id=model_id,
quantize=quantize,
additional_args=['--port', str(port)],
__test__=True)
else:
proc = openllm.start_grpc(model, model_id=model_id, quantize=quantize, additional_args=['--port', str(port)], __test__=True)
proc = openllm.start_grpc(model,
model_id=model_id,
quantize=quantize,
additional_args=['--port', str(port)],
__test__=True)
yield LocalHandle(proc, port, deployment_mode)
proc.terminate()
@@ -164,9 +194,14 @@ def _local_handle(
proc.stderr.close()
@contextlib.contextmanager
def _container_handle(
model: str, model_id: str, image_tag: str, deployment_mode: t.Literal['container', 'local'], quantize: t.Literal['int8', 'int4', 'gptq'] | None = None, *, _serve_grpc: bool = False,
):
def _container_handle(model: str,
model_id: str,
image_tag: str,
deployment_mode: t.Literal['container', 'local'],
quantize: t.Literal['int8', 'int4', 'gptq'] | None = None,
*,
_serve_grpc: bool = False,
):
envvar = openllm.utils.EnvVarMixin(model)
with openllm.utils.reserve_free_port() as port, openllm.utils.reserve_free_port() as prom_port:
@@ -191,11 +226,18 @@ def _container_handle(
gpus = openllm.utils.device_count() or -1
devs = [docker.types.DeviceRequest(count=gpus, capabilities=[['gpu']])] if gpus > 0 else None
container = client.containers.run(
image_tag, command=args, name=container_name, environment=env, auto_remove=False, detach=True, device_requests=devs, ports={
'3000/tcp': port, '3001/tcp': prom_port
},
)
container = client.containers.run(image_tag,
command=args,
name=container_name,
environment=env,
auto_remove=False,
detach=True,
device_requests=devs,
ports={
'3000/tcp': port,
'3001/tcp': prom_port
},
)
yield DockerHandle(client, container.name, port, deployment_mode)

View File

@@ -16,8 +16,11 @@ model = 'flan_t5'
model_id = 'google/flan-t5-small'
@pytest.fixture(scope='module')
def flan_t5_handle(handler: HandleProtocol, deployment_mode: t.Literal['container', 'local'], clean_context: contextlib.ExitStack,):
with openllm.testing.prepare(model, model_id=model_id, deployment_mode=deployment_mode, clean_context=clean_context) as image_tag:
def flan_t5_handle(handler: HandleProtocol, deployment_mode: t.Literal['container', 'local'],
clean_context: contextlib.ExitStack,
):
with openllm.testing.prepare(model, model_id=model_id, deployment_mode=deployment_mode,
clean_context=clean_context) as image_tag:
with handler(model=model, model_id=model_id, image_tag=image_tag) as handle:
yield handle

View File

@@ -16,8 +16,11 @@ model = 'opt'
model_id = 'facebook/opt-125m'
@pytest.fixture(scope='module')
def opt_125m_handle(handler: HandleProtocol, deployment_mode: t.Literal['container', 'local'], clean_context: contextlib.ExitStack,):
with openllm.testing.prepare(model, model_id=model_id, deployment_mode=deployment_mode, clean_context=clean_context) as image_tag:
def opt_125m_handle(handler: HandleProtocol, deployment_mode: t.Literal['container', 'local'],
clean_context: contextlib.ExitStack,
):
with openllm.testing.prepare(model, model_id=model_id, deployment_mode=deployment_mode,
clean_context=clean_context) as image_tag:
with handler(model=model, model_id=model_id, image_tag=image_tag) as handle:
yield handle