diff --git a/src/openllm/cli.py b/src/openllm/cli.py index 256b4141..51b46743 100644 --- a/src/openllm/cli.py +++ b/src/openllm/cli.py @@ -29,6 +29,7 @@ import typing as t import bentoml import click +import click_option_group as cog import inflection import orjson import psutil @@ -342,7 +343,7 @@ def start_model_command( command_attrs.update( { "short_help": "(Disabled because there is no GPU available)", - "help": f"""{model_name} is currently not available to run on your + "help": f"""{model_name} is currently not available to run on your local machine because it requires GPU for faster inference.""", } ) @@ -603,7 +604,7 @@ def cli_factory() -> click.Group: _echo(bento.tag) return bento - @cli.command() + @cli.command(aliases=["list"]) @output_option def models(output: OutputLiteral): """List all supported models.""" @@ -664,7 +665,7 @@ def cli_factory() -> click.Group: sys.exit(0) - @cli.command() + @cli.command(aliases=["save"]) @click.argument( "model_name", type=click.Choice([inflection.dasherize(name) for name in openllm.CONFIG_MAPPING.keys()]) ) @@ -717,8 +718,63 @@ def cli_factory() -> click.Group: _echo(m.tag) return m + @cli.command(name="query", aliases=["run", "ask"]) + @cog.optgroup.group( + "Host options", cls=cog.RequiredMutuallyExclusiveOptionGroup, help="default host for the running LLM server" + ) + @cog.optgroup.option("--endpoint", type=click.STRING, help="LLM Server endpoint, i.e: http://12.323.2.1") + @cog.optgroup.option("--local", type=click.BOOL, help="Whether the server is running locally.") + @click.option("--port", type=click.INT, default=3000, help="LLM Server port", show_default=True) + @click.option("--timeout", type=click.INT, default=30, help="Default server timeout", show_default=True) + @click.option( + "--server-type", type=click.Choice(["grpc", "http"]), help="Server type", default="http", show_default=True + ) + @output_option + @click.argument("query", type=click.STRING) + def query( + query: str, + endpoint: str, + port: int, + timeout: int, + local: bool, + server_type: t.Literal["http", "grpc"], + output: OutputLiteral, + ): + """Ask a LLM interactively, from a terminal. + + $ openllm query --endpoint http://12.323.2.1 "What is the meaning of life?" + """ + target_url = f"http://0.0.0.0:{port}" if local else f"{endpoint}:{port}" + + client = ( + openllm.client.HTTPClient(target_url, timeout=timeout) + if server_type == "http" + else openllm.client.GrpcClient(target_url, timeout=timeout) + ) + + if client.framework == "flax": + model = openllm.AutoFlaxLLM.for_model(client.model_name) + elif client.framework == "tf": + model = openllm.AutoTFLLM.for_model(client.model_name) + else: + model = openllm.AutoLLM.for_model(client.model_name) + + if output != "porcelain": + _echo(f"Processing query: {query}\n", fg="white") + + res = client.query(query, return_raw_response=True) + + if output == "pretty": + formatted = model.postprocess_generate(query, res["responses"]) + _echo("Responses: ", fg="white", nl=False) + _echo(formatted, fg="cyan") + elif output == "json": + _echo(orjson.dumps(res, option=orjson.OPT_INDENT_2).decode(), fg="white") + else: + _echo(res["responses"], fg="white") + if t.TYPE_CHECKING: - assert download_models and bundle and models and version and start and start_grpc + assert download_models and bundle and models and version and start and start_grpc and query if psutil.WINDOWS: sys.stdout.reconfigure(encoding="utf-8") # type: ignore