fix: loading correct local models (#599)

* fix(model): loading local correctly

Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>

* chore: update repr and correct bentomodel processor

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>

* ci: auto fixes from pre-commit.ci

For more information, see https://pre-commit.ci

* chore: cleanup transformers implementation

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>

* fix: ruff to ignore I001 on all stubs

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>

---------

Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>
Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Aaron Pham
2023-11-10 02:36:12 -05:00
committed by GitHub
parent 5e45245457
commit fa2038f4e2
11 changed files with 121 additions and 111 deletions

View File

@@ -3,7 +3,6 @@
from __future__ import annotations
import importlib
import logging
import typing as t
import attr
import orjson
@@ -29,25 +28,17 @@ from ._helpers import process_config
from .weights import HfIgnore
if t.TYPE_CHECKING:
import types
from bentoml._internal.models import ModelStore
from openllm_core._typing_compat import DictStrAny
logger = logging.getLogger(__name__)
__all__ = ['import_model', 'get', 'load_model']
_object_setattr = object.__setattr__
def _patch_correct_tag(
llm: openllm.LLM[M, T], config: transformers.PretrainedConfig, _revision: str | None = None
) -> None:
def _patch_correct_tag(llm: openllm.LLM[M, T], config: transformers.PretrainedConfig, _revision: str | None = None):
# NOTE: The following won't hit during local since we generated a correct version based on local path hash It will only hit if we use model from HF Hub
if llm.revision is not None:
return
if not llm._local:
if not llm.local:
try:
if _revision is None:
_revision = get_hash(config)
@@ -55,7 +46,7 @@ def _patch_correct_tag(
pass
if _revision is None and llm.tag.version is not None:
_revision = llm.tag.version
if llm._tag.version is None:
if llm.tag.version is None:
_object_setattr(
llm, '_tag', attr.evolve(llm.tag, version=_revision)
) # HACK: This copies the correct revision into llm.tag
@@ -64,13 +55,12 @@ def _patch_correct_tag(
@inject
def import_model(
llm: openllm.LLM[M, T],
*decls: t.Any,
trust_remote_code: bool,
_model_store: ModelStore = Provide[BentoMLContainer.model_store],
**attrs: t.Any,
) -> bentoml.Model:
def import_model(llm, *decls, trust_remote_code, _model_store=Provide[BentoMLContainer.model_store], **attrs):
_base_decls, _base_attrs = llm.llm_parameters[0]
decls = (*_base_decls, *decls)
attrs = {**_base_attrs, **attrs}
if llm._local:
logger.warning('Given model is a local model, OpenLLM will load model into memory for serialisation.')
config, hub_attrs, attrs = process_config(llm.model_id, trust_remote_code, **attrs)
_patch_correct_tag(llm, config)
_, tokenizer_attrs = llm.llm_parameters
@@ -78,7 +68,7 @@ def import_model(
safe_serialisation = openllm.utils.first_not_none(
attrs.get('safe_serialization'), default=llm._serialisation == 'safetensors'
)
metadata: DictStrAny = {'safe_serialisation': safe_serialisation}
metadata = {'safe_serialisation': safe_serialisation}
if quantize:
metadata['_quantize'] = quantize
architectures = getattr(config, 'architectures', [])
@@ -87,9 +77,12 @@ def import_model(
'Failed to determine the architecture for this model. Make sure the `config.json` is valid and can be loaded with `transformers.AutoConfig`'
)
metadata['_pretrained_class'] = architectures[0]
metadata['_revision'] = get_hash(config)
if not llm._local:
metadata['_revision'] = get_hash(config)
else:
metadata['_revision'] = llm.revision
signatures: DictStrAny = {}
signatures = {}
if quantize == 'gptq':
if not openllm.utils.is_autogptq_available() or not openllm.utils.is_optimum_supports_gptq():
@@ -125,8 +118,8 @@ def import_model(
tokenizer.pad_token = tokenizer.eos_token
model = None
external_modules: list[types.ModuleType] = [importlib.import_module(tokenizer.__module__)]
imported_modules: list[types.ModuleType] = []
external_modules = [importlib.import_module(tokenizer.__module__)]
imported_modules = []
bentomodel = bentoml.Model.create(
llm.tag,
module='openllm.serialisation.transformers',
@@ -150,11 +143,20 @@ def import_model(
f.write(orjson.dumps(config.quantization_config, option=orjson.OPT_INDENT_2 | orjson.OPT_SORT_KEYS).decode())
if llm._local: # possible local path
model = infer_autoclass_from_llm(llm, config).from_pretrained(
llm.model_id, *decls, config=config, trust_remote_code=trust_remote_code, **hub_attrs, **attrs
llm.model_id,
*decls,
local_files_only=True,
config=config,
trust_remote_code=trust_remote_code,
**hub_attrs,
**attrs,
)
# for trust_remote_code to work
bentomodel.enter_cloudpickle_context([importlib.import_module(model.__module__)], imported_modules)
model.save_pretrained(bentomodel.path, max_shard_size='5GB', safe_serialization=safe_serialisation)
model.save_pretrained(bentomodel.path, max_shard_size='2GB', safe_serialization=safe_serialisation)
del model
if torch.cuda.is_available():
torch.cuda.empty_cache()
else:
# we will clone the all tings into the bentomodel path without loading model into memory
snapshot_download(
@@ -175,15 +177,10 @@ def import_model(
)
finally:
bentomodel.exit_cloudpickle_context(imported_modules)
# NOTE: We need to free up the cache after importing the model
# in the case where users first run openllm start without the model available locally.
if openllm.utils.is_torch_available() and torch.cuda.is_available():
torch.cuda.empty_cache()
del model
return bentomodel
def get(llm: openllm.LLM[M, T]) -> bentoml.Model:
def get(llm):
try:
model = bentoml.models.get(llm.tag)
backend = model.info.labels['backend']
@@ -203,7 +200,13 @@ def get(llm: openllm.LLM[M, T]) -> bentoml.Model:
) from err
def load_model(llm: openllm.LLM[M, T], *decls: t.Any, **attrs: t.Any) -> M:
def check_unintialised_params(model):
unintialized = [n for n, param in model.named_parameters() if param.data.device == torch.device('meta')]
if len(unintialized) > 0:
raise RuntimeError(f'Found the following unintialized parameters in {model}: {unintialized}')
def load_model(llm, *decls, **attrs):
if llm._quantise in {'awq', 'squeezellm'}:
raise RuntimeError('AWQ is not yet supported with PyTorch backend.')
config, attrs = transformers.AutoConfig.from_pretrained(
@@ -232,7 +235,9 @@ def load_model(llm: openllm.LLM[M, T], *decls: t.Any, **attrs: t.Any) -> M:
)
# TODO: investigate load with flash attention
model = auto_class.from_pretrained(llm.bentomodel.path, device_map=device_map, **attrs)
model = auto_class.from_pretrained(
llm.bentomodel.path, device_map=device_map, trust_remote_code=llm.trust_remote_code, **attrs
)
else:
model = auto_class.from_pretrained(
llm.bentomodel.path,
@@ -242,4 +247,5 @@ def load_model(llm: openllm.LLM[M, T], *decls: t.Any, **attrs: t.Any) -> M:
device_map=device_map,
**attrs,
)
return t.cast('M', model)
check_unintialised_params(model)
return model

View File

@@ -2,43 +2,28 @@ from __future__ import annotations
import copy
import typing as t
import torch
import transformers
import openllm
from openllm.serialisation.constants import FRAMEWORK_TO_AUTOCLASS_MAPPING
from openllm.serialisation.constants import HUB_ATTRS
from openllm_core.exceptions import OpenLLMException
if t.TYPE_CHECKING:
from transformers.models.auto.auto_factory import _BaseAutoModelClass
from openllm_core._typing_compat import DictStrAny
from openllm_core._typing_compat import M
from openllm_core._typing_compat import T
from ..._llm import LLM
def get_hash(config: transformers.PretrainedConfig) -> str:
def get_hash(config) -> str:
_commit_hash = getattr(config, '_commit_hash', None)
if _commit_hash is None:
raise ValueError(f'Cannot find commit hash in {config}')
return _commit_hash
def process_config(
model_id: str, trust_remote_code: bool, **attrs: t.Any
) -> tuple[transformers.PretrainedConfig, DictStrAny, DictStrAny]:
"""A helper function that correctly parse config and attributes for transformers.PretrainedConfig.
Args:
model_id: Model id to pass into ``transformers.AutoConfig``.
trust_remote_code: Whether to trust_remote_code or not
attrs: All possible attributes that can be processed by ``transformers.AutoConfig.from_pretrained
Returns:
A tuple of ``transformers.PretrainedConfig``, all hub attributes, and remanining attributes that can be used by the Model class.
"""
def process_config(model_id, trust_remote_code, **attrs):
config = attrs.pop('config', None)
# this logic below is synonymous to handling `from_pretrained` attrs.
hub_attrs = {k: attrs.pop(k) for k in HUB_ATTRS if k in attrs}
@@ -52,7 +37,7 @@ def process_config(
return config, hub_attrs, attrs
def infer_autoclass_from_llm(llm: openllm.LLM[M, T], config: transformers.PretrainedConfig, /) -> _BaseAutoModelClass:
def infer_autoclass_from_llm(llm: LLM[M, T], config, /):
if llm.trust_remote_code:
autoclass = 'AutoModelForSeq2SeqLM' if llm.config['model_type'] == 'seq2seq_lm' else 'AutoModelForCausalLM'
if not hasattr(config, 'auto_map'):
@@ -70,11 +55,5 @@ def infer_autoclass_from_llm(llm: openllm.LLM[M, T], config: transformers.Pretra
elif type(config) in transformers.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING:
idx = 1
else:
raise openllm.exceptions.OpenLLMException(f'Model type {type(config)} is not supported yet.')
raise OpenLLMException(f'Model type {type(config)} is not supported yet.')
return getattr(transformers, FRAMEWORK_TO_AUTOCLASS_MAPPING[llm.__llm_backend__][idx])
def check_unintialised_params(model: torch.nn.Module) -> None:
unintialized = [n for n, param in model.named_parameters() if param.data.device == torch.device('meta')]
if len(unintialized) > 0:
raise RuntimeError(f'Found the following unintialized parameters in {model}: {unintialized}')

View File

@@ -2,11 +2,15 @@ from __future__ import annotations
import traceback
import typing as t
from pathlib import Path
import attr
from huggingface_hub import HfApi
from openllm_core.exceptions import Error
from openllm_core.utils import resolve_filepath
from openllm_core.utils import validate_is_path
if t.TYPE_CHECKING:
@@ -40,6 +44,8 @@ def ModelInfo(model_id: str, revision: str | None = None) -> HfModelInfo:
def has_safetensors_weights(model_id: str, revision: str | None = None) -> bool:
if validate_is_path(model_id):
return next((True for item in Path(resolve_filepath(model_id)).glob('*.safetensors')), False)
return any(s.rfilename.endswith('.safetensors') for s in ModelInfo(model_id, revision=revision).siblings)