mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-03-11 11:39:52 -04:00
chore: cleanup loader (#729)
Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
@@ -4,7 +4,7 @@ import orjson, torch, transformers, bentoml, openllm
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
from openllm_core.exceptions import OpenLLMException
|
||||
from openllm_core.utils import first_not_none, is_autogptq_available
|
||||
from openllm_core.utils import first_not_none, is_autogptq_available, is_flash_attn_2_available
|
||||
|
||||
from ._helpers import get_tokenizer, infer_autoclass_from_llm, process_config
|
||||
from .weights import HfIgnore
|
||||
@@ -122,18 +122,36 @@ def load_model(llm, *decls, **attrs):
|
||||
if llm.config['model_type'] != 'causal_lm':
|
||||
raise OpenLLMException(f"GPTQ only support Causal LM (got {llm.__class__} of {llm.config['model_type']})")
|
||||
|
||||
# TODO: investigate load with flash attention
|
||||
model = auto_class.from_pretrained(
|
||||
llm.bentomodel.path, device_map=device_map, trust_remote_code=llm.trust_remote_code, **attrs
|
||||
)
|
||||
try:
|
||||
model = auto_class.from_pretrained(
|
||||
llm.bentomodel.path, device_map=device_map, trust_remote_code=llm.trust_remote_code, use_flash_attention_2=is_flash_attn_2_available(), **attrs
|
||||
)
|
||||
except Exception as err:
|
||||
logger.debug("Failed to load model with 'use_flash_attention_2' (lookup for traceback):\n%s", err)
|
||||
model = auto_class.from_pretrained(
|
||||
llm.bentomodel.path, device_map=device_map, trust_remote_code=llm.trust_remote_code, **attrs
|
||||
)
|
||||
else:
|
||||
model = auto_class.from_pretrained(
|
||||
llm.bentomodel.path,
|
||||
*decls,
|
||||
config=config,
|
||||
trust_remote_code=llm.trust_remote_code,
|
||||
device_map=device_map,
|
||||
**attrs,
|
||||
)
|
||||
try:
|
||||
model = auto_class.from_pretrained(
|
||||
llm.bentomodel.path,
|
||||
*decls,
|
||||
config=config,
|
||||
trust_remote_code=llm.trust_remote_code,
|
||||
device_map=device_map,
|
||||
use_flash_attention_2=is_flash_attn_2_available(),
|
||||
**attrs,
|
||||
)
|
||||
except Exception as err:
|
||||
logger.debug("Failed to load model with 'use_flash_attention_2' (lookup for traceback):\n%s", err)
|
||||
model = auto_class.from_pretrained(
|
||||
llm.bentomodel.path,
|
||||
*decls,
|
||||
config=config,
|
||||
trust_remote_code=llm.trust_remote_code,
|
||||
device_map=device_map,
|
||||
**attrs,
|
||||
)
|
||||
|
||||
check_unintialised_params(model)
|
||||
return model
|
||||
|
||||
Reference in New Issue
Block a user