mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-01-06 06:29:21 -05:00
fix: loading correct local models (#599)
* fix(model): loading local correctly Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com> * chore: update repr and correct bentomodel processor Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> * ci: auto fixes from pre-commit.ci For more information, see https://pre-commit.ci * chore: cleanup transformers implementation Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> * fix: ruff to ignore I001 on all stubs Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> --------- Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com> Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -24,7 +24,7 @@ repos:
|
||||
language: system
|
||||
always_run: true
|
||||
pass_filenames: false
|
||||
entry: mypy --strict openllm-client/src/openllm_client/__init__.pyi openllm-core/src/openllm_core/_typing_compat.py
|
||||
entry: mypy --strict openllm-client/src/openllm_client/__init__.pyi openllm-core/src/openllm_core/_typing_compat.py openllm-python/src/openllm/serialisation/__init__.pyi
|
||||
- repo: https://github.com/editorconfig-checker/editorconfig-checker.python
|
||||
rev: '2.7.3'
|
||||
hooks:
|
||||
|
||||
@@ -121,7 +121,7 @@ _AdapterTuple: type[AdapterTuple] = codegen.make_attr_tuple_class('AdapterTuple'
|
||||
|
||||
|
||||
@attr.define(slots=True, repr=False, init=False)
|
||||
class LLM(t.Generic[M, T]):
|
||||
class LLM(t.Generic[M, T], ReprMixin):
|
||||
_model_id: str
|
||||
_revision: str | None
|
||||
_quantization_config: transformers.BitsAndBytesConfig | transformers.GPTQConfig | transformers.AwqConfig | None
|
||||
@@ -136,6 +136,7 @@ class LLM(t.Generic[M, T]):
|
||||
_prompt_template: PromptTemplate | None
|
||||
_system_message: str | None
|
||||
|
||||
_bentomodel: bentoml.Model = attr.field(init=False)
|
||||
__llm_config__: LLMConfig | None = None
|
||||
__llm_backend__: LiteralBackend = None # type: ignore
|
||||
__llm_quantization_config__: transformers.BitsAndBytesConfig | transformers.GPTQConfig | transformers.AwqConfig | None = None
|
||||
@@ -229,6 +230,7 @@ class LLM(t.Generic[M, T]):
|
||||
model = openllm.serialisation.import_model(self, trust_remote_code=self.trust_remote_code)
|
||||
# resolve the tag
|
||||
self._tag = model.tag
|
||||
self._bentomodel = model
|
||||
|
||||
@apply(lambda val: tuple(str.lower(i) if i else i for i in val))
|
||||
def _make_tag_components(
|
||||
@@ -256,7 +258,14 @@ class LLM(t.Generic[M, T]):
|
||||
if attr in _reserved_namespace:raise ForbiddenAttributeError(f'{attr} should not be set during runtime.')
|
||||
super().__setattr__(attr,value)
|
||||
@property
|
||||
def import_kwargs(self)->tuple[dict[str, t.Any],dict[str, t.Any]]: return {'device_map': 'auto' if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None, 'torch_dtype': torch.float16 if torch.cuda.is_available() else torch.float32}, {'padding_side': 'left', 'truncation_side': 'left'}
|
||||
def __repr_keys__(self): return {'model_id', 'revision', 'backend', 'type'}
|
||||
def __repr_args__(self) -> ReprArgs:
|
||||
yield 'model_id', self._model_id if not self._local else self.tag.name
|
||||
yield 'revision', self._revision if self._revision else self.tag.version
|
||||
yield 'backend', self.__llm_backend__
|
||||
yield 'type', self.llm_type
|
||||
@property
|
||||
def import_kwargs(self)->tuple[dict[str, t.Any],dict[str, t.Any]]: return {'device_map': 'auto' if torch.cuda.is_available() else None, 'torch_dtype': torch.float16 if torch.cuda.is_available() else torch.float32}, {'padding_side': 'left', 'truncation_side': 'left'}
|
||||
@property
|
||||
def trust_remote_code(self)->bool:return first_not_none(check_bool_env('TRUST_REMOTE_CODE',False),default=self.__llm_trust_remote_code__)
|
||||
@property
|
||||
@@ -268,7 +277,7 @@ class LLM(t.Generic[M, T]):
|
||||
@property
|
||||
def tag(self)->bentoml.Tag:return self._tag
|
||||
@property
|
||||
def bentomodel(self)->bentoml.Model:return openllm.serialisation.get(self)
|
||||
def bentomodel(self)->bentoml.Model:return self._bentomodel
|
||||
@property
|
||||
def config(self)->LLMConfig:
|
||||
if self.__llm_config__ is None:self.__llm_config__=openllm.AutoConfig.infer_class_from_llm(self).model_construct_env(**self._model_attrs)
|
||||
@@ -282,6 +291,8 @@ class LLM(t.Generic[M, T]):
|
||||
return self.__llm_quantization_config__
|
||||
@property
|
||||
def has_adapters(self)->bool:return self._adapter_map is not None
|
||||
@property
|
||||
def local(self)->bool:return self._local
|
||||
# NOTE: The section below defines a loose contract with langchain's LLM interface.
|
||||
@property
|
||||
def llm_type(self)->str:return normalise_model_name(self._model_id)
|
||||
|
||||
@@ -78,6 +78,8 @@ def parse_config_options(
|
||||
_bentoml_config_options_env = environ.pop('BENTOML_CONFIG_OPTIONS', '')
|
||||
_bentoml_config_options_opts = [
|
||||
'tracing.sample_rate=1.0',
|
||||
'api_server.max_runner_connections=25',
|
||||
f'runners."llm-{config["start_name"]}-runner".batching.max_batch_size=128',
|
||||
f'api_server.traffic.timeout={server_timeout}',
|
||||
f'runners."llm-{config["start_name"]}-runner".traffic.timeout={config["timeout"]}',
|
||||
f'runners."llm-{config["start_name"]}-runner".workers_per_resource={workers_per_resource}',
|
||||
|
||||
@@ -199,7 +199,9 @@ def _build(
|
||||
'--serialisation',
|
||||
t.cast(
|
||||
LiteralSerialisation,
|
||||
first_not_none(serialisation, default='safetensors' if has_safetensors_weights(model_id) else 'legacy'),
|
||||
first_not_none(
|
||||
serialisation, default='safetensors' if has_safetensors_weights(model_id, model_version) else 'legacy'
|
||||
),
|
||||
),
|
||||
]
|
||||
if quantize:
|
||||
|
||||
@@ -420,7 +420,9 @@ def start_command(
|
||||
|
||||
serialisation = t.cast(
|
||||
LiteralSerialisation,
|
||||
first_not_none(serialisation, default='safetensors' if has_safetensors_weights(model_id) else 'legacy'),
|
||||
first_not_none(
|
||||
serialisation, default='safetensors' if has_safetensors_weights(model_id, model_version) else 'legacy'
|
||||
),
|
||||
)
|
||||
if serialisation == 'safetensors' and quantize is not None and check_bool_env('OPENLLM_SERIALIZATION_WARNING'):
|
||||
termui.warning(
|
||||
@@ -592,7 +594,9 @@ def start_grpc_command(
|
||||
|
||||
serialisation = t.cast(
|
||||
LiteralSerialisation,
|
||||
first_not_none(serialisation, default='safetensors' if has_safetensors_weights(model_id) else 'legacy'),
|
||||
first_not_none(
|
||||
serialisation, default='safetensors' if has_safetensors_weights(model_id, model_version) else 'legacy'
|
||||
),
|
||||
)
|
||||
if serialisation == 'safetensors' and quantize is not None and check_bool_env('OPENLLM_SERIALIZATION_WARNING'):
|
||||
termui.warning(
|
||||
@@ -789,7 +793,9 @@ def import_command(
|
||||
backend=backend,
|
||||
serialisation=t.cast(
|
||||
LiteralSerialisation,
|
||||
first_not_none(serialisation, default='safetensors' if has_safetensors_weights(model_id) else 'legacy'),
|
||||
first_not_none(
|
||||
serialisation, default='safetensors' if has_safetensors_weights(model_id, model_version) else 'legacy'
|
||||
),
|
||||
),
|
||||
)
|
||||
backend_warning(llm.__llm_backend__)
|
||||
@@ -963,7 +969,9 @@ def build_command(
|
||||
quantize=quantize,
|
||||
serialisation=t.cast(
|
||||
LiteralSerialisation,
|
||||
first_not_none(serialisation, default='safetensors' if has_safetensors_weights(model_id) else 'legacy'),
|
||||
first_not_none(
|
||||
serialisation, default='safetensors' if has_safetensors_weights(model_id, model_version) else 'legacy'
|
||||
),
|
||||
),
|
||||
)
|
||||
backend_warning(llm.__llm_backend__)
|
||||
|
||||
@@ -1,10 +1,3 @@
|
||||
"""Serialisation utilities for OpenLLM.
|
||||
|
||||
Currently supports transformers for PyTorch, and vLLM.
|
||||
|
||||
Currently, GGML format is working in progress.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import importlib
|
||||
import typing as t
|
||||
@@ -14,32 +7,26 @@ import fs
|
||||
|
||||
import openllm
|
||||
|
||||
from bentoml._internal.models.model import CUSTOM_OBJECTS_FILENAME
|
||||
from openllm_core._typing_compat import M
|
||||
from openllm_core._typing_compat import ParamSpec
|
||||
from openllm_core._typing_compat import T
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import transformers as _transformers
|
||||
|
||||
import bentoml
|
||||
|
||||
from . import constants as constants
|
||||
from . import ggml as ggml
|
||||
from . import transformers as transformers
|
||||
import transformers as transformers
|
||||
else:
|
||||
_transformers = openllm.utils.LazyLoader('_transformers', globals(), 'transformers')
|
||||
transformers = openllm.utils.LazyLoader('transformers', globals(), 'transformers')
|
||||
|
||||
P = ParamSpec('P')
|
||||
|
||||
|
||||
def load_tokenizer(llm: openllm.LLM[t.Any, T], **tokenizer_attrs: t.Any) -> T:
|
||||
def load_tokenizer(llm, **tokenizer_attrs):
|
||||
"""Load the tokenizer from BentoML store.
|
||||
|
||||
By default, it will try to find the bentomodel whether it is in store..
|
||||
If model is not found, it will raises a ``bentoml.exceptions.NotFound``.
|
||||
"""
|
||||
tokenizer_attrs = {**llm.llm_parameters[-1], **tokenizer_attrs}
|
||||
from bentoml._internal.models.model import CUSTOM_OBJECTS_FILENAME
|
||||
|
||||
from .transformers._helpers import process_config
|
||||
|
||||
config, *_ = process_config(llm.bentomodel.path, llm.trust_remote_code)
|
||||
@@ -48,14 +35,14 @@ def load_tokenizer(llm: openllm.LLM[t.Any, T], **tokenizer_attrs: t.Any) -> T:
|
||||
if bentomodel_fs.isfile(CUSTOM_OBJECTS_FILENAME):
|
||||
with bentomodel_fs.open(CUSTOM_OBJECTS_FILENAME, 'rb') as cofile:
|
||||
try:
|
||||
tokenizer = cloudpickle.load(t.cast('t.IO[bytes]', cofile))['tokenizer']
|
||||
tokenizer = cloudpickle.load(cofile)['tokenizer']
|
||||
except KeyError:
|
||||
raise openllm.exceptions.OpenLLMException(
|
||||
"Bento model does not have tokenizer. Make sure to save the tokenizer within the model via 'custom_objects'. "
|
||||
'For example: "bentoml.transformers.save_model(..., custom_objects={\'tokenizer\': tokenizer})"'
|
||||
) from None
|
||||
else:
|
||||
tokenizer = _transformers.AutoTokenizer.from_pretrained(
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
||||
bentomodel_fs.getsyspath('/'), trust_remote_code=llm.trust_remote_code, **tokenizer_attrs
|
||||
)
|
||||
|
||||
@@ -71,15 +58,11 @@ def load_tokenizer(llm: openllm.LLM[t.Any, T], **tokenizer_attrs: t.Any) -> T:
|
||||
return tokenizer
|
||||
|
||||
|
||||
class _Caller(t.Protocol[P]):
|
||||
def __call__(self, llm: openllm.LLM[M, T], *args: P.args, **kwargs: P.kwargs) -> t.Any: ...
|
||||
|
||||
|
||||
_extras = ['get', 'import_model', 'load_model']
|
||||
|
||||
|
||||
def _make_dispatch_function(fn: str) -> _Caller[P]:
|
||||
def caller(llm: openllm.LLM[M, T], *args: P.args, **kwargs: P.kwargs) -> t.Any:
|
||||
def _make_dispatch_function(fn):
|
||||
def caller(llm, *args, **kwargs):
|
||||
"""Generic function dispatch to correct serialisation submodules based on LLM runtime.
|
||||
|
||||
> [!NOTE] See 'openllm.serialisation.transformers' if 'llm.__llm_backend__ in ("pt", "vllm")'
|
||||
@@ -94,24 +77,15 @@ def _make_dispatch_function(fn: str) -> _Caller[P]:
|
||||
return caller
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
|
||||
def get(llm: openllm.LLM[M, T], *args: t.Any, **kwargs: t.Any) -> bentoml.Model: ...
|
||||
|
||||
def import_model(llm: openllm.LLM[M, T], *args: t.Any, **kwargs: t.Any) -> bentoml.Model: ...
|
||||
|
||||
def load_model(llm: openllm.LLM[M, T], *args: t.Any, **kwargs: t.Any) -> M: ...
|
||||
|
||||
|
||||
_import_structure: dict[str, list[str]] = {'ggml': [], 'transformers': [], 'constants': []}
|
||||
__all__ = ['ggml', 'transformers', 'constants', 'load_tokenizer', *_extras]
|
||||
|
||||
|
||||
def __dir__() -> list[str]:
|
||||
def __dir__():
|
||||
return sorted(__all__)
|
||||
|
||||
|
||||
def __getattr__(name: str) -> t.Any:
|
||||
def __getattr__(name):
|
||||
if name == 'load_tokenizer':
|
||||
return load_tokenizer
|
||||
elif name in _import_structure:
|
||||
|
||||
22
openllm-python/src/openllm/serialisation/__init__.pyi
Normal file
22
openllm-python/src/openllm/serialisation/__init__.pyi
Normal file
@@ -0,0 +1,22 @@
|
||||
"""Serialisation utilities for OpenLLM.
|
||||
|
||||
Currently supports transformers for PyTorch, and vLLM.
|
||||
|
||||
Currently, GGML format is working in progress.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from bentoml import Model
|
||||
from openllm import LLM
|
||||
from openllm_core._typing_compat import M
|
||||
from openllm_core._typing_compat import T
|
||||
|
||||
from . import constants as constants
|
||||
from . import ggml as ggml
|
||||
from . import transformers as transformers
|
||||
|
||||
def load_tokenizer(llm: LLM[M, T], **attrs: Any) -> T: ...
|
||||
def get(llm: LLM[M, T]) -> Model: ...
|
||||
def import_model(llm: LLM[M, T], *args: Any, trust_remote_code: bool, **attrs: Any) -> Model: ...
|
||||
def load_model(llm: LLM[M, T], *args: Any, **attrs: Any) -> M: ...
|
||||
@@ -3,7 +3,6 @@
|
||||
from __future__ import annotations
|
||||
import importlib
|
||||
import logging
|
||||
import typing as t
|
||||
|
||||
import attr
|
||||
import orjson
|
||||
@@ -29,25 +28,17 @@ from ._helpers import process_config
|
||||
from .weights import HfIgnore
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import types
|
||||
|
||||
from bentoml._internal.models import ModelStore
|
||||
from openllm_core._typing_compat import DictStrAny
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
__all__ = ['import_model', 'get', 'load_model']
|
||||
_object_setattr = object.__setattr__
|
||||
|
||||
|
||||
def _patch_correct_tag(
|
||||
llm: openllm.LLM[M, T], config: transformers.PretrainedConfig, _revision: str | None = None
|
||||
) -> None:
|
||||
def _patch_correct_tag(llm: openllm.LLM[M, T], config: transformers.PretrainedConfig, _revision: str | None = None):
|
||||
# NOTE: The following won't hit during local since we generated a correct version based on local path hash It will only hit if we use model from HF Hub
|
||||
if llm.revision is not None:
|
||||
return
|
||||
if not llm._local:
|
||||
if not llm.local:
|
||||
try:
|
||||
if _revision is None:
|
||||
_revision = get_hash(config)
|
||||
@@ -55,7 +46,7 @@ def _patch_correct_tag(
|
||||
pass
|
||||
if _revision is None and llm.tag.version is not None:
|
||||
_revision = llm.tag.version
|
||||
if llm._tag.version is None:
|
||||
if llm.tag.version is None:
|
||||
_object_setattr(
|
||||
llm, '_tag', attr.evolve(llm.tag, version=_revision)
|
||||
) # HACK: This copies the correct revision into llm.tag
|
||||
@@ -64,13 +55,12 @@ def _patch_correct_tag(
|
||||
|
||||
|
||||
@inject
|
||||
def import_model(
|
||||
llm: openllm.LLM[M, T],
|
||||
*decls: t.Any,
|
||||
trust_remote_code: bool,
|
||||
_model_store: ModelStore = Provide[BentoMLContainer.model_store],
|
||||
**attrs: t.Any,
|
||||
) -> bentoml.Model:
|
||||
def import_model(llm, *decls, trust_remote_code, _model_store=Provide[BentoMLContainer.model_store], **attrs):
|
||||
_base_decls, _base_attrs = llm.llm_parameters[0]
|
||||
decls = (*_base_decls, *decls)
|
||||
attrs = {**_base_attrs, **attrs}
|
||||
if llm._local:
|
||||
logger.warning('Given model is a local model, OpenLLM will load model into memory for serialisation.')
|
||||
config, hub_attrs, attrs = process_config(llm.model_id, trust_remote_code, **attrs)
|
||||
_patch_correct_tag(llm, config)
|
||||
_, tokenizer_attrs = llm.llm_parameters
|
||||
@@ -78,7 +68,7 @@ def import_model(
|
||||
safe_serialisation = openllm.utils.first_not_none(
|
||||
attrs.get('safe_serialization'), default=llm._serialisation == 'safetensors'
|
||||
)
|
||||
metadata: DictStrAny = {'safe_serialisation': safe_serialisation}
|
||||
metadata = {'safe_serialisation': safe_serialisation}
|
||||
if quantize:
|
||||
metadata['_quantize'] = quantize
|
||||
architectures = getattr(config, 'architectures', [])
|
||||
@@ -87,9 +77,12 @@ def import_model(
|
||||
'Failed to determine the architecture for this model. Make sure the `config.json` is valid and can be loaded with `transformers.AutoConfig`'
|
||||
)
|
||||
metadata['_pretrained_class'] = architectures[0]
|
||||
metadata['_revision'] = get_hash(config)
|
||||
if not llm._local:
|
||||
metadata['_revision'] = get_hash(config)
|
||||
else:
|
||||
metadata['_revision'] = llm.revision
|
||||
|
||||
signatures: DictStrAny = {}
|
||||
signatures = {}
|
||||
|
||||
if quantize == 'gptq':
|
||||
if not openllm.utils.is_autogptq_available() or not openllm.utils.is_optimum_supports_gptq():
|
||||
@@ -125,8 +118,8 @@ def import_model(
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
model = None
|
||||
external_modules: list[types.ModuleType] = [importlib.import_module(tokenizer.__module__)]
|
||||
imported_modules: list[types.ModuleType] = []
|
||||
external_modules = [importlib.import_module(tokenizer.__module__)]
|
||||
imported_modules = []
|
||||
bentomodel = bentoml.Model.create(
|
||||
llm.tag,
|
||||
module='openllm.serialisation.transformers',
|
||||
@@ -150,11 +143,20 @@ def import_model(
|
||||
f.write(orjson.dumps(config.quantization_config, option=orjson.OPT_INDENT_2 | orjson.OPT_SORT_KEYS).decode())
|
||||
if llm._local: # possible local path
|
||||
model = infer_autoclass_from_llm(llm, config).from_pretrained(
|
||||
llm.model_id, *decls, config=config, trust_remote_code=trust_remote_code, **hub_attrs, **attrs
|
||||
llm.model_id,
|
||||
*decls,
|
||||
local_files_only=True,
|
||||
config=config,
|
||||
trust_remote_code=trust_remote_code,
|
||||
**hub_attrs,
|
||||
**attrs,
|
||||
)
|
||||
# for trust_remote_code to work
|
||||
bentomodel.enter_cloudpickle_context([importlib.import_module(model.__module__)], imported_modules)
|
||||
model.save_pretrained(bentomodel.path, max_shard_size='5GB', safe_serialization=safe_serialisation)
|
||||
model.save_pretrained(bentomodel.path, max_shard_size='2GB', safe_serialization=safe_serialisation)
|
||||
del model
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
else:
|
||||
# we will clone the all tings into the bentomodel path without loading model into memory
|
||||
snapshot_download(
|
||||
@@ -175,15 +177,10 @@ def import_model(
|
||||
)
|
||||
finally:
|
||||
bentomodel.exit_cloudpickle_context(imported_modules)
|
||||
# NOTE: We need to free up the cache after importing the model
|
||||
# in the case where users first run openllm start without the model available locally.
|
||||
if openllm.utils.is_torch_available() and torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
del model
|
||||
return bentomodel
|
||||
|
||||
|
||||
def get(llm: openllm.LLM[M, T]) -> bentoml.Model:
|
||||
def get(llm):
|
||||
try:
|
||||
model = bentoml.models.get(llm.tag)
|
||||
backend = model.info.labels['backend']
|
||||
@@ -203,7 +200,13 @@ def get(llm: openllm.LLM[M, T]) -> bentoml.Model:
|
||||
) from err
|
||||
|
||||
|
||||
def load_model(llm: openllm.LLM[M, T], *decls: t.Any, **attrs: t.Any) -> M:
|
||||
def check_unintialised_params(model):
|
||||
unintialized = [n for n, param in model.named_parameters() if param.data.device == torch.device('meta')]
|
||||
if len(unintialized) > 0:
|
||||
raise RuntimeError(f'Found the following unintialized parameters in {model}: {unintialized}')
|
||||
|
||||
|
||||
def load_model(llm, *decls, **attrs):
|
||||
if llm._quantise in {'awq', 'squeezellm'}:
|
||||
raise RuntimeError('AWQ is not yet supported with PyTorch backend.')
|
||||
config, attrs = transformers.AutoConfig.from_pretrained(
|
||||
@@ -232,7 +235,9 @@ def load_model(llm: openllm.LLM[M, T], *decls: t.Any, **attrs: t.Any) -> M:
|
||||
)
|
||||
|
||||
# TODO: investigate load with flash attention
|
||||
model = auto_class.from_pretrained(llm.bentomodel.path, device_map=device_map, **attrs)
|
||||
model = auto_class.from_pretrained(
|
||||
llm.bentomodel.path, device_map=device_map, trust_remote_code=llm.trust_remote_code, **attrs
|
||||
)
|
||||
else:
|
||||
model = auto_class.from_pretrained(
|
||||
llm.bentomodel.path,
|
||||
@@ -242,4 +247,5 @@ def load_model(llm: openllm.LLM[M, T], *decls: t.Any, **attrs: t.Any) -> M:
|
||||
device_map=device_map,
|
||||
**attrs,
|
||||
)
|
||||
return t.cast('M', model)
|
||||
check_unintialised_params(model)
|
||||
return model
|
||||
|
||||
@@ -2,43 +2,28 @@ from __future__ import annotations
|
||||
import copy
|
||||
import typing as t
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
|
||||
import openllm
|
||||
|
||||
from openllm.serialisation.constants import FRAMEWORK_TO_AUTOCLASS_MAPPING
|
||||
from openllm.serialisation.constants import HUB_ATTRS
|
||||
from openllm_core.exceptions import OpenLLMException
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from transformers.models.auto.auto_factory import _BaseAutoModelClass
|
||||
|
||||
from openllm_core._typing_compat import DictStrAny
|
||||
from openllm_core._typing_compat import M
|
||||
from openllm_core._typing_compat import T
|
||||
|
||||
from ..._llm import LLM
|
||||
|
||||
def get_hash(config: transformers.PretrainedConfig) -> str:
|
||||
|
||||
def get_hash(config) -> str:
|
||||
_commit_hash = getattr(config, '_commit_hash', None)
|
||||
if _commit_hash is None:
|
||||
raise ValueError(f'Cannot find commit hash in {config}')
|
||||
return _commit_hash
|
||||
|
||||
|
||||
def process_config(
|
||||
model_id: str, trust_remote_code: bool, **attrs: t.Any
|
||||
) -> tuple[transformers.PretrainedConfig, DictStrAny, DictStrAny]:
|
||||
"""A helper function that correctly parse config and attributes for transformers.PretrainedConfig.
|
||||
|
||||
Args:
|
||||
model_id: Model id to pass into ``transformers.AutoConfig``.
|
||||
trust_remote_code: Whether to trust_remote_code or not
|
||||
attrs: All possible attributes that can be processed by ``transformers.AutoConfig.from_pretrained
|
||||
|
||||
Returns:
|
||||
A tuple of ``transformers.PretrainedConfig``, all hub attributes, and remanining attributes that can be used by the Model class.
|
||||
"""
|
||||
def process_config(model_id, trust_remote_code, **attrs):
|
||||
config = attrs.pop('config', None)
|
||||
# this logic below is synonymous to handling `from_pretrained` attrs.
|
||||
hub_attrs = {k: attrs.pop(k) for k in HUB_ATTRS if k in attrs}
|
||||
@@ -52,7 +37,7 @@ def process_config(
|
||||
return config, hub_attrs, attrs
|
||||
|
||||
|
||||
def infer_autoclass_from_llm(llm: openllm.LLM[M, T], config: transformers.PretrainedConfig, /) -> _BaseAutoModelClass:
|
||||
def infer_autoclass_from_llm(llm: LLM[M, T], config, /):
|
||||
if llm.trust_remote_code:
|
||||
autoclass = 'AutoModelForSeq2SeqLM' if llm.config['model_type'] == 'seq2seq_lm' else 'AutoModelForCausalLM'
|
||||
if not hasattr(config, 'auto_map'):
|
||||
@@ -70,11 +55,5 @@ def infer_autoclass_from_llm(llm: openllm.LLM[M, T], config: transformers.Pretra
|
||||
elif type(config) in transformers.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING:
|
||||
idx = 1
|
||||
else:
|
||||
raise openllm.exceptions.OpenLLMException(f'Model type {type(config)} is not supported yet.')
|
||||
raise OpenLLMException(f'Model type {type(config)} is not supported yet.')
|
||||
return getattr(transformers, FRAMEWORK_TO_AUTOCLASS_MAPPING[llm.__llm_backend__][idx])
|
||||
|
||||
|
||||
def check_unintialised_params(model: torch.nn.Module) -> None:
|
||||
unintialized = [n for n, param in model.named_parameters() if param.data.device == torch.device('meta')]
|
||||
if len(unintialized) > 0:
|
||||
raise RuntimeError(f'Found the following unintialized parameters in {model}: {unintialized}')
|
||||
|
||||
@@ -2,11 +2,15 @@ from __future__ import annotations
|
||||
import traceback
|
||||
import typing as t
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import attr
|
||||
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
from openllm_core.exceptions import Error
|
||||
from openllm_core.utils import resolve_filepath
|
||||
from openllm_core.utils import validate_is_path
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
@@ -40,6 +44,8 @@ def ModelInfo(model_id: str, revision: str | None = None) -> HfModelInfo:
|
||||
|
||||
|
||||
def has_safetensors_weights(model_id: str, revision: str | None = None) -> bool:
|
||||
if validate_is_path(model_id):
|
||||
return next((True for item in Path(resolve_filepath(model_id)).glob('*.safetensors')), False)
|
||||
return any(s.rfilename.endswith('.safetensors') for s in ModelInfo(model_id, revision=revision).siblings)
|
||||
|
||||
|
||||
|
||||
@@ -93,7 +93,7 @@ multiline-quotes = "double"
|
||||
docstring-quotes = "double"
|
||||
|
||||
[lint.extend-per-file-ignores]
|
||||
"openllm-client/src/openllm_client/__init__.pyi" = ["I001"]
|
||||
"/**/*.pyi" = ["I001"]
|
||||
"openllm-python/tests/**/*" = ["S101", "TID252", "PT011", "S307"]
|
||||
"openllm-python/src/openllm/_llm.py" = ["F811"]
|
||||
"openllm-core/src/openllm_core/utils/import_utils.py" = ["PLW0603", "F811"]
|
||||
|
||||
Reference in New Issue
Block a user