From 099c0dc31b5a2d08c553dc0bad30d61e91269c46 Mon Sep 17 00:00:00 2001 From: Aaron Pham <29749331+aarnphm@users.noreply.github.com> Date: Mon, 13 Nov 2023 05:25:50 -0500 Subject: [PATCH] feat(cli): `--dtype` arguments (#627) Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> --- openllm-python/src/openllm/_llm.py | 3 ++- openllm-python/src/openllm/bundle/_package.py | 1 + openllm-python/src/openllm_cli/_factory.py | 12 ++++++++++++ openllm-python/src/openllm_cli/entrypoint.py | 17 ++++++++++++----- 4 files changed, 27 insertions(+), 6 deletions(-) diff --git a/openllm-python/src/openllm/_llm.py b/openllm-python/src/openllm/_llm.py index 887ed3f1..6c152927 100644 --- a/openllm-python/src/openllm/_llm.py +++ b/openllm-python/src/openllm/_llm.py @@ -188,6 +188,7 @@ class LLM(t.Generic[M, T], ReprMixin): backend = first_not_none( backend, os.getenv('OPENLLM_BACKEND'), default='vllm' if openllm.utils.is_vllm_available() else 'pt' ) + torch_dtype = first_not_none(os.getenv('TORCH_DTYPE'), torch_dtype, default='auto') quantize = first_not_none(quantize, os.getenv('OPENLLM_QUANTIZE'), default=None) # elif quantization_config is None and quantize is not None: # quantization_config, attrs = infer_quantisation_config(self, quantize, **attrs) @@ -254,7 +255,7 @@ class LLM(t.Generic[M, T], ReprMixin): None, ) if config_dtype is None: - config_type = torch.float32 + config_dtype = torch.float32 if self.__llm_torch_dtype__ == 'auto': if config_dtype == torch.float32: torch_dtype = torch.float16 # following common practice diff --git a/openllm-python/src/openllm/bundle/_package.py b/openllm-python/src/openllm/bundle/_package.py index 4ebd3df9..f689fbcb 100644 --- a/openllm-python/src/openllm/bundle/_package.py +++ b/openllm-python/src/openllm/bundle/_package.py @@ -127,6 +127,7 @@ def construct_docker_options( environ = parse_config_options(llm.config, llm.config['timeout'], 1.0, None, True, os.environ.copy()) env_dict = { + 'TORCH_DTYPE': str(llm._torch_dtype).split('.')[-1], 'OPENLLM_BACKEND': llm.__llm_backend__, 'OPENLLM_CONFIG': f"'{llm.config.model_dump_json(flatten=True).decode()}'", 'OPENLLM_SERIALIZATION': serialisation, diff --git a/openllm-python/src/openllm_cli/_factory.py b/openllm-python/src/openllm_cli/_factory.py index 03051dfd..40e06bf7 100644 --- a/openllm-python/src/openllm_cli/_factory.py +++ b/openllm-python/src/openllm_cli/_factory.py @@ -140,6 +140,7 @@ def start_decorator(serve_grpc: bool = False) -> t.Callable[[FC], t.Callable[[FC _OpenLLM_GenericInternalConfig().to_click_options, _http_server_args if not serve_grpc else _grpc_server_args, cog.optgroup.group('General LLM Options', help='The following options are related to running LLM Server.'), + dtype_option(factory=cog.optgroup), model_version_option(factory=cog.optgroup), system_message_option(factory=cog.optgroup), prompt_template_file_option(factory=cog.optgroup), @@ -300,6 +301,17 @@ def machine_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[ return cli_option('--machine', is_flag=True, default=False, hidden=True, **attrs)(f) +def dtype_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]: + return cli_option( + '--dtype', + type=click.Choice(['float16', 'float32', 'bfloat16']), + envvar='TORCH_DTYPE', + default='float16', + help='Optional dtype for casting tensors for running inference.', + **attrs, + )(f) + + def model_id_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]: return cli_option( '--model-id', diff --git a/openllm-python/src/openllm_cli/entrypoint.py b/openllm-python/src/openllm_cli/entrypoint.py index f4f738d0..59369822 100644 --- a/openllm-python/src/openllm_cli/entrypoint.py +++ b/openllm-python/src/openllm_cli/entrypoint.py @@ -56,6 +56,7 @@ from openllm_core._typing_compat import ( Concatenate, DictStrAny, LiteralBackend, + LiteralDtype, LiteralQuantise, LiteralSerialisation, LiteralString, @@ -92,6 +93,7 @@ from ._factory import ( _AnyCallable, backend_option, container_registry_option, + dtype_option, machine_option, model_name_argument, model_version_option, @@ -403,6 +405,7 @@ def start_command( cors: bool, adapter_id: str | None, return_process: bool, + dtype: LiteralDtype, deprecated_model_id: str | None, **attrs: t.Any, ) -> LLMConfig | subprocess.Popen[bytes]: @@ -452,6 +455,7 @@ def start_command( adapter_map=adapter_map, quantize=quantize, serialisation=serialisation, + torch_dtype=dtype, trust_remote_code=check_bool_env('TRUST_REMOTE_CODE'), ) backend_warning(llm.__llm_backend__) @@ -522,6 +526,7 @@ def start_grpc_command( backend: LiteralBackend | None, serialisation: LiteralSerialisation | None, cors: bool, + dtype: LiteralDtype, adapter_id: str | None, return_process: bool, deprecated_model_id: str | None, @@ -573,6 +578,7 @@ def start_grpc_command( adapter_map=adapter_map, quantize=quantize, serialisation=serialisation, + torch_dtype=dtype, trust_remote_code=check_bool_env('TRUST_REMOTE_CODE'), ) backend_warning(llm.__llm_backend__) @@ -872,6 +878,7 @@ class BuildBentoOutput(t.TypedDict): metavar='[REMOTE_REPO/MODEL_ID | /path/to/local/model]', help='Deprecated. Use positional argument instead.', ) +@dtype_option @backend_option @system_message_option @prompt_template_file_option @@ -943,6 +950,7 @@ def build_command( overwrite: bool, quantize: LiteralQuantise | None, machine: bool, + dtype: LiteralDtype, enable_features: tuple[str, ...] | None, adapter_id: tuple[str, ...], build_ctx: str | None, @@ -1003,17 +1011,16 @@ def build_command( system_message=system_message, backend=backend, quantize=quantize, - serialisation=t.cast( - LiteralSerialisation, - first_not_none( - serialisation, default='safetensors' if has_safetensors_weights(model_id, model_version) else 'legacy' - ), + torch_dtype=dtype, + serialisation=first_not_none( + serialisation, default='safetensors' if has_safetensors_weights(model_id, model_version) else 'legacy' ), ) backend_warning(llm.__llm_backend__, build=True) os.environ.update( { + 'TORCH_DTYPE': dtype, 'OPENLLM_BACKEND': llm.__llm_backend__, 'OPENLLM_SERIALIZATION': llm._serialisation, 'OPENLLM_MODEL_ID': llm.model_id,