mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-01-30 02:12:00 -05:00
fix: generation serde
Signed-off-by: aarnphm-ec2-dev <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
@@ -48,6 +48,7 @@ import click
|
||||
import inflection
|
||||
import orjson
|
||||
from bentoml._internal.models.model import ModelSignature
|
||||
from cattr.gen import make_dict_unstructure_fn, override
|
||||
from click_option_group import optgroup
|
||||
|
||||
import openllm
|
||||
@@ -367,20 +368,13 @@ class GenerationConfig:
|
||||
self.__attrs_init__(**attrs)
|
||||
|
||||
|
||||
def generation_config_dump(genconf: GenerationConfig) -> dict[str, t.Any]:
|
||||
target: dict[str, t.Any] = {}
|
||||
for k in attr.fields_dict(genconf.__class__):
|
||||
v = getattr(genconf, k, None)
|
||||
if v is None:
|
||||
continue
|
||||
if not k.startswith("_"):
|
||||
target[k] = v
|
||||
return target
|
||||
|
||||
|
||||
bentoml_cattr.register_unstructure_hook_func(
|
||||
lambda cls: lenient_issubclass(cls, GenerationConfig),
|
||||
generation_config_dump,
|
||||
bentoml_cattr.register_unstructure_hook_factory(
|
||||
lambda cls: attr.has(cls) and lenient_issubclass(cls, GenerationConfig),
|
||||
lambda cls: make_dict_unstructure_fn(
|
||||
cls,
|
||||
bentoml_cattr,
|
||||
**{k: override(omit=True) for k, v in attr.fields_dict(cls).items() if v.default in (None, attr.NOTHING)},
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -551,23 +545,33 @@ def _make_attr_tuple_class(cls_name: str, attr_names: t.Iterable[str]) -> type[t
|
||||
|
||||
|
||||
def _make_internal_generation_class(cls: type[LLMConfig]) -> type[GenerationConfig]:
|
||||
attribs: DictStrAny = {}
|
||||
_has_gen_class = _has_own_attribute(cls, "GenerationConfig")
|
||||
for key, field in attr.fields_dict(GenerationConfig).items():
|
||||
class_gen_value = cls.GenerationConfig.__dict__.get(key, attr.NOTHING) if _has_gen_class else attr.NOTHING
|
||||
attribs[key] = cls.Field(
|
||||
class_gen_value if class_gen_value is not attr.NOTHING else field.default,
|
||||
description=field.metadata.get("description"),
|
||||
env=f"OPENLLM_{cls.__openllm_model_name__.upper()}_GENERATION_{key.upper()}",
|
||||
validator=field.validator,
|
||||
)
|
||||
if _has_gen_class:
|
||||
delattr(cls, "GenerationConfig")
|
||||
|
||||
return attr.make_class(
|
||||
cls.__name__.replace("Config", "GenerationConfig"), attribs, field_transformer=env_transformers
|
||||
hints = t.get_type_hints(GenerationConfig)
|
||||
|
||||
def _evolve_with_base_default(
|
||||
__cls: type[GenerationConfig], fields: list[attr.Attribute[t.Any]]
|
||||
) -> list[attr.Attribute[t.Any]]:
|
||||
transformed: list[attr.Attribute[t.Any]] = []
|
||||
for f in fields:
|
||||
env = f"OPENLLM_{cls.__openllm_model_name__.upper()}_GENERATION_{f.name.upper()}"
|
||||
_from_env = _populate_value_from_env_var(env, fallback=f.default)
|
||||
default_value = f.default if not _has_gen_class else getattr(cls.GenerationConfig, f.name, _from_env)
|
||||
transformed.append(f.evolve(default=default_value, metadata={"env": env}, type=hints[f.name]))
|
||||
return transformed
|
||||
|
||||
generated_cls = attr.make_class(
|
||||
cls.__name__.replace("Config", "GenerationConfig"),
|
||||
[],
|
||||
bases=(GenerationConfig,),
|
||||
frozen=True,
|
||||
slots=True,
|
||||
repr=True,
|
||||
field_transformer=_evolve_with_base_default,
|
||||
)
|
||||
|
||||
return generated_cls
|
||||
|
||||
|
||||
@attr.define
|
||||
class LLMConfig:
|
||||
@@ -655,11 +659,12 @@ class LLMConfig:
|
||||
f"The following field doesn't have a type annotation: {missing_annotated}"
|
||||
)
|
||||
|
||||
hints = t.get_type_hints(cls)
|
||||
# NOTE: we know need to determine the list of the attrs
|
||||
# by mro to at the very least support inheritance. Tho it is not recommended.
|
||||
own_attrs: list[attr.Attribute[t.Any]] = []
|
||||
for attr_name, ca in ca_list:
|
||||
gen_attribute = attr.Attribute.from_counting_attr(name=attr_name, ca=ca, type=anns.get(attr_name))
|
||||
gen_attribute = attr.Attribute.from_counting_attr(name=attr_name, ca=ca, type=hints.get(attr_name))
|
||||
if attr_name in ca_names:
|
||||
metadata = ca.metadata
|
||||
metadata["env"] = field_env_key(attr_name)
|
||||
@@ -747,15 +752,11 @@ class LLMConfig:
|
||||
attrs = {k: v for k, v in attrs.items() if k not in generation_config}
|
||||
|
||||
extras = set(attrs).difference(set(attr.fields_dict(self.__class__)))
|
||||
if len(extras) > 0:
|
||||
# Set all of the keys in extras to the class
|
||||
for k in extras:
|
||||
setattr(self, k, attrs[k])
|
||||
|
||||
self.__attrs_init__(**{k: v for k, v in attrs.items() if k not in extras})
|
||||
|
||||
def __repr__(self) -> str:
|
||||
bases = f"{self.__class__.__qualname__.rsplit('>.', 1)[-1]}(generation_config={self.generation_config}"
|
||||
bases = f"{self.__class__.__qualname__.rsplit('>.', 1)[-1]}(generation_config={repr(self.generation_class())}"
|
||||
if len(self.__openllm_attrs__) > 0:
|
||||
bases += ", " + ", ".join([f"{k}={getattr(self, k)}" for k in self.__openllm_attrs__]) + ")"
|
||||
else:
|
||||
@@ -783,7 +784,7 @@ class LLMConfig:
|
||||
raise GpuNotAvailableError(msg) from None
|
||||
|
||||
def model_dump(self, flatten: bool = False, **_: t.Any):
|
||||
dumped = {k: getattr(self, k) for k in self.__openllm_attrs__}
|
||||
dumped = bentoml_cattr.unstructure(self)
|
||||
generation_config = bentoml_cattr.unstructure(self.generation_config)
|
||||
if not flatten:
|
||||
dumped["generation_config"] = generation_config
|
||||
@@ -872,7 +873,10 @@ class LLMConfig:
|
||||
return optgroup.group(f"{self.__class__.__name__} options")(f)
|
||||
|
||||
|
||||
bentoml_cattr.register_unstructure_hook_func(lambda cls: lenient_issubclass(cls, LLMConfig), LLMConfig.model_dump)
|
||||
bentoml_cattr.register_unstructure_hook_factory(
|
||||
lambda cls: lenient_issubclass(cls, LLMConfig),
|
||||
lambda cls: make_dict_unstructure_fn(cls, bentoml_cattr, _cattrs_omit_if_default=False),
|
||||
)
|
||||
|
||||
|
||||
def structure_llm_config(data: dict[str, t.Any], cls: type[LLMConfig]) -> LLMConfig:
|
||||
|
||||
@@ -288,7 +288,6 @@ def start_model_command(
|
||||
"""
|
||||
from bentoml._internal.configuration import get_debug_mode
|
||||
|
||||
breakpoint()
|
||||
ModelEnv = openllm.utils.ModelEnv(model_name)
|
||||
model_command_decr: dict[str, t.Any] = {"name": ModelEnv.model_name, "context_settings": _context_settings or {}}
|
||||
|
||||
|
||||
@@ -85,7 +85,16 @@ def Field(
|
||||
else:
|
||||
_validator = attr.validators.and_(*piped)
|
||||
|
||||
return attr.field(default=default, metadata=metadata, validator=_validator, converter=converter, **attrs)
|
||||
factory = attrs.pop("factory", None)
|
||||
if factory is not None and default is not None:
|
||||
raise RuntimeError("'factory' and 'default' are mutually exclusive.")
|
||||
# NOTE: the behaviour of this is we will respect factory over the default
|
||||
if factory is not None:
|
||||
attrs["factory"] = factory
|
||||
else:
|
||||
attrs["default"] = default
|
||||
|
||||
return attr.field(metadata=metadata, validator=_validator, converter=converter, **attrs)
|
||||
|
||||
|
||||
def allows_multiple(field_type: t.Any) -> bool:
|
||||
|
||||
Reference in New Issue
Block a user