mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-05-04 22:02:45 -04:00
chore(types): improve annotaiton for specified CLI
improve: Faster CLI improvement, cached hints to __openllm_hints__ Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
@@ -187,6 +187,7 @@ ban-relative-imports = "all"
|
||||
[tool.ruff.per-file-ignores]
|
||||
# Tests can use magic values, assertions, and relative imports
|
||||
"__init__.py" = ["E402", "F401", "F403", "F811"]
|
||||
"src/openllm/_types.py" = ["E402"]
|
||||
"tests/**/*" = ["PLR2004", "S101", "TID252"]
|
||||
|
||||
[tool.pyright]
|
||||
|
||||
@@ -44,7 +44,6 @@ import typing as t
|
||||
from operator import itemgetter
|
||||
|
||||
import attr
|
||||
import click
|
||||
import inflection
|
||||
import orjson
|
||||
from bentoml._internal.models.model import ModelSignature
|
||||
@@ -63,9 +62,7 @@ if t.TYPE_CHECKING:
|
||||
from attr import _CountingAttr, _make_init
|
||||
from transformers.generation.beam_constraints import Constraint
|
||||
|
||||
P = t.ParamSpec("P")
|
||||
|
||||
F = t.Callable[P, t.Any]
|
||||
from ._types import ClickFunctionWrapper, F, O_co, P
|
||||
|
||||
ReprArgs: t.TypeAlias = t.Iterable[tuple[str | None, t.Any]]
|
||||
|
||||
@@ -85,17 +82,38 @@ else:
|
||||
|
||||
__all__ = ["LLMConfig", "ModelSignature"]
|
||||
|
||||
_object_setattr = object.__setattr__
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@t.overload
|
||||
def attrs_to_options(
|
||||
name: str,
|
||||
field: attr.Attribute[t.Any],
|
||||
model_name: str,
|
||||
typ: type[t.Any] | None = None,
|
||||
suffix_generation: bool = False,
|
||||
) -> F[..., F[..., openllm.LLMConfig]]:
|
||||
...
|
||||
|
||||
|
||||
@t.overload
|
||||
def attrs_to_options( # type: ignore (overlapping overload)
|
||||
name: str,
|
||||
field: attr.Attribute[O_co],
|
||||
model_name: str,
|
||||
typ: type[t.Any] | None = None,
|
||||
suffix_generation: bool = False,
|
||||
) -> F[..., F[P, O_co]]:
|
||||
...
|
||||
|
||||
|
||||
def attrs_to_options(
|
||||
name: str,
|
||||
field: attr.Attribute[t.Any],
|
||||
model_name: str,
|
||||
typ: type[t.Any] | None = None,
|
||||
suffix_generation: bool = False,
|
||||
) -> t.Callable[[F[P]], F[P]]:
|
||||
) -> t.Callable[..., ClickFunctionWrapper[..., t.Any]]:
|
||||
# TODO: support parsing nested attrs class
|
||||
envvar = field.metadata["env"]
|
||||
dasherized = inflection.dasherize(name)
|
||||
@@ -112,7 +130,7 @@ def attrs_to_options(
|
||||
return optgroup.option(
|
||||
identifier,
|
||||
full_option_name,
|
||||
type=field.type,
|
||||
type=typ if typ is not None else field.type,
|
||||
required=field.default is attr.NOTHING,
|
||||
default=field.default if field.default not in (attr.NOTHING, None) else None,
|
||||
show_default=True,
|
||||
@@ -547,17 +565,15 @@ def _make_attr_tuple_class(cls_name: str, attr_names: t.Iterable[str]) -> type[t
|
||||
def _make_internal_generation_class(cls: type[LLMConfig]) -> type[GenerationConfig]:
|
||||
_has_gen_class = _has_own_attribute(cls, "GenerationConfig")
|
||||
|
||||
hints = t.get_type_hints(GenerationConfig)
|
||||
|
||||
def _evolve_with_base_default(
|
||||
__cls: type[GenerationConfig], fields: list[attr.Attribute[t.Any]]
|
||||
_: type[GenerationConfig], fields: list[attr.Attribute[t.Any]]
|
||||
) -> list[attr.Attribute[t.Any]]:
|
||||
transformed: list[attr.Attribute[t.Any]] = []
|
||||
for f in fields:
|
||||
env = f"OPENLLM_{cls.__openllm_model_name__.upper()}_GENERATION_{f.name.upper()}"
|
||||
_from_env = _populate_value_from_env_var(env, fallback=f.default)
|
||||
default_value = f.default if not _has_gen_class else getattr(cls.GenerationConfig, f.name, _from_env)
|
||||
transformed.append(f.evolve(default=default_value, metadata={"env": env}, type=hints[f.name]))
|
||||
transformed.append(f.evolve(default=default_value, metadata={"env": env}))
|
||||
return transformed
|
||||
|
||||
generated_cls = attr.make_class(
|
||||
@@ -598,6 +614,10 @@ class LLMConfig:
|
||||
__openllm_name_type__: t.Literal["dasherize", "lowercase"] = "dasherize"
|
||||
|
||||
__openllm_env__: openllm.utils.ModelEnv = Field(None, init=False)
|
||||
"""A ModelEnv instance for this LLMConfig."""
|
||||
|
||||
__openllm_hints__: dict[str, t.Any] = Field(None, init=False)
|
||||
"""An internal cache of resolved types for this LLMConfig."""
|
||||
|
||||
generation_class: type[GenerationConfig] = Field(None, init=False)
|
||||
|
||||
@@ -728,6 +748,9 @@ class LLMConfig:
|
||||
# NOTE: Finally, set the generation_class for this given config.
|
||||
cls.generation_class = _make_internal_generation_class(cls)
|
||||
|
||||
hints.update(t.get_type_hints(cls.generation_class))
|
||||
cls.__openllm_hints__ = hints
|
||||
|
||||
@property
|
||||
def name_type(self) -> t.Literal["dasherize", "lowercase"]:
|
||||
return self.__openllm_name_type__
|
||||
@@ -846,10 +869,10 @@ class LLMConfig:
|
||||
...
|
||||
|
||||
def to_generation_config(self, return_as_dict: bool = False) -> transformers.GenerationConfig | dict[str, t.Any]:
|
||||
config = transformers.GenerationConfig(**self.generation_config.model_dump())
|
||||
config = transformers.GenerationConfig(**bentoml_cattr.unstructure(self.generation_config))
|
||||
return config.to_dict() if return_as_dict else config
|
||||
|
||||
def to_click_options(self, f: F[P]) -> t.Callable[[F[P]], click.Command]:
|
||||
def to_click_options(self, f: F[P, openllm.LLMConfig]) -> F[P, ClickFunctionWrapper[P, openllm.LLMConfig]]:
|
||||
"""
|
||||
Convert current model to click options. This can be used as a decorator for click commands.
|
||||
Note that the identifier for all LLMConfig will be prefixed with '<model_name>_*', and the generation config
|
||||
@@ -859,16 +882,21 @@ class LLMConfig:
|
||||
if t.get_origin(field.type) is t.Union:
|
||||
# NOTE: Union type is currently not yet supported, we probably just need to use environment instead.
|
||||
continue
|
||||
f = attrs_to_options(name, field, self.__openllm_model_name__, suffix_generation=True)(f)
|
||||
f = attrs_to_options(
|
||||
name, field, self.__openllm_model_name__, typ=self.__openllm_hints__[name], suffix_generation=True
|
||||
)(f)
|
||||
f = optgroup.group(f"{self.generation_class.__name__} generation options")(f)
|
||||
|
||||
if len(self.__class__.__openllm_attrs__) == 0:
|
||||
return f
|
||||
# NOTE: in this case, the function is already a ClickFunctionWrapper
|
||||
# hence the casting
|
||||
return t.cast("F[P, ClickFunctionWrapper[P, openllm.LLMConfig]]", f)
|
||||
|
||||
for name, field in attr.fields_dict(self.__class__).items():
|
||||
if t.get_origin(field.type) is t.Union:
|
||||
# NOTE: Union type is currently not yet supported, we probably just need to use environment instead.
|
||||
continue
|
||||
f = attrs_to_options(name, field, self.__openllm_model_name__)(f)
|
||||
f = attrs_to_options(name, field, self.__openllm_model_name__, typ=self.__openllm_hints__[name])(f)
|
||||
|
||||
return optgroup.group(f"{self.__class__.__name__} options")(f)
|
||||
|
||||
|
||||
@@ -23,6 +23,11 @@ import typing as t
|
||||
if not t.TYPE_CHECKING:
|
||||
raise RuntimeError(f"{__name__} should not be imported during runtime")
|
||||
|
||||
import click
|
||||
|
||||
P = t.ParamSpec("P")
|
||||
O_co = t.TypeVar("O_co", covariant=True)
|
||||
|
||||
import transformers
|
||||
from bentoml._internal.models.model import ModelSignatureDict as ModelSignatureDict
|
||||
from bentoml._internal.models.model import ModelSignaturesType as ModelSignaturesType
|
||||
@@ -31,3 +36,18 @@ LLMModel = transformers.PreTrainedModel | transformers.TFPreTrainedModel | trans
|
||||
LLMTokenizer = (
|
||||
transformers.PreTrainedTokenizer | transformers.PreTrainedTokenizerFast | transformers.PreTrainedTokenizerBase
|
||||
)
|
||||
|
||||
|
||||
class ClickFunctionWrapper(t.Protocol[P, O_co]):
|
||||
__click_params__: list[click.Option]
|
||||
|
||||
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> O_co:
|
||||
...
|
||||
|
||||
|
||||
# F is a t.Callable[P, O_co] with compatible to ClickFunctionWrapper
|
||||
class F(t.Generic[P, O_co]):
|
||||
__click_params__: list[click.Option]
|
||||
|
||||
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> O_co:
|
||||
...
|
||||
|
||||
@@ -1,15 +1,25 @@
|
||||
from typing import Any, Dict, NamedTuple, Optional, Tuple, Type
|
||||
from typing import Any, Callable, Dict, NamedTuple, Optional, ParamSpec, Protocol, Tuple, Type, TypeVar
|
||||
|
||||
import click
|
||||
|
||||
from ._core import FC, OptionGroup
|
||||
|
||||
P = ParamSpec("P")
|
||||
O_co = TypeVar("O_co", covariant=True)
|
||||
|
||||
F = Callable[P, O_co]
|
||||
|
||||
class OptionStackItem(NamedTuple):
|
||||
param_decls: Tuple[str, ...]
|
||||
attrs: Dict[str, Any]
|
||||
param_count: int
|
||||
...
|
||||
|
||||
class ClickFunctionWrapper(Protocol[P, O_co]):
|
||||
__click_params__: list[click.Option]
|
||||
|
||||
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> O_co: ...
|
||||
|
||||
class _NotAttachedOption(click.Option):
|
||||
"""The helper class to catch grouped options which were not attached to the group
|
||||
|
||||
@@ -76,7 +86,7 @@ class _OptGroup:
|
||||
:param attrs: Additional parameters of option group class
|
||||
"""
|
||||
...
|
||||
def option(self, *param_decls: Any, **attrs: Any) -> FC:
|
||||
def option(self, *param_decls: Any, **attrs: Any) -> F[P, ClickFunctionWrapper[P, O_co]]:
|
||||
"""The decorator adds a new option to the group
|
||||
|
||||
The decorator is lazy. It adds option decls and attrs.
|
||||
|
||||
Reference in New Issue
Block a user