diff --git a/openllm-python/src/openllm/_llm.py b/openllm-python/src/openllm/_llm.py index 4625d943..de3d6c25 100644 --- a/openllm-python/src/openllm/_llm.py +++ b/openllm-python/src/openllm/_llm.py @@ -221,12 +221,8 @@ class LLM(t.Generic[M, T], ReprMixin): config_dtype = getattr(hf_config, 'torch_dtype', None) if config_dtype is None: config_dtype = torch.float32 - if not torch.cuda.is_available(): - if self.__llm_torch_dtype__ in {'auto', 'half'} and not get_disable_warnings() and not get_quiet_mode(): - logger.warning('"auto" and "half" are not supported on CPU. OpenLLM will default fallback to "float32".') - torch_dtype = torch.float32 # we need to cast back to full precision if cuda is not available - elif self.__llm_torch_dtype__ == 'auto': - if config_dtype == torch.float32: + if self.__llm_torch_dtype__ == 'auto': + if config_dtype == torch.float32 and torch.cuda.is_available(): torch_dtype = torch.float16 # following common practice else: torch_dtype = config_dtype diff --git a/openllm-python/src/openllm_cli/entrypoint.py b/openllm-python/src/openllm_cli/entrypoint.py index 91991b92..ea388da1 100644 --- a/openllm-python/src/openllm_cli/entrypoint.py +++ b/openllm-python/src/openllm_cli/entrypoint.py @@ -448,6 +448,14 @@ def start_command( if not get_debug_mode(): termui.info("To disable these warnings, set 'OPENLLM_DISABLE_WARNING=True'") + import torch + + if not torch.cuda.is_available(): + if dtype == 'auto': + dtype = 'float' + elif dtype not in {'float', 'float32'} and not get_disable_warnings() and not get_quiet_mode(): + termui.warning('"bfloat16" and "half" are not supported on CPU. OpenLLM will default fallback to "float32".') + dtype = 'float' # we need to cast back to full precision if cuda is not available llm = openllm.LLM[t.Any, t.Any]( model_id=model_id, model_version=model_version, @@ -570,6 +578,14 @@ def start_grpc_command( if not get_debug_mode(): termui.info("To disable these warnings, set 'OPENLLM_DISABLE_WARNING=True'") + import torch + + if not torch.cuda.is_available(): + if dtype == 'auto': + dtype = 'float' + elif dtype not in {'float', 'float32'} and not get_disable_warnings() and not get_quiet_mode(): + termui.warning('"bfloat16" and "half" are not supported on CPU. OpenLLM will default fallback to "float32".') + dtype = 'float' # we need to cast back to full precision if cuda is not available llm = openllm.LLM[t.Any, t.Any]( model_id=model_id, model_version=model_version,