diff --git a/openllm-python/src/openllm/_llm.py b/openllm-python/src/openllm/_llm.py index 6c152927..952eb9a2 100644 --- a/openllm-python/src/openllm/_llm.py +++ b/openllm-python/src/openllm/_llm.py @@ -249,11 +249,13 @@ class LLM(t.Generic[M, T], ReprMixin): import transformers if not isinstance(self.__llm_torch_dtype__, torch.dtype): - config_dtype = getattr( - transformers.AutoConfig.from_pretrained(self.bentomodel.path, trust_remote_code=self.trust_remote_code), - 'torch_dtype', - None, - ) + try: + hf_config = transformers.AutoConfig.from_pretrained( + self.bentomodel.path, trust_remote_code=self.trust_remote_code + ) + except OpenLLMException: + hf_config = transformers.AutoConfig.from_pretrained(self.model_id, trust_remote_code=self.trust_remote_code) + config_dtype = getattr(hf_config, 'torch_dtype', None) if config_dtype is None: config_dtype = torch.float32 if self.__llm_torch_dtype__ == 'auto':