refactor(cli): cleanup API (#592)

* chore: remove unused imports

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>

* refactor(cli): update to only need model_id

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>

* feat: `openllm start model-id`

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>

* chore: add changelog

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>

* chore: update changelog notice

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>

* chore: update correct config and running tools

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>

* chore: update backward compat options and treat JSON outputs
corespondingly

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>

---------

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
Aaron Pham
2023-11-09 11:40:17 -05:00
committed by GitHub
parent 86f7acafa9
commit b8a2e8cf91
48 changed files with 1096 additions and 1047 deletions

View File

@@ -1,6 +1,5 @@
from __future__ import annotations
import functools
import importlib.util
import logging
import os
import typing as t
@@ -8,37 +7,32 @@ import typing as t
import click
import click_option_group as cog
import inflection
import orjson
from bentoml_cli.utils import BentoMLCommandGroup
from click import ClickException
from click import shell_completion as sc
from click.shell_completion import CompletionItem
import bentoml
import openllm
import openllm_core
from bentoml._internal.configuration.containers import BentoMLContainer
from openllm_core._configuration import LLMConfig
from openllm_core._typing_compat import Concatenate
from openllm_core._typing_compat import DictStrAny
from openllm_core._typing_compat import LiteralBackend
from openllm_core._typing_compat import LiteralQuantise
from openllm_core._typing_compat import LiteralSerialisation
from openllm_core._typing_compat import LiteralString
from openllm_core._typing_compat import ParamSpec
from openllm_core._typing_compat import get_literal_args
from openllm_core.utils import DEBUG
from openllm_core.utils import check_bool_env
from openllm_core.utils import first_not_none
from openllm_core.utils import is_vllm_available
from . import termui
class _OpenLLM_GenericInternalConfig(LLMConfig):
__config__ = {'name_type': 'lowercase', 'default_id': 'openllm/generic', 'model_ids': ['openllm/generic'], 'architecture': 'PreTrainedModel'}
if t.TYPE_CHECKING:
import subprocess
from openllm_core._configuration import LLMConfig
class GenerationConfig:
top_k: int = 15
top_p: float = 0.9
temperature: float = 0.75
max_new_tokens: int = 128
logger = logging.getLogger(__name__)
@@ -91,146 +85,12 @@ def _id_callback(ctx: click.Context, _: click.Parameter, value: t.Tuple[str, ...
ctx.params[_adapter_mapping_key][adapter_id] = adapter_name[0]
return None
def start_command_factory(group: click.Group, model: str, _context_settings: DictStrAny | None = None, _serve_grpc: bool = False) -> click.Command:
llm_config = openllm.AutoConfig.for_model(model)
command_attrs: DictStrAny = dict(name=llm_config['model_name'],
context_settings=_context_settings or termui.CONTEXT_SETTINGS,
short_help=f"Start a LLMServer for '{model}'",
aliases=[llm_config['start_name']] if llm_config['name_type'] == 'dasherize' else None,
help=f'''\
{llm_config['env'].start_docstring}
\b
Note: ``{llm_config['start_name']}`` can also be run with any other models available on HuggingFace
or fine-tuned variants as long as it belongs to the architecture generation ``{llm_config['architecture']}`` (trust_remote_code={llm_config['trust_remote_code']}).
\b
For example: One can start [Fastchat-T5](https://huggingface.co/lmsys/fastchat-t5-3b-v1.0) with ``openllm start flan-t5``:
\b
$ openllm start flan-t5 --model-id lmsys/fastchat-t5-3b-v1.0
\b
Available official model_id(s): [default: {llm_config['default_id']}]
\b
{orjson.dumps(llm_config['model_ids'], option=orjson.OPT_INDENT_2).decode()}
''')
@group.command(**command_attrs)
@start_decorator(llm_config, serve_grpc=_serve_grpc)
@click.pass_context
def start_cmd(ctx: click.Context, /, server_timeout: int, model_id: str | None, model_version: str | None, system_message: str | None, prompt_template_file: t.IO[t.Any] | None,
workers_per_resource: t.Literal['conserved', 'round_robin'] | LiteralString, device: t.Tuple[str, ...], quantize: LiteralQuantise | None, backend: LiteralBackend | None,
serialisation: LiteralSerialisation | None, cors: bool, adapter_id: str | None, return_process: bool, **attrs: t.Any) -> LLMConfig | subprocess.Popen[bytes]:
_serialisation = t.cast(LiteralSerialisation, first_not_none(serialisation, default=llm_config['serialisation']))
if _serialisation == 'safetensors' and quantize is not None and check_bool_env('OPENLLM_SERIALIZATION_WARNING'):
termui.echo(
f"'--quantize={quantize}' might not work with 'safetensors' serialisation format. To silence this warning, set \"OPENLLM_SERIALIZATION_WARNING=False\"\nNote: You can always fallback to '--serialisation legacy' when running quantisation.",
fg='yellow')
termui.echo(f"Make sure to check out '{model_id}' repository to see if the weights is in '{_serialisation}' format if unsure.")
adapter_map: dict[str, str] | None = attrs.pop(_adapter_mapping_key, None)
config, server_attrs = llm_config.model_validate_click(**attrs)
server_timeout = first_not_none(server_timeout, default=config['timeout'])
server_attrs.update({'working_dir': os.path.dirname(os.path.dirname(__file__)), 'timeout': server_timeout})
if _serve_grpc: server_attrs['grpc_protocol_version'] = 'v1'
# NOTE: currently, theres no development args in bentoml.Server. To be fixed upstream.
development = server_attrs.pop('development')
server_attrs.setdefault('production', not development)
wpr = first_not_none(workers_per_resource, default=config['workers_per_resource'])
if isinstance(wpr, str):
if wpr == 'round_robin': wpr = 1.0
elif wpr == 'conserved':
if device and openllm.utils.device_count() == 0:
termui.echo('--device will have no effect as there is no GPUs available', fg='yellow')
wpr = 1.0
else:
available_gpu = len(device) if device else openllm.utils.device_count()
wpr = 1.0 if available_gpu == 0 else float(1 / available_gpu)
else:
wpr = float(wpr)
elif isinstance(wpr, int):
wpr = float(wpr)
# Create a new model env to work with the envvar during CLI invocation
env = openllm.utils.EnvVarMixin(config['model_name'],
backend=openllm_core.utils.first_not_none(backend, default='vllm' if is_vllm_available() else 'pt'),
model_id=model_id or config['default_id'],
quantize=quantize)
requirements = llm_config['requirements']
if requirements is not None and len(requirements) > 0:
missing_requirements = [i for i in requirements if importlib.util.find_spec(inflection.underscore(i)) is None]
if len(missing_requirements) > 0:
termui.echo(f'Make sure to have the following dependencies available: {missing_requirements}', fg='yellow')
# NOTE: This is to set current configuration
start_env = os.environ.copy()
start_env = parse_config_options(config, server_timeout, wpr, device, cors, start_env)
prompt_template: str | None = prompt_template_file.read() if prompt_template_file is not None else None
start_env.update({
'OPENLLM_MODEL': model,
'BENTOML_DEBUG': str(openllm.utils.get_debug_mode()),
'BENTOML_HOME': os.environ.get('BENTOML_HOME', BentoMLContainer.bentoml_home.get()),
'OPENLLM_ADAPTER_MAP': orjson.dumps(adapter_map).decode(),
'OPENLLM_SERIALIZATION': _serialisation,
env.backend: env['backend_value'],
})
if env['model_id_value']: start_env[env.model_id] = str(env['model_id_value'])
if env['quantize_value']: start_env[env.quantize] = str(env['quantize_value'])
if system_message: start_env['OPENLLM_SYSTEM_MESSAGE'] = system_message
if prompt_template: start_env['OPENLLM_PROMPT_TEMPLATE'] = prompt_template
llm = openllm.LLM[t.Any, t.Any](model_id=start_env[env.model_id],
revision=model_version,
prompt_template=prompt_template,
system_message=system_message,
llm_config=config,
backend=env['backend_value'],
adapter_map=adapter_map,
quantize=env['quantize_value'],
serialisation=_serialisation)
llm.save_pretrained() # ensure_available = True
start_env.update({env.config: llm.config.model_dump_json().decode()})
server = bentoml.GrpcServer('_service:svc', **server_attrs) if _serve_grpc else bentoml.HTTPServer('_service:svc', **server_attrs)
openllm.utils.analytics.track_start_init(llm.config)
def next_step(model_name: str, adapter_map: DictStrAny | None) -> None:
cmd_name = f'openllm build {model_name}'
if not llm._local: cmd_name += f' --model-id {llm.model_id}'
if llm._quantise: cmd_name += f' --quantize {llm._quantise}'
cmd_name += f' --serialization {_serialisation}'
if adapter_map is not None:
cmd_name += ' ' + ' '.join([f'--adapter-id {s}' for s in [f'{p}:{name}' if name not in (None, 'default') else p for p, name in adapter_map.items()]])
if not openllm.utils.get_quiet_mode():
termui.echo(f"\n🚀 Next step: run '{cmd_name}' to create a Bento for {model_name}", fg='blue')
if return_process:
server.start(env=start_env, text=True)
if server.process is None: raise click.ClickException('Failed to start the server.')
return server.process
else:
try:
server.start(env=start_env, text=True, blocking=True)
except Exception as err:
termui.echo(f'Error caught while running LLM Server:\n{err}', fg='red')
raise
else:
next_step(model, adapter_map)
# NOTE: Return the configuration for telemetry purposes.
return config
return start_cmd
def start_decorator(llm_config: LLMConfig, serve_grpc: bool = False) -> t.Callable[[FC], t.Callable[[FC], FC]]:
def start_decorator(serve_grpc: bool = False) -> t.Callable[[FC], t.Callable[[FC], FC]]:
def wrapper(fn: FC) -> t.Callable[[FC], FC]:
composed = openllm.utils.compose(
llm_config.to_click_options, _http_server_args if not serve_grpc else _grpc_server_args,
cog.optgroup.group('General LLM Options', help=f"The following options are related to running '{llm_config['start_name']}' LLM Server."), model_id_option(factory=cog.optgroup),
model_version_option(factory=cog.optgroup), system_message_option(factory=cog.optgroup), prompt_template_file_option(factory=cog.optgroup),
_OpenLLM_GenericInternalConfig().to_click_options, _http_server_args if not serve_grpc else _grpc_server_args,
cog.optgroup.group('General LLM Options', help='The following options are related to running LLM Server.'), model_version_option(factory=cog.optgroup),
system_message_option(factory=cog.optgroup), prompt_template_file_option(factory=cog.optgroup),
cog.optgroup.option('--server-timeout', type=int, default=None, help='Server timeout in seconds'), workers_per_resource_option(factory=cog.optgroup), cors_option(factory=cog.optgroup),
backend_option(factory=cog.optgroup),
cog.optgroup.group('LLM Optimization Options',
@@ -248,7 +108,7 @@ def start_decorator(llm_config: LLMConfig, serve_grpc: bool = False) -> t.Callab
multiple=True,
envvar='CUDA_VISIBLE_DEVICES',
callback=parse_device_callback,
help=f"Assign GPU devices (if available) for {llm_config['model_name']}.",
help='Assign GPU devices (if available)',
show_envvar=True),
cog.optgroup.group('Fine-tuning related options',
help='''\
@@ -268,7 +128,7 @@ def start_decorator(llm_config: LLMConfig, serve_grpc: bool = False) -> t.Callab
'''),
cog.optgroup.option('--adapter-id',
default=None,
help='Optional name or path for given LoRA adapter' + f" to wrap '{llm_config['model_name']}'",
help='Optional name or path for given LoRA adapter',
multiple=True,
callback=_id_callback,
metavar='[PATH | [remote/][adapter_name:]adapter_id][, ...]'), click.option('--return-process', is_flag=True, default=False, help='Internal use only.',
@@ -341,24 +201,6 @@ def _click_factory_type(*param_decls: t.Any, **attrs: t.Any) -> t.Callable[[FC |
cli_option = functools.partial(_click_factory_type, attr='option')
cli_argument = functools.partial(_click_factory_type, attr='argument')
def output_option(f: _AnyCallable | None = None, *, default_value: LiteralOutput = 'pretty', **attrs: t.Any) -> t.Callable[[FC], FC]:
output = ['json', 'pretty', 'porcelain']
def complete_output_var(ctx: click.Context, param: click.Parameter, incomplete: str) -> list[CompletionItem]:
return [CompletionItem(it) for it in output]
return cli_option('-o',
'--output',
'output',
type=click.Choice(output),
default=default_value,
help='Showing output type.',
show_default=True,
envvar='OPENLLM_OUTPUT',
show_envvar=True,
shell_complete=complete_output_var,
**attrs)(f)
def cors_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
return cli_option('--cors/--no-cors', show_default=True, default=False, envvar='OPENLLM_CORS', show_envvar=True, help='Enable CORS for the server.', **attrs)(f)
@@ -450,37 +292,23 @@ def serialisation_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Cal
return cli_option('--serialisation',
'--serialization',
'serialisation',
type=str,
type=click.Choice(get_literal_args(LiteralSerialisation)),
default=None,
show_default=True,
show_envvar=True,
envvar='OPENLLM_SERIALIZATION',
callback=serialisation_callback,
help='''Serialisation format for save/load LLM.
Currently the following strategies are supported:
- ``safetensors``: This will use safetensors format, which is synonymous to
- ``safetensors``: This will use safetensors format, which is synonymous to ``safe_serialization=True``.
\b
``safe_serialization=True``.
\b
> [!NOTE] that this format might not work for every cases, and
you can always fallback to ``legacy`` if needed.
> [!NOTE] Safetensors might not work for every cases, and you can always fallback to ``legacy`` if needed.
- ``legacy``: This will use PyTorch serialisation format, often as ``.bin`` files. This should be used if the model doesn't yet support safetensors.
> [!NOTE] that GGML format is working in progress.
''',
**attrs)(f)
def serialisation_callback(ctx: click.Context, param: click.Parameter, value: LiteralSerialisation | None) -> LiteralSerialisation | None:
if value is None: return value
if value not in {'safetensors', 'legacy'}:
raise click.BadParameter(f"'serialisation' only accept 'safetensors', 'legacy' as serialisation format. got {value} instead.", ctx, param) from None
return value
def container_registry_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
return cli_option('--container-registry',
'container_registry',

View File

@@ -7,20 +7,21 @@ import subprocess
import sys
import typing as t
import orjson
from simple_di import Provide
from simple_di import inject
import bentoml
import openllm
import openllm_core
from bentoml._internal.configuration.containers import BentoMLContainer
from openllm.exceptions import OpenLLMException
from openllm_core._typing_compat import LiteralSerialisation
from openllm_core.exceptions import OpenLLMException
from openllm_core.utils import codegen
from openllm_core.utils import first_not_none
from openllm_core.utils import is_vllm_available
from . import termui
from ._factory import start_command_factory
if t.TYPE_CHECKING:
from bentoml._internal.bento import BentoStore
from openllm_core._configuration import LLMConfig
@@ -28,15 +29,11 @@ if t.TYPE_CHECKING:
from openllm_core._typing_compat import LiteralContainerRegistry
from openllm_core._typing_compat import LiteralContainerVersionStrategy
from openllm_core._typing_compat import LiteralQuantise
from openllm_core._typing_compat import LiteralSerialisation
from openllm_core._typing_compat import LiteralString
logger = logging.getLogger(__name__)
def _start(model_name: str,
/,
*,
model_id: str | None = None,
def _start(model_id: str,
timeout: int = 30,
workers_per_resource: t.Literal['conserved', 'round_robin'] | float | None = None,
device: tuple[str, ...] | t.Literal['all'] | None = None,
@@ -61,8 +58,7 @@ def _start(model_name: str,
``openllm.start`` will invoke ``click.Command`` under the hood, so it behaves exactly the same as the CLI interaction.
Args:
model_name: The model name to start this LLM
model_id: Optional model id for this given LLM
model_id: The model id to start this LLMServer
timeout: The server timeout
system_message: Optional system message for supported LLMs. If given LLM supports system message, OpenLLM will provide a default system message.
prompt_template_file: Optional file path containing user-defined custom prompt template. By default, the prompt template for the specified LLM will be used..
@@ -91,8 +87,7 @@ def _start(model_name: str,
from .entrypoint import start_grpc_command
os.environ['OPENLLM_BACKEND'] = openllm_core.utils.first_not_none(backend, default='vllm' if is_vllm_available() else 'pt')
args: list[str] = []
if model_id: args.extend(['--model-id', model_id])
args: list[str] = [model_id]
if system_message: args.extend(['--system-message', system_message])
if prompt_template_file: args.extend(['--prompt-template-file', openllm_core.utils.resolve_filepath(prompt_template_file)])
if timeout: args.extend(['--server-timeout', str(timeout)])
@@ -106,14 +101,11 @@ def _start(model_name: str,
if additional_args: args.extend(additional_args)
if __test__: args.append('--return-process')
return start_command_factory(start_command if not _serve_grpc else start_grpc_command, model_name, _context_settings=termui.CONTEXT_SETTINGS,
_serve_grpc=_serve_grpc).main(args=args if len(args) > 0 else None, standalone_mode=False)
cmd = start_command if not _serve_grpc else start_grpc_command
return cmd.main(args=args, standalone_mode=False)
@inject
def _build(model_name: str,
/,
*,
model_id: str | None = None,
def _build(model_id: str,
model_version: str | None = None,
bento_version: str | None = None,
quantize: LiteralQuantise | None = None,
@@ -122,17 +114,17 @@ def _build(model_name: str,
prompt_template_file: str | None = None,
build_ctx: str | None = None,
enable_features: tuple[str, ...] | None = None,
workers_per_resource: float | None = None,
dockerfile_template: str | None = None,
overwrite: bool = False,
container_registry: LiteralContainerRegistry | None = None,
container_version_strategy: LiteralContainerVersionStrategy | None = None,
push: bool = False,
force_push: bool = False,
containerize: bool = False,
serialisation: LiteralSerialisation | None = None,
additional_args: list[str] | None = None,
bento_store: BentoStore = Provide[BentoMLContainer.bento_store]) -> bentoml.Bento:
"""Package a LLM into a Bento.
"""Package a LLM into a BentoLLM.
The LLM will be built into a BentoService with the following structure:
if ``quantize`` is passed, it will instruct the model to be quantized dynamically during serving time.
@@ -140,8 +132,7 @@ def _build(model_name: str,
``openllm.build`` will invoke ``click.Command`` under the hood, so it behaves exactly the same as ``openllm build`` CLI.
Args:
model_name: The model name to start this LLM
model_id: Optional model id for this given LLM
model_id: The model id to build this BentoLLM
model_version: Optional model version for this given LLM
bento_version: Optional bento veresion for this given BentoLLM
system_message: Optional system message for supported LLMs. If given LLM supports system message, OpenLLM will provide a default system message.
@@ -154,15 +145,6 @@ def _build(model_name: str,
adapter_map: The adapter mapping of LoRA to use for this LLM. It accepts a dictionary of ``{adapter_id: adapter_name}``.
build_ctx: The build context to use for building BentoLLM. By default, it sets to current directory.
enable_features: Additional OpenLLM features to be included with this BentoLLM.
workers_per_resource: Number of workers per resource assigned.
See [resource scheduling](https://docs.bentoml.org/en/latest/guides/scheduling.html#resource-scheduling-strategy)
for more information. By default, this is set to 1.
> [!NOTE] ``--workers-per-resource`` will also accept the following strategies:
> - ``round_robin``: Similar behaviour when setting ``--workers-per-resource 1``. This is useful for smaller models.
> - ``conserved``: This will determine the number of available GPU resources, and only assign
> one worker for the LLMRunner. For example, if ther are 4 GPUs available, then ``conserved`` is
> equivalent to ``--workers-per-resource 0.25``.
dockerfile_template: The dockerfile template to use for building BentoLLM. See https://docs.bentoml.com/en/latest/guides/containerization.html#dockerfile-template.
overwrite: Whether to overwrite the existing BentoLLM. By default, this is set to ``False``.
push: Whether to push the result bento to BentoCloud. Make sure to login with 'bentoml cloud login' first.
@@ -178,17 +160,17 @@ def _build(model_name: str,
Returns:
``bentoml.Bento | str``: BentoLLM instance. This can be used to serve the LLM or can be pushed to BentoCloud.
"""
config = openllm.AutoConfig.for_model(model_name)
_serialisation = openllm_core.utils.first_not_none(serialisation, default=config['serialisation'])
args: list[str] = [sys.executable, '-m', 'openllm', 'build', model_name, '--machine', '--serialisation', _serialisation]
from ..serialisation.transformers.weights import has_safetensors_weights
args: list[str] = [
sys.executable, '-m', 'openllm', 'build', model_id, '--machine', '--serialisation',
t.cast(LiteralSerialisation, first_not_none(serialisation, default='safetensors' if has_safetensors_weights(model_id) else 'legacy'))
]
if quantize: args.extend(['--quantize', quantize])
if containerize and push: raise OpenLLMException("'containerize' and 'push' are currently mutually exclusive.")
if push: args.extend(['--push'])
if containerize: args.extend(['--containerize'])
if model_id: args.extend(['--model-id', model_id])
if build_ctx: args.extend(['--build-ctx', build_ctx])
if enable_features: args.extend([f'--enable-features={f}' for f in enable_features])
if workers_per_resource: args.extend(['--workers-per-resource', str(workers_per_resource)])
if overwrite: args.append('--overwrite')
if system_message: args.extend(['--system-message', system_message])
if prompt_template_file: args.extend(['--prompt-template-file', openllm_core.utils.resolve_filepath(prompt_template_file)])
@@ -204,23 +186,24 @@ def _build(model_name: str,
try:
output = subprocess.check_output(args, env=os.environ.copy(), cwd=build_ctx or os.getcwd())
except subprocess.CalledProcessError as e:
logger.error('Exception caught while building %s', model_name, exc_info=e)
logger.error("Exception caught while building Bento for '%s'", model_id, exc_info=e)
if e.stderr: raise OpenLLMException(e.stderr.decode('utf-8')) from None
raise OpenLLMException(str(e)) from None
matched = re.match(r'__tag__:([^:\n]+:[^:\n]+)$', output.decode('utf-8').strip())
matched = re.match(r'__object__:(\{.*\})$', output.decode('utf-8').strip())
if matched is None:
raise ValueError(f"Failed to find tag from output: {output.decode('utf-8').strip()}\nNote: Output from 'openllm build' might not be correct. Please open an issue on GitHub.")
return bentoml.get(matched.group(1), _bento_store=bento_store)
try:
result = orjson.loads(matched.group(1))
except orjson.JSONDecodeError as e:
raise ValueError(f"Failed to decode JSON from output: {output.decode('utf-8').strip()}\nNote: Output from 'openllm build' might not be correct. Please open an issue on GitHub.") from e
return bentoml.get(result['tag'], _bento_store=bento_store)
def _import_model(model_name: str,
/,
*,
model_id: str | None = None,
def _import_model(model_id: str,
model_version: str | None = None,
backend: LiteralBackend | None = None,
quantize: LiteralQuantise | None = None,
serialisation: t.Literal['legacy', 'safetensors'] | None = None,
additional_args: t.Sequence[str] | None = None) -> bentoml.Model:
serialisation: LiteralSerialisation | None = None,
additional_args: t.Sequence[str] | None = None) -> dict[str, t.Any]:
"""Import a LLM into local store.
> [!NOTE]
@@ -228,14 +211,13 @@ def _import_model(model_name: str,
> only use this option if you want the weight to be quantized by default. Note that OpenLLM also
> support on-demand quantisation during initial startup.
``openllm.download`` will invoke ``click.Command`` under the hood, so it behaves exactly the same as the CLI ``openllm import``.
``openllm.import_model`` will invoke ``click.Command`` under the hood, so it behaves exactly the same as the CLI ``openllm import``.
> [!NOTE]
> ``openllm.start`` will automatically invoke ``openllm.download`` under the hood.
> ``openllm.start`` will automatically invoke ``openllm.import_model`` under the hood.
Args:
model_name: The model name to start this LLM
model_id: Optional model id for this given LLM
model_id: required model id for this given LLM
model_version: Optional model version for this given LLM
backend: The backend to use for this LLM. By default, this is set to ``pt``.
quantize: Quantize the model weights. This is only applicable for PyTorch models.
@@ -243,29 +225,26 @@ def _import_model(model_name: str,
- int8: Quantize the model with 8bit (bitsandbytes required)
- int4: Quantize the model with 4bit (bitsandbytes required)
- gptq: Quantize the model with GPTQ (auto-gptq required)
serialisation: Type of model format to save to local store. If set to 'safetensors', then OpenLLM will save model using safetensors.
Default behaviour is similar to ``safe_serialization=False``.
serialisation: Type of model format to save to local store. If set to 'safetensors', then OpenLLM will save model using safetensors. Default behaviour is similar to ``safe_serialization=False``.
additional_args: Additional arguments to pass to ``openllm import``.
Returns:
``bentoml.Model``:BentoModel of the given LLM. This can be used to serve the LLM or can be pushed to BentoCloud.
"""
from .entrypoint import import_command
config = openllm.AutoConfig.for_model(model_name)
_serialisation = openllm_core.utils.first_not_none(serialisation, default=config['serialisation'])
args = [model_name, '--machine', '--serialisation', _serialisation]
args = [model_id, '--quiet']
if backend is not None: args.extend(['--backend', backend])
if model_id is not None: args.append(model_id)
if model_version is not None: args.extend(['--model-version', str(model_version)])
if additional_args is not None: args.extend(additional_args)
if quantize is not None: args.extend(['--quantize', quantize])
if serialisation is not None: args.extend(['--serialisation', serialisation])
if additional_args is not None: args.extend(additional_args)
return import_command.main(args=args, standalone_mode=False)
def _list_models() -> dict[str, t.Any]:
"""List all available models within the local store."""
from .entrypoint import models_command
return models_command.main(args=['-o', 'json', '--show-available', '--machine'], standalone_mode=False)
return models_command.main(args=['--show-available', '--quiet'], standalone_mode=False)
start, start_grpc, build, import_model, list_models = openllm_core.utils.codegen.gen_sdk(_start, _serve_grpc=False), openllm_core.utils.codegen.gen_sdk(
_start, _serve_grpc=True), openllm_core.utils.codegen.gen_sdk(_build), openllm_core.utils.codegen.gen_sdk(_import_model), openllm_core.utils.codegen.gen_sdk(_list_models)
start, start_grpc = codegen.gen_sdk(_start, _serve_grpc=False), codegen.gen_sdk(_start, _serve_grpc=True)
build, import_model, list_models = codegen.gen_sdk(_build), codegen.gen_sdk(_import_model), codegen.gen_sdk(_list_models)
__all__ = ['start', 'start_grpc', 'build', 'import_model', 'list_models']

View File

@@ -5,29 +5,33 @@ This module also contains the SDK to call ``start`` and ``build`` from SDK
Start any LLM:
```python
openllm.start("falcon", model_id='tiiuae/falcon-7b-instruct')
openllm.start('mistral', model_id='mistralai/Mistral-7B-v0.1')
```
Build a BentoLLM
```python
bento = openllm.build("falcon")
bento = openllm.build('mistralai/Mistral-7B-v0.1')
```
Import any LLM into local store
```python
bentomodel = openllm.import_model("falcon", model_id='tiiuae/falcon-7b-instruct')
bentomodel = openllm.import_model('mistralai/Mistral-7B-v0.1')
```
"""
from __future__ import annotations
import enum
import functools
import importlib.util
import inspect
import itertools
import logging
import os
import platform
import random
import subprocess
import time
import traceback
import typing as t
import attr
@@ -57,23 +61,22 @@ from openllm_core._typing_compat import LiteralBackend
from openllm_core._typing_compat import LiteralQuantise
from openllm_core._typing_compat import LiteralSerialisation
from openllm_core._typing_compat import LiteralString
from openllm_core._typing_compat import NotRequired
from openllm_core._typing_compat import ParamSpec
from openllm_core._typing_compat import Self
from openllm_core.config import CONFIG_MAPPING
from openllm_core.utils import DEBUG_ENV_VAR
from openllm_core.utils import OPTIONAL_DEPENDENCIES
from openllm_core.utils import QUIET_ENV_VAR
from openllm_core.utils import EnvVarMixin
from openllm_core.utils import LazyLoader
from openllm_core.utils import analytics
from openllm_core.utils import check_bool_env
from openllm_core.utils import compose
from openllm_core.utils import configure_logging
from openllm_core.utils import converter
from openllm_core.utils import first_not_none
from openllm_core.utils import get_debug_mode
from openllm_core.utils import get_quiet_mode
from openllm_core.utils import is_torch_available
from openllm_core.utils import is_vllm_available
from openllm_core.utils import resolve_user_filepath
from openllm_core.utils import set_debug_mode
from openllm_core.utils import set_quiet_mode
@@ -85,24 +88,22 @@ from ._factory import _AnyCallable
from ._factory import backend_option
from ._factory import container_registry_option
from ._factory import machine_option
from ._factory import model_id_option
from ._factory import model_name_argument
from ._factory import model_version_option
from ._factory import output_option
from ._factory import parse_config_options
from ._factory import prompt_template_file_option
from ._factory import quantize_option
from ._factory import serialisation_option
from ._factory import start_command_factory
from ._factory import start_decorator
from ._factory import system_message_option
from ._factory import workers_per_resource_option
if t.TYPE_CHECKING:
import torch
from bentoml._internal.bento import BentoStore
from bentoml._internal.container import DefaultBuilder
from openllm_client._schemas import Response
from openllm_client._schemas import StreamingResponse
from openllm_core._configuration import LLMConfig
from openllm_core._typing_compat import LiteralContainerRegistry
from openllm_core._typing_compat import LiteralContainerVersionStrategy
else:
@@ -134,6 +135,16 @@ _object_setattr = object.__setattr__
_EXT_FOLDER = os.path.abspath(os.path.join(os.path.dirname(__file__), 'extension'))
def backend_warning(backend: LiteralBackend):
if backend == 'pt' and check_bool_env('OPENLLM_BACKEND_WARNING') and not get_quiet_mode():
if openllm.utils.is_vllm_available():
termui.warning(
'\nvLLM is available, but using PyTorch backend instead. Note that vLLM is a lot more performant and should always be used in production (by explicitly set --backend vllm).')
else:
termui.warning('\nvLLM is not available. Note that PyTorch backend is not as performant as vLLM and you should always consider using vLLM for production.')
termui.debug(
content="\nTip: if you are running 'openllm build' you can set '--backend vllm' to package your Bento with vLLM backend. To hide these messages, set 'OPENLLM_BACKEND_WARNING=False'\n")
class Extensions(click.MultiCommand):
def list_commands(self, ctx: click.Context) -> list[str]:
return sorted([filename[:-3] for filename in os.listdir(_EXT_FOLDER) if filename.endswith('.py') and not filename.startswith('__')])
@@ -162,7 +173,7 @@ class OpenLLMCommandGroup(BentoMLCommandGroup):
ctx.obj = GlobalOptions(cloud_context=cloud_context)
if quiet:
set_quiet_mode(True)
if debug: logger.warning("'--quiet' passed; ignoring '--verbose/--debug'")
if debug: termui.warning("'--quiet' passed; ignoring '--verbose/--debug'")
elif debug: set_debug_mode(True)
configure_logging()
return f(*args, **attrs)
@@ -202,21 +213,9 @@ class OpenLLMCommandGroup(BentoMLCommandGroup):
if cmd_name in t.cast('Extensions', extension_command).list_commands(ctx):
return t.cast('Extensions', extension_command).get_command(ctx, cmd_name)
cmd_name = self.resolve_alias(cmd_name)
if ctx.command.name in _start_mapping:
try:
return _start_mapping[ctx.command.name][cmd_name]
except KeyError:
# TODO: support start from a bento
try:
bentoml.get(cmd_name)
raise click.ClickException(f"'openllm start {cmd_name}' is currently disabled for the time being. Please let us know if you need this feature by opening an issue on GitHub.")
except bentoml.exceptions.NotFound:
pass
raise click.BadArgumentUsage(f'{cmd_name} is not a valid model identifier supported by OpenLLM.') from None
return super().get_command(ctx, cmd_name)
def list_commands(self, ctx: click.Context) -> list[str]:
if ctx.command.name in {'start', 'start-grpc'}: return list(CONFIG_MAPPING.keys())
return super().list_commands(ctx) + t.cast('Extensions', extension_command).list_commands(ctx)
def command(self, *args: t.Any, **kwargs: t.Any) -> t.Callable[[t.Callable[..., t.Any]], click.Command]: # type: ignore[override] # XXX: fix decorator on BentoMLCommandGroup
@@ -280,10 +279,7 @@ class OpenLLMCommandGroup(BentoMLCommandGroup):
formatter.write_dl(rows)
@click.group(cls=OpenLLMCommandGroup, context_settings=termui.CONTEXT_SETTINGS, name='openllm')
@click.version_option(None,
'--version',
'-v',
message=f"%(prog)s, %(version)s (compiled: {'yes' if openllm.COMPILED else 'no'})\nPython ({platform.python_implementation()}) {platform.python_version()}")
@click.version_option(None, '--version', '-v', message=f'%(prog)s, %(version)s (compiled: {openllm.COMPILED})\nPython ({platform.python_implementation()}) {platform.python_version()}')
def cli() -> None:
"""\b
██████╗ ██████╗ ███████╗███╗ ██╗██╗ ██╗ ███╗ ███╗
@@ -298,52 +294,285 @@ def cli() -> None:
Fine-tune, serve, deploy, and monitor any LLMs with ease.
"""
@cli.group(cls=OpenLLMCommandGroup, context_settings=termui.CONTEXT_SETTINGS, name='start', aliases=['start-http'])
def start_command() -> None:
@cli.command(context_settings=termui.CONTEXT_SETTINGS, name='start', aliases=['start-http'], short_help='Start a LLMServer for any supported LLM.')
@click.argument('model_id', type=click.STRING, metavar='[REMOTE_REPO/MODEL_ID | /path/to/local/model]', required=True)
@click.option('--model-id',
'deprecated_model_id',
type=click.STRING,
default=None,
hidden=True,
metavar='[REMOTE_REPO/MODEL_ID | /path/to/local/model]',
help='Deprecated. Use positional argument instead.')
@start_decorator(serve_grpc=False)
def start_command(model_id: str, server_timeout: int, model_version: str | None, system_message: str | None, prompt_template_file: t.IO[t.Any] | None,
workers_per_resource: t.Literal['conserved', 'round_robin'] | LiteralString, device: t.Tuple[str, ...], quantize: LiteralQuantise | None, backend: LiteralBackend | None,
serialisation: LiteralSerialisation | None, cors: bool, adapter_id: str | None, return_process: bool, deprecated_model_id: str | None,
**attrs: t.Any) -> LLMConfig | subprocess.Popen[bytes]:
"""Start any LLM as a REST server.
\b
```bash
$ openllm <start|start-http> <model_name> --<options> ...
$ openllm <start|start-http> <model_id> --<options> ...
```
"""
if model_id in openllm.CONFIG_MAPPING:
_model_name = model_id
if deprecated_model_id is not None: model_id = deprecated_model_id
else: model_id = openllm.AutoConfig.for_model(_model_name)['default_id']
termui.warning(
f"Passing 'openllm start {_model_name}{'' if deprecated_model_id is None else ' --model-id ' + deprecated_model_id}' is deprecated and will be remove in a future version. Use 'openllm start {model_id}' instead."
)
@cli.group(cls=OpenLLMCommandGroup, context_settings=termui.CONTEXT_SETTINGS, name='start-grpc')
def start_grpc_command() -> None:
adapter_map: dict[str, str] | None = attrs.pop('adapter_map', None)
prompt_template = prompt_template_file.read() if prompt_template_file is not None else None
from ..serialisation.transformers.weights import has_safetensors_weights
serialisation = t.cast(LiteralSerialisation, first_not_none(serialisation, default='safetensors' if has_safetensors_weights(model_id) else 'legacy'))
if serialisation == 'safetensors' and quantize is not None and check_bool_env('OPENLLM_SERIALIZATION_WARNING'):
termui.warning(
f"'--quantize={quantize}' might not work with 'safetensors' serialisation format. To silence this warning, set \"OPENLLM_SERIALIZATION_WARNING=False\"\nNote: You can always fallback to '--serialisation legacy' when running quantisation."
)
termui.warning(f"Make sure to check out '{model_id}' repository to see if the weights is in '{serialisation}' format if unsure.")
llm = openllm.LLM[t.Any, t.Any](model_id=model_id,
model_version=model_version,
prompt_template=prompt_template,
system_message=system_message,
backend=backend,
adapter_map=adapter_map,
quantize=quantize,
serialisation=serialisation,
trust_remote_code=check_bool_env('TRUST_REMOTE_CODE'))
backend_warning(llm.__llm_backend__)
config, server_attrs = llm.config.model_validate_click(**attrs)
server_timeout = first_not_none(server_timeout, default=config['timeout'])
server_attrs.update({'working_dir': os.path.dirname(os.path.dirname(__file__)), 'timeout': server_timeout})
# XXX: currently, theres no development args in bentoml.Server. To be fixed upstream.
development = server_attrs.pop('development')
server_attrs.setdefault('production', not development)
wpr = first_not_none(workers_per_resource, default=config['workers_per_resource'])
if isinstance(wpr, str):
if wpr == 'round_robin': wpr = 1.0
elif wpr == 'conserved':
if device and openllm.utils.device_count() == 0:
termui.echo('--device will have no effect as there is no GPUs available', fg='yellow')
wpr = 1.0
else:
available_gpu = len(device) if device else openllm.utils.device_count()
wpr = 1.0 if available_gpu == 0 else float(1 / available_gpu)
else:
wpr = float(wpr)
elif isinstance(wpr, int):
wpr = float(wpr)
requirements = llm.config['requirements']
if requirements is not None and len(requirements) > 0:
missing_requirements = [i for i in requirements if importlib.util.find_spec(inflection.underscore(i)) is None]
if len(missing_requirements) > 0:
termui.echo(f'Make sure to have the following dependencies available: {missing_requirements}', fg='yellow')
start_env = parse_config_options(config, server_timeout, wpr, device, cors, os.environ.copy())
start_env.update({
'OPENLLM_MODEL_ID': model_id,
'BENTOML_DEBUG': str(openllm.utils.get_debug_mode()),
'BENTOML_HOME': os.environ.get('BENTOML_HOME', BentoMLContainer.bentoml_home.get()),
'OPENLLM_ADAPTER_MAP': orjson.dumps(adapter_map).decode(),
'OPENLLM_SERIALIZATION': serialisation,
'OPENLLM_BACKEND': llm.__llm_backend__,
'OPENLLM_CONFIG': llm.config.model_dump_json(flatten=True).decode(),
})
if llm._quantise: start_env['OPENLLM_QUANTIZE'] = str(llm._quantise)
if system_message: start_env['OPENLLM_SYSTEM_MESSAGE'] = system_message
if prompt_template: start_env['OPENLLM_PROMPT_TEMPLATE'] = prompt_template
server = bentoml.HTTPServer('_service:svc', **server_attrs)
openllm.utils.analytics.track_start_init(llm.config)
def next_step(adapter_map: DictStrAny | None, caught_exception: bool = False) -> None:
if caught_exception: return
cmd_name = f'openllm build {model_id}'
if llm._quantise: cmd_name += f' --quantize {llm._quantise}'
cmd_name += f' --serialization {serialisation}'
if adapter_map is not None:
cmd_name += ' ' + ' '.join([f'--adapter-id {s}' for s in [f'{p}:{name}' if name not in (None, 'default') else p for p, name in adapter_map.items()]])
if not openllm.utils.get_quiet_mode():
termui.info(f"\n\n🚀 Next step: run '{cmd_name}' to create a BentoLLM for '{model_id}'")
_exception = False
if return_process:
server.start(env=start_env, text=True)
if server.process is None: raise click.ClickException('Failed to start the server.')
return server.process
else:
try:
server.start(env=start_env, text=True, blocking=True)
except KeyboardInterrupt:
_exception = True
except Exception as err:
termui.error(f'Error caught while running LLM Server:\n{err}')
_exception = True
raise
else:
next_step(adapter_map, _exception)
# NOTE: Return the configuration for telemetry purposes.
return config
@cli.command(context_settings=termui.CONTEXT_SETTINGS, name='start-grpc', short_help='Start a gRPC LLMServer for any supported LLM.')
@click.argument('model_id', type=click.STRING, metavar='[REMOTE_REPO/MODEL_ID | /path/to/local/model]', required=True)
@click.option('--model-id',
'deprecated_model_id',
type=click.STRING,
default=None,
hidden=True,
metavar='[REMOTE_REPO/MODEL_ID | /path/to/local/model]',
help='Deprecated. Use positional argument instead.')
@start_decorator(serve_grpc=True)
def start_grpc_command(model_id: str, server_timeout: int, model_version: str | None, system_message: str | None, prompt_template_file: t.IO[t.Any] | None,
workers_per_resource: t.Literal['conserved', 'round_robin'] | LiteralString, device: t.Tuple[str, ...], quantize: LiteralQuantise | None, backend: LiteralBackend | None,
serialisation: LiteralSerialisation | None, cors: bool, adapter_id: str | None, return_process: bool, deprecated_model_id: str | None,
**attrs: t.Any) -> LLMConfig | subprocess.Popen[bytes]:
"""Start any LLM as a gRPC server.
\b
```bash
$ openllm start-grpc <model_name> --<options> ...
$ openllm start-grpc <model_id> --<options> ...
```
"""
termui.warning('Continuous batching is currently not yet supported with gPRC. If you want to use continuous batching with gRPC, feel free to open a GitHub issue about your usecase.\n')
if model_id in openllm.CONFIG_MAPPING:
_model_name = model_id
if deprecated_model_id is not None: model_id = deprecated_model_id
else: model_id = openllm.AutoConfig.for_model(_model_name)['default_id']
termui.warning(
f"Passing 'openllm start-grpc {_model_name}{'' if deprecated_model_id is None else ' --model-id ' + deprecated_model_id}' is deprecated and will be remove in a future version. Use 'openllm start-grpc {model_id}' instead."
)
_start_mapping = {
'start': {
key: start_command_factory(start_command, key, _context_settings=termui.CONTEXT_SETTINGS) for key in CONFIG_MAPPING
},
'start-grpc': {
key: start_command_factory(start_grpc_command, key, _context_settings=termui.CONTEXT_SETTINGS, _serve_grpc=True) for key in CONFIG_MAPPING
}
}
adapter_map: dict[str, str] | None = attrs.pop('adapter_map', None)
prompt_template = prompt_template_file.read() if prompt_template_file is not None else None
from ..serialisation.transformers.weights import has_safetensors_weights
serialisation = t.cast(LiteralSerialisation, first_not_none(serialisation, default='safetensors' if has_safetensors_weights(model_id) else 'legacy'))
if serialisation == 'safetensors' and quantize is not None and check_bool_env('OPENLLM_SERIALIZATION_WARNING'):
termui.warning(
f"'--quantize={quantize}' might not work with 'safetensors' serialisation format. To silence this warning, set \"OPENLLM_SERIALIZATION_WARNING=False\"\nNote: You can always fallback to '--serialisation legacy' when running quantisation."
)
termui.warning(f"Make sure to check out '{model_id}' repository to see if the weights is in '{serialisation}' format if unsure.")
llm = openllm.LLM[t.Any, t.Any](model_id=model_id,
model_version=model_version,
prompt_template=prompt_template,
system_message=system_message,
backend=backend,
adapter_map=adapter_map,
quantize=quantize,
serialisation=serialisation,
trust_remote_code=check_bool_env('TRUST_REMOTE_CODE'))
backend_warning(llm.__llm_backend__)
config, server_attrs = llm.config.model_validate_click(**attrs)
server_timeout = first_not_none(server_timeout, default=config['timeout'])
server_attrs.update({'working_dir': os.path.dirname(os.path.dirname(__file__)), 'timeout': server_timeout})
server_attrs['grpc_protocol_version'] = 'v1'
# XXX: currently, theres no development args in bentoml.Server. To be fixed upstream.
development = server_attrs.pop('development')
server_attrs.setdefault('production', not development)
wpr = first_not_none(workers_per_resource, default=config['workers_per_resource'])
if isinstance(wpr, str):
if wpr == 'round_robin': wpr = 1.0
elif wpr == 'conserved':
if device and openllm.utils.device_count() == 0:
termui.echo('--device will have no effect as there is no GPUs available', fg='yellow')
wpr = 1.0
else:
available_gpu = len(device) if device else openllm.utils.device_count()
wpr = 1.0 if available_gpu == 0 else float(1 / available_gpu)
else:
wpr = float(wpr)
elif isinstance(wpr, int):
wpr = float(wpr)
requirements = llm.config['requirements']
if requirements is not None and len(requirements) > 0:
missing_requirements = [i for i in requirements if importlib.util.find_spec(inflection.underscore(i)) is None]
if len(missing_requirements) > 0:
termui.warning(f'Make sure to have the following dependencies available: {missing_requirements}')
start_env = parse_config_options(config, server_timeout, wpr, device, cors, os.environ.copy())
start_env.update({
'OPENLLM_MODEL_ID': model_id,
'BENTOML_DEBUG': str(openllm.utils.get_debug_mode()),
'BENTOML_HOME': os.environ.get('BENTOML_HOME', BentoMLContainer.bentoml_home.get()),
'OPENLLM_ADAPTER_MAP': orjson.dumps(adapter_map).decode(),
'OPENLLM_SERIALIZATION': serialisation,
'OPENLLM_BACKEND': llm.__llm_backend__,
'OPENLLM_CONFIG': llm.config.model_dump_json().decode(),
})
if llm._quantise: start_env['OPENLLM_QUANTIZE'] = str(llm._quantise)
if system_message: start_env['OPENLLM_SYSTEM_MESSAGE'] = system_message
if prompt_template: start_env['OPENLLM_PROMPT_TEMPLATE'] = prompt_template
server = bentoml.GrpcServer('_service:svc', **server_attrs)
openllm.utils.analytics.track_start_init(llm.config)
def next_step(adapter_map: DictStrAny | None, caught_exception: bool = False) -> None:
if caught_exception: return
cmd_name = f'openllm build {model_id}'
if llm._quantise: cmd_name += f' --quantize {llm._quantise}'
cmd_name += f' --serialization {serialisation}'
if adapter_map is not None:
cmd_name += ' ' + ' '.join([f'--adapter-id {s}' for s in [f'{p}:{name}' if name not in (None, 'default') else p for p, name in adapter_map.items()]])
if not openllm.utils.get_quiet_mode():
termui.info(f"\n🚀 Next step: run '{cmd_name}' to create a BentoLLM for '{model_id}'")
_exception = False
if return_process:
server.start(env=start_env, text=True)
if server.process is None: raise click.ClickException('Failed to start the server.')
return server.process
else:
try:
server.start(env=start_env, text=True, blocking=True)
except KeyboardInterrupt:
_exception = True
except Exception as err:
termui.error(f'Error caught while running LLM Server:\n{err}')
_exception = True
raise
else:
next_step(adapter_map, _exception)
# NOTE: Return the configuration for telemetry purposes.
return config
class ItemState(enum.Enum):
NOT_FOUND = 'NOT_FOUND'
EXISTS = 'EXISTS'
OVERWRITE = 'OVERWRITE'
class ImportModelOutput(t.TypedDict):
state: ItemState
backend: LiteralBackend
tag: str
@cli.command(name='import', aliases=['download'])
@model_name_argument
@click.argument('model_id', type=click.STRING, default=None, metavar='Optional[REMOTE_REPO/MODEL_ID | /path/to/local/model]', required=False)
@click.argument('converter', envvar='CONVERTER', type=click.STRING, default=None, required=False, metavar=None)
@click.argument('model_id', type=click.STRING, metavar='[REMOTE_REPO/MODEL_ID | /path/to/local/model]', required=True)
@click.option('--model-id',
'deprecated_model_id',
type=click.STRING,
default=None,
hidden=True,
metavar='[REMOTE_REPO/MODEL_ID | /path/to/local/model]',
help='Deprecated. Use positional argument instead.')
@model_version_option
@output_option
@quantize_option
@machine_option
@backend_option
@quantize_option
@serialisation_option
def import_command(model_name: str, model_id: str | None, converter: str | None, model_version: str | None, output: LiteralOutput, machine: bool, backend: LiteralBackend | None,
quantize: LiteralQuantise | None, serialisation: LiteralSerialisation | None) -> bentoml.Model:
def import_command(model_id: str, deprecated_model_id: str | None, model_version: str | None, backend: LiteralBackend | None, quantize: LiteralQuantise | None,
serialisation: LiteralSerialisation | None) -> ImportModelOutput:
"""Setup LLM interactively.
It accepts two positional arguments: `model_name` and `model_id`. The first name determine
the model type to download, and the second one is the optional model id to download.
\b
This `model_id` can be either pretrained model id that you can get from HuggingFace Hub, or
a custom model path from your custom pretrained model. Note that the custom model path should
@@ -351,8 +580,9 @@ def import_command(model_name: str, model_id: str | None, converter: str | None,
and `transformers.PreTrainedTokenizer` objects.
\b
Note: This is useful for development and setup for fine-tune.
This will be automatically called when `ensure_available=True` in `openllm.LLM.for_model`
Note that if `--serialisation` is not defined, then we will try to infer serialisation from HuggingFace Hub.
If the model id contains safetensors weights, then we will use `safetensors` serialisation. Otherwise, we will
fallback to `legacy` '.bin' (otherwise known as pickle) serialisation.
\b
``--model-version`` is an optional option to save the model. Note that
@@ -362,7 +592,7 @@ def import_command(model_name: str, model_id: str | None, converter: str | None,
\b
```bash
$ openllm import opt facebook/opt-2.7b
$ openllm import mistralai/Mistral-7B-v0.1
```
\b
@@ -370,41 +600,58 @@ def import_command(model_name: str, model_id: str | None, converter: str | None,
> only use this option if you want the weight to be quantized by default. Note that OpenLLM also
> support on-demand quantisation during initial startup.
"""
llm_config = openllm.AutoConfig.for_model(model_name)
_serialisation = t.cast(LiteralSerialisation, first_not_none(serialisation, default=llm_config['serialisation']))
env = EnvVarMixin(model_name, model_id=model_id, quantize=quantize)
model_id = first_not_none(model_id, env['model_id_value'], default=llm_config['default_id'])
backend = first_not_none(backend, env['backend_value'], default='vllm' if is_vllm_available() else 'pt')
llm = openllm.LLM[t.Any, t.Any](model_id=model_id, llm_config=llm_config, revision=model_version, quantize=env['quantize_value'], serialisation=_serialisation, backend=backend)
_previously_saved = False
from ..serialisation.transformers.weights import has_safetensors_weights
if model_id in openllm.CONFIG_MAPPING:
_model_name = model_id
if deprecated_model_id is not None: model_id = deprecated_model_id
else: model_id = openllm.AutoConfig.for_model(_model_name)['default_id']
termui.echo(
f"Passing 'openllm import {_model_name}{'' if deprecated_model_id is None else ' --model-id ' + deprecated_model_id}' is deprecated and will be remove in a future version. Use 'openllm import {model_id}' instead.",
fg='yellow')
llm = openllm.LLM[t.Any, t.Any](model_id=model_id,
model_version=model_version,
quantize=quantize,
backend=backend,
serialisation=t.cast(LiteralSerialisation, first_not_none(serialisation, default='safetensors' if has_safetensors_weights(model_id) else 'legacy')))
backend_warning(llm.__llm_backend__)
state = ItemState.NOT_FOUND
try:
_ref = openllm.serialisation.get(llm)
_previously_saved = True
except openllm.exceptions.OpenLLMException:
if not machine and output == 'pretty':
msg = f"'{model_name}' with model_id='{model_id}' does not exists in local store for backend {llm.__llm_backend__}. Saving to BENTOML_HOME{' (path=' + os.environ.get('BENTOML_HOME', BentoMLContainer.bentoml_home.get()) + ')' if get_debug_mode() else ''}..."
termui.echo(msg, fg='yellow', nl=True)
_ref = openllm.serialisation.get(llm, auto_import=True)
if backend == 'pt' and is_torch_available() and torch.cuda.is_available(): torch.cuda.empty_cache()
if machine: return _ref
elif output == 'pretty':
if _previously_saved: termui.echo(f"{model_name} with 'model_id={model_id}' is already setup for backend '{backend}': {_ref.tag!s}", nl=True, fg='yellow')
else: termui.echo(f'Saved model: {_ref.tag}')
elif output == 'json': termui.echo(orjson.dumps({'previously_setup': _previously_saved, 'backend': backend, 'tag': str(_ref.tag)}, option=orjson.OPT_INDENT_2).decode())
else: termui.echo(_ref.tag)
return _ref
model = bentoml.models.get(llm.tag)
state = ItemState.EXISTS
except bentoml.exceptions.NotFound:
model = openllm.serialisation.import_model(llm, trust_remote_code=llm.trust_remote_code)
if llm.__llm_backend__ == 'pt' and is_torch_available() and torch.cuda.is_available(): torch.cuda.empty_cache()
response = ImportModelOutput(state=state, backend=llm.__llm_backend__, tag=str(model.tag))
termui.echo(orjson.dumps(response).decode(), fg='white')
return response
class DeploymentInstruction(t.TypedDict):
type: t.Literal['container', 'bentocloud']
content: str
class BuildBentoOutput(t.TypedDict):
state: ItemState
tag: str
backend: LiteralBackend
instructions: t.List[DeploymentInstruction]
@cli.command(context_settings={'token_normalize_func': inflection.underscore})
@model_name_argument
@model_id_option
@output_option
@machine_option
@click.argument('model_id', type=click.STRING, metavar='[REMOTE_REPO/MODEL_ID | /path/to/local/model]', required=True)
@click.option('--model-id',
'deprecated_model_id',
type=click.STRING,
default=None,
hidden=True,
metavar='[REMOTE_REPO/MODEL_ID | /path/to/local/model]',
help='Deprecated. Use positional argument instead.')
@backend_option
@system_message_option
@prompt_template_file_option
@click.option('--bento-version', type=str, default=None, help='Optional bento version for this BentoLLM. Default is the the model revision.')
@click.option('--overwrite', is_flag=True, help='Overwrite existing Bento for given LLM if it already exists.')
@workers_per_resource_option(factory=click, build=True)
@cog.optgroup.group(cls=cog.MutuallyExclusiveOptionGroup, name='Optimisation options') # type: ignore[misc]
@quantize_option(factory=cog.optgroup, build=True)
@click.option('--enable-features',
@@ -434,17 +681,18 @@ def import_command(model_name: str, model_id: str | None, converter: str | None,
help="Whether to containerize the Bento after building. '--containerize' is the shortcut of 'openllm build && bentoml containerize'.")
@cog.optgroup.option('--push', default=False, is_flag=True, type=click.BOOL, help="Whether to push the result bento to BentoCloud. Make sure to login with 'bentoml cloud login' first.")
@click.option('--force-push', default=False, is_flag=True, type=click.BOOL, help='Whether to force push.')
@machine_option
@click.pass_context
def build_command(ctx: click.Context, /, model_name: str, model_id: str | None, bento_version: str | None, overwrite: bool, output: LiteralOutput, quantize: LiteralQuantise | None,
enable_features: tuple[str, ...] | None, workers_per_resource: float | None, adapter_id: tuple[str, ...], build_ctx: str | None, backend: LiteralBackend | None,
system_message: str | None, prompt_template_file: t.IO[t.Any] | None, machine: bool, model_version: str | None, dockerfile_template: t.TextIO | None, containerize: bool,
push: bool, serialisation: LiteralSerialisation | None, container_registry: LiteralContainerRegistry, container_version_strategy: LiteralContainerVersionStrategy,
force_push: bool, **attrs: t.Any) -> bentoml.Bento:
"""Package a given models into a Bento.
def build_command(ctx: click.Context, /, model_id: str, deprecated_model_id: str | None, bento_version: str | None, overwrite: bool, quantize: LiteralQuantise | None, machine: bool,
enable_features: tuple[str, ...] | None, adapter_id: tuple[str, ...], build_ctx: str | None, backend: LiteralBackend | None, system_message: str | None,
prompt_template_file: t.IO[t.Any] | None, model_version: str | None, dockerfile_template: t.TextIO | None, containerize: bool, push: bool,
serialisation: LiteralSerialisation | None, container_registry: LiteralContainerRegistry, container_version_strategy: LiteralContainerVersionStrategy, force_push: bool,
**_: t.Any) -> BuildBentoOutput:
"""Package a given models into a BentoLLM.
\b
```bash
$ openllm build flan-t5 --model-id google/flan-t5-large
$ openllm build google/flan-t5-large
```
\b
@@ -457,47 +705,47 @@ def build_command(ctx: click.Context, /, model_name: str, model_id: str | None,
> To build the bento with compiled OpenLLM, make sure to prepend HATCH_BUILD_HOOKS_ENABLE=1. Make sure that the deployment
> target also use the same Python version and architecture as build machine.
"""
if machine: output = 'porcelain'
from .._llm import normalise_model_name
from ..serialisation.transformers.weights import has_safetensors_weights
if model_id in openllm.CONFIG_MAPPING:
_model_name = model_id
if deprecated_model_id is not None: model_id = deprecated_model_id
else: model_id = openllm.AutoConfig.for_model(_model_name)['default_id']
termui.echo(
f"Passing 'openllm build {_model_name}{'' if deprecated_model_id is None else ' --model-id ' + deprecated_model_id}' is deprecated and will be remove in a future version. Use 'openllm build {model_id}' instead.",
fg='yellow')
if enable_features: enable_features = tuple(itertools.chain.from_iterable((s.split(',') for s in enable_features)))
_previously_built = False
state = ItemState.NOT_FOUND
llm_config = openllm.AutoConfig.for_model(model_name)
_serialisation = t.cast(LiteralSerialisation, first_not_none(serialisation, default=llm_config['serialisation']))
env = EnvVarMixin(model_name, backend=first_not_none(backend, default='vllm' if is_vllm_available() else 'pt'), model_id=model_id or llm_config['default_id'], quantize=quantize)
prompt_template: str | None = prompt_template_file.read() if prompt_template_file is not None else None
prompt_template = prompt_template_file.read() if prompt_template_file is not None else None
llm = openllm.LLM[t.Any, t.Any](model_id=model_id,
model_version=model_version,
prompt_template=prompt_template,
system_message=system_message,
backend=backend,
quantize=quantize,
serialisation=t.cast(LiteralSerialisation, first_not_none(serialisation, default='safetensors' if has_safetensors_weights(model_id) else 'legacy')))
backend_warning(llm.__llm_backend__)
os.environ.update({'OPENLLM_BACKEND': llm.__llm_backend__, 'OPENLLM_SERIALIZATION': llm._serialisation, 'OPENLLM_MODEL_ID': llm.model_id})
if llm._quantise: os.environ['OPENLLM_QUANTIZE'] = str(llm._quantise)
if system_message: os.environ['OPENLLM_SYSTEM_MESSAGE'] = system_message
if prompt_template: os.environ['OPENLLM_PROMPT_TEMPLATE'] = prompt_template
# NOTE: We set this environment variable so that our service.py logic won't raise RuntimeError
# during build. This is a current limitation of bentoml build where we actually import the service.py into sys.path
try:
os.environ.update({'OPENLLM_MODEL': inflection.underscore(model_name), 'OPENLLM_SERIALIZATION': _serialisation, env.backend: env['backend_value']})
if env['model_id_value']: os.environ[env.model_id] = str(env['model_id_value'])
if env['quantize_value']: os.environ[env.quantize] = str(env['quantize_value'])
if env['backend_value']: os.environ[env.backend] = str(env['backend_value'])
if system_message: os.environ['OPENLLM_SYSTEM_MESSAGE'] = system_message
if prompt_template: os.environ['OPENLLM_PROMPT_TEMPLATE'] = prompt_template
llm = openllm.LLM[t.Any, t.Any](model_id=env['model_id_value'] or llm_config['default_id'],
revision=model_version,
prompt_template=prompt_template,
system_message=system_message,
llm_config=llm_config,
backend=env['backend_value'],
quantize=env['quantize_value'],
serialisation=_serialisation,
**attrs)
llm.save_pretrained() # ensure_available = True
assert llm.bentomodel # HACK: call it here to patch correct tag with revision and everything
# FIX: This is a patch for _service_vars injection
if 'OPENLLM_MODEL_ID' not in os.environ: os.environ['OPENLLM_MODEL_ID'] = llm.model_id
if 'OPENLLM_ADAPTER_MAP' not in os.environ: os.environ['OPENLLM_ADAPTER_MAP'] = orjson.dumps(None).decode()
labels = dict(llm.identifying_params)
labels.update({'_type': llm.llm_type, '_framework': env['backend_value']})
workers_per_resource = first_not_none(workers_per_resource, default=llm_config['workers_per_resource'])
labels.update({'_type': llm.llm_type, '_framework': llm.__llm_backend__})
with fs.open_fs(f"temp://llm_{llm_config['model_name']}") as llm_fs:
with fs.open_fs(f'temp://llm_{normalise_model_name(model_id)}') as llm_fs:
dockerfile_template_path = None
if dockerfile_template:
with dockerfile_template:
@@ -505,8 +753,8 @@ def build_command(ctx: click.Context, /, model_name: str, model_id: str | None,
dockerfile_template_path = llm_fs.getsyspath('/Dockerfile.template')
adapter_map: dict[str, str] | None = None
if adapter_id and not build_ctx: ctx.fail("'build_ctx' is required when '--adapter-id' is passsed.")
if adapter_id:
if not build_ctx: ctx.fail("'build_ctx' is required when '--adapter-id' is passsed.")
adapter_map = {}
for v in adapter_id:
_adapter_id, *adapter_name = v.rsplit(':', maxsplit=1)
@@ -531,15 +779,14 @@ def build_command(ctx: click.Context, /, model_name: str, model_id: str | None,
try:
bento = bentoml.get(bento_tag)
if overwrite:
if output == 'pretty': termui.echo(f'Overwriting existing Bento {bento_tag}', fg='yellow')
bentoml.delete(bento_tag)
state = ItemState.OVERWRITE
raise bentoml.exceptions.NotFound(f'Rebuilding existing Bento {bento_tag}') from None
_previously_built = True
state = ItemState.EXISTS
except bentoml.exceptions.NotFound:
bento = bundle.create_bento(bento_tag,
llm_fs,
llm,
workers_per_resource=workers_per_resource,
adapter_map=adapter_map,
quantize=quantize,
extra_dependencies=enable_features,
@@ -547,23 +794,26 @@ def build_command(ctx: click.Context, /, model_name: str, model_id: str | None,
container_registry=container_registry,
container_version_strategy=container_version_strategy)
except Exception as err:
raise err from None
traceback.print_exc()
raise click.ClickException('Exception caught while building BentoLLM:\n' + str(err)) from err
if machine: termui.echo(f'__tag__:{bento.tag}', fg='white')
elif output == 'pretty':
if not get_quiet_mode() and (not push or not containerize):
termui.echo('\n' + OPENLLM_FIGLET, fg='white')
if not _previously_built: termui.echo(f'Successfully built {bento}.', fg='green')
elif not overwrite: termui.echo(f"'{model_name}' already has a Bento built [{bento}]. To overwrite it pass '--overwrite'.", fg='yellow')
termui.echo('📖 Next steps:\n\n' + f"* Push to BentoCloud with 'bentoml push':\n\t$ bentoml push {bento.tag}\n\n" +
f"* Containerize your Bento with 'bentoml containerize':\n\t$ bentoml containerize {bento.tag} --opt progress=plain\n\n" +
"\tTip: To enable additional BentoML features for 'containerize', use '--enable-features=FEATURE[,FEATURE]' [see 'bentoml containerize -h' for more advanced usage]\n",
fg='blue',
)
elif output == 'json':
termui.echo(orjson.dumps(bento.info.to_dict(), option=orjson.OPT_INDENT_2).decode())
else:
termui.echo(bento.tag)
response = BuildBentoOutput(state=state,
tag=str(bento_tag),
backend=llm.__llm_backend__,
instructions=[
DeploymentInstruction(type='bentocloud', content=f"Push to BentoCloud with 'bentoml push': `bentoml push {bento_tag}`"),
DeploymentInstruction(type='container', content=f"Container BentoLLM with 'bentoml containerize': `bentoml containerize {bento_tag} --opt progress=plain`")
])
if machine: termui.echo(f'__object__:{orjson.dumps(response).decode()}\n\n', fg='white')
elif not get_quiet_mode() and (not push or not containerize):
if not overwrite: termui.warning(f"Bento for '{model_id}' already exists [{bento}]. To overwrite it pass '--overwrite'.\n")
elif state != ItemState.EXISTS: termui.info(f"Successfully built Bento '{bento.tag}'.\n")
if not get_debug_mode():
termui.echo(OPENLLM_FIGLET)
termui.echo('\n📖 Next steps:\n\n', nl=False)
for instruction in response['instructions']:
termui.echo(f"* {instruction['content']}\n", nl=False)
if push: BentoMLContainer.bentocloud_client.get().push_bento(bento, context=t.cast(GlobalOptions, ctx.obj).cloud_context, force=force_push)
elif containerize:
@@ -576,86 +826,50 @@ def build_command(ctx: click.Context, /, model_name: str, model_id: str | None,
bentoml.container.build(bento.tag, backend=container_backend, features=('grpc', 'io'))
except Exception as err:
raise OpenLLMException(f"Exception caught while containerizing '{bento.tag!s}':\n{err}") from err
return bento
response.pop('instructions')
if get_debug_mode(): termui.echo('\n' + orjson.dumps(response).decode(), fg=None)
return response
class ModelItem(t.TypedDict):
architecture: str
example_id: str
supported_backends: t.Tuple[LiteralBackend, ...]
installation: str
items: NotRequired[t.List[str]]
@cli.command()
@output_option
@click.option('--show-available', is_flag=True, default=False, help="Show available models in local store (mutually exclusive with '-o porcelain').")
@machine_option
@click.pass_context
def models_command(ctx: click.Context, output: LiteralOutput, show_available: bool, machine: bool) -> DictStrAny | None:
def models_command(show_available: bool) -> dict[t.LiteralString, ModelItem]:
"""List all supported models.
\b
> NOTE: '--show-available' and '-o porcelain' are mutually exclusive.
\b
```bash
openllm models --show-available
```
"""
from .._llm import normalise_model_name
models = tuple(inflection.dasherize(key) for key in CONFIG_MAPPING.keys())
if output == 'porcelain':
if show_available: raise click.BadOptionUsage('--show-available', "Cannot use '--show-available' with '-o porcelain' (mutually exclusive).")
termui.echo('\n'.join(models), fg='white')
else:
json_data: dict[str, dict[t.Literal['architecture', 'model_id', 'url', 'installation', 'cpu', 'gpu', 'backend'], t.Any] | t.Any] = {}
converted: list[str] = []
for m in models:
config = openllm.AutoConfig.for_model(m)
json_data[m] = {
'architecture': config['architecture'],
'model_id': config['model_ids'],
'backend': config['backend'],
'installation': f'"openllm[{m}]"' if m in OPTIONAL_DEPENDENCIES or config['requirements'] else 'openllm',
}
converted.extend([normalise_model_name(i) for i in config['model_ids']])
ids_in_local_store = {
k: [
i for i in bentoml.models.list() if 'framework' in i.info.labels and i.info.labels['framework'] == 'openllm' and 'model_name' in i.info.labels and i.info.labels['model_name'] == k
] for k in json_data.keys()
}
ids_in_local_store = {k: v for k, v in ids_in_local_store.items() if v}
local_models: DictStrAny | None = None
if show_available:
local_models = {k: [str(i.tag) for i in val] for k, val in ids_in_local_store.items()}
if machine:
if show_available: json_data['local'] = local_models
return json_data
elif output == 'pretty':
import tabulate
tabulate.PRESERVE_WHITESPACE = True
# llm, architecture, url, model_id, installation, backend
data: list[str | tuple[str, str, list[str], str, tuple[LiteralBackend, ...]]] = []
for m, v in json_data.items():
data.extend([(m, v['architecture'], v['model_id'], v['installation'], v['backend'])])
column_widths = [int(termui.COLUMNS / 12), int(termui.COLUMNS / 6), int(termui.COLUMNS / 4), int(termui.COLUMNS / 6), int(termui.COLUMNS / 4)]
table = tabulate.tabulate(data, tablefmt='fancy_grid', headers=['LLM', 'Architecture', 'Models Id', 'Installation', 'Runtime'], maxcolwidths=column_widths)
termui.echo(table, fg='white')
if show_available:
if len(ids_in_local_store) == 0:
termui.echo('No models available locally.')
ctx.exit(0)
termui.echo('The following are available in local store:', fg='magenta')
termui.echo(orjson.dumps(local_models, option=orjson.OPT_INDENT_2).decode(), fg='white')
else:
if show_available: json_data['local'] = local_models
termui.echo(orjson.dumps(json_data, option=orjson.OPT_INDENT_2,).decode(), fg='white')
ctx.exit(0)
result: dict[t.LiteralString, ModelItem] = {
m: ModelItem(architecture=config.__openllm_architecture__,
example_id=random.choice(config.__openllm_model_ids__),
supported_backends=config.__openllm_backend__,
installation='pip install ' + (f'"openllm[{m}]"' if m in OPTIONAL_DEPENDENCIES or config.__openllm_requirements__ else 'openllm'),
items=[] if not show_available else [
str(md.tag)
for md in bentoml.models.list()
if 'framework' in md.info.labels and md.info.labels['framework'] == 'openllm' and 'model_name' in md.info.labels and md.info.labels['model_name'] == m
]) for m, config in CONFIG_MAPPING.items()
}
termui.echo(orjson.dumps(result, option=orjson.OPT_INDENT_2).decode(), fg=None)
return result
@cli.command()
@model_name_argument(required=False)
@click.option('-y', '--yes', '--assume-yes', is_flag=True, help='Skip confirmation when deleting a specific model')
@click.option('--include-bentos/--no-include-bentos', is_flag=True, default=False, help='Whether to also include pruning bentos.')
@inject
def prune_command(model_name: str | None,
@click.pass_context
def prune_command(ctx: click.Context,
model_name: str | None,
yes: bool,
include_bentos: bool,
model_store: ModelStore = Provide[BentoMLContainer.model_store],
@@ -679,7 +893,8 @@ def prune_command(model_name: str | None,
else: delete_confirmed = click.confirm(f"delete {'model' if isinstance(store, ModelStore) else 'bento'} {store_item.tag}?")
if delete_confirmed:
store.delete(store_item.tag)
termui.echo(f"{store_item} deleted from {'model' if isinstance(store, ModelStore) else 'bento'} store.", fg='yellow')
termui.warning(f"{store_item} deleted from {'model' if isinstance(store, ModelStore) else 'bento'} store.")
ctx.exit(0)
def parsing_instruction_callback(ctx: click.Context, param: click.Parameter, value: list[str] | str | None) -> tuple[str, bool | str] | list[str] | str | None:
if value is None:
@@ -700,23 +915,26 @@ def parsing_instruction_callback(ctx: click.Context, param: click.Parameter, val
else:
raise click.BadParameter(f'Invalid option format: {value}')
def shared_client_options(f: _AnyCallable | None = None, output_value: t.Literal['json', 'porcelain', 'pretty'] = 'pretty') -> t.Callable[[FC], FC]:
def shared_client_options(f: _AnyCallable | None = None) -> t.Callable[[FC], FC]:
options = [
click.option('--endpoint', type=click.STRING, help='OpenLLM Server endpoint, i.e: http://localhost:3000', envvar='OPENLLM_ENDPOINT', default='http://localhost:3000',
),
click.option('--endpoint',
type=click.STRING,
help='OpenLLM Server endpoint, i.e: http://localhost:3000',
envvar='OPENLLM_ENDPOINT',
show_envvar=True,
show_default=True,
default='http://localhost:3000'),
click.option('--timeout', type=click.INT, default=30, help='Default server timeout', show_default=True),
output_option(default_value=output_value),
]
return compose(*options)(f) if f is not None else compose(*options)
@cli.command()
@cli.command(hidden=True)
@click.argument('task', type=click.STRING, metavar='TASK')
@shared_client_options
@click.option('--agent', type=click.Choice(['hf']), default='hf', help='Whether to interact with Agents from given Server endpoint.', show_default=True)
@click.option('--remote', is_flag=True, default=False, help='Whether or not to use remote tools (inference endpoints) instead of local ones.', show_default=True)
@click.option('--opt',
help="Define prompt options. "
"(format: ``--opt text='I love this' --opt audio:./path/to/audio --opt image:/path/to/file``)",
help="Define prompt options. (format: ``--opt text='I love this' --opt audio:./path/to/audio --opt image:/path/to/file``)",
required=False,
multiple=True,
callback=opt_callback,
@@ -750,8 +968,8 @@ def instruct_command(endpoint: str, timeout: int, agent: LiteralString, output:
# raise click.BadOptionUsage('agent', f'Unknown agent type {agent}')
@cli.command()
@shared_client_options(output_value='porcelain')
@click.option('--server-type', type=click.Choice(['grpc', 'http']), help='Server type', default='http', show_default=True)
@shared_client_options
@click.option('--server-type', type=click.Choice(['grpc', 'http']), help='Server type', default='http', show_default=True, hidden=True)
@click.option('--stream/--no-stream', type=click.BOOL, is_flag=True, default=True, help='Whether to stream the response.')
@click.argument('prompt', type=click.STRING)
@click.option('--sampling-params',
@@ -761,45 +979,28 @@ def instruct_command(endpoint: str, timeout: int, agent: LiteralString, output:
callback=opt_callback,
metavar='ARG=VALUE[,ARG=VALUE]')
@click.pass_context
def query_command(ctx: click.Context, /, prompt: str, endpoint: str, timeout: int, stream: bool, server_type: t.Literal['http', 'grpc'], output: LiteralOutput, _memoized: DictStrAny,
**attrs: t.Any) -> None:
'''Ask a LLM interactively, from a terminal.
def query_command(ctx: click.Context, /, prompt: str, endpoint: str, timeout: int, stream: bool, server_type: t.Literal['http', 'grpc'], _memoized: DictStrAny, **_: t.Any) -> None:
'''Query a LLM interactively, from a terminal.
\b
```bash
$ openllm query --endpoint http://12.323.2.1:3000 "What is the meaning of life?"
```
'''
_memoized = {k: orjson.loads(v[0]) for k, v in _memoized.items() if v}
if server_type == 'grpc': raise click.ClickException("'grpc' is currently disabled.")
_memoized = {k: orjson.loads(v[0]) for k, v in _memoized.items() if v}
# TODO: grpc support
client = openllm.client.HTTPClient(address=endpoint, timeout=timeout)
input_fg, generated_fg = 'magenta', 'cyan'
if output != 'porcelain':
termui.echo('==Input==\n', fg='white')
termui.echo(f'{prompt}', fg=input_fg)
if stream:
stream_res: t.Iterator[StreamingResponse] = client.generate_stream(prompt, **{**client._config(), **_memoized})
if output == 'pretty':
termui.echo('\n\n==Responses==\n', fg='white')
for it in stream_res:
termui.echo(it.text, fg=generated_fg, nl=False)
elif output == 'json':
for it in stream_res:
termui.echo(orjson.dumps(converter.unstructure(it), option=orjson.OPT_INDENT_2).decode(), fg='white')
else:
for it in stream_res:
termui.echo(it.text, fg=generated_fg, nl=False)
stream_res: t.Iterator[StreamingResponse] = client.generate_stream(prompt, **_memoized)
termui.echo(prompt, fg=input_fg, nl=False)
for it in stream_res:
termui.echo(it.text, fg=generated_fg, nl=False)
else:
res: Response = client.generate(prompt, **{**client._config(), **_memoized})
if output == 'pretty':
termui.echo('\n\n==Responses==\n', fg='white')
termui.echo(res.outputs[0].text, fg=generated_fg)
elif output == 'json':
termui.echo(orjson.dumps(converter.unstructure(res), option=orjson.OPT_INDENT_2).decode(), fg='white')
else:
termui.echo(res.outputs[0].text, fg='white')
termui.echo(prompt, fg=input_fg, nl=False)
termui.echo(client.generate(prompt, **_memoized).outputs[0].text, fg=generated_fg, nl=False)
ctx.exit(0)
@cli.group(cls=Extensions, hidden=True, name='extension')

View File

@@ -1,4 +1,6 @@
from __future__ import annotations
import logging
import traceback
import typing as t
import click
@@ -8,21 +10,18 @@ import orjson
from bentoml_cli.utils import opt_callback
import openllm
import openllm_core
from openllm.cli import termui
from openllm.cli._factory import machine_option
from openllm.cli._factory import model_complete_envvar
from openllm.cli._factory import output_option
from openllm_core.prompts import process_prompt
LiteralOutput = t.Literal['json', 'pretty', 'porcelain']
logger = logging.getLogger(__name__)
@click.command('get_prompt', context_settings=termui.CONTEXT_SETTINGS)
@click.argument('model_name', type=click.Choice([inflection.dasherize(name) for name in openllm.CONFIG_MAPPING.keys()]), shell_complete=model_complete_envvar)
@click.argument('prompt', type=click.STRING)
@output_option
@click.option('--format', type=click.STRING, default=None)
@machine_option
@click.option('--opt',
help="Define additional prompt variables. (format: ``--opt system_prompt='You are a useful assistant'``)",
required=False,
@@ -30,9 +29,9 @@ LiteralOutput = t.Literal['json', 'pretty', 'porcelain']
callback=opt_callback,
metavar='ARG=VALUE[,ARG=VALUE]')
@click.pass_context
def cli(ctx: click.Context, /, model_name: str, prompt: str, format: str | None, output: LiteralOutput, machine: bool, _memoized: dict[str, t.Any], **_: t.Any) -> str | None:
def cli(ctx: click.Context, /, model_name: str, prompt: str, format: str | None, _memoized: dict[str, t.Any], **_: t.Any) -> str | None:
"""Get the default prompt used by OpenLLM."""
module = openllm.utils.EnvVarMixin(model_name).module
module = getattr(openllm_core.config, f'configuration_{model_name}')
_memoized = {k: v[0] for k, v in _memoized.items() if v}
try:
template = getattr(module, 'DEFAULT_PROMPT_TEMPLATE', None)
@@ -54,15 +53,11 @@ def cli(ctx: click.Context, /, model_name: str, prompt: str, format: str | None,
try:
# backward-compatible. TO BE REMOVED once every model has default system message and prompt template.
fully_formatted = process_prompt(prompt, _prompt_template, True, **_memoized)
except RuntimeError:
except RuntimeError as err:
logger.debug('Exception caught while formatting prompt: %s', err)
fully_formatted = openllm.AutoConfig.for_model(model_name).sanitize_parameters(prompt, prompt_template=_prompt_template)[0]
if machine: return repr(fully_formatted)
elif output == 'porcelain': termui.echo(repr(fully_formatted), fg='white')
elif output == 'json':
termui.echo(orjson.dumps({'prompt': fully_formatted}, option=orjson.OPT_INDENT_2).decode(), fg='white')
else:
termui.echo(f'== Prompt for {model_name} ==\n', fg='magenta')
termui.echo(fully_formatted, fg='white')
except AttributeError:
raise click.ClickException(f'Failed to determine a default prompt template for {model_name}.') from None
termui.echo(orjson.dumps({'prompt': fully_formatted}, option=orjson.OPT_INDENT_2).decode(), fg='white')
except Exception as err:
traceback.print_exc()
raise click.ClickException(f'Failed to determine a default prompt template for {model_name}.') from err
ctx.exit(0)

View File

@@ -9,13 +9,10 @@ import openllm
from bentoml._internal.utils import human_readable_size
from openllm.cli import termui
from openllm.cli._factory import LiteralOutput
from openllm.cli._factory import output_option
@click.command('list_bentos', context_settings=termui.CONTEXT_SETTINGS)
@output_option(default_value='json')
@click.pass_context
def cli(ctx: click.Context, output: LiteralOutput) -> None:
def cli(ctx: click.Context) -> None:
"""List available bentos built by OpenLLM."""
mapping = {
k: [{
@@ -29,13 +26,5 @@ def cli(ctx: click.Context, output: LiteralOutput) -> None:
k in i.info.labels for k in {'start_name', 'bundler'})) if b.info.labels['start_name'] == k] for k in tuple(inflection.dasherize(key) for key in openllm.CONFIG_MAPPING.keys())
}
mapping = {k: v for k, v in mapping.items() if v}
if output == 'pretty':
import tabulate
tabulate.PRESERVE_WHITESPACE = True
termui.echo(tabulate.tabulate([(k, i['tag'], i['size'], [_['tag'] for _ in i['models']]) for k, v in mapping.items() for i in v],
tablefmt='fancy_grid',
headers=['LLM', 'Tag', 'Size', 'Models']),
fg='white')
else:
termui.echo(orjson.dumps(mapping, option=orjson.OPT_INDENT_2).decode(), fg='white')
termui.echo(orjson.dumps(mapping, option=orjson.OPT_INDENT_2).decode(), fg='white')
ctx.exit(0)

View File

@@ -10,18 +10,15 @@ import openllm
from bentoml._internal.utils import human_readable_size
from openllm.cli import termui
from openllm.cli._factory import LiteralOutput
from openllm.cli._factory import model_complete_envvar
from openllm.cli._factory import model_name_argument
from openllm.cli._factory import output_option
if t.TYPE_CHECKING:
from openllm_core._typing_compat import DictStrAny
@click.command('list_models', context_settings=termui.CONTEXT_SETTINGS)
@model_name_argument(required=False, shell_complete=model_complete_envvar)
@output_option(default_value='json')
def cli(model_name: str | None, output: LiteralOutput) -> DictStrAny:
def cli(model_name: str | None) -> DictStrAny:
"""This is equivalent to openllm models --show-available less the nice table."""
models = tuple(inflection.dasherize(key) for key in openllm.CONFIG_MAPPING.keys())
ids_in_local_store = {
@@ -32,10 +29,5 @@ def cli(model_name: str | None, output: LiteralOutput) -> DictStrAny:
ids_in_local_store = {k: [i for i in v if 'model_name' in i.info.labels and i.info.labels['model_name'] == inflection.dasherize(model_name)] for k, v in ids_in_local_store.items()}
ids_in_local_store = {k: v for k, v in ids_in_local_store.items() if v}
local_models = {k: [{'tag': str(i.tag), 'size': human_readable_size(openllm.utils.calc_dir_size(i.path))} for i in val] for k, val in ids_in_local_store.items()}
if output == 'pretty':
import tabulate
tabulate.PRESERVE_WHITESPACE = True
termui.echo(tabulate.tabulate([(k, i['tag'], i['size']) for k, v in local_models.items() for i in v], tablefmt='fancy_grid', headers=['LLM', 'Tag', 'Size']), fg='white')
else:
termui.echo(orjson.dumps(local_models, option=orjson.OPT_INDENT_2).decode(), fg='white')
termui.echo(orjson.dumps(local_models, option=orjson.OPT_INDENT_2).decode(), fg='white')
return local_models

View File

@@ -1,20 +1,65 @@
from __future__ import annotations
import enum
import functools
import logging
import os
import typing as t
import click
import inflection
import orjson
import openllm
from openllm_core._typing_compat import DictStrAny
from openllm_core.utils import get_debug_mode
from openllm_core.utils import get_quiet_mode
if t.TYPE_CHECKING:
from openllm_core._typing_compat import DictStrAny
logger = logging.getLogger('openllm')
def echo(text: t.Any, fg: str = 'green', _with_style: bool = True, **attrs: t.Any) -> None:
attrs['fg'] = fg if not openllm.utils.get_debug_mode() else None
if not openllm.utils.get_quiet_mode():
t.cast(t.Callable[..., None], click.echo if not _with_style else click.secho)(text, **attrs)
class Level(enum.IntEnum):
NOTSET = logging.DEBUG
DEBUG = logging.DEBUG
INFO = logging.INFO
WARNING = logging.WARNING
ERROR = logging.ERROR
CRITICAL = logging.CRITICAL
@property
def color(self) -> str | None:
return {Level.NOTSET: None, Level.DEBUG: 'cyan', Level.INFO: 'green', Level.WARNING: 'yellow', Level.ERROR: 'red', Level.CRITICAL: 'red'}[self]
class JsonLog(t.TypedDict):
log_level: Level
content: str
def log(content: str, level: Level = Level.INFO, fg: str | None = None) -> None:
def caller(text: str) -> None:
if get_debug_mode(): logger.log(level.value, text)
else: echo(JsonLog(log_level=level, content=content), json=True, fg=fg)
caller(orjson.dumps(JsonLog(log_level=level, content=content)).decode())
warning = functools.partial(log, level=Level.WARNING)
error = functools.partial(log, level=Level.ERROR)
critical = functools.partial(log, level=Level.CRITICAL)
debug = functools.partial(log, level=Level.DEBUG)
info = functools.partial(log, level=Level.INFO)
notset = functools.partial(log, level=Level.NOTSET)
def echo(text: t.Any, fg: str | None = None, _with_style: bool = True, json: bool = False, **attrs: t.Any) -> None:
if json and not isinstance(text, dict): raise TypeError('text must be a dict')
if json:
if 'content' in text and 'log_level' in text:
content = t.cast(DictStrAny, text)['content']
fg = t.cast(Level, text['log_level']).color
else:
content = orjson.dumps(text).decode()
fg = Level.INFO.color if not get_debug_mode() else Level.DEBUG.color
else:
content = t.cast(str, text)
attrs['fg'] = fg if not get_debug_mode() else None
if not get_quiet_mode(): t.cast(t.Callable[..., None], click.echo if not _with_style else click.secho)(content, **attrs)
COLUMNS: int = int(os.environ.get('COLUMNS', str(120)))
CONTEXT_SETTINGS: DictStrAny = {'help_option_names': ['-h', '--help'], 'max_content_width': COLUMNS, 'token_normalize_func': inflection.underscore}
__all__ = ['echo', 'COLUMNS', 'CONTEXT_SETTINGS']
__all__ = ['echo', 'COLUMNS', 'CONTEXT_SETTINGS', 'log', 'warning', 'error', 'critical', 'debug', 'info', 'Level']