mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-05-18 21:54:11 -04:00
chore(style): enable yapf to match with style guidelines
Signed-off-by: aarnphm-ec2-dev <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
@@ -5,7 +5,6 @@ import attr, docker, docker.errors, docker.types, orjson, pytest, openllm
|
||||
from syrupy.extensions.json import JSONSnapshotExtension
|
||||
from openllm._llm import normalise_model_name
|
||||
from openllm_core._typing_compat import DictStrAny, ListAny
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
@@ -14,7 +13,6 @@ if t.TYPE_CHECKING:
|
||||
from syrupy.types import PropertyFilter, PropertyMatcher, SerializableData, SerializedData
|
||||
from openllm._configuration import GenerationConfig
|
||||
from openllm.client import BaseAsyncClient
|
||||
|
||||
class ResponseComparator(JSONSnapshotExtension):
|
||||
def serialize(self, data: SerializableData, *, exclude: PropertyFilter | None = None, matcher: PropertyMatcher | None = None,) -> SerializedData:
|
||||
if openllm.utils.LazyType(ListAny).isinstance(data):
|
||||
@@ -52,11 +50,9 @@ class ResponseComparator(JSONSnapshotExtension):
|
||||
return (len(s.responses) == len(t.responses) and all([_s == _t for _s, _t in zip(s.responses, t.responses)]) and eq_config(s.marshaled_config, t.marshaled_config))
|
||||
|
||||
return len(serialized_data) == len(snapshot_data) and all([eq_output(s, t) for s, t in zip(serialized_data, snapshot_data)])
|
||||
|
||||
@pytest.fixture()
|
||||
def response_snapshot(snapshot: SnapshotAssertion):
|
||||
return snapshot.use_extension(ResponseComparator)
|
||||
|
||||
@attr.define(init=False)
|
||||
class _Handle(ABC):
|
||||
port: int
|
||||
@@ -88,7 +84,6 @@ class _Handle(ABC):
|
||||
except Exception:
|
||||
time.sleep(1)
|
||||
raise RuntimeError(f"Handle failed to initialise within {timeout} seconds.")
|
||||
|
||||
@attr.define(init=False)
|
||||
class LocalHandle(_Handle):
|
||||
process: subprocess.Popen[bytes]
|
||||
@@ -98,12 +93,10 @@ class LocalHandle(_Handle):
|
||||
|
||||
def status(self) -> bool:
|
||||
return self.process.poll() is None
|
||||
|
||||
class HandleProtocol(t.Protocol):
|
||||
@contextlib.contextmanager
|
||||
def __call__(*, model: str, model_id: str, image_tag: str, quantize: t.AnyStr | None = None,) -> t.Generator[_Handle, None, None]:
|
||||
...
|
||||
|
||||
@attr.define(init=False)
|
||||
class DockerHandle(_Handle):
|
||||
container_name: str
|
||||
@@ -115,7 +108,6 @@ class DockerHandle(_Handle):
|
||||
def status(self) -> bool:
|
||||
container = self.docker_client.containers.get(self.container_name)
|
||||
return container.status in ["running", "created"]
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _local_handle(model: str, model_id: str, image_tag: str, deployment_mode: t.Literal["container", "local"], quantize: t.Literal["int8", "int4", "gptq"] | None = None, *, _serve_grpc: bool = False,):
|
||||
with openllm.utils.reserve_free_port() as port:
|
||||
@@ -136,7 +128,6 @@ def _local_handle(model: str, model_id: str, image_tag: str, deployment_mode: t.
|
||||
proc.stdout.close()
|
||||
if proc.stderr:
|
||||
proc.stderr.close()
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _container_handle(model: str, model_id: str, image_tag: str, deployment_mode: t.Literal["container", "local"], quantize: t.Literal["int8", "int4", "gptq"] | None = None, *, _serve_grpc: bool = False,):
|
||||
envvar = openllm.utils.EnvVarMixin(model)
|
||||
@@ -177,23 +168,19 @@ def _container_handle(model: str, model_id: str, image_tag: str, deployment_mode
|
||||
print(container_output, file=sys.stderr)
|
||||
|
||||
container.remove()
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def clean_context() -> t.Generator[contextlib.ExitStack, None, None]:
|
||||
stack = contextlib.ExitStack()
|
||||
yield stack
|
||||
stack.close()
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def el() -> t.Generator[asyncio.AbstractEventLoop, None, None]:
|
||||
loop = asyncio.get_event_loop()
|
||||
yield loop
|
||||
loop.close()
|
||||
|
||||
@pytest.fixture(params=["container", "local"], scope="session")
|
||||
def deployment_mode(request: pytest.FixtureRequest) -> str:
|
||||
return request.param
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def handler(el: asyncio.AbstractEventLoop, deployment_mode: t.Literal["container", "local"]):
|
||||
if deployment_mode == "container":
|
||||
|
||||
@@ -9,21 +9,17 @@ if t.TYPE_CHECKING:
|
||||
import contextlib
|
||||
|
||||
from .conftest import HandleProtocol, ResponseComparator, _Handle
|
||||
|
||||
model = "flan_t5"
|
||||
model_id = "google/flan-t5-small"
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def flan_t5_handle(handler: HandleProtocol, deployment_mode: t.Literal["container", "local"], clean_context: contextlib.ExitStack,):
|
||||
with openllm.testing.prepare(model, model_id=model_id, deployment_mode=deployment_mode, clean_context=clean_context) as image_tag:
|
||||
with handler(model=model, model_id=model_id, image_tag=image_tag) as handle:
|
||||
yield handle
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
async def flan_t5(flan_t5_handle: _Handle):
|
||||
await flan_t5_handle.health(240)
|
||||
return flan_t5_handle.client
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
async def test_flan_t5(flan_t5: t.Awaitable[openllm.client.AsyncHTTPClient], response_snapshot: ResponseComparator):
|
||||
client = await flan_t5
|
||||
|
||||
@@ -9,21 +9,17 @@ if t.TYPE_CHECKING:
|
||||
import contextlib
|
||||
|
||||
from .conftest import HandleProtocol, ResponseComparator, _Handle
|
||||
|
||||
model = "opt"
|
||||
model_id = "facebook/opt-125m"
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def opt_125m_handle(handler: HandleProtocol, deployment_mode: t.Literal["container", "local"], clean_context: contextlib.ExitStack,):
|
||||
with openllm.testing.prepare(model, model_id=model_id, deployment_mode=deployment_mode, clean_context=clean_context) as image_tag:
|
||||
with handler(model=model, model_id=model_id, image_tag=image_tag) as handle:
|
||||
yield handle
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
async def opt_125m(opt_125m_handle: _Handle):
|
||||
await opt_125m_handle.health(240)
|
||||
return opt_125m_handle.client
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
async def test_opt_125m(opt_125m: t.Awaitable[openllm.client.AsyncHTTPClient], response_snapshot: ResponseComparator):
|
||||
client = await opt_125m
|
||||
|
||||
Reference in New Issue
Block a user