mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-01-24 15:27:51 -05:00
chore(style): cleanup bytes
Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
@@ -1,32 +1,18 @@
|
||||
from __future__ import annotations
|
||||
import importlib
|
||||
import typing as t
|
||||
|
||||
from openllm_core._typing_compat import M, ParamSpec, T, TypeGuard
|
||||
import importlib, typing as t
|
||||
from openllm_core._typing_compat import M, ParamSpec, T, TypeGuard, Concatenate
|
||||
from openllm_core.exceptions import OpenLLMException
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from bentoml import Model
|
||||
|
||||
from .._llm import LLM
|
||||
if t.TYPE_CHECKING: from bentoml import Model; from .._llm import LLM
|
||||
|
||||
P = ParamSpec('P')
|
||||
|
||||
|
||||
def load_tokenizer(llm: LLM[M, T], **tokenizer_attrs: t.Any) -> TypeGuard[T]:
|
||||
'''Load the tokenizer from BentoML store.
|
||||
|
||||
By default, it will try to find the bentomodel whether it is in store..
|
||||
If model is not found, it will raises a ``bentoml.exceptions.NotFound``.
|
||||
'''
|
||||
import cloudpickle
|
||||
import fs
|
||||
from transformers import AutoTokenizer
|
||||
import cloudpickle, fs, transformers
|
||||
from bentoml._internal.models.model import CUSTOM_OBJECTS_FILENAME
|
||||
from .transformers._helpers import process_config
|
||||
|
||||
tokenizer_attrs = {**llm.llm_parameters[-1], **tokenizer_attrs}
|
||||
from bentoml._internal.models.model import CUSTOM_OBJECTS_FILENAME
|
||||
|
||||
from .transformers._helpers import process_config
|
||||
|
||||
config, *_ = process_config(llm.bentomodel.path, llm.trust_remote_code)
|
||||
|
||||
@@ -41,9 +27,7 @@ def load_tokenizer(llm: LLM[M, T], **tokenizer_attrs: t.Any) -> TypeGuard[T]:
|
||||
'For example: "bentoml.transformers.save_model(..., custom_objects={\'tokenizer\': tokenizer})"'
|
||||
) from None
|
||||
else:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
bentomodel_fs.getsyspath('/'), trust_remote_code=llm.trust_remote_code, **tokenizer_attrs
|
||||
)
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(bentomodel_fs.getsyspath('/'), trust_remote_code=llm.trust_remote_code, **tokenizer_attrs)
|
||||
|
||||
if tokenizer.pad_token_id is None:
|
||||
if config.pad_token_id is not None:
|
||||
@@ -56,17 +40,16 @@ def load_tokenizer(llm: LLM[M, T], **tokenizer_attrs: t.Any) -> TypeGuard[T]:
|
||||
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
|
||||
return tokenizer
|
||||
|
||||
|
||||
def _make_dispatch_function(fn):
|
||||
def _make_dispatch_function(fn: str) -> t.Callable[Concatenate[LLM[M, T], P], TypeGuard[M | T | Model]]:
|
||||
def caller(llm: LLM[M, T], *args: P.args, **kwargs: P.kwargs) -> TypeGuard[M | T | Model]:
|
||||
"""Generic function dispatch to correct serialisation submodules based on LLM runtime.
|
||||
'''Generic function dispatch to correct serialisation submodules based on LLM runtime.
|
||||
|
||||
> [!NOTE] See 'openllm.serialisation.transformers' if 'llm.__llm_backend__ in ("pt", "vllm")'
|
||||
|
||||
> [!NOTE] See 'openllm.serialisation.ggml' if 'llm.__llm_backend__="ggml"'
|
||||
|
||||
> [!NOTE] See 'openllm.serialisation.ctranslate' if 'llm.__llm_backend__="ctranslate"'
|
||||
"""
|
||||
'''
|
||||
if llm.__llm_backend__ == 'ggml':
|
||||
serde = 'ggml'
|
||||
elif llm.__llm_backend__ == 'ctranslate':
|
||||
@@ -76,19 +59,12 @@ def _make_dispatch_function(fn):
|
||||
else:
|
||||
raise OpenLLMException(f'Not supported backend {llm.__llm_backend__}')
|
||||
return getattr(importlib.import_module(f'.{serde}', 'openllm.serialisation'), fn)(llm, *args, **kwargs)
|
||||
|
||||
return caller
|
||||
|
||||
|
||||
_extras = ['get', 'import_model', 'load_model']
|
||||
_import_structure = {'ggml', 'transformers', 'ctranslate', 'constants'}
|
||||
__all__ = ['load_tokenizer', *_extras, *_import_structure]
|
||||
|
||||
|
||||
def __dir__() -> t.Sequence[str]:
|
||||
return sorted(__all__)
|
||||
|
||||
|
||||
def __dir__() -> t.Sequence[str]: return sorted(__all__)
|
||||
def __getattr__(name: str) -> t.Any:
|
||||
if name == 'load_tokenizer':
|
||||
return load_tokenizer
|
||||
|
||||
@@ -4,16 +4,18 @@ Currently supports transformers for PyTorch, and vLLM.
|
||||
|
||||
Currently, GGML format is working in progress.
|
||||
'''
|
||||
|
||||
from typing import Any
|
||||
|
||||
from bentoml import Model
|
||||
from openllm import LLM
|
||||
from openllm_core._typing_compat import M, T
|
||||
|
||||
from . import constants as constants, ggml as ggml, transformers as transformers
|
||||
|
||||
def load_tokenizer(llm: LLM[M, T], **attrs: Any) -> T: ...
|
||||
def load_tokenizer(llm: LLM[M, T], **attrs: Any) -> T:
|
||||
'''Load the tokenizer from BentoML store.
|
||||
|
||||
By default, it will try to find the bentomodel whether it is in store..
|
||||
If model is not found, it will raises a ``bentoml.exceptions.NotFound``.
|
||||
'''
|
||||
def get(llm: LLM[M, T]) -> Model: ...
|
||||
def import_model(llm: LLM[M, T], *args: Any, trust_remote_code: bool, **attrs: Any) -> Model: ...
|
||||
def load_model(llm: LLM[M, T], *args: Any, **attrs: Any) -> M: ...
|
||||
|
||||
@@ -1,10 +1,6 @@
|
||||
import contextlib
|
||||
|
||||
import attr
|
||||
import contextlib, attr
|
||||
from simple_di import Provide, inject
|
||||
|
||||
import bentoml
|
||||
import openllm
|
||||
import bentoml, openllm
|
||||
from bentoml._internal.configuration.containers import BentoMLContainer
|
||||
from bentoml._internal.models.model import ModelOptions, ModelSignature
|
||||
from openllm_core.exceptions import OpenLLMException
|
||||
@@ -12,88 +8,58 @@ from openllm_core.utils import is_autogptq_available
|
||||
|
||||
_object_setattr = object.__setattr__
|
||||
|
||||
|
||||
def get_hash(config) -> str:
|
||||
_commit_hash = getattr(config, '_commit_hash', None)
|
||||
if _commit_hash is None:
|
||||
raise ValueError(f'Cannot find commit hash in {config}')
|
||||
if _commit_hash is None: raise ValueError(f'Cannot find commit hash in {config}')
|
||||
return _commit_hash
|
||||
|
||||
|
||||
def patch_correct_tag(llm, config, _revision=None):
|
||||
def patch_correct_tag(llm, config, _revision=None) -> None:
|
||||
# NOTE: The following won't hit during local since we generated a correct version based on local path hash It will only hit if we use model from HF Hub
|
||||
if llm.revision is not None:
|
||||
return
|
||||
if llm.revision is not None: return
|
||||
if not llm.local:
|
||||
try:
|
||||
if _revision is None:
|
||||
_revision = get_hash(config)
|
||||
if _revision is None: _revision = get_hash(config)
|
||||
except ValueError:
|
||||
pass
|
||||
if _revision is None and llm.tag.version is not None:
|
||||
_revision = llm.tag.version
|
||||
if _revision is None and llm.tag.version is not None: _revision = llm.tag.version
|
||||
if llm.tag.version is None:
|
||||
# HACK: This copies the correct revision into llm.tag
|
||||
_object_setattr(llm, '_tag', attr.evolve(llm.tag, version=_revision))
|
||||
if llm._revision is None:
|
||||
_object_setattr(llm, '_revision', _revision) # HACK: This copies the correct revision into llm._model_version
|
||||
|
||||
_object_setattr(llm, '_tag', attr.evolve(llm.tag, version=_revision)) # HACK: This copies the correct revision into llm.tag
|
||||
if llm._revision is None: _object_setattr(llm, '_revision', _revision) # HACK: This copies the correct revision into llm._model_version
|
||||
|
||||
def _create_metadata(llm, config, safe_serialisation, trust_remote_code, metadata=None):
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
if metadata is None: metadata = {}
|
||||
metadata.update({'safe_serialisation': safe_serialisation, '_framework': llm.__llm_backend__})
|
||||
if llm.quantise:
|
||||
metadata['_quantize'] = llm.quantise
|
||||
if llm.quantise: metadata['_quantize'] = llm.quantise
|
||||
architectures = getattr(config, 'architectures', [])
|
||||
if not architectures:
|
||||
if trust_remote_code:
|
||||
auto_map = getattr(config, 'auto_map', {})
|
||||
if not auto_map:
|
||||
raise RuntimeError(
|
||||
f'Failed to determine the architecture from both `auto_map` and `architectures` from {llm.model_id}'
|
||||
)
|
||||
if not auto_map: raise RuntimeError(f'Failed to determine the architecture from both `auto_map` and `architectures` from {llm.model_id}')
|
||||
autoclass = 'AutoModelForSeq2SeqLM' if llm.config['model_type'] == 'seq2seq_lm' else 'AutoModelForCausalLM'
|
||||
if autoclass not in auto_map:
|
||||
raise RuntimeError(
|
||||
f"Given model '{llm.model_id}' is yet to be supported with 'auto_map'. OpenLLM currently only support encoder-decoders or decoders only models."
|
||||
)
|
||||
raise RuntimeError(f"Given model '{llm.model_id}' is yet to be supported with 'auto_map'. OpenLLM currently only support encoder-decoders or decoders only models.")
|
||||
architectures = [auto_map[autoclass]]
|
||||
else:
|
||||
raise RuntimeError(
|
||||
'Failed to determine the architecture for this model. Make sure the `config.json` is valid and can be loaded with `transformers.AutoConfig`'
|
||||
)
|
||||
metadata.update(
|
||||
{'_pretrained_class': architectures[0], '_revision': get_hash(config) if not llm.local else llm.revision}
|
||||
)
|
||||
raise RuntimeError('Failed to determine the architecture for this model. Make sure the `config.json` is valid and can be loaded with `transformers.AutoConfig`')
|
||||
metadata.update({'_pretrained_class': architectures[0], '_revision': get_hash(config) if not llm.local else llm.revision})
|
||||
return metadata
|
||||
|
||||
|
||||
def _create_signatures(llm, signatures=None):
|
||||
if signatures is None:
|
||||
signatures = {}
|
||||
if signatures is None: signatures = {}
|
||||
if llm.__llm_backend__ == 'pt':
|
||||
if llm.quantise == 'gptq':
|
||||
if not is_autogptq_available():
|
||||
raise OpenLLMException(
|
||||
"GPTQ quantisation requires 'auto-gptq' and 'optimum' (Not found in local environment). Install it with 'pip install \"openllm[gptq]\" --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/'"
|
||||
)
|
||||
raise OpenLLMException("Requires 'auto-gptq' and 'optimum'. Install it with 'pip install \"openllm[gptq]\"'")
|
||||
signatures['generate'] = {'batchable': False}
|
||||
else:
|
||||
signatures.update(
|
||||
{
|
||||
k: ModelSignature(batchable=False)
|
||||
for k in (
|
||||
'__call__',
|
||||
'forward',
|
||||
'generate',
|
||||
'contrastive_search',
|
||||
'greedy_search',
|
||||
'sample',
|
||||
'beam_search',
|
||||
'beam_sample',
|
||||
'group_beam_search',
|
||||
'constrained_beam_search',
|
||||
'__call__', 'forward', 'generate', #
|
||||
'contrastive_search', 'greedy_search', #
|
||||
'sample', 'beam_search', 'beam_sample', #
|
||||
'group_beam_search', 'constrained_beam_search', #
|
||||
)
|
||||
}
|
||||
)
|
||||
@@ -104,37 +70,24 @@ def _create_signatures(llm, signatures=None):
|
||||
else:
|
||||
non_batch_keys = set()
|
||||
batch_keys = {
|
||||
'async_generate_tokens',
|
||||
'forward_batch',
|
||||
'generate_batch',
|
||||
'generate_iterable',
|
||||
'generate_tokens',
|
||||
'score_batch',
|
||||
'score_iterable',
|
||||
'async_generate_tokens', 'forward_batch', 'generate_batch', #
|
||||
'generate_iterable', 'generate_tokens', 'score_batch', 'score_iterable', #
|
||||
}
|
||||
signatures.update({k: ModelSignature(batchable=False) for k in non_batch_keys})
|
||||
signatures.update({k: ModelSignature(batchable=True) for k in batch_keys})
|
||||
return signatures
|
||||
|
||||
|
||||
@inject
|
||||
@contextlib.contextmanager
|
||||
def save_model(
|
||||
llm,
|
||||
config,
|
||||
safe_serialisation,
|
||||
trust_remote_code,
|
||||
module,
|
||||
external_modules,
|
||||
_model_store=Provide[BentoMLContainer.model_store],
|
||||
_api_version='v2.1.0',
|
||||
llm, config, safe_serialisation, #
|
||||
trust_remote_code, module, external_modules, #
|
||||
_model_store=Provide[BentoMLContainer.model_store], _api_version='v2.1.0', #
|
||||
):
|
||||
imported_modules = []
|
||||
bentomodel = bentoml.Model.create(
|
||||
llm.tag,
|
||||
module=f'openllm.serialisation.{module}',
|
||||
api_version=_api_version,
|
||||
options=ModelOptions(),
|
||||
llm.tag, module=f'openllm.serialisation.{module}', #
|
||||
api_version=_api_version, options=ModelOptions(), #
|
||||
context=openllm.utils.generate_context('openllm'),
|
||||
labels=openllm.utils.generate_labels(llm),
|
||||
metadata=_create_metadata(llm, config, safe_serialisation, trust_remote_code),
|
||||
|
||||
@@ -1,13 +1,7 @@
|
||||
HUB_ATTRS = [
|
||||
'cache_dir',
|
||||
'code_revision',
|
||||
'force_download',
|
||||
'local_files_only',
|
||||
'proxies',
|
||||
'resume_download',
|
||||
'revision',
|
||||
'subfolder',
|
||||
'use_auth_token',
|
||||
'cache_dir', 'code_revision', 'force_download', #
|
||||
'local_files_only', 'proxies', 'resume_download', #
|
||||
'revision', 'subfolder', 'use_auth_token', #
|
||||
]
|
||||
CONFIG_FILE_NAME = 'config.json'
|
||||
# the below is similar to peft.utils.other.CONFIG_NAME
|
||||
|
||||
@@ -15,18 +15,15 @@ logger = logging.getLogger(__name__)
|
||||
__all__ = ['import_model', 'get', 'load_model']
|
||||
_object_setattr = object.__setattr__
|
||||
|
||||
|
||||
def import_model(llm, *decls, trust_remote_code, **attrs):
|
||||
(_base_decls, _base_attrs), tokenizer_attrs = llm.llm_parameters
|
||||
decls = (*_base_decls, *decls)
|
||||
attrs = {**_base_attrs, **attrs}
|
||||
if llm._local:
|
||||
logger.warning('Given model is a local model, OpenLLM will load model into memory for serialisation.')
|
||||
if llm._local: logger.warning('Given model is a local model, OpenLLM will load model into memory for serialisation.')
|
||||
config, hub_attrs, attrs = process_config(llm.model_id, trust_remote_code, **attrs)
|
||||
patch_correct_tag(llm, config)
|
||||
safe_serialisation = first_not_none(attrs.get('safe_serialization'), default=llm._serialisation == 'safetensors')
|
||||
if llm.quantise != 'gptq':
|
||||
attrs['use_safetensors'] = safe_serialisation
|
||||
if llm.quantise != 'gptq': attrs['use_safetensors'] = safe_serialisation
|
||||
|
||||
model = None
|
||||
tokenizer = get_tokenizer(llm.model_id, trust_remote_code=trust_remote_code, **hub_attrs, **tokenizer_attrs)
|
||||
@@ -39,7 +36,6 @@ def import_model(llm, *decls, trust_remote_code, **attrs):
|
||||
attrs['quantization_config'] = llm.quantization_config
|
||||
if llm.quantise == 'gptq':
|
||||
from optimum.gptq.constants import GPTQ_CONFIG
|
||||
|
||||
with open(bentomodel.path_of(GPTQ_CONFIG), 'w', encoding='utf-8') as f:
|
||||
f.write(orjson.dumps(config.quantization_config, option=orjson.OPT_INDENT_2 | orjson.OPT_SORT_KEYS).decode())
|
||||
if llm._local: # possible local path
|
||||
@@ -56,27 +52,21 @@ def import_model(llm, *decls, trust_remote_code, **attrs):
|
||||
bentomodel.enter_cloudpickle_context([importlib.import_module(model.__module__)], imported_modules)
|
||||
model.save_pretrained(bentomodel.path, max_shard_size='2GB', safe_serialization=safe_serialisation)
|
||||
del model
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
if torch.cuda.is_available(): torch.cuda.empty_cache()
|
||||
else:
|
||||
# we will clone the all tings into the bentomodel path without loading model into memory
|
||||
snapshot_download(
|
||||
llm.model_id,
|
||||
local_dir=bentomodel.path,
|
||||
local_dir_use_symlinks=False,
|
||||
ignore_patterns=HfIgnore.ignore_patterns(llm),
|
||||
llm.model_id, local_dir=bentomodel.path, #
|
||||
local_dir_use_symlinks=False, ignore_patterns=HfIgnore.ignore_patterns(llm), #
|
||||
)
|
||||
return bentomodel
|
||||
|
||||
|
||||
def get(llm):
|
||||
try:
|
||||
model = bentoml.models.get(llm.tag)
|
||||
backend = model.info.labels['backend']
|
||||
if backend != llm.__llm_backend__:
|
||||
raise OpenLLMException(
|
||||
f"'{model.tag!s}' was saved with backend '{backend}', while loading with '{llm.__llm_backend__}'."
|
||||
)
|
||||
raise OpenLLMException(f"'{model.tag!s}' was saved with backend '{backend}', while loading with '{llm.__llm_backend__}'.")
|
||||
patch_correct_tag(
|
||||
llm,
|
||||
transformers.AutoConfig.from_pretrained(model.path, trust_remote_code=llm.trust_remote_code),
|
||||
@@ -89,18 +79,15 @@ def get(llm):
|
||||
|
||||
def check_unintialised_params(model):
|
||||
unintialized = [n for n, param in model.named_parameters() if param.data.device == torch.device('meta')]
|
||||
if len(unintialized) > 0:
|
||||
raise RuntimeError(f'Found the following unintialized parameters in {model}: {unintialized}')
|
||||
if len(unintialized) > 0: raise RuntimeError(f'Found the following unintialized parameters in {model}: {unintialized}')
|
||||
|
||||
|
||||
def load_model(llm, *decls, **attrs):
|
||||
if llm.quantise in {'awq', 'squeezellm'}:
|
||||
raise RuntimeError('AWQ is not yet supported with PyTorch backend.')
|
||||
if llm.quantise in {'awq', 'squeezellm'}: raise RuntimeError('AWQ is not yet supported with PyTorch backend.')
|
||||
config, attrs = transformers.AutoConfig.from_pretrained(
|
||||
llm.bentomodel.path, return_unused_kwargs=True, trust_remote_code=llm.trust_remote_code, **attrs
|
||||
)
|
||||
if llm.__llm_backend__ == 'triton':
|
||||
return openllm.models.load_model(llm, config, **attrs)
|
||||
if llm.__llm_backend__ == 'triton': return openllm.models.load_model(llm, config, **attrs)
|
||||
|
||||
auto_class = infer_autoclass_from_llm(llm, config)
|
||||
device_map = attrs.pop('device_map', None)
|
||||
@@ -152,6 +139,5 @@ def load_model(llm, *decls, **attrs):
|
||||
device_map=device_map,
|
||||
**attrs,
|
||||
)
|
||||
|
||||
check_unintialised_params(model)
|
||||
return model
|
||||
|
||||
@@ -1,45 +1,28 @@
|
||||
from __future__ import annotations
|
||||
import copy
|
||||
import logging
|
||||
import typing as t
|
||||
|
||||
import copy, logging
|
||||
import transformers
|
||||
|
||||
from openllm.serialisation.constants import HUB_ATTRS
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_tokenizer(model_id_or_path, trust_remote_code, **attrs):
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
||||
model_id_or_path, trust_remote_code=trust_remote_code, **attrs
|
||||
)
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(model_id_or_path, trust_remote_code=trust_remote_code, **attrs)
|
||||
if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token
|
||||
return tokenizer
|
||||
|
||||
|
||||
def process_config(model_id: str, trust_remote_code: bool, **attrs: t.Any):
|
||||
def process_config(model_id, trust_remote_code, **attrs):
|
||||
config = attrs.pop('config', None)
|
||||
# this logic below is synonymous to handling `from_pretrained` attrs.
|
||||
hub_attrs = {k: attrs.pop(k) for k in HUB_ATTRS if k in attrs}
|
||||
if not isinstance(config, transformers.PretrainedConfig):
|
||||
copied_attrs = copy.deepcopy(attrs)
|
||||
if copied_attrs.get('torch_dtype', None) == 'auto':
|
||||
copied_attrs.pop('torch_dtype')
|
||||
if copied_attrs.get('torch_dtype', None) == 'auto': copied_attrs.pop('torch_dtype')
|
||||
config, attrs = transformers.AutoConfig.from_pretrained(
|
||||
model_id, return_unused_kwargs=True, trust_remote_code=trust_remote_code, **hub_attrs, **copied_attrs
|
||||
)
|
||||
return config, hub_attrs, attrs
|
||||
|
||||
|
||||
def infer_autoclass_from_llm(llm, config, /):
|
||||
autoclass = 'AutoModelForSeq2SeqLM' if llm.config['model_type'] == 'seq2seq_lm' else 'AutoModelForCausalLM'
|
||||
if llm.trust_remote_code:
|
||||
if not hasattr(config, 'auto_map'):
|
||||
raise ValueError(
|
||||
f'Invalid configuration for {llm.model_id}. ``trust_remote_code=True`` requires `transformers.PretrainedConfig` to contain a `auto_map` mapping'
|
||||
)
|
||||
raise ValueError(f'Invalid configuration for {llm.model_id}. ``trust_remote_code=True`` requires `transformers.PretrainedConfig` to contain a `auto_map` mapping')
|
||||
# in case this model doesn't use the correct auto class for model type, for example like chatglm
|
||||
# where it uses AutoModel instead of AutoModelForCausalLM. Then we fallback to AutoModel
|
||||
if autoclass not in config.auto_map:
|
||||
|
||||
@@ -1,30 +1,22 @@
|
||||
from __future__ import annotations
|
||||
import traceback
|
||||
import typing as t
|
||||
from pathlib import Path
|
||||
|
||||
import attr
|
||||
import attr, traceback, pathlib, typing as t
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
from openllm_core.exceptions import Error
|
||||
from openllm_core.utils import resolve_filepath, validate_is_path
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from huggingface_hub.hf_api import ModelInfo as HfModelInfo
|
||||
|
||||
import openllm
|
||||
|
||||
__global_inst__ = None
|
||||
__cached_id__: dict[str, HfModelInfo] = dict()
|
||||
|
||||
|
||||
def Client() -> HfApi:
|
||||
global __global_inst__ # noqa: PLW0603
|
||||
if __global_inst__ is None:
|
||||
__global_inst__ = HfApi()
|
||||
return __global_inst__
|
||||
|
||||
|
||||
def ModelInfo(model_id: str, revision: str | None = None) -> HfModelInfo:
|
||||
if model_id in __cached_id__:
|
||||
return __cached_id__[model_id]
|
||||
@@ -35,13 +27,11 @@ def ModelInfo(model_id: str, revision: str | None = None) -> HfModelInfo:
|
||||
traceback.print_exc()
|
||||
raise Error(f'Failed to fetch {model_id} from huggingface.co') from err
|
||||
|
||||
|
||||
def has_safetensors_weights(model_id: str, revision: str | None = None) -> bool:
|
||||
if validate_is_path(model_id):
|
||||
return next((True for _ in Path(resolve_filepath(model_id)).glob('*.safetensors')), False)
|
||||
return next((True for _ in pathlib.Path(resolve_filepath(model_id)).glob('*.safetensors')), False)
|
||||
return any(s.rfilename.endswith('.safetensors') for s in ModelInfo(model_id, revision=revision).siblings)
|
||||
|
||||
|
||||
@attr.define(slots=True)
|
||||
class HfIgnore:
|
||||
safetensors = '*.safetensors'
|
||||
@@ -49,7 +39,6 @@ class HfIgnore:
|
||||
tf = '*.h5'
|
||||
flax = '*.msgpack'
|
||||
gguf = '*.gguf'
|
||||
|
||||
@classmethod
|
||||
def ignore_patterns(cls, llm: openllm.LLM[t.Any, t.Any]) -> list[str]:
|
||||
if llm.__llm_backend__ in {'vllm', 'pt'}:
|
||||
|
||||
Reference in New Issue
Block a user