|
|
|
@@ -1,16 +1,17 @@
|
|
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
|
|
import os, logging, traceback, pathlib, sys, fs, click, enum, inflection, bentoml, orjson, openllm, openllm_core, platform, typing as t
|
|
|
|
|
import os, traceback, io, pathlib, sys, fs, click, enum, importlib, importlib.metadata, inflection, bentoml, orjson, openllm, openllm_core as core, platform, tarfile, typing as t
|
|
|
|
|
from ._helpers import recommended_instance_type
|
|
|
|
|
from openllm_core.utils import (
|
|
|
|
|
DEBUG,
|
|
|
|
|
DEBUG_ENV_VAR,
|
|
|
|
|
QUIET_ENV_VAR,
|
|
|
|
|
OPENLLM_DEV_BUILD,
|
|
|
|
|
SHOW_CODEGEN,
|
|
|
|
|
check_bool_env,
|
|
|
|
|
compose,
|
|
|
|
|
first_not_none,
|
|
|
|
|
dantic,
|
|
|
|
|
pkg,
|
|
|
|
|
gen_random_uuid,
|
|
|
|
|
get_debug_mode,
|
|
|
|
|
get_quiet_mode,
|
|
|
|
@@ -25,7 +26,10 @@ from openllm_core._typing_compat import (
|
|
|
|
|
)
|
|
|
|
|
from openllm_cli import termui
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
if sys.version_info >= (3, 11):
|
|
|
|
|
import tomllib
|
|
|
|
|
else:
|
|
|
|
|
import tomli as tomllib
|
|
|
|
|
|
|
|
|
|
OPENLLM_FIGLET = """
|
|
|
|
|
██████╗ ██████╗ ███████╗███╗ ██╗██╗ ██╗ ███╗ ███╗
|
|
|
|
@@ -38,20 +42,21 @@ OPENLLM_FIGLET = """
|
|
|
|
|
_PACKAGE_NAME = 'openllm'
|
|
|
|
|
_TINY_PATH = pathlib.Path(os.path.abspath(__file__)).parent
|
|
|
|
|
_SERVICE_FILE = _TINY_PATH / '_service.py'
|
|
|
|
|
_SERVICE_VARS = '''\
|
|
|
|
|
_SERVICE_README = _TINY_PATH / 'service.md'
|
|
|
|
|
_SERVICE_VARS = """\
|
|
|
|
|
# fmt: off
|
|
|
|
|
# GENERATED BY 'openllm build {__model_id__}'. DO NOT EDIT
|
|
|
|
|
# GENERATED BY '{__command__}'. DO NOT EDIT
|
|
|
|
|
import orjson,openllm_core.utils as coreutils
|
|
|
|
|
model_id='{__model_id__}'
|
|
|
|
|
model_name='{__model_name__}'
|
|
|
|
|
revision=orjson.loads(coreutils.getenv('revision',default={__model_revision__}))
|
|
|
|
|
quantise=coreutils.getenv('quantize',default='{__model_quantise__}',var=['QUANTISE'])
|
|
|
|
|
serialisation=coreutils.getenv('serialization',default='{__model_serialization__}',var=['SERIALISATION'])
|
|
|
|
|
dtype=coreutils.getenv('dtype', default='{__model_dtype__}', var=['TORCH_DTYPE'])
|
|
|
|
|
dtype=coreutils.getenv('dtype', default='{__model_dtype__}',var=['TORCH_DTYPE'])
|
|
|
|
|
trust_remote_code=coreutils.check_bool_env("TRUST_REMOTE_CODE",{__model_trust_remote_code__})
|
|
|
|
|
max_model_len=orjson.loads(coreutils.getenv('max_model_len', default=orjson.dumps({__max_model_len__})))
|
|
|
|
|
gpu_memory_utilization=orjson.loads(coreutils.getenv('gpu_memory_utilization', default=orjson.dumps({__gpu_memory_utilization__}), var=['GPU_MEMORY_UTILISATION']))
|
|
|
|
|
services_config=orjson.loads(coreutils.getenv('services_config',"""{__services_config__}"""))
|
|
|
|
|
'''
|
|
|
|
|
max_model_len=orjson.loads(coreutils.getenv('max_model_len',default=orjson.dumps({__max_model_len__})))
|
|
|
|
|
gpu_memory_utilization=orjson.loads(coreutils.getenv('gpu_memory_utilization',default=orjson.dumps({__gpu_memory_utilization__}),var=['GPU_MEMORY_UTILISATION']))
|
|
|
|
|
services_config=orjson.loads(coreutils.getenv('services_config',default={__services_config__}))
|
|
|
|
|
"""
|
|
|
|
|
HF_HUB_DISABLE_PROGRESS_BARS = 'HF_HUB_DISABLE_PROGRESS_BARS'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -62,20 +67,6 @@ class ItemState(enum.Enum):
|
|
|
|
|
OVERWRITE = 'OVERWRITE'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def parse_device_callback(
|
|
|
|
|
_: click.Context, param: click.Parameter, value: tuple[tuple[str], ...] | None
|
|
|
|
|
) -> t.Tuple[str, ...] | None:
|
|
|
|
|
if value is None:
|
|
|
|
|
return value
|
|
|
|
|
el: t.Tuple[str, ...] = tuple(i for k in value for i in k)
|
|
|
|
|
# NOTE: --device all is a special case
|
|
|
|
|
if len(el) == 1 and el[0] == 'all':
|
|
|
|
|
return tuple(map(str, openllm.utils.available_devices()))
|
|
|
|
|
if len(el) == 1 and el[0] == 'gpu':
|
|
|
|
|
return ('0',)
|
|
|
|
|
return el
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@click.group(context_settings=termui.CONTEXT_SETTINGS, name='openllm')
|
|
|
|
|
@click.version_option(
|
|
|
|
|
None,
|
|
|
|
@@ -94,53 +85,47 @@ def cli() -> None:
|
|
|
|
|
╚═════╝ ╚═╝ ╚══════╝╚═╝ ╚═══╝╚══════╝╚══════╝╚═╝ ╚═╝.
|
|
|
|
|
|
|
|
|
|
\b
|
|
|
|
|
An open platform for operating large language models in production.
|
|
|
|
|
Fine-tune, serve, deploy, and monitor any LLMs with ease.
|
|
|
|
|
Self-Hosting LLMs Made Easy
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def optimization_decorator(fn):
|
|
|
|
|
# NOTE: return device, quantize, serialisation, dtype, max_model_len, gpu_memory_utilization
|
|
|
|
|
def optimization_decorator(fn: t.Callable[..., t.Any]):
|
|
|
|
|
optimization = [
|
|
|
|
|
click.option(
|
|
|
|
|
'--device',
|
|
|
|
|
type=dantic.CUDA,
|
|
|
|
|
multiple=True,
|
|
|
|
|
envvar='CUDA_VISIBLE_DEVICES',
|
|
|
|
|
callback=parse_device_callback,
|
|
|
|
|
help='Assign GPU devices (if available)',
|
|
|
|
|
'--concurrency',
|
|
|
|
|
type=int,
|
|
|
|
|
envvar='CONCURRENCY',
|
|
|
|
|
help='See https://docs.bentoml.com/en/latest/guides/concurrency.html#concurrency for more information.',
|
|
|
|
|
show_envvar=True,
|
|
|
|
|
default=None,
|
|
|
|
|
),
|
|
|
|
|
click.option('--timeout', type=int, default=360000, help='Timeout for the model executor in seconds'),
|
|
|
|
|
click.option(
|
|
|
|
|
'--dtype',
|
|
|
|
|
type=str,
|
|
|
|
|
envvar='DTYPE',
|
|
|
|
|
default='auto',
|
|
|
|
|
help="Optional dtype for casting tensors for running inference ['float16', 'float32', 'bfloat16', 'int8', 'int16']",
|
|
|
|
|
help="Optional dtype for casting tensors for running inference ['float16', 'float32', 'bfloat16']. Default to auto for infering dtype based on available accelerator.",
|
|
|
|
|
),
|
|
|
|
|
click.option(
|
|
|
|
|
'--quantise',
|
|
|
|
|
'--quantize',
|
|
|
|
|
'quantize',
|
|
|
|
|
'quantise',
|
|
|
|
|
type=str,
|
|
|
|
|
default=None,
|
|
|
|
|
envvar='QUANTIZE',
|
|
|
|
|
show_envvar=True,
|
|
|
|
|
help="""Dynamic quantization for running this LLM.
|
|
|
|
|
help="""Quantisation options for this LLM.
|
|
|
|
|
|
|
|
|
|
The following quantization strategies are supported:
|
|
|
|
|
The following quantisation strategies are supported:
|
|
|
|
|
|
|
|
|
|
- ``int8``: ``LLM.int8`` for [8-bit](https://arxiv.org/abs/2208.07339) quantization.
|
|
|
|
|
- ``gptq``: ``GPTQ`` [quantisation](https://arxiv.org/abs/2210.17323)
|
|
|
|
|
|
|
|
|
|
- ``int4``: ``SpQR`` for [4-bit](https://arxiv.org/abs/2306.03078) quantization.
|
|
|
|
|
- ``awq``: ``AWQ`` [AWQ: Activation-aware Weight Quantisation](https://arxiv.org/abs/2306.00978)
|
|
|
|
|
|
|
|
|
|
- ``gptq``: ``GPTQ`` [quantization](https://arxiv.org/abs/2210.17323)
|
|
|
|
|
- ``squeezellm``: ``SqueezeLLM`` [SqueezeLLM: Dense-and-Sparse Quantisation](https://arxiv.org/abs/2306.07629)
|
|
|
|
|
|
|
|
|
|
- ``awq``: ``AWQ`` [AWQ: Activation-aware Weight Quantization](https://arxiv.org/abs/2306.00978)
|
|
|
|
|
|
|
|
|
|
- ``squeezellm``: ``SqueezeLLM`` [SqueezeLLM: Dense-and-Sparse Quantization](https://arxiv.org/abs/2306.07629)
|
|
|
|
|
|
|
|
|
|
> [!NOTE] that the model can also be served with quantized weights.
|
|
|
|
|
> [!NOTE] that the model must be pre-quantised to ensure correct loading, as all aforementioned quantization scheme are post-training quantization.
|
|
|
|
|
""",
|
|
|
|
|
),
|
|
|
|
|
click.option(
|
|
|
|
@@ -151,14 +136,14 @@ def optimization_decorator(fn):
|
|
|
|
|
default=None,
|
|
|
|
|
show_default=True,
|
|
|
|
|
show_envvar=True,
|
|
|
|
|
envvar='OPENLLM_SERIALIZATION',
|
|
|
|
|
help="""Serialisation format for save/load LLM.
|
|
|
|
|
envvar='SERIALIZATION',
|
|
|
|
|
help="""Serialisation format for loading LLM. Make sure to check HF repository for the correct format.
|
|
|
|
|
|
|
|
|
|
Currently the following strategies are supported:
|
|
|
|
|
|
|
|
|
|
- ``safetensors``: This will use safetensors format, which is synonymous to ``safe_serialization=True``.
|
|
|
|
|
|
|
|
|
|
> [!NOTE] Safetensors might not work for every cases, and you can always fallback to ``legacy`` if needed.
|
|
|
|
|
> [!NOTE] Safetensors might not work for older models, and you can always fallback to ``legacy`` if needed.
|
|
|
|
|
|
|
|
|
|
- ``legacy``: This will use PyTorch serialisation format, often as ``.bin`` files. This should be used if the model doesn't yet support safetensors.
|
|
|
|
|
""",
|
|
|
|
@@ -178,14 +163,28 @@ def optimization_decorator(fn):
|
|
|
|
|
default=0.9,
|
|
|
|
|
help='The percentage of GPU memory to be used for the model executor',
|
|
|
|
|
),
|
|
|
|
|
click.option(
|
|
|
|
|
'--trust-remote-code',
|
|
|
|
|
'--trust_remote_code',
|
|
|
|
|
'trust_remote_code',
|
|
|
|
|
type=bool,
|
|
|
|
|
is_flag=True,
|
|
|
|
|
default=False,
|
|
|
|
|
show_envvar=True,
|
|
|
|
|
envvar='TRUST_REMOTE_CODE',
|
|
|
|
|
help='If model from HuggingFace requires custom code, pass this to enable remote code execution. If the model is a private model, make sure to also pass this argument such that OpenLLM can determine model architecture to load.',
|
|
|
|
|
),
|
|
|
|
|
]
|
|
|
|
|
return compose(*optimization)(fn)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def shared_decorator(fn):
|
|
|
|
|
def shared_decorator(fn: t.Callable[..., t.Any]):
|
|
|
|
|
shared = [
|
|
|
|
|
click.argument(
|
|
|
|
|
'model_id', type=click.STRING, metavar='[REMOTE_REPO/MODEL_ID | /path/to/local/model]', required=True
|
|
|
|
|
'model_id',
|
|
|
|
|
type=click.STRING,
|
|
|
|
|
metavar='[REMOTE_REPO/MODEL_ID | /path/to/local/model | bentomodel_tag]',
|
|
|
|
|
required=True,
|
|
|
|
|
),
|
|
|
|
|
click.option(
|
|
|
|
|
'--revision',
|
|
|
|
@@ -194,15 +193,17 @@ def shared_decorator(fn):
|
|
|
|
|
'model_version',
|
|
|
|
|
type=click.STRING,
|
|
|
|
|
default=None,
|
|
|
|
|
help='Optional model revision to save for this model. It will be inferred automatically from model-id.',
|
|
|
|
|
help='Optional model revision to for this LLM. If this is a private model, specify this alongside model_id will be used as a bentomodel tag. If using in conjunction with a HF model id, this will be a specific revision, code branch, or a commit id on HF repo.',
|
|
|
|
|
),
|
|
|
|
|
click.option(
|
|
|
|
|
'--model-tag',
|
|
|
|
|
'--bentomodel-tag',
|
|
|
|
|
'model_tag',
|
|
|
|
|
type=click.STRING,
|
|
|
|
|
default=None,
|
|
|
|
|
help='Optional bentomodel tag to save for this model. It will be generated automatically based on model_id and model_version if not specified.',
|
|
|
|
|
'--debug',
|
|
|
|
|
'--verbose',
|
|
|
|
|
type=bool,
|
|
|
|
|
default=False,
|
|
|
|
|
show_envvar=True,
|
|
|
|
|
is_flag=True,
|
|
|
|
|
envvar='DEBUG',
|
|
|
|
|
help='whether to enable verbose logging (For more fine-grained control, set DEBUG to number instead of this flag.).',
|
|
|
|
|
),
|
|
|
|
|
]
|
|
|
|
|
return compose(*shared)(fn)
|
|
|
|
@@ -210,76 +211,77 @@ def shared_decorator(fn):
|
|
|
|
|
|
|
|
|
|
@cli.command(name='start')
|
|
|
|
|
@shared_decorator
|
|
|
|
|
@click.option('--timeout', type=int, default=360000, help='Timeout for the model executor in seconds')
|
|
|
|
|
@click.option('--port', type=int, default=3000, help='Port to serve the LLM. Default to 3000.')
|
|
|
|
|
@optimization_decorator
|
|
|
|
|
def start_command(
|
|
|
|
|
model_id: str,
|
|
|
|
|
model_version: str | None,
|
|
|
|
|
model_tag: str | None,
|
|
|
|
|
timeout: int,
|
|
|
|
|
concurrency: int | None,
|
|
|
|
|
port: int,
|
|
|
|
|
device: t.Tuple[str, ...],
|
|
|
|
|
quantize: LiteralQuantise | None,
|
|
|
|
|
quantise: LiteralQuantise | None,
|
|
|
|
|
serialisation: LiteralSerialisation | None,
|
|
|
|
|
dtype: LiteralDtype | t.Literal['auto', 'float'],
|
|
|
|
|
max_model_len: int | None,
|
|
|
|
|
gpu_memory_utilization: float,
|
|
|
|
|
trust_remote_code: bool,
|
|
|
|
|
debug: bool,
|
|
|
|
|
):
|
|
|
|
|
"""Start any LLM as a REST server.
|
|
|
|
|
|
|
|
|
|
\b
|
|
|
|
|
```bash
|
|
|
|
|
$ openllm <start|start-http> <model_id> --<options> ...
|
|
|
|
|
$ openllm start microsoft/Phi-3-mini-4k-instruct --trust-remote-code
|
|
|
|
|
```
|
|
|
|
|
"""
|
|
|
|
|
import transformers
|
|
|
|
|
|
|
|
|
|
from _bentoml_impl.server import serve_http
|
|
|
|
|
from bentoml._internal.service.loader import load
|
|
|
|
|
from bentoml._internal.log import configure_server_logging
|
|
|
|
|
|
|
|
|
|
configure_server_logging()
|
|
|
|
|
|
|
|
|
|
trust_remote_code = check_bool_env('TRUST_REMOTE_CODE', False)
|
|
|
|
|
try:
|
|
|
|
|
# if given model_id is a private model, then we can use it directly
|
|
|
|
|
# NOTE: if given model_id is a private model (assuming this is packaged into a bentomodel), then we can use it directly
|
|
|
|
|
bentomodel = bentoml.models.get(model_id.lower())
|
|
|
|
|
model_id = bentomodel.path
|
|
|
|
|
if not trust_remote_code:
|
|
|
|
|
trust_remote_code = True
|
|
|
|
|
except (ValueError, bentoml.exceptions.NotFound):
|
|
|
|
|
bentomodel = None
|
|
|
|
|
config = transformers.AutoConfig.from_pretrained(model_id, trust_remote_code=trust_remote_code)
|
|
|
|
|
for arch in config.architectures:
|
|
|
|
|
if arch in openllm_core.AutoConfig._architecture_mappings:
|
|
|
|
|
model_name = openllm_core.AutoConfig._architecture_mappings[arch]
|
|
|
|
|
break
|
|
|
|
|
else:
|
|
|
|
|
raise RuntimeError(f'Failed to determine config class for {model_id}')
|
|
|
|
|
|
|
|
|
|
llm_config = openllm_core.AutoConfig.for_model(model_name).model_construct_env()
|
|
|
|
|
llm_config = core.AutoConfig.from_id(model_id, trust_remote_code=trust_remote_code)
|
|
|
|
|
if serialisation is None:
|
|
|
|
|
termui.warning(
|
|
|
|
|
f"Serialisation format is not specified. Defaulting to '{llm_config['serialisation']}'. Your model might not work with this format. Make sure to explicitly specify the serialisation format."
|
|
|
|
|
)
|
|
|
|
|
serialisation = llm_config['serialisation']
|
|
|
|
|
|
|
|
|
|
# TODO: support LoRA adapters
|
|
|
|
|
os.environ.update({
|
|
|
|
|
QUIET_ENV_VAR: str(openllm.utils.get_quiet_mode()),
|
|
|
|
|
DEBUG_ENV_VAR: str(openllm.utils.get_debug_mode()),
|
|
|
|
|
DEBUG_ENV_VAR: str(debug) or str(openllm.utils.get_debug_mode()),
|
|
|
|
|
HF_HUB_DISABLE_PROGRESS_BARS: str(not openllm.utils.get_debug_mode()),
|
|
|
|
|
'MODEL_ID': model_id,
|
|
|
|
|
'MODEL_NAME': model_name,
|
|
|
|
|
# handling custom revision if users specify --revision alongside with model_id
|
|
|
|
|
# this should work only if bentomodel is None
|
|
|
|
|
'REVISION': orjson.dumps(first_not_none(model_version, default=None)).decode(),
|
|
|
|
|
'SERIALIZATION': serialisation,
|
|
|
|
|
'OPENLLM_CONFIG': llm_config.model_dump_json(),
|
|
|
|
|
'DTYPE': dtype,
|
|
|
|
|
'TRUST_REMOTE_CODE': str(trust_remote_code),
|
|
|
|
|
'GPU_MEMORY_UTILIZATION': orjson.dumps(gpu_memory_utilization).decode(),
|
|
|
|
|
'SERVICES_CONFIG': orjson.dumps(
|
|
|
|
|
dict(resources={'gpu' if device else 'cpu': len(device) if device else '1'}, traffic=dict(timeout=timeout))
|
|
|
|
|
# XXX: right now we just enable GPU by default. will revisit this if we decide to support TPU later.
|
|
|
|
|
dict(
|
|
|
|
|
resources={'gpu': len(openllm.utils.available_devices())},
|
|
|
|
|
traffic=dict(timeout=timeout, concurrency=concurrency),
|
|
|
|
|
)
|
|
|
|
|
).decode(),
|
|
|
|
|
})
|
|
|
|
|
if max_model_len is not None:
|
|
|
|
|
os.environ['MAX_MODEL_LEN'] = orjson.dumps(max_model_len)
|
|
|
|
|
if quantize:
|
|
|
|
|
os.environ['QUANTIZE'] = str(quantize)
|
|
|
|
|
os.environ['MAX_MODEL_LEN'] = orjson.dumps(max_model_len).decode()
|
|
|
|
|
if quantise:
|
|
|
|
|
os.environ['QUANTIZE'] = str(quantise)
|
|
|
|
|
|
|
|
|
|
working_dir = os.path.abspath(os.path.dirname(__file__))
|
|
|
|
|
if sys.path[0] != working_dir:
|
|
|
|
@@ -290,22 +292,69 @@ def start_command(
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def construct_python_options(llm_config, llm_fs):
|
|
|
|
|
from bentoml._internal.bento.build_config import PythonOptions
|
|
|
|
|
from openllm.bundle._package import build_editable
|
|
|
|
|
def get_package_version(_package: str) -> str:
|
|
|
|
|
try:
|
|
|
|
|
return importlib.import_module('._version', _package).__version__
|
|
|
|
|
except ModuleNotFoundError:
|
|
|
|
|
return importlib.metadata.version(_package)
|
|
|
|
|
|
|
|
|
|
# TODO: Add this line back once 0.5 is out, for now depends on OPENLLM_DEV_BUILD
|
|
|
|
|
# packages = ['scipy', 'bentoml[tracing]>=1.2.8', 'openllm[vllm]>0.4', 'vllm>=0.3']
|
|
|
|
|
packages = ['scipy', 'bentoml[tracing]>=1.2.8', 'openllm']
|
|
|
|
|
if llm_config['requirements'] is not None:
|
|
|
|
|
packages.extend(llm_config['requirements'])
|
|
|
|
|
built_wheels = [build_editable(llm_fs.getsyspath('/'), p) for p in ('openllm_core', 'openllm_client', 'openllm')]
|
|
|
|
|
return PythonOptions(
|
|
|
|
|
packages=packages,
|
|
|
|
|
wheels=[llm_fs.getsyspath(f"/{i.split('/')[-1]}") for i in built_wheels] if all(i for i in built_wheels) else None,
|
|
|
|
|
lock_packages=False,
|
|
|
|
|
|
|
|
|
|
def build_sdist(target_path: str, package: t.Literal['openllm', 'openllm-core', 'openllm-client']):
|
|
|
|
|
import tomli_w
|
|
|
|
|
|
|
|
|
|
if not check_bool_env(OPENLLM_DEV_BUILD, default=False):
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
module_location = pkg.source_locations(inflection.underscore(package))
|
|
|
|
|
if not module_location:
|
|
|
|
|
raise RuntimeError(
|
|
|
|
|
f'Could not find the source location of {package}. Make sure to set "{OPENLLM_DEV_BUILD}=False" if you are not using development build.'
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
package_path = pathlib.Path(module_location)
|
|
|
|
|
package_version = get_package_version(package)
|
|
|
|
|
project_path = package_path.parent.parent
|
|
|
|
|
|
|
|
|
|
if not (project_path / 'pyproject.toml').exists():
|
|
|
|
|
termui.warning(
|
|
|
|
|
f'Custom "{package}" is detected. For a Bento to use the same build at serving time, at your custom "{package}" build to pip packages list under your "bentofile.yaml". i.e: "packages=[\'git+https://github.com/bentoml/openllm.git@bc0be03\']"'
|
|
|
|
|
)
|
|
|
|
|
return
|
|
|
|
|
termui.debug(
|
|
|
|
|
f'"{package}" is installed in "editable" mode; building "{package}" distribution with local code base. the built tar will be included in generated bento.'
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def exclude_pycache(tarinfo: tarfile.TarInfo):
|
|
|
|
|
if '__pycache__' in tarinfo.name or tarinfo.name.endswith(('.pyc', '.pyo')):
|
|
|
|
|
return None
|
|
|
|
|
return tarinfo
|
|
|
|
|
|
|
|
|
|
with open(project_path / 'pyproject.toml', 'rb') as f:
|
|
|
|
|
pyproject_toml = tomllib.load(f)
|
|
|
|
|
pyproject_toml['project']['version'] = package_version
|
|
|
|
|
if 'dynamic' in pyproject_toml['project'] and 'version' in pyproject_toml['project']['dynamic']:
|
|
|
|
|
pyproject_toml['project']['dynamic'].remove('version')
|
|
|
|
|
|
|
|
|
|
pyproject_io = io.BytesIO()
|
|
|
|
|
tomli_w.dump(pyproject_toml, pyproject_io)
|
|
|
|
|
# make a tarball of this package-version
|
|
|
|
|
base_name = f'{package}-{package_version}'
|
|
|
|
|
sdist_filename = f'{base_name}.tar.gz'
|
|
|
|
|
files_to_include = ['src', 'LICENSE.md', 'README.md']
|
|
|
|
|
|
|
|
|
|
if not pathlib.Path(target_path).exists():
|
|
|
|
|
pathlib.Path(target_path).mkdir(parents=True, exist_ok=True)
|
|
|
|
|
with tarfile.open(pathlib.Path(target_path, sdist_filename), 'w:gz') as tar:
|
|
|
|
|
for file in files_to_include:
|
|
|
|
|
tar.add(project_path / file, arcname=f'{base_name}/{file}', filter=exclude_pycache)
|
|
|
|
|
if package == 'openllm':
|
|
|
|
|
tar.add(project_path / 'CHANGELOG.md', arcname=f'{base_name}/CHANGELOG.md', filter=exclude_pycache)
|
|
|
|
|
tarinfo = tar.gettarinfo(project_path / 'pyproject.toml', arcname=f'{base_name}/pyproject.toml')
|
|
|
|
|
tarinfo.size = pyproject_io.tell()
|
|
|
|
|
pyproject_io.seek(0)
|
|
|
|
|
tar.addfile(tarinfo, pyproject_io)
|
|
|
|
|
return sdist_filename
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class EnvironmentEntry(TypedDict):
|
|
|
|
|
name: str
|
|
|
|
@@ -327,8 +376,6 @@ class EnvironmentEntry(TypedDict):
|
|
|
|
|
help='Optional bento version for this BentoLLM. Default is the the model revision.',
|
|
|
|
|
)
|
|
|
|
|
@click.option('--overwrite', is_flag=True, help='Overwrite existing Bento for given LLM if it already exists.')
|
|
|
|
|
@click.option('--nightly', is_flag=True, default=False, help='Package with OpenLLM nightly version.')
|
|
|
|
|
@click.option('--timeout', type=int, default=360000, help='Timeout for the model executor in seconds')
|
|
|
|
|
@optimization_decorator
|
|
|
|
|
@click.option(
|
|
|
|
|
'-o',
|
|
|
|
@@ -338,31 +385,28 @@ class EnvironmentEntry(TypedDict):
|
|
|
|
|
show_default=True,
|
|
|
|
|
help="Output log format. '-o tag' to display only bento tag.",
|
|
|
|
|
)
|
|
|
|
|
@click.pass_context
|
|
|
|
|
def build_command(
|
|
|
|
|
ctx: click.Context,
|
|
|
|
|
/,
|
|
|
|
|
model_id: str,
|
|
|
|
|
model_version: str | None,
|
|
|
|
|
model_tag: str | None,
|
|
|
|
|
bento_version: str | None,
|
|
|
|
|
bento_tag: str | None,
|
|
|
|
|
overwrite: bool,
|
|
|
|
|
device: t.Tuple[str, ...],
|
|
|
|
|
nightly: bool,
|
|
|
|
|
timeout: int,
|
|
|
|
|
quantize: LiteralQuantise | None,
|
|
|
|
|
concurrency: int | None,
|
|
|
|
|
quantise: LiteralQuantise | None,
|
|
|
|
|
serialisation: LiteralSerialisation | None,
|
|
|
|
|
dtype: LiteralDtype | t.Literal['auto', 'float'],
|
|
|
|
|
max_model_len: int | None,
|
|
|
|
|
gpu_memory_utilization: float,
|
|
|
|
|
output: t.Literal['default', 'tag'],
|
|
|
|
|
trust_remote_code: bool,
|
|
|
|
|
debug: bool,
|
|
|
|
|
):
|
|
|
|
|
"""Package a given models into a BentoLLM.
|
|
|
|
|
"""Package a given LLM into a servable artefacts.
|
|
|
|
|
|
|
|
|
|
\b
|
|
|
|
|
```bash
|
|
|
|
|
$ openllm build google/flan-t5-large
|
|
|
|
|
$ openllm build microsoft/Phi-3-mini-4k-instruct --trust-remote-code
|
|
|
|
|
```
|
|
|
|
|
|
|
|
|
|
\b
|
|
|
|
@@ -371,43 +415,42 @@ def build_command(
|
|
|
|
|
> to have https://github.com/NVIDIA/nvidia-container-toolkit install locally.
|
|
|
|
|
|
|
|
|
|
\b
|
|
|
|
|
> [!IMPORTANT]
|
|
|
|
|
> To build the bento with compiled OpenLLM, make sure to prepend HATCH_BUILD_HOOKS_ENABLE=1. Make sure that the deployment
|
|
|
|
|
> target also use the same Python version and architecture as build machine.
|
|
|
|
|
> [!NOTE]
|
|
|
|
|
> For private model, make sure to save it to the bentomodel store first. See https://docs.bentoml.com/en/latest/guides/model-store.html#model-store for more information
|
|
|
|
|
"""
|
|
|
|
|
import transformers
|
|
|
|
|
from bentoml._internal.configuration.containers import BentoMLContainer
|
|
|
|
|
from bentoml._internal.configuration import set_quiet_mode
|
|
|
|
|
from bentoml._internal.log import configure_logging
|
|
|
|
|
from bentoml._internal.bento.build_config import BentoBuildConfig
|
|
|
|
|
from bentoml._internal.bento.build_config import DockerOptions
|
|
|
|
|
from bentoml._internal.bento.build_config import ModelSpec
|
|
|
|
|
from bentoml._internal.configuration import set_quiet_mode
|
|
|
|
|
from bentoml._internal.configuration.containers import BentoMLContainer
|
|
|
|
|
from bentoml._internal.bento.build_config import BentoBuildConfig, DockerOptions, ModelSpec, PythonOptions
|
|
|
|
|
|
|
|
|
|
if output == 'tag':
|
|
|
|
|
set_quiet_mode(True)
|
|
|
|
|
configure_logging()
|
|
|
|
|
|
|
|
|
|
trust_remote_code = check_bool_env('TRUST_REMOTE_CODE', False)
|
|
|
|
|
try:
|
|
|
|
|
# if given model_id is a private model, then we can use it directly
|
|
|
|
|
# NOTE: if given model_id is a private model (assuming the model is packaged into a bentomodel), then we can use it directly
|
|
|
|
|
bentomodel = bentoml.models.get(model_id.lower())
|
|
|
|
|
model_id = bentomodel.path
|
|
|
|
|
_revision = bentomodel.tag.version
|
|
|
|
|
if not trust_remote_code:
|
|
|
|
|
trust_remote_code = True
|
|
|
|
|
except (ValueError, bentoml.exceptions.NotFound):
|
|
|
|
|
bentomodel = None
|
|
|
|
|
_revision = None
|
|
|
|
|
bentomodel, _revision = None, None
|
|
|
|
|
|
|
|
|
|
config = transformers.AutoConfig.from_pretrained(model_id, trust_remote_code=trust_remote_code)
|
|
|
|
|
for arch in config.architectures:
|
|
|
|
|
if arch in openllm_core.AutoConfig._architecture_mappings:
|
|
|
|
|
model_name = openllm_core.AutoConfig._architecture_mappings[arch]
|
|
|
|
|
break
|
|
|
|
|
else:
|
|
|
|
|
raise RuntimeError(f'Failed to determine config class for {model_id}')
|
|
|
|
|
llm_config = core.AutoConfig.from_id(model_id, trust_remote_code=trust_remote_code)
|
|
|
|
|
transformers_config = transformers.AutoConfig.from_pretrained(model_id, trust_remote_code=trust_remote_code)
|
|
|
|
|
commit_hash = getattr(transformers_config, '_commit_hash', None)
|
|
|
|
|
|
|
|
|
|
llm_config: openllm_core.LLMConfig = openllm_core.AutoConfig.for_model(model_name).model_construct_env()
|
|
|
|
|
# in case that user specify the revision here,
|
|
|
|
|
generated_uuid = gen_random_uuid()
|
|
|
|
|
_revision = first_not_none(_revision, model_version, commit_hash, default=generated_uuid)
|
|
|
|
|
|
|
|
|
|
model_revision = None
|
|
|
|
|
if bentomodel is None and model_version is not None:
|
|
|
|
|
# this is when --revision|--model-version is specified alongside with HF model-id, then we set it in the generated service_vars.py, then we let users manage this themselves
|
|
|
|
|
model_revision = model_version
|
|
|
|
|
|
|
|
|
|
_revision = first_not_none(_revision, getattr(config, '_commit_hash', None), default=gen_random_uuid())
|
|
|
|
|
if serialisation is None:
|
|
|
|
|
termui.warning(
|
|
|
|
|
f"Serialisation format is not specified. Defaulting to '{llm_config['serialisation']}'. Your model might not work with this format. Make sure to explicitly specify the serialisation format."
|
|
|
|
@@ -416,74 +459,93 @@ def build_command(
|
|
|
|
|
|
|
|
|
|
if bento_tag is None:
|
|
|
|
|
_bento_version = first_not_none(bento_version, default=_revision)
|
|
|
|
|
bento_tag = bentoml.Tag.from_taglike(f'{normalise_model_name(model_id)}-service:{_bento_version}'.lower().strip())
|
|
|
|
|
generated_tag = bentoml.Tag.from_taglike(
|
|
|
|
|
f'{normalise_model_name(model_id)}-service:{_bento_version}'.lower().strip()
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
bento_tag = bentoml.Tag.from_taglike(bento_tag)
|
|
|
|
|
generated_tag = bentoml.Tag.from_taglike(bento_tag)
|
|
|
|
|
|
|
|
|
|
state = ItemState.NOT_FOUND
|
|
|
|
|
try:
|
|
|
|
|
bento = bentoml.get(bento_tag)
|
|
|
|
|
bento = bentoml.get(generated_tag)
|
|
|
|
|
if overwrite:
|
|
|
|
|
bentoml.delete(bento_tag)
|
|
|
|
|
bentoml.delete(generated_tag)
|
|
|
|
|
state = ItemState.OVERWRITE
|
|
|
|
|
raise bentoml.exceptions.NotFound(f'Rebuilding existing Bento {bento_tag}') from None
|
|
|
|
|
raise bentoml.exceptions.NotFound(f'Rebuilding existing Bento {generated_tag}') from None
|
|
|
|
|
state = ItemState.EXISTS
|
|
|
|
|
except bentoml.exceptions.NotFound:
|
|
|
|
|
if state != ItemState.OVERWRITE:
|
|
|
|
|
state = ItemState.ADDED
|
|
|
|
|
|
|
|
|
|
labels = {'library': 'vllm'}
|
|
|
|
|
labels = {'runtime': 'vllm'}
|
|
|
|
|
# XXX: right now we just enable GPU by default. will revisit this if we decide to support TPU later.
|
|
|
|
|
service_config = dict(
|
|
|
|
|
resources={
|
|
|
|
|
'gpu' if device else 'cpu': len(device) if device else '1',
|
|
|
|
|
'gpu_type': recommended_instance_type(model_id, bentomodel, serialisation),
|
|
|
|
|
},
|
|
|
|
|
traffic=dict(timeout=timeout),
|
|
|
|
|
resources=dict(
|
|
|
|
|
gpu=len(openllm.utils.available_devices()),
|
|
|
|
|
gpu_type=recommended_instance_type(model_id, bentomodel, serialisation),
|
|
|
|
|
),
|
|
|
|
|
traffic=dict(timeout=timeout, concurrency=concurrency),
|
|
|
|
|
)
|
|
|
|
|
with fs.open_fs(f'temp://llm_{gen_random_uuid()}') as llm_fs:
|
|
|
|
|
logger.debug('Generating service vars %s (dir=%s)', model_id, llm_fs.getsyspath('/'))
|
|
|
|
|
with fs.open_fs(f'temp://{gen_random_uuid()}') as llm_fs, fs.open_fs(
|
|
|
|
|
f'temp://wheels_{gen_random_uuid()}'
|
|
|
|
|
) as wheel_fs:
|
|
|
|
|
termui.debug(f'Generating service vars {model_id} (dir={llm_fs.getsyspath("/")})')
|
|
|
|
|
script = _SERVICE_VARS.format(
|
|
|
|
|
__command__=' '.join(['openllm', *sys.argv[1:]]),
|
|
|
|
|
__model_id__=model_id,
|
|
|
|
|
__model_name__=model_name,
|
|
|
|
|
__model_quantise__=quantize,
|
|
|
|
|
__model_revision__=orjson.dumps(model_revision),
|
|
|
|
|
__model_quantise__=quantise,
|
|
|
|
|
__model_dtype__=dtype,
|
|
|
|
|
__model_serialization__=serialisation,
|
|
|
|
|
__model_trust_remote_code__=trust_remote_code,
|
|
|
|
|
__max_model_len__=max_model_len,
|
|
|
|
|
__gpu_memory_utilization__=gpu_memory_utilization,
|
|
|
|
|
__services_config__=orjson.dumps(service_config).decode(),
|
|
|
|
|
__services_config__=orjson.dumps(service_config),
|
|
|
|
|
)
|
|
|
|
|
models = []
|
|
|
|
|
if bentomodel is not None:
|
|
|
|
|
models.append(ModelSpec.from_item({'tag': str(bentomodel.tag), 'alias': bentomodel.tag.name}))
|
|
|
|
|
if SHOW_CODEGEN:
|
|
|
|
|
logger.info('Generated _service_vars.py:\n%s', script)
|
|
|
|
|
termui.info(f'\n{"=" * 27}\nGenerated _service_vars.py:\n\n{script}\n{"=" * 27}\n')
|
|
|
|
|
llm_fs.writetext('_service_vars.py', script)
|
|
|
|
|
|
|
|
|
|
with _SERVICE_README.open('r') as f:
|
|
|
|
|
service_readme = f.read()
|
|
|
|
|
service_readme = service_readme.format(model_id=model_id)
|
|
|
|
|
with _SERVICE_FILE.open('r') as f:
|
|
|
|
|
service_src = f.read()
|
|
|
|
|
llm_fs.writetext(llm_config['service_name'], service_src)
|
|
|
|
|
|
|
|
|
|
built_wheels = [build_sdist(wheel_fs.getsyspath('/'), p) for p in ('openllm-core', 'openllm-client', 'openllm')]
|
|
|
|
|
|
|
|
|
|
bento = bentoml.Bento.create(
|
|
|
|
|
version=bento_tag.version,
|
|
|
|
|
version=generated_tag.version,
|
|
|
|
|
build_ctx=llm_fs.getsyspath('/'),
|
|
|
|
|
build_config=BentoBuildConfig(
|
|
|
|
|
service=f"{llm_config['service_name']}:LLMService",
|
|
|
|
|
name=bento_tag.name,
|
|
|
|
|
name=generated_tag.name,
|
|
|
|
|
labels=labels,
|
|
|
|
|
models=models,
|
|
|
|
|
models=[ModelSpec.from_item({'tag': str(bentomodel.tag), 'alias': bentomodel.tag.name})]
|
|
|
|
|
if bentomodel is not None
|
|
|
|
|
else [],
|
|
|
|
|
envs=[
|
|
|
|
|
EnvironmentEntry(name='NVIDIA_DRIVER_CAPABILITIES', value='compute,utility'),
|
|
|
|
|
EnvironmentEntry(name='VLLM_VERSION', value='0.4.2'),
|
|
|
|
|
EnvironmentEntry(name=HF_HUB_DISABLE_PROGRESS_BARS, value='TRUE'),
|
|
|
|
|
],
|
|
|
|
|
description=f"OpenLLM service for {llm_config['start_name']}",
|
|
|
|
|
description=service_readme,
|
|
|
|
|
include=list(llm_fs.walk.files()),
|
|
|
|
|
exclude=['/venv', '/.venv', '__pycache__/', '*.py[cod]', '*$py.class'],
|
|
|
|
|
python=construct_python_options(llm_config, llm_fs),
|
|
|
|
|
python=PythonOptions(
|
|
|
|
|
packages=['scipy', 'bentoml[tracing]>=1.2.16', 'openllm[vllm]'],
|
|
|
|
|
pip_args='--no-color --progress-bar off',
|
|
|
|
|
wheels=[wheel_fs.getsyspath(f"/{i.split('/')[-1]}") for i in built_wheels]
|
|
|
|
|
if all(i for i in built_wheels)
|
|
|
|
|
else None,
|
|
|
|
|
lock_packages=False,
|
|
|
|
|
),
|
|
|
|
|
docker=DockerOptions(
|
|
|
|
|
python_version='3.11',
|
|
|
|
|
setup_script=str(_TINY_PATH / 'setup.sh') if nightly else None,
|
|
|
|
|
setup_script=str(_TINY_PATH / 'setup.sh'),
|
|
|
|
|
dockerfile_template=str(_TINY_PATH / 'Dockerfile.j2'),
|
|
|
|
|
system_packages=['git'],
|
|
|
|
|
),
|
|
|
|
|
),
|
|
|
|
|
).save(bento_store=BentoMLContainer.bento_store.get(), model_store=BentoMLContainer.model_store.get())
|
|
|
|
@@ -500,7 +562,7 @@ def build_command(
|
|
|
|
|
termui.info(f"Successfully built Bento '{bento.tag}'.\n")
|
|
|
|
|
elif not overwrite:
|
|
|
|
|
termui.warning(f"Bento for '{model_id}' already exists [{bento}]. To overwrite it pass '--overwrite'.\n")
|
|
|
|
|
if not get_debug_mode():
|
|
|
|
|
if not (debug or get_debug_mode()):
|
|
|
|
|
termui.echo(OPENLLM_FIGLET)
|
|
|
|
|
termui.echo('📖 Next steps:\n', nl=False)
|
|
|
|
|
termui.echo(f'☁️ Deploy to BentoCloud:\n $ bentoml deploy {bento.tag} -n ${{DEPLOYMENT_NAME}}\n', nl=False)
|
|
|
|
|