fix(config): __getitem__ to get the value instead of member of class

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
Aaron
2023-06-19 05:34:49 -04:00
parent 622a2fb37d
commit 2244cce5bd
2 changed files with 23 additions and 8 deletions

View File

@@ -97,8 +97,6 @@ _T = t.TypeVar("_T")
if t.TYPE_CHECKING:
import tensorflow as tf
import torch
from attr import _CountingAttr # type: ignore
from attr import _make_init # type: ignore
from attr import _make_repr # type: ignore
@@ -131,8 +129,6 @@ else:
from attr._make import _transform_attrs
transformers = openllm.utils.LazyLoader("transformers", globals(), "transformers")
torch = openllm.utils.LazyLoader("torch", globals(), "torch")
tf = openllm.utils.LazyLoader("tf", globals(), "tensorflow")
__all__ = ["LLMConfig"]
@@ -999,7 +995,9 @@ class LLMConfig:
# that are defined in parent classes.
# As their descriptors may be overridden by a child class,
# we collect them here and update the class dict
reused_slots = {slot: slot_descriptor for slot, slot_descriptor in existing_slots.items() if slot in slot_names}
reused_slots = {
slot: slot_descriptor for slot, slot_descriptor in existing_slots.items() if slot in slot_names
}
# __openllm_extras__ holds additional metadata that might be usefule for users, hence we add it to slots
slot_names = [name for name in slot_names if name not in reused_slots] + ["__openllm_extras__"]
cls.__slots__ = tuple(slot_names)
@@ -1090,7 +1088,7 @@ class LLMConfig:
elif hasattr(self, item):
return getattr(self, item)
elif hasattr(self.__openllm_generation_class__, item):
return getattr(self.__openllm_generation_class__, item)
return getattr(self.generation_config, item)
elif item in self.__openllm_extras__:
return self.__openllm_extras__[item]
else:
@@ -1240,7 +1238,9 @@ def structure_llm_config(data: DictStrAny, cls: type[LLMConfig]) -> LLMConfig:
raise RuntimeError(f"Expected a dictionary, but got {type(data)}")
generation_cls_fields = attr.fields_dict(cls.__openllm_generation_class__)
cls_attrs = {k: v for k, v in data.items() if k in cls.__openllm_accepted_keys__ and k not in generation_cls_fields}
cls_attrs = {
k: v for k, v in data.items() if k in cls.__openllm_accepted_keys__ and k not in generation_cls_fields
}
if "generation_config" in data:
generation_config = data.pop("generation_config")
if not LazyType(DictStrAny).isinstance(generation_config):