chore(llm): expose quantise and lazy load heavy imports (#617)

* chore(llm): expose quantise and lazy load heavy imports

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>

* chore: move transformers to TYPE_CHECKING block

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>

---------

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
Aaron Pham
2023-11-12 14:55:37 -05:00
committed by GitHub
parent 106e8617c1
commit 7e1fb35a71
8 changed files with 189 additions and 116 deletions

View File

@@ -62,7 +62,7 @@ def import_model(llm, *decls, trust_remote_code, _model_store=Provide[BentoMLCon
config, hub_attrs, attrs = process_config(llm.model_id, trust_remote_code, **attrs)
_patch_correct_tag(llm, config)
_, tokenizer_attrs = llm.llm_parameters
quantize = llm._quantise
quantize = llm.quantise
safe_serialisation = openllm.utils.first_not_none(
attrs.get('safe_serialization'), default=llm._serialisation == 'safetensors'
)
@@ -132,7 +132,7 @@ def import_model(llm, *decls, trust_remote_code, _model_store=Provide[BentoMLCon
try:
bentomodel.enter_cloudpickle_context(external_modules, imported_modules)
tokenizer.save_pretrained(bentomodel.path)
if llm._quantization_config or (llm._quantise and llm._quantise not in {'squeezellm', 'awq'}):
if llm._quantization_config or (llm.quantise and llm.quantise not in {'squeezellm', 'awq'}):
attrs['quantization_config'] = llm.quantization_config
if quantize == 'gptq':
from optimum.gptq.constants import GPTQ_CONFIG
@@ -205,7 +205,7 @@ def check_unintialised_params(model):
def load_model(llm, *decls, **attrs):
if llm._quantise in {'awq', 'squeezellm'}:
if llm.quantise in {'awq', 'squeezellm'}:
raise RuntimeError('AWQ is not yet supported with PyTorch backend.')
config, attrs = transformers.AutoConfig.from_pretrained(
llm.bentomodel.path, return_unused_kwargs=True, trust_remote_code=llm.trust_remote_code, **attrs
@@ -217,7 +217,7 @@ def load_model(llm, *decls, **attrs):
device_map = 'auto'
elif torch.cuda.device_count() == 1:
device_map = 'cuda:0'
if llm._quantise in {'int8', 'int4'}:
if llm.quantise in {'int8', 'int4'}:
attrs['quantization_config'] = llm.quantization_config
if '_quantize' in llm.bentomodel.info.metadata: