chore: export generation items for lazy loading

Signed-off-by: aarnphm-ec2-dev <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
aarnphm-ec2-dev
2023-07-23 06:50:00 +00:00
parent e2cdd767ef
commit 4cd0784ee2
17 changed files with 135 additions and 389 deletions

View File

@@ -60,8 +60,10 @@ else:
_import_structure: dict[str, list[str]] = {
"_llm": ["LLM", "Runner", "LLMRunner", "LLMRunnable"],
"_configuration": ["LLMConfig"],
"exceptions": [],
"_schema": ["GenerationInput", "GenerationOutput", "MetadataOutput", "EmbeddingsOutput"],
"_generation": ["StopSequenceCriteria", "StopOnTokens"],
"_quantisation": ["infer_quantisation_config"],
"exceptions": [],
"utils": ["infer_auto_class"],
"models": [],
"client": [],
@@ -202,10 +204,13 @@ if t.TYPE_CHECKING:
# Specific types import
from ._configuration import LLMConfig as LLMConfig
from ._generation import StopOnTokens as StopOnTokens
from ._generation import StopSequenceCriteria as StopSequenceCriteria
from ._llm import LLM as LLM
from ._llm import LLMRunnable as LLMRunnable
from ._llm import LLMRunner as LLMRunner
from ._llm import Runner as Runner
from ._quantisation import infer_quantisation_config as infer_quantisation_config
from ._schema import EmbeddingsOutput as EmbeddingsOutput
from ._schema import GenerationInput as GenerationInput
from ._schema import GenerationOutput as GenerationOutput

View File

@@ -41,7 +41,6 @@ from ._configuration import AdapterType
from ._configuration import FineTuneConfig
from ._configuration import _object_getattribute
from ._configuration import _setattr_class
from ._quantisation import infer_quantisation_config
from .exceptions import ForbiddenAttributeError
from .exceptions import GpuNotAvailableError
from .utils import DEBUG
@@ -378,7 +377,6 @@ class LLMInterface(ABC, t.Generic[M, T]):
# NOTE: All fields below are attributes that can be accessed by users.
config_class: type[openllm.LLMConfig]
"""The config class to use for this LLM. If you are creating a custom LLM, you must specify this class."""
bettertransformer: bool
"""Whether to load this LLM with FasterTransformer enabled. The order of loading is:
@@ -388,8 +386,7 @@ class LLMInterface(ABC, t.Generic[M, T]):
> **Note** that if LoRA is enabled, bettertransformer will be disabled.
"""
device: torch.device
device: "torch.device"
"""The device to be used for this LLM. If the implementation is 'pt', then it will be torch.device, else string."""
# NOTE: The following will be populated by __init_subclass__, note that these should be immutable.
@@ -418,7 +415,6 @@ class LLMInterface(ABC, t.Generic[M, T]):
"""A reference to the bentomodel used for this LLM. Instead of access this directly, you should use `_bentomodel` property instead."""
__llm_adapter_map__: dict[AdapterType, dict[str | t.Literal["default"], tuple[peft.PeftConfig, str]]] | None
"""A reference to the the cached LoRA adapter mapping."""
__llm_supports_embeddings__: bool
"""A boolean to determine whether models does implement ``LLM.embeddings``."""
__llm_supports_generate__: bool
@@ -616,10 +612,8 @@ class LLM(LLMInterface[M, T], ReprMixin):
config: openllm.LLMConfig
"""The config instance to use for this LLM. This will be created based on config_class and available
when initialising the LLM."""
quantization_config: transformers.BitsAndBytesConfig | autogptq.BaseQuantizeConfig | None
"""Quantisation config for quantised model on the fly."""
_model_id: str
_runtime: t.Literal["ggml", "transformers"]
_model_decls: TupleAny
@@ -631,8 +625,6 @@ class LLM(LLMInterface[M, T], ReprMixin):
_quantize_method: t.Literal["int8", "int4", "gptq"] | None
_serialisation_format: t.Literal["safetensors", "legacy"]
tokenizer_cls: t.ClassVar[str | None] = None
@staticmethod
def _infer_implementation_from_name(name: str) -> tuple[LiteralRuntime, str]:
if name.startswith("Flax"):
@@ -656,6 +648,29 @@ class LLM(LLMInterface[M, T], ReprMixin):
raise RuntimeError("Missing required key 'config_class'. Make sure to define it within the LLM subclass.")
_make_assignment_script(cls)(cls)
# fmt: off
@overload
def __getitem__(self, item: t.Literal["trust_remote_code"]) -> bool: ...
@overload
def __getitem__(self, item: t.Literal["implementation"]) -> LiteralRuntime: ...
@overload
def __getitem__(self, item: t.Literal["model"]) -> M | None: ...
@overload
def __getitem__(self, item: t.Literal["tokenizer"]) -> T | None: ...
@overload
def __getitem__(self, item: t.Literal["bentomodel"]) -> bentoml.Model | None: ...
@overload
def __getitem__(self, item: t.Literal["adapter_map"]) -> dict[AdapterType, dict[str | t.Literal["default"], tuple[peft.PeftConfig, str]]] | None: ...
@overload
def __getitem__(self, item: t.Literal["supports_embeddings"]) -> bool: ...
@overload
def __getitem__(self, item: t.Literal["supports_generate"]) -> bool: ...
@overload
def __getitem__(self, item: t.Literal["supports_generate_one"]) -> bool: ...
@overload
def __getitem__(self, item: t.Literal["supports_generate_iterator"]) -> bool: ...
# fmt: on
def __getitem__(self, item: t.LiteralString | t.Any) -> t.Any:
if item is None:
raise TypeError(f"{self} doesn't understand how to index None.")
@@ -804,11 +819,13 @@ class LLM(LLMInterface[M, T], ReprMixin):
your quantization_config or use the 'quantize' argument."""
)
if quantization_config is None and quantize is not None:
quantization_config, attrs = infer_quantisation_config(cls, quantize, **attrs)
quantization_config, attrs = openllm.infer_quantisation_config(cls, quantize, **attrs)
if quantize == "gptq":
# We will use safetensors for gptq
serialisation = "safetensors"
elif cls.__llm_implementation__ == "vllm":
serialisation = "legacy"
# NOTE: LoRA adapter setup
if adapter_map and adapter_id:
@@ -1516,7 +1533,7 @@ def Runner(
default_implementation = llm_config["default_implementation"] if llm_config is not None else "pt"
implementation = first_not_none(
implementation: LiteralRuntime = first_not_none(
implementation, default=EnvVarMixin(model_name, default_implementation)["framework_value"]
)
@@ -1705,8 +1722,8 @@ def llm_runner_class(self: openllm.LLM[M, T]) -> type[LLMRunner]:
"__repr__": ReprMixin.__repr__,
"__repr_keys__": property(_wrapped_repr_keys),
"__repr_args__": _wrapped_repr_args,
"supports_embeddings": self["supports-embeddings"],
"supports_hf_agent": self["supports-generate-one"],
"supports_embeddings": self["supports_embeddings"],
"supports_hf_agent": self["supports_generate_one"],
"has_adapters": self._adapters_mapping is not None,
}
),

View File

@@ -1243,7 +1243,12 @@ def start_model(
@output_option
@quantize_option(click)
@machine_option(click)
@click.option("--implementation", type=click.Choice(["pt", "tf", "flax", "vllm"]), default=None, hidden=True)
@click.option(
"--implementation",
type=click.Choice(["pt", "tf", "flax", "vllm"]),
default=None,
help="The implementation for saving this LLM.",
)
@serialisation_option(click)
def download_models_command(
model: str,

View File

@@ -11,21 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This module is derived from HuggingFace's AutoConfig, AutoModel, etc."""
from __future__ import annotations
import sys
import typing as t
import openllm
from ...utils import LazyModule
from ...utils import is_flax_available
from ...utils import is_tf_available
from ...utils import is_torch_available
from ...utils import is_vllm_available
_import_structure: dict[str, list[str]] = {
"configuration_auto": ["AutoConfig", "CONFIG_MAPPING", "CONFIG_MAPPING_NAMES"],
"modeling_auto": ["MODEL_MAPPING_NAMES"],
@@ -33,40 +27,22 @@ _import_structure: dict[str, list[str]] = {
"modeling_tf_auto": ["MODEL_TF_MAPPING_NAMES"],
"modeling_vllm_auto": ["MODEL_VLLM_MAPPING_NAMES"],
}
try:
if not is_torch_available():
raise openllm.exceptions.MissingDependencyError
except openllm.exceptions.MissingDependencyError:
pass
else:
_import_structure["modeling_auto"].extend(["AutoLLM", "MODEL_MAPPING"])
if not is_torch_available(): raise openllm.exceptions.MissingDependencyError
except openllm.exceptions.MissingDependencyError: pass
else: _import_structure["modeling_auto"].extend(["AutoLLM", "MODEL_MAPPING"])
try:
if not is_vllm_available():
raise openllm.exceptions.MissingDependencyError
except openllm.exceptions.MissingDependencyError:
pass
else:
_import_structure["modeling_vllm_auto"].extend(["AutoVLLM", "MODEL_VLLM_MAPPING"])
if not is_vllm_available(): raise openllm.exceptions.MissingDependencyError
except openllm.exceptions.MissingDependencyError: pass
else: _import_structure["modeling_vllm_auto"].extend(["AutoVLLM", "MODEL_VLLM_MAPPING"])
try:
if not is_flax_available():
raise openllm.exceptions.MissingDependencyError
except openllm.exceptions.MissingDependencyError:
pass
else:
_import_structure["modeling_flax_auto"].extend(["AutoFlaxLLM", "MODEL_FLAX_MAPPING"])
if not is_flax_available(): raise openllm.exceptions.MissingDependencyError
except openllm.exceptions.MissingDependencyError: pass
else: _import_structure["modeling_flax_auto"].extend(["AutoFlaxLLM", "MODEL_FLAX_MAPPING"])
try:
if not is_tf_available():
raise openllm.exceptions.MissingDependencyError
except openllm.exceptions.MissingDependencyError:
pass
else:
_import_structure["modeling_tf_auto"].extend(["AutoTFLLM", "MODEL_TF_MAPPING"])
if not is_tf_available(): raise openllm.exceptions.MissingDependencyError
except openllm.exceptions.MissingDependencyError: pass
else: _import_structure["modeling_tf_auto"].extend(["AutoTFLLM", "MODEL_TF_MAPPING"])
if t.TYPE_CHECKING:
from .configuration_auto import CONFIG_MAPPING as CONFIG_MAPPING
from .configuration_auto import CONFIG_MAPPING_NAMES as CONFIG_MAPPING_NAMES
@@ -75,43 +51,20 @@ if t.TYPE_CHECKING:
from .modeling_flax_auto import MODEL_FLAX_MAPPING_NAMES as MODEL_FLAX_MAPPING_NAMES
from .modeling_tf_auto import MODEL_TF_MAPPING_NAMES as MODEL_TF_MAPPING_NAMES
from .modeling_vllm_auto import MODEL_VLLM_MAPPING_NAMES as MODEL_VLLM_MAPPING_NAMES
try:
if not is_torch_available():
raise openllm.exceptions.MissingDependencyError
except openllm.exceptions.MissingDependencyError:
pass
else:
from .modeling_auto import MODEL_MAPPING as MODEL_MAPPING
from .modeling_auto import AutoLLM as AutoLLM
if not is_torch_available(): raise openllm.exceptions.MissingDependencyError
except openllm.exceptions.MissingDependencyError: pass
else: from .modeling_auto import MODEL_MAPPING as MODEL_MAPPING, AutoLLM as AutoLLM
try:
if not is_vllm_available():
raise openllm.exceptions.MissingDependencyError
except openllm.exceptions.MissingDependencyError:
pass
else:
from .modeling_vllm_auto import MODEL_VLLM_MAPPING as MODEL_VLLM_MAPPING
from .modeling_vllm_auto import AutoVLLM as AutoVLLM
if not is_vllm_available(): raise openllm.exceptions.MissingDependencyError
except openllm.exceptions.MissingDependencyError: pass
else: from .modeling_vllm_auto import MODEL_VLLM_MAPPING as MODEL_VLLM_MAPPING, AutoVLLM as AutoVLLM
try:
if not is_flax_available():
raise openllm.exceptions.MissingDependencyError
except openllm.exceptions.MissingDependencyError:
pass
else:
from .modeling_flax_auto import MODEL_FLAX_MAPPING as MODEL_FLAX_MAPPING
from .modeling_flax_auto import AutoFlaxLLM as AutoFlaxLLM
if not is_flax_available(): raise openllm.exceptions.MissingDependencyError
except openllm.exceptions.MissingDependencyError: pass
else: from .modeling_flax_auto import MODEL_FLAX_MAPPING as MODEL_FLAX_MAPPING, AutoFlaxLLM as AutoFlaxLLM
try:
if not is_tf_available():
raise openllm.exceptions.MissingDependencyError
except openllm.exceptions.MissingDependencyError:
pass
else:
from .modeling_tf_auto import MODEL_TF_MAPPING as MODEL_TF_MAPPING
from .modeling_tf_auto import AutoTFLLM as AutoTFLLM
else:
import sys
sys.modules[__name__] = LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
if not is_tf_available(): raise openllm.exceptions.MissingDependencyError
except openllm.exceptions.MissingDependencyError: pass
else: from .modeling_tf_auto import MODEL_TF_MAPPING as MODEL_TF_MAPPING, AutoTFLLM as AutoTFLLM
else: sys.modules[__name__] = LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)

View File

@@ -11,33 +11,22 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import typing as t
from collections import OrderedDict
import inflection
import openllm
from ...utils import ReprMixin
if t.TYPE_CHECKING:
import types
from collections import _odict_items
from collections import _odict_keys
from collections import _odict_values
from collections import _odict_items, _odict_keys, _odict_values
ConfigOrderedDict = OrderedDict[str, type[openllm.LLMConfig]]
ConfigKeysView = _odict_keys[str, type[openllm.LLMConfig]]
ConfigValuesView = _odict_values[str, type[openllm.LLMConfig]]
ConfigItemsView = _odict_items[str, type[openllm.LLMConfig]]
else:
ConfigKeysView = ConfigValuesView = ConfigItemsView = t.Any
ConfigOrderedDict = OrderedDict
# NOTE: This is the entrypoint when adding new model config
CONFIG_MAPPING_NAMES = OrderedDict(
[
@@ -54,68 +43,34 @@ CONFIG_MAPPING_NAMES = OrderedDict(
("baichuan", "BaichuanConfig"),
]
)
class _LazyConfigMapping(ConfigOrderedDict, ReprMixin):
def __init__(self, mapping: OrderedDict[t.LiteralString, t.LiteralString]):
self._mapping = mapping
self._extra_content: dict[str, t.Any] = {}
self._modules: dict[str, types.ModuleType] = {}
def __getitem__(self, key: str) -> t.Any:
if key in self._extra_content:
return self._extra_content[key]
if key in self._extra_content: return self._extra_content[key]
if key not in self._mapping:
if inflection.underscore(key) in self._mapping:
return self.__getitem__(inflection.underscore(key))
if inflection.underscore(key) in self._mapping: return self.__getitem__(inflection.underscore(key))
raise KeyError(key)
value = self._mapping[key]
module_name = inflection.underscore(key)
if module_name not in self._modules:
self._modules[module_name] = openllm.utils.EnvVarMixin(module_name).module
if hasattr(self._modules[module_name], value):
return getattr(self._modules[module_name], value)
# Some of the mappings have entries model_type -> config of another model type. In that case we try to grab the
# object at the top level.
value, module_name = self._mapping[key], inflection.underscore(key)
if module_name not in self._modules: self._modules[module_name] = openllm.utils.EnvVarMixin(module_name).module
if hasattr(self._modules[module_name], value): return getattr(self._modules[module_name], value)
# Some of the mappings have entries model_type -> config of another model type. In that case we try to grab the object at the top level.
return getattr(openllm, value)
@property
def __repr_keys__(self) -> set[str]:
return set(self._mapping.keys())
def __repr__(self) -> str:
return ReprMixin.__repr__(self)
def __repr_args__(self) -> t.Generator[tuple[str, t.Any], t.Any, t.Any]:
yield from self._mapping.items()
def keys(self):
return t.cast(ConfigKeysView, list(self._mapping.keys()) + list(self._extra_content.keys()))
def values(self):
return t.cast(ConfigValuesView, [self[k] for k in self._mapping.keys()] + list(self._extra_content.values()))
def items(self):
return t.cast(
ConfigItemsView, [(k, self[k]) for k in self._mapping.keys()] + list(self._extra_content.items())
)
def __iter__(self):
return iter(list(self._mapping.keys()) + list(self._extra_content.keys()))
def __contains__(self, item: t.Any):
return item in self._mapping or item in self._extra_content
def __repr_keys__(self) -> set[str]: return set(self._mapping.keys())
def __repr__(self) -> str: return ReprMixin.__repr__(self)
def __repr_args__(self) -> t.Generator[tuple[str, t.Any], t.Any, t.Any]: yield from self._mapping.items()
def keys(self): return t.cast(ConfigKeysView, list(self._mapping.keys()) + list(self._extra_content.keys()))
def values(self): return t.cast(ConfigValuesView, [self[k] for k in self._mapping.keys()] + list(self._extra_content.values()))
def items(self): return t.cast(ConfigItemsView, [(k, self[k]) for k in self._mapping.keys()] + list(self._extra_content.items()))
def __iter__(self): return iter(list(self._mapping.keys()) + list(self._extra_content.keys()))
def __contains__(self, item: t.Any): return item in self._mapping or item in self._extra_content
def register(self, key: str, value: t.Any):
"""Register a new configuration in this mapping."""
if key in self._mapping.keys():
raise ValueError(f"'{key}' is already used by a OpenLLM config, pick another name.")
if key in self._mapping.keys(): raise ValueError(f"'{key}' is already used by a OpenLLM config, pick another name.")
self._extra_content[key] = value
CONFIG_MAPPING: dict[str, type[openllm.LLMConfig]] = _LazyConfigMapping(CONFIG_MAPPING_NAMES)
# The below handle special alias when we call underscore to the name directly
# without processing camelcase first.
CONFIG_NAME_ALIASES: dict[str, str] = {
@@ -125,33 +80,16 @@ CONFIG_NAME_ALIASES: dict[str, str] = {
"gpt_neo_x": "gpt_neox",
"lla_ma": "llama",
}
class AutoConfig:
def __init__(self, *_: t.Any, **__: t.Any):
"""This metaclass should be initialised via `AutoConfig.for_model`."""
raise EnvironmentError(
"Cannot instantiate AutoConfig directly. Please use `AutoConfig.for_model(model_name)` instead."
)
def __init__(self, *_: t.Any, **__: t.Any): raise EnvironmentError("Cannot instantiate AutoConfig directly. Please use `AutoConfig.for_model(model_name)` instead.")
@classmethod
def for_model(cls, model_name: str, **attrs: t.Any) -> openllm.LLMConfig:
model_name = inflection.underscore(model_name)
if model_name in CONFIG_MAPPING:
return CONFIG_MAPPING[model_name].model_construct_env(**attrs)
raise ValueError(
f"Unrecognized configuration class for {model_name}. "
f"Model name should be one of {', '.join(CONFIG_MAPPING.keys())}."
)
if model_name in CONFIG_MAPPING: return CONFIG_MAPPING[model_name].model_construct_env(**attrs)
raise ValueError(f"Unrecognized configuration class for {model_name}. Model name should be one of {', '.join(CONFIG_MAPPING.keys())}.")
@classmethod
def infer_class_from_name(cls, name: str) -> type[openllm.LLMConfig]:
model_name = inflection.underscore(name)
if model_name in CONFIG_NAME_ALIASES:
model_name = CONFIG_NAME_ALIASES[model_name]
if model_name in CONFIG_MAPPING:
return CONFIG_MAPPING[model_name]
raise ValueError(
f"Unrecognized configuration class for {model_name}. "
f"Model name should be one of {', '.join(CONFIG_MAPPING.keys())}."
)
if model_name in CONFIG_NAME_ALIASES: model_name = CONFIG_NAME_ALIASES[model_name]
if model_name in CONFIG_MAPPING: return CONFIG_MAPPING[model_name]
raise ValueError(f"Unrecognized configuration class for {model_name}. Model name should be one of {', '.join(CONFIG_MAPPING.keys())}.")

View File

@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import importlib
import inspect
@@ -19,31 +18,18 @@ import logging
import sys
import typing as t
from collections import OrderedDict
import inflection
import openllm
from .configuration_auto import AutoConfig
from ...utils import ReprMixin
# NOTE: We need to do this so that overload can register
# correct overloads to typing registry
if sys.version_info[:2] >= (3, 11):
from typing import overload
else:
from typing_extensions import overload
if sys.version_info[:2] >= (3, 11): from typing import overload
else: from typing_extensions import overload
if t.TYPE_CHECKING:
import types
from collections import _odict_items
from collections import _odict_keys
from collections import _odict_values
from ..._llm import LLMRunner
from collections import _odict_items, _odict_keys, _odict_values
ConfigModelOrderedDict = OrderedDict[type[openllm.LLMConfig], type[openllm.LLM[t.Any, t.Any]]]
ConfigModelKeysView = _odict_keys[type[openllm.LLMConfig], type[openllm.LLM[t.Any, t.Any]]]
ConfigModelValuesView = _odict_values[type[openllm.LLMConfig], type[openllm.LLM[t.Any, t.Any]]]
@@ -51,58 +37,18 @@ if t.TYPE_CHECKING:
else:
ConfigModelKeysView = ConfigModelValuesView = ConfigModelItemsView = t.Any
ConfigModelOrderedDict = OrderedDict
logger = logging.getLogger(__name__)
class BaseAutoLLMClass:
_model_mapping: _LazyAutoMapping
def __init__(self, *args: t.Any, **attrs: t.Any): # noqa
raise EnvironmentError(
f"Cannot instantiate {self.__class__.__name__} directly. "
"Please use '{self.__class__.__name__}.Runner(model_name)' instead."
)
def __init__(self, *args: t.Any, **attrs: t.Any): raise EnvironmentError(f"Cannot instantiate {self.__class__.__name__} directly. Please use '{self.__class__.__name__}.Runner(model_name)' instead.")
@overload
@classmethod
def for_model(
cls,
model: str,
model_id: str | None = None,
model_version: str | None = None,
return_runner_kwargs: t.Literal[False] = False,
llm_config: openllm.LLMConfig | None = ...,
ensure_available: t.Literal[False, True] = ...,
**attrs: t.Any,
) -> openllm.LLM[t.Any, t.Any]:
...
def for_model(cls, model: str, model_id: str | None = None, model_version: str | None = None, return_runner_kwargs: t.Literal[False] = False, llm_config: openllm.LLMConfig | None = ..., ensure_available: t.Literal[False, True] = ..., **attrs: t.Any) -> openllm.LLM[t.Any, t.Any]: ...
@overload
@classmethod
def for_model(
cls,
model: str,
model_id: str | None = None,
model_version: str | None = None,
return_runner_kwargs: t.Literal[True] = ...,
llm_config: openllm.LLMConfig | None = ...,
ensure_available: t.Literal[False, True] = ...,
**attrs: t.Any,
) -> tuple[openllm.LLM[t.Any, t.Any], dict[str, t.Any]]:
...
def for_model(cls, model: str, model_id: str | None = None, model_version: str | None = None, return_runner_kwargs: t.Literal[True] = ..., llm_config: openllm.LLMConfig | None = ..., ensure_available: t.Literal[False, True] = ..., **attrs: t.Any) -> tuple[openllm.LLM[t.Any, t.Any], dict[str, t.Any]]: ...
@classmethod
def for_model(
cls,
model: str,
model_id: str | None = None,
model_version: str | None = None,
return_runner_kwargs: bool = False,
llm_config: openllm.LLMConfig | None = None,
ensure_available: bool = False,
**attrs: t.Any,
) -> openllm.LLM[t.Any, t.Any] | tuple[openllm.LLM[t.Any, t.Any], dict[str, t.Any]]:
def for_model(cls, model: str, model_id: str | None = None, model_version: str | None = None, return_runner_kwargs: bool = False, llm_config: openllm.LLMConfig | None = None, ensure_available: bool = False, **attrs: t.Any) -> openllm.LLM[t.Any, t.Any] | tuple[openllm.LLM[t.Any, t.Any], dict[str, t.Any]]:
"""The lower level API for creating a LLM instance.
```python
@@ -128,16 +74,10 @@ class BaseAutoLLMClass:
if type(llm_config) in cls._model_mapping.keys():
model_class = cls._model_mapping[type(llm_config)]
llm = model_class.from_pretrained(model_id, model_version=model_version, llm_config=llm_config, **attrs)
if ensure_available:
llm.ensure_model_id_exists()
if not return_runner_kwargs:
return llm
if ensure_available: llm.ensure_model_id_exists()
if not return_runner_kwargs: return llm
return llm, to_runner_attrs
raise ValueError(
f"Unrecognized configuration class {llm_config.__class__} for this kind of AutoLLM: {cls.__name__}.\n"
f"LLM type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}."
)
raise ValueError(f"Unrecognized configuration class {llm_config.__class__} for this kind of AutoLLM: {cls.__name__}.\nLLM type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}.")
@classmethod
def create_runner(cls, model: str, model_id: str | None = None, **attrs: t.Any) -> LLMRunner:
"""Create a LLM Runner for the given model name.
@@ -152,7 +92,6 @@ class BaseAutoLLMClass:
"""
llm, runner_attrs = cls.for_model(model, model_id, return_runner_kwargs=True, **attrs)
return llm.to_runner(**runner_attrs)
@classmethod
def register(cls, config_class: type[openllm.LLMConfig], llm_class: type[openllm.LLM[t.Any, t.Any]]):
"""Register a new model for this class.
@@ -161,142 +100,60 @@ class BaseAutoLLMClass:
config_class: The configuration corresponding to the model to register.
llm_class: The runnable to register.
"""
if hasattr(llm_class, "config_class") and llm_class.config_class is not config_class:
raise ValueError(
"The model class you are passing has a `config_class` attribute that is not consistent with the "
f"config class you passed (model has {llm_class.config_class} and you passed {config_class}. Fix "
"one of those so they match!"
)
if hasattr(llm_class, "config_class") and llm_class.config_class is not config_class: raise ValueError("The model class you are passing has a `config_class` attribute that is not consistent with the config class you passed (model has {llm_class.config_class} and you passed {config_class}. Fix one of those so they match!")
cls._model_mapping.register(config_class, llm_class)
def getattribute_from_module(module: types.ModuleType, attr: t.Any) -> t.Any:
if attr is None:
return
if isinstance(attr, tuple):
return tuple(getattribute_from_module(module, a) for a in attr)
if hasattr(module, attr):
return getattr(module, attr)
if attr is None: return
if isinstance(attr, tuple): return tuple(getattribute_from_module(module, a) for a in attr)
if hasattr(module, attr): return getattr(module, attr)
# Some of the mappings have entries model_type -> object of another model type. In that case we try to grab the
# object at the top level.
openllm_module = importlib.import_module("openllm")
if module != openllm_module:
try:
return getattribute_from_module(openllm_module, attr)
except ValueError:
raise ValueError(f"Could not find {attr} neither in {module} nor in {openllm_module}!") from None
else:
raise ValueError(f"Could not find {attr} in {openllm_module}!")
try: return getattribute_from_module(openllm_module, attr)
except ValueError: raise ValueError(f"Could not find {attr} neither in {module} nor in {openllm_module}!") from None
else: raise ValueError(f"Could not find {attr} in {openllm_module}!")
class _LazyAutoMapping(ConfigModelOrderedDict, ReprMixin):
"""Based on transformers.models.auto.configuration_auto._LazyAutoMapping.
This OrderedDict values() and keys() returns the list instead, so you don't
have to do list(mapping.values()) to get the list of values.
"""
def __init__(
self,
config_mapping: OrderedDict[t.LiteralString, t.LiteralString],
model_mapping: OrderedDict[t.LiteralString, t.LiteralString],
):
def __init__(self, config_mapping: OrderedDict[t.LiteralString, t.LiteralString], model_mapping: OrderedDict[t.LiteralString, t.LiteralString]):
self._config_mapping = config_mapping
self._reverse_config_mapping = {v: k for k, v in config_mapping.items()}
self._model_mapping = model_mapping
self._extra_content: dict[t.Any, t.Any] = {}
self._modules: dict[str, types.ModuleType] = {}
def __len__(self):
common_keys = set(self._config_mapping.keys()).intersection(self._model_mapping.keys())
return len(common_keys) + len(self._extra_content)
def __len__(self): return len(set(self._config_mapping.keys()).intersection(self._model_mapping.keys())) + len(self._extra_content)
def __getitem__(self, key: type[openllm.LLMConfig]) -> type[openllm.LLM[t.Any, t.Any]]:
if key in self._extra_content:
return self._extra_content[key]
if key in self._extra_content: return self._extra_content[key]
model_type = self._reverse_config_mapping[key.__name__]
if model_type in self._model_mapping:
model_name = self._model_mapping[model_type]
return self._load_attr_from_module(model_type, model_name)
if model_type in self._model_mapping: return self._load_attr_from_module(model_type, self._model_mapping[model_type])
# Maybe there was several model types associated with this config.
model_types = [k for k, v in self._config_mapping.items() if v == key.__name__]
for mtype in model_types:
if mtype in self._model_mapping:
model_name = self._model_mapping[mtype]
return self._load_attr_from_module(mtype, model_name)
if mtype in self._model_mapping: return self._load_attr_from_module(mtype, self._model_mapping[mtype])
raise KeyError(key)
def _load_attr_from_module(self, model_type: str, attr: str) -> t.Any:
module_name = inflection.underscore(model_type)
if module_name not in self._modules:
self._modules[module_name] = importlib.import_module(f".{module_name}", "openllm.models")
if module_name not in self._modules: self._modules[module_name] = importlib.import_module(f".{module_name}", "openllm.models")
return getattribute_from_module(self._modules[module_name], attr)
def keys(self):
mapping_keys = [
self._load_attr_from_module(key, name)
for key, name in self._config_mapping.items()
if key in self._model_mapping.keys()
]
return t.cast(ConfigModelKeysView, mapping_keys + list(self._extra_content.keys()))
def keys(self): return t.cast(ConfigModelKeysView, [self._load_attr_from_module(key, name) for key, name in self._config_mapping.items() if key in self._model_mapping.keys()] + list(self._extra_content.keys()))
@property
def __repr_keys__(self) -> set[str]:
return set(self._config_mapping.keys())
def __repr__(self) -> str:
return ReprMixin.__repr__(self)
def __repr_args__(self) -> t.Generator[tuple[str, tuple[str, str]], t.Any, t.Any]:
yield from (
(key, (value, self._model_mapping[key]))
for key, value in self._config_mapping.items()
if key in self._model_mapping
)
def __bool__(self):
return bool(self.keys())
def values(self):
mapping_values = [
self._load_attr_from_module(key, name)
for key, name in self._model_mapping.items()
if key in self._config_mapping.keys()
]
return t.cast(ConfigModelValuesView, mapping_values + list(self._extra_content.values()))
def items(self):
mapping_items = [
(
self._load_attr_from_module(key, self._config_mapping[key]),
self._load_attr_from_module(key, self._model_mapping[key]),
)
for key in self._model_mapping.keys()
if key in self._config_mapping.keys()
]
return t.cast(ConfigModelItemsView, mapping_items + list(self._extra_content.items()))
def __iter__(self):
return iter(self.keys())
def __repr_keys__(self) -> set[str]: return set(self._config_mapping.keys())
def __repr__(self) -> str: return ReprMixin.__repr__(self)
def __repr_args__(self) -> t.Generator[tuple[str, tuple[str, str]], t.Any, t.Any]: yield from ((key, (value, self._model_mapping[key])) for key, value in self._config_mapping.items() if key in self._model_mapping)
def __bool__(self): return bool(self.keys())
def values(self): return t.cast(ConfigModelValuesView, [self._load_attr_from_module(key, name) for key, name in self._model_mapping.items() if key in self._config_mapping.keys()] + list(self._extra_content.values()))
def items(self): return t.cast(ConfigModelItemsView, [(self._load_attr_from_module(key, self._config_mapping[key]), self._load_attr_from_module(key, self._model_mapping[key])) for key in self._model_mapping.keys() if key in self._config_mapping.keys()] + list(self._extra_content.items()))
def __iter__(self): return iter(self.keys())
def __contains__(self, item: t.Any):
if item in self._extra_content:
return True
if not hasattr(item, "__name__") or item.__name__ not in self._reverse_config_mapping:
return False
model_type = self._reverse_config_mapping[item.__name__]
return model_type in self._model_mapping
if item in self._extra_content: return True
if not hasattr(item, "__name__") or item.__name__ not in self._reverse_config_mapping: return False
return self._reverse_config_mapping[item.__name__] in self._model_mapping
def register(self, key: t.Any, value: t.Any):
"""Register a new model in this mapping."""
if hasattr(key, "__name__") and key.__name__ in self._reverse_config_mapping:
model_type = self._reverse_config_mapping[key.__name__]
if model_type in self._model_mapping.keys():
raise ValueError(f"'{key}' is already used by a OpenLLM model.")
if self._reverse_config_mapping[key.__name__] in self._model_mapping.keys(): raise ValueError(f"'{key}' is already used by a OpenLLM model.")
self._extra_content[key] = value
__all__ = ["BaseAutoLLMClass", "_LazyAutoMapping"]

View File

@@ -14,12 +14,9 @@
from __future__ import annotations
from collections import OrderedDict
from .configuration_auto import CONFIG_MAPPING_NAMES
from .factory import BaseAutoLLMClass
from .factory import _LazyAutoMapping
MODEL_MAPPING_NAMES = OrderedDict(
[
("chatglm", "ChatGLM"),
@@ -35,9 +32,6 @@ MODEL_MAPPING_NAMES = OrderedDict(
("baichuan", "Baichuan"),
]
)
MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_MAPPING_NAMES)
class AutoLLM(BaseAutoLLMClass):
_model_mapping = MODEL_MAPPING

View File

@@ -11,24 +11,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from collections import OrderedDict
from .configuration_auto import CONFIG_MAPPING_NAMES
from .factory import BaseAutoLLMClass
from .factory import _LazyAutoMapping
MODEL_FLAX_MAPPING_NAMES = OrderedDict(
[
("flan_t5", "FlaxFlanT5"),
("opt", "FlaxOPT"),
]
)
MODEL_FLAX_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FLAX_MAPPING_NAMES)
class AutoFlaxLLM(BaseAutoLLMClass):
_model_mapping = MODEL_FLAX_MAPPING

View File

@@ -11,24 +11,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from collections import OrderedDict
from .configuration_auto import CONFIG_MAPPING_NAMES
from .factory import BaseAutoLLMClass
from .factory import _LazyAutoMapping
MODEL_TF_MAPPING_NAMES = OrderedDict(
[
("flan_t5", "TFFlanT5"),
("opt", "TFOPT"),
]
)
MODEL_TF_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_TF_MAPPING_NAMES)
class AutoTFLLM(BaseAutoLLMClass):
_model_mapping = MODEL_TF_MAPPING

View File

@@ -11,23 +11,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from collections import OrderedDict
from .configuration_auto import CONFIG_MAPPING_NAMES
from .factory import BaseAutoLLMClass
from .factory import _LazyAutoMapping
MODEL_VLLM_MAPPING_NAMES = OrderedDict(
[
("llama", "VLLMLlaMA"),
]
)
MODEL_VLLM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_VLLM_MAPPING_NAMES)
class AutoVLLM(BaseAutoLLMClass):
_model_mapping = MODEL_VLLM_MAPPING

View File

@@ -40,10 +40,9 @@ class Falcon(openllm.LLM["transformers.PreTrainedModel", "transformers.PreTraine
eos_token_id, inputs = attrs.pop("eos_token_id", self.tokenizer.eos_token_id), self.tokenizer(prompt, return_tensors="pt").to(self.device)
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.float16): return self.tokenizer.batch_decode(self.model.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], generation_config=self.config.model_construct_env( eos_token_id=eos_token_id, **attrs).to_generation_config()), skip_special_tokens=True)
def generate_one(self, prompt: str, stop: list[str], **preprocess_generate_kwds: t.Any) -> list[dict[t.Literal["generated_text"], str]]:
from ..._generation import StopSequenceCriteria
max_new_tokens, encoded_inputs = preprocess_generate_kwds.pop("max_new_tokens", 200), self.tokenizer(prompt, return_tensors="pt").to(self.device)
src_len, stopping_criteria = encoded_inputs["input_ids"].shape[1], preprocess_generate_kwds.pop("stopping_criteria", transformers.StoppingCriteriaList([]))
stopping_criteria.append(StopSequenceCriteria(stop, self.tokenizer))
stopping_criteria.append(openllm.StopSequenceCriteria(stop, self.tokenizer))
result = self.tokenizer.decode(self.model.generate(encoded_inputs["input_ids"], max_new_tokens=max_new_tokens, stopping_criteria=stopping_criteria)[0].tolist()[src_len:])
# Inference API returns the stop sequence
for stop_seq in stop:

View File

@@ -39,6 +39,4 @@ class GPTNeoX(openllm.LLM["transformers.GPTNeoXForCausalLM", "transformers.GPTNe
if self.config.use_half_precision: model.half()
return model
def generate(self, prompt: str, **attrs: t.Any) -> list[str]:
from ..._generation import StopOnTokens
generation_kwargs = {"do_sample": True, "generation_config": self.config.model_construct_env(**attrs).to_generation_config(), "pad_token_id": self.tokenizer.eos_token_id, "stopping_criteria": transformers.StoppingCriteriaList([StopOnTokens()])}
with torch.inference_mode(): return self.tokenizer.batch_decode(self.model.generate(self.tokenizer(prompt, return_tensors="pt").to(self.device).input_ids, **generation_kwargs))
with torch.inference_mode(): return self.tokenizer.batch_decode(self.model.generate(self.tokenizer(prompt, return_tensors="pt").to(self.device).input_ids, do_sample=True, generation_config=self.config.model_construct_env(**attrs).to_generation_config(), pad_token_id=self.tokenizer.eos_token_id, stopping_criteria=transformers.StoppingCriteriaList([openllm.StopOnTokens()])))

View File

@@ -37,9 +37,7 @@ class LlaMA(openllm.LLM["transformers.LlamaForCausalLM", "transformers.LlamaToke
def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]: return {"device_map": "auto" if torch.cuda.device_count() > 1 else None}, {}
def postprocess_generate(self, prompt: str, generation_result: list[str], **_: t.Any) -> str: return generation_result[0]
def generate(self, prompt: str, **attrs: t.Any) -> list[str]:
from ..._generation import StopOnTokens
generation_kwargs = {"generation_config": self.config.model_construct_env(**attrs).to_generation_config(), "stopping_criteria": transformers.StoppingCriteriaList([StopOnTokens()])}
with torch.inference_mode(): return self.tokenizer.batch_decode(self.model.generate(**self.tokenizer(prompt, return_tensors="pt").to(self.device), **generation_kwargs), skip_special_tokens=True, clean_up_tokenization_spaces=True)
with torch.inference_mode(): return self.tokenizer.batch_decode(self.model.generate(**self.tokenizer(prompt, return_tensors="pt").to(self.device), generation_config=self.config.model_construct_env(**attrs).to_generation_config(), stopping_criteria=transformers.StoppingCriteriaList([openllm.StopOnTokens()])), skip_special_tokens=True, clean_up_tokenization_spaces=True)
def embeddings(self, prompts: list[str]) -> LLMEmbeddings:
encoding = self.tokenizer(prompts, padding=True, return_tensors="pt").to(self.device)
input_ids, attention_mask = encoding["input_ids"], encoding["attention_mask"]

View File

@@ -36,5 +36,4 @@ class StableLM(openllm.LLM["transformers.GPTNeoXForCausalLM", "transformers.GPTN
return prompt_text, {"max_new_tokens": max_new_tokens, "temperature": temperature, "top_k": top_k, "top_p": top_p}, {}
def postprocess_generate(self, prompt: str, generation_result: list[str], **_: t.Any) -> str: return generation_result[0]
def generate(self, prompt: str, **attrs: t.Any) -> list[str]:
from ..._generation import StopOnTokens
with torch.inference_mode(): return [self.tokenizer.decode(self.model.generate(**self.tokenizer(prompt, return_tensors="pt").to(self.device), do_sample=True, generation_config=self.config.model_construct_env(**attrs).to_generation_config(), pad_token_id=self.tokenizer.eos_token_id, stopping_criteria=transformers.StoppingCriteriaList([StopOnTokens()]))[0], skip_special_tokens=True)]
with torch.inference_mode(): return [self.tokenizer.decode(self.model.generate(**self.tokenizer(prompt, return_tensors="pt").to(self.device), do_sample=True, generation_config=self.config.model_construct_env(**attrs).to_generation_config(), pad_token_id=self.tokenizer.eos_token_id, stopping_criteria=transformers.StoppingCriteriaList([openllm.StopOnTokens()]))[0], skip_special_tokens=True)]

View File

@@ -55,10 +55,9 @@ class StarCoder(openllm.LLM["transformers.GPTBigCodeForCausalLM", "transformers.
# return (skip_special_tokens=False, clean_up_tokenization_spaces=False))
return self.tokenizer.batch_decode(result_tensor[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
def generate_one(self, prompt: str, stop: list[str], **preprocess_generate_kwds: t.Any) -> list[dict[t.Literal["generated_text"], str]]:
from ..._generation import StopSequenceCriteria
max_new_tokens, encoded_inputs = preprocess_generate_kwds.pop("max_new_tokens", 200), self.tokenizer(prompt, return_tensors="pt").to(self.device)
src_len, stopping_criteria = encoded_inputs["input_ids"].shape[1], preprocess_generate_kwds.pop("stopping_criteria", transformers.StoppingCriteriaList([]))
stopping_criteria.append(StopSequenceCriteria(stop, self.tokenizer))
stopping_criteria.append(openllm.StopSequenceCriteria(stop, self.tokenizer))
result = self.tokenizer.decode(self.model.generate(encoded_inputs["input_ids"], max_new_tokens=max_new_tokens, stopping_criteria=stopping_criteria)[0].tolist()[src_len:])
# Inference API returns the stop sequence
for stop_seq in stop:

View File

@@ -17,6 +17,8 @@ from __future__ import annotations
FRAMEWORK_TO_AUTOCLASS_MAPPING = {
"pt": ("AutoModelForCausalLM", "AutoModelForSeq2SeqLM"),
# NOTE: vllm will use PyTorch implementation of transformers for serialisation
"vllm": ("AutoModelForCausalLM", "AutoModelForSeq2SeqLM"),
"tf": ("TFAutoModelForCausalLM", "TFAutoModelForSeq2SeqLM"),
"flax": ("FlaxAutoModelForCausalLM", "FlaxAutoModelForSeq2SeqLM"),
}

View File

@@ -329,6 +329,9 @@ def save_pretrained(
save_function = first_not_none(save_function, default=torch.save)
model_save_attrs, tokenizer_save_attrs = normalize_attrs_to_model_tokenizer_pair(**attrs)
safe_serialization = safe_serialization or llm._serialisation_format == "safetensors"
if llm.__llm_implementation__ == "vllm":
# NOTE: disable safetensors for vllm
safe_serialization = False
if llm._quantize_method == "gptq":
if not is_autogptq_available():
raise OpenLLMException(