Files
OpenLLM/openllm-python/src/openllm/serialisation/transformers/weights.py
2023-08-22 14:03:06 +00:00

27 lines
1.1 KiB
Python

from __future__ import annotations
import typing as t, attr
from huggingface_hub import HfApi
if t.TYPE_CHECKING:
import openllm
from openllm_core._typing_compat import M, T
def has_safetensors_weights(model_id: str, revision: str | None = None) -> bool:
return any(s.rfilename.endswith(".safetensors") for s in HfApi().model_info(model_id, revision=revision).siblings)
@attr.define(slots=True)
class HfIgnore:
safetensors = "*.safetensors"
pt = "*.bin"
tf = "*.h5"
flax = "*.msgpack"
@classmethod
def ignore_patterns(cls, llm: openllm.LLM[M, T]) -> list[str]:
if llm.__llm_implementation__ == "vllm": base = [cls.tf, cls.flax, cls.safetensors]
elif llm.__llm_implementation__ == "tf": base = [cls.flax, cls.pt]
elif llm.__llm_implementation__ == "flax": base = [cls.tf, cls.pt, cls.safetensors] # as of current, safetensors is not supported with flax
else:
base = [cls.tf, cls.flax]
if has_safetensors_weights(llm.model_id): base.append(cls.pt)
# filter out these files, since we probably don't need them for now.
base.extend(["*.pdf", "*.md", ".gitattributes", "LICENSE.txt"])
return base