Files
OpenLLM/openllm-python/src/openllm/serialisation/transformers/weights.py
Aaron Pham 0d83cefcb6 fix(mixtral): setup hack atm to load weights from pt specifically instead of safetensors (#776)
fix(mixtral): setup hack atm to load weights from pt specifically
instead of safetensors

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
2023-12-13 18:18:51 -05:00

62 lines
2.3 KiB
Python

from __future__ import annotations
import attr, traceback, functools, pathlib, typing as t
from huggingface_hub import HfApi
from openllm_core.exceptions import Error
from openllm_core.utils import resolve_filepath, validate_is_path
if t.TYPE_CHECKING:
from huggingface_hub.hf_api import ModelInfo as HfModelInfo
import openllm
__global_inst__ = None
__cached_id__: dict[str, HfModelInfo] = dict()
def Client() -> HfApi:
global __global_inst__ # noqa: PLW0603
if __global_inst__ is None:
__global_inst__ = HfApi()
return __global_inst__
def ModelInfo(model_id: str, revision: str | None = None) -> HfModelInfo:
if model_id in __cached_id__:
return __cached_id__[model_id]
try:
__cached_id__[model_id] = Client().model_info(model_id, revision=revision)
return __cached_id__[model_id]
except Exception as err:
traceback.print_exc()
raise Error(f'Failed to fetch {model_id} from huggingface.co') from err
def has_weights(model_id: str, revision: str | None = None, *, extensions: str) -> bool:
if validate_is_path(model_id):
return next((True for _ in pathlib.Path(resolve_filepath(model_id)).glob(f'*.{extensions}')), False)
return any(s.rfilename.endswith(f'.{extensions}') for s in ModelInfo(model_id, revision=revision).siblings)
has_safetensors_weights = functools.partial(has_weights, extensions='safetensors')
has_pt_weights = functools.partial(has_weights, extensions='pt')
@attr.define(slots=True)
class HfIgnore:
safetensors = '*.safetensors'
pt = '*.bin'
tf = '*.h5'
flax = '*.msgpack'
gguf = '*.gguf'
@classmethod
def ignore_patterns(cls, llm: openllm.LLM[t.Any, t.Any]) -> list[str]:
if llm.__llm_backend__ in {'vllm', 'pt'}:
base = [cls.tf, cls.flax, cls.gguf]
if llm.config['architecture'] == 'MixtralForCausalLM': # XXX: Hack for Mixtral as safetensors is yet to be working atm
base.append(cls.safetensors)
elif has_safetensors_weights(llm.model_id):
base.extend([cls.pt, '*.pt'])
else:
base.append(cls.safetensors)
elif llm.__llm_backend__ == 'ggml':
base = [cls.tf, cls.flax, cls.pt, cls.safetensors]
else:
raise ValueError('Unknown backend (should never happen at all.)')
# filter out these files, since we probably don't need them for now.
base.extend(['*.pdf', '*.md', '.gitattributes', 'LICENSE.txt', 'Notice'])
return base