From 4cd0784ee23586e66acc1497925eb407fc47c53a Mon Sep 17 00:00:00 2001 From: aarnphm-ec2-dev <29749331+aarnphm@users.noreply.github.com> Date: Sun, 23 Jul 2023 06:50:00 +0000 Subject: [PATCH] chore: export generation items for lazy loading Signed-off-by: aarnphm-ec2-dev <29749331+aarnphm@users.noreply.github.com> --- src/openllm/__init__.py | 7 +- src/openllm/_llm.py | 43 ++-- src/openllm/cli.py | 7 +- src/openllm/models/auto/__init__.py | 99 +++----- src/openllm/models/auto/configuration_auto.py | 106 ++------- src/openllm/models/auto/factory.py | 213 +++--------------- src/openllm/models/auto/modeling_auto.py | 6 - src/openllm/models/auto/modeling_flax_auto.py | 7 - src/openllm/models/auto/modeling_tf_auto.py | 7 - src/openllm/models/auto/modeling_vllm_auto.py | 7 - src/openllm/models/falcon/modeling_falcon.py | 3 +- .../models/gpt_neox/modeling_gpt_neox.py | 4 +- src/openllm/models/llama/modeling_llama.py | 4 +- .../models/stablelm/modeling_stablelm.py | 3 +- .../models/starcoder/modeling_starcoder.py | 3 +- src/openllm/serialisation/constants.py | 2 + src/openllm/serialisation/transformers.py | 3 + 17 files changed, 135 insertions(+), 389 deletions(-) diff --git a/src/openllm/__init__.py b/src/openllm/__init__.py index da14fafa..553d1ab9 100644 --- a/src/openllm/__init__.py +++ b/src/openllm/__init__.py @@ -60,8 +60,10 @@ else: _import_structure: dict[str, list[str]] = { "_llm": ["LLM", "Runner", "LLMRunner", "LLMRunnable"], "_configuration": ["LLMConfig"], - "exceptions": [], "_schema": ["GenerationInput", "GenerationOutput", "MetadataOutput", "EmbeddingsOutput"], + "_generation": ["StopSequenceCriteria", "StopOnTokens"], + "_quantisation": ["infer_quantisation_config"], + "exceptions": [], "utils": ["infer_auto_class"], "models": [], "client": [], @@ -202,10 +204,13 @@ if t.TYPE_CHECKING: # Specific types import from ._configuration import LLMConfig as LLMConfig + from ._generation import StopOnTokens as StopOnTokens + from ._generation import StopSequenceCriteria as StopSequenceCriteria from ._llm import LLM as LLM from ._llm import LLMRunnable as LLMRunnable from ._llm import LLMRunner as LLMRunner from ._llm import Runner as Runner + from ._quantisation import infer_quantisation_config as infer_quantisation_config from ._schema import EmbeddingsOutput as EmbeddingsOutput from ._schema import GenerationInput as GenerationInput from ._schema import GenerationOutput as GenerationOutput diff --git a/src/openllm/_llm.py b/src/openllm/_llm.py index 65fc5a92..ad54a742 100644 --- a/src/openllm/_llm.py +++ b/src/openllm/_llm.py @@ -41,7 +41,6 @@ from ._configuration import AdapterType from ._configuration import FineTuneConfig from ._configuration import _object_getattribute from ._configuration import _setattr_class -from ._quantisation import infer_quantisation_config from .exceptions import ForbiddenAttributeError from .exceptions import GpuNotAvailableError from .utils import DEBUG @@ -378,7 +377,6 @@ class LLMInterface(ABC, t.Generic[M, T]): # NOTE: All fields below are attributes that can be accessed by users. config_class: type[openllm.LLMConfig] """The config class to use for this LLM. If you are creating a custom LLM, you must specify this class.""" - bettertransformer: bool """Whether to load this LLM with FasterTransformer enabled. The order of loading is: @@ -388,8 +386,7 @@ class LLMInterface(ABC, t.Generic[M, T]): > **Note** that if LoRA is enabled, bettertransformer will be disabled. """ - - device: torch.device + device: "torch.device" """The device to be used for this LLM. If the implementation is 'pt', then it will be torch.device, else string.""" # NOTE: The following will be populated by __init_subclass__, note that these should be immutable. @@ -418,7 +415,6 @@ class LLMInterface(ABC, t.Generic[M, T]): """A reference to the bentomodel used for this LLM. Instead of access this directly, you should use `_bentomodel` property instead.""" __llm_adapter_map__: dict[AdapterType, dict[str | t.Literal["default"], tuple[peft.PeftConfig, str]]] | None """A reference to the the cached LoRA adapter mapping.""" - __llm_supports_embeddings__: bool """A boolean to determine whether models does implement ``LLM.embeddings``.""" __llm_supports_generate__: bool @@ -616,10 +612,8 @@ class LLM(LLMInterface[M, T], ReprMixin): config: openllm.LLMConfig """The config instance to use for this LLM. This will be created based on config_class and available when initialising the LLM.""" - quantization_config: transformers.BitsAndBytesConfig | autogptq.BaseQuantizeConfig | None """Quantisation config for quantised model on the fly.""" - _model_id: str _runtime: t.Literal["ggml", "transformers"] _model_decls: TupleAny @@ -631,8 +625,6 @@ class LLM(LLMInterface[M, T], ReprMixin): _quantize_method: t.Literal["int8", "int4", "gptq"] | None _serialisation_format: t.Literal["safetensors", "legacy"] - tokenizer_cls: t.ClassVar[str | None] = None - @staticmethod def _infer_implementation_from_name(name: str) -> tuple[LiteralRuntime, str]: if name.startswith("Flax"): @@ -656,6 +648,29 @@ class LLM(LLMInterface[M, T], ReprMixin): raise RuntimeError("Missing required key 'config_class'. Make sure to define it within the LLM subclass.") _make_assignment_script(cls)(cls) + # fmt: off + @overload + def __getitem__(self, item: t.Literal["trust_remote_code"]) -> bool: ... + @overload + def __getitem__(self, item: t.Literal["implementation"]) -> LiteralRuntime: ... + @overload + def __getitem__(self, item: t.Literal["model"]) -> M | None: ... + @overload + def __getitem__(self, item: t.Literal["tokenizer"]) -> T | None: ... + @overload + def __getitem__(self, item: t.Literal["bentomodel"]) -> bentoml.Model | None: ... + @overload + def __getitem__(self, item: t.Literal["adapter_map"]) -> dict[AdapterType, dict[str | t.Literal["default"], tuple[peft.PeftConfig, str]]] | None: ... + @overload + def __getitem__(self, item: t.Literal["supports_embeddings"]) -> bool: ... + @overload + def __getitem__(self, item: t.Literal["supports_generate"]) -> bool: ... + @overload + def __getitem__(self, item: t.Literal["supports_generate_one"]) -> bool: ... + @overload + def __getitem__(self, item: t.Literal["supports_generate_iterator"]) -> bool: ... + # fmt: on + def __getitem__(self, item: t.LiteralString | t.Any) -> t.Any: if item is None: raise TypeError(f"{self} doesn't understand how to index None.") @@ -804,11 +819,13 @@ class LLM(LLMInterface[M, T], ReprMixin): your quantization_config or use the 'quantize' argument.""" ) if quantization_config is None and quantize is not None: - quantization_config, attrs = infer_quantisation_config(cls, quantize, **attrs) + quantization_config, attrs = openllm.infer_quantisation_config(cls, quantize, **attrs) if quantize == "gptq": # We will use safetensors for gptq serialisation = "safetensors" + elif cls.__llm_implementation__ == "vllm": + serialisation = "legacy" # NOTE: LoRA adapter setup if adapter_map and adapter_id: @@ -1516,7 +1533,7 @@ def Runner( default_implementation = llm_config["default_implementation"] if llm_config is not None else "pt" - implementation = first_not_none( + implementation: LiteralRuntime = first_not_none( implementation, default=EnvVarMixin(model_name, default_implementation)["framework_value"] ) @@ -1705,8 +1722,8 @@ def llm_runner_class(self: openllm.LLM[M, T]) -> type[LLMRunner]: "__repr__": ReprMixin.__repr__, "__repr_keys__": property(_wrapped_repr_keys), "__repr_args__": _wrapped_repr_args, - "supports_embeddings": self["supports-embeddings"], - "supports_hf_agent": self["supports-generate-one"], + "supports_embeddings": self["supports_embeddings"], + "supports_hf_agent": self["supports_generate_one"], "has_adapters": self._adapters_mapping is not None, } ), diff --git a/src/openllm/cli.py b/src/openllm/cli.py index 24b25663..22fe146b 100644 --- a/src/openllm/cli.py +++ b/src/openllm/cli.py @@ -1243,7 +1243,12 @@ def start_model( @output_option @quantize_option(click) @machine_option(click) -@click.option("--implementation", type=click.Choice(["pt", "tf", "flax", "vllm"]), default=None, hidden=True) +@click.option( + "--implementation", + type=click.Choice(["pt", "tf", "flax", "vllm"]), + default=None, + help="The implementation for saving this LLM.", +) @serialisation_option(click) def download_models_command( model: str, diff --git a/src/openllm/models/auto/__init__.py b/src/openllm/models/auto/__init__.py index ef8518da..95ee087e 100644 --- a/src/openllm/models/auto/__init__.py +++ b/src/openllm/models/auto/__init__.py @@ -11,21 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -"""This module is derived from HuggingFace's AutoConfig, AutoModel, etc.""" - from __future__ import annotations +import sys import typing as t - import openllm - from ...utils import LazyModule from ...utils import is_flax_available from ...utils import is_tf_available from ...utils import is_torch_available from ...utils import is_vllm_available - - _import_structure: dict[str, list[str]] = { "configuration_auto": ["AutoConfig", "CONFIG_MAPPING", "CONFIG_MAPPING_NAMES"], "modeling_auto": ["MODEL_MAPPING_NAMES"], @@ -33,40 +27,22 @@ _import_structure: dict[str, list[str]] = { "modeling_tf_auto": ["MODEL_TF_MAPPING_NAMES"], "modeling_vllm_auto": ["MODEL_VLLM_MAPPING_NAMES"], } - try: - if not is_torch_available(): - raise openllm.exceptions.MissingDependencyError -except openllm.exceptions.MissingDependencyError: - pass -else: - _import_structure["modeling_auto"].extend(["AutoLLM", "MODEL_MAPPING"]) - + if not is_torch_available(): raise openllm.exceptions.MissingDependencyError +except openllm.exceptions.MissingDependencyError: pass +else: _import_structure["modeling_auto"].extend(["AutoLLM", "MODEL_MAPPING"]) try: - if not is_vllm_available(): - raise openllm.exceptions.MissingDependencyError -except openllm.exceptions.MissingDependencyError: - pass -else: - _import_structure["modeling_vllm_auto"].extend(["AutoVLLM", "MODEL_VLLM_MAPPING"]) - - + if not is_vllm_available(): raise openllm.exceptions.MissingDependencyError +except openllm.exceptions.MissingDependencyError: pass +else: _import_structure["modeling_vllm_auto"].extend(["AutoVLLM", "MODEL_VLLM_MAPPING"]) try: - if not is_flax_available(): - raise openllm.exceptions.MissingDependencyError -except openllm.exceptions.MissingDependencyError: - pass -else: - _import_structure["modeling_flax_auto"].extend(["AutoFlaxLLM", "MODEL_FLAX_MAPPING"]) - + if not is_flax_available(): raise openllm.exceptions.MissingDependencyError +except openllm.exceptions.MissingDependencyError: pass +else: _import_structure["modeling_flax_auto"].extend(["AutoFlaxLLM", "MODEL_FLAX_MAPPING"]) try: - if not is_tf_available(): - raise openllm.exceptions.MissingDependencyError -except openllm.exceptions.MissingDependencyError: - pass -else: - _import_structure["modeling_tf_auto"].extend(["AutoTFLLM", "MODEL_TF_MAPPING"]) - + if not is_tf_available(): raise openllm.exceptions.MissingDependencyError +except openllm.exceptions.MissingDependencyError: pass +else: _import_structure["modeling_tf_auto"].extend(["AutoTFLLM", "MODEL_TF_MAPPING"]) if t.TYPE_CHECKING: from .configuration_auto import CONFIG_MAPPING as CONFIG_MAPPING from .configuration_auto import CONFIG_MAPPING_NAMES as CONFIG_MAPPING_NAMES @@ -75,43 +51,20 @@ if t.TYPE_CHECKING: from .modeling_flax_auto import MODEL_FLAX_MAPPING_NAMES as MODEL_FLAX_MAPPING_NAMES from .modeling_tf_auto import MODEL_TF_MAPPING_NAMES as MODEL_TF_MAPPING_NAMES from .modeling_vllm_auto import MODEL_VLLM_MAPPING_NAMES as MODEL_VLLM_MAPPING_NAMES - try: - if not is_torch_available(): - raise openllm.exceptions.MissingDependencyError - except openllm.exceptions.MissingDependencyError: - pass - else: - from .modeling_auto import MODEL_MAPPING as MODEL_MAPPING - from .modeling_auto import AutoLLM as AutoLLM - + if not is_torch_available(): raise openllm.exceptions.MissingDependencyError + except openllm.exceptions.MissingDependencyError: pass + else: from .modeling_auto import MODEL_MAPPING as MODEL_MAPPING, AutoLLM as AutoLLM try: - if not is_vllm_available(): - raise openllm.exceptions.MissingDependencyError - except openllm.exceptions.MissingDependencyError: - pass - else: - from .modeling_vllm_auto import MODEL_VLLM_MAPPING as MODEL_VLLM_MAPPING - from .modeling_vllm_auto import AutoVLLM as AutoVLLM - + if not is_vllm_available(): raise openllm.exceptions.MissingDependencyError + except openllm.exceptions.MissingDependencyError: pass + else: from .modeling_vllm_auto import MODEL_VLLM_MAPPING as MODEL_VLLM_MAPPING, AutoVLLM as AutoVLLM try: - if not is_flax_available(): - raise openllm.exceptions.MissingDependencyError - except openllm.exceptions.MissingDependencyError: - pass - else: - from .modeling_flax_auto import MODEL_FLAX_MAPPING as MODEL_FLAX_MAPPING - from .modeling_flax_auto import AutoFlaxLLM as AutoFlaxLLM - + if not is_flax_available(): raise openllm.exceptions.MissingDependencyError + except openllm.exceptions.MissingDependencyError: pass + else: from .modeling_flax_auto import MODEL_FLAX_MAPPING as MODEL_FLAX_MAPPING, AutoFlaxLLM as AutoFlaxLLM try: - if not is_tf_available(): - raise openllm.exceptions.MissingDependencyError - except openllm.exceptions.MissingDependencyError: - pass - else: - from .modeling_tf_auto import MODEL_TF_MAPPING as MODEL_TF_MAPPING - from .modeling_tf_auto import AutoTFLLM as AutoTFLLM -else: - import sys - - sys.modules[__name__] = LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) + if not is_tf_available(): raise openllm.exceptions.MissingDependencyError + except openllm.exceptions.MissingDependencyError: pass + else: from .modeling_tf_auto import MODEL_TF_MAPPING as MODEL_TF_MAPPING, AutoTFLLM as AutoTFLLM +else: sys.modules[__name__] = LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) diff --git a/src/openllm/models/auto/configuration_auto.py b/src/openllm/models/auto/configuration_auto.py index dc54fb87..f30edaf3 100644 --- a/src/openllm/models/auto/configuration_auto.py +++ b/src/openllm/models/auto/configuration_auto.py @@ -11,33 +11,22 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from __future__ import annotations import typing as t from collections import OrderedDict - import inflection - import openllm - from ...utils import ReprMixin - - if t.TYPE_CHECKING: import types - from collections import _odict_items - from collections import _odict_keys - from collections import _odict_values - + from collections import _odict_items, _odict_keys, _odict_values ConfigOrderedDict = OrderedDict[str, type[openllm.LLMConfig]] - ConfigKeysView = _odict_keys[str, type[openllm.LLMConfig]] ConfigValuesView = _odict_values[str, type[openllm.LLMConfig]] ConfigItemsView = _odict_items[str, type[openllm.LLMConfig]] else: ConfigKeysView = ConfigValuesView = ConfigItemsView = t.Any ConfigOrderedDict = OrderedDict - # NOTE: This is the entrypoint when adding new model config CONFIG_MAPPING_NAMES = OrderedDict( [ @@ -54,68 +43,34 @@ CONFIG_MAPPING_NAMES = OrderedDict( ("baichuan", "BaichuanConfig"), ] ) - - class _LazyConfigMapping(ConfigOrderedDict, ReprMixin): def __init__(self, mapping: OrderedDict[t.LiteralString, t.LiteralString]): self._mapping = mapping self._extra_content: dict[str, t.Any] = {} self._modules: dict[str, types.ModuleType] = {} - def __getitem__(self, key: str) -> t.Any: - if key in self._extra_content: - return self._extra_content[key] + if key in self._extra_content: return self._extra_content[key] if key not in self._mapping: - if inflection.underscore(key) in self._mapping: - return self.__getitem__(inflection.underscore(key)) + if inflection.underscore(key) in self._mapping: return self.__getitem__(inflection.underscore(key)) raise KeyError(key) - value = self._mapping[key] - module_name = inflection.underscore(key) - if module_name not in self._modules: - self._modules[module_name] = openllm.utils.EnvVarMixin(module_name).module - if hasattr(self._modules[module_name], value): - return getattr(self._modules[module_name], value) - - # Some of the mappings have entries model_type -> config of another model type. In that case we try to grab the - # object at the top level. + value, module_name = self._mapping[key], inflection.underscore(key) + if module_name not in self._modules: self._modules[module_name] = openllm.utils.EnvVarMixin(module_name).module + if hasattr(self._modules[module_name], value): return getattr(self._modules[module_name], value) + # Some of the mappings have entries model_type -> config of another model type. In that case we try to grab the object at the top level. return getattr(openllm, value) - @property - def __repr_keys__(self) -> set[str]: - return set(self._mapping.keys()) - - def __repr__(self) -> str: - return ReprMixin.__repr__(self) - - def __repr_args__(self) -> t.Generator[tuple[str, t.Any], t.Any, t.Any]: - yield from self._mapping.items() - - def keys(self): - return t.cast(ConfigKeysView, list(self._mapping.keys()) + list(self._extra_content.keys())) - - def values(self): - return t.cast(ConfigValuesView, [self[k] for k in self._mapping.keys()] + list(self._extra_content.values())) - - def items(self): - return t.cast( - ConfigItemsView, [(k, self[k]) for k in self._mapping.keys()] + list(self._extra_content.items()) - ) - - def __iter__(self): - return iter(list(self._mapping.keys()) + list(self._extra_content.keys())) - - def __contains__(self, item: t.Any): - return item in self._mapping or item in self._extra_content - + def __repr_keys__(self) -> set[str]: return set(self._mapping.keys()) + def __repr__(self) -> str: return ReprMixin.__repr__(self) + def __repr_args__(self) -> t.Generator[tuple[str, t.Any], t.Any, t.Any]: yield from self._mapping.items() + def keys(self): return t.cast(ConfigKeysView, list(self._mapping.keys()) + list(self._extra_content.keys())) + def values(self): return t.cast(ConfigValuesView, [self[k] for k in self._mapping.keys()] + list(self._extra_content.values())) + def items(self): return t.cast(ConfigItemsView, [(k, self[k]) for k in self._mapping.keys()] + list(self._extra_content.items())) + def __iter__(self): return iter(list(self._mapping.keys()) + list(self._extra_content.keys())) + def __contains__(self, item: t.Any): return item in self._mapping or item in self._extra_content def register(self, key: str, value: t.Any): - """Register a new configuration in this mapping.""" - if key in self._mapping.keys(): - raise ValueError(f"'{key}' is already used by a OpenLLM config, pick another name.") + if key in self._mapping.keys(): raise ValueError(f"'{key}' is already used by a OpenLLM config, pick another name.") self._extra_content[key] = value - - CONFIG_MAPPING: dict[str, type[openllm.LLMConfig]] = _LazyConfigMapping(CONFIG_MAPPING_NAMES) - # The below handle special alias when we call underscore to the name directly # without processing camelcase first. CONFIG_NAME_ALIASES: dict[str, str] = { @@ -125,33 +80,16 @@ CONFIG_NAME_ALIASES: dict[str, str] = { "gpt_neo_x": "gpt_neox", "lla_ma": "llama", } - - class AutoConfig: - def __init__(self, *_: t.Any, **__: t.Any): - """This metaclass should be initialised via `AutoConfig.for_model`.""" - raise EnvironmentError( - "Cannot instantiate AutoConfig directly. Please use `AutoConfig.for_model(model_name)` instead." - ) - + def __init__(self, *_: t.Any, **__: t.Any): raise EnvironmentError("Cannot instantiate AutoConfig directly. Please use `AutoConfig.for_model(model_name)` instead.") @classmethod def for_model(cls, model_name: str, **attrs: t.Any) -> openllm.LLMConfig: model_name = inflection.underscore(model_name) - if model_name in CONFIG_MAPPING: - return CONFIG_MAPPING[model_name].model_construct_env(**attrs) - raise ValueError( - f"Unrecognized configuration class for {model_name}. " - f"Model name should be one of {', '.join(CONFIG_MAPPING.keys())}." - ) - + if model_name in CONFIG_MAPPING: return CONFIG_MAPPING[model_name].model_construct_env(**attrs) + raise ValueError(f"Unrecognized configuration class for {model_name}. Model name should be one of {', '.join(CONFIG_MAPPING.keys())}.") @classmethod def infer_class_from_name(cls, name: str) -> type[openllm.LLMConfig]: model_name = inflection.underscore(name) - if model_name in CONFIG_NAME_ALIASES: - model_name = CONFIG_NAME_ALIASES[model_name] - if model_name in CONFIG_MAPPING: - return CONFIG_MAPPING[model_name] - raise ValueError( - f"Unrecognized configuration class for {model_name}. " - f"Model name should be one of {', '.join(CONFIG_MAPPING.keys())}." - ) + if model_name in CONFIG_NAME_ALIASES: model_name = CONFIG_NAME_ALIASES[model_name] + if model_name in CONFIG_MAPPING: return CONFIG_MAPPING[model_name] + raise ValueError(f"Unrecognized configuration class for {model_name}. Model name should be one of {', '.join(CONFIG_MAPPING.keys())}.") diff --git a/src/openllm/models/auto/factory.py b/src/openllm/models/auto/factory.py index a9fd4aa9..d17c33be 100644 --- a/src/openllm/models/auto/factory.py +++ b/src/openllm/models/auto/factory.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from __future__ import annotations import importlib import inspect @@ -19,31 +18,18 @@ import logging import sys import typing as t from collections import OrderedDict - import inflection - import openllm - from .configuration_auto import AutoConfig from ...utils import ReprMixin - - # NOTE: We need to do this so that overload can register # correct overloads to typing registry -if sys.version_info[:2] >= (3, 11): - from typing import overload -else: - from typing_extensions import overload - - +if sys.version_info[:2] >= (3, 11): from typing import overload +else: from typing_extensions import overload if t.TYPE_CHECKING: import types - from collections import _odict_items - from collections import _odict_keys - from collections import _odict_values - from ..._llm import LLMRunner - + from collections import _odict_items, _odict_keys, _odict_values ConfigModelOrderedDict = OrderedDict[type[openllm.LLMConfig], type[openllm.LLM[t.Any, t.Any]]] ConfigModelKeysView = _odict_keys[type[openllm.LLMConfig], type[openllm.LLM[t.Any, t.Any]]] ConfigModelValuesView = _odict_values[type[openllm.LLMConfig], type[openllm.LLM[t.Any, t.Any]]] @@ -51,58 +37,18 @@ if t.TYPE_CHECKING: else: ConfigModelKeysView = ConfigModelValuesView = ConfigModelItemsView = t.Any ConfigModelOrderedDict = OrderedDict - logger = logging.getLogger(__name__) - - class BaseAutoLLMClass: _model_mapping: _LazyAutoMapping - - def __init__(self, *args: t.Any, **attrs: t.Any): # noqa - raise EnvironmentError( - f"Cannot instantiate {self.__class__.__name__} directly. " - "Please use '{self.__class__.__name__}.Runner(model_name)' instead." - ) - + def __init__(self, *args: t.Any, **attrs: t.Any): raise EnvironmentError(f"Cannot instantiate {self.__class__.__name__} directly. Please use '{self.__class__.__name__}.Runner(model_name)' instead.") @overload @classmethod - def for_model( - cls, - model: str, - model_id: str | None = None, - model_version: str | None = None, - return_runner_kwargs: t.Literal[False] = False, - llm_config: openllm.LLMConfig | None = ..., - ensure_available: t.Literal[False, True] = ..., - **attrs: t.Any, - ) -> openllm.LLM[t.Any, t.Any]: - ... - + def for_model(cls, model: str, model_id: str | None = None, model_version: str | None = None, return_runner_kwargs: t.Literal[False] = False, llm_config: openllm.LLMConfig | None = ..., ensure_available: t.Literal[False, True] = ..., **attrs: t.Any) -> openllm.LLM[t.Any, t.Any]: ... @overload @classmethod - def for_model( - cls, - model: str, - model_id: str | None = None, - model_version: str | None = None, - return_runner_kwargs: t.Literal[True] = ..., - llm_config: openllm.LLMConfig | None = ..., - ensure_available: t.Literal[False, True] = ..., - **attrs: t.Any, - ) -> tuple[openllm.LLM[t.Any, t.Any], dict[str, t.Any]]: - ... - + def for_model(cls, model: str, model_id: str | None = None, model_version: str | None = None, return_runner_kwargs: t.Literal[True] = ..., llm_config: openllm.LLMConfig | None = ..., ensure_available: t.Literal[False, True] = ..., **attrs: t.Any) -> tuple[openllm.LLM[t.Any, t.Any], dict[str, t.Any]]: ... @classmethod - def for_model( - cls, - model: str, - model_id: str | None = None, - model_version: str | None = None, - return_runner_kwargs: bool = False, - llm_config: openllm.LLMConfig | None = None, - ensure_available: bool = False, - **attrs: t.Any, - ) -> openllm.LLM[t.Any, t.Any] | tuple[openllm.LLM[t.Any, t.Any], dict[str, t.Any]]: + def for_model(cls, model: str, model_id: str | None = None, model_version: str | None = None, return_runner_kwargs: bool = False, llm_config: openllm.LLMConfig | None = None, ensure_available: bool = False, **attrs: t.Any) -> openllm.LLM[t.Any, t.Any] | tuple[openllm.LLM[t.Any, t.Any], dict[str, t.Any]]: """The lower level API for creating a LLM instance. ```python @@ -128,16 +74,10 @@ class BaseAutoLLMClass: if type(llm_config) in cls._model_mapping.keys(): model_class = cls._model_mapping[type(llm_config)] llm = model_class.from_pretrained(model_id, model_version=model_version, llm_config=llm_config, **attrs) - if ensure_available: - llm.ensure_model_id_exists() - if not return_runner_kwargs: - return llm + if ensure_available: llm.ensure_model_id_exists() + if not return_runner_kwargs: return llm return llm, to_runner_attrs - raise ValueError( - f"Unrecognized configuration class {llm_config.__class__} for this kind of AutoLLM: {cls.__name__}.\n" - f"LLM type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}." - ) - + raise ValueError(f"Unrecognized configuration class {llm_config.__class__} for this kind of AutoLLM: {cls.__name__}.\nLLM type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}.") @classmethod def create_runner(cls, model: str, model_id: str | None = None, **attrs: t.Any) -> LLMRunner: """Create a LLM Runner for the given model name. @@ -152,7 +92,6 @@ class BaseAutoLLMClass: """ llm, runner_attrs = cls.for_model(model, model_id, return_runner_kwargs=True, **attrs) return llm.to_runner(**runner_attrs) - @classmethod def register(cls, config_class: type[openllm.LLMConfig], llm_class: type[openllm.LLM[t.Any, t.Any]]): """Register a new model for this class. @@ -161,142 +100,60 @@ class BaseAutoLLMClass: config_class: The configuration corresponding to the model to register. llm_class: The runnable to register. """ - if hasattr(llm_class, "config_class") and llm_class.config_class is not config_class: - raise ValueError( - "The model class you are passing has a `config_class` attribute that is not consistent with the " - f"config class you passed (model has {llm_class.config_class} and you passed {config_class}. Fix " - "one of those so they match!" - ) + if hasattr(llm_class, "config_class") and llm_class.config_class is not config_class: raise ValueError("The model class you are passing has a `config_class` attribute that is not consistent with the config class you passed (model has {llm_class.config_class} and you passed {config_class}. Fix one of those so they match!") cls._model_mapping.register(config_class, llm_class) - - def getattribute_from_module(module: types.ModuleType, attr: t.Any) -> t.Any: - if attr is None: - return - if isinstance(attr, tuple): - return tuple(getattribute_from_module(module, a) for a in attr) - if hasattr(module, attr): - return getattr(module, attr) + if attr is None: return + if isinstance(attr, tuple): return tuple(getattribute_from_module(module, a) for a in attr) + if hasattr(module, attr): return getattr(module, attr) # Some of the mappings have entries model_type -> object of another model type. In that case we try to grab the # object at the top level. openllm_module = importlib.import_module("openllm") - if module != openllm_module: - try: - return getattribute_from_module(openllm_module, attr) - except ValueError: - raise ValueError(f"Could not find {attr} neither in {module} nor in {openllm_module}!") from None - else: - raise ValueError(f"Could not find {attr} in {openllm_module}!") - - + try: return getattribute_from_module(openllm_module, attr) + except ValueError: raise ValueError(f"Could not find {attr} neither in {module} nor in {openllm_module}!") from None + else: raise ValueError(f"Could not find {attr} in {openllm_module}!") class _LazyAutoMapping(ConfigModelOrderedDict, ReprMixin): """Based on transformers.models.auto.configuration_auto._LazyAutoMapping. This OrderedDict values() and keys() returns the list instead, so you don't have to do list(mapping.values()) to get the list of values. """ - - def __init__( - self, - config_mapping: OrderedDict[t.LiteralString, t.LiteralString], - model_mapping: OrderedDict[t.LiteralString, t.LiteralString], - ): + def __init__(self, config_mapping: OrderedDict[t.LiteralString, t.LiteralString], model_mapping: OrderedDict[t.LiteralString, t.LiteralString]): self._config_mapping = config_mapping self._reverse_config_mapping = {v: k for k, v in config_mapping.items()} self._model_mapping = model_mapping self._extra_content: dict[t.Any, t.Any] = {} self._modules: dict[str, types.ModuleType] = {} - - def __len__(self): - common_keys = set(self._config_mapping.keys()).intersection(self._model_mapping.keys()) - return len(common_keys) + len(self._extra_content) - + def __len__(self): return len(set(self._config_mapping.keys()).intersection(self._model_mapping.keys())) + len(self._extra_content) def __getitem__(self, key: type[openllm.LLMConfig]) -> type[openllm.LLM[t.Any, t.Any]]: - if key in self._extra_content: - return self._extra_content[key] + if key in self._extra_content: return self._extra_content[key] model_type = self._reverse_config_mapping[key.__name__] - if model_type in self._model_mapping: - model_name = self._model_mapping[model_type] - return self._load_attr_from_module(model_type, model_name) - + if model_type in self._model_mapping: return self._load_attr_from_module(model_type, self._model_mapping[model_type]) # Maybe there was several model types associated with this config. model_types = [k for k, v in self._config_mapping.items() if v == key.__name__] for mtype in model_types: - if mtype in self._model_mapping: - model_name = self._model_mapping[mtype] - return self._load_attr_from_module(mtype, model_name) + if mtype in self._model_mapping: return self._load_attr_from_module(mtype, self._model_mapping[mtype]) raise KeyError(key) - def _load_attr_from_module(self, model_type: str, attr: str) -> t.Any: module_name = inflection.underscore(model_type) - if module_name not in self._modules: - self._modules[module_name] = importlib.import_module(f".{module_name}", "openllm.models") + if module_name not in self._modules: self._modules[module_name] = importlib.import_module(f".{module_name}", "openllm.models") return getattribute_from_module(self._modules[module_name], attr) - - def keys(self): - mapping_keys = [ - self._load_attr_from_module(key, name) - for key, name in self._config_mapping.items() - if key in self._model_mapping.keys() - ] - return t.cast(ConfigModelKeysView, mapping_keys + list(self._extra_content.keys())) - + def keys(self): return t.cast(ConfigModelKeysView, [self._load_attr_from_module(key, name) for key, name in self._config_mapping.items() if key in self._model_mapping.keys()] + list(self._extra_content.keys())) @property - def __repr_keys__(self) -> set[str]: - return set(self._config_mapping.keys()) - - def __repr__(self) -> str: - return ReprMixin.__repr__(self) - - def __repr_args__(self) -> t.Generator[tuple[str, tuple[str, str]], t.Any, t.Any]: - yield from ( - (key, (value, self._model_mapping[key])) - for key, value in self._config_mapping.items() - if key in self._model_mapping - ) - - def __bool__(self): - return bool(self.keys()) - - def values(self): - mapping_values = [ - self._load_attr_from_module(key, name) - for key, name in self._model_mapping.items() - if key in self._config_mapping.keys() - ] - return t.cast(ConfigModelValuesView, mapping_values + list(self._extra_content.values())) - - def items(self): - mapping_items = [ - ( - self._load_attr_from_module(key, self._config_mapping[key]), - self._load_attr_from_module(key, self._model_mapping[key]), - ) - for key in self._model_mapping.keys() - if key in self._config_mapping.keys() - ] - return t.cast(ConfigModelItemsView, mapping_items + list(self._extra_content.items())) - - def __iter__(self): - return iter(self.keys()) - + def __repr_keys__(self) -> set[str]: return set(self._config_mapping.keys()) + def __repr__(self) -> str: return ReprMixin.__repr__(self) + def __repr_args__(self) -> t.Generator[tuple[str, tuple[str, str]], t.Any, t.Any]: yield from ((key, (value, self._model_mapping[key])) for key, value in self._config_mapping.items() if key in self._model_mapping) + def __bool__(self): return bool(self.keys()) + def values(self): return t.cast(ConfigModelValuesView, [self._load_attr_from_module(key, name) for key, name in self._model_mapping.items() if key in self._config_mapping.keys()] + list(self._extra_content.values())) + def items(self): return t.cast(ConfigModelItemsView, [(self._load_attr_from_module(key, self._config_mapping[key]), self._load_attr_from_module(key, self._model_mapping[key])) for key in self._model_mapping.keys() if key in self._config_mapping.keys()] + list(self._extra_content.items())) + def __iter__(self): return iter(self.keys()) def __contains__(self, item: t.Any): - if item in self._extra_content: - return True - if not hasattr(item, "__name__") or item.__name__ not in self._reverse_config_mapping: - return False - model_type = self._reverse_config_mapping[item.__name__] - return model_type in self._model_mapping - + if item in self._extra_content: return True + if not hasattr(item, "__name__") or item.__name__ not in self._reverse_config_mapping: return False + return self._reverse_config_mapping[item.__name__] in self._model_mapping def register(self, key: t.Any, value: t.Any): - """Register a new model in this mapping.""" if hasattr(key, "__name__") and key.__name__ in self._reverse_config_mapping: - model_type = self._reverse_config_mapping[key.__name__] - if model_type in self._model_mapping.keys(): - raise ValueError(f"'{key}' is already used by a OpenLLM model.") - + if self._reverse_config_mapping[key.__name__] in self._model_mapping.keys(): raise ValueError(f"'{key}' is already used by a OpenLLM model.") self._extra_content[key] = value - - __all__ = ["BaseAutoLLMClass", "_LazyAutoMapping"] diff --git a/src/openllm/models/auto/modeling_auto.py b/src/openllm/models/auto/modeling_auto.py index 58c9e910..87b6261d 100644 --- a/src/openllm/models/auto/modeling_auto.py +++ b/src/openllm/models/auto/modeling_auto.py @@ -14,12 +14,9 @@ from __future__ import annotations from collections import OrderedDict - from .configuration_auto import CONFIG_MAPPING_NAMES from .factory import BaseAutoLLMClass from .factory import _LazyAutoMapping - - MODEL_MAPPING_NAMES = OrderedDict( [ ("chatglm", "ChatGLM"), @@ -35,9 +32,6 @@ MODEL_MAPPING_NAMES = OrderedDict( ("baichuan", "Baichuan"), ] ) - MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_MAPPING_NAMES) - - class AutoLLM(BaseAutoLLMClass): _model_mapping = MODEL_MAPPING diff --git a/src/openllm/models/auto/modeling_flax_auto.py b/src/openllm/models/auto/modeling_flax_auto.py index 242c6f54..18f489cf 100644 --- a/src/openllm/models/auto/modeling_flax_auto.py +++ b/src/openllm/models/auto/modeling_flax_auto.py @@ -11,24 +11,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from __future__ import annotations from collections import OrderedDict - from .configuration_auto import CONFIG_MAPPING_NAMES from .factory import BaseAutoLLMClass from .factory import _LazyAutoMapping - - MODEL_FLAX_MAPPING_NAMES = OrderedDict( [ ("flan_t5", "FlaxFlanT5"), ("opt", "FlaxOPT"), ] ) - MODEL_FLAX_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FLAX_MAPPING_NAMES) - - class AutoFlaxLLM(BaseAutoLLMClass): _model_mapping = MODEL_FLAX_MAPPING diff --git a/src/openllm/models/auto/modeling_tf_auto.py b/src/openllm/models/auto/modeling_tf_auto.py index df2df55a..fd90f56b 100644 --- a/src/openllm/models/auto/modeling_tf_auto.py +++ b/src/openllm/models/auto/modeling_tf_auto.py @@ -11,24 +11,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from __future__ import annotations from collections import OrderedDict - from .configuration_auto import CONFIG_MAPPING_NAMES from .factory import BaseAutoLLMClass from .factory import _LazyAutoMapping - - MODEL_TF_MAPPING_NAMES = OrderedDict( [ ("flan_t5", "TFFlanT5"), ("opt", "TFOPT"), ] ) - MODEL_TF_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_TF_MAPPING_NAMES) - - class AutoTFLLM(BaseAutoLLMClass): _model_mapping = MODEL_TF_MAPPING diff --git a/src/openllm/models/auto/modeling_vllm_auto.py b/src/openllm/models/auto/modeling_vllm_auto.py index 3c24e3d2..e5356d09 100644 --- a/src/openllm/models/auto/modeling_vllm_auto.py +++ b/src/openllm/models/auto/modeling_vllm_auto.py @@ -11,23 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from __future__ import annotations from collections import OrderedDict - from .configuration_auto import CONFIG_MAPPING_NAMES from .factory import BaseAutoLLMClass from .factory import _LazyAutoMapping - - MODEL_VLLM_MAPPING_NAMES = OrderedDict( [ ("llama", "VLLMLlaMA"), ] ) - MODEL_VLLM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_VLLM_MAPPING_NAMES) - - class AutoVLLM(BaseAutoLLMClass): _model_mapping = MODEL_VLLM_MAPPING diff --git a/src/openllm/models/falcon/modeling_falcon.py b/src/openllm/models/falcon/modeling_falcon.py index 342d99ae..40fc1c60 100644 --- a/src/openllm/models/falcon/modeling_falcon.py +++ b/src/openllm/models/falcon/modeling_falcon.py @@ -40,10 +40,9 @@ class Falcon(openllm.LLM["transformers.PreTrainedModel", "transformers.PreTraine eos_token_id, inputs = attrs.pop("eos_token_id", self.tokenizer.eos_token_id), self.tokenizer(prompt, return_tensors="pt").to(self.device) with torch.inference_mode(), torch.autocast("cuda", dtype=torch.float16): return self.tokenizer.batch_decode(self.model.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], generation_config=self.config.model_construct_env( eos_token_id=eos_token_id, **attrs).to_generation_config()), skip_special_tokens=True) def generate_one(self, prompt: str, stop: list[str], **preprocess_generate_kwds: t.Any) -> list[dict[t.Literal["generated_text"], str]]: - from ..._generation import StopSequenceCriteria max_new_tokens, encoded_inputs = preprocess_generate_kwds.pop("max_new_tokens", 200), self.tokenizer(prompt, return_tensors="pt").to(self.device) src_len, stopping_criteria = encoded_inputs["input_ids"].shape[1], preprocess_generate_kwds.pop("stopping_criteria", transformers.StoppingCriteriaList([])) - stopping_criteria.append(StopSequenceCriteria(stop, self.tokenizer)) + stopping_criteria.append(openllm.StopSequenceCriteria(stop, self.tokenizer)) result = self.tokenizer.decode(self.model.generate(encoded_inputs["input_ids"], max_new_tokens=max_new_tokens, stopping_criteria=stopping_criteria)[0].tolist()[src_len:]) # Inference API returns the stop sequence for stop_seq in stop: diff --git a/src/openllm/models/gpt_neox/modeling_gpt_neox.py b/src/openllm/models/gpt_neox/modeling_gpt_neox.py index bd38b0a5..4ad5c30e 100644 --- a/src/openllm/models/gpt_neox/modeling_gpt_neox.py +++ b/src/openllm/models/gpt_neox/modeling_gpt_neox.py @@ -39,6 +39,4 @@ class GPTNeoX(openllm.LLM["transformers.GPTNeoXForCausalLM", "transformers.GPTNe if self.config.use_half_precision: model.half() return model def generate(self, prompt: str, **attrs: t.Any) -> list[str]: - from ..._generation import StopOnTokens - generation_kwargs = {"do_sample": True, "generation_config": self.config.model_construct_env(**attrs).to_generation_config(), "pad_token_id": self.tokenizer.eos_token_id, "stopping_criteria": transformers.StoppingCriteriaList([StopOnTokens()])} - with torch.inference_mode(): return self.tokenizer.batch_decode(self.model.generate(self.tokenizer(prompt, return_tensors="pt").to(self.device).input_ids, **generation_kwargs)) + with torch.inference_mode(): return self.tokenizer.batch_decode(self.model.generate(self.tokenizer(prompt, return_tensors="pt").to(self.device).input_ids, do_sample=True, generation_config=self.config.model_construct_env(**attrs).to_generation_config(), pad_token_id=self.tokenizer.eos_token_id, stopping_criteria=transformers.StoppingCriteriaList([openllm.StopOnTokens()]))) diff --git a/src/openllm/models/llama/modeling_llama.py b/src/openllm/models/llama/modeling_llama.py index 1ed4ec5b..968a7e35 100644 --- a/src/openllm/models/llama/modeling_llama.py +++ b/src/openllm/models/llama/modeling_llama.py @@ -37,9 +37,7 @@ class LlaMA(openllm.LLM["transformers.LlamaForCausalLM", "transformers.LlamaToke def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]: return {"device_map": "auto" if torch.cuda.device_count() > 1 else None}, {} def postprocess_generate(self, prompt: str, generation_result: list[str], **_: t.Any) -> str: return generation_result[0] def generate(self, prompt: str, **attrs: t.Any) -> list[str]: - from ..._generation import StopOnTokens - generation_kwargs = {"generation_config": self.config.model_construct_env(**attrs).to_generation_config(), "stopping_criteria": transformers.StoppingCriteriaList([StopOnTokens()])} - with torch.inference_mode(): return self.tokenizer.batch_decode(self.model.generate(**self.tokenizer(prompt, return_tensors="pt").to(self.device), **generation_kwargs), skip_special_tokens=True, clean_up_tokenization_spaces=True) + with torch.inference_mode(): return self.tokenizer.batch_decode(self.model.generate(**self.tokenizer(prompt, return_tensors="pt").to(self.device), generation_config=self.config.model_construct_env(**attrs).to_generation_config(), stopping_criteria=transformers.StoppingCriteriaList([openllm.StopOnTokens()])), skip_special_tokens=True, clean_up_tokenization_spaces=True) def embeddings(self, prompts: list[str]) -> LLMEmbeddings: encoding = self.tokenizer(prompts, padding=True, return_tensors="pt").to(self.device) input_ids, attention_mask = encoding["input_ids"], encoding["attention_mask"] diff --git a/src/openllm/models/stablelm/modeling_stablelm.py b/src/openllm/models/stablelm/modeling_stablelm.py index 0f8384bb..8d578477 100644 --- a/src/openllm/models/stablelm/modeling_stablelm.py +++ b/src/openllm/models/stablelm/modeling_stablelm.py @@ -36,5 +36,4 @@ class StableLM(openllm.LLM["transformers.GPTNeoXForCausalLM", "transformers.GPTN return prompt_text, {"max_new_tokens": max_new_tokens, "temperature": temperature, "top_k": top_k, "top_p": top_p}, {} def postprocess_generate(self, prompt: str, generation_result: list[str], **_: t.Any) -> str: return generation_result[0] def generate(self, prompt: str, **attrs: t.Any) -> list[str]: - from ..._generation import StopOnTokens - with torch.inference_mode(): return [self.tokenizer.decode(self.model.generate(**self.tokenizer(prompt, return_tensors="pt").to(self.device), do_sample=True, generation_config=self.config.model_construct_env(**attrs).to_generation_config(), pad_token_id=self.tokenizer.eos_token_id, stopping_criteria=transformers.StoppingCriteriaList([StopOnTokens()]))[0], skip_special_tokens=True)] + with torch.inference_mode(): return [self.tokenizer.decode(self.model.generate(**self.tokenizer(prompt, return_tensors="pt").to(self.device), do_sample=True, generation_config=self.config.model_construct_env(**attrs).to_generation_config(), pad_token_id=self.tokenizer.eos_token_id, stopping_criteria=transformers.StoppingCriteriaList([openllm.StopOnTokens()]))[0], skip_special_tokens=True)] diff --git a/src/openllm/models/starcoder/modeling_starcoder.py b/src/openllm/models/starcoder/modeling_starcoder.py index 93da678a..1744cff3 100644 --- a/src/openllm/models/starcoder/modeling_starcoder.py +++ b/src/openllm/models/starcoder/modeling_starcoder.py @@ -55,10 +55,9 @@ class StarCoder(openllm.LLM["transformers.GPTBigCodeForCausalLM", "transformers. # return (skip_special_tokens=False, clean_up_tokenization_spaces=False)) return self.tokenizer.batch_decode(result_tensor[0], skip_special_tokens=True, clean_up_tokenization_spaces=True) def generate_one(self, prompt: str, stop: list[str], **preprocess_generate_kwds: t.Any) -> list[dict[t.Literal["generated_text"], str]]: - from ..._generation import StopSequenceCriteria max_new_tokens, encoded_inputs = preprocess_generate_kwds.pop("max_new_tokens", 200), self.tokenizer(prompt, return_tensors="pt").to(self.device) src_len, stopping_criteria = encoded_inputs["input_ids"].shape[1], preprocess_generate_kwds.pop("stopping_criteria", transformers.StoppingCriteriaList([])) - stopping_criteria.append(StopSequenceCriteria(stop, self.tokenizer)) + stopping_criteria.append(openllm.StopSequenceCriteria(stop, self.tokenizer)) result = self.tokenizer.decode(self.model.generate(encoded_inputs["input_ids"], max_new_tokens=max_new_tokens, stopping_criteria=stopping_criteria)[0].tolist()[src_len:]) # Inference API returns the stop sequence for stop_seq in stop: diff --git a/src/openllm/serialisation/constants.py b/src/openllm/serialisation/constants.py index af1844f9..0c40a046 100644 --- a/src/openllm/serialisation/constants.py +++ b/src/openllm/serialisation/constants.py @@ -17,6 +17,8 @@ from __future__ import annotations FRAMEWORK_TO_AUTOCLASS_MAPPING = { "pt": ("AutoModelForCausalLM", "AutoModelForSeq2SeqLM"), + # NOTE: vllm will use PyTorch implementation of transformers for serialisation + "vllm": ("AutoModelForCausalLM", "AutoModelForSeq2SeqLM"), "tf": ("TFAutoModelForCausalLM", "TFAutoModelForSeq2SeqLM"), "flax": ("FlaxAutoModelForCausalLM", "FlaxAutoModelForSeq2SeqLM"), } diff --git a/src/openllm/serialisation/transformers.py b/src/openllm/serialisation/transformers.py index ee5e86f7..a07cceb2 100644 --- a/src/openllm/serialisation/transformers.py +++ b/src/openllm/serialisation/transformers.py @@ -329,6 +329,9 @@ def save_pretrained( save_function = first_not_none(save_function, default=torch.save) model_save_attrs, tokenizer_save_attrs = normalize_attrs_to_model_tokenizer_pair(**attrs) safe_serialization = safe_serialization or llm._serialisation_format == "safetensors" + if llm.__llm_implementation__ == "vllm": + # NOTE: disable safetensors for vllm + safe_serialization = False if llm._quantize_method == "gptq": if not is_autogptq_available(): raise OpenLLMException(