From 2d428f12dadf7f72f6cae108b2748675a20aac17 Mon Sep 17 00:00:00 2001 From: Aaron Pham <29749331+aarnphm@users.noreply.github.com> Date: Mon, 13 Nov 2023 20:40:50 -0500 Subject: [PATCH] fix(cpu): more verbose definition for dtype casting (#639) Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> --- openllm-python/src/openllm/_llm.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/openllm-python/src/openllm/_llm.py b/openllm-python/src/openllm/_llm.py index 5993ba3a..7c10da7f 100644 --- a/openllm-python/src/openllm/_llm.py +++ b/openllm-python/src/openllm/_llm.py @@ -258,13 +258,15 @@ class LLM(t.Generic[M, T], ReprMixin): config_dtype = getattr(hf_config, 'torch_dtype', None) if config_dtype is None: config_dtype = torch.float32 - if self.__llm_torch_dtype__ == 'auto': + if not torch.cuda.is_available(): + if self.__llm_torch_dtype__ in {'auto', 'half'}: + logger.warning('"auto" and "half" are not supported on CPU. OpenLLM will default fallback to "float32".') + torch_dtype = torch.float32 # we need to cast back to full precision if cuda is not available + elif self.__llm_torch_dtype__ == 'auto': if config_dtype == torch.float32: torch_dtype = torch.float16 # following common practice else: torch_dtype = config_dtype - if not torch.cuda.is_available(): - torch_dtype = torch.float32 # we need to cast back to full precision if cuda is not available else: if self.__llm_torch_dtype__ not in _torch_dtype_mapping(): raise ValueError(f"Unknown dtype '{self.__llm_torch_dtype__}'")