diff --git a/pyproject.toml b/pyproject.toml index 1b55d690..aca5e7e0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -73,11 +73,11 @@ all = [ "openllm[falcon]", "openllm[mpt]", "openllm[flan-t5]", - "openllm[openai]", + "openllm[ggml]", "openllm[fine-tune]", "openllm[agents]", + "openllm[openai]", "openllm[playground]", - "openllm[ggml]", ] chatglm = ["cpm-kernels", "sentencepiece"] falcon = ["einops", "xformers", "safetensors"] diff --git a/src/openllm/_llm.py b/src/openllm/_llm.py index 9d8747fe..e7b4f683 100644 --- a/src/openllm/_llm.py +++ b/src/openllm/_llm.py @@ -472,7 +472,7 @@ class LLM(LLMInterface[_M, _T], ReprMixin): def save_pretrained(self, save_directory: str | Path, **attrs: t.Any) -> None: if isinstance(save_directory, Path): save_directory = str(save_directory) - if self.__llm_model__ is not None and self.bettertransformer.upper() in ENV_VARS_TRUE_VALUES: + if self.__llm_model__ is not None and self.bettertransformer: from optimum.bettertransformer import BetterTransformer self.__llm_model__ = BetterTransformer.reverse(self.__llm_model__) diff --git a/src/openllm/serialisation/transformers.py b/src/openllm/serialisation/transformers.py index 8cb6d6bd..cc2a1b1c 100644 --- a/src/openllm/serialisation/transformers.py +++ b/src/openllm/serialisation/transformers.py @@ -28,7 +28,6 @@ import cloudpickle from bentoml._internal.models.model import CUSTOM_OBJECTS_FILENAME from ..utils import LazyLoader, is_torch_available from ..utils import generate_context, normalize_attrs_to_model_tokenizer_pair -from ..utils import ENV_VARS_TRUE_VALUES from .constants import FRAMEWORK_TO_AUTOCLASS_MAPPING, MODEL_TO_AUTOCLASS_MAPPING if t.TYPE_CHECKING: @@ -221,11 +220,7 @@ def load_model(llm: openllm.LLM[_M, t.Any], *decls: t.Any, **attrs: t.Any) -> Mo model = infer_autoclass_from_llm_config(llm, config).from_pretrained( llm._bentomodel.path, *decls, config=config, **hub_attrs, **attrs ) - if ( - llm.bettertransformer.upper() in ENV_VARS_TRUE_VALUES - and llm.__llm_implementation__ == "pt" - and not isinstance(model, transformers.Pipeline) - ): + if llm.bettertransformer and llm.__llm_implementation__ == "pt" and not isinstance(model, transformers.Pipeline): # BetterTransformer is currently only supported on PyTorch. from optimum.bettertransformer import BetterTransformer diff --git a/tests/test_package.py b/tests/test_package.py index 84fb56c7..d61d2581 100644 --- a/tests/test_package.py +++ b/tests/test_package.py @@ -48,7 +48,7 @@ def test_general_build_from_local(tmp_path_factory: pytest.TempPathFactory): local_path = tmp_path_factory.mktemp("local_t5") llm = openllm.AutoLLM.for_model("flan-t5", model_id=HF_INTERNAL_T5_TESTING, ensure_available=True) - if llm.bettertransformer.upper() in openllm.utils.ENV_VARS_TRUE_VALUES: + if llm.bettertransformer: llm.__llm_model__ = llm.model.reverse_bettertransformer() llm.save_pretrained(local_path)