feat: support continuous batching on generate (#375)

* feat: support continuous batching on `generate`

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>

* chore: add changelog

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>

---------

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
Aaron Pham
2023-09-19 03:04:59 -04:00
committed by GitHub
parent 9a206d8184
commit 3b2ac1cd59
4 changed files with 25 additions and 19 deletions

View File

@@ -1,7 +1,6 @@
# mypy: disable-error-code="name-defined,attr-defined"
from __future__ import annotations
import abc
import asyncio
import gc
import inspect
import logging
@@ -49,7 +48,6 @@ from openllm_core.utils import ReprMixin
from openllm_core.utils import apply
from openllm_core.utils import bentoml_cattr
from openllm_core.utils import codegen
from openllm_core.utils import ensure_exec_coro
from openllm_core.utils import first_not_none
from openllm_core.utils import generate_hash_from_file
from openllm_core.utils import is_peft_available
@@ -1213,8 +1211,9 @@ def llm_runnable_class(self: LLM[M, T], embeddings_sig: ModelSignature, generate
return ' '.join(output_text) + ' '
@bentoml.Runnable.method(**method_signature(generate_sig)) # type: ignore
def vllm_generate(__self: _Runnable, prompt: str, **attrs: t.Any) -> list[t.Any]:
async def vllm_generate(__self: _Runnable, prompt: str, **attrs: t.Any) -> t.AsyncGenerator[list[t.Any], None]:
stop: str | t.Iterable[str] | None = attrs.pop('stop', None)
echo = attrs.pop('echo', False)
stop_token_ids: list[int] | None = attrs.pop('stop_token_ids', None)
adapter_name = attrs.pop('adapter_name', None)
if adapter_name is not None: __self.set_adapter(adapter_name)
@@ -1234,18 +1233,14 @@ def llm_runnable_class(self: LLM[M, T], embeddings_sig: ModelSignature, generate
config = self.config.model_construct_env(stop=list(stop_), top_p=top_p, **attrs)
sampling_params = config.to_sampling_config()
async def loop() -> list[str]:
async for request_output in t.cast('vllm.AsyncLLMEngine', self.model).generate(prompt=prompt, sampling_params=sampling_params, request_id=request_id):
pass
return [output.text for output in request_output.outputs]
try:
return asyncio.run(loop())
except RuntimeError:
try:
return ensure_exec_coro(loop())
except Exception:
raise
final_output = None
async for request_output in t.cast('vllm.AsyncLLMEngine', self.model).generate(prompt=prompt, sampling_params=sampling_params, request_id=request_id):
final_output = request_output
if final_output is None: raise ValueError("'output' should not be None")
prompt = final_output.prompt
if echo: text_outputs = [prompt + output.text for output in final_output.outputs]
else: text_outputs = [output.text for output in final_output.outputs]
yield text_outputs
@bentoml.Runnable.method(**method_signature(generate_iterator_sig)) # type: ignore
async def vllm_generate_iterator(__self: _Runnable, prompt: str, **attrs: t.Any) -> t.AsyncGenerator[str, None]:

View File

@@ -45,10 +45,13 @@ _JsonInput = bentoml.io.JSON.from_sample({'prompt': '', 'llm_config': llm_config
@svc.api(route='/v1/generate', input=_JsonInput, output=bentoml.io.JSON.from_sample({'responses': [], 'configuration': llm_config.model_dump(flatten=True)}))
async def generate_v1(input_dict: dict[str, t.Any]) -> openllm.GenerationOutput:
echo = input_dict.pop('echo', False)
qa_inputs = openllm.GenerationInput.from_llm_config(llm_config)(**input_dict)
config = qa_inputs.llm_config.model_dump()
if runner.backend == 'vllm':
responses = await runner.vllm_generate.async_run(qa_inputs.prompt, adapter_name=qa_inputs.adapter_name, request_id=openllm_core.utils.gen_random_uuid(), **config)
async for output in runner.vllm_generate.async_stream(qa_inputs.prompt, adapter_name=qa_inputs.adapter_name, echo=echo, request_id=openllm_core.utils.gen_random_uuid(), **config):
responses = output
if responses is None: raise ValueError("'responses' should not be None.")
else:
responses = await runner.generate.async_run(qa_inputs.prompt, adapter_name=qa_inputs.adapter_name, **config)
return openllm.GenerationOutput(responses=responses, configuration=config)