feat(client): support authentication token and shim implementation (#605)

* chore: synch generate_iterator to be the same as server

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>

* --wip--

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>

* wip

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>

* feat: cleanup shim implementation

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>

* ci: auto fixes from pre-commit.ci

For more information, see https://pre-commit.ci

* chore: fix pre-commit

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>

* chore: update changelog

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>

* chore: update check with tuple

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>

---------

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Aaron Pham
2023-11-10 17:44:31 -05:00
committed by GitHub
parent af0b1b9a7f
commit c41828f68f
25 changed files with 1086 additions and 344 deletions

View File

@@ -202,11 +202,8 @@ After you change or update any CI related under `.github`, run `bash tools/lock-
See this [docs](/.github/INFRA.md) for more information on OpenLLM's CI/CD workflow.
## UI
See [ClojureScript UI's README.md](/external/clojure/README.md) for more information.
See [Documentation's README.md](/docs/README.md) for more information on running locally.
## Typing
For all internal functions, it is recommended to provide type hint. For all public function definitions, it is recommended to create a stubs file `.pyi` to separate supported external API to increase code visibility. See [openllm-client's `__init__.pyi`](/openllm-client/src/openllm_client/__init__.pyi) for example.
## Install from git archive install

View File

@@ -0,0 +1 @@
Update client implementation and support Authentication through `OPENLLM_AUTH_TOKEN`

View File

@@ -1,6 +1,7 @@
# NOTE: Make sure to install openai>1
import os, openai
import os, openai, typing as t
from openai.types.chat import (
ChatCompletionMessageParam,
ChatCompletionSystemMessageParam,
ChatCompletionUserMessageParam,
ChatCompletionAssistantMessageParam,
@@ -14,7 +15,7 @@ model = models.data[0].id
# Chat completion API
stream = str(os.getenv('STREAM', False)).upper() in ['TRUE', '1', 'YES', 'Y', 'ON']
messages = [
messages: t.List[ChatCompletionMessageParam]= [
ChatCompletionSystemMessageParam(role='system', content='You are acting as Ernest Hemmingway.'),
ChatCompletionUserMessageParam(role='user', content='Hi there!'),
ChatCompletionAssistantMessageParam(role='assistant', content='Yes?'),

View File

@@ -57,7 +57,7 @@ keywords = [
"PyTorch",
"Transformers",
]
dependencies = ["orjson", "httpx", "attrs>=23.1.0", "cattrs>=23.1.0"]
dependencies = ["orjson", "httpx", "attrs>=23.1.0", "cattrs>=23.1.0", 'distro', 'anyio']
license = "Apache-2.0"
name = "openllm-client"
requires-python = ">=3.8"
@@ -71,8 +71,9 @@ Homepage = "https://bentoml.com"
Tracker = "https://github.com/bentoml/OpenLLM/issues"
Twitter = "https://twitter.com/bentomlai"
[project.optional-dependencies]
full = ["openllm-client[grpc,agents]"]
grpc = ["bentoml[grpc]>=1.1.6"]
full = ["openllm-client[grpc,agents,auth]"]
grpc = ["bentoml[grpc]>=1.1.9"]
auth = ['httpx_auth']
agents = ["transformers[agents]>=4.30", "diffusers", "soundfile"]
[tool.hatch.version]

View File

@@ -15,26 +15,21 @@ from ._schemas import StreamingResponse as _StreamingResponse
@_attr.define
class HTTPClient:
address: str
client_args: Dict[str, Any]
@staticmethod
def wait_until_server_ready(
addr: str, timeout: float = ..., verify: bool = ..., check_interval: int = ..., **client_args: Any
@overload
def __init__(
self, address: str, timeout: int = ..., verify: bool = ..., max_retries: int = ..., api_version: str = ...
) -> None: ...
@overload
def __init__(
self, address: str, timeout: int = ..., verify: bool = ..., api_version: str = ..., **client_args: Any
self, address: str = ..., timeout: int = ..., verify: bool = ..., max_retries: int = ..., api_version: str = ...
) -> None: ...
@overload
def __init__(
self, address: str = ..., timeout: int = ..., verify: bool = ..., api_version: str = ..., **client_args: Any
) -> None: ...
@overload
def __init__(
self, address: None = ..., timeout: int = ..., verify: bool = ..., api_version: str = ..., **client_args: Any
self, address: None = ..., timeout: int = ..., verify: bool = ..., max_retries: int = ..., api_version: str = ...
) -> None: ...
@property
def is_ready(self) -> bool: ...
def health(self) -> None: ...
def health(self) -> bool: ...
def query(self, prompt: str, **attrs: Any) -> _Response: ...
def generate(
self,
@@ -46,6 +41,16 @@ class HTTPClient:
verify: Optional[bool] = ...,
**attrs: Any,
) -> _Response: ...
def generate_iterator(
self,
prompt: str,
llm_config: Optional[Dict[str, Any]] = ...,
stop: Optional[Union[str, List[str]]] = ...,
adapter_name: Optional[str] = ...,
timeout: Optional[int] = ...,
verify: Optional[bool] = ...,
**attrs: Any,
) -> Iterator[_Response]: ...
def generate_stream(
self,
prompt: str,
@@ -60,26 +65,21 @@ class HTTPClient:
@_attr.define
class AsyncHTTPClient:
address: str
client_args: Dict[str, Any]
@staticmethod
async def wait_until_server_ready(
addr: str, timeout: float = ..., verify: bool = ..., check_interval: int = ..., **client_args: Any
@overload
def __init__(
self, address: str, timeout: int = ..., verify: bool = ..., max_retries: int = ..., api_version: str = ...
) -> None: ...
@overload
def __init__(
self, address: str, timeout: int = ..., verify: bool = ..., api_version: str = ..., **client_args: Any
self, address: str = ..., timeout: int = ..., verify: bool = ..., max_retries: int = ..., api_version: str = ...
) -> None: ...
@overload
def __init__(
self, address: str = ..., timeout: int = ..., verify: bool = ..., api_version: str = ..., **client_args: Any
) -> None: ...
@overload
def __init__(
self, address: None = ..., timeout: int = ..., verify: bool = ..., api_version: str = ..., **client_args: Any
self, address: None = ..., timeout: int = ..., verify: bool = ..., max_retries: int = ..., api_version: str = ...
) -> None: ...
@property
def is_ready(self) -> bool: ...
async def health(self) -> None: ...
async def health(self) -> bool: ...
async def query(self, prompt: str, **attrs: Any) -> _Response: ...
async def generate(
self,
@@ -91,6 +91,16 @@ class AsyncHTTPClient:
verify: Optional[bool] = ...,
**attrs: Any,
) -> _Response: ...
async def generate_iterator(
self,
prompt: str,
llm_config: Optional[Dict[str, Any]] = ...,
stop: Optional[Union[str, List[str]]] = ...,
adapter_name: Optional[str] = ...,
timeout: Optional[int] = ...,
verify: Optional[bool] = ...,
**attrs: Any,
) -> Iterator[_Response]: ...
async def generate_stream(
self,
prompt: str,

View File

@@ -1,349 +1,240 @@
from __future__ import annotations
import asyncio
import enum
import importlib.metadata
import logging
import os
import time
import typing as t
import urllib.error
from urllib.parse import urlparse
import attr
import httpx
import orjson
from ._schemas import Request
from ._schemas import MetadataOutput
from ._schemas import Response
from ._schemas import StreamingResponse
from ._shim import MAX_RETRIES
from ._shim import AsyncClient
from ._shim import Client
logger = logging.getLogger(__name__)
def _address_validator(_, attr, value):
if not isinstance(value, str):
raise TypeError(f'{attr.name} must be a string')
if not urlparse(value).netloc:
raise ValueError(f'{attr.name} must be a valid URL')
VERSION = importlib.metadata.version('openllm-client')
def _address_converter(addr: str):
return addr if '://' in addr else 'http://' + addr
class ServerState(enum.Enum):
class ClientState(enum.Enum):
CLOSED = 1 # CLOSED: The server is not yet ready or `wait_until_server_ready` has not been called/failed.
READY = 2 # READY: The server is ready and `wait_until_server_ready` has been called.
_object_setattr = object.__setattr__
DISCONNECTED = 3 # DISCONNECTED: The server is disconnected and `wait_until_server_ready` has been called.
@attr.define(init=False)
class HTTPClient:
address: str = attr.field(validator=_address_validator, converter=_address_converter)
client_args: t.Dict[str, t.Any]
_inner: httpx.Client
_timeout: int = 30
class HTTPClient(Client):
_api_version: str = 'v1'
_verify: bool = True
_state: ServerState = ServerState.CLOSED
__metadata: dict[str, t.Any] | None = None
__metadata: MetadataOutput | None = None
__config: dict[str, t.Any] | None = None
def __repr__(self):
return (
f'<HTTPClient timeout={self._timeout} api_version={self._api_version} verify={self._verify} state={self._state}>'
)
return f'<HTTPClient address={self.address} timeout={self._timeout} api_version={self._api_version} verify={self._verify}>'
@staticmethod
def wait_until_server_ready(addr, timeout=30, verify=False, check_interval=1, **client_args):
addr = _address_converter(addr)
logger.debug('Wait for server @ %s to be ready', addr)
start = time.monotonic()
while time.monotonic() - start < timeout:
try:
with httpx.Client(base_url=addr, verify=verify, **client_args) as sess:
status = sess.get('/readyz').status_code
if status == 200:
break
else:
time.sleep(check_interval)
except (httpx.ConnectError, urllib.error.URLError, ConnectionError):
logger.debug('Server is not ready yet, retrying in %d seconds...', check_interval)
time.sleep(check_interval)
# Try once more and raise for exception
try:
with httpx.Client(base_url=addr, verify=verify, **client_args) as sess:
status = sess.get('/readyz').status_code
except httpx.HTTPStatusError as err:
logger.error('Failed to wait until server ready: %s', addr)
logger.error(err)
raise
def __init__(self, address=None, timeout=30, verify=False, api_version='v1', **client_args):
def __init__(self, address=None, timeout=30, verify=False, max_retries=MAX_RETRIES, api_version='v1'):
if address is None:
env = os.getenv('OPENLLM_ENDPOINT')
if env is None:
raise ValueError('address must be provided')
address = env
self.__attrs_init__(
address,
client_args,
httpx.Client(base_url=address, timeout=timeout, verify=verify, **client_args),
timeout,
api_version,
verify,
)
address = os.getenv('OPENLLM_ENDPOINT')
if address is None:
raise ValueError("address must either be provided or through 'OPENLLM_ENDPOINT'")
self._api_version, self._verify = api_version, verify
super().__init__(_address_converter(address), VERSION, timeout=timeout, max_retries=max_retries)
def _metadata(self) -> dict[str, t.Any]:
if self.__metadata is None:
self.__metadata = self._inner.post(self._build_endpoint('metadata')).json()
return self.__metadata
def _config(self) -> dict[str, t.Any]:
if self.__config is None:
config = orjson.loads(self._metadata()['configuration'])
generation_config = config.pop('generation_config')
self.__config = {**config, **generation_config}
return self.__config
def _build_auth_headers(self) -> t.Dict[str, str]:
env = os.getenv('OPENLLM_AUTH_TOKEN')
if env is not None:
return {'Authorization': f'Bearer {env}'}
return super()._build_auth_headers()
def __del__(self):
self._inner.close()
def _build_endpoint(self, endpoint):
return ('/' if not self._api_version.startswith('/') else '') + f'{self._api_version}/{endpoint}'
self.close()
@property
def is_ready(self):
return self._state == ServerState.READY
def _metadata(self):
if self.__metadata is None:
path = f'/{self._api_version}/metadata'
self.__metadata = self._post(
path, response_cls=MetadataOutput, json={}, options={'max_retries': self._max_retries}
)
return self.__metadata
@property
def _config(self) -> dict[str, t.Any]:
if self.__config is None:
self.__config = self._metadata.configuration
return self.__config
def query(self, prompt, **attrs):
return self.generate(prompt, **attrs)
def health(self):
try:
self.wait_until_server_ready(self.address, timeout=self._timeout, verify=self._verify, **self.client_args)
_object_setattr(self, '_state', ServerState.READY)
except Exception as e:
logger.error('Server is not healthy (Scroll up for traceback)\n%s', e)
_object_setattr(self, '_state', ServerState.CLOSED)
response = self._get(
'/readyz', response_cls=None, options={'return_raw_response': True, 'max_retries': self._max_retries}
)
return response.status_code == 200
def generate(
self, prompt, llm_config=None, stop=None, adapter_name=None, timeout=None, verify=None, **attrs
) -> Response:
if not self.is_ready:
self.health()
if not self.is_ready:
raise RuntimeError('Server is not ready. Check server logs for more information.')
if timeout is None:
timeout = self._timeout
if verify is None:
verify = self._verify
_meta, _config = self._metadata(), self._config()
verify = self._verify # XXX: need to support this again
if llm_config is not None:
llm_config = {**_config, **llm_config, **attrs}
llm_config = {**self._config, **llm_config, **attrs}
else:
llm_config = {**_config, **attrs}
if _meta['prompt_template'] is not None:
prompt = _meta['prompt_template'].format(system_message=_meta['system_message'], instruction=prompt)
llm_config = {**self._config, **attrs}
if self._metadata.prompt_template is not None:
prompt = self._metadata.prompt_template.format(system_message=self._metadata.system_message, instruction=prompt)
req = Request(prompt=prompt, llm_config=llm_config, stop=stop, adapter_name=adapter_name)
with httpx.Client(base_url=self.address, timeout=timeout, verify=verify, **self.client_args) as client:
r = client.post(self._build_endpoint('generate'), json=req.model_dump_json(), **self.client_args)
if r.status_code != 200:
raise ValueError("Failed to get generation from '/v1/generate'. Check server logs for more details.")
return Response.model_construct(r.json())
return self._post(
f'/{self._api_version}/generate',
response_cls=Response,
json=dict(prompt=prompt, llm_config=llm_config, stop=stop, adapter_name=adapter_name),
)
def generate_stream(
self, prompt, llm_config=None, stop=None, adapter_name=None, timeout=None, verify=None, **attrs
) -> t.Iterator[StreamingResponse]:
if not self.is_ready:
self.health()
if not self.is_ready:
raise RuntimeError('Server is not ready. Check server logs for more information.')
for response_chunk in self.generate_iterator(prompt, llm_config, stop, adapter_name, timeout, verify, **attrs):
yield StreamingResponse.from_response_chunk(response_chunk)
def generate_iterator(
self, prompt, llm_config=None, stop=None, adapter_name=None, timeout=None, verify=None, **attrs
) -> t.Iterator[Response]:
if timeout is None:
timeout = self._timeout
if verify is None:
verify = self._verify
_meta, _config = self._metadata(), self._config()
verify = self._verify # XXX: need to support this again
if llm_config is not None:
llm_config = {**_config, **llm_config, **attrs}
llm_config = {**self._config, **llm_config, **attrs}
else:
llm_config = {**_config, **attrs}
if _meta['prompt_template'] is not None:
prompt = _meta['prompt_template'].format(system_message=_meta['system_message'], instruction=prompt)
req = Request(prompt=prompt, llm_config=llm_config, stop=stop, adapter_name=adapter_name)
with httpx.Client(base_url=self.address, timeout=timeout, verify=verify, **self.client_args) as client:
with client.stream(
'POST', self._build_endpoint('generate_stream'), json=req.model_dump_json(), **self.client_args
) as r:
for payload in r.iter_bytes():
if payload == b'data: [DONE]\n\n':
break
# Skip line
if payload == b'\n':
continue
if payload.startswith(b'data: '):
try:
proc = payload.decode('utf-8').lstrip('data: ').rstrip('\n')
data = orjson.loads(proc)
yield StreamingResponse.from_response_chunk(Response.model_construct(data))
except Exception:
pass # FIXME: Handle this
llm_config = {**self._config, **attrs}
if self._metadata.prompt_template is not None:
prompt = self._metadata.prompt_template.format(system_message=self._metadata.system_message, instruction=prompt)
return self._post(
f'/{self._api_version}/generate_stream',
response_cls=Response,
json=dict(prompt=prompt, llm_config=llm_config, stop=stop, adapter_name=adapter_name),
stream=True,
)
@attr.define(init=False)
class AsyncHTTPClient:
address: str = attr.field(validator=_address_validator, converter=_address_converter)
client_args: t.Dict[str, t.Any]
_inner: httpx.AsyncClient
_timeout: int = 30
class AsyncHTTPClient(AsyncClient):
_api_version: str = 'v1'
_verify: bool = True
_state: ServerState = ServerState.CLOSED
__metadata: dict[str, t.Any] | None = None
__metadata: MetadataOutput | None = None
__config: dict[str, t.Any] | None = None
def __repr__(self):
return f'<AsyncHTTPClient timeout={self._timeout} api_version={self._api_version} verify={self._verify} state={self._state}>'
return f'<AsyncHTTPClient address={self.address} timeout={self._timeout} api_version={self._api_version} verify={self._verify}>'
@staticmethod
async def wait_until_server_ready(addr, timeout=30, verify=False, check_interval=1, **client_args):
addr = _address_converter(addr)
logger.debug('Wait for server @ %s to be ready', addr)
start = time.monotonic()
while time.monotonic() - start < timeout:
try:
async with httpx.AsyncClient(base_url=addr, verify=verify, **client_args) as sess:
status = (await sess.get('/readyz')).status_code
if status == 200:
break
else:
await asyncio.sleep(check_interval)
except (httpx.ConnectError, urllib.error.URLError, ConnectionError):
logger.debug('Server is not ready yet, retrying in %d seconds...', check_interval)
await asyncio.sleep(check_interval)
# Try once more and raise for exception
try:
async with httpx.AsyncClient(base_url=addr, verify=verify, **client_args) as sess:
status = (await sess.get('/readyz')).status_code
except httpx.HTTPStatusError as err:
logger.error('Failed to wait until server ready: %s', addr)
logger.error(err)
raise
def __init__(self, address=None, timeout=30, verify=False, api_version='v1', **client_args):
def __init__(self, address=None, timeout=30, verify=False, max_retries=MAX_RETRIES, api_version='v1'):
if address is None:
env = os.getenv('OPENLLM_ENDPOINT')
if env is None:
raise ValueError('address must be provided')
address = env
self.__attrs_init__(
address,
client_args,
httpx.AsyncClient(base_url=address, timeout=timeout, verify=verify, **client_args),
timeout,
api_version,
verify,
)
address = os.getenv('OPENLLM_ENDPOINT')
if address is None:
raise ValueError("address must either be provided or through 'OPENLLM_ENDPOINT'")
self._api_version, self._verify = api_version, verify
super().__init__(_address_converter(address), VERSION, timeout=timeout, max_retries=max_retries)
async def _metadata(self) -> dict[str, t.Any]:
if self.__metadata is None:
self.__metadata = (await self._inner.post(self._build_endpoint('metadata'))).json()
return self.__metadata
async def _config(self) -> dict[str, t.Any]:
if self.__config is None:
config = orjson.loads((await self._metadata())['configuration'])
generation_config = config.pop('generation_config')
self.__config = {**config, **generation_config}
return self.__config
def _build_endpoint(self, endpoint):
return '/' + f'{self._api_version}/{endpoint}'
def _build_auth_headers(self) -> t.Dict[str, str]:
env = os.getenv('OPENLLM_AUTH_TOKEN')
if env is not None:
return {'Authorization': f'Bearer {env}'}
return super()._build_auth_headers()
@property
def is_ready(self):
return self._state == ServerState.READY
def _loop(self) -> asyncio.AbstractEventLoop:
try:
return asyncio.get_running_loop()
except RuntimeError:
return asyncio.get_event_loop()
@property
async def _metadata(self) -> t.Awaitable[MetadataOutput]:
if self.__metadata is None:
self.__metadata = await self._post(
f'/{self._api_version}/metadata',
response_cls=MetadataOutput,
json={},
options={'max_retries': self._max_retries},
)
return self.__metadata
@property
async def _config(self):
if self.__config is None:
self.__config = (await self._metadata).configuration
return self.__config
async def query(self, prompt, **attrs):
return await self.generate(prompt, **attrs)
async def health(self):
try:
await self.wait_until_server_ready(self.address, timeout=self._timeout, verify=self._verify, **self.client_args)
_object_setattr(self, '_state', ServerState.READY)
except Exception as e:
logger.error('Server is not healthy (Scroll up for traceback)\n%s', e)
_object_setattr(self, '_state', ServerState.CLOSED)
response = await self._get(
'/readyz', response_cls=None, options={'return_raw_response': True, 'max_retries': self._max_retries}
)
return response.status_code == 200
async def generate(
self, prompt, llm_config=None, stop=None, adapter_name=None, timeout=None, verify=None, **attrs
) -> Response:
if not self.is_ready:
await self.health()
if not self.is_ready:
raise RuntimeError('Server is not ready. Check server logs for more information.')
if timeout is None:
timeout = self._timeout
if verify is None:
verify = self._verify
_meta, _config = await self._metadata(), await self._config()
verify = self._verify # XXX: need to support this again
_metadata = await self._metadata
_config = await self._config
if llm_config is not None:
llm_config = {**_config, **llm_config, **attrs}
else:
llm_config = {**_config, **attrs}
if _meta['prompt_template'] is not None:
prompt = _meta['prompt_template'].format(system_message=_meta['system_message'], instruction=prompt)
req = Request(prompt=prompt, llm_config=llm_config, stop=stop, adapter_name=adapter_name)
async with httpx.AsyncClient(base_url=self.address, timeout=timeout, verify=verify, **self.client_args) as client:
r = await client.post(self._build_endpoint('generate'), json=req.model_dump_json(), **self.client_args)
if r.status_code != 200:
raise ValueError("Failed to get generation from '/v1/generate'. Check server logs for more details.")
return Response.model_construct(r.json())
if _metadata.prompt_template is not None:
prompt = _metadata.prompt_template.format(system_message=_metadata.system_message, instruction=prompt)
return await self._post(
f'/{self._api_version}/generate',
response_cls=Response,
json=dict(prompt=prompt, llm_config=llm_config, stop=stop, adapter_name=adapter_name),
)
async def generate_stream(
self, prompt, llm_config=None, stop=None, adapter_name=None, timeout=None, verify=None, **attrs
) -> t.AsyncGenerator[StreamingResponse, t.Any]:
if not self.is_ready:
await self.health()
if not self.is_ready:
raise RuntimeError('Server is not ready. Check server logs for more information.')
async for response_chunk in self.generate_iterator(
prompt, llm_config, stop, adapter_name, timeout, verify, **attrs
):
yield StreamingResponse.from_response_chunk(response_chunk)
async def generate_iterator(
self, prompt, llm_config=None, stop=None, adapter_name=None, timeout=None, verify=None, **attrs
) -> t.AsyncGenerator[Response, t.Any]:
if timeout is None:
timeout = self._timeout
if verify is None:
verify = self._verify
_meta, _config = await self._metadata(), await self._config()
verify = self._verify # XXX: need to support this again
_metadata = await self._metadata
_config = await self._config
if llm_config is not None:
llm_config = {**_config, **llm_config, **attrs}
else:
llm_config = {**_config, **attrs}
if _meta['prompt_template'] is not None:
prompt = _meta['prompt_template'].format(system_message=_meta['system_message'], instruction=prompt)
if _metadata.prompt_template is not None:
prompt = _metadata.prompt_template.format(system_message=_metadata.system_message, instruction=prompt)
req = Request(prompt=prompt, llm_config=llm_config, stop=stop, adapter_name=adapter_name)
async with httpx.AsyncClient(base_url=self.address, timeout=timeout, verify=verify, **self.client_args) as client:
async with client.stream(
'POST', self._build_endpoint('generate_stream'), json=req.model_dump_json(), **self.client_args
) as r:
async for payload in r.aiter_bytes():
if payload == b'data: [DONE]\n\n':
break
# Skip line
if payload == b'\n':
continue
if payload.startswith(b'data: '):
try:
proc = payload.decode('utf-8').lstrip('data: ').rstrip('\n')
data = orjson.loads(proc)
yield StreamingResponse.from_response_chunk(Response.model_construct(data))
except Exception:
pass # FIXME: Handle this
async for response_chunk in await self._post(
f'/{self._api_version}/generate_stream',
response_cls=Response,
json=dict(prompt=prompt, llm_config=llm_config, stop=stop, adapter_name=adapter_name),
stream=True,
):
yield response_chunk

View File

@@ -3,9 +3,47 @@ import typing as t
import attr
import cattr
import orjson
from ._utils import converter
# XXX: sync with openllm-core/src/openllm_core/_schemas.py
@attr.define
class MetadataOutput:
model_id: str
timeout: int
model_name: str
backend: str
configuration: t.Dict[str, t.Any]
prompt_template: t.Optional[str]
system_message: t.Optional[str]
def _structure_metadata(data: t.Dict[str, t.Any], cls: type[MetadataOutput]) -> MetadataOutput:
try:
configuration = orjson.loads(data['configuration'])
generation_config = configuration.pop('generation_config')
configuration = {**configuration, **generation_config}
except orjson.JSONDecodeError as e:
raise RuntimeError(f'Malformed metadata configuration (Server-side issue): {e}') from None
try:
return cls(
model_id=data['model_id'],
timeout=data['timeout'],
model_name=data['model_name'],
backend=data['backend'],
configuration=configuration,
prompt_template=data['prompt_template'],
system_message=data['system_message'],
)
except Exception as e:
raise RuntimeError(f'Malformed metadata (Server-side issue): {e}') from None
converter.register_structure_hook(MetadataOutput, _structure_metadata)
@attr.define
class Request:
prompt: str

View File

@@ -0,0 +1,610 @@
# This provides a base shim with httpx and acts as base request
from __future__ import annotations
import email.utils
import logging
import platform
import random
import time
import typing as t
import anyio
import attr
import distro
import httpx
from ._stream import AsyncStream
from ._stream import Response
from ._stream import Stream
from ._typing_compat import Architecture
from ._typing_compat import LiteralString
from ._typing_compat import Platform
from ._utils import converter
logger = logging.getLogger(__name__)
InnerClient = t.TypeVar('InnerClient', bound=t.Union[httpx.Client, httpx.AsyncClient])
StreamType = t.TypeVar('StreamType', bound=t.Union[Stream[t.Any], AsyncStream[t.Any]])
_Stream = t.TypeVar('_Stream', bound=Stream[t.Any])
_AsyncStream = t.TypeVar('_AsyncStream', bound=AsyncStream[t.Any])
LiteralVersion = t.Annotated[LiteralString, t.Literal['v1'], str]
def _address_converter(addr: str | httpx.URL) -> httpx.URL:
if isinstance(addr, httpx.URL):
url = addr
else:
url = httpx.URL(addr if '://' in addr else f'http://{addr}')
if not url.raw_path.endswith(b'/'):
url = url.copy_with(path=url.raw_path + b'/')
return url
MAX_RETRIES = 2
DEFAULT_TIMEOUT = httpx.Timeout(5.0) # Similar to httpx
_T_co = t.TypeVar('_T_co', covariant=True)
_T = t.TypeVar('_T')
def _merge_mapping(a: t.Mapping[_T_co, _T], b: t.Mapping[_T_co, _T]) -> t.Dict[_T_co, _T]:
# does the merging and filter out None
return {k: v for k, v in {**a, **b}.items() if v is not None}
def _platform() -> Platform:
system = platform.system().lower()
platform_name = platform.platform().lower()
if system == 'darwin':
return 'MacOS'
elif system == 'windows':
return 'Windows'
elif system == 'linux':
distro_id = distro.id()
if distro_id == 'freebsd':
return 'FreeBSD'
elif distro_id == 'openbsd':
return 'OpenBSD'
else:
return 'Linux'
elif 'android' in platform_name:
return 'Android'
elif 'iphone' in platform_name:
return 'iOS'
elif 'ipad' in platform_name:
return 'iPadOS'
if platform_name:
return f'Other:{platform_name}'
return 'Unknown'
def _architecture() -> Architecture:
machine = platform.machine().lower()
if machine in {'arm64', 'aarch64'}:
return 'arm64'
elif machine in {'arm', 'aarch32'}:
return 'arm'
elif machine in {'x86_64', 'amd64'}:
return 'x86_64'
elif machine in {'x86', 'i386', 'i686'}:
return 'x86'
elif machine:
return f'Other:{machine}'
return 'Unknown'
@t.final
@attr.frozen(auto_attribs=True)
class RequestOptions:
method: str = attr.field(converter=str.lower)
url: str
json: t.Optional[t.Dict[str, t.Any]] = attr.field(default=None)
params: t.Optional[t.Mapping[str, t.Any]] = attr.field(default=None)
headers: t.Optional[t.Dict[str, str]] = attr.field(default=None)
max_retries: int = attr.field(default=MAX_RETRIES)
return_raw_response: bool = attr.field(default=False)
def get_max_retries(self, max_retries: int | None) -> int:
return max_retries if max_retries is not None else self.max_retries
@classmethod
def model_construct(cls, **options: t.Any) -> RequestOptions:
return cls(**options)
@t.final
@attr.frozen(auto_attribs=True)
class APIResponse(t.Generic[Response]):
_raw_response: httpx.Response
_client: BaseClient[t.Any, t.Any]
_response_cls: type[Response]
_stream: bool
_stream_cls: t.Optional[t.Union[t.Type[Stream[t.Any]], t.Type[AsyncStream[t.Any]]]]
_options: RequestOptions
_parsed: t.Optional[Response] = attr.field(default=None, repr=False)
def parse(self):
if self._options.return_raw_response:
return self._raw_response
if self._parsed is not None:
return self._parsed
if self._stream:
stream_cls = self._stream_cls or self._client._default_stream_cls
return stream_cls(response_cls=self._response_cls, response=self._raw_response, client=self._client)
content_type, *_ = self._raw_response.headers.get('content-type', '').split(';')
if content_type != 'application/json':
# Since users specific different content_type, then we return the raw binary text without and deserialisation
return self._raw_response.text
data = self._raw_response.json()
try:
return self._client._process_response_data(
data=data, response_cls=self._response_cls, raw_response=self._raw_response
)
except Exception as exc:
raise ValueError(exc) from None # validation error here
@property
def headers(self):
return self._raw_response.headers
@property
def status_code(self):
return self._raw_response.status_code
@property
def request(self):
return self._raw_response.request
@property
def url(self):
return self._raw_response.url
@property
def content(self):
return self._raw_response.content
@property
def text(self):
return self._raw_response.text
@property
def http_version(self):
return self._raw_response.http_version
@attr.define(init=False)
class BaseClient(t.Generic[InnerClient, StreamType]):
_base_url: httpx.URL = attr.field(converter=_address_converter)
_version: LiteralVersion
_timeout: httpx.Timeout = attr.field(converter=httpx.Timeout)
_max_retries: int
_inner: InnerClient
_default_stream_cls: t.Type[StreamType]
_auth_headers: t.Dict[str, str] = attr.field(init=False)
def __init__(
self,
*,
base_url: str | httpx.URL,
version: str,
timeout: int | httpx.Timeout = DEFAULT_TIMEOUT,
max_retries: int = MAX_RETRIES,
client: InnerClient,
_default_stream_cls: t.Type[StreamType],
_internal: bool = False,
):
if not _internal:
raise RuntimeError('Client is reserved to be used internally only.')
self.__attrs_init__(base_url, version, timeout, max_retries, client, _default_stream_cls)
self._auth_headers = self._build_auth_headers()
def _prepare_url(self, url: str) -> httpx.URL:
# copied from httpx._merge_url
merge_url = httpx.URL(url)
if merge_url.is_relative_url:
merge_raw = self._base_url.raw_path + merge_url.raw_path.lstrip(b'/')
return self._base_url.copy_with(raw_path=merge_raw)
return merge_url
@property
def is_closed(self):
return self._inner.is_closed
@property
def is_ready(self):
return not self.is_closed # backward compat
@property
def base_url(self):
return self._base_url
@property
def address(self):
return str(self.base_url)
@base_url.setter
def base_url(self, url):
self._base_url = url if isinstance(url, httpx.URL) else httpx.URL(url)
def _build_auth_headers(self) -> t.Dict[str, str]:
return {} # can be overridden for subclasses for auth support
@property
def auth(self) -> httpx.Auth | None:
return None
@property
def user_agent(self):
return f'{self.__class__.__name__}/Python {self._version}'
@property
def auth_headers(self):
return self._auth_headers
@property
def _default_headers(self) -> t.Dict[str, str]:
return {
'Content-Type': 'application/json',
'Accept': 'application/json',
'User-Agent': self.user_agent,
**self.platform_headers,
**self.auth_headers,
}
@property
def platform_headers(self):
return {
'X-OpenLLM-Client-Package-Version': self._version,
'X-OpenLLM-Client-Language': 'Python',
'X-OpenLLM-Client-Runtime': platform.python_implementation(),
'X-OpenLLM-Client-Runtime-Version': platform.python_version(),
'X-OpenLLM-Client-Arch': str(_architecture()),
'X-OpenLLM-Client-OS': str(_platform()),
}
def _remaining_retries(self, remaining_retries: int | None, options: RequestOptions) -> int:
return remaining_retries if remaining_retries is not None else options.get_max_retries(self._max_retries)
def _build_headers(self, options: RequestOptions) -> httpx.Headers:
return httpx.Headers(_merge_mapping(self._default_headers, options.headers or {}))
def _build_request(self, options: RequestOptions) -> httpx.Request:
return self._inner.build_request(
method=options.method,
headers=self._build_headers(options),
url=self._prepare_url(options.url),
json=options.json,
params=options.params,
)
def _calculate_retry_timeout(
self, remaining_retries: int, options: RequestOptions, headers: t.Optional[httpx.Headers] = None
) -> float:
max_retries = options.get_max_retries(self._max_retries)
# https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After
try:
if headers is not None:
retry_header = headers.get('retry-after')
try:
retry_after = int(retry_header)
except ValueError:
tup = email.utils.parsedate_tz(retry_header)
if tup is None:
retry_after = -1
else:
retry_after = int(email.utils.mktime_tz(tup) - time.time())
else:
retry_after = -1
except Exception:
# omit everything
retry_after = -1
if 0 < retry_after <= 60:
return retry_after # this is reasonable from users
initial_delay, max_delay = 0.5, 8.0
num_retries = max_retries - remaining_retries
sleep = min(initial_delay * pow(2.0, num_retries), max_delay) # apply exponential backoff here
timeout = sleep * (1 - 0.25 * random.random())
return timeout if timeout >= 0 else 0
def _should_retry(self, response: httpx.Response) -> bool:
should_retry_header = response.headers.get('x-should-retry')
if should_retry_header.lower() == 'true':
return True
if should_retry_header.lower() == 'false':
return False
if response.status_code in {408, 409, 429}:
return True
if response.status_code >= 500:
return True
return False
def _process_response_data(
self, *, response_cls: type[Response], data: t.Dict[str, t.Any], raw_response: httpx.Response
) -> Response:
return converter.structure(data, response_cls)
def _process_response(
self,
*,
response_cls: type[Response],
options: RequestOptions,
raw_response: httpx.Response,
stream: bool,
stream_cls: type[_Stream] | type[_AsyncStream] | None,
) -> Response:
return APIResponse(
raw_response=raw_response,
client=self,
response_cls=response_cls,
stream=stream,
stream_cls=stream_cls,
options=options,
).parse()
@attr.define(init=False)
class Client(BaseClient[httpx.Client, Stream[t.Any]]):
def __init__(
self,
base_url: str | httpx.URL,
version: str,
timeout: int | httpx.Timeout = DEFAULT_TIMEOUT,
max_retries: int = MAX_RETRIES,
):
super().__init__(
base_url=base_url,
version=version,
timeout=timeout,
max_retries=max_retries,
client=httpx.Client(base_url=base_url, timeout=timeout),
_default_stream_cls=Stream,
_internal=True,
)
def close(self):
self._inner.close()
def __enter__(self):
return self
def __exit__(self, *args) -> None:
self.close()
def request(
self,
response_cls: type[Response],
options: RequestOptions,
remaining_retries: t.Optional[int] = None,
*,
stream: bool = False,
stream_cls: type[_Stream] | None = None,
) -> Response | _Stream:
return self._request(
response_cls=response_cls,
options=options,
remaining_retries=remaining_retries,
stream=stream,
stream_cls=stream_cls,
)
def _request(
self,
response_cls: type[Response],
options: RequestOptions,
remaining_retries: int | None = None,
*,
stream: bool = False,
stream_cls: type[_Stream] | None = None,
) -> Response | _Stream:
retries = self._remaining_retries(remaining_retries, options)
request = self._build_request(options)
try:
response = self._inner.send(request, auth=self.auth, stream=stream)
logger.debug('HTTP [%s, %s]: %s [%i]', request.method, request.url, response.status_code, response.reason_phrase)
response.raise_for_status()
except httpx.HTTPStatusError as exc:
if retries > 0 and self._should_retry(exc.response):
return self._retry_request(
response_cls, options, retries, exc.response.headers, stream=stream, stream_cls=stream_cls
)
# If the response is streamed then we need to explicitly read the completed response
exc.response.read()
raise ValueError(exc.message) from None
except httpx.TimeoutException:
if retries > 0:
return self._retry_request(response_cls, options, retries, stream=stream, stream_cls=stream_cls)
raise ValueError(request) from None # timeout
except Exception:
if retries > 0:
return self._retry_request(response_cls, options, retries, stream=stream, stream_cls=stream_cls)
raise ValueError(request) from None # connection error
return self._process_response(
response_cls=response_cls, options=options, raw_response=response, stream=stream, stream_cls=stream_cls
)
def _retry_request(
self,
response_cls: type[Response],
options: RequestOptions,
remaining_retries: int,
response_headers: httpx.Headers | None = None,
*,
stream: bool = False,
stream_cls: type[_Stream] | None,
) -> Response | _Stream:
remaining = remaining_retries - 1
timeout = self._calculate_retry_timeout(remaining_retries, options, response_headers)
logger.info('Retrying request to %s in %f seconds', options.url, timeout)
# In synchronous thread we are blocking the thread. Depends on how users want to do this downstream.
time.sleep(timeout)
return self._request(response_cls, options, remaining, stream=stream, stream_cls=stream_cls)
def _get(
self,
path: str,
*,
response_cls: type[Response],
options: dict[str, t.Any] | None = None,
stream: bool = False,
stream_cls: type[_Stream] | None = None,
) -> Response | _Stream:
if options is None:
options = {}
return self.request(
response_cls, RequestOptions(method='GET', url=path, **options), stream=stream, stream_cls=stream_cls
)
def _post(
self,
path: str,
*,
response_cls: type[Response],
json: dict[str, t.Any],
options: dict[str, t.Any] | None = None,
stream: bool = False,
stream_cls: type[_Stream] | None = None,
) -> Response | _Stream:
if options is None:
options = {}
return self.request(
response_cls, RequestOptions(method='POST', url=path, json=json, **options), stream=stream, stream_cls=stream_cls
)
@attr.define(init=False)
class AsyncClient(BaseClient[httpx.AsyncClient, AsyncStream[t.Any]]):
def __init__(
self,
base_url: str | httpx.URL,
version: str,
timeout: int | httpx.Timeout = DEFAULT_TIMEOUT,
max_retries: int = MAX_RETRIES,
):
super().__init__(
base_url=base_url,
version=version,
timeout=timeout,
max_retries=max_retries,
client=httpx.AsyncClient(base_url=base_url, timeout=timeout),
_default_stream_cls=AsyncStream,
_internal=True,
)
async def close(self):
await self._inner.aclose()
async def __aenter__(self):
return self
async def __aexit__(self, *args) -> None:
await self.close()
async def request(
self,
response_cls: type[Response],
options: RequestOptions,
remaining_retries: t.Optional[int] = None,
*,
stream: bool = False,
stream_cls: type[_AsyncStream] | None = None,
) -> Response | _AsyncStream:
return await self._request(
response_cls, options, remaining_retries=remaining_retries, stream=stream, stream_cls=stream_cls
)
async def _request(
self,
response_cls: type[Response],
options: RequestOptions,
remaining_retries: int | None = None,
*,
stream: bool = False,
stream_cls: type[_AsyncStream] | None = None,
) -> Response | _AsyncStream:
retries = self._remaining_retries(remaining_retries, options)
request = self._build_request(options)
try:
response = await self._inner.send(request, auth=self.auth, stream=stream)
logger.debug('HTTP [%s, %s]: %s [%i]', request.method, request.url, response.status_code, response.reason_phrase)
response.raise_for_status()
except httpx.HTTPStatusError as exc:
if retries > 0 and self._should_retry(exc.response):
return self._retry_request(
response_cls, options, retries, exc.response.headers, stream=stream, stream_cls=stream_cls
)
# If the response is streamed then we need to explicitly read the completed response
await exc.response.aread()
raise ValueError(exc.message) from None
except httpx.ConnectTimeout as err:
if retries > 0:
return await self._retry_request(response_cls, options, retries, stream=stream, stream_cls=stream_cls)
raise ValueError(request) from err # timeout
except httpx.ReadTimeout:
# We don't retry on ReadTimeout error, so something might happen on server-side
raise
except httpx.TimeoutException as err:
if retries > 0:
return await self._retry_request(response_cls, options, retries, stream=stream, stream_cls=stream_cls)
raise ValueError(request) from err # timeout
except Exception as err:
if retries > 0:
return await self._retry_request(response_cls, options, retries, stream=stream, stream_cls=stream_cls)
raise ValueError(request) from err # connection error
return self._process_response(
response_cls=response_cls, options=options, raw_response=response, stream=stream, stream_cls=stream_cls
)
async def _retry_request(
self,
response_cls: type[Response],
options: RequestOptions,
remaining_retries: int,
response_headers: httpx.Headers | None = None,
*,
stream: bool,
stream_cls: type[_AsyncStream] | None,
):
remaining = remaining_retries - 1
timeout = self._calculate_retry_timeout(remaining_retries, options, response_headers)
logger.info('Retrying request to %s in %f seconds', options.url, timeout)
await anyio.sleep(timeout)
return await self._request(response_cls, options, remaining, stream=stream, stream_cls=stream_cls)
async def _get(
self,
path: str,
*,
response_cls: type[Response],
options: dict[str, t.Any] | None = None,
stream: bool = False,
stream_cls: type[_AsyncStream] | None = None,
) -> Response | _AsyncStream:
if options is None:
options = {}
return await self.request(
response_cls, RequestOptions(method='GET', url=path, **options), stream=stream, stream_cls=stream_cls
)
async def _post(
self,
path: str,
*,
response_cls: type[Response],
json: dict[str, t.Any],
options: dict[str, t.Any] | None = None,
stream: bool = False,
stream_cls: type[_AsyncStream] | None = None,
) -> Response | _AsyncStream:
if options is None:
options = {}
return await self.request(
response_cls, RequestOptions(method='POST', url=path, json=json, **options), stream=stream, stream_cls=stream_cls
)

View File

@@ -0,0 +1,143 @@
from __future__ import annotations
import typing as t
import attr
import httpx
import orjson
if t.TYPE_CHECKING:
from ._shim import AsyncClient
from ._shim import Client
Response = t.TypeVar('Response', bound=attr.AttrsInstance)
@attr.define(auto_attribs=True)
class Stream(t.Generic[Response]):
_response_cls: t.Type[Response]
_response: httpx.Response
_client: Client
_decoder: SSEDecoder = attr.field(factory=lambda: SSEDecoder())
_iterator: t.Iterator[Response] = attr.field(init=False)
def __init__(self, response_cls, response, client):
self.__attrs_init__(response_cls, response, client)
self._iterator = self._stream()
def __next__(self):
return self._iterator.__next__()
def __iter__(self) -> t.Iterator[Response]:
for item in self._iterator:
yield item
def _iter_events(self) -> t.Iterator[SSE]:
yield from self._decoder.iter(self._response.iter_lines())
def _stream(self) -> t.Iterator[Response]:
for sse in self._iter_events():
if sse.data.startswith('[DONE]'):
break
if sse.event is None:
yield self._client._process_response_data(
data=sse.model_dump(), response_cls=self._response_cls, raw_response=self._response
)
@attr.define(auto_attribs=True)
class AsyncStream(t.Generic[Response]):
_response_cls: t.Type[Response]
_response: httpx.Response
_client: AsyncClient
_decoder: SSEDecoder = attr.field(factory=lambda: SSEDecoder())
_iterator: t.Iterator[Response] = attr.field(init=False)
def __init__(self, response_cls, response, client):
self.__attrs_init__(response_cls, response, client)
self._iterator = self._stream()
async def __anext__(self):
return await self._iterator.__anext__()
async def __aiter__(self):
async for item in self._iterator:
yield item
async def _iter_events(self):
async for sse in self._decoder.aiter(self._response.aiter_lines()):
yield sse
async def _stream(self) -> t.AsyncGenerator[Response, None]:
async for sse in self._iter_events():
if sse.data.startswith('[DONE]'):
break
if sse.event is None:
yield self._client._process_response_data(
data=sse.model_dump(), response_cls=self._response_cls, raw_response=self._response
)
@attr.define
class SSE:
data: str = attr.field(default='')
id: t.Optional[str] = attr.field(default=None)
event: t.Optional[str] = attr.field(default=None)
retry: t.Optional[int] = attr.field(default=None)
def model_dump(self) -> t.Dict[str, t.Any]:
try:
return orjson.loads(self.data)
except orjson.JSONDecodeError:
raise
@attr.define(auto_attribs=True)
class SSEDecoder:
_data: t.List[str] = attr.field(factory=list)
_event: t.Optional[str] = None
_retry: t.Optional[int] = None
_last_event_id: t.Optional[str] = None
def iter(self, iterator: t.Iterator[str]) -> t.Iterator[SSE]:
for line in iterator:
sse = self.decode(line.rstrip('\n'))
if sse:
yield sse
async def aiter(self, iterator: t.AsyncIterator[str]) -> t.AsyncIterator[SSE]:
async for line in iterator:
sse = self.decode(line.rstrip('\n'))
if sse:
yield sse
def decode(self, line: str) -> SSE | None:
# NOTE: https://html.spec.whatwg.org/multipage/server-sent-events.html#event-stream-interpretation
if not line:
if all(not a for a in [self._event, self._data, self._retry, self._last_event_id]):
return None
sse = SSE(data='\n'.join(self._data), event=self._event, retry=self._retry, id=self._last_event_id)
self._event, self._data, self._retry = None, [], None
return sse
if line.startswith(':'):
return None
field, _, value = line.partition(':')
if value.startswith(' '):
value = value[1:]
if field == 'event':
self._event = value
elif field == 'data':
self._data.append(value)
elif field == 'id':
if '\0' in value:
pass
else:
self._last_event_id = value
elif field == 'retry':
try:
self._retry = int(value)
except (TypeError, ValueError):
pass
else:
pass # Ignore unknown fields
return None

View File

@@ -0,0 +1,29 @@
import sys
from typing import Literal
if sys.version_info[:2] >= (3, 11):
from typing import LiteralString as LiteralString
from typing import NotRequired as NotRequired
from typing import Required as Required
from typing import Self as Self
from typing import dataclass_transform as dataclass_transform
from typing import overload as overload
else:
from typing_extensions import LiteralString as LiteralString
from typing_extensions import NotRequired as NotRequired
from typing_extensions import Required as Required
from typing_extensions import Self as Self
from typing_extensions import dataclass_transform as dataclass_transform
from typing_extensions import overload as overload
if sys.version_info[:2] >= (3, 9):
from typing import Annotated as Annotated
else:
from typing_extensions import Annotated as Annotated
Platform = Annotated[
LiteralString, Literal['MacOS', 'Linux', 'Windows', 'FreeBSD', 'OpenBSD', 'iOS', 'iPadOS', 'Android', 'Unknown'], str
]
Architecture = Annotated[LiteralString, Literal['arm', 'arm64', 'x86', 'x86_64', 'Unknown'], str]

View File

@@ -0,0 +1,6 @@
from __future__ import annotations
from cattr import Converter
converter = Converter(omit_if_default=True)

View File

@@ -108,7 +108,7 @@ class Conversation:
elif self.sep_style == SeparatorStyle.LLAMA:
seps = [self.sep, self.sep2]
ret = system_prompt if self.system_message else '<s>[INST] '
for i, (role, message) in enumerate(self.messages):
for i, (_, message) in enumerate(self.messages):
tag = self.roles[i % 2]
if message:
if i == 0:
@@ -139,12 +139,12 @@ class Conversation:
return ret
elif self.sep_style == SeparatorStyle.MPT:
ret = f'<|im_start|>system\n{system_prompt}<|im_end|>{self.sep}' if system_prompt else ''
for i, (role, message) in enumerate(self.messages):
for _, (role, message) in enumerate(self.messages):
ret += f'<|im_start|>{role}\n{message}<|im_end|>{self.sep}' if message else f'{role}:'
return ret
elif self.sep_style == SeparatorStyle.STARCODER:
ret = f'<|system|>\n{system_prompt}<|end|>{self.sep}' if system_prompt else ''
for i, (role, message) in enumerate(self.messages):
for _, (role, message) in enumerate(self.messages):
ret += f'{role}\n{message}<|end|>{self.sep}' if message else f'{role}:'
else:
raise ValueError(f'Invalid style: {self.sep_style}')

View File

@@ -16,26 +16,45 @@ from .utils import gen_random_uuid
if t.TYPE_CHECKING:
import vllm
from ._typing_compat import Self
@attr.frozen(slots=True)
class MetadataOutput:
model_id: str
timeout: int
model_name: str
backend: str
configuration: str
prompt_template: str
system_message: str
@attr.define
class _SchemaMixin:
def model_dump(self) -> dict[str, t.Any]:
return converter.unstructure(self)
def model_dump_json(self) -> str:
return orjson.dumps(self.model_dump(), option=orjson.OPT_INDENT_2).decode('utf-8')
def with_options(self, **options: t.Any) -> Self:
return attr.evolve(self, **options)
@attr.define(slots=True, frozen=True)
class GenerationInput:
@attr.define
class MetadataOutput(_SchemaMixin):
model_id: str
timeout: int
model_name: str
backend: str
configuration: str
prompt_template: t.Optional[str]
system_message: t.Optional[str]
def model_dump(self) -> dict[str, t.Any]:
return {
'model_id': self.model_id,
'timeout': self.timeout,
'model_name': self.model_name,
'backend': self.backend,
'configuration': self.configuration,
'prompt_template': orjson.dumps(self.prompt_template).decode(),
'system_message': orjson.dumps(self.system_message).decode(),
}
@attr.define
class GenerationInput(_SchemaMixin):
prompt: str
llm_config: LLMConfig
stop: list[str] | None = attr.field(default=None)
@@ -53,9 +72,6 @@ class GenerationInput:
'adapter_name': self.adapter_name,
}
def model_dump_json(self) -> str:
return orjson.dumps(self.model_dump(), option=orjson.OPT_INDENT_2).decode('utf-8')
@classmethod
def from_llm_config(cls, llm_config: LLMConfig) -> type[GenerationInput]:
def init(self: GenerationInput, prompt: str, stop: list[str] | None, adapter_name: str | None) -> None:
@@ -80,7 +96,7 @@ class GenerationInput:
def examples(_: type[GenerationInput]) -> dict[str, t.Any]:
return klass(prompt='What is the meaning of life?', llm_config=llm_config, stop=['\n']).model_dump()
setattr(klass, 'examples', classmethod(examples))
klass.examples = classmethod(examples)
try:
klass.__module__ = cls.__module__
@@ -98,7 +114,7 @@ FinishReason = t.Literal['length', 'stop']
@attr.define
class CompletionChunk:
class CompletionChunk(_SchemaMixin):
index: int
text: str
token_ids: t.List[int]
@@ -106,15 +122,12 @@ class CompletionChunk:
logprobs: t.Optional[SampleLogprobs] = None
finish_reason: t.Optional[FinishReason] = None
# yapf: disable
def with_options(self,**options: t.Any)->CompletionChunk: return attr.evolve(self, **options)
def model_dump(self)->dict[str, t.Any]:return converter.unstructure(self)
def model_dump_json(self)->str:return orjson.dumps(self.model_dump(),option=orjson.OPT_NON_STR_KEYS).decode('utf-8')
# yapf: enable
def model_dump_json(self) -> str:
return orjson.dumps(self.model_dump(), option=orjson.OPT_NON_STR_KEYS).decode('utf-8')
@attr.define
class GenerationOutput:
class GenerationOutput(_SchemaMixin):
prompt: str
finished: bool
outputs: t.List[CompletionChunk]
@@ -182,11 +195,5 @@ class GenerationOutput:
prompt_logprobs=request_output.prompt_logprobs,
)
def with_options(self, **options: t.Any) -> GenerationOutput:
return attr.evolve(self, **options)
def model_dump(self) -> dict[str, t.Any]:
return converter.unstructure(self)
def model_dump_json(self) -> str:
return orjson.dumps(self.model_dump(), option=orjson.OPT_NON_STR_KEYS).decode('utf-8')

View File

@@ -70,6 +70,11 @@ else:
from typing_extensions import TypeAlias as TypeAlias
from typing_extensions import TypeGuard as TypeGuard
if sys.version_info[:2] >= (3, 9):
from typing import Annotated as Annotated
else:
from typing_extensions import Annotated as Annotated
class AdapterTuple(TupleAny):
adapter_id: str

View File

@@ -158,11 +158,11 @@ class AutoConfig:
else:
try:
config_file = llm.bentomodel.path_of(CONFIG_FILE_NAME)
except OpenLLMException:
except OpenLLMException as err:
if not is_transformers_available():
raise MissingDependencyError(
"'infer_class_from_llm' requires 'transformers' to be available. Make sure to install it with 'pip install transformers'"
)
) from err
from transformers.utils import cached_file
try:

View File

@@ -101,7 +101,7 @@ class FineTuneConfig:
from peft.mapping import get_peft_config
from peft.utils.peft_types import TaskType
except ImportError:
raise ImportError('PEFT is not installed. Please install it via `pip install "openllm[fine-tune]"`.')
raise ImportError('PEFT is not installed. Please install it via `pip install "openllm[fine-tune]"`.') from None
adapter_config = self.adapter_config.copy()
# no need for peft_type
if 'peft_type' in adapter_config:

View File

@@ -66,7 +66,7 @@ else:
_import_structure: dict[str, list[str]] = {
'exceptions': [],
'client': [],
'client': ['HTTPClient', 'AsyncHTTPClient'],
'bundle': [],
'playground': [],
'testing': [],
@@ -98,6 +98,8 @@ if _t.TYPE_CHECKING:
from . import serialisation as serialisation
from . import testing as testing
from . import utils as utils
from .client import HTTPClient as HTTPClient
from .client import AsyncHTTPClient as AsyncHTTPClient
from ._deprecated import Runner as Runner
from ._generation import LogitsProcessorList as LogitsProcessorList
from ._generation import StopOnTokens as StopOnTokens

View File

@@ -7,9 +7,6 @@ To start any OpenLLM model:
openllm start <model_name> --options ...
"""
from __future__ import annotations
if __name__ == '__main__':
from openllm.cli.entrypoint import cli

View File

@@ -71,7 +71,7 @@ def is_sentence_complete(output: str) -> bool:
def is_partial_stop(output: str, stop_str: str) -> bool:
"""Check whether the output contains a partial stop str."""
for i in range(0, min(len(output), len(stop_str))):
for i in range(min(len(output), len(stop_str))):
if stop_str.startswith(output[-i:]):
return True
return False

View File

@@ -235,7 +235,7 @@ def _validate(cls: type[DynResource], val: list[t.Any]) -> None:
raise RuntimeError('Failed to initialise CUDA runtime binding.')
# correctly parse handle
for el in val:
if el.startswith('GPU-') or el.startswith('MIG-'):
if el.startswith(('GPU-', 'MIG-')):
uuids = _raw_device_uuid_nvml()
if uuids is None:
raise ValueError('Failed to parse available GPUs UUID')

View File

@@ -580,6 +580,6 @@ def append_schemas(
def mk_generate_spec(svc:bentoml.Service,openapi_version:str=OPENAPI_VERSION)->MKSchema:return MKSchema(svc_schema)
def mk_asdict(self:OpenAPISpecification)->dict[str,t.Any]:return svc_schema
openapi.generate_spec=mk_generate_spec
setattr(OpenAPISpecification, 'asdict', mk_asdict)
OpenAPISpecification.asdict = mk_asdict
# yapf: disable
return svc

View File

@@ -1,30 +1,13 @@
"""Serialisation related implementation for GGML-based implementation.
This requires ctransformers to be installed.
"""
from __future__ import annotations
import typing as t
if t.TYPE_CHECKING:
import bentoml
import openllm
from openllm_core._typing_compat import M
_conversion_strategy = {'pt': 'ggml'}
def import_model(
llm: openllm.LLM[t.Any, t.Any], *decls: t.Any, trust_remote_code: bool = True, **attrs: t.Any
) -> bentoml.Model:
def import_model(llm, *decls, trust_remote_code=True, **attrs):
raise NotImplementedError('Currently work in progress.')
def get(llm: openllm.LLM[t.Any, t.Any]) -> bentoml.Model:
def get(llm):
raise NotImplementedError('Currently work in progress.')
def load_model(llm: openllm.LLM[M, t.Any], *decls: t.Any, **attrs: t.Any) -> M:
def load_model(llm, *decls, **attrs):
raise NotImplementedError('Currently work in progress.')

View File

@@ -1,5 +1,3 @@
"""Serialisation related implementation for Transformers-based implementation."""
from __future__ import annotations
import importlib
import logging

View File

@@ -6,6 +6,7 @@ we won't ensure backward compatibility for these functions. So use with caution.
from __future__ import annotations
import functools
import importlib.metadata
import typing as t
import openllm_core
@@ -72,6 +73,7 @@ def generate_labels(llm: openllm.LLM[t.Any, t.Any]) -> dict[str, t.Any]:
'model_name': llm.config['model_name'],
'architecture': llm.config['architecture'],
'serialisation': llm._serialisation,
**{package: importlib.metadata.version(package) for package in {'openllm', 'openllm-core', 'openllm-client'}},
}

View File

@@ -9,6 +9,10 @@ extend-exclude = [
]
extend-include = ["*.ipynb"]
extend-select = [
"E",
"F",
"B",
"PIE",
"I", # isort
"G", # flake8-logging-format
"W", # pycodestyle
@@ -40,11 +44,19 @@ ignore = [
line-length = 119
indent-width = 2
target-version = "py38"
typing-modules = ["openllm_core._typing_compat"]
typing-modules = [
"openllm_core._typing_compat",
"openllm_client._typing_compat",
]
unfixable = ["TCH004"]
[lint.flake8-type-checking]
exempt-modules = ["typing", "typing_extensions", "openllm_core._typing_compat"]
exempt-modules = [
"typing",
"typing_extensions",
"openllm_core._typing_compat",
"openllm_client._typing_compat",
]
runtime-evaluated-base-classes = [
"openllm_core._configuration.LLMConfig",
"openllm_core._configuration.GenerationConfig",
@@ -52,7 +64,15 @@ runtime-evaluated-base-classes = [
"openllm_core._configuration.ModelSettings",
"openllm.LLMConfig",
]
runtime-evaluated-decorators = ["attrs.define", "attrs.frozen", "trait"]
runtime-evaluated-decorators = [
"attrs.define",
"attrs.frozen",
"trait",
"attr.attrs",
'attr.define',
'_attr.define',
'attr.frozen',
]
[format]
preview = true
@@ -65,6 +85,7 @@ convention = "google"
[lint.pycodestyle]
ignore-overlong-task-comments = true
max-line-length = 119
[lint.isort]
combine-as-imports = true