mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-04-19 22:48:39 -04:00
feat(client): support authentication token and shim implementation (#605)
* chore: synch generate_iterator to be the same as server Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> * --wip-- Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> * wip Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> * feat: cleanup shim implementation Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> * ci: auto fixes from pre-commit.ci For more information, see https://pre-commit.ci * chore: fix pre-commit Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> * chore: update changelog Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> * chore: update check with tuple Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> --------- Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -202,11 +202,8 @@ After you change or update any CI related under `.github`, run `bash tools/lock-
|
||||
|
||||
See this [docs](/.github/INFRA.md) for more information on OpenLLM's CI/CD workflow.
|
||||
|
||||
## UI
|
||||
|
||||
See [ClojureScript UI's README.md](/external/clojure/README.md) for more information.
|
||||
|
||||
See [Documentation's README.md](/docs/README.md) for more information on running locally.
|
||||
## Typing
|
||||
For all internal functions, it is recommended to provide type hint. For all public function definitions, it is recommended to create a stubs file `.pyi` to separate supported external API to increase code visibility. See [openllm-client's `__init__.pyi`](/openllm-client/src/openllm_client/__init__.pyi) for example.
|
||||
|
||||
## Install from git archive install
|
||||
|
||||
|
||||
1
changelog.d/605.change.md
Normal file
1
changelog.d/605.change.md
Normal file
@@ -0,0 +1 @@
|
||||
Update client implementation and support Authentication through `OPENLLM_AUTH_TOKEN`
|
||||
@@ -1,6 +1,7 @@
|
||||
# NOTE: Make sure to install openai>1
|
||||
import os, openai
|
||||
import os, openai, typing as t
|
||||
from openai.types.chat import (
|
||||
ChatCompletionMessageParam,
|
||||
ChatCompletionSystemMessageParam,
|
||||
ChatCompletionUserMessageParam,
|
||||
ChatCompletionAssistantMessageParam,
|
||||
@@ -14,7 +15,7 @@ model = models.data[0].id
|
||||
|
||||
# Chat completion API
|
||||
stream = str(os.getenv('STREAM', False)).upper() in ['TRUE', '1', 'YES', 'Y', 'ON']
|
||||
messages = [
|
||||
messages: t.List[ChatCompletionMessageParam]= [
|
||||
ChatCompletionSystemMessageParam(role='system', content='You are acting as Ernest Hemmingway.'),
|
||||
ChatCompletionUserMessageParam(role='user', content='Hi there!'),
|
||||
ChatCompletionAssistantMessageParam(role='assistant', content='Yes?'),
|
||||
|
||||
@@ -57,7 +57,7 @@ keywords = [
|
||||
"PyTorch",
|
||||
"Transformers",
|
||||
]
|
||||
dependencies = ["orjson", "httpx", "attrs>=23.1.0", "cattrs>=23.1.0"]
|
||||
dependencies = ["orjson", "httpx", "attrs>=23.1.0", "cattrs>=23.1.0", 'distro', 'anyio']
|
||||
license = "Apache-2.0"
|
||||
name = "openllm-client"
|
||||
requires-python = ">=3.8"
|
||||
@@ -71,8 +71,9 @@ Homepage = "https://bentoml.com"
|
||||
Tracker = "https://github.com/bentoml/OpenLLM/issues"
|
||||
Twitter = "https://twitter.com/bentomlai"
|
||||
[project.optional-dependencies]
|
||||
full = ["openllm-client[grpc,agents]"]
|
||||
grpc = ["bentoml[grpc]>=1.1.6"]
|
||||
full = ["openllm-client[grpc,agents,auth]"]
|
||||
grpc = ["bentoml[grpc]>=1.1.9"]
|
||||
auth = ['httpx_auth']
|
||||
agents = ["transformers[agents]>=4.30", "diffusers", "soundfile"]
|
||||
|
||||
[tool.hatch.version]
|
||||
|
||||
@@ -15,26 +15,21 @@ from ._schemas import StreamingResponse as _StreamingResponse
|
||||
@_attr.define
|
||||
class HTTPClient:
|
||||
address: str
|
||||
client_args: Dict[str, Any]
|
||||
@staticmethod
|
||||
def wait_until_server_ready(
|
||||
addr: str, timeout: float = ..., verify: bool = ..., check_interval: int = ..., **client_args: Any
|
||||
@overload
|
||||
def __init__(
|
||||
self, address: str, timeout: int = ..., verify: bool = ..., max_retries: int = ..., api_version: str = ...
|
||||
) -> None: ...
|
||||
@overload
|
||||
def __init__(
|
||||
self, address: str, timeout: int = ..., verify: bool = ..., api_version: str = ..., **client_args: Any
|
||||
self, address: str = ..., timeout: int = ..., verify: bool = ..., max_retries: int = ..., api_version: str = ...
|
||||
) -> None: ...
|
||||
@overload
|
||||
def __init__(
|
||||
self, address: str = ..., timeout: int = ..., verify: bool = ..., api_version: str = ..., **client_args: Any
|
||||
) -> None: ...
|
||||
@overload
|
||||
def __init__(
|
||||
self, address: None = ..., timeout: int = ..., verify: bool = ..., api_version: str = ..., **client_args: Any
|
||||
self, address: None = ..., timeout: int = ..., verify: bool = ..., max_retries: int = ..., api_version: str = ...
|
||||
) -> None: ...
|
||||
@property
|
||||
def is_ready(self) -> bool: ...
|
||||
def health(self) -> None: ...
|
||||
def health(self) -> bool: ...
|
||||
def query(self, prompt: str, **attrs: Any) -> _Response: ...
|
||||
def generate(
|
||||
self,
|
||||
@@ -46,6 +41,16 @@ class HTTPClient:
|
||||
verify: Optional[bool] = ...,
|
||||
**attrs: Any,
|
||||
) -> _Response: ...
|
||||
def generate_iterator(
|
||||
self,
|
||||
prompt: str,
|
||||
llm_config: Optional[Dict[str, Any]] = ...,
|
||||
stop: Optional[Union[str, List[str]]] = ...,
|
||||
adapter_name: Optional[str] = ...,
|
||||
timeout: Optional[int] = ...,
|
||||
verify: Optional[bool] = ...,
|
||||
**attrs: Any,
|
||||
) -> Iterator[_Response]: ...
|
||||
def generate_stream(
|
||||
self,
|
||||
prompt: str,
|
||||
@@ -60,26 +65,21 @@ class HTTPClient:
|
||||
@_attr.define
|
||||
class AsyncHTTPClient:
|
||||
address: str
|
||||
client_args: Dict[str, Any]
|
||||
@staticmethod
|
||||
async def wait_until_server_ready(
|
||||
addr: str, timeout: float = ..., verify: bool = ..., check_interval: int = ..., **client_args: Any
|
||||
@overload
|
||||
def __init__(
|
||||
self, address: str, timeout: int = ..., verify: bool = ..., max_retries: int = ..., api_version: str = ...
|
||||
) -> None: ...
|
||||
@overload
|
||||
def __init__(
|
||||
self, address: str, timeout: int = ..., verify: bool = ..., api_version: str = ..., **client_args: Any
|
||||
self, address: str = ..., timeout: int = ..., verify: bool = ..., max_retries: int = ..., api_version: str = ...
|
||||
) -> None: ...
|
||||
@overload
|
||||
def __init__(
|
||||
self, address: str = ..., timeout: int = ..., verify: bool = ..., api_version: str = ..., **client_args: Any
|
||||
) -> None: ...
|
||||
@overload
|
||||
def __init__(
|
||||
self, address: None = ..., timeout: int = ..., verify: bool = ..., api_version: str = ..., **client_args: Any
|
||||
self, address: None = ..., timeout: int = ..., verify: bool = ..., max_retries: int = ..., api_version: str = ...
|
||||
) -> None: ...
|
||||
@property
|
||||
def is_ready(self) -> bool: ...
|
||||
async def health(self) -> None: ...
|
||||
async def health(self) -> bool: ...
|
||||
async def query(self, prompt: str, **attrs: Any) -> _Response: ...
|
||||
async def generate(
|
||||
self,
|
||||
@@ -91,6 +91,16 @@ class AsyncHTTPClient:
|
||||
verify: Optional[bool] = ...,
|
||||
**attrs: Any,
|
||||
) -> _Response: ...
|
||||
async def generate_iterator(
|
||||
self,
|
||||
prompt: str,
|
||||
llm_config: Optional[Dict[str, Any]] = ...,
|
||||
stop: Optional[Union[str, List[str]]] = ...,
|
||||
adapter_name: Optional[str] = ...,
|
||||
timeout: Optional[int] = ...,
|
||||
verify: Optional[bool] = ...,
|
||||
**attrs: Any,
|
||||
) -> Iterator[_Response]: ...
|
||||
async def generate_stream(
|
||||
self,
|
||||
prompt: str,
|
||||
|
||||
@@ -1,349 +1,240 @@
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
import enum
|
||||
import importlib.metadata
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
import typing as t
|
||||
import urllib.error
|
||||
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import attr
|
||||
import httpx
|
||||
import orjson
|
||||
|
||||
from ._schemas import Request
|
||||
from ._schemas import MetadataOutput
|
||||
from ._schemas import Response
|
||||
from ._schemas import StreamingResponse
|
||||
from ._shim import MAX_RETRIES
|
||||
from ._shim import AsyncClient
|
||||
from ._shim import Client
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _address_validator(_, attr, value):
|
||||
if not isinstance(value, str):
|
||||
raise TypeError(f'{attr.name} must be a string')
|
||||
if not urlparse(value).netloc:
|
||||
raise ValueError(f'{attr.name} must be a valid URL')
|
||||
VERSION = importlib.metadata.version('openllm-client')
|
||||
|
||||
|
||||
def _address_converter(addr: str):
|
||||
return addr if '://' in addr else 'http://' + addr
|
||||
|
||||
|
||||
class ServerState(enum.Enum):
|
||||
class ClientState(enum.Enum):
|
||||
CLOSED = 1 # CLOSED: The server is not yet ready or `wait_until_server_ready` has not been called/failed.
|
||||
READY = 2 # READY: The server is ready and `wait_until_server_ready` has been called.
|
||||
|
||||
|
||||
_object_setattr = object.__setattr__
|
||||
DISCONNECTED = 3 # DISCONNECTED: The server is disconnected and `wait_until_server_ready` has been called.
|
||||
|
||||
|
||||
@attr.define(init=False)
|
||||
class HTTPClient:
|
||||
address: str = attr.field(validator=_address_validator, converter=_address_converter)
|
||||
client_args: t.Dict[str, t.Any]
|
||||
|
||||
_inner: httpx.Client
|
||||
_timeout: int = 30
|
||||
class HTTPClient(Client):
|
||||
_api_version: str = 'v1'
|
||||
_verify: bool = True
|
||||
_state: ServerState = ServerState.CLOSED
|
||||
|
||||
__metadata: dict[str, t.Any] | None = None
|
||||
__metadata: MetadataOutput | None = None
|
||||
__config: dict[str, t.Any] | None = None
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f'<HTTPClient timeout={self._timeout} api_version={self._api_version} verify={self._verify} state={self._state}>'
|
||||
)
|
||||
return f'<HTTPClient address={self.address} timeout={self._timeout} api_version={self._api_version} verify={self._verify}>'
|
||||
|
||||
@staticmethod
|
||||
def wait_until_server_ready(addr, timeout=30, verify=False, check_interval=1, **client_args):
|
||||
addr = _address_converter(addr)
|
||||
logger.debug('Wait for server @ %s to be ready', addr)
|
||||
start = time.monotonic()
|
||||
while time.monotonic() - start < timeout:
|
||||
try:
|
||||
with httpx.Client(base_url=addr, verify=verify, **client_args) as sess:
|
||||
status = sess.get('/readyz').status_code
|
||||
if status == 200:
|
||||
break
|
||||
else:
|
||||
time.sleep(check_interval)
|
||||
except (httpx.ConnectError, urllib.error.URLError, ConnectionError):
|
||||
logger.debug('Server is not ready yet, retrying in %d seconds...', check_interval)
|
||||
time.sleep(check_interval)
|
||||
# Try once more and raise for exception
|
||||
try:
|
||||
with httpx.Client(base_url=addr, verify=verify, **client_args) as sess:
|
||||
status = sess.get('/readyz').status_code
|
||||
except httpx.HTTPStatusError as err:
|
||||
logger.error('Failed to wait until server ready: %s', addr)
|
||||
logger.error(err)
|
||||
raise
|
||||
|
||||
def __init__(self, address=None, timeout=30, verify=False, api_version='v1', **client_args):
|
||||
def __init__(self, address=None, timeout=30, verify=False, max_retries=MAX_RETRIES, api_version='v1'):
|
||||
if address is None:
|
||||
env = os.getenv('OPENLLM_ENDPOINT')
|
||||
if env is None:
|
||||
raise ValueError('address must be provided')
|
||||
address = env
|
||||
self.__attrs_init__(
|
||||
address,
|
||||
client_args,
|
||||
httpx.Client(base_url=address, timeout=timeout, verify=verify, **client_args),
|
||||
timeout,
|
||||
api_version,
|
||||
verify,
|
||||
)
|
||||
address = os.getenv('OPENLLM_ENDPOINT')
|
||||
if address is None:
|
||||
raise ValueError("address must either be provided or through 'OPENLLM_ENDPOINT'")
|
||||
self._api_version, self._verify = api_version, verify
|
||||
super().__init__(_address_converter(address), VERSION, timeout=timeout, max_retries=max_retries)
|
||||
|
||||
def _metadata(self) -> dict[str, t.Any]:
|
||||
if self.__metadata is None:
|
||||
self.__metadata = self._inner.post(self._build_endpoint('metadata')).json()
|
||||
return self.__metadata
|
||||
|
||||
def _config(self) -> dict[str, t.Any]:
|
||||
if self.__config is None:
|
||||
config = orjson.loads(self._metadata()['configuration'])
|
||||
generation_config = config.pop('generation_config')
|
||||
self.__config = {**config, **generation_config}
|
||||
return self.__config
|
||||
def _build_auth_headers(self) -> t.Dict[str, str]:
|
||||
env = os.getenv('OPENLLM_AUTH_TOKEN')
|
||||
if env is not None:
|
||||
return {'Authorization': f'Bearer {env}'}
|
||||
return super()._build_auth_headers()
|
||||
|
||||
def __del__(self):
|
||||
self._inner.close()
|
||||
|
||||
def _build_endpoint(self, endpoint):
|
||||
return ('/' if not self._api_version.startswith('/') else '') + f'{self._api_version}/{endpoint}'
|
||||
self.close()
|
||||
|
||||
@property
|
||||
def is_ready(self):
|
||||
return self._state == ServerState.READY
|
||||
def _metadata(self):
|
||||
if self.__metadata is None:
|
||||
path = f'/{self._api_version}/metadata'
|
||||
self.__metadata = self._post(
|
||||
path, response_cls=MetadataOutput, json={}, options={'max_retries': self._max_retries}
|
||||
)
|
||||
return self.__metadata
|
||||
|
||||
@property
|
||||
def _config(self) -> dict[str, t.Any]:
|
||||
if self.__config is None:
|
||||
self.__config = self._metadata.configuration
|
||||
return self.__config
|
||||
|
||||
def query(self, prompt, **attrs):
|
||||
return self.generate(prompt, **attrs)
|
||||
|
||||
def health(self):
|
||||
try:
|
||||
self.wait_until_server_ready(self.address, timeout=self._timeout, verify=self._verify, **self.client_args)
|
||||
_object_setattr(self, '_state', ServerState.READY)
|
||||
except Exception as e:
|
||||
logger.error('Server is not healthy (Scroll up for traceback)\n%s', e)
|
||||
_object_setattr(self, '_state', ServerState.CLOSED)
|
||||
response = self._get(
|
||||
'/readyz', response_cls=None, options={'return_raw_response': True, 'max_retries': self._max_retries}
|
||||
)
|
||||
return response.status_code == 200
|
||||
|
||||
def generate(
|
||||
self, prompt, llm_config=None, stop=None, adapter_name=None, timeout=None, verify=None, **attrs
|
||||
) -> Response:
|
||||
if not self.is_ready:
|
||||
self.health()
|
||||
if not self.is_ready:
|
||||
raise RuntimeError('Server is not ready. Check server logs for more information.')
|
||||
if timeout is None:
|
||||
timeout = self._timeout
|
||||
if verify is None:
|
||||
verify = self._verify
|
||||
_meta, _config = self._metadata(), self._config()
|
||||
verify = self._verify # XXX: need to support this again
|
||||
if llm_config is not None:
|
||||
llm_config = {**_config, **llm_config, **attrs}
|
||||
llm_config = {**self._config, **llm_config, **attrs}
|
||||
else:
|
||||
llm_config = {**_config, **attrs}
|
||||
if _meta['prompt_template'] is not None:
|
||||
prompt = _meta['prompt_template'].format(system_message=_meta['system_message'], instruction=prompt)
|
||||
llm_config = {**self._config, **attrs}
|
||||
if self._metadata.prompt_template is not None:
|
||||
prompt = self._metadata.prompt_template.format(system_message=self._metadata.system_message, instruction=prompt)
|
||||
|
||||
req = Request(prompt=prompt, llm_config=llm_config, stop=stop, adapter_name=adapter_name)
|
||||
with httpx.Client(base_url=self.address, timeout=timeout, verify=verify, **self.client_args) as client:
|
||||
r = client.post(self._build_endpoint('generate'), json=req.model_dump_json(), **self.client_args)
|
||||
if r.status_code != 200:
|
||||
raise ValueError("Failed to get generation from '/v1/generate'. Check server logs for more details.")
|
||||
return Response.model_construct(r.json())
|
||||
return self._post(
|
||||
f'/{self._api_version}/generate',
|
||||
response_cls=Response,
|
||||
json=dict(prompt=prompt, llm_config=llm_config, stop=stop, adapter_name=adapter_name),
|
||||
)
|
||||
|
||||
def generate_stream(
|
||||
self, prompt, llm_config=None, stop=None, adapter_name=None, timeout=None, verify=None, **attrs
|
||||
) -> t.Iterator[StreamingResponse]:
|
||||
if not self.is_ready:
|
||||
self.health()
|
||||
if not self.is_ready:
|
||||
raise RuntimeError('Server is not ready. Check server logs for more information.')
|
||||
for response_chunk in self.generate_iterator(prompt, llm_config, stop, adapter_name, timeout, verify, **attrs):
|
||||
yield StreamingResponse.from_response_chunk(response_chunk)
|
||||
|
||||
def generate_iterator(
|
||||
self, prompt, llm_config=None, stop=None, adapter_name=None, timeout=None, verify=None, **attrs
|
||||
) -> t.Iterator[Response]:
|
||||
if timeout is None:
|
||||
timeout = self._timeout
|
||||
if verify is None:
|
||||
verify = self._verify
|
||||
_meta, _config = self._metadata(), self._config()
|
||||
verify = self._verify # XXX: need to support this again
|
||||
if llm_config is not None:
|
||||
llm_config = {**_config, **llm_config, **attrs}
|
||||
llm_config = {**self._config, **llm_config, **attrs}
|
||||
else:
|
||||
llm_config = {**_config, **attrs}
|
||||
if _meta['prompt_template'] is not None:
|
||||
prompt = _meta['prompt_template'].format(system_message=_meta['system_message'], instruction=prompt)
|
||||
|
||||
req = Request(prompt=prompt, llm_config=llm_config, stop=stop, adapter_name=adapter_name)
|
||||
with httpx.Client(base_url=self.address, timeout=timeout, verify=verify, **self.client_args) as client:
|
||||
with client.stream(
|
||||
'POST', self._build_endpoint('generate_stream'), json=req.model_dump_json(), **self.client_args
|
||||
) as r:
|
||||
for payload in r.iter_bytes():
|
||||
if payload == b'data: [DONE]\n\n':
|
||||
break
|
||||
# Skip line
|
||||
if payload == b'\n':
|
||||
continue
|
||||
if payload.startswith(b'data: '):
|
||||
try:
|
||||
proc = payload.decode('utf-8').lstrip('data: ').rstrip('\n')
|
||||
data = orjson.loads(proc)
|
||||
yield StreamingResponse.from_response_chunk(Response.model_construct(data))
|
||||
except Exception:
|
||||
pass # FIXME: Handle this
|
||||
llm_config = {**self._config, **attrs}
|
||||
if self._metadata.prompt_template is not None:
|
||||
prompt = self._metadata.prompt_template.format(system_message=self._metadata.system_message, instruction=prompt)
|
||||
return self._post(
|
||||
f'/{self._api_version}/generate_stream',
|
||||
response_cls=Response,
|
||||
json=dict(prompt=prompt, llm_config=llm_config, stop=stop, adapter_name=adapter_name),
|
||||
stream=True,
|
||||
)
|
||||
|
||||
|
||||
@attr.define(init=False)
|
||||
class AsyncHTTPClient:
|
||||
address: str = attr.field(validator=_address_validator, converter=_address_converter)
|
||||
client_args: t.Dict[str, t.Any]
|
||||
|
||||
_inner: httpx.AsyncClient
|
||||
_timeout: int = 30
|
||||
class AsyncHTTPClient(AsyncClient):
|
||||
_api_version: str = 'v1'
|
||||
_verify: bool = True
|
||||
_state: ServerState = ServerState.CLOSED
|
||||
|
||||
__metadata: dict[str, t.Any] | None = None
|
||||
__metadata: MetadataOutput | None = None
|
||||
__config: dict[str, t.Any] | None = None
|
||||
|
||||
def __repr__(self):
|
||||
return f'<AsyncHTTPClient timeout={self._timeout} api_version={self._api_version} verify={self._verify} state={self._state}>'
|
||||
return f'<AsyncHTTPClient address={self.address} timeout={self._timeout} api_version={self._api_version} verify={self._verify}>'
|
||||
|
||||
@staticmethod
|
||||
async def wait_until_server_ready(addr, timeout=30, verify=False, check_interval=1, **client_args):
|
||||
addr = _address_converter(addr)
|
||||
logger.debug('Wait for server @ %s to be ready', addr)
|
||||
start = time.monotonic()
|
||||
while time.monotonic() - start < timeout:
|
||||
try:
|
||||
async with httpx.AsyncClient(base_url=addr, verify=verify, **client_args) as sess:
|
||||
status = (await sess.get('/readyz')).status_code
|
||||
if status == 200:
|
||||
break
|
||||
else:
|
||||
await asyncio.sleep(check_interval)
|
||||
except (httpx.ConnectError, urllib.error.URLError, ConnectionError):
|
||||
logger.debug('Server is not ready yet, retrying in %d seconds...', check_interval)
|
||||
await asyncio.sleep(check_interval)
|
||||
# Try once more and raise for exception
|
||||
try:
|
||||
async with httpx.AsyncClient(base_url=addr, verify=verify, **client_args) as sess:
|
||||
status = (await sess.get('/readyz')).status_code
|
||||
except httpx.HTTPStatusError as err:
|
||||
logger.error('Failed to wait until server ready: %s', addr)
|
||||
logger.error(err)
|
||||
raise
|
||||
|
||||
def __init__(self, address=None, timeout=30, verify=False, api_version='v1', **client_args):
|
||||
def __init__(self, address=None, timeout=30, verify=False, max_retries=MAX_RETRIES, api_version='v1'):
|
||||
if address is None:
|
||||
env = os.getenv('OPENLLM_ENDPOINT')
|
||||
if env is None:
|
||||
raise ValueError('address must be provided')
|
||||
address = env
|
||||
self.__attrs_init__(
|
||||
address,
|
||||
client_args,
|
||||
httpx.AsyncClient(base_url=address, timeout=timeout, verify=verify, **client_args),
|
||||
timeout,
|
||||
api_version,
|
||||
verify,
|
||||
)
|
||||
address = os.getenv('OPENLLM_ENDPOINT')
|
||||
if address is None:
|
||||
raise ValueError("address must either be provided or through 'OPENLLM_ENDPOINT'")
|
||||
self._api_version, self._verify = api_version, verify
|
||||
super().__init__(_address_converter(address), VERSION, timeout=timeout, max_retries=max_retries)
|
||||
|
||||
async def _metadata(self) -> dict[str, t.Any]:
|
||||
if self.__metadata is None:
|
||||
self.__metadata = (await self._inner.post(self._build_endpoint('metadata'))).json()
|
||||
return self.__metadata
|
||||
|
||||
async def _config(self) -> dict[str, t.Any]:
|
||||
if self.__config is None:
|
||||
config = orjson.loads((await self._metadata())['configuration'])
|
||||
generation_config = config.pop('generation_config')
|
||||
self.__config = {**config, **generation_config}
|
||||
return self.__config
|
||||
|
||||
def _build_endpoint(self, endpoint):
|
||||
return '/' + f'{self._api_version}/{endpoint}'
|
||||
def _build_auth_headers(self) -> t.Dict[str, str]:
|
||||
env = os.getenv('OPENLLM_AUTH_TOKEN')
|
||||
if env is not None:
|
||||
return {'Authorization': f'Bearer {env}'}
|
||||
return super()._build_auth_headers()
|
||||
|
||||
@property
|
||||
def is_ready(self):
|
||||
return self._state == ServerState.READY
|
||||
def _loop(self) -> asyncio.AbstractEventLoop:
|
||||
try:
|
||||
return asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
return asyncio.get_event_loop()
|
||||
|
||||
@property
|
||||
async def _metadata(self) -> t.Awaitable[MetadataOutput]:
|
||||
if self.__metadata is None:
|
||||
self.__metadata = await self._post(
|
||||
f'/{self._api_version}/metadata',
|
||||
response_cls=MetadataOutput,
|
||||
json={},
|
||||
options={'max_retries': self._max_retries},
|
||||
)
|
||||
return self.__metadata
|
||||
|
||||
@property
|
||||
async def _config(self):
|
||||
if self.__config is None:
|
||||
self.__config = (await self._metadata).configuration
|
||||
return self.__config
|
||||
|
||||
async def query(self, prompt, **attrs):
|
||||
return await self.generate(prompt, **attrs)
|
||||
|
||||
async def health(self):
|
||||
try:
|
||||
await self.wait_until_server_ready(self.address, timeout=self._timeout, verify=self._verify, **self.client_args)
|
||||
_object_setattr(self, '_state', ServerState.READY)
|
||||
except Exception as e:
|
||||
logger.error('Server is not healthy (Scroll up for traceback)\n%s', e)
|
||||
_object_setattr(self, '_state', ServerState.CLOSED)
|
||||
response = await self._get(
|
||||
'/readyz', response_cls=None, options={'return_raw_response': True, 'max_retries': self._max_retries}
|
||||
)
|
||||
return response.status_code == 200
|
||||
|
||||
async def generate(
|
||||
self, prompt, llm_config=None, stop=None, adapter_name=None, timeout=None, verify=None, **attrs
|
||||
) -> Response:
|
||||
if not self.is_ready:
|
||||
await self.health()
|
||||
if not self.is_ready:
|
||||
raise RuntimeError('Server is not ready. Check server logs for more information.')
|
||||
if timeout is None:
|
||||
timeout = self._timeout
|
||||
if verify is None:
|
||||
verify = self._verify
|
||||
_meta, _config = await self._metadata(), await self._config()
|
||||
verify = self._verify # XXX: need to support this again
|
||||
_metadata = await self._metadata
|
||||
_config = await self._config
|
||||
if llm_config is not None:
|
||||
llm_config = {**_config, **llm_config, **attrs}
|
||||
else:
|
||||
llm_config = {**_config, **attrs}
|
||||
if _meta['prompt_template'] is not None:
|
||||
prompt = _meta['prompt_template'].format(system_message=_meta['system_message'], instruction=prompt)
|
||||
|
||||
req = Request(prompt=prompt, llm_config=llm_config, stop=stop, adapter_name=adapter_name)
|
||||
async with httpx.AsyncClient(base_url=self.address, timeout=timeout, verify=verify, **self.client_args) as client:
|
||||
r = await client.post(self._build_endpoint('generate'), json=req.model_dump_json(), **self.client_args)
|
||||
if r.status_code != 200:
|
||||
raise ValueError("Failed to get generation from '/v1/generate'. Check server logs for more details.")
|
||||
return Response.model_construct(r.json())
|
||||
if _metadata.prompt_template is not None:
|
||||
prompt = _metadata.prompt_template.format(system_message=_metadata.system_message, instruction=prompt)
|
||||
return await self._post(
|
||||
f'/{self._api_version}/generate',
|
||||
response_cls=Response,
|
||||
json=dict(prompt=prompt, llm_config=llm_config, stop=stop, adapter_name=adapter_name),
|
||||
)
|
||||
|
||||
async def generate_stream(
|
||||
self, prompt, llm_config=None, stop=None, adapter_name=None, timeout=None, verify=None, **attrs
|
||||
) -> t.AsyncGenerator[StreamingResponse, t.Any]:
|
||||
if not self.is_ready:
|
||||
await self.health()
|
||||
if not self.is_ready:
|
||||
raise RuntimeError('Server is not ready. Check server logs for more information.')
|
||||
async for response_chunk in self.generate_iterator(
|
||||
prompt, llm_config, stop, adapter_name, timeout, verify, **attrs
|
||||
):
|
||||
yield StreamingResponse.from_response_chunk(response_chunk)
|
||||
|
||||
async def generate_iterator(
|
||||
self, prompt, llm_config=None, stop=None, adapter_name=None, timeout=None, verify=None, **attrs
|
||||
) -> t.AsyncGenerator[Response, t.Any]:
|
||||
if timeout is None:
|
||||
timeout = self._timeout
|
||||
if verify is None:
|
||||
verify = self._verify
|
||||
_meta, _config = await self._metadata(), await self._config()
|
||||
verify = self._verify # XXX: need to support this again
|
||||
_metadata = await self._metadata
|
||||
_config = await self._config
|
||||
if llm_config is not None:
|
||||
llm_config = {**_config, **llm_config, **attrs}
|
||||
else:
|
||||
llm_config = {**_config, **attrs}
|
||||
if _meta['prompt_template'] is not None:
|
||||
prompt = _meta['prompt_template'].format(system_message=_meta['system_message'], instruction=prompt)
|
||||
if _metadata.prompt_template is not None:
|
||||
prompt = _metadata.prompt_template.format(system_message=_metadata.system_message, instruction=prompt)
|
||||
|
||||
req = Request(prompt=prompt, llm_config=llm_config, stop=stop, adapter_name=adapter_name)
|
||||
async with httpx.AsyncClient(base_url=self.address, timeout=timeout, verify=verify, **self.client_args) as client:
|
||||
async with client.stream(
|
||||
'POST', self._build_endpoint('generate_stream'), json=req.model_dump_json(), **self.client_args
|
||||
) as r:
|
||||
async for payload in r.aiter_bytes():
|
||||
if payload == b'data: [DONE]\n\n':
|
||||
break
|
||||
# Skip line
|
||||
if payload == b'\n':
|
||||
continue
|
||||
if payload.startswith(b'data: '):
|
||||
try:
|
||||
proc = payload.decode('utf-8').lstrip('data: ').rstrip('\n')
|
||||
data = orjson.loads(proc)
|
||||
yield StreamingResponse.from_response_chunk(Response.model_construct(data))
|
||||
except Exception:
|
||||
pass # FIXME: Handle this
|
||||
async for response_chunk in await self._post(
|
||||
f'/{self._api_version}/generate_stream',
|
||||
response_cls=Response,
|
||||
json=dict(prompt=prompt, llm_config=llm_config, stop=stop, adapter_name=adapter_name),
|
||||
stream=True,
|
||||
):
|
||||
yield response_chunk
|
||||
|
||||
@@ -3,9 +3,47 @@ import typing as t
|
||||
|
||||
import attr
|
||||
import cattr
|
||||
import orjson
|
||||
|
||||
from ._utils import converter
|
||||
|
||||
|
||||
# XXX: sync with openllm-core/src/openllm_core/_schemas.py
|
||||
@attr.define
|
||||
class MetadataOutput:
|
||||
model_id: str
|
||||
timeout: int
|
||||
model_name: str
|
||||
backend: str
|
||||
configuration: t.Dict[str, t.Any]
|
||||
prompt_template: t.Optional[str]
|
||||
system_message: t.Optional[str]
|
||||
|
||||
|
||||
def _structure_metadata(data: t.Dict[str, t.Any], cls: type[MetadataOutput]) -> MetadataOutput:
|
||||
try:
|
||||
configuration = orjson.loads(data['configuration'])
|
||||
generation_config = configuration.pop('generation_config')
|
||||
configuration = {**configuration, **generation_config}
|
||||
except orjson.JSONDecodeError as e:
|
||||
raise RuntimeError(f'Malformed metadata configuration (Server-side issue): {e}') from None
|
||||
try:
|
||||
return cls(
|
||||
model_id=data['model_id'],
|
||||
timeout=data['timeout'],
|
||||
model_name=data['model_name'],
|
||||
backend=data['backend'],
|
||||
configuration=configuration,
|
||||
prompt_template=data['prompt_template'],
|
||||
system_message=data['system_message'],
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f'Malformed metadata (Server-side issue): {e}') from None
|
||||
|
||||
|
||||
converter.register_structure_hook(MetadataOutput, _structure_metadata)
|
||||
|
||||
|
||||
@attr.define
|
||||
class Request:
|
||||
prompt: str
|
||||
|
||||
610
openllm-client/src/openllm_client/_shim.py
Normal file
610
openllm-client/src/openllm_client/_shim.py
Normal file
@@ -0,0 +1,610 @@
|
||||
# This provides a base shim with httpx and acts as base request
|
||||
from __future__ import annotations
|
||||
import email.utils
|
||||
import logging
|
||||
import platform
|
||||
import random
|
||||
import time
|
||||
import typing as t
|
||||
|
||||
import anyio
|
||||
import attr
|
||||
import distro
|
||||
import httpx
|
||||
|
||||
from ._stream import AsyncStream
|
||||
from ._stream import Response
|
||||
from ._stream import Stream
|
||||
from ._typing_compat import Architecture
|
||||
from ._typing_compat import LiteralString
|
||||
from ._typing_compat import Platform
|
||||
from ._utils import converter
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
InnerClient = t.TypeVar('InnerClient', bound=t.Union[httpx.Client, httpx.AsyncClient])
|
||||
StreamType = t.TypeVar('StreamType', bound=t.Union[Stream[t.Any], AsyncStream[t.Any]])
|
||||
_Stream = t.TypeVar('_Stream', bound=Stream[t.Any])
|
||||
_AsyncStream = t.TypeVar('_AsyncStream', bound=AsyncStream[t.Any])
|
||||
LiteralVersion = t.Annotated[LiteralString, t.Literal['v1'], str]
|
||||
|
||||
|
||||
def _address_converter(addr: str | httpx.URL) -> httpx.URL:
|
||||
if isinstance(addr, httpx.URL):
|
||||
url = addr
|
||||
else:
|
||||
url = httpx.URL(addr if '://' in addr else f'http://{addr}')
|
||||
if not url.raw_path.endswith(b'/'):
|
||||
url = url.copy_with(path=url.raw_path + b'/')
|
||||
return url
|
||||
|
||||
|
||||
MAX_RETRIES = 2
|
||||
DEFAULT_TIMEOUT = httpx.Timeout(5.0) # Similar to httpx
|
||||
|
||||
|
||||
_T_co = t.TypeVar('_T_co', covariant=True)
|
||||
_T = t.TypeVar('_T')
|
||||
|
||||
|
||||
def _merge_mapping(a: t.Mapping[_T_co, _T], b: t.Mapping[_T_co, _T]) -> t.Dict[_T_co, _T]:
|
||||
# does the merging and filter out None
|
||||
return {k: v for k, v in {**a, **b}.items() if v is not None}
|
||||
|
||||
|
||||
def _platform() -> Platform:
|
||||
system = platform.system().lower()
|
||||
platform_name = platform.platform().lower()
|
||||
if system == 'darwin':
|
||||
return 'MacOS'
|
||||
elif system == 'windows':
|
||||
return 'Windows'
|
||||
elif system == 'linux':
|
||||
distro_id = distro.id()
|
||||
if distro_id == 'freebsd':
|
||||
return 'FreeBSD'
|
||||
elif distro_id == 'openbsd':
|
||||
return 'OpenBSD'
|
||||
else:
|
||||
return 'Linux'
|
||||
elif 'android' in platform_name:
|
||||
return 'Android'
|
||||
elif 'iphone' in platform_name:
|
||||
return 'iOS'
|
||||
elif 'ipad' in platform_name:
|
||||
return 'iPadOS'
|
||||
if platform_name:
|
||||
return f'Other:{platform_name}'
|
||||
return 'Unknown'
|
||||
|
||||
|
||||
def _architecture() -> Architecture:
|
||||
machine = platform.machine().lower()
|
||||
if machine in {'arm64', 'aarch64'}:
|
||||
return 'arm64'
|
||||
elif machine in {'arm', 'aarch32'}:
|
||||
return 'arm'
|
||||
elif machine in {'x86_64', 'amd64'}:
|
||||
return 'x86_64'
|
||||
elif machine in {'x86', 'i386', 'i686'}:
|
||||
return 'x86'
|
||||
elif machine:
|
||||
return f'Other:{machine}'
|
||||
return 'Unknown'
|
||||
|
||||
|
||||
@t.final
|
||||
@attr.frozen(auto_attribs=True)
|
||||
class RequestOptions:
|
||||
method: str = attr.field(converter=str.lower)
|
||||
url: str
|
||||
json: t.Optional[t.Dict[str, t.Any]] = attr.field(default=None)
|
||||
params: t.Optional[t.Mapping[str, t.Any]] = attr.field(default=None)
|
||||
headers: t.Optional[t.Dict[str, str]] = attr.field(default=None)
|
||||
max_retries: int = attr.field(default=MAX_RETRIES)
|
||||
return_raw_response: bool = attr.field(default=False)
|
||||
|
||||
def get_max_retries(self, max_retries: int | None) -> int:
|
||||
return max_retries if max_retries is not None else self.max_retries
|
||||
|
||||
@classmethod
|
||||
def model_construct(cls, **options: t.Any) -> RequestOptions:
|
||||
return cls(**options)
|
||||
|
||||
|
||||
@t.final
|
||||
@attr.frozen(auto_attribs=True)
|
||||
class APIResponse(t.Generic[Response]):
|
||||
_raw_response: httpx.Response
|
||||
_client: BaseClient[t.Any, t.Any]
|
||||
_response_cls: type[Response]
|
||||
_stream: bool
|
||||
_stream_cls: t.Optional[t.Union[t.Type[Stream[t.Any]], t.Type[AsyncStream[t.Any]]]]
|
||||
_options: RequestOptions
|
||||
|
||||
_parsed: t.Optional[Response] = attr.field(default=None, repr=False)
|
||||
|
||||
def parse(self):
|
||||
if self._options.return_raw_response:
|
||||
return self._raw_response
|
||||
if self._parsed is not None:
|
||||
return self._parsed
|
||||
if self._stream:
|
||||
stream_cls = self._stream_cls or self._client._default_stream_cls
|
||||
return stream_cls(response_cls=self._response_cls, response=self._raw_response, client=self._client)
|
||||
|
||||
content_type, *_ = self._raw_response.headers.get('content-type', '').split(';')
|
||||
if content_type != 'application/json':
|
||||
# Since users specific different content_type, then we return the raw binary text without and deserialisation
|
||||
return self._raw_response.text
|
||||
|
||||
data = self._raw_response.json()
|
||||
try:
|
||||
return self._client._process_response_data(
|
||||
data=data, response_cls=self._response_cls, raw_response=self._raw_response
|
||||
)
|
||||
except Exception as exc:
|
||||
raise ValueError(exc) from None # validation error here
|
||||
|
||||
@property
|
||||
def headers(self):
|
||||
return self._raw_response.headers
|
||||
|
||||
@property
|
||||
def status_code(self):
|
||||
return self._raw_response.status_code
|
||||
|
||||
@property
|
||||
def request(self):
|
||||
return self._raw_response.request
|
||||
|
||||
@property
|
||||
def url(self):
|
||||
return self._raw_response.url
|
||||
|
||||
@property
|
||||
def content(self):
|
||||
return self._raw_response.content
|
||||
|
||||
@property
|
||||
def text(self):
|
||||
return self._raw_response.text
|
||||
|
||||
@property
|
||||
def http_version(self):
|
||||
return self._raw_response.http_version
|
||||
|
||||
|
||||
@attr.define(init=False)
|
||||
class BaseClient(t.Generic[InnerClient, StreamType]):
|
||||
_base_url: httpx.URL = attr.field(converter=_address_converter)
|
||||
_version: LiteralVersion
|
||||
_timeout: httpx.Timeout = attr.field(converter=httpx.Timeout)
|
||||
_max_retries: int
|
||||
_inner: InnerClient
|
||||
_default_stream_cls: t.Type[StreamType]
|
||||
_auth_headers: t.Dict[str, str] = attr.field(init=False)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
base_url: str | httpx.URL,
|
||||
version: str,
|
||||
timeout: int | httpx.Timeout = DEFAULT_TIMEOUT,
|
||||
max_retries: int = MAX_RETRIES,
|
||||
client: InnerClient,
|
||||
_default_stream_cls: t.Type[StreamType],
|
||||
_internal: bool = False,
|
||||
):
|
||||
if not _internal:
|
||||
raise RuntimeError('Client is reserved to be used internally only.')
|
||||
self.__attrs_init__(base_url, version, timeout, max_retries, client, _default_stream_cls)
|
||||
self._auth_headers = self._build_auth_headers()
|
||||
|
||||
def _prepare_url(self, url: str) -> httpx.URL:
|
||||
# copied from httpx._merge_url
|
||||
merge_url = httpx.URL(url)
|
||||
if merge_url.is_relative_url:
|
||||
merge_raw = self._base_url.raw_path + merge_url.raw_path.lstrip(b'/')
|
||||
return self._base_url.copy_with(raw_path=merge_raw)
|
||||
return merge_url
|
||||
|
||||
@property
|
||||
def is_closed(self):
|
||||
return self._inner.is_closed
|
||||
|
||||
@property
|
||||
def is_ready(self):
|
||||
return not self.is_closed # backward compat
|
||||
|
||||
@property
|
||||
def base_url(self):
|
||||
return self._base_url
|
||||
|
||||
@property
|
||||
def address(self):
|
||||
return str(self.base_url)
|
||||
|
||||
@base_url.setter
|
||||
def base_url(self, url):
|
||||
self._base_url = url if isinstance(url, httpx.URL) else httpx.URL(url)
|
||||
|
||||
def _build_auth_headers(self) -> t.Dict[str, str]:
|
||||
return {} # can be overridden for subclasses for auth support
|
||||
|
||||
@property
|
||||
def auth(self) -> httpx.Auth | None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def user_agent(self):
|
||||
return f'{self.__class__.__name__}/Python {self._version}'
|
||||
|
||||
@property
|
||||
def auth_headers(self):
|
||||
return self._auth_headers
|
||||
|
||||
@property
|
||||
def _default_headers(self) -> t.Dict[str, str]:
|
||||
return {
|
||||
'Content-Type': 'application/json',
|
||||
'Accept': 'application/json',
|
||||
'User-Agent': self.user_agent,
|
||||
**self.platform_headers,
|
||||
**self.auth_headers,
|
||||
}
|
||||
|
||||
@property
|
||||
def platform_headers(self):
|
||||
return {
|
||||
'X-OpenLLM-Client-Package-Version': self._version,
|
||||
'X-OpenLLM-Client-Language': 'Python',
|
||||
'X-OpenLLM-Client-Runtime': platform.python_implementation(),
|
||||
'X-OpenLLM-Client-Runtime-Version': platform.python_version(),
|
||||
'X-OpenLLM-Client-Arch': str(_architecture()),
|
||||
'X-OpenLLM-Client-OS': str(_platform()),
|
||||
}
|
||||
|
||||
def _remaining_retries(self, remaining_retries: int | None, options: RequestOptions) -> int:
|
||||
return remaining_retries if remaining_retries is not None else options.get_max_retries(self._max_retries)
|
||||
|
||||
def _build_headers(self, options: RequestOptions) -> httpx.Headers:
|
||||
return httpx.Headers(_merge_mapping(self._default_headers, options.headers or {}))
|
||||
|
||||
def _build_request(self, options: RequestOptions) -> httpx.Request:
|
||||
return self._inner.build_request(
|
||||
method=options.method,
|
||||
headers=self._build_headers(options),
|
||||
url=self._prepare_url(options.url),
|
||||
json=options.json,
|
||||
params=options.params,
|
||||
)
|
||||
|
||||
def _calculate_retry_timeout(
|
||||
self, remaining_retries: int, options: RequestOptions, headers: t.Optional[httpx.Headers] = None
|
||||
) -> float:
|
||||
max_retries = options.get_max_retries(self._max_retries)
|
||||
# https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After
|
||||
try:
|
||||
if headers is not None:
|
||||
retry_header = headers.get('retry-after')
|
||||
try:
|
||||
retry_after = int(retry_header)
|
||||
except ValueError:
|
||||
tup = email.utils.parsedate_tz(retry_header)
|
||||
if tup is None:
|
||||
retry_after = -1
|
||||
else:
|
||||
retry_after = int(email.utils.mktime_tz(tup) - time.time())
|
||||
else:
|
||||
retry_after = -1
|
||||
except Exception:
|
||||
# omit everything
|
||||
retry_after = -1
|
||||
if 0 < retry_after <= 60:
|
||||
return retry_after # this is reasonable from users
|
||||
initial_delay, max_delay = 0.5, 8.0
|
||||
num_retries = max_retries - remaining_retries
|
||||
|
||||
sleep = min(initial_delay * pow(2.0, num_retries), max_delay) # apply exponential backoff here
|
||||
timeout = sleep * (1 - 0.25 * random.random())
|
||||
return timeout if timeout >= 0 else 0
|
||||
|
||||
def _should_retry(self, response: httpx.Response) -> bool:
|
||||
should_retry_header = response.headers.get('x-should-retry')
|
||||
if should_retry_header.lower() == 'true':
|
||||
return True
|
||||
if should_retry_header.lower() == 'false':
|
||||
return False
|
||||
if response.status_code in {408, 409, 429}:
|
||||
return True
|
||||
if response.status_code >= 500:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _process_response_data(
|
||||
self, *, response_cls: type[Response], data: t.Dict[str, t.Any], raw_response: httpx.Response
|
||||
) -> Response:
|
||||
return converter.structure(data, response_cls)
|
||||
|
||||
def _process_response(
|
||||
self,
|
||||
*,
|
||||
response_cls: type[Response],
|
||||
options: RequestOptions,
|
||||
raw_response: httpx.Response,
|
||||
stream: bool,
|
||||
stream_cls: type[_Stream] | type[_AsyncStream] | None,
|
||||
) -> Response:
|
||||
return APIResponse(
|
||||
raw_response=raw_response,
|
||||
client=self,
|
||||
response_cls=response_cls,
|
||||
stream=stream,
|
||||
stream_cls=stream_cls,
|
||||
options=options,
|
||||
).parse()
|
||||
|
||||
|
||||
@attr.define(init=False)
|
||||
class Client(BaseClient[httpx.Client, Stream[t.Any]]):
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str | httpx.URL,
|
||||
version: str,
|
||||
timeout: int | httpx.Timeout = DEFAULT_TIMEOUT,
|
||||
max_retries: int = MAX_RETRIES,
|
||||
):
|
||||
super().__init__(
|
||||
base_url=base_url,
|
||||
version=version,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
client=httpx.Client(base_url=base_url, timeout=timeout),
|
||||
_default_stream_cls=Stream,
|
||||
_internal=True,
|
||||
)
|
||||
|
||||
def close(self):
|
||||
self._inner.close()
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, *args) -> None:
|
||||
self.close()
|
||||
|
||||
def request(
|
||||
self,
|
||||
response_cls: type[Response],
|
||||
options: RequestOptions,
|
||||
remaining_retries: t.Optional[int] = None,
|
||||
*,
|
||||
stream: bool = False,
|
||||
stream_cls: type[_Stream] | None = None,
|
||||
) -> Response | _Stream:
|
||||
return self._request(
|
||||
response_cls=response_cls,
|
||||
options=options,
|
||||
remaining_retries=remaining_retries,
|
||||
stream=stream,
|
||||
stream_cls=stream_cls,
|
||||
)
|
||||
|
||||
def _request(
|
||||
self,
|
||||
response_cls: type[Response],
|
||||
options: RequestOptions,
|
||||
remaining_retries: int | None = None,
|
||||
*,
|
||||
stream: bool = False,
|
||||
stream_cls: type[_Stream] | None = None,
|
||||
) -> Response | _Stream:
|
||||
retries = self._remaining_retries(remaining_retries, options)
|
||||
request = self._build_request(options)
|
||||
try:
|
||||
response = self._inner.send(request, auth=self.auth, stream=stream)
|
||||
logger.debug('HTTP [%s, %s]: %s [%i]', request.method, request.url, response.status_code, response.reason_phrase)
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as exc:
|
||||
if retries > 0 and self._should_retry(exc.response):
|
||||
return self._retry_request(
|
||||
response_cls, options, retries, exc.response.headers, stream=stream, stream_cls=stream_cls
|
||||
)
|
||||
# If the response is streamed then we need to explicitly read the completed response
|
||||
exc.response.read()
|
||||
raise ValueError(exc.message) from None
|
||||
except httpx.TimeoutException:
|
||||
if retries > 0:
|
||||
return self._retry_request(response_cls, options, retries, stream=stream, stream_cls=stream_cls)
|
||||
raise ValueError(request) from None # timeout
|
||||
except Exception:
|
||||
if retries > 0:
|
||||
return self._retry_request(response_cls, options, retries, stream=stream, stream_cls=stream_cls)
|
||||
raise ValueError(request) from None # connection error
|
||||
|
||||
return self._process_response(
|
||||
response_cls=response_cls, options=options, raw_response=response, stream=stream, stream_cls=stream_cls
|
||||
)
|
||||
|
||||
def _retry_request(
|
||||
self,
|
||||
response_cls: type[Response],
|
||||
options: RequestOptions,
|
||||
remaining_retries: int,
|
||||
response_headers: httpx.Headers | None = None,
|
||||
*,
|
||||
stream: bool = False,
|
||||
stream_cls: type[_Stream] | None,
|
||||
) -> Response | _Stream:
|
||||
remaining = remaining_retries - 1
|
||||
timeout = self._calculate_retry_timeout(remaining_retries, options, response_headers)
|
||||
logger.info('Retrying request to %s in %f seconds', options.url, timeout)
|
||||
# In synchronous thread we are blocking the thread. Depends on how users want to do this downstream.
|
||||
time.sleep(timeout)
|
||||
return self._request(response_cls, options, remaining, stream=stream, stream_cls=stream_cls)
|
||||
|
||||
def _get(
|
||||
self,
|
||||
path: str,
|
||||
*,
|
||||
response_cls: type[Response],
|
||||
options: dict[str, t.Any] | None = None,
|
||||
stream: bool = False,
|
||||
stream_cls: type[_Stream] | None = None,
|
||||
) -> Response | _Stream:
|
||||
if options is None:
|
||||
options = {}
|
||||
return self.request(
|
||||
response_cls, RequestOptions(method='GET', url=path, **options), stream=stream, stream_cls=stream_cls
|
||||
)
|
||||
|
||||
def _post(
|
||||
self,
|
||||
path: str,
|
||||
*,
|
||||
response_cls: type[Response],
|
||||
json: dict[str, t.Any],
|
||||
options: dict[str, t.Any] | None = None,
|
||||
stream: bool = False,
|
||||
stream_cls: type[_Stream] | None = None,
|
||||
) -> Response | _Stream:
|
||||
if options is None:
|
||||
options = {}
|
||||
return self.request(
|
||||
response_cls, RequestOptions(method='POST', url=path, json=json, **options), stream=stream, stream_cls=stream_cls
|
||||
)
|
||||
|
||||
|
||||
@attr.define(init=False)
|
||||
class AsyncClient(BaseClient[httpx.AsyncClient, AsyncStream[t.Any]]):
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str | httpx.URL,
|
||||
version: str,
|
||||
timeout: int | httpx.Timeout = DEFAULT_TIMEOUT,
|
||||
max_retries: int = MAX_RETRIES,
|
||||
):
|
||||
super().__init__(
|
||||
base_url=base_url,
|
||||
version=version,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
client=httpx.AsyncClient(base_url=base_url, timeout=timeout),
|
||||
_default_stream_cls=AsyncStream,
|
||||
_internal=True,
|
||||
)
|
||||
|
||||
async def close(self):
|
||||
await self._inner.aclose()
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *args) -> None:
|
||||
await self.close()
|
||||
|
||||
async def request(
|
||||
self,
|
||||
response_cls: type[Response],
|
||||
options: RequestOptions,
|
||||
remaining_retries: t.Optional[int] = None,
|
||||
*,
|
||||
stream: bool = False,
|
||||
stream_cls: type[_AsyncStream] | None = None,
|
||||
) -> Response | _AsyncStream:
|
||||
return await self._request(
|
||||
response_cls, options, remaining_retries=remaining_retries, stream=stream, stream_cls=stream_cls
|
||||
)
|
||||
|
||||
async def _request(
|
||||
self,
|
||||
response_cls: type[Response],
|
||||
options: RequestOptions,
|
||||
remaining_retries: int | None = None,
|
||||
*,
|
||||
stream: bool = False,
|
||||
stream_cls: type[_AsyncStream] | None = None,
|
||||
) -> Response | _AsyncStream:
|
||||
retries = self._remaining_retries(remaining_retries, options)
|
||||
request = self._build_request(options)
|
||||
|
||||
try:
|
||||
response = await self._inner.send(request, auth=self.auth, stream=stream)
|
||||
logger.debug('HTTP [%s, %s]: %s [%i]', request.method, request.url, response.status_code, response.reason_phrase)
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as exc:
|
||||
if retries > 0 and self._should_retry(exc.response):
|
||||
return self._retry_request(
|
||||
response_cls, options, retries, exc.response.headers, stream=stream, stream_cls=stream_cls
|
||||
)
|
||||
# If the response is streamed then we need to explicitly read the completed response
|
||||
await exc.response.aread()
|
||||
raise ValueError(exc.message) from None
|
||||
except httpx.ConnectTimeout as err:
|
||||
if retries > 0:
|
||||
return await self._retry_request(response_cls, options, retries, stream=stream, stream_cls=stream_cls)
|
||||
raise ValueError(request) from err # timeout
|
||||
except httpx.ReadTimeout:
|
||||
# We don't retry on ReadTimeout error, so something might happen on server-side
|
||||
raise
|
||||
except httpx.TimeoutException as err:
|
||||
if retries > 0:
|
||||
return await self._retry_request(response_cls, options, retries, stream=stream, stream_cls=stream_cls)
|
||||
raise ValueError(request) from err # timeout
|
||||
except Exception as err:
|
||||
if retries > 0:
|
||||
return await self._retry_request(response_cls, options, retries, stream=stream, stream_cls=stream_cls)
|
||||
raise ValueError(request) from err # connection error
|
||||
|
||||
return self._process_response(
|
||||
response_cls=response_cls, options=options, raw_response=response, stream=stream, stream_cls=stream_cls
|
||||
)
|
||||
|
||||
async def _retry_request(
|
||||
self,
|
||||
response_cls: type[Response],
|
||||
options: RequestOptions,
|
||||
remaining_retries: int,
|
||||
response_headers: httpx.Headers | None = None,
|
||||
*,
|
||||
stream: bool,
|
||||
stream_cls: type[_AsyncStream] | None,
|
||||
):
|
||||
remaining = remaining_retries - 1
|
||||
timeout = self._calculate_retry_timeout(remaining_retries, options, response_headers)
|
||||
logger.info('Retrying request to %s in %f seconds', options.url, timeout)
|
||||
await anyio.sleep(timeout)
|
||||
return await self._request(response_cls, options, remaining, stream=stream, stream_cls=stream_cls)
|
||||
|
||||
async def _get(
|
||||
self,
|
||||
path: str,
|
||||
*,
|
||||
response_cls: type[Response],
|
||||
options: dict[str, t.Any] | None = None,
|
||||
stream: bool = False,
|
||||
stream_cls: type[_AsyncStream] | None = None,
|
||||
) -> Response | _AsyncStream:
|
||||
if options is None:
|
||||
options = {}
|
||||
return await self.request(
|
||||
response_cls, RequestOptions(method='GET', url=path, **options), stream=stream, stream_cls=stream_cls
|
||||
)
|
||||
|
||||
async def _post(
|
||||
self,
|
||||
path: str,
|
||||
*,
|
||||
response_cls: type[Response],
|
||||
json: dict[str, t.Any],
|
||||
options: dict[str, t.Any] | None = None,
|
||||
stream: bool = False,
|
||||
stream_cls: type[_AsyncStream] | None = None,
|
||||
) -> Response | _AsyncStream:
|
||||
if options is None:
|
||||
options = {}
|
||||
return await self.request(
|
||||
response_cls, RequestOptions(method='POST', url=path, json=json, **options), stream=stream, stream_cls=stream_cls
|
||||
)
|
||||
143
openllm-client/src/openllm_client/_stream.py
Normal file
143
openllm-client/src/openllm_client/_stream.py
Normal file
@@ -0,0 +1,143 @@
|
||||
from __future__ import annotations
|
||||
import typing as t
|
||||
|
||||
import attr
|
||||
import httpx
|
||||
import orjson
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from ._shim import AsyncClient
|
||||
from ._shim import Client
|
||||
|
||||
Response = t.TypeVar('Response', bound=attr.AttrsInstance)
|
||||
|
||||
|
||||
@attr.define(auto_attribs=True)
|
||||
class Stream(t.Generic[Response]):
|
||||
_response_cls: t.Type[Response]
|
||||
_response: httpx.Response
|
||||
_client: Client
|
||||
_decoder: SSEDecoder = attr.field(factory=lambda: SSEDecoder())
|
||||
_iterator: t.Iterator[Response] = attr.field(init=False)
|
||||
|
||||
def __init__(self, response_cls, response, client):
|
||||
self.__attrs_init__(response_cls, response, client)
|
||||
self._iterator = self._stream()
|
||||
|
||||
def __next__(self):
|
||||
return self._iterator.__next__()
|
||||
|
||||
def __iter__(self) -> t.Iterator[Response]:
|
||||
for item in self._iterator:
|
||||
yield item
|
||||
|
||||
def _iter_events(self) -> t.Iterator[SSE]:
|
||||
yield from self._decoder.iter(self._response.iter_lines())
|
||||
|
||||
def _stream(self) -> t.Iterator[Response]:
|
||||
for sse in self._iter_events():
|
||||
if sse.data.startswith('[DONE]'):
|
||||
break
|
||||
if sse.event is None:
|
||||
yield self._client._process_response_data(
|
||||
data=sse.model_dump(), response_cls=self._response_cls, raw_response=self._response
|
||||
)
|
||||
|
||||
|
||||
@attr.define(auto_attribs=True)
|
||||
class AsyncStream(t.Generic[Response]):
|
||||
_response_cls: t.Type[Response]
|
||||
_response: httpx.Response
|
||||
_client: AsyncClient
|
||||
_decoder: SSEDecoder = attr.field(factory=lambda: SSEDecoder())
|
||||
_iterator: t.Iterator[Response] = attr.field(init=False)
|
||||
|
||||
def __init__(self, response_cls, response, client):
|
||||
self.__attrs_init__(response_cls, response, client)
|
||||
self._iterator = self._stream()
|
||||
|
||||
async def __anext__(self):
|
||||
return await self._iterator.__anext__()
|
||||
|
||||
async def __aiter__(self):
|
||||
async for item in self._iterator:
|
||||
yield item
|
||||
|
||||
async def _iter_events(self):
|
||||
async for sse in self._decoder.aiter(self._response.aiter_lines()):
|
||||
yield sse
|
||||
|
||||
async def _stream(self) -> t.AsyncGenerator[Response, None]:
|
||||
async for sse in self._iter_events():
|
||||
if sse.data.startswith('[DONE]'):
|
||||
break
|
||||
if sse.event is None:
|
||||
yield self._client._process_response_data(
|
||||
data=sse.model_dump(), response_cls=self._response_cls, raw_response=self._response
|
||||
)
|
||||
|
||||
|
||||
@attr.define
|
||||
class SSE:
|
||||
data: str = attr.field(default='')
|
||||
id: t.Optional[str] = attr.field(default=None)
|
||||
event: t.Optional[str] = attr.field(default=None)
|
||||
retry: t.Optional[int] = attr.field(default=None)
|
||||
|
||||
def model_dump(self) -> t.Dict[str, t.Any]:
|
||||
try:
|
||||
return orjson.loads(self.data)
|
||||
except orjson.JSONDecodeError:
|
||||
raise
|
||||
|
||||
|
||||
@attr.define(auto_attribs=True)
|
||||
class SSEDecoder:
|
||||
_data: t.List[str] = attr.field(factory=list)
|
||||
_event: t.Optional[str] = None
|
||||
_retry: t.Optional[int] = None
|
||||
_last_event_id: t.Optional[str] = None
|
||||
|
||||
def iter(self, iterator: t.Iterator[str]) -> t.Iterator[SSE]:
|
||||
for line in iterator:
|
||||
sse = self.decode(line.rstrip('\n'))
|
||||
if sse:
|
||||
yield sse
|
||||
|
||||
async def aiter(self, iterator: t.AsyncIterator[str]) -> t.AsyncIterator[SSE]:
|
||||
async for line in iterator:
|
||||
sse = self.decode(line.rstrip('\n'))
|
||||
if sse:
|
||||
yield sse
|
||||
|
||||
def decode(self, line: str) -> SSE | None:
|
||||
# NOTE: https://html.spec.whatwg.org/multipage/server-sent-events.html#event-stream-interpretation
|
||||
if not line:
|
||||
if all(not a for a in [self._event, self._data, self._retry, self._last_event_id]):
|
||||
return None
|
||||
sse = SSE(data='\n'.join(self._data), event=self._event, retry=self._retry, id=self._last_event_id)
|
||||
self._event, self._data, self._retry = None, [], None
|
||||
return sse
|
||||
if line.startswith(':'):
|
||||
return None
|
||||
field, _, value = line.partition(':')
|
||||
if value.startswith(' '):
|
||||
value = value[1:]
|
||||
if field == 'event':
|
||||
self._event = value
|
||||
elif field == 'data':
|
||||
self._data.append(value)
|
||||
elif field == 'id':
|
||||
if '\0' in value:
|
||||
pass
|
||||
else:
|
||||
self._last_event_id = value
|
||||
elif field == 'retry':
|
||||
try:
|
||||
self._retry = int(value)
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
else:
|
||||
pass # Ignore unknown fields
|
||||
return None
|
||||
29
openllm-client/src/openllm_client/_typing_compat.py
Normal file
29
openllm-client/src/openllm_client/_typing_compat.py
Normal file
@@ -0,0 +1,29 @@
|
||||
import sys
|
||||
|
||||
from typing import Literal
|
||||
|
||||
|
||||
if sys.version_info[:2] >= (3, 11):
|
||||
from typing import LiteralString as LiteralString
|
||||
from typing import NotRequired as NotRequired
|
||||
from typing import Required as Required
|
||||
from typing import Self as Self
|
||||
from typing import dataclass_transform as dataclass_transform
|
||||
from typing import overload as overload
|
||||
else:
|
||||
from typing_extensions import LiteralString as LiteralString
|
||||
from typing_extensions import NotRequired as NotRequired
|
||||
from typing_extensions import Required as Required
|
||||
from typing_extensions import Self as Self
|
||||
from typing_extensions import dataclass_transform as dataclass_transform
|
||||
from typing_extensions import overload as overload
|
||||
|
||||
if sys.version_info[:2] >= (3, 9):
|
||||
from typing import Annotated as Annotated
|
||||
else:
|
||||
from typing_extensions import Annotated as Annotated
|
||||
|
||||
Platform = Annotated[
|
||||
LiteralString, Literal['MacOS', 'Linux', 'Windows', 'FreeBSD', 'OpenBSD', 'iOS', 'iPadOS', 'Android', 'Unknown'], str
|
||||
]
|
||||
Architecture = Annotated[LiteralString, Literal['arm', 'arm64', 'x86', 'x86_64', 'Unknown'], str]
|
||||
6
openllm-client/src/openllm_client/_utils.py
Normal file
6
openllm-client/src/openllm_client/_utils.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from cattr import Converter
|
||||
|
||||
|
||||
converter = Converter(omit_if_default=True)
|
||||
@@ -108,7 +108,7 @@ class Conversation:
|
||||
elif self.sep_style == SeparatorStyle.LLAMA:
|
||||
seps = [self.sep, self.sep2]
|
||||
ret = system_prompt if self.system_message else '<s>[INST] '
|
||||
for i, (role, message) in enumerate(self.messages):
|
||||
for i, (_, message) in enumerate(self.messages):
|
||||
tag = self.roles[i % 2]
|
||||
if message:
|
||||
if i == 0:
|
||||
@@ -139,12 +139,12 @@ class Conversation:
|
||||
return ret
|
||||
elif self.sep_style == SeparatorStyle.MPT:
|
||||
ret = f'<|im_start|>system\n{system_prompt}<|im_end|>{self.sep}' if system_prompt else ''
|
||||
for i, (role, message) in enumerate(self.messages):
|
||||
for _, (role, message) in enumerate(self.messages):
|
||||
ret += f'<|im_start|>{role}\n{message}<|im_end|>{self.sep}' if message else f'{role}:'
|
||||
return ret
|
||||
elif self.sep_style == SeparatorStyle.STARCODER:
|
||||
ret = f'<|system|>\n{system_prompt}<|end|>{self.sep}' if system_prompt else ''
|
||||
for i, (role, message) in enumerate(self.messages):
|
||||
for _, (role, message) in enumerate(self.messages):
|
||||
ret += f'{role}\n{message}<|end|>{self.sep}' if message else f'{role}:'
|
||||
else:
|
||||
raise ValueError(f'Invalid style: {self.sep_style}')
|
||||
|
||||
@@ -16,26 +16,45 @@ from .utils import gen_random_uuid
|
||||
if t.TYPE_CHECKING:
|
||||
import vllm
|
||||
|
||||
from ._typing_compat import Self
|
||||
|
||||
@attr.frozen(slots=True)
|
||||
class MetadataOutput:
|
||||
model_id: str
|
||||
timeout: int
|
||||
model_name: str
|
||||
backend: str
|
||||
configuration: str
|
||||
prompt_template: str
|
||||
system_message: str
|
||||
|
||||
@attr.define
|
||||
class _SchemaMixin:
|
||||
def model_dump(self) -> dict[str, t.Any]:
|
||||
return converter.unstructure(self)
|
||||
|
||||
def model_dump_json(self) -> str:
|
||||
return orjson.dumps(self.model_dump(), option=orjson.OPT_INDENT_2).decode('utf-8')
|
||||
|
||||
def with_options(self, **options: t.Any) -> Self:
|
||||
return attr.evolve(self, **options)
|
||||
|
||||
@attr.define(slots=True, frozen=True)
|
||||
class GenerationInput:
|
||||
|
||||
@attr.define
|
||||
class MetadataOutput(_SchemaMixin):
|
||||
model_id: str
|
||||
timeout: int
|
||||
model_name: str
|
||||
backend: str
|
||||
configuration: str
|
||||
prompt_template: t.Optional[str]
|
||||
system_message: t.Optional[str]
|
||||
|
||||
def model_dump(self) -> dict[str, t.Any]:
|
||||
return {
|
||||
'model_id': self.model_id,
|
||||
'timeout': self.timeout,
|
||||
'model_name': self.model_name,
|
||||
'backend': self.backend,
|
||||
'configuration': self.configuration,
|
||||
'prompt_template': orjson.dumps(self.prompt_template).decode(),
|
||||
'system_message': orjson.dumps(self.system_message).decode(),
|
||||
}
|
||||
|
||||
|
||||
@attr.define
|
||||
class GenerationInput(_SchemaMixin):
|
||||
prompt: str
|
||||
llm_config: LLMConfig
|
||||
stop: list[str] | None = attr.field(default=None)
|
||||
@@ -53,9 +72,6 @@ class GenerationInput:
|
||||
'adapter_name': self.adapter_name,
|
||||
}
|
||||
|
||||
def model_dump_json(self) -> str:
|
||||
return orjson.dumps(self.model_dump(), option=orjson.OPT_INDENT_2).decode('utf-8')
|
||||
|
||||
@classmethod
|
||||
def from_llm_config(cls, llm_config: LLMConfig) -> type[GenerationInput]:
|
||||
def init(self: GenerationInput, prompt: str, stop: list[str] | None, adapter_name: str | None) -> None:
|
||||
@@ -80,7 +96,7 @@ class GenerationInput:
|
||||
def examples(_: type[GenerationInput]) -> dict[str, t.Any]:
|
||||
return klass(prompt='What is the meaning of life?', llm_config=llm_config, stop=['\n']).model_dump()
|
||||
|
||||
setattr(klass, 'examples', classmethod(examples))
|
||||
klass.examples = classmethod(examples)
|
||||
|
||||
try:
|
||||
klass.__module__ = cls.__module__
|
||||
@@ -98,7 +114,7 @@ FinishReason = t.Literal['length', 'stop']
|
||||
|
||||
|
||||
@attr.define
|
||||
class CompletionChunk:
|
||||
class CompletionChunk(_SchemaMixin):
|
||||
index: int
|
||||
text: str
|
||||
token_ids: t.List[int]
|
||||
@@ -106,15 +122,12 @@ class CompletionChunk:
|
||||
logprobs: t.Optional[SampleLogprobs] = None
|
||||
finish_reason: t.Optional[FinishReason] = None
|
||||
|
||||
# yapf: disable
|
||||
def with_options(self,**options: t.Any)->CompletionChunk: return attr.evolve(self, **options)
|
||||
def model_dump(self)->dict[str, t.Any]:return converter.unstructure(self)
|
||||
def model_dump_json(self)->str:return orjson.dumps(self.model_dump(),option=orjson.OPT_NON_STR_KEYS).decode('utf-8')
|
||||
# yapf: enable
|
||||
def model_dump_json(self) -> str:
|
||||
return orjson.dumps(self.model_dump(), option=orjson.OPT_NON_STR_KEYS).decode('utf-8')
|
||||
|
||||
|
||||
@attr.define
|
||||
class GenerationOutput:
|
||||
class GenerationOutput(_SchemaMixin):
|
||||
prompt: str
|
||||
finished: bool
|
||||
outputs: t.List[CompletionChunk]
|
||||
@@ -182,11 +195,5 @@ class GenerationOutput:
|
||||
prompt_logprobs=request_output.prompt_logprobs,
|
||||
)
|
||||
|
||||
def with_options(self, **options: t.Any) -> GenerationOutput:
|
||||
return attr.evolve(self, **options)
|
||||
|
||||
def model_dump(self) -> dict[str, t.Any]:
|
||||
return converter.unstructure(self)
|
||||
|
||||
def model_dump_json(self) -> str:
|
||||
return orjson.dumps(self.model_dump(), option=orjson.OPT_NON_STR_KEYS).decode('utf-8')
|
||||
|
||||
@@ -70,6 +70,11 @@ else:
|
||||
from typing_extensions import TypeAlias as TypeAlias
|
||||
from typing_extensions import TypeGuard as TypeGuard
|
||||
|
||||
if sys.version_info[:2] >= (3, 9):
|
||||
from typing import Annotated as Annotated
|
||||
else:
|
||||
from typing_extensions import Annotated as Annotated
|
||||
|
||||
|
||||
class AdapterTuple(TupleAny):
|
||||
adapter_id: str
|
||||
|
||||
@@ -158,11 +158,11 @@ class AutoConfig:
|
||||
else:
|
||||
try:
|
||||
config_file = llm.bentomodel.path_of(CONFIG_FILE_NAME)
|
||||
except OpenLLMException:
|
||||
except OpenLLMException as err:
|
||||
if not is_transformers_available():
|
||||
raise MissingDependencyError(
|
||||
"'infer_class_from_llm' requires 'transformers' to be available. Make sure to install it with 'pip install transformers'"
|
||||
)
|
||||
) from err
|
||||
from transformers.utils import cached_file
|
||||
|
||||
try:
|
||||
|
||||
@@ -101,7 +101,7 @@ class FineTuneConfig:
|
||||
from peft.mapping import get_peft_config
|
||||
from peft.utils.peft_types import TaskType
|
||||
except ImportError:
|
||||
raise ImportError('PEFT is not installed. Please install it via `pip install "openllm[fine-tune]"`.')
|
||||
raise ImportError('PEFT is not installed. Please install it via `pip install "openllm[fine-tune]"`.') from None
|
||||
adapter_config = self.adapter_config.copy()
|
||||
# no need for peft_type
|
||||
if 'peft_type' in adapter_config:
|
||||
|
||||
@@ -66,7 +66,7 @@ else:
|
||||
|
||||
_import_structure: dict[str, list[str]] = {
|
||||
'exceptions': [],
|
||||
'client': [],
|
||||
'client': ['HTTPClient', 'AsyncHTTPClient'],
|
||||
'bundle': [],
|
||||
'playground': [],
|
||||
'testing': [],
|
||||
@@ -98,6 +98,8 @@ if _t.TYPE_CHECKING:
|
||||
from . import serialisation as serialisation
|
||||
from . import testing as testing
|
||||
from . import utils as utils
|
||||
from .client import HTTPClient as HTTPClient
|
||||
from .client import AsyncHTTPClient as AsyncHTTPClient
|
||||
from ._deprecated import Runner as Runner
|
||||
from ._generation import LogitsProcessorList as LogitsProcessorList
|
||||
from ._generation import StopOnTokens as StopOnTokens
|
||||
|
||||
@@ -7,9 +7,6 @@ To start any OpenLLM model:
|
||||
openllm start <model_name> --options ...
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
from openllm.cli.entrypoint import cli
|
||||
|
||||
|
||||
@@ -71,7 +71,7 @@ def is_sentence_complete(output: str) -> bool:
|
||||
|
||||
def is_partial_stop(output: str, stop_str: str) -> bool:
|
||||
"""Check whether the output contains a partial stop str."""
|
||||
for i in range(0, min(len(output), len(stop_str))):
|
||||
for i in range(min(len(output), len(stop_str))):
|
||||
if stop_str.startswith(output[-i:]):
|
||||
return True
|
||||
return False
|
||||
|
||||
@@ -235,7 +235,7 @@ def _validate(cls: type[DynResource], val: list[t.Any]) -> None:
|
||||
raise RuntimeError('Failed to initialise CUDA runtime binding.')
|
||||
# correctly parse handle
|
||||
for el in val:
|
||||
if el.startswith('GPU-') or el.startswith('MIG-'):
|
||||
if el.startswith(('GPU-', 'MIG-')):
|
||||
uuids = _raw_device_uuid_nvml()
|
||||
if uuids is None:
|
||||
raise ValueError('Failed to parse available GPUs UUID')
|
||||
|
||||
@@ -580,6 +580,6 @@ def append_schemas(
|
||||
def mk_generate_spec(svc:bentoml.Service,openapi_version:str=OPENAPI_VERSION)->MKSchema:return MKSchema(svc_schema)
|
||||
def mk_asdict(self:OpenAPISpecification)->dict[str,t.Any]:return svc_schema
|
||||
openapi.generate_spec=mk_generate_spec
|
||||
setattr(OpenAPISpecification, 'asdict', mk_asdict)
|
||||
OpenAPISpecification.asdict = mk_asdict
|
||||
# yapf: disable
|
||||
return svc
|
||||
|
||||
@@ -1,30 +1,13 @@
|
||||
"""Serialisation related implementation for GGML-based implementation.
|
||||
|
||||
This requires ctransformers to be installed.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import typing as t
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import bentoml
|
||||
import openllm
|
||||
|
||||
from openllm_core._typing_compat import M
|
||||
|
||||
_conversion_strategy = {'pt': 'ggml'}
|
||||
|
||||
|
||||
def import_model(
|
||||
llm: openllm.LLM[t.Any, t.Any], *decls: t.Any, trust_remote_code: bool = True, **attrs: t.Any
|
||||
) -> bentoml.Model:
|
||||
def import_model(llm, *decls, trust_remote_code=True, **attrs):
|
||||
raise NotImplementedError('Currently work in progress.')
|
||||
|
||||
|
||||
def get(llm: openllm.LLM[t.Any, t.Any]) -> bentoml.Model:
|
||||
def get(llm):
|
||||
raise NotImplementedError('Currently work in progress.')
|
||||
|
||||
|
||||
def load_model(llm: openllm.LLM[M, t.Any], *decls: t.Any, **attrs: t.Any) -> M:
|
||||
def load_model(llm, *decls, **attrs):
|
||||
raise NotImplementedError('Currently work in progress.')
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
"""Serialisation related implementation for Transformers-based implementation."""
|
||||
|
||||
from __future__ import annotations
|
||||
import importlib
|
||||
import logging
|
||||
|
||||
@@ -6,6 +6,7 @@ we won't ensure backward compatibility for these functions. So use with caution.
|
||||
|
||||
from __future__ import annotations
|
||||
import functools
|
||||
import importlib.metadata
|
||||
import typing as t
|
||||
|
||||
import openllm_core
|
||||
@@ -72,6 +73,7 @@ def generate_labels(llm: openllm.LLM[t.Any, t.Any]) -> dict[str, t.Any]:
|
||||
'model_name': llm.config['model_name'],
|
||||
'architecture': llm.config['architecture'],
|
||||
'serialisation': llm._serialisation,
|
||||
**{package: importlib.metadata.version(package) for package in {'openllm', 'openllm-core', 'openllm-client'}},
|
||||
}
|
||||
|
||||
|
||||
|
||||
27
ruff.toml
27
ruff.toml
@@ -9,6 +9,10 @@ extend-exclude = [
|
||||
]
|
||||
extend-include = ["*.ipynb"]
|
||||
extend-select = [
|
||||
"E",
|
||||
"F",
|
||||
"B",
|
||||
"PIE",
|
||||
"I", # isort
|
||||
"G", # flake8-logging-format
|
||||
"W", # pycodestyle
|
||||
@@ -40,11 +44,19 @@ ignore = [
|
||||
line-length = 119
|
||||
indent-width = 2
|
||||
target-version = "py38"
|
||||
typing-modules = ["openllm_core._typing_compat"]
|
||||
typing-modules = [
|
||||
"openllm_core._typing_compat",
|
||||
"openllm_client._typing_compat",
|
||||
]
|
||||
unfixable = ["TCH004"]
|
||||
|
||||
[lint.flake8-type-checking]
|
||||
exempt-modules = ["typing", "typing_extensions", "openllm_core._typing_compat"]
|
||||
exempt-modules = [
|
||||
"typing",
|
||||
"typing_extensions",
|
||||
"openllm_core._typing_compat",
|
||||
"openllm_client._typing_compat",
|
||||
]
|
||||
runtime-evaluated-base-classes = [
|
||||
"openllm_core._configuration.LLMConfig",
|
||||
"openllm_core._configuration.GenerationConfig",
|
||||
@@ -52,7 +64,15 @@ runtime-evaluated-base-classes = [
|
||||
"openllm_core._configuration.ModelSettings",
|
||||
"openllm.LLMConfig",
|
||||
]
|
||||
runtime-evaluated-decorators = ["attrs.define", "attrs.frozen", "trait"]
|
||||
runtime-evaluated-decorators = [
|
||||
"attrs.define",
|
||||
"attrs.frozen",
|
||||
"trait",
|
||||
"attr.attrs",
|
||||
'attr.define',
|
||||
'_attr.define',
|
||||
'attr.frozen',
|
||||
]
|
||||
|
||||
[format]
|
||||
preview = true
|
||||
@@ -65,6 +85,7 @@ convention = "google"
|
||||
|
||||
[lint.pycodestyle]
|
||||
ignore-overlong-task-comments = true
|
||||
max-line-length = 119
|
||||
|
||||
[lint.isort]
|
||||
combine-as-imports = true
|
||||
|
||||
Reference in New Issue
Block a user