mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-03-14 21:16:16 -04:00
perf: faster LLM loading
using attrs for faster class creation opposed to metaclass Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
@@ -66,7 +66,7 @@ from deepmerge.merger import Merger
|
||||
import openllm
|
||||
|
||||
from .exceptions import ForbiddenAttributeError, GpuNotAvailableError, OpenLLMException
|
||||
from .utils import DEBUG, LazyType, bentoml_cattr, dantic, first_not_none, lenient_issubclass
|
||||
from .utils import DEBUG, LazyType, bentoml_cattr, codegen, dantic, first_not_none, lenient_issubclass
|
||||
|
||||
if hasattr(t, "Required"):
|
||||
from typing import Required
|
||||
@@ -85,7 +85,7 @@ if t.TYPE_CHECKING:
|
||||
import tensorflow as tf
|
||||
import torch
|
||||
import transformers
|
||||
from attr import _CountingAttr, _make_init, _make_method, _make_repr, _transform_attrs
|
||||
from attr import _CountingAttr, _make_init, _make_repr, _transform_attrs # type: ignore
|
||||
from transformers.generation.beam_constraints import Constraint
|
||||
|
||||
from ._types import ClickFunctionWrapper, F, O_co, P
|
||||
@@ -103,7 +103,7 @@ else:
|
||||
ItemgetterAny = itemgetter
|
||||
# NOTE: Using internal API from attr here, since we are actually
|
||||
# allowing subclass of openllm.LLMConfig to become 'attrs'-ish
|
||||
from attr._make import _CountingAttr, _make_init, _make_method, _make_repr, _transform_attrs
|
||||
from attr._make import _CountingAttr, _make_init, _make_repr, _transform_attrs
|
||||
|
||||
transformers = openllm.utils.LazyLoader("transformers", globals(), "transformers")
|
||||
torch = openllm.utils.LazyLoader("torch", globals(), "torch")
|
||||
@@ -389,89 +389,10 @@ def _populate_value_from_env_var(
|
||||
return os.environ.get(key, fallback)
|
||||
|
||||
|
||||
# sentinel object for unequivocal object() getattr
|
||||
_sentinel = object()
|
||||
|
||||
|
||||
def _field_env_key(model_name: str, key: str, suffix: str | None = None) -> str:
|
||||
return "_".join(filter(None, map(str.upper, ["OPENLLM", model_name, suffix.strip("_") if suffix else "", key])))
|
||||
|
||||
|
||||
def _has_own_attribute(cls: type[t.Any], attrib_name: t.Any):
|
||||
"""
|
||||
Check whether *cls* defines *attrib_name* (and doesn't just inherit it).
|
||||
"""
|
||||
attr = getattr(cls, attrib_name, _sentinel)
|
||||
if attr is _sentinel:
|
||||
return False
|
||||
|
||||
for base_cls in cls.__mro__[1:]:
|
||||
a = getattr(base_cls, attrib_name, None)
|
||||
if attr is a:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def _get_annotations(cls: type[t.Any]) -> DictStrAny:
|
||||
"""
|
||||
Get annotations for *cls*.
|
||||
"""
|
||||
if _has_own_attribute(cls, "__annotations__"):
|
||||
return cls.__annotations__
|
||||
|
||||
return DictStrAny()
|
||||
|
||||
|
||||
_classvar_prefixes = (
|
||||
"typing.ClassVar",
|
||||
"t.ClassVar",
|
||||
"ClassVar",
|
||||
"typing_extensions.ClassVar",
|
||||
)
|
||||
|
||||
|
||||
def _is_class_var(annot: str | t.Any) -> bool:
|
||||
"""
|
||||
Check whether *annot* is a typing.ClassVar.
|
||||
|
||||
The string comparison hack is used to avoid evaluating all string
|
||||
annotations which would put attrs-based classes at a performance
|
||||
disadvantage compared to plain old classes.
|
||||
"""
|
||||
annot = str(annot)
|
||||
|
||||
# Annotation can be quoted.
|
||||
if annot.startswith(("'", '"')) and annot.endswith(("'", '"')):
|
||||
annot = annot[1:-1]
|
||||
|
||||
return annot.startswith(_classvar_prefixes)
|
||||
|
||||
|
||||
def _add_method_dunders(cls: type[t.Any], method_or_cls: _T, _overwrite_doc: str | None = None) -> _T:
|
||||
"""
|
||||
Add __module__ and __qualname__ to a *method* if possible.
|
||||
"""
|
||||
try:
|
||||
method_or_cls.__module__ = cls.__module__
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
try:
|
||||
method_or_cls.__qualname__ = ".".join((cls.__qualname__, method_or_cls.__name__))
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
try:
|
||||
method_or_cls.__doc__ = (
|
||||
_overwrite_doc or "Method or class generated by LLMConfig for class " f"{cls.__qualname__}."
|
||||
)
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
return method_or_cls
|
||||
|
||||
|
||||
# cached it here to save one lookup per assignment
|
||||
_object_getattribute = object.__getattribute__
|
||||
|
||||
@@ -506,8 +427,8 @@ class ModelSettings(t.TypedDict, total=False):
|
||||
generation_class: t.Type[GenerationConfig]
|
||||
|
||||
|
||||
_ModelSettings: type[attr.AttrsInstance] = _add_method_dunders(
|
||||
type("__internal__", (ModelSettings,), {"__module__": "openllm._configuration"}),
|
||||
_ModelSettings: type[attr.AttrsInstance] = codegen.add_method_dunders(
|
||||
type("__openllm_internal__", (ModelSettings,), {"__module__": "openllm._configuration"}),
|
||||
attr.make_class(
|
||||
"ModelSettings",
|
||||
{
|
||||
@@ -563,7 +484,7 @@ def structure_settings(cl_: type[LLMConfig], cls: type[t.Any]):
|
||||
partialed = functools.partial(_field_env_key, model_name=model_name, suffix="generation")
|
||||
|
||||
def auto_env_transformers(_: t.Any, fields: list[attr.Attribute[t.Any]]) -> list[attr.Attribute[t.Any]]:
|
||||
_has_own_gen = _has_own_attribute(cl_, "GenerationConfig")
|
||||
_has_own_gen = codegen.has_own_attribute(cl_, "GenerationConfig")
|
||||
return [
|
||||
f.evolve(
|
||||
default=_populate_value_from_env_var(
|
||||
@@ -576,20 +497,20 @@ def structure_settings(cl_: type[LLMConfig], cls: type[t.Any]):
|
||||
for f in fields
|
||||
]
|
||||
|
||||
_target: DictStrAny = {
|
||||
"default_id": settings["default_id"],
|
||||
"model_ids": settings["model_ids"],
|
||||
"url": settings.get("url", ""),
|
||||
"requires_gpu": settings.get("requires_gpu", False),
|
||||
"trust_remote_code": settings.get("trust_remote_code", False),
|
||||
"requirements": settings.get("requirements", None),
|
||||
"name_type": name_type,
|
||||
"model_name": model_name,
|
||||
"start_name": start_name,
|
||||
"env": openllm.utils.ModelEnv(model_name),
|
||||
"timeout": settings.get("timeout", 3600),
|
||||
"workers_per_resource": settings.get("workers_per_resource", 1),
|
||||
"generation_class": attr.make_class(
|
||||
return cls(
|
||||
default_id=settings["default_id"],
|
||||
model_ids=settings["model_ids"],
|
||||
url=settings.get("url", ""),
|
||||
requires_gpu=settings.get("requires_gpu", False),
|
||||
trust_remote_code=settings.get("trust_remote_code", False),
|
||||
requirements=settings.get("requirements", None),
|
||||
name_type=name_type,
|
||||
model_name=model_name,
|
||||
start_name=start_name,
|
||||
env=openllm.utils.ModelEnv(model_name),
|
||||
timeout=settings.get("timeout", 3600),
|
||||
workers_per_resource=settings.get("workers_per_resource", 1),
|
||||
generation_class=attr.make_class(
|
||||
f"{_cl_name}GenerationConfig",
|
||||
[],
|
||||
bases=(GenerationConfig,),
|
||||
@@ -599,18 +520,12 @@ def structure_settings(cl_: type[LLMConfig], cls: type[t.Any]):
|
||||
repr=True,
|
||||
field_transformer=auto_env_transformers,
|
||||
),
|
||||
}
|
||||
|
||||
return cls(**_target)
|
||||
)
|
||||
|
||||
|
||||
bentoml_cattr.register_structure_hook(_ModelSettings, structure_settings)
|
||||
|
||||
|
||||
def _generate_unique_filename(cls: type[t.Any], func_name: str):
|
||||
return f"<LLMConfig generated {func_name} {cls.__module__}." f"{getattr(cls, '__qualname__', cls.__name__)}>"
|
||||
|
||||
|
||||
def _setattr_class(attr_name: str, value_var: t.Any, add_dunder: bool = False):
|
||||
"""
|
||||
Use the builtin setattr to set *attr_name* to *value_var*.
|
||||
@@ -632,7 +547,7 @@ def _make_assignment_script(cls: type[LLMConfig], attributes: attr.AttrsInstance
|
||||
"cls": cls,
|
||||
"_cached_attribute": attributes,
|
||||
"_cached_getattribute_get": _object_getattribute.__get__,
|
||||
"__add_dunder": _add_method_dunders,
|
||||
"__add_dunder": codegen.add_method_dunders,
|
||||
}
|
||||
annotations: DictStrAny = {"return": None}
|
||||
|
||||
@@ -643,19 +558,9 @@ def _make_assignment_script(cls: type[LLMConfig], attributes: attr.AttrsInstance
|
||||
lines.append(_setattr_class(arg_name, attr_name, add_dunder=attr_name in _dunder_add))
|
||||
annotations[attr_name] = field.type
|
||||
|
||||
script = "def __assign_attr(cls, %s):\n %s\n" % (", ".join(args), "\n ".join(lines) if lines else "pass")
|
||||
assign_method = _make_method(
|
||||
"__assign_attr",
|
||||
script,
|
||||
_generate_unique_filename(cls, "__assign_attr"),
|
||||
globs,
|
||||
return codegen.generate_function(
|
||||
cls, "__assign_attr", lines, args=("cls", *args), globs=globs, annotations=annotations
|
||||
)
|
||||
assign_method.__annotations__ = annotations
|
||||
|
||||
if DEBUG:
|
||||
logger.info("Generated script:\n%s", script)
|
||||
|
||||
return assign_method
|
||||
|
||||
|
||||
_reserved_namespace = {"__config__", "GenerationConfig"}
|
||||
@@ -841,7 +746,7 @@ class LLMConfig:
|
||||
_make_assignment_script(cls, bentoml_cattr.structure(cls, _ModelSettings))(cls)
|
||||
# process a fields under cls.__dict__ and auto convert them with dantic.Field
|
||||
cd = cls.__dict__
|
||||
anns = _get_annotations(cls)
|
||||
anns = codegen.get_annotations(cls)
|
||||
partialed = functools.partial(_field_env_key, model_name=cls.__openllm_model_name__)
|
||||
|
||||
def auto_config_env(_: type[LLMConfig], attrs: list[attr.Attribute[t.Any]]) -> list[attr.Attribute[t.Any]]:
|
||||
@@ -861,7 +766,7 @@ class LLMConfig:
|
||||
these: dict[str, _CountingAttr[t.Any]] = {}
|
||||
annotated_names: set[str] = set()
|
||||
for attr_name, typ in anns.items():
|
||||
if _is_class_var(typ):
|
||||
if codegen.is_class_var(typ):
|
||||
continue
|
||||
annotated_names.add(attr_name)
|
||||
val = cd.get(attr_name, attr.NOTHING)
|
||||
@@ -907,7 +812,7 @@ class LLMConfig:
|
||||
cls.__attrs_attrs__ = attrs
|
||||
# generate a __attrs_init__ for the subclass, since we will
|
||||
# implement a custom __init__
|
||||
cls.__attrs_init__ = _add_method_dunders(
|
||||
cls.__attrs_init__ = codegen.add_method_dunders(
|
||||
cls,
|
||||
_make_init(
|
||||
cls, # cls (the attrs-decorated class)
|
||||
@@ -924,7 +829,7 @@ class LLMConfig:
|
||||
),
|
||||
)
|
||||
# __repr__ function with the updated fields.
|
||||
cls.__repr__ = _add_method_dunders(cls, _make_repr(cls.__attrs_attrs__, None, cls))
|
||||
cls.__repr__ = codegen.add_method_dunders(cls, _make_repr(cls.__attrs_attrs__, None, cls))
|
||||
# Traverse the MRO to collect existing slots
|
||||
# and check for an existing __weakref__.
|
||||
existing_slots: DictStrAny = dict()
|
||||
|
||||
@@ -19,11 +19,11 @@ import functools
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import types
|
||||
import typing as t
|
||||
from abc import ABC, ABCMeta, abstractmethod
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import attr
|
||||
import bentoml
|
||||
import inflection
|
||||
import orjson
|
||||
@@ -33,7 +33,6 @@ from bentoml._internal.types import ModelSignatureDict
|
||||
import openllm
|
||||
|
||||
from .exceptions import ForbiddenAttributeError, OpenLLMException
|
||||
from .models.auto import AutoConfig
|
||||
from .utils import ENV_VARS_TRUE_VALUES, LazyLoader, bentoml_cattr
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
@@ -64,6 +63,8 @@ FRAMEWORK_TO_AUTOCLASS_MAPPING = {
|
||||
"flax": ("FlaxAutoModelForCausalLM", "FlaxAutoModelForSeq2SeqLM"),
|
||||
}
|
||||
|
||||
TOKENIZER_PREFIX = "_tokenizer_"
|
||||
|
||||
|
||||
def convert_transformers_model_name(name: str) -> str:
|
||||
if os.path.exists(os.path.dirname(name)):
|
||||
@@ -178,9 +179,12 @@ class LLMInterface(ABC):
|
||||
"""The config class to use for this LLM. If you are creating a custom LLM, you must specify this class."""
|
||||
|
||||
@property
|
||||
def import_kwargs(self) -> dict[str, t.Any] | None:
|
||||
def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]] | None:
|
||||
"""The default import kwargs to used when importing the model.
|
||||
This will be passed into 'openllm.LLM.import_model'."""
|
||||
This will be passed into 'openllm.LLM.import_model'.
|
||||
It returns two dictionaries: one for model kwargs and one for tokenizer kwargs.
|
||||
"""
|
||||
return
|
||||
|
||||
@abstractmethod
|
||||
def generate(self, prompt: str, **preprocess_generate_kwds: t.Any) -> t.Any:
|
||||
@@ -254,132 +258,87 @@ _M = t.TypeVar("_M")
|
||||
_T = t.TypeVar("_T")
|
||||
|
||||
|
||||
class LLMMetaclass(ABCMeta):
|
||||
def __new__(
|
||||
mcls, cls_name: str, bases: tuple[type[t.Any], ...], namespace: dict[str, t.Any], **attrs: t.Any
|
||||
) -> type:
|
||||
"""Metaclass for creating a LLM."""
|
||||
if LLMInterface not in bases: # only actual openllm.LLM should hit this branch.
|
||||
annotations_dict: dict[str, t.Any] = {}
|
||||
|
||||
# NOTE: set implementation branch
|
||||
prefix_class_name_config = cls_name
|
||||
if cls_name.startswith("Flax"):
|
||||
implementation = "flax"
|
||||
prefix_class_name_config = cls_name[4:]
|
||||
annotations_dict.update(
|
||||
{
|
||||
"__llm_model__": "transformers.FlaxPreTrainedModel",
|
||||
"model": "transformers.FlaxPreTrainedModel",
|
||||
}
|
||||
)
|
||||
elif cls_name.startswith("TF"):
|
||||
implementation = "tf"
|
||||
prefix_class_name_config = cls_name[2:]
|
||||
annotations_dict.update(
|
||||
{
|
||||
"__llm_model__": "transformers.TFPreTrainedModel",
|
||||
"model": "transformers.TFPreTrainedModel",
|
||||
}
|
||||
)
|
||||
else:
|
||||
implementation = "pt"
|
||||
annotations_dict.update(
|
||||
{
|
||||
"__llm_model__": "transformers.PreTrainedModel",
|
||||
"model": "transformers.PreTrainedModel",
|
||||
}
|
||||
)
|
||||
namespace["__llm_implementation__"] = implementation
|
||||
|
||||
config_class = AutoConfig.infer_class_from_name(prefix_class_name_config)
|
||||
annotations_dict["config_class"] = f"type[openllm.{config_class.__qualname__}]"
|
||||
annotations_dict["config"] = f"openllm.{config_class.__qualname__}"
|
||||
|
||||
# NOTE: setup config class branch
|
||||
if "__openllm_internal__" in namespace:
|
||||
# NOTE: we will automatically find the subclass for this given config class
|
||||
if "config_class" not in namespace:
|
||||
# this branch we will automatically get the class
|
||||
namespace["config_class"] = config_class
|
||||
else:
|
||||
logger.debug(f"Using config class {namespace['config_class']} for {cls_name}.")
|
||||
# NOTE: check for required attributes
|
||||
else:
|
||||
if "config_class" not in namespace:
|
||||
raise RuntimeError(
|
||||
"Missing required key 'config_class'. Make sure to define it within the LLM subclass."
|
||||
)
|
||||
|
||||
# NOTE: the llm_post_init branch
|
||||
if "llm_post_init" in namespace:
|
||||
original_llm_post_init = namespace["llm_post_init"]
|
||||
|
||||
def wrapped_llm_post_init(self: LLM[t.Any, t.Any]) -> None:
|
||||
"""We need to both initialize private attributes and call the user-defined model_post_init
|
||||
method.
|
||||
"""
|
||||
_default_post_init(self)
|
||||
original_llm_post_init(self)
|
||||
|
||||
namespace["llm_post_init"] = wrapped_llm_post_init
|
||||
else:
|
||||
namespace["llm_post_init"] = _default_post_init
|
||||
|
||||
# NOTE: import_model branch
|
||||
if "import_model" not in namespace:
|
||||
# using the default import model
|
||||
namespace["import_model"] = functools.partial(import_model, _model_framework=implementation)
|
||||
else:
|
||||
logger.debug("Using custom 'import_model' for %s", cls_name)
|
||||
|
||||
namespace.update({k: None for k in {"__llm_model__", "__llm_tokenizer__", "__llm_bentomodel__"}})
|
||||
tokenizer_type = "typing.Optional[transformers.PreTrainedTokenizerFast]"
|
||||
annotations_dict.update(
|
||||
{
|
||||
"load_in_mha": "bool",
|
||||
"tokenizer": tokenizer_type,
|
||||
"__llm_tokenizer__": tokenizer_type,
|
||||
"__llm_bentomodel__": "typing.Optional[bentoml.Model]",
|
||||
}
|
||||
)
|
||||
# NOTE: transformers to ForwardRef
|
||||
for key, value in annotations_dict.items():
|
||||
if (3, 10) > sys.version_info >= (3, 9, 8) or sys.version_info >= (3, 10, 1):
|
||||
annotations_dict[key] = t.ForwardRef(value, is_argument=False, is_class=True)
|
||||
else:
|
||||
annotations_dict[key] = t.ForwardRef(value, is_argument=False)
|
||||
|
||||
namespace["__annotations__"] = annotations_dict
|
||||
|
||||
cls: type[LLM[t.Any, t.Any]] = super().__new__(
|
||||
t.cast("type[type[LLM[t.Any, t.Any]]]", mcls), cls_name, bases, namespace, **attrs
|
||||
)
|
||||
|
||||
cls.__openllm_post_init__ = None if cls.llm_post_init is LLMInterface.llm_post_init else cls.llm_post_init
|
||||
cls.__openllm_custom_load__ = None if cls.load_model is LLMInterface.load_model else cls.load_model
|
||||
return cls
|
||||
else:
|
||||
# the LLM class itself being created, no need to setup
|
||||
return super().__new__(mcls, cls_name, bases, namespace, **attrs)
|
||||
|
||||
|
||||
class LLM(LLMInterface, t.Generic[_M, _T], metaclass=LLMMetaclass):
|
||||
@attr.define(slots=True, repr=False)
|
||||
class LLM(LLMInterface, t.Generic[_M, _T]):
|
||||
if t.TYPE_CHECKING:
|
||||
# NOTE: the following will be populated by metaclass
|
||||
# The following will be populated by metaclass
|
||||
__llm_trust_remote_code__: bool
|
||||
__llm_implementation__: t.Literal["pt", "tf", "flax"]
|
||||
__llm_model__: _M | None
|
||||
__llm_tokenizer__: _T | None
|
||||
__llm_tag__: bentoml.Tag | None
|
||||
__llm_bentomodel__: bentoml.Model | None
|
||||
|
||||
__openllm_post_init__: t.Callable[[t.Self], None] | None
|
||||
__openllm_custom_load__: t.Callable[[t.Self, t.Any, t.Any], None] | None
|
||||
__openllm_custom_import_kwargs__: property | None
|
||||
|
||||
load_in_mha: bool
|
||||
_llm_attrs: dict[str, t.Any]
|
||||
_llm_args: tuple[t.Any, ...]
|
||||
_model_args: tuple[t.Any, ...]
|
||||
_model_attrs: dict[str, t.Any]
|
||||
_tokenizer_attrs: dict[str, t.Any]
|
||||
|
||||
# NOTE: the following is the similar interface to HuggingFace pretrained protocol.
|
||||
def __init_subclass__(cls):
|
||||
cd = cls.__dict__
|
||||
prefix_class_name_config = cls.__name__
|
||||
if prefix_class_name_config.startswith("Flax"):
|
||||
implementation = "flax"
|
||||
prefix_class_name_config = prefix_class_name_config[4:]
|
||||
elif prefix_class_name_config.startswith("TF"):
|
||||
implementation = "tf"
|
||||
prefix_class_name_config = prefix_class_name_config[2:]
|
||||
else:
|
||||
implementation = "pt"
|
||||
cls.__llm_implementation__ = implementation
|
||||
config_class = openllm.AutoConfig.infer_class_from_name(prefix_class_name_config)
|
||||
|
||||
if "__openllm_internal__" in cd:
|
||||
if "config_class" not in cd:
|
||||
cls.config_class = config_class
|
||||
else:
|
||||
logger.debug(f"Using config class {cd['config_class']} for {cls.__name__}.")
|
||||
else:
|
||||
if "config_class" not in cd:
|
||||
raise RuntimeError(
|
||||
"Missing required key 'config_class'. Make sure to define it within the LLM subclass."
|
||||
)
|
||||
|
||||
if cls.llm_post_init is not LLMInterface.llm_post_init:
|
||||
original_llm_post_init = cd["llm_post_init"]
|
||||
|
||||
def wrapped_llm_post_init(self: t.Self) -> None:
|
||||
"""We need to both initialize private attributes and call the user-defined model_post_init
|
||||
method.
|
||||
"""
|
||||
_default_post_init(self)
|
||||
original_llm_post_init(self)
|
||||
|
||||
cls.llm_post_init = wrapped_llm_post_init
|
||||
else:
|
||||
setattr(cls, "llm_post_init", _default_post_init)
|
||||
|
||||
if cls.import_model is LLMInterface.import_model:
|
||||
# using the default import model
|
||||
setattr(cls, "import_model", functools.partial(import_model, _model_framework=implementation))
|
||||
else:
|
||||
logger.debug("Using custom 'import_model' for %s", cls.__name__)
|
||||
|
||||
cls.__openllm_post_init__ = None if cls.llm_post_init is LLMInterface.llm_post_init else cls.llm_post_init
|
||||
cls.__openllm_custom_load__ = None if cls.load_model is LLMInterface.load_model else cls.load_model
|
||||
cls.__openllm_custom_import_kwargs__ = (
|
||||
None if cls.import_kwargs is LLMInterface.import_kwargs else cls.import_kwargs
|
||||
)
|
||||
|
||||
for at in {"bentomodel", "tag", "model", "tokenizer"}:
|
||||
setattr(cls, f"__llm_{at}__", None)
|
||||
|
||||
# The following is the similar interface to HuggingFace pretrained protocol.
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls, model_id: str | None = None, llm_config: openllm.LLMConfig | None = None, *args: t.Any, **attrs: t.Any
|
||||
cls,
|
||||
model_id: str | None = None,
|
||||
llm_config: openllm.LLMConfig | None = None,
|
||||
*args: t.Any,
|
||||
**attrs: t.Any,
|
||||
) -> LLM[_M, _T]:
|
||||
return cls(model_id=model_id, llm_config=llm_config, *args, **attrs)
|
||||
|
||||
@@ -474,6 +433,7 @@ class LLM(LLMInterface, t.Generic[_M, _T], metaclass=LLMMetaclass):
|
||||
"""
|
||||
|
||||
load_in_mha = attrs.pop("load_in_mha", False)
|
||||
openllm_model_version = attrs.pop("openllm_model_version", None)
|
||||
|
||||
if llm_config is not None:
|
||||
logger.debug("Using given 'llm_config=(%s)' to initialize LLM", llm_config)
|
||||
@@ -490,17 +450,37 @@ class LLM(LLMInterface, t.Generic[_M, _T], metaclass=LLMMetaclass):
|
||||
# NOTE: This is the actual given path or pretrained weight for this LLM.
|
||||
self._model_id = model_id
|
||||
|
||||
# parsing tokenizer and model kwargs
|
||||
tokenizer_kwds = {k[len(TOKENIZER_PREFIX) :]: v for k, v in attrs.items() if k.startswith(TOKENIZER_PREFIX)}
|
||||
for k in attrs:
|
||||
if k in tokenizer_kwds:
|
||||
del attrs[k]
|
||||
# the rest of kwargs should be model kwargs
|
||||
model_kwds = attrs
|
||||
|
||||
if self.import_kwargs:
|
||||
_custom_model_kwds, _custom_tokenizer_kwds = self.import_kwargs
|
||||
model_kwds.update(_custom_model_kwds)
|
||||
tokenizer_kwds.update(_custom_tokenizer_kwds)
|
||||
|
||||
# handle trust_remote_code
|
||||
self.__llm_trust_remote_code__ = model_kwds.pop("trust_remote_code", self.config.__openllm_trust_remote_code__)
|
||||
|
||||
# NOTE: Save the args and kwargs for latter load
|
||||
self._llm_args = args
|
||||
self._llm_attrs = attrs
|
||||
self._model_args = args
|
||||
self._model_attrs = model_kwds
|
||||
self._tokenizer_attrs = tokenizer_kwds
|
||||
|
||||
if self.__openllm_post_init__:
|
||||
self.llm_post_init()
|
||||
|
||||
# finally: we allow users to overwrite the load_in_mha defined by the LLM subclass.
|
||||
if load_in_mha:
|
||||
logger.debug("Overwriting 'load_in_mha=%s' (base load_in_mha=%s)", load_in_mha, self.load_in_mha)
|
||||
self.load_in_mha = load_in_mha
|
||||
|
||||
self._openllm_model_version = openllm_model_version
|
||||
|
||||
def __setattr__(self, attr: str, value: t.Any):
|
||||
if attr in _reserved_namespace:
|
||||
raise ForbiddenAttributeError(
|
||||
@@ -511,10 +491,18 @@ class LLM(LLMInterface, t.Generic[_M, _T], metaclass=LLMMetaclass):
|
||||
|
||||
super().__setattr__(attr, value)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
keys = {"model_id", "runner_name", "llm_type", "config"}
|
||||
return f"{self.__class__.__name__}({', '.join(f'{k}={getattr(self, k)!r}' for k in keys)})"
|
||||
|
||||
@property
|
||||
def model_id(self) -> str:
|
||||
return self._model_id
|
||||
|
||||
@property
|
||||
def runner_name(self) -> str:
|
||||
return f"llm-{self.config.__openllm_start_name__}-runner"
|
||||
|
||||
# NOTE: The section below defines a loose contract with langchain's LLM interface.
|
||||
@property
|
||||
def llm_type(self) -> str:
|
||||
@@ -527,42 +515,14 @@ class LLM(LLMInterface, t.Generic[_M, _T], metaclass=LLMMetaclass):
|
||||
"model_ids": orjson.dumps(self.config.__openllm_model_ids__).decode(),
|
||||
}
|
||||
|
||||
@t.overload
|
||||
def make_tag(
|
||||
self,
|
||||
model_id: str | None = None,
|
||||
return_unused_kwargs: t.Literal[False] = ...,
|
||||
trust_remote_code: bool = ...,
|
||||
**attrs: t.Any,
|
||||
) -> bentoml.Tag:
|
||||
...
|
||||
|
||||
@t.overload
|
||||
def make_tag(
|
||||
self,
|
||||
model_id: str | None = None,
|
||||
return_unused_kwargs: t.Literal[True] = ...,
|
||||
trust_remote_code: bool = ...,
|
||||
**attrs: t.Any,
|
||||
) -> tuple[bentoml.Tag, dict[str, t.Any]]:
|
||||
...
|
||||
|
||||
def make_tag(
|
||||
self,
|
||||
model_id: str | None = None,
|
||||
return_unused_kwargs: bool = False,
|
||||
trust_remote_code: bool = False,
|
||||
**attrs: t.Any,
|
||||
) -> bentoml.Tag | tuple[bentoml.Tag, dict[str, t.Any]]:
|
||||
def make_tag(self, model_id: str | None = None) -> bentoml.Tag:
|
||||
"""Generate a ``bentoml.Tag`` from a given transformers model name.
|
||||
|
||||
Note that this depends on your model to have a config class available.
|
||||
|
||||
Args:
|
||||
model_name_or_path: The transformers model name or path to load the model from.
|
||||
If it is a path, then `openllm_model_version` must be passed in as a kwarg.
|
||||
return_unused_kwargs: Whether to return unused kwargs. Defaults to False. If set, it will return a tuple
|
||||
of ``bentoml.Tag`` and a dict of unused kwargs.
|
||||
model_id: The transformers model name or path to load the model from.
|
||||
If it is a path, then `openllm_model_version` must be passed in as a kwarg.
|
||||
trust_remote_code: Whether to trust the remote code. Defaults to False.
|
||||
**attrs: Additional kwargs to pass to the ``transformers.AutoConfig`` constructor.
|
||||
If your pass ``return_unused_kwargs=True``, it will be ignored.
|
||||
@@ -572,88 +532,43 @@ class LLM(LLMInterface, t.Generic[_M, _T], metaclass=LLMMetaclass):
|
||||
"""
|
||||
if model_id is None:
|
||||
model_id = self._model_id
|
||||
|
||||
if "return_unused_kwargs" in attrs:
|
||||
logger.debug("Ignoring 'return_unused_kwargs' in 'generate_tag_from_model_name'.")
|
||||
attrs.pop("return_unused_kwargs", None)
|
||||
|
||||
config, attrs = t.cast(
|
||||
"tuple[transformers.PretrainedConfig, dict[str, t.Any]]",
|
||||
transformers.AutoConfig.from_pretrained(
|
||||
model_id, return_unused_kwargs=True, trust_remote_code=trust_remote_code, **attrs
|
||||
),
|
||||
)
|
||||
name = convert_transformers_model_name(model_id)
|
||||
|
||||
if os.path.exists(os.path.dirname(model_id)):
|
||||
# If the model_name_or_path is a path, we assume it's a local path,
|
||||
# then users must pass a version for this.
|
||||
model_version = attrs.pop("openllm_model_version", None)
|
||||
if model_version is None:
|
||||
config = t.cast(
|
||||
"transformers.PretrainedConfig",
|
||||
transformers.AutoConfig.from_pretrained(model_id, trust_remote_code=self.__llm_trust_remote_code__),
|
||||
)
|
||||
|
||||
model_version = getattr(config, "_commit_hash", None)
|
||||
if model_version is None:
|
||||
if self._openllm_model_version is not None:
|
||||
logger.warning(
|
||||
"""\
|
||||
When passing a path, it is recommended to also pass 'openllm_model_version'
|
||||
into Runner/AutoLLM intialization.
|
||||
|
||||
For example:
|
||||
|
||||
>>> import openllm
|
||||
>>> runner = openllm.Runner('/path/to/fine-tuning/model', openllm_model_version='lora-version')
|
||||
|
||||
Example with AutoLLM:
|
||||
|
||||
>>> import openllm
|
||||
>>> model = openllm.AutoLLM.for_model('/path/to/fine-tuning/model', openllm_model_version='lora-version')
|
||||
|
||||
No worries, OpenLLM will generate one for you. But for your own convenience, make sure to
|
||||
specify 'openllm_model_version'.
|
||||
"""
|
||||
)
|
||||
model_version = bentoml.Tag.from_taglike(name).make_new_version().version
|
||||
else:
|
||||
model_version = getattr(config, "_commit_hash", None)
|
||||
if model_version is None:
|
||||
logger.warning(
|
||||
"Given %s from '%s' doesn't contain a commit hash. We will generate"
|
||||
" the tag without specific version.",
|
||||
t.cast("type[transformers.PretrainedConfig]", config.__class__),
|
||||
"Given %s from '%s' doesn't contain a commit hash, and 'openllm_model_version' is not given. "
|
||||
"We will generate the tag without specific version.",
|
||||
config.__class__,
|
||||
model_id,
|
||||
)
|
||||
tag = bentoml.Tag.from_taglike(f"{self.__llm_implementation__}-{name}:{model_version}")
|
||||
model_version = bentoml.Tag.from_taglike(name).make_new_version().version
|
||||
else:
|
||||
logger.debug("Using %s for '%s' as model version", self._openllm_model_version, model_id)
|
||||
model_version = self._openllm_model_version
|
||||
|
||||
if return_unused_kwargs:
|
||||
return tag, attrs
|
||||
return tag
|
||||
return bentoml.Tag.from_taglike(f"{self.__llm_implementation__}-{name}:{model_version}")
|
||||
|
||||
def ensure_model_id_exists(self) -> bentoml.Model:
|
||||
trust_remote_code = self._llm_attrs.pop("trust_remote_code", self.config.__openllm_trust_remote_code__)
|
||||
tag, kwds = self.make_tag(return_unused_kwargs=True, trust_remote_code=trust_remote_code, **self._llm_attrs)
|
||||
try:
|
||||
return bentoml.transformers.get(tag)
|
||||
return bentoml.transformers.get(self.tag)
|
||||
except bentoml.exceptions.NotFound:
|
||||
logger.info("'%s' with tag (%s) not found, importing from HuggingFace Hub.", self.__class__.__name__, tag)
|
||||
tokenizer_kwds = {k[len("_tokenizer_") :]: v for k, v in kwds.items() if k.startswith("_tokenizer_")}
|
||||
kwds = {k: v for k, v in kwds.items() if not k.startswith("_tokenizer_")}
|
||||
|
||||
if self.import_kwargs:
|
||||
tokenizer_kwds = {
|
||||
**{
|
||||
k[len("_tokenizer_") :]: v for k, v in self.import_kwargs.items() if k.startswith("_tokenizer_")
|
||||
},
|
||||
**tokenizer_kwds,
|
||||
}
|
||||
kwds = {
|
||||
**{k: v for k, v in self.import_kwargs.items() if not k.startswith("_tokenizer_")},
|
||||
**kwds,
|
||||
}
|
||||
|
||||
logger.info(
|
||||
"'%s' with tag (%s) not found, importing from HuggingFace Hub.", self.__class__.__name__, self.tag
|
||||
)
|
||||
return self.import_model(
|
||||
self._model_id,
|
||||
tag,
|
||||
*self._llm_args,
|
||||
tokenizer_kwds=tokenizer_kwds,
|
||||
trust_remote_code=trust_remote_code,
|
||||
**kwds,
|
||||
self.tag,
|
||||
*self._model_args,
|
||||
tokenizer_kwds=self._tokenizer_attrs,
|
||||
trust_remote_code=self.__llm_trust_remote_code__,
|
||||
**self._model_attrs,
|
||||
)
|
||||
finally:
|
||||
if openllm.utils.is_torch_available() and torch.cuda.is_available():
|
||||
@@ -667,31 +582,30 @@ class LLM(LLMInterface, t.Generic[_M, _T], metaclass=LLMMetaclass):
|
||||
|
||||
@property
|
||||
def tag(self) -> bentoml.Tag:
|
||||
return self._bentomodel.tag
|
||||
if self.__llm_tag__ is None:
|
||||
self.__llm_tag__ = self.make_tag()
|
||||
return self.__llm_tag__
|
||||
|
||||
@property
|
||||
def model(self) -> _M:
|
||||
"""The model to use for this LLM. This shouldn't be set at runtime, rather let OpenLLM handle it."""
|
||||
# Run check for GPU
|
||||
trust_remote_code = self._llm_attrs.pop("trust_remote_code", self.config.__openllm_trust_remote_code__)
|
||||
self.config.check_if_gpu_is_available(implementation=self.__llm_implementation__)
|
||||
|
||||
kwds = {k: v for k, v in self._llm_attrs.items() if not k.startswith("_tokenizer_")}
|
||||
kwds = self._model_attrs
|
||||
kwds["trust_remote_code"] = self.__llm_trust_remote_code__
|
||||
|
||||
if self.import_kwargs:
|
||||
kwds = {**{k: v for k, v in self.import_kwargs.items() if not k.startswith("_tokenizer_")}, **kwds}
|
||||
|
||||
kwds["trust_remote_code"] = trust_remote_code
|
||||
if self.load_in_mha and "_pretrained_class" not in self._bentomodel.info.metadata:
|
||||
# This is a pipeline, provide a accelerator args
|
||||
kwds["accelerator"] = "bettertransformer"
|
||||
|
||||
if self.__llm_model__ is None:
|
||||
if self.__openllm_custom_load__:
|
||||
self.__llm_model__ = self.load_model(self.tag, *self._llm_args, **kwds)
|
||||
self.__llm_model__ = self.load_model(self.tag, *self._model_args, **kwds)
|
||||
else:
|
||||
self.__llm_model__ = self._bentomodel.load_model(*self._llm_args, **kwds)
|
||||
self.__llm_model__ = self._bentomodel.load_model(*self._model_args, **kwds)
|
||||
|
||||
# TODO: self.config.__openllm_runtime__: t.Literal['python', 'cpp']
|
||||
if (
|
||||
self.load_in_mha
|
||||
and all(i in self._bentomodel.info.metadata for i in ("_framework", "_pretrained_class"))
|
||||
@@ -743,11 +657,7 @@ class LLM(LLMInterface, t.Generic[_M, _T], metaclass=LLMMetaclass):
|
||||
- 'name': will be generated by OpenLLM, hence users don't shouldn't worry about this.
|
||||
The generated name will be 'llm-<model-start-name>-runner' (ex: llm-dolly-v2-runner, llm-chatglm-runner)
|
||||
"""
|
||||
|
||||
name = f"llm-{self.config.__openllm_start_name__}-runner"
|
||||
models = models if models is not None else []
|
||||
|
||||
# NOTE: The side effect of this is that will load the imported model during runner creation.
|
||||
models.append(self._bentomodel)
|
||||
|
||||
if scheduling_strategy is None:
|
||||
@@ -776,6 +686,12 @@ class LLM(LLMInterface, t.Generic[_M, _T], metaclass=LLMMetaclass):
|
||||
cls.llm_type = llm_type
|
||||
cls.identifying_params = identifying_params
|
||||
|
||||
def __init__(__self: _Runnable):
|
||||
# NOTE: The side effect of this line
|
||||
# is that it will load the imported model during
|
||||
# runner creation. So don't remove it!!
|
||||
self.model
|
||||
|
||||
@bentoml.Runnable.method(
|
||||
batchable=generate_sig.batchable,
|
||||
batch_dim=generate_sig.batch_dim,
|
||||
@@ -803,7 +719,7 @@ class LLM(LLMInterface, t.Generic[_M, _T], metaclass=LLMMetaclass):
|
||||
def generate_iterator(__self, prompt: str, **attrs: t.Any) -> t.Iterator[t.Any]:
|
||||
yield self.generate_iterator(prompt, **attrs)
|
||||
|
||||
def _wrapped_generate_run(__self: bentoml.Runner, *args: t.Any, **kwargs: t.Any) -> t.Any:
|
||||
def _wrapped_generate_run(__self: LLMRunner, prompt: str, **kwargs: t.Any) -> t.Any:
|
||||
"""Wrapper for runner.generate.run() to handle the prompt and postprocessing.
|
||||
|
||||
This will be used for LangChain API.
|
||||
@@ -814,10 +730,6 @@ class LLM(LLMInterface, t.Generic[_M, _T], metaclass=LLMMetaclass):
|
||||
runner("What is the meaning of life?")
|
||||
```
|
||||
"""
|
||||
if len(args) > 1:
|
||||
raise RuntimeError("Only one positional argument is allowed for generate()")
|
||||
prompt = args[0] if len(args) == 1 else kwargs.pop("prompt", "")
|
||||
|
||||
prompt, generate_kwargs, postprocess_kwargs = self.sanitize_parameters(prompt, **kwargs)
|
||||
generated_result = __self.generate.run(prompt, **generate_kwargs)
|
||||
return self.postprocess_generate(prompt, generated_result, **postprocess_kwargs)
|
||||
@@ -849,7 +761,7 @@ class LLM(LLMInterface, t.Generic[_M, _T], metaclass=LLMMetaclass):
|
||||
"identifying_params": self.identifying_params,
|
||||
},
|
||||
),
|
||||
name=name,
|
||||
name=self.runner_name,
|
||||
models=models,
|
||||
max_batch_size=max_batch_size,
|
||||
max_latency_ms=max_latency_ms,
|
||||
@@ -868,6 +780,13 @@ class LLM(LLMInterface, t.Generic[_M, _T], metaclass=LLMMetaclass):
|
||||
First, it runs `self.sanitize_parameters` to sanitize the parameters.
|
||||
The the sanitized prompt and kwargs will be pass into self.generate.
|
||||
Finally, run self.postprocess_generate to postprocess the generated result.
|
||||
|
||||
This allows users to do the following:
|
||||
|
||||
```python
|
||||
llm = openllm.AutoLLM.for_model("dolly-v2")
|
||||
llm("What is the meaning of life?")
|
||||
```
|
||||
"""
|
||||
prompt, generate_kwargs, postprocess_kwargs = self.sanitize_parameters(prompt, **attrs)
|
||||
generated_result = self.generate(prompt, **generate_kwargs)
|
||||
|
||||
@@ -884,11 +884,9 @@ def cli_factory() -> click.Group:
|
||||
else:
|
||||
model = openllm.AutoLLM.for_model(model_name, model_id=model_id, llm_config=config)
|
||||
|
||||
tag = model.make_tag(trust_remote_code=config.__openllm_trust_remote_code__)
|
||||
|
||||
if len(bentoml.models.list(tag)) == 0:
|
||||
if len(bentoml.models.list(model.tag)) == 0:
|
||||
if output == "pretty":
|
||||
_echo(f"{tag} does not exists yet!. Downloading...", fg="yellow", nl=True)
|
||||
_echo(f"{model.tag} does not exists yet!. Downloading...", fg="yellow", nl=True)
|
||||
m = model.ensure_model_id_exists()
|
||||
if output == "pretty":
|
||||
_echo(f"Saved model: {m.tag}")
|
||||
@@ -899,9 +897,9 @@ def cli_factory() -> click.Group:
|
||||
).decode()
|
||||
)
|
||||
else:
|
||||
_echo(tag)
|
||||
_echo(model.tag)
|
||||
else:
|
||||
m = bentoml.transformers.get(tag)
|
||||
m = bentoml.transformers.get(model.tag)
|
||||
if output == "pretty":
|
||||
_echo(f"{model_name} is already setup for framework '{env}': {str(m.tag)}", nl=True, fg="yellow")
|
||||
elif output == "json":
|
||||
|
||||
@@ -37,11 +37,12 @@ class DollyV2(openllm.LLM["transformers.Pipeline", "transformers.PreTrainedToken
|
||||
|
||||
@property
|
||||
def import_kwargs(self):
|
||||
return {
|
||||
model_kwds = {
|
||||
"device_map": "auto" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None,
|
||||
"torch_dtype": torch.bfloat16,
|
||||
"_tokenizer_padding_side": "left",
|
||||
}
|
||||
tokenizer_kwds = {"padding_side": "left"}
|
||||
return model_kwds, tokenizer_kwds
|
||||
|
||||
def llm_post_init(self):
|
||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
@@ -36,10 +36,12 @@ class Falcon(openllm.LLM["transformers.TextGenerationPipeline", "transformers.Pr
|
||||
|
||||
@property
|
||||
def import_kwargs(self):
|
||||
return {
|
||||
model_kwds = {
|
||||
"torch_dtype": torch.bfloat16,
|
||||
"device_map": "auto" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None,
|
||||
}
|
||||
tokenizer_kwds: dict[str, t.Any] = {}
|
||||
return model_kwds, tokenizer_kwds
|
||||
|
||||
def import_model(
|
||||
self, model_id: str, tag: bentoml.Tag, *model_args: t.Any, tokenizer_kwds: dict[str, t.Any], **attrs: t.Any
|
||||
|
||||
@@ -51,11 +51,13 @@ class StableLM(openllm.LLM["transformers.GPTNeoXForCausalLM", "transformers.GPTN
|
||||
|
||||
@property
|
||||
def import_kwargs(self):
|
||||
return {
|
||||
model_kwds = {
|
||||
"torch_dtype": torch.float16 if torch.cuda.is_available() else torch.float32,
|
||||
"load_in_8bit": False,
|
||||
"device_map": "auto" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None,
|
||||
}
|
||||
tokenizer_kwds: dict[str, t.Any] = {}
|
||||
return model_kwds, tokenizer_kwds
|
||||
|
||||
def sanitize_parameters(
|
||||
self,
|
||||
|
||||
@@ -45,12 +45,13 @@ class StarCoder(openllm.LLM["transformers.GPTBigCodeForCausalLM", "transformers.
|
||||
|
||||
@property
|
||||
def import_kwargs(self):
|
||||
return {
|
||||
"_tokenizer_padding_side": "left",
|
||||
model_kwds = {
|
||||
"device_map": "auto" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None,
|
||||
"load_in_8bit": True if torch.cuda.device_count() > 1 else False,
|
||||
"torch_dtype": torch.float16,
|
||||
}
|
||||
tokenizer_kwds = {"padding_side": "left"}
|
||||
return model_kwds, tokenizer_kwds
|
||||
|
||||
def import_model(
|
||||
self,
|
||||
|
||||
@@ -33,10 +33,21 @@ if t.TYPE_CHECKING:
|
||||
from black.lines import LinesBlock
|
||||
from fs.base import FS
|
||||
|
||||
DictStrAny = dict[str, t.Any]
|
||||
|
||||
class ModifyNodeProtocol(t.Protocol):
|
||||
def __call__(self, node: Node, model_name: str, formatter: type[ModelNameFormatter]) -> None:
|
||||
...
|
||||
|
||||
from attr import _make_method
|
||||
else:
|
||||
# NOTE: Using internal API from attr here, since we are actually
|
||||
# allowing subclass of openllm.LLMConfig to become 'attrs'-ish
|
||||
from attr._make import _make_method
|
||||
|
||||
DictStrAny = dict
|
||||
|
||||
_T = t.TypeVar("_T")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -165,3 +176,113 @@ def _parse_service_file(src_contents: str, model_name: str, model_id: str, mode:
|
||||
return newline
|
||||
return ""
|
||||
return "".join(dst_contents)
|
||||
|
||||
|
||||
# NOTE: The following ins extracted from attrs internal APIs
|
||||
|
||||
# sentinel object for unequivocal object() getattr
|
||||
_sentinel = object()
|
||||
|
||||
|
||||
def has_own_attribute(cls: type[t.Any], attrib_name: t.Any):
|
||||
"""
|
||||
Check whether *cls* defines *attrib_name* (and doesn't just inherit it).
|
||||
"""
|
||||
attr = getattr(cls, attrib_name, _sentinel)
|
||||
if attr is _sentinel:
|
||||
return False
|
||||
|
||||
for base_cls in cls.__mro__[1:]:
|
||||
a = getattr(base_cls, attrib_name, None)
|
||||
if attr is a:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def get_annotations(cls: type[t.Any]) -> DictStrAny:
|
||||
"""
|
||||
Get annotations for *cls*.
|
||||
"""
|
||||
if has_own_attribute(cls, "__annotations__"):
|
||||
return cls.__annotations__
|
||||
|
||||
return DictStrAny()
|
||||
|
||||
|
||||
_classvar_prefixes = (
|
||||
"typing.ClassVar",
|
||||
"t.ClassVar",
|
||||
"ClassVar",
|
||||
"typing_extensions.ClassVar",
|
||||
)
|
||||
|
||||
|
||||
def is_class_var(annot: str | t.Any) -> bool:
|
||||
"""
|
||||
Check whether *annot* is a typing.ClassVar.
|
||||
|
||||
The string comparison hack is used to avoid evaluating all string
|
||||
annotations which would put attrs-based classes at a performance
|
||||
disadvantage compared to plain old classes.
|
||||
"""
|
||||
annot = str(annot)
|
||||
|
||||
# Annotation can be quoted.
|
||||
if annot.startswith(("'", '"')) and annot.endswith(("'", '"')):
|
||||
annot = annot[1:-1]
|
||||
|
||||
return annot.startswith(_classvar_prefixes)
|
||||
|
||||
|
||||
def add_method_dunders(cls: type[t.Any], method_or_cls: _T, _overwrite_doc: str | None = None) -> _T:
|
||||
"""
|
||||
Add __module__ and __qualname__ to a *method* if possible.
|
||||
"""
|
||||
try:
|
||||
method_or_cls.__module__ = cls.__module__
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
try:
|
||||
method_or_cls.__qualname__ = ".".join((cls.__qualname__, method_or_cls.__name__))
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
try:
|
||||
method_or_cls.__doc__ = (
|
||||
_overwrite_doc or "Method or class generated by LLMConfig for class " f"{cls.__qualname__}."
|
||||
)
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
return method_or_cls
|
||||
|
||||
|
||||
def generate_unique_filename(cls: type[t.Any], func_name: str):
|
||||
return f"<{cls.__name__} generated {func_name} {cls.__module__}." f"{getattr(cls, '__qualname__', cls.__name__)}>"
|
||||
|
||||
|
||||
def generate_function(
|
||||
typ: type[t.Any],
|
||||
func_name: str,
|
||||
lines: list[str] | None,
|
||||
args: tuple[str, ...] | None,
|
||||
globs: dict[str, t.Any],
|
||||
annotations: dict[str, t.Any] | None = None,
|
||||
):
|
||||
from . import DEBUG
|
||||
|
||||
script = "def %s(%s):\n %s\n" % (
|
||||
func_name,
|
||||
", ".join(args) if args is not None else "",
|
||||
"\n ".join(lines) if lines else "pass",
|
||||
)
|
||||
meth = _make_method(func_name, script, generate_unique_filename(typ, func_name), globs)
|
||||
if annotations:
|
||||
meth.__annotations__ = annotations
|
||||
|
||||
if DEBUG:
|
||||
logger.info("Generated script for %s:\n\n%s", typ, script)
|
||||
|
||||
return meth
|
||||
|
||||
Reference in New Issue
Block a user