mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-03-09 18:48:09 -04:00
refactor: focus (#730)
* perf: remove based images Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> * chore: update changelog Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> * chore: move dockerifle to run on release only Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> * chore: cleanup unused types Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> --------- Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
@@ -1,31 +1,8 @@
|
||||
from __future__ import annotations
|
||||
import enum
|
||||
import functools
|
||||
import inspect
|
||||
import itertools
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
import random
|
||||
import subprocess
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
import typing as t
|
||||
|
||||
import attr
|
||||
import click
|
||||
import click_option_group as cog
|
||||
import fs
|
||||
import fs.copy
|
||||
import fs.errors
|
||||
import inflection
|
||||
import orjson
|
||||
import enum, functools, inspect, itertools, logging, os, platform, random, subprocess, threading, time, traceback, typing as t
|
||||
import attr, click, fs, inflection, bentoml, openllm, orjson, fs.copy, fs.errors, click_option_group as cog
|
||||
from bentoml_cli.utils import BentoMLCommandGroup, opt_callback
|
||||
from simple_di import Provide, inject
|
||||
|
||||
import bentoml
|
||||
import openllm
|
||||
from bentoml._internal.cloud.config import CloudClientConfig
|
||||
from bentoml._internal.configuration.containers import BentoMLContainer
|
||||
from bentoml._internal.models.model import ModelStore
|
||||
@@ -70,7 +47,6 @@ from ._factory import (
|
||||
FC,
|
||||
_AnyCallable,
|
||||
backend_option,
|
||||
container_registry_option,
|
||||
dtype_option,
|
||||
machine_option,
|
||||
model_name_argument,
|
||||
@@ -88,7 +64,6 @@ if t.TYPE_CHECKING:
|
||||
from bentoml._internal.container import DefaultBuilder
|
||||
from openllm_client._schemas import StreamingResponse
|
||||
from openllm_core._configuration import LLMConfig
|
||||
from openllm_core._typing_compat import LiteralContainerRegistry, LiteralContainerVersionStrategy
|
||||
else:
|
||||
torch = LazyLoader('torch', globals(), 'torch')
|
||||
|
||||
@@ -248,7 +223,7 @@ class OpenLLMCommandGroup(BentoMLCommandGroup):
|
||||
return super().get_command(ctx, cmd_name)
|
||||
|
||||
def list_commands(self, ctx: click.Context) -> list[str]:
|
||||
return super().list_commands(ctx) + t.cast('Extensions', extension_command).list_commands(ctx)
|
||||
return super().list_commands(ctx) + extension_command.list_commands(ctx)
|
||||
|
||||
def command(self, *args: t.Any, **kwargs: t.Any) -> t.Callable[[t.Callable[..., t.Any]], click.Command]:
|
||||
"""Override the default 'cli.command' with supports for aliases for given command, and it wraps the implementation with common parameters."""
|
||||
@@ -371,7 +346,7 @@ def cli() -> None:
|
||||
default=None,
|
||||
help='Maximum sequence length for the model. If not specified, we will use the default value from the model config.',
|
||||
)
|
||||
@start_decorator(serve_grpc=False)
|
||||
@start_decorator
|
||||
def start_command(
|
||||
model_id: str,
|
||||
server_timeout: int,
|
||||
@@ -396,26 +371,21 @@ def start_command(
|
||||
$ openllm <start|start-http> <model_id> --<options> ...
|
||||
```
|
||||
'''
|
||||
if backend == 'pt': logger.warning('PyTorch backend is deprecated and will be removed in future releases. Make sure to use vLLM instead.')
|
||||
if model_id in openllm.CONFIG_MAPPING:
|
||||
_model_name = model_id
|
||||
if deprecated_model_id is not None:
|
||||
model_id = deprecated_model_id
|
||||
else:
|
||||
model_id = openllm.AutoConfig.for_model(_model_name)['default_id']
|
||||
termui.warning(
|
||||
f"Passing 'openllm start {_model_name}{'' if deprecated_model_id is None else ' --model-id ' + deprecated_model_id}' is deprecated and will be remove in a future version. Use 'openllm start {model_id}' instead."
|
||||
)
|
||||
logger.warning("Passing 'openllm start %s%s' is deprecated and will be remove in a future version. Use 'openllm start %s' instead.", _model_name, '' if deprecated_model_id is None else f' --model-id {deprecated_model_id}', model_id)
|
||||
|
||||
adapter_map: dict[str, str] | None = attrs.pop('adapter_map', None)
|
||||
|
||||
from openllm.serialisation.transformers.weights import has_safetensors_weights
|
||||
|
||||
serialisation = t.cast(
|
||||
LiteralSerialisation,
|
||||
first_not_none(
|
||||
serialisation, default='safetensors' if has_safetensors_weights(model_id, model_version) else 'legacy'
|
||||
),
|
||||
)
|
||||
serialisation = first_not_none(serialisation, default='safetensors' if has_safetensors_weights(model_id, model_version) else 'legacy')
|
||||
|
||||
if serialisation == 'safetensors' and quantize is not None:
|
||||
logger.warning("'--quantize=%s' might not work with 'safetensors' serialisation format.", quantize)
|
||||
logger.warning(
|
||||
@@ -449,8 +419,7 @@ def start_command(
|
||||
config, server_attrs = llm.config.model_validate_click(**attrs)
|
||||
server_timeout = first_not_none(server_timeout, default=config['timeout'])
|
||||
server_attrs.update({'working_dir': pkg.source_locations('openllm'), 'timeout': server_timeout})
|
||||
# XXX: currently, theres no development args in bentoml.Server. To be fixed upstream.
|
||||
development = server_attrs.pop('development')
|
||||
development = server_attrs.pop('development') # XXX: currently, theres no development args in bentoml.Server. To be fixed upstream.
|
||||
server_attrs.setdefault('production', not development)
|
||||
|
||||
start_env = process_environ(
|
||||
@@ -479,145 +448,8 @@ def start_command(
|
||||
# NOTE: Return the configuration for telemetry purposes.
|
||||
return config
|
||||
|
||||
|
||||
@cli.command(
|
||||
context_settings=termui.CONTEXT_SETTINGS,
|
||||
name='start-grpc',
|
||||
short_help='Start a gRPC LLMServer for any supported LLM.',
|
||||
)
|
||||
@click.argument('model_id', type=click.STRING, metavar='[REMOTE_REPO/MODEL_ID | /path/to/local/model]', required=True)
|
||||
@click.option(
|
||||
'--model-id',
|
||||
'deprecated_model_id',
|
||||
type=click.STRING,
|
||||
default=None,
|
||||
hidden=True,
|
||||
metavar='[REMOTE_REPO/MODEL_ID | /path/to/local/model]',
|
||||
help='Deprecated. Use positional argument instead.',
|
||||
)
|
||||
@start_decorator(serve_grpc=True)
|
||||
@click.option(
|
||||
'--max-model-len',
|
||||
'--max_model_len',
|
||||
'max_model_len',
|
||||
default=None,
|
||||
help='Maximum sequence length for the model. If not specified, we will use the default value from the model config.',
|
||||
)
|
||||
def start_grpc_command(
|
||||
model_id: str,
|
||||
server_timeout: int,
|
||||
model_version: str | None,
|
||||
workers_per_resource: t.Literal['conserved', 'round_robin'] | LiteralString,
|
||||
device: t.Tuple[str, ...],
|
||||
quantize: LiteralQuantise | None,
|
||||
backend: LiteralBackend | None,
|
||||
serialisation: LiteralSerialisation | None,
|
||||
cors: bool,
|
||||
dtype: LiteralDtype,
|
||||
adapter_id: str | None,
|
||||
return_process: bool,
|
||||
deprecated_model_id: str | None,
|
||||
max_model_len: int | None,
|
||||
**attrs: t.Any,
|
||||
) -> LLMConfig | subprocess.Popen[bytes]:
|
||||
'''Start any LLM as a gRPC server.
|
||||
|
||||
\b
|
||||
```bash
|
||||
$ openllm start-grpc <model_id> --<options> ...
|
||||
```
|
||||
'''
|
||||
termui.warning(
|
||||
'Continuous batching is currently not yet supported with gPRC. If you want to use continuous batching with gRPC, feel free to open a GitHub issue about your usecase.\n'
|
||||
)
|
||||
if model_id in openllm.CONFIG_MAPPING:
|
||||
_model_name = model_id
|
||||
if deprecated_model_id is not None:
|
||||
model_id = deprecated_model_id
|
||||
else:
|
||||
model_id = openllm.AutoConfig.for_model(_model_name)['default_id']
|
||||
termui.warning(
|
||||
f"Passing 'openllm start-grpc {_model_name}{'' if deprecated_model_id is None else ' --model-id ' + deprecated_model_id}' is deprecated and will be remove in a future version. Use 'openllm start-grpc {model_id}' instead."
|
||||
)
|
||||
|
||||
adapter_map: dict[str, str] | None = attrs.pop('adapter_map', None)
|
||||
|
||||
from openllm.serialisation.transformers.weights import has_safetensors_weights
|
||||
|
||||
serialisation = first_not_none(
|
||||
serialisation, default='safetensors' if has_safetensors_weights(model_id, model_version) else 'legacy'
|
||||
)
|
||||
if serialisation == 'safetensors' and quantize is not None:
|
||||
logger.warning("'--quantize=%s' might not work with 'safetensors' serialisation format.", quantize)
|
||||
logger.warning(
|
||||
"Make sure to check out '%s' repository to see if the weights is in '%s' format if unsure.",
|
||||
model_id,
|
||||
serialisation,
|
||||
)
|
||||
logger.info("Tip: You can always fallback to '--serialisation legacy' when running quantisation.")
|
||||
|
||||
import torch
|
||||
|
||||
if backend == 'pt' and not torch.cuda.is_available():
|
||||
if dtype == 'auto':
|
||||
dtype = 'float'
|
||||
elif dtype not in {'float', 'float32'}:
|
||||
logger.warning('"bfloat16" and "half" are not supported on CPU. OpenLLM will default fallback to "float32".')
|
||||
dtype = 'float' # we need to cast back to full precision if cuda is not available
|
||||
llm = openllm.LLM[t.Any, t.Any](
|
||||
model_id=model_id,
|
||||
model_version=model_version,
|
||||
backend=backend,
|
||||
adapter_map=adapter_map,
|
||||
quantize=quantize,
|
||||
serialisation=serialisation,
|
||||
dtype=dtype,
|
||||
max_model_len=max_model_len,
|
||||
trust_remote_code=check_bool_env('TRUST_REMOTE_CODE', False),
|
||||
)
|
||||
backend_warning(llm.__llm_backend__)
|
||||
|
||||
config, server_attrs = llm.config.model_validate_click(**attrs)
|
||||
server_timeout = first_not_none(server_timeout, default=config['timeout'])
|
||||
server_attrs.update({'working_dir': pkg.source_locations('openllm'), 'timeout': server_timeout})
|
||||
server_attrs['grpc_protocol_version'] = 'v1'
|
||||
# XXX: currently, theres no development args in bentoml.Server. To be fixed upstream.
|
||||
development = server_attrs.pop('development')
|
||||
server_attrs.setdefault('production', not development)
|
||||
|
||||
start_env = process_environ(
|
||||
config,
|
||||
server_timeout,
|
||||
process_workers_per_resource(first_not_none(workers_per_resource, default=config['workers_per_resource']), device),
|
||||
device,
|
||||
cors,
|
||||
model_id,
|
||||
adapter_map,
|
||||
serialisation,
|
||||
llm,
|
||||
)
|
||||
|
||||
server = bentoml.GrpcServer('_service:svc', **server_attrs)
|
||||
openllm.utils.analytics.track_start_init(llm.config)
|
||||
|
||||
try:
|
||||
build_bento_instruction(llm, model_id, serialisation, adapter_map)
|
||||
it = run_server(server.args, start_env, return_process=return_process)
|
||||
if return_process:
|
||||
return it
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
|
||||
# NOTE: Return the configuration for telemetry purposes.
|
||||
return config
|
||||
|
||||
|
||||
def process_environ(
|
||||
config, server_timeout, wpr, device, cors, model_id, adapter_map, serialisation, llm, use_current_env=True
|
||||
) -> t.Dict[str, t.Any]:
|
||||
environ = parse_config_options(
|
||||
config, server_timeout, wpr, device, cors, os.environ.copy() if use_current_env else {}
|
||||
)
|
||||
def process_environ(config, server_timeout, wpr, device, cors, model_id, adapter_map, serialisation, llm, use_current_env=True):
|
||||
environ = parse_config_options(config, server_timeout, wpr, device, cors, os.environ.copy() if use_current_env else {})
|
||||
environ.update(
|
||||
{
|
||||
'OPENLLM_MODEL_ID': model_id,
|
||||
@@ -631,11 +463,9 @@ def process_environ(
|
||||
'TRUST_REMOTE_CODE': str(llm.trust_remote_code),
|
||||
}
|
||||
)
|
||||
if llm.quantise:
|
||||
environ['QUANTIZE'] = str(llm.quantise)
|
||||
if llm.quantise: environ['QUANTIZE'] = str(llm.quantise)
|
||||
return environ
|
||||
|
||||
|
||||
def process_workers_per_resource(wpr: str | float | int, device: tuple[str, ...]) -> TypeGuard[float]:
|
||||
if isinstance(wpr, str):
|
||||
if wpr == 'round_robin':
|
||||
@@ -653,7 +483,6 @@ def process_workers_per_resource(wpr: str | float | int, device: tuple[str, ...]
|
||||
wpr = float(wpr)
|
||||
return wpr
|
||||
|
||||
|
||||
def build_bento_instruction(llm, model_id, serialisation, adapter_map):
|
||||
cmd_name = f'openllm build {model_id} --backend {llm.__llm_backend__}'
|
||||
if llm.quantise:
|
||||
@@ -907,13 +736,6 @@ class BuildBentoOutput(t.TypedDict):
|
||||
help='Optional custom dockerfile template to be used with this BentoLLM.',
|
||||
)
|
||||
@serialisation_option
|
||||
@container_registry_option
|
||||
@click.option(
|
||||
'--container-version-strategy',
|
||||
type=click.Choice(['release', 'latest', 'nightly']),
|
||||
default='release',
|
||||
help="Default container version strategy for the image from '--container-registry'",
|
||||
)
|
||||
@cog.optgroup.group(cls=cog.MutuallyExclusiveOptionGroup, name='Utilities options') # type: ignore[misc]
|
||||
@cog.optgroup.option(
|
||||
'--containerize',
|
||||
@@ -951,8 +773,6 @@ def build_command(
|
||||
containerize: bool,
|
||||
push: bool,
|
||||
serialisation: LiteralSerialisation | None,
|
||||
container_registry: LiteralContainerRegistry,
|
||||
container_version_strategy: LiteralContainerVersionStrategy,
|
||||
force_push: bool,
|
||||
**_: t.Any,
|
||||
) -> BuildBentoOutput:
|
||||
@@ -991,6 +811,10 @@ def build_command(
|
||||
|
||||
state = ItemState.NOT_FOUND
|
||||
|
||||
if backend == 'pt':
|
||||
logger.warning("PyTorch backend is deprecated and will be removed from the next releases. Will set default backend to 'vllm' instead.")
|
||||
backend = 'vllm'
|
||||
|
||||
llm = openllm.LLM[t.Any, t.Any](
|
||||
model_id=model_id,
|
||||
model_version=model_version,
|
||||
@@ -1069,8 +893,6 @@ def build_command(
|
||||
quantize=quantize,
|
||||
extra_dependencies=enable_features,
|
||||
dockerfile_template=dockerfile_template_path,
|
||||
container_registry=container_registry,
|
||||
container_version_strategy=container_version_strategy,
|
||||
)
|
||||
if state != ItemState.OVERWRITE:
|
||||
state = ItemState.ADDED
|
||||
|
||||
Reference in New Issue
Block a user