feat(client): simple implementation and streaming (#256)

This commit is contained in:
Aaron Pham
2023-10-12 17:21:54 -04:00
committed by GitHub
parent 59cee71b0c
commit 1539c3f7dc
42 changed files with 2581 additions and 997 deletions

View File

@@ -39,7 +39,6 @@ from openllm_core._typing_compat import T
from openllm_core._typing_compat import TupleAny
from openllm_core._typing_compat import overload
from openllm_core.prompts import PromptTemplate
from openllm_core.prompts import process_prompt
from openllm_core.utils import DEBUG
from openllm_core.utils import MYPY
from openllm_core.utils import EnvVarMixin
@@ -620,7 +619,7 @@ class LLM(LLMInterface[M, T], ReprMixin):
# set default tokenizer kwargs
tokenizer_kwds.update({'padding_side': 'left', 'truncation_side': 'left'})
# parsing tokenizer and model kwargs, as the hierachy is param pass > default
# parsing tokenizer and model kwargs, as the hierarchy is param pass > default
normalized_model_kwds, normalized_tokenizer_kwds = normalize_attrs_to_model_tokenizer_pair(**attrs)
# NOTE: Save the args and kwargs for latter load
self.__attrs_init__(llm_config, quantization_config, _quantize, model_id, args, {
@@ -1211,6 +1210,8 @@ def llm_runnable_class(self: LLM[M, T], embeddings_sig: ModelSignature, generate
if adapter_name is not None: __self.set_adapter(adapter_name)
request_id: str | None = attrs.pop('request_id', None)
if request_id is None: raise ValueError('request_id must not be None.')
prompt, *_ = self.sanitize_parameters(prompt, **attrs)
if openllm_core.utils.DEBUG: logger.debug('Prompt:\n%s', prompt)
if stop_token_ids is None: stop_token_ids = []
stop_token_ids.append(self.tokenizer.eos_token_id)
@@ -1237,7 +1238,6 @@ def llm_runnable_class(self: LLM[M, T], embeddings_sig: ModelSignature, generate
async def vllm_generate_iterator(__self: _Runnable, prompt: str, **attrs: t.Any) -> t.AsyncGenerator[str, None]:
# TODO: System prompt support
pre = 0
prompt = process_prompt(prompt, None, False)
echo = attrs.pop('echo', False)
stop: str | t.Iterable[str] | None = attrs.pop('stop', None)
stop_token_ids: list[int] | None = attrs.pop('stop_token_ids', None)
@@ -1247,6 +1247,8 @@ def llm_runnable_class(self: LLM[M, T], embeddings_sig: ModelSignature, generate
if adapter_name is not None: __self.set_adapter(adapter_name)
request_id: str | None = attrs.pop('request_id', None)
if request_id is None: raise ValueError('request_id must not be None.')
prompt, *_ = self.sanitize_parameters(prompt, **attrs)
if openllm_core.utils.DEBUG: logger.debug('Prompt:\n%s', prompt)
if stop_token_ids is None: stop_token_ids = []
stop_token_ids.append(self.tokenizer.eos_token_id)
@@ -1342,7 +1344,9 @@ def llm_runner_class(self: LLM[M, T]) -> type[LLMRunner[M, T]]:
'__repr_args__': _wrapped_repr_args,
'supports_embeddings': self['supports_embeddings'],
'supports_hf_agent': self['supports_generate_one'],
'has_adapters': self._adapters_mapping is not None
'has_adapters': self._adapters_mapping is not None,
'prompt_template': self._prompt_template.to_string() if self._prompt_template else self.config.default_prompt_template,
'system_message': self._system_message if self._system_message else self.config.default_system_message,
}))
__all__ = ['LLMRunner', 'LLMRunnable', 'Runner', 'LLM', 'llm_runner_class', 'llm_runnable_class', 'EmbeddingsOutput']

View File

@@ -32,13 +32,7 @@ model = svars.model
model_id = svars.model_id
adapter_map = svars.adapter_map
llm_config = openllm.AutoConfig.for_model(model)
runner = openllm.Runner(
model,
llm_config=llm_config,
model_id=model_id,
ensure_available=False,
adapter_map=orjson.loads(adapter_map)
)
runner = openllm.Runner(model, llm_config=llm_config, model_id=model_id, ensure_available=False, adapter_map=orjson.loads(adapter_map))
generic_embedding_runner = bentoml.Runner(openllm.GenericEmbeddingRunnable, # XXX: remove arg-type once bentoml.Runner is correct set with type
name='llm-generic-embedding',
scheduling_strategy=openllm_core.CascadingResourceStrategy,
@@ -189,9 +183,11 @@ async def chat_completion_v1(input_dict: dict[str, t.Any], ctx: bentoml.Context)
'timeout': 3600,
'model_name': llm_config['model_name'],
'backend': runner.backend,
'configuration': '',
'configuration': llm_config.model_dump(flatten=True),
'supports_embeddings': runner.supports_embeddings,
'supports_hf_agent': runner.supports_hf_agent
'supports_hf_agent': runner.supports_hf_agent,
'prompt_template': runner.prompt_template,
'system_message': runner.system_message,
}))
def metadata_v1(_: str) -> openllm.MetadataOutput:
return openllm.MetadataOutput(timeout=llm_config['timeout'],
@@ -200,7 +196,10 @@ def metadata_v1(_: str) -> openllm.MetadataOutput:
model_id=runner.llm.model_id,
configuration=llm_config.model_dump_json().decode(),
supports_embeddings=runner.supports_embeddings,
supports_hf_agent=runner.supports_hf_agent)
supports_hf_agent=runner.supports_hf_agent,
prompt_template=runner.prompt_template,
system_message=runner.system_message,
)
@svc.api(route='/v1/embeddings',
input=bentoml.io.JSON.from_sample(['Hey Jude, welcome to the jungle!', 'What is the meaning of life?']),

View File

@@ -27,7 +27,6 @@ import itertools
import logging
import os
import platform
import re
import subprocess
import sys
import time
@@ -795,6 +794,7 @@ def instruct_command(endpoint: str, timeout: int, agent: LiteralString, output:
--text "¡Este es un API muy agradable!"
```
'''
raise click.ClickException("'instruct' is currently disabled")
client = openllm.client.HTTPClient(endpoint, timeout=timeout)
try:
@@ -844,15 +844,16 @@ def embed_command(
termui.echo(gen_embed.embeddings, fg='white')
ctx.exit(0)
@cli.command()
@shared_client_options
@shared_client_options(output_value='porcelain')
@click.option('--server-type', type=click.Choice(['grpc', 'http']), help='Server type', default='http', show_default=True)
@click.option('--stream/--no-stream', type=click.BOOL, is_flag=True, default=True, help='Whether to stream the response.')
@click.argument('prompt', type=click.STRING)
@click.option(
'--sampling-params', help='Define query options. (format: ``--opt temperature=0.8 --opt=top_k:12)', required=False, multiple=True, callback=opt_callback, metavar='ARG=VALUE[,ARG=VALUE]'
)
@click.pass_context
def query_command(
ctx: click.Context, /, prompt: str, endpoint: str, timeout: int, server_type: t.Literal['http', 'grpc'], output: LiteralOutput, _memoized: DictStrAny, **attrs: t.Any
ctx: click.Context, /, prompt: str, endpoint: str, timeout: int, stream: bool, server_type: t.Literal['http', 'grpc'], output: LiteralOutput, _memoized: DictStrAny, **attrs: t.Any
) -> None:
'''Ask a LLM interactively, from a terminal.
@@ -862,23 +863,30 @@ def query_command(
```
'''
_memoized = {k: orjson.loads(v[0]) for k, v in _memoized.items() if v}
if server_type == 'grpc': endpoint = re.sub(r'http://', '', endpoint)
client = openllm.client.HTTPClient(endpoint, timeout=timeout) if server_type == 'http' else openllm.client.GrpcClient(endpoint, timeout=timeout)
if server_type == 'grpc': raise click.ClickException("'grpc' is currently disabled.")
# TODO: grpc support
client = openllm.client.HTTPClient(address=endpoint, timeout=timeout)
input_fg, generated_fg = 'magenta', 'cyan'
if output != 'porcelain':
termui.echo('==Input==\n', fg='white')
termui.echo(f'{prompt}', fg=input_fg)
res = client.query(prompt, return_response='raw', **{**client.configuration, **_memoized})
fn = client.generate_stream if stream else client.generate
res = fn(prompt, **{**client._config(), **_memoized})
if output == 'pretty':
response = client.config.postprocess_generate(prompt, res['responses'])
if isinstance(response, dict) and 'text' in response: response = response['text']
termui.echo('\n\n==Responses==\n', fg='white')
termui.echo(response, fg=generated_fg)
if stream:
for it in res: termui.echo(it.text, fg=generated_fg, nl=False)
else: termui.echo(res.responses[0], fg=generated_fg)
elif output == 'json':
termui.echo(orjson.dumps(res, option=orjson.OPT_INDENT_2).decode(), fg='white')
else:
termui.echo(res['responses'], fg='white')
if stream:
for it in res: termui.echo(orjson.dumps(bentoml_cattr.unstructure(it), option=orjson.OPT_INDENT_2).decode(), fg='white')
else: termui.echo(orjson.dumps(bentoml_cattr.unstructure(res), option=orjson.OPT_INDENT_2).decode(), fg='white')
else: # noqa: PLR5501
if stream:
for it in res: termui.echo(it.text, fg=generated_fg, nl=False)
else: termui.echo(res.responses, fg='white')
ctx.exit(0)
@cli.group(cls=Extensions, hidden=True, name='extension')
def extension_command() -> None:
'''Extension for OpenLLM CLI.'''

View File

@@ -16,12 +16,10 @@ import typing as t
import openllm_client
if t.TYPE_CHECKING:
from openllm_client import AsyncGrpcClient as AsyncGrpcClient
from openllm_client import AsyncHTTPClient as AsyncHTTPClient
from openllm_client import BaseAsyncClient as BaseAsyncClient
from openllm_client import BaseClient as BaseClient
from openllm_client import GrpcClient as GrpcClient
from openllm_client import HTTPClient as HTTPClient
# from openllm_client import AsyncGrpcClient as AsyncGrpcClient
# from openllm_client import GrpcClient as GrpcClient
def __dir__() -> t.Sequence[str]:
return sorted(dir(openllm_client))