refactor: packages (#249)

This commit is contained in:
Aaron Pham
2023-08-22 08:55:46 -04:00
committed by GitHub
parent a964e659c1
commit 3ffb25a872
148 changed files with 2899 additions and 1937 deletions

View File

@@ -4,244 +4,19 @@ User can import these function for convenience, but
we won't ensure backward compatibility for these functions. So use with caution.
"""
from __future__ import annotations
import contextlib, functools, hashlib, logging, logging.config, os, sys, types, typing as t, openllm
from pathlib import Path
from circus.exc import ConflictError
from bentoml._internal.configuration import (
DEBUG_ENV_VAR as DEBUG_ENV_VAR,
GRPC_DEBUG_ENV_VAR as _GRPC_DEBUG_ENV_VAR,
QUIET_ENV_VAR as QUIET_ENV_VAR,
get_debug_mode as _get_debug_mode,
get_quiet_mode as _get_quiet_mode,
set_quiet_mode as set_quiet_mode,
)
from bentoml._internal.models.model import ModelContext as _ModelContext
from bentoml._internal.types import LazyType as LazyType
from bentoml._internal.utils import (
LazyLoader as LazyLoader,
bentoml_cattr as bentoml_cattr,
calc_dir_size as calc_dir_size,
first_not_none as first_not_none,
pkg as pkg,
reserve_free_port as reserve_free_port,
resolve_user_filepath as resolve_user_filepath,
)
from openllm.utils.lazy import (
LazyModule as LazyModule,
VersionInfo as VersionInfo,
import typing as t, openllm_core
from . import (
dummy_flax_objects as dummy_flax_objects,
dummy_pt_objects as dummy_pt_objects,
dummy_tf_objects as dummy_tf_objects,
dummy_vllm_objects as dummy_vllm_objects,
)
if t.TYPE_CHECKING:
from openllm._typing_compat import AnyCallable, LiteralRuntime
logger = logging.getLogger(__name__)
try: from typing import GenericAlias as _TypingGenericAlias # type: ignore
except ImportError: _TypingGenericAlias = () # type: ignore # python < 3.9 does not have GenericAlias (list[int], tuple[str, ...] and so on)
if sys.version_info < (3, 10): _WithArgsTypes = (_TypingGenericAlias,)
else: _WithArgsTypes: t.Any = (t._GenericAlias, types.GenericAlias, types.UnionType) # type: ignore # _GenericAlias is the actual GenericAlias implementation
DEV_DEBUG_VAR = "OPENLLMDEVDEBUG"
def set_debug_mode(enabled: bool, level: int = 1) -> None:
# monkeypatch bentoml._internal.configuration.set_debug_mode to remove unused logs
if enabled: os.environ[DEV_DEBUG_VAR] = str(level)
os.environ[DEBUG_ENV_VAR] = str(enabled)
os.environ[_GRPC_DEBUG_ENV_VAR] = "DEBUG" if enabled else "ERROR"
def lenient_issubclass(cls: t.Any, class_or_tuple: type[t.Any] | tuple[type[t.Any], ...] | None) -> bool:
try:
return isinstance(cls, type) and issubclass(cls, class_or_tuple) # type: ignore[arg-type]
except TypeError:
if isinstance(cls, _WithArgsTypes): return False
raise
def available_devices() -> tuple[str, ...]:
"""Return available GPU under system. Currently only supports NVIDIA GPUs."""
from openllm._strategies import NvidiaGpuResource
return tuple(NvidiaGpuResource.from_system())
@functools.lru_cache(maxsize=128)
def generate_hash_from_file(f: str, algorithm: t.Literal["md5", "sha1"] = "sha1") -> str:
"""Generate a hash from given file's modification time.
Args:
f: The file to generate the hash from.
algorithm: The hashing algorithm to use. Defaults to 'sha1' (similar to how Git generate its commit hash.)
Returns:
The generated hash.
"""
return getattr(hashlib, algorithm)(str(os.path.getmtime(resolve_filepath(f))).encode()).hexdigest()
@functools.lru_cache(maxsize=1)
def device_count() -> int: return len(available_devices())
# equivocal setattr to save one lookup per assignment
_object_setattr = object.__setattr__
def non_intrusive_setattr(obj: t.Any, name: str, value: t.Any) -> None:
"""This makes sure that we don't overwrite any existing attributes on the object."""
_setattr = functools.partial(setattr, obj) if isinstance(obj, type) else _object_setattr.__get__(obj)
if not hasattr(obj, name): _setattr(name, value)
def field_env_key(model_name: str, key: str, suffix: str | None = None) -> str: return "_".join(filter(None, map(str.upper, ["OPENLLM", model_name, suffix.strip("_") if suffix else "", key])))
# Special debug flag controled via OPENLLMDEVDEBUG
DEBUG: bool = sys.flags.dev_mode or (not sys.flags.ignore_environment and bool(os.environ.get(DEV_DEBUG_VAR)))
# MYPY is like t.TYPE_CHECKING, but reserved for Mypy plugins
MYPY = False
SHOW_CODEGEN: bool = DEBUG and int(os.environ.get("OPENLLMDEVDEBUG", str(0))) > 3
def get_debug_mode() -> bool: return DEBUG or _get_debug_mode()
def get_quiet_mode() -> bool: return not DEBUG and _get_quiet_mode()
class ExceptionFilter(logging.Filter):
def __init__(self, exclude_exceptions: list[type[Exception]] | None = None, **kwargs: t.Any):
"""A filter of all exception."""
if exclude_exceptions is None: exclude_exceptions = [ConflictError]
if ConflictError not in exclude_exceptions: exclude_exceptions.append(ConflictError)
super(ExceptionFilter, self).__init__(**kwargs)
self.EXCLUDE_EXCEPTIONS = exclude_exceptions
def filter(self, record: logging.LogRecord) -> bool:
if record.exc_info:
etype, _, _ = record.exc_info
if etype is not None:
for exc in self.EXCLUDE_EXCEPTIONS:
if issubclass(etype, exc): return False
return True
class InfoFilter(logging.Filter):
def filter(self, record: logging.LogRecord) -> bool: return logging.INFO <= record.levelno < logging.WARNING
_LOGGING_CONFIG: dict[str, t.Any] = {
"version": 1, "disable_existing_loggers": True,
"filters": {"excfilter": {"()": "openllm.utils.ExceptionFilter"}, "infofilter": {"()": "openllm.utils.InfoFilter"}},
"handlers": {"bentomlhandler": {"class": "logging.StreamHandler", "filters": ["excfilter", "infofilter"], "stream": "ext://sys.stdout"}, "defaulthandler": {"class": "logging.StreamHandler", "level": logging.WARNING}},
"loggers": {"bentoml": {"handlers": ["bentomlhandler", "defaulthandler"], "level": logging.INFO, "propagate": False}, "openllm": {"handlers": ["bentomlhandler", "defaulthandler"], "level": logging.INFO, "propagate": False,}},
"root": {"level": logging.WARNING},
}
def configure_logging() -> None:
"""Configure logging for OpenLLM.
Behaves similar to how BentoML loggers are being configured.
"""
if get_quiet_mode():
_LOGGING_CONFIG["loggers"]["openllm"]["level"] = logging.ERROR
_LOGGING_CONFIG["loggers"]["bentoml"]["level"] = logging.ERROR
_LOGGING_CONFIG["root"]["level"] = logging.ERROR
elif get_debug_mode() or DEBUG:
_LOGGING_CONFIG["loggers"]["openllm"]["level"] = logging.DEBUG
_LOGGING_CONFIG["loggers"]["bentoml"]["level"] = logging.DEBUG
_LOGGING_CONFIG["root"]["level"] = logging.DEBUG
else:
_LOGGING_CONFIG["loggers"]["openllm"]["level"] = logging.INFO
_LOGGING_CONFIG["loggers"]["bentoml"]["level"] = logging.INFO
_LOGGING_CONFIG["root"]["level"] = logging.INFO
logging.config.dictConfig(_LOGGING_CONFIG)
@functools.lru_cache(maxsize=1)
def in_notebook() -> bool:
try:
from IPython.core.getipython import get_ipython
if t.TYPE_CHECKING:
from IPython.core.interactiveshell import InteractiveShell
return "IPKernelApp" in t.cast("dict[str, t.Any]", t.cast(t.Callable[[], "InteractiveShell"], get_ipython)().config)
except (ImportError, AttributeError): return False
_dockerenv, _cgroup = Path("/.dockerenv"), Path("/proc/self/cgroup")
class suppress(contextlib.suppress, contextlib.ContextDecorator):
"""A version of contextlib.suppress with decorator support.
>>> @suppress(KeyError)
... def key_error():
... {}['']
>>> key_error()
"""
def compose(*funcs: AnyCallable) -> AnyCallable:
"""Compose any number of unary functions into a single unary function.
>>> import textwrap
>>> expected = str.strip(textwrap.dedent(compose.__doc__))
>>> strip_and_dedent = compose(str.strip, textwrap.dedent)
>>> strip_and_dedent(compose.__doc__) == expected
True
Compose also allows the innermost function to take arbitrary arguments.
>>> round_three = lambda x: round(x, ndigits=3)
>>> f = compose(round_three, int.__truediv__)
>>> [f(3*x, x+1) for x in range(1,10)]
[1.5, 2.0, 2.25, 2.4, 2.5, 2.571, 2.625, 2.667, 2.7]
"""
def compose_two(f1: AnyCallable, f2: AnyCallable) -> AnyCallable: return lambda *args, **kwargs: f1(f2(*args, **kwargs))
return functools.reduce(compose_two, funcs)
def apply(transform: AnyCallable) -> t.Callable[[AnyCallable], AnyCallable]:
"""Decorate a function with a transform function that is invoked on results returned from the decorated function.
```python
@apply(reversed)
def get_numbers(start):
"doc for get_numbers"
return range(start, start+3)
list(get_numbers(4))
# [6, 5, 4]
```
```python
get_numbers.__doc__
# 'doc for get_numbers'
```
"""
return lambda func: functools.wraps(func)(compose(transform, func))
@apply(bool)
@suppress(FileNotFoundError)
def _text_in_file(text: str, filename: Path) -> bool:
return any(text in line for line in filename.open())
def in_docker() -> bool:
"""Is this current environment running in docker?
```python
type(in_docker())
```
"""
return _dockerenv.exists() or _text_in_file("docker", _cgroup)
T, K = t.TypeVar("T"), t.TypeVar("K")
def resolve_filepath(path: str, ctx: str | None = None) -> str:
"""Resolve a file path to an absolute path, expand user and environment variables."""
try: return resolve_user_filepath(path, ctx)
except FileNotFoundError: return path
def validate_is_path(maybe_path: str) -> bool: return os.path.exists(os.path.dirname(resolve_filepath(maybe_path)))
def generate_context(framework_name: str) -> _ModelContext:
framework_versions = {"transformers": pkg.get_pkg_version("transformers")}
if openllm.utils.is_torch_available(): framework_versions["torch"] = pkg.get_pkg_version("torch")
if openllm.utils.is_tf_available():
from bentoml._internal.frameworks.utils.tensorflow import get_tf_version
framework_versions["tensorflow"] = get_tf_version()
if openllm.utils.is_flax_available(): framework_versions.update({"flax": pkg.get_pkg_version("flax"), "jax": pkg.get_pkg_version("jax"), "jaxlib": pkg.get_pkg_version("jaxlib")})
return _ModelContext(framework_name=framework_name, framework_versions=framework_versions)
from openllm_core._typing_compat import LiteralRuntime
import openllm
def generate_labels(llm: openllm.LLM[t.Any, t.Any]) -> dict[str, t.Any]: return {"runtime": llm.runtime, "framework": "openllm", "model_name": llm.config["model_name"], "architecture": llm.config["architecture"], "serialisation_format": llm._serialisation_format}
_TOKENIZER_PREFIX = "_tokenizer_"
def normalize_attrs_to_model_tokenizer_pair(**attrs: t.Any) -> tuple[dict[str, t.Any], dict[str, t.Any]]:
"""Normalize the given attrs to a model and tokenizer kwargs accordingly."""
tokenizer_attrs = {k[len(_TOKENIZER_PREFIX):]: v for k, v in attrs.items() if k.startswith(_TOKENIZER_PREFIX)}
for k in tuple(attrs.keys()):
if k.startswith(_TOKENIZER_PREFIX): del attrs[k]
return attrs, tokenizer_attrs
def infer_auto_class(implementation: LiteralRuntime) -> type[openllm.AutoLLM | openllm.AutoTFLLM | openllm.AutoFlaxLLM | openllm.AutoVLLM]:
import openllm
if implementation == "tf": return openllm.AutoTFLLM
@@ -250,62 +25,8 @@ def infer_auto_class(implementation: LiteralRuntime) -> type[openllm.AutoLLM | o
elif implementation == "vllm": return openllm.AutoVLLM
else: raise RuntimeError(f"Unknown implementation: {implementation} (supported: 'pt', 'flax', 'tf', 'vllm')")
# NOTE: The set marks contains a set of modules name
# that are available above and are whitelisted
# to be included in the extra_objects map.
_whitelist_modules = {"pkg"}
# XXX: define all classes, functions import above this line
# since _extras will be the locals() import from this file.
_extras: dict[str, t.Any] = {k: v for k, v in locals().items() if k in _whitelist_modules or (not isinstance(v, types.ModuleType) and not k.startswith("_"))}
_extras["__openllm_migration__"] = {"ModelEnv": "EnvVarMixin"}
_import_structure: dict[str, list[str]] = {
"analytics": [], "codegen": [], "dantic": [], "dummy_flax_objects": [], "dummy_pt_objects": [], "dummy_tf_objects": [], "dummy_vllm_objects": [], "representation": ["ReprMixin"], "lazy": ["LazyModule"],
"import_utils": ["OPTIONAL_DEPENDENCIES", "ENV_VARS_TRUE_VALUES", "DummyMetaclass", "EnvVarMixin", "require_backends",
"is_cpm_kernels_available", "is_einops_available", "is_flax_available", "is_tf_available", "is_vllm_available", "is_torch_available", "is_bitsandbytes_available", "is_peft_available", "is_datasets_available",
"is_transformers_supports_kbit", "is_transformers_supports_agent", "is_jupyter_available", "is_jupytext_available", "is_notebook_available", "is_triton_available", "is_autogptq_available", "is_sentencepiece_available",
"is_xformers_available", "is_fairscale_available"]}
if t.TYPE_CHECKING:
# NOTE: The following exports useful utils from bentoml
from . import (
analytics as analytics,
codegen as codegen,
dantic as dantic,
dummy_flax_objects as dummy_flax_objects,
dummy_pt_objects as dummy_pt_objects,
dummy_tf_objects as dummy_tf_objects,
dummy_vllm_objects as dummy_vllm_objects,
)
from .import_utils import (
ENV_VARS_TRUE_VALUES as ENV_VARS_TRUE_VALUES,
OPTIONAL_DEPENDENCIES as OPTIONAL_DEPENDENCIES,
DummyMetaclass as DummyMetaclass,
EnvVarMixin as EnvVarMixin,
is_autogptq_available as is_autogptq_available,
is_bitsandbytes_available as is_bitsandbytes_available,
is_cpm_kernels_available as is_cpm_kernels_available,
is_datasets_available as is_datasets_available,
is_einops_available as is_einops_available,
is_fairscale_available as is_fairscale_available,
is_flax_available as is_flax_available,
is_jupyter_available as is_jupyter_available,
is_jupytext_available as is_jupytext_available,
is_notebook_available as is_notebook_available,
is_peft_available as is_peft_available,
is_sentencepiece_available as is_sentencepiece_available,
is_tf_available as is_tf_available,
is_torch_available as is_torch_available,
is_transformers_supports_agent as is_transformers_supports_agent,
is_transformers_supports_kbit as is_transformers_supports_kbit,
is_triton_available as is_triton_available,
is_vllm_available as is_vllm_available,
is_xformers_available as is_xformers_available,
require_backends as require_backends,
)
from .representation import ReprMixin as ReprMixin
__lazy = LazyModule(__name__, globals()["__file__"], _import_structure, extra_objects=_extras)
__all__ = __lazy.__all__
__dir__ = __lazy.__dir__
__getattr__ = __lazy.__getattr__
__all__ = ["generate_labels", "infer_auto_class", "dummy_flax_objects", "dummy_pt_objects", "dummy_tf_objects", "dummy_vllm_objects"]
def __dir__() -> t.Sequence[str]: return sorted(__all__)
def __getattr__(it: str) -> t.Any:
if hasattr(openllm_core.utils, it): return getattr(openllm_core.utils, it)
else: raise AttributeError(f"module {__name__} has no attribute {it}")

View File

@@ -1,80 +0,0 @@
"""Telemetry related for OpenLLM tracking.
Users can disable this with OPENLLM_DO_NOT_TRACK envvar.
"""
from __future__ import annotations
import contextlib, functools, logging, os, re, typing as t, importlib.metadata
import attr, openllm
from bentoml._internal.utils import analytics as _internal_analytics
from openllm._typing_compat import ParamSpec
P = ParamSpec("P")
T = t.TypeVar("T")
logger = logging.getLogger(__name__)
# This variable is a proxy that will control BENTOML_DO_NOT_TRACK
OPENLLM_DO_NOT_TRACK = "OPENLLM_DO_NOT_TRACK"
DO_NOT_TRACK = os.environ.get(OPENLLM_DO_NOT_TRACK, str(False)).upper()
@functools.lru_cache(maxsize=1)
def do_not_track() -> bool: return DO_NOT_TRACK in openllm.utils.ENV_VARS_TRUE_VALUES
@functools.lru_cache(maxsize=1)
def _usage_event_debugging() -> bool: return os.environ.get("__BENTOML_DEBUG_USAGE", str(False)).lower() == "true"
def silent(func: t.Callable[P, T]) -> t.Callable[P, T]:
@functools.wraps(func)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> t.Any:
try: return func(*args, **kwargs)
except Exception as err:
if _usage_event_debugging():
if openllm.utils.get_debug_mode(): logger.error("Tracking Error: %s", err, stack_info=True, stacklevel=3)
else: logger.info("Tracking Error: %s", err)
else: logger.debug("Tracking Error: %s", err)
return wrapper
@silent
def track(event_properties: attr.AttrsInstance) -> None:
if do_not_track(): return
_internal_analytics.track(t.cast("_internal_analytics.schemas.EventMeta", event_properties))
@contextlib.contextmanager
def set_bentoml_tracking() -> t.Generator[None, None, None]:
original_value = os.environ.pop(_internal_analytics.BENTOML_DO_NOT_TRACK, str(False))
try:
os.environ[_internal_analytics.BENTOML_DO_NOT_TRACK] = str(do_not_track())
yield
finally: os.environ[_internal_analytics.BENTOML_DO_NOT_TRACK] = original_value
class EventMeta:
@property
def event_name(self) -> str:
# camel case to snake case
event_name = re.sub(r"(?<!^)(?=[A-Z])", "_", self.__class__.__name__).lower()
# remove "_event" suffix
suffix_to_remove = "_event"
if event_name.endswith(suffix_to_remove): event_name = event_name[:-len(suffix_to_remove)]
return event_name
@attr.define
class ModelSaveEvent(EventMeta):
module: str
model_size_in_kb: float
@attr.define
class OpenllmCliEvent(EventMeta):
cmd_group: str
cmd_name: str
openllm_version: str = importlib.metadata.version("openllm")
# NOTE: reserved for the do_not_track logics
duration_in_ms: t.Any = attr.field(default=None)
error_type: str = attr.field(default=None)
return_code: int = attr.field(default=None)
@attr.define
class StartInitEvent(EventMeta):
model_name: str
llm_config: t.Dict[str, t.Any] = attr.field(default=None)
@staticmethod
def handler(llm_config: openllm.LLMConfig) -> StartInitEvent: return StartInitEvent(model_name=llm_config["model_name"], llm_config=llm_config.model_dump())
def track_start_init(llm_config: openllm.LLMConfig) -> None:
if do_not_track(): return
track(StartInitEvent.handler(llm_config))

View File

@@ -1,141 +0,0 @@
from __future__ import annotations
import functools, inspect, linecache, os, logging, string, types, typing as t
from operator import itemgetter
from pathlib import Path
import orjson
if t.TYPE_CHECKING:
from fs.base import FS
import openllm
from openllm._typing_compat import LiteralString, AnyCallable, DictStrAny, ListStr
PartialAny = functools.partial[t.Any]
_T = t.TypeVar("_T", bound=t.Callable[..., t.Any])
logger = logging.getLogger(__name__)
OPENLLM_MODEL_NAME = "# openllm: model name"
OPENLLM_MODEL_ADAPTER_MAP = "# openllm: model adapter map"
class ModelNameFormatter(string.Formatter):
model_keyword: LiteralString = "__model_name__"
def __init__(self, model_name: str):
"""The formatter that extends model_name to be formatted the 'service.py'."""
super().__init__()
self.model_name = model_name
def vformat(self, format_string: str, *args: t.Any, **attrs: t.Any) -> t.Any: return super().vformat(format_string, (), {self.model_keyword: self.model_name})
def can_format(self, value: str) -> bool:
try:
self.parse(value)
return True
except ValueError: return False
class ModelIdFormatter(ModelNameFormatter):
model_keyword: LiteralString = "__model_id__"
class ModelAdapterMapFormatter(ModelNameFormatter):
model_keyword: LiteralString = "__model_adapter_map__"
_service_file = Path(os.path.abspath(__file__)).parent.parent/"_service.py"
def write_service(llm: openllm.LLM[t.Any, t.Any], adapter_map: dict[str, str | None] | None, llm_fs: FS) -> None:
from openllm.utils import DEBUG
model_name = llm.config["model_name"]
logger.debug("Generating service file for %s at %s (dir=%s)", model_name, llm.config["service_name"], llm_fs.getsyspath("/"))
with open(_service_file.__fspath__(), "r") as f: src_contents = f.readlines()
for it in src_contents:
if OPENLLM_MODEL_NAME in it: src_contents[src_contents.index(it)] = (ModelNameFormatter(model_name).vformat(it)[:-(len(OPENLLM_MODEL_NAME) + 3)] + "\n")
elif OPENLLM_MODEL_ADAPTER_MAP in it: src_contents[src_contents.index(it)] = (ModelAdapterMapFormatter(orjson.dumps(adapter_map).decode()).vformat(it)[:-(len(OPENLLM_MODEL_ADAPTER_MAP) + 3)] + "\n")
script = f"# GENERATED BY 'openllm build {model_name}'. DO NOT EDIT\n\n" + "".join(src_contents)
if DEBUG: logger.info("Generated script:\n%s", script)
llm_fs.writetext(llm.config["service_name"], script)
# sentinel object for unequivocal object() getattr
_sentinel = object()
def has_own_attribute(cls: type[t.Any], attrib_name: t.Any) -> bool:
"""Check whether *cls* defines *attrib_name* (and doesn't just inherit it)."""
attr = getattr(cls, attrib_name, _sentinel)
if attr is _sentinel: return False
for base_cls in cls.__mro__[1:]:
a = getattr(base_cls, attrib_name, None)
if attr is a: return False
return True
def get_annotations(cls: type[t.Any]) -> DictStrAny:
if has_own_attribute(cls, "__annotations__"): return cls.__annotations__
return t.cast("DictStrAny", {})
def is_class_var(annot: str | t.Any) -> bool:
annot = str(annot)
# Annotation can be quoted.
if annot.startswith(("'", '"')) and annot.endswith(("'", '"')): annot = annot[1:-1]
return annot.startswith(("typing.ClassVar", "t.ClassVar", "ClassVar", "typing_extensions.ClassVar",))
def add_method_dunders(cls: type[t.Any], method_or_cls: _T, _overwrite_doc: str | None = None) -> _T:
try: method_or_cls.__module__ = cls.__module__
except AttributeError: pass
try: method_or_cls.__qualname__ = f"{cls.__qualname__}.{method_or_cls.__name__}"
except AttributeError: pass
try: method_or_cls.__doc__ = _overwrite_doc or "Generated by ``openllm.LLMConfig`` for class " f"{cls.__qualname__}."
except AttributeError: pass
return method_or_cls
def _compile_and_eval(script: str, globs: DictStrAny, locs: t.Any = None, filename: str = "") -> None: eval(compile(script, filename, "exec"), globs, locs) # noqa: S307
def _make_method(name: str, script: str, filename: str, globs: DictStrAny) -> AnyCallable:
locs: DictStrAny = {}
# In order of debuggers like PDB being able to step through the code, we add a fake linecache entry.
count = 1
base_filename = filename
while True:
linecache_tuple = (len(script), None, script.splitlines(True), filename)
old_val = linecache.cache.setdefault(filename, linecache_tuple)
if old_val == linecache_tuple: break
else:
filename = f"{base_filename[:-1]}-{count}>"
count += 1
_compile_and_eval(script, globs, locs, filename)
return locs[name]
def make_attr_tuple_class(cls_name: str, attr_names: t.Sequence[str]) -> type[t.Any]:
"""Create a tuple subclass to hold class attributes.
The subclass is a bare tuple with properties for names.
class MyClassAttributes(tuple):
__slots__ = ()
x = property(itemgetter(0))
"""
from . import SHOW_CODEGEN
attr_class_name = f"{cls_name}Attributes"
attr_class_template = [f"class {attr_class_name}(tuple):", " __slots__ = ()",]
if attr_names:
for i, attr_name in enumerate(attr_names): attr_class_template.append(f" {attr_name} = _attrs_property(_attrs_itemgetter({i}))")
else: attr_class_template.append(" pass")
globs: DictStrAny = {"_attrs_itemgetter": itemgetter, "_attrs_property": property}
if SHOW_CODEGEN: logger.info("Generated class for %s:\n\n%s", attr_class_name, "\n".join(attr_class_template))
_compile_and_eval("\n".join(attr_class_template), globs)
return globs[attr_class_name]
def generate_unique_filename(cls: type[t.Any], func_name: str) -> str: return f"<{cls.__name__} generated {func_name} {cls.__module__}.{getattr(cls, '__qualname__', cls.__name__)}>"
def generate_function(typ: type[t.Any], func_name: str, lines: list[str] | None, args: tuple[str, ...] | None, globs: dict[str, t.Any], annotations: dict[str, t.Any] | None = None) -> AnyCallable:
from openllm.utils import SHOW_CODEGEN
script = "def %s(%s):\n %s\n" % (func_name, ", ".join(args) if args is not None else "", "\n ".join(lines) if lines else "pass")
meth = _make_method(func_name, script, generate_unique_filename(typ, func_name), globs)
if annotations: meth.__annotations__ = annotations
if SHOW_CODEGEN: logger.info("Generated script for %s:\n\n%s", typ, script)
return meth
def make_env_transformer(cls: type[openllm.LLMConfig], model_name: str, suffix: LiteralString | None = None, default_callback: t.Callable[[str, t.Any], t.Any] | None = None, globs: DictStrAny | None = None,) -> AnyCallable:
from openllm.utils import dantic, field_env_key
def identity(_: str, x_value: t.Any) -> t.Any: return x_value
default_callback = identity if default_callback is None else default_callback
globs = {} if globs is None else globs
globs.update({"__populate_env": dantic.env_converter, "__default_callback": default_callback, "__field_env": field_env_key, "__suffix": suffix or "", "__model_name": model_name,})
lines: ListStr = ["__env = lambda field_name: __field_env(__model_name, field_name, __suffix)", "return [", " f.evolve(", " default=__populate_env(__default_callback(f.name, f.default), __env(f.name)),", " metadata={", " 'env': f.metadata.get('env', __env(f.name)),", " 'description': f.metadata.get('description', '(not provided)'),", " },", " )", " for f in fields", "]"]
fields_ann = "list[attr.Attribute[t.Any]]"
return generate_function(cls, "__auto_env", lines, args=("_", "fields"), globs=globs, annotations={"_": "type[LLMConfig]", "fields": fields_ann, "return": fields_ann})
def gen_sdk(func: _T, name: str | None = None, **attrs: t.Any) -> _T:
"""Enhance sdk with nice repr that plays well with your brain."""
from openllm.utils import ReprMixin
if name is None: name = func.__name__.strip("_")
_signatures = inspect.signature(func).parameters
def _repr(self: ReprMixin) -> str: return f"<generated function {name} {orjson.dumps(dict(self.__repr_args__()), option=orjson.OPT_NON_STR_KEYS | orjson.OPT_INDENT_2).decode()}>"
def _repr_args(self: ReprMixin) -> t.Iterator[t.Tuple[str, t.Any]]: return ((k, _signatures[k].annotation) for k in self.__repr_keys__)
if func.__doc__ is None: doc = f"Generated SDK for {func.__name__}"
else: doc = func.__doc__
return t.cast(_T, functools.update_wrapper(types.new_class(name, (t.cast("PartialAny", functools.partial), ReprMixin), exec_body=lambda ns: ns.update({"__repr_keys__": property(lambda _: [i for i in _signatures.keys() if not i.startswith("_")]), "__repr_args__": _repr_args, "__repr__": _repr, "__doc__": inspect.cleandoc(doc), "__module__": "openllm",}),)(func, **attrs), func,))
__all__ = ["gen_sdk", "make_attr_tuple_class", "make_env_transformer", "generate_unique_filename", "generate_function", "OPENLLM_MODEL_NAME", "OPENLLM_MODEL_ADAPTER_MAP"]

View File

@@ -1,387 +0,0 @@
"""An interface provides the best of pydantic and attrs."""
from __future__ import annotations
import functools, importlib, os, sys, typing as t
from enum import Enum
import attr, click, click_option_group as cog, inflection, orjson
from click import (
ParamType,
shell_completion as sc,
types as click_types,
)
if t.TYPE_CHECKING: from attr import _ValidatorType
AnyCallable = t.Callable[..., t.Any]
FC = t.TypeVar("FC", bound=t.Union[AnyCallable, click.Command])
__all__ = ["FC", "attrs_to_options", "Field", "parse_type", "is_typing", "is_literal", "ModuleType", "EnumChoice", "LiteralChoice", "allows_multiple", "is_mapping", "is_container", "parse_container_args", "parse_single_arg", "CUDA", "JsonType", "BytesType"]
def __dir__() -> list[str]: return sorted(__all__)
def attrs_to_options(name: str, field: attr.Attribute[t.Any], model_name: str, typ: t.Any | None = None, suffix_generation: bool = False, suffix_sampling: bool = False,) -> t.Callable[[FC], FC]:
# TODO: support parsing nested attrs class and Union
envvar = field.metadata["env"]
dasherized = inflection.dasherize(name)
underscored = inflection.underscore(name)
if typ in (None, attr.NOTHING):
typ = field.type
if typ is None: raise RuntimeError(f"Failed to parse type for {name}")
full_option_name = f"--{dasherized}"
if field.type is bool: full_option_name += f"/--no-{dasherized}"
if suffix_generation: identifier = f"{model_name}_generation_{underscored}"
elif suffix_sampling: identifier = f"{model_name}_sampling_{underscored}"
else: identifier = f"{model_name}_{underscored}"
return cog.optgroup.option(identifier, full_option_name, type=parse_type(typ), required=field.default is attr.NOTHING, default=field.default if field.default not in (attr.NOTHING, None) else None, show_default=True, multiple=allows_multiple(typ) if typ else False, help=field.metadata.get("description", "(No description provided)"), show_envvar=True, envvar=envvar,)
def env_converter(value: t.Any, env: str | None = None) -> t.Any:
if env is not None:
value = os.environ.get(env, value)
if value is not None and isinstance(value, str):
try: return orjson.loads(value.lower())
except orjson.JSONDecodeError as err: raise RuntimeError(f"Failed to parse ({value!r}) from '{env}': {err}") from None
return value
def Field(default: t.Any = None, *, ge: int | float | None = None, le: int | float | None = None, validator: _ValidatorType[t.Any] | None = None, description: str | None = None, env: str | None = None, auto_default: bool = False, use_default_converter: bool = True, **attrs: t.Any) -> t.Any:
"""A decorator that extends attr.field with additional arguments, which provides the same interface as pydantic's Field.
By default, if both validator and ge are provided, then then ge will be
piped into first, then all of the other validator will be run afterwards.
Args:
default: The default value for ``dantic.Field``. Defaults to ``None``.
ge: Greater than or equal to. Defaults to None.
le: Less than or equal to. Defaults to None.
validator: Optional attrs-compatible validators type. Default to None
description: the documentation for the field. Defaults to None.
env: the environment variable to read from. Defaults to None.
auto_default: a bool indicating whether to use the default value as the environment.
Defaults to False. If set to True, the behaviour of this Field will also depends
on kw_only. If kw_only=True, the this field will become 'Required' and the default
value is omitted. If kw_only=False, then the default value will be used as before.
use_default_converter: a bool indicating whether to use the default converter. Defaults
to True. If set to False, then the default converter will not be used.
The default converter converts a given value from the environment variable
for this given Field.
**attrs: The rest of the arguments are passed to attr.field
"""
metadata = attrs.pop("metadata", {})
if description is None: description = "(No description provided)"
metadata["description"] = description
if env is not None: metadata["env"] = env
piped: list[_ValidatorType[t.Any]] = []
converter = attrs.pop("converter", None)
if use_default_converter: converter = functools.partial(env_converter, env=env)
if ge is not None: piped.append(attr.validators.ge(ge))
if le is not None: piped.append(attr.validators.le(le))
if validator is not None: piped.append(validator)
if len(piped) == 0: _validator = None
elif len(piped) == 1: _validator = piped[0]
else: _validator = attr.validators.and_(*piped)
factory = attrs.pop("factory", None)
if factory is not None and default is not None: raise RuntimeError("'factory' and 'default' are mutually exclusive.")
# NOTE: the behaviour of this is we will respect factory over the default
if factory is not None: attrs["factory"] = factory
else: attrs["default"] = default
kw_only = attrs.pop("kw_only", False)
if auto_default and kw_only:
attrs.pop("default")
return attr.field(metadata=metadata, validator=_validator, converter=converter, **attrs)
def parse_type(field_type: t.Any) -> ParamType | tuple[ParamType, ...]:
"""Transforms the pydantic field's type into a click-compatible type.
Args:
field_type: pydantic field type
Returns:
ParamType: click type equivalent
"""
from . import lenient_issubclass
if t.get_origin(field_type) is t.Union:
raise NotImplementedError("Unions are not supported")
# enumeration strings or other Enum derivatives
if lenient_issubclass(field_type, Enum):
return EnumChoice(enum=field_type, case_sensitive=True)
# literals are enum-like with way less functionality
if is_literal(field_type):
return LiteralChoice(value=field_type, case_sensitive=True)
# modules, classes, functions
if is_typing(field_type): return ModuleType()
# entire dictionaries:
# using a Dict, convert in advance
if is_mapping(field_type): return JsonType()
# list, List[p], Tuple[p], Set[p] and so on
if is_container(field_type): return parse_container_args(field_type)
# bytes are not natively supported by click
if lenient_issubclass(field_type, bytes): return BytesType()
# return the current type: it should be a primitive
return field_type
def is_typing(field_type: type) -> bool:
"""Checks whether the current type is a module-like type.
Args:
field_type: pydantic field type
Returns:
bool: true if the type is itself a type
"""
raw = t.get_origin(field_type)
if raw is None: return False
if raw is type or raw is t.Type: return True
return False
def is_literal(field_type: type) -> bool:
"""Checks whether the given field type is a Literal type or not.
Literals are weird: isinstance and subclass do not work, so you compare
the origin with the Literal declaration itself.
Args:
field_type: current pydantic type
Returns:
bool: true if Literal type, false otherwise
"""
origin = t.get_origin(field_type)
return origin is not None and origin is t.Literal
class ModuleType(ParamType):
name = "module"
def _import_object(self, value: str) -> t.Any:
module_name, class_name = value.rsplit(".", maxsplit=1)
if not all(s.isidentifier() for s in module_name.split(".")): raise ValueError(f"'{value}' is not a valid module name")
if not class_name.isidentifier(): raise ValueError(f"Variable '{class_name}' is not a valid identifier")
module = importlib.import_module(module_name)
if class_name:
try: return getattr(module, class_name)
except AttributeError: raise ImportError(f"Module '{module_name}' does not define a '{class_name}' variable.") from None
def convert(self, value: str | t.Any, param: click.Parameter | None, ctx: click.Context | None) -> t.Any:
try:
if isinstance(value, str): return self._import_object(value)
return value
except Exception as exc: self.fail(f"'{value}' is not a valid object ({type(exc)}: {exc!s})", param, ctx)
class EnumChoice(click.Choice):
name = "enum"
def __init__(self, enum: Enum, case_sensitive: bool = False):
"""Enum type support for click that extends ``click.Choice``.
Args:
enum: Given enum
case_sensitive: Whether this choice should be case case_sensitive.
"""
self.mapping = enum
self.internal_type = type(enum)
choices: list[t.Any] = [e.name for e in enum.__class__]
super().__init__(choices, case_sensitive)
def convert(self, value: t.Any, param: click.Parameter | None, ctx: click.Context | None) -> Enum:
if isinstance(value, self.internal_type):
return value
result = super().convert(value, param, ctx)
if isinstance(result, str):
result = self.internal_type[result]
return result
class LiteralChoice(EnumChoice):
name = "literal"
def __init__(self, value: t.Any, case_sensitive: bool = False):
"""Literal support for click."""
# expect every literal value to belong to the same primitive type
values = list(value.__args__)
item_type = type(values[0])
if not all(isinstance(v, item_type) for v in values): raise ValueError(f"Field {value} contains items of different types.")
_mapping = {str(v): v for v in values}
super(EnumChoice, self).__init__(list(_mapping), case_sensitive)
self.internal_type = item_type
def allows_multiple(field_type: type[t.Any]) -> bool:
"""Checks whether the current type allows for multiple arguments to be provided as input or not.
For containers, it exploits click's support for lists and such to use the same option multiple times
to create a complex object: `python run.py --subsets train --subsets test`
# becomes `subsets: ["train", "test"]`.
Args:
field_type: pydantic type.
Returns:
bool: true if it's a composite field (lists, containers and so on), false otherwise
"""
# Early out for mappings, since it's better to deal with them using strings.
if is_mapping(field_type):
return False
# Activate multiple option for (simple) container types
if is_container(field_type):
args = parse_container_args(field_type)
# A non-composite type has a single argument, such as 'List[int]'
# A composite type has a tuple of arguments, like 'Tuple[str, int, int]'.
# For the moment, only non-composite types are allowed.
return not isinstance(args, tuple)
return False
def is_mapping(field_type: type) -> bool:
"""Checks whether this field represents a dictionary or JSON object.
Args:
field_type (type): pydantic type
Returns:
bool: true when the field is a dict-like object, false otherwise.
"""
# Early out for standard containers.
from . import lenient_issubclass
if lenient_issubclass(field_type, t.Mapping): return True
# for everything else or when the typing is more complex, check its origin
origin = t.get_origin(field_type)
if origin is None: return False
return lenient_issubclass(origin, t.Mapping)
def is_container(field_type: type) -> bool:
"""Checks whether the current type is a container type ('contains' other types), like lists and tuples.
Args:
field_type: pydantic field type
Returns:
bool: true if a container, false otherwise
"""
# do not consider strings or byte arrays as containers
if field_type in (str, bytes): return False
# Early out for standard containers: list, tuple, range
from . import lenient_issubclass
if lenient_issubclass(field_type, t.Container): return True
origin = t.get_origin(field_type)
# Early out for non-typing objects
if origin is None: return False
return lenient_issubclass(origin, t.Container)
def parse_container_args(field_type: type[t.Any]) -> ParamType | tuple[ParamType, ...]:
"""Parses the arguments inside a container type (lists, tuples and so on).
Args:
field_type: pydantic field type
Returns:
ParamType | tuple[ParamType]: single click-compatible type or a tuple
"""
if not is_container(field_type):
raise ValueError("Field type is not a container type.")
args = t.get_args(field_type)
# Early out for untyped containers: standard lists, tuples, List[Any]
# Use strings when the type is unknown, avoid click's type guessing
if len(args) == 0:
return click_types.convert_type(str)
# Early out for homogenous containers: Tuple[int], List[str]
# or homogenous tuples of indefinite length: Tuple[int, ...]
if len(args) == 1 or (len(args) == 2 and args[1] is Ellipsis):
return parse_single_arg(args[0])
# Then deal with fixed-length containers: Tuple[str, int, int]
return tuple(parse_single_arg(arg) for arg in args)
def parse_single_arg(arg: type) -> ParamType:
"""Returns the click-compatible type for container origin types.
In this case, returns string when it's not inferrable, a JSON for mappings
and the original type itself in every other case (ints, floats and so on).
Bytes is a special case, not natively handled by click.
Args:
arg (type): single argument
Returns:
ParamType: click-compatible type
"""
from . import lenient_issubclass
# When we don't know the type, we choose 'str'
if arg is t.Any: return click_types.convert_type(str)
# For containers and nested models, we use JSON
if is_container(arg): return JsonType()
if lenient_issubclass(arg, bytes): return BytesType()
return click_types.convert_type(arg)
class BytesType(ParamType):
name = "bytes"
def convert(self, value: t.Any, param: click.Parameter | None, ctx: click.Context | None) -> t.Any:
if isinstance(value, bytes): return value
try: return str.encode(value)
except Exception as exc: self.fail(f"'{value}' is not a valid string ({exc!s})", param, ctx)
CYGWIN = sys.platform.startswith("cygwin")
WIN = sys.platform.startswith("win")
if sys.platform.startswith("win") and WIN:
def _get_argv_encoding() -> str:
import locale
return locale.getpreferredencoding()
else:
def _get_argv_encoding() -> str: return getattr(sys.stdin, "encoding", None) or sys.getfilesystemencoding()
class CudaValueType(ParamType):
name = "cuda"
envvar_list_splitter = ","
is_composite = True
typ = click_types.convert_type(str)
def split_envvar_value(self, rv: str) -> t.Sequence[str]:
var = tuple(i for i in rv.split(self.envvar_list_splitter))
if "-1" in var:
return var[:var.index("-1")]
return var
def shell_complete(self, ctx: click.Context, param: click.Parameter, incomplete: str) -> list[sc.CompletionItem]:
"""Return a list of :class:`~click.shell_completion.CompletionItem` objects for the incomplete value.
Most types do not provide completions, but some do, and this allows custom types to provide custom completions as well.
Args:
ctx: Invocation context for this command.
param: The parameter that is requesting completion.
incomplete: Value being completed. May be empty.
"""
from openllm.utils import available_devices
mapping = incomplete.split(self.envvar_list_splitter) if incomplete else available_devices()
return [sc.CompletionItem(str(i), help=f"CUDA device index {i}") for i in mapping]
def convert(self, value: t.Any, param: click.Parameter | None, ctx: click.Context | None) -> t.Any:
if isinstance(value, bytes):
enc = _get_argv_encoding()
try: value = value.decode(enc)
except UnicodeError:
fs_enc = sys.getfilesystemencoding()
if fs_enc != enc:
try: value = value.decode(fs_enc)
except UnicodeError: value = value.decode("utf-8", "replace")
else: value = value.decode("utf-8", "replace")
return tuple(self.typ(x, param, ctx) for x in value.split(","))
def __repr__(self) -> str: return "STRING"
CUDA = CudaValueType()
class JsonType(ParamType):
name = "json"
def __init__(self, should_load: bool = True) -> None:
"""Support JSON type for click.ParamType.
Args:
should_load: Whether to load the JSON. Default to True. If False, the value won't be converted.
"""
super().__init__()
self.should_load = should_load
def convert(self, value: t.Any, param: click.Parameter | None, ctx: click.Context | None) -> t.Any:
if isinstance(value, dict) or not self.should_load: return value
try: return orjson.loads(value)
except orjson.JSONDecodeError as exc: self.fail(f"'{value}' is not a valid JSON string ({exc!s})", param, ctx)

View File

@@ -2,7 +2,7 @@
# To update this, run ./tools/update-dummy.py
from __future__ import annotations
import typing as _t
from openllm.utils import DummyMetaclass as _DummyMetaclass, require_backends as _require_backends
from openllm_core.utils import DummyMetaclass as _DummyMetaclass, require_backends as _require_backends
class FlaxFlanT5(metaclass=_DummyMetaclass):
_backends=["flax"]
def __init__(self,*param_decls:_t.Any,**attrs: _t.Any):_require_backends(self,["flax"])

View File

@@ -2,7 +2,7 @@
# To update this, run ./tools/update-dummy.py
from __future__ import annotations
import typing as _t
from openllm.utils import DummyMetaclass as _DummyMetaclass, require_backends as _require_backends
from openllm_core.utils import DummyMetaclass as _DummyMetaclass, require_backends as _require_backends
class ChatGLM(metaclass=_DummyMetaclass):
_backends=["torch","cpm_kernels","sentencepiece"]
def __init__(self,*param_decls:_t.Any,**attrs: _t.Any):_require_backends(self,["torch","cpm_kernels","sentencepiece"])

View File

@@ -2,7 +2,7 @@
# To update this, run ./tools/update-dummy.py
from __future__ import annotations
import typing as _t
from openllm.utils import DummyMetaclass as _DummyMetaclass, require_backends as _require_backends
from openllm_core.utils import DummyMetaclass as _DummyMetaclass, require_backends as _require_backends
class TFFlanT5(metaclass=_DummyMetaclass):
_backends=["tensorflow"]
def __init__(self,*param_decls:_t.Any,**attrs: _t.Any):_require_backends(self,["tensorflow"])

View File

@@ -2,7 +2,7 @@
# To update this, run ./tools/update-dummy.py
from __future__ import annotations
import typing as _t
from openllm.utils import DummyMetaclass as _DummyMetaclass, require_backends as _require_backends
from openllm_core.utils import DummyMetaclass as _DummyMetaclass, require_backends as _require_backends
class VLLMBaichuan(metaclass=_DummyMetaclass):
_backends=["vllm","cpm_kernels","sentencepiece"]
def __init__(self,*param_decls:_t.Any,**attrs: _t.Any):_require_backends(self,["vllm","cpm_kernels","sentencepiece"])

View File

@@ -1,312 +0,0 @@
"""Some imports utils are vendorred from transformers/utils/import_utils.py for performance reasons."""
from __future__ import annotations
import importlib, importlib.metadata, importlib.util, logging, os, abc, typing as t
from collections import OrderedDict
import inflection, packaging.version
from bentoml._internal.utils import LazyLoader, pkg
from openllm._typing_compat import overload, LiteralString
from .representation import ReprMixin
if t.TYPE_CHECKING:
BackendOrderedDict = OrderedDict[str, t.Tuple[t.Callable[[], bool], str]]
from openllm._typing_compat import LiteralRuntime
logger = logging.getLogger(__name__)
OPTIONAL_DEPENDENCIES = {"opt", "flan-t5", "vllm", "fine-tune", "ggml", "agents", "openai", "playground", "gptq",}
ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"}
ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"})
USE_TF = os.environ.get("USE_TF", "AUTO").upper()
USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
USE_JAX = os.environ.get("USE_FLAX", "AUTO").upper()
FORCE_TF_AVAILABLE = os.environ.get("FORCE_TF_AVAILABLE", "AUTO").upper()
def _is_package_available(package: str) -> bool:
_package_available = importlib.util.find_spec(package) is not None
if _package_available:
try: importlib.metadata.version(package)
except importlib.metadata.PackageNotFoundError: _package_available = False
return _package_available
_torch_available = importlib.util.find_spec("torch") is not None
_tf_available = importlib.util.find_spec("tensorflow") is not None
_flax_available = importlib.util.find_spec("jax") is not None and importlib.util.find_spec("flax") is not None
_vllm_available = importlib.util.find_spec("vllm") is not None
_peft_available = _is_package_available("peft")
_einops_available = _is_package_available("einops")
_cpm_kernel_available = _is_package_available("cpm_kernels")
_bitsandbytes_available = _is_package_available("bitsandbytes")
_datasets_available = _is_package_available("datasets")
_triton_available = _is_package_available("triton")
_jupyter_available = _is_package_available("jupyter")
_jupytext_available = _is_package_available("jupytext")
_notebook_available = _is_package_available("notebook")
_autogptq_available = _is_package_available("auto_gptq")
_sentencepiece_available = _is_package_available("sentencepiece")
_xformers_available = _is_package_available("xformers")
_fairscale_available = _is_package_available("fairscale")
def is_transformers_supports_kbit() -> bool: return pkg.pkg_version_info("transformers")[:2] >= (4, 30)
def is_transformers_supports_agent() -> bool: return pkg.pkg_version_info("transformers")[:2] >= (4, 29)
def is_jupyter_available() -> bool: return _jupyter_available
def is_jupytext_available() -> bool: return _jupytext_available
def is_notebook_available() -> bool: return _notebook_available
def is_triton_available() -> bool: return _triton_available
def is_datasets_available() -> bool: return _datasets_available
def is_peft_available() -> bool: return _peft_available
def is_einops_available() -> bool: return _einops_available
def is_cpm_kernels_available() -> bool: return _cpm_kernel_available
def is_bitsandbytes_available() -> bool: return _bitsandbytes_available
def is_autogptq_available() -> bool: return _autogptq_available
def is_vllm_available() -> bool: return _vllm_available
def is_sentencepiece_available() -> bool: return _sentencepiece_available
def is_xformers_available() -> bool: return _xformers_available
def is_fairscale_available() -> bool: return _fairscale_available
def is_torch_available() -> bool:
global _torch_available
if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES:
if _torch_available:
try: importlib.metadata.version("torch")
except importlib.metadata.PackageNotFoundError: _torch_available = False
else:
logger.info("Disabling PyTorch because USE_TF is set")
_torch_available = False
return _torch_available
def is_tf_available() -> bool:
global _tf_available
if FORCE_TF_AVAILABLE in ENV_VARS_TRUE_VALUES: _tf_available = True
else:
_tf_version = None
if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VALUES:
if _tf_available:
candidates = ("tensorflow", "tensorflow-cpu", "tensorflow-gpu", "tf-nightly", "tf-nightly-cpu", "tf-nightly-gpu", "intel-tensorflow", "intel-tensorflow-avx512", "tensorflow-rocm", "tensorflow-macos", "tensorflow-aarch64",)
_tf_version = None
# For the metadata, we have to look for both tensorflow and tensorflow-cpu
for _pkg in candidates:
try:
_tf_version = importlib.metadata.version(_pkg)
break
except importlib.metadata.PackageNotFoundError: pass # noqa: PERF203 # Ok to ignore here since we actually need to check for all possible tensorflow distribution.
_tf_available = _tf_version is not None
if _tf_available:
if _tf_version and packaging.version.parse(_tf_version) < packaging.version.parse("2"):
logger.info("TensorFlow found but with version %s. OpenLLM only supports TF 2.x", _tf_version)
_tf_available = False
else:
logger.info("Disabling Tensorflow because USE_TORCH is set")
_tf_available = False
return _tf_available
def is_flax_available() -> bool:
global _flax_available
if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES:
if _flax_available:
try:
importlib.metadata.version("jax")
importlib.metadata.version("flax")
except importlib.metadata.PackageNotFoundError: _flax_available = False
else:
_flax_available = False
return _flax_available
VLLM_IMPORT_ERROR_WITH_PYTORCH = """\
{0} requires the vLLM library but it was not found in your environment.
However, we were able to find a PyTorch installation. PyTorch classes do not begin
with "VLLM", but are otherwise identically named to our PyTorch classes.
If you want to use PyTorch, please use those classes instead!
If you really do want to use vLLM, please follow the instructions on the
installation page https://github.com/vllm-project/vllm that match your environment.
"""
VLLM_IMPORT_ERROR_WITH_TF = """\
{0} requires the vLLM library but it was not found in your environment.
However, we were able to find a TensorFlow installation. TensorFlow classes begin
with "TF", but are otherwise identically named to the PyTorch classes. This
means that the TF equivalent of the class you tried to import would be "TF{0}".
If you want to use TensorFlow, please use TF classes instead!
If you really do want to use vLLM, please follow the instructions on the
installation page https://github.com/vllm-project/vllm that match your environment.
"""
VLLM_IMPORT_ERROR_WITH_FLAX = """\
{0} requires the vLLM library but it was not found in your environment.
However, we were able to find a Flax installation. Flax classes begin
with "Flax", but are otherwise identically named to the PyTorch classes. This
means that the Flax equivalent of the class you tried to import would be "Flax{0}".
If you want to use Flax, please use Flax classes instead!
If you really do want to use vLLM, please follow the instructions on the
installation page https://github.com/vllm-project/vllm that match your environment.
"""
PYTORCH_IMPORT_ERROR_WITH_TF = """\
{0} requires the PyTorch library but it was not found in your environment.
However, we were able to find a TensorFlow installation. TensorFlow classes begin
with "TF", but are otherwise identically named to the PyTorch classes. This
means that the TF equivalent of the class you tried to import would be "TF{0}".
If you want to use TensorFlow, please use TF classes instead!
If you really do want to use PyTorch please go to
https://pytorch.org/get-started/locally/ and follow the instructions that
match your environment.
"""
TF_IMPORT_ERROR_WITH_PYTORCH = """\
{0} requires the TensorFlow library but it was not found in your environment.
However, we were able to find a PyTorch installation. PyTorch classes do not begin
with "TF", but are otherwise identically named to our TF classes.
If you want to use PyTorch, please use those classes instead!
If you really do want to use TensorFlow, please follow the instructions on the
installation page https://www.tensorflow.org/install that match your environment.
"""
TENSORFLOW_IMPORT_ERROR = """{0} requires the TensorFlow library but it was not found in your environment.
Checkout the instructions on the installation page: https://www.tensorflow.org/install and follow the
ones that match your environment. Please note that you may need to restart your runtime after installation.
"""
FLAX_IMPORT_ERROR = """{0} requires the FLAX library but it was not found in your environment.
Checkout the instructions on the installation page: https://github.com/google/flax and follow the
ones that match your environment. Please note that you may need to restart your runtime after installation.
"""
PYTORCH_IMPORT_ERROR = """{0} requires the PyTorch library but it was not found in your environment.
Checkout the instructions on the installation page: https://pytorch.org/get-started/locally/ and follow the
ones that match your environment. Please note that you may need to restart your runtime after installation.
"""
VLLM_IMPORT_ERROR = """{0} requires the vLLM library but it was not found in your environment.
Checkout the instructions on the installation page: https://github.com/vllm-project/vllm
ones that match your environment. Please note that you may need to restart your runtime after installation.
"""
CPM_KERNELS_IMPORT_ERROR = """{0} requires the cpm_kernels library but it was not found in your environment.
You can install it with pip: `pip install cpm_kernels`. Please note that you may need to restart your
runtime after installation.
"""
EINOPS_IMPORT_ERROR = """{0} requires the einops library but it was not found in your environment.
You can install it with pip: `pip install einops`. Please note that you may need to restart
your runtime after installation.
"""
TRITON_IMPORT_ERROR = """{0} requires the triton library but it was not found in your environment.
You can install it with pip: 'pip install \"git+https://github.com/openai/triton.git#egg=triton&subdirectory=python\"'.
Please note that you may need to restart your runtime after installation.
"""
DATASETS_IMPORT_ERROR = """{0} requires the datasets library but it was not found in your environment.
You can install it with pip: `pip install datasets`. Please note that you may need to restart
your runtime after installation.
"""
PEFT_IMPORT_ERROR = """{0} requires the peft library but it was not found in your environment.
You can install it with pip: `pip install peft`. Please note that you may need to restart
your runtime after installation.
"""
BITSANDBYTES_IMPORT_ERROR = """{0} requires the bitsandbytes library but it was not found in your environment.
You can install it with pip: `pip install bitsandbytes`. Please note that you may need to restart
your runtime after installation.
"""
AUTOGPTQ_IMPORT_ERROR = """{0} requires the auto-gptq library but it was not found in your environment.
You can install it with pip: `pip install auto-gptq`. Please note that you may need to restart
your runtime after installation.
"""
SENTENCEPIECE_IMPORT_ERROR = """{0} requires the sentencepiece library but it was not found in your environment.
You can install it with pip: `pip install sentencepiece`. Please note that you may need to restart
your runtime after installation.
"""
XFORMERS_IMPORT_ERROR = """{0} requires the xformers library but it was not found in your environment.
You can install it with pip: `pip install xformers`. Please note that you may need to restart
your runtime after installation.
"""
FAIRSCALE_IMPORT_ERROR = """{0} requires the fairscale library but it was not found in your environment.
You can install it with pip: `pip install fairscale`. Please note that you may need to restart
your runtime after installation.
"""
BACKENDS_MAPPING: BackendOrderedDict = OrderedDict([("flax", (is_flax_available, FLAX_IMPORT_ERROR)), ("tf", (is_tf_available, TENSORFLOW_IMPORT_ERROR)), ("torch", (is_torch_available, PYTORCH_IMPORT_ERROR)),
("vllm", (is_vllm_available, VLLM_IMPORT_ERROR)), ("cpm_kernels", (is_cpm_kernels_available, CPM_KERNELS_IMPORT_ERROR)), ("einops", (is_einops_available, EINOPS_IMPORT_ERROR)),
("triton", (is_triton_available, TRITON_IMPORT_ERROR)), ("datasets", (is_datasets_available, DATASETS_IMPORT_ERROR)), ("peft", (is_peft_available, PEFT_IMPORT_ERROR)),
("bitsandbytes", (is_bitsandbytes_available, BITSANDBYTES_IMPORT_ERROR)), ("auto-gptq", (is_autogptq_available, AUTOGPTQ_IMPORT_ERROR)), ("sentencepiece", (is_sentencepiece_available, SENTENCEPIECE_IMPORT_ERROR)),
("xformers", (is_xformers_available, XFORMERS_IMPORT_ERROR)), ("fairscale", (is_fairscale_available, FAIRSCALE_IMPORT_ERROR))])
class DummyMetaclass(abc.ABCMeta):
"""Metaclass for dummy object.
It will raises ImportError generated by ``require_backends`` if users try to access attributes from given class.
"""
_backends: t.List[str]
def __getattribute__(cls, key: str) -> t.Any:
if key.startswith("_"): return super().__getattribute__(key)
require_backends(cls, cls._backends)
def require_backends(o: t.Any, backends: t.MutableSequence[str]) -> None:
if not isinstance(backends, (list, tuple)): backends = list(backends)
name = o.__name__ if hasattr(o, "__name__") else o.__class__.__name__
# Raise an error for users who might not realize that classes without "TF" are torch-only
if "torch" in backends and "tf" not in backends and not is_torch_available() and is_tf_available(): raise ImportError(PYTORCH_IMPORT_ERROR_WITH_TF.format(name))
# Raise the inverse error for PyTorch users trying to load TF classes
if "tf" in backends and "torch" not in backends and is_torch_available() and not is_tf_available(): raise ImportError(TF_IMPORT_ERROR_WITH_PYTORCH.format(name))
# Raise an error when vLLM is not available to consider the alternative, order from PyTorch -> Tensorflow -> Flax
if "vllm" in backends:
if "torch" not in backends and is_torch_available() and not is_vllm_available(): raise ImportError(VLLM_IMPORT_ERROR_WITH_PYTORCH.format(name))
if "tf" not in backends and is_tf_available() and not is_vllm_available(): raise ImportError(VLLM_IMPORT_ERROR_WITH_TF.format(name))
if "flax" not in backends and is_flax_available() and not is_vllm_available(): raise ImportError(VLLM_IMPORT_ERROR_WITH_FLAX.format(name))
failed = [msg.format(name) for available, msg in (BACKENDS_MAPPING[backend] for backend in backends) if not available()]
if failed: raise ImportError("".join(failed))
class EnvVarMixin(ReprMixin):
model_name: str
config: str
model_id: str
quantize: str
framework: str
bettertransformer: str
runtime: str
@overload
def __getitem__(self, item: t.Literal["config"]) -> str: ...
@overload
def __getitem__(self, item: t.Literal["model_id"]) -> str: ...
@overload
def __getitem__(self, item: t.Literal["quantize"]) -> str: ...
@overload
def __getitem__(self, item: t.Literal["framework"]) -> str: ...
@overload
def __getitem__(self, item: t.Literal["bettertransformer"]) -> str: ...
@overload
def __getitem__(self, item: t.Literal["runtime"]) -> str: ...
@overload
def __getitem__(self, item: t.Literal["framework_value"]) -> LiteralRuntime: ...
@overload
def __getitem__(self, item: t.Literal["quantize_value"]) -> t.Literal["int8", "int4", "gptq"] | None: ...
@overload
def __getitem__(self, item: t.Literal["model_id_value"]) -> str | None: ...
@overload
def __getitem__(self, item: t.Literal["bettertransformer_value"]) -> bool: ...
@overload
def __getitem__(self, item: t.Literal["runtime_value"]) -> t.Literal["ggml", "transformers"]: ...
def __getitem__(self, item: str | t.Any) -> t.Any:
if item.endswith("_value") and hasattr(self, f"_{item}"): return object.__getattribute__(self, f"_{item}")()
elif hasattr(self, item): return getattr(self, item)
raise KeyError(f"Key {item} not found in {self}")
def __init__(self, model_name: str, implementation: LiteralRuntime = "pt", model_id: str | None = None, bettertransformer: bool | None = None, quantize: LiteralString | None = None, runtime: t.Literal["ggml", "transformers"] = "transformers") -> None:
"""EnvVarMixin is a mixin class that returns the value extracted from environment variables."""
from openllm._configuration import field_env_key
self.model_name = inflection.underscore(model_name)
self._implementation = implementation
self._model_id = model_id
self._bettertransformer = bettertransformer
self._quantize = quantize
self._runtime = runtime
for att in {"config", "model_id", "quantize", "framework", "bettertransformer", "runtime"}: setattr(self, att, field_env_key(self.model_name, att.upper()))
def _quantize_value(self) -> t.Literal["int8", "int4", "gptq"] | None:
from . import first_not_none
return t.cast(t.Optional[t.Literal["int8", "int4", "gptq"]], first_not_none(os.environ.get(self["quantize"]), default=self._quantize))
def _framework_value(self) -> LiteralRuntime:
from . import first_not_none
return t.cast(t.Literal["pt", "tf", "flax", "vllm"], first_not_none(os.environ.get(self["framework"]), default=self._implementation))
def _bettertransformer_value(self) -> bool:
from . import first_not_none
return t.cast(bool, first_not_none(os.environ.get(self["bettertransformer"], str(False)).upper() in ENV_VARS_TRUE_VALUES, default=self._bettertransformer))
def _model_id_value(self) -> str | None:
from . import first_not_none
return first_not_none(os.environ.get(self["model_id"]), default=self._model_id)
def _runtime_value(self) -> t.Literal["ggml", "transformers"]:
from . import first_not_none
return t.cast(t.Literal["ggml", "transformers"], first_not_none(os.environ.get(self["runtime"]), default=self._runtime))
@property
def __repr_keys__(self) -> set[str]: return {"config", "model_id", "quantize", "framework", "bettertransformer", "runtime"}
@property
def start_docstring(self) -> str: return getattr(self.module, f"START_{self.model_name.upper()}_COMMAND_DOCSTRING")
@property
def module(self) -> LazyLoader: return LazyLoader(self.model_name, globals(), f"openllm.models.{self.model_name}")

View File

@@ -1,107 +0,0 @@
from __future__ import annotations
import functools, importlib, importlib.machinery, importlib.metadata, importlib.util, itertools, os, time, types, warnings, typing as t
import attr, openllm
__all__ = ["VersionInfo", "LazyModule"]
# vendorred from attrs
@functools.total_ordering
@attr.attrs(eq=False, order=False, slots=True, frozen=True, repr=False)
class VersionInfo:
major: int = attr.field()
minor: int = attr.field()
micro: int = attr.field()
releaselevel: str = attr.field()
@classmethod
def from_version_string(cls, s: str) -> VersionInfo:
v = s.split(".")
if len(v) == 3: v.append("final")
return cls(major=int(v[0]), minor=int(v[1]), micro=int(v[2]), releaselevel=v[3])
def _ensure_tuple(self, other: VersionInfo) -> tuple[tuple[int, int, int, str], tuple[int, int, int, str]]:
cmp = attr.astuple(other) if self.__class__ is other.__class__ else other
if not isinstance(cmp, tuple): raise NotImplementedError
if not (1 <= len(cmp) <= 4): raise NotImplementedError
return t.cast(t.Tuple[int, int, int, str], attr.astuple(self)[:len(cmp)]), t.cast(t.Tuple[int, int, int, str], cmp)
def __eq__(self, other: t.Any) -> bool:
try: us, them = self._ensure_tuple(other)
except NotImplementedError: return NotImplemented
return us == them
def __lt__(self, other: t.Any) -> bool:
try: us, them = self._ensure_tuple(other)
except NotImplementedError: return NotImplemented
# Since alphabetically "dev0" < "final" < "post1" < "post2", we don't have to do anything special with releaselevel for now.
return us < them
def __repr__(self) -> str: return "{0}.{1}.{2}".format(*attr.astuple(self)[:3])
_sentinel, _reserved_namespace = object(), {"__openllm_migration__"}
class LazyModule(types.ModuleType):
# Very heavily inspired by optuna.integration._IntegrationModule: https://github.com/optuna/optuna/blob/master/optuna/integration/__init__.py
def __init__(self, name: str, module_file: str, import_structure: dict[str, list[str]], module_spec: importlib.machinery.ModuleSpec | None = None, doc: str | None = None, extra_objects: dict[str, t.Any] | None = None):
"""Lazily load this module as an object.
It does instantiate a __all__ and __dir__ for IDE support
Args:
name: module name
module_file: the given file. Often default to 'globals()['__file__']'
import_structure: A dictionary of module and its corresponding attributes that can be loaded from given 'module'
module_spec: __spec__ of the lazily loaded module
doc: Optional docstring for this module.
extra_objects: Any additional objects that this module can also be accessed. Useful for additional metadata as well as any locals() functions
"""
super().__init__(name)
self._modules = set(import_structure.keys())
self._class_to_module: dict[str, str] = {}
_extra_objects = {} if extra_objects is None else extra_objects
for key, values in import_structure.items():
for value in values: self._class_to_module[value] = key
# Needed for autocompletion in an IDE
self.__all__: list[str] = list(import_structure.keys()) + list(itertools.chain(*import_structure.values()))
self.__file__ = module_file
self.__spec__ = module_spec or importlib.util.find_spec(name)
self.__path__ = [os.path.dirname(module_file)]
self.__doc__ = doc
self._name = name
self._objects = _extra_objects
self._import_structure = import_structure
def __dir__(self) -> list[str]:
result = t.cast("list[str]", super().__dir__())
# The elements of self.__all__ that are submodules may or may not be in the dir already, depending on whether
# they have been accessed or not. So we only add the elements of self.__all__ that are not already in the dir.
return result + [i for i in self.__all__ if i not in result]
def __getattr__(self, name: str) -> t.Any:
"""Equivocal __getattr__ implementation.
It checks from _objects > _modules and does it recursively.
It also contains a special case for all of the metadata information, such as __version__ and __version_info__.
"""
if name in _reserved_namespace: raise openllm.exceptions.ForbiddenAttributeError(f"'{name}' is a reserved namespace for {self._name} and should not be access nor modified.")
dunder_to_metadata = {"__title__": "Name", "__copyright__": "", "__version__": "version", "__version_info__": "version", "__description__": "summary", "__uri__": "", "__url__": "", "__author__": "", "__email__": "", "__license__": "license", "__homepage__": ""}
if name in dunder_to_metadata:
if name not in {"__version_info__", "__copyright__", "__version__"}: warnings.warn(f"Accessing '{self._name}.{name}' is deprecated. Please consider using 'importlib.metadata' directly to query for openllm packaging metadata.", DeprecationWarning, stacklevel=2)
meta = importlib.metadata.metadata("openllm")
project_url = dict(url.split(", ") for url in t.cast(t.List[str], meta.get_all("Project-URL")))
if name == "__license__": return "Apache-2.0"
elif name == "__copyright__": return f"Copyright (c) 2023-{time.strftime('%Y')}, Aaron Pham et al."
elif name in ("__uri__", "__url__"): return project_url["GitHub"]
elif name == "__homepage__": return project_url["Homepage"]
elif name == "__version_info__": return VersionInfo.from_version_string(meta["version"]) # similar to how attrs handle __version_info__
elif name == "__author__": return meta["Author-email"].rsplit(" ", 1)[0]
elif name == "__email__": return meta["Author-email"].rsplit("<", 1)[1][:-1]
return meta[dunder_to_metadata[name]]
if "__openllm_migration__" in self._objects:
cur_value = self._objects["__openllm_migration__"].get(name, _sentinel)
if cur_value is not _sentinel:
warnings.warn(f"'{name}' is deprecated and will be removed in future version. Make sure to use '{cur_value}' instead", DeprecationWarning, stacklevel=3)
return getattr(self, cur_value)
if name in self._objects: return self._objects.__getitem__(name)
if name in self._modules: value = self._get_module(name)
elif name in self._class_to_module.keys(): value = getattr(self._get_module(self._class_to_module.__getitem__(name)), name)
else: raise AttributeError(f"module {self.__name__} has no attribute {name}")
setattr(self, name, value)
return value
def _get_module(self, module_name: str) -> types.ModuleType:
try: return importlib.import_module("." + module_name, self.__name__)
except Exception as e: raise RuntimeError(f"Failed to import {self.__name__}.{module_name} because of the following error (look up to see its traceback):\n{e}") from e
# make sure this module is picklable
def __reduce__(self) -> tuple[type[LazyModule], tuple[str, str | None, dict[str, list[str]]]]: return (self.__class__, (self._name, self.__file__, self._import_structure))

View File

@@ -1,32 +0,0 @@
from __future__ import annotations
import typing as t
from abc import abstractmethod
import attr, orjson
from openllm import utils
if t.TYPE_CHECKING: from openllm._typing_compat import TypeAlias
ReprArgs: TypeAlias = t.Generator[t.Tuple[t.Optional[str], t.Any], None, None]
class ReprMixin:
@property
@abstractmethod
def __repr_keys__(self) -> set[str]: raise NotImplementedError
"""This can be overriden by base class using this mixin."""
def __repr__(self) -> str: return f"{self.__class__.__name__} {orjson.dumps({k: utils.bentoml_cattr.unstructure(v) if attr.has(v) else v for k, v in self.__repr_args__()}, option=orjson.OPT_INDENT_2).decode()}"
"""The `__repr__` for any subclass of Mixin.
It will print nicely the class name with each of the fields under '__repr_keys__' as kv JSON dict.
"""
def __str__(self) -> str: return self.__repr_str__(" ")
"""The string representation of the given Mixin subclass.
It will contains all of the attributes from __repr_keys__
"""
def __repr_name__(self) -> str: return self.__class__.__name__
"""Name of the instance's class, used in __repr__."""
def __repr_str__(self, join_str: str) -> str: return join_str.join(repr(v) if a is None else f"{a}={v!r}" for a, v in self.__repr_args__())
"""To be used with __str__."""
def __repr_args__(self) -> ReprArgs: return ((k, getattr(self, k)) for k in self.__repr_keys__)
"""This can also be overriden by base class using this mixin.
By default it does a getattr of the current object from __repr_keys__.
"""