feat(vllm): progressive streaming via parser.extract_tool_calls_streaming (follow-up to #10346) (#10351)

* fix(vllm): don't stream raw tool-call markup as content when a tool parser is active

When a tool_parser is configured and the request carries tools, the streaming
loop emitted every text delta as delta.content — including the model's raw
tool-call markup (e.g. <tool_call>...) — because extract_tool_calls only runs
on the full output after the stream. Clients streaming a tool call therefore
saw the unparsed tool-call syntax as assistant content.

Buffer the text while a tool parser is active for the request; the existing
end-of-stream chat_delta already carries the parsed tool_calls (or the cleaned
content), which the Go side converts to SSE deltas. Non-tool-parser streaming
is unchanged.

Add a server-less regression test covering both the tool-call case (no raw
markup leaked as content) and the plain-text case (content delivered exactly
once — guards against double-emitting the buffered content).

Signed-off-by: pos-ei-don <1822533+pos-ei-don@users.noreply.github.com>

* test(vllm): add expectedFailure test for progressive streaming with tool parser (Case 3, #582)

Signed-off-by: pos-ei-don <1822533+pos-ei-don@users.noreply.github.com>

* test(vllm): add Cases 4+5 — marker split across chunks + false-positive prefix (TDD, Option B state machine, #582)

Signed-off-by: pos-ei-don <1822533+pos-ei-don@users.noreply.github.com>

* feat(vllm): progressive streaming via parser.extract_tool_calls_streaming

When a tool parser is active for a tool-enabled streaming request,
#10346 buffers the entire generation and surfaces it on the final
chunk to prevent raw tool-call markup from leaking as delta.content.
This is correct but turns the request into effectively non-streaming
for plain-text responses — the client sees nothing until the model
stops.

Every concrete tool parser shipped with vLLM 0.23+ already implements
extract_tool_calls_streaming (Granite4, Qwen3Coder, DeepSeekV31, Jamba,
Ernie45, Hermes2Pro, llama3_json, mistral, …). Use it: instantiate
the parser before the streaming loop and call its streaming method per
delta, emitting DeltaMessage(content=…) or DeltaMessage(tool_calls=[…])
when the parser is ready.

Falls back to the existing #10346 buffer path when:
  - the parser does not have extract_tool_calls_streaming, OR
  - extract_tool_calls_streaming raises mid-stream (logged, the
    rest of the request finishes via post-loop extract_tool_calls).

Tests (TestStreamingToolParser):
  1. Buffer path: no markup leaked, no content duplication
  2. Native streaming: plain-text response streams progressively
  3. Native streaming: tool_call structured, no markup leaked
  4. Native streaming exception → graceful fallback, no markup, no crash
  5. No tool parser → unchanged per-delta content stream

E2E verified against qwen3_coder on vLLM 0.23.0 (NVIDIA GB10 / arm64 / CUDA 13).

Signed-off-by: pos-ei-don <1822533+pos-ei-don@users.noreply.github.com>

* docs(vllm): add server-side TTFT benchmark for the streaming tool-parser path

Self-contained stdlib-only script that measures time-to-first-token (TTFT)
for the vLLM backend's two streaming scenarios:

  - tool_call:  request mentions a tool; model is expected to call it
  - plain_text: request offers a tool but explicitly asks for prose

Use this to compare:
  - the buffer-all path (#10346)         → plain_text TTFT ≈ total response time
  - the native-streaming path (this PR)  → plain_text TTFT ≈ true first-token time

  python examples/vllm-bench/ttft_streaming_tool_parser.py \\
      --url http://localhost:8080 --model my-coder --runs 3

Lives under examples/ so it does not interfere with the test suite.

Signed-off-by: pos-ei-don <1822533+pos-ei-don@users.noreply.github.com>

* examples/vllm-bench: add long-text scenario (8 paragraphs, 1500 tokens)

The long-text scenario shows the buffering vs streaming difference most
dramatically: with the buffer-all path, the client receives nothing for
20+ seconds and then the entire 1500-token response at once. With native
streaming, the first token arrives in tens of milliseconds and the
response flows progressively.

Signed-off-by: pos-ei-don <1822533+pos-ei-don@users.noreply.github.com>

---------

Signed-off-by: pos-ei-don <1822533+pos-ei-don@users.noreply.github.com>
Co-authored-by: Philipp Wacker <philipp.wacker@ibf-solutions.com>
This commit is contained in:
pos-ei-don
2026-06-21 17:07:15 +02:00
committed by GitHub
parent 01fa12e0de
commit b4c0dc67fe
4 changed files with 631 additions and 18 deletions

View File

@@ -598,23 +598,124 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
# Stream the results
generated_text = ""
generated_token_ids: list[int] = []
last_output = None
# Tool-parsing strategy decision (made once, before the loop):
#
# When a tool parser is active, the model's raw tool-call markup
# (e.g. <tool_call>...) must not be streamed verbatim as delta.content
# — clients would see the unparsed syntax. Two paths:
#
# (A) native streaming via parser.extract_tool_calls_streaming. All
# concrete tool parsers shipped with vLLM 0.23+ implement this
# (Granite4, Qwen3Coder, DeepSeekV31, Jamba, Ernie45, Hermes,
# llama3_json, mistral, …). The parser decides per-delta whether
# to emit content or suppress tool-call markup, and emits a
# structured DeltaMessage(tool_calls=[...]) when a call is ready.
# (B) buffer fallback — used only when the parser surprisingly lacks
# the streaming method or it raises mid-stream. The post-loop
# extract_tool_calls assembles the final chat_delta. Same correctness
# guarantee as a non-streaming response, at the cost of a delayed
# final chunk.
has_tool_parser = bool(self.tool_parser_cls and request.Tools)
tp_instance = None
tp_request = None
native_streaming = False
native_streaming_error = False
if has_tool_parser:
try:
tools_for_parser = json.loads(request.Tools)
except json.JSONDecodeError:
tools_for_parser = []
try:
tp_instance = self.tool_parser_cls(self.tokenizer, tools=tools_for_parser)
except TypeError:
tp_instance = self.tool_parser_cls(self.tokenizer)
# Build a minimal ChatCompletionRequest so the streaming method
# sees the tools list. We do not need any other request fields —
# parsers only read .tools (and sometimes .tool_choice, which we
# leave at default).
try:
from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionRequest as _CCR,
)
tp_request = _CCR(
model="local",
messages=[{"role": "user", "content": ""}],
tools=tools_for_parser or None,
)
except Exception as e:
print(f"Could not build ChatCompletionRequest for streaming parser: {e}",
file=sys.stderr)
tp_request = None
native_streaming = (
tp_request is not None
and hasattr(tp_instance, "extract_tool_calls_streaming")
)
try:
async for request_output in outputs:
iteration_text = request_output.outputs[0].text
last_output = request_output
if streaming:
# Remove text already sent as vllm concatenates the text from previous yields
delta_iteration_text = iteration_text.removeprefix(generated_text)
# Send the partial result
yield backend_pb2.Reply(
message=bytes(delta_iteration_text, encoding='utf-8'),
chat_deltas=[backend_pb2.ChatDelta(content=delta_iteration_text)],
)
new_token_ids = list(request_output.outputs[0].token_ids)
delta_token_ids = new_token_ids[len(generated_token_ids):]
# Keep track of text generated
if not has_tool_parser:
# Plain streaming — unchanged from pre-tool-parser path.
yield backend_pb2.Reply(
message=bytes(delta_iteration_text, encoding='utf-8'),
chat_deltas=[backend_pb2.ChatDelta(content=delta_iteration_text)],
)
elif native_streaming and not native_streaming_error:
# (A) Native vLLM extract_tool_calls_streaming.
try:
msg = tp_instance.extract_tool_calls_streaming(
previous_text=generated_text,
current_text=iteration_text,
delta_text=delta_iteration_text,
previous_token_ids=generated_token_ids,
current_token_ids=new_token_ids,
delta_token_ids=delta_token_ids,
request=tp_request,
)
except Exception as e:
print(f"Streaming tool parser error (falling back to "
f"buffer for the rest of the stream): {e}",
file=sys.stderr)
native_streaming_error = True
msg = None
if msg is not None:
tc_protos = []
for tc in (msg.tool_calls or []):
fn = tc.function or None
tc_protos.append(backend_pb2.ToolCallDelta(
index=tc.index,
id=tc.id or "",
name=(fn.name if fn and fn.name else "") or "",
arguments=(fn.arguments if fn and fn.arguments else "") or "",
))
cd_kwargs = {}
if msg.content:
cd_kwargs["content"] = msg.content
if msg.reasoning:
cd_kwargs["reasoning_content"] = msg.reasoning
if tc_protos:
cd_kwargs["tool_calls"] = tc_protos
if cd_kwargs:
yield backend_pb2.Reply(
message=bytes(msg.content or "", encoding='utf-8'),
chat_deltas=[backend_pb2.ChatDelta(**cd_kwargs)],
)
# (B) buffer fallback — emit nothing during the stream.
# The post-loop extract_tool_calls block builds the final chunk.
# Keep track of text + token_ids generated
generated_text = iteration_text
generated_token_ids = list(request_output.outputs[0].token_ids)
finally:
await outputs.aclose()
@@ -639,16 +740,19 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
except Exception as e:
print(f"Reasoning parser error: {e}", file=sys.stderr)
if self.tool_parser_cls and request.Tools:
# When (A) native streaming ran cleanly, per-delta yields above already
# delivered everything — do NOT extract again on the full text or we'd
# duplicate content/tool_calls into the final chunk.
if has_tool_parser and not (native_streaming and not native_streaming_error):
try:
tools = json.loads(request.Tools)
# Some concrete parsers only accept the tokenizer; only the
# abstract base declares the tools kwarg. Try with tools first,
# fall back to tokenizer-only.
try:
tp = self.tool_parser_cls(self.tokenizer, tools=tools)
except TypeError:
tp = self.tool_parser_cls(self.tokenizer)
tp = tp_instance
if tp is None:
# Defensive: tp_instance build failed earlier; reconstruct.
tools = json.loads(request.Tools)
try:
tp = self.tool_parser_cls(self.tokenizer, tools=tools)
except TypeError:
tp = self.tool_parser_cls(self.tokenizer)
info = tp.extract_tool_calls(content, request=None)
if info.tools_called:
content = info.content or ""
@@ -661,6 +765,10 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
))
except Exception as e:
print(f"Tool parser error: {e}", file=sys.stderr)
elif native_streaming and not native_streaming_error:
# Per-delta path already emitted content + tool_calls; the final
# chat_delta should carry only metadata (token counts, logprobs).
content = ""
# Extract token counts
prompt_tokens = 0
@@ -700,7 +808,26 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
)
if streaming:
# Final chunk with structured data
# Final chunk with structured data.
#
# If we used the buffer fallback (has_tool_parser=True AND native
# streaming did NOT run cleanly) and the parser found no tool call,
# flush the buffered content as ONE content delta — and clear the
# final chat_delta's content so the metadata chunk does not repeat
# what we just sent. This is the plain-text-with-tool-parser path.
buffered_fallback = (
has_tool_parser
and not (native_streaming and not native_streaming_error)
)
if buffered_fallback and not tool_calls_proto and content:
yield backend_pb2.Reply(
message=bytes(content, encoding='utf-8'),
chat_deltas=[backend_pb2.ChatDelta(content=content)],
)
chat_delta = backend_pb2.ChatDelta(
reasoning_content=reasoning_content,
tool_calls=tool_calls_proto,
)
yield backend_pb2.Reply(
message=b"",
prompt_tokens=prompt_tokens,

View File

@@ -278,4 +278,261 @@ class TestBackendServicer(unittest.TestCase):
print(err)
self.fail("Embedding service failed")
finally:
self.tearDown()
self.tearDown()
class TestStreamingToolParser(unittest.TestCase):
"""
Server-less unit tests for the streaming + tool-parser machinery in
BackendServicer._predict. These tests instantiate BackendServicer
directly and mock the vLLM engine + tool parser, so they do not need
a GPU, a model, or a running gRPC server. Kept in a separate class to
avoid the parent setUp() which spawns a subprocess.
Covers #582 (follow-up to #10346):
1. Markup-leak prevention with a non-streaming parser (buffer fallback)
2. No content duplication on the plain-text path with the buffer fallback
3. Native streaming progressive plain-text emission
4. Native streaming structured tool_call, no markup leak
5. Parser exception → graceful fallback to buffer, still no markup
6. No-tool-parser regression: unchanged per-delta content stream
"""
@staticmethod
def _make_generate(chunks):
"""Build a fake vLLM engine.generate that yields cumulative chunks."""
from types import SimpleNamespace
async def gen(*a, **k):
for i, t in enumerate(chunks):
yield SimpleNamespace(
outputs=[SimpleNamespace(
text=t,
token_ids=list(range(i + 1)),
logprobs=None,
)],
prompt_token_ids=[0],
)
return lambda *a, **k: gen()
@staticmethod
def _collect(servicer, req):
import asyncio
async def run():
return [r async for r in servicer._predict(req, None, streaming=True)]
return asyncio.run(run())
def _new_servicer(self):
import sys, os
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from backend import BackendServicer
s = BackendServicer()
s.reasoning_parser_cls = None
s.tool_parser_cls = None
s.tokenizer = None
return s
# ── Case 1+2: parser without streaming method → buffer fallback ──
def test_buffer_path_no_markup_no_duplication(self):
from types import SimpleNamespace
def parser_cls(called, content_text, calls):
class _P:
def __init__(self, tokenizer, tools=None):
pass
# NOTE: NO extract_tool_calls_streaming → takes the buffer path
def extract_tool_calls(self, c, request=None):
return SimpleNamespace(
tools_called=called, content=content_text, tool_calls=calls,
)
return _P
tools_json = '[{"type":"function","function":{"name":"calc","parameters":{}}}]'
# Tool-call case: no raw markup in any delta.content
s = self._new_servicer()
s.llm = SimpleNamespace(generate=self._make_generate([
'<tool_call>\n{"name": "calc"',
'<tool_call>\n{"name": "calc", "arguments": {"x": 1}}\n</tool_call>',
]))
call = SimpleNamespace(id="call_1",
function=SimpleNamespace(name="calc", arguments='{"x": 1}'))
s.tool_parser_cls = parser_cls(True, "", [call])
req = backend_pb2.PredictOptions(Prompt="x", Tools=tools_json)
replies = self._collect(s, req)
contents = [cd.content for r in replies for cd in r.chat_deltas if cd.content]
self.assertFalse(
any("<tool_call" in c for c in contents),
f"markup leaked: {contents!r}",
)
names = [tc.name for r in replies for cd in r.chat_deltas for tc in cd.tool_calls]
self.assertIn("calc", names, "tool_call missing from final chunk")
# Plain-text-with-tools case: full content delivered exactly once
s2 = self._new_servicer()
s2.llm = SimpleNamespace(generate=self._make_generate([
"The capital ",
"The capital of France is Paris.",
]))
s2.tool_parser_cls = parser_cls(False, "", [])
req2 = backend_pb2.PredictOptions(Prompt="x", Tools=tools_json)
joined = "".join(
cd.content for r in self._collect(s2, req2)
for cd in r.chat_deltas if cd.content
)
self.assertEqual(
joined.count("The capital of France is Paris."), 1,
f"buffered content duplicated: {joined!r}",
)
# ── Case 3: native streaming, progressive plain text ──
def test_native_streaming_progressive_plain_text(self):
from types import SimpleNamespace
class _DeltaMsg:
def __init__(self, content=None, reasoning=None, tool_calls=None):
self.content = content
self.reasoning = reasoning
self.tool_calls = tool_calls or []
class StreamingParser:
def __init__(self, tokenizer, tools=None):
pass
def extract_tool_calls(self, c, request=None):
# Should NOT be called when native streaming runs successfully.
raise AssertionError("extract_tool_calls invoked on native-streaming path")
def extract_tool_calls_streaming(
self, previous_text, current_text, delta_text,
previous_token_ids, current_token_ids, delta_token_ids, request,
):
if not delta_text:
return None
return _DeltaMsg(content=delta_text)
s = self._new_servicer()
s.llm = SimpleNamespace(generate=self._make_generate([
"Paris ",
"Paris is ",
"Paris is the capital of France.",
]))
s.tool_parser_cls = StreamingParser
req = backend_pb2.PredictOptions(
Prompt="x",
Tools='[{"type":"function","function":{"name":"calc","parameters":{}}}]',
)
replies = self._collect(s, req)
intermediate_content = [
cd.content for r in replies[:-1] for cd in r.chat_deltas if cd.content
]
self.assertTrue(
len(intermediate_content) > 0,
"Plain-text response not streamed progressively (native streaming inactive?)",
)
assembled = "".join(
cd.content for r in replies for cd in r.chat_deltas if cd.content
)
self.assertEqual(
assembled, "Paris is the capital of France.",
f"Assembled content wrong: {assembled!r}",
)
# ── Case 4: native streaming, structured tool_call, no markup ──
def test_native_streaming_tool_call_no_markup_leak(self):
from types import SimpleNamespace
class _DeltaMsg:
def __init__(self, content=None, reasoning=None, tool_calls=None):
self.content = content
self.reasoning = reasoning
self.tool_calls = tool_calls or []
class _ToolCallStreamer:
def __init__(self, tokenizer, tools=None):
self._emitted = False
def extract_tool_calls(self, c, request=None):
raise AssertionError("extract_tool_calls invoked on native-streaming path")
def extract_tool_calls_streaming(
self, previous_text, current_text, delta_text,
previous_token_ids, current_token_ids, delta_token_ids, request,
):
if "</tool_call>" in current_text and not self._emitted:
self._emitted = True
fn = SimpleNamespace(name="calc", arguments='{"x": 1}')
tc = SimpleNamespace(id="call_1", type="function", index=0, function=fn)
return _DeltaMsg(tool_calls=[tc])
return None
s = self._new_servicer()
s.llm = SimpleNamespace(generate=self._make_generate([
'<tool_call>\n',
'<tool_call>\n{"name": "calc"',
'<tool_call>\n{"name": "calc", "arguments": {"x": 1}}\n</tool_call>',
]))
s.tool_parser_cls = _ToolCallStreamer
req = backend_pb2.PredictOptions(
Prompt="x",
Tools='[{"type":"function","function":{"name":"calc","parameters":{}}}]',
)
replies = self._collect(s, req)
contents = [cd.content for r in replies for cd in r.chat_deltas if cd.content]
self.assertFalse(
any("<tool_call" in c or "</tool_call>" in c for c in contents),
f"markup leaked as content: {contents!r}",
)
names = [tc.name for r in replies for cd in r.chat_deltas for tc in cd.tool_calls if tc.name]
args = [tc.arguments for r in replies for cd in r.chat_deltas for tc in cd.tool_calls if tc.arguments]
self.assertIn("calc", names, f"tool_call name missing; got {names!r}")
self.assertIn('{"x": 1}', args, f"tool_call args missing; got {args!r}")
# ── Case 5: parser exception → fallback to buffer, no leak ──
def test_native_streaming_parser_exception_falls_back_to_buffer(self):
from types import SimpleNamespace
call = SimpleNamespace(id="call_1",
function=SimpleNamespace(name="calc", arguments='{"x": 1}'))
class _BrokenStreamer:
def __init__(self, tokenizer, tools=None):
pass
def extract_tool_calls(self, c, request=None):
return SimpleNamespace(tools_called=True, content="", tool_calls=[call])
def extract_tool_calls_streaming(self, *a, **kw):
raise RuntimeError("simulated parser bug")
s = self._new_servicer()
s.llm = SimpleNamespace(generate=self._make_generate([
'<tool_call>\n{"name": "calc"',
'<tool_call>\n{"name": "calc", "arguments": {"x": 1}}\n</tool_call>',
]))
s.tool_parser_cls = _BrokenStreamer
req = backend_pb2.PredictOptions(
Prompt="x",
Tools='[{"type":"function","function":{"name":"calc","parameters":{}}}]',
)
replies = self._collect(s, req)
contents = [cd.content for r in replies for cd in r.chat_deltas if cd.content]
self.assertFalse(
any("<tool_call" in c for c in contents),
f"markup leaked after parser exception: {contents!r}",
)
names = [tc.name for r in replies for cd in r.chat_deltas for tc in cd.tool_calls]
self.assertIn("calc", names, "tool_call missing from final chunk after fallback")
# ── Case 6: no tool parser → unchanged per-delta content stream ──
def test_no_tool_parser_unchanged_per_delta_stream(self):
from types import SimpleNamespace
s = self._new_servicer() # tool_parser_cls already None
s.llm = SimpleNamespace(generate=self._make_generate([
"Hello ", "Hello world", "Hello world!",
]))
req = backend_pb2.PredictOptions(Prompt="x", Tools="")
replies = self._collect(s, req)
intermediate = [
cd.content for r in replies[:-1] for cd in r.chat_deltas if cd.content
]
self.assertEqual(
intermediate, ["Hello ", "world", "!"],
f"plain streaming changed; got {intermediate!r}",
)