From 6726f6ae3ef48c60bbb474e69bc4db57e5043a58 Mon Sep 17 00:00:00 2001 From: paperspace <29749331+aarnphm@users.noreply.github.com> Date: Thu, 9 May 2024 00:06:10 +0000 Subject: [PATCH] fix: make sure to add cpu to number Signed-off-by: paperspace <29749331+aarnphm@users.noreply.github.com> --- openllm-core/src/openllm_core/_configuration.py | 15 ++++++++++++++- openllm-python/src/_openllm_tiny/_entrypoint.py | 2 +- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/openllm-core/src/openllm_core/_configuration.py b/openllm-core/src/openllm_core/_configuration.py index 8f223db6..42708ed3 100644 --- a/openllm-core/src/openllm_core/_configuration.py +++ b/openllm-core/src/openllm_core/_configuration.py @@ -3,7 +3,17 @@ import abc, inspect, logging, os, typing as t import inflection, orjson, pydantic from deepmerge.merger import Merger -from ._typing_compat import DictStrAny, ListStr, LiteralSerialisation, NotRequired, Required, Self, TypedDict, overload +from ._typing_compat import ( + DictStrAny, + ListStr, + LiteralSerialisation, + NotRequired, + Required, + Self, + TypedDict, + overload, + Annotated, +) from .exceptions import ForbiddenAttributeError, MissingDependencyError from .utils import field_env_key, first_not_none, is_vllm_available, is_transformers_available @@ -223,6 +233,9 @@ class GenerationConfig(pydantic.BaseModel): None, description='Number of log probabilities to return per output token.' ) detokenize: bool = pydantic.Field(True, description='Whether to detokenize the output.') + truncate_prompt_tokens: t.Optional[Annotated[int, pydantic.Field(ge=1)]] = pydantic.Field( + None, description='Truncate the prompt tokens.' + ) prompt_logprobs: t.Optional[int] = pydantic.Field( None, description='Number of log probabilities to return per input token.' ) diff --git a/openllm-python/src/_openllm_tiny/_entrypoint.py b/openllm-python/src/_openllm_tiny/_entrypoint.py index 006edd67..72271720 100644 --- a/openllm-python/src/_openllm_tiny/_entrypoint.py +++ b/openllm-python/src/_openllm_tiny/_entrypoint.py @@ -426,7 +426,7 @@ def build_command( labels = {'library': 'vllm'} service_config = dict( resources={ - 'gpu' if device else 'cpu': len(device) if device else 'cpu_count', + 'gpu' if device else 'cpu': len(device) if device else '1', 'gpu_type': recommended_instance_type(model_id, bentomodel), }, traffic=dict(timeout=timeout),