diff --git a/src/openllm/cli.py b/src/openllm/cli.py index fb7398d5..256b4141 100644 --- a/src/openllm/cli.py +++ b/src/openllm/cli.py @@ -24,6 +24,7 @@ import logging import os import sys import time +import traceback import typing as t import bentoml @@ -496,7 +497,11 @@ output_option = click.option( ) -def cli_factory(): +def cli_factory() -> click.Group: + from bentoml._internal.log import configure_logging + + configure_logging() + @click.group(cls=OpenLLMCommandGroup, context_settings=_CONTEXT_SETTINGS, name="openllm") def cli(): """ @@ -517,52 +522,51 @@ def cli_factory(): - Powered by BentoML 🍱 + HuggingFace 🤗 """ - @cli.command(name="version") + @cli.command() @output_option @click.pass_context - def _(ctx: click.Context, output: OutputLiteral): + def version(ctx: click.Context, output: OutputLiteral): """🚀 OpenLLM version.""" from gettext import gettext from .__about__ import __version__ message = gettext("%(prog)s, version %(version)s") - version = __version__ prog_name = ctx.find_root().info_name if output == "pretty": - _echo(message % {"prog": prog_name, "version": version}, color=ctx.color) + _echo(message % {"prog": prog_name, "version": __version__}, color=ctx.color) elif output == "json": - _echo(orjson.dumps({"version": version}, option=orjson.OPT_INDENT_2).decode()) + _echo(orjson.dumps({"version": __version__}, option=orjson.OPT_INDENT_2).decode()) else: - _echo(version) + _echo(__version__) ctx.exit() - @cli.group(cls=OpenLLMCommandGroup, context_settings=_CONTEXT_SETTINGS, name="start") - def _(): + @cli.group(cls=OpenLLMCommandGroup, context_settings=_CONTEXT_SETTINGS) + def start(): """ Start any LLM as a REST server. $ openllm start -- ... """ - @cli.group(cls=OpenLLMCommandGroup, context_settings=_CONTEXT_SETTINGS, name="start-grpc") - def _(): + @cli.group(cls=OpenLLMCommandGroup, context_settings=_CONTEXT_SETTINGS) + def start_grpc(): """ Start any LLM as a gRPC server. $ openllm start-grpc -- ... """ - @cli.command(name="bundle", aliases=["build"]) + @cli.command(aliases=["build"]) @click.argument( "model_name", type=click.Choice([inflection.dasherize(name) for name in openllm.CONFIG_MAPPING.keys()]) ) @click.option("--pretrained", default=None, help="Given pretrained model name for the given model name [Optional].") @click.option("--overwrite", is_flag=True, help="Overwrite existing Bento for given LLM if it already exists.") @output_option - def _(model_name: str, pretrained: str | None, overwrite: bool, output: OutputLiteral): + def bundle(model_name: str, pretrained: str | None, overwrite: bool, output: OutputLiteral): """Package a given models into a Bento. $ openllm bundle flan-t5 @@ -599,13 +603,13 @@ def cli_factory(): _echo(bento.tag) return bento - @cli.command(name="models") + @cli.command() @output_option - def _(output: OutputLiteral): + def models(output: OutputLiteral): """List all supported models.""" models = tuple(inflection.dasherize(key) for key in openllm.CONFIG_MAPPING.keys()) if output == "porcelain": - _echo("\n".join(models)) + _echo("\n".join(models), fg="white") else: failed_initialized: list[tuple[str, Exception]] = [] @@ -628,6 +632,14 @@ def cli_factory(): for m, v in json_data.items(): data.extend([(m, v["description"], v["variants"])]) column_widths = [int(COLUMNS / 6), int(COLUMNS / 3 * 2), int(COLUMNS / 6)] + + if len(data) == 0 and len(failed_initialized) > 0: + _echo("Exception found while parsing models:\n", fg="yellow") + for m, err in failed_initialized: + _echo(f"- {m}: ", fg="yellow", nl=False) + _echo(traceback.print_exception(err, limit=3), fg="red") + sys.exit(1) + table = tabulate.tabulate( data, tablefmt="fancy_grid", @@ -645,13 +657,14 @@ def cli_factory(): if len(failed_initialized) > 0: _echo("\nThe following models are supported but failed to initialize:\n") for m, err in failed_initialized: - _echo(f" - {m}: {err}", fg="yellow") + _echo(f"- {m}: ", fg="blue", nl=False) + _echo(err, fg="red") else: _echo(orjson.dumps(json_data, option=orjson.OPT_INDENT_2).decode()) sys.exit(0) - @cli.command(name="download_models") + @cli.command() @click.argument( "model_name", type=click.Choice([inflection.dasherize(name) for name in openllm.CONFIG_MAPPING.keys()]) ) @@ -659,7 +672,7 @@ def cli_factory(): "--pretrained", type=click.STRING, default=None, help="Optional pretrained name or path to fine-tune weight." ) @output_option - def _(model_name: str, pretrained: str | None, output: OutputLiteral): + def download_models(model_name: str, pretrained: str | None, output: OutputLiteral): """Setup LLM interactively. Note: This is useful for development and setup for fine-tune. @@ -704,6 +717,9 @@ def cli_factory(): _echo(m.tag) return m + if t.TYPE_CHECKING: + assert download_models and bundle and models and version and start and start_grpc + if psutil.WINDOWS: sys.stdout.reconfigure(encoding="utf-8") # type: ignore