fix: one-shot generation not to concatenate duplicates

Signed-off-by: paperspace <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
paperspace
2024-05-12 04:21:06 +00:00
parent 1d2e554a94
commit 806308eed6

View File

@@ -200,7 +200,6 @@ class LLM:
if stop_token_ids is not None:
attrs.update({'stop_token_ids': stop_token_ids})
config = self.config.model_construct_env(**attrs)
texts, token_ids = [[]] * config['n'], [[]] * config['n']
async for result in self.generate_iterator(
prompt,
prompt_token_ids=prompt_token_ids,
@@ -208,18 +207,7 @@ class LLM:
adapter_name=adapter_name,
**config.model_dump(),
):
for output in result.outputs:
texts[output.index].append(output.text)
token_ids[output.index].extend(output.token_ids)
pass
if (final_result := result) is None:
raise RuntimeError('No result is returned.')
converted = GenerationOutput.from_vllm(final_result)
return converted.model_copy(
update=dict(
prompt=prompt,
outputs=[
output.model_copy(update=dict(text=''.join(texts[output.index]), token_ids=token_ids[output.index]))
for output in converted.outputs
],
)
)
return GenerationOutput.from_vllm(final_result).model_copy(update=dict(prompt=prompt))