mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-01-26 08:17:52 -05:00
chore(cli): redirect download models into subcontext
utilise click subcontext for nicer CLI interaction Signed-off-by: aarnphm-ec2-dev <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
@@ -23,7 +23,7 @@ import inspect
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import contextlib
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
@@ -163,7 +163,9 @@ Available model_id(s): {llm_config.__openllm_model_ids__} [default: {llm_config.
|
||||
show_envvar=True,
|
||||
)
|
||||
@workers_per_resource_option(cog.optgroup)
|
||||
@click.pass_context
|
||||
def model_start(
|
||||
ctx: click.Context,
|
||||
server_timeout: int | None,
|
||||
model_id: str | None,
|
||||
workers_per_resource: float | None,
|
||||
@@ -187,19 +189,14 @@ Available model_id(s): {llm_config.__openllm_model_ids__} [default: {llm_config.
|
||||
|
||||
# NOTE: We need to initialize llm here first to check if the model is already downloaded to
|
||||
# avoid deadlock before the subprocess forking.
|
||||
subprocess.check_output(
|
||||
[
|
||||
sys.executable,
|
||||
"-m",
|
||||
"openllm",
|
||||
"download-models",
|
||||
model_name,
|
||||
"--model-id",
|
||||
llm.model_id,
|
||||
"--output",
|
||||
"porcelain",
|
||||
]
|
||||
)
|
||||
with open(os.devnull, "w") as devnull:
|
||||
with contextlib.redirect_stderr(devnull), contextlib.redirect_stdout(devnull):
|
||||
ctx.invoke(
|
||||
download_models,
|
||||
model_name=model_name,
|
||||
model_id=llm.model_id,
|
||||
output="porcelain",
|
||||
)
|
||||
|
||||
workers_per_resource = first_not_none(workers_per_resource, default=llm.config.__openllm_workers_per_resource__)
|
||||
server_timeout = first_not_none(server_timeout, default=llm.config.__openllm_timeout__)
|
||||
@@ -424,6 +421,9 @@ class OpenLLMCommandGroup(BentoMLCommandGroup):
|
||||
# Wrap into OpenLLM tracking
|
||||
wrapped = self.usage_tracking(wrapped, self, **attrs)
|
||||
# Wrap into exception handling
|
||||
if "do_not_track" in attrs:
|
||||
# We hit this branch when ctx.invoke the function
|
||||
attrs.pop("do_not_track")
|
||||
wrapped = self.exception_handling(wrapped, self, **attrs)
|
||||
|
||||
# move common parameters to end of the parameters list
|
||||
@@ -544,14 +544,17 @@ class NargsOptions(cog.GroupedOption):
|
||||
|
||||
|
||||
def parse_device_callback(
|
||||
_: click.Context, params: click.Parameter, value: tuple[str, ...] | t.Literal["all"] | None
|
||||
_: click.Context, params: click.Parameter, value: tuple[str, ...] | tuple[t.Literal["all"]] | None
|
||||
) -> t.Any:
|
||||
if value is None:
|
||||
return value
|
||||
|
||||
if not LazyType(TupleStrAny).isinstance(value):
|
||||
raise RuntimeError(f"{params} only accept multiple values.")
|
||||
|
||||
# NOTE: --device all is a special case
|
||||
if isinstance(value, str):
|
||||
if value != "all":
|
||||
if len(value) == 1:
|
||||
if value[0] != "all":
|
||||
raise RuntimeError(f"{params} parameter only accept 'all' as a string value.")
|
||||
import pynvml # transitive dependencies of BentoML
|
||||
|
||||
@@ -567,8 +570,6 @@ def parse_device_callback(
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if not LazyType(TupleStrAny).isinstance(value):
|
||||
raise RuntimeError(f"{params} only accept multiple values.")
|
||||
parsed: tuple[str, ...] = tuple()
|
||||
for v in value:
|
||||
if v == ",":
|
||||
@@ -890,55 +891,6 @@ def cli_factory() -> click.Group:
|
||||
|
||||
sys.exit(0)
|
||||
|
||||
@cli.command()
|
||||
@click.argument(
|
||||
"model_name", type=click.Choice([inflection.dasherize(name) for name in openllm.CONFIG_MAPPING.keys()])
|
||||
)
|
||||
@model_id_option(click)
|
||||
@output_option
|
||||
def download_models(model_name: str, model_id: str | None, output: OutputLiteral):
|
||||
"""Setup LLM interactively.
|
||||
|
||||
Note: This is useful for development and setup for fine-tune.
|
||||
"""
|
||||
config = openllm.AutoConfig.for_model(model_name)
|
||||
env = config.__openllm_env__.get_framework_env()
|
||||
if env == "flax":
|
||||
model = openllm.AutoFlaxLLM.for_model(model_name, model_id=model_id, llm_config=config)
|
||||
elif env == "tf":
|
||||
model = openllm.AutoTFLLM.for_model(model_name, model_id=model_id, llm_config=config)
|
||||
else:
|
||||
model = openllm.AutoLLM.for_model(model_name, model_id=model_id, llm_config=config)
|
||||
|
||||
if len(bentoml.models.list(model.tag)) == 0:
|
||||
if output == "pretty":
|
||||
_echo(f"{model.tag} does not exists yet!. Downloading...", fg="yellow", nl=True)
|
||||
m = model.ensure_model_id_exists()
|
||||
if output == "pretty":
|
||||
_echo(f"Saved model: {m.tag}")
|
||||
elif output == "json":
|
||||
_echo(
|
||||
orjson.dumps(
|
||||
{"previously_setup": False, "framework": env, "tag": str(m.tag)}, option=orjson.OPT_INDENT_2
|
||||
).decode()
|
||||
)
|
||||
else:
|
||||
_echo(model.tag)
|
||||
else:
|
||||
m = bentoml.transformers.get(model.tag)
|
||||
if output == "pretty":
|
||||
_echo(f"{model_name} is already setup for framework '{env}': {str(m.tag)}", nl=True, fg="yellow")
|
||||
elif output == "json":
|
||||
_echo(
|
||||
orjson.dumps(
|
||||
{"previously_setup": True, "framework": env, "model": str(m.tag)}, option=orjson.OPT_INDENT_2
|
||||
).decode(),
|
||||
fg="white",
|
||||
)
|
||||
else:
|
||||
_echo(m.tag, fg="white")
|
||||
return m
|
||||
|
||||
@cli.command()
|
||||
@click.option(
|
||||
"-y",
|
||||
@@ -1021,7 +973,7 @@ def cli_factory() -> click.Group:
|
||||
_echo(res["responses"], fg="white")
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
assert download_models and build and models and start and start_grpc and query_ and prune
|
||||
assert build and models and start and start_grpc and query_ and prune
|
||||
|
||||
if psutil.WINDOWS:
|
||||
sys.stdout.reconfigure(encoding="utf-8") # type: ignore
|
||||
@@ -1031,5 +983,57 @@ def cli_factory() -> click.Group:
|
||||
|
||||
cli = cli_factory()
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.argument(
|
||||
"model_name",
|
||||
type=click.Choice([inflection.dasherize(name) for name in openllm.CONFIG_MAPPING.keys()]),
|
||||
)
|
||||
@model_id_option(click)
|
||||
@output_option
|
||||
def download_models(model_name: str, model_id: str | None, output: OutputLiteral):
|
||||
"""Setup LLM interactively.
|
||||
|
||||
Note: This is useful for development and setup for fine-tune.
|
||||
"""
|
||||
config = openllm.AutoConfig.for_model(model_name)
|
||||
env = config.__openllm_env__.get_framework_env()
|
||||
if env == "flax":
|
||||
model = openllm.AutoFlaxLLM.for_model(model_name, model_id=model_id, llm_config=config)
|
||||
elif env == "tf":
|
||||
model = openllm.AutoTFLLM.for_model(model_name, model_id=model_id, llm_config=config)
|
||||
else:
|
||||
model = openllm.AutoLLM.for_model(model_name, model_id=model_id, llm_config=config)
|
||||
|
||||
if len(bentoml.models.list(model.tag)) == 0:
|
||||
if output == "pretty":
|
||||
_echo(f"{model.tag} does not exists yet!. Downloading...", fg="yellow", nl=True)
|
||||
m = model.ensure_model_id_exists()
|
||||
if output == "pretty":
|
||||
_echo(f"Saved model: {m.tag}")
|
||||
elif output == "json":
|
||||
_echo(
|
||||
orjson.dumps(
|
||||
{"previously_setup": False, "framework": env, "tag": str(m.tag)}, option=orjson.OPT_INDENT_2
|
||||
).decode()
|
||||
)
|
||||
else:
|
||||
_echo(model.tag)
|
||||
else:
|
||||
m = bentoml.transformers.get(model.tag)
|
||||
if output == "pretty":
|
||||
_echo(f"{model_name} is already setup for framework '{env}': {str(m.tag)}", nl=True, fg="yellow")
|
||||
elif output == "json":
|
||||
_echo(
|
||||
orjson.dumps(
|
||||
{"previously_setup": True, "framework": env, "model": str(m.tag)}, option=orjson.OPT_INDENT_2
|
||||
).decode(),
|
||||
fg="white",
|
||||
)
|
||||
else:
|
||||
_echo(m.tag, fg="white")
|
||||
return m
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
||||
|
||||
Reference in New Issue
Block a user