mirror of
https://github.com/mudler/LocalAI.git
synced 2026-06-21 23:29:04 -04:00
feat(vllm): progressive streaming via parser.extract_tool_calls_streaming (follow-up to #10346) (#10351)
* fix(vllm): don't stream raw tool-call markup as content when a tool parser is active When a tool_parser is configured and the request carries tools, the streaming loop emitted every text delta as delta.content — including the model's raw tool-call markup (e.g. <tool_call>...) — because extract_tool_calls only runs on the full output after the stream. Clients streaming a tool call therefore saw the unparsed tool-call syntax as assistant content. Buffer the text while a tool parser is active for the request; the existing end-of-stream chat_delta already carries the parsed tool_calls (or the cleaned content), which the Go side converts to SSE deltas. Non-tool-parser streaming is unchanged. Add a server-less regression test covering both the tool-call case (no raw markup leaked as content) and the plain-text case (content delivered exactly once — guards against double-emitting the buffered content). Signed-off-by: pos-ei-don <1822533+pos-ei-don@users.noreply.github.com> * test(vllm): add expectedFailure test for progressive streaming with tool parser (Case 3, #582) Signed-off-by: pos-ei-don <1822533+pos-ei-don@users.noreply.github.com> * test(vllm): add Cases 4+5 — marker split across chunks + false-positive prefix (TDD, Option B state machine, #582) Signed-off-by: pos-ei-don <1822533+pos-ei-don@users.noreply.github.com> * feat(vllm): progressive streaming via parser.extract_tool_calls_streaming When a tool parser is active for a tool-enabled streaming request, #10346 buffers the entire generation and surfaces it on the final chunk to prevent raw tool-call markup from leaking as delta.content. This is correct but turns the request into effectively non-streaming for plain-text responses — the client sees nothing until the model stops. Every concrete tool parser shipped with vLLM 0.23+ already implements extract_tool_calls_streaming (Granite4, Qwen3Coder, DeepSeekV31, Jamba, Ernie45, Hermes2Pro, llama3_json, mistral, …). Use it: instantiate the parser before the streaming loop and call its streaming method per delta, emitting DeltaMessage(content=…) or DeltaMessage(tool_calls=[…]) when the parser is ready. Falls back to the existing #10346 buffer path when: - the parser does not have extract_tool_calls_streaming, OR - extract_tool_calls_streaming raises mid-stream (logged, the rest of the request finishes via post-loop extract_tool_calls). Tests (TestStreamingToolParser): 1. Buffer path: no markup leaked, no content duplication 2. Native streaming: plain-text response streams progressively 3. Native streaming: tool_call structured, no markup leaked 4. Native streaming exception → graceful fallback, no markup, no crash 5. No tool parser → unchanged per-delta content stream E2E verified against qwen3_coder on vLLM 0.23.0 (NVIDIA GB10 / arm64 / CUDA 13). Signed-off-by: pos-ei-don <1822533+pos-ei-don@users.noreply.github.com> * docs(vllm): add server-side TTFT benchmark for the streaming tool-parser path Self-contained stdlib-only script that measures time-to-first-token (TTFT) for the vLLM backend's two streaming scenarios: - tool_call: request mentions a tool; model is expected to call it - plain_text: request offers a tool but explicitly asks for prose Use this to compare: - the buffer-all path (#10346) → plain_text TTFT ≈ total response time - the native-streaming path (this PR) → plain_text TTFT ≈ true first-token time python examples/vllm-bench/ttft_streaming_tool_parser.py \\ --url http://localhost:8080 --model my-coder --runs 3 Lives under examples/ so it does not interfere with the test suite. Signed-off-by: pos-ei-don <1822533+pos-ei-don@users.noreply.github.com> * examples/vllm-bench: add long-text scenario (8 paragraphs, 1500 tokens) The long-text scenario shows the buffering vs streaming difference most dramatically: with the buffer-all path, the client receives nothing for 20+ seconds and then the entire 1500-token response at once. With native streaming, the first token arrives in tens of milliseconds and the response flows progressively. Signed-off-by: pos-ei-don <1822533+pos-ei-don@users.noreply.github.com> --------- Signed-off-by: pos-ei-don <1822533+pos-ei-don@users.noreply.github.com> Co-authored-by: Philipp Wacker <philipp.wacker@ibf-solutions.com>
This commit is contained in:
@@ -598,23 +598,124 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
|
||||
# Stream the results
|
||||
generated_text = ""
|
||||
generated_token_ids: list[int] = []
|
||||
last_output = None
|
||||
|
||||
# Tool-parsing strategy decision (made once, before the loop):
|
||||
#
|
||||
# When a tool parser is active, the model's raw tool-call markup
|
||||
# (e.g. <tool_call>...) must not be streamed verbatim as delta.content
|
||||
# — clients would see the unparsed syntax. Two paths:
|
||||
#
|
||||
# (A) native streaming via parser.extract_tool_calls_streaming. All
|
||||
# concrete tool parsers shipped with vLLM 0.23+ implement this
|
||||
# (Granite4, Qwen3Coder, DeepSeekV31, Jamba, Ernie45, Hermes,
|
||||
# llama3_json, mistral, …). The parser decides per-delta whether
|
||||
# to emit content or suppress tool-call markup, and emits a
|
||||
# structured DeltaMessage(tool_calls=[...]) when a call is ready.
|
||||
# (B) buffer fallback — used only when the parser surprisingly lacks
|
||||
# the streaming method or it raises mid-stream. The post-loop
|
||||
# extract_tool_calls assembles the final chat_delta. Same correctness
|
||||
# guarantee as a non-streaming response, at the cost of a delayed
|
||||
# final chunk.
|
||||
has_tool_parser = bool(self.tool_parser_cls and request.Tools)
|
||||
tp_instance = None
|
||||
tp_request = None
|
||||
native_streaming = False
|
||||
native_streaming_error = False
|
||||
if has_tool_parser:
|
||||
try:
|
||||
tools_for_parser = json.loads(request.Tools)
|
||||
except json.JSONDecodeError:
|
||||
tools_for_parser = []
|
||||
try:
|
||||
tp_instance = self.tool_parser_cls(self.tokenizer, tools=tools_for_parser)
|
||||
except TypeError:
|
||||
tp_instance = self.tool_parser_cls(self.tokenizer)
|
||||
# Build a minimal ChatCompletionRequest so the streaming method
|
||||
# sees the tools list. We do not need any other request fields —
|
||||
# parsers only read .tools (and sometimes .tool_choice, which we
|
||||
# leave at default).
|
||||
try:
|
||||
from vllm.entrypoints.openai.chat_completion.protocol import (
|
||||
ChatCompletionRequest as _CCR,
|
||||
)
|
||||
tp_request = _CCR(
|
||||
model="local",
|
||||
messages=[{"role": "user", "content": ""}],
|
||||
tools=tools_for_parser or None,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Could not build ChatCompletionRequest for streaming parser: {e}",
|
||||
file=sys.stderr)
|
||||
tp_request = None
|
||||
native_streaming = (
|
||||
tp_request is not None
|
||||
and hasattr(tp_instance, "extract_tool_calls_streaming")
|
||||
)
|
||||
|
||||
try:
|
||||
async for request_output in outputs:
|
||||
iteration_text = request_output.outputs[0].text
|
||||
last_output = request_output
|
||||
|
||||
if streaming:
|
||||
# Remove text already sent as vllm concatenates the text from previous yields
|
||||
delta_iteration_text = iteration_text.removeprefix(generated_text)
|
||||
# Send the partial result
|
||||
yield backend_pb2.Reply(
|
||||
message=bytes(delta_iteration_text, encoding='utf-8'),
|
||||
chat_deltas=[backend_pb2.ChatDelta(content=delta_iteration_text)],
|
||||
)
|
||||
new_token_ids = list(request_output.outputs[0].token_ids)
|
||||
delta_token_ids = new_token_ids[len(generated_token_ids):]
|
||||
|
||||
# Keep track of text generated
|
||||
if not has_tool_parser:
|
||||
# Plain streaming — unchanged from pre-tool-parser path.
|
||||
yield backend_pb2.Reply(
|
||||
message=bytes(delta_iteration_text, encoding='utf-8'),
|
||||
chat_deltas=[backend_pb2.ChatDelta(content=delta_iteration_text)],
|
||||
)
|
||||
elif native_streaming and not native_streaming_error:
|
||||
# (A) Native vLLM extract_tool_calls_streaming.
|
||||
try:
|
||||
msg = tp_instance.extract_tool_calls_streaming(
|
||||
previous_text=generated_text,
|
||||
current_text=iteration_text,
|
||||
delta_text=delta_iteration_text,
|
||||
previous_token_ids=generated_token_ids,
|
||||
current_token_ids=new_token_ids,
|
||||
delta_token_ids=delta_token_ids,
|
||||
request=tp_request,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Streaming tool parser error (falling back to "
|
||||
f"buffer for the rest of the stream): {e}",
|
||||
file=sys.stderr)
|
||||
native_streaming_error = True
|
||||
msg = None
|
||||
if msg is not None:
|
||||
tc_protos = []
|
||||
for tc in (msg.tool_calls or []):
|
||||
fn = tc.function or None
|
||||
tc_protos.append(backend_pb2.ToolCallDelta(
|
||||
index=tc.index,
|
||||
id=tc.id or "",
|
||||
name=(fn.name if fn and fn.name else "") or "",
|
||||
arguments=(fn.arguments if fn and fn.arguments else "") or "",
|
||||
))
|
||||
cd_kwargs = {}
|
||||
if msg.content:
|
||||
cd_kwargs["content"] = msg.content
|
||||
if msg.reasoning:
|
||||
cd_kwargs["reasoning_content"] = msg.reasoning
|
||||
if tc_protos:
|
||||
cd_kwargs["tool_calls"] = tc_protos
|
||||
if cd_kwargs:
|
||||
yield backend_pb2.Reply(
|
||||
message=bytes(msg.content or "", encoding='utf-8'),
|
||||
chat_deltas=[backend_pb2.ChatDelta(**cd_kwargs)],
|
||||
)
|
||||
# (B) buffer fallback — emit nothing during the stream.
|
||||
# The post-loop extract_tool_calls block builds the final chunk.
|
||||
|
||||
# Keep track of text + token_ids generated
|
||||
generated_text = iteration_text
|
||||
generated_token_ids = list(request_output.outputs[0].token_ids)
|
||||
finally:
|
||||
await outputs.aclose()
|
||||
|
||||
@@ -639,16 +740,19 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
except Exception as e:
|
||||
print(f"Reasoning parser error: {e}", file=sys.stderr)
|
||||
|
||||
if self.tool_parser_cls and request.Tools:
|
||||
# When (A) native streaming ran cleanly, per-delta yields above already
|
||||
# delivered everything — do NOT extract again on the full text or we'd
|
||||
# duplicate content/tool_calls into the final chunk.
|
||||
if has_tool_parser and not (native_streaming and not native_streaming_error):
|
||||
try:
|
||||
tools = json.loads(request.Tools)
|
||||
# Some concrete parsers only accept the tokenizer; only the
|
||||
# abstract base declares the tools kwarg. Try with tools first,
|
||||
# fall back to tokenizer-only.
|
||||
try:
|
||||
tp = self.tool_parser_cls(self.tokenizer, tools=tools)
|
||||
except TypeError:
|
||||
tp = self.tool_parser_cls(self.tokenizer)
|
||||
tp = tp_instance
|
||||
if tp is None:
|
||||
# Defensive: tp_instance build failed earlier; reconstruct.
|
||||
tools = json.loads(request.Tools)
|
||||
try:
|
||||
tp = self.tool_parser_cls(self.tokenizer, tools=tools)
|
||||
except TypeError:
|
||||
tp = self.tool_parser_cls(self.tokenizer)
|
||||
info = tp.extract_tool_calls(content, request=None)
|
||||
if info.tools_called:
|
||||
content = info.content or ""
|
||||
@@ -661,6 +765,10 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
))
|
||||
except Exception as e:
|
||||
print(f"Tool parser error: {e}", file=sys.stderr)
|
||||
elif native_streaming and not native_streaming_error:
|
||||
# Per-delta path already emitted content + tool_calls; the final
|
||||
# chat_delta should carry only metadata (token counts, logprobs).
|
||||
content = ""
|
||||
|
||||
# Extract token counts
|
||||
prompt_tokens = 0
|
||||
@@ -700,7 +808,26 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
)
|
||||
|
||||
if streaming:
|
||||
# Final chunk with structured data
|
||||
# Final chunk with structured data.
|
||||
#
|
||||
# If we used the buffer fallback (has_tool_parser=True AND native
|
||||
# streaming did NOT run cleanly) and the parser found no tool call,
|
||||
# flush the buffered content as ONE content delta — and clear the
|
||||
# final chat_delta's content so the metadata chunk does not repeat
|
||||
# what we just sent. This is the plain-text-with-tool-parser path.
|
||||
buffered_fallback = (
|
||||
has_tool_parser
|
||||
and not (native_streaming and not native_streaming_error)
|
||||
)
|
||||
if buffered_fallback and not tool_calls_proto and content:
|
||||
yield backend_pb2.Reply(
|
||||
message=bytes(content, encoding='utf-8'),
|
||||
chat_deltas=[backend_pb2.ChatDelta(content=content)],
|
||||
)
|
||||
chat_delta = backend_pb2.ChatDelta(
|
||||
reasoning_content=reasoning_content,
|
||||
tool_calls=tool_calls_proto,
|
||||
)
|
||||
yield backend_pb2.Reply(
|
||||
message=b"",
|
||||
prompt_tokens=prompt_tokens,
|
||||
|
||||
@@ -278,4 +278,261 @@ class TestBackendServicer(unittest.TestCase):
|
||||
print(err)
|
||||
self.fail("Embedding service failed")
|
||||
finally:
|
||||
self.tearDown()
|
||||
self.tearDown()
|
||||
|
||||
|
||||
class TestStreamingToolParser(unittest.TestCase):
|
||||
"""
|
||||
Server-less unit tests for the streaming + tool-parser machinery in
|
||||
BackendServicer._predict. These tests instantiate BackendServicer
|
||||
directly and mock the vLLM engine + tool parser, so they do not need
|
||||
a GPU, a model, or a running gRPC server. Kept in a separate class to
|
||||
avoid the parent setUp() which spawns a subprocess.
|
||||
|
||||
Covers #582 (follow-up to #10346):
|
||||
1. Markup-leak prevention with a non-streaming parser (buffer fallback)
|
||||
2. No content duplication on the plain-text path with the buffer fallback
|
||||
3. Native streaming progressive plain-text emission
|
||||
4. Native streaming structured tool_call, no markup leak
|
||||
5. Parser exception → graceful fallback to buffer, still no markup
|
||||
6. No-tool-parser regression: unchanged per-delta content stream
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _make_generate(chunks):
|
||||
"""Build a fake vLLM engine.generate that yields cumulative chunks."""
|
||||
from types import SimpleNamespace
|
||||
async def gen(*a, **k):
|
||||
for i, t in enumerate(chunks):
|
||||
yield SimpleNamespace(
|
||||
outputs=[SimpleNamespace(
|
||||
text=t,
|
||||
token_ids=list(range(i + 1)),
|
||||
logprobs=None,
|
||||
)],
|
||||
prompt_token_ids=[0],
|
||||
)
|
||||
return lambda *a, **k: gen()
|
||||
|
||||
@staticmethod
|
||||
def _collect(servicer, req):
|
||||
import asyncio
|
||||
async def run():
|
||||
return [r async for r in servicer._predict(req, None, streaming=True)]
|
||||
return asyncio.run(run())
|
||||
|
||||
def _new_servicer(self):
|
||||
import sys, os
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
from backend import BackendServicer
|
||||
s = BackendServicer()
|
||||
s.reasoning_parser_cls = None
|
||||
s.tool_parser_cls = None
|
||||
s.tokenizer = None
|
||||
return s
|
||||
|
||||
# ── Case 1+2: parser without streaming method → buffer fallback ──
|
||||
def test_buffer_path_no_markup_no_duplication(self):
|
||||
from types import SimpleNamespace
|
||||
|
||||
def parser_cls(called, content_text, calls):
|
||||
class _P:
|
||||
def __init__(self, tokenizer, tools=None):
|
||||
pass
|
||||
# NOTE: NO extract_tool_calls_streaming → takes the buffer path
|
||||
def extract_tool_calls(self, c, request=None):
|
||||
return SimpleNamespace(
|
||||
tools_called=called, content=content_text, tool_calls=calls,
|
||||
)
|
||||
return _P
|
||||
|
||||
tools_json = '[{"type":"function","function":{"name":"calc","parameters":{}}}]'
|
||||
|
||||
# Tool-call case: no raw markup in any delta.content
|
||||
s = self._new_servicer()
|
||||
s.llm = SimpleNamespace(generate=self._make_generate([
|
||||
'<tool_call>\n{"name": "calc"',
|
||||
'<tool_call>\n{"name": "calc", "arguments": {"x": 1}}\n</tool_call>',
|
||||
]))
|
||||
call = SimpleNamespace(id="call_1",
|
||||
function=SimpleNamespace(name="calc", arguments='{"x": 1}'))
|
||||
s.tool_parser_cls = parser_cls(True, "", [call])
|
||||
req = backend_pb2.PredictOptions(Prompt="x", Tools=tools_json)
|
||||
replies = self._collect(s, req)
|
||||
contents = [cd.content for r in replies for cd in r.chat_deltas if cd.content]
|
||||
self.assertFalse(
|
||||
any("<tool_call" in c for c in contents),
|
||||
f"markup leaked: {contents!r}",
|
||||
)
|
||||
names = [tc.name for r in replies for cd in r.chat_deltas for tc in cd.tool_calls]
|
||||
self.assertIn("calc", names, "tool_call missing from final chunk")
|
||||
|
||||
# Plain-text-with-tools case: full content delivered exactly once
|
||||
s2 = self._new_servicer()
|
||||
s2.llm = SimpleNamespace(generate=self._make_generate([
|
||||
"The capital ",
|
||||
"The capital of France is Paris.",
|
||||
]))
|
||||
s2.tool_parser_cls = parser_cls(False, "", [])
|
||||
req2 = backend_pb2.PredictOptions(Prompt="x", Tools=tools_json)
|
||||
joined = "".join(
|
||||
cd.content for r in self._collect(s2, req2)
|
||||
for cd in r.chat_deltas if cd.content
|
||||
)
|
||||
self.assertEqual(
|
||||
joined.count("The capital of France is Paris."), 1,
|
||||
f"buffered content duplicated: {joined!r}",
|
||||
)
|
||||
|
||||
# ── Case 3: native streaming, progressive plain text ──
|
||||
def test_native_streaming_progressive_plain_text(self):
|
||||
from types import SimpleNamespace
|
||||
|
||||
class _DeltaMsg:
|
||||
def __init__(self, content=None, reasoning=None, tool_calls=None):
|
||||
self.content = content
|
||||
self.reasoning = reasoning
|
||||
self.tool_calls = tool_calls or []
|
||||
|
||||
class StreamingParser:
|
||||
def __init__(self, tokenizer, tools=None):
|
||||
pass
|
||||
def extract_tool_calls(self, c, request=None):
|
||||
# Should NOT be called when native streaming runs successfully.
|
||||
raise AssertionError("extract_tool_calls invoked on native-streaming path")
|
||||
def extract_tool_calls_streaming(
|
||||
self, previous_text, current_text, delta_text,
|
||||
previous_token_ids, current_token_ids, delta_token_ids, request,
|
||||
):
|
||||
if not delta_text:
|
||||
return None
|
||||
return _DeltaMsg(content=delta_text)
|
||||
|
||||
s = self._new_servicer()
|
||||
s.llm = SimpleNamespace(generate=self._make_generate([
|
||||
"Paris ",
|
||||
"Paris is ",
|
||||
"Paris is the capital of France.",
|
||||
]))
|
||||
s.tool_parser_cls = StreamingParser
|
||||
req = backend_pb2.PredictOptions(
|
||||
Prompt="x",
|
||||
Tools='[{"type":"function","function":{"name":"calc","parameters":{}}}]',
|
||||
)
|
||||
replies = self._collect(s, req)
|
||||
|
||||
intermediate_content = [
|
||||
cd.content for r in replies[:-1] for cd in r.chat_deltas if cd.content
|
||||
]
|
||||
self.assertTrue(
|
||||
len(intermediate_content) > 0,
|
||||
"Plain-text response not streamed progressively (native streaming inactive?)",
|
||||
)
|
||||
assembled = "".join(
|
||||
cd.content for r in replies for cd in r.chat_deltas if cd.content
|
||||
)
|
||||
self.assertEqual(
|
||||
assembled, "Paris is the capital of France.",
|
||||
f"Assembled content wrong: {assembled!r}",
|
||||
)
|
||||
|
||||
# ── Case 4: native streaming, structured tool_call, no markup ──
|
||||
def test_native_streaming_tool_call_no_markup_leak(self):
|
||||
from types import SimpleNamespace
|
||||
|
||||
class _DeltaMsg:
|
||||
def __init__(self, content=None, reasoning=None, tool_calls=None):
|
||||
self.content = content
|
||||
self.reasoning = reasoning
|
||||
self.tool_calls = tool_calls or []
|
||||
|
||||
class _ToolCallStreamer:
|
||||
def __init__(self, tokenizer, tools=None):
|
||||
self._emitted = False
|
||||
def extract_tool_calls(self, c, request=None):
|
||||
raise AssertionError("extract_tool_calls invoked on native-streaming path")
|
||||
def extract_tool_calls_streaming(
|
||||
self, previous_text, current_text, delta_text,
|
||||
previous_token_ids, current_token_ids, delta_token_ids, request,
|
||||
):
|
||||
if "</tool_call>" in current_text and not self._emitted:
|
||||
self._emitted = True
|
||||
fn = SimpleNamespace(name="calc", arguments='{"x": 1}')
|
||||
tc = SimpleNamespace(id="call_1", type="function", index=0, function=fn)
|
||||
return _DeltaMsg(tool_calls=[tc])
|
||||
return None
|
||||
|
||||
s = self._new_servicer()
|
||||
s.llm = SimpleNamespace(generate=self._make_generate([
|
||||
'<tool_call>\n',
|
||||
'<tool_call>\n{"name": "calc"',
|
||||
'<tool_call>\n{"name": "calc", "arguments": {"x": 1}}\n</tool_call>',
|
||||
]))
|
||||
s.tool_parser_cls = _ToolCallStreamer
|
||||
req = backend_pb2.PredictOptions(
|
||||
Prompt="x",
|
||||
Tools='[{"type":"function","function":{"name":"calc","parameters":{}}}]',
|
||||
)
|
||||
replies = self._collect(s, req)
|
||||
|
||||
contents = [cd.content for r in replies for cd in r.chat_deltas if cd.content]
|
||||
self.assertFalse(
|
||||
any("<tool_call" in c or "</tool_call>" in c for c in contents),
|
||||
f"markup leaked as content: {contents!r}",
|
||||
)
|
||||
names = [tc.name for r in replies for cd in r.chat_deltas for tc in cd.tool_calls if tc.name]
|
||||
args = [tc.arguments for r in replies for cd in r.chat_deltas for tc in cd.tool_calls if tc.arguments]
|
||||
self.assertIn("calc", names, f"tool_call name missing; got {names!r}")
|
||||
self.assertIn('{"x": 1}', args, f"tool_call args missing; got {args!r}")
|
||||
|
||||
# ── Case 5: parser exception → fallback to buffer, no leak ──
|
||||
def test_native_streaming_parser_exception_falls_back_to_buffer(self):
|
||||
from types import SimpleNamespace
|
||||
call = SimpleNamespace(id="call_1",
|
||||
function=SimpleNamespace(name="calc", arguments='{"x": 1}'))
|
||||
|
||||
class _BrokenStreamer:
|
||||
def __init__(self, tokenizer, tools=None):
|
||||
pass
|
||||
def extract_tool_calls(self, c, request=None):
|
||||
return SimpleNamespace(tools_called=True, content="", tool_calls=[call])
|
||||
def extract_tool_calls_streaming(self, *a, **kw):
|
||||
raise RuntimeError("simulated parser bug")
|
||||
|
||||
s = self._new_servicer()
|
||||
s.llm = SimpleNamespace(generate=self._make_generate([
|
||||
'<tool_call>\n{"name": "calc"',
|
||||
'<tool_call>\n{"name": "calc", "arguments": {"x": 1}}\n</tool_call>',
|
||||
]))
|
||||
s.tool_parser_cls = _BrokenStreamer
|
||||
req = backend_pb2.PredictOptions(
|
||||
Prompt="x",
|
||||
Tools='[{"type":"function","function":{"name":"calc","parameters":{}}}]',
|
||||
)
|
||||
replies = self._collect(s, req)
|
||||
|
||||
contents = [cd.content for r in replies for cd in r.chat_deltas if cd.content]
|
||||
self.assertFalse(
|
||||
any("<tool_call" in c for c in contents),
|
||||
f"markup leaked after parser exception: {contents!r}",
|
||||
)
|
||||
names = [tc.name for r in replies for cd in r.chat_deltas for tc in cd.tool_calls]
|
||||
self.assertIn("calc", names, "tool_call missing from final chunk after fallback")
|
||||
|
||||
# ── Case 6: no tool parser → unchanged per-delta content stream ──
|
||||
def test_no_tool_parser_unchanged_per_delta_stream(self):
|
||||
from types import SimpleNamespace
|
||||
s = self._new_servicer() # tool_parser_cls already None
|
||||
s.llm = SimpleNamespace(generate=self._make_generate([
|
||||
"Hello ", "Hello world", "Hello world!",
|
||||
]))
|
||||
req = backend_pb2.PredictOptions(Prompt="x", Tools="")
|
||||
replies = self._collect(s, req)
|
||||
|
||||
intermediate = [
|
||||
cd.content for r in replies[:-1] for cd in r.chat_deltas if cd.content
|
||||
]
|
||||
self.assertEqual(
|
||||
intermediate, ["Hello ", "world", "!"],
|
||||
f"plain streaming changed; got {intermediate!r}",
|
||||
)
|
||||
|
||||
54
examples/vllm-bench/README.md
Normal file
54
examples/vllm-bench/README.md
Normal file
@@ -0,0 +1,54 @@
|
||||
# vLLM streaming + tool-parser benchmark
|
||||
|
||||
A small, self-contained Python script (stdlib only) that measures
|
||||
time-to-first-token (TTFT) for the vLLM backend's streaming path with
|
||||
a tool parser configured.
|
||||
|
||||
## Why this exists
|
||||
|
||||
When a vLLM tool parser is active and a streaming chat completion is requested,
|
||||
LocalAI used to buffer the full generation to prevent raw tool-call markup
|
||||
(e.g. `<tool_call>...`) from leaking as `delta.content`. That was correct
|
||||
for tool-call responses, but it turned plain-text responses into effectively
|
||||
non-streaming — the client received nothing until the model finished.
|
||||
|
||||
With native parser-side streaming (`parser.extract_tool_calls_streaming`,
|
||||
implemented by every concrete vLLM 0.23+ tool parser), each delta can be
|
||||
classified per-token: emit as content, emit as a structured tool_call, or
|
||||
suppress.
|
||||
|
||||
## Three scenarios
|
||||
|
||||
| Scenario | Request | Expected outcome |
|
||||
|---|---|---|
|
||||
| `tool_call` | "What is the weather in Paris? Please use the tool." | Model calls `get_weather`. `delta.tool_calls` chunks; no content leak. |
|
||||
| `plain_text_short` | "Explain in 3 short sentences what a hash table is. Do NOT call any tool." | Model writes ~3 sentences. |
|
||||
| `plain_text_long` | "Write a thorough 8-paragraph explanation of how Python's GIL works…" | Model writes ~1500 tokens of prose. |
|
||||
|
||||
The **long scenario** is where the streaming/buffering difference is most
|
||||
dramatic: with the buffer-all path, the client sees nothing for 20+ seconds
|
||||
and then everything at once; with native streaming, the first token arrives
|
||||
in <100ms and the response flows progressively.
|
||||
|
||||
## What the script reports
|
||||
|
||||
For each scenario, across N runs:
|
||||
|
||||
- `ttf_content_s` — time until the first `delta.content` chunk
|
||||
- `ttf_tool_s` — time until the first `delta.tool_calls` chunk
|
||||
- `n_content_chunks` — total content deltas (1 = bundled, >>1 = streamed)
|
||||
- `n_tool_chunks` — total tool_call deltas
|
||||
- `total_s` — total wall-clock until `[DONE]`
|
||||
- `finish_reason` — `tool_calls` / `stop` / `length`
|
||||
|
||||
The big tell is **`n_content_chunks` vs `total_s` ratio**:
|
||||
- Buffer-all: `n_content_chunks` ≈ 1, `ttf_content_s` ≈ `total_s` (one chunk at end)
|
||||
- Streaming: `n_content_chunks` ≈ token count, `ttf_content_s` ≈ first-token latency
|
||||
|
||||
## Usage
|
||||
|
||||
```bash
|
||||
python ttft_streaming_tool_parser.py --url http://localhost:8080 --model my-coder --runs 3
|
||||
```
|
||||
|
||||
JSON results are written to `ttft_bench_<label>.json` (default label: `run`).
|
||||
175
examples/vllm-bench/ttft_streaming_tool_parser.py
Executable file
175
examples/vllm-bench/ttft_streaming_tool_parser.py
Executable file
@@ -0,0 +1,175 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
TTFT benchmark for the vLLM backend's streaming + tool-parser path.
|
||||
|
||||
Three scenarios:
|
||||
1. tool_call — request mentions a tool; model is expected to call it
|
||||
2. plain_text_short — request offers a tool but explicitly asks for ~3 sentences
|
||||
3. plain_text_long — same as above but asks for ~8 paragraphs (1500 tokens)
|
||||
|
||||
The long scenario shows the dramatic difference between buffering and
|
||||
streaming most clearly: with buffer-all, the client sees nothing for
|
||||
20+ seconds; with native streaming, the first token arrives in <100 ms.
|
||||
|
||||
Usage:
|
||||
python ttft_streaming_tool_parser.py \\
|
||||
--url http://localhost:8080 --model my-coder --runs 3
|
||||
|
||||
The script is self-contained (stdlib only — urllib, json, time, argparse).
|
||||
"""
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
import time
|
||||
import urllib.request
|
||||
|
||||
DEFAULT_TOOLS = [{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get current weather for a city",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"city": {"type": "string"}},
|
||||
"required": ["city"],
|
||||
},
|
||||
},
|
||||
}]
|
||||
|
||||
SCENARIOS = [
|
||||
{
|
||||
"label": "tool_call",
|
||||
"messages": [{"role": "user",
|
||||
"content": "What is the weather in Paris? Please use the tool."}],
|
||||
"max_tokens": 80,
|
||||
},
|
||||
{
|
||||
"label": "plain_text_short",
|
||||
"messages": [{"role": "user",
|
||||
"content": "Explain in 3 short sentences what a hash table is. "
|
||||
"Do NOT call any tool."}],
|
||||
"max_tokens": 200,
|
||||
},
|
||||
{
|
||||
"label": "plain_text_long",
|
||||
"messages": [{"role": "user",
|
||||
"content": "Write a thorough 8-paragraph explanation of how "
|
||||
"Python's GIL works, including history, current "
|
||||
"state, no-GIL build, and alternatives. Be "
|
||||
"detailed. Do NOT call any tool."}],
|
||||
"max_tokens": 1500,
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def bench_one(url, model, messages, tools, max_tokens, timeout):
|
||||
body = json.dumps({
|
||||
"model": model,
|
||||
"stream": True,
|
||||
"tools": tools,
|
||||
"messages": messages,
|
||||
"max_tokens": max_tokens,
|
||||
}).encode()
|
||||
req = urllib.request.Request(
|
||||
f"{url.rstrip('/')}/v1/chat/completions",
|
||||
data=body, headers={"Content-Type": "application/json"},
|
||||
)
|
||||
|
||||
t0 = time.perf_counter()
|
||||
first_content = None
|
||||
first_tool = None
|
||||
n_content = 0
|
||||
n_tool = 0
|
||||
last = None
|
||||
finish = None
|
||||
with urllib.request.urlopen(req, timeout=timeout) as resp:
|
||||
for line in resp:
|
||||
line = line.decode("utf-8", "replace").strip()
|
||||
if not line.startswith("data: "):
|
||||
continue
|
||||
payload = line[6:]
|
||||
if payload == "[DONE]":
|
||||
break
|
||||
try:
|
||||
chunk = json.loads(payload)
|
||||
except Exception:
|
||||
continue
|
||||
if not chunk.get("choices"):
|
||||
continue
|
||||
ch = chunk["choices"][0]
|
||||
delta = ch.get("delta") or {}
|
||||
now = time.perf_counter() - t0
|
||||
if delta.get("content"):
|
||||
if first_content is None:
|
||||
first_content = now
|
||||
n_content += 1
|
||||
if delta.get("tool_calls"):
|
||||
if first_tool is None:
|
||||
first_tool = now
|
||||
n_tool += 1
|
||||
if ch.get("finish_reason"):
|
||||
finish = ch["finish_reason"]
|
||||
last = now
|
||||
return {
|
||||
"ttf_content_s": first_content,
|
||||
"ttf_tool_s": first_tool,
|
||||
"n_content_chunks": n_content,
|
||||
"n_tool_chunks": n_tool,
|
||||
"total_s": last,
|
||||
"finish_reason": finish,
|
||||
}
|
||||
|
||||
|
||||
def stats(values):
|
||||
values = [v for v in values if v is not None]
|
||||
if not values:
|
||||
return "n/a"
|
||||
return f"min={min(values):.3f} avg={sum(values)/len(values):.3f} max={max(values):.3f}"
|
||||
|
||||
|
||||
def main():
|
||||
p = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter)
|
||||
p.add_argument("--url", default="http://localhost:8080",
|
||||
help="LocalAI base URL (default: %(default)s)")
|
||||
p.add_argument("--model", default="coder", help="Model name (default: %(default)s)")
|
||||
p.add_argument("--runs", type=int, default=3, help="Repetitions per scenario (default: %(default)s)")
|
||||
p.add_argument("--timeout", type=int, default=180, help="Per-request timeout in seconds")
|
||||
p.add_argument("--label", default="run",
|
||||
help="Tag for the JSON output file (default: %(default)s)")
|
||||
args = p.parse_args()
|
||||
|
||||
print(f"=== TTFT Bench — {args.url} model={args.model} runs={args.runs} ===")
|
||||
summary = {}
|
||||
for sc in SCENARIOS:
|
||||
print(f"\nScenario: {sc['label']}")
|
||||
rows = []
|
||||
for run in range(args.runs):
|
||||
r = bench_one(args.url, args.model,
|
||||
sc["messages"], DEFAULT_TOOLS, sc["max_tokens"], args.timeout)
|
||||
rows.append(r)
|
||||
ttf_c = f"{r['ttf_content_s']:.3f}" if r["ttf_content_s"] is not None else "—"
|
||||
ttf_t = f"{r['ttf_tool_s']:.3f}" if r["ttf_tool_s"] is not None else "—"
|
||||
print(f" run {run+1}/{args.runs}: "
|
||||
f"ttf_content={ttf_c}s ttf_tool={ttf_t}s "
|
||||
f"n_content={r['n_content_chunks']} n_tool={r['n_tool_chunks']} "
|
||||
f"total={r['total_s']:.2f}s finish={r['finish_reason']}")
|
||||
summary[sc["label"]] = rows
|
||||
|
||||
print("\n=== Summary (per scenario) ===")
|
||||
for label, rows in summary.items():
|
||||
print(f"[{label}]")
|
||||
print(f" ttf_content_s: {stats(r['ttf_content_s'] for r in rows)}")
|
||||
print(f" ttf_tool_s: {stats(r['ttf_tool_s'] for r in rows)}")
|
||||
print(f" n_content_chunks: {stats(r['n_content_chunks'] for r in rows)}")
|
||||
print(f" n_tool_chunks: {stats(r['n_tool_chunks'] for r in rows)}")
|
||||
print(f" total_s: {stats(r['total_s'] for r in rows)}")
|
||||
|
||||
out = f"ttft_bench_{args.label}.json"
|
||||
with open(out, "w") as f:
|
||||
json.dump(summary, f, indent=2)
|
||||
print(f"\nSaved to {out}")
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
Reference in New Issue
Block a user