From 0d83cefcb6e03762cbb27cfa972243a2c2338919 Mon Sep 17 00:00:00 2001 From: Aaron Pham <29749331+aarnphm@users.noreply.github.com> Date: Wed, 13 Dec 2023 18:18:51 -0500 Subject: [PATCH] fix(mixtral): setup hack atm to load weights from pt specifically instead of safetensors (#776) fix(mixtral): setup hack atm to load weights from pt specifically instead of safetensors Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> --- .../serialisation/transformers/weights.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/openllm-python/src/openllm/serialisation/transformers/weights.py b/openllm-python/src/openllm/serialisation/transformers/weights.py index bda9fa7c..d97e8a5d 100644 --- a/openllm-python/src/openllm/serialisation/transformers/weights.py +++ b/openllm-python/src/openllm/serialisation/transformers/weights.py @@ -1,5 +1,5 @@ from __future__ import annotations -import attr, traceback, pathlib, typing as t +import attr, traceback, functools, pathlib, typing as t from huggingface_hub import HfApi from openllm_core.exceptions import Error from openllm_core.utils import resolve_filepath, validate_is_path @@ -27,10 +27,13 @@ def ModelInfo(model_id: str, revision: str | None = None) -> HfModelInfo: traceback.print_exc() raise Error(f'Failed to fetch {model_id} from huggingface.co') from err -def has_safetensors_weights(model_id: str, revision: str | None = None) -> bool: +def has_weights(model_id: str, revision: str | None = None, *, extensions: str) -> bool: if validate_is_path(model_id): - return next((True for _ in pathlib.Path(resolve_filepath(model_id)).glob('*.safetensors')), False) - return any(s.rfilename.endswith('.safetensors') for s in ModelInfo(model_id, revision=revision).siblings) + return next((True for _ in pathlib.Path(resolve_filepath(model_id)).glob(f'*.{extensions}')), False) + return any(s.rfilename.endswith(f'.{extensions}') for s in ModelInfo(model_id, revision=revision).siblings) + +has_safetensors_weights = functools.partial(has_weights, extensions='safetensors') +has_pt_weights = functools.partial(has_weights, extensions='pt') @attr.define(slots=True) class HfIgnore: @@ -43,8 +46,10 @@ class HfIgnore: def ignore_patterns(cls, llm: openllm.LLM[t.Any, t.Any]) -> list[str]: if llm.__llm_backend__ in {'vllm', 'pt'}: base = [cls.tf, cls.flax, cls.gguf] - if has_safetensors_weights(llm.model_id): - base.append(cls.pt) + if llm.config['architecture'] == 'MixtralForCausalLM': # XXX: Hack for Mixtral as safetensors is yet to be working atm + base.append(cls.safetensors) + elif has_safetensors_weights(llm.model_id): + base.extend([cls.pt, '*.pt']) else: base.append(cls.safetensors) elif llm.__llm_backend__ == 'ggml':