feat(client): simple implementation and streaming (#256)

This commit is contained in:
Aaron Pham
2023-10-12 17:21:54 -04:00
committed by GitHub
parent 59cee71b0c
commit 1539c3f7dc
42 changed files with 2581 additions and 997 deletions

View File

@@ -27,7 +27,6 @@ import itertools
import logging
import os
import platform
import re
import subprocess
import sys
import time
@@ -795,6 +794,7 @@ def instruct_command(endpoint: str, timeout: int, agent: LiteralString, output:
--text "¡Este es un API muy agradable!"
```
'''
raise click.ClickException("'instruct' is currently disabled")
client = openllm.client.HTTPClient(endpoint, timeout=timeout)
try:
@@ -844,15 +844,16 @@ def embed_command(
termui.echo(gen_embed.embeddings, fg='white')
ctx.exit(0)
@cli.command()
@shared_client_options
@shared_client_options(output_value='porcelain')
@click.option('--server-type', type=click.Choice(['grpc', 'http']), help='Server type', default='http', show_default=True)
@click.option('--stream/--no-stream', type=click.BOOL, is_flag=True, default=True, help='Whether to stream the response.')
@click.argument('prompt', type=click.STRING)
@click.option(
'--sampling-params', help='Define query options. (format: ``--opt temperature=0.8 --opt=top_k:12)', required=False, multiple=True, callback=opt_callback, metavar='ARG=VALUE[,ARG=VALUE]'
)
@click.pass_context
def query_command(
ctx: click.Context, /, prompt: str, endpoint: str, timeout: int, server_type: t.Literal['http', 'grpc'], output: LiteralOutput, _memoized: DictStrAny, **attrs: t.Any
ctx: click.Context, /, prompt: str, endpoint: str, timeout: int, stream: bool, server_type: t.Literal['http', 'grpc'], output: LiteralOutput, _memoized: DictStrAny, **attrs: t.Any
) -> None:
'''Ask a LLM interactively, from a terminal.
@@ -862,23 +863,30 @@ def query_command(
```
'''
_memoized = {k: orjson.loads(v[0]) for k, v in _memoized.items() if v}
if server_type == 'grpc': endpoint = re.sub(r'http://', '', endpoint)
client = openllm.client.HTTPClient(endpoint, timeout=timeout) if server_type == 'http' else openllm.client.GrpcClient(endpoint, timeout=timeout)
if server_type == 'grpc': raise click.ClickException("'grpc' is currently disabled.")
# TODO: grpc support
client = openllm.client.HTTPClient(address=endpoint, timeout=timeout)
input_fg, generated_fg = 'magenta', 'cyan'
if output != 'porcelain':
termui.echo('==Input==\n', fg='white')
termui.echo(f'{prompt}', fg=input_fg)
res = client.query(prompt, return_response='raw', **{**client.configuration, **_memoized})
fn = client.generate_stream if stream else client.generate
res = fn(prompt, **{**client._config(), **_memoized})
if output == 'pretty':
response = client.config.postprocess_generate(prompt, res['responses'])
if isinstance(response, dict) and 'text' in response: response = response['text']
termui.echo('\n\n==Responses==\n', fg='white')
termui.echo(response, fg=generated_fg)
if stream:
for it in res: termui.echo(it.text, fg=generated_fg, nl=False)
else: termui.echo(res.responses[0], fg=generated_fg)
elif output == 'json':
termui.echo(orjson.dumps(res, option=orjson.OPT_INDENT_2).decode(), fg='white')
else:
termui.echo(res['responses'], fg='white')
if stream:
for it in res: termui.echo(orjson.dumps(bentoml_cattr.unstructure(it), option=orjson.OPT_INDENT_2).decode(), fg='white')
else: termui.echo(orjson.dumps(bentoml_cattr.unstructure(res), option=orjson.OPT_INDENT_2).decode(), fg='white')
else: # noqa: PLR5501
if stream:
for it in res: termui.echo(it.text, fg=generated_fg, nl=False)
else: termui.echo(res.responses, fg='white')
ctx.exit(0)
@cli.group(cls=Extensions, hidden=True, name='extension')
def extension_command() -> None:
'''Extension for OpenLLM CLI.'''