fix(torch_dtype): correctly infer based on options (#682)

Users should be able to set the dtype during build, as we it doesn't effect start time

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
Aaron Pham
2023-11-17 10:52:05 -05:00
committed by GitHub
parent 7402408c5f
commit 14b3ceb436
2 changed files with 18 additions and 6 deletions

View File

@@ -221,12 +221,8 @@ class LLM(t.Generic[M, T], ReprMixin):
config_dtype = getattr(hf_config, 'torch_dtype', None)
if config_dtype is None:
config_dtype = torch.float32
if not torch.cuda.is_available():
if self.__llm_torch_dtype__ in {'auto', 'half'} and not get_disable_warnings() and not get_quiet_mode():
logger.warning('"auto" and "half" are not supported on CPU. OpenLLM will default fallback to "float32".')
torch_dtype = torch.float32 # we need to cast back to full precision if cuda is not available
elif self.__llm_torch_dtype__ == 'auto':
if config_dtype == torch.float32:
if self.__llm_torch_dtype__ == 'auto':
if config_dtype == torch.float32 and torch.cuda.is_available():
torch_dtype = torch.float16 # following common practice
else:
torch_dtype = config_dtype