mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-05-19 14:16:22 -04:00
feat: support continuous batching on generate (#375)
* feat: support continuous batching on `generate` Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> * chore: add changelog Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> --------- Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
# mypy: disable-error-code="name-defined,attr-defined"
|
||||
from __future__ import annotations
|
||||
import abc
|
||||
import asyncio
|
||||
import gc
|
||||
import inspect
|
||||
import logging
|
||||
@@ -49,7 +48,6 @@ from openllm_core.utils import ReprMixin
|
||||
from openllm_core.utils import apply
|
||||
from openllm_core.utils import bentoml_cattr
|
||||
from openllm_core.utils import codegen
|
||||
from openllm_core.utils import ensure_exec_coro
|
||||
from openllm_core.utils import first_not_none
|
||||
from openllm_core.utils import generate_hash_from_file
|
||||
from openllm_core.utils import is_peft_available
|
||||
@@ -1213,8 +1211,9 @@ def llm_runnable_class(self: LLM[M, T], embeddings_sig: ModelSignature, generate
|
||||
return ' '.join(output_text) + ' '
|
||||
|
||||
@bentoml.Runnable.method(**method_signature(generate_sig)) # type: ignore
|
||||
def vllm_generate(__self: _Runnable, prompt: str, **attrs: t.Any) -> list[t.Any]:
|
||||
async def vllm_generate(__self: _Runnable, prompt: str, **attrs: t.Any) -> t.AsyncGenerator[list[t.Any], None]:
|
||||
stop: str | t.Iterable[str] | None = attrs.pop('stop', None)
|
||||
echo = attrs.pop('echo', False)
|
||||
stop_token_ids: list[int] | None = attrs.pop('stop_token_ids', None)
|
||||
adapter_name = attrs.pop('adapter_name', None)
|
||||
if adapter_name is not None: __self.set_adapter(adapter_name)
|
||||
@@ -1234,18 +1233,14 @@ def llm_runnable_class(self: LLM[M, T], embeddings_sig: ModelSignature, generate
|
||||
config = self.config.model_construct_env(stop=list(stop_), top_p=top_p, **attrs)
|
||||
sampling_params = config.to_sampling_config()
|
||||
|
||||
async def loop() -> list[str]:
|
||||
async for request_output in t.cast('vllm.AsyncLLMEngine', self.model).generate(prompt=prompt, sampling_params=sampling_params, request_id=request_id):
|
||||
pass
|
||||
return [output.text for output in request_output.outputs]
|
||||
|
||||
try:
|
||||
return asyncio.run(loop())
|
||||
except RuntimeError:
|
||||
try:
|
||||
return ensure_exec_coro(loop())
|
||||
except Exception:
|
||||
raise
|
||||
final_output = None
|
||||
async for request_output in t.cast('vllm.AsyncLLMEngine', self.model).generate(prompt=prompt, sampling_params=sampling_params, request_id=request_id):
|
||||
final_output = request_output
|
||||
if final_output is None: raise ValueError("'output' should not be None")
|
||||
prompt = final_output.prompt
|
||||
if echo: text_outputs = [prompt + output.text for output in final_output.outputs]
|
||||
else: text_outputs = [output.text for output in final_output.outputs]
|
||||
yield text_outputs
|
||||
|
||||
@bentoml.Runnable.method(**method_signature(generate_iterator_sig)) # type: ignore
|
||||
async def vllm_generate_iterator(__self: _Runnable, prompt: str, **attrs: t.Any) -> t.AsyncGenerator[str, None]:
|
||||
|
||||
@@ -45,10 +45,13 @@ _JsonInput = bentoml.io.JSON.from_sample({'prompt': '', 'llm_config': llm_config
|
||||
|
||||
@svc.api(route='/v1/generate', input=_JsonInput, output=bentoml.io.JSON.from_sample({'responses': [], 'configuration': llm_config.model_dump(flatten=True)}))
|
||||
async def generate_v1(input_dict: dict[str, t.Any]) -> openllm.GenerationOutput:
|
||||
echo = input_dict.pop('echo', False)
|
||||
qa_inputs = openllm.GenerationInput.from_llm_config(llm_config)(**input_dict)
|
||||
config = qa_inputs.llm_config.model_dump()
|
||||
if runner.backend == 'vllm':
|
||||
responses = await runner.vllm_generate.async_run(qa_inputs.prompt, adapter_name=qa_inputs.adapter_name, request_id=openllm_core.utils.gen_random_uuid(), **config)
|
||||
async for output in runner.vllm_generate.async_stream(qa_inputs.prompt, adapter_name=qa_inputs.adapter_name, echo=echo, request_id=openllm_core.utils.gen_random_uuid(), **config):
|
||||
responses = output
|
||||
if responses is None: raise ValueError("'responses' should not be None.")
|
||||
else:
|
||||
responses = await runner.generate.async_run(qa_inputs.prompt, adapter_name=qa_inputs.adapter_name, **config)
|
||||
return openllm.GenerationOutput(responses=responses, configuration=config)
|
||||
|
||||
Reference in New Issue
Block a user