From d358e6853941d480554b8e40b22ebc1a46e0967f Mon Sep 17 00:00:00 2001 From: Aaron Pham <29749331+aarnphm@users.noreply.github.com> Date: Mon, 13 Nov 2023 13:48:04 -0500 Subject: [PATCH] fix(torch_dtype): load eagerly (#631) Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> --- openllm-python/src/openllm/_llm.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/openllm-python/src/openllm/_llm.py b/openllm-python/src/openllm/_llm.py index 6c152927..952eb9a2 100644 --- a/openllm-python/src/openllm/_llm.py +++ b/openllm-python/src/openllm/_llm.py @@ -249,11 +249,13 @@ class LLM(t.Generic[M, T], ReprMixin): import transformers if not isinstance(self.__llm_torch_dtype__, torch.dtype): - config_dtype = getattr( - transformers.AutoConfig.from_pretrained(self.bentomodel.path, trust_remote_code=self.trust_remote_code), - 'torch_dtype', - None, - ) + try: + hf_config = transformers.AutoConfig.from_pretrained( + self.bentomodel.path, trust_remote_code=self.trust_remote_code + ) + except OpenLLMException: + hf_config = transformers.AutoConfig.from_pretrained(self.model_id, trust_remote_code=self.trust_remote_code) + config_dtype = getattr(hf_config, 'torch_dtype', None) if config_dtype is None: config_dtype = torch.float32 if self.__llm_torch_dtype__ == 'auto':