From 9e371d2eade0365faca02f3359125df3cea1b736 Mon Sep 17 00:00:00 2001 From: aarnphm-ec2-dev <29749331+aarnphm@users.noreply.github.com> Date: Mon, 21 Aug 2023 12:02:35 +0000 Subject: [PATCH] fix(generate): Correct set batch output for generate from iterator Signed-off-by: aarnphm-ec2-dev <29749331+aarnphm@users.noreply.github.com> --- openllm-python/src/openllm/_llm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/openllm-python/src/openllm/_llm.py b/openllm-python/src/openllm/_llm.py index a09fda96..deb6fd45 100644 --- a/openllm-python/src/openllm/_llm.py +++ b/openllm-python/src/openllm/_llm.py @@ -927,10 +927,10 @@ class LLM(LLMInterface[M, T], ReprMixin): prompt, generate_kwargs, postprocess_kwargs = self.sanitize_parameters(prompt, **attrs) return self.postprocess_generate(prompt, self.generate(prompt, **generate_kwargs), **postprocess_kwargs) - def generate(self, prompt: str, **attrs: t.Any) -> t.Any: + def generate(self, prompt: str, **attrs: t.Any) -> t.List[t.Any]: # TODO: support different generation strategies, similar to self.model.generate for it in self.generate_iterator(prompt, **attrs): pass - return it + return [it] def postprocess_generate(self, prompt: str, generation_result: t.Any, **attrs: t.Any) -> str: if isinstance(generation_result, dict): return generation_result["text"]