diff --git a/.python-version-default b/.python-version-default index bd28b9c5..2c073331 100644 --- a/.python-version-default +++ b/.python-version-default @@ -1 +1 @@ -3.9 +3.11 diff --git a/openllm-core/src/openllm_core/_schemas.py b/openllm-core/src/openllm_core/_schemas.py index fa6202d2..33a09080 100644 --- a/openllm-core/src/openllm_core/_schemas.py +++ b/openllm-core/src/openllm_core/_schemas.py @@ -1,6 +1,6 @@ from __future__ import annotations -import pydantic, inflection, orjson, typing as t +import pydantic, orjson, typing as t from ._configuration import LLMConfig from .utils import gen_random_uuid from ._typing_compat import Required, TypedDict, LiteralString @@ -47,14 +47,10 @@ class GenerationInput(pydantic.BaseModel): request_id: t.Optional[str] = pydantic.Field(default=None) adapter_name: t.Optional[str] = pydantic.Field(default=None) - _class_ref: t.ClassVar[type[LLMConfig]] = pydantic.PrivateAttr() - - @pydantic.field_validator('llm_config') - @classmethod - def llm_config_validator(cls, v: LLMConfig | dict[str, t.Any]) -> LLMConfig: - if isinstance(v, dict): - return cls._class_ref.model_construct_env(**v) - return v + def __init__(self, *, _internal=False, **data: t.Any): + if not _internal: + raise RuntimeError('This class is not meant to be used directly. Use "from_config" instead') + super().__init__(**data) @pydantic.field_validator('stop') @classmethod @@ -81,35 +77,9 @@ class GenerationInput(pydantic.BaseModel): flattened['stop_token_ids'] = self.stop_token_ids return flattened - def __init__(self, /, *, _internal: bool = False, **data: t.Any) -> None: - if not _internal: - raise RuntimeError( - f'Cannot instantiate GenerationInput directly. Use "{self.__class__.__qualname__}.from_dict" instead.' - ) - super().__init__(**data) - - @classmethod - def from_dict(cls, structured: GenerationInputDict) -> GenerationInput: - if not hasattr(cls, '_class_ref'): - raise ValueError( - 'Cannot use "from_dict" from a raw GenerationInput class. Currently only supports class created from "from_config".' - ) - filtered: dict[str, t.Any] = {k: v for k, v in structured.items() if v is not None} - llm_config: dict[str, t.Any] | None = filtered.pop('llm_config', None) - if llm_config is not None: - filtered['llm_config'] = cls._class_ref.model_construct_env(**llm_config) - - return cls(_internal=True, **filtered) - @classmethod def from_config(cls, llm_config: LLMConfig) -> type[GenerationInput]: - klass = pydantic.create_model( - inflection.camelize(llm_config['start_name']) + 'GenerationInput', - __base__=cls, - llm_config=(type(llm_config), llm_config), - _class_ref=(llm_config.__class__, pydantic.PrivateAttr(default=llm_config.__class__)), - ) - return klass + return cls(_internal=True, llm_config=llm_config) # NOTE: parameters from vllm.RequestOutput and vllm.CompletionOutput since vllm is not available on CPU. diff --git a/openllm-core/src/openllm_core/config/configuration_auto.py b/openllm-core/src/openllm_core/config/configuration_auto.py index e9aa7f27..69b69c9b 100644 --- a/openllm-core/src/openllm_core/config/configuration_auto.py +++ b/openllm-core/src/openllm_core/config/configuration_auto.py @@ -218,3 +218,16 @@ class AutoConfig: raise ValueError( f"Failed to determine config class for '{bentomodel.name}'. Make sure {bentomodel.name} is saved with openllm." ) + + @classmethod + def from_id(cls, model_id: str, *, trust_remote_code: bool = False, **attrs: t.Any) -> openllm_core.LLMConfig: + import transformers + + config = transformers.AutoConfig.from_pretrained(model_id, trust_remote_code=trust_remote_code) + for arch in config.architectures: + if arch in cls._architecture_mappings: + return cls.for_model(cls._architecture_mappings[arch]).model_construct_env(**attrs) + else: + raise RuntimeError( + f'Failed to determine config class for {model_id}. Got {config.architectures}, which is not yet supported (Supported: {list(cls._architecture_mappings.keys())})' + ) diff --git a/openllm-core/src/openllm_core/utils/__init__.py b/openllm-core/src/openllm_core/utils/__init__.py index 5e352aaf..e163590f 100644 --- a/openllm-core/src/openllm_core/utils/__init__.py +++ b/openllm-core/src/openllm_core/utils/__init__.py @@ -14,6 +14,7 @@ from ._constants import ( DEBUG as DEBUG, SHOW_CODEGEN as SHOW_CODEGEN, MYPY as MYPY, + OPENLLM_DEV_BUILD as OPENLLM_DEV_BUILD, ) if t.TYPE_CHECKING: diff --git a/openllm-core/src/openllm_core/utils/__init__.pyi b/openllm-core/src/openllm_core/utils/__init__.pyi index 3aa5a75a..ffc06bd3 100644 --- a/openllm-core/src/openllm_core/utils/__init__.pyi +++ b/openllm-core/src/openllm_core/utils/__init__.pyi @@ -44,6 +44,7 @@ DEBUG_ENV_VAR: str = ... QUIET_ENV_VAR: str = ... DEV_DEBUG_VAR: str = ... WARNING_ENV_VAR: str = ... +OPENLLM_DEV_BUILD: str = ... _T = TypeVar('_T') R = TypeVar('R') diff --git a/openllm-core/src/openllm_core/utils/_constants.py b/openllm-core/src/openllm_core/utils/_constants.py index 729e057a..a5811c55 100644 --- a/openllm-core/src/openllm_core/utils/_constants.py +++ b/openllm-core/src/openllm_core/utils/_constants.py @@ -9,6 +9,7 @@ WARNING_ENV_VAR = 'OPENLLM_DISABLE_WARNING' DEV_DEBUG_VAR = 'DEBUG' ENV_VARS_TRUE_VALUES = {'1', 'ON', 'YES', 'TRUE'} +OPENLLM_DEV_BUILD = 'OPENLLM_DEV_BUILD' def check_bool_env(env: str, default: bool = True): diff --git a/openllm-python/pyproject.toml b/openllm-python/pyproject.toml index 5b75cf35..038ac8cd 100644 --- a/openllm-python/pyproject.toml +++ b/openllm-python/pyproject.toml @@ -38,7 +38,7 @@ classifiers = [ "Programming Language :: Python :: Implementation :: PyPy", ] dependencies = [ - "bentoml[io]>=1.2", + "bentoml[io]>=1.2.16", "transformers[torch,tokenizers]>=4.36.0", "openllm-client>=0.5.0-alpha.14", "openllm-core>=0.5.0-alpha.14", @@ -112,7 +112,7 @@ gemma = ["xformers"] ggml = ["ctransformers"] gpt-neox = ["xformers"] gptq = ["auto-gptq[triton]>=0.4.2"] -grpc = ["bentoml[grpc]>=1.2"] +grpc = ["bentoml[grpc]>=1.2.16"] llama = ["xformers"] mistral = ["xformers"] mixtral = ["xformers"] diff --git a/openllm-python/src/_openllm_tiny/_entrypoint.py b/openllm-python/src/_openllm_tiny/_entrypoint.py index 4b3faaa3..bfeca822 100644 --- a/openllm-python/src/_openllm_tiny/_entrypoint.py +++ b/openllm-python/src/_openllm_tiny/_entrypoint.py @@ -1,16 +1,17 @@ from __future__ import annotations -import os, logging, traceback, pathlib, sys, fs, click, enum, inflection, bentoml, orjson, openllm, openllm_core, platform, typing as t +import os, traceback, io, pathlib, sys, fs, click, enum, importlib, importlib.metadata, inflection, bentoml, orjson, openllm, openllm_core as core, platform, tarfile, typing as t from ._helpers import recommended_instance_type from openllm_core.utils import ( DEBUG, DEBUG_ENV_VAR, QUIET_ENV_VAR, + OPENLLM_DEV_BUILD, SHOW_CODEGEN, check_bool_env, compose, first_not_none, - dantic, + pkg, gen_random_uuid, get_debug_mode, get_quiet_mode, @@ -25,7 +26,10 @@ from openllm_core._typing_compat import ( ) from openllm_cli import termui -logger = logging.getLogger(__name__) +if sys.version_info >= (3, 11): + import tomllib +else: + import tomli as tomllib OPENLLM_FIGLET = """ ██████╗ ██████╗ ███████╗███╗ ██╗██╗ ██╗ ███╗ ███╗ @@ -38,20 +42,21 @@ OPENLLM_FIGLET = """ _PACKAGE_NAME = 'openllm' _TINY_PATH = pathlib.Path(os.path.abspath(__file__)).parent _SERVICE_FILE = _TINY_PATH / '_service.py' -_SERVICE_VARS = '''\ +_SERVICE_README = _TINY_PATH / 'service.md' +_SERVICE_VARS = """\ # fmt: off -# GENERATED BY 'openllm build {__model_id__}'. DO NOT EDIT +# GENERATED BY '{__command__}'. DO NOT EDIT import orjson,openllm_core.utils as coreutils model_id='{__model_id__}' -model_name='{__model_name__}' +revision=orjson.loads(coreutils.getenv('revision',default={__model_revision__})) quantise=coreutils.getenv('quantize',default='{__model_quantise__}',var=['QUANTISE']) serialisation=coreutils.getenv('serialization',default='{__model_serialization__}',var=['SERIALISATION']) -dtype=coreutils.getenv('dtype', default='{__model_dtype__}', var=['TORCH_DTYPE']) +dtype=coreutils.getenv('dtype', default='{__model_dtype__}',var=['TORCH_DTYPE']) trust_remote_code=coreutils.check_bool_env("TRUST_REMOTE_CODE",{__model_trust_remote_code__}) -max_model_len=orjson.loads(coreutils.getenv('max_model_len', default=orjson.dumps({__max_model_len__}))) -gpu_memory_utilization=orjson.loads(coreutils.getenv('gpu_memory_utilization', default=orjson.dumps({__gpu_memory_utilization__}), var=['GPU_MEMORY_UTILISATION'])) -services_config=orjson.loads(coreutils.getenv('services_config',"""{__services_config__}""")) -''' +max_model_len=orjson.loads(coreutils.getenv('max_model_len',default=orjson.dumps({__max_model_len__}))) +gpu_memory_utilization=orjson.loads(coreutils.getenv('gpu_memory_utilization',default=orjson.dumps({__gpu_memory_utilization__}),var=['GPU_MEMORY_UTILISATION'])) +services_config=orjson.loads(coreutils.getenv('services_config',default={__services_config__})) +""" HF_HUB_DISABLE_PROGRESS_BARS = 'HF_HUB_DISABLE_PROGRESS_BARS' @@ -62,20 +67,6 @@ class ItemState(enum.Enum): OVERWRITE = 'OVERWRITE' -def parse_device_callback( - _: click.Context, param: click.Parameter, value: tuple[tuple[str], ...] | None -) -> t.Tuple[str, ...] | None: - if value is None: - return value - el: t.Tuple[str, ...] = tuple(i for k in value for i in k) - # NOTE: --device all is a special case - if len(el) == 1 and el[0] == 'all': - return tuple(map(str, openllm.utils.available_devices())) - if len(el) == 1 and el[0] == 'gpu': - return ('0',) - return el - - @click.group(context_settings=termui.CONTEXT_SETTINGS, name='openllm') @click.version_option( None, @@ -94,53 +85,47 @@ def cli() -> None: ╚═════╝ ╚═╝ ╚══════╝╚═╝ ╚═══╝╚══════╝╚══════╝╚═╝ ╚═╝. \b - An open platform for operating large language models in production. - Fine-tune, serve, deploy, and monitor any LLMs with ease. + Self-Hosting LLMs Made Easy """ -def optimization_decorator(fn): - # NOTE: return device, quantize, serialisation, dtype, max_model_len, gpu_memory_utilization +def optimization_decorator(fn: t.Callable[..., t.Any]): optimization = [ click.option( - '--device', - type=dantic.CUDA, - multiple=True, - envvar='CUDA_VISIBLE_DEVICES', - callback=parse_device_callback, - help='Assign GPU devices (if available)', + '--concurrency', + type=int, + envvar='CONCURRENCY', + help='See https://docs.bentoml.com/en/latest/guides/concurrency.html#concurrency for more information.', show_envvar=True, + default=None, ), + click.option('--timeout', type=int, default=360000, help='Timeout for the model executor in seconds'), click.option( '--dtype', type=str, envvar='DTYPE', default='auto', - help="Optional dtype for casting tensors for running inference ['float16', 'float32', 'bfloat16', 'int8', 'int16']", + help="Optional dtype for casting tensors for running inference ['float16', 'float32', 'bfloat16']. Default to auto for infering dtype based on available accelerator.", ), click.option( '--quantise', '--quantize', - 'quantize', + 'quantise', type=str, default=None, envvar='QUANTIZE', show_envvar=True, - help="""Dynamic quantization for running this LLM. + help="""Quantisation options for this LLM. - The following quantization strategies are supported: + The following quantisation strategies are supported: - - ``int8``: ``LLM.int8`` for [8-bit](https://arxiv.org/abs/2208.07339) quantization. + - ``gptq``: ``GPTQ`` [quantisation](https://arxiv.org/abs/2210.17323) - - ``int4``: ``SpQR`` for [4-bit](https://arxiv.org/abs/2306.03078) quantization. + - ``awq``: ``AWQ`` [AWQ: Activation-aware Weight Quantisation](https://arxiv.org/abs/2306.00978) - - ``gptq``: ``GPTQ`` [quantization](https://arxiv.org/abs/2210.17323) + - ``squeezellm``: ``SqueezeLLM`` [SqueezeLLM: Dense-and-Sparse Quantisation](https://arxiv.org/abs/2306.07629) - - ``awq``: ``AWQ`` [AWQ: Activation-aware Weight Quantization](https://arxiv.org/abs/2306.00978) - - - ``squeezellm``: ``SqueezeLLM`` [SqueezeLLM: Dense-and-Sparse Quantization](https://arxiv.org/abs/2306.07629) - - > [!NOTE] that the model can also be served with quantized weights. + > [!NOTE] that the model must be pre-quantised to ensure correct loading, as all aforementioned quantization scheme are post-training quantization. """, ), click.option( @@ -151,14 +136,14 @@ def optimization_decorator(fn): default=None, show_default=True, show_envvar=True, - envvar='OPENLLM_SERIALIZATION', - help="""Serialisation format for save/load LLM. + envvar='SERIALIZATION', + help="""Serialisation format for loading LLM. Make sure to check HF repository for the correct format. Currently the following strategies are supported: - ``safetensors``: This will use safetensors format, which is synonymous to ``safe_serialization=True``. - > [!NOTE] Safetensors might not work for every cases, and you can always fallback to ``legacy`` if needed. + > [!NOTE] Safetensors might not work for older models, and you can always fallback to ``legacy`` if needed. - ``legacy``: This will use PyTorch serialisation format, often as ``.bin`` files. This should be used if the model doesn't yet support safetensors. """, @@ -178,14 +163,28 @@ def optimization_decorator(fn): default=0.9, help='The percentage of GPU memory to be used for the model executor', ), + click.option( + '--trust-remote-code', + '--trust_remote_code', + 'trust_remote_code', + type=bool, + is_flag=True, + default=False, + show_envvar=True, + envvar='TRUST_REMOTE_CODE', + help='If model from HuggingFace requires custom code, pass this to enable remote code execution. If the model is a private model, make sure to also pass this argument such that OpenLLM can determine model architecture to load.', + ), ] return compose(*optimization)(fn) -def shared_decorator(fn): +def shared_decorator(fn: t.Callable[..., t.Any]): shared = [ click.argument( - 'model_id', type=click.STRING, metavar='[REMOTE_REPO/MODEL_ID | /path/to/local/model]', required=True + 'model_id', + type=click.STRING, + metavar='[REMOTE_REPO/MODEL_ID | /path/to/local/model | bentomodel_tag]', + required=True, ), click.option( '--revision', @@ -194,15 +193,17 @@ def shared_decorator(fn): 'model_version', type=click.STRING, default=None, - help='Optional model revision to save for this model. It will be inferred automatically from model-id.', + help='Optional model revision to for this LLM. If this is a private model, specify this alongside model_id will be used as a bentomodel tag. If using in conjunction with a HF model id, this will be a specific revision, code branch, or a commit id on HF repo.', ), click.option( - '--model-tag', - '--bentomodel-tag', - 'model_tag', - type=click.STRING, - default=None, - help='Optional bentomodel tag to save for this model. It will be generated automatically based on model_id and model_version if not specified.', + '--debug', + '--verbose', + type=bool, + default=False, + show_envvar=True, + is_flag=True, + envvar='DEBUG', + help='whether to enable verbose logging (For more fine-grained control, set DEBUG to number instead of this flag.).', ), ] return compose(*shared)(fn) @@ -210,76 +211,77 @@ def shared_decorator(fn): @cli.command(name='start') @shared_decorator -@click.option('--timeout', type=int, default=360000, help='Timeout for the model executor in seconds') @click.option('--port', type=int, default=3000, help='Port to serve the LLM. Default to 3000.') @optimization_decorator def start_command( model_id: str, model_version: str | None, - model_tag: str | None, timeout: int, + concurrency: int | None, port: int, - device: t.Tuple[str, ...], - quantize: LiteralQuantise | None, + quantise: LiteralQuantise | None, serialisation: LiteralSerialisation | None, dtype: LiteralDtype | t.Literal['auto', 'float'], max_model_len: int | None, gpu_memory_utilization: float, + trust_remote_code: bool, + debug: bool, ): """Start any LLM as a REST server. \b ```bash - $ openllm -- ... + $ openllm start microsoft/Phi-3-mini-4k-instruct --trust-remote-code ``` """ - import transformers - from _bentoml_impl.server import serve_http from bentoml._internal.service.loader import load from bentoml._internal.log import configure_server_logging configure_server_logging() - trust_remote_code = check_bool_env('TRUST_REMOTE_CODE', False) try: - # if given model_id is a private model, then we can use it directly + # NOTE: if given model_id is a private model (assuming this is packaged into a bentomodel), then we can use it directly bentomodel = bentoml.models.get(model_id.lower()) model_id = bentomodel.path + if not trust_remote_code: + trust_remote_code = True except (ValueError, bentoml.exceptions.NotFound): bentomodel = None - config = transformers.AutoConfig.from_pretrained(model_id, trust_remote_code=trust_remote_code) - for arch in config.architectures: - if arch in openllm_core.AutoConfig._architecture_mappings: - model_name = openllm_core.AutoConfig._architecture_mappings[arch] - break - else: - raise RuntimeError(f'Failed to determine config class for {model_id}') - llm_config = openllm_core.AutoConfig.for_model(model_name).model_construct_env() + llm_config = core.AutoConfig.from_id(model_id, trust_remote_code=trust_remote_code) if serialisation is None: + termui.warning( + f"Serialisation format is not specified. Defaulting to '{llm_config['serialisation']}'. Your model might not work with this format. Make sure to explicitly specify the serialisation format." + ) serialisation = llm_config['serialisation'] # TODO: support LoRA adapters os.environ.update({ QUIET_ENV_VAR: str(openllm.utils.get_quiet_mode()), - DEBUG_ENV_VAR: str(openllm.utils.get_debug_mode()), + DEBUG_ENV_VAR: str(debug) or str(openllm.utils.get_debug_mode()), HF_HUB_DISABLE_PROGRESS_BARS: str(not openllm.utils.get_debug_mode()), 'MODEL_ID': model_id, - 'MODEL_NAME': model_name, + # handling custom revision if users specify --revision alongside with model_id + # this should work only if bentomodel is None + 'REVISION': orjson.dumps(first_not_none(model_version, default=None)).decode(), 'SERIALIZATION': serialisation, 'OPENLLM_CONFIG': llm_config.model_dump_json(), 'DTYPE': dtype, 'TRUST_REMOTE_CODE': str(trust_remote_code), 'GPU_MEMORY_UTILIZATION': orjson.dumps(gpu_memory_utilization).decode(), 'SERVICES_CONFIG': orjson.dumps( - dict(resources={'gpu' if device else 'cpu': len(device) if device else '1'}, traffic=dict(timeout=timeout)) + # XXX: right now we just enable GPU by default. will revisit this if we decide to support TPU later. + dict( + resources={'gpu': len(openllm.utils.available_devices())}, + traffic=dict(timeout=timeout, concurrency=concurrency), + ) ).decode(), }) if max_model_len is not None: - os.environ['MAX_MODEL_LEN'] = orjson.dumps(max_model_len) - if quantize: - os.environ['QUANTIZE'] = str(quantize) + os.environ['MAX_MODEL_LEN'] = orjson.dumps(max_model_len).decode() + if quantise: + os.environ['QUANTIZE'] = str(quantise) working_dir = os.path.abspath(os.path.dirname(__file__)) if sys.path[0] != working_dir: @@ -290,22 +292,69 @@ def start_command( ) -def construct_python_options(llm_config, llm_fs): - from bentoml._internal.bento.build_config import PythonOptions - from openllm.bundle._package import build_editable +def get_package_version(_package: str) -> str: + try: + return importlib.import_module('._version', _package).__version__ + except ModuleNotFoundError: + return importlib.metadata.version(_package) - # TODO: Add this line back once 0.5 is out, for now depends on OPENLLM_DEV_BUILD - # packages = ['scipy', 'bentoml[tracing]>=1.2.8', 'openllm[vllm]>0.4', 'vllm>=0.3'] - packages = ['scipy', 'bentoml[tracing]>=1.2.8', 'openllm'] - if llm_config['requirements'] is not None: - packages.extend(llm_config['requirements']) - built_wheels = [build_editable(llm_fs.getsyspath('/'), p) for p in ('openllm_core', 'openllm_client', 'openllm')] - return PythonOptions( - packages=packages, - wheels=[llm_fs.getsyspath(f"/{i.split('/')[-1]}") for i in built_wheels] if all(i for i in built_wheels) else None, - lock_packages=False, + +def build_sdist(target_path: str, package: t.Literal['openllm', 'openllm-core', 'openllm-client']): + import tomli_w + + if not check_bool_env(OPENLLM_DEV_BUILD, default=False): + return None + + module_location = pkg.source_locations(inflection.underscore(package)) + if not module_location: + raise RuntimeError( + f'Could not find the source location of {package}. Make sure to set "{OPENLLM_DEV_BUILD}=False" if you are not using development build.' + ) + + package_path = pathlib.Path(module_location) + package_version = get_package_version(package) + project_path = package_path.parent.parent + + if not (project_path / 'pyproject.toml').exists(): + termui.warning( + f'Custom "{package}" is detected. For a Bento to use the same build at serving time, at your custom "{package}" build to pip packages list under your "bentofile.yaml". i.e: "packages=[\'git+https://github.com/bentoml/openllm.git@bc0be03\']"' + ) + return + termui.debug( + f'"{package}" is installed in "editable" mode; building "{package}" distribution with local code base. the built tar will be included in generated bento.' ) + def exclude_pycache(tarinfo: tarfile.TarInfo): + if '__pycache__' in tarinfo.name or tarinfo.name.endswith(('.pyc', '.pyo')): + return None + return tarinfo + + with open(project_path / 'pyproject.toml', 'rb') as f: + pyproject_toml = tomllib.load(f) + pyproject_toml['project']['version'] = package_version + if 'dynamic' in pyproject_toml['project'] and 'version' in pyproject_toml['project']['dynamic']: + pyproject_toml['project']['dynamic'].remove('version') + + pyproject_io = io.BytesIO() + tomli_w.dump(pyproject_toml, pyproject_io) + # make a tarball of this package-version + base_name = f'{package}-{package_version}' + sdist_filename = f'{base_name}.tar.gz' + files_to_include = ['src', 'LICENSE.md', 'README.md'] + + if not pathlib.Path(target_path).exists(): + pathlib.Path(target_path).mkdir(parents=True, exist_ok=True) + with tarfile.open(pathlib.Path(target_path, sdist_filename), 'w:gz') as tar: + for file in files_to_include: + tar.add(project_path / file, arcname=f'{base_name}/{file}', filter=exclude_pycache) + if package == 'openllm': + tar.add(project_path / 'CHANGELOG.md', arcname=f'{base_name}/CHANGELOG.md', filter=exclude_pycache) + tarinfo = tar.gettarinfo(project_path / 'pyproject.toml', arcname=f'{base_name}/pyproject.toml') + tarinfo.size = pyproject_io.tell() + pyproject_io.seek(0) + tar.addfile(tarinfo, pyproject_io) + return sdist_filename + class EnvironmentEntry(TypedDict): name: str @@ -327,8 +376,6 @@ class EnvironmentEntry(TypedDict): help='Optional bento version for this BentoLLM. Default is the the model revision.', ) @click.option('--overwrite', is_flag=True, help='Overwrite existing Bento for given LLM if it already exists.') -@click.option('--nightly', is_flag=True, default=False, help='Package with OpenLLM nightly version.') -@click.option('--timeout', type=int, default=360000, help='Timeout for the model executor in seconds') @optimization_decorator @click.option( '-o', @@ -338,31 +385,28 @@ class EnvironmentEntry(TypedDict): show_default=True, help="Output log format. '-o tag' to display only bento tag.", ) -@click.pass_context def build_command( - ctx: click.Context, - /, model_id: str, model_version: str | None, - model_tag: str | None, bento_version: str | None, bento_tag: str | None, overwrite: bool, - device: t.Tuple[str, ...], - nightly: bool, timeout: int, - quantize: LiteralQuantise | None, + concurrency: int | None, + quantise: LiteralQuantise | None, serialisation: LiteralSerialisation | None, dtype: LiteralDtype | t.Literal['auto', 'float'], max_model_len: int | None, gpu_memory_utilization: float, output: t.Literal['default', 'tag'], + trust_remote_code: bool, + debug: bool, ): - """Package a given models into a BentoLLM. + """Package a given LLM into a servable artefacts. \b ```bash - $ openllm build google/flan-t5-large + $ openllm build microsoft/Phi-3-mini-4k-instruct --trust-remote-code ``` \b @@ -371,43 +415,42 @@ def build_command( > to have https://github.com/NVIDIA/nvidia-container-toolkit install locally. \b - > [!IMPORTANT] - > To build the bento with compiled OpenLLM, make sure to prepend HATCH_BUILD_HOOKS_ENABLE=1. Make sure that the deployment - > target also use the same Python version and architecture as build machine. + > [!NOTE] + > For private model, make sure to save it to the bentomodel store first. See https://docs.bentoml.com/en/latest/guides/model-store.html#model-store for more information """ import transformers - from bentoml._internal.configuration.containers import BentoMLContainer - from bentoml._internal.configuration import set_quiet_mode from bentoml._internal.log import configure_logging - from bentoml._internal.bento.build_config import BentoBuildConfig - from bentoml._internal.bento.build_config import DockerOptions - from bentoml._internal.bento.build_config import ModelSpec + from bentoml._internal.configuration import set_quiet_mode + from bentoml._internal.configuration.containers import BentoMLContainer + from bentoml._internal.bento.build_config import BentoBuildConfig, DockerOptions, ModelSpec, PythonOptions if output == 'tag': set_quiet_mode(True) configure_logging() - trust_remote_code = check_bool_env('TRUST_REMOTE_CODE', False) try: - # if given model_id is a private model, then we can use it directly + # NOTE: if given model_id is a private model (assuming the model is packaged into a bentomodel), then we can use it directly bentomodel = bentoml.models.get(model_id.lower()) model_id = bentomodel.path _revision = bentomodel.tag.version + if not trust_remote_code: + trust_remote_code = True except (ValueError, bentoml.exceptions.NotFound): - bentomodel = None - _revision = None + bentomodel, _revision = None, None - config = transformers.AutoConfig.from_pretrained(model_id, trust_remote_code=trust_remote_code) - for arch in config.architectures: - if arch in openllm_core.AutoConfig._architecture_mappings: - model_name = openllm_core.AutoConfig._architecture_mappings[arch] - break - else: - raise RuntimeError(f'Failed to determine config class for {model_id}') + llm_config = core.AutoConfig.from_id(model_id, trust_remote_code=trust_remote_code) + transformers_config = transformers.AutoConfig.from_pretrained(model_id, trust_remote_code=trust_remote_code) + commit_hash = getattr(transformers_config, '_commit_hash', None) - llm_config: openllm_core.LLMConfig = openllm_core.AutoConfig.for_model(model_name).model_construct_env() + # in case that user specify the revision here, + generated_uuid = gen_random_uuid() + _revision = first_not_none(_revision, model_version, commit_hash, default=generated_uuid) + + model_revision = None + if bentomodel is None and model_version is not None: + # this is when --revision|--model-version is specified alongside with HF model-id, then we set it in the generated service_vars.py, then we let users manage this themselves + model_revision = model_version - _revision = first_not_none(_revision, getattr(config, '_commit_hash', None), default=gen_random_uuid()) if serialisation is None: termui.warning( f"Serialisation format is not specified. Defaulting to '{llm_config['serialisation']}'. Your model might not work with this format. Make sure to explicitly specify the serialisation format." @@ -416,74 +459,93 @@ def build_command( if bento_tag is None: _bento_version = first_not_none(bento_version, default=_revision) - bento_tag = bentoml.Tag.from_taglike(f'{normalise_model_name(model_id)}-service:{_bento_version}'.lower().strip()) + generated_tag = bentoml.Tag.from_taglike( + f'{normalise_model_name(model_id)}-service:{_bento_version}'.lower().strip() + ) else: - bento_tag = bentoml.Tag.from_taglike(bento_tag) + generated_tag = bentoml.Tag.from_taglike(bento_tag) state = ItemState.NOT_FOUND try: - bento = bentoml.get(bento_tag) + bento = bentoml.get(generated_tag) if overwrite: - bentoml.delete(bento_tag) + bentoml.delete(generated_tag) state = ItemState.OVERWRITE - raise bentoml.exceptions.NotFound(f'Rebuilding existing Bento {bento_tag}') from None + raise bentoml.exceptions.NotFound(f'Rebuilding existing Bento {generated_tag}') from None state = ItemState.EXISTS except bentoml.exceptions.NotFound: if state != ItemState.OVERWRITE: state = ItemState.ADDED - labels = {'library': 'vllm'} + labels = {'runtime': 'vllm'} + # XXX: right now we just enable GPU by default. will revisit this if we decide to support TPU later. service_config = dict( - resources={ - 'gpu' if device else 'cpu': len(device) if device else '1', - 'gpu_type': recommended_instance_type(model_id, bentomodel, serialisation), - }, - traffic=dict(timeout=timeout), + resources=dict( + gpu=len(openllm.utils.available_devices()), + gpu_type=recommended_instance_type(model_id, bentomodel, serialisation), + ), + traffic=dict(timeout=timeout, concurrency=concurrency), ) - with fs.open_fs(f'temp://llm_{gen_random_uuid()}') as llm_fs: - logger.debug('Generating service vars %s (dir=%s)', model_id, llm_fs.getsyspath('/')) + with fs.open_fs(f'temp://{gen_random_uuid()}') as llm_fs, fs.open_fs( + f'temp://wheels_{gen_random_uuid()}' + ) as wheel_fs: + termui.debug(f'Generating service vars {model_id} (dir={llm_fs.getsyspath("/")})') script = _SERVICE_VARS.format( + __command__=' '.join(['openllm', *sys.argv[1:]]), __model_id__=model_id, - __model_name__=model_name, - __model_quantise__=quantize, + __model_revision__=orjson.dumps(model_revision), + __model_quantise__=quantise, __model_dtype__=dtype, __model_serialization__=serialisation, __model_trust_remote_code__=trust_remote_code, __max_model_len__=max_model_len, __gpu_memory_utilization__=gpu_memory_utilization, - __services_config__=orjson.dumps(service_config).decode(), + __services_config__=orjson.dumps(service_config), ) - models = [] - if bentomodel is not None: - models.append(ModelSpec.from_item({'tag': str(bentomodel.tag), 'alias': bentomodel.tag.name})) if SHOW_CODEGEN: - logger.info('Generated _service_vars.py:\n%s', script) + termui.info(f'\n{"=" * 27}\nGenerated _service_vars.py:\n\n{script}\n{"=" * 27}\n') llm_fs.writetext('_service_vars.py', script) + with _SERVICE_README.open('r') as f: + service_readme = f.read() + service_readme = service_readme.format(model_id=model_id) with _SERVICE_FILE.open('r') as f: service_src = f.read() llm_fs.writetext(llm_config['service_name'], service_src) + + built_wheels = [build_sdist(wheel_fs.getsyspath('/'), p) for p in ('openllm-core', 'openllm-client', 'openllm')] + bento = bentoml.Bento.create( - version=bento_tag.version, + version=generated_tag.version, build_ctx=llm_fs.getsyspath('/'), build_config=BentoBuildConfig( service=f"{llm_config['service_name']}:LLMService", - name=bento_tag.name, + name=generated_tag.name, labels=labels, - models=models, + models=[ModelSpec.from_item({'tag': str(bentomodel.tag), 'alias': bentomodel.tag.name})] + if bentomodel is not None + else [], envs=[ EnvironmentEntry(name='NVIDIA_DRIVER_CAPABILITIES', value='compute,utility'), EnvironmentEntry(name='VLLM_VERSION', value='0.4.2'), EnvironmentEntry(name=HF_HUB_DISABLE_PROGRESS_BARS, value='TRUE'), ], - description=f"OpenLLM service for {llm_config['start_name']}", + description=service_readme, include=list(llm_fs.walk.files()), exclude=['/venv', '/.venv', '__pycache__/', '*.py[cod]', '*$py.class'], - python=construct_python_options(llm_config, llm_fs), + python=PythonOptions( + packages=['scipy', 'bentoml[tracing]>=1.2.16', 'openllm[vllm]'], + pip_args='--no-color --progress-bar off', + wheels=[wheel_fs.getsyspath(f"/{i.split('/')[-1]}") for i in built_wheels] + if all(i for i in built_wheels) + else None, + lock_packages=False, + ), docker=DockerOptions( python_version='3.11', - setup_script=str(_TINY_PATH / 'setup.sh') if nightly else None, + setup_script=str(_TINY_PATH / 'setup.sh'), dockerfile_template=str(_TINY_PATH / 'Dockerfile.j2'), + system_packages=['git'], ), ), ).save(bento_store=BentoMLContainer.bento_store.get(), model_store=BentoMLContainer.model_store.get()) @@ -500,7 +562,7 @@ def build_command( termui.info(f"Successfully built Bento '{bento.tag}'.\n") elif not overwrite: termui.warning(f"Bento for '{model_id}' already exists [{bento}]. To overwrite it pass '--overwrite'.\n") - if not get_debug_mode(): + if not (debug or get_debug_mode()): termui.echo(OPENLLM_FIGLET) termui.echo('📖 Next steps:\n', nl=False) termui.echo(f'☁️ Deploy to BentoCloud:\n $ bentoml deploy {bento.tag} -n ${{DEPLOYMENT_NAME}}\n', nl=False) diff --git a/openllm-python/src/_openllm_tiny/_service.py b/openllm-python/src/_openllm_tiny/_service.py index 2856c051..b4574c8b 100644 --- a/openllm-python/src/_openllm_tiny/_service.py +++ b/openllm-python/src/_openllm_tiny/_service.py @@ -24,14 +24,14 @@ try: except Exception: bentomodel = None model_id = svars.model_id -LLMConfig = core.AutoConfig.for_model(svars.model_name) +LLMConfig = core.AutoConfig.from_id(model_id, trust_remote_code=svars.trust_remote_code) GenerationInput = core.GenerationInput.from_config(LLMConfig) ChatMessages = [ MessageParam( role='system', content='You are acting as Ernest Hemmingway. All of your response will follow his and his writing style ONLY.', ), - MessageParam(role='user', content='Write an essay about Nietzsche and absurdism.'), + MessageParam(role='user', content='Write an essay on absurdism and its impact in the 20th century.'), ] app_v1 = FastAPI( @@ -42,6 +42,9 @@ app_v1 = FastAPI( contact={'name': 'BentoML Team', 'email': 'contact@bentoml.com'}, ) +if core.utils.get_debug_mode(): + logger.info('service_config: %s', svars.services_config) + @bentoml.mount_asgi_app(app_v1, path='/v1') @bentoml.service(name=f"llm-{LLMConfig['start_name']}-service", **svars.services_config) @@ -53,6 +56,7 @@ class LLMService: model_id, dtype=svars.dtype, bentomodel=bentomodel, + revision=svars.revision, serialisation=svars.serialisation, quantise=svars.quantise, llm_config=LLMConfig, diff --git a/openllm-python/src/_openllm_tiny/_service_vars.py b/openllm-python/src/_openllm_tiny/_service_vars.py index c0e69624..b523e9e2 100644 --- a/openllm-python/src/_openllm_tiny/_service_vars.py +++ b/openllm-python/src/_openllm_tiny/_service_vars.py @@ -4,7 +4,7 @@ from _openllm_tiny._llm import Dtype ( model_id, - model_name, + revision, quantise, serialisation, dtype, @@ -14,7 +14,7 @@ from _openllm_tiny._llm import Dtype services_config, ) = ( coreutils.getenv('model_id', var=['MODEL_ID'], return_type=str), - coreutils.getenv('model_name', return_type=str), + orjson.loads(coreutils.getenv('revision', return_type=str)), coreutils.getenv('quantize', var=['QUANTISE'], return_type=LiteralQuantise), coreutils.getenv('serialization', default='safetensors', var=['SERIALISATION'], return_type=LiteralSerialisation), coreutils.getenv('dtype', default='auto', var=['TORCH_DTYPE'], return_type=Dtype), diff --git a/openllm-python/src/_openllm_tiny/service.md b/openllm-python/src/_openllm_tiny/service.md new file mode 100644 index 00000000..c94e7cbe --- /dev/null +++ b/openllm-python/src/_openllm_tiny/service.md @@ -0,0 +1,10 @@ +## OpenLLM generated services for {model_id} + +Available endpoints: + +- `GET /v1/models`: compatible with OpenAI's models list. +- `POST /v1/chat/completions`: compatible with OpenAI's chat completions client. +- `POST /v1/generate_stream`: low-level SSE-formatted streams. Users are responsible for correct prompt format strings. +- `POST /v1/generate`: low-level one-shot generation. Users are responsible for correct prompt format strings. +- `POST /v1/helpers/messages`: helpers endpoints to return fully formatted chat messages should the model is a chat model. +- `POST /v1/metadata (deprecated)`: returns compatible metadata for OpenLLM server. diff --git a/openllm-python/src/_openllm_tiny/setup.sh b/openllm-python/src/_openllm_tiny/setup.sh index 4d838048..342526e5 100755 --- a/openllm-python/src/_openllm_tiny/setup.sh +++ b/openllm-python/src/_openllm_tiny/setup.sh @@ -1,5 +1,14 @@ #!/usr/bin/env bash set -Eexuo pipefail -pip3 install --no-color --progress-bar off --pre openllm openllm-core openllm-client || true -FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE pip3 install --no-build-isolation --no-color --progress-bar off flash-attn==2.5.7 || true +BASEDIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]:-$0}")" &>/dev/null && pwd 2>/dev/null)" +PARENT_DIR="$(dirname -- "$BASEDIR")" +WHEELS_DIR="${PARENT_DIR}/python/wheels" +pushd "${PARENT_DIR}/python" &>/dev/null +shopt -s nullglob +targzs=($WHEELS_DIR/*.tar.gz) +if [ ${#targzs[@]} -gt 0 ]; then + echo "Installing tar.gz packaged in Bento.." + pip3 install --no-color --progress-bar off "${targzs[@]}" +fi +popd &>/dev/null diff --git a/openllm-python/src/_service_vars.pyi b/openllm-python/src/_service_vars.pyi index 45a0845e..f9b06945 100644 --- a/openllm-python/src/_service_vars.pyi +++ b/openllm-python/src/_service_vars.pyi @@ -1,9 +1,9 @@ from typing import Dict, Optional, Any -from openllm_core._typing_compat import LiteralSerialisation, LiteralQuantise, LiteralString +from openllm_core._typing_compat import LiteralSerialisation, LiteralQuantise from _openllm_tiny._llm import Dtype model_id: str = ... -model_name: LiteralString = ... +revision: str = ... model_tag: Optional[str] = ... model_version: Optional[str] = ... quantise: LiteralQuantise = ... diff --git a/tools/dependencies.py b/tools/dependencies.py index ad47faaa..9ceb9e43 100755 --- a/tools/dependencies.py +++ b/tools/dependencies.py @@ -141,7 +141,7 @@ class Dependencies: return cls(*decls) -_LOWER_BENTOML_CONSTRAINT = '1.2' +_LOWER_BENTOML_CONSTRAINT = '1.2.16' _BENTOML_EXT = ['io'] _TRANSFORMERS_EXT = ['torch', 'tokenizers'] _TRANSFORMERS_CONSTRAINTS = '4.36.0'