diff --git a/src/openllm_client/runtimes/base.py b/src/openllm_client/runtimes/base.py index e36b0b09..d32066f4 100644 --- a/src/openllm_client/runtimes/base.py +++ b/src/openllm_client/runtimes/base.py @@ -15,6 +15,7 @@ from __future__ import annotations import typing as t +from abc import abstractmethod import bentoml @@ -64,18 +65,14 @@ class ClientMixin: cls._api_version = api_version @property + @abstractmethod def model_name(self) -> str: - try: - return self._metadata["model_name"] - except KeyError: - raise RuntimeError("Malformed service endpoint. (Possible malicious)") + raise NotImplementedError @property + @abstractmethod def framework(self) -> t.Literal["pt", "flax", "tf"]: - try: - return self._metadata["framework"] - except KeyError: - raise RuntimeError("Malformed service endpoint. (Possible malicious)") + raise NotImplementedError @property def llm(self) -> openllm.LLM: diff --git a/src/openllm_client/runtimes/grpc.py b/src/openllm_client/runtimes/grpc.py index e3411486..dfd49b13 100644 --- a/src/openllm_client/runtimes/grpc.py +++ b/src/openllm_client/runtimes/grpc.py @@ -22,11 +22,33 @@ from .base import BaseAsyncClient, BaseClient if t.TYPE_CHECKING: import grpc_health.v1.health_pb2 as health_pb2 + from bentoml.grpc.v1.service_pb2 import Response logger = logging.getLogger(__name__) -class GrpcClient(BaseClient, client_type="grpc"): +class GrpcClientMixin: + _metadata: Response + + @property + def model_name(self) -> str: + try: + return self._metadata.json.struct_value.fields["model_name"].string_value + except KeyError: + raise RuntimeError("Malformed service endpoint. (Possible malicious)") + + @property + def framework(self) -> t.Literal["pt", "flax", "tf"]: + try: + value = self._metadata.json.struct_value.fields["framework"].string_value + if value not in ("pt", "flax", "tf"): + raise KeyError + return value + except KeyError: + raise RuntimeError("Malformed service endpoint. (Possible malicious)") + + +class GrpcClient(GrpcClientMixin, BaseClient, client_type="grpc"): def __init__(self, address: str, timeout: int = 30): self._host, self._port = address.split(":") super().__init__(address, timeout) @@ -35,7 +57,7 @@ class GrpcClient(BaseClient, client_type="grpc"): return asyncio.run(self._cached.health("bentoml.grpc.v1.BentoService")) -class AsyncGrpcClient(BaseAsyncClient, client_type="grpc"): +class AsyncGrpcClient(GrpcClientMixin, BaseAsyncClient, client_type="grpc"): def __init__(self, address: str, timeout: int = 30): self._host, self._port = address.split(":") super().__init__(address, timeout) diff --git a/src/openllm_client/runtimes/http.py b/src/openllm_client/runtimes/http.py index d2cbe8dc..f2c47caf 100644 --- a/src/openllm_client/runtimes/http.py +++ b/src/openllm_client/runtimes/http.py @@ -23,7 +23,25 @@ from .base import BaseAsyncClient, BaseClient logger = logging.getLogger(__name__) -class HTTPClient(BaseClient): +class HTTPClientMixin: + _metadata: dict[str, t.Any] + + @property + def model_name(self) -> str: + try: + return self._metadata["model_name"] + except KeyError: + raise RuntimeError("Malformed service endpoint. (Possible malicious)") + + @property + def framework(self) -> t.Literal["pt", "flax", "tf"]: + try: + return self._metadata["framework"] + except KeyError: + raise RuntimeError("Malformed service endpoint. (Possible malicious)") + + +class HTTPClient(HTTPClientMixin, BaseClient): def __init__(self, address: str, timeout: int = 30): address = address if "://" in address else "http://" + address self._host, self._port = urlparse(address).netloc.split(":") @@ -33,7 +51,7 @@ class HTTPClient(BaseClient): return self._cached.health() -class AsyncHTTPClient(BaseAsyncClient): +class AsyncHTTPClient(HTTPClientMixin, BaseAsyncClient): def __init__(self, address: str, timeout: int = 30): address = address if "://" in address else "http://" + address self._host, self._port = urlparse(address).netloc.split(":")