diff --git a/openllm-python/src/openllm/_llm.py b/openllm-python/src/openllm/_llm.py index 5993ba3a..7c10da7f 100644 --- a/openllm-python/src/openllm/_llm.py +++ b/openllm-python/src/openllm/_llm.py @@ -258,13 +258,15 @@ class LLM(t.Generic[M, T], ReprMixin): config_dtype = getattr(hf_config, 'torch_dtype', None) if config_dtype is None: config_dtype = torch.float32 - if self.__llm_torch_dtype__ == 'auto': + if not torch.cuda.is_available(): + if self.__llm_torch_dtype__ in {'auto', 'half'}: + logger.warning('"auto" and "half" are not supported on CPU. OpenLLM will default fallback to "float32".') + torch_dtype = torch.float32 # we need to cast back to full precision if cuda is not available + elif self.__llm_torch_dtype__ == 'auto': if config_dtype == torch.float32: torch_dtype = torch.float16 # following common practice else: torch_dtype = config_dtype - if not torch.cuda.is_available(): - torch_dtype = torch.float32 # we need to cast back to full precision if cuda is not available else: if self.__llm_torch_dtype__ not in _torch_dtype_mapping(): raise ValueError(f"Unknown dtype '{self.__llm_torch_dtype__}'")