diff --git a/openllm-python/src/openllm/_runners.py b/openllm-python/src/openllm/_runners.py index 5c43efe5..9b069f06 100644 --- a/openllm-python/src/openllm/_runners.py +++ b/openllm-python/src/openllm/_runners.py @@ -93,7 +93,7 @@ class CTranslateRunnable(bentoml.Runnable): spaces_between_special_tokens=False, clean_up_tokenization_spaces=True, # ) - yield GenerationOutput( + out = GenerationOutput( prompt_token_ids=prompt_token_ids, # prompt='', finished=request_output.is_last, @@ -109,7 +109,7 @@ class CTranslateRunnable(bentoml.Runnable): ) ], ).model_dump_json() - + yield bentoml.io.SSE(out).marshal() @registry class vLLMRunnable(bentoml.Runnable): @@ -286,7 +286,7 @@ class PyTorchRunnable(bentoml.Runnable): if config['logprobs']: sample_logprobs.append({token: token_logprobs}) - yield GenerationOutput( + out = GenerationOutput( prompt='', finished=False, outputs=[ @@ -303,13 +303,14 @@ class PyTorchRunnable(bentoml.Runnable): prompt_logprobs=prompt_logprobs if config['prompt_logprobs'] else None, request_id=request_id, ).model_dump_json() + yield bentoml.io.SSE(out).marshal() if stopped: break else: finish_reason = 'length' if stopped: finish_reason = 'stop' - yield GenerationOutput( + out = GenerationOutput( prompt='', finished=True, outputs=[ @@ -326,6 +327,7 @@ class PyTorchRunnable(bentoml.Runnable): prompt_logprobs=prompt_logprobs if config['prompt_logprobs'] else None, request_id=request_id, ).model_dump_json() + yield bentoml.io.SSE(out).marshal() # Clean del past_key_values, out