From 806308eed65ea2af6c6b9b8a333747acf627f837 Mon Sep 17 00:00:00 2001 From: paperspace <29749331+aarnphm@users.noreply.github.com> Date: Sun, 12 May 2024 04:21:06 +0000 Subject: [PATCH] fix: one-shot generation not to concatenate duplicates Signed-off-by: paperspace <29749331+aarnphm@users.noreply.github.com> --- openllm-python/src/_openllm_tiny/_llm.py | 16 ++-------------- 1 file changed, 2 insertions(+), 14 deletions(-) diff --git a/openllm-python/src/_openllm_tiny/_llm.py b/openllm-python/src/_openllm_tiny/_llm.py index eab5635e..7df3a2e4 100644 --- a/openllm-python/src/_openllm_tiny/_llm.py +++ b/openllm-python/src/_openllm_tiny/_llm.py @@ -200,7 +200,6 @@ class LLM: if stop_token_ids is not None: attrs.update({'stop_token_ids': stop_token_ids}) config = self.config.model_construct_env(**attrs) - texts, token_ids = [[]] * config['n'], [[]] * config['n'] async for result in self.generate_iterator( prompt, prompt_token_ids=prompt_token_ids, @@ -208,18 +207,7 @@ class LLM: adapter_name=adapter_name, **config.model_dump(), ): - for output in result.outputs: - texts[output.index].append(output.text) - token_ids[output.index].extend(output.token_ids) + pass if (final_result := result) is None: raise RuntimeError('No result is returned.') - converted = GenerationOutput.from_vllm(final_result) - return converted.model_copy( - update=dict( - prompt=prompt, - outputs=[ - output.model_copy(update=dict(text=''.join(texts[output.index]), token_ids=token_ids[output.index])) - for output in converted.outputs - ], - ) - ) + return GenerationOutput.from_vllm(final_result).model_copy(update=dict(prompt=prompt))