From ff8b6377c8e6d38b1d6cf0d19ea7cdf3ba57734d Mon Sep 17 00:00:00 2001 From: Aaron Pham <29749331+aarnphm@users.noreply.github.com> Date: Wed, 8 Nov 2023 06:57:11 -0500 Subject: [PATCH] fix(awq): correct awq detection for support (#586) * fix(awq): correct detection for awq Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com> * chore: update base docker to work Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com> * chore: disable awq on pytorch for now Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com> * ci: auto fixes from pre-commit.ci For more information, see https://pre-commit.ci --------- Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../src/openllm_core/utils/import_utils.py | 7 +- openllm-python/src/openllm/_llm.py | 8 +- .../src/openllm/bundle/oci/Dockerfile | 3 + .../serialisation/transformers/__init__.py | 88 +++++-------------- .../serialisation/transformers/_helpers.py | 11 +-- openllm-python/tests/conftest.py | 2 +- tools/dependencies.py | 2 +- 7 files changed, 44 insertions(+), 77 deletions(-) diff --git a/openllm-core/src/openllm_core/utils/import_utils.py b/openllm-core/src/openllm_core/utils/import_utils.py index 15fc2f2e..679976f1 100644 --- a/openllm-core/src/openllm_core/utils/import_utils.py +++ b/openllm-core/src/openllm_core/utils/import_utils.py @@ -54,7 +54,7 @@ _jupyter_available = _is_package_available('jupyter') _jupytext_available = _is_package_available('jupytext') _notebook_available = _is_package_available('notebook') _autogptq_available = _is_package_available('auto_gptq') -_autoawq_available = _is_package_available('awq') +_autoawq_available = importlib.util.find_spec('awq') is not None _sentencepiece_available = _is_package_available('sentencepiece') _xformers_available = _is_package_available('xformers') _fairscale_available = _is_package_available('fairscale') @@ -124,6 +124,11 @@ def is_torch_available() -> bool: return _torch_available def is_autoawq_available() -> bool: + global _autoawq_available + try: + importlib.metadata.version('autoawq') + except importlib.metadata.PackageNotFoundError: + _autoawq_available = False return _autoawq_available def is_vllm_available() -> bool: diff --git a/openllm-python/src/openllm/_llm.py b/openllm-python/src/openllm/_llm.py index c22161b4..2e390ace 100644 --- a/openllm-python/src/openllm/_llm.py +++ b/openllm-python/src/openllm/_llm.py @@ -132,6 +132,7 @@ class LLM(t.Generic[M, T]): __llm_model__: t.Optional[M] = None __llm_tokenizer__: t.Optional[T] = None __llm_adapter_map__: t.Optional[ResolvedAdapterMap] = None + __llm_trust_remote_code__: bool = False device: 'torch.device | None' = None @@ -151,6 +152,7 @@ class LLM(t.Generic[M, T]): quantization_config: transformers.BitsAndBytesConfig | transformers.GPTQConfig | transformers.AwqConfig | None = None, adapter_map: dict[str, str] | None = None, serialisation: LiteralSerialisation = 'safetensors', + trust_remote_code: bool = False, **attrs: t.Any): # low_cpu_mem_usage is only available for model this is helpful on system with low memory to avoid OOM low_cpu_mem_usage = attrs.pop('low_cpu_mem_usage', True) @@ -186,8 +188,8 @@ class LLM(t.Generic[M, T]): prompt_template=prompt_template, system_message=system_message, llm_backend__=backend, - llm_config__=llm_config) - self.__llm_backend__ = backend + llm_config__=llm_config, + llm_trust_remote_code__=trust_remote_code) @apply(lambda val: tuple(str.lower(i) if i else i for i in val)) def _make_tag_components(self, model_id: str, model_version: str | None, backend: LiteralBackend) -> tuple[str, str | None]: @@ -206,7 +208,7 @@ class LLM(t.Generic[M, T]): @property def import_kwargs(self)->tuple[dict[str, t.Any],dict[str, t.Any]]: return {'device_map': 'auto' if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None, 'torch_dtype': torch.float16 if torch.cuda.is_available() else torch.float32}, {'padding_side': 'left', 'truncation_side': 'left'} @property - def trust_remote_code(self)->bool:return first_not_none(check_bool_env('TRUST_REMOTE_CODE',False),default=self.config['trust_remote_code']) + def trust_remote_code(self)->bool:return first_not_none(check_bool_env('TRUST_REMOTE_CODE',False),default=self.__llm_trust_remote_code__) @property def runner_name(self)->str:return f"llm-{self.config['start_name']}-runner" @property diff --git a/openllm-python/src/openllm/bundle/oci/Dockerfile b/openllm-python/src/openllm/bundle/oci/Dockerfile index 946974a0..fc2be86e 100644 --- a/openllm-python/src/openllm/bundle/oci/Dockerfile +++ b/openllm-python/src/openllm/bundle/oci/Dockerfile @@ -91,6 +91,9 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins rm -rf /var/lib/apt/lists/* # Install all required dependencies +# We have to install autoawq first to avoid conflict with torch, then reinstall torch with vllm +# below +# pip install autoawq --no-cache-dir && \ RUN --mount=type=cache,target=/root/.cache/pip \ pip install --extra-index-url "https://download.pytorch.org/whl/cu118" \ --extra-index-url "https://huggingface.github.io/autogptq-index/whl/cu118/" \ diff --git a/openllm-python/src/openllm/serialisation/transformers/__init__.py b/openllm-python/src/openllm/serialisation/transformers/__init__.py index e2a8c9c6..4a15fed3 100644 --- a/openllm-python/src/openllm/serialisation/transformers/__init__.py +++ b/openllm-python/src/openllm/serialisation/transformers/__init__.py @@ -13,6 +13,8 @@ from simple_di import inject import bentoml import openllm +import torch +import transformers from bentoml._internal.configuration.containers import BentoMLContainer from bentoml._internal.models.model import ModelOptions @@ -20,7 +22,6 @@ from bentoml._internal.models.model import ModelSignature from openllm_core._typing_compat import M from openllm_core._typing_compat import T -from ._helpers import check_unintialised_params from ._helpers import get_hash from ._helpers import infer_autoclass_from_llm from ._helpers import process_config @@ -28,18 +29,8 @@ from .weights import HfIgnore if t.TYPE_CHECKING: import types - - import auto_gptq as autogptq - import torch - import torch.nn - import transformers - from bentoml._internal.models import ModelStore from openllm_core._typing_compat import DictStrAny -else: - transformers = openllm.utils.LazyLoader('transformers', globals(), 'transformers') - autogptq = openllm.utils.LazyLoader('autogptq', globals(), 'auto_gptq') - torch = openllm.utils.LazyLoader('torch', globals(), 'torch') logger = logging.getLogger(__name__) @@ -48,34 +39,19 @@ _object_setattr = object.__setattr__ def _patch_correct_tag(llm: openllm.LLM[M, T], config: transformers.PretrainedConfig, _revision: str | None = None) -> None: # NOTE: The following won't hit during local since we generated a correct version based on local path hash It will only hit if we use model from HF Hub + if llm.revision is not None: return if not llm._local: try: if _revision is None: _revision = get_hash(config) except ValueError: pass + if _revision is None and llm.tag.version is not None: _revision = llm.tag.version if llm._tag.version is None: _object_setattr(llm, '_tag', attr.evolve(llm.tag, version=_revision)) # HACK: This copies the correct revision into llm.tag - else: _revision = llm._tag.version if llm._revision is None: _object_setattr(llm, '_revision', _revision) # HACK: This copies the correct revision into llm._model_version @inject def import_model(llm: openllm.LLM[M, T], *decls: t.Any, trust_remote_code: bool, _model_store: ModelStore = Provide[BentoMLContainer.model_store], **attrs: t.Any) -> bentoml.Model: - """Auto detect model type from given model_id and import it to bentoml's model store. - - For all kwargs, it will be parsed into `transformers.AutoConfig.from_pretrained` first, - returning all of the unused kwargs. - The unused kwargs then parsed directly into AutoModelForSeq2SeqLM or AutoModelForCausalLM. - For all tokenizer kwargs, make sure to prefix it with `_tokenizer_` to avoid confusion. - - Note: Currently, there are only two tasks supported: `text-generation` and `text2text-generation`. - - Refer to Transformers documentation for more information about kwargs. - - Args: - llm: The LLM instance for this given model. - trust_remote_code: Whether to trust the remote code when loading the model. - *decls: Args to be passed into AutoModelForSeq2SeqLM or AutoModelForCausalLM. - **attrs: Kwargs to be passed into AutoModelForSeq2SeqLM or AutoModelForCausalLM. - """ + if llm._quantise == 'awq': raise RuntimeError('AWQ is not yet supported with PyTorch backend.') config, hub_attrs, attrs = process_config(llm.model_id, trust_remote_code, **attrs) _patch_correct_tag(llm, config) _, tokenizer_attrs = llm.llm_parameters @@ -150,53 +126,37 @@ def import_model(llm: openllm.LLM[M, T], *decls: t.Any, trust_remote_code: bool, return bentomodel def get(llm: openllm.LLM[M, T], auto_import: bool = False) -> bentoml.Model: - """Return an instance of ``bentoml.Model`` from given LLM instance. - - By default, it will try to check the model in the local store. - If model is not found, and ``auto_import`` is set to True, it will try to import the model from HuggingFace Hub. - - Otherwise, it will raises a ``bentoml.exceptions.NotFound``. - """ try: model = bentoml.models.get(llm.tag) backend = model.info.labels['backend'] if backend != llm.__llm_backend__: raise openllm.exceptions.OpenLLMException(f"'{model.tag!s}' was saved with backend '{backend}', while loading with '{llm.__llm_backend__}'.") - _patch_correct_tag(llm, process_config(model.path, llm.trust_remote_code)[0], _revision=t.cast(t.Optional[str], model.info.metadata.get('_revision'))) + _patch_correct_tag(llm, transformers.AutoConfig.from_pretrained(model.path, trust_remote_code=llm.trust_remote_code), _revision=model.info.metadata.get('_revision')) return model except Exception as err: if auto_import: return import_model(llm, trust_remote_code=llm.trust_remote_code) raise openllm.exceptions.OpenLLMException(f'Failed while getting stored artefact (lookup for traceback):\n{err}') from err def load_model(llm: openllm.LLM[M, T], *decls: t.Any, **attrs: t.Any) -> M: - config, hub_attrs, attrs = process_config(llm.bentomodel.path, llm.trust_remote_code, **attrs) - _patch_correct_tag(llm, config, _revision=t.cast(t.Optional[str], llm.bentomodel.info.metadata.get('_revision'))) + if llm._quantise == 'awq': raise RuntimeError('AWQ is not yet supported with PyTorch backend.') + config, attrs = transformers.AutoConfig.from_pretrained(llm.bentomodel.path, return_unused_kwargs=True, trust_remote_code=llm.trust_remote_code, **attrs) auto_class = infer_autoclass_from_llm(llm, config) - device_map: str | None = attrs.pop('device_map', 'auto' if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None) - if llm._quantise or llm._quantization_config: attrs['quantization_config'] = llm.quantization_config + device_map = attrs.pop('device_map', None) + if torch.cuda.is_available(): + if torch.cuda.device_count() > 1: device_map = 'auto' + elif torch.cuda.device_count() == 1: device_map = 'cuda:0' + if llm._quantise in {'int8', 'int4'}: attrs['quantization_config'] = llm.quantization_config - if '_quantize' in llm.bentomodel.info.metadata and llm.bentomodel.info.metadata['_quantize'] == 'gptq': - if not openllm.utils.is_autogptq_available() or not openllm.utils.is_optimum_supports_gptq(): - raise openllm.exceptions.OpenLLMException( - "GPTQ quantisation requires 'auto-gptq' and 'optimum' (Not found in local environment). Install it with 'pip install \"openllm[gptq]\" --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/'" - ) - if llm.config['model_type'] != 'causal_lm': raise openllm.exceptions.OpenLLMException(f"GPTQ only support Causal LM (got {llm.__class__} of {llm.config['model_type']})") + if '_quantize' in llm.bentomodel.info.metadata: + _quantise = llm.bentomodel.info.metadata['_quantize'] + if _quantise == 'gptq': + if not openllm.utils.is_autogptq_available() or not openllm.utils.is_optimum_supports_gptq(): + raise openllm.exceptions.OpenLLMException( + "GPTQ quantisation requires 'auto-gptq' and 'optimum' (Not found in local environment). Install it with 'pip install \"openllm[gptq]\" --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/'" + ) + if llm.config['model_type'] != 'causal_lm': raise openllm.exceptions.OpenLLMException(f"GPTQ only support Causal LM (got {llm.__class__} of {llm.config['model_type']})") - try: - model = auto_class.from_pretrained(llm.bentomodel.path, device_map='auto', use_flash_attention_2=True, **hub_attrs, **attrs) - except Exception as err: - logger.debug("Exception caught while trying to load with 'flash_attention_2': %s", err) - model = auto_class.from_pretrained(llm.bentomodel.path, device_map='auto', use_flash_attention_2=False, **hub_attrs, **attrs) - # XXX: Use the below logic once TheBloke finished migration to new GPTQConfig from transformers - # Seems like the logic below requires to add support for safetensors on accelerate - # from accelerate import init_empty_weights - # from optimum.gptq import load_quantized_model - # # disable exllama if gptq is loaded on CPU - # disable_exllama = not torch.cuda.is_available() - # with init_empty_weights(): - # empty = auto_class.from_pretrained(llm.model_id, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, device_map='auto') - # empty.tie_weights() - # model = load_quantized_model(empty, save_folder=llm._bentomodel.path, device_map='auto', disable_exllama=disable_exllama) + # TODO: investigate load with flash attention + model = auto_class.from_pretrained(llm.bentomodel.path, device_map=device_map, **attrs) else: - model = auto_class.from_pretrained(llm.bentomodel.path, *decls, config=config, trust_remote_code=llm.trust_remote_code, device_map=device_map, **hub_attrs, **attrs).eval() - if llm.__llm_backend__ == 'pt': check_unintialised_params(model) + model = auto_class.from_pretrained(llm.bentomodel.path, *decls, config=config, trust_remote_code=llm.trust_remote_code, device_map=device_map, **attrs) return t.cast('M', model) diff --git a/openllm-python/src/openllm/serialisation/transformers/_helpers.py b/openllm-python/src/openllm/serialisation/transformers/_helpers.py index ef2b1f3e..6b50052c 100644 --- a/openllm-python/src/openllm/serialisation/transformers/_helpers.py +++ b/openllm-python/src/openllm/serialisation/transformers/_helpers.py @@ -1,24 +1,21 @@ from __future__ import annotations -import copy +import copy, re +from pathlib import Path import typing as t import openllm -import openllm_core +import transformers +import torch from openllm.serialisation.constants import FRAMEWORK_TO_AUTOCLASS_MAPPING from openllm.serialisation.constants import HUB_ATTRS if t.TYPE_CHECKING: - import torch - import transformers - from transformers.models.auto.auto_factory import _BaseAutoModelClass from openllm_core._typing_compat import DictStrAny from openllm_core._typing_compat import M from openllm_core._typing_compat import T -else: - transformers, torch = openllm_core.utils.LazyLoader('transformers', globals(), 'transformers'), openllm_core.utils.LazyLoader('torch', globals(), 'torch') def get_hash(config: transformers.PretrainedConfig) -> str: _commit_hash = getattr(config, '_commit_hash', None) diff --git a/openllm-python/tests/conftest.py b/openllm-python/tests/conftest.py index 3318a4ab..60fb13cd 100644 --- a/openllm-python/tests/conftest.py +++ b/openllm-python/tests/conftest.py @@ -13,7 +13,7 @@ if t.TYPE_CHECKING: _MODELING_MAPPING = {'flan_t5': 'google/flan-t5-small', 'opt': 'facebook/opt-125m', 'baichuan': 'baichuan-inc/Baichuan-7B'} _PROMPT_MAPPING = {'qa': 'Answer the following yes/no question by reasoning step-by-step. Can you write a whole Haiku in a single tweet?'} -def parametrise_local_llm(model: str,) -> t.Generator[tuple[str, openllm.LLMRunner[t.Any, t.Any] | openllm.LLM[t.Any, t.Any]], None, None]: +def parametrise_local_llm(model: str) -> t.Generator[tuple[str, openllm.LLMRunner[t.Any, t.Any] | openllm.LLM[t.Any, t.Any]], None, None]: if model not in _MODELING_MAPPING: pytest.skip(f"'{model}' is not yet supported in framework testing.") backends: tuple[LiteralBackend, ...] = ('pt',) for backend, prompt in itertools.product(backends, _PROMPT_MAPPING.keys()): diff --git a/tools/dependencies.py b/tools/dependencies.py index 8e44ce86..d71f0005 100755 --- a/tools/dependencies.py +++ b/tools/dependencies.py @@ -139,7 +139,7 @@ OPENAI_DEPS = ['openai[datalib]>=1', 'tiktoken'] AGENTS_DEPS = ['transformers[agents]>=4.34.0', 'diffusers', 'soundfile'] PLAYGROUND_DEPS = ['jupyter', 'notebook', 'ipython', 'jupytext', 'nbformat'] GGML_DEPS = ['ctransformers'] -AWQ_DEPS = ["autoawq"] +AWQ_DEPS = ['autoawq'] GPTQ_DEPS = ['auto-gptq[triton]>=0.4.2', 'optimum>=1.12.0'] VLLM_DEPS = ['vllm>=0.2.1post1', 'ray']