mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-03-04 23:26:16 -05:00
225 lines
7.8 KiB
Python
225 lines
7.8 KiB
Python
from __future__ import annotations
|
|
import importlib.metadata
|
|
import logging
|
|
import os
|
|
import typing as t
|
|
|
|
import attr
|
|
|
|
from ._schemas import Helpers, Metadata, Response, StreamingResponse
|
|
from ._shim import MAX_RETRIES, AsyncClient, Client
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
VERSION = importlib.metadata.version('openllm-client')
|
|
|
|
|
|
def _address_converter(addr: str):
|
|
return addr if '://' in addr else 'http://' + addr
|
|
|
|
|
|
@attr.define(init=False)
|
|
class HTTPClient(Client):
|
|
helpers: Helpers = attr.field(init=False)
|
|
_api_version: str = 'v1'
|
|
_verify: bool = True
|
|
__metadata: Metadata | None = None
|
|
__config: dict[str, t.Any] | None = None
|
|
|
|
def __repr__(self):
|
|
return f'<HTTPClient address={self.address} timeout={self._timeout} api_version={self._api_version} verify={self._verify}>'
|
|
|
|
def __init__(self, address=None, timeout=30, verify=False, max_retries=MAX_RETRIES, api_version='v1'):
|
|
if address is None:
|
|
address = os.getenv('OPENLLM_ENDPOINT')
|
|
if address is None:
|
|
raise ValueError("address must either be provided or through 'OPENLLM_ENDPOINT'")
|
|
self._api_version, self._verify = api_version, verify
|
|
|
|
self.helpers = Helpers(client=self)
|
|
|
|
super().__init__(_address_converter(address), VERSION, timeout=timeout, max_retries=max_retries)
|
|
|
|
def _build_auth_headers(self) -> t.Dict[str, str]:
|
|
env = os.getenv('OPENLLM_AUTH_TOKEN')
|
|
if env is not None:
|
|
return {'Authorization': f'Bearer {env}'}
|
|
return super()._build_auth_headers()
|
|
|
|
def __getitem__(self, item):
|
|
if hasattr(self._metadata, item):
|
|
return getattr(self._metadata, item)
|
|
elif item in self._config:
|
|
return self._config[item]
|
|
raise KeyError(f'No attributes: {item}')
|
|
|
|
@property
|
|
def _metadata(self):
|
|
if self.__metadata is None:
|
|
path = f'/{self._api_version}/metadata'
|
|
self.__metadata = self._post(path, response_cls=Metadata, json={}, options={'max_retries': self._max_retries})
|
|
return self.__metadata
|
|
|
|
@property
|
|
def _config(self) -> dict[str, t.Any]:
|
|
if self.__config is None:
|
|
self.__config = self._metadata.configuration
|
|
return self.__config
|
|
|
|
def query(self, prompt, **attrs):
|
|
return self.generate(prompt, **attrs)
|
|
|
|
def health(self):
|
|
response = self._get(
|
|
'/readyz', response_cls=None, options={'return_raw_response': True, 'max_retries': self._max_retries}
|
|
)
|
|
return response.status_code == 200
|
|
|
|
def generate(
|
|
self, prompt, llm_config=None, stop=None, adapter_name=None, timeout=None, verify=None, **attrs
|
|
) -> Response:
|
|
if timeout is None:
|
|
timeout = self._timeout
|
|
if verify is None:
|
|
verify = self._verify # XXX: need to support this again
|
|
if llm_config is not None:
|
|
llm_config = {**self._config, **llm_config, **attrs}
|
|
else:
|
|
llm_config = {**self._config, **attrs}
|
|
|
|
return self._post(
|
|
f'/{self._api_version}/generate',
|
|
response_cls=Response,
|
|
json=dict(prompt=prompt, llm_config=llm_config, stop=stop, adapter_name=adapter_name),
|
|
options={'max_retries': self._max_retries},
|
|
)
|
|
|
|
def generate_stream(
|
|
self, prompt, llm_config=None, stop=None, adapter_name=None, timeout=None, verify=None, **attrs
|
|
) -> t.Iterator[StreamingResponse]:
|
|
for response_chunk in self.generate_iterator(prompt, llm_config, stop, adapter_name, timeout, verify, **attrs):
|
|
yield StreamingResponse.from_response_chunk(response_chunk)
|
|
|
|
def generate_iterator(
|
|
self, prompt, llm_config=None, stop=None, adapter_name=None, timeout=None, verify=None, **attrs
|
|
) -> t.Iterator[Response]:
|
|
if timeout is None:
|
|
timeout = self._timeout
|
|
if verify is None:
|
|
verify = self._verify # XXX: need to support this again
|
|
if llm_config is not None:
|
|
llm_config = {**self._config, **llm_config, **attrs}
|
|
else:
|
|
llm_config = {**self._config, **attrs}
|
|
return self._post(
|
|
f'/{self._api_version}/generate_stream',
|
|
response_cls=Response,
|
|
json=dict(prompt=prompt, llm_config=llm_config, stop=stop, adapter_name=adapter_name),
|
|
options={'max_retries': self._max_retries},
|
|
stream=True,
|
|
)
|
|
|
|
|
|
@attr.define(init=False)
|
|
class AsyncHTTPClient(AsyncClient):
|
|
helpers: Helpers = attr.field(init=False)
|
|
_api_version: str = 'v1'
|
|
_verify: bool = True
|
|
__metadata: Metadata | None = None
|
|
__config: dict[str, t.Any] | None = None
|
|
|
|
def __repr__(self):
|
|
return f'<AsyncHTTPClient address={self.address} timeout={self._timeout} api_version={self._api_version} verify={self._verify}>'
|
|
|
|
def __init__(self, address=None, timeout=30, verify=False, max_retries=MAX_RETRIES, api_version='v1'):
|
|
if address is None:
|
|
address = os.getenv('OPENLLM_ENDPOINT')
|
|
if address is None:
|
|
raise ValueError("address must either be provided or through 'OPENLLM_ENDPOINT'")
|
|
self._api_version, self._verify = api_version, verify
|
|
|
|
# mk messages to be async here
|
|
self.helpers = Helpers.permute(messages=Helpers.async_messages)(async_client=self)
|
|
|
|
super().__init__(_address_converter(address), VERSION, timeout=timeout, max_retries=max_retries)
|
|
|
|
def _build_auth_headers(self) -> t.Dict[str, str]:
|
|
env = os.getenv('OPENLLM_AUTH_TOKEN')
|
|
if env is not None:
|
|
return {'Authorization': f'Bearer {env}'}
|
|
return super()._build_auth_headers()
|
|
|
|
@property
|
|
async def _metadata(self) -> t.Awaitable[Metadata]:
|
|
if self.__metadata is None:
|
|
self.__metadata = await self._post(
|
|
f'/{self._api_version}/metadata', response_cls=Metadata, json={}, options={'max_retries': self._max_retries}
|
|
)
|
|
return self.__metadata
|
|
|
|
@property
|
|
async def _config(self):
|
|
if self.__config is None:
|
|
self.__config = (await self._metadata).configuration
|
|
return self.__config
|
|
|
|
async def query(self, prompt, **attrs):
|
|
return await self.generate(prompt, **attrs)
|
|
|
|
async def health(self):
|
|
response = await self._get(
|
|
'/readyz', response_cls=None, options={'return_raw_response': True, 'max_retries': self._max_retries}
|
|
)
|
|
return response.status_code == 200
|
|
|
|
async def generate(
|
|
self, prompt, llm_config=None, stop=None, adapter_name=None, timeout=None, verify=None, **attrs
|
|
) -> Response:
|
|
if timeout is None:
|
|
timeout = self._timeout
|
|
if verify is None:
|
|
verify = self._verify # XXX: need to support this again
|
|
_metadata = await self._metadata
|
|
_config = await self._config
|
|
if llm_config is not None:
|
|
llm_config = {**_config, **llm_config, **attrs}
|
|
else:
|
|
llm_config = {**_config, **attrs}
|
|
return await self._post(
|
|
f'/{self._api_version}/generate',
|
|
response_cls=Response,
|
|
json=dict(prompt=prompt, llm_config=llm_config, stop=stop, adapter_name=adapter_name),
|
|
options={'max_retries': self._max_retries},
|
|
)
|
|
|
|
async def generate_stream(
|
|
self, prompt, llm_config=None, stop=None, adapter_name=None, timeout=None, verify=None, **attrs
|
|
) -> t.AsyncGenerator[StreamingResponse, t.Any]:
|
|
async for response_chunk in self.generate_iterator(
|
|
prompt, llm_config, stop, adapter_name, timeout, verify, **attrs
|
|
):
|
|
yield StreamingResponse.from_response_chunk(response_chunk)
|
|
|
|
async def generate_iterator(
|
|
self, prompt, llm_config=None, stop=None, adapter_name=None, timeout=None, verify=None, **attrs
|
|
) -> t.AsyncGenerator[Response, t.Any]:
|
|
if timeout is None:
|
|
timeout = self._timeout
|
|
if verify is None:
|
|
verify = self._verify # XXX: need to support this again
|
|
_metadata = await self._metadata
|
|
_config = await self._config
|
|
if llm_config is not None:
|
|
llm_config = {**_config, **llm_config, **attrs}
|
|
else:
|
|
llm_config = {**_config, **attrs}
|
|
|
|
async for response_chunk in await self._post(
|
|
f'/{self._api_version}/generate_stream',
|
|
response_cls=Response,
|
|
json=dict(prompt=prompt, llm_config=llm_config, stop=stop, adapter_name=adapter_name),
|
|
options={'max_retries': self._max_retries},
|
|
stream=True,
|
|
):
|
|
yield response_chunk
|