fix(awq): correct awq detection for support (#586)

* fix(awq): correct detection for awq

Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>

* chore: update base docker to work

Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>

* chore: disable awq on pytorch for now

Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>

* ci: auto fixes from pre-commit.ci

For more information, see https://pre-commit.ci

---------

Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Aaron Pham
2023-11-08 06:57:11 -05:00
committed by GitHub
parent 655a4071de
commit ff8b6377c8
7 changed files with 44 additions and 77 deletions

View File

@@ -13,6 +13,8 @@ from simple_di import inject
import bentoml
import openllm
import torch
import transformers
from bentoml._internal.configuration.containers import BentoMLContainer
from bentoml._internal.models.model import ModelOptions
@@ -20,7 +22,6 @@ from bentoml._internal.models.model import ModelSignature
from openllm_core._typing_compat import M
from openllm_core._typing_compat import T
from ._helpers import check_unintialised_params
from ._helpers import get_hash
from ._helpers import infer_autoclass_from_llm
from ._helpers import process_config
@@ -28,18 +29,8 @@ from .weights import HfIgnore
if t.TYPE_CHECKING:
import types
import auto_gptq as autogptq
import torch
import torch.nn
import transformers
from bentoml._internal.models import ModelStore
from openllm_core._typing_compat import DictStrAny
else:
transformers = openllm.utils.LazyLoader('transformers', globals(), 'transformers')
autogptq = openllm.utils.LazyLoader('autogptq', globals(), 'auto_gptq')
torch = openllm.utils.LazyLoader('torch', globals(), 'torch')
logger = logging.getLogger(__name__)
@@ -48,34 +39,19 @@ _object_setattr = object.__setattr__
def _patch_correct_tag(llm: openllm.LLM[M, T], config: transformers.PretrainedConfig, _revision: str | None = None) -> None:
# NOTE: The following won't hit during local since we generated a correct version based on local path hash It will only hit if we use model from HF Hub
if llm.revision is not None: return
if not llm._local:
try:
if _revision is None: _revision = get_hash(config)
except ValueError:
pass
if _revision is None and llm.tag.version is not None: _revision = llm.tag.version
if llm._tag.version is None: _object_setattr(llm, '_tag', attr.evolve(llm.tag, version=_revision)) # HACK: This copies the correct revision into llm.tag
else: _revision = llm._tag.version
if llm._revision is None: _object_setattr(llm, '_revision', _revision) # HACK: This copies the correct revision into llm._model_version
@inject
def import_model(llm: openllm.LLM[M, T], *decls: t.Any, trust_remote_code: bool, _model_store: ModelStore = Provide[BentoMLContainer.model_store], **attrs: t.Any) -> bentoml.Model:
"""Auto detect model type from given model_id and import it to bentoml's model store.
For all kwargs, it will be parsed into `transformers.AutoConfig.from_pretrained` first,
returning all of the unused kwargs.
The unused kwargs then parsed directly into AutoModelForSeq2SeqLM or AutoModelForCausalLM.
For all tokenizer kwargs, make sure to prefix it with `_tokenizer_` to avoid confusion.
Note: Currently, there are only two tasks supported: `text-generation` and `text2text-generation`.
Refer to Transformers documentation for more information about kwargs.
Args:
llm: The LLM instance for this given model.
trust_remote_code: Whether to trust the remote code when loading the model.
*decls: Args to be passed into AutoModelForSeq2SeqLM or AutoModelForCausalLM.
**attrs: Kwargs to be passed into AutoModelForSeq2SeqLM or AutoModelForCausalLM.
"""
if llm._quantise == 'awq': raise RuntimeError('AWQ is not yet supported with PyTorch backend.')
config, hub_attrs, attrs = process_config(llm.model_id, trust_remote_code, **attrs)
_patch_correct_tag(llm, config)
_, tokenizer_attrs = llm.llm_parameters
@@ -150,53 +126,37 @@ def import_model(llm: openllm.LLM[M, T], *decls: t.Any, trust_remote_code: bool,
return bentomodel
def get(llm: openllm.LLM[M, T], auto_import: bool = False) -> bentoml.Model:
"""Return an instance of ``bentoml.Model`` from given LLM instance.
By default, it will try to check the model in the local store.
If model is not found, and ``auto_import`` is set to True, it will try to import the model from HuggingFace Hub.
Otherwise, it will raises a ``bentoml.exceptions.NotFound``.
"""
try:
model = bentoml.models.get(llm.tag)
backend = model.info.labels['backend']
if backend != llm.__llm_backend__: raise openllm.exceptions.OpenLLMException(f"'{model.tag!s}' was saved with backend '{backend}', while loading with '{llm.__llm_backend__}'.")
_patch_correct_tag(llm, process_config(model.path, llm.trust_remote_code)[0], _revision=t.cast(t.Optional[str], model.info.metadata.get('_revision')))
_patch_correct_tag(llm, transformers.AutoConfig.from_pretrained(model.path, trust_remote_code=llm.trust_remote_code), _revision=model.info.metadata.get('_revision'))
return model
except Exception as err:
if auto_import: return import_model(llm, trust_remote_code=llm.trust_remote_code)
raise openllm.exceptions.OpenLLMException(f'Failed while getting stored artefact (lookup for traceback):\n{err}') from err
def load_model(llm: openllm.LLM[M, T], *decls: t.Any, **attrs: t.Any) -> M:
config, hub_attrs, attrs = process_config(llm.bentomodel.path, llm.trust_remote_code, **attrs)
_patch_correct_tag(llm, config, _revision=t.cast(t.Optional[str], llm.bentomodel.info.metadata.get('_revision')))
if llm._quantise == 'awq': raise RuntimeError('AWQ is not yet supported with PyTorch backend.')
config, attrs = transformers.AutoConfig.from_pretrained(llm.bentomodel.path, return_unused_kwargs=True, trust_remote_code=llm.trust_remote_code, **attrs)
auto_class = infer_autoclass_from_llm(llm, config)
device_map: str | None = attrs.pop('device_map', 'auto' if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None)
if llm._quantise or llm._quantization_config: attrs['quantization_config'] = llm.quantization_config
device_map = attrs.pop('device_map', None)
if torch.cuda.is_available():
if torch.cuda.device_count() > 1: device_map = 'auto'
elif torch.cuda.device_count() == 1: device_map = 'cuda:0'
if llm._quantise in {'int8', 'int4'}: attrs['quantization_config'] = llm.quantization_config
if '_quantize' in llm.bentomodel.info.metadata and llm.bentomodel.info.metadata['_quantize'] == 'gptq':
if not openllm.utils.is_autogptq_available() or not openllm.utils.is_optimum_supports_gptq():
raise openllm.exceptions.OpenLLMException(
"GPTQ quantisation requires 'auto-gptq' and 'optimum' (Not found in local environment). Install it with 'pip install \"openllm[gptq]\" --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/'"
)
if llm.config['model_type'] != 'causal_lm': raise openllm.exceptions.OpenLLMException(f"GPTQ only support Causal LM (got {llm.__class__} of {llm.config['model_type']})")
if '_quantize' in llm.bentomodel.info.metadata:
_quantise = llm.bentomodel.info.metadata['_quantize']
if _quantise == 'gptq':
if not openllm.utils.is_autogptq_available() or not openllm.utils.is_optimum_supports_gptq():
raise openllm.exceptions.OpenLLMException(
"GPTQ quantisation requires 'auto-gptq' and 'optimum' (Not found in local environment). Install it with 'pip install \"openllm[gptq]\" --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/'"
)
if llm.config['model_type'] != 'causal_lm': raise openllm.exceptions.OpenLLMException(f"GPTQ only support Causal LM (got {llm.__class__} of {llm.config['model_type']})")
try:
model = auto_class.from_pretrained(llm.bentomodel.path, device_map='auto', use_flash_attention_2=True, **hub_attrs, **attrs)
except Exception as err:
logger.debug("Exception caught while trying to load with 'flash_attention_2': %s", err)
model = auto_class.from_pretrained(llm.bentomodel.path, device_map='auto', use_flash_attention_2=False, **hub_attrs, **attrs)
# XXX: Use the below logic once TheBloke finished migration to new GPTQConfig from transformers
# Seems like the logic below requires to add support for safetensors on accelerate
# from accelerate import init_empty_weights
# from optimum.gptq import load_quantized_model
# # disable exllama if gptq is loaded on CPU
# disable_exllama = not torch.cuda.is_available()
# with init_empty_weights():
# empty = auto_class.from_pretrained(llm.model_id, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, device_map='auto')
# empty.tie_weights()
# model = load_quantized_model(empty, save_folder=llm._bentomodel.path, device_map='auto', disable_exllama=disable_exllama)
# TODO: investigate load with flash attention
model = auto_class.from_pretrained(llm.bentomodel.path, device_map=device_map, **attrs)
else:
model = auto_class.from_pretrained(llm.bentomodel.path, *decls, config=config, trust_remote_code=llm.trust_remote_code, device_map=device_map, **hub_attrs, **attrs).eval()
if llm.__llm_backend__ == 'pt': check_unintialised_params(model)
model = auto_class.from_pretrained(llm.bentomodel.path, *decls, config=config, trust_remote_code=llm.trust_remote_code, device_map=device_map, **attrs)
return t.cast('M', model)

View File

@@ -1,24 +1,21 @@
from __future__ import annotations
import copy
import copy, re
from pathlib import Path
import typing as t
import openllm
import openllm_core
import transformers
import torch
from openllm.serialisation.constants import FRAMEWORK_TO_AUTOCLASS_MAPPING
from openllm.serialisation.constants import HUB_ATTRS
if t.TYPE_CHECKING:
import torch
import transformers
from transformers.models.auto.auto_factory import _BaseAutoModelClass
from openllm_core._typing_compat import DictStrAny
from openllm_core._typing_compat import M
from openllm_core._typing_compat import T
else:
transformers, torch = openllm_core.utils.LazyLoader('transformers', globals(), 'transformers'), openllm_core.utils.LazyLoader('torch', globals(), 'torch')
def get_hash(config: transformers.PretrainedConfig) -> str:
_commit_hash = getattr(config, '_commit_hash', None)