infra: using ruff formatter (#594)

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
Aaron Pham
2023-11-09 12:44:05 -05:00
committed by GitHub
parent 021fd453b9
commit ac377fe490
102 changed files with 5577 additions and 2540 deletions

View File

@@ -14,6 +14,7 @@ from openllm_core.utils import is_autogptq_available
from openllm_core.utils import is_bitsandbytes_available
from openllm_core.utils import is_optimum_supports_gptq
if t.TYPE_CHECKING:
from openllm_core._typing_compat import DictStrAny
@@ -21,20 +22,28 @@ if t.TYPE_CHECKING:
logger = logging.getLogger(__name__)
@overload
def infer_quantisation_config(self: LLM[t.Any, t.Any], quantise: t.Literal['int8', 'int4'], **attrs: t.Any) -> tuple[transformers.BitsAndBytesConfig, DictStrAny]:
...
@overload
def infer_quantisation_config(self: LLM[t.Any, t.Any], quantise: t.Literal['gptq'], **attrs: t.Any) -> tuple[transformers.GPTQConfig, DictStrAny]:
...
def infer_quantisation_config(
self: LLM[t.Any, t.Any], quantise: t.Literal['int8', 'int4'], **attrs: t.Any
) -> tuple[transformers.BitsAndBytesConfig, DictStrAny]: ...
@overload
def infer_quantisation_config(self: LLM[t.Any, t.Any], quantise: t.Literal['awq'], **attrs: t.Any) -> tuple[transformers.AwqConfig, DictStrAny]:
...
def infer_quantisation_config(
self: LLM[t.Any, t.Any], quantise: t.Literal['gptq'], **attrs: t.Any
) -> tuple[transformers.GPTQConfig, DictStrAny]: ...
def infer_quantisation_config(self: LLM[t.Any, t.Any], quantise: LiteralQuantise,
**attrs: t.Any) -> tuple[transformers.BitsAndBytesConfig | transformers.GPTQConfig | transformers.AwqConfig, DictStrAny]:
@overload
def infer_quantisation_config(
self: LLM[t.Any, t.Any], quantise: t.Literal['awq'], **attrs: t.Any
) -> tuple[transformers.AwqConfig, DictStrAny]: ...
def infer_quantisation_config(
self: LLM[t.Any, t.Any], quantise: LiteralQuantise, **attrs: t.Any
) -> tuple[transformers.BitsAndBytesConfig | transformers.GPTQConfig | transformers.AwqConfig, DictStrAny]:
# 8 bit configuration
int8_threshold = attrs.pop('llm_int8_threshhold', 6.0)
int8_enable_fp32_cpu_offload = attrs.pop('llm_int8_enable_fp32_cpu_offload', False)
@@ -64,34 +73,39 @@ def infer_quantisation_config(self: LLM[t.Any, t.Any], quantise: LiteralQuantise
gptq_pad_token_id = attrs.pop('pad_token_id', None)
disable_exllama = attrs.pop('disable_exllama', False) # backward compatibility
gptq_use_exllama = attrs.pop('use_exllama', True)
if disable_exllama: gptq_use_exllama = False
return transformers.GPTQConfig(bits=bits,
tokenizer=gptq_tokenizer,
dataset=gptq_dataset,
group_size=group_size,
damp_percent=gptq_damp_percent,
desc_act=gptq_desc_act,
sym=gptq_sym,
true_sequential=gptq_true_sequential,
use_cuda_fp16=gptq_use_cuda_fp16,
model_seqlen=gptq_model_seqlen,
block_name_to_quantize=gptq_block_name_to_quantize,
module_name_preceding_first_block=gptq_module_name_preceding_first_block,
batch_size=gptq_batch_size,
pad_token_id=gptq_pad_token_id,
use_exllama=gptq_use_exllama,
exllama_config={'version': 1}) # XXX: See how to migrate to v2
if disable_exllama:
gptq_use_exllama = False
return transformers.GPTQConfig(
bits=bits,
tokenizer=gptq_tokenizer,
dataset=gptq_dataset,
group_size=group_size,
damp_percent=gptq_damp_percent,
desc_act=gptq_desc_act,
sym=gptq_sym,
true_sequential=gptq_true_sequential,
use_cuda_fp16=gptq_use_cuda_fp16,
model_seqlen=gptq_model_seqlen,
block_name_to_quantize=gptq_block_name_to_quantize,
module_name_preceding_first_block=gptq_module_name_preceding_first_block,
batch_size=gptq_batch_size,
pad_token_id=gptq_pad_token_id,
use_exllama=gptq_use_exllama,
exllama_config={'version': 1},
) # XXX: See how to migrate to v2
def create_int8_config(int8_skip_modules: list[str] | None) -> transformers.BitsAndBytesConfig:
# if int8_skip_modules is None: int8_skip_modules = []
# if 'lm_head' not in int8_skip_modules and self.config_class.__openllm_model_type__ == 'causal_lm':
# logger.debug("Skipping 'lm_head' for quantization for %s", self.__name__)
# int8_skip_modules.append('lm_head')
return transformers.BitsAndBytesConfig(load_in_8bit=True,
llm_int8_enable_fp32_cpu_offload=int8_enable_fp32_cpu_offload,
llm_int8_threshhold=int8_threshold,
llm_int8_skip_modules=int8_skip_modules,
llm_int8_has_fp16_weight=int8_has_fp16_weight)
return transformers.BitsAndBytesConfig(
load_in_8bit=True,
llm_int8_enable_fp32_cpu_offload=int8_enable_fp32_cpu_offload,
llm_int8_threshhold=int8_threshold,
llm_int8_skip_modules=int8_skip_modules,
llm_int8_has_fp16_weight=int8_has_fp16_weight,
)
# 4 bit configuration
int4_compute_dtype = attrs.pop('bnb_4bit_compute_dtype', torch.bfloat16)
@@ -100,22 +114,30 @@ def infer_quantisation_config(self: LLM[t.Any, t.Any], quantise: LiteralQuantise
# NOTE: Quantization setup quantize is a openllm.LLM feature, where we can quantize the model with bitsandbytes or quantization aware training.
if not is_bitsandbytes_available():
raise RuntimeError("Quantization requires bitsandbytes to be installed. Make sure to install OpenLLM with 'pip install \"openllm[fine-tune]\"'")
if quantise == 'int8': quantisation_config = create_int8_config(int8_skip_modules)
raise RuntimeError(
'Quantization requires bitsandbytes to be installed. Make sure to install OpenLLM with \'pip install "openllm[fine-tune]"\''
)
if quantise == 'int8':
quantisation_config = create_int8_config(int8_skip_modules)
elif quantise == 'int4':
quantisation_config = transformers.BitsAndBytesConfig(load_in_4bit=True,
bnb_4bit_compute_dtype=int4_compute_dtype,
bnb_4bit_quant_type=int4_quant_type,
bnb_4bit_use_double_quant=int4_use_double_quant)
quantisation_config = transformers.BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=int4_compute_dtype,
bnb_4bit_quant_type=int4_quant_type,
bnb_4bit_use_double_quant=int4_use_double_quant,
)
elif quantise == 'gptq':
if not is_autogptq_available() or not is_optimum_supports_gptq():
raise MissingDependencyError(
"'quantize=\"gptq\"' requires 'auto-gptq' and 'optimum>=0.12' to be installed (missing or failed to import). Make sure to do 'pip install \"openllm[gptq]\"'")
"'quantize=\"gptq\"' requires 'auto-gptq' and 'optimum>=0.12' to be installed (missing or failed to import). Make sure to do 'pip install \"openllm[gptq]\"'"
)
else:
quantisation_config = create_gptq_config()
elif quantise == 'awq':
if not is_autoawq_available():
raise MissingDependencyError("quantize='awq' requires 'auto-awq' to be installed (missing or failed to import). Make sure to do 'pip install \"openllm[awq]\"'.")
raise MissingDependencyError(
"quantize='awq' requires 'auto-awq' to be installed (missing or failed to import). Make sure to do 'pip install \"openllm[awq]\"'."
)
else:
quantisation_config = create_awq_config()
else: