mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-01-28 17:31:57 -05:00
infra: using ruff formatter (#594)
Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
@@ -8,6 +8,7 @@ deploy, and monitor any LLMs with ease.
|
||||
* Online Serving with HTTP, gRPC, SSE(coming soon) or custom API
|
||||
* Native integration with BentoML and LangChain for custom LLM apps
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import logging as _logging
|
||||
import os as _os
|
||||
@@ -51,29 +52,41 @@ else:
|
||||
# configuration for bitsandbytes before import
|
||||
_os.environ['BITSANDBYTES_NOWELCOME'] = _os.environ.get('BITSANDBYTES_NOWELCOME', '1')
|
||||
# NOTE: The following warnings from bitsandbytes, and probably not that important for users to see when DEBUG is False
|
||||
_warnings.filterwarnings('ignore', message='MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization')
|
||||
_warnings.filterwarnings('ignore', message='MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization')
|
||||
_warnings.filterwarnings(
|
||||
'ignore', message='MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization'
|
||||
)
|
||||
_warnings.filterwarnings(
|
||||
'ignore', message='MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization'
|
||||
)
|
||||
_warnings.filterwarnings('ignore', message='The installed version of bitsandbytes was compiled without GPU support.')
|
||||
# NOTE: ignore the following warning from ghapi as it is not important for users
|
||||
_warnings.filterwarnings('ignore', message='Neither GITHUB_TOKEN nor GITHUB_JWT_TOKEN found: running as unauthenticated')
|
||||
_warnings.filterwarnings(
|
||||
'ignore', message='Neither GITHUB_TOKEN nor GITHUB_JWT_TOKEN found: running as unauthenticated'
|
||||
)
|
||||
|
||||
_import_structure: dict[str, list[str]] = {
|
||||
'exceptions': [],
|
||||
'client': [],
|
||||
'bundle': [],
|
||||
'playground': [],
|
||||
'testing': [],
|
||||
'prompts': ['PromptTemplate'],
|
||||
'protocol': [],
|
||||
'utils': [],
|
||||
'_deprecated': ['Runner'],
|
||||
'_strategies': ['CascadingResourceStrategy', 'get_resource'],
|
||||
'entrypoints': ['mount_entrypoints'],
|
||||
'serialisation': ['ggml', 'transformers'],
|
||||
'cli._sdk': ['start', 'start_grpc', 'build', 'import_model', 'list_models'],
|
||||
'_quantisation': ['infer_quantisation_config'],
|
||||
'_llm': ['LLM', 'LLMRunner', 'LLMRunnable'],
|
||||
'_generation': ['StopSequenceCriteria', 'StopOnTokens', 'LogitsProcessorList', 'StoppingCriteriaList', 'prepare_logits_processor'],
|
||||
'exceptions': [],
|
||||
'client': [],
|
||||
'bundle': [],
|
||||
'playground': [],
|
||||
'testing': [],
|
||||
'prompts': ['PromptTemplate'],
|
||||
'protocol': [],
|
||||
'utils': [],
|
||||
'_deprecated': ['Runner'],
|
||||
'_strategies': ['CascadingResourceStrategy', 'get_resource'],
|
||||
'entrypoints': ['mount_entrypoints'],
|
||||
'serialisation': ['ggml', 'transformers'],
|
||||
'cli._sdk': ['start', 'start_grpc', 'build', 'import_model', 'list_models'],
|
||||
'_quantisation': ['infer_quantisation_config'],
|
||||
'_llm': ['LLM', 'LLMRunner', 'LLMRunnable'],
|
||||
'_generation': [
|
||||
'StopSequenceCriteria',
|
||||
'StopOnTokens',
|
||||
'LogitsProcessorList',
|
||||
'StoppingCriteriaList',
|
||||
'prepare_logits_processor',
|
||||
],
|
||||
}
|
||||
COMPILED = _Path(__file__).suffix in ('.pyd', '.so')
|
||||
|
||||
@@ -109,7 +122,9 @@ if _t.TYPE_CHECKING:
|
||||
from .serialisation import transformers as transformers
|
||||
|
||||
# NOTE: update this to sys.modules[__name__] once mypy_extensions can recognize __spec__
|
||||
__lazy = openllm_core.utils.LazyModule(__name__, globals()['__file__'], _import_structure, extra_objects={'COMPILED': COMPILED})
|
||||
__lazy = openllm_core.utils.LazyModule(
|
||||
__name__, globals()['__file__'], _import_structure, extra_objects={'COMPILED': COMPILED}
|
||||
)
|
||||
__all__ = __lazy.__all__
|
||||
__dir__ = __lazy.__dir__
|
||||
__getattr__ = __lazy.__getattr__
|
||||
|
||||
@@ -6,8 +6,11 @@ Usage:
|
||||
To start any OpenLLM model:
|
||||
openllm start <model_name> --options ...
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
from openllm.cli.entrypoint import cli
|
||||
|
||||
cli()
|
||||
|
||||
@@ -9,27 +9,33 @@ from openllm_core._typing_compat import LiteralBackend
|
||||
from openllm_core.utils import first_not_none
|
||||
from openllm_core.utils import is_vllm_available
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from openllm_core import LLMConfig
|
||||
from openllm_core._typing_compat import ParamSpec
|
||||
|
||||
from ._llm import LLMRunner
|
||||
|
||||
P = ParamSpec('P')
|
||||
|
||||
_object_setattr = object.__setattr__
|
||||
|
||||
|
||||
def _mark_deprecated(fn: t.Callable[P, t.Any]) -> t.Callable[P, t.Any]:
|
||||
_object_setattr(fn, '__deprecated__', True)
|
||||
return fn
|
||||
|
||||
|
||||
@_mark_deprecated
|
||||
def Runner(model_name: str,
|
||||
ensure_available: bool = True,
|
||||
init_local: bool = False,
|
||||
backend: LiteralBackend | None = None,
|
||||
llm_config: LLMConfig | None = None,
|
||||
**attrs: t.Any) -> LLMRunner[t.Any, t.Any]:
|
||||
'''Create a Runner for given LLM. For a list of currently supported LLM, check out 'openllm models'.
|
||||
def Runner(
|
||||
model_name: str,
|
||||
ensure_available: bool = True,
|
||||
init_local: bool = False,
|
||||
backend: LiteralBackend | None = None,
|
||||
llm_config: LLMConfig | None = None,
|
||||
**attrs: t.Any,
|
||||
) -> LLMRunner[t.Any, t.Any]:
|
||||
"""Create a Runner for given LLM. For a list of currently supported LLM, check out 'openllm models'.
|
||||
|
||||
> [!WARNING]
|
||||
> This method is now deprecated and in favor of 'openllm.LLM.runner'
|
||||
@@ -54,11 +60,13 @@ def Runner(model_name: str,
|
||||
llm_config: Optional ``openllm.LLMConfig`` to initialise this ``openllm.LLMRunner``.
|
||||
init_local: If True, it will initialize the model locally. This is useful if you want to run the model locally. (Symmetrical to bentoml.Runner.init_local())
|
||||
**attrs: The rest of kwargs will then be passed to the LLM. Refer to the LLM documentation for the kwargs behaviour
|
||||
'''
|
||||
"""
|
||||
from ._llm import LLM
|
||||
if llm_config is None: llm_config = openllm.AutoConfig.for_model(model_name)
|
||||
|
||||
if llm_config is None:
|
||||
llm_config = openllm.AutoConfig.for_model(model_name)
|
||||
model_id = attrs.get('model_id', default=os.getenv('OPENLLM_MODEL_ID', llm_config['default_id']))
|
||||
_RUNNER_MSG = f'''\
|
||||
_RUNNER_MSG = f"""\
|
||||
Using 'openllm.Runner' is now deprecated. Make sure to switch to the following syntax:
|
||||
|
||||
```python
|
||||
@@ -70,24 +78,31 @@ def Runner(model_name: str,
|
||||
async def chat(input: str) -> str:
|
||||
async for it in llm.generate_iterator(input): print(it)
|
||||
```
|
||||
'''
|
||||
"""
|
||||
warnings.warn(_RUNNER_MSG, DeprecationWarning, stacklevel=2)
|
||||
attrs.update({
|
||||
attrs.update(
|
||||
{
|
||||
'model_id': model_id,
|
||||
'quantize': os.getenv('OPENLLM_QUANTIZE', attrs.get('quantize', None)),
|
||||
'serialisation': first_not_none(attrs.get('serialisation'), os.environ.get('OPENLLM_SERIALIZATION'), default=llm_config['serialisation']),
|
||||
'serialisation': first_not_none(
|
||||
attrs.get('serialisation'), os.environ.get('OPENLLM_SERIALIZATION'), default=llm_config['serialisation']
|
||||
),
|
||||
'system_message': first_not_none(os.environ.get('OPENLLM_SYSTEM_MESSAGE'), attrs.get('system_message'), None),
|
||||
'prompt_template': first_not_none(os.environ.get('OPENLLM_PROMPT_TEMPLATE'), attrs.get('prompt_template'), None),
|
||||
})
|
||||
}
|
||||
)
|
||||
|
||||
backend = t.cast(LiteralBackend, first_not_none(backend, default='vllm' if is_vllm_available() else 'pt'))
|
||||
llm = LLM[t.Any, t.Any](backend=backend, llm_config=llm_config, **attrs)
|
||||
if init_local: llm.runner.init_local(quiet=True)
|
||||
if init_local:
|
||||
llm.runner.init_local(quiet=True)
|
||||
return llm.runner
|
||||
|
||||
|
||||
_DEPRECATED = {k: v for k, v in locals().items() if getattr(v, '__deprecated__', False)}
|
||||
|
||||
__all__ = list(_DEPRECATED)
|
||||
|
||||
|
||||
def __dir__() -> list[str]:
|
||||
return sorted(_DEPRECATED.keys())
|
||||
|
||||
@@ -4,6 +4,7 @@ import typing as t
|
||||
|
||||
import transformers
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import torch
|
||||
|
||||
@@ -13,18 +14,30 @@ if t.TYPE_CHECKING:
|
||||
LogitsProcessorList = transformers.LogitsProcessorList
|
||||
StoppingCriteriaList = transformers.StoppingCriteriaList
|
||||
|
||||
|
||||
class StopSequenceCriteria(transformers.StoppingCriteria):
|
||||
def __init__(self, stop_sequences: str | list[str], tokenizer: transformers.PreTrainedTokenizer | transformers.PreTrainedTokenizerBase | transformers.PreTrainedTokenizerFast):
|
||||
if isinstance(stop_sequences, str): stop_sequences = [stop_sequences]
|
||||
def __init__(
|
||||
self,
|
||||
stop_sequences: str | list[str],
|
||||
tokenizer: transformers.PreTrainedTokenizer
|
||||
| transformers.PreTrainedTokenizerBase
|
||||
| transformers.PreTrainedTokenizerFast,
|
||||
):
|
||||
if isinstance(stop_sequences, str):
|
||||
stop_sequences = [stop_sequences]
|
||||
self.stop_sequences, self.tokenizer = stop_sequences, tokenizer
|
||||
|
||||
def __call__(self, input_ids: torch.Tensor, scores: t.Any, **_: t.Any) -> bool:
|
||||
return any(self.tokenizer.decode(input_ids.tolist()[0]).endswith(stop_sequence) for stop_sequence in self.stop_sequences)
|
||||
return any(
|
||||
self.tokenizer.decode(input_ids.tolist()[0]).endswith(stop_sequence) for stop_sequence in self.stop_sequences
|
||||
)
|
||||
|
||||
|
||||
class StopOnTokens(transformers.StoppingCriteria):
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **_: t.Any) -> bool:
|
||||
return input_ids[0][-1] in {50278, 50279, 50277, 1, 0}
|
||||
|
||||
|
||||
def prepare_logits_processor(config: openllm.LLMConfig) -> transformers.LogitsProcessorList:
|
||||
generation_config = config.generation_config
|
||||
logits_processor = transformers.LogitsProcessorList()
|
||||
@@ -34,24 +47,31 @@ def prepare_logits_processor(config: openllm.LLMConfig) -> transformers.LogitsPr
|
||||
logits_processor.append(transformers.RepetitionPenaltyLogitsProcessor(generation_config['repetition_penalty']))
|
||||
if 1e-8 <= generation_config['top_p']:
|
||||
logits_processor.append(transformers.TopPLogitsWarper(generation_config['top_p']))
|
||||
if generation_config['top_k'] > 0: logits_processor.append(transformers.TopKLogitsWarper(generation_config['top_k']))
|
||||
if generation_config['top_k'] > 0:
|
||||
logits_processor.append(transformers.TopKLogitsWarper(generation_config['top_k']))
|
||||
return logits_processor
|
||||
|
||||
|
||||
# NOTE: The ordering here is important. Some models have two of these and we have a preference for which value gets used.
|
||||
SEQLEN_KEYS = ['max_sequence_length', 'seq_length', 'max_position_embeddings', 'max_seq_len', 'model_max_length']
|
||||
|
||||
|
||||
def get_context_length(config: transformers.PretrainedConfig) -> int:
|
||||
rope_scaling = getattr(config, 'rope_scaling', None)
|
||||
rope_scaling_factor = config.rope_scaling['factor'] if rope_scaling else 1.0
|
||||
for key in SEQLEN_KEYS:
|
||||
if getattr(config, key, None) is not None: return int(rope_scaling_factor * getattr(config, key))
|
||||
if getattr(config, key, None) is not None:
|
||||
return int(rope_scaling_factor * getattr(config, key))
|
||||
return 2048
|
||||
|
||||
|
||||
def is_sentence_complete(output: str) -> bool:
|
||||
return output.endswith(('.', '?', '!', '...', '。', '?', '!', '…', '"', "'", '”'))
|
||||
|
||||
|
||||
def is_partial_stop(output: str, stop_str: str) -> bool:
|
||||
"""Check whether the output contains a partial stop str."""
|
||||
for i in range(0, min(len(output), len(stop_str))):
|
||||
if stop_str.startswith(output[-i:]): return True
|
||||
if stop_str.startswith(output[-i:]):
|
||||
return True
|
||||
return False
|
||||
|
||||
@@ -54,6 +54,7 @@ from .exceptions import ForbiddenAttributeError
|
||||
from .exceptions import OpenLLMException
|
||||
from .serialisation.constants import PEFT_CONFIG_NAME
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import peft
|
||||
import torch
|
||||
@@ -77,11 +78,14 @@ P = ParamSpec('P')
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def normalise_model_name(name: str) -> str:
|
||||
if validate_is_path(name): return os.path.basename(resolve_filepath(name))
|
||||
if validate_is_path(name):
|
||||
return os.path.basename(resolve_filepath(name))
|
||||
name = name.replace('/', '--')
|
||||
return inflection.dasherize(name)
|
||||
|
||||
|
||||
def resolve_peft_config_type(adapter_map: dict[str, str]) -> AdapterMap:
|
||||
"""Resolve the type of the PeftConfig given the adapter_map.
|
||||
|
||||
@@ -93,7 +97,8 @@ def resolve_peft_config_type(adapter_map: dict[str, str]) -> AdapterMap:
|
||||
resolved: AdapterMap = {}
|
||||
_has_set_default = False
|
||||
for path_or_adapter_id, name in adapter_map.items():
|
||||
if name is None: raise ValueError('Adapter name must be specified.')
|
||||
if name is None:
|
||||
raise ValueError('Adapter name must be specified.')
|
||||
if os.path.isfile(os.path.join(path_or_adapter_id, PEFT_CONFIG_NAME)):
|
||||
config_file = os.path.join(path_or_adapter_id, PEFT_CONFIG_NAME)
|
||||
else:
|
||||
@@ -105,13 +110,16 @@ def resolve_peft_config_type(adapter_map: dict[str, str]) -> AdapterMap:
|
||||
resolved_config = orjson.loads(file.read())
|
||||
# all peft_type should be available in PEFT_CONFIG_NAME
|
||||
_peft_type: AdapterType = resolved_config['peft_type'].lower()
|
||||
if _peft_type not in resolved: resolved[_peft_type] = ()
|
||||
if _peft_type not in resolved:
|
||||
resolved[_peft_type] = ()
|
||||
resolved[_peft_type] += (_AdapterTuple((path_or_adapter_id, name, resolved_config)),)
|
||||
return resolved
|
||||
|
||||
|
||||
_reserved_namespace = {'model', 'tokenizer', 'runner', 'import_kwargs'}
|
||||
_AdapterTuple: type[AdapterTuple] = codegen.make_attr_tuple_class('AdapterTuple', ['adapter_id', 'name', 'config'])
|
||||
|
||||
|
||||
@attr.define(slots=True, repr=False, init=False)
|
||||
class LLM(t.Generic[M, T]):
|
||||
_model_id: str
|
||||
@@ -140,30 +148,44 @@ class LLM(t.Generic[M, T]):
|
||||
device: 'torch.device | None' = None
|
||||
|
||||
def __attrs_post_init__(self) -> None:
|
||||
if self.__llm_backend__ == 'pt': self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
if self.__llm_backend__ == 'pt':
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
def __init__(self,
|
||||
model_id: str,
|
||||
model_version: str | None = None,
|
||||
model_tag: str | bentoml.Tag | None = None,
|
||||
prompt_template: PromptTemplate | str | None = None,
|
||||
system_message: str | None = None,
|
||||
llm_config: LLMConfig | None = None,
|
||||
backend: LiteralBackend | None = None,
|
||||
*args: t.Any,
|
||||
quantize: LiteralQuantise | None = None,
|
||||
quantization_config: transformers.BitsAndBytesConfig | transformers.GPTQConfig | transformers.AwqConfig | None = None,
|
||||
adapter_map: dict[str, str] | None = None,
|
||||
serialisation: LiteralSerialisation = 'safetensors',
|
||||
trust_remote_code: bool = False,
|
||||
**attrs: t.Any):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
model_version: str | None = None,
|
||||
model_tag: str | bentoml.Tag | None = None,
|
||||
prompt_template: PromptTemplate | str | None = None,
|
||||
system_message: str | None = None,
|
||||
llm_config: LLMConfig | None = None,
|
||||
backend: LiteralBackend | None = None,
|
||||
*args: t.Any,
|
||||
quantize: LiteralQuantise | None = None,
|
||||
quantization_config: transformers.BitsAndBytesConfig
|
||||
| transformers.GPTQConfig
|
||||
| transformers.AwqConfig
|
||||
| None = None,
|
||||
adapter_map: dict[str, str] | None = None,
|
||||
serialisation: LiteralSerialisation = 'safetensors',
|
||||
trust_remote_code: bool = False,
|
||||
**attrs: t.Any,
|
||||
):
|
||||
# low_cpu_mem_usage is only available for model this is helpful on system with low memory to avoid OOM
|
||||
low_cpu_mem_usage = attrs.pop('low_cpu_mem_usage', True)
|
||||
_local = False
|
||||
if validate_is_path(model_id): model_id, _local = resolve_filepath(model_id), True
|
||||
backend = t.cast(LiteralBackend, first_not_none(backend, os.getenv('OPENLLM_BACKEND'), default='vllm' if openllm.utils.is_vllm_available() else 'pt'))
|
||||
if validate_is_path(model_id):
|
||||
model_id, _local = resolve_filepath(model_id), True
|
||||
backend = t.cast(
|
||||
LiteralBackend,
|
||||
first_not_none(
|
||||
backend, os.getenv('OPENLLM_BACKEND'), default='vllm' if openllm.utils.is_vllm_available() else 'pt'
|
||||
),
|
||||
)
|
||||
|
||||
quantize = first_not_none(quantize, t.cast(t.Optional[LiteralQuantise], os.getenv('OPENLLM_QUANTIZE')), default=None)
|
||||
quantize = first_not_none(
|
||||
quantize, t.cast(t.Optional[LiteralQuantise], os.getenv('OPENLLM_QUANTIZE')), default=None
|
||||
)
|
||||
# elif quantization_config is None and quantize is not None:
|
||||
# quantization_config, attrs = infer_quantisation_config(self, quantize, **attrs)
|
||||
attrs.update({'low_cpu_mem_usage': low_cpu_mem_usage})
|
||||
@@ -171,28 +193,35 @@ class LLM(t.Generic[M, T]):
|
||||
# parsing tokenizer and model kwargs, as the hierarchy is param pass > default
|
||||
model_attrs, tokenizer_attrs = flatten_attrs(**attrs)
|
||||
|
||||
if adapter_map is not None and not is_peft_available(): raise RuntimeError("LoRA adapter requires 'peft' to be installed. Make sure to do 'pip install \"openllm[fine-tune]\"'")
|
||||
if isinstance(prompt_template, str): prompt_template = PromptTemplate(prompt_template)
|
||||
if adapter_map is not None and not is_peft_available():
|
||||
raise RuntimeError(
|
||||
"LoRA adapter requires 'peft' to be installed. Make sure to do 'pip install \"openllm[fine-tune]\"'"
|
||||
)
|
||||
if isinstance(prompt_template, str):
|
||||
prompt_template = PromptTemplate(prompt_template)
|
||||
if model_tag is None:
|
||||
model_tag, model_version = self._make_tag_components(model_id, model_version, backend=backend)
|
||||
if model_version: model_tag = f'{model_tag}:{model_version}'
|
||||
if model_version:
|
||||
model_tag = f'{model_tag}:{model_version}'
|
||||
|
||||
self.__attrs_init__(model_id=model_id,
|
||||
revision=model_version,
|
||||
tag=bentoml.Tag.from_taglike(t.cast(t.Union[str, bentoml.Tag], model_tag)),
|
||||
quantization_config=quantization_config,
|
||||
quantise=quantize,
|
||||
model_decls=args,
|
||||
model_attrs=dict(**self.import_kwargs[0], **model_attrs),
|
||||
tokenizer_attrs=dict(**self.import_kwargs[-1], **tokenizer_attrs),
|
||||
adapter_map=resolve_peft_config_type(adapter_map) if adapter_map is not None else None,
|
||||
serialisation=serialisation,
|
||||
local=_local,
|
||||
prompt_template=prompt_template,
|
||||
system_message=system_message,
|
||||
llm_backend__=backend,
|
||||
llm_config__=llm_config,
|
||||
llm_trust_remote_code__=trust_remote_code)
|
||||
self.__attrs_init__(
|
||||
model_id=model_id,
|
||||
revision=model_version,
|
||||
tag=bentoml.Tag.from_taglike(t.cast(t.Union[str, bentoml.Tag], model_tag)),
|
||||
quantization_config=quantization_config,
|
||||
quantise=quantize,
|
||||
model_decls=args,
|
||||
model_attrs=dict(**self.import_kwargs[0], **model_attrs),
|
||||
tokenizer_attrs=dict(**self.import_kwargs[-1], **tokenizer_attrs),
|
||||
adapter_map=resolve_peft_config_type(adapter_map) if adapter_map is not None else None,
|
||||
serialisation=serialisation,
|
||||
local=_local,
|
||||
prompt_template=prompt_template,
|
||||
system_message=system_message,
|
||||
llm_backend__=backend,
|
||||
llm_config__=llm_config,
|
||||
llm_trust_remote_code__=trust_remote_code,
|
||||
)
|
||||
|
||||
try:
|
||||
model = bentoml.models.get(self.tag)
|
||||
@@ -202,13 +231,24 @@ class LLM(t.Generic[M, T]):
|
||||
self._tag = model.tag
|
||||
|
||||
@apply(lambda val: tuple(str.lower(i) if i else i for i in val))
|
||||
def _make_tag_components(self, model_id: str, model_version: str | None, backend: LiteralBackend) -> tuple[str, str | None]:
|
||||
def _make_tag_components(
|
||||
self, model_id: str, model_version: str | None, backend: LiteralBackend
|
||||
) -> tuple[str, str | None]:
|
||||
"""Return a valid tag name (<backend>-<repo>--<model_id>) and its tag version."""
|
||||
model_id, *maybe_revision = model_id.rsplit(':')
|
||||
if len(maybe_revision) > 0:
|
||||
if model_version is not None: logger.warning("revision is specified within 'model_id' (%s), and 'model_version=%s' will be ignored.", maybe_revision[0], model_version)
|
||||
if model_version is not None:
|
||||
logger.warning(
|
||||
"revision is specified within 'model_id' (%s), and 'model_version=%s' will be ignored.",
|
||||
maybe_revision[0],
|
||||
model_version,
|
||||
)
|
||||
model_version = maybe_revision[0]
|
||||
if validate_is_path(model_id): model_id, model_version = resolve_filepath(model_id), first_not_none(model_version, default=generate_hash_from_file(model_id))
|
||||
if validate_is_path(model_id):
|
||||
model_id, model_version = (
|
||||
resolve_filepath(model_id),
|
||||
first_not_none(model_version, default=generate_hash_from_file(model_id)),
|
||||
)
|
||||
return f'{backend}-{normalise_model_name(model_id)}', model_version
|
||||
|
||||
# yapf: disable
|
||||
@@ -257,28 +297,44 @@ class LLM(t.Generic[M, T]):
|
||||
try:
|
||||
import peft as _ # noqa: F401
|
||||
except ImportError as err:
|
||||
raise MissingDependencyError("Failed to import 'peft'. Make sure to do 'pip install \"openllm[fine-tune]\"'") from err
|
||||
if not self.has_adapters: raise AttributeError('Adapter map is not available.')
|
||||
raise MissingDependencyError(
|
||||
"Failed to import 'peft'. Make sure to do 'pip install \"openllm[fine-tune]\"'"
|
||||
) from err
|
||||
if not self.has_adapters:
|
||||
raise AttributeError('Adapter map is not available.')
|
||||
assert self._adapter_map is not None
|
||||
if self.__llm_adapter_map__ is None:
|
||||
_map: ResolvedAdapterMap = {k: {} for k in self._adapter_map}
|
||||
for adapter_type, adapter_tuple in self._adapter_map.items():
|
||||
base = first_not_none(self.config['fine_tune_strategies'].get(adapter_type), default=self.config.make_fine_tune_config(adapter_type))
|
||||
base = first_not_none(
|
||||
self.config['fine_tune_strategies'].get(adapter_type),
|
||||
default=self.config.make_fine_tune_config(adapter_type),
|
||||
)
|
||||
for adapter in adapter_tuple:
|
||||
_map[adapter_type][adapter.name] = (base.with_config(**adapter.config).build(), adapter.adapter_id)
|
||||
self.__llm_adapter_map__ = _map
|
||||
return self.__llm_adapter_map__
|
||||
|
||||
def prepare_for_training(self,
|
||||
adapter_type: AdapterType = 'lora',
|
||||
use_gradient_checking: bool = True,
|
||||
**attrs: t.Any) -> tuple[peft.PeftModel | peft.PeftModelForCausalLM | peft.PeftModelForSeq2SeqLM, T]:
|
||||
def prepare_for_training(
|
||||
self, adapter_type: AdapterType = 'lora', use_gradient_checking: bool = True, **attrs: t.Any
|
||||
) -> tuple[peft.PeftModel | peft.PeftModelForCausalLM | peft.PeftModelForSeq2SeqLM, T]:
|
||||
from peft import get_peft_model
|
||||
from peft import prepare_model_for_kbit_training
|
||||
peft_config = self.config['fine_tune_strategies'].get(adapter_type, self.config.make_fine_tune_config(adapter_type)).train().with_config(**attrs).build()
|
||||
if self.has_adapters: raise ValueError('Adapter should not be specified when fine-tuning.')
|
||||
model = get_peft_model(prepare_model_for_kbit_training(self.model, use_gradient_checkpointing=use_gradient_checking), peft_config) # type: ignore[no-untyped-call]
|
||||
if DEBUG: model.print_trainable_parameters() # type: ignore[no-untyped-call]
|
||||
|
||||
peft_config = (
|
||||
self.config['fine_tune_strategies']
|
||||
.get(adapter_type, self.config.make_fine_tune_config(adapter_type))
|
||||
.train()
|
||||
.with_config(**attrs)
|
||||
.build()
|
||||
)
|
||||
if self.has_adapters:
|
||||
raise ValueError('Adapter should not be specified when fine-tuning.')
|
||||
model = get_peft_model(
|
||||
prepare_model_for_kbit_training(self.model, use_gradient_checkpointing=use_gradient_checking), peft_config
|
||||
) # type: ignore[no-untyped-call]
|
||||
if DEBUG:
|
||||
model.print_trainable_parameters() # type: ignore[no-untyped-call]
|
||||
return model, self.tokenizer
|
||||
|
||||
@property
|
||||
@@ -288,13 +344,22 @@ class LLM(t.Generic[M, T]):
|
||||
# If OOM, then it is probably you don't have enough VRAM to run this model.
|
||||
if self.__llm_backend__ == 'pt':
|
||||
if is_torch_available():
|
||||
loaded_in_kbit = getattr(model, 'is_loaded_in_8bit', False) or getattr(model, 'is_loaded_in_4bit', False) or getattr(model, 'is_quantized', False)
|
||||
if torch.cuda.is_available() and torch.cuda.device_count() == 1 and not loaded_in_kbit and not isinstance(model, transformers.Pipeline):
|
||||
loaded_in_kbit = (
|
||||
getattr(model, 'is_loaded_in_8bit', False)
|
||||
or getattr(model, 'is_loaded_in_4bit', False)
|
||||
or getattr(model, 'is_quantized', False)
|
||||
)
|
||||
if (
|
||||
torch.cuda.is_available()
|
||||
and torch.cuda.device_count() == 1
|
||||
and not loaded_in_kbit
|
||||
and not isinstance(model, transformers.Pipeline)
|
||||
):
|
||||
try:
|
||||
model = model.to('cuda')
|
||||
except Exception as err:
|
||||
raise OpenLLMException(
|
||||
f'Failed to load {self} into GPU: {err}\nTip: If you run into OOM issue, maybe try different offload strategy. See https://huggingface.co/docs/transformers/v4.31.0/en/main_classes/quantization#offload-between-cpu-and-gpu for more information.'
|
||||
f'Failed to load {self} into GPU: {err}\nTip: If you run into OOM issue, maybe try different offload strategy. See https://huggingface.co/docs/transformers/v4.31.0/en/main_classes/quantization#offload-between-cpu-and-gpu for more information.'
|
||||
) from err
|
||||
if self.has_adapters:
|
||||
logger.debug('Applying the following adapters: %s', self.adapter_map)
|
||||
@@ -307,83 +372,117 @@ class LLM(t.Generic[M, T]):
|
||||
@property
|
||||
def tokenizer(self) -> T:
|
||||
# NOTE: the signature of load_tokenizer here is the wrapper under _wrapped_load_tokenizer
|
||||
if self.__llm_tokenizer__ is None: self.__llm_tokenizer__ = openllm.serialisation.load_tokenizer(self, **self.llm_parameters[-1])
|
||||
if self.__llm_tokenizer__ is None:
|
||||
self.__llm_tokenizer__ = openllm.serialisation.load_tokenizer(self, **self.llm_parameters[-1])
|
||||
return self.__llm_tokenizer__
|
||||
|
||||
@property
|
||||
def runner(self) -> LLMRunner[M, T]:
|
||||
if self.__llm_runner__ is None: self.__llm_runner__ = _RunnerFactory(self)
|
||||
if self.__llm_runner__ is None:
|
||||
self.__llm_runner__ = _RunnerFactory(self)
|
||||
return self.__llm_runner__
|
||||
|
||||
async def generate(self,
|
||||
prompt: str | None,
|
||||
prompt_token_ids: list[int] | None = None,
|
||||
stop: str | t.Iterable[str] | None = None,
|
||||
stop_token_ids: list[int] | None = None,
|
||||
request_id: str | None = None,
|
||||
adapter_name: str | None = None,
|
||||
**attrs: t.Any) -> GenerationOutput:
|
||||
async def generate(
|
||||
self,
|
||||
prompt: str | None,
|
||||
prompt_token_ids: list[int] | None = None,
|
||||
stop: str | t.Iterable[str] | None = None,
|
||||
stop_token_ids: list[int] | None = None,
|
||||
request_id: str | None = None,
|
||||
adapter_name: str | None = None,
|
||||
**attrs: t.Any,
|
||||
) -> GenerationOutput:
|
||||
config = self.config.model_construct_env(**attrs)
|
||||
texts: list[list[str]] = [[]] * config['n']
|
||||
token_ids: list[list[int]] = [[]] * config['n']
|
||||
final_result: GenerationOutput | None = None
|
||||
async for result in self.generate_iterator(prompt, prompt_token_ids, stop, stop_token_ids, request_id, adapter_name, **config.model_dump(flatten=True)):
|
||||
async for result in self.generate_iterator(
|
||||
prompt, prompt_token_ids, stop, stop_token_ids, request_id, adapter_name, **config.model_dump(flatten=True)
|
||||
):
|
||||
for output in result.outputs:
|
||||
texts[output.index].append(output.text)
|
||||
token_ids[output.index].extend(output.token_ids)
|
||||
final_result = result
|
||||
if final_result is None: raise RuntimeError('No result is returned.')
|
||||
return final_result.with_options(prompt=prompt, outputs=[output.with_options(text=''.join(texts[output.index]), token_ids=token_ids[output.index]) for output in final_result.outputs])
|
||||
if final_result is None:
|
||||
raise RuntimeError('No result is returned.')
|
||||
return final_result.with_options(
|
||||
prompt=prompt,
|
||||
outputs=[
|
||||
output.with_options(text=''.join(texts[output.index]), token_ids=token_ids[output.index])
|
||||
for output in final_result.outputs
|
||||
],
|
||||
)
|
||||
|
||||
async def generate_iterator(self,
|
||||
prompt: str | None,
|
||||
prompt_token_ids: list[int] | None = None,
|
||||
stop: str | t.Iterable[str] | None = None,
|
||||
stop_token_ids: list[int] | None = None,
|
||||
request_id: str | None = None,
|
||||
adapter_name: str | None = None,
|
||||
**attrs: t.Any) -> t.AsyncGenerator[GenerationOutput, None]:
|
||||
async def generate_iterator(
|
||||
self,
|
||||
prompt: str | None,
|
||||
prompt_token_ids: list[int] | None = None,
|
||||
stop: str | t.Iterable[str] | None = None,
|
||||
stop_token_ids: list[int] | None = None,
|
||||
request_id: str | None = None,
|
||||
adapter_name: str | None = None,
|
||||
**attrs: t.Any,
|
||||
) -> t.AsyncGenerator[GenerationOutput, None]:
|
||||
if isinstance(self.runner._runner_handle, DummyRunnerHandle):
|
||||
if os.getenv('BENTO_PATH') is not None: raise RuntimeError('Runner client failed to set up correctly.')
|
||||
else: self.runner.init_local(quiet=True)
|
||||
if os.getenv('BENTO_PATH') is not None:
|
||||
raise RuntimeError('Runner client failed to set up correctly.')
|
||||
else:
|
||||
self.runner.init_local(quiet=True)
|
||||
|
||||
config = self.config.model_construct_env(**attrs)
|
||||
|
||||
if stop_token_ids is None: stop_token_ids = []
|
||||
if self.tokenizer.eos_token_id not in stop_token_ids: stop_token_ids.append(self.tokenizer.eos_token_id)
|
||||
if stop is None: stop = set()
|
||||
elif isinstance(stop, str): stop = {stop}
|
||||
else: stop = set(stop)
|
||||
if stop_token_ids is None:
|
||||
stop_token_ids = []
|
||||
if self.tokenizer.eos_token_id not in stop_token_ids:
|
||||
stop_token_ids.append(self.tokenizer.eos_token_id)
|
||||
if stop is None:
|
||||
stop = set()
|
||||
elif isinstance(stop, str):
|
||||
stop = {stop}
|
||||
else:
|
||||
stop = set(stop)
|
||||
for tid in stop_token_ids:
|
||||
if tid: stop.add(self.tokenizer.decode(tid))
|
||||
if tid:
|
||||
stop.add(self.tokenizer.decode(tid))
|
||||
|
||||
if prompt_token_ids is None:
|
||||
if prompt is None: raise ValueError('Either prompt or prompt_token_ids must be specified.')
|
||||
if prompt is None:
|
||||
raise ValueError('Either prompt or prompt_token_ids must be specified.')
|
||||
prompt_token_ids = self.tokenizer.encode(prompt)
|
||||
|
||||
if request_id is None: request_id = openllm_core.utils.gen_random_uuid()
|
||||
if request_id is None:
|
||||
request_id = openllm_core.utils.gen_random_uuid()
|
||||
previous_texts, previous_num_tokens = [''] * config['n'], [0] * config['n']
|
||||
async for out in self.runner.generate_iterator.async_stream(prompt_token_ids, request_id, stop, adapter_name, **config.model_dump(flatten=True)):
|
||||
async for out in self.runner.generate_iterator.async_stream(
|
||||
prompt_token_ids, request_id, stop, adapter_name, **config.model_dump(flatten=True)
|
||||
):
|
||||
generated = GenerationOutput.from_runner(out).with_options(prompt=prompt)
|
||||
delta_outputs = t.cast(t.List[CompletionChunk], [None] * len(generated.outputs))
|
||||
if generated.finished: break
|
||||
if generated.finished:
|
||||
break
|
||||
for output in generated.outputs:
|
||||
i = output.index
|
||||
delta_tokens, delta_text = output.token_ids[previous_num_tokens[i]:], output.text[len(previous_texts[i]):]
|
||||
delta_tokens, delta_text = output.token_ids[previous_num_tokens[i] :], output.text[len(previous_texts[i]) :]
|
||||
previous_texts[i], previous_num_tokens[i] = output.text, len(output.token_ids)
|
||||
delta_outputs[i] = output.with_options(text=delta_text, token_ids=delta_tokens)
|
||||
yield generated.with_options(outputs=delta_outputs)
|
||||
|
||||
def _RunnerFactory(self: openllm.LLM[M, T],
|
||||
/,
|
||||
models: list[bentoml.Model] | None = None,
|
||||
max_batch_size: int | None = None,
|
||||
max_latency_ms: int | None = None,
|
||||
scheduling_strategy: type[bentoml.Strategy] = CascadingResourceStrategy,
|
||||
*,
|
||||
backend: LiteralBackend | None = None) -> LLMRunner[M, T]:
|
||||
|
||||
def _RunnerFactory(
|
||||
self: openllm.LLM[M, T],
|
||||
/,
|
||||
models: list[bentoml.Model] | None = None,
|
||||
max_batch_size: int | None = None,
|
||||
max_latency_ms: int | None = None,
|
||||
scheduling_strategy: type[bentoml.Strategy] = CascadingResourceStrategy,
|
||||
*,
|
||||
backend: LiteralBackend | None = None,
|
||||
) -> LLMRunner[M, T]:
|
||||
from ._runners import runnable
|
||||
backend = t.cast(LiteralBackend, first_not_none(backend, os.environ.get('OPENLLM_BACKEND'), default=self.__llm_backend__))
|
||||
|
||||
backend = t.cast(
|
||||
LiteralBackend, first_not_none(backend, os.environ.get('OPENLLM_BACKEND'), default=self.__llm_backend__)
|
||||
)
|
||||
|
||||
models = models if models is not None else []
|
||||
try:
|
||||
@@ -391,12 +490,18 @@ def _RunnerFactory(self: openllm.LLM[M, T],
|
||||
except bentoml.exceptions.NotFound as err:
|
||||
raise RuntimeError(f'Failed to locate {self.bentomodel}:{err}') from err
|
||||
|
||||
if self._prompt_template: prompt_template = self._prompt_template.to_string()
|
||||
elif hasattr(self.config, 'default_prompt_template'): prompt_template = self.config.default_prompt_template
|
||||
else: prompt_template = None
|
||||
if self._system_message: system_message = self._system_message
|
||||
elif hasattr(self.config, 'default_system_message'): system_message = self.config.default_system_message
|
||||
else: system_message = None
|
||||
if self._prompt_template:
|
||||
prompt_template = self._prompt_template.to_string()
|
||||
elif hasattr(self.config, 'default_prompt_template'):
|
||||
prompt_template = self.config.default_prompt_template
|
||||
else:
|
||||
prompt_template = None
|
||||
if self._system_message:
|
||||
system_message = self._system_message
|
||||
elif hasattr(self.config, 'default_system_message'):
|
||||
system_message = self.config.default_system_message
|
||||
else:
|
||||
system_message = None
|
||||
|
||||
# yapf: disable
|
||||
def _wrapped_repr_keys(_: LLMRunner[M, T]) -> set[str]: return {'config', 'llm_type', 'runner_methods', 'backend', 'llm_tag'}
|
||||
@@ -408,31 +513,39 @@ def _RunnerFactory(self: openllm.LLM[M, T],
|
||||
yield 'llm_tag', self.tag
|
||||
# yapf: enable
|
||||
|
||||
return types.new_class(self.__class__.__name__ + 'Runner', (bentoml.Runner,),
|
||||
exec_body=lambda ns: ns.update({
|
||||
'llm_type': self.llm_type,
|
||||
'identifying_params': self.identifying_params,
|
||||
'llm_tag': self.tag,
|
||||
'llm': self,
|
||||
'config': self.config,
|
||||
'backend': backend,
|
||||
'__module__': self.__module__,
|
||||
'__doc__': getattr(openllm_core.config, f'START_{self.config["model_name"].upper()}_COMMAND_DOCSTRING'),
|
||||
'__repr__': ReprMixin.__repr__,
|
||||
'__repr_keys__': property(_wrapped_repr_keys),
|
||||
'__repr_args__': _wrapped_repr_args,
|
||||
'has_adapters': self.has_adapters,
|
||||
'prompt_template': prompt_template,
|
||||
'system_message': system_message,
|
||||
}))(runnable(backend),
|
||||
name=self.runner_name,
|
||||
embedded=False,
|
||||
models=models,
|
||||
max_batch_size=max_batch_size,
|
||||
max_latency_ms=max_latency_ms,
|
||||
scheduling_strategy=scheduling_strategy,
|
||||
runnable_init_params=dict(llm=self),
|
||||
method_configs=converter.unstructure({'generate_iterator': ModelSignature(batchable=False)}))
|
||||
return types.new_class(
|
||||
self.__class__.__name__ + 'Runner',
|
||||
(bentoml.Runner,),
|
||||
exec_body=lambda ns: ns.update(
|
||||
{
|
||||
'llm_type': self.llm_type,
|
||||
'identifying_params': self.identifying_params,
|
||||
'llm_tag': self.tag,
|
||||
'llm': self,
|
||||
'config': self.config,
|
||||
'backend': backend,
|
||||
'__module__': self.__module__,
|
||||
'__doc__': getattr(openllm_core.config, f'START_{self.config["model_name"].upper()}_COMMAND_DOCSTRING'),
|
||||
'__repr__': ReprMixin.__repr__,
|
||||
'__repr_keys__': property(_wrapped_repr_keys),
|
||||
'__repr_args__': _wrapped_repr_args,
|
||||
'has_adapters': self.has_adapters,
|
||||
'prompt_template': prompt_template,
|
||||
'system_message': system_message,
|
||||
}
|
||||
),
|
||||
)(
|
||||
runnable(backend),
|
||||
name=self.runner_name,
|
||||
embedded=False,
|
||||
models=models,
|
||||
max_batch_size=max_batch_size,
|
||||
max_latency_ms=max_latency_ms,
|
||||
scheduling_strategy=scheduling_strategy,
|
||||
runnable_init_params=dict(llm=self),
|
||||
method_configs=converter.unstructure({'generate_iterator': ModelSignature(batchable=False)}),
|
||||
)
|
||||
|
||||
|
||||
@t.final
|
||||
class LLMRunnable(bentoml.Runnable, t.Generic[M, T]):
|
||||
@@ -440,6 +553,7 @@ class LLMRunnable(bentoml.Runnable, t.Generic[M, T]):
|
||||
SUPPORTS_CPU_MULTI_THREADING = True
|
||||
generate_iterator: RunnableMethod[LLMRunnable[M, T], [list[int], str, str | t.Iterable[str] | None, str | None], str]
|
||||
|
||||
|
||||
@t.final
|
||||
class LLMRunner(t.Protocol[M, T]):
|
||||
__doc__: str
|
||||
@@ -461,22 +575,23 @@ class LLMRunner(t.Protocol[M, T]):
|
||||
runnable_init_params: dict[str, t.Any]
|
||||
_runner_handle: RunnerHandle
|
||||
|
||||
def __init__(self,
|
||||
runnable_class: type[LLMRunnable[M, T]],
|
||||
*,
|
||||
runnable_init_params: dict[str, t.Any] | None = ...,
|
||||
name: str | None = ...,
|
||||
scheduling_strategy: type[Strategy] = ...,
|
||||
models: list[bentoml.Model] | None = ...,
|
||||
max_batch_size: int | None = ...,
|
||||
max_latency_ms: int | None = ...,
|
||||
method_configs: dict[str, dict[str, int]] | None = ...,
|
||||
embedded: bool = False) -> None:
|
||||
...
|
||||
def __init__(
|
||||
self,
|
||||
runnable_class: type[LLMRunnable[M, T]],
|
||||
*,
|
||||
runnable_init_params: dict[str, t.Any] | None = ...,
|
||||
name: str | None = ...,
|
||||
scheduling_strategy: type[Strategy] = ...,
|
||||
models: list[bentoml.Model] | None = ...,
|
||||
max_batch_size: int | None = ...,
|
||||
max_latency_ms: int | None = ...,
|
||||
method_configs: dict[str, dict[str, int]] | None = ...,
|
||||
embedded: bool = False,
|
||||
) -> None: ...
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def __repr_keys__(self) -> set[str]:
|
||||
...
|
||||
def __repr_keys__(self) -> set[str]: ...
|
||||
|
||||
|
||||
__all__ = ['LLMRunner', 'LLMRunnable', 'LLM']
|
||||
|
||||
@@ -14,6 +14,7 @@ from openllm_core.utils import is_autogptq_available
|
||||
from openllm_core.utils import is_bitsandbytes_available
|
||||
from openllm_core.utils import is_optimum_supports_gptq
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from openllm_core._typing_compat import DictStrAny
|
||||
|
||||
@@ -21,20 +22,28 @@ if t.TYPE_CHECKING:
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@overload
|
||||
def infer_quantisation_config(self: LLM[t.Any, t.Any], quantise: t.Literal['int8', 'int4'], **attrs: t.Any) -> tuple[transformers.BitsAndBytesConfig, DictStrAny]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def infer_quantisation_config(self: LLM[t.Any, t.Any], quantise: t.Literal['gptq'], **attrs: t.Any) -> tuple[transformers.GPTQConfig, DictStrAny]:
|
||||
...
|
||||
def infer_quantisation_config(
|
||||
self: LLM[t.Any, t.Any], quantise: t.Literal['int8', 'int4'], **attrs: t.Any
|
||||
) -> tuple[transformers.BitsAndBytesConfig, DictStrAny]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def infer_quantisation_config(self: LLM[t.Any, t.Any], quantise: t.Literal['awq'], **attrs: t.Any) -> tuple[transformers.AwqConfig, DictStrAny]:
|
||||
...
|
||||
def infer_quantisation_config(
|
||||
self: LLM[t.Any, t.Any], quantise: t.Literal['gptq'], **attrs: t.Any
|
||||
) -> tuple[transformers.GPTQConfig, DictStrAny]: ...
|
||||
|
||||
def infer_quantisation_config(self: LLM[t.Any, t.Any], quantise: LiteralQuantise,
|
||||
**attrs: t.Any) -> tuple[transformers.BitsAndBytesConfig | transformers.GPTQConfig | transformers.AwqConfig, DictStrAny]:
|
||||
|
||||
@overload
|
||||
def infer_quantisation_config(
|
||||
self: LLM[t.Any, t.Any], quantise: t.Literal['awq'], **attrs: t.Any
|
||||
) -> tuple[transformers.AwqConfig, DictStrAny]: ...
|
||||
|
||||
|
||||
def infer_quantisation_config(
|
||||
self: LLM[t.Any, t.Any], quantise: LiteralQuantise, **attrs: t.Any
|
||||
) -> tuple[transformers.BitsAndBytesConfig | transformers.GPTQConfig | transformers.AwqConfig, DictStrAny]:
|
||||
# 8 bit configuration
|
||||
int8_threshold = attrs.pop('llm_int8_threshhold', 6.0)
|
||||
int8_enable_fp32_cpu_offload = attrs.pop('llm_int8_enable_fp32_cpu_offload', False)
|
||||
@@ -64,34 +73,39 @@ def infer_quantisation_config(self: LLM[t.Any, t.Any], quantise: LiteralQuantise
|
||||
gptq_pad_token_id = attrs.pop('pad_token_id', None)
|
||||
disable_exllama = attrs.pop('disable_exllama', False) # backward compatibility
|
||||
gptq_use_exllama = attrs.pop('use_exllama', True)
|
||||
if disable_exllama: gptq_use_exllama = False
|
||||
return transformers.GPTQConfig(bits=bits,
|
||||
tokenizer=gptq_tokenizer,
|
||||
dataset=gptq_dataset,
|
||||
group_size=group_size,
|
||||
damp_percent=gptq_damp_percent,
|
||||
desc_act=gptq_desc_act,
|
||||
sym=gptq_sym,
|
||||
true_sequential=gptq_true_sequential,
|
||||
use_cuda_fp16=gptq_use_cuda_fp16,
|
||||
model_seqlen=gptq_model_seqlen,
|
||||
block_name_to_quantize=gptq_block_name_to_quantize,
|
||||
module_name_preceding_first_block=gptq_module_name_preceding_first_block,
|
||||
batch_size=gptq_batch_size,
|
||||
pad_token_id=gptq_pad_token_id,
|
||||
use_exllama=gptq_use_exllama,
|
||||
exllama_config={'version': 1}) # XXX: See how to migrate to v2
|
||||
if disable_exllama:
|
||||
gptq_use_exllama = False
|
||||
return transformers.GPTQConfig(
|
||||
bits=bits,
|
||||
tokenizer=gptq_tokenizer,
|
||||
dataset=gptq_dataset,
|
||||
group_size=group_size,
|
||||
damp_percent=gptq_damp_percent,
|
||||
desc_act=gptq_desc_act,
|
||||
sym=gptq_sym,
|
||||
true_sequential=gptq_true_sequential,
|
||||
use_cuda_fp16=gptq_use_cuda_fp16,
|
||||
model_seqlen=gptq_model_seqlen,
|
||||
block_name_to_quantize=gptq_block_name_to_quantize,
|
||||
module_name_preceding_first_block=gptq_module_name_preceding_first_block,
|
||||
batch_size=gptq_batch_size,
|
||||
pad_token_id=gptq_pad_token_id,
|
||||
use_exllama=gptq_use_exllama,
|
||||
exllama_config={'version': 1},
|
||||
) # XXX: See how to migrate to v2
|
||||
|
||||
def create_int8_config(int8_skip_modules: list[str] | None) -> transformers.BitsAndBytesConfig:
|
||||
# if int8_skip_modules is None: int8_skip_modules = []
|
||||
# if 'lm_head' not in int8_skip_modules and self.config_class.__openllm_model_type__ == 'causal_lm':
|
||||
# logger.debug("Skipping 'lm_head' for quantization for %s", self.__name__)
|
||||
# int8_skip_modules.append('lm_head')
|
||||
return transformers.BitsAndBytesConfig(load_in_8bit=True,
|
||||
llm_int8_enable_fp32_cpu_offload=int8_enable_fp32_cpu_offload,
|
||||
llm_int8_threshhold=int8_threshold,
|
||||
llm_int8_skip_modules=int8_skip_modules,
|
||||
llm_int8_has_fp16_weight=int8_has_fp16_weight)
|
||||
return transformers.BitsAndBytesConfig(
|
||||
load_in_8bit=True,
|
||||
llm_int8_enable_fp32_cpu_offload=int8_enable_fp32_cpu_offload,
|
||||
llm_int8_threshhold=int8_threshold,
|
||||
llm_int8_skip_modules=int8_skip_modules,
|
||||
llm_int8_has_fp16_weight=int8_has_fp16_weight,
|
||||
)
|
||||
|
||||
# 4 bit configuration
|
||||
int4_compute_dtype = attrs.pop('bnb_4bit_compute_dtype', torch.bfloat16)
|
||||
@@ -100,22 +114,30 @@ def infer_quantisation_config(self: LLM[t.Any, t.Any], quantise: LiteralQuantise
|
||||
|
||||
# NOTE: Quantization setup quantize is a openllm.LLM feature, where we can quantize the model with bitsandbytes or quantization aware training.
|
||||
if not is_bitsandbytes_available():
|
||||
raise RuntimeError("Quantization requires bitsandbytes to be installed. Make sure to install OpenLLM with 'pip install \"openllm[fine-tune]\"'")
|
||||
if quantise == 'int8': quantisation_config = create_int8_config(int8_skip_modules)
|
||||
raise RuntimeError(
|
||||
'Quantization requires bitsandbytes to be installed. Make sure to install OpenLLM with \'pip install "openllm[fine-tune]"\''
|
||||
)
|
||||
if quantise == 'int8':
|
||||
quantisation_config = create_int8_config(int8_skip_modules)
|
||||
elif quantise == 'int4':
|
||||
quantisation_config = transformers.BitsAndBytesConfig(load_in_4bit=True,
|
||||
bnb_4bit_compute_dtype=int4_compute_dtype,
|
||||
bnb_4bit_quant_type=int4_quant_type,
|
||||
bnb_4bit_use_double_quant=int4_use_double_quant)
|
||||
quantisation_config = transformers.BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_compute_dtype=int4_compute_dtype,
|
||||
bnb_4bit_quant_type=int4_quant_type,
|
||||
bnb_4bit_use_double_quant=int4_use_double_quant,
|
||||
)
|
||||
elif quantise == 'gptq':
|
||||
if not is_autogptq_available() or not is_optimum_supports_gptq():
|
||||
raise MissingDependencyError(
|
||||
"'quantize=\"gptq\"' requires 'auto-gptq' and 'optimum>=0.12' to be installed (missing or failed to import). Make sure to do 'pip install \"openllm[gptq]\"'")
|
||||
"'quantize=\"gptq\"' requires 'auto-gptq' and 'optimum>=0.12' to be installed (missing or failed to import). Make sure to do 'pip install \"openllm[gptq]\"'"
|
||||
)
|
||||
else:
|
||||
quantisation_config = create_gptq_config()
|
||||
elif quantise == 'awq':
|
||||
if not is_autoawq_available():
|
||||
raise MissingDependencyError("quantize='awq' requires 'auto-awq' to be installed (missing or failed to import). Make sure to do 'pip install \"openllm[awq]\"'.")
|
||||
raise MissingDependencyError(
|
||||
"quantize='awq' requires 'auto-awq' to be installed (missing or failed to import). Make sure to do 'pip install \"openllm[awq]\"'."
|
||||
)
|
||||
else:
|
||||
quantisation_config = create_awq_config()
|
||||
else:
|
||||
|
||||
@@ -19,6 +19,7 @@ from openllm_core.utils import first_not_none
|
||||
from openllm_core.utils import get_debug_mode
|
||||
from openllm_core.utils import is_vllm_available
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import vllm
|
||||
|
||||
@@ -30,10 +31,15 @@ _DEFAULT_TOKENIZER = 'hf-internal-testing/llama-tokenizer'
|
||||
|
||||
__all__ = ['runnable']
|
||||
|
||||
|
||||
def runnable(backend: LiteralBackend | None = None) -> type[bentoml.Runnable]:
|
||||
backend = t.cast(LiteralBackend, first_not_none(backend, os.getenv('OPENLLM_BACKEND'), default='vllm' if is_vllm_available() else 'pt'))
|
||||
backend = t.cast(
|
||||
LiteralBackend,
|
||||
first_not_none(backend, os.getenv('OPENLLM_BACKEND'), default='vllm' if is_vllm_available() else 'pt'),
|
||||
)
|
||||
return vLLMRunnable if backend == 'vllm' else PyTorchRunnable
|
||||
|
||||
|
||||
class vLLMRunnable(bentoml.Runnable):
|
||||
SUPPORTED_RESOURCES = ('nvidia.com/gpu', 'amd.com/gpu', 'cpu')
|
||||
SUPPORTS_CPU_MULTI_THREADING = True
|
||||
@@ -41,47 +47,62 @@ class vLLMRunnable(bentoml.Runnable):
|
||||
def __init__(self, llm: openllm.LLM[M, T]) -> None:
|
||||
self.config = llm.config
|
||||
num_gpus, dev = 1, openllm.utils.device_count()
|
||||
if dev >= 2: num_gpus = min(dev // 2 * 2, dev)
|
||||
if dev >= 2:
|
||||
num_gpus = min(dev // 2 * 2, dev)
|
||||
quantization = None
|
||||
if llm._quantise and llm._quantise in {'awq', 'squeezellm'}: quantization = llm._quantise
|
||||
if llm._quantise and llm._quantise in {'awq', 'squeezellm'}:
|
||||
quantization = llm._quantise
|
||||
try:
|
||||
self.model = vllm.AsyncLLMEngine.from_engine_args(
|
||||
vllm.AsyncEngineArgs(model=llm.bentomodel.path,
|
||||
tokenizer=llm.bentomodel.path,
|
||||
trust_remote_code=llm.trust_remote_code,
|
||||
tokenizer_mode='auto',
|
||||
tensor_parallel_size=num_gpus,
|
||||
dtype='auto',
|
||||
quantization=quantization,
|
||||
disable_log_requests=not get_debug_mode(),
|
||||
worker_use_ray=False,
|
||||
engine_use_ray=False))
|
||||
vllm.AsyncEngineArgs(
|
||||
model=llm.bentomodel.path,
|
||||
tokenizer=llm.bentomodel.path,
|
||||
trust_remote_code=llm.trust_remote_code,
|
||||
tokenizer_mode='auto',
|
||||
tensor_parallel_size=num_gpus,
|
||||
dtype='auto',
|
||||
quantization=quantization,
|
||||
disable_log_requests=not get_debug_mode(),
|
||||
worker_use_ray=False,
|
||||
engine_use_ray=False,
|
||||
)
|
||||
)
|
||||
except Exception as err:
|
||||
traceback.print_exc()
|
||||
raise OpenLLMException(f'Failed to initialise vLLMEngine due to the following error:\n{err}') from err
|
||||
|
||||
@bentoml.Runnable.method(batchable=False)
|
||||
async def generate_iterator(self,
|
||||
prompt_token_ids: list[int],
|
||||
request_id: str,
|
||||
stop: str | t.Iterable[str] | None = None,
|
||||
adapter_name: str | None = None,
|
||||
**attrs: t.Any) -> t.AsyncGenerator[str, None]:
|
||||
if adapter_name is not None: raise NotImplementedError('Adapter is not supported with vLLM.')
|
||||
async def generate_iterator(
|
||||
self,
|
||||
prompt_token_ids: list[int],
|
||||
request_id: str,
|
||||
stop: str | t.Iterable[str] | None = None,
|
||||
adapter_name: str | None = None,
|
||||
**attrs: t.Any,
|
||||
) -> t.AsyncGenerator[str, None]:
|
||||
if adapter_name is not None:
|
||||
raise NotImplementedError('Adapter is not supported with vLLM.')
|
||||
stop_: set[str] = set()
|
||||
if isinstance(stop, str) and stop != '': stop_.add(stop)
|
||||
elif isinstance(stop, t.Iterable): stop_.update(stop)
|
||||
if isinstance(stop, str) and stop != '':
|
||||
stop_.add(stop)
|
||||
elif isinstance(stop, t.Iterable):
|
||||
stop_.update(stop)
|
||||
|
||||
temperature = attrs.pop('temperature', self.config['temperature'])
|
||||
top_p = attrs.pop('top_p', self.config['top_p'])
|
||||
if temperature <= 1e-5: top_p = 1.0
|
||||
sampling_params = self.config.model_construct_env(stop=list(stop_), temperature=temperature, top_p=top_p, **attrs).to_sampling_config()
|
||||
if temperature <= 1e-5:
|
||||
top_p = 1.0
|
||||
sampling_params = self.config.model_construct_env(
|
||||
stop=list(stop_), temperature=temperature, top_p=top_p, **attrs
|
||||
).to_sampling_config()
|
||||
|
||||
async for request_output in self.model.generate(None, sampling_params, request_id, prompt_token_ids):
|
||||
# XXX: Need to write a hook for serialisation None correctly
|
||||
if request_output.prompt_logprobs is not None: request_output.prompt_logprobs = [it if it else {} for it in request_output.prompt_logprobs]
|
||||
if request_output.prompt_logprobs is not None:
|
||||
request_output.prompt_logprobs = [it if it else {} for it in request_output.prompt_logprobs]
|
||||
yield GenerationOutput.from_vllm(request_output).model_dump_json()
|
||||
|
||||
|
||||
class PyTorchRunnable(bentoml.Runnable):
|
||||
SUPPORTED_RESOURCES = ('nvidia.com/gpu', 'amd.com/gpu', 'cpu')
|
||||
SUPPORTS_CPU_MULTI_THREADING = True
|
||||
@@ -93,23 +114,30 @@ class PyTorchRunnable(bentoml.Runnable):
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
@bentoml.Runnable.method(batchable=False)
|
||||
async def generate_iterator(self,
|
||||
prompt_token_ids: list[int],
|
||||
request_id: str,
|
||||
stop: str | t.Iterable[str] | None = None,
|
||||
adapter_name: str | None = None,
|
||||
**attrs: t.Any) -> t.AsyncGenerator[str, None]:
|
||||
if adapter_name is not None: self.model.set_adapter(adapter_name)
|
||||
async def generate_iterator(
|
||||
self,
|
||||
prompt_token_ids: list[int],
|
||||
request_id: str,
|
||||
stop: str | t.Iterable[str] | None = None,
|
||||
adapter_name: str | None = None,
|
||||
**attrs: t.Any,
|
||||
) -> t.AsyncGenerator[str, None]:
|
||||
if adapter_name is not None:
|
||||
self.model.set_adapter(adapter_name)
|
||||
async for generation_output in self.forward(prompt_token_ids, request_id, stop=stop, **attrs):
|
||||
yield generation_output.model_dump_json()
|
||||
|
||||
async def forward(self, prompt_token_ids: list[int], request_id: str, stop: str | t.Iterable[str] | None = None, **attrs: t.Any) -> t.AsyncGenerator[GenerationOutput, None]:
|
||||
async def forward(
|
||||
self, prompt_token_ids: list[int], request_id: str, stop: str | t.Iterable[str] | None = None, **attrs: t.Any
|
||||
) -> t.AsyncGenerator[GenerationOutput, None]:
|
||||
from ._generation import is_partial_stop
|
||||
from ._generation import prepare_logits_processor
|
||||
|
||||
stop_: set[str] = set()
|
||||
if isinstance(stop, str) and stop != '': stop_.add(stop)
|
||||
elif isinstance(stop, t.Iterable): stop_.update(stop)
|
||||
if isinstance(stop, str) and stop != '':
|
||||
stop_.add(stop)
|
||||
elif isinstance(stop, t.Iterable):
|
||||
stop_.update(stop)
|
||||
config = self.config.model_construct_env(**attrs)
|
||||
|
||||
with torch.inference_mode():
|
||||
@@ -129,7 +157,9 @@ class PyTorchRunnable(bentoml.Runnable):
|
||||
if i == 0: # prefill
|
||||
out = self.model(torch.as_tensor([prompt_token_ids], device=self.device), use_cache=True)
|
||||
else: # decoding
|
||||
out = self.model(torch.as_tensor([[token]], device=self.device), use_cache=True, past_key_values=past_key_values)
|
||||
out = self.model(
|
||||
torch.as_tensor([[token]], device=self.device), use_cache=True, past_key_values=past_key_values
|
||||
)
|
||||
logits = out.logits
|
||||
past_key_values = out.past_key_values
|
||||
|
||||
@@ -143,7 +173,8 @@ class PyTorchRunnable(bentoml.Runnable):
|
||||
last_token_logits = logits[0, -1, :]
|
||||
|
||||
# Switch to CPU by avoiding some bugs in mps backend.
|
||||
if self.device.type == 'mps': last_token_logits = last_token_logits.float().to('cpu')
|
||||
if self.device.type == 'mps':
|
||||
last_token_logits = last_token_logits.float().to('cpu')
|
||||
|
||||
if config['temperature'] < 1e-5 or config['top_p'] < 1e-8: # greedy
|
||||
_, indices = torch.topk(last_token_logits, 2)
|
||||
@@ -160,7 +191,12 @@ class PyTorchRunnable(bentoml.Runnable):
|
||||
|
||||
tmp_output_ids, rfind_start = output_token_ids[input_len:], 0
|
||||
# XXX: Move this to API server
|
||||
text = self.tokenizer.decode(tmp_output_ids, skip_special_tokens=True, spaces_between_special_tokens=False, clean_up_tokenization_spaces=True)
|
||||
text = self.tokenizer.decode(
|
||||
tmp_output_ids,
|
||||
skip_special_tokens=True,
|
||||
spaces_between_special_tokens=False,
|
||||
clean_up_tokenization_spaces=True,
|
||||
)
|
||||
partially_stopped = False
|
||||
if stop_:
|
||||
for it in stop_:
|
||||
@@ -170,21 +206,41 @@ class PyTorchRunnable(bentoml.Runnable):
|
||||
break
|
||||
else:
|
||||
partially_stopped = is_partial_stop(text, it)
|
||||
if partially_stopped: break
|
||||
if partially_stopped:
|
||||
break
|
||||
if not partially_stopped:
|
||||
yield GenerationOutput(prompt='',
|
||||
finished=False,
|
||||
outputs=[CompletionChunk(index=0, text=text, token_ids=output_token_ids[input_len:], cumulative_logprob=0.0, finish_reason=None)],
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
request_id=request_id)
|
||||
if stopped: break
|
||||
else: finish_reason = 'length'
|
||||
if stopped: finish_reason = 'stop'
|
||||
yield GenerationOutput(prompt='',
|
||||
finished=True,
|
||||
outputs=[CompletionChunk(index=0, text=text, token_ids=output_token_ids[input_len:], cumulative_logprob=0.0, finish_reason=finish_reason)],
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
request_id=request_id)
|
||||
yield GenerationOutput(
|
||||
prompt='',
|
||||
finished=False,
|
||||
outputs=[
|
||||
CompletionChunk(
|
||||
index=0, text=text, token_ids=output_token_ids[input_len:], cumulative_logprob=0.0, finish_reason=None
|
||||
)
|
||||
],
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
request_id=request_id,
|
||||
)
|
||||
if stopped:
|
||||
break
|
||||
else:
|
||||
finish_reason = 'length'
|
||||
if stopped:
|
||||
finish_reason = 'stop'
|
||||
yield GenerationOutput(
|
||||
prompt='',
|
||||
finished=True,
|
||||
outputs=[
|
||||
CompletionChunk(
|
||||
index=0,
|
||||
text=text,
|
||||
token_ids=output_token_ids[input_len:],
|
||||
cumulative_logprob=0.0,
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
],
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
# Clean
|
||||
del past_key_values, out
|
||||
|
||||
@@ -13,40 +13,60 @@ import openllm
|
||||
from bentoml.io import JSON
|
||||
from bentoml.io import Text
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
llm = openllm.LLM[t.Any, t.Any](svars.model_id,
|
||||
model_tag=svars.model_tag,
|
||||
prompt_template=openllm.utils.first_not_none(os.getenv('OPENLLM_PROMPT_TEMPLATE'), None),
|
||||
system_message=openllm.utils.first_not_none(os.getenv('OPENLLM_SYSTEM_MESSAGE'), None),
|
||||
serialisation=openllm.utils.first_not_none(os.getenv('OPENLLM_SERIALIZATION'), 'safetensors'),
|
||||
adapter_map=orjson.loads(svars.adapter_map),
|
||||
trust_remote_code=openllm.utils.check_bool_env('TRUST_REMOTE_CODE', default=False))
|
||||
llm = openllm.LLM[t.Any, t.Any](
|
||||
svars.model_id,
|
||||
model_tag=svars.model_tag,
|
||||
prompt_template=openllm.utils.first_not_none(os.getenv('OPENLLM_PROMPT_TEMPLATE'), None),
|
||||
system_message=openllm.utils.first_not_none(os.getenv('OPENLLM_SYSTEM_MESSAGE'), None),
|
||||
serialisation=openllm.utils.first_not_none(os.getenv('OPENLLM_SERIALIZATION'), 'safetensors'),
|
||||
adapter_map=orjson.loads(svars.adapter_map),
|
||||
trust_remote_code=openllm.utils.check_bool_env('TRUST_REMOTE_CODE', default=False),
|
||||
)
|
||||
llm_config = llm.config
|
||||
svc = bentoml.Service(name=f"llm-{llm_config['start_name']}-service", runners=[llm.runner])
|
||||
|
||||
llm_model_class = openllm.GenerationInput.from_llm_config(llm_config)
|
||||
|
||||
@svc.api(route='/v1/generate', input=JSON.from_sample(llm_model_class.examples()), output=JSON.from_sample(openllm.GenerationOutput.examples()))
|
||||
|
||||
@svc.api(
|
||||
route='/v1/generate',
|
||||
input=JSON.from_sample(llm_model_class.examples()),
|
||||
output=JSON.from_sample(openllm.GenerationOutput.examples()),
|
||||
)
|
||||
async def generate_v1(input_dict: dict[str, t.Any]) -> openllm.GenerationOutput:
|
||||
return await llm.generate(**llm_model_class(**input_dict).model_dump())
|
||||
|
||||
@svc.api(route='/v1/generate_stream', input=JSON.from_sample(llm_model_class.examples()), output=Text(content_type='text/event-stream'))
|
||||
|
||||
@svc.api(
|
||||
route='/v1/generate_stream',
|
||||
input=JSON.from_sample(llm_model_class.examples()),
|
||||
output=Text(content_type='text/event-stream'),
|
||||
)
|
||||
async def generate_stream_v1(input_dict: dict[str, t.Any]) -> t.AsyncGenerator[str, None]:
|
||||
async for it in llm.generate_iterator(**llm_model_class(**input_dict).model_dump()):
|
||||
yield f'data: {it.model_dump_json()}\n\n'
|
||||
yield 'data: [DONE]\n\n'
|
||||
|
||||
_Metadata = openllm.MetadataOutput(timeout=llm_config['timeout'],
|
||||
model_name=llm_config['model_name'],
|
||||
backend=llm.__llm_backend__,
|
||||
model_id=llm.model_id,
|
||||
configuration=llm_config.model_dump_json(flatten=True).decode(),
|
||||
prompt_template=llm.runner.prompt_template,
|
||||
system_message=llm.runner.system_message)
|
||||
|
||||
_Metadata = openllm.MetadataOutput(
|
||||
timeout=llm_config['timeout'],
|
||||
model_name=llm_config['model_name'],
|
||||
backend=llm.__llm_backend__,
|
||||
model_id=llm.model_id,
|
||||
configuration=llm_config.model_dump_json(flatten=True).decode(),
|
||||
prompt_template=llm.runner.prompt_template,
|
||||
system_message=llm.runner.system_message,
|
||||
)
|
||||
|
||||
|
||||
@svc.api(route='/v1/metadata', input=Text(), output=JSON.from_sample(_Metadata.model_dump()))
|
||||
def metadata_v1(_: str) -> openllm.MetadataOutput:
|
||||
return _Metadata
|
||||
|
||||
openllm.mount_entrypoints(svc, llm) # HACK: This must always be the last line in this file, as we will do some MK for OpenAPI schema.
|
||||
|
||||
openllm.mount_entrypoints(
|
||||
svc, llm
|
||||
) # HACK: This must always be the last line in this file, as we will do some MK for OpenAPI schema.
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
model_id = '{__model_id__}' # openllm: model id
|
||||
model_tag = '{__model_tag__}' # openllm: model tag
|
||||
adapter_map = '''{__model_adapter_map__}''' # openllm: model adapter map
|
||||
adapter_map = """{__model_adapter_map__}""" # openllm: model adapter map
|
||||
|
||||
@@ -20,60 +20,74 @@ from openllm_core._typing_compat import overload
|
||||
from openllm_core.utils import DEBUG
|
||||
from openllm_core.utils import ReprMixin
|
||||
|
||||
|
||||
class DynResource(t.Protocol):
|
||||
resource_id: t.ClassVar[str]
|
||||
|
||||
@classmethod
|
||||
def from_system(cls) -> t.Sequence[t.Any]:
|
||||
...
|
||||
def from_system(cls) -> t.Sequence[t.Any]: ...
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _strtoul(s: str) -> int:
|
||||
"""Return -1 or positive integer sequence string starts with,."""
|
||||
if not s: return -1
|
||||
if not s:
|
||||
return -1
|
||||
idx = 0
|
||||
for idx, c in enumerate(s):
|
||||
if not (c.isdigit() or (idx == 0 and c in '+-')): break
|
||||
if idx + 1 == len(s): idx += 1 # noqa: PLW2901
|
||||
if not (c.isdigit() or (idx == 0 and c in '+-')):
|
||||
break
|
||||
if idx + 1 == len(s):
|
||||
idx += 1 # noqa: PLW2901
|
||||
# NOTE: idx will be set via enumerate
|
||||
return int(s[:idx]) if idx > 0 else -1
|
||||
|
||||
|
||||
def _parse_list_with_prefix(lst: str, prefix: str) -> list[str]:
|
||||
rcs: list[str] = []
|
||||
for elem in lst.split(','):
|
||||
# Repeated id results in empty set
|
||||
if elem in rcs: return []
|
||||
if elem in rcs:
|
||||
return []
|
||||
# Anything other but prefix is ignored
|
||||
if not elem.startswith(prefix): break
|
||||
if not elem.startswith(prefix):
|
||||
break
|
||||
rcs.append(elem)
|
||||
return rcs
|
||||
|
||||
|
||||
_STACK_LEVEL = 3
|
||||
|
||||
|
||||
@overload # variant: default callback
|
||||
def _parse_visible_devices() -> list[str] | None:
|
||||
...
|
||||
def _parse_visible_devices() -> list[str] | None: ...
|
||||
|
||||
|
||||
@overload # variant: specify None, and respect_env
|
||||
def _parse_visible_devices(default_var: None, *, respect_env: t.Literal[True]) -> list[str] | None:
|
||||
...
|
||||
def _parse_visible_devices(default_var: None, *, respect_env: t.Literal[True]) -> list[str] | None: ...
|
||||
|
||||
|
||||
@overload # variant: default var is something other than None
|
||||
def _parse_visible_devices(default_var: str = ..., *, respect_env: t.Literal[False]) -> list[str]:
|
||||
...
|
||||
def _parse_visible_devices(default_var: str = ..., *, respect_env: t.Literal[False]) -> list[str]: ...
|
||||
|
||||
|
||||
def _parse_visible_devices(default_var: str | None = None, respect_env: bool = True) -> list[str] | None:
|
||||
"""CUDA_VISIBLE_DEVICES aware with default var for parsing spec."""
|
||||
if respect_env:
|
||||
spec = os.environ.get('CUDA_VISIBLE_DEVICES', default_var)
|
||||
if not spec: return None
|
||||
if not spec:
|
||||
return None
|
||||
else:
|
||||
if default_var is None: raise ValueError('spec is required to be not None when parsing spec.')
|
||||
if default_var is None:
|
||||
raise ValueError('spec is required to be not None when parsing spec.')
|
||||
spec = default_var
|
||||
|
||||
if spec.startswith('GPU-'): return _parse_list_with_prefix(spec, 'GPU-')
|
||||
if spec.startswith('MIG-'): return _parse_list_with_prefix(spec, 'MIG-')
|
||||
if spec.startswith('GPU-'):
|
||||
return _parse_list_with_prefix(spec, 'GPU-')
|
||||
if spec.startswith('MIG-'):
|
||||
return _parse_list_with_prefix(spec, 'MIG-')
|
||||
# XXX: We need to somehow handle cases such as '100m'
|
||||
# CUDA_VISIBLE_DEVICES uses something like strtoul
|
||||
# which makes `1gpu2,2ampere` is equivalent to `1,2`
|
||||
@@ -81,18 +95,22 @@ def _parse_visible_devices(default_var: str | None = None, respect_env: bool = T
|
||||
for el in spec.split(','):
|
||||
x = _strtoul(el.strip())
|
||||
# Repeated ordinal results in empty set
|
||||
if x in rc: return []
|
||||
if x in rc:
|
||||
return []
|
||||
# Negative value aborts the sequence
|
||||
if x < 0: break
|
||||
if x < 0:
|
||||
break
|
||||
rc.append(x)
|
||||
return [str(i) for i in rc]
|
||||
|
||||
|
||||
def _from_system(cls: type[DynResource]) -> list[str]:
|
||||
visible_devices = _parse_visible_devices()
|
||||
if visible_devices is None:
|
||||
if cls.resource_id == 'amd.com/gpu':
|
||||
if not psutil.LINUX:
|
||||
if DEBUG: logger.debug('AMD GPUs is currently only supported on Linux.')
|
||||
if DEBUG:
|
||||
logger.debug('AMD GPUs is currently only supported on Linux.')
|
||||
return []
|
||||
# ROCm does not currently have the rocm_smi wheel.
|
||||
# So we need to use the ctypes bindings directly.
|
||||
@@ -108,7 +126,8 @@ def _from_system(cls: type[DynResource]) -> list[str]:
|
||||
|
||||
device_count = c_uint32(0)
|
||||
ret = rocmsmi.rsmi_num_monitor_devices(byref(device_count))
|
||||
if ret == rsmi_status_t.RSMI_STATUS_SUCCESS: return [str(i) for i in range(device_count.value)]
|
||||
if ret == rsmi_status_t.RSMI_STATUS_SUCCESS:
|
||||
return [str(i) for i in range(device_count.value)]
|
||||
return []
|
||||
# In this case the binary is not found, returning empty list
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
@@ -118,6 +137,7 @@ def _from_system(cls: type[DynResource]) -> list[str]:
|
||||
else:
|
||||
try:
|
||||
from cuda import cuda
|
||||
|
||||
cuda.cuInit(0)
|
||||
_, dev = cuda.cuDeviceGetCount()
|
||||
return [str(i) for i in range(dev)]
|
||||
@@ -125,31 +145,39 @@ def _from_system(cls: type[DynResource]) -> list[str]:
|
||||
return []
|
||||
return visible_devices
|
||||
|
||||
@overload
|
||||
def _from_spec(cls: type[DynResource], spec: int) -> list[str]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def _from_spec(cls: type[DynResource], spec: list[int | str]) -> list[str]:
|
||||
...
|
||||
def _from_spec(cls: type[DynResource], spec: int) -> list[str]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def _from_spec(cls: type[DynResource], spec: str) -> list[str]:
|
||||
...
|
||||
def _from_spec(cls: type[DynResource], spec: list[int | str]) -> list[str]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def _from_spec(cls: type[DynResource], spec: str) -> list[str]: ...
|
||||
|
||||
|
||||
def _from_spec(cls: type[DynResource], spec: t.Any) -> list[str]:
|
||||
if isinstance(spec, int):
|
||||
if spec in (-1, 0): return []
|
||||
if spec < -1: raise ValueError('Spec cannot be < -1.')
|
||||
if spec in (-1, 0):
|
||||
return []
|
||||
if spec < -1:
|
||||
raise ValueError('Spec cannot be < -1.')
|
||||
return [str(i) for i in range(spec)]
|
||||
elif isinstance(spec, str):
|
||||
if not spec: return []
|
||||
if spec.isdigit(): spec = ','.join([str(i) for i in range(_strtoul(spec))])
|
||||
if not spec:
|
||||
return []
|
||||
if spec.isdigit():
|
||||
spec = ','.join([str(i) for i in range(_strtoul(spec))])
|
||||
return _parse_visible_devices(spec, respect_env=False)
|
||||
elif isinstance(spec, list):
|
||||
return [str(x) for x in spec]
|
||||
else:
|
||||
raise TypeError(f"'{cls.__name__}.from_spec' only supports parsing spec of type int, str, or list, got '{type(spec)}' instead.")
|
||||
raise TypeError(
|
||||
f"'{cls.__name__}.from_spec' only supports parsing spec of type int, str, or list, got '{type(spec)}' instead."
|
||||
)
|
||||
|
||||
|
||||
def _raw_device_uuid_nvml() -> list[str] | None:
|
||||
from ctypes import CDLL
|
||||
@@ -190,10 +218,14 @@ def _raw_device_uuid_nvml() -> list[str] | None:
|
||||
del nvml_h
|
||||
return uuids
|
||||
|
||||
|
||||
def _validate(cls: type[DynResource], val: list[t.Any]) -> None:
|
||||
if cls.resource_id == 'amd.com/gpu':
|
||||
raise RuntimeError("AMD GPU validation is not yet supported. Make sure to call 'get_resource(..., validate=False)'")
|
||||
if not all(isinstance(i, str) for i in val): raise ValueError('Input list should be all string type.')
|
||||
raise RuntimeError(
|
||||
"AMD GPU validation is not yet supported. Make sure to call 'get_resource(..., validate=False)'"
|
||||
)
|
||||
if not all(isinstance(i, str) for i in val):
|
||||
raise ValueError('Input list should be all string type.')
|
||||
|
||||
try:
|
||||
from cuda import cuda
|
||||
@@ -205,25 +237,36 @@ def _validate(cls: type[DynResource], val: list[t.Any]) -> None:
|
||||
for el in val:
|
||||
if el.startswith('GPU-') or el.startswith('MIG-'):
|
||||
uuids = _raw_device_uuid_nvml()
|
||||
if uuids is None: raise ValueError('Failed to parse available GPUs UUID')
|
||||
if el not in uuids: raise ValueError(f'Given UUID {el} is not found with available UUID (available: {uuids})')
|
||||
if uuids is None:
|
||||
raise ValueError('Failed to parse available GPUs UUID')
|
||||
if el not in uuids:
|
||||
raise ValueError(f'Given UUID {el} is not found with available UUID (available: {uuids})')
|
||||
elif el.isdigit():
|
||||
err, _ = cuda.cuDeviceGet(int(el))
|
||||
if err != cuda.CUresult.CUDA_SUCCESS: raise ValueError(f'Failed to get device {el}')
|
||||
if err != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise ValueError(f'Failed to get device {el}')
|
||||
except (ImportError, RuntimeError):
|
||||
pass
|
||||
|
||||
|
||||
def _make_resource_class(name: str, resource_kind: str, docstring: str) -> type[DynResource]:
|
||||
return types.new_class(
|
||||
name, (bentoml.Resource[t.List[str]], ReprMixin), {'resource_id': resource_kind}, lambda ns: ns.update({
|
||||
'resource_id': resource_kind,
|
||||
'from_spec': classmethod(_from_spec),
|
||||
'from_system': classmethod(_from_system),
|
||||
'validate': classmethod(_validate),
|
||||
'__repr_keys__': property(lambda _: {'resource_id'}),
|
||||
'__doc__': inspect.cleandoc(docstring),
|
||||
'__module__': 'openllm._strategies'
|
||||
}))
|
||||
name,
|
||||
(bentoml.Resource[t.List[str]], ReprMixin),
|
||||
{'resource_id': resource_kind},
|
||||
lambda ns: ns.update(
|
||||
{
|
||||
'resource_id': resource_kind,
|
||||
'from_spec': classmethod(_from_spec),
|
||||
'from_system': classmethod(_from_system),
|
||||
'validate': classmethod(_validate),
|
||||
'__repr_keys__': property(lambda _: {'resource_id'}),
|
||||
'__doc__': inspect.cleandoc(docstring),
|
||||
'__module__': 'openllm._strategies',
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# NOTE: we need to hint these t.Literal since mypy is to dumb to infer this as literal 🤦
|
||||
_TPU_RESOURCE: t.Literal['cloud-tpus.google.com/v2'] = 'cloud-tpus.google.com/v2'
|
||||
@@ -232,15 +275,22 @@ _NVIDIA_GPU_RESOURCE: t.Literal['nvidia.com/gpu'] = 'nvidia.com/gpu'
|
||||
_CPU_RESOURCE: t.Literal['cpu'] = 'cpu'
|
||||
|
||||
NvidiaGpuResource = _make_resource_class(
|
||||
'NvidiaGpuResource', _NVIDIA_GPU_RESOURCE, '''NVIDIA GPU resource.
|
||||
'NvidiaGpuResource',
|
||||
_NVIDIA_GPU_RESOURCE,
|
||||
"""NVIDIA GPU resource.
|
||||
|
||||
This is a modified version of internal's BentoML's NvidiaGpuResource
|
||||
where it respects and parse CUDA_VISIBLE_DEVICES correctly.''')
|
||||
where it respects and parse CUDA_VISIBLE_DEVICES correctly.""",
|
||||
)
|
||||
AmdGpuResource = _make_resource_class(
|
||||
'AmdGpuResource', _AMD_GPU_RESOURCE, '''AMD GPU resource.
|
||||
'AmdGpuResource',
|
||||
_AMD_GPU_RESOURCE,
|
||||
"""AMD GPU resource.
|
||||
|
||||
Since ROCm will respect CUDA_VISIBLE_DEVICES, the behaviour of from_spec, from_system are similar to
|
||||
``NvidiaGpuResource``. Currently ``validate`` is not yet supported.''')
|
||||
``NvidiaGpuResource``. Currently ``validate`` is not yet supported.""",
|
||||
)
|
||||
|
||||
|
||||
class CascadingResourceStrategy(bentoml.Strategy, ReprMixin):
|
||||
"""This is extends the default BentoML strategy where we check for NVIDIA GPU resource -> AMD GPU resource -> CPU resource.
|
||||
@@ -251,21 +301,27 @@ class CascadingResourceStrategy(bentoml.Strategy, ReprMixin):
|
||||
|
||||
TODO: Support CloudTPUResource
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def get_worker_count(cls, runnable_class: type[bentoml.Runnable], resource_request: dict[str, t.Any] | None, workers_per_resource: float) -> int:
|
||||
def get_worker_count(
|
||||
cls, runnable_class: type[bentoml.Runnable], resource_request: dict[str, t.Any] | None, workers_per_resource: float
|
||||
) -> int:
|
||||
"""Return the number of workers to be used for the given runnable class.
|
||||
|
||||
Note that for all available GPU, the number of workers will always be 1.
|
||||
"""
|
||||
if resource_request is None: resource_request = system_resources()
|
||||
if resource_request is None:
|
||||
resource_request = system_resources()
|
||||
# use NVIDIA
|
||||
kind = 'nvidia.com/gpu'
|
||||
nvidia_req = get_resource(resource_request, kind)
|
||||
if nvidia_req is not None: return 1
|
||||
if nvidia_req is not None:
|
||||
return 1
|
||||
# use AMD
|
||||
kind = 'amd.com/gpu'
|
||||
amd_req = get_resource(resource_request, kind, validate=False)
|
||||
if amd_req is not None: return 1
|
||||
if amd_req is not None:
|
||||
return 1
|
||||
# use CPU
|
||||
cpus = get_resource(resource_request, 'cpu')
|
||||
if cpus is not None and cpus > 0:
|
||||
@@ -279,10 +335,18 @@ class CascadingResourceStrategy(bentoml.Strategy, ReprMixin):
|
||||
return math.ceil(cpus) * workers_per_resource
|
||||
|
||||
# this should not be reached by user since we always read system resource as default
|
||||
raise ValueError(f'No known supported resource available for {runnable_class}. Please check your resource request. Leaving it blank will allow BentoML to use system resources.')
|
||||
raise ValueError(
|
||||
f'No known supported resource available for {runnable_class}. Please check your resource request. Leaving it blank will allow BentoML to use system resources.'
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_worker_env(cls, runnable_class: type[bentoml.Runnable], resource_request: dict[str, t.Any] | None, workers_per_resource: int | float, worker_index: int) -> dict[str, t.Any]:
|
||||
def get_worker_env(
|
||||
cls,
|
||||
runnable_class: type[bentoml.Runnable],
|
||||
resource_request: dict[str, t.Any] | None,
|
||||
workers_per_resource: int | float,
|
||||
worker_index: int,
|
||||
) -> dict[str, t.Any]:
|
||||
"""Get worker env for this given worker_index.
|
||||
|
||||
Args:
|
||||
@@ -295,7 +359,8 @@ class CascadingResourceStrategy(bentoml.Strategy, ReprMixin):
|
||||
disabled = cuda_env in ('', '-1')
|
||||
environ: dict[str, t.Any] = {}
|
||||
|
||||
if resource_request is None: resource_request = system_resources()
|
||||
if resource_request is None:
|
||||
resource_request = system_resources()
|
||||
# use NVIDIA
|
||||
kind = 'nvidia.com/gpu'
|
||||
typ = get_resource(resource_request, kind)
|
||||
@@ -340,20 +405,34 @@ class CascadingResourceStrategy(bentoml.Strategy, ReprMixin):
|
||||
# NOTE: We hit this branch when workers_per_resource is set to
|
||||
# float, for example 0.5 or 0.25
|
||||
if workers_per_resource > 1:
|
||||
raise ValueError("Currently, the default strategy doesn't support workers_per_resource > 1. It is recommended that one should implement a custom strategy in this case.")
|
||||
raise ValueError(
|
||||
"Currently, the default strategy doesn't support workers_per_resource > 1. It is recommended that one should implement a custom strategy in this case."
|
||||
)
|
||||
# We are round the assigned resource here. This means if workers_per_resource=.4
|
||||
# then it will round down to 2. If workers_per_source=0.6, then it will also round up to 2.
|
||||
assigned_resource_per_worker = round(1 / workers_per_resource)
|
||||
if len(gpus) < assigned_resource_per_worker:
|
||||
logger.warning('Failed to allocate %s GPUs for %s (number of available GPUs < assigned workers per resource [%s])', gpus, worker_index, assigned_resource_per_worker)
|
||||
raise IndexError(f"There aren't enough assigned GPU(s) for given worker id '{worker_index}' [required: {assigned_resource_per_worker}].")
|
||||
assigned_gpu = gpus[assigned_resource_per_worker * worker_index:assigned_resource_per_worker * (worker_index + 1)]
|
||||
logger.warning(
|
||||
'Failed to allocate %s GPUs for %s (number of available GPUs < assigned workers per resource [%s])',
|
||||
gpus,
|
||||
worker_index,
|
||||
assigned_resource_per_worker,
|
||||
)
|
||||
raise IndexError(
|
||||
f"There aren't enough assigned GPU(s) for given worker id '{worker_index}' [required: {assigned_resource_per_worker}]."
|
||||
)
|
||||
assigned_gpu = gpus[
|
||||
assigned_resource_per_worker * worker_index : assigned_resource_per_worker * (worker_index + 1)
|
||||
]
|
||||
dev = ','.join(assigned_gpu)
|
||||
else:
|
||||
idx = worker_index // workers_per_resource
|
||||
if idx >= len(gpus):
|
||||
raise ValueError(f'Number of available GPU ({gpus}) preceeds the given workers_per_resource {workers_per_resource}')
|
||||
raise ValueError(
|
||||
f'Number of available GPU ({gpus}) preceeds the given workers_per_resource {workers_per_resource}'
|
||||
)
|
||||
dev = str(gpus[idx])
|
||||
return dev
|
||||
|
||||
|
||||
__all__ = ['CascadingResourceStrategy', 'get_resource']
|
||||
|
||||
@@ -2,15 +2,24 @@
|
||||
|
||||
These utilities will stay internal, and its API can be changed or updated without backward-compatibility.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import os
|
||||
import typing as t
|
||||
|
||||
from openllm_core.utils import LazyModule
|
||||
|
||||
|
||||
_import_structure: dict[str, list[str]] = {
|
||||
'_package': ['create_bento', 'build_editable', 'construct_python_options', 'construct_docker_options'],
|
||||
'oci': ['CONTAINER_NAMES', 'get_base_container_tag', 'build_container', 'get_base_container_name', 'supported_registries', 'RefResolver']
|
||||
'_package': ['create_bento', 'build_editable', 'construct_python_options', 'construct_docker_options'],
|
||||
'oci': [
|
||||
'CONTAINER_NAMES',
|
||||
'get_base_container_tag',
|
||||
'build_container',
|
||||
'get_base_container_name',
|
||||
'supported_registries',
|
||||
'RefResolver',
|
||||
],
|
||||
}
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
|
||||
@@ -27,6 +27,7 @@ from bentoml._internal.configuration.containers import BentoMLContainer
|
||||
|
||||
from . import oci
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from fs.base import FS
|
||||
|
||||
@@ -43,15 +44,22 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
OPENLLM_DEV_BUILD = 'OPENLLM_DEV_BUILD'
|
||||
|
||||
def build_editable(path: str, package: t.Literal['openllm', 'openllm_core', 'openllm_client'] = 'openllm') -> str | None:
|
||||
|
||||
def build_editable(
|
||||
path: str, package: t.Literal['openllm', 'openllm_core', 'openllm_client'] = 'openllm'
|
||||
) -> str | None:
|
||||
"""Build OpenLLM if the OPENLLM_DEV_BUILD environment variable is set."""
|
||||
if openllm_core.utils.check_bool_env(OPENLLM_DEV_BUILD, default=False): return None
|
||||
if openllm_core.utils.check_bool_env(OPENLLM_DEV_BUILD, default=False):
|
||||
return None
|
||||
# We need to build the package in editable mode, so that we can import it
|
||||
from build import ProjectBuilder
|
||||
from build.env import IsolatedEnvBuilder
|
||||
|
||||
module_location = openllm_core.utils.pkg.source_locations(package)
|
||||
if not module_location:
|
||||
raise RuntimeError('Could not find the source location of OpenLLM. Make sure to unset OPENLLM_DEV_BUILD if you are developing OpenLLM.')
|
||||
raise RuntimeError(
|
||||
'Could not find the source location of OpenLLM. Make sure to unset OPENLLM_DEV_BUILD if you are developing OpenLLM.'
|
||||
)
|
||||
pyproject_path = Path(module_location).parent.parent / 'pyproject.toml'
|
||||
if os.path.isfile(pyproject_path.__fspath__()):
|
||||
logger.info('Generating built wheels for package %s...', package)
|
||||
@@ -61,57 +69,98 @@ def build_editable(path: str, package: t.Literal['openllm', 'openllm_core', 'ope
|
||||
builder.scripts_dir = env.scripts_dir
|
||||
env.install(builder.build_system_requires)
|
||||
return builder.build('wheel', path, config_settings={'--global-option': '--quiet'})
|
||||
raise RuntimeError('Custom OpenLLM build is currently not supported. Please install OpenLLM from PyPI or built it from Git source.')
|
||||
raise RuntimeError(
|
||||
'Custom OpenLLM build is currently not supported. Please install OpenLLM from PyPI or built it from Git source.'
|
||||
)
|
||||
|
||||
def construct_python_options(llm: openllm.LLM[t.Any, t.Any], llm_fs: FS, extra_dependencies: tuple[str, ...] | None = None, adapter_map: dict[str, str] | None = None) -> PythonOptions:
|
||||
|
||||
def construct_python_options(
|
||||
llm: openllm.LLM[t.Any, t.Any],
|
||||
llm_fs: FS,
|
||||
extra_dependencies: tuple[str, ...] | None = None,
|
||||
adapter_map: dict[str, str] | None = None,
|
||||
) -> PythonOptions:
|
||||
packages = ['openllm', 'scipy'] # apparently bnb misses this one
|
||||
if adapter_map is not None: packages += ['openllm[fine-tune]']
|
||||
if adapter_map is not None:
|
||||
packages += ['openllm[fine-tune]']
|
||||
# NOTE: add openllm to the default dependencies
|
||||
# if users has openllm custom built wheels, it will still respect
|
||||
# that since bentoml will always install dependencies from requirements.txt
|
||||
# first, then proceed to install everything inside the wheels/ folder.
|
||||
if extra_dependencies is not None: packages += [f'openllm[{k}]' for k in extra_dependencies]
|
||||
if extra_dependencies is not None:
|
||||
packages += [f'openllm[{k}]' for k in extra_dependencies]
|
||||
|
||||
req = llm.config['requirements']
|
||||
if req is not None: packages.extend(req)
|
||||
if req is not None:
|
||||
packages.extend(req)
|
||||
if str(os.environ.get('BENTOML_BUNDLE_LOCAL_BUILD', False)).lower() == 'false':
|
||||
packages.append(f"bentoml>={'.'.join([str(i) for i in openllm_core.utils.pkg.pkg_version_info('bentoml')])}")
|
||||
|
||||
if not openllm_core.utils.is_torch_available():
|
||||
raise ValueError('PyTorch is not available. Make sure to have it locally installed.')
|
||||
packages.extend(['torch==2.0.1+cu118', 'vllm==0.2.1.post1', 'xformers==0.0.22', 'bentoml[tracing]==1.1.9']) # XXX: Currently locking this for correctness
|
||||
packages.extend(
|
||||
['torch==2.0.1+cu118', 'vllm==0.2.1.post1', 'xformers==0.0.22', 'bentoml[tracing]==1.1.9']
|
||||
) # XXX: Currently locking this for correctness
|
||||
wheels: list[str] = []
|
||||
built_wheels = [build_editable(llm_fs.getsyspath('/'), t.cast(t.Literal['openllm', 'openllm_core', 'openllm_client'], p)) for p in ('openllm_core', 'openllm_client', 'openllm')]
|
||||
built_wheels = [
|
||||
build_editable(llm_fs.getsyspath('/'), t.cast(t.Literal['openllm', 'openllm_core', 'openllm_client'], p))
|
||||
for p in ('openllm_core', 'openllm_client', 'openllm')
|
||||
]
|
||||
if all(i for i in built_wheels):
|
||||
wheels.extend([llm_fs.getsyspath(f"/{i.split('/')[-1]}") for i in t.cast(t.List[str], built_wheels)])
|
||||
return PythonOptions(packages=packages,
|
||||
wheels=wheels,
|
||||
lock_packages=False,
|
||||
extra_index_url=['https://download.pytorch.org/whl/cu118', 'https://huggingface.github.io/autogptq-index/whl/cu118/'])
|
||||
return PythonOptions(
|
||||
packages=packages,
|
||||
wheels=wheels,
|
||||
lock_packages=False,
|
||||
extra_index_url=[
|
||||
'https://download.pytorch.org/whl/cu118',
|
||||
'https://huggingface.github.io/autogptq-index/whl/cu118/',
|
||||
],
|
||||
)
|
||||
|
||||
def construct_docker_options(llm: openllm.LLM[t.Any, t.Any], _: FS, quantize: LiteralString | None, adapter_map: dict[str, str] | None, dockerfile_template: str | None,
|
||||
serialisation: LiteralSerialisation, container_registry: LiteralContainerRegistry, container_version_strategy: LiteralContainerVersionStrategy) -> DockerOptions:
|
||||
|
||||
def construct_docker_options(
|
||||
llm: openllm.LLM[t.Any, t.Any],
|
||||
_: FS,
|
||||
quantize: LiteralString | None,
|
||||
adapter_map: dict[str, str] | None,
|
||||
dockerfile_template: str | None,
|
||||
serialisation: LiteralSerialisation,
|
||||
container_registry: LiteralContainerRegistry,
|
||||
container_version_strategy: LiteralContainerVersionStrategy,
|
||||
) -> DockerOptions:
|
||||
from openllm.cli._factory import parse_config_options
|
||||
|
||||
environ = parse_config_options(llm.config, llm.config['timeout'], 1.0, None, True, os.environ.copy())
|
||||
env_dict = {
|
||||
'OPENLLM_BACKEND': llm.__llm_backend__,
|
||||
'OPENLLM_CONFIG': f"'{llm.config.model_dump_json(flatten=True).decode()}'",
|
||||
'OPENLLM_SERIALIZATION': serialisation,
|
||||
'BENTOML_DEBUG': str(True),
|
||||
'BENTOML_QUIET': str(False),
|
||||
'BENTOML_CONFIG_OPTIONS': f"'{environ['BENTOML_CONFIG_OPTIONS']}'",
|
||||
'OPENLLM_BACKEND': llm.__llm_backend__,
|
||||
'OPENLLM_CONFIG': f"'{llm.config.model_dump_json(flatten=True).decode()}'",
|
||||
'OPENLLM_SERIALIZATION': serialisation,
|
||||
'BENTOML_DEBUG': str(True),
|
||||
'BENTOML_QUIET': str(False),
|
||||
'BENTOML_CONFIG_OPTIONS': f"'{environ['BENTOML_CONFIG_OPTIONS']}'",
|
||||
}
|
||||
if adapter_map: env_dict['BITSANDBYTES_NOWELCOME'] = os.environ.get('BITSANDBYTES_NOWELCOME', '1')
|
||||
if llm._system_message: env_dict['OPENLLM_SYSTEM_MESSAGE'] = repr(llm._system_message)
|
||||
if llm._prompt_template: env_dict['OPENLLM_PROMPT_TEMPLATE'] = repr(llm._prompt_template.to_string())
|
||||
if quantize: env_dict['OPENLLM_QUANTISE'] = str(quantize)
|
||||
return DockerOptions(base_image=f'{oci.CONTAINER_NAMES[container_registry]}:{oci.get_base_container_tag(container_version_strategy)}', env=env_dict, dockerfile_template=dockerfile_template)
|
||||
if adapter_map:
|
||||
env_dict['BITSANDBYTES_NOWELCOME'] = os.environ.get('BITSANDBYTES_NOWELCOME', '1')
|
||||
if llm._system_message:
|
||||
env_dict['OPENLLM_SYSTEM_MESSAGE'] = repr(llm._system_message)
|
||||
if llm._prompt_template:
|
||||
env_dict['OPENLLM_PROMPT_TEMPLATE'] = repr(llm._prompt_template.to_string())
|
||||
if quantize:
|
||||
env_dict['OPENLLM_QUANTISE'] = str(quantize)
|
||||
return DockerOptions(
|
||||
base_image=f'{oci.CONTAINER_NAMES[container_registry]}:{oci.get_base_container_tag(container_version_strategy)}',
|
||||
env=env_dict,
|
||||
dockerfile_template=dockerfile_template,
|
||||
)
|
||||
|
||||
|
||||
OPENLLM_MODEL_NAME = '# openllm: model name'
|
||||
OPENLLM_MODEL_ID = '# openllm: model id'
|
||||
OPENLLM_MODEL_TAG = '# openllm: model tag'
|
||||
OPENLLM_MODEL_ADAPTER_MAP = '# openllm: model adapter map'
|
||||
|
||||
|
||||
class ModelNameFormatter(string.Formatter):
|
||||
model_keyword: LiteralString = '__model_name__'
|
||||
|
||||
@@ -130,75 +179,122 @@ class ModelNameFormatter(string.Formatter):
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
class ModelIdFormatter(ModelNameFormatter):
|
||||
model_keyword: LiteralString = '__model_id__'
|
||||
|
||||
|
||||
class ModelTagFormatter(ModelNameFormatter):
|
||||
model_keyword: LiteralString = '__model_tag__'
|
||||
|
||||
|
||||
class ModelAdapterMapFormatter(ModelNameFormatter):
|
||||
model_keyword: LiteralString = '__model_adapter_map__'
|
||||
|
||||
|
||||
_service_file = Path(os.path.abspath(__file__)).parent.parent / '_service.py'
|
||||
_service_vars_file = Path(os.path.abspath(__file__)).parent.parent / '_service_vars_pkg.py'
|
||||
|
||||
|
||||
def write_service(llm: openllm.LLM[t.Any, t.Any], adapter_map: dict[str, str] | None, llm_fs: FS) -> None:
|
||||
from openllm_core.utils import DEBUG
|
||||
|
||||
model_name = llm.config['model_name']
|
||||
model_id = llm.model_id
|
||||
model_tag = str(llm.tag)
|
||||
logger.debug('Generating service vars file for %s at %s (dir=%s)', model_name, '_service_vars.py', llm_fs.getsyspath('/'))
|
||||
logger.debug(
|
||||
'Generating service vars file for %s at %s (dir=%s)', model_name, '_service_vars.py', llm_fs.getsyspath('/')
|
||||
)
|
||||
with open(_service_vars_file.__fspath__(), 'r') as f:
|
||||
src_contents = f.readlines()
|
||||
for it in src_contents:
|
||||
if OPENLLM_MODEL_NAME in it:
|
||||
src_contents[src_contents.index(it)] = (ModelNameFormatter(model_name).vformat(it)[:-(len(OPENLLM_MODEL_NAME) + 3)] + '\n')
|
||||
src_contents[src_contents.index(it)] = (
|
||||
ModelNameFormatter(model_name).vformat(it)[: -(len(OPENLLM_MODEL_NAME) + 3)] + '\n'
|
||||
)
|
||||
if OPENLLM_MODEL_ID in it:
|
||||
src_contents[src_contents.index(it)] = (ModelIdFormatter(model_id).vformat(it)[:-(len(OPENLLM_MODEL_ID) + 3)] + '\n')
|
||||
src_contents[src_contents.index(it)] = (
|
||||
ModelIdFormatter(model_id).vformat(it)[: -(len(OPENLLM_MODEL_ID) + 3)] + '\n'
|
||||
)
|
||||
elif OPENLLM_MODEL_TAG in it:
|
||||
src_contents[src_contents.index(it)] = (ModelTagFormatter(model_tag).vformat(it)[:-(len(OPENLLM_MODEL_TAG) + 3)] + '\n')
|
||||
src_contents[src_contents.index(it)] = (
|
||||
ModelTagFormatter(model_tag).vformat(it)[: -(len(OPENLLM_MODEL_TAG) + 3)] + '\n'
|
||||
)
|
||||
elif OPENLLM_MODEL_ADAPTER_MAP in it:
|
||||
src_contents[src_contents.index(it)] = (ModelAdapterMapFormatter(orjson.dumps(adapter_map).decode()).vformat(it)[:-(len(OPENLLM_MODEL_ADAPTER_MAP) + 3)] + '\n')
|
||||
src_contents[src_contents.index(it)] = (
|
||||
ModelAdapterMapFormatter(orjson.dumps(adapter_map).decode()).vformat(it)[
|
||||
: -(len(OPENLLM_MODEL_ADAPTER_MAP) + 3)
|
||||
]
|
||||
+ '\n'
|
||||
)
|
||||
script = f"# GENERATED BY 'openllm build {model_name}'. DO NOT EDIT\n\n" + ''.join(src_contents)
|
||||
if DEBUG: logger.info('Generated script:\n%s', script)
|
||||
if DEBUG:
|
||||
logger.info('Generated script:\n%s', script)
|
||||
llm_fs.writetext('_service_vars.py', script)
|
||||
|
||||
logger.debug('Generating service file for %s at %s (dir=%s)', model_name, llm.config['service_name'], llm_fs.getsyspath('/'))
|
||||
logger.debug(
|
||||
'Generating service file for %s at %s (dir=%s)', model_name, llm.config['service_name'], llm_fs.getsyspath('/')
|
||||
)
|
||||
with open(_service_file.__fspath__(), 'r') as f:
|
||||
service_src = f.read()
|
||||
llm_fs.writetext(llm.config['service_name'], service_src)
|
||||
|
||||
|
||||
@inject
|
||||
def create_bento(bento_tag: bentoml.Tag,
|
||||
llm_fs: FS,
|
||||
llm: openllm.LLM[t.Any, t.Any],
|
||||
quantize: LiteralString | None,
|
||||
dockerfile_template: str | None,
|
||||
adapter_map: dict[str, str] | None = None,
|
||||
extra_dependencies: tuple[str, ...] | None = None,
|
||||
serialisation: LiteralSerialisation | None = None,
|
||||
container_registry: LiteralContainerRegistry = 'ecr',
|
||||
container_version_strategy: LiteralContainerVersionStrategy = 'release',
|
||||
_bento_store: BentoStore = Provide[BentoMLContainer.bento_store],
|
||||
_model_store: ModelStore = Provide[BentoMLContainer.model_store]) -> bentoml.Bento:
|
||||
_serialisation: LiteralSerialisation = openllm_core.utils.first_not_none(serialisation, default=llm.config['serialisation'])
|
||||
def create_bento(
|
||||
bento_tag: bentoml.Tag,
|
||||
llm_fs: FS,
|
||||
llm: openllm.LLM[t.Any, t.Any],
|
||||
quantize: LiteralString | None,
|
||||
dockerfile_template: str | None,
|
||||
adapter_map: dict[str, str] | None = None,
|
||||
extra_dependencies: tuple[str, ...] | None = None,
|
||||
serialisation: LiteralSerialisation | None = None,
|
||||
container_registry: LiteralContainerRegistry = 'ecr',
|
||||
container_version_strategy: LiteralContainerVersionStrategy = 'release',
|
||||
_bento_store: BentoStore = Provide[BentoMLContainer.bento_store],
|
||||
_model_store: ModelStore = Provide[BentoMLContainer.model_store],
|
||||
) -> bentoml.Bento:
|
||||
_serialisation: LiteralSerialisation = openllm_core.utils.first_not_none(
|
||||
serialisation, default=llm.config['serialisation']
|
||||
)
|
||||
labels = dict(llm.identifying_params)
|
||||
labels.update({'_type': llm.llm_type, '_framework': llm.__llm_backend__, 'start_name': llm.config['start_name'], 'base_name_or_path': llm.model_id, 'bundler': 'openllm.bundle'})
|
||||
if adapter_map: labels.update(adapter_map)
|
||||
labels.update(
|
||||
{
|
||||
'_type': llm.llm_type,
|
||||
'_framework': llm.__llm_backend__,
|
||||
'start_name': llm.config['start_name'],
|
||||
'base_name_or_path': llm.model_id,
|
||||
'bundler': 'openllm.bundle',
|
||||
}
|
||||
)
|
||||
if adapter_map:
|
||||
labels.update(adapter_map)
|
||||
logger.debug("Building Bento '%s' with model backend '%s'", bento_tag, llm.__llm_backend__)
|
||||
# add service.py definition to this temporary folder
|
||||
write_service(llm, adapter_map, llm_fs)
|
||||
|
||||
llm_spec = ModelSpec.from_item({'tag': str(llm.tag), 'alias': llm.tag.name})
|
||||
build_config = BentoBuildConfig(service=f"{llm.config['service_name']}:svc",
|
||||
name=bento_tag.name,
|
||||
labels=labels,
|
||||
models=[llm_spec],
|
||||
description=f"OpenLLM service for {llm.config['start_name']}",
|
||||
include=list(llm_fs.walk.files()),
|
||||
exclude=['/venv', '/.venv', '__pycache__/', '*.py[cod]', '*$py.class'],
|
||||
python=construct_python_options(llm, llm_fs, extra_dependencies, adapter_map),
|
||||
docker=construct_docker_options(llm, llm_fs, quantize, adapter_map, dockerfile_template, _serialisation, container_registry, container_version_strategy))
|
||||
build_config = BentoBuildConfig(
|
||||
service=f"{llm.config['service_name']}:svc",
|
||||
name=bento_tag.name,
|
||||
labels=labels,
|
||||
models=[llm_spec],
|
||||
description=f"OpenLLM service for {llm.config['start_name']}",
|
||||
include=list(llm_fs.walk.files()),
|
||||
exclude=['/venv', '/.venv', '__pycache__/', '*.py[cod]', '*$py.class'],
|
||||
python=construct_python_options(llm, llm_fs, extra_dependencies, adapter_map),
|
||||
docker=construct_docker_options(
|
||||
llm,
|
||||
llm_fs,
|
||||
quantize,
|
||||
adapter_map,
|
||||
dockerfile_template,
|
||||
_serialisation,
|
||||
container_registry,
|
||||
container_version_strategy,
|
||||
),
|
||||
)
|
||||
|
||||
bento = bentoml.Bento.create(build_config=build_config, version=bento_tag.version, build_ctx=llm_fs.getsyspath('/'))
|
||||
# NOTE: the model_id_path here are only used for setting this environment variable within the container built with for BentoLLM.
|
||||
@@ -208,10 +304,12 @@ def create_bento(bento_tag: bentoml.Tag,
|
||||
service_contents = f.readlines()
|
||||
|
||||
for it in service_contents:
|
||||
if '__bento_name__' in it: service_contents[service_contents.index(it)] = it.format(__bento_name__=str(bento.tag))
|
||||
if '__bento_name__' in it:
|
||||
service_contents[service_contents.index(it)] = it.format(__bento_name__=str(bento.tag))
|
||||
|
||||
script = ''.join(service_contents)
|
||||
if openllm_core.utils.DEBUG: logger.info('Generated script:\n%s', script)
|
||||
if openllm_core.utils.DEBUG:
|
||||
logger.info('Generated script:\n%s', script)
|
||||
|
||||
bento._fs.writetext(service_fs_path, script)
|
||||
if 'model_store' in inspect.signature(bento.save).parameters:
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# mypy: disable-error-code="misc"
|
||||
"""OCI-related utilities for OpenLLM. This module is considered to be internal and API are subjected to change."""
|
||||
|
||||
from __future__ import annotations
|
||||
import functools
|
||||
import importlib
|
||||
@@ -23,6 +24,7 @@ import openllm_core
|
||||
|
||||
from openllm_core.utils.lazy import VersionInfo
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from ghapi import all
|
||||
|
||||
@@ -42,7 +44,11 @@ ROOT_DIR = pathlib.Path(os.path.abspath('__file__')).parent.parent.parent
|
||||
# but in the future, we can infer based on git repo and everything to make it more options for users
|
||||
# to build the base image. For now, all of the base image will be <registry>/bentoml/openllm:...
|
||||
# NOTE: The ECR registry is the public one and currently only @bentoml team has access to push it.
|
||||
_CONTAINER_REGISTRY: dict[LiteralContainerRegistry, str] = {'docker': 'docker.io/bentoml/openllm', 'gh': 'ghcr.io/bentoml/openllm', 'ecr': 'public.ecr.aws/y5w8i4y6/bentoml/openllm'}
|
||||
_CONTAINER_REGISTRY: dict[LiteralContainerRegistry, str] = {
|
||||
'docker': 'docker.io/bentoml/openllm',
|
||||
'gh': 'ghcr.io/bentoml/openllm',
|
||||
'ecr': 'public.ecr.aws/y5w8i4y6/bentoml/openllm',
|
||||
}
|
||||
|
||||
# TODO: support custom fork. Currently it only support openllm main.
|
||||
_OWNER = 'bentoml'
|
||||
@@ -50,21 +56,29 @@ _REPO = 'openllm'
|
||||
|
||||
_module_location = openllm_core.utils.pkg.source_locations('openllm')
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
@openllm_core.utils.apply(str.lower)
|
||||
def get_base_container_name(reg: LiteralContainerRegistry) -> str:
|
||||
return _CONTAINER_REGISTRY[reg]
|
||||
|
||||
|
||||
def _convert_version_from_string(s: str) -> VersionInfo:
|
||||
return VersionInfo.from_version_string(s)
|
||||
|
||||
|
||||
def _commit_time_range(r: int = 5) -> str:
|
||||
return (datetime.now(timezone.utc) - timedelta(days=r)).strftime('%Y-%m-%dT%H:%M:%SZ')
|
||||
|
||||
|
||||
class VersionNotSupported(openllm.exceptions.OpenLLMException):
|
||||
"""Raised when the stable release is too low that it doesn't include OpenLLM base container."""
|
||||
|
||||
_RefTuple: type[RefTuple] = openllm_core.utils.codegen.make_attr_tuple_class('_RefTuple', ['git_hash', 'version', 'strategy'])
|
||||
|
||||
_RefTuple: type[RefTuple] = openllm_core.utils.codegen.make_attr_tuple_class(
|
||||
'_RefTuple', ['git_hash', 'version', 'strategy']
|
||||
)
|
||||
|
||||
|
||||
def nightly_resolver(cls: type[RefResolver]) -> str:
|
||||
# NOTE: all openllm container will have sha-<git_hash[:7]>
|
||||
@@ -73,12 +87,27 @@ def nightly_resolver(cls: type[RefResolver]) -> str:
|
||||
docker_bin = shutil.which('docker')
|
||||
if docker_bin is None:
|
||||
logger.warning(
|
||||
'To get the correct available nightly container, make sure to have docker available. Fallback to previous behaviour for determine nightly hash (container might not exists due to the lack of GPU machine at a time. See https://github.com/bentoml/OpenLLM/pkgs/container/openllm for available image.)'
|
||||
'To get the correct available nightly container, make sure to have docker available. Fallback to previous behaviour for determine nightly hash (container might not exists due to the lack of GPU machine at a time. See https://github.com/bentoml/OpenLLM/pkgs/container/openllm for available image.)'
|
||||
)
|
||||
commits = t.cast('list[dict[str, t.Any]]', cls._ghapi.repos.list_commits(since=_commit_time_range()))
|
||||
return next(f'sha-{it["sha"][:7]}' for it in commits if '[skip ci]' not in it['commit']['message'])
|
||||
# now is the correct behaviour
|
||||
return orjson.loads(subprocess.check_output([docker_bin, 'run', '--rm', '-it', 'quay.io/skopeo/stable:latest', 'list-tags', 'docker://ghcr.io/bentoml/openllm']).decode().strip())['Tags'][-2]
|
||||
return orjson.loads(
|
||||
subprocess.check_output(
|
||||
[
|
||||
docker_bin,
|
||||
'run',
|
||||
'--rm',
|
||||
'-it',
|
||||
'quay.io/skopeo/stable:latest',
|
||||
'list-tags',
|
||||
'docker://ghcr.io/bentoml/openllm',
|
||||
]
|
||||
)
|
||||
.decode()
|
||||
.strip()
|
||||
)['Tags'][-2]
|
||||
|
||||
|
||||
@attr.attrs(eq=False, order=False, slots=True, frozen=True)
|
||||
class RefResolver:
|
||||
@@ -98,80 +127,124 @@ class RefResolver:
|
||||
# NOTE: This strategy will only support openllm>0.2.12
|
||||
meta: dict[str, t.Any] = cls._ghapi.repos.get_latest_release()
|
||||
version_str = meta['name'].lstrip('v')
|
||||
version: tuple[str, str | None] = (cls._ghapi.git.get_ref(ref=f"tags/{meta['name']}")['object']['sha'], version_str)
|
||||
version: tuple[str, str | None] = (
|
||||
cls._ghapi.git.get_ref(ref=f"tags/{meta['name']}")['object']['sha'],
|
||||
version_str,
|
||||
)
|
||||
else:
|
||||
version = ('', version_str)
|
||||
if openllm_core.utils.VersionInfo.from_version_string(t.cast(str, version_str)) < (0, 2, 12):
|
||||
raise VersionNotSupported(f"Version {version_str} doesn't support OpenLLM base container. Consider using 'nightly' or upgrade 'openllm>=0.2.12'")
|
||||
raise VersionNotSupported(
|
||||
f"Version {version_str} doesn't support OpenLLM base container. Consider using 'nightly' or upgrade 'openllm>=0.2.12'"
|
||||
)
|
||||
return _RefTuple((*version, 'release' if _use_base_strategy else 'custom'))
|
||||
|
||||
@classmethod
|
||||
@functools.lru_cache(maxsize=64)
|
||||
def from_strategy(cls, strategy_or_version: t.Literal['release', 'nightly'] | LiteralString | None = None) -> RefResolver:
|
||||
def from_strategy(
|
||||
cls, strategy_or_version: t.Literal['release', 'nightly'] | LiteralString | None = None
|
||||
) -> RefResolver:
|
||||
# using default strategy
|
||||
if strategy_or_version is None or strategy_or_version == 'release': return cls(*cls._release_ref())
|
||||
elif strategy_or_version == 'latest': return cls('latest', '0.0.0', 'latest')
|
||||
if strategy_or_version is None or strategy_or_version == 'release':
|
||||
return cls(*cls._release_ref())
|
||||
elif strategy_or_version == 'latest':
|
||||
return cls('latest', '0.0.0', 'latest')
|
||||
elif strategy_or_version == 'nightly':
|
||||
_ref = cls._nightly_ref()
|
||||
return cls(_ref[0], '0.0.0', _ref[-1])
|
||||
else:
|
||||
logger.warning('Using custom %s. Make sure that it is at lease 0.2.12 for base container support.', strategy_or_version)
|
||||
logger.warning(
|
||||
'Using custom %s. Make sure that it is at lease 0.2.12 for base container support.', strategy_or_version
|
||||
)
|
||||
return cls(*cls._release_ref(version_str=strategy_or_version))
|
||||
|
||||
@property
|
||||
def tag(self) -> str:
|
||||
# NOTE: latest tag can also be nightly, but discouraged to use it. For nightly refer to use sha-<git_hash_short>
|
||||
if self.strategy == 'latest': return 'latest'
|
||||
elif self.strategy == 'nightly': return self.git_hash
|
||||
else: return repr(self.version)
|
||||
if self.strategy == 'latest':
|
||||
return 'latest'
|
||||
elif self.strategy == 'nightly':
|
||||
return self.git_hash
|
||||
else:
|
||||
return repr(self.version)
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=256)
|
||||
def get_base_container_tag(strategy: LiteralContainerVersionStrategy | None = None) -> str:
|
||||
return RefResolver.from_strategy(strategy).tag
|
||||
|
||||
def build_container(registries: LiteralContainerRegistry | t.Sequence[LiteralContainerRegistry] | None = None,
|
||||
version_strategy: LiteralContainerVersionStrategy = 'release',
|
||||
push: bool = False,
|
||||
machine: bool = False) -> dict[str | LiteralContainerRegistry, str]:
|
||||
|
||||
def build_container(
|
||||
registries: LiteralContainerRegistry | t.Sequence[LiteralContainerRegistry] | None = None,
|
||||
version_strategy: LiteralContainerVersionStrategy = 'release',
|
||||
push: bool = False,
|
||||
machine: bool = False,
|
||||
) -> dict[str | LiteralContainerRegistry, str]:
|
||||
try:
|
||||
if not _BUILDER.health(): raise openllm.exceptions.Error
|
||||
if not _BUILDER.health():
|
||||
raise openllm.exceptions.Error
|
||||
except (openllm.exceptions.Error, subprocess.CalledProcessError):
|
||||
raise RuntimeError('Building base container requires BuildKit (via Buildx) to be installed. See https://docs.docker.com/build/buildx/install/ for instalation instruction.') from None
|
||||
raise RuntimeError(
|
||||
'Building base container requires BuildKit (via Buildx) to be installed. See https://docs.docker.com/build/buildx/install/ for instalation instruction.'
|
||||
) from None
|
||||
if not shutil.which('nvidia-container-runtime'):
|
||||
raise RuntimeError('NVIDIA Container Toolkit is required to compile CUDA kernel in container.')
|
||||
if not _module_location:
|
||||
raise RuntimeError("Failed to determine source location of 'openllm'. (Possible broken installation)")
|
||||
pyproject_path = pathlib.Path(_module_location).parent.parent / 'pyproject.toml'
|
||||
if not pyproject_path.exists():
|
||||
raise ValueError("This utility can only be run within OpenLLM git repository. Clone it first with 'git clone https://github.com/bentoml/OpenLLM.git'")
|
||||
raise ValueError(
|
||||
"This utility can only be run within OpenLLM git repository. Clone it first with 'git clone https://github.com/bentoml/OpenLLM.git'"
|
||||
)
|
||||
if not registries:
|
||||
tags: dict[str | LiteralContainerRegistry, str] = {alias: f'{value}:{get_base_container_tag(version_strategy)}' for alias, value in _CONTAINER_REGISTRY.items()}
|
||||
tags: dict[str | LiteralContainerRegistry, str] = {
|
||||
alias: f'{value}:{get_base_container_tag(version_strategy)}' for alias, value in _CONTAINER_REGISTRY.items()
|
||||
}
|
||||
else:
|
||||
registries = [registries] if isinstance(registries, str) else list(registries)
|
||||
tags = {name: f'{_CONTAINER_REGISTRY[name]}:{get_base_container_tag(version_strategy)}' for name in registries}
|
||||
try:
|
||||
outputs = _BUILDER.build(file=pathlib.Path(__file__).parent.joinpath('Dockerfile').resolve().__fspath__(),
|
||||
context_path=pyproject_path.parent.__fspath__(),
|
||||
tag=tuple(tags.values()),
|
||||
push=push,
|
||||
progress='plain' if openllm_core.utils.get_debug_mode() else 'auto',
|
||||
quiet=machine)
|
||||
if machine and outputs is not None: tags['image_sha'] = outputs.decode('utf-8').strip()
|
||||
outputs = _BUILDER.build(
|
||||
file=pathlib.Path(__file__).parent.joinpath('Dockerfile').resolve().__fspath__(),
|
||||
context_path=pyproject_path.parent.__fspath__(),
|
||||
tag=tuple(tags.values()),
|
||||
push=push,
|
||||
progress='plain' if openllm_core.utils.get_debug_mode() else 'auto',
|
||||
quiet=machine,
|
||||
)
|
||||
if machine and outputs is not None:
|
||||
tags['image_sha'] = outputs.decode('utf-8').strip()
|
||||
except Exception as err:
|
||||
raise openllm.exceptions.OpenLLMException(f'Failed to containerize base container images (Scroll up to see error above, or set OPENLLMDEVDEBUG=True for more traceback):\n{err}') from err
|
||||
raise openllm.exceptions.OpenLLMException(
|
||||
f'Failed to containerize base container images (Scroll up to see error above, or set OPENLLMDEVDEBUG=True for more traceback):\n{err}'
|
||||
) from err
|
||||
return tags
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
CONTAINER_NAMES: dict[LiteralContainerRegistry, str]
|
||||
supported_registries: list[str]
|
||||
|
||||
__all__ = ['CONTAINER_NAMES', 'get_base_container_tag', 'build_container', 'get_base_container_name', 'supported_registries', 'RefResolver']
|
||||
__all__ = [
|
||||
'CONTAINER_NAMES',
|
||||
'get_base_container_tag',
|
||||
'build_container',
|
||||
'get_base_container_name',
|
||||
'supported_registries',
|
||||
'RefResolver',
|
||||
]
|
||||
|
||||
|
||||
def __dir__() -> list[str]:
|
||||
return sorted(__all__)
|
||||
|
||||
|
||||
def __getattr__(name: str) -> t.Any:
|
||||
if name == 'supported_registries': return functools.lru_cache(1)(lambda: list(_CONTAINER_REGISTRY))()
|
||||
elif name == 'CONTAINER_NAMES': return _CONTAINER_REGISTRY
|
||||
elif name in __all__: return importlib.import_module('.' + name, __name__)
|
||||
else: raise AttributeError(f'{name} does not exists under {__name__}')
|
||||
if name == 'supported_registries':
|
||||
return functools.lru_cache(1)(lambda: list(_CONTAINER_REGISTRY))()
|
||||
elif name == 'CONTAINER_NAMES':
|
||||
return _CONTAINER_REGISTRY
|
||||
elif name in __all__:
|
||||
return importlib.import_module('.' + name, __name__)
|
||||
else:
|
||||
raise AttributeError(f'{name} does not exists under {__name__}')
|
||||
|
||||
@@ -25,8 +25,14 @@ from openllm_core._typing_compat import ParamSpec
|
||||
from openllm_core._typing_compat import get_literal_args
|
||||
from openllm_core.utils import DEBUG
|
||||
|
||||
|
||||
class _OpenLLM_GenericInternalConfig(LLMConfig):
|
||||
__config__ = {'name_type': 'lowercase', 'default_id': 'openllm/generic', 'model_ids': ['openllm/generic'], 'architecture': 'PreTrainedModel'}
|
||||
__config__ = {
|
||||
'name_type': 'lowercase',
|
||||
'default_id': 'openllm/generic',
|
||||
'model_ids': ['openllm/generic'],
|
||||
'architecture': 'PreTrainedModel',
|
||||
}
|
||||
|
||||
class GenerationConfig:
|
||||
top_k: int = 15
|
||||
@@ -34,6 +40,7 @@ class _OpenLLM_GenericInternalConfig(LLMConfig):
|
||||
temperature: float = 0.75
|
||||
max_new_tokens: int = 128
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
P = ParamSpec('P')
|
||||
@@ -42,37 +49,76 @@ LiteralOutput = t.Literal['json', 'pretty', 'porcelain']
|
||||
_AnyCallable = t.Callable[..., t.Any]
|
||||
FC = t.TypeVar('FC', bound=t.Union[_AnyCallable, click.Command])
|
||||
|
||||
|
||||
def bento_complete_envvar(ctx: click.Context, param: click.Parameter, incomplete: str) -> list[sc.CompletionItem]:
|
||||
return [sc.CompletionItem(str(it.tag), help='Bento') for it in bentoml.list() if str(it.tag).startswith(incomplete) and all(k in it.info.labels for k in {'start_name', 'bundler'})]
|
||||
return [
|
||||
sc.CompletionItem(str(it.tag), help='Bento')
|
||||
for it in bentoml.list()
|
||||
if str(it.tag).startswith(incomplete) and all(k in it.info.labels for k in {'start_name', 'bundler'})
|
||||
]
|
||||
|
||||
|
||||
def model_complete_envvar(ctx: click.Context, param: click.Parameter, incomplete: str) -> list[sc.CompletionItem]:
|
||||
return [sc.CompletionItem(inflection.dasherize(it), help='Model') for it in openllm.CONFIG_MAPPING if it.startswith(incomplete)]
|
||||
return [
|
||||
sc.CompletionItem(inflection.dasherize(it), help='Model')
|
||||
for it in openllm.CONFIG_MAPPING
|
||||
if it.startswith(incomplete)
|
||||
]
|
||||
|
||||
def parse_config_options(config: LLMConfig, server_timeout: int, workers_per_resource: float, device: t.Tuple[str, ...] | None, cors: bool, environ: DictStrAny) -> DictStrAny:
|
||||
|
||||
def parse_config_options(
|
||||
config: LLMConfig,
|
||||
server_timeout: int,
|
||||
workers_per_resource: float,
|
||||
device: t.Tuple[str, ...] | None,
|
||||
cors: bool,
|
||||
environ: DictStrAny,
|
||||
) -> DictStrAny:
|
||||
# TODO: Support amd.com/gpu on k8s
|
||||
_bentoml_config_options_env = environ.pop('BENTOML_CONFIG_OPTIONS', '')
|
||||
_bentoml_config_options_opts = [
|
||||
'tracing.sample_rate=1.0', f'api_server.traffic.timeout={server_timeout}', f'runners."llm-{config["start_name"]}-runner".traffic.timeout={config["timeout"]}',
|
||||
f'runners."llm-{config["start_name"]}-runner".workers_per_resource={workers_per_resource}'
|
||||
'tracing.sample_rate=1.0',
|
||||
f'api_server.traffic.timeout={server_timeout}',
|
||||
f'runners."llm-{config["start_name"]}-runner".traffic.timeout={config["timeout"]}',
|
||||
f'runners."llm-{config["start_name"]}-runner".workers_per_resource={workers_per_resource}',
|
||||
]
|
||||
if device:
|
||||
if len(device) > 1:
|
||||
_bentoml_config_options_opts.extend([f'runners."llm-{config["start_name"]}-runner".resources."nvidia.com/gpu"[{idx}]={dev}' for idx, dev in enumerate(device)])
|
||||
_bentoml_config_options_opts.extend(
|
||||
[
|
||||
f'runners."llm-{config["start_name"]}-runner".resources."nvidia.com/gpu"[{idx}]={dev}'
|
||||
for idx, dev in enumerate(device)
|
||||
]
|
||||
)
|
||||
else:
|
||||
_bentoml_config_options_opts.append(f'runners."llm-{config["start_name"]}-runner".resources."nvidia.com/gpu"=[{device[0]}]')
|
||||
_bentoml_config_options_opts.append(
|
||||
f'runners."llm-{config["start_name"]}-runner".resources."nvidia.com/gpu"=[{device[0]}]'
|
||||
)
|
||||
if cors:
|
||||
_bentoml_config_options_opts.extend(['api_server.http.cors.enabled=true', 'api_server.http.cors.access_control_allow_origins="*"'])
|
||||
_bentoml_config_options_opts.extend([f'api_server.http.cors.access_control_allow_methods[{idx}]="{it}"' for idx, it in enumerate(['GET', 'OPTIONS', 'POST', 'HEAD', 'PUT'])])
|
||||
_bentoml_config_options_opts.extend(
|
||||
['api_server.http.cors.enabled=true', 'api_server.http.cors.access_control_allow_origins="*"']
|
||||
)
|
||||
_bentoml_config_options_opts.extend(
|
||||
[
|
||||
f'api_server.http.cors.access_control_allow_methods[{idx}]="{it}"'
|
||||
for idx, it in enumerate(['GET', 'OPTIONS', 'POST', 'HEAD', 'PUT'])
|
||||
]
|
||||
)
|
||||
_bentoml_config_options_env += ' ' if _bentoml_config_options_env else '' + ' '.join(_bentoml_config_options_opts)
|
||||
environ['BENTOML_CONFIG_OPTIONS'] = _bentoml_config_options_env
|
||||
if DEBUG: logger.debug('Setting BENTOML_CONFIG_OPTIONS=%s', _bentoml_config_options_env)
|
||||
if DEBUG:
|
||||
logger.debug('Setting BENTOML_CONFIG_OPTIONS=%s', _bentoml_config_options_env)
|
||||
return environ
|
||||
|
||||
|
||||
_adapter_mapping_key = 'adapter_map'
|
||||
|
||||
|
||||
def _id_callback(ctx: click.Context, _: click.Parameter, value: t.Tuple[str, ...] | None) -> None:
|
||||
if not value: return None
|
||||
if _adapter_mapping_key not in ctx.params: ctx.params[_adapter_mapping_key] = {}
|
||||
if not value:
|
||||
return None
|
||||
if _adapter_mapping_key not in ctx.params:
|
||||
ctx.params[_adapter_mapping_key] = {}
|
||||
for v in value:
|
||||
adapter_id, *adapter_name = v.rsplit(':', maxsplit=1)
|
||||
# try to resolve the full path if users pass in relative,
|
||||
@@ -81,20 +127,28 @@ def _id_callback(ctx: click.Context, _: click.Parameter, value: t.Tuple[str, ...
|
||||
adapter_id = openllm.utils.resolve_user_filepath(adapter_id, os.getcwd())
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
if len(adapter_name) == 0: raise ClickException(f'Adapter name is required for {adapter_id}')
|
||||
if len(adapter_name) == 0:
|
||||
raise ClickException(f'Adapter name is required for {adapter_id}')
|
||||
ctx.params[_adapter_mapping_key][adapter_id] = adapter_name[0]
|
||||
return None
|
||||
|
||||
|
||||
def start_decorator(serve_grpc: bool = False) -> t.Callable[[FC], t.Callable[[FC], FC]]:
|
||||
def wrapper(fn: FC) -> t.Callable[[FC], FC]:
|
||||
composed = openllm.utils.compose(
|
||||
_OpenLLM_GenericInternalConfig().to_click_options, _http_server_args if not serve_grpc else _grpc_server_args,
|
||||
cog.optgroup.group('General LLM Options', help='The following options are related to running LLM Server.'), model_version_option(factory=cog.optgroup),
|
||||
system_message_option(factory=cog.optgroup), prompt_template_file_option(factory=cog.optgroup),
|
||||
cog.optgroup.option('--server-timeout', type=int, default=None, help='Server timeout in seconds'), workers_per_resource_option(factory=cog.optgroup), cors_option(factory=cog.optgroup),
|
||||
backend_option(factory=cog.optgroup),
|
||||
cog.optgroup.group('LLM Optimization Options',
|
||||
help='''Optimization related options.
|
||||
_OpenLLM_GenericInternalConfig().to_click_options,
|
||||
_http_server_args if not serve_grpc else _grpc_server_args,
|
||||
cog.optgroup.group('General LLM Options', help='The following options are related to running LLM Server.'),
|
||||
model_version_option(factory=cog.optgroup),
|
||||
system_message_option(factory=cog.optgroup),
|
||||
prompt_template_file_option(factory=cog.optgroup),
|
||||
cog.optgroup.option('--server-timeout', type=int, default=None, help='Server timeout in seconds'),
|
||||
workers_per_resource_option(factory=cog.optgroup),
|
||||
cors_option(factory=cog.optgroup),
|
||||
backend_option(factory=cog.optgroup),
|
||||
cog.optgroup.group(
|
||||
'LLM Optimization Options',
|
||||
help="""Optimization related options.
|
||||
|
||||
OpenLLM supports running model k-bit quantization (8-bit, 4-bit), GPTQ quantization, PagedAttention via vLLM.
|
||||
|
||||
@@ -102,16 +156,22 @@ def start_decorator(serve_grpc: bool = False) -> t.Callable[[FC], t.Callable[[FC
|
||||
|
||||
- DeepSpeed Inference: [link](https://www.deepspeed.ai/inference/)
|
||||
- GGML: Fast inference on [bare metal](https://github.com/ggerganov/ggml)
|
||||
'''), quantize_option(factory=cog.optgroup), serialisation_option(factory=cog.optgroup),
|
||||
cog.optgroup.option('--device',
|
||||
type=openllm.utils.dantic.CUDA,
|
||||
multiple=True,
|
||||
envvar='CUDA_VISIBLE_DEVICES',
|
||||
callback=parse_device_callback,
|
||||
help='Assign GPU devices (if available)',
|
||||
show_envvar=True),
|
||||
cog.optgroup.group('Fine-tuning related options',
|
||||
help='''\
|
||||
""",
|
||||
),
|
||||
quantize_option(factory=cog.optgroup),
|
||||
serialisation_option(factory=cog.optgroup),
|
||||
cog.optgroup.option(
|
||||
'--device',
|
||||
type=openllm.utils.dantic.CUDA,
|
||||
multiple=True,
|
||||
envvar='CUDA_VISIBLE_DEVICES',
|
||||
callback=parse_device_callback,
|
||||
help='Assign GPU devices (if available)',
|
||||
show_envvar=True,
|
||||
),
|
||||
cog.optgroup.group(
|
||||
'Fine-tuning related options',
|
||||
help="""\
|
||||
Note that the argument `--adapter-id` can accept the following format:
|
||||
|
||||
- `--adapter-id /path/to/adapter` (local adapter)
|
||||
@@ -125,46 +185,62 @@ def start_decorator(serve_grpc: bool = False) -> t.Callable[[FC], t.Callable[[FC
|
||||
$ openllm start opt --adapter-id /path/to/adapter_dir --adapter-id remote/adapter:eng_lora
|
||||
|
||||
```
|
||||
'''),
|
||||
cog.optgroup.option('--adapter-id',
|
||||
default=None,
|
||||
help='Optional name or path for given LoRA adapter',
|
||||
multiple=True,
|
||||
callback=_id_callback,
|
||||
metavar='[PATH | [remote/][adapter_name:]adapter_id][, ...]'), click.option('--return-process', is_flag=True, default=False, help='Internal use only.',
|
||||
hidden=True),
|
||||
""",
|
||||
),
|
||||
cog.optgroup.option(
|
||||
'--adapter-id',
|
||||
default=None,
|
||||
help='Optional name or path for given LoRA adapter',
|
||||
multiple=True,
|
||||
callback=_id_callback,
|
||||
metavar='[PATH | [remote/][adapter_name:]adapter_id][, ...]',
|
||||
),
|
||||
click.option('--return-process', is_flag=True, default=False, help='Internal use only.', hidden=True),
|
||||
)
|
||||
return composed(fn)
|
||||
|
||||
return wrapper
|
||||
|
||||
def parse_device_callback(ctx: click.Context, param: click.Parameter, value: tuple[tuple[str], ...] | None) -> t.Tuple[str, ...] | None:
|
||||
if value is None: return value
|
||||
if not isinstance(value, tuple): ctx.fail(f'{param} only accept multiple values, not {type(value)} (value: {value})')
|
||||
|
||||
def parse_device_callback(
|
||||
ctx: click.Context, param: click.Parameter, value: tuple[tuple[str], ...] | None
|
||||
) -> t.Tuple[str, ...] | None:
|
||||
if value is None:
|
||||
return value
|
||||
if not isinstance(value, tuple):
|
||||
ctx.fail(f'{param} only accept multiple values, not {type(value)} (value: {value})')
|
||||
el: t.Tuple[str, ...] = tuple(i for k in value for i in k)
|
||||
# NOTE: --device all is a special case
|
||||
if len(el) == 1 and el[0] == 'all': return tuple(map(str, openllm.utils.available_devices()))
|
||||
if len(el) == 1 and el[0] == 'all':
|
||||
return tuple(map(str, openllm.utils.available_devices()))
|
||||
return el
|
||||
|
||||
|
||||
# NOTE: A list of bentoml option that is not needed for parsing.
|
||||
# NOTE: User shouldn't set '--working-dir', as OpenLLM will setup this.
|
||||
# NOTE: production is also deprecated
|
||||
_IGNORED_OPTIONS = {'working_dir', 'production', 'protocol_version'}
|
||||
|
||||
|
||||
def parse_serve_args(serve_grpc: bool) -> t.Callable[[t.Callable[..., LLMConfig]], t.Callable[[FC], FC]]:
|
||||
"""Parsing `bentoml serve|serve-grpc` click.Option to be parsed via `openllm start`."""
|
||||
from bentoml_cli.cli import cli
|
||||
|
||||
command = 'serve' if not serve_grpc else 'serve-grpc'
|
||||
group = cog.optgroup.group(f"Start a {'HTTP' if not serve_grpc else 'gRPC'} server options",
|
||||
help=f"Related to serving the model [synonymous to `bentoml {'serve-http' if not serve_grpc else command }`]",
|
||||
)
|
||||
group = cog.optgroup.group(
|
||||
f"Start a {'HTTP' if not serve_grpc else 'gRPC'} server options",
|
||||
help=f"Related to serving the model [synonymous to `bentoml {'serve-http' if not serve_grpc else command }`]",
|
||||
)
|
||||
|
||||
def decorator(f: t.Callable[Concatenate[int, t.Optional[str], P], LLMConfig]) -> t.Callable[[FC], FC]:
|
||||
serve_command = cli.commands[command]
|
||||
# The first variable is the argument bento
|
||||
# The last five is from BentoMLCommandGroup.NUMBER_OF_COMMON_PARAMS
|
||||
serve_options = [p for p in serve_command.params[1:-BentoMLCommandGroup.NUMBER_OF_COMMON_PARAMS] if p.name not in _IGNORED_OPTIONS]
|
||||
serve_options = [
|
||||
p
|
||||
for p in serve_command.params[1 : -BentoMLCommandGroup.NUMBER_OF_COMMON_PARAMS]
|
||||
if p.name not in _IGNORED_OPTIONS
|
||||
]
|
||||
for options in reversed(serve_options):
|
||||
attrs = options.to_info_dict()
|
||||
# we don't need param_type_name, since it should all be options
|
||||
@@ -179,8 +255,10 @@ def parse_serve_args(serve_grpc: bool) -> t.Callable[[t.Callable[..., LLMConfig]
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
_http_server_args, _grpc_server_args = parse_serve_args(False), parse_serve_args(True)
|
||||
|
||||
|
||||
def _click_factory_type(*param_decls: t.Any, **attrs: t.Any) -> t.Callable[[FC | None], FC]:
|
||||
"""General ``@click`` decorator with some sauce.
|
||||
|
||||
@@ -189,68 +267,114 @@ def _click_factory_type(*param_decls: t.Any, **attrs: t.Any) -> t.Callable[[FC |
|
||||
"""
|
||||
factory = attrs.pop('factory', click)
|
||||
factory_attr = attrs.pop('attr', 'option')
|
||||
if factory_attr != 'argument': attrs.setdefault('help', 'General option for OpenLLM CLI.')
|
||||
if factory_attr != 'argument':
|
||||
attrs.setdefault('help', 'General option for OpenLLM CLI.')
|
||||
|
||||
def decorator(f: FC | None) -> FC:
|
||||
callback = getattr(factory, factory_attr, None)
|
||||
if callback is None: raise ValueError(f'Factory {factory} has no attribute {factory_attr}.')
|
||||
if callback is None:
|
||||
raise ValueError(f'Factory {factory} has no attribute {factory_attr}.')
|
||||
return t.cast(FC, callback(*param_decls, **attrs)(f) if f is not None else callback(*param_decls, **attrs))
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
cli_option = functools.partial(_click_factory_type, attr='option')
|
||||
cli_argument = functools.partial(_click_factory_type, attr='argument')
|
||||
|
||||
|
||||
def cors_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option('--cors/--no-cors', show_default=True, default=False, envvar='OPENLLM_CORS', show_envvar=True, help='Enable CORS for the server.', **attrs)(f)
|
||||
return cli_option(
|
||||
'--cors/--no-cors',
|
||||
show_default=True,
|
||||
default=False,
|
||||
envvar='OPENLLM_CORS',
|
||||
show_envvar=True,
|
||||
help='Enable CORS for the server.',
|
||||
**attrs,
|
||||
)(f)
|
||||
|
||||
|
||||
def machine_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option('--machine', is_flag=True, default=False, hidden=True, **attrs)(f)
|
||||
|
||||
|
||||
def model_id_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option('--model-id', type=click.STRING, default=None, envvar='OPENLLM_MODEL_ID', show_envvar=True, help='Optional model_id name or path for (fine-tune) weight.', **attrs)(f)
|
||||
return cli_option(
|
||||
'--model-id',
|
||||
type=click.STRING,
|
||||
default=None,
|
||||
envvar='OPENLLM_MODEL_ID',
|
||||
show_envvar=True,
|
||||
help='Optional model_id name or path for (fine-tune) weight.',
|
||||
**attrs,
|
||||
)(f)
|
||||
|
||||
|
||||
def model_version_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option('--model-version', type=click.STRING, default=None, help='Optional model version to save for this model. It will be inferred automatically from model-id.', **attrs)(f)
|
||||
return cli_option(
|
||||
'--model-version',
|
||||
type=click.STRING,
|
||||
default=None,
|
||||
help='Optional model version to save for this model. It will be inferred automatically from model-id.',
|
||||
**attrs,
|
||||
)(f)
|
||||
|
||||
|
||||
def system_message_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option('--system-message',
|
||||
type=click.STRING,
|
||||
default=None,
|
||||
envvar='OPENLLM_SYSTEM_MESSAGE',
|
||||
help='Optional system message for supported LLMs. If given LLM supports system message, OpenLLM will provide a default system message.',
|
||||
**attrs)(f)
|
||||
return cli_option(
|
||||
'--system-message',
|
||||
type=click.STRING,
|
||||
default=None,
|
||||
envvar='OPENLLM_SYSTEM_MESSAGE',
|
||||
help='Optional system message for supported LLMs. If given LLM supports system message, OpenLLM will provide a default system message.',
|
||||
**attrs,
|
||||
)(f)
|
||||
|
||||
|
||||
def prompt_template_file_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option('--prompt-template-file',
|
||||
type=click.File(),
|
||||
default=None,
|
||||
help='Optional file path containing user-defined custom prompt template. By default, the prompt template for the specified LLM will be used.',
|
||||
**attrs)(f)
|
||||
return cli_option(
|
||||
'--prompt-template-file',
|
||||
type=click.File(),
|
||||
default=None,
|
||||
help='Optional file path containing user-defined custom prompt template. By default, the prompt template for the specified LLM will be used.',
|
||||
**attrs,
|
||||
)(f)
|
||||
|
||||
|
||||
def backend_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
# NOTE: LiteralBackend needs to remove the last two item as ggml and mlc is wip
|
||||
# XXX: remove the check for __args__ once we have ggml and mlc supports
|
||||
return cli_option('--backend',
|
||||
type=click.Choice(get_literal_args(LiteralBackend)[:2]),
|
||||
default=None,
|
||||
envvar='OPENLLM_BACKEND',
|
||||
show_envvar=True,
|
||||
help='The implementation for saving this LLM.',
|
||||
**attrs)(f)
|
||||
return cli_option(
|
||||
'--backend',
|
||||
type=click.Choice(get_literal_args(LiteralBackend)[:2]),
|
||||
default=None,
|
||||
envvar='OPENLLM_BACKEND',
|
||||
show_envvar=True,
|
||||
help='The implementation for saving this LLM.',
|
||||
**attrs,
|
||||
)(f)
|
||||
|
||||
|
||||
def model_name_argument(f: _AnyCallable | None = None, required: bool = True, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_argument('model_name', type=click.Choice([inflection.dasherize(name) for name in openllm.CONFIG_MAPPING]), required=required, **attrs)(f)
|
||||
return cli_argument(
|
||||
'model_name',
|
||||
type=click.Choice([inflection.dasherize(name) for name in openllm.CONFIG_MAPPING]),
|
||||
required=required,
|
||||
**attrs,
|
||||
)(f)
|
||||
|
||||
|
||||
def quantize_option(f: _AnyCallable | None = None, *, build: bool = False, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option('--quantise',
|
||||
'--quantize',
|
||||
'quantize',
|
||||
type=click.Choice(get_literal_args(LiteralQuantise)),
|
||||
default=None,
|
||||
envvar='OPENLLM_QUANTIZE',
|
||||
show_envvar=True,
|
||||
help='''Dynamic quantization for running this LLM.
|
||||
return cli_option(
|
||||
'--quantise',
|
||||
'--quantize',
|
||||
'quantize',
|
||||
type=click.Choice(get_literal_args(LiteralQuantise)),
|
||||
default=None,
|
||||
envvar='OPENLLM_QUANTIZE',
|
||||
show_envvar=True,
|
||||
help="""Dynamic quantization for running this LLM.
|
||||
|
||||
The following quantization strategies are supported:
|
||||
|
||||
@@ -261,18 +385,29 @@ def quantize_option(f: _AnyCallable | None = None, *, build: bool = False, **att
|
||||
- ``gptq``: ``GPTQ`` [quantization](https://arxiv.org/abs/2210.17323)
|
||||
|
||||
> [!NOTE] that the model can also be served with quantized weights.
|
||||
''' + ('''
|
||||
> [!NOTE] that this will set the mode for serving within deployment.''' if build else '') + '''
|
||||
> [!NOTE] that quantization are currently only available in *PyTorch* models.''',
|
||||
**attrs)(f)
|
||||
"""
|
||||
+ (
|
||||
"""
|
||||
> [!NOTE] that this will set the mode for serving within deployment."""
|
||||
if build
|
||||
else ''
|
||||
)
|
||||
+ """
|
||||
> [!NOTE] that quantization are currently only available in *PyTorch* models.""",
|
||||
**attrs,
|
||||
)(f)
|
||||
|
||||
def workers_per_resource_option(f: _AnyCallable | None = None, *, build: bool = False, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option('--workers-per-resource',
|
||||
default=None,
|
||||
callback=workers_per_resource_callback,
|
||||
type=str,
|
||||
required=False,
|
||||
help='''Number of workers per resource assigned.
|
||||
|
||||
def workers_per_resource_option(
|
||||
f: _AnyCallable | None = None, *, build: bool = False, **attrs: t.Any
|
||||
) -> t.Callable[[FC], FC]:
|
||||
return cli_option(
|
||||
'--workers-per-resource',
|
||||
default=None,
|
||||
callback=workers_per_resource_callback,
|
||||
type=str,
|
||||
required=False,
|
||||
help="""Number of workers per resource assigned.
|
||||
|
||||
See https://docs.bentoml.org/en/latest/guides/scheduling.html#resource-scheduling-strategy
|
||||
for more information. By default, this is set to 1.
|
||||
@@ -282,22 +417,30 @@ def workers_per_resource_option(f: _AnyCallable | None = None, *, build: bool =
|
||||
- ``round_robin``: Similar behaviour when setting ``--workers-per-resource 1``. This is useful for smaller models.
|
||||
|
||||
- ``conserved``: This will determine the number of available GPU resources, and only assign one worker for the LLMRunner. For example, if ther are 4 GPUs available, then ``conserved`` is equivalent to ``--workers-per-resource 0.25``.
|
||||
''' + ("""\n
|
||||
"""
|
||||
+ (
|
||||
"""\n
|
||||
> [!NOTE] The workers value passed into 'build' will determine how the LLM can
|
||||
> be provisioned in Kubernetes as well as in standalone container. This will
|
||||
> ensure it has the same effect with 'openllm start --api-workers ...'""" if build else ''),
|
||||
**attrs)(f)
|
||||
> ensure it has the same effect with 'openllm start --api-workers ...'"""
|
||||
if build
|
||||
else ''
|
||||
),
|
||||
**attrs,
|
||||
)(f)
|
||||
|
||||
|
||||
def serialisation_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option('--serialisation',
|
||||
'--serialization',
|
||||
'serialisation',
|
||||
type=click.Choice(get_literal_args(LiteralSerialisation)),
|
||||
default=None,
|
||||
show_default=True,
|
||||
show_envvar=True,
|
||||
envvar='OPENLLM_SERIALIZATION',
|
||||
help='''Serialisation format for save/load LLM.
|
||||
return cli_option(
|
||||
'--serialisation',
|
||||
'--serialization',
|
||||
'serialisation',
|
||||
type=click.Choice(get_literal_args(LiteralSerialisation)),
|
||||
default=None,
|
||||
show_default=True,
|
||||
show_envvar=True,
|
||||
envvar='OPENLLM_SERIALIZATION',
|
||||
help="""Serialisation format for save/load LLM.
|
||||
|
||||
Currently the following strategies are supported:
|
||||
|
||||
@@ -306,37 +449,51 @@ def serialisation_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Cal
|
||||
> [!NOTE] Safetensors might not work for every cases, and you can always fallback to ``legacy`` if needed.
|
||||
|
||||
- ``legacy``: This will use PyTorch serialisation format, often as ``.bin`` files. This should be used if the model doesn't yet support safetensors.
|
||||
''',
|
||||
**attrs)(f)
|
||||
""",
|
||||
**attrs,
|
||||
)(f)
|
||||
|
||||
|
||||
def container_registry_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option('--container-registry',
|
||||
'container_registry',
|
||||
type=click.Choice(list(openllm.bundle.CONTAINER_NAMES)),
|
||||
default='ecr',
|
||||
show_default=True,
|
||||
show_envvar=True,
|
||||
envvar='OPENLLM_CONTAINER_REGISTRY',
|
||||
callback=container_registry_callback,
|
||||
help='The default container registry to get the base image for building BentoLLM. Currently, it supports ecr, ghcr, docker',
|
||||
**attrs)(f)
|
||||
return cli_option(
|
||||
'--container-registry',
|
||||
'container_registry',
|
||||
type=click.Choice(list(openllm.bundle.CONTAINER_NAMES)),
|
||||
default='ecr',
|
||||
show_default=True,
|
||||
show_envvar=True,
|
||||
envvar='OPENLLM_CONTAINER_REGISTRY',
|
||||
callback=container_registry_callback,
|
||||
help='The default container registry to get the base image for building BentoLLM. Currently, it supports ecr, ghcr, docker',
|
||||
**attrs,
|
||||
)(f)
|
||||
|
||||
|
||||
_wpr_strategies = {'round_robin', 'conserved'}
|
||||
|
||||
|
||||
def workers_per_resource_callback(ctx: click.Context, param: click.Parameter, value: str | None) -> str | None:
|
||||
if value is None: return value
|
||||
if value is None:
|
||||
return value
|
||||
value = inflection.underscore(value)
|
||||
if value in _wpr_strategies: return value
|
||||
if value in _wpr_strategies:
|
||||
return value
|
||||
else:
|
||||
try:
|
||||
float(value) # type: ignore[arg-type]
|
||||
except ValueError:
|
||||
raise click.BadParameter(f"'workers_per_resource' only accept '{_wpr_strategies}' as possible strategies, otherwise pass in float.", ctx, param) from None
|
||||
raise click.BadParameter(
|
||||
f"'workers_per_resource' only accept '{_wpr_strategies}' as possible strategies, otherwise pass in float.",
|
||||
ctx,
|
||||
param,
|
||||
) from None
|
||||
else:
|
||||
return value
|
||||
|
||||
|
||||
def container_registry_callback(ctx: click.Context, param: click.Parameter, value: str | None) -> str | None:
|
||||
if value is None: return value
|
||||
if value is None:
|
||||
return value
|
||||
if value not in openllm.bundle.supported_registries:
|
||||
raise click.BadParameter(f'Value must be one of {openllm.bundle.supported_registries}', ctx, param)
|
||||
return value
|
||||
|
||||
@@ -22,6 +22,7 @@ from openllm_core.utils import codegen
|
||||
from openllm_core.utils import first_not_none
|
||||
from openllm_core.utils import is_vllm_available
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from bentoml._internal.bento import BentoStore
|
||||
from openllm_core._configuration import LLMConfig
|
||||
@@ -33,20 +34,23 @@ if t.TYPE_CHECKING:
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def _start(model_id: str,
|
||||
timeout: int = 30,
|
||||
workers_per_resource: t.Literal['conserved', 'round_robin'] | float | None = None,
|
||||
device: tuple[str, ...] | t.Literal['all'] | None = None,
|
||||
quantize: LiteralQuantise | None = None,
|
||||
system_message: str | None = None,
|
||||
prompt_template_file: str | None = None,
|
||||
adapter_map: dict[LiteralString, str | None] | None = None,
|
||||
backend: LiteralBackend | None = None,
|
||||
additional_args: list[str] | None = None,
|
||||
cors: bool = False,
|
||||
_serve_grpc: bool = False,
|
||||
__test__: bool = False,
|
||||
**_: t.Any) -> LLMConfig | subprocess.Popen[bytes]:
|
||||
|
||||
def _start(
|
||||
model_id: str,
|
||||
timeout: int = 30,
|
||||
workers_per_resource: t.Literal['conserved', 'round_robin'] | float | None = None,
|
||||
device: tuple[str, ...] | t.Literal['all'] | None = None,
|
||||
quantize: LiteralQuantise | None = None,
|
||||
system_message: str | None = None,
|
||||
prompt_template_file: str | None = None,
|
||||
adapter_map: dict[LiteralString, str | None] | None = None,
|
||||
backend: LiteralBackend | None = None,
|
||||
additional_args: list[str] | None = None,
|
||||
cors: bool = False,
|
||||
_serve_grpc: bool = False,
|
||||
__test__: bool = False,
|
||||
**_: t.Any,
|
||||
) -> LLMConfig | subprocess.Popen[bytes]:
|
||||
"""Python API to start a LLM server. These provides one-to-one mapping to CLI arguments.
|
||||
|
||||
For all additional arguments, pass it as string to ``additional_args``. For example, if you want to
|
||||
@@ -85,45 +89,68 @@ def _start(model_id: str,
|
||||
"""
|
||||
from .entrypoint import start_command
|
||||
from .entrypoint import start_grpc_command
|
||||
os.environ['OPENLLM_BACKEND'] = openllm_core.utils.first_not_none(backend, default='vllm' if is_vllm_available() else 'pt')
|
||||
|
||||
os.environ['OPENLLM_BACKEND'] = openllm_core.utils.first_not_none(
|
||||
backend, default='vllm' if is_vllm_available() else 'pt'
|
||||
)
|
||||
|
||||
args: list[str] = [model_id]
|
||||
if system_message: args.extend(['--system-message', system_message])
|
||||
if prompt_template_file: args.extend(['--prompt-template-file', openllm_core.utils.resolve_filepath(prompt_template_file)])
|
||||
if timeout: args.extend(['--server-timeout', str(timeout)])
|
||||
if system_message:
|
||||
args.extend(['--system-message', system_message])
|
||||
if prompt_template_file:
|
||||
args.extend(['--prompt-template-file', openllm_core.utils.resolve_filepath(prompt_template_file)])
|
||||
if timeout:
|
||||
args.extend(['--server-timeout', str(timeout)])
|
||||
if workers_per_resource:
|
||||
args.extend(['--workers-per-resource', str(workers_per_resource) if not isinstance(workers_per_resource, str) else workers_per_resource])
|
||||
if device and not os.environ.get('CUDA_VISIBLE_DEVICES'): args.extend(['--device', ','.join(device)])
|
||||
if quantize: args.extend(['--quantize', str(quantize)])
|
||||
if cors: args.append('--cors')
|
||||
args.extend(
|
||||
[
|
||||
'--workers-per-resource',
|
||||
str(workers_per_resource) if not isinstance(workers_per_resource, str) else workers_per_resource,
|
||||
]
|
||||
)
|
||||
if device and not os.environ.get('CUDA_VISIBLE_DEVICES'):
|
||||
args.extend(['--device', ','.join(device)])
|
||||
if quantize:
|
||||
args.extend(['--quantize', str(quantize)])
|
||||
if cors:
|
||||
args.append('--cors')
|
||||
if adapter_map:
|
||||
args.extend(list(itertools.chain.from_iterable([['--adapter-id', f"{k}{':'+v if v else ''}"] for k, v in adapter_map.items()])))
|
||||
if additional_args: args.extend(additional_args)
|
||||
if __test__: args.append('--return-process')
|
||||
args.extend(
|
||||
list(
|
||||
itertools.chain.from_iterable([['--adapter-id', f"{k}{':'+v if v else ''}"] for k, v in adapter_map.items()])
|
||||
)
|
||||
)
|
||||
if additional_args:
|
||||
args.extend(additional_args)
|
||||
if __test__:
|
||||
args.append('--return-process')
|
||||
|
||||
cmd = start_command if not _serve_grpc else start_grpc_command
|
||||
return cmd.main(args=args, standalone_mode=False)
|
||||
|
||||
|
||||
@inject
|
||||
def _build(model_id: str,
|
||||
model_version: str | None = None,
|
||||
bento_version: str | None = None,
|
||||
quantize: LiteralQuantise | None = None,
|
||||
adapter_map: dict[str, str | None] | None = None,
|
||||
system_message: str | None = None,
|
||||
prompt_template_file: str | None = None,
|
||||
build_ctx: str | None = None,
|
||||
enable_features: tuple[str, ...] | None = None,
|
||||
dockerfile_template: str | None = None,
|
||||
overwrite: bool = False,
|
||||
container_registry: LiteralContainerRegistry | None = None,
|
||||
container_version_strategy: LiteralContainerVersionStrategy | None = None,
|
||||
push: bool = False,
|
||||
force_push: bool = False,
|
||||
containerize: bool = False,
|
||||
serialisation: LiteralSerialisation | None = None,
|
||||
additional_args: list[str] | None = None,
|
||||
bento_store: BentoStore = Provide[BentoMLContainer.bento_store]) -> bentoml.Bento:
|
||||
def _build(
|
||||
model_id: str,
|
||||
model_version: str | None = None,
|
||||
bento_version: str | None = None,
|
||||
quantize: LiteralQuantise | None = None,
|
||||
adapter_map: dict[str, str | None] | None = None,
|
||||
system_message: str | None = None,
|
||||
prompt_template_file: str | None = None,
|
||||
build_ctx: str | None = None,
|
||||
enable_features: tuple[str, ...] | None = None,
|
||||
dockerfile_template: str | None = None,
|
||||
overwrite: bool = False,
|
||||
container_registry: LiteralContainerRegistry | None = None,
|
||||
container_version_strategy: LiteralContainerVersionStrategy | None = None,
|
||||
push: bool = False,
|
||||
force_push: bool = False,
|
||||
containerize: bool = False,
|
||||
serialisation: LiteralSerialisation | None = None,
|
||||
additional_args: list[str] | None = None,
|
||||
bento_store: BentoStore = Provide[BentoMLContainer.bento_store],
|
||||
) -> bentoml.Bento:
|
||||
"""Package a LLM into a BentoLLM.
|
||||
|
||||
The LLM will be built into a BentoService with the following structure:
|
||||
@@ -161,49 +188,83 @@ def _build(model_id: str,
|
||||
``bentoml.Bento | str``: BentoLLM instance. This can be used to serve the LLM or can be pushed to BentoCloud.
|
||||
"""
|
||||
from ..serialisation.transformers.weights import has_safetensors_weights
|
||||
|
||||
args: list[str] = [
|
||||
sys.executable, '-m', 'openllm', 'build', model_id, '--machine', '--serialisation',
|
||||
t.cast(LiteralSerialisation, first_not_none(serialisation, default='safetensors' if has_safetensors_weights(model_id) else 'legacy'))
|
||||
sys.executable,
|
||||
'-m',
|
||||
'openllm',
|
||||
'build',
|
||||
model_id,
|
||||
'--machine',
|
||||
'--serialisation',
|
||||
t.cast(
|
||||
LiteralSerialisation,
|
||||
first_not_none(serialisation, default='safetensors' if has_safetensors_weights(model_id) else 'legacy'),
|
||||
),
|
||||
]
|
||||
if quantize: args.extend(['--quantize', quantize])
|
||||
if containerize and push: raise OpenLLMException("'containerize' and 'push' are currently mutually exclusive.")
|
||||
if push: args.extend(['--push'])
|
||||
if containerize: args.extend(['--containerize'])
|
||||
if build_ctx: args.extend(['--build-ctx', build_ctx])
|
||||
if enable_features: args.extend([f'--enable-features={f}' for f in enable_features])
|
||||
if overwrite: args.append('--overwrite')
|
||||
if system_message: args.extend(['--system-message', system_message])
|
||||
if prompt_template_file: args.extend(['--prompt-template-file', openllm_core.utils.resolve_filepath(prompt_template_file)])
|
||||
if adapter_map: args.extend([f"--adapter-id={k}{':'+v if v is not None else ''}" for k, v in adapter_map.items()])
|
||||
if model_version: args.extend(['--model-version', model_version])
|
||||
if bento_version: args.extend(['--bento-version', bento_version])
|
||||
if dockerfile_template: args.extend(['--dockerfile-template', dockerfile_template])
|
||||
if container_registry is None: container_registry = 'ecr'
|
||||
if container_version_strategy is None: container_version_strategy = 'release'
|
||||
if quantize:
|
||||
args.extend(['--quantize', quantize])
|
||||
if containerize and push:
|
||||
raise OpenLLMException("'containerize' and 'push' are currently mutually exclusive.")
|
||||
if push:
|
||||
args.extend(['--push'])
|
||||
if containerize:
|
||||
args.extend(['--containerize'])
|
||||
if build_ctx:
|
||||
args.extend(['--build-ctx', build_ctx])
|
||||
if enable_features:
|
||||
args.extend([f'--enable-features={f}' for f in enable_features])
|
||||
if overwrite:
|
||||
args.append('--overwrite')
|
||||
if system_message:
|
||||
args.extend(['--system-message', system_message])
|
||||
if prompt_template_file:
|
||||
args.extend(['--prompt-template-file', openllm_core.utils.resolve_filepath(prompt_template_file)])
|
||||
if adapter_map:
|
||||
args.extend([f"--adapter-id={k}{':'+v if v is not None else ''}" for k, v in adapter_map.items()])
|
||||
if model_version:
|
||||
args.extend(['--model-version', model_version])
|
||||
if bento_version:
|
||||
args.extend(['--bento-version', bento_version])
|
||||
if dockerfile_template:
|
||||
args.extend(['--dockerfile-template', dockerfile_template])
|
||||
if container_registry is None:
|
||||
container_registry = 'ecr'
|
||||
if container_version_strategy is None:
|
||||
container_version_strategy = 'release'
|
||||
args.extend(['--container-registry', container_registry, '--container-version-strategy', container_version_strategy])
|
||||
if additional_args: args.extend(additional_args)
|
||||
if additional_args:
|
||||
args.extend(additional_args)
|
||||
|
||||
try:
|
||||
output = subprocess.check_output(args, env=os.environ.copy(), cwd=build_ctx or os.getcwd())
|
||||
except subprocess.CalledProcessError as e:
|
||||
logger.error("Exception caught while building Bento for '%s'", model_id, exc_info=e)
|
||||
if e.stderr: raise OpenLLMException(e.stderr.decode('utf-8')) from None
|
||||
if e.stderr:
|
||||
raise OpenLLMException(e.stderr.decode('utf-8')) from None
|
||||
raise OpenLLMException(str(e)) from None
|
||||
matched = re.match(r'__object__:(\{.*\})$', output.decode('utf-8').strip())
|
||||
if matched is None:
|
||||
raise ValueError(f"Failed to find tag from output: {output.decode('utf-8').strip()}\nNote: Output from 'openllm build' might not be correct. Please open an issue on GitHub.")
|
||||
raise ValueError(
|
||||
f"Failed to find tag from output: {output.decode('utf-8').strip()}\nNote: Output from 'openllm build' might not be correct. Please open an issue on GitHub."
|
||||
)
|
||||
try:
|
||||
result = orjson.loads(matched.group(1))
|
||||
except orjson.JSONDecodeError as e:
|
||||
raise ValueError(f"Failed to decode JSON from output: {output.decode('utf-8').strip()}\nNote: Output from 'openllm build' might not be correct. Please open an issue on GitHub.") from e
|
||||
raise ValueError(
|
||||
f"Failed to decode JSON from output: {output.decode('utf-8').strip()}\nNote: Output from 'openllm build' might not be correct. Please open an issue on GitHub."
|
||||
) from e
|
||||
return bentoml.get(result['tag'], _bento_store=bento_store)
|
||||
|
||||
def _import_model(model_id: str,
|
||||
model_version: str | None = None,
|
||||
backend: LiteralBackend | None = None,
|
||||
quantize: LiteralQuantise | None = None,
|
||||
serialisation: LiteralSerialisation | None = None,
|
||||
additional_args: t.Sequence[str] | None = None) -> dict[str, t.Any]:
|
||||
|
||||
def _import_model(
|
||||
model_id: str,
|
||||
model_version: str | None = None,
|
||||
backend: LiteralBackend | None = None,
|
||||
quantize: LiteralQuantise | None = None,
|
||||
serialisation: LiteralSerialisation | None = None,
|
||||
additional_args: t.Sequence[str] | None = None,
|
||||
) -> dict[str, t.Any]:
|
||||
"""Import a LLM into local store.
|
||||
|
||||
> [!NOTE]
|
||||
@@ -232,19 +293,32 @@ def _import_model(model_id: str,
|
||||
``bentoml.Model``:BentoModel of the given LLM. This can be used to serve the LLM or can be pushed to BentoCloud.
|
||||
"""
|
||||
from .entrypoint import import_command
|
||||
|
||||
args = [model_id, '--quiet']
|
||||
if backend is not None: args.extend(['--backend', backend])
|
||||
if model_version is not None: args.extend(['--model-version', str(model_version)])
|
||||
if quantize is not None: args.extend(['--quantize', quantize])
|
||||
if serialisation is not None: args.extend(['--serialisation', serialisation])
|
||||
if additional_args is not None: args.extend(additional_args)
|
||||
if backend is not None:
|
||||
args.extend(['--backend', backend])
|
||||
if model_version is not None:
|
||||
args.extend(['--model-version', str(model_version)])
|
||||
if quantize is not None:
|
||||
args.extend(['--quantize', quantize])
|
||||
if serialisation is not None:
|
||||
args.extend(['--serialisation', serialisation])
|
||||
if additional_args is not None:
|
||||
args.extend(additional_args)
|
||||
return import_command.main(args=args, standalone_mode=False)
|
||||
|
||||
|
||||
def _list_models() -> dict[str, t.Any]:
|
||||
"""List all available models within the local store."""
|
||||
from .entrypoint import models_command
|
||||
|
||||
return models_command.main(args=['--show-available', '--quiet'], standalone_mode=False)
|
||||
|
||||
|
||||
start, start_grpc = codegen.gen_sdk(_start, _serve_grpc=False), codegen.gen_sdk(_start, _serve_grpc=True)
|
||||
build, import_model, list_models = codegen.gen_sdk(_build), codegen.gen_sdk(_import_model), codegen.gen_sdk(_list_models)
|
||||
build, import_model, list_models = (
|
||||
codegen.gen_sdk(_build),
|
||||
codegen.gen_sdk(_import_model),
|
||||
codegen.gen_sdk(_list_models),
|
||||
)
|
||||
__all__ = ['start', 'start_grpc', 'build', 'import_model', 'list_models']
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -10,13 +10,16 @@ from openllm.cli import termui
|
||||
from openllm.cli._factory import container_registry_option
|
||||
from openllm.cli._factory import machine_option
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from openllm_core._typing_compat import LiteralContainerRegistry
|
||||
from openllm_core._typing_compat import LiteralContainerVersionStrategy
|
||||
|
||||
@click.command('build_base_container',
|
||||
context_settings=termui.CONTEXT_SETTINGS,
|
||||
help='''Base image builder for BentoLLM.
|
||||
|
||||
@click.command(
|
||||
'build_base_container',
|
||||
context_settings=termui.CONTEXT_SETTINGS,
|
||||
help="""Base image builder for BentoLLM.
|
||||
|
||||
By default, the base image will include custom kernels (PagedAttention via vllm, FlashAttention-v2, etc.) built with CUDA 11.8, Python 3.9 on Ubuntu22.04.
|
||||
Optionally, this can also be pushed directly to remote registry. Currently support ``docker.io``, ``ghcr.io`` and ``quay.io``.
|
||||
@@ -26,12 +29,24 @@ if t.TYPE_CHECKING:
|
||||
This command is only useful for debugging and for building custom base image for extending BentoML with custom base images and custom kernels.
|
||||
|
||||
Note that we already release images on our CI to ECR and GHCR, so you don't need to build it yourself.
|
||||
''')
|
||||
""",
|
||||
)
|
||||
@container_registry_option
|
||||
@click.option('--version-strategy', type=click.Choice(['release', 'latest', 'nightly']), default='nightly', help='Version strategy to use for tagging the image.')
|
||||
@click.option(
|
||||
'--version-strategy',
|
||||
type=click.Choice(['release', 'latest', 'nightly']),
|
||||
default='nightly',
|
||||
help='Version strategy to use for tagging the image.',
|
||||
)
|
||||
@click.option('--push/--no-push', help='Whether to push to remote repository', is_flag=True, default=False)
|
||||
@machine_option
|
||||
def cli(container_registry: tuple[LiteralContainerRegistry, ...] | None, version_strategy: LiteralContainerVersionStrategy, push: bool, machine: bool) -> dict[str, str]:
|
||||
def cli(
|
||||
container_registry: tuple[LiteralContainerRegistry, ...] | None,
|
||||
version_strategy: LiteralContainerVersionStrategy,
|
||||
push: bool,
|
||||
machine: bool,
|
||||
) -> dict[str, str]:
|
||||
mapping = openllm.bundle.build_container(container_registry, version_strategy, push, machine)
|
||||
if machine: termui.echo(orjson.dumps(mapping, option=orjson.OPT_INDENT_2).decode(), fg='white')
|
||||
if machine:
|
||||
termui.echo(orjson.dumps(mapping, option=orjson.OPT_INDENT_2).decode(), fg='white')
|
||||
return mapping
|
||||
|
||||
@@ -16,24 +16,33 @@ from openllm.cli import termui
|
||||
from openllm.cli._factory import bento_complete_envvar
|
||||
from openllm.cli._factory import machine_option
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from bentoml._internal.bento import BentoStore
|
||||
|
||||
|
||||
@click.command('dive_bentos', context_settings=termui.CONTEXT_SETTINGS)
|
||||
@click.argument('bento', type=str, shell_complete=bento_complete_envvar)
|
||||
@machine_option
|
||||
@click.pass_context
|
||||
@inject
|
||||
def cli(ctx: click.Context, bento: str, machine: bool, _bento_store: BentoStore = Provide[BentoMLContainer.bento_store]) -> str | None:
|
||||
def cli(
|
||||
ctx: click.Context, bento: str, machine: bool, _bento_store: BentoStore = Provide[BentoMLContainer.bento_store]
|
||||
) -> str | None:
|
||||
"""Dive into a BentoLLM. This is synonymous to cd $(b get <bento>:<tag> -o path)."""
|
||||
try:
|
||||
bentomodel = _bento_store.get(bento)
|
||||
except bentoml.exceptions.NotFound:
|
||||
ctx.fail(f'Bento {bento} not found. Make sure to call `openllm build` first.')
|
||||
if 'bundler' not in bentomodel.info.labels or bentomodel.info.labels['bundler'] != 'openllm.bundle':
|
||||
ctx.fail(f"Bento is either too old or not built with OpenLLM. Make sure to use ``openllm build {bentomodel.info.labels['start_name']}`` for correctness.")
|
||||
if machine: return bentomodel.path
|
||||
ctx.fail(
|
||||
f"Bento is either too old or not built with OpenLLM. Make sure to use ``openllm build {bentomodel.info.labels['start_name']}`` for correctness."
|
||||
)
|
||||
if machine:
|
||||
return bentomodel.path
|
||||
# copy and paste this into a new shell
|
||||
if psutil.WINDOWS: subprocess.check_call([shutil.which('dir') or 'dir'], cwd=bentomodel.path)
|
||||
else: subprocess.check_call([shutil.which('ls') or 'ls', '-Rrthla'], cwd=bentomodel.path)
|
||||
if psutil.WINDOWS:
|
||||
subprocess.check_call([shutil.which('dir') or 'dir'], cwd=bentomodel.path)
|
||||
else:
|
||||
subprocess.check_call([shutil.which('ls') or 'ls', '-Rrthla'], cwd=bentomodel.path)
|
||||
ctx.exit(0)
|
||||
|
||||
@@ -16,10 +16,14 @@ from openllm.cli import termui
|
||||
from openllm.cli._factory import bento_complete_envvar
|
||||
from openllm_core.utils import converter
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from bentoml._internal.bento import BentoStore
|
||||
|
||||
@click.command('get_containerfile', context_settings=termui.CONTEXT_SETTINGS, help='Return Containerfile of any given Bento.')
|
||||
|
||||
@click.command(
|
||||
'get_containerfile', context_settings=termui.CONTEXT_SETTINGS, help='Return Containerfile of any given Bento.'
|
||||
)
|
||||
@click.argument('bento', type=str, shell_complete=bento_complete_envvar)
|
||||
@click.pass_context
|
||||
@inject
|
||||
@@ -41,6 +45,13 @@ def cli(ctx: click.Context, bento: str, _bento_store: BentoStore = Provide[Bento
|
||||
# for the reconstruction of the Dockerfile.
|
||||
if 'dockerfile_template' in docker_attrs and docker_attrs['dockerfile_template'] is not None:
|
||||
docker_attrs['dockerfile_template'] = 'env/docker/Dockerfile.template'
|
||||
doc = generate_containerfile(docker=DockerOptions(**docker_attrs), build_ctx=bentomodel.path, conda=options.conda, bento_fs=bentomodel._fs, enable_buildkit=True, add_header=True)
|
||||
doc = generate_containerfile(
|
||||
docker=DockerOptions(**docker_attrs),
|
||||
build_ctx=bentomodel.path,
|
||||
conda=options.conda,
|
||||
bento_fs=bentomodel._fs,
|
||||
enable_buildkit=True,
|
||||
add_header=True,
|
||||
)
|
||||
termui.echo(doc, fg='white')
|
||||
return bentomodel.path
|
||||
|
||||
@@ -16,20 +16,30 @@ from openllm.cli import termui
|
||||
from openllm.cli._factory import model_complete_envvar
|
||||
from openllm_core.prompts import process_prompt
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@click.command('get_prompt', context_settings=termui.CONTEXT_SETTINGS)
|
||||
@click.argument('model_name', type=click.Choice([inflection.dasherize(name) for name in openllm.CONFIG_MAPPING.keys()]), shell_complete=model_complete_envvar)
|
||||
@click.argument(
|
||||
'model_name',
|
||||
type=click.Choice([inflection.dasherize(name) for name in openllm.CONFIG_MAPPING.keys()]),
|
||||
shell_complete=model_complete_envvar,
|
||||
)
|
||||
@click.argument('prompt', type=click.STRING)
|
||||
@click.option('--format', type=click.STRING, default=None)
|
||||
@click.option('--opt',
|
||||
help="Define additional prompt variables. (format: ``--opt system_prompt='You are a useful assistant'``)",
|
||||
required=False,
|
||||
multiple=True,
|
||||
callback=opt_callback,
|
||||
metavar='ARG=VALUE[,ARG=VALUE]')
|
||||
@click.option(
|
||||
'--opt',
|
||||
help="Define additional prompt variables. (format: ``--opt system_prompt='You are a useful assistant'``)",
|
||||
required=False,
|
||||
multiple=True,
|
||||
callback=opt_callback,
|
||||
metavar='ARG=VALUE[,ARG=VALUE]',
|
||||
)
|
||||
@click.pass_context
|
||||
def cli(ctx: click.Context, /, model_name: str, prompt: str, format: str | None, _memoized: dict[str, t.Any], **_: t.Any) -> str | None:
|
||||
def cli(
|
||||
ctx: click.Context, /, model_name: str, prompt: str, format: str | None, _memoized: dict[str, t.Any], **_: t.Any
|
||||
) -> str | None:
|
||||
"""Get the default prompt used by OpenLLM."""
|
||||
module = getattr(openllm_core.config, f'configuration_{model_name}')
|
||||
_memoized = {k: v[0] for k, v in _memoized.items() if v}
|
||||
@@ -42,11 +52,18 @@ def cli(ctx: click.Context, /, model_name: str, prompt: str, format: str | None,
|
||||
if format is None:
|
||||
if not hasattr(module, 'PROMPT_MAPPING') or module.PROMPT_MAPPING is None:
|
||||
raise RuntimeError('Failed to find prompt mapping while DEFAULT_PROMPT_TEMPLATE is a function.')
|
||||
raise click.BadOptionUsage('format', f"{model_name} prompt requires passing '--format' (available format: {list(module.PROMPT_MAPPING)})")
|
||||
raise click.BadOptionUsage(
|
||||
'format',
|
||||
f"{model_name} prompt requires passing '--format' (available format: {list(module.PROMPT_MAPPING)})",
|
||||
)
|
||||
if prompt_mapping is None:
|
||||
raise click.BadArgumentUsage(f'Failed to fine prompt mapping while the default prompt for {model_name} is a callable.') from None
|
||||
raise click.BadArgumentUsage(
|
||||
f'Failed to fine prompt mapping while the default prompt for {model_name} is a callable.'
|
||||
) from None
|
||||
if format not in prompt_mapping:
|
||||
raise click.BadOptionUsage('format', f'Given format {format} is not valid for {model_name} (available format: {list(prompt_mapping)})')
|
||||
raise click.BadOptionUsage(
|
||||
'format', f'Given format {format} is not valid for {model_name} (available format: {list(prompt_mapping)})'
|
||||
)
|
||||
_prompt_template = template(format)
|
||||
else:
|
||||
_prompt_template = template
|
||||
@@ -55,7 +72,9 @@ def cli(ctx: click.Context, /, model_name: str, prompt: str, format: str | None,
|
||||
fully_formatted = process_prompt(prompt, _prompt_template, True, **_memoized)
|
||||
except RuntimeError as err:
|
||||
logger.debug('Exception caught while formatting prompt: %s', err)
|
||||
fully_formatted = openllm.AutoConfig.for_model(model_name).sanitize_parameters(prompt, prompt_template=_prompt_template)[0]
|
||||
fully_formatted = openllm.AutoConfig.for_model(model_name).sanitize_parameters(
|
||||
prompt, prompt_template=_prompt_template
|
||||
)[0]
|
||||
termui.echo(orjson.dumps({'prompt': fully_formatted}, option=orjson.OPT_INDENT_2).decode(), fg='white')
|
||||
except Exception as err:
|
||||
traceback.print_exc()
|
||||
|
||||
@@ -10,20 +10,25 @@ import openllm
|
||||
from bentoml._internal.utils import human_readable_size
|
||||
from openllm.cli import termui
|
||||
|
||||
|
||||
@click.command('list_bentos', context_settings=termui.CONTEXT_SETTINGS)
|
||||
@click.pass_context
|
||||
def cli(ctx: click.Context) -> None:
|
||||
"""List available bentos built by OpenLLM."""
|
||||
mapping = {
|
||||
k: [{
|
||||
'tag': str(b.tag),
|
||||
'size': human_readable_size(openllm.utils.calc_dir_size(b.path)),
|
||||
'models': [{
|
||||
'tag': str(m.tag),
|
||||
'size': human_readable_size(openllm.utils.calc_dir_size(m.path))
|
||||
} for m in (bentoml.models.get(_.tag) for _ in b.info.models)]
|
||||
} for b in tuple(i for i in bentoml.list() if all(
|
||||
k in i.info.labels for k in {'start_name', 'bundler'})) if b.info.labels['start_name'] == k] for k in tuple(inflection.dasherize(key) for key in openllm.CONFIG_MAPPING.keys())
|
||||
k: [
|
||||
{
|
||||
'tag': str(b.tag),
|
||||
'size': human_readable_size(openllm.utils.calc_dir_size(b.path)),
|
||||
'models': [
|
||||
{'tag': str(m.tag), 'size': human_readable_size(openllm.utils.calc_dir_size(m.path))}
|
||||
for m in (bentoml.models.get(_.tag) for _ in b.info.models)
|
||||
],
|
||||
}
|
||||
for b in tuple(i for i in bentoml.list() if all(k in i.info.labels for k in {'start_name', 'bundler'}))
|
||||
if b.info.labels['start_name'] == k
|
||||
]
|
||||
for k in tuple(inflection.dasherize(key) for key in openllm.CONFIG_MAPPING.keys())
|
||||
}
|
||||
mapping = {k: v for k, v in mapping.items() if v}
|
||||
termui.echo(orjson.dumps(mapping, option=orjson.OPT_INDENT_2).decode(), fg='white')
|
||||
|
||||
@@ -13,21 +13,40 @@ from openllm.cli import termui
|
||||
from openllm.cli._factory import model_complete_envvar
|
||||
from openllm.cli._factory import model_name_argument
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from openllm_core._typing_compat import DictStrAny
|
||||
|
||||
|
||||
@click.command('list_models', context_settings=termui.CONTEXT_SETTINGS)
|
||||
@model_name_argument(required=False, shell_complete=model_complete_envvar)
|
||||
def cli(model_name: str | None) -> DictStrAny:
|
||||
"""This is equivalent to openllm models --show-available less the nice table."""
|
||||
models = tuple(inflection.dasherize(key) for key in openllm.CONFIG_MAPPING.keys())
|
||||
ids_in_local_store = {
|
||||
k: [i for i in bentoml.models.list() if 'framework' in i.info.labels and i.info.labels['framework'] == 'openllm' and 'model_name' in i.info.labels and i.info.labels['model_name'] == k]
|
||||
for k in models
|
||||
k: [
|
||||
i
|
||||
for i in bentoml.models.list()
|
||||
if 'framework' in i.info.labels
|
||||
and i.info.labels['framework'] == 'openllm'
|
||||
and 'model_name' in i.info.labels
|
||||
and i.info.labels['model_name'] == k
|
||||
]
|
||||
for k in models
|
||||
}
|
||||
if model_name is not None:
|
||||
ids_in_local_store = {k: [i for i in v if 'model_name' in i.info.labels and i.info.labels['model_name'] == inflection.dasherize(model_name)] for k, v in ids_in_local_store.items()}
|
||||
ids_in_local_store = {
|
||||
k: [
|
||||
i
|
||||
for i in v
|
||||
if 'model_name' in i.info.labels and i.info.labels['model_name'] == inflection.dasherize(model_name)
|
||||
]
|
||||
for k, v in ids_in_local_store.items()
|
||||
}
|
||||
ids_in_local_store = {k: v for k, v in ids_in_local_store.items() if v}
|
||||
local_models = {k: [{'tag': str(i.tag), 'size': human_readable_size(openllm.utils.calc_dir_size(i.path))} for i in val] for k, val in ids_in_local_store.items()}
|
||||
local_models = {
|
||||
k: [{'tag': str(i.tag), 'size': human_readable_size(openllm.utils.calc_dir_size(i.path))} for i in val]
|
||||
for k, val in ids_in_local_store.items()
|
||||
}
|
||||
termui.echo(orjson.dumps(local_models, option=orjson.OPT_INDENT_2).decode(), fg='white')
|
||||
return local_models
|
||||
|
||||
@@ -19,11 +19,13 @@ from openllm_core.utils import is_jupyter_available
|
||||
from openllm_core.utils import is_jupytext_available
|
||||
from openllm_core.utils import is_notebook_available
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from openllm_core._typing_compat import DictStrAny
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def load_notebook_metadata() -> DictStrAny:
|
||||
with open(os.path.join(os.path.dirname(playground.__file__), '_meta.yml'), 'r') as f:
|
||||
content = yaml.safe_load(f)
|
||||
@@ -31,9 +33,17 @@ def load_notebook_metadata() -> DictStrAny:
|
||||
raise ValueError("Invalid metadata file. All entries must have a 'description' key.")
|
||||
return content
|
||||
|
||||
|
||||
@click.command('playground', context_settings=termui.CONTEXT_SETTINGS)
|
||||
@click.argument('output-dir', default=None, required=False)
|
||||
@click.option('--port', envvar='JUPYTER_PORT', show_envvar=True, show_default=True, default=8888, help='Default port for Jupyter server')
|
||||
@click.option(
|
||||
'--port',
|
||||
envvar='JUPYTER_PORT',
|
||||
show_envvar=True,
|
||||
show_default=True,
|
||||
default=8888,
|
||||
help='Default port for Jupyter server',
|
||||
)
|
||||
@click.pass_context
|
||||
def cli(ctx: click.Context, output_dir: str | None, port: int) -> None:
|
||||
"""OpenLLM Playground.
|
||||
@@ -54,7 +64,9 @@ def cli(ctx: click.Context, output_dir: str | None, port: int) -> None:
|
||||
> This command requires Jupyter to be installed. Install it with 'pip install "openllm[playground]"'
|
||||
"""
|
||||
if not is_jupyter_available() or not is_jupytext_available() or not is_notebook_available():
|
||||
raise RuntimeError("Playground requires 'jupyter', 'jupytext', and 'notebook'. Install it with 'pip install \"openllm[playground]\"'")
|
||||
raise RuntimeError(
|
||||
"Playground requires 'jupyter', 'jupytext', and 'notebook'. Install it with 'pip install \"openllm[playground]\"'"
|
||||
)
|
||||
metadata = load_notebook_metadata()
|
||||
_temp_dir = False
|
||||
if output_dir is None:
|
||||
@@ -66,20 +78,37 @@ def cli(ctx: click.Context, output_dir: str | None, port: int) -> None:
|
||||
termui.echo('The playground notebooks will be saved to: ' + os.path.abspath(output_dir), fg='blue')
|
||||
for module in pkgutil.iter_modules(playground.__path__):
|
||||
if module.ispkg or os.path.exists(os.path.join(output_dir, module.name + '.ipynb')):
|
||||
logger.debug('Skipping: %s (%s)', module.name, 'File already exists' if not module.ispkg else f'{module.name} is a module')
|
||||
logger.debug(
|
||||
'Skipping: %s (%s)', module.name, 'File already exists' if not module.ispkg else f'{module.name} is a module'
|
||||
)
|
||||
continue
|
||||
if not isinstance(module.module_finder, importlib.machinery.FileFinder):
|
||||
continue
|
||||
if not isinstance(module.module_finder, importlib.machinery.FileFinder): continue
|
||||
termui.echo('Generating notebook for: ' + module.name, fg='magenta')
|
||||
markdown_cell = nbformat.v4.new_markdown_cell(metadata[module.name]['description'])
|
||||
f = jupytext.read(os.path.join(module.module_finder.path, module.name + '.py'))
|
||||
f.cells.insert(0, markdown_cell)
|
||||
jupytext.write(f, os.path.join(output_dir, module.name + '.ipynb'), fmt='notebook')
|
||||
try:
|
||||
subprocess.check_output([sys.executable, '-m', 'jupyter', 'notebook', '--notebook-dir', output_dir, '--port', str(port), '--no-browser', '--debug'])
|
||||
subprocess.check_output(
|
||||
[
|
||||
sys.executable,
|
||||
'-m',
|
||||
'jupyter',
|
||||
'notebook',
|
||||
'--notebook-dir',
|
||||
output_dir,
|
||||
'--port',
|
||||
str(port),
|
||||
'--no-browser',
|
||||
'--debug',
|
||||
]
|
||||
)
|
||||
except subprocess.CalledProcessError as e:
|
||||
termui.echo(e.output, fg='red')
|
||||
raise click.ClickException(f'Failed to start a jupyter server:\n{e}') from None
|
||||
except KeyboardInterrupt:
|
||||
termui.echo('\nShutting down Jupyter server...', fg='yellow')
|
||||
if _temp_dir: termui.echo('Note: You can access the generated notebooks in: ' + output_dir, fg='blue')
|
||||
if _temp_dir:
|
||||
termui.echo('Note: You can access the generated notebooks in: ' + output_dir, fg='blue')
|
||||
ctx.exit(0)
|
||||
|
||||
@@ -13,8 +13,10 @@ from openllm_core._typing_compat import DictStrAny
|
||||
from openllm_core.utils import get_debug_mode
|
||||
from openllm_core.utils import get_quiet_mode
|
||||
|
||||
|
||||
logger = logging.getLogger('openllm')
|
||||
|
||||
|
||||
class Level(enum.IntEnum):
|
||||
NOTSET = logging.DEBUG
|
||||
DEBUG = logging.DEBUG
|
||||
@@ -25,19 +27,31 @@ class Level(enum.IntEnum):
|
||||
|
||||
@property
|
||||
def color(self) -> str | None:
|
||||
return {Level.NOTSET: None, Level.DEBUG: 'cyan', Level.INFO: 'green', Level.WARNING: 'yellow', Level.ERROR: 'red', Level.CRITICAL: 'red'}[self]
|
||||
return {
|
||||
Level.NOTSET: None,
|
||||
Level.DEBUG: 'cyan',
|
||||
Level.INFO: 'green',
|
||||
Level.WARNING: 'yellow',
|
||||
Level.ERROR: 'red',
|
||||
Level.CRITICAL: 'red',
|
||||
}[self]
|
||||
|
||||
|
||||
class JsonLog(t.TypedDict):
|
||||
log_level: Level
|
||||
content: str
|
||||
|
||||
|
||||
def log(content: str, level: Level = Level.INFO, fg: str | None = None) -> None:
|
||||
def caller(text: str) -> None:
|
||||
if get_debug_mode(): logger.log(level.value, text)
|
||||
else: echo(JsonLog(log_level=level, content=content), json=True, fg=fg)
|
||||
if get_debug_mode():
|
||||
logger.log(level.value, text)
|
||||
else:
|
||||
echo(JsonLog(log_level=level, content=content), json=True, fg=fg)
|
||||
|
||||
caller(orjson.dumps(JsonLog(log_level=level, content=content)).decode())
|
||||
|
||||
|
||||
warning = functools.partial(log, level=Level.WARNING)
|
||||
error = functools.partial(log, level=Level.ERROR)
|
||||
critical = functools.partial(log, level=Level.CRITICAL)
|
||||
@@ -45,8 +59,10 @@ debug = functools.partial(log, level=Level.DEBUG)
|
||||
info = functools.partial(log, level=Level.INFO)
|
||||
notset = functools.partial(log, level=Level.NOTSET)
|
||||
|
||||
|
||||
def echo(text: t.Any, fg: str | None = None, _with_style: bool = True, json: bool = False, **attrs: t.Any) -> None:
|
||||
if json and not isinstance(text, dict): raise TypeError('text must be a dict')
|
||||
if json and not isinstance(text, dict):
|
||||
raise TypeError('text must be a dict')
|
||||
if json:
|
||||
if 'content' in text and 'log_level' in text:
|
||||
content = t.cast(DictStrAny, text)['content']
|
||||
@@ -58,8 +74,14 @@ def echo(text: t.Any, fg: str | None = None, _with_style: bool = True, json: boo
|
||||
content = t.cast(str, text)
|
||||
attrs['fg'] = fg if not get_debug_mode() else None
|
||||
|
||||
if not get_quiet_mode(): t.cast(t.Callable[..., None], click.echo if not _with_style else click.secho)(content, **attrs)
|
||||
if not get_quiet_mode():
|
||||
t.cast(t.Callable[..., None], click.echo if not _with_style else click.secho)(content, **attrs)
|
||||
|
||||
|
||||
COLUMNS: int = int(os.environ.get('COLUMNS', str(120)))
|
||||
CONTEXT_SETTINGS: DictStrAny = {'help_option_names': ['-h', '--help'], 'max_content_width': COLUMNS, 'token_normalize_func': inflection.underscore}
|
||||
CONTEXT_SETTINGS: DictStrAny = {
|
||||
'help_option_names': ['-h', '--help'],
|
||||
'max_content_width': COLUMNS,
|
||||
'token_normalize_func': inflection.underscore,
|
||||
}
|
||||
__all__ = ['echo', 'COLUMNS', 'CONTEXT_SETTINGS', 'log', 'warning', 'error', 'critical', 'debug', 'info', 'Level']
|
||||
|
||||
@@ -1,23 +1,27 @@
|
||||
'''OpenLLM Python client.
|
||||
"""OpenLLM Python client.
|
||||
|
||||
```python
|
||||
client = openllm.client.HTTPClient("http://localhost:8080")
|
||||
client.query("What is the difference between gather and scatter?")
|
||||
```
|
||||
'''
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import typing as t
|
||||
|
||||
import openllm_client
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from openllm_client import AsyncHTTPClient as AsyncHTTPClient
|
||||
from openllm_client import HTTPClient as HTTPClient
|
||||
# from openllm_client import AsyncGrpcClient as AsyncGrpcClient
|
||||
# from openllm_client import GrpcClient as GrpcClient
|
||||
|
||||
|
||||
def __dir__() -> t.Sequence[str]:
|
||||
return sorted(dir(openllm_client))
|
||||
|
||||
|
||||
def __getattr__(it: str) -> t.Any:
|
||||
return getattr(openllm_client, it)
|
||||
|
||||
@@ -6,6 +6,7 @@ Each module should implement the following API:
|
||||
|
||||
- `mount_to_svc(svc: bentoml.Service, llm: openllm.LLM[M, T]) -> bentoml.Service: ...`
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import typing as t
|
||||
|
||||
@@ -14,16 +15,21 @@ from openllm_core.utils import LazyModule
|
||||
from . import hf as hf
|
||||
from . import openai as openai
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import bentoml
|
||||
import openllm
|
||||
|
||||
_import_structure: dict[str, list[str]] = {'openai': [], 'hf': []}
|
||||
|
||||
|
||||
def mount_entrypoints(svc: bentoml.Service, llm: openllm.LLM[t.Any, t.Any]) -> bentoml.Service:
|
||||
return openai.mount_to_svc(hf.mount_to_svc(svc, llm), llm)
|
||||
|
||||
__lazy = LazyModule(__name__, globals()['__file__'], _import_structure, extra_objects={'mount_entrypoints': mount_entrypoints})
|
||||
|
||||
__lazy = LazyModule(
|
||||
__name__, globals()['__file__'], _import_structure, extra_objects={'mount_entrypoints': mount_entrypoints}
|
||||
)
|
||||
__all__ = __lazy.__all__
|
||||
__dir__ = __lazy.__dir__
|
||||
__getattr__ = __lazy.__getattr__
|
||||
|
||||
@@ -15,6 +15,7 @@ from starlette.schemas import SchemaGenerator
|
||||
from openllm_core._typing_compat import ParamSpec
|
||||
from openllm_core.utils import first_not_none
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from attr import AttrsInstance
|
||||
|
||||
@@ -23,7 +24,7 @@ if t.TYPE_CHECKING:
|
||||
P = ParamSpec('P')
|
||||
OPENAPI_VERSION, API_VERSION = '3.0.2', '1.0'
|
||||
# NOTE: OpenAI schema
|
||||
LIST_MODEL_SCHEMA = '''\
|
||||
LIST_MODEL_SCHEMA = """\
|
||||
---
|
||||
consumes:
|
||||
- application/json
|
||||
@@ -53,8 +54,8 @@ responses:
|
||||
owned_by: 'na'
|
||||
schema:
|
||||
$ref: '#/components/schemas/ModelList'
|
||||
'''
|
||||
CHAT_COMPLETION_SCHEMA = '''\
|
||||
"""
|
||||
CHAT_COMPLETION_SCHEMA = """\
|
||||
---
|
||||
consumes:
|
||||
- application/json
|
||||
@@ -191,8 +192,8 @@ responses:
|
||||
}
|
||||
}
|
||||
description: Bad Request
|
||||
'''
|
||||
COMPLETION_SCHEMA = '''\
|
||||
"""
|
||||
COMPLETION_SCHEMA = """\
|
||||
---
|
||||
consumes:
|
||||
- application/json
|
||||
@@ -344,8 +345,8 @@ responses:
|
||||
}
|
||||
}
|
||||
description: Bad Request
|
||||
'''
|
||||
HF_AGENT_SCHEMA = '''\
|
||||
"""
|
||||
HF_AGENT_SCHEMA = """\
|
||||
---
|
||||
consumes:
|
||||
- application/json
|
||||
@@ -389,8 +390,8 @@ responses:
|
||||
schema:
|
||||
$ref: '#/components/schemas/HFErrorResponse'
|
||||
description: Not Found
|
||||
'''
|
||||
HF_ADAPTERS_SCHEMA = '''\
|
||||
"""
|
||||
HF_ADAPTERS_SCHEMA = """\
|
||||
---
|
||||
consumes:
|
||||
- application/json
|
||||
@@ -420,16 +421,19 @@ responses:
|
||||
schema:
|
||||
$ref: '#/components/schemas/HFErrorResponse'
|
||||
description: Not Found
|
||||
'''
|
||||
"""
|
||||
|
||||
|
||||
def add_schema_definitions(append_str: str) -> t.Callable[[t.Callable[P, t.Any]], t.Callable[P, t.Any]]:
|
||||
def docstring_decorator(func: t.Callable[P, t.Any]) -> t.Callable[P, t.Any]:
|
||||
if func.__doc__ is None: func.__doc__ = ''
|
||||
if func.__doc__ is None:
|
||||
func.__doc__ = ''
|
||||
func.__doc__ = func.__doc__.strip() + '\n\n' + append_str.strip()
|
||||
return func
|
||||
|
||||
return docstring_decorator
|
||||
|
||||
|
||||
class OpenLLMSchemaGenerator(SchemaGenerator):
|
||||
def get_endpoints(self, routes: list[BaseRoute]) -> list[EndpointInfo]:
|
||||
endpoints_info: list[EndpointInfo] = []
|
||||
@@ -437,20 +441,29 @@ class OpenLLMSchemaGenerator(SchemaGenerator):
|
||||
if isinstance(route, (Mount, Host)):
|
||||
routes = route.routes or []
|
||||
path = self._remove_converter(route.path) if isinstance(route, Mount) else ''
|
||||
sub_endpoints = [EndpointInfo(path=f'{path}{sub_endpoint.path}', http_method=sub_endpoint.http_method, func=sub_endpoint.func) for sub_endpoint in self.get_endpoints(routes)]
|
||||
sub_endpoints = [
|
||||
EndpointInfo(path=f'{path}{sub_endpoint.path}', http_method=sub_endpoint.http_method, func=sub_endpoint.func)
|
||||
for sub_endpoint in self.get_endpoints(routes)
|
||||
]
|
||||
endpoints_info.extend(sub_endpoints)
|
||||
elif not isinstance(route, Route) or not route.include_in_schema:
|
||||
continue
|
||||
elif inspect.isfunction(route.endpoint) or inspect.ismethod(route.endpoint) or isinstance(route.endpoint, functools.partial):
|
||||
elif (
|
||||
inspect.isfunction(route.endpoint)
|
||||
or inspect.ismethod(route.endpoint)
|
||||
or isinstance(route.endpoint, functools.partial)
|
||||
):
|
||||
endpoint = route.endpoint.func if isinstance(route.endpoint, functools.partial) else route.endpoint
|
||||
path = self._remove_converter(route.path)
|
||||
for method in route.methods or ['GET']:
|
||||
if method == 'HEAD': continue
|
||||
if method == 'HEAD':
|
||||
continue
|
||||
endpoints_info.append(EndpointInfo(path, method.lower(), endpoint))
|
||||
else:
|
||||
path = self._remove_converter(route.path)
|
||||
for method in ['get', 'post', 'put', 'patch', 'delete', 'options']:
|
||||
if not hasattr(route.endpoint, method): continue
|
||||
if not hasattr(route.endpoint, method):
|
||||
continue
|
||||
func = getattr(route.endpoint, method)
|
||||
endpoints_info.append(EndpointInfo(path, method.lower(), func))
|
||||
return endpoints_info
|
||||
@@ -459,37 +472,52 @@ class OpenLLMSchemaGenerator(SchemaGenerator):
|
||||
schema = dict(self.base_schema)
|
||||
schema.setdefault('paths', {})
|
||||
endpoints_info = self.get_endpoints(routes)
|
||||
if mount_path: mount_path = f'/{mount_path}' if not mount_path.startswith('/') else mount_path
|
||||
if mount_path:
|
||||
mount_path = f'/{mount_path}' if not mount_path.startswith('/') else mount_path
|
||||
|
||||
for endpoint in endpoints_info:
|
||||
parsed = self.parse_docstring(endpoint.func)
|
||||
if not parsed: continue
|
||||
if not parsed:
|
||||
continue
|
||||
|
||||
path = endpoint.path if mount_path is None else mount_path + endpoint.path
|
||||
if path not in schema['paths']: schema['paths'][path] = {}
|
||||
if path not in schema['paths']:
|
||||
schema['paths'][path] = {}
|
||||
schema['paths'][path][endpoint.http_method] = parsed
|
||||
|
||||
return schema
|
||||
|
||||
def get_generator(title: str, components: list[type[AttrsInstance]] | None = None, tags: list[dict[str, t.Any]] | None = None) -> OpenLLMSchemaGenerator:
|
||||
|
||||
def get_generator(
|
||||
title: str, components: list[type[AttrsInstance]] | None = None, tags: list[dict[str, t.Any]] | None = None
|
||||
) -> OpenLLMSchemaGenerator:
|
||||
base_schema: dict[str, t.Any] = dict(info={'title': title, 'version': API_VERSION}, version=OPENAPI_VERSION)
|
||||
if components: base_schema['components'] = {'schemas': {c.__name__: component_schema_generator(c) for c in components}}
|
||||
if tags is not None and tags: base_schema['tags'] = tags
|
||||
if components:
|
||||
base_schema['components'] = {'schemas': {c.__name__: component_schema_generator(c) for c in components}}
|
||||
if tags is not None and tags:
|
||||
base_schema['tags'] = tags
|
||||
return OpenLLMSchemaGenerator(base_schema)
|
||||
|
||||
|
||||
def component_schema_generator(attr_cls: type[AttrsInstance], description: str | None = None) -> dict[str, t.Any]:
|
||||
schema: dict[str, t.Any] = {'type': 'object', 'required': [], 'properties': {}, 'title': attr_cls.__name__}
|
||||
schema['description'] = first_not_none(getattr(attr_cls, '__doc__', None), description, default=f'Generated components for {attr_cls.__name__}')
|
||||
schema['description'] = first_not_none(
|
||||
getattr(attr_cls, '__doc__', None), description, default=f'Generated components for {attr_cls.__name__}'
|
||||
)
|
||||
for field in attr.fields(attr.resolve_types(attr_cls)): # type: ignore[misc,type-var]
|
||||
attr_type = field.type
|
||||
origin_type = t.get_origin(attr_type)
|
||||
args_type = t.get_args(attr_type)
|
||||
|
||||
# Map Python types to OpenAPI schema types
|
||||
if attr_type == str: schema_type = 'string'
|
||||
elif attr_type == int: schema_type = 'integer'
|
||||
elif attr_type == float: schema_type = 'number'
|
||||
elif attr_type == bool: schema_type = 'boolean'
|
||||
if attr_type == str:
|
||||
schema_type = 'string'
|
||||
elif attr_type == int:
|
||||
schema_type = 'integer'
|
||||
elif attr_type == float:
|
||||
schema_type = 'number'
|
||||
elif attr_type == bool:
|
||||
schema_type = 'boolean'
|
||||
elif origin_type is list or origin_type is tuple:
|
||||
schema_type = 'array'
|
||||
elif origin_type is dict:
|
||||
@@ -504,14 +532,18 @@ def component_schema_generator(attr_cls: type[AttrsInstance], description: str |
|
||||
else:
|
||||
schema_type = 'string'
|
||||
|
||||
if 'prop_schema' not in locals(): prop_schema = {'type': schema_type}
|
||||
if field.default is not attr.NOTHING and not isinstance(field.default, attr.Factory): prop_schema['default'] = field.default # type: ignore[arg-type]
|
||||
if field.default is attr.NOTHING and not isinstance(attr_type, type(t.Optional)): schema['required'].append(field.name)
|
||||
if 'prop_schema' not in locals():
|
||||
prop_schema = {'type': schema_type}
|
||||
if field.default is not attr.NOTHING and not isinstance(field.default, attr.Factory):
|
||||
prop_schema['default'] = field.default # type: ignore[arg-type]
|
||||
if field.default is attr.NOTHING and not isinstance(attr_type, type(t.Optional)):
|
||||
schema['required'].append(field.name)
|
||||
schema['properties'][field.name] = prop_schema
|
||||
locals().pop('prop_schema', None)
|
||||
|
||||
return schema
|
||||
|
||||
|
||||
class MKSchema:
|
||||
def __init__(self, it: dict[str, t.Any]) -> None:
|
||||
self.it = it
|
||||
@@ -519,19 +551,30 @@ class MKSchema:
|
||||
def asdict(self) -> dict[str, t.Any]:
|
||||
return self.it
|
||||
|
||||
def append_schemas(svc: bentoml.Service, generated_schema: dict[str, t.Any], tags_order: t.Literal['prepend', 'append'] = 'prepend') -> bentoml.Service:
|
||||
|
||||
def append_schemas(
|
||||
svc: bentoml.Service, generated_schema: dict[str, t.Any], tags_order: t.Literal['prepend', 'append'] = 'prepend'
|
||||
) -> bentoml.Service:
|
||||
# HACK: Dirty hack to append schemas to existing service. We def need to support mounting Starlette app OpenAPI spec.
|
||||
from bentoml._internal.service.openapi.specification import OpenAPISpecification
|
||||
|
||||
svc_schema: t.Any = svc.openapi_spec
|
||||
if isinstance(svc_schema, (OpenAPISpecification, MKSchema)): svc_schema = svc_schema.asdict()
|
||||
if isinstance(svc_schema, (OpenAPISpecification, MKSchema)):
|
||||
svc_schema = svc_schema.asdict()
|
||||
if 'tags' in generated_schema:
|
||||
if tags_order == 'prepend': svc_schema['tags'] = generated_schema['tags'] + svc_schema['tags']
|
||||
elif tags_order == 'append': svc_schema['tags'].extend(generated_schema['tags'])
|
||||
else: raise ValueError(f'Invalid tags_order: {tags_order}')
|
||||
if 'components' in generated_schema: svc_schema['components']['schemas'].update(generated_schema['components']['schemas'])
|
||||
if tags_order == 'prepend':
|
||||
svc_schema['tags'] = generated_schema['tags'] + svc_schema['tags']
|
||||
elif tags_order == 'append':
|
||||
svc_schema['tags'].extend(generated_schema['tags'])
|
||||
else:
|
||||
raise ValueError(f'Invalid tags_order: {tags_order}')
|
||||
if 'components' in generated_schema:
|
||||
svc_schema['components']['schemas'].update(generated_schema['components']['schemas'])
|
||||
svc_schema['paths'].update(generated_schema['paths'])
|
||||
|
||||
from bentoml._internal.service import openapi # HACK: mk this attribute until we have a better way to add starlette schemas.
|
||||
from bentoml._internal.service import (
|
||||
openapi, # HACK: mk this attribute until we have a better way to add starlette schemas.
|
||||
)
|
||||
|
||||
# yapf: disable
|
||||
def mk_generate_spec(svc:bentoml.Service,openapi_version:str=OPENAPI_VERSION)->MKSchema:return MKSchema(svc_schema)
|
||||
|
||||
@@ -23,17 +23,21 @@ from ..protocol.hf import AgentRequest
|
||||
from ..protocol.hf import AgentResponse
|
||||
from ..protocol.hf import HFErrorResponse
|
||||
|
||||
schemas = get_generator('hf',
|
||||
components=[AgentRequest, AgentResponse, HFErrorResponse],
|
||||
tags=[{
|
||||
'name': 'HF',
|
||||
'description': 'HF integration, including Agent and others schema endpoints.',
|
||||
'externalDocs': 'https://huggingface.co/docs/transformers/main_classes/agent'
|
||||
}])
|
||||
|
||||
schemas = get_generator(
|
||||
'hf',
|
||||
components=[AgentRequest, AgentResponse, HFErrorResponse],
|
||||
tags=[
|
||||
{
|
||||
'name': 'HF',
|
||||
'description': 'HF integration, including Agent and others schema endpoints.',
|
||||
'externalDocs': 'https://huggingface.co/docs/transformers/main_classes/agent',
|
||||
}
|
||||
],
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
|
||||
from peft.config import PeftConfig
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
@@ -44,20 +48,28 @@ if t.TYPE_CHECKING:
|
||||
from openllm_core._typing_compat import M
|
||||
from openllm_core._typing_compat import T
|
||||
|
||||
|
||||
def mount_to_svc(svc: bentoml.Service, llm: openllm.LLM[M, T]) -> bentoml.Service:
|
||||
app = Starlette(debug=True,
|
||||
routes=[
|
||||
Route('/agent', endpoint=functools.partial(hf_agent, llm=llm), name='hf_agent', methods=['POST']),
|
||||
Route('/adapters', endpoint=functools.partial(adapters_map, llm=llm), name='adapters', methods=['GET']),
|
||||
Route('/schema', endpoint=openapi_schema, include_in_schema=False)
|
||||
])
|
||||
app = Starlette(
|
||||
debug=True,
|
||||
routes=[
|
||||
Route('/agent', endpoint=functools.partial(hf_agent, llm=llm), name='hf_agent', methods=['POST']),
|
||||
Route('/adapters', endpoint=functools.partial(adapters_map, llm=llm), name='adapters', methods=['GET']),
|
||||
Route('/schema', endpoint=openapi_schema, include_in_schema=False),
|
||||
],
|
||||
)
|
||||
mount_path = '/hf'
|
||||
generated_schema = schemas.get_schema(routes=app.routes, mount_path=mount_path)
|
||||
svc.mount_asgi_app(app, path=mount_path)
|
||||
return append_schemas(svc, generated_schema, tags_order='append')
|
||||
|
||||
|
||||
def error_response(status_code: HTTPStatus, message: str) -> JSONResponse:
|
||||
return JSONResponse(converter.unstructure(HFErrorResponse(message=message, error_code=status_code.value)), status_code=status_code.value)
|
||||
return JSONResponse(
|
||||
converter.unstructure(HFErrorResponse(message=message, error_code=status_code.value)),
|
||||
status_code=status_code.value,
|
||||
)
|
||||
|
||||
|
||||
@add_schema_definitions(HF_AGENT_SCHEMA)
|
||||
async def hf_agent(req: Request, llm: openllm.LLM[M, T]) -> Response:
|
||||
@@ -72,22 +84,26 @@ async def hf_agent(req: Request, llm: openllm.LLM[M, T]) -> Response:
|
||||
stop = request.parameters.pop('stop', ['\n'])
|
||||
try:
|
||||
result = await llm.generate(request.inputs, stop=stop, **request.parameters)
|
||||
return JSONResponse(converter.unstructure([AgentResponse(generated_text=result.outputs[0].text)]), status_code=HTTPStatus.OK.value)
|
||||
return JSONResponse(
|
||||
converter.unstructure([AgentResponse(generated_text=result.outputs[0].text)]), status_code=HTTPStatus.OK.value
|
||||
)
|
||||
except Exception as err:
|
||||
logger.error('Error while generating: %s', err)
|
||||
return error_response(HTTPStatus.INTERNAL_SERVER_ERROR, 'Error while generating (Check server log).')
|
||||
|
||||
|
||||
@add_schema_definitions(HF_ADAPTERS_SCHEMA)
|
||||
def adapters_map(req: Request, llm: openllm.LLM[M, T]) -> Response:
|
||||
if not llm.has_adapters: return error_response(HTTPStatus.NOT_FOUND, 'No adapters found.')
|
||||
if not llm.has_adapters:
|
||||
return error_response(HTTPStatus.NOT_FOUND, 'No adapters found.')
|
||||
return JSONResponse(
|
||||
{
|
||||
adapter_tuple[1]: {
|
||||
'adapter_name': k,
|
||||
'adapter_type': t.cast(Enum, adapter_tuple[0].peft_type).value
|
||||
} for k, adapter_tuple in t.cast(t.Dict[str, t.Tuple['PeftConfig', str]], dict(*llm.adapter_map.values())).items()
|
||||
},
|
||||
status_code=HTTPStatus.OK.value)
|
||||
{
|
||||
adapter_tuple[1]: {'adapter_name': k, 'adapter_type': t.cast(Enum, adapter_tuple[0].peft_type).value}
|
||||
for k, adapter_tuple in t.cast(t.Dict[str, t.Tuple['PeftConfig', str]], dict(*llm.adapter_map.values())).items()
|
||||
},
|
||||
status_code=HTTPStatus.OK.value,
|
||||
)
|
||||
|
||||
|
||||
def openapi_schema(req: Request) -> Response:
|
||||
return schemas.OpenAPIResponse(req)
|
||||
|
||||
@@ -42,14 +42,27 @@ from ..protocol.openai import ModelCard
|
||||
from ..protocol.openai import ModelList
|
||||
from ..protocol.openai import UsageInfo
|
||||
|
||||
|
||||
schemas = get_generator(
|
||||
'openai',
|
||||
components=[ErrorResponse, ModelList, ChatCompletionResponse, ChatCompletionRequest, ChatCompletionStreamResponse, CompletionRequest, CompletionResponse, CompletionStreamResponse],
|
||||
tags=[{
|
||||
'name': 'OpenAI',
|
||||
'description': 'OpenAI Compatible API support',
|
||||
'externalDocs': 'https://platform.openai.com/docs/api-reference/completions/object'
|
||||
}])
|
||||
'openai',
|
||||
components=[
|
||||
ErrorResponse,
|
||||
ModelList,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionStreamResponse,
|
||||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
CompletionStreamResponse,
|
||||
],
|
||||
tags=[
|
||||
{
|
||||
'name': 'OpenAI',
|
||||
'description': 'OpenAI Compatible API support',
|
||||
'externalDocs': 'https://platform.openai.com/docs/api-reference/completions/object',
|
||||
}
|
||||
],
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
@@ -64,20 +77,34 @@ if t.TYPE_CHECKING:
|
||||
from openllm_core._typing_compat import M
|
||||
from openllm_core._typing_compat import T
|
||||
|
||||
|
||||
def jsonify_attr(obj: AttrsInstance) -> str:
|
||||
return orjson.dumps(converter.unstructure(obj)).decode()
|
||||
|
||||
def error_response(status_code: HTTPStatus, message: str) -> JSONResponse:
|
||||
return JSONResponse({'error': converter.unstructure(ErrorResponse(message=message, type='invalid_request_error', code=str(status_code.value)))}, status_code=status_code.value)
|
||||
|
||||
async def check_model(request: CompletionRequest | ChatCompletionRequest, model: str) -> JSONResponse | None:
|
||||
if request.model == model: return None
|
||||
return error_response(
|
||||
HTTPStatus.NOT_FOUND,
|
||||
f"Model '{request.model}' does not exists. Try 'GET /v1/models' to see available models.\nTip: If you are migrating from OpenAI, make sure to update your 'model' parameters in the request."
|
||||
def error_response(status_code: HTTPStatus, message: str) -> JSONResponse:
|
||||
return JSONResponse(
|
||||
{
|
||||
'error': converter.unstructure(
|
||||
ErrorResponse(message=message, type='invalid_request_error', code=str(status_code.value))
|
||||
)
|
||||
},
|
||||
status_code=status_code.value,
|
||||
)
|
||||
|
||||
def create_logprobs(token_ids: list[int], id_logprobs: list[dict[int, float]], initial_text_offset: int = 0, *, llm: openllm.LLM[M, T]) -> LogProbs:
|
||||
|
||||
async def check_model(request: CompletionRequest | ChatCompletionRequest, model: str) -> JSONResponse | None:
|
||||
if request.model == model:
|
||||
return None
|
||||
return error_response(
|
||||
HTTPStatus.NOT_FOUND,
|
||||
f"Model '{request.model}' does not exists. Try 'GET /v1/models' to see available models.\nTip: If you are migrating from OpenAI, make sure to update your 'model' parameters in the request.",
|
||||
)
|
||||
|
||||
|
||||
def create_logprobs(
|
||||
token_ids: list[int], id_logprobs: list[dict[int, float]], initial_text_offset: int = 0, *, llm: openllm.LLM[M, T]
|
||||
) -> LogProbs:
|
||||
# Create OpenAI-style logprobs.
|
||||
logprobs = LogProbs()
|
||||
last_token_len = 0
|
||||
@@ -94,22 +121,29 @@ def create_logprobs(token_ids: list[int], id_logprobs: list[dict[int, float]], i
|
||||
logprobs.top_logprobs.append({llm.tokenizer.convert_ids_to_tokens(i): p for i, p in id_logprob.items()})
|
||||
return logprobs
|
||||
|
||||
|
||||
def mount_to_svc(svc: bentoml.Service, llm: openllm.LLM[M, T]) -> bentoml.Service:
|
||||
app = Starlette(debug=True,
|
||||
routes=[
|
||||
Route('/models', functools.partial(list_models, llm=llm), methods=['GET']),
|
||||
Route('/completions', functools.partial(create_completions, llm=llm), methods=['POST']),
|
||||
Route('/chat/completions', functools.partial(create_chat_completions, llm=llm), methods=['POST'])
|
||||
])
|
||||
app = Starlette(
|
||||
debug=True,
|
||||
routes=[
|
||||
Route('/models', functools.partial(list_models, llm=llm), methods=['GET']),
|
||||
Route('/completions', functools.partial(create_completions, llm=llm), methods=['POST']),
|
||||
Route('/chat/completions', functools.partial(create_chat_completions, llm=llm), methods=['POST']),
|
||||
],
|
||||
)
|
||||
mount_path = '/v1'
|
||||
generated_schema = schemas.get_schema(routes=app.routes, mount_path=mount_path)
|
||||
svc.mount_asgi_app(app, path=mount_path)
|
||||
return append_schemas(svc, generated_schema)
|
||||
|
||||
|
||||
# GET /v1/models
|
||||
@add_schema_definitions(LIST_MODEL_SCHEMA)
|
||||
def list_models(_: Request, llm: openllm.LLM[M, T]) -> Response:
|
||||
return JSONResponse(converter.unstructure(ModelList(data=[ModelCard(id=llm.llm_type)])), status_code=HTTPStatus.OK.value)
|
||||
return JSONResponse(
|
||||
converter.unstructure(ModelList(data=[ModelCard(id=llm.llm_type)])), status_code=HTTPStatus.OK.value
|
||||
)
|
||||
|
||||
|
||||
# POST /v1/chat/completions
|
||||
@add_schema_definitions(CHAT_COMPLETION_SCHEMA)
|
||||
@@ -124,11 +158,14 @@ async def create_chat_completions(req: Request, llm: openllm.LLM[M, T]) -> Respo
|
||||
return error_response(HTTPStatus.BAD_REQUEST, 'Invalid JSON input received (Check server log).')
|
||||
logger.debug('Received chat completion request: %s', request)
|
||||
err_check = await check_model(request, llm.llm_type)
|
||||
if err_check is not None: return err_check
|
||||
if err_check is not None:
|
||||
return err_check
|
||||
|
||||
model_name, request_id = request.model, gen_random_uuid('chatcmpl')
|
||||
created_time = int(time.monotonic())
|
||||
prompt = llm.tokenizer.apply_chat_template(request.messages, tokenize=False, add_generation_prompt=llm.config['add_generation_prompt'])
|
||||
prompt = llm.tokenizer.apply_chat_template(
|
||||
request.messages, tokenize=False, add_generation_prompt=llm.config['add_generation_prompt']
|
||||
)
|
||||
logger.debug('Prompt: %r', prompt)
|
||||
config = llm.config.with_openai_request(request)
|
||||
|
||||
@@ -141,10 +178,15 @@ async def create_chat_completions(req: Request, llm: openllm.LLM[M, T]) -> Respo
|
||||
|
||||
def create_stream_response_json(index: int, text: str, finish_reason: str | None = None) -> str:
|
||||
return jsonify_attr(
|
||||
ChatCompletionStreamResponse(id=request_id,
|
||||
created=created_time,
|
||||
model=model_name,
|
||||
choices=[ChatCompletionResponseStreamChoice(index=index, delta=Delta(content=text), finish_reason=finish_reason)]))
|
||||
ChatCompletionStreamResponse(
|
||||
id=request_id,
|
||||
created=created_time,
|
||||
model=model_name,
|
||||
choices=[
|
||||
ChatCompletionResponseStreamChoice(index=index, delta=Delta(content=text), finish_reason=finish_reason)
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
async def completion_stream_generator() -> t.AsyncGenerator[str, None]:
|
||||
# first chunk with role
|
||||
@@ -160,25 +202,47 @@ async def create_chat_completions(req: Request, llm: openllm.LLM[M, T]) -> Respo
|
||||
|
||||
try:
|
||||
# Streaming case
|
||||
if request.stream: return StreamingResponse(completion_stream_generator(), media_type='text/event-stream')
|
||||
if request.stream:
|
||||
return StreamingResponse(completion_stream_generator(), media_type='text/event-stream')
|
||||
# Non-streaming case
|
||||
final_result: GenerationOutput | None = None
|
||||
texts: list[list[str]] = [[]] * config['n']
|
||||
token_ids: list[list[int]] = [[]] * config['n']
|
||||
async for res in result_generator:
|
||||
if await req.is_disconnected(): return error_response(HTTPStatus.BAD_REQUEST, 'Client disconnected.')
|
||||
if await req.is_disconnected():
|
||||
return error_response(HTTPStatus.BAD_REQUEST, 'Client disconnected.')
|
||||
for output in res.outputs:
|
||||
texts[output.index].append(output.text)
|
||||
token_ids[output.index].extend(output.token_ids)
|
||||
final_result = res
|
||||
if final_result is None: return error_response(HTTPStatus.BAD_REQUEST, 'No response from model.')
|
||||
final_result = final_result.with_options(outputs=[output.with_options(text=''.join(texts[output.index]), token_ids=token_ids[output.index]) for output in final_result.outputs])
|
||||
if final_result is None:
|
||||
return error_response(HTTPStatus.BAD_REQUEST, 'No response from model.')
|
||||
final_result = final_result.with_options(
|
||||
outputs=[
|
||||
output.with_options(text=''.join(texts[output.index]), token_ids=token_ids[output.index])
|
||||
for output in final_result.outputs
|
||||
]
|
||||
)
|
||||
choices = [
|
||||
ChatCompletionResponseChoice(index=output.index, message=ChatMessage(role='assistant', content=output.text), finish_reason=output.finish_reason) for output in final_result.outputs
|
||||
ChatCompletionResponseChoice(
|
||||
index=output.index,
|
||||
message=ChatMessage(role='assistant', content=output.text),
|
||||
finish_reason=output.finish_reason,
|
||||
)
|
||||
for output in final_result.outputs
|
||||
]
|
||||
num_prompt_tokens, num_generated_tokens = len(t.cast(t.List[int], final_result.prompt_token_ids)), sum(len(output.token_ids) for output in final_result.outputs)
|
||||
usage = UsageInfo(prompt_tokens=num_prompt_tokens, completion_tokens=num_generated_tokens, total_tokens=num_prompt_tokens + num_generated_tokens)
|
||||
response = ChatCompletionResponse(id=request_id, created=created_time, model=model_name, usage=usage, choices=choices)
|
||||
num_prompt_tokens, num_generated_tokens = (
|
||||
len(t.cast(t.List[int], final_result.prompt_token_ids)),
|
||||
sum(len(output.token_ids) for output in final_result.outputs),
|
||||
)
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=num_generated_tokens,
|
||||
total_tokens=num_prompt_tokens + num_generated_tokens,
|
||||
)
|
||||
response = ChatCompletionResponse(
|
||||
id=request_id, created=created_time, model=model_name, usage=usage, choices=choices
|
||||
)
|
||||
|
||||
if request.stream: # type: ignore[unreachable]
|
||||
# When user requests streaming but we don't stream, we still need to
|
||||
@@ -187,7 +251,9 @@ async def create_chat_completions(req: Request, llm: openllm.LLM[M, T]) -> Respo
|
||||
yield f'data: {jsonify_attr(response)}\n\n'
|
||||
yield 'data: [DONE]\n\n'
|
||||
|
||||
return StreamingResponse(fake_stream_generator(), media_type='text/event-stream', status_code=HTTPStatus.OK.value)
|
||||
return StreamingResponse(
|
||||
fake_stream_generator(), media_type='text/event-stream', status_code=HTTPStatus.OK.value
|
||||
)
|
||||
|
||||
return JSONResponse(converter.unstructure(response), status_code=HTTPStatus.OK.value)
|
||||
except Exception as err:
|
||||
@@ -195,6 +261,7 @@ async def create_chat_completions(req: Request, llm: openllm.LLM[M, T]) -> Respo
|
||||
logger.error('Error generating completion: %s', err)
|
||||
return error_response(HTTPStatus.INTERNAL_SERVER_ERROR, f'Exception: {err!s} (check server log)')
|
||||
|
||||
|
||||
# POST /v1/completions
|
||||
@add_schema_definitions(COMPLETION_SCHEMA)
|
||||
async def create_completions(req: Request, llm: openllm.LLM[M, T]) -> Response:
|
||||
@@ -208,18 +275,25 @@ async def create_completions(req: Request, llm: openllm.LLM[M, T]) -> Response:
|
||||
return error_response(HTTPStatus.BAD_REQUEST, 'Invalid JSON input received (Check server log).')
|
||||
logger.debug('Received legacy completion request: %s', request)
|
||||
err_check = await check_model(request, llm.llm_type)
|
||||
if err_check is not None: return err_check
|
||||
if err_check is not None:
|
||||
return err_check
|
||||
|
||||
if request.echo: return error_response(HTTPStatus.BAD_REQUEST, "'echo' is not yet supported.")
|
||||
if request.suffix is not None: return error_response(HTTPStatus.BAD_REQUEST, "'suffix' is not yet supported.")
|
||||
if request.logit_bias is not None and len(request.logit_bias) > 0: return error_response(HTTPStatus.BAD_REQUEST, "'logit_bias' is not yet supported.")
|
||||
if request.echo:
|
||||
return error_response(HTTPStatus.BAD_REQUEST, "'echo' is not yet supported.")
|
||||
if request.suffix is not None:
|
||||
return error_response(HTTPStatus.BAD_REQUEST, "'suffix' is not yet supported.")
|
||||
if request.logit_bias is not None and len(request.logit_bias) > 0:
|
||||
return error_response(HTTPStatus.BAD_REQUEST, "'logit_bias' is not yet supported.")
|
||||
|
||||
if not request.prompt: return error_response(HTTPStatus.BAD_REQUEST, 'Please provide a prompt.')
|
||||
if not request.prompt:
|
||||
return error_response(HTTPStatus.BAD_REQUEST, 'Please provide a prompt.')
|
||||
prompt = request.prompt
|
||||
# TODO: Support multiple prompts
|
||||
|
||||
if request.logprobs is not None and llm.__llm_backend__ == 'pt': # TODO: support logprobs generation for PyTorch
|
||||
return error_response(HTTPStatus.BAD_REQUEST, "'logprobs' is not yet supported for PyTorch models. Make sure to unset `logprobs`.")
|
||||
return error_response(
|
||||
HTTPStatus.BAD_REQUEST, "'logprobs' is not yet supported for PyTorch models. Make sure to unset `logprobs`."
|
||||
)
|
||||
|
||||
model_name, request_id = request.model, gen_random_uuid('cmpl')
|
||||
created_time = int(time.monotonic())
|
||||
@@ -236,12 +310,19 @@ async def create_completions(req: Request, llm: openllm.LLM[M, T]) -> Response:
|
||||
# TODO: support use_beam_search
|
||||
stream = request.stream and (config['best_of'] is None or config['n'] == config['best_of'])
|
||||
|
||||
def create_stream_response_json(index: int, text: str, logprobs: LogProbs | None = None, finish_reason: str | None = None) -> str:
|
||||
def create_stream_response_json(
|
||||
index: int, text: str, logprobs: LogProbs | None = None, finish_reason: str | None = None
|
||||
) -> str:
|
||||
return jsonify_attr(
|
||||
CompletionStreamResponse(id=request_id,
|
||||
created=created_time,
|
||||
model=model_name,
|
||||
choices=[CompletionResponseStreamChoice(index=index, text=text, logprobs=logprobs, finish_reason=finish_reason)]))
|
||||
CompletionStreamResponse(
|
||||
id=request_id,
|
||||
created=created_time,
|
||||
model=model_name,
|
||||
choices=[
|
||||
CompletionResponseStreamChoice(index=index, text=text, logprobs=logprobs, finish_reason=finish_reason)
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
async def completion_stream_generator() -> t.AsyncGenerator[str, None]:
|
||||
previous_num_tokens = [0] * config['n']
|
||||
@@ -249,7 +330,11 @@ async def create_completions(req: Request, llm: openllm.LLM[M, T]) -> Response:
|
||||
for output in res.outputs:
|
||||
i = output.index
|
||||
if request.logprobs is not None:
|
||||
logprobs = create_logprobs(token_ids=output.token_ids, id_logprobs=t.cast(SampleLogprobs, output.logprobs)[previous_num_tokens[i]:], llm=llm)
|
||||
logprobs = create_logprobs(
|
||||
token_ids=output.token_ids,
|
||||
id_logprobs=t.cast(SampleLogprobs, output.logprobs)[previous_num_tokens[i] :],
|
||||
llm=llm,
|
||||
)
|
||||
else:
|
||||
logprobs = None
|
||||
previous_num_tokens[i] += len(output.token_ids)
|
||||
@@ -261,32 +346,50 @@ async def create_completions(req: Request, llm: openllm.LLM[M, T]) -> Response:
|
||||
|
||||
try:
|
||||
# Streaming case
|
||||
if stream: return StreamingResponse(completion_stream_generator(), media_type='text/event-stream')
|
||||
if stream:
|
||||
return StreamingResponse(completion_stream_generator(), media_type='text/event-stream')
|
||||
# Non-streaming case
|
||||
final_result: GenerationOutput | None = None
|
||||
texts: list[list[str]] = [[]] * config['n']
|
||||
token_ids: list[list[int]] = [[]] * config['n']
|
||||
async for res in result_generator:
|
||||
if await req.is_disconnected(): return error_response(HTTPStatus.BAD_REQUEST, 'Client disconnected.')
|
||||
if await req.is_disconnected():
|
||||
return error_response(HTTPStatus.BAD_REQUEST, 'Client disconnected.')
|
||||
for output in res.outputs:
|
||||
texts[output.index].append(output.text)
|
||||
token_ids[output.index].extend(output.token_ids)
|
||||
final_result = res
|
||||
if final_result is None: return error_response(HTTPStatus.BAD_REQUEST, 'No response from model.')
|
||||
final_result = final_result.with_options(outputs=[output.with_options(text=''.join(texts[output.index]), token_ids=token_ids[output.index]) for output in final_result.outputs])
|
||||
if final_result is None:
|
||||
return error_response(HTTPStatus.BAD_REQUEST, 'No response from model.')
|
||||
final_result = final_result.with_options(
|
||||
outputs=[
|
||||
output.with_options(text=''.join(texts[output.index]), token_ids=token_ids[output.index])
|
||||
for output in final_result.outputs
|
||||
]
|
||||
)
|
||||
|
||||
choices: list[CompletionResponseChoice] = []
|
||||
for output in final_result.outputs:
|
||||
if request.logprobs is not None:
|
||||
logprobs = create_logprobs(token_ids=output.token_ids, id_logprobs=t.cast(SampleLogprobs, output.logprobs), llm=llm)
|
||||
logprobs = create_logprobs(
|
||||
token_ids=output.token_ids, id_logprobs=t.cast(SampleLogprobs, output.logprobs), llm=llm
|
||||
)
|
||||
else:
|
||||
logprobs = None
|
||||
choice_data = CompletionResponseChoice(index=output.index, text=output.text, logprobs=logprobs, finish_reason=output.finish_reason)
|
||||
choice_data = CompletionResponseChoice(
|
||||
index=output.index, text=output.text, logprobs=logprobs, finish_reason=output.finish_reason
|
||||
)
|
||||
choices.append(choice_data)
|
||||
|
||||
num_prompt_tokens = len(t.cast(t.List[int], final_result.prompt_token_ids)) # XXX: We will always return prompt_token_ids, so this won't be None
|
||||
num_prompt_tokens = len(
|
||||
t.cast(t.List[int], final_result.prompt_token_ids)
|
||||
) # XXX: We will always return prompt_token_ids, so this won't be None
|
||||
num_generated_tokens = sum(len(output.token_ids) for output in final_result.outputs)
|
||||
usage = UsageInfo(prompt_tokens=num_prompt_tokens, completion_tokens=num_generated_tokens, total_tokens=num_prompt_tokens + num_generated_tokens)
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=num_generated_tokens,
|
||||
total_tokens=num_prompt_tokens + num_generated_tokens,
|
||||
)
|
||||
response = CompletionResponse(id=request_id, created=created_time, model=model_name, usage=usage, choices=choices)
|
||||
|
||||
if request.stream:
|
||||
@@ -296,7 +399,9 @@ async def create_completions(req: Request, llm: openllm.LLM[M, T]) -> Response:
|
||||
yield f'data: {jsonify_attr(response)}\n\n'
|
||||
yield 'data: [DONE]\n\n'
|
||||
|
||||
return StreamingResponse(fake_stream_generator(), media_type='text/event-stream', status_code=HTTPStatus.OK.value)
|
||||
return StreamingResponse(
|
||||
fake_stream_generator(), media_type='text/event-stream', status_code=HTTPStatus.OK.value
|
||||
)
|
||||
|
||||
return JSONResponse(converter.unstructure(response), status_code=HTTPStatus.OK.value)
|
||||
except Exception as err:
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""Base exceptions for OpenLLM. This extends BentoML exceptions."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from openllm_core.exceptions import Error as Error
|
||||
|
||||
@@ -23,14 +23,15 @@ logger = logging.getLogger(__name__)
|
||||
from datasets import load_dataset
|
||||
from trl import SFTTrainer
|
||||
|
||||
DEFAULT_MODEL_ID = "ybelkada/falcon-7b-sharded-bf16"
|
||||
DATASET_NAME = "timdettmers/openassistant-guanaco"
|
||||
DEFAULT_MODEL_ID = 'ybelkada/falcon-7b-sharded-bf16'
|
||||
DATASET_NAME = 'timdettmers/openassistant-guanaco'
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class TrainingArguments:
|
||||
per_device_train_batch_size: int = dataclasses.field(default=4)
|
||||
gradient_accumulation_steps: int = dataclasses.field(default=4)
|
||||
optim: str = dataclasses.field(default="paged_adamw_32bit")
|
||||
optim: str = dataclasses.field(default='paged_adamw_32bit')
|
||||
save_steps: int = dataclasses.field(default=10)
|
||||
warmup_steps: int = dataclasses.field(default=10)
|
||||
max_steps: int = dataclasses.field(default=500)
|
||||
@@ -40,47 +41,56 @@ class TrainingArguments:
|
||||
warmup_ratio: float = dataclasses.field(default=0.03)
|
||||
fp16: bool = dataclasses.field(default=True)
|
||||
group_by_length: bool = dataclasses.field(default=True)
|
||||
lr_scheduler_type: str = dataclasses.field(default="constant")
|
||||
output_dir: str = dataclasses.field(default=os.path.join(os.getcwd(), "outputs", "falcon"))
|
||||
lr_scheduler_type: str = dataclasses.field(default='constant')
|
||||
output_dir: str = dataclasses.field(default=os.path.join(os.getcwd(), 'outputs', 'falcon'))
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ModelArguments:
|
||||
model_id: str = dataclasses.field(default=DEFAULT_MODEL_ID)
|
||||
max_sequence_length: int = dataclasses.field(default=512)
|
||||
|
||||
|
||||
parser = transformers.HfArgumentParser((ModelArguments, TrainingArguments))
|
||||
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
||||
if len(sys.argv) == 2 and sys.argv[1].endswith('.json'):
|
||||
# If we pass only one argument to the script and it's the path to a json file,
|
||||
# let's parse it to get our arguments.
|
||||
model_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
|
||||
else:
|
||||
model_args, training_args = t.cast(t.Tuple[ModelArguments, TrainingArguments], parser.parse_args_into_dataclasses())
|
||||
|
||||
llm = openllm.LLM(model_args.model_id, quantize="int4", bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float16)
|
||||
model, tokenizer = llm.prepare_for_training(adapter_type="lora",
|
||||
lora_alpha=16,
|
||||
lora_dropout=0.1,
|
||||
r=16,
|
||||
bias="none",
|
||||
target_modules=["query_key_value", "dense", "dense_h_to_4h", "dense_4h_to_h"])
|
||||
llm = openllm.LLM(
|
||||
model_args.model_id, quantize='int4', bnb_4bit_quant_type='nf4', bnb_4bit_compute_dtype=torch.float16
|
||||
)
|
||||
model, tokenizer = llm.prepare_for_training(
|
||||
adapter_type='lora',
|
||||
lora_alpha=16,
|
||||
lora_dropout=0.1,
|
||||
r=16,
|
||||
bias='none',
|
||||
target_modules=['query_key_value', 'dense', 'dense_h_to_4h', 'dense_4h_to_h'],
|
||||
)
|
||||
model.config.use_cache = False
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
dataset = load_dataset(DATASET_NAME, split="train")
|
||||
dataset = load_dataset(DATASET_NAME, split='train')
|
||||
|
||||
trainer = SFTTrainer(model=model,
|
||||
train_dataset=dataset,
|
||||
dataset_text_field="text",
|
||||
max_seq_length=model_args.max_sequence_length,
|
||||
tokenizer=tokenizer,
|
||||
args=dataclasses.replace(transformers.TrainingArguments(training_args.output_dir), **dataclasses.asdict(training_args)),
|
||||
)
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
train_dataset=dataset,
|
||||
dataset_text_field='text',
|
||||
max_seq_length=model_args.max_sequence_length,
|
||||
tokenizer=tokenizer,
|
||||
args=dataclasses.replace(
|
||||
transformers.TrainingArguments(training_args.output_dir), **dataclasses.asdict(training_args)
|
||||
),
|
||||
)
|
||||
|
||||
# upcast layernorm in float32 for more stable training
|
||||
for name, module in trainer.model.named_modules():
|
||||
if "norm" in name:
|
||||
if 'norm' in name:
|
||||
module = module.to(torch.float32)
|
||||
|
||||
trainer.train()
|
||||
|
||||
trainer.model.save_pretrained(os.path.join(training_args.output_dir, "lora"))
|
||||
trainer.model.save_pretrained(os.path.join(training_args.output_dir, 'lora'))
|
||||
|
||||
@@ -15,6 +15,7 @@ MAX_NEW_TOKENS = 384
|
||||
Q = 'Answer the following question, step by step:\n{q}\nA:'
|
||||
question = 'What is the meaning of life?'
|
||||
|
||||
|
||||
async def main() -> int:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('question', default=question)
|
||||
@@ -37,6 +38,7 @@ async def main() -> int:
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
def _mp_fn(index: t.Any): # type: ignore
|
||||
# For xla_spawn (TPUs)
|
||||
asyncio.run(main())
|
||||
|
||||
@@ -30,39 +30,45 @@ from random import randint, randrange
|
||||
import bitsandbytes as bnb
|
||||
from datasets import load_dataset
|
||||
|
||||
|
||||
# COPIED FROM https://github.com/artidoro/qlora/blob/main/qlora.py
|
||||
def find_all_linear_names(model):
|
||||
lora_module_names = set()
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, bnb.nn.Linear4bit):
|
||||
names = name.split(".")
|
||||
names = name.split('.')
|
||||
lora_module_names.add(names[0] if len(names) == 1 else names[-1])
|
||||
|
||||
if "lm_head" in lora_module_names: # needed for 16-bit
|
||||
lora_module_names.remove("lm_head")
|
||||
if 'lm_head' in lora_module_names: # needed for 16-bit
|
||||
lora_module_names.remove('lm_head')
|
||||
return list(lora_module_names)
|
||||
|
||||
|
||||
# Change this to the local converted path if you don't have access to the meta-llama model
|
||||
DEFAULT_MODEL_ID = "meta-llama/Llama-2-7b-hf"
|
||||
DEFAULT_MODEL_ID = 'meta-llama/Llama-2-7b-hf'
|
||||
# change this to 'main' if you want to use the latest llama
|
||||
DEFAULT_MODEL_VERSION = "335a02887eb6684d487240bbc28b5699298c3135"
|
||||
DATASET_NAME = "databricks/databricks-dolly-15k"
|
||||
DEFAULT_MODEL_VERSION = '335a02887eb6684d487240bbc28b5699298c3135'
|
||||
DATASET_NAME = 'databricks/databricks-dolly-15k'
|
||||
|
||||
|
||||
def format_dolly(sample):
|
||||
instruction = f"### Instruction\n{sample['instruction']}"
|
||||
context = f"### Context\n{sample['context']}" if len(sample["context"]) > 0 else None
|
||||
context = f"### Context\n{sample['context']}" if len(sample['context']) > 0 else None
|
||||
response = f"### Answer\n{sample['response']}"
|
||||
# join all the parts together
|
||||
prompt = "\n\n".join([i for i in [instruction, context, response] if i is not None])
|
||||
prompt = '\n\n'.join([i for i in [instruction, context, response] if i is not None])
|
||||
return prompt
|
||||
|
||||
|
||||
# template dataset to add prompt to each sample
|
||||
def template_dataset(sample, tokenizer):
|
||||
sample["text"] = f"{format_dolly(sample)}{tokenizer.eos_token}"
|
||||
sample['text'] = f'{format_dolly(sample)}{tokenizer.eos_token}'
|
||||
return sample
|
||||
|
||||
|
||||
# empty list to save remainder from batches to use in next batch
|
||||
remainder = {"input_ids": [], "attention_mask": [], "token_type_ids": []}
|
||||
remainder = {'input_ids': [], 'attention_mask': [], 'token_type_ids': []}
|
||||
|
||||
|
||||
def chunk(sample, chunk_length=2048):
|
||||
# define global remainder variable to save remainder from batches to use in next batch
|
||||
@@ -78,61 +84,76 @@ def chunk(sample, chunk_length=2048):
|
||||
batch_chunk_length = (batch_total_length // chunk_length) * chunk_length
|
||||
|
||||
# Split by chunks of max_len.
|
||||
result = {k: [t[i:i + chunk_length] for i in range(0, batch_chunk_length, chunk_length)] for k, t in concatenated_examples.items()}
|
||||
result = {
|
||||
k: [t[i : i + chunk_length] for i in range(0, batch_chunk_length, chunk_length)]
|
||||
for k, t in concatenated_examples.items()
|
||||
}
|
||||
# add remainder to global variable for next batch
|
||||
remainder = {k: concatenated_examples[k][batch_chunk_length:] for k in concatenated_examples.keys()}
|
||||
# prepare labels
|
||||
result["labels"] = result["input_ids"].copy()
|
||||
result['labels'] = result['input_ids'].copy()
|
||||
return result
|
||||
|
||||
|
||||
def prepare_datasets(tokenizer, dataset_name=DATASET_NAME):
|
||||
# Load dataset from the hub
|
||||
dataset = load_dataset(dataset_name, split="train")
|
||||
dataset = load_dataset(dataset_name, split='train')
|
||||
|
||||
print(f"dataset size: {len(dataset)}")
|
||||
print(f'dataset size: {len(dataset)}')
|
||||
print(dataset[randrange(len(dataset))])
|
||||
|
||||
# apply prompt template per sample
|
||||
dataset = dataset.map(partial(template_dataset, tokenizer=tokenizer), remove_columns=list(dataset.features))
|
||||
# print random sample
|
||||
print("Sample from dolly-v2 ds:", dataset[randint(0, len(dataset))]["text"])
|
||||
print('Sample from dolly-v2 ds:', dataset[randint(0, len(dataset))]['text'])
|
||||
|
||||
# tokenize and chunk dataset
|
||||
lm_dataset = dataset.map(lambda sample: tokenizer(sample["text"]), batched=True, remove_columns=list(dataset.features)).map(partial(chunk, chunk_length=2048), batched=True)
|
||||
lm_dataset = dataset.map(
|
||||
lambda sample: tokenizer(sample['text']), batched=True, remove_columns=list(dataset.features)
|
||||
).map(partial(chunk, chunk_length=2048), batched=True)
|
||||
|
||||
# Print total number of samples
|
||||
print(f"Total number of samples: {len(lm_dataset)}")
|
||||
print(f'Total number of samples: {len(lm_dataset)}')
|
||||
return lm_dataset
|
||||
|
||||
def prepare_for_int4_training(model_id: str,
|
||||
model_version: str | None = None,
|
||||
gradient_checkpointing: bool = True,
|
||||
bf16: bool = True,
|
||||
) -> tuple[peft.PeftModel, transformers.LlamaTokenizerFast]:
|
||||
|
||||
def prepare_for_int4_training(
|
||||
model_id: str, model_version: str | None = None, gradient_checkpointing: bool = True, bf16: bool = True
|
||||
) -> tuple[peft.PeftModel, transformers.LlamaTokenizerFast]:
|
||||
from peft.tuners.lora import LoraLayer
|
||||
|
||||
llm = openllm.LLM(model_id, revision=model_version, quantize="int4", bnb_4bit_compute_dtype=torch.bfloat16, use_cache=not gradient_checkpointing, device_map="auto")
|
||||
print("Model summary:", llm.model)
|
||||
llm = openllm.LLM(
|
||||
model_id,
|
||||
revision=model_version,
|
||||
quantize='int4',
|
||||
bnb_4bit_compute_dtype=torch.bfloat16,
|
||||
use_cache=not gradient_checkpointing,
|
||||
device_map='auto',
|
||||
)
|
||||
print('Model summary:', llm.model)
|
||||
|
||||
# get lora target modules
|
||||
modules = find_all_linear_names(llm.model)
|
||||
print(f"Found {len(modules)} modules to quantize: {modules}")
|
||||
print(f'Found {len(modules)} modules to quantize: {modules}')
|
||||
|
||||
model, tokenizer = llm.prepare_for_training(adapter_type="lora", use_gradient_checkpointing=gradient_checkpointing, target_modules=modules)
|
||||
model, tokenizer = llm.prepare_for_training(
|
||||
adapter_type='lora', use_gradient_checkpointing=gradient_checkpointing, target_modules=modules
|
||||
)
|
||||
|
||||
# pre-process the model by upcasting the layer norms in float 32 for
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, LoraLayer):
|
||||
if bf16:
|
||||
module = module.to(torch.bfloat16)
|
||||
if "norm" in name:
|
||||
if 'norm' in name:
|
||||
module = module.to(torch.float32)
|
||||
if "lm_head" in name or "embed_tokens" in name:
|
||||
if hasattr(module, "weight"):
|
||||
if 'lm_head' in name or 'embed_tokens' in name:
|
||||
if hasattr(module, 'weight'):
|
||||
if bf16 and module.weight.dtype == torch.float32:
|
||||
module = module.to(torch.bfloat16)
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class TrainingArguments:
|
||||
per_device_train_batch_size: int = dataclasses.field(default=1)
|
||||
@@ -141,9 +162,10 @@ class TrainingArguments:
|
||||
learning_rate: float = dataclasses.field(default=5e-5)
|
||||
num_train_epochs: int = dataclasses.field(default=3)
|
||||
logging_steps: int = dataclasses.field(default=1)
|
||||
report_to: str = dataclasses.field(default="none")
|
||||
output_dir: str = dataclasses.field(default=os.path.join(os.getcwd(), "outputs", "llama"))
|
||||
save_strategy: str = dataclasses.field(default="no")
|
||||
report_to: str = dataclasses.field(default='none')
|
||||
output_dir: str = dataclasses.field(default=os.path.join(os.getcwd(), 'outputs', 'llama'))
|
||||
save_strategy: str = dataclasses.field(default='no')
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ModelArguments:
|
||||
@@ -152,32 +174,42 @@ class ModelArguments:
|
||||
seed: int = dataclasses.field(default=42)
|
||||
merge_weights: bool = dataclasses.field(default=False)
|
||||
|
||||
|
||||
if openllm.utils.in_notebook():
|
||||
model_args, training_rags = ModelArguments(), TrainingArguments()
|
||||
else:
|
||||
parser = transformers.HfArgumentParser((ModelArguments, TrainingArguments))
|
||||
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
||||
if len(sys.argv) == 2 and sys.argv[1].endswith('.json'):
|
||||
# If we pass only one argument to the script and it's the path to a json file,
|
||||
# let's parse it to get our arguments.
|
||||
model_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
|
||||
else:
|
||||
model_args, training_args = t.cast(t.Tuple[ModelArguments, TrainingArguments], parser.parse_args_into_dataclasses())
|
||||
model_args, training_args = t.cast(
|
||||
t.Tuple[ModelArguments, TrainingArguments], parser.parse_args_into_dataclasses()
|
||||
)
|
||||
|
||||
# import the model first hand
|
||||
openllm.import_model(model_id=model_args.model_id, model_version=model_args.model_version)
|
||||
|
||||
|
||||
def train_loop(model_args: ModelArguments, training_args: TrainingArguments):
|
||||
import peft
|
||||
|
||||
transformers.set_seed(model_args.seed)
|
||||
|
||||
model, tokenizer = prepare_for_int4_training(model_args.model_id, gradient_checkpointing=training_args.gradient_checkpointing, bf16=training_args.bf16,)
|
||||
model, tokenizer = prepare_for_int4_training(
|
||||
model_args.model_id, gradient_checkpointing=training_args.gradient_checkpointing, bf16=training_args.bf16
|
||||
)
|
||||
datasets = prepare_datasets(tokenizer)
|
||||
|
||||
trainer = transformers.Trainer(model=model,
|
||||
args=dataclasses.replace(transformers.TrainingArguments(training_args.output_dir), **dataclasses.asdict(training_args)),
|
||||
train_dataset=datasets,
|
||||
data_collator=transformers.default_data_collator)
|
||||
trainer = transformers.Trainer(
|
||||
model=model,
|
||||
args=dataclasses.replace(
|
||||
transformers.TrainingArguments(training_args.output_dir), **dataclasses.asdict(training_args)
|
||||
),
|
||||
train_dataset=datasets,
|
||||
data_collator=transformers.default_data_collator,
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
|
||||
@@ -192,11 +224,16 @@ def train_loop(model_args: ModelArguments, training_args: TrainingArguments):
|
||||
del model, trainer
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
model = peft.AutoPeftModelForCausalLM.from_pretrained(training_args.output_dir, low_cpu_mem_usage=True, torch_dtype=torch.float16)
|
||||
model = peft.AutoPeftModelForCausalLM.from_pretrained(
|
||||
training_args.output_dir, low_cpu_mem_usage=True, torch_dtype=torch.float16
|
||||
)
|
||||
# merge lora with base weights and save
|
||||
model = model.merge_and_unload()
|
||||
model.save_pretrained(os.path.join(os.getcwd(), "outputs", "merged_llama_lora"), safe_serialization=True, max_shard_size="2GB")
|
||||
model.save_pretrained(
|
||||
os.path.join(os.getcwd(), 'outputs', 'merged_llama_lora'), safe_serialization=True, max_shard_size='2GB'
|
||||
)
|
||||
else:
|
||||
trainer.model.save_pretrained(os.path.join(training_args.output_dir, "lora"))
|
||||
trainer.model.save_pretrained(os.path.join(training_args.output_dir, 'lora'))
|
||||
|
||||
|
||||
train_loop(model_args, training_args)
|
||||
|
||||
@@ -24,13 +24,21 @@ from datasets import load_dataset
|
||||
if t.TYPE_CHECKING:
|
||||
from peft import PeftModel
|
||||
|
||||
DEFAULT_MODEL_ID = "facebook/opt-6.7b"
|
||||
DEFAULT_MODEL_ID = 'facebook/opt-6.7b'
|
||||
|
||||
|
||||
def load_trainer(
|
||||
model: PeftModel, tokenizer: transformers.GPT2TokenizerFast, dataset_dict: t.Any, training_args: TrainingArguments
|
||||
):
|
||||
return transformers.Trainer(
|
||||
model=model,
|
||||
train_dataset=dataset_dict['train'],
|
||||
args=dataclasses.replace(
|
||||
transformers.TrainingArguments(training_args.output_dir), **dataclasses.asdict(training_args)
|
||||
),
|
||||
data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
|
||||
)
|
||||
|
||||
def load_trainer(model: PeftModel, tokenizer: transformers.GPT2TokenizerFast, dataset_dict: t.Any, training_args: TrainingArguments):
|
||||
return transformers.Trainer(model=model,
|
||||
train_dataset=dataset_dict["train"],
|
||||
args=dataclasses.replace(transformers.TrainingArguments(training_args.output_dir), **dataclasses.asdict(training_args)),
|
||||
data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False))
|
||||
|
||||
@dataclasses.dataclass
|
||||
class TrainingArguments:
|
||||
@@ -41,30 +49,34 @@ class TrainingArguments:
|
||||
learning_rate: float = dataclasses.field(default=3e-4)
|
||||
fp16: bool = dataclasses.field(default=True)
|
||||
logging_steps: int = dataclasses.field(default=1)
|
||||
output_dir: str = dataclasses.field(default=os.path.join(os.getcwd(), "outputs", "opt"))
|
||||
output_dir: str = dataclasses.field(default=os.path.join(os.getcwd(), 'outputs', 'opt'))
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ModelArguments:
|
||||
model_id: str = dataclasses.field(default=DEFAULT_MODEL_ID)
|
||||
|
||||
|
||||
parser = transformers.HfArgumentParser((ModelArguments, TrainingArguments))
|
||||
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
||||
if len(sys.argv) == 2 and sys.argv[1].endswith('.json'):
|
||||
# If we pass only one argument to the script and it's the path to a json file,
|
||||
# let's parse it to get our arguments.
|
||||
model_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
|
||||
else:
|
||||
model_args, training_args = t.cast(t.Tuple[ModelArguments, TrainingArguments], parser.parse_args_into_dataclasses())
|
||||
|
||||
llm = openllm.LLM(model_args.model_id, quantize="int8")
|
||||
model, tokenizer = llm.prepare_for_training(adapter_type="lora", r=16, lora_alpha=32, target_modules=["q_proj", "v_proj"], lora_dropout=0.05, bias="none")
|
||||
llm = openllm.LLM(model_args.model_id, quantize='int8')
|
||||
model, tokenizer = llm.prepare_for_training(
|
||||
adapter_type='lora', r=16, lora_alpha=32, target_modules=['q_proj', 'v_proj'], lora_dropout=0.05, bias='none'
|
||||
)
|
||||
|
||||
# ft on english_quotes
|
||||
data = load_dataset("Abirate/english_quotes")
|
||||
data = data.map(lambda samples: tokenizer(samples["quote"]), batched=True)
|
||||
data = load_dataset('Abirate/english_quotes')
|
||||
data = data.map(lambda samples: tokenizer(samples['quote']), batched=True)
|
||||
|
||||
trainer = load_trainer(model, tokenizer, data, training_args)
|
||||
model.config.use_cache = False # silence just for warning, reenable for inference later
|
||||
|
||||
trainer.train()
|
||||
|
||||
trainer.model.save_pretrained(os.path.join(training_args.output_dir, "lora"))
|
||||
trainer.model.save_pretrained(os.path.join(training_args.output_dir, 'lora'))
|
||||
|
||||
@@ -2,12 +2,14 @@
|
||||
|
||||
Currently support OpenAI compatible API.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import os
|
||||
import typing as t
|
||||
|
||||
from openllm_core.utils import LazyModule
|
||||
|
||||
|
||||
_import_structure: dict[str, list[str]] = {'openai': []}
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
|
||||
@@ -3,15 +3,18 @@ import typing as t
|
||||
|
||||
import attr
|
||||
|
||||
|
||||
@attr.define
|
||||
class AgentRequest:
|
||||
inputs: str
|
||||
parameters: t.Dict[str, t.Any]
|
||||
|
||||
|
||||
@attr.define
|
||||
class AgentResponse:
|
||||
generated_text: str
|
||||
|
||||
|
||||
@attr.define
|
||||
class HFErrorResponse:
|
||||
error_code: int
|
||||
|
||||
@@ -8,6 +8,7 @@ import openllm_core
|
||||
|
||||
from openllm_core.utils import converter
|
||||
|
||||
|
||||
@attr.define
|
||||
class ErrorResponse:
|
||||
message: str
|
||||
@@ -16,6 +17,7 @@ class ErrorResponse:
|
||||
param: t.Optional[str] = None
|
||||
code: t.Optional[str] = None
|
||||
|
||||
|
||||
@attr.define
|
||||
class CompletionRequest:
|
||||
prompt: str
|
||||
@@ -37,6 +39,7 @@ class CompletionRequest:
|
||||
top_k: t.Optional[int] = attr.field(default=None)
|
||||
best_of: t.Optional[int] = attr.field(default=1)
|
||||
|
||||
|
||||
@attr.define
|
||||
class ChatCompletionRequest:
|
||||
messages: t.List[t.Dict[str, str]]
|
||||
@@ -57,6 +60,7 @@ class ChatCompletionRequest:
|
||||
top_k: t.Optional[int] = attr.field(default=None)
|
||||
best_of: t.Optional[int] = attr.field(default=1)
|
||||
|
||||
|
||||
@attr.define
|
||||
class LogProbs:
|
||||
text_offset: t.List[int] = attr.field(default=attr.Factory(list))
|
||||
@@ -64,12 +68,14 @@ class LogProbs:
|
||||
tokens: t.List[str] = attr.field(default=attr.Factory(list))
|
||||
top_logprobs: t.List[t.Dict[str, t.Any]] = attr.field(default=attr.Factory(list))
|
||||
|
||||
|
||||
@attr.define
|
||||
class UsageInfo:
|
||||
prompt_tokens: int = attr.field(default=0)
|
||||
completion_tokens: int = attr.field(default=0)
|
||||
total_tokens: int = attr.field(default=0)
|
||||
|
||||
|
||||
@attr.define
|
||||
class CompletionResponseChoice:
|
||||
index: int
|
||||
@@ -77,6 +83,7 @@ class CompletionResponseChoice:
|
||||
logprobs: t.Optional[LogProbs] = None
|
||||
finish_reason: t.Optional[str] = None
|
||||
|
||||
|
||||
@attr.define
|
||||
class CompletionResponseStreamChoice:
|
||||
index: int
|
||||
@@ -84,6 +91,7 @@ class CompletionResponseStreamChoice:
|
||||
logprobs: t.Optional[LogProbs] = None
|
||||
finish_reason: t.Optional[str] = None
|
||||
|
||||
|
||||
@attr.define
|
||||
class CompletionStreamResponse:
|
||||
model: str
|
||||
@@ -92,6 +100,7 @@ class CompletionStreamResponse:
|
||||
id: str = attr.field(default=attr.Factory(lambda: openllm_core.utils.gen_random_uuid('cmpl')))
|
||||
created: int = attr.field(default=attr.Factory(lambda: int(time.monotonic())))
|
||||
|
||||
|
||||
@attr.define
|
||||
class CompletionResponse:
|
||||
choices: t.List[CompletionResponseChoice]
|
||||
@@ -101,32 +110,39 @@ class CompletionResponse:
|
||||
id: str = attr.field(default=attr.Factory(lambda: openllm_core.utils.gen_random_uuid('cmpl')))
|
||||
created: int = attr.field(default=attr.Factory(lambda: int(time.monotonic())))
|
||||
|
||||
|
||||
LiteralRole = t.Literal['system', 'user', 'assistant']
|
||||
|
||||
|
||||
@attr.define
|
||||
class Delta:
|
||||
role: t.Optional[LiteralRole] = None
|
||||
content: t.Optional[str] = None
|
||||
|
||||
|
||||
@attr.define
|
||||
class ChatMessage:
|
||||
role: LiteralRole
|
||||
content: str
|
||||
|
||||
|
||||
converter.register_unstructure_hook(ChatMessage, lambda msg: {'role': msg.role, 'content': msg.content})
|
||||
|
||||
|
||||
@attr.define
|
||||
class ChatCompletionResponseStreamChoice:
|
||||
index: int
|
||||
delta: Delta
|
||||
finish_reason: t.Optional[str] = attr.field(default=None)
|
||||
|
||||
|
||||
@attr.define
|
||||
class ChatCompletionResponseChoice:
|
||||
index: int
|
||||
message: ChatMessage
|
||||
finish_reason: t.Optional[str] = attr.field(default=None)
|
||||
|
||||
|
||||
@attr.define
|
||||
class ChatCompletionResponse:
|
||||
choices: t.List[ChatCompletionResponseChoice]
|
||||
@@ -136,6 +152,7 @@ class ChatCompletionResponse:
|
||||
created: int = attr.field(default=attr.Factory(lambda: int(time.monotonic())))
|
||||
usage: UsageInfo = attr.field(default=attr.Factory(lambda: UsageInfo()))
|
||||
|
||||
|
||||
@attr.define
|
||||
class ChatCompletionStreamResponse:
|
||||
choices: t.List[ChatCompletionResponseStreamChoice]
|
||||
@@ -144,6 +161,7 @@ class ChatCompletionStreamResponse:
|
||||
id: str = attr.field(default=attr.Factory(lambda: openllm_core.utils.gen_random_uuid('chatcmpl')))
|
||||
created: int = attr.field(default=attr.Factory(lambda: int(time.monotonic())))
|
||||
|
||||
|
||||
@attr.define
|
||||
class ModelCard:
|
||||
id: str
|
||||
@@ -151,19 +169,25 @@ class ModelCard:
|
||||
created: int = attr.field(default=attr.Factory(lambda: int(time.monotonic())))
|
||||
owned_by: str = 'na'
|
||||
|
||||
|
||||
@attr.define
|
||||
class ModelList:
|
||||
object: str = 'list'
|
||||
data: t.List[ModelCard] = attr.field(factory=list)
|
||||
|
||||
|
||||
async def get_conversation_prompt(request: ChatCompletionRequest, llm_config: openllm_core.LLMConfig) -> str:
|
||||
conv = llm_config.get_conversation_template()
|
||||
for message in request.messages:
|
||||
msg_role = message['role']
|
||||
if msg_role == 'system': conv.set_system_message(message['content'])
|
||||
elif msg_role == 'user': conv.append_message(conv.roles[0], message['content'])
|
||||
elif msg_role == 'assistant': conv.append_message(conv.roles[1], message['content'])
|
||||
else: raise ValueError(f'Unknown role: {msg_role}')
|
||||
if msg_role == 'system':
|
||||
conv.set_system_message(message['content'])
|
||||
elif msg_role == 'user':
|
||||
conv.append_message(conv.roles[0], message['content'])
|
||||
elif msg_role == 'assistant':
|
||||
conv.append_message(conv.roles[1], message['content'])
|
||||
else:
|
||||
raise ValueError(f'Unknown role: {msg_role}')
|
||||
# Add a blank message for the assistant.
|
||||
conv.append_message(conv.roles[1], '')
|
||||
return conv.get_prompt()
|
||||
|
||||
@@ -4,6 +4,7 @@ Currently supports transformers for PyTorch, and vLLM.
|
||||
|
||||
Currently, GGML format is working in progress.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import importlib
|
||||
import typing as t
|
||||
@@ -18,6 +19,7 @@ from openllm_core._typing_compat import M
|
||||
from openllm_core._typing_compat import ParamSpec
|
||||
from openllm_core._typing_compat import T
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import transformers as _transformers
|
||||
|
||||
@@ -31,6 +33,7 @@ else:
|
||||
|
||||
P = ParamSpec('P')
|
||||
|
||||
|
||||
def load_tokenizer(llm: openllm.LLM[t.Any, T], **tokenizer_attrs: t.Any) -> T:
|
||||
"""Load the tokenizer from BentoML store.
|
||||
|
||||
@@ -47,24 +50,34 @@ def load_tokenizer(llm: openllm.LLM[t.Any, T], **tokenizer_attrs: t.Any) -> T:
|
||||
try:
|
||||
tokenizer = cloudpickle.load(t.cast('t.IO[bytes]', cofile))['tokenizer']
|
||||
except KeyError:
|
||||
raise openllm.exceptions.OpenLLMException("Bento model does not have tokenizer. Make sure to save the tokenizer within the model via 'custom_objects'. "
|
||||
"For example: \"bentoml.transformers.save_model(..., custom_objects={'tokenizer': tokenizer})\"") from None
|
||||
raise openllm.exceptions.OpenLLMException(
|
||||
"Bento model does not have tokenizer. Make sure to save the tokenizer within the model via 'custom_objects'. "
|
||||
'For example: "bentoml.transformers.save_model(..., custom_objects={\'tokenizer\': tokenizer})"'
|
||||
) from None
|
||||
else:
|
||||
tokenizer = _transformers.AutoTokenizer.from_pretrained(bentomodel_fs.getsyspath('/'), trust_remote_code=llm.trust_remote_code, **tokenizer_attrs)
|
||||
tokenizer = _transformers.AutoTokenizer.from_pretrained(
|
||||
bentomodel_fs.getsyspath('/'), trust_remote_code=llm.trust_remote_code, **tokenizer_attrs
|
||||
)
|
||||
|
||||
if tokenizer.pad_token_id is None:
|
||||
if config.pad_token_id is not None: tokenizer.pad_token_id = config.pad_token_id
|
||||
elif config.eos_token_id is not None: tokenizer.pad_token_id = config.eos_token_id
|
||||
elif tokenizer.eos_token_id is not None: tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||
else: tokenizer.add_special_tokens({'pad_token': '[PAD]'})
|
||||
if config.pad_token_id is not None:
|
||||
tokenizer.pad_token_id = config.pad_token_id
|
||||
elif config.eos_token_id is not None:
|
||||
tokenizer.pad_token_id = config.eos_token_id
|
||||
elif tokenizer.eos_token_id is not None:
|
||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||
else:
|
||||
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
|
||||
return tokenizer
|
||||
|
||||
|
||||
class _Caller(t.Protocol[P]):
|
||||
def __call__(self, llm: openllm.LLM[M, T], *args: P.args, **kwargs: P.kwargs) -> t.Any:
|
||||
...
|
||||
def __call__(self, llm: openllm.LLM[M, T], *args: P.args, **kwargs: P.kwargs) -> t.Any: ...
|
||||
|
||||
|
||||
_extras = ['get', 'import_model', 'load_model']
|
||||
|
||||
|
||||
def _make_dispatch_function(fn: str) -> _Caller[P]:
|
||||
def caller(llm: openllm.LLM[M, T], *args: P.args, **kwargs: P.kwargs) -> t.Any:
|
||||
"""Generic function dispatch to correct serialisation submodules based on LLM runtime.
|
||||
@@ -74,30 +87,36 @@ def _make_dispatch_function(fn: str) -> _Caller[P]:
|
||||
> [!NOTE] See 'openllm.serialisation.ggml' if 'llm.__llm_backend__="ggml"'
|
||||
"""
|
||||
serde = 'transformers'
|
||||
if llm.__llm_backend__ == 'ggml': serde = 'ggml'
|
||||
if llm.__llm_backend__ == 'ggml':
|
||||
serde = 'ggml'
|
||||
return getattr(importlib.import_module(f'.{serde}', __name__), fn)(llm, *args, **kwargs)
|
||||
|
||||
return caller
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
|
||||
def get(llm: openllm.LLM[M, T], *args: t.Any, **kwargs: t.Any) -> bentoml.Model:
|
||||
...
|
||||
def get(llm: openllm.LLM[M, T], *args: t.Any, **kwargs: t.Any) -> bentoml.Model: ...
|
||||
|
||||
def import_model(llm: openllm.LLM[M, T], *args: t.Any, **kwargs: t.Any) -> bentoml.Model:
|
||||
...
|
||||
def import_model(llm: openllm.LLM[M, T], *args: t.Any, **kwargs: t.Any) -> bentoml.Model: ...
|
||||
|
||||
def load_model(llm: openllm.LLM[M, T], *args: t.Any, **kwargs: t.Any) -> M: ...
|
||||
|
||||
def load_model(llm: openllm.LLM[M, T], *args: t.Any, **kwargs: t.Any) -> M:
|
||||
...
|
||||
|
||||
_import_structure: dict[str, list[str]] = {'ggml': [], 'transformers': [], 'constants': []}
|
||||
__all__ = ['ggml', 'transformers', 'constants', 'load_tokenizer', *_extras]
|
||||
|
||||
|
||||
def __dir__() -> list[str]:
|
||||
return sorted(__all__)
|
||||
|
||||
|
||||
def __getattr__(name: str) -> t.Any:
|
||||
if name == 'load_tokenizer': return load_tokenizer
|
||||
elif name in _import_structure: return importlib.import_module(f'.{name}', __name__)
|
||||
elif name in _extras: return _make_dispatch_function(name)
|
||||
else: raise AttributeError(f'{__name__} has no attribute {name}')
|
||||
if name == 'load_tokenizer':
|
||||
return load_tokenizer
|
||||
elif name in _import_structure:
|
||||
return importlib.import_module(f'.{name}', __name__)
|
||||
elif name in _extras:
|
||||
return _make_dispatch_function(name)
|
||||
else:
|
||||
raise AttributeError(f'{__name__} has no attribute {name}')
|
||||
|
||||
@@ -1,7 +1,21 @@
|
||||
from __future__ import annotations
|
||||
|
||||
FRAMEWORK_TO_AUTOCLASS_MAPPING = {'pt': ('AutoModelForCausalLM', 'AutoModelForSeq2SeqLM'), 'vllm': ('AutoModelForCausalLM', 'AutoModelForSeq2SeqLM')}
|
||||
HUB_ATTRS = ['cache_dir', 'code_revision', 'force_download', 'local_files_only', 'proxies', 'resume_download', 'revision', 'subfolder', 'use_auth_token']
|
||||
|
||||
FRAMEWORK_TO_AUTOCLASS_MAPPING = {
|
||||
'pt': ('AutoModelForCausalLM', 'AutoModelForSeq2SeqLM'),
|
||||
'vllm': ('AutoModelForCausalLM', 'AutoModelForSeq2SeqLM'),
|
||||
}
|
||||
HUB_ATTRS = [
|
||||
'cache_dir',
|
||||
'code_revision',
|
||||
'force_download',
|
||||
'local_files_only',
|
||||
'proxies',
|
||||
'resume_download',
|
||||
'revision',
|
||||
'subfolder',
|
||||
'use_auth_token',
|
||||
]
|
||||
CONFIG_FILE_NAME = 'config.json'
|
||||
# the below is similar to peft.utils.other.CONFIG_NAME
|
||||
PEFT_CONFIG_NAME = 'adapter_config.json'
|
||||
|
||||
@@ -2,9 +2,11 @@
|
||||
|
||||
This requires ctransformers to be installed.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import typing as t
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import bentoml
|
||||
import openllm
|
||||
@@ -13,11 +15,16 @@ if t.TYPE_CHECKING:
|
||||
|
||||
_conversion_strategy = {'pt': 'ggml'}
|
||||
|
||||
def import_model(llm: openllm.LLM[t.Any, t.Any], *decls: t.Any, trust_remote_code: bool = True, **attrs: t.Any,) -> bentoml.Model:
|
||||
|
||||
def import_model(
|
||||
llm: openllm.LLM[t.Any, t.Any], *decls: t.Any, trust_remote_code: bool = True, **attrs: t.Any
|
||||
) -> bentoml.Model:
|
||||
raise NotImplementedError('Currently work in progress.')
|
||||
|
||||
|
||||
def get(llm: openllm.LLM[t.Any, t.Any]) -> bentoml.Model:
|
||||
raise NotImplementedError('Currently work in progress.')
|
||||
|
||||
|
||||
def load_model(llm: openllm.LLM[M, t.Any], *decls: t.Any, **attrs: t.Any) -> M:
|
||||
raise NotImplementedError('Currently work in progress.')
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""Serialisation related implementation for Transformers-based implementation."""
|
||||
|
||||
from __future__ import annotations
|
||||
import importlib
|
||||
import logging
|
||||
@@ -27,6 +28,7 @@ from ._helpers import infer_autoclass_from_llm
|
||||
from ._helpers import process_config
|
||||
from .weights import HfIgnore
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import types
|
||||
|
||||
@@ -38,29 +40,52 @@ logger = logging.getLogger(__name__)
|
||||
__all__ = ['import_model', 'get', 'load_model']
|
||||
_object_setattr = object.__setattr__
|
||||
|
||||
def _patch_correct_tag(llm: openllm.LLM[M, T], config: transformers.PretrainedConfig, _revision: str | None = None) -> None:
|
||||
|
||||
def _patch_correct_tag(
|
||||
llm: openllm.LLM[M, T], config: transformers.PretrainedConfig, _revision: str | None = None
|
||||
) -> None:
|
||||
# NOTE: The following won't hit during local since we generated a correct version based on local path hash It will only hit if we use model from HF Hub
|
||||
if llm.revision is not None: return
|
||||
if llm.revision is not None:
|
||||
return
|
||||
if not llm._local:
|
||||
try:
|
||||
if _revision is None: _revision = get_hash(config)
|
||||
if _revision is None:
|
||||
_revision = get_hash(config)
|
||||
except ValueError:
|
||||
pass
|
||||
if _revision is None and llm.tag.version is not None: _revision = llm.tag.version
|
||||
if llm._tag.version is None: _object_setattr(llm, '_tag', attr.evolve(llm.tag, version=_revision)) # HACK: This copies the correct revision into llm.tag
|
||||
if llm._revision is None: _object_setattr(llm, '_revision', _revision) # HACK: This copies the correct revision into llm._model_version
|
||||
if _revision is None and llm.tag.version is not None:
|
||||
_revision = llm.tag.version
|
||||
if llm._tag.version is None:
|
||||
_object_setattr(
|
||||
llm, '_tag', attr.evolve(llm.tag, version=_revision)
|
||||
) # HACK: This copies the correct revision into llm.tag
|
||||
if llm._revision is None:
|
||||
_object_setattr(llm, '_revision', _revision) # HACK: This copies the correct revision into llm._model_version
|
||||
|
||||
|
||||
@inject
|
||||
def import_model(llm: openllm.LLM[M, T], *decls: t.Any, trust_remote_code: bool, _model_store: ModelStore = Provide[BentoMLContainer.model_store], **attrs: t.Any) -> bentoml.Model:
|
||||
def import_model(
|
||||
llm: openllm.LLM[M, T],
|
||||
*decls: t.Any,
|
||||
trust_remote_code: bool,
|
||||
_model_store: ModelStore = Provide[BentoMLContainer.model_store],
|
||||
**attrs: t.Any,
|
||||
) -> bentoml.Model:
|
||||
config, hub_attrs, attrs = process_config(llm.model_id, trust_remote_code, **attrs)
|
||||
_patch_correct_tag(llm, config)
|
||||
_, tokenizer_attrs = llm.llm_parameters
|
||||
quantize = llm._quantise
|
||||
safe_serialisation = openllm.utils.first_not_none(attrs.get('safe_serialization'), default=llm._serialisation == 'safetensors')
|
||||
safe_serialisation = openllm.utils.first_not_none(
|
||||
attrs.get('safe_serialization'), default=llm._serialisation == 'safetensors'
|
||||
)
|
||||
metadata: DictStrAny = {'safe_serialisation': safe_serialisation}
|
||||
if quantize: metadata['_quantize'] = quantize
|
||||
if quantize:
|
||||
metadata['_quantize'] = quantize
|
||||
architectures = getattr(config, 'architectures', [])
|
||||
if not architectures: raise RuntimeError('Failed to determine the architecture for this model. Make sure the `config.json` is valid and can be loaded with `transformers.AutoConfig`')
|
||||
if not architectures:
|
||||
raise RuntimeError(
|
||||
'Failed to determine the architecture for this model. Make sure the `config.json` is valid and can be loaded with `transformers.AutoConfig`'
|
||||
)
|
||||
metadata['_pretrained_class'] = architectures[0]
|
||||
metadata['_revision'] = get_hash(config)
|
||||
|
||||
@@ -69,93 +94,152 @@ def import_model(llm: openllm.LLM[M, T], *decls: t.Any, trust_remote_code: bool,
|
||||
if quantize == 'gptq':
|
||||
if not openllm.utils.is_autogptq_available() or not openllm.utils.is_optimum_supports_gptq():
|
||||
raise openllm.exceptions.OpenLLMException(
|
||||
"GPTQ quantisation requires 'auto-gptq' and 'optimum' (Not found in local environment). Install it with 'pip install \"openllm[gptq]\" --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/'"
|
||||
"GPTQ quantisation requires 'auto-gptq' and 'optimum' (Not found in local environment). Install it with 'pip install \"openllm[gptq]\" --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/'"
|
||||
)
|
||||
signatures['generate'] = {'batchable': False}
|
||||
else:
|
||||
attrs['use_safetensors'] = safe_serialisation
|
||||
metadata['_framework'] = llm.__llm_backend__
|
||||
signatures.update({
|
||||
signatures.update(
|
||||
{
|
||||
k: ModelSignature(batchable=False)
|
||||
for k in ('__call__', 'forward', 'generate', 'contrastive_search', 'greedy_search', 'sample', 'beam_search', 'beam_sample', 'group_beam_search', 'constrained_beam_search')
|
||||
})
|
||||
for k in (
|
||||
'__call__',
|
||||
'forward',
|
||||
'generate',
|
||||
'contrastive_search',
|
||||
'greedy_search',
|
||||
'sample',
|
||||
'beam_search',
|
||||
'beam_sample',
|
||||
'group_beam_search',
|
||||
'constrained_beam_search',
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(llm.model_id, trust_remote_code=trust_remote_code, **hub_attrs, **tokenizer_attrs)
|
||||
if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
||||
llm.model_id, trust_remote_code=trust_remote_code, **hub_attrs, **tokenizer_attrs
|
||||
)
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
model = None
|
||||
external_modules: list[types.ModuleType] = [importlib.import_module(tokenizer.__module__)]
|
||||
imported_modules: list[types.ModuleType] = []
|
||||
bentomodel = bentoml.Model.create(llm.tag,
|
||||
module='openllm.serialisation.transformers',
|
||||
api_version='v2.1.0',
|
||||
options=ModelOptions(),
|
||||
context=openllm.utils.generate_context(framework_name='openllm'),
|
||||
labels=openllm.utils.generate_labels(llm),
|
||||
metadata=metadata,
|
||||
signatures=signatures)
|
||||
bentomodel = bentoml.Model.create(
|
||||
llm.tag,
|
||||
module='openllm.serialisation.transformers',
|
||||
api_version='v2.1.0',
|
||||
options=ModelOptions(),
|
||||
context=openllm.utils.generate_context(framework_name='openllm'),
|
||||
labels=openllm.utils.generate_labels(llm),
|
||||
metadata=metadata,
|
||||
signatures=signatures,
|
||||
)
|
||||
with openllm.utils.analytics.set_bentoml_tracking():
|
||||
try:
|
||||
bentomodel.enter_cloudpickle_context(external_modules, imported_modules)
|
||||
tokenizer.save_pretrained(bentomodel.path)
|
||||
if llm._quantization_config or (llm._quantise and llm._quantise not in {'squeezellm', 'awq'}): attrs['quantization_config'] = llm.quantization_config
|
||||
if llm._quantization_config or (llm._quantise and llm._quantise not in {'squeezellm', 'awq'}):
|
||||
attrs['quantization_config'] = llm.quantization_config
|
||||
if quantize == 'gptq':
|
||||
from optimum.gptq.constants import GPTQ_CONFIG
|
||||
|
||||
with open(bentomodel.path_of(GPTQ_CONFIG), 'w', encoding='utf-8') as f:
|
||||
f.write(orjson.dumps(config.quantization_config, option=orjson.OPT_INDENT_2 | orjson.OPT_SORT_KEYS).decode())
|
||||
if llm._local: # possible local path
|
||||
model = infer_autoclass_from_llm(llm, config).from_pretrained(llm.model_id, *decls, config=config, trust_remote_code=trust_remote_code, **hub_attrs, **attrs)
|
||||
model = infer_autoclass_from_llm(llm, config).from_pretrained(
|
||||
llm.model_id, *decls, config=config, trust_remote_code=trust_remote_code, **hub_attrs, **attrs
|
||||
)
|
||||
# for trust_remote_code to work
|
||||
bentomodel.enter_cloudpickle_context([importlib.import_module(model.__module__)], imported_modules)
|
||||
model.save_pretrained(bentomodel.path, max_shard_size='5GB', safe_serialization=safe_serialisation)
|
||||
else:
|
||||
# we will clone the all tings into the bentomodel path without loading model into memory
|
||||
snapshot_download(llm.model_id, local_dir=bentomodel.path, local_dir_use_symlinks=False, ignore_patterns=HfIgnore.ignore_patterns(llm))
|
||||
snapshot_download(
|
||||
llm.model_id,
|
||||
local_dir=bentomodel.path,
|
||||
local_dir_use_symlinks=False,
|
||||
ignore_patterns=HfIgnore.ignore_patterns(llm),
|
||||
)
|
||||
except Exception:
|
||||
raise
|
||||
else:
|
||||
bentomodel.flush() # type: ignore[no-untyped-call]
|
||||
bentomodel.save(_model_store)
|
||||
openllm.utils.analytics.track(openllm.utils.analytics.ModelSaveEvent(module=bentomodel.info.module, model_size_in_kb=openllm.utils.calc_dir_size(bentomodel.path) / 1024))
|
||||
openllm.utils.analytics.track(
|
||||
openllm.utils.analytics.ModelSaveEvent(
|
||||
module=bentomodel.info.module, model_size_in_kb=openllm.utils.calc_dir_size(bentomodel.path) / 1024
|
||||
)
|
||||
)
|
||||
finally:
|
||||
bentomodel.exit_cloudpickle_context(imported_modules)
|
||||
# NOTE: We need to free up the cache after importing the model
|
||||
# in the case where users first run openllm start without the model available locally.
|
||||
if openllm.utils.is_torch_available() and torch.cuda.is_available(): torch.cuda.empty_cache()
|
||||
if openllm.utils.is_torch_available() and torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
del model
|
||||
return bentomodel
|
||||
|
||||
|
||||
def get(llm: openllm.LLM[M, T]) -> bentoml.Model:
|
||||
try:
|
||||
model = bentoml.models.get(llm.tag)
|
||||
backend = model.info.labels['backend']
|
||||
if backend != llm.__llm_backend__: raise openllm.exceptions.OpenLLMException(f"'{model.tag!s}' was saved with backend '{backend}', while loading with '{llm.__llm_backend__}'.")
|
||||
_patch_correct_tag(llm, transformers.AutoConfig.from_pretrained(model.path, trust_remote_code=llm.trust_remote_code), _revision=model.info.metadata.get('_revision'))
|
||||
if backend != llm.__llm_backend__:
|
||||
raise openllm.exceptions.OpenLLMException(
|
||||
f"'{model.tag!s}' was saved with backend '{backend}', while loading with '{llm.__llm_backend__}'."
|
||||
)
|
||||
_patch_correct_tag(
|
||||
llm,
|
||||
transformers.AutoConfig.from_pretrained(model.path, trust_remote_code=llm.trust_remote_code),
|
||||
_revision=model.info.metadata.get('_revision'),
|
||||
)
|
||||
return model
|
||||
except Exception as err:
|
||||
raise openllm.exceptions.OpenLLMException(f'Failed while getting stored artefact (lookup for traceback):\n{err}') from err
|
||||
raise openllm.exceptions.OpenLLMException(
|
||||
f'Failed while getting stored artefact (lookup for traceback):\n{err}'
|
||||
) from err
|
||||
|
||||
|
||||
def load_model(llm: openllm.LLM[M, T], *decls: t.Any, **attrs: t.Any) -> M:
|
||||
if llm._quantise in {'awq', 'squeezellm'}: raise RuntimeError('AWQ is not yet supported with PyTorch backend.')
|
||||
config, attrs = transformers.AutoConfig.from_pretrained(llm.bentomodel.path, return_unused_kwargs=True, trust_remote_code=llm.trust_remote_code, **attrs)
|
||||
if llm._quantise in {'awq', 'squeezellm'}:
|
||||
raise RuntimeError('AWQ is not yet supported with PyTorch backend.')
|
||||
config, attrs = transformers.AutoConfig.from_pretrained(
|
||||
llm.bentomodel.path, return_unused_kwargs=True, trust_remote_code=llm.trust_remote_code, **attrs
|
||||
)
|
||||
auto_class = infer_autoclass_from_llm(llm, config)
|
||||
device_map = attrs.pop('device_map', None)
|
||||
if torch.cuda.is_available():
|
||||
if torch.cuda.device_count() > 1: device_map = 'auto'
|
||||
elif torch.cuda.device_count() == 1: device_map = 'cuda:0'
|
||||
if llm._quantise in {'int8', 'int4'}: attrs['quantization_config'] = llm.quantization_config
|
||||
if torch.cuda.device_count() > 1:
|
||||
device_map = 'auto'
|
||||
elif torch.cuda.device_count() == 1:
|
||||
device_map = 'cuda:0'
|
||||
if llm._quantise in {'int8', 'int4'}:
|
||||
attrs['quantization_config'] = llm.quantization_config
|
||||
|
||||
if '_quantize' in llm.bentomodel.info.metadata:
|
||||
_quantise = llm.bentomodel.info.metadata['_quantize']
|
||||
if _quantise == 'gptq':
|
||||
if not openllm.utils.is_autogptq_available() or not openllm.utils.is_optimum_supports_gptq():
|
||||
raise openllm.exceptions.OpenLLMException(
|
||||
"GPTQ quantisation requires 'auto-gptq' and 'optimum' (Not found in local environment). Install it with 'pip install \"openllm[gptq]\" --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/'"
|
||||
"GPTQ quantisation requires 'auto-gptq' and 'optimum' (Not found in local environment). Install it with 'pip install \"openllm[gptq]\" --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/'"
|
||||
)
|
||||
if llm.config['model_type'] != 'causal_lm':
|
||||
raise openllm.exceptions.OpenLLMException(
|
||||
f"GPTQ only support Causal LM (got {llm.__class__} of {llm.config['model_type']})"
|
||||
)
|
||||
if llm.config['model_type'] != 'causal_lm': raise openllm.exceptions.OpenLLMException(f"GPTQ only support Causal LM (got {llm.__class__} of {llm.config['model_type']})")
|
||||
|
||||
# TODO: investigate load with flash attention
|
||||
model = auto_class.from_pretrained(llm.bentomodel.path, device_map=device_map, **attrs)
|
||||
else:
|
||||
model = auto_class.from_pretrained(llm.bentomodel.path, *decls, config=config, trust_remote_code=llm.trust_remote_code, device_map=device_map, **attrs)
|
||||
model = auto_class.from_pretrained(
|
||||
llm.bentomodel.path,
|
||||
*decls,
|
||||
config=config,
|
||||
trust_remote_code=llm.trust_remote_code,
|
||||
device_map=device_map,
|
||||
**attrs,
|
||||
)
|
||||
return t.cast('M', model)
|
||||
|
||||
@@ -10,6 +10,7 @@ import openllm
|
||||
from openllm.serialisation.constants import FRAMEWORK_TO_AUTOCLASS_MAPPING
|
||||
from openllm.serialisation.constants import HUB_ATTRS
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from transformers.models.auto.auto_factory import _BaseAutoModelClass
|
||||
|
||||
@@ -17,12 +18,17 @@ if t.TYPE_CHECKING:
|
||||
from openllm_core._typing_compat import M
|
||||
from openllm_core._typing_compat import T
|
||||
|
||||
|
||||
def get_hash(config: transformers.PretrainedConfig) -> str:
|
||||
_commit_hash = getattr(config, '_commit_hash', None)
|
||||
if _commit_hash is None: raise ValueError(f'Cannot find commit hash in {config}')
|
||||
if _commit_hash is None:
|
||||
raise ValueError(f'Cannot find commit hash in {config}')
|
||||
return _commit_hash
|
||||
|
||||
def process_config(model_id: str, trust_remote_code: bool, **attrs: t.Any) -> tuple[transformers.PretrainedConfig, DictStrAny, DictStrAny]:
|
||||
|
||||
def process_config(
|
||||
model_id: str, trust_remote_code: bool, **attrs: t.Any
|
||||
) -> tuple[transformers.PretrainedConfig, DictStrAny, DictStrAny]:
|
||||
"""A helper function that correctly parse config and attributes for transformers.PretrainedConfig.
|
||||
|
||||
Args:
|
||||
@@ -38,25 +44,36 @@ def process_config(model_id: str, trust_remote_code: bool, **attrs: t.Any) -> tu
|
||||
hub_attrs = {k: attrs.pop(k) for k in HUB_ATTRS if k in attrs}
|
||||
if not isinstance(config, transformers.PretrainedConfig):
|
||||
copied_attrs = copy.deepcopy(attrs)
|
||||
if copied_attrs.get('torch_dtype', None) == 'auto': copied_attrs.pop('torch_dtype')
|
||||
config, attrs = transformers.AutoConfig.from_pretrained(model_id, return_unused_kwargs=True, trust_remote_code=trust_remote_code, **hub_attrs, **copied_attrs)
|
||||
if copied_attrs.get('torch_dtype', None) == 'auto':
|
||||
copied_attrs.pop('torch_dtype')
|
||||
config, attrs = transformers.AutoConfig.from_pretrained(
|
||||
model_id, return_unused_kwargs=True, trust_remote_code=trust_remote_code, **hub_attrs, **copied_attrs
|
||||
)
|
||||
return config, hub_attrs, attrs
|
||||
|
||||
|
||||
def infer_autoclass_from_llm(llm: openllm.LLM[M, T], config: transformers.PretrainedConfig, /) -> _BaseAutoModelClass:
|
||||
if llm.trust_remote_code:
|
||||
autoclass = 'AutoModelForSeq2SeqLM' if llm.config['model_type'] == 'seq2seq_lm' else 'AutoModelForCausalLM'
|
||||
if not hasattr(config, 'auto_map'):
|
||||
raise ValueError(f'Invalid configuration for {llm.model_id}. ``trust_remote_code=True`` requires `transformers.PretrainedConfig` to contain a `auto_map` mapping')
|
||||
raise ValueError(
|
||||
f'Invalid configuration for {llm.model_id}. ``trust_remote_code=True`` requires `transformers.PretrainedConfig` to contain a `auto_map` mapping'
|
||||
)
|
||||
# in case this model doesn't use the correct auto class for model type, for example like chatglm
|
||||
# where it uses AutoModel instead of AutoModelForCausalLM. Then we fallback to AutoModel
|
||||
if autoclass not in config.auto_map: autoclass = 'AutoModel'
|
||||
if autoclass not in config.auto_map:
|
||||
autoclass = 'AutoModel'
|
||||
return getattr(transformers, autoclass)
|
||||
else:
|
||||
if type(config) in transformers.MODEL_FOR_CAUSAL_LM_MAPPING: idx = 0
|
||||
elif type(config) in transformers.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING: idx = 1
|
||||
else: raise openllm.exceptions.OpenLLMException(f'Model type {type(config)} is not supported yet.')
|
||||
if type(config) in transformers.MODEL_FOR_CAUSAL_LM_MAPPING:
|
||||
idx = 0
|
||||
elif type(config) in transformers.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING:
|
||||
idx = 1
|
||||
else:
|
||||
raise openllm.exceptions.OpenLLMException(f'Model type {type(config)} is not supported yet.')
|
||||
return getattr(transformers, FRAMEWORK_TO_AUTOCLASS_MAPPING[llm.__llm_backend__][idx])
|
||||
|
||||
|
||||
def check_unintialised_params(model: torch.nn.Module) -> None:
|
||||
unintialized = [n for n, param in model.named_parameters() if param.data.device == torch.device('meta')]
|
||||
if len(unintialized) > 0:
|
||||
|
||||
@@ -8,6 +8,7 @@ from huggingface_hub import HfApi
|
||||
|
||||
from openllm_core.exceptions import Error
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from huggingface_hub.hf_api import ModelInfo as HfModelInfo
|
||||
|
||||
@@ -19,13 +20,17 @@ if t.TYPE_CHECKING:
|
||||
__global_inst__ = None
|
||||
__cached_id__: dict[str, HfModelInfo] = dict()
|
||||
|
||||
|
||||
def Client() -> HfApi:
|
||||
global __global_inst__ # noqa: PLW0603
|
||||
if __global_inst__ is None: __global_inst__ = HfApi()
|
||||
if __global_inst__ is None:
|
||||
__global_inst__ = HfApi()
|
||||
return __global_inst__
|
||||
|
||||
|
||||
def ModelInfo(model_id: str, revision: str | None = None) -> HfModelInfo:
|
||||
if model_id in __cached_id__: return __cached_id__[model_id]
|
||||
if model_id in __cached_id__:
|
||||
return __cached_id__[model_id]
|
||||
try:
|
||||
__cached_id__[model_id] = Client().model_info(model_id, revision=revision)
|
||||
return __cached_id__[model_id]
|
||||
@@ -33,9 +38,11 @@ def ModelInfo(model_id: str, revision: str | None = None) -> HfModelInfo:
|
||||
traceback.print_exc()
|
||||
raise Error(f'Failed to fetch {model_id} from huggingface.co') from err
|
||||
|
||||
|
||||
def has_safetensors_weights(model_id: str, revision: str | None = None) -> bool:
|
||||
return any(s.rfilename.endswith('.safetensors') for s in ModelInfo(model_id, revision=revision).siblings)
|
||||
|
||||
|
||||
@attr.define(slots=True)
|
||||
class HfIgnore:
|
||||
safetensors = '*.safetensors'
|
||||
@@ -48,9 +55,12 @@ class HfIgnore:
|
||||
def ignore_patterns(cls, llm: openllm.LLM[M, T]) -> list[str]:
|
||||
if llm.__llm_backend__ in {'vllm', 'pt'}:
|
||||
base = [cls.tf, cls.flax, cls.gguf]
|
||||
if has_safetensors_weights(llm.model_id): base.append(cls.pt)
|
||||
else: base.append(cls.safetensors)
|
||||
elif llm.__llm_backend__ == 'ggml': base = [cls.tf, cls.flax, cls.pt, cls.safetensors]
|
||||
if has_safetensors_weights(llm.model_id):
|
||||
base.append(cls.pt)
|
||||
else:
|
||||
base.append(cls.safetensors)
|
||||
elif llm.__llm_backend__ == 'ggml':
|
||||
base = [cls.tf, cls.flax, cls.pt, cls.safetensors]
|
||||
else:
|
||||
raise ValueError('Unknown backend (should never happen at all.)')
|
||||
# filter out these files, since we probably don't need them for now.
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""Tests utilities for OpenLLM."""
|
||||
|
||||
from __future__ import annotations
|
||||
import contextlib
|
||||
import logging
|
||||
@@ -9,14 +10,18 @@ import typing as t
|
||||
import bentoml
|
||||
import openllm
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from openllm_core._typing_compat import LiteralBackend
|
||||
from openllm_core._typing_compat import LiteralQuantise
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def build_bento(model: str, model_id: str | None = None, quantize: LiteralQuantise | None = None, cleanup: bool = False) -> t.Iterator[bentoml.Bento]:
|
||||
def build_bento(
|
||||
model: str, model_id: str | None = None, quantize: LiteralQuantise | None = None, cleanup: bool = False
|
||||
) -> t.Iterator[bentoml.Bento]:
|
||||
logger.info('Building BentoML for %s', model)
|
||||
bento = openllm.build(model, model_id=model_id, quantize=quantize)
|
||||
yield bento
|
||||
@@ -24,29 +29,39 @@ def build_bento(model: str, model_id: str | None = None, quantize: LiteralQuanti
|
||||
logger.info('Deleting %s', bento.tag)
|
||||
bentoml.bentos.delete(bento.tag)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def build_container(bento: bentoml.Bento | str | bentoml.Tag, image_tag: str | None = None, cleanup: bool = False, **attrs: t.Any) -> t.Iterator[str]:
|
||||
if isinstance(bento, bentoml.Bento): bento_tag = bento.tag
|
||||
else: bento_tag = bentoml.Tag.from_taglike(bento)
|
||||
if image_tag is None: image_tag = str(bento_tag)
|
||||
def build_container(
|
||||
bento: bentoml.Bento | str | bentoml.Tag, image_tag: str | None = None, cleanup: bool = False, **attrs: t.Any
|
||||
) -> t.Iterator[str]:
|
||||
if isinstance(bento, bentoml.Bento):
|
||||
bento_tag = bento.tag
|
||||
else:
|
||||
bento_tag = bentoml.Tag.from_taglike(bento)
|
||||
if image_tag is None:
|
||||
image_tag = str(bento_tag)
|
||||
executable = shutil.which('docker')
|
||||
if not executable: raise RuntimeError('docker executable not found')
|
||||
if not executable:
|
||||
raise RuntimeError('docker executable not found')
|
||||
try:
|
||||
logger.info('Building container for %s', bento_tag)
|
||||
bentoml.container.build(bento_tag, backend='docker', image_tag=(image_tag,), progress='plain', **attrs,)
|
||||
bentoml.container.build(bento_tag, backend='docker', image_tag=(image_tag,), progress='plain', **attrs)
|
||||
yield image_tag
|
||||
finally:
|
||||
if cleanup:
|
||||
logger.info('Deleting container %s', image_tag)
|
||||
subprocess.check_output([executable, 'rmi', '-f', image_tag])
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def prepare(model: str,
|
||||
model_id: str,
|
||||
backend: LiteralBackend = 'pt',
|
||||
deployment_mode: t.Literal['container', 'local'] = 'local',
|
||||
clean_context: contextlib.ExitStack | None = None,
|
||||
cleanup: bool = True) -> t.Iterator[str]:
|
||||
def prepare(
|
||||
model: str,
|
||||
model_id: str,
|
||||
backend: LiteralBackend = 'pt',
|
||||
deployment_mode: t.Literal['container', 'local'] = 'local',
|
||||
clean_context: contextlib.ExitStack | None = None,
|
||||
cleanup: bool = True,
|
||||
) -> t.Iterator[str]:
|
||||
if clean_context is None:
|
||||
clean_context = contextlib.ExitStack()
|
||||
cleanup = True
|
||||
@@ -60,4 +75,5 @@ def prepare(model: str,
|
||||
if deployment_mode == 'container':
|
||||
container_name = clean_context.enter_context(build_container(bento, image_tag=container_name, cleanup=cleanup))
|
||||
yield container_name
|
||||
if cleanup: clean_context.close()
|
||||
if cleanup:
|
||||
clean_context.close()
|
||||
|
||||
@@ -3,12 +3,14 @@
|
||||
User can import these function for convenience, but
|
||||
we won't ensure backward compatibility for these functions. So use with caution.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import functools
|
||||
import typing as t
|
||||
|
||||
import openllm_core
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import openllm
|
||||
|
||||
@@ -62,23 +64,38 @@ if t.TYPE_CHECKING:
|
||||
from openllm_core.utils import validate_is_path as validate_is_path
|
||||
from openllm_core.utils.serde import converter as converter
|
||||
|
||||
|
||||
def generate_labels(llm: openllm.LLM[t.Any, t.Any]) -> dict[str, t.Any]:
|
||||
return {'backend': llm.__llm_backend__, 'framework': 'openllm', 'model_name': llm.config['model_name'], 'architecture': llm.config['architecture'], 'serialisation': llm._serialisation}
|
||||
return {
|
||||
'backend': llm.__llm_backend__,
|
||||
'framework': 'openllm',
|
||||
'model_name': llm.config['model_name'],
|
||||
'architecture': llm.config['architecture'],
|
||||
'serialisation': llm._serialisation,
|
||||
}
|
||||
|
||||
|
||||
def available_devices() -> tuple[str, ...]:
|
||||
"""Return available GPU under system. Currently only supports NVIDIA GPUs."""
|
||||
from .._strategies import NvidiaGpuResource
|
||||
|
||||
return tuple(NvidiaGpuResource.from_system())
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=1)
|
||||
def device_count() -> int:
|
||||
return len(available_devices())
|
||||
|
||||
|
||||
__all__ = ['generate_labels', 'available_devices', 'device_count']
|
||||
|
||||
|
||||
def __dir__() -> t.Sequence[str]:
|
||||
return sorted(__all__)
|
||||
|
||||
|
||||
def __getattr__(it: str) -> t.Any:
|
||||
if hasattr(openllm_core.utils, it): return getattr(openllm_core.utils, it)
|
||||
else: raise AttributeError(f'module {__name__} has no attribute {it}')
|
||||
if hasattr(openllm_core.utils, it):
|
||||
return getattr(openllm_core.utils, it)
|
||||
else:
|
||||
raise AttributeError(f'module {__name__} has no attribute {it}')
|
||||
|
||||
Reference in New Issue
Block a user