fix(infra): conform ruff to 150 LL (#781)

Generally correctly format it with ruff format and manual style

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
Aaron Pham
2023-12-14 17:27:32 -05:00
committed by GitHub
parent 8d989767e8
commit c8c9663d06
90 changed files with 1832 additions and 1893 deletions

View File

@@ -1,4 +1,4 @@
'''OpenLLM CLI.
"""OpenLLM CLI.
For more information see ``openllm -h``.
'''
"""

View File

@@ -5,24 +5,12 @@ from bentoml_cli.utils import BentoMLCommandGroup
from click import shell_completion as sc
from openllm_core._configuration import LLMConfig
from openllm_core._typing_compat import (
Concatenate,
DictStrAny,
LiteralBackend,
LiteralSerialisation,
ParamSpec,
AnyCallable,
get_literal_args,
)
from openllm_core._typing_compat import Concatenate, DictStrAny, LiteralBackend, LiteralSerialisation, ParamSpec, AnyCallable, get_literal_args
from openllm_core.utils import DEBUG, compose, dantic, resolve_user_filepath
class _OpenLLM_GenericInternalConfig(LLMConfig):
__config__ = {
'name_type': 'lowercase',
'default_id': 'openllm/generic',
'model_ids': ['openllm/generic'],
'architecture': 'PreTrainedModel',
}
__config__ = {'name_type': 'lowercase', 'default_id': 'openllm/generic', 'model_ids': ['openllm/generic'], 'architecture': 'PreTrainedModel'}
class GenerationConfig:
top_k: int = 15
@@ -30,6 +18,7 @@ class _OpenLLM_GenericInternalConfig(LLMConfig):
temperature: float = 0.75
max_new_tokens: int = 128
logger = logging.getLogger(__name__)
P = ParamSpec('P')
@@ -38,6 +27,7 @@ LiteralOutput = t.Literal['json', 'pretty', 'porcelain']
_AnyCallable = t.Callable[..., t.Any]
FC = t.TypeVar('FC', bound=t.Union[_AnyCallable, click.Command])
def bento_complete_envvar(ctx: click.Context, param: click.Parameter, incomplete: str) -> list[sc.CompletionItem]:
return [
sc.CompletionItem(str(it.tag), help='Bento')
@@ -45,20 +35,13 @@ def bento_complete_envvar(ctx: click.Context, param: click.Parameter, incomplete
if str(it.tag).startswith(incomplete) and all(k in it.info.labels for k in {'start_name', 'bundler'})
]
def model_complete_envvar(ctx: click.Context, param: click.Parameter, incomplete: str) -> list[sc.CompletionItem]:
return [
sc.CompletionItem(inflection.dasherize(it), help='Model')
for it in openllm.CONFIG_MAPPING
if it.startswith(incomplete)
]
return [sc.CompletionItem(inflection.dasherize(it), help='Model') for it in openllm.CONFIG_MAPPING if it.startswith(incomplete)]
def parse_config_options(
config: LLMConfig,
server_timeout: int,
workers_per_resource: float,
device: t.Tuple[str, ...] | None,
cors: bool,
environ: DictStrAny,
config: LLMConfig, server_timeout: int, workers_per_resource: float, device: t.Tuple[str, ...] | None, cors: bool, environ: DictStrAny
) -> DictStrAny:
# TODO: Support amd.com/gpu on k8s
_bentoml_config_options_env = environ.pop('BENTOML_CONFIG_OPTIONS', '')
@@ -72,26 +55,16 @@ def parse_config_options(
]
if device:
if len(device) > 1:
_bentoml_config_options_opts.extend(
[
f'runners."llm-{config["start_name"]}-runner".resources."nvidia.com/gpu"[{idx}]={dev}'
for idx, dev in enumerate(device)
]
)
_bentoml_config_options_opts.extend([
f'runners."llm-{config["start_name"]}-runner".resources."nvidia.com/gpu"[{idx}]={dev}' for idx, dev in enumerate(device)
])
else:
_bentoml_config_options_opts.append(
f'runners."llm-{config["start_name"]}-runner".resources."nvidia.com/gpu"=[{device[0]}]'
)
_bentoml_config_options_opts.append(f'runners."llm-{config["start_name"]}-runner".resources."nvidia.com/gpu"=[{device[0]}]')
if cors:
_bentoml_config_options_opts.extend(
['api_server.http.cors.enabled=true', 'api_server.http.cors.access_control_allow_origins="*"']
)
_bentoml_config_options_opts.extend(
[
f'api_server.http.cors.access_control_allow_methods[{idx}]="{it}"'
for idx, it in enumerate(['GET', 'OPTIONS', 'POST', 'HEAD', 'PUT'])
]
)
_bentoml_config_options_opts.extend(['api_server.http.cors.enabled=true', 'api_server.http.cors.access_control_allow_origins="*"'])
_bentoml_config_options_opts.extend([
f'api_server.http.cors.access_control_allow_methods[{idx}]="{it}"' for idx, it in enumerate(['GET', 'OPTIONS', 'POST', 'HEAD', 'PUT'])
])
_bentoml_config_options_env += ' ' if _bentoml_config_options_env else '' + ' '.join(_bentoml_config_options_opts)
environ['BENTOML_CONFIG_OPTIONS'] = _bentoml_config_options_env
if DEBUG:
@@ -119,22 +92,27 @@ def _id_callback(ctx: click.Context, _: click.Parameter, value: t.Tuple[str, ...
ctx.params[_adapter_mapping_key][adapter_id] = name
return None
def optimization_decorator(fn: FC, *, factory=click, _eager=True) -> FC | list[AnyCallable]:
shared = [
dtype_option(factory=factory), model_version_option(factory=factory), #
backend_option(factory=factory), quantize_option(factory=factory), #
dtype_option(factory=factory),
model_version_option(factory=factory), #
backend_option(factory=factory),
quantize_option(factory=factory), #
serialisation_option(factory=factory),
]
if not _eager: return shared
if not _eager:
return shared
return compose(*shared)(fn)
def start_decorator(fn: FC) -> FC:
composed = compose(
_OpenLLM_GenericInternalConfig.parse,
parse_serve_args(),
cog.optgroup.group(
'LLM Options',
help='''The following options are related to running LLM Server as well as optimization options.
help="""The following options are related to running LLM Server as well as optimization options.
OpenLLM supports running model k-bit quantization (8-bit, 4-bit), GPTQ quantization, PagedAttention via vLLM.
@@ -142,7 +120,7 @@ def start_decorator(fn: FC) -> FC:
- DeepSpeed Inference: [link](https://www.deepspeed.ai/inference/)
- GGML: Fast inference on [bare metal](https://github.com/ggerganov/ggml)
''',
""",
),
cog.optgroup.option('--server-timeout', type=int, default=None, help='Server timeout in seconds'),
workers_per_resource_option(factory=cog.optgroup),
@@ -163,12 +141,14 @@ def start_decorator(fn: FC) -> FC:
return composed(fn)
def parse_device_callback(_: click.Context, param: click.Parameter, value: tuple[tuple[str], ...] | None) -> t.Tuple[str, ...] | None:
if value is None:
return value
el: t.Tuple[str, ...] = tuple(i for k in value for i in k)
# NOTE: --device all is a special case
if len(el) == 1 and el[0] == 'all': return tuple(map(str, openllm.utils.available_devices()))
if len(el) == 1 and el[0] == 'all':
return tuple(map(str, openllm.utils.available_devices()))
return el
@@ -182,15 +162,12 @@ def parse_serve_args() -> t.Callable[[t.Callable[..., LLMConfig]], t.Callable[[F
from bentoml_cli.cli import cli
group = cog.optgroup.group('Start a HTTP server options', help='Related to serving the model [synonymous to `bentoml serve-http`]')
def decorator(f: t.Callable[Concatenate[int, t.Optional[str], P], LLMConfig]) -> t.Callable[[FC], FC]:
serve_command = cli.commands['serve']
# The first variable is the argument bento
# The last five is from BentoMLCommandGroup.NUMBER_OF_COMMON_PARAMS
serve_options = [
p
for p in serve_command.params[1 : -BentoMLCommandGroup.NUMBER_OF_COMMON_PARAMS]
if p.name not in _IGNORED_OPTIONS
]
serve_options = [p for p in serve_command.params[1 : -BentoMLCommandGroup.NUMBER_OF_COMMON_PARAMS] if p.name not in _IGNORED_OPTIONS]
for options in reversed(serve_options):
attrs = options.to_info_dict()
# we don't need param_type_name, since it should all be options
@@ -202,14 +179,16 @@ def parse_serve_args() -> t.Callable[[t.Callable[..., LLMConfig]], t.Callable[[F
param_decls = (*attrs.pop('opts'), *attrs.pop('secondary_opts'))
f = cog.optgroup.option(*param_decls, **attrs)(f)
return group(f)
return decorator
def _click_factory_type(*param_decls: t.Any, **attrs: t.Any) -> t.Callable[[FC | None], FC]:
'''General ``@click`` decorator with some sauce.
"""General ``@click`` decorator with some sauce.
This decorator extends the default ``@click.option`` plus a factory option and factory attr to
provide type-safe click.option or click.argument wrapper for all compatible factory.
'''
"""
factory = attrs.pop('factory', click)
factory_attr = attrs.pop('attr', 'option')
if factory_attr != 'argument':
@@ -242,18 +221,14 @@ def adapter_id_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callab
def cors_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
return cli_option(
'--cors/--no-cors',
show_default=True,
default=False,
envvar='OPENLLM_CORS',
show_envvar=True,
help='Enable CORS for the server.',
**attrs,
'--cors/--no-cors', show_default=True, default=False, envvar='OPENLLM_CORS', show_envvar=True, help='Enable CORS for the server.', **attrs
)(f)
def machine_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
return cli_option('--machine', is_flag=True, default=False, hidden=True, **attrs)(f)
def dtype_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
return cli_option(
'--dtype',
@@ -264,6 +239,7 @@ def dtype_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[F
**attrs,
)(f)
def model_id_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
return cli_option(
'--model-id',
@@ -294,16 +270,14 @@ def backend_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[
envvar='OPENLLM_BACKEND',
show_envvar=True,
help='Runtime to use for both serialisation/inference engine.',
**attrs)(f)
def model_name_argument(f: _AnyCallable | None = None, required: bool = True, **attrs: t.Any) -> t.Callable[[FC], FC]:
return cli_argument(
'model_name',
type=click.Choice([inflection.dasherize(name) for name in openllm.CONFIG_MAPPING]),
required=required,
**attrs,
)(f)
def model_name_argument(f: _AnyCallable | None = None, required: bool = True, **attrs: t.Any) -> t.Callable[[FC], FC]:
return cli_argument('model_name', type=click.Choice([inflection.dasherize(name) for name in openllm.CONFIG_MAPPING]), required=required, **attrs)(f)
def quantize_option(f: _AnyCallable | None = None, *, build: bool = False, **attrs: t.Any) -> t.Callable[[FC], FC]:
return cli_option(
'--quantise',
@@ -313,7 +287,7 @@ def quantize_option(f: _AnyCallable | None = None, *, build: bool = False, **att
default=None,
envvar='OPENLLM_QUANTIZE',
show_envvar=True,
help='''Dynamic quantization for running this LLM.
help="""Dynamic quantization for running this LLM.
The following quantization strategies are supported:
@@ -328,23 +302,25 @@ def quantize_option(f: _AnyCallable | None = None, *, build: bool = False, **att
- ``squeezellm``: ``SqueezeLLM`` [SqueezeLLM: Dense-and-Sparse Quantization](https://arxiv.org/abs/2306.07629)
> [!NOTE] that the model can also be served with quantized weights.
'''
"""
+ (
'''
> [!NOTE] that this will set the mode for serving within deployment.''' if build else ''
"""
> [!NOTE] that this will set the mode for serving within deployment."""
if build
else ''
),
**attrs)(f)
**attrs,
)(f)
def workers_per_resource_option(
f: _AnyCallable | None = None, *, build: bool = False, **attrs: t.Any
) -> t.Callable[[FC], FC]:
def workers_per_resource_option(f: _AnyCallable | None = None, *, build: bool = False, **attrs: t.Any) -> t.Callable[[FC], FC]:
return cli_option(
'--workers-per-resource',
default=None,
callback=workers_per_resource_callback,
type=str,
required=False,
help='''Number of workers per resource assigned.
help="""Number of workers per resource assigned.
See https://docs.bentoml.org/en/latest/guides/scheduling.html#resource-scheduling-strategy
for more information. By default, this is set to 1.
@@ -354,7 +330,7 @@ def workers_per_resource_option(
- ``round_robin``: Similar behaviour when setting ``--workers-per-resource 1``. This is useful for smaller models.
- ``conserved``: This will determine the number of available GPU resources. For example, if ther are 4 GPUs available, then ``conserved`` is equivalent to ``--workers-per-resource 0.25``.
'''
"""
+ (
"""\n
> [!NOTE] The workers value passed into 'build' will determine how the LLM can
@@ -366,6 +342,7 @@ def workers_per_resource_option(
**attrs,
)(f)
def serialisation_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
return cli_option(
'--serialisation',
@@ -376,7 +353,7 @@ def serialisation_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Cal
show_default=True,
show_envvar=True,
envvar='OPENLLM_SERIALIZATION',
help='''Serialisation format for save/load LLM.
help="""Serialisation format for save/load LLM.
Currently the following strategies are supported:
@@ -385,12 +362,14 @@ def serialisation_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Cal
> [!NOTE] Safetensors might not work for every cases, and you can always fallback to ``legacy`` if needed.
- ``legacy``: This will use PyTorch serialisation format, often as ``.bin`` files. This should be used if the model doesn't yet support safetensors.
''',
""",
**attrs,
)(f)
_wpr_strategies = {'round_robin', 'conserved'}
def workers_per_resource_callback(ctx: click.Context, param: click.Parameter, value: str | None) -> str | None:
if value is None:
return value
@@ -402,9 +381,7 @@ def workers_per_resource_callback(ctx: click.Context, param: click.Parameter, va
float(value) # type: ignore[arg-type]
except ValueError:
raise click.BadParameter(
f"'workers_per_resource' only accept '{_wpr_strategies}' as possible strategies, otherwise pass in float.",
ctx,
param,
f"'workers_per_resource' only accept '{_wpr_strategies}' as possible strategies, otherwise pass in float.", ctx, param
) from None
else:
return value

View File

@@ -6,6 +6,7 @@ from bentoml._internal.configuration.containers import BentoMLContainer
from openllm_core._typing_compat import LiteralSerialisation
from openllm_core.exceptions import OpenLLMException
from openllm_core.utils import WARNING_ENV_VAR, codegen, first_not_none, get_disable_warnings, is_vllm_available
if t.TYPE_CHECKING:
from bentoml._internal.bento import BentoStore
from openllm_core._configuration import LLMConfig
@@ -13,6 +14,7 @@ if t.TYPE_CHECKING:
logger = logging.getLogger(__name__)
def _start(
model_id: str,
timeout: int = 30,
@@ -61,27 +63,25 @@ def _start(
additional_args: Additional arguments to pass to ``openllm start``.
"""
from .entrypoint import start_command
os.environ['BACKEND'] = openllm_core.utils.first_not_none(backend, default='vllm' if is_vllm_available() else 'pt')
args: list[str] = [model_id]
if timeout: args.extend(['--server-timeout', str(timeout)])
if timeout:
args.extend(['--server-timeout', str(timeout)])
if workers_per_resource:
args.extend(
[
'--workers-per-resource',
str(workers_per_resource) if not isinstance(workers_per_resource, str) else workers_per_resource,
]
)
if device and not os.environ.get('CUDA_VISIBLE_DEVICES'): args.extend(['--device', ','.join(device)])
if quantize: args.extend(['--quantize', str(quantize)])
if cors: args.append('--cors')
args.extend(['--workers-per-resource', str(workers_per_resource) if not isinstance(workers_per_resource, str) else workers_per_resource])
if device and not os.environ.get('CUDA_VISIBLE_DEVICES'):
args.extend(['--device', ','.join(device)])
if quantize:
args.extend(['--quantize', str(quantize)])
if cors:
args.append('--cors')
if adapter_map:
args.extend(
list(
itertools.chain.from_iterable([['--adapter-id', f"{k}{':'+v if v else ''}"] for k, v in adapter_map.items()])
)
)
if additional_args: args.extend(additional_args)
if __test__: args.append('--return-process')
args.extend(list(itertools.chain.from_iterable([['--adapter-id', f"{k}{':'+v if v else ''}"] for k, v in adapter_map.items()])))
if additional_args:
args.extend(additional_args)
if __test__:
args.append('--return-process')
cmd = start_command
return cmd.main(args=args, standalone_mode=False)
@@ -138,6 +138,7 @@ def _build(
``bentoml.Bento | str``: BentoLLM instance. This can be used to serve the LLM or can be pushed to BentoCloud.
"""
from openllm.serialisation.transformers.weights import has_safetensors_weights
args: list[str] = [
sys.executable,
'-m',
@@ -147,23 +148,34 @@ def _build(
'--machine',
'--quiet',
'--serialisation',
first_not_none(
serialisation, default='safetensors' if has_safetensors_weights(model_id, model_version) else 'legacy'
),
first_not_none(serialisation, default='safetensors' if has_safetensors_weights(model_id, model_version) else 'legacy'),
]
if quantize: args.extend(['--quantize', quantize])
if containerize and push: raise OpenLLMException("'containerize' and 'push' are currently mutually exclusive.")
if push: args.extend(['--push'])
if containerize: args.extend(['--containerize'])
if build_ctx: args.extend(['--build-ctx', build_ctx])
if enable_features: args.extend([f'--enable-features={f}' for f in enable_features])
if overwrite: args.append('--overwrite')
if adapter_map: args.extend([f"--adapter-id={k}{':'+v if v is not None else ''}" for k, v in adapter_map.items()])
if model_version: args.extend(['--model-version', model_version])
if bento_version: args.extend(['--bento-version', bento_version])
if dockerfile_template: args.extend(['--dockerfile-template', dockerfile_template])
if additional_args: args.extend(additional_args)
if force_push: args.append('--force-push')
if quantize:
args.extend(['--quantize', quantize])
if containerize and push:
raise OpenLLMException("'containerize' and 'push' are currently mutually exclusive.")
if push:
args.extend(['--push'])
if containerize:
args.extend(['--containerize'])
if build_ctx:
args.extend(['--build-ctx', build_ctx])
if enable_features:
args.extend([f'--enable-features={f}' for f in enable_features])
if overwrite:
args.append('--overwrite')
if adapter_map:
args.extend([f"--adapter-id={k}{':'+v if v is not None else ''}" for k, v in adapter_map.items()])
if model_version:
args.extend(['--model-version', model_version])
if bento_version:
args.extend(['--bento-version', bento_version])
if dockerfile_template:
args.extend(['--dockerfile-template', dockerfile_template])
if additional_args:
args.extend(additional_args)
if force_push:
args.append('--force-push')
current_disable_warning = get_disable_warnings()
os.environ[WARNING_ENV_VAR] = str(True)
@@ -171,17 +183,24 @@ def _build(
output = subprocess.check_output(args, env=os.environ.copy(), cwd=build_ctx or os.getcwd())
except subprocess.CalledProcessError as e:
logger.error("Exception caught while building Bento for '%s'", model_id, exc_info=e)
if e.stderr: raise OpenLLMException(e.stderr.decode('utf-8')) from None
if e.stderr:
raise OpenLLMException(e.stderr.decode('utf-8')) from None
raise OpenLLMException(str(e)) from None
matched = re.match(r'__object__:(\{.*\})$', output.decode('utf-8').strip())
if matched is None:
raise ValueError(f"Failed to find tag from output: {output.decode('utf-8').strip()}\nNote: Output from 'openllm build' might not be correct. Please open an issue on GitHub.")
raise ValueError(
f"Failed to find tag from output: {output.decode('utf-8').strip()}\nNote: Output from 'openllm build' might not be correct. Please open an issue on GitHub."
)
os.environ[WARNING_ENV_VAR] = str(current_disable_warning)
try:
result = orjson.loads(matched.group(1))
except orjson.JSONDecodeError as e:
raise ValueError(f"Failed to decode JSON from output: {output.decode('utf-8').strip()}\nNote: Output from 'openllm build' might not be correct. Please open an issue on GitHub.") from e
raise ValueError(
f"Failed to decode JSON from output: {output.decode('utf-8').strip()}\nNote: Output from 'openllm build' might not be correct. Please open an issue on GitHub."
) from e
return bentoml.get(result['tag'], _bento_store=bento_store)
def _import_model(
model_id: str,
model_version: str | None = None,
@@ -218,15 +237,32 @@ def _import_model(
``bentoml.Model``:BentoModel of the given LLM. This can be used to serve the LLM or can be pushed to BentoCloud.
"""
from .entrypoint import import_command
args = [model_id, '--quiet']
if backend is not None: args.extend(['--backend', backend])
if model_version is not None: args.extend(['--model-version', str(model_version)])
if quantize is not None: args.extend(['--quantize', quantize])
if serialisation is not None: args.extend(['--serialisation', serialisation])
if additional_args is not None: args.extend(additional_args)
if backend is not None:
args.extend(['--backend', backend])
if model_version is not None:
args.extend(['--model-version', str(model_version)])
if quantize is not None:
args.extend(['--quantize', quantize])
if serialisation is not None:
args.extend(['--serialisation', serialisation])
if additional_args is not None:
args.extend(additional_args)
return import_command.main(args=args, standalone_mode=False)
def _list_models() -> dict[str, t.Any]:
'''List all available models within the local store.'''
from .entrypoint import models_command; return models_command.main(args=['--quiet'], standalone_mode=False)
start, build, import_model, list_models = codegen.gen_sdk(_start), codegen.gen_sdk(_build), codegen.gen_sdk(_import_model), codegen.gen_sdk(_list_models)
"""List all available models within the local store."""
from .entrypoint import models_command
return models_command.main(args=['--quiet'], standalone_mode=False)
start, build, import_model, list_models = (
codegen.gen_sdk(_start),
codegen.gen_sdk(_build),
codegen.gen_sdk(_import_model),
codegen.gen_sdk(_list_models),
)
__all__ = ['start', 'build', 'import_model', 'list_models']

View File

@@ -43,15 +43,7 @@ from openllm_core.utils import (
)
from . import termui
from ._factory import (
FC,
_AnyCallable,
machine_option,
model_name_argument,
parse_config_options,
start_decorator,
optimization_decorator,
)
from ._factory import FC, _AnyCallable, machine_option, model_name_argument, parse_config_options, start_decorator, optimization_decorator
if t.TYPE_CHECKING:
import torch
@@ -65,14 +57,14 @@ else:
P = ParamSpec('P')
logger = logging.getLogger('openllm')
OPENLLM_FIGLET = '''\
OPENLLM_FIGLET = """\
██████╗ ██████╗ ███████╗███╗ ██╗██╗ ██╗ ███╗ ███╗
██╔═══██╗██╔══██╗██╔════╝████╗ ██║██║ ██║ ████╗ ████║
██║ ██║██████╔╝█████╗ ██╔██╗ ██║██║ ██║ ██╔████╔██║
██║ ██║██╔═══╝ ██╔══╝ ██║╚██╗██║██║ ██║ ██║╚██╔╝██║
╚██████╔╝██║ ███████╗██║ ╚████║███████╗███████╗██║ ╚═╝ ██║
╚═════╝ ╚═╝ ╚══════╝╚═╝ ╚═══╝╚══════╝╚══════╝╚═╝ ╚═╝
'''
"""
ServeCommand = t.Literal['serve', 'serve-grpc']
@@ -103,20 +95,12 @@ def backend_warning(backend: LiteralBackend, build: bool = False) -> None:
'vLLM is not available. Note that PyTorch backend is not as performant as vLLM and you should always consider using vLLM for production.'
)
if build:
logger.info(
"Tip: You can set '--backend vllm' to package your Bento with vLLM backend regardless if vLLM is available locally."
)
logger.info("Tip: You can set '--backend vllm' to package your Bento with vLLM backend regardless if vLLM is available locally.")
class Extensions(click.MultiCommand):
def list_commands(self, ctx: click.Context) -> list[str]:
return sorted(
[
filename[:-3]
for filename in os.listdir(_EXT_FOLDER)
if filename.endswith('.py') and not filename.startswith('__')
]
)
return sorted([filename[:-3] for filename in os.listdir(_EXT_FOLDER) if filename.endswith('.py') and not filename.startswith('__')])
def get_command(self, ctx: click.Context, cmd_name: str) -> click.Command | None:
try:
@@ -133,41 +117,19 @@ class OpenLLMCommandGroup(BentoMLCommandGroup):
def common_params(f: t.Callable[P, t.Any]) -> t.Callable[[FC], FC]:
# The following logics is similar to one of BentoMLCommandGroup
@cog.optgroup.group(name='Global options', help='Shared globals options for all OpenLLM CLI.') # type: ignore[misc]
@cog.optgroup.option('-q', '--quiet', envvar=QUIET_ENV_VAR, is_flag=True, default=False, help='Suppress all output.', show_envvar=True)
@cog.optgroup.option(
'-q', '--quiet', envvar=QUIET_ENV_VAR, is_flag=True, default=False, help='Suppress all output.', show_envvar=True
'--debug', '--verbose', 'debug', envvar=DEBUG_ENV_VAR, is_flag=True, default=False, help='Print out debug logs.', show_envvar=True
)
@cog.optgroup.option(
'--debug',
'--verbose',
'debug',
envvar=DEBUG_ENV_VAR,
is_flag=True,
default=False,
help='Print out debug logs.',
show_envvar=True,
'--do-not-track', is_flag=True, default=False, envvar=analytics.OPENLLM_DO_NOT_TRACK, help='Do not send usage info', show_envvar=True
)
@cog.optgroup.option(
'--do-not-track',
is_flag=True,
default=False,
envvar=analytics.OPENLLM_DO_NOT_TRACK,
help='Do not send usage info',
show_envvar=True,
)
@cog.optgroup.option(
'--context',
'cloud_context',
envvar='BENTOCLOUD_CONTEXT',
type=click.STRING,
default=None,
help='BentoCloud context name.',
show_envvar=True,
'--context', 'cloud_context', envvar='BENTOCLOUD_CONTEXT', type=click.STRING, default=None, help='BentoCloud context name.', show_envvar=True
)
@click.pass_context
@functools.wraps(f)
def wrapper(
ctx: click.Context, quiet: bool, debug: bool, cloud_context: str | None, *args: P.args, **attrs: P.kwargs
) -> t.Any:
def wrapper(ctx: click.Context, quiet: bool, debug: bool, cloud_context: str | None, *args: P.args, **attrs: P.kwargs) -> t.Any:
ctx.obj = GlobalOptions(cloud_context=cloud_context)
if quiet:
set_quiet_mode(True)
@@ -181,9 +143,7 @@ class OpenLLMCommandGroup(BentoMLCommandGroup):
return wrapper
@staticmethod
def usage_tracking(
func: t.Callable[P, t.Any], group: click.Group, **attrs: t.Any
) -> t.Callable[Concatenate[bool, P], t.Any]:
def usage_tracking(func: t.Callable[P, t.Any], group: click.Group, **attrs: t.Any) -> t.Callable[Concatenate[bool, P], t.Any]:
command_name = attrs.get('name', func.__name__)
@functools.wraps(func)
@@ -242,9 +202,7 @@ class OpenLLMCommandGroup(BentoMLCommandGroup):
_memo = getattr(wrapped, '__click_params__', None)
if _memo is None:
raise ValueError('Click command not register correctly.')
_object_setattr(
wrapped, '__click_params__', _memo[-self.NUMBER_OF_COMMON_PARAMS :] + _memo[: -self.NUMBER_OF_COMMON_PARAMS]
)
_object_setattr(wrapped, '__click_params__', _memo[-self.NUMBER_OF_COMMON_PARAMS :] + _memo[: -self.NUMBER_OF_COMMON_PARAMS])
# NOTE: we need to call super of super to avoid conflict with BentoMLCommandGroup command setup
cmd = super(BentoMLCommandGroup, self).command(*args, **kwargs)(wrapped)
# NOTE: add aliases to a given commands if it is specified.
@@ -258,7 +216,7 @@ class OpenLLMCommandGroup(BentoMLCommandGroup):
return decorator
def format_commands(self, ctx: click.Context, formatter: click.HelpFormatter) -> None:
'''Additional format methods that include extensions as well as the default cli command.'''
"""Additional format methods that include extensions as well as the default cli command."""
from gettext import gettext as _
commands: list[tuple[str, click.Command]] = []
@@ -305,7 +263,7 @@ _PACKAGE_NAME = 'openllm'
message=f'{_PACKAGE_NAME}, %(version)s (compiled: {openllm.COMPILED})\nPython ({platform.python_implementation()}) {platform.python_version()}',
)
def cli() -> None:
'''\b
"""\b
██████╗ ██████╗ ███████╗███╗ ██╗██╗ ██╗ ███╗ ███╗
██╔═══██╗██╔══██╗██╔════╝████╗ ██║██║ ██║ ████╗ ████║
██║ ██║██████╔╝█████╗ ██╔██╗ ██║██║ ██║ ██╔████╔██║
@@ -316,15 +274,10 @@ def cli() -> None:
\b
An open platform for operating large language models in production.
Fine-tune, serve, deploy, and monitor any LLMs with ease.
'''
"""
@cli.command(
context_settings=termui.CONTEXT_SETTINGS,
name='start',
aliases=['start-http'],
short_help='Start a LLMServer for any supported LLM.',
)
@cli.command(context_settings=termui.CONTEXT_SETTINGS, name='start', aliases=['start-http'], short_help='Start a LLMServer for any supported LLM.')
@click.argument('model_id', type=click.STRING, metavar='[REMOTE_REPO/MODEL_ID | /path/to/local/model]', required=True)
@click.option(
'--model-id',
@@ -366,24 +319,30 @@ def start_command(
dtype: LiteralDtype,
deprecated_model_id: str | None,
max_model_len: int | None,
gpu_memory_utilization:float,
gpu_memory_utilization: float,
**attrs: t.Any,
) -> LLMConfig | subprocess.Popen[bytes]:
'''Start any LLM as a REST server.
"""Start any LLM as a REST server.
\b
```bash
$ openllm <start|start-http> <model_id> --<options> ...
```
'''
if backend == 'pt': logger.warning('PyTorch backend is deprecated and will be removed in future releases. Make sure to use vLLM instead.')
"""
if backend == 'pt':
logger.warning('PyTorch backend is deprecated and will be removed in future releases. Make sure to use vLLM instead.')
if model_id in openllm.CONFIG_MAPPING:
_model_name = model_id
if deprecated_model_id is not None:
model_id = deprecated_model_id
else:
model_id = openllm.AutoConfig.for_model(_model_name)['default_id']
logger.warning("Passing 'openllm start %s%s' is deprecated and will be remove in a future version. Use 'openllm start %s' instead.", _model_name, '' if deprecated_model_id is None else f' --model-id {deprecated_model_id}', model_id)
logger.warning(
"Passing 'openllm start %s%s' is deprecated and will be remove in a future version. Use 'openllm start %s' instead.",
_model_name,
'' if deprecated_model_id is None else f' --model-id {deprecated_model_id}',
model_id,
)
adapter_map: dict[str, str] | None = attrs.pop('adapter_map', None)
@@ -393,11 +352,7 @@ def start_command(
if serialisation == 'safetensors' and quantize is not None:
logger.warning("'--quantize=%s' might not work with 'safetensors' serialisation format.", quantize)
logger.warning(
"Make sure to check out '%s' repository to see if the weights is in '%s' format if unsure.",
model_id,
serialisation,
)
logger.warning("Make sure to check out '%s' repository to see if the weights is in '%s' format if unsure.", model_id, serialisation)
logger.info("Tip: You can always fallback to '--serialisation legacy' when running quantisation.")
import torch
@@ -425,7 +380,7 @@ def start_command(
config, server_attrs = llm.config.model_validate_click(**attrs)
server_timeout = first_not_none(server_timeout, default=config['timeout'])
server_attrs.update({'working_dir': pkg.source_locations('openllm'), 'timeout': server_timeout})
development = server_attrs.pop('development') # XXX: currently, theres no development args in bentoml.Server. To be fixed upstream.
development = server_attrs.pop('development') # XXX: currently, theres no development args in bentoml.Server. To be fixed upstream.
server_attrs.setdefault('production', not development)
start_env = process_environ(
@@ -454,26 +409,27 @@ def start_command(
# NOTE: Return the configuration for telemetry purposes.
return config
def process_environ(config, server_timeout, wpr, device, cors, model_id, adapter_map, serialisation, llm, use_current_env=True):
environ = parse_config_options(config, server_timeout, wpr, device, cors, os.environ.copy() if use_current_env else {})
environ.update(
{
'OPENLLM_MODEL_ID': model_id,
'BENTOML_DEBUG': str(openllm.utils.get_debug_mode()),
'BENTOML_HOME': os.environ.get('BENTOML_HOME', BentoMLContainer.bentoml_home.get()),
'OPENLLM_ADAPTER_MAP': orjson.dumps(adapter_map).decode(),
'OPENLLM_SERIALIZATION': serialisation,
'OPENLLM_CONFIG': config.model_dump_json(flatten=True).decode(),
'BACKEND': llm.__llm_backend__,
'DTYPE': str(llm._torch_dtype).split('.')[-1],
'TRUST_REMOTE_CODE': str(llm.trust_remote_code),
'MAX_MODEL_LEN': orjson.dumps(llm._max_model_len).decode(),
'GPU_MEMORY_UTILIZATION': orjson.dumps(llm._gpu_memory_utilization).decode(),
}
)
if llm.quantise: environ['QUANTIZE'] = str(llm.quantise)
environ.update({
'OPENLLM_MODEL_ID': model_id,
'BENTOML_DEBUG': str(openllm.utils.get_debug_mode()),
'BENTOML_HOME': os.environ.get('BENTOML_HOME', BentoMLContainer.bentoml_home.get()),
'OPENLLM_ADAPTER_MAP': orjson.dumps(adapter_map).decode(),
'OPENLLM_SERIALIZATION': serialisation,
'OPENLLM_CONFIG': config.model_dump_json(flatten=True).decode(),
'BACKEND': llm.__llm_backend__,
'DTYPE': str(llm._torch_dtype).split('.')[-1],
'TRUST_REMOTE_CODE': str(llm.trust_remote_code),
'MAX_MODEL_LEN': orjson.dumps(llm._max_model_len).decode(),
'GPU_MEMORY_UTILIZATION': orjson.dumps(llm._gpu_memory_utilization).decode(),
})
if llm.quantise:
environ['QUANTIZE'] = str(llm.quantise)
return environ
def process_workers_per_resource(wpr: str | float | int, device: tuple[str, ...]) -> TypeGuard[float]:
if isinstance(wpr, str):
if wpr == 'round_robin':
@@ -491,6 +447,7 @@ def process_workers_per_resource(wpr: str | float | int, device: tuple[str, ...]
wpr = float(wpr)
return wpr
def build_bento_instruction(llm, model_id, serialisation, adapter_map):
cmd_name = f'openllm build {model_id} --backend {llm.__llm_backend__}'
if llm.quantise:
@@ -498,12 +455,9 @@ def build_bento_instruction(llm, model_id, serialisation, adapter_map):
if llm.__llm_backend__ in {'pt', 'vllm'}:
cmd_name += f' --serialization {serialisation}'
if adapter_map is not None:
cmd_name += ' ' + ' '.join(
[
f'--adapter-id {s}'
for s in [f'{p}:{name}' if name not in (None, 'default') else p for p, name in adapter_map.items()]
]
)
cmd_name += ' ' + ' '.join([
f'--adapter-id {s}' for s in [f'{p}:{name}' if name not in (None, 'default') else p for p, name in adapter_map.items()]
])
if not openllm.utils.get_quiet_mode():
termui.info(f"🚀Tip: run '{cmd_name}' to create a BentoLLM for '{model_id}'")
@@ -537,9 +491,8 @@ def run_server(args, env, return_process=False) -> subprocess.Popen[bytes] | int
if return_process:
return process
stop_event = threading.Event()
# yapf: disable
stdout, stderr = threading.Thread(target=handle, args=(process.stdout, stop_event)), threading.Thread(target=handle, args=(process.stderr, stop_event))
stdout.start(); stderr.start()
stdout.start(); stderr.start() # noqa: E702
try:
process.wait()
@@ -554,10 +507,9 @@ def run_server(args, env, return_process=False) -> subprocess.Popen[bytes] | int
raise
finally:
stop_event.set()
stdout.join(); stderr.join()
stdout.join(); stderr.join() # noqa: E702
if process.poll() is not None: process.kill()
stdout.join(); stderr.join()
# yapf: disable
stdout.join(); stderr.join() # noqa: E702
return process.returncode
@@ -645,10 +597,7 @@ def import_command(
backend=backend,
dtype=dtype,
serialisation=t.cast(
LiteralSerialisation,
first_not_none(
serialisation, default='safetensors' if has_safetensors_weights(model_id, model_version) else 'legacy'
),
LiteralSerialisation, first_not_none(serialisation, default='safetensors' if has_safetensors_weights(model_id, model_version) else 'legacy')
),
)
backend_warning(llm.__llm_backend__)
@@ -707,21 +656,14 @@ class BuildBentoOutput(t.TypedDict):
metavar='[REMOTE_REPO/MODEL_ID | /path/to/local/model]',
help='Deprecated. Use positional argument instead.',
)
@click.option(
'--bento-version',
type=str,
default=None,
help='Optional bento version for this BentoLLM. Default is the the model revision.',
)
@click.option('--bento-version', type=str, default=None, help='Optional bento version for this BentoLLM. Default is the the model revision.')
@click.option('--overwrite', is_flag=True, help='Overwrite existing Bento for given LLM if it already exists.')
@click.option(
'--enable-features',
multiple=True,
nargs=1,
metavar='FEATURE[,FEATURE]',
help='Enable additional features for building this LLM Bento. Available: {}'.format(
', '.join(OPTIONAL_DEPENDENCIES)
),
help='Enable additional features for building this LLM Bento. Available: {}'.format(', '.join(OPTIONAL_DEPENDENCIES)),
)
@optimization_decorator
@click.option(
@@ -732,12 +674,7 @@ class BuildBentoOutput(t.TypedDict):
help="Optional adapters id to be included within the Bento. Note that if you are using relative path, '--build-ctx' must be passed.",
)
@click.option('--build-ctx', help='Build context. This is required if --adapter-id uses relative path', default=None)
@click.option(
'--dockerfile-template',
default=None,
type=click.File(),
help='Optional custom dockerfile template to be used with this BentoLLM.',
)
@click.option('--dockerfile-template', default=None, type=click.File(), help='Optional custom dockerfile template to be used with this BentoLLM.')
@cog.optgroup.group(cls=cog.MutuallyExclusiveOptionGroup, name='Utilities options') # type: ignore[misc]
@cog.optgroup.option(
'--containerize',
@@ -788,13 +725,13 @@ def build_command(
build_ctx: str | None,
dockerfile_template: t.TextIO | None,
max_model_len: int | None,
gpu_memory_utilization:float,
gpu_memory_utilization: float,
containerize: bool,
push: bool,
force_push: bool,
**_: t.Any,
) -> BuildBentoOutput:
'''Package a given models into a BentoLLM.
"""Package a given models into a BentoLLM.
\b
```bash
@@ -810,7 +747,7 @@ def build_command(
> [!IMPORTANT]
> To build the bento with compiled OpenLLM, make sure to prepend HATCH_BUILD_HOOKS_ENABLE=1. Make sure that the deployment
> target also use the same Python version and architecture as build machine.
'''
"""
from openllm.serialisation.transformers.weights import has_safetensors_weights
if model_id in openllm.CONFIG_MAPPING:
@@ -840,9 +777,7 @@ def build_command(
dtype=dtype,
max_model_len=max_model_len,
gpu_memory_utilization=gpu_memory_utilization,
serialisation=first_not_none(
serialisation, default='safetensors' if has_safetensors_weights(model_id, model_version) else 'legacy'
),
serialisation=first_not_none(serialisation, default='safetensors' if has_safetensors_weights(model_id, model_version) else 'legacy'),
_eager=False,
)
if llm.__llm_backend__ not in llm.config['backend']:
@@ -854,9 +789,7 @@ def build_command(
model = openllm.serialisation.import_model(llm, trust_remote_code=llm.trust_remote_code)
llm._tag = model.tag
os.environ.update(
**process_environ(llm.config, llm.config['timeout'], 1.0, None, True, llm.model_id, None, llm._serialisation, llm)
)
os.environ.update(**process_environ(llm.config, llm.config['timeout'], 1.0, None, True, llm.model_id, None, llm._serialisation, llm))
try:
assert llm.bentomodel # HACK: call it here to patch correct tag with revision and everything
@@ -923,11 +856,7 @@ def build_command(
def get_current_bentocloud_context() -> str | None:
try:
context = (
cloud_config.get_context(ctx.obj.cloud_context)
if ctx.obj.cloud_context
else cloud_config.get_current_context()
)
context = cloud_config.get_context(ctx.obj.cloud_context) if ctx.obj.cloud_context else cloud_config.get_current_context()
return context.name
except Exception:
return None
@@ -951,9 +880,7 @@ def build_command(
tag=str(bento_tag),
backend=llm.__llm_backend__,
instructions=[
DeploymentInstruction.from_content(
type='bentocloud', instr="☁️ Push to BentoCloud with 'bentoml push':\n $ {cmd}", cmd=push_cmd
),
DeploymentInstruction.from_content(type='bentocloud', instr="☁️ Push to BentoCloud with 'bentoml push':\n $ {cmd}", cmd=push_cmd),
DeploymentInstruction.from_content(
type='container',
instr="🐳 Container BentoLLM with 'bentoml containerize':\n $ {cmd}",
@@ -979,9 +906,7 @@ def build_command(
termui.echo(f" * {instruction['content']}\n", nl=False)
if push:
BentoMLContainer.bentocloud_client.get().push_bento(
bento, context=t.cast(GlobalOptions, ctx.obj).cloud_context, force=force_push
)
BentoMLContainer.bentocloud_client.get().push_bento(bento, context=t.cast(GlobalOptions, ctx.obj).cloud_context, force=force_push)
elif containerize:
container_backend = t.cast('DefaultBuilder', os.environ.get('BENTOML_CONTAINERIZE_BACKEND', 'docker'))
try:
@@ -1009,20 +934,19 @@ class ModelItem(t.TypedDict):
@cli.command()
@click.option('--show-available', is_flag=True, default=True, hidden=True)
def models_command(**_: t.Any) -> dict[t.LiteralString, ModelItem]:
'''List all supported models.
"""List all supported models.
\b
```bash
openllm models
```
'''
"""
result: dict[t.LiteralString, ModelItem] = {
m: ModelItem(
architecture=config.__openllm_architecture__,
example_id=random.choice(config.__openllm_model_ids__),
supported_backends=config.__openllm_backend__,
installation='pip install '
+ (f'"openllm[{m}]"' if m in OPTIONAL_DEPENDENCIES or config.__openllm_requirements__ else 'openllm'),
installation='pip install ' + (f'"openllm[{m}]"' if m in OPTIONAL_DEPENDENCIES or config.__openllm_requirements__ else 'openllm'),
items=[
str(md.tag)
for md in bentoml.models.list()
@@ -1041,13 +965,7 @@ def models_command(**_: t.Any) -> dict[t.LiteralString, ModelItem]:
@cli.command()
@model_name_argument(required=False)
@click.option('-y', '--yes', '--assume-yes', is_flag=True, help='Skip confirmation when deleting a specific model')
@click.option(
'--include-bentos/--no-include-bentos',
is_flag=True,
hidden=True,
default=True,
help='Whether to also include pruning bentos.',
)
@click.option('--include-bentos/--no-include-bentos', is_flag=True, hidden=True, default=True, help='Whether to also include pruning bentos.')
@inject
@click.pass_context
def prune_command(
@@ -1058,38 +976,30 @@ def prune_command(
bento_store: BentoStore = Provide[BentoMLContainer.bento_store],
**_: t.Any,
) -> None:
'''Remove all saved models, and bentos built with OpenLLM locally.
"""Remove all saved models, and bentos built with OpenLLM locally.
\b
If a model type is passed, then only prune models for that given model type.
'''
"""
available: list[tuple[bentoml.Model | bentoml.Bento, ModelStore | BentoStore]] = [
(m, model_store)
for m in bentoml.models.list()
if 'framework' in m.info.labels and m.info.labels['framework'] == 'openllm'
(m, model_store) for m in bentoml.models.list() if 'framework' in m.info.labels and m.info.labels['framework'] == 'openllm'
]
if model_name is not None:
available = [
(m, store)
for m, store in available
if 'model_name' in m.info.labels and m.info.labels['model_name'] == inflection.underscore(model_name)
(m, store) for m, store in available if 'model_name' in m.info.labels and m.info.labels['model_name'] == inflection.underscore(model_name)
] + [
(b, bento_store)
for b in bentoml.bentos.list()
if 'start_name' in b.info.labels and b.info.labels['start_name'] == inflection.underscore(model_name)
]
else:
available += [
(b, bento_store) for b in bentoml.bentos.list() if '_type' in b.info.labels and '_framework' in b.info.labels
]
available += [(b, bento_store) for b in bentoml.bentos.list() if '_type' in b.info.labels and '_framework' in b.info.labels]
for store_item, store in available:
if yes:
delete_confirmed = True
else:
delete_confirmed = click.confirm(
f"delete {'model' if isinstance(store, ModelStore) else 'bento'} {store_item.tag}?"
)
delete_confirmed = click.confirm(f"delete {'model' if isinstance(store, ModelStore) else 'bento'} {store_item.tag}?")
if delete_confirmed:
store.delete(store_item.tag)
termui.warning(f"{store_item} deleted from {'model' if isinstance(store, ModelStore) else 'bento'} store.")
@@ -1136,17 +1046,8 @@ def shared_client_options(f: _AnyCallable | None = None) -> t.Callable[[FC], FC]
@cli.command()
@shared_client_options
@click.option(
'--server-type',
type=click.Choice(['grpc', 'http']),
help='Server type',
default='http',
show_default=True,
hidden=True,
)
@click.option(
'--stream/--no-stream', type=click.BOOL, is_flag=True, default=True, help='Whether to stream the response.'
)
@click.option('--server-type', type=click.Choice(['grpc', 'http']), help='Server type', default='http', show_default=True, hidden=True)
@click.option('--stream/--no-stream', type=click.BOOL, is_flag=True, default=True, help='Whether to stream the response.')
@click.argument('prompt', type=click.STRING)
@click.option(
'--sampling-params',
@@ -1168,14 +1069,15 @@ def query_command(
_memoized: DictStrAny,
**_: t.Any,
) -> None:
'''Query a LLM interactively, from a terminal.
"""Query a LLM interactively, from a terminal.
\b
```bash
$ openllm query --endpoint http://12.323.2.1:3000 "What is the meaning of life?"
```
'''
if server_type == 'grpc': raise click.ClickException("'grpc' is currently disabled.")
"""
if server_type == 'grpc':
raise click.ClickException("'grpc' is currently disabled.")
_memoized = {k: orjson.loads(v[0]) for k, v in _memoized.items() if v}
# TODO: grpc support
client = openllm.HTTPClient(address=endpoint, timeout=timeout)
@@ -1194,7 +1096,7 @@ def query_command(
@cli.group(cls=Extensions, hidden=True, name='extension')
def extension_command() -> None:
'''Extension for OpenLLM CLI.'''
"""Extension for OpenLLM CLI."""
if __name__ == '__main__':

View File

@@ -21,10 +21,8 @@ if t.TYPE_CHECKING:
@machine_option
@click.pass_context
@inject
def cli(
ctx: click.Context, bento: str, machine: bool, _bento_store: BentoStore = Provide[BentoMLContainer.bento_store]
) -> str | None:
'''Dive into a BentoLLM. This is synonymous to cd $(b get <bento>:<tag> -o path).'''
def cli(ctx: click.Context, bento: str, machine: bool, _bento_store: BentoStore = Provide[BentoMLContainer.bento_store]) -> str | None:
"""Dive into a BentoLLM. This is synonymous to cd $(b get <bento>:<tag> -o path)."""
try:
bentomodel = _bento_store.get(bento)
except bentoml.exceptions.NotFound:

View File

@@ -17,9 +17,7 @@ if t.TYPE_CHECKING:
from bentoml._internal.bento import BentoStore
@click.command(
'get_containerfile', context_settings=termui.CONTEXT_SETTINGS, help='Return Containerfile of any given Bento.'
)
@click.command('get_containerfile', context_settings=termui.CONTEXT_SETTINGS, help='Return Containerfile of any given Bento.')
@click.argument('bento', type=str, shell_complete=bento_complete_envvar)
@click.pass_context
@inject

View File

@@ -22,9 +22,7 @@ class PromptFormatter(string.Formatter):
raise ValueError('Positional arguments are not supported')
return super().vformat(format_string, args, kwargs)
def check_unused_args(
self, used_args: set[int | str], args: t.Sequence[t.Any], kwargs: t.Mapping[str, t.Any]
) -> None:
def check_unused_args(self, used_args: set[int | str], args: t.Sequence[t.Any], kwargs: t.Mapping[str, t.Any]) -> None:
extras = set(kwargs).difference(used_args)
if extras:
raise KeyError(f'Extra params passed: {extras}')
@@ -58,9 +56,7 @@ class PromptTemplate:
try:
return self.template.format(**prompt_variables)
except KeyError as e:
raise RuntimeError(
f"Missing variable '{e.args[0]}' (required: {self._input_variables}) in the prompt template."
) from None
raise RuntimeError(f"Missing variable '{e.args[0]}' (required: {self._input_variables}) in the prompt template.") from None
@click.command('get_prompt', context_settings=termui.CONTEXT_SETTINGS)
@@ -128,21 +124,15 @@ def cli(
if prompt_template_file and chat_template_file:
ctx.fail('prompt-template-file and chat-template-file are mutually exclusive.')
acceptable = set(openllm.CONFIG_MAPPING_NAMES.keys()) | set(
inflection.dasherize(name) for name in openllm.CONFIG_MAPPING_NAMES.keys()
)
acceptable = set(openllm.CONFIG_MAPPING_NAMES.keys()) | set(inflection.dasherize(name) for name in openllm.CONFIG_MAPPING_NAMES.keys())
if model_id in acceptable:
logger.warning(
'Using a default prompt from OpenLLM. Note that this prompt might not work for your intended usage.\n'
)
logger.warning('Using a default prompt from OpenLLM. Note that this prompt might not work for your intended usage.\n')
config = openllm.AutoConfig.for_model(model_id)
template = prompt_template_file.read() if prompt_template_file is not None else config.template
system_message = system_message or config.system_message
try:
formatted = (
PromptTemplate(template).with_options(system_message=system_message).format(instruction=prompt, **_memoized)
)
formatted = PromptTemplate(template).with_options(system_message=system_message).format(instruction=prompt, **_memoized)
except RuntimeError as err:
logger.debug('Exception caught while formatting prompt: %s', err)
ctx.fail(str(err))
@@ -159,21 +149,15 @@ def cli(
for architecture in config.architectures:
if architecture in openllm.AutoConfig._CONFIG_MAPPING_NAMES_TO_ARCHITECTURE():
system_message = (
openllm.AutoConfig.infer_class_from_name(
openllm.AutoConfig._CONFIG_MAPPING_NAMES_TO_ARCHITECTURE()[architecture]
)
openllm.AutoConfig.infer_class_from_name(openllm.AutoConfig._CONFIG_MAPPING_NAMES_TO_ARCHITECTURE()[architecture])
.model_construct_env()
.system_message
)
break
else:
ctx.fail(
f'Failed to infer system message from model architecture: {config.architectures}. Please pass in --system-message'
)
ctx.fail(f'Failed to infer system message from model architecture: {config.architectures}. Please pass in --system-message')
messages = [{'role': 'system', 'content': system_message}, {'role': 'user', 'content': prompt}]
formatted = tokenizer.apply_chat_template(
messages, chat_template=chat_template_file, add_generation_prompt=add_generation_prompt, tokenize=False
)
formatted = tokenizer.apply_chat_template(messages, chat_template=chat_template_file, add_generation_prompt=add_generation_prompt, tokenize=False)
termui.echo(orjson.dumps({'prompt': formatted}, option=orjson.OPT_INDENT_2).decode(), fg='white')
ctx.exit(0)

View File

@@ -13,7 +13,7 @@ from openllm_cli import termui
@click.command('list_bentos', context_settings=termui.CONTEXT_SETTINGS)
@click.pass_context
def cli(ctx: click.Context) -> None:
'''List available bentos built by OpenLLM.'''
"""List available bentos built by OpenLLM."""
mapping = {
k: [
{

View File

@@ -18,7 +18,7 @@ if t.TYPE_CHECKING:
@click.command('list_models', context_settings=termui.CONTEXT_SETTINGS)
@model_name_argument(required=False, shell_complete=model_complete_envvar)
def cli(model_name: str | None) -> DictStrAny:
'''List available models in lcoal store to be used wit OpenLLM.'''
"""List available models in lcoal store to be used wit OpenLLM."""
models = tuple(inflection.dasherize(key) for key in openllm.CONFIG_MAPPING.keys())
ids_in_local_store = {
k: [
@@ -33,17 +33,12 @@ def cli(model_name: str | None) -> DictStrAny:
}
if model_name is not None:
ids_in_local_store = {
k: [
i
for i in v
if 'model_name' in i.info.labels and i.info.labels['model_name'] == inflection.dasherize(model_name)
]
k: [i for i in v if 'model_name' in i.info.labels and i.info.labels['model_name'] == inflection.dasherize(model_name)]
for k, v in ids_in_local_store.items()
}
ids_in_local_store = {k: v for k, v in ids_in_local_store.items() if v}
local_models = {
k: [{'tag': str(i.tag), 'size': human_readable_size(openllm.utils.calc_dir_size(i.path))} for i in val]
for k, val in ids_in_local_store.items()
k: [{'tag': str(i.tag), 'size': human_readable_size(openllm.utils.calc_dir_size(i.path))} for i in val] for k, val in ids_in_local_store.items()
}
termui.echo(orjson.dumps(local_models, option=orjson.OPT_INDENT_2).decode(), fg='white')
return local_models

View File

@@ -32,14 +32,7 @@ def load_notebook_metadata() -> DictStrAny:
@click.command('playground', context_settings=termui.CONTEXT_SETTINGS)
@click.argument('output-dir', default=None, required=False)
@click.option(
'--port',
envvar='JUPYTER_PORT',
show_envvar=True,
show_default=True,
default=8888,
help='Default port for Jupyter server',
)
@click.option('--port', envvar='JUPYTER_PORT', show_envvar=True, show_default=True, default=8888, help='Default port for Jupyter server')
@click.pass_context
def cli(ctx: click.Context, output_dir: str | None, port: int) -> None:
"""OpenLLM Playground.
@@ -60,9 +53,7 @@ def cli(ctx: click.Context, output_dir: str | None, port: int) -> None:
> This command requires Jupyter to be installed. Install it with 'pip install "openllm[playground]"'
"""
if not is_jupyter_available() or not is_jupytext_available() or not is_notebook_available():
raise RuntimeError(
"Playground requires 'jupyter', 'jupytext', and 'notebook'. Install it with 'pip install \"openllm[playground]\"'"
)
raise RuntimeError("Playground requires 'jupyter', 'jupytext', and 'notebook'. Install it with 'pip install \"openllm[playground]\"'")
metadata = load_notebook_metadata()
_temp_dir = False
if output_dir is None:
@@ -74,9 +65,7 @@ def cli(ctx: click.Context, output_dir: str | None, port: int) -> None:
termui.echo('The playground notebooks will be saved to: ' + os.path.abspath(output_dir), fg='blue')
for module in pkgutil.iter_modules(playground.__path__):
if module.ispkg or os.path.exists(os.path.join(output_dir, module.name + '.ipynb')):
logger.debug(
'Skipping: %s (%s)', module.name, 'File already exists' if not module.ispkg else f'{module.name} is a module'
)
logger.debug('Skipping: %s (%s)', module.name, 'File already exists' if not module.ispkg else f'{module.name} is a module')
continue
if not isinstance(module.module_finder, importlib.machinery.FileFinder):
continue
@@ -86,20 +75,18 @@ def cli(ctx: click.Context, output_dir: str | None, port: int) -> None:
f.cells.insert(0, markdown_cell)
jupytext.write(f, os.path.join(output_dir, module.name + '.ipynb'), fmt='notebook')
try:
subprocess.check_output(
[
sys.executable,
'-m',
'jupyter',
'notebook',
'--notebook-dir',
output_dir,
'--port',
str(port),
'--no-browser',
'--debug',
]
)
subprocess.check_output([
sys.executable,
'-m',
'jupyter',
'notebook',
'--notebook-dir',
output_dir,
'--port',
str(port),
'--no-browser',
'--debug',
])
except subprocess.CalledProcessError as e:
termui.echo(e.output, fg='red')
raise click.ClickException(f'Failed to start a jupyter server:\n{e}') from None

View File

@@ -25,14 +25,7 @@ class Level(enum.IntEnum):
@property
def color(self) -> str | None:
return {
Level.NOTSET: None,
Level.DEBUG: 'cyan',
Level.INFO: 'green',
Level.WARNING: 'yellow',
Level.ERROR: 'red',
Level.CRITICAL: 'red',
}[self]
return {Level.NOTSET: None, Level.DEBUG: 'cyan', Level.INFO: 'green', Level.WARNING: 'yellow', Level.ERROR: 'red', Level.CRITICAL: 'red'}[self]
@classmethod
def from_logging_level(cls, level: int) -> Level:
@@ -82,9 +75,5 @@ def echo(text: t.Any, fg: str | None = None, *, _with_style: bool = True, json:
COLUMNS: int = int(os.environ.get('COLUMNS', str(120)))
CONTEXT_SETTINGS: DictStrAny = {
'help_option_names': ['-h', '--help'],
'max_content_width': COLUMNS,
'token_normalize_func': inflection.underscore,
}
CONTEXT_SETTINGS: DictStrAny = {'help_option_names': ['-h', '--help'], 'max_content_width': COLUMNS, 'token_normalize_func': inflection.underscore}
__all__ = ['echo', 'COLUMNS', 'CONTEXT_SETTINGS', 'log', 'warning', 'error', 'critical', 'debug', 'info', 'Level']