From 2cc264aa7243bb5a1185ae8ce96da7f55b714796 Mon Sep 17 00:00:00 2001 From: Aaron Pham <29749331+aarnphm@users.noreply.github.com> Date: Thu, 3 Aug 2023 16:34:35 -0400 Subject: [PATCH] fix(vllm): correctly load given model id from envvar (#181) --- changelog.d/181.fix.md | 6 ++++ src/openllm/_configuration.py | 6 ++-- src/openllm/_llm.py | 47 ++++++++++++++------------- src/openllm/_strategies.py | 6 ++-- src/openllm/bundle/_package.py | 12 +++---- src/openllm/bundle/oci/Dockerfile | 19 ++++++++++- src/openllm/cli/_factory.py | 14 ++++---- src/openllm/cli/entrypoint.py | 26 +++++++-------- src/openllm/cli/termui.py | 2 +- src/openllm/models/auto/factory.py | 2 +- src/openllm/utils/__init__.py | 2 +- src/openllm/utils/import_utils.py | 52 ++++++++++++++++++++---------- 12 files changed, 119 insertions(+), 75 deletions(-) create mode 100644 changelog.d/181.fix.md diff --git a/changelog.d/181.fix.md b/changelog.d/181.fix.md new file mode 100644 index 00000000..6dc0069e --- /dev/null +++ b/changelog.d/181.fix.md @@ -0,0 +1,6 @@ +Fixes a bug with `EnvVarMixin` where it didn't respect environment variable for specific fields + +This inherently provide a confusing behaviour with `--model-id`. This is now has been addressed with main + +The base docker will now also include a installation of xformers from source, locked at a given hash, since the latest release of xformers +are too old and would fail with vLLM when running within the k8s diff --git a/src/openllm/_configuration.py b/src/openllm/_configuration.py index 5a575717..16e6e2f6 100644 --- a/src/openllm/_configuration.py +++ b/src/openllm/_configuration.py @@ -622,7 +622,7 @@ def structure_settings(cl_: type[LLMConfig], cls: type[_ModelSettingsAttr]) -> _ _final_value_dct["env"] = env # bettertransformer support - if _settings_attr["bettertransformer"] is None: _final_value_dct["bettertransformer"] = str(env.bettertransformer_value).upper() in ENV_VARS_TRUE_VALUES + if _settings_attr["bettertransformer"] is None: _final_value_dct["bettertransformer"] = str(env["bettertransformer_value"]).upper() in ENV_VARS_TRUE_VALUES # if requires_gpu is True, then disable BetterTransformer for quantization. if _settings_attr["requires_gpu"]: _final_value_dct["bettertransformer"] = False _final_value_dct["service_name"] = f"generated_{model_name}_service.py" @@ -1485,4 +1485,6 @@ def structure_llm_config(data: DictStrAny, cls: type[LLMConfig]) -> LLMConfig: bentoml_cattr.register_structure_hook_func(lambda cls: lenient_issubclass(cls, LLMConfig), structure_llm_config) -openllm_home = os.path.expanduser(os.getenv("OPENLLM_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", os.path.join(os.path.expanduser("~"), ".cache")), "openllm"))) +openllm_home = os.path.expanduser(os.environ.get("OPENLLM_HOME", os.path.join(os.environ.get("XDG_CACHE_HOME", os.path.join(os.path.expanduser("~"), ".cache")), "openllm"))) + +__all__ = ["LLMConfig", "field_env_key"] diff --git a/src/openllm/_llm.py b/src/openllm/_llm.py index 173bd908..786d91e0 100644 --- a/src/openllm/_llm.py +++ b/src/openllm/_llm.py @@ -419,6 +419,11 @@ def _make_assignment_script(cls: type[LLM[M, T]]) -> t.Callable[[type[LLM[M, T]] else: func_call = f"_impl_{cls.__name__}_{func}={cached_func_name} if {cached_func_name} is not _cached_LLMInterface_get('{func}') else __serialisation_{func}" lines.extend([f"{cached_func_name}=cls.{func}", func_call, _setattr_class(func, f"{impl_name}(_impl_{cls.__name__}_{func})"),]) + # assign vllm specific implementation + if cls.__llm_implementation__ == "vllm": + globs.update({"_vllm_generate": vllm_generate, "_vllm_postprocess_generate": vllm_postprocess_generate}) + lines.extend([_setattr_class(it, f"_vllm_{it}") for it in {"generate", "postprocess_generate"}]) + # cached attribute initialisation interface_anns = codegen.get_annotations(LLMInterface) for v in {"bentomodel", "model", "tokenizer", "adapter_map"}: @@ -432,6 +437,17 @@ def _make_assignment_script(cls: type[LLM[M, T]]) -> t.Callable[[type[LLM[M, T]] anns[key] = interface_anns.get(key) return codegen.generate_function(cls, "__assign_llm_attr", lines, args=("cls", *args), globs=globs, annotations=anns) +def vllm_postprocess_generate(self: LLM["vllm.LLMEngine", T], prompt: str, generation_result: list[dict[str, t.Any]], **_: t.Any) -> str: + return generation_result[0]["outputs"][0]["text"] + +def vllm_generate(self: LLM["vllm.LLMEngine", T], prompt: str, **attrs: t.Any) -> list[dict[str, t.Any]]: + outputs: list[vllm.RequestOutput] = [] + # TODO: support prompt_token_ids + self.model.add_request(request_id=str(uuid.uuid4().hex), prompt=prompt, sampling_params=self.config.model_construct_env(**attrs).to_sampling_config()) + while self.model.has_unfinished_requests(): + outputs.extend([r for r in self.model.step() if r.finished]) + return [unmarshal_vllm_outputs(i) for i in outputs] + _AdaptersTuple: type[AdaptersTuple] = codegen.make_attr_tuple_class("AdaptersTuple", ["adapter_id", "name", "config"]) @attr.define(slots=True, repr=False, init=False) @@ -470,19 +486,6 @@ class LLM(LLMInterface[M, T], ReprMixin): _make_assignment_script(cls)(cls) if "tokenizer_id" not in cd and cls.__llm_implementation__ == "vllm": cls.tokenizer_id = _DEFAULT_TOKENIZER - if implementation == "vllm": - def vllm_postprocess_generate(self: LLM["vllm.LLMEngine", T], prompt: str, generation_result: list[dict[str, t.Any]], **_: t.Any) -> str: return generation_result[0]["outputs"][0]["text"] - def vllm_generate(self: LLM["vllm.LLMEngine", T], prompt: str, **attrs: t.Any) -> list[dict[str, t.Any]]: - outputs: list[vllm.RequestOutput] = [] - # TODO: support prompt_token_ids - self.model.add_request(request_id=str(uuid.uuid4().hex), prompt=prompt, sampling_params=self.config.model_construct_env(**attrs).to_sampling_config()) - while self.model.has_unfinished_requests(): - outputs.extend([r for r in self.model.step() if r.finished]) - return [unmarshal_vllm_outputs(i) for i in outputs] - - _object_setattr(cls, "postprocess_generate", vllm_postprocess_generate) - _object_setattr(cls, "generate", vllm_generate) - # fmt: off @overload def __getitem__(self, item: t.Literal["trust_remote_code"]) -> bool: ... @@ -586,10 +589,10 @@ class LLM(LLMInterface[M, T], ReprMixin): **attrs: The kwargs to be passed to the model. """ cfg_cls = cls.config_class - model_id = first_not_none(model_id, os.getenv(cfg_cls.__openllm_env__["model_id"]), cfg_cls.__openllm_default_id__) + model_id = first_not_none(model_id, os.environ.get(cfg_cls.__openllm_env__["model_id"]), cfg_cls.__openllm_default_id__) if model_id is None: raise RuntimeError("Failed to resolve a valid model_id.") if validate_is_path(model_id): model_id = resolve_filepath(model_id) - quantize = first_not_none(quantize, t.cast(t.Optional[t.Literal["int8", "int4", "gptq"]], os.getenv(cfg_cls.__openllm_env__["quantize"])), default=None) + quantize = first_not_none(quantize, t.cast(t.Optional[t.Literal["int8", "int4", "gptq"]], os.environ.get(cfg_cls.__openllm_env__["quantize"])), default=None) # quantization setup if quantization_config and quantize: raise ValueError("'quantization_config' and 'quantize' are mutually exclusive. Either customise your quantization_config or use the 'quantize' argument.") @@ -614,10 +617,9 @@ class LLM(LLMInterface[M, T], ReprMixin): except Exception as err: raise OpenLLMException(f"Failed to generate a valid tag for {cfg_cls.__openllm_start_name__} with 'model_id={model_id}' (lookup to see its traceback):\n{err}") from err return cls( - *args, model_id=model_id, llm_config=llm_config, quantization_config=quantization_config, - bettertransformer=str(first_not_none(bettertransformer, os.getenv(cfg_cls.__openllm_env__["bettertransformer"]), default=None)).upper() in ENV_VARS_TRUE_VALUES, - _runtime=first_not_none(runtime, t.cast(t.Optional[t.Literal["ggml", "transformers"]], os.getenv(cfg_cls.__openllm_env__["runtime"])), default=cfg_cls.__openllm_runtime__), - _adapters_mapping=resolve_peft_config_type(adapter_map) if adapter_map is not None else None, _quantize_method=quantize, _model_version=_tag.version, _tag=_tag, _serialisation_format=serialisation, **attrs + *args, model_id=model_id, llm_config=llm_config, quantization_config=quantization_config, bettertransformer=str(first_not_none(bettertransformer, os.environ.get(cfg_cls.__openllm_env__["bettertransformer"]), default=None)).upper() in ENV_VARS_TRUE_VALUES, + _runtime=first_not_none(runtime, t.cast(t.Optional[t.Literal["ggml", "transformers"]], os.environ.get(cfg_cls.__openllm_env__["runtime"])), default=cfg_cls.__openllm_runtime__), _adapters_mapping=resolve_peft_config_type(adapter_map) + if adapter_map is not None else None, _quantize_method=quantize, _model_version=_tag.version, _tag=_tag, _serialisation_format=serialisation, **attrs ) @classmethod @@ -640,7 +642,7 @@ class LLM(LLMInterface[M, T], ReprMixin): ``str``: Generated tag format that can be parsed by ``bentoml.Tag`` """ # specific branch for running in docker, this is very hacky, needs change upstream - if in_docker() and os.getenv("BENTO_PATH") is not None: return ":".join(fs.path.parts(model_id)[-2:]) + if in_docker() and os.environ.get("BENTO_PATH") is not None: return ":".join(fs.path.parts(model_id)[-2:]) model_name = normalise_model_name(model_id) model_id, *maybe_revision = model_id.rsplit(":") @@ -649,7 +651,7 @@ class LLM(LLMInterface[M, T], ReprMixin): return f"{cls.__llm_implementation__}-{model_name}:{maybe_revision[0]}" tag_name = f"{cls.__llm_implementation__}-{model_name}" - if os.getenv("OPENLLM_USE_LOCAL_LATEST", str(False)).upper() in ENV_VARS_TRUE_VALUES: return bentoml_cattr.unstructure(bentoml.models.get(f"{tag_name}{':'+model_version if model_version is not None else ''}").tag) + if os.environ.get("OPENLLM_USE_LOCAL_LATEST", str(False)).upper() in ENV_VARS_TRUE_VALUES: return bentoml_cattr.unstructure(bentoml.models.get(f"{tag_name}{':'+model_version if model_version is not None else ''}").tag) if validate_is_path(model_id): model_id, model_version = resolve_filepath(model_id), first_not_none(model_version, default=generate_hash_from_file(model_id)) else: _config = transformers.AutoConfig.from_pretrained(model_id, trust_remote_code=cls.config_class.__openllm_trust_remote_code__, revision=first_not_none(model_version, default="main")) @@ -1015,8 +1017,7 @@ def Runner(model_name: str, ensure_available: bool | None = None, init_local: bo behaviour """ if llm_config is not None: - attrs.update({"model_id": llm_config["env"]["model_id_value"], "bettertransformer": llm_config["env"]["bettertransformer_value"], "quantize": llm_config["env"]["quantize_value"], "runtime": llm_config["env"]["runtime_value"], - "serialisation": first_not_none(os.getenv("OPENLLM_SERIALIZATION"), attrs.get("serialisation"), default="safetensors")}) + attrs.update({"model_id": llm_config["env"]["model_id_value"], "bettertransformer": llm_config["env"]["bettertransformer_value"], "quantize": llm_config["env"]["quantize_value"], "runtime": llm_config["env"]["runtime_value"], "serialisation": first_not_none(os.environ.get("OPENLLM_SERIALIZATION"), attrs.get("serialisation"), default="safetensors")}) default_implementation = llm_config.default_implementation() if llm_config is not None else "pt" implementation = first_not_none(implementation, default=EnvVarMixin(model_name, default_implementation)["framework_value"]) diff --git a/src/openllm/_strategies.py b/src/openllm/_strategies.py index 97a08b18..ff09841c 100644 --- a/src/openllm/_strategies.py +++ b/src/openllm/_strategies.py @@ -86,7 +86,7 @@ def _parse_visible_devices(default_var: str = ..., *, respect_env: t.Literal[Fal def _parse_visible_devices(default_var: str | None = None, respect_env: bool = True) -> list[str] | None: """CUDA_VISIBLE_DEVICES aware with default var for parsing spec.""" if respect_env: - spec = os.getenv("CUDA_VISIBLE_DEVICES", default_var) + spec = os.environ.get("CUDA_VISIBLE_DEVICES", default_var) if not spec: return None else: if default_var is None: raise ValueError("spec is required to be not None when parsing spec.") @@ -370,11 +370,11 @@ class CascadingResourceStrategy(bentoml.Strategy, ReprMixin): if runnable_class.SUPPORTS_CPU_MULTI_THREADING: thread_count = math.ceil(cpus) for thread_env in THREAD_ENVS: - environ[thread_env] = os.getenv(thread_env, str(thread_count)) + environ[thread_env] = os.environ.get(thread_env, str(thread_count)) logger.debug("Environ for worker %s: %s", worker_index, environ) return environ for thread_env in THREAD_ENVS: - environ[thread_env] = os.getenv(thread_env, "1") + environ[thread_env] = os.environ.get(thread_env, "1") return environ return environ diff --git a/src/openllm/bundle/_package.py b/src/openllm/bundle/_package.py index 97af3935..dbf88189 100644 --- a/src/openllm/bundle/_package.py +++ b/src/openllm/bundle/_package.py @@ -135,17 +135,17 @@ def construct_docker_options( _bentoml_config_options += " " if _bentoml_config_options else "" + " ".join(_bentoml_config_options_opts) env: EnvVarMixin = llm.config["env"] env_dict = { - env.framework: env.framework_value, env.config: f"'{llm.config.model_dump_json().decode()}'", "OPENLLM_MODEL": llm.config["model_name"], "OPENLLM_SERIALIZATION": serialisation_format, - "OPENLLM_ADAPTER_MAP": f"'{orjson.dumps(adapter_map).decode()}'", "BENTOML_DEBUG": str(True), "BENTOML_QUIET": str(False), "BENTOML_CONFIG_OPTIONS": f"'{_bentoml_config_options}'", - env.model_id: f"/home/bentoml/bento/models/{llm.tag.path()}"} + env.framework: env["framework_value"], env.config: f"'{llm.config.model_dump_json().decode()}'", "OPENLLM_MODEL": llm.config["model_name"], "OPENLLM_SERIALIZATION": serialisation_format, "OPENLLM_ADAPTER_MAP": f"'{orjson.dumps(adapter_map).decode()}'", "BENTOML_DEBUG": str(True), "BENTOML_QUIET": str(False), "BENTOML_CONFIG_OPTIONS": f"'{_bentoml_config_options}'", + env.model_id: f"/home/bentoml/bento/models/{llm.tag.path()}" + } if adapter_map: env_dict["BITSANDBYTES_NOWELCOME"] = os.environ.get("BITSANDBYTES_NOWELCOME", "1") # We need to handle None separately here, as env from subprocess doesn't accept None value. _env = EnvVarMixin(llm.config["model_name"], bettertransformer=bettertransformer, quantize=quantize, runtime=runtime) - if _env.bettertransformer_value is not None: env_dict[_env.bettertransformer] = str(_env.bettertransformer_value) - if _env.quantize_value is not None: env_dict[_env.quantize] = _env.quantize_value - env_dict[_env.runtime] = _env.runtime_value + env_dict[_env.bettertransformer] = str(_env["bettertransformer_value"]) + if _env["quantize_value"] is not None: env_dict[_env.quantize] = t.cast(str, _env["quantize_value"]) + env_dict[_env.runtime] = _env["runtime_value"] return DockerOptions(base_image=f"{oci.CONTAINER_NAMES[container_registry]}:{oci.get_base_container_tag(container_version_strategy)}", env=env_dict, dockerfile_template=dockerfile_template) @inject diff --git a/src/openllm/bundle/oci/Dockerfile b/src/openllm/bundle/oci/Dockerfile index f9560fe0..8b9eb4ad 100644 --- a/src/openllm/bundle/oci/Dockerfile +++ b/src/openllm/bundle/oci/Dockerfile @@ -118,6 +118,20 @@ git fetch && git checkout ${COMMIT_HASH} TORCH_CUDA_ARCH_LIST="7.5;8.0;8.6+PTX;8.9;9.0" python setup.py build EOT +# NOTE: Build xformers from source since the latest xformers are too old +FROM kernel-builder as xformers-builder + +ENV COMMIT_HASH 2d3a2217c263419243b70c53f725213d1c386b0f +ARG COMMIT_HASH=${COMMIT_HASH} + +WORKDIR /usr/src + +RUN < LLMConfig | subprocess.Popen[bytes]: fast = str(fast).upper() in ENV_VARS_TRUE_VALUES - if serialisation_format == "safetensors" and quantize is not None and os.getenv("OPENLLM_SERIALIZATION_WARNING", str(True)).upper() in ENV_VARS_TRUE_VALUES: + if serialisation_format == "safetensors" and quantize is not None and os.environ.get("OPENLLM_SERIALIZATION_WARNING", str(True)).upper() in ENV_VARS_TRUE_VALUES: termui.echo(f"'--quantize={quantize}' might not work with 'safetensors' serialisation format. Use with caution!. To silence this warning, set \"OPENLLM_SERIALIZATION_WARNING=False\"\nNote: You can always fallback to '--serialisation legacy' when running quantisation.", fg="yellow") adapter_map: dict[str, str | None] | None = attrs.pop(_adapter_mapping_key, None) config, server_attrs = llm_config.model_validate_click(**attrs) @@ -173,14 +173,14 @@ Available official model_id(s): [default: {llm_config['default_id']}] start_env = parse_config_options(config, server_timeout, wpr, device, start_env) if fast: termui.echo(f"Fast mode is enabled. Make sure the model is available in local store before 'start': 'openllm import {model}{' --model-id ' + model_id if model_id else ''}'", fg="yellow") - start_env.update({"OPENLLM_MODEL": model, "BENTOML_DEBUG": str(get_debug_mode()), "BENTOML_HOME": os.getenv("BENTOML_HOME", BentoMLContainer.bentoml_home.get()), "OPENLLM_ADAPTER_MAP": orjson.dumps(adapter_map).decode(), "OPENLLM_SERIALIZATION": serialisation_format, env.runtime: env.runtime_value, env.framework: env.framework_value}) - if env.model_id_value: start_env[env.model_id] = str(env.model_id_value) + start_env.update({"OPENLLM_MODEL": model, "BENTOML_DEBUG": str(get_debug_mode()), "BENTOML_HOME": os.environ.get("BENTOML_HOME", BentoMLContainer.bentoml_home.get()), "OPENLLM_ADAPTER_MAP": orjson.dumps(adapter_map).decode(), "OPENLLM_SERIALIZATION": serialisation_format, env.runtime: env["runtime_value"], env.framework: env["framework_value"]}) + start_env[env.model_id] = str(env["model_id_value"]) # NOTE: quantize and bettertransformer value is already assigned within env - if bettertransformer is not None: start_env[env.bettertransformer] = str(env.bettertransformer_value) - if quantize is not None: start_env[env.quantize] = str(env.quantize_value) + if bettertransformer is not None: start_env[env.bettertransformer] = str(env["bettertransformer_value"]) + if quantize is not None: start_env[env.quantize] = str(t.cast(str, env["quantize_value"])) - llm = infer_auto_class(env.framework_value).for_model(model, model_version=model_version, llm_config=config, ensure_available=not fast, adapter_map=adapter_map, serialisation=serialisation_format) - start_env.update({env.config: llm.config.model_dump_json().decode(), env.model_id: llm.model_id}) + llm = infer_auto_class(env["framework_value"]).for_model(model, model_id=start_env[env.model_id], model_version=model_version, llm_config=config, ensure_available=not fast, adapter_map=adapter_map, serialisation=serialisation_format) + start_env.update({env.config: llm.config.model_dump_json().decode()}) server = bentoml.GrpcServer("_service.py:svc", **server_attrs) if _serve_grpc else bentoml.HTTPServer("_service.py:svc", **server_attrs) analytics.track_start_init(llm.config) diff --git a/src/openllm/cli/entrypoint.py b/src/openllm/cli/entrypoint.py index 63c7730a..daaffbdc 100644 --- a/src/openllm/cli/entrypoint.py +++ b/src/openllm/cli/entrypoint.py @@ -455,15 +455,15 @@ def import_command(model_name: str, model_id: str | None, converter: str | None, """ llm_config = AutoConfig.for_model(model_name) env = EnvVarMixin(model_name, llm_config.default_implementation(), model_id=model_id, runtime=runtime, quantize=quantize) - impl: LiteralRuntime = first_not_none(implementation, default=env.framework_value) - llm = infer_auto_class(impl).for_model(model_name, llm_config=llm_config, model_version=model_version, ensure_available=False, serialisation=serialisation_format) + impl: LiteralRuntime = first_not_none(implementation, default=env["framework_value"]) + llm = infer_auto_class(impl).for_model(model_name, model_id=env["model_id_value"], llm_config=llm_config, model_version=model_version, ensure_available=False, serialisation=serialisation_format) _previously_saved = False try: _ref = serialisation.get(llm) _previously_saved = True except bentoml.exceptions.NotFound: if not machine and output == "pretty": - msg = f"'{model_name}' {'with model_id='+ model_id if model_id is not None else ''} does not exists in local store. Saving to BENTOML_HOME{' (path=' + os.getenv('BENTOML_HOME', BentoMLContainer.bentoml_home.get()) + ')' if get_debug_mode() else ''}..." + msg = f"'{model_name}' {'with model_id='+ model_id if model_id is not None else ''} does not exists in local store. Saving to BENTOML_HOME{' (path=' + os.environ.get('BENTOML_HOME', BentoMLContainer.bentoml_home.get()) + ')' if get_debug_mode() else ''}..." termui.echo(msg, fg="yellow", nl=True) _ref = serialisation.get(llm, auto_import=True) if impl == "pt" and is_torch_available() and torch.cuda.is_available(): torch.cuda.empty_cache() @@ -518,16 +518,16 @@ def _start( framework: The framework to use for this LLM. By default, this is set to ``pt``. additional_args: Additional arguments to pass to ``openllm start``. """ - fast = os.getenv("OPENLLM_FAST", str(fast)).upper() in ENV_VARS_TRUE_VALUES + fast = os.environ.get("OPENLLM_FAST", str(fast)).upper() in ENV_VARS_TRUE_VALUES llm_config = AutoConfig.for_model(model_name) _ModelEnv = EnvVarMixin(model_name, first_not_none(framework, default=llm_config.default_implementation()), model_id=model_id, bettertransformer=bettertransformer, quantize=quantize, runtime=runtime) - os.environ[_ModelEnv.framework] = _ModelEnv.framework_value + os.environ[_ModelEnv.framework] = _ModelEnv["framework_value"] args: ListStr = ["--runtime", runtime] if model_id: args.extend(["--model-id", model_id]) if timeout: args.extend(["--server-timeout", str(timeout)]) if workers_per_resource: args.extend(["--workers-per-resource", str(workers_per_resource) if not isinstance(workers_per_resource, str) else workers_per_resource]) - if device and not os.getenv("CUDA_VISIBLE_DEVICES"): args.extend(["--device", ",".join(device)]) + if device and not os.environ.get("CUDA_VISIBLE_DEVICES"): args.extend(["--device", ",".join(device)]) if quantize and bettertransformer: raise OpenLLMException("'quantize' and 'bettertransformer' are currently mutually exclusive.") if quantize: args.extend(["--quantize", str(quantize)]) elif bettertransformer: args.append("--bettertransformer") @@ -722,15 +722,15 @@ def build_command( # NOTE: We set this environment variable so that our service.py logic won't raise RuntimeError # during build. This is a current limitation of bentoml build where we actually import the service.py into sys.path try: - os.environ.update({"OPENLLM_MODEL": inflection.underscore(model_name), env.runtime: str(env.runtime_value), "OPENLLM_SERIALIZATION": serialisation_format}) - os.environ[env.model_id] = str(env.model_id_value) - os.environ[env.quantize] = str(env.quantize_value) - os.environ[env.bettertransformer] = str(env.bettertransformer_value) + os.environ.update({"OPENLLM_MODEL": inflection.underscore(model_name), env.runtime: str(env["runtime_value"]), "OPENLLM_SERIALIZATION": serialisation_format}) + os.environ[env.model_id] = str(env["model_id_value"]) + os.environ[env.quantize] = str(env["quantize_value"]) + os.environ[env.bettertransformer] = str(env["bettertransformer_value"]) - llm = infer_auto_class(env.framework_value).for_model(model_name, llm_config=llm_config, ensure_available=not fast, model_version=model_version, serialisation=serialisation_format, **attrs) + llm = infer_auto_class(env["framework_value"]).for_model(model_name, model_id=env["model_id_value"], llm_config=llm_config, ensure_available=not fast, model_version=model_version, serialisation=serialisation_format, **attrs) labels = dict(llm.identifying_params) - labels.update({"_type": llm.llm_type, "_framework": env.framework_value}) + labels.update({"_type": llm.llm_type, "_framework": env["framework_value"]}) workers_per_resource = first_not_none(workers_per_resource, default=llm_config["workers_per_resource"]) with fs.open_fs(f"temp://llm_{llm_config['model_name']}") as llm_fs: @@ -796,7 +796,7 @@ def build_command( if push: BentoMLContainer.bentocloud_client.get().push_bento(bento, context=t.cast(GlobalOptions, ctx.obj).cloud_context, force=True) elif containerize: - backend = t.cast("DefaultBuilder", os.getenv("BENTOML_CONTAINERIZE_BACKEND", "docker")) + backend = t.cast("DefaultBuilder", os.environ.get("BENTOML_CONTAINERIZE_BACKEND", "docker")) try: bentoml.container.health(backend) except subprocess.CalledProcessError: diff --git a/src/openllm/cli/termui.py b/src/openllm/cli/termui.py index 9db513ed..24a6f434 100644 --- a/src/openllm/cli/termui.py +++ b/src/openllm/cli/termui.py @@ -26,7 +26,7 @@ def echo(text: t.Any, fg: str = "green", _with_style: bool = True, **attrs: t.An attrs["fg"], call = fg if not get_debug_mode() else None, click.echo if not _with_style else click.secho if not get_quiet_mode(): call(text, **attrs) -COLUMNS = int(os.getenv("COLUMNS", str(120))) +COLUMNS: int = int(os.environ.get("COLUMNS", str(120))) CONTEXT_SETTINGS = {"help_option_names": ["-h", "--help"], "max_content_width": COLUMNS, "token_normalize_func": inflection.underscore} diff --git a/src/openllm/models/auto/factory.py b/src/openllm/models/auto/factory.py index 9603aa54..8dfa4409 100644 --- a/src/openllm/models/auto/factory.py +++ b/src/openllm/models/auto/factory.py @@ -48,7 +48,7 @@ class BaseAutoLLMClass: >>> llm = openllm.AutoLLM.for_model("flan-t5") ``` """ - llm = cls.infer_class_from_name(model).from_pretrained(model_id, model_version=model_version, llm_config=llm_config, **attrs) + llm = cls.infer_class_from_name(model).from_pretrained(model_id=model_id, model_version=model_version, llm_config=llm_config, **attrs) if ensure_available: llm.ensure_model_id_exists() return llm diff --git a/src/openllm/utils/__init__.py b/src/openllm/utils/__init__.py index 0294e04f..426e76cd 100644 --- a/src/openllm/utils/__init__.py +++ b/src/openllm/utils/__init__.py @@ -124,7 +124,7 @@ def field_env_key(model_name: str, key: str, suffix: str | t.Literal[""] | None return "_".join(filter(None, map(str.upper, ["OPENLLM", model_name, suffix.strip("_") if suffix else "", key]))) # Special debug flag controled via OPENLLMDEVDEBUG -DEBUG = sys.flags.dev_mode or (not sys.flags.ignore_environment and bool(os.getenv(DEV_DEBUG_VAR))) +DEBUG: bool = sys.flags.dev_mode or (not sys.flags.ignore_environment and bool(os.environ.get(DEV_DEBUG_VAR))) # MYPY is like t.TYPE_CHECKING, but reserved for Mypy plugins MYPY = False SHOW_CODEGEN = DEBUG and int(os.environ.get("OPENLLMDEVDEBUG", str(0))) > 3 diff --git a/src/openllm/utils/import_utils.py b/src/openllm/utils/import_utils.py index ccd91e90..86bdea8f 100644 --- a/src/openllm/utils/import_utils.py +++ b/src/openllm/utils/import_utils.py @@ -373,10 +373,11 @@ class EnvVarMixin(ReprMixin): def __getitem__(self, item: t.Literal["runtime_value"]) -> t.Literal["ggml", "transformers"]: ... # fmt: on def __getitem__(self, item: str | t.Any) -> t.Any: - if hasattr(self, item): return getattr(self, item) + if item.endswith("_value") and hasattr(self, f"_{item}"): return object.__getattribute__(self, f"_{item}")() + elif hasattr(self, item): return getattr(self, item) raise KeyError(f"Key {item} not found in {self}") - def __init__(self, model_name: str, implementation: LiteralRuntime = "pt", model_id: str | None = None, bettertransformer: bool | None = None, quantize: t.LiteralString | None = None, - runtime: t.Literal["ggml", "transformers"] = "transformers") -> None: + + def __init__(self, model_name: str, implementation: LiteralRuntime = "pt", model_id: str | None = None, bettertransformer: bool | None = None, quantize: t.LiteralString | None = None, runtime: t.Literal["ggml", "transformers"] = "transformers") -> None: """EnvVarMixin is a mixin class that returns the value extracted from environment variables.""" from .._configuration import field_env_key self.model_name = inflection.underscore(model_name) @@ -385,20 +386,37 @@ class EnvVarMixin(ReprMixin): self._bettertransformer = bettertransformer self._quantize = quantize self._runtime = runtime - for att in {"config", "model_id", "quantize", "framework", "bettertransformer", "runtime"}: setattr(self, att, field_env_key(self.model_name, att.upper())) + for att in {"config", "model_id", "quantize", "framework", "bettertransformer", "runtime"}: + setattr(self, att, field_env_key(self.model_name, att.upper())) + @property - def __repr_keys__(self) -> set[str]: return {"config", "model_id", "quantize", "framework", "bettertransformer", "runtime"} + def __repr_keys__(self) -> set[str]: + return {"config", "model_id", "quantize", "framework", "bettertransformer", "runtime"} + + def _quantize_value(self) -> t.Literal["int8", "int4", "gptq"] | None: + from . import first_not_none + return t.cast(t.Optional[t.Literal["int8", "int4", "gptq"]], first_not_none(os.environ.get(self["quantize"]), default=self._quantize)) + + def _framework_value(self) -> LiteralRuntime: + from . import first_not_none + return t.cast(t.Literal["pt", "tf", "flax", "vllm"], first_not_none(os.environ.get(self["framework"]), default=self._implementation)) + + def _bettertransformer_value(self) -> bool: + from . import first_not_none + return t.cast(bool, first_not_none(os.environ.get(self["bettertransformer"], str(False)).upper() in ENV_VARS_TRUE_VALUES, default=self._bettertransformer)) + + def _model_id_value(self) -> str | None: + from . import first_not_none + return first_not_none(os.environ.get(self["model_id"]), default=self._model_id) + + def _runtime_value(self) -> t.Literal["ggml", "transformers"]: + from . import first_not_none + return t.cast(t.Literal["ggml", "transformers"], first_not_none(os.environ.get(self["runtime"]), default=self._runtime)) + @property - def quantize_value(self) -> t.Literal["int8", "int4", "gptq"] | None: return t.cast(t.Optional[t.Literal["int8", "int4", "gptq"]], os.getenv(self["quantize"], self._quantize)) + def start_docstring(self) -> str: + return getattr(self.module, f"START_{self.model_name.upper()}_COMMAND_DOCSTRING") + @property - def framework_value(self) -> LiteralRuntime: return t.cast(t.Literal["pt", "tf", "flax", "vllm"], os.getenv(self["framework"], self._implementation)) - @property - def bettertransformer_value(self) -> bool: return os.getenv(self["bettertransformer"], str(self._bettertransformer)).upper() in ENV_VARS_TRUE_VALUES - @property - def model_id_value(self) -> str | None: return os.getenv(self["model_id"], self._model_id) - @property - def runtime_value(self) -> t.Literal["ggml", "transformers"]: return t.cast(t.Literal["ggml", "transformers"], os.getenv(self["runtime"], self._runtime)) - @property - def start_docstring(self) -> str: return getattr(self.module, f"START_{self.model_name.upper()}_COMMAND_DOCSTRING") - @property - def module(self) -> _AnnotatedLazyLoader[t.LiteralString]: return _AnnotatedLazyLoader(self.model_name, globals(), f"openllm.models.{self.model_name}") + def module(self) -> _AnnotatedLazyLoader[t.LiteralString]: + return _AnnotatedLazyLoader(self.model_name, globals(), f"openllm.models.{self.model_name}")