chore(types): improve annotaiton for specified CLI

improve: Faster CLI improvement, cached hints to __openllm_hints__

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
Aaron
2023-06-02 01:01:00 -07:00
parent c3aeb43997
commit 873edc4121
4 changed files with 78 additions and 19 deletions

View File

@@ -187,6 +187,7 @@ ban-relative-imports = "all"
[tool.ruff.per-file-ignores]
# Tests can use magic values, assertions, and relative imports
"__init__.py" = ["E402", "F401", "F403", "F811"]
"src/openllm/_types.py" = ["E402"]
"tests/**/*" = ["PLR2004", "S101", "TID252"]
[tool.pyright]

View File

@@ -44,7 +44,6 @@ import typing as t
from operator import itemgetter
import attr
import click
import inflection
import orjson
from bentoml._internal.models.model import ModelSignature
@@ -63,9 +62,7 @@ if t.TYPE_CHECKING:
from attr import _CountingAttr, _make_init
from transformers.generation.beam_constraints import Constraint
P = t.ParamSpec("P")
F = t.Callable[P, t.Any]
from ._types import ClickFunctionWrapper, F, O_co, P
ReprArgs: t.TypeAlias = t.Iterable[tuple[str | None, t.Any]]
@@ -85,17 +82,38 @@ else:
__all__ = ["LLMConfig", "ModelSignature"]
_object_setattr = object.__setattr__
logger = logging.getLogger(__name__)
@t.overload
def attrs_to_options(
name: str,
field: attr.Attribute[t.Any],
model_name: str,
typ: type[t.Any] | None = None,
suffix_generation: bool = False,
) -> F[..., F[..., openllm.LLMConfig]]:
...
@t.overload
def attrs_to_options( # type: ignore (overlapping overload)
name: str,
field: attr.Attribute[O_co],
model_name: str,
typ: type[t.Any] | None = None,
suffix_generation: bool = False,
) -> F[..., F[P, O_co]]:
...
def attrs_to_options(
name: str,
field: attr.Attribute[t.Any],
model_name: str,
typ: type[t.Any] | None = None,
suffix_generation: bool = False,
) -> t.Callable[[F[P]], F[P]]:
) -> t.Callable[..., ClickFunctionWrapper[..., t.Any]]:
# TODO: support parsing nested attrs class
envvar = field.metadata["env"]
dasherized = inflection.dasherize(name)
@@ -112,7 +130,7 @@ def attrs_to_options(
return optgroup.option(
identifier,
full_option_name,
type=field.type,
type=typ if typ is not None else field.type,
required=field.default is attr.NOTHING,
default=field.default if field.default not in (attr.NOTHING, None) else None,
show_default=True,
@@ -547,17 +565,15 @@ def _make_attr_tuple_class(cls_name: str, attr_names: t.Iterable[str]) -> type[t
def _make_internal_generation_class(cls: type[LLMConfig]) -> type[GenerationConfig]:
_has_gen_class = _has_own_attribute(cls, "GenerationConfig")
hints = t.get_type_hints(GenerationConfig)
def _evolve_with_base_default(
__cls: type[GenerationConfig], fields: list[attr.Attribute[t.Any]]
_: type[GenerationConfig], fields: list[attr.Attribute[t.Any]]
) -> list[attr.Attribute[t.Any]]:
transformed: list[attr.Attribute[t.Any]] = []
for f in fields:
env = f"OPENLLM_{cls.__openllm_model_name__.upper()}_GENERATION_{f.name.upper()}"
_from_env = _populate_value_from_env_var(env, fallback=f.default)
default_value = f.default if not _has_gen_class else getattr(cls.GenerationConfig, f.name, _from_env)
transformed.append(f.evolve(default=default_value, metadata={"env": env}, type=hints[f.name]))
transformed.append(f.evolve(default=default_value, metadata={"env": env}))
return transformed
generated_cls = attr.make_class(
@@ -598,6 +614,10 @@ class LLMConfig:
__openllm_name_type__: t.Literal["dasherize", "lowercase"] = "dasherize"
__openllm_env__: openllm.utils.ModelEnv = Field(None, init=False)
"""A ModelEnv instance for this LLMConfig."""
__openllm_hints__: dict[str, t.Any] = Field(None, init=False)
"""An internal cache of resolved types for this LLMConfig."""
generation_class: type[GenerationConfig] = Field(None, init=False)
@@ -728,6 +748,9 @@ class LLMConfig:
# NOTE: Finally, set the generation_class for this given config.
cls.generation_class = _make_internal_generation_class(cls)
hints.update(t.get_type_hints(cls.generation_class))
cls.__openllm_hints__ = hints
@property
def name_type(self) -> t.Literal["dasherize", "lowercase"]:
return self.__openllm_name_type__
@@ -846,10 +869,10 @@ class LLMConfig:
...
def to_generation_config(self, return_as_dict: bool = False) -> transformers.GenerationConfig | dict[str, t.Any]:
config = transformers.GenerationConfig(**self.generation_config.model_dump())
config = transformers.GenerationConfig(**bentoml_cattr.unstructure(self.generation_config))
return config.to_dict() if return_as_dict else config
def to_click_options(self, f: F[P]) -> t.Callable[[F[P]], click.Command]:
def to_click_options(self, f: F[P, openllm.LLMConfig]) -> F[P, ClickFunctionWrapper[P, openllm.LLMConfig]]:
"""
Convert current model to click options. This can be used as a decorator for click commands.
Note that the identifier for all LLMConfig will be prefixed with '<model_name>_*', and the generation config
@@ -859,16 +882,21 @@ class LLMConfig:
if t.get_origin(field.type) is t.Union:
# NOTE: Union type is currently not yet supported, we probably just need to use environment instead.
continue
f = attrs_to_options(name, field, self.__openllm_model_name__, suffix_generation=True)(f)
f = attrs_to_options(
name, field, self.__openllm_model_name__, typ=self.__openllm_hints__[name], suffix_generation=True
)(f)
f = optgroup.group(f"{self.generation_class.__name__} generation options")(f)
if len(self.__class__.__openllm_attrs__) == 0:
return f
# NOTE: in this case, the function is already a ClickFunctionWrapper
# hence the casting
return t.cast("F[P, ClickFunctionWrapper[P, openllm.LLMConfig]]", f)
for name, field in attr.fields_dict(self.__class__).items():
if t.get_origin(field.type) is t.Union:
# NOTE: Union type is currently not yet supported, we probably just need to use environment instead.
continue
f = attrs_to_options(name, field, self.__openllm_model_name__)(f)
f = attrs_to_options(name, field, self.__openllm_model_name__, typ=self.__openllm_hints__[name])(f)
return optgroup.group(f"{self.__class__.__name__} options")(f)

View File

@@ -23,6 +23,11 @@ import typing as t
if not t.TYPE_CHECKING:
raise RuntimeError(f"{__name__} should not be imported during runtime")
import click
P = t.ParamSpec("P")
O_co = t.TypeVar("O_co", covariant=True)
import transformers
from bentoml._internal.models.model import ModelSignatureDict as ModelSignatureDict
from bentoml._internal.models.model import ModelSignaturesType as ModelSignaturesType
@@ -31,3 +36,18 @@ LLMModel = transformers.PreTrainedModel | transformers.TFPreTrainedModel | trans
LLMTokenizer = (
transformers.PreTrainedTokenizer | transformers.PreTrainedTokenizerFast | transformers.PreTrainedTokenizerBase
)
class ClickFunctionWrapper(t.Protocol[P, O_co]):
__click_params__: list[click.Option]
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> O_co:
...
# F is a t.Callable[P, O_co] with compatible to ClickFunctionWrapper
class F(t.Generic[P, O_co]):
__click_params__: list[click.Option]
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> O_co:
...

View File

@@ -1,15 +1,25 @@
from typing import Any, Dict, NamedTuple, Optional, Tuple, Type
from typing import Any, Callable, Dict, NamedTuple, Optional, ParamSpec, Protocol, Tuple, Type, TypeVar
import click
from ._core import FC, OptionGroup
P = ParamSpec("P")
O_co = TypeVar("O_co", covariant=True)
F = Callable[P, O_co]
class OptionStackItem(NamedTuple):
param_decls: Tuple[str, ...]
attrs: Dict[str, Any]
param_count: int
...
class ClickFunctionWrapper(Protocol[P, O_co]):
__click_params__: list[click.Option]
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> O_co: ...
class _NotAttachedOption(click.Option):
"""The helper class to catch grouped options which were not attached to the group
@@ -76,7 +86,7 @@ class _OptGroup:
:param attrs: Additional parameters of option group class
"""
...
def option(self, *param_decls: Any, **attrs: Any) -> FC:
def option(self, *param_decls: Any, **attrs: Any) -> F[P, ClickFunctionWrapper[P, O_co]]:
"""The decorator adds a new option to the group
The decorator is lazy. It adds option decls and attrs.