fix(cli): simplify register code for start

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
Aaron
2023-05-26 01:44:33 -07:00
parent 4127961c5c
commit 85252f13c4

View File

@@ -224,20 +224,6 @@ class OpenLLMCommandGroup(click.Group):
raise click.exceptions.UsageError(error_msg, e.ctx)
class StartCommandGroup(OpenLLMCommandGroup):
"""A `start` factory that generate each models as its own click Command. See 'openllm start --help' for more details."""
def __init__(self, *args: t.Any, **kwargs: t.Any):
_serve_grpc = kwargs.pop("_serve_grpc", False)
super(StartCommandGroup, self).__init__(*args, **kwargs)
for name in openllm.CONFIG_MAPPING.keys():
if name not in self.commands:
self.add_command(start_model_command(name, self, _serve_grpc=_serve_grpc))
def list_commands(self, ctx: click.Context):
return openllm.CONFIG_MAPPING.keys()
def parse_serve_args(serve_grpc: bool) -> t.Callable[[F[P]], F[P]]:
"""Parsing `bentoml serve|serve-grpc` click.Option to be parsed via `openllm start`"""
from bentoml_cli.cli import cli
@@ -274,8 +260,7 @@ def parse_serve_args(serve_grpc: bool) -> t.Callable[[F[P]], F[P]]:
def start_model_command(
model_name: str,
factory: OpenLLMCommandGroup,
pretrained: str | None = None,
factory: click.Group,
_context_settings: dict[str, t.Any] | None = None,
_serve_grpc: bool = False,
) -> click.Command:
@@ -309,7 +294,13 @@ def start_model_command(
except openllm.exceptions.GpuNotAvailableError:
# NOTE: The model requires GPU, therefore we will return a dummy command
model_command_decr.update({"short_help": "(Disabled because there is no GPU available)"})
return factory.command(**model_command_decr)(lambda: None)
@factory.command(**model_command_decr)
def noop():
click.secho("No GPU available, therefore this command is disabled", fg="red")
return
return noop
@factory.command(**model_command_decr)
@config.to_click_options
@@ -431,7 +422,7 @@ def cli():
"""
@cli.group(cls=openllm.cli.StartCommandGroup, context_settings=_CONTEXT_SETTINGS, aliases=["start-http"], name="start")
@cli.group(cls=OpenLLMCommandGroup, context_settings=_CONTEXT_SETTINGS, aliases=["start-http"], name="start")
def start_cli():
"""
Start any LLM as a REST server.
@@ -440,7 +431,11 @@ def start_cli():
"""
@cli.group(cls=openllm.cli.StartCommandGroup, context_settings=_CONTEXT_SETTINGS, _serve_grpc=True, name="start-grpc")
for name in openllm.CONFIG_MAPPING:
start_cli.add_command(start_model_command(name, start_cli, _context_settings=_CONTEXT_SETTINGS))
@cli.group(cls=OpenLLMCommandGroup, context_settings=_CONTEXT_SETTINGS, name="start-grpc")
def start_grpc_cli():
"""
Start any LLM as a gRPC server.
@@ -449,6 +444,12 @@ def start_grpc_cli():
"""
for name in openllm.CONFIG_MAPPING:
start_grpc_cli.add_command(
start_model_command(name, start_grpc_cli, _context_settings=_CONTEXT_SETTINGS, _serve_grpc=True)
)
@cli.command(aliases=["build"])
@click.argument("model_name", type=click.Choice([inflection.dasherize(name) for name in openllm.CONFIG_MAPPING.keys()]))
@click.option("--pretrained", default=None, help="Given pretrained model name for the given model name [Optional].")