mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-03-05 23:56:47 -05:00
refactor: packages (#249)
This commit is contained in:
@@ -4,244 +4,19 @@ User can import these function for convenience, but
|
||||
we won't ensure backward compatibility for these functions. So use with caution.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import contextlib, functools, hashlib, logging, logging.config, os, sys, types, typing as t, openllm
|
||||
from pathlib import Path
|
||||
from circus.exc import ConflictError
|
||||
from bentoml._internal.configuration import (
|
||||
DEBUG_ENV_VAR as DEBUG_ENV_VAR,
|
||||
GRPC_DEBUG_ENV_VAR as _GRPC_DEBUG_ENV_VAR,
|
||||
QUIET_ENV_VAR as QUIET_ENV_VAR,
|
||||
get_debug_mode as _get_debug_mode,
|
||||
get_quiet_mode as _get_quiet_mode,
|
||||
set_quiet_mode as set_quiet_mode,
|
||||
)
|
||||
from bentoml._internal.models.model import ModelContext as _ModelContext
|
||||
from bentoml._internal.types import LazyType as LazyType
|
||||
from bentoml._internal.utils import (
|
||||
LazyLoader as LazyLoader,
|
||||
bentoml_cattr as bentoml_cattr,
|
||||
calc_dir_size as calc_dir_size,
|
||||
first_not_none as first_not_none,
|
||||
pkg as pkg,
|
||||
reserve_free_port as reserve_free_port,
|
||||
resolve_user_filepath as resolve_user_filepath,
|
||||
)
|
||||
from openllm.utils.lazy import (
|
||||
LazyModule as LazyModule,
|
||||
VersionInfo as VersionInfo,
|
||||
import typing as t, openllm_core
|
||||
from . import (
|
||||
dummy_flax_objects as dummy_flax_objects,
|
||||
dummy_pt_objects as dummy_pt_objects,
|
||||
dummy_tf_objects as dummy_tf_objects,
|
||||
dummy_vllm_objects as dummy_vllm_objects,
|
||||
)
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from openllm._typing_compat import AnyCallable, LiteralRuntime
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
try: from typing import GenericAlias as _TypingGenericAlias # type: ignore
|
||||
except ImportError: _TypingGenericAlias = () # type: ignore # python < 3.9 does not have GenericAlias (list[int], tuple[str, ...] and so on)
|
||||
if sys.version_info < (3, 10): _WithArgsTypes = (_TypingGenericAlias,)
|
||||
else: _WithArgsTypes: t.Any = (t._GenericAlias, types.GenericAlias, types.UnionType) # type: ignore # _GenericAlias is the actual GenericAlias implementation
|
||||
|
||||
DEV_DEBUG_VAR = "OPENLLMDEVDEBUG"
|
||||
|
||||
def set_debug_mode(enabled: bool, level: int = 1) -> None:
|
||||
# monkeypatch bentoml._internal.configuration.set_debug_mode to remove unused logs
|
||||
if enabled: os.environ[DEV_DEBUG_VAR] = str(level)
|
||||
os.environ[DEBUG_ENV_VAR] = str(enabled)
|
||||
os.environ[_GRPC_DEBUG_ENV_VAR] = "DEBUG" if enabled else "ERROR"
|
||||
|
||||
def lenient_issubclass(cls: t.Any, class_or_tuple: type[t.Any] | tuple[type[t.Any], ...] | None) -> bool:
|
||||
try:
|
||||
return isinstance(cls, type) and issubclass(cls, class_or_tuple) # type: ignore[arg-type]
|
||||
except TypeError:
|
||||
if isinstance(cls, _WithArgsTypes): return False
|
||||
raise
|
||||
|
||||
def available_devices() -> tuple[str, ...]:
|
||||
"""Return available GPU under system. Currently only supports NVIDIA GPUs."""
|
||||
from openllm._strategies import NvidiaGpuResource
|
||||
return tuple(NvidiaGpuResource.from_system())
|
||||
|
||||
@functools.lru_cache(maxsize=128)
|
||||
def generate_hash_from_file(f: str, algorithm: t.Literal["md5", "sha1"] = "sha1") -> str:
|
||||
"""Generate a hash from given file's modification time.
|
||||
|
||||
Args:
|
||||
f: The file to generate the hash from.
|
||||
algorithm: The hashing algorithm to use. Defaults to 'sha1' (similar to how Git generate its commit hash.)
|
||||
|
||||
Returns:
|
||||
The generated hash.
|
||||
"""
|
||||
return getattr(hashlib, algorithm)(str(os.path.getmtime(resolve_filepath(f))).encode()).hexdigest()
|
||||
|
||||
@functools.lru_cache(maxsize=1)
|
||||
def device_count() -> int: return len(available_devices())
|
||||
|
||||
# equivocal setattr to save one lookup per assignment
|
||||
_object_setattr = object.__setattr__
|
||||
|
||||
def non_intrusive_setattr(obj: t.Any, name: str, value: t.Any) -> None:
|
||||
"""This makes sure that we don't overwrite any existing attributes on the object."""
|
||||
_setattr = functools.partial(setattr, obj) if isinstance(obj, type) else _object_setattr.__get__(obj)
|
||||
if not hasattr(obj, name): _setattr(name, value)
|
||||
|
||||
def field_env_key(model_name: str, key: str, suffix: str | None = None) -> str: return "_".join(filter(None, map(str.upper, ["OPENLLM", model_name, suffix.strip("_") if suffix else "", key])))
|
||||
|
||||
# Special debug flag controled via OPENLLMDEVDEBUG
|
||||
DEBUG: bool = sys.flags.dev_mode or (not sys.flags.ignore_environment and bool(os.environ.get(DEV_DEBUG_VAR)))
|
||||
# MYPY is like t.TYPE_CHECKING, but reserved for Mypy plugins
|
||||
MYPY = False
|
||||
SHOW_CODEGEN: bool = DEBUG and int(os.environ.get("OPENLLMDEVDEBUG", str(0))) > 3
|
||||
|
||||
def get_debug_mode() -> bool: return DEBUG or _get_debug_mode()
|
||||
def get_quiet_mode() -> bool: return not DEBUG and _get_quiet_mode()
|
||||
|
||||
class ExceptionFilter(logging.Filter):
|
||||
def __init__(self, exclude_exceptions: list[type[Exception]] | None = None, **kwargs: t.Any):
|
||||
"""A filter of all exception."""
|
||||
if exclude_exceptions is None: exclude_exceptions = [ConflictError]
|
||||
if ConflictError not in exclude_exceptions: exclude_exceptions.append(ConflictError)
|
||||
super(ExceptionFilter, self).__init__(**kwargs)
|
||||
self.EXCLUDE_EXCEPTIONS = exclude_exceptions
|
||||
|
||||
def filter(self, record: logging.LogRecord) -> bool:
|
||||
if record.exc_info:
|
||||
etype, _, _ = record.exc_info
|
||||
if etype is not None:
|
||||
for exc in self.EXCLUDE_EXCEPTIONS:
|
||||
if issubclass(etype, exc): return False
|
||||
return True
|
||||
|
||||
class InfoFilter(logging.Filter):
|
||||
def filter(self, record: logging.LogRecord) -> bool: return logging.INFO <= record.levelno < logging.WARNING
|
||||
|
||||
_LOGGING_CONFIG: dict[str, t.Any] = {
|
||||
"version": 1, "disable_existing_loggers": True,
|
||||
"filters": {"excfilter": {"()": "openllm.utils.ExceptionFilter"}, "infofilter": {"()": "openllm.utils.InfoFilter"}},
|
||||
"handlers": {"bentomlhandler": {"class": "logging.StreamHandler", "filters": ["excfilter", "infofilter"], "stream": "ext://sys.stdout"}, "defaulthandler": {"class": "logging.StreamHandler", "level": logging.WARNING}},
|
||||
"loggers": {"bentoml": {"handlers": ["bentomlhandler", "defaulthandler"], "level": logging.INFO, "propagate": False}, "openllm": {"handlers": ["bentomlhandler", "defaulthandler"], "level": logging.INFO, "propagate": False,}},
|
||||
"root": {"level": logging.WARNING},
|
||||
}
|
||||
|
||||
def configure_logging() -> None:
|
||||
"""Configure logging for OpenLLM.
|
||||
|
||||
Behaves similar to how BentoML loggers are being configured.
|
||||
"""
|
||||
if get_quiet_mode():
|
||||
_LOGGING_CONFIG["loggers"]["openllm"]["level"] = logging.ERROR
|
||||
_LOGGING_CONFIG["loggers"]["bentoml"]["level"] = logging.ERROR
|
||||
_LOGGING_CONFIG["root"]["level"] = logging.ERROR
|
||||
elif get_debug_mode() or DEBUG:
|
||||
_LOGGING_CONFIG["loggers"]["openllm"]["level"] = logging.DEBUG
|
||||
_LOGGING_CONFIG["loggers"]["bentoml"]["level"] = logging.DEBUG
|
||||
_LOGGING_CONFIG["root"]["level"] = logging.DEBUG
|
||||
else:
|
||||
_LOGGING_CONFIG["loggers"]["openllm"]["level"] = logging.INFO
|
||||
_LOGGING_CONFIG["loggers"]["bentoml"]["level"] = logging.INFO
|
||||
_LOGGING_CONFIG["root"]["level"] = logging.INFO
|
||||
|
||||
logging.config.dictConfig(_LOGGING_CONFIG)
|
||||
|
||||
@functools.lru_cache(maxsize=1)
|
||||
def in_notebook() -> bool:
|
||||
try:
|
||||
from IPython.core.getipython import get_ipython
|
||||
if t.TYPE_CHECKING:
|
||||
from IPython.core.interactiveshell import InteractiveShell
|
||||
return "IPKernelApp" in t.cast("dict[str, t.Any]", t.cast(t.Callable[[], "InteractiveShell"], get_ipython)().config)
|
||||
except (ImportError, AttributeError): return False
|
||||
|
||||
_dockerenv, _cgroup = Path("/.dockerenv"), Path("/proc/self/cgroup")
|
||||
|
||||
class suppress(contextlib.suppress, contextlib.ContextDecorator):
|
||||
"""A version of contextlib.suppress with decorator support.
|
||||
|
||||
>>> @suppress(KeyError)
|
||||
... def key_error():
|
||||
... {}['']
|
||||
>>> key_error()
|
||||
"""
|
||||
|
||||
def compose(*funcs: AnyCallable) -> AnyCallable:
|
||||
"""Compose any number of unary functions into a single unary function.
|
||||
|
||||
>>> import textwrap
|
||||
>>> expected = str.strip(textwrap.dedent(compose.__doc__))
|
||||
>>> strip_and_dedent = compose(str.strip, textwrap.dedent)
|
||||
>>> strip_and_dedent(compose.__doc__) == expected
|
||||
True
|
||||
|
||||
Compose also allows the innermost function to take arbitrary arguments.
|
||||
|
||||
>>> round_three = lambda x: round(x, ndigits=3)
|
||||
>>> f = compose(round_three, int.__truediv__)
|
||||
>>> [f(3*x, x+1) for x in range(1,10)]
|
||||
[1.5, 2.0, 2.25, 2.4, 2.5, 2.571, 2.625, 2.667, 2.7]
|
||||
"""
|
||||
def compose_two(f1: AnyCallable, f2: AnyCallable) -> AnyCallable: return lambda *args, **kwargs: f1(f2(*args, **kwargs))
|
||||
return functools.reduce(compose_two, funcs)
|
||||
|
||||
def apply(transform: AnyCallable) -> t.Callable[[AnyCallable], AnyCallable]:
|
||||
"""Decorate a function with a transform function that is invoked on results returned from the decorated function.
|
||||
|
||||
```python
|
||||
@apply(reversed)
|
||||
def get_numbers(start):
|
||||
"doc for get_numbers"
|
||||
return range(start, start+3)
|
||||
list(get_numbers(4))
|
||||
# [6, 5, 4]
|
||||
```
|
||||
```python
|
||||
get_numbers.__doc__
|
||||
# 'doc for get_numbers'
|
||||
```
|
||||
"""
|
||||
return lambda func: functools.wraps(func)(compose(transform, func))
|
||||
|
||||
@apply(bool)
|
||||
@suppress(FileNotFoundError)
|
||||
def _text_in_file(text: str, filename: Path) -> bool:
|
||||
return any(text in line for line in filename.open())
|
||||
|
||||
def in_docker() -> bool:
|
||||
"""Is this current environment running in docker?
|
||||
|
||||
```python
|
||||
type(in_docker())
|
||||
```
|
||||
"""
|
||||
return _dockerenv.exists() or _text_in_file("docker", _cgroup)
|
||||
|
||||
T, K = t.TypeVar("T"), t.TypeVar("K")
|
||||
|
||||
def resolve_filepath(path: str, ctx: str | None = None) -> str:
|
||||
"""Resolve a file path to an absolute path, expand user and environment variables."""
|
||||
try: return resolve_user_filepath(path, ctx)
|
||||
except FileNotFoundError: return path
|
||||
|
||||
def validate_is_path(maybe_path: str) -> bool: return os.path.exists(os.path.dirname(resolve_filepath(maybe_path)))
|
||||
|
||||
def generate_context(framework_name: str) -> _ModelContext:
|
||||
framework_versions = {"transformers": pkg.get_pkg_version("transformers")}
|
||||
if openllm.utils.is_torch_available(): framework_versions["torch"] = pkg.get_pkg_version("torch")
|
||||
if openllm.utils.is_tf_available():
|
||||
from bentoml._internal.frameworks.utils.tensorflow import get_tf_version
|
||||
framework_versions["tensorflow"] = get_tf_version()
|
||||
if openllm.utils.is_flax_available(): framework_versions.update({"flax": pkg.get_pkg_version("flax"), "jax": pkg.get_pkg_version("jax"), "jaxlib": pkg.get_pkg_version("jaxlib")})
|
||||
return _ModelContext(framework_name=framework_name, framework_versions=framework_versions)
|
||||
from openllm_core._typing_compat import LiteralRuntime
|
||||
import openllm
|
||||
|
||||
def generate_labels(llm: openllm.LLM[t.Any, t.Any]) -> dict[str, t.Any]: return {"runtime": llm.runtime, "framework": "openllm", "model_name": llm.config["model_name"], "architecture": llm.config["architecture"], "serialisation_format": llm._serialisation_format}
|
||||
|
||||
_TOKENIZER_PREFIX = "_tokenizer_"
|
||||
|
||||
def normalize_attrs_to_model_tokenizer_pair(**attrs: t.Any) -> tuple[dict[str, t.Any], dict[str, t.Any]]:
|
||||
"""Normalize the given attrs to a model and tokenizer kwargs accordingly."""
|
||||
tokenizer_attrs = {k[len(_TOKENIZER_PREFIX):]: v for k, v in attrs.items() if k.startswith(_TOKENIZER_PREFIX)}
|
||||
for k in tuple(attrs.keys()):
|
||||
if k.startswith(_TOKENIZER_PREFIX): del attrs[k]
|
||||
return attrs, tokenizer_attrs
|
||||
|
||||
def infer_auto_class(implementation: LiteralRuntime) -> type[openllm.AutoLLM | openllm.AutoTFLLM | openllm.AutoFlaxLLM | openllm.AutoVLLM]:
|
||||
import openllm
|
||||
if implementation == "tf": return openllm.AutoTFLLM
|
||||
@@ -250,62 +25,8 @@ def infer_auto_class(implementation: LiteralRuntime) -> type[openllm.AutoLLM | o
|
||||
elif implementation == "vllm": return openllm.AutoVLLM
|
||||
else: raise RuntimeError(f"Unknown implementation: {implementation} (supported: 'pt', 'flax', 'tf', 'vllm')")
|
||||
|
||||
# NOTE: The set marks contains a set of modules name
|
||||
# that are available above and are whitelisted
|
||||
# to be included in the extra_objects map.
|
||||
_whitelist_modules = {"pkg"}
|
||||
|
||||
# XXX: define all classes, functions import above this line
|
||||
# since _extras will be the locals() import from this file.
|
||||
_extras: dict[str, t.Any] = {k: v for k, v in locals().items() if k in _whitelist_modules or (not isinstance(v, types.ModuleType) and not k.startswith("_"))}
|
||||
_extras["__openllm_migration__"] = {"ModelEnv": "EnvVarMixin"}
|
||||
_import_structure: dict[str, list[str]] = {
|
||||
"analytics": [], "codegen": [], "dantic": [], "dummy_flax_objects": [], "dummy_pt_objects": [], "dummy_tf_objects": [], "dummy_vllm_objects": [], "representation": ["ReprMixin"], "lazy": ["LazyModule"],
|
||||
"import_utils": ["OPTIONAL_DEPENDENCIES", "ENV_VARS_TRUE_VALUES", "DummyMetaclass", "EnvVarMixin", "require_backends",
|
||||
"is_cpm_kernels_available", "is_einops_available", "is_flax_available", "is_tf_available", "is_vllm_available", "is_torch_available", "is_bitsandbytes_available", "is_peft_available", "is_datasets_available",
|
||||
"is_transformers_supports_kbit", "is_transformers_supports_agent", "is_jupyter_available", "is_jupytext_available", "is_notebook_available", "is_triton_available", "is_autogptq_available", "is_sentencepiece_available",
|
||||
"is_xformers_available", "is_fairscale_available"]}
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
# NOTE: The following exports useful utils from bentoml
|
||||
from . import (
|
||||
analytics as analytics,
|
||||
codegen as codegen,
|
||||
dantic as dantic,
|
||||
dummy_flax_objects as dummy_flax_objects,
|
||||
dummy_pt_objects as dummy_pt_objects,
|
||||
dummy_tf_objects as dummy_tf_objects,
|
||||
dummy_vllm_objects as dummy_vllm_objects,
|
||||
)
|
||||
from .import_utils import (
|
||||
ENV_VARS_TRUE_VALUES as ENV_VARS_TRUE_VALUES,
|
||||
OPTIONAL_DEPENDENCIES as OPTIONAL_DEPENDENCIES,
|
||||
DummyMetaclass as DummyMetaclass,
|
||||
EnvVarMixin as EnvVarMixin,
|
||||
is_autogptq_available as is_autogptq_available,
|
||||
is_bitsandbytes_available as is_bitsandbytes_available,
|
||||
is_cpm_kernels_available as is_cpm_kernels_available,
|
||||
is_datasets_available as is_datasets_available,
|
||||
is_einops_available as is_einops_available,
|
||||
is_fairscale_available as is_fairscale_available,
|
||||
is_flax_available as is_flax_available,
|
||||
is_jupyter_available as is_jupyter_available,
|
||||
is_jupytext_available as is_jupytext_available,
|
||||
is_notebook_available as is_notebook_available,
|
||||
is_peft_available as is_peft_available,
|
||||
is_sentencepiece_available as is_sentencepiece_available,
|
||||
is_tf_available as is_tf_available,
|
||||
is_torch_available as is_torch_available,
|
||||
is_transformers_supports_agent as is_transformers_supports_agent,
|
||||
is_transformers_supports_kbit as is_transformers_supports_kbit,
|
||||
is_triton_available as is_triton_available,
|
||||
is_vllm_available as is_vllm_available,
|
||||
is_xformers_available as is_xformers_available,
|
||||
require_backends as require_backends,
|
||||
)
|
||||
from .representation import ReprMixin as ReprMixin
|
||||
|
||||
__lazy = LazyModule(__name__, globals()["__file__"], _import_structure, extra_objects=_extras)
|
||||
__all__ = __lazy.__all__
|
||||
__dir__ = __lazy.__dir__
|
||||
__getattr__ = __lazy.__getattr__
|
||||
__all__ = ["generate_labels", "infer_auto_class", "dummy_flax_objects", "dummy_pt_objects", "dummy_tf_objects", "dummy_vllm_objects"]
|
||||
def __dir__() -> t.Sequence[str]: return sorted(__all__)
|
||||
def __getattr__(it: str) -> t.Any:
|
||||
if hasattr(openllm_core.utils, it): return getattr(openllm_core.utils, it)
|
||||
else: raise AttributeError(f"module {__name__} has no attribute {it}")
|
||||
|
||||
@@ -1,80 +0,0 @@
|
||||
"""Telemetry related for OpenLLM tracking.
|
||||
|
||||
Users can disable this with OPENLLM_DO_NOT_TRACK envvar.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import contextlib, functools, logging, os, re, typing as t, importlib.metadata
|
||||
import attr, openllm
|
||||
from bentoml._internal.utils import analytics as _internal_analytics
|
||||
from openllm._typing_compat import ParamSpec
|
||||
|
||||
P = ParamSpec("P")
|
||||
T = t.TypeVar("T")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# This variable is a proxy that will control BENTOML_DO_NOT_TRACK
|
||||
OPENLLM_DO_NOT_TRACK = "OPENLLM_DO_NOT_TRACK"
|
||||
DO_NOT_TRACK = os.environ.get(OPENLLM_DO_NOT_TRACK, str(False)).upper()
|
||||
|
||||
@functools.lru_cache(maxsize=1)
|
||||
def do_not_track() -> bool: return DO_NOT_TRACK in openllm.utils.ENV_VARS_TRUE_VALUES
|
||||
@functools.lru_cache(maxsize=1)
|
||||
def _usage_event_debugging() -> bool: return os.environ.get("__BENTOML_DEBUG_USAGE", str(False)).lower() == "true"
|
||||
|
||||
def silent(func: t.Callable[P, T]) -> t.Callable[P, T]:
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> t.Any:
|
||||
try: return func(*args, **kwargs)
|
||||
except Exception as err:
|
||||
if _usage_event_debugging():
|
||||
if openllm.utils.get_debug_mode(): logger.error("Tracking Error: %s", err, stack_info=True, stacklevel=3)
|
||||
else: logger.info("Tracking Error: %s", err)
|
||||
else: logger.debug("Tracking Error: %s", err)
|
||||
return wrapper
|
||||
|
||||
@silent
|
||||
def track(event_properties: attr.AttrsInstance) -> None:
|
||||
if do_not_track(): return
|
||||
_internal_analytics.track(t.cast("_internal_analytics.schemas.EventMeta", event_properties))
|
||||
|
||||
@contextlib.contextmanager
|
||||
def set_bentoml_tracking() -> t.Generator[None, None, None]:
|
||||
original_value = os.environ.pop(_internal_analytics.BENTOML_DO_NOT_TRACK, str(False))
|
||||
try:
|
||||
os.environ[_internal_analytics.BENTOML_DO_NOT_TRACK] = str(do_not_track())
|
||||
yield
|
||||
finally: os.environ[_internal_analytics.BENTOML_DO_NOT_TRACK] = original_value
|
||||
|
||||
class EventMeta:
|
||||
@property
|
||||
def event_name(self) -> str:
|
||||
# camel case to snake case
|
||||
event_name = re.sub(r"(?<!^)(?=[A-Z])", "_", self.__class__.__name__).lower()
|
||||
# remove "_event" suffix
|
||||
suffix_to_remove = "_event"
|
||||
if event_name.endswith(suffix_to_remove): event_name = event_name[:-len(suffix_to_remove)]
|
||||
return event_name
|
||||
|
||||
@attr.define
|
||||
class ModelSaveEvent(EventMeta):
|
||||
module: str
|
||||
model_size_in_kb: float
|
||||
@attr.define
|
||||
class OpenllmCliEvent(EventMeta):
|
||||
cmd_group: str
|
||||
cmd_name: str
|
||||
openllm_version: str = importlib.metadata.version("openllm")
|
||||
# NOTE: reserved for the do_not_track logics
|
||||
duration_in_ms: t.Any = attr.field(default=None)
|
||||
error_type: str = attr.field(default=None)
|
||||
return_code: int = attr.field(default=None)
|
||||
@attr.define
|
||||
class StartInitEvent(EventMeta):
|
||||
model_name: str
|
||||
llm_config: t.Dict[str, t.Any] = attr.field(default=None)
|
||||
@staticmethod
|
||||
def handler(llm_config: openllm.LLMConfig) -> StartInitEvent: return StartInitEvent(model_name=llm_config["model_name"], llm_config=llm_config.model_dump())
|
||||
|
||||
def track_start_init(llm_config: openllm.LLMConfig) -> None:
|
||||
if do_not_track(): return
|
||||
track(StartInitEvent.handler(llm_config))
|
||||
@@ -1,141 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import functools, inspect, linecache, os, logging, string, types, typing as t
|
||||
from operator import itemgetter
|
||||
from pathlib import Path
|
||||
import orjson
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from fs.base import FS
|
||||
|
||||
import openllm
|
||||
from openllm._typing_compat import LiteralString, AnyCallable, DictStrAny, ListStr
|
||||
PartialAny = functools.partial[t.Any]
|
||||
|
||||
_T = t.TypeVar("_T", bound=t.Callable[..., t.Any])
|
||||
logger = logging.getLogger(__name__)
|
||||
OPENLLM_MODEL_NAME = "# openllm: model name"
|
||||
OPENLLM_MODEL_ADAPTER_MAP = "# openllm: model adapter map"
|
||||
class ModelNameFormatter(string.Formatter):
|
||||
model_keyword: LiteralString = "__model_name__"
|
||||
def __init__(self, model_name: str):
|
||||
"""The formatter that extends model_name to be formatted the 'service.py'."""
|
||||
super().__init__()
|
||||
self.model_name = model_name
|
||||
def vformat(self, format_string: str, *args: t.Any, **attrs: t.Any) -> t.Any: return super().vformat(format_string, (), {self.model_keyword: self.model_name})
|
||||
def can_format(self, value: str) -> bool:
|
||||
try:
|
||||
self.parse(value)
|
||||
return True
|
||||
except ValueError: return False
|
||||
class ModelIdFormatter(ModelNameFormatter):
|
||||
model_keyword: LiteralString = "__model_id__"
|
||||
class ModelAdapterMapFormatter(ModelNameFormatter):
|
||||
model_keyword: LiteralString = "__model_adapter_map__"
|
||||
|
||||
_service_file = Path(os.path.abspath(__file__)).parent.parent/"_service.py"
|
||||
def write_service(llm: openllm.LLM[t.Any, t.Any], adapter_map: dict[str, str | None] | None, llm_fs: FS) -> None:
|
||||
from openllm.utils import DEBUG
|
||||
model_name = llm.config["model_name"]
|
||||
logger.debug("Generating service file for %s at %s (dir=%s)", model_name, llm.config["service_name"], llm_fs.getsyspath("/"))
|
||||
with open(_service_file.__fspath__(), "r") as f: src_contents = f.readlines()
|
||||
for it in src_contents:
|
||||
if OPENLLM_MODEL_NAME in it: src_contents[src_contents.index(it)] = (ModelNameFormatter(model_name).vformat(it)[:-(len(OPENLLM_MODEL_NAME) + 3)] + "\n")
|
||||
elif OPENLLM_MODEL_ADAPTER_MAP in it: src_contents[src_contents.index(it)] = (ModelAdapterMapFormatter(orjson.dumps(adapter_map).decode()).vformat(it)[:-(len(OPENLLM_MODEL_ADAPTER_MAP) + 3)] + "\n")
|
||||
script = f"# GENERATED BY 'openllm build {model_name}'. DO NOT EDIT\n\n" + "".join(src_contents)
|
||||
if DEBUG: logger.info("Generated script:\n%s", script)
|
||||
llm_fs.writetext(llm.config["service_name"], script)
|
||||
|
||||
# sentinel object for unequivocal object() getattr
|
||||
_sentinel = object()
|
||||
def has_own_attribute(cls: type[t.Any], attrib_name: t.Any) -> bool:
|
||||
"""Check whether *cls* defines *attrib_name* (and doesn't just inherit it)."""
|
||||
attr = getattr(cls, attrib_name, _sentinel)
|
||||
if attr is _sentinel: return False
|
||||
for base_cls in cls.__mro__[1:]:
|
||||
a = getattr(base_cls, attrib_name, None)
|
||||
if attr is a: return False
|
||||
return True
|
||||
def get_annotations(cls: type[t.Any]) -> DictStrAny:
|
||||
if has_own_attribute(cls, "__annotations__"): return cls.__annotations__
|
||||
return t.cast("DictStrAny", {})
|
||||
|
||||
def is_class_var(annot: str | t.Any) -> bool:
|
||||
annot = str(annot)
|
||||
# Annotation can be quoted.
|
||||
if annot.startswith(("'", '"')) and annot.endswith(("'", '"')): annot = annot[1:-1]
|
||||
return annot.startswith(("typing.ClassVar", "t.ClassVar", "ClassVar", "typing_extensions.ClassVar",))
|
||||
def add_method_dunders(cls: type[t.Any], method_or_cls: _T, _overwrite_doc: str | None = None) -> _T:
|
||||
try: method_or_cls.__module__ = cls.__module__
|
||||
except AttributeError: pass
|
||||
try: method_or_cls.__qualname__ = f"{cls.__qualname__}.{method_or_cls.__name__}"
|
||||
except AttributeError: pass
|
||||
try: method_or_cls.__doc__ = _overwrite_doc or "Generated by ``openllm.LLMConfig`` for class " f"{cls.__qualname__}."
|
||||
except AttributeError: pass
|
||||
return method_or_cls
|
||||
def _compile_and_eval(script: str, globs: DictStrAny, locs: t.Any = None, filename: str = "") -> None: eval(compile(script, filename, "exec"), globs, locs) # noqa: S307
|
||||
def _make_method(name: str, script: str, filename: str, globs: DictStrAny) -> AnyCallable:
|
||||
locs: DictStrAny = {}
|
||||
# In order of debuggers like PDB being able to step through the code, we add a fake linecache entry.
|
||||
count = 1
|
||||
base_filename = filename
|
||||
while True:
|
||||
linecache_tuple = (len(script), None, script.splitlines(True), filename)
|
||||
old_val = linecache.cache.setdefault(filename, linecache_tuple)
|
||||
if old_val == linecache_tuple: break
|
||||
else:
|
||||
filename = f"{base_filename[:-1]}-{count}>"
|
||||
count += 1
|
||||
_compile_and_eval(script, globs, locs, filename)
|
||||
return locs[name]
|
||||
|
||||
def make_attr_tuple_class(cls_name: str, attr_names: t.Sequence[str]) -> type[t.Any]:
|
||||
"""Create a tuple subclass to hold class attributes.
|
||||
|
||||
The subclass is a bare tuple with properties for names.
|
||||
|
||||
class MyClassAttributes(tuple):
|
||||
__slots__ = ()
|
||||
x = property(itemgetter(0))
|
||||
"""
|
||||
from . import SHOW_CODEGEN
|
||||
|
||||
attr_class_name = f"{cls_name}Attributes"
|
||||
attr_class_template = [f"class {attr_class_name}(tuple):", " __slots__ = ()",]
|
||||
if attr_names:
|
||||
for i, attr_name in enumerate(attr_names): attr_class_template.append(f" {attr_name} = _attrs_property(_attrs_itemgetter({i}))")
|
||||
else: attr_class_template.append(" pass")
|
||||
globs: DictStrAny = {"_attrs_itemgetter": itemgetter, "_attrs_property": property}
|
||||
if SHOW_CODEGEN: logger.info("Generated class for %s:\n\n%s", attr_class_name, "\n".join(attr_class_template))
|
||||
_compile_and_eval("\n".join(attr_class_template), globs)
|
||||
return globs[attr_class_name]
|
||||
|
||||
def generate_unique_filename(cls: type[t.Any], func_name: str) -> str: return f"<{cls.__name__} generated {func_name} {cls.__module__}.{getattr(cls, '__qualname__', cls.__name__)}>"
|
||||
def generate_function(typ: type[t.Any], func_name: str, lines: list[str] | None, args: tuple[str, ...] | None, globs: dict[str, t.Any], annotations: dict[str, t.Any] | None = None) -> AnyCallable:
|
||||
from openllm.utils import SHOW_CODEGEN
|
||||
script = "def %s(%s):\n %s\n" % (func_name, ", ".join(args) if args is not None else "", "\n ".join(lines) if lines else "pass")
|
||||
meth = _make_method(func_name, script, generate_unique_filename(typ, func_name), globs)
|
||||
if annotations: meth.__annotations__ = annotations
|
||||
if SHOW_CODEGEN: logger.info("Generated script for %s:\n\n%s", typ, script)
|
||||
return meth
|
||||
|
||||
def make_env_transformer(cls: type[openllm.LLMConfig], model_name: str, suffix: LiteralString | None = None, default_callback: t.Callable[[str, t.Any], t.Any] | None = None, globs: DictStrAny | None = None,) -> AnyCallable:
|
||||
from openllm.utils import dantic, field_env_key
|
||||
def identity(_: str, x_value: t.Any) -> t.Any: return x_value
|
||||
default_callback = identity if default_callback is None else default_callback
|
||||
globs = {} if globs is None else globs
|
||||
globs.update({"__populate_env": dantic.env_converter, "__default_callback": default_callback, "__field_env": field_env_key, "__suffix": suffix or "", "__model_name": model_name,})
|
||||
lines: ListStr = ["__env = lambda field_name: __field_env(__model_name, field_name, __suffix)", "return [", " f.evolve(", " default=__populate_env(__default_callback(f.name, f.default), __env(f.name)),", " metadata={", " 'env': f.metadata.get('env', __env(f.name)),", " 'description': f.metadata.get('description', '(not provided)'),", " },", " )", " for f in fields", "]"]
|
||||
fields_ann = "list[attr.Attribute[t.Any]]"
|
||||
return generate_function(cls, "__auto_env", lines, args=("_", "fields"), globs=globs, annotations={"_": "type[LLMConfig]", "fields": fields_ann, "return": fields_ann})
|
||||
def gen_sdk(func: _T, name: str | None = None, **attrs: t.Any) -> _T:
|
||||
"""Enhance sdk with nice repr that plays well with your brain."""
|
||||
from openllm.utils import ReprMixin
|
||||
if name is None: name = func.__name__.strip("_")
|
||||
_signatures = inspect.signature(func).parameters
|
||||
def _repr(self: ReprMixin) -> str: return f"<generated function {name} {orjson.dumps(dict(self.__repr_args__()), option=orjson.OPT_NON_STR_KEYS | orjson.OPT_INDENT_2).decode()}>"
|
||||
def _repr_args(self: ReprMixin) -> t.Iterator[t.Tuple[str, t.Any]]: return ((k, _signatures[k].annotation) for k in self.__repr_keys__)
|
||||
if func.__doc__ is None: doc = f"Generated SDK for {func.__name__}"
|
||||
else: doc = func.__doc__
|
||||
return t.cast(_T, functools.update_wrapper(types.new_class(name, (t.cast("PartialAny", functools.partial), ReprMixin), exec_body=lambda ns: ns.update({"__repr_keys__": property(lambda _: [i for i in _signatures.keys() if not i.startswith("_")]), "__repr_args__": _repr_args, "__repr__": _repr, "__doc__": inspect.cleandoc(doc), "__module__": "openllm",}),)(func, **attrs), func,))
|
||||
|
||||
__all__ = ["gen_sdk", "make_attr_tuple_class", "make_env_transformer", "generate_unique_filename", "generate_function", "OPENLLM_MODEL_NAME", "OPENLLM_MODEL_ADAPTER_MAP"]
|
||||
@@ -1,387 +0,0 @@
|
||||
"""An interface provides the best of pydantic and attrs."""
|
||||
from __future__ import annotations
|
||||
import functools, importlib, os, sys, typing as t
|
||||
from enum import Enum
|
||||
import attr, click, click_option_group as cog, inflection, orjson
|
||||
from click import (
|
||||
ParamType,
|
||||
shell_completion as sc,
|
||||
types as click_types,
|
||||
)
|
||||
|
||||
if t.TYPE_CHECKING: from attr import _ValidatorType
|
||||
|
||||
AnyCallable = t.Callable[..., t.Any]
|
||||
FC = t.TypeVar("FC", bound=t.Union[AnyCallable, click.Command])
|
||||
|
||||
__all__ = ["FC", "attrs_to_options", "Field", "parse_type", "is_typing", "is_literal", "ModuleType", "EnumChoice", "LiteralChoice", "allows_multiple", "is_mapping", "is_container", "parse_container_args", "parse_single_arg", "CUDA", "JsonType", "BytesType"]
|
||||
def __dir__() -> list[str]: return sorted(__all__)
|
||||
|
||||
def attrs_to_options(name: str, field: attr.Attribute[t.Any], model_name: str, typ: t.Any | None = None, suffix_generation: bool = False, suffix_sampling: bool = False,) -> t.Callable[[FC], FC]:
|
||||
# TODO: support parsing nested attrs class and Union
|
||||
envvar = field.metadata["env"]
|
||||
dasherized = inflection.dasherize(name)
|
||||
underscored = inflection.underscore(name)
|
||||
|
||||
if typ in (None, attr.NOTHING):
|
||||
typ = field.type
|
||||
if typ is None: raise RuntimeError(f"Failed to parse type for {name}")
|
||||
|
||||
full_option_name = f"--{dasherized}"
|
||||
if field.type is bool: full_option_name += f"/--no-{dasherized}"
|
||||
if suffix_generation: identifier = f"{model_name}_generation_{underscored}"
|
||||
elif suffix_sampling: identifier = f"{model_name}_sampling_{underscored}"
|
||||
else: identifier = f"{model_name}_{underscored}"
|
||||
|
||||
return cog.optgroup.option(identifier, full_option_name, type=parse_type(typ), required=field.default is attr.NOTHING, default=field.default if field.default not in (attr.NOTHING, None) else None, show_default=True, multiple=allows_multiple(typ) if typ else False, help=field.metadata.get("description", "(No description provided)"), show_envvar=True, envvar=envvar,)
|
||||
|
||||
def env_converter(value: t.Any, env: str | None = None) -> t.Any:
|
||||
if env is not None:
|
||||
value = os.environ.get(env, value)
|
||||
if value is not None and isinstance(value, str):
|
||||
try: return orjson.loads(value.lower())
|
||||
except orjson.JSONDecodeError as err: raise RuntimeError(f"Failed to parse ({value!r}) from '{env}': {err}") from None
|
||||
return value
|
||||
|
||||
def Field(default: t.Any = None, *, ge: int | float | None = None, le: int | float | None = None, validator: _ValidatorType[t.Any] | None = None, description: str | None = None, env: str | None = None, auto_default: bool = False, use_default_converter: bool = True, **attrs: t.Any) -> t.Any:
|
||||
"""A decorator that extends attr.field with additional arguments, which provides the same interface as pydantic's Field.
|
||||
|
||||
By default, if both validator and ge are provided, then then ge will be
|
||||
piped into first, then all of the other validator will be run afterwards.
|
||||
|
||||
Args:
|
||||
default: The default value for ``dantic.Field``. Defaults to ``None``.
|
||||
ge: Greater than or equal to. Defaults to None.
|
||||
le: Less than or equal to. Defaults to None.
|
||||
validator: Optional attrs-compatible validators type. Default to None
|
||||
description: the documentation for the field. Defaults to None.
|
||||
env: the environment variable to read from. Defaults to None.
|
||||
auto_default: a bool indicating whether to use the default value as the environment.
|
||||
Defaults to False. If set to True, the behaviour of this Field will also depends
|
||||
on kw_only. If kw_only=True, the this field will become 'Required' and the default
|
||||
value is omitted. If kw_only=False, then the default value will be used as before.
|
||||
use_default_converter: a bool indicating whether to use the default converter. Defaults
|
||||
to True. If set to False, then the default converter will not be used.
|
||||
The default converter converts a given value from the environment variable
|
||||
for this given Field.
|
||||
**attrs: The rest of the arguments are passed to attr.field
|
||||
"""
|
||||
metadata = attrs.pop("metadata", {})
|
||||
if description is None: description = "(No description provided)"
|
||||
metadata["description"] = description
|
||||
if env is not None: metadata["env"] = env
|
||||
piped: list[_ValidatorType[t.Any]] = []
|
||||
|
||||
converter = attrs.pop("converter", None)
|
||||
if use_default_converter: converter = functools.partial(env_converter, env=env)
|
||||
|
||||
if ge is not None: piped.append(attr.validators.ge(ge))
|
||||
if le is not None: piped.append(attr.validators.le(le))
|
||||
if validator is not None: piped.append(validator)
|
||||
|
||||
if len(piped) == 0: _validator = None
|
||||
elif len(piped) == 1: _validator = piped[0]
|
||||
else: _validator = attr.validators.and_(*piped)
|
||||
|
||||
factory = attrs.pop("factory", None)
|
||||
if factory is not None and default is not None: raise RuntimeError("'factory' and 'default' are mutually exclusive.")
|
||||
# NOTE: the behaviour of this is we will respect factory over the default
|
||||
if factory is not None: attrs["factory"] = factory
|
||||
else: attrs["default"] = default
|
||||
|
||||
kw_only = attrs.pop("kw_only", False)
|
||||
if auto_default and kw_only:
|
||||
attrs.pop("default")
|
||||
|
||||
return attr.field(metadata=metadata, validator=_validator, converter=converter, **attrs)
|
||||
|
||||
def parse_type(field_type: t.Any) -> ParamType | tuple[ParamType, ...]:
|
||||
"""Transforms the pydantic field's type into a click-compatible type.
|
||||
|
||||
Args:
|
||||
field_type: pydantic field type
|
||||
|
||||
Returns:
|
||||
ParamType: click type equivalent
|
||||
"""
|
||||
from . import lenient_issubclass
|
||||
|
||||
if t.get_origin(field_type) is t.Union:
|
||||
raise NotImplementedError("Unions are not supported")
|
||||
# enumeration strings or other Enum derivatives
|
||||
if lenient_issubclass(field_type, Enum):
|
||||
return EnumChoice(enum=field_type, case_sensitive=True)
|
||||
# literals are enum-like with way less functionality
|
||||
if is_literal(field_type):
|
||||
return LiteralChoice(value=field_type, case_sensitive=True)
|
||||
# modules, classes, functions
|
||||
if is_typing(field_type): return ModuleType()
|
||||
# entire dictionaries:
|
||||
# using a Dict, convert in advance
|
||||
if is_mapping(field_type): return JsonType()
|
||||
# list, List[p], Tuple[p], Set[p] and so on
|
||||
if is_container(field_type): return parse_container_args(field_type)
|
||||
# bytes are not natively supported by click
|
||||
if lenient_issubclass(field_type, bytes): return BytesType()
|
||||
# return the current type: it should be a primitive
|
||||
return field_type
|
||||
|
||||
def is_typing(field_type: type) -> bool:
|
||||
"""Checks whether the current type is a module-like type.
|
||||
|
||||
Args:
|
||||
field_type: pydantic field type
|
||||
|
||||
Returns:
|
||||
bool: true if the type is itself a type
|
||||
"""
|
||||
raw = t.get_origin(field_type)
|
||||
if raw is None: return False
|
||||
if raw is type or raw is t.Type: return True
|
||||
return False
|
||||
|
||||
def is_literal(field_type: type) -> bool:
|
||||
"""Checks whether the given field type is a Literal type or not.
|
||||
|
||||
Literals are weird: isinstance and subclass do not work, so you compare
|
||||
the origin with the Literal declaration itself.
|
||||
|
||||
Args:
|
||||
field_type: current pydantic type
|
||||
|
||||
Returns:
|
||||
bool: true if Literal type, false otherwise
|
||||
"""
|
||||
origin = t.get_origin(field_type)
|
||||
return origin is not None and origin is t.Literal
|
||||
|
||||
class ModuleType(ParamType):
|
||||
name = "module"
|
||||
|
||||
def _import_object(self, value: str) -> t.Any:
|
||||
module_name, class_name = value.rsplit(".", maxsplit=1)
|
||||
if not all(s.isidentifier() for s in module_name.split(".")): raise ValueError(f"'{value}' is not a valid module name")
|
||||
if not class_name.isidentifier(): raise ValueError(f"Variable '{class_name}' is not a valid identifier")
|
||||
|
||||
module = importlib.import_module(module_name)
|
||||
if class_name:
|
||||
try: return getattr(module, class_name)
|
||||
except AttributeError: raise ImportError(f"Module '{module_name}' does not define a '{class_name}' variable.") from None
|
||||
|
||||
def convert(self, value: str | t.Any, param: click.Parameter | None, ctx: click.Context | None) -> t.Any:
|
||||
try:
|
||||
if isinstance(value, str): return self._import_object(value)
|
||||
return value
|
||||
except Exception as exc: self.fail(f"'{value}' is not a valid object ({type(exc)}: {exc!s})", param, ctx)
|
||||
|
||||
class EnumChoice(click.Choice):
|
||||
name = "enum"
|
||||
|
||||
def __init__(self, enum: Enum, case_sensitive: bool = False):
|
||||
"""Enum type support for click that extends ``click.Choice``.
|
||||
|
||||
Args:
|
||||
enum: Given enum
|
||||
case_sensitive: Whether this choice should be case case_sensitive.
|
||||
"""
|
||||
self.mapping = enum
|
||||
self.internal_type = type(enum)
|
||||
choices: list[t.Any] = [e.name for e in enum.__class__]
|
||||
super().__init__(choices, case_sensitive)
|
||||
|
||||
def convert(self, value: t.Any, param: click.Parameter | None, ctx: click.Context | None) -> Enum:
|
||||
if isinstance(value, self.internal_type):
|
||||
return value
|
||||
result = super().convert(value, param, ctx)
|
||||
if isinstance(result, str):
|
||||
result = self.internal_type[result]
|
||||
return result
|
||||
|
||||
class LiteralChoice(EnumChoice):
|
||||
name = "literal"
|
||||
|
||||
def __init__(self, value: t.Any, case_sensitive: bool = False):
|
||||
"""Literal support for click."""
|
||||
# expect every literal value to belong to the same primitive type
|
||||
values = list(value.__args__)
|
||||
item_type = type(values[0])
|
||||
if not all(isinstance(v, item_type) for v in values): raise ValueError(f"Field {value} contains items of different types.")
|
||||
_mapping = {str(v): v for v in values}
|
||||
super(EnumChoice, self).__init__(list(_mapping), case_sensitive)
|
||||
self.internal_type = item_type
|
||||
|
||||
def allows_multiple(field_type: type[t.Any]) -> bool:
|
||||
"""Checks whether the current type allows for multiple arguments to be provided as input or not.
|
||||
|
||||
For containers, it exploits click's support for lists and such to use the same option multiple times
|
||||
to create a complex object: `python run.py --subsets train --subsets test`
|
||||
# becomes `subsets: ["train", "test"]`.
|
||||
|
||||
Args:
|
||||
field_type: pydantic type.
|
||||
|
||||
Returns:
|
||||
bool: true if it's a composite field (lists, containers and so on), false otherwise
|
||||
"""
|
||||
# Early out for mappings, since it's better to deal with them using strings.
|
||||
if is_mapping(field_type):
|
||||
return False
|
||||
# Activate multiple option for (simple) container types
|
||||
if is_container(field_type):
|
||||
args = parse_container_args(field_type)
|
||||
# A non-composite type has a single argument, such as 'List[int]'
|
||||
# A composite type has a tuple of arguments, like 'Tuple[str, int, int]'.
|
||||
# For the moment, only non-composite types are allowed.
|
||||
return not isinstance(args, tuple)
|
||||
return False
|
||||
|
||||
def is_mapping(field_type: type) -> bool:
|
||||
"""Checks whether this field represents a dictionary or JSON object.
|
||||
|
||||
Args:
|
||||
field_type (type): pydantic type
|
||||
|
||||
Returns:
|
||||
bool: true when the field is a dict-like object, false otherwise.
|
||||
"""
|
||||
# Early out for standard containers.
|
||||
from . import lenient_issubclass
|
||||
if lenient_issubclass(field_type, t.Mapping): return True
|
||||
# for everything else or when the typing is more complex, check its origin
|
||||
origin = t.get_origin(field_type)
|
||||
if origin is None: return False
|
||||
return lenient_issubclass(origin, t.Mapping)
|
||||
|
||||
def is_container(field_type: type) -> bool:
|
||||
"""Checks whether the current type is a container type ('contains' other types), like lists and tuples.
|
||||
|
||||
Args:
|
||||
field_type: pydantic field type
|
||||
|
||||
Returns:
|
||||
bool: true if a container, false otherwise
|
||||
"""
|
||||
# do not consider strings or byte arrays as containers
|
||||
if field_type in (str, bytes): return False
|
||||
# Early out for standard containers: list, tuple, range
|
||||
from . import lenient_issubclass
|
||||
if lenient_issubclass(field_type, t.Container): return True
|
||||
origin = t.get_origin(field_type)
|
||||
# Early out for non-typing objects
|
||||
if origin is None: return False
|
||||
return lenient_issubclass(origin, t.Container)
|
||||
|
||||
def parse_container_args(field_type: type[t.Any]) -> ParamType | tuple[ParamType, ...]:
|
||||
"""Parses the arguments inside a container type (lists, tuples and so on).
|
||||
|
||||
Args:
|
||||
field_type: pydantic field type
|
||||
|
||||
Returns:
|
||||
ParamType | tuple[ParamType]: single click-compatible type or a tuple
|
||||
"""
|
||||
if not is_container(field_type):
|
||||
raise ValueError("Field type is not a container type.")
|
||||
args = t.get_args(field_type)
|
||||
# Early out for untyped containers: standard lists, tuples, List[Any]
|
||||
# Use strings when the type is unknown, avoid click's type guessing
|
||||
if len(args) == 0:
|
||||
return click_types.convert_type(str)
|
||||
# Early out for homogenous containers: Tuple[int], List[str]
|
||||
# or homogenous tuples of indefinite length: Tuple[int, ...]
|
||||
if len(args) == 1 or (len(args) == 2 and args[1] is Ellipsis):
|
||||
return parse_single_arg(args[0])
|
||||
# Then deal with fixed-length containers: Tuple[str, int, int]
|
||||
return tuple(parse_single_arg(arg) for arg in args)
|
||||
|
||||
def parse_single_arg(arg: type) -> ParamType:
|
||||
"""Returns the click-compatible type for container origin types.
|
||||
|
||||
In this case, returns string when it's not inferrable, a JSON for mappings
|
||||
and the original type itself in every other case (ints, floats and so on).
|
||||
Bytes is a special case, not natively handled by click.
|
||||
|
||||
Args:
|
||||
arg (type): single argument
|
||||
|
||||
Returns:
|
||||
ParamType: click-compatible type
|
||||
"""
|
||||
from . import lenient_issubclass
|
||||
# When we don't know the type, we choose 'str'
|
||||
if arg is t.Any: return click_types.convert_type(str)
|
||||
# For containers and nested models, we use JSON
|
||||
if is_container(arg): return JsonType()
|
||||
if lenient_issubclass(arg, bytes): return BytesType()
|
||||
return click_types.convert_type(arg)
|
||||
|
||||
class BytesType(ParamType):
|
||||
name = "bytes"
|
||||
def convert(self, value: t.Any, param: click.Parameter | None, ctx: click.Context | None) -> t.Any:
|
||||
if isinstance(value, bytes): return value
|
||||
try: return str.encode(value)
|
||||
except Exception as exc: self.fail(f"'{value}' is not a valid string ({exc!s})", param, ctx)
|
||||
|
||||
CYGWIN = sys.platform.startswith("cygwin")
|
||||
WIN = sys.platform.startswith("win")
|
||||
if sys.platform.startswith("win") and WIN:
|
||||
def _get_argv_encoding() -> str:
|
||||
import locale
|
||||
return locale.getpreferredencoding()
|
||||
else:
|
||||
def _get_argv_encoding() -> str: return getattr(sys.stdin, "encoding", None) or sys.getfilesystemencoding()
|
||||
|
||||
class CudaValueType(ParamType):
|
||||
name = "cuda"
|
||||
envvar_list_splitter = ","
|
||||
is_composite = True
|
||||
typ = click_types.convert_type(str)
|
||||
|
||||
def split_envvar_value(self, rv: str) -> t.Sequence[str]:
|
||||
var = tuple(i for i in rv.split(self.envvar_list_splitter))
|
||||
if "-1" in var:
|
||||
return var[:var.index("-1")]
|
||||
return var
|
||||
def shell_complete(self, ctx: click.Context, param: click.Parameter, incomplete: str) -> list[sc.CompletionItem]:
|
||||
"""Return a list of :class:`~click.shell_completion.CompletionItem` objects for the incomplete value.
|
||||
|
||||
Most types do not provide completions, but some do, and this allows custom types to provide custom completions as well.
|
||||
|
||||
Args:
|
||||
ctx: Invocation context for this command.
|
||||
param: The parameter that is requesting completion.
|
||||
incomplete: Value being completed. May be empty.
|
||||
"""
|
||||
from openllm.utils import available_devices
|
||||
mapping = incomplete.split(self.envvar_list_splitter) if incomplete else available_devices()
|
||||
return [sc.CompletionItem(str(i), help=f"CUDA device index {i}") for i in mapping]
|
||||
def convert(self, value: t.Any, param: click.Parameter | None, ctx: click.Context | None) -> t.Any:
|
||||
if isinstance(value, bytes):
|
||||
enc = _get_argv_encoding()
|
||||
try: value = value.decode(enc)
|
||||
except UnicodeError:
|
||||
fs_enc = sys.getfilesystemencoding()
|
||||
if fs_enc != enc:
|
||||
try: value = value.decode(fs_enc)
|
||||
except UnicodeError: value = value.decode("utf-8", "replace")
|
||||
else: value = value.decode("utf-8", "replace")
|
||||
return tuple(self.typ(x, param, ctx) for x in value.split(","))
|
||||
|
||||
def __repr__(self) -> str: return "STRING"
|
||||
|
||||
CUDA = CudaValueType()
|
||||
|
||||
class JsonType(ParamType):
|
||||
name = "json"
|
||||
def __init__(self, should_load: bool = True) -> None:
|
||||
"""Support JSON type for click.ParamType.
|
||||
|
||||
Args:
|
||||
should_load: Whether to load the JSON. Default to True. If False, the value won't be converted.
|
||||
"""
|
||||
super().__init__()
|
||||
self.should_load = should_load
|
||||
def convert(self, value: t.Any, param: click.Parameter | None, ctx: click.Context | None) -> t.Any:
|
||||
if isinstance(value, dict) or not self.should_load: return value
|
||||
try: return orjson.loads(value)
|
||||
except orjson.JSONDecodeError as exc: self.fail(f"'{value}' is not a valid JSON string ({exc!s})", param, ctx)
|
||||
@@ -2,7 +2,7 @@
|
||||
# To update this, run ./tools/update-dummy.py
|
||||
from __future__ import annotations
|
||||
import typing as _t
|
||||
from openllm.utils import DummyMetaclass as _DummyMetaclass, require_backends as _require_backends
|
||||
from openllm_core.utils import DummyMetaclass as _DummyMetaclass, require_backends as _require_backends
|
||||
class FlaxFlanT5(metaclass=_DummyMetaclass):
|
||||
_backends=["flax"]
|
||||
def __init__(self,*param_decls:_t.Any,**attrs: _t.Any):_require_backends(self,["flax"])
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
# To update this, run ./tools/update-dummy.py
|
||||
from __future__ import annotations
|
||||
import typing as _t
|
||||
from openllm.utils import DummyMetaclass as _DummyMetaclass, require_backends as _require_backends
|
||||
from openllm_core.utils import DummyMetaclass as _DummyMetaclass, require_backends as _require_backends
|
||||
class ChatGLM(metaclass=_DummyMetaclass):
|
||||
_backends=["torch","cpm_kernels","sentencepiece"]
|
||||
def __init__(self,*param_decls:_t.Any,**attrs: _t.Any):_require_backends(self,["torch","cpm_kernels","sentencepiece"])
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
# To update this, run ./tools/update-dummy.py
|
||||
from __future__ import annotations
|
||||
import typing as _t
|
||||
from openllm.utils import DummyMetaclass as _DummyMetaclass, require_backends as _require_backends
|
||||
from openllm_core.utils import DummyMetaclass as _DummyMetaclass, require_backends as _require_backends
|
||||
class TFFlanT5(metaclass=_DummyMetaclass):
|
||||
_backends=["tensorflow"]
|
||||
def __init__(self,*param_decls:_t.Any,**attrs: _t.Any):_require_backends(self,["tensorflow"])
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
# To update this, run ./tools/update-dummy.py
|
||||
from __future__ import annotations
|
||||
import typing as _t
|
||||
from openllm.utils import DummyMetaclass as _DummyMetaclass, require_backends as _require_backends
|
||||
from openllm_core.utils import DummyMetaclass as _DummyMetaclass, require_backends as _require_backends
|
||||
class VLLMBaichuan(metaclass=_DummyMetaclass):
|
||||
_backends=["vllm","cpm_kernels","sentencepiece"]
|
||||
def __init__(self,*param_decls:_t.Any,**attrs: _t.Any):_require_backends(self,["vllm","cpm_kernels","sentencepiece"])
|
||||
|
||||
@@ -1,312 +0,0 @@
|
||||
"""Some imports utils are vendorred from transformers/utils/import_utils.py for performance reasons."""
|
||||
from __future__ import annotations
|
||||
import importlib, importlib.metadata, importlib.util, logging, os, abc, typing as t
|
||||
from collections import OrderedDict
|
||||
import inflection, packaging.version
|
||||
from bentoml._internal.utils import LazyLoader, pkg
|
||||
from openllm._typing_compat import overload, LiteralString
|
||||
|
||||
from .representation import ReprMixin
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
BackendOrderedDict = OrderedDict[str, t.Tuple[t.Callable[[], bool], str]]
|
||||
from openllm._typing_compat import LiteralRuntime
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
OPTIONAL_DEPENDENCIES = {"opt", "flan-t5", "vllm", "fine-tune", "ggml", "agents", "openai", "playground", "gptq",}
|
||||
ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"}
|
||||
ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"})
|
||||
USE_TF = os.environ.get("USE_TF", "AUTO").upper()
|
||||
USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
|
||||
USE_JAX = os.environ.get("USE_FLAX", "AUTO").upper()
|
||||
FORCE_TF_AVAILABLE = os.environ.get("FORCE_TF_AVAILABLE", "AUTO").upper()
|
||||
|
||||
def _is_package_available(package: str) -> bool:
|
||||
_package_available = importlib.util.find_spec(package) is not None
|
||||
if _package_available:
|
||||
try: importlib.metadata.version(package)
|
||||
except importlib.metadata.PackageNotFoundError: _package_available = False
|
||||
return _package_available
|
||||
|
||||
_torch_available = importlib.util.find_spec("torch") is not None
|
||||
_tf_available = importlib.util.find_spec("tensorflow") is not None
|
||||
_flax_available = importlib.util.find_spec("jax") is not None and importlib.util.find_spec("flax") is not None
|
||||
_vllm_available = importlib.util.find_spec("vllm") is not None
|
||||
_peft_available = _is_package_available("peft")
|
||||
_einops_available = _is_package_available("einops")
|
||||
_cpm_kernel_available = _is_package_available("cpm_kernels")
|
||||
_bitsandbytes_available = _is_package_available("bitsandbytes")
|
||||
_datasets_available = _is_package_available("datasets")
|
||||
_triton_available = _is_package_available("triton")
|
||||
_jupyter_available = _is_package_available("jupyter")
|
||||
_jupytext_available = _is_package_available("jupytext")
|
||||
_notebook_available = _is_package_available("notebook")
|
||||
_autogptq_available = _is_package_available("auto_gptq")
|
||||
_sentencepiece_available = _is_package_available("sentencepiece")
|
||||
_xformers_available = _is_package_available("xformers")
|
||||
_fairscale_available = _is_package_available("fairscale")
|
||||
|
||||
def is_transformers_supports_kbit() -> bool: return pkg.pkg_version_info("transformers")[:2] >= (4, 30)
|
||||
def is_transformers_supports_agent() -> bool: return pkg.pkg_version_info("transformers")[:2] >= (4, 29)
|
||||
def is_jupyter_available() -> bool: return _jupyter_available
|
||||
def is_jupytext_available() -> bool: return _jupytext_available
|
||||
def is_notebook_available() -> bool: return _notebook_available
|
||||
def is_triton_available() -> bool: return _triton_available
|
||||
def is_datasets_available() -> bool: return _datasets_available
|
||||
def is_peft_available() -> bool: return _peft_available
|
||||
def is_einops_available() -> bool: return _einops_available
|
||||
def is_cpm_kernels_available() -> bool: return _cpm_kernel_available
|
||||
def is_bitsandbytes_available() -> bool: return _bitsandbytes_available
|
||||
def is_autogptq_available() -> bool: return _autogptq_available
|
||||
def is_vllm_available() -> bool: return _vllm_available
|
||||
def is_sentencepiece_available() -> bool: return _sentencepiece_available
|
||||
def is_xformers_available() -> bool: return _xformers_available
|
||||
def is_fairscale_available() -> bool: return _fairscale_available
|
||||
def is_torch_available() -> bool:
|
||||
global _torch_available
|
||||
if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES:
|
||||
if _torch_available:
|
||||
try: importlib.metadata.version("torch")
|
||||
except importlib.metadata.PackageNotFoundError: _torch_available = False
|
||||
else:
|
||||
logger.info("Disabling PyTorch because USE_TF is set")
|
||||
_torch_available = False
|
||||
return _torch_available
|
||||
def is_tf_available() -> bool:
|
||||
global _tf_available
|
||||
if FORCE_TF_AVAILABLE in ENV_VARS_TRUE_VALUES: _tf_available = True
|
||||
else:
|
||||
_tf_version = None
|
||||
if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VALUES:
|
||||
if _tf_available:
|
||||
candidates = ("tensorflow", "tensorflow-cpu", "tensorflow-gpu", "tf-nightly", "tf-nightly-cpu", "tf-nightly-gpu", "intel-tensorflow", "intel-tensorflow-avx512", "tensorflow-rocm", "tensorflow-macos", "tensorflow-aarch64",)
|
||||
_tf_version = None
|
||||
# For the metadata, we have to look for both tensorflow and tensorflow-cpu
|
||||
for _pkg in candidates:
|
||||
try:
|
||||
_tf_version = importlib.metadata.version(_pkg)
|
||||
break
|
||||
except importlib.metadata.PackageNotFoundError: pass # noqa: PERF203 # Ok to ignore here since we actually need to check for all possible tensorflow distribution.
|
||||
_tf_available = _tf_version is not None
|
||||
if _tf_available:
|
||||
if _tf_version and packaging.version.parse(_tf_version) < packaging.version.parse("2"):
|
||||
logger.info("TensorFlow found but with version %s. OpenLLM only supports TF 2.x", _tf_version)
|
||||
_tf_available = False
|
||||
else:
|
||||
logger.info("Disabling Tensorflow because USE_TORCH is set")
|
||||
_tf_available = False
|
||||
return _tf_available
|
||||
def is_flax_available() -> bool:
|
||||
global _flax_available
|
||||
if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES:
|
||||
if _flax_available:
|
||||
try:
|
||||
importlib.metadata.version("jax")
|
||||
importlib.metadata.version("flax")
|
||||
except importlib.metadata.PackageNotFoundError: _flax_available = False
|
||||
else:
|
||||
_flax_available = False
|
||||
return _flax_available
|
||||
|
||||
VLLM_IMPORT_ERROR_WITH_PYTORCH = """\
|
||||
{0} requires the vLLM library but it was not found in your environment.
|
||||
However, we were able to find a PyTorch installation. PyTorch classes do not begin
|
||||
with "VLLM", but are otherwise identically named to our PyTorch classes.
|
||||
If you want to use PyTorch, please use those classes instead!
|
||||
|
||||
If you really do want to use vLLM, please follow the instructions on the
|
||||
installation page https://github.com/vllm-project/vllm that match your environment.
|
||||
"""
|
||||
VLLM_IMPORT_ERROR_WITH_TF = """\
|
||||
{0} requires the vLLM library but it was not found in your environment.
|
||||
However, we were able to find a TensorFlow installation. TensorFlow classes begin
|
||||
with "TF", but are otherwise identically named to the PyTorch classes. This
|
||||
means that the TF equivalent of the class you tried to import would be "TF{0}".
|
||||
If you want to use TensorFlow, please use TF classes instead!
|
||||
|
||||
If you really do want to use vLLM, please follow the instructions on the
|
||||
installation page https://github.com/vllm-project/vllm that match your environment.
|
||||
"""
|
||||
VLLM_IMPORT_ERROR_WITH_FLAX = """\
|
||||
{0} requires the vLLM library but it was not found in your environment.
|
||||
However, we were able to find a Flax installation. Flax classes begin
|
||||
with "Flax", but are otherwise identically named to the PyTorch classes. This
|
||||
means that the Flax equivalent of the class you tried to import would be "Flax{0}".
|
||||
If you want to use Flax, please use Flax classes instead!
|
||||
|
||||
If you really do want to use vLLM, please follow the instructions on the
|
||||
installation page https://github.com/vllm-project/vllm that match your environment.
|
||||
"""
|
||||
PYTORCH_IMPORT_ERROR_WITH_TF = """\
|
||||
{0} requires the PyTorch library but it was not found in your environment.
|
||||
However, we were able to find a TensorFlow installation. TensorFlow classes begin
|
||||
with "TF", but are otherwise identically named to the PyTorch classes. This
|
||||
means that the TF equivalent of the class you tried to import would be "TF{0}".
|
||||
If you want to use TensorFlow, please use TF classes instead!
|
||||
|
||||
If you really do want to use PyTorch please go to
|
||||
https://pytorch.org/get-started/locally/ and follow the instructions that
|
||||
match your environment.
|
||||
"""
|
||||
TF_IMPORT_ERROR_WITH_PYTORCH = """\
|
||||
{0} requires the TensorFlow library but it was not found in your environment.
|
||||
However, we were able to find a PyTorch installation. PyTorch classes do not begin
|
||||
with "TF", but are otherwise identically named to our TF classes.
|
||||
If you want to use PyTorch, please use those classes instead!
|
||||
|
||||
If you really do want to use TensorFlow, please follow the instructions on the
|
||||
installation page https://www.tensorflow.org/install that match your environment.
|
||||
"""
|
||||
TENSORFLOW_IMPORT_ERROR = """{0} requires the TensorFlow library but it was not found in your environment.
|
||||
Checkout the instructions on the installation page: https://www.tensorflow.org/install and follow the
|
||||
ones that match your environment. Please note that you may need to restart your runtime after installation.
|
||||
"""
|
||||
FLAX_IMPORT_ERROR = """{0} requires the FLAX library but it was not found in your environment.
|
||||
Checkout the instructions on the installation page: https://github.com/google/flax and follow the
|
||||
ones that match your environment. Please note that you may need to restart your runtime after installation.
|
||||
"""
|
||||
PYTORCH_IMPORT_ERROR = """{0} requires the PyTorch library but it was not found in your environment.
|
||||
Checkout the instructions on the installation page: https://pytorch.org/get-started/locally/ and follow the
|
||||
ones that match your environment. Please note that you may need to restart your runtime after installation.
|
||||
"""
|
||||
VLLM_IMPORT_ERROR = """{0} requires the vLLM library but it was not found in your environment.
|
||||
Checkout the instructions on the installation page: https://github.com/vllm-project/vllm
|
||||
ones that match your environment. Please note that you may need to restart your runtime after installation.
|
||||
"""
|
||||
CPM_KERNELS_IMPORT_ERROR = """{0} requires the cpm_kernels library but it was not found in your environment.
|
||||
You can install it with pip: `pip install cpm_kernels`. Please note that you may need to restart your
|
||||
runtime after installation.
|
||||
"""
|
||||
EINOPS_IMPORT_ERROR = """{0} requires the einops library but it was not found in your environment.
|
||||
You can install it with pip: `pip install einops`. Please note that you may need to restart
|
||||
your runtime after installation.
|
||||
"""
|
||||
TRITON_IMPORT_ERROR = """{0} requires the triton library but it was not found in your environment.
|
||||
You can install it with pip: 'pip install \"git+https://github.com/openai/triton.git#egg=triton&subdirectory=python\"'.
|
||||
Please note that you may need to restart your runtime after installation.
|
||||
"""
|
||||
DATASETS_IMPORT_ERROR = """{0} requires the datasets library but it was not found in your environment.
|
||||
You can install it with pip: `pip install datasets`. Please note that you may need to restart
|
||||
your runtime after installation.
|
||||
"""
|
||||
PEFT_IMPORT_ERROR = """{0} requires the peft library but it was not found in your environment.
|
||||
You can install it with pip: `pip install peft`. Please note that you may need to restart
|
||||
your runtime after installation.
|
||||
"""
|
||||
BITSANDBYTES_IMPORT_ERROR = """{0} requires the bitsandbytes library but it was not found in your environment.
|
||||
You can install it with pip: `pip install bitsandbytes`. Please note that you may need to restart
|
||||
your runtime after installation.
|
||||
"""
|
||||
AUTOGPTQ_IMPORT_ERROR = """{0} requires the auto-gptq library but it was not found in your environment.
|
||||
You can install it with pip: `pip install auto-gptq`. Please note that you may need to restart
|
||||
your runtime after installation.
|
||||
"""
|
||||
SENTENCEPIECE_IMPORT_ERROR = """{0} requires the sentencepiece library but it was not found in your environment.
|
||||
You can install it with pip: `pip install sentencepiece`. Please note that you may need to restart
|
||||
your runtime after installation.
|
||||
"""
|
||||
XFORMERS_IMPORT_ERROR = """{0} requires the xformers library but it was not found in your environment.
|
||||
You can install it with pip: `pip install xformers`. Please note that you may need to restart
|
||||
your runtime after installation.
|
||||
"""
|
||||
FAIRSCALE_IMPORT_ERROR = """{0} requires the fairscale library but it was not found in your environment.
|
||||
You can install it with pip: `pip install fairscale`. Please note that you may need to restart
|
||||
your runtime after installation.
|
||||
"""
|
||||
|
||||
BACKENDS_MAPPING: BackendOrderedDict = OrderedDict([("flax", (is_flax_available, FLAX_IMPORT_ERROR)), ("tf", (is_tf_available, TENSORFLOW_IMPORT_ERROR)), ("torch", (is_torch_available, PYTORCH_IMPORT_ERROR)),
|
||||
("vllm", (is_vllm_available, VLLM_IMPORT_ERROR)), ("cpm_kernels", (is_cpm_kernels_available, CPM_KERNELS_IMPORT_ERROR)), ("einops", (is_einops_available, EINOPS_IMPORT_ERROR)),
|
||||
("triton", (is_triton_available, TRITON_IMPORT_ERROR)), ("datasets", (is_datasets_available, DATASETS_IMPORT_ERROR)), ("peft", (is_peft_available, PEFT_IMPORT_ERROR)),
|
||||
("bitsandbytes", (is_bitsandbytes_available, BITSANDBYTES_IMPORT_ERROR)), ("auto-gptq", (is_autogptq_available, AUTOGPTQ_IMPORT_ERROR)), ("sentencepiece", (is_sentencepiece_available, SENTENCEPIECE_IMPORT_ERROR)),
|
||||
("xformers", (is_xformers_available, XFORMERS_IMPORT_ERROR)), ("fairscale", (is_fairscale_available, FAIRSCALE_IMPORT_ERROR))])
|
||||
|
||||
class DummyMetaclass(abc.ABCMeta):
|
||||
"""Metaclass for dummy object.
|
||||
|
||||
It will raises ImportError generated by ``require_backends`` if users try to access attributes from given class.
|
||||
"""
|
||||
_backends: t.List[str]
|
||||
def __getattribute__(cls, key: str) -> t.Any:
|
||||
if key.startswith("_"): return super().__getattribute__(key)
|
||||
require_backends(cls, cls._backends)
|
||||
|
||||
def require_backends(o: t.Any, backends: t.MutableSequence[str]) -> None:
|
||||
if not isinstance(backends, (list, tuple)): backends = list(backends)
|
||||
name = o.__name__ if hasattr(o, "__name__") else o.__class__.__name__
|
||||
# Raise an error for users who might not realize that classes without "TF" are torch-only
|
||||
if "torch" in backends and "tf" not in backends and not is_torch_available() and is_tf_available(): raise ImportError(PYTORCH_IMPORT_ERROR_WITH_TF.format(name))
|
||||
# Raise the inverse error for PyTorch users trying to load TF classes
|
||||
if "tf" in backends and "torch" not in backends and is_torch_available() and not is_tf_available(): raise ImportError(TF_IMPORT_ERROR_WITH_PYTORCH.format(name))
|
||||
# Raise an error when vLLM is not available to consider the alternative, order from PyTorch -> Tensorflow -> Flax
|
||||
if "vllm" in backends:
|
||||
if "torch" not in backends and is_torch_available() and not is_vllm_available(): raise ImportError(VLLM_IMPORT_ERROR_WITH_PYTORCH.format(name))
|
||||
if "tf" not in backends and is_tf_available() and not is_vllm_available(): raise ImportError(VLLM_IMPORT_ERROR_WITH_TF.format(name))
|
||||
if "flax" not in backends and is_flax_available() and not is_vllm_available(): raise ImportError(VLLM_IMPORT_ERROR_WITH_FLAX.format(name))
|
||||
failed = [msg.format(name) for available, msg in (BACKENDS_MAPPING[backend] for backend in backends) if not available()]
|
||||
if failed: raise ImportError("".join(failed))
|
||||
|
||||
class EnvVarMixin(ReprMixin):
|
||||
model_name: str
|
||||
config: str
|
||||
model_id: str
|
||||
quantize: str
|
||||
framework: str
|
||||
bettertransformer: str
|
||||
runtime: str
|
||||
@overload
|
||||
def __getitem__(self, item: t.Literal["config"]) -> str: ...
|
||||
@overload
|
||||
def __getitem__(self, item: t.Literal["model_id"]) -> str: ...
|
||||
@overload
|
||||
def __getitem__(self, item: t.Literal["quantize"]) -> str: ...
|
||||
@overload
|
||||
def __getitem__(self, item: t.Literal["framework"]) -> str: ...
|
||||
@overload
|
||||
def __getitem__(self, item: t.Literal["bettertransformer"]) -> str: ...
|
||||
@overload
|
||||
def __getitem__(self, item: t.Literal["runtime"]) -> str: ...
|
||||
@overload
|
||||
def __getitem__(self, item: t.Literal["framework_value"]) -> LiteralRuntime: ...
|
||||
@overload
|
||||
def __getitem__(self, item: t.Literal["quantize_value"]) -> t.Literal["int8", "int4", "gptq"] | None: ...
|
||||
@overload
|
||||
def __getitem__(self, item: t.Literal["model_id_value"]) -> str | None: ...
|
||||
@overload
|
||||
def __getitem__(self, item: t.Literal["bettertransformer_value"]) -> bool: ...
|
||||
@overload
|
||||
def __getitem__(self, item: t.Literal["runtime_value"]) -> t.Literal["ggml", "transformers"]: ...
|
||||
def __getitem__(self, item: str | t.Any) -> t.Any:
|
||||
if item.endswith("_value") and hasattr(self, f"_{item}"): return object.__getattribute__(self, f"_{item}")()
|
||||
elif hasattr(self, item): return getattr(self, item)
|
||||
raise KeyError(f"Key {item} not found in {self}")
|
||||
def __init__(self, model_name: str, implementation: LiteralRuntime = "pt", model_id: str | None = None, bettertransformer: bool | None = None, quantize: LiteralString | None = None, runtime: t.Literal["ggml", "transformers"] = "transformers") -> None:
|
||||
"""EnvVarMixin is a mixin class that returns the value extracted from environment variables."""
|
||||
from openllm._configuration import field_env_key
|
||||
self.model_name = inflection.underscore(model_name)
|
||||
self._implementation = implementation
|
||||
self._model_id = model_id
|
||||
self._bettertransformer = bettertransformer
|
||||
self._quantize = quantize
|
||||
self._runtime = runtime
|
||||
for att in {"config", "model_id", "quantize", "framework", "bettertransformer", "runtime"}: setattr(self, att, field_env_key(self.model_name, att.upper()))
|
||||
def _quantize_value(self) -> t.Literal["int8", "int4", "gptq"] | None:
|
||||
from . import first_not_none
|
||||
return t.cast(t.Optional[t.Literal["int8", "int4", "gptq"]], first_not_none(os.environ.get(self["quantize"]), default=self._quantize))
|
||||
def _framework_value(self) -> LiteralRuntime:
|
||||
from . import first_not_none
|
||||
return t.cast(t.Literal["pt", "tf", "flax", "vllm"], first_not_none(os.environ.get(self["framework"]), default=self._implementation))
|
||||
def _bettertransformer_value(self) -> bool:
|
||||
from . import first_not_none
|
||||
return t.cast(bool, first_not_none(os.environ.get(self["bettertransformer"], str(False)).upper() in ENV_VARS_TRUE_VALUES, default=self._bettertransformer))
|
||||
def _model_id_value(self) -> str | None:
|
||||
from . import first_not_none
|
||||
return first_not_none(os.environ.get(self["model_id"]), default=self._model_id)
|
||||
def _runtime_value(self) -> t.Literal["ggml", "transformers"]:
|
||||
from . import first_not_none
|
||||
return t.cast(t.Literal["ggml", "transformers"], first_not_none(os.environ.get(self["runtime"]), default=self._runtime))
|
||||
@property
|
||||
def __repr_keys__(self) -> set[str]: return {"config", "model_id", "quantize", "framework", "bettertransformer", "runtime"}
|
||||
@property
|
||||
def start_docstring(self) -> str: return getattr(self.module, f"START_{self.model_name.upper()}_COMMAND_DOCSTRING")
|
||||
@property
|
||||
def module(self) -> LazyLoader: return LazyLoader(self.model_name, globals(), f"openllm.models.{self.model_name}")
|
||||
@@ -1,107 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import functools, importlib, importlib.machinery, importlib.metadata, importlib.util, itertools, os, time, types, warnings, typing as t
|
||||
import attr, openllm
|
||||
|
||||
__all__ = ["VersionInfo", "LazyModule"]
|
||||
# vendorred from attrs
|
||||
@functools.total_ordering
|
||||
@attr.attrs(eq=False, order=False, slots=True, frozen=True, repr=False)
|
||||
class VersionInfo:
|
||||
major: int = attr.field()
|
||||
minor: int = attr.field()
|
||||
micro: int = attr.field()
|
||||
releaselevel: str = attr.field()
|
||||
@classmethod
|
||||
def from_version_string(cls, s: str) -> VersionInfo:
|
||||
v = s.split(".")
|
||||
if len(v) == 3: v.append("final")
|
||||
return cls(major=int(v[0]), minor=int(v[1]), micro=int(v[2]), releaselevel=v[3])
|
||||
def _ensure_tuple(self, other: VersionInfo) -> tuple[tuple[int, int, int, str], tuple[int, int, int, str]]:
|
||||
cmp = attr.astuple(other) if self.__class__ is other.__class__ else other
|
||||
if not isinstance(cmp, tuple): raise NotImplementedError
|
||||
if not (1 <= len(cmp) <= 4): raise NotImplementedError
|
||||
return t.cast(t.Tuple[int, int, int, str], attr.astuple(self)[:len(cmp)]), t.cast(t.Tuple[int, int, int, str], cmp)
|
||||
def __eq__(self, other: t.Any) -> bool:
|
||||
try: us, them = self._ensure_tuple(other)
|
||||
except NotImplementedError: return NotImplemented
|
||||
return us == them
|
||||
def __lt__(self, other: t.Any) -> bool:
|
||||
try: us, them = self._ensure_tuple(other)
|
||||
except NotImplementedError: return NotImplemented
|
||||
# Since alphabetically "dev0" < "final" < "post1" < "post2", we don't have to do anything special with releaselevel for now.
|
||||
return us < them
|
||||
def __repr__(self) -> str: return "{0}.{1}.{2}".format(*attr.astuple(self)[:3])
|
||||
|
||||
_sentinel, _reserved_namespace = object(), {"__openllm_migration__"}
|
||||
class LazyModule(types.ModuleType):
|
||||
# Very heavily inspired by optuna.integration._IntegrationModule: https://github.com/optuna/optuna/blob/master/optuna/integration/__init__.py
|
||||
def __init__(self, name: str, module_file: str, import_structure: dict[str, list[str]], module_spec: importlib.machinery.ModuleSpec | None = None, doc: str | None = None, extra_objects: dict[str, t.Any] | None = None):
|
||||
"""Lazily load this module as an object.
|
||||
|
||||
It does instantiate a __all__ and __dir__ for IDE support
|
||||
|
||||
Args:
|
||||
name: module name
|
||||
module_file: the given file. Often default to 'globals()['__file__']'
|
||||
import_structure: A dictionary of module and its corresponding attributes that can be loaded from given 'module'
|
||||
module_spec: __spec__ of the lazily loaded module
|
||||
doc: Optional docstring for this module.
|
||||
extra_objects: Any additional objects that this module can also be accessed. Useful for additional metadata as well as any locals() functions
|
||||
"""
|
||||
super().__init__(name)
|
||||
self._modules = set(import_structure.keys())
|
||||
self._class_to_module: dict[str, str] = {}
|
||||
_extra_objects = {} if extra_objects is None else extra_objects
|
||||
for key, values in import_structure.items():
|
||||
for value in values: self._class_to_module[value] = key
|
||||
# Needed for autocompletion in an IDE
|
||||
self.__all__: list[str] = list(import_structure.keys()) + list(itertools.chain(*import_structure.values()))
|
||||
self.__file__ = module_file
|
||||
self.__spec__ = module_spec or importlib.util.find_spec(name)
|
||||
self.__path__ = [os.path.dirname(module_file)]
|
||||
self.__doc__ = doc
|
||||
self._name = name
|
||||
self._objects = _extra_objects
|
||||
self._import_structure = import_structure
|
||||
def __dir__(self) -> list[str]:
|
||||
result = t.cast("list[str]", super().__dir__())
|
||||
# The elements of self.__all__ that are submodules may or may not be in the dir already, depending on whether
|
||||
# they have been accessed or not. So we only add the elements of self.__all__ that are not already in the dir.
|
||||
return result + [i for i in self.__all__ if i not in result]
|
||||
def __getattr__(self, name: str) -> t.Any:
|
||||
"""Equivocal __getattr__ implementation.
|
||||
|
||||
It checks from _objects > _modules and does it recursively.
|
||||
|
||||
It also contains a special case for all of the metadata information, such as __version__ and __version_info__.
|
||||
"""
|
||||
if name in _reserved_namespace: raise openllm.exceptions.ForbiddenAttributeError(f"'{name}' is a reserved namespace for {self._name} and should not be access nor modified.")
|
||||
dunder_to_metadata = {"__title__": "Name", "__copyright__": "", "__version__": "version", "__version_info__": "version", "__description__": "summary", "__uri__": "", "__url__": "", "__author__": "", "__email__": "", "__license__": "license", "__homepage__": ""}
|
||||
if name in dunder_to_metadata:
|
||||
if name not in {"__version_info__", "__copyright__", "__version__"}: warnings.warn(f"Accessing '{self._name}.{name}' is deprecated. Please consider using 'importlib.metadata' directly to query for openllm packaging metadata.", DeprecationWarning, stacklevel=2)
|
||||
meta = importlib.metadata.metadata("openllm")
|
||||
project_url = dict(url.split(", ") for url in t.cast(t.List[str], meta.get_all("Project-URL")))
|
||||
if name == "__license__": return "Apache-2.0"
|
||||
elif name == "__copyright__": return f"Copyright (c) 2023-{time.strftime('%Y')}, Aaron Pham et al."
|
||||
elif name in ("__uri__", "__url__"): return project_url["GitHub"]
|
||||
elif name == "__homepage__": return project_url["Homepage"]
|
||||
elif name == "__version_info__": return VersionInfo.from_version_string(meta["version"]) # similar to how attrs handle __version_info__
|
||||
elif name == "__author__": return meta["Author-email"].rsplit(" ", 1)[0]
|
||||
elif name == "__email__": return meta["Author-email"].rsplit("<", 1)[1][:-1]
|
||||
return meta[dunder_to_metadata[name]]
|
||||
if "__openllm_migration__" in self._objects:
|
||||
cur_value = self._objects["__openllm_migration__"].get(name, _sentinel)
|
||||
if cur_value is not _sentinel:
|
||||
warnings.warn(f"'{name}' is deprecated and will be removed in future version. Make sure to use '{cur_value}' instead", DeprecationWarning, stacklevel=3)
|
||||
return getattr(self, cur_value)
|
||||
if name in self._objects: return self._objects.__getitem__(name)
|
||||
if name in self._modules: value = self._get_module(name)
|
||||
elif name in self._class_to_module.keys(): value = getattr(self._get_module(self._class_to_module.__getitem__(name)), name)
|
||||
else: raise AttributeError(f"module {self.__name__} has no attribute {name}")
|
||||
setattr(self, name, value)
|
||||
return value
|
||||
def _get_module(self, module_name: str) -> types.ModuleType:
|
||||
try: return importlib.import_module("." + module_name, self.__name__)
|
||||
except Exception as e: raise RuntimeError(f"Failed to import {self.__name__}.{module_name} because of the following error (look up to see its traceback):\n{e}") from e
|
||||
# make sure this module is picklable
|
||||
def __reduce__(self) -> tuple[type[LazyModule], tuple[str, str | None, dict[str, list[str]]]]: return (self.__class__, (self._name, self.__file__, self._import_structure))
|
||||
@@ -1,32 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import typing as t
|
||||
from abc import abstractmethod
|
||||
import attr, orjson
|
||||
from openllm import utils
|
||||
if t.TYPE_CHECKING: from openllm._typing_compat import TypeAlias
|
||||
|
||||
ReprArgs: TypeAlias = t.Generator[t.Tuple[t.Optional[str], t.Any], None, None]
|
||||
class ReprMixin:
|
||||
@property
|
||||
@abstractmethod
|
||||
def __repr_keys__(self) -> set[str]: raise NotImplementedError
|
||||
"""This can be overriden by base class using this mixin."""
|
||||
def __repr__(self) -> str: return f"{self.__class__.__name__} {orjson.dumps({k: utils.bentoml_cattr.unstructure(v) if attr.has(v) else v for k, v in self.__repr_args__()}, option=orjson.OPT_INDENT_2).decode()}"
|
||||
"""The `__repr__` for any subclass of Mixin.
|
||||
|
||||
It will print nicely the class name with each of the fields under '__repr_keys__' as kv JSON dict.
|
||||
"""
|
||||
def __str__(self) -> str: return self.__repr_str__(" ")
|
||||
"""The string representation of the given Mixin subclass.
|
||||
|
||||
It will contains all of the attributes from __repr_keys__
|
||||
"""
|
||||
def __repr_name__(self) -> str: return self.__class__.__name__
|
||||
"""Name of the instance's class, used in __repr__."""
|
||||
def __repr_str__(self, join_str: str) -> str: return join_str.join(repr(v) if a is None else f"{a}={v!r}" for a, v in self.__repr_args__())
|
||||
"""To be used with __str__."""
|
||||
def __repr_args__(self) -> ReprArgs: return ((k, getattr(self, k)) for k in self.__repr_keys__)
|
||||
"""This can also be overriden by base class using this mixin.
|
||||
|
||||
By default it does a getattr of the current object from __repr_keys__.
|
||||
"""
|
||||
Reference in New Issue
Block a user