mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-03-11 03:33:44 -04:00
fix(mixtral): setup hack atm to load weights from pt specifically instead of safetensors (#776)
fix(mixtral): setup hack atm to load weights from pt specifically instead of safetensors Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
from __future__ import annotations
|
||||
import attr, traceback, pathlib, typing as t
|
||||
import attr, traceback, functools, pathlib, typing as t
|
||||
from huggingface_hub import HfApi
|
||||
from openllm_core.exceptions import Error
|
||||
from openllm_core.utils import resolve_filepath, validate_is_path
|
||||
@@ -27,10 +27,13 @@ def ModelInfo(model_id: str, revision: str | None = None) -> HfModelInfo:
|
||||
traceback.print_exc()
|
||||
raise Error(f'Failed to fetch {model_id} from huggingface.co') from err
|
||||
|
||||
def has_safetensors_weights(model_id: str, revision: str | None = None) -> bool:
|
||||
def has_weights(model_id: str, revision: str | None = None, *, extensions: str) -> bool:
|
||||
if validate_is_path(model_id):
|
||||
return next((True for _ in pathlib.Path(resolve_filepath(model_id)).glob('*.safetensors')), False)
|
||||
return any(s.rfilename.endswith('.safetensors') for s in ModelInfo(model_id, revision=revision).siblings)
|
||||
return next((True for _ in pathlib.Path(resolve_filepath(model_id)).glob(f'*.{extensions}')), False)
|
||||
return any(s.rfilename.endswith(f'.{extensions}') for s in ModelInfo(model_id, revision=revision).siblings)
|
||||
|
||||
has_safetensors_weights = functools.partial(has_weights, extensions='safetensors')
|
||||
has_pt_weights = functools.partial(has_weights, extensions='pt')
|
||||
|
||||
@attr.define(slots=True)
|
||||
class HfIgnore:
|
||||
@@ -43,8 +46,10 @@ class HfIgnore:
|
||||
def ignore_patterns(cls, llm: openllm.LLM[t.Any, t.Any]) -> list[str]:
|
||||
if llm.__llm_backend__ in {'vllm', 'pt'}:
|
||||
base = [cls.tf, cls.flax, cls.gguf]
|
||||
if has_safetensors_weights(llm.model_id):
|
||||
base.append(cls.pt)
|
||||
if llm.config['architecture'] == 'MixtralForCausalLM': # XXX: Hack for Mixtral as safetensors is yet to be working atm
|
||||
base.append(cls.safetensors)
|
||||
elif has_safetensors_weights(llm.model_id):
|
||||
base.extend([cls.pt, '*.pt'])
|
||||
else:
|
||||
base.append(cls.safetensors)
|
||||
elif llm.__llm_backend__ == 'ggml':
|
||||
|
||||
Reference in New Issue
Block a user