diff --git a/openllm-python/src/openllm/serialisation/transformers/weights.py b/openllm-python/src/openllm/serialisation/transformers/weights.py index bda9fa7c..d97e8a5d 100644 --- a/openllm-python/src/openllm/serialisation/transformers/weights.py +++ b/openllm-python/src/openllm/serialisation/transformers/weights.py @@ -1,5 +1,5 @@ from __future__ import annotations -import attr, traceback, pathlib, typing as t +import attr, traceback, functools, pathlib, typing as t from huggingface_hub import HfApi from openllm_core.exceptions import Error from openllm_core.utils import resolve_filepath, validate_is_path @@ -27,10 +27,13 @@ def ModelInfo(model_id: str, revision: str | None = None) -> HfModelInfo: traceback.print_exc() raise Error(f'Failed to fetch {model_id} from huggingface.co') from err -def has_safetensors_weights(model_id: str, revision: str | None = None) -> bool: +def has_weights(model_id: str, revision: str | None = None, *, extensions: str) -> bool: if validate_is_path(model_id): - return next((True for _ in pathlib.Path(resolve_filepath(model_id)).glob('*.safetensors')), False) - return any(s.rfilename.endswith('.safetensors') for s in ModelInfo(model_id, revision=revision).siblings) + return next((True for _ in pathlib.Path(resolve_filepath(model_id)).glob(f'*.{extensions}')), False) + return any(s.rfilename.endswith(f'.{extensions}') for s in ModelInfo(model_id, revision=revision).siblings) + +has_safetensors_weights = functools.partial(has_weights, extensions='safetensors') +has_pt_weights = functools.partial(has_weights, extensions='pt') @attr.define(slots=True) class HfIgnore: @@ -43,8 +46,10 @@ class HfIgnore: def ignore_patterns(cls, llm: openllm.LLM[t.Any, t.Any]) -> list[str]: if llm.__llm_backend__ in {'vllm', 'pt'}: base = [cls.tf, cls.flax, cls.gguf] - if has_safetensors_weights(llm.model_id): - base.append(cls.pt) + if llm.config['architecture'] == 'MixtralForCausalLM': # XXX: Hack for Mixtral as safetensors is yet to be working atm + base.append(cls.safetensors) + elif has_safetensors_weights(llm.model_id): + base.extend([cls.pt, '*.pt']) else: base.append(cls.safetensors) elif llm.__llm_backend__ == 'ggml':