chore(style): synchronized style across packages [skip ci]

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
Aaron
2023-08-23 08:46:22 -04:00
parent bbd9aa7646
commit 787ce1b3b6
124 changed files with 2775 additions and 2771 deletions

View File

@@ -1,4 +1,4 @@
"""Serialisation related implementation for Transformers-based implementation."""
'''Serialisation related implementation for Transformers-based implementation.'''
from __future__ import annotations
import importlib, logging, typing as t
import bentoml, openllm
@@ -18,14 +18,14 @@ if t.TYPE_CHECKING:
from bentoml._internal.models import ModelStore
from openllm_core._typing_compat import DictStrAny, M, T
else:
vllm = openllm.utils.LazyLoader("vllm", globals(), "vllm")
autogptq = openllm.utils.LazyLoader("autogptq", globals(), "auto_gptq")
transformers = openllm.utils.LazyLoader("transformers", globals(), "transformers")
torch = openllm.utils.LazyLoader("torch", globals(), "torch")
vllm = openllm.utils.LazyLoader('vllm', globals(), 'vllm')
autogptq = openllm.utils.LazyLoader('autogptq', globals(), 'auto_gptq')
transformers = openllm.utils.LazyLoader('transformers', globals(), 'transformers')
torch = openllm.utils.LazyLoader('torch', globals(), 'torch')
logger = logging.getLogger(__name__)
__all__ = ["import_model", "get", "load_model", "save_pretrained"]
__all__ = ['import_model', 'get', 'load_model', 'save_pretrained']
@inject
def import_model(llm: openllm.LLM[M, T], *decls: t.Any, trust_remote_code: bool, _model_store: ModelStore = Provide[BentoMLContainer.model_store], **attrs: t.Any) -> bentoml.Model:
"""Auto detect model type from given model_id and import it to bentoml's model store.
@@ -48,23 +48,23 @@ def import_model(llm: openllm.LLM[M, T], *decls: t.Any, trust_remote_code: bool,
config, hub_attrs, attrs = process_config(llm.model_id, trust_remote_code, **attrs)
_, tokenizer_attrs = llm.llm_parameters
quantize_method = llm._quantize_method
safe_serialisation = openllm.utils.first_not_none(attrs.get("safe_serialization"), default=llm._serialisation_format == "safetensors")
safe_serialisation = openllm.utils.first_not_none(attrs.get('safe_serialization'), default=llm._serialisation_format == 'safetensors')
# Disable safe serialization with vLLM
if llm.__llm_implementation__ == "vllm": safe_serialisation = False
metadata: DictStrAny = {"safe_serialisation": safe_serialisation, "_quantize": quantize_method is not None and quantize_method}
if llm.__llm_implementation__ == 'vllm': safe_serialisation = False
metadata: DictStrAny = {'safe_serialisation': safe_serialisation, '_quantize': quantize_method is not None and quantize_method}
signatures: DictStrAny = {}
if quantize_method == "gptq":
if quantize_method == 'gptq':
if not openllm.utils.is_autogptq_available():
raise openllm.exceptions.OpenLLMException("GPTQ quantisation requires 'auto-gptq' (Not found in local environment). Install it with 'pip install \"openllm[gptq]\"'")
if llm.config["model_type"] != "causal_lm": raise openllm.exceptions.OpenLLMException(f"GPTQ only support Causal LM (got {llm.__class__} of {llm.config['model_type']})")
signatures["generate"] = {"batchable": False}
if llm.config['model_type'] != 'causal_lm': raise openllm.exceptions.OpenLLMException(f"GPTQ only support Causal LM (got {llm.__class__} of {llm.config['model_type']})")
signatures['generate'] = {'batchable': False}
else:
# this model might be called with --quantize int4, therefore we need to pop this out
# since saving int4 is not yet supported
if "quantization_config" in attrs and getattr(attrs["quantization_config"], "load_in_4bit", False): attrs.pop("quantization_config")
if llm.__llm_implementation__ != "flax": attrs["use_safetensors"] = safe_serialisation
metadata["_framework"] = "pt" if llm.__llm_implementation__ == "vllm" else llm.__llm_implementation__
if 'quantization_config' in attrs and getattr(attrs['quantization_config'], 'load_in_4bit', False): attrs.pop('quantization_config')
if llm.__llm_implementation__ != 'flax': attrs['use_safetensors'] = safe_serialisation
metadata['_framework'] = 'pt' if llm.__llm_implementation__ == 'vllm' else llm.__llm_implementation__
tokenizer = infer_tokenizers_from_llm(llm).from_pretrained(llm.model_id, trust_remote_code=trust_remote_code, **hub_attrs, **tokenizer_attrs)
if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token
@@ -73,10 +73,10 @@ def import_model(llm: openllm.LLM[M, T], *decls: t.Any, trust_remote_code: bool,
imported_modules: list[types.ModuleType] = []
bentomodel = bentoml.Model.create(
llm.tag,
module="openllm.serialisation.transformers",
api_version="v1",
module='openllm.serialisation.transformers',
api_version='v1',
options=ModelOptions(),
context=openllm.utils.generate_context(framework_name="openllm"),
context=openllm.utils.generate_context(framework_name='openllm'),
labels=openllm.utils.generate_labels(llm),
signatures=signatures if signatures else make_model_signatures(llm)
)
@@ -84,35 +84,35 @@ def import_model(llm: openllm.LLM[M, T], *decls: t.Any, trust_remote_code: bool,
try:
bentomodel.enter_cloudpickle_context(external_modules, imported_modules)
tokenizer.save_pretrained(bentomodel.path)
if quantize_method == "gptq":
if quantize_method == 'gptq':
if not openllm.utils.is_autogptq_available():
raise openllm.exceptions.OpenLLMException("GPTQ quantisation requires 'auto-gptq' (Not found in local environment). Install it with 'pip install \"openllm[gptq]\"'")
if llm.config["model_type"] != "causal_lm": raise openllm.exceptions.OpenLLMException(f"GPTQ only support Causal LM (got {llm.__class__} of {llm.config['model_type']})")
logger.debug("Saving model with GPTQ quantisation will require loading model into memory.")
if llm.config['model_type'] != 'causal_lm': raise openllm.exceptions.OpenLLMException(f"GPTQ only support Causal LM (got {llm.__class__} of {llm.config['model_type']})")
logger.debug('Saving model with GPTQ quantisation will require loading model into memory.')
model = autogptq.AutoGPTQForCausalLM.from_quantized(
llm.model_id,
*decls,
quantize_config=t.cast("autogptq.BaseQuantizeConfig", llm.quantization_config),
quantize_config=t.cast('autogptq.BaseQuantizeConfig', llm.quantization_config),
trust_remote_code=trust_remote_code,
use_safetensors=safe_serialisation,
**hub_attrs,
**attrs,
)
update_model(bentomodel, metadata={"_pretrained_class": model.__class__.__name__, "_framework": model.model.framework})
update_model(bentomodel, metadata={'_pretrained_class': model.__class__.__name__, '_framework': model.model.framework})
model.save_quantized(bentomodel.path, use_safetensors=safe_serialisation)
else:
architectures = getattr(config, "architectures", [])
architectures = getattr(config, 'architectures', [])
if not architectures:
raise RuntimeError("Failed to determine the architecture for this model. Make sure the `config.json` is valid and can be loaded with `transformers.AutoConfig`")
raise RuntimeError('Failed to determine the architecture for this model. Make sure the `config.json` is valid and can be loaded with `transformers.AutoConfig`')
architecture = architectures[0]
update_model(bentomodel, metadata={"_pretrained_class": architecture})
update_model(bentomodel, metadata={'_pretrained_class': architecture})
if llm._local:
# possible local path
logger.debug("Model will be loaded into memory to save to target store as it is from local path.")
logger.debug('Model will be loaded into memory to save to target store as it is from local path.')
model = infer_autoclass_from_llm(llm, config).from_pretrained(llm.model_id, *decls, config=config, trust_remote_code=trust_remote_code, **hub_attrs, **attrs)
# for trust_remote_code to work
bentomodel.enter_cloudpickle_context([importlib.import_module(model.__module__)], imported_modules)
model.save_pretrained(bentomodel.path, max_shard_size="5GB", safe_serialization=safe_serialisation)
model.save_pretrained(bentomodel.path, max_shard_size='5GB', safe_serialization=safe_serialisation)
else:
# we will clone the all tings into the bentomodel path without loading model into memory
snapshot_download(llm.model_id, local_dir=bentomodel.path, local_dir_use_symlinks=False, ignore_patterns=HfIgnore.ignore_patterns(llm))
@@ -129,58 +129,58 @@ def import_model(llm: openllm.LLM[M, T], *decls: t.Any, trust_remote_code: bool,
if openllm.utils.is_torch_available() and torch.cuda.is_available(): torch.cuda.empty_cache()
return bentomodel
def get(llm: openllm.LLM[M, T], auto_import: bool = False) -> bentoml.Model:
"""Return an instance of ``bentoml.Model`` from given LLM instance.
'''Return an instance of ``bentoml.Model`` from given LLM instance.
By default, it will try to check the model in the local store.
If model is not found, and ``auto_import`` is set to True, it will try to import the model from HuggingFace Hub.
Otherwise, it will raises a ``bentoml.exceptions.NotFound``.
"""
'''
try:
model = bentoml.models.get(llm.tag)
if model.info.module not in (
"openllm.serialisation.transformers"
"bentoml.transformers", "bentoml._internal.frameworks.transformers", __name__
'openllm.serialisation.transformers'
'bentoml.transformers', 'bentoml._internal.frameworks.transformers', __name__
): # NOTE: backward compatible with previous version of OpenLLM.
raise bentoml.exceptions.NotFound(f"Model {model.tag} was saved with module {model.info.module}, not loading with 'openllm.serialisation.transformers'.")
if "runtime" in model.info.labels and model.info.labels["runtime"] != llm.runtime:
if 'runtime' in model.info.labels and model.info.labels['runtime'] != llm.runtime:
raise openllm.exceptions.OpenLLMException(f"Model {model.tag} was saved with runtime {model.info.labels['runtime']}, not loading with {llm.runtime}.")
return model
except bentoml.exceptions.NotFound as err:
if auto_import: return import_model(llm, trust_remote_code=llm.__llm_trust_remote_code__)
raise err from None
def load_model(llm: openllm.LLM[M, T], *decls: t.Any, **attrs: t.Any) -> M:
"""Load the model from BentoML store.
'''Load the model from BentoML store.
By default, it will try to find check the model in the local store.
If model is not found, it will raises a ``bentoml.exceptions.NotFound``.
"""
'''
config, hub_attrs, attrs = process_config(llm.model_id, llm.__llm_trust_remote_code__, **attrs)
safe_serialization = openllm.utils.first_not_none(
t.cast(t.Optional[bool], llm._bentomodel.info.metadata.get("safe_serialisation", None)), attrs.pop("safe_serialization", None), default=llm._serialisation_format == "safetensors"
t.cast(t.Optional[bool], llm._bentomodel.info.metadata.get('safe_serialisation', None)), attrs.pop('safe_serialization', None), default=llm._serialisation_format == 'safetensors'
)
if "_quantize" in llm._bentomodel.info.metadata and llm._bentomodel.info.metadata["_quantize"] == "gptq":
if '_quantize' in llm._bentomodel.info.metadata and llm._bentomodel.info.metadata['_quantize'] == 'gptq':
if not openllm.utils.is_autogptq_available():
raise openllm.exceptions.OpenLLMException("GPTQ quantisation requires 'auto-gptq' (Not found in local environment). Install it with 'pip install \"openllm[gptq]\"'")
if llm.config["model_type"] != "causal_lm": raise openllm.exceptions.OpenLLMException(f"GPTQ only support Causal LM (got {llm.__class__} of {llm.config['model_type']})")
if llm.config['model_type'] != 'causal_lm': raise openllm.exceptions.OpenLLMException(f"GPTQ only support Causal LM (got {llm.__class__} of {llm.config['model_type']})")
return autogptq.AutoGPTQForCausalLM.from_quantized(
llm._bentomodel.path,
*decls,
quantize_config=t.cast("autogptq.BaseQuantizeConfig", llm.quantization_config),
quantize_config=t.cast('autogptq.BaseQuantizeConfig', llm.quantization_config),
trust_remote_code=llm.__llm_trust_remote_code__,
use_safetensors=safe_serialization,
**hub_attrs,
**attrs
)
device_map = attrs.pop("device_map", "auto" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None)
device_map = attrs.pop('device_map', 'auto' if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None)
model = infer_autoclass_from_llm(llm, config).from_pretrained(
llm._bentomodel.path, *decls, config=config, trust_remote_code=llm.__llm_trust_remote_code__, device_map=device_map, **hub_attrs, **attrs
).eval()
# BetterTransformer is currently only supported on PyTorch.
if llm.bettertransformer and isinstance(model, transformers.PreTrainedModel): model = model.to_bettertransformer()
if llm.__llm_implementation__ in {"pt", "vllm"}: check_unintialised_params(model)
return t.cast("M", model)
if llm.__llm_implementation__ in {'pt', 'vllm'}: check_unintialised_params(model)
return t.cast('M', model)
def save_pretrained(
llm: openllm.LLM[M, T],
save_directory: str,
@@ -188,29 +188,29 @@ def save_pretrained(
state_dict: DictStrAny | None = None,
save_function: t.Any | None = None,
push_to_hub: bool = False,
max_shard_size: int | str = "10GB",
max_shard_size: int | str = '10GB',
safe_serialization: bool = False,
variant: str | None = None,
**attrs: t.Any
) -> None:
save_function = t.cast(t.Callable[..., None], openllm.utils.first_not_none(save_function, default=torch.save))
model_save_attrs, tokenizer_save_attrs = openllm.utils.normalize_attrs_to_model_tokenizer_pair(**attrs)
safe_serialization = safe_serialization or llm._serialisation_format == "safetensors"
safe_serialization = safe_serialization or llm._serialisation_format == 'safetensors'
# NOTE: disable safetensors for vllm
if llm.__llm_implementation__ == "vllm": safe_serialization = False
if llm._quantize_method == "gptq":
if llm.__llm_implementation__ == 'vllm': safe_serialization = False
if llm._quantize_method == 'gptq':
if not openllm.utils.is_autogptq_available():
raise openllm.exceptions.OpenLLMException("GPTQ quantisation requires 'auto-gptq' (Not found in local environment). Install it with 'pip install \"openllm[gptq]\"'")
if llm.config["model_type"] != "causal_lm": raise openllm.exceptions.OpenLLMException(f"GPTQ only support Causal LM (got {llm.__class__} of {llm.config['model_type']})")
if not openllm.utils.lenient_issubclass(llm.model, autogptq.modeling.BaseGPTQForCausalLM): raise ValueError(f"Model is not a BaseGPTQForCausalLM (type: {type(llm.model)})")
t.cast("autogptq.modeling.BaseGPTQForCausalLM", llm.model).save_quantized(save_directory, use_safetensors=safe_serialization)
elif openllm.utils.LazyType["vllm.LLMEngine"]("vllm.LLMEngine").isinstance(llm.model):
if llm.config['model_type'] != 'causal_lm': raise openllm.exceptions.OpenLLMException(f"GPTQ only support Causal LM (got {llm.__class__} of {llm.config['model_type']})")
if not openllm.utils.lenient_issubclass(llm.model, autogptq.modeling.BaseGPTQForCausalLM): raise ValueError(f'Model is not a BaseGPTQForCausalLM (type: {type(llm.model)})')
t.cast('autogptq.modeling.BaseGPTQForCausalLM', llm.model).save_quantized(save_directory, use_safetensors=safe_serialization)
elif openllm.utils.LazyType['vllm.LLMEngine']('vllm.LLMEngine').isinstance(llm.model):
raise RuntimeError("vllm.LLMEngine cannot be serialisation directly. This happens when 'save_pretrained' is called directly after `openllm.AutoVLLM` is initialized.")
elif isinstance(llm.model, transformers.Pipeline):
llm.model.save_pretrained(save_directory, safe_serialization=safe_serialization)
else:
# We can safely cast here since it will be the PreTrainedModel protocol.
t.cast("transformers.PreTrainedModel", llm.model).save_pretrained(
t.cast('transformers.PreTrainedModel', llm.model).save_pretrained(
save_directory,
is_main_process=is_main_process,
state_dict=state_dict,

View File

@@ -9,11 +9,11 @@ if t.TYPE_CHECKING:
from bentoml._internal.models.model import ModelSignaturesType
from openllm_core._typing_compat import DictStrAny, M, T
else:
transformers, torch = openllm_core.utils.LazyLoader("transformers", globals(), "transformers"), openllm_core.utils.LazyLoader("torch", globals(), "torch")
transformers, torch = openllm_core.utils.LazyLoader('transformers', globals(), 'transformers'), openllm_core.utils.LazyLoader('torch', globals(), 'torch')
_object_setattr = object.__setattr__
def process_config(model_id: str, trust_remote_code: bool, **attrs: t.Any) -> tuple[transformers.PretrainedConfig, DictStrAny, DictStrAny]:
"""A helper function that correctly parse config and attributes for transformers.PretrainedConfig.
'''A helper function that correctly parse config and attributes for transformers.PretrainedConfig.
Args:
model_id: Model id to pass into ``transformers.AutoConfig``.
@@ -22,51 +22,51 @@ def process_config(model_id: str, trust_remote_code: bool, **attrs: t.Any) -> tu
Returns:
A tuple of ``transformers.PretrainedConfig``, all hub attributes, and remanining attributes that can be used by the Model class.
"""
config = attrs.pop("config", None)
'''
config = attrs.pop('config', None)
# this logic below is synonymous to handling `from_pretrained` attrs.
hub_attrs = {k: attrs.pop(k) for k in HUB_ATTRS if k in attrs}
if not isinstance(config, transformers.PretrainedConfig):
copied_attrs = copy.deepcopy(attrs)
if copied_attrs.get("torch_dtype", None) == "auto": copied_attrs.pop("torch_dtype")
if copied_attrs.get('torch_dtype', None) == 'auto': copied_attrs.pop('torch_dtype')
config, attrs = transformers.AutoConfig.from_pretrained(model_id, return_unused_kwargs=True, trust_remote_code=trust_remote_code, **hub_attrs, **copied_attrs)
return config, hub_attrs, attrs
def infer_tokenizers_from_llm(__llm: openllm.LLM[t.Any, T], /) -> T:
__cls = getattr(transformers, openllm_core.utils.first_not_none(__llm.config["tokenizer_class"], default="AutoTokenizer"), None)
if __cls is None: raise ValueError(f"Cannot infer correct tokenizer class for {__llm}. Make sure to unset `tokenizer_class`")
__cls = getattr(transformers, openllm_core.utils.first_not_none(__llm.config['tokenizer_class'], default='AutoTokenizer'), None)
if __cls is None: raise ValueError(f'Cannot infer correct tokenizer class for {__llm}. Make sure to unset `tokenizer_class`')
return __cls
def infer_autoclass_from_llm(llm: openllm.LLM[M, T], config: transformers.PretrainedConfig, /) -> _BaseAutoModelClass:
if llm.config["trust_remote_code"]:
autoclass = "AutoModelForSeq2SeqLM" if llm.config["model_type"] == "seq2seq_lm" else "AutoModelForCausalLM"
if not hasattr(config, "auto_map"):
raise ValueError(f"Invalid configuraiton for {llm.model_id}. ``trust_remote_code=True`` requires `transformers.PretrainedConfig` to contain a `auto_map` mapping")
if llm.config['trust_remote_code']:
autoclass = 'AutoModelForSeq2SeqLM' if llm.config['model_type'] == 'seq2seq_lm' else 'AutoModelForCausalLM'
if not hasattr(config, 'auto_map'):
raise ValueError(f'Invalid configuraiton for {llm.model_id}. ``trust_remote_code=True`` requires `transformers.PretrainedConfig` to contain a `auto_map` mapping')
# in case this model doesn't use the correct auto class for model type, for example like chatglm
# where it uses AutoModel instead of AutoModelForCausalLM. Then we fallback to AutoModel
if autoclass not in config.auto_map: autoclass = "AutoModel"
if autoclass not in config.auto_map: autoclass = 'AutoModel'
return getattr(transformers, autoclass)
else:
if type(config) in transformers.MODEL_FOR_CAUSAL_LM_MAPPING: idx = 0
elif type(config) in transformers.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING: idx = 1
else: raise openllm.exceptions.OpenLLMException(f"Model type {type(config)} is not supported yet.")
else: raise openllm.exceptions.OpenLLMException(f'Model type {type(config)} is not supported yet.')
return getattr(transformers, FRAMEWORK_TO_AUTOCLASS_MAPPING[llm.__llm_implementation__][idx])
def check_unintialised_params(model: torch.nn.Module) -> None:
unintialized = [n for n, param in model.named_parameters() if param.data.device == torch.device("meta")]
if len(unintialized) > 0: raise RuntimeError(f"Found the following unintialized parameters in {model}: {unintialized}")
unintialized = [n for n, param in model.named_parameters() if param.data.device == torch.device('meta')]
if len(unintialized) > 0: raise RuntimeError(f'Found the following unintialized parameters in {model}: {unintialized}')
def update_model(bentomodel: bentoml.Model, metadata: DictStrAny) -> bentoml.Model:
based: DictStrAny = copy.deepcopy(bentomodel.info.metadata)
based.update(metadata)
_object_setattr(bentomodel, "_info", ModelInfo( # type: ignore[call-arg] # XXX: remove me once upstream is merged
_object_setattr(bentomodel, '_info', ModelInfo( # type: ignore[call-arg] # XXX: remove me once upstream is merged
tag=bentomodel.info.tag, module=bentomodel.info.module, labels=bentomodel.info.labels, options=bentomodel.info.options.to_dict(), signatures=bentomodel.info.signatures, context=bentomodel.info.context, api_version=bentomodel.info.api_version, creation_time=bentomodel.info.creation_time, metadata=based
))
return bentomodel
# NOTE: sync with bentoml/_internal/frameworks/transformers.py#make_default_signatures
def make_model_signatures(llm: openllm.LLM[M, T]) -> ModelSignaturesType:
infer_fn: tuple[str, ...] = ("__call__",)
infer_fn: tuple[str, ...] = ('__call__',)
default_config = ModelSignature(batchable=False)
if llm.__llm_implementation__ in {"pt", "vllm"}:
infer_fn += ("forward", "generate", "contrastive_search", "greedy_search", "sample", "beam_search", "beam_sample", "group_beam_search", "constrained_beam_search",)
elif llm.__llm_implementation__ == "tf":
infer_fn += ("predict", "call", "generate", "compute_transition_scores", "greedy_search", "sample", "beam_search", "contrastive_search",)
if llm.__llm_implementation__ in {'pt', 'vllm'}:
infer_fn += ('forward', 'generate', 'contrastive_search', 'greedy_search', 'sample', 'beam_search', 'beam_sample', 'group_beam_search', 'constrained_beam_search',)
elif llm.__llm_implementation__ == 'tf':
infer_fn += ('predict', 'call', 'generate', 'compute_transition_scores', 'greedy_search', 'sample', 'beam_search', 'contrastive_search',)
else:
infer_fn += ("generate",)
infer_fn += ('generate',)
return {k: default_config for k in infer_fn}

View File

@@ -5,22 +5,22 @@ if t.TYPE_CHECKING:
import openllm
from openllm_core._typing_compat import M, T
def has_safetensors_weights(model_id: str, revision: str | None = None) -> bool:
return any(s.rfilename.endswith(".safetensors") for s in HfApi().model_info(model_id, revision=revision).siblings)
return any(s.rfilename.endswith('.safetensors') for s in HfApi().model_info(model_id, revision=revision).siblings)
@attr.define(slots=True)
class HfIgnore:
safetensors = "*.safetensors"
pt = "*.bin"
tf = "*.h5"
flax = "*.msgpack"
safetensors = '*.safetensors'
pt = '*.bin'
tf = '*.h5'
flax = '*.msgpack'
@classmethod
def ignore_patterns(cls, llm: openllm.LLM[M, T]) -> list[str]:
if llm.__llm_implementation__ == "vllm": base = [cls.tf, cls.flax, cls.safetensors]
elif llm.__llm_implementation__ == "tf": base = [cls.flax, cls.pt]
elif llm.__llm_implementation__ == "flax": base = [cls.tf, cls.pt, cls.safetensors] # as of current, safetensors is not supported with flax
if llm.__llm_implementation__ == 'vllm': base = [cls.tf, cls.flax, cls.safetensors]
elif llm.__llm_implementation__ == 'tf': base = [cls.flax, cls.pt]
elif llm.__llm_implementation__ == 'flax': base = [cls.tf, cls.pt, cls.safetensors] # as of current, safetensors is not supported with flax
else:
base = [cls.tf, cls.flax]
if has_safetensors_weights(llm.model_id): base.append(cls.pt)
# filter out these files, since we probably don't need them for now.
base.extend(["*.pdf", "*.md", ".gitattributes", "LICENSE.txt"])
base.extend(['*.pdf', '*.md', '.gitattributes', 'LICENSE.txt'])
return base