fix: bettertransformer check to bool already

Signed-off-by: aarnphm-ec2-dev <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
aarnphm-ec2-dev
2023-07-04 17:57:21 +00:00
parent d6303d306a
commit 4c5b27495c
4 changed files with 5 additions and 10 deletions

View File

@@ -73,11 +73,11 @@ all = [
"openllm[falcon]",
"openllm[mpt]",
"openllm[flan-t5]",
"openllm[openai]",
"openllm[ggml]",
"openllm[fine-tune]",
"openllm[agents]",
"openllm[openai]",
"openllm[playground]",
"openllm[ggml]",
]
chatglm = ["cpm-kernels", "sentencepiece"]
falcon = ["einops", "xformers", "safetensors"]

View File

@@ -472,7 +472,7 @@ class LLM(LLMInterface[_M, _T], ReprMixin):
def save_pretrained(self, save_directory: str | Path, **attrs: t.Any) -> None:
if isinstance(save_directory, Path):
save_directory = str(save_directory)
if self.__llm_model__ is not None and self.bettertransformer.upper() in ENV_VARS_TRUE_VALUES:
if self.__llm_model__ is not None and self.bettertransformer:
from optimum.bettertransformer import BetterTransformer
self.__llm_model__ = BetterTransformer.reverse(self.__llm_model__)

View File

@@ -28,7 +28,6 @@ import cloudpickle
from bentoml._internal.models.model import CUSTOM_OBJECTS_FILENAME
from ..utils import LazyLoader, is_torch_available
from ..utils import generate_context, normalize_attrs_to_model_tokenizer_pair
from ..utils import ENV_VARS_TRUE_VALUES
from .constants import FRAMEWORK_TO_AUTOCLASS_MAPPING, MODEL_TO_AUTOCLASS_MAPPING
if t.TYPE_CHECKING:
@@ -221,11 +220,7 @@ def load_model(llm: openllm.LLM[_M, t.Any], *decls: t.Any, **attrs: t.Any) -> Mo
model = infer_autoclass_from_llm_config(llm, config).from_pretrained(
llm._bentomodel.path, *decls, config=config, **hub_attrs, **attrs
)
if (
llm.bettertransformer.upper() in ENV_VARS_TRUE_VALUES
and llm.__llm_implementation__ == "pt"
and not isinstance(model, transformers.Pipeline)
):
if llm.bettertransformer and llm.__llm_implementation__ == "pt" and not isinstance(model, transformers.Pipeline):
# BetterTransformer is currently only supported on PyTorch.
from optimum.bettertransformer import BetterTransformer

View File

@@ -48,7 +48,7 @@ def test_general_build_from_local(tmp_path_factory: pytest.TempPathFactory):
local_path = tmp_path_factory.mktemp("local_t5")
llm = openllm.AutoLLM.for_model("flan-t5", model_id=HF_INTERNAL_T5_TESTING, ensure_available=True)
if llm.bettertransformer.upper() in openllm.utils.ENV_VARS_TRUE_VALUES:
if llm.bettertransformer:
llm.__llm_model__ = llm.model.reverse_bettertransformer()
llm.save_pretrained(local_path)