From 3b2ac1cd596541e9ac7c0ca0f93291582bd02d85 Mon Sep 17 00:00:00 2001 From: Aaron Pham <29749331+aarnphm@users.noreply.github.com> Date: Tue, 19 Sep 2023 03:04:59 -0400 Subject: [PATCH] feat: support continuous batching on `generate` (#375) * feat: support continuous batching on `generate` Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> * chore: add changelog Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> --------- Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> --- bench.py | 13 ++++++++++--- changelog.d/375.feature.md | 1 + openllm-python/src/openllm/_llm.py | 25 ++++++++++--------------- openllm-python/src/openllm/_service.py | 5 ++++- 4 files changed, 25 insertions(+), 19 deletions(-) create mode 100644 changelog.d/375.feature.md diff --git a/bench.py b/bench.py index dd2c670f..e65ed45c 100755 --- a/bench.py +++ b/bench.py @@ -1,4 +1,5 @@ from __future__ import annotations +import argparse import asyncio import json @@ -14,8 +15,9 @@ async def send_request(url, prompt, session, model, **attrs): result = await response.text() print('-' * 10 + '\n\n prompt:', prompt, '\nGeneration:', result, '\n\n' + '-' * 10) -async def main(): - url = 'http://localhost:3000/v1/generate_stream' +async def main(args: argparse.Namespace) -> int: + endpoint = 'generate' if args.generate else 'generate_stream' + url = f'http://localhost:3000/v1/{endpoint}' # len=100 prompts = [ 'What is the meaning of life?', @@ -135,5 +137,10 @@ async def main(): ] async with aiohttp.ClientSession() as session: await asyncio.gather(*[send_request(url, prompt, session, 'llama', max_new_tokens=4096, top_p=0.21) for _, prompt in enumerate(prompts)]) + return 0 -if __name__ == '__main__': asyncio.run(main()) +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--generate', default=False, action='store_true', help='Whether to test with stream endpoint.') + args = parser.parse_args() + raise SystemExit(asyncio.run(main(args))) diff --git a/changelog.d/375.feature.md b/changelog.d/375.feature.md new file mode 100644 index 00000000..7c02896c --- /dev/null +++ b/changelog.d/375.feature.md @@ -0,0 +1 @@ +Added support for continuous batching on `/v1/generate` diff --git a/openllm-python/src/openllm/_llm.py b/openllm-python/src/openllm/_llm.py index 8f53bd25..3c318cf3 100644 --- a/openllm-python/src/openllm/_llm.py +++ b/openllm-python/src/openllm/_llm.py @@ -1,7 +1,6 @@ # mypy: disable-error-code="name-defined,attr-defined" from __future__ import annotations import abc -import asyncio import gc import inspect import logging @@ -49,7 +48,6 @@ from openllm_core.utils import ReprMixin from openllm_core.utils import apply from openllm_core.utils import bentoml_cattr from openllm_core.utils import codegen -from openllm_core.utils import ensure_exec_coro from openllm_core.utils import first_not_none from openllm_core.utils import generate_hash_from_file from openllm_core.utils import is_peft_available @@ -1213,8 +1211,9 @@ def llm_runnable_class(self: LLM[M, T], embeddings_sig: ModelSignature, generate return ' '.join(output_text) + ' ' @bentoml.Runnable.method(**method_signature(generate_sig)) # type: ignore - def vllm_generate(__self: _Runnable, prompt: str, **attrs: t.Any) -> list[t.Any]: + async def vllm_generate(__self: _Runnable, prompt: str, **attrs: t.Any) -> t.AsyncGenerator[list[t.Any], None]: stop: str | t.Iterable[str] | None = attrs.pop('stop', None) + echo = attrs.pop('echo', False) stop_token_ids: list[int] | None = attrs.pop('stop_token_ids', None) adapter_name = attrs.pop('adapter_name', None) if adapter_name is not None: __self.set_adapter(adapter_name) @@ -1234,18 +1233,14 @@ def llm_runnable_class(self: LLM[M, T], embeddings_sig: ModelSignature, generate config = self.config.model_construct_env(stop=list(stop_), top_p=top_p, **attrs) sampling_params = config.to_sampling_config() - async def loop() -> list[str]: - async for request_output in t.cast('vllm.AsyncLLMEngine', self.model).generate(prompt=prompt, sampling_params=sampling_params, request_id=request_id): - pass - return [output.text for output in request_output.outputs] - - try: - return asyncio.run(loop()) - except RuntimeError: - try: - return ensure_exec_coro(loop()) - except Exception: - raise + final_output = None + async for request_output in t.cast('vllm.AsyncLLMEngine', self.model).generate(prompt=prompt, sampling_params=sampling_params, request_id=request_id): + final_output = request_output + if final_output is None: raise ValueError("'output' should not be None") + prompt = final_output.prompt + if echo: text_outputs = [prompt + output.text for output in final_output.outputs] + else: text_outputs = [output.text for output in final_output.outputs] + yield text_outputs @bentoml.Runnable.method(**method_signature(generate_iterator_sig)) # type: ignore async def vllm_generate_iterator(__self: _Runnable, prompt: str, **attrs: t.Any) -> t.AsyncGenerator[str, None]: diff --git a/openllm-python/src/openllm/_service.py b/openllm-python/src/openllm/_service.py index c19e4adc..1f538075 100644 --- a/openllm-python/src/openllm/_service.py +++ b/openllm-python/src/openllm/_service.py @@ -45,10 +45,13 @@ _JsonInput = bentoml.io.JSON.from_sample({'prompt': '', 'llm_config': llm_config @svc.api(route='/v1/generate', input=_JsonInput, output=bentoml.io.JSON.from_sample({'responses': [], 'configuration': llm_config.model_dump(flatten=True)})) async def generate_v1(input_dict: dict[str, t.Any]) -> openllm.GenerationOutput: + echo = input_dict.pop('echo', False) qa_inputs = openllm.GenerationInput.from_llm_config(llm_config)(**input_dict) config = qa_inputs.llm_config.model_dump() if runner.backend == 'vllm': - responses = await runner.vllm_generate.async_run(qa_inputs.prompt, adapter_name=qa_inputs.adapter_name, request_id=openllm_core.utils.gen_random_uuid(), **config) + async for output in runner.vllm_generate.async_stream(qa_inputs.prompt, adapter_name=qa_inputs.adapter_name, echo=echo, request_id=openllm_core.utils.gen_random_uuid(), **config): + responses = output + if responses is None: raise ValueError("'responses' should not be None.") else: responses = await runner.generate.async_run(qa_inputs.prompt, adapter_name=qa_inputs.adapter_name, **config) return openllm.GenerationOutput(responses=responses, configuration=config)