mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-03-10 19:22:10 -04:00
405 lines
18 KiB
Python
405 lines
18 KiB
Python
# Copyright 2023 BentoML Team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""Some imports utils are vendorred from transformers/utils/import_utils.py for performance reasons."""
|
|
from __future__ import annotations
|
|
import functools
|
|
import importlib
|
|
import importlib.metadata
|
|
import importlib.util
|
|
import logging
|
|
import os
|
|
import sys
|
|
import typing as t
|
|
from abc import ABCMeta
|
|
from collections import OrderedDict
|
|
|
|
import inflection
|
|
from packaging import version
|
|
|
|
from bentoml._internal.utils import LazyLoader
|
|
from bentoml._internal.utils import pkg
|
|
|
|
from .representation import ReprMixin
|
|
|
|
# NOTE: We need to do this so that overload can register
|
|
# correct overloads to typing registry
|
|
if sys.version_info[:2] >= (3, 11):
|
|
from typing import overload
|
|
else:
|
|
from typing_extensions import overload
|
|
|
|
if t.TYPE_CHECKING:
|
|
BackendOrderredDict = OrderedDict[str, tuple[t.Callable[[], bool], str]]
|
|
from .._types import LiteralRuntime
|
|
from .._types import P
|
|
from .._types import T
|
|
|
|
class _AnnotatedLazyLoader(LazyLoader, t.Generic[T]):
|
|
DEFAULT_PROMPT_TEMPLATE: t.LiteralString | None | t.Callable[[T], t.LiteralString]
|
|
PROMPT_MAPPING: dict[T, t.LiteralString] | None
|
|
|
|
else:
|
|
_AnnotatedLazyLoader = LazyLoader
|
|
BackendOrderredDict = OrderedDict
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
OPTIONAL_DEPENDENCIES = {"opt", "flan-t5", "vllm", "fine-tune", "ggml", "agents", "openai", "playground", "gptq",}
|
|
ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"}
|
|
ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"})
|
|
USE_TF = os.environ.get("USE_TF", "AUTO").upper()
|
|
USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
|
|
USE_JAX = os.environ.get("USE_FLAX", "AUTO").upper()
|
|
FORCE_TF_AVAILABLE = os.environ.get("FORCE_TF_AVAILABLE", "AUTO").upper()
|
|
|
|
def _is_package_available(package: str) -> bool:
|
|
_package_available = importlib.util.find_spec(package) is not None
|
|
if _package_available:
|
|
try:
|
|
importlib.metadata.version(package)
|
|
except importlib.metadata.PackageNotFoundError:
|
|
_package_available = False
|
|
return _package_available
|
|
|
|
_torch_available = importlib.util.find_spec("torch") is not None
|
|
_tf_available = importlib.util.find_spec("tensorflow") is not None
|
|
_flax_available = importlib.util.find_spec("jax") is not None and importlib.util.find_spec("flax") is not None
|
|
_vllm_available = importlib.util.find_spec("vllm") is not None
|
|
_peft_available = _is_package_available("peft")
|
|
_einops_available = _is_package_available("einops")
|
|
_cpm_kernel_available = _is_package_available("cpm_kernels")
|
|
_bitsandbytes_available = _is_package_available("bitsandbytes")
|
|
_datasets_available = _is_package_available("datasets")
|
|
_triton_available = _is_package_available("triton")
|
|
_jupyter_available = _is_package_available("jupyter")
|
|
_jupytext_available = _is_package_available("jupytext")
|
|
_notebook_available = _is_package_available("notebook")
|
|
_autogptq_available = _is_package_available("auto_gptq")
|
|
|
|
def is_transformers_supports_kbit() -> bool:
|
|
return pkg.pkg_version_info("transformers")[:2] >= (4, 30)
|
|
|
|
def is_transformers_supports_agent() -> bool:
|
|
return pkg.pkg_version_info("transformers")[:2] >= (4, 29)
|
|
|
|
def is_jupyter_available() -> bool:
|
|
return _jupyter_available
|
|
|
|
def is_jupytext_available() -> bool:
|
|
return _jupytext_available
|
|
|
|
def is_notebook_available() -> bool:
|
|
return _notebook_available
|
|
|
|
def is_triton_available() -> bool:
|
|
return _triton_available
|
|
|
|
def is_datasets_available() -> bool:
|
|
return _datasets_available
|
|
|
|
def is_peft_available() -> bool:
|
|
return _peft_available
|
|
|
|
def is_einops_available() -> bool:
|
|
return _einops_available
|
|
|
|
def is_cpm_kernels_available() -> bool:
|
|
return _cpm_kernel_available
|
|
|
|
def is_bitsandbytes_available() -> bool:
|
|
return _bitsandbytes_available
|
|
|
|
def is_autogptq_available() -> bool:
|
|
return _autogptq_available
|
|
|
|
def is_vllm_available() -> bool:
|
|
return _vllm_available
|
|
|
|
def is_torch_available() -> bool:
|
|
global _torch_available
|
|
if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES:
|
|
if _torch_available:
|
|
try:
|
|
importlib.metadata.version("torch")
|
|
except importlib.metadata.PackageNotFoundError:
|
|
_torch_available = False
|
|
else:
|
|
logger.info("Disabling PyTorch because USE_TF is set")
|
|
_torch_available = False
|
|
return _torch_available
|
|
|
|
def is_tf_available() -> bool:
|
|
global _tf_available
|
|
if FORCE_TF_AVAILABLE in ENV_VARS_TRUE_VALUES: _tf_available = True
|
|
else:
|
|
_tf_version = None
|
|
if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VALUES:
|
|
if _tf_available:
|
|
candidates = ("tensorflow", "tensorflow-cpu", "tensorflow-gpu", "tf-nightly", "tf-nightly-cpu", "tf-nightly-gpu", "intel-tensorflow", "intel-tensorflow-avx512", "tensorflow-rocm", "tensorflow-macos", "tensorflow-aarch64",)
|
|
_tf_version = None
|
|
# For the metadata, we have to look for both tensorflow and tensorflow-cpu
|
|
for _pkg in candidates:
|
|
try:
|
|
_tf_version = importlib.metadata.version(_pkg)
|
|
break
|
|
except importlib.metadata.PackageNotFoundError:
|
|
pass
|
|
_tf_available = _tf_version is not None
|
|
if _tf_available:
|
|
if _tf_version and version.parse(_tf_version) < version.parse("2"):
|
|
logger.info("TensorFlow found but with version %s. OpenLLM only supports TF 2.x", _tf_version)
|
|
_tf_available = False
|
|
else:
|
|
logger.info("Disabling Tensorflow because USE_TORCH is set")
|
|
_tf_available = False
|
|
return _tf_available
|
|
|
|
def is_flax_available() -> bool:
|
|
global _flax_available
|
|
if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES:
|
|
if _flax_available:
|
|
try:
|
|
importlib.metadata.version("jax")
|
|
importlib.metadata.version("flax")
|
|
except importlib.metadata.PackageNotFoundError:
|
|
_flax_available = False
|
|
else:
|
|
_flax_available = False
|
|
return _flax_available
|
|
|
|
def requires_dependencies(package: str | list[str], *, extra: str | list[str] | None = None) -> t.Callable[[t.Callable[P, t.Any]], t.Callable[P, t.Any]]:
|
|
import openllm.utils
|
|
|
|
if isinstance(package, str): package = [package]
|
|
if isinstance(extra, str): extra = [extra]
|
|
|
|
def decorator(func: t.Callable[P, t.Any]) -> t.Callable[P, t.Any]:
|
|
@functools.wraps(func)
|
|
def wrapper(*args: P.args, **kwargs: P.kwargs) -> t.Any:
|
|
for p in package:
|
|
cached_check: t.Callable[[], bool] | None = getattr(openllm.utils, f"is_{p}_available", None)
|
|
if not ((cached_check is not None and cached_check()) or _is_package_available(p)):
|
|
raise ImportError(f"{func.__name__} requires '{p}' to be available locally (Currently missing). Make sure to have {p} to be installed: 'pip install \"{p if not extra else 'openllm['+', '.join(extra)+']'}\"'")
|
|
return func(*args, **kwargs)
|
|
|
|
return wrapper
|
|
|
|
return decorator
|
|
|
|
VLLM_IMPORT_ERROR_WITH_PYTORCH = """\
|
|
{0} requires the vLLM library but it was not found in your environment.
|
|
However, we were able to find a PyTorch installation. PyTorch classes do not begin
|
|
with "VLLM", but are otherwise identically named to our PyTorch classes.
|
|
If you want to use PyTorch, please use those classes instead!
|
|
|
|
If you really do want to use vLLM, please follow the instructions on the
|
|
installation page https://github.com/vllm-project/vllm that match your environment.
|
|
"""
|
|
|
|
VLLM_IMPORT_ERROR_WITH_TF = """\
|
|
{0} requires the vLLM library but it was not found in your environment.
|
|
However, we were able to find a TensorFlow installation. TensorFlow classes begin
|
|
with "TF", but are otherwise identically named to the PyTorch classes. This
|
|
means that the TF equivalent of the class you tried to import would be "TF{0}".
|
|
If you want to use TensorFlow, please use TF classes instead!
|
|
|
|
If you really do want to use vLLM, please follow the instructions on the
|
|
installation page https://github.com/vllm-project/vllm that match your environment.
|
|
"""
|
|
|
|
VLLM_IMPORT_ERROR_WITH_FLAX = """\
|
|
{0} requires the vLLM library but it was not found in your environment.
|
|
However, we were able to find a Flax installation. Flax classes begin
|
|
with "Flax", but are otherwise identically named to the PyTorch classes. This
|
|
means that the Flax equivalent of the class you tried to import would be "Flax{0}".
|
|
If you want to use Flax, please use Flax classes instead!
|
|
|
|
If you really do want to use vLLM, please follow the instructions on the
|
|
installation page https://github.com/vllm-project/vllm that match your environment.
|
|
"""
|
|
|
|
PYTORCH_IMPORT_ERROR_WITH_TF = """\
|
|
{0} requires the PyTorch library but it was not found in your environment.
|
|
However, we were able to find a TensorFlow installation. TensorFlow classes begin
|
|
with "TF", but are otherwise identically named to the PyTorch classes. This
|
|
means that the TF equivalent of the class you tried to import would be "TF{0}".
|
|
If you want to use TensorFlow, please use TF classes instead!
|
|
|
|
If you really do want to use PyTorch please go to
|
|
https://pytorch.org/get-started/locally/ and follow the instructions that
|
|
match your environment.
|
|
"""
|
|
|
|
TF_IMPORT_ERROR_WITH_PYTORCH = """\
|
|
{0} requires the TensorFlow library but it was not found in your environment.
|
|
However, we were able to find a PyTorch installation. PyTorch classes do not begin
|
|
with "TF", but are otherwise identically named to our TF classes.
|
|
If you want to use PyTorch, please use those classes instead!
|
|
|
|
If you really do want to use TensorFlow, please follow the instructions on the
|
|
installation page https://www.tensorflow.org/install that match your environment.
|
|
"""
|
|
|
|
TENSORFLOW_IMPORT_ERROR = """{0} requires the TensorFlow library but it was not found in your environment.
|
|
Checkout the instructions on the installation page: https://www.tensorflow.org/install and follow the
|
|
ones that match your environment. Please note that you may need to restart your runtime after installation.
|
|
"""
|
|
|
|
FLAX_IMPORT_ERROR = """{0} requires the FLAX library but it was not found in your environment.
|
|
Checkout the instructions on the installation page: https://github.com/google/flax and follow the
|
|
ones that match your environment. Please note that you may need to restart your runtime after installation.
|
|
"""
|
|
|
|
PYTORCH_IMPORT_ERROR = """{0} requires the PyTorch library but it was not found in your environment.
|
|
Checkout the instructions on the installation page: https://pytorch.org/get-started/locally/ and follow the
|
|
ones that match your environment. Please note that you may need to restart your runtime after installation.
|
|
"""
|
|
|
|
VLLM_IMPORT_ERROR = """{0} requires the vLLM library but it was not found in your environment.
|
|
Checkout the instructions on the installation page: https://github.com/vllm-project/vllm
|
|
ones that match your environment. Please note that you may need to restart your runtime after installation.
|
|
"""
|
|
|
|
CPM_KERNELS_IMPORT_ERROR = """{0} requires the cpm_kernels library but it was not found in your environment.
|
|
You can install it with pip: `pip install cpm_kernels`. Please note that you may need to restart your
|
|
runtime after installation.
|
|
"""
|
|
|
|
EINOPS_IMPORT_ERROR = """{0} requires the einops library but it was not found in your environment.
|
|
You can install it with pip: `pip install einops`. Please note that you may need to restart
|
|
your runtime after installation.
|
|
"""
|
|
|
|
TRITON_IMPORT_ERROR = """{0} requires the triton library but it was not found in your environment.
|
|
You can install it with pip: 'pip install \"git+https://github.com/openai/triton.git#egg=triton&subdirectory=python\"'.
|
|
Please note that you may need to restart your runtime after installation.
|
|
"""
|
|
|
|
DATASETS_IMPORT_ERROR = """{0} requires the datasets library but it was not found in your environment.
|
|
You can install it with pip: `pip install datasets`. Please note that you may need to restart
|
|
your runtime after installation.
|
|
"""
|
|
|
|
PEFT_IMPORT_ERROR = """{0} requires the peft library but it was not found in your environment.
|
|
You can install it with pip: `pip install peft`. Please note that you may need to restart
|
|
your runtime after installation.
|
|
"""
|
|
|
|
BITSANDBYTES_IMPORT_ERROR = """{0} requires the bitsandbytes library but it was not found in your environment.
|
|
You can install it with pip: `pip install bitsandbytes`. Please note that you may need to restart
|
|
your runtime after installation.
|
|
"""
|
|
|
|
AUTOGPTQ_IMPORT_ERROR = """{0} requires the auto-gptq library but it was not found in your environment.
|
|
You can install it with pip: `pip install auto-gptq`. Please note that you may need to restart
|
|
your runtime after installation.
|
|
"""
|
|
|
|
BACKENDS_MAPPING = BackendOrderredDict([("flax", (is_flax_available, FLAX_IMPORT_ERROR)), ("tf", (is_tf_available, TENSORFLOW_IMPORT_ERROR)), ("torch", (is_torch_available, PYTORCH_IMPORT_ERROR)), ("vllm", (is_vllm_available, VLLM_IMPORT_ERROR)), ("cpm_kernels", (is_cpm_kernels_available, CPM_KERNELS_IMPORT_ERROR)), ("einops", (is_einops_available, EINOPS_IMPORT_ERROR)),
|
|
("triton", (is_triton_available, TRITON_IMPORT_ERROR)), ("datasets", (is_datasets_available, DATASETS_IMPORT_ERROR)), ("peft", (is_peft_available, PEFT_IMPORT_ERROR)), ("bitsandbytes", (is_bitsandbytes_available, BITSANDBYTES_IMPORT_ERROR)), ("auto-gptq", (is_autogptq_available, AUTOGPTQ_IMPORT_ERROR)),])
|
|
|
|
class DummyMetaclass(ABCMeta):
|
|
"""Metaclass for dummy object.
|
|
|
|
It will raises ImportError generated by ``require_backends`` if users try to access attributes from given class.
|
|
"""
|
|
|
|
_backends: t.List[str]
|
|
|
|
def __getattribute__(cls, key: str) -> t.Any:
|
|
if key.startswith("_"): return super().__getattribute__(key)
|
|
require_backends(cls, cls._backends)
|
|
|
|
def require_backends(o: t.Any, backends: t.MutableSequence[str]) -> None:
|
|
if not isinstance(backends, (list, tuple)): backends = list(backends)
|
|
name = o.__name__ if hasattr(o, "__name__") else o.__class__.__name__
|
|
# Raise an error for users who might not realize that classes without "TF" are torch-only
|
|
if "torch" in backends and "tf" not in backends and not is_torch_available() and is_tf_available(): raise ImportError(PYTORCH_IMPORT_ERROR_WITH_TF.format(name))
|
|
# Raise the inverse error for PyTorch users trying to load TF classes
|
|
if "tf" in backends and "torch" not in backends and is_torch_available() and not is_tf_available(): raise ImportError(TF_IMPORT_ERROR_WITH_PYTORCH.format(name))
|
|
# Raise an error when vLLM is not available to consider the alternative, order from PyTorch -> Tensorflow -> Flax
|
|
if "vllm" in backends:
|
|
if "torch" not in backends and is_torch_available() and not is_vllm_available(): raise ImportError(VLLM_IMPORT_ERROR_WITH_PYTORCH.format(name))
|
|
if "tf" not in backends and is_tf_available() and not is_vllm_available(): raise ImportError(VLLM_IMPORT_ERROR_WITH_TF.format(name))
|
|
if "flax" not in backends and is_flax_available() and not is_vllm_available(): raise ImportError(VLLM_IMPORT_ERROR_WITH_FLAX.format(name))
|
|
|
|
checks = (BACKENDS_MAPPING[backend] for backend in backends)
|
|
failed = [msg.format(name) for available, msg in checks if not available()]
|
|
if failed: raise ImportError("".join(failed))
|
|
|
|
class EnvVarMixin(ReprMixin):
|
|
model_name: str
|
|
if t.TYPE_CHECKING:
|
|
config: str
|
|
model_id: str
|
|
quantize: str
|
|
framework: str
|
|
bettertransformer: str
|
|
runtime: str
|
|
|
|
# fmt: off
|
|
@overload
|
|
def __getitem__(self, item: t.Literal["config"]) -> str: ...
|
|
@overload
|
|
def __getitem__(self, item: t.Literal["model_id"]) -> str: ...
|
|
@overload
|
|
def __getitem__(self, item: t.Literal["quantize"]) -> str: ...
|
|
@overload
|
|
def __getitem__(self, item: t.Literal["framework"]) -> str: ...
|
|
@overload
|
|
def __getitem__(self, item: t.Literal["bettertransformer"]) -> str: ...
|
|
@overload
|
|
def __getitem__(self, item: t.Literal["runtime"]) -> str: ...
|
|
@overload
|
|
def __getitem__(self, item: t.Literal["framework_value"]) -> LiteralRuntime: ...
|
|
@overload
|
|
def __getitem__(self, item: t.Literal["quantize_value"]) -> t.Literal["int8", "int4", "gptq"] | None: ...
|
|
@overload
|
|
def __getitem__(self, item: t.Literal["model_id_value"]) -> str | None: ...
|
|
@overload
|
|
def __getitem__(self, item: t.Literal["bettertransformer_value"]) -> bool: ...
|
|
@overload
|
|
def __getitem__(self, item: t.Literal["runtime_value"]) -> t.Literal["ggml", "transformers"]: ...
|
|
# fmt: on
|
|
def __getitem__(self, item: str | t.Any) -> t.Any:
|
|
if hasattr(self, item): return getattr(self, item)
|
|
raise KeyError(f"Key {item} not found in {self}")
|
|
def __init__(self, model_name: str, implementation: LiteralRuntime = "pt", model_id: str | None = None, bettertransformer: bool | None = None, quantize: t.LiteralString | None = None,
|
|
runtime: t.Literal["ggml", "transformers"] = "transformers") -> None:
|
|
"""EnvVarMixin is a mixin class that returns the value extracted from environment variables."""
|
|
from .._configuration import field_env_key
|
|
self.model_name = inflection.underscore(model_name)
|
|
self._implementation = implementation
|
|
self._model_id = model_id
|
|
self._bettertransformer = bettertransformer
|
|
self._quantize = quantize
|
|
self._runtime = runtime
|
|
for att in {"config", "model_id", "quantize", "framework", "bettertransformer", "runtime"}: setattr(self, att, field_env_key(self.model_name, att.upper()))
|
|
@property
|
|
def __repr_keys__(self) -> set[str]: return {"config", "model_id", "quantize", "framework", "bettertransformer", "runtime"}
|
|
@property
|
|
def quantize_value(self) -> t.Literal["int8", "int4", "gptq"] | None: return t.cast(t.Optional[t.Literal["int8", "int4", "gptq"]], os.getenv(self["quantize"], self._quantize))
|
|
@property
|
|
def framework_value(self) -> LiteralRuntime: return t.cast(t.Literal["pt", "tf", "flax", "vllm"], os.getenv(self["framework"], self._implementation))
|
|
@property
|
|
def bettertransformer_value(self) -> bool: return os.getenv(self["bettertransformer"], str(self._bettertransformer)).upper() in ENV_VARS_TRUE_VALUES
|
|
@property
|
|
def model_id_value(self) -> str | None: return os.getenv(self["model_id"], self._model_id)
|
|
@property
|
|
def runtime_value(self) -> t.Literal["ggml", "transformers"]: return t.cast(t.Literal["ggml", "transformers"], os.getenv(self["runtime"], self._runtime))
|
|
@property
|
|
def start_docstring(self) -> str: return getattr(self.module, f"START_{self.model_name.upper()}_COMMAND_DOCSTRING")
|
|
@property
|
|
def module(self) -> _AnnotatedLazyLoader[t.LiteralString]: return _AnnotatedLazyLoader(self.model_name, globals(), f"openllm.models.{self.model_name}")
|