mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-04-25 09:32:37 -04:00
chore(style): synchronized style across packages [skip ci]
Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
@@ -31,21 +31,21 @@ from openllm_core._typing_compat import M, T, ParamSpec
|
||||
if t.TYPE_CHECKING:
|
||||
import bentoml
|
||||
from . import constants as constants, ggml as ggml, transformers as transformers
|
||||
P = ParamSpec("P")
|
||||
P = ParamSpec('P')
|
||||
def load_tokenizer(llm: openllm.LLM[t.Any, T], **tokenizer_attrs: t.Any) -> T:
|
||||
"""Load the tokenizer from BentoML store.
|
||||
'''Load the tokenizer from BentoML store.
|
||||
|
||||
By default, it will try to find the bentomodel whether it is in store..
|
||||
If model is not found, it will raises a ``bentoml.exceptions.NotFound``.
|
||||
"""
|
||||
'''
|
||||
from .transformers._helpers import infer_tokenizers_from_llm, process_config
|
||||
|
||||
config, *_ = process_config(llm._bentomodel.path, llm.__llm_trust_remote_code__)
|
||||
bentomodel_fs = fs.open_fs(llm._bentomodel.path)
|
||||
if bentomodel_fs.isfile(CUSTOM_OBJECTS_FILENAME):
|
||||
with bentomodel_fs.open(CUSTOM_OBJECTS_FILENAME, "rb") as cofile:
|
||||
with bentomodel_fs.open(CUSTOM_OBJECTS_FILENAME, 'rb') as cofile:
|
||||
try:
|
||||
tokenizer = cloudpickle.load(t.cast("t.IO[bytes]", cofile))["tokenizer"]
|
||||
tokenizer = cloudpickle.load(t.cast('t.IO[bytes]', cofile))['tokenizer']
|
||||
except KeyError:
|
||||
raise openllm.exceptions.OpenLLMException(
|
||||
"Bento model does not have tokenizer. Make sure to save"
|
||||
@@ -53,18 +53,18 @@ def load_tokenizer(llm: openllm.LLM[t.Any, T], **tokenizer_attrs: t.Any) -> T:
|
||||
" For example: \"bentoml.transformers.save_model(..., custom_objects={'tokenizer': tokenizer})\""
|
||||
) from None
|
||||
else:
|
||||
tokenizer = infer_tokenizers_from_llm(llm).from_pretrained(bentomodel_fs.getsyspath("/"), trust_remote_code=llm.__llm_trust_remote_code__, **tokenizer_attrs)
|
||||
tokenizer = infer_tokenizers_from_llm(llm).from_pretrained(bentomodel_fs.getsyspath('/'), trust_remote_code=llm.__llm_trust_remote_code__, **tokenizer_attrs)
|
||||
|
||||
if tokenizer.pad_token_id is None:
|
||||
if config.pad_token_id is not None: tokenizer.pad_token_id = config.pad_token_id
|
||||
elif config.eos_token_id is not None: tokenizer.pad_token_id = config.eos_token_id
|
||||
elif tokenizer.eos_token_id is not None: tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||
else: tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
||||
else: tokenizer.add_special_tokens({'pad_token': '[PAD]'})
|
||||
return tokenizer
|
||||
class _Caller(t.Protocol[P]):
|
||||
def __call__(self, llm: openllm.LLM[M, T], *args: P.args, **kwargs: P.kwargs) -> t.Any:
|
||||
...
|
||||
_extras = ["get", "import_model", "save_pretrained", "load_model"]
|
||||
_extras = ['get', 'import_model', 'save_pretrained', 'load_model']
|
||||
def _make_dispatch_function(fn: str) -> _Caller[P]:
|
||||
def caller(llm: openllm.LLM[M, T], *args: P.args, **kwargs: P.kwargs) -> t.Any:
|
||||
"""Generic function dispatch to correct serialisation submodules based on LLM runtime.
|
||||
@@ -73,7 +73,7 @@ def _make_dispatch_function(fn: str) -> _Caller[P]:
|
||||
|
||||
> [!NOTE] See 'openllm.serialisation.ggml' if 'llm.runtime="ggml"'
|
||||
"""
|
||||
return getattr(importlib.import_module(f".{llm.runtime}", __name__), fn)(llm, *args, **kwargs)
|
||||
return getattr(importlib.import_module(f'.{llm.runtime}', __name__), fn)(llm, *args, **kwargs)
|
||||
|
||||
return caller
|
||||
if t.TYPE_CHECKING:
|
||||
@@ -89,12 +89,12 @@ if t.TYPE_CHECKING:
|
||||
|
||||
def load_model(llm: openllm.LLM[M, T], *args: t.Any, **kwargs: t.Any) -> M:
|
||||
...
|
||||
_import_structure: dict[str, list[str]] = {"ggml": [], "transformers": [], "constants": []}
|
||||
__all__ = ["ggml", "transformers", "constants", "load_tokenizer", *_extras]
|
||||
_import_structure: dict[str, list[str]] = {'ggml': [], 'transformers': [], 'constants': []}
|
||||
__all__ = ['ggml', 'transformers', 'constants', 'load_tokenizer', *_extras]
|
||||
def __dir__() -> list[str]:
|
||||
return sorted(__all__)
|
||||
def __getattr__(name: str) -> t.Any:
|
||||
if name == "load_tokenizer": return load_tokenizer
|
||||
elif name in _import_structure: return importlib.import_module(f".{name}", __name__)
|
||||
if name == 'load_tokenizer': return load_tokenizer
|
||||
elif name in _import_structure: return importlib.import_module(f'.{name}', __name__)
|
||||
elif name in _extras: return _make_dispatch_function(name)
|
||||
else: raise AttributeError(f"{__name__} has no attribute {name}")
|
||||
else: raise AttributeError(f'{__name__} has no attribute {name}')
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from __future__ import annotations
|
||||
FRAMEWORK_TO_AUTOCLASS_MAPPING = {
|
||||
"pt": ("AutoModelForCausalLM", "AutoModelForSeq2SeqLM"),
|
||||
"tf": ("TFAutoModelForCausalLM", "TFAutoModelForSeq2SeqLM"),
|
||||
"flax": ("FlaxAutoModelForCausalLM", "FlaxAutoModelForSeq2SeqLM"),
|
||||
"vllm": ("AutoModelForCausalLM", "AutoModelForSeq2SeqLM")
|
||||
'pt': ('AutoModelForCausalLM', 'AutoModelForSeq2SeqLM'),
|
||||
'tf': ('TFAutoModelForCausalLM', 'TFAutoModelForSeq2SeqLM'),
|
||||
'flax': ('FlaxAutoModelForCausalLM', 'FlaxAutoModelForSeq2SeqLM'),
|
||||
'vllm': ('AutoModelForCausalLM', 'AutoModelForSeq2SeqLM')
|
||||
}
|
||||
HUB_ATTRS = ["cache_dir", "code_revision", "force_download", "local_files_only", "proxies", "resume_download", "revision", "subfolder", "use_auth_token"]
|
||||
HUB_ATTRS = ['cache_dir', 'code_revision', 'force_download', 'local_files_only', 'proxies', 'resume_download', 'revision', 'subfolder', 'use_auth_token']
|
||||
|
||||
@@ -1,29 +1,29 @@
|
||||
"""Serialisation related implementation for GGML-based implementation.
|
||||
'''Serialisation related implementation for GGML-based implementation.
|
||||
|
||||
This requires ctransformers to be installed.
|
||||
"""
|
||||
'''
|
||||
from __future__ import annotations
|
||||
import typing as t
|
||||
import bentoml, openllm
|
||||
|
||||
if t.TYPE_CHECKING: from openllm_core._typing_compat import M
|
||||
|
||||
_conversion_strategy = {"pt": "ggml"}
|
||||
_conversion_strategy = {'pt': 'ggml'}
|
||||
def import_model(llm: openllm.LLM[t.Any, t.Any], *decls: t.Any, trust_remote_code: bool = True, **attrs: t.Any,) -> bentoml.Model:
|
||||
raise NotImplementedError("Currently work in progress.")
|
||||
raise NotImplementedError('Currently work in progress.')
|
||||
def get(llm: openllm.LLM[t.Any, t.Any], auto_import: bool = False) -> bentoml.Model:
|
||||
"""Return an instance of ``bentoml.Model`` from given LLM instance.
|
||||
'''Return an instance of ``bentoml.Model`` from given LLM instance.
|
||||
|
||||
By default, it will try to check the model in the local store.
|
||||
If model is not found, and ``auto_import`` is set to True, it will try to import the model from HuggingFace Hub.
|
||||
|
||||
Otherwise, it will raises a ``bentoml.exceptions.NotFound``.
|
||||
"""
|
||||
'''
|
||||
try:
|
||||
model = bentoml.models.get(llm.tag)
|
||||
if model.info.module not in ("openllm.serialisation.ggml", __name__):
|
||||
if model.info.module not in ('openllm.serialisation.ggml', __name__):
|
||||
raise bentoml.exceptions.NotFound(f"Model {model.tag} was saved with module {model.info.module}, not loading with 'openllm.serialisation.transformers'.")
|
||||
if "runtime" in model.info.labels and model.info.labels["runtime"] != llm.runtime:
|
||||
if 'runtime' in model.info.labels and model.info.labels['runtime'] != llm.runtime:
|
||||
raise openllm.exceptions.OpenLLMException(f"Model {model.tag} was saved with runtime {model.info.labels['runtime']}, not loading with {llm.runtime}.")
|
||||
return model
|
||||
except bentoml.exceptions.NotFound:
|
||||
@@ -31,6 +31,6 @@ def get(llm: openllm.LLM[t.Any, t.Any], auto_import: bool = False) -> bentoml.Mo
|
||||
return import_model(llm, trust_remote_code=llm.__llm_trust_remote_code__)
|
||||
raise
|
||||
def load_model(llm: openllm.LLM[M, t.Any], *decls: t.Any, **attrs: t.Any) -> M:
|
||||
raise NotImplementedError("Currently work in progress.")
|
||||
raise NotImplementedError('Currently work in progress.')
|
||||
def save_pretrained(llm: openllm.LLM[t.Any, t.Any], save_directory: str, **attrs: t.Any) -> None:
|
||||
raise NotImplementedError("Currently work in progress.")
|
||||
raise NotImplementedError('Currently work in progress.')
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""Serialisation related implementation for Transformers-based implementation."""
|
||||
'''Serialisation related implementation for Transformers-based implementation.'''
|
||||
from __future__ import annotations
|
||||
import importlib, logging, typing as t
|
||||
import bentoml, openllm
|
||||
@@ -18,14 +18,14 @@ if t.TYPE_CHECKING:
|
||||
from bentoml._internal.models import ModelStore
|
||||
from openllm_core._typing_compat import DictStrAny, M, T
|
||||
else:
|
||||
vllm = openllm.utils.LazyLoader("vllm", globals(), "vllm")
|
||||
autogptq = openllm.utils.LazyLoader("autogptq", globals(), "auto_gptq")
|
||||
transformers = openllm.utils.LazyLoader("transformers", globals(), "transformers")
|
||||
torch = openllm.utils.LazyLoader("torch", globals(), "torch")
|
||||
vllm = openllm.utils.LazyLoader('vllm', globals(), 'vllm')
|
||||
autogptq = openllm.utils.LazyLoader('autogptq', globals(), 'auto_gptq')
|
||||
transformers = openllm.utils.LazyLoader('transformers', globals(), 'transformers')
|
||||
torch = openllm.utils.LazyLoader('torch', globals(), 'torch')
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
__all__ = ["import_model", "get", "load_model", "save_pretrained"]
|
||||
__all__ = ['import_model', 'get', 'load_model', 'save_pretrained']
|
||||
@inject
|
||||
def import_model(llm: openllm.LLM[M, T], *decls: t.Any, trust_remote_code: bool, _model_store: ModelStore = Provide[BentoMLContainer.model_store], **attrs: t.Any) -> bentoml.Model:
|
||||
"""Auto detect model type from given model_id and import it to bentoml's model store.
|
||||
@@ -48,23 +48,23 @@ def import_model(llm: openllm.LLM[M, T], *decls: t.Any, trust_remote_code: bool,
|
||||
config, hub_attrs, attrs = process_config(llm.model_id, trust_remote_code, **attrs)
|
||||
_, tokenizer_attrs = llm.llm_parameters
|
||||
quantize_method = llm._quantize_method
|
||||
safe_serialisation = openllm.utils.first_not_none(attrs.get("safe_serialization"), default=llm._serialisation_format == "safetensors")
|
||||
safe_serialisation = openllm.utils.first_not_none(attrs.get('safe_serialization'), default=llm._serialisation_format == 'safetensors')
|
||||
# Disable safe serialization with vLLM
|
||||
if llm.__llm_implementation__ == "vllm": safe_serialisation = False
|
||||
metadata: DictStrAny = {"safe_serialisation": safe_serialisation, "_quantize": quantize_method is not None and quantize_method}
|
||||
if llm.__llm_implementation__ == 'vllm': safe_serialisation = False
|
||||
metadata: DictStrAny = {'safe_serialisation': safe_serialisation, '_quantize': quantize_method is not None and quantize_method}
|
||||
signatures: DictStrAny = {}
|
||||
|
||||
if quantize_method == "gptq":
|
||||
if quantize_method == 'gptq':
|
||||
if not openllm.utils.is_autogptq_available():
|
||||
raise openllm.exceptions.OpenLLMException("GPTQ quantisation requires 'auto-gptq' (Not found in local environment). Install it with 'pip install \"openllm[gptq]\"'")
|
||||
if llm.config["model_type"] != "causal_lm": raise openllm.exceptions.OpenLLMException(f"GPTQ only support Causal LM (got {llm.__class__} of {llm.config['model_type']})")
|
||||
signatures["generate"] = {"batchable": False}
|
||||
if llm.config['model_type'] != 'causal_lm': raise openllm.exceptions.OpenLLMException(f"GPTQ only support Causal LM (got {llm.__class__} of {llm.config['model_type']})")
|
||||
signatures['generate'] = {'batchable': False}
|
||||
else:
|
||||
# this model might be called with --quantize int4, therefore we need to pop this out
|
||||
# since saving int4 is not yet supported
|
||||
if "quantization_config" in attrs and getattr(attrs["quantization_config"], "load_in_4bit", False): attrs.pop("quantization_config")
|
||||
if llm.__llm_implementation__ != "flax": attrs["use_safetensors"] = safe_serialisation
|
||||
metadata["_framework"] = "pt" if llm.__llm_implementation__ == "vllm" else llm.__llm_implementation__
|
||||
if 'quantization_config' in attrs and getattr(attrs['quantization_config'], 'load_in_4bit', False): attrs.pop('quantization_config')
|
||||
if llm.__llm_implementation__ != 'flax': attrs['use_safetensors'] = safe_serialisation
|
||||
metadata['_framework'] = 'pt' if llm.__llm_implementation__ == 'vllm' else llm.__llm_implementation__
|
||||
|
||||
tokenizer = infer_tokenizers_from_llm(llm).from_pretrained(llm.model_id, trust_remote_code=trust_remote_code, **hub_attrs, **tokenizer_attrs)
|
||||
if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token
|
||||
@@ -73,10 +73,10 @@ def import_model(llm: openllm.LLM[M, T], *decls: t.Any, trust_remote_code: bool,
|
||||
imported_modules: list[types.ModuleType] = []
|
||||
bentomodel = bentoml.Model.create(
|
||||
llm.tag,
|
||||
module="openllm.serialisation.transformers",
|
||||
api_version="v1",
|
||||
module='openllm.serialisation.transformers',
|
||||
api_version='v1',
|
||||
options=ModelOptions(),
|
||||
context=openllm.utils.generate_context(framework_name="openllm"),
|
||||
context=openllm.utils.generate_context(framework_name='openllm'),
|
||||
labels=openllm.utils.generate_labels(llm),
|
||||
signatures=signatures if signatures else make_model_signatures(llm)
|
||||
)
|
||||
@@ -84,35 +84,35 @@ def import_model(llm: openllm.LLM[M, T], *decls: t.Any, trust_remote_code: bool,
|
||||
try:
|
||||
bentomodel.enter_cloudpickle_context(external_modules, imported_modules)
|
||||
tokenizer.save_pretrained(bentomodel.path)
|
||||
if quantize_method == "gptq":
|
||||
if quantize_method == 'gptq':
|
||||
if not openllm.utils.is_autogptq_available():
|
||||
raise openllm.exceptions.OpenLLMException("GPTQ quantisation requires 'auto-gptq' (Not found in local environment). Install it with 'pip install \"openllm[gptq]\"'")
|
||||
if llm.config["model_type"] != "causal_lm": raise openllm.exceptions.OpenLLMException(f"GPTQ only support Causal LM (got {llm.__class__} of {llm.config['model_type']})")
|
||||
logger.debug("Saving model with GPTQ quantisation will require loading model into memory.")
|
||||
if llm.config['model_type'] != 'causal_lm': raise openllm.exceptions.OpenLLMException(f"GPTQ only support Causal LM (got {llm.__class__} of {llm.config['model_type']})")
|
||||
logger.debug('Saving model with GPTQ quantisation will require loading model into memory.')
|
||||
model = autogptq.AutoGPTQForCausalLM.from_quantized(
|
||||
llm.model_id,
|
||||
*decls,
|
||||
quantize_config=t.cast("autogptq.BaseQuantizeConfig", llm.quantization_config),
|
||||
quantize_config=t.cast('autogptq.BaseQuantizeConfig', llm.quantization_config),
|
||||
trust_remote_code=trust_remote_code,
|
||||
use_safetensors=safe_serialisation,
|
||||
**hub_attrs,
|
||||
**attrs,
|
||||
)
|
||||
update_model(bentomodel, metadata={"_pretrained_class": model.__class__.__name__, "_framework": model.model.framework})
|
||||
update_model(bentomodel, metadata={'_pretrained_class': model.__class__.__name__, '_framework': model.model.framework})
|
||||
model.save_quantized(bentomodel.path, use_safetensors=safe_serialisation)
|
||||
else:
|
||||
architectures = getattr(config, "architectures", [])
|
||||
architectures = getattr(config, 'architectures', [])
|
||||
if not architectures:
|
||||
raise RuntimeError("Failed to determine the architecture for this model. Make sure the `config.json` is valid and can be loaded with `transformers.AutoConfig`")
|
||||
raise RuntimeError('Failed to determine the architecture for this model. Make sure the `config.json` is valid and can be loaded with `transformers.AutoConfig`')
|
||||
architecture = architectures[0]
|
||||
update_model(bentomodel, metadata={"_pretrained_class": architecture})
|
||||
update_model(bentomodel, metadata={'_pretrained_class': architecture})
|
||||
if llm._local:
|
||||
# possible local path
|
||||
logger.debug("Model will be loaded into memory to save to target store as it is from local path.")
|
||||
logger.debug('Model will be loaded into memory to save to target store as it is from local path.')
|
||||
model = infer_autoclass_from_llm(llm, config).from_pretrained(llm.model_id, *decls, config=config, trust_remote_code=trust_remote_code, **hub_attrs, **attrs)
|
||||
# for trust_remote_code to work
|
||||
bentomodel.enter_cloudpickle_context([importlib.import_module(model.__module__)], imported_modules)
|
||||
model.save_pretrained(bentomodel.path, max_shard_size="5GB", safe_serialization=safe_serialisation)
|
||||
model.save_pretrained(bentomodel.path, max_shard_size='5GB', safe_serialization=safe_serialisation)
|
||||
else:
|
||||
# we will clone the all tings into the bentomodel path without loading model into memory
|
||||
snapshot_download(llm.model_id, local_dir=bentomodel.path, local_dir_use_symlinks=False, ignore_patterns=HfIgnore.ignore_patterns(llm))
|
||||
@@ -129,58 +129,58 @@ def import_model(llm: openllm.LLM[M, T], *decls: t.Any, trust_remote_code: bool,
|
||||
if openllm.utils.is_torch_available() and torch.cuda.is_available(): torch.cuda.empty_cache()
|
||||
return bentomodel
|
||||
def get(llm: openllm.LLM[M, T], auto_import: bool = False) -> bentoml.Model:
|
||||
"""Return an instance of ``bentoml.Model`` from given LLM instance.
|
||||
'''Return an instance of ``bentoml.Model`` from given LLM instance.
|
||||
|
||||
By default, it will try to check the model in the local store.
|
||||
If model is not found, and ``auto_import`` is set to True, it will try to import the model from HuggingFace Hub.
|
||||
|
||||
Otherwise, it will raises a ``bentoml.exceptions.NotFound``.
|
||||
"""
|
||||
'''
|
||||
try:
|
||||
model = bentoml.models.get(llm.tag)
|
||||
if model.info.module not in (
|
||||
"openllm.serialisation.transformers"
|
||||
"bentoml.transformers", "bentoml._internal.frameworks.transformers", __name__
|
||||
'openllm.serialisation.transformers'
|
||||
'bentoml.transformers', 'bentoml._internal.frameworks.transformers', __name__
|
||||
): # NOTE: backward compatible with previous version of OpenLLM.
|
||||
raise bentoml.exceptions.NotFound(f"Model {model.tag} was saved with module {model.info.module}, not loading with 'openllm.serialisation.transformers'.")
|
||||
if "runtime" in model.info.labels and model.info.labels["runtime"] != llm.runtime:
|
||||
if 'runtime' in model.info.labels and model.info.labels['runtime'] != llm.runtime:
|
||||
raise openllm.exceptions.OpenLLMException(f"Model {model.tag} was saved with runtime {model.info.labels['runtime']}, not loading with {llm.runtime}.")
|
||||
return model
|
||||
except bentoml.exceptions.NotFound as err:
|
||||
if auto_import: return import_model(llm, trust_remote_code=llm.__llm_trust_remote_code__)
|
||||
raise err from None
|
||||
def load_model(llm: openllm.LLM[M, T], *decls: t.Any, **attrs: t.Any) -> M:
|
||||
"""Load the model from BentoML store.
|
||||
'''Load the model from BentoML store.
|
||||
|
||||
By default, it will try to find check the model in the local store.
|
||||
If model is not found, it will raises a ``bentoml.exceptions.NotFound``.
|
||||
"""
|
||||
'''
|
||||
config, hub_attrs, attrs = process_config(llm.model_id, llm.__llm_trust_remote_code__, **attrs)
|
||||
safe_serialization = openllm.utils.first_not_none(
|
||||
t.cast(t.Optional[bool], llm._bentomodel.info.metadata.get("safe_serialisation", None)), attrs.pop("safe_serialization", None), default=llm._serialisation_format == "safetensors"
|
||||
t.cast(t.Optional[bool], llm._bentomodel.info.metadata.get('safe_serialisation', None)), attrs.pop('safe_serialization', None), default=llm._serialisation_format == 'safetensors'
|
||||
)
|
||||
if "_quantize" in llm._bentomodel.info.metadata and llm._bentomodel.info.metadata["_quantize"] == "gptq":
|
||||
if '_quantize' in llm._bentomodel.info.metadata and llm._bentomodel.info.metadata['_quantize'] == 'gptq':
|
||||
if not openllm.utils.is_autogptq_available():
|
||||
raise openllm.exceptions.OpenLLMException("GPTQ quantisation requires 'auto-gptq' (Not found in local environment). Install it with 'pip install \"openllm[gptq]\"'")
|
||||
if llm.config["model_type"] != "causal_lm": raise openllm.exceptions.OpenLLMException(f"GPTQ only support Causal LM (got {llm.__class__} of {llm.config['model_type']})")
|
||||
if llm.config['model_type'] != 'causal_lm': raise openllm.exceptions.OpenLLMException(f"GPTQ only support Causal LM (got {llm.__class__} of {llm.config['model_type']})")
|
||||
return autogptq.AutoGPTQForCausalLM.from_quantized(
|
||||
llm._bentomodel.path,
|
||||
*decls,
|
||||
quantize_config=t.cast("autogptq.BaseQuantizeConfig", llm.quantization_config),
|
||||
quantize_config=t.cast('autogptq.BaseQuantizeConfig', llm.quantization_config),
|
||||
trust_remote_code=llm.__llm_trust_remote_code__,
|
||||
use_safetensors=safe_serialization,
|
||||
**hub_attrs,
|
||||
**attrs
|
||||
)
|
||||
|
||||
device_map = attrs.pop("device_map", "auto" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None)
|
||||
device_map = attrs.pop('device_map', 'auto' if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None)
|
||||
model = infer_autoclass_from_llm(llm, config).from_pretrained(
|
||||
llm._bentomodel.path, *decls, config=config, trust_remote_code=llm.__llm_trust_remote_code__, device_map=device_map, **hub_attrs, **attrs
|
||||
).eval()
|
||||
# BetterTransformer is currently only supported on PyTorch.
|
||||
if llm.bettertransformer and isinstance(model, transformers.PreTrainedModel): model = model.to_bettertransformer()
|
||||
if llm.__llm_implementation__ in {"pt", "vllm"}: check_unintialised_params(model)
|
||||
return t.cast("M", model)
|
||||
if llm.__llm_implementation__ in {'pt', 'vllm'}: check_unintialised_params(model)
|
||||
return t.cast('M', model)
|
||||
def save_pretrained(
|
||||
llm: openllm.LLM[M, T],
|
||||
save_directory: str,
|
||||
@@ -188,29 +188,29 @@ def save_pretrained(
|
||||
state_dict: DictStrAny | None = None,
|
||||
save_function: t.Any | None = None,
|
||||
push_to_hub: bool = False,
|
||||
max_shard_size: int | str = "10GB",
|
||||
max_shard_size: int | str = '10GB',
|
||||
safe_serialization: bool = False,
|
||||
variant: str | None = None,
|
||||
**attrs: t.Any
|
||||
) -> None:
|
||||
save_function = t.cast(t.Callable[..., None], openllm.utils.first_not_none(save_function, default=torch.save))
|
||||
model_save_attrs, tokenizer_save_attrs = openllm.utils.normalize_attrs_to_model_tokenizer_pair(**attrs)
|
||||
safe_serialization = safe_serialization or llm._serialisation_format == "safetensors"
|
||||
safe_serialization = safe_serialization or llm._serialisation_format == 'safetensors'
|
||||
# NOTE: disable safetensors for vllm
|
||||
if llm.__llm_implementation__ == "vllm": safe_serialization = False
|
||||
if llm._quantize_method == "gptq":
|
||||
if llm.__llm_implementation__ == 'vllm': safe_serialization = False
|
||||
if llm._quantize_method == 'gptq':
|
||||
if not openllm.utils.is_autogptq_available():
|
||||
raise openllm.exceptions.OpenLLMException("GPTQ quantisation requires 'auto-gptq' (Not found in local environment). Install it with 'pip install \"openllm[gptq]\"'")
|
||||
if llm.config["model_type"] != "causal_lm": raise openllm.exceptions.OpenLLMException(f"GPTQ only support Causal LM (got {llm.__class__} of {llm.config['model_type']})")
|
||||
if not openllm.utils.lenient_issubclass(llm.model, autogptq.modeling.BaseGPTQForCausalLM): raise ValueError(f"Model is not a BaseGPTQForCausalLM (type: {type(llm.model)})")
|
||||
t.cast("autogptq.modeling.BaseGPTQForCausalLM", llm.model).save_quantized(save_directory, use_safetensors=safe_serialization)
|
||||
elif openllm.utils.LazyType["vllm.LLMEngine"]("vllm.LLMEngine").isinstance(llm.model):
|
||||
if llm.config['model_type'] != 'causal_lm': raise openllm.exceptions.OpenLLMException(f"GPTQ only support Causal LM (got {llm.__class__} of {llm.config['model_type']})")
|
||||
if not openllm.utils.lenient_issubclass(llm.model, autogptq.modeling.BaseGPTQForCausalLM): raise ValueError(f'Model is not a BaseGPTQForCausalLM (type: {type(llm.model)})')
|
||||
t.cast('autogptq.modeling.BaseGPTQForCausalLM', llm.model).save_quantized(save_directory, use_safetensors=safe_serialization)
|
||||
elif openllm.utils.LazyType['vllm.LLMEngine']('vllm.LLMEngine').isinstance(llm.model):
|
||||
raise RuntimeError("vllm.LLMEngine cannot be serialisation directly. This happens when 'save_pretrained' is called directly after `openllm.AutoVLLM` is initialized.")
|
||||
elif isinstance(llm.model, transformers.Pipeline):
|
||||
llm.model.save_pretrained(save_directory, safe_serialization=safe_serialization)
|
||||
else:
|
||||
# We can safely cast here since it will be the PreTrainedModel protocol.
|
||||
t.cast("transformers.PreTrainedModel", llm.model).save_pretrained(
|
||||
t.cast('transformers.PreTrainedModel', llm.model).save_pretrained(
|
||||
save_directory,
|
||||
is_main_process=is_main_process,
|
||||
state_dict=state_dict,
|
||||
|
||||
@@ -9,11 +9,11 @@ if t.TYPE_CHECKING:
|
||||
from bentoml._internal.models.model import ModelSignaturesType
|
||||
from openllm_core._typing_compat import DictStrAny, M, T
|
||||
else:
|
||||
transformers, torch = openllm_core.utils.LazyLoader("transformers", globals(), "transformers"), openllm_core.utils.LazyLoader("torch", globals(), "torch")
|
||||
transformers, torch = openllm_core.utils.LazyLoader('transformers', globals(), 'transformers'), openllm_core.utils.LazyLoader('torch', globals(), 'torch')
|
||||
|
||||
_object_setattr = object.__setattr__
|
||||
def process_config(model_id: str, trust_remote_code: bool, **attrs: t.Any) -> tuple[transformers.PretrainedConfig, DictStrAny, DictStrAny]:
|
||||
"""A helper function that correctly parse config and attributes for transformers.PretrainedConfig.
|
||||
'''A helper function that correctly parse config and attributes for transformers.PretrainedConfig.
|
||||
|
||||
Args:
|
||||
model_id: Model id to pass into ``transformers.AutoConfig``.
|
||||
@@ -22,51 +22,51 @@ def process_config(model_id: str, trust_remote_code: bool, **attrs: t.Any) -> tu
|
||||
|
||||
Returns:
|
||||
A tuple of ``transformers.PretrainedConfig``, all hub attributes, and remanining attributes that can be used by the Model class.
|
||||
"""
|
||||
config = attrs.pop("config", None)
|
||||
'''
|
||||
config = attrs.pop('config', None)
|
||||
# this logic below is synonymous to handling `from_pretrained` attrs.
|
||||
hub_attrs = {k: attrs.pop(k) for k in HUB_ATTRS if k in attrs}
|
||||
if not isinstance(config, transformers.PretrainedConfig):
|
||||
copied_attrs = copy.deepcopy(attrs)
|
||||
if copied_attrs.get("torch_dtype", None) == "auto": copied_attrs.pop("torch_dtype")
|
||||
if copied_attrs.get('torch_dtype', None) == 'auto': copied_attrs.pop('torch_dtype')
|
||||
config, attrs = transformers.AutoConfig.from_pretrained(model_id, return_unused_kwargs=True, trust_remote_code=trust_remote_code, **hub_attrs, **copied_attrs)
|
||||
return config, hub_attrs, attrs
|
||||
def infer_tokenizers_from_llm(__llm: openllm.LLM[t.Any, T], /) -> T:
|
||||
__cls = getattr(transformers, openllm_core.utils.first_not_none(__llm.config["tokenizer_class"], default="AutoTokenizer"), None)
|
||||
if __cls is None: raise ValueError(f"Cannot infer correct tokenizer class for {__llm}. Make sure to unset `tokenizer_class`")
|
||||
__cls = getattr(transformers, openllm_core.utils.first_not_none(__llm.config['tokenizer_class'], default='AutoTokenizer'), None)
|
||||
if __cls is None: raise ValueError(f'Cannot infer correct tokenizer class for {__llm}. Make sure to unset `tokenizer_class`')
|
||||
return __cls
|
||||
def infer_autoclass_from_llm(llm: openllm.LLM[M, T], config: transformers.PretrainedConfig, /) -> _BaseAutoModelClass:
|
||||
if llm.config["trust_remote_code"]:
|
||||
autoclass = "AutoModelForSeq2SeqLM" if llm.config["model_type"] == "seq2seq_lm" else "AutoModelForCausalLM"
|
||||
if not hasattr(config, "auto_map"):
|
||||
raise ValueError(f"Invalid configuraiton for {llm.model_id}. ``trust_remote_code=True`` requires `transformers.PretrainedConfig` to contain a `auto_map` mapping")
|
||||
if llm.config['trust_remote_code']:
|
||||
autoclass = 'AutoModelForSeq2SeqLM' if llm.config['model_type'] == 'seq2seq_lm' else 'AutoModelForCausalLM'
|
||||
if not hasattr(config, 'auto_map'):
|
||||
raise ValueError(f'Invalid configuraiton for {llm.model_id}. ``trust_remote_code=True`` requires `transformers.PretrainedConfig` to contain a `auto_map` mapping')
|
||||
# in case this model doesn't use the correct auto class for model type, for example like chatglm
|
||||
# where it uses AutoModel instead of AutoModelForCausalLM. Then we fallback to AutoModel
|
||||
if autoclass not in config.auto_map: autoclass = "AutoModel"
|
||||
if autoclass not in config.auto_map: autoclass = 'AutoModel'
|
||||
return getattr(transformers, autoclass)
|
||||
else:
|
||||
if type(config) in transformers.MODEL_FOR_CAUSAL_LM_MAPPING: idx = 0
|
||||
elif type(config) in transformers.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING: idx = 1
|
||||
else: raise openllm.exceptions.OpenLLMException(f"Model type {type(config)} is not supported yet.")
|
||||
else: raise openllm.exceptions.OpenLLMException(f'Model type {type(config)} is not supported yet.')
|
||||
return getattr(transformers, FRAMEWORK_TO_AUTOCLASS_MAPPING[llm.__llm_implementation__][idx])
|
||||
def check_unintialised_params(model: torch.nn.Module) -> None:
|
||||
unintialized = [n for n, param in model.named_parameters() if param.data.device == torch.device("meta")]
|
||||
if len(unintialized) > 0: raise RuntimeError(f"Found the following unintialized parameters in {model}: {unintialized}")
|
||||
unintialized = [n for n, param in model.named_parameters() if param.data.device == torch.device('meta')]
|
||||
if len(unintialized) > 0: raise RuntimeError(f'Found the following unintialized parameters in {model}: {unintialized}')
|
||||
def update_model(bentomodel: bentoml.Model, metadata: DictStrAny) -> bentoml.Model:
|
||||
based: DictStrAny = copy.deepcopy(bentomodel.info.metadata)
|
||||
based.update(metadata)
|
||||
_object_setattr(bentomodel, "_info", ModelInfo( # type: ignore[call-arg] # XXX: remove me once upstream is merged
|
||||
_object_setattr(bentomodel, '_info', ModelInfo( # type: ignore[call-arg] # XXX: remove me once upstream is merged
|
||||
tag=bentomodel.info.tag, module=bentomodel.info.module, labels=bentomodel.info.labels, options=bentomodel.info.options.to_dict(), signatures=bentomodel.info.signatures, context=bentomodel.info.context, api_version=bentomodel.info.api_version, creation_time=bentomodel.info.creation_time, metadata=based
|
||||
))
|
||||
return bentomodel
|
||||
# NOTE: sync with bentoml/_internal/frameworks/transformers.py#make_default_signatures
|
||||
def make_model_signatures(llm: openllm.LLM[M, T]) -> ModelSignaturesType:
|
||||
infer_fn: tuple[str, ...] = ("__call__",)
|
||||
infer_fn: tuple[str, ...] = ('__call__',)
|
||||
default_config = ModelSignature(batchable=False)
|
||||
if llm.__llm_implementation__ in {"pt", "vllm"}:
|
||||
infer_fn += ("forward", "generate", "contrastive_search", "greedy_search", "sample", "beam_search", "beam_sample", "group_beam_search", "constrained_beam_search",)
|
||||
elif llm.__llm_implementation__ == "tf":
|
||||
infer_fn += ("predict", "call", "generate", "compute_transition_scores", "greedy_search", "sample", "beam_search", "contrastive_search",)
|
||||
if llm.__llm_implementation__ in {'pt', 'vllm'}:
|
||||
infer_fn += ('forward', 'generate', 'contrastive_search', 'greedy_search', 'sample', 'beam_search', 'beam_sample', 'group_beam_search', 'constrained_beam_search',)
|
||||
elif llm.__llm_implementation__ == 'tf':
|
||||
infer_fn += ('predict', 'call', 'generate', 'compute_transition_scores', 'greedy_search', 'sample', 'beam_search', 'contrastive_search',)
|
||||
else:
|
||||
infer_fn += ("generate",)
|
||||
infer_fn += ('generate',)
|
||||
return {k: default_config for k in infer_fn}
|
||||
|
||||
@@ -5,22 +5,22 @@ if t.TYPE_CHECKING:
|
||||
import openllm
|
||||
from openllm_core._typing_compat import M, T
|
||||
def has_safetensors_weights(model_id: str, revision: str | None = None) -> bool:
|
||||
return any(s.rfilename.endswith(".safetensors") for s in HfApi().model_info(model_id, revision=revision).siblings)
|
||||
return any(s.rfilename.endswith('.safetensors') for s in HfApi().model_info(model_id, revision=revision).siblings)
|
||||
@attr.define(slots=True)
|
||||
class HfIgnore:
|
||||
safetensors = "*.safetensors"
|
||||
pt = "*.bin"
|
||||
tf = "*.h5"
|
||||
flax = "*.msgpack"
|
||||
safetensors = '*.safetensors'
|
||||
pt = '*.bin'
|
||||
tf = '*.h5'
|
||||
flax = '*.msgpack'
|
||||
|
||||
@classmethod
|
||||
def ignore_patterns(cls, llm: openllm.LLM[M, T]) -> list[str]:
|
||||
if llm.__llm_implementation__ == "vllm": base = [cls.tf, cls.flax, cls.safetensors]
|
||||
elif llm.__llm_implementation__ == "tf": base = [cls.flax, cls.pt]
|
||||
elif llm.__llm_implementation__ == "flax": base = [cls.tf, cls.pt, cls.safetensors] # as of current, safetensors is not supported with flax
|
||||
if llm.__llm_implementation__ == 'vllm': base = [cls.tf, cls.flax, cls.safetensors]
|
||||
elif llm.__llm_implementation__ == 'tf': base = [cls.flax, cls.pt]
|
||||
elif llm.__llm_implementation__ == 'flax': base = [cls.tf, cls.pt, cls.safetensors] # as of current, safetensors is not supported with flax
|
||||
else:
|
||||
base = [cls.tf, cls.flax]
|
||||
if has_safetensors_weights(llm.model_id): base.append(cls.pt)
|
||||
# filter out these files, since we probably don't need them for now.
|
||||
base.extend(["*.pdf", "*.md", ".gitattributes", "LICENSE.txt"])
|
||||
base.extend(['*.pdf', '*.md', '.gitattributes', 'LICENSE.txt'])
|
||||
return base
|
||||
|
||||
Reference in New Issue
Block a user