fix(mixtral): setup hack atm to load weights from pt specifically instead of safetensors (#776)

fix(mixtral): setup hack atm to load weights from pt specifically
instead of safetensors

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
Aaron Pham
2023-12-13 18:18:51 -05:00
committed by GitHub
parent 2dbcfa8a0c
commit 0d83cefcb6

View File

@@ -1,5 +1,5 @@
from __future__ import annotations
import attr, traceback, pathlib, typing as t
import attr, traceback, functools, pathlib, typing as t
from huggingface_hub import HfApi
from openllm_core.exceptions import Error
from openllm_core.utils import resolve_filepath, validate_is_path
@@ -27,10 +27,13 @@ def ModelInfo(model_id: str, revision: str | None = None) -> HfModelInfo:
traceback.print_exc()
raise Error(f'Failed to fetch {model_id} from huggingface.co') from err
def has_safetensors_weights(model_id: str, revision: str | None = None) -> bool:
def has_weights(model_id: str, revision: str | None = None, *, extensions: str) -> bool:
if validate_is_path(model_id):
return next((True for _ in pathlib.Path(resolve_filepath(model_id)).glob('*.safetensors')), False)
return any(s.rfilename.endswith('.safetensors') for s in ModelInfo(model_id, revision=revision).siblings)
return next((True for _ in pathlib.Path(resolve_filepath(model_id)).glob(f'*.{extensions}')), False)
return any(s.rfilename.endswith(f'.{extensions}') for s in ModelInfo(model_id, revision=revision).siblings)
has_safetensors_weights = functools.partial(has_weights, extensions='safetensors')
has_pt_weights = functools.partial(has_weights, extensions='pt')
@attr.define(slots=True)
class HfIgnore:
@@ -43,8 +46,10 @@ class HfIgnore:
def ignore_patterns(cls, llm: openllm.LLM[t.Any, t.Any]) -> list[str]:
if llm.__llm_backend__ in {'vllm', 'pt'}:
base = [cls.tf, cls.flax, cls.gguf]
if has_safetensors_weights(llm.model_id):
base.append(cls.pt)
if llm.config['architecture'] == 'MixtralForCausalLM': # XXX: Hack for Mixtral as safetensors is yet to be working atm
base.append(cls.safetensors)
elif has_safetensors_weights(llm.model_id):
base.extend([cls.pt, '*.pt'])
else:
base.append(cls.safetensors)
elif llm.__llm_backend__ == 'ggml':