fix(llm): ignore quantization config when --quantize int4 is passed

Signed-off-by: aarnphm-ec2-dev <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
aarnphm-ec2-dev
2023-07-22 22:45:46 +00:00
parent 6f4c58175d
commit d4f3cf8b75
4 changed files with 29 additions and 17 deletions

View File

@@ -37,7 +37,7 @@ pip install --upgrade openllm==${TAG}
## Usage
All available models: \`\`\`python -m openllm.models\`\`\`
All available models: \`\`\`openllm models\`\`\`
To start a LLM: \`\`\`python -m openllm start opt\`\`\`

View File

@@ -58,7 +58,6 @@ from .utils import is_peft_available
from .utils import is_torch_available
from .utils import non_intrusive_setattr
from .utils import normalize_attrs_to_model_tokenizer_pair
from .utils import pkg
from .utils import requires_dependencies
from .utils import resolve_filepath
from .utils import validate_is_path
@@ -1144,10 +1143,7 @@ class LLM(LLMInterface[M, T], ReprMixin):
use_gradient_checkpointing: bool = True,
**attrs: t.Any,
) -> tuple[peft.PeftModel, T]:
if pkg.pkg_version_info("peft")[:2] >= (0, 4):
from peft import prepare_model_for_kbit_training
else:
from peft import prepare_model_for_int8_training as prepare_model_for_kbit_training
from peft import prepare_model_for_kbit_training
peft_config = (
self.config["fine_tune_strategies"]

View File

@@ -45,6 +45,8 @@ def find_all_linear_names(model):
# Change this to the local converted path if you don't have access to the meta-llama model
DEFAULT_MODEL_ID = "meta-llama/Llama-2-7b-hf"
# change this to 'main' if you want to use the latest llama
DEFAULT_MODEL_VERSION = "335a02887eb6684d487240bbc28b5699298c3135"
DATASET_NAME = "databricks/databricks-dolly-15k"
@@ -119,13 +121,17 @@ def prepare_datasets(tokenizer, dataset_name=DATASET_NAME):
@openllm.utils.requires_dependencies("peft", extra="fine-tune")
def prepare_for_int4_training(
model_id: str, gradient_checkpointing: bool = True, bf16: bool = True
model_id: str,
model_version: str | None = None,
gradient_checkpointing: bool = True,
bf16: bool = True,
) -> tuple[peft.PeftModel, transformers.LlamaTokenizerFast]:
from peft.tuners.lora import LoraLayer
llm = openllm.AutoLLM.for_model(
"llama",
model_id=model_id,
model_version=model_version,
ensure_available=True,
quantize="int4",
bnb_4bit_compute_dtype=torch.bfloat16,
@@ -138,7 +144,9 @@ def prepare_for_int4_training(
modules = find_all_linear_names(llm.model)
print(f"Found {len(modules)} modules to quantize: {modules}")
model, tokenizer = llm.prepare_for_training(adapter_type="lora", use_gradient_checkpointing=gradient_checkpointing)
model, tokenizer = llm.prepare_for_training(
adapter_type="lora", use_gradient_checkpointing=gradient_checkpointing, target_modules=modules
)
# pre-process the model by upcasting the layer norms in float 32 for
for name, module in model.named_modules():
@@ -170,23 +178,27 @@ class TrainingArguments:
@dataclasses.dataclass
class ModelArguments:
model_id: str = dataclasses.field(default=DEFAULT_MODEL_ID)
model_version: str = dataclasses.field(default=DEFAULT_MODEL_VERSION)
seed: int = dataclasses.field(default=42)
merge_weights: bool = dataclasses.field(default=False)
parser = transformers.HfArgumentParser((ModelArguments, TrainingArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
model_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
if openllm.utils.in_notebook():
model_args, training_rags = ModelArguments(), TrainingArguments()
else:
model_args, training_args = t.cast(
t.Tuple[ModelArguments, TrainingArguments], parser.parse_args_into_dataclasses()
)
parser = transformers.HfArgumentParser((ModelArguments, TrainingArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
model_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, training_args = t.cast(
t.Tuple[ModelArguments, TrainingArguments], parser.parse_args_into_dataclasses()
)
# import the model first hand
openllm.import_model("llama", model_id=model_args.model_id)
openllm.import_model("llama", model_id=model_args.model_id, model_version=model_args.model_version)
def train_loop(model_args: ModelArguments, training_args: TrainingArguments):

View File

@@ -166,6 +166,10 @@ def import_model(
metadata["_framework"] = model.model.framework
signatures["generate"] = {"batchable": False}
else:
if "quantization_config" in attrs and getattr(attrs["quantization_config"], "load_in_4bit", False):
# this model might be called with --quantize int4, therefore we need to pop this out
# since saving int4 is not yet supported
attrs.pop("quantization_config")
model = t.cast(
"_transformers.PreTrainedModel",
infer_autoclass_from_llm_config(llm, config).from_pretrained(