From 5e1445218be525eb44520ac72ea28d3d563f5fc2 Mon Sep 17 00:00:00 2001 From: Aaron Pham <29749331+aarnphm@users.noreply.github.com> Date: Thu, 15 Jun 2023 02:32:46 -0400 Subject: [PATCH] refactor: toplevel CLI (#26) Move up CLI outside of the factory function to simplify workflow --- src/openllm/cli.py | 611 ++++++++++++++++++++++----------------------- 1 file changed, 304 insertions(+), 307 deletions(-) diff --git a/src/openllm/cli.py b/src/openllm/cli.py index b4c8913d..e2c329c6 100644 --- a/src/openllm/cli.py +++ b/src/openllm/cli.py @@ -18,12 +18,12 @@ This extends clidantic and BentoML's internal CLI CommandGroup. """ from __future__ import annotations +import contextlib import functools import inspect import logging import os import re -import contextlib import sys import time import traceback @@ -39,12 +39,16 @@ from bentoml._internal.configuration import get_debug_mode, get_quiet_mode, set_ from bentoml._internal.configuration.containers import BentoMLContainer from bentoml._internal.log import configure_logging, configure_server_logging from bentoml_cli.utils import BentoMLCommandGroup +from simple_di import Provide, inject import openllm +from .__about__ import __version__ from .utils import DEBUG, LazyType, ModelEnv, analytics, bentoml_cattr, first_not_none if t.TYPE_CHECKING: + from bentoml._internal.models import ModelStore + from ._types import ClickFunctionWrapper, F, P ServeCommand = t.Literal["serve", "serve-grpc"] @@ -55,6 +59,8 @@ else: TupleStrAny = tuple +configure_logging() + logger = logging.getLogger(__name__) COLUMNS = int(os.environ.get("COLUMNS", 120)) @@ -647,341 +653,328 @@ def workers_per_resource_option(factory: t.Any, build: bool = False): ) -def cli_factory() -> click.Group: - from .__about__ import __version__ +@click.group(cls=OpenLLMCommandGroup, context_settings=_CONTEXT_SETTINGS, name="openllm") +@click.version_option(__version__, "--version", "-v") +def cli(): + """ + \b + ██████╗ ██████╗ ███████╗███╗ ██╗██╗ ██╗ ███╗ ███╗ + ██╔═══██╗██╔══██╗██╔════╝████╗ ██║██║ ██║ ████╗ ████║ + ██║ ██║██████╔╝█████╗ ██╔██╗ ██║██║ ██║ ██╔████╔██║ + ██║ ██║██╔═══╝ ██╔══╝ ██║╚██╗██║██║ ██║ ██║╚██╔╝██║ + ╚██████╔╝██║ ███████╗██║ ╚████║███████╗███████╗██║ ╚═╝ ██║ + ╚═════╝ ╚═╝ ╚══════╝╚═╝ ╚═══╝╚══════╝╚══════╝╚═╝ ╚═╝ - configure_logging() + \b + An open platform for operating large language models in production. + Fine-tune, serve, deploy, and monitor any LLMs with ease. + """ - model_store = BentoMLContainer.model_store.get() - @click.group(cls=OpenLLMCommandGroup, context_settings=_CONTEXT_SETTINGS, name="openllm") - @click.version_option(__version__, "--version", "-v") - def cli(): - """ - \b - ██████╗ ██████╗ ███████╗███╗ ██╗██╗ ██╗ ███╗ ███╗ - ██╔═══██╗██╔══██╗██╔════╝████╗ ██║██║ ██║ ████╗ ████║ - ██║ ██║██████╔╝█████╗ ██╔██╗ ██║██║ ██║ ██╔████╔██║ - ██║ ██║██╔═══╝ ██╔══╝ ██║╚██╗██║██║ ██║ ██║╚██╔╝██║ - ╚██████╔╝██║ ███████╗██║ ╚████║███████╗███████╗██║ ╚═╝ ██║ - ╚═════╝ ╚═╝ ╚══════╝╚═╝ ╚═══╝╚══════╝╚══════╝╚═╝ ╚═╝ +@cli.group(cls=OpenLLMCommandGroup, context_settings=_CONTEXT_SETTINGS) +def start(): + """ + Start any LLM as a REST server. - \b - An open platform for operating large language models in production. - Fine-tune, serve, deploy, and monitor any LLMs with ease. - """ + $ openllm start -- ... + """ - @cli.group(cls=OpenLLMCommandGroup, context_settings=_CONTEXT_SETTINGS) - def start(): - """ - Start any LLM as a REST server. - $ openllm start -- ... - """ +@cli.group(cls=OpenLLMCommandGroup, context_settings=_CONTEXT_SETTINGS) +def start_grpc(): + """ + Start any LLM as a gRPC server. - @cli.group(cls=OpenLLMCommandGroup, context_settings=_CONTEXT_SETTINGS) - def start_grpc(): - """ - Start any LLM as a gRPC server. + $ openllm start-grpc -- ... + """ - $ openllm start-grpc -- ... - """ - @cli.command() - @click.argument( - "model_name", type=click.Choice([inflection.dasherize(name) for name in openllm.CONFIG_MAPPING.keys()]) +@cli.command() +@click.argument("model_name", type=click.Choice([inflection.dasherize(name) for name in openllm.CONFIG_MAPPING.keys()])) +@model_id_option(click) +@output_option +@click.option("--overwrite", is_flag=True, help="Overwrite existing Bento for given LLM if it already exists.") +@workers_per_resource_option(click, build=True) +def build( + model_name: str, + model_id: str | None, + overwrite: bool, + output: OutputLiteral, + workers_per_resource: float | None, +): + """Package a given models into a Bento. + + $ openllm build flan-t5 + + \b + NOTE: To run a container built from this Bento with GPU support, make sure + to have https://github.com/NVIDIA/nvidia-container-toolkit install locally. + """ + if output == "porcelain": + set_quiet_mode(True) + configure_server_logging() + + if output == "pretty": + if overwrite: + _echo(f"Overwriting existing Bento for {model_name}.", fg="yellow") + + bento, _previously_built = openllm.build( + model_name, + __cli__=True, + model_id=model_id, + _workers_per_resource=workers_per_resource, + _overwrite_existing_bento=overwrite, ) - @model_id_option(click) - @output_option - @click.option("--overwrite", is_flag=True, help="Overwrite existing Bento for given LLM if it already exists.") - @workers_per_resource_option(click, build=True) - def build( - model_name: str, - model_id: str | None, - overwrite: bool, - output: OutputLiteral, - workers_per_resource: float | None, - ): - """Package a given models into a Bento. - $ openllm build flan-t5 - - \b - NOTE: To run a container built from this Bento with GPU support, make sure - to have https://github.com/NVIDIA/nvidia-container-toolkit install locally. - """ - if output == "porcelain": - set_quiet_mode(True) - configure_server_logging() - - if output == "pretty": - if overwrite: - _echo(f"Overwriting existing Bento for {model_name}.", fg="yellow") - - bento, _previously_built = openllm.build( - model_name, - __cli__=True, - model_id=model_id, - _workers_per_resource=workers_per_resource, - _overwrite_existing_bento=overwrite, - ) - - if output == "pretty": - if not get_quiet_mode(): - _echo("\n" + OPENLLM_FIGLET, fg="white") - if not _previously_built: - _echo(f"Successfully built {bento}.", fg="green") - else: - _echo( - f"'{model_name}' already has a Bento built [{bento}]. To overwrite it pass '--overwrite'.", - fg="yellow", - ) - - _echo( - "\nPossible next steps:\n\n" - + "* Push to BentoCloud with `bentoml push`:\n" - + f" $ bentoml push {bento.tag}\n" - + "* Containerize your Bento with `bentoml containerize`:\n" - + f" $ bentoml containerize {bento.tag}\n" - + " Tip: To enable additional BentoML feature for 'containerize', " - + "use '--enable-features=FEATURE[,FEATURE]' " - + "[see 'bentoml containerize -h' for more advanced usage]\n", - fg="blue", - ) - elif output == "json": - _echo(orjson.dumps(bento.info.to_dict(), option=orjson.OPT_INDENT_2).decode()) - else: - _echo(bento.tag) - return bento - - @cli.command() - @output_option - @click.option( - "--show-available", - is_flag=True, - default=False, - help="Show available models in local store (mutually exclusive with '-o porcelain').", - ) - def models(output: OutputLiteral, show_available: bool): - """List all supported models. - - NOTE: '--show-available' and '-o porcelain' are mutually exclusive. - """ - from ._llm import convert_transformers_model_name - - models = tuple(inflection.dasherize(key) for key in openllm.CONFIG_MAPPING.keys()) - if output == "porcelain": - if show_available: - raise click.BadOptionUsage( - "--show-available", "Cannot use '--show-available' with '-o porcelain' (mutually exclusive)." - ) - _echo("\n".join(models), fg="white") - else: - failed_initialized: list[tuple[str, Exception]] = [] - - json_data: dict[ - str, dict[t.Literal["model_id", "url", "installation", "requires_gpu", "runtime_impl"], t.Any] - ] = {} - - # NOTE: Keep a sync list with ./tools/update-optional-dependencies.py - extras = ["chatglm", "falcon", "flan-t5", "starcoder"] - - converted: list[str] = [] - for m in models: - config = openllm.AutoConfig.for_model(m) - runtime_impl: tuple[t.Literal["pt", "flax", "tf"], ...] = tuple() - if config.__openllm_model_name__ in openllm.MODEL_MAPPING_NAMES: - runtime_impl += ("pt",) - if config.__openllm_model_name__ in openllm.MODEL_FLAX_MAPPING_NAMES: - runtime_impl += ("flax",) - if config.__openllm_model_name__ in openllm.MODEL_TF_MAPPING_NAMES: - runtime_impl += ("tf",) - json_data[m] = { - "model_id": config.__openllm_model_ids__, - "url": config.__openllm_url__, - "requires_gpu": config.__openllm_requires_gpu__, - "runtime_impl": runtime_impl, - "installation": "pip install openllm" if m not in extras else f'pip install "openllm[{m}]"', - } - converted.extend([convert_transformers_model_name(i) for i in config.__openllm_model_ids__]) - if DEBUG: - try: - openllm.AutoLLM.for_model(m, llm_config=config) - except Exception as err: - failed_initialized.append((m, err)) - - ids_in_local_store = None - if show_available: - ids_in_local_store = [i for i in bentoml.models.list() if any(n in i.tag.name for n in converted)] - - if output == "pretty": - import tabulate - - tabulate.PRESERVE_WHITESPACE = True - - data: list[ - str | tuple[str, str, list[str], str, t.LiteralString, tuple[t.Literal["pt", "flax", "tf"], ...]] - ] = [] - for m, v in json_data.items(): - data.extend( - [ - ( - m, - v["url"], - v["model_id"], - v["installation"], - "✅" if v["requires_gpu"] else "❌", - v["runtime_impl"], - ) - ] - ) - column_widths = [ - int(COLUMNS / 6), - int(COLUMNS / 6), - int(COLUMNS / 3), - int(COLUMNS / 6), - int(COLUMNS / 6), - int(COLUMNS / 9), - ] - - if len(data) == 0 and len(failed_initialized) > 0: - _echo("Exception found while parsing models:\n", fg="yellow") - for m, err in failed_initialized: - _echo(f"- {m}: ", fg="yellow", nl=False) - _echo(traceback.print_exception(err, limit=3), fg="red") - sys.exit(1) - - table = tabulate.tabulate( - data, - tablefmt="fancy_grid", - headers=["LLM", "URL", "Models Id", "Installation", "GPU Only", "Runtime"], - maxcolwidths=column_widths, - ) - - formatted_table = "" - for line in table.split("\n"): - formatted_table += ( - "".join(f"{cell:{width}}" for cell, width in zip(line.split("\t"), column_widths)) + "\n" - ) - _echo(formatted_table, fg="white") - - if DEBUG and len(failed_initialized) > 0: - _echo("\nThe following models are supported but failed to initialize:\n") - for m, err in failed_initialized: - _echo(f"- {m}: ", fg="blue", nl=False) - _echo(err, fg="red") - - if show_available: - assert ids_in_local_store - _echo("The following models are available in local store:\n", fg="white") - for i in ids_in_local_store: - _echo(f"- {i}", fg="white") + if output == "pretty": + if not get_quiet_mode(): + _echo("\n" + OPENLLM_FIGLET, fg="white") + if not _previously_built: + _echo(f"Successfully built {bento}.", fg="green") else: - dumped: dict[str, t.Any] = json_data - if show_available: - assert ids_in_local_store - dumped["local"] = [bentoml_cattr.unstructure(i.tag) for i in ids_in_local_store] _echo( - orjson.dumps( - dumped, - option=orjson.OPT_INDENT_2, - ).decode(), - fg="white", + f"'{model_name}' already has a Bento built [{bento}]. To overwrite it pass '--overwrite'.", + fg="yellow", ) - sys.exit(0) + _echo( + "\nPossible next steps:\n\n" + + "* Push to BentoCloud with `bentoml push`:\n" + + f" $ bentoml push {bento.tag}\n" + + "* Containerize your Bento with `bentoml containerize`:\n" + + f" $ bentoml containerize {bento.tag}\n" + + " Tip: To enable additional BentoML feature for 'containerize', " + + "use '--enable-features=FEATURE[,FEATURE]' " + + "[see 'bentoml containerize -h' for more advanced usage]\n", + fg="blue", + ) + elif output == "json": + _echo(orjson.dumps(bento.info.to_dict(), option=orjson.OPT_INDENT_2).decode()) + else: + _echo(bento.tag) + return bento - @cli.command() - @click.option( - "-y", - "--yes", - "--assume-yes", - is_flag=True, - help="Skip confirmation when deleting a specific model", - ) - def prune(yes: bool): - """Remove all saved models locally.""" - available = [ - m - for t in map(inflection.dasherize, openllm.CONFIG_MAPPING.keys()) - for m in bentoml.models.list() - if t in m.tag.name - ] - for model in available: - if yes: - delete_confirmed = True - else: - delete_confirmed = click.confirm(f"delete model {model.tag}?") +@cli.command() +@output_option +@click.option( + "--show-available", + is_flag=True, + default=False, + help="Show available models in local store (mutually exclusive with '-o porcelain').", +) +def models(output: OutputLiteral, show_available: bool): + """List all supported models. - if delete_confirmed: - model_store.delete(model.tag) - click.echo(f"{model} deleted.") + NOTE: '--show-available' and '-o porcelain' are mutually exclusive. + """ + from ._llm import convert_transformers_model_name - @cli.command(name="query") - @click.option( - "--endpoint", - type=click.STRING, - help="OpenLLM Server endpoint, i.e: http://localhost:3000", - envvar="OPENLLM_ENDPOINT", - default="http://localhost:3000", - ) - @click.option("--timeout", type=click.INT, default=30, help="Default server timeout", show_default=True) - @click.option( - "--server-type", type=click.Choice(["grpc", "http"]), help="Server type", default="http", show_default=True - ) - @output_option - @click.argument("query", type=click.STRING) - def query_( - query: str, - endpoint: str, - timeout: int, - server_type: t.Literal["http", "grpc"], - output: OutputLiteral, - ): - """Ask a LLM interactively, from a terminal. + models = tuple(inflection.dasherize(key) for key in openllm.CONFIG_MAPPING.keys()) + if output == "porcelain": + if show_available: + raise click.BadOptionUsage( + "--show-available", "Cannot use '--show-available' with '-o porcelain' (mutually exclusive)." + ) + _echo("\n".join(models), fg="white") + else: + failed_initialized: list[tuple[str, Exception]] = [] - $ openllm query --endpoint http://12.323.2.1:3000 "What is the meaning of life?" - """ - if server_type == "grpc": - endpoint = re.sub(r"http://", "", endpoint) - client = ( - openllm.client.HTTPClient(endpoint, timeout=timeout) - if server_type == "http" - else openllm.client.GrpcClient(endpoint, timeout=timeout) - ) + json_data: dict[ + str, dict[t.Literal["model_id", "url", "installation", "requires_gpu", "runtime_impl"], t.Any] + ] = {} - if client.framework == "flax": - model = openllm.AutoFlaxLLM.for_model(client.model_name) - elif client.framework == "tf": - model = openllm.AutoTFLLM.for_model(client.model_name) - else: - model = openllm.AutoLLM.for_model(client.model_name) + # NOTE: Keep a sync list with ./tools/update-optional-dependencies.py + extras = ["chatglm", "falcon", "flan-t5", "starcoder"] - if output != "porcelain": - _echo(f"Processing query: {query}\n", fg="white") + converted: list[str] = [] + for m in models: + config = openllm.AutoConfig.for_model(m) + runtime_impl: tuple[t.Literal["pt", "flax", "tf"], ...] = tuple() + if config.__openllm_model_name__ in openllm.MODEL_MAPPING_NAMES: + runtime_impl += ("pt",) + if config.__openllm_model_name__ in openllm.MODEL_FLAX_MAPPING_NAMES: + runtime_impl += ("flax",) + if config.__openllm_model_name__ in openllm.MODEL_TF_MAPPING_NAMES: + runtime_impl += ("tf",) + json_data[m] = { + "model_id": config.__openllm_model_ids__, + "url": config.__openllm_url__, + "requires_gpu": config.__openllm_requires_gpu__, + "runtime_impl": runtime_impl, + "installation": "pip install openllm" if m not in extras else f'pip install "openllm[{m}]"', + } + converted.extend([convert_transformers_model_name(i) for i in config.__openllm_model_ids__]) + if DEBUG: + try: + openllm.AutoLLM.for_model(m, llm_config=config) + except Exception as err: + failed_initialized.append((m, err)) - res = client.query(query, return_raw_response=True) + ids_in_local_store = None + if show_available: + ids_in_local_store = [i for i in bentoml.models.list() if any(n in i.tag.name for n in converted)] if output == "pretty": - formatted = model.postprocess_generate(query, res["responses"]) - _echo("Responses: ", fg="white", nl=False) - _echo(formatted, fg="cyan") - elif output == "json": - _echo(orjson.dumps(res, option=orjson.OPT_INDENT_2).decode(), fg="white") + import tabulate + + tabulate.PRESERVE_WHITESPACE = True + + data: list[ + str | tuple[str, str, list[str], str, t.LiteralString, tuple[t.Literal["pt", "flax", "tf"], ...]] + ] = [] + for m, v in json_data.items(): + data.extend( + [ + ( + m, + v["url"], + v["model_id"], + v["installation"], + "✅" if v["requires_gpu"] else "❌", + v["runtime_impl"], + ) + ] + ) + column_widths = [ + int(COLUMNS / 6), + int(COLUMNS / 6), + int(COLUMNS / 3), + int(COLUMNS / 6), + int(COLUMNS / 6), + int(COLUMNS / 9), + ] + + if len(data) == 0 and len(failed_initialized) > 0: + _echo("Exception found while parsing models:\n", fg="yellow") + for m, err in failed_initialized: + _echo(f"- {m}: ", fg="yellow", nl=False) + _echo(traceback.print_exception(err, limit=3), fg="red") + sys.exit(1) + + table = tabulate.tabulate( + data, + tablefmt="fancy_grid", + headers=["LLM", "URL", "Models Id", "Installation", "GPU Only", "Runtime"], + maxcolwidths=column_widths, + ) + + formatted_table = "" + for line in table.split("\n"): + formatted_table += ( + "".join(f"{cell:{width}}" for cell, width in zip(line.split("\t"), column_widths)) + "\n" + ) + _echo(formatted_table, fg="white") + + if DEBUG and len(failed_initialized) > 0: + _echo("\nThe following models are supported but failed to initialize:\n") + for m, err in failed_initialized: + _echo(f"- {m}: ", fg="blue", nl=False) + _echo(err, fg="red") + + if show_available: + assert ids_in_local_store + _echo("The following models are available in local store:\n", fg="white") + for i in ids_in_local_store: + _echo(f"- {i}", fg="white") else: - _echo(res["responses"], fg="white") + dumped: dict[str, t.Any] = json_data + if show_available: + assert ids_in_local_store + dumped["local"] = [bentoml_cattr.unstructure(i.tag) for i in ids_in_local_store] + _echo( + orjson.dumps( + dumped, + option=orjson.OPT_INDENT_2, + ).decode(), + fg="white", + ) - if t.TYPE_CHECKING: - assert build and models and start and start_grpc and query_ and prune - - if psutil.WINDOWS: - sys.stdout.reconfigure(encoding="utf-8") # type: ignore - - return cli + sys.exit(0) -cli = cli_factory() +@cli.command() +@click.option( + "-y", + "--yes", + "--assume-yes", + is_flag=True, + help="Skip confirmation when deleting a specific model", +) +@inject +def prune(yes: bool, model_store: ModelStore = Provide[BentoMLContainer.model_store]): + """Remove all saved models locally.""" + available = [ + m + for t in map(inflection.dasherize, openllm.CONFIG_MAPPING.keys()) + for m in bentoml.models.list() + if t in m.tag.name + ] + + for model in available: + if yes: + delete_confirmed = True + else: + delete_confirmed = click.confirm(f"delete model {model.tag}?") + + if delete_confirmed: + model_store.delete(model.tag) + click.echo(f"{model} deleted.") + + +@cli.command(name="query") +@click.option( + "--endpoint", + type=click.STRING, + help="OpenLLM Server endpoint, i.e: http://localhost:3000", + envvar="OPENLLM_ENDPOINT", + default="http://localhost:3000", +) +@click.option("--timeout", type=click.INT, default=30, help="Default server timeout", show_default=True) +@click.option( + "--server-type", type=click.Choice(["grpc", "http"]), help="Server type", default="http", show_default=True +) +@output_option +@click.argument("query", type=click.STRING) +def query_( + query: str, + endpoint: str, + timeout: int, + server_type: t.Literal["http", "grpc"], + output: OutputLiteral, +): + """Ask a LLM interactively, from a terminal. + + $ openllm query --endpoint http://12.323.2.1:3000 "What is the meaning of life?" + """ + if server_type == "grpc": + endpoint = re.sub(r"http://", "", endpoint) + client = ( + openllm.client.HTTPClient(endpoint, timeout=timeout) + if server_type == "http" + else openllm.client.GrpcClient(endpoint, timeout=timeout) + ) + + if client.framework == "flax": + model = openllm.AutoFlaxLLM.for_model(client.model_name) + elif client.framework == "tf": + model = openllm.AutoTFLLM.for_model(client.model_name) + else: + model = openllm.AutoLLM.for_model(client.model_name) + + if output != "porcelain": + _echo(f"Processing query: {query}\n", fg="white") + + res = client.query(query, return_raw_response=True) + + if output == "pretty": + formatted = model.postprocess_generate(query, res["responses"]) + _echo("Responses: ", fg="white", nl=False) + _echo(formatted, fg="cyan") + elif output == "json": + _echo(orjson.dumps(res, option=orjson.OPT_INDENT_2).decode(), fg="white") + else: + _echo(res["responses"], fg="white") @cli.command() @@ -1035,5 +1028,9 @@ def download_models(model_name: str, model_id: str | None, output: OutputLiteral return m +if psutil.WINDOWS: + sys.stdout.reconfigure(encoding="utf-8") # type: ignore + + if __name__ == "__main__": cli()