mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-02-07 22:33:28 -05:00
feat: 1.2 APIs (#821)
Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
3
openllm-python/src/_openllm_tiny/__init__.py
Normal file
3
openllm-python/src/_openllm_tiny/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from ._llm import LLM as LLM
|
||||
25
openllm-python/src/_openllm_tiny/__init__.pyi
Normal file
25
openllm-python/src/_openllm_tiny/__init__.pyi
Normal file
@@ -0,0 +1,25 @@
|
||||
import bentoml
|
||||
from bentoml._internal.models.model import ModelInfo
|
||||
from openllm_core._typing_compat import TypedDict, NotRequired
|
||||
from ._llm import LLM as LLM
|
||||
|
||||
class _Metadata(TypedDict):
|
||||
model_id: str
|
||||
dtype: str
|
||||
_revision: str
|
||||
_local: bool
|
||||
serialisation: str
|
||||
architectures: str
|
||||
trust_remote_code: bool
|
||||
api_version: str
|
||||
llm_type: str
|
||||
openllm_version: str
|
||||
openllm_core_version: str
|
||||
openllm_client_version: str
|
||||
quantize: NotRequired[str]
|
||||
|
||||
class _Info(ModelInfo):
|
||||
metadata: _Metadata # type: ignore[assignment]
|
||||
|
||||
class _Model(bentoml.Model):
|
||||
info: _Info
|
||||
499
openllm-python/src/_openllm_tiny/_entrypoint.py
Normal file
499
openllm-python/src/_openllm_tiny/_entrypoint.py
Normal file
@@ -0,0 +1,499 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os, logging, traceback, pathlib, sys, fs, click, enum, inflection, bentoml, orjson, openllm, openllm_core, platform, typing as t
|
||||
from ._helpers import recommended_instance_type
|
||||
from openllm_core.utils import (
|
||||
DEBUG_ENV_VAR,
|
||||
QUIET_ENV_VAR,
|
||||
SHOW_CODEGEN,
|
||||
check_bool_env,
|
||||
compose,
|
||||
first_not_none,
|
||||
dantic,
|
||||
gen_random_uuid,
|
||||
get_debug_mode,
|
||||
get_quiet_mode,
|
||||
normalise_model_name,
|
||||
)
|
||||
from openllm_core._typing_compat import (
|
||||
LiteralQuantise,
|
||||
LiteralSerialisation,
|
||||
LiteralDtype,
|
||||
get_literal_args,
|
||||
TypedDict,
|
||||
)
|
||||
from openllm_cli import termui
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
OPENLLM_FIGLET = """
|
||||
██████╗ ██████╗ ███████╗███╗ ██╗██╗ ██╗ ███╗ ███╗
|
||||
██╔═══██╗██╔══██╗██╔════╝████╗ ██║██║ ██║ ████╗ ████║
|
||||
██║ ██║██████╔╝█████╗ ██╔██╗ ██║██║ ██║ ██╔████╔██║
|
||||
██║ ██║██╔═══╝ ██╔══╝ ██║╚██╗██║██║ ██║ ██║╚██╔╝██║
|
||||
╚██████╔╝██║ ███████╗██║ ╚████║███████╗███████╗██║ ╚═╝ ██║
|
||||
╚═════╝ ╚═╝ ╚══════╝╚═╝ ╚═══╝╚══════╝╚══════╝╚═╝ ╚═╝.
|
||||
"""
|
||||
_PACKAGE_NAME = 'openllm'
|
||||
_SERVICE_FILE = pathlib.Path(os.path.abspath(__file__)).parent / '_service.py'
|
||||
_SERVICE_VARS = '''\
|
||||
# fmt: off
|
||||
# GENERATED BY 'openllm build {__model_id__}'. DO NOT EDIT
|
||||
import orjson,openllm_core.utils as coreutils
|
||||
model_id='{__model_id__}'
|
||||
model_name='{__model_name__}'
|
||||
quantise=coreutils.getenv('quantize',default='{__model_quantise__}',var=['QUANTISE'])
|
||||
serialisation=coreutils.getenv('serialization',default='{__model_serialization__}',var=['SERIALISATION'])
|
||||
dtype=coreutils.getenv('dtype', default='{__model_dtype__}', var=['TORCH_DTYPE'])
|
||||
trust_remote_code=coreutils.check_bool_env("TRUST_REMOTE_CODE",{__model_trust_remote_code__})
|
||||
max_model_len={__max_model_len__}
|
||||
gpu_memory_utilization={__gpu_memory_utilization__}
|
||||
services_config=orjson.loads(coreutils.getenv('services_config',"""{__services_config__}"""))
|
||||
'''
|
||||
|
||||
|
||||
class ItemState(enum.Enum):
|
||||
NOT_FOUND = 'NOT_FOUND'
|
||||
ADDED = 'ADDED'
|
||||
EXISTS = 'EXISTS'
|
||||
OVERWRITE = 'OVERWRITE'
|
||||
|
||||
|
||||
def parse_device_callback(
|
||||
_: click.Context, param: click.Parameter, value: tuple[tuple[str], ...] | None
|
||||
) -> t.Tuple[str, ...] | None:
|
||||
if value is None:
|
||||
return value
|
||||
el: t.Tuple[str, ...] = tuple(i for k in value for i in k)
|
||||
# NOTE: --device all is a special case
|
||||
if len(el) == 1 and el[0] == 'all':
|
||||
return tuple(map(str, openllm.utils.available_devices()))
|
||||
return el
|
||||
|
||||
|
||||
@click.group(context_settings=termui.CONTEXT_SETTINGS, name='openllm')
|
||||
@click.version_option(
|
||||
None,
|
||||
'--version',
|
||||
'-v',
|
||||
package_name=_PACKAGE_NAME,
|
||||
message=f'{_PACKAGE_NAME}, %(version)s (compiled: {openllm.COMPILED})\nPython ({platform.python_implementation()}) {platform.python_version()}',
|
||||
)
|
||||
def cli() -> None:
|
||||
"""\b
|
||||
██████╗ ██████╗ ███████╗███╗ ██╗██╗ ██╗ ███╗ ███╗
|
||||
██╔═══██╗██╔══██╗██╔════╝████╗ ██║██║ ██║ ████╗ ████║
|
||||
██║ ██║██████╔╝█████╗ ██╔██╗ ██║██║ ██║ ██╔████╔██║
|
||||
██║ ██║██╔═══╝ ██╔══╝ ██║╚██╗██║██║ ██║ ██║╚██╔╝██║
|
||||
╚██████╔╝██║ ███████╗██║ ╚████║███████╗███████╗██║ ╚═╝ ██║
|
||||
╚═════╝ ╚═╝ ╚══════╝╚═╝ ╚═══╝╚══════╝╚══════╝╚═╝ ╚═╝.
|
||||
|
||||
\b
|
||||
An open platform for operating large language models in production.
|
||||
Fine-tune, serve, deploy, and monitor any LLMs with ease.
|
||||
"""
|
||||
|
||||
|
||||
def optimization_decorator(fn):
|
||||
# NOTE: return device, quantize, serialisation, dtype, max_model_len, gpu_memory_utilization
|
||||
optimization = [
|
||||
click.option(
|
||||
'--device',
|
||||
type=dantic.CUDA,
|
||||
multiple=True,
|
||||
envvar='CUDA_VISIBLE_DEVICES',
|
||||
callback=parse_device_callback,
|
||||
help='Assign GPU devices (if available)',
|
||||
show_envvar=True,
|
||||
),
|
||||
click.option(
|
||||
'--dtype',
|
||||
type=str,
|
||||
envvar='DTYPE',
|
||||
default='auto',
|
||||
help="Optional dtype for casting tensors for running inference ['float16', 'float32', 'bfloat16', 'int8', 'int16']",
|
||||
),
|
||||
click.option(
|
||||
'--quantise',
|
||||
'--quantize',
|
||||
'quantize',
|
||||
type=str,
|
||||
default=None,
|
||||
envvar='QUANTIZE',
|
||||
show_envvar=True,
|
||||
help="""Dynamic quantization for running this LLM.
|
||||
|
||||
The following quantization strategies are supported:
|
||||
|
||||
- ``int8``: ``LLM.int8`` for [8-bit](https://arxiv.org/abs/2208.07339) quantization.
|
||||
|
||||
- ``int4``: ``SpQR`` for [4-bit](https://arxiv.org/abs/2306.03078) quantization.
|
||||
|
||||
- ``gptq``: ``GPTQ`` [quantization](https://arxiv.org/abs/2210.17323)
|
||||
|
||||
- ``awq``: ``AWQ`` [AWQ: Activation-aware Weight Quantization](https://arxiv.org/abs/2306.00978)
|
||||
|
||||
- ``squeezellm``: ``SqueezeLLM`` [SqueezeLLM: Dense-and-Sparse Quantization](https://arxiv.org/abs/2306.07629)
|
||||
|
||||
> [!NOTE] that the model can also be served with quantized weights.
|
||||
""",
|
||||
),
|
||||
click.option(
|
||||
'--serialisation',
|
||||
'--serialization',
|
||||
'serialisation',
|
||||
type=click.Choice(get_literal_args(LiteralSerialisation)),
|
||||
default=None,
|
||||
show_default=True,
|
||||
show_envvar=True,
|
||||
envvar='OPENLLM_SERIALIZATION',
|
||||
help="""Serialisation format for save/load LLM.
|
||||
|
||||
Currently the following strategies are supported:
|
||||
|
||||
- ``safetensors``: This will use safetensors format, which is synonymous to ``safe_serialization=True``.
|
||||
|
||||
> [!NOTE] Safetensors might not work for every cases, and you can always fallback to ``legacy`` if needed.
|
||||
|
||||
- ``legacy``: This will use PyTorch serialisation format, often as ``.bin`` files. This should be used if the model doesn't yet support safetensors.
|
||||
""",
|
||||
),
|
||||
click.option(
|
||||
'--max-model-len',
|
||||
'--max_model_len',
|
||||
'max_model_len',
|
||||
type=int,
|
||||
default=None,
|
||||
help='Maximum sequence length for the model. If not specified, we will use the default value from the model config.',
|
||||
),
|
||||
click.option(
|
||||
'--gpu-memory-utilization',
|
||||
'--gpu_memory_utilization',
|
||||
'gpu_memory_utilization',
|
||||
default=0.9,
|
||||
help='The percentage of GPU memory to be used for the model executor',
|
||||
),
|
||||
]
|
||||
return compose(*optimization)(fn)
|
||||
|
||||
|
||||
def shared_decorator(fn):
|
||||
shared = [
|
||||
click.argument(
|
||||
'model_id', type=click.STRING, metavar='[REMOTE_REPO/MODEL_ID | /path/to/local/model]', required=True
|
||||
),
|
||||
click.option(
|
||||
'--revision',
|
||||
'--bentomodel-version',
|
||||
'--model-version',
|
||||
'model_version',
|
||||
type=click.STRING,
|
||||
default=None,
|
||||
help='Optional model revision to save for this model. It will be inferred automatically from model-id.',
|
||||
),
|
||||
click.option(
|
||||
'--model-tag',
|
||||
'--bentomodel-tag',
|
||||
'model_tag',
|
||||
type=click.STRING,
|
||||
default=None,
|
||||
help='Optional bentomodel tag to save for this model. It will be generated automatically based on model_id and model_version if not specified.',
|
||||
),
|
||||
]
|
||||
return compose(*shared)(fn)
|
||||
|
||||
|
||||
@cli.command(name='start')
|
||||
@shared_decorator
|
||||
@click.option('--timeout', type=int, default=360000, help='Timeout for the model executor in seconds')
|
||||
@optimization_decorator
|
||||
def start_command(
|
||||
model_id: str,
|
||||
model_version: str | None,
|
||||
model_tag: str | None,
|
||||
timeout: int,
|
||||
device: t.Tuple[str, ...],
|
||||
quantize: LiteralQuantise | None,
|
||||
serialisation: LiteralSerialisation | None,
|
||||
dtype: LiteralDtype | t.Literal['auto', 'float'],
|
||||
max_model_len: int | None,
|
||||
gpu_memory_utilization: float,
|
||||
):
|
||||
"""Start any LLM as a REST server.
|
||||
|
||||
\b
|
||||
```bash
|
||||
$ openllm <start|start-http> <model_id> --<options> ...
|
||||
```
|
||||
"""
|
||||
import transformers
|
||||
|
||||
from _bentoml_impl.server import serve_http
|
||||
from bentoml._internal.service.loader import load
|
||||
from bentoml._internal.log import configure_server_logging
|
||||
|
||||
configure_server_logging()
|
||||
|
||||
trust_remote_code = check_bool_env('TRUST_REMOTE_CODE', False)
|
||||
try:
|
||||
# if given model_id is a private model, then we can use it directly
|
||||
bentomodel = bentoml.models.get(model_id.lower())
|
||||
model_id = bentomodel.path
|
||||
except (ValueError, bentoml.exceptions.NotFound):
|
||||
pass
|
||||
config = transformers.AutoConfig.from_pretrained(model_id, trust_remote_code=trust_remote_code)
|
||||
for arch in config.architectures:
|
||||
if arch in openllm_core.AutoConfig._architecture_mappings:
|
||||
model_name = openllm_core.AutoConfig._architecture_mappings[arch]
|
||||
break
|
||||
else:
|
||||
raise RuntimeError(f'Failed to determine config class for {model_id}')
|
||||
|
||||
llm_config = openllm_core.AutoConfig.for_model(model_name).model_construct_env()
|
||||
|
||||
# TODO: support LoRA adapters
|
||||
os.environ.update({
|
||||
QUIET_ENV_VAR: str(openllm.utils.get_quiet_mode()),
|
||||
DEBUG_ENV_VAR: str(openllm.utils.get_debug_mode()),
|
||||
'MODEL_ID': model_id,
|
||||
'MODEL_NAME': model_name,
|
||||
'SERIALIZATION': serialisation,
|
||||
'OPENLLM_CONFIG': llm_config.model_dump_json(),
|
||||
'DTYPE': dtype,
|
||||
'TRUST_REMOTE_CODE': str(trust_remote_code),
|
||||
'MAX_MODEL_LEN': orjson.dumps(max_model_len).decode(),
|
||||
'GPU_MEMORY_UTILIZATION': orjson.dumps(gpu_memory_utilization).decode(),
|
||||
'SERVICES_CONFIG': orjson.dumps(
|
||||
dict(
|
||||
resources={'gpu' if device else 'cpu': len(device) if device else 'cpu_count'}, traffic=dict(timeout=timeout)
|
||||
)
|
||||
).decode(),
|
||||
})
|
||||
if quantize:
|
||||
os.environ['QUANTIZE'] = str(quantize)
|
||||
|
||||
working_dir = os.path.abspath(os.path.dirname(__file__))
|
||||
if sys.path[0] != working_dir:
|
||||
sys.path.insert(0, working_dir)
|
||||
load('.', working_dir=working_dir).inject_config()
|
||||
serve_http('.', working_dir=working_dir)
|
||||
|
||||
|
||||
def construct_python_options(llm_config, llm_fs):
|
||||
from openllm.bundle import RefResolver
|
||||
from bentoml._internal.bento.build_config import PythonOptions
|
||||
from openllm.bundle._package import build_editable
|
||||
|
||||
packages = ['scipy', 'bentoml[tracing]>=1.2', f'openllm[vllm]>={RefResolver.from_strategy("release").version}']
|
||||
if llm_config['requirements'] is not None:
|
||||
packages.extend(llm_config['requirements'])
|
||||
built_wheels = [build_editable(llm_fs.getsyspath('/'), p) for p in ('openllm_core', 'openllm_client', 'openllm')]
|
||||
return PythonOptions(
|
||||
packages=packages,
|
||||
wheels=[llm_fs.getsyspath(f"/{i.split('/')[-1]}") for i in built_wheels] if all(i for i in built_wheels) else None,
|
||||
lock_packages=False,
|
||||
)
|
||||
|
||||
|
||||
class EnvironmentEntry(TypedDict):
|
||||
name: str
|
||||
value: str
|
||||
|
||||
|
||||
@cli.command(name='build', context_settings={'token_normalize_func': inflection.underscore})
|
||||
@shared_decorator
|
||||
@click.option(
|
||||
'--bento-version',
|
||||
type=str,
|
||||
default=None,
|
||||
help='Optional bento version for this BentoLLM. Default is the the model revision.',
|
||||
)
|
||||
@click.option(
|
||||
'--bento-tag',
|
||||
type=str,
|
||||
default=None,
|
||||
help='Optional bento version for this BentoLLM. Default is the the model revision.',
|
||||
)
|
||||
@click.option('--overwrite', is_flag=True, help='Overwrite existing Bento for given LLM if it already exists.')
|
||||
@click.option('--timeout', type=int, default=360000, help='Timeout for the model executor in seconds')
|
||||
@optimization_decorator
|
||||
@click.option(
|
||||
'-o',
|
||||
'--output',
|
||||
type=click.Choice(['tag', 'default']),
|
||||
default='default',
|
||||
show_default=True,
|
||||
help="Output log format. '-o tag' to display only bento tag.",
|
||||
)
|
||||
@click.pass_context
|
||||
def build_command(
|
||||
ctx: click.Context,
|
||||
/,
|
||||
model_id: str,
|
||||
model_version: str | None,
|
||||
model_tag: str | None,
|
||||
bento_version: str | None,
|
||||
bento_tag: str | None,
|
||||
overwrite: bool,
|
||||
device: t.Tuple[str, ...],
|
||||
timeout: int,
|
||||
quantize: LiteralQuantise | None,
|
||||
serialisation: LiteralSerialisation | None,
|
||||
dtype: LiteralDtype | t.Literal['auto', 'float'],
|
||||
max_model_len: int | None,
|
||||
gpu_memory_utilization: float,
|
||||
output: t.Literal['default', 'tag'],
|
||||
):
|
||||
"""Package a given models into a BentoLLM.
|
||||
|
||||
\b
|
||||
```bash
|
||||
$ openllm build google/flan-t5-large
|
||||
```
|
||||
|
||||
\b
|
||||
> [!NOTE]
|
||||
> To run a container built from this Bento with GPU support, make sure
|
||||
> to have https://github.com/NVIDIA/nvidia-container-toolkit install locally.
|
||||
|
||||
\b
|
||||
> [!IMPORTANT]
|
||||
> To build the bento with compiled OpenLLM, make sure to prepend HATCH_BUILD_HOOKS_ENABLE=1. Make sure that the deployment
|
||||
> target also use the same Python version and architecture as build machine.
|
||||
"""
|
||||
import transformers
|
||||
from bentoml._internal.configuration.containers import BentoMLContainer
|
||||
from bentoml._internal.configuration import set_quiet_mode
|
||||
from bentoml._internal.log import configure_logging
|
||||
from bentoml._internal.bento.build_config import BentoBuildConfig
|
||||
from bentoml._internal.bento.build_config import DockerOptions
|
||||
from bentoml._internal.bento.build_config import ModelSpec
|
||||
|
||||
if output == 'tag':
|
||||
set_quiet_mode(True)
|
||||
configure_logging()
|
||||
|
||||
trust_remote_code = check_bool_env('TRUST_REMOTE_CODE', False)
|
||||
try:
|
||||
# if given model_id is a private model, then we can use it directly
|
||||
bentomodel = bentoml.models.get(model_id.lower())
|
||||
model_id = bentomodel.path
|
||||
_revision = bentomodel.tag.version
|
||||
except (ValueError, bentoml.exceptions.NotFound):
|
||||
bentomodel = None
|
||||
_revision = None
|
||||
|
||||
config = transformers.AutoConfig.from_pretrained(model_id, trust_remote_code=trust_remote_code)
|
||||
for arch in config.architectures:
|
||||
if arch in openllm_core.AutoConfig._architecture_mappings:
|
||||
model_name = openllm_core.AutoConfig._architecture_mappings[arch]
|
||||
break
|
||||
else:
|
||||
raise RuntimeError(f'Failed to determine config class for {model_id}')
|
||||
|
||||
llm_config: openllm_core.LLMConfig = openllm_core.AutoConfig.for_model(model_name).model_construct_env()
|
||||
|
||||
_revision = first_not_none(_revision, getattr(config, '_commit_hash', None), default=gen_random_uuid())
|
||||
if serialisation is None:
|
||||
termui.warning(
|
||||
f"Serialisation format is not specified. Defaulting to '{llm_config['serialisation']}'. Your model might not work with this format. Make sure to explicitly specify the serialisation format."
|
||||
)
|
||||
serialisation = llm_config['serialisation']
|
||||
|
||||
if bento_tag is None:
|
||||
_bento_version = first_not_none(bento_version, default=_revision)
|
||||
bento_tag = bentoml.Tag.from_taglike(f'{normalise_model_name(model_id)}-service:{_bento_version}'.lower().strip())
|
||||
else:
|
||||
bento_tag = bentoml.Tag.from_taglike(bento_tag)
|
||||
|
||||
state = ItemState.NOT_FOUND
|
||||
try:
|
||||
bento = bentoml.get(bento_tag)
|
||||
if overwrite:
|
||||
bentoml.delete(bento_tag)
|
||||
state = ItemState.OVERWRITE
|
||||
raise bentoml.exceptions.NotFound(f'Rebuilding existing Bento {bento_tag}') from None
|
||||
state = ItemState.EXISTS
|
||||
except bentoml.exceptions.NotFound:
|
||||
if state != ItemState.OVERWRITE:
|
||||
state = ItemState.ADDED
|
||||
|
||||
labels = {'library': 'vllm'}
|
||||
service_config = dict(
|
||||
resources={
|
||||
'gpu' if device else 'cpu': len(device) if device else 'cpu_count',
|
||||
'gpu_type': recommended_instance_type(model_id, bentomodel),
|
||||
},
|
||||
traffic=dict(timeout=timeout),
|
||||
)
|
||||
with fs.open_fs(f'temp://llm_{gen_random_uuid()}') as llm_fs:
|
||||
logger.debug('Generating service vars %s (dir=%s)', model_id, llm_fs.getsyspath('/'))
|
||||
script = _SERVICE_VARS.format(
|
||||
__model_id__=model_id,
|
||||
__model_name__=model_name,
|
||||
__model_quantise__=quantize,
|
||||
__model_dtype__=dtype,
|
||||
__model_serialization__=serialisation,
|
||||
__model_trust_remote_code__=trust_remote_code,
|
||||
__max_model_len__=max_model_len,
|
||||
__gpu_memory_utilization__=gpu_memory_utilization,
|
||||
__services_config__=orjson.dumps(service_config).decode(),
|
||||
)
|
||||
models = []
|
||||
if bentomodel is not None:
|
||||
models.append(ModelSpec.from_item({'tag': str(bentomodel.tag), 'alias': bentomodel.tag.name}))
|
||||
if SHOW_CODEGEN:
|
||||
logger.info('Generated _service_vars.py:\n%s', script)
|
||||
llm_fs.writetext('_service_vars.py', script)
|
||||
with _SERVICE_FILE.open('r') as f:
|
||||
service_src = f.read()
|
||||
llm_fs.writetext(llm_config['service_name'], service_src)
|
||||
bento = bentoml.Bento.create(
|
||||
version=bento_tag.version,
|
||||
build_ctx=llm_fs.getsyspath('/'),
|
||||
build_config=BentoBuildConfig(
|
||||
service=f"{llm_config['service_name']}:LLMService",
|
||||
name=bento_tag.name,
|
||||
labels=labels,
|
||||
models=models,
|
||||
envs=[
|
||||
EnvironmentEntry(name=QUIET_ENV_VAR, value=str(openllm.utils.get_quiet_mode())),
|
||||
EnvironmentEntry(name=DEBUG_ENV_VAR, value=str(openllm.utils.get_debug_mode())),
|
||||
EnvironmentEntry(name='OPENLLM_CONFIG', value=llm_config.model_dump_json()),
|
||||
EnvironmentEntry(name='NVIDIA_DRIVER_CAPABILITIES', value='compute,utility'),
|
||||
],
|
||||
description=f"OpenLLM service for {llm_config['start_name']}",
|
||||
include=list(llm_fs.walk.files()),
|
||||
exclude=['/venv', '/.venv', '__pycache__/', '*.py[cod]', '*$py.class'],
|
||||
python=construct_python_options(llm_config, llm_fs),
|
||||
docker=DockerOptions(python_version='3.11'),
|
||||
),
|
||||
).save(bento_store=BentoMLContainer.bento_store.get(), model_store=BentoMLContainer.model_store.get())
|
||||
except Exception as err:
|
||||
traceback.print_exc()
|
||||
raise click.ClickException('Exception caught while building BentoLLM:\n' + str(err)) from err
|
||||
|
||||
if output == 'tag':
|
||||
termui.echo(f'__tag__:{bento.tag}')
|
||||
return
|
||||
|
||||
if not get_quiet_mode():
|
||||
if state != ItemState.EXISTS:
|
||||
termui.info(f"Successfully built Bento '{bento.tag}'.\n")
|
||||
elif not overwrite:
|
||||
termui.warning(f"Bento for '{model_id}' already exists [{bento}]. To overwrite it pass '--overwrite'.\n")
|
||||
if not get_debug_mode():
|
||||
termui.echo(OPENLLM_FIGLET)
|
||||
termui.echo('📖 Next steps:\n', nl=False)
|
||||
termui.echo(f'☁️ Deploy to BentoCloud:\n $ bentoml deploy {bento.tag} -n ${{DEPLOYMENT_NAME}}\n', nl=False)
|
||||
termui.echo(
|
||||
f'☁️ Update existing deployment on BentoCloud:\n $ bentoml deployment update --bento {bento.tag} ${{DEPLOYMENT_NAME}}\n',
|
||||
nl=False,
|
||||
)
|
||||
termui.echo(f'🐳 Containerize BentoLLM:\n $ bentoml containerize {bento.tag} --opt progress=plain\n', nl=False)
|
||||
|
||||
return bento
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
cli()
|
||||
340
openllm-python/src/_openllm_tiny/_helpers.py
Normal file
340
openllm-python/src/_openllm_tiny/_helpers.py
Normal file
@@ -0,0 +1,340 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import openllm, traceback, logging, time, pathlib, pydantic, typing as t
|
||||
from openllm_core.exceptions import ModelNotFound, OpenLLMException, ValidationError
|
||||
from openllm_core.utils import gen_random_uuid, resolve_filepath
|
||||
from openllm_core.protocol.openai import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponseChoice,
|
||||
ChatCompletionResponseStreamChoice,
|
||||
ChatCompletionStreamResponse,
|
||||
ChatMessage,
|
||||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
ChatCompletionResponse,
|
||||
Delta,
|
||||
ErrorResponse,
|
||||
NotSupportedError,
|
||||
LogProbs,
|
||||
UsageInfo,
|
||||
)
|
||||
from starlette.requests import Request
|
||||
from huggingface_hub import scan_cache_dir
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from vllm import RequestOutput
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Error(pydantic.BaseModel):
|
||||
error: ErrorResponse
|
||||
|
||||
|
||||
def error_response(exception: type[OpenLLMException], message: str) -> ErrorResponse:
|
||||
return Error(
|
||||
error=ErrorResponse(message=message, type=str(exception.__qualname__), code=str(exception.error_code.value))
|
||||
)
|
||||
|
||||
|
||||
class OpenAI:
|
||||
def __init__(self, llm: openllm.LLM):
|
||||
self.llm = llm
|
||||
|
||||
async def chat_completions(self, request: ChatCompletionRequest, raw_request: Request):
|
||||
if request.logit_bias is not None and len(request.logit_bias) > 0:
|
||||
return self.create_error_response("'logit_bias' is not supported .", NotSupportedError)
|
||||
|
||||
error = await self._check_model(request)
|
||||
if error is not None:
|
||||
return error
|
||||
|
||||
try:
|
||||
prompt: str = self.llm._tokenizer.apply_chat_template(
|
||||
conversation=request.messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=request.add_generation_prompt,
|
||||
chat_template=request.chat_template if request.chat_template != 'None' else None,
|
||||
)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
return self.create_error_response(f'Failed to apply chat template: {e}')
|
||||
|
||||
model_name, request_id = request.model, gen_random_uuid('chatcmpl')
|
||||
created_time = int(time.monotonic())
|
||||
|
||||
generator = self.llm.generate_iterator(prompt, request_id=request_id, **request.model_dump())
|
||||
if request.stream:
|
||||
return self.chat_completion_stream_generator(request, model_name, request_id, created_time, generator)
|
||||
|
||||
try:
|
||||
return await self.chat_completion_full_generator(
|
||||
request, raw_request, model_name, request_id, created_time, generator
|
||||
)
|
||||
except Exception as err:
|
||||
traceback.print_exc()
|
||||
return self.create_error_response(str(err))
|
||||
|
||||
async def chat_completion_stream_generator(
|
||||
self,
|
||||
request: ChatCompletionRequest,
|
||||
model_name: str,
|
||||
request_id: str,
|
||||
created_time: int,
|
||||
generator: t.AsyncIterator[RequestOutput],
|
||||
):
|
||||
first_iteration = True
|
||||
|
||||
previous_texts = [''] * self.llm.config['n']
|
||||
previous_num_tokens = [0] * self.llm.config['n']
|
||||
finish_reason_sent = [False] * self.llm.config['n']
|
||||
try:
|
||||
async for request_output in generator:
|
||||
if first_iteration:
|
||||
role = self.get_chat_role(request)
|
||||
for i in range(self.llm.config['n']):
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=i, delta=Delta(role=role), logprobs=None, finish_reason=None
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=request_id, created=created_time, choices=[choice_data], model=model_name
|
||||
)
|
||||
data = chunk.model_dump_json(exclude_unset=True)
|
||||
yield f'data: {data}\n\n'
|
||||
|
||||
if request.echo:
|
||||
last_msg_content = ''
|
||||
if (
|
||||
request.messages
|
||||
and isinstance(request.messages, list)
|
||||
and request.messages[-1].get('content')
|
||||
and request.messages[-1].get('role') == role
|
||||
):
|
||||
last_msg_content = request.messages[-1]['content']
|
||||
|
||||
if last_msg_content:
|
||||
for i in range(request.n):
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=i, delta=Delta(content=last_msg_content), finish_reason=None
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=request_id, created=created_time, choices=[choice_data], logprobs=None, model=model_name
|
||||
)
|
||||
data = chunk.model_dump_json(exclude_unset=True)
|
||||
yield f'data: {data}\n\n'
|
||||
first_iteration = False
|
||||
|
||||
for output in request_output.outputs:
|
||||
i = output.index
|
||||
|
||||
if finish_reason_sent[i]:
|
||||
continue
|
||||
|
||||
delta_token_ids = output.token_ids[previous_num_tokens[i] :]
|
||||
top_logprobs = output.logprobs[previous_num_tokens[i] :] if output.logprobs else None
|
||||
logprobs = None
|
||||
|
||||
if request.logprobs:
|
||||
logprobs = self._create_logprobs(
|
||||
token_ids=delta_token_ids,
|
||||
top_logprobs=top_logprobs,
|
||||
num_output_top_logprobs=request.logprobs,
|
||||
initial_text_offset=len(previous_texts[i]),
|
||||
)
|
||||
|
||||
delta_text = output.text[len(previous_texts[i]) :]
|
||||
previous_texts[i] = output.text
|
||||
previous_num_tokens[i] = len(output.token_ids)
|
||||
if output.finish_reason is None:
|
||||
# Send token-by-token response for each request.n
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=i, delta=Delta(content=delta_text), logprobs=logprobs, finish_reason=None
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=request_id, created=created_time, choices=[choice_data], model=model_name
|
||||
)
|
||||
data = chunk.model_dump_json(exclude_unset=True)
|
||||
yield f'data: {data}\n\n'
|
||||
else:
|
||||
# Send the finish response for each request.n only once
|
||||
prompt_tokens = len(request_output.prompt_token_ids)
|
||||
final_usage = UsageInfo(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=previous_num_tokens[i],
|
||||
total_tokens=prompt_tokens + previous_num_tokens[i],
|
||||
)
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=i, delta=Delta(content=delta_text), logprobs=logprobs, finish_reason=output.finish_reason
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=request_id, created=created_time, choices=[choice_data], model=model_name
|
||||
)
|
||||
if final_usage is not None:
|
||||
chunk.usage = final_usage
|
||||
data = chunk.model_dump_json(exclude_unset=True, exclude_none=True)
|
||||
yield f'data: {data}\n\n'
|
||||
finish_reason_sent[i] = True
|
||||
except ValueError as e:
|
||||
data = self.create_error_response(str(e)).model_dump_json()
|
||||
yield f'data: {data}\n\n'
|
||||
# Send the final done message after all response.n are finished
|
||||
yield 'data: [DONE]\n\n'
|
||||
|
||||
async def chat_completion_full_generator(
|
||||
self,
|
||||
request: ChatCompletionRequest,
|
||||
raw_request: Request,
|
||||
model_name: str,
|
||||
request_id: str,
|
||||
created_time: int,
|
||||
generator: t.AsyncIterator[RequestOutput],
|
||||
):
|
||||
final_result: RequestOutput = None
|
||||
|
||||
try:
|
||||
async for request_output in generator:
|
||||
if await raw_request.is_disconnected():
|
||||
await self.llm._model.abort(request_id)
|
||||
return self.create_error_response('Client disconnected.')
|
||||
final_result = request_output
|
||||
except ValueError as e:
|
||||
await self.llm._model.abort(request_id)
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
if final_result is None:
|
||||
return self.create_error_response('No result is returned.')
|
||||
|
||||
choices = []
|
||||
role = self.get_chat_role(request)
|
||||
|
||||
for output in final_result.outputs:
|
||||
token_ids = output.token_ids
|
||||
top_logprobs = output.logprobs
|
||||
|
||||
logprobs = None
|
||||
if request.logprobs:
|
||||
logprobs = self._create_logprobs(
|
||||
token_ids=token_ids, top_logprobs=top_logprobs, num_output_top_logprobs=request.logprobs
|
||||
)
|
||||
|
||||
choice_data = ChatCompletionResponseChoice(
|
||||
index=output.index,
|
||||
message=ChatMessage(role=role, content=output.text),
|
||||
logprobs=logprobs,
|
||||
finish_reason=output.finish_reason,
|
||||
)
|
||||
choices.append(choice_data)
|
||||
|
||||
if request.echo:
|
||||
last_msg_content = ''
|
||||
if (
|
||||
request.messages
|
||||
and isinstance(request.messages, list)
|
||||
and request.messages[-1].get('content')
|
||||
and request.messages[-1].get('role') == role
|
||||
):
|
||||
last_msg_content = request.messages[-1]['content']
|
||||
|
||||
for choice in choices:
|
||||
full_message = last_msg_content + choice.message.content
|
||||
choice.message.content = full_message
|
||||
|
||||
num_prompt_tokens = len(final_result.prompt_token_ids)
|
||||
num_generated_tokens = sum(len(output.token_ids) for output in final_result.outputs)
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=num_generated_tokens,
|
||||
total_tokens=num_prompt_tokens + num_generated_tokens,
|
||||
)
|
||||
response = ChatCompletionResponse(
|
||||
id=request_id, created=created_time, model=model_name, choices=choices, usage=usage
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
# --- chat specific ---
|
||||
def get_chat_role(self, request: ChatCompletionRequest) -> str:
|
||||
return request.messages[-1]['role'] if not request.add_generation_prompt else 'assistant'
|
||||
|
||||
@staticmethod
|
||||
async def completions(self, request) -> t.AsyncGenerator[CompletionResponse, None]: ...
|
||||
|
||||
# ---
|
||||
def create_error_response(self, message: str, exception: type[OpenLLMException] = ValidationError) -> ErrorResponse:
|
||||
return error_response(exception, message)
|
||||
|
||||
async def _check_model(self, request: ChatCompletionRequest | CompletionRequest) -> Error | None:
|
||||
if request.model != self.llm.llm_type:
|
||||
return error_response(
|
||||
ModelNotFound,
|
||||
f"Model '{request.model}' does not exists. Try 'GET /v1/models' to see available models.\nTip: If you are migrating from OpenAI, make sure to update your 'model' parameters in the request.",
|
||||
)
|
||||
# TODO: support lora
|
||||
|
||||
def _validate_prompt(
|
||||
self,
|
||||
request: ChatCompletionRequest | CompletionRequest,
|
||||
prompt: str | None = None,
|
||||
prompt_ids: list[int] | None = None,
|
||||
) -> Error | None:
|
||||
if not (prompt or prompt_ids):
|
||||
raise ValueError("'prompt' or 'prompt_ids' must be provided.")
|
||||
if prompt and prompt_ids:
|
||||
raise ValueError("'prompt' and 'prompt_ids' are mutually exclusive.")
|
||||
# TODO: valudate max context length based on requested max_tokens
|
||||
#
|
||||
# input_ids = prompt_ids if prompt_ids is not None else self._tokenizer(prompt).input_ids
|
||||
# token_num = len(input_ids)
|
||||
# if request.max_tokens is None: request.max_tokens = self.engine_args['max_model_len']
|
||||
|
||||
def _create_logprobs(
|
||||
self,
|
||||
token_ids: list[int],
|
||||
top_logprobs: list[dict[int, float] | None] | None = None, #
|
||||
num_output_top_logprobs: int | None = None,
|
||||
initial_text_offset: int = 0,
|
||||
) -> LogProbs:
|
||||
"""Create OpenAI-style logprobs."""
|
||||
logprobs = LogProbs()
|
||||
last_token_len = 0
|
||||
if num_output_top_logprobs:
|
||||
logprobs.top_logprobs = []
|
||||
for i, token_id in enumerate(token_ids):
|
||||
step_top_logprobs = top_logprobs[i]
|
||||
if step_top_logprobs is not None:
|
||||
token_logprob = step_top_logprobs[token_id].logprob
|
||||
else:
|
||||
token_logprob = None
|
||||
token = step_top_logprobs[token_id].decoded_token
|
||||
logprobs.tokens.append(token)
|
||||
logprobs.token_logprobs.append(token_logprob)
|
||||
if len(logprobs.text_offset) == 0:
|
||||
logprobs.text_offset.append(initial_text_offset)
|
||||
else:
|
||||
logprobs.text_offset.append(logprobs.text_offset[-1] + last_token_len)
|
||||
last_token_len = len(token)
|
||||
|
||||
if num_output_top_logprobs:
|
||||
logprobs.top_logprobs.append(
|
||||
{p.decoded_token: p.logprob for i, p in step_top_logprobs.items()} if step_top_logprobs else None
|
||||
)
|
||||
return logprobs
|
||||
|
||||
|
||||
# NOTE: Check for following supported mapping from here:
|
||||
# python -c "from openllm_core._typing_compat import get_type_hints, get_literal_args; from _bentoml_sdk.service.config import ResourceSchema; print(get_literal_args(get_type_hints(ResourceSchema)['gpu_type']))"
|
||||
RECOMMENDED_MAPPING = {'nvidia-l4': 24e9, 'nvidia-a10g': 24e9, 'nvidia-tesla-a100': 40e9, 'nvidia-a100-80gb': 80e9}
|
||||
|
||||
|
||||
def recommended_instance_type(model_id, bentomodel=None):
|
||||
if bentomodel is not None:
|
||||
size = sum(f.stat().st_size for f in pathlib.Path(resolve_filepath(model_id)).glob('**/*') if f.is_file())
|
||||
else:
|
||||
info = next(filter(lambda repo: repo.repo_id == model_id, scan_cache_dir().repos))
|
||||
size = info.size_on_disk
|
||||
|
||||
# find the first occurence of the gpu_type in the recommended mapping such that "size" should be less than or equal to 70% of the recommended size
|
||||
for gpu, max_size in RECOMMENDED_MAPPING.items():
|
||||
if size <= max_size * 0.7:
|
||||
return gpu
|
||||
227
openllm-python/src/_openllm_tiny/_llm.py
Normal file
227
openllm-python/src/_openllm_tiny/_llm.py
Normal file
@@ -0,0 +1,227 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect, orjson, dataclasses, bentoml, functools, attr, openllm_core, traceback, openllm, typing as t
|
||||
|
||||
from openllm_core.utils import (
|
||||
get_debug_mode,
|
||||
is_vllm_available,
|
||||
normalise_model_name,
|
||||
gen_random_uuid,
|
||||
dict_filter_none,
|
||||
)
|
||||
from openllm_core._typing_compat import LiteralQuantise, LiteralSerialisation, LiteralDtype
|
||||
from openllm_core._schemas import GenerationOutput, GenerationInput
|
||||
from _bentoml_sdk.service import ServiceConfig
|
||||
|
||||
Dtype = t.Union[LiteralDtype, t.Literal['auto', 'half', 'float']]
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from vllm import RequestOutput
|
||||
|
||||
|
||||
def check_engine_args(_, attr: attr.Attribute[dict[str, t.Any]], v: dict[str, t.Any]) -> dict[str, t.Any]:
|
||||
from vllm import AsyncEngineArgs
|
||||
|
||||
fields = dataclasses.fields(AsyncEngineArgs)
|
||||
invalid_args = {k: v for k, v in v.items() if k not in {f.name for f in fields}}
|
||||
if len(invalid_args) > 0:
|
||||
raise ValueError(f'Invalid engine args: {list(invalid_args)}')
|
||||
return v
|
||||
|
||||
|
||||
@attr.define(init=False)
|
||||
class LLM:
|
||||
model_id: str
|
||||
bentomodel: bentoml.Model
|
||||
serialisation: LiteralSerialisation
|
||||
config: openllm_core.LLMConfig
|
||||
dtype: Dtype
|
||||
quantise: t.Optional[LiteralQuantise] = attr.field(default=None)
|
||||
trust_remote_code: bool = attr.field(default=False)
|
||||
engine_args: t.Dict[str, t.Any] = attr.field(factory=dict, validator=check_engine_args)
|
||||
service_config: t.Optional[ServiceConfig] = attr.field(factory=dict)
|
||||
|
||||
_path: str = attr.field(
|
||||
init=False,
|
||||
default=attr.Factory(
|
||||
lambda self: self.bentomodel.path if self.bentomodel is not None else self.model_id, takes_self=True
|
||||
),
|
||||
)
|
||||
|
||||
def __init__(self, _: str = '', /, _internal: bool = False, **kwargs: t.Any) -> None:
|
||||
if not _internal:
|
||||
raise RuntimeError(
|
||||
'Cannot instantiate LLM directly in the new API. Make sure to use `openllm.LLM.from_model` instead.'
|
||||
)
|
||||
self.__attrs_init__(**kwargs)
|
||||
|
||||
def __attrs_post_init__(self):
|
||||
assert self._model # Ensure the model is initialised.
|
||||
|
||||
@functools.cached_property
|
||||
def _model(self):
|
||||
if is_vllm_available():
|
||||
num_gpus, dev = 1, openllm.utils.device_count()
|
||||
if dev >= 2:
|
||||
num_gpus = min(dev // 2 * 2, dev)
|
||||
quantise = self.quantise if self.quantise and self.quantise in {'gptq', 'awq', 'squeezellm'} else None
|
||||
dtype = 'float16' if quantise == 'gptq' else self.dtype # NOTE: quantise GPTQ doesn't support bfloat16 yet.
|
||||
|
||||
self.engine_args.setdefault('gpu_memory_utilization', 0.9)
|
||||
self.engine_args.update({
|
||||
'worker_use_ray': False,
|
||||
'engine_use_ray': False,
|
||||
'tokenizer_mode': 'auto',
|
||||
'tensor_parallel_size': num_gpus,
|
||||
'model': self._path,
|
||||
'tokenizer': self._path,
|
||||
'trust_remote_code': self.trust_remote_code,
|
||||
'dtype': dtype,
|
||||
'quantization': quantise,
|
||||
})
|
||||
if 'disable_log_stats' not in self.engine_args:
|
||||
self.engine_args['disable_log_stats'] = not get_debug_mode()
|
||||
if 'disable_log_requests' not in self.engine_args:
|
||||
self.engine_args['disable_log_requests'] = not get_debug_mode()
|
||||
|
||||
try:
|
||||
from vllm import AsyncEngineArgs, AsyncLLMEngine
|
||||
|
||||
return AsyncLLMEngine.from_engine_args(AsyncEngineArgs(**self.engine_args))
|
||||
except Exception as err:
|
||||
traceback.print_exc()
|
||||
raise openllm.exceptions.OpenLLMException(
|
||||
f'Failed to initialise vLLMEngine due to the following error:\n{err}'
|
||||
) from err
|
||||
else:
|
||||
raise RuntimeError('Currently, OpenLLM is only supported running with a GPU and vLLM backend.')
|
||||
|
||||
@classmethod
|
||||
def from_model(
|
||||
cls,
|
||||
model_id: str,
|
||||
dtype: str,
|
||||
bentomodel: bentoml.Model | None = None,
|
||||
serialisation: LiteralSerialisation = 'safetensors',
|
||||
quantise: LiteralQuantise | None = None,
|
||||
trust_remote_code: bool = False,
|
||||
llm_config: openllm_core.LLMConfig | None = None,
|
||||
service_config: ServiceConfig | None = None,
|
||||
**engine_args: t.Any,
|
||||
) -> LLM:
|
||||
return cls(
|
||||
_internal=True,
|
||||
model_id=model_id,
|
||||
bentomodel=bentomodel,
|
||||
quantise=quantise,
|
||||
serialisation=serialisation,
|
||||
config=llm_config,
|
||||
dtype=dtype,
|
||||
engine_args=engine_args,
|
||||
trust_remote_code=trust_remote_code,
|
||||
service_config=service_config,
|
||||
)
|
||||
|
||||
@property
|
||||
def llm_type(self) -> str:
|
||||
return normalise_model_name(self.model_id)
|
||||
|
||||
@property
|
||||
def identifying_params(self) -> dict[str, str]:
|
||||
return {
|
||||
'configuration': self.config.model_dump_json(),
|
||||
'model_ids': orjson.dumps(self.config['model_ids']).decode(),
|
||||
'model_id': self.model_id,
|
||||
}
|
||||
|
||||
@functools.cached_property
|
||||
def _tokenizer(self):
|
||||
import transformers
|
||||
|
||||
return transformers.AutoTokenizer.from_pretrained(self._path, trust_remote_code=self.trust_remote_code)
|
||||
|
||||
async def generate_iterator(
|
||||
self,
|
||||
prompt: str | None,
|
||||
prompt_token_ids: list[int] | None = None,
|
||||
stop: str | t.Iterable[str] | None = None,
|
||||
stop_token_ids: list[int] | None = None,
|
||||
request_id: str | None = None,
|
||||
adapter_name: str | None = None,
|
||||
**attrs: t.Any,
|
||||
) -> t.AsyncGenerator[RequestOutput, None]:
|
||||
from vllm import SamplingParams
|
||||
|
||||
config = self.config.generation_config.model_copy(update=dict_filter_none(attrs))
|
||||
|
||||
stop_token_ids = stop_token_ids or []
|
||||
eos_token_id = attrs.get('eos_token_id', config['eos_token_id'])
|
||||
if eos_token_id and not isinstance(eos_token_id, list):
|
||||
eos_token_id = [eos_token_id]
|
||||
stop_token_ids.extend(eos_token_id or [])
|
||||
if (config_eos := config['eos_token_id']) and config_eos not in stop_token_ids:
|
||||
stop_token_ids.append(config_eos)
|
||||
if self._tokenizer.eos_token_id not in stop_token_ids:
|
||||
stop_token_ids.append(self._tokenizer.eos_token_id)
|
||||
if stop is None:
|
||||
stop = set()
|
||||
elif isinstance(stop, str):
|
||||
stop = {stop}
|
||||
else:
|
||||
stop = set(stop)
|
||||
|
||||
request_id = gen_random_uuid() if request_id is None else request_id
|
||||
|
||||
top_p = 1.0 if config['temperature'] <= 1e-5 else config['top_p']
|
||||
config = config.model_copy(update=dict(stop=list(stop), stop_token_ids=stop_token_ids, top_p=top_p))
|
||||
sampling_params = SamplingParams(**{
|
||||
k: getattr(config, k, None) for k in set(inspect.signature(SamplingParams).parameters.keys())
|
||||
})
|
||||
|
||||
try:
|
||||
async for it in self._model.generate(
|
||||
prompt, sampling_params=sampling_params, request_id=request_id, prompt_token_ids=prompt_token_ids
|
||||
):
|
||||
yield it
|
||||
except Exception as err:
|
||||
raise RuntimeError(f'Failed to start generation task: {err}') from err
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
prompt: str | None,
|
||||
prompt_token_ids: list[int] | None = None,
|
||||
stop: str | t.Iterable[str] | None = None,
|
||||
stop_token_ids: list[int] | None = None,
|
||||
request_id: str | None = None,
|
||||
adapter_name: str | None = None,
|
||||
*,
|
||||
_generated: GenerationInput | None = None,
|
||||
**attrs: t.Any,
|
||||
) -> GenerationOutput:
|
||||
if stop is not None:
|
||||
attrs.update({'stop': stop})
|
||||
if stop_token_ids is not None:
|
||||
attrs.update({'stop_token_ids': stop_token_ids})
|
||||
config = self.config.model_copy(update=attrs)
|
||||
texts, token_ids = [[]] * config['n'], [[]] * config['n']
|
||||
async for result in self.generate_iterator(
|
||||
prompt,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
request_id=request_id,
|
||||
adapter_name=adapter_name,
|
||||
**config.model_dump(),
|
||||
):
|
||||
for output in result.outputs:
|
||||
texts[output.index].append(output.text)
|
||||
token_ids[output.index].extend(output.token_ids)
|
||||
if (final_result := result) is None:
|
||||
raise RuntimeError('No result is returned.')
|
||||
return GenerationOutput.from_vllm(final_result).model_copy(
|
||||
update=dict(
|
||||
prompt=prompt,
|
||||
outputs=[
|
||||
output.model_copy(update=dict(text=''.join(texts[output.index]), token_ids=token_ids[output.index]))
|
||||
for output in final_result.outputs
|
||||
],
|
||||
)
|
||||
)
|
||||
139
openllm-python/src/_openllm_tiny/_service.py
Normal file
139
openllm-python/src/_openllm_tiny/_service.py
Normal file
@@ -0,0 +1,139 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from http import HTTPStatus
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse, StreamingResponse
|
||||
import openllm, bentoml, logging, openllm_core as core
|
||||
import _service_vars as svars, typing as t
|
||||
from openllm_core._typing_compat import Annotated, Unpack
|
||||
from openllm_core._schemas import MessageParam, MessagesConverterInput
|
||||
from openllm_core.protocol.openai import ModelCard, ModelList, ChatCompletionRequest
|
||||
from _openllm_tiny._helpers import OpenAI, Error
|
||||
|
||||
try:
|
||||
from fastapi import FastAPI
|
||||
except ImportError:
|
||||
raise ImportError("Make sure to install openllm with 'pip install openllm[openai]'") from None
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
bentomodel = bentoml.models.get(svars.model_id.lower())
|
||||
model_id = bentomodel.path
|
||||
except Exception:
|
||||
bentomodel = None
|
||||
model_id = svars.model_id
|
||||
llm_config = core.AutoConfig.for_model(svars.model_name)
|
||||
GenerationInput = core.GenerationInput.from_config(llm_config)
|
||||
|
||||
app_v1 = FastAPI(debug=True, description='OpenAI Compatible API support')
|
||||
|
||||
|
||||
@bentoml.mount_asgi_app(app_v1)
|
||||
@bentoml.service(name=f"llm-{llm_config['start_name']}-service", **svars.services_config)
|
||||
class LLMService:
|
||||
bentomodel = bentomodel
|
||||
|
||||
def __init__(self):
|
||||
self.llm = openllm.LLM.from_model(
|
||||
model_id,
|
||||
dtype=svars.dtype,
|
||||
bentomodel=bentomodel,
|
||||
serialisation=svars.serialisation,
|
||||
quantise=svars.quantise,
|
||||
llm_config=llm_config,
|
||||
trust_remote_code=svars.trust_remote_code,
|
||||
services_config=svars.services_config,
|
||||
max_model_len=svars.max_model_len,
|
||||
gpu_memory_utilization=svars.gpu_memory_utilization,
|
||||
)
|
||||
self.openai = OpenAI(self.llm)
|
||||
|
||||
@core.utils.api(route='/v1/generate')
|
||||
async def generate_v1(self, **parameters: Unpack[core.GenerationInputDict]) -> core.GenerationOutput:
|
||||
return await self.llm.generate(**GenerationInput.from_dict(parameters).model_dump())
|
||||
|
||||
@core.utils.api(route='/v1/generate_stream')
|
||||
async def generate_stream_v1(self, **parameters: Unpack[core.GenerationInputDict]) -> t.AsyncGenerator[str, None]:
|
||||
async for generated in self.llm.generate_iterator(**GenerationInput.from_dict(parameters).model_dump()):
|
||||
yield f'data: {generated.model_dump_json()}\n\n'
|
||||
yield 'data: [DONE]\n\n'
|
||||
|
||||
@core.utils.api(route='/v1/metadata')
|
||||
async def metadata_v1(self) -> core.MetadataOutput:
|
||||
return core.MetadataOutput(
|
||||
timeout=self.llm.config['timeout'],
|
||||
model_name=self.llm.config['model_name'],
|
||||
backend='vllm', # deprecated
|
||||
model_id=self.llm.model_id,
|
||||
configuration=self.llm.config.model_dump_json(),
|
||||
)
|
||||
|
||||
@core.utils.api(route='/v1/helpers/messages')
|
||||
def helpers_messages_v1(
|
||||
self,
|
||||
message: Annotated[t.Dict[str, t.Any], MessagesConverterInput] = MessagesConverterInput(
|
||||
add_generation_prompt=False,
|
||||
messages=[
|
||||
MessageParam(role='system', content='You are acting as Ernest Hemmingway.'),
|
||||
MessageParam(role='user', content='Hi there!'),
|
||||
MessageParam(role='assistant', content='Yes?'),
|
||||
],
|
||||
),
|
||||
) -> str:
|
||||
return self.llm._tokenizer.apply_chat_template(
|
||||
message['messages'], add_generation_prompt=message['add_generation_prompt'], tokenize=False
|
||||
)
|
||||
|
||||
@app_v1.post(
|
||||
'/v1/chat/completions',
|
||||
tags=['OpenAI'],
|
||||
status_code=HTTPStatus.OK,
|
||||
summary='Given a list of messages comprising a conversation, the model will return a response.',
|
||||
operation_id='openai__chat_completions',
|
||||
)
|
||||
async def chat_completions_v1(
|
||||
self,
|
||||
raw_request: Request,
|
||||
request: ChatCompletionRequest = ChatCompletionRequest(
|
||||
messages=[
|
||||
MessageParam(role='system', content='You are acting as Ernest Hemmingway.'),
|
||||
MessageParam(role='user', content='Hi there!'),
|
||||
MessageParam(role='assistant', content='Yes?'),
|
||||
],
|
||||
model=core.utils.normalise_model_name(model_id),
|
||||
n=1,
|
||||
stream=True,
|
||||
),
|
||||
):
|
||||
generator = await self.openai.chat_completions(request, raw_request)
|
||||
if isinstance(generator, Error):
|
||||
# NOTE: safe to cast here as we know it's an error
|
||||
return JSONResponse(content=generator.model_dump(), status_code=int(t.cast(str, generator.error.code)))
|
||||
if request.stream is True:
|
||||
return StreamingResponse(generator, media_type='text/event-stream')
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
|
||||
# GET /v1/models
|
||||
@app_v1.get(
|
||||
'/v1/models',
|
||||
tags=['OpenAI'],
|
||||
status_code=HTTPStatus.OK,
|
||||
summary='Describes a model offering that can be used with the API.',
|
||||
operation_id='openai__list_models',
|
||||
)
|
||||
def list_models(self) -> ModelList:
|
||||
"""
|
||||
List and describe the various models available in the API.
|
||||
|
||||
You can refer to the available supported models with `openllm models` for more information.
|
||||
"""
|
||||
return ModelList(
|
||||
data=[ModelCard(root=core.utils.normalise_model_name(model_id), id=core.utils.normalise_model_name(model_id))]
|
||||
)
|
||||
|
||||
|
||||
LLMService.mount_asgi_app(app_v1)
|
||||
|
||||
if __name__ == '__main__':
|
||||
LLMService.serve_http(reload=core.utils.check_bool_env('RELOAD', False))
|
||||
30
openllm-python/src/_openllm_tiny/_service_vars.py
Normal file
30
openllm-python/src/_openllm_tiny/_service_vars.py
Normal file
@@ -0,0 +1,30 @@
|
||||
import orjson, typing as t, openllm_core.utils as coreutils
|
||||
from openllm_core._typing_compat import LiteralSerialisation, LiteralQuantise
|
||||
from _openllm_tiny._llm import Dtype
|
||||
|
||||
(
|
||||
model_id,
|
||||
model_name,
|
||||
quantise,
|
||||
serialisation,
|
||||
dtype,
|
||||
trust_remote_code,
|
||||
max_model_len,
|
||||
gpu_memory_utilization,
|
||||
services_config,
|
||||
) = (
|
||||
coreutils.getenv('model_id', var=['MODEL_ID'], return_type=str),
|
||||
coreutils.getenv('model_name', return_type=str),
|
||||
coreutils.getenv('quantize', var=['QUANTISE'], return_type=LiteralQuantise),
|
||||
coreutils.getenv('serialization', default='safetensors', var=['SERIALISATION'], return_type=LiteralSerialisation),
|
||||
coreutils.getenv('dtype', default='auto', var=['TORCH_DTYPE'], return_type=Dtype),
|
||||
coreutils.check_bool_env('TRUST_REMOTE_CODE', False),
|
||||
t.cast(t.Optional[int], orjson.loads(coreutils.getenv('max_model_len', default=orjson.dumps(None)))),
|
||||
t.cast(
|
||||
float,
|
||||
orjson.loads(
|
||||
coreutils.getenv('gpu_memory_utilization', default=orjson.dumps(0.9), var=['GPU_MEMORY_UTILISATION'])
|
||||
),
|
||||
),
|
||||
t.cast(t.Dict[str, t.Any], orjson.loads(coreutils.getenv('services_config', orjson.dumps({})))),
|
||||
)
|
||||
1
openllm-python/src/_openllm_tiny/bentofile.yaml
Normal file
1
openllm-python/src/_openllm_tiny/bentofile.yaml
Normal file
@@ -0,0 +1 @@
|
||||
service: '_service:LLMService'
|
||||
Reference in New Issue
Block a user