fix(api-server): correct set generation from LLM class

Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
Aaron Pham
2023-11-21 10:38:36 +00:00
parent 2821e172ef
commit d53cf234bd

View File

@@ -28,7 +28,7 @@ class GenerateInput(TypedDict):
async def generate(request: GenerateInput) -> Union[AsyncGenerator[str, None], str]:
n = request['sampling_params'].pop('n', 1)
request_id = f'tinyllm-{uuid.uuid4().hex}'
previous_texts = [''] * n
previous_texts = [[]] * n
generator = llm.generate_iterator(request['prompt'], request_id=request_id, n=n, **request['sampling_params'])
@@ -36,15 +36,11 @@ async def generate(request: GenerateInput) -> Union[AsyncGenerator[str, None], s
async for request_output in generator:
for output in request_output.outputs:
i = output.index
delta_text = output.text[len(previous_texts[i]) :]
previous_texts[i] = output.text
yield delta_text
previous_texts[i].append(output.text)
yield output.text
if request['stream']:
return streamer()
final_output = None
async for request_output in generator:
final_output = request_output
assert final_output is not None
return final_output.outputs[0].text
async for _ in streamer(): pass
return ''.join(previous_texts[0])