refactor: toplevel CLI (#26)

Move up CLI outside of the factory function to simplify workflow
This commit is contained in:
Aaron Pham
2023-06-15 02:32:46 -04:00
committed by GitHub
parent 9a6a976ce1
commit 5e1445218b

View File

@@ -18,12 +18,12 @@ This extends clidantic and BentoML's internal CLI CommandGroup.
"""
from __future__ import annotations
import contextlib
import functools
import inspect
import logging
import os
import re
import contextlib
import sys
import time
import traceback
@@ -39,12 +39,16 @@ from bentoml._internal.configuration import get_debug_mode, get_quiet_mode, set_
from bentoml._internal.configuration.containers import BentoMLContainer
from bentoml._internal.log import configure_logging, configure_server_logging
from bentoml_cli.utils import BentoMLCommandGroup
from simple_di import Provide, inject
import openllm
from .__about__ import __version__
from .utils import DEBUG, LazyType, ModelEnv, analytics, bentoml_cattr, first_not_none
if t.TYPE_CHECKING:
from bentoml._internal.models import ModelStore
from ._types import ClickFunctionWrapper, F, P
ServeCommand = t.Literal["serve", "serve-grpc"]
@@ -55,6 +59,8 @@ else:
TupleStrAny = tuple
configure_logging()
logger = logging.getLogger(__name__)
COLUMNS = int(os.environ.get("COLUMNS", 120))
@@ -647,341 +653,328 @@ def workers_per_resource_option(factory: t.Any, build: bool = False):
)
def cli_factory() -> click.Group:
from .__about__ import __version__
@click.group(cls=OpenLLMCommandGroup, context_settings=_CONTEXT_SETTINGS, name="openllm")
@click.version_option(__version__, "--version", "-v")
def cli():
"""
\b
██████╗ ██████╗ ███████╗███╗ ██╗██╗ ██╗ ███╗ ███╗
██╔═══██╗██╔══██╗██╔════╝████╗ ██║██║ ██║ ████╗ ████║
██║ ██║██████╔╝█████╗ ██╔██╗ ██║██║ ██║ ██╔████╔██║
██║ ██║██╔═══╝ ██╔══╝ ██║╚██╗██║██║ ██║ ██║╚██╔╝██║
╚██████╔╝██║ ███████╗██║ ╚████║███████╗███████╗██║ ╚═╝ ██║
╚═════╝ ╚═╝ ╚══════╝╚═╝ ╚═══╝╚══════╝╚══════╝╚═╝ ╚═╝
configure_logging()
\b
An open platform for operating large language models in production.
Fine-tune, serve, deploy, and monitor any LLMs with ease.
"""
model_store = BentoMLContainer.model_store.get()
@click.group(cls=OpenLLMCommandGroup, context_settings=_CONTEXT_SETTINGS, name="openllm")
@click.version_option(__version__, "--version", "-v")
def cli():
"""
\b
██████╗ ██████╗ ███████╗███╗ ██╗██╗ ██╗ ███╗ ███╗
██╔═══██╗██╔══██╗██╔════╝████╗ ██║██║ ██║ ████╗ ████║
██║ ██║██████╔╝█████╗ ██╔██╗ ██║██║ ██║ ██╔████╔██║
██║ ██║██╔═══╝ ██╔══╝ ██║╚██╗██║██║ ██║ ██║╚██╔╝██║
╚██████╔╝██║ ███████╗██║ ╚████║███████╗███████╗██║ ╚═╝ ██║
╚═════╝ ╚═╝ ╚══════╝╚═╝ ╚═══╝╚══════╝╚══════╝╚═╝ ╚═╝
@cli.group(cls=OpenLLMCommandGroup, context_settings=_CONTEXT_SETTINGS)
def start():
"""
Start any LLM as a REST server.
\b
An open platform for operating large language models in production.
Fine-tune, serve, deploy, and monitor any LLMs with ease.
"""
$ openllm start <model_name> --<options> ...
"""
@cli.group(cls=OpenLLMCommandGroup, context_settings=_CONTEXT_SETTINGS)
def start():
"""
Start any LLM as a REST server.
$ openllm start <model_name> --<options> ...
"""
@cli.group(cls=OpenLLMCommandGroup, context_settings=_CONTEXT_SETTINGS)
def start_grpc():
"""
Start any LLM as a gRPC server.
@cli.group(cls=OpenLLMCommandGroup, context_settings=_CONTEXT_SETTINGS)
def start_grpc():
"""
Start any LLM as a gRPC server.
$ openllm start-grpc <model_name> --<options> ...
"""
$ openllm start-grpc <model_name> --<options> ...
"""
@cli.command()
@click.argument(
"model_name", type=click.Choice([inflection.dasherize(name) for name in openllm.CONFIG_MAPPING.keys()])
@cli.command()
@click.argument("model_name", type=click.Choice([inflection.dasherize(name) for name in openllm.CONFIG_MAPPING.keys()]))
@model_id_option(click)
@output_option
@click.option("--overwrite", is_flag=True, help="Overwrite existing Bento for given LLM if it already exists.")
@workers_per_resource_option(click, build=True)
def build(
model_name: str,
model_id: str | None,
overwrite: bool,
output: OutputLiteral,
workers_per_resource: float | None,
):
"""Package a given models into a Bento.
$ openllm build flan-t5
\b
NOTE: To run a container built from this Bento with GPU support, make sure
to have https://github.com/NVIDIA/nvidia-container-toolkit install locally.
"""
if output == "porcelain":
set_quiet_mode(True)
configure_server_logging()
if output == "pretty":
if overwrite:
_echo(f"Overwriting existing Bento for {model_name}.", fg="yellow")
bento, _previously_built = openllm.build(
model_name,
__cli__=True,
model_id=model_id,
_workers_per_resource=workers_per_resource,
_overwrite_existing_bento=overwrite,
)
@model_id_option(click)
@output_option
@click.option("--overwrite", is_flag=True, help="Overwrite existing Bento for given LLM if it already exists.")
@workers_per_resource_option(click, build=True)
def build(
model_name: str,
model_id: str | None,
overwrite: bool,
output: OutputLiteral,
workers_per_resource: float | None,
):
"""Package a given models into a Bento.
$ openllm build flan-t5
\b
NOTE: To run a container built from this Bento with GPU support, make sure
to have https://github.com/NVIDIA/nvidia-container-toolkit install locally.
"""
if output == "porcelain":
set_quiet_mode(True)
configure_server_logging()
if output == "pretty":
if overwrite:
_echo(f"Overwriting existing Bento for {model_name}.", fg="yellow")
bento, _previously_built = openllm.build(
model_name,
__cli__=True,
model_id=model_id,
_workers_per_resource=workers_per_resource,
_overwrite_existing_bento=overwrite,
)
if output == "pretty":
if not get_quiet_mode():
_echo("\n" + OPENLLM_FIGLET, fg="white")
if not _previously_built:
_echo(f"Successfully built {bento}.", fg="green")
else:
_echo(
f"'{model_name}' already has a Bento built [{bento}]. To overwrite it pass '--overwrite'.",
fg="yellow",
)
_echo(
"\nPossible next steps:\n\n"
+ "* Push to BentoCloud with `bentoml push`:\n"
+ f" $ bentoml push {bento.tag}\n"
+ "* Containerize your Bento with `bentoml containerize`:\n"
+ f" $ bentoml containerize {bento.tag}\n"
+ " Tip: To enable additional BentoML feature for 'containerize', "
+ "use '--enable-features=FEATURE[,FEATURE]' "
+ "[see 'bentoml containerize -h' for more advanced usage]\n",
fg="blue",
)
elif output == "json":
_echo(orjson.dumps(bento.info.to_dict(), option=orjson.OPT_INDENT_2).decode())
else:
_echo(bento.tag)
return bento
@cli.command()
@output_option
@click.option(
"--show-available",
is_flag=True,
default=False,
help="Show available models in local store (mutually exclusive with '-o porcelain').",
)
def models(output: OutputLiteral, show_available: bool):
"""List all supported models.
NOTE: '--show-available' and '-o porcelain' are mutually exclusive.
"""
from ._llm import convert_transformers_model_name
models = tuple(inflection.dasherize(key) for key in openllm.CONFIG_MAPPING.keys())
if output == "porcelain":
if show_available:
raise click.BadOptionUsage(
"--show-available", "Cannot use '--show-available' with '-o porcelain' (mutually exclusive)."
)
_echo("\n".join(models), fg="white")
else:
failed_initialized: list[tuple[str, Exception]] = []
json_data: dict[
str, dict[t.Literal["model_id", "url", "installation", "requires_gpu", "runtime_impl"], t.Any]
] = {}
# NOTE: Keep a sync list with ./tools/update-optional-dependencies.py
extras = ["chatglm", "falcon", "flan-t5", "starcoder"]
converted: list[str] = []
for m in models:
config = openllm.AutoConfig.for_model(m)
runtime_impl: tuple[t.Literal["pt", "flax", "tf"], ...] = tuple()
if config.__openllm_model_name__ in openllm.MODEL_MAPPING_NAMES:
runtime_impl += ("pt",)
if config.__openllm_model_name__ in openllm.MODEL_FLAX_MAPPING_NAMES:
runtime_impl += ("flax",)
if config.__openllm_model_name__ in openllm.MODEL_TF_MAPPING_NAMES:
runtime_impl += ("tf",)
json_data[m] = {
"model_id": config.__openllm_model_ids__,
"url": config.__openllm_url__,
"requires_gpu": config.__openllm_requires_gpu__,
"runtime_impl": runtime_impl,
"installation": "pip install openllm" if m not in extras else f'pip install "openllm[{m}]"',
}
converted.extend([convert_transformers_model_name(i) for i in config.__openllm_model_ids__])
if DEBUG:
try:
openllm.AutoLLM.for_model(m, llm_config=config)
except Exception as err:
failed_initialized.append((m, err))
ids_in_local_store = None
if show_available:
ids_in_local_store = [i for i in bentoml.models.list() if any(n in i.tag.name for n in converted)]
if output == "pretty":
import tabulate
tabulate.PRESERVE_WHITESPACE = True
data: list[
str | tuple[str, str, list[str], str, t.LiteralString, tuple[t.Literal["pt", "flax", "tf"], ...]]
] = []
for m, v in json_data.items():
data.extend(
[
(
m,
v["url"],
v["model_id"],
v["installation"],
"" if v["requires_gpu"] else "",
v["runtime_impl"],
)
]
)
column_widths = [
int(COLUMNS / 6),
int(COLUMNS / 6),
int(COLUMNS / 3),
int(COLUMNS / 6),
int(COLUMNS / 6),
int(COLUMNS / 9),
]
if len(data) == 0 and len(failed_initialized) > 0:
_echo("Exception found while parsing models:\n", fg="yellow")
for m, err in failed_initialized:
_echo(f"- {m}: ", fg="yellow", nl=False)
_echo(traceback.print_exception(err, limit=3), fg="red")
sys.exit(1)
table = tabulate.tabulate(
data,
tablefmt="fancy_grid",
headers=["LLM", "URL", "Models Id", "Installation", "GPU Only", "Runtime"],
maxcolwidths=column_widths,
)
formatted_table = ""
for line in table.split("\n"):
formatted_table += (
"".join(f"{cell:{width}}" for cell, width in zip(line.split("\t"), column_widths)) + "\n"
)
_echo(formatted_table, fg="white")
if DEBUG and len(failed_initialized) > 0:
_echo("\nThe following models are supported but failed to initialize:\n")
for m, err in failed_initialized:
_echo(f"- {m}: ", fg="blue", nl=False)
_echo(err, fg="red")
if show_available:
assert ids_in_local_store
_echo("The following models are available in local store:\n", fg="white")
for i in ids_in_local_store:
_echo(f"- {i}", fg="white")
if output == "pretty":
if not get_quiet_mode():
_echo("\n" + OPENLLM_FIGLET, fg="white")
if not _previously_built:
_echo(f"Successfully built {bento}.", fg="green")
else:
dumped: dict[str, t.Any] = json_data
if show_available:
assert ids_in_local_store
dumped["local"] = [bentoml_cattr.unstructure(i.tag) for i in ids_in_local_store]
_echo(
orjson.dumps(
dumped,
option=orjson.OPT_INDENT_2,
).decode(),
fg="white",
f"'{model_name}' already has a Bento built [{bento}]. To overwrite it pass '--overwrite'.",
fg="yellow",
)
sys.exit(0)
_echo(
"\nPossible next steps:\n\n"
+ "* Push to BentoCloud with `bentoml push`:\n"
+ f" $ bentoml push {bento.tag}\n"
+ "* Containerize your Bento with `bentoml containerize`:\n"
+ f" $ bentoml containerize {bento.tag}\n"
+ " Tip: To enable additional BentoML feature for 'containerize', "
+ "use '--enable-features=FEATURE[,FEATURE]' "
+ "[see 'bentoml containerize -h' for more advanced usage]\n",
fg="blue",
)
elif output == "json":
_echo(orjson.dumps(bento.info.to_dict(), option=orjson.OPT_INDENT_2).decode())
else:
_echo(bento.tag)
return bento
@cli.command()
@click.option(
"-y",
"--yes",
"--assume-yes",
is_flag=True,
help="Skip confirmation when deleting a specific model",
)
def prune(yes: bool):
"""Remove all saved models locally."""
available = [
m
for t in map(inflection.dasherize, openllm.CONFIG_MAPPING.keys())
for m in bentoml.models.list()
if t in m.tag.name
]
for model in available:
if yes:
delete_confirmed = True
else:
delete_confirmed = click.confirm(f"delete model {model.tag}?")
@cli.command()
@output_option
@click.option(
"--show-available",
is_flag=True,
default=False,
help="Show available models in local store (mutually exclusive with '-o porcelain').",
)
def models(output: OutputLiteral, show_available: bool):
"""List all supported models.
if delete_confirmed:
model_store.delete(model.tag)
click.echo(f"{model} deleted.")
NOTE: '--show-available' and '-o porcelain' are mutually exclusive.
"""
from ._llm import convert_transformers_model_name
@cli.command(name="query")
@click.option(
"--endpoint",
type=click.STRING,
help="OpenLLM Server endpoint, i.e: http://localhost:3000",
envvar="OPENLLM_ENDPOINT",
default="http://localhost:3000",
)
@click.option("--timeout", type=click.INT, default=30, help="Default server timeout", show_default=True)
@click.option(
"--server-type", type=click.Choice(["grpc", "http"]), help="Server type", default="http", show_default=True
)
@output_option
@click.argument("query", type=click.STRING)
def query_(
query: str,
endpoint: str,
timeout: int,
server_type: t.Literal["http", "grpc"],
output: OutputLiteral,
):
"""Ask a LLM interactively, from a terminal.
models = tuple(inflection.dasherize(key) for key in openllm.CONFIG_MAPPING.keys())
if output == "porcelain":
if show_available:
raise click.BadOptionUsage(
"--show-available", "Cannot use '--show-available' with '-o porcelain' (mutually exclusive)."
)
_echo("\n".join(models), fg="white")
else:
failed_initialized: list[tuple[str, Exception]] = []
$ openllm query --endpoint http://12.323.2.1:3000 "What is the meaning of life?"
"""
if server_type == "grpc":
endpoint = re.sub(r"http://", "", endpoint)
client = (
openllm.client.HTTPClient(endpoint, timeout=timeout)
if server_type == "http"
else openllm.client.GrpcClient(endpoint, timeout=timeout)
)
json_data: dict[
str, dict[t.Literal["model_id", "url", "installation", "requires_gpu", "runtime_impl"], t.Any]
] = {}
if client.framework == "flax":
model = openllm.AutoFlaxLLM.for_model(client.model_name)
elif client.framework == "tf":
model = openllm.AutoTFLLM.for_model(client.model_name)
else:
model = openllm.AutoLLM.for_model(client.model_name)
# NOTE: Keep a sync list with ./tools/update-optional-dependencies.py
extras = ["chatglm", "falcon", "flan-t5", "starcoder"]
if output != "porcelain":
_echo(f"Processing query: {query}\n", fg="white")
converted: list[str] = []
for m in models:
config = openllm.AutoConfig.for_model(m)
runtime_impl: tuple[t.Literal["pt", "flax", "tf"], ...] = tuple()
if config.__openllm_model_name__ in openllm.MODEL_MAPPING_NAMES:
runtime_impl += ("pt",)
if config.__openllm_model_name__ in openllm.MODEL_FLAX_MAPPING_NAMES:
runtime_impl += ("flax",)
if config.__openllm_model_name__ in openllm.MODEL_TF_MAPPING_NAMES:
runtime_impl += ("tf",)
json_data[m] = {
"model_id": config.__openllm_model_ids__,
"url": config.__openllm_url__,
"requires_gpu": config.__openllm_requires_gpu__,
"runtime_impl": runtime_impl,
"installation": "pip install openllm" if m not in extras else f'pip install "openllm[{m}]"',
}
converted.extend([convert_transformers_model_name(i) for i in config.__openllm_model_ids__])
if DEBUG:
try:
openllm.AutoLLM.for_model(m, llm_config=config)
except Exception as err:
failed_initialized.append((m, err))
res = client.query(query, return_raw_response=True)
ids_in_local_store = None
if show_available:
ids_in_local_store = [i for i in bentoml.models.list() if any(n in i.tag.name for n in converted)]
if output == "pretty":
formatted = model.postprocess_generate(query, res["responses"])
_echo("Responses: ", fg="white", nl=False)
_echo(formatted, fg="cyan")
elif output == "json":
_echo(orjson.dumps(res, option=orjson.OPT_INDENT_2).decode(), fg="white")
import tabulate
tabulate.PRESERVE_WHITESPACE = True
data: list[
str | tuple[str, str, list[str], str, t.LiteralString, tuple[t.Literal["pt", "flax", "tf"], ...]]
] = []
for m, v in json_data.items():
data.extend(
[
(
m,
v["url"],
v["model_id"],
v["installation"],
"" if v["requires_gpu"] else "",
v["runtime_impl"],
)
]
)
column_widths = [
int(COLUMNS / 6),
int(COLUMNS / 6),
int(COLUMNS / 3),
int(COLUMNS / 6),
int(COLUMNS / 6),
int(COLUMNS / 9),
]
if len(data) == 0 and len(failed_initialized) > 0:
_echo("Exception found while parsing models:\n", fg="yellow")
for m, err in failed_initialized:
_echo(f"- {m}: ", fg="yellow", nl=False)
_echo(traceback.print_exception(err, limit=3), fg="red")
sys.exit(1)
table = tabulate.tabulate(
data,
tablefmt="fancy_grid",
headers=["LLM", "URL", "Models Id", "Installation", "GPU Only", "Runtime"],
maxcolwidths=column_widths,
)
formatted_table = ""
for line in table.split("\n"):
formatted_table += (
"".join(f"{cell:{width}}" for cell, width in zip(line.split("\t"), column_widths)) + "\n"
)
_echo(formatted_table, fg="white")
if DEBUG and len(failed_initialized) > 0:
_echo("\nThe following models are supported but failed to initialize:\n")
for m, err in failed_initialized:
_echo(f"- {m}: ", fg="blue", nl=False)
_echo(err, fg="red")
if show_available:
assert ids_in_local_store
_echo("The following models are available in local store:\n", fg="white")
for i in ids_in_local_store:
_echo(f"- {i}", fg="white")
else:
_echo(res["responses"], fg="white")
dumped: dict[str, t.Any] = json_data
if show_available:
assert ids_in_local_store
dumped["local"] = [bentoml_cattr.unstructure(i.tag) for i in ids_in_local_store]
_echo(
orjson.dumps(
dumped,
option=orjson.OPT_INDENT_2,
).decode(),
fg="white",
)
if t.TYPE_CHECKING:
assert build and models and start and start_grpc and query_ and prune
if psutil.WINDOWS:
sys.stdout.reconfigure(encoding="utf-8") # type: ignore
return cli
sys.exit(0)
cli = cli_factory()
@cli.command()
@click.option(
"-y",
"--yes",
"--assume-yes",
is_flag=True,
help="Skip confirmation when deleting a specific model",
)
@inject
def prune(yes: bool, model_store: ModelStore = Provide[BentoMLContainer.model_store]):
"""Remove all saved models locally."""
available = [
m
for t in map(inflection.dasherize, openllm.CONFIG_MAPPING.keys())
for m in bentoml.models.list()
if t in m.tag.name
]
for model in available:
if yes:
delete_confirmed = True
else:
delete_confirmed = click.confirm(f"delete model {model.tag}?")
if delete_confirmed:
model_store.delete(model.tag)
click.echo(f"{model} deleted.")
@cli.command(name="query")
@click.option(
"--endpoint",
type=click.STRING,
help="OpenLLM Server endpoint, i.e: http://localhost:3000",
envvar="OPENLLM_ENDPOINT",
default="http://localhost:3000",
)
@click.option("--timeout", type=click.INT, default=30, help="Default server timeout", show_default=True)
@click.option(
"--server-type", type=click.Choice(["grpc", "http"]), help="Server type", default="http", show_default=True
)
@output_option
@click.argument("query", type=click.STRING)
def query_(
query: str,
endpoint: str,
timeout: int,
server_type: t.Literal["http", "grpc"],
output: OutputLiteral,
):
"""Ask a LLM interactively, from a terminal.
$ openllm query --endpoint http://12.323.2.1:3000 "What is the meaning of life?"
"""
if server_type == "grpc":
endpoint = re.sub(r"http://", "", endpoint)
client = (
openllm.client.HTTPClient(endpoint, timeout=timeout)
if server_type == "http"
else openllm.client.GrpcClient(endpoint, timeout=timeout)
)
if client.framework == "flax":
model = openllm.AutoFlaxLLM.for_model(client.model_name)
elif client.framework == "tf":
model = openllm.AutoTFLLM.for_model(client.model_name)
else:
model = openllm.AutoLLM.for_model(client.model_name)
if output != "porcelain":
_echo(f"Processing query: {query}\n", fg="white")
res = client.query(query, return_raw_response=True)
if output == "pretty":
formatted = model.postprocess_generate(query, res["responses"])
_echo("Responses: ", fg="white", nl=False)
_echo(formatted, fg="cyan")
elif output == "json":
_echo(orjson.dumps(res, option=orjson.OPT_INDENT_2).decode(), fg="white")
else:
_echo(res["responses"], fg="white")
@cli.command()
@@ -1035,5 +1028,9 @@ def download_models(model_name: str, model_id: str | None, output: OutputLiteral
return m
if psutil.WINDOWS:
sys.stdout.reconfigure(encoding="utf-8") # type: ignore
if __name__ == "__main__":
cli()