fix(client): implement per client framework and model_name getters

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
Aaron
2023-05-27 05:20:07 -07:00
parent aa72c16651
commit cf4e55c36f
3 changed files with 49 additions and 12 deletions

View File

@@ -15,6 +15,7 @@
from __future__ import annotations
import typing as t
from abc import abstractmethod
import bentoml
@@ -64,18 +65,14 @@ class ClientMixin:
cls._api_version = api_version
@property
@abstractmethod
def model_name(self) -> str:
try:
return self._metadata["model_name"]
except KeyError:
raise RuntimeError("Malformed service endpoint. (Possible malicious)")
raise NotImplementedError
@property
@abstractmethod
def framework(self) -> t.Literal["pt", "flax", "tf"]:
try:
return self._metadata["framework"]
except KeyError:
raise RuntimeError("Malformed service endpoint. (Possible malicious)")
raise NotImplementedError
@property
def llm(self) -> openllm.LLM:

View File

@@ -22,11 +22,33 @@ from .base import BaseAsyncClient, BaseClient
if t.TYPE_CHECKING:
import grpc_health.v1.health_pb2 as health_pb2
from bentoml.grpc.v1.service_pb2 import Response
logger = logging.getLogger(__name__)
class GrpcClient(BaseClient, client_type="grpc"):
class GrpcClientMixin:
_metadata: Response
@property
def model_name(self) -> str:
try:
return self._metadata.json.struct_value.fields["model_name"].string_value
except KeyError:
raise RuntimeError("Malformed service endpoint. (Possible malicious)")
@property
def framework(self) -> t.Literal["pt", "flax", "tf"]:
try:
value = self._metadata.json.struct_value.fields["framework"].string_value
if value not in ("pt", "flax", "tf"):
raise KeyError
return value
except KeyError:
raise RuntimeError("Malformed service endpoint. (Possible malicious)")
class GrpcClient(GrpcClientMixin, BaseClient, client_type="grpc"):
def __init__(self, address: str, timeout: int = 30):
self._host, self._port = address.split(":")
super().__init__(address, timeout)
@@ -35,7 +57,7 @@ class GrpcClient(BaseClient, client_type="grpc"):
return asyncio.run(self._cached.health("bentoml.grpc.v1.BentoService"))
class AsyncGrpcClient(BaseAsyncClient, client_type="grpc"):
class AsyncGrpcClient(GrpcClientMixin, BaseAsyncClient, client_type="grpc"):
def __init__(self, address: str, timeout: int = 30):
self._host, self._port = address.split(":")
super().__init__(address, timeout)

View File

@@ -23,7 +23,25 @@ from .base import BaseAsyncClient, BaseClient
logger = logging.getLogger(__name__)
class HTTPClient(BaseClient):
class HTTPClientMixin:
_metadata: dict[str, t.Any]
@property
def model_name(self) -> str:
try:
return self._metadata["model_name"]
except KeyError:
raise RuntimeError("Malformed service endpoint. (Possible malicious)")
@property
def framework(self) -> t.Literal["pt", "flax", "tf"]:
try:
return self._metadata["framework"]
except KeyError:
raise RuntimeError("Malformed service endpoint. (Possible malicious)")
class HTTPClient(HTTPClientMixin, BaseClient):
def __init__(self, address: str, timeout: int = 30):
address = address if "://" in address else "http://" + address
self._host, self._port = urlparse(address).netloc.split(":")
@@ -33,7 +51,7 @@ class HTTPClient(BaseClient):
return self._cached.health()
class AsyncHTTPClient(BaseAsyncClient):
class AsyncHTTPClient(HTTPClientMixin, BaseAsyncClient):
def __init__(self, address: str, timeout: int = 30):
address = address if "://" in address else "http://" + address
self._host, self._port = urlparse(address).netloc.split(":")