perf: unify LLM interface (#518)

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
Signed-off-by: paperspace <29749331+aarnphm@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Aaron Pham
2023-11-06 20:39:43 -05:00
committed by GitHub
parent f2639879af
commit e2029c934b
136 changed files with 9646 additions and 11244 deletions

View File

@@ -11,6 +11,7 @@ import inflection
import orjson
from bentoml_cli.utils import BentoMLCommandGroup
from click import ClickException
from click import shell_completion as sc
from click.shell_completion import CompletionItem
@@ -28,6 +29,9 @@ from openllm_core._typing_compat import LiteralString
from openllm_core._typing_compat import ParamSpec
from openllm_core._typing_compat import get_literal_args
from openllm_core.utils import DEBUG
from openllm_core.utils import check_bool_env
from openllm_core.utils import first_not_none
from openllm_core.utils import is_vllm_available
from . import termui
@@ -62,7 +66,6 @@ def parse_config_options(config: LLMConfig, server_timeout: int, workers_per_res
_bentoml_config_options_opts.extend([f'runners."llm-{config["start_name"]}-runner".resources."nvidia.com/gpu"[{idx}]={dev}' for idx, dev in enumerate(device)])
else:
_bentoml_config_options_opts.append(f'runners."llm-{config["start_name"]}-runner".resources."nvidia.com/gpu"=[{device[0]}]')
_bentoml_config_options_opts.append(f'runners."llm-generic-embedding".resources.cpu={openllm.get_resource({"cpu":"system"},"cpu")}')
if cors:
_bentoml_config_options_opts.extend(['api_server.http.cors.enabled=true', 'api_server.http.cors.access_control_allow_origins="*"'])
_bentoml_config_options_opts.extend([f'api_server.http.cors.access_control_allow_methods[{idx}]="{it}"' for idx, it in enumerate(['GET', 'OPTIONS', 'POST', 'HEAD', 'PUT'])])
@@ -84,7 +87,8 @@ def _id_callback(ctx: click.Context, _: click.Parameter, value: t.Tuple[str, ...
adapter_id = openllm.utils.resolve_user_filepath(adapter_id, os.getcwd())
except FileNotFoundError:
pass
ctx.params[_adapter_mapping_key][adapter_id] = adapter_name[0] if len(adapter_name) > 0 else None
if len(adapter_name) == 0: raise ClickException(f'Adapter name is required for {adapter_id}')
ctx.params[_adapter_mapping_key][adapter_id] = adapter_name[0]
return None
def start_command_factory(group: click.Group, model: str, _context_settings: DictStrAny | None = None, _serve_grpc: bool = False) -> click.Command:
@@ -117,24 +121,23 @@ Available official model_id(s): [default: {llm_config['default_id']}]
@start_decorator(llm_config, serve_grpc=_serve_grpc)
@click.pass_context
def start_cmd(ctx: click.Context, /, server_timeout: int, model_id: str | None, model_version: str | None, system_message: str | None, prompt_template_file: t.IO[t.Any] | None,
workers_per_resource: t.Literal['conserved', 'round_robin'] | LiteralString, device: t.Tuple[str, ...], quantize: LiteralQuantise | None, backend: LiteralBackend,
serialisation: LiteralSerialisation | None, cors: bool, adapter_id: str | None, return_process: bool, **attrs: t.Any,
) -> LLMConfig | subprocess.Popen[bytes]:
_serialisation = openllm_core.utils.first_not_none(serialisation, default=llm_config['serialisation'])
if _serialisation == 'safetensors' and quantize is not None and openllm_core.utils.check_bool_env('OPENLLM_SERIALIZATION_WARNING'):
workers_per_resource: t.Literal['conserved', 'round_robin'] | LiteralString, device: t.Tuple[str, ...], quantize: LiteralQuantise | None, backend: LiteralBackend | None,
serialisation: LiteralSerialisation | None, cors: bool, adapter_id: str | None, return_process: bool, **attrs: t.Any) -> LLMConfig | subprocess.Popen[bytes]:
_serialisation = t.cast(LiteralSerialisation, first_not_none(serialisation, default=llm_config['serialisation']))
if _serialisation == 'safetensors' and quantize is not None and check_bool_env('OPENLLM_SERIALIZATION_WARNING'):
termui.echo(
f"'--quantize={quantize}' might not work with 'safetensors' serialisation format. To silence this warning, set \"OPENLLM_SERIALIZATION_WARNING=False\"\nNote: You can always fallback to '--serialisation legacy' when running quantisation.",
fg='yellow')
termui.echo(f"Make sure to check out '{model_id}' repository to see if the weights is in '{_serialisation}' format if unsure.")
adapter_map: dict[str, str | None] | None = attrs.pop(_adapter_mapping_key, None)
adapter_map: dict[str, str] | None = attrs.pop(_adapter_mapping_key, None)
config, server_attrs = llm_config.model_validate_click(**attrs)
server_timeout = openllm.utils.first_not_none(server_timeout, default=config['timeout'])
server_timeout = first_not_none(server_timeout, default=config['timeout'])
server_attrs.update({'working_dir': os.path.dirname(os.path.dirname(__file__)), 'timeout': server_timeout})
if _serve_grpc: server_attrs['grpc_protocol_version'] = 'v1'
# NOTE: currently, theres no development args in bentoml.Server. To be fixed upstream.
development = server_attrs.pop('development')
server_attrs.setdefault('production', not development)
wpr = openllm.utils.first_not_none(workers_per_resource, default=config['workers_per_resource'])
wpr = first_not_none(workers_per_resource, default=config['workers_per_resource'])
if isinstance(wpr, str):
if wpr == 'round_robin': wpr = 1.0
@@ -151,7 +154,10 @@ Available official model_id(s): [default: {llm_config['default_id']}]
wpr = float(wpr)
# Create a new model env to work with the envvar during CLI invocation
env = openllm.utils.EnvVarMixin(config['model_name'], backend, model_id=model_id or config['default_id'], quantize=quantize)
env = openllm.utils.EnvVarMixin(config['model_name'],
backend=openllm_core.utils.first_not_none(backend, default='vllm' if is_vllm_available() else 'pt'),
model_id=model_id or config['default_id'],
quantize=quantize)
requirements = llm_config['requirements']
if requirements is not None and len(requirements) > 0:
missing_requirements = [i for i in requirements if importlib.util.find_spec(inflection.underscore(i)) is None]
@@ -176,16 +182,16 @@ Available official model_id(s): [default: {llm_config['default_id']}]
if system_message: start_env['OPENLLM_SYSTEM_MESSAGE'] = system_message
if prompt_template: start_env['OPENLLM_PROMPT_TEMPLATE'] = prompt_template
llm = openllm.utils.infer_auto_class(env['backend_value']).for_model(model,
model_id=start_env[env.model_id],
model_version=model_version,
prompt_template=prompt_template,
system_message=system_message,
llm_config=config,
ensure_available=True,
adapter_map=adapter_map,
quantize=env['quantize_value'],
serialisation=_serialisation)
llm = openllm.LLM[t.Any, t.Any](model_id=start_env[env.model_id],
revision=model_version,
prompt_template=prompt_template,
system_message=system_message,
llm_config=config,
backend=env['backend_value'],
adapter_map=adapter_map,
quantize=env['quantize_value'],
serialisation=_serialisation)
llm.save_pretrained() # ensure_available = True
start_env.update({env.config: llm.config.model_dump_json().decode()})
server = bentoml.GrpcServer('_service:svc', **server_attrs) if _serve_grpc else bentoml.HTTPServer('_service:svc', **server_attrs)
@@ -382,8 +388,8 @@ def backend_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[
# NOTE: LiteralBackend needs to remove the last two item as ggml and mlc is wip
# XXX: remove the check for __args__ once we have ggml and mlc supports
return cli_option('--backend',
type=click.Choice(get_literal_args(LiteralBackend)[:-2]),
default='pt',
type=click.Choice(get_literal_args(LiteralBackend)[:2]),
default=None,
envvar='OPENLLM_BACKEND',
show_envvar=True,
help='The implementation for saving this LLM.',
@@ -396,7 +402,7 @@ def quantize_option(f: _AnyCallable | None = None, *, build: bool = False, **att
return cli_option('--quantise',
'--quantize',
'quantize',
type=click.Choice(['int8', 'int4', 'gptq']),
type=click.Choice(get_literal_args(LiteralQuantise)),
default=None,
envvar='OPENLLM_QUANTIZE',
show_envvar=True,

View File

@@ -16,6 +16,7 @@ import openllm_core
from bentoml._internal.configuration.containers import BentoMLContainer
from openllm.exceptions import OpenLLMException
from openllm_core.utils import is_vllm_available
from . import termui
from ._factory import start_command_factory
@@ -88,9 +89,7 @@ def _start(model_name: str,
"""
from .entrypoint import start_command
from .entrypoint import start_grpc_command
llm_config = openllm.AutoConfig.for_model(model_name)
_ModelEnv = openllm_core.utils.EnvVarMixin(model_name, backend=openllm_core.utils.first_not_none(backend, default=llm_config.default_backend()), model_id=model_id, quantize=quantize)
os.environ[_ModelEnv.backend] = _ModelEnv['backend_value']
os.environ['OPENLLM_BACKEND'] = openllm_core.utils.first_not_none(backend, default='vllm' if is_vllm_available() else 'pt')
args: list[str] = []
if model_id: args.extend(['--model-id', model_id])
@@ -218,7 +217,7 @@ def _import_model(model_name: str,
*,
model_id: str | None = None,
model_version: str | None = None,
backend: LiteralBackend = 'pt',
backend: LiteralBackend | None = None,
quantize: LiteralQuantise | None = None,
serialisation: t.Literal['legacy', 'safetensors'] | None = None,
additional_args: t.Sequence[str] | None = None) -> bentoml.Model:
@@ -254,7 +253,8 @@ def _import_model(model_name: str,
from .entrypoint import import_command
config = openllm.AutoConfig.for_model(model_name)
_serialisation = openllm_core.utils.first_not_none(serialisation, default=config['serialisation'])
args = [model_name, '--backend', backend, '--machine', '--serialisation', _serialisation]
args = [model_name, '--machine', '--serialisation', _serialisation]
if backend is not None: args.extend(['--backend', backend])
if model_id is not None: args.append(model_id)
if model_version is not None: args.extend(['--model-version', str(model_version)])
if additional_args is not None: args.extend(additional_args)

View File

@@ -27,9 +27,7 @@ import logging
import os
import platform
import subprocess
import sys
import time
import traceback
import typing as t
import attr
@@ -48,20 +46,11 @@ from simple_di import inject
import bentoml
import openllm
import openllm_core
from bentoml._internal.configuration.containers import BentoMLContainer
from bentoml._internal.models.model import ModelStore
from openllm import bundle
from openllm.exceptions import OpenLLMException
from openllm.models.auto import CONFIG_MAPPING
from openllm.models.auto import MODEL_FLAX_MAPPING_NAMES
from openllm.models.auto import MODEL_MAPPING_NAMES
from openllm.models.auto import MODEL_TF_MAPPING_NAMES
from openllm.models.auto import MODEL_VLLM_MAPPING_NAMES
from openllm.models.auto import AutoConfig
from openllm.models.auto import AutoLLM
from openllm.utils import infer_auto_class
from openllm_core._typing_compat import Concatenate
from openllm_core._typing_compat import DictStrAny
from openllm_core._typing_compat import LiteralBackend
@@ -70,20 +59,21 @@ from openllm_core._typing_compat import LiteralSerialisation
from openllm_core._typing_compat import LiteralString
from openllm_core._typing_compat import ParamSpec
from openllm_core._typing_compat import Self
from openllm_core.utils import DEBUG
from openllm_core.config import CONFIG_MAPPING
from openllm_core.utils import DEBUG_ENV_VAR
from openllm_core.utils import OPTIONAL_DEPENDENCIES
from openllm_core.utils import QUIET_ENV_VAR
from openllm_core.utils import EnvVarMixin
from openllm_core.utils import LazyLoader
from openllm_core.utils import analytics
from openllm_core.utils import bentoml_cattr
from openllm_core.utils import compose
from openllm_core.utils import configure_logging
from openllm_core.utils import converter
from openllm_core.utils import first_not_none
from openllm_core.utils import get_debug_mode
from openllm_core.utils import get_quiet_mode
from openllm_core.utils import is_torch_available
from openllm_core.utils import is_vllm_available
from openllm_core.utils import resolve_user_filepath
from openllm_core.utils import set_debug_mode
from openllm_core.utils import set_quiet_mode
@@ -112,7 +102,7 @@ if t.TYPE_CHECKING:
from bentoml._internal.bento import BentoStore
from bentoml._internal.container import DefaultBuilder
from openllm_client._schemas import Response
from openllm_client._schemas import StreamResponse
from openllm_client._schemas import StreamingResponse
from openllm_core._typing_compat import LiteralContainerRegistry
from openllm_core._typing_compat import LiteralContainerVersionStrategy
else:
@@ -347,9 +337,8 @@ _start_mapping = {
@machine_option
@backend_option
@serialisation_option
def import_command(model_name: str, model_id: str | None, converter: str | None, model_version: str | None, output: LiteralOutput, machine: bool, backend: LiteralBackend,
quantize: LiteralQuantise | None, serialisation: LiteralSerialisation | None,
) -> bentoml.Model:
def import_command(model_name: str, model_id: str | None, converter: str | None, model_version: str | None, output: LiteralOutput, machine: bool, backend: LiteralBackend | None,
quantize: LiteralQuantise | None, serialisation: LiteralSerialisation | None) -> bentoml.Model:
"""Setup LLM interactively.
It accepts two positional arguments: `model_name` and `model_id`. The first name determine
@@ -400,24 +389,19 @@ def import_command(model_name: str, model_id: str | None, converter: str | None,
$ CONVERTER=llama2-hf openllm import llama /path/to/llama-2
```
"""
llm_config = AutoConfig.for_model(model_name)
_serialisation = openllm_core.utils.first_not_none(serialisation, default=llm_config['serialisation'])
env = EnvVarMixin(model_name, backend=llm_config.default_backend(), model_id=model_id, quantize=quantize)
backend = first_not_none(backend, default=env['backend_value'])
llm = infer_auto_class(backend).for_model(model_name,
model_id=env['model_id_value'],
llm_config=llm_config,
model_version=model_version,
ensure_available=False,
quantize=env['quantize_value'],
serialisation=_serialisation)
llm_config = openllm.AutoConfig.for_model(model_name)
_serialisation = t.cast(LiteralSerialisation, first_not_none(serialisation, default=llm_config['serialisation']))
env = EnvVarMixin(model_name, model_id=model_id, quantize=quantize)
model_id = first_not_none(model_id, env['model_id_value'], default=llm_config['default_id'])
backend = first_not_none(backend, env['backend_value'], default='vllm' if is_vllm_available() else 'pt')
llm = openllm.LLM[t.Any, t.Any](model_id=model_id, llm_config=llm_config, revision=model_version, quantize=env['quantize_value'], serialisation=_serialisation, backend=backend)
_previously_saved = False
try:
_ref = openllm.serialisation.get(llm)
_previously_saved = True
except openllm.exceptions.OpenLLMException:
if not machine and output == 'pretty':
msg = f"'{model_name}' {'with model_id='+ model_id if model_id is not None else ''} does not exists in local store for backend {llm.__llm_backend__}. Saving to BENTOML_HOME{' (path=' + os.environ.get('BENTOML_HOME', BentoMLContainer.bentoml_home.get()) + ')' if get_debug_mode() else ''}..."
msg = f"'{model_name}' with model_id='{model_id}' does not exists in local store for backend {llm.__llm_backend__}. Saving to BENTOML_HOME{' (path=' + os.environ.get('BENTOML_HOME', BentoMLContainer.bentoml_home.get()) + ')' if get_debug_mode() else ''}..."
termui.echo(msg, fg='yellow', nl=True)
_ref = openllm.serialisation.get(llm, auto_import=True)
if backend == 'pt' and is_torch_available() and torch.cuda.is_available(): torch.cuda.empty_cache()
@@ -471,11 +455,10 @@ def import_command(model_name: str, model_id: str | None, converter: str | None,
@click.option('--force-push', default=False, is_flag=True, type=click.BOOL, help='Whether to force push.')
@click.pass_context
def build_command(ctx: click.Context, /, model_name: str, model_id: str | None, bento_version: str | None, overwrite: bool, output: LiteralOutput, quantize: LiteralQuantise | None,
enable_features: tuple[str, ...] | None, workers_per_resource: float | None, adapter_id: tuple[str, ...], build_ctx: str | None, backend: LiteralBackend,
enable_features: tuple[str, ...] | None, workers_per_resource: float | None, adapter_id: tuple[str, ...], build_ctx: str | None, backend: LiteralBackend | None,
system_message: str | None, prompt_template_file: t.IO[t.Any] | None, machine: bool, model_version: str | None, dockerfile_template: t.TextIO | None, containerize: bool,
push: bool, serialisation: LiteralSerialisation | None, container_registry: LiteralContainerRegistry, container_version_strategy: LiteralContainerVersionStrategy,
force_push: bool, **attrs: t.Any,
) -> bentoml.Bento:
force_push: bool, **attrs: t.Any) -> bentoml.Bento:
'''Package a given models into a Bento.
\b
@@ -498,9 +481,9 @@ def build_command(ctx: click.Context, /, model_name: str, model_id: str | None,
_previously_built = False
llm_config = AutoConfig.for_model(model_name)
_serialisation = openllm_core.utils.first_not_none(serialisation, default=llm_config['serialisation'])
env = EnvVarMixin(model_name, backend=backend, model_id=model_id, quantize=quantize)
llm_config = openllm.AutoConfig.for_model(model_name)
_serialisation = t.cast(LiteralSerialisation, first_not_none(serialisation, default=llm_config['serialisation']))
env = EnvVarMixin(model_name, backend=first_not_none(backend, default='vllm' if is_vllm_available() else 'pt'), model_id=model_id or llm_config['default_id'], quantize=quantize)
prompt_template: str | None = prompt_template_file.read() if prompt_template_file is not None else None
# NOTE: We set this environment variable so that our service.py logic won't raise RuntimeError
@@ -509,21 +492,25 @@ def build_command(ctx: click.Context, /, model_name: str, model_id: str | None,
os.environ.update({'OPENLLM_MODEL': inflection.underscore(model_name), 'OPENLLM_SERIALIZATION': _serialisation, env.backend: env['backend_value']})
if env['model_id_value']: os.environ[env.model_id] = str(env['model_id_value'])
if env['quantize_value']: os.environ[env.quantize] = str(env['quantize_value'])
if env['backend_value']: os.environ[env.backend] = str(env['backend_value'])
if system_message: os.environ['OPENLLM_SYSTEM_MESSAGE'] = system_message
if prompt_template: os.environ['OPENLLM_PROMPT_TEMPLATE'] = prompt_template
llm = infer_auto_class(env['backend_value']).for_model(model_name,
model_id=env['model_id_value'],
prompt_template=prompt_template,
system_message=system_message,
llm_config=llm_config,
ensure_available=True,
model_version=model_version,
quantize=env['quantize_value'],
serialisation=_serialisation,
**attrs)
llm = openllm.LLM[t.Any, t.Any](model_id=env['model_id_value'] or llm_config['default_id'],
revision=model_version,
prompt_template=prompt_template,
system_message=system_message,
llm_config=llm_config,
backend=env['backend_value'],
quantize=env['quantize_value'],
serialisation=_serialisation,
**attrs)
llm.save_pretrained() # ensure_available = True
assert llm.bentomodel # HACK: call it here to patch correct tag with revision and everything
# FIX: This is a patch for _service_vars injection
if 'OPENLLM_MODEL_ID' not in os.environ: os.environ['OPENLLM_MODEL_ID'] = llm.model_id
if 'OPENLLM_ADAPTER_MAP' not in os.environ: os.environ['OPENLLM_ADAPTER_MAP'] = orjson.dumps(None).decode()
labels = dict(llm.identifying_params)
labels.update({'_type': llm.llm_type, '_framework': env['backend_value']})
@@ -536,13 +523,13 @@ def build_command(ctx: click.Context, /, model_name: str, model_id: str | None,
llm_fs.writetext('Dockerfile.template', dockerfile_template.read())
dockerfile_template_path = llm_fs.getsyspath('/Dockerfile.template')
adapter_map: dict[str, str | None] | None = None
adapter_map: dict[str, str] | None = None
if adapter_id:
if not build_ctx: ctx.fail("'build_ctx' is required when '--adapter-id' is passsed.")
adapter_map = {}
for v in adapter_id:
_adapter_id, *adapter_name = v.rsplit(':', maxsplit=1)
name = adapter_name[0] if len(adapter_name) > 0 else None
name = adapter_name[0] if len(adapter_name) > 0 else 'default'
try:
resolve_user_filepath(_adapter_id, build_ctx)
src_folder_name = os.path.basename(_adapter_id)
@@ -558,7 +545,7 @@ def build_command(ctx: click.Context, /, model_name: str, model_id: str | None,
adapter_map[_adapter_id] = name
os.environ['OPENLLM_ADAPTER_MAP'] = orjson.dumps(adapter_map).decode()
_bento_version = first_not_none(bento_version, default=llm.tag.version)
_bento_version = first_not_none(bento_version, default=llm.bentomodel.tag.version)
bento_tag = bentoml.Tag.from_taglike(f'{llm.llm_type}-service:{_bento_version}'.lower().strip())
try:
bento = bentoml.get(bento_tag)
@@ -633,29 +620,17 @@ def models_command(ctx: click.Context, output: LiteralOutput, show_available: bo
if show_available: raise click.BadOptionUsage('--show-available', "Cannot use '--show-available' with '-o porcelain' (mutually exclusive).")
termui.echo('\n'.join(models), fg='white')
else:
failed_initialized: list[tuple[str, Exception]] = []
json_data: dict[str, dict[t.Literal['architecture', 'model_id', 'url', 'installation', 'cpu', 'gpu', 'backend'], t.Any] | t.Any] = {}
converted: list[str] = []
for m in models:
config = AutoConfig.for_model(m)
backend: tuple[str, ...] = ()
if config['model_name'] in MODEL_MAPPING_NAMES: backend += ('pt',)
if config['model_name'] in MODEL_FLAX_MAPPING_NAMES: backend += ('flax',)
if config['model_name'] in MODEL_TF_MAPPING_NAMES: backend += ('tf',)
if config['model_name'] in MODEL_VLLM_MAPPING_NAMES: backend += ('vllm',)
config = openllm.AutoConfig.for_model(m)
json_data[m] = {
'architecture': config['architecture'],
'model_id': config['model_ids'],
'backend': backend,
'backend': config['backend'],
'installation': f'"openllm[{m}]"' if m in OPTIONAL_DEPENDENCIES or config['requirements'] else 'openllm',
}
converted.extend([normalise_model_name(i) for i in config['model_ids']])
if DEBUG:
try:
AutoLLM.for_model(m, llm_config=config)
except Exception as e:
failed_initialized.append((m, e))
ids_in_local_store = {
k: [
@@ -680,22 +655,9 @@ def models_command(ctx: click.Context, output: LiteralOutput, show_available: bo
data.extend([(m, v['architecture'], v['model_id'], v['installation'], v['backend'])])
column_widths = [int(termui.COLUMNS / 12), int(termui.COLUMNS / 6), int(termui.COLUMNS / 4), int(termui.COLUMNS / 6), int(termui.COLUMNS / 4)]
if len(data) == 0 and len(failed_initialized) > 0:
termui.echo('Exception found while parsing models:\n', fg='yellow')
for m, err in failed_initialized:
termui.echo(f'- {m}: ', fg='yellow', nl=False)
termui.echo(traceback.print_exception(None, err, None, limit=5), fg='red') # type: ignore[func-returns-value]
sys.exit(1)
table = tabulate.tabulate(data, tablefmt='fancy_grid', headers=['LLM', 'Architecture', 'Models Id', 'Installation', 'Runtime'], maxcolwidths=column_widths)
termui.echo(table, fg='white')
if DEBUG and len(failed_initialized) > 0:
termui.echo('\nThe following models are supported but failed to initialize:\n')
for m, err in failed_initialized:
termui.echo(f'- {m}: ', fg='blue', nl=False)
termui.echo(err, fg='red')
if show_available:
if len(ids_in_local_store) == 0:
termui.echo('No models available locally.')
@@ -837,14 +799,14 @@ def query_command(ctx: click.Context, /, prompt: str, endpoint: str, timeout: in
termui.echo(f'{prompt}', fg=input_fg)
if stream:
stream_res: t.Iterator[StreamResponse] = client.generate_stream(prompt, **{**client._config(), **_memoized})
stream_res: t.Iterator[StreamingResponse] = client.generate_stream(prompt, **{**client._config(), **_memoized})
if output == 'pretty':
termui.echo('\n\n==Responses==\n', fg='white')
for it in stream_res:
termui.echo(it.text, fg=generated_fg, nl=False)
elif output == 'json':
for it in stream_res:
termui.echo(orjson.dumps(bentoml_cattr.unstructure(it), option=orjson.OPT_INDENT_2).decode(), fg='white')
termui.echo(orjson.dumps(converter.unstructure(it), option=orjson.OPT_INDENT_2).decode(), fg='white')
else:
for it in stream_res:
termui.echo(it.text, fg=generated_fg, nl=False)
@@ -852,11 +814,11 @@ def query_command(ctx: click.Context, /, prompt: str, endpoint: str, timeout: in
res: Response = client.generate(prompt, **{**client._config(), **_memoized})
if output == 'pretty':
termui.echo('\n\n==Responses==\n', fg='white')
termui.echo(res.responses[0], fg=generated_fg)
termui.echo(res.outputs[0].text, fg=generated_fg)
elif output == 'json':
termui.echo(orjson.dumps(bentoml_cattr.unstructure(res), option=orjson.OPT_INDENT_2).decode(), fg='white')
termui.echo(orjson.dumps(converter.unstructure(res), option=orjson.OPT_INDENT_2).decode(), fg='white')
else:
termui.echo(res.responses, fg='white')
termui.echo(res.outputs[0].text, fg='white')
ctx.exit(0)
@cli.group(cls=Extensions, hidden=True, name='extension')

View File

@@ -14,7 +14,7 @@ from bentoml._internal.configuration.containers import BentoMLContainer
from bentoml._internal.container.generate import generate_containerfile
from openllm.cli import termui
from openllm.cli._factory import bento_complete_envvar
from openllm_core.utils import bentoml_cattr
from openllm_core.utils import converter
if t.TYPE_CHECKING:
from bentoml._internal.bento import BentoStore
@@ -35,7 +35,7 @@ def cli(ctx: click.Context, bento: str, _bento_store: BentoStore = Provide[Bento
# Dockerfile inside bento, and it is not relevant to
# construct_containerfile. Hence it is safe to set it to None here.
# See https://github.com/bentoml/BentoML/issues/3399.
docker_attrs = bentoml_cattr.unstructure(options.docker)
docker_attrs = converter.unstructure(options.docker)
# NOTE: if users specify a dockerfile_template, we will
# save it to /env/docker/Dockerfile.template. This is necessary
# for the reconstruction of the Dockerfile.