mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-01-23 15:01:32 -05:00
fix(client): implement per client framework and model_name getters
Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
@@ -15,6 +15,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
from abc import abstractmethod
|
||||
|
||||
import bentoml
|
||||
|
||||
@@ -64,18 +65,14 @@ class ClientMixin:
|
||||
cls._api_version = api_version
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def model_name(self) -> str:
|
||||
try:
|
||||
return self._metadata["model_name"]
|
||||
except KeyError:
|
||||
raise RuntimeError("Malformed service endpoint. (Possible malicious)")
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def framework(self) -> t.Literal["pt", "flax", "tf"]:
|
||||
try:
|
||||
return self._metadata["framework"]
|
||||
except KeyError:
|
||||
raise RuntimeError("Malformed service endpoint. (Possible malicious)")
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def llm(self) -> openllm.LLM:
|
||||
|
||||
@@ -22,11 +22,33 @@ from .base import BaseAsyncClient, BaseClient
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import grpc_health.v1.health_pb2 as health_pb2
|
||||
from bentoml.grpc.v1.service_pb2 import Response
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GrpcClient(BaseClient, client_type="grpc"):
|
||||
class GrpcClientMixin:
|
||||
_metadata: Response
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
try:
|
||||
return self._metadata.json.struct_value.fields["model_name"].string_value
|
||||
except KeyError:
|
||||
raise RuntimeError("Malformed service endpoint. (Possible malicious)")
|
||||
|
||||
@property
|
||||
def framework(self) -> t.Literal["pt", "flax", "tf"]:
|
||||
try:
|
||||
value = self._metadata.json.struct_value.fields["framework"].string_value
|
||||
if value not in ("pt", "flax", "tf"):
|
||||
raise KeyError
|
||||
return value
|
||||
except KeyError:
|
||||
raise RuntimeError("Malformed service endpoint. (Possible malicious)")
|
||||
|
||||
|
||||
class GrpcClient(GrpcClientMixin, BaseClient, client_type="grpc"):
|
||||
def __init__(self, address: str, timeout: int = 30):
|
||||
self._host, self._port = address.split(":")
|
||||
super().__init__(address, timeout)
|
||||
@@ -35,7 +57,7 @@ class GrpcClient(BaseClient, client_type="grpc"):
|
||||
return asyncio.run(self._cached.health("bentoml.grpc.v1.BentoService"))
|
||||
|
||||
|
||||
class AsyncGrpcClient(BaseAsyncClient, client_type="grpc"):
|
||||
class AsyncGrpcClient(GrpcClientMixin, BaseAsyncClient, client_type="grpc"):
|
||||
def __init__(self, address: str, timeout: int = 30):
|
||||
self._host, self._port = address.split(":")
|
||||
super().__init__(address, timeout)
|
||||
|
||||
@@ -23,7 +23,25 @@ from .base import BaseAsyncClient, BaseClient
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HTTPClient(BaseClient):
|
||||
class HTTPClientMixin:
|
||||
_metadata: dict[str, t.Any]
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
try:
|
||||
return self._metadata["model_name"]
|
||||
except KeyError:
|
||||
raise RuntimeError("Malformed service endpoint. (Possible malicious)")
|
||||
|
||||
@property
|
||||
def framework(self) -> t.Literal["pt", "flax", "tf"]:
|
||||
try:
|
||||
return self._metadata["framework"]
|
||||
except KeyError:
|
||||
raise RuntimeError("Malformed service endpoint. (Possible malicious)")
|
||||
|
||||
|
||||
class HTTPClient(HTTPClientMixin, BaseClient):
|
||||
def __init__(self, address: str, timeout: int = 30):
|
||||
address = address if "://" in address else "http://" + address
|
||||
self._host, self._port = urlparse(address).netloc.split(":")
|
||||
@@ -33,7 +51,7 @@ class HTTPClient(BaseClient):
|
||||
return self._cached.health()
|
||||
|
||||
|
||||
class AsyncHTTPClient(BaseAsyncClient):
|
||||
class AsyncHTTPClient(HTTPClientMixin, BaseAsyncClient):
|
||||
def __init__(self, address: str, timeout: int = 30):
|
||||
address = address if "://" in address else "http://" + address
|
||||
self._host, self._port = urlparse(address).netloc.split(":")
|
||||
|
||||
Reference in New Issue
Block a user