mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-03-11 19:47:03 -04:00
fix: loading correct local models (#599)
* fix(model): loading local correctly Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com> * chore: update repr and correct bentomodel processor Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> * ci: auto fixes from pre-commit.ci For more information, see https://pre-commit.ci * chore: cleanup transformers implementation Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> * fix: ruff to ignore I001 on all stubs Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> --------- Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com> Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -3,7 +3,6 @@
|
||||
from __future__ import annotations
|
||||
import importlib
|
||||
import logging
|
||||
import typing as t
|
||||
|
||||
import attr
|
||||
import orjson
|
||||
@@ -29,25 +28,17 @@ from ._helpers import process_config
|
||||
from .weights import HfIgnore
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import types
|
||||
|
||||
from bentoml._internal.models import ModelStore
|
||||
from openllm_core._typing_compat import DictStrAny
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
__all__ = ['import_model', 'get', 'load_model']
|
||||
_object_setattr = object.__setattr__
|
||||
|
||||
|
||||
def _patch_correct_tag(
|
||||
llm: openllm.LLM[M, T], config: transformers.PretrainedConfig, _revision: str | None = None
|
||||
) -> None:
|
||||
def _patch_correct_tag(llm: openllm.LLM[M, T], config: transformers.PretrainedConfig, _revision: str | None = None):
|
||||
# NOTE: The following won't hit during local since we generated a correct version based on local path hash It will only hit if we use model from HF Hub
|
||||
if llm.revision is not None:
|
||||
return
|
||||
if not llm._local:
|
||||
if not llm.local:
|
||||
try:
|
||||
if _revision is None:
|
||||
_revision = get_hash(config)
|
||||
@@ -55,7 +46,7 @@ def _patch_correct_tag(
|
||||
pass
|
||||
if _revision is None and llm.tag.version is not None:
|
||||
_revision = llm.tag.version
|
||||
if llm._tag.version is None:
|
||||
if llm.tag.version is None:
|
||||
_object_setattr(
|
||||
llm, '_tag', attr.evolve(llm.tag, version=_revision)
|
||||
) # HACK: This copies the correct revision into llm.tag
|
||||
@@ -64,13 +55,12 @@ def _patch_correct_tag(
|
||||
|
||||
|
||||
@inject
|
||||
def import_model(
|
||||
llm: openllm.LLM[M, T],
|
||||
*decls: t.Any,
|
||||
trust_remote_code: bool,
|
||||
_model_store: ModelStore = Provide[BentoMLContainer.model_store],
|
||||
**attrs: t.Any,
|
||||
) -> bentoml.Model:
|
||||
def import_model(llm, *decls, trust_remote_code, _model_store=Provide[BentoMLContainer.model_store], **attrs):
|
||||
_base_decls, _base_attrs = llm.llm_parameters[0]
|
||||
decls = (*_base_decls, *decls)
|
||||
attrs = {**_base_attrs, **attrs}
|
||||
if llm._local:
|
||||
logger.warning('Given model is a local model, OpenLLM will load model into memory for serialisation.')
|
||||
config, hub_attrs, attrs = process_config(llm.model_id, trust_remote_code, **attrs)
|
||||
_patch_correct_tag(llm, config)
|
||||
_, tokenizer_attrs = llm.llm_parameters
|
||||
@@ -78,7 +68,7 @@ def import_model(
|
||||
safe_serialisation = openllm.utils.first_not_none(
|
||||
attrs.get('safe_serialization'), default=llm._serialisation == 'safetensors'
|
||||
)
|
||||
metadata: DictStrAny = {'safe_serialisation': safe_serialisation}
|
||||
metadata = {'safe_serialisation': safe_serialisation}
|
||||
if quantize:
|
||||
metadata['_quantize'] = quantize
|
||||
architectures = getattr(config, 'architectures', [])
|
||||
@@ -87,9 +77,12 @@ def import_model(
|
||||
'Failed to determine the architecture for this model. Make sure the `config.json` is valid and can be loaded with `transformers.AutoConfig`'
|
||||
)
|
||||
metadata['_pretrained_class'] = architectures[0]
|
||||
metadata['_revision'] = get_hash(config)
|
||||
if not llm._local:
|
||||
metadata['_revision'] = get_hash(config)
|
||||
else:
|
||||
metadata['_revision'] = llm.revision
|
||||
|
||||
signatures: DictStrAny = {}
|
||||
signatures = {}
|
||||
|
||||
if quantize == 'gptq':
|
||||
if not openllm.utils.is_autogptq_available() or not openllm.utils.is_optimum_supports_gptq():
|
||||
@@ -125,8 +118,8 @@ def import_model(
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
model = None
|
||||
external_modules: list[types.ModuleType] = [importlib.import_module(tokenizer.__module__)]
|
||||
imported_modules: list[types.ModuleType] = []
|
||||
external_modules = [importlib.import_module(tokenizer.__module__)]
|
||||
imported_modules = []
|
||||
bentomodel = bentoml.Model.create(
|
||||
llm.tag,
|
||||
module='openllm.serialisation.transformers',
|
||||
@@ -150,11 +143,20 @@ def import_model(
|
||||
f.write(orjson.dumps(config.quantization_config, option=orjson.OPT_INDENT_2 | orjson.OPT_SORT_KEYS).decode())
|
||||
if llm._local: # possible local path
|
||||
model = infer_autoclass_from_llm(llm, config).from_pretrained(
|
||||
llm.model_id, *decls, config=config, trust_remote_code=trust_remote_code, **hub_attrs, **attrs
|
||||
llm.model_id,
|
||||
*decls,
|
||||
local_files_only=True,
|
||||
config=config,
|
||||
trust_remote_code=trust_remote_code,
|
||||
**hub_attrs,
|
||||
**attrs,
|
||||
)
|
||||
# for trust_remote_code to work
|
||||
bentomodel.enter_cloudpickle_context([importlib.import_module(model.__module__)], imported_modules)
|
||||
model.save_pretrained(bentomodel.path, max_shard_size='5GB', safe_serialization=safe_serialisation)
|
||||
model.save_pretrained(bentomodel.path, max_shard_size='2GB', safe_serialization=safe_serialisation)
|
||||
del model
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
else:
|
||||
# we will clone the all tings into the bentomodel path without loading model into memory
|
||||
snapshot_download(
|
||||
@@ -175,15 +177,10 @@ def import_model(
|
||||
)
|
||||
finally:
|
||||
bentomodel.exit_cloudpickle_context(imported_modules)
|
||||
# NOTE: We need to free up the cache after importing the model
|
||||
# in the case where users first run openllm start without the model available locally.
|
||||
if openllm.utils.is_torch_available() and torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
del model
|
||||
return bentomodel
|
||||
|
||||
|
||||
def get(llm: openllm.LLM[M, T]) -> bentoml.Model:
|
||||
def get(llm):
|
||||
try:
|
||||
model = bentoml.models.get(llm.tag)
|
||||
backend = model.info.labels['backend']
|
||||
@@ -203,7 +200,13 @@ def get(llm: openllm.LLM[M, T]) -> bentoml.Model:
|
||||
) from err
|
||||
|
||||
|
||||
def load_model(llm: openllm.LLM[M, T], *decls: t.Any, **attrs: t.Any) -> M:
|
||||
def check_unintialised_params(model):
|
||||
unintialized = [n for n, param in model.named_parameters() if param.data.device == torch.device('meta')]
|
||||
if len(unintialized) > 0:
|
||||
raise RuntimeError(f'Found the following unintialized parameters in {model}: {unintialized}')
|
||||
|
||||
|
||||
def load_model(llm, *decls, **attrs):
|
||||
if llm._quantise in {'awq', 'squeezellm'}:
|
||||
raise RuntimeError('AWQ is not yet supported with PyTorch backend.')
|
||||
config, attrs = transformers.AutoConfig.from_pretrained(
|
||||
@@ -232,7 +235,9 @@ def load_model(llm: openllm.LLM[M, T], *decls: t.Any, **attrs: t.Any) -> M:
|
||||
)
|
||||
|
||||
# TODO: investigate load with flash attention
|
||||
model = auto_class.from_pretrained(llm.bentomodel.path, device_map=device_map, **attrs)
|
||||
model = auto_class.from_pretrained(
|
||||
llm.bentomodel.path, device_map=device_map, trust_remote_code=llm.trust_remote_code, **attrs
|
||||
)
|
||||
else:
|
||||
model = auto_class.from_pretrained(
|
||||
llm.bentomodel.path,
|
||||
@@ -242,4 +247,5 @@ def load_model(llm: openllm.LLM[M, T], *decls: t.Any, **attrs: t.Any) -> M:
|
||||
device_map=device_map,
|
||||
**attrs,
|
||||
)
|
||||
return t.cast('M', model)
|
||||
check_unintialised_params(model)
|
||||
return model
|
||||
|
||||
@@ -2,43 +2,28 @@ from __future__ import annotations
|
||||
import copy
|
||||
import typing as t
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
|
||||
import openllm
|
||||
|
||||
from openllm.serialisation.constants import FRAMEWORK_TO_AUTOCLASS_MAPPING
|
||||
from openllm.serialisation.constants import HUB_ATTRS
|
||||
from openllm_core.exceptions import OpenLLMException
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from transformers.models.auto.auto_factory import _BaseAutoModelClass
|
||||
|
||||
from openllm_core._typing_compat import DictStrAny
|
||||
from openllm_core._typing_compat import M
|
||||
from openllm_core._typing_compat import T
|
||||
|
||||
from ..._llm import LLM
|
||||
|
||||
def get_hash(config: transformers.PretrainedConfig) -> str:
|
||||
|
||||
def get_hash(config) -> str:
|
||||
_commit_hash = getattr(config, '_commit_hash', None)
|
||||
if _commit_hash is None:
|
||||
raise ValueError(f'Cannot find commit hash in {config}')
|
||||
return _commit_hash
|
||||
|
||||
|
||||
def process_config(
|
||||
model_id: str, trust_remote_code: bool, **attrs: t.Any
|
||||
) -> tuple[transformers.PretrainedConfig, DictStrAny, DictStrAny]:
|
||||
"""A helper function that correctly parse config and attributes for transformers.PretrainedConfig.
|
||||
|
||||
Args:
|
||||
model_id: Model id to pass into ``transformers.AutoConfig``.
|
||||
trust_remote_code: Whether to trust_remote_code or not
|
||||
attrs: All possible attributes that can be processed by ``transformers.AutoConfig.from_pretrained
|
||||
|
||||
Returns:
|
||||
A tuple of ``transformers.PretrainedConfig``, all hub attributes, and remanining attributes that can be used by the Model class.
|
||||
"""
|
||||
def process_config(model_id, trust_remote_code, **attrs):
|
||||
config = attrs.pop('config', None)
|
||||
# this logic below is synonymous to handling `from_pretrained` attrs.
|
||||
hub_attrs = {k: attrs.pop(k) for k in HUB_ATTRS if k in attrs}
|
||||
@@ -52,7 +37,7 @@ def process_config(
|
||||
return config, hub_attrs, attrs
|
||||
|
||||
|
||||
def infer_autoclass_from_llm(llm: openllm.LLM[M, T], config: transformers.PretrainedConfig, /) -> _BaseAutoModelClass:
|
||||
def infer_autoclass_from_llm(llm: LLM[M, T], config, /):
|
||||
if llm.trust_remote_code:
|
||||
autoclass = 'AutoModelForSeq2SeqLM' if llm.config['model_type'] == 'seq2seq_lm' else 'AutoModelForCausalLM'
|
||||
if not hasattr(config, 'auto_map'):
|
||||
@@ -70,11 +55,5 @@ def infer_autoclass_from_llm(llm: openllm.LLM[M, T], config: transformers.Pretra
|
||||
elif type(config) in transformers.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING:
|
||||
idx = 1
|
||||
else:
|
||||
raise openllm.exceptions.OpenLLMException(f'Model type {type(config)} is not supported yet.')
|
||||
raise OpenLLMException(f'Model type {type(config)} is not supported yet.')
|
||||
return getattr(transformers, FRAMEWORK_TO_AUTOCLASS_MAPPING[llm.__llm_backend__][idx])
|
||||
|
||||
|
||||
def check_unintialised_params(model: torch.nn.Module) -> None:
|
||||
unintialized = [n for n, param in model.named_parameters() if param.data.device == torch.device('meta')]
|
||||
if len(unintialized) > 0:
|
||||
raise RuntimeError(f'Found the following unintialized parameters in {model}: {unintialized}')
|
||||
|
||||
@@ -2,11 +2,15 @@ from __future__ import annotations
|
||||
import traceback
|
||||
import typing as t
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import attr
|
||||
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
from openllm_core.exceptions import Error
|
||||
from openllm_core.utils import resolve_filepath
|
||||
from openllm_core.utils import validate_is_path
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
@@ -40,6 +44,8 @@ def ModelInfo(model_id: str, revision: str | None = None) -> HfModelInfo:
|
||||
|
||||
|
||||
def has_safetensors_weights(model_id: str, revision: str | None = None) -> bool:
|
||||
if validate_is_path(model_id):
|
||||
return next((True for item in Path(resolve_filepath(model_id)).glob('*.safetensors')), False)
|
||||
return any(s.rfilename.endswith('.safetensors') for s in ModelInfo(model_id, revision=revision).siblings)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user