mirror of
https://github.com/bentoml/OpenLLM.git
synced 2025-12-23 15:47:49 -05:00
chore: running updated ruff formatter [skip ci]
Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
@@ -15,11 +15,17 @@ class HTTPClient:
|
||||
address: str
|
||||
helpers: _Helpers
|
||||
@overload
|
||||
def __init__(self, address: str, timeout: int = ..., verify: bool = ..., max_retries: int = ..., api_version: str = ...) -> None: ...
|
||||
def __init__(
|
||||
self, address: str, timeout: int = ..., verify: bool = ..., max_retries: int = ..., api_version: str = ...
|
||||
) -> None: ...
|
||||
@overload
|
||||
def __init__(self, address: str = ..., timeout: int = ..., verify: bool = ..., max_retries: int = ..., api_version: str = ...) -> None: ...
|
||||
def __init__(
|
||||
self, address: str = ..., timeout: int = ..., verify: bool = ..., max_retries: int = ..., api_version: str = ...
|
||||
) -> None: ...
|
||||
@overload
|
||||
def __init__(self, address: None = ..., timeout: int = ..., verify: bool = ..., max_retries: int = ..., api_version: str = ...) -> None: ...
|
||||
def __init__(
|
||||
self, address: None = ..., timeout: int = ..., verify: bool = ..., max_retries: int = ..., api_version: str = ...
|
||||
) -> None: ...
|
||||
@property
|
||||
def is_ready(self) -> bool: ...
|
||||
def health(self) -> bool: ...
|
||||
@@ -60,11 +66,17 @@ class AsyncHTTPClient:
|
||||
address: str
|
||||
helpers: _AsyncHelpers
|
||||
@overload
|
||||
def __init__(self, address: str, timeout: int = ..., verify: bool = ..., max_retries: int = ..., api_version: str = ...) -> None: ...
|
||||
def __init__(
|
||||
self, address: str, timeout: int = ..., verify: bool = ..., max_retries: int = ..., api_version: str = ...
|
||||
) -> None: ...
|
||||
@overload
|
||||
def __init__(self, address: str = ..., timeout: int = ..., verify: bool = ..., max_retries: int = ..., api_version: str = ...) -> None: ...
|
||||
def __init__(
|
||||
self, address: str = ..., timeout: int = ..., verify: bool = ..., max_retries: int = ..., api_version: str = ...
|
||||
) -> None: ...
|
||||
@overload
|
||||
def __init__(self, address: None = ..., timeout: int = ..., verify: bool = ..., max_retries: int = ..., api_version: str = ...) -> None: ...
|
||||
def __init__(
|
||||
self, address: None = ..., timeout: int = ..., verify: bool = ..., max_retries: int = ..., api_version: str = ...
|
||||
) -> None: ...
|
||||
@property
|
||||
def is_ready(self) -> bool: ...
|
||||
async def health(self) -> bool: ...
|
||||
|
||||
@@ -70,10 +70,14 @@ class HTTPClient(Client):
|
||||
return self.generate(prompt, **attrs)
|
||||
|
||||
def health(self):
|
||||
response = self._get('/readyz', response_cls=None, options={'return_raw_response': True, 'max_retries': self._max_retries})
|
||||
response = self._get(
|
||||
'/readyz', response_cls=None, options={'return_raw_response': True, 'max_retries': self._max_retries}
|
||||
)
|
||||
return response.status_code == 200
|
||||
|
||||
def generate(self, prompt, llm_config=None, stop=None, adapter_name=None, timeout=None, verify=None, **attrs) -> Response:
|
||||
def generate(
|
||||
self, prompt, llm_config=None, stop=None, adapter_name=None, timeout=None, verify=None, **attrs
|
||||
) -> Response:
|
||||
if timeout is None:
|
||||
timeout = self._timeout
|
||||
if verify is None:
|
||||
@@ -96,7 +100,9 @@ class HTTPClient(Client):
|
||||
for response_chunk in self.generate_iterator(prompt, llm_config, stop, adapter_name, timeout, verify, **attrs):
|
||||
yield StreamingResponse.from_response_chunk(response_chunk)
|
||||
|
||||
def generate_iterator(self, prompt, llm_config=None, stop=None, adapter_name=None, timeout=None, verify=None, **attrs) -> t.Iterator[Response]:
|
||||
def generate_iterator(
|
||||
self, prompt, llm_config=None, stop=None, adapter_name=None, timeout=None, verify=None, **attrs
|
||||
) -> t.Iterator[Response]:
|
||||
if timeout is None:
|
||||
timeout = self._timeout
|
||||
if verify is None:
|
||||
@@ -146,7 +152,9 @@ class AsyncHTTPClient(AsyncClient):
|
||||
@property
|
||||
async def _metadata(self) -> t.Awaitable[Metadata]:
|
||||
if self.__metadata is None:
|
||||
self.__metadata = await self._post(f'/{self._api_version}/metadata', response_cls=Metadata, json={}, options={'max_retries': self._max_retries})
|
||||
self.__metadata = await self._post(
|
||||
f'/{self._api_version}/metadata', response_cls=Metadata, json={}, options={'max_retries': self._max_retries}
|
||||
)
|
||||
return self.__metadata
|
||||
|
||||
@property
|
||||
@@ -159,10 +167,14 @@ class AsyncHTTPClient(AsyncClient):
|
||||
return await self.generate(prompt, **attrs)
|
||||
|
||||
async def health(self):
|
||||
response = await self._get('/readyz', response_cls=None, options={'return_raw_response': True, 'max_retries': self._max_retries})
|
||||
response = await self._get(
|
||||
'/readyz', response_cls=None, options={'return_raw_response': True, 'max_retries': self._max_retries}
|
||||
)
|
||||
return response.status_code == 200
|
||||
|
||||
async def generate(self, prompt, llm_config=None, stop=None, adapter_name=None, timeout=None, verify=None, **attrs) -> Response:
|
||||
async def generate(
|
||||
self, prompt, llm_config=None, stop=None, adapter_name=None, timeout=None, verify=None, **attrs
|
||||
) -> Response:
|
||||
if timeout is None:
|
||||
timeout = self._timeout
|
||||
if verify is None:
|
||||
@@ -183,7 +195,9 @@ class AsyncHTTPClient(AsyncClient):
|
||||
async def generate_stream(
|
||||
self, prompt, llm_config=None, stop=None, adapter_name=None, timeout=None, verify=None, **attrs
|
||||
) -> t.AsyncGenerator[StreamingResponse, t.Any]:
|
||||
async for response_chunk in self.generate_iterator(prompt, llm_config, stop, adapter_name, timeout, verify, **attrs):
|
||||
async for response_chunk in self.generate_iterator(
|
||||
prompt, llm_config, stop, adapter_name, timeout, verify, **attrs
|
||||
):
|
||||
yield StreamingResponse.from_response_chunk(response_chunk)
|
||||
|
||||
async def generate_iterator(
|
||||
|
||||
@@ -17,7 +17,7 @@ if t.TYPE_CHECKING:
|
||||
from ._shim import AsyncClient, Client
|
||||
|
||||
|
||||
__all__ = ['Response', 'CompletionChunk', 'Metadata', 'StreamingResponse', 'Helpers']
|
||||
__all__ = ['CompletionChunk', 'Helpers', 'Metadata', 'Response', 'StreamingResponse']
|
||||
|
||||
|
||||
@attr.define
|
||||
@@ -42,7 +42,11 @@ def _structure_metadata(data: t.Dict[str, t.Any], cls: type[Metadata]) -> Metada
|
||||
raise RuntimeError(f'Malformed metadata configuration (Server-side issue): {e}') from None
|
||||
try:
|
||||
return cls(
|
||||
model_id=data['model_id'], timeout=data['timeout'], model_name=data['model_name'], backend=data['backend'], configuration=configuration
|
||||
model_id=data['model_id'],
|
||||
timeout=data['timeout'],
|
||||
model_name=data['model_name'],
|
||||
backend=data['backend'],
|
||||
configuration=configuration,
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f'Malformed metadata (Server-side issue): {e}') from None
|
||||
@@ -61,7 +65,10 @@ class StreamingResponse(_SchemaMixin):
|
||||
@classmethod
|
||||
def from_response_chunk(cls, response: Response) -> StreamingResponse:
|
||||
return cls(
|
||||
request_id=response.request_id, index=response.outputs[0].index, text=response.outputs[0].text, token_ids=response.outputs[0].token_ids[0]
|
||||
request_id=response.request_id,
|
||||
index=response.outputs[0].index,
|
||||
text=response.outputs[0].text,
|
||||
token_ids=response.outputs[0].token_ids[0],
|
||||
)
|
||||
|
||||
|
||||
@@ -88,11 +95,17 @@ class Helpers:
|
||||
return self._async_client
|
||||
|
||||
def messages(self, messages, add_generation_prompt=False):
|
||||
return self.client._post('/v1/helpers/messages', response_cls=str, json=dict(messages=messages, add_generation_prompt=add_generation_prompt))
|
||||
return self.client._post(
|
||||
'/v1/helpers/messages',
|
||||
response_cls=str,
|
||||
json=dict(messages=messages, add_generation_prompt=add_generation_prompt),
|
||||
)
|
||||
|
||||
async def async_messages(self, messages, add_generation_prompt=False):
|
||||
return await self.async_client._post(
|
||||
'/v1/helpers/messages', response_cls=str, json=dict(messages=messages, add_generation_prompt=add_generation_prompt)
|
||||
'/v1/helpers/messages',
|
||||
response_cls=str,
|
||||
json=dict(messages=messages, add_generation_prompt=add_generation_prompt),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -140,7 +140,9 @@ class APIResponse(t.Generic[Response]):
|
||||
|
||||
data = self._raw_response.json()
|
||||
try:
|
||||
return self._client._process_response_data(data=data, response_cls=self._response_cls, raw_response=self._raw_response)
|
||||
return self._client._process_response_data(
|
||||
data=data, response_cls=self._response_cls, raw_response=self._raw_response
|
||||
)
|
||||
except Exception as exc:
|
||||
raise ValueError(exc) from None # validation error here
|
||||
|
||||
@@ -271,10 +273,16 @@ class BaseClient(t.Generic[InnerClient, StreamType]):
|
||||
|
||||
def _build_request(self, options: RequestOptions) -> httpx.Request:
|
||||
return self._inner.build_request(
|
||||
method=options.method, headers=self._build_headers(options), url=self._prepare_url(options.url), json=options.json, params=options.params
|
||||
method=options.method,
|
||||
headers=self._build_headers(options),
|
||||
url=self._prepare_url(options.url),
|
||||
json=options.json,
|
||||
params=options.params,
|
||||
)
|
||||
|
||||
def _calculate_retry_timeout(self, remaining_retries: int, options: RequestOptions, headers: t.Optional[httpx.Headers] = None) -> float:
|
||||
def _calculate_retry_timeout(
|
||||
self, remaining_retries: int, options: RequestOptions, headers: t.Optional[httpx.Headers] = None
|
||||
) -> float:
|
||||
max_retries = options.get_max_retries(self._max_retries)
|
||||
# https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After
|
||||
try:
|
||||
@@ -315,7 +323,9 @@ class BaseClient(t.Generic[InnerClient, StreamType]):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _process_response_data(self, *, response_cls: type[Response], data: t.Dict[str, t.Any], raw_response: httpx.Response) -> Response:
|
||||
def _process_response_data(
|
||||
self, *, response_cls: type[Response], data: t.Dict[str, t.Any], raw_response: httpx.Response
|
||||
) -> Response:
|
||||
return converter.structure(data, response_cls)
|
||||
|
||||
def _process_response(
|
||||
@@ -328,13 +338,24 @@ class BaseClient(t.Generic[InnerClient, StreamType]):
|
||||
stream_cls: type[_Stream] | type[_AsyncStream] | None,
|
||||
) -> Response:
|
||||
return APIResponse(
|
||||
raw_response=raw_response, client=self, response_cls=response_cls, stream=stream, stream_cls=stream_cls, options=options
|
||||
raw_response=raw_response,
|
||||
client=self,
|
||||
response_cls=response_cls,
|
||||
stream=stream,
|
||||
stream_cls=stream_cls,
|
||||
options=options,
|
||||
).parse()
|
||||
|
||||
|
||||
@attr.define(init=False)
|
||||
class Client(BaseClient[httpx.Client, Stream[t.Any]]):
|
||||
def __init__(self, base_url: str | httpx.URL, version: str, timeout: int | httpx.Timeout = DEFAULT_TIMEOUT, max_retries: int = MAX_RETRIES):
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str | httpx.URL,
|
||||
version: str,
|
||||
timeout: int | httpx.Timeout = DEFAULT_TIMEOUT,
|
||||
max_retries: int = MAX_RETRIES,
|
||||
):
|
||||
super().__init__(
|
||||
base_url=base_url,
|
||||
version=version,
|
||||
@@ -366,7 +387,13 @@ class Client(BaseClient[httpx.Client, Stream[t.Any]]):
|
||||
stream: bool = False,
|
||||
stream_cls: type[_Stream] | None = None,
|
||||
) -> Response | _Stream:
|
||||
return self._request(response_cls=response_cls, options=options, remaining_retries=remaining_retries, stream=stream, stream_cls=stream_cls)
|
||||
return self._request(
|
||||
response_cls=response_cls,
|
||||
options=options,
|
||||
remaining_retries=remaining_retries,
|
||||
stream=stream,
|
||||
stream_cls=stream_cls,
|
||||
)
|
||||
|
||||
def _request(
|
||||
self,
|
||||
@@ -385,7 +412,9 @@ class Client(BaseClient[httpx.Client, Stream[t.Any]]):
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as exc:
|
||||
if retries > 0 and self._should_retry(exc.response):
|
||||
return self._retry_request(response_cls, options, retries, exc.response.headers, stream=stream, stream_cls=stream_cls)
|
||||
return self._retry_request(
|
||||
response_cls, options, retries, exc.response.headers, stream=stream, stream_cls=stream_cls
|
||||
)
|
||||
# If the response is streamed then we need to explicitly read the completed response
|
||||
exc.response.read()
|
||||
raise ValueError(exc.message) from None
|
||||
@@ -398,7 +427,9 @@ class Client(BaseClient[httpx.Client, Stream[t.Any]]):
|
||||
return self._retry_request(response_cls, options, retries, stream=stream, stream_cls=stream_cls)
|
||||
raise ValueError(request) from None # connection error
|
||||
|
||||
return self._process_response(response_cls=response_cls, options=options, raw_response=response, stream=stream, stream_cls=stream_cls)
|
||||
return self._process_response(
|
||||
response_cls=response_cls, options=options, raw_response=response, stream=stream, stream_cls=stream_cls
|
||||
)
|
||||
|
||||
def _retry_request(
|
||||
self,
|
||||
@@ -428,7 +459,9 @@ class Client(BaseClient[httpx.Client, Stream[t.Any]]):
|
||||
) -> Response | _Stream:
|
||||
if options is None:
|
||||
options = {}
|
||||
return self.request(response_cls, RequestOptions(method='GET', url=path, **options), stream=stream, stream_cls=stream_cls)
|
||||
return self.request(
|
||||
response_cls, RequestOptions(method='GET', url=path, **options), stream=stream, stream_cls=stream_cls
|
||||
)
|
||||
|
||||
def _post(
|
||||
self,
|
||||
@@ -442,12 +475,20 @@ class Client(BaseClient[httpx.Client, Stream[t.Any]]):
|
||||
) -> Response | _Stream:
|
||||
if options is None:
|
||||
options = {}
|
||||
return self.request(response_cls, RequestOptions(method='POST', url=path, json=json, **options), stream=stream, stream_cls=stream_cls)
|
||||
return self.request(
|
||||
response_cls, RequestOptions(method='POST', url=path, json=json, **options), stream=stream, stream_cls=stream_cls
|
||||
)
|
||||
|
||||
|
||||
@attr.define(init=False)
|
||||
class AsyncClient(BaseClient[httpx.AsyncClient, AsyncStream[t.Any]]):
|
||||
def __init__(self, base_url: str | httpx.URL, version: str, timeout: int | httpx.Timeout = DEFAULT_TIMEOUT, max_retries: int = MAX_RETRIES):
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str | httpx.URL,
|
||||
version: str,
|
||||
timeout: int | httpx.Timeout = DEFAULT_TIMEOUT,
|
||||
max_retries: int = MAX_RETRIES,
|
||||
):
|
||||
super().__init__(
|
||||
base_url=base_url,
|
||||
version=version,
|
||||
@@ -471,7 +512,7 @@ class AsyncClient(BaseClient[httpx.AsyncClient, AsyncStream[t.Any]]):
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_running():
|
||||
loop.create_task(self.close())
|
||||
loop.create_task(self.close()) # noqa
|
||||
else:
|
||||
loop.run_until_complete(self.close())
|
||||
except Exception:
|
||||
@@ -486,7 +527,9 @@ class AsyncClient(BaseClient[httpx.AsyncClient, AsyncStream[t.Any]]):
|
||||
stream: bool = False,
|
||||
stream_cls: type[_AsyncStream] | None = None,
|
||||
) -> Response | _AsyncStream:
|
||||
return await self._request(response_cls, options, remaining_retries=remaining_retries, stream=stream, stream_cls=stream_cls)
|
||||
return await self._request(
|
||||
response_cls, options, remaining_retries=remaining_retries, stream=stream, stream_cls=stream_cls
|
||||
)
|
||||
|
||||
async def _request(
|
||||
self,
|
||||
@@ -506,7 +549,9 @@ class AsyncClient(BaseClient[httpx.AsyncClient, AsyncStream[t.Any]]):
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as exc:
|
||||
if retries > 0 and self._should_retry(exc.response):
|
||||
return self._retry_request(response_cls, options, retries, exc.response.headers, stream=stream, stream_cls=stream_cls)
|
||||
return self._retry_request(
|
||||
response_cls, options, retries, exc.response.headers, stream=stream, stream_cls=stream_cls
|
||||
)
|
||||
# If the response is streamed then we need to explicitly read the completed response
|
||||
await exc.response.aread()
|
||||
raise ValueError(exc.message) from None
|
||||
@@ -526,7 +571,9 @@ class AsyncClient(BaseClient[httpx.AsyncClient, AsyncStream[t.Any]]):
|
||||
return await self._retry_request(response_cls, options, retries, stream=stream, stream_cls=stream_cls)
|
||||
raise ValueError(request) from err # connection error
|
||||
|
||||
return self._process_response(response_cls=response_cls, options=options, raw_response=response, stream=stream, stream_cls=stream_cls)
|
||||
return self._process_response(
|
||||
response_cls=response_cls, options=options, raw_response=response, stream=stream, stream_cls=stream_cls
|
||||
)
|
||||
|
||||
async def _retry_request(
|
||||
self,
|
||||
@@ -555,7 +602,9 @@ class AsyncClient(BaseClient[httpx.AsyncClient, AsyncStream[t.Any]]):
|
||||
) -> Response | _AsyncStream:
|
||||
if options is None:
|
||||
options = {}
|
||||
return await self.request(response_cls, RequestOptions(method='GET', url=path, **options), stream=stream, stream_cls=stream_cls)
|
||||
return await self.request(
|
||||
response_cls, RequestOptions(method='GET', url=path, **options), stream=stream, stream_cls=stream_cls
|
||||
)
|
||||
|
||||
async def _post(
|
||||
self,
|
||||
@@ -569,4 +618,6 @@ class AsyncClient(BaseClient[httpx.AsyncClient, AsyncStream[t.Any]]):
|
||||
) -> Response | _AsyncStream:
|
||||
if options is None:
|
||||
options = {}
|
||||
return await self.request(response_cls, RequestOptions(method='POST', url=path, json=json, **options), stream=stream, stream_cls=stream_cls)
|
||||
return await self.request(
|
||||
response_cls, RequestOptions(method='POST', url=path, json=json, **options), stream=stream, stream_cls=stream_cls
|
||||
)
|
||||
|
||||
@@ -38,7 +38,9 @@ class Stream(t.Generic[Response]):
|
||||
if sse.data.startswith('[DONE]'):
|
||||
break
|
||||
if sse.event is None:
|
||||
yield self._client._process_response_data(data=sse.model_dump(), response_cls=self._response_cls, raw_response=self._response)
|
||||
yield self._client._process_response_data(
|
||||
data=sse.model_dump(), response_cls=self._response_cls, raw_response=self._response
|
||||
)
|
||||
|
||||
|
||||
@attr.define(auto_attribs=True)
|
||||
@@ -69,7 +71,9 @@ class AsyncStream(t.Generic[Response]):
|
||||
if sse.data.startswith('[DONE]'):
|
||||
break
|
||||
if sse.event is None:
|
||||
yield self._client._process_response_data(data=sse.model_dump(), response_cls=self._response_cls, raw_response=self._response)
|
||||
yield self._client._process_response_data(
|
||||
data=sse.model_dump(), response_cls=self._response_cls, raw_response=self._response
|
||||
)
|
||||
|
||||
|
||||
@attr.define
|
||||
|
||||
@@ -11,5 +11,7 @@ from openllm_core._typing_compat import (
|
||||
overload as overload,
|
||||
)
|
||||
|
||||
Platform = Annotated[LiteralString, Literal['MacOS', 'Linux', 'Windows', 'FreeBSD', 'OpenBSD', 'iOS', 'iPadOS', 'Android', 'Unknown'], str]
|
||||
Platform = Annotated[
|
||||
LiteralString, Literal['MacOS', 'Linux', 'Windows', 'FreeBSD', 'OpenBSD', 'iOS', 'iPadOS', 'Android', 'Unknown'], str
|
||||
]
|
||||
Architecture = Annotated[LiteralString, Literal['arm', 'arm64', 'x86', 'x86_64', 'Unknown'], str]
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
"""OpenLLM Python client.
|
||||
|
||||
```python
|
||||
client = openllm.client.HTTPClient("http://localhost:8080")
|
||||
client.query("What is the difference between gather and scatter?")
|
||||
client = openllm.client.HTTPClient('http://localhost:8080')
|
||||
client.query('What is the difference between gather and scatter?')
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
@@ -81,7 +81,7 @@ def _start(
|
||||
if adapter_map:
|
||||
args.extend(
|
||||
list(
|
||||
itertools.chain.from_iterable([['--adapter-id', f"{k}{':'+v if v else ''}"] for k, v in adapter_map.items()])
|
||||
itertools.chain.from_iterable([['--adapter-id', f"{k}{':' + v if v else ''}"] for k, v in adapter_map.items()])
|
||||
)
|
||||
)
|
||||
if additional_args:
|
||||
@@ -173,7 +173,7 @@ def _build(
|
||||
if overwrite:
|
||||
args.append('--overwrite')
|
||||
if adapter_map:
|
||||
args.extend([f"--adapter-id={k}{':'+v if v is not None else ''}" for k, v in adapter_map.items()])
|
||||
args.extend([f"--adapter-id={k}{':' + v if v is not None else ''}" for k, v in adapter_map.items()])
|
||||
if model_version:
|
||||
args.extend(['--model-version', model_version])
|
||||
if bento_version:
|
||||
|
||||
@@ -21,7 +21,9 @@ if t.TYPE_CHECKING:
|
||||
@machine_option
|
||||
@click.pass_context
|
||||
@inject
|
||||
def cli(ctx: click.Context, bento: str, machine: bool, _bento_store: BentoStore = Provide[BentoMLContainer.bento_store]) -> str | None:
|
||||
def cli(
|
||||
ctx: click.Context, bento: str, machine: bool, _bento_store: BentoStore = Provide[BentoMLContainer.bento_store]
|
||||
) -> str | None:
|
||||
"""Dive into a BentoLLM. This is synonymous to cd $(b get <bento>:<tag> -o path)."""
|
||||
try:
|
||||
bentomodel = _bento_store.get(bento)
|
||||
|
||||
@@ -17,7 +17,9 @@ if t.TYPE_CHECKING:
|
||||
from bentoml._internal.bento import BentoStore
|
||||
|
||||
|
||||
@click.command('get_containerfile', context_settings=termui.CONTEXT_SETTINGS, help='Return Containerfile of any given Bento.')
|
||||
@click.command(
|
||||
'get_containerfile', context_settings=termui.CONTEXT_SETTINGS, help='Return Containerfile of any given Bento.'
|
||||
)
|
||||
@click.argument('bento', type=str, shell_complete=bento_complete_envvar)
|
||||
@click.pass_context
|
||||
@inject
|
||||
|
||||
@@ -22,7 +22,9 @@ class PromptFormatter(string.Formatter):
|
||||
raise ValueError('Positional arguments are not supported')
|
||||
return super().vformat(format_string, args, kwargs)
|
||||
|
||||
def check_unused_args(self, used_args: set[int | str], args: t.Sequence[t.Any], kwargs: t.Mapping[str, t.Any]) -> None:
|
||||
def check_unused_args(
|
||||
self, used_args: set[int | str], args: t.Sequence[t.Any], kwargs: t.Mapping[str, t.Any]
|
||||
) -> None:
|
||||
extras = set(kwargs).difference(used_args)
|
||||
if extras:
|
||||
raise KeyError(f'Extra params passed: {extras}')
|
||||
@@ -56,7 +58,9 @@ class PromptTemplate:
|
||||
try:
|
||||
return self.template.format(**prompt_variables)
|
||||
except KeyError as e:
|
||||
raise RuntimeError(f"Missing variable '{e.args[0]}' (required: {self._input_variables}) in the prompt template.") from None
|
||||
raise RuntimeError(
|
||||
f"Missing variable '{e.args[0]}' (required: {self._input_variables}) in the prompt template."
|
||||
) from None
|
||||
|
||||
|
||||
@click.command('get_prompt', context_settings=termui.CONTEXT_SETTINGS)
|
||||
@@ -124,15 +128,21 @@ def cli(
|
||||
if prompt_template_file and chat_template_file:
|
||||
ctx.fail('prompt-template-file and chat-template-file are mutually exclusive.')
|
||||
|
||||
acceptable = set(openllm.CONFIG_MAPPING_NAMES.keys()) | set(inflection.dasherize(name) for name in openllm.CONFIG_MAPPING_NAMES.keys())
|
||||
acceptable = set(openllm.CONFIG_MAPPING_NAMES.keys()) | set(
|
||||
inflection.dasherize(name) for name in openllm.CONFIG_MAPPING_NAMES.keys()
|
||||
)
|
||||
if model_id in acceptable:
|
||||
logger.warning('Using a default prompt from OpenLLM. Note that this prompt might not work for your intended usage.\n')
|
||||
logger.warning(
|
||||
'Using a default prompt from OpenLLM. Note that this prompt might not work for your intended usage.\n'
|
||||
)
|
||||
config = openllm.AutoConfig.for_model(model_id)
|
||||
template = prompt_template_file.read() if prompt_template_file is not None else config.template
|
||||
system_message = system_message or config.system_message
|
||||
|
||||
try:
|
||||
formatted = PromptTemplate(template).with_options(system_message=system_message).format(instruction=prompt, **_memoized)
|
||||
formatted = (
|
||||
PromptTemplate(template).with_options(system_message=system_message).format(instruction=prompt, **_memoized)
|
||||
)
|
||||
except RuntimeError as err:
|
||||
logger.debug('Exception caught while formatting prompt: %s', err)
|
||||
ctx.fail(str(err))
|
||||
@@ -149,15 +159,21 @@ def cli(
|
||||
for architecture in config.architectures:
|
||||
if architecture in openllm.AutoConfig._CONFIG_MAPPING_NAMES_TO_ARCHITECTURE():
|
||||
system_message = (
|
||||
openllm.AutoConfig.infer_class_from_name(openllm.AutoConfig._CONFIG_MAPPING_NAMES_TO_ARCHITECTURE()[architecture])
|
||||
openllm.AutoConfig.infer_class_from_name(
|
||||
openllm.AutoConfig._CONFIG_MAPPING_NAMES_TO_ARCHITECTURE()[architecture]
|
||||
)
|
||||
.model_construct_env()
|
||||
.system_message
|
||||
)
|
||||
break
|
||||
else:
|
||||
ctx.fail(f'Failed to infer system message from model architecture: {config.architectures}. Please pass in --system-message')
|
||||
ctx.fail(
|
||||
f'Failed to infer system message from model architecture: {config.architectures}. Please pass in --system-message'
|
||||
)
|
||||
messages = [{'role': 'system', 'content': system_message}, {'role': 'user', 'content': prompt}]
|
||||
formatted = tokenizer.apply_chat_template(messages, chat_template=chat_template_file, add_generation_prompt=add_generation_prompt, tokenize=False)
|
||||
formatted = tokenizer.apply_chat_template(
|
||||
messages, chat_template=chat_template_file, add_generation_prompt=add_generation_prompt, tokenize=False
|
||||
)
|
||||
|
||||
termui.echo(orjson.dumps({'prompt': formatted}, option=orjson.OPT_INDENT_2).decode(), fg='white')
|
||||
ctx.exit(0)
|
||||
|
||||
@@ -33,12 +33,17 @@ def cli(model_name: str | None) -> DictStrAny:
|
||||
}
|
||||
if model_name is not None:
|
||||
ids_in_local_store = {
|
||||
k: [i for i in v if 'model_name' in i.info.labels and i.info.labels['model_name'] == inflection.dasherize(model_name)]
|
||||
k: [
|
||||
i
|
||||
for i in v
|
||||
if 'model_name' in i.info.labels and i.info.labels['model_name'] == inflection.dasherize(model_name)
|
||||
]
|
||||
for k, v in ids_in_local_store.items()
|
||||
}
|
||||
ids_in_local_store = {k: v for k, v in ids_in_local_store.items() if v}
|
||||
local_models = {
|
||||
k: [{'tag': str(i.tag), 'size': human_readable_size(openllm.utils.calc_dir_size(i.path))} for i in val] for k, val in ids_in_local_store.items()
|
||||
k: [{'tag': str(i.tag), 'size': human_readable_size(openllm.utils.calc_dir_size(i.path))} for i in val]
|
||||
for k, val in ids_in_local_store.items()
|
||||
}
|
||||
termui.echo(orjson.dumps(local_models, option=orjson.OPT_INDENT_2).decode(), fg='white')
|
||||
return local_models
|
||||
|
||||
@@ -32,7 +32,14 @@ def load_notebook_metadata() -> DictStrAny:
|
||||
|
||||
@click.command('playground', context_settings=termui.CONTEXT_SETTINGS)
|
||||
@click.argument('output-dir', default=None, required=False)
|
||||
@click.option('--port', envvar='JUPYTER_PORT', show_envvar=True, show_default=True, default=8888, help='Default port for Jupyter server')
|
||||
@click.option(
|
||||
'--port',
|
||||
envvar='JUPYTER_PORT',
|
||||
show_envvar=True,
|
||||
show_default=True,
|
||||
default=8888,
|
||||
help='Default port for Jupyter server',
|
||||
)
|
||||
@click.pass_context
|
||||
def cli(ctx: click.Context, output_dir: str | None, port: int) -> None:
|
||||
"""OpenLLM Playground.
|
||||
@@ -53,7 +60,9 @@ def cli(ctx: click.Context, output_dir: str | None, port: int) -> None:
|
||||
> This command requires Jupyter to be installed. Install it with 'pip install "openllm[playground]"'
|
||||
"""
|
||||
if not is_jupyter_available() or not is_jupytext_available() or not is_notebook_available():
|
||||
raise RuntimeError("Playground requires 'jupyter', 'jupytext', and 'notebook'. Install it with 'pip install \"openllm[playground]\"'")
|
||||
raise RuntimeError(
|
||||
"Playground requires 'jupyter', 'jupytext', and 'notebook'. Install it with 'pip install \"openllm[playground]\"'"
|
||||
)
|
||||
metadata = load_notebook_metadata()
|
||||
_temp_dir = False
|
||||
if output_dir is None:
|
||||
@@ -65,7 +74,9 @@ def cli(ctx: click.Context, output_dir: str | None, port: int) -> None:
|
||||
termui.echo('The playground notebooks will be saved to: ' + os.path.abspath(output_dir), fg='blue')
|
||||
for module in pkgutil.iter_modules(playground.__path__):
|
||||
if module.ispkg or os.path.exists(os.path.join(output_dir, module.name + '.ipynb')):
|
||||
logger.debug('Skipping: %s (%s)', module.name, 'File already exists' if not module.ispkg else f'{module.name} is a module')
|
||||
logger.debug(
|
||||
'Skipping: %s (%s)', module.name, 'File already exists' if not module.ispkg else f'{module.name} is a module'
|
||||
)
|
||||
continue
|
||||
if not isinstance(module.module_finder, importlib.machinery.FileFinder):
|
||||
continue
|
||||
|
||||
@@ -66,8 +66,15 @@ def test_config_derived_follow_attrs_protocol(gen_settings: ModelSettings):
|
||||
st.integers(max_value=283473),
|
||||
st.floats(min_value=0.0, max_value=1.0),
|
||||
)
|
||||
def test_complex_struct_dump(gen_settings: ModelSettings, field1: int, temperature: float, input_field1: int, input_temperature: float):
|
||||
cl_ = make_llm_config('ComplexLLM', gen_settings, fields=(('field1', 'float', field1),), generation_fields=(('temperature', temperature),))
|
||||
def test_complex_struct_dump(
|
||||
gen_settings: ModelSettings, field1: int, temperature: float, input_field1: int, input_temperature: float
|
||||
):
|
||||
cl_ = make_llm_config(
|
||||
'ComplexLLM',
|
||||
gen_settings,
|
||||
fields=(('field1', 'float', field1),),
|
||||
generation_fields=(('temperature', temperature),),
|
||||
)
|
||||
sent = cl_()
|
||||
assert sent.model_dump()['field1'] == field1
|
||||
assert sent.model_dump()['generation_config']['temperature'] == temperature
|
||||
|
||||
@@ -10,8 +10,14 @@ import openllm
|
||||
if t.TYPE_CHECKING:
|
||||
from openllm_core._typing_compat import LiteralBackend
|
||||
|
||||
_MODELING_MAPPING = {'flan_t5': 'google/flan-t5-small', 'opt': 'facebook/opt-125m', 'baichuan': 'baichuan-inc/Baichuan-7B'}
|
||||
_PROMPT_MAPPING = {'qa': 'Answer the following yes/no question by reasoning step-by-step. Can you write a whole Haiku in a single tweet?'}
|
||||
_MODELING_MAPPING = {
|
||||
'flan_t5': 'google/flan-t5-small',
|
||||
'opt': 'facebook/opt-125m',
|
||||
'baichuan': 'baichuan-inc/Baichuan-7B',
|
||||
}
|
||||
_PROMPT_MAPPING = {
|
||||
'qa': 'Answer the following yes/no question by reasoning step-by-step. Can you write a whole Haiku in a single tweet?'
|
||||
}
|
||||
|
||||
|
||||
def parametrise_local_llm(model: str) -> t.Generator[tuple[str, openllm.LLM[t.Any, t.Any]], None, None]:
|
||||
@@ -25,7 +31,9 @@ def parametrise_local_llm(model: str) -> t.Generator[tuple[str, openllm.LLM[t.An
|
||||
def pytest_generate_tests(metafunc: pytest.Metafunc) -> None:
|
||||
if os.getenv('GITHUB_ACTIONS') is None:
|
||||
if 'prompt' in metafunc.fixturenames and 'llm' in metafunc.fixturenames:
|
||||
metafunc.parametrize('prompt,llm', [(p, llm) for p, llm in parametrise_local_llm(metafunc.function.__name__[5:-15])])
|
||||
metafunc.parametrize(
|
||||
'prompt,llm', [(p, llm) for p, llm in parametrise_local_llm(metafunc.function.__name__[5:-15])]
|
||||
)
|
||||
|
||||
|
||||
def pytest_sessionfinish(session: pytest.Session, exitstatus: int):
|
||||
|
||||
@@ -73,9 +73,13 @@ def test_nvidia_gpu_validate(monkeypatch: pytest.MonkeyPatch):
|
||||
mcls.setenv('CUDA_VISIBLE_DEVICES', '')
|
||||
assert len(NvidiaGpuResource.from_system()) >= 0 # TODO: real from_system tests
|
||||
|
||||
assert pytest.raises(ValueError, NvidiaGpuResource.validate, [*NvidiaGpuResource.from_system(), 1]).match('Input list should be all string type.')
|
||||
assert pytest.raises(ValueError, NvidiaGpuResource.validate, [*NvidiaGpuResource.from_system(), 1]).match(
|
||||
'Input list should be all string type.'
|
||||
)
|
||||
assert pytest.raises(ValueError, NvidiaGpuResource.validate, [-2]).match('Input list should be all string type.')
|
||||
assert pytest.raises(ValueError, NvidiaGpuResource.validate, ['GPU-5ebe9f43', 'GPU-ac33420d4628']).match('Failed to parse available GPUs UUID')
|
||||
assert pytest.raises(ValueError, NvidiaGpuResource.validate, ['GPU-5ebe9f43', 'GPU-ac33420d4628']).match(
|
||||
'Failed to parse available GPUs UUID'
|
||||
)
|
||||
|
||||
|
||||
def test_nvidia_gpu_from_spec(monkeypatch: pytest.MonkeyPatch):
|
||||
|
||||
Reference in New Issue
Block a user