chore: cleanup loader (#729)

Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
Aaron Pham
2023-11-22 21:51:51 -05:00
committed by GitHub
parent 5442d9cd10
commit 52a44b1bfa
8 changed files with 125 additions and 264 deletions

View File

@@ -4,7 +4,7 @@ import orjson, torch, transformers, bentoml, openllm
from huggingface_hub import snapshot_download
from openllm_core.exceptions import OpenLLMException
from openllm_core.utils import first_not_none, is_autogptq_available
from openllm_core.utils import first_not_none, is_autogptq_available, is_flash_attn_2_available
from ._helpers import get_tokenizer, infer_autoclass_from_llm, process_config
from .weights import HfIgnore
@@ -122,18 +122,36 @@ def load_model(llm, *decls, **attrs):
if llm.config['model_type'] != 'causal_lm':
raise OpenLLMException(f"GPTQ only support Causal LM (got {llm.__class__} of {llm.config['model_type']})")
# TODO: investigate load with flash attention
model = auto_class.from_pretrained(
llm.bentomodel.path, device_map=device_map, trust_remote_code=llm.trust_remote_code, **attrs
)
try:
model = auto_class.from_pretrained(
llm.bentomodel.path, device_map=device_map, trust_remote_code=llm.trust_remote_code, use_flash_attention_2=is_flash_attn_2_available(), **attrs
)
except Exception as err:
logger.debug("Failed to load model with 'use_flash_attention_2' (lookup for traceback):\n%s", err)
model = auto_class.from_pretrained(
llm.bentomodel.path, device_map=device_map, trust_remote_code=llm.trust_remote_code, **attrs
)
else:
model = auto_class.from_pretrained(
llm.bentomodel.path,
*decls,
config=config,
trust_remote_code=llm.trust_remote_code,
device_map=device_map,
**attrs,
)
try:
model = auto_class.from_pretrained(
llm.bentomodel.path,
*decls,
config=config,
trust_remote_code=llm.trust_remote_code,
device_map=device_map,
use_flash_attention_2=is_flash_attn_2_available(),
**attrs,
)
except Exception as err:
logger.debug("Failed to load model with 'use_flash_attention_2' (lookup for traceback):\n%s", err)
model = auto_class.from_pretrained(
llm.bentomodel.path,
*decls,
config=config,
trust_remote_code=llm.trust_remote_code,
device_map=device_map,
**attrs,
)
check_unintialised_params(model)
return model