feat(type): provide structured annotations stubs (#663)

* feat(type): provide client stubs

separation of concern for more brevity code base

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>

* docs: update changelog

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>

---------

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
Aaron Pham
2023-11-16 02:58:45 -05:00
committed by GitHub
parent c6264f3af7
commit 4a6f13ddd2
32 changed files with 795 additions and 582 deletions

View File

@@ -1,12 +1,5 @@
# mypy: disable-error-code="name-defined,no-redef"
from __future__ import annotations
import logging
import typing as t
import torch
import transformers
from openllm_core._typing_compat import LiteralQuantise, overload
from openllm_core.exceptions import MissingDependencyError
from openllm_core.utils import (
is_autoawq_available,
@@ -15,35 +8,11 @@ from openllm_core.utils import (
is_optimum_supports_gptq,
)
if t.TYPE_CHECKING:
from openllm_core._typing_compat import DictStrAny
from ._llm import LLM
def infer_quantisation_config(llm, quantise, **attrs):
import torch
import transformers
logger = logging.getLogger(__name__)
@overload
def infer_quantisation_config(
self: LLM[t.Any, t.Any], quantise: t.Literal['int8', 'int4'], **attrs: t.Any
) -> tuple[transformers.BitsAndBytesConfig, DictStrAny]: ...
@overload
def infer_quantisation_config(
self: LLM[t.Any, t.Any], quantise: t.Literal['gptq'], **attrs: t.Any
) -> tuple[transformers.GPTQConfig, DictStrAny]: ...
@overload
def infer_quantisation_config(
self: LLM[t.Any, t.Any], quantise: t.Literal['awq'], **attrs: t.Any
) -> tuple[transformers.AwqConfig, DictStrAny]: ...
def infer_quantisation_config(
self: LLM[t.Any, t.Any], quantise: LiteralQuantise, **attrs: t.Any
) -> tuple[transformers.BitsAndBytesConfig | transformers.GPTQConfig | transformers.AwqConfig, DictStrAny]:
# 8 bit configuration
int8_threshold = attrs.pop('llm_int8_threshhold', 6.0)
int8_enable_fp32_cpu_offload = attrs.pop('llm_int8_enable_fp32_cpu_offload', False)
@@ -54,12 +23,17 @@ def infer_quantisation_config(
bits = attrs.pop('bits', 4)
group_size = attrs.pop('group_size', 128)
def create_awq_config() -> transformers.AwqConfig:
# 4 bit configuration
int4_compute_dtype = attrs.pop('bnb_4bit_compute_dtype', torch.bfloat16)
int4_quant_type = attrs.pop('bnb_4bit_quant_type', 'nf4')
int4_use_double_quant = attrs.pop('bnb_4bit_use_double_quant', True)
def create_awq_config():
zero_point = attrs.pop('zero_point', True)
return transformers.AwqConfig(bits=bits, group_size=group_size, zero_point=zero_point)
def create_gptq_config() -> transformers.GPTQConfig:
gptq_tokenizer = attrs.pop('tokenizer', self.model_id)
def create_gptq_config():
gptq_tokenizer = attrs.pop('tokenizer', llm.model_id)
gptq_dataset = attrs.pop('dataset', 'c4')
gptq_damp_percent = attrs.pop('damp_percent', 0.1)
gptq_desc_act = attrs.pop('desc_act', False)
@@ -94,10 +68,9 @@ def infer_quantisation_config(
exllama_config={'version': 1},
) # XXX: See how to migrate to v2
def create_int8_config(int8_skip_modules: list[str] | None) -> transformers.BitsAndBytesConfig:
def create_int8_config(int8_skip_modules):
# if int8_skip_modules is None: int8_skip_modules = []
# if 'lm_head' not in int8_skip_modules and self.config_class.__openllm_model_type__ == 'causal_lm':
# logger.debug("Skipping 'lm_head' for quantization for %s", self.__name__)
# int8_skip_modules.append('lm_head')
return transformers.BitsAndBytesConfig(
load_in_8bit=True,
@@ -107,10 +80,13 @@ def infer_quantisation_config(
llm_int8_has_fp16_weight=int8_has_fp16_weight,
)
# 4 bit configuration
int4_compute_dtype = attrs.pop('bnb_4bit_compute_dtype', torch.bfloat16)
int4_quant_type = attrs.pop('bnb_4bit_quant_type', 'nf4')
int4_use_double_quant = attrs.pop('bnb_4bit_use_double_quant', True)
def create_int4_config():
return transformers.BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=int4_compute_dtype,
bnb_4bit_quant_type=int4_quant_type,
bnb_4bit_use_double_quant=int4_use_double_quant,
)
# NOTE: Quantization setup quantize is a openllm.LLM feature, where we can quantize the model with bitsandbytes or quantization aware training.
if not is_bitsandbytes_available():
@@ -120,23 +96,18 @@ def infer_quantisation_config(
if quantise == 'int8':
quantisation_config = create_int8_config(int8_skip_modules)
elif quantise == 'int4':
quantisation_config = transformers.BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=int4_compute_dtype,
bnb_4bit_quant_type=int4_quant_type,
bnb_4bit_use_double_quant=int4_use_double_quant,
)
quantisation_config = create_int4_config()
elif quantise == 'gptq':
if not is_autogptq_available() or not is_optimum_supports_gptq():
raise MissingDependencyError(
"'quantize=\"gptq\"' requires 'auto-gptq' and 'optimum>=0.12' to be installed (missing or failed to import). Make sure to do 'pip install \"openllm[gptq]\"'"
"GPTQ requires 'auto-gptq' and 'optimum>=0.12' to be installed. Do it with 'pip install \"openllm[gptq]\"'"
)
else:
quantisation_config = create_gptq_config()
elif quantise == 'awq':
if not is_autoawq_available():
raise MissingDependencyError(
"quantize='awq' requires 'auto-awq' to be installed (missing or failed to import). Make sure to do 'pip install \"openllm[awq]\"'."
"AWQ requires 'auto-awq' to be installed. Do it with 'pip install \"openllm[awq]\"'."
)
else:
quantisation_config = create_awq_config()