chore(cli): redirect download models into subcontext

utilise click subcontext for nicer CLI interaction

Signed-off-by: aarnphm-ec2-dev <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
aarnphm-ec2-dev
2023-06-14 11:44:39 +00:00
parent d7e92ae525
commit dfe71d7867

View File

@@ -23,7 +23,7 @@ import inspect
import logging
import os
import re
import subprocess
import contextlib
import sys
import time
import traceback
@@ -163,7 +163,9 @@ Available model_id(s): {llm_config.__openllm_model_ids__} [default: {llm_config.
show_envvar=True,
)
@workers_per_resource_option(cog.optgroup)
@click.pass_context
def model_start(
ctx: click.Context,
server_timeout: int | None,
model_id: str | None,
workers_per_resource: float | None,
@@ -187,19 +189,14 @@ Available model_id(s): {llm_config.__openllm_model_ids__} [default: {llm_config.
# NOTE: We need to initialize llm here first to check if the model is already downloaded to
# avoid deadlock before the subprocess forking.
subprocess.check_output(
[
sys.executable,
"-m",
"openllm",
"download-models",
model_name,
"--model-id",
llm.model_id,
"--output",
"porcelain",
]
)
with open(os.devnull, "w") as devnull:
with contextlib.redirect_stderr(devnull), contextlib.redirect_stdout(devnull):
ctx.invoke(
download_models,
model_name=model_name,
model_id=llm.model_id,
output="porcelain",
)
workers_per_resource = first_not_none(workers_per_resource, default=llm.config.__openllm_workers_per_resource__)
server_timeout = first_not_none(server_timeout, default=llm.config.__openllm_timeout__)
@@ -424,6 +421,9 @@ class OpenLLMCommandGroup(BentoMLCommandGroup):
# Wrap into OpenLLM tracking
wrapped = self.usage_tracking(wrapped, self, **attrs)
# Wrap into exception handling
if "do_not_track" in attrs:
# We hit this branch when ctx.invoke the function
attrs.pop("do_not_track")
wrapped = self.exception_handling(wrapped, self, **attrs)
# move common parameters to end of the parameters list
@@ -544,14 +544,17 @@ class NargsOptions(cog.GroupedOption):
def parse_device_callback(
_: click.Context, params: click.Parameter, value: tuple[str, ...] | t.Literal["all"] | None
_: click.Context, params: click.Parameter, value: tuple[str, ...] | tuple[t.Literal["all"]] | None
) -> t.Any:
if value is None:
return value
if not LazyType(TupleStrAny).isinstance(value):
raise RuntimeError(f"{params} only accept multiple values.")
# NOTE: --device all is a special case
if isinstance(value, str):
if value != "all":
if len(value) == 1:
if value[0] != "all":
raise RuntimeError(f"{params} parameter only accept 'all' as a string value.")
import pynvml # transitive dependencies of BentoML
@@ -567,8 +570,6 @@ def parse_device_callback(
except Exception:
pass
if not LazyType(TupleStrAny).isinstance(value):
raise RuntimeError(f"{params} only accept multiple values.")
parsed: tuple[str, ...] = tuple()
for v in value:
if v == ",":
@@ -890,55 +891,6 @@ def cli_factory() -> click.Group:
sys.exit(0)
@cli.command()
@click.argument(
"model_name", type=click.Choice([inflection.dasherize(name) for name in openllm.CONFIG_MAPPING.keys()])
)
@model_id_option(click)
@output_option
def download_models(model_name: str, model_id: str | None, output: OutputLiteral):
"""Setup LLM interactively.
Note: This is useful for development and setup for fine-tune.
"""
config = openllm.AutoConfig.for_model(model_name)
env = config.__openllm_env__.get_framework_env()
if env == "flax":
model = openllm.AutoFlaxLLM.for_model(model_name, model_id=model_id, llm_config=config)
elif env == "tf":
model = openllm.AutoTFLLM.for_model(model_name, model_id=model_id, llm_config=config)
else:
model = openllm.AutoLLM.for_model(model_name, model_id=model_id, llm_config=config)
if len(bentoml.models.list(model.tag)) == 0:
if output == "pretty":
_echo(f"{model.tag} does not exists yet!. Downloading...", fg="yellow", nl=True)
m = model.ensure_model_id_exists()
if output == "pretty":
_echo(f"Saved model: {m.tag}")
elif output == "json":
_echo(
orjson.dumps(
{"previously_setup": False, "framework": env, "tag": str(m.tag)}, option=orjson.OPT_INDENT_2
).decode()
)
else:
_echo(model.tag)
else:
m = bentoml.transformers.get(model.tag)
if output == "pretty":
_echo(f"{model_name} is already setup for framework '{env}': {str(m.tag)}", nl=True, fg="yellow")
elif output == "json":
_echo(
orjson.dumps(
{"previously_setup": True, "framework": env, "model": str(m.tag)}, option=orjson.OPT_INDENT_2
).decode(),
fg="white",
)
else:
_echo(m.tag, fg="white")
return m
@cli.command()
@click.option(
"-y",
@@ -1021,7 +973,7 @@ def cli_factory() -> click.Group:
_echo(res["responses"], fg="white")
if t.TYPE_CHECKING:
assert download_models and build and models and start and start_grpc and query_ and prune
assert build and models and start and start_grpc and query_ and prune
if psutil.WINDOWS:
sys.stdout.reconfigure(encoding="utf-8") # type: ignore
@@ -1031,5 +983,57 @@ def cli_factory() -> click.Group:
cli = cli_factory()
@cli.command()
@click.argument(
"model_name",
type=click.Choice([inflection.dasherize(name) for name in openllm.CONFIG_MAPPING.keys()]),
)
@model_id_option(click)
@output_option
def download_models(model_name: str, model_id: str | None, output: OutputLiteral):
"""Setup LLM interactively.
Note: This is useful for development and setup for fine-tune.
"""
config = openllm.AutoConfig.for_model(model_name)
env = config.__openllm_env__.get_framework_env()
if env == "flax":
model = openllm.AutoFlaxLLM.for_model(model_name, model_id=model_id, llm_config=config)
elif env == "tf":
model = openllm.AutoTFLLM.for_model(model_name, model_id=model_id, llm_config=config)
else:
model = openllm.AutoLLM.for_model(model_name, model_id=model_id, llm_config=config)
if len(bentoml.models.list(model.tag)) == 0:
if output == "pretty":
_echo(f"{model.tag} does not exists yet!. Downloading...", fg="yellow", nl=True)
m = model.ensure_model_id_exists()
if output == "pretty":
_echo(f"Saved model: {m.tag}")
elif output == "json":
_echo(
orjson.dumps(
{"previously_setup": False, "framework": env, "tag": str(m.tag)}, option=orjson.OPT_INDENT_2
).decode()
)
else:
_echo(model.tag)
else:
m = bentoml.transformers.get(model.tag)
if output == "pretty":
_echo(f"{model_name} is already setup for framework '{env}': {str(m.tag)}", nl=True, fg="yellow")
elif output == "json":
_echo(
orjson.dumps(
{"previously_setup": True, "framework": env, "model": str(m.tag)}, option=orjson.OPT_INDENT_2
).decode(),
fg="white",
)
else:
_echo(m.tag, fg="white")
return m
if __name__ == "__main__":
cli()