fix(cli): set default dtype to auto infer (#642)

Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
Aaron Pham
2023-11-13 23:05:27 -05:00
committed by GitHub
parent 99a5d26527
commit b30a412398

View File

@@ -304,9 +304,9 @@ def machine_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[
def dtype_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
return cli_option(
'--dtype',
type=click.Choice(['float16', 'float32', 'bfloat16']),
type=click.Choice(['float16', 'float32', 'bfloat16', 'auto']),
envvar='TORCH_DTYPE',
default='float16',
default='auto',
help='Optional dtype for casting tensors for running inference.',
**attrs,
)(f)