From c3aeb43997aa9505debb8bf53bd5607b7027b0cd Mon Sep 17 00:00:00 2001 From: aarnphm-ec2-dev <29749331+aarnphm@users.noreply.github.com> Date: Fri, 2 Jun 2023 07:05:16 +0000 Subject: [PATCH] fix: generation serde Signed-off-by: aarnphm-ec2-dev <29749331+aarnphm@users.noreply.github.com> --- src/openllm/_configuration.py | 74 ++++++++++++++++++----------------- src/openllm/cli.py | 1 - src/openllm/utils/dantic.py | 11 +++++- 3 files changed, 49 insertions(+), 37 deletions(-) diff --git a/src/openllm/_configuration.py b/src/openllm/_configuration.py index 1f322fca..ed1527e7 100644 --- a/src/openllm/_configuration.py +++ b/src/openllm/_configuration.py @@ -48,6 +48,7 @@ import click import inflection import orjson from bentoml._internal.models.model import ModelSignature +from cattr.gen import make_dict_unstructure_fn, override from click_option_group import optgroup import openllm @@ -367,20 +368,13 @@ class GenerationConfig: self.__attrs_init__(**attrs) -def generation_config_dump(genconf: GenerationConfig) -> dict[str, t.Any]: - target: dict[str, t.Any] = {} - for k in attr.fields_dict(genconf.__class__): - v = getattr(genconf, k, None) - if v is None: - continue - if not k.startswith("_"): - target[k] = v - return target - - -bentoml_cattr.register_unstructure_hook_func( - lambda cls: lenient_issubclass(cls, GenerationConfig), - generation_config_dump, +bentoml_cattr.register_unstructure_hook_factory( + lambda cls: attr.has(cls) and lenient_issubclass(cls, GenerationConfig), + lambda cls: make_dict_unstructure_fn( + cls, + bentoml_cattr, + **{k: override(omit=True) for k, v in attr.fields_dict(cls).items() if v.default in (None, attr.NOTHING)}, + ), ) @@ -551,23 +545,33 @@ def _make_attr_tuple_class(cls_name: str, attr_names: t.Iterable[str]) -> type[t def _make_internal_generation_class(cls: type[LLMConfig]) -> type[GenerationConfig]: - attribs: DictStrAny = {} _has_gen_class = _has_own_attribute(cls, "GenerationConfig") - for key, field in attr.fields_dict(GenerationConfig).items(): - class_gen_value = cls.GenerationConfig.__dict__.get(key, attr.NOTHING) if _has_gen_class else attr.NOTHING - attribs[key] = cls.Field( - class_gen_value if class_gen_value is not attr.NOTHING else field.default, - description=field.metadata.get("description"), - env=f"OPENLLM_{cls.__openllm_model_name__.upper()}_GENERATION_{key.upper()}", - validator=field.validator, - ) - if _has_gen_class: - delattr(cls, "GenerationConfig") - return attr.make_class( - cls.__name__.replace("Config", "GenerationConfig"), attribs, field_transformer=env_transformers + hints = t.get_type_hints(GenerationConfig) + + def _evolve_with_base_default( + __cls: type[GenerationConfig], fields: list[attr.Attribute[t.Any]] + ) -> list[attr.Attribute[t.Any]]: + transformed: list[attr.Attribute[t.Any]] = [] + for f in fields: + env = f"OPENLLM_{cls.__openllm_model_name__.upper()}_GENERATION_{f.name.upper()}" + _from_env = _populate_value_from_env_var(env, fallback=f.default) + default_value = f.default if not _has_gen_class else getattr(cls.GenerationConfig, f.name, _from_env) + transformed.append(f.evolve(default=default_value, metadata={"env": env}, type=hints[f.name])) + return transformed + + generated_cls = attr.make_class( + cls.__name__.replace("Config", "GenerationConfig"), + [], + bases=(GenerationConfig,), + frozen=True, + slots=True, + repr=True, + field_transformer=_evolve_with_base_default, ) + return generated_cls + @attr.define class LLMConfig: @@ -655,11 +659,12 @@ class LLMConfig: f"The following field doesn't have a type annotation: {missing_annotated}" ) + hints = t.get_type_hints(cls) # NOTE: we know need to determine the list of the attrs # by mro to at the very least support inheritance. Tho it is not recommended. own_attrs: list[attr.Attribute[t.Any]] = [] for attr_name, ca in ca_list: - gen_attribute = attr.Attribute.from_counting_attr(name=attr_name, ca=ca, type=anns.get(attr_name)) + gen_attribute = attr.Attribute.from_counting_attr(name=attr_name, ca=ca, type=hints.get(attr_name)) if attr_name in ca_names: metadata = ca.metadata metadata["env"] = field_env_key(attr_name) @@ -747,15 +752,11 @@ class LLMConfig: attrs = {k: v for k, v in attrs.items() if k not in generation_config} extras = set(attrs).difference(set(attr.fields_dict(self.__class__))) - if len(extras) > 0: - # Set all of the keys in extras to the class - for k in extras: - setattr(self, k, attrs[k]) self.__attrs_init__(**{k: v for k, v in attrs.items() if k not in extras}) def __repr__(self) -> str: - bases = f"{self.__class__.__qualname__.rsplit('>.', 1)[-1]}(generation_config={self.generation_config}" + bases = f"{self.__class__.__qualname__.rsplit('>.', 1)[-1]}(generation_config={repr(self.generation_class())}" if len(self.__openllm_attrs__) > 0: bases += ", " + ", ".join([f"{k}={getattr(self, k)}" for k in self.__openllm_attrs__]) + ")" else: @@ -783,7 +784,7 @@ class LLMConfig: raise GpuNotAvailableError(msg) from None def model_dump(self, flatten: bool = False, **_: t.Any): - dumped = {k: getattr(self, k) for k in self.__openllm_attrs__} + dumped = bentoml_cattr.unstructure(self) generation_config = bentoml_cattr.unstructure(self.generation_config) if not flatten: dumped["generation_config"] = generation_config @@ -872,7 +873,10 @@ class LLMConfig: return optgroup.group(f"{self.__class__.__name__} options")(f) -bentoml_cattr.register_unstructure_hook_func(lambda cls: lenient_issubclass(cls, LLMConfig), LLMConfig.model_dump) +bentoml_cattr.register_unstructure_hook_factory( + lambda cls: lenient_issubclass(cls, LLMConfig), + lambda cls: make_dict_unstructure_fn(cls, bentoml_cattr, _cattrs_omit_if_default=False), +) def structure_llm_config(data: dict[str, t.Any], cls: type[LLMConfig]) -> LLMConfig: diff --git a/src/openllm/cli.py b/src/openllm/cli.py index 15cc8641..5a9664fc 100644 --- a/src/openllm/cli.py +++ b/src/openllm/cli.py @@ -288,7 +288,6 @@ def start_model_command( """ from bentoml._internal.configuration import get_debug_mode - breakpoint() ModelEnv = openllm.utils.ModelEnv(model_name) model_command_decr: dict[str, t.Any] = {"name": ModelEnv.model_name, "context_settings": _context_settings or {}} diff --git a/src/openllm/utils/dantic.py b/src/openllm/utils/dantic.py index f83f9198..061553d8 100644 --- a/src/openllm/utils/dantic.py +++ b/src/openllm/utils/dantic.py @@ -85,7 +85,16 @@ def Field( else: _validator = attr.validators.and_(*piped) - return attr.field(default=default, metadata=metadata, validator=_validator, converter=converter, **attrs) + factory = attrs.pop("factory", None) + if factory is not None and default is not None: + raise RuntimeError("'factory' and 'default' are mutually exclusive.") + # NOTE: the behaviour of this is we will respect factory over the default + if factory is not None: + attrs["factory"] = factory + else: + attrs["default"] = default + + return attr.field(metadata=metadata, validator=_validator, converter=converter, **attrs) def allows_multiple(field_type: t.Any) -> bool: