refactor: cleanup typing to expose correct API (#576)

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
Aaron Pham
2023-11-08 01:24:03 -05:00
committed by GitHub
parent c40d4c1016
commit 97d7c38fea
83 changed files with 440 additions and 1992 deletions

View File

@@ -1,11 +1,11 @@
'''CLI entrypoint for OpenLLM.
"""CLI entrypoint for OpenLLM.
Usage:
openllm --help
To start any OpenLLM model:
openllm start <model_name> --options ...
'''
"""
from __future__ import annotations
if __name__ == '__main__':

View File

@@ -51,7 +51,7 @@ def is_sentence_complete(output: str) -> bool:
return output.endswith(('.', '?', '!', '...', '', '?', '!', '', '"', "'", ''))
def is_partial_stop(output: str, stop_str: str) -> bool:
'''Check whether the output contains a partial stop str.'''
"""Check whether the output contains a partial stop str."""
for i in range(0, min(len(output), len(stop_str))):
if stop_str.startswith(output[-i:]): return True
return False

View File

@@ -80,13 +80,13 @@ def normalise_model_name(name: str) -> str:
return inflection.dasherize(name)
def resolve_peft_config_type(adapter_map: dict[str, str]) -> AdapterMap:
'''Resolve the type of the PeftConfig given the adapter_map.
"""Resolve the type of the PeftConfig given the adapter_map.
This is similar to how PeftConfig resolve its config type.
Args:
adapter_map: The given mapping from either SDK or CLI. See CLI docs for more information.
'''
"""
resolved: AdapterMap = {}
_has_set_default = False
for path_or_adapter_id, name in adapter_map.items():
@@ -191,7 +191,7 @@ class LLM(t.Generic[M, T]):
@apply(lambda val: tuple(str.lower(i) if i else i for i in val))
def _make_tag_components(self, model_id: str, model_version: str | None, backend: LiteralBackend) -> tuple[str, str | None]:
'''Return a valid tag name (<backend>-<repo>--<model_id>) and its tag version.'''
"""Return a valid tag name (<backend>-<repo>--<model_id>) and its tag version."""
model_id, *maybe_revision = model_id.rsplit(':')
if len(maybe_revision) > 0:
if model_version is not None: logger.warning("revision is specified within 'model_id' (%s), and 'model_version=%s' will be ignored.", maybe_revision[0], model_version)

View File

@@ -44,7 +44,7 @@ logger = logging.getLogger(__name__)
OPENLLM_DEV_BUILD = 'OPENLLM_DEV_BUILD'
def build_editable(path: str, package: t.Literal['openllm', 'openllm_core', 'openllm_client'] = 'openllm') -> str | None:
'''Build OpenLLM if the OPENLLM_DEV_BUILD environment variable is set.'''
"""Build OpenLLM if the OPENLLM_DEV_BUILD environment variable is set."""
if str(os.environ.get(OPENLLM_DEV_BUILD, False)).lower() != 'true': return None
# We need to build the package in editable mode, so that we can import it
from build import ProjectBuilder

View File

@@ -1,5 +1,5 @@
# mypy: disable-error-code="misc"
'''OCI-related utilities for OpenLLM. This module is considered to be internal and API are subjected to change.'''
"""OCI-related utilities for OpenLLM. This module is considered to be internal and API are subjected to change."""
from __future__ import annotations
import functools
import importlib

View File

@@ -1,4 +1,4 @@
'''OpenLLM CLI.
"""OpenLLM CLI.
For more information see ``openllm -h``.
'''
"""

View File

@@ -290,7 +290,7 @@ def parse_device_callback(ctx: click.Context, param: click.Parameter, value: tup
_IGNORED_OPTIONS = {'working_dir', 'production', 'protocol_version'}
def parse_serve_args(serve_grpc: bool) -> t.Callable[[t.Callable[..., LLMConfig]], t.Callable[[FC], FC]]:
'''Parsing `bentoml serve|serve-grpc` click.Option to be parsed via `openllm start`.'''
"""Parsing `bentoml serve|serve-grpc` click.Option to be parsed via `openllm start`."""
from bentoml_cli.cli import cli
command = 'serve' if not serve_grpc else 'serve-grpc'
@@ -320,11 +320,11 @@ def parse_serve_args(serve_grpc: bool) -> t.Callable[[t.Callable[..., LLMConfig]
_http_server_args, _grpc_server_args = parse_serve_args(False), parse_serve_args(True)
def _click_factory_type(*param_decls: t.Any, **attrs: t.Any) -> t.Callable[[FC | None], FC]:
'''General ``@click`` decorator with some sauce.
"""General ``@click`` decorator with some sauce.
This decorator extends the default ``@click.option`` plus a factory option and factory attr to
provide type-safe click.option or click.argument wrapper for all compatible factory.
'''
"""
factory = attrs.pop('factory', click)
factory_attr = attrs.pop('attr', 'option')
if factory_attr != 'argument': attrs.setdefault('help', 'General option for OpenLLM CLI.')

View File

@@ -262,7 +262,7 @@ def _import_model(model_name: str,
return import_command.main(args=args, standalone_mode=False)
def _list_models() -> dict[str, t.Any]:
'''List all available models within the local store.'''
"""List all available models within the local store."""
from .entrypoint import models_command
return models_command.main(args=['-o', 'json', '--show-available', '--machine'], standalone_mode=False)

View File

@@ -151,7 +151,7 @@ class OpenLLMCommandGroup(BentoMLCommandGroup):
@staticmethod
def common_params(f: t.Callable[P, t.Any]) -> t.Callable[[FC], FC]:
# The following logics is similar to one of BentoMLCommandGroup
@cog.optgroup.group(name='Global options', help='Shared globals options for all OpenLLM CLI.')
@cog.optgroup.group(name='Global options', help='Shared globals options for all OpenLLM CLI.') # type: ignore[misc]
@cog.optgroup.option('-q', '--quiet', envvar=QUIET_ENV_VAR, is_flag=True, default=False, help='Suppress all output.', show_envvar=True)
@cog.optgroup.option('--debug', '--verbose', 'debug', envvar=DEBUG_ENV_VAR, is_flag=True, default=False, help='Print out debug logs.', show_envvar=True)
@cog.optgroup.option('--do-not-track', is_flag=True, default=False, envvar=analytics.OPENLLM_DO_NOT_TRACK, help='Do not send usage info', show_envvar=True)
@@ -249,7 +249,7 @@ class OpenLLMCommandGroup(BentoMLCommandGroup):
return decorator
def format_commands(self, ctx: click.Context, formatter: click.HelpFormatter) -> None:
'''Additional format methods that include extensions as well as the default cli command.'''
"""Additional format methods that include extensions as well as the default cli command."""
from gettext import gettext as _
commands: list[tuple[str, click.Command]] = []
extensions: list[tuple[str, click.Command]] = []
@@ -285,7 +285,7 @@ class OpenLLMCommandGroup(BentoMLCommandGroup):
'-v',
message=f"%(prog)s, %(version)s (compiled: {'yes' if openllm.COMPILED else 'no'})\nPython ({platform.python_implementation()}) {platform.python_version()}")
def cli() -> None:
'''\b
"""\b
██████╗ ██████╗ ███████╗███╗ ██╗██╗ ██╗ ███╗ ███╗
██╔═══██╗██╔══██╗██╔════╝████╗ ██║██║ ██║ ████╗ ████║
██║ ██║██████╔╝█████╗ ██╔██╗ ██║██║ ██║ ██╔████╔██║
@@ -296,27 +296,27 @@ def cli() -> None:
\b
An open platform for operating large language models in production.
Fine-tune, serve, deploy, and monitor any LLMs with ease.
'''
"""
@cli.group(cls=OpenLLMCommandGroup, context_settings=termui.CONTEXT_SETTINGS, name='start', aliases=['start-http'])
def start_command() -> None:
'''Start any LLM as a REST server.
"""Start any LLM as a REST server.
\b
```bash
$ openllm <start|start-http> <model_name> --<options> ...
```
'''
"""
@cli.group(cls=OpenLLMCommandGroup, context_settings=termui.CONTEXT_SETTINGS, name='start-grpc')
def start_grpc_command() -> None:
'''Start any LLM as a gRPC server.
"""Start any LLM as a gRPC server.
\b
```bash
$ openllm start-grpc <model_name> --<options> ...
```
'''
"""
_start_mapping = {
'start': {
@@ -424,7 +424,7 @@ def import_command(model_name: str, model_id: str | None, converter: str | None,
@click.option('--bento-version', type=str, default=None, help='Optional bento version for this BentoLLM. Default is the the model revision.')
@click.option('--overwrite', is_flag=True, help='Overwrite existing Bento for given LLM if it already exists.')
@workers_per_resource_option(factory=click, build=True)
@cog.optgroup.group(cls=cog.MutuallyExclusiveOptionGroup, name='Optimisation options')
@cog.optgroup.group(cls=cog.MutuallyExclusiveOptionGroup, name='Optimisation options') # type: ignore[misc]
@quantize_option(factory=cog.optgroup, build=True)
@click.option('--enable-features',
multiple=True,
@@ -445,7 +445,7 @@ def import_command(model_name: str, model_id: str | None, converter: str | None,
type=click.Choice(['release', 'latest', 'nightly']),
default='release',
help="Default container version strategy for the image from '--container-registry'")
@cog.optgroup.group(cls=cog.MutuallyExclusiveOptionGroup, name='Utilities options')
@cog.optgroup.group(cls=cog.MutuallyExclusiveOptionGroup, name='Utilities options') # type: ignore[misc]
@cog.optgroup.option('--containerize',
default=False,
is_flag=True,
@@ -459,7 +459,7 @@ def build_command(ctx: click.Context, /, model_name: str, model_id: str | None,
system_message: str | None, prompt_template_file: t.IO[t.Any] | None, machine: bool, model_version: str | None, dockerfile_template: t.TextIO | None, containerize: bool,
push: bool, serialisation: LiteralSerialisation | None, container_registry: LiteralContainerRegistry, container_version_strategy: LiteralContainerVersionStrategy,
force_push: bool, **attrs: t.Any) -> bentoml.Bento:
'''Package a given models into a Bento.
"""Package a given models into a Bento.
\b
```bash
@@ -475,7 +475,7 @@ def build_command(ctx: click.Context, /, model_name: str, model_id: str | None,
> [!IMPORTANT]
> To build the bento with compiled OpenLLM, make sure to prepend HATCH_BUILD_HOOKS_ENABLE=1. Make sure that the deployment
> target also use the same Python version and architecture as build machine.
'''
"""
if machine: output = 'porcelain'
if enable_features: enable_features = tuple(itertools.chain.from_iterable((s.split(',') for s in enable_features)))
@@ -679,11 +679,11 @@ def prune_command(model_name: str | None,
include_bentos: bool,
model_store: ModelStore = Provide[BentoMLContainer.model_store],
bento_store: BentoStore = Provide[BentoMLContainer.bento_store]) -> None:
'''Remove all saved models, (and optionally bentos) built with OpenLLM locally.
"""Remove all saved models, (and optionally bentos) built with OpenLLM locally.
\b
If a model type is passed, then only prune models for that given model type.
'''
"""
available: list[tuple[bentoml.Model | bentoml.Bento,
ModelStore | BentoStore]] = [(m, model_store) for m in bentoml.models.list() if 'framework' in m.info.labels and m.info.labels['framework'] == 'openllm']
if model_name is not None: available = [(m, store) for m, store in available if 'model_name' in m.info.labels and m.info.labels['model_name'] == inflection.underscore(model_name)]
@@ -823,6 +823,6 @@ def query_command(ctx: click.Context, /, prompt: str, endpoint: str, timeout: in
@cli.group(cls=Extensions, hidden=True, name='extension')
def extension_command() -> None:
'''Extension for OpenLLM CLI.'''
"""Extension for OpenLLM CLI."""
if __name__ == '__main__': cli()

View File

@@ -25,7 +25,7 @@ if t.TYPE_CHECKING:
@click.pass_context
@inject
def cli(ctx: click.Context, bento: str, machine: bool, _bento_store: BentoStore = Provide[BentoMLContainer.bento_store]) -> str | None:
'''Dive into a BentoLLM. This is synonymous to cd $(b get <bento>:<tag> -o path).'''
"""Dive into a BentoLLM. This is synonymous to cd $(b get <bento>:<tag> -o path)."""
try:
bentomodel = _bento_store.get(bento)
except bentoml.exceptions.NotFound:

View File

@@ -31,7 +31,7 @@ LiteralOutput = t.Literal['json', 'pretty', 'porcelain']
metavar='ARG=VALUE[,ARG=VALUE]')
@click.pass_context
def cli(ctx: click.Context, /, model_name: str, prompt: str, format: str | None, output: LiteralOutput, machine: bool, _memoized: dict[str, t.Any], **_: t.Any) -> str | None:
'''Get the default prompt used by OpenLLM.'''
"""Get the default prompt used by OpenLLM."""
module = openllm.utils.EnvVarMixin(model_name).module
_memoized = {k: v[0] for k, v in _memoized.items() if v}
try:

View File

@@ -16,7 +16,7 @@ from openllm.cli._factory import output_option
@output_option(default_value='json')
@click.pass_context
def cli(ctx: click.Context, output: LiteralOutput) -> None:
'''List available bentos built by OpenLLM.'''
"""List available bentos built by OpenLLM."""
mapping = {
k: [{
'tag': str(b.tag),

View File

@@ -22,7 +22,7 @@ if t.TYPE_CHECKING:
@model_name_argument(required=False, shell_complete=model_complete_envvar)
@output_option(default_value='json')
def cli(model_name: str | None, output: LiteralOutput) -> DictStrAny:
'''This is equivalent to openllm models --show-available less the nice table.'''
"""This is equivalent to openllm models --show-available less the nice table."""
models = tuple(inflection.dasherize(key) for key in openllm.CONFIG_MAPPING.keys())
ids_in_local_store = {
k: [i for i in bentoml.models.list() if 'framework' in i.info.labels and i.info.labels['framework'] == 'openllm' and 'model_name' in i.info.labels and i.info.labels['model_name'] == k]

View File

@@ -1,11 +1,11 @@
'''Entrypoint for all third-party apps.
"""Entrypoint for all third-party apps.
Currently support OpenAI compatible API.
Each module should implement the following API:
- `mount_to_svc(svc: bentoml.Service, llm: openllm.LLM[M, T]) -> bentoml.Service: ...`
'''
"""
from __future__ import annotations
import typing as t

View File

@@ -480,7 +480,7 @@ def get_generator(title: str, components: list[type[AttrsInstance]] | None = Non
def component_schema_generator(attr_cls: type[AttrsInstance], description: str | None = None) -> dict[str, t.Any]:
schema: dict[str, t.Any] = {'type': 'object', 'required': [], 'properties': {}, 'title': attr_cls.__name__}
schema['description'] = first_not_none(getattr(attr_cls, '__doc__', None), description, default=f'Generated components for {attr_cls.__name__}')
for field in attr.fields(attr.resolve_types(attr_cls)): # type: ignore[misc]
for field in attr.fields(attr.resolve_types(attr_cls)): # type: ignore[misc,type-var]
attr_type = field.type
origin_type = t.get_origin(attr_type)
args_type = t.get_args(attr_type)
@@ -495,21 +495,12 @@ def component_schema_generator(attr_cls: type[AttrsInstance], description: str |
elif origin_type is dict:
schema_type = 'object'
# Assuming string keys for simplicity, and handling Any type for values
prop_schema = {
'type': 'object',
'additionalProperties':
True if args_type[1] is t.Any else {
'type': 'string'
} # Simplified
}
prop_schema = {'type': 'object', 'additionalProperties': True if args_type[1] is t.Any else {'type': 'string'}}
elif attr_type == t.Optional[str]:
schema_type = 'string'
elif origin_type is t.Union and t.Any in args_type:
schema_type = 'object'
prop_schema = {
'type': 'object',
'additionalProperties': True # Allows any type of values
}
prop_schema = {'type': 'object', 'additionalProperties': True}
else:
schema_type = 'string'

View File

@@ -1,4 +1,4 @@
'''Base exceptions for OpenLLM. This extends BentoML exceptions.'''
"""Base exceptions for OpenLLM. This extends BentoML exceptions."""
from __future__ import annotations
from openllm_core.exceptions import Error as Error

View File

@@ -1,7 +1,7 @@
'''Protocol-related packages for all library integrations.
"""Protocol-related packages for all library integrations.
Currently support OpenAI compatible API.
'''
"""
from __future__ import annotations
import os
import typing as t

View File

@@ -1,9 +1,9 @@
'''Serialisation utilities for OpenLLM.
"""Serialisation utilities for OpenLLM.
Currently supports transformers for PyTorch, and vLLM.
Currently, GGML format is working in progress.
'''
"""
from __future__ import annotations
import importlib
import typing as t
@@ -32,11 +32,11 @@ else:
P = ParamSpec('P')
def load_tokenizer(llm: openllm.LLM[t.Any, T], **tokenizer_attrs: t.Any) -> T:
'''Load the tokenizer from BentoML store.
"""Load the tokenizer from BentoML store.
By default, it will try to find the bentomodel whether it is in store..
If model is not found, it will raises a ``bentoml.exceptions.NotFound``.
'''
"""
from .transformers._helpers import process_config
config, *_ = process_config(llm.bentomodel.path, llm.trust_remote_code)

View File

@@ -1,7 +1,7 @@
'''Serialisation related implementation for GGML-based implementation.
"""Serialisation related implementation for GGML-based implementation.
This requires ctransformers to be installed.
'''
"""
from __future__ import annotations
import typing as t

View File

@@ -1,4 +1,4 @@
'''Serialisation related implementation for Transformers-based implementation.'''
"""Serialisation related implementation for Transformers-based implementation."""
from __future__ import annotations
import importlib
import logging
@@ -150,13 +150,13 @@ def import_model(llm: openllm.LLM[M, T], *decls: t.Any, trust_remote_code: bool,
return bentomodel
def get(llm: openllm.LLM[M, T], auto_import: bool = False) -> bentoml.Model:
'''Return an instance of ``bentoml.Model`` from given LLM instance.
"""Return an instance of ``bentoml.Model`` from given LLM instance.
By default, it will try to check the model in the local store.
If model is not found, and ``auto_import`` is set to True, it will try to import the model from HuggingFace Hub.
Otherwise, it will raises a ``bentoml.exceptions.NotFound``.
'''
"""
try:
model = bentoml.models.get(llm.tag)
backend = model.info.labels['backend']

View File

@@ -26,7 +26,7 @@ def get_hash(config: transformers.PretrainedConfig) -> str:
return _commit_hash
def process_config(model_id: str, trust_remote_code: bool, **attrs: t.Any) -> tuple[transformers.PretrainedConfig, DictStrAny, DictStrAny]:
'''A helper function that correctly parse config and attributes for transformers.PretrainedConfig.
"""A helper function that correctly parse config and attributes for transformers.PretrainedConfig.
Args:
model_id: Model id to pass into ``transformers.AutoConfig``.
@@ -35,7 +35,7 @@ def process_config(model_id: str, trust_remote_code: bool, **attrs: t.Any) -> tu
Returns:
A tuple of ``transformers.PretrainedConfig``, all hub attributes, and remanining attributes that can be used by the Model class.
'''
"""
config = attrs.pop('config', None)
# this logic below is synonymous to handling `from_pretrained` attrs.
hub_attrs = {k: attrs.pop(k) for k in HUB_ATTRS if k in attrs}

View File

@@ -1,4 +1,4 @@
'''Tests utilities for OpenLLM.'''
"""Tests utilities for OpenLLM."""
from __future__ import annotations
import contextlib
import logging

View File

@@ -14,7 +14,7 @@ env_strats = st.sampled_from([openllm.utils.EnvVarMixin(model_name) for model_na
@st.composite
def model_settings(draw: st.DrawFn):
'''Strategy for generating ModelSettings objects.'''
"""Strategy for generating ModelSettings objects."""
kwargs: dict[str, t.Any] = {
'default_id': st.text(min_size=1),
'model_ids': st.lists(st.text(), min_size=1),