diff --git a/src/openllm/cli.py b/src/openllm/cli.py index 22b4b272..b4c8913d 100644 --- a/src/openllm/cli.py +++ b/src/openllm/cli.py @@ -23,7 +23,7 @@ import inspect import logging import os import re -import subprocess +import contextlib import sys import time import traceback @@ -163,7 +163,9 @@ Available model_id(s): {llm_config.__openllm_model_ids__} [default: {llm_config. show_envvar=True, ) @workers_per_resource_option(cog.optgroup) + @click.pass_context def model_start( + ctx: click.Context, server_timeout: int | None, model_id: str | None, workers_per_resource: float | None, @@ -187,19 +189,14 @@ Available model_id(s): {llm_config.__openllm_model_ids__} [default: {llm_config. # NOTE: We need to initialize llm here first to check if the model is already downloaded to # avoid deadlock before the subprocess forking. - subprocess.check_output( - [ - sys.executable, - "-m", - "openllm", - "download-models", - model_name, - "--model-id", - llm.model_id, - "--output", - "porcelain", - ] - ) + with open(os.devnull, "w") as devnull: + with contextlib.redirect_stderr(devnull), contextlib.redirect_stdout(devnull): + ctx.invoke( + download_models, + model_name=model_name, + model_id=llm.model_id, + output="porcelain", + ) workers_per_resource = first_not_none(workers_per_resource, default=llm.config.__openllm_workers_per_resource__) server_timeout = first_not_none(server_timeout, default=llm.config.__openllm_timeout__) @@ -424,6 +421,9 @@ class OpenLLMCommandGroup(BentoMLCommandGroup): # Wrap into OpenLLM tracking wrapped = self.usage_tracking(wrapped, self, **attrs) # Wrap into exception handling + if "do_not_track" in attrs: + # We hit this branch when ctx.invoke the function + attrs.pop("do_not_track") wrapped = self.exception_handling(wrapped, self, **attrs) # move common parameters to end of the parameters list @@ -544,14 +544,17 @@ class NargsOptions(cog.GroupedOption): def parse_device_callback( - _: click.Context, params: click.Parameter, value: tuple[str, ...] | t.Literal["all"] | None + _: click.Context, params: click.Parameter, value: tuple[str, ...] | tuple[t.Literal["all"]] | None ) -> t.Any: if value is None: return value + if not LazyType(TupleStrAny).isinstance(value): + raise RuntimeError(f"{params} only accept multiple values.") + # NOTE: --device all is a special case - if isinstance(value, str): - if value != "all": + if len(value) == 1: + if value[0] != "all": raise RuntimeError(f"{params} parameter only accept 'all' as a string value.") import pynvml # transitive dependencies of BentoML @@ -567,8 +570,6 @@ def parse_device_callback( except Exception: pass - if not LazyType(TupleStrAny).isinstance(value): - raise RuntimeError(f"{params} only accept multiple values.") parsed: tuple[str, ...] = tuple() for v in value: if v == ",": @@ -890,55 +891,6 @@ def cli_factory() -> click.Group: sys.exit(0) - @cli.command() - @click.argument( - "model_name", type=click.Choice([inflection.dasherize(name) for name in openllm.CONFIG_MAPPING.keys()]) - ) - @model_id_option(click) - @output_option - def download_models(model_name: str, model_id: str | None, output: OutputLiteral): - """Setup LLM interactively. - - Note: This is useful for development and setup for fine-tune. - """ - config = openllm.AutoConfig.for_model(model_name) - env = config.__openllm_env__.get_framework_env() - if env == "flax": - model = openllm.AutoFlaxLLM.for_model(model_name, model_id=model_id, llm_config=config) - elif env == "tf": - model = openllm.AutoTFLLM.for_model(model_name, model_id=model_id, llm_config=config) - else: - model = openllm.AutoLLM.for_model(model_name, model_id=model_id, llm_config=config) - - if len(bentoml.models.list(model.tag)) == 0: - if output == "pretty": - _echo(f"{model.tag} does not exists yet!. Downloading...", fg="yellow", nl=True) - m = model.ensure_model_id_exists() - if output == "pretty": - _echo(f"Saved model: {m.tag}") - elif output == "json": - _echo( - orjson.dumps( - {"previously_setup": False, "framework": env, "tag": str(m.tag)}, option=orjson.OPT_INDENT_2 - ).decode() - ) - else: - _echo(model.tag) - else: - m = bentoml.transformers.get(model.tag) - if output == "pretty": - _echo(f"{model_name} is already setup for framework '{env}': {str(m.tag)}", nl=True, fg="yellow") - elif output == "json": - _echo( - orjson.dumps( - {"previously_setup": True, "framework": env, "model": str(m.tag)}, option=orjson.OPT_INDENT_2 - ).decode(), - fg="white", - ) - else: - _echo(m.tag, fg="white") - return m - @cli.command() @click.option( "-y", @@ -1021,7 +973,7 @@ def cli_factory() -> click.Group: _echo(res["responses"], fg="white") if t.TYPE_CHECKING: - assert download_models and build and models and start and start_grpc and query_ and prune + assert build and models and start and start_grpc and query_ and prune if psutil.WINDOWS: sys.stdout.reconfigure(encoding="utf-8") # type: ignore @@ -1031,5 +983,57 @@ def cli_factory() -> click.Group: cli = cli_factory() + +@cli.command() +@click.argument( + "model_name", + type=click.Choice([inflection.dasherize(name) for name in openllm.CONFIG_MAPPING.keys()]), +) +@model_id_option(click) +@output_option +def download_models(model_name: str, model_id: str | None, output: OutputLiteral): + """Setup LLM interactively. + + Note: This is useful for development and setup for fine-tune. + """ + config = openllm.AutoConfig.for_model(model_name) + env = config.__openllm_env__.get_framework_env() + if env == "flax": + model = openllm.AutoFlaxLLM.for_model(model_name, model_id=model_id, llm_config=config) + elif env == "tf": + model = openllm.AutoTFLLM.for_model(model_name, model_id=model_id, llm_config=config) + else: + model = openllm.AutoLLM.for_model(model_name, model_id=model_id, llm_config=config) + + if len(bentoml.models.list(model.tag)) == 0: + if output == "pretty": + _echo(f"{model.tag} does not exists yet!. Downloading...", fg="yellow", nl=True) + m = model.ensure_model_id_exists() + if output == "pretty": + _echo(f"Saved model: {m.tag}") + elif output == "json": + _echo( + orjson.dumps( + {"previously_setup": False, "framework": env, "tag": str(m.tag)}, option=orjson.OPT_INDENT_2 + ).decode() + ) + else: + _echo(model.tag) + else: + m = bentoml.transformers.get(model.tag) + if output == "pretty": + _echo(f"{model_name} is already setup for framework '{env}': {str(m.tag)}", nl=True, fg="yellow") + elif output == "json": + _echo( + orjson.dumps( + {"previously_setup": True, "framework": env, "model": str(m.tag)}, option=orjson.OPT_INDENT_2 + ).decode(), + fg="white", + ) + else: + _echo(m.tag, fg="white") + return m + + if __name__ == "__main__": cli()