mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-01-28 01:14:09 -05:00
fix(gptq): use upstream integration (#297)
* wip Signed-off-by: aarnphm-ec2-dev <29749331+aarnphm@users.noreply.github.com> * feat: GPTQ transformers integration Signed-off-by: aarnphm-ec2-dev <29749331+aarnphm@users.noreply.github.com> * fix: only load if variable is available and add changelog Signed-off-by: aarnphm-ec2-dev <29749331+aarnphm@users.noreply.github.com> * chore: remove boilerplate check Signed-off-by: aarnphm-ec2-dev <29749331+aarnphm@users.noreply.github.com> --------- Signed-off-by: aarnphm-ec2-dev <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
@@ -37,6 +37,7 @@ def load_tokenizer(llm: openllm.LLM[t.Any, T], **tokenizer_attrs: t.Any) -> T:
|
||||
from .transformers._helpers import process_config
|
||||
|
||||
config, *_ = process_config(llm._bentomodel.path, llm.trust_remote_code)
|
||||
|
||||
bentomodel_fs = fs.open_fs(llm._bentomodel.path)
|
||||
if bentomodel_fs.isfile(CUSTOM_OBJECTS_FILENAME):
|
||||
with bentomodel_fs.open(CUSTOM_OBJECTS_FILENAME, 'rb') as cofile:
|
||||
|
||||
@@ -14,13 +14,14 @@ import openllm
|
||||
|
||||
from bentoml._internal.configuration.containers import BentoMLContainer
|
||||
from bentoml._internal.models.model import ModelOptions
|
||||
from openllm_core._typing_compat import M
|
||||
from openllm_core._typing_compat import T
|
||||
|
||||
from ._helpers import check_unintialised_params
|
||||
from ._helpers import infer_autoclass_from_llm
|
||||
from ._helpers import infer_tokenizers_from_llm
|
||||
from ._helpers import make_model_signatures
|
||||
from ._helpers import process_config
|
||||
from ._helpers import update_model
|
||||
from .weights import HfIgnore
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
@@ -32,8 +33,6 @@ if t.TYPE_CHECKING:
|
||||
|
||||
from bentoml._internal.models import ModelStore
|
||||
from openllm_core._typing_compat import DictStrAny
|
||||
from openllm_core._typing_compat import M
|
||||
from openllm_core._typing_compat import T
|
||||
else:
|
||||
autogptq = openllm.utils.LazyLoader('autogptq', globals(), 'auto_gptq')
|
||||
torch = openllm.utils.LazyLoader('torch', globals(), 'torch')
|
||||
@@ -63,16 +62,23 @@ def import_model(llm: openllm.LLM[M, T], *decls: t.Any, trust_remote_code: bool,
|
||||
"""
|
||||
config, hub_attrs, attrs = process_config(llm.model_id, trust_remote_code, **attrs)
|
||||
_, tokenizer_attrs = llm.llm_parameters
|
||||
quantize_method = llm._quantize_method
|
||||
safe_serialisation = openllm.utils.first_not_none(attrs.get('safe_serialization'), default=llm._serialisation_format == 'safetensors')
|
||||
quantize = llm._quantize
|
||||
safe_serialisation = openllm.utils.first_not_none(attrs.get('safe_serialization'), default=llm._serialisation == 'safetensors')
|
||||
# Disable safe serialization with vLLM
|
||||
if llm.__llm_backend__ == 'vllm': safe_serialisation = False
|
||||
metadata: DictStrAny = {'safe_serialisation': safe_serialisation, '_quantize': quantize_method is not None and quantize_method}
|
||||
metadata: DictStrAny = {'safe_serialisation': safe_serialisation}
|
||||
if quantize: metadata['_quantize'] = quantize
|
||||
architectures = getattr(config, 'architectures', [])
|
||||
if not architectures: raise RuntimeError('Failed to determine the architecture for this model. Make sure the `config.json` is valid and can be loaded with `transformers.AutoConfig`')
|
||||
metadata['_pretrained_class'] = architectures[0]
|
||||
|
||||
signatures: DictStrAny = {}
|
||||
|
||||
if quantize_method == 'gptq':
|
||||
if not openllm.utils.is_autogptq_available():
|
||||
raise openllm.exceptions.OpenLLMException("GPTQ quantisation requires 'auto-gptq' (Not found in local environment). Install it with 'pip install \"openllm[gptq]\"'")
|
||||
if quantize == 'gptq':
|
||||
if not openllm.utils.is_autogptq_available() or not openllm.utils.is_optimum_supports_gptq():
|
||||
raise openllm.exceptions.OpenLLMException(
|
||||
"GPTQ quantisation requires 'auto-gptq' and 'optimum' (Not found in local environment). Install it with 'pip install \"openllm[gptq]\" --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/'"
|
||||
)
|
||||
if llm.config['model_type'] != 'causal_lm':
|
||||
raise openllm.exceptions.OpenLLMException(f"GPTQ only support Causal LM (got {llm.__class__} of {llm.config['model_type']})")
|
||||
signatures['generate'] = {'batchable': False}
|
||||
@@ -82,7 +88,8 @@ def import_model(llm: openllm.LLM[M, T], *decls: t.Any, trust_remote_code: bool,
|
||||
if 'quantization_config' in attrs and getattr(attrs['quantization_config'], 'load_in_4bit', False):
|
||||
attrs.pop('quantization_config')
|
||||
if llm.__llm_backend__ != 'flax': attrs['use_safetensors'] = safe_serialisation
|
||||
metadata['_framework'] = 'pt' if llm.__llm_backend__ == 'vllm' else llm.__llm_backend__
|
||||
metadata['_framework'] = llm.__llm_backend__
|
||||
signatures.update(make_model_signatures(llm))
|
||||
|
||||
tokenizer = infer_tokenizers_from_llm(llm).from_pretrained(llm.model_id, trust_remote_code=trust_remote_code, **hub_attrs, **tokenizer_attrs)
|
||||
if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token
|
||||
@@ -95,42 +102,22 @@ def import_model(llm: openllm.LLM[M, T], *decls: t.Any, trust_remote_code: bool,
|
||||
options=ModelOptions(),
|
||||
context=openllm.utils.generate_context(framework_name='openllm'),
|
||||
labels=openllm.utils.generate_labels(llm),
|
||||
signatures=signatures if signatures else make_model_signatures(llm))
|
||||
metadata=metadata,
|
||||
signatures=signatures)
|
||||
with openllm.utils.analytics.set_bentoml_tracking():
|
||||
try:
|
||||
bentomodel.enter_cloudpickle_context(external_modules, imported_modules)
|
||||
tokenizer.save_pretrained(bentomodel.path)
|
||||
if quantize_method == 'gptq':
|
||||
if not openllm.utils.is_autogptq_available():
|
||||
raise openllm.exceptions.OpenLLMException("GPTQ quantisation requires 'auto-gptq' (Not found in local environment). Install it with 'pip install \"openllm[gptq]\"'")
|
||||
if llm.config['model_type'] != 'causal_lm':
|
||||
raise openllm.exceptions.OpenLLMException(f"GPTQ only support Causal LM (got {llm.__class__} of {llm.config['model_type']})")
|
||||
logger.debug('Saving model with GPTQ quantisation will require loading model into memory.')
|
||||
model = autogptq.AutoGPTQForCausalLM.from_quantized(llm.model_id,
|
||||
*decls,
|
||||
quantize_config=t.cast('autogptq.BaseQuantizeConfig', llm.quantization_config),
|
||||
trust_remote_code=trust_remote_code,
|
||||
use_safetensors=safe_serialisation,
|
||||
**hub_attrs,
|
||||
**attrs)
|
||||
update_model(bentomodel, metadata={'_pretrained_class': model.__class__.__name__, '_framework': model.model.framework})
|
||||
model.save_quantized(bentomodel.path, use_safetensors=safe_serialisation)
|
||||
if llm._local:
|
||||
# possible local path
|
||||
logger.debug('Model will be loaded into memory to save to target store as it is from local path.')
|
||||
model = infer_autoclass_from_llm(llm, config).from_pretrained(llm.model_id, *decls, config=config, trust_remote_code=trust_remote_code, **hub_attrs, **attrs)
|
||||
# for trust_remote_code to work
|
||||
bentomodel.enter_cloudpickle_context([importlib.import_module(model.__module__)], imported_modules)
|
||||
model.save_pretrained(bentomodel.path, max_shard_size='5GB', safe_serialization=safe_serialisation)
|
||||
else:
|
||||
architectures = getattr(config, 'architectures', [])
|
||||
if not architectures:
|
||||
raise RuntimeError('Failed to determine the architecture for this model. Make sure the `config.json` is valid and can be loaded with `transformers.AutoConfig`')
|
||||
architecture = architectures[0]
|
||||
update_model(bentomodel, metadata={'_pretrained_class': architecture})
|
||||
if llm._local:
|
||||
# possible local path
|
||||
logger.debug('Model will be loaded into memory to save to target store as it is from local path.')
|
||||
model = infer_autoclass_from_llm(llm, config).from_pretrained(llm.model_id, *decls, config=config, trust_remote_code=trust_remote_code, **hub_attrs, **attrs)
|
||||
# for trust_remote_code to work
|
||||
bentomodel.enter_cloudpickle_context([importlib.import_module(model.__module__)], imported_modules)
|
||||
model.save_pretrained(bentomodel.path, max_shard_size='5GB', safe_serialization=safe_serialisation)
|
||||
else:
|
||||
# we will clone the all tings into the bentomodel path without loading model into memory
|
||||
snapshot_download(llm.model_id, local_dir=bentomodel.path, local_dir_use_symlinks=False, ignore_patterns=HfIgnore.ignore_patterns(llm))
|
||||
# we will clone the all tings into the bentomodel path without loading model into memory
|
||||
snapshot_download(llm.model_id, local_dir=bentomodel.path, local_dir_use_symlinks=False, ignore_patterns=HfIgnore.ignore_patterns(llm))
|
||||
except Exception:
|
||||
raise
|
||||
else:
|
||||
@@ -165,29 +152,27 @@ def get(llm: openllm.LLM[M, T], auto_import: bool = False) -> bentoml.Model:
|
||||
|
||||
def load_model(llm: openllm.LLM[M, T], *decls: t.Any, **attrs: t.Any) -> M:
|
||||
config, hub_attrs, attrs = process_config(llm.model_id, llm.trust_remote_code, **attrs)
|
||||
safe_serialization = openllm.utils.first_not_none(t.cast(t.Optional[bool], llm._bentomodel.info.metadata.get('safe_serialisation', None)),
|
||||
attrs.pop('safe_serialization', None),
|
||||
default=llm._serialisation_format == 'safetensors')
|
||||
if '_quantize' in llm._bentomodel.info.metadata and llm._bentomodel.info.metadata['_quantize'] == 'gptq':
|
||||
if not openllm.utils.is_autogptq_available():
|
||||
raise openllm.exceptions.OpenLLMException("GPTQ quantisation requires 'auto-gptq' (Not found in local environment). Install it with 'pip install \"openllm[gptq]\"'")
|
||||
if llm.config['model_type'] != 'causal_lm':
|
||||
raise openllm.exceptions.OpenLLMException(f"GPTQ only support Causal LM (got {llm.__class__} of {llm.config['model_type']})")
|
||||
return autogptq.AutoGPTQForCausalLM.from_quantized(llm._bentomodel.path,
|
||||
*decls,
|
||||
quantize_config=t.cast('autogptq.BaseQuantizeConfig', llm.quantization_config),
|
||||
trust_remote_code=llm.trust_remote_code,
|
||||
use_safetensors=safe_serialization,
|
||||
**hub_attrs,
|
||||
**attrs)
|
||||
auto_class = infer_autoclass_from_llm(llm, config)
|
||||
device_map: str | None = attrs.pop('device_map', 'auto' if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None)
|
||||
|
||||
device_map = attrs.pop('device_map', 'auto' if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None)
|
||||
model = infer_autoclass_from_llm(llm, config).from_pretrained(llm._bentomodel.path,
|
||||
*decls,
|
||||
config=config,
|
||||
trust_remote_code=llm.trust_remote_code,
|
||||
device_map=device_map,
|
||||
**hub_attrs,
|
||||
**attrs).eval()
|
||||
if llm.__llm_backend__ in {'pt', 'vllm'}: check_unintialised_params(model)
|
||||
if '_quantize' in llm._bentomodel.info.metadata and llm._bentomodel.info.metadata['_quantize'] == 'gptq':
|
||||
if not openllm.utils.is_autogptq_available() or not openllm.utils.is_optimum_supports_gptq():
|
||||
raise openllm.exceptions.OpenLLMException(
|
||||
"GPTQ quantisation requires 'auto-gptq' and 'optimum' (Not found in local environment). Install it with 'pip install \"openllm[gptq]\" --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/'"
|
||||
)
|
||||
if llm.config['model_type'] != 'causal_lm': raise openllm.exceptions.OpenLLMException(f"GPTQ only support Causal LM (got {llm.__class__} of {llm.config['model_type']})")
|
||||
|
||||
model = auto_class.from_pretrained(llm._bentomodel.path, device_map='auto', **hub_attrs, **attrs)
|
||||
# TODO: Use the below logic once TheBloke finished migration to new GPTQConfig from transformers
|
||||
# from accelerate import init_empty_weights
|
||||
# from optimum.gptq import load_quantized_model
|
||||
# # disable exllama if gptq is loaded on CPU
|
||||
# disable_exllama = not torch.cuda.is_available()
|
||||
# with init_empty_weights():
|
||||
# empty = auto_class.from_pretrained(llm.model_id, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, device_map='auto')
|
||||
# empty.tie_weights()
|
||||
# model = load_quantized_model(empty, save_folder=llm._bentomodel.path, device_map='auto', disable_exllama=disable_exllama)
|
||||
else:
|
||||
model = auto_class.from_pretrained(llm._bentomodel.path, *decls, config=config, trust_remote_code=llm.trust_remote_code, device_map=device_map, **hub_attrs, **attrs).eval()
|
||||
if llm.__llm_backend__ in {'pt', 'vllm'}: check_unintialised_params(model)
|
||||
return t.cast('M', model)
|
||||
|
||||
@@ -5,7 +5,6 @@ import typing as t
|
||||
import openllm
|
||||
import openllm_core
|
||||
|
||||
from bentoml._internal.models.model import ModelInfo
|
||||
from bentoml._internal.models.model import ModelSignature
|
||||
from openllm.serialisation.constants import FRAMEWORK_TO_AUTOCLASS_MAPPING
|
||||
from openllm.serialisation.constants import HUB_ATTRS
|
||||
@@ -16,8 +15,6 @@ if t.TYPE_CHECKING:
|
||||
|
||||
from transformers.models.auto.auto_factory import _BaseAutoModelClass
|
||||
|
||||
import bentoml
|
||||
|
||||
from bentoml._internal.models.model import ModelSignaturesType
|
||||
from openllm_core._typing_compat import DictStrAny
|
||||
from openllm_core._typing_compat import M
|
||||
@@ -25,8 +22,6 @@ if t.TYPE_CHECKING:
|
||||
else:
|
||||
transformers, torch = openllm_core.utils.LazyLoader('transformers', globals(), 'transformers'), openllm_core.utils.LazyLoader('torch', globals(), 'torch')
|
||||
|
||||
_object_setattr = object.__setattr__
|
||||
|
||||
def process_config(model_id: str, trust_remote_code: bool, **attrs: t.Any) -> tuple[transformers.PretrainedConfig, DictStrAny, DictStrAny]:
|
||||
'''A helper function that correctly parse config and attributes for transformers.PretrainedConfig.
|
||||
|
||||
@@ -73,24 +68,6 @@ def check_unintialised_params(model: torch.nn.Module) -> None:
|
||||
if len(unintialized) > 0:
|
||||
raise RuntimeError(f'Found the following unintialized parameters in {model}: {unintialized}')
|
||||
|
||||
def update_model(bentomodel: bentoml.Model, metadata: DictStrAny) -> bentoml.Model:
|
||||
based: DictStrAny = copy.deepcopy(bentomodel.info.metadata)
|
||||
based.update(metadata)
|
||||
_object_setattr(
|
||||
bentomodel,
|
||||
'_info',
|
||||
ModelInfo( # type: ignore[call-arg] # XXX: remove me once upstream is merged
|
||||
tag=bentomodel.info.tag,
|
||||
module=bentomodel.info.module,
|
||||
labels=bentomodel.info.labels,
|
||||
options=bentomodel.info.options.to_dict(),
|
||||
signatures=bentomodel.info.signatures,
|
||||
context=bentomodel.info.context,
|
||||
api_version=bentomodel.info.api_version,
|
||||
creation_time=bentomodel.info.creation_time,
|
||||
metadata=based))
|
||||
return bentomodel
|
||||
|
||||
# NOTE: sync with bentoml/_internal/frameworks/transformers.py#make_default_signatures
|
||||
def make_model_signatures(llm: openllm.LLM[M, T]) -> ModelSignaturesType:
|
||||
infer_fn: tuple[str, ...] = ('__call__',)
|
||||
|
||||
Reference in New Issue
Block a user