chore(style): synchronized style across packages [skip ci]

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
Aaron
2023-08-23 08:46:22 -04:00
parent bbd9aa7646
commit 787ce1b3b6
124 changed files with 2775 additions and 2771 deletions

View File

@@ -27,13 +27,13 @@ class ResponseComparator(JSONSnapshotExtension):
try:
data = orjson.loads(data)
except orjson.JSONDecodeError as err:
raise ValueError(f"Failed to decode JSON data: {data}") from err
raise ValueError(f'Failed to decode JSON data: {data}') from err
if openllm.utils.LazyType(DictStrAny).isinstance(data):
return openllm.GenerationOutput(**data)
elif openllm.utils.LazyType(ListAny).isinstance(data):
return [openllm.GenerationOutput(**d) for d in data]
else:
raise NotImplementedError(f"Data {data} has unsupported type.")
raise NotImplementedError(f'Data {data} has unsupported type.')
serialized_data = convert_data(serialized_data)
snapshot_data = convert_data(snapshot_data)
@@ -56,7 +56,7 @@ def response_snapshot(snapshot: SnapshotAssertion):
@attr.define(init=False)
class _Handle(ABC):
port: int
deployment_mode: t.Literal["container", "local"]
deployment_mode: t.Literal['container', 'local']
client: BaseAsyncClient[t.Any] = attr.field(init=False)
@@ -66,7 +66,7 @@ class _Handle(ABC):
...
def __attrs_post_init__(self):
self.client = openllm.client.AsyncHTTPClient(f"http://localhost:{self.port}")
self.client = openllm.client.AsyncHTTPClient(f'http://localhost:{self.port}')
@abstractmethod
def status(self) -> bool:
@@ -76,19 +76,19 @@ class _Handle(ABC):
start_time = time.time()
while time.time() - start_time < timeout:
if not self.status():
raise RuntimeError(f"Failed to initialise {self.__class__.__name__}")
raise RuntimeError(f'Failed to initialise {self.__class__.__name__}')
await self.client.health()
try:
await self.client.query("sanity")
await self.client.query('sanity')
return
except Exception:
time.sleep(1)
raise RuntimeError(f"Handle failed to initialise within {timeout} seconds.")
raise RuntimeError(f'Handle failed to initialise within {timeout} seconds.')
@attr.define(init=False)
class LocalHandle(_Handle):
process: subprocess.Popen[bytes]
def __init__(self, process: subprocess.Popen[bytes], port: int, deployment_mode: t.Literal["container", "local"],):
def __init__(self, process: subprocess.Popen[bytes], port: int, deployment_mode: t.Literal['container', 'local'],):
self.__attrs_init__(port, deployment_mode, process)
def status(self) -> bool:
@@ -102,23 +102,23 @@ class DockerHandle(_Handle):
container_name: str
docker_client: docker.DockerClient
def __init__(self, docker_client: docker.DockerClient, container_name: str, port: int, deployment_mode: t.Literal["container", "local"],):
def __init__(self, docker_client: docker.DockerClient, container_name: str, port: int, deployment_mode: t.Literal['container', 'local'],):
self.__attrs_init__(port, deployment_mode, container_name, docker_client)
def status(self) -> bool:
container = self.docker_client.containers.get(self.container_name)
return container.status in ["running", "created"]
return container.status in ['running', 'created']
@contextlib.contextmanager
def _local_handle(
model: str, model_id: str, image_tag: str, deployment_mode: t.Literal["container", "local"], quantize: t.Literal["int8", "int4", "gptq"] | None = None, *, _serve_grpc: bool = False,
model: str, model_id: str, image_tag: str, deployment_mode: t.Literal['container', 'local'], quantize: t.Literal['int8', 'int4', 'gptq'] | None = None, *, _serve_grpc: bool = False,
):
with openllm.utils.reserve_free_port() as port:
pass
if not _serve_grpc:
proc = openllm.start(model, model_id=model_id, quantize=quantize, additional_args=["--port", str(port)], __test__=True)
proc = openllm.start(model, model_id=model_id, quantize=quantize, additional_args=['--port', str(port)], __test__=True)
else:
proc = openllm.start_grpc(model, model_id=model_id, quantize=quantize, additional_args=["--port", str(port)], __test__=True)
proc = openllm.start_grpc(model, model_id=model_id, quantize=quantize, additional_args=['--port', str(port)], __test__=True)
yield LocalHandle(proc, port, deployment_mode)
proc.terminate()
@@ -132,13 +132,13 @@ def _local_handle(
proc.stderr.close()
@contextlib.contextmanager
def _container_handle(
model: str, model_id: str, image_tag: str, deployment_mode: t.Literal["container", "local"], quantize: t.Literal["int8", "int4", "gptq"] | None = None, *, _serve_grpc: bool = False,
model: str, model_id: str, image_tag: str, deployment_mode: t.Literal['container', 'local'], quantize: t.Literal['int8', 'int4', 'gptq'] | None = None, *, _serve_grpc: bool = False,
):
envvar = openllm.utils.EnvVarMixin(model)
with openllm.utils.reserve_free_port() as port, openllm.utils.reserve_free_port() as prom_port:
pass
container_name = f"openllm-{model}-{normalise_model_name(model_id)}".replace("-", "_")
container_name = f'openllm-{model}-{normalise_model_name(model_id)}'.replace('-', '_')
client = docker.from_env()
try:
container = client.containers.get(container_name)
@@ -148,7 +148,7 @@ def _container_handle(
except docker.errors.NotFound:
pass
args = ["serve" if not _serve_grpc else "serve-grpc"]
args = ['serve' if not _serve_grpc else 'serve-grpc']
env: DictStrAny = {}
@@ -156,11 +156,11 @@ def _container_handle(
env[envvar.quantize] = quantize
gpus = openllm.utils.device_count() or -1
devs = [docker.types.DeviceRequest(count=gpus, capabilities=[["gpu"]])] if gpus > 0 else None
devs = [docker.types.DeviceRequest(count=gpus, capabilities=[['gpu']])] if gpus > 0 else None
container = client.containers.run(
image_tag, command=args, name=container_name, environment=env, auto_remove=False, detach=True, device_requests=devs, ports={
"3000/tcp": port, "3001/tcp": prom_port
'3000/tcp': port, '3001/tcp': prom_port
},
)
@@ -172,28 +172,28 @@ def _container_handle(
except docker.errors.NotFound:
pass
container_output = container.logs().decode("utf-8")
container_output = container.logs().decode('utf-8')
print(container_output, file=sys.stderr)
container.remove()
@pytest.fixture(scope="session", autouse=True)
@pytest.fixture(scope='session', autouse=True)
def clean_context() -> t.Generator[contextlib.ExitStack, None, None]:
stack = contextlib.ExitStack()
yield stack
stack.close()
@pytest.fixture(scope="module")
@pytest.fixture(scope='module')
def el() -> t.Generator[asyncio.AbstractEventLoop, None, None]:
loop = asyncio.get_event_loop()
yield loop
loop.close()
@pytest.fixture(params=["container", "local"], scope="session")
@pytest.fixture(params=['container', 'local'], scope='session')
def deployment_mode(request: pytest.FixtureRequest) -> str:
return request.param
@pytest.fixture(scope="module")
def handler(el: asyncio.AbstractEventLoop, deployment_mode: t.Literal["container", "local"]):
if deployment_mode == "container":
@pytest.fixture(scope='module')
def handler(el: asyncio.AbstractEventLoop, deployment_mode: t.Literal['container', 'local']):
if deployment_mode == 'container':
return functools.partial(_container_handle, deployment_mode=deployment_mode)
elif deployment_mode == "local":
elif deployment_mode == 'local':
return functools.partial(_local_handle, deployment_mode=deployment_mode)
else:
raise ValueError(f"Unknown deployment mode: {deployment_mode}")
raise ValueError(f'Unknown deployment mode: {deployment_mode}')

View File

@@ -9,21 +9,21 @@ if t.TYPE_CHECKING:
import contextlib
from .conftest import HandleProtocol, ResponseComparator, _Handle
model = "flan_t5"
model_id = "google/flan-t5-small"
@pytest.fixture(scope="module")
def flan_t5_handle(handler: HandleProtocol, deployment_mode: t.Literal["container", "local"], clean_context: contextlib.ExitStack,):
model = 'flan_t5'
model_id = 'google/flan-t5-small'
@pytest.fixture(scope='module')
def flan_t5_handle(handler: HandleProtocol, deployment_mode: t.Literal['container', 'local'], clean_context: contextlib.ExitStack,):
with openllm.testing.prepare(model, model_id=model_id, deployment_mode=deployment_mode, clean_context=clean_context) as image_tag:
with handler(model=model, model_id=model_id, image_tag=image_tag) as handle:
yield handle
@pytest.fixture(scope="module")
@pytest.fixture(scope='module')
async def flan_t5(flan_t5_handle: _Handle):
await flan_t5_handle.health(240)
return flan_t5_handle.client
@pytest.mark.asyncio()
async def test_flan_t5(flan_t5: t.Awaitable[openllm.client.AsyncHTTPClient], response_snapshot: ResponseComparator):
client = await flan_t5
response = await client.query("What is the meaning of life?", max_new_tokens=10, top_p=0.9, return_response="attrs")
response = await client.query('What is the meaning of life?', max_new_tokens=10, top_p=0.9, return_response='attrs')
assert response.configuration["generation_config"]["max_new_tokens"] == 10
assert response.configuration['generation_config']['max_new_tokens'] == 10
assert response == response_snapshot

View File

@@ -9,21 +9,21 @@ if t.TYPE_CHECKING:
import contextlib
from .conftest import HandleProtocol, ResponseComparator, _Handle
model = "opt"
model_id = "facebook/opt-125m"
@pytest.fixture(scope="module")
def opt_125m_handle(handler: HandleProtocol, deployment_mode: t.Literal["container", "local"], clean_context: contextlib.ExitStack,):
model = 'opt'
model_id = 'facebook/opt-125m'
@pytest.fixture(scope='module')
def opt_125m_handle(handler: HandleProtocol, deployment_mode: t.Literal['container', 'local'], clean_context: contextlib.ExitStack,):
with openllm.testing.prepare(model, model_id=model_id, deployment_mode=deployment_mode, clean_context=clean_context) as image_tag:
with handler(model=model, model_id=model_id, image_tag=image_tag) as handle:
yield handle
@pytest.fixture(scope="module")
@pytest.fixture(scope='module')
async def opt_125m(opt_125m_handle: _Handle):
await opt_125m_handle.health(240)
return opt_125m_handle.client
@pytest.mark.asyncio()
async def test_opt_125m(opt_125m: t.Awaitable[openllm.client.AsyncHTTPClient], response_snapshot: ResponseComparator):
client = await opt_125m
response = await client.query("What is Deep learning?", max_new_tokens=20, return_response="attrs")
response = await client.query('What is Deep learning?', max_new_tokens=20, return_response='attrs')
assert response.configuration["generation_config"]["max_new_tokens"] == 20
assert response.configuration['generation_config']['max_new_tokens'] == 20
assert response == response_snapshot