mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-03-11 11:39:52 -04:00
fix(mixtral): setup hack atm to load weights from pt specifically instead of safetensors Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
62 lines
2.3 KiB
Python
62 lines
2.3 KiB
Python
from __future__ import annotations
|
|
import attr, traceback, functools, pathlib, typing as t
|
|
from huggingface_hub import HfApi
|
|
from openllm_core.exceptions import Error
|
|
from openllm_core.utils import resolve_filepath, validate_is_path
|
|
|
|
if t.TYPE_CHECKING:
|
|
from huggingface_hub.hf_api import ModelInfo as HfModelInfo
|
|
import openllm
|
|
|
|
__global_inst__ = None
|
|
__cached_id__: dict[str, HfModelInfo] = dict()
|
|
|
|
def Client() -> HfApi:
|
|
global __global_inst__ # noqa: PLW0603
|
|
if __global_inst__ is None:
|
|
__global_inst__ = HfApi()
|
|
return __global_inst__
|
|
|
|
def ModelInfo(model_id: str, revision: str | None = None) -> HfModelInfo:
|
|
if model_id in __cached_id__:
|
|
return __cached_id__[model_id]
|
|
try:
|
|
__cached_id__[model_id] = Client().model_info(model_id, revision=revision)
|
|
return __cached_id__[model_id]
|
|
except Exception as err:
|
|
traceback.print_exc()
|
|
raise Error(f'Failed to fetch {model_id} from huggingface.co') from err
|
|
|
|
def has_weights(model_id: str, revision: str | None = None, *, extensions: str) -> bool:
|
|
if validate_is_path(model_id):
|
|
return next((True for _ in pathlib.Path(resolve_filepath(model_id)).glob(f'*.{extensions}')), False)
|
|
return any(s.rfilename.endswith(f'.{extensions}') for s in ModelInfo(model_id, revision=revision).siblings)
|
|
|
|
has_safetensors_weights = functools.partial(has_weights, extensions='safetensors')
|
|
has_pt_weights = functools.partial(has_weights, extensions='pt')
|
|
|
|
@attr.define(slots=True)
|
|
class HfIgnore:
|
|
safetensors = '*.safetensors'
|
|
pt = '*.bin'
|
|
tf = '*.h5'
|
|
flax = '*.msgpack'
|
|
gguf = '*.gguf'
|
|
@classmethod
|
|
def ignore_patterns(cls, llm: openllm.LLM[t.Any, t.Any]) -> list[str]:
|
|
if llm.__llm_backend__ in {'vllm', 'pt'}:
|
|
base = [cls.tf, cls.flax, cls.gguf]
|
|
if llm.config['architecture'] == 'MixtralForCausalLM': # XXX: Hack for Mixtral as safetensors is yet to be working atm
|
|
base.append(cls.safetensors)
|
|
elif has_safetensors_weights(llm.model_id):
|
|
base.extend([cls.pt, '*.pt'])
|
|
else:
|
|
base.append(cls.safetensors)
|
|
elif llm.__llm_backend__ == 'ggml':
|
|
base = [cls.tf, cls.flax, cls.pt, cls.safetensors]
|
|
else:
|
|
raise ValueError('Unknown backend (should never happen at all.)')
|
|
# filter out these files, since we probably don't need them for now.
|
|
base.extend(['*.pdf', '*.md', '.gitattributes', 'LICENSE.txt', 'Notice'])
|
|
return base
|