From fa2038f4e2e27614e45f89bce750e37f2fae0102 Mon Sep 17 00:00:00 2001 From: Aaron Pham <29749331+aarnphm@users.noreply.github.com> Date: Fri, 10 Nov 2023 02:36:12 -0500 Subject: [PATCH] fix: loading correct local models (#599) * fix(model): loading local correctly Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com> * chore: update repr and correct bentomodel processor Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> * ci: auto fixes from pre-commit.ci For more information, see https://pre-commit.ci * chore: cleanup transformers implementation Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> * fix: ruff to ignore I001 on all stubs Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> --------- Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com> Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .pre-commit-config.yaml | 2 +- openllm-python/src/openllm/_llm.py | 17 ++++- openllm-python/src/openllm/cli/_factory.py | 2 + openllm-python/src/openllm/cli/_sdk.py | 4 +- openllm-python/src/openllm/cli/entrypoint.py | 16 +++- .../src/openllm/serialisation/__init__.py | 50 +++--------- .../src/openllm/serialisation/__init__.pyi | 22 ++++++ .../serialisation/transformers/__init__.py | 76 ++++++++++--------- .../serialisation/transformers/_helpers.py | 35 ++------- .../serialisation/transformers/weights.py | 6 ++ ruff.toml | 2 +- 11 files changed, 121 insertions(+), 111 deletions(-) create mode 100644 openllm-python/src/openllm/serialisation/__init__.pyi diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 3436c178..a6fc9a9e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -24,7 +24,7 @@ repos: language: system always_run: true pass_filenames: false - entry: mypy --strict openllm-client/src/openllm_client/__init__.pyi openllm-core/src/openllm_core/_typing_compat.py + entry: mypy --strict openllm-client/src/openllm_client/__init__.pyi openllm-core/src/openllm_core/_typing_compat.py openllm-python/src/openllm/serialisation/__init__.pyi - repo: https://github.com/editorconfig-checker/editorconfig-checker.python rev: '2.7.3' hooks: diff --git a/openllm-python/src/openllm/_llm.py b/openllm-python/src/openllm/_llm.py index a67ba1ce..1b21a1cb 100644 --- a/openllm-python/src/openllm/_llm.py +++ b/openllm-python/src/openllm/_llm.py @@ -121,7 +121,7 @@ _AdapterTuple: type[AdapterTuple] = codegen.make_attr_tuple_class('AdapterTuple' @attr.define(slots=True, repr=False, init=False) -class LLM(t.Generic[M, T]): +class LLM(t.Generic[M, T], ReprMixin): _model_id: str _revision: str | None _quantization_config: transformers.BitsAndBytesConfig | transformers.GPTQConfig | transformers.AwqConfig | None @@ -136,6 +136,7 @@ class LLM(t.Generic[M, T]): _prompt_template: PromptTemplate | None _system_message: str | None + _bentomodel: bentoml.Model = attr.field(init=False) __llm_config__: LLMConfig | None = None __llm_backend__: LiteralBackend = None # type: ignore __llm_quantization_config__: transformers.BitsAndBytesConfig | transformers.GPTQConfig | transformers.AwqConfig | None = None @@ -229,6 +230,7 @@ class LLM(t.Generic[M, T]): model = openllm.serialisation.import_model(self, trust_remote_code=self.trust_remote_code) # resolve the tag self._tag = model.tag + self._bentomodel = model @apply(lambda val: tuple(str.lower(i) if i else i for i in val)) def _make_tag_components( @@ -256,7 +258,14 @@ class LLM(t.Generic[M, T]): if attr in _reserved_namespace:raise ForbiddenAttributeError(f'{attr} should not be set during runtime.') super().__setattr__(attr,value) @property - def import_kwargs(self)->tuple[dict[str, t.Any],dict[str, t.Any]]: return {'device_map': 'auto' if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None, 'torch_dtype': torch.float16 if torch.cuda.is_available() else torch.float32}, {'padding_side': 'left', 'truncation_side': 'left'} + def __repr_keys__(self): return {'model_id', 'revision', 'backend', 'type'} + def __repr_args__(self) -> ReprArgs: + yield 'model_id', self._model_id if not self._local else self.tag.name + yield 'revision', self._revision if self._revision else self.tag.version + yield 'backend', self.__llm_backend__ + yield 'type', self.llm_type + @property + def import_kwargs(self)->tuple[dict[str, t.Any],dict[str, t.Any]]: return {'device_map': 'auto' if torch.cuda.is_available() else None, 'torch_dtype': torch.float16 if torch.cuda.is_available() else torch.float32}, {'padding_side': 'left', 'truncation_side': 'left'} @property def trust_remote_code(self)->bool:return first_not_none(check_bool_env('TRUST_REMOTE_CODE',False),default=self.__llm_trust_remote_code__) @property @@ -268,7 +277,7 @@ class LLM(t.Generic[M, T]): @property def tag(self)->bentoml.Tag:return self._tag @property - def bentomodel(self)->bentoml.Model:return openllm.serialisation.get(self) + def bentomodel(self)->bentoml.Model:return self._bentomodel @property def config(self)->LLMConfig: if self.__llm_config__ is None:self.__llm_config__=openllm.AutoConfig.infer_class_from_llm(self).model_construct_env(**self._model_attrs) @@ -282,6 +291,8 @@ class LLM(t.Generic[M, T]): return self.__llm_quantization_config__ @property def has_adapters(self)->bool:return self._adapter_map is not None + @property + def local(self)->bool:return self._local # NOTE: The section below defines a loose contract with langchain's LLM interface. @property def llm_type(self)->str:return normalise_model_name(self._model_id) diff --git a/openllm-python/src/openllm/cli/_factory.py b/openllm-python/src/openllm/cli/_factory.py index e061be71..185c6967 100644 --- a/openllm-python/src/openllm/cli/_factory.py +++ b/openllm-python/src/openllm/cli/_factory.py @@ -78,6 +78,8 @@ def parse_config_options( _bentoml_config_options_env = environ.pop('BENTOML_CONFIG_OPTIONS', '') _bentoml_config_options_opts = [ 'tracing.sample_rate=1.0', + 'api_server.max_runner_connections=25', + f'runners."llm-{config["start_name"]}-runner".batching.max_batch_size=128', f'api_server.traffic.timeout={server_timeout}', f'runners."llm-{config["start_name"]}-runner".traffic.timeout={config["timeout"]}', f'runners."llm-{config["start_name"]}-runner".workers_per_resource={workers_per_resource}', diff --git a/openllm-python/src/openllm/cli/_sdk.py b/openllm-python/src/openllm/cli/_sdk.py index ff59b299..beda442b 100644 --- a/openllm-python/src/openllm/cli/_sdk.py +++ b/openllm-python/src/openllm/cli/_sdk.py @@ -199,7 +199,9 @@ def _build( '--serialisation', t.cast( LiteralSerialisation, - first_not_none(serialisation, default='safetensors' if has_safetensors_weights(model_id) else 'legacy'), + first_not_none( + serialisation, default='safetensors' if has_safetensors_weights(model_id, model_version) else 'legacy' + ), ), ] if quantize: diff --git a/openllm-python/src/openllm/cli/entrypoint.py b/openllm-python/src/openllm/cli/entrypoint.py index 4740e1dc..79e9b219 100644 --- a/openllm-python/src/openllm/cli/entrypoint.py +++ b/openllm-python/src/openllm/cli/entrypoint.py @@ -420,7 +420,9 @@ def start_command( serialisation = t.cast( LiteralSerialisation, - first_not_none(serialisation, default='safetensors' if has_safetensors_weights(model_id) else 'legacy'), + first_not_none( + serialisation, default='safetensors' if has_safetensors_weights(model_id, model_version) else 'legacy' + ), ) if serialisation == 'safetensors' and quantize is not None and check_bool_env('OPENLLM_SERIALIZATION_WARNING'): termui.warning( @@ -592,7 +594,9 @@ def start_grpc_command( serialisation = t.cast( LiteralSerialisation, - first_not_none(serialisation, default='safetensors' if has_safetensors_weights(model_id) else 'legacy'), + first_not_none( + serialisation, default='safetensors' if has_safetensors_weights(model_id, model_version) else 'legacy' + ), ) if serialisation == 'safetensors' and quantize is not None and check_bool_env('OPENLLM_SERIALIZATION_WARNING'): termui.warning( @@ -789,7 +793,9 @@ def import_command( backend=backend, serialisation=t.cast( LiteralSerialisation, - first_not_none(serialisation, default='safetensors' if has_safetensors_weights(model_id) else 'legacy'), + first_not_none( + serialisation, default='safetensors' if has_safetensors_weights(model_id, model_version) else 'legacy' + ), ), ) backend_warning(llm.__llm_backend__) @@ -963,7 +969,9 @@ def build_command( quantize=quantize, serialisation=t.cast( LiteralSerialisation, - first_not_none(serialisation, default='safetensors' if has_safetensors_weights(model_id) else 'legacy'), + first_not_none( + serialisation, default='safetensors' if has_safetensors_weights(model_id, model_version) else 'legacy' + ), ), ) backend_warning(llm.__llm_backend__) diff --git a/openllm-python/src/openllm/serialisation/__init__.py b/openllm-python/src/openllm/serialisation/__init__.py index 908199ca..ab9d9858 100644 --- a/openllm-python/src/openllm/serialisation/__init__.py +++ b/openllm-python/src/openllm/serialisation/__init__.py @@ -1,10 +1,3 @@ -"""Serialisation utilities for OpenLLM. - -Currently supports transformers for PyTorch, and vLLM. - -Currently, GGML format is working in progress. -""" - from __future__ import annotations import importlib import typing as t @@ -14,32 +7,26 @@ import fs import openllm -from bentoml._internal.models.model import CUSTOM_OBJECTS_FILENAME -from openllm_core._typing_compat import M from openllm_core._typing_compat import ParamSpec -from openllm_core._typing_compat import T if t.TYPE_CHECKING: - import transformers as _transformers - - import bentoml - - from . import constants as constants - from . import ggml as ggml - from . import transformers as transformers + import transformers as transformers else: - _transformers = openllm.utils.LazyLoader('_transformers', globals(), 'transformers') + transformers = openllm.utils.LazyLoader('transformers', globals(), 'transformers') P = ParamSpec('P') -def load_tokenizer(llm: openllm.LLM[t.Any, T], **tokenizer_attrs: t.Any) -> T: +def load_tokenizer(llm, **tokenizer_attrs): """Load the tokenizer from BentoML store. By default, it will try to find the bentomodel whether it is in store.. If model is not found, it will raises a ``bentoml.exceptions.NotFound``. """ + tokenizer_attrs = {**llm.llm_parameters[-1], **tokenizer_attrs} + from bentoml._internal.models.model import CUSTOM_OBJECTS_FILENAME + from .transformers._helpers import process_config config, *_ = process_config(llm.bentomodel.path, llm.trust_remote_code) @@ -48,14 +35,14 @@ def load_tokenizer(llm: openllm.LLM[t.Any, T], **tokenizer_attrs: t.Any) -> T: if bentomodel_fs.isfile(CUSTOM_OBJECTS_FILENAME): with bentomodel_fs.open(CUSTOM_OBJECTS_FILENAME, 'rb') as cofile: try: - tokenizer = cloudpickle.load(t.cast('t.IO[bytes]', cofile))['tokenizer'] + tokenizer = cloudpickle.load(cofile)['tokenizer'] except KeyError: raise openllm.exceptions.OpenLLMException( "Bento model does not have tokenizer. Make sure to save the tokenizer within the model via 'custom_objects'. " 'For example: "bentoml.transformers.save_model(..., custom_objects={\'tokenizer\': tokenizer})"' ) from None else: - tokenizer = _transformers.AutoTokenizer.from_pretrained( + tokenizer = transformers.AutoTokenizer.from_pretrained( bentomodel_fs.getsyspath('/'), trust_remote_code=llm.trust_remote_code, **tokenizer_attrs ) @@ -71,15 +58,11 @@ def load_tokenizer(llm: openllm.LLM[t.Any, T], **tokenizer_attrs: t.Any) -> T: return tokenizer -class _Caller(t.Protocol[P]): - def __call__(self, llm: openllm.LLM[M, T], *args: P.args, **kwargs: P.kwargs) -> t.Any: ... - - _extras = ['get', 'import_model', 'load_model'] -def _make_dispatch_function(fn: str) -> _Caller[P]: - def caller(llm: openllm.LLM[M, T], *args: P.args, **kwargs: P.kwargs) -> t.Any: +def _make_dispatch_function(fn): + def caller(llm, *args, **kwargs): """Generic function dispatch to correct serialisation submodules based on LLM runtime. > [!NOTE] See 'openllm.serialisation.transformers' if 'llm.__llm_backend__ in ("pt", "vllm")' @@ -94,24 +77,15 @@ def _make_dispatch_function(fn: str) -> _Caller[P]: return caller -if t.TYPE_CHECKING: - - def get(llm: openllm.LLM[M, T], *args: t.Any, **kwargs: t.Any) -> bentoml.Model: ... - - def import_model(llm: openllm.LLM[M, T], *args: t.Any, **kwargs: t.Any) -> bentoml.Model: ... - - def load_model(llm: openllm.LLM[M, T], *args: t.Any, **kwargs: t.Any) -> M: ... - - _import_structure: dict[str, list[str]] = {'ggml': [], 'transformers': [], 'constants': []} __all__ = ['ggml', 'transformers', 'constants', 'load_tokenizer', *_extras] -def __dir__() -> list[str]: +def __dir__(): return sorted(__all__) -def __getattr__(name: str) -> t.Any: +def __getattr__(name): if name == 'load_tokenizer': return load_tokenizer elif name in _import_structure: diff --git a/openllm-python/src/openllm/serialisation/__init__.pyi b/openllm-python/src/openllm/serialisation/__init__.pyi new file mode 100644 index 00000000..0dc8d4c0 --- /dev/null +++ b/openllm-python/src/openllm/serialisation/__init__.pyi @@ -0,0 +1,22 @@ +"""Serialisation utilities for OpenLLM. + +Currently supports transformers for PyTorch, and vLLM. + +Currently, GGML format is working in progress. +""" + +from typing import Any + +from bentoml import Model +from openllm import LLM +from openllm_core._typing_compat import M +from openllm_core._typing_compat import T + +from . import constants as constants +from . import ggml as ggml +from . import transformers as transformers + +def load_tokenizer(llm: LLM[M, T], **attrs: Any) -> T: ... +def get(llm: LLM[M, T]) -> Model: ... +def import_model(llm: LLM[M, T], *args: Any, trust_remote_code: bool, **attrs: Any) -> Model: ... +def load_model(llm: LLM[M, T], *args: Any, **attrs: Any) -> M: ... diff --git a/openllm-python/src/openllm/serialisation/transformers/__init__.py b/openllm-python/src/openllm/serialisation/transformers/__init__.py index 9f2933a9..e3c91cb8 100644 --- a/openllm-python/src/openllm/serialisation/transformers/__init__.py +++ b/openllm-python/src/openllm/serialisation/transformers/__init__.py @@ -3,7 +3,6 @@ from __future__ import annotations import importlib import logging -import typing as t import attr import orjson @@ -29,25 +28,17 @@ from ._helpers import process_config from .weights import HfIgnore -if t.TYPE_CHECKING: - import types - - from bentoml._internal.models import ModelStore - from openllm_core._typing_compat import DictStrAny - logger = logging.getLogger(__name__) __all__ = ['import_model', 'get', 'load_model'] _object_setattr = object.__setattr__ -def _patch_correct_tag( - llm: openllm.LLM[M, T], config: transformers.PretrainedConfig, _revision: str | None = None -) -> None: +def _patch_correct_tag(llm: openllm.LLM[M, T], config: transformers.PretrainedConfig, _revision: str | None = None): # NOTE: The following won't hit during local since we generated a correct version based on local path hash It will only hit if we use model from HF Hub if llm.revision is not None: return - if not llm._local: + if not llm.local: try: if _revision is None: _revision = get_hash(config) @@ -55,7 +46,7 @@ def _patch_correct_tag( pass if _revision is None and llm.tag.version is not None: _revision = llm.tag.version - if llm._tag.version is None: + if llm.tag.version is None: _object_setattr( llm, '_tag', attr.evolve(llm.tag, version=_revision) ) # HACK: This copies the correct revision into llm.tag @@ -64,13 +55,12 @@ def _patch_correct_tag( @inject -def import_model( - llm: openllm.LLM[M, T], - *decls: t.Any, - trust_remote_code: bool, - _model_store: ModelStore = Provide[BentoMLContainer.model_store], - **attrs: t.Any, -) -> bentoml.Model: +def import_model(llm, *decls, trust_remote_code, _model_store=Provide[BentoMLContainer.model_store], **attrs): + _base_decls, _base_attrs = llm.llm_parameters[0] + decls = (*_base_decls, *decls) + attrs = {**_base_attrs, **attrs} + if llm._local: + logger.warning('Given model is a local model, OpenLLM will load model into memory for serialisation.') config, hub_attrs, attrs = process_config(llm.model_id, trust_remote_code, **attrs) _patch_correct_tag(llm, config) _, tokenizer_attrs = llm.llm_parameters @@ -78,7 +68,7 @@ def import_model( safe_serialisation = openllm.utils.first_not_none( attrs.get('safe_serialization'), default=llm._serialisation == 'safetensors' ) - metadata: DictStrAny = {'safe_serialisation': safe_serialisation} + metadata = {'safe_serialisation': safe_serialisation} if quantize: metadata['_quantize'] = quantize architectures = getattr(config, 'architectures', []) @@ -87,9 +77,12 @@ def import_model( 'Failed to determine the architecture for this model. Make sure the `config.json` is valid and can be loaded with `transformers.AutoConfig`' ) metadata['_pretrained_class'] = architectures[0] - metadata['_revision'] = get_hash(config) + if not llm._local: + metadata['_revision'] = get_hash(config) + else: + metadata['_revision'] = llm.revision - signatures: DictStrAny = {} + signatures = {} if quantize == 'gptq': if not openllm.utils.is_autogptq_available() or not openllm.utils.is_optimum_supports_gptq(): @@ -125,8 +118,8 @@ def import_model( tokenizer.pad_token = tokenizer.eos_token model = None - external_modules: list[types.ModuleType] = [importlib.import_module(tokenizer.__module__)] - imported_modules: list[types.ModuleType] = [] + external_modules = [importlib.import_module(tokenizer.__module__)] + imported_modules = [] bentomodel = bentoml.Model.create( llm.tag, module='openllm.serialisation.transformers', @@ -150,11 +143,20 @@ def import_model( f.write(orjson.dumps(config.quantization_config, option=orjson.OPT_INDENT_2 | orjson.OPT_SORT_KEYS).decode()) if llm._local: # possible local path model = infer_autoclass_from_llm(llm, config).from_pretrained( - llm.model_id, *decls, config=config, trust_remote_code=trust_remote_code, **hub_attrs, **attrs + llm.model_id, + *decls, + local_files_only=True, + config=config, + trust_remote_code=trust_remote_code, + **hub_attrs, + **attrs, ) # for trust_remote_code to work bentomodel.enter_cloudpickle_context([importlib.import_module(model.__module__)], imported_modules) - model.save_pretrained(bentomodel.path, max_shard_size='5GB', safe_serialization=safe_serialisation) + model.save_pretrained(bentomodel.path, max_shard_size='2GB', safe_serialization=safe_serialisation) + del model + if torch.cuda.is_available(): + torch.cuda.empty_cache() else: # we will clone the all tings into the bentomodel path without loading model into memory snapshot_download( @@ -175,15 +177,10 @@ def import_model( ) finally: bentomodel.exit_cloudpickle_context(imported_modules) - # NOTE: We need to free up the cache after importing the model - # in the case where users first run openllm start without the model available locally. - if openllm.utils.is_torch_available() and torch.cuda.is_available(): - torch.cuda.empty_cache() - del model return bentomodel -def get(llm: openllm.LLM[M, T]) -> bentoml.Model: +def get(llm): try: model = bentoml.models.get(llm.tag) backend = model.info.labels['backend'] @@ -203,7 +200,13 @@ def get(llm: openllm.LLM[M, T]) -> bentoml.Model: ) from err -def load_model(llm: openllm.LLM[M, T], *decls: t.Any, **attrs: t.Any) -> M: +def check_unintialised_params(model): + unintialized = [n for n, param in model.named_parameters() if param.data.device == torch.device('meta')] + if len(unintialized) > 0: + raise RuntimeError(f'Found the following unintialized parameters in {model}: {unintialized}') + + +def load_model(llm, *decls, **attrs): if llm._quantise in {'awq', 'squeezellm'}: raise RuntimeError('AWQ is not yet supported with PyTorch backend.') config, attrs = transformers.AutoConfig.from_pretrained( @@ -232,7 +235,9 @@ def load_model(llm: openllm.LLM[M, T], *decls: t.Any, **attrs: t.Any) -> M: ) # TODO: investigate load with flash attention - model = auto_class.from_pretrained(llm.bentomodel.path, device_map=device_map, **attrs) + model = auto_class.from_pretrained( + llm.bentomodel.path, device_map=device_map, trust_remote_code=llm.trust_remote_code, **attrs + ) else: model = auto_class.from_pretrained( llm.bentomodel.path, @@ -242,4 +247,5 @@ def load_model(llm: openllm.LLM[M, T], *decls: t.Any, **attrs: t.Any) -> M: device_map=device_map, **attrs, ) - return t.cast('M', model) + check_unintialised_params(model) + return model diff --git a/openllm-python/src/openllm/serialisation/transformers/_helpers.py b/openllm-python/src/openllm/serialisation/transformers/_helpers.py index 0cf622af..708d6f7e 100644 --- a/openllm-python/src/openllm/serialisation/transformers/_helpers.py +++ b/openllm-python/src/openllm/serialisation/transformers/_helpers.py @@ -2,43 +2,28 @@ from __future__ import annotations import copy import typing as t -import torch import transformers -import openllm - from openllm.serialisation.constants import FRAMEWORK_TO_AUTOCLASS_MAPPING from openllm.serialisation.constants import HUB_ATTRS +from openllm_core.exceptions import OpenLLMException if t.TYPE_CHECKING: - from transformers.models.auto.auto_factory import _BaseAutoModelClass - - from openllm_core._typing_compat import DictStrAny from openllm_core._typing_compat import M from openllm_core._typing_compat import T + from ..._llm import LLM -def get_hash(config: transformers.PretrainedConfig) -> str: + +def get_hash(config) -> str: _commit_hash = getattr(config, '_commit_hash', None) if _commit_hash is None: raise ValueError(f'Cannot find commit hash in {config}') return _commit_hash -def process_config( - model_id: str, trust_remote_code: bool, **attrs: t.Any -) -> tuple[transformers.PretrainedConfig, DictStrAny, DictStrAny]: - """A helper function that correctly parse config and attributes for transformers.PretrainedConfig. - - Args: - model_id: Model id to pass into ``transformers.AutoConfig``. - trust_remote_code: Whether to trust_remote_code or not - attrs: All possible attributes that can be processed by ``transformers.AutoConfig.from_pretrained - - Returns: - A tuple of ``transformers.PretrainedConfig``, all hub attributes, and remanining attributes that can be used by the Model class. - """ +def process_config(model_id, trust_remote_code, **attrs): config = attrs.pop('config', None) # this logic below is synonymous to handling `from_pretrained` attrs. hub_attrs = {k: attrs.pop(k) for k in HUB_ATTRS if k in attrs} @@ -52,7 +37,7 @@ def process_config( return config, hub_attrs, attrs -def infer_autoclass_from_llm(llm: openllm.LLM[M, T], config: transformers.PretrainedConfig, /) -> _BaseAutoModelClass: +def infer_autoclass_from_llm(llm: LLM[M, T], config, /): if llm.trust_remote_code: autoclass = 'AutoModelForSeq2SeqLM' if llm.config['model_type'] == 'seq2seq_lm' else 'AutoModelForCausalLM' if not hasattr(config, 'auto_map'): @@ -70,11 +55,5 @@ def infer_autoclass_from_llm(llm: openllm.LLM[M, T], config: transformers.Pretra elif type(config) in transformers.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING: idx = 1 else: - raise openllm.exceptions.OpenLLMException(f'Model type {type(config)} is not supported yet.') + raise OpenLLMException(f'Model type {type(config)} is not supported yet.') return getattr(transformers, FRAMEWORK_TO_AUTOCLASS_MAPPING[llm.__llm_backend__][idx]) - - -def check_unintialised_params(model: torch.nn.Module) -> None: - unintialized = [n for n, param in model.named_parameters() if param.data.device == torch.device('meta')] - if len(unintialized) > 0: - raise RuntimeError(f'Found the following unintialized parameters in {model}: {unintialized}') diff --git a/openllm-python/src/openllm/serialisation/transformers/weights.py b/openllm-python/src/openllm/serialisation/transformers/weights.py index a5123bea..4f526063 100644 --- a/openllm-python/src/openllm/serialisation/transformers/weights.py +++ b/openllm-python/src/openllm/serialisation/transformers/weights.py @@ -2,11 +2,15 @@ from __future__ import annotations import traceback import typing as t +from pathlib import Path + import attr from huggingface_hub import HfApi from openllm_core.exceptions import Error +from openllm_core.utils import resolve_filepath +from openllm_core.utils import validate_is_path if t.TYPE_CHECKING: @@ -40,6 +44,8 @@ def ModelInfo(model_id: str, revision: str | None = None) -> HfModelInfo: def has_safetensors_weights(model_id: str, revision: str | None = None) -> bool: + if validate_is_path(model_id): + return next((True for item in Path(resolve_filepath(model_id)).glob('*.safetensors')), False) return any(s.rfilename.endswith('.safetensors') for s in ModelInfo(model_id, revision=revision).siblings) diff --git a/ruff.toml b/ruff.toml index 1996764a..daa84865 100644 --- a/ruff.toml +++ b/ruff.toml @@ -93,7 +93,7 @@ multiline-quotes = "double" docstring-quotes = "double" [lint.extend-per-file-ignores] -"openllm-client/src/openllm_client/__init__.pyi" = ["I001"] +"/**/*.pyi" = ["I001"] "openllm-python/tests/**/*" = ["S101", "TID252", "PT011", "S307"] "openllm-python/src/openllm/_llm.py" = ["F811"] "openllm-core/src/openllm_core/utils/import_utils.py" = ["PLW0603", "F811"]