mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-05-19 14:16:22 -04:00
feat(client): simple implementation and streaming (#256)
This commit is contained in:
@@ -39,7 +39,6 @@ from openllm_core._typing_compat import T
|
||||
from openllm_core._typing_compat import TupleAny
|
||||
from openllm_core._typing_compat import overload
|
||||
from openllm_core.prompts import PromptTemplate
|
||||
from openllm_core.prompts import process_prompt
|
||||
from openllm_core.utils import DEBUG
|
||||
from openllm_core.utils import MYPY
|
||||
from openllm_core.utils import EnvVarMixin
|
||||
@@ -620,7 +619,7 @@ class LLM(LLMInterface[M, T], ReprMixin):
|
||||
# set default tokenizer kwargs
|
||||
tokenizer_kwds.update({'padding_side': 'left', 'truncation_side': 'left'})
|
||||
|
||||
# parsing tokenizer and model kwargs, as the hierachy is param pass > default
|
||||
# parsing tokenizer and model kwargs, as the hierarchy is param pass > default
|
||||
normalized_model_kwds, normalized_tokenizer_kwds = normalize_attrs_to_model_tokenizer_pair(**attrs)
|
||||
# NOTE: Save the args and kwargs for latter load
|
||||
self.__attrs_init__(llm_config, quantization_config, _quantize, model_id, args, {
|
||||
@@ -1211,6 +1210,8 @@ def llm_runnable_class(self: LLM[M, T], embeddings_sig: ModelSignature, generate
|
||||
if adapter_name is not None: __self.set_adapter(adapter_name)
|
||||
request_id: str | None = attrs.pop('request_id', None)
|
||||
if request_id is None: raise ValueError('request_id must not be None.')
|
||||
prompt, *_ = self.sanitize_parameters(prompt, **attrs)
|
||||
if openllm_core.utils.DEBUG: logger.debug('Prompt:\n%s', prompt)
|
||||
|
||||
if stop_token_ids is None: stop_token_ids = []
|
||||
stop_token_ids.append(self.tokenizer.eos_token_id)
|
||||
@@ -1237,7 +1238,6 @@ def llm_runnable_class(self: LLM[M, T], embeddings_sig: ModelSignature, generate
|
||||
async def vllm_generate_iterator(__self: _Runnable, prompt: str, **attrs: t.Any) -> t.AsyncGenerator[str, None]:
|
||||
# TODO: System prompt support
|
||||
pre = 0
|
||||
prompt = process_prompt(prompt, None, False)
|
||||
echo = attrs.pop('echo', False)
|
||||
stop: str | t.Iterable[str] | None = attrs.pop('stop', None)
|
||||
stop_token_ids: list[int] | None = attrs.pop('stop_token_ids', None)
|
||||
@@ -1247,6 +1247,8 @@ def llm_runnable_class(self: LLM[M, T], embeddings_sig: ModelSignature, generate
|
||||
if adapter_name is not None: __self.set_adapter(adapter_name)
|
||||
request_id: str | None = attrs.pop('request_id', None)
|
||||
if request_id is None: raise ValueError('request_id must not be None.')
|
||||
prompt, *_ = self.sanitize_parameters(prompt, **attrs)
|
||||
if openllm_core.utils.DEBUG: logger.debug('Prompt:\n%s', prompt)
|
||||
|
||||
if stop_token_ids is None: stop_token_ids = []
|
||||
stop_token_ids.append(self.tokenizer.eos_token_id)
|
||||
@@ -1342,7 +1344,9 @@ def llm_runner_class(self: LLM[M, T]) -> type[LLMRunner[M, T]]:
|
||||
'__repr_args__': _wrapped_repr_args,
|
||||
'supports_embeddings': self['supports_embeddings'],
|
||||
'supports_hf_agent': self['supports_generate_one'],
|
||||
'has_adapters': self._adapters_mapping is not None
|
||||
'has_adapters': self._adapters_mapping is not None,
|
||||
'prompt_template': self._prompt_template.to_string() if self._prompt_template else self.config.default_prompt_template,
|
||||
'system_message': self._system_message if self._system_message else self.config.default_system_message,
|
||||
}))
|
||||
|
||||
__all__ = ['LLMRunner', 'LLMRunnable', 'Runner', 'LLM', 'llm_runner_class', 'llm_runnable_class', 'EmbeddingsOutput']
|
||||
|
||||
@@ -32,13 +32,7 @@ model = svars.model
|
||||
model_id = svars.model_id
|
||||
adapter_map = svars.adapter_map
|
||||
llm_config = openllm.AutoConfig.for_model(model)
|
||||
runner = openllm.Runner(
|
||||
model,
|
||||
llm_config=llm_config,
|
||||
model_id=model_id,
|
||||
ensure_available=False,
|
||||
adapter_map=orjson.loads(adapter_map)
|
||||
)
|
||||
runner = openllm.Runner(model, llm_config=llm_config, model_id=model_id, ensure_available=False, adapter_map=orjson.loads(adapter_map))
|
||||
generic_embedding_runner = bentoml.Runner(openllm.GenericEmbeddingRunnable, # XXX: remove arg-type once bentoml.Runner is correct set with type
|
||||
name='llm-generic-embedding',
|
||||
scheduling_strategy=openllm_core.CascadingResourceStrategy,
|
||||
@@ -189,9 +183,11 @@ async def chat_completion_v1(input_dict: dict[str, t.Any], ctx: bentoml.Context)
|
||||
'timeout': 3600,
|
||||
'model_name': llm_config['model_name'],
|
||||
'backend': runner.backend,
|
||||
'configuration': '',
|
||||
'configuration': llm_config.model_dump(flatten=True),
|
||||
'supports_embeddings': runner.supports_embeddings,
|
||||
'supports_hf_agent': runner.supports_hf_agent
|
||||
'supports_hf_agent': runner.supports_hf_agent,
|
||||
'prompt_template': runner.prompt_template,
|
||||
'system_message': runner.system_message,
|
||||
}))
|
||||
def metadata_v1(_: str) -> openllm.MetadataOutput:
|
||||
return openllm.MetadataOutput(timeout=llm_config['timeout'],
|
||||
@@ -200,7 +196,10 @@ def metadata_v1(_: str) -> openllm.MetadataOutput:
|
||||
model_id=runner.llm.model_id,
|
||||
configuration=llm_config.model_dump_json().decode(),
|
||||
supports_embeddings=runner.supports_embeddings,
|
||||
supports_hf_agent=runner.supports_hf_agent)
|
||||
supports_hf_agent=runner.supports_hf_agent,
|
||||
prompt_template=runner.prompt_template,
|
||||
system_message=runner.system_message,
|
||||
)
|
||||
|
||||
@svc.api(route='/v1/embeddings',
|
||||
input=bentoml.io.JSON.from_sample(['Hey Jude, welcome to the jungle!', 'What is the meaning of life?']),
|
||||
|
||||
@@ -27,7 +27,6 @@ import itertools
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
@@ -795,6 +794,7 @@ def instruct_command(endpoint: str, timeout: int, agent: LiteralString, output:
|
||||
--text "¡Este es un API muy agradable!"
|
||||
```
|
||||
'''
|
||||
raise click.ClickException("'instruct' is currently disabled")
|
||||
client = openllm.client.HTTPClient(endpoint, timeout=timeout)
|
||||
|
||||
try:
|
||||
@@ -844,15 +844,16 @@ def embed_command(
|
||||
termui.echo(gen_embed.embeddings, fg='white')
|
||||
ctx.exit(0)
|
||||
@cli.command()
|
||||
@shared_client_options
|
||||
@shared_client_options(output_value='porcelain')
|
||||
@click.option('--server-type', type=click.Choice(['grpc', 'http']), help='Server type', default='http', show_default=True)
|
||||
@click.option('--stream/--no-stream', type=click.BOOL, is_flag=True, default=True, help='Whether to stream the response.')
|
||||
@click.argument('prompt', type=click.STRING)
|
||||
@click.option(
|
||||
'--sampling-params', help='Define query options. (format: ``--opt temperature=0.8 --opt=top_k:12)', required=False, multiple=True, callback=opt_callback, metavar='ARG=VALUE[,ARG=VALUE]'
|
||||
)
|
||||
@click.pass_context
|
||||
def query_command(
|
||||
ctx: click.Context, /, prompt: str, endpoint: str, timeout: int, server_type: t.Literal['http', 'grpc'], output: LiteralOutput, _memoized: DictStrAny, **attrs: t.Any
|
||||
ctx: click.Context, /, prompt: str, endpoint: str, timeout: int, stream: bool, server_type: t.Literal['http', 'grpc'], output: LiteralOutput, _memoized: DictStrAny, **attrs: t.Any
|
||||
) -> None:
|
||||
'''Ask a LLM interactively, from a terminal.
|
||||
|
||||
@@ -862,23 +863,30 @@ def query_command(
|
||||
```
|
||||
'''
|
||||
_memoized = {k: orjson.loads(v[0]) for k, v in _memoized.items() if v}
|
||||
if server_type == 'grpc': endpoint = re.sub(r'http://', '', endpoint)
|
||||
client = openllm.client.HTTPClient(endpoint, timeout=timeout) if server_type == 'http' else openllm.client.GrpcClient(endpoint, timeout=timeout)
|
||||
if server_type == 'grpc': raise click.ClickException("'grpc' is currently disabled.")
|
||||
# TODO: grpc support
|
||||
client = openllm.client.HTTPClient(address=endpoint, timeout=timeout)
|
||||
input_fg, generated_fg = 'magenta', 'cyan'
|
||||
if output != 'porcelain':
|
||||
termui.echo('==Input==\n', fg='white')
|
||||
termui.echo(f'{prompt}', fg=input_fg)
|
||||
res = client.query(prompt, return_response='raw', **{**client.configuration, **_memoized})
|
||||
fn = client.generate_stream if stream else client.generate
|
||||
res = fn(prompt, **{**client._config(), **_memoized})
|
||||
if output == 'pretty':
|
||||
response = client.config.postprocess_generate(prompt, res['responses'])
|
||||
if isinstance(response, dict) and 'text' in response: response = response['text']
|
||||
termui.echo('\n\n==Responses==\n', fg='white')
|
||||
termui.echo(response, fg=generated_fg)
|
||||
if stream:
|
||||
for it in res: termui.echo(it.text, fg=generated_fg, nl=False)
|
||||
else: termui.echo(res.responses[0], fg=generated_fg)
|
||||
elif output == 'json':
|
||||
termui.echo(orjson.dumps(res, option=orjson.OPT_INDENT_2).decode(), fg='white')
|
||||
else:
|
||||
termui.echo(res['responses'], fg='white')
|
||||
if stream:
|
||||
for it in res: termui.echo(orjson.dumps(bentoml_cattr.unstructure(it), option=orjson.OPT_INDENT_2).decode(), fg='white')
|
||||
else: termui.echo(orjson.dumps(bentoml_cattr.unstructure(res), option=orjson.OPT_INDENT_2).decode(), fg='white')
|
||||
else: # noqa: PLR5501
|
||||
if stream:
|
||||
for it in res: termui.echo(it.text, fg=generated_fg, nl=False)
|
||||
else: termui.echo(res.responses, fg='white')
|
||||
ctx.exit(0)
|
||||
|
||||
@cli.group(cls=Extensions, hidden=True, name='extension')
|
||||
def extension_command() -> None:
|
||||
'''Extension for OpenLLM CLI.'''
|
||||
|
||||
@@ -16,12 +16,10 @@ import typing as t
|
||||
import openllm_client
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from openllm_client import AsyncGrpcClient as AsyncGrpcClient
|
||||
from openllm_client import AsyncHTTPClient as AsyncHTTPClient
|
||||
from openllm_client import BaseAsyncClient as BaseAsyncClient
|
||||
from openllm_client import BaseClient as BaseClient
|
||||
from openllm_client import GrpcClient as GrpcClient
|
||||
from openllm_client import HTTPClient as HTTPClient
|
||||
# from openllm_client import AsyncGrpcClient as AsyncGrpcClient
|
||||
# from openllm_client import GrpcClient as GrpcClient
|
||||
|
||||
def __dir__() -> t.Sequence[str]:
|
||||
return sorted(dir(openllm_client))
|
||||
|
||||
Reference in New Issue
Block a user