diff --git a/openllm-python/src/openllm/_quantisation.py b/openllm-python/src/openllm/_quantisation.py index f7f7e679..326cb663 100644 --- a/openllm-python/src/openllm/_quantisation.py +++ b/openllm-python/src/openllm/_quantisation.py @@ -62,7 +62,9 @@ def infer_quantisation_config(self: LLM[t.Any, t.Any], quantise: LiteralQuantise gptq_module_name_preceding_first_block = attrs.pop('module_name_preceding_first_block', None) gptq_batch_size = attrs.pop('batch_size', 1) gptq_pad_token_id = attrs.pop('pad_token_id', None) - gptq_disable_exllama = attrs.pop('disable_exllama', False) + disable_exllama = attrs.pop('disable_exllama', False) # backward compatibility + gptq_use_exllama = attrs.pop('use_exllama', True) + if disable_exllama: gptq_use_exllama = False return transformers.GPTQConfig(bits=bits, tokenizer=gptq_tokenizer, dataset=gptq_dataset, @@ -77,7 +79,8 @@ def infer_quantisation_config(self: LLM[t.Any, t.Any], quantise: LiteralQuantise module_name_preceding_first_block=gptq_module_name_preceding_first_block, batch_size=gptq_batch_size, pad_token_id=gptq_pad_token_id, - disable_exllama=gptq_disable_exllama) + use_exllama=gptq_use_exllama, + exllama_config={'version': 1}) # XXX: See how to migrate to v2 def create_int8_config(int8_skip_modules: list[str] | None) -> transformers.BitsAndBytesConfig: # if int8_skip_modules is None: int8_skip_modules = []