refactor: focus (#730)

* perf: remove based images

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>

* chore: update changelog

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>

* chore: move dockerifle to run on release only

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>

* chore: cleanup unused types

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>

---------

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
Aaron Pham
2023-11-24 01:11:31 -05:00
committed by GitHub
parent 52a44b1bfa
commit aab173cd99
19 changed files with 168 additions and 679 deletions

View File

@@ -1,31 +1,8 @@
from __future__ import annotations
import enum
import functools
import inspect
import itertools
import logging
import os
import platform
import random
import subprocess
import threading
import time
import traceback
import typing as t
import attr
import click
import click_option_group as cog
import fs
import fs.copy
import fs.errors
import inflection
import orjson
import enum, functools, inspect, itertools, logging, os, platform, random, subprocess, threading, time, traceback, typing as t
import attr, click, fs, inflection, bentoml, openllm, orjson, fs.copy, fs.errors, click_option_group as cog
from bentoml_cli.utils import BentoMLCommandGroup, opt_callback
from simple_di import Provide, inject
import bentoml
import openllm
from bentoml._internal.cloud.config import CloudClientConfig
from bentoml._internal.configuration.containers import BentoMLContainer
from bentoml._internal.models.model import ModelStore
@@ -70,7 +47,6 @@ from ._factory import (
FC,
_AnyCallable,
backend_option,
container_registry_option,
dtype_option,
machine_option,
model_name_argument,
@@ -88,7 +64,6 @@ if t.TYPE_CHECKING:
from bentoml._internal.container import DefaultBuilder
from openllm_client._schemas import StreamingResponse
from openllm_core._configuration import LLMConfig
from openllm_core._typing_compat import LiteralContainerRegistry, LiteralContainerVersionStrategy
else:
torch = LazyLoader('torch', globals(), 'torch')
@@ -248,7 +223,7 @@ class OpenLLMCommandGroup(BentoMLCommandGroup):
return super().get_command(ctx, cmd_name)
def list_commands(self, ctx: click.Context) -> list[str]:
return super().list_commands(ctx) + t.cast('Extensions', extension_command).list_commands(ctx)
return super().list_commands(ctx) + extension_command.list_commands(ctx)
def command(self, *args: t.Any, **kwargs: t.Any) -> t.Callable[[t.Callable[..., t.Any]], click.Command]:
"""Override the default 'cli.command' with supports for aliases for given command, and it wraps the implementation with common parameters."""
@@ -371,7 +346,7 @@ def cli() -> None:
default=None,
help='Maximum sequence length for the model. If not specified, we will use the default value from the model config.',
)
@start_decorator(serve_grpc=False)
@start_decorator
def start_command(
model_id: str,
server_timeout: int,
@@ -396,26 +371,21 @@ def start_command(
$ openllm <start|start-http> <model_id> --<options> ...
```
'''
if backend == 'pt': logger.warning('PyTorch backend is deprecated and will be removed in future releases. Make sure to use vLLM instead.')
if model_id in openllm.CONFIG_MAPPING:
_model_name = model_id
if deprecated_model_id is not None:
model_id = deprecated_model_id
else:
model_id = openllm.AutoConfig.for_model(_model_name)['default_id']
termui.warning(
f"Passing 'openllm start {_model_name}{'' if deprecated_model_id is None else ' --model-id ' + deprecated_model_id}' is deprecated and will be remove in a future version. Use 'openllm start {model_id}' instead."
)
logger.warning("Passing 'openllm start %s%s' is deprecated and will be remove in a future version. Use 'openllm start %s' instead.", _model_name, '' if deprecated_model_id is None else f' --model-id {deprecated_model_id}', model_id)
adapter_map: dict[str, str] | None = attrs.pop('adapter_map', None)
from openllm.serialisation.transformers.weights import has_safetensors_weights
serialisation = t.cast(
LiteralSerialisation,
first_not_none(
serialisation, default='safetensors' if has_safetensors_weights(model_id, model_version) else 'legacy'
),
)
serialisation = first_not_none(serialisation, default='safetensors' if has_safetensors_weights(model_id, model_version) else 'legacy')
if serialisation == 'safetensors' and quantize is not None:
logger.warning("'--quantize=%s' might not work with 'safetensors' serialisation format.", quantize)
logger.warning(
@@ -449,8 +419,7 @@ def start_command(
config, server_attrs = llm.config.model_validate_click(**attrs)
server_timeout = first_not_none(server_timeout, default=config['timeout'])
server_attrs.update({'working_dir': pkg.source_locations('openllm'), 'timeout': server_timeout})
# XXX: currently, theres no development args in bentoml.Server. To be fixed upstream.
development = server_attrs.pop('development')
development = server_attrs.pop('development') # XXX: currently, theres no development args in bentoml.Server. To be fixed upstream.
server_attrs.setdefault('production', not development)
start_env = process_environ(
@@ -479,145 +448,8 @@ def start_command(
# NOTE: Return the configuration for telemetry purposes.
return config
@cli.command(
context_settings=termui.CONTEXT_SETTINGS,
name='start-grpc',
short_help='Start a gRPC LLMServer for any supported LLM.',
)
@click.argument('model_id', type=click.STRING, metavar='[REMOTE_REPO/MODEL_ID | /path/to/local/model]', required=True)
@click.option(
'--model-id',
'deprecated_model_id',
type=click.STRING,
default=None,
hidden=True,
metavar='[REMOTE_REPO/MODEL_ID | /path/to/local/model]',
help='Deprecated. Use positional argument instead.',
)
@start_decorator(serve_grpc=True)
@click.option(
'--max-model-len',
'--max_model_len',
'max_model_len',
default=None,
help='Maximum sequence length for the model. If not specified, we will use the default value from the model config.',
)
def start_grpc_command(
model_id: str,
server_timeout: int,
model_version: str | None,
workers_per_resource: t.Literal['conserved', 'round_robin'] | LiteralString,
device: t.Tuple[str, ...],
quantize: LiteralQuantise | None,
backend: LiteralBackend | None,
serialisation: LiteralSerialisation | None,
cors: bool,
dtype: LiteralDtype,
adapter_id: str | None,
return_process: bool,
deprecated_model_id: str | None,
max_model_len: int | None,
**attrs: t.Any,
) -> LLMConfig | subprocess.Popen[bytes]:
'''Start any LLM as a gRPC server.
\b
```bash
$ openllm start-grpc <model_id> --<options> ...
```
'''
termui.warning(
'Continuous batching is currently not yet supported with gPRC. If you want to use continuous batching with gRPC, feel free to open a GitHub issue about your usecase.\n'
)
if model_id in openllm.CONFIG_MAPPING:
_model_name = model_id
if deprecated_model_id is not None:
model_id = deprecated_model_id
else:
model_id = openllm.AutoConfig.for_model(_model_name)['default_id']
termui.warning(
f"Passing 'openllm start-grpc {_model_name}{'' if deprecated_model_id is None else ' --model-id ' + deprecated_model_id}' is deprecated and will be remove in a future version. Use 'openllm start-grpc {model_id}' instead."
)
adapter_map: dict[str, str] | None = attrs.pop('adapter_map', None)
from openllm.serialisation.transformers.weights import has_safetensors_weights
serialisation = first_not_none(
serialisation, default='safetensors' if has_safetensors_weights(model_id, model_version) else 'legacy'
)
if serialisation == 'safetensors' and quantize is not None:
logger.warning("'--quantize=%s' might not work with 'safetensors' serialisation format.", quantize)
logger.warning(
"Make sure to check out '%s' repository to see if the weights is in '%s' format if unsure.",
model_id,
serialisation,
)
logger.info("Tip: You can always fallback to '--serialisation legacy' when running quantisation.")
import torch
if backend == 'pt' and not torch.cuda.is_available():
if dtype == 'auto':
dtype = 'float'
elif dtype not in {'float', 'float32'}:
logger.warning('"bfloat16" and "half" are not supported on CPU. OpenLLM will default fallback to "float32".')
dtype = 'float' # we need to cast back to full precision if cuda is not available
llm = openllm.LLM[t.Any, t.Any](
model_id=model_id,
model_version=model_version,
backend=backend,
adapter_map=adapter_map,
quantize=quantize,
serialisation=serialisation,
dtype=dtype,
max_model_len=max_model_len,
trust_remote_code=check_bool_env('TRUST_REMOTE_CODE', False),
)
backend_warning(llm.__llm_backend__)
config, server_attrs = llm.config.model_validate_click(**attrs)
server_timeout = first_not_none(server_timeout, default=config['timeout'])
server_attrs.update({'working_dir': pkg.source_locations('openllm'), 'timeout': server_timeout})
server_attrs['grpc_protocol_version'] = 'v1'
# XXX: currently, theres no development args in bentoml.Server. To be fixed upstream.
development = server_attrs.pop('development')
server_attrs.setdefault('production', not development)
start_env = process_environ(
config,
server_timeout,
process_workers_per_resource(first_not_none(workers_per_resource, default=config['workers_per_resource']), device),
device,
cors,
model_id,
adapter_map,
serialisation,
llm,
)
server = bentoml.GrpcServer('_service:svc', **server_attrs)
openllm.utils.analytics.track_start_init(llm.config)
try:
build_bento_instruction(llm, model_id, serialisation, adapter_map)
it = run_server(server.args, start_env, return_process=return_process)
if return_process:
return it
except KeyboardInterrupt:
pass
# NOTE: Return the configuration for telemetry purposes.
return config
def process_environ(
config, server_timeout, wpr, device, cors, model_id, adapter_map, serialisation, llm, use_current_env=True
) -> t.Dict[str, t.Any]:
environ = parse_config_options(
config, server_timeout, wpr, device, cors, os.environ.copy() if use_current_env else {}
)
def process_environ(config, server_timeout, wpr, device, cors, model_id, adapter_map, serialisation, llm, use_current_env=True):
environ = parse_config_options(config, server_timeout, wpr, device, cors, os.environ.copy() if use_current_env else {})
environ.update(
{
'OPENLLM_MODEL_ID': model_id,
@@ -631,11 +463,9 @@ def process_environ(
'TRUST_REMOTE_CODE': str(llm.trust_remote_code),
}
)
if llm.quantise:
environ['QUANTIZE'] = str(llm.quantise)
if llm.quantise: environ['QUANTIZE'] = str(llm.quantise)
return environ
def process_workers_per_resource(wpr: str | float | int, device: tuple[str, ...]) -> TypeGuard[float]:
if isinstance(wpr, str):
if wpr == 'round_robin':
@@ -653,7 +483,6 @@ def process_workers_per_resource(wpr: str | float | int, device: tuple[str, ...]
wpr = float(wpr)
return wpr
def build_bento_instruction(llm, model_id, serialisation, adapter_map):
cmd_name = f'openllm build {model_id} --backend {llm.__llm_backend__}'
if llm.quantise:
@@ -907,13 +736,6 @@ class BuildBentoOutput(t.TypedDict):
help='Optional custom dockerfile template to be used with this BentoLLM.',
)
@serialisation_option
@container_registry_option
@click.option(
'--container-version-strategy',
type=click.Choice(['release', 'latest', 'nightly']),
default='release',
help="Default container version strategy for the image from '--container-registry'",
)
@cog.optgroup.group(cls=cog.MutuallyExclusiveOptionGroup, name='Utilities options') # type: ignore[misc]
@cog.optgroup.option(
'--containerize',
@@ -951,8 +773,6 @@ def build_command(
containerize: bool,
push: bool,
serialisation: LiteralSerialisation | None,
container_registry: LiteralContainerRegistry,
container_version_strategy: LiteralContainerVersionStrategy,
force_push: bool,
**_: t.Any,
) -> BuildBentoOutput:
@@ -991,6 +811,10 @@ def build_command(
state = ItemState.NOT_FOUND
if backend == 'pt':
logger.warning("PyTorch backend is deprecated and will be removed from the next releases. Will set default backend to 'vllm' instead.")
backend = 'vllm'
llm = openllm.LLM[t.Any, t.Any](
model_id=model_id,
model_version=model_version,
@@ -1069,8 +893,6 @@ def build_command(
quantize=quantize,
extra_dependencies=enable_features,
dockerfile_template=dockerfile_template_path,
container_registry=container_registry,
container_version_strategy=container_version_strategy,
)
if state != ItemState.OVERWRITE:
state = ItemState.ADDED