From 78358dbb8d5eb54456d5037a4e5c248a454e2d12 Mon Sep 17 00:00:00 2001 From: Aaron <29749331+aarnphm@users.noreply.github.com> Date: Sun, 28 May 2023 06:01:11 -0700 Subject: [PATCH] fix(type): configuration and dependencies Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> --- examples/play.py | 5 +++++ pyproject.toml | 2 +- src/openllm/_configuration.py | 13 +++++++------ src/openllm/_llm.py | 10 +++++++++- src/openllm/models/auto/factory.py | 23 +++++++++-------------- src/openllm/utils/__init__.py | 7 +++++++ 6 files changed, 38 insertions(+), 22 deletions(-) create mode 100644 examples/play.py diff --git a/examples/play.py b/examples/play.py new file mode 100644 index 00000000..3216a80a --- /dev/null +++ b/examples/play.py @@ -0,0 +1,5 @@ +from __future__ import annotations + +import openllm + +model = openllm.AutoLLM.for_model("flan-t5") diff --git a/pyproject.toml b/pyproject.toml index 4828de67..8f10f88a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,7 +75,7 @@ all = [ 'openllm[flan-t5]', 'openllm[starcoder]', ] -fine-tune = ["peft", "bitsandbytes"] +fine-tune = ["peft", "bitsandbytes", "datasets"] chatglm = ['cpm_kernels', 'sentencepiece'] falcon = ['einops'] flan-t5 = ['flax', 'jax', 'jaxlib', 'tensorflow'] diff --git a/src/openllm/_configuration.py b/src/openllm/_configuration.py index 28dc39c3..994b102e 100644 --- a/src/openllm/_configuration.py +++ b/src/openllm/_configuration.py @@ -370,8 +370,8 @@ class GenerationConfig(pydantic.BaseModel): if t.TYPE_CHECKING: # The following is handled via __pydantic_init_subclass__ - __openllm_env_name__: str __openllm_model_name__: str + __openllm_env__: openllm.utils.ModelEnv def __init_subclass__(cls, *, _internal: bool = False, **attrs: t.Any) -> None: if not _internal: @@ -383,7 +383,7 @@ class GenerationConfig(pydantic.BaseModel): if model_name is None: raise RuntimeError("Failed to initialize GenerationConfig subclass (missing model_name)") cls.__openllm_model_name__ = inflection.underscore(model_name) - cls.__openllm_env_name__ = cls.__openllm_model_name__.upper() + cls.__openllm_env__ = openllm.utils.ModelEnv(cls.__openllm_model_name__) @classmethod def construct_from_llm_config(cls, llm_config: type[LLMConfig]) -> GenerationConfig: @@ -404,7 +404,7 @@ class GenerationConfig(pydantic.BaseModel): field.json_schema_extra = {} if "env" in field.json_schema_extra: continue - field.json_schema_extra["env"] = f"OPENLLM_{self.__openllm_env_name__}_GENERATION_{key.upper()}" + field.json_schema_extra["env"] = self.__openllm_env__.gen_env_key(f"GENERATION_{key.upper()}") class LLMConfig(pydantic.BaseModel, ABC): @@ -425,6 +425,7 @@ class LLMConfig(pydantic.BaseModel, ABC): __openllm_trust_remote_code__: bool = False __openllm_requires_gpu__: bool = False __openllm_env__: openllm.utils.ModelEnv + GenerationConfig: type[t.Any] = GenerationConfig def __init_subclass__( @@ -468,8 +469,6 @@ class LLMConfig(pydantic.BaseModel, ABC): cls.__openllm_model_name__ = cls.__name__.replace("Config", "").lower() cls.__openllm_start_name__ = cls.__openllm_model_name__ - cls.__openllm_env__ = openllm.utils.ModelEnv(cls.__openllm_model_name__) - if hasattr(cls, "GenerationConfig"): cls.generation_config = t.cast( "type[GenerationConfig]", @@ -481,12 +480,14 @@ class LLMConfig(pydantic.BaseModel, ABC): ).construct_from_llm_config(cls) delattr(cls, "GenerationConfig") + cls.__openllm_env__ = cls.generation_config.__openllm_env__ + for key, field in cls.model_fields.items(): if not field.json_schema_extra: field.json_schema_extra = {} if "env" in field.json_schema_extra: continue - field.json_schema_extra["env"] = f"OPENLLM_{cls.__openllm_model_name__.upper()}_{key.upper()}" + field.json_schema_extra["env"] = cls.__openllm_env__.gen_env_key(key) def model_post_init(self, _: t.Any): if self.__pydantic_extra__: diff --git a/src/openllm/_llm.py b/src/openllm/_llm.py index 5b8044a9..f6872279 100644 --- a/src/openllm/_llm.py +++ b/src/openllm/_llm.py @@ -329,6 +329,14 @@ class LLM(LLMInterface, metaclass=LLMMetaclass): # NOTE: the following will be populated by __init__ config: openllm.LLMConfig + # NOTE: the following is the similar interface to HuggingFace pretrained protocol. + + @classmethod + def from_pretrained( + cls, pretrained: str | None = None, llm_config: openllm.LLMConfig | None = None, *args: t.Any, **attrs: t.Any + ) -> LLM: + return cls(pretrained=pretrained, llm_config=llm_config, *args, **attrs) + def __init__( self, pretrained: str | None = None, @@ -412,7 +420,7 @@ class LLM(LLMInterface, metaclass=LLMMetaclass): attrs = copy.deepcopy(self.config.__pydantic_extra__) if pretrained is None: - pretrained = os.environ.get(f"OPENLLM_{self.config.__openllm_model_name__.upper()}_PRETRAINED", None) + pretrained = os.environ.get(self.config.__openllm_env__.pretrained, None) if not pretrained: assert self.default_model, "A default model is required for any LLM." pretrained = self.default_model diff --git a/src/openllm/models/auto/factory.py b/src/openllm/models/auto/factory.py index 40a575cd..c86fbecf 100644 --- a/src/openllm/models/auto/factory.py +++ b/src/openllm/models/auto/factory.py @@ -33,13 +33,6 @@ else: ConfigModelOrderedDict = OrderedDict -def _get_llm_class(config: openllm.LLMConfig, llm_mapping: _LazyAutoMapping) -> type[openllm.LLM]: - supported_llm = llm_mapping[type(config)] - if not isinstance(supported_llm, (list, tuple)): - return supported_llm - return supported_llm[0] - - class _BaseAutoLLMClass: _model_mapping: _LazyAutoMapping @@ -56,6 +49,7 @@ class _BaseAutoLLMClass: model_name: str, pretrained: str | None = None, return_runner_kwargs: t.Literal[False] = ..., + llm_config: openllm.LLMConfig | None = ..., **attrs: t.Any, ) -> openllm.LLM: ... @@ -67,6 +61,7 @@ class _BaseAutoLLMClass: model_name: str, pretrained: str | None = None, return_runner_kwargs: t.Literal[True] = ..., + llm_config: openllm.LLMConfig | None = ..., **attrs: t.Any, ) -> tuple[openllm.LLM, dict[str, t.Any]]: ... @@ -77,9 +72,9 @@ class _BaseAutoLLMClass: model_name: str, pretrained: str | None = None, return_runner_kwargs: bool = False, + llm_config: openllm.LLMConfig | None = ..., **attrs: t.Any, ) -> openllm.LLM | tuple[openllm.LLM, dict[str, t.Any]]: - config = attrs.pop("llm_config", None) runner_kwargs_name = [ "name", "models", @@ -90,16 +85,16 @@ class _BaseAutoLLMClass: "scheduling_strategy", ] to_runner_attrs = {k: v for k, v in attrs.items() if k in runner_kwargs_name} - if not isinstance(config, openllm.LLMConfig): + if not isinstance(llm_config, openllm.LLMConfig): # The rest of kwargs is now passed to config - config = AutoConfig.for_model(model_name, **attrs) - if type(config) in cls._model_mapping.keys(): - llm = _get_llm_class(config, cls._model_mapping)(pretrained=pretrained, llm_config=config, **attrs) + llm_config = AutoConfig.for_model(model_name, **attrs) + if type(llm_config) in cls._model_mapping.keys(): + llm = cls._model_mapping[type(llm_config)](pretrained, llm_config=llm_config, **attrs) if not return_runner_kwargs: return llm return llm, to_runner_attrs raise ValueError( - f"Unrecognized configuration class {config.__class__} for this kind of AutoRunner: {cls.__name__}.\n" + f"Unrecognized configuration class {llm_config.__class__} for this kind of AutoRunner: {cls.__name__}.\n" f"Runnable type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}." ) @@ -171,7 +166,7 @@ class _LazyAutoMapping(ConfigModelOrderedDict): common_keys = set(self._config_mapping.keys()).intersection(self._model_mapping.keys()) return len(common_keys) + len(self._extra_content) - def __getitem__(self, key: openllm.LLMConfig) -> openllm.LLM: + def __getitem__(self, key: type[openllm.LLMConfig]) -> type[openllm.LLM]: if key in self._extra_content: return self._extra_content[key] model_type = self._reverse_config_mapping[key.__name__] diff --git a/src/openllm/utils/__init__.py b/src/openllm/utils/__init__.py index ea32d0b1..d3dc4959 100644 --- a/src/openllm/utils/__init__.py +++ b/src/openllm/utils/__init__.py @@ -75,6 +75,13 @@ class ModelEnv: def model_config(self) -> str: return f"OPENLLM_{self.model_name.upper()}_CONFIG" + @property + def pretrained(self) -> str: + return f"OPENLLM_{self.model_name.upper()}_PRETRAINED" + + def gen_env_key(self, key: str) -> str: + return f"OPENLLM_{self.model_name.upper()}_{key.upper()}" + @property def start_docstring(self) -> str: return getattr(self.module, f"START_{self.model_name.upper()}_COMMAND_DOCSTRING")