fix(awq): correct awq detection for support (#586)

* fix(awq): correct detection for awq

Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>

* chore: update base docker to work

Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>

* chore: disable awq on pytorch for now

Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>

* ci: auto fixes from pre-commit.ci

For more information, see https://pre-commit.ci

---------

Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Aaron Pham
2023-11-08 06:57:11 -05:00
committed by GitHub
parent 655a4071de
commit ff8b6377c8
7 changed files with 44 additions and 77 deletions

View File

@@ -132,6 +132,7 @@ class LLM(t.Generic[M, T]):
__llm_model__: t.Optional[M] = None
__llm_tokenizer__: t.Optional[T] = None
__llm_adapter_map__: t.Optional[ResolvedAdapterMap] = None
__llm_trust_remote_code__: bool = False
device: 'torch.device | None' = None
@@ -151,6 +152,7 @@ class LLM(t.Generic[M, T]):
quantization_config: transformers.BitsAndBytesConfig | transformers.GPTQConfig | transformers.AwqConfig | None = None,
adapter_map: dict[str, str] | None = None,
serialisation: LiteralSerialisation = 'safetensors',
trust_remote_code: bool = False,
**attrs: t.Any):
# low_cpu_mem_usage is only available for model this is helpful on system with low memory to avoid OOM
low_cpu_mem_usage = attrs.pop('low_cpu_mem_usage', True)
@@ -186,8 +188,8 @@ class LLM(t.Generic[M, T]):
prompt_template=prompt_template,
system_message=system_message,
llm_backend__=backend,
llm_config__=llm_config)
self.__llm_backend__ = backend
llm_config__=llm_config,
llm_trust_remote_code__=trust_remote_code)
@apply(lambda val: tuple(str.lower(i) if i else i for i in val))
def _make_tag_components(self, model_id: str, model_version: str | None, backend: LiteralBackend) -> tuple[str, str | None]:
@@ -206,7 +208,7 @@ class LLM(t.Generic[M, T]):
@property
def import_kwargs(self)->tuple[dict[str, t.Any],dict[str, t.Any]]: return {'device_map': 'auto' if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None, 'torch_dtype': torch.float16 if torch.cuda.is_available() else torch.float32}, {'padding_side': 'left', 'truncation_side': 'left'}
@property
def trust_remote_code(self)->bool:return first_not_none(check_bool_env('TRUST_REMOTE_CODE',False),default=self.config['trust_remote_code'])
def trust_remote_code(self)->bool:return first_not_none(check_bool_env('TRUST_REMOTE_CODE',False),default=self.__llm_trust_remote_code__)
@property
def runner_name(self)->str:return f"llm-{self.config['start_name']}-runner"
@property

View File

