mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-03-08 00:59:50 -05:00
feat: 1.2 APIs (#821)
Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
3
openllm-python/src/_openllm_tiny/__init__.py
Normal file
3
openllm-python/src/_openllm_tiny/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from ._llm import LLM as LLM
|
||||
25
openllm-python/src/_openllm_tiny/__init__.pyi
Normal file
25
openllm-python/src/_openllm_tiny/__init__.pyi
Normal file
@@ -0,0 +1,25 @@
|
||||
import bentoml
|
||||
from bentoml._internal.models.model import ModelInfo
|
||||
from openllm_core._typing_compat import TypedDict, NotRequired
|
||||
from ._llm import LLM as LLM
|
||||
|
||||
class _Metadata(TypedDict):
|
||||
model_id: str
|
||||
dtype: str
|
||||
_revision: str
|
||||
_local: bool
|
||||
serialisation: str
|
||||
architectures: str
|
||||
trust_remote_code: bool
|
||||
api_version: str
|
||||
llm_type: str
|
||||
openllm_version: str
|
||||
openllm_core_version: str
|
||||
openllm_client_version: str
|
||||
quantize: NotRequired[str]
|
||||
|
||||
class _Info(ModelInfo):
|
||||
metadata: _Metadata # type: ignore[assignment]
|
||||
|
||||
class _Model(bentoml.Model):
|
||||
info: _Info
|
||||
499
openllm-python/src/_openllm_tiny/_entrypoint.py
Normal file
499
openllm-python/src/_openllm_tiny/_entrypoint.py
Normal file
@@ -0,0 +1,499 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os, logging, traceback, pathlib, sys, fs, click, enum, inflection, bentoml, orjson, openllm, openllm_core, platform, typing as t
|
||||
from ._helpers import recommended_instance_type
|
||||
from openllm_core.utils import (
|
||||
DEBUG_ENV_VAR,
|
||||
QUIET_ENV_VAR,
|
||||
SHOW_CODEGEN,
|
||||
check_bool_env,
|
||||
compose,
|
||||
first_not_none,
|
||||
dantic,
|
||||
gen_random_uuid,
|
||||
get_debug_mode,
|
||||
get_quiet_mode,
|
||||
normalise_model_name,
|
||||
)
|
||||
from openllm_core._typing_compat import (
|
||||
LiteralQuantise,
|
||||
LiteralSerialisation,
|
||||
LiteralDtype,
|
||||
get_literal_args,
|
||||
TypedDict,
|
||||
)
|
||||
from openllm_cli import termui
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
OPENLLM_FIGLET = """
|
||||
██████╗ ██████╗ ███████╗███╗ ██╗██╗ ██╗ ███╗ ███╗
|
||||
██╔═══██╗██╔══██╗██╔════╝████╗ ██║██║ ██║ ████╗ ████║
|
||||
██║ ██║██████╔╝█████╗ ██╔██╗ ██║██║ ██║ ██╔████╔██║
|
||||
██║ ██║██╔═══╝ ██╔══╝ ██║╚██╗██║██║ ██║ ██║╚██╔╝██║
|
||||
╚██████╔╝██║ ███████╗██║ ╚████║███████╗███████╗██║ ╚═╝ ██║
|
||||
╚═════╝ ╚═╝ ╚══════╝╚═╝ ╚═══╝╚══════╝╚══════╝╚═╝ ╚═╝.
|
||||
"""
|
||||
_PACKAGE_NAME = 'openllm'
|
||||
_SERVICE_FILE = pathlib.Path(os.path.abspath(__file__)).parent / '_service.py'
|
||||
_SERVICE_VARS = '''\
|
||||
# fmt: off
|
||||
# GENERATED BY 'openllm build {__model_id__}'. DO NOT EDIT
|
||||
import orjson,openllm_core.utils as coreutils
|
||||
model_id='{__model_id__}'
|
||||
model_name='{__model_name__}'
|
||||
quantise=coreutils.getenv('quantize',default='{__model_quantise__}',var=['QUANTISE'])
|
||||
serialisation=coreutils.getenv('serialization',default='{__model_serialization__}',var=['SERIALISATION'])
|
||||
dtype=coreutils.getenv('dtype', default='{__model_dtype__}', var=['TORCH_DTYPE'])
|
||||
trust_remote_code=coreutils.check_bool_env("TRUST_REMOTE_CODE",{__model_trust_remote_code__})
|
||||
max_model_len={__max_model_len__}
|
||||
gpu_memory_utilization={__gpu_memory_utilization__}
|
||||
services_config=orjson.loads(coreutils.getenv('services_config',"""{__services_config__}"""))
|
||||
'''
|
||||
|
||||
|
||||
class ItemState(enum.Enum):
|
||||
NOT_FOUND = 'NOT_FOUND'
|
||||
ADDED = 'ADDED'
|
||||
EXISTS = 'EXISTS'
|
||||
OVERWRITE = 'OVERWRITE'
|
||||
|
||||
|
||||
def parse_device_callback(
|
||||
_: click.Context, param: click.Parameter, value: tuple[tuple[str], ...] | None
|
||||
) -> t.Tuple[str, ...] | None:
|
||||
if value is None:
|
||||
return value
|
||||
el: t.Tuple[str, ...] = tuple(i for k in value for i in k)
|
||||
# NOTE: --device all is a special case
|
||||
if len(el) == 1 and el[0] == 'all':
|
||||
return tuple(map(str, openllm.utils.available_devices()))
|
||||
return el
|
||||
|
||||
|
||||
@click.group(context_settings=termui.CONTEXT_SETTINGS, name='openllm')
|
||||
@click.version_option(
|
||||
None,
|
||||
'--version',
|
||||
'-v',
|
||||
package_name=_PACKAGE_NAME,
|
||||
message=f'{_PACKAGE_NAME}, %(version)s (compiled: {openllm.COMPILED})\nPython ({platform.python_implementation()}) {platform.python_version()}',
|
||||
)
|
||||
def cli() -> None:
|
||||
"""\b
|
||||
██████╗ ██████╗ ███████╗███╗ ██╗██╗ ██╗ ███╗ ███╗
|
||||
██╔═══██╗██╔══██╗██╔════╝████╗ ██║██║ ██║ ████╗ ████║
|
||||
██║ ██║██████╔╝█████╗ ██╔██╗ ██║██║ ██║ ██╔████╔██║
|
||||
██║ ██║██╔═══╝ ██╔══╝ ██║╚██╗██║██║ ██║ ██║╚██╔╝██║
|
||||
╚██████╔╝██║ ███████╗██║ ╚████║███████╗███████╗██║ ╚═╝ ██║
|
||||
╚═════╝ ╚═╝ ╚══════╝╚═╝ ╚═══╝╚══════╝╚══════╝╚═╝ ╚═╝.
|
||||
|
||||
\b
|
||||
An open platform for operating large language models in production.
|
||||
Fine-tune, serve, deploy, and monitor any LLMs with ease.
|
||||
"""
|
||||
|
||||
|
||||
def optimization_decorator(fn):
|
||||
# NOTE: return device, quantize, serialisation, dtype, max_model_len, gpu_memory_utilization
|
||||
optimization = [
|
||||
click.option(
|
||||
'--device',
|
||||
type=dantic.CUDA,
|
||||
multiple=True,
|
||||
envvar='CUDA_VISIBLE_DEVICES',
|
||||
callback=parse_device_callback,
|
||||
help='Assign GPU devices (if available)',
|
||||
show_envvar=True,
|
||||
),
|
||||
click.option(
|
||||
'--dtype',
|
||||
type=str,
|
||||
envvar='DTYPE',
|
||||
default='auto',
|
||||
help="Optional dtype for casting tensors for running inference ['float16', 'float32', 'bfloat16', 'int8', 'int16']",
|
||||
),
|
||||
click.option(
|
||||
'--quantise',
|
||||
'--quantize',
|
||||
'quantize',
|
||||
type=str,
|
||||
default=None,
|
||||
envvar='QUANTIZE',
|
||||
show_envvar=True,
|
||||
help="""Dynamic quantization for running this LLM.
|
||||
|
||||
The following quantization strategies are supported:
|
||||
|
||||
- ``int8``: ``LLM.int8`` for [8-bit](https://arxiv.org/abs/2208.07339) quantization.
|
||||
|
||||
- ``int4``: ``SpQR`` for [4-bit](https://arxiv.org/abs/2306.03078) quantization.
|
||||
|
||||
- ``gptq``: ``GPTQ`` [quantization](https://arxiv.org/abs/2210.17323)
|
||||
|
||||
- ``awq``: ``AWQ`` [AWQ: Activation-aware Weight Quantization](https://arxiv.org/abs/2306.00978)
|
||||
|
||||
- ``squeezellm``: ``SqueezeLLM`` [SqueezeLLM: Dense-and-Sparse Quantization](https://arxiv.org/abs/2306.07629)
|
||||
|
||||
> [!NOTE] that the model can also be served with quantized weights.
|
||||
""",
|
||||
),
|
||||
click.option(
|
||||
'--serialisation',
|
||||
'--serialization',
|
||||
'serialisation',
|
||||
type=click.Choice(get_literal_args(LiteralSerialisation)),
|
||||
default=None,
|
||||
show_default=True,
|
||||
show_envvar=True,
|
||||
envvar='OPENLLM_SERIALIZATION',
|
||||
help="""Serialisation format for save/load LLM.
|
||||
|
||||
Currently the following strategies are supported:
|
||||
|
||||
- ``safetensors``: This will use safetensors format, which is synonymous to ``safe_serialization=True``.
|
||||
|
||||
> [!NOTE] Safetensors might not work for every cases, and you can always fallback to ``legacy`` if needed.
|
||||
|
||||
- ``legacy``: This will use PyTorch serialisation format, often as ``.bin`` files. This should be used if the model doesn't yet support safetensors.
|
||||
""",
|
||||
),
|
||||
click.option(
|
||||
'--max-model-len',
|
||||
'--max_model_len',
|
||||
'max_model_len',
|
||||
type=int,
|
||||
default=None,
|
||||
help='Maximum sequence length for the model. If not specified, we will use the default value from the model config.',
|
||||
),
|
||||
click.option(
|
||||
'--gpu-memory-utilization',
|
||||
'--gpu_memory_utilization',
|
||||
'gpu_memory_utilization',
|
||||
default=0.9,
|
||||
help='The percentage of GPU memory to be used for the model executor',
|
||||
),
|
||||
]
|
||||
return compose(*optimization)(fn)
|
||||
|
||||
|
||||
def shared_decorator(fn):
|
||||
shared = [
|
||||
click.argument(
|
||||
'model_id', type=click.STRING, metavar='[REMOTE_REPO/MODEL_ID | /path/to/local/model]', required=True
|
||||
),
|
||||
click.option(
|
||||
'--revision',
|
||||
'--bentomodel-version',
|
||||
'--model-version',
|
||||
'model_version',
|
||||
type=click.STRING,
|
||||
default=None,
|
||||
help='Optional model revision to save for this model. It will be inferred automatically from model-id.',
|
||||
),
|
||||
click.option(
|
||||
'--model-tag',
|
||||
'--bentomodel-tag',
|
||||
'model_tag',
|
||||
type=click.STRING,
|
||||
default=None,
|
||||
help='Optional bentomodel tag to save for this model. It will be generated automatically based on model_id and model_version if not specified.',
|
||||
),
|
||||
]
|
||||
return compose(*shared)(fn)
|
||||
|
||||
|
||||
@cli.command(name='start')
|
||||
@shared_decorator
|
||||
@click.option('--timeout', type=int, default=360000, help='Timeout for the model executor in seconds')
|
||||
@optimization_decorator
|
||||
def start_command(
|
||||
model_id: str,
|
||||
model_version: str | None,
|
||||
model_tag: str | None,
|
||||
timeout: int,
|
||||
device: t.Tuple[str, ...],
|
||||
quantize: LiteralQuantise | None,
|
||||
serialisation: LiteralSerialisation | None,
|
||||
dtype: LiteralDtype | t.Literal['auto', 'float'],
|
||||
max_model_len: int | None,
|
||||
gpu_memory_utilization: float,
|
||||
):
|
||||
"""Start any LLM as a REST server.
|
||||
|
||||
\b
|
||||
```bash
|
||||
$ openllm <start|start-http> <model_id> --<options> ...
|
||||
```
|
||||
"""
|
||||
import transformers
|
||||
|
||||
from _bentoml_impl.server import serve_http
|
||||
from bentoml._internal.service.loader import load
|
||||
from bentoml._internal.log import configure_server_logging
|
||||
|
||||
configure_server_logging()
|
||||
|
||||
trust_remote_code = check_bool_env('TRUST_REMOTE_CODE', False)
|
||||
try:
|
||||
# if given model_id is a private model, then we can use it directly
|
||||
bentomodel = bentoml.models.get(model_id.lower())
|
||||
model_id = bentomodel.path
|
||||
except (ValueError, bentoml.exceptions.NotFound):
|
||||
pass
|
||||
config = transformers.AutoConfig.from_pretrained(model_id, trust_remote_code=trust_remote_code)
|
||||
for arch in config.architectures:
|
||||
if arch in openllm_core.AutoConfig._architecture_mappings:
|
||||
model_name = openllm_core.AutoConfig._architecture_mappings[arch]
|
||||
break
|
||||
else:
|
||||
raise RuntimeError(f'Failed to determine config class for {model_id}')
|
||||
|
||||
llm_config = openllm_core.AutoConfig.for_model(model_name).model_construct_env()
|
||||
|
||||
# TODO: support LoRA adapters
|
||||
os.environ.update({
|
||||
QUIET_ENV_VAR: str(openllm.utils.get_quiet_mode()),
|
||||
DEBUG_ENV_VAR: str(openllm.utils.get_debug_mode()),
|
||||
'MODEL_ID': model_id,
|
||||
'MODEL_NAME': model_name,
|
||||
'SERIALIZATION': serialisation,
|
||||
'OPENLLM_CONFIG': llm_config.model_dump_json(),
|
||||
'DTYPE': dtype,
|
||||
'TRUST_REMOTE_CODE': str(trust_remote_code),
|
||||
'MAX_MODEL_LEN': orjson.dumps(max_model_len).decode(),
|
||||
'GPU_MEMORY_UTILIZATION': orjson.dumps(gpu_memory_utilization).decode(),
|
||||
'SERVICES_CONFIG': orjson.dumps(
|
||||
dict(
|
||||
resources={'gpu' if device else 'cpu': len(device) if device else 'cpu_count'}, traffic=dict(timeout=timeout)
|
||||
)
|
||||
).decode(),
|
||||
})
|
||||
if quantize:
|
||||
os.environ['QUANTIZE'] = str(quantize)
|
||||
|
||||
working_dir = os.path.abspath(os.path.dirname(__file__))
|
||||
if sys.path[0] != working_dir:
|
||||
sys.path.insert(0, working_dir)
|
||||
load('.', working_dir=working_dir).inject_config()
|
||||
serve_http('.', working_dir=working_dir)
|
||||
|
||||
|
||||
def construct_python_options(llm_config, llm_fs):
|
||||
from openllm.bundle import RefResolver
|
||||
from bentoml._internal.bento.build_config import PythonOptions
|
||||
from openllm.bundle._package import build_editable
|
||||
|
||||
packages = ['scipy', 'bentoml[tracing]>=1.2', f'openllm[vllm]>={RefResolver.from_strategy("release").version}']
|
||||
if llm_config['requirements'] is not None:
|
||||
packages.extend(llm_config['requirements'])
|
||||
built_wheels = [build_editable(llm_fs.getsyspath('/'), p) for p in ('openllm_core', 'openllm_client', 'openllm')]
|
||||
return PythonOptions(
|
||||
packages=packages,
|
||||
wheels=[llm_fs.getsyspath(f"/{i.split('/')[-1]}") for i in built_wheels] if all(i for i in built_wheels) else None,
|
||||
lock_packages=False,
|
||||
)
|
||||
|
||||
|
||||
class EnvironmentEntry(TypedDict):
|
||||
name: str
|
||||
value: str
|
||||
|
||||
|
||||
@cli.command(name='build', context_settings={'token_normalize_func': inflection.underscore})
|
||||
@shared_decorator
|
||||
@click.option(
|
||||
'--bento-version',
|
||||
type=str,
|
||||
default=None,
|
||||
help='Optional bento version for this BentoLLM. Default is the the model revision.',
|
||||
)
|
||||
@click.option(
|
||||
'--bento-tag',
|
||||
type=str,
|
||||
default=None,
|
||||
help='Optional bento version for this BentoLLM. Default is the the model revision.',
|
||||
)
|
||||
@click.option('--overwrite', is_flag=True, help='Overwrite existing Bento for given LLM if it already exists.')
|
||||
@click.option('--timeout', type=int, default=360000, help='Timeout for the model executor in seconds')
|
||||
@optimization_decorator
|
||||
@click.option(
|
||||
'-o',
|
||||
'--output',
|
||||
type=click.Choice(['tag', 'default']),
|
||||
default='default',
|
||||
show_default=True,
|
||||
help="Output log format. '-o tag' to display only bento tag.",
|
||||
)
|
||||
@click.pass_context
|
||||
def build_command(
|
||||
ctx: click.Context,
|
||||
/,
|
||||
model_id: str,
|
||||
model_version: str | None,
|
||||
model_tag: str | None,
|
||||
bento_version: str | None,
|
||||
bento_tag: str | None,
|
||||
overwrite: bool,
|
||||
device: t.Tuple[str, ...],
|
||||
timeout: int,
|
||||
quantize: LiteralQuantise | None,
|
||||
serialisation: LiteralSerialisation | None,
|
||||
dtype: LiteralDtype | t.Literal['auto', 'float'],
|
||||
max_model_len: int | None,
|
||||
gpu_memory_utilization: float,
|
||||
output: t.Literal['default', 'tag'],
|
||||
):
|
||||
"""Package a given models into a BentoLLM.
|
||||
|
||||
\b
|
||||
```bash
|
||||
$ openllm build google/flan-t5-large
|
||||
```
|
||||
|
||||
\b
|
||||
> [!NOTE]
|
||||
> To run a container built from this Bento with GPU support, make sure
|
||||
> to have https://github.com/NVIDIA/nvidia-container-toolkit install locally.
|
||||
|
||||
\b
|
||||
> [!IMPORTANT]
|
||||
> To build the bento with compiled OpenLLM, make sure to prepend HATCH_BUILD_HOOKS_ENABLE=1. Make sure that the deployment
|
||||
> target also use the same Python version and architecture as build machine.
|
||||
"""
|
||||
import transformers
|
||||
from bentoml._internal.configuration.containers import BentoMLContainer
|
||||
from bentoml._internal.configuration import set_quiet_mode
|
||||
from bentoml._internal.log import configure_logging
|
||||
from bentoml._internal.bento.build_config import BentoBuildConfig
|
||||
from bentoml._internal.bento.build_config import DockerOptions
|
||||
from bentoml._internal.bento.build_config import ModelSpec
|
||||
|
||||
if output == 'tag':
|
||||
set_quiet_mode(True)
|
||||
configure_logging()
|
||||
|
||||
trust_remote_code = check_bool_env('TRUST_REMOTE_CODE', False)
|
||||
try:
|
||||
# if given model_id is a private model, then we can use it directly
|
||||
bentomodel = bentoml.models.get(model_id.lower())
|
||||
model_id = bentomodel.path
|
||||
_revision = bentomodel.tag.version
|
||||
except (ValueError, bentoml.exceptions.NotFound):
|
||||
bentomodel = None
|
||||
_revision = None
|
||||
|
||||
config = transformers.AutoConfig.from_pretrained(model_id, trust_remote_code=trust_remote_code)
|
||||
for arch in config.architectures:
|
||||
if arch in openllm_core.AutoConfig._architecture_mappings:
|
||||
model_name = openllm_core.AutoConfig._architecture_mappings[arch]
|
||||
break
|
||||
else:
|
||||
raise RuntimeError(f'Failed to determine config class for {model_id}')
|
||||
|
||||
llm_config: openllm_core.LLMConfig = openllm_core.AutoConfig.for_model(model_name).model_construct_env()
|
||||
|
||||
_revision = first_not_none(_revision, getattr(config, '_commit_hash', None), default=gen_random_uuid())
|
||||
if serialisation is None:
|
||||
termui.warning(
|
||||
f"Serialisation format is not specified. Defaulting to '{llm_config['serialisation']}'. Your model might not work with this format. Make sure to explicitly specify the serialisation format."
|
||||
)
|
||||
serialisation = llm_config['serialisation']
|
||||
|
||||
if bento_tag is None:
|
||||
_bento_version = first_not_none(bento_version, default=_revision)
|
||||
bento_tag = bentoml.Tag.from_taglike(f'{normalise_model_name(model_id)}-service:{_bento_version}'.lower().strip())
|
||||
else:
|
||||
bento_tag = bentoml.Tag.from_taglike(bento_tag)
|
||||
|
||||
state = ItemState.NOT_FOUND
|
||||
try:
|
||||
bento = bentoml.get(bento_tag)
|
||||
if overwrite:
|
||||
bentoml.delete(bento_tag)
|
||||
state = ItemState.OVERWRITE
|
||||
raise bentoml.exceptions.NotFound(f'Rebuilding existing Bento {bento_tag}') from None
|
||||
state = ItemState.EXISTS
|
||||
except bentoml.exceptions.NotFound:
|
||||
if state != ItemState.OVERWRITE:
|
||||
state = ItemState.ADDED
|
||||
|
||||
labels = {'library': 'vllm'}
|
||||
service_config = dict(
|
||||
resources={
|
||||
'gpu' if device else 'cpu': len(device) if device else 'cpu_count',
|
||||
'gpu_type': recommended_instance_type(model_id, bentomodel),
|
||||
},
|
||||
traffic=dict(timeout=timeout),
|
||||
)
|
||||
with fs.open_fs(f'temp://llm_{gen_random_uuid()}') as llm_fs:
|
||||
logger.debug('Generating service vars %s (dir=%s)', model_id, llm_fs.getsyspath('/'))
|
||||
script = _SERVICE_VARS.format(
|
||||
__model_id__=model_id,
|
||||
__model_name__=model_name,
|
||||
__model_quantise__=quantize,
|
||||
__model_dtype__=dtype,
|
||||
__model_serialization__=serialisation,
|
||||
__model_trust_remote_code__=trust_remote_code,
|
||||
__max_model_len__=max_model_len,
|
||||
__gpu_memory_utilization__=gpu_memory_utilization,
|
||||
__services_config__=orjson.dumps(service_config).decode(),
|
||||
)
|
||||
models = []
|
||||
if bentomodel is not None:
|
||||
models.append(ModelSpec.from_item({'tag': str(bentomodel.tag), 'alias': bentomodel.tag.name}))
|
||||
if SHOW_CODEGEN:
|
||||
logger.info('Generated _service_vars.py:\n%s', script)
|
||||
llm_fs.writetext('_service_vars.py', script)
|
||||
with _SERVICE_FILE.open('r') as f:
|
||||
service_src = f.read()
|
||||
llm_fs.writetext(llm_config['service_name'], service_src)
|
||||
bento = bentoml.Bento.create(
|
||||
version=bento_tag.version,
|
||||
build_ctx=llm_fs.getsyspath('/'),
|
||||
build_config=BentoBuildConfig(
|
||||
service=f"{llm_config['service_name']}:LLMService",
|
||||
name=bento_tag.name,
|
||||
labels=labels,
|
||||
models=models,
|
||||
envs=[
|
||||
EnvironmentEntry(name=QUIET_ENV_VAR, value=str(openllm.utils.get_quiet_mode())),
|
||||
EnvironmentEntry(name=DEBUG_ENV_VAR, value=str(openllm.utils.get_debug_mode())),
|
||||
EnvironmentEntry(name='OPENLLM_CONFIG', value=llm_config.model_dump_json()),
|
||||
EnvironmentEntry(name='NVIDIA_DRIVER_CAPABILITIES', value='compute,utility'),
|
||||
],
|
||||
description=f"OpenLLM service for {llm_config['start_name']}",
|
||||
include=list(llm_fs.walk.files()),
|
||||
exclude=['/venv', '/.venv', '__pycache__/', '*.py[cod]', '*$py.class'],
|
||||
python=construct_python_options(llm_config, llm_fs),
|
||||
docker=DockerOptions(python_version='3.11'),
|
||||
),
|
||||
).save(bento_store=BentoMLContainer.bento_store.get(), model_store=BentoMLContainer.model_store.get())
|
||||
except Exception as err:
|
||||
traceback.print_exc()
|
||||
raise click.ClickException('Exception caught while building BentoLLM:\n' + str(err)) from err
|
||||
|
||||
if output == 'tag':
|
||||
termui.echo(f'__tag__:{bento.tag}')
|
||||
return
|
||||
|
||||
if not get_quiet_mode():
|
||||
if state != ItemState.EXISTS:
|
||||
termui.info(f"Successfully built Bento '{bento.tag}'.\n")
|
||||
elif not overwrite:
|
||||
termui.warning(f"Bento for '{model_id}' already exists [{bento}]. To overwrite it pass '--overwrite'.\n")
|
||||
if not get_debug_mode():
|
||||
termui.echo(OPENLLM_FIGLET)
|
||||
termui.echo('📖 Next steps:\n', nl=False)
|
||||
termui.echo(f'☁️ Deploy to BentoCloud:\n $ bentoml deploy {bento.tag} -n ${{DEPLOYMENT_NAME}}\n', nl=False)
|
||||
termui.echo(
|
||||
f'☁️ Update existing deployment on BentoCloud:\n $ bentoml deployment update --bento {bento.tag} ${{DEPLOYMENT_NAME}}\n',
|
||||
nl=False,
|
||||
)
|
||||
termui.echo(f'🐳 Containerize BentoLLM:\n $ bentoml containerize {bento.tag} --opt progress=plain\n', nl=False)
|
||||
|
||||
return bento
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
cli()
|
||||
340
openllm-python/src/_openllm_tiny/_helpers.py
Normal file
340
openllm-python/src/_openllm_tiny/_helpers.py
Normal file
@@ -0,0 +1,340 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import openllm, traceback, logging, time, pathlib, pydantic, typing as t
|
||||
from openllm_core.exceptions import ModelNotFound, OpenLLMException, ValidationError
|
||||
from openllm_core.utils import gen_random_uuid, resolve_filepath
|
||||
from openllm_core.protocol.openai import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponseChoice,
|
||||
ChatCompletionResponseStreamChoice,
|
||||
ChatCompletionStreamResponse,
|
||||
ChatMessage,
|
||||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
ChatCompletionResponse,
|
||||
Delta,
|
||||
ErrorResponse,
|
||||
NotSupportedError,
|
||||
LogProbs,
|
||||
UsageInfo,
|
||||
)
|
||||
from starlette.requests import Request
|
||||
from huggingface_hub import scan_cache_dir
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from vllm import RequestOutput
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Error(pydantic.BaseModel):
|
||||
error: ErrorResponse
|
||||
|
||||
|
||||
def error_response(exception: type[OpenLLMException], message: str) -> ErrorResponse:
|
||||
return Error(
|
||||
error=ErrorResponse(message=message, type=str(exception.__qualname__), code=str(exception.error_code.value))
|
||||
)
|
||||
|
||||
|
||||
class OpenAI:
|
||||
def __init__(self, llm: openllm.LLM):
|
||||
self.llm = llm
|
||||
|
||||
async def chat_completions(self, request: ChatCompletionRequest, raw_request: Request):
|
||||
if request.logit_bias is not None and len(request.logit_bias) > 0:
|
||||
return self.create_error_response("'logit_bias' is not supported .", NotSupportedError)
|
||||
|
||||
error = await self._check_model(request)
|
||||
if error is not None:
|
||||
return error
|
||||
|
||||
try:
|
||||
prompt: str = self.llm._tokenizer.apply_chat_template(
|
||||
conversation=request.messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=request.add_generation_prompt,
|
||||
chat_template=request.chat_template if request.chat_template != 'None' else None,
|
||||
)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
return self.create_error_response(f'Failed to apply chat template: {e}')
|
||||
|
||||
model_name, request_id = request.model, gen_random_uuid('chatcmpl')
|
||||
created_time = int(time.monotonic())
|
||||
|
||||
generator = self.llm.generate_iterator(prompt, request_id=request_id, **request.model_dump())
|
||||
if request.stream:
|
||||
return self.chat_completion_stream_generator(request, model_name, request_id, created_time, generator)
|
||||
|
||||
try:
|
||||
return await self.chat_completion_full_generator(
|
||||
request, raw_request, model_name, request_id, created_time, generator
|
||||
)
|
||||
except Exception as err:
|
||||
traceback.print_exc()
|
||||
return self.create_error_response(str(err))
|
||||
|
||||
async def chat_completion_stream_generator(
|
||||
self,
|
||||
request: ChatCompletionRequest,
|
||||
model_name: str,
|
||||
request_id: str,
|
||||
created_time: int,
|
||||
generator: t.AsyncIterator[RequestOutput],
|
||||
):
|
||||
first_iteration = True
|
||||
|
||||
previous_texts = [''] * self.llm.config['n']
|
||||
previous_num_tokens = [0] * self.llm.config['n']
|
||||
finish_reason_sent = [False] * self.llm.config['n']
|
||||
try:
|
||||
async for request_output in generator:
|
||||
if first_iteration:
|
||||
role = self.get_chat_role(request)
|
||||
for i in range(self.llm.config['n']):
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=i, delta=Delta(role=role), logprobs=None, finish_reason=None
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=request_id, created=created_time, choices=[choice_data], model=model_name
|
||||
)
|
||||
data = chunk.model_dump_json(exclude_unset=True)
|
||||
yield f'data: {data}\n\n'
|
||||
|
||||
if request.echo:
|
||||
last_msg_content = ''
|
||||
if (
|
||||
request.messages
|
||||
and isinstance(request.messages, list)
|
||||
and request.messages[-1].get('content')
|
||||
and request.messages[-1].get('role') == role
|
||||
):
|
||||
last_msg_content = request.messages[-1]['content']
|
||||
|
||||
if last_msg_content:
|
||||
for i in range(request.n):
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=i, delta=Delta(content=last_msg_content), finish_reason=None
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=request_id, created=created_time, choices=[choice_data], logprobs=None, model=model_name
|
||||
)
|
||||
data = chunk.model_dump_json(exclude_unset=True)
|
||||
yield f'data: {data}\n\n'
|
||||
first_iteration = False
|
||||
|
||||
for output in request_output.outputs:
|
||||
i = output.index
|
||||
|
||||
if finish_reason_sent[i]:
|
||||
continue
|
||||
|
||||
delta_token_ids = output.token_ids[previous_num_tokens[i] :]
|
||||
top_logprobs = output.logprobs[previous_num_tokens[i] :] if output.logprobs else None
|
||||
logprobs = None
|
||||
|
||||
if request.logprobs:
|
||||
logprobs = self._create_logprobs(
|
||||
token_ids=delta_token_ids,
|
||||
top_logprobs=top_logprobs,
|
||||
num_output_top_logprobs=request.logprobs,
|
||||
initial_text_offset=len(previous_texts[i]),
|
||||
)
|
||||
|
||||
delta_text = output.text[len(previous_texts[i]) :]
|
||||
previous_texts[i] = output.text
|
||||
previous_num_tokens[i] = len(output.token_ids)
|
||||
if output.finish_reason is None:
|
||||
# Send token-by-token response for each request.n
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=i, delta=Delta(content=delta_text), logprobs=logprobs, finish_reason=None
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=request_id, created=created_time, choices=[choice_data], model=model_name
|
||||
)
|
||||
data = chunk.model_dump_json(exclude_unset=True)
|
||||
yield f'data: {data}\n\n'
|
||||
else:
|
||||
# Send the finish response for each request.n only once
|
||||
prompt_tokens = len(request_output.prompt_token_ids)
|
||||
final_usage = UsageInfo(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=previous_num_tokens[i],
|
||||
total_tokens=prompt_tokens + previous_num_tokens[i],
|
||||
)
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=i, delta=Delta(content=delta_text), logprobs=logprobs, finish_reason=output.finish_reason
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=request_id, created=created_time, choices=[choice_data], model=model_name
|
||||
)
|
||||
if final_usage is not None:
|
||||
chunk.usage = final_usage
|
||||
data = chunk.model_dump_json(exclude_unset=True, exclude_none=True)
|
||||
yield f'data: {data}\n\n'
|
||||
finish_reason_sent[i] = True
|
||||
except ValueError as e:
|
||||
data = self.create_error_response(str(e)).model_dump_json()
|
||||
yield f'data: {data}\n\n'
|
||||
# Send the final done message after all response.n are finished
|
||||
yield 'data: [DONE]\n\n'
|
||||
|
||||
async def chat_completion_full_generator(
|
||||
self,
|
||||
request: ChatCompletionRequest,
|
||||
raw_request: Request,
|
||||
model_name: str,
|
||||
request_id: str,
|
||||
created_time: int,
|
||||
generator: t.AsyncIterator[RequestOutput],
|
||||
):
|
||||
final_result: RequestOutput = None
|
||||
|
||||
try:
|
||||
async for request_output in generator:
|
||||
if await raw_request.is_disconnected():
|
||||
await self.llm._model.abort(request_id)
|
||||
return self.create_error_response('Client disconnected.')
|
||||
final_result = request_output
|
||||
except ValueError as e:
|
||||
await self.llm._model.abort(request_id)
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
if final_result is None:
|
||||
return self.create_error_response('No result is returned.')
|
||||
|
||||
choices = []
|
||||
role = self.get_chat_role(request)
|
||||
|
||||
for output in final_result.outputs:
|
||||
token_ids = output.token_ids
|
||||
top_logprobs = output.logprobs
|
||||
|
||||
logprobs = None
|
||||
if request.logprobs:
|
||||
logprobs = self._create_logprobs(
|
||||
token_ids=token_ids, top_logprobs=top_logprobs, num_output_top_logprobs=request.logprobs
|
||||
)
|
||||
|
||||
choice_data = ChatCompletionResponseChoice(
|
||||
index=output.index,
|
||||
message=ChatMessage(role=role, content=output.text),
|
||||
logprobs=logprobs,
|
||||
finish_reason=output.finish_reason,
|
||||
)
|
||||
choices.append(choice_data)
|
||||
|
||||
if request.echo:
|
||||
last_msg_content = ''
|
||||
if (
|
||||
request.messages
|
||||
and isinstance(request.messages, list)
|
||||
and request.messages[-1].get('content')
|
||||
and request.messages[-1].get('role') == role
|
||||
):
|
||||
last_msg_content = request.messages[-1]['content']
|
||||
|
||||
for choice in choices:
|
||||
full_message = last_msg_content + choice.message.content
|
||||
choice.message.content = full_message
|
||||
|
||||
num_prompt_tokens = len(final_result.prompt_token_ids)
|
||||
num_generated_tokens = sum(len(output.token_ids) for output in final_result.outputs)
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=num_generated_tokens,
|
||||
total_tokens=num_prompt_tokens + num_generated_tokens,
|
||||
)
|
||||
response = ChatCompletionResponse(
|
||||
id=request_id, created=created_time, model=model_name, choices=choices, usage=usage
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
# --- chat specific ---
|
||||
def get_chat_role(self, request: ChatCompletionRequest) -> str:
|
||||
return request.messages[-1]['role'] if not request.add_generation_prompt else 'assistant'
|
||||
|
||||
@staticmethod
|
||||
async def completions(self, request) -> t.AsyncGenerator[CompletionResponse, None]: ...
|
||||
|
||||
# ---
|
||||
def create_error_response(self, message: str, exception: type[OpenLLMException] = ValidationError) -> ErrorResponse:
|
||||
return error_response(exception, message)
|
||||
|
||||
async def _check_model(self, request: ChatCompletionRequest | CompletionRequest) -> Error | None:
|
||||
if request.model != self.llm.llm_type:
|
||||
return error_response(
|
||||
ModelNotFound,
|
||||
f"Model '{request.model}' does not exists. Try 'GET /v1/models' to see available models.\nTip: If you are migrating from OpenAI, make sure to update your 'model' parameters in the request.",
|
||||
)
|
||||
# TODO: support lora
|
||||
|
||||
def _validate_prompt(
|
||||
self,
|
||||
request: ChatCompletionRequest | CompletionRequest,
|
||||
prompt: str | None = None,
|
||||
prompt_ids: list[int] | None = None,
|
||||
) -> Error | None:
|
||||
if not (prompt or prompt_ids):
|
||||
raise ValueError("'prompt' or 'prompt_ids' must be provided.")
|
||||
if prompt and prompt_ids:
|
||||
raise ValueError("'prompt' and 'prompt_ids' are mutually exclusive.")
|
||||
# TODO: valudate max context length based on requested max_tokens
|
||||
#
|
||||
# input_ids = prompt_ids if prompt_ids is not None else self._tokenizer(prompt).input_ids
|
||||
# token_num = len(input_ids)
|
||||
# if request.max_tokens is None: request.max_tokens = self.engine_args['max_model_len']
|
||||
|
||||
def _create_logprobs(
|
||||
self,
|
||||
token_ids: list[int],
|
||||
top_logprobs: list[dict[int, float] | None] | None = None, #
|
||||
num_output_top_logprobs: int | None = None,
|
||||
initial_text_offset: int = 0,
|
||||
) -> LogProbs:
|
||||
"""Create OpenAI-style logprobs."""
|
||||
logprobs = LogProbs()
|
||||
last_token_len = 0
|
||||
if num_output_top_logprobs:
|
||||
logprobs.top_logprobs = []
|
||||
for i, token_id in enumerate(token_ids):
|
||||
step_top_logprobs = top_logprobs[i]
|
||||
if step_top_logprobs is not None:
|
||||
token_logprob = step_top_logprobs[token_id].logprob
|
||||
else:
|
||||
token_logprob = None
|
||||
token = step_top_logprobs[token_id].decoded_token
|
||||
logprobs.tokens.append(token)
|
||||
logprobs.token_logprobs.append(token_logprob)
|
||||
if len(logprobs.text_offset) == 0:
|
||||
logprobs.text_offset.append(initial_text_offset)
|
||||
else:
|
||||
logprobs.text_offset.append(logprobs.text_offset[-1] + last_token_len)
|
||||
last_token_len = len(token)
|
||||
|
||||
if num_output_top_logprobs:
|
||||
logprobs.top_logprobs.append(
|
||||
{p.decoded_token: p.logprob for i, p in step_top_logprobs.items()} if step_top_logprobs else None
|
||||
)
|
||||
return logprobs
|
||||
|
||||
|
||||
# NOTE: Check for following supported mapping from here:
|
||||
# python -c "from openllm_core._typing_compat import get_type_hints, get_literal_args; from _bentoml_sdk.service.config import ResourceSchema; print(get_literal_args(get_type_hints(ResourceSchema)['gpu_type']))"
|
||||
RECOMMENDED_MAPPING = {'nvidia-l4': 24e9, 'nvidia-a10g': 24e9, 'nvidia-tesla-a100': 40e9, 'nvidia-a100-80gb': 80e9}
|
||||
|
||||
|
||||
def recommended_instance_type(model_id, bentomodel=None):
|
||||
if bentomodel is not None:
|
||||
size = sum(f.stat().st_size for f in pathlib.Path(resolve_filepath(model_id)).glob('**/*') if f.is_file())
|
||||
else:
|
||||
info = next(filter(lambda repo: repo.repo_id == model_id, scan_cache_dir().repos))
|
||||
size = info.size_on_disk
|
||||
|
||||
# find the first occurence of the gpu_type in the recommended mapping such that "size" should be less than or equal to 70% of the recommended size
|
||||
for gpu, max_size in RECOMMENDED_MAPPING.items():
|
||||
if size <= max_size * 0.7:
|
||||
return gpu
|
||||
227
openllm-python/src/_openllm_tiny/_llm.py
Normal file
227
openllm-python/src/_openllm_tiny/_llm.py
Normal file
@@ -0,0 +1,227 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect, orjson, dataclasses, bentoml, functools, attr, openllm_core, traceback, openllm, typing as t
|
||||
|
||||
from openllm_core.utils import (
|
||||
get_debug_mode,
|
||||
is_vllm_available,
|
||||
normalise_model_name,
|
||||
gen_random_uuid,
|
||||
dict_filter_none,
|
||||
)
|
||||
from openllm_core._typing_compat import LiteralQuantise, LiteralSerialisation, LiteralDtype
|
||||
from openllm_core._schemas import GenerationOutput, GenerationInput
|
||||
from _bentoml_sdk.service import ServiceConfig
|
||||
|
||||
Dtype = t.Union[LiteralDtype, t.Literal['auto', 'half', 'float']]
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from vllm import RequestOutput
|
||||
|
||||
|
||||
def check_engine_args(_, attr: attr.Attribute[dict[str, t.Any]], v: dict[str, t.Any]) -> dict[str, t.Any]:
|
||||
from vllm import AsyncEngineArgs
|
||||
|
||||
fields = dataclasses.fields(AsyncEngineArgs)
|
||||
invalid_args = {k: v for k, v in v.items() if k not in {f.name for f in fields}}
|
||||
if len(invalid_args) > 0:
|
||||
raise ValueError(f'Invalid engine args: {list(invalid_args)}')
|
||||
return v
|
||||
|
||||
|
||||
@attr.define(init=False)
|
||||
class LLM:
|
||||
model_id: str
|
||||
bentomodel: bentoml.Model
|
||||
serialisation: LiteralSerialisation
|
||||
config: openllm_core.LLMConfig
|
||||
dtype: Dtype
|
||||
quantise: t.Optional[LiteralQuantise] = attr.field(default=None)
|
||||
trust_remote_code: bool = attr.field(default=False)
|
||||
engine_args: t.Dict[str, t.Any] = attr.field(factory=dict, validator=check_engine_args)
|
||||
service_config: t.Optional[ServiceConfig] = attr.field(factory=dict)
|
||||
|
||||
_path: str = attr.field(
|
||||
init=False,
|
||||
default=attr.Factory(
|
||||
lambda self: self.bentomodel.path if self.bentomodel is not None else self.model_id, takes_self=True
|
||||
),
|
||||
)
|
||||
|
||||
def __init__(self, _: str = '', /, _internal: bool = False, **kwargs: t.Any) -> None:
|
||||
if not _internal:
|
||||
raise RuntimeError(
|
||||
'Cannot instantiate LLM directly in the new API. Make sure to use `openllm.LLM.from_model` instead.'
|
||||
)
|
||||
self.__attrs_init__(**kwargs)
|
||||
|
||||
def __attrs_post_init__(self):
|
||||
assert self._model # Ensure the model is initialised.
|
||||
|
||||
@functools.cached_property
|
||||
def _model(self):
|
||||
if is_vllm_available():
|
||||
num_gpus, dev = 1, openllm.utils.device_count()
|
||||
if dev >= 2:
|
||||
num_gpus = min(dev // 2 * 2, dev)
|
||||
quantise = self.quantise if self.quantise and self.quantise in {'gptq', 'awq', 'squeezellm'} else None
|
||||
dtype = 'float16' if quantise == 'gptq' else self.dtype # NOTE: quantise GPTQ doesn't support bfloat16 yet.
|
||||
|
||||
self.engine_args.setdefault('gpu_memory_utilization', 0.9)
|
||||
self.engine_args.update({
|
||||
'worker_use_ray': False,
|
||||
'engine_use_ray': False,
|
||||
'tokenizer_mode': 'auto',
|
||||
'tensor_parallel_size': num_gpus,
|
||||
'model': self._path,
|
||||
'tokenizer': self._path,
|
||||
'trust_remote_code': self.trust_remote_code,
|
||||
'dtype': dtype,
|
||||
'quantization': quantise,
|
||||
})
|
||||
if 'disable_log_stats' not in self.engine_args:
|
||||
self.engine_args['disable_log_stats'] = not get_debug_mode()
|
||||
if 'disable_log_requests' not in self.engine_args:
|
||||
self.engine_args['disable_log_requests'] = not get_debug_mode()
|
||||
|
||||
try:
|
||||
from vllm import AsyncEngineArgs, AsyncLLMEngine
|
||||
|
||||
return AsyncLLMEngine.from_engine_args(AsyncEngineArgs(**self.engine_args))
|
||||
except Exception as err:
|
||||
traceback.print_exc()
|
||||
raise openllm.exceptions.OpenLLMException(
|
||||
f'Failed to initialise vLLMEngine due to the following error:\n{err}'
|
||||
) from err
|
||||
else:
|
||||
raise RuntimeError('Currently, OpenLLM is only supported running with a GPU and vLLM backend.')
|
||||
|
||||
@classmethod
|
||||
def from_model(
|
||||
cls,
|
||||
model_id: str,
|
||||
dtype: str,
|
||||
bentomodel: bentoml.Model | None = None,
|
||||
serialisation: LiteralSerialisation = 'safetensors',
|
||||
quantise: LiteralQuantise | None = None,
|
||||
trust_remote_code: bool = False,
|
||||
llm_config: openllm_core.LLMConfig | None = None,
|
||||
service_config: ServiceConfig | None = None,
|
||||
**engine_args: t.Any,
|
||||
) -> LLM:
|
||||
return cls(
|
||||
_internal=True,
|
||||
model_id=model_id,
|
||||
bentomodel=bentomodel,
|
||||
quantise=quantise,
|
||||
serialisation=serialisation,
|
||||
config=llm_config,
|
||||
dtype=dtype,
|
||||
engine_args=engine_args,
|
||||
trust_remote_code=trust_remote_code,
|
||||
service_config=service_config,
|
||||
)
|
||||
|
||||
@property
|
||||
def llm_type(self) -> str:
|
||||
return normalise_model_name(self.model_id)
|
||||
|
||||
@property
|
||||
def identifying_params(self) -> dict[str, str]:
|
||||
return {
|
||||
'configuration': self.config.model_dump_json(),
|
||||
'model_ids': orjson.dumps(self.config['model_ids']).decode(),
|
||||
'model_id': self.model_id,
|
||||
}
|
||||
|
||||
@functools.cached_property
|
||||
def _tokenizer(self):
|
||||
import transformers
|
||||
|
||||
return transformers.AutoTokenizer.from_pretrained(self._path, trust_remote_code=self.trust_remote_code)
|
||||
|
||||
async def generate_iterator(
|
||||
self,
|
||||
prompt: str | None,
|
||||
prompt_token_ids: list[int] | None = None,
|
||||
stop: str | t.Iterable[str] | None = None,
|
||||
stop_token_ids: list[int] | None = None,
|
||||
request_id: str | None = None,
|
||||
adapter_name: str | None = None,
|
||||
**attrs: t.Any,
|
||||
) -> t.AsyncGenerator[RequestOutput, None]:
|
||||
from vllm import SamplingParams
|
||||
|
||||
config = self.config.generation_config.model_copy(update=dict_filter_none(attrs))
|
||||
|
||||
stop_token_ids = stop_token_ids or []
|
||||
eos_token_id = attrs.get('eos_token_id', config['eos_token_id'])
|
||||
if eos_token_id and not isinstance(eos_token_id, list):
|
||||
eos_token_id = [eos_token_id]
|
||||
stop_token_ids.extend(eos_token_id or [])
|
||||
if (config_eos := config['eos_token_id']) and config_eos not in stop_token_ids:
|
||||
stop_token_ids.append(config_eos)
|
||||
if self._tokenizer.eos_token_id not in stop_token_ids:
|
||||
stop_token_ids.append(self._tokenizer.eos_token_id)
|
||||
if stop is None:
|
||||
stop = set()
|
||||
elif isinstance(stop, str):
|
||||
stop = {stop}
|
||||
else:
|
||||
stop = set(stop)
|
||||
|
||||
request_id = gen_random_uuid() if request_id is None else request_id
|
||||
|
||||
top_p = 1.0 if config['temperature'] <= 1e-5 else config['top_p']
|
||||
config = config.model_copy(update=dict(stop=list(stop), stop_token_ids=stop_token_ids, top_p=top_p))
|
||||
sampling_params = SamplingParams(**{
|
||||
k: getattr(config, k, None) for k in set(inspect.signature(SamplingParams).parameters.keys())
|
||||
})
|
||||
|
||||
try:
|
||||
async for it in self._model.generate(
|
||||
prompt, sampling_params=sampling_params, request_id=request_id, prompt_token_ids=prompt_token_ids
|
||||
):
|
||||
yield it
|
||||
except Exception as err:
|
||||
raise RuntimeError(f'Failed to start generation task: {err}') from err
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
prompt: str | None,
|
||||
prompt_token_ids: list[int] | None = None,
|
||||
stop: str | t.Iterable[str] | None = None,
|
||||
stop_token_ids: list[int] | None = None,
|
||||
request_id: str | None = None,
|
||||
adapter_name: str | None = None,
|
||||
*,
|
||||
_generated: GenerationInput | None = None,
|
||||
**attrs: t.Any,
|
||||
) -> GenerationOutput:
|
||||
if stop is not None:
|
||||
attrs.update({'stop': stop})
|
||||
if stop_token_ids is not None:
|
||||
attrs.update({'stop_token_ids': stop_token_ids})
|
||||
config = self.config.model_copy(update=attrs)
|
||||
texts, token_ids = [[]] * config['n'], [[]] * config['n']
|
||||
async for result in self.generate_iterator(
|
||||
prompt,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
request_id=request_id,
|
||||
adapter_name=adapter_name,
|
||||
**config.model_dump(),
|
||||
):
|
||||
for output in result.outputs:
|
||||
texts[output.index].append(output.text)
|
||||
token_ids[output.index].extend(output.token_ids)
|
||||
if (final_result := result) is None:
|
||||
raise RuntimeError('No result is returned.')
|
||||
return GenerationOutput.from_vllm(final_result).model_copy(
|
||||
update=dict(
|
||||
prompt=prompt,
|
||||
outputs=[
|
||||
output.model_copy(update=dict(text=''.join(texts[output.index]), token_ids=token_ids[output.index]))
|
||||
for output in final_result.outputs
|
||||
],
|
||||
)
|
||||
)
|
||||
139
openllm-python/src/_openllm_tiny/_service.py
Normal file
139
openllm-python/src/_openllm_tiny/_service.py
Normal file
@@ -0,0 +1,139 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from http import HTTPStatus
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse, StreamingResponse
|
||||
import openllm, bentoml, logging, openllm_core as core
|
||||
import _service_vars as svars, typing as t
|
||||
from openllm_core._typing_compat import Annotated, Unpack
|
||||
from openllm_core._schemas import MessageParam, MessagesConverterInput
|
||||
from openllm_core.protocol.openai import ModelCard, ModelList, ChatCompletionRequest
|
||||
from _openllm_tiny._helpers import OpenAI, Error
|
||||
|
||||
try:
|
||||
from fastapi import FastAPI
|
||||
except ImportError:
|
||||
raise ImportError("Make sure to install openllm with 'pip install openllm[openai]'") from None
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
bentomodel = bentoml.models.get(svars.model_id.lower())
|
||||
model_id = bentomodel.path
|
||||
except Exception:
|
||||
bentomodel = None
|
||||
model_id = svars.model_id
|
||||
llm_config = core.AutoConfig.for_model(svars.model_name)
|
||||
GenerationInput = core.GenerationInput.from_config(llm_config)
|
||||
|
||||
app_v1 = FastAPI(debug=True, description='OpenAI Compatible API support')
|
||||
|
||||
|
||||
@bentoml.mount_asgi_app(app_v1)
|
||||
@bentoml.service(name=f"llm-{llm_config['start_name']}-service", **svars.services_config)
|
||||
class LLMService:
|
||||
bentomodel = bentomodel
|
||||
|
||||
def __init__(self):
|
||||
self.llm = openllm.LLM.from_model(
|
||||
model_id,
|
||||
dtype=svars.dtype,
|
||||
bentomodel=bentomodel,
|
||||
serialisation=svars.serialisation,
|
||||
quantise=svars.quantise,
|
||||
llm_config=llm_config,
|
||||
trust_remote_code=svars.trust_remote_code,
|
||||
services_config=svars.services_config,
|
||||
max_model_len=svars.max_model_len,
|
||||
gpu_memory_utilization=svars.gpu_memory_utilization,
|
||||
)
|
||||
self.openai = OpenAI(self.llm)
|
||||
|
||||
@core.utils.api(route='/v1/generate')
|
||||
async def generate_v1(self, **parameters: Unpack[core.GenerationInputDict]) -> core.GenerationOutput:
|
||||
return await self.llm.generate(**GenerationInput.from_dict(parameters).model_dump())
|
||||
|
||||
@core.utils.api(route='/v1/generate_stream')
|
||||
async def generate_stream_v1(self, **parameters: Unpack[core.GenerationInputDict]) -> t.AsyncGenerator[str, None]:
|
||||
async for generated in self.llm.generate_iterator(**GenerationInput.from_dict(parameters).model_dump()):
|
||||
yield f'data: {generated.model_dump_json()}\n\n'
|
||||
yield 'data: [DONE]\n\n'
|
||||
|
||||
@core.utils.api(route='/v1/metadata')
|
||||
async def metadata_v1(self) -> core.MetadataOutput:
|
||||
return core.MetadataOutput(
|
||||
timeout=self.llm.config['timeout'],
|
||||
model_name=self.llm.config['model_name'],
|
||||
backend='vllm', # deprecated
|
||||
model_id=self.llm.model_id,
|
||||
configuration=self.llm.config.model_dump_json(),
|
||||
)
|
||||
|
||||
@core.utils.api(route='/v1/helpers/messages')
|
||||
def helpers_messages_v1(
|
||||
self,
|
||||
message: Annotated[t.Dict[str, t.Any], MessagesConverterInput] = MessagesConverterInput(
|
||||
add_generation_prompt=False,
|
||||
messages=[
|
||||
MessageParam(role='system', content='You are acting as Ernest Hemmingway.'),
|
||||
MessageParam(role='user', content='Hi there!'),
|
||||
MessageParam(role='assistant', content='Yes?'),
|
||||
],
|
||||
),
|
||||
) -> str:
|
||||
return self.llm._tokenizer.apply_chat_template(
|
||||
message['messages'], add_generation_prompt=message['add_generation_prompt'], tokenize=False
|
||||
)
|
||||
|
||||
@app_v1.post(
|
||||
'/v1/chat/completions',
|
||||
tags=['OpenAI'],
|
||||
status_code=HTTPStatus.OK,
|
||||
summary='Given a list of messages comprising a conversation, the model will return a response.',
|
||||
operation_id='openai__chat_completions',
|
||||
)
|
||||
async def chat_completions_v1(
|
||||
self,
|
||||
raw_request: Request,
|
||||
request: ChatCompletionRequest = ChatCompletionRequest(
|
||||
messages=[
|
||||
MessageParam(role='system', content='You are acting as Ernest Hemmingway.'),
|
||||
MessageParam(role='user', content='Hi there!'),
|
||||
MessageParam(role='assistant', content='Yes?'),
|
||||
],
|
||||
model=core.utils.normalise_model_name(model_id),
|
||||
n=1,
|
||||
stream=True,
|
||||
),
|
||||
):
|
||||
generator = await self.openai.chat_completions(request, raw_request)
|
||||
if isinstance(generator, Error):
|
||||
# NOTE: safe to cast here as we know it's an error
|
||||
return JSONResponse(content=generator.model_dump(), status_code=int(t.cast(str, generator.error.code)))
|
||||
if request.stream is True:
|
||||
return StreamingResponse(generator, media_type='text/event-stream')
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
|
||||
# GET /v1/models
|
||||
@app_v1.get(
|
||||
'/v1/models',
|
||||
tags=['OpenAI'],
|
||||
status_code=HTTPStatus.OK,
|
||||
summary='Describes a model offering that can be used with the API.',
|
||||
operation_id='openai__list_models',
|
||||
)
|
||||
def list_models(self) -> ModelList:
|
||||
"""
|
||||
List and describe the various models available in the API.
|
||||
|
||||
You can refer to the available supported models with `openllm models` for more information.
|
||||
"""
|
||||
return ModelList(
|
||||
data=[ModelCard(root=core.utils.normalise_model_name(model_id), id=core.utils.normalise_model_name(model_id))]
|
||||
)
|
||||
|
||||
|
||||
LLMService.mount_asgi_app(app_v1)
|
||||
|
||||
if __name__ == '__main__':
|
||||
LLMService.serve_http(reload=core.utils.check_bool_env('RELOAD', False))
|
||||
30
openllm-python/src/_openllm_tiny/_service_vars.py
Normal file
30
openllm-python/src/_openllm_tiny/_service_vars.py
Normal file
@@ -0,0 +1,30 @@
|
||||
import orjson, typing as t, openllm_core.utils as coreutils
|
||||
from openllm_core._typing_compat import LiteralSerialisation, LiteralQuantise
|
||||
from _openllm_tiny._llm import Dtype
|
||||
|
||||
(
|
||||
model_id,
|
||||
model_name,
|
||||
quantise,
|
||||
serialisation,
|
||||
dtype,
|
||||
trust_remote_code,
|
||||
max_model_len,
|
||||
gpu_memory_utilization,
|
||||
services_config,
|
||||
) = (
|
||||
coreutils.getenv('model_id', var=['MODEL_ID'], return_type=str),
|
||||
coreutils.getenv('model_name', return_type=str),
|
||||
coreutils.getenv('quantize', var=['QUANTISE'], return_type=LiteralQuantise),
|
||||
coreutils.getenv('serialization', default='safetensors', var=['SERIALISATION'], return_type=LiteralSerialisation),
|
||||
coreutils.getenv('dtype', default='auto', var=['TORCH_DTYPE'], return_type=Dtype),
|
||||
coreutils.check_bool_env('TRUST_REMOTE_CODE', False),
|
||||
t.cast(t.Optional[int], orjson.loads(coreutils.getenv('max_model_len', default=orjson.dumps(None)))),
|
||||
t.cast(
|
||||
float,
|
||||
orjson.loads(
|
||||
coreutils.getenv('gpu_memory_utilization', default=orjson.dumps(0.9), var=['GPU_MEMORY_UTILISATION'])
|
||||
),
|
||||
),
|
||||
t.cast(t.Dict[str, t.Any], orjson.loads(coreutils.getenv('services_config', orjson.dumps({})))),
|
||||
)
|
||||
1
openllm-python/src/_openllm_tiny/bentofile.yaml
Normal file
1
openllm-python/src/_openllm_tiny/bentofile.yaml
Normal file
@@ -0,0 +1 @@
|
||||
service: '_service:LLMService'
|
||||
15
openllm-python/src/_service_vars.pyi
Normal file
15
openllm-python/src/_service_vars.pyi
Normal file
@@ -0,0 +1,15 @@
|
||||
from typing import Dict, Optional, Any
|
||||
from openllm_core._typing_compat import LiteralSerialisation, LiteralQuantise, LiteralString
|
||||
from _openllm_tiny._llm import Dtype
|
||||
|
||||
model_id: str = ...
|
||||
model_name: LiteralString = ...
|
||||
model_tag: Optional[str] = ...
|
||||
model_version: Optional[str] = ...
|
||||
quantise: LiteralQuantise = ...
|
||||
serialisation: LiteralSerialisation = ...
|
||||
dtype: Dtype = ...
|
||||
trust_remote_code: bool = ...
|
||||
max_model_len: Optional[int] = ...
|
||||
gpu_memory_utilization: int = ...
|
||||
services_config: Dict[str, Any] = ...
|
||||
@@ -1,5 +1,5 @@
|
||||
import logging as _logging, os as _os, pathlib as _pathlib, warnings as _warnings
|
||||
from openllm_cli import _sdk
|
||||
import logging as _logging, os as _os, pathlib as _pathlib, warnings as _warnings, typing as _t
|
||||
|
||||
from . import utils as utils
|
||||
|
||||
if utils.DEBUG:
|
||||
@@ -9,10 +9,16 @@ else:
|
||||
# configuration for bitsandbytes before import
|
||||
_os.environ['BITSANDBYTES_NOWELCOME'] = _os.environ.get('BITSANDBYTES_NOWELCOME', '1')
|
||||
# NOTE: The following warnings from bitsandbytes, and probably not that important for users to see when DEBUG is False
|
||||
_warnings.filterwarnings('ignore', message='MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization')
|
||||
_warnings.filterwarnings('ignore', message='MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization')
|
||||
_warnings.filterwarnings(
|
||||
'ignore', message='MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization'
|
||||
)
|
||||
_warnings.filterwarnings(
|
||||
'ignore', message='MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization'
|
||||
)
|
||||
_warnings.filterwarnings('ignore', message='The installed version of bitsandbytes was compiled without GPU support.')
|
||||
_warnings.filterwarnings('ignore', message='Neither GITHUB_TOKEN nor GITHUB_JWT_TOKEN found: running as unauthenticated')
|
||||
_warnings.filterwarnings(
|
||||
'ignore', message='Neither GITHUB_TOKEN nor GITHUB_JWT_TOKEN found: running as unauthenticated'
|
||||
)
|
||||
COMPILED = _pathlib.Path(__file__).suffix in ('.pyd', '.so')
|
||||
__lazy = utils.LazyModule( # NOTE: update this to sys.modules[__name__] once mypy_extensions can recognize __spec__
|
||||
__name__,
|
||||
@@ -22,21 +28,42 @@ __lazy = utils.LazyModule( # NOTE: update this to sys.modules[__name__] once my
|
||||
'client': ['HTTPClient', 'AsyncHTTPClient'],
|
||||
'bundle': [],
|
||||
'testing': [],
|
||||
'protocol': [],
|
||||
'utils': [],
|
||||
'_deprecated': ['Runner'],
|
||||
'_strategies': ['CascadingResourceStrategy', 'get_resource'],
|
||||
'utils': ['api'],
|
||||
'entrypoints': ['mount_entrypoints'],
|
||||
'serialisation': ['ggml', 'transformers'],
|
||||
'_quantisation': ['infer_quantisation_config'],
|
||||
'serialisation': ['ggml', 'transformers', 'vllm'],
|
||||
'_llm': ['LLM'],
|
||||
'_deprecated': ['Runner'],
|
||||
'_runners': ['runner'],
|
||||
'_quantisation': ['infer_quantisation_config'],
|
||||
'_strategies': ['CascadingResourceStrategy', 'get_resource'],
|
||||
},
|
||||
extra_objects={
|
||||
'COMPILED': COMPILED,
|
||||
'start': _sdk.start,
|
||||
'build': _sdk.build, #
|
||||
'import_model': _sdk.import_model,
|
||||
'list_models': _sdk.list_models, #
|
||||
},
|
||||
extra_objects={'COMPILED': COMPILED},
|
||||
)
|
||||
__all__, __dir__, __getattr__ = __lazy.__all__, __lazy.__dir__, __lazy.__getattr__
|
||||
__all__, __dir__ = __lazy.__all__, __lazy.__dir__
|
||||
|
||||
_BREAKING_INTERNAL = ['_service', '_service_vars']
|
||||
_NEW_IMPL = ['LLM', *_BREAKING_INTERNAL]
|
||||
|
||||
if (_BENTOML_VERSION := utils.pkg.pkg_version_info('bentoml')) > (1, 2):
|
||||
import _openllm_tiny as _tiny
|
||||
else:
|
||||
_tiny = None
|
||||
|
||||
|
||||
def __getattr__(name: str) -> _t.Any:
|
||||
if name in _NEW_IMPL:
|
||||
if utils.getenv('IMPLEMENTATION', default='new_impl') == 'deprecated' or _tiny is None:
|
||||
if name in _BREAKING_INTERNAL:
|
||||
raise ImportError(
|
||||
f'"{name}" is an internal implementation and considered breaking with older OpenLLM. Please migrate your code if you depend on this.'
|
||||
)
|
||||
_warnings.warn(
|
||||
f'"{name}" is considered deprecated implementation and will be removed in the future. Make sure to upgrade to OpenLLM 0.5.x',
|
||||
DeprecationWarning,
|
||||
stacklevel=3,
|
||||
)
|
||||
return __lazy.__getattr__(name)
|
||||
else:
|
||||
return getattr(_tiny, name)
|
||||
else:
|
||||
return __lazy.__getattr__(name)
|
||||
|
||||
@@ -10,22 +10,29 @@ Fine-tune, serve, deploy, and monitor any LLMs with ease.
|
||||
* Native integration with BentoML, LangChain, OpenAI compatible endpoints, LlamaIndex for custom LLM apps
|
||||
"""
|
||||
|
||||
# fmt: off
|
||||
# update-config-stubs.py: import stubs start
|
||||
from openlm_core.config import CONFIG_MAPPING as CONFIG_MAPPING, CONFIG_MAPPING_NAMES as CONFIG_MAPPING_NAMES, AutoConfig as AutoConfig, BaichuanConfig as BaichuanConfig, ChatGLMConfig as ChatGLMConfig, DollyV2Config as DollyV2Config, FalconConfig as FalconConfig, FlanT5Config as FlanT5Config, GPTNeoXConfig as GPTNeoXConfig, LlamaConfig as LlamaConfig, MistralConfig as MistralConfig, MixtralConfig as MixtralConfig, MPTConfig as MPTConfig, OPTConfig as OPTConfig, PhiConfig as PhiConfig, QwenConfig as QwenConfig, StableLMConfig as StableLMConfig, StarCoderConfig as StarCoderConfig, YiConfig as YiConfig
|
||||
from openllm_client import AsyncHTTPClient as AsyncHTTPClient, HTTPClient as HTTPClient
|
||||
from openlm_core.config import CONFIG_MAPPING as CONFIG_MAPPING, CONFIG_MAPPING_NAMES as CONFIG_MAPPING_NAMES, AutoConfig as AutoConfig, BaichuanConfig as BaichuanConfig, ChatGLMConfig as ChatGLMConfig, DollyV2Config as DollyV2Config, FalconConfig as FalconConfig, FlanT5Config as FlanT5Config, GemmaConfig as GemmaConfig, GPTNeoXConfig as GPTNeoXConfig, LlamaConfig as LlamaConfig, MistralConfig as MistralConfig, MixtralConfig as MixtralConfig, MPTConfig as MPTConfig, OPTConfig as OPTConfig, PhiConfig as PhiConfig, QwenConfig as QwenConfig, StableLMConfig as StableLMConfig, StarCoderConfig as StarCoderConfig, YiConfig as YiConfig
|
||||
from openllm_core._configuration import GenerationConfig as GenerationConfig, LLMConfig as LLMConfig
|
||||
from openllm_core._schemas import GenerationInput as GenerationInput, GenerationOutput as GenerationOutput, MetadataOutput as MetadataOutput, MessageParam as MessageParam
|
||||
from openllm_core.utils import api as api
|
||||
# update-config-stubs.py: import stubs stop
|
||||
# fmt: on
|
||||
|
||||
from openllm_cli._sdk import build as build, import_model as import_model, list_models as list_models, start as start
|
||||
from openllm_core._configuration import GenerationConfig as GenerationConfig, LLMConfig as LLMConfig, SamplingParams as SamplingParams
|
||||
from openllm_core._schemas import GenerationInput as GenerationInput, GenerationOutput as GenerationOutput, MetadataOutput as MetadataOutput
|
||||
|
||||
from . import bundle as bundle, client as client, exceptions as exceptions, serialisation as serialisation, utils as utils
|
||||
from . import (
|
||||
bundle as bundle,
|
||||
client as client,
|
||||
exceptions as exceptions,
|
||||
serialisation as serialisation,
|
||||
utils as utils,
|
||||
entrypoints as entrypoints,
|
||||
)
|
||||
from .serialisation import ggml as ggml, transformers as transformers, vllm as vllm
|
||||
from ._deprecated import Runner as Runner
|
||||
from ._llm import LLM as LLM
|
||||
from ._runners import runner as runner
|
||||
from ._quantisation import infer_quantisation_config as infer_quantisation_config
|
||||
from ._strategies import CascadingResourceStrategy as CascadingResourceStrategy, get_resource as get_resource
|
||||
from .client import AsyncHTTPClient as AsyncHTTPClient, HTTPClient as HTTPClient
|
||||
from .entrypoints import mount_entrypoints as mount_entrypoints
|
||||
from .protocol import openai as openai
|
||||
from .serialisation import ggml as ggml, transformers as transformers
|
||||
from _openllm_tiny import LLM as LLM
|
||||
|
||||
COMPILED: bool = ...
|
||||
|
||||
@@ -19,7 +19,9 @@ def Runner(
|
||||
if llm_config is None:
|
||||
llm_config = openllm.AutoConfig.for_model(model_name)
|
||||
if not ensure_available:
|
||||
logger.warning("'ensure_available=False' won't have any effect as LLM will always check to download the model on initialisation.")
|
||||
logger.warning(
|
||||
"'ensure_available=False' won't have any effect as LLM will always check to download the model on initialisation."
|
||||
)
|
||||
model_id = attrs.get('model_id', os.getenv('OPENLLM_MODEL_ID', llm_config['default_id']))
|
||||
warnings.warn(
|
||||
f"""\
|
||||
@@ -40,8 +42,14 @@ def Runner(
|
||||
attrs.update({
|
||||
'model_id': model_id,
|
||||
'quantize': getenv('QUANTIZE', var=['QUANTISE'], default=attrs.get('quantize', None)), #
|
||||
'serialisation': getenv('serialization', default=attrs.get('serialisation', llm_config['serialisation']), var=['SERIALISATION']),
|
||||
'serialisation': getenv(
|
||||
'serialization', default=attrs.get('serialisation', llm_config['serialisation']), var=['SERIALISATION']
|
||||
),
|
||||
})
|
||||
# XXX: Make this back to Runnable implementation
|
||||
return openllm.LLM(
|
||||
backend=first_not_none(backend, default='vllm' if is_vllm_available() else 'pt'), llm_config=llm_config, embedded=init_local, **attrs
|
||||
backend=first_not_none(backend, default='vllm' if is_vllm_available() else 'pt'),
|
||||
llm_config=llm_config,
|
||||
embedded=init_local,
|
||||
**attrs,
|
||||
).runner
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
import functools, logging, os, warnings, typing as t
|
||||
import attr, inflection, orjson, bentoml, openllm
|
||||
import attr, orjson, bentoml, openllm, openllm_core
|
||||
from openllm_core._schemas import GenerationOutput
|
||||
from openllm_core._typing_compat import (
|
||||
AdapterMap,
|
||||
@@ -10,22 +10,20 @@ from openllm_core._typing_compat import (
|
||||
LiteralDtype,
|
||||
LiteralQuantise,
|
||||
LiteralSerialisation,
|
||||
M,
|
||||
T,
|
||||
)
|
||||
from openllm.serialisation import _make_tag_components
|
||||
from openllm_core.exceptions import MissingDependencyError
|
||||
from openllm_core.utils import (
|
||||
DEBUG,
|
||||
apply,
|
||||
check_bool_env,
|
||||
codegen,
|
||||
first_not_none,
|
||||
normalise_model_name,
|
||||
flatten_attrs,
|
||||
gen_random_uuid,
|
||||
generate_hash_from_file,
|
||||
getenv,
|
||||
is_ctranslate_available,
|
||||
is_peft_available,
|
||||
is_transformers_available,
|
||||
is_vllm_available,
|
||||
resolve_filepath,
|
||||
validate_is_path,
|
||||
@@ -43,29 +41,50 @@ if t.TYPE_CHECKING:
|
||||
logger = logging.getLogger(__name__)
|
||||
_AdapterTuple: type[AdapterTuple] = codegen.make_attr_tuple_class('AdapterTuple', ['adapter_id', 'name', 'config'])
|
||||
ResolvedAdapterMap = t.Dict[AdapterType, t.Dict[str, t.Tuple['PeftConfig', str]]]
|
||||
CONFIG_FILE_NAME = 'config.json'
|
||||
|
||||
M = t.TypeVar('M')
|
||||
T = t.TypeVar('T')
|
||||
|
||||
|
||||
@attr.define(slots=False, repr=False, init=False)
|
||||
class LLM(t.Generic[M, T]):
|
||||
async def generate(self, prompt, prompt_token_ids=None, stop=None, stop_token_ids=None, request_id=None, adapter_name=None, **attrs):
|
||||
async def generate(
|
||||
self, prompt, prompt_token_ids=None, stop=None, stop_token_ids=None, request_id=None, adapter_name=None, **attrs
|
||||
):
|
||||
if adapter_name is not None and self.__llm_backend__ != 'pt':
|
||||
raise NotImplementedError(f'Adapter is not supported with {self.__llm_backend__}.')
|
||||
if stop is not None:
|
||||
attrs.update({'stop': stop})
|
||||
if stop_token_ids is not None:
|
||||
attrs.update({'stop_token_ids': stop_token_ids})
|
||||
config = self.config.model_construct_env(**attrs)
|
||||
texts, token_ids = [[]] * config['n'], [[]] * config['n']
|
||||
async for result in self.generate_iterator(
|
||||
prompt, prompt_token_ids, stop, stop_token_ids, request_id, adapter_name, **config.model_dump(flatten=True)
|
||||
prompt,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
request_id=request_id,
|
||||
adapter_name=adapter_name,
|
||||
**config.model_dump(),
|
||||
):
|
||||
for output in result.outputs:
|
||||
texts[output.index].append(output.text)
|
||||
token_ids[output.index].extend(output.token_ids)
|
||||
if (final_result := result) is None:
|
||||
raise RuntimeError('No result is returned.')
|
||||
return final_result.with_options(
|
||||
prompt=prompt,
|
||||
outputs=[output.with_options(text=''.join(texts[output.index]), token_ids=token_ids[output.index]) for output in final_result.outputs],
|
||||
return final_result.model_copy(
|
||||
update=dict(
|
||||
prompt=prompt,
|
||||
outputs=[
|
||||
output.model_copy(update=dict(text=''.join(texts[output.index]), token_ids=token_ids[output.index]))
|
||||
for output in final_result.outputs
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
async def generate_iterator(self, prompt, prompt_token_ids=None, stop=None, stop_token_ids=None, request_id=None, adapter_name=None, **attrs):
|
||||
async def generate_iterator(
|
||||
self, prompt, prompt_token_ids=None, stop=None, stop_token_ids=None, request_id=None, adapter_name=None, **attrs
|
||||
):
|
||||
from bentoml._internal.runner.runner_handle import DummyRunnerHandle
|
||||
|
||||
if adapter_name is not None and self.__llm_backend__ != 'pt':
|
||||
@@ -76,6 +95,7 @@ class LLM(t.Generic[M, T]):
|
||||
raise RuntimeError('Runner client failed to set up correctly.')
|
||||
else:
|
||||
self.runner.init_local(quiet=True)
|
||||
|
||||
config = self.config.model_construct_env(**attrs)
|
||||
|
||||
stop_token_ids = stop_token_ids or []
|
||||
@@ -93,36 +113,32 @@ class LLM(t.Generic[M, T]):
|
||||
stop = {stop}
|
||||
else:
|
||||
stop = set(stop)
|
||||
for tid in stop_token_ids:
|
||||
if tid:
|
||||
stop.add(self.tokenizer.decode(tid))
|
||||
|
||||
if prompt_token_ids is None:
|
||||
if prompt is None:
|
||||
raise ValueError('Either prompt or prompt_token_ids must be specified.')
|
||||
prompt_token_ids = self.tokenizer.encode(prompt)
|
||||
|
||||
request_id = gen_random_uuid() if request_id is None else request_id
|
||||
previous_texts, previous_num_tokens = [''] * config['n'], [0] * config['n']
|
||||
|
||||
config = config.model_construct_env(stop=list(stop), stop_token_ids=stop_token_ids)
|
||||
|
||||
try:
|
||||
generator = self.runner.generate_iterator.async_stream(
|
||||
prompt_token_ids, request_id, stop=list(stop), adapter_name=adapter_name, **config.model_dump(flatten=True)
|
||||
generator = bentoml.io.SSE.from_iterator(
|
||||
self.runner.generate_iterator.async_stream(
|
||||
prompt, request_id, prompt_token_ids=prompt_token_ids, adapter_name=adapter_name, **config.model_dump()
|
||||
)
|
||||
)
|
||||
generator = bentoml.io.SSE.from_iterator(generator)
|
||||
except Exception as err:
|
||||
raise RuntimeError(f'Failed to start generation task: {err}') from err
|
||||
|
||||
try:
|
||||
async for out in generator:
|
||||
out = out.data
|
||||
generated = GenerationOutput.from_runner(out).with_options(prompt=prompt)
|
||||
generated = GenerationOutput.from_runner(out).model_copy(update=dict(prompt=prompt))
|
||||
delta_outputs = [None] * len(generated.outputs)
|
||||
for output in generated.outputs:
|
||||
i = output.index
|
||||
delta_tokens, delta_text = output.token_ids[previous_num_tokens[i] :], output.text[len(previous_texts[i]) :]
|
||||
previous_texts[i], previous_num_tokens[i] = output.text, len(output.token_ids)
|
||||
delta_outputs[i] = output.with_options(text=delta_text, token_ids=delta_tokens)
|
||||
yield generated.with_options(outputs=delta_outputs)
|
||||
delta_outputs[i] = output.model_copy(update=dict(text=delta_text, token_ids=delta_tokens))
|
||||
yield generated.model_copy(update=dict(outputs=delta_outputs))
|
||||
except Exception as err:
|
||||
raise RuntimeError(f'Exception caught during generation: {err}') from err
|
||||
|
||||
@@ -130,7 +146,9 @@ class LLM(t.Generic[M, T]):
|
||||
# The below are mainly for internal implementation that you don't have to worry about.
|
||||
_model_id: str
|
||||
_revision: t.Optional[str] #
|
||||
_quantization_config: t.Optional[t.Union[transformers.BitsAndBytesConfig, transformers.GPTQConfig, transformers.AwqConfig]]
|
||||
_quantization_config: t.Optional[
|
||||
t.Union[transformers.BitsAndBytesConfig, transformers.GPTQConfig, transformers.AwqConfig]
|
||||
]
|
||||
_quantise: t.Optional[LiteralQuantise]
|
||||
_model_decls: t.Tuple[t.Any, ...]
|
||||
__model_attrs: t.Dict[str, t.Any] #
|
||||
@@ -146,7 +164,9 @@ class LLM(t.Generic[M, T]):
|
||||
__llm_torch_dtype__: 'torch.dtype' = None
|
||||
__llm_config__: t.Optional[LLMConfig] = None
|
||||
__llm_backend__: LiteralBackend = None
|
||||
__llm_quantization_config__: t.Optional[t.Union[transformers.BitsAndBytesConfig, transformers.GPTQConfig, transformers.AwqConfig]] = None
|
||||
__llm_quantization_config__: t.Optional[
|
||||
t.Union[transformers.BitsAndBytesConfig, transformers.GPTQConfig, transformers.AwqConfig]
|
||||
] = None
|
||||
__llm_runner__: t.Optional[Runner[M, T]] = None
|
||||
__llm_model__: t.Optional[M] = None
|
||||
__llm_tokenizer__: t.Optional[T] = None
|
||||
@@ -177,7 +197,9 @@ class LLM(t.Generic[M, T]):
|
||||
torch_dtype = attrs.pop('torch_dtype', None) # backward compatible
|
||||
if torch_dtype is not None:
|
||||
warnings.warn(
|
||||
'The argument "torch_dtype" is deprecated and will be removed in the future. Please use "dtype" instead.', DeprecationWarning, stacklevel=3
|
||||
'The argument "torch_dtype" is deprecated and will be removed in the future. Please use "dtype" instead.',
|
||||
DeprecationWarning,
|
||||
stacklevel=3,
|
||||
)
|
||||
dtype = torch_dtype
|
||||
_local = False
|
||||
@@ -195,7 +217,7 @@ class LLM(t.Generic[M, T]):
|
||||
# parsing tokenizer and model kwargs, as the hierarchy is param pass > default
|
||||
model_attrs, tokenizer_attrs = flatten_attrs(**attrs)
|
||||
if model_tag is None:
|
||||
model_tag, model_version = self._make_tag_components(model_id, model_version, backend=backend)
|
||||
model_tag, model_version = _make_tag_components(model_id, model_version)
|
||||
if model_version:
|
||||
model_tag = f'{model_tag}:{model_version}'
|
||||
|
||||
@@ -233,25 +255,12 @@ class LLM(t.Generic[M, T]):
|
||||
|
||||
class _Quantise:
|
||||
@staticmethod
|
||||
def pt(llm: LLM, quantise=None): return quantise
|
||||
@staticmethod
|
||||
def vllm(llm: LLM, quantise=None): return quantise
|
||||
@staticmethod
|
||||
def ctranslate(llm: LLM, quantise=None):
|
||||
if quantise in {'int4', 'awq', 'gptq', 'squeezellm'}: raise ValueError(f"Quantisation '{quantise}' is not supported for backend 'ctranslate'")
|
||||
if quantise == 'int8': quantise = 'int8_float16' if llm._has_gpus else 'int8_float32'
|
||||
def pt(llm: LLM, quantise=None):
|
||||
return quantise
|
||||
|
||||
@apply(lambda val: tuple(str.lower(i) if i else i for i in val))
|
||||
def _make_tag_components(self, model_id: str, model_version: str | None, backend: str) -> tuple[str, str | None]:
|
||||
model_id, *maybe_revision = model_id.rsplit(':')
|
||||
if len(maybe_revision) > 0:
|
||||
if model_version is not None:
|
||||
logger.warning("revision is specified (%s). 'model_version=%s' will be ignored.", maybe_revision[0], model_version)
|
||||
model_version = maybe_revision[0]
|
||||
if validate_is_path(model_id):
|
||||
model_id, model_version = resolve_filepath(model_id), first_not_none(model_version, default=generate_hash_from_file(model_id))
|
||||
return f'{backend}-{normalise_model_name(model_id)}', model_version
|
||||
@staticmethod
|
||||
def vllm(llm: LLM, quantise=None):
|
||||
return quantise
|
||||
|
||||
@functools.cached_property
|
||||
def _has_gpus(self):
|
||||
@@ -259,9 +268,11 @@ class LLM(t.Generic[M, T]):
|
||||
from cuda import cuda
|
||||
|
||||
err, *_ = cuda.cuInit(0)
|
||||
if err != cuda.CUresult.CUDA_SUCCESS: raise RuntimeError('Failed to initialise CUDA runtime binding.')
|
||||
if err != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise RuntimeError('Failed to initialise CUDA runtime binding.')
|
||||
err, _ = cuda.cuDeviceGetCount()
|
||||
if err != cuda.CUresult.CUDA_SUCCESS: raise RuntimeError('Failed to get CUDA device count.')
|
||||
if err != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise RuntimeError('Failed to get CUDA device count.')
|
||||
return True
|
||||
except (ImportError, RuntimeError):
|
||||
return False
|
||||
@@ -272,16 +283,13 @@ class LLM(t.Generic[M, T]):
|
||||
|
||||
_map = _torch_dtype_mapping()
|
||||
if not isinstance(self.__llm_torch_dtype__, torch.dtype):
|
||||
try:
|
||||
hf_config = transformers.AutoConfig.from_pretrained(self.bentomodel.path, trust_remote_code=self.trust_remote_code)
|
||||
except OpenLLMException:
|
||||
hf_config = transformers.AutoConfig.from_pretrained(self.model_id, trust_remote_code=self.trust_remote_code)
|
||||
config_dtype = getattr(hf_config, 'torch_dtype', None)
|
||||
if config_dtype is None:
|
||||
config_dtype = torch.float32
|
||||
hf_config = transformers.AutoConfig.from_pretrained(self.model_id, trust_remote_code=self.trust_remote_code)
|
||||
config_dtype = getattr(hf_config, 'torch_dtype', torch.float32)
|
||||
if self.__llm_dtype__ == 'auto':
|
||||
if config_dtype == torch.float32:
|
||||
if torch.cuda.is_available() and config_dtype is torch.float32:
|
||||
torch_dtype = torch.float16
|
||||
elif not torch.cuda.is_available():
|
||||
torch_dtype = torch.float32
|
||||
else:
|
||||
torch_dtype = config_dtype
|
||||
else:
|
||||
@@ -304,14 +312,11 @@ class LLM(t.Generic[M, T]):
|
||||
return {**self.import_kwargs[1], **self.__tokenizer_attrs}
|
||||
|
||||
def _cascade_backend(self) -> LiteralBackend:
|
||||
logger.warning('It is recommended to specify the backend explicitly. Cascading backend might lead to unexpected behaviour.')
|
||||
if self._has_gpus:
|
||||
if is_vllm_available():
|
||||
return 'vllm'
|
||||
elif is_ctranslate_available():
|
||||
return 'ctranslate'
|
||||
elif is_ctranslate_available():
|
||||
return 'ctranslate'
|
||||
logger.warning(
|
||||
'It is recommended to specify the backend explicitly. Cascading backend might lead to unexpected behaviour.'
|
||||
)
|
||||
if self._has_gpus and is_vllm_available():
|
||||
return 'vllm'
|
||||
else:
|
||||
return 'pt'
|
||||
|
||||
@@ -339,7 +344,10 @@ class LLM(t.Generic[M, T]):
|
||||
|
||||
@property
|
||||
def import_kwargs(self):
|
||||
return {'device_map': 'auto' if self._has_gpus else None, 'torch_dtype': self._torch_dtype}, {'padding_side': 'left', 'truncation_side': 'left'}
|
||||
return {'device_map': 'auto' if self._has_gpus else None, 'torch_dtype': self._torch_dtype}, {
|
||||
'padding_side': 'left',
|
||||
'truncation_side': 'left',
|
||||
}
|
||||
|
||||
@property
|
||||
def trust_remote_code(self):
|
||||
@@ -362,7 +370,7 @@ class LLM(t.Generic[M, T]):
|
||||
|
||||
@property
|
||||
def bentomodel(self):
|
||||
return openllm.serialisation.get(self)
|
||||
return bentoml.models.get(self.tag)
|
||||
|
||||
@property
|
||||
def quantization_config(self):
|
||||
@@ -372,7 +380,9 @@ class LLM(t.Generic[M, T]):
|
||||
if self._quantization_config is not None:
|
||||
self.__llm_quantization_config__ = self._quantization_config
|
||||
elif self._quantise is not None:
|
||||
self.__llm_quantization_config__, self._model_attrs = infer_quantisation_config(self, self._quantise, **self._model_attrs)
|
||||
self.__llm_quantization_config__, self._model_attrs = infer_quantisation_config(
|
||||
self, self._quantise, **self._model_attrs
|
||||
)
|
||||
else:
|
||||
raise ValueError("Either 'quantization_config' or 'quantise' must be specified.")
|
||||
return self.__llm_quantization_config__
|
||||
@@ -400,7 +410,7 @@ class LLM(t.Generic[M, T]):
|
||||
@property
|
||||
def identifying_params(self):
|
||||
return {
|
||||
'configuration': self.config.model_dump_json().decode(),
|
||||
'configuration': self.config.model_dump_json(),
|
||||
'model_ids': orjson.dumps(self.config['model_ids']).decode(),
|
||||
'model_id': self.model_id,
|
||||
}
|
||||
@@ -427,7 +437,11 @@ class LLM(t.Generic[M, T]):
|
||||
|
||||
model = get_peft_model(
|
||||
prepare_model_for_kbit_training(self.model, use_gradient_checkpointing=use_gradient_checking),
|
||||
self.config['fine_tune_strategies'].get(adapter_type, self.config.make_fine_tune_config(adapter_type)).train().with_config(**attrs).build(),
|
||||
self.config['fine_tune_strategies']
|
||||
.get(adapter_type, self.config.make_fine_tune_config(adapter_type))
|
||||
.train()
|
||||
.with_config(**attrs)
|
||||
.build(),
|
||||
)
|
||||
if DEBUG:
|
||||
model.print_trainable_parameters()
|
||||
@@ -447,7 +461,10 @@ class LLM(t.Generic[M, T]):
|
||||
if self.__llm_adapter_map__ is None:
|
||||
_map: ResolvedAdapterMap = {k: {} for k in self._adapter_map}
|
||||
for adapter_type, adapter_tuple in self._adapter_map.items():
|
||||
base = first_not_none(self.config['fine_tune_strategies'].get(adapter_type), default=self.config.make_fine_tune_config(adapter_type))
|
||||
base = first_not_none(
|
||||
self.config['fine_tune_strategies'].get(adapter_type),
|
||||
default=self.config.make_fine_tune_config(adapter_type),
|
||||
)
|
||||
for adapter in adapter_tuple:
|
||||
_map[adapter_type][adapter.name] = (base.with_config(**adapter.config).build(), adapter.adapter_id)
|
||||
self.__llm_adapter_map__ = _map
|
||||
@@ -456,50 +473,44 @@ class LLM(t.Generic[M, T]):
|
||||
@property
|
||||
def model(self):
|
||||
if self.__llm_model__ is None:
|
||||
model = openllm.serialisation.load_model(self, *self._model_decls, **self._model_attrs)
|
||||
# If OOM, then it is probably you don't have enough VRAM to run this model.
|
||||
if self.__llm_backend__ == 'pt':
|
||||
import torch
|
||||
|
||||
loaded_in_kbit = (
|
||||
getattr(model, 'is_loaded_in_8bit', False) or getattr(model, 'is_loaded_in_4bit', False) or getattr(model, 'is_quantized', False)
|
||||
)
|
||||
if torch.cuda.is_available() and torch.cuda.device_count() == 1 and not loaded_in_kbit:
|
||||
try:
|
||||
model = model.to('cuda')
|
||||
except Exception as err:
|
||||
raise OpenLLMException(f'Failed to load model into GPU: {err}.\n') from err
|
||||
if self.has_adapters:
|
||||
logger.debug('Applying the following adapters: %s', self.adapter_map)
|
||||
for adapter_dict in self.adapter_map.values():
|
||||
for adapter_name, (peft_config, peft_model_id) in adapter_dict.items():
|
||||
model.load_adapter(peft_model_id, adapter_name, peft_config=peft_config)
|
||||
self.__llm_model__ = model
|
||||
self.__llm_model__ = openllm.serialisation.load_model(self, *self._model_decls, **self._model_attrs)
|
||||
return self.__llm_model__
|
||||
|
||||
@property
|
||||
def config(self):
|
||||
import transformers
|
||||
|
||||
if self.__llm_config__ is None:
|
||||
if self.__llm_backend__ == 'ctranslate':
|
||||
try:
|
||||
config = transformers.AutoConfig.from_pretrained(self.bentomodel.path_of('/hf'), trust_remote_code=self.trust_remote_code)
|
||||
except OpenLLMException:
|
||||
config = transformers.AutoConfig.from_pretrained(self.model_id, trust_remote_code=self.trust_remote_code)
|
||||
for architecture in config.architectures:
|
||||
if architecture in openllm.AutoConfig._CONFIG_MAPPING_NAMES_TO_ARCHITECTURE():
|
||||
config = openllm.AutoConfig.infer_class_from_name(
|
||||
openllm.AutoConfig._CONFIG_MAPPING_NAMES_TO_ARCHITECTURE()[architecture]
|
||||
).model_construct_env(**self._model_attrs)
|
||||
break
|
||||
else:
|
||||
raise OpenLLMException(
|
||||
f"Failed to infer the configuration class. Make sure the model is a supported model. Supported models are: {', '.join(openllm.AutoConfig._CONFIG_MAPPING_NAMES_TO_ARCHITECTURE.keys())}"
|
||||
)
|
||||
if self._local:
|
||||
config_file = os.path.join(self.model_id, CONFIG_FILE_NAME)
|
||||
else:
|
||||
config = openllm.AutoConfig.infer_class_from_llm(self).model_construct_env(**self._model_attrs)
|
||||
self.__llm_config__ = config
|
||||
try:
|
||||
config_file = self.bentomodel.path_of(CONFIG_FILE_NAME)
|
||||
except OpenLLMException as err:
|
||||
if not is_transformers_available():
|
||||
raise MissingDependencyError(
|
||||
"Requires 'transformers' to be available. Do 'pip install transformers'"
|
||||
) from err
|
||||
from transformers.utils import cached_file
|
||||
|
||||
try:
|
||||
config_file = cached_file(self.model_id, CONFIG_FILE_NAME)
|
||||
except Exception as err:
|
||||
raise ValueError(
|
||||
"Failed to determine architecture from 'config.json'. If this is a gated model, make sure to pass in HUGGING_FACE_HUB_TOKEN"
|
||||
) from err
|
||||
if not os.path.exists(config_file):
|
||||
raise ValueError(f"Failed to find 'config.json' (config_json_path={config_file})")
|
||||
with open(config_file, 'r', encoding='utf-8') as f:
|
||||
loaded_config = orjson.loads(f.read())
|
||||
|
||||
if 'architectures' in loaded_config:
|
||||
for architecture in loaded_config['architectures']:
|
||||
if architecture in self._architecture_mappings:
|
||||
self.__llm_config__ = openllm_core.AutoConfig.for_model(
|
||||
self._architecture_mappings[architecture]
|
||||
).model_construct_env()
|
||||
break
|
||||
else:
|
||||
raise ValueError(f"Failed to find architecture from 'config.json' (config_json_path={config_file})")
|
||||
return self.__llm_config__
|
||||
|
||||
|
||||
@@ -516,13 +527,11 @@ def _torch_dtype_mapping() -> dict[str, torch.dtype]:
|
||||
}
|
||||
|
||||
|
||||
def normalise_model_name(name: str) -> str:
|
||||
return os.path.basename(resolve_filepath(name)) if validate_is_path(name) else inflection.dasherize(name.replace('/', '--'))
|
||||
|
||||
|
||||
def convert_peft_config_type(adapter_map: dict[str, str]) -> AdapterMap:
|
||||
if not is_peft_available():
|
||||
raise RuntimeError("LoRA adapter requires 'peft' to be installed. Make sure to do 'pip install \"openllm[fine-tune]\"'")
|
||||
raise RuntimeError(
|
||||
"LoRA adapter requires 'peft' to be installed. Make sure to do 'pip install \"openllm[fine-tune]\"'"
|
||||
)
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
resolved: AdapterMap = {}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, AsyncGenerator, Dict, Generic, Iterable, List, Literal, Optional, Tuple, TypedDict, Union
|
||||
from typing import Any, AsyncGenerator, Dict, Generic, Iterable, List, Optional, Tuple, TypedDict, Union, TypeVar
|
||||
|
||||
import attr
|
||||
import torch
|
||||
@@ -7,13 +7,26 @@ from peft.peft_model import PeftModel, PeftModelForCausalLM, PeftModelForSeq2Seq
|
||||
|
||||
from bentoml import Model, Tag
|
||||
from openllm_core import LLMConfig
|
||||
from openllm_core._schemas import GenerationOutput
|
||||
from openllm_core._typing_compat import AdapterMap, AdapterType, LiteralBackend, LiteralDtype, LiteralQuantise, LiteralSerialisation, M, T
|
||||
from openllm_core._schemas import GenerationOutput, GenerationInputDict, MetadataOutput
|
||||
from openllm_core._typing_compat import (
|
||||
AdapterMap,
|
||||
AdapterType,
|
||||
LiteralBackend,
|
||||
LiteralQuantise,
|
||||
LiteralSerialisation,
|
||||
ParamSpec,
|
||||
MessagesConverterInput,
|
||||
)
|
||||
from openllm_core.utils import api
|
||||
|
||||
from ._quantisation import QuantizationConfig
|
||||
from ._runners import Runner
|
||||
from _openllm_tiny._llm import Dtype
|
||||
|
||||
InjectedModel = Union[PeftModel, PeftModelForCausalLM, PeftModelForSeq2SeqLM]
|
||||
P = ParamSpec('P')
|
||||
M = TypeVar('M')
|
||||
T = TypeVar('T')
|
||||
|
||||
class IdentifyingParams(TypedDict):
|
||||
configuration: str
|
||||
@@ -21,8 +34,16 @@ class IdentifyingParams(TypedDict):
|
||||
model_id: str
|
||||
|
||||
ResolvedAdapterMap = Dict[AdapterType, Dict[str, Tuple[PeftConfig, str]]]
|
||||
CTranslateDtype = Literal['int8_float32', 'int8_float16', 'int8_bfloat16']
|
||||
Dtype = Union[LiteralDtype, CTranslateDtype, Literal['auto', 'half', 'float']]
|
||||
|
||||
class LLMService:
|
||||
@api
|
||||
async def generate_v1(self, parameters: GenerationInputDict = ...) -> GenerationOutput: ...
|
||||
@api
|
||||
async def generate_stream_v1(self, parameters: GenerationInputDict = ...) -> AsyncGenerator[str, None]: ...
|
||||
@api
|
||||
def metadata_v1(self) -> MetadataOutput: ...
|
||||
@api
|
||||
def helpers_messages_v1(self, message: MessagesConverterInput = ...) -> str: ...
|
||||
|
||||
@attr.define(slots=True, repr=False, init=False)
|
||||
class LLM(Generic[M, T]):
|
||||
@@ -37,6 +58,8 @@ class LLM(Generic[M, T]):
|
||||
_adapter_map: Optional[AdapterMap]
|
||||
_serialisation: LiteralSerialisation
|
||||
_local: bool
|
||||
_max_model_len: Optional[int]
|
||||
_gpu_memory_utilization: float
|
||||
|
||||
__llm_dtype__: Dtype = ...
|
||||
__llm_torch_dtype__: Optional[torch.dtype] = ...
|
||||
@@ -49,7 +72,26 @@ class LLM(Generic[M, T]):
|
||||
__llm_adapter_map__: Optional[ResolvedAdapterMap] = ...
|
||||
__llm_trust_remote_code__: bool = ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
async def generate(
|
||||
self,
|
||||
prompt: Optional[str],
|
||||
prompt_token_ids: Optional[List[int]] = ...,
|
||||
stop: Optional[Union[str, Iterable[str]]] = ...,
|
||||
stop_token_ids: Optional[List[int]] = ...,
|
||||
request_id: Optional[str] = ...,
|
||||
adapter_name: Optional[str] = ...,
|
||||
**attrs: Any,
|
||||
) -> GenerationOutput: ...
|
||||
async def generate_iterator(
|
||||
self,
|
||||
prompt: Optional[str],
|
||||
prompt_token_ids: Optional[List[int]] = ...,
|
||||
stop: Optional[Union[str, Iterable[str]]] = ...,
|
||||
stop_token_ids: Optional[List[int]] = ...,
|
||||
request_id: Optional[str] = ...,
|
||||
adapter_name: Optional[str] = ...,
|
||||
**attrs: Any,
|
||||
) -> AsyncGenerator[GenerationOutput, None]: ...
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
@@ -66,6 +108,8 @@ class LLM(Generic[M, T]):
|
||||
embedded: bool = ...,
|
||||
dtype: Dtype = ...,
|
||||
low_cpu_mem_usage: bool = ...,
|
||||
max_model_len: Optional[int] = ...,
|
||||
gpu_memory_utilization: float = ...,
|
||||
**attrs: Any,
|
||||
) -> None: ...
|
||||
@property
|
||||
@@ -91,8 +135,6 @@ class LLM(Generic[M, T]):
|
||||
@property
|
||||
def quantization_config(self) -> QuantizationConfig: ...
|
||||
@property
|
||||
def has_adapters(self) -> bool: ...
|
||||
@property
|
||||
def local(self) -> bool: ...
|
||||
@property
|
||||
def quantise(self) -> Optional[LiteralQuantise]: ...
|
||||
@@ -112,24 +154,6 @@ class LLM(Generic[M, T]):
|
||||
def runner(self) -> Runner[M, T]: ...
|
||||
@property
|
||||
def adapter_map(self) -> ResolvedAdapterMap: ...
|
||||
def prepare(self, adapter_type: AdapterType = ..., use_gradient_checking: bool = ..., **attrs: Any) -> Tuple[InjectedModel, T]: ...
|
||||
async def generate(
|
||||
self,
|
||||
prompt: Optional[str],
|
||||
prompt_token_ids: Optional[List[int]] = ...,
|
||||
stop: Optional[Union[str, Iterable[str]]] = ...,
|
||||
stop_token_ids: Optional[List[int]] = ...,
|
||||
request_id: Optional[str] = ...,
|
||||
adapter_name: Optional[str] = ...,
|
||||
**attrs: Any,
|
||||
) -> GenerationOutput: ...
|
||||
async def generate_iterator(
|
||||
self,
|
||||
prompt: Optional[str],
|
||||
prompt_token_ids: Optional[List[int]] = ...,
|
||||
stop: Optional[Union[str, Iterable[str]]] = ...,
|
||||
stop_token_ids: Optional[List[int]] = ...,
|
||||
request_id: Optional[str] = ...,
|
||||
adapter_name: Optional[str] = ...,
|
||||
**attrs: Any,
|
||||
) -> AsyncGenerator[GenerationOutput, None]: ...
|
||||
def prepare(
|
||||
self, adapter_type: AdapterType = ..., use_gradient_checking: bool = ..., **attrs: Any
|
||||
) -> Tuple[InjectedModel, T]: ...
|
||||
|
||||
@@ -83,19 +83,25 @@ def infer_quantisation_config(llm, quantise, **attrs):
|
||||
|
||||
# NOTE: Quantization setup quantize is a openllm.LLM feature, where we can quantize the model with bitsandbytes or quantization aware training.
|
||||
if not is_bitsandbytes_available():
|
||||
raise RuntimeError('Quantization requires bitsandbytes to be installed. Make sure to install OpenLLM with \'pip install "openllm[fine-tune]"\'')
|
||||
raise RuntimeError(
|
||||
'Quantization requires bitsandbytes to be installed. Make sure to install OpenLLM with \'pip install "openllm[fine-tune]"\''
|
||||
)
|
||||
if quantise == 'int8':
|
||||
quantisation_config = create_int8_config(int8_skip_modules)
|
||||
elif quantise == 'int4':
|
||||
quantisation_config = create_int4_config()
|
||||
elif quantise == 'gptq':
|
||||
if not is_autogptq_available():
|
||||
raise MissingDependencyError("GPTQ requires 'auto-gptq' and 'optimum>=0.12' to be installed. Do it with 'pip install \"openllm[gptq]\"'")
|
||||
raise MissingDependencyError(
|
||||
"GPTQ requires 'auto-gptq' and 'optimum>=0.12' to be installed. Do it with 'pip install \"openllm[gptq]\"'"
|
||||
)
|
||||
else:
|
||||
quantisation_config = create_gptq_config()
|
||||
elif quantise == 'awq':
|
||||
if not is_autoawq_available():
|
||||
raise MissingDependencyError("AWQ requires 'auto-awq' to be installed. Do it with 'pip install \"openllm[awq]\"'.")
|
||||
raise MissingDependencyError(
|
||||
"AWQ requires 'auto-awq' to be installed. Do it with 'pip install \"openllm[awq]\"'."
|
||||
)
|
||||
else:
|
||||
quantisation_config = create_awq_config()
|
||||
else:
|
||||
|
||||
@@ -9,10 +9,18 @@ from ._llm import LLM
|
||||
QuantizationConfig = Union[BitsAndBytesConfig, GPTQConfig, AwqConfig]
|
||||
|
||||
@overload
|
||||
def infer_quantisation_config(self: LLM[M, T], quantise: Literal['int8', 'int4'], **attrs: Any) -> tuple[BitsAndBytesConfig, Dict[str, Any]]: ...
|
||||
def infer_quantisation_config(
|
||||
self: LLM[M, T], quantise: Literal['int8', 'int4'], **attrs: Any
|
||||
) -> tuple[BitsAndBytesConfig, Dict[str, Any]]: ...
|
||||
@overload
|
||||
def infer_quantisation_config(self: LLM[M, T], quantise: Literal['gptq'], **attrs: Any) -> tuple[GPTQConfig, Dict[str, Any]]: ...
|
||||
def infer_quantisation_config(
|
||||
self: LLM[M, T], quantise: Literal['gptq'], **attrs: Any
|
||||
) -> tuple[GPTQConfig, Dict[str, Any]]: ...
|
||||
@overload
|
||||
def infer_quantisation_config(self: LLM[M, T], quantise: Literal['awq'], **attrs: Any) -> tuple[AwqConfig, Dict[str, Any]]: ...
|
||||
def infer_quantisation_config(
|
||||
self: LLM[M, T], quantise: Literal['awq'], **attrs: Any
|
||||
) -> tuple[AwqConfig, Dict[str, Any]]: ...
|
||||
@overload
|
||||
def infer_quantisation_config(self: LLM[M, T], quantise: LiteralQuantise, **attrs: Any) -> tuple[QuantizationConfig, Dict[str, Any]]: ...
|
||||
def infer_quantisation_config(
|
||||
self: LLM[M, T], quantise: LiteralQuantise, **attrs: Any
|
||||
) -> tuple[QuantizationConfig, Dict[str, Any]]: ...
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from __future__ import annotations
|
||||
import gc, traceback, types, typing as t
|
||||
import gc, types, typing as t
|
||||
import torch, bentoml, openllm
|
||||
from openllm_core._schemas import CompletionChunk, GenerationOutput, SampleLogprobs
|
||||
from openllm_core.utils import ReprMixin, is_ctranslate_available, is_vllm_available
|
||||
from openllm_core.utils import ReprMixin, is_vllm_available
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from openllm_core._typing_compat import M, T
|
||||
@@ -46,11 +46,14 @@ def runner(llm: openllm.LLM[M, T]) -> Runner[M, T]:
|
||||
(
|
||||
'runner_methods',
|
||||
{
|
||||
method.name: {'batchable': method.config.batchable, 'batch_dim': method.config.batch_dim if method.config.batchable else None}
|
||||
method.name: {
|
||||
'batchable': method.config.batchable,
|
||||
'batch_dim': method.config.batch_dim if method.config.batchable else None,
|
||||
}
|
||||
for method in _.runner_methods
|
||||
},
|
||||
),
|
||||
('config', llm.config.model_dump(flatten=True)),
|
||||
('config', llm.config.model_dump()),
|
||||
('llm_type', llm.llm_type),
|
||||
('backend', llm.__llm_backend__),
|
||||
('llm_tag', llm.tag),
|
||||
@@ -68,49 +71,6 @@ def runner(llm: openllm.LLM[M, T]) -> Runner[M, T]:
|
||||
)
|
||||
|
||||
|
||||
@registry
|
||||
class CTranslateRunnable(bentoml.Runnable):
|
||||
SUPPORTED_RESOURCES = ('nvidia.com/gpu', 'cpu')
|
||||
SUPPORTS_CPU_MULTI_THREADING = True
|
||||
|
||||
def __init__(self, llm):
|
||||
if not is_ctranslate_available():
|
||||
raise openllm.exceptions.OpenLLMException('ctranslate is not installed. Do `pip install "openllm[ctranslate]"`')
|
||||
self.llm, self.config, self.model, self.tokenizer = llm, llm.config, llm.model, llm.tokenizer
|
||||
|
||||
@bentoml.Runnable.method(batchable=False)
|
||||
async def generate_iterator(self, prompt_token_ids, request_id, stop=None, adapter_name=None, **attrs):
|
||||
config, sampling_params = self.config.model_construct_env(stop=list(stop), **attrs).inference_options(self.llm)
|
||||
cumulative_logprob, output_token_ids, input_len = 0.0, list(prompt_token_ids), len(prompt_token_ids)
|
||||
tokens = self.tokenizer.convert_ids_to_tokens(prompt_token_ids)
|
||||
async for request_output in self.model.async_generate_tokens(tokens, **sampling_params):
|
||||
if config['logprobs']:
|
||||
cumulative_logprob += request_output.log_prob
|
||||
output_token_ids.append(request_output.token_id)
|
||||
text = self.tokenizer.decode(
|
||||
output_token_ids[input_len:],
|
||||
skip_special_tokens=True, #
|
||||
spaces_between_special_tokens=False,
|
||||
clean_up_tokenization_spaces=True, #
|
||||
)
|
||||
out = GenerationOutput(
|
||||
prompt_token_ids=prompt_token_ids, #
|
||||
prompt='',
|
||||
finished=request_output.is_last,
|
||||
request_id=request_id, #
|
||||
outputs=[
|
||||
CompletionChunk(
|
||||
index=0,
|
||||
text=text,
|
||||
finish_reason=None, #
|
||||
token_ids=output_token_ids[input_len:],
|
||||
cumulative_logprob=cumulative_logprob, #
|
||||
# TODO: logprobs, but seems like we don't have access to the raw logits
|
||||
)
|
||||
],
|
||||
).model_dump_json()
|
||||
yield bentoml.io.SSE(out).marshal()
|
||||
|
||||
@registry
|
||||
class vLLMRunnable(bentoml.Runnable):
|
||||
SUPPORTED_RESOURCES = ('nvidia.com/gpu', 'amd.com/gpu', 'cpu')
|
||||
@@ -119,41 +79,17 @@ class vLLMRunnable(bentoml.Runnable):
|
||||
def __init__(self, llm):
|
||||
if not is_vllm_available():
|
||||
raise openllm.exceptions.OpenLLMException('vLLM is not installed. Do `pip install "openllm[vllm]"`.')
|
||||
import vllm
|
||||
|
||||
self.llm, self.config, self.tokenizer = llm, llm.config, llm.tokenizer
|
||||
num_gpus, dev = 1, openllm.utils.device_count()
|
||||
if dev >= 2:
|
||||
num_gpus = min(dev // 2 * 2, dev)
|
||||
quantise = llm.quantise if llm.quantise and llm.quantise in {'gptq', 'awq', 'squeezellm'} else None
|
||||
dtype = torch.float16 if quantise == 'gptq' else llm._torch_dtype # NOTE: quantise GPTQ doesn't support bfloat16 yet.
|
||||
try:
|
||||
self.model = vllm.AsyncLLMEngine.from_engine_args(
|
||||
vllm.AsyncEngineArgs(
|
||||
worker_use_ray=False,
|
||||
engine_use_ray=False, #
|
||||
tokenizer_mode='auto',
|
||||
tensor_parallel_size=num_gpus, #
|
||||
model=llm.bentomodel.path,
|
||||
tokenizer=llm.bentomodel.path, #
|
||||
trust_remote_code=llm.trust_remote_code,
|
||||
dtype=dtype, #
|
||||
max_model_len=llm._max_model_len,
|
||||
gpu_memory_utilization=llm._gpu_memory_utilization, #
|
||||
quantization=quantise,
|
||||
)
|
||||
)
|
||||
except Exception as err:
|
||||
traceback.print_exc()
|
||||
raise openllm.exceptions.OpenLLMException(f'Failed to initialise vLLMEngine due to the following error:\n{err}') from err
|
||||
self.llm, self.config, self.model = llm, llm.config, llm.model
|
||||
|
||||
@bentoml.Runnable.method(batchable=False)
|
||||
async def generate_iterator(self, prompt_token_ids, request_id, stop=None, adapter_name=None, **attrs):
|
||||
async def generate_iterator(self, prompt, request_id, prompt_token_ids=None, stop=None, adapter_name=None, **attrs):
|
||||
_, sampling_params = self.config.model_construct_env(stop=stop, **attrs).inference_options(self.llm)
|
||||
async for request_output in self.model.generate(None, sampling_params, request_id, prompt_token_ids):
|
||||
async for request_output in self.model.generate(
|
||||
prompt, sampling_params=sampling_params, request_id=request_id, prompt_token_ids=prompt_token_ids
|
||||
):
|
||||
out = GenerationOutput.from_vllm(request_output).model_dump_json()
|
||||
out = bentoml.io.SSE(out).marshal()
|
||||
yield out
|
||||
yield bentoml.io.SSE(out).marshal()
|
||||
|
||||
|
||||
@registry(alias='pt')
|
||||
@@ -170,12 +106,17 @@ class PyTorchRunnable(bentoml.Runnable):
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
@bentoml.Runnable.method(batchable=False)
|
||||
async def generate_iterator(self, prompt_token_ids, request_id, stop=None, adapter_name=None, **attrs):
|
||||
async def generate_iterator(self, prompt, request_id, prompt_token_ids=None, stop=None, adapter_name=None, **attrs):
|
||||
from ._generation import get_context_length, prepare_logits_processor
|
||||
|
||||
if adapter_name is not None:
|
||||
self.model.set_adapter(adapter_name)
|
||||
|
||||
if prompt_token_ids is None:
|
||||
if prompt is None:
|
||||
raise ValueError('Either prompt or prompt_token_ids must be specified.')
|
||||
prompt_token_ids = self.tokenizer.encode(prompt)
|
||||
|
||||
max_new_tokens = attrs.pop('max_new_tokens', 256)
|
||||
context_length = attrs.pop('context_length', None)
|
||||
if context_length is None:
|
||||
@@ -202,7 +143,9 @@ class PyTorchRunnable(bentoml.Runnable):
|
||||
if config['logprobs']: # FIXME: logprobs is not supported
|
||||
raise NotImplementedError('Logprobs is yet to be supported with encoder-decoder models.')
|
||||
encoder_output = self.model.encoder(input_ids=torch.as_tensor([prompt_token_ids], device=self.device))[0]
|
||||
start_ids = torch.as_tensor([[self.model.generation_config.decoder_start_token_id]], dtype=torch.int64, device=self.device)
|
||||
start_ids = torch.as_tensor(
|
||||
[[self.model.generation_config.decoder_start_token_id]], dtype=torch.int64, device=self.device
|
||||
)
|
||||
else:
|
||||
start_ids = torch.as_tensor([prompt_token_ids], device=self.device)
|
||||
|
||||
@@ -230,7 +173,9 @@ class PyTorchRunnable(bentoml.Runnable):
|
||||
)
|
||||
logits = self.model.lm_head(out[0])
|
||||
else:
|
||||
out = self.model(input_ids=torch.as_tensor([[token]], device=self.device), past_key_values=past_key_values, use_cache=True)
|
||||
out = self.model(
|
||||
input_ids=torch.as_tensor([[token]], device=self.device), past_key_values=past_key_values, use_cache=True
|
||||
)
|
||||
logits = out.logits
|
||||
past_key_values = out.past_key_values
|
||||
if logits_processor:
|
||||
@@ -274,7 +219,12 @@ class PyTorchRunnable(bentoml.Runnable):
|
||||
|
||||
tmp_output_ids, rfind_start = output_token_ids[input_len:], 0
|
||||
# XXX: Move this to API server
|
||||
text = self.tokenizer.decode(tmp_output_ids, skip_special_tokens=True, spaces_between_special_tokens=False, clean_up_tokenization_spaces=True)
|
||||
text = self.tokenizer.decode(
|
||||
tmp_output_ids,
|
||||
skip_special_tokens=True,
|
||||
spaces_between_special_tokens=False,
|
||||
clean_up_tokenization_spaces=True,
|
||||
)
|
||||
|
||||
if len(stop) > 0:
|
||||
for it in stop:
|
||||
|
||||
@@ -1,40 +1,67 @@
|
||||
from typing import Any, AsyncGenerator, Dict, Generic, Iterable, List, Literal, Optional, Protocol, Tuple, Type, TypeVar, Union, final
|
||||
|
||||
import torch
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizer
|
||||
|
||||
from bentoml import Model, Strategy, Tag
|
||||
from typing import (
|
||||
Any,
|
||||
TypeVar,
|
||||
Protocol,
|
||||
AsyncGenerator,
|
||||
List,
|
||||
Optional,
|
||||
Iterable,
|
||||
Dict,
|
||||
Union,
|
||||
Tuple,
|
||||
Generic,
|
||||
Type,
|
||||
Literal,
|
||||
)
|
||||
import bentoml
|
||||
from bentoml._internal.runner.runner_handle import RunnerHandle
|
||||
|
||||
from openllm_core import LLMConfig
|
||||
from openllm_core._typing_compat import LiteralBackend, M, T
|
||||
from openllm_core._typing_compat import M, T, LiteralBackend
|
||||
|
||||
from ._llm import LLM
|
||||
|
||||
try:
|
||||
from vllm import AsyncLLMEngine
|
||||
except ImportError:
|
||||
AsyncLLMEngine = Any
|
||||
|
||||
try:
|
||||
from ctranslate2 import Generator, Translator
|
||||
except ImportError:
|
||||
Translator = Generator = Any
|
||||
|
||||
Mo = TypeVar('Mo')
|
||||
To = TypeVar('To')
|
||||
|
||||
__all_ = ['Runner', 'runner']
|
||||
|
||||
def runner(llm: LLM[M, T]) -> Runner[M, T]: ...
|
||||
|
||||
# class Runner(Protocol[Mo, To]):
|
||||
# __doc__: str = ...
|
||||
# __module__: str = ...
|
||||
# llm: LLM[Mo, To] = ...
|
||||
# llm_config: LLMConfig = ...
|
||||
# llm_type: str = ...
|
||||
# llm_tag: bentoml.Tag = ...
|
||||
# llm_bentomodel: bentoml.Model = ...
|
||||
# identifying_params: Dict[str, Any] = ...
|
||||
# backend: LiteralBackend = ...
|
||||
# template: str = ...
|
||||
# system_message: str = ...
|
||||
#
|
||||
# @api # type: ignore[arg-type] # XXX: I don't really know how to fix this for marking positional-only arg as self?
|
||||
# async def generate_iterator(
|
||||
# self,
|
||||
# prompt_token_ids: List[int],
|
||||
# request_id: str,
|
||||
# stop: Optional[Iterable[str]] = ...,
|
||||
# adapter_name: Optional[str] = ...,
|
||||
# **attrs: Any,
|
||||
# ) -> AsyncGenerator[GenerationOutput, None]: ...
|
||||
|
||||
class _Runnable(Protocol[Mo, To]):
|
||||
SUPPORTED_RESOURCES: Tuple[Literal['nvidia.com/gpu', 'amd.com/gpu', 'cpu'], ...] = ...
|
||||
SUPPORTS_CPU_MULTI_THREADING: bool = ...
|
||||
llm: LLM[Mo, To] = ...
|
||||
config: LLMConfig = ...
|
||||
model: Mo = ...
|
||||
tokenizer: To = ...
|
||||
def __init__(self, llm: LLM[Mo, T]) -> None: ...
|
||||
def __init__(self, llm: LLM[Mo, To]) -> None: ...
|
||||
async def generate_iterator(
|
||||
self,
|
||||
prompt_token_ids: List[int],
|
||||
prompt: str,
|
||||
request_id: str,
|
||||
prompt_token_ids: Optional[List[int]] = ...,
|
||||
stop: Optional[Union[str, Iterable[str]]] = ...,
|
||||
adapter_name: Optional[str] = ...,
|
||||
**attrs: Any,
|
||||
@@ -45,36 +72,27 @@ Ret = TypeVar('Ret')
|
||||
|
||||
class RunnerMethod(Generic[In, Ret]): ...
|
||||
|
||||
@final
|
||||
class vLLMRunnable(_Runnable[AsyncLLMEngine, PreTrainedTokenizer]): ...
|
||||
|
||||
@final
|
||||
class CTranslateRunnable(_Runnable[Union[Translator, Generator], PreTrainedTokenizer]): ...
|
||||
|
||||
@final
|
||||
class PyTorchRunnable(_Runnable[PreTrainedModel, PreTrainedTokenizer]):
|
||||
is_encoder_decoder: bool = ...
|
||||
device: torch.device = ...
|
||||
|
||||
def runner(llm: LLM[M, T]) -> Runner[M, T]: ...
|
||||
|
||||
class Runner(Protocol[Mo, To]):
|
||||
__doc__: str = ...
|
||||
__module__: str = ...
|
||||
llm_type: str = ...
|
||||
llm_tag: Tag = ...
|
||||
llm_tag: bentoml.Tag = ...
|
||||
identifying_params: Dict[str, Any] = ...
|
||||
llm: LLM[Mo, To] = ...
|
||||
config: LLMConfig = ...
|
||||
backend: LiteralBackend = ...
|
||||
has_adapters: bool = ...
|
||||
template: str = ...
|
||||
system_message: str = ...
|
||||
|
||||
class generate_iterator(RunnerMethod[List[int], AsyncGenerator[str, None]]):
|
||||
@staticmethod
|
||||
def async_stream(
|
||||
prompt_token_ids: List[int], request_id: str, stop: Optional[Union[Iterable[str], str]] = ..., adapter_name: Optional[str] = ..., **attrs: Any
|
||||
prompt: str,
|
||||
request_id: str,
|
||||
prompt_token_ids: Optional[List[int]] = ...,
|
||||
stop: Optional[Union[Iterable[str], str]] = ...,
|
||||
adapter_name: Optional[str] = ...,
|
||||
**attrs: Any,
|
||||
) -> AsyncGenerator[str, None]: ...
|
||||
|
||||
def __init__(
|
||||
@@ -83,8 +101,8 @@ class Runner(Protocol[Mo, To]):
|
||||
*,
|
||||
runnable_init_params: Optional[Dict[str, Any]] = ...,
|
||||
name: Optional[str] = ...,
|
||||
scheduling_strategy: Type[Strategy] = ...,
|
||||
models: Optional[List[Model]] = ...,
|
||||
scheduling_strategy: Type[bentoml.Strategy] = ...,
|
||||
models: Optional[List[bentoml.Model]] = ...,
|
||||
max_batch_size: Optional[int] = ...,
|
||||
max_latency_ms: Optional[int] = ...,
|
||||
method_configs: Optional[Dict[str, Dict[str, int]]] = ...,
|
||||
@@ -92,12 +110,12 @@ class Runner(Protocol[Mo, To]):
|
||||
) -> None: ...
|
||||
|
||||
name: str = ...
|
||||
models: List[Model] = ...
|
||||
models: List[bentoml.Model] = ...
|
||||
resource_config: Dict[str, Any]
|
||||
runnable_class: Type[_Runnable[Mo, To]]
|
||||
embedded: bool
|
||||
runner_methods: List[RunnerMethod[Any, Any]]
|
||||
scheduling_strategy: Type[Strategy]
|
||||
scheduling_strategy: Type[bentoml.Strategy]
|
||||
workers_per_resource: Union[int, float] = ...
|
||||
runnable_init_params: Dict[str, Any] = ...
|
||||
_runner_handle: RunnerHandle = ...
|
||||
|
||||
@@ -1,71 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import logging, typing as t
|
||||
import bentoml, openllm, _service_vars as svars
|
||||
from openllm_core._schemas import MessageParam
|
||||
from bentoml.io import JSON, Text
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
llm = openllm.LLM[t.Any, t.Any](
|
||||
model_id=svars.model_id,
|
||||
model_tag=svars.model_tag,
|
||||
adapter_map=svars.adapter_map, #
|
||||
serialisation=svars.serialization,
|
||||
trust_remote_code=svars.trust_remote_code, #
|
||||
max_model_len=svars.max_model_len,
|
||||
gpu_memory_utilization=svars.gpu_memory_utilization, #
|
||||
)
|
||||
svc = bentoml.Service(name=f"llm-{llm.config['start_name']}-service", runners=[llm.runner])
|
||||
llm_model_class = openllm.GenerationInput.from_llm_config(llm.config)
|
||||
|
||||
|
||||
@svc.api(route='/v1/generate', input=JSON.from_sample(llm_model_class.examples()), output=JSON.from_sample(openllm.GenerationOutput.examples()))
|
||||
async def generate_v1(input_dict: dict[str, t.Any]) -> dict[str, t.Any]:
|
||||
return (await llm.generate(**llm_model_class(**input_dict).model_dump())).model_dump()
|
||||
|
||||
|
||||
@svc.api(route='/v1/generate_stream', input=JSON.from_sample(llm_model_class.examples()), output=Text(content_type='text/event-stream'))
|
||||
async def generate_stream_v1(input_dict: dict[str, t.Any]) -> t.AsyncGenerator[str, None]:
|
||||
async for it in llm.generate_iterator(**llm_model_class(**input_dict).model_dump()):
|
||||
yield f'data: {it.model_dump_json()}\n\n'
|
||||
yield 'data: [DONE]\n\n'
|
||||
|
||||
|
||||
_Metadata = openllm.MetadataOutput(
|
||||
timeout=llm.config['timeout'],
|
||||
model_name=llm.config['model_name'], #
|
||||
backend=llm.__llm_backend__,
|
||||
model_id=llm.model_id,
|
||||
configuration=llm.config.model_dump_json().decode(), #
|
||||
)
|
||||
|
||||
|
||||
@svc.api(route='/v1/metadata', input=Text(), output=JSON.from_sample(_Metadata.model_dump()))
|
||||
def metadata_v1(_: str) -> openllm.MetadataOutput:
|
||||
return _Metadata
|
||||
|
||||
|
||||
class MessagesConverterInput(t.TypedDict):
|
||||
add_generation_prompt: bool
|
||||
messages: t.List[t.Dict[str, t.Any]]
|
||||
|
||||
|
||||
@svc.api(
|
||||
route='/v1/helpers/messages',
|
||||
input=JSON.from_sample(
|
||||
MessagesConverterInput(
|
||||
add_generation_prompt=False,
|
||||
messages=[
|
||||
MessageParam(role='system', content='You are acting as Ernest Hemmingway.'),
|
||||
MessageParam(role='user', content='Hi there!'),
|
||||
MessageParam(role='assistant', content='Yes?'), #
|
||||
],
|
||||
)
|
||||
),
|
||||
output=Text(),
|
||||
)
|
||||
def helpers_messages_v1(message: MessagesConverterInput) -> str:
|
||||
add_generation_prompt, messages = message['add_generation_prompt'], message['messages']
|
||||
return llm.tokenizer.apply_chat_template(messages, add_generation_prompt=add_generation_prompt, tokenize=False)
|
||||
|
||||
|
||||
openllm.mount_entrypoints(svc, llm) # HACK: This must always be the last line in this file, as we will do some MK for OpenAPI schema.
|
||||
@@ -1,13 +0,0 @@
|
||||
import os, orjson, openllm_core.utils as coreutils
|
||||
|
||||
model_id, model_tag, adapter_map, serialization, trust_remote_code = (
|
||||
os.environ['OPENLLM_MODEL_ID'],
|
||||
None,
|
||||
orjson.loads(os.getenv('OPENLLM_ADAPTER_MAP', orjson.dumps(None))),
|
||||
os.getenv('OPENLLM_SERIALIZATION', default='safetensors'),
|
||||
coreutils.check_bool_env('TRUST_REMOTE_CODE', False),
|
||||
)
|
||||
max_model_len, gpu_memory_utilization = (
|
||||
orjson.loads(os.getenv('MAX_MODEL_LEN', orjson.dumps(None).decode())),
|
||||
orjson.loads(os.getenv('GPU_MEMORY_UTILIZATION', orjson.dumps(0.9).decode())),
|
||||
)
|
||||
@@ -1,9 +0,0 @@
|
||||
from typing import Dict, Optional
|
||||
|
||||
from openllm_core._typing_compat import LiteralSerialisation
|
||||
|
||||
model_id: str = ...
|
||||
model_tag: Optional[str] = ...
|
||||
adapter_map: Optional[Dict[str, str]] = ...
|
||||
serialization: LiteralSerialisation = ...
|
||||
trust_remote_code: bool = ...
|
||||
@@ -158,12 +158,16 @@ class _ResourceMixin:
|
||||
elif isinstance(spec, list):
|
||||
return [str(x) for x in spec]
|
||||
else:
|
||||
raise TypeError(f"'{cls.__name__}.from_spec' only supports parsing spec of type int, str, or list, got '{type(spec)}' instead.")
|
||||
raise TypeError(
|
||||
f"'{cls.__name__}.from_spec' only supports parsing spec of type int, str, or list, got '{type(spec)}' instead."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def validate(cls, val: list[t.Any]) -> None:
|
||||
if cls.resource_id == 'amd.com/gpu':
|
||||
raise RuntimeError("AMD GPU validation is not yet supported. Make sure to call 'get_resource(..., validate=False)'")
|
||||
raise RuntimeError(
|
||||
"AMD GPU validation is not yet supported. Make sure to call 'get_resource(..., validate=False)'"
|
||||
)
|
||||
if not all(isinstance(i, str) for i in val):
|
||||
raise ValueError('Input list should be all string type.')
|
||||
|
||||
@@ -307,12 +311,18 @@ class CascadingResourceStrategy(bentoml.Strategy, coreutils.ReprMixin):
|
||||
worker_index,
|
||||
assigned_resource_per_worker,
|
||||
)
|
||||
raise IndexError(f"There aren't enough assigned GPU(s) for given worker id '{worker_index}' [required: {assigned_resource_per_worker}].")
|
||||
assigned_gpu = gpus[assigned_resource_per_worker * worker_index : assigned_resource_per_worker * (worker_index + 1)]
|
||||
raise IndexError(
|
||||
f"There aren't enough assigned GPU(s) for given worker id '{worker_index}' [required: {assigned_resource_per_worker}]."
|
||||
)
|
||||
assigned_gpu = gpus[
|
||||
assigned_resource_per_worker * worker_index : assigned_resource_per_worker * (worker_index + 1)
|
||||
]
|
||||
dev = ','.join(assigned_gpu)
|
||||
else:
|
||||
idx = worker_index // workers_per_resource
|
||||
if idx >= len(gpus):
|
||||
raise ValueError(f'Number of available GPU ({gpus}) preceeds the given workers_per_resource {workers_per_resource}')
|
||||
raise ValueError(
|
||||
f'Number of available GPU ({gpus}) preceeds the given workers_per_resource {workers_per_resource}'
|
||||
)
|
||||
dev = str(gpus[idx])
|
||||
return dev
|
||||
|
||||
@@ -13,7 +13,12 @@ class CascadingResourceStrategy:
|
||||
TODO: Support CloudTPUResource
|
||||
"""
|
||||
@classmethod
|
||||
def get_worker_count(cls, runnable_class: Type[bentoml.Runnable], resource_request: Optional[Dict[str, Any]], workers_per_resource: float) -> int:
|
||||
def get_worker_count(
|
||||
cls,
|
||||
runnable_class: Type[bentoml.Runnable],
|
||||
resource_request: Optional[Dict[str, Any]],
|
||||
workers_per_resource: float,
|
||||
) -> int:
|
||||
"""Return the number of workers to be used for the given runnable class.
|
||||
|
||||
Note that for all available GPU, the number of workers will always be 1.
|
||||
@@ -35,5 +40,7 @@ class CascadingResourceStrategy:
|
||||
worker_index: The index of the worker, start from 0.
|
||||
"""
|
||||
@staticmethod
|
||||
def transpile_workers_to_cuda_envvar(workers_per_resource: Union[float, int], gpus: List[str], worker_index: int) -> str:
|
||||
def transpile_workers_to_cuda_envvar(
|
||||
workers_per_resource: Union[float, int], gpus: List[str], worker_index: int
|
||||
) -> str:
|
||||
"""Convert given workers_per_resource to correct CUDA_VISIBLE_DEVICES string."""
|
||||
|
||||
@@ -10,7 +10,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
OPENLLM_DEV_BUILD = 'OPENLLM_DEV_BUILD'
|
||||
_service_file = pathlib.Path(os.path.abspath(__file__)).parent.parent / '_service.py'
|
||||
_SERVICE_VARS = '''import orjson;model_id,model_tag,adapter_map,serialization,trust_remote_code,max_model_len,gpu_memory_utilization='{__model_id__}','{__model_tag__}',orjson.loads("""{__model_adapter_map__}"""),'{__model_serialization__}',{__model_trust_remote_code__},{__max_model_len__},{__gpu_memory_utilization__}'''
|
||||
_SERVICE_VARS = '''import orjson;model_id,model_tag,adapter_map,serialization,trust_remote_code,max_model_len,gpu_memory_utilization,services_config='{__model_id__}','{__model_tag__}',orjson.loads("""{__model_adapter_map__}"""),'{__model_serialization__}',{__model_trust_remote_code__},{__max_model_len__},{__gpu_memory_utilization__},orjson.loads("""{__services_config__}""")'''
|
||||
|
||||
|
||||
def build_editable(path, package='openllm'):
|
||||
@@ -38,8 +38,11 @@ def build_editable(path, package='openllm'):
|
||||
def construct_python_options(llm, llm_fs, extra_dependencies=None, adapter_map=None):
|
||||
from . import RefResolver
|
||||
|
||||
openllm_package = 'openllm[vllm]' if llm.__llm_backend__.lower() == "vllm" else "openllm"
|
||||
packages = ['scipy', 'bentoml[tracing]>=1.1.11,<1.2', f'{openllm_package}>={RefResolver.from_strategy("release").version}'] # apparently bnb misses this one
|
||||
packages = [
|
||||
'scipy',
|
||||
'bentoml[tracing]>=1.2',
|
||||
f'openllm[vllm]>={RefResolver.from_strategy("release").version}',
|
||||
] # apparently bnb misses this one
|
||||
if adapter_map is not None:
|
||||
packages += ['openllm[fine-tune]']
|
||||
if extra_dependencies is not None:
|
||||
@@ -57,7 +60,18 @@ def construct_python_options(llm, llm_fs, extra_dependencies=None, adapter_map=N
|
||||
def construct_docker_options(llm, _, quantize, adapter_map, dockerfile_template, serialisation):
|
||||
from openllm_cli.entrypoint import process_environ
|
||||
|
||||
environ = process_environ(llm.config, llm.config['timeout'], 1.0, None, True, llm.model_id, None, llm._serialisation, llm, use_current_env=False)
|
||||
environ = process_environ(
|
||||
llm.config,
|
||||
llm.config['timeout'],
|
||||
1.0,
|
||||
None,
|
||||
True,
|
||||
llm.model_id,
|
||||
None,
|
||||
llm._serialisation,
|
||||
llm,
|
||||
use_current_env=False,
|
||||
)
|
||||
# XXX: We need to quote this so that the envvar in container recognize as valid json
|
||||
environ['OPENLLM_CONFIG'] = f"'{environ['OPENLLM_CONFIG']}'"
|
||||
environ.pop('BENTOML_HOME', None) # NOTE: irrelevant in container
|
||||
@@ -86,7 +100,10 @@ def create_bento(
|
||||
'start_name': llm.config['start_name'],
|
||||
'base_name_or_path': llm.model_id,
|
||||
'bundler': 'openllm.bundle',
|
||||
**{f'{package.replace("-","_")}_version': importlib.metadata.version(package) for package in {'openllm', 'openllm-core', 'openllm-client'}},
|
||||
**{
|
||||
f'{package.replace("-", "_")}_version': importlib.metadata.version(package)
|
||||
for package in {'openllm', 'openllm-core', 'openllm-client'}
|
||||
},
|
||||
})
|
||||
if adapter_map:
|
||||
labels.update(adapter_map)
|
||||
|
||||
@@ -13,7 +13,10 @@ from .._llm import LLM
|
||||
|
||||
def build_editable(path: str, package: LiteralString) -> Optional[str]: ...
|
||||
def construct_python_options(
|
||||
llm: LLM[M, T], llm_fs: FS, extra_dependencies: Optional[Tuple[str, ...]] = ..., adapter_map: Optional[Dict[str, str]] = ...
|
||||
llm: LLM[M, T],
|
||||
llm_fs: FS,
|
||||
extra_dependencies: Optional[Tuple[str, ...]] = ...,
|
||||
adapter_map: Optional[Dict[str, str]] = ...,
|
||||
) -> PythonOptions: ...
|
||||
def construct_docker_options(
|
||||
llm: LLM[M, T],
|
||||
|
||||
@@ -1,15 +1,16 @@
|
||||
import importlib
|
||||
from openllm_core.utils import LazyModule
|
||||
|
||||
_import_structure = {'openai': [], 'hf': [], 'cohere': []}
|
||||
_import_structure = {'openai': [], 'hf': []}
|
||||
|
||||
|
||||
def mount_entrypoints(svc, llm):
|
||||
for module_name in _import_structure:
|
||||
module = importlib.import_module(f'.{module_name}', __name__)
|
||||
svc = module.mount_to_svc(svc, llm)
|
||||
svc = importlib.import_module(f'.{module_name}', __name__).mount_to_svc(svc, llm)
|
||||
return svc
|
||||
|
||||
|
||||
__lazy = LazyModule(__name__, globals()['__file__'], _import_structure, extra_objects={'mount_entrypoints': mount_entrypoints})
|
||||
__lazy = LazyModule(
|
||||
__name__, globals()['__file__'], _import_structure, extra_objects={'mount_entrypoints': mount_entrypoints}
|
||||
)
|
||||
__all__, __dir__, __getattr__ = __lazy.__all__, __lazy.__dir__, __lazy.__getattr__
|
||||
|
||||
@@ -1,16 +1,17 @@
|
||||
"""Entrypoint for all third-party apps.
|
||||
|
||||
Currently support OpenAI, Cohere compatible API.
|
||||
Currently support OpenAI compatible API.
|
||||
|
||||
Each module should implement the following API:
|
||||
|
||||
- `mount_to_svc(svc: bentoml.Service, llm: openllm.LLM[M, T]) -> bentoml.Service: ...`
|
||||
"""
|
||||
|
||||
from bentoml import Service
|
||||
from typing import Any
|
||||
from _bentoml_sdk import Service
|
||||
from openllm_core._typing_compat import M, T
|
||||
|
||||
from . import cohere as cohere, hf as hf, openai as openai
|
||||
from . import hf as hf, openai as openai
|
||||
from .._llm import LLM
|
||||
|
||||
def mount_entrypoints(svc: Service, llm: LLM[M, T]) -> Service: ...
|
||||
def mount_entrypoints(svc: Service[Any], llm: LLM[M, T]) -> Service: ...
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import inspect
|
||||
import types
|
||||
@@ -9,6 +11,9 @@ from starlette.schemas import EndpointInfo, SchemaGenerator
|
||||
|
||||
from openllm_core.utils import first_not_none
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import pydantic
|
||||
|
||||
OPENAPI_VERSION, API_VERSION = '3.0.2', '1.0'
|
||||
# NOTE: OpenAI schema
|
||||
LIST_MODELS_SCHEMA = """\
|
||||
@@ -186,9 +191,7 @@ COMPLETIONS_SCHEMA = """\
|
||||
consumes:
|
||||
- application/json
|
||||
description: >-
|
||||
Given a prompt, the model will return one or more predicted completions, and
|
||||
can also return the probabilities of alternative tokens at each position. We
|
||||
recommend most users use our Chat completions API.
|
||||
Given a prompt, the model will return one or more predicted completions, and can also return the probabilities of alternative tokens at each position. We recommend most users use our Chat completions API.
|
||||
operationId: openai__completions
|
||||
produces:
|
||||
- application/json
|
||||
@@ -210,7 +213,7 @@ requestBody:
|
||||
model: __model_id__
|
||||
max_tokens: 256
|
||||
temperature: 0.7
|
||||
logprobs: 1
|
||||
logprobs: null
|
||||
top_p: 0.43
|
||||
n: 1
|
||||
stream: false
|
||||
@@ -222,7 +225,7 @@ requestBody:
|
||||
max_tokens: 256
|
||||
temperature: 0.7
|
||||
top_p: 0.43
|
||||
logprobs: 1
|
||||
logprobs: null
|
||||
n: 1
|
||||
stream: true
|
||||
stop:
|
||||
@@ -501,7 +504,11 @@ class OpenLLMSchemaGenerator(SchemaGenerator):
|
||||
endpoints_info.extend(sub_endpoints)
|
||||
elif not isinstance(route, Route) or not route.include_in_schema:
|
||||
continue
|
||||
elif inspect.isfunction(route.endpoint) or inspect.ismethod(route.endpoint) or isinstance(route.endpoint, functools.partial):
|
||||
elif (
|
||||
inspect.isfunction(route.endpoint)
|
||||
or inspect.ismethod(route.endpoint)
|
||||
or isinstance(route.endpoint, functools.partial)
|
||||
):
|
||||
endpoint = route.endpoint.func if isinstance(route.endpoint, functools.partial) else route.endpoint
|
||||
path = self._remove_converter(route.path)
|
||||
for method in route.methods or ['GET']:
|
||||
@@ -546,11 +553,13 @@ def get_generator(title, components=None, tags=None, inject=True):
|
||||
return OpenLLMSchemaGenerator(base_schema)
|
||||
|
||||
|
||||
def component_schema_generator(attr_cls, description=None):
|
||||
def component_schema_generator(attr_cls: pydantic.BaseModel, description=None):
|
||||
schema = {'type': 'object', 'required': [], 'properties': {}, 'title': attr_cls.__name__}
|
||||
schema['description'] = first_not_none(getattr(attr_cls, '__doc__', None), description, default=f'Generated components for {attr_cls.__name__}')
|
||||
for field in attr.fields(attr.resolve_types(attr_cls)):
|
||||
attr_type = field.type
|
||||
schema['description'] = first_not_none(
|
||||
getattr(attr_cls, '__doc__', None), description, default=f'Generated components for {attr_cls.__name__}'
|
||||
)
|
||||
for name, field in attr_cls.model_fields.items():
|
||||
attr_type = field.annotation
|
||||
origin_type = t.get_origin(attr_type)
|
||||
args_type = t.get_args(attr_type)
|
||||
|
||||
@@ -582,15 +591,18 @@ def component_schema_generator(attr_cls, description=None):
|
||||
if field.default is not attr.NOTHING and not isinstance(field.default, attr.Factory):
|
||||
prop_schema['default'] = field.default
|
||||
if field.default is attr.NOTHING and not isinstance(attr_type, type(t.Optional)):
|
||||
schema['required'].append(field.name)
|
||||
schema['properties'][field.name] = prop_schema
|
||||
schema['required'].append(name)
|
||||
schema['properties'][name] = prop_schema
|
||||
locals().pop('prop_schema', None)
|
||||
|
||||
return schema
|
||||
|
||||
|
||||
_SimpleSchema = types.new_class(
|
||||
'_SimpleSchema', (object,), {}, lambda ns: ns.update({'__init__': lambda self, it: setattr(self, 'it', it), 'asdict': lambda self: self.it})
|
||||
'_SimpleSchema',
|
||||
(object,),
|
||||
{},
|
||||
lambda ns: ns.update({'__init__': lambda self, it: setattr(self, 'it', it), 'asdict': lambda self: self.it}),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -17,8 +17,13 @@ class OpenLLMSchemaGenerator:
|
||||
|
||||
def apply_schema(func: Callable[P, Any], **attrs: Any) -> Callable[P, Any]: ...
|
||||
def add_schema_definitions(func: Callable[P, Any]) -> Callable[P, Any]: ...
|
||||
def append_schemas(svc: Service, generated_schema: Dict[str, Any], tags_order: Literal['prepend', 'append'] = ..., inject: bool = ...) -> Service: ...
|
||||
def append_schemas(
|
||||
svc: Service, generated_schema: Dict[str, Any], tags_order: Literal['prepend', 'append'] = ..., inject: bool = ...
|
||||
) -> Service: ...
|
||||
def component_schema_generator(attr_cls: Type[AttrsInstance], description: Optional[str] = ...) -> Dict[str, Any]: ...
|
||||
def get_generator(
|
||||
title: str, components: Optional[List[Type[AttrsInstance]]] = ..., tags: Optional[List[Dict[str, Any]]] = ..., inject: bool = ...
|
||||
title: str,
|
||||
components: Optional[List[Type[AttrsInstance]]] = ...,
|
||||
tags: Optional[List[Dict[str, Any]]] = ...,
|
||||
inject: bool = ...,
|
||||
) -> OpenLLMSchemaGenerator: ...
|
||||
|
||||
@@ -1,272 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import functools, json, logging, traceback
|
||||
from http import HTTPStatus
|
||||
import orjson
|
||||
from starlette.applications import Starlette
|
||||
from starlette.responses import JSONResponse, StreamingResponse
|
||||
from starlette.routing import Route
|
||||
from openllm_core.utils import DEBUG, converter, gen_random_uuid
|
||||
from ._openapi import add_schema_definitions, append_schemas, get_generator
|
||||
from ..protocol.cohere import (
|
||||
Chat,
|
||||
ChatStreamEnd,
|
||||
ChatStreamStart,
|
||||
ChatStreamTextGeneration,
|
||||
CohereChatRequest,
|
||||
CohereErrorResponse,
|
||||
CohereGenerateRequest,
|
||||
Generation,
|
||||
Generations,
|
||||
StreamingGenerations,
|
||||
StreamingText,
|
||||
)
|
||||
|
||||
schemas = get_generator(
|
||||
'cohere',
|
||||
components=[
|
||||
CohereChatRequest,
|
||||
CohereErrorResponse,
|
||||
CohereGenerateRequest,
|
||||
Generation,
|
||||
Generations,
|
||||
StreamingGenerations,
|
||||
StreamingText,
|
||||
Chat,
|
||||
ChatStreamStart,
|
||||
ChatStreamEnd,
|
||||
ChatStreamTextGeneration,
|
||||
],
|
||||
tags=[
|
||||
{
|
||||
'name': 'Cohere',
|
||||
'description': 'Cohere compatible API. Currently support /generate, /chat',
|
||||
'externalDocs': 'https://docs.cohere.com/docs/the-cohere-platform',
|
||||
}
|
||||
],
|
||||
inject=DEBUG,
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def jsonify_attr(obj):
|
||||
return json.dumps(converter.unstructure(obj))
|
||||
|
||||
|
||||
def error_response(status_code, message):
|
||||
return JSONResponse(converter.unstructure(CohereErrorResponse(text=message)), status_code=status_code.value)
|
||||
|
||||
|
||||
async def check_model(request, model):
|
||||
if request.model is None or request.model == model:
|
||||
return None
|
||||
return error_response(HTTPStatus.NOT_FOUND, f"Model '{request.model}' does not exists. Try 'GET /v1/models' to see current running models.")
|
||||
|
||||
|
||||
def mount_to_svc(svc, llm):
|
||||
app = Starlette(
|
||||
debug=True,
|
||||
routes=[
|
||||
Route('/schema', endpoint=lambda req: schemas.OpenAPIResponse(req), include_in_schema=False),
|
||||
Route('/v1/chat', endpoint=functools.partial(cohere_chat, llm=llm), name='cohere_chat', methods=['POST']),
|
||||
Route('/v1/generate', endpoint=functools.partial(cohere_generate, llm=llm), name='cohere_generate', methods=['POST']),
|
||||
],
|
||||
)
|
||||
mount_path = '/cohere'
|
||||
|
||||
svc.mount_asgi_app(app, path=mount_path)
|
||||
return append_schemas(svc, schemas.get_schema(routes=app.routes, mount_path=mount_path), tags_order='append', inject=DEBUG)
|
||||
|
||||
|
||||
@add_schema_definitions
|
||||
async def cohere_generate(req, llm):
|
||||
json_str = await req.body()
|
||||
try:
|
||||
request = converter.structure(orjson.loads(json_str), CohereGenerateRequest)
|
||||
except orjson.JSONDecodeError as err:
|
||||
logger.debug('Sent body: %s', json_str)
|
||||
logger.error('Invalid JSON input received: %s', err)
|
||||
return error_response(HTTPStatus.BAD_REQUEST, 'Invalid JSON input received (Check server log).')
|
||||
logger.debug('Received generate request: %s', request)
|
||||
|
||||
err_check = await check_model(request, llm.llm_type)
|
||||
if err_check is not None:
|
||||
return err_check
|
||||
request_id = gen_random_uuid('cohere-generate')
|
||||
config = llm.config.compatible_options(request)
|
||||
|
||||
if request.prompt_vars is not None:
|
||||
prompt = request.prompt.format(**request.prompt_vars)
|
||||
else:
|
||||
prompt = request.prompt
|
||||
|
||||
# TODO: support end_sequences, stop_sequences, logit_bias, return_likelihoods, truncate
|
||||
|
||||
try:
|
||||
result_generator = llm.generate_iterator(prompt, request_id=request_id, stop=request.stop_sequences, **config)
|
||||
except Exception as err:
|
||||
traceback.print_exc()
|
||||
logger.error('Error generating completion: %s', err)
|
||||
return error_response(HTTPStatus.INTERNAL_SERVER_ERROR, f'Exception: {err!s} (check server log)')
|
||||
|
||||
def create_stream_response_json(index, text, is_finished):
|
||||
return f'{jsonify_attr(StreamingText(index=index, text=text, is_finished=is_finished))}\n'
|
||||
|
||||
async def generate_stream_generator():
|
||||
async for res in result_generator:
|
||||
for output in res.outputs:
|
||||
yield create_stream_response_json(index=output.index, text=output.text, is_finished=output.finish_reason)
|
||||
|
||||
try:
|
||||
# streaming case
|
||||
if request.stream:
|
||||
return StreamingResponse(generate_stream_generator(), media_type='text/event-stream')
|
||||
# None-streaming case
|
||||
final_result = None
|
||||
texts, token_ids = [[]] * config['n'], [[]] * config['n']
|
||||
async for res in result_generator:
|
||||
if await req.is_disconnected():
|
||||
return error_response(HTTPStatus.BAD_REQUEST, 'Client disconnected.')
|
||||
for output in res.outputs:
|
||||
texts[output.index].append(output.text)
|
||||
token_ids[output.index].extend(output.token_ids)
|
||||
final_result = res
|
||||
if final_result is None:
|
||||
return error_response(HTTPStatus.BAD_REQUEST, 'No response from model.')
|
||||
final_result = final_result.with_options(
|
||||
outputs=[output.with_options(text=''.join(texts[output.index]), token_ids=token_ids[output.index]) for output in final_result.outputs]
|
||||
)
|
||||
return JSONResponse(
|
||||
converter.unstructure(
|
||||
Generations(
|
||||
id=request_id,
|
||||
generations=[
|
||||
Generation(id=request_id, text=output.text, prompt=prompt, finish_reason=output.finish_reason) for output in final_result.outputs
|
||||
],
|
||||
)
|
||||
),
|
||||
status_code=HTTPStatus.OK.value,
|
||||
)
|
||||
except Exception as err:
|
||||
traceback.print_exc()
|
||||
logger.error('Error generating completion: %s', err)
|
||||
return error_response(HTTPStatus.INTERNAL_SERVER_ERROR, f'Exception: {err!s} (check server log)')
|
||||
|
||||
|
||||
def _transpile_cohere_chat_messages(request: CohereChatRequest) -> list[dict[str, str]]:
|
||||
def convert_role(role):
|
||||
return {'User': 'user', 'Chatbot': 'assistant'}[role]
|
||||
|
||||
chat_history = request.chat_history
|
||||
if chat_history:
|
||||
messages = [{'role': convert_role(msg['role']), 'content': msg['message']} for msg in chat_history]
|
||||
else:
|
||||
messages = []
|
||||
messages.append({'role': 'user', 'content': request.message})
|
||||
return messages
|
||||
|
||||
|
||||
@add_schema_definitions
|
||||
async def cohere_chat(req, llm):
|
||||
json_str = await req.body()
|
||||
try:
|
||||
request = converter.structure(orjson.loads(json_str), CohereChatRequest)
|
||||
except orjson.JSONDecodeError as err:
|
||||
logger.debug('Sent body: %s', json_str)
|
||||
logger.error('Invalid JSON input received: %s', err)
|
||||
return error_response(HTTPStatus.BAD_REQUEST, 'Invalid JSON input received (Check server log).')
|
||||
logger.debug('Received chat completion request: %s', request)
|
||||
|
||||
err_check = await check_model(request, llm.llm_type)
|
||||
if err_check is not None:
|
||||
return err_check
|
||||
|
||||
request_id = gen_random_uuid('cohere-chat')
|
||||
prompt: str = llm.tokenizer.apply_chat_template(
|
||||
_transpile_cohere_chat_messages(request), tokenize=False, add_generation_prompt=llm.config['add_generation_prompt']
|
||||
)
|
||||
logger.debug('Prompt: %r', prompt)
|
||||
config = llm.config.compatible_options(request)
|
||||
|
||||
try:
|
||||
result_generator = llm.generate_iterator(prompt, request_id=request_id, **config)
|
||||
except Exception as err:
|
||||
traceback.print_exc()
|
||||
logger.error('Error generating completion: %s', err)
|
||||
return error_response(HTTPStatus.INTERNAL_SERVER_ERROR, f'Exception: {err!s} (check server log)')
|
||||
|
||||
def create_stream_generation_json(index: int, text: str, is_finished: bool) -> str:
|
||||
return f'{jsonify_attr(ChatStreamTextGeneration(index=index, text=text, is_finished=is_finished))}\n'
|
||||
|
||||
async def completion_stream_generator():
|
||||
texts, token_ids = [], []
|
||||
yield f'{jsonify_attr(ChatStreamStart(is_finished=False, index=0, generation_id=request_id))}\n'
|
||||
|
||||
it = None
|
||||
async for res in result_generator:
|
||||
yield create_stream_generation_json(index=res.outputs[0].index, text=res.outputs[0].text, is_finished=False)
|
||||
texts.append(res.outputs[0].text)
|
||||
token_ids.extend(res.outputs[0].token_ids)
|
||||
it = res
|
||||
|
||||
if it is None:
|
||||
raise ValueError('No response from model.')
|
||||
num_prompt_tokens, num_response_tokens = len(it.prompt_token_ids), len(token_ids)
|
||||
|
||||
json_str = jsonify_attr(
|
||||
ChatStreamEnd(
|
||||
is_finished=True,
|
||||
finish_reason='COMPLETE',
|
||||
index=0,
|
||||
response=Chat(
|
||||
response_id=request_id,
|
||||
message=request.message,
|
||||
text=''.join(texts),
|
||||
prompt=prompt,
|
||||
chat_history=request.chat_history,
|
||||
token_count={
|
||||
'prompt_tokens': num_prompt_tokens,
|
||||
'response_tokens': num_response_tokens,
|
||||
'total_tokens': num_prompt_tokens + num_response_tokens,
|
||||
},
|
||||
),
|
||||
)
|
||||
)
|
||||
yield f'{json_str}\n'
|
||||
|
||||
try:
|
||||
if request.stream:
|
||||
return StreamingResponse(completion_stream_generator(), media_type='text/event-stream')
|
||||
# Non-streaming case
|
||||
final_result = None
|
||||
texts, token_ids = [], []
|
||||
async for res in result_generator:
|
||||
if await req.is_disconnected():
|
||||
return error_response(HTTPStatus.BAD_REQUEST, 'Client disconnected.')
|
||||
texts.append(res.outputs[0].text)
|
||||
token_ids.extend(res.outputs[0].token_ids)
|
||||
final_result = res
|
||||
if final_result is None:
|
||||
return error_response(HTTPStatus.BAD_REQUEST, 'No response from model.')
|
||||
final_result = final_result.with_options(outputs=[final_result.outputs[0].with_options(text=''.join(texts), token_ids=token_ids)])
|
||||
num_prompt_tokens, num_response_tokens = len(final_result.prompt_token_ids), len(token_ids)
|
||||
return JSONResponse(
|
||||
converter.unstructure(
|
||||
Chat(
|
||||
response_id=request_id,
|
||||
message=request.message,
|
||||
text=''.join(texts),
|
||||
prompt=prompt,
|
||||
chat_history=request.chat_history,
|
||||
token_count={
|
||||
'prompt_tokens': num_prompt_tokens,
|
||||
'response_tokens': num_response_tokens,
|
||||
'total_tokens': num_prompt_tokens + num_response_tokens,
|
||||
},
|
||||
)
|
||||
),
|
||||
status_code=HTTPStatus.OK.value,
|
||||
)
|
||||
except Exception as err:
|
||||
traceback.print_exc()
|
||||
logger.error('Error generating completion: %s', err)
|
||||
return error_response(HTTPStatus.INTERNAL_SERVER_ERROR, f'Exception: {err!s} (check server log)')
|
||||
@@ -1,19 +0,0 @@
|
||||
from http import HTTPStatus
|
||||
from typing import Optional, Union
|
||||
|
||||
from attr import AttrsInstance
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse, Response
|
||||
|
||||
from bentoml import Service
|
||||
from openllm_core._typing_compat import M, T
|
||||
|
||||
from .._llm import LLM
|
||||
from ..protocol.cohere import CohereChatRequest, CohereGenerateRequest
|
||||
|
||||
def mount_to_svc(svc: Service, llm: LLM[M, T]) -> Service: ...
|
||||
def jsonify_attr(obj: AttrsInstance) -> str: ...
|
||||
def error_response(status_code: HTTPStatus, message: str) -> JSONResponse: ...
|
||||
async def check_model(request: Union[CohereGenerateRequest, CohereChatRequest], model: str) -> Optional[JSONResponse]: ...
|
||||
async def cohere_generate(req: Request, llm: LLM[M, T]) -> Response: ...
|
||||
async def cohere_chat(req: Request, llm: LLM[M, T]) -> Response: ...
|
||||
@@ -27,7 +27,6 @@ def mount_to_svc(svc, llm):
|
||||
debug=True,
|
||||
routes=[
|
||||
Route('/agent', endpoint=functools.partial(hf_agent, llm=llm), name='hf_agent', methods=['POST']),
|
||||
Route('/adapters', endpoint=functools.partial(hf_adapters, llm=llm), name='adapters', methods=['GET']),
|
||||
Route('/schema', endpoint=lambda req: schemas.OpenAPIResponse(req), include_in_schema=False),
|
||||
],
|
||||
)
|
||||
@@ -37,7 +36,10 @@ def mount_to_svc(svc, llm):
|
||||
|
||||
|
||||
def error_response(status_code, message):
|
||||
return JSONResponse(converter.unstructure(HFErrorResponse(message=message, error_code=status_code.value)), status_code=status_code.value)
|
||||
return JSONResponse(
|
||||
converter.unstructure(HFErrorResponse(message=message, error_code=status_code.value)),
|
||||
status_code=status_code.value,
|
||||
)
|
||||
|
||||
|
||||
@add_schema_definitions
|
||||
@@ -53,20 +55,9 @@ async def hf_agent(req, llm):
|
||||
stop = request.parameters.pop('stop', [])
|
||||
try:
|
||||
result = await llm.generate(request.inputs, stop=stop, **request.parameters)
|
||||
return JSONResponse(converter.unstructure([AgentResponse(generated_text=result.outputs[0].text)]), status_code=HTTPStatus.OK.value)
|
||||
return JSONResponse(
|
||||
converter.unstructure([AgentResponse(generated_text=result.outputs[0].text)]), status_code=HTTPStatus.OK.value
|
||||
)
|
||||
except Exception as err:
|
||||
logger.error('Error while generating: %s', err)
|
||||
return error_response(HTTPStatus.INTERNAL_SERVER_ERROR, 'Error while generating (Check server log).')
|
||||
|
||||
|
||||
@add_schema_definitions
|
||||
def hf_adapters(req, llm):
|
||||
if not llm.has_adapters:
|
||||
return error_response(HTTPStatus.NOT_FOUND, 'No adapters found.')
|
||||
return JSONResponse(
|
||||
{
|
||||
adapter_tuple[1]: {'adapter_name': k, 'adapter_type': adapter_tuple[0].peft_type.value}
|
||||
for k, adapter_tuple in dict(*llm.adapter_map.values()).items()
|
||||
},
|
||||
status_code=HTTPStatus.OK.value,
|
||||
)
|
||||
|
||||
@@ -12,7 +12,7 @@ from starlette.routing import Route
|
||||
from openllm_core.utils import converter, gen_random_uuid
|
||||
|
||||
from ._openapi import add_schema_definitions, append_schemas, apply_schema, get_generator
|
||||
from ..protocol.openai import (
|
||||
from openllm_core.protocol.openai import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseChoice,
|
||||
@@ -61,7 +61,11 @@ def jsonify_attr(obj):
|
||||
|
||||
def error_response(status_code, message):
|
||||
return JSONResponse(
|
||||
{'error': converter.unstructure(ErrorResponse(message=message, type='invalid_request_error', code=str(status_code.value)))},
|
||||
{
|
||||
'error': converter.unstructure(
|
||||
ErrorResponse(message=message, type='invalid_request_error', code=str(status_code.value))
|
||||
)
|
||||
},
|
||||
status_code=status_code.value,
|
||||
)
|
||||
|
||||
@@ -95,7 +99,11 @@ def create_logprobs(token_ids, top_logprobs, num_output_top_logprobs=None, initi
|
||||
logprobs.text_offset.append(logprobs.text_offset[-1] + last_token_len)
|
||||
last_token_len = len(token)
|
||||
if num_output_top_logprobs:
|
||||
logprobs.top_logprobs.append({llm.tokenizer.convert_ids_to_tokens(i): p for i, p in step_top_logprobs.items()} if step_top_logprobs else None)
|
||||
logprobs.top_logprobs.append(
|
||||
{llm.tokenizer.convert_ids_to_tokens(i): p for i, p in step_top_logprobs.items()}
|
||||
if step_top_logprobs
|
||||
else None
|
||||
)
|
||||
return logprobs
|
||||
|
||||
|
||||
@@ -106,8 +114,14 @@ def mount_to_svc(svc, llm):
|
||||
app = Starlette(
|
||||
debug=True,
|
||||
routes=[
|
||||
Route('/models', functools.partial(apply_schema(list_models, __model_id__=llm.llm_type), llm=llm), methods=['GET']),
|
||||
Route('/completions', functools.partial(apply_schema(completions, __model_id__=llm.llm_type), llm=llm), methods=['POST']),
|
||||
Route(
|
||||
'/models', functools.partial(apply_schema(list_models, __model_id__=llm.llm_type), llm=llm), methods=['GET']
|
||||
),
|
||||
Route(
|
||||
'/completions',
|
||||
functools.partial(apply_schema(completions, __model_id__=llm.llm_type), llm=llm),
|
||||
methods=['POST'],
|
||||
),
|
||||
Route(
|
||||
'/chat/completions',
|
||||
functools.partial(
|
||||
@@ -132,7 +146,9 @@ def mount_to_svc(svc, llm):
|
||||
# GET /v1/models
|
||||
@add_schema_definitions
|
||||
def list_models(_, llm):
|
||||
return JSONResponse(converter.unstructure(ModelList(data=[ModelCard(id=llm.llm_type)])), status_code=HTTPStatus.OK.value)
|
||||
return JSONResponse(
|
||||
converter.unstructure(ModelList(data=[ModelCard(id=llm.llm_type)])), status_code=HTTPStatus.OK.value
|
||||
)
|
||||
|
||||
|
||||
# POST /v1/chat/completions
|
||||
@@ -166,7 +182,9 @@ async def chat_completions(req, llm):
|
||||
config = llm.config.compatible_options(request)
|
||||
|
||||
def get_role() -> str:
|
||||
return request.messages[-1]['role'] if not request.add_generation_prompt else 'assistant' # TODO: Support custom role here.
|
||||
return (
|
||||
request.messages[-1]['role'] if not request.add_generation_prompt else 'assistant'
|
||||
) # TODO: Support custom role here.
|
||||
|
||||
try:
|
||||
result_generator = llm.generate_iterator(prompt, request_id=request_id, **config)
|
||||
@@ -180,7 +198,9 @@ async def chat_completions(req, llm):
|
||||
id=request_id,
|
||||
created=created_time,
|
||||
model=model_name,
|
||||
choices=[ChatCompletionResponseStreamChoice(index=index, delta=Delta(content=text), finish_reason=finish_reason)],
|
||||
choices=[
|
||||
ChatCompletionResponseStreamChoice(index=index, delta=Delta(content=text), finish_reason=finish_reason)
|
||||
],
|
||||
)
|
||||
if usage is not None:
|
||||
response.usage = usage
|
||||
@@ -230,13 +250,20 @@ async def chat_completions(req, llm):
|
||||
final_result = res
|
||||
if final_result is None:
|
||||
return error_response(HTTPStatus.BAD_REQUEST, 'No response from model.')
|
||||
final_result = final_result.with_options(
|
||||
outputs=[output.with_options(text=''.join(texts[output.index]), token_ids=token_ids[output.index]) for output in final_result.outputs]
|
||||
final_result = final_result.model_copy(
|
||||
update=dict(
|
||||
outputs=[
|
||||
output.model_copy(update=dict(text=''.join(texts[output.index]), token_ids=token_ids[output.index]))
|
||||
for output in final_result.outputs
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
role = get_role()
|
||||
choices = [
|
||||
ChatCompletionResponseChoice(index=output.index, message=ChatMessage(role=role, content=output.text), finish_reason=output.finish_reason)
|
||||
ChatCompletionResponseChoice(
|
||||
index=output.index, message=ChatMessage(role=role, content=output.text), finish_reason=output.finish_reason
|
||||
)
|
||||
for output in final_result.outputs
|
||||
]
|
||||
if request.echo:
|
||||
@@ -250,7 +277,9 @@ async def chat_completions(req, llm):
|
||||
num_prompt_tokens = len(final_result.prompt_token_ids)
|
||||
num_generated_tokens = sum(len(output.token_ids) for output in final_result.outputs)
|
||||
usage = UsageInfo(num_prompt_tokens, num_generated_tokens, num_prompt_tokens + num_generated_tokens)
|
||||
response = ChatCompletionResponse(id=request_id, created=created_time, model=model_name, usage=usage, choices=choices)
|
||||
response = ChatCompletionResponse(
|
||||
id=request_id, created=created_time, model=model_name, usage=usage, choices=choices
|
||||
)
|
||||
return JSONResponse(converter.unstructure(response), status_code=HTTPStatus.OK.value)
|
||||
except Exception as err:
|
||||
traceback.print_exc()
|
||||
@@ -342,7 +371,13 @@ async def completions(req, llm):
|
||||
top_logprobs = res.prompt_logprobs
|
||||
previous_echo[i] = True
|
||||
if request.logprobs is not None:
|
||||
logprobs = create_logprobs(output.token_ids, output.logprobs[previous_num_tokens[i] :], request.logprobs, len(previous_texts[i]), llm=llm)
|
||||
logprobs = create_logprobs(
|
||||
output.token_ids,
|
||||
output.logprobs[previous_num_tokens[i] :],
|
||||
request.logprobs,
|
||||
len(previous_texts[i]),
|
||||
llm=llm,
|
||||
)
|
||||
previous_num_tokens[i] += len(output.token_ids)
|
||||
previous_texts[i] += output.text
|
||||
yield f'data: {create_stream_response_json(index=i, text=output.text, logprobs=logprobs, finish_reason=output.finish_reason)}\n\n'
|
||||
@@ -368,8 +403,13 @@ async def completions(req, llm):
|
||||
final_result = res
|
||||
if final_result is None:
|
||||
return error_response(HTTPStatus.BAD_REQUEST, 'No response from model.')
|
||||
final_result = final_result.with_options(
|
||||
outputs=[output.with_options(text=''.join(texts[output.index]), token_ids=token_ids[output.index]) for output in final_result.outputs]
|
||||
final_result = final_result.model_copy(
|
||||
update=dict(
|
||||
outputs=[
|
||||
output.model_copy(update=dict(text=''.join(texts[output.index]), token_ids=token_ids[output.index]))
|
||||
for output in final_result.outputs
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
choices = []
|
||||
@@ -392,7 +432,9 @@ async def completions(req, llm):
|
||||
output_text = prompt_text + output_text
|
||||
else:
|
||||
output_text = prompt_text
|
||||
choice_data = CompletionResponseChoice(index=output.index, text=output_text, logprobs=logprobs, finish_reason=output.finish_reason)
|
||||
choice_data = CompletionResponseChoice(
|
||||
index=output.index, text=output_text, logprobs=logprobs, finish_reason=output.finish_reason
|
||||
)
|
||||
choices.append(choice_data)
|
||||
|
||||
num_prompt_tokens = len(final_result.prompt_token_ids)
|
||||
|
||||
@@ -14,7 +14,9 @@ from ..protocol.openai import ChatCompletionRequest, CompletionRequest, LogProbs
|
||||
def mount_to_svc(svc: Service, llm: LLM[M, T]) -> Service: ...
|
||||
def jsonify_attr(obj: AttrsInstance) -> str: ...
|
||||
def error_response(status_code: HTTPStatus, message: str) -> JSONResponse: ...
|
||||
async def check_model(request: Union[CompletionRequest, ChatCompletionRequest], model: str) -> Optional[JSONResponse]: ...
|
||||
async def check_model(
|
||||
request: Union[CompletionRequest, ChatCompletionRequest], model: str
|
||||
) -> Optional[JSONResponse]: ...
|
||||
def create_logprobs(
|
||||
token_ids: List[int],
|
||||
top_logprobs: List[Dict[int, float]], #
|
||||
|
||||
@@ -7,4 +7,5 @@ from openllm_core.exceptions import (
|
||||
ValidationError as ValidationError, #
|
||||
MissingAnnotationAttributeError as MissingAnnotationAttributeError,
|
||||
MissingDependencyError as MissingDependencyError,
|
||||
ModelNotFound as ModelNotFound, #
|
||||
)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from __future__ import annotations
|
||||
import openllm, transformers, typing as t
|
||||
import openllm, typing as t
|
||||
|
||||
|
||||
def load_model(llm: openllm.LLM, config: transformers.PretrainedConfig, **attrs: t.Any): ...
|
||||
def load_model(llm: openllm.LLM, *args: t.Any, **attrs: t.Any): ...
|
||||
|
||||
@@ -1,11 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import os
|
||||
import typing as t
|
||||
|
||||
from openllm_core.utils import LazyModule
|
||||
|
||||
_import_structure: dict[str, list[str]] = {'openai': [], 'cohere': [], 'hf': []}
|
||||
if t.TYPE_CHECKING:
|
||||
from . import cohere as cohere, hf as hf, openai as openai
|
||||
__lazy = LazyModule(__name__, os.path.abspath('__file__'), _import_structure)
|
||||
__all__, __dir__, __getattr__ = __lazy.__all__, __lazy.__dir__, __lazy.__getattr__
|
||||
@@ -1,154 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import typing as t
|
||||
from enum import Enum
|
||||
|
||||
import attr
|
||||
|
||||
from openllm_core.utils import converter
|
||||
|
||||
|
||||
@attr.define
|
||||
class CohereErrorResponse:
|
||||
text: str
|
||||
|
||||
|
||||
converter.register_unstructure_hook(CohereErrorResponse, lambda obj: obj.text)
|
||||
|
||||
|
||||
@attr.define
|
||||
class CohereGenerateRequest:
|
||||
prompt: str
|
||||
prompt_vars: t.Optional[t.Dict[str, t.Any]] = None
|
||||
model: t.Optional[str] = None
|
||||
preset: t.Optional[str] = None
|
||||
num_generations: t.Optional[int] = None
|
||||
max_tokens: t.Optional[int] = None
|
||||
temperature: t.Optional[float] = None
|
||||
k: t.Optional[int] = None
|
||||
p: t.Optional[float] = None
|
||||
frequency_penalty: t.Optional[float] = None
|
||||
presence_penalty: t.Optional[float] = None
|
||||
end_sequences: t.Optional[t.List[str]] = None
|
||||
stop_sequences: t.Optional[t.List[str]] = None
|
||||
return_likelihoods: t.Optional[t.Literal['GENERATION', 'ALL', 'NONE']] = None
|
||||
truncate: t.Optional[str] = None
|
||||
logit_bias: t.Optional[t.Dict[int, float]] = None
|
||||
stream: bool = False
|
||||
|
||||
|
||||
@attr.define
|
||||
class TokenLikelihood: # pretty sure this is similar to token_logprobs
|
||||
token: str
|
||||
likelihood: float
|
||||
|
||||
|
||||
@attr.define
|
||||
class Generation:
|
||||
id: str
|
||||
text: str
|
||||
prompt: str
|
||||
likelihood: t.Optional[float] = None
|
||||
token_likelihoods: t.List[TokenLikelihood] = attr.field(factory=list)
|
||||
finish_reason: t.Optional[str] = None
|
||||
|
||||
|
||||
@attr.define
|
||||
class Generations:
|
||||
id: str
|
||||
generations: t.List[Generation]
|
||||
meta: t.Optional[t.Dict[str, t.Any]] = None
|
||||
|
||||
|
||||
@attr.define
|
||||
class StreamingText:
|
||||
index: int
|
||||
text: str
|
||||
is_finished: bool
|
||||
|
||||
|
||||
@attr.define
|
||||
class StreamingGenerations:
|
||||
id: str
|
||||
generations: Generations
|
||||
texts: t.List[str]
|
||||
meta: t.Optional[t.Dict[str, t.Any]] = None
|
||||
|
||||
|
||||
@attr.define
|
||||
class CohereChatRequest:
|
||||
message: str
|
||||
conversation_id: t.Optional[str] = ''
|
||||
model: t.Optional[str] = None
|
||||
return_chat_history: t.Optional[bool] = False
|
||||
return_prompt: t.Optional[bool] = False
|
||||
return_preamble: t.Optional[bool] = False
|
||||
chat_history: t.Optional[t.List[t.Dict[str, str]]] = None
|
||||
preamble_override: t.Optional[str] = None
|
||||
user_name: t.Optional[str] = None
|
||||
temperature: t.Optional[float] = 0.8
|
||||
max_tokens: t.Optional[int] = None
|
||||
stream: t.Optional[bool] = False
|
||||
p: t.Optional[float] = None
|
||||
k: t.Optional[float] = None
|
||||
logit_bias: t.Optional[t.Dict[int, float]] = None
|
||||
search_queries_only: t.Optional[bool] = None
|
||||
documents: t.Optional[t.List[t.Dict[str, t.Any]]] = None
|
||||
citation_quality: t.Optional[str] = None
|
||||
prompt_truncation: t.Optional[str] = None
|
||||
connectors: t.Optional[t.List[t.Dict[str, t.Any]]] = None
|
||||
|
||||
|
||||
class StreamEvent(str, Enum):
|
||||
STREAM_START = 'stream-start'
|
||||
TEXT_GENERATION = 'text-generation'
|
||||
STREAM_END = 'stream-end'
|
||||
# TODO: The following are yet to be implemented
|
||||
SEARCH_QUERIES_GENERATION = 'search-queries-generation'
|
||||
SEARCH_RESULTS = 'search-results'
|
||||
CITATION_GENERATION = 'citation-generation'
|
||||
|
||||
|
||||
@attr.define
|
||||
class Chat:
|
||||
response_id: str
|
||||
message: str
|
||||
text: str
|
||||
generation_id: t.Optional[str] = None
|
||||
conversation_id: t.Optional[str] = None
|
||||
meta: t.Optional[t.Dict[str, t.Any]] = None
|
||||
prompt: t.Optional[str] = None
|
||||
chat_history: t.Optional[t.List[t.Dict[str, t.Any]]] = None
|
||||
preamble: t.Optional[str] = None
|
||||
token_count: t.Optional[t.Dict[str, int]] = None
|
||||
is_search_required: t.Optional[bool] = None
|
||||
citations: t.Optional[t.List[t.Dict[str, t.Any]]] = None
|
||||
documents: t.Optional[t.List[t.Dict[str, t.Any]]] = None
|
||||
search_results: t.Optional[t.List[t.Dict[str, t.Any]]] = None
|
||||
search_queries: t.Optional[t.List[t.Dict[str, t.Any]]] = None
|
||||
|
||||
|
||||
@attr.define
|
||||
class ChatStreamResponse:
|
||||
is_finished: bool
|
||||
event_type: StreamEvent
|
||||
index: int
|
||||
|
||||
|
||||
@attr.define
|
||||
class ChatStreamStart(ChatStreamResponse):
|
||||
generation_id: str
|
||||
conversation_id: t.Optional[str] = None
|
||||
event_type: StreamEvent = StreamEvent.STREAM_START
|
||||
|
||||
|
||||
@attr.define
|
||||
class ChatStreamTextGeneration(ChatStreamResponse):
|
||||
text: str
|
||||
event_type: StreamEvent = StreamEvent.TEXT_GENERATION
|
||||
|
||||
|
||||
@attr.define
|
||||
class ChatStreamEnd(ChatStreamResponse):
|
||||
finish_reason: str
|
||||
response: Chat
|
||||
event_type: StreamEvent = StreamEvent.STREAM_END
|
||||
@@ -1,21 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import typing as t
|
||||
|
||||
import attr
|
||||
|
||||
|
||||
@attr.define
|
||||
class AgentRequest:
|
||||
inputs: str
|
||||
parameters: t.Dict[str, t.Any]
|
||||
|
||||
|
||||
@attr.define
|
||||
class AgentResponse:
|
||||
generated_text: str
|
||||
|
||||
|
||||
@attr.define
|
||||
class HFErrorResponse:
|
||||
error_code: int
|
||||
message: str
|
||||
@@ -1,188 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import time
|
||||
import typing as t
|
||||
|
||||
import attr
|
||||
|
||||
import openllm_core
|
||||
from openllm_core._schemas import FinishReason
|
||||
from openllm_core.utils import converter
|
||||
|
||||
|
||||
@attr.define
|
||||
class ErrorResponse:
|
||||
message: str
|
||||
type: str
|
||||
object: str = 'error'
|
||||
param: t.Optional[str] = None
|
||||
code: t.Optional[str] = None
|
||||
|
||||
|
||||
def _stop_converter(data: t.Union[str, t.List[str]]) -> t.List[str]:
|
||||
if not data:
|
||||
return None
|
||||
return [data] if isinstance(data, str) else data
|
||||
|
||||
|
||||
@attr.define
|
||||
class CompletionRequest:
|
||||
prompt: str
|
||||
model: str = attr.field(default=None)
|
||||
suffix: t.Optional[str] = attr.field(default=None)
|
||||
max_tokens: t.Optional[int] = attr.field(default=16)
|
||||
temperature: t.Optional[float] = attr.field(default=1.0)
|
||||
top_p: t.Optional[float] = attr.field(default=1.0)
|
||||
n: t.Optional[int] = attr.field(default=1)
|
||||
stream: t.Optional[bool] = attr.field(default=False)
|
||||
logprobs: t.Optional[int] = attr.field(default=None)
|
||||
echo: t.Optional[bool] = attr.field(default=False)
|
||||
stop: t.Optional[t.Union[str, t.List[str]]] = attr.field(default=None, converter=_stop_converter)
|
||||
presence_penalty: t.Optional[float] = attr.field(default=0.0)
|
||||
frequency_penalty: t.Optional[float] = attr.field(default=0.0)
|
||||
logit_bias: t.Optional[t.Dict[str, float]] = attr.field(default=None)
|
||||
user: t.Optional[str] = attr.field(default=None)
|
||||
# supported by vLLM and us
|
||||
top_k: t.Optional[int] = attr.field(default=None)
|
||||
best_of: t.Optional[int] = attr.field(default=1)
|
||||
|
||||
|
||||
@attr.define
|
||||
class ChatCompletionRequest:
|
||||
messages: t.List[t.Dict[str, str]]
|
||||
model: str = attr.field(default=None)
|
||||
functions: t.List[t.Dict[str, str]] = attr.field(default=attr.Factory(list))
|
||||
function_calls: t.List[t.Dict[str, str]] = attr.field(default=attr.Factory(list))
|
||||
temperature: t.Optional[float] = attr.field(default=None)
|
||||
top_p: t.Optional[float] = attr.field(default=None)
|
||||
n: t.Optional[int] = attr.field(default=None)
|
||||
stream: t.Optional[bool] = attr.field(default=False)
|
||||
stop: t.Optional[t.Union[str, t.List[str]]] = attr.field(default=None, converter=_stop_converter)
|
||||
max_tokens: t.Optional[int] = attr.field(default=None)
|
||||
presence_penalty: t.Optional[float] = attr.field(default=None)
|
||||
frequency_penalty: t.Optional[float] = attr.field(default=None)
|
||||
echo: t.Optional[bool] = attr.field(default=False)
|
||||
logit_bias: t.Optional[t.Dict[str, float]] = attr.field(default=None)
|
||||
user: t.Optional[str] = attr.field(default=None)
|
||||
# supported by vLLM and us
|
||||
top_k: t.Optional[int] = attr.field(default=None)
|
||||
best_of: t.Optional[int] = attr.field(default=1)
|
||||
# Additional features to support chat_template
|
||||
chat_template: str = attr.field(default=None)
|
||||
add_generation_prompt: bool = attr.field(default=True)
|
||||
|
||||
|
||||
@attr.define
|
||||
class LogProbs:
|
||||
text_offset: t.List[int] = attr.field(default=attr.Factory(list))
|
||||
token_logprobs: t.List[float] = attr.field(default=attr.Factory(list))
|
||||
tokens: t.List[str] = attr.field(default=attr.Factory(list))
|
||||
top_logprobs: t.List[t.Dict[str, t.Any]] = attr.field(default=attr.Factory(list))
|
||||
|
||||
|
||||
@attr.define
|
||||
class UsageInfo:
|
||||
prompt_tokens: int = attr.field(default=0)
|
||||
completion_tokens: int = attr.field(default=0)
|
||||
total_tokens: int = attr.field(default=0)
|
||||
|
||||
|
||||
@attr.define
|
||||
class CompletionResponseChoice:
|
||||
index: int
|
||||
text: str
|
||||
logprobs: t.Optional[LogProbs] = None
|
||||
finish_reason: t.Optional[FinishReason] = None
|
||||
|
||||
|
||||
@attr.define
|
||||
class CompletionResponseStreamChoice:
|
||||
index: int
|
||||
text: str
|
||||
logprobs: t.Optional[LogProbs] = None
|
||||
finish_reason: t.Optional[FinishReason] = None
|
||||
|
||||
|
||||
@attr.define
|
||||
class CompletionStreamResponse:
|
||||
model: str
|
||||
choices: t.List[CompletionResponseStreamChoice]
|
||||
object: str = 'text_completion'
|
||||
id: str = attr.field(default=attr.Factory(lambda: openllm_core.utils.gen_random_uuid('cmpl')))
|
||||
created: int = attr.field(default=attr.Factory(lambda: int(time.monotonic())))
|
||||
usage: t.Optional[UsageInfo] = attr.field(default=None)
|
||||
|
||||
|
||||
@attr.define
|
||||
class CompletionResponse:
|
||||
choices: t.List[CompletionResponseChoice]
|
||||
model: str
|
||||
usage: UsageInfo
|
||||
object: str = 'text_completion'
|
||||
id: str = attr.field(default=attr.Factory(lambda: openllm_core.utils.gen_random_uuid('cmpl')))
|
||||
created: int = attr.field(default=attr.Factory(lambda: int(time.monotonic())))
|
||||
|
||||
|
||||
LiteralRole = t.Literal['system', 'user', 'assistant']
|
||||
|
||||
|
||||
@attr.define
|
||||
class Delta:
|
||||
role: t.Optional[LiteralRole] = None
|
||||
content: t.Optional[str] = None
|
||||
|
||||
|
||||
@attr.define
|
||||
class ChatMessage:
|
||||
role: LiteralRole
|
||||
content: str
|
||||
|
||||
|
||||
converter.register_unstructure_hook(ChatMessage, lambda msg: {'role': msg.role, 'content': msg.content})
|
||||
|
||||
|
||||
@attr.define
|
||||
class ChatCompletionResponseStreamChoice:
|
||||
index: int
|
||||
delta: Delta
|
||||
finish_reason: t.Optional[FinishReason] = None
|
||||
|
||||
|
||||
@attr.define
|
||||
class ChatCompletionResponseChoice:
|
||||
index: int
|
||||
message: ChatMessage
|
||||
finish_reason: t.Optional[FinishReason] = None
|
||||
|
||||
|
||||
@attr.define
|
||||
class ChatCompletionResponse:
|
||||
choices: t.List[ChatCompletionResponseChoice]
|
||||
model: str
|
||||
object: str = 'chat.completion'
|
||||
id: str = attr.field(default=attr.Factory(lambda: openllm_core.utils.gen_random_uuid('chatcmpl')))
|
||||
created: int = attr.field(default=attr.Factory(lambda: int(time.monotonic())))
|
||||
usage: UsageInfo = attr.field(default=attr.Factory(lambda: UsageInfo()))
|
||||
|
||||
|
||||
@attr.define
|
||||
class ChatCompletionStreamResponse:
|
||||
choices: t.List[ChatCompletionResponseStreamChoice]
|
||||
model: str
|
||||
object: str = 'chat.completion.chunk'
|
||||
id: str = attr.field(default=attr.Factory(lambda: openllm_core.utils.gen_random_uuid('chatcmpl')))
|
||||
created: int = attr.field(default=attr.Factory(lambda: int(time.monotonic())))
|
||||
usage: t.Optional[UsageInfo] = attr.field(default=None)
|
||||
|
||||
|
||||
@attr.define
|
||||
class ModelCard:
|
||||
id: str
|
||||
object: str = 'model'
|
||||
created: int = attr.field(default=attr.Factory(lambda: int(time.monotonic())))
|
||||
owned_by: str = 'na'
|
||||
|
||||
|
||||
@attr.define
|
||||
class ModelList:
|
||||
object: str = 'list'
|
||||
data: t.List[ModelCard] = attr.field(factory=list)
|
||||
@@ -1,16 +1,43 @@
|
||||
from __future__ import annotations
|
||||
import importlib, typing as t
|
||||
from openllm_core._typing_compat import M, ParamSpec, T, TypeGuard, Concatenate
|
||||
import importlib, logging, inspect, typing as t
|
||||
from openllm_core._typing_compat import ParamSpec, Concatenate, TypeGuard
|
||||
from openllm_core.exceptions import OpenLLMException
|
||||
from openllm_core.utils import (
|
||||
apply,
|
||||
first_not_none,
|
||||
generate_hash_from_file,
|
||||
resolve_filepath,
|
||||
validate_is_path,
|
||||
normalise_model_name,
|
||||
)
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from bentoml import Model
|
||||
from .._llm import LLM
|
||||
|
||||
P = ParamSpec('P')
|
||||
M = t.TypeVar('M')
|
||||
T = t.TypeVar('T')
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def load_tokenizer(llm: LLM[M, T], **tokenizer_attrs: t.Any) -> TypeGuard[T]:
|
||||
@apply(lambda val: tuple(str.lower(i) if i else i for i in val))
|
||||
def _make_tag_components(model_id: str, model_version: t.Optional[str]) -> t.Tuple[str, t.Optional[str]]:
|
||||
model_id, *maybe_revision = model_id.rsplit(':')
|
||||
if len(maybe_revision) > 0:
|
||||
if model_version is not None:
|
||||
logger.warning(
|
||||
"revision is specified (%s). 'model_version=%s' will be ignored.", maybe_revision[0], model_version
|
||||
)
|
||||
model_version = maybe_revision[0]
|
||||
if validate_is_path(model_id):
|
||||
model_id = resolve_filepath(model_id)
|
||||
model_version = first_not_none(model_version, default=generate_hash_from_file(model_id))
|
||||
return normalise_model_name(model_id), model_version
|
||||
|
||||
|
||||
def load_tokenizer(llm, **tokenizer_attrs):
|
||||
import cloudpickle, fs, transformers
|
||||
from bentoml._internal.models.model import CUSTOM_OBJECTS_FILENAME
|
||||
from .transformers._helpers import process_config
|
||||
@@ -30,7 +57,9 @@ def load_tokenizer(llm: LLM[M, T], **tokenizer_attrs: t.Any) -> TypeGuard[T]:
|
||||
'For example: "bentoml.transformers.save_model(..., custom_objects={\'tokenizer\': tokenizer})"'
|
||||
) from None
|
||||
else:
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(bentomodel_fs.getsyspath('/'), trust_remote_code=llm.trust_remote_code, **tokenizer_attrs)
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
||||
bentomodel_fs.getsyspath('/'), trust_remote_code=llm.trust_remote_code, **tokenizer_attrs
|
||||
)
|
||||
|
||||
if tokenizer.pad_token_id is None:
|
||||
if config.pad_token_id is not None:
|
||||
@@ -42,32 +71,34 @@ def load_tokenizer(llm: LLM[M, T], **tokenizer_attrs: t.Any) -> TypeGuard[T]:
|
||||
return tokenizer
|
||||
|
||||
|
||||
def _make_dispatch_function(fn: str) -> t.Callable[Concatenate[LLM[M, T], P], TypeGuard[M | T | Model]]:
|
||||
def caller(llm: LLM[M, T], *args: P.args, **kwargs: P.kwargs) -> TypeGuard[M | T | Model]:
|
||||
def _make_dispatch_function(fn: str) -> t.Callable[Concatenate[LLM[M, T], P], TypeGuard[t.Union[M, T, Model]]]:
|
||||
def caller(llm: t.Optional[LLM[M, T]] = None, *args: P.args, **kwargs: P.kwargs) -> TypeGuard[t.Union[M, T, Model]]:
|
||||
"""Generic function dispatch to correct serialisation submodules based on LLM runtime.
|
||||
|
||||
> [!NOTE] See 'openllm.serialisation.transformers' if 'llm.__llm_backend__ in ("pt", "vllm")'
|
||||
|
||||
> [!NOTE] See 'openllm.serialisation.ggml' if 'llm.__llm_backend__="ggml"'
|
||||
|
||||
> [!NOTE] See 'openllm.serialisation.ctranslate' if 'llm.__llm_backend__="ctranslate"'
|
||||
"""
|
||||
if llm.__llm_backend__ == 'ggml':
|
||||
serde = 'ggml'
|
||||
elif llm.__llm_backend__ == 'ctranslate':
|
||||
serde = 'ctranslate'
|
||||
elif llm.__llm_backend__ in {'pt', 'vllm'}:
|
||||
serde = 'transformers'
|
||||
else:
|
||||
raise OpenLLMException(f'Not supported backend {llm.__llm_backend__}')
|
||||
return getattr(importlib.import_module(f'.{serde}', 'openllm.serialisation'), fn)(llm, *args, **kwargs)
|
||||
backend = kwargs.get('_backend', None)
|
||||
if backend is None:
|
||||
if llm is None:
|
||||
raise OpenLLMException('Cannot dispatch without LLM instance.')
|
||||
backend = llm.__llm_backend__
|
||||
serde_mapping = {'pt': 'transformers', 'vllm': 'vllm', 'ggml': 'ggml'}
|
||||
try:
|
||||
serde = serde_mapping[backend]
|
||||
except KeyError:
|
||||
raise OpenLLMException(f'Not supported backend {backend}')
|
||||
call = getattr(importlib.import_module(f'.{serde}', 'openllm.serialisation'), fn)
|
||||
params = inspect.signature(call).parameters
|
||||
return call(llm, *args, **kwargs) if next(iter(params.keys())) == 'llm' else call(*args, **kwargs)
|
||||
|
||||
return caller
|
||||
|
||||
|
||||
_extras = ['get', 'import_model', 'load_model']
|
||||
_import_structure = {'ggml', 'transformers', 'ctranslate', 'constants'}
|
||||
__all__ = ['load_tokenizer', *_extras, *_import_structure]
|
||||
_extras = ['import_model', 'load_model']
|
||||
_import_structure = {'ggml', 'transformers', 'vllm', 'constants'}
|
||||
__all__ = ['_make_tag_components', 'load_tokenizer', *_extras, *_import_structure]
|
||||
|
||||
|
||||
def __dir__() -> t.Sequence[str]:
|
||||
@@ -75,9 +106,7 @@ def __dir__() -> t.Sequence[str]:
|
||||
|
||||
|
||||
def __getattr__(name: str) -> t.Any:
|
||||
if name == 'load_tokenizer':
|
||||
return load_tokenizer
|
||||
elif name in _import_structure:
|
||||
if name in _import_structure:
|
||||
return importlib.import_module(f'.{name}', __name__)
|
||||
elif name in _extras:
|
||||
return _make_dispatch_function(name)
|
||||
|
||||
@@ -6,18 +6,10 @@ Currently, GGML format is working in progress.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
from bentoml import Model
|
||||
from openllm import LLM
|
||||
from openllm_core._typing_compat import M, T
|
||||
from . import constants as constants, ggml as ggml, transformers as transformers
|
||||
from bentoml import Model as _Model
|
||||
from openllm import LLM as _LLM
|
||||
from . import constants as constants, ggml as ggml, transformers as transformers, vllm as vllm
|
||||
|
||||
def load_tokenizer(llm: LLM[M, T], **attrs: Any) -> T:
|
||||
"""Load the tokenizer from BentoML store.
|
||||
|
||||
By default, it will try to find the bentomodel whether it is in store..
|
||||
If model is not found, it will raises a ``bentoml.exceptions.NotFound``.
|
||||
"""
|
||||
|
||||
def get(llm: LLM[M, T]) -> Model: ...
|
||||
def import_model(llm: LLM[M, T], *args: Any, trust_remote_code: bool, **attrs: Any) -> Model: ...
|
||||
def load_model(llm: LLM[M, T], *args: Any, **attrs: Any) -> M: ...
|
||||
def import_model(*args: Any, trust_remote_code: bool, **attrs: Any) -> _Model: ...
|
||||
def load_model(llm: _LLM, *args: Any, **attrs: Any) -> Any: ...
|
||||
def load_tokenizer(llm: _LLM, **attrs: Any) -> Any: ...
|
||||
|
||||
@@ -1,18 +1,21 @@
|
||||
import contextlib, attr
|
||||
from __future__ import annotations
|
||||
import contextlib, attr, bentoml, openllm, types, logging, typing as t
|
||||
from simple_di import Provide, inject
|
||||
import bentoml, openllm
|
||||
from bentoml._internal.configuration.containers import BentoMLContainer
|
||||
from bentoml._internal.models.model import ModelOptions, ModelSignature
|
||||
from openllm_core.exceptions import OpenLLMException
|
||||
from openllm_core.utils import is_autogptq_available
|
||||
from openllm_core._typing_compat import LiteralSerialisation, LiteralQuantise, LiteralBackend
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import transformers
|
||||
from bentoml._internal.models import ModelStore
|
||||
|
||||
_object_setattr = object.__setattr__
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_hash(config) -> str:
|
||||
def get_hash(config: transformers.PretrainedConfig) -> str:
|
||||
_commit_hash = getattr(config, '_commit_hash', None)
|
||||
if _commit_hash is None:
|
||||
raise ValueError(f'Cannot find commit hash in {config}')
|
||||
logger.warning('Cannot find commit hash in %r', config)
|
||||
return _commit_hash
|
||||
|
||||
|
||||
@@ -29,7 +32,9 @@ def patch_correct_tag(llm, config, _revision=None) -> None:
|
||||
if _revision is None and llm.tag.version is not None:
|
||||
_revision = llm.tag.version
|
||||
if llm.tag.version is None:
|
||||
_object_setattr(llm, '_tag', attr.evolve(llm.tag, version=_revision)) # HACK: This copies the correct revision into llm.tag
|
||||
_object_setattr(
|
||||
llm, '_tag', attr.evolve(llm.tag, version=_revision)
|
||||
) # HACK: This copies the correct revision into llm.tag
|
||||
if llm._revision is None:
|
||||
_object_setattr(llm, '_revision', _revision) # HACK: This copies the correct revision into llm._model_version
|
||||
|
||||
@@ -37,7 +42,7 @@ def patch_correct_tag(llm, config, _revision=None) -> None:
|
||||
def _create_metadata(llm, config, safe_serialisation, trust_remote_code, metadata=None):
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
metadata.update({'safe_serialisation': safe_serialisation, '_framework': llm.__llm_backend__})
|
||||
metadata.update({'_framework': llm.__llm_backend__})
|
||||
if llm.quantise:
|
||||
metadata['_quantize'] = llm.quantise
|
||||
architectures = getattr(config, 'architectures', [])
|
||||
@@ -45,7 +50,9 @@ def _create_metadata(llm, config, safe_serialisation, trust_remote_code, metadat
|
||||
if trust_remote_code:
|
||||
auto_map = getattr(config, 'auto_map', {})
|
||||
if not auto_map:
|
||||
raise RuntimeError(f'Failed to determine the architecture from both `auto_map` and `architectures` from {llm.model_id}')
|
||||
raise RuntimeError(
|
||||
f'Failed to determine the architecture from both `auto_map` and `architectures` from {llm.model_id}'
|
||||
)
|
||||
autoclass = 'AutoModelForSeq2SeqLM' if llm.config['model_type'] == 'seq2seq_lm' else 'AutoModelForCausalLM'
|
||||
if autoclass not in auto_map:
|
||||
raise RuntimeError(
|
||||
@@ -56,89 +63,97 @@ def _create_metadata(llm, config, safe_serialisation, trust_remote_code, metadat
|
||||
raise RuntimeError(
|
||||
'Failed to determine the architecture for this model. Make sure the `config.json` is valid and can be loaded with `transformers.AutoConfig`'
|
||||
)
|
||||
metadata.update({'_pretrained_class': architectures[0], '_revision': get_hash(config) if not llm.local else llm.revision})
|
||||
metadata.update({
|
||||
'_pretrained_class': architectures[0],
|
||||
'_revision': get_hash(config) if not llm.local else llm.revision,
|
||||
'_local': llm.local,
|
||||
'serialisation': llm._serialisation,
|
||||
'model_name': llm.config['model_name'],
|
||||
'architecture': llm.config['architecture'],
|
||||
'model_id': llm.model_id,
|
||||
})
|
||||
return metadata
|
||||
|
||||
|
||||
def _create_signatures(llm, signatures=None):
|
||||
if signatures is None:
|
||||
signatures = {}
|
||||
if llm.__llm_backend__ == 'pt':
|
||||
if llm.quantise == 'gptq':
|
||||
if not is_autogptq_available():
|
||||
raise OpenLLMException("Requires 'auto-gptq' and 'optimum'. Install it with 'pip install \"openllm[gptq]\"'")
|
||||
signatures['generate'] = {'batchable': False}
|
||||
else:
|
||||
signatures.update({
|
||||
k: ModelSignature(batchable=False)
|
||||
for k in (
|
||||
'__call__',
|
||||
'forward',
|
||||
'generate', #
|
||||
'contrastive_search',
|
||||
'greedy_search', #
|
||||
'sample',
|
||||
'beam_search',
|
||||
'beam_sample', #
|
||||
'group_beam_search',
|
||||
'constrained_beam_search', #
|
||||
)
|
||||
})
|
||||
elif llm.__llm_backend__ == 'ctranslate':
|
||||
if llm.config['model_type'] == 'seq2seq_lm':
|
||||
non_batch_keys = {'score_file', 'translate_file'}
|
||||
batch_keys = {'generate_tokens', 'score_batch', 'translate_batch', 'translate_iterable', 'score_iterable'}
|
||||
else:
|
||||
non_batch_keys = set()
|
||||
batch_keys = {
|
||||
'async_generate_tokens',
|
||||
'forward_batch',
|
||||
'generate_batch', #
|
||||
'generate_iterable',
|
||||
'generate_tokens',
|
||||
'score_batch',
|
||||
'score_iterable', #
|
||||
}
|
||||
signatures.update({k: ModelSignature(batchable=False) for k in non_batch_keys})
|
||||
signatures.update({k: ModelSignature(batchable=True) for k in batch_keys})
|
||||
return signatures
|
||||
@attr.define(init=False)
|
||||
class _Model(bentoml.Model):
|
||||
_imported_modules: t.List[types.ModuleType] = None
|
||||
|
||||
@property
|
||||
def imported_modules(self):
|
||||
if self._imported_modules is None:
|
||||
self._imported_modules = []
|
||||
return self._imported_modules
|
||||
|
||||
@imported_modules.setter
|
||||
def imported_modules(self, value):
|
||||
self._imported_modules = value
|
||||
|
||||
@classmethod
|
||||
def create(cls, tag, *, module, api_version, labels=None, metadata=None):
|
||||
return super().create(
|
||||
tag,
|
||||
module=module,
|
||||
api_version=api_version,
|
||||
signatures={},
|
||||
labels=labels,
|
||||
metadata=metadata,
|
||||
context=openllm.utils.generate_context('openllm'),
|
||||
)
|
||||
|
||||
|
||||
@inject
|
||||
@contextlib.contextmanager
|
||||
def save_model(
|
||||
llm,
|
||||
config,
|
||||
safe_serialisation, #
|
||||
trust_remote_code,
|
||||
module,
|
||||
external_modules, #
|
||||
_model_store=Provide[BentoMLContainer.model_store],
|
||||
_api_version='v2.1.0', #
|
||||
):
|
||||
tag: bentoml.Tag,
|
||||
config: transformers.PretrainedConfig,
|
||||
serialisation: LiteralSerialisation, #
|
||||
trust_remote_code: bool,
|
||||
module: str,
|
||||
external_modules: list[types.ModuleType], #
|
||||
model_id: str,
|
||||
quantise: LiteralQuantise,
|
||||
backend: LiteralBackend,
|
||||
_local: bool,
|
||||
_dtype: str,
|
||||
_model_store: ModelStore = Provide[BentoMLContainer.model_store],
|
||||
_api_version: str = 'v3.0.0', #
|
||||
) -> bentoml.Model:
|
||||
imported_modules = []
|
||||
bentomodel = bentoml.Model.create(
|
||||
llm.tag,
|
||||
module=f'openllm.serialisation.{module}', #
|
||||
architectures = getattr(config, 'architectures', [])
|
||||
_metadata = {
|
||||
'model_id': model_id,
|
||||
'backend': backend,
|
||||
'dtype': _dtype,
|
||||
'architectures': architectures,
|
||||
'_revision': get_hash(config) or tag.version,
|
||||
'_local': _local,
|
||||
'serialisation': serialisation,
|
||||
}
|
||||
if quantise:
|
||||
_metadata['_quantize'] = quantise
|
||||
bentomodel = _Model.create(
|
||||
tag,
|
||||
module=f'openllm.serialisation.{module}',
|
||||
api_version=_api_version,
|
||||
options=ModelOptions(), #
|
||||
context=openllm.utils.generate_context('openllm'),
|
||||
labels=openllm.utils.generate_labels(llm),
|
||||
metadata=_create_metadata(llm, config, safe_serialisation, trust_remote_code),
|
||||
signatures=_create_signatures(llm),
|
||||
labels=openllm.utils.generate_labels(serialisation),
|
||||
metadata=_metadata,
|
||||
)
|
||||
with openllm.utils.analytics.set_bentoml_tracking():
|
||||
try:
|
||||
bentomodel.enter_cloudpickle_context(external_modules, imported_modules)
|
||||
yield bentomodel, imported_modules
|
||||
bentomodel.imported_modules = imported_modules
|
||||
yield bentomodel
|
||||
except Exception:
|
||||
raise
|
||||
else:
|
||||
bentomodel.flush()
|
||||
bentomodel.save(_model_store)
|
||||
openllm.utils.analytics.track(
|
||||
openllm.utils.analytics.ModelSaveEvent(module=bentomodel.info.module, model_size_in_kb=openllm.utils.calc_dir_size(bentomodel.path) / 1024)
|
||||
openllm.utils.analytics.ModelSaveEvent(
|
||||
module=bentomodel.info.module, model_size_in_kb=openllm.utils.calc_dir_size(bentomodel.path) / 1024
|
||||
)
|
||||
)
|
||||
finally:
|
||||
bentomodel.exit_cloudpickle_context(imported_modules)
|
||||
bentomodel.exit_cloudpickle_context(bentomodel.imported_modules)
|
||||
return bentomodel
|
||||
|
||||
@@ -1,22 +0,0 @@
|
||||
import types
|
||||
from contextlib import contextmanager
|
||||
from typing import Iterator, Optional, Sequence, Tuple
|
||||
|
||||
import transformers
|
||||
|
||||
from bentoml import Model
|
||||
from openllm_core._typing_compat import M, T
|
||||
|
||||
from .._llm import LLM
|
||||
|
||||
def get_hash(config: transformers.PretrainedConfig) -> str: ...
|
||||
def patch_correct_tag(llm: LLM[M, T], config: transformers.PretrainedConfig, _revision: Optional[str] = ...) -> None: ...
|
||||
@contextmanager
|
||||
def save_model(
|
||||
llm: LLM[M, T],
|
||||
config: transformers.PretrainedConfig,
|
||||
safe_serialisation: bool,
|
||||
trust_remote_code: bool,
|
||||
module: str,
|
||||
external_modules: Sequence[types.ModuleType],
|
||||
) -> Iterator[Tuple[Model, Sequence[types.ModuleType]]]: ...
|
||||
@@ -1,90 +0,0 @@
|
||||
import importlib
|
||||
import logging
|
||||
import shutil
|
||||
|
||||
import transformers
|
||||
|
||||
import bentoml
|
||||
from openllm_core.exceptions import OpenLLMException
|
||||
from openllm_core.utils import is_ctranslate_available
|
||||
|
||||
from .._helpers import patch_correct_tag, save_model
|
||||
from ..transformers._helpers import get_tokenizer, process_config
|
||||
|
||||
if not is_ctranslate_available():
|
||||
raise RuntimeError("'ctranslate2' is required to use with backend 'ctranslate'. Install it with 'pip install \"openllm[ctranslate]\"'")
|
||||
|
||||
import ctranslate2
|
||||
from ctranslate2.converters.transformers import TransformersConverter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _get_class(llm):
|
||||
return ctranslate2.Translator if llm.config['model_type'] == 'seq2seq_lm' else ctranslate2.Generator
|
||||
|
||||
|
||||
def import_model(llm, *decls, trust_remote_code, **attrs):
|
||||
(_base_decls, _base_attrs), tokenizer_attrs = llm.llm_parameters
|
||||
for it in {'device_map', 'torch_dtype'}:
|
||||
_base_attrs.pop(it, None) # pop out hf-specific attributes
|
||||
decls = (*_base_decls, *decls)
|
||||
attrs = {**_base_attrs, **attrs}
|
||||
low_cpu_mem_usage = attrs.pop('low_cpu_mem_usage', True)
|
||||
logger.debug(
|
||||
'Note that CTranslate2 will load into memory for conversion. Refer to https://opennmt.net/CTranslate2/guides/transformers.html for more information.'
|
||||
)
|
||||
if not llm._local:
|
||||
logger.warning(
|
||||
"It is RECOMMENDED to convert '%s' to CTranslate2 format yourself to utilise CTranslate2's features, then start with `openllm start /path/to/ct2-dir`. OpenLLM will conservely apply quantization for conversion if specified.",
|
||||
llm.model_id,
|
||||
)
|
||||
config, hub_attrs, attrs = process_config(llm.model_id, trust_remote_code, **attrs)
|
||||
patch_correct_tag(llm, config)
|
||||
tokenizer = get_tokenizer(llm.model_id, trust_remote_code=trust_remote_code, **hub_attrs, **tokenizer_attrs)
|
||||
with save_model(llm, config, False, trust_remote_code, 'ctranslate', [importlib.import_module(tokenizer.__module__)]) as save_metadata:
|
||||
bentomodel, _ = save_metadata
|
||||
if llm._local:
|
||||
shutil.copytree(
|
||||
llm.model_id, bentomodel.path, symlinks=False, ignore=shutil.ignore_patterns('.git', 'venv', '__pycache__', '.venv'), dirs_exist_ok=True
|
||||
)
|
||||
else:
|
||||
TransformersConverter(
|
||||
llm.model_id,
|
||||
load_as_float16=llm.quantise in ('float16', 'int8_float16'),
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
trust_remote_code=trust_remote_code,
|
||||
).convert(bentomodel.path, quantization=llm.quantise, force=True)
|
||||
# Save the original HF configuration to hf
|
||||
config.save_pretrained(bentomodel.path_of('/hf/'))
|
||||
tokenizer.save_pretrained(bentomodel.path)
|
||||
return bentomodel
|
||||
|
||||
|
||||
def get(llm):
|
||||
try:
|
||||
model = bentoml.models.get(llm.tag)
|
||||
backend = model.info.labels['backend']
|
||||
if backend != llm.__llm_backend__:
|
||||
raise OpenLLMException(f"'{model.tag!s}' was saved with backend '{backend}', while loading with '{llm.__llm_backend__}'.")
|
||||
patch_correct_tag(
|
||||
llm,
|
||||
transformers.AutoConfig.from_pretrained(model.path_of('/hf/'), trust_remote_code=llm.trust_remote_code),
|
||||
_revision=model.info.metadata.get('_revision'),
|
||||
)
|
||||
return model
|
||||
except Exception as err:
|
||||
raise OpenLLMException(f'Failed while getting stored artefact (lookup for traceback):\n{err}') from err
|
||||
|
||||
|
||||
def load_model(llm, *decls, **attrs):
|
||||
device = 'cuda' if llm._has_gpus else 'cpu'
|
||||
if llm.quantise:
|
||||
compute_type = llm.quantise
|
||||
elif llm.__llm_dtype__ == 'half':
|
||||
compute_type = 'float16'
|
||||
elif llm.__llm_dtype__ == 'float':
|
||||
compute_type = 'float32'
|
||||
else:
|
||||
compute_type = llm.__llm_dtype__
|
||||
return _get_class(llm)(llm.bentomodel.path, device=device, compute_type=compute_type)
|
||||
@@ -5,9 +5,5 @@ def import_model(llm, *decls, trust_remote_code=True, **attrs):
|
||||
raise NotImplementedError('Currently work in progress.')
|
||||
|
||||
|
||||
def get(llm):
|
||||
raise NotImplementedError('Currently work in progress.')
|
||||
|
||||
|
||||
def load_model(llm, *decls, **attrs):
|
||||
raise NotImplementedError('Currently work in progress.')
|
||||
|
||||
@@ -1,84 +1,153 @@
|
||||
from __future__ import annotations
|
||||
import importlib, logging
|
||||
import orjson, torch, transformers, bentoml, openllm
|
||||
|
||||
import functools, importlib, logging, orjson, torch, transformers, openllm, attr
|
||||
from huggingface_hub import snapshot_download
|
||||
from openllm_core.exceptions import OpenLLMException
|
||||
from openllm_core.utils import first_not_none, is_autogptq_available, is_flash_attn_2_available
|
||||
from openllm_core.utils import is_autogptq_available, is_flash_attn_2_available
|
||||
|
||||
from ._helpers import get_tokenizer, infer_autoclass_from_llm, process_config
|
||||
from .weights import HfIgnore
|
||||
from .._helpers import patch_correct_tag, save_model
|
||||
from .._helpers import save_model, get_hash
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
__all__ = ['import_model', 'get', 'load_model']
|
||||
__all__ = ['import_model', 'load_model']
|
||||
_object_setattr = object.__setattr__
|
||||
|
||||
TOKENIZER_ATTRS = {'padding_side': 'left', 'truncation_side': 'left'}
|
||||
|
||||
def import_model(llm, *decls, trust_remote_code, **attrs):
|
||||
(_base_decls, _base_attrs), tokenizer_attrs = llm.llm_parameters
|
||||
decls = (*_base_decls, *decls)
|
||||
|
||||
@functools.lru_cache(maxsize=1)
|
||||
def has_gpus() -> bool:
|
||||
try:
|
||||
from cuda import cuda
|
||||
|
||||
err, *_ = cuda.cuInit(0)
|
||||
if err != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise RuntimeError('Failed to initialise CUDA runtime binding.')
|
||||
err, _ = cuda.cuDeviceGetCount()
|
||||
if err != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise RuntimeError('Failed to get CUDA device count.')
|
||||
return True
|
||||
except (ImportError, RuntimeError):
|
||||
return False
|
||||
|
||||
|
||||
_TORCH_DTYPE_MAPPING = {
|
||||
'half': torch.float16,
|
||||
'float16': torch.float16, #
|
||||
'float': torch.float32,
|
||||
'float32': torch.float32, #
|
||||
'bfloat16': torch.bfloat16,
|
||||
}
|
||||
|
||||
|
||||
def _torch_dtype(dtype: str, model_id: str, trust_remote_code: bool) -> 'torch.dtype':
|
||||
hf_config = transformers.AutoConfig.from_pretrained(model_id, trust_remote_code=trust_remote_code)
|
||||
config_dtype = getattr(hf_config, 'torch_dtype', None)
|
||||
if config_dtype is None:
|
||||
config_dtype = torch.float32
|
||||
if dtype == 'auto':
|
||||
if config_dtype == torch.float32:
|
||||
torch_dtype = torch.float16
|
||||
else:
|
||||
torch_dtype = config_dtype
|
||||
else:
|
||||
torch_dtype = _TORCH_DTYPE_MAPPING.get(dtype, None)
|
||||
if torch_dtype is None:
|
||||
raise OpenLLMException(f'dtype not yet supported: {dtype}')
|
||||
if not torch.cuda.is_available() and torch_dtype != torch.float32:
|
||||
torch_dtype = torch.float32
|
||||
return torch_dtype
|
||||
|
||||
|
||||
def import_model(
|
||||
*decls,
|
||||
_model_id=None,
|
||||
_bentomodel_tag=None,
|
||||
_backend=None,
|
||||
_local=False,
|
||||
_quantization_config=None,
|
||||
_quantize=None,
|
||||
_dtype='auto',
|
||||
_serialisation='safetensors',
|
||||
trust_remote_code,
|
||||
**attrs,
|
||||
):
|
||||
_base_attrs = {
|
||||
'device_map': 'auto' if has_gpus() else None,
|
||||
'safe_serialization': _serialisation == 'safetensors',
|
||||
'torch_dtype': _torch_dtype(_dtype, _model_id, trust_remote_code),
|
||||
}
|
||||
attrs = {**_base_attrs, **attrs}
|
||||
if llm._local:
|
||||
logger.warning('Given model is a local model, OpenLLM will load model into memory for serialisation.')
|
||||
config, hub_attrs, attrs = process_config(llm.model_id, trust_remote_code, **attrs)
|
||||
patch_correct_tag(llm, config)
|
||||
safe_serialisation = first_not_none(attrs.get('safe_serialization'), default=llm._serialisation == 'safetensors')
|
||||
if llm.quantise != 'gptq':
|
||||
attrs['use_safetensors'] = safe_serialisation
|
||||
|
||||
model = None
|
||||
tokenizer = get_tokenizer(llm.model_id, trust_remote_code=trust_remote_code, **hub_attrs, **tokenizer_attrs)
|
||||
config, hub_attrs, attrs = process_config(_model_id, trust_remote_code, **attrs)
|
||||
_revision = get_hash(config) if not _local else None
|
||||
if _revision:
|
||||
_bentomodel_tag = attr.evolve(_bentomodel_tag, version=_revision)
|
||||
model, tokenizer = (
|
||||
None,
|
||||
get_tokenizer(_model_id, trust_remote_code=trust_remote_code, **hub_attrs, **TOKENIZER_ATTRS),
|
||||
)
|
||||
with save_model(
|
||||
llm, config, safe_serialisation, trust_remote_code, 'transformers', [importlib.import_module(tokenizer.__module__)]
|
||||
) as save_metadata:
|
||||
bentomodel, imported_modules = save_metadata
|
||||
_bentomodel_tag,
|
||||
config,
|
||||
_serialisation,
|
||||
trust_remote_code,
|
||||
'transformers',
|
||||
[importlib.import_module(tokenizer.__module__)],
|
||||
_model_id,
|
||||
_quantize,
|
||||
_backend,
|
||||
_local,
|
||||
_dtype,
|
||||
) as bentomodel:
|
||||
tokenizer.save_pretrained(bentomodel.path)
|
||||
if llm._quantization_config or (llm.quantise and llm.quantise not in {'squeezellm', 'awq'}):
|
||||
attrs['quantization_config'] = llm.quantization_config
|
||||
if llm.quantise == 'gptq' and llm.__llm_backend__ == 'pt':
|
||||
if _quantization_config or (_quantize and _quantize not in {'squeezellm', 'awq'}):
|
||||
attrs['quantization_config'] = _quantization_config
|
||||
if _quantize == 'gptq' and _backend == 'pt':
|
||||
from optimum.gptq.constants import GPTQ_CONFIG
|
||||
|
||||
with open(bentomodel.path_of(GPTQ_CONFIG), 'w', encoding='utf-8') as f:
|
||||
f.write(orjson.dumps(config.quantization_config, option=orjson.OPT_INDENT_2 | orjson.OPT_SORT_KEYS).decode())
|
||||
if llm._local: # possible local path
|
||||
model = infer_autoclass_from_llm(llm, config).from_pretrained(
|
||||
llm.model_id, *decls, local_files_only=True, config=config, trust_remote_code=trust_remote_code, **hub_attrs, **attrs
|
||||
)
|
||||
if _local:
|
||||
try:
|
||||
model = transformers.AutoModelForCausalLM.from_pretrained(
|
||||
_model_id,
|
||||
*decls,
|
||||
local_files_only=True,
|
||||
config=config,
|
||||
trust_remote_code=trust_remote_code,
|
||||
**hub_attrs,
|
||||
**attrs,
|
||||
)
|
||||
except Exception:
|
||||
try:
|
||||
model = transformers.AutoModelForSeq2SeqLM.from_pretrained(
|
||||
_model_id,
|
||||
*decls,
|
||||
local_files_only=True,
|
||||
config=config,
|
||||
trust_remote_code=trust_remote_code,
|
||||
**hub_attrs,
|
||||
**attrs,
|
||||
)
|
||||
except Exception as err:
|
||||
raise OpenLLMException(f'Failed to load model from {_model_id}') from err
|
||||
# for trust_remote_code to work
|
||||
bentomodel.enter_cloudpickle_context([importlib.import_module(model.__module__)], imported_modules)
|
||||
model.save_pretrained(bentomodel.path, max_shard_size='2GB', safe_serialization=safe_serialisation)
|
||||
bentomodel.enter_cloudpickle_context([importlib.import_module(model.__module__)], bentomodel.imported_modules)
|
||||
model.save_pretrained(bentomodel.path, max_shard_size='2GB', safe_serialization=_serialisation == 'safetensors')
|
||||
del model
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
else:
|
||||
# we will clone the all tings into the bentomodel path without loading model into memory
|
||||
snapshot_download(
|
||||
llm.model_id,
|
||||
local_dir=bentomodel.path, #
|
||||
_model_id,
|
||||
local_dir=bentomodel.path,
|
||||
local_dir_use_symlinks=False,
|
||||
ignore_patterns=HfIgnore.ignore_patterns(llm), #
|
||||
ignore_patterns=HfIgnore.ignore_patterns(_backend, _model_id),
|
||||
)
|
||||
return bentomodel
|
||||
|
||||
|
||||
def get(llm):
|
||||
try:
|
||||
model = bentoml.models.get(llm.tag)
|
||||
backend = model.info.labels['backend']
|
||||
if backend != llm.__llm_backend__:
|
||||
raise OpenLLMException(f"'{model.tag!s}' was saved with backend '{backend}', while loading with '{llm.__llm_backend__}'.")
|
||||
patch_correct_tag(
|
||||
llm,
|
||||
transformers.AutoConfig.from_pretrained(model.path, trust_remote_code=llm.trust_remote_code),
|
||||
_revision=model.info.metadata.get('_revision'),
|
||||
)
|
||||
return model
|
||||
except Exception as err:
|
||||
raise OpenLLMException(f'Failed while getting stored artefact (lookup for traceback):\n{err}') from err
|
||||
|
||||
|
||||
def check_unintialised_params(model):
|
||||
unintialized = [n for n, param in model.named_parameters() if param.data.device == torch.device('meta')]
|
||||
if len(unintialized) > 0:
|
||||
@@ -124,7 +193,9 @@ def load_model(llm, *decls, **attrs):
|
||||
)
|
||||
except Exception as err:
|
||||
logger.debug("Failed to load model with 'use_flash_attention_2' (lookup for traceback):\n%s", err)
|
||||
model = auto_class.from_pretrained(llm.bentomodel.path, device_map=device_map, trust_remote_code=llm.trust_remote_code, **attrs)
|
||||
model = auto_class.from_pretrained(
|
||||
llm.bentomodel.path, device_map=device_map, trust_remote_code=llm.trust_remote_code, **attrs
|
||||
)
|
||||
else:
|
||||
try:
|
||||
model = auto_class.from_pretrained(
|
||||
@@ -139,7 +210,25 @@ def load_model(llm, *decls, **attrs):
|
||||
except Exception as err:
|
||||
logger.debug("Failed to load model with 'use_flash_attention_2' (lookup for traceback):\n%s", err)
|
||||
model = auto_class.from_pretrained(
|
||||
llm.bentomodel.path, *decls, config=config, trust_remote_code=llm.trust_remote_code, device_map=device_map, **attrs
|
||||
llm.bentomodel.path,
|
||||
*decls,
|
||||
config=config,
|
||||
trust_remote_code=llm.trust_remote_code,
|
||||
device_map=device_map,
|
||||
**attrs,
|
||||
)
|
||||
check_unintialised_params(model)
|
||||
|
||||
# If OOM, then it is probably you don't have enough VRAM to run this model.
|
||||
loaded_in_kbit = (
|
||||
getattr(model, 'is_loaded_in_8bit', False)
|
||||
or getattr(model, 'is_loaded_in_4bit', False)
|
||||
or getattr(model, 'is_quantized', False)
|
||||
)
|
||||
if torch.cuda.is_available() and torch.cuda.device_count() == 1 and not loaded_in_kbit:
|
||||
try:
|
||||
model = model.to('cuda')
|
||||
except Exception as err:
|
||||
raise OpenLLMException(f'Failed to load model into GPU: {err}.\n') from err
|
||||
|
||||
return model
|
||||
|
||||
@@ -6,7 +6,9 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_tokenizer(model_id_or_path, trust_remote_code, **attrs):
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(model_id_or_path, trust_remote_code=trust_remote_code, **attrs)
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
||||
model_id_or_path, trust_remote_code=trust_remote_code, **attrs
|
||||
)
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
return tokenizer
|
||||
|
||||
@@ -6,7 +6,6 @@ from openllm_core.utils import resolve_filepath, validate_is_path
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from huggingface_hub.hf_api import ModelInfo as HfModelInfo
|
||||
import openllm
|
||||
|
||||
__global_inst__ = None
|
||||
__cached_id__: dict[str, HfModelInfo] = dict()
|
||||
@@ -49,16 +48,16 @@ class HfIgnore:
|
||||
gguf = '*.gguf'
|
||||
|
||||
@classmethod
|
||||
def ignore_patterns(cls, llm: openllm.LLM[t.Any, t.Any]) -> list[str]:
|
||||
if llm.__llm_backend__ in {'vllm', 'pt'}:
|
||||
def ignore_patterns(cls, backend, model_id) -> list[str]:
|
||||
if backend in {'vllm', 'pt'}:
|
||||
base = [cls.tf, cls.flax, cls.gguf]
|
||||
if has_safetensors_weights(llm.model_id):
|
||||
if has_safetensors_weights(model_id):
|
||||
base.extend([cls.pt, '*.pt'])
|
||||
elif has_pt_weights(llm.model_id):
|
||||
elif has_pt_weights(model_id):
|
||||
base.extend([cls.safetensors, cls.pt])
|
||||
else:
|
||||
base.append(cls.safetensors)
|
||||
elif llm.__llm_backend__ == 'ggml':
|
||||
elif backend == 'ggml':
|
||||
base = [cls.tf, cls.flax, cls.pt, cls.safetensors]
|
||||
else:
|
||||
raise ValueError('Unknown backend (should never happen at all.)')
|
||||
|
||||
42
openllm-python/src/openllm/serialisation/vllm/__init__.py
Normal file
42
openllm-python/src/openllm/serialisation/vllm/__init__.py
Normal file
@@ -0,0 +1,42 @@
|
||||
import openllm, traceback
|
||||
from openllm_core.utils import is_vllm_available
|
||||
from ..transformers import import_model
|
||||
|
||||
__all__ = ['import_model', 'load_model']
|
||||
|
||||
|
||||
def load_model(llm, *decls, **attrs):
|
||||
if not is_vllm_available():
|
||||
raise RuntimeError(
|
||||
"'vllm' is required to use with backend 'vllm'. Install it with 'pip install \"openllm[vllm]\"'"
|
||||
)
|
||||
import vllm, torch
|
||||
|
||||
num_gpus, dev = 1, openllm.utils.device_count()
|
||||
if dev >= 2:
|
||||
num_gpus = min(dev // 2 * 2, dev)
|
||||
quantise = llm.quantise if llm.quantise and llm.quantise in {'gptq', 'awq', 'squeezellm'} else None
|
||||
dtype = (
|
||||
torch.float16 if quantise == 'gptq' else llm._torch_dtype
|
||||
) # NOTE: quantise GPTQ doesn't support bfloat16 yet.
|
||||
try:
|
||||
return vllm.AsyncLLMEngine.from_engine_args(
|
||||
vllm.AsyncEngineArgs(
|
||||
worker_use_ray=False,
|
||||
engine_use_ray=False,
|
||||
tokenizer_mode='auto',
|
||||
tensor_parallel_size=num_gpus,
|
||||
model=llm.bentomodel.path,
|
||||
tokenizer=llm.bentomodel.path,
|
||||
trust_remote_code=llm.trust_remote_code,
|
||||
dtype=dtype,
|
||||
max_model_len=llm._max_model_len,
|
||||
gpu_memory_utilization=llm._gpu_memory_utilization,
|
||||
quantization=quantise,
|
||||
)
|
||||
)
|
||||
except Exception as err:
|
||||
traceback.print_exc()
|
||||
raise openllm.exceptions.OpenLLMException(
|
||||
f'Failed to initialise vLLMEngine due to the following error:\n{err}'
|
||||
) from err
|
||||
@@ -1,15 +1,12 @@
|
||||
import functools, importlib.metadata, openllm_core
|
||||
|
||||
__all__ = ['generate_labels', 'available_devices', 'device_count']
|
||||
__all__ = ['available_devices', 'device_count', 'generate_labels']
|
||||
|
||||
|
||||
def generate_labels(llm):
|
||||
def generate_labels(serialisation):
|
||||
return {
|
||||
'backend': llm.__llm_backend__,
|
||||
'framework': 'openllm',
|
||||
'model_name': llm.config['model_name'], #
|
||||
'architecture': llm.config['architecture'],
|
||||
'serialisation': llm._serialisation, #
|
||||
'serialisation': serialisation,
|
||||
**{package: importlib.metadata.version(package) for package in {'openllm', 'openllm-core', 'openllm-client'}},
|
||||
}
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ from openllm_core.utils import (
|
||||
LazyModule as LazyModule,
|
||||
ReprMixin as ReprMixin,
|
||||
VersionInfo as VersionInfo,
|
||||
correct_closure as correct_closure,
|
||||
analytics as analytics,
|
||||
calc_dir_size as calc_dir_size,
|
||||
check_bool_env as check_bool_env,
|
||||
@@ -34,7 +35,6 @@ from openllm_core.utils import (
|
||||
is_autogptq_available as is_autogptq_available,
|
||||
is_bentoml_available as is_bentoml_available,
|
||||
is_bitsandbytes_available as is_bitsandbytes_available,
|
||||
is_ctranslate_available as is_ctranslate_available,
|
||||
is_flash_attn_2_available as is_flash_attn_2_available,
|
||||
is_grpc_available as is_grpc_available,
|
||||
is_jupyter_available as is_jupyter_available,
|
||||
@@ -49,15 +49,15 @@ from openllm_core.utils import (
|
||||
resolve_filepath as resolve_filepath,
|
||||
resolve_user_filepath as resolve_user_filepath,
|
||||
serde as serde,
|
||||
pkg as pkg,
|
||||
set_debug_mode as set_debug_mode,
|
||||
set_disable_warnings as set_disable_warnings,
|
||||
set_quiet_mode as set_quiet_mode,
|
||||
validate_is_path as validate_is_path,
|
||||
)
|
||||
from openllm_core.utils.serde import converter as converter
|
||||
|
||||
from ._llm import LLM
|
||||
from openllm_core._typing_compat import LiteralSerialisation as _LiteralSerialisation
|
||||
|
||||
def available_devices() -> Tuple[str, ...]: ...
|
||||
def device_count() -> int: ...
|
||||
def generate_labels(llm: LLM[Any, Any]) -> Dict[str, Any]: ...
|
||||
def generate_labels(serialisation: _LiteralSerialisation) -> Dict[str, Any]: ...
|
||||
|
||||
@@ -5,20 +5,18 @@ from bentoml_cli.utils import BentoMLCommandGroup
|
||||
from click import shell_completion as sc
|
||||
|
||||
from openllm_core._configuration import LLMConfig
|
||||
from openllm_core._typing_compat import Concatenate, DictStrAny, LiteralBackend, LiteralSerialisation, ParamSpec, AnyCallable, get_literal_args
|
||||
from openllm_core._typing_compat import (
|
||||
Concatenate,
|
||||
DictStrAny,
|
||||
LiteralBackend,
|
||||
LiteralSerialisation,
|
||||
ParamSpec,
|
||||
AnyCallable,
|
||||
get_literal_args,
|
||||
)
|
||||
from openllm_core.utils import DEBUG, compose, dantic, resolve_user_filepath
|
||||
|
||||
|
||||
class _OpenLLM_GenericInternalConfig(LLMConfig):
|
||||
__config__ = {'name_type': 'lowercase', 'default_id': 'openllm/generic', 'model_ids': ['openllm/generic'], 'architecture': 'PreTrainedModel'}
|
||||
|
||||
class GenerationConfig:
|
||||
top_k: int = 15
|
||||
top_p: float = 0.78
|
||||
temperature: float = 0.75
|
||||
max_new_tokens: int = 128
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
P = ParamSpec('P')
|
||||
@@ -37,11 +35,20 @@ def bento_complete_envvar(ctx: click.Context, param: click.Parameter, incomplete
|
||||
|
||||
|
||||
def model_complete_envvar(ctx: click.Context, param: click.Parameter, incomplete: str) -> list[sc.CompletionItem]:
|
||||
return [sc.CompletionItem(inflection.dasherize(it), help='Model') for it in openllm.CONFIG_MAPPING if it.startswith(incomplete)]
|
||||
return [
|
||||
sc.CompletionItem(inflection.dasherize(it), help='Model')
|
||||
for it in openllm.CONFIG_MAPPING
|
||||
if it.startswith(incomplete)
|
||||
]
|
||||
|
||||
|
||||
def parse_config_options(
|
||||
config: LLMConfig, server_timeout: int, workers_per_resource: float, device: t.Tuple[str, ...] | None, cors: bool, environ: DictStrAny
|
||||
config: LLMConfig,
|
||||
server_timeout: int,
|
||||
workers_per_resource: float,
|
||||
device: t.Tuple[str, ...] | None,
|
||||
cors: bool,
|
||||
environ: DictStrAny,
|
||||
) -> DictStrAny:
|
||||
# TODO: Support amd.com/gpu on k8s
|
||||
_bentoml_config_options_env = environ.pop('BENTOML_CONFIG_OPTIONS', '')
|
||||
@@ -56,14 +63,21 @@ def parse_config_options(
|
||||
if device:
|
||||
if len(device) > 1:
|
||||
_bentoml_config_options_opts.extend([
|
||||
f'runners."llm-{config["start_name"]}-runner".resources."nvidia.com/gpu"[{idx}]={dev}' for idx, dev in enumerate(device)
|
||||
f'runners."llm-{config["start_name"]}-runner".resources."nvidia.com/gpu"[{idx}]={dev}'
|
||||
for idx, dev in enumerate(device)
|
||||
])
|
||||
else:
|
||||
_bentoml_config_options_opts.append(f'runners."llm-{config["start_name"]}-runner".resources."nvidia.com/gpu"=[{device[0]}]')
|
||||
_bentoml_config_options_opts.append(
|
||||
f'runners."llm-{config["start_name"]}-runner".resources."nvidia.com/gpu"=[{device[0]}]'
|
||||
)
|
||||
if cors:
|
||||
_bentoml_config_options_opts.extend(['api_server.http.cors.enabled=true', 'api_server.http.cors.access_control_allow_origins="*"'])
|
||||
_bentoml_config_options_opts.extend([
|
||||
f'api_server.http.cors.access_control_allow_methods[{idx}]="{it}"' for idx, it in enumerate(['GET', 'OPTIONS', 'POST', 'HEAD', 'PUT'])
|
||||
'api_server.http.cors.enabled=true',
|
||||
'api_server.http.cors.access_control_allow_origins="*"',
|
||||
])
|
||||
_bentoml_config_options_opts.extend([
|
||||
f'api_server.http.cors.access_control_allow_methods[{idx}]="{it}"'
|
||||
for idx, it in enumerate(['GET', 'OPTIONS', 'POST', 'HEAD', 'PUT'])
|
||||
])
|
||||
_bentoml_config_options_env += ' ' if _bentoml_config_options_env else '' + ' '.join(_bentoml_config_options_opts)
|
||||
environ['BENTOML_CONFIG_OPTIONS'] = _bentoml_config_options_env
|
||||
@@ -96,7 +110,7 @@ def _id_callback(ctx: click.Context, _: click.Parameter, value: t.Tuple[str, ...
|
||||
def optimization_decorator(fn: FC, *, factory=click, _eager=True) -> FC | list[AnyCallable]:
|
||||
shared = [
|
||||
dtype_option(factory=factory),
|
||||
model_version_option(factory=factory), #
|
||||
revision_option(factory=factory), #
|
||||
backend_option(factory=factory),
|
||||
quantize_option(factory=factory), #
|
||||
serialisation_option(factory=factory),
|
||||
@@ -106,52 +120,6 @@ def optimization_decorator(fn: FC, *, factory=click, _eager=True) -> FC | list[A
|
||||
return compose(*shared)(fn)
|
||||
|
||||
|
||||
def start_decorator(fn: FC) -> FC:
|
||||
composed = compose(
|
||||
_OpenLLM_GenericInternalConfig.parse,
|
||||
parse_serve_args(),
|
||||
cog.optgroup.group(
|
||||
'LLM Options',
|
||||
help="""The following options are related to running LLM Server as well as optimization options.
|
||||
|
||||
OpenLLM supports running model k-bit quantization (8-bit, 4-bit), GPTQ quantization, PagedAttention via vLLM.
|
||||
|
||||
The following are either in our roadmap or currently being worked on:
|
||||
|
||||
- DeepSpeed Inference: [link](https://www.deepspeed.ai/inference/)
|
||||
- GGML: Fast inference on [bare metal](https://github.com/ggerganov/ggml)
|
||||
""",
|
||||
),
|
||||
cog.optgroup.option('--server-timeout', type=int, default=None, help='Server timeout in seconds'),
|
||||
workers_per_resource_option(factory=cog.optgroup),
|
||||
cors_option(factory=cog.optgroup),
|
||||
*optimization_decorator(fn, factory=cog.optgroup, _eager=False),
|
||||
cog.optgroup.option(
|
||||
'--device',
|
||||
type=dantic.CUDA,
|
||||
multiple=True,
|
||||
envvar='CUDA_VISIBLE_DEVICES',
|
||||
callback=parse_device_callback,
|
||||
help='Assign GPU devices (if available)',
|
||||
show_envvar=True,
|
||||
),
|
||||
adapter_id_option(factory=cog.optgroup),
|
||||
click.option('--return-process', is_flag=True, default=False, help='Internal use only.', hidden=True),
|
||||
)
|
||||
|
||||
return composed(fn)
|
||||
|
||||
|
||||
def parse_device_callback(_: click.Context, param: click.Parameter, value: tuple[tuple[str], ...] | None) -> t.Tuple[str, ...] | None:
|
||||
if value is None:
|
||||
return value
|
||||
el: t.Tuple[str, ...] = tuple(i for k in value for i in k)
|
||||
# NOTE: --device all is a special case
|
||||
if len(el) == 1 and el[0] == 'all':
|
||||
return tuple(map(str, openllm.utils.available_devices()))
|
||||
return el
|
||||
|
||||
|
||||
# NOTE: A list of bentoml option that is not needed for parsing.
|
||||
# NOTE: User shouldn't set '--working-dir', as OpenLLM will setup this.
|
||||
# NOTE: production is also deprecated
|
||||
@@ -161,13 +129,19 @@ _IGNORED_OPTIONS = {'working_dir', 'production', 'protocol_version'}
|
||||
def parse_serve_args() -> t.Callable[[t.Callable[..., LLMConfig]], t.Callable[[FC], FC]]:
|
||||
from bentoml_cli.cli import cli
|
||||
|
||||
group = cog.optgroup.group('Start a HTTP server options', help='Related to serving the model [synonymous to `bentoml serve-http`]')
|
||||
group = cog.optgroup.group(
|
||||
'Start a HTTP server options', help='Related to serving the model [synonymous to `bentoml serve-http`]'
|
||||
)
|
||||
|
||||
def decorator(f: t.Callable[Concatenate[int, t.Optional[str], P], LLMConfig]) -> t.Callable[[FC], FC]:
|
||||
serve_command = cli.commands['serve']
|
||||
# The first variable is the argument bento
|
||||
# The last five is from BentoMLCommandGroup.NUMBER_OF_COMMON_PARAMS
|
||||
serve_options = [p for p in serve_command.params[1 : -BentoMLCommandGroup.NUMBER_OF_COMMON_PARAMS] if p.name not in _IGNORED_OPTIONS]
|
||||
serve_options = [
|
||||
p
|
||||
for p in serve_command.params[1 : -BentoMLCommandGroup.NUMBER_OF_COMMON_PARAMS]
|
||||
if p.name not in _IGNORED_OPTIONS
|
||||
]
|
||||
for options in reversed(serve_options):
|
||||
attrs = options.to_info_dict()
|
||||
# we don't need param_type_name, since it should all be options
|
||||
@@ -183,6 +157,48 @@ def parse_serve_args() -> t.Callable[[t.Callable[..., LLMConfig]], t.Callable[[F
|
||||
return decorator
|
||||
|
||||
|
||||
def start_decorator(fn: FC) -> FC:
|
||||
composed = compose(
|
||||
cog.optgroup.group(
|
||||
'LLM Options',
|
||||
help="""The following options are related to running LLM Server as well as optimization options.
|
||||
|
||||
OpenLLM supports running model k-bit quantization (8-bit, 4-bit), GPTQ quantization, PagedAttention via vLLM.
|
||||
|
||||
The following are either in our roadmap or currently being worked on:
|
||||
|
||||
- DeepSpeed Inference: [link](https://www.deepspeed.ai/inference/)
|
||||
- GGML: Fast inference on [bare metal](https://github.com/ggerganov/ggml)
|
||||
""",
|
||||
),
|
||||
*optimization_decorator(fn, factory=cog.optgroup, _eager=False),
|
||||
cog.optgroup.option(
|
||||
'--device',
|
||||
type=dantic.CUDA,
|
||||
multiple=True,
|
||||
envvar='CUDA_VISIBLE_DEVICES',
|
||||
callback=parse_device_callback,
|
||||
help='Assign GPU devices (if available)',
|
||||
show_envvar=True,
|
||||
),
|
||||
adapter_id_option(factory=cog.optgroup),
|
||||
)
|
||||
|
||||
return composed(fn)
|
||||
|
||||
|
||||
def parse_device_callback(
|
||||
_: click.Context, param: click.Parameter, value: tuple[tuple[str], ...] | None
|
||||
) -> t.Tuple[str, ...] | None:
|
||||
if value is None:
|
||||
return value
|
||||
el: t.Tuple[str, ...] = tuple(i for k in value for i in k)
|
||||
# NOTE: --device all is a special case
|
||||
if len(el) == 1 and el[0] == 'all':
|
||||
return tuple(map(str, openllm.utils.available_devices()))
|
||||
return el
|
||||
|
||||
|
||||
def _click_factory_type(*param_decls: t.Any, **attrs: t.Any) -> t.Callable[[FC | None], FC]:
|
||||
"""General ``@click`` decorator with some sauce.
|
||||
|
||||
@@ -221,7 +237,13 @@ def adapter_id_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callab
|
||||
|
||||
def cors_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option(
|
||||
'--cors/--no-cors', show_default=True, default=False, envvar='OPENLLM_CORS', show_envvar=True, help='Enable CORS for the server.', **attrs
|
||||
'--cors/--no-cors',
|
||||
show_default=True,
|
||||
default=False,
|
||||
envvar='OPENLLM_CORS',
|
||||
show_envvar=True,
|
||||
help='Enable CORS for the server.',
|
||||
**attrs,
|
||||
)(f)
|
||||
|
||||
|
||||
@@ -235,7 +257,7 @@ def dtype_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[F
|
||||
type=str,
|
||||
envvar='TORCH_DTYPE',
|
||||
default='auto',
|
||||
help="Optional dtype for casting tensors for running inference ['float16', 'float32', 'bfloat16', 'int8', 'int16']. For CTranslate2, it also accepts the following ['int8_float32', 'int8_float16', 'int8_bfloat16']",
|
||||
help="Optional dtype for casting tensors for running inference ['float16', 'float32', 'bfloat16', 'int8', 'int16']",
|
||||
**attrs,
|
||||
)(f)
|
||||
|
||||
@@ -252,12 +274,14 @@ def model_id_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable
|
||||
)(f)
|
||||
|
||||
|
||||
def model_version_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
def revision_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option(
|
||||
'--revision',
|
||||
'--model-version',
|
||||
'model_version',
|
||||
type=click.STRING,
|
||||
default=None,
|
||||
help='Optional model version to save for this model. It will be inferred automatically from model-id.',
|
||||
help='Optional model revision to save for this model. It will be inferred automatically from model-id.',
|
||||
**attrs,
|
||||
)(f)
|
||||
|
||||
@@ -275,7 +299,12 @@ def backend_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[
|
||||
|
||||
|
||||
def model_name_argument(f: _AnyCallable | None = None, required: bool = True, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_argument('model_name', type=click.Choice([inflection.dasherize(name) for name in openllm.CONFIG_MAPPING]), required=required, **attrs)(f)
|
||||
return cli_argument(
|
||||
'model_name',
|
||||
type=click.Choice([inflection.dasherize(name) for name in openllm.CONFIG_MAPPING]),
|
||||
required=required,
|
||||
**attrs,
|
||||
)(f)
|
||||
|
||||
|
||||
def quantize_option(f: _AnyCallable | None = None, *, build: bool = False, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
@@ -313,36 +342,6 @@ def quantize_option(f: _AnyCallable | None = None, *, build: bool = False, **att
|
||||
)(f)
|
||||
|
||||
|
||||
def workers_per_resource_option(f: _AnyCallable | None = None, *, build: bool = False, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option(
|
||||
'--workers-per-resource',
|
||||
default=None,
|
||||
callback=workers_per_resource_callback,
|
||||
type=str,
|
||||
required=False,
|
||||
help="""Number of workers per resource assigned.
|
||||
|
||||
See https://docs.bentoml.org/en/latest/guides/scheduling.html#resource-scheduling-strategy
|
||||
for more information. By default, this is set to 1.
|
||||
|
||||
> [!NOTE] ``--workers-per-resource`` will also accept the following strategies:
|
||||
|
||||
- ``round_robin``: Similar behaviour when setting ``--workers-per-resource 1``. This is useful for smaller models.
|
||||
|
||||
- ``conserved``: This will determine the number of available GPU resources. For example, if ther are 4 GPUs available, then ``conserved`` is equivalent to ``--workers-per-resource 0.25``.
|
||||
"""
|
||||
+ (
|
||||
"""\n
|
||||
> [!NOTE] The workers value passed into 'build' will determine how the LLM can
|
||||
> be provisioned in Kubernetes as well as in standalone container. This will
|
||||
> ensure it has the same effect with 'openllm start --api-workers ...'"""
|
||||
if build
|
||||
else ''
|
||||
),
|
||||
**attrs,
|
||||
)(f)
|
||||
|
||||
|
||||
def serialisation_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option(
|
||||
'--serialisation',
|
||||
@@ -365,23 +364,3 @@ def serialisation_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Cal
|
||||
""",
|
||||
**attrs,
|
||||
)(f)
|
||||
|
||||
|
||||
_wpr_strategies = {'round_robin', 'conserved'}
|
||||
|
||||
|
||||
def workers_per_resource_callback(ctx: click.Context, param: click.Parameter, value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return value
|
||||
value = inflection.underscore(value)
|
||||
if value in _wpr_strategies:
|
||||
return value
|
||||
else:
|
||||
try:
|
||||
float(value) # type: ignore[arg-type]
|
||||
except ValueError:
|
||||
raise click.BadParameter(
|
||||
f"'workers_per_resource' only accept '{_wpr_strategies}' as possible strategies, otherwise pass in float.", ctx, param
|
||||
) from None
|
||||
else:
|
||||
return value
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from __future__ import annotations
|
||||
import itertools, logging, os, re, subprocess, sys, typing as t
|
||||
import itertools, logging, os, re, subprocess, sys, typing as t, bentoml, openllm_core, orjson
|
||||
from simple_di import Provide, inject
|
||||
import bentoml, openllm_core, orjson
|
||||
from bentoml._internal.configuration.containers import BentoMLContainer
|
||||
from openllm_core._typing_compat import LiteralSerialisation
|
||||
from openllm_core.exceptions import OpenLLMException
|
||||
@@ -69,7 +68,10 @@ def _start(
|
||||
if timeout:
|
||||
args.extend(['--server-timeout', str(timeout)])
|
||||
if workers_per_resource:
|
||||
args.extend(['--workers-per-resource', str(workers_per_resource) if not isinstance(workers_per_resource, str) else workers_per_resource])
|
||||
args.extend([
|
||||
'--workers-per-resource',
|
||||
str(workers_per_resource) if not isinstance(workers_per_resource, str) else workers_per_resource,
|
||||
])
|
||||
if device and not os.environ.get('CUDA_VISIBLE_DEVICES'):
|
||||
args.extend(['--device', ','.join(device)])
|
||||
if quantize:
|
||||
@@ -77,7 +79,11 @@ def _start(
|
||||
if cors:
|
||||
args.append('--cors')
|
||||
if adapter_map:
|
||||
args.extend(list(itertools.chain.from_iterable([['--adapter-id', f"{k}{':'+v if v else ''}"] for k, v in adapter_map.items()])))
|
||||
args.extend(
|
||||
list(
|
||||
itertools.chain.from_iterable([['--adapter-id', f"{k}{':'+v if v else ''}"] for k, v in adapter_map.items()])
|
||||
)
|
||||
)
|
||||
if additional_args:
|
||||
args.extend(additional_args)
|
||||
if __test__:
|
||||
@@ -148,7 +154,9 @@ def _build(
|
||||
'--machine',
|
||||
'--quiet',
|
||||
'--serialisation',
|
||||
first_not_none(serialisation, default='safetensors' if has_safetensors_weights(model_id, model_version) else 'legacy'),
|
||||
first_not_none(
|
||||
serialisation, default='safetensors' if has_safetensors_weights(model_id, model_version) else 'legacy'
|
||||
),
|
||||
]
|
||||
if quantize:
|
||||
args.extend(['--quantize', quantize])
|
||||
@@ -265,4 +273,4 @@ start, build, import_model, list_models = (
|
||||
codegen.gen_sdk(_import_model),
|
||||
codegen.gen_sdk(_list_models),
|
||||
)
|
||||
__all__ = ['start', 'build', 'import_model', 'list_models']
|
||||
__all__ = ['build', 'import_model', 'list_models', 'start']
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -9,7 +9,7 @@ import click
|
||||
import inflection
|
||||
import orjson
|
||||
|
||||
from openllm_core._typing_compat import DictStrAny
|
||||
from openllm_core._typing_compat import DictStrAny, TypedDict
|
||||
from openllm_core.utils import get_debug_mode
|
||||
|
||||
logger = logging.getLogger('openllm')
|
||||
@@ -25,7 +25,14 @@ class Level(enum.IntEnum):
|
||||
|
||||
@property
|
||||
def color(self) -> str | None:
|
||||
return {Level.NOTSET: None, Level.DEBUG: 'cyan', Level.INFO: 'green', Level.WARNING: 'yellow', Level.ERROR: 'red', Level.CRITICAL: 'red'}[self]
|
||||
return {
|
||||
Level.NOTSET: None,
|
||||
Level.DEBUG: 'cyan',
|
||||
Level.INFO: 'green',
|
||||
Level.WARNING: 'yellow',
|
||||
Level.ERROR: 'red',
|
||||
Level.CRITICAL: 'red',
|
||||
}[self]
|
||||
|
||||
@classmethod
|
||||
def from_logging_level(cls, level: int) -> Level:
|
||||
@@ -38,7 +45,7 @@ class Level(enum.IntEnum):
|
||||
}[level]
|
||||
|
||||
|
||||
class JsonLog(t.TypedDict):
|
||||
class JsonLog(TypedDict):
|
||||
log_level: Level
|
||||
content: str
|
||||
|
||||
@@ -75,5 +82,9 @@ def echo(text: t.Any, fg: str | None = None, *, _with_style: bool = True, json:
|
||||
|
||||
|
||||
COLUMNS: int = int(os.environ.get('COLUMNS', str(120)))
|
||||
CONTEXT_SETTINGS: DictStrAny = {'help_option_names': ['-h', '--help'], 'max_content_width': COLUMNS, 'token_normalize_func': inflection.underscore}
|
||||
__all__ = ['echo', 'COLUMNS', 'CONTEXT_SETTINGS', 'log', 'warning', 'error', 'critical', 'debug', 'info', 'Level']
|
||||
CONTEXT_SETTINGS: DictStrAny = {
|
||||
'help_option_names': ['-h', '--help'],
|
||||
'max_content_width': COLUMNS,
|
||||
'token_normalize_func': inflection.underscore,
|
||||
}
|
||||
__all__ = ['COLUMNS', 'CONTEXT_SETTINGS', 'Level', 'critical', 'debug', 'echo', 'error', 'info', 'log', 'warning']
|
||||
|
||||
Reference in New Issue
Block a user