fix: generation serde

Signed-off-by: aarnphm-ec2-dev <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
aarnphm-ec2-dev
2023-06-02 07:05:16 +00:00
parent 07d42daaec
commit c3aeb43997
3 changed files with 49 additions and 37 deletions

View File

@@ -48,6 +48,7 @@ import click
import inflection
import orjson
from bentoml._internal.models.model import ModelSignature
from cattr.gen import make_dict_unstructure_fn, override
from click_option_group import optgroup
import openllm
@@ -367,20 +368,13 @@ class GenerationConfig:
self.__attrs_init__(**attrs)
def generation_config_dump(genconf: GenerationConfig) -> dict[str, t.Any]:
target: dict[str, t.Any] = {}
for k in attr.fields_dict(genconf.__class__):
v = getattr(genconf, k, None)
if v is None:
continue
if not k.startswith("_"):
target[k] = v
return target
bentoml_cattr.register_unstructure_hook_func(
lambda cls: lenient_issubclass(cls, GenerationConfig),
generation_config_dump,
bentoml_cattr.register_unstructure_hook_factory(
lambda cls: attr.has(cls) and lenient_issubclass(cls, GenerationConfig),
lambda cls: make_dict_unstructure_fn(
cls,
bentoml_cattr,
**{k: override(omit=True) for k, v in attr.fields_dict(cls).items() if v.default in (None, attr.NOTHING)},
),
)
@@ -551,23 +545,33 @@ def _make_attr_tuple_class(cls_name: str, attr_names: t.Iterable[str]) -> type[t
def _make_internal_generation_class(cls: type[LLMConfig]) -> type[GenerationConfig]:
attribs: DictStrAny = {}
_has_gen_class = _has_own_attribute(cls, "GenerationConfig")
for key, field in attr.fields_dict(GenerationConfig).items():
class_gen_value = cls.GenerationConfig.__dict__.get(key, attr.NOTHING) if _has_gen_class else attr.NOTHING
attribs[key] = cls.Field(
class_gen_value if class_gen_value is not attr.NOTHING else field.default,
description=field.metadata.get("description"),
env=f"OPENLLM_{cls.__openllm_model_name__.upper()}_GENERATION_{key.upper()}",
validator=field.validator,
)
if _has_gen_class:
delattr(cls, "GenerationConfig")
return attr.make_class(
cls.__name__.replace("Config", "GenerationConfig"), attribs, field_transformer=env_transformers
hints = t.get_type_hints(GenerationConfig)
def _evolve_with_base_default(
__cls: type[GenerationConfig], fields: list[attr.Attribute[t.Any]]
) -> list[attr.Attribute[t.Any]]:
transformed: list[attr.Attribute[t.Any]] = []
for f in fields:
env = f"OPENLLM_{cls.__openllm_model_name__.upper()}_GENERATION_{f.name.upper()}"
_from_env = _populate_value_from_env_var(env, fallback=f.default)
default_value = f.default if not _has_gen_class else getattr(cls.GenerationConfig, f.name, _from_env)
transformed.append(f.evolve(default=default_value, metadata={"env": env}, type=hints[f.name]))
return transformed
generated_cls = attr.make_class(
cls.__name__.replace("Config", "GenerationConfig"),
[],
bases=(GenerationConfig,),
frozen=True,
slots=True,
repr=True,
field_transformer=_evolve_with_base_default,
)
return generated_cls
@attr.define
class LLMConfig:
@@ -655,11 +659,12 @@ class LLMConfig:
f"The following field doesn't have a type annotation: {missing_annotated}"
)
hints = t.get_type_hints(cls)
# NOTE: we know need to determine the list of the attrs
# by mro to at the very least support inheritance. Tho it is not recommended.
own_attrs: list[attr.Attribute[t.Any]] = []
for attr_name, ca in ca_list:
gen_attribute = attr.Attribute.from_counting_attr(name=attr_name, ca=ca, type=anns.get(attr_name))
gen_attribute = attr.Attribute.from_counting_attr(name=attr_name, ca=ca, type=hints.get(attr_name))
if attr_name in ca_names:
metadata = ca.metadata
metadata["env"] = field_env_key(attr_name)
@@ -747,15 +752,11 @@ class LLMConfig:
attrs = {k: v for k, v in attrs.items() if k not in generation_config}
extras = set(attrs).difference(set(attr.fields_dict(self.__class__)))
if len(extras) > 0:
# Set all of the keys in extras to the class
for k in extras:
setattr(self, k, attrs[k])
self.__attrs_init__(**{k: v for k, v in attrs.items() if k not in extras})
def __repr__(self) -> str:
bases = f"{self.__class__.__qualname__.rsplit('>.', 1)[-1]}(generation_config={self.generation_config}"
bases = f"{self.__class__.__qualname__.rsplit('>.', 1)[-1]}(generation_config={repr(self.generation_class())}"
if len(self.__openllm_attrs__) > 0:
bases += ", " + ", ".join([f"{k}={getattr(self, k)}" for k in self.__openllm_attrs__]) + ")"
else:
@@ -783,7 +784,7 @@ class LLMConfig:
raise GpuNotAvailableError(msg) from None
def model_dump(self, flatten: bool = False, **_: t.Any):
dumped = {k: getattr(self, k) for k in self.__openllm_attrs__}
dumped = bentoml_cattr.unstructure(self)
generation_config = bentoml_cattr.unstructure(self.generation_config)
if not flatten:
dumped["generation_config"] = generation_config
@@ -872,7 +873,10 @@ class LLMConfig:
return optgroup.group(f"{self.__class__.__name__} options")(f)
bentoml_cattr.register_unstructure_hook_func(lambda cls: lenient_issubclass(cls, LLMConfig), LLMConfig.model_dump)
bentoml_cattr.register_unstructure_hook_factory(
lambda cls: lenient_issubclass(cls, LLMConfig),
lambda cls: make_dict_unstructure_fn(cls, bentoml_cattr, _cattrs_omit_if_default=False),
)
def structure_llm_config(data: dict[str, t.Any], cls: type[LLMConfig]) -> LLMConfig:

View File

@@ -288,7 +288,6 @@ def start_model_command(
"""
from bentoml._internal.configuration import get_debug_mode
breakpoint()
ModelEnv = openllm.utils.ModelEnv(model_name)
model_command_decr: dict[str, t.Any] = {"name": ModelEnv.model_name, "context_settings": _context_settings or {}}

View File

@@ -85,7 +85,16 @@ def Field(
else:
_validator = attr.validators.and_(*piped)
return attr.field(default=default, metadata=metadata, validator=_validator, converter=converter, **attrs)
factory = attrs.pop("factory", None)
if factory is not None and default is not None:
raise RuntimeError("'factory' and 'default' are mutually exclusive.")
# NOTE: the behaviour of this is we will respect factory over the default
if factory is not None:
attrs["factory"] = factory
else:
attrs["default"] = default
return attr.field(metadata=metadata, validator=_validator, converter=converter, **attrs)
def allows_multiple(field_type: t.Any) -> bool: