mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-01-26 16:27:51 -05:00
chore(style): reduce line length and truncate compression
Signed-off-by: aarnphm-ec2-dev <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
@@ -47,9 +47,11 @@ def load_tokenizer(llm: openllm.LLM[t.Any, T], **tokenizer_attrs: t.Any) -> T:
|
||||
try:
|
||||
tokenizer = cloudpickle.load(t.cast("t.IO[bytes]", cofile))["tokenizer"]
|
||||
except KeyError:
|
||||
raise openllm.exceptions.OpenLLMException("Bento model does not have tokenizer. Make sure to save"
|
||||
" the tokenizer within the model via 'custom_objects'."
|
||||
" For example: \"bentoml.transformers.save_model(..., custom_objects={'tokenizer': tokenizer})\"") from None
|
||||
raise openllm.exceptions.OpenLLMException(
|
||||
"Bento model does not have tokenizer. Make sure to save"
|
||||
" the tokenizer within the model via 'custom_objects'."
|
||||
" For example: \"bentoml.transformers.save_model(..., custom_objects={'tokenizer': tokenizer})\""
|
||||
) from None
|
||||
else:
|
||||
tokenizer = infer_tokenizers_from_llm(llm).from_pretrained(bentomodel_fs.getsyspath("/"), trust_remote_code=llm.__llm_trust_remote_code__, **tokenizer_attrs)
|
||||
|
||||
|
||||
@@ -1,3 +1,8 @@
|
||||
from __future__ import annotations
|
||||
FRAMEWORK_TO_AUTOCLASS_MAPPING = {"pt": ("AutoModelForCausalLM", "AutoModelForSeq2SeqLM"), "tf": ("TFAutoModelForCausalLM", "TFAutoModelForSeq2SeqLM"), "flax": ("FlaxAutoModelForCausalLM", "FlaxAutoModelForSeq2SeqLM"), "vllm": ("AutoModelForCausalLM", "AutoModelForSeq2SeqLM")}
|
||||
FRAMEWORK_TO_AUTOCLASS_MAPPING = {
|
||||
"pt": ("AutoModelForCausalLM", "AutoModelForSeq2SeqLM"),
|
||||
"tf": ("TFAutoModelForCausalLM", "TFAutoModelForSeq2SeqLM"),
|
||||
"flax": ("FlaxAutoModelForCausalLM", "FlaxAutoModelForSeq2SeqLM"),
|
||||
"vllm": ("AutoModelForCausalLM", "AutoModelForSeq2SeqLM")
|
||||
}
|
||||
HUB_ATTRS = ["cache_dir", "code_revision", "force_download", "local_files_only", "proxies", "resume_download", "revision", "subfolder", "use_auth_token"]
|
||||
|
||||
@@ -55,7 +55,8 @@ def import_model(llm: openllm.LLM[M, T], *decls: t.Any, trust_remote_code: bool,
|
||||
signatures: DictStrAny = {}
|
||||
|
||||
if quantize_method == "gptq":
|
||||
if not openllm.utils.is_autogptq_available(): raise openllm.exceptions.OpenLLMException("GPTQ quantisation requires 'auto-gptq' (Not found in local environment). Install it with 'pip install \"openllm[gptq]\"'")
|
||||
if not openllm.utils.is_autogptq_available():
|
||||
raise openllm.exceptions.OpenLLMException("GPTQ quantisation requires 'auto-gptq' (Not found in local environment). Install it with 'pip install \"openllm[gptq]\"'")
|
||||
if llm.config["model_type"] != "causal_lm": raise openllm.exceptions.OpenLLMException(f"GPTQ only support Causal LM (got {llm.__class__} of {llm.config['model_type']})")
|
||||
signatures["generate"] = {"batchable": False}
|
||||
else:
|
||||
@@ -70,16 +71,33 @@ def import_model(llm: openllm.LLM[M, T], *decls: t.Any, trust_remote_code: bool,
|
||||
|
||||
external_modules: list[types.ModuleType] = [importlib.import_module(tokenizer.__module__)]
|
||||
imported_modules: list[types.ModuleType] = []
|
||||
bentomodel = bentoml.Model.create(llm.tag, module="openllm.serialisation.transformers", api_version="v1", options=ModelOptions(), context=openllm.utils.generate_context(framework_name="openllm"), labels=openllm.utils.generate_labels(llm), signatures=signatures if signatures else make_model_signatures(llm))
|
||||
bentomodel = bentoml.Model.create(
|
||||
llm.tag,
|
||||
module="openllm.serialisation.transformers",
|
||||
api_version="v1",
|
||||
options=ModelOptions(),
|
||||
context=openllm.utils.generate_context(framework_name="openllm"),
|
||||
labels=openllm.utils.generate_labels(llm),
|
||||
signatures=signatures if signatures else make_model_signatures(llm)
|
||||
)
|
||||
with openllm.utils.analytics.set_bentoml_tracking():
|
||||
try:
|
||||
bentomodel.enter_cloudpickle_context(external_modules, imported_modules)
|
||||
tokenizer.save_pretrained(bentomodel.path)
|
||||
if quantize_method == "gptq":
|
||||
if not openllm.utils.is_autogptq_available(): raise openllm.exceptions.OpenLLMException("GPTQ quantisation requires 'auto-gptq' (Not found in local environment). Install it with 'pip install \"openllm[gptq]\"'")
|
||||
if not openllm.utils.is_autogptq_available():
|
||||
raise openllm.exceptions.OpenLLMException("GPTQ quantisation requires 'auto-gptq' (Not found in local environment). Install it with 'pip install \"openllm[gptq]\"'")
|
||||
if llm.config["model_type"] != "causal_lm": raise openllm.exceptions.OpenLLMException(f"GPTQ only support Causal LM (got {llm.__class__} of {llm.config['model_type']})")
|
||||
logger.debug("Saving model with GPTQ quantisation will require loading model into memory.")
|
||||
model = autogptq.AutoGPTQForCausalLM.from_quantized(llm.model_id, *decls, quantize_config=t.cast("autogptq.BaseQuantizeConfig", llm.quantization_config), trust_remote_code=trust_remote_code, use_safetensors=safe_serialisation, **hub_attrs, **attrs,)
|
||||
model = autogptq.AutoGPTQForCausalLM.from_quantized(
|
||||
llm.model_id,
|
||||
*decls,
|
||||
quantize_config=t.cast("autogptq.BaseQuantizeConfig", llm.quantization_config),
|
||||
trust_remote_code=trust_remote_code,
|
||||
use_safetensors=safe_serialisation,
|
||||
**hub_attrs,
|
||||
**attrs,
|
||||
)
|
||||
update_model(bentomodel, metadata={"_pretrained_class": model.__class__.__name__, "_framework": model.model.framework})
|
||||
model.save_quantized(bentomodel.path, use_safetensors=safe_serialisation)
|
||||
else:
|
||||
@@ -120,8 +138,10 @@ def get(llm: openllm.LLM[M, T], auto_import: bool = False) -> bentoml.Model:
|
||||
"""
|
||||
try:
|
||||
model = bentoml.models.get(llm.tag)
|
||||
if model.info.module not in ("openllm.serialisation.transformers"
|
||||
"bentoml.transformers", "bentoml._internal.frameworks.transformers", __name__): # NOTE: backward compatible with previous version of OpenLLM.
|
||||
if model.info.module not in (
|
||||
"openllm.serialisation.transformers"
|
||||
"bentoml.transformers", "bentoml._internal.frameworks.transformers", __name__
|
||||
): # NOTE: backward compatible with previous version of OpenLLM.
|
||||
raise bentoml.exceptions.NotFound(f"Model {model.tag} was saved with module {model.info.module}, not loading with 'openllm.serialisation.transformers'.")
|
||||
if "runtime" in model.info.labels and model.info.labels["runtime"] != llm.runtime:
|
||||
raise openllm.exceptions.OpenLLMException(f"Model {model.tag} was saved with runtime {model.info.labels['runtime']}, not loading with {llm.runtime}.")
|
||||
@@ -136,26 +156,51 @@ def load_model(llm: openllm.LLM[M, T], *decls: t.Any, **attrs: t.Any) -> M:
|
||||
If model is not found, it will raises a ``bentoml.exceptions.NotFound``.
|
||||
"""
|
||||
config, hub_attrs, attrs = process_config(llm.model_id, llm.__llm_trust_remote_code__, **attrs)
|
||||
safe_serialization = openllm.utils.first_not_none(t.cast(t.Optional[bool], llm._bentomodel.info.metadata.get("safe_serialisation", None)), attrs.pop("safe_serialization", None), default=llm._serialisation_format == "safetensors")
|
||||
safe_serialization = openllm.utils.first_not_none(
|
||||
t.cast(t.Optional[bool], llm._bentomodel.info.metadata.get("safe_serialisation", None)), attrs.pop("safe_serialization", None), default=llm._serialisation_format == "safetensors"
|
||||
)
|
||||
if "_quantize" in llm._bentomodel.info.metadata and llm._bentomodel.info.metadata["_quantize"] == "gptq":
|
||||
if not openllm.utils.is_autogptq_available(): raise openllm.exceptions.OpenLLMException("GPTQ quantisation requires 'auto-gptq' (Not found in local environment). Install it with 'pip install \"openllm[gptq]\"'")
|
||||
if not openllm.utils.is_autogptq_available():
|
||||
raise openllm.exceptions.OpenLLMException("GPTQ quantisation requires 'auto-gptq' (Not found in local environment). Install it with 'pip install \"openllm[gptq]\"'")
|
||||
if llm.config["model_type"] != "causal_lm": raise openllm.exceptions.OpenLLMException(f"GPTQ only support Causal LM (got {llm.__class__} of {llm.config['model_type']})")
|
||||
return autogptq.AutoGPTQForCausalLM.from_quantized(llm._bentomodel.path, *decls, quantize_config=t.cast("autogptq.BaseQuantizeConfig", llm.quantization_config), trust_remote_code=llm.__llm_trust_remote_code__, use_safetensors=safe_serialization, **hub_attrs, **attrs)
|
||||
return autogptq.AutoGPTQForCausalLM.from_quantized(
|
||||
llm._bentomodel.path,
|
||||
*decls,
|
||||
quantize_config=t.cast("autogptq.BaseQuantizeConfig", llm.quantization_config),
|
||||
trust_remote_code=llm.__llm_trust_remote_code__,
|
||||
use_safetensors=safe_serialization,
|
||||
**hub_attrs,
|
||||
**attrs
|
||||
)
|
||||
|
||||
device_map = attrs.pop("device_map", "auto" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None)
|
||||
model = infer_autoclass_from_llm(llm, config).from_pretrained(llm._bentomodel.path, *decls, config=config, trust_remote_code=llm.__llm_trust_remote_code__, device_map=device_map, **hub_attrs, **attrs).eval()
|
||||
model = infer_autoclass_from_llm(llm, config).from_pretrained(
|
||||
llm._bentomodel.path, *decls, config=config, trust_remote_code=llm.__llm_trust_remote_code__, device_map=device_map, **hub_attrs, **attrs
|
||||
).eval()
|
||||
# BetterTransformer is currently only supported on PyTorch.
|
||||
if llm.bettertransformer and isinstance(model, transformers.PreTrainedModel): model = model.to_bettertransformer()
|
||||
if llm.__llm_implementation__ in {"pt", "vllm"}: check_unintialised_params(model)
|
||||
return t.cast("M", model)
|
||||
def save_pretrained(llm: openllm.LLM[M, T], save_directory: str, is_main_process: bool = True, state_dict: DictStrAny | None = None, save_function: t.Any | None = None, push_to_hub: bool = False, max_shard_size: int | str = "10GB", safe_serialization: bool = False, variant: str | None = None, **attrs: t.Any) -> None:
|
||||
def save_pretrained(
|
||||
llm: openllm.LLM[M, T],
|
||||
save_directory: str,
|
||||
is_main_process: bool = True,
|
||||
state_dict: DictStrAny | None = None,
|
||||
save_function: t.Any | None = None,
|
||||
push_to_hub: bool = False,
|
||||
max_shard_size: int | str = "10GB",
|
||||
safe_serialization: bool = False,
|
||||
variant: str | None = None,
|
||||
**attrs: t.Any
|
||||
) -> None:
|
||||
save_function = t.cast(t.Callable[..., None], openllm.utils.first_not_none(save_function, default=torch.save))
|
||||
model_save_attrs, tokenizer_save_attrs = openllm.utils.normalize_attrs_to_model_tokenizer_pair(**attrs)
|
||||
safe_serialization = safe_serialization or llm._serialisation_format == "safetensors"
|
||||
# NOTE: disable safetensors for vllm
|
||||
if llm.__llm_implementation__ == "vllm": safe_serialization = False
|
||||
if llm._quantize_method == "gptq":
|
||||
if not openllm.utils.is_autogptq_available(): raise openllm.exceptions.OpenLLMException("GPTQ quantisation requires 'auto-gptq' (Not found in local environment). Install it with 'pip install \"openllm[gptq]\"'")
|
||||
if not openllm.utils.is_autogptq_available():
|
||||
raise openllm.exceptions.OpenLLMException("GPTQ quantisation requires 'auto-gptq' (Not found in local environment). Install it with 'pip install \"openllm[gptq]\"'")
|
||||
if llm.config["model_type"] != "causal_lm": raise openllm.exceptions.OpenLLMException(f"GPTQ only support Causal LM (got {llm.__class__} of {llm.config['model_type']})")
|
||||
if not openllm.utils.lenient_issubclass(llm.model, autogptq.modeling.BaseGPTQForCausalLM): raise ValueError(f"Model is not a BaseGPTQForCausalLM (type: {type(llm.model)})")
|
||||
t.cast("autogptq.modeling.BaseGPTQForCausalLM", llm.model).save_quantized(save_directory, use_safetensors=safe_serialization)
|
||||
@@ -165,5 +210,15 @@ def save_pretrained(llm: openllm.LLM[M, T], save_directory: str, is_main_process
|
||||
llm.model.save_pretrained(save_directory, safe_serialization=safe_serialization)
|
||||
else:
|
||||
# We can safely cast here since it will be the PreTrainedModel protocol.
|
||||
t.cast("transformers.PreTrainedModel", llm.model).save_pretrained(save_directory, is_main_process=is_main_process, state_dict=state_dict, save_function=save_function, push_to_hub=push_to_hub, max_shard_size=max_shard_size, safe_serialization=safe_serialization, variant=variant, **model_save_attrs)
|
||||
t.cast("transformers.PreTrainedModel", llm.model).save_pretrained(
|
||||
save_directory,
|
||||
is_main_process=is_main_process,
|
||||
state_dict=state_dict,
|
||||
save_function=save_function,
|
||||
push_to_hub=push_to_hub,
|
||||
max_shard_size=max_shard_size,
|
||||
safe_serialization=safe_serialization,
|
||||
variant=variant,
|
||||
**model_save_attrs
|
||||
)
|
||||
llm.tokenizer.save_pretrained(save_directory, push_to_hub=push_to_hub, **tokenizer_save_attrs)
|
||||
|
||||
@@ -38,7 +38,8 @@ def infer_tokenizers_from_llm(__llm: openllm.LLM[t.Any, T], /) -> T:
|
||||
def infer_autoclass_from_llm(llm: openllm.LLM[M, T], config: transformers.PretrainedConfig, /) -> _BaseAutoModelClass:
|
||||
if llm.config["trust_remote_code"]:
|
||||
autoclass = "AutoModelForSeq2SeqLM" if llm.config["model_type"] == "seq2seq_lm" else "AutoModelForCausalLM"
|
||||
if not hasattr(config, "auto_map"): raise ValueError(f"Invalid configuraiton for {llm.model_id}. ``trust_remote_code=True`` requires `transformers.PretrainedConfig` to contain a `auto_map` mapping")
|
||||
if not hasattr(config, "auto_map"):
|
||||
raise ValueError(f"Invalid configuraiton for {llm.model_id}. ``trust_remote_code=True`` requires `transformers.PretrainedConfig` to contain a `auto_map` mapping")
|
||||
# in case this model doesn't use the correct auto class for model type, for example like chatglm
|
||||
# where it uses AutoModel instead of AutoModelForCausalLM. Then we fallback to AutoModel
|
||||
if autoclass not in config.auto_map: autoclass = "AutoModel"
|
||||
|
||||
Reference in New Issue
Block a user