From 873edc4121481cc3733dcaa8a675f8fac55b7186 Mon Sep 17 00:00:00 2001 From: Aaron <29749331+aarnphm@users.noreply.github.com> Date: Fri, 2 Jun 2023 01:01:00 -0700 Subject: [PATCH] chore(types): improve annotaiton for specified CLI improve: Faster CLI improvement, cached hints to __openllm_hints__ Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> --- pyproject.toml | 1 + src/openllm/_configuration.py | 62 ++++++++++++++++------ src/openllm/_types.py | 20 +++++++ typings/click_option_group/_decorators.pyi | 14 ++++- 4 files changed, 78 insertions(+), 19 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e56f9cf6..6c38562c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -187,6 +187,7 @@ ban-relative-imports = "all" [tool.ruff.per-file-ignores] # Tests can use magic values, assertions, and relative imports "__init__.py" = ["E402", "F401", "F403", "F811"] +"src/openllm/_types.py" = ["E402"] "tests/**/*" = ["PLR2004", "S101", "TID252"] [tool.pyright] diff --git a/src/openllm/_configuration.py b/src/openllm/_configuration.py index ed1527e7..ba492b3b 100644 --- a/src/openllm/_configuration.py +++ b/src/openllm/_configuration.py @@ -44,7 +44,6 @@ import typing as t from operator import itemgetter import attr -import click import inflection import orjson from bentoml._internal.models.model import ModelSignature @@ -63,9 +62,7 @@ if t.TYPE_CHECKING: from attr import _CountingAttr, _make_init from transformers.generation.beam_constraints import Constraint - P = t.ParamSpec("P") - - F = t.Callable[P, t.Any] + from ._types import ClickFunctionWrapper, F, O_co, P ReprArgs: t.TypeAlias = t.Iterable[tuple[str | None, t.Any]] @@ -85,17 +82,38 @@ else: __all__ = ["LLMConfig", "ModelSignature"] -_object_setattr = object.__setattr__ - logger = logging.getLogger(__name__) +@t.overload +def attrs_to_options( + name: str, + field: attr.Attribute[t.Any], + model_name: str, + typ: type[t.Any] | None = None, + suffix_generation: bool = False, +) -> F[..., F[..., openllm.LLMConfig]]: + ... + + +@t.overload +def attrs_to_options( # type: ignore (overlapping overload) + name: str, + field: attr.Attribute[O_co], + model_name: str, + typ: type[t.Any] | None = None, + suffix_generation: bool = False, +) -> F[..., F[P, O_co]]: + ... + + def attrs_to_options( name: str, field: attr.Attribute[t.Any], model_name: str, + typ: type[t.Any] | None = None, suffix_generation: bool = False, -) -> t.Callable[[F[P]], F[P]]: +) -> t.Callable[..., ClickFunctionWrapper[..., t.Any]]: # TODO: support parsing nested attrs class envvar = field.metadata["env"] dasherized = inflection.dasherize(name) @@ -112,7 +130,7 @@ def attrs_to_options( return optgroup.option( identifier, full_option_name, - type=field.type, + type=typ if typ is not None else field.type, required=field.default is attr.NOTHING, default=field.default if field.default not in (attr.NOTHING, None) else None, show_default=True, @@ -547,17 +565,15 @@ def _make_attr_tuple_class(cls_name: str, attr_names: t.Iterable[str]) -> type[t def _make_internal_generation_class(cls: type[LLMConfig]) -> type[GenerationConfig]: _has_gen_class = _has_own_attribute(cls, "GenerationConfig") - hints = t.get_type_hints(GenerationConfig) - def _evolve_with_base_default( - __cls: type[GenerationConfig], fields: list[attr.Attribute[t.Any]] + _: type[GenerationConfig], fields: list[attr.Attribute[t.Any]] ) -> list[attr.Attribute[t.Any]]: transformed: list[attr.Attribute[t.Any]] = [] for f in fields: env = f"OPENLLM_{cls.__openllm_model_name__.upper()}_GENERATION_{f.name.upper()}" _from_env = _populate_value_from_env_var(env, fallback=f.default) default_value = f.default if not _has_gen_class else getattr(cls.GenerationConfig, f.name, _from_env) - transformed.append(f.evolve(default=default_value, metadata={"env": env}, type=hints[f.name])) + transformed.append(f.evolve(default=default_value, metadata={"env": env})) return transformed generated_cls = attr.make_class( @@ -598,6 +614,10 @@ class LLMConfig: __openllm_name_type__: t.Literal["dasherize", "lowercase"] = "dasherize" __openllm_env__: openllm.utils.ModelEnv = Field(None, init=False) + """A ModelEnv instance for this LLMConfig.""" + + __openllm_hints__: dict[str, t.Any] = Field(None, init=False) + """An internal cache of resolved types for this LLMConfig.""" generation_class: type[GenerationConfig] = Field(None, init=False) @@ -728,6 +748,9 @@ class LLMConfig: # NOTE: Finally, set the generation_class for this given config. cls.generation_class = _make_internal_generation_class(cls) + hints.update(t.get_type_hints(cls.generation_class)) + cls.__openllm_hints__ = hints + @property def name_type(self) -> t.Literal["dasherize", "lowercase"]: return self.__openllm_name_type__ @@ -846,10 +869,10 @@ class LLMConfig: ... def to_generation_config(self, return_as_dict: bool = False) -> transformers.GenerationConfig | dict[str, t.Any]: - config = transformers.GenerationConfig(**self.generation_config.model_dump()) + config = transformers.GenerationConfig(**bentoml_cattr.unstructure(self.generation_config)) return config.to_dict() if return_as_dict else config - def to_click_options(self, f: F[P]) -> t.Callable[[F[P]], click.Command]: + def to_click_options(self, f: F[P, openllm.LLMConfig]) -> F[P, ClickFunctionWrapper[P, openllm.LLMConfig]]: """ Convert current model to click options. This can be used as a decorator for click commands. Note that the identifier for all LLMConfig will be prefixed with '_*', and the generation config @@ -859,16 +882,21 @@ class LLMConfig: if t.get_origin(field.type) is t.Union: # NOTE: Union type is currently not yet supported, we probably just need to use environment instead. continue - f = attrs_to_options(name, field, self.__openllm_model_name__, suffix_generation=True)(f) + f = attrs_to_options( + name, field, self.__openllm_model_name__, typ=self.__openllm_hints__[name], suffix_generation=True + )(f) f = optgroup.group(f"{self.generation_class.__name__} generation options")(f) if len(self.__class__.__openllm_attrs__) == 0: - return f + # NOTE: in this case, the function is already a ClickFunctionWrapper + # hence the casting + return t.cast("F[P, ClickFunctionWrapper[P, openllm.LLMConfig]]", f) + for name, field in attr.fields_dict(self.__class__).items(): if t.get_origin(field.type) is t.Union: # NOTE: Union type is currently not yet supported, we probably just need to use environment instead. continue - f = attrs_to_options(name, field, self.__openllm_model_name__)(f) + f = attrs_to_options(name, field, self.__openllm_model_name__, typ=self.__openllm_hints__[name])(f) return optgroup.group(f"{self.__class__.__name__} options")(f) diff --git a/src/openllm/_types.py b/src/openllm/_types.py index 1120a5bb..cb2d938b 100644 --- a/src/openllm/_types.py +++ b/src/openllm/_types.py @@ -23,6 +23,11 @@ import typing as t if not t.TYPE_CHECKING: raise RuntimeError(f"{__name__} should not be imported during runtime") +import click + +P = t.ParamSpec("P") +O_co = t.TypeVar("O_co", covariant=True) + import transformers from bentoml._internal.models.model import ModelSignatureDict as ModelSignatureDict from bentoml._internal.models.model import ModelSignaturesType as ModelSignaturesType @@ -31,3 +36,18 @@ LLMModel = transformers.PreTrainedModel | transformers.TFPreTrainedModel | trans LLMTokenizer = ( transformers.PreTrainedTokenizer | transformers.PreTrainedTokenizerFast | transformers.PreTrainedTokenizerBase ) + + +class ClickFunctionWrapper(t.Protocol[P, O_co]): + __click_params__: list[click.Option] + + def __call__(self, *args: P.args, **kwargs: P.kwargs) -> O_co: + ... + + +# F is a t.Callable[P, O_co] with compatible to ClickFunctionWrapper +class F(t.Generic[P, O_co]): + __click_params__: list[click.Option] + + def __call__(self, *args: P.args, **kwargs: P.kwargs) -> O_co: + ... diff --git a/typings/click_option_group/_decorators.pyi b/typings/click_option_group/_decorators.pyi index dbeef750..48196eb9 100644 --- a/typings/click_option_group/_decorators.pyi +++ b/typings/click_option_group/_decorators.pyi @@ -1,15 +1,25 @@ -from typing import Any, Dict, NamedTuple, Optional, Tuple, Type +from typing import Any, Callable, Dict, NamedTuple, Optional, ParamSpec, Protocol, Tuple, Type, TypeVar import click from ._core import FC, OptionGroup +P = ParamSpec("P") +O_co = TypeVar("O_co", covariant=True) + +F = Callable[P, O_co] + class OptionStackItem(NamedTuple): param_decls: Tuple[str, ...] attrs: Dict[str, Any] param_count: int ... +class ClickFunctionWrapper(Protocol[P, O_co]): + __click_params__: list[click.Option] + + def __call__(self, *args: P.args, **kwargs: P.kwargs) -> O_co: ... + class _NotAttachedOption(click.Option): """The helper class to catch grouped options which were not attached to the group @@ -76,7 +86,7 @@ class _OptGroup: :param attrs: Additional parameters of option group class """ ... - def option(self, *param_decls: Any, **attrs: Any) -> FC: + def option(self, *param_decls: Any, **attrs: Any) -> F[P, ClickFunctionWrapper[P, O_co]]: """The decorator adds a new option to the group The decorator is lazy. It adds option decls and attrs.