infra: using ruff formatter (#594)

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
Aaron Pham
2023-11-09 12:44:05 -05:00
committed by GitHub
parent 021fd453b9
commit ac377fe490
102 changed files with 5577 additions and 2540 deletions

View File

@@ -1,4 +1,5 @@
"""Serialisation related implementation for Transformers-based implementation."""
from __future__ import annotations
import importlib
import logging
@@ -27,6 +28,7 @@ from ._helpers import infer_autoclass_from_llm
from ._helpers import process_config
from .weights import HfIgnore
if t.TYPE_CHECKING:
import types
@@ -38,29 +40,52 @@ logger = logging.getLogger(__name__)
__all__ = ['import_model', 'get', 'load_model']
_object_setattr = object.__setattr__
def _patch_correct_tag(llm: openllm.LLM[M, T], config: transformers.PretrainedConfig, _revision: str | None = None) -> None:
def _patch_correct_tag(
llm: openllm.LLM[M, T], config: transformers.PretrainedConfig, _revision: str | None = None
) -> None:
# NOTE: The following won't hit during local since we generated a correct version based on local path hash It will only hit if we use model from HF Hub
if llm.revision is not None: return
if llm.revision is not None:
return
if not llm._local:
try:
if _revision is None: _revision = get_hash(config)
if _revision is None:
_revision = get_hash(config)
except ValueError:
pass
if _revision is None and llm.tag.version is not None: _revision = llm.tag.version
if llm._tag.version is None: _object_setattr(llm, '_tag', attr.evolve(llm.tag, version=_revision)) # HACK: This copies the correct revision into llm.tag
if llm._revision is None: _object_setattr(llm, '_revision', _revision) # HACK: This copies the correct revision into llm._model_version
if _revision is None and llm.tag.version is not None:
_revision = llm.tag.version
if llm._tag.version is None:
_object_setattr(
llm, '_tag', attr.evolve(llm.tag, version=_revision)
) # HACK: This copies the correct revision into llm.tag
if llm._revision is None:
_object_setattr(llm, '_revision', _revision) # HACK: This copies the correct revision into llm._model_version
@inject
def import_model(llm: openllm.LLM[M, T], *decls: t.Any, trust_remote_code: bool, _model_store: ModelStore = Provide[BentoMLContainer.model_store], **attrs: t.Any) -> bentoml.Model:
def import_model(
llm: openllm.LLM[M, T],
*decls: t.Any,
trust_remote_code: bool,
_model_store: ModelStore = Provide[BentoMLContainer.model_store],
**attrs: t.Any,
) -> bentoml.Model:
config, hub_attrs, attrs = process_config(llm.model_id, trust_remote_code, **attrs)
_patch_correct_tag(llm, config)
_, tokenizer_attrs = llm.llm_parameters
quantize = llm._quantise
safe_serialisation = openllm.utils.first_not_none(attrs.get('safe_serialization'), default=llm._serialisation == 'safetensors')
safe_serialisation = openllm.utils.first_not_none(
attrs.get('safe_serialization'), default=llm._serialisation == 'safetensors'
)
metadata: DictStrAny = {'safe_serialisation': safe_serialisation}
if quantize: metadata['_quantize'] = quantize
if quantize:
metadata['_quantize'] = quantize
architectures = getattr(config, 'architectures', [])
if not architectures: raise RuntimeError('Failed to determine the architecture for this model. Make sure the `config.json` is valid and can be loaded with `transformers.AutoConfig`')
if not architectures:
raise RuntimeError(
'Failed to determine the architecture for this model. Make sure the `config.json` is valid and can be loaded with `transformers.AutoConfig`'
)
metadata['_pretrained_class'] = architectures[0]
metadata['_revision'] = get_hash(config)
@@ -69,93 +94,152 @@ def import_model(llm: openllm.LLM[M, T], *decls: t.Any, trust_remote_code: bool,
if quantize == 'gptq':
if not openllm.utils.is_autogptq_available() or not openllm.utils.is_optimum_supports_gptq():
raise openllm.exceptions.OpenLLMException(
"GPTQ quantisation requires 'auto-gptq' and 'optimum' (Not found in local environment). Install it with 'pip install \"openllm[gptq]\" --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/'"
"GPTQ quantisation requires 'auto-gptq' and 'optimum' (Not found in local environment). Install it with 'pip install \"openllm[gptq]\" --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/'"
)
signatures['generate'] = {'batchable': False}
else:
attrs['use_safetensors'] = safe_serialisation
metadata['_framework'] = llm.__llm_backend__
signatures.update({
signatures.update(
{
k: ModelSignature(batchable=False)
for k in ('__call__', 'forward', 'generate', 'contrastive_search', 'greedy_search', 'sample', 'beam_search', 'beam_sample', 'group_beam_search', 'constrained_beam_search')
})
for k in (
'__call__',
'forward',
'generate',
'contrastive_search',
'greedy_search',
'sample',
'beam_search',
'beam_sample',
'group_beam_search',
'constrained_beam_search',
)
}
)
tokenizer = transformers.AutoTokenizer.from_pretrained(llm.model_id, trust_remote_code=trust_remote_code, **hub_attrs, **tokenizer_attrs)
if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token
tokenizer = transformers.AutoTokenizer.from_pretrained(
llm.model_id, trust_remote_code=trust_remote_code, **hub_attrs, **tokenizer_attrs
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = None
external_modules: list[types.ModuleType] = [importlib.import_module(tokenizer.__module__)]
imported_modules: list[types.ModuleType] = []
bentomodel = bentoml.Model.create(llm.tag,
module='openllm.serialisation.transformers',
api_version='v2.1.0',
options=ModelOptions(),
context=openllm.utils.generate_context(framework_name='openllm'),
labels=openllm.utils.generate_labels(llm),
metadata=metadata,
signatures=signatures)
bentomodel = bentoml.Model.create(
llm.tag,
module='openllm.serialisation.transformers',
api_version='v2.1.0',
options=ModelOptions(),
context=openllm.utils.generate_context(framework_name='openllm'),
labels=openllm.utils.generate_labels(llm),
metadata=metadata,
signatures=signatures,
)
with openllm.utils.analytics.set_bentoml_tracking():
try:
bentomodel.enter_cloudpickle_context(external_modules, imported_modules)
tokenizer.save_pretrained(bentomodel.path)
if llm._quantization_config or (llm._quantise and llm._quantise not in {'squeezellm', 'awq'}): attrs['quantization_config'] = llm.quantization_config
if llm._quantization_config or (llm._quantise and llm._quantise not in {'squeezellm', 'awq'}):
attrs['quantization_config'] = llm.quantization_config
if quantize == 'gptq':
from optimum.gptq.constants import GPTQ_CONFIG
with open(bentomodel.path_of(GPTQ_CONFIG), 'w', encoding='utf-8') as f:
f.write(orjson.dumps(config.quantization_config, option=orjson.OPT_INDENT_2 | orjson.OPT_SORT_KEYS).decode())
if llm._local: # possible local path
model = infer_autoclass_from_llm(llm, config).from_pretrained(llm.model_id, *decls, config=config, trust_remote_code=trust_remote_code, **hub_attrs, **attrs)
model = infer_autoclass_from_llm(llm, config).from_pretrained(
llm.model_id, *decls, config=config, trust_remote_code=trust_remote_code, **hub_attrs, **attrs
)
# for trust_remote_code to work
bentomodel.enter_cloudpickle_context([importlib.import_module(model.__module__)], imported_modules)
model.save_pretrained(bentomodel.path, max_shard_size='5GB', safe_serialization=safe_serialisation)
else:
# we will clone the all tings into the bentomodel path without loading model into memory
snapshot_download(llm.model_id, local_dir=bentomodel.path, local_dir_use_symlinks=False, ignore_patterns=HfIgnore.ignore_patterns(llm))
snapshot_download(
llm.model_id,
local_dir=bentomodel.path,
local_dir_use_symlinks=False,
ignore_patterns=HfIgnore.ignore_patterns(llm),
)
except Exception:
raise
else:
bentomodel.flush() # type: ignore[no-untyped-call]
bentomodel.save(_model_store)
openllm.utils.analytics.track(openllm.utils.analytics.ModelSaveEvent(module=bentomodel.info.module, model_size_in_kb=openllm.utils.calc_dir_size(bentomodel.path) / 1024))
openllm.utils.analytics.track(
openllm.utils.analytics.ModelSaveEvent(
module=bentomodel.info.module, model_size_in_kb=openllm.utils.calc_dir_size(bentomodel.path) / 1024
)
)
finally:
bentomodel.exit_cloudpickle_context(imported_modules)
# NOTE: We need to free up the cache after importing the model
# in the case where users first run openllm start without the model available locally.
if openllm.utils.is_torch_available() and torch.cuda.is_available(): torch.cuda.empty_cache()
if openllm.utils.is_torch_available() and torch.cuda.is_available():
torch.cuda.empty_cache()
del model
return bentomodel
def get(llm: openllm.LLM[M, T]) -> bentoml.Model:
try:
model = bentoml.models.get(llm.tag)
backend = model.info.labels['backend']
if backend != llm.__llm_backend__: raise openllm.exceptions.OpenLLMException(f"'{model.tag!s}' was saved with backend '{backend}', while loading with '{llm.__llm_backend__}'.")
_patch_correct_tag(llm, transformers.AutoConfig.from_pretrained(model.path, trust_remote_code=llm.trust_remote_code), _revision=model.info.metadata.get('_revision'))
if backend != llm.__llm_backend__:
raise openllm.exceptions.OpenLLMException(
f"'{model.tag!s}' was saved with backend '{backend}', while loading with '{llm.__llm_backend__}'."
)
_patch_correct_tag(
llm,
transformers.AutoConfig.from_pretrained(model.path, trust_remote_code=llm.trust_remote_code),
_revision=model.info.metadata.get('_revision'),
)
return model
except Exception as err:
raise openllm.exceptions.OpenLLMException(f'Failed while getting stored artefact (lookup for traceback):\n{err}') from err
raise openllm.exceptions.OpenLLMException(
f'Failed while getting stored artefact (lookup for traceback):\n{err}'
) from err
def load_model(llm: openllm.LLM[M, T], *decls: t.Any, **attrs: t.Any) -> M:
if llm._quantise in {'awq', 'squeezellm'}: raise RuntimeError('AWQ is not yet supported with PyTorch backend.')
config, attrs = transformers.AutoConfig.from_pretrained(llm.bentomodel.path, return_unused_kwargs=True, trust_remote_code=llm.trust_remote_code, **attrs)
if llm._quantise in {'awq', 'squeezellm'}:
raise RuntimeError('AWQ is not yet supported with PyTorch backend.')
config, attrs = transformers.AutoConfig.from_pretrained(
llm.bentomodel.path, return_unused_kwargs=True, trust_remote_code=llm.trust_remote_code, **attrs
)
auto_class = infer_autoclass_from_llm(llm, config)
device_map = attrs.pop('device_map', None)
if torch.cuda.is_available():
if torch.cuda.device_count() > 1: device_map = 'auto'
elif torch.cuda.device_count() == 1: device_map = 'cuda:0'
if llm._quantise in {'int8', 'int4'}: attrs['quantization_config'] = llm.quantization_config
if torch.cuda.device_count() > 1:
device_map = 'auto'
elif torch.cuda.device_count() == 1:
device_map = 'cuda:0'
if llm._quantise in {'int8', 'int4'}:
attrs['quantization_config'] = llm.quantization_config
if '_quantize' in llm.bentomodel.info.metadata:
_quantise = llm.bentomodel.info.metadata['_quantize']
if _quantise == 'gptq':
if not openllm.utils.is_autogptq_available() or not openllm.utils.is_optimum_supports_gptq():
raise openllm.exceptions.OpenLLMException(
"GPTQ quantisation requires 'auto-gptq' and 'optimum' (Not found in local environment). Install it with 'pip install \"openllm[gptq]\" --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/'"
"GPTQ quantisation requires 'auto-gptq' and 'optimum' (Not found in local environment). Install it with 'pip install \"openllm[gptq]\" --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/'"
)
if llm.config['model_type'] != 'causal_lm':
raise openllm.exceptions.OpenLLMException(
f"GPTQ only support Causal LM (got {llm.__class__} of {llm.config['model_type']})"
)
if llm.config['model_type'] != 'causal_lm': raise openllm.exceptions.OpenLLMException(f"GPTQ only support Causal LM (got {llm.__class__} of {llm.config['model_type']})")
# TODO: investigate load with flash attention
model = auto_class.from_pretrained(llm.bentomodel.path, device_map=device_map, **attrs)
else:
model = auto_class.from_pretrained(llm.bentomodel.path, *decls, config=config, trust_remote_code=llm.trust_remote_code, device_map=device_map, **attrs)
model = auto_class.from_pretrained(
llm.bentomodel.path,
*decls,
config=config,
trust_remote_code=llm.trust_remote_code,
device_map=device_map,
**attrs,
)
return t.cast('M', model)

View File

@@ -10,6 +10,7 @@ import openllm
from openllm.serialisation.constants import FRAMEWORK_TO_AUTOCLASS_MAPPING
from openllm.serialisation.constants import HUB_ATTRS
if t.TYPE_CHECKING:
from transformers.models.auto.auto_factory import _BaseAutoModelClass
@@ -17,12 +18,17 @@ if t.TYPE_CHECKING:
from openllm_core._typing_compat import M
from openllm_core._typing_compat import T
def get_hash(config: transformers.PretrainedConfig) -> str:
_commit_hash = getattr(config, '_commit_hash', None)
if _commit_hash is None: raise ValueError(f'Cannot find commit hash in {config}')
if _commit_hash is None:
raise ValueError(f'Cannot find commit hash in {config}')
return _commit_hash
def process_config(model_id: str, trust_remote_code: bool, **attrs: t.Any) -> tuple[transformers.PretrainedConfig, DictStrAny, DictStrAny]:
def process_config(
model_id: str, trust_remote_code: bool, **attrs: t.Any
) -> tuple[transformers.PretrainedConfig, DictStrAny, DictStrAny]:
"""A helper function that correctly parse config and attributes for transformers.PretrainedConfig.
Args:
@@ -38,25 +44,36 @@ def process_config(model_id: str, trust_remote_code: bool, **attrs: t.Any) -> tu
hub_attrs = {k: attrs.pop(k) for k in HUB_ATTRS if k in attrs}
if not isinstance(config, transformers.PretrainedConfig):
copied_attrs = copy.deepcopy(attrs)
if copied_attrs.get('torch_dtype', None) == 'auto': copied_attrs.pop('torch_dtype')
config, attrs = transformers.AutoConfig.from_pretrained(model_id, return_unused_kwargs=True, trust_remote_code=trust_remote_code, **hub_attrs, **copied_attrs)
if copied_attrs.get('torch_dtype', None) == 'auto':
copied_attrs.pop('torch_dtype')
config, attrs = transformers.AutoConfig.from_pretrained(
model_id, return_unused_kwargs=True, trust_remote_code=trust_remote_code, **hub_attrs, **copied_attrs
)
return config, hub_attrs, attrs
def infer_autoclass_from_llm(llm: openllm.LLM[M, T], config: transformers.PretrainedConfig, /) -> _BaseAutoModelClass:
if llm.trust_remote_code:
autoclass = 'AutoModelForSeq2SeqLM' if llm.config['model_type'] == 'seq2seq_lm' else 'AutoModelForCausalLM'
if not hasattr(config, 'auto_map'):
raise ValueError(f'Invalid configuration for {llm.model_id}. ``trust_remote_code=True`` requires `transformers.PretrainedConfig` to contain a `auto_map` mapping')
raise ValueError(
f'Invalid configuration for {llm.model_id}. ``trust_remote_code=True`` requires `transformers.PretrainedConfig` to contain a `auto_map` mapping'
)
# in case this model doesn't use the correct auto class for model type, for example like chatglm
# where it uses AutoModel instead of AutoModelForCausalLM. Then we fallback to AutoModel
if autoclass not in config.auto_map: autoclass = 'AutoModel'
if autoclass not in config.auto_map:
autoclass = 'AutoModel'
return getattr(transformers, autoclass)
else:
if type(config) in transformers.MODEL_FOR_CAUSAL_LM_MAPPING: idx = 0
elif type(config) in transformers.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING: idx = 1
else: raise openllm.exceptions.OpenLLMException(f'Model type {type(config)} is not supported yet.')
if type(config) in transformers.MODEL_FOR_CAUSAL_LM_MAPPING:
idx = 0
elif type(config) in transformers.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING:
idx = 1
else:
raise openllm.exceptions.OpenLLMException(f'Model type {type(config)} is not supported yet.')
return getattr(transformers, FRAMEWORK_TO_AUTOCLASS_MAPPING[llm.__llm_backend__][idx])
def check_unintialised_params(model: torch.nn.Module) -> None:
unintialized = [n for n, param in model.named_parameters() if param.data.device == torch.device('meta')]
if len(unintialized) > 0:

View File

@@ -8,6 +8,7 @@ from huggingface_hub import HfApi
from openllm_core.exceptions import Error
if t.TYPE_CHECKING:
from huggingface_hub.hf_api import ModelInfo as HfModelInfo
@@ -19,13 +20,17 @@ if t.TYPE_CHECKING:
__global_inst__ = None
__cached_id__: dict[str, HfModelInfo] = dict()
def Client() -> HfApi:
global __global_inst__ # noqa: PLW0603
if __global_inst__ is None: __global_inst__ = HfApi()
if __global_inst__ is None:
__global_inst__ = HfApi()
return __global_inst__
def ModelInfo(model_id: str, revision: str | None = None) -> HfModelInfo:
if model_id in __cached_id__: return __cached_id__[model_id]
if model_id in __cached_id__:
return __cached_id__[model_id]
try:
__cached_id__[model_id] = Client().model_info(model_id, revision=revision)
return __cached_id__[model_id]
@@ -33,9 +38,11 @@ def ModelInfo(model_id: str, revision: str | None = None) -> HfModelInfo:
traceback.print_exc()
raise Error(f'Failed to fetch {model_id} from huggingface.co') from err
def has_safetensors_weights(model_id: str, revision: str | None = None) -> bool:
return any(s.rfilename.endswith('.safetensors') for s in ModelInfo(model_id, revision=revision).siblings)
@attr.define(slots=True)
class HfIgnore:
safetensors = '*.safetensors'
@@ -48,9 +55,12 @@ class HfIgnore:
def ignore_patterns(cls, llm: openllm.LLM[M, T]) -> list[str]:
if llm.__llm_backend__ in {'vllm', 'pt'}:
base = [cls.tf, cls.flax, cls.gguf]
if has_safetensors_weights(llm.model_id): base.append(cls.pt)
else: base.append(cls.safetensors)
elif llm.__llm_backend__ == 'ggml': base = [cls.tf, cls.flax, cls.pt, cls.safetensors]
if has_safetensors_weights(llm.model_id):
base.append(cls.pt)
else:
base.append(cls.safetensors)
elif llm.__llm_backend__ == 'ggml':
base = [cls.tf, cls.flax, cls.pt, cls.safetensors]
else:
raise ValueError('Unknown backend (should never happen at all.)')
# filter out these files, since we probably don't need them for now.