mirror of
https://github.com/mudler/LocalAI.git
synced 2026-06-22 07:39:02 -04:00
feat(vllm): progressive streaming via parser.extract_tool_calls_streaming (follow-up to #10346) (#10351)
* fix(vllm): don't stream raw tool-call markup as content when a tool parser is active When a tool_parser is configured and the request carries tools, the streaming loop emitted every text delta as delta.content — including the model's raw tool-call markup (e.g. <tool_call>...) — because extract_tool_calls only runs on the full output after the stream. Clients streaming a tool call therefore saw the unparsed tool-call syntax as assistant content. Buffer the text while a tool parser is active for the request; the existing end-of-stream chat_delta already carries the parsed tool_calls (or the cleaned content), which the Go side converts to SSE deltas. Non-tool-parser streaming is unchanged. Add a server-less regression test covering both the tool-call case (no raw markup leaked as content) and the plain-text case (content delivered exactly once — guards against double-emitting the buffered content). Signed-off-by: pos-ei-don <1822533+pos-ei-don@users.noreply.github.com> * test(vllm): add expectedFailure test for progressive streaming with tool parser (Case 3, #582) Signed-off-by: pos-ei-don <1822533+pos-ei-don@users.noreply.github.com> * test(vllm): add Cases 4+5 — marker split across chunks + false-positive prefix (TDD, Option B state machine, #582) Signed-off-by: pos-ei-don <1822533+pos-ei-don@users.noreply.github.com> * feat(vllm): progressive streaming via parser.extract_tool_calls_streaming When a tool parser is active for a tool-enabled streaming request, #10346 buffers the entire generation and surfaces it on the final chunk to prevent raw tool-call markup from leaking as delta.content. This is correct but turns the request into effectively non-streaming for plain-text responses — the client sees nothing until the model stops. Every concrete tool parser shipped with vLLM 0.23+ already implements extract_tool_calls_streaming (Granite4, Qwen3Coder, DeepSeekV31, Jamba, Ernie45, Hermes2Pro, llama3_json, mistral, …). Use it: instantiate the parser before the streaming loop and call its streaming method per delta, emitting DeltaMessage(content=…) or DeltaMessage(tool_calls=[…]) when the parser is ready. Falls back to the existing #10346 buffer path when: - the parser does not have extract_tool_calls_streaming, OR - extract_tool_calls_streaming raises mid-stream (logged, the rest of the request finishes via post-loop extract_tool_calls). Tests (TestStreamingToolParser): 1. Buffer path: no markup leaked, no content duplication 2. Native streaming: plain-text response streams progressively 3. Native streaming: tool_call structured, no markup leaked 4. Native streaming exception → graceful fallback, no markup, no crash 5. No tool parser → unchanged per-delta content stream E2E verified against qwen3_coder on vLLM 0.23.0 (NVIDIA GB10 / arm64 / CUDA 13). Signed-off-by: pos-ei-don <1822533+pos-ei-don@users.noreply.github.com> * docs(vllm): add server-side TTFT benchmark for the streaming tool-parser path Self-contained stdlib-only script that measures time-to-first-token (TTFT) for the vLLM backend's two streaming scenarios: - tool_call: request mentions a tool; model is expected to call it - plain_text: request offers a tool but explicitly asks for prose Use this to compare: - the buffer-all path (#10346) → plain_text TTFT ≈ total response time - the native-streaming path (this PR) → plain_text TTFT ≈ true first-token time python examples/vllm-bench/ttft_streaming_tool_parser.py \\ --url http://localhost:8080 --model my-coder --runs 3 Lives under examples/ so it does not interfere with the test suite. Signed-off-by: pos-ei-don <1822533+pos-ei-don@users.noreply.github.com> * examples/vllm-bench: add long-text scenario (8 paragraphs, 1500 tokens) The long-text scenario shows the buffering vs streaming difference most dramatically: with the buffer-all path, the client receives nothing for 20+ seconds and then the entire 1500-token response at once. With native streaming, the first token arrives in tens of milliseconds and the response flows progressively. Signed-off-by: pos-ei-don <1822533+pos-ei-don@users.noreply.github.com> --------- Signed-off-by: pos-ei-don <1822533+pos-ei-don@users.noreply.github.com> Co-authored-by: Philipp Wacker <philipp.wacker@ibf-solutions.com>
This commit is contained in:
@@ -598,23 +598,124 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
|
||||
# Stream the results
|
||||
generated_text = ""
|
||||
generated_token_ids: list[int] = []
|
||||
last_output = None
|
||||
|
||||
# Tool-parsing strategy decision (made once, before the loop):
|
||||
#
|
||||
# When a tool parser is active, the model's raw tool-call markup
|
||||
# (e.g. <tool_call>...) must not be streamed verbatim as delta.content
|
||||
# — clients would see the unparsed syntax. Two paths:
|
||||
#
|
||||
# (A) native streaming via parser.extract_tool_calls_streaming. All
|
||||
# concrete tool parsers shipped with vLLM 0.23+ implement this
|
||||
# (Granite4, Qwen3Coder, DeepSeekV31, Jamba, Ernie45, Hermes,
|
||||
# llama3_json, mistral, …). The parser decides per-delta whether
|
||||
# to emit content or suppress tool-call markup, and emits a
|
||||
# structured DeltaMessage(tool_calls=[...]) when a call is ready.
|
||||
# (B) buffer fallback — used only when the parser surprisingly lacks
|
||||
# the streaming method or it raises mid-stream. The post-loop
|
||||
# extract_tool_calls assembles the final chat_delta. Same correctness
|
||||
# guarantee as a non-streaming response, at the cost of a delayed
|
||||
# final chunk.
|
||||
has_tool_parser = bool(self.tool_parser_cls and request.Tools)
|
||||
tp_instance = None
|
||||
tp_request = None
|
||||
native_streaming = False
|
||||
native_streaming_error = False
|
||||
if has_tool_parser:
|
||||
try:
|
||||
tools_for_parser = json.loads(request.Tools)
|
||||
except json.JSONDecodeError:
|
||||
tools_for_parser = []
|
||||
try:
|
||||
tp_instance = self.tool_parser_cls(self.tokenizer, tools=tools_for_parser)
|
||||
except TypeError:
|
||||
tp_instance = self.tool_parser_cls(self.tokenizer)
|
||||
# Build a minimal ChatCompletionRequest so the streaming method
|
||||
# sees the tools list. We do not need any other request fields —
|
||||
# parsers only read .tools (and sometimes .tool_choice, which we
|
||||
# leave at default).
|
||||
try:
|
||||
from vllm.entrypoints.openai.chat_completion.protocol import (
|
||||
ChatCompletionRequest as _CCR,
|
||||
)
|
||||
tp_request = _CCR(
|
||||
model="local",
|
||||
messages=[{"role": "user", "content": ""}],
|
||||
tools=tools_for_parser or None,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Could not build ChatCompletionRequest for streaming parser: {e}",
|
||||
file=sys.stderr)
|
||||
tp_request = None
|
||||
native_streaming = (
|
||||
tp_request is not None
|
||||
and hasattr(tp_instance, "extract_tool_calls_streaming")
|
||||
)
|
||||
|
||||
try:
|
||||
async for request_output in outputs:
|
||||
iteration_text = request_output.outputs[0].text
|
||||
last_output = request_output
|
||||
|
||||
if streaming:
|
||||
# Remove text already sent as vllm concatenates the text from previous yields
|
||||
delta_iteration_text = iteration_text.removeprefix(generated_text)
|
||||
# Send the partial result
|
||||
yield backend_pb2.Reply(
|
||||
message=bytes(delta_iteration_text, encoding='utf-8'),
|
||||
chat_deltas=[backend_pb2.ChatDelta(content=delta_iteration_text)],
|
||||
)
|
||||
new_token_ids = list(request_output.outputs[0].token_ids)
|
||||
delta_token_ids = new_token_ids[len(generated_token_ids):]
|
||||
|
||||
# Keep track of text generated
|
||||
if not has_tool_parser:
|
||||
# Plain streaming — unchanged from pre-tool-parser path.
|
||||
yield backend_pb2.Reply(
|
||||
message=bytes(delta_iteration_text, encoding='utf-8'),
|
||||
chat_deltas=[backend_pb2.ChatDelta(content=delta_iteration_text)],
|
||||
)
|
||||
elif native_streaming and not native_streaming_error:
|
||||
# (A) Native vLLM extract_tool_calls_streaming.
|
||||
try:
|
||||
msg = tp_instance.extract_tool_calls_streaming(
|
||||
previous_text=generated_text,
|
||||
current_text=iteration_text,
|
||||
delta_text=delta_iteration_text,
|
||||
previous_token_ids=generated_token_ids,
|
||||
current_token_ids=new_token_ids,
|
||||
delta_token_ids=delta_token_ids,
|
||||
request=tp_request,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Streaming tool parser error (falling back to "
|
||||
f"buffer for the rest of the stream): {e}",
|
||||
file=sys.stderr)
|
||||
native_streaming_error = True
|
||||
msg = None
|
||||
if msg is not None:
|
||||
tc_protos = []
|
||||
for tc in (msg.tool_calls or []):
|
||||
fn = tc.function or None
|
||||
tc_protos.append(backend_pb2.ToolCallDelta(
|
||||
index=tc.index,
|
||||
id=tc.id or "",
|
||||
name=(fn.name if fn and fn.name else "") or "",
|
||||
arguments=(fn.arguments if fn and fn.arguments else "") or "",
|
||||
))
|
||||
cd_kwargs = {}
|
||||
if msg.content:
|
||||
cd_kwargs["content"] = msg.content
|
||||
if msg.reasoning:
|
||||
cd_kwargs["reasoning_content"] = msg.reasoning
|
||||
if tc_protos:
|
||||
cd_kwargs["tool_calls"] = tc_protos
|
||||
if cd_kwargs:
|
||||
yield backend_pb2.Reply(
|
||||
message=bytes(msg.content or "", encoding='utf-8'),
|
||||
chat_deltas=[backend_pb2.ChatDelta(**cd_kwargs)],
|
||||
)
|
||||
# (B) buffer fallback — emit nothing during the stream.
|
||||
# The post-loop extract_tool_calls block builds the final chunk.
|
||||
|
||||
# Keep track of text + token_ids generated
|
||||
generated_text = iteration_text
|
||||
generated_token_ids = list(request_output.outputs[0].token_ids)
|
||||
finally:
|
||||
await outputs.aclose()
|
||||
|
||||
@@ -639,16 +740,19 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
except Exception as e:
|
||||
print(f"Reasoning parser error: {e}", file=sys.stderr)
|
||||
|
||||
if self.tool_parser_cls and request.Tools:
|
||||
# When (A) native streaming ran cleanly, per-delta yields above already
|
||||
# delivered everything — do NOT extract again on the full text or we'd
|
||||
# duplicate content/tool_calls into the final chunk.
|
||||
if has_tool_parser and not (native_streaming and not native_streaming_error):
|
||||
try:
|
||||
tools = json.loads(request.Tools)
|
||||
# Some concrete parsers only accept the tokenizer; only the
|
||||
# abstract base declares the tools kwarg. Try with tools first,
|
||||
# fall back to tokenizer-only.
|
||||
try:
|
||||
tp = self.tool_parser_cls(self.tokenizer, tools=tools)
|
||||
except TypeError:
|
||||
tp = self.tool_parser_cls(self.tokenizer)
|
||||
tp = tp_instance
|
||||
if tp is None:
|
||||
# Defensive: tp_instance build failed earlier; reconstruct.
|
||||
tools = json.loads(request.Tools)
|
||||
try:
|
||||
tp = self.tool_parser_cls(self.tokenizer, tools=tools)
|
||||
except TypeError:
|
||||
tp = self.tool_parser_cls(self.tokenizer)
|
||||
info = tp.extract_tool_calls(content, request=None)
|
||||
if info.tools_called:
|
||||
content = info.content or ""
|
||||
@@ -661,6 +765,10 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
))
|
||||
except Exception as e:
|
||||
print(f"Tool parser error: {e}", file=sys.stderr)
|
||||
elif native_streaming and not native_streaming_error:
|
||||
# Per-delta path already emitted content + tool_calls; the final
|
||||
# chat_delta should carry only metadata (token counts, logprobs).
|
||||
content = ""
|
||||
|
||||
# Extract token counts
|
||||
prompt_tokens = 0
|
||||
@@ -700,7 +808,26 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
)
|
||||
|
||||
if streaming:
|
||||
# Final chunk with structured data
|
||||
# Final chunk with structured data.
|
||||
#
|
||||
# If we used the buffer fallback (has_tool_parser=True AND native
|
||||
# streaming did NOT run cleanly) and the parser found no tool call,
|
||||
# flush the buffered content as ONE content delta — and clear the
|
||||
# final chat_delta's content so the metadata chunk does not repeat
|
||||
# what we just sent. This is the plain-text-with-tool-parser path.
|
||||
buffered_fallback = (
|
||||
has_tool_parser
|
||||
and not (native_streaming and not native_streaming_error)
|
||||
)
|
||||
if buffered_fallback and not tool_calls_proto and content:
|
||||
yield backend_pb2.Reply(
|
||||
message=bytes(content, encoding='utf-8'),
|
||||
chat_deltas=[backend_pb2.ChatDelta(content=content)],
|
||||
)
|
||||
chat_delta = backend_pb2.ChatDelta(
|
||||
reasoning_content=reasoning_content,
|
||||
tool_calls=tool_calls_proto,
|
||||
)
|
||||
yield backend_pb2.Reply(
|
||||
message=b"",
|
||||
prompt_tokens=prompt_tokens,
|
||||
|
||||
@@ -278,4 +278,261 @@ class TestBackendServicer(unittest.TestCase):
|
||||
print(err)
|
||||
self.fail("Embedding service failed")
|
||||
finally:
|
||||
self.tearDown()
|
||||
self.tearDown()
|
||||
|
||||
|
||||
class TestStreamingToolParser(unittest.TestCase):
|
||||
"""
|
||||
Server-less unit tests for the streaming + tool-parser machinery in
|
||||
BackendServicer._predict. These tests instantiate BackendServicer
|
||||
directly and mock the vLLM engine + tool parser, so they do not need
|
||||
a GPU, a model, or a running gRPC server. Kept in a separate class to
|
||||
avoid the parent setUp() which spawns a subprocess.
|
||||
|
||||
Covers #582 (follow-up to #10346):
|
||||
1. Markup-leak prevention with a non-streaming parser (buffer fallback)
|
||||
2. No content duplication on the plain-text path with the buffer fallback
|
||||
3. Native streaming progressive plain-text emission
|
||||
4. Native streaming structured tool_call, no markup leak
|
||||
5. Parser exception → graceful fallback to buffer, still no markup
|
||||
6. No-tool-parser regression: unchanged per-delta content stream
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _make_generate(chunks):
|
||||
"""Build a fake vLLM engine.generate that yields cumulative chunks."""
|
||||
from types import SimpleNamespace
|
||||
async def gen(*a, **k):
|
||||
for i, t in enumerate(chunks):
|
||||
yield SimpleNamespace(
|
||||
outputs=[SimpleNamespace(
|
||||
text=t,
|
||||
token_ids=list(range(i + 1)),
|
||||
logprobs=None,
|
||||
)],
|
||||
prompt_token_ids=[0],
|
||||
)
|
||||
return lambda *a, **k: gen()
|
||||
|
||||
@staticmethod
|
||||
def _collect(servicer, req):
|
||||
import asyncio
|
||||
async def run():
|
||||
return [r async for r in servicer._predict(req, None, streaming=True)]
|
||||
return asyncio.run(run())
|
||||
|
||||
def _new_servicer(self):
|
||||
import sys, os
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
from backend import BackendServicer
|
||||
s = BackendServicer()
|
||||
s.reasoning_parser_cls = None
|
||||
s.tool_parser_cls = None
|
||||
s.tokenizer = None
|
||||
return s
|
||||
|
||||
# ── Case 1+2: parser without streaming method → buffer fallback ──
|
||||
def test_buffer_path_no_markup_no_duplication(self):
|
||||
from types import SimpleNamespace
|
||||
|
||||
def parser_cls(called, content_text, calls):
|
||||
class _P:
|
||||
def __init__(self, tokenizer, tools=None):
|
||||
pass
|
||||
# NOTE: NO extract_tool_calls_streaming → takes the buffer path
|
||||
def extract_tool_calls(self, c, request=None):
|
||||
return SimpleNamespace(
|
||||
tools_called=called, content=content_text, tool_calls=calls,
|
||||
)
|
||||
return _P
|
||||
|
||||
tools_json = '[{"type":"function","function":{"name":"calc","parameters":{}}}]'
|
||||
|
||||
# Tool-call case: no raw markup in any delta.content
|
||||
s = self._new_servicer()
|
||||
s.llm = SimpleNamespace(generate=self._make_generate([
|
||||
'<tool_call>\n{"name": "calc"',
|
||||
'<tool_call>\n{"name": "calc", "arguments": {"x": 1}}\n</tool_call>',
|
||||
]))
|
||||
call = SimpleNamespace(id="call_1",
|
||||
function=SimpleNamespace(name="calc", arguments='{"x": 1}'))
|
||||
s.tool_parser_cls = parser_cls(True, "", [call])
|
||||
req = backend_pb2.PredictOptions(Prompt="x", Tools=tools_json)
|
||||
replies = self._collect(s, req)
|
||||
contents = [cd.content for r in replies for cd in r.chat_deltas if cd.content]
|
||||
self.assertFalse(
|
||||
any("<tool_call" in c for c in contents),
|
||||
f"markup leaked: {contents!r}",
|
||||
)
|
||||
names = [tc.name for r in replies for cd in r.chat_deltas for tc in cd.tool_calls]
|
||||
self.assertIn("calc", names, "tool_call missing from final chunk")
|
||||
|
||||
# Plain-text-with-tools case: full content delivered exactly once
|
||||
s2 = self._new_servicer()
|
||||
s2.llm = SimpleNamespace(generate=self._make_generate([
|
||||
"The capital ",
|
||||
"The capital of France is Paris.",
|
||||
]))
|
||||
s2.tool_parser_cls = parser_cls(False, "", [])
|
||||
req2 = backend_pb2.PredictOptions(Prompt="x", Tools=tools_json)
|
||||
joined = "".join(
|
||||
cd.content for r in self._collect(s2, req2)
|
||||
for cd in r.chat_deltas if cd.content
|
||||
)
|
||||
self.assertEqual(
|
||||
joined.count("The capital of France is Paris."), 1,
|
||||
f"buffered content duplicated: {joined!r}",
|
||||
)
|
||||
|
||||
# ── Case 3: native streaming, progressive plain text ──
|
||||
def test_native_streaming_progressive_plain_text(self):
|
||||
from types import SimpleNamespace
|
||||
|
||||
class _DeltaMsg:
|
||||
def __init__(self, content=None, reasoning=None, tool_calls=None):
|
||||
self.content = content
|
||||
self.reasoning = reasoning
|
||||
self.tool_calls = tool_calls or []
|
||||
|
||||
class StreamingParser:
|
||||
def __init__(self, tokenizer, tools=None):
|
||||
pass
|
||||
def extract_tool_calls(self, c, request=None):
|
||||
# Should NOT be called when native streaming runs successfully.
|
||||
raise AssertionError("extract_tool_calls invoked on native-streaming path")
|
||||
def extract_tool_calls_streaming(
|
||||
self, previous_text, current_text, delta_text,
|
||||
previous_token_ids, current_token_ids, delta_token_ids, request,
|
||||
):
|
||||
if not delta_text:
|
||||
return None
|
||||
return _DeltaMsg(content=delta_text)
|
||||
|
||||
s = self._new_servicer()
|
||||
s.llm = SimpleNamespace(generate=self._make_generate([
|
||||
"Paris ",
|
||||
"Paris is ",
|
||||
"Paris is the capital of France.",
|
||||
]))
|
||||
s.tool_parser_cls = StreamingParser
|
||||
req = backend_pb2.PredictOptions(
|
||||
Prompt="x",
|
||||
Tools='[{"type":"function","function":{"name":"calc","parameters":{}}}]',
|
||||
)
|
||||
replies = self._collect(s, req)
|
||||
|
||||
intermediate_content = [
|
||||
cd.content for r in replies[:-1] for cd in r.chat_deltas if cd.content
|
||||
]
|
||||
self.assertTrue(
|
||||
len(intermediate_content) > 0,
|
||||
"Plain-text response not streamed progressively (native streaming inactive?)",
|
||||
)
|
||||
assembled = "".join(
|
||||
cd.content for r in replies for cd in r.chat_deltas if cd.content
|
||||
)
|
||||
self.assertEqual(
|
||||
assembled, "Paris is the capital of France.",
|
||||
f"Assembled content wrong: {assembled!r}",
|
||||
)
|
||||
|
||||
# ── Case 4: native streaming, structured tool_call, no markup ──
|
||||
def test_native_streaming_tool_call_no_markup_leak(self):
|
||||
from types import SimpleNamespace
|
||||
|
||||
class _DeltaMsg:
|
||||
def __init__(self, content=None, reasoning=None, tool_calls=None):
|
||||
self.content = content
|
||||
self.reasoning = reasoning
|
||||
self.tool_calls = tool_calls or []
|
||||
|
||||
class _ToolCallStreamer:
|
||||
def __init__(self, tokenizer, tools=None):
|
||||
self._emitted = False
|
||||
def extract_tool_calls(self, c, request=None):
|
||||
raise AssertionError("extract_tool_calls invoked on native-streaming path")
|
||||
def extract_tool_calls_streaming(
|
||||
self, previous_text, current_text, delta_text,
|
||||
previous_token_ids, current_token_ids, delta_token_ids, request,
|
||||
):
|
||||
if "</tool_call>" in current_text and not self._emitted:
|
||||
self._emitted = True
|
||||
fn = SimpleNamespace(name="calc", arguments='{"x": 1}')
|
||||
tc = SimpleNamespace(id="call_1", type="function", index=0, function=fn)
|
||||
return _DeltaMsg(tool_calls=[tc])
|
||||
return None
|
||||
|
||||
s = self._new_servicer()
|
||||
s.llm = SimpleNamespace(generate=self._make_generate([
|
||||
'<tool_call>\n',
|
||||
'<tool_call>\n{"name": "calc"',
|
||||
'<tool_call>\n{"name": "calc", "arguments": {"x": 1}}\n</tool_call>',
|
||||
]))
|
||||
s.tool_parser_cls = _ToolCallStreamer
|
||||
req = backend_pb2.PredictOptions(
|
||||
Prompt="x",
|
||||
Tools='[{"type":"function","function":{"name":"calc","parameters":{}}}]',
|
||||
)
|
||||
replies = self._collect(s, req)
|
||||
|
||||
contents = [cd.content for r in replies for cd in r.chat_deltas if cd.content]
|
||||
self.assertFalse(
|
||||
any("<tool_call" in c or "</tool_call>" in c for c in contents),
|
||||
f"markup leaked as content: {contents!r}",
|
||||
)
|
||||
names = [tc.name for r in replies for cd in r.chat_deltas for tc in cd.tool_calls if tc.name]
|
||||
args = [tc.arguments for r in replies for cd in r.chat_deltas for tc in cd.tool_calls if tc.arguments]
|
||||
self.assertIn("calc", names, f"tool_call name missing; got {names!r}")
|
||||
self.assertIn('{"x": 1}', args, f"tool_call args missing; got {args!r}")
|
||||
|
||||
# ── Case 5: parser exception → fallback to buffer, no leak ──
|
||||
def test_native_streaming_parser_exception_falls_back_to_buffer(self):
|
||||
from types import SimpleNamespace
|
||||
call = SimpleNamespace(id="call_1",
|
||||
function=SimpleNamespace(name="calc", arguments='{"x": 1}'))
|
||||
|
||||
class _BrokenStreamer:
|
||||
def __init__(self, tokenizer, tools=None):
|
||||
pass
|
||||
def extract_tool_calls(self, c, request=None):
|
||||
return SimpleNamespace(tools_called=True, content="", tool_calls=[call])
|
||||
def extract_tool_calls_streaming(self, *a, **kw):
|
||||
raise RuntimeError("simulated parser bug")
|
||||
|
||||
s = self._new_servicer()
|
||||
s.llm = SimpleNamespace(generate=self._make_generate([
|
||||
'<tool_call>\n{"name": "calc"',
|
||||
'<tool_call>\n{"name": "calc", "arguments": {"x": 1}}\n</tool_call>',
|
||||
]))
|
||||
s.tool_parser_cls = _BrokenStreamer
|
||||
req = backend_pb2.PredictOptions(
|
||||
Prompt="x",
|
||||
Tools='[{"type":"function","function":{"name":"calc","parameters":{}}}]',
|
||||
)
|
||||
replies = self._collect(s, req)
|
||||
|
||||
contents = [cd.content for r in replies for cd in r.chat_deltas if cd.content]
|
||||
self.assertFalse(
|
||||
any("<tool_call" in c for c in contents),
|
||||
f"markup leaked after parser exception: {contents!r}",
|
||||
)
|
||||
names = [tc.name for r in replies for cd in r.chat_deltas for tc in cd.tool_calls]
|
||||
self.assertIn("calc", names, "tool_call missing from final chunk after fallback")
|
||||
|
||||
# ── Case 6: no tool parser → unchanged per-delta content stream ──
|
||||
def test_no_tool_parser_unchanged_per_delta_stream(self):
|
||||
from types import SimpleNamespace
|
||||
s = self._new_servicer() # tool_parser_cls already None
|
||||
s.llm = SimpleNamespace(generate=self._make_generate([
|
||||
"Hello ", "Hello world", "Hello world!",
|
||||
]))
|
||||
req = backend_pb2.PredictOptions(Prompt="x", Tools="")
|
||||
replies = self._collect(s, req)
|
||||
|
||||
intermediate = [
|
||||
cd.content for r in replies[:-1] for cd in r.chat_deltas if cd.content
|
||||
]
|
||||
self.assertEqual(
|
||||
intermediate, ["Hello ", "world", "!"],
|
||||
f"plain streaming changed; got {intermediate!r}",
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user