@@ -91,6 +91,9 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
rm -rf /var/lib/apt/lists/*
# Install all required dependencies
# We have to install autoawq first to avoid conflict with torch, then reinstall torch with vllm
# below
# pip install autoawq --no-cache-dir && \
RUN --mount=type=cache,target=/root/.cache/pip \
pip install --extra-index-url "https://download.pytorch.org/whl/cu118" \
--extra-index-url "https://huggingface.github.io/autogptq-index/whl/cu118/" \

View File

@@ -13,6 +13,8 @@ from simple_di import inject
import bentoml
import openllm
import torch
import transformers
from bentoml._internal.configuration.containers import BentoMLContainer
from bentoml._internal.models.model import ModelOptions
@@ -20,7 +22,6 @@ from bentoml._internal.models.model import ModelSignature
from openllm_core._typing_compat import M
from openllm_core._typing_compat import T
from ._helpers import check_unintialised_params
from ._helpers import get_hash
from ._helpers import infer_autoclass_from_llm
from ._helpers import process_config
@@ -28,18 +29,8 @@ from .weights import HfIgnore
if t.TYPE_CHECKING:
import types
import auto_gptq as autogptq
import torch
import torch.nn
import transformers
from bentoml._internal.models import ModelStore
from openllm_core._typing_compat import DictStrAny
else:
transformers = openllm.utils.LazyLoader('transformers', globals(), 'transformers')
autogptq = openllm.utils.LazyLoader('autogptq', globals(), 'auto_gptq')
torch = openllm.utils.LazyLoader('torch', globals(), 'torch')
logger = logging.getLogger(__name__)
@@ -48,34 +39,19 @@ _object_setattr = object.__setattr__
def _patch_correct_tag(llm: openllm.LLM[M, T], config: transformers.PretrainedConfig, _revision: str | None = None) -> None:
# NOTE: The following won't hit during local since we generated a correct version based on local path hash It will only hit if we use model from HF Hub
if llm.revision is not None: return
if not llm._local:
try:
if _revision is None: _revision = get_hash(config)
except ValueError:
pass
if _revision is None and llm.tag.version is not None: _revision = llm.tag.version
if llm._tag.version is None: _object_setattr(llm, '_tag', attr.evolve(llm.tag, version=_revision)) # HACK: This copies the correct revision into llm.tag
else: _revision = llm._tag.version
if llm._revision is None: _object_setattr(llm, '_revision', _revision) # HACK: This copies the correct revision into llm._model_version
@inject
def import_model(llm: openllm.LLM[M, T], *decls: t.Any, trust_remote_code: bool, _model_store: ModelStore = Provide[BentoMLContainer.model_store], **attrs: t.Any) -> bentoml.Model:
"""Auto detect model type from given model_id and import it to bentoml's model store.
For all kwargs, it will be parsed into `transformers.AutoConfig.from_pretrained` first,
returning all of the unused kwargs.
The unused kwargs then parsed directly into AutoModelForSeq2SeqLM or AutoModelForCausalLM.
For all tokenizer kwargs, make sure to prefix it with `_tokenizer_` to avoid confusion.
Note: Currently, there are only two tasks supported: `text-generation` and `text2text-generation`.
Refer to Transformers documentation for more information about kwargs.
Args:
llm: The LLM instance for this given model.
trust_remote_code: Whether to trust the remote code when loading the model.
*decls: Args to be passed into AutoModelForSeq2SeqLM or AutoModelForCausalLM.
**attrs: Kwargs to be passed into AutoModelForSeq2SeqLM or AutoModelForCausalLM.
"""
if llm._quantise == 'awq': raise RuntimeError('AWQ is not yet supported with PyTorch backend.')
config, hub_attrs, attrs = process_config(llm.model_id, trust_remote_code, **attrs)
_patch_correct_tag(llm, config)
_, tokenizer_attrs = llm.llm_parameters
@@ -150,53 +126,37 @@ def import_model(llm: openllm.LLM[M, T], *decls: t.Any, trust_remote_code: bool,
return bentomodel
def get(llm: openllm.LLM[M, T], auto_import: bool = False) -> bentoml.Model:
"""Return an instance of ``bentoml.Model`` from given LLM instance.
By default, it will try to check the model in the local store.
If model is not found, and ``auto_import`` is set to True, it will try to import the model from HuggingFace Hub.
Otherwise, it will raises a ``bentoml.exceptions.NotFound``.
"""
try:
model = bentoml.models.get(llm.tag)
backend = model.info.labels['backend']
if backend != llm.__llm_backend__: raise openllm.exceptions.OpenLLMException(f"'{model.tag!s}' was saved with backend '{backend}', while loading with '{llm.__llm_backend__}'.")
_patch_correct_tag(llm, process_config(model.path, llm.trust_remote_code)[0], _revision=t.cast(t.Optional[str], model.info.metadata.get('_revision')))
_patch_correct_tag(llm, transformers.AutoConfig.from_pretrained(model.path, trust_remote_code=llm.trust_remote_code), _revision=model.info.metadata.get('_revision'))
return model
except Exception as err:
if auto_import: return import_model(llm, trust_remote_code=llm.trust_remote_code)
raise openllm.exceptions.OpenLLMException(f'Failed while getting stored artefact (lookup for traceback):\n{err}') from err
def load_model(llm: openllm.LLM[M, T], *decls: t.Any, **attrs: t.Any) -> M:
config, hub_attrs, attrs = process_config(llm.bentomodel.path, llm.trust_remote_code, **attrs)
_patch_correct_tag(llm, config, _revision=t.cast(t.Optional[str], llm.bentomodel.info.metadata.get('_revision')))
if llm._quantise == 'awq': raise RuntimeError('AWQ is not yet supported with PyTorch backend.')
config, attrs = transformers.AutoConfig.from_pretrained(llm.bentomodel.path, return_unused_kwargs=True, trust_remote_code=llm.trust_remote_code, **attrs)
auto_class = infer_autoclass_from_llm(llm, config)
device_map: str | None = attrs.pop('device_map', 'auto' if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None)
if llm._quantise or llm._quantization_config: attrs['quantization_config'] = llm.quantization_config
device_map = attrs.pop('device_map', None)
if torch.cuda.is_available():
if torch.cuda.device_count() > 1: device_map = 'auto'
elif torch.cuda.device_count() == 1: device_map = 'cuda:0'
if llm._quantise in {'int8', 'int4'}: attrs['quantization_config'] = llm.quantization_config
if '_quantize' in llm.bentomodel.info.metadata and llm.bentomodel.info.metadata['_quantize'] == 'gptq':
if not openllm.utils.is_autogptq_available() or not openllm.utils.is_optimum_supports_gptq():
raise openllm.exceptions.OpenLLMException(
"GPTQ quantisation requires 'auto-gptq' and 'optimum' (Not found in local environment). Install it with 'pip install \"openllm[gptq]\" --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/'"
)
if llm.config['model_type'] != 'causal_lm': raise openllm.exceptions.OpenLLMException(f"GPTQ only support Causal LM (got {llm.__class__} of {llm.config['model_type']})")
if '_quantize' in llm.bentomodel.info.metadata:
_quantise = llm.bentomodel.info.metadata['_quantize']
if _quantise == 'gptq':
if not openllm.utils.is_autogptq_available() or not openllm.utils.is_optimum_supports_gptq():
raise openllm.exceptions.OpenLLMException(
"GPTQ quantisation requires 'auto-gptq' and 'optimum' (Not found in local environment). Install it with 'pip install \"openllm[gptq]\" --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/'"
)
if llm.config['model_type'] != 'causal_lm': raise openllm.exceptions.OpenLLMException(f"GPTQ only support Causal LM (got {llm.__class__} of {llm.config['model_type']})")
try:
model = auto_class.from_pretrained(llm.bentomodel.path, device_map='auto', use_flash_attention_2=True, **hub_attrs, **attrs)
except Exception as err:
logger.debug("Exception caught while trying to load with 'flash_attention_2': %s", err)
model = auto_class.from_pretrained(llm.bentomodel.path, device_map='auto', use_flash_attention_2=False, **hub_attrs, **attrs)
# XXX: Use the below logic once TheBloke finished migration to new GPTQConfig from transformers
# Seems like the logic below requires to add support for safetensors on accelerate
# from accelerate import init_empty_weights
# from optimum.gptq import load_quantized_model
# # disable exllama if gptq is loaded on CPU
# disable_exllama = not torch.cuda.is_available()
# with init_empty_weights():
# empty = auto_class.from_pretrained(llm.model_id, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, device_map='auto')
# empty.tie_weights()
# model = load_quantized_model(empty, save_folder=llm._bentomodel.path, device_map='auto', disable_exllama=disable_exllama)
# TODO: investigate load with flash attention
model = auto_class.from_pretrained(llm.bentomodel.path, device_map=device_map, **attrs)
else:
model = auto_class.from_pretrained(llm.bentomodel.path, *decls, config=config, trust_remote_code=llm.trust_remote_code, device_map=device_map, **hub_attrs, **attrs).eval()
if llm.__llm_backend__ == 'pt': check_unintialised_params(model)
model = auto_class.from_pretrained(llm.bentomodel.path, *decls, config=config, trust_remote_code=llm.trust_remote_code, device_map=device_map, **attrs)
return t.cast('M', model)

View File

@@ -1,24 +1,21 @@
from __future__ import annotations
import copy
import copy, re
from pathlib import Path
import typing as t
import openllm
import openllm_core
import transformers
import torch
from openllm.serialisation.constants import FRAMEWORK_TO_AUTOCLASS_MAPPING
from openllm.serialisation.constants import HUB_ATTRS
if t.TYPE_CHECKING:
import torch
import transformers
from transformers.models.auto.auto_factory import _BaseAutoModelClass
from openllm_core._typing_compat import DictStrAny
from openllm_core._typing_compat import M
from openllm_core._typing_compat import T
else:
transformers, torch = openllm_core.utils.LazyLoader('transformers', globals(), 'transformers'), openllm_core.utils.LazyLoader('torch', globals(), 'torch')
def get_hash(config: transformers.PretrainedConfig) -> str:
_commit_hash = getattr(config, '_commit_hash', None)