refactor: update runner helpers and add max_model_len (#712)

* chore(runner): cleanup unecessary checks for runnable backend

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>

* chore: saving llm reference to runner

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>

* chore: correct inject item

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>

* chore: update support for max_seq_len

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>

* fix: correct max_model_len

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>

* chore: update and warning backward compatibility

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>

* chore: remove unused sets

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>

---------

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
Aaron Pham
2023-11-20 20:37:15 -05:00
committed by GitHub
parent 8fc5f1f70c
commit ad4f388c98
6 changed files with 220 additions and 204 deletions

View File

@@ -392,6 +392,13 @@ def cli() -> None:
metavar='[REMOTE_REPO/MODEL_ID | /path/to/local/model]',
help='Deprecated. Use positional argument instead.',
)
@click.option(
'--max-model-len',
'--max_model_len',
'max_model_len',
default=None,
help='Maximum sequence length for the model. If not specified, we will use the default value from the model config.',
)
@start_decorator(serve_grpc=False)
def start_command(
model_id: str,
@@ -409,6 +416,7 @@ def start_command(
return_process: bool,
dtype: LiteralDtype,
deprecated_model_id: str | None,
max_model_len: int | None,
**attrs: t.Any,
) -> LLMConfig | subprocess.Popen[bytes]:
"""Start any LLM as a REST server.
@@ -466,6 +474,7 @@ def start_command(
quantize=quantize,
serialisation=serialisation,
dtype=dtype,
max_model_len=max_model_len,
)
backend_warning(llm.__llm_backend__)
@@ -521,9 +530,14 @@ def start_command(
help='Deprecated. Use positional argument instead.',
)
@start_decorator(serve_grpc=True)
@click.pass_context
@click.option(
'--max-model-len',
'--max_model_len',
'max_model_len',
default=None,
help='Maximum sequence length for the model. If not specified, we will use the default value from the model config.',
)
def start_grpc_command(
ctx: click.Context,
model_id: str,
server_timeout: int,
model_version: str | None,
@@ -539,6 +553,7 @@ def start_grpc_command(
adapter_id: str | None,
return_process: bool,
deprecated_model_id: str | None,
max_model_len: int | None,
**attrs: t.Any,
) -> LLMConfig | subprocess.Popen[bytes]:
"""Start any LLM as a gRPC server.
@@ -596,6 +611,7 @@ def start_grpc_command(
quantize=quantize,
serialisation=serialisation,
dtype=dtype,
max_model_len=max_model_len,
trust_remote_code=check_bool_env('TRUST_REMOTE_CODE'),
)
backend_warning(llm.__llm_backend__)