mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-02-20 07:33:55 -05:00
fix(cpu): more verbose definition for dtype casting (#639)
Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
@@ -258,13 +258,15 @@ class LLM(t.Generic[M, T], ReprMixin):
|
||||
config_dtype = getattr(hf_config, 'torch_dtype', None)
|
||||
if config_dtype is None:
|
||||
config_dtype = torch.float32
|
||||
if self.__llm_torch_dtype__ == 'auto':
|
||||
if not torch.cuda.is_available():
|
||||
if self.__llm_torch_dtype__ in {'auto', 'half'}:
|
||||
logger.warning('"auto" and "half" are not supported on CPU. OpenLLM will default fallback to "float32".')
|
||||
torch_dtype = torch.float32 # we need to cast back to full precision if cuda is not available
|
||||
elif self.__llm_torch_dtype__ == 'auto':
|
||||
if config_dtype == torch.float32:
|
||||
torch_dtype = torch.float16 # following common practice
|
||||
else:
|
||||
torch_dtype = config_dtype
|
||||
if not torch.cuda.is_available():
|
||||
torch_dtype = torch.float32 # we need to cast back to full precision if cuda is not available
|
||||
else:
|
||||
if self.__llm_torch_dtype__ not in _torch_dtype_mapping():
|
||||
raise ValueError(f"Unknown dtype '{self.__llm_torch_dtype__}'")
|
||||
|
||||
Reference in New Issue
Block a user