From cf4e55c36fa69d4e0a05095f747d93d6f9655d4d Mon Sep 17 00:00:00 2001 From: Aaron <29749331+aarnphm@users.noreply.github.com> Date: Sat, 27 May 2023 05:20:07 -0700 Subject: [PATCH] fix(client): implement per client framework and model_name getters Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> --- src/openllm_client/runtimes/base.py | 13 +++++-------- src/openllm_client/runtimes/grpc.py | 26 ++++++++++++++++++++++++-- src/openllm_client/runtimes/http.py | 22 ++++++++++++++++++++-- 3 files changed, 49 insertions(+), 12 deletions(-) diff --git a/src/openllm_client/runtimes/base.py b/src/openllm_client/runtimes/base.py index e36b0b09..d32066f4 100644 --- a/src/openllm_client/runtimes/base.py +++ b/src/openllm_client/runtimes/base.py @@ -15,6 +15,7 @@ from __future__ import annotations import typing as t +from abc import abstractmethod import bentoml @@ -64,18 +65,14 @@ class ClientMixin: cls._api_version = api_version @property + @abstractmethod def model_name(self) -> str: - try: - return self._metadata["model_name"] - except KeyError: - raise RuntimeError("Malformed service endpoint. (Possible malicious)") + raise NotImplementedError @property + @abstractmethod def framework(self) -> t.Literal["pt", "flax", "tf"]: - try: - return self._metadata["framework"] - except KeyError: - raise RuntimeError("Malformed service endpoint. (Possible malicious)") + raise NotImplementedError @property def llm(self) -> openllm.LLM: diff --git a/src/openllm_client/runtimes/grpc.py b/src/openllm_client/runtimes/grpc.py index e3411486..dfd49b13 100644 --- a/src/openllm_client/runtimes/grpc.py +++ b/src/openllm_client/runtimes/grpc.py @@ -22,11 +22,33 @@ from .base import BaseAsyncClient, BaseClient if t.TYPE_CHECKING: import grpc_health.v1.health_pb2 as health_pb2 + from bentoml.grpc.v1.service_pb2 import Response logger = logging.getLogger(__name__) -class GrpcClient(BaseClient, client_type="grpc"): +class GrpcClientMixin: + _metadata: Response + + @property + def model_name(self) -> str: + try: + return self._metadata.json.struct_value.fields["model_name"].string_value + except KeyError: + raise RuntimeError("Malformed service endpoint. (Possible malicious)") + + @property + def framework(self) -> t.Literal["pt", "flax", "tf"]: + try: + value = self._metadata.json.struct_value.fields["framework"].string_value + if value not in ("pt", "flax", "tf"): + raise KeyError + return value + except KeyError: + raise RuntimeError("Malformed service endpoint. (Possible malicious)") + + +class GrpcClient(GrpcClientMixin, BaseClient, client_type="grpc"): def __init__(self, address: str, timeout: int = 30): self._host, self._port = address.split(":") super().__init__(address, timeout) @@ -35,7 +57,7 @@ class GrpcClient(BaseClient, client_type="grpc"): return asyncio.run(self._cached.health("bentoml.grpc.v1.BentoService")) -class AsyncGrpcClient(BaseAsyncClient, client_type="grpc"): +class AsyncGrpcClient(GrpcClientMixin, BaseAsyncClient, client_type="grpc"): def __init__(self, address: str, timeout: int = 30): self._host, self._port = address.split(":") super().__init__(address, timeout) diff --git a/src/openllm_client/runtimes/http.py b/src/openllm_client/runtimes/http.py index d2cbe8dc..f2c47caf 100644 --- a/src/openllm_client/runtimes/http.py +++ b/src/openllm_client/runtimes/http.py @@ -23,7 +23,25 @@ from .base import BaseAsyncClient, BaseClient logger = logging.getLogger(__name__) -class HTTPClient(BaseClient): +class HTTPClientMixin: + _metadata: dict[str, t.Any] + + @property + def model_name(self) -> str: + try: + return self._metadata["model_name"] + except KeyError: + raise RuntimeError("Malformed service endpoint. (Possible malicious)") + + @property + def framework(self) -> t.Literal["pt", "flax", "tf"]: + try: + return self._metadata["framework"] + except KeyError: + raise RuntimeError("Malformed service endpoint. (Possible malicious)") + + +class HTTPClient(HTTPClientMixin, BaseClient): def __init__(self, address: str, timeout: int = 30): address = address if "://" in address else "http://" + address self._host, self._port = urlparse(address).netloc.split(":") @@ -33,7 +51,7 @@ class HTTPClient(BaseClient): return self._cached.health() -class AsyncHTTPClient(BaseAsyncClient): +class AsyncHTTPClient(HTTPClientMixin, BaseAsyncClient): def __init__(self, address: str, timeout: int = 30): address = address if "://" in address else "http://" + address self._host, self._port = urlparse(address).netloc.split(":")