diff --git a/openllm-python/src/openllm/_llm.py b/openllm-python/src/openllm/_llm.py index 414403b5..a11a6d0c 100644 --- a/openllm-python/src/openllm/_llm.py +++ b/openllm-python/src/openllm/_llm.py @@ -163,6 +163,7 @@ class LLM(t.Generic[M, T], ReprMixin): embedded=False, dtype='auto', low_cpu_mem_usage=True, + _eager=True, **attrs, ): # fmt: off @@ -201,12 +202,15 @@ class LLM(t.Generic[M, T], ReprMixin): llm_trust_remote_code__=trust_remote_code, ) - try: - model = bentoml.models.get(self.tag) - except bentoml.exceptions.NotFound: - model = openllm.serialisation.import_model(self, trust_remote_code=self.trust_remote_code) - # resolve the tag - self._tag = model.tag + if _eager: + try: + model = bentoml.models.get(self.tag) + except bentoml.exceptions.NotFound: + model = openllm.serialisation.import_model(self, trust_remote_code=self.trust_remote_code) + # resolve the tag + self._tag = model.tag + if not _eager and embedded: + raise RuntimeError("Embedded mode is not supported when '_eager' is False.") if embedded and not get_disable_warnings() and not get_quiet_mode(): logger.warning( 'You are using embedded mode, which means the models will be loaded into memory. This is often not recommended in production and should only be used for local development only.' diff --git a/openllm-python/src/openllm_cli/entrypoint.py b/openllm-python/src/openllm_cli/entrypoint.py index 49348038..9a1a6594 100644 --- a/openllm-python/src/openllm_cli/entrypoint.py +++ b/openllm-python/src/openllm_cli/entrypoint.py @@ -1047,10 +1047,17 @@ def build_command( serialisation=first_not_none( serialisation, default='safetensors' if has_safetensors_weights(model_id, model_version) else 'legacy' ), + _eager=False, ) if llm.__llm_backend__ not in llm.config['backend']: raise click.ClickException(f"'{backend}' is not supported with {model_id}") backend_warning(llm.__llm_backend__, build=True) + try: + model = bentoml.models.get(llm.tag) + except bentoml.exceptions.NotFound: + model = openllm.serialisation.import_model(llm, trust_remote_code=llm.trust_remote_code) + llm._tag = model.tag + os.environ.update( **process_environ( llm.config,