perf: unify LLM interface (#518)

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
Signed-off-by: paperspace <29749331+aarnphm@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Aaron Pham
2023-11-06 20:39:43 -05:00
committed by GitHub
parent f2639879af
commit e2029c934b
136 changed files with 9646 additions and 11244 deletions

View File

@@ -27,9 +27,7 @@ import logging
import os
import platform
import subprocess
import sys
import time
import traceback
import typing as t
import attr
@@ -48,20 +46,11 @@ from simple_di import inject
import bentoml
import openllm
import openllm_core
from bentoml._internal.configuration.containers import BentoMLContainer
from bentoml._internal.models.model import ModelStore
from openllm import bundle
from openllm.exceptions import OpenLLMException
from openllm.models.auto import CONFIG_MAPPING
from openllm.models.auto import MODEL_FLAX_MAPPING_NAMES
from openllm.models.auto import MODEL_MAPPING_NAMES
from openllm.models.auto import MODEL_TF_MAPPING_NAMES
from openllm.models.auto import MODEL_VLLM_MAPPING_NAMES
from openllm.models.auto import AutoConfig
from openllm.models.auto import AutoLLM
from openllm.utils import infer_auto_class
from openllm_core._typing_compat import Concatenate
from openllm_core._typing_compat import DictStrAny
from openllm_core._typing_compat import LiteralBackend
@@ -70,20 +59,21 @@ from openllm_core._typing_compat import LiteralSerialisation
from openllm_core._typing_compat import LiteralString
from openllm_core._typing_compat import ParamSpec
from openllm_core._typing_compat import Self
from openllm_core.utils import DEBUG
from openllm_core.config import CONFIG_MAPPING
from openllm_core.utils import DEBUG_ENV_VAR
from openllm_core.utils import OPTIONAL_DEPENDENCIES
from openllm_core.utils import QUIET_ENV_VAR
from openllm_core.utils import EnvVarMixin
from openllm_core.utils import LazyLoader
from openllm_core.utils import analytics
from openllm_core.utils import bentoml_cattr
from openllm_core.utils import compose
from openllm_core.utils import configure_logging
from openllm_core.utils import converter
from openllm_core.utils import first_not_none
from openllm_core.utils import get_debug_mode
from openllm_core.utils import get_quiet_mode
from openllm_core.utils import is_torch_available
from openllm_core.utils import is_vllm_available
from openllm_core.utils import resolve_user_filepath
from openllm_core.utils import set_debug_mode
from openllm_core.utils import set_quiet_mode
@@ -112,7 +102,7 @@ if t.TYPE_CHECKING:
from bentoml._internal.bento import BentoStore
from bentoml._internal.container import DefaultBuilder
from openllm_client._schemas import Response
from openllm_client._schemas import StreamResponse
from openllm_client._schemas import StreamingResponse
from openllm_core._typing_compat import LiteralContainerRegistry
from openllm_core._typing_compat import LiteralContainerVersionStrategy
else:
@@ -347,9 +337,8 @@ _start_mapping = {
@machine_option
@backend_option
@serialisation_option
def import_command(model_name: str, model_id: str | None, converter: str | None, model_version: str | None, output: LiteralOutput, machine: bool, backend: LiteralBackend,
quantize: LiteralQuantise | None, serialisation: LiteralSerialisation | None,
) -> bentoml.Model:
def import_command(model_name: str, model_id: str | None, converter: str | None, model_version: str | None, output: LiteralOutput, machine: bool, backend: LiteralBackend | None,
quantize: LiteralQuantise | None, serialisation: LiteralSerialisation | None) -> bentoml.Model:
"""Setup LLM interactively.
It accepts two positional arguments: `model_name` and `model_id`. The first name determine
@@ -400,24 +389,19 @@ def import_command(model_name: str, model_id: str | None, converter: str | None,
$ CONVERTER=llama2-hf openllm import llama /path/to/llama-2
```
"""
llm_config = AutoConfig.for_model(model_name)
_serialisation = openllm_core.utils.first_not_none(serialisation, default=llm_config['serialisation'])
env = EnvVarMixin(model_name, backend=llm_config.default_backend(), model_id=model_id, quantize=quantize)
backend = first_not_none(backend, default=env['backend_value'])
llm = infer_auto_class(backend).for_model(model_name,
model_id=env['model_id_value'],
llm_config=llm_config,
model_version=model_version,
ensure_available=False,
quantize=env['quantize_value'],
serialisation=_serialisation)
llm_config = openllm.AutoConfig.for_model(model_name)
_serialisation = t.cast(LiteralSerialisation, first_not_none(serialisation, default=llm_config['serialisation']))
env = EnvVarMixin(model_name, model_id=model_id, quantize=quantize)
model_id = first_not_none(model_id, env['model_id_value'], default=llm_config['default_id'])
backend = first_not_none(backend, env['backend_value'], default='vllm' if is_vllm_available() else 'pt')
llm = openllm.LLM[t.Any, t.Any](model_id=model_id, llm_config=llm_config, revision=model_version, quantize=env['quantize_value'], serialisation=_serialisation, backend=backend)
_previously_saved = False
try:
_ref = openllm.serialisation.get(llm)
_previously_saved = True
except openllm.exceptions.OpenLLMException:
if not machine and output == 'pretty':
msg = f"'{model_name}' {'with model_id='+ model_id if model_id is not None else ''} does not exists in local store for backend {llm.__llm_backend__}. Saving to BENTOML_HOME{' (path=' + os.environ.get('BENTOML_HOME', BentoMLContainer.bentoml_home.get()) + ')' if get_debug_mode() else ''}..."
msg = f"'{model_name}' with model_id='{model_id}' does not exists in local store for backend {llm.__llm_backend__}. Saving to BENTOML_HOME{' (path=' + os.environ.get('BENTOML_HOME', BentoMLContainer.bentoml_home.get()) + ')' if get_debug_mode() else ''}..."
termui.echo(msg, fg='yellow', nl=True)
_ref = openllm.serialisation.get(llm, auto_import=True)
if backend == 'pt' and is_torch_available() and torch.cuda.is_available(): torch.cuda.empty_cache()
@@ -471,11 +455,10 @@ def import_command(model_name: str, model_id: str | None, converter: str | None,
@click.option('--force-push', default=False, is_flag=True, type=click.BOOL, help='Whether to force push.')
@click.pass_context
def build_command(ctx: click.Context, /, model_name: str, model_id: str | None, bento_version: str | None, overwrite: bool, output: LiteralOutput, quantize: LiteralQuantise | None,
enable_features: tuple[str, ...] | None, workers_per_resource: float | None, adapter_id: tuple[str, ...], build_ctx: str | None, backend: LiteralBackend,
enable_features: tuple[str, ...] | None, workers_per_resource: float | None, adapter_id: tuple[str, ...], build_ctx: str | None, backend: LiteralBackend | None,
system_message: str | None, prompt_template_file: t.IO[t.Any] | None, machine: bool, model_version: str | None, dockerfile_template: t.TextIO | None, containerize: bool,
push: bool, serialisation: LiteralSerialisation | None, container_registry: LiteralContainerRegistry, container_version_strategy: LiteralContainerVersionStrategy,
force_push: bool, **attrs: t.Any,
) -> bentoml.Bento:
force_push: bool, **attrs: t.Any) -> bentoml.Bento:
'''Package a given models into a Bento.
\b
@@ -498,9 +481,9 @@ def build_command(ctx: click.Context, /, model_name: str, model_id: str | None,
_previously_built = False
llm_config = AutoConfig.for_model(model_name)
_serialisation = openllm_core.utils.first_not_none(serialisation, default=llm_config['serialisation'])
env = EnvVarMixin(model_name, backend=backend, model_id=model_id, quantize=quantize)
llm_config = openllm.AutoConfig.for_model(model_name)
_serialisation = t.cast(LiteralSerialisation, first_not_none(serialisation, default=llm_config['serialisation']))
env = EnvVarMixin(model_name, backend=first_not_none(backend, default='vllm' if is_vllm_available() else 'pt'), model_id=model_id or llm_config['default_id'], quantize=quantize)
prompt_template: str | None = prompt_template_file.read() if prompt_template_file is not None else None
# NOTE: We set this environment variable so that our service.py logic won't raise RuntimeError
@@ -509,21 +492,25 @@ def build_command(ctx: click.Context, /, model_name: str, model_id: str | None,
os.environ.update({'OPENLLM_MODEL': inflection.underscore(model_name), 'OPENLLM_SERIALIZATION': _serialisation, env.backend: env['backend_value']})
if env['model_id_value']: os.environ[env.model_id] = str(env['model_id_value'])
if env['quantize_value']: os.environ[env.quantize] = str(env['quantize_value'])
if env['backend_value']: os.environ[env.backend] = str(env['backend_value'])
if system_message: os.environ['OPENLLM_SYSTEM_MESSAGE'] = system_message
if prompt_template: os.environ['OPENLLM_PROMPT_TEMPLATE'] = prompt_template
llm = infer_auto_class(env['backend_value']).for_model(model_name,
model_id=env['model_id_value'],
prompt_template=prompt_template,
system_message=system_message,
llm_config=llm_config,
ensure_available=True,
model_version=model_version,
quantize=env['quantize_value'],
serialisation=_serialisation,
**attrs)
llm = openllm.LLM[t.Any, t.Any](model_id=env['model_id_value'] or llm_config['default_id'],
revision=model_version,
prompt_template=prompt_template,
system_message=system_message,
llm_config=llm_config,
backend=env['backend_value'],
quantize=env['quantize_value'],
serialisation=_serialisation,
**attrs)
llm.save_pretrained() # ensure_available = True
assert llm.bentomodel # HACK: call it here to patch correct tag with revision and everything
# FIX: This is a patch for _service_vars injection
if 'OPENLLM_MODEL_ID' not in os.environ: os.environ['OPENLLM_MODEL_ID'] = llm.model_id
if 'OPENLLM_ADAPTER_MAP' not in os.environ: os.environ['OPENLLM_ADAPTER_MAP'] = orjson.dumps(None).decode()
labels = dict(llm.identifying_params)
labels.update({'_type': llm.llm_type, '_framework': env['backend_value']})
@@ -536,13 +523,13 @@ def build_command(ctx: click.Context, /, model_name: str, model_id: str | None,
llm_fs.writetext('Dockerfile.template', dockerfile_template.read())
dockerfile_template_path = llm_fs.getsyspath('/Dockerfile.template')
adapter_map: dict[str, str | None] | None = None
adapter_map: dict[str, str] | None = None
if adapter_id:
if not build_ctx: ctx.fail("'build_ctx' is required when '--adapter-id' is passsed.")
adapter_map = {}
for v in adapter_id:
_adapter_id, *adapter_name = v.rsplit(':', maxsplit=1)
name = adapter_name[0] if len(adapter_name) > 0 else None
name = adapter_name[0] if len(adapter_name) > 0 else 'default'
try:
resolve_user_filepath(_adapter_id, build_ctx)
src_folder_name = os.path.basename(_adapter_id)
@@ -558,7 +545,7 @@ def build_command(ctx: click.Context, /, model_name: str, model_id: str | None,
adapter_map[_adapter_id] = name
os.environ['OPENLLM_ADAPTER_MAP'] = orjson.dumps(adapter_map).decode()
_bento_version = first_not_none(bento_version, default=llm.tag.version)
_bento_version = first_not_none(bento_version, default=llm.bentomodel.tag.version)
bento_tag = bentoml.Tag.from_taglike(f'{llm.llm_type}-service:{_bento_version}'.lower().strip())
try:
bento = bentoml.get(bento_tag)
@@ -633,29 +620,17 @@ def models_command(ctx: click.Context, output: LiteralOutput, show_available: bo
if show_available: raise click.BadOptionUsage('--show-available', "Cannot use '--show-available' with '-o porcelain' (mutually exclusive).")
termui.echo('\n'.join(models), fg='white')
else:
failed_initialized: list[tuple[str, Exception]] = []
json_data: dict[str, dict[t.Literal['architecture', 'model_id', 'url', 'installation', 'cpu', 'gpu', 'backend'], t.Any] | t.Any] = {}
converted: list[str] = []
for m in models:
config = AutoConfig.for_model(m)
backend: tuple[str, ...] = ()
if config['model_name'] in MODEL_MAPPING_NAMES: backend += ('pt',)
if config['model_name'] in MODEL_FLAX_MAPPING_NAMES: backend += ('flax',)
if config['model_name'] in MODEL_TF_MAPPING_NAMES: backend += ('tf',)
if config['model_name'] in MODEL_VLLM_MAPPING_NAMES: backend += ('vllm',)
config = openllm.AutoConfig.for_model(m)
json_data[m] = {
'architecture': config['architecture'],
'model_id': config['model_ids'],
'backend': backend,
'backend': config['backend'],
'installation': f'"openllm[{m}]"' if m in OPTIONAL_DEPENDENCIES or config['requirements'] else 'openllm',
}
converted.extend([normalise_model_name(i) for i in config['model_ids']])
if DEBUG:
try:
AutoLLM.for_model(m, llm_config=config)
except Exception as e:
failed_initialized.append((m, e))
ids_in_local_store = {
k: [
@@ -680,22 +655,9 @@ def models_command(ctx: click.Context, output: LiteralOutput, show_available: bo
data.extend([(m, v['architecture'], v['model_id'], v['installation'], v['backend'])])
column_widths = [int(termui.COLUMNS / 12), int(termui.COLUMNS / 6), int(termui.COLUMNS / 4), int(termui.COLUMNS / 6), int(termui.COLUMNS / 4)]
if len(data) == 0 and len(failed_initialized) > 0:
termui.echo('Exception found while parsing models:\n', fg='yellow')
for m, err in failed_initialized:
termui.echo(f'- {m}: ', fg='yellow', nl=False)
termui.echo(traceback.print_exception(None, err, None, limit=5), fg='red') # type: ignore[func-returns-value]
sys.exit(1)
table = tabulate.tabulate(data, tablefmt='fancy_grid', headers=['LLM', 'Architecture', 'Models Id', 'Installation', 'Runtime'], maxcolwidths=column_widths)
termui.echo(table, fg='white')
if DEBUG and len(failed_initialized) > 0:
termui.echo('\nThe following models are supported but failed to initialize:\n')
for m, err in failed_initialized:
termui.echo(f'- {m}: ', fg='blue', nl=False)
termui.echo(err, fg='red')
if show_available:
if len(ids_in_local_store) == 0:
termui.echo('No models available locally.')
@@ -837,14 +799,14 @@ def query_command(ctx: click.Context, /, prompt: str, endpoint: str, timeout: in
termui.echo(f'{prompt}', fg=input_fg)
if stream:
stream_res: t.Iterator[StreamResponse] = client.generate_stream(prompt, **{**client._config(), **_memoized})
stream_res: t.Iterator[StreamingResponse] = client.generate_stream(prompt, **{**client._config(), **_memoized})
if output == 'pretty':
termui.echo('\n\n==Responses==\n', fg='white')
for it in stream_res:
termui.echo(it.text, fg=generated_fg, nl=False)
elif output == 'json':
for it in stream_res:
termui.echo(orjson.dumps(bentoml_cattr.unstructure(it), option=orjson.OPT_INDENT_2).decode(), fg='white')
termui.echo(orjson.dumps(converter.unstructure(it), option=orjson.OPT_INDENT_2).decode(), fg='white')
else:
for it in stream_res:
termui.echo(it.text, fg=generated_fg, nl=False)
@@ -852,11 +814,11 @@ def query_command(ctx: click.Context, /, prompt: str, endpoint: str, timeout: in
res: Response = client.generate(prompt, **{**client._config(), **_memoized})
if output == 'pretty':
termui.echo('\n\n==Responses==\n', fg='white')
termui.echo(res.responses[0], fg=generated_fg)
termui.echo(res.outputs[0].text, fg=generated_fg)
elif output == 'json':
termui.echo(orjson.dumps(bentoml_cattr.unstructure(res), option=orjson.OPT_INDENT_2).decode(), fg='white')
termui.echo(orjson.dumps(converter.unstructure(res), option=orjson.OPT_INDENT_2).decode(), fg='white')
else:
termui.echo(res.responses, fg='white')
termui.echo(res.outputs[0].text, fg='white')
ctx.exit(0)
@cli.group(cls=Extensions, hidden=True, name='extension')