mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-03-08 00:59:50 -05:00
perf: unify LLM interface (#518)
Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> Signed-off-by: paperspace <29749331+aarnphm@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
'''Serialisation utilities for OpenLLM.
|
||||
|
||||
Currently supports transformers for PyTorch, Tensorflow and Flax.
|
||||
Currently supports transformers for PyTorch, and vLLM.
|
||||
|
||||
Currently, GGML format is working in progress.
|
||||
'''
|
||||
@@ -19,11 +19,15 @@ from openllm_core._typing_compat import ParamSpec
|
||||
from openllm_core._typing_compat import T
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import transformers as _transformers
|
||||
|
||||
import bentoml
|
||||
|
||||
from . import constants as constants
|
||||
from . import ggml as ggml
|
||||
from . import transformers as transformers
|
||||
else:
|
||||
_transformers = openllm.utils.LazyLoader('_transformers', globals(), 'transformers')
|
||||
|
||||
P = ParamSpec('P')
|
||||
|
||||
@@ -33,12 +37,11 @@ def load_tokenizer(llm: openllm.LLM[t.Any, T], **tokenizer_attrs: t.Any) -> T:
|
||||
By default, it will try to find the bentomodel whether it is in store..
|
||||
If model is not found, it will raises a ``bentoml.exceptions.NotFound``.
|
||||
'''
|
||||
from .transformers._helpers import infer_tokenizers_from_llm
|
||||
from .transformers._helpers import process_config
|
||||
|
||||
config, *_ = process_config(llm._bentomodel.path, llm.trust_remote_code)
|
||||
config, *_ = process_config(llm.bentomodel.path, llm.trust_remote_code)
|
||||
|
||||
bentomodel_fs = fs.open_fs(llm._bentomodel.path)
|
||||
bentomodel_fs = fs.open_fs(llm.bentomodel.path)
|
||||
if bentomodel_fs.isfile(CUSTOM_OBJECTS_FILENAME):
|
||||
with bentomodel_fs.open(CUSTOM_OBJECTS_FILENAME, 'rb') as cofile:
|
||||
try:
|
||||
@@ -47,7 +50,7 @@ def load_tokenizer(llm: openllm.LLM[t.Any, T], **tokenizer_attrs: t.Any) -> T:
|
||||
raise openllm.exceptions.OpenLLMException("Bento model does not have tokenizer. Make sure to save the tokenizer within the model via 'custom_objects'. "
|
||||
"For example: \"bentoml.transformers.save_model(..., custom_objects={'tokenizer': tokenizer})\"") from None
|
||||
else:
|
||||
tokenizer = infer_tokenizers_from_llm(llm).from_pretrained(bentomodel_fs.getsyspath('/'), trust_remote_code=llm.trust_remote_code, **tokenizer_attrs)
|
||||
tokenizer = _transformers.AutoTokenizer.from_pretrained(bentomodel_fs.getsyspath('/'), trust_remote_code=llm.trust_remote_code, **tokenizer_attrs)
|
||||
|
||||
if tokenizer.pad_token_id is None:
|
||||
if config.pad_token_id is not None: tokenizer.pad_token_id = config.pad_token_id
|
||||
@@ -66,7 +69,7 @@ def _make_dispatch_function(fn: str) -> _Caller[P]:
|
||||
def caller(llm: openllm.LLM[M, T], *args: P.args, **kwargs: P.kwargs) -> t.Any:
|
||||
"""Generic function dispatch to correct serialisation submodules based on LLM runtime.
|
||||
|
||||
> [!NOTE] See 'openllm.serialisation.transformers' if 'llm.__llm_backend__ in ("pt", "tf", "flax", "vllm")'
|
||||
> [!NOTE] See 'openllm.serialisation.transformers' if 'llm.__llm_backend__ in ("pt", "vllm")'
|
||||
|
||||
> [!NOTE] See 'openllm.serialisation.ggml' if 'llm.__llm_backend__="ggml"'
|
||||
"""
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
FRAMEWORK_TO_AUTOCLASS_MAPPING = {
|
||||
'pt': ('AutoModelForCausalLM', 'AutoModelForSeq2SeqLM'),
|
||||
'tf': ('TFAutoModelForCausalLM', 'TFAutoModelForSeq2SeqLM'),
|
||||
'flax': ('FlaxAutoModelForCausalLM', 'FlaxAutoModelForSeq2SeqLM'),
|
||||
'vllm': ('AutoModelForCausalLM', 'AutoModelForSeq2SeqLM')
|
||||
}
|
||||
FRAMEWORK_TO_AUTOCLASS_MAPPING = {'pt': ('AutoModelForCausalLM', 'AutoModelForSeq2SeqLM'), 'vllm': ('AutoModelForCausalLM', 'AutoModelForSeq2SeqLM')}
|
||||
HUB_ATTRS = ['cache_dir', 'code_revision', 'force_download', 'local_files_only', 'proxies', 'resume_download', 'revision', 'subfolder', 'use_auth_token']
|
||||
CONFIG_FILE_NAME = 'config.json'
|
||||
# the below is similar to peft.utils.other.CONFIG_NAME
|
||||
PEFT_CONFIG_NAME = 'adapter_config.json'
|
||||
|
||||
@@ -4,10 +4,10 @@ import importlib
|
||||
import logging
|
||||
import typing as t
|
||||
|
||||
import attr
|
||||
import orjson
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
from packaging.version import Version
|
||||
from simple_di import Provide
|
||||
from simple_di import inject
|
||||
|
||||
@@ -16,13 +16,13 @@ import openllm
|
||||
|
||||
from bentoml._internal.configuration.containers import BentoMLContainer
|
||||
from bentoml._internal.models.model import ModelOptions
|
||||
from bentoml._internal.models.model import ModelSignature
|
||||
from openllm_core._typing_compat import M
|
||||
from openllm_core._typing_compat import T
|
||||
|
||||
from ._helpers import check_unintialised_params
|
||||
from ._helpers import get_hash
|
||||
from ._helpers import infer_autoclass_from_llm
|
||||
from ._helpers import infer_tokenizers_from_llm
|
||||
from ._helpers import make_model_signatures
|
||||
from ._helpers import process_config
|
||||
from .weights import HfIgnore
|
||||
|
||||
@@ -32,16 +32,30 @@ if t.TYPE_CHECKING:
|
||||
import auto_gptq as autogptq
|
||||
import torch
|
||||
import torch.nn
|
||||
import transformers
|
||||
|
||||
from bentoml._internal.models import ModelStore
|
||||
from openllm_core._typing_compat import DictStrAny
|
||||
else:
|
||||
transformers = openllm.utils.LazyLoader('transformers', globals(), 'transformers')
|
||||
autogptq = openllm.utils.LazyLoader('autogptq', globals(), 'auto_gptq')
|
||||
torch = openllm.utils.LazyLoader('torch', globals(), 'torch')
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
__all__ = ['import_model', 'get', 'load_model']
|
||||
_object_setattr = object.__setattr__
|
||||
|
||||
def _patch_correct_tag(llm: openllm.LLM[M, T], config: transformers.PretrainedConfig, _revision: str | None = None) -> None:
|
||||
# NOTE: The following won't hit during local since we generated a correct version based on local path hash It will only hit if we use model from HF Hub
|
||||
if not llm._local:
|
||||
try:
|
||||
if _revision is None: _revision = get_hash(config)
|
||||
except ValueError:
|
||||
pass
|
||||
if llm._tag.version is None: _object_setattr(llm, '_tag', attr.evolve(llm.tag, version=_revision)) # HACK: This copies the correct revision into llm.tag
|
||||
else: _revision = llm._tag.version
|
||||
if llm._revision is None: _object_setattr(llm, '_revision', _revision) # HACK: This copies the correct revision into llm._model_version
|
||||
|
||||
@inject
|
||||
def import_model(llm: openllm.LLM[M, T], *decls: t.Any, trust_remote_code: bool, _model_store: ModelStore = Provide[BentoMLContainer.model_store], **attrs: t.Any) -> bentoml.Model:
|
||||
@@ -49,7 +63,7 @@ def import_model(llm: openllm.LLM[M, T], *decls: t.Any, trust_remote_code: bool,
|
||||
|
||||
For all kwargs, it will be parsed into `transformers.AutoConfig.from_pretrained` first,
|
||||
returning all of the unused kwargs.
|
||||
The unused kwargs then parsed directly into AutoModelForSeq2SeqLM or AutoModelForCausalLM (+ TF, Flax variants).
|
||||
The unused kwargs then parsed directly into AutoModelForSeq2SeqLM or AutoModelForCausalLM.
|
||||
For all tokenizer kwargs, make sure to prefix it with `_tokenizer_` to avoid confusion.
|
||||
|
||||
Note: Currently, there are only two tasks supported: `text-generation` and `text2text-generation`.
|
||||
@@ -57,20 +71,22 @@ def import_model(llm: openllm.LLM[M, T], *decls: t.Any, trust_remote_code: bool,
|
||||
Refer to Transformers documentation for more information about kwargs.
|
||||
|
||||
Args:
|
||||
llm: The LLM instance for this given model.
|
||||
trust_remote_code: Whether to trust the remote code when loading the model.
|
||||
*decls: Args to be passed into AutoModelForSeq2SeqLM or AutoModelForCausalLM (+ TF, Flax variants).
|
||||
**attrs: Kwargs to be passed into AutoModelForSeq2SeqLM or AutoModelForCausalLM (+ TF, Flax variants).
|
||||
llm: The LLM instance for this given model.
|
||||
trust_remote_code: Whether to trust the remote code when loading the model.
|
||||
*decls: Args to be passed into AutoModelForSeq2SeqLM or AutoModelForCausalLM.
|
||||
**attrs: Kwargs to be passed into AutoModelForSeq2SeqLM or AutoModelForCausalLM.
|
||||
"""
|
||||
config, hub_attrs, attrs = process_config(llm.model_id, trust_remote_code, **attrs)
|
||||
_patch_correct_tag(llm, config)
|
||||
_, tokenizer_attrs = llm.llm_parameters
|
||||
quantize = llm._quantize
|
||||
quantize = llm._quantise
|
||||
safe_serialisation = openllm.utils.first_not_none(attrs.get('safe_serialization'), default=llm._serialisation == 'safetensors')
|
||||
metadata: DictStrAny = {'safe_serialisation': safe_serialisation}
|
||||
if quantize: metadata['_quantize'] = quantize
|
||||
architectures = getattr(config, 'architectures', [])
|
||||
if not architectures: raise RuntimeError('Failed to determine the architecture for this model. Make sure the `config.json` is valid and can be loaded with `transformers.AutoConfig`')
|
||||
metadata['_pretrained_class'] = architectures[0]
|
||||
metadata['_revision'] = get_hash(config)
|
||||
|
||||
signatures: DictStrAny = {}
|
||||
|
||||
@@ -79,26 +95,24 @@ def import_model(llm: openllm.LLM[M, T], *decls: t.Any, trust_remote_code: bool,
|
||||
raise openllm.exceptions.OpenLLMException(
|
||||
"GPTQ quantisation requires 'auto-gptq' and 'optimum' (Not found in local environment). Install it with 'pip install \"openllm[gptq]\" --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/'"
|
||||
)
|
||||
if llm.config['model_type'] != 'causal_lm':
|
||||
raise openllm.exceptions.OpenLLMException(f"GPTQ only support Causal LM (got {llm.__class__} of {llm.config['model_type']})")
|
||||
signatures['generate'] = {'batchable': False}
|
||||
else:
|
||||
# this model might be called with --quantize int4, therefore we need to pop this out
|
||||
# since saving int4 is not yet supported
|
||||
if 'quantization_config' in attrs and getattr(attrs['quantization_config'], 'load_in_4bit', False):
|
||||
attrs.pop('quantization_config')
|
||||
if llm.__llm_backend__ != 'flax': attrs['use_safetensors'] = safe_serialisation
|
||||
attrs['use_safetensors'] = safe_serialisation
|
||||
metadata['_framework'] = llm.__llm_backend__
|
||||
signatures.update(make_model_signatures(llm))
|
||||
signatures.update({
|
||||
k: ModelSignature(batchable=False)
|
||||
for k in ('__call__', 'forward', 'generate', 'contrastive_search', 'greedy_search', 'sample', 'beam_search', 'beam_sample', 'group_beam_search', 'constrained_beam_search')
|
||||
})
|
||||
|
||||
tokenizer = infer_tokenizers_from_llm(llm).from_pretrained(llm.model_id, trust_remote_code=trust_remote_code, **hub_attrs, **tokenizer_attrs)
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(llm.model_id, trust_remote_code=trust_remote_code, **hub_attrs, **tokenizer_attrs)
|
||||
if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
model = None
|
||||
external_modules: list[types.ModuleType] = [importlib.import_module(tokenizer.__module__)]
|
||||
imported_modules: list[types.ModuleType] = []
|
||||
bentomodel = bentoml.Model.create(llm.tag,
|
||||
module='openllm.serialisation.transformers',
|
||||
api_version='v2',
|
||||
api_version='v2.1.0',
|
||||
options=ModelOptions(),
|
||||
context=openllm.utils.generate_context(framework_name='openllm'),
|
||||
labels=openllm.utils.generate_labels(llm),
|
||||
@@ -108,13 +122,12 @@ def import_model(llm: openllm.LLM[M, T], *decls: t.Any, trust_remote_code: bool,
|
||||
try:
|
||||
bentomodel.enter_cloudpickle_context(external_modules, imported_modules)
|
||||
tokenizer.save_pretrained(bentomodel.path)
|
||||
if llm._quantise or llm._quantization_config: attrs['quantization_config'] = llm.quantization_config
|
||||
if quantize == 'gptq':
|
||||
from optimum.gptq.constants import GPTQ_CONFIG
|
||||
with open(bentomodel.path_of(GPTQ_CONFIG), 'w', encoding='utf-8') as f:
|
||||
f.write(orjson.dumps(config.quantization_config, option=orjson.OPT_INDENT_2 | orjson.OPT_SORT_KEYS).decode())
|
||||
if llm._local:
|
||||
# possible local path
|
||||
logger.debug('Model will be loaded into memory to save to target store as it is from local path.')
|
||||
if llm._local: # possible local path
|
||||
model = infer_autoclass_from_llm(llm, config).from_pretrained(llm.model_id, *decls, config=config, trust_remote_code=trust_remote_code, **hub_attrs, **attrs)
|
||||
# for trust_remote_code to work
|
||||
bentomodel.enter_cloudpickle_context([importlib.import_module(model.__module__)], imported_modules)
|
||||
@@ -133,6 +146,7 @@ def import_model(llm: openllm.LLM[M, T], *decls: t.Any, trust_remote_code: bool,
|
||||
# NOTE: We need to free up the cache after importing the model
|
||||
# in the case where users first run openllm start without the model available locally.
|
||||
if openllm.utils.is_torch_available() and torch.cuda.is_available(): torch.cuda.empty_cache()
|
||||
del model
|
||||
return bentomodel
|
||||
|
||||
def get(llm: openllm.LLM[M, T], auto_import: bool = False) -> bentoml.Model:
|
||||
@@ -145,31 +159,35 @@ def get(llm: openllm.LLM[M, T], auto_import: bool = False) -> bentoml.Model:
|
||||
'''
|
||||
try:
|
||||
model = bentoml.models.get(llm.tag)
|
||||
if Version(model.info.api_version) < Version('v2'):
|
||||
raise openllm.exceptions.OpenLLMException('Please run "openllm prune -y --include-bentos" and upgrade all saved model to latest release.')
|
||||
if model.info.labels['backend'] != llm.__llm_backend__:
|
||||
raise openllm.exceptions.OpenLLMException(f"Model {model.tag} was saved with backend {model.info.labels['backend']}, while loading with {llm.__llm_backend__}.")
|
||||
backend = model.info.labels['backend']
|
||||
if backend != llm.__llm_backend__: raise openllm.exceptions.OpenLLMException(f"'{model.tag!s}' was saved with backend '{backend}', while loading with '{llm.__llm_backend__}'.")
|
||||
_patch_correct_tag(llm, process_config(model.path, llm.trust_remote_code)[0], _revision=t.cast(t.Optional[str], model.info.metadata.get('_revision')))
|
||||
return model
|
||||
except Exception as err:
|
||||
if auto_import: return import_model(llm, trust_remote_code=llm.trust_remote_code)
|
||||
raise openllm.exceptions.OpenLLMException(f'Failed while getting stored artefact (lookup for traceback):\n{err}') from err
|
||||
|
||||
def load_model(llm: openllm.LLM[M, T], *decls: t.Any, **attrs: t.Any) -> M:
|
||||
config, hub_attrs, attrs = process_config(llm.model_id, llm.trust_remote_code, **attrs)
|
||||
config, hub_attrs, attrs = process_config(llm.bentomodel.path, llm.trust_remote_code, **attrs)
|
||||
_patch_correct_tag(llm, config, _revision=t.cast(t.Optional[str], llm.bentomodel.info.metadata.get('_revision')))
|
||||
auto_class = infer_autoclass_from_llm(llm, config)
|
||||
device_map: str | None = attrs.pop('device_map', 'auto' if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None)
|
||||
if llm._quantise or llm._quantization_config: attrs['quantization_config'] = llm.quantization_config
|
||||
|
||||
if '_quantize' in llm._bentomodel.info.metadata and llm._bentomodel.info.metadata['_quantize'] == 'gptq':
|
||||
if '_quantize' in llm.bentomodel.info.metadata and llm.bentomodel.info.metadata['_quantize'] == 'gptq':
|
||||
if not openllm.utils.is_autogptq_available() or not openllm.utils.is_optimum_supports_gptq():
|
||||
raise openllm.exceptions.OpenLLMException(
|
||||
"GPTQ quantisation requires 'auto-gptq' and 'optimum' (Not found in local environment). Install it with 'pip install \"openllm[gptq]\" --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/'"
|
||||
)
|
||||
if llm.config['model_type'] != 'causal_lm': raise openllm.exceptions.OpenLLMException(f"GPTQ only support Causal LM (got {llm.__class__} of {llm.config['model_type']})")
|
||||
|
||||
model = auto_class.from_pretrained(llm._bentomodel.path, device_map='auto', **hub_attrs, **attrs)
|
||||
try:
|
||||
model = auto_class.from_pretrained(llm.bentomodel.path, device_map='auto', use_flash_attention_2=True, **hub_attrs, **attrs)
|
||||
except Exception as err:
|
||||
logger.debug("Exception caught while trying to load with 'flash_attention_2': %s", err)
|
||||
model = auto_class.from_pretrained(llm.bentomodel.path, device_map='auto', use_flash_attention_2=False, **hub_attrs, **attrs)
|
||||
# XXX: Use the below logic once TheBloke finished migration to new GPTQConfig from transformers
|
||||
# Seems like the logic below requires to add support for safetensors on accelerate
|
||||
#
|
||||
# from accelerate import init_empty_weights
|
||||
# from optimum.gptq import load_quantized_model
|
||||
# # disable exllama if gptq is loaded on CPU
|
||||
@@ -179,6 +197,6 @@ def load_model(llm: openllm.LLM[M, T], *decls: t.Any, **attrs: t.Any) -> M:
|
||||
# empty.tie_weights()
|
||||
# model = load_quantized_model(empty, save_folder=llm._bentomodel.path, device_map='auto', disable_exllama=disable_exllama)
|
||||
else:
|
||||
model = auto_class.from_pretrained(llm._bentomodel.path, *decls, config=config, trust_remote_code=llm.trust_remote_code, device_map=device_map, **hub_attrs, **attrs).eval()
|
||||
if llm.__llm_backend__ in {'pt', 'vllm'}: check_unintialised_params(model)
|
||||
model = auto_class.from_pretrained(llm.bentomodel.path, *decls, config=config, trust_remote_code=llm.trust_remote_code, device_map=device_map, **hub_attrs, **attrs).eval()
|
||||
if llm.__llm_backend__ == 'pt': check_unintialised_params(model)
|
||||
return t.cast('M', model)
|
||||
|
||||
@@ -5,7 +5,6 @@ import typing as t
|
||||
import openllm
|
||||
import openllm_core
|
||||
|
||||
from bentoml._internal.models.model import ModelSignature
|
||||
from openllm.serialisation.constants import FRAMEWORK_TO_AUTOCLASS_MAPPING
|
||||
from openllm.serialisation.constants import HUB_ATTRS
|
||||
|
||||
@@ -15,13 +14,17 @@ if t.TYPE_CHECKING:
|
||||
|
||||
from transformers.models.auto.auto_factory import _BaseAutoModelClass
|
||||
|
||||
from bentoml._internal.models.model import ModelSignaturesType
|
||||
from openllm_core._typing_compat import DictStrAny
|
||||
from openllm_core._typing_compat import M
|
||||
from openllm_core._typing_compat import T
|
||||
else:
|
||||
transformers, torch = openllm_core.utils.LazyLoader('transformers', globals(), 'transformers'), openllm_core.utils.LazyLoader('torch', globals(), 'torch')
|
||||
|
||||
def get_hash(config: transformers.PretrainedConfig) -> str:
|
||||
_commit_hash = getattr(config, '_commit_hash', None)
|
||||
if _commit_hash is None: raise ValueError(f'Cannot find commit hash in {config}')
|
||||
return _commit_hash
|
||||
|
||||
def process_config(model_id: str, trust_remote_code: bool, **attrs: t.Any) -> tuple[transformers.PretrainedConfig, DictStrAny, DictStrAny]:
|
||||
'''A helper function that correctly parse config and attributes for transformers.PretrainedConfig.
|
||||
|
||||
@@ -42,17 +45,11 @@ def process_config(model_id: str, trust_remote_code: bool, **attrs: t.Any) -> tu
|
||||
config, attrs = transformers.AutoConfig.from_pretrained(model_id, return_unused_kwargs=True, trust_remote_code=trust_remote_code, **hub_attrs, **copied_attrs)
|
||||
return config, hub_attrs, attrs
|
||||
|
||||
def infer_tokenizers_from_llm(__llm: openllm.LLM[t.Any, T], /) -> T:
|
||||
__cls = getattr(transformers, openllm_core.utils.first_not_none(__llm.config['tokenizer_class'], default='AutoTokenizer'), None)
|
||||
if __cls is None:
|
||||
raise ValueError(f'Cannot infer correct tokenizer class for {__llm}. Make sure to unset `tokenizer_class`')
|
||||
return __cls
|
||||
|
||||
def infer_autoclass_from_llm(llm: openllm.LLM[M, T], config: transformers.PretrainedConfig, /) -> _BaseAutoModelClass:
|
||||
if llm.config['trust_remote_code']:
|
||||
if llm.trust_remote_code:
|
||||
autoclass = 'AutoModelForSeq2SeqLM' if llm.config['model_type'] == 'seq2seq_lm' else 'AutoModelForCausalLM'
|
||||
if not hasattr(config, 'auto_map'):
|
||||
raise ValueError(f'Invalid configuraiton for {llm.model_id}. ``trust_remote_code=True`` requires `transformers.PretrainedConfig` to contain a `auto_map` mapping')
|
||||
raise ValueError(f'Invalid configuration for {llm.model_id}. ``trust_remote_code=True`` requires `transformers.PretrainedConfig` to contain a `auto_map` mapping')
|
||||
# in case this model doesn't use the correct auto class for model type, for example like chatglm
|
||||
# where it uses AutoModel instead of AutoModelForCausalLM. Then we fallback to AutoModel
|
||||
if autoclass not in config.auto_map: autoclass = 'AutoModel'
|
||||
@@ -67,15 +64,3 @@ def check_unintialised_params(model: torch.nn.Module) -> None:
|
||||
unintialized = [n for n, param in model.named_parameters() if param.data.device == torch.device('meta')]
|
||||
if len(unintialized) > 0:
|
||||
raise RuntimeError(f'Found the following unintialized parameters in {model}: {unintialized}')
|
||||
|
||||
# NOTE: sync with bentoml/_internal/frameworks/transformers.py#make_default_signatures
|
||||
def make_model_signatures(llm: openllm.LLM[M, T]) -> ModelSignaturesType:
|
||||
infer_fn: tuple[str, ...] = ('__call__',)
|
||||
default_config = ModelSignature(batchable=False)
|
||||
if llm.__llm_backend__ in {'pt', 'vllm'}:
|
||||
infer_fn += ('forward', 'generate', 'contrastive_search', 'greedy_search', 'sample', 'beam_search', 'beam_sample', 'group_beam_search', 'constrained_beam_search',)
|
||||
elif llm.__llm_backend__ == 'tf':
|
||||
infer_fn += ('predict', 'call', 'generate', 'compute_transition_scores', 'greedy_search', 'sample', 'beam_search', 'contrastive_search',)
|
||||
else:
|
||||
infer_fn += ('generate',)
|
||||
return {k: default_config for k in infer_fn}
|
||||
|
||||
@@ -24,19 +24,11 @@ class HfIgnore:
|
||||
|
||||
@classmethod
|
||||
def ignore_patterns(cls, llm: openllm.LLM[M, T]) -> list[str]:
|
||||
if llm.__llm_backend__ == 'vllm':
|
||||
if llm.__llm_backend__ in {'vllm', 'pt'}:
|
||||
base = [cls.tf, cls.flax, cls.gguf]
|
||||
if has_safetensors_weights(llm.model_id) or llm._serialisation == 'safetensors': base.append(cls.pt)
|
||||
if has_safetensors_weights(llm.model_id): base.append(cls.pt)
|
||||
else: base.append(cls.safetensors)
|
||||
elif llm.__llm_backend__ == 'tf': base = [cls.flax, cls.pt, cls.gguf]
|
||||
elif llm.__llm_backend__ == 'flax':
|
||||
base = [cls.tf, cls.pt, cls.safetensors, cls.gguf] # as of current, safetensors is not supported with flax
|
||||
elif llm.__llm_backend__ == 'pt':
|
||||
base = [cls.tf, cls.flax, cls.gguf]
|
||||
if has_safetensors_weights(llm.model_id) or llm._serialisation == 'safetensors': base.append(cls.pt)
|
||||
else: base.append(cls.safetensors)
|
||||
elif llm.__llm_backend__ == 'ggml':
|
||||
base = [cls.tf, cls.flax, cls.pt, cls.safetensors]
|
||||
elif llm.__llm_backend__ == 'ggml': base = [cls.tf, cls.flax, cls.pt, cls.safetensors]
|
||||
else:
|
||||
raise ValueError('Unknown backend (should never happen at all.)')
|
||||
# filter out these files, since we probably don't need them for now.
|
||||
|
||||
Reference in New Issue
Block a user