diff --git a/examples/api_server.py b/examples/api_server.py index 459faa27..6e1ecb47 100644 --- a/examples/api_server.py +++ b/examples/api_server.py @@ -28,7 +28,7 @@ class GenerateInput(TypedDict): async def generate(request: GenerateInput) -> Union[AsyncGenerator[str, None], str]: n = request['sampling_params'].pop('n', 1) request_id = f'tinyllm-{uuid.uuid4().hex}' - previous_texts = [''] * n + previous_texts = [[]] * n generator = llm.generate_iterator(request['prompt'], request_id=request_id, n=n, **request['sampling_params']) @@ -36,15 +36,11 @@ async def generate(request: GenerateInput) -> Union[AsyncGenerator[str, None], s async for request_output in generator: for output in request_output.outputs: i = output.index - delta_text = output.text[len(previous_texts[i]) :] - previous_texts[i] = output.text - yield delta_text + previous_texts[i].append(output.text) + yield output.text if request['stream']: return streamer() - final_output = None - async for request_output in generator: - final_output = request_output - assert final_output is not None - return final_output.outputs[0].text + async for _ in streamer(): pass + return ''.join(previous_texts[0])