diff --git a/DEVELOPMENT.md b/DEVELOPMENT.md index 5939cc45..006cc33a 100644 --- a/DEVELOPMENT.md +++ b/DEVELOPMENT.md @@ -202,11 +202,8 @@ After you change or update any CI related under `.github`, run `bash tools/lock- See this [docs](/.github/INFRA.md) for more information on OpenLLM's CI/CD workflow. -## UI - -See [ClojureScript UI's README.md](/external/clojure/README.md) for more information. - -See [Documentation's README.md](/docs/README.md) for more information on running locally. +## Typing +For all internal functions, it is recommended to provide type hint. For all public function definitions, it is recommended to create a stubs file `.pyi` to separate supported external API to increase code visibility. See [openllm-client's `__init__.pyi`](/openllm-client/src/openllm_client/__init__.pyi) for example. ## Install from git archive install diff --git a/changelog.d/605.change.md b/changelog.d/605.change.md new file mode 100644 index 00000000..3b920324 --- /dev/null +++ b/changelog.d/605.change.md @@ -0,0 +1 @@ +Update client implementation and support Authentication through `OPENLLM_AUTH_TOKEN` diff --git a/examples/openai_chat_completion_client.py b/examples/openai_chat_completion_client.py index b3ed4256..c4d0b4fe 100644 --- a/examples/openai_chat_completion_client.py +++ b/examples/openai_chat_completion_client.py @@ -1,6 +1,7 @@ # NOTE: Make sure to install openai>1 -import os, openai +import os, openai, typing as t from openai.types.chat import ( + ChatCompletionMessageParam, ChatCompletionSystemMessageParam, ChatCompletionUserMessageParam, ChatCompletionAssistantMessageParam, @@ -14,7 +15,7 @@ model = models.data[0].id # Chat completion API stream = str(os.getenv('STREAM', False)).upper() in ['TRUE', '1', 'YES', 'Y', 'ON'] -messages = [ +messages: t.List[ChatCompletionMessageParam]= [ ChatCompletionSystemMessageParam(role='system', content='You are acting as Ernest Hemmingway.'), ChatCompletionUserMessageParam(role='user', content='Hi there!'), ChatCompletionAssistantMessageParam(role='assistant', content='Yes?'), diff --git a/openllm-client/pyproject.toml b/openllm-client/pyproject.toml index eff8c7bc..fa7baf09 100644 --- a/openllm-client/pyproject.toml +++ b/openllm-client/pyproject.toml @@ -57,7 +57,7 @@ keywords = [ "PyTorch", "Transformers", ] -dependencies = ["orjson", "httpx", "attrs>=23.1.0", "cattrs>=23.1.0"] +dependencies = ["orjson", "httpx", "attrs>=23.1.0", "cattrs>=23.1.0", 'distro', 'anyio'] license = "Apache-2.0" name = "openllm-client" requires-python = ">=3.8" @@ -71,8 +71,9 @@ Homepage = "https://bentoml.com" Tracker = "https://github.com/bentoml/OpenLLM/issues" Twitter = "https://twitter.com/bentomlai" [project.optional-dependencies] -full = ["openllm-client[grpc,agents]"] -grpc = ["bentoml[grpc]>=1.1.6"] +full = ["openllm-client[grpc,agents,auth]"] +grpc = ["bentoml[grpc]>=1.1.9"] +auth = ['httpx_auth'] agents = ["transformers[agents]>=4.30", "diffusers", "soundfile"] [tool.hatch.version] diff --git a/openllm-client/src/openllm_client/__init__.pyi b/openllm-client/src/openllm_client/__init__.pyi index 76292815..58929ebe 100644 --- a/openllm-client/src/openllm_client/__init__.pyi +++ b/openllm-client/src/openllm_client/__init__.pyi @@ -15,26 +15,21 @@ from ._schemas import StreamingResponse as _StreamingResponse @_attr.define class HTTPClient: address: str - client_args: Dict[str, Any] - @staticmethod - def wait_until_server_ready( - addr: str, timeout: float = ..., verify: bool = ..., check_interval: int = ..., **client_args: Any + @overload + def __init__( + self, address: str, timeout: int = ..., verify: bool = ..., max_retries: int = ..., api_version: str = ... ) -> None: ... @overload def __init__( - self, address: str, timeout: int = ..., verify: bool = ..., api_version: str = ..., **client_args: Any + self, address: str = ..., timeout: int = ..., verify: bool = ..., max_retries: int = ..., api_version: str = ... ) -> None: ... @overload def __init__( - self, address: str = ..., timeout: int = ..., verify: bool = ..., api_version: str = ..., **client_args: Any - ) -> None: ... - @overload - def __init__( - self, address: None = ..., timeout: int = ..., verify: bool = ..., api_version: str = ..., **client_args: Any + self, address: None = ..., timeout: int = ..., verify: bool = ..., max_retries: int = ..., api_version: str = ... ) -> None: ... @property def is_ready(self) -> bool: ... - def health(self) -> None: ... + def health(self) -> bool: ... def query(self, prompt: str, **attrs: Any) -> _Response: ... def generate( self, @@ -46,6 +41,16 @@ class HTTPClient: verify: Optional[bool] = ..., **attrs: Any, ) -> _Response: ... + def generate_iterator( + self, + prompt: str, + llm_config: Optional[Dict[str, Any]] = ..., + stop: Optional[Union[str, List[str]]] = ..., + adapter_name: Optional[str] = ..., + timeout: Optional[int] = ..., + verify: Optional[bool] = ..., + **attrs: Any, + ) -> Iterator[_Response]: ... def generate_stream( self, prompt: str, @@ -60,26 +65,21 @@ class HTTPClient: @_attr.define class AsyncHTTPClient: address: str - client_args: Dict[str, Any] - @staticmethod - async def wait_until_server_ready( - addr: str, timeout: float = ..., verify: bool = ..., check_interval: int = ..., **client_args: Any + @overload + def __init__( + self, address: str, timeout: int = ..., verify: bool = ..., max_retries: int = ..., api_version: str = ... ) -> None: ... @overload def __init__( - self, address: str, timeout: int = ..., verify: bool = ..., api_version: str = ..., **client_args: Any + self, address: str = ..., timeout: int = ..., verify: bool = ..., max_retries: int = ..., api_version: str = ... ) -> None: ... @overload def __init__( - self, address: str = ..., timeout: int = ..., verify: bool = ..., api_version: str = ..., **client_args: Any - ) -> None: ... - @overload - def __init__( - self, address: None = ..., timeout: int = ..., verify: bool = ..., api_version: str = ..., **client_args: Any + self, address: None = ..., timeout: int = ..., verify: bool = ..., max_retries: int = ..., api_version: str = ... ) -> None: ... @property def is_ready(self) -> bool: ... - async def health(self) -> None: ... + async def health(self) -> bool: ... async def query(self, prompt: str, **attrs: Any) -> _Response: ... async def generate( self, @@ -91,6 +91,16 @@ class AsyncHTTPClient: verify: Optional[bool] = ..., **attrs: Any, ) -> _Response: ... + async def generate_iterator( + self, + prompt: str, + llm_config: Optional[Dict[str, Any]] = ..., + stop: Optional[Union[str, List[str]]] = ..., + adapter_name: Optional[str] = ..., + timeout: Optional[int] = ..., + verify: Optional[bool] = ..., + **attrs: Any, + ) -> Iterator[_Response]: ... async def generate_stream( self, prompt: str, diff --git a/openllm-client/src/openllm_client/_http.py b/openllm-client/src/openllm_client/_http.py index 6f9b191e..45a6c7f9 100644 --- a/openllm-client/src/openllm_client/_http.py +++ b/openllm-client/src/openllm_client/_http.py @@ -1,349 +1,240 @@ from __future__ import annotations import asyncio import enum +import importlib.metadata import logging import os -import time import typing as t -import urllib.error - -from urllib.parse import urlparse import attr -import httpx -import orjson -from ._schemas import Request +from ._schemas import MetadataOutput from ._schemas import Response from ._schemas import StreamingResponse +from ._shim import MAX_RETRIES +from ._shim import AsyncClient +from ._shim import Client logger = logging.getLogger(__name__) - -def _address_validator(_, attr, value): - if not isinstance(value, str): - raise TypeError(f'{attr.name} must be a string') - if not urlparse(value).netloc: - raise ValueError(f'{attr.name} must be a valid URL') +VERSION = importlib.metadata.version('openllm-client') def _address_converter(addr: str): return addr if '://' in addr else 'http://' + addr -class ServerState(enum.Enum): +class ClientState(enum.Enum): CLOSED = 1 # CLOSED: The server is not yet ready or `wait_until_server_ready` has not been called/failed. READY = 2 # READY: The server is ready and `wait_until_server_ready` has been called. - - -_object_setattr = object.__setattr__ + DISCONNECTED = 3 # DISCONNECTED: The server is disconnected and `wait_until_server_ready` has been called. @attr.define(init=False) -class HTTPClient: - address: str = attr.field(validator=_address_validator, converter=_address_converter) - client_args: t.Dict[str, t.Any] - - _inner: httpx.Client - _timeout: int = 30 +class HTTPClient(Client): _api_version: str = 'v1' _verify: bool = True - _state: ServerState = ServerState.CLOSED - - __metadata: dict[str, t.Any] | None = None + __metadata: MetadataOutput | None = None __config: dict[str, t.Any] | None = None def __repr__(self): - return ( - f'' - ) + return f'' - @staticmethod - def wait_until_server_ready(addr, timeout=30, verify=False, check_interval=1, **client_args): - addr = _address_converter(addr) - logger.debug('Wait for server @ %s to be ready', addr) - start = time.monotonic() - while time.monotonic() - start < timeout: - try: - with httpx.Client(base_url=addr, verify=verify, **client_args) as sess: - status = sess.get('/readyz').status_code - if status == 200: - break - else: - time.sleep(check_interval) - except (httpx.ConnectError, urllib.error.URLError, ConnectionError): - logger.debug('Server is not ready yet, retrying in %d seconds...', check_interval) - time.sleep(check_interval) - # Try once more and raise for exception - try: - with httpx.Client(base_url=addr, verify=verify, **client_args) as sess: - status = sess.get('/readyz').status_code - except httpx.HTTPStatusError as err: - logger.error('Failed to wait until server ready: %s', addr) - logger.error(err) - raise - - def __init__(self, address=None, timeout=30, verify=False, api_version='v1', **client_args): + def __init__(self, address=None, timeout=30, verify=False, max_retries=MAX_RETRIES, api_version='v1'): if address is None: - env = os.getenv('OPENLLM_ENDPOINT') - if env is None: - raise ValueError('address must be provided') - address = env - self.__attrs_init__( - address, - client_args, - httpx.Client(base_url=address, timeout=timeout, verify=verify, **client_args), - timeout, - api_version, - verify, - ) + address = os.getenv('OPENLLM_ENDPOINT') + if address is None: + raise ValueError("address must either be provided or through 'OPENLLM_ENDPOINT'") + self._api_version, self._verify = api_version, verify + super().__init__(_address_converter(address), VERSION, timeout=timeout, max_retries=max_retries) - def _metadata(self) -> dict[str, t.Any]: - if self.__metadata is None: - self.__metadata = self._inner.post(self._build_endpoint('metadata')).json() - return self.__metadata - - def _config(self) -> dict[str, t.Any]: - if self.__config is None: - config = orjson.loads(self._metadata()['configuration']) - generation_config = config.pop('generation_config') - self.__config = {**config, **generation_config} - return self.__config + def _build_auth_headers(self) -> t.Dict[str, str]: + env = os.getenv('OPENLLM_AUTH_TOKEN') + if env is not None: + return {'Authorization': f'Bearer {env}'} + return super()._build_auth_headers() def __del__(self): - self._inner.close() - - def _build_endpoint(self, endpoint): - return ('/' if not self._api_version.startswith('/') else '') + f'{self._api_version}/{endpoint}' + self.close() @property - def is_ready(self): - return self._state == ServerState.READY + def _metadata(self): + if self.__metadata is None: + path = f'/{self._api_version}/metadata' + self.__metadata = self._post( + path, response_cls=MetadataOutput, json={}, options={'max_retries': self._max_retries} + ) + return self.__metadata + + @property + def _config(self) -> dict[str, t.Any]: + if self.__config is None: + self.__config = self._metadata.configuration + return self.__config def query(self, prompt, **attrs): return self.generate(prompt, **attrs) def health(self): - try: - self.wait_until_server_ready(self.address, timeout=self._timeout, verify=self._verify, **self.client_args) - _object_setattr(self, '_state', ServerState.READY) - except Exception as e: - logger.error('Server is not healthy (Scroll up for traceback)\n%s', e) - _object_setattr(self, '_state', ServerState.CLOSED) + response = self._get( + '/readyz', response_cls=None, options={'return_raw_response': True, 'max_retries': self._max_retries} + ) + return response.status_code == 200 def generate( self, prompt, llm_config=None, stop=None, adapter_name=None, timeout=None, verify=None, **attrs ) -> Response: - if not self.is_ready: - self.health() - if not self.is_ready: - raise RuntimeError('Server is not ready. Check server logs for more information.') if timeout is None: timeout = self._timeout if verify is None: - verify = self._verify - _meta, _config = self._metadata(), self._config() + verify = self._verify # XXX: need to support this again if llm_config is not None: - llm_config = {**_config, **llm_config, **attrs} + llm_config = {**self._config, **llm_config, **attrs} else: - llm_config = {**_config, **attrs} - if _meta['prompt_template'] is not None: - prompt = _meta['prompt_template'].format(system_message=_meta['system_message'], instruction=prompt) + llm_config = {**self._config, **attrs} + if self._metadata.prompt_template is not None: + prompt = self._metadata.prompt_template.format(system_message=self._metadata.system_message, instruction=prompt) - req = Request(prompt=prompt, llm_config=llm_config, stop=stop, adapter_name=adapter_name) - with httpx.Client(base_url=self.address, timeout=timeout, verify=verify, **self.client_args) as client: - r = client.post(self._build_endpoint('generate'), json=req.model_dump_json(), **self.client_args) - if r.status_code != 200: - raise ValueError("Failed to get generation from '/v1/generate'. Check server logs for more details.") - return Response.model_construct(r.json()) + return self._post( + f'/{self._api_version}/generate', + response_cls=Response, + json=dict(prompt=prompt, llm_config=llm_config, stop=stop, adapter_name=adapter_name), + ) def generate_stream( self, prompt, llm_config=None, stop=None, adapter_name=None, timeout=None, verify=None, **attrs ) -> t.Iterator[StreamingResponse]: - if not self.is_ready: - self.health() - if not self.is_ready: - raise RuntimeError('Server is not ready. Check server logs for more information.') + for response_chunk in self.generate_iterator(prompt, llm_config, stop, adapter_name, timeout, verify, **attrs): + yield StreamingResponse.from_response_chunk(response_chunk) + + def generate_iterator( + self, prompt, llm_config=None, stop=None, adapter_name=None, timeout=None, verify=None, **attrs + ) -> t.Iterator[Response]: if timeout is None: timeout = self._timeout if verify is None: - verify = self._verify - _meta, _config = self._metadata(), self._config() + verify = self._verify # XXX: need to support this again if llm_config is not None: - llm_config = {**_config, **llm_config, **attrs} + llm_config = {**self._config, **llm_config, **attrs} else: - llm_config = {**_config, **attrs} - if _meta['prompt_template'] is not None: - prompt = _meta['prompt_template'].format(system_message=_meta['system_message'], instruction=prompt) - - req = Request(prompt=prompt, llm_config=llm_config, stop=stop, adapter_name=adapter_name) - with httpx.Client(base_url=self.address, timeout=timeout, verify=verify, **self.client_args) as client: - with client.stream( - 'POST', self._build_endpoint('generate_stream'), json=req.model_dump_json(), **self.client_args - ) as r: - for payload in r.iter_bytes(): - if payload == b'data: [DONE]\n\n': - break - # Skip line - if payload == b'\n': - continue - if payload.startswith(b'data: '): - try: - proc = payload.decode('utf-8').lstrip('data: ').rstrip('\n') - data = orjson.loads(proc) - yield StreamingResponse.from_response_chunk(Response.model_construct(data)) - except Exception: - pass # FIXME: Handle this + llm_config = {**self._config, **attrs} + if self._metadata.prompt_template is not None: + prompt = self._metadata.prompt_template.format(system_message=self._metadata.system_message, instruction=prompt) + return self._post( + f'/{self._api_version}/generate_stream', + response_cls=Response, + json=dict(prompt=prompt, llm_config=llm_config, stop=stop, adapter_name=adapter_name), + stream=True, + ) @attr.define(init=False) -class AsyncHTTPClient: - address: str = attr.field(validator=_address_validator, converter=_address_converter) - client_args: t.Dict[str, t.Any] - - _inner: httpx.AsyncClient - _timeout: int = 30 +class AsyncHTTPClient(AsyncClient): _api_version: str = 'v1' _verify: bool = True - _state: ServerState = ServerState.CLOSED - - __metadata: dict[str, t.Any] | None = None + __metadata: MetadataOutput | None = None __config: dict[str, t.Any] | None = None def __repr__(self): - return f'' + return f'' - @staticmethod - async def wait_until_server_ready(addr, timeout=30, verify=False, check_interval=1, **client_args): - addr = _address_converter(addr) - logger.debug('Wait for server @ %s to be ready', addr) - start = time.monotonic() - while time.monotonic() - start < timeout: - try: - async with httpx.AsyncClient(base_url=addr, verify=verify, **client_args) as sess: - status = (await sess.get('/readyz')).status_code - if status == 200: - break - else: - await asyncio.sleep(check_interval) - except (httpx.ConnectError, urllib.error.URLError, ConnectionError): - logger.debug('Server is not ready yet, retrying in %d seconds...', check_interval) - await asyncio.sleep(check_interval) - # Try once more and raise for exception - try: - async with httpx.AsyncClient(base_url=addr, verify=verify, **client_args) as sess: - status = (await sess.get('/readyz')).status_code - except httpx.HTTPStatusError as err: - logger.error('Failed to wait until server ready: %s', addr) - logger.error(err) - raise - - def __init__(self, address=None, timeout=30, verify=False, api_version='v1', **client_args): + def __init__(self, address=None, timeout=30, verify=False, max_retries=MAX_RETRIES, api_version='v1'): if address is None: - env = os.getenv('OPENLLM_ENDPOINT') - if env is None: - raise ValueError('address must be provided') - address = env - self.__attrs_init__( - address, - client_args, - httpx.AsyncClient(base_url=address, timeout=timeout, verify=verify, **client_args), - timeout, - api_version, - verify, - ) + address = os.getenv('OPENLLM_ENDPOINT') + if address is None: + raise ValueError("address must either be provided or through 'OPENLLM_ENDPOINT'") + self._api_version, self._verify = api_version, verify + super().__init__(_address_converter(address), VERSION, timeout=timeout, max_retries=max_retries) - async def _metadata(self) -> dict[str, t.Any]: - if self.__metadata is None: - self.__metadata = (await self._inner.post(self._build_endpoint('metadata'))).json() - return self.__metadata - - async def _config(self) -> dict[str, t.Any]: - if self.__config is None: - config = orjson.loads((await self._metadata())['configuration']) - generation_config = config.pop('generation_config') - self.__config = {**config, **generation_config} - return self.__config - - def _build_endpoint(self, endpoint): - return '/' + f'{self._api_version}/{endpoint}' + def _build_auth_headers(self) -> t.Dict[str, str]: + env = os.getenv('OPENLLM_AUTH_TOKEN') + if env is not None: + return {'Authorization': f'Bearer {env}'} + return super()._build_auth_headers() @property - def is_ready(self): - return self._state == ServerState.READY + def _loop(self) -> asyncio.AbstractEventLoop: + try: + return asyncio.get_running_loop() + except RuntimeError: + return asyncio.get_event_loop() + + @property + async def _metadata(self) -> t.Awaitable[MetadataOutput]: + if self.__metadata is None: + self.__metadata = await self._post( + f'/{self._api_version}/metadata', + response_cls=MetadataOutput, + json={}, + options={'max_retries': self._max_retries}, + ) + return self.__metadata + + @property + async def _config(self): + if self.__config is None: + self.__config = (await self._metadata).configuration + return self.__config async def query(self, prompt, **attrs): return await self.generate(prompt, **attrs) async def health(self): - try: - await self.wait_until_server_ready(self.address, timeout=self._timeout, verify=self._verify, **self.client_args) - _object_setattr(self, '_state', ServerState.READY) - except Exception as e: - logger.error('Server is not healthy (Scroll up for traceback)\n%s', e) - _object_setattr(self, '_state', ServerState.CLOSED) + response = await self._get( + '/readyz', response_cls=None, options={'return_raw_response': True, 'max_retries': self._max_retries} + ) + return response.status_code == 200 async def generate( self, prompt, llm_config=None, stop=None, adapter_name=None, timeout=None, verify=None, **attrs ) -> Response: - if not self.is_ready: - await self.health() - if not self.is_ready: - raise RuntimeError('Server is not ready. Check server logs for more information.') if timeout is None: timeout = self._timeout if verify is None: - verify = self._verify - _meta, _config = await self._metadata(), await self._config() + verify = self._verify # XXX: need to support this again + _metadata = await self._metadata + _config = await self._config if llm_config is not None: llm_config = {**_config, **llm_config, **attrs} else: llm_config = {**_config, **attrs} - if _meta['prompt_template'] is not None: - prompt = _meta['prompt_template'].format(system_message=_meta['system_message'], instruction=prompt) - - req = Request(prompt=prompt, llm_config=llm_config, stop=stop, adapter_name=adapter_name) - async with httpx.AsyncClient(base_url=self.address, timeout=timeout, verify=verify, **self.client_args) as client: - r = await client.post(self._build_endpoint('generate'), json=req.model_dump_json(), **self.client_args) - if r.status_code != 200: - raise ValueError("Failed to get generation from '/v1/generate'. Check server logs for more details.") - return Response.model_construct(r.json()) + if _metadata.prompt_template is not None: + prompt = _metadata.prompt_template.format(system_message=_metadata.system_message, instruction=prompt) + return await self._post( + f'/{self._api_version}/generate', + response_cls=Response, + json=dict(prompt=prompt, llm_config=llm_config, stop=stop, adapter_name=adapter_name), + ) async def generate_stream( self, prompt, llm_config=None, stop=None, adapter_name=None, timeout=None, verify=None, **attrs ) -> t.AsyncGenerator[StreamingResponse, t.Any]: - if not self.is_ready: - await self.health() - if not self.is_ready: - raise RuntimeError('Server is not ready. Check server logs for more information.') + async for response_chunk in self.generate_iterator( + prompt, llm_config, stop, adapter_name, timeout, verify, **attrs + ): + yield StreamingResponse.from_response_chunk(response_chunk) + + async def generate_iterator( + self, prompt, llm_config=None, stop=None, adapter_name=None, timeout=None, verify=None, **attrs + ) -> t.AsyncGenerator[Response, t.Any]: if timeout is None: timeout = self._timeout if verify is None: - verify = self._verify - _meta, _config = await self._metadata(), await self._config() + verify = self._verify # XXX: need to support this again + _metadata = await self._metadata + _config = await self._config if llm_config is not None: llm_config = {**_config, **llm_config, **attrs} else: llm_config = {**_config, **attrs} - if _meta['prompt_template'] is not None: - prompt = _meta['prompt_template'].format(system_message=_meta['system_message'], instruction=prompt) + if _metadata.prompt_template is not None: + prompt = _metadata.prompt_template.format(system_message=_metadata.system_message, instruction=prompt) - req = Request(prompt=prompt, llm_config=llm_config, stop=stop, adapter_name=adapter_name) - async with httpx.AsyncClient(base_url=self.address, timeout=timeout, verify=verify, **self.client_args) as client: - async with client.stream( - 'POST', self._build_endpoint('generate_stream'), json=req.model_dump_json(), **self.client_args - ) as r: - async for payload in r.aiter_bytes(): - if payload == b'data: [DONE]\n\n': - break - # Skip line - if payload == b'\n': - continue - if payload.startswith(b'data: '): - try: - proc = payload.decode('utf-8').lstrip('data: ').rstrip('\n') - data = orjson.loads(proc) - yield StreamingResponse.from_response_chunk(Response.model_construct(data)) - except Exception: - pass # FIXME: Handle this + async for response_chunk in await self._post( + f'/{self._api_version}/generate_stream', + response_cls=Response, + json=dict(prompt=prompt, llm_config=llm_config, stop=stop, adapter_name=adapter_name), + stream=True, + ): + yield response_chunk diff --git a/openllm-client/src/openllm_client/_schemas.py b/openllm-client/src/openllm_client/_schemas.py index d29fd46b..235e4896 100644 --- a/openllm-client/src/openllm_client/_schemas.py +++ b/openllm-client/src/openllm_client/_schemas.py @@ -3,9 +3,47 @@ import typing as t import attr import cattr +import orjson + +from ._utils import converter # XXX: sync with openllm-core/src/openllm_core/_schemas.py +@attr.define +class MetadataOutput: + model_id: str + timeout: int + model_name: str + backend: str + configuration: t.Dict[str, t.Any] + prompt_template: t.Optional[str] + system_message: t.Optional[str] + + +def _structure_metadata(data: t.Dict[str, t.Any], cls: type[MetadataOutput]) -> MetadataOutput: + try: + configuration = orjson.loads(data['configuration']) + generation_config = configuration.pop('generation_config') + configuration = {**configuration, **generation_config} + except orjson.JSONDecodeError as e: + raise RuntimeError(f'Malformed metadata configuration (Server-side issue): {e}') from None + try: + return cls( + model_id=data['model_id'], + timeout=data['timeout'], + model_name=data['model_name'], + backend=data['backend'], + configuration=configuration, + prompt_template=data['prompt_template'], + system_message=data['system_message'], + ) + except Exception as e: + raise RuntimeError(f'Malformed metadata (Server-side issue): {e}') from None + + +converter.register_structure_hook(MetadataOutput, _structure_metadata) + + @attr.define class Request: prompt: str diff --git a/openllm-client/src/openllm_client/_shim.py b/openllm-client/src/openllm_client/_shim.py new file mode 100644 index 00000000..fcb8b7de --- /dev/null +++ b/openllm-client/src/openllm_client/_shim.py @@ -0,0 +1,610 @@ +# This provides a base shim with httpx and acts as base request +from __future__ import annotations +import email.utils +import logging +import platform +import random +import time +import typing as t + +import anyio +import attr +import distro +import httpx + +from ._stream import AsyncStream +from ._stream import Response +from ._stream import Stream +from ._typing_compat import Architecture +from ._typing_compat import LiteralString +from ._typing_compat import Platform +from ._utils import converter + + +logger = logging.getLogger(__name__) + +InnerClient = t.TypeVar('InnerClient', bound=t.Union[httpx.Client, httpx.AsyncClient]) +StreamType = t.TypeVar('StreamType', bound=t.Union[Stream[t.Any], AsyncStream[t.Any]]) +_Stream = t.TypeVar('_Stream', bound=Stream[t.Any]) +_AsyncStream = t.TypeVar('_AsyncStream', bound=AsyncStream[t.Any]) +LiteralVersion = t.Annotated[LiteralString, t.Literal['v1'], str] + + +def _address_converter(addr: str | httpx.URL) -> httpx.URL: + if isinstance(addr, httpx.URL): + url = addr + else: + url = httpx.URL(addr if '://' in addr else f'http://{addr}') + if not url.raw_path.endswith(b'/'): + url = url.copy_with(path=url.raw_path + b'/') + return url + + +MAX_RETRIES = 2 +DEFAULT_TIMEOUT = httpx.Timeout(5.0) # Similar to httpx + + +_T_co = t.TypeVar('_T_co', covariant=True) +_T = t.TypeVar('_T') + + +def _merge_mapping(a: t.Mapping[_T_co, _T], b: t.Mapping[_T_co, _T]) -> t.Dict[_T_co, _T]: + # does the merging and filter out None + return {k: v for k, v in {**a, **b}.items() if v is not None} + + +def _platform() -> Platform: + system = platform.system().lower() + platform_name = platform.platform().lower() + if system == 'darwin': + return 'MacOS' + elif system == 'windows': + return 'Windows' + elif system == 'linux': + distro_id = distro.id() + if distro_id == 'freebsd': + return 'FreeBSD' + elif distro_id == 'openbsd': + return 'OpenBSD' + else: + return 'Linux' + elif 'android' in platform_name: + return 'Android' + elif 'iphone' in platform_name: + return 'iOS' + elif 'ipad' in platform_name: + return 'iPadOS' + if platform_name: + return f'Other:{platform_name}' + return 'Unknown' + + +def _architecture() -> Architecture: + machine = platform.machine().lower() + if machine in {'arm64', 'aarch64'}: + return 'arm64' + elif machine in {'arm', 'aarch32'}: + return 'arm' + elif machine in {'x86_64', 'amd64'}: + return 'x86_64' + elif machine in {'x86', 'i386', 'i686'}: + return 'x86' + elif machine: + return f'Other:{machine}' + return 'Unknown' + + +@t.final +@attr.frozen(auto_attribs=True) +class RequestOptions: + method: str = attr.field(converter=str.lower) + url: str + json: t.Optional[t.Dict[str, t.Any]] = attr.field(default=None) + params: t.Optional[t.Mapping[str, t.Any]] = attr.field(default=None) + headers: t.Optional[t.Dict[str, str]] = attr.field(default=None) + max_retries: int = attr.field(default=MAX_RETRIES) + return_raw_response: bool = attr.field(default=False) + + def get_max_retries(self, max_retries: int | None) -> int: + return max_retries if max_retries is not None else self.max_retries + + @classmethod + def model_construct(cls, **options: t.Any) -> RequestOptions: + return cls(**options) + + +@t.final +@attr.frozen(auto_attribs=True) +class APIResponse(t.Generic[Response]): + _raw_response: httpx.Response + _client: BaseClient[t.Any, t.Any] + _response_cls: type[Response] + _stream: bool + _stream_cls: t.Optional[t.Union[t.Type[Stream[t.Any]], t.Type[AsyncStream[t.Any]]]] + _options: RequestOptions + + _parsed: t.Optional[Response] = attr.field(default=None, repr=False) + + def parse(self): + if self._options.return_raw_response: + return self._raw_response + if self._parsed is not None: + return self._parsed + if self._stream: + stream_cls = self._stream_cls or self._client._default_stream_cls + return stream_cls(response_cls=self._response_cls, response=self._raw_response, client=self._client) + + content_type, *_ = self._raw_response.headers.get('content-type', '').split(';') + if content_type != 'application/json': + # Since users specific different content_type, then we return the raw binary text without and deserialisation + return self._raw_response.text + + data = self._raw_response.json() + try: + return self._client._process_response_data( + data=data, response_cls=self._response_cls, raw_response=self._raw_response + ) + except Exception as exc: + raise ValueError(exc) from None # validation error here + + @property + def headers(self): + return self._raw_response.headers + + @property + def status_code(self): + return self._raw_response.status_code + + @property + def request(self): + return self._raw_response.request + + @property + def url(self): + return self._raw_response.url + + @property + def content(self): + return self._raw_response.content + + @property + def text(self): + return self._raw_response.text + + @property + def http_version(self): + return self._raw_response.http_version + + +@attr.define(init=False) +class BaseClient(t.Generic[InnerClient, StreamType]): + _base_url: httpx.URL = attr.field(converter=_address_converter) + _version: LiteralVersion + _timeout: httpx.Timeout = attr.field(converter=httpx.Timeout) + _max_retries: int + _inner: InnerClient + _default_stream_cls: t.Type[StreamType] + _auth_headers: t.Dict[str, str] = attr.field(init=False) + + def __init__( + self, + *, + base_url: str | httpx.URL, + version: str, + timeout: int | httpx.Timeout = DEFAULT_TIMEOUT, + max_retries: int = MAX_RETRIES, + client: InnerClient, + _default_stream_cls: t.Type[StreamType], + _internal: bool = False, + ): + if not _internal: + raise RuntimeError('Client is reserved to be used internally only.') + self.__attrs_init__(base_url, version, timeout, max_retries, client, _default_stream_cls) + self._auth_headers = self._build_auth_headers() + + def _prepare_url(self, url: str) -> httpx.URL: + # copied from httpx._merge_url + merge_url = httpx.URL(url) + if merge_url.is_relative_url: + merge_raw = self._base_url.raw_path + merge_url.raw_path.lstrip(b'/') + return self._base_url.copy_with(raw_path=merge_raw) + return merge_url + + @property + def is_closed(self): + return self._inner.is_closed + + @property + def is_ready(self): + return not self.is_closed # backward compat + + @property + def base_url(self): + return self._base_url + + @property + def address(self): + return str(self.base_url) + + @base_url.setter + def base_url(self, url): + self._base_url = url if isinstance(url, httpx.URL) else httpx.URL(url) + + def _build_auth_headers(self) -> t.Dict[str, str]: + return {} # can be overridden for subclasses for auth support + + @property + def auth(self) -> httpx.Auth | None: + return None + + @property + def user_agent(self): + return f'{self.__class__.__name__}/Python {self._version}' + + @property + def auth_headers(self): + return self._auth_headers + + @property + def _default_headers(self) -> t.Dict[str, str]: + return { + 'Content-Type': 'application/json', + 'Accept': 'application/json', + 'User-Agent': self.user_agent, + **self.platform_headers, + **self.auth_headers, + } + + @property + def platform_headers(self): + return { + 'X-OpenLLM-Client-Package-Version': self._version, + 'X-OpenLLM-Client-Language': 'Python', + 'X-OpenLLM-Client-Runtime': platform.python_implementation(), + 'X-OpenLLM-Client-Runtime-Version': platform.python_version(), + 'X-OpenLLM-Client-Arch': str(_architecture()), + 'X-OpenLLM-Client-OS': str(_platform()), + } + + def _remaining_retries(self, remaining_retries: int | None, options: RequestOptions) -> int: + return remaining_retries if remaining_retries is not None else options.get_max_retries(self._max_retries) + + def _build_headers(self, options: RequestOptions) -> httpx.Headers: + return httpx.Headers(_merge_mapping(self._default_headers, options.headers or {})) + + def _build_request(self, options: RequestOptions) -> httpx.Request: + return self._inner.build_request( + method=options.method, + headers=self._build_headers(options), + url=self._prepare_url(options.url), + json=options.json, + params=options.params, + ) + + def _calculate_retry_timeout( + self, remaining_retries: int, options: RequestOptions, headers: t.Optional[httpx.Headers] = None + ) -> float: + max_retries = options.get_max_retries(self._max_retries) + # https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After + try: + if headers is not None: + retry_header = headers.get('retry-after') + try: + retry_after = int(retry_header) + except ValueError: + tup = email.utils.parsedate_tz(retry_header) + if tup is None: + retry_after = -1 + else: + retry_after = int(email.utils.mktime_tz(tup) - time.time()) + else: + retry_after = -1 + except Exception: + # omit everything + retry_after = -1 + if 0 < retry_after <= 60: + return retry_after # this is reasonable from users + initial_delay, max_delay = 0.5, 8.0 + num_retries = max_retries - remaining_retries + + sleep = min(initial_delay * pow(2.0, num_retries), max_delay) # apply exponential backoff here + timeout = sleep * (1 - 0.25 * random.random()) + return timeout if timeout >= 0 else 0 + + def _should_retry(self, response: httpx.Response) -> bool: + should_retry_header = response.headers.get('x-should-retry') + if should_retry_header.lower() == 'true': + return True + if should_retry_header.lower() == 'false': + return False + if response.status_code in {408, 409, 429}: + return True + if response.status_code >= 500: + return True + return False + + def _process_response_data( + self, *, response_cls: type[Response], data: t.Dict[str, t.Any], raw_response: httpx.Response + ) -> Response: + return converter.structure(data, response_cls) + + def _process_response( + self, + *, + response_cls: type[Response], + options: RequestOptions, + raw_response: httpx.Response, + stream: bool, + stream_cls: type[_Stream] | type[_AsyncStream] | None, + ) -> Response: + return APIResponse( + raw_response=raw_response, + client=self, + response_cls=response_cls, + stream=stream, + stream_cls=stream_cls, + options=options, + ).parse() + + +@attr.define(init=False) +class Client(BaseClient[httpx.Client, Stream[t.Any]]): + def __init__( + self, + base_url: str | httpx.URL, + version: str, + timeout: int | httpx.Timeout = DEFAULT_TIMEOUT, + max_retries: int = MAX_RETRIES, + ): + super().__init__( + base_url=base_url, + version=version, + timeout=timeout, + max_retries=max_retries, + client=httpx.Client(base_url=base_url, timeout=timeout), + _default_stream_cls=Stream, + _internal=True, + ) + + def close(self): + self._inner.close() + + def __enter__(self): + return self + + def __exit__(self, *args) -> None: + self.close() + + def request( + self, + response_cls: type[Response], + options: RequestOptions, + remaining_retries: t.Optional[int] = None, + *, + stream: bool = False, + stream_cls: type[_Stream] | None = None, + ) -> Response | _Stream: + return self._request( + response_cls=response_cls, + options=options, + remaining_retries=remaining_retries, + stream=stream, + stream_cls=stream_cls, + ) + + def _request( + self, + response_cls: type[Response], + options: RequestOptions, + remaining_retries: int | None = None, + *, + stream: bool = False, + stream_cls: type[_Stream] | None = None, + ) -> Response | _Stream: + retries = self._remaining_retries(remaining_retries, options) + request = self._build_request(options) + try: + response = self._inner.send(request, auth=self.auth, stream=stream) + logger.debug('HTTP [%s, %s]: %s [%i]', request.method, request.url, response.status_code, response.reason_phrase) + response.raise_for_status() + except httpx.HTTPStatusError as exc: + if retries > 0 and self._should_retry(exc.response): + return self._retry_request( + response_cls, options, retries, exc.response.headers, stream=stream, stream_cls=stream_cls + ) + # If the response is streamed then we need to explicitly read the completed response + exc.response.read() + raise ValueError(exc.message) from None + except httpx.TimeoutException: + if retries > 0: + return self._retry_request(response_cls, options, retries, stream=stream, stream_cls=stream_cls) + raise ValueError(request) from None # timeout + except Exception: + if retries > 0: + return self._retry_request(response_cls, options, retries, stream=stream, stream_cls=stream_cls) + raise ValueError(request) from None # connection error + + return self._process_response( + response_cls=response_cls, options=options, raw_response=response, stream=stream, stream_cls=stream_cls + ) + + def _retry_request( + self, + response_cls: type[Response], + options: RequestOptions, + remaining_retries: int, + response_headers: httpx.Headers | None = None, + *, + stream: bool = False, + stream_cls: type[_Stream] | None, + ) -> Response | _Stream: + remaining = remaining_retries - 1 + timeout = self._calculate_retry_timeout(remaining_retries, options, response_headers) + logger.info('Retrying request to %s in %f seconds', options.url, timeout) + # In synchronous thread we are blocking the thread. Depends on how users want to do this downstream. + time.sleep(timeout) + return self._request(response_cls, options, remaining, stream=stream, stream_cls=stream_cls) + + def _get( + self, + path: str, + *, + response_cls: type[Response], + options: dict[str, t.Any] | None = None, + stream: bool = False, + stream_cls: type[_Stream] | None = None, + ) -> Response | _Stream: + if options is None: + options = {} + return self.request( + response_cls, RequestOptions(method='GET', url=path, **options), stream=stream, stream_cls=stream_cls + ) + + def _post( + self, + path: str, + *, + response_cls: type[Response], + json: dict[str, t.Any], + options: dict[str, t.Any] | None = None, + stream: bool = False, + stream_cls: type[_Stream] | None = None, + ) -> Response | _Stream: + if options is None: + options = {} + return self.request( + response_cls, RequestOptions(method='POST', url=path, json=json, **options), stream=stream, stream_cls=stream_cls + ) + + +@attr.define(init=False) +class AsyncClient(BaseClient[httpx.AsyncClient, AsyncStream[t.Any]]): + def __init__( + self, + base_url: str | httpx.URL, + version: str, + timeout: int | httpx.Timeout = DEFAULT_TIMEOUT, + max_retries: int = MAX_RETRIES, + ): + super().__init__( + base_url=base_url, + version=version, + timeout=timeout, + max_retries=max_retries, + client=httpx.AsyncClient(base_url=base_url, timeout=timeout), + _default_stream_cls=AsyncStream, + _internal=True, + ) + + async def close(self): + await self._inner.aclose() + + async def __aenter__(self): + return self + + async def __aexit__(self, *args) -> None: + await self.close() + + async def request( + self, + response_cls: type[Response], + options: RequestOptions, + remaining_retries: t.Optional[int] = None, + *, + stream: bool = False, + stream_cls: type[_AsyncStream] | None = None, + ) -> Response | _AsyncStream: + return await self._request( + response_cls, options, remaining_retries=remaining_retries, stream=stream, stream_cls=stream_cls + ) + + async def _request( + self, + response_cls: type[Response], + options: RequestOptions, + remaining_retries: int | None = None, + *, + stream: bool = False, + stream_cls: type[_AsyncStream] | None = None, + ) -> Response | _AsyncStream: + retries = self._remaining_retries(remaining_retries, options) + request = self._build_request(options) + + try: + response = await self._inner.send(request, auth=self.auth, stream=stream) + logger.debug('HTTP [%s, %s]: %s [%i]', request.method, request.url, response.status_code, response.reason_phrase) + response.raise_for_status() + except httpx.HTTPStatusError as exc: + if retries > 0 and self._should_retry(exc.response): + return self._retry_request( + response_cls, options, retries, exc.response.headers, stream=stream, stream_cls=stream_cls + ) + # If the response is streamed then we need to explicitly read the completed response + await exc.response.aread() + raise ValueError(exc.message) from None + except httpx.ConnectTimeout as err: + if retries > 0: + return await self._retry_request(response_cls, options, retries, stream=stream, stream_cls=stream_cls) + raise ValueError(request) from err # timeout + except httpx.ReadTimeout: + # We don't retry on ReadTimeout error, so something might happen on server-side + raise + except httpx.TimeoutException as err: + if retries > 0: + return await self._retry_request(response_cls, options, retries, stream=stream, stream_cls=stream_cls) + raise ValueError(request) from err # timeout + except Exception as err: + if retries > 0: + return await self._retry_request(response_cls, options, retries, stream=stream, stream_cls=stream_cls) + raise ValueError(request) from err # connection error + + return self._process_response( + response_cls=response_cls, options=options, raw_response=response, stream=stream, stream_cls=stream_cls + ) + + async def _retry_request( + self, + response_cls: type[Response], + options: RequestOptions, + remaining_retries: int, + response_headers: httpx.Headers | None = None, + *, + stream: bool, + stream_cls: type[_AsyncStream] | None, + ): + remaining = remaining_retries - 1 + timeout = self._calculate_retry_timeout(remaining_retries, options, response_headers) + logger.info('Retrying request to %s in %f seconds', options.url, timeout) + await anyio.sleep(timeout) + return await self._request(response_cls, options, remaining, stream=stream, stream_cls=stream_cls) + + async def _get( + self, + path: str, + *, + response_cls: type[Response], + options: dict[str, t.Any] | None = None, + stream: bool = False, + stream_cls: type[_AsyncStream] | None = None, + ) -> Response | _AsyncStream: + if options is None: + options = {} + return await self.request( + response_cls, RequestOptions(method='GET', url=path, **options), stream=stream, stream_cls=stream_cls + ) + + async def _post( + self, + path: str, + *, + response_cls: type[Response], + json: dict[str, t.Any], + options: dict[str, t.Any] | None = None, + stream: bool = False, + stream_cls: type[_AsyncStream] | None = None, + ) -> Response | _AsyncStream: + if options is None: + options = {} + return await self.request( + response_cls, RequestOptions(method='POST', url=path, json=json, **options), stream=stream, stream_cls=stream_cls + ) diff --git a/openllm-client/src/openllm_client/_stream.py b/openllm-client/src/openllm_client/_stream.py new file mode 100644 index 00000000..6a3d5da0 --- /dev/null +++ b/openllm-client/src/openllm_client/_stream.py @@ -0,0 +1,143 @@ +from __future__ import annotations +import typing as t + +import attr +import httpx +import orjson + + +if t.TYPE_CHECKING: + from ._shim import AsyncClient + from ._shim import Client + +Response = t.TypeVar('Response', bound=attr.AttrsInstance) + + +@attr.define(auto_attribs=True) +class Stream(t.Generic[Response]): + _response_cls: t.Type[Response] + _response: httpx.Response + _client: Client + _decoder: SSEDecoder = attr.field(factory=lambda: SSEDecoder()) + _iterator: t.Iterator[Response] = attr.field(init=False) + + def __init__(self, response_cls, response, client): + self.__attrs_init__(response_cls, response, client) + self._iterator = self._stream() + + def __next__(self): + return self._iterator.__next__() + + def __iter__(self) -> t.Iterator[Response]: + for item in self._iterator: + yield item + + def _iter_events(self) -> t.Iterator[SSE]: + yield from self._decoder.iter(self._response.iter_lines()) + + def _stream(self) -> t.Iterator[Response]: + for sse in self._iter_events(): + if sse.data.startswith('[DONE]'): + break + if sse.event is None: + yield self._client._process_response_data( + data=sse.model_dump(), response_cls=self._response_cls, raw_response=self._response + ) + + +@attr.define(auto_attribs=True) +class AsyncStream(t.Generic[Response]): + _response_cls: t.Type[Response] + _response: httpx.Response + _client: AsyncClient + _decoder: SSEDecoder = attr.field(factory=lambda: SSEDecoder()) + _iterator: t.Iterator[Response] = attr.field(init=False) + + def __init__(self, response_cls, response, client): + self.__attrs_init__(response_cls, response, client) + self._iterator = self._stream() + + async def __anext__(self): + return await self._iterator.__anext__() + + async def __aiter__(self): + async for item in self._iterator: + yield item + + async def _iter_events(self): + async for sse in self._decoder.aiter(self._response.aiter_lines()): + yield sse + + async def _stream(self) -> t.AsyncGenerator[Response, None]: + async for sse in self._iter_events(): + if sse.data.startswith('[DONE]'): + break + if sse.event is None: + yield self._client._process_response_data( + data=sse.model_dump(), response_cls=self._response_cls, raw_response=self._response + ) + + +@attr.define +class SSE: + data: str = attr.field(default='') + id: t.Optional[str] = attr.field(default=None) + event: t.Optional[str] = attr.field(default=None) + retry: t.Optional[int] = attr.field(default=None) + + def model_dump(self) -> t.Dict[str, t.Any]: + try: + return orjson.loads(self.data) + except orjson.JSONDecodeError: + raise + + +@attr.define(auto_attribs=True) +class SSEDecoder: + _data: t.List[str] = attr.field(factory=list) + _event: t.Optional[str] = None + _retry: t.Optional[int] = None + _last_event_id: t.Optional[str] = None + + def iter(self, iterator: t.Iterator[str]) -> t.Iterator[SSE]: + for line in iterator: + sse = self.decode(line.rstrip('\n')) + if sse: + yield sse + + async def aiter(self, iterator: t.AsyncIterator[str]) -> t.AsyncIterator[SSE]: + async for line in iterator: + sse = self.decode(line.rstrip('\n')) + if sse: + yield sse + + def decode(self, line: str) -> SSE | None: + # NOTE: https://html.spec.whatwg.org/multipage/server-sent-events.html#event-stream-interpretation + if not line: + if all(not a for a in [self._event, self._data, self._retry, self._last_event_id]): + return None + sse = SSE(data='\n'.join(self._data), event=self._event, retry=self._retry, id=self._last_event_id) + self._event, self._data, self._retry = None, [], None + return sse + if line.startswith(':'): + return None + field, _, value = line.partition(':') + if value.startswith(' '): + value = value[1:] + if field == 'event': + self._event = value + elif field == 'data': + self._data.append(value) + elif field == 'id': + if '\0' in value: + pass + else: + self._last_event_id = value + elif field == 'retry': + try: + self._retry = int(value) + except (TypeError, ValueError): + pass + else: + pass # Ignore unknown fields + return None diff --git a/openllm-client/src/openllm_client/_typing_compat.py b/openllm-client/src/openllm_client/_typing_compat.py new file mode 100644 index 00000000..ecdf6a86 --- /dev/null +++ b/openllm-client/src/openllm_client/_typing_compat.py @@ -0,0 +1,29 @@ +import sys + +from typing import Literal + + +if sys.version_info[:2] >= (3, 11): + from typing import LiteralString as LiteralString + from typing import NotRequired as NotRequired + from typing import Required as Required + from typing import Self as Self + from typing import dataclass_transform as dataclass_transform + from typing import overload as overload +else: + from typing_extensions import LiteralString as LiteralString + from typing_extensions import NotRequired as NotRequired + from typing_extensions import Required as Required + from typing_extensions import Self as Self + from typing_extensions import dataclass_transform as dataclass_transform + from typing_extensions import overload as overload + +if sys.version_info[:2] >= (3, 9): + from typing import Annotated as Annotated +else: + from typing_extensions import Annotated as Annotated + +Platform = Annotated[ + LiteralString, Literal['MacOS', 'Linux', 'Windows', 'FreeBSD', 'OpenBSD', 'iOS', 'iPadOS', 'Android', 'Unknown'], str +] +Architecture = Annotated[LiteralString, Literal['arm', 'arm64', 'x86', 'x86_64', 'Unknown'], str] diff --git a/openllm-client/src/openllm_client/_utils.py b/openllm-client/src/openllm_client/_utils.py new file mode 100644 index 00000000..d5ddbf30 --- /dev/null +++ b/openllm-client/src/openllm_client/_utils.py @@ -0,0 +1,6 @@ +from __future__ import annotations + +from cattr import Converter + + +converter = Converter(omit_if_default=True) diff --git a/openllm-core/src/openllm_core/_conversation.py b/openllm-core/src/openllm_core/_conversation.py index 71c4d149..4c1cbce1 100644 --- a/openllm-core/src/openllm_core/_conversation.py +++ b/openllm-core/src/openllm_core/_conversation.py @@ -108,7 +108,7 @@ class Conversation: elif self.sep_style == SeparatorStyle.LLAMA: seps = [self.sep, self.sep2] ret = system_prompt if self.system_message else '[INST] ' - for i, (role, message) in enumerate(self.messages): + for i, (_, message) in enumerate(self.messages): tag = self.roles[i % 2] if message: if i == 0: @@ -139,12 +139,12 @@ class Conversation: return ret elif self.sep_style == SeparatorStyle.MPT: ret = f'<|im_start|>system\n{system_prompt}<|im_end|>{self.sep}' if system_prompt else '' - for i, (role, message) in enumerate(self.messages): + for _, (role, message) in enumerate(self.messages): ret += f'<|im_start|>{role}\n{message}<|im_end|>{self.sep}' if message else f'{role}:' return ret elif self.sep_style == SeparatorStyle.STARCODER: ret = f'<|system|>\n{system_prompt}<|end|>{self.sep}' if system_prompt else '' - for i, (role, message) in enumerate(self.messages): + for _, (role, message) in enumerate(self.messages): ret += f'{role}\n{message}<|end|>{self.sep}' if message else f'{role}:' else: raise ValueError(f'Invalid style: {self.sep_style}') diff --git a/openllm-core/src/openllm_core/_schemas.py b/openllm-core/src/openllm_core/_schemas.py index e23f7b56..679e09d4 100644 --- a/openllm-core/src/openllm_core/_schemas.py +++ b/openllm-core/src/openllm_core/_schemas.py @@ -16,26 +16,45 @@ from .utils import gen_random_uuid if t.TYPE_CHECKING: import vllm + from ._typing_compat import Self -@attr.frozen(slots=True) -class MetadataOutput: - model_id: str - timeout: int - model_name: str - backend: str - configuration: str - prompt_template: str - system_message: str +@attr.define +class _SchemaMixin: def model_dump(self) -> dict[str, t.Any]: return converter.unstructure(self) def model_dump_json(self) -> str: return orjson.dumps(self.model_dump(), option=orjson.OPT_INDENT_2).decode('utf-8') + def with_options(self, **options: t.Any) -> Self: + return attr.evolve(self, **options) -@attr.define(slots=True, frozen=True) -class GenerationInput: + +@attr.define +class MetadataOutput(_SchemaMixin): + model_id: str + timeout: int + model_name: str + backend: str + configuration: str + prompt_template: t.Optional[str] + system_message: t.Optional[str] + + def model_dump(self) -> dict[str, t.Any]: + return { + 'model_id': self.model_id, + 'timeout': self.timeout, + 'model_name': self.model_name, + 'backend': self.backend, + 'configuration': self.configuration, + 'prompt_template': orjson.dumps(self.prompt_template).decode(), + 'system_message': orjson.dumps(self.system_message).decode(), + } + + +@attr.define +class GenerationInput(_SchemaMixin): prompt: str llm_config: LLMConfig stop: list[str] | None = attr.field(default=None) @@ -53,9 +72,6 @@ class GenerationInput: 'adapter_name': self.adapter_name, } - def model_dump_json(self) -> str: - return orjson.dumps(self.model_dump(), option=orjson.OPT_INDENT_2).decode('utf-8') - @classmethod def from_llm_config(cls, llm_config: LLMConfig) -> type[GenerationInput]: def init(self: GenerationInput, prompt: str, stop: list[str] | None, adapter_name: str | None) -> None: @@ -80,7 +96,7 @@ class GenerationInput: def examples(_: type[GenerationInput]) -> dict[str, t.Any]: return klass(prompt='What is the meaning of life?', llm_config=llm_config, stop=['\n']).model_dump() - setattr(klass, 'examples', classmethod(examples)) + klass.examples = classmethod(examples) try: klass.__module__ = cls.__module__ @@ -98,7 +114,7 @@ FinishReason = t.Literal['length', 'stop'] @attr.define -class CompletionChunk: +class CompletionChunk(_SchemaMixin): index: int text: str token_ids: t.List[int] @@ -106,15 +122,12 @@ class CompletionChunk: logprobs: t.Optional[SampleLogprobs] = None finish_reason: t.Optional[FinishReason] = None - # yapf: disable - def with_options(self,**options: t.Any)->CompletionChunk: return attr.evolve(self, **options) - def model_dump(self)->dict[str, t.Any]:return converter.unstructure(self) - def model_dump_json(self)->str:return orjson.dumps(self.model_dump(),option=orjson.OPT_NON_STR_KEYS).decode('utf-8') - # yapf: enable + def model_dump_json(self) -> str: + return orjson.dumps(self.model_dump(), option=orjson.OPT_NON_STR_KEYS).decode('utf-8') @attr.define -class GenerationOutput: +class GenerationOutput(_SchemaMixin): prompt: str finished: bool outputs: t.List[CompletionChunk] @@ -182,11 +195,5 @@ class GenerationOutput: prompt_logprobs=request_output.prompt_logprobs, ) - def with_options(self, **options: t.Any) -> GenerationOutput: - return attr.evolve(self, **options) - - def model_dump(self) -> dict[str, t.Any]: - return converter.unstructure(self) - def model_dump_json(self) -> str: return orjson.dumps(self.model_dump(), option=orjson.OPT_NON_STR_KEYS).decode('utf-8') diff --git a/openllm-core/src/openllm_core/_typing_compat.py b/openllm-core/src/openllm_core/_typing_compat.py index 8f1b6904..ecaeb7fb 100644 --- a/openllm-core/src/openllm_core/_typing_compat.py +++ b/openllm-core/src/openllm_core/_typing_compat.py @@ -70,6 +70,11 @@ else: from typing_extensions import TypeAlias as TypeAlias from typing_extensions import TypeGuard as TypeGuard +if sys.version_info[:2] >= (3, 9): + from typing import Annotated as Annotated +else: + from typing_extensions import Annotated as Annotated + class AdapterTuple(TupleAny): adapter_id: str diff --git a/openllm-core/src/openllm_core/config/configuration_auto.py b/openllm-core/src/openllm_core/config/configuration_auto.py index fd5a0f5b..24339261 100644 --- a/openllm-core/src/openllm_core/config/configuration_auto.py +++ b/openllm-core/src/openllm_core/config/configuration_auto.py @@ -158,11 +158,11 @@ class AutoConfig: else: try: config_file = llm.bentomodel.path_of(CONFIG_FILE_NAME) - except OpenLLMException: + except OpenLLMException as err: if not is_transformers_available(): raise MissingDependencyError( "'infer_class_from_llm' requires 'transformers' to be available. Make sure to install it with 'pip install transformers'" - ) + ) from err from transformers.utils import cached_file try: diff --git a/openllm-core/src/openllm_core/utils/peft.py b/openllm-core/src/openllm_core/utils/peft.py index 0158ab61..5a2301d1 100644 --- a/openllm-core/src/openllm_core/utils/peft.py +++ b/openllm-core/src/openllm_core/utils/peft.py @@ -101,7 +101,7 @@ class FineTuneConfig: from peft.mapping import get_peft_config from peft.utils.peft_types import TaskType except ImportError: - raise ImportError('PEFT is not installed. Please install it via `pip install "openllm[fine-tune]"`.') + raise ImportError('PEFT is not installed. Please install it via `pip install "openllm[fine-tune]"`.') from None adapter_config = self.adapter_config.copy() # no need for peft_type if 'peft_type' in adapter_config: diff --git a/openllm-python/src/openllm/__init__.py b/openllm-python/src/openllm/__init__.py index a83ccd19..7ae0ac06 100644 --- a/openllm-python/src/openllm/__init__.py +++ b/openllm-python/src/openllm/__init__.py @@ -66,7 +66,7 @@ else: _import_structure: dict[str, list[str]] = { 'exceptions': [], - 'client': [], + 'client': ['HTTPClient', 'AsyncHTTPClient'], 'bundle': [], 'playground': [], 'testing': [], @@ -98,6 +98,8 @@ if _t.TYPE_CHECKING: from . import serialisation as serialisation from . import testing as testing from . import utils as utils + from .client import HTTPClient as HTTPClient + from .client import AsyncHTTPClient as AsyncHTTPClient from ._deprecated import Runner as Runner from ._generation import LogitsProcessorList as LogitsProcessorList from ._generation import StopOnTokens as StopOnTokens diff --git a/openllm-python/src/openllm/__main__.py b/openllm-python/src/openllm/__main__.py index 97e94bc2..6721cd10 100644 --- a/openllm-python/src/openllm/__main__.py +++ b/openllm-python/src/openllm/__main__.py @@ -7,9 +7,6 @@ To start any OpenLLM model: openllm start --options ... """ -from __future__ import annotations - - if __name__ == '__main__': from openllm.cli.entrypoint import cli diff --git a/openllm-python/src/openllm/_generation.py b/openllm-python/src/openllm/_generation.py index fd914c5c..3b4b8aea 100644 --- a/openllm-python/src/openllm/_generation.py +++ b/openllm-python/src/openllm/_generation.py @@ -71,7 +71,7 @@ def is_sentence_complete(output: str) -> bool: def is_partial_stop(output: str, stop_str: str) -> bool: """Check whether the output contains a partial stop str.""" - for i in range(0, min(len(output), len(stop_str))): + for i in range(min(len(output), len(stop_str))): if stop_str.startswith(output[-i:]): return True return False diff --git a/openllm-python/src/openllm/_strategies.py b/openllm-python/src/openllm/_strategies.py index 454aac72..944513f2 100644 --- a/openllm-python/src/openllm/_strategies.py +++ b/openllm-python/src/openllm/_strategies.py @@ -235,7 +235,7 @@ def _validate(cls: type[DynResource], val: list[t.Any]) -> None: raise RuntimeError('Failed to initialise CUDA runtime binding.') # correctly parse handle for el in val: - if el.startswith('GPU-') or el.startswith('MIG-'): + if el.startswith(('GPU-', 'MIG-')): uuids = _raw_device_uuid_nvml() if uuids is None: raise ValueError('Failed to parse available GPUs UUID') diff --git a/openllm-python/src/openllm/entrypoints/_openapi.py b/openllm-python/src/openllm/entrypoints/_openapi.py index 9f8f7485..882a6e42 100644 --- a/openllm-python/src/openllm/entrypoints/_openapi.py +++ b/openllm-python/src/openllm/entrypoints/_openapi.py @@ -580,6 +580,6 @@ def append_schemas( def mk_generate_spec(svc:bentoml.Service,openapi_version:str=OPENAPI_VERSION)->MKSchema:return MKSchema(svc_schema) def mk_asdict(self:OpenAPISpecification)->dict[str,t.Any]:return svc_schema openapi.generate_spec=mk_generate_spec - setattr(OpenAPISpecification, 'asdict', mk_asdict) + OpenAPISpecification.asdict = mk_asdict # yapf: disable return svc diff --git a/openllm-python/src/openllm/serialisation/ggml.py b/openllm-python/src/openllm/serialisation/ggml.py index 413050a9..9e15d5ac 100644 --- a/openllm-python/src/openllm/serialisation/ggml.py +++ b/openllm-python/src/openllm/serialisation/ggml.py @@ -1,30 +1,13 @@ -"""Serialisation related implementation for GGML-based implementation. - -This requires ctransformers to be installed. -""" - from __future__ import annotations -import typing as t -if t.TYPE_CHECKING: - import bentoml - import openllm - - from openllm_core._typing_compat import M - -_conversion_strategy = {'pt': 'ggml'} - - -def import_model( - llm: openllm.LLM[t.Any, t.Any], *decls: t.Any, trust_remote_code: bool = True, **attrs: t.Any -) -> bentoml.Model: +def import_model(llm, *decls, trust_remote_code=True, **attrs): raise NotImplementedError('Currently work in progress.') -def get(llm: openllm.LLM[t.Any, t.Any]) -> bentoml.Model: +def get(llm): raise NotImplementedError('Currently work in progress.') -def load_model(llm: openllm.LLM[M, t.Any], *decls: t.Any, **attrs: t.Any) -> M: +def load_model(llm, *decls, **attrs): raise NotImplementedError('Currently work in progress.') diff --git a/openllm-python/src/openllm/serialisation/transformers/__init__.py b/openllm-python/src/openllm/serialisation/transformers/__init__.py index e3c91cb8..599aa6af 100644 --- a/openllm-python/src/openllm/serialisation/transformers/__init__.py +++ b/openllm-python/src/openllm/serialisation/transformers/__init__.py @@ -1,5 +1,3 @@ -"""Serialisation related implementation for Transformers-based implementation.""" - from __future__ import annotations import importlib import logging diff --git a/openllm-python/src/openllm/utils/__init__.py b/openllm-python/src/openllm/utils/__init__.py index 5cb72b3a..2438da98 100644 --- a/openllm-python/src/openllm/utils/__init__.py +++ b/openllm-python/src/openllm/utils/__init__.py @@ -6,6 +6,7 @@ we won't ensure backward compatibility for these functions. So use with caution. from __future__ import annotations import functools +import importlib.metadata import typing as t import openllm_core @@ -72,6 +73,7 @@ def generate_labels(llm: openllm.LLM[t.Any, t.Any]) -> dict[str, t.Any]: 'model_name': llm.config['model_name'], 'architecture': llm.config['architecture'], 'serialisation': llm._serialisation, + **{package: importlib.metadata.version(package) for package in {'openllm', 'openllm-core', 'openllm-client'}}, } diff --git a/ruff.toml b/ruff.toml index daa84865..0283924b 100644 --- a/ruff.toml +++ b/ruff.toml @@ -9,6 +9,10 @@ extend-exclude = [ ] extend-include = ["*.ipynb"] extend-select = [ + "E", + "F", + "B", + "PIE", "I", # isort "G", # flake8-logging-format "W", # pycodestyle @@ -40,11 +44,19 @@ ignore = [ line-length = 119 indent-width = 2 target-version = "py38" -typing-modules = ["openllm_core._typing_compat"] +typing-modules = [ + "openllm_core._typing_compat", + "openllm_client._typing_compat", +] unfixable = ["TCH004"] [lint.flake8-type-checking] -exempt-modules = ["typing", "typing_extensions", "openllm_core._typing_compat"] +exempt-modules = [ + "typing", + "typing_extensions", + "openllm_core._typing_compat", + "openllm_client._typing_compat", +] runtime-evaluated-base-classes = [ "openllm_core._configuration.LLMConfig", "openllm_core._configuration.GenerationConfig", @@ -52,7 +64,15 @@ runtime-evaluated-base-classes = [ "openllm_core._configuration.ModelSettings", "openllm.LLMConfig", ] -runtime-evaluated-decorators = ["attrs.define", "attrs.frozen", "trait"] +runtime-evaluated-decorators = [ + "attrs.define", + "attrs.frozen", + "trait", + "attr.attrs", + 'attr.define', + '_attr.define', + 'attr.frozen', +] [format] preview = true @@ -65,6 +85,7 @@ convention = "google" [lint.pycodestyle] ignore-overlong-task-comments = true +max-line-length = 119 [lint.isort] combine-as-imports = true