chore(cli): update namespace and show better traceback

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
Aaron
2023-06-03 06:39:01 -07:00
parent ced6faf3c9
commit 64d783107d

View File

@@ -24,6 +24,7 @@ import logging
import os
import sys
import time
import traceback
import typing as t
import bentoml
@@ -496,7 +497,11 @@ output_option = click.option(
)
def cli_factory():
def cli_factory() -> click.Group:
from bentoml._internal.log import configure_logging
configure_logging()
@click.group(cls=OpenLLMCommandGroup, context_settings=_CONTEXT_SETTINGS, name="openllm")
def cli():
"""
@@ -517,52 +522,51 @@ def cli_factory():
- Powered by BentoML 🍱 + HuggingFace 🤗
"""
@cli.command(name="version")
@cli.command()
@output_option
@click.pass_context
def _(ctx: click.Context, output: OutputLiteral):
def version(ctx: click.Context, output: OutputLiteral):
"""🚀 OpenLLM version."""
from gettext import gettext
from .__about__ import __version__
message = gettext("%(prog)s, version %(version)s")
version = __version__
prog_name = ctx.find_root().info_name
if output == "pretty":
_echo(message % {"prog": prog_name, "version": version}, color=ctx.color)
_echo(message % {"prog": prog_name, "version": __version__}, color=ctx.color)
elif output == "json":
_echo(orjson.dumps({"version": version}, option=orjson.OPT_INDENT_2).decode())
_echo(orjson.dumps({"version": __version__}, option=orjson.OPT_INDENT_2).decode())
else:
_echo(version)
_echo(__version__)
ctx.exit()
@cli.group(cls=OpenLLMCommandGroup, context_settings=_CONTEXT_SETTINGS, name="start")
def _():
@cli.group(cls=OpenLLMCommandGroup, context_settings=_CONTEXT_SETTINGS)
def start():
"""
Start any LLM as a REST server.
$ openllm start <model_name> --<options> ...
"""
@cli.group(cls=OpenLLMCommandGroup, context_settings=_CONTEXT_SETTINGS, name="start-grpc")
def _():
@cli.group(cls=OpenLLMCommandGroup, context_settings=_CONTEXT_SETTINGS)
def start_grpc():
"""
Start any LLM as a gRPC server.
$ openllm start-grpc <model_name> --<options> ...
"""
@cli.command(name="bundle", aliases=["build"])
@cli.command(aliases=["build"])
@click.argument(
"model_name", type=click.Choice([inflection.dasherize(name) for name in openllm.CONFIG_MAPPING.keys()])
)
@click.option("--pretrained", default=None, help="Given pretrained model name for the given model name [Optional].")
@click.option("--overwrite", is_flag=True, help="Overwrite existing Bento for given LLM if it already exists.")
@output_option
def _(model_name: str, pretrained: str | None, overwrite: bool, output: OutputLiteral):
def bundle(model_name: str, pretrained: str | None, overwrite: bool, output: OutputLiteral):
"""Package a given models into a Bento.
$ openllm bundle flan-t5
@@ -599,13 +603,13 @@ def cli_factory():
_echo(bento.tag)
return bento
@cli.command(name="models")
@cli.command()
@output_option
def _(output: OutputLiteral):
def models(output: OutputLiteral):
"""List all supported models."""
models = tuple(inflection.dasherize(key) for key in openllm.CONFIG_MAPPING.keys())
if output == "porcelain":
_echo("\n".join(models))
_echo("\n".join(models), fg="white")
else:
failed_initialized: list[tuple[str, Exception]] = []
@@ -628,6 +632,14 @@ def cli_factory():
for m, v in json_data.items():
data.extend([(m, v["description"], v["variants"])])
column_widths = [int(COLUMNS / 6), int(COLUMNS / 3 * 2), int(COLUMNS / 6)]
if len(data) == 0 and len(failed_initialized) > 0:
_echo("Exception found while parsing models:\n", fg="yellow")
for m, err in failed_initialized:
_echo(f"- {m}: ", fg="yellow", nl=False)
_echo(traceback.print_exception(err, limit=3), fg="red")
sys.exit(1)
table = tabulate.tabulate(
data,
tablefmt="fancy_grid",
@@ -645,13 +657,14 @@ def cli_factory():
if len(failed_initialized) > 0:
_echo("\nThe following models are supported but failed to initialize:\n")
for m, err in failed_initialized:
_echo(f" - {m}: {err}", fg="yellow")
_echo(f"- {m}: ", fg="blue", nl=False)
_echo(err, fg="red")
else:
_echo(orjson.dumps(json_data, option=orjson.OPT_INDENT_2).decode())
sys.exit(0)
@cli.command(name="download_models")
@cli.command()
@click.argument(
"model_name", type=click.Choice([inflection.dasherize(name) for name in openllm.CONFIG_MAPPING.keys()])
)
@@ -659,7 +672,7 @@ def cli_factory():
"--pretrained", type=click.STRING, default=None, help="Optional pretrained name or path to fine-tune weight."
)
@output_option
def _(model_name: str, pretrained: str | None, output: OutputLiteral):
def download_models(model_name: str, pretrained: str | None, output: OutputLiteral):
"""Setup LLM interactively.
Note: This is useful for development and setup for fine-tune.
@@ -704,6 +717,9 @@ def cli_factory():
_echo(m.tag)
return m
if t.TYPE_CHECKING:
assert download_models and bundle and models and version and start and start_grpc
if psutil.WINDOWS:
sys.stdout.reconfigure(encoding="utf-8") # type: ignore