feat(cli): --dtype arguments (#627)

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
Aaron Pham
2023-11-13 05:25:50 -05:00
committed by GitHub
parent f288b8a276
commit 099c0dc31b
4 changed files with 27 additions and 6 deletions

View File

@@ -188,6 +188,7 @@ class LLM(t.Generic[M, T], ReprMixin):
backend = first_not_none(
backend, os.getenv('OPENLLM_BACKEND'), default='vllm' if openllm.utils.is_vllm_available() else 'pt'
)
torch_dtype = first_not_none(os.getenv('TORCH_DTYPE'), torch_dtype, default='auto')
quantize = first_not_none(quantize, os.getenv('OPENLLM_QUANTIZE'), default=None)
# elif quantization_config is None and quantize is not None:
# quantization_config, attrs = infer_quantisation_config(self, quantize, **attrs)
@@ -254,7 +255,7 @@ class LLM(t.Generic[M, T], ReprMixin):
None,
)
if config_dtype is None:
config_type = torch.float32
config_dtype = torch.float32
if self.__llm_torch_dtype__ == 'auto':
if config_dtype == torch.float32:
torch_dtype = torch.float16 # following common practice

View File

@@ -127,6 +127,7 @@ def construct_docker_options(
environ = parse_config_options(llm.config, llm.config['timeout'], 1.0, None, True, os.environ.copy())
env_dict = {
'TORCH_DTYPE': str(llm._torch_dtype).split('.')[-1],
'OPENLLM_BACKEND': llm.__llm_backend__,
'OPENLLM_CONFIG': f"'{llm.config.model_dump_json(flatten=True).decode()}'",
'OPENLLM_SERIALIZATION': serialisation,

View File

@@ -140,6 +140,7 @@ def start_decorator(serve_grpc: bool = False) -> t.Callable[[FC], t.Callable[[FC
_OpenLLM_GenericInternalConfig().to_click_options,
_http_server_args if not serve_grpc else _grpc_server_args,
cog.optgroup.group('General LLM Options', help='The following options are related to running LLM Server.'),
dtype_option(factory=cog.optgroup),
model_version_option(factory=cog.optgroup),
system_message_option(factory=cog.optgroup),
prompt_template_file_option(factory=cog.optgroup),
@@ -300,6 +301,17 @@ def machine_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[
return cli_option('--machine', is_flag=True, default=False, hidden=True, **attrs)(f)
def dtype_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
return cli_option(
'--dtype',
type=click.Choice(['float16', 'float32', 'bfloat16']),
envvar='TORCH_DTYPE',
default='float16',
help='Optional dtype for casting tensors for running inference.',
**attrs,
)(f)
def model_id_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
return cli_option(
'--model-id',

View File

@@ -56,6 +56,7 @@ from openllm_core._typing_compat import (
Concatenate,
DictStrAny,
LiteralBackend,
LiteralDtype,
LiteralQuantise,
LiteralSerialisation,
LiteralString,
@@ -92,6 +93,7 @@ from ._factory import (
_AnyCallable,
backend_option,
container_registry_option,
dtype_option,
machine_option,
model_name_argument,
model_version_option,
@@ -403,6 +405,7 @@ def start_command(
cors: bool,
adapter_id: str | None,
return_process: bool,
dtype: LiteralDtype,
deprecated_model_id: str | None,
**attrs: t.Any,
) -> LLMConfig | subprocess.Popen[bytes]:
@@ -452,6 +455,7 @@ def start_command(
adapter_map=adapter_map,
quantize=quantize,
serialisation=serialisation,
torch_dtype=dtype,
trust_remote_code=check_bool_env('TRUST_REMOTE_CODE'),
)
backend_warning(llm.__llm_backend__)
@@ -522,6 +526,7 @@ def start_grpc_command(
backend: LiteralBackend | None,
serialisation: LiteralSerialisation | None,
cors: bool,
dtype: LiteralDtype,
adapter_id: str | None,
return_process: bool,
deprecated_model_id: str | None,
@@ -573,6 +578,7 @@ def start_grpc_command(
adapter_map=adapter_map,
quantize=quantize,
serialisation=serialisation,
torch_dtype=dtype,
trust_remote_code=check_bool_env('TRUST_REMOTE_CODE'),
)
backend_warning(llm.__llm_backend__)
@@ -872,6 +878,7 @@ class BuildBentoOutput(t.TypedDict):
metavar='[REMOTE_REPO/MODEL_ID | /path/to/local/model]',
help='Deprecated. Use positional argument instead.',
)
@dtype_option
@backend_option
@system_message_option
@prompt_template_file_option
@@ -943,6 +950,7 @@ def build_command(
overwrite: bool,
quantize: LiteralQuantise | None,
machine: bool,
dtype: LiteralDtype,
enable_features: tuple[str, ...] | None,
adapter_id: tuple[str, ...],
build_ctx: str | None,
@@ -1003,17 +1011,16 @@ def build_command(
system_message=system_message,
backend=backend,
quantize=quantize,
serialisation=t.cast(
LiteralSerialisation,
first_not_none(
serialisation, default='safetensors' if has_safetensors_weights(model_id, model_version) else 'legacy'
),
torch_dtype=dtype,
serialisation=first_not_none(
serialisation, default='safetensors' if has_safetensors_weights(model_id, model_version) else 'legacy'
),
)
backend_warning(llm.__llm_backend__, build=True)
os.environ.update(
{
'TORCH_DTYPE': dtype,
'OPENLLM_BACKEND': llm.__llm_backend__,
'OPENLLM_SERIALIZATION': llm._serialisation,
'OPENLLM_MODEL_ID': llm.model_id,