perf: unify LLM interface (#518)

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
Signed-off-by: paperspace <29749331+aarnphm@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Aaron Pham
2023-11-06 20:39:43 -05:00
committed by GitHub
parent f2639879af
commit e2029c934b
136 changed files with 9646 additions and 11244 deletions

View File

@@ -1,6 +1,6 @@
'''Serialisation utilities for OpenLLM.
Currently supports transformers for PyTorch, Tensorflow and Flax.
Currently supports transformers for PyTorch, and vLLM.
Currently, GGML format is working in progress.
'''
@@ -19,11 +19,15 @@ from openllm_core._typing_compat import ParamSpec
from openllm_core._typing_compat import T
if t.TYPE_CHECKING:
import transformers as _transformers
import bentoml
from . import constants as constants
from . import ggml as ggml
from . import transformers as transformers
else:
_transformers = openllm.utils.LazyLoader('_transformers', globals(), 'transformers')
P = ParamSpec('P')
@@ -33,12 +37,11 @@ def load_tokenizer(llm: openllm.LLM[t.Any, T], **tokenizer_attrs: t.Any) -> T:
By default, it will try to find the bentomodel whether it is in store..
If model is not found, it will raises a ``bentoml.exceptions.NotFound``.
'''
from .transformers._helpers import infer_tokenizers_from_llm
from .transformers._helpers import process_config
config, *_ = process_config(llm._bentomodel.path, llm.trust_remote_code)
config, *_ = process_config(llm.bentomodel.path, llm.trust_remote_code)
bentomodel_fs = fs.open_fs(llm._bentomodel.path)
bentomodel_fs = fs.open_fs(llm.bentomodel.path)
if bentomodel_fs.isfile(CUSTOM_OBJECTS_FILENAME):
with bentomodel_fs.open(CUSTOM_OBJECTS_FILENAME, 'rb') as cofile:
try:
@@ -47,7 +50,7 @@ def load_tokenizer(llm: openllm.LLM[t.Any, T], **tokenizer_attrs: t.Any) -> T:
raise openllm.exceptions.OpenLLMException("Bento model does not have tokenizer. Make sure to save the tokenizer within the model via 'custom_objects'. "
"For example: \"bentoml.transformers.save_model(..., custom_objects={'tokenizer': tokenizer})\"") from None
else:
tokenizer = infer_tokenizers_from_llm(llm).from_pretrained(bentomodel_fs.getsyspath('/'), trust_remote_code=llm.trust_remote_code, **tokenizer_attrs)
tokenizer = _transformers.AutoTokenizer.from_pretrained(bentomodel_fs.getsyspath('/'), trust_remote_code=llm.trust_remote_code, **tokenizer_attrs)
if tokenizer.pad_token_id is None:
if config.pad_token_id is not None: tokenizer.pad_token_id = config.pad_token_id
@@ -66,7 +69,7 @@ def _make_dispatch_function(fn: str) -> _Caller[P]:
def caller(llm: openllm.LLM[M, T], *args: P.args, **kwargs: P.kwargs) -> t.Any:
"""Generic function dispatch to correct serialisation submodules based on LLM runtime.
> [!NOTE] See 'openllm.serialisation.transformers' if 'llm.__llm_backend__ in ("pt", "tf", "flax", "vllm")'
> [!NOTE] See 'openllm.serialisation.transformers' if 'llm.__llm_backend__ in ("pt", "vllm")'
> [!NOTE] See 'openllm.serialisation.ggml' if 'llm.__llm_backend__="ggml"'
"""

View File

@@ -1,9 +1,7 @@
from __future__ import annotations
FRAMEWORK_TO_AUTOCLASS_MAPPING = {
'pt': ('AutoModelForCausalLM', 'AutoModelForSeq2SeqLM'),
'tf': ('TFAutoModelForCausalLM', 'TFAutoModelForSeq2SeqLM'),
'flax': ('FlaxAutoModelForCausalLM', 'FlaxAutoModelForSeq2SeqLM'),
'vllm': ('AutoModelForCausalLM', 'AutoModelForSeq2SeqLM')
}
FRAMEWORK_TO_AUTOCLASS_MAPPING = {'pt': ('AutoModelForCausalLM', 'AutoModelForSeq2SeqLM'), 'vllm': ('AutoModelForCausalLM', 'AutoModelForSeq2SeqLM')}
HUB_ATTRS = ['cache_dir', 'code_revision', 'force_download', 'local_files_only', 'proxies', 'resume_download', 'revision', 'subfolder', 'use_auth_token']
CONFIG_FILE_NAME = 'config.json'
# the below is similar to peft.utils.other.CONFIG_NAME
PEFT_CONFIG_NAME = 'adapter_config.json'

View File

@@ -4,10 +4,10 @@ import importlib
import logging
import typing as t
import attr
import orjson
from huggingface_hub import snapshot_download
from packaging.version import Version
from simple_di import Provide
from simple_di import inject
@@ -16,13 +16,13 @@ import openllm
from bentoml._internal.configuration.containers import BentoMLContainer
from bentoml._internal.models.model import ModelOptions
from bentoml._internal.models.model import ModelSignature
from openllm_core._typing_compat import M
from openllm_core._typing_compat import T
from ._helpers import check_unintialised_params
from ._helpers import get_hash
from ._helpers import infer_autoclass_from_llm
from ._helpers import infer_tokenizers_from_llm
from ._helpers import make_model_signatures
from ._helpers import process_config
from .weights import HfIgnore
@@ -32,16 +32,30 @@ if t.TYPE_CHECKING:
import auto_gptq as autogptq
import torch
import torch.nn
import transformers
from bentoml._internal.models import ModelStore
from openllm_core._typing_compat import DictStrAny
else:
transformers = openllm.utils.LazyLoader('transformers', globals(), 'transformers')
autogptq = openllm.utils.LazyLoader('autogptq', globals(), 'auto_gptq')
torch = openllm.utils.LazyLoader('torch', globals(), 'torch')
logger = logging.getLogger(__name__)
__all__ = ['import_model', 'get', 'load_model']
_object_setattr = object.__setattr__
def _patch_correct_tag(llm: openllm.LLM[M, T], config: transformers.PretrainedConfig, _revision: str | None = None) -> None:
# NOTE: The following won't hit during local since we generated a correct version based on local path hash It will only hit if we use model from HF Hub
if not llm._local:
try:
if _revision is None: _revision = get_hash(config)
except ValueError:
pass
if llm._tag.version is None: _object_setattr(llm, '_tag', attr.evolve(llm.tag, version=_revision)) # HACK: This copies the correct revision into llm.tag
else: _revision = llm._tag.version
if llm._revision is None: _object_setattr(llm, '_revision', _revision) # HACK: This copies the correct revision into llm._model_version
@inject
def import_model(llm: openllm.LLM[M, T], *decls: t.Any, trust_remote_code: bool, _model_store: ModelStore = Provide[BentoMLContainer.model_store], **attrs: t.Any) -> bentoml.Model:
@@ -49,7 +63,7 @@ def import_model(llm: openllm.LLM[M, T], *decls: t.Any, trust_remote_code: bool,
For all kwargs, it will be parsed into `transformers.AutoConfig.from_pretrained` first,
returning all of the unused kwargs.
The unused kwargs then parsed directly into AutoModelForSeq2SeqLM or AutoModelForCausalLM (+ TF, Flax variants).
The unused kwargs then parsed directly into AutoModelForSeq2SeqLM or AutoModelForCausalLM.
For all tokenizer kwargs, make sure to prefix it with `_tokenizer_` to avoid confusion.
Note: Currently, there are only two tasks supported: `text-generation` and `text2text-generation`.
@@ -57,20 +71,22 @@ def import_model(llm: openllm.LLM[M, T], *decls: t.Any, trust_remote_code: bool,
Refer to Transformers documentation for more information about kwargs.
Args:
llm: The LLM instance for this given model.
trust_remote_code: Whether to trust the remote code when loading the model.
*decls: Args to be passed into AutoModelForSeq2SeqLM or AutoModelForCausalLM (+ TF, Flax variants).
**attrs: Kwargs to be passed into AutoModelForSeq2SeqLM or AutoModelForCausalLM (+ TF, Flax variants).
llm: The LLM instance for this given model.
trust_remote_code: Whether to trust the remote code when loading the model.
*decls: Args to be passed into AutoModelForSeq2SeqLM or AutoModelForCausalLM.
**attrs: Kwargs to be passed into AutoModelForSeq2SeqLM or AutoModelForCausalLM.
"""
config, hub_attrs, attrs = process_config(llm.model_id, trust_remote_code, **attrs)
_patch_correct_tag(llm, config)
_, tokenizer_attrs = llm.llm_parameters
quantize = llm._quantize
quantize = llm._quantise
safe_serialisation = openllm.utils.first_not_none(attrs.get('safe_serialization'), default=llm._serialisation == 'safetensors')
metadata: DictStrAny = {'safe_serialisation': safe_serialisation}
if quantize: metadata['_quantize'] = quantize
architectures = getattr(config, 'architectures', [])
if not architectures: raise RuntimeError('Failed to determine the architecture for this model. Make sure the `config.json` is valid and can be loaded with `transformers.AutoConfig`')
metadata['_pretrained_class'] = architectures[0]
metadata['_revision'] = get_hash(config)
signatures: DictStrAny = {}
@@ -79,26 +95,24 @@ def import_model(llm: openllm.LLM[M, T], *decls: t.Any, trust_remote_code: bool,
raise openllm.exceptions.OpenLLMException(
"GPTQ quantisation requires 'auto-gptq' and 'optimum' (Not found in local environment). Install it with 'pip install \"openllm[gptq]\" --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/'"
)
if llm.config['model_type'] != 'causal_lm':
raise openllm.exceptions.OpenLLMException(f"GPTQ only support Causal LM (got {llm.__class__} of {llm.config['model_type']})")
signatures['generate'] = {'batchable': False}
else:
# this model might be called with --quantize int4, therefore we need to pop this out
# since saving int4 is not yet supported
if 'quantization_config' in attrs and getattr(attrs['quantization_config'], 'load_in_4bit', False):
attrs.pop('quantization_config')
if llm.__llm_backend__ != 'flax': attrs['use_safetensors'] = safe_serialisation
attrs['use_safetensors'] = safe_serialisation
metadata['_framework'] = llm.__llm_backend__
signatures.update(make_model_signatures(llm))
signatures.update({
k: ModelSignature(batchable=False)
for k in ('__call__', 'forward', 'generate', 'contrastive_search', 'greedy_search', 'sample', 'beam_search', 'beam_sample', 'group_beam_search', 'constrained_beam_search')
})
tokenizer = infer_tokenizers_from_llm(llm).from_pretrained(llm.model_id, trust_remote_code=trust_remote_code, **hub_attrs, **tokenizer_attrs)
tokenizer = transformers.AutoTokenizer.from_pretrained(llm.model_id, trust_remote_code=trust_remote_code, **hub_attrs, **tokenizer_attrs)
if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token
model = None
external_modules: list[types.ModuleType] = [importlib.import_module(tokenizer.__module__)]
imported_modules: list[types.ModuleType] = []
bentomodel = bentoml.Model.create(llm.tag,
module='openllm.serialisation.transformers',
api_version='v2',
api_version='v2.1.0',
options=ModelOptions(),
context=openllm.utils.generate_context(framework_name='openllm'),
labels=openllm.utils.generate_labels(llm),
@@ -108,13 +122,12 @@ def import_model(llm: openllm.LLM[M, T], *decls: t.Any, trust_remote_code: bool,
try:
bentomodel.enter_cloudpickle_context(external_modules, imported_modules)
tokenizer.save_pretrained(bentomodel.path)
if llm._quantise or llm._quantization_config: attrs['quantization_config'] = llm.quantization_config
if quantize == 'gptq':
from optimum.gptq.constants import GPTQ_CONFIG
with open(bentomodel.path_of(GPTQ_CONFIG), 'w', encoding='utf-8') as f:
f.write(orjson.dumps(config.quantization_config, option=orjson.OPT_INDENT_2 | orjson.OPT_SORT_KEYS).decode())
if llm._local:
# possible local path
logger.debug('Model will be loaded into memory to save to target store as it is from local path.')
if llm._local: # possible local path
model = infer_autoclass_from_llm(llm, config).from_pretrained(llm.model_id, *decls, config=config, trust_remote_code=trust_remote_code, **hub_attrs, **attrs)
# for trust_remote_code to work
bentomodel.enter_cloudpickle_context([importlib.import_module(model.__module__)], imported_modules)
@@ -133,6 +146,7 @@ def import_model(llm: openllm.LLM[M, T], *decls: t.Any, trust_remote_code: bool,
# NOTE: We need to free up the cache after importing the model
# in the case where users first run openllm start without the model available locally.
if openllm.utils.is_torch_available() and torch.cuda.is_available(): torch.cuda.empty_cache()
del model
return bentomodel
def get(llm: openllm.LLM[M, T], auto_import: bool = False) -> bentoml.Model:
@@ -145,31 +159,35 @@ def get(llm: openllm.LLM[M, T], auto_import: bool = False) -> bentoml.Model:
'''
try:
model = bentoml.models.get(llm.tag)
if Version(model.info.api_version) < Version('v2'):
raise openllm.exceptions.OpenLLMException('Please run "openllm prune -y --include-bentos" and upgrade all saved model to latest release.')
if model.info.labels['backend'] != llm.__llm_backend__:
raise openllm.exceptions.OpenLLMException(f"Model {model.tag} was saved with backend {model.info.labels['backend']}, while loading with {llm.__llm_backend__}.")
backend = model.info.labels['backend']
if backend != llm.__llm_backend__: raise openllm.exceptions.OpenLLMException(f"'{model.tag!s}' was saved with backend '{backend}', while loading with '{llm.__llm_backend__}'.")
_patch_correct_tag(llm, process_config(model.path, llm.trust_remote_code)[0], _revision=t.cast(t.Optional[str], model.info.metadata.get('_revision')))
return model
except Exception as err:
if auto_import: return import_model(llm, trust_remote_code=llm.trust_remote_code)
raise openllm.exceptions.OpenLLMException(f'Failed while getting stored artefact (lookup for traceback):\n{err}') from err
def load_model(llm: openllm.LLM[M, T], *decls: t.Any, **attrs: t.Any) -> M:
config, hub_attrs, attrs = process_config(llm.model_id, llm.trust_remote_code, **attrs)
config, hub_attrs, attrs = process_config(llm.bentomodel.path, llm.trust_remote_code, **attrs)
_patch_correct_tag(llm, config, _revision=t.cast(t.Optional[str], llm.bentomodel.info.metadata.get('_revision')))
auto_class = infer_autoclass_from_llm(llm, config)
device_map: str | None = attrs.pop('device_map', 'auto' if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None)
if llm._quantise or llm._quantization_config: attrs['quantization_config'] = llm.quantization_config
if '_quantize' in llm._bentomodel.info.metadata and llm._bentomodel.info.metadata['_quantize'] == 'gptq':
if '_quantize' in llm.bentomodel.info.metadata and llm.bentomodel.info.metadata['_quantize'] == 'gptq':
if not openllm.utils.is_autogptq_available() or not openllm.utils.is_optimum_supports_gptq():
raise openllm.exceptions.OpenLLMException(
"GPTQ quantisation requires 'auto-gptq' and 'optimum' (Not found in local environment). Install it with 'pip install \"openllm[gptq]\" --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/'"
)
if llm.config['model_type'] != 'causal_lm': raise openllm.exceptions.OpenLLMException(f"GPTQ only support Causal LM (got {llm.__class__} of {llm.config['model_type']})")
model = auto_class.from_pretrained(llm._bentomodel.path, device_map='auto', **hub_attrs, **attrs)
try:
model = auto_class.from_pretrained(llm.bentomodel.path, device_map='auto', use_flash_attention_2=True, **hub_attrs, **attrs)
except Exception as err:
logger.debug("Exception caught while trying to load with 'flash_attention_2': %s", err)
model = auto_class.from_pretrained(llm.bentomodel.path, device_map='auto', use_flash_attention_2=False, **hub_attrs, **attrs)
# XXX: Use the below logic once TheBloke finished migration to new GPTQConfig from transformers
# Seems like the logic below requires to add support for safetensors on accelerate
#
# from accelerate import init_empty_weights
# from optimum.gptq import load_quantized_model
# # disable exllama if gptq is loaded on CPU
@@ -179,6 +197,6 @@ def load_model(llm: openllm.LLM[M, T], *decls: t.Any, **attrs: t.Any) -> M:
# empty.tie_weights()
# model = load_quantized_model(empty, save_folder=llm._bentomodel.path, device_map='auto', disable_exllama=disable_exllama)
else:
model = auto_class.from_pretrained(llm._bentomodel.path, *decls, config=config, trust_remote_code=llm.trust_remote_code, device_map=device_map, **hub_attrs, **attrs).eval()
if llm.__llm_backend__ in {'pt', 'vllm'}: check_unintialised_params(model)
model = auto_class.from_pretrained(llm.bentomodel.path, *decls, config=config, trust_remote_code=llm.trust_remote_code, device_map=device_map, **hub_attrs, **attrs).eval()
if llm.__llm_backend__ == 'pt': check_unintialised_params(model)
return t.cast('M', model)

View File

@@ -5,7 +5,6 @@ import typing as t
import openllm
import openllm_core
from bentoml._internal.models.model import ModelSignature
from openllm.serialisation.constants import FRAMEWORK_TO_AUTOCLASS_MAPPING
from openllm.serialisation.constants import HUB_ATTRS
@@ -15,13 +14,17 @@ if t.TYPE_CHECKING:
from transformers.models.auto.auto_factory import _BaseAutoModelClass
from bentoml._internal.models.model import ModelSignaturesType
from openllm_core._typing_compat import DictStrAny
from openllm_core._typing_compat import M
from openllm_core._typing_compat import T
else:
transformers, torch = openllm_core.utils.LazyLoader('transformers', globals(), 'transformers'), openllm_core.utils.LazyLoader('torch', globals(), 'torch')
def get_hash(config: transformers.PretrainedConfig) -> str:
_commit_hash = getattr(config, '_commit_hash', None)
if _commit_hash is None: raise ValueError(f'Cannot find commit hash in {config}')
return _commit_hash
def process_config(model_id: str, trust_remote_code: bool, **attrs: t.Any) -> tuple[transformers.PretrainedConfig, DictStrAny, DictStrAny]:
'''A helper function that correctly parse config and attributes for transformers.PretrainedConfig.
@@ -42,17 +45,11 @@ def process_config(model_id: str, trust_remote_code: bool, **attrs: t.Any) -> tu
config, attrs = transformers.AutoConfig.from_pretrained(model_id, return_unused_kwargs=True, trust_remote_code=trust_remote_code, **hub_attrs, **copied_attrs)
return config, hub_attrs, attrs
def infer_tokenizers_from_llm(__llm: openllm.LLM[t.Any, T], /) -> T:
__cls = getattr(transformers, openllm_core.utils.first_not_none(__llm.config['tokenizer_class'], default='AutoTokenizer'), None)
if __cls is None:
raise ValueError(f'Cannot infer correct tokenizer class for {__llm}. Make sure to unset `tokenizer_class`')
return __cls
def infer_autoclass_from_llm(llm: openllm.LLM[M, T], config: transformers.PretrainedConfig, /) -> _BaseAutoModelClass:
if llm.config['trust_remote_code']:
if llm.trust_remote_code:
autoclass = 'AutoModelForSeq2SeqLM' if llm.config['model_type'] == 'seq2seq_lm' else 'AutoModelForCausalLM'
if not hasattr(config, 'auto_map'):
raise ValueError(f'Invalid configuraiton for {llm.model_id}. ``trust_remote_code=True`` requires `transformers.PretrainedConfig` to contain a `auto_map` mapping')
raise ValueError(f'Invalid configuration for {llm.model_id}. ``trust_remote_code=True`` requires `transformers.PretrainedConfig` to contain a `auto_map` mapping')
# in case this model doesn't use the correct auto class for model type, for example like chatglm
# where it uses AutoModel instead of AutoModelForCausalLM. Then we fallback to AutoModel
if autoclass not in config.auto_map: autoclass = 'AutoModel'
@@ -67,15 +64,3 @@ def check_unintialised_params(model: torch.nn.Module) -> None:
unintialized = [n for n, param in model.named_parameters() if param.data.device == torch.device('meta')]
if len(unintialized) > 0:
raise RuntimeError(f'Found the following unintialized parameters in {model}: {unintialized}')
# NOTE: sync with bentoml/_internal/frameworks/transformers.py#make_default_signatures
def make_model_signatures(llm: openllm.LLM[M, T]) -> ModelSignaturesType:
infer_fn: tuple[str, ...] = ('__call__',)
default_config = ModelSignature(batchable=False)
if llm.__llm_backend__ in {'pt', 'vllm'}:
infer_fn += ('forward', 'generate', 'contrastive_search', 'greedy_search', 'sample', 'beam_search', 'beam_sample', 'group_beam_search', 'constrained_beam_search',)
elif llm.__llm_backend__ == 'tf':
infer_fn += ('predict', 'call', 'generate', 'compute_transition_scores', 'greedy_search', 'sample', 'beam_search', 'contrastive_search',)
else:
infer_fn += ('generate',)
return {k: default_config for k in infer_fn}

View File

@@ -24,19 +24,11 @@ class HfIgnore:
@classmethod
def ignore_patterns(cls, llm: openllm.LLM[M, T]) -> list[str]:
if llm.__llm_backend__ == 'vllm':
if llm.__llm_backend__ in {'vllm', 'pt'}:
base = [cls.tf, cls.flax, cls.gguf]
if has_safetensors_weights(llm.model_id) or llm._serialisation == 'safetensors': base.append(cls.pt)
if has_safetensors_weights(llm.model_id): base.append(cls.pt)
else: base.append(cls.safetensors)
elif llm.__llm_backend__ == 'tf': base = [cls.flax, cls.pt, cls.gguf]
elif llm.__llm_backend__ == 'flax':
base = [cls.tf, cls.pt, cls.safetensors, cls.gguf] # as of current, safetensors is not supported with flax
elif llm.__llm_backend__ == 'pt':
base = [cls.tf, cls.flax, cls.gguf]
if has_safetensors_weights(llm.model_id) or llm._serialisation == 'safetensors': base.append(cls.pt)
else: base.append(cls.safetensors)
elif llm.__llm_backend__ == 'ggml':
base = [cls.tf, cls.flax, cls.pt, cls.safetensors]
elif llm.__llm_backend__ == 'ggml': base = [cls.tf, cls.flax, cls.pt, cls.safetensors]
else:
raise ValueError('Unknown backend (should never happen at all.)')
# filter out these files, since we probably don't need them for now.