mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-04-22 16:07:24 -04:00
fix: bettertransformer check to bool already
Signed-off-by: aarnphm-ec2-dev <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
@@ -73,11 +73,11 @@ all = [
|
||||
"openllm[falcon]",
|
||||
"openllm[mpt]",
|
||||
"openllm[flan-t5]",
|
||||
"openllm[openai]",
|
||||
"openllm[ggml]",
|
||||
"openllm[fine-tune]",
|
||||
"openllm[agents]",
|
||||
"openllm[openai]",
|
||||
"openllm[playground]",
|
||||
"openllm[ggml]",
|
||||
]
|
||||
chatglm = ["cpm-kernels", "sentencepiece"]
|
||||
falcon = ["einops", "xformers", "safetensors"]
|
||||
|
||||
@@ -472,7 +472,7 @@ class LLM(LLMInterface[_M, _T], ReprMixin):
|
||||
def save_pretrained(self, save_directory: str | Path, **attrs: t.Any) -> None:
|
||||
if isinstance(save_directory, Path):
|
||||
save_directory = str(save_directory)
|
||||
if self.__llm_model__ is not None and self.bettertransformer.upper() in ENV_VARS_TRUE_VALUES:
|
||||
if self.__llm_model__ is not None and self.bettertransformer:
|
||||
from optimum.bettertransformer import BetterTransformer
|
||||
|
||||
self.__llm_model__ = BetterTransformer.reverse(self.__llm_model__)
|
||||
|
||||
@@ -28,7 +28,6 @@ import cloudpickle
|
||||
from bentoml._internal.models.model import CUSTOM_OBJECTS_FILENAME
|
||||
from ..utils import LazyLoader, is_torch_available
|
||||
from ..utils import generate_context, normalize_attrs_to_model_tokenizer_pair
|
||||
from ..utils import ENV_VARS_TRUE_VALUES
|
||||
from .constants import FRAMEWORK_TO_AUTOCLASS_MAPPING, MODEL_TO_AUTOCLASS_MAPPING
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
@@ -221,11 +220,7 @@ def load_model(llm: openllm.LLM[_M, t.Any], *decls: t.Any, **attrs: t.Any) -> Mo
|
||||
model = infer_autoclass_from_llm_config(llm, config).from_pretrained(
|
||||
llm._bentomodel.path, *decls, config=config, **hub_attrs, **attrs
|
||||
)
|
||||
if (
|
||||
llm.bettertransformer.upper() in ENV_VARS_TRUE_VALUES
|
||||
and llm.__llm_implementation__ == "pt"
|
||||
and not isinstance(model, transformers.Pipeline)
|
||||
):
|
||||
if llm.bettertransformer and llm.__llm_implementation__ == "pt" and not isinstance(model, transformers.Pipeline):
|
||||
# BetterTransformer is currently only supported on PyTorch.
|
||||
from optimum.bettertransformer import BetterTransformer
|
||||
|
||||
|
||||
@@ -48,7 +48,7 @@ def test_general_build_from_local(tmp_path_factory: pytest.TempPathFactory):
|
||||
local_path = tmp_path_factory.mktemp("local_t5")
|
||||
llm = openllm.AutoLLM.for_model("flan-t5", model_id=HF_INTERNAL_T5_TESTING, ensure_available=True)
|
||||
|
||||
if llm.bettertransformer.upper() in openllm.utils.ENV_VARS_TRUE_VALUES:
|
||||
if llm.bettertransformer:
|
||||
llm.__llm_model__ = llm.model.reverse_bettertransformer()
|
||||
|
||||
llm.save_pretrained(local_path)
|
||||
|
||||
Reference in New Issue
Block a user