mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-03-06 08:08:03 -05:00
feat(cli): --dtype arguments (#627)
Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
@@ -188,6 +188,7 @@ class LLM(t.Generic[M, T], ReprMixin):
|
||||
backend = first_not_none(
|
||||
backend, os.getenv('OPENLLM_BACKEND'), default='vllm' if openllm.utils.is_vllm_available() else 'pt'
|
||||
)
|
||||
torch_dtype = first_not_none(os.getenv('TORCH_DTYPE'), torch_dtype, default='auto')
|
||||
quantize = first_not_none(quantize, os.getenv('OPENLLM_QUANTIZE'), default=None)
|
||||
# elif quantization_config is None and quantize is not None:
|
||||
# quantization_config, attrs = infer_quantisation_config(self, quantize, **attrs)
|
||||
@@ -254,7 +255,7 @@ class LLM(t.Generic[M, T], ReprMixin):
|
||||
None,
|
||||
)
|
||||
if config_dtype is None:
|
||||
config_type = torch.float32
|
||||
config_dtype = torch.float32
|
||||
if self.__llm_torch_dtype__ == 'auto':
|
||||
if config_dtype == torch.float32:
|
||||
torch_dtype = torch.float16 # following common practice
|
||||
|
||||
@@ -127,6 +127,7 @@ def construct_docker_options(
|
||||
|
||||
environ = parse_config_options(llm.config, llm.config['timeout'], 1.0, None, True, os.environ.copy())
|
||||
env_dict = {
|
||||
'TORCH_DTYPE': str(llm._torch_dtype).split('.')[-1],
|
||||
'OPENLLM_BACKEND': llm.__llm_backend__,
|
||||
'OPENLLM_CONFIG': f"'{llm.config.model_dump_json(flatten=True).decode()}'",
|
||||
'OPENLLM_SERIALIZATION': serialisation,
|
||||
|
||||
@@ -140,6 +140,7 @@ def start_decorator(serve_grpc: bool = False) -> t.Callable[[FC], t.Callable[[FC
|
||||
_OpenLLM_GenericInternalConfig().to_click_options,
|
||||
_http_server_args if not serve_grpc else _grpc_server_args,
|
||||
cog.optgroup.group('General LLM Options', help='The following options are related to running LLM Server.'),
|
||||
dtype_option(factory=cog.optgroup),
|
||||
model_version_option(factory=cog.optgroup),
|
||||
system_message_option(factory=cog.optgroup),
|
||||
prompt_template_file_option(factory=cog.optgroup),
|
||||
@@ -300,6 +301,17 @@ def machine_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[
|
||||
return cli_option('--machine', is_flag=True, default=False, hidden=True, **attrs)(f)
|
||||
|
||||
|
||||
def dtype_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option(
|
||||
'--dtype',
|
||||
type=click.Choice(['float16', 'float32', 'bfloat16']),
|
||||
envvar='TORCH_DTYPE',
|
||||
default='float16',
|
||||
help='Optional dtype for casting tensors for running inference.',
|
||||
**attrs,
|
||||
)(f)
|
||||
|
||||
|
||||
def model_id_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option(
|
||||
'--model-id',
|
||||
|
||||
@@ -56,6 +56,7 @@ from openllm_core._typing_compat import (
|
||||
Concatenate,
|
||||
DictStrAny,
|
||||
LiteralBackend,
|
||||
LiteralDtype,
|
||||
LiteralQuantise,
|
||||
LiteralSerialisation,
|
||||
LiteralString,
|
||||
@@ -92,6 +93,7 @@ from ._factory import (
|
||||
_AnyCallable,
|
||||
backend_option,
|
||||
container_registry_option,
|
||||
dtype_option,
|
||||
machine_option,
|
||||
model_name_argument,
|
||||
model_version_option,
|
||||
@@ -403,6 +405,7 @@ def start_command(
|
||||
cors: bool,
|
||||
adapter_id: str | None,
|
||||
return_process: bool,
|
||||
dtype: LiteralDtype,
|
||||
deprecated_model_id: str | None,
|
||||
**attrs: t.Any,
|
||||
) -> LLMConfig | subprocess.Popen[bytes]:
|
||||
@@ -452,6 +455,7 @@ def start_command(
|
||||
adapter_map=adapter_map,
|
||||
quantize=quantize,
|
||||
serialisation=serialisation,
|
||||
torch_dtype=dtype,
|
||||
trust_remote_code=check_bool_env('TRUST_REMOTE_CODE'),
|
||||
)
|
||||
backend_warning(llm.__llm_backend__)
|
||||
@@ -522,6 +526,7 @@ def start_grpc_command(
|
||||
backend: LiteralBackend | None,
|
||||
serialisation: LiteralSerialisation | None,
|
||||
cors: bool,
|
||||
dtype: LiteralDtype,
|
||||
adapter_id: str | None,
|
||||
return_process: bool,
|
||||
deprecated_model_id: str | None,
|
||||
@@ -573,6 +578,7 @@ def start_grpc_command(
|
||||
adapter_map=adapter_map,
|
||||
quantize=quantize,
|
||||
serialisation=serialisation,
|
||||
torch_dtype=dtype,
|
||||
trust_remote_code=check_bool_env('TRUST_REMOTE_CODE'),
|
||||
)
|
||||
backend_warning(llm.__llm_backend__)
|
||||
@@ -872,6 +878,7 @@ class BuildBentoOutput(t.TypedDict):
|
||||
metavar='[REMOTE_REPO/MODEL_ID | /path/to/local/model]',
|
||||
help='Deprecated. Use positional argument instead.',
|
||||
)
|
||||
@dtype_option
|
||||
@backend_option
|
||||
@system_message_option
|
||||
@prompt_template_file_option
|
||||
@@ -943,6 +950,7 @@ def build_command(
|
||||
overwrite: bool,
|
||||
quantize: LiteralQuantise | None,
|
||||
machine: bool,
|
||||
dtype: LiteralDtype,
|
||||
enable_features: tuple[str, ...] | None,
|
||||
adapter_id: tuple[str, ...],
|
||||
build_ctx: str | None,
|
||||
@@ -1003,17 +1011,16 @@ def build_command(
|
||||
system_message=system_message,
|
||||
backend=backend,
|
||||
quantize=quantize,
|
||||
serialisation=t.cast(
|
||||
LiteralSerialisation,
|
||||
first_not_none(
|
||||
serialisation, default='safetensors' if has_safetensors_weights(model_id, model_version) else 'legacy'
|
||||
),
|
||||
torch_dtype=dtype,
|
||||
serialisation=first_not_none(
|
||||
serialisation, default='safetensors' if has_safetensors_weights(model_id, model_version) else 'legacy'
|
||||
),
|
||||
)
|
||||
backend_warning(llm.__llm_backend__, build=True)
|
||||
|
||||
os.environ.update(
|
||||
{
|
||||
'TORCH_DTYPE': dtype,
|
||||
'OPENLLM_BACKEND': llm.__llm_backend__,
|
||||
'OPENLLM_SERIALIZATION': llm._serialisation,
|
||||
'OPENLLM_MODEL_ID': llm.model_id,
|
||||
|
||||
Reference in New Issue
Block a user