from __future__ import annotations import asyncio, contextlib, functools, logging, sys, time, typing as t from abc import ABC, abstractmethod import attr, docker, docker.errors, docker.types, orjson, pytest, openllm from syrupy.extensions.json import JSONSnapshotExtension from openllm._llm import normalise_model_name from openllm_core._typing_compat import DictStrAny, ListAny logger = logging.getLogger(__name__) if t.TYPE_CHECKING: import subprocess from syrupy.assertion import SnapshotAssertion from syrupy.types import PropertyFilter, PropertyMatcher, SerializableData, SerializedData from openllm._configuration import GenerationConfig from openllm.client import BaseAsyncClient class ResponseComparator(JSONSnapshotExtension): def serialize(self, data: SerializableData, *, exclude: PropertyFilter | None = None, matcher: PropertyMatcher | None = None,) -> SerializedData: if openllm.utils.LazyType(ListAny).isinstance(data): data = [d.unmarshaled for d in data] else: data = data.unmarshaled data = self._filter(data=data, depth=0, path=(), exclude=exclude, matcher=matcher) return orjson.dumps(data, option=orjson.OPT_INDENT_2 | orjson.OPT_SORT_KEYS).decode() def matches(self, *, serialized_data: SerializableData, snapshot_data: SerializableData) -> bool: def convert_data(data: SerializableData) -> openllm.GenerationOutput | t.Sequence[openllm.GenerationOutput]: try: data = orjson.loads(data) except orjson.JSONDecodeError as err: raise ValueError(f"Failed to decode JSON data: {data}") from err if openllm.utils.LazyType(DictStrAny).isinstance(data): return openllm.GenerationOutput(**data) elif openllm.utils.LazyType(ListAny).isinstance(data): return [openllm.GenerationOutput(**d) for d in data] else: raise NotImplementedError(f"Data {data} has unsupported type.") serialized_data = convert_data(serialized_data) snapshot_data = convert_data(snapshot_data) if openllm.utils.LazyType(ListAny).isinstance(serialized_data): serialized_data = [serialized_data] if openllm.utils.LazyType(ListAny).isinstance(snapshot_data): snapshot_data = [snapshot_data] def eq_config(s: GenerationConfig, t: GenerationConfig) -> bool: return s == t def eq_output(s: openllm.GenerationOutput, t: openllm.GenerationOutput) -> bool: return (len(s.responses) == len(t.responses) and all([_s == _t for _s, _t in zip(s.responses, t.responses)]) and eq_config(s.marshaled_config, t.marshaled_config)) return len(serialized_data) == len(snapshot_data) and all([eq_output(s, t) for s, t in zip(serialized_data, snapshot_data)]) @pytest.fixture() def response_snapshot(snapshot: SnapshotAssertion): return snapshot.use_extension(ResponseComparator) @attr.define(init=False) class _Handle(ABC): port: int deployment_mode: t.Literal["container", "local"] client: BaseAsyncClient[t.Any] = attr.field(init=False) if t.TYPE_CHECKING: def __attrs_init__(self, *args: t.Any, **attrs: t.Any): ... def __attrs_post_init__(self): self.client = openllm.client.AsyncHTTPClient(f"http://localhost:{self.port}") @abstractmethod def status(self) -> bool: raise NotImplementedError async def health(self, timeout: int = 240): start_time = time.time() while time.time() - start_time < timeout: if not self.status(): raise RuntimeError(f"Failed to initialise {self.__class__.__name__}") await self.client.health() try: await self.client.query("sanity") return except Exception: time.sleep(1) raise RuntimeError(f"Handle failed to initialise within {timeout} seconds.") @attr.define(init=False) class LocalHandle(_Handle): process: subprocess.Popen[bytes] def __init__(self, process: subprocess.Popen[bytes], port: int, deployment_mode: t.Literal["container", "local"],): self.__attrs_init__(port, deployment_mode, process) def status(self) -> bool: return self.process.poll() is None class HandleProtocol(t.Protocol): @contextlib.contextmanager def __call__(*, model: str, model_id: str, image_tag: str, quantize: t.AnyStr | None = None,) -> t.Generator[_Handle, None, None]: ... @attr.define(init=False) class DockerHandle(_Handle): container_name: str docker_client: docker.DockerClient def __init__(self, docker_client: docker.DockerClient, container_name: str, port: int, deployment_mode: t.Literal["container", "local"],): self.__attrs_init__(port, deployment_mode, container_name, docker_client) def status(self) -> bool: container = self.docker_client.containers.get(self.container_name) return container.status in ["running", "created"] @contextlib.contextmanager def _local_handle( model: str, model_id: str, image_tag: str, deployment_mode: t.Literal["container", "local"], quantize: t.Literal["int8", "int4", "gptq"] | None = None, *, _serve_grpc: bool = False, ): with openllm.utils.reserve_free_port() as port: pass if not _serve_grpc: proc = openllm.start(model, model_id=model_id, quantize=quantize, additional_args=["--port", str(port)], __test__=True) else: proc = openllm.start_grpc(model, model_id=model_id, quantize=quantize, additional_args=["--port", str(port)], __test__=True) yield LocalHandle(proc, port, deployment_mode) proc.terminate() proc.wait(60) process_output = proc.stdout.read() print(process_output, file=sys.stderr) proc.stdout.close() if proc.stderr: proc.stderr.close() @contextlib.contextmanager def _container_handle( model: str, model_id: str, image_tag: str, deployment_mode: t.Literal["container", "local"], quantize: t.Literal["int8", "int4", "gptq"] | None = None, *, _serve_grpc: bool = False, ): envvar = openllm.utils.EnvVarMixin(model) with openllm.utils.reserve_free_port() as port, openllm.utils.reserve_free_port() as prom_port: pass container_name = f"openllm-{model}-{normalise_model_name(model_id)}".replace("-", "_") client = docker.from_env() try: container = client.containers.get(container_name) container.stop() container.wait() container.remove() except docker.errors.NotFound: pass args = ["serve" if not _serve_grpc else "serve-grpc"] env: DictStrAny = {} if quantize is not None: env[envvar.quantize] = quantize gpus = openllm.utils.device_count() or -1 devs = [docker.types.DeviceRequest(count=gpus, capabilities=[["gpu"]])] if gpus > 0 else None container = client.containers.run( image_tag, command=args, name=container_name, environment=env, auto_remove=False, detach=True, device_requests=devs, ports={ "3000/tcp": port, "3001/tcp": prom_port }, ) yield DockerHandle(client, container.name, port, deployment_mode) try: container.stop() container.wait() except docker.errors.NotFound: pass container_output = container.logs().decode("utf-8") print(container_output, file=sys.stderr) container.remove() @pytest.fixture(scope="session", autouse=True) def clean_context() -> t.Generator[contextlib.ExitStack, None, None]: stack = contextlib.ExitStack() yield stack stack.close() @pytest.fixture(scope="module") def el() -> t.Generator[asyncio.AbstractEventLoop, None, None]: loop = asyncio.get_event_loop() yield loop loop.close() @pytest.fixture(params=["container", "local"], scope="session") def deployment_mode(request: pytest.FixtureRequest) -> str: return request.param @pytest.fixture(scope="module") def handler(el: asyncio.AbstractEventLoop, deployment_mode: t.Literal["container", "local"]): if deployment_mode == "container": return functools.partial(_container_handle, deployment_mode=deployment_mode) elif deployment_mode == "local": return functools.partial(_local_handle, deployment_mode=deployment_mode) else: raise ValueError(f"Unknown deployment mode: {deployment_mode}")