chore: upgrade to new vLLM schema

Signed-off-by: paperspace <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
paperspace
2024-06-02 15:52:45 +00:00
parent 2e7592cd45
commit a93da12084
90 changed files with 34 additions and 1391 deletions

View File

@@ -4,6 +4,8 @@ ARG NVIDIA_DRIVER_CAPABILITIES=compute,utility
ENV NVIDIA_DRIVER_CAPABILITIES=$NVIDIA_DRIVER_CAPABILITIES
ARG HF_HUB_DISABLE_PROGRESS_BARS=TRUE
ENV HF_HUB_DISABLE_PROGRESS_BARS=$HF_HUB_DISABLE_PROGRESS_BARS
{% call common.RUN(__enable_buildkit__) -%} {{ __pip_cache__ }} {% endcall -%} bash -c 'pip install --no-color --progress-bar off "vllm==0.4.2" || true'
ARG VLLM_NO_USAGE_STATS=1
ENV VLLM_NO_USAGE_STATS=$VLLM_NO_USAGE_STATS
{% call common.RUN(__enable_buildkit__) -%} {{ __pip_cache__ }} {% endcall -%} bash -c 'pip install --no-color --progress-bar off "vllm==0.4.3" || true'
{{ super() }}
{% endblock %}

View File

@@ -263,6 +263,7 @@ def start_command(
'OPENLLM_CONFIG': llm_config.model_dump_json(),
'DTYPE': dtype,
'TRUST_REMOTE_CODE': str(trust_remote_code),
'VLLM_NO_USAGE_STATS': str(1),
'GPU_MEMORY_UTILIZATION': orjson.dumps(gpu_memory_utilization).decode(),
'SERVICES_CONFIG': orjson.dumps(
# XXX: right now we just enable GPU by default. will revisit this if we decide to support TPU later.

View File

@@ -202,7 +202,14 @@ class LLM:
"'generate_iterator' is reserved only for online serving. For batch inference use 'LLM.batch' instead."
)
from vllm import SamplingParams
from vllm import SamplingParams, TextPrompt, TokensPrompt
if prompt_token_ids is not None:
inputs = TokensPrompt(prompt_token_ids=prompt_token_ids)
else:
if prompt is None:
raise ValueError('Either "prompt" or "prompt_token_ids" must be passed.')
inputs = TextPrompt(prompt=prompt)
config = self.config.model_construct_env(**dict_filter_none(attrs))
@@ -229,12 +236,11 @@ class LLM:
try:
async for generations in self._model.generate(
prompt,
inputs=inputs,
sampling_params=SamplingParams(**{
k: config.__getitem__(k) for k in set(inspect.signature(SamplingParams).parameters.keys())
}),
request_id=request_id,
prompt_token_ids=prompt_token_ids if prompt_token_ids else None,
):
yield generations
except Exception as err:
@@ -277,7 +283,7 @@ class LLM:
"'batch' is reserved for offline batch inference. For online serving use 'LLM.generate' or 'LLM.generate_iterator' instead."
)
from vllm import SamplingParams
from vllm import SamplingParams, TextPrompt
if isinstance(prompts, str):
prompts = [prompts]
@@ -295,11 +301,10 @@ class LLM:
for i in range(num_requests):
request_id = str(next(self._counter))
prompt = prompts[i]
config = configs[i]
self._model.add_request(
request_id,
prompt,
TextPrompt(prompt=prompts[i]),
SamplingParams(**{k: config.__getitem__(k) for k in set(inspect.signature(SamplingParams).parameters.keys())}),
)