From 727361ced761c82351ff539fcafa7af62fb5e2f0 Mon Sep 17 00:00:00 2001 From: Aaron <29749331+aarnphm@users.noreply.github.com> Date: Fri, 15 Mar 2024 05:35:24 -0400 Subject: [PATCH] chore: running updated ruff formatter [skip ci] Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> --- .../src/openllm_client/__init__.pyi | 24 +++-- openllm-client/src/openllm_client/_http.py | 28 ++++-- openllm-client/src/openllm_client/_schemas.py | 23 +++-- openllm-client/src/openllm_client/_shim.py | 87 +++++++++++++++---- openllm-client/src/openllm_client/_stream.py | 8 +- .../src/openllm_client/_typing_compat.py | 4 +- openllm-python/src/openllm/client.pyi | 4 +- openllm-python/src/openllm_cli/_sdk.py | 4 +- .../src/openllm_cli/extension/dive_bentos.py | 4 +- .../extension/get_containerfile.py | 4 +- .../src/openllm_cli/extension/get_prompt.py | 32 +++++-- .../src/openllm_cli/extension/list_models.py | 9 +- .../src/openllm_cli/extension/playground.py | 17 +++- openllm-python/tests/configuration_test.py | 11 ++- openllm-python/tests/conftest.py | 14 ++- openllm-python/tests/strategies_test.py | 8 +- 16 files changed, 216 insertions(+), 65 deletions(-) diff --git a/openllm-client/src/openllm_client/__init__.pyi b/openllm-client/src/openllm_client/__init__.pyi index 3b5ecfb3..bec8fc11 100644 --- a/openllm-client/src/openllm_client/__init__.pyi +++ b/openllm-client/src/openllm_client/__init__.pyi @@ -15,11 +15,17 @@ class HTTPClient: address: str helpers: _Helpers @overload - def __init__(self, address: str, timeout: int = ..., verify: bool = ..., max_retries: int = ..., api_version: str = ...) -> None: ... + def __init__( + self, address: str, timeout: int = ..., verify: bool = ..., max_retries: int = ..., api_version: str = ... + ) -> None: ... @overload - def __init__(self, address: str = ..., timeout: int = ..., verify: bool = ..., max_retries: int = ..., api_version: str = ...) -> None: ... + def __init__( + self, address: str = ..., timeout: int = ..., verify: bool = ..., max_retries: int = ..., api_version: str = ... + ) -> None: ... @overload - def __init__(self, address: None = ..., timeout: int = ..., verify: bool = ..., max_retries: int = ..., api_version: str = ...) -> None: ... + def __init__( + self, address: None = ..., timeout: int = ..., verify: bool = ..., max_retries: int = ..., api_version: str = ... + ) -> None: ... @property def is_ready(self) -> bool: ... def health(self) -> bool: ... @@ -60,11 +66,17 @@ class AsyncHTTPClient: address: str helpers: _AsyncHelpers @overload - def __init__(self, address: str, timeout: int = ..., verify: bool = ..., max_retries: int = ..., api_version: str = ...) -> None: ... + def __init__( + self, address: str, timeout: int = ..., verify: bool = ..., max_retries: int = ..., api_version: str = ... + ) -> None: ... @overload - def __init__(self, address: str = ..., timeout: int = ..., verify: bool = ..., max_retries: int = ..., api_version: str = ...) -> None: ... + def __init__( + self, address: str = ..., timeout: int = ..., verify: bool = ..., max_retries: int = ..., api_version: str = ... + ) -> None: ... @overload - def __init__(self, address: None = ..., timeout: int = ..., verify: bool = ..., max_retries: int = ..., api_version: str = ...) -> None: ... + def __init__( + self, address: None = ..., timeout: int = ..., verify: bool = ..., max_retries: int = ..., api_version: str = ... + ) -> None: ... @property def is_ready(self) -> bool: ... async def health(self) -> bool: ... diff --git a/openllm-client/src/openllm_client/_http.py b/openllm-client/src/openllm_client/_http.py index 9923468e..ed802a67 100644 --- a/openllm-client/src/openllm_client/_http.py +++ b/openllm-client/src/openllm_client/_http.py @@ -70,10 +70,14 @@ class HTTPClient(Client): return self.generate(prompt, **attrs) def health(self): - response = self._get('/readyz', response_cls=None, options={'return_raw_response': True, 'max_retries': self._max_retries}) + response = self._get( + '/readyz', response_cls=None, options={'return_raw_response': True, 'max_retries': self._max_retries} + ) return response.status_code == 200 - def generate(self, prompt, llm_config=None, stop=None, adapter_name=None, timeout=None, verify=None, **attrs) -> Response: + def generate( + self, prompt, llm_config=None, stop=None, adapter_name=None, timeout=None, verify=None, **attrs + ) -> Response: if timeout is None: timeout = self._timeout if verify is None: @@ -96,7 +100,9 @@ class HTTPClient(Client): for response_chunk in self.generate_iterator(prompt, llm_config, stop, adapter_name, timeout, verify, **attrs): yield StreamingResponse.from_response_chunk(response_chunk) - def generate_iterator(self, prompt, llm_config=None, stop=None, adapter_name=None, timeout=None, verify=None, **attrs) -> t.Iterator[Response]: + def generate_iterator( + self, prompt, llm_config=None, stop=None, adapter_name=None, timeout=None, verify=None, **attrs + ) -> t.Iterator[Response]: if timeout is None: timeout = self._timeout if verify is None: @@ -146,7 +152,9 @@ class AsyncHTTPClient(AsyncClient): @property async def _metadata(self) -> t.Awaitable[Metadata]: if self.__metadata is None: - self.__metadata = await self._post(f'/{self._api_version}/metadata', response_cls=Metadata, json={}, options={'max_retries': self._max_retries}) + self.__metadata = await self._post( + f'/{self._api_version}/metadata', response_cls=Metadata, json={}, options={'max_retries': self._max_retries} + ) return self.__metadata @property @@ -159,10 +167,14 @@ class AsyncHTTPClient(AsyncClient): return await self.generate(prompt, **attrs) async def health(self): - response = await self._get('/readyz', response_cls=None, options={'return_raw_response': True, 'max_retries': self._max_retries}) + response = await self._get( + '/readyz', response_cls=None, options={'return_raw_response': True, 'max_retries': self._max_retries} + ) return response.status_code == 200 - async def generate(self, prompt, llm_config=None, stop=None, adapter_name=None, timeout=None, verify=None, **attrs) -> Response: + async def generate( + self, prompt, llm_config=None, stop=None, adapter_name=None, timeout=None, verify=None, **attrs + ) -> Response: if timeout is None: timeout = self._timeout if verify is None: @@ -183,7 +195,9 @@ class AsyncHTTPClient(AsyncClient): async def generate_stream( self, prompt, llm_config=None, stop=None, adapter_name=None, timeout=None, verify=None, **attrs ) -> t.AsyncGenerator[StreamingResponse, t.Any]: - async for response_chunk in self.generate_iterator(prompt, llm_config, stop, adapter_name, timeout, verify, **attrs): + async for response_chunk in self.generate_iterator( + prompt, llm_config, stop, adapter_name, timeout, verify, **attrs + ): yield StreamingResponse.from_response_chunk(response_chunk) async def generate_iterator( diff --git a/openllm-client/src/openllm_client/_schemas.py b/openllm-client/src/openllm_client/_schemas.py index 037c92b3..1723d583 100644 --- a/openllm-client/src/openllm_client/_schemas.py +++ b/openllm-client/src/openllm_client/_schemas.py @@ -17,7 +17,7 @@ if t.TYPE_CHECKING: from ._shim import AsyncClient, Client -__all__ = ['Response', 'CompletionChunk', 'Metadata', 'StreamingResponse', 'Helpers'] +__all__ = ['CompletionChunk', 'Helpers', 'Metadata', 'Response', 'StreamingResponse'] @attr.define @@ -42,7 +42,11 @@ def _structure_metadata(data: t.Dict[str, t.Any], cls: type[Metadata]) -> Metada raise RuntimeError(f'Malformed metadata configuration (Server-side issue): {e}') from None try: return cls( - model_id=data['model_id'], timeout=data['timeout'], model_name=data['model_name'], backend=data['backend'], configuration=configuration + model_id=data['model_id'], + timeout=data['timeout'], + model_name=data['model_name'], + backend=data['backend'], + configuration=configuration, ) except Exception as e: raise RuntimeError(f'Malformed metadata (Server-side issue): {e}') from None @@ -61,7 +65,10 @@ class StreamingResponse(_SchemaMixin): @classmethod def from_response_chunk(cls, response: Response) -> StreamingResponse: return cls( - request_id=response.request_id, index=response.outputs[0].index, text=response.outputs[0].text, token_ids=response.outputs[0].token_ids[0] + request_id=response.request_id, + index=response.outputs[0].index, + text=response.outputs[0].text, + token_ids=response.outputs[0].token_ids[0], ) @@ -88,11 +95,17 @@ class Helpers: return self._async_client def messages(self, messages, add_generation_prompt=False): - return self.client._post('/v1/helpers/messages', response_cls=str, json=dict(messages=messages, add_generation_prompt=add_generation_prompt)) + return self.client._post( + '/v1/helpers/messages', + response_cls=str, + json=dict(messages=messages, add_generation_prompt=add_generation_prompt), + ) async def async_messages(self, messages, add_generation_prompt=False): return await self.async_client._post( - '/v1/helpers/messages', response_cls=str, json=dict(messages=messages, add_generation_prompt=add_generation_prompt) + '/v1/helpers/messages', + response_cls=str, + json=dict(messages=messages, add_generation_prompt=add_generation_prompt), ) @classmethod diff --git a/openllm-client/src/openllm_client/_shim.py b/openllm-client/src/openllm_client/_shim.py index 04a9a730..c0841060 100644 --- a/openllm-client/src/openllm_client/_shim.py +++ b/openllm-client/src/openllm_client/_shim.py @@ -140,7 +140,9 @@ class APIResponse(t.Generic[Response]): data = self._raw_response.json() try: - return self._client._process_response_data(data=data, response_cls=self._response_cls, raw_response=self._raw_response) + return self._client._process_response_data( + data=data, response_cls=self._response_cls, raw_response=self._raw_response + ) except Exception as exc: raise ValueError(exc) from None # validation error here @@ -271,10 +273,16 @@ class BaseClient(t.Generic[InnerClient, StreamType]): def _build_request(self, options: RequestOptions) -> httpx.Request: return self._inner.build_request( - method=options.method, headers=self._build_headers(options), url=self._prepare_url(options.url), json=options.json, params=options.params + method=options.method, + headers=self._build_headers(options), + url=self._prepare_url(options.url), + json=options.json, + params=options.params, ) - def _calculate_retry_timeout(self, remaining_retries: int, options: RequestOptions, headers: t.Optional[httpx.Headers] = None) -> float: + def _calculate_retry_timeout( + self, remaining_retries: int, options: RequestOptions, headers: t.Optional[httpx.Headers] = None + ) -> float: max_retries = options.get_max_retries(self._max_retries) # https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After try: @@ -315,7 +323,9 @@ class BaseClient(t.Generic[InnerClient, StreamType]): return True return False - def _process_response_data(self, *, response_cls: type[Response], data: t.Dict[str, t.Any], raw_response: httpx.Response) -> Response: + def _process_response_data( + self, *, response_cls: type[Response], data: t.Dict[str, t.Any], raw_response: httpx.Response + ) -> Response: return converter.structure(data, response_cls) def _process_response( @@ -328,13 +338,24 @@ class BaseClient(t.Generic[InnerClient, StreamType]): stream_cls: type[_Stream] | type[_AsyncStream] | None, ) -> Response: return APIResponse( - raw_response=raw_response, client=self, response_cls=response_cls, stream=stream, stream_cls=stream_cls, options=options + raw_response=raw_response, + client=self, + response_cls=response_cls, + stream=stream, + stream_cls=stream_cls, + options=options, ).parse() @attr.define(init=False) class Client(BaseClient[httpx.Client, Stream[t.Any]]): - def __init__(self, base_url: str | httpx.URL, version: str, timeout: int | httpx.Timeout = DEFAULT_TIMEOUT, max_retries: int = MAX_RETRIES): + def __init__( + self, + base_url: str | httpx.URL, + version: str, + timeout: int | httpx.Timeout = DEFAULT_TIMEOUT, + max_retries: int = MAX_RETRIES, + ): super().__init__( base_url=base_url, version=version, @@ -366,7 +387,13 @@ class Client(BaseClient[httpx.Client, Stream[t.Any]]): stream: bool = False, stream_cls: type[_Stream] | None = None, ) -> Response | _Stream: - return self._request(response_cls=response_cls, options=options, remaining_retries=remaining_retries, stream=stream, stream_cls=stream_cls) + return self._request( + response_cls=response_cls, + options=options, + remaining_retries=remaining_retries, + stream=stream, + stream_cls=stream_cls, + ) def _request( self, @@ -385,7 +412,9 @@ class Client(BaseClient[httpx.Client, Stream[t.Any]]): response.raise_for_status() except httpx.HTTPStatusError as exc: if retries > 0 and self._should_retry(exc.response): - return self._retry_request(response_cls, options, retries, exc.response.headers, stream=stream, stream_cls=stream_cls) + return self._retry_request( + response_cls, options, retries, exc.response.headers, stream=stream, stream_cls=stream_cls + ) # If the response is streamed then we need to explicitly read the completed response exc.response.read() raise ValueError(exc.message) from None @@ -398,7 +427,9 @@ class Client(BaseClient[httpx.Client, Stream[t.Any]]): return self._retry_request(response_cls, options, retries, stream=stream, stream_cls=stream_cls) raise ValueError(request) from None # connection error - return self._process_response(response_cls=response_cls, options=options, raw_response=response, stream=stream, stream_cls=stream_cls) + return self._process_response( + response_cls=response_cls, options=options, raw_response=response, stream=stream, stream_cls=stream_cls + ) def _retry_request( self, @@ -428,7 +459,9 @@ class Client(BaseClient[httpx.Client, Stream[t.Any]]): ) -> Response | _Stream: if options is None: options = {} - return self.request(response_cls, RequestOptions(method='GET', url=path, **options), stream=stream, stream_cls=stream_cls) + return self.request( + response_cls, RequestOptions(method='GET', url=path, **options), stream=stream, stream_cls=stream_cls + ) def _post( self, @@ -442,12 +475,20 @@ class Client(BaseClient[httpx.Client, Stream[t.Any]]): ) -> Response | _Stream: if options is None: options = {} - return self.request(response_cls, RequestOptions(method='POST', url=path, json=json, **options), stream=stream, stream_cls=stream_cls) + return self.request( + response_cls, RequestOptions(method='POST', url=path, json=json, **options), stream=stream, stream_cls=stream_cls + ) @attr.define(init=False) class AsyncClient(BaseClient[httpx.AsyncClient, AsyncStream[t.Any]]): - def __init__(self, base_url: str | httpx.URL, version: str, timeout: int | httpx.Timeout = DEFAULT_TIMEOUT, max_retries: int = MAX_RETRIES): + def __init__( + self, + base_url: str | httpx.URL, + version: str, + timeout: int | httpx.Timeout = DEFAULT_TIMEOUT, + max_retries: int = MAX_RETRIES, + ): super().__init__( base_url=base_url, version=version, @@ -471,7 +512,7 @@ class AsyncClient(BaseClient[httpx.AsyncClient, AsyncStream[t.Any]]): try: loop = asyncio.get_event_loop() if loop.is_running(): - loop.create_task(self.close()) + loop.create_task(self.close()) # noqa else: loop.run_until_complete(self.close()) except Exception: @@ -486,7 +527,9 @@ class AsyncClient(BaseClient[httpx.AsyncClient, AsyncStream[t.Any]]): stream: bool = False, stream_cls: type[_AsyncStream] | None = None, ) -> Response | _AsyncStream: - return await self._request(response_cls, options, remaining_retries=remaining_retries, stream=stream, stream_cls=stream_cls) + return await self._request( + response_cls, options, remaining_retries=remaining_retries, stream=stream, stream_cls=stream_cls + ) async def _request( self, @@ -506,7 +549,9 @@ class AsyncClient(BaseClient[httpx.AsyncClient, AsyncStream[t.Any]]): response.raise_for_status() except httpx.HTTPStatusError as exc: if retries > 0 and self._should_retry(exc.response): - return self._retry_request(response_cls, options, retries, exc.response.headers, stream=stream, stream_cls=stream_cls) + return self._retry_request( + response_cls, options, retries, exc.response.headers, stream=stream, stream_cls=stream_cls + ) # If the response is streamed then we need to explicitly read the completed response await exc.response.aread() raise ValueError(exc.message) from None @@ -526,7 +571,9 @@ class AsyncClient(BaseClient[httpx.AsyncClient, AsyncStream[t.Any]]): return await self._retry_request(response_cls, options, retries, stream=stream, stream_cls=stream_cls) raise ValueError(request) from err # connection error - return self._process_response(response_cls=response_cls, options=options, raw_response=response, stream=stream, stream_cls=stream_cls) + return self._process_response( + response_cls=response_cls, options=options, raw_response=response, stream=stream, stream_cls=stream_cls + ) async def _retry_request( self, @@ -555,7 +602,9 @@ class AsyncClient(BaseClient[httpx.AsyncClient, AsyncStream[t.Any]]): ) -> Response | _AsyncStream: if options is None: options = {} - return await self.request(response_cls, RequestOptions(method='GET', url=path, **options), stream=stream, stream_cls=stream_cls) + return await self.request( + response_cls, RequestOptions(method='GET', url=path, **options), stream=stream, stream_cls=stream_cls + ) async def _post( self, @@ -569,4 +618,6 @@ class AsyncClient(BaseClient[httpx.AsyncClient, AsyncStream[t.Any]]): ) -> Response | _AsyncStream: if options is None: options = {} - return await self.request(response_cls, RequestOptions(method='POST', url=path, json=json, **options), stream=stream, stream_cls=stream_cls) + return await self.request( + response_cls, RequestOptions(method='POST', url=path, json=json, **options), stream=stream, stream_cls=stream_cls + ) diff --git a/openllm-client/src/openllm_client/_stream.py b/openllm-client/src/openllm_client/_stream.py index a5103207..e81a7fb0 100644 --- a/openllm-client/src/openllm_client/_stream.py +++ b/openllm-client/src/openllm_client/_stream.py @@ -38,7 +38,9 @@ class Stream(t.Generic[Response]): if sse.data.startswith('[DONE]'): break if sse.event is None: - yield self._client._process_response_data(data=sse.model_dump(), response_cls=self._response_cls, raw_response=self._response) + yield self._client._process_response_data( + data=sse.model_dump(), response_cls=self._response_cls, raw_response=self._response + ) @attr.define(auto_attribs=True) @@ -69,7 +71,9 @@ class AsyncStream(t.Generic[Response]): if sse.data.startswith('[DONE]'): break if sse.event is None: - yield self._client._process_response_data(data=sse.model_dump(), response_cls=self._response_cls, raw_response=self._response) + yield self._client._process_response_data( + data=sse.model_dump(), response_cls=self._response_cls, raw_response=self._response + ) @attr.define diff --git a/openllm-client/src/openllm_client/_typing_compat.py b/openllm-client/src/openllm_client/_typing_compat.py index 48bd0a85..15d86f8a 100644 --- a/openllm-client/src/openllm_client/_typing_compat.py +++ b/openllm-client/src/openllm_client/_typing_compat.py @@ -11,5 +11,7 @@ from openllm_core._typing_compat import ( overload as overload, ) -Platform = Annotated[LiteralString, Literal['MacOS', 'Linux', 'Windows', 'FreeBSD', 'OpenBSD', 'iOS', 'iPadOS', 'Android', 'Unknown'], str] +Platform = Annotated[ + LiteralString, Literal['MacOS', 'Linux', 'Windows', 'FreeBSD', 'OpenBSD', 'iOS', 'iPadOS', 'Android', 'Unknown'], str +] Architecture = Annotated[LiteralString, Literal['arm', 'arm64', 'x86', 'x86_64', 'Unknown'], str] diff --git a/openllm-python/src/openllm/client.pyi b/openllm-python/src/openllm/client.pyi index f26d4fa4..2ab9656e 100644 --- a/openllm-python/src/openllm/client.pyi +++ b/openllm-python/src/openllm/client.pyi @@ -1,8 +1,8 @@ """OpenLLM Python client. ```python -client = openllm.client.HTTPClient("http://localhost:8080") -client.query("What is the difference between gather and scatter?") +client = openllm.client.HTTPClient('http://localhost:8080') +client.query('What is the difference between gather and scatter?') ``` """ diff --git a/openllm-python/src/openllm_cli/_sdk.py b/openllm-python/src/openllm_cli/_sdk.py index 1c9c04de..a5167cbe 100644 --- a/openllm-python/src/openllm_cli/_sdk.py +++ b/openllm-python/src/openllm_cli/_sdk.py @@ -81,7 +81,7 @@ def _start( if adapter_map: args.extend( list( - itertools.chain.from_iterable([['--adapter-id', f"{k}{':'+v if v else ''}"] for k, v in adapter_map.items()]) + itertools.chain.from_iterable([['--adapter-id', f"{k}{':' + v if v else ''}"] for k, v in adapter_map.items()]) ) ) if additional_args: @@ -173,7 +173,7 @@ def _build( if overwrite: args.append('--overwrite') if adapter_map: - args.extend([f"--adapter-id={k}{':'+v if v is not None else ''}" for k, v in adapter_map.items()]) + args.extend([f"--adapter-id={k}{':' + v if v is not None else ''}" for k, v in adapter_map.items()]) if model_version: args.extend(['--model-version', model_version]) if bento_version: diff --git a/openllm-python/src/openllm_cli/extension/dive_bentos.py b/openllm-python/src/openllm_cli/extension/dive_bentos.py index db488004..541d07bf 100644 --- a/openllm-python/src/openllm_cli/extension/dive_bentos.py +++ b/openllm-python/src/openllm_cli/extension/dive_bentos.py @@ -21,7 +21,9 @@ if t.TYPE_CHECKING: @machine_option @click.pass_context @inject -def cli(ctx: click.Context, bento: str, machine: bool, _bento_store: BentoStore = Provide[BentoMLContainer.bento_store]) -> str | None: +def cli( + ctx: click.Context, bento: str, machine: bool, _bento_store: BentoStore = Provide[BentoMLContainer.bento_store] +) -> str | None: """Dive into a BentoLLM. This is synonymous to cd $(b get : -o path).""" try: bentomodel = _bento_store.get(bento) diff --git a/openllm-python/src/openllm_cli/extension/get_containerfile.py b/openllm-python/src/openllm_cli/extension/get_containerfile.py index 88605414..50798829 100644 --- a/openllm-python/src/openllm_cli/extension/get_containerfile.py +++ b/openllm-python/src/openllm_cli/extension/get_containerfile.py @@ -17,7 +17,9 @@ if t.TYPE_CHECKING: from bentoml._internal.bento import BentoStore -@click.command('get_containerfile', context_settings=termui.CONTEXT_SETTINGS, help='Return Containerfile of any given Bento.') +@click.command( + 'get_containerfile', context_settings=termui.CONTEXT_SETTINGS, help='Return Containerfile of any given Bento.' +) @click.argument('bento', type=str, shell_complete=bento_complete_envvar) @click.pass_context @inject diff --git a/openllm-python/src/openllm_cli/extension/get_prompt.py b/openllm-python/src/openllm_cli/extension/get_prompt.py index b679577f..0e64c230 100644 --- a/openllm-python/src/openllm_cli/extension/get_prompt.py +++ b/openllm-python/src/openllm_cli/extension/get_prompt.py @@ -22,7 +22,9 @@ class PromptFormatter(string.Formatter): raise ValueError('Positional arguments are not supported') return super().vformat(format_string, args, kwargs) - def check_unused_args(self, used_args: set[int | str], args: t.Sequence[t.Any], kwargs: t.Mapping[str, t.Any]) -> None: + def check_unused_args( + self, used_args: set[int | str], args: t.Sequence[t.Any], kwargs: t.Mapping[str, t.Any] + ) -> None: extras = set(kwargs).difference(used_args) if extras: raise KeyError(f'Extra params passed: {extras}') @@ -56,7 +58,9 @@ class PromptTemplate: try: return self.template.format(**prompt_variables) except KeyError as e: - raise RuntimeError(f"Missing variable '{e.args[0]}' (required: {self._input_variables}) in the prompt template.") from None + raise RuntimeError( + f"Missing variable '{e.args[0]}' (required: {self._input_variables}) in the prompt template." + ) from None @click.command('get_prompt', context_settings=termui.CONTEXT_SETTINGS) @@ -124,15 +128,21 @@ def cli( if prompt_template_file and chat_template_file: ctx.fail('prompt-template-file and chat-template-file are mutually exclusive.') - acceptable = set(openllm.CONFIG_MAPPING_NAMES.keys()) | set(inflection.dasherize(name) for name in openllm.CONFIG_MAPPING_NAMES.keys()) + acceptable = set(openllm.CONFIG_MAPPING_NAMES.keys()) | set( + inflection.dasherize(name) for name in openllm.CONFIG_MAPPING_NAMES.keys() + ) if model_id in acceptable: - logger.warning('Using a default prompt from OpenLLM. Note that this prompt might not work for your intended usage.\n') + logger.warning( + 'Using a default prompt from OpenLLM. Note that this prompt might not work for your intended usage.\n' + ) config = openllm.AutoConfig.for_model(model_id) template = prompt_template_file.read() if prompt_template_file is not None else config.template system_message = system_message or config.system_message try: - formatted = PromptTemplate(template).with_options(system_message=system_message).format(instruction=prompt, **_memoized) + formatted = ( + PromptTemplate(template).with_options(system_message=system_message).format(instruction=prompt, **_memoized) + ) except RuntimeError as err: logger.debug('Exception caught while formatting prompt: %s', err) ctx.fail(str(err)) @@ -149,15 +159,21 @@ def cli( for architecture in config.architectures: if architecture in openllm.AutoConfig._CONFIG_MAPPING_NAMES_TO_ARCHITECTURE(): system_message = ( - openllm.AutoConfig.infer_class_from_name(openllm.AutoConfig._CONFIG_MAPPING_NAMES_TO_ARCHITECTURE()[architecture]) + openllm.AutoConfig.infer_class_from_name( + openllm.AutoConfig._CONFIG_MAPPING_NAMES_TO_ARCHITECTURE()[architecture] + ) .model_construct_env() .system_message ) break else: - ctx.fail(f'Failed to infer system message from model architecture: {config.architectures}. Please pass in --system-message') + ctx.fail( + f'Failed to infer system message from model architecture: {config.architectures}. Please pass in --system-message' + ) messages = [{'role': 'system', 'content': system_message}, {'role': 'user', 'content': prompt}] - formatted = tokenizer.apply_chat_template(messages, chat_template=chat_template_file, add_generation_prompt=add_generation_prompt, tokenize=False) + formatted = tokenizer.apply_chat_template( + messages, chat_template=chat_template_file, add_generation_prompt=add_generation_prompt, tokenize=False + ) termui.echo(orjson.dumps({'prompt': formatted}, option=orjson.OPT_INDENT_2).decode(), fg='white') ctx.exit(0) diff --git a/openllm-python/src/openllm_cli/extension/list_models.py b/openllm-python/src/openllm_cli/extension/list_models.py index eb18ce0d..6eb49e07 100644 --- a/openllm-python/src/openllm_cli/extension/list_models.py +++ b/openllm-python/src/openllm_cli/extension/list_models.py @@ -33,12 +33,17 @@ def cli(model_name: str | None) -> DictStrAny: } if model_name is not None: ids_in_local_store = { - k: [i for i in v if 'model_name' in i.info.labels and i.info.labels['model_name'] == inflection.dasherize(model_name)] + k: [ + i + for i in v + if 'model_name' in i.info.labels and i.info.labels['model_name'] == inflection.dasherize(model_name) + ] for k, v in ids_in_local_store.items() } ids_in_local_store = {k: v for k, v in ids_in_local_store.items() if v} local_models = { - k: [{'tag': str(i.tag), 'size': human_readable_size(openllm.utils.calc_dir_size(i.path))} for i in val] for k, val in ids_in_local_store.items() + k: [{'tag': str(i.tag), 'size': human_readable_size(openllm.utils.calc_dir_size(i.path))} for i in val] + for k, val in ids_in_local_store.items() } termui.echo(orjson.dumps(local_models, option=orjson.OPT_INDENT_2).decode(), fg='white') return local_models diff --git a/openllm-python/src/openllm_cli/extension/playground.py b/openllm-python/src/openllm_cli/extension/playground.py index f8e5b4da..fcbc128b 100644 --- a/openllm-python/src/openllm_cli/extension/playground.py +++ b/openllm-python/src/openllm_cli/extension/playground.py @@ -32,7 +32,14 @@ def load_notebook_metadata() -> DictStrAny: @click.command('playground', context_settings=termui.CONTEXT_SETTINGS) @click.argument('output-dir', default=None, required=False) -@click.option('--port', envvar='JUPYTER_PORT', show_envvar=True, show_default=True, default=8888, help='Default port for Jupyter server') +@click.option( + '--port', + envvar='JUPYTER_PORT', + show_envvar=True, + show_default=True, + default=8888, + help='Default port for Jupyter server', +) @click.pass_context def cli(ctx: click.Context, output_dir: str | None, port: int) -> None: """OpenLLM Playground. @@ -53,7 +60,9 @@ def cli(ctx: click.Context, output_dir: str | None, port: int) -> None: > This command requires Jupyter to be installed. Install it with 'pip install "openllm[playground]"' """ if not is_jupyter_available() or not is_jupytext_available() or not is_notebook_available(): - raise RuntimeError("Playground requires 'jupyter', 'jupytext', and 'notebook'. Install it with 'pip install \"openllm[playground]\"'") + raise RuntimeError( + "Playground requires 'jupyter', 'jupytext', and 'notebook'. Install it with 'pip install \"openllm[playground]\"'" + ) metadata = load_notebook_metadata() _temp_dir = False if output_dir is None: @@ -65,7 +74,9 @@ def cli(ctx: click.Context, output_dir: str | None, port: int) -> None: termui.echo('The playground notebooks will be saved to: ' + os.path.abspath(output_dir), fg='blue') for module in pkgutil.iter_modules(playground.__path__): if module.ispkg or os.path.exists(os.path.join(output_dir, module.name + '.ipynb')): - logger.debug('Skipping: %s (%s)', module.name, 'File already exists' if not module.ispkg else f'{module.name} is a module') + logger.debug( + 'Skipping: %s (%s)', module.name, 'File already exists' if not module.ispkg else f'{module.name} is a module' + ) continue if not isinstance(module.module_finder, importlib.machinery.FileFinder): continue diff --git a/openllm-python/tests/configuration_test.py b/openllm-python/tests/configuration_test.py index 90069b79..fafa4098 100644 --- a/openllm-python/tests/configuration_test.py +++ b/openllm-python/tests/configuration_test.py @@ -66,8 +66,15 @@ def test_config_derived_follow_attrs_protocol(gen_settings: ModelSettings): st.integers(max_value=283473), st.floats(min_value=0.0, max_value=1.0), ) -def test_complex_struct_dump(gen_settings: ModelSettings, field1: int, temperature: float, input_field1: int, input_temperature: float): - cl_ = make_llm_config('ComplexLLM', gen_settings, fields=(('field1', 'float', field1),), generation_fields=(('temperature', temperature),)) +def test_complex_struct_dump( + gen_settings: ModelSettings, field1: int, temperature: float, input_field1: int, input_temperature: float +): + cl_ = make_llm_config( + 'ComplexLLM', + gen_settings, + fields=(('field1', 'float', field1),), + generation_fields=(('temperature', temperature),), + ) sent = cl_() assert sent.model_dump()['field1'] == field1 assert sent.model_dump()['generation_config']['temperature'] == temperature diff --git a/openllm-python/tests/conftest.py b/openllm-python/tests/conftest.py index 1efd9e4d..e49b2656 100644 --- a/openllm-python/tests/conftest.py +++ b/openllm-python/tests/conftest.py @@ -10,8 +10,14 @@ import openllm if t.TYPE_CHECKING: from openllm_core._typing_compat import LiteralBackend -_MODELING_MAPPING = {'flan_t5': 'google/flan-t5-small', 'opt': 'facebook/opt-125m', 'baichuan': 'baichuan-inc/Baichuan-7B'} -_PROMPT_MAPPING = {'qa': 'Answer the following yes/no question by reasoning step-by-step. Can you write a whole Haiku in a single tweet?'} +_MODELING_MAPPING = { + 'flan_t5': 'google/flan-t5-small', + 'opt': 'facebook/opt-125m', + 'baichuan': 'baichuan-inc/Baichuan-7B', +} +_PROMPT_MAPPING = { + 'qa': 'Answer the following yes/no question by reasoning step-by-step. Can you write a whole Haiku in a single tweet?' +} def parametrise_local_llm(model: str) -> t.Generator[tuple[str, openllm.LLM[t.Any, t.Any]], None, None]: @@ -25,7 +31,9 @@ def parametrise_local_llm(model: str) -> t.Generator[tuple[str, openllm.LLM[t.An def pytest_generate_tests(metafunc: pytest.Metafunc) -> None: if os.getenv('GITHUB_ACTIONS') is None: if 'prompt' in metafunc.fixturenames and 'llm' in metafunc.fixturenames: - metafunc.parametrize('prompt,llm', [(p, llm) for p, llm in parametrise_local_llm(metafunc.function.__name__[5:-15])]) + metafunc.parametrize( + 'prompt,llm', [(p, llm) for p, llm in parametrise_local_llm(metafunc.function.__name__[5:-15])] + ) def pytest_sessionfinish(session: pytest.Session, exitstatus: int): diff --git a/openllm-python/tests/strategies_test.py b/openllm-python/tests/strategies_test.py index f801ed81..6b95ac0d 100644 --- a/openllm-python/tests/strategies_test.py +++ b/openllm-python/tests/strategies_test.py @@ -73,9 +73,13 @@ def test_nvidia_gpu_validate(monkeypatch: pytest.MonkeyPatch): mcls.setenv('CUDA_VISIBLE_DEVICES', '') assert len(NvidiaGpuResource.from_system()) >= 0 # TODO: real from_system tests - assert pytest.raises(ValueError, NvidiaGpuResource.validate, [*NvidiaGpuResource.from_system(), 1]).match('Input list should be all string type.') + assert pytest.raises(ValueError, NvidiaGpuResource.validate, [*NvidiaGpuResource.from_system(), 1]).match( + 'Input list should be all string type.' + ) assert pytest.raises(ValueError, NvidiaGpuResource.validate, [-2]).match('Input list should be all string type.') - assert pytest.raises(ValueError, NvidiaGpuResource.validate, ['GPU-5ebe9f43', 'GPU-ac33420d4628']).match('Failed to parse available GPUs UUID') + assert pytest.raises(ValueError, NvidiaGpuResource.validate, ['GPU-5ebe9f43', 'GPU-ac33420d4628']).match( + 'Failed to parse available GPUs UUID' + ) def test_nvidia_gpu_from_spec(monkeypatch: pytest.MonkeyPatch):