mirror of
https://github.com/mudler/LocalAI.git
synced 2026-06-22 07:39:02 -04:00
* fix(vllm): don't stream raw tool-call markup as content when a tool parser is active When a tool_parser is configured and the request carries tools, the streaming loop emitted every text delta as delta.content — including the model's raw tool-call markup (e.g. <tool_call>...) — because extract_tool_calls only runs on the full output after the stream. Clients streaming a tool call therefore saw the unparsed tool-call syntax as assistant content. Buffer the text while a tool parser is active for the request; the existing end-of-stream chat_delta already carries the parsed tool_calls (or the cleaned content), which the Go side converts to SSE deltas. Non-tool-parser streaming is unchanged. Add a server-less regression test covering both the tool-call case (no raw markup leaked as content) and the plain-text case (content delivered exactly once — guards against double-emitting the buffered content). Signed-off-by: pos-ei-don <1822533+pos-ei-don@users.noreply.github.com> * test(vllm): add expectedFailure test for progressive streaming with tool parser (Case 3, #582) Signed-off-by: pos-ei-don <1822533+pos-ei-don@users.noreply.github.com> * test(vllm): add Cases 4+5 — marker split across chunks + false-positive prefix (TDD, Option B state machine, #582) Signed-off-by: pos-ei-don <1822533+pos-ei-don@users.noreply.github.com> * feat(vllm): progressive streaming via parser.extract_tool_calls_streaming When a tool parser is active for a tool-enabled streaming request, #10346 buffers the entire generation and surfaces it on the final chunk to prevent raw tool-call markup from leaking as delta.content. This is correct but turns the request into effectively non-streaming for plain-text responses — the client sees nothing until the model stops. Every concrete tool parser shipped with vLLM 0.23+ already implements extract_tool_calls_streaming (Granite4, Qwen3Coder, DeepSeekV31, Jamba, Ernie45, Hermes2Pro, llama3_json, mistral, …). Use it: instantiate the parser before the streaming loop and call its streaming method per delta, emitting DeltaMessage(content=…) or DeltaMessage(tool_calls=[…]) when the parser is ready. Falls back to the existing #10346 buffer path when: - the parser does not have extract_tool_calls_streaming, OR - extract_tool_calls_streaming raises mid-stream (logged, the rest of the request finishes via post-loop extract_tool_calls). Tests (TestStreamingToolParser): 1. Buffer path: no markup leaked, no content duplication 2. Native streaming: plain-text response streams progressively 3. Native streaming: tool_call structured, no markup leaked 4. Native streaming exception → graceful fallback, no markup, no crash 5. No tool parser → unchanged per-delta content stream E2E verified against qwen3_coder on vLLM 0.23.0 (NVIDIA GB10 / arm64 / CUDA 13). Signed-off-by: pos-ei-don <1822533+pos-ei-don@users.noreply.github.com> * docs(vllm): add server-side TTFT benchmark for the streaming tool-parser path Self-contained stdlib-only script that measures time-to-first-token (TTFT) for the vLLM backend's two streaming scenarios: - tool_call: request mentions a tool; model is expected to call it - plain_text: request offers a tool but explicitly asks for prose Use this to compare: - the buffer-all path (#10346) → plain_text TTFT ≈ total response time - the native-streaming path (this PR) → plain_text TTFT ≈ true first-token time python examples/vllm-bench/ttft_streaming_tool_parser.py \\ --url http://localhost:8080 --model my-coder --runs 3 Lives under examples/ so it does not interfere with the test suite. Signed-off-by: pos-ei-don <1822533+pos-ei-don@users.noreply.github.com> * examples/vllm-bench: add long-text scenario (8 paragraphs, 1500 tokens) The long-text scenario shows the buffering vs streaming difference most dramatically: with the buffer-all path, the client receives nothing for 20+ seconds and then the entire 1500-token response at once. With native streaming, the first token arrives in tens of milliseconds and the response flows progressively. Signed-off-by: pos-ei-don <1822533+pos-ei-don@users.noreply.github.com> --------- Signed-off-by: pos-ei-don <1822533+pos-ei-don@users.noreply.github.com> Co-authored-by: Philipp Wacker <philipp.wacker@ibf-solutions.com>
539 lines
22 KiB
Python
539 lines
22 KiB
Python
import unittest
|
|
import subprocess
|
|
import time
|
|
import backend_pb2
|
|
import backend_pb2_grpc
|
|
|
|
import grpc
|
|
|
|
import unittest
|
|
import subprocess
|
|
import time
|
|
import grpc
|
|
import backend_pb2_grpc
|
|
import backend_pb2
|
|
|
|
class TestBackendServicer(unittest.TestCase):
|
|
"""
|
|
TestBackendServicer is the class that tests the gRPC service.
|
|
|
|
This class contains methods to test the startup and shutdown of the gRPC service.
|
|
"""
|
|
def setUp(self):
|
|
self.service = subprocess.Popen(["python", "backend.py", "--addr", "localhost:50051"])
|
|
time.sleep(10)
|
|
|
|
def tearDown(self) -> None:
|
|
self.service.terminate()
|
|
self.service.wait()
|
|
|
|
def test_server_startup(self):
|
|
try:
|
|
self.setUp()
|
|
with grpc.insecure_channel("localhost:50051") as channel:
|
|
stub = backend_pb2_grpc.BackendStub(channel)
|
|
response = stub.Health(backend_pb2.HealthMessage())
|
|
self.assertEqual(response.message, b'OK')
|
|
except Exception as err:
|
|
print(err)
|
|
self.fail("Server failed to start")
|
|
finally:
|
|
self.tearDown()
|
|
def test_load_model(self):
|
|
"""
|
|
This method tests if the model is loaded successfully
|
|
"""
|
|
try:
|
|
self.setUp()
|
|
with grpc.insecure_channel("localhost:50051") as channel:
|
|
stub = backend_pb2_grpc.BackendStub(channel)
|
|
response = stub.LoadModel(backend_pb2.ModelOptions(Model="facebook/opt-125m"))
|
|
self.assertTrue(response.success)
|
|
self.assertEqual(response.message, "Model loaded successfully")
|
|
except Exception as err:
|
|
print(err)
|
|
self.fail("LoadModel service failed")
|
|
finally:
|
|
self.tearDown()
|
|
|
|
def test_text(self):
|
|
"""
|
|
This method tests if the embeddings are generated successfully
|
|
"""
|
|
try:
|
|
self.setUp()
|
|
with grpc.insecure_channel("localhost:50051") as channel:
|
|
stub = backend_pb2_grpc.BackendStub(channel)
|
|
response = stub.LoadModel(backend_pb2.ModelOptions(Model="facebook/opt-125m"))
|
|
self.assertTrue(response.success)
|
|
req = backend_pb2.PredictOptions(Prompt="The capital of France is")
|
|
resp = stub.Predict(req)
|
|
self.assertIsNotNone(resp.message)
|
|
except Exception as err:
|
|
print(err)
|
|
self.fail("text service failed")
|
|
finally:
|
|
self.tearDown()
|
|
|
|
def test_sampling_params(self):
|
|
"""
|
|
This method tests if all sampling parameters are correctly processed
|
|
NOTE: this does NOT test for correctness, just that we received a compatible response
|
|
"""
|
|
try:
|
|
self.setUp()
|
|
with grpc.insecure_channel("localhost:50051") as channel:
|
|
stub = backend_pb2_grpc.BackendStub(channel)
|
|
response = stub.LoadModel(backend_pb2.ModelOptions(Model="facebook/opt-125m"))
|
|
self.assertTrue(response.success)
|
|
|
|
req = backend_pb2.PredictOptions(
|
|
Prompt="The capital of France is",
|
|
TopP=0.8,
|
|
Tokens=50,
|
|
Temperature=0.7,
|
|
TopK=40,
|
|
PresencePenalty=0.1,
|
|
FrequencyPenalty=0.2,
|
|
RepetitionPenalty=1.1,
|
|
MinP=0.05,
|
|
Seed=42,
|
|
StopPrompts=["\n"],
|
|
StopTokenIds=[50256],
|
|
BadWords=["badword"],
|
|
IncludeStopStrInOutput=True,
|
|
IgnoreEOS=True,
|
|
MinTokens=5,
|
|
Logprobs=5,
|
|
PromptLogprobs=5,
|
|
SkipSpecialTokens=True,
|
|
SpacesBetweenSpecialTokens=True,
|
|
TruncatePromptTokens=10,
|
|
GuidedDecoding=True,
|
|
N=2,
|
|
)
|
|
resp = stub.Predict(req)
|
|
self.assertIsNotNone(resp.message)
|
|
self.assertIsNotNone(resp.logprobs)
|
|
except Exception as err:
|
|
print(err)
|
|
self.fail("sampling params service failed")
|
|
finally:
|
|
self.tearDown()
|
|
|
|
|
|
def test_messages_to_dicts(self):
|
|
"""
|
|
Tests _messages_to_dicts conversion of proto Messages to dicts.
|
|
"""
|
|
import sys, os
|
|
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
|
from backend import BackendServicer
|
|
servicer = BackendServicer()
|
|
msgs = [
|
|
backend_pb2.Message(role="user", content="hello"),
|
|
backend_pb2.Message(
|
|
role="assistant",
|
|
content="",
|
|
tool_calls='[{"id":"call_1","type":"function","function":{"name":"foo","arguments":"{}"}}]',
|
|
reasoning_content="thinking...",
|
|
),
|
|
backend_pb2.Message(role="tool", content="result", name="foo", tool_call_id="call_1"),
|
|
]
|
|
result = servicer._messages_to_dicts(msgs)
|
|
self.assertEqual(len(result), 3)
|
|
self.assertEqual(result[0], {"role": "user", "content": "hello"})
|
|
self.assertEqual(result[1]["reasoning_content"], "thinking...")
|
|
self.assertIsInstance(result[1]["tool_calls"], list)
|
|
self.assertEqual(result[1]["tool_calls"][0]["id"], "call_1")
|
|
self.assertEqual(result[2]["tool_call_id"], "call_1")
|
|
self.assertEqual(result[2]["name"], "foo")
|
|
|
|
def test_parse_options(self):
|
|
"""
|
|
Tests _parse_options correctly parses key:value strings.
|
|
"""
|
|
import sys, os
|
|
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
|
from backend import BackendServicer
|
|
servicer = BackendServicer()
|
|
opts = servicer._parse_options([
|
|
"tool_parser:hermes",
|
|
"reasoning_parser:deepseek_r1",
|
|
"invalid_no_colon",
|
|
"key_with_colons:a:b:c",
|
|
])
|
|
self.assertEqual(opts["tool_parser"], "hermes")
|
|
self.assertEqual(opts["reasoning_parser"], "deepseek_r1")
|
|
self.assertEqual(opts["key_with_colons"], "a:b:c")
|
|
self.assertNotIn("invalid_no_colon", opts)
|
|
|
|
def test_apply_engine_args_known_keys(self):
|
|
"""
|
|
Tests _apply_engine_args overlays user-supplied JSON onto AsyncEngineArgs.
|
|
"""
|
|
import sys, os, json as _json
|
|
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
|
from backend import BackendServicer
|
|
from vllm.engine.arg_utils import AsyncEngineArgs
|
|
|
|
servicer = BackendServicer()
|
|
base = AsyncEngineArgs(model="facebook/opt-125m")
|
|
extras = _json.dumps({
|
|
"trust_remote_code": True,
|
|
"max_num_seqs": 32,
|
|
})
|
|
out = servicer._apply_engine_args(base, extras)
|
|
self.assertTrue(out.trust_remote_code)
|
|
self.assertEqual(out.max_num_seqs, 32)
|
|
# untouched fields preserved
|
|
self.assertEqual(out.model, "facebook/opt-125m")
|
|
|
|
def test_apply_engine_args_unknown_key_raises(self):
|
|
"""
|
|
Tests _apply_engine_args rejects unknown keys with a helpful suggestion.
|
|
"""
|
|
import sys, os, json as _json
|
|
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
|
from backend import BackendServicer
|
|
from vllm.engine.arg_utils import AsyncEngineArgs
|
|
|
|
servicer = BackendServicer()
|
|
base = AsyncEngineArgs(model="facebook/opt-125m")
|
|
with self.assertRaises(ValueError) as ctx:
|
|
servicer._apply_engine_args(base, _json.dumps({"trustremotecode": True}))
|
|
self.assertIn("trustremotecode", str(ctx.exception))
|
|
# close-match hint for the typo
|
|
self.assertIn("trust_remote_code", str(ctx.exception))
|
|
|
|
def test_apply_engine_args_empty_passthrough(self):
|
|
"""
|
|
Tests that empty engine_args returns the base unchanged.
|
|
"""
|
|
import sys, os
|
|
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
|
from backend import BackendServicer
|
|
from vllm.engine.arg_utils import AsyncEngineArgs
|
|
|
|
servicer = BackendServicer()
|
|
base = AsyncEngineArgs(model="facebook/opt-125m")
|
|
self.assertIs(servicer._apply_engine_args(base, ""), base)
|
|
self.assertIs(servicer._apply_engine_args(base, None), base)
|
|
|
|
def test_tokenize_string(self):
|
|
"""
|
|
Tests the TokenizeString RPC returns valid tokens.
|
|
"""
|
|
try:
|
|
self.setUp()
|
|
with grpc.insecure_channel("localhost:50051") as channel:
|
|
stub = backend_pb2_grpc.BackendStub(channel)
|
|
response = stub.LoadModel(backend_pb2.ModelOptions(Model="facebook/opt-125m"))
|
|
self.assertTrue(response.success)
|
|
resp = stub.TokenizeString(backend_pb2.PredictOptions(Prompt="Hello world"))
|
|
self.assertGreater(resp.length, 0)
|
|
self.assertEqual(len(resp.tokens), resp.length)
|
|
except Exception as err:
|
|
print(err)
|
|
self.fail("TokenizeString service failed")
|
|
finally:
|
|
self.tearDown()
|
|
|
|
def test_free(self):
|
|
"""
|
|
Tests the Free RPC doesn't crash.
|
|
"""
|
|
try:
|
|
self.setUp()
|
|
with grpc.insecure_channel("localhost:50051") as channel:
|
|
stub = backend_pb2_grpc.BackendStub(channel)
|
|
response = stub.LoadModel(backend_pb2.ModelOptions(Model="facebook/opt-125m"))
|
|
self.assertTrue(response.success)
|
|
free_resp = stub.Free(backend_pb2.HealthMessage())
|
|
self.assertTrue(free_resp.success)
|
|
except Exception as err:
|
|
print(err)
|
|
self.fail("Free service failed")
|
|
finally:
|
|
self.tearDown()
|
|
|
|
def test_embedding(self):
|
|
"""
|
|
This method tests if the embeddings are generated successfully
|
|
"""
|
|
try:
|
|
self.setUp()
|
|
with grpc.insecure_channel("localhost:50051") as channel:
|
|
stub = backend_pb2_grpc.BackendStub(channel)
|
|
response = stub.LoadModel(backend_pb2.ModelOptions(Model="intfloat/e5-mistral-7b-instruct"))
|
|
self.assertTrue(response.success)
|
|
embedding_request = backend_pb2.PredictOptions(Embeddings="This is a test sentence.")
|
|
embedding_response = stub.Embedding(embedding_request)
|
|
self.assertIsNotNone(embedding_response.embeddings)
|
|
# assert that is a list of floats
|
|
self.assertIsInstance(embedding_response.embeddings, list)
|
|
# assert that the list is not empty
|
|
self.assertTrue(len(embedding_response.embeddings) > 0)
|
|
except Exception as err:
|
|
print(err)
|
|
self.fail("Embedding service failed")
|
|
finally:
|
|
self.tearDown()
|
|
|
|
|
|
class TestStreamingToolParser(unittest.TestCase):
|
|
"""
|
|
Server-less unit tests for the streaming + tool-parser machinery in
|
|
BackendServicer._predict. These tests instantiate BackendServicer
|
|
directly and mock the vLLM engine + tool parser, so they do not need
|
|
a GPU, a model, or a running gRPC server. Kept in a separate class to
|
|
avoid the parent setUp() which spawns a subprocess.
|
|
|
|
Covers #582 (follow-up to #10346):
|
|
1. Markup-leak prevention with a non-streaming parser (buffer fallback)
|
|
2. No content duplication on the plain-text path with the buffer fallback
|
|
3. Native streaming progressive plain-text emission
|
|
4. Native streaming structured tool_call, no markup leak
|
|
5. Parser exception → graceful fallback to buffer, still no markup
|
|
6. No-tool-parser regression: unchanged per-delta content stream
|
|
"""
|
|
|
|
@staticmethod
|
|
def _make_generate(chunks):
|
|
"""Build a fake vLLM engine.generate that yields cumulative chunks."""
|
|
from types import SimpleNamespace
|
|
async def gen(*a, **k):
|
|
for i, t in enumerate(chunks):
|
|
yield SimpleNamespace(
|
|
outputs=[SimpleNamespace(
|
|
text=t,
|
|
token_ids=list(range(i + 1)),
|
|
logprobs=None,
|
|
)],
|
|
prompt_token_ids=[0],
|
|
)
|
|
return lambda *a, **k: gen()
|
|
|
|
@staticmethod
|
|
def _collect(servicer, req):
|
|
import asyncio
|
|
async def run():
|
|
return [r async for r in servicer._predict(req, None, streaming=True)]
|
|
return asyncio.run(run())
|
|
|
|
def _new_servicer(self):
|
|
import sys, os
|
|
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
|
from backend import BackendServicer
|
|
s = BackendServicer()
|
|
s.reasoning_parser_cls = None
|
|
s.tool_parser_cls = None
|
|
s.tokenizer = None
|
|
return s
|
|
|
|
# ── Case 1+2: parser without streaming method → buffer fallback ──
|
|
def test_buffer_path_no_markup_no_duplication(self):
|
|
from types import SimpleNamespace
|
|
|
|
def parser_cls(called, content_text, calls):
|
|
class _P:
|
|
def __init__(self, tokenizer, tools=None):
|
|
pass
|
|
# NOTE: NO extract_tool_calls_streaming → takes the buffer path
|
|
def extract_tool_calls(self, c, request=None):
|
|
return SimpleNamespace(
|
|
tools_called=called, content=content_text, tool_calls=calls,
|
|
)
|
|
return _P
|
|
|
|
tools_json = '[{"type":"function","function":{"name":"calc","parameters":{}}}]'
|
|
|
|
# Tool-call case: no raw markup in any delta.content
|
|
s = self._new_servicer()
|
|
s.llm = SimpleNamespace(generate=self._make_generate([
|
|
'<tool_call>\n{"name": "calc"',
|
|
'<tool_call>\n{"name": "calc", "arguments": {"x": 1}}\n</tool_call>',
|
|
]))
|
|
call = SimpleNamespace(id="call_1",
|
|
function=SimpleNamespace(name="calc", arguments='{"x": 1}'))
|
|
s.tool_parser_cls = parser_cls(True, "", [call])
|
|
req = backend_pb2.PredictOptions(Prompt="x", Tools=tools_json)
|
|
replies = self._collect(s, req)
|
|
contents = [cd.content for r in replies for cd in r.chat_deltas if cd.content]
|
|
self.assertFalse(
|
|
any("<tool_call" in c for c in contents),
|
|
f"markup leaked: {contents!r}",
|
|
)
|
|
names = [tc.name for r in replies for cd in r.chat_deltas for tc in cd.tool_calls]
|
|
self.assertIn("calc", names, "tool_call missing from final chunk")
|
|
|
|
# Plain-text-with-tools case: full content delivered exactly once
|
|
s2 = self._new_servicer()
|
|
s2.llm = SimpleNamespace(generate=self._make_generate([
|
|
"The capital ",
|
|
"The capital of France is Paris.",
|
|
]))
|
|
s2.tool_parser_cls = parser_cls(False, "", [])
|
|
req2 = backend_pb2.PredictOptions(Prompt="x", Tools=tools_json)
|
|
joined = "".join(
|
|
cd.content for r in self._collect(s2, req2)
|
|
for cd in r.chat_deltas if cd.content
|
|
)
|
|
self.assertEqual(
|
|
joined.count("The capital of France is Paris."), 1,
|
|
f"buffered content duplicated: {joined!r}",
|
|
)
|
|
|
|
# ── Case 3: native streaming, progressive plain text ──
|
|
def test_native_streaming_progressive_plain_text(self):
|
|
from types import SimpleNamespace
|
|
|
|
class _DeltaMsg:
|
|
def __init__(self, content=None, reasoning=None, tool_calls=None):
|
|
self.content = content
|
|
self.reasoning = reasoning
|
|
self.tool_calls = tool_calls or []
|
|
|
|
class StreamingParser:
|
|
def __init__(self, tokenizer, tools=None):
|
|
pass
|
|
def extract_tool_calls(self, c, request=None):
|
|
# Should NOT be called when native streaming runs successfully.
|
|
raise AssertionError("extract_tool_calls invoked on native-streaming path")
|
|
def extract_tool_calls_streaming(
|
|
self, previous_text, current_text, delta_text,
|
|
previous_token_ids, current_token_ids, delta_token_ids, request,
|
|
):
|
|
if not delta_text:
|
|
return None
|
|
return _DeltaMsg(content=delta_text)
|
|
|
|
s = self._new_servicer()
|
|
s.llm = SimpleNamespace(generate=self._make_generate([
|
|
"Paris ",
|
|
"Paris is ",
|
|
"Paris is the capital of France.",
|
|
]))
|
|
s.tool_parser_cls = StreamingParser
|
|
req = backend_pb2.PredictOptions(
|
|
Prompt="x",
|
|
Tools='[{"type":"function","function":{"name":"calc","parameters":{}}}]',
|
|
)
|
|
replies = self._collect(s, req)
|
|
|
|
intermediate_content = [
|
|
cd.content for r in replies[:-1] for cd in r.chat_deltas if cd.content
|
|
]
|
|
self.assertTrue(
|
|
len(intermediate_content) > 0,
|
|
"Plain-text response not streamed progressively (native streaming inactive?)",
|
|
)
|
|
assembled = "".join(
|
|
cd.content for r in replies for cd in r.chat_deltas if cd.content
|
|
)
|
|
self.assertEqual(
|
|
assembled, "Paris is the capital of France.",
|
|
f"Assembled content wrong: {assembled!r}",
|
|
)
|
|
|
|
# ── Case 4: native streaming, structured tool_call, no markup ──
|
|
def test_native_streaming_tool_call_no_markup_leak(self):
|
|
from types import SimpleNamespace
|
|
|
|
class _DeltaMsg:
|
|
def __init__(self, content=None, reasoning=None, tool_calls=None):
|
|
self.content = content
|
|
self.reasoning = reasoning
|
|
self.tool_calls = tool_calls or []
|
|
|
|
class _ToolCallStreamer:
|
|
def __init__(self, tokenizer, tools=None):
|
|
self._emitted = False
|
|
def extract_tool_calls(self, c, request=None):
|
|
raise AssertionError("extract_tool_calls invoked on native-streaming path")
|
|
def extract_tool_calls_streaming(
|
|
self, previous_text, current_text, delta_text,
|
|
previous_token_ids, current_token_ids, delta_token_ids, request,
|
|
):
|
|
if "</tool_call>" in current_text and not self._emitted:
|
|
self._emitted = True
|
|
fn = SimpleNamespace(name="calc", arguments='{"x": 1}')
|
|
tc = SimpleNamespace(id="call_1", type="function", index=0, function=fn)
|
|
return _DeltaMsg(tool_calls=[tc])
|
|
return None
|
|
|
|
s = self._new_servicer()
|
|
s.llm = SimpleNamespace(generate=self._make_generate([
|
|
'<tool_call>\n',
|
|
'<tool_call>\n{"name": "calc"',
|
|
'<tool_call>\n{"name": "calc", "arguments": {"x": 1}}\n</tool_call>',
|
|
]))
|
|
s.tool_parser_cls = _ToolCallStreamer
|
|
req = backend_pb2.PredictOptions(
|
|
Prompt="x",
|
|
Tools='[{"type":"function","function":{"name":"calc","parameters":{}}}]',
|
|
)
|
|
replies = self._collect(s, req)
|
|
|
|
contents = [cd.content for r in replies for cd in r.chat_deltas if cd.content]
|
|
self.assertFalse(
|
|
any("<tool_call" in c or "</tool_call>" in c for c in contents),
|
|
f"markup leaked as content: {contents!r}",
|
|
)
|
|
names = [tc.name for r in replies for cd in r.chat_deltas for tc in cd.tool_calls if tc.name]
|
|
args = [tc.arguments for r in replies for cd in r.chat_deltas for tc in cd.tool_calls if tc.arguments]
|
|
self.assertIn("calc", names, f"tool_call name missing; got {names!r}")
|
|
self.assertIn('{"x": 1}', args, f"tool_call args missing; got {args!r}")
|
|
|
|
# ── Case 5: parser exception → fallback to buffer, no leak ──
|
|
def test_native_streaming_parser_exception_falls_back_to_buffer(self):
|
|
from types import SimpleNamespace
|
|
call = SimpleNamespace(id="call_1",
|
|
function=SimpleNamespace(name="calc", arguments='{"x": 1}'))
|
|
|
|
class _BrokenStreamer:
|
|
def __init__(self, tokenizer, tools=None):
|
|
pass
|
|
def extract_tool_calls(self, c, request=None):
|
|
return SimpleNamespace(tools_called=True, content="", tool_calls=[call])
|
|
def extract_tool_calls_streaming(self, *a, **kw):
|
|
raise RuntimeError("simulated parser bug")
|
|
|
|
s = self._new_servicer()
|
|
s.llm = SimpleNamespace(generate=self._make_generate([
|
|
'<tool_call>\n{"name": "calc"',
|
|
'<tool_call>\n{"name": "calc", "arguments": {"x": 1}}\n</tool_call>',
|
|
]))
|
|
s.tool_parser_cls = _BrokenStreamer
|
|
req = backend_pb2.PredictOptions(
|
|
Prompt="x",
|
|
Tools='[{"type":"function","function":{"name":"calc","parameters":{}}}]',
|
|
)
|
|
replies = self._collect(s, req)
|
|
|
|
contents = [cd.content for r in replies for cd in r.chat_deltas if cd.content]
|
|
self.assertFalse(
|
|
any("<tool_call" in c for c in contents),
|
|
f"markup leaked after parser exception: {contents!r}",
|
|
)
|
|
names = [tc.name for r in replies for cd in r.chat_deltas for tc in cd.tool_calls]
|
|
self.assertIn("calc", names, "tool_call missing from final chunk after fallback")
|
|
|
|
# ── Case 6: no tool parser → unchanged per-delta content stream ──
|
|
def test_no_tool_parser_unchanged_per_delta_stream(self):
|
|
from types import SimpleNamespace
|
|
s = self._new_servicer() # tool_parser_cls already None
|
|
s.llm = SimpleNamespace(generate=self._make_generate([
|
|
"Hello ", "Hello world", "Hello world!",
|
|
]))
|
|
req = backend_pb2.PredictOptions(Prompt="x", Tools="")
|
|
replies = self._collect(s, req)
|
|
|
|
intermediate = [
|
|
cd.content for r in replies[:-1] for cd in r.chat_deltas if cd.content
|
|
]
|
|
self.assertEqual(
|
|
intermediate, ["Hello ", "world", "!"],
|
|
f"plain streaming changed; got {intermediate!r}",
|
|
)
|