mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-03-12 03:58:51 -04:00
chore(style): cleanup bytes
Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
@@ -15,18 +15,15 @@ logger = logging.getLogger(__name__)
|
||||
__all__ = ['import_model', 'get', 'load_model']
|
||||
_object_setattr = object.__setattr__
|
||||
|
||||
|
||||
def import_model(llm, *decls, trust_remote_code, **attrs):
|
||||
(_base_decls, _base_attrs), tokenizer_attrs = llm.llm_parameters
|
||||
decls = (*_base_decls, *decls)
|
||||
attrs = {**_base_attrs, **attrs}
|
||||
if llm._local:
|
||||
logger.warning('Given model is a local model, OpenLLM will load model into memory for serialisation.')
|
||||
if llm._local: logger.warning('Given model is a local model, OpenLLM will load model into memory for serialisation.')
|
||||
config, hub_attrs, attrs = process_config(llm.model_id, trust_remote_code, **attrs)
|
||||
patch_correct_tag(llm, config)
|
||||
safe_serialisation = first_not_none(attrs.get('safe_serialization'), default=llm._serialisation == 'safetensors')
|
||||
if llm.quantise != 'gptq':
|
||||
attrs['use_safetensors'] = safe_serialisation
|
||||
if llm.quantise != 'gptq': attrs['use_safetensors'] = safe_serialisation
|
||||
|
||||
model = None
|
||||
tokenizer = get_tokenizer(llm.model_id, trust_remote_code=trust_remote_code, **hub_attrs, **tokenizer_attrs)
|
||||
@@ -39,7 +36,6 @@ def import_model(llm, *decls, trust_remote_code, **attrs):
|
||||
attrs['quantization_config'] = llm.quantization_config
|
||||
if llm.quantise == 'gptq':
|
||||
from optimum.gptq.constants import GPTQ_CONFIG
|
||||
|
||||
with open(bentomodel.path_of(GPTQ_CONFIG), 'w', encoding='utf-8') as f:
|
||||
f.write(orjson.dumps(config.quantization_config, option=orjson.OPT_INDENT_2 | orjson.OPT_SORT_KEYS).decode())
|
||||
if llm._local: # possible local path
|
||||
@@ -56,27 +52,21 @@ def import_model(llm, *decls, trust_remote_code, **attrs):
|
||||
bentomodel.enter_cloudpickle_context([importlib.import_module(model.__module__)], imported_modules)
|
||||
model.save_pretrained(bentomodel.path, max_shard_size='2GB', safe_serialization=safe_serialisation)
|
||||
del model
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
if torch.cuda.is_available(): torch.cuda.empty_cache()
|
||||
else:
|
||||
# we will clone the all tings into the bentomodel path without loading model into memory
|
||||
snapshot_download(
|
||||
llm.model_id,
|
||||
local_dir=bentomodel.path,
|
||||
local_dir_use_symlinks=False,
|
||||
ignore_patterns=HfIgnore.ignore_patterns(llm),
|
||||
llm.model_id, local_dir=bentomodel.path, #
|
||||
local_dir_use_symlinks=False, ignore_patterns=HfIgnore.ignore_patterns(llm), #
|
||||
)
|
||||
return bentomodel
|
||||
|
||||
|
||||
def get(llm):
|
||||
try:
|
||||
model = bentoml.models.get(llm.tag)
|
||||
backend = model.info.labels['backend']
|
||||
if backend != llm.__llm_backend__:
|
||||
raise OpenLLMException(
|
||||
f"'{model.tag!s}' was saved with backend '{backend}', while loading with '{llm.__llm_backend__}'."
|
||||
)
|
||||
raise OpenLLMException(f"'{model.tag!s}' was saved with backend '{backend}', while loading with '{llm.__llm_backend__}'.")
|
||||
patch_correct_tag(
|
||||
llm,
|
||||
transformers.AutoConfig.from_pretrained(model.path, trust_remote_code=llm.trust_remote_code),
|
||||
@@ -89,18 +79,15 @@ def get(llm):
|
||||
|
||||
def check_unintialised_params(model):
|
||||
unintialized = [n for n, param in model.named_parameters() if param.data.device == torch.device('meta')]
|
||||
if len(unintialized) > 0:
|
||||
raise RuntimeError(f'Found the following unintialized parameters in {model}: {unintialized}')
|
||||
if len(unintialized) > 0: raise RuntimeError(f'Found the following unintialized parameters in {model}: {unintialized}')
|
||||
|
||||
|
||||
def load_model(llm, *decls, **attrs):
|
||||
if llm.quantise in {'awq', 'squeezellm'}:
|
||||
raise RuntimeError('AWQ is not yet supported with PyTorch backend.')
|
||||
if llm.quantise in {'awq', 'squeezellm'}: raise RuntimeError('AWQ is not yet supported with PyTorch backend.')
|
||||
config, attrs = transformers.AutoConfig.from_pretrained(
|
||||
llm.bentomodel.path, return_unused_kwargs=True, trust_remote_code=llm.trust_remote_code, **attrs
|
||||
)
|
||||
if llm.__llm_backend__ == 'triton':
|
||||
return openllm.models.load_model(llm, config, **attrs)
|
||||
if llm.__llm_backend__ == 'triton': return openllm.models.load_model(llm, config, **attrs)
|
||||
|
||||
auto_class = infer_autoclass_from_llm(llm, config)
|
||||
device_map = attrs.pop('device_map', None)
|
||||
@@ -152,6 +139,5 @@ def load_model(llm, *decls, **attrs):
|
||||
device_map=device_map,
|
||||
**attrs,
|
||||
)
|
||||
|
||||
check_unintialised_params(model)
|
||||
return model
|
||||
|
||||
@@ -1,45 +1,28 @@
|
||||
from __future__ import annotations
|
||||
import copy
|
||||
import logging
|
||||
import typing as t
|
||||
|
||||
import copy, logging
|
||||
import transformers
|
||||
|
||||
from openllm.serialisation.constants import HUB_ATTRS
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_tokenizer(model_id_or_path, trust_remote_code, **attrs):
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
||||
model_id_or_path, trust_remote_code=trust_remote_code, **attrs
|
||||
)
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(model_id_or_path, trust_remote_code=trust_remote_code, **attrs)
|
||||
if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token
|
||||
return tokenizer
|
||||
|
||||
|
||||
def process_config(model_id: str, trust_remote_code: bool, **attrs: t.Any):
|
||||
def process_config(model_id, trust_remote_code, **attrs):
|
||||
config = attrs.pop('config', None)
|
||||
# this logic below is synonymous to handling `from_pretrained` attrs.
|
||||
hub_attrs = {k: attrs.pop(k) for k in HUB_ATTRS if k in attrs}
|
||||
if not isinstance(config, transformers.PretrainedConfig):
|
||||
copied_attrs = copy.deepcopy(attrs)
|
||||
if copied_attrs.get('torch_dtype', None) == 'auto':
|
||||
copied_attrs.pop('torch_dtype')
|
||||
if copied_attrs.get('torch_dtype', None) == 'auto': copied_attrs.pop('torch_dtype')
|
||||
config, attrs = transformers.AutoConfig.from_pretrained(
|
||||
model_id, return_unused_kwargs=True, trust_remote_code=trust_remote_code, **hub_attrs, **copied_attrs
|
||||
)
|
||||
return config, hub_attrs, attrs
|
||||
|
||||
|
||||
def infer_autoclass_from_llm(llm, config, /):
|
||||
autoclass = 'AutoModelForSeq2SeqLM' if llm.config['model_type'] == 'seq2seq_lm' else 'AutoModelForCausalLM'
|
||||
if llm.trust_remote_code:
|
||||
if not hasattr(config, 'auto_map'):
|
||||
raise ValueError(
|
||||
f'Invalid configuration for {llm.model_id}. ``trust_remote_code=True`` requires `transformers.PretrainedConfig` to contain a `auto_map` mapping'
|
||||
)
|
||||
raise ValueError(f'Invalid configuration for {llm.model_id}. ``trust_remote_code=True`` requires `transformers.PretrainedConfig` to contain a `auto_map` mapping')
|
||||
# in case this model doesn't use the correct auto class for model type, for example like chatglm
|
||||
# where it uses AutoModel instead of AutoModelForCausalLM. Then we fallback to AutoModel
|
||||
if autoclass not in config.auto_map:
|
||||
|
||||
@@ -1,30 +1,22 @@
|
||||
from __future__ import annotations
|
||||
import traceback
|
||||
import typing as t
|
||||
from pathlib import Path
|
||||
|
||||
import attr
|
||||
import attr, traceback, pathlib, typing as t
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
from openllm_core.exceptions import Error
|
||||
from openllm_core.utils import resolve_filepath, validate_is_path
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from huggingface_hub.hf_api import ModelInfo as HfModelInfo
|
||||
|
||||
import openllm
|
||||
|
||||
__global_inst__ = None
|
||||
__cached_id__: dict[str, HfModelInfo] = dict()
|
||||
|
||||
|
||||
def Client() -> HfApi:
|
||||
global __global_inst__ # noqa: PLW0603
|
||||
if __global_inst__ is None:
|
||||
__global_inst__ = HfApi()
|
||||
return __global_inst__
|
||||
|
||||
|
||||
def ModelInfo(model_id: str, revision: str | None = None) -> HfModelInfo:
|
||||
if model_id in __cached_id__:
|
||||
return __cached_id__[model_id]
|
||||
@@ -35,13 +27,11 @@ def ModelInfo(model_id: str, revision: str | None = None) -> HfModelInfo:
|
||||
traceback.print_exc()
|
||||
raise Error(f'Failed to fetch {model_id} from huggingface.co') from err
|
||||
|
||||
|
||||
def has_safetensors_weights(model_id: str, revision: str | None = None) -> bool:
|
||||
if validate_is_path(model_id):
|
||||
return next((True for _ in Path(resolve_filepath(model_id)).glob('*.safetensors')), False)
|
||||
return next((True for _ in pathlib.Path(resolve_filepath(model_id)).glob('*.safetensors')), False)
|
||||
return any(s.rfilename.endswith('.safetensors') for s in ModelInfo(model_id, revision=revision).siblings)
|
||||
|
||||
|
||||
@attr.define(slots=True)
|
||||
class HfIgnore:
|
||||
safetensors = '*.safetensors'
|
||||
@@ -49,7 +39,6 @@ class HfIgnore:
|
||||
tf = '*.h5'
|
||||
flax = '*.msgpack'
|
||||
gguf = '*.gguf'
|
||||
|
||||
@classmethod
|
||||
def ignore_patterns(cls, llm: openllm.LLM[t.Any, t.Any]) -> list[str]:
|
||||
if llm.__llm_backend__ in {'vllm', 'pt'}:
|
||||
|
||||
Reference in New Issue
Block a user