fix(cpu): more verbose definition for dtype casting (#639)

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
Aaron Pham
2023-11-13 20:40:50 -05:00
committed by GitHub
parent b20c7d1c1d
commit 2d428f12da

View File

@@ -258,13 +258,15 @@ class LLM(t.Generic[M, T], ReprMixin):
config_dtype = getattr(hf_config, 'torch_dtype', None)
if config_dtype is None:
config_dtype = torch.float32
if self.__llm_torch_dtype__ == 'auto':
if not torch.cuda.is_available():
if self.__llm_torch_dtype__ in {'auto', 'half'}:
logger.warning('"auto" and "half" are not supported on CPU. OpenLLM will default fallback to "float32".')
torch_dtype = torch.float32 # we need to cast back to full precision if cuda is not available
elif self.__llm_torch_dtype__ == 'auto':
if config_dtype == torch.float32:
torch_dtype = torch.float16 # following common practice
else:
torch_dtype = config_dtype
if not torch.cuda.is_available():
torch_dtype = torch.float32 # we need to cast back to full precision if cuda is not available
else:
if self.__llm_torch_dtype__ not in _torch_dtype_mapping():
raise ValueError(f"Unknown dtype '{self.__llm_torch_dtype__}'")