fix(generate): Correct set batch output for generate from iterator

Signed-off-by: aarnphm-ec2-dev <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
aarnphm-ec2-dev
2023-08-21 12:02:35 +00:00
parent ce20cb32f0
commit 9e371d2ead

View File

@@ -927,10 +927,10 @@ class LLM(LLMInterface[M, T], ReprMixin):
prompt, generate_kwargs, postprocess_kwargs = self.sanitize_parameters(prompt, **attrs)
return self.postprocess_generate(prompt, self.generate(prompt, **generate_kwargs), **postprocess_kwargs)
def generate(self, prompt: str, **attrs: t.Any) -> t.Any:
def generate(self, prompt: str, **attrs: t.Any) -> t.List[t.Any]:
# TODO: support different generation strategies, similar to self.model.generate
for it in self.generate_iterator(prompt, **attrs): pass
return it
return [it]
def postprocess_generate(self, prompt: str, generation_result: t.Any, **attrs: t.Any) -> str:
if isinstance(generation_result, dict): return generation_result["text"]