chore(serialisation): dump quantization_config.json to conform with

optimum load

Signed-off-by: aarnphm-ec2-dev <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
aarnphm-ec2-dev
2023-09-07 16:50:25 +00:00
parent 6b6cf7a9cb
commit 8530a067ea

View File

@@ -4,6 +4,8 @@ import importlib
import logging
import typing as t
import orjson
from huggingface_hub import snapshot_download
from packaging.version import Version
from simple_di import Provide
@@ -108,6 +110,10 @@ def import_model(llm: openllm.LLM[M, T], *decls: t.Any, trust_remote_code: bool,
try:
bentomodel.enter_cloudpickle_context(external_modules, imported_modules)
tokenizer.save_pretrained(bentomodel.path)
if quantize == 'gptq':
from optimum.gptq.constants import GPTQ_CONFIG
with open(bentomodel.path_of(GPTQ_CONFIG), 'w', encoding='utf-8') as f:
f.write(orjson.dumps(config.quantization_config, option=orjson.OPT_INDENT_2 | orjson.OPT_SORT_KEYS).decode())
if llm._local:
# possible local path
logger.debug('Model will be loaded into memory to save to target store as it is from local path.')
@@ -163,7 +169,9 @@ def load_model(llm: openllm.LLM[M, T], *decls: t.Any, **attrs: t.Any) -> M:
if llm.config['model_type'] != 'causal_lm': raise openllm.exceptions.OpenLLMException(f"GPTQ only support Causal LM (got {llm.__class__} of {llm.config['model_type']})")
model = auto_class.from_pretrained(llm._bentomodel.path, device_map='auto', **hub_attrs, **attrs)
# TODO: Use the below logic once TheBloke finished migration to new GPTQConfig from transformers
# XXX: Use the below logic once TheBloke finished migration to new GPTQConfig from transformers
# Seems like the logic below requires to add support for safetensors on accelerate
#
# from accelerate import init_empty_weights
# from optimum.gptq import load_quantized_model
# # disable exllama if gptq is loaded on CPU