fix: correct setup property for envvar instance

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
Aaron
2023-07-31 23:34:42 -04:00
parent 16f032417e
commit ca5e3c7ae5
4 changed files with 34 additions and 50 deletions

View File

@@ -585,10 +585,10 @@ class LLM(LLMInterface[M, T], ReprMixin):
**attrs: The kwargs to be passed to the model.
"""
cfg_cls = cls.config_class
model_id = first_not_none(model_id, cfg_cls.__openllm_env__["model_id_value"], cfg_cls.__openllm_default_id__)
model_id = first_not_none(model_id, os.getenv(cfg_cls.__openllm_env__["model_id"]), cfg_cls.__openllm_default_id__)
if model_id is None: raise RuntimeError("Failed to resolve a valid model_id.")
if validate_is_path(model_id): model_id = resolve_filepath(model_id)
quantize = first_not_none(quantize, cfg_cls.__openllm_env__["quantize_value"], default=None)
quantize = first_not_none(quantize, t.cast(t.Optional[t.Literal["int8", "int4", "gptq"]], os.getenv(cfg_cls.__openllm_env__["quantize"])), default=None)
# quantization setup
if quantization_config and quantize: raise ValueError("'quantization_config' and 'quantize' are mutually exclusive. Either customise your quantization_config or use the 'quantize' argument.")
@@ -613,7 +613,9 @@ class LLM(LLMInterface[M, T], ReprMixin):
except Exception as err: raise OpenLLMException(f"Failed to generate a valid tag for {cfg_cls.__openllm_start_name__} with 'model_id={model_id}' (lookup to see its traceback):\n{err}") from err
return cls(
*args, model_id=model_id, llm_config=llm_config, quantization_config=quantization_config, bettertransformer=str(first_not_none(bettertransformer, cfg_cls.__openllm_env__["bettertransformer_value"], default=None)).upper() in ENV_VARS_TRUE_VALUES, _runtime=first_not_none(runtime, cfg_cls.__openllm_env__["runtime_value"], default=cfg_cls.__openllm_runtime__),
*args, model_id=model_id, llm_config=llm_config, quantization_config=quantization_config,
bettertransformer=str(first_not_none(bettertransformer, os.getenv(cfg_cls.__openllm_env__["bettertransformer"]), default=None)).upper() in ENV_VARS_TRUE_VALUES,
_runtime=first_not_none(runtime, t.cast(t.Optional[t.Literal["ggml", "transformers"]], os.getenv(cfg_cls.__openllm_env__["runtime"])), default=cfg_cls.__openllm_runtime__),
_adapters_mapping=resolve_peft_config_type(adapter_map) if adapter_map is not None else None, _quantize_method=quantize, _model_version=_tag.version, _tag=_tag, _serialisation_format=serialisation, **attrs
)

View File

@@ -106,7 +106,7 @@ def __dir__() -> list[str]:
return sorted(__all__)
def __getattr__(name: str) -> t.Any:
if name == "supported_registries": return functools.lru_cache(1)(lambda _: list(_CONTAINER_REGISTRY))()
if name == "supported_registries": return functools.lru_cache(1)(lambda: list(_CONTAINER_REGISTRY))()
elif name == "CONTAINER_NAMES": return _CONTAINER_REGISTRY
elif name in __all__: return importlib.import_module("." + name, __name__)
else: raise AttributeError(f"{name} does not exists under {__name__}")

View File

@@ -721,9 +721,9 @@ def build_command(
# during build. This is a current limitation of bentoml build where we actually import the service.py into sys.path
try:
os.environ.update({"OPENLLM_MODEL": inflection.underscore(model_name), env.runtime: str(env.runtime_value), "OPENLLM_SERIALIZATION": serialisation_format})
if env.model_id_value: os.environ[env.model_id] = str(env.model_id_value)
if env.quantize_value: os.environ[env.quantize] = str(env.quantize_value)
if env.bettertransformer_value: os.environ[env.bettertransformer] = str(env.bettertransformer_value)
os.environ[env.model_id] = str(env.model_id_value)
os.environ[env.quantize] = str(env.quantize_value)
os.environ[env.bettertransformer] = str(env.bettertransformer_value)
llm = infer_auto_class(env.framework_value).for_model(model_name, llm_config=llm_config, ensure_available=not fast, model_version=model_version, serialisation=serialisation_format, **attrs)

View File

@@ -340,24 +340,13 @@ def require_backends(o: t.Any, backends: t.MutableSequence[str]) -> None:
class EnvVarMixin(ReprMixin):
model_name: str
@property
def __repr_keys__(self) -> set[str]:
return {"config", "model_id", "quantize", "framework", "bettertransformer", "runtime"}
if t.TYPE_CHECKING:
config: str
model_id: str
quantize: str
framework: str
bettertransformer: str
runtime: t.Literal["ggml", "transformers"]
framework_value: LiteralRuntime
quantize_value: t.Literal["int8", "int4", "gptq"] | None
bettertransformer_value: bool | None
model_id_value: str | None
runtime_value: t.Literal["ggml", "transformers"]
runtime: str
# fmt: off
@overload
@@ -383,40 +372,33 @@ class EnvVarMixin(ReprMixin):
@overload
def __getitem__(self, item: t.Literal["runtime_value"]) -> t.Literal["ggml", "transformers"]: ...
# fmt: on
def __getitem__(self, item: str | t.Any) -> t.Any:
if hasattr(self, item): return getattr(self, item)
raise KeyError(f"Key {item} not found in {self}")
def __new__(cls, model_name: str, implementation: LiteralRuntime = "pt", model_id: str | None = None, bettertransformer: bool | None = None, quantize: t.LiteralString | None = None, runtime: t.Literal["ggml", "transformers"] = "transformers",) -> t.Self:
from . import codegen
def __init__(self, model_name: str, implementation: LiteralRuntime = "pt", model_id: str | None = None, bettertransformer: bool | None = None, quantize: t.LiteralString | None = None,
runtime: t.Literal["ggml", "transformers"] = "transformers") -> None:
"""EnvVarMixin is a mixin class that returns the value extracted from environment variables."""
from .._configuration import field_env_key
model_name = inflection.underscore(model_name)
res = super().__new__(cls)
res.model_name = model_name
# gen properties env key
for att in {"config", "model_id", "quantize", "framework", "bettertransformer", "runtime"}:
setattr(res, att, field_env_key(model_name, att.upper()))
# gen properties env value
attributes_with_values = {"framework": (str, implementation), "quantize": (str, quantize), "bettertransformer": (bool, bettertransformer), "model_id": (str, model_id), "runtime": (str, runtime),}
globs: dict[str, t.Any] = {"__bool_vars_value": ENV_VARS_TRUE_VALUES, "__env_get": os.getenv, "self": res}
for attribute, (default_type, default_value) in attributes_with_values.items():
lines: list[str] = []
if default_type is bool: lines.append(f"return str(__env_get(self['{attribute}'], str(__env_default)).upper() in __bool_vars_value)")
else: lines.append(f"return __env_get(self['{attribute}'], __env_default)")
setattr(res, f"{attribute}_value", codegen.generate_function(cls, "_env_get_" + attribute, lines, ("__env_default",), globs)(default_value))
return res
self.model_name = inflection.underscore(model_name)
self._implementation = implementation
self._model_id = model_id
self._bettertransformer = bettertransformer
self._quantize = quantize
self._runtime = runtime
for att in {"config", "model_id", "quantize", "framework", "bettertransformer", "runtime"}: setattr(self, att, field_env_key(self.model_name, att.upper()))
@property
def start_docstring(self) -> str:
return getattr(self.module, f"START_{self.model_name.upper()}_COMMAND_DOCSTRING")
def __repr_keys__(self) -> set[str]: return {"config", "model_id", "quantize", "framework", "bettertransformer", "runtime"}
@property
def module(self) -> _AnnotatedLazyLoader[t.LiteralString]:
return _AnnotatedLazyLoader(self.model_name, globals(), f"openllm.models.{self.model_name}")
def quantize_value(self) -> t.Literal["int8", "int4", "gptq"] | None: return t.cast(t.Optional[t.Literal["int8", "int4", "gptq"]], os.getenv(self["quantize"], self._quantize))
@property
def framework_value(self) -> LiteralRuntime: return t.cast(t.Literal["pt", "tf", "flax", "vllm"], os.getenv(self["framework"], self._implementation))
@property
def bettertransformer_value(self) -> bool: return os.getenv(self["bettertransformer"], str(self._bettertransformer)).upper() in ENV_VARS_TRUE_VALUES
@property
def model_id_value(self) -> str | None: return os.getenv(self["model_id"], self._model_id)
@property
def runtime_value(self) -> t.Literal["ggml", "transformers"]: return t.cast(t.Literal["ggml", "transformers"], os.getenv(self["runtime"], self._runtime))
@property
def start_docstring(self) -> str: return getattr(self.module, f"START_{self.model_name.upper()}_COMMAND_DOCSTRING")
@property
def module(self) -> _AnnotatedLazyLoader[t.LiteralString]: return _AnnotatedLazyLoader(self.model_name, globals(), f"openllm.models.{self.model_name}")