chore: running updated ruff formatter [skip ci]

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
Aaron
2024-03-15 05:35:24 -04:00
parent c34db550a6
commit 727361ced7
16 changed files with 216 additions and 65 deletions

View File

@@ -15,11 +15,17 @@ class HTTPClient:
address: str
helpers: _Helpers
@overload
def __init__(self, address: str, timeout: int = ..., verify: bool = ..., max_retries: int = ..., api_version: str = ...) -> None: ...
def __init__(
self, address: str, timeout: int = ..., verify: bool = ..., max_retries: int = ..., api_version: str = ...
) -> None: ...
@overload
def __init__(self, address: str = ..., timeout: int = ..., verify: bool = ..., max_retries: int = ..., api_version: str = ...) -> None: ...
def __init__(
self, address: str = ..., timeout: int = ..., verify: bool = ..., max_retries: int = ..., api_version: str = ...
) -> None: ...
@overload
def __init__(self, address: None = ..., timeout: int = ..., verify: bool = ..., max_retries: int = ..., api_version: str = ...) -> None: ...
def __init__(
self, address: None = ..., timeout: int = ..., verify: bool = ..., max_retries: int = ..., api_version: str = ...
) -> None: ...
@property
def is_ready(self) -> bool: ...
def health(self) -> bool: ...
@@ -60,11 +66,17 @@ class AsyncHTTPClient:
address: str
helpers: _AsyncHelpers
@overload
def __init__(self, address: str, timeout: int = ..., verify: bool = ..., max_retries: int = ..., api_version: str = ...) -> None: ...
def __init__(
self, address: str, timeout: int = ..., verify: bool = ..., max_retries: int = ..., api_version: str = ...
) -> None: ...
@overload
def __init__(self, address: str = ..., timeout: int = ..., verify: bool = ..., max_retries: int = ..., api_version: str = ...) -> None: ...
def __init__(
self, address: str = ..., timeout: int = ..., verify: bool = ..., max_retries: int = ..., api_version: str = ...
) -> None: ...
@overload
def __init__(self, address: None = ..., timeout: int = ..., verify: bool = ..., max_retries: int = ..., api_version: str = ...) -> None: ...
def __init__(
self, address: None = ..., timeout: int = ..., verify: bool = ..., max_retries: int = ..., api_version: str = ...
) -> None: ...
@property
def is_ready(self) -> bool: ...
async def health(self) -> bool: ...

View File

@@ -70,10 +70,14 @@ class HTTPClient(Client):
return self.generate(prompt, **attrs)
def health(self):
response = self._get('/readyz', response_cls=None, options={'return_raw_response': True, 'max_retries': self._max_retries})
response = self._get(
'/readyz', response_cls=None, options={'return_raw_response': True, 'max_retries': self._max_retries}
)
return response.status_code == 200
def generate(self, prompt, llm_config=None, stop=None, adapter_name=None, timeout=None, verify=None, **attrs) -> Response:
def generate(
self, prompt, llm_config=None, stop=None, adapter_name=None, timeout=None, verify=None, **attrs
) -> Response:
if timeout is None:
timeout = self._timeout
if verify is None:
@@ -96,7 +100,9 @@ class HTTPClient(Client):
for response_chunk in self.generate_iterator(prompt, llm_config, stop, adapter_name, timeout, verify, **attrs):
yield StreamingResponse.from_response_chunk(response_chunk)
def generate_iterator(self, prompt, llm_config=None, stop=None, adapter_name=None, timeout=None, verify=None, **attrs) -> t.Iterator[Response]:
def generate_iterator(
self, prompt, llm_config=None, stop=None, adapter_name=None, timeout=None, verify=None, **attrs
) -> t.Iterator[Response]:
if timeout is None:
timeout = self._timeout
if verify is None:
@@ -146,7 +152,9 @@ class AsyncHTTPClient(AsyncClient):
@property
async def _metadata(self) -> t.Awaitable[Metadata]:
if self.__metadata is None:
self.__metadata = await self._post(f'/{self._api_version}/metadata', response_cls=Metadata, json={}, options={'max_retries': self._max_retries})
self.__metadata = await self._post(
f'/{self._api_version}/metadata', response_cls=Metadata, json={}, options={'max_retries': self._max_retries}
)
return self.__metadata
@property
@@ -159,10 +167,14 @@ class AsyncHTTPClient(AsyncClient):
return await self.generate(prompt, **attrs)
async def health(self):
response = await self._get('/readyz', response_cls=None, options={'return_raw_response': True, 'max_retries': self._max_retries})
response = await self._get(
'/readyz', response_cls=None, options={'return_raw_response': True, 'max_retries': self._max_retries}
)
return response.status_code == 200
async def generate(self, prompt, llm_config=None, stop=None, adapter_name=None, timeout=None, verify=None, **attrs) -> Response:
async def generate(
self, prompt, llm_config=None, stop=None, adapter_name=None, timeout=None, verify=None, **attrs
) -> Response:
if timeout is None:
timeout = self._timeout
if verify is None:
@@ -183,7 +195,9 @@ class AsyncHTTPClient(AsyncClient):
async def generate_stream(
self, prompt, llm_config=None, stop=None, adapter_name=None, timeout=None, verify=None, **attrs
) -> t.AsyncGenerator[StreamingResponse, t.Any]:
async for response_chunk in self.generate_iterator(prompt, llm_config, stop, adapter_name, timeout, verify, **attrs):
async for response_chunk in self.generate_iterator(
prompt, llm_config, stop, adapter_name, timeout, verify, **attrs
):
yield StreamingResponse.from_response_chunk(response_chunk)
async def generate_iterator(

View File

@@ -17,7 +17,7 @@ if t.TYPE_CHECKING:
from ._shim import AsyncClient, Client
__all__ = ['Response', 'CompletionChunk', 'Metadata', 'StreamingResponse', 'Helpers']
__all__ = ['CompletionChunk', 'Helpers', 'Metadata', 'Response', 'StreamingResponse']
@attr.define
@@ -42,7 +42,11 @@ def _structure_metadata(data: t.Dict[str, t.Any], cls: type[Metadata]) -> Metada
raise RuntimeError(f'Malformed metadata configuration (Server-side issue): {e}') from None
try:
return cls(
model_id=data['model_id'], timeout=data['timeout'], model_name=data['model_name'], backend=data['backend'], configuration=configuration
model_id=data['model_id'],
timeout=data['timeout'],
model_name=data['model_name'],
backend=data['backend'],
configuration=configuration,
)
except Exception as e:
raise RuntimeError(f'Malformed metadata (Server-side issue): {e}') from None
@@ -61,7 +65,10 @@ class StreamingResponse(_SchemaMixin):
@classmethod
def from_response_chunk(cls, response: Response) -> StreamingResponse:
return cls(
request_id=response.request_id, index=response.outputs[0].index, text=response.outputs[0].text, token_ids=response.outputs[0].token_ids[0]
request_id=response.request_id,
index=response.outputs[0].index,
text=response.outputs[0].text,
token_ids=response.outputs[0].token_ids[0],
)
@@ -88,11 +95,17 @@ class Helpers:
return self._async_client
def messages(self, messages, add_generation_prompt=False):
return self.client._post('/v1/helpers/messages', response_cls=str, json=dict(messages=messages, add_generation_prompt=add_generation_prompt))
return self.client._post(
'/v1/helpers/messages',
response_cls=str,
json=dict(messages=messages, add_generation_prompt=add_generation_prompt),
)
async def async_messages(self, messages, add_generation_prompt=False):
return await self.async_client._post(
'/v1/helpers/messages', response_cls=str, json=dict(messages=messages, add_generation_prompt=add_generation_prompt)
'/v1/helpers/messages',
response_cls=str,
json=dict(messages=messages, add_generation_prompt=add_generation_prompt),
)
@classmethod

View File

@@ -140,7 +140,9 @@ class APIResponse(t.Generic[Response]):
data = self._raw_response.json()
try:
return self._client._process_response_data(data=data, response_cls=self._response_cls, raw_response=self._raw_response)
return self._client._process_response_data(
data=data, response_cls=self._response_cls, raw_response=self._raw_response
)
except Exception as exc:
raise ValueError(exc) from None # validation error here
@@ -271,10 +273,16 @@ class BaseClient(t.Generic[InnerClient, StreamType]):
def _build_request(self, options: RequestOptions) -> httpx.Request:
return self._inner.build_request(
method=options.method, headers=self._build_headers(options), url=self._prepare_url(options.url), json=options.json, params=options.params
method=options.method,
headers=self._build_headers(options),
url=self._prepare_url(options.url),
json=options.json,
params=options.params,
)
def _calculate_retry_timeout(self, remaining_retries: int, options: RequestOptions, headers: t.Optional[httpx.Headers] = None) -> float:
def _calculate_retry_timeout(
self, remaining_retries: int, options: RequestOptions, headers: t.Optional[httpx.Headers] = None
) -> float:
max_retries = options.get_max_retries(self._max_retries)
# https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After
try:
@@ -315,7 +323,9 @@ class BaseClient(t.Generic[InnerClient, StreamType]):
return True
return False
def _process_response_data(self, *, response_cls: type[Response], data: t.Dict[str, t.Any], raw_response: httpx.Response) -> Response:
def _process_response_data(
self, *, response_cls: type[Response], data: t.Dict[str, t.Any], raw_response: httpx.Response
) -> Response:
return converter.structure(data, response_cls)
def _process_response(
@@ -328,13 +338,24 @@ class BaseClient(t.Generic[InnerClient, StreamType]):
stream_cls: type[_Stream] | type[_AsyncStream] | None,
) -> Response:
return APIResponse(
raw_response=raw_response, client=self, response_cls=response_cls, stream=stream, stream_cls=stream_cls, options=options
raw_response=raw_response,
client=self,
response_cls=response_cls,
stream=stream,
stream_cls=stream_cls,
options=options,
).parse()
@attr.define(init=False)
class Client(BaseClient[httpx.Client, Stream[t.Any]]):
def __init__(self, base_url: str | httpx.URL, version: str, timeout: int | httpx.Timeout = DEFAULT_TIMEOUT, max_retries: int = MAX_RETRIES):
def __init__(
self,
base_url: str | httpx.URL,
version: str,
timeout: int | httpx.Timeout = DEFAULT_TIMEOUT,
max_retries: int = MAX_RETRIES,
):
super().__init__(
base_url=base_url,
version=version,
@@ -366,7 +387,13 @@ class Client(BaseClient[httpx.Client, Stream[t.Any]]):
stream: bool = False,
stream_cls: type[_Stream] | None = None,
) -> Response | _Stream:
return self._request(response_cls=response_cls, options=options, remaining_retries=remaining_retries, stream=stream, stream_cls=stream_cls)
return self._request(
response_cls=response_cls,
options=options,
remaining_retries=remaining_retries,
stream=stream,
stream_cls=stream_cls,
)
def _request(
self,
@@ -385,7 +412,9 @@ class Client(BaseClient[httpx.Client, Stream[t.Any]]):
response.raise_for_status()
except httpx.HTTPStatusError as exc:
if retries > 0 and self._should_retry(exc.response):
return self._retry_request(response_cls, options, retries, exc.response.headers, stream=stream, stream_cls=stream_cls)
return self._retry_request(
response_cls, options, retries, exc.response.headers, stream=stream, stream_cls=stream_cls
)
# If the response is streamed then we need to explicitly read the completed response
exc.response.read()
raise ValueError(exc.message) from None
@@ -398,7 +427,9 @@ class Client(BaseClient[httpx.Client, Stream[t.Any]]):
return self._retry_request(response_cls, options, retries, stream=stream, stream_cls=stream_cls)
raise ValueError(request) from None # connection error
return self._process_response(response_cls=response_cls, options=options, raw_response=response, stream=stream, stream_cls=stream_cls)
return self._process_response(
response_cls=response_cls, options=options, raw_response=response, stream=stream, stream_cls=stream_cls
)
def _retry_request(
self,
@@ -428,7 +459,9 @@ class Client(BaseClient[httpx.Client, Stream[t.Any]]):
) -> Response | _Stream:
if options is None:
options = {}
return self.request(response_cls, RequestOptions(method='GET', url=path, **options), stream=stream, stream_cls=stream_cls)
return self.request(
response_cls, RequestOptions(method='GET', url=path, **options), stream=stream, stream_cls=stream_cls
)
def _post(
self,
@@ -442,12 +475,20 @@ class Client(BaseClient[httpx.Client, Stream[t.Any]]):
) -> Response | _Stream:
if options is None:
options = {}
return self.request(response_cls, RequestOptions(method='POST', url=path, json=json, **options), stream=stream, stream_cls=stream_cls)
return self.request(
response_cls, RequestOptions(method='POST', url=path, json=json, **options), stream=stream, stream_cls=stream_cls
)
@attr.define(init=False)
class AsyncClient(BaseClient[httpx.AsyncClient, AsyncStream[t.Any]]):
def __init__(self, base_url: str | httpx.URL, version: str, timeout: int | httpx.Timeout = DEFAULT_TIMEOUT, max_retries: int = MAX_RETRIES):
def __init__(
self,
base_url: str | httpx.URL,
version: str,
timeout: int | httpx.Timeout = DEFAULT_TIMEOUT,
max_retries: int = MAX_RETRIES,
):
super().__init__(
base_url=base_url,
version=version,
@@ -471,7 +512,7 @@ class AsyncClient(BaseClient[httpx.AsyncClient, AsyncStream[t.Any]]):
try:
loop = asyncio.get_event_loop()
if loop.is_running():
loop.create_task(self.close())
loop.create_task(self.close()) # noqa
else:
loop.run_until_complete(self.close())
except Exception:
@@ -486,7 +527,9 @@ class AsyncClient(BaseClient[httpx.AsyncClient, AsyncStream[t.Any]]):
stream: bool = False,
stream_cls: type[_AsyncStream] | None = None,
) -> Response | _AsyncStream:
return await self._request(response_cls, options, remaining_retries=remaining_retries, stream=stream, stream_cls=stream_cls)
return await self._request(
response_cls, options, remaining_retries=remaining_retries, stream=stream, stream_cls=stream_cls
)
async def _request(
self,
@@ -506,7 +549,9 @@ class AsyncClient(BaseClient[httpx.AsyncClient, AsyncStream[t.Any]]):
response.raise_for_status()
except httpx.HTTPStatusError as exc:
if retries > 0 and self._should_retry(exc.response):
return self._retry_request(response_cls, options, retries, exc.response.headers, stream=stream, stream_cls=stream_cls)
return self._retry_request(
response_cls, options, retries, exc.response.headers, stream=stream, stream_cls=stream_cls
)
# If the response is streamed then we need to explicitly read the completed response
await exc.response.aread()
raise ValueError(exc.message) from None
@@ -526,7 +571,9 @@ class AsyncClient(BaseClient[httpx.AsyncClient, AsyncStream[t.Any]]):
return await self._retry_request(response_cls, options, retries, stream=stream, stream_cls=stream_cls)
raise ValueError(request) from err # connection error
return self._process_response(response_cls=response_cls, options=options, raw_response=response, stream=stream, stream_cls=stream_cls)
return self._process_response(
response_cls=response_cls, options=options, raw_response=response, stream=stream, stream_cls=stream_cls
)
async def _retry_request(
self,
@@ -555,7 +602,9 @@ class AsyncClient(BaseClient[httpx.AsyncClient, AsyncStream[t.Any]]):
) -> Response | _AsyncStream:
if options is None:
options = {}
return await self.request(response_cls, RequestOptions(method='GET', url=path, **options), stream=stream, stream_cls=stream_cls)
return await self.request(
response_cls, RequestOptions(method='GET', url=path, **options), stream=stream, stream_cls=stream_cls
)
async def _post(
self,
@@ -569,4 +618,6 @@ class AsyncClient(BaseClient[httpx.AsyncClient, AsyncStream[t.Any]]):
) -> Response | _AsyncStream:
if options is None:
options = {}
return await self.request(response_cls, RequestOptions(method='POST', url=path, json=json, **options), stream=stream, stream_cls=stream_cls)
return await self.request(
response_cls, RequestOptions(method='POST', url=path, json=json, **options), stream=stream, stream_cls=stream_cls
)

View File

@@ -38,7 +38,9 @@ class Stream(t.Generic[Response]):
if sse.data.startswith('[DONE]'):
break
if sse.event is None:
yield self._client._process_response_data(data=sse.model_dump(), response_cls=self._response_cls, raw_response=self._response)
yield self._client._process_response_data(
data=sse.model_dump(), response_cls=self._response_cls, raw_response=self._response
)
@attr.define(auto_attribs=True)
@@ -69,7 +71,9 @@ class AsyncStream(t.Generic[Response]):
if sse.data.startswith('[DONE]'):
break
if sse.event is None:
yield self._client._process_response_data(data=sse.model_dump(), response_cls=self._response_cls, raw_response=self._response)
yield self._client._process_response_data(
data=sse.model_dump(), response_cls=self._response_cls, raw_response=self._response
)
@attr.define

View File

@@ -11,5 +11,7 @@ from openllm_core._typing_compat import (
overload as overload,
)
Platform = Annotated[LiteralString, Literal['MacOS', 'Linux', 'Windows', 'FreeBSD', 'OpenBSD', 'iOS', 'iPadOS', 'Android', 'Unknown'], str]
Platform = Annotated[
LiteralString, Literal['MacOS', 'Linux', 'Windows', 'FreeBSD', 'OpenBSD', 'iOS', 'iPadOS', 'Android', 'Unknown'], str
]
Architecture = Annotated[LiteralString, Literal['arm', 'arm64', 'x86', 'x86_64', 'Unknown'], str]

View File

@@ -1,8 +1,8 @@
"""OpenLLM Python client.
```python
client = openllm.client.HTTPClient("http://localhost:8080")
client.query("What is the difference between gather and scatter?")
client = openllm.client.HTTPClient('http://localhost:8080')
client.query('What is the difference between gather and scatter?')
```
"""

View File

@@ -81,7 +81,7 @@ def _start(
if adapter_map:
args.extend(
list(
itertools.chain.from_iterable([['--adapter-id', f"{k}{':'+v if v else ''}"] for k, v in adapter_map.items()])
itertools.chain.from_iterable([['--adapter-id', f"{k}{':' + v if v else ''}"] for k, v in adapter_map.items()])
)
)
if additional_args:
@@ -173,7 +173,7 @@ def _build(
if overwrite:
args.append('--overwrite')
if adapter_map:
args.extend([f"--adapter-id={k}{':'+v if v is not None else ''}" for k, v in adapter_map.items()])
args.extend([f"--adapter-id={k}{':' + v if v is not None else ''}" for k, v in adapter_map.items()])
if model_version:
args.extend(['--model-version', model_version])
if bento_version:

View File

@@ -21,7 +21,9 @@ if t.TYPE_CHECKING:
@machine_option
@click.pass_context
@inject
def cli(ctx: click.Context, bento: str, machine: bool, _bento_store: BentoStore = Provide[BentoMLContainer.bento_store]) -> str | None:
def cli(
ctx: click.Context, bento: str, machine: bool, _bento_store: BentoStore = Provide[BentoMLContainer.bento_store]
) -> str | None:
"""Dive into a BentoLLM. This is synonymous to cd $(b get <bento>:<tag> -o path)."""
try:
bentomodel = _bento_store.get(bento)

View File

@@ -17,7 +17,9 @@ if t.TYPE_CHECKING:
from bentoml._internal.bento import BentoStore
@click.command('get_containerfile', context_settings=termui.CONTEXT_SETTINGS, help='Return Containerfile of any given Bento.')
@click.command(
'get_containerfile', context_settings=termui.CONTEXT_SETTINGS, help='Return Containerfile of any given Bento.'
)
@click.argument('bento', type=str, shell_complete=bento_complete_envvar)
@click.pass_context
@inject

View File

@@ -22,7 +22,9 @@ class PromptFormatter(string.Formatter):
raise ValueError('Positional arguments are not supported')
return super().vformat(format_string, args, kwargs)
def check_unused_args(self, used_args: set[int | str], args: t.Sequence[t.Any], kwargs: t.Mapping[str, t.Any]) -> None:
def check_unused_args(
self, used_args: set[int | str], args: t.Sequence[t.Any], kwargs: t.Mapping[str, t.Any]
) -> None:
extras = set(kwargs).difference(used_args)
if extras:
raise KeyError(f'Extra params passed: {extras}')
@@ -56,7 +58,9 @@ class PromptTemplate:
try:
return self.template.format(**prompt_variables)
except KeyError as e:
raise RuntimeError(f"Missing variable '{e.args[0]}' (required: {self._input_variables}) in the prompt template.") from None
raise RuntimeError(
f"Missing variable '{e.args[0]}' (required: {self._input_variables}) in the prompt template."
) from None
@click.command('get_prompt', context_settings=termui.CONTEXT_SETTINGS)
@@ -124,15 +128,21 @@ def cli(
if prompt_template_file and chat_template_file:
ctx.fail('prompt-template-file and chat-template-file are mutually exclusive.')
acceptable = set(openllm.CONFIG_MAPPING_NAMES.keys()) | set(inflection.dasherize(name) for name in openllm.CONFIG_MAPPING_NAMES.keys())
acceptable = set(openllm.CONFIG_MAPPING_NAMES.keys()) | set(
inflection.dasherize(name) for name in openllm.CONFIG_MAPPING_NAMES.keys()
)
if model_id in acceptable:
logger.warning('Using a default prompt from OpenLLM. Note that this prompt might not work for your intended usage.\n')
logger.warning(
'Using a default prompt from OpenLLM. Note that this prompt might not work for your intended usage.\n'
)
config = openllm.AutoConfig.for_model(model_id)
template = prompt_template_file.read() if prompt_template_file is not None else config.template
system_message = system_message or config.system_message
try:
formatted = PromptTemplate(template).with_options(system_message=system_message).format(instruction=prompt, **_memoized)
formatted = (
PromptTemplate(template).with_options(system_message=system_message).format(instruction=prompt, **_memoized)
)
except RuntimeError as err:
logger.debug('Exception caught while formatting prompt: %s', err)
ctx.fail(str(err))
@@ -149,15 +159,21 @@ def cli(
for architecture in config.architectures:
if architecture in openllm.AutoConfig._CONFIG_MAPPING_NAMES_TO_ARCHITECTURE():
system_message = (
openllm.AutoConfig.infer_class_from_name(openllm.AutoConfig._CONFIG_MAPPING_NAMES_TO_ARCHITECTURE()[architecture])
openllm.AutoConfig.infer_class_from_name(
openllm.AutoConfig._CONFIG_MAPPING_NAMES_TO_ARCHITECTURE()[architecture]
)
.model_construct_env()
.system_message
)
break
else:
ctx.fail(f'Failed to infer system message from model architecture: {config.architectures}. Please pass in --system-message')
ctx.fail(
f'Failed to infer system message from model architecture: {config.architectures}. Please pass in --system-message'
)
messages = [{'role': 'system', 'content': system_message}, {'role': 'user', 'content': prompt}]
formatted = tokenizer.apply_chat_template(messages, chat_template=chat_template_file, add_generation_prompt=add_generation_prompt, tokenize=False)
formatted = tokenizer.apply_chat_template(
messages, chat_template=chat_template_file, add_generation_prompt=add_generation_prompt, tokenize=False
)
termui.echo(orjson.dumps({'prompt': formatted}, option=orjson.OPT_INDENT_2).decode(), fg='white')
ctx.exit(0)

View File

@@ -33,12 +33,17 @@ def cli(model_name: str | None) -> DictStrAny:
}
if model_name is not None:
ids_in_local_store = {
k: [i for i in v if 'model_name' in i.info.labels and i.info.labels['model_name'] == inflection.dasherize(model_name)]
k: [
i
for i in v
if 'model_name' in i.info.labels and i.info.labels['model_name'] == inflection.dasherize(model_name)
]
for k, v in ids_in_local_store.items()
}
ids_in_local_store = {k: v for k, v in ids_in_local_store.items() if v}
local_models = {
k: [{'tag': str(i.tag), 'size': human_readable_size(openllm.utils.calc_dir_size(i.path))} for i in val] for k, val in ids_in_local_store.items()
k: [{'tag': str(i.tag), 'size': human_readable_size(openllm.utils.calc_dir_size(i.path))} for i in val]
for k, val in ids_in_local_store.items()
}
termui.echo(orjson.dumps(local_models, option=orjson.OPT_INDENT_2).decode(), fg='white')
return local_models

View File

@@ -32,7 +32,14 @@ def load_notebook_metadata() -> DictStrAny:
@click.command('playground', context_settings=termui.CONTEXT_SETTINGS)
@click.argument('output-dir', default=None, required=False)
@click.option('--port', envvar='JUPYTER_PORT', show_envvar=True, show_default=True, default=8888, help='Default port for Jupyter server')
@click.option(
'--port',
envvar='JUPYTER_PORT',
show_envvar=True,
show_default=True,
default=8888,
help='Default port for Jupyter server',
)
@click.pass_context
def cli(ctx: click.Context, output_dir: str | None, port: int) -> None:
"""OpenLLM Playground.
@@ -53,7 +60,9 @@ def cli(ctx: click.Context, output_dir: str | None, port: int) -> None:
> This command requires Jupyter to be installed. Install it with 'pip install "openllm[playground]"'
"""
if not is_jupyter_available() or not is_jupytext_available() or not is_notebook_available():
raise RuntimeError("Playground requires 'jupyter', 'jupytext', and 'notebook'. Install it with 'pip install \"openllm[playground]\"'")
raise RuntimeError(
"Playground requires 'jupyter', 'jupytext', and 'notebook'. Install it with 'pip install \"openllm[playground]\"'"
)
metadata = load_notebook_metadata()
_temp_dir = False
if output_dir is None:
@@ -65,7 +74,9 @@ def cli(ctx: click.Context, output_dir: str | None, port: int) -> None:
termui.echo('The playground notebooks will be saved to: ' + os.path.abspath(output_dir), fg='blue')
for module in pkgutil.iter_modules(playground.__path__):
if module.ispkg or os.path.exists(os.path.join(output_dir, module.name + '.ipynb')):
logger.debug('Skipping: %s (%s)', module.name, 'File already exists' if not module.ispkg else f'{module.name} is a module')
logger.debug(
'Skipping: %s (%s)', module.name, 'File already exists' if not module.ispkg else f'{module.name} is a module'
)
continue
if not isinstance(module.module_finder, importlib.machinery.FileFinder):
continue

View File

@@ -66,8 +66,15 @@ def test_config_derived_follow_attrs_protocol(gen_settings: ModelSettings):
st.integers(max_value=283473),
st.floats(min_value=0.0, max_value=1.0),
)
def test_complex_struct_dump(gen_settings: ModelSettings, field1: int, temperature: float, input_field1: int, input_temperature: float):
cl_ = make_llm_config('ComplexLLM', gen_settings, fields=(('field1', 'float', field1),), generation_fields=(('temperature', temperature),))
def test_complex_struct_dump(
gen_settings: ModelSettings, field1: int, temperature: float, input_field1: int, input_temperature: float
):
cl_ = make_llm_config(
'ComplexLLM',
gen_settings,
fields=(('field1', 'float', field1),),
generation_fields=(('temperature', temperature),),
)
sent = cl_()
assert sent.model_dump()['field1'] == field1
assert sent.model_dump()['generation_config']['temperature'] == temperature

View File

@@ -10,8 +10,14 @@ import openllm
if t.TYPE_CHECKING:
from openllm_core._typing_compat import LiteralBackend
_MODELING_MAPPING = {'flan_t5': 'google/flan-t5-small', 'opt': 'facebook/opt-125m', 'baichuan': 'baichuan-inc/Baichuan-7B'}
_PROMPT_MAPPING = {'qa': 'Answer the following yes/no question by reasoning step-by-step. Can you write a whole Haiku in a single tweet?'}
_MODELING_MAPPING = {
'flan_t5': 'google/flan-t5-small',
'opt': 'facebook/opt-125m',
'baichuan': 'baichuan-inc/Baichuan-7B',
}
_PROMPT_MAPPING = {
'qa': 'Answer the following yes/no question by reasoning step-by-step. Can you write a whole Haiku in a single tweet?'
}
def parametrise_local_llm(model: str) -> t.Generator[tuple[str, openllm.LLM[t.Any, t.Any]], None, None]:
@@ -25,7 +31,9 @@ def parametrise_local_llm(model: str) -> t.Generator[tuple[str, openllm.LLM[t.An
def pytest_generate_tests(metafunc: pytest.Metafunc) -> None:
if os.getenv('GITHUB_ACTIONS') is None:
if 'prompt' in metafunc.fixturenames and 'llm' in metafunc.fixturenames:
metafunc.parametrize('prompt,llm', [(p, llm) for p, llm in parametrise_local_llm(metafunc.function.__name__[5:-15])])
metafunc.parametrize(
'prompt,llm', [(p, llm) for p, llm in parametrise_local_llm(metafunc.function.__name__[5:-15])]
)
def pytest_sessionfinish(session: pytest.Session, exitstatus: int):

View File

@@ -73,9 +73,13 @@ def test_nvidia_gpu_validate(monkeypatch: pytest.MonkeyPatch):
mcls.setenv('CUDA_VISIBLE_DEVICES', '')
assert len(NvidiaGpuResource.from_system()) >= 0 # TODO: real from_system tests
assert pytest.raises(ValueError, NvidiaGpuResource.validate, [*NvidiaGpuResource.from_system(), 1]).match('Input list should be all string type.')
assert pytest.raises(ValueError, NvidiaGpuResource.validate, [*NvidiaGpuResource.from_system(), 1]).match(
'Input list should be all string type.'
)
assert pytest.raises(ValueError, NvidiaGpuResource.validate, [-2]).match('Input list should be all string type.')
assert pytest.raises(ValueError, NvidiaGpuResource.validate, ['GPU-5ebe9f43', 'GPU-ac33420d4628']).match('Failed to parse available GPUs UUID')
assert pytest.raises(ValueError, NvidiaGpuResource.validate, ['GPU-5ebe9f43', 'GPU-ac33420d4628']).match(
'Failed to parse available GPUs UUID'
)
def test_nvidia_gpu_from_spec(monkeypatch: pytest.MonkeyPatch):