fix(torch_dtype): correctly infer based on options (#682)

Users should be able to set the dtype during build, as we it doesn't effect start time

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
Aaron Pham
2023-11-17 10:52:05 -05:00
committed by GitHub
parent 7402408c5f
commit 14b3ceb436
2 changed files with 18 additions and 6 deletions

View File

@@ -221,12 +221,8 @@ class LLM(t.Generic[M, T], ReprMixin):
config_dtype = getattr(hf_config, 'torch_dtype', None)
if config_dtype is None:
config_dtype = torch.float32
if not torch.cuda.is_available():
if self.__llm_torch_dtype__ in {'auto', 'half'} and not get_disable_warnings() and not get_quiet_mode():
logger.warning('"auto" and "half" are not supported on CPU. OpenLLM will default fallback to "float32".')
torch_dtype = torch.float32 # we need to cast back to full precision if cuda is not available
elif self.__llm_torch_dtype__ == 'auto':
if config_dtype == torch.float32:
if self.__llm_torch_dtype__ == 'auto':
if config_dtype == torch.float32 and torch.cuda.is_available():
torch_dtype = torch.float16 # following common practice
else:
torch_dtype = config_dtype

View File

@@ -448,6 +448,14 @@ def start_command(
if not get_debug_mode():
termui.info("To disable these warnings, set 'OPENLLM_DISABLE_WARNING=True'")
import torch
if not torch.cuda.is_available():
if dtype == 'auto':
dtype = 'float'
elif dtype not in {'float', 'float32'} and not get_disable_warnings() and not get_quiet_mode():
termui.warning('"bfloat16" and "half" are not supported on CPU. OpenLLM will default fallback to "float32".')
dtype = 'float' # we need to cast back to full precision if cuda is not available
llm = openllm.LLM[t.Any, t.Any](
model_id=model_id,
model_version=model_version,
@@ -570,6 +578,14 @@ def start_grpc_command(
if not get_debug_mode():
termui.info("To disable these warnings, set 'OPENLLM_DISABLE_WARNING=True'")
import torch
if not torch.cuda.is_available():
if dtype == 'auto':
dtype = 'float'
elif dtype not in {'float', 'float32'} and not get_disable_warnings() and not get_quiet_mode():
termui.warning('"bfloat16" and "half" are not supported on CPU. OpenLLM will default fallback to "float32".')
dtype = 'float' # we need to cast back to full precision if cuda is not available
llm = openllm.LLM[t.Any, t.Any](
model_id=model_id,
model_version=model_version,