From f753662ae6d4bbb4dc3bc3d9af6d1e7c35e3d63b Mon Sep 17 00:00:00 2001 From: Aaron <29749331+aarnphm@users.noreply.github.com> Date: Mon, 20 Nov 2023 17:06:25 -0500 Subject: [PATCH] fix(build): only load model when eager is True Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> --- openllm-python/src/openllm/_llm.py | 16 ++++++++++------ openllm-python/src/openllm_cli/entrypoint.py | 7 +++++++ 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/openllm-python/src/openllm/_llm.py b/openllm-python/src/openllm/_llm.py index 414403b5..a11a6d0c 100644 --- a/openllm-python/src/openllm/_llm.py +++ b/openllm-python/src/openllm/_llm.py @@ -163,6 +163,7 @@ class LLM(t.Generic[M, T], ReprMixin): embedded=False, dtype='auto', low_cpu_mem_usage=True, + _eager=True, **attrs, ): # fmt: off @@ -201,12 +202,15 @@ class LLM(t.Generic[M, T], ReprMixin): llm_trust_remote_code__=trust_remote_code, ) - try: - model = bentoml.models.get(self.tag) - except bentoml.exceptions.NotFound: - model = openllm.serialisation.import_model(self, trust_remote_code=self.trust_remote_code) - # resolve the tag - self._tag = model.tag + if _eager: + try: + model = bentoml.models.get(self.tag) + except bentoml.exceptions.NotFound: + model = openllm.serialisation.import_model(self, trust_remote_code=self.trust_remote_code) + # resolve the tag + self._tag = model.tag + if not _eager and embedded: + raise RuntimeError("Embedded mode is not supported when '_eager' is False.") if embedded and not get_disable_warnings() and not get_quiet_mode(): logger.warning( 'You are using embedded mode, which means the models will be loaded into memory. This is often not recommended in production and should only be used for local development only.' diff --git a/openllm-python/src/openllm_cli/entrypoint.py b/openllm-python/src/openllm_cli/entrypoint.py index 49348038..9a1a6594 100644 --- a/openllm-python/src/openllm_cli/entrypoint.py +++ b/openllm-python/src/openllm_cli/entrypoint.py @@ -1047,10 +1047,17 @@ def build_command( serialisation=first_not_none( serialisation, default='safetensors' if has_safetensors_weights(model_id, model_version) else 'legacy' ), + _eager=False, ) if llm.__llm_backend__ not in llm.config['backend']: raise click.ClickException(f"'{backend}' is not supported with {model_id}") backend_warning(llm.__llm_backend__, build=True) + try: + model = bentoml.models.get(llm.tag) + except bentoml.exceptions.NotFound: + model = openllm.serialisation.import_model(llm, trust_remote_code=llm.trust_remote_code) + llm._tag = model.tag + os.environ.update( **process_environ( llm.config,