From ce6efc2a9eb886b5231d2ee5a2ef050136ebcf04 Mon Sep 17 00:00:00 2001 From: Aaron <29749331+aarnphm@users.noreply.github.com> Date: Tue, 28 Nov 2023 01:27:27 -0500 Subject: [PATCH] chore(style): cleanup bytes Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> --- .../src/openllm/serialisation/__init__.py | 46 ++------ .../src/openllm/serialisation/__init__.pyi | 10 +- .../src/openllm/serialisation/_helpers.py | 103 +++++------------- .../src/openllm/serialisation/constants.py | 12 +- .../serialisation/transformers/__init__.py | 32 ++---- .../serialisation/transformers/_helpers.py | 29 +---- .../serialisation/transformers/weights.py | 15 +-- 7 files changed, 65 insertions(+), 182 deletions(-) diff --git a/openllm-python/src/openllm/serialisation/__init__.py b/openllm-python/src/openllm/serialisation/__init__.py index 2a05d90a..ff6a57e2 100644 --- a/openllm-python/src/openllm/serialisation/__init__.py +++ b/openllm-python/src/openllm/serialisation/__init__.py @@ -1,32 +1,18 @@ from __future__ import annotations -import importlib -import typing as t - -from openllm_core._typing_compat import M, ParamSpec, T, TypeGuard +import importlib, typing as t +from openllm_core._typing_compat import M, ParamSpec, T, TypeGuard, Concatenate from openllm_core.exceptions import OpenLLMException -if t.TYPE_CHECKING: - from bentoml import Model - - from .._llm import LLM +if t.TYPE_CHECKING: from bentoml import Model; from .._llm import LLM P = ParamSpec('P') - def load_tokenizer(llm: LLM[M, T], **tokenizer_attrs: t.Any) -> TypeGuard[T]: - '''Load the tokenizer from BentoML store. - - By default, it will try to find the bentomodel whether it is in store.. - If model is not found, it will raises a ``bentoml.exceptions.NotFound``. - ''' - import cloudpickle - import fs - from transformers import AutoTokenizer + import cloudpickle, fs, transformers + from bentoml._internal.models.model import CUSTOM_OBJECTS_FILENAME + from .transformers._helpers import process_config tokenizer_attrs = {**llm.llm_parameters[-1], **tokenizer_attrs} - from bentoml._internal.models.model import CUSTOM_OBJECTS_FILENAME - - from .transformers._helpers import process_config config, *_ = process_config(llm.bentomodel.path, llm.trust_remote_code) @@ -41,9 +27,7 @@ def load_tokenizer(llm: LLM[M, T], **tokenizer_attrs: t.Any) -> TypeGuard[T]: 'For example: "bentoml.transformers.save_model(..., custom_objects={\'tokenizer\': tokenizer})"' ) from None else: - tokenizer = AutoTokenizer.from_pretrained( - bentomodel_fs.getsyspath('/'), trust_remote_code=llm.trust_remote_code, **tokenizer_attrs - ) + tokenizer = transformers.AutoTokenizer.from_pretrained(bentomodel_fs.getsyspath('/'), trust_remote_code=llm.trust_remote_code, **tokenizer_attrs) if tokenizer.pad_token_id is None: if config.pad_token_id is not None: @@ -56,17 +40,16 @@ def load_tokenizer(llm: LLM[M, T], **tokenizer_attrs: t.Any) -> TypeGuard[T]: tokenizer.add_special_tokens({'pad_token': '[PAD]'}) return tokenizer - -def _make_dispatch_function(fn): +def _make_dispatch_function(fn: str) -> t.Callable[Concatenate[LLM[M, T], P], TypeGuard[M | T | Model]]: def caller(llm: LLM[M, T], *args: P.args, **kwargs: P.kwargs) -> TypeGuard[M | T | Model]: - """Generic function dispatch to correct serialisation submodules based on LLM runtime. + '''Generic function dispatch to correct serialisation submodules based on LLM runtime. > [!NOTE] See 'openllm.serialisation.transformers' if 'llm.__llm_backend__ in ("pt", "vllm")' > [!NOTE] See 'openllm.serialisation.ggml' if 'llm.__llm_backend__="ggml"' > [!NOTE] See 'openllm.serialisation.ctranslate' if 'llm.__llm_backend__="ctranslate"' - """ + ''' if llm.__llm_backend__ == 'ggml': serde = 'ggml' elif llm.__llm_backend__ == 'ctranslate': @@ -76,19 +59,12 @@ def _make_dispatch_function(fn): else: raise OpenLLMException(f'Not supported backend {llm.__llm_backend__}') return getattr(importlib.import_module(f'.{serde}', 'openllm.serialisation'), fn)(llm, *args, **kwargs) - return caller - _extras = ['get', 'import_model', 'load_model'] _import_structure = {'ggml', 'transformers', 'ctranslate', 'constants'} __all__ = ['load_tokenizer', *_extras, *_import_structure] - - -def __dir__() -> t.Sequence[str]: - return sorted(__all__) - - +def __dir__() -> t.Sequence[str]: return sorted(__all__) def __getattr__(name: str) -> t.Any: if name == 'load_tokenizer': return load_tokenizer diff --git a/openllm-python/src/openllm/serialisation/__init__.pyi b/openllm-python/src/openllm/serialisation/__init__.pyi index 9a6b3fad..e8ea2b91 100644 --- a/openllm-python/src/openllm/serialisation/__init__.pyi +++ b/openllm-python/src/openllm/serialisation/__init__.pyi @@ -4,16 +4,18 @@ Currently supports transformers for PyTorch, and vLLM. Currently, GGML format is working in progress. ''' - from typing import Any - from bentoml import Model from openllm import LLM from openllm_core._typing_compat import M, T - from . import constants as constants, ggml as ggml, transformers as transformers -def load_tokenizer(llm: LLM[M, T], **attrs: Any) -> T: ... +def load_tokenizer(llm: LLM[M, T], **attrs: Any) -> T: + '''Load the tokenizer from BentoML store. + + By default, it will try to find the bentomodel whether it is in store.. + If model is not found, it will raises a ``bentoml.exceptions.NotFound``. + ''' def get(llm: LLM[M, T]) -> Model: ... def import_model(llm: LLM[M, T], *args: Any, trust_remote_code: bool, **attrs: Any) -> Model: ... def load_model(llm: LLM[M, T], *args: Any, **attrs: Any) -> M: ... diff --git a/openllm-python/src/openllm/serialisation/_helpers.py b/openllm-python/src/openllm/serialisation/_helpers.py index 23643bae..ea364b07 100644 --- a/openllm-python/src/openllm/serialisation/_helpers.py +++ b/openllm-python/src/openllm/serialisation/_helpers.py @@ -1,10 +1,6 @@ -import contextlib - -import attr +import contextlib, attr from simple_di import Provide, inject - -import bentoml -import openllm +import bentoml, openllm from bentoml._internal.configuration.containers import BentoMLContainer from bentoml._internal.models.model import ModelOptions, ModelSignature from openllm_core.exceptions import OpenLLMException @@ -12,88 +8,58 @@ from openllm_core.utils import is_autogptq_available _object_setattr = object.__setattr__ - def get_hash(config) -> str: _commit_hash = getattr(config, '_commit_hash', None) - if _commit_hash is None: - raise ValueError(f'Cannot find commit hash in {config}') + if _commit_hash is None: raise ValueError(f'Cannot find commit hash in {config}') return _commit_hash - -def patch_correct_tag(llm, config, _revision=None): +def patch_correct_tag(llm, config, _revision=None) -> None: # NOTE: The following won't hit during local since we generated a correct version based on local path hash It will only hit if we use model from HF Hub - if llm.revision is not None: - return + if llm.revision is not None: return if not llm.local: try: - if _revision is None: - _revision = get_hash(config) + if _revision is None: _revision = get_hash(config) except ValueError: pass - if _revision is None and llm.tag.version is not None: - _revision = llm.tag.version + if _revision is None and llm.tag.version is not None: _revision = llm.tag.version if llm.tag.version is None: - # HACK: This copies the correct revision into llm.tag - _object_setattr(llm, '_tag', attr.evolve(llm.tag, version=_revision)) - if llm._revision is None: - _object_setattr(llm, '_revision', _revision) # HACK: This copies the correct revision into llm._model_version - + _object_setattr(llm, '_tag', attr.evolve(llm.tag, version=_revision)) # HACK: This copies the correct revision into llm.tag + if llm._revision is None: _object_setattr(llm, '_revision', _revision) # HACK: This copies the correct revision into llm._model_version def _create_metadata(llm, config, safe_serialisation, trust_remote_code, metadata=None): - if metadata is None: - metadata = {} + if metadata is None: metadata = {} metadata.update({'safe_serialisation': safe_serialisation, '_framework': llm.__llm_backend__}) - if llm.quantise: - metadata['_quantize'] = llm.quantise + if llm.quantise: metadata['_quantize'] = llm.quantise architectures = getattr(config, 'architectures', []) if not architectures: if trust_remote_code: auto_map = getattr(config, 'auto_map', {}) - if not auto_map: - raise RuntimeError( - f'Failed to determine the architecture from both `auto_map` and `architectures` from {llm.model_id}' - ) + if not auto_map: raise RuntimeError(f'Failed to determine the architecture from both `auto_map` and `architectures` from {llm.model_id}') autoclass = 'AutoModelForSeq2SeqLM' if llm.config['model_type'] == 'seq2seq_lm' else 'AutoModelForCausalLM' if autoclass not in auto_map: - raise RuntimeError( - f"Given model '{llm.model_id}' is yet to be supported with 'auto_map'. OpenLLM currently only support encoder-decoders or decoders only models." - ) + raise RuntimeError(f"Given model '{llm.model_id}' is yet to be supported with 'auto_map'. OpenLLM currently only support encoder-decoders or decoders only models.") architectures = [auto_map[autoclass]] else: - raise RuntimeError( - 'Failed to determine the architecture for this model. Make sure the `config.json` is valid and can be loaded with `transformers.AutoConfig`' - ) - metadata.update( - {'_pretrained_class': architectures[0], '_revision': get_hash(config) if not llm.local else llm.revision} - ) + raise RuntimeError('Failed to determine the architecture for this model. Make sure the `config.json` is valid and can be loaded with `transformers.AutoConfig`') + metadata.update({'_pretrained_class': architectures[0], '_revision': get_hash(config) if not llm.local else llm.revision}) return metadata - def _create_signatures(llm, signatures=None): - if signatures is None: - signatures = {} + if signatures is None: signatures = {} if llm.__llm_backend__ == 'pt': if llm.quantise == 'gptq': if not is_autogptq_available(): - raise OpenLLMException( - "GPTQ quantisation requires 'auto-gptq' and 'optimum' (Not found in local environment). Install it with 'pip install \"openllm[gptq]\" --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/'" - ) + raise OpenLLMException("Requires 'auto-gptq' and 'optimum'. Install it with 'pip install \"openllm[gptq]\"'") signatures['generate'] = {'batchable': False} else: signatures.update( { k: ModelSignature(batchable=False) for k in ( - '__call__', - 'forward', - 'generate', - 'contrastive_search', - 'greedy_search', - 'sample', - 'beam_search', - 'beam_sample', - 'group_beam_search', - 'constrained_beam_search', + '__call__', 'forward', 'generate', # + 'contrastive_search', 'greedy_search', # + 'sample', 'beam_search', 'beam_sample', # + 'group_beam_search', 'constrained_beam_search', # ) } ) @@ -104,37 +70,24 @@ def _create_signatures(llm, signatures=None): else: non_batch_keys = set() batch_keys = { - 'async_generate_tokens', - 'forward_batch', - 'generate_batch', - 'generate_iterable', - 'generate_tokens', - 'score_batch', - 'score_iterable', + 'async_generate_tokens', 'forward_batch', 'generate_batch', # + 'generate_iterable', 'generate_tokens', 'score_batch', 'score_iterable', # } signatures.update({k: ModelSignature(batchable=False) for k in non_batch_keys}) signatures.update({k: ModelSignature(batchable=True) for k in batch_keys}) return signatures - @inject @contextlib.contextmanager def save_model( - llm, - config, - safe_serialisation, - trust_remote_code, - module, - external_modules, - _model_store=Provide[BentoMLContainer.model_store], - _api_version='v2.1.0', + llm, config, safe_serialisation, # + trust_remote_code, module, external_modules, # + _model_store=Provide[BentoMLContainer.model_store], _api_version='v2.1.0', # ): imported_modules = [] bentomodel = bentoml.Model.create( - llm.tag, - module=f'openllm.serialisation.{module}', - api_version=_api_version, - options=ModelOptions(), + llm.tag, module=f'openllm.serialisation.{module}', # + api_version=_api_version, options=ModelOptions(), # context=openllm.utils.generate_context('openllm'), labels=openllm.utils.generate_labels(llm), metadata=_create_metadata(llm, config, safe_serialisation, trust_remote_code), diff --git a/openllm-python/src/openllm/serialisation/constants.py b/openllm-python/src/openllm/serialisation/constants.py index 247e8a72..88614513 100644 --- a/openllm-python/src/openllm/serialisation/constants.py +++ b/openllm-python/src/openllm/serialisation/constants.py @@ -1,13 +1,7 @@ HUB_ATTRS = [ - 'cache_dir', - 'code_revision', - 'force_download', - 'local_files_only', - 'proxies', - 'resume_download', - 'revision', - 'subfolder', - 'use_auth_token', + 'cache_dir', 'code_revision', 'force_download', # + 'local_files_only', 'proxies', 'resume_download', # + 'revision', 'subfolder', 'use_auth_token', # ] CONFIG_FILE_NAME = 'config.json' # the below is similar to peft.utils.other.CONFIG_NAME diff --git a/openllm-python/src/openllm/serialisation/transformers/__init__.py b/openllm-python/src/openllm/serialisation/transformers/__init__.py index f2b40cfb..20d917af 100644 --- a/openllm-python/src/openllm/serialisation/transformers/__init__.py +++ b/openllm-python/src/openllm/serialisation/transformers/__init__.py @@ -15,18 +15,15 @@ logger = logging.getLogger(__name__) __all__ = ['import_model', 'get', 'load_model'] _object_setattr = object.__setattr__ - def import_model(llm, *decls, trust_remote_code, **attrs): (_base_decls, _base_attrs), tokenizer_attrs = llm.llm_parameters decls = (*_base_decls, *decls) attrs = {**_base_attrs, **attrs} - if llm._local: - logger.warning('Given model is a local model, OpenLLM will load model into memory for serialisation.') + if llm._local: logger.warning('Given model is a local model, OpenLLM will load model into memory for serialisation.') config, hub_attrs, attrs = process_config(llm.model_id, trust_remote_code, **attrs) patch_correct_tag(llm, config) safe_serialisation = first_not_none(attrs.get('safe_serialization'), default=llm._serialisation == 'safetensors') - if llm.quantise != 'gptq': - attrs['use_safetensors'] = safe_serialisation + if llm.quantise != 'gptq': attrs['use_safetensors'] = safe_serialisation model = None tokenizer = get_tokenizer(llm.model_id, trust_remote_code=trust_remote_code, **hub_attrs, **tokenizer_attrs) @@ -39,7 +36,6 @@ def import_model(llm, *decls, trust_remote_code, **attrs): attrs['quantization_config'] = llm.quantization_config if llm.quantise == 'gptq': from optimum.gptq.constants import GPTQ_CONFIG - with open(bentomodel.path_of(GPTQ_CONFIG), 'w', encoding='utf-8') as f: f.write(orjson.dumps(config.quantization_config, option=orjson.OPT_INDENT_2 | orjson.OPT_SORT_KEYS).decode()) if llm._local: # possible local path @@ -56,27 +52,21 @@ def import_model(llm, *decls, trust_remote_code, **attrs): bentomodel.enter_cloudpickle_context([importlib.import_module(model.__module__)], imported_modules) model.save_pretrained(bentomodel.path, max_shard_size='2GB', safe_serialization=safe_serialisation) del model - if torch.cuda.is_available(): - torch.cuda.empty_cache() + if torch.cuda.is_available(): torch.cuda.empty_cache() else: # we will clone the all tings into the bentomodel path without loading model into memory snapshot_download( - llm.model_id, - local_dir=bentomodel.path, - local_dir_use_symlinks=False, - ignore_patterns=HfIgnore.ignore_patterns(llm), + llm.model_id, local_dir=bentomodel.path, # + local_dir_use_symlinks=False, ignore_patterns=HfIgnore.ignore_patterns(llm), # ) return bentomodel - def get(llm): try: model = bentoml.models.get(llm.tag) backend = model.info.labels['backend'] if backend != llm.__llm_backend__: - raise OpenLLMException( - f"'{model.tag!s}' was saved with backend '{backend}', while loading with '{llm.__llm_backend__}'." - ) + raise OpenLLMException(f"'{model.tag!s}' was saved with backend '{backend}', while loading with '{llm.__llm_backend__}'.") patch_correct_tag( llm, transformers.AutoConfig.from_pretrained(model.path, trust_remote_code=llm.trust_remote_code), @@ -89,18 +79,15 @@ def get(llm): def check_unintialised_params(model): unintialized = [n for n, param in model.named_parameters() if param.data.device == torch.device('meta')] - if len(unintialized) > 0: - raise RuntimeError(f'Found the following unintialized parameters in {model}: {unintialized}') + if len(unintialized) > 0: raise RuntimeError(f'Found the following unintialized parameters in {model}: {unintialized}') def load_model(llm, *decls, **attrs): - if llm.quantise in {'awq', 'squeezellm'}: - raise RuntimeError('AWQ is not yet supported with PyTorch backend.') + if llm.quantise in {'awq', 'squeezellm'}: raise RuntimeError('AWQ is not yet supported with PyTorch backend.') config, attrs = transformers.AutoConfig.from_pretrained( llm.bentomodel.path, return_unused_kwargs=True, trust_remote_code=llm.trust_remote_code, **attrs ) - if llm.__llm_backend__ == 'triton': - return openllm.models.load_model(llm, config, **attrs) + if llm.__llm_backend__ == 'triton': return openllm.models.load_model(llm, config, **attrs) auto_class = infer_autoclass_from_llm(llm, config) device_map = attrs.pop('device_map', None) @@ -152,6 +139,5 @@ def load_model(llm, *decls, **attrs): device_map=device_map, **attrs, ) - check_unintialised_params(model) return model diff --git a/openllm-python/src/openllm/serialisation/transformers/_helpers.py b/openllm-python/src/openllm/serialisation/transformers/_helpers.py index 3ab49df5..f77d639a 100644 --- a/openllm-python/src/openllm/serialisation/transformers/_helpers.py +++ b/openllm-python/src/openllm/serialisation/transformers/_helpers.py @@ -1,45 +1,28 @@ -from __future__ import annotations -import copy -import logging -import typing as t - +import copy, logging import transformers - from openllm.serialisation.constants import HUB_ATTRS logger = logging.getLogger(__name__) - - def get_tokenizer(model_id_or_path, trust_remote_code, **attrs): - tokenizer = transformers.AutoTokenizer.from_pretrained( - model_id_or_path, trust_remote_code=trust_remote_code, **attrs - ) - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token + tokenizer = transformers.AutoTokenizer.from_pretrained(model_id_or_path, trust_remote_code=trust_remote_code, **attrs) + if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token return tokenizer - - -def process_config(model_id: str, trust_remote_code: bool, **attrs: t.Any): +def process_config(model_id, trust_remote_code, **attrs): config = attrs.pop('config', None) # this logic below is synonymous to handling `from_pretrained` attrs. hub_attrs = {k: attrs.pop(k) for k in HUB_ATTRS if k in attrs} if not isinstance(config, transformers.PretrainedConfig): copied_attrs = copy.deepcopy(attrs) - if copied_attrs.get('torch_dtype', None) == 'auto': - copied_attrs.pop('torch_dtype') + if copied_attrs.get('torch_dtype', None) == 'auto': copied_attrs.pop('torch_dtype') config, attrs = transformers.AutoConfig.from_pretrained( model_id, return_unused_kwargs=True, trust_remote_code=trust_remote_code, **hub_attrs, **copied_attrs ) return config, hub_attrs, attrs - - def infer_autoclass_from_llm(llm, config, /): autoclass = 'AutoModelForSeq2SeqLM' if llm.config['model_type'] == 'seq2seq_lm' else 'AutoModelForCausalLM' if llm.trust_remote_code: if not hasattr(config, 'auto_map'): - raise ValueError( - f'Invalid configuration for {llm.model_id}. ``trust_remote_code=True`` requires `transformers.PretrainedConfig` to contain a `auto_map` mapping' - ) + raise ValueError(f'Invalid configuration for {llm.model_id}. ``trust_remote_code=True`` requires `transformers.PretrainedConfig` to contain a `auto_map` mapping') # in case this model doesn't use the correct auto class for model type, for example like chatglm # where it uses AutoModel instead of AutoModelForCausalLM. Then we fallback to AutoModel if autoclass not in config.auto_map: diff --git a/openllm-python/src/openllm/serialisation/transformers/weights.py b/openllm-python/src/openllm/serialisation/transformers/weights.py index 8f0016cb..bda9fa7c 100644 --- a/openllm-python/src/openllm/serialisation/transformers/weights.py +++ b/openllm-python/src/openllm/serialisation/transformers/weights.py @@ -1,30 +1,22 @@ from __future__ import annotations -import traceback -import typing as t -from pathlib import Path - -import attr +import attr, traceback, pathlib, typing as t from huggingface_hub import HfApi - from openllm_core.exceptions import Error from openllm_core.utils import resolve_filepath, validate_is_path if t.TYPE_CHECKING: from huggingface_hub.hf_api import ModelInfo as HfModelInfo - import openllm __global_inst__ = None __cached_id__: dict[str, HfModelInfo] = dict() - def Client() -> HfApi: global __global_inst__ # noqa: PLW0603 if __global_inst__ is None: __global_inst__ = HfApi() return __global_inst__ - def ModelInfo(model_id: str, revision: str | None = None) -> HfModelInfo: if model_id in __cached_id__: return __cached_id__[model_id] @@ -35,13 +27,11 @@ def ModelInfo(model_id: str, revision: str | None = None) -> HfModelInfo: traceback.print_exc() raise Error(f'Failed to fetch {model_id} from huggingface.co') from err - def has_safetensors_weights(model_id: str, revision: str | None = None) -> bool: if validate_is_path(model_id): - return next((True for _ in Path(resolve_filepath(model_id)).glob('*.safetensors')), False) + return next((True for _ in pathlib.Path(resolve_filepath(model_id)).glob('*.safetensors')), False) return any(s.rfilename.endswith('.safetensors') for s in ModelInfo(model_id, revision=revision).siblings) - @attr.define(slots=True) class HfIgnore: safetensors = '*.safetensors' @@ -49,7 +39,6 @@ class HfIgnore: tf = '*.h5' flax = '*.msgpack' gguf = '*.gguf' - @classmethod def ignore_patterns(cls, llm: openllm.LLM[t.Any, t.Any]) -> list[str]: if llm.__llm_backend__ in {'vllm', 'pt'}: