feat(cli): openllm query

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
Aaron
2023-06-05 22:39:04 -04:00
parent 41a6bd03a6
commit 1707beb7aa

View File

@@ -29,6 +29,7 @@ import typing as t
import bentoml
import click
import click_option_group as cog
import inflection
import orjson
import psutil
@@ -342,7 +343,7 @@ def start_model_command(
command_attrs.update(
{
"short_help": "(Disabled because there is no GPU available)",
"help": f"""{model_name} is currently not available to run on your
"help": f"""{model_name} is currently not available to run on your
local machine because it requires GPU for faster inference.""",
}
)
@@ -603,7 +604,7 @@ def cli_factory() -> click.Group:
_echo(bento.tag)
return bento
@cli.command()
@cli.command(aliases=["list"])
@output_option
def models(output: OutputLiteral):
"""List all supported models."""
@@ -664,7 +665,7 @@ def cli_factory() -> click.Group:
sys.exit(0)
@cli.command()
@cli.command(aliases=["save"])
@click.argument(
"model_name", type=click.Choice([inflection.dasherize(name) for name in openllm.CONFIG_MAPPING.keys()])
)
@@ -717,8 +718,63 @@ def cli_factory() -> click.Group:
_echo(m.tag)
return m
@cli.command(name="query", aliases=["run", "ask"])
@cog.optgroup.group(
"Host options", cls=cog.RequiredMutuallyExclusiveOptionGroup, help="default host for the running LLM server"
)
@cog.optgroup.option("--endpoint", type=click.STRING, help="LLM Server endpoint, i.e: http://12.323.2.1")
@cog.optgroup.option("--local", type=click.BOOL, help="Whether the server is running locally.")
@click.option("--port", type=click.INT, default=3000, help="LLM Server port", show_default=True)
@click.option("--timeout", type=click.INT, default=30, help="Default server timeout", show_default=True)
@click.option(
"--server-type", type=click.Choice(["grpc", "http"]), help="Server type", default="http", show_default=True
)
@output_option
@click.argument("query", type=click.STRING)
def query(
query: str,
endpoint: str,
port: int,
timeout: int,
local: bool,
server_type: t.Literal["http", "grpc"],
output: OutputLiteral,
):
"""Ask a LLM interactively, from a terminal.
$ openllm query --endpoint http://12.323.2.1 "What is the meaning of life?"
"""
target_url = f"http://0.0.0.0:{port}" if local else f"{endpoint}:{port}"
client = (
openllm.client.HTTPClient(target_url, timeout=timeout)
if server_type == "http"
else openllm.client.GrpcClient(target_url, timeout=timeout)
)
if client.framework == "flax":
model = openllm.AutoFlaxLLM.for_model(client.model_name)
elif client.framework == "tf":
model = openllm.AutoTFLLM.for_model(client.model_name)
else:
model = openllm.AutoLLM.for_model(client.model_name)
if output != "porcelain":
_echo(f"Processing query: {query}\n", fg="white")
res = client.query(query, return_raw_response=True)
if output == "pretty":
formatted = model.postprocess_generate(query, res["responses"])
_echo("Responses: ", fg="white", nl=False)
_echo(formatted, fg="cyan")
elif output == "json":
_echo(orjson.dumps(res, option=orjson.OPT_INDENT_2).decode(), fg="white")
else:
_echo(res["responses"], fg="white")
if t.TYPE_CHECKING:
assert download_models and bundle and models and version and start and start_grpc
assert download_models and bundle and models and version and start and start_grpc and query
if psutil.WINDOWS:
sys.stdout.reconfigure(encoding="utf-8") # type: ignore