diff --git a/src/openllm/_llm.py b/src/openllm/_llm.py index 5abe4041..ba0d057e 100644 --- a/src/openllm/_llm.py +++ b/src/openllm/_llm.py @@ -585,10 +585,10 @@ class LLM(LLMInterface[M, T], ReprMixin): **attrs: The kwargs to be passed to the model. """ cfg_cls = cls.config_class - model_id = first_not_none(model_id, cfg_cls.__openllm_env__["model_id_value"], cfg_cls.__openllm_default_id__) + model_id = first_not_none(model_id, os.getenv(cfg_cls.__openllm_env__["model_id"]), cfg_cls.__openllm_default_id__) if model_id is None: raise RuntimeError("Failed to resolve a valid model_id.") if validate_is_path(model_id): model_id = resolve_filepath(model_id) - quantize = first_not_none(quantize, cfg_cls.__openllm_env__["quantize_value"], default=None) + quantize = first_not_none(quantize, t.cast(t.Optional[t.Literal["int8", "int4", "gptq"]], os.getenv(cfg_cls.__openllm_env__["quantize"])), default=None) # quantization setup if quantization_config and quantize: raise ValueError("'quantization_config' and 'quantize' are mutually exclusive. Either customise your quantization_config or use the 'quantize' argument.") @@ -613,7 +613,9 @@ class LLM(LLMInterface[M, T], ReprMixin): except Exception as err: raise OpenLLMException(f"Failed to generate a valid tag for {cfg_cls.__openllm_start_name__} with 'model_id={model_id}' (lookup to see its traceback):\n{err}") from err return cls( - *args, model_id=model_id, llm_config=llm_config, quantization_config=quantization_config, bettertransformer=str(first_not_none(bettertransformer, cfg_cls.__openllm_env__["bettertransformer_value"], default=None)).upper() in ENV_VARS_TRUE_VALUES, _runtime=first_not_none(runtime, cfg_cls.__openllm_env__["runtime_value"], default=cfg_cls.__openllm_runtime__), + *args, model_id=model_id, llm_config=llm_config, quantization_config=quantization_config, + bettertransformer=str(first_not_none(bettertransformer, os.getenv(cfg_cls.__openllm_env__["bettertransformer"]), default=None)).upper() in ENV_VARS_TRUE_VALUES, + _runtime=first_not_none(runtime, t.cast(t.Optional[t.Literal["ggml", "transformers"]], os.getenv(cfg_cls.__openllm_env__["runtime"])), default=cfg_cls.__openllm_runtime__), _adapters_mapping=resolve_peft_config_type(adapter_map) if adapter_map is not None else None, _quantize_method=quantize, _model_version=_tag.version, _tag=_tag, _serialisation_format=serialisation, **attrs ) diff --git a/src/openllm/bundle/oci/__init__.py b/src/openllm/bundle/oci/__init__.py index f806ad10..e54cfe4e 100644 --- a/src/openllm/bundle/oci/__init__.py +++ b/src/openllm/bundle/oci/__init__.py @@ -106,7 +106,7 @@ def __dir__() -> list[str]: return sorted(__all__) def __getattr__(name: str) -> t.Any: - if name == "supported_registries": return functools.lru_cache(1)(lambda _: list(_CONTAINER_REGISTRY))() + if name == "supported_registries": return functools.lru_cache(1)(lambda: list(_CONTAINER_REGISTRY))() elif name == "CONTAINER_NAMES": return _CONTAINER_REGISTRY elif name in __all__: return importlib.import_module("." + name, __name__) else: raise AttributeError(f"{name} does not exists under {__name__}") diff --git a/src/openllm/cli/entrypoint.py b/src/openllm/cli/entrypoint.py index bb903cc6..40001410 100644 --- a/src/openllm/cli/entrypoint.py +++ b/src/openllm/cli/entrypoint.py @@ -721,9 +721,9 @@ def build_command( # during build. This is a current limitation of bentoml build where we actually import the service.py into sys.path try: os.environ.update({"OPENLLM_MODEL": inflection.underscore(model_name), env.runtime: str(env.runtime_value), "OPENLLM_SERIALIZATION": serialisation_format}) - if env.model_id_value: os.environ[env.model_id] = str(env.model_id_value) - if env.quantize_value: os.environ[env.quantize] = str(env.quantize_value) - if env.bettertransformer_value: os.environ[env.bettertransformer] = str(env.bettertransformer_value) + os.environ[env.model_id] = str(env.model_id_value) + os.environ[env.quantize] = str(env.quantize_value) + os.environ[env.bettertransformer] = str(env.bettertransformer_value) llm = infer_auto_class(env.framework_value).for_model(model_name, llm_config=llm_config, ensure_available=not fast, model_version=model_version, serialisation=serialisation_format, **attrs) diff --git a/src/openllm/utils/import_utils.py b/src/openllm/utils/import_utils.py index 9bf4c18a..a99eaf11 100644 --- a/src/openllm/utils/import_utils.py +++ b/src/openllm/utils/import_utils.py @@ -340,24 +340,13 @@ def require_backends(o: t.Any, backends: t.MutableSequence[str]) -> None: class EnvVarMixin(ReprMixin): model_name: str - - @property - def __repr_keys__(self) -> set[str]: - return {"config", "model_id", "quantize", "framework", "bettertransformer", "runtime"} - if t.TYPE_CHECKING: config: str model_id: str quantize: str framework: str bettertransformer: str - runtime: t.Literal["ggml", "transformers"] - - framework_value: LiteralRuntime - quantize_value: t.Literal["int8", "int4", "gptq"] | None - bettertransformer_value: bool | None - model_id_value: str | None - runtime_value: t.Literal["ggml", "transformers"] + runtime: str # fmt: off @overload @@ -383,40 +372,33 @@ class EnvVarMixin(ReprMixin): @overload def __getitem__(self, item: t.Literal["runtime_value"]) -> t.Literal["ggml", "transformers"]: ... # fmt: on - def __getitem__(self, item: str | t.Any) -> t.Any: if hasattr(self, item): return getattr(self, item) raise KeyError(f"Key {item} not found in {self}") - - def __new__(cls, model_name: str, implementation: LiteralRuntime = "pt", model_id: str | None = None, bettertransformer: bool | None = None, quantize: t.LiteralString | None = None, runtime: t.Literal["ggml", "transformers"] = "transformers",) -> t.Self: - from . import codegen + def __init__(self, model_name: str, implementation: LiteralRuntime = "pt", model_id: str | None = None, bettertransformer: bool | None = None, quantize: t.LiteralString | None = None, + runtime: t.Literal["ggml", "transformers"] = "transformers") -> None: + """EnvVarMixin is a mixin class that returns the value extracted from environment variables.""" from .._configuration import field_env_key - model_name = inflection.underscore(model_name) - - res = super().__new__(cls) - res.model_name = model_name - - # gen properties env key - for att in {"config", "model_id", "quantize", "framework", "bettertransformer", "runtime"}: - setattr(res, att, field_env_key(model_name, att.upper())) - - # gen properties env value - attributes_with_values = {"framework": (str, implementation), "quantize": (str, quantize), "bettertransformer": (bool, bettertransformer), "model_id": (str, model_id), "runtime": (str, runtime),} - globs: dict[str, t.Any] = {"__bool_vars_value": ENV_VARS_TRUE_VALUES, "__env_get": os.getenv, "self": res} - - for attribute, (default_type, default_value) in attributes_with_values.items(): - lines: list[str] = [] - if default_type is bool: lines.append(f"return str(__env_get(self['{attribute}'], str(__env_default)).upper() in __bool_vars_value)") - else: lines.append(f"return __env_get(self['{attribute}'], __env_default)") - - setattr(res, f"{attribute}_value", codegen.generate_function(cls, "_env_get_" + attribute, lines, ("__env_default",), globs)(default_value)) - - return res - + self.model_name = inflection.underscore(model_name) + self._implementation = implementation + self._model_id = model_id + self._bettertransformer = bettertransformer + self._quantize = quantize + self._runtime = runtime + for att in {"config", "model_id", "quantize", "framework", "bettertransformer", "runtime"}: setattr(self, att, field_env_key(self.model_name, att.upper())) @property - def start_docstring(self) -> str: - return getattr(self.module, f"START_{self.model_name.upper()}_COMMAND_DOCSTRING") - + def __repr_keys__(self) -> set[str]: return {"config", "model_id", "quantize", "framework", "bettertransformer", "runtime"} @property - def module(self) -> _AnnotatedLazyLoader[t.LiteralString]: - return _AnnotatedLazyLoader(self.model_name, globals(), f"openllm.models.{self.model_name}") + def quantize_value(self) -> t.Literal["int8", "int4", "gptq"] | None: return t.cast(t.Optional[t.Literal["int8", "int4", "gptq"]], os.getenv(self["quantize"], self._quantize)) + @property + def framework_value(self) -> LiteralRuntime: return t.cast(t.Literal["pt", "tf", "flax", "vllm"], os.getenv(self["framework"], self._implementation)) + @property + def bettertransformer_value(self) -> bool: return os.getenv(self["bettertransformer"], str(self._bettertransformer)).upper() in ENV_VARS_TRUE_VALUES + @property + def model_id_value(self) -> str | None: return os.getenv(self["model_id"], self._model_id) + @property + def runtime_value(self) -> t.Literal["ggml", "transformers"]: return t.cast(t.Literal["ggml", "transformers"], os.getenv(self["runtime"], self._runtime)) + @property + def start_docstring(self) -> str: return getattr(self.module, f"START_{self.model_name.upper()}_COMMAND_DOCSTRING") + @property + def module(self) -> _AnnotatedLazyLoader[t.LiteralString]: return _AnnotatedLazyLoader(self.model_name, globals(), f"openllm.models.{self.model_name}")