Compare commits

..

7 Commits

Author SHA1 Message Date
ciaranbor
12ace705fc Test event dropping 2026-02-19 18:55:52 +00:00
ciaranbor
20ccf097bb Route DownloadCoordinator events through worker's event channel 2026-02-19 18:55:52 +00:00
ciaranbor
94848bd5bd Use n-strike ping tolerance 2026-02-19 18:55:52 +00:00
rltakashige
cf648a53b8 Add thinking in thinking blocks, and fix DeepSeek interleaved tool calls (#1548)
## Motivation

OpenCode shows <think> tags and not thinking blocks as we aren't
following the API specs properly.

Claude was also getting horrible prefix cache hits because it sends
headers.

## Changes

Handle thinking tokens properly by placing them in think tags for each
of the API endpoints.
Also support DeepSeekV3.2 tool calling properly as a minor feature.
Strips Claude headers at the API level.

## Test Plan

### Manual Testing
Tested OpenCode manually
Needs testing with Claude.

### Automated Testing
All CI and tests passing - added a new e2e test for DeepSeekV32 tool
parsing.
2026-02-19 18:44:49 +00:00
Alex Cheema
94b2ce6922 feat: Mac Studio en2 RDMA port warning v2 (#1551)
Rebuilt from scratch (replaces PR #1543). Detects when Mac Studio uses
RDMA over en2 (TB5 port next to Ethernet) which does not support RDMA.
Shows dismissible warning banner with hover tooltip showing affected
devices, SVG rear panel illustration, and fix instructions. 205 lines in
+page.svelte.

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: rltakashige <rl.takashige@gmail.com>
2026-02-19 18:39:17 +00:00
rltakashige
423ed0f07f Strip Claude headers to improve prefix cache hit rates (#1552)
## Motivation
Our hits are really bad at the moment (0.2%). This PR makes it 98.5% on
average.

## Changes

Also adds an example for how to run Claude using Exo.

## Why It Works
Claude sends some billing and session headers that change with each
message.

## Test Plan

### Manual Testing
Works in manual testing.
2026-02-19 18:29:34 +00:00
Evan Quiney
ed001f2409 remove prefillprogress event (#1550)
this should never have been a separate event, but i didnt quite
communicate that well when this was merged. convert PrefillProgress to a
chunk like the rest of the runner responses.

tested with Llama-3.3-70B, prefill progress events still show up in the
dashboard as usual
2026-02-19 18:23:28 +00:00
32 changed files with 2535 additions and 285 deletions

View File

@@ -1652,11 +1652,12 @@ class AppStore {
if (!reader) throw new Error("No response body");
let fullContent = prefixText;
let streamedThinking = "";
const collectedTokens: TokenData[] = [...tokensToKeep];
interface ChatCompletionChunk {
choices?: Array<{
delta?: { content?: string };
delta?: { content?: string; reasoning_content?: string };
logprobs?: {
content?: Array<{
token: string;
@@ -1677,6 +1678,7 @@ class AppStore {
(parsed) => {
const choice = parsed.choices?.[0];
const delta = choice?.delta?.content;
const thinkingDelta = choice?.delta?.reasoning_content;
// Collect logprobs data
const logprobsContent = choice?.logprobs?.content;
@@ -1695,7 +1697,11 @@ class AppStore {
}
}
if (delta) {
if (thinkingDelta) {
streamedThinking += thinkingDelta;
}
if (delta || thinkingDelta) {
if (firstTokenTime === null) {
firstTokenTime = performance.now();
this.ttftMs = firstTokenTime - requestStartTime;
@@ -1709,9 +1715,14 @@ class AppStore {
this.tps = ((tokenCount - tokensToKeep.length) / elapsed) * 1000;
}
fullContent += delta;
const { displayContent, thinkingContent } =
if (delta) {
fullContent += delta;
}
const { displayContent, thinkingContent: tagThinking } =
this.stripThinkingTags(fullContent);
const combinedThinking = [streamedThinking, tagThinking]
.filter(Boolean)
.join("\n\n");
if (this.activeConversationId === targetConversationId) {
this.currentResponse = displayContent;
@@ -1723,7 +1734,7 @@ class AppStore {
messageId,
(m) => {
m.content = displayContent;
m.thinking = thinkingContent || undefined;
m.thinking = combinedThinking || undefined;
m.tokens = [...collectedTokens];
},
);
@@ -1735,11 +1746,14 @@ class AppStore {
// Final update
if (this.conversationExists(targetConversationId)) {
const { displayContent, thinkingContent } =
const { displayContent, thinkingContent: tagThinking } =
this.stripThinkingTags(fullContent);
const finalThinking = [streamedThinking, tagThinking]
.filter(Boolean)
.join("\n\n");
this.updateConversationMessage(targetConversationId, messageId, (m) => {
m.content = displayContent;
m.thinking = thinkingContent || undefined;
m.thinking = finalThinking || undefined;
m.tokens = [...collectedTokens];
if (this.ttftMs !== null) m.ttftMs = this.ttftMs;
if (this.tps !== null) m.tps = this.tps;
@@ -1847,11 +1861,12 @@ class AppStore {
}
let streamedContent = "";
let streamedThinking = "";
const collectedTokens: TokenData[] = [];
interface ChatCompletionChunk {
choices?: Array<{
delta?: { content?: string };
delta?: { content?: string; reasoning_content?: string };
logprobs?: {
content?: Array<{
token: string;
@@ -1872,6 +1887,7 @@ class AppStore {
(parsed) => {
const choice = parsed.choices?.[0];
const delta = choice?.delta?.content;
const thinkingDelta = choice?.delta?.reasoning_content;
// Collect logprobs data
const logprobsContent = choice?.logprobs?.content;
@@ -1890,10 +1906,19 @@ class AppStore {
}
}
if (delta) {
streamedContent += delta;
const { displayContent, thinkingContent } =
if (thinkingDelta) {
streamedThinking += thinkingDelta;
}
if (delta || thinkingDelta) {
if (delta) {
streamedContent += delta;
}
const { displayContent, thinkingContent: tagThinking } =
this.stripThinkingTags(streamedContent);
const combinedThinking = [streamedThinking, tagThinking]
.filter(Boolean)
.join("\n\n");
// Only update currentResponse if target conversation is active
if (this.activeConversationId === targetConversationId) {
@@ -1906,7 +1931,7 @@ class AppStore {
assistantMessage.id,
(msg) => {
msg.content = displayContent;
msg.thinking = thinkingContent || undefined;
msg.thinking = combinedThinking || undefined;
msg.tokens = [...collectedTokens];
},
);
@@ -1918,14 +1943,17 @@ class AppStore {
// Final cleanup of the message (if conversation still exists)
if (this.conversationExists(targetConversationId)) {
const { displayContent, thinkingContent } =
const { displayContent, thinkingContent: tagThinking } =
this.stripThinkingTags(streamedContent);
const finalThinking = [streamedThinking, tagThinking]
.filter(Boolean)
.join("\n\n");
this.updateConversationMessage(
targetConversationId,
assistantMessage.id,
(msg) => {
msg.content = displayContent;
msg.thinking = thinkingContent || undefined;
msg.thinking = finalThinking || undefined;
msg.tokens = [...collectedTokens];
},
);
@@ -2317,10 +2345,11 @@ class AppStore {
}
let streamedContent = "";
let streamedThinking = "";
interface ChatCompletionChunk {
choices?: Array<{
delta?: { content?: string };
delta?: { content?: string; reasoning_content?: string };
logprobs?: {
content?: Array<{
token: string;
@@ -2348,6 +2377,7 @@ class AppStore {
const choice = parsed.choices?.[0];
const tokenContent = choice?.delta?.content;
const thinkingContent = choice?.delta?.reasoning_content;
// Collect logprobs data
const logprobsContent = choice?.logprobs?.content;
@@ -2366,7 +2396,11 @@ class AppStore {
}
}
if (tokenContent) {
if (thinkingContent) {
streamedThinking += thinkingContent;
}
if (tokenContent || thinkingContent) {
// Track first token for TTFT
if (firstTokenTime === null) {
firstTokenTime = performance.now();
@@ -2383,11 +2417,16 @@ class AppStore {
this.tps = (tokenCount / elapsed) * 1000;
}
streamedContent += tokenContent;
if (tokenContent) {
streamedContent += tokenContent;
}
// Strip thinking tags for display and extract thinking content
const { displayContent, thinkingContent } =
// Use stripThinkingTags as fallback for any <think> tags still in content
const { displayContent, thinkingContent: tagThinking } =
this.stripThinkingTags(streamedContent);
const combinedThinking = [streamedThinking, tagThinking]
.filter(Boolean)
.join("\n\n");
// Only update currentResponse if target conversation is active
if (this.activeConversationId === targetConversationId) {
@@ -2400,7 +2439,7 @@ class AppStore {
assistantMessage.id,
(msg) => {
msg.content = displayContent;
msg.thinking = thinkingContent || undefined;
msg.thinking = combinedThinking || undefined;
msg.tokens = [...collectedTokens];
},
);
@@ -2436,14 +2475,17 @@ class AppStore {
// Final cleanup of the message (if conversation still exists)
if (this.conversationExists(targetConversationId)) {
const { displayContent, thinkingContent } =
const { displayContent, thinkingContent: tagThinking } =
this.stripThinkingTags(streamedContent);
const finalThinking = [streamedThinking, tagThinking]
.filter(Boolean)
.join("\n\n");
this.updateConversationMessage(
targetConversationId,
assistantMessage.id,
(msg) => {
msg.content = displayContent;
msg.thinking = thinkingContent || undefined;
msg.thinking = finalThinking || undefined;
msg.tokens = [...collectedTokens];
// Store performance metrics on the message
if (this.ttftMs !== null) {

View File

@@ -114,6 +114,74 @@
});
let tb5InfoDismissed = $state(false);
// Detect Mac Studio nodes using RDMA on en2 (the port next to ethernet — RDMA doesn't work there)
const macStudioEn2RdmaWarning = $derived.by(() => {
const edges = data?.edges;
const ids = tbIdentifiers;
const rdmaCtl = rdmaCtlData;
if (!edges || !ids || !rdmaCtl) return null;
const affectedConnections: Array<{
nodeId: string;
nodeName: string;
peerNodeId: string;
peerNodeName: string;
rdmaIface: string;
}> = [];
const isMacStudio = (node: (typeof data.nodes)[string] | undefined) =>
node?.system_info?.model_id === "Mac Studio";
for (const edge of edges) {
if (!edge.sourceRdmaIface && !edge.sinkRdmaIface) continue;
const sourceNode = data?.nodes?.[edge.source];
if (
isMacStudio(sourceNode) &&
edge.sourceRdmaIface === "rdma_en2" &&
rdmaCtl[edge.source]?.enabled
) {
affectedConnections.push({
nodeId: edge.source,
nodeName:
sourceNode?.friendly_name || edge.source.slice(0, 8) + "...",
peerNodeId: edge.target,
peerNodeName:
data?.nodes?.[edge.target]?.friendly_name ||
edge.target.slice(0, 8) + "...",
rdmaIface: "en2",
});
}
const sinkNode = data?.nodes?.[edge.target];
if (
isMacStudio(sinkNode) &&
edge.sinkRdmaIface === "rdma_en2" &&
rdmaCtl[edge.target]?.enabled
) {
affectedConnections.push({
nodeId: edge.target,
nodeName: sinkNode?.friendly_name || edge.target.slice(0, 8) + "...",
peerNodeId: edge.source,
peerNodeName:
sourceNode?.friendly_name || edge.source.slice(0, 8) + "...",
rdmaIface: "en2",
});
}
}
// Deduplicate by nodeId
const seen = new Set<string>();
const unique = affectedConnections.filter((c) => {
if (seen.has(c.nodeId)) return false;
seen.add(c.nodeId);
return true;
});
return unique.length > 0 ? unique : null;
});
let macStudioEn2Dismissed = $state(false);
// Helper to get friendly node name from node ID
function getNodeName(nodeId: string): string {
const node = data?.nodes?.[nodeId];
@@ -1758,7 +1826,7 @@
</script>
{#snippet clusterWarnings()}
{#if tbBridgeCycles.length > 0 || macosVersionMismatch || (tb5WithoutRdma && !tb5InfoDismissed)}
{#if tbBridgeCycles.length > 0 || macosVersionMismatch || (tb5WithoutRdma && !tb5InfoDismissed) || (macStudioEn2RdmaWarning && !macStudioEn2Dismissed)}
<div class="absolute top-4 left-4 flex flex-col gap-2 z-40">
{#if tbBridgeCycles.length > 0}
{@const cycle = tbBridgeCycles[0]}
@@ -1923,12 +1991,260 @@
</button>
</div>
{/if}
{#if macStudioEn2RdmaWarning && !macStudioEn2Dismissed}
<div class="group relative" role="alert">
<div
class="flex items-center gap-2 px-3 py-2 rounded border border-red-500/50 bg-red-500/10 backdrop-blur-sm cursor-help"
>
<svg
class="w-5 h-5 text-red-400 flex-shrink-0"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
d={warningIconPath}
/>
</svg>
<span class="text-sm font-mono text-red-200">
RDMA INCOMPATIBLE PORT
</span>
<button
type="button"
onclick={() => (macStudioEn2Dismissed = true)}
class="ml-1 text-red-300/60 hover:text-red-200 transition-colors cursor-pointer"
title="Dismiss"
>
<svg
class="w-4 h-4"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
d="M6 18L18 6M6 6l12 12"
/>
</svg>
</button>
</div>
<!-- Expanded tooltip on hover -->
<div
class="absolute top-full left-0 mt-2 w-96 p-4 rounded border border-red-500/30 bg-[#1a1a1a]/95 backdrop-blur-sm opacity-0 invisible group-hover:opacity-100 group-hover:visible transition-all duration-200 z-50 shadow-lg"
>
<p class="text-xs text-white/80 mb-3">
The Thunderbolt 5 port next to the Ethernet port on Mac Studio
does
<span class="text-red-400 font-semibold">not support RDMA</span>.
Move the cable to one of the other three TB5 ports.
</p>
<div class="text-xs text-white/60 mb-3">
<span class="text-red-300">Affected:</span>
{#each macStudioEn2RdmaWarning as conn}
<div class="ml-2 mt-0.5">
<span class="text-white/80">{conn.nodeName}</span>
<span class="text-white/30">&rarr;</span>
<span class="text-white/60">{conn.peerNodeName}</span>
<span class="text-white/30 ml-1">(en2)</span>
</div>
{/each}
</div>
<!-- Mac Studio back panel illustration -->
<div class="bg-black/40 rounded p-3 mb-3">
<p
class="text-[10px] font-mono text-white/30 uppercase tracking-wider mb-2"
>
Mac Studio Rear Panel
</p>
<svg
viewBox="0 0 320 72"
class="w-full"
xmlns="http://www.w3.org/2000/svg"
>
<rect
x="1"
y="1"
width="318"
height="70"
rx="6"
ry="6"
fill="none"
stroke="rgba(255,255,255,0.12)"
stroke-width="1"
/>
<!-- TB5 port 1 -->
<rect
x="24"
y="22"
width="28"
height="14"
rx="4"
fill="none"
stroke="rgba(255,255,255,0.3)"
stroke-width="1"
/>
<text
x="38"
y="52"
text-anchor="middle"
fill="rgba(255,255,255,0.25)"
style="font-size:7px;font-family:ui-monospace,monospace;"
>TB5</text
>
<!-- TB5 port 2 -->
<rect
x="62"
y="22"
width="28"
height="14"
rx="4"
fill="none"
stroke="rgba(255,255,255,0.3)"
stroke-width="1"
/>
<text
x="76"
y="52"
text-anchor="middle"
fill="rgba(255,255,255,0.25)"
style="font-size:7px;font-family:ui-monospace,monospace;"
>TB5</text
>
<!-- TB5 port 3 -->
<rect
x="100"
y="22"
width="28"
height="14"
rx="4"
fill="none"
stroke="rgba(255,255,255,0.3)"
stroke-width="1"
/>
<text
x="114"
y="52"
text-anchor="middle"
fill="rgba(255,255,255,0.25)"
style="font-size:7px;font-family:ui-monospace,monospace;"
>TB5</text
>
<!-- TB5 port 4: INCOMPATIBLE (en2) — equally spaced with ports 1-3 -->
<rect
x="138"
y="22"
width="28"
height="14"
rx="4"
fill="rgba(239,68,68,0.1)"
stroke="rgba(239,68,68,0.7)"
stroke-width="1.5"
/>
<line
x1="142"
y1="25"
x2="162"
y2="33"
stroke="rgba(239,68,68,0.8)"
stroke-width="1.5"
stroke-linecap="round"
/>
<line
x1="162"
y1="25"
x2="142"
y2="33"
stroke="rgba(239,68,68,0.8)"
stroke-width="1.5"
stroke-linecap="round"
/>
<text
x="152"
y="52"
text-anchor="middle"
fill="rgba(239,68,68,0.6)"
style="font-size:7px;font-family:ui-monospace,monospace;font-weight:600;"
>en2</text
>
<!-- Ethernet port -->
<rect
x="196"
y="19"
width="24"
height="20"
rx="2"
fill="none"
stroke="rgba(255,255,255,0.2)"
stroke-width="1"
/>
<rect
x="200"
y="23"
width="16"
height="12"
rx="1"
fill="none"
stroke="rgba(255,255,255,0.12)"
stroke-width="0.75"
/>
<text
x="208"
y="52"
text-anchor="middle"
fill="rgba(255,255,255,0.25)"
style="font-size:7px;font-family:ui-monospace,monospace;"
>ETH</text
>
<!-- Green checkmarks on working ports -->
<circle
cx="38"
cy="62"
r="3"
fill="none"
stroke="rgba(74,222,128,0.5)"
stroke-width="0.75"
/>
<circle
cx="76"
cy="62"
r="3"
fill="none"
stroke="rgba(74,222,128,0.5)"
stroke-width="0.75"
/>
<circle
cx="114"
cy="62"
r="3"
fill="none"
stroke="rgba(74,222,128,0.5)"
stroke-width="0.75"
/>
</svg>
</div>
<p class="text-xs text-white/50">
<span class="text-green-400">Fix:</span> Move the Thunderbolt cable
to any of the three leftmost ports (all support RDMA).
</p>
</div>
</div>
{/if}
</div>
{/if}
{/snippet}
{#snippet clusterWarningsCompact()}
{#if tbBridgeCycles.length > 0 || macosVersionMismatch || (tb5WithoutRdma && !tb5InfoDismissed)}
{#if tbBridgeCycles.length > 0 || macosVersionMismatch || (tb5WithoutRdma && !tb5InfoDismissed) || (macStudioEn2RdmaWarning && !macStudioEn2Dismissed)}
<div class="absolute top-2 left-2 flex flex-col gap-1">
{#if tbBridgeCycles.length > 0}
<div
@@ -1996,6 +2312,27 @@
>
</div>
{/if}
{#if macStudioEn2RdmaWarning && !macStudioEn2Dismissed}
<div
class="flex items-center gap-1.5 px-2 py-1 rounded border border-red-500/50 bg-red-500/10 backdrop-blur-sm"
title="Mac Studio RDMA incompatible port (en2) — move cable to another TB5 port"
>
<svg
class="w-3.5 h-3.5 text-red-400"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
d={warningIconPath}
/>
</svg>
<span class="text-[10px] font-mono text-red-200">BAD RDMA PORT</span>
</div>
{/if}
</div>
{/if}
{/snippet}

View File

@@ -23,6 +23,8 @@ use util::wakerdeque::WakerDeque;
const RETRY_CONNECT_INTERVAL: Duration = Duration::from_secs(5);
const MAX_PING_FAILURES: u32 = 3;
mod managed {
use libp2p::swarm::NetworkBehaviour;
use libp2p::{identity, mdns, ping};
@@ -31,8 +33,8 @@ mod managed {
const MDNS_RECORD_TTL: Duration = Duration::from_secs(2_500);
const MDNS_QUERY_INTERVAL: Duration = Duration::from_secs(1_500);
const PING_TIMEOUT: Duration = Duration::from_millis(2_500);
const PING_INTERVAL: Duration = Duration::from_millis(2_500);
const PING_TIMEOUT: Duration = Duration::from_secs(10);
const PING_INTERVAL: Duration = Duration::from_secs(5);
#[derive(NetworkBehaviour)]
pub struct Behaviour {
@@ -109,6 +111,9 @@ pub struct Behaviour {
// pending events to emmit => waker-backed Deque to control polling
pending_events: WakerDeque<ToSwarm<Event, Infallible>>,
// track consecutive ping failures per connection for N-strike tolerance
ping_failures: HashMap<ConnectionId, u32>,
}
impl Behaviour {
@@ -118,6 +123,7 @@ impl Behaviour {
mdns_discovered: HashMap::new(),
retry_delay: Delay::new(RETRY_CONNECT_INTERVAL),
pending_events: WakerDeque::new(),
ping_failures: HashMap::new(),
})
}
@@ -308,6 +314,7 @@ impl NetworkBehaviour for Behaviour {
};
if let Some((ip, port)) = remote_address.try_to_tcp_addr() {
self.ping_failures.remove(&connection_id);
// handle connection closed event which is filtered correctly
self.on_connection_closed(peer_id, connection_id, ip, port)
}
@@ -337,10 +344,41 @@ impl NetworkBehaviour for Behaviour {
}
},
// handle ping events => if error then disconnect
// handle ping events => disconnect after N consecutive failures
managed::BehaviourEvent::Ping(e) => {
if let Err(_) = e.result {
self.close_connection(e.peer, e.connection.clone())
match &e.result {
Err(err) => {
let count = self.ping_failures.entry(e.connection).or_insert(0);
*count += 1;
log::warn!(
"Ping failed for peer {:?} (connection {:?}): {:?} — failure {}/{}",
e.peer,
e.connection,
err,
count,
MAX_PING_FAILURES
);
if *count >= MAX_PING_FAILURES {
log::warn!(
"Closing connection to peer {:?} after {} consecutive ping failures",
e.peer,
MAX_PING_FAILURES
);
self.ping_failures.remove(&e.connection);
self.close_connection(e.peer, e.connection);
}
}
Ok(rtt) => {
// Reset failure counter on successful ping
if self.ping_failures.remove(&e.connection).is_some() {
log::debug!(
"Ping recovered for peer {:?} (rtt={:?}), reset failure counter",
e.peer,
rtt
);
}
log::trace!("Ping OK for peer {:?}: rtt={:?}", e.peer, rtt);
}
}
}
}

View File

@@ -21,10 +21,9 @@ from exo.shared.types.commands import (
ForwarderDownloadCommand,
StartDownload,
)
from exo.shared.types.common import NodeId, SessionId, SystemId
from exo.shared.types.common import NodeId
from exo.shared.types.events import (
Event,
LocalForwarderEvent,
NodeDownloadProgress,
)
from exo.shared.types.worker.downloads import (
@@ -35,33 +34,27 @@ from exo.shared.types.worker.downloads import (
DownloadProgress,
)
from exo.shared.types.worker.shards import ShardMetadata
from exo.utils.channels import Receiver, Sender, channel
from exo.utils.channels import Receiver, Sender
@dataclass
class DownloadCoordinator:
node_id: NodeId
session_id: SessionId
shard_downloader: ShardDownloader
download_command_receiver: Receiver[ForwarderDownloadCommand]
local_event_sender: Sender[LocalForwarderEvent]
event_sender: Sender[Event]
offline: bool = False
_system_id: SystemId = field(default_factory=SystemId)
# Local state
download_status: dict[ModelId, DownloadProgress] = field(default_factory=dict)
active_downloads: dict[ModelId, asyncio.Task[None]] = field(default_factory=dict)
# Internal event channel for forwarding (initialized in __post_init__)
event_sender: Sender[Event] = field(init=False)
event_receiver: Receiver[Event] = field(init=False)
_tg: TaskGroup = field(init=False, default_factory=anyio.create_task_group)
# Per-model throttle for download progress events
_last_progress_time: dict[ModelId, float] = field(default_factory=dict)
def __post_init__(self) -> None:
self.event_sender, self.event_receiver = channel[Event]()
if self.offline:
self.shard_downloader.set_internet_connection(False)
self.shard_downloader.on_progress(self._download_progress_callback)
@@ -116,7 +109,6 @@ class DownloadCoordinator:
self._test_internet_connection()
async with self._tg as tg:
tg.start_soon(self._command_processor)
tg.start_soon(self._forward_events)
tg.start_soon(self._emit_existing_download_progress)
if not self.offline:
tg.start_soon(self._check_internet_connection)
@@ -296,22 +288,6 @@ class DownloadCoordinator:
)
del self.download_status[model_id]
async def _forward_events(self) -> None:
idx = 0
with self.event_receiver as events:
async for event in events:
fe = LocalForwarderEvent(
origin_idx=idx,
origin=self._system_id,
session=self.session_id,
event=event,
)
idx += 1
logger.debug(
f"DownloadCoordinator published event {idx}: {str(event)[:100]}"
)
await self.local_event_sender.send(fe)
async def _emit_existing_download_progress(self) -> None:
try:
while True:

View File

@@ -1,10 +1,11 @@
import argparse
import itertools
import multiprocessing as mp
import os
import resource
import signal
from dataclasses import dataclass, field
from typing import Self
from typing import Iterator, Self
import anyio
from anyio.abc import TaskGroup
@@ -37,11 +38,12 @@ class Node:
api: API | None
node_id: NodeId
event_index_counter: Iterator[int]
offline: bool
_tg: TaskGroup = field(init=False, default_factory=anyio.create_task_group)
@classmethod
async def create(cls, args: "Args") -> Self:
async def create(cls, args: "Args") -> "Self":
keypair = get_node_id_keypair()
node_id = NodeId(keypair.to_node_id())
session_id = SessionId(master_node_id=node_id, election_clock=0)
@@ -55,18 +57,7 @@ class Node:
logger.info(f"Starting node {node_id}")
# Create DownloadCoordinator (unless --no-downloads)
if not args.no_downloads:
download_coordinator = DownloadCoordinator(
node_id,
session_id,
exo_shard_downloader(),
download_command_receiver=router.receiver(topics.DOWNLOAD_COMMANDS),
local_event_sender=router.sender(topics.LOCAL_EVENTS),
offline=args.offline,
)
else:
download_coordinator = None
event_index_counter = itertools.count()
if args.spawn_api:
api = API(
@@ -89,10 +80,25 @@ class Node:
local_event_sender=router.sender(topics.LOCAL_EVENTS),
command_sender=router.sender(topics.COMMANDS),
download_command_sender=router.sender(topics.DOWNLOAD_COMMANDS),
event_index_counter=event_index_counter,
)
else:
worker = None
# DownloadCoordinator sends events through the Worker's event channel
# so they get the same index sequence and retry mechanism
if not args.no_downloads:
assert worker is not None, "DownloadCoordinator requires a Worker"
download_coordinator = DownloadCoordinator(
node_id,
exo_shard_downloader(),
download_command_receiver=router.receiver(topics.DOWNLOAD_COMMANDS),
event_sender=worker.event_sender.clone(),
offline=args.offline,
)
else:
download_coordinator = None
# We start every node with a master
master = Master(
node_id,
@@ -126,6 +132,7 @@ class Node:
master,
api,
node_id,
event_index_counter,
args.offline,
)
@@ -204,19 +211,8 @@ class Node:
)
if result.is_new_master:
await anyio.sleep(0)
if self.download_coordinator:
self.download_coordinator.shutdown()
self.download_coordinator = DownloadCoordinator(
self.node_id,
result.session_id,
exo_shard_downloader(),
download_command_receiver=self.router.receiver(
topics.DOWNLOAD_COMMANDS
),
local_event_sender=self.router.sender(topics.LOCAL_EVENTS),
offline=self.offline,
)
self._tg.start_soon(self.download_coordinator.run)
# Fresh counter for new session (buffer expects indices from 0)
self.event_index_counter = itertools.count()
if self.worker:
self.worker.shutdown()
# TODO: add profiling etc to resource monitor
@@ -231,8 +227,22 @@ class Node:
download_command_sender=self.router.sender(
topics.DOWNLOAD_COMMANDS
),
event_index_counter=self.event_index_counter,
)
self._tg.start_soon(self.worker.run)
if self.download_coordinator:
self.download_coordinator.shutdown()
assert self.worker is not None
self.download_coordinator = DownloadCoordinator(
self.node_id,
exo_shard_downloader(),
download_command_receiver=self.router.receiver(
topics.DOWNLOAD_COMMANDS
),
event_sender=self.worker.event_sender.clone(),
offline=self.offline,
)
self._tg.start_soon(self.download_coordinator.run)
if self.api:
self.api.reset(result.session_id, result.won_clock)
else:

View File

@@ -59,7 +59,11 @@ def chat_request_to_text_generation(
chat_template_messages.append({"role": "system", "content": content})
else:
# Skip messages with no meaningful content
if msg.content is None and msg.thinking is None and msg.tool_calls is None:
if (
msg.content is None
and msg.reasoning_content is None
and msg.tool_calls is None
):
continue
if msg.role in ("user", "assistant", "developer"):
@@ -111,6 +115,11 @@ def chunk_to_response(
]
)
if chunk.is_thinking:
delta = ChatCompletionMessage(role="assistant", reasoning_content=chunk.text)
else:
delta = ChatCompletionMessage(role="assistant", content=chunk.text)
return ChatCompletionResponse(
id=command_id,
created=int(time.time()),
@@ -118,7 +127,7 @@ def chunk_to_response(
choices=[
StreamingChoiceResponse(
index=0,
delta=ChatCompletionMessage(role="assistant", content=chunk.text),
delta=delta,
logprobs=logprobs,
finish_reason=chunk.finish_reason,
)
@@ -208,6 +217,7 @@ async def collect_chat_response(
# FastAPI handles the cancellation better but wouldn't auto-serialize for some reason
"""Collect all token chunks and return a single ChatCompletionResponse."""
text_parts: list[str] = []
thinking_parts: list[str] = []
tool_calls: list[ToolCall] = []
logprobs_content: list[LogprobsContentItem] = []
model: str | None = None
@@ -228,7 +238,10 @@ async def collect_chat_response(
if model is None:
model = chunk.model
last_usage = chunk.usage or last_usage
text_parts.append(chunk.text)
if chunk.is_thinking:
thinking_parts.append(chunk.text)
else:
text_parts.append(chunk.text)
if chunk.logprob is not None:
logprobs_content.append(
LogprobsContentItem(
@@ -258,6 +271,7 @@ async def collect_chat_response(
raise ValueError(error_message)
combined_text = "".join(text_parts)
combined_thinking = "".join(thinking_parts) if thinking_parts else None
assert model is not None
yield ChatCompletionResponse(
@@ -270,6 +284,7 @@ async def collect_chat_response(
message=ChatCompletionMessage(
role="assistant",
content=combined_text,
reasoning_content=combined_thinking,
tool_calls=tool_calls if tool_calls else None,
),
logprobs=Logprobs(content=logprobs_content)

View File

@@ -1,6 +1,7 @@
"""Claude Messages API adapter for converting requests/responses."""
import json
import re
from collections.abc import AsyncGenerator
from typing import Any
@@ -28,6 +29,8 @@ from exo.shared.types.claude_api import (
ClaudeStopReason,
ClaudeTextBlock,
ClaudeTextDelta,
ClaudeThinkingBlock,
ClaudeThinkingDelta,
ClaudeToolResultBlock,
ClaudeToolUseBlock,
ClaudeUsage,
@@ -61,6 +64,22 @@ def _extract_tool_result_text(block: ClaudeToolResultBlock) -> str:
return "".join(sub_block.text for sub_block in block.content)
# Matches "x-anthropic-billing-header: ...;" (with optional trailing newline)
# or similar telemetry headers that change every request and break KV prefix caching.
_VOLATILE_HEADER_RE = re.compile(r"^x-anthropic-[^\n]*;\n?", re.MULTILINE)
def _strip_volatile_headers(text: str) -> str:
"""Remove Anthropic billing/telemetry headers from system prompt text.
Claude Code prepends headers like 'x-anthropic-billing-header: cc_version=...;
cc_entrypoint=...; cch=...;' that contain per-request content hashes. These
change every request and break KV prefix caching (the prefix diverges at ~20
tokens instead of matching thousands of conversation tokens).
"""
return _VOLATILE_HEADER_RE.sub("", text)
def claude_request_to_text_generation(
request: ClaudeMessagesRequest,
) -> TextGenerationTaskParams:
@@ -73,6 +92,8 @@ def claude_request_to_text_generation(
instructions = request.system
else:
instructions = "".join(block.text for block in request.system)
instructions = _strip_volatile_headers(instructions)
chat_template_messages.append({"role": "system", "content": instructions})
# Convert messages to input
@@ -85,12 +106,15 @@ def claude_request_to_text_generation(
# Process structured content blocks
text_parts: list[str] = []
thinking_parts: list[str] = []
tool_calls: list[dict[str, Any]] = []
tool_results: list[ClaudeToolResultBlock] = []
for block in msg.content:
if isinstance(block, ClaudeTextBlock):
text_parts.append(block.text)
elif isinstance(block, ClaudeThinkingBlock):
thinking_parts.append(block.thinking)
elif isinstance(block, ClaudeToolUseBlock):
tool_calls.append(
{
@@ -106,6 +130,7 @@ def claude_request_to_text_generation(
tool_results.append(block)
content = "".join(text_parts)
reasoning_content = "".join(thinking_parts) if thinking_parts else None
# Build InputMessage from text content
if msg.role in ("user", "assistant"):
@@ -113,9 +138,14 @@ def claude_request_to_text_generation(
# Build chat_template_messages preserving tool structure
if tool_calls:
chat_template_messages.append(
{"role": "assistant", "content": content, "tool_calls": tool_calls}
)
chat_msg: dict[str, Any] = {
"role": "assistant",
"content": content,
"tool_calls": tool_calls,
}
if reasoning_content:
chat_msg["reasoning_content"] = reasoning_content
chat_template_messages.append(chat_msg)
elif tool_results:
for tr in tool_results:
chat_template_messages.append(
@@ -126,7 +156,10 @@ def claude_request_to_text_generation(
}
)
else:
chat_template_messages.append({"role": msg.role, "content": content})
chat_msg = {"role": msg.role, "content": content}
if reasoning_content:
chat_msg["reasoning_content"] = reasoning_content
chat_template_messages.append(chat_msg)
# Convert Claude tool definitions to OpenAI-style function tools
tools: list[dict[str, Any]] | None = None
@@ -143,6 +176,10 @@ def claude_request_to_text_generation(
for tool in request.tools
]
enable_thinking: bool | None = None
if request.thinking is not None:
enable_thinking = request.thinking.type in ("enabled", "adaptive")
return TextGenerationTaskParams(
model=request.model,
input=input_messages
@@ -156,6 +193,7 @@ def claude_request_to_text_generation(
stop=request.stop_sequences,
stream=request.stream,
tools=tools,
enable_thinking=enable_thinking,
chat_template_messages=chat_template_messages
if chat_template_messages
else None,
@@ -173,6 +211,7 @@ async def collect_claude_response(
# FastAPI handles the cancellation better but wouldn't auto-serialize for some reason
"""Collect all token chunks and return a single ClaudeMessagesResponse."""
text_parts: list[str] = []
thinking_parts: list[str] = []
tool_use_blocks: list[ClaudeToolUseBlock] = []
stop_reason: ClaudeStopReason | None = None
last_usage: Usage | None = None
@@ -200,7 +239,10 @@ async def collect_claude_response(
stop_reason = "tool_use"
continue
text_parts.append(chunk.text)
if chunk.is_thinking:
thinking_parts.append(chunk.text)
else:
text_parts.append(chunk.text)
if chunk.finish_reason is not None:
stop_reason = finish_reason_to_claude_stop_reason(chunk.finish_reason)
@@ -209,9 +251,12 @@ async def collect_claude_response(
raise ValueError(error_message)
combined_text = "".join(text_parts)
combined_thinking = "".join(thinking_parts)
# Build content blocks
content: list[ClaudeContentBlock] = []
if combined_thinking:
content.append(ClaudeThinkingBlock(thinking=combined_thinking))
if combined_text:
content.append(ClaudeTextBlock(text=combined_text))
content.extend(tool_use_blocks)
@@ -256,16 +301,16 @@ async def generate_claude_stream(
start_event = ClaudeMessageStartEvent(message=initial_message)
yield f"event: message_start\ndata: {start_event.model_dump_json()}\n\n"
# content_block_start for text block at index 0
block_start = ClaudeContentBlockStartEvent(
index=0, content_block=ClaudeTextBlock(text="")
)
yield f"event: content_block_start\ndata: {block_start.model_dump_json()}\n\n"
output_tokens = 0
stop_reason: ClaudeStopReason | None = None
last_usage: Usage | None = None
next_block_index = 1 # text block is 0, tool blocks start at 1
next_block_index = 0
# Track whether we've started thinking/text blocks
thinking_block_started = False
thinking_block_index = -1
text_block_started = False
text_block_index = -1
async for chunk in chunk_stream:
if isinstance(chunk, PrefillProgressChunk):
@@ -310,12 +355,45 @@ async def generate_claude_stream(
output_tokens += 1 # Count each chunk as one token
# content_block_delta
delta_event = ClaudeContentBlockDeltaEvent(
index=0,
delta=ClaudeTextDelta(text=chunk.text),
)
yield f"event: content_block_delta\ndata: {delta_event.model_dump_json()}\n\n"
if chunk.is_thinking:
# Start thinking block on first thinking token
if not thinking_block_started:
thinking_block_started = True
thinking_block_index = next_block_index
next_block_index += 1
block_start = ClaudeContentBlockStartEvent(
index=thinking_block_index,
content_block=ClaudeThinkingBlock(thinking=""),
)
yield f"event: content_block_start\ndata: {block_start.model_dump_json()}\n\n"
delta_event = ClaudeContentBlockDeltaEvent(
index=thinking_block_index,
delta=ClaudeThinkingDelta(thinking=chunk.text),
)
yield f"event: content_block_delta\ndata: {delta_event.model_dump_json()}\n\n"
else:
# Close thinking block when transitioning to text
if thinking_block_started and text_block_index == -1:
block_stop = ClaudeContentBlockStopEvent(index=thinking_block_index)
yield f"event: content_block_stop\ndata: {block_stop.model_dump_json()}\n\n"
# Start text block on first text token
if not text_block_started:
text_block_started = True
text_block_index = next_block_index
next_block_index += 1
block_start = ClaudeContentBlockStartEvent(
index=text_block_index,
content_block=ClaudeTextBlock(text=""),
)
yield f"event: content_block_start\ndata: {block_start.model_dump_json()}\n\n"
delta_event = ClaudeContentBlockDeltaEvent(
index=text_block_index,
delta=ClaudeTextDelta(text=chunk.text),
)
yield f"event: content_block_delta\ndata: {delta_event.model_dump_json()}\n\n"
if chunk.finish_reason is not None:
stop_reason = finish_reason_to_claude_stop_reason(chunk.finish_reason)
@@ -324,9 +402,22 @@ async def generate_claude_stream(
if last_usage is not None:
output_tokens = last_usage.completion_tokens
# content_block_stop for text block
block_stop = ClaudeContentBlockStopEvent(index=0)
yield f"event: content_block_stop\ndata: {block_stop.model_dump_json()}\n\n"
# Close any open blocks
if thinking_block_started and text_block_index == -1:
block_stop = ClaudeContentBlockStopEvent(index=thinking_block_index)
yield f"event: content_block_stop\ndata: {block_stop.model_dump_json()}\n\n"
if text_block_started:
block_stop = ClaudeContentBlockStopEvent(index=text_block_index)
yield f"event: content_block_stop\ndata: {block_stop.model_dump_json()}\n\n"
if not thinking_block_started and not text_block_started:
empty_start = ClaudeContentBlockStartEvent(
index=0, content_block=ClaudeTextBlock(text="")
)
yield f"event: content_block_start\ndata: {empty_start.model_dump_json()}\n\n"
empty_stop = ClaudeContentBlockStopEvent(index=0)
yield f"event: content_block_stop\ndata: {empty_stop.model_dump_json()}\n\n"
# message_delta
message_delta = ClaudeMessageDeltaEvent(

View File

@@ -29,6 +29,12 @@ from exo.shared.types.openai_responses import (
ResponseOutputItemAddedEvent,
ResponseOutputItemDoneEvent,
ResponseOutputText,
ResponseReasoningItem,
ResponseReasoningSummaryPartAddedEvent,
ResponseReasoningSummaryPartDoneEvent,
ResponseReasoningSummaryText,
ResponseReasoningSummaryTextDeltaEvent,
ResponseReasoningSummaryTextDoneEvent,
ResponsesRequest,
ResponsesResponse,
ResponsesStreamEvent,
@@ -141,7 +147,9 @@ async def collect_responses_response(
"""Collect all token chunks and return a single ResponsesResponse."""
response_id = f"resp_{command_id}"
item_id = f"item_{command_id}"
reasoning_id = f"rs_{command_id}"
accumulated_text = ""
thinking_parts: list[str] = []
function_call_items: list[ResponseFunctionCallItem] = []
last_usage: Usage | None = None
error_message: str | None = None
@@ -168,6 +176,10 @@ async def collect_responses_response(
)
continue
if chunk.is_thinking:
thinking_parts.append(chunk.text)
continue
accumulated_text += chunk.text
if error_message is not None:
@@ -182,13 +194,21 @@ async def collect_responses_response(
total_tokens=last_usage.total_tokens,
)
output: list[ResponseItem] = [
output: list[ResponseItem] = []
if thinking_parts:
output.append(
ResponseReasoningItem(
id=reasoning_id,
summary=[ResponseReasoningSummaryText(text="".join(thinking_parts))],
)
)
output.append(
ResponseMessageItem(
id=item_id,
content=[ResponseOutputText(text=accumulated_text)],
status="completed",
)
]
)
output.extend(function_call_items)
yield ResponsesResponse(
@@ -212,6 +232,7 @@ async def generate_responses_stream(
"""Generate OpenAI Responses API streaming events from TokenChunks."""
response_id = f"resp_{command_id}"
item_id = f"item_{command_id}"
reasoning_id = f"rs_{command_id}"
seq = count(1)
# response.created
@@ -233,32 +254,17 @@ async def generate_responses_stream(
)
yield _format_sse(in_progress_event)
# response.output_item.added
initial_item = ResponseMessageItem(
id=item_id,
content=[ResponseOutputText(text="")],
status="in_progress",
)
item_added = ResponseOutputItemAddedEvent(
sequence_number=next(seq), output_index=0, item=initial_item
)
yield _format_sse(item_added)
# response.content_part.added
initial_part = ResponseOutputText(text="")
part_added = ResponseContentPartAddedEvent(
sequence_number=next(seq),
item_id=item_id,
output_index=0,
content_index=0,
part=initial_part,
)
yield _format_sse(part_added)
accumulated_text = ""
accumulated_thinking = ""
function_call_items: list[ResponseFunctionCallItem] = []
last_usage: Usage | None = None
next_output_index = 1 # message item is at 0
next_output_index = 0
# Track dynamic block creation
reasoning_started = False
reasoning_output_index = -1
message_started = False
message_output_index = -1
async for chunk in chunk_stream:
if isinstance(chunk, PrefillProgressChunk):
@@ -327,23 +333,184 @@ async def generate_responses_stream(
next_output_index += 1
continue
if chunk.is_thinking:
# Start reasoning block on first thinking token
if not reasoning_started:
reasoning_started = True
reasoning_output_index = next_output_index
next_output_index += 1
# response.output_item.added for reasoning
reasoning_item = ResponseReasoningItem(
id=reasoning_id,
summary=[],
status="in_progress",
)
rs_added = ResponseOutputItemAddedEvent(
sequence_number=next(seq),
output_index=reasoning_output_index,
item=reasoning_item,
)
yield _format_sse(rs_added)
# response.reasoning_summary_part.added
part_added = ResponseReasoningSummaryPartAddedEvent(
sequence_number=next(seq),
item_id=reasoning_id,
output_index=reasoning_output_index,
summary_index=0,
part=ResponseReasoningSummaryText(text=""),
)
yield _format_sse(part_added)
accumulated_thinking += chunk.text
# response.reasoning_summary_text.delta
rs_delta = ResponseReasoningSummaryTextDeltaEvent(
sequence_number=next(seq),
item_id=reasoning_id,
output_index=reasoning_output_index,
summary_index=0,
delta=chunk.text,
)
yield _format_sse(rs_delta)
continue
# Close reasoning block when transitioning to text
if reasoning_started and not message_started:
# response.reasoning_summary_text.done
rs_text_done = ResponseReasoningSummaryTextDoneEvent(
sequence_number=next(seq),
item_id=reasoning_id,
output_index=reasoning_output_index,
summary_index=0,
text=accumulated_thinking,
)
yield _format_sse(rs_text_done)
# response.reasoning_summary_part.done
rs_part_done = ResponseReasoningSummaryPartDoneEvent(
sequence_number=next(seq),
item_id=reasoning_id,
output_index=reasoning_output_index,
summary_index=0,
part=ResponseReasoningSummaryText(text=accumulated_thinking),
)
yield _format_sse(rs_part_done)
# response.output_item.done for reasoning
rs_item_done = ResponseOutputItemDoneEvent(
sequence_number=next(seq),
output_index=reasoning_output_index,
item=ResponseReasoningItem(
id=reasoning_id,
summary=[ResponseReasoningSummaryText(text=accumulated_thinking)],
),
)
yield _format_sse(rs_item_done)
# Start message block on first text token
if not message_started:
message_started = True
message_output_index = next_output_index
next_output_index += 1
initial_item = ResponseMessageItem(
id=item_id,
content=[ResponseOutputText(text="")],
status="in_progress",
)
item_added = ResponseOutputItemAddedEvent(
sequence_number=next(seq),
output_index=message_output_index,
item=initial_item,
)
yield _format_sse(item_added)
initial_part = ResponseOutputText(text="")
part_added = ResponseContentPartAddedEvent(
sequence_number=next(seq),
item_id=item_id,
output_index=message_output_index,
content_index=0,
part=initial_part,
)
yield _format_sse(part_added)
accumulated_text += chunk.text
# response.output_text.delta
delta_event = ResponseTextDeltaEvent(
sequence_number=next(seq),
item_id=item_id,
output_index=0,
output_index=message_output_index,
content_index=0,
delta=chunk.text,
)
yield _format_sse(delta_event)
# Close reasoning block if it was never followed by text
if reasoning_started and not message_started:
rs_text_done = ResponseReasoningSummaryTextDoneEvent(
sequence_number=next(seq),
item_id=reasoning_id,
output_index=reasoning_output_index,
summary_index=0,
text=accumulated_thinking,
)
yield _format_sse(rs_text_done)
rs_part_done = ResponseReasoningSummaryPartDoneEvent(
sequence_number=next(seq),
item_id=reasoning_id,
output_index=reasoning_output_index,
summary_index=0,
part=ResponseReasoningSummaryText(text=accumulated_thinking),
)
yield _format_sse(rs_part_done)
rs_item_done = ResponseOutputItemDoneEvent(
sequence_number=next(seq),
output_index=reasoning_output_index,
item=ResponseReasoningItem(
id=reasoning_id,
summary=[ResponseReasoningSummaryText(text=accumulated_thinking)],
),
)
yield _format_sse(rs_item_done)
# If no message block was started, create one now (empty text)
if not message_started:
message_output_index = next_output_index
next_output_index += 1
initial_item = ResponseMessageItem(
id=item_id,
content=[ResponseOutputText(text="")],
status="in_progress",
)
item_added = ResponseOutputItemAddedEvent(
sequence_number=next(seq),
output_index=message_output_index,
item=initial_item,
)
yield _format_sse(item_added)
initial_part = ResponseOutputText(text="")
part_added_evt = ResponseContentPartAddedEvent(
sequence_number=next(seq),
item_id=item_id,
output_index=message_output_index,
content_index=0,
part=initial_part,
)
yield _format_sse(part_added_evt)
# response.output_text.done
text_done = ResponseTextDoneEvent(
sequence_number=next(seq),
item_id=item_id,
output_index=0,
output_index=message_output_index,
content_index=0,
text=accumulated_text,
)
@@ -354,7 +521,7 @@ async def generate_responses_stream(
part_done = ResponseContentPartDoneEvent(
sequence_number=next(seq),
item_id=item_id,
output_index=0,
output_index=message_output_index,
content_index=0,
part=final_part,
)
@@ -367,7 +534,9 @@ async def generate_responses_stream(
status="completed",
)
item_done = ResponseOutputItemDoneEvent(
sequence_number=next(seq), output_index=0, item=final_message_item
sequence_number=next(seq),
output_index=message_output_index,
item=final_message_item,
)
yield _format_sse(item_done)
@@ -381,7 +550,15 @@ async def generate_responses_stream(
)
# response.completed
output: list[ResponseItem] = [final_message_item]
output: list[ResponseItem] = []
if reasoning_started:
output.append(
ResponseReasoningItem(
id=reasoning_id,
summary=[ResponseReasoningSummaryText(text=accumulated_thinking)],
)
)
output.append(final_message_item)
output.extend(function_call_items)
final_response = ResponsesResponse(
id=response_id,

View File

@@ -132,13 +132,12 @@ from exo.shared.types.commands import (
TaskFinished,
TextGeneration,
)
from exo.shared.types.common import CommandId, Id, NodeId, SessionId, SystemId
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
from exo.shared.types.events import (
ChunkGenerated,
Event,
GlobalForwarderEvent,
ForwarderEvent,
IndexedEvent,
PrefillProgress,
TracesMerged,
)
from exo.shared.types.memory import Memory
@@ -177,7 +176,8 @@ class API:
session_id: SessionId,
*,
port: int,
global_event_receiver: Receiver[GlobalForwarderEvent],
# Ideally this would be a MasterForwarderEvent but type system says no :(
global_event_receiver: Receiver[ForwarderEvent],
command_sender: Sender[ForwarderCommand],
download_command_sender: Sender[ForwarderDownloadCommand],
# This lets us pause the API if an election is running
@@ -185,7 +185,6 @@ class API:
) -> None:
self.state = State()
self._event_log = DiskEventLog(_API_EVENT_LOG_DIR)
self._system_id = SystemId()
self.command_sender = command_sender
self.download_command_sender = download_command_sender
self.global_event_receiver = global_event_receiver
@@ -237,7 +236,6 @@ class API:
self._event_log.close()
self._event_log = DiskEventLog(_API_EVENT_LOG_DIR)
self.state = State()
self._system_id = SystemId()
self.session_id = new_session_id
self.event_buffer = OrderedBuffer[Event]()
self._text_generation_queues = {}
@@ -555,7 +553,7 @@ class API:
command = TaskCancelled(cancelled_command_id=command_id)
with anyio.CancelScope(shield=True):
await self.command_sender.send(
ForwarderCommand(origin=self._system_id, command=command)
ForwarderCommand(origin=self.node_id, command=command)
)
raise
finally:
@@ -903,7 +901,7 @@ class API:
command = TaskCancelled(cancelled_command_id=command_id)
with anyio.CancelScope(shield=True):
await self.command_sender.send(
ForwarderCommand(origin=self._system_id, command=command)
ForwarderCommand(origin=self.node_id, command=command)
)
raise
finally:
@@ -989,7 +987,7 @@ class API:
command = TaskCancelled(cancelled_command_id=command_id)
with anyio.CancelScope(shield=True):
await self.command_sender.send(
ForwarderCommand(origin=self._system_id, command=command)
ForwarderCommand(origin=self.node_id, command=command)
)
raise
finally:
@@ -1430,8 +1428,6 @@ class API:
async def _apply_state(self):
with self.global_event_receiver as events:
async for f_event in events:
if f_event.session != self.session_id:
continue
if f_event.origin != self.session_id.master_node_id:
continue
self.event_buffer.ingest(f_event.origin_idx, f_event.event)
@@ -1458,22 +1454,6 @@ class API:
await queue.send(event.chunk)
except BrokenResourceError:
self._text_generation_queues.pop(event.command_id, None)
elif isinstance(event, PrefillProgress):
if queue := self._text_generation_queues.get(
event.command_id, None
):
try:
await queue.send(
PrefillProgressChunk(
model=event.model,
processed_tokens=event.processed_tokens,
total_tokens=event.total_tokens,
)
)
except BrokenResourceError:
self._text_generation_queues.pop(event.command_id, None)
if isinstance(event, TracesMerged):
self._save_merged_trace(event)
@@ -1511,12 +1491,12 @@ class API:
while self.paused:
await self.paused_ev.wait()
await self.command_sender.send(
ForwarderCommand(origin=self._system_id, command=command)
ForwarderCommand(origin=self.node_id, command=command)
)
async def _send_download(self, command: DownloadCommand):
await self.download_command_sender.send(
ForwarderDownloadCommand(origin=self._system_id, command=command)
ForwarderDownloadCommand(origin=self.node_id, command=command)
)
async def start_download(

View File

@@ -29,14 +29,13 @@ from exo.shared.types.commands import (
TestCommand,
TextGeneration,
)
from exo.shared.types.common import CommandId, NodeId, SessionId, SystemId
from exo.shared.types.common import CommandId, NodeId, SessionId
from exo.shared.types.events import (
Event,
GlobalForwarderEvent,
ForwarderEvent,
IndexedEvent,
InputChunkReceived,
InstanceDeleted,
LocalForwarderEvent,
NodeGatheredInfo,
NodeTimedOut,
TaskCreated,
@@ -72,8 +71,8 @@ class Master:
session_id: SessionId,
*,
command_receiver: Receiver[ForwarderCommand],
local_event_receiver: Receiver[LocalForwarderEvent],
global_event_sender: Sender[GlobalForwarderEvent],
local_event_receiver: Receiver[ForwarderEvent],
global_event_sender: Sender[ForwarderEvent],
download_command_sender: Sender[ForwarderDownloadCommand],
):
self.state = State()
@@ -88,11 +87,10 @@ class Master:
send, recv = channel[Event]()
self.event_sender: Sender[Event] = send
self._loopback_event_receiver: Receiver[Event] = recv
self._loopback_event_sender: Sender[LocalForwarderEvent] = (
self._loopback_event_sender: Sender[ForwarderEvent] = (
local_event_receiver.clone_sender()
)
self._system_id = SystemId()
self._multi_buffer = MultiSourceBuffer[SystemId, Event]()
self._multi_buffer = MultiSourceBuffer[NodeId, Event]()
self._event_log = DiskEventLog(EXO_EVENT_LOG_DIR / "master")
self._pending_traces: dict[TaskId, dict[int, list[TraceEventData]]] = {}
self._expected_ranks: dict[TaskId, set[int]] = {}
@@ -290,7 +288,7 @@ class Master:
):
await self.download_command_sender.send(
ForwarderDownloadCommand(
origin=self._system_id, command=cmd
origin=self.node_id, command=cmd
)
)
generated_events.extend(transition_events)
@@ -416,8 +414,8 @@ class Master:
with self._loopback_event_receiver as events:
async for event in events:
await self._loopback_event_sender.send(
LocalForwarderEvent(
origin=self._system_id,
ForwarderEvent(
origin=NodeId(f"master_{self.node_id}"),
origin_idx=local_index,
session=self.session_id,
event=event,
@@ -429,7 +427,7 @@ class Master:
async def _send_event(self, event: IndexedEvent):
# Convenience method since this line is ugly
await self.global_event_sender.send(
GlobalForwarderEvent(
ForwarderEvent(
origin=self.node_id,
origin_idx=event.idx,
session=self.session_id,

View File

@@ -261,7 +261,7 @@ class TestGenerateClaudeStreamToolUse:
parsed = _parse_sse_events(events)
# Two tool block starts (at indices 1 and 2)
# Two tool block starts (at indices 0 and 1 — no text block when only tools)
tool_starts = [
e
for e in parsed
@@ -270,12 +270,11 @@ class TestGenerateClaudeStreamToolUse:
== "tool_use"
]
assert len(tool_starts) == 2
assert tool_starts[0]["index"] == 1
assert tool_starts[1]["index"] == 2
assert tool_starts[0]["index"] == 0
assert tool_starts[1]["index"] == 1
# Two tool block stops (at indices 1 and 2), plus text block stop at 0
# Two tool block stops (at indices 0 and 1)
block_stops = [e for e in parsed if e.get("type") == "content_block_stop"]
stop_indices = [e["index"] for e in block_stops]
assert 0 in stop_indices
assert 1 in stop_indices
assert 2 in stop_indices

View File

@@ -15,12 +15,11 @@ from exo.shared.types.commands import (
PlaceInstance,
TextGeneration,
)
from exo.shared.types.common import ModelId, NodeId, SessionId, SystemId
from exo.shared.types.common import ModelId, NodeId, SessionId
from exo.shared.types.events import (
GlobalForwarderEvent,
ForwarderEvent,
IndexedEvent,
InstanceCreated,
LocalForwarderEvent,
NodeGatheredInfo,
TaskCreated,
)
@@ -46,9 +45,9 @@ async def test_master():
node_id = NodeId(keypair.to_node_id())
session_id = SessionId(master_node_id=node_id, election_clock=0)
ge_sender, global_event_receiver = channel[GlobalForwarderEvent]()
ge_sender, global_event_receiver = channel[ForwarderEvent]()
command_sender, co_receiver = channel[ForwarderCommand]()
local_event_sender, le_receiver = channel[LocalForwarderEvent]()
local_event_sender, le_receiver = channel[ForwarderEvent]()
fcds, _fcdr = channel[ForwarderDownloadCommand]()
all_events: list[IndexedEvent] = []
@@ -76,12 +75,13 @@ async def test_master():
async with anyio.create_task_group() as tg:
tg.start_soon(master.run)
sender_node_id = NodeId(f"{keypair.to_node_id()}_sender")
# inject a NodeGatheredInfo event
logger.info("inject a NodeGatheredInfo event")
await local_event_sender.send(
LocalForwarderEvent(
ForwarderEvent(
origin_idx=0,
origin=SystemId("Worker"),
origin=sender_node_id,
session=session_id,
event=(
NodeGatheredInfo(
@@ -108,7 +108,7 @@ async def test_master():
logger.info("inject a CreateInstance Command")
await command_sender.send(
ForwarderCommand(
origin=SystemId("API"),
origin=node_id,
command=(
PlaceInstance(
command_id=CommandId(),
@@ -133,7 +133,7 @@ async def test_master():
logger.info("inject a TextGeneration Command")
await command_sender.send(
ForwarderCommand(
origin=SystemId("API"),
origin=node_id,
command=(
TextGeneration(
command_id=CommandId(),

View File

@@ -0,0 +1,250 @@
from collections.abc import Awaitable, Callable
from datetime import datetime, timedelta, timezone
from itertools import count
from pathlib import Path
from typing import AsyncIterator
import anyio
import pytest
from exo.download.coordinator import DownloadCoordinator
from exo.download.shard_downloader import RepoDownloadProgress, ShardDownloader
from exo.master.main import Master
from exo.master.tests.conftest import create_node_memory
from exo.shared.models.model_cards import ModelCard, ModelTask
from exo.shared.types.commands import (
ForwarderCommand,
ForwarderDownloadCommand,
StartDownload,
)
from exo.shared.types.common import ModelId, NodeId, SessionId
from exo.shared.types.events import (
ForwarderEvent,
NodeDownloadProgress,
NodeGatheredInfo,
)
from exo.shared.types.memory import Memory
from exo.shared.types.worker.shards import PipelineShardMetadata, ShardMetadata
from exo.utils.channels import Receiver, Sender, channel
from exo.worker.main import Worker
def _complete_progress(shard: ShardMetadata) -> RepoDownloadProgress:
return RepoDownloadProgress(
repo_id=str(shard.model_card.model_id),
repo_revision="test",
shard=shard,
completed_files=0,
total_files=0,
downloaded_bytes=Memory.from_bytes(0),
downloaded_bytes_this_session=Memory.from_bytes(0),
total_bytes=Memory.from_bytes(0),
overall_speed=0,
overall_eta=timedelta(seconds=0),
status="complete",
)
class _TestShardDownloader(ShardDownloader):
"""Shard downloader that reports every shard as already complete."""
async def ensure_shard(
self, shard: ShardMetadata, config_only: bool = False
) -> Path:
return Path("/tmp/test_shard")
def on_progress(
self,
callback: Callable[[ShardMetadata, RepoDownloadProgress], Awaitable[None]],
) -> None:
pass
async def get_shard_download_status(
self,
) -> AsyncIterator[tuple[Path, RepoDownloadProgress]]:
# Yield nothing — no pre-existing downloads
return
yield # make this an async generator
async def get_shard_download_status_for_shard(
self, shard: ShardMetadata
) -> RepoDownloadProgress:
return _complete_progress(shard)
def _make_heartbeat(node_id: NodeId) -> NodeGatheredInfo:
return NodeGatheredInfo(
node_id=node_id,
when=str(datetime.now(tz=timezone.utc)),
info=create_node_memory(500),
)
class _PartitionSwitch:
"""Mutable boolean flag shared with the partition proxy coroutine."""
def __init__(self) -> None:
self.connected = True
async def _partition_proxy(
source: Receiver[ForwarderEvent],
dest: Sender[ForwarderEvent],
switch: _PartitionSwitch,
) -> None:
"""Forward events when ``switch.connected`` is True; drop otherwise."""
with source as events:
async for event in events:
if switch.connected:
await dest.send(event)
async def _wait_until(
predicate: Callable[[], object], *, timeout: float = 5.0, poll: float = 0.02
) -> None:
"""Poll *predicate* until truthy, raising on timeout."""
with anyio.fail_after(timeout):
while not predicate():
await anyio.sleep(poll)
# ---------------------------------------------------------------------------
# Test 1 same master: Worker + DC retry recovers lost events
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_partition_recovery_same_master() -> None:
"""Worker's out_for_delivery retry fills the Master's buffer gap after a
partition heals, even when DownloadCoordinator events are interleaved."""
master_node = NodeId("master-node")
worker_node = NodeId("worker-node")
session = SessionId(master_node_id=master_node, election_clock=1)
switch = _PartitionSwitch()
# --- channels --------------------------------------------------------
# Worker → proxy → Master (local events)
worker_local_send, proxy_local_recv = channel[ForwarderEvent]()
proxy_local_send, master_local_recv = channel[ForwarderEvent]()
# Master → proxy → Worker (global events)
master_global_send, proxy_global_recv = channel[ForwarderEvent]()
proxy_global_send, worker_global_recv = channel[ForwarderEvent]()
# Commands (required by constructors)
cmd_send, cmd_recv = channel[ForwarderCommand]()
dl_cmd_send, dl_cmd_recv = channel[ForwarderDownloadCommand]()
# --- components ------------------------------------------------------
worker = Worker(
worker_node,
session,
global_event_receiver=worker_global_recv,
local_event_sender=worker_local_send,
command_sender=cmd_send.clone(),
download_command_sender=dl_cmd_send.clone(),
event_index_counter=count(),
)
dc = DownloadCoordinator(
node_id=worker_node,
shard_downloader=_TestShardDownloader(),
download_command_receiver=dl_cmd_recv,
event_sender=worker.event_sender.clone(),
offline=True,
)
master = Master(
master_node,
session,
command_receiver=cmd_recv,
local_event_receiver=master_local_recv,
global_event_sender=master_global_send,
download_command_sender=dl_cmd_send.clone(),
)
async with anyio.create_task_group() as tg:
tg.start_soon(_partition_proxy, proxy_local_recv, proxy_local_send, switch)
tg.start_soon(_partition_proxy, proxy_global_recv, proxy_global_send, switch)
tg.start_soon(master.run)
tg.start_soon(dc.run)
tg.start_soon(worker.run)
# 1. Pre-partition: heartbeat reaches master
await worker.event_sender.send(_make_heartbeat(worker_node))
await _wait_until(lambda: worker_node in master.state.last_seen)
initial_last_seen = master.state.last_seen[worker_node]
# 2. Partition — proxy drops everything
switch.connected = False
# Worker heartbeat during partition — lost at proxy, kept in
# out_for_delivery.
await worker.event_sender.send(_make_heartbeat(worker_node))
# Trigger a download via DC's command channel. NoopShardDownloader
# returns status="complete" for any shard, so _start_download emits
# NodeDownloadProgress(DownloadPending) then
# NodeDownloadProgress(DownloadCompleted) through worker.event_sender.
# These go through _forward_events → proxy (dropped) → out_for_delivery.
# Use a unique model ID so the DC doesn't skip it as already-completed
# (it pre-emits progress for the default "noop" model at startup).
test_shard = PipelineShardMetadata(
model_card=ModelCard(
model_id=ModelId("test-partition-model"),
n_layers=1,
storage_size=Memory.from_bytes(0),
hidden_size=1,
supports_tensor=False,
tasks=[ModelTask.TextGeneration],
),
device_rank=0,
world_size=1,
start_layer=0,
end_layer=1,
n_layers=1,
)
await dl_cmd_send.send(
ForwarderDownloadCommand(
origin=worker_node,
command=StartDownload(
target_node_id=worker_node,
shard_metadata=test_shard,
),
)
)
# Wait for DC events to flow through worker's _forward_events
# (poll instead of sleeping a fixed duration to avoid flakiness on slow CI)
await _wait_until(lambda: len(worker.out_for_delivery) >= 3)
# Verify at least one is a download progress event
has_download_event = any(
isinstance(fe.event, NodeDownloadProgress)
for fe in worker.out_for_delivery.values()
)
assert has_download_event, (
"out_for_delivery should contain DC-originated download events"
)
# 3. Heal partition
switch.connected = True
# Worker's _resend_out_for_delivery runs every ~1-2s.
await _wait_until(
lambda: master.state.last_seen.get(worker_node, initial_last_seen)
> initial_last_seen,
timeout=8.0,
)
# 4. All events recovered — both worker heartbeats and DC download
# progress events were retried and accepted by master.
await _wait_until(lambda: len(worker.out_for_delivery) == 0, timeout=8.0)
# Master state reflects the download
assert worker_node in master.state.downloads
await master.shutdown()
worker.shutdown()
dc.shutdown()

View File

@@ -5,8 +5,7 @@ from exo.routing.connection_message import ConnectionMessage
from exo.shared.election import ElectionMessage
from exo.shared.types.commands import ForwarderCommand, ForwarderDownloadCommand
from exo.shared.types.events import (
GlobalForwarderEvent,
LocalForwarderEvent,
ForwarderEvent,
)
from exo.utils.pydantic_ext import CamelCaseModel
@@ -37,8 +36,8 @@ class TypedTopic[T: CamelCaseModel]:
return self.model_type.model_validate_json(b.decode("utf-8"))
GLOBAL_EVENTS = TypedTopic("global_events", PublishPolicy.Always, GlobalForwarderEvent)
LOCAL_EVENTS = TypedTopic("local_events", PublishPolicy.Always, LocalForwarderEvent)
GLOBAL_EVENTS = TypedTopic("global_events", PublishPolicy.Always, ForwarderEvent)
LOCAL_EVENTS = TypedTopic("local_events", PublishPolicy.Always, ForwarderEvent)
COMMANDS = TypedTopic("commands", PublishPolicy.Always, ForwarderCommand)
ELECTION_MESSAGES = TypedTopic(
"election_messages", PublishPolicy.Always, ElectionMessage

View File

@@ -15,7 +15,6 @@ from exo.shared.types.events import (
NodeDownloadProgress,
NodeGatheredInfo,
NodeTimedOut,
PrefillProgress,
RunnerDeleted,
RunnerStatusUpdated,
TaskAcknowledged,
@@ -65,7 +64,6 @@ def event_apply(event: Event, state: State) -> State:
| ChunkGenerated()
| TaskAcknowledged()
| InputChunkReceived()
| PrefillProgress()
| TracesCollected()
| TracesMerged()
): # Pass-through events that don't modify state

View File

@@ -4,7 +4,7 @@ from anyio import create_task_group, fail_after, move_on_after
from exo.routing.connection_message import ConnectionMessage, ConnectionMessageType
from exo.shared.election import Election, ElectionMessage, ElectionResult
from exo.shared.types.commands import ForwarderCommand, TestCommand
from exo.shared.types.common import NodeId, SessionId, SystemId
from exo.shared.types.common import NodeId, SessionId
from exo.utils.channels import channel
# ======= #
@@ -384,7 +384,7 @@ async def test_tie_breaker_prefers_node_with_more_commands_seen() -> None:
# Pump local commands so our commands_seen is high before the round starts
for _ in range(50):
await co_tx.send(
ForwarderCommand(origin=SystemId("SOMEONE"), command=TestCommand())
ForwarderCommand(origin=NodeId("SOMEONE"), command=TestCommand())
)
# Trigger a round at clock=1 with a peer of equal seniority but fewer commands

View File

@@ -77,7 +77,7 @@ class ChatCompletionMessage(BaseModel):
content: (
str | ChatCompletionMessageText | list[ChatCompletionMessageText] | None
) = None
thinking: str | None = None # Added for GPT-OSS harmony format support
reasoning_content: str | None = None
name: str | None = None
tool_calls: list[ToolCall] | None = None
tool_call_id: str | None = None

View File

@@ -27,6 +27,7 @@ class TokenChunk(BaseChunk):
stats: GenerationStats | None = None
logprob: float | None = None
top_logprobs: list[TopLogprobItem] | None = None
is_thinking: bool = False
class ErrorChunk(BaseChunk):

View File

@@ -47,6 +47,14 @@ class ClaudeImageBlock(BaseModel, frozen=True):
source: ClaudeImageSource
class ClaudeThinkingBlock(BaseModel, frozen=True):
"""Thinking content block in Claude Messages API."""
type: Literal["thinking"] = "thinking"
thinking: str
signature: str | None = None
class ClaudeToolUseBlock(BaseModel, frozen=True):
"""Tool use content block in Claude Messages API."""
@@ -66,11 +74,17 @@ class ClaudeToolResultBlock(BaseModel, frozen=True):
cache_control: dict[str, str] | None = None
ClaudeContentBlock = ClaudeTextBlock | ClaudeImageBlock | ClaudeToolUseBlock
ClaudeContentBlock = (
ClaudeTextBlock | ClaudeImageBlock | ClaudeThinkingBlock | ClaudeToolUseBlock
)
# Input content blocks can also include tool_result (sent by user after tool_use)
ClaudeInputContentBlock = (
ClaudeTextBlock | ClaudeImageBlock | ClaudeToolUseBlock | ClaudeToolResultBlock
ClaudeTextBlock
| ClaudeImageBlock
| ClaudeThinkingBlock
| ClaudeToolUseBlock
| ClaudeToolResultBlock
)
@@ -82,6 +96,11 @@ class ClaudeMessage(BaseModel, frozen=True):
content: str | list[ClaudeInputContentBlock]
class ClaudeThinkingConfig(BaseModel, frozen=True):
type: Literal["enabled", "disabled", "adaptive"]
budget_tokens: int | None = None
class ClaudeMessagesRequest(BaseModel):
"""Request body for Claude Messages API."""
@@ -96,6 +115,7 @@ class ClaudeMessagesRequest(BaseModel):
top_k: int | None = None
tools: list[ClaudeToolDefinition] | None = None
metadata: dict[str, str] | None = None
thinking: ClaudeThinkingConfig | None = None
# Response types
@@ -145,7 +165,7 @@ class ClaudeContentBlockStartEvent(BaseModel, frozen=True):
type: Literal["content_block_start"] = "content_block_start"
index: int
content_block: ClaudeTextBlock | ClaudeToolUseBlock
content_block: ClaudeTextBlock | ClaudeThinkingBlock | ClaudeToolUseBlock
class ClaudeTextDelta(BaseModel, frozen=True):
@@ -155,6 +175,13 @@ class ClaudeTextDelta(BaseModel, frozen=True):
text: str
class ClaudeThinkingDelta(BaseModel, frozen=True):
"""Delta for thinking content block."""
type: Literal["thinking_delta"] = "thinking_delta"
thinking: str
class ClaudeInputJsonDelta(BaseModel, frozen=True):
"""Delta for tool use input JSON content block."""
@@ -167,7 +194,7 @@ class ClaudeContentBlockDeltaEvent(BaseModel, frozen=True):
type: Literal["content_block_delta"] = "content_block_delta"
index: int
delta: ClaudeTextDelta | ClaudeInputJsonDelta
delta: ClaudeTextDelta | ClaudeThinkingDelta | ClaudeInputJsonDelta
class ClaudeContentBlockStopEvent(BaseModel, frozen=True):

View File

@@ -6,7 +6,7 @@ from exo.shared.types.api import (
ImageGenerationTaskParams,
)
from exo.shared.types.chunks import InputImageChunk
from exo.shared.types.common import CommandId, NodeId, SystemId
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.text_generation import TextGenerationTaskParams
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.shards import Sharding, ShardMetadata
@@ -100,10 +100,10 @@ Command = (
class ForwarderCommand(CamelCaseModel):
origin: SystemId
origin: NodeId
command: Command
class ForwarderDownloadCommand(CamelCaseModel):
origin: SystemId
origin: NodeId
command: DownloadCommand

View File

@@ -25,10 +25,6 @@ class NodeId(Id):
pass
class SystemId(Id):
pass
class ModelId(Id):
def normalize(self) -> str:
return self.replace("/", "--")

View File

@@ -5,7 +5,7 @@ from pydantic import Field
from exo.shared.topology import Connection
from exo.shared.types.chunks import GenerationChunk, InputImageChunk
from exo.shared.types.common import CommandId, Id, ModelId, NodeId, SessionId, SystemId
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
from exo.shared.types.tasks import Task, TaskId, TaskStatus
from exo.shared.types.worker.downloads import DownloadProgress
from exo.shared.types.worker.instances import Instance, InstanceId
@@ -102,13 +102,6 @@ class InputChunkReceived(BaseEvent):
chunk: InputImageChunk
class PrefillProgress(BaseEvent):
command_id: CommandId
model: ModelId
processed_tokens: int
total_tokens: int
class TopologyEdgeCreated(BaseEvent):
conn: Connection
@@ -155,7 +148,6 @@ Event = (
| NodeDownloadProgress
| ChunkGenerated
| InputChunkReceived
| PrefillProgress
| TopologyEdgeCreated
| TopologyEdgeDeleted
| TracesCollected
@@ -170,19 +162,10 @@ class IndexedEvent(CamelCaseModel):
event: Event
class GlobalForwarderEvent(CamelCaseModel):
class ForwarderEvent(CamelCaseModel):
"""An event the forwarder will serialize and send over the network"""
origin_idx: int = Field(ge=0)
origin: NodeId
session: SessionId
event: Event
class LocalForwarderEvent(CamelCaseModel):
"""An event the forwarder will serialize and send over the network"""
origin_idx: int = Field(ge=0)
origin: SystemId
session: SessionId
event: Event

View File

@@ -145,7 +145,23 @@ class ResponseFunctionCallItem(BaseModel, frozen=True):
status: ResponseStatus = "completed"
ResponseItem = ResponseMessageItem | ResponseFunctionCallItem
class ResponseReasoningSummaryText(BaseModel, frozen=True):
"""Summary text part in a reasoning output item."""
type: Literal["summary_text"] = "summary_text"
text: str
class ResponseReasoningItem(BaseModel, frozen=True):
"""Reasoning output item in response output array."""
type: Literal["reasoning"] = "reasoning"
id: str
summary: list[ResponseReasoningSummaryText] = Field(default_factory=list)
status: ResponseStatus = "completed"
ResponseItem = ResponseMessageItem | ResponseFunctionCallItem | ResponseReasoningItem
class ResponseUsage(BaseModel, frozen=True):
@@ -273,6 +289,58 @@ class ResponseFunctionCallArgumentsDoneEvent(BaseModel, frozen=True):
arguments: str
class ResponseReasoningSummaryPartAddedEvent(BaseModel, frozen=True):
"""Event sent when a reasoning summary part is added."""
type: Literal["response.reasoning_summary_part.added"] = (
"response.reasoning_summary_part.added"
)
sequence_number: int
item_id: str
output_index: int
summary_index: int
part: ResponseReasoningSummaryText
class ResponseReasoningSummaryTextDeltaEvent(BaseModel, frozen=True):
"""Event sent for reasoning summary text delta during streaming."""
type: Literal["response.reasoning_summary_text.delta"] = (
"response.reasoning_summary_text.delta"
)
sequence_number: int
item_id: str
output_index: int
summary_index: int
delta: str
class ResponseReasoningSummaryTextDoneEvent(BaseModel, frozen=True):
"""Event sent when reasoning summary text is done."""
type: Literal["response.reasoning_summary_text.done"] = (
"response.reasoning_summary_text.done"
)
sequence_number: int
item_id: str
output_index: int
summary_index: int
text: str
class ResponseReasoningSummaryPartDoneEvent(BaseModel, frozen=True):
"""Event sent when a reasoning summary part is done."""
type: Literal["response.reasoning_summary_part.done"] = (
"response.reasoning_summary_part.done"
)
sequence_number: int
item_id: str
output_index: int
summary_index: int
part: ResponseReasoningSummaryText
class ResponseCompletedEvent(BaseModel, frozen=True):
"""Event sent when response is completed."""
@@ -292,5 +360,9 @@ ResponsesStreamEvent = (
| ResponseOutputItemDoneEvent
| ResponseFunctionCallArgumentsDeltaEvent
| ResponseFunctionCallArgumentsDoneEvent
| ResponseReasoningSummaryPartAddedEvent
| ResponseReasoningSummaryTextDeltaEvent
| ResponseReasoningSummaryTextDoneEvent
| ResponseReasoningSummaryPartDoneEvent
| ResponseCompletedEvent
)

View File

@@ -28,6 +28,7 @@ class GenerationResponse(BaseRunnerResponse):
finish_reason: FinishReason | None = None
stats: GenerationStats | None = None
usage: Usage | None
is_thinking: bool = False
class ImageGenerationResponse(BaseRunnerResponse):

View File

@@ -0,0 +1,72 @@
import json
import re
from typing import Any
from mlx_lm.chat_templates import deepseek_v32
from exo.shared.types.api import ToolCallItem
BOS_TOKEN: str = deepseek_v32.bos_token
EOS_TOKEN: str = deepseek_v32.eos_token
DSML_TOKEN: str = deepseek_v32.dsml_token
THINKING_START: str = deepseek_v32.thinking_start_token
THINKING_END: str = deepseek_v32.thinking_end_token
USER_TOKEN = "<\uff5cUser\uff5c>"
ASSISTANT_TOKEN = "<\uff5cAssistant\uff5c>"
TOOL_CALLS_START = f"<{DSML_TOKEN}function_calls>"
TOOL_CALLS_END = f"</{DSML_TOKEN}function_calls>"
encode_messages = deepseek_v32.encode_messages
_INVOKE_PATTERN = re.compile(
rf"<{re.escape(DSML_TOKEN)}invoke\s+name=\"([^\"]+)\">"
rf"(.*?)"
rf"</{re.escape(DSML_TOKEN)}invoke>",
re.DOTALL,
)
_PARAM_PATTERN = re.compile(
rf"<{re.escape(DSML_TOKEN)}parameter\s+name=\"([^\"]+)\"\s+string=\"(true|false)\">"
rf"(.*?)"
rf"</{re.escape(DSML_TOKEN)}parameter>",
re.DOTALL,
)
def parse_dsml_output(text: str) -> list[ToolCallItem] | None:
"""Parse DSML function_calls block from model output text.
Args:
text: The text containing the DSML function_calls block
(including the start/end markers).
Returns:
List of ToolCallItem, or None if parsing fails.
"""
tool_calls: list[ToolCallItem] = []
for invoke_match in _INVOKE_PATTERN.finditer(text):
func_name = invoke_match.group(1)
invoke_body = invoke_match.group(2)
args: dict[str, Any] = {}
for param_match in _PARAM_PATTERN.finditer(invoke_body):
param_name = param_match.group(1)
is_string = param_match.group(2) == "true"
param_value = param_match.group(3)
if is_string:
args[param_name] = param_value
else:
try:
args[param_name] = json.loads(param_value)
except (json.JSONDecodeError, ValueError):
args[param_name] = param_value
tool_calls.append(
ToolCallItem(
name=func_name,
arguments=json.dumps(args),
)
)
return tool_calls if tool_calls else None

View File

@@ -458,6 +458,19 @@ def _patch_lossy_chat_template(template: str) -> str | None:
return patched if n > 0 else None
def _needs_dsml_encoding(task_params: TextGenerationTaskParams) -> bool:
if "deepseek-v3.2" not in task_params.model.lower():
return False
# Use DSML encoding when tools are provided or tool results are in the conversation
if task_params.tools:
return True
if task_params.chat_template_messages:
return any(
msg.get("role") == "tool" for msg in task_params.chat_template_messages
)
return False
def apply_chat_template(
tokenizer: TokenizerWrapper,
task_params: TextGenerationTaskParams,
@@ -469,7 +482,6 @@ def apply_chat_template(
When chat_template_messages is available (from Chat Completions API),
uses those directly to preserve tool_calls, thinking, and other fields.
Otherwise builds messages from the task params input/instructions.
"""
formatted_messages: list[dict[str, Any]] = []
if task_params.chat_template_messages is not None:
@@ -497,6 +509,19 @@ def apply_chat_template(
partial_assistant_content = cast(str, formatted_messages[-1].get("content", ""))
formatted_messages = formatted_messages[:-1]
if _needs_dsml_encoding(task_params):
from exo.worker.engines.mlx.dsml_encoding import encode_messages
prompt = encode_messages(
messages=formatted_messages,
thinking_mode="thinking" if task_params.enable_thinking else "chat",
tools=task_params.tools,
)
if partial_assistant_content:
prompt += partial_assistant_content
logger.info(prompt)
return prompt
extra_kwargs: dict[str, Any] = {}
if task_params.enable_thinking is not None:
# Qwen3 and GLM use "enable_thinking"; DeepSeek uses "thinking".

View File

@@ -1,6 +1,7 @@
from collections import defaultdict
from datetime import datetime, timezone
from random import random
from typing import Iterator
import anyio
from anyio import CancelScope, create_task_group, fail_after
@@ -16,14 +17,13 @@ from exo.shared.types.commands import (
RequestEventLog,
StartDownload,
)
from exo.shared.types.common import CommandId, NodeId, SessionId, SystemId
from exo.shared.types.common import CommandId, NodeId, SessionId
from exo.shared.types.events import (
Event,
EventId,
GlobalForwarderEvent,
ForwarderEvent,
IndexedEvent,
InputChunkReceived,
LocalForwarderEvent,
NodeGatheredInfo,
TaskCreated,
TaskStatusUpdated,
@@ -58,22 +58,24 @@ class Worker:
node_id: NodeId,
session_id: SessionId,
*,
global_event_receiver: Receiver[GlobalForwarderEvent],
local_event_sender: Sender[LocalForwarderEvent],
global_event_receiver: Receiver[ForwarderEvent],
local_event_sender: Sender[ForwarderEvent],
# This is for requesting updates. It doesn't need to be a general command sender right now,
# but I think it's the correct way to be thinking about commands
command_sender: Sender[ForwarderCommand],
download_command_sender: Sender[ForwarderDownloadCommand],
event_index_counter: Iterator[int],
):
self.node_id: NodeId = node_id
self.session_id: SessionId = session_id
self.global_event_receiver = global_event_receiver
self.local_event_sender = local_event_sender
self.event_index_counter = event_index_counter
self.command_sender = command_sender
self.download_command_sender = download_command_sender
self.event_buffer = OrderedBuffer[Event]()
self.out_for_delivery: dict[EventId, LocalForwarderEvent] = {}
self.out_for_delivery: dict[EventId, ForwarderEvent] = {}
self.state: State = State()
self.runners: dict[RunnerId, RunnerSupervisor] = {}
@@ -84,8 +86,6 @@ class Worker:
self._nack_base_seconds: float = 0.5
self._nack_cap_seconds: float = 10.0
self._system_id = SystemId()
self.event_sender, self.event_receiver = channel[Event]()
# Buffer for input image chunks (for image editing)
@@ -132,8 +132,6 @@ class Worker:
async def _event_applier(self):
with self.global_event_receiver as events:
async for f_event in events:
if f_event.session != self.session_id:
continue
if f_event.origin != self.session_id.master_node_id:
continue
self.event_buffer.ingest(f_event.origin_idx, f_event.event)
@@ -214,7 +212,7 @@ class Worker:
await self.download_command_sender.send(
ForwarderDownloadCommand(
origin=self._system_id,
origin=self.node_id,
command=StartDownload(
target_node_id=self.node_id,
shard_metadata=shard,
@@ -319,7 +317,7 @@ class Worker:
)
await self.command_sender.send(
ForwarderCommand(
origin=self._system_id,
origin=self.node_id,
command=RequestEventLog(since_idx=since_idx),
)
)
@@ -346,16 +344,15 @@ class Worker:
return runner
async def _forward_events(self) -> None:
idx = 0
with self.event_receiver as events:
async for event in events:
fe = LocalForwarderEvent(
idx = next(self.event_index_counter)
fe = ForwarderEvent(
origin_idx=idx,
origin=self._system_id,
origin=self.node_id,
session=self.session_id,
event=event,
)
idx += 1
logger.debug(f"Worker published event {idx}: {str(event)[:100]}")
await self.local_event_sender.send(fe)
self.out_for_delivery[event.event_id] = fe

View File

@@ -7,6 +7,7 @@ from functools import cache
from typing import Literal
import mlx.core as mx
from mlx_lm.models.deepseek_v32 import Model as DeepseekV32Model
from mlx_lm.models.gpt_oss import Model as GptOssModel
from mlx_lm.tokenizer_utils import TokenizerWrapper
from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
@@ -21,12 +22,17 @@ from exo.shared.constants import EXO_MAX_CHUNK_SIZE, EXO_TRACING_ENABLED
from exo.shared.models.model_cards import ModelId, ModelTask
from exo.shared.tracing import clear_trace_buffer, get_trace_buffer
from exo.shared.types.api import ImageGenerationStats
from exo.shared.types.chunks import ErrorChunk, ImageChunk, TokenChunk, ToolCallChunk
from exo.shared.types.chunks import (
ErrorChunk,
ImageChunk,
PrefillProgressChunk,
TokenChunk,
ToolCallChunk,
)
from exo.shared.types.common import CommandId
from exo.shared.types.events import (
ChunkGenerated,
Event,
PrefillProgress,
RunnerStatusUpdated,
TaskAcknowledged,
TaskStatusUpdated,
@@ -315,11 +321,13 @@ def main(
) -> None:
if device_rank == 0:
event_sender.send(
PrefillProgress(
ChunkGenerated(
command_id=command_id,
model=shard_metadata.model_card.model_id,
processed_tokens=processed,
total_tokens=total,
chunk=PrefillProgressChunk(
model=shard_metadata.model_card.model_id,
processed_tokens=processed,
total_tokens=total,
),
)
)
cancelled_tasks.update(cancel_receiver.collect())
@@ -346,16 +354,22 @@ def main(
group=group,
)
# For other thinking models (GLM, etc.), check if we need to
# prepend the thinking tag that was consumed by the chat template
if detect_thinking_prompt_suffix(prompt, tokenizer):
if tokenizer.has_thinking:
mlx_generator = parse_thinking_models(
mlx_generator, tokenizer
mlx_generator,
tokenizer,
# For other thinking models (GLM, etc.), check if we need to
# prepend the thinking tag that was consumed by the chat template
starts_in_thinking=detect_thinking_prompt_suffix(
prompt, tokenizer
),
)
# GPT-OSS specific parsing to match other model formats.
# Model-specific output parsing for tool calls.
if isinstance(inference_model, GptOssModel):
mlx_generator = parse_gpt_oss(mlx_generator)
elif isinstance(inference_model, DeepseekV32Model):
mlx_generator = parse_deepseek_v32(mlx_generator)
elif tool_parser:
mlx_generator = parse_tool_calls(mlx_generator, tool_parser)
@@ -407,6 +421,7 @@ def main(
stats=response.stats,
logprob=response.logprob,
top_logprobs=response.top_logprobs,
is_thinking=response.is_thinking,
),
)
)
@@ -668,44 +683,208 @@ def parse_gpt_oss(
if ch == "analysis" and not thinking:
thinking = True
yield response.model_copy(update={"text": "<think>"})
if ch != "analysis" and thinking:
thinking = False
yield response.model_copy(update={"text": "</think>"})
if delta:
yield response.model_copy(update={"text": delta})
yield response.model_copy(update={"text": delta, "is_thinking": thinking})
if response.finish_reason is not None:
if thinking:
yield response.model_copy(update={"text": "</think>"})
yield response
def parse_deepseek_v32(
responses: Generator[GenerationResponse],
) -> Generator[GenerationResponse | ToolCallResponse]:
"""Parse DeepSeek V3.2 DSML tool calls from the generation stream.
Uses accumulated-text matching (not per-token marker checks) because
DSML markers like <DSMLfunction_calls> may span multiple tokens.
Also handles <think>...</think> blocks for thinking mode.
"""
from exo.worker.engines.mlx.dsml_encoding import (
THINKING_END,
THINKING_START,
TOOL_CALLS_END,
TOOL_CALLS_START,
parse_dsml_output,
)
accumulated = ""
in_tool_call = False
thinking = False
# Tokens buffered while we detect the start of a DSML block
pending_buffer: list[GenerationResponse] = []
# Text accumulated during a tool call block
tool_call_text = ""
for response in responses:
assert isinstance(response, GenerationResponse)
# ── Handle thinking tags ──
if not thinking and THINKING_START in response.text:
thinking = True
# Yield any text before the <think> tag
before = response.text[: response.text.index(THINKING_START)]
if before:
yield response.model_copy(update={"text": before})
continue
if thinking and THINKING_END in response.text:
thinking = False
# Yield any text after the </think> tag
after = response.text[
response.text.index(THINKING_END) + len(THINKING_END) :
]
if after:
yield response.model_copy(update={"text": after, "is_thinking": False})
continue
if thinking:
yield response.model_copy(update={"is_thinking": True})
continue
# ── Handle tool call accumulation ──
if in_tool_call:
tool_call_text += response.text
if TOOL_CALLS_END in tool_call_text:
# Parse the accumulated DSML block
parsed = parse_dsml_output(tool_call_text)
if parsed is not None:
logger.info(f"parsed DSML tool calls: {parsed}")
yield ToolCallResponse(
tool_calls=parsed,
usage=response.usage,
stats=response.stats,
)
else:
logger.warning(
f"DSML tool call parsing failed for: {tool_call_text}"
)
yield response.model_copy(update={"text": tool_call_text})
in_tool_call = False
tool_call_text = ""
continue
# EOS reached before end marker — yield buffered text as-is
if response.finish_reason is not None:
logger.info("DSML tool call parsing interrupted by EOS")
yield response.model_copy(update={"text": tool_call_text})
in_tool_call = False
tool_call_text = ""
continue
# ── Detect start of tool call block ──
accumulated += response.text
if TOOL_CALLS_START in accumulated:
# The start marker might be split across pending_buffer + current token
start_idx = accumulated.index(TOOL_CALLS_START)
# Yield any pending tokens that are purely before the marker
pre_text = accumulated[:start_idx]
if pre_text:
# Flush pending buffer tokens that contributed text before the marker
for buf_resp in pending_buffer:
if pre_text:
chunk = buf_resp.text
if len(chunk) <= len(pre_text):
yield buf_resp
pre_text = pre_text[len(chunk) :]
else:
yield buf_resp.model_copy(update={"text": pre_text})
pre_text = ""
pending_buffer = []
tool_call_text = accumulated[start_idx:]
accumulated = ""
# Check if the end marker is already present (entire tool call in one token)
if TOOL_CALLS_END in tool_call_text:
parsed = parse_dsml_output(tool_call_text)
if parsed is not None:
logger.info(f"parsed DSML tool calls: {parsed}")
yield ToolCallResponse(
tool_calls=parsed,
usage=response.usage,
stats=response.stats,
)
else:
logger.warning(
f"DSML tool call parsing failed for: {tool_call_text}"
)
yield response.model_copy(update={"text": tool_call_text})
tool_call_text = ""
else:
in_tool_call = True
continue
# Check if accumulated text might be the start of a DSML marker
# Buffer tokens if we see a partial match at the end
if _could_be_dsml_prefix(accumulated):
pending_buffer.append(response)
continue
# No partial match — flush all pending tokens and the current one
for buf_resp in pending_buffer:
yield buf_resp
pending_buffer = []
accumulated = ""
yield response
# Flush any remaining pending buffer at generator end
for buf_resp in pending_buffer:
yield buf_resp
def _could_be_dsml_prefix(text: str) -> bool:
"""Check if the end of text could be the start of a DSML function_calls marker.
We look for suffixes of text that are prefixes of the TOOL_CALLS_START pattern.
This allows us to buffer tokens until we can determine if a tool call is starting.
"""
from exo.worker.engines.mlx.dsml_encoding import TOOL_CALLS_START
# Only check the last portion of text that could overlap with the marker
max_check = len(TOOL_CALLS_START)
tail = text[-max_check:] if len(text) > max_check else text
# Check if any suffix of tail is a prefix of TOOL_CALLS_START
for i in range(len(tail)):
suffix = tail[i:]
if TOOL_CALLS_START.startswith(suffix):
return True
return False
def parse_thinking_models(
responses: Generator[GenerationResponse],
tokenizer: TokenizerWrapper,
starts_in_thinking: bool = True,
) -> Generator[GenerationResponse]:
"""Route thinking tokens via is_thinking flag.
Swallows think tag tokens, sets is_thinking on all others.
Always yields tokens with finish_reason to avoid hanging the chunk stream.
"""
For models that inject thinking tags in the prompt (like GLM-4.7),
prepend the thinking tag to the output stream so the frontend
can properly parse thinking content.
"""
first = True
in_thinking = starts_in_thinking
for response in responses:
if isinstance(response, ToolCallResponse):
yield response
continue
if first:
first = False
yield response.model_copy(
update={
"text": tokenizer.think_start,
"token": tokenizer.think_start_id,
}
)
yield response
is_think_tag = (
tokenizer.think_end is not None and response.text == tokenizer.think_end
) or (
tokenizer.think_start is not None and response.text == tokenizer.think_start
)
if is_think_tag:
in_thinking = response.text != tokenizer.think_end
# Never swallow finish_reason — the chunk stream needs it to terminate.
if response.finish_reason is not None:
yield response.model_copy(update={"text": "", "is_thinking": False})
continue
yield response.model_copy(update={"is_thinking": in_thinking})
def _send_image_chunk(

View File

@@ -0,0 +1,967 @@
import json
from collections.abc import Generator
from typing import Any
from exo.shared.types.worker.runner_response import (
GenerationResponse,
ToolCallResponse,
)
from exo.worker.engines.mlx.dsml_encoding import (
ASSISTANT_TOKEN,
BOS_TOKEN,
DSML_TOKEN,
EOS_TOKEN,
THINKING_END,
THINKING_START,
TOOL_CALLS_END,
TOOL_CALLS_START,
USER_TOKEN,
encode_messages,
parse_dsml_output,
)
from exo.worker.runner.runner import parse_deepseek_v32
# ── Shared fixtures ──────────────────────────────────────────────
_WEATHER_TOOLS: list[dict[str, Any]] = [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get the current weather in a given city",
"parameters": {
"type": "object",
"properties": {
"city": {"type": "string", "description": "The city name"},
"units": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
"description": "Temperature units",
},
},
"required": ["city"],
},
},
},
{
"type": "function",
"function": {
"name": "get_time",
"description": "Get the current time in a timezone",
"parameters": {
"type": "object",
"properties": {
"timezone": {"type": "string"},
},
"required": ["timezone"],
},
},
},
]
def _simulate_tokens(
texts: list[str],
finish_on_last: bool = True,
) -> Generator[GenerationResponse]:
"""Simulate a model producing tokens from a list of text strings."""
for i, text in enumerate(texts):
is_last = i == len(texts) - 1
yield GenerationResponse(
text=text,
token=i,
finish_reason="stop" if (is_last and finish_on_last) else None,
usage=None,
)
# ── Test: Standard text response (no tool calls) ────────────────
class TestE2EStandardResponse:
"""Model generates a plain text response — no tool calling involved."""
def test_plain_text_passthrough(self):
"""Simulate model producing: 'The weather in NYC is 72°F and sunny.'"""
# Step 1: Encode the prompt (with tools available)
messages: list[dict[str, Any]] = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What's the weather in NYC?"},
]
prompt = encode_messages(messages, thinking_mode="chat", tools=_WEATHER_TOOLS)
# Verify prompt structure
assert BOS_TOKEN in prompt
assert "## Tools" in prompt
assert "get_weather" in prompt
assert f"{USER_TOKEN}What's the weather in NYC?{ASSISTANT_TOKEN}" in prompt
# Step 2: Simulate model response — plain text tokens (no DSML)
model_tokens = [
"The weather",
" in NYC",
" is 72",
"°F",
" and sunny",
".",
]
results = list(parse_deepseek_v32(_simulate_tokens(model_tokens)))
# Step 3: Verify all tokens pass through as GenerationResponse
gen_results = [r for r in results if isinstance(r, GenerationResponse)]
tool_results = [r for r in results if isinstance(r, ToolCallResponse)]
assert len(tool_results) == 0
assert len(gen_results) == 6
full_text = "".join(r.text for r in gen_results)
assert full_text == "The weather in NYC is 72°F and sunny."
assert gen_results[-1].finish_reason == "stop"
# ── Test: Tool call response ─────────────────────────────────────
class TestE2EToolCallResponse:
"""Model generates a DSML tool call — realistic token boundaries."""
def test_realistic_tool_call_tokens(self):
"""Simulate model generating a get_weather tool call with realistic token splits.
Real models split DSML markers across tokens unpredictably.
This simulates how DeepSeek V3.2 actually tokenizes DSML output.
"""
# Step 1: Encode prompt
messages: list[dict[str, Any]] = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What's the weather in San Francisco?"},
]
prompt = encode_messages(messages, thinking_mode="chat", tools=_WEATHER_TOOLS)
assert "get_weather" in prompt
# Step 2: Simulate realistic token-by-token model output
# The model first produces some text, then a DSML tool call block
model_tokens = [
"I'll check the weather for you.",
"\n\n",
f"<{DSML_TOKEN}", # marker split across tokens
"function_calls>\n",
f'<{DSML_TOKEN}invoke name="get_weather">\n',
f'<{DSML_TOKEN}parameter name="city" string="true">',
"San Francisco",
f"</{DSML_TOKEN}parameter>\n",
f'<{DSML_TOKEN}parameter name="units" string="false">',
'"celsius"',
f"</{DSML_TOKEN}parameter>\n",
f"</{DSML_TOKEN}invoke>\n",
f"</{DSML_TOKEN}function_calls>",
]
results = list(parse_deepseek_v32(_simulate_tokens(model_tokens)))
# Step 3: Verify
gen_results = [r for r in results if isinstance(r, GenerationResponse)]
tool_results = [r for r in results if isinstance(r, ToolCallResponse)]
# Should have text tokens before tool call + one ToolCallResponse
assert len(tool_results) == 1
assert len(tool_results[0].tool_calls) == 1
tc = tool_results[0].tool_calls[0]
assert tc.name == "get_weather"
args = json.loads(tc.arguments) # pyright: ignore[reportAny]
assert args["city"] == "San Francisco"
assert args["units"] == "celsius"
# The text before the tool call should still be yielded
text_before = "".join(r.text for r in gen_results if not r.is_thinking)
assert "check the weather" in text_before
def test_multiple_tool_calls_in_one_block(self):
"""Model generates two tool calls in a single function_calls block."""
messages: list[dict[str, Any]] = [
{"role": "system", "content": "You are helpful."},
{"role": "user", "content": "Weather in NYC and time in EST?"},
]
prompt = encode_messages(messages, thinking_mode="chat", tools=_WEATHER_TOOLS)
assert "get_weather" in prompt
assert "get_time" in prompt
# Simulate model output with two invocations
model_tokens = [
"Let me check both.\n\n",
TOOL_CALLS_START,
"\n",
f'<{DSML_TOKEN}invoke name="get_weather">\n',
f'<{DSML_TOKEN}parameter name="city" string="true">NYC</{DSML_TOKEN}parameter>\n',
f"</{DSML_TOKEN}invoke>\n",
f'<{DSML_TOKEN}invoke name="get_time">\n',
f'<{DSML_TOKEN}parameter name="timezone" string="true">EST</{DSML_TOKEN}parameter>\n',
f"</{DSML_TOKEN}invoke>\n",
TOOL_CALLS_END,
]
results = list(parse_deepseek_v32(_simulate_tokens(model_tokens)))
tool_results = [r for r in results if isinstance(r, ToolCallResponse)]
assert len(tool_results) == 1
assert len(tool_results[0].tool_calls) == 2
assert tool_results[0].tool_calls[0].name == "get_weather"
assert tool_results[0].tool_calls[1].name == "get_time"
args0 = json.loads(tool_results[0].tool_calls[0].arguments) # pyright: ignore[reportAny]
args1 = json.loads(tool_results[0].tool_calls[1].arguments) # pyright: ignore[reportAny]
assert args0 == {"city": "NYC"}
assert args1 == {"timezone": "EST"}
# ── Test: Multi-turn tool use flow ───────────────────────────────
class TestE2EMultiTurnToolUse:
"""Full multi-turn: user asks → model calls tool → tool result → model answers."""
def test_encode_multi_turn_with_tool_results(self):
"""Verify the prompt for turn 2 (after tool results) is correctly encoded."""
# Turn 1: user asks, model calls tool
# Turn 2: tool result provided, model answers
messages: list[dict[str, Any]] = [
{"role": "system", "content": "You are a weather assistant."},
{"role": "user", "content": "What's the weather in NYC?"},
{
"role": "assistant",
"content": "",
"tool_calls": [
{
"type": "function",
"function": {
"name": "get_weather",
"arguments": '{"city": "NYC"}',
},
}
],
},
{"role": "tool", "content": '{"temperature": 72, "condition": "sunny"}'},
]
prompt = encode_messages(messages, thinking_mode="chat", tools=_WEATHER_TOOLS)
# Verify multi-turn structure
assert BOS_TOKEN in prompt
assert "You are a weather assistant." in prompt
assert "## Tools" in prompt
# The assistant's tool call should be encoded as DSML
assert TOOL_CALLS_START in prompt
assert f'<{DSML_TOKEN}invoke name="get_weather">' in prompt
assert EOS_TOKEN in prompt
# The tool result should be wrapped in function_results
assert "<function_results>" in prompt
assert "<result>" in prompt
assert "72" in prompt
assert "</function_results>" in prompt
# Now simulate model answering after seeing the tool result
model_tokens = [
"The current",
" weather in NYC",
" is 72°F",
" and sunny.",
]
results = list(parse_deepseek_v32(_simulate_tokens(model_tokens)))
gen_results = [r for r in results if isinstance(r, GenerationResponse)]
tool_results = [r for r in results if isinstance(r, ToolCallResponse)]
assert len(tool_results) == 0
full_text = "".join(r.text for r in gen_results)
assert full_text == "The current weather in NYC is 72°F and sunny."
def test_multi_tool_results_encoding(self):
"""Verify encoding when model called two tools and both return results."""
messages: list[dict[str, Any]] = [
{"role": "user", "content": "Weather and time?"},
{
"role": "assistant",
"content": "",
"tool_calls": [
{
"type": "function",
"function": {
"name": "get_weather",
"arguments": '{"city": "LA"}',
},
},
{
"type": "function",
"function": {
"name": "get_time",
"arguments": '{"timezone": "PST"}',
},
},
],
},
{"role": "tool", "content": "85F, clear skies"},
{"role": "tool", "content": "3:42 PM PST"},
]
prompt = encode_messages(messages, thinking_mode="chat", tools=_WEATHER_TOOLS)
# Should have one function_results block with two results
assert prompt.count("<function_results>") == 1
assert prompt.count("</function_results>") == 1
assert "<result>85F, clear skies</result>" in prompt
assert "<result>3:42 PM PST</result>" in prompt
# ── Test: Thinking + tool call ───────────────────────────────────
class TestE2EThinkingAndToolCall:
"""Model uses thinking mode, reasons, then makes a tool call."""
def test_thinking_then_tool_call(self):
"""Model thinks first, then produces a DSML tool call block."""
messages: list[dict[str, Any]] = [
{"role": "user", "content": "What's the weather?"},
]
prompt = encode_messages(
messages, tools=_WEATHER_TOOLS, thinking_mode="thinking"
)
# Thinking mode: prompt should end with <think>
assert prompt.endswith(THINKING_START)
# Simulate: model outputs <think>, thinks, closes thinking, then tool call.
# In the full pipeline, parse_thinking_models handles the case where
# <think> is in the prompt. Here we test parse_deepseek_v32 directly,
# which detects <think>/<think> markers in the stream.
model_tokens = [
THINKING_START,
"The user wants weather",
" information. I should use",
" the get_weather tool.",
THINKING_END,
"\n\n",
TOOL_CALLS_START,
"\n",
f'<{DSML_TOKEN}invoke name="get_weather">\n',
f'<{DSML_TOKEN}parameter name="city" string="true">',
"San Francisco",
f"</{DSML_TOKEN}parameter>\n",
f"</{DSML_TOKEN}invoke>\n",
TOOL_CALLS_END,
]
results = list(parse_deepseek_v32(_simulate_tokens(model_tokens)))
gen_results = [r for r in results if isinstance(r, GenerationResponse)]
tool_results = [r for r in results if isinstance(r, ToolCallResponse)]
# Should have thinking tokens + tool call
thinking_results = [r for r in gen_results if r.is_thinking]
assert len(thinking_results) >= 1
thinking_text = "".join(r.text for r in thinking_results)
assert "get_weather tool" in thinking_text
assert len(tool_results) == 1
assert tool_results[0].tool_calls[0].name == "get_weather"
args = json.loads(tool_results[0].tool_calls[0].arguments) # pyright: ignore[reportAny]
assert args["city"] == "San Francisco"
def test_thinking_prompt_encoding(self):
"""Verify thinking mode affects prompt encoding correctly."""
messages: list[dict[str, Any]] = [
{"role": "system", "content": "Be thorough."},
{"role": "user", "content": "What's the weather?"},
]
# With thinking enabled
prompt_think = encode_messages(
messages, tools=_WEATHER_TOOLS, thinking_mode="thinking"
)
assert prompt_think.endswith(THINKING_START)
# With thinking disabled
prompt_no_think = encode_messages(
messages, tools=_WEATHER_TOOLS, thinking_mode="chat"
)
assert prompt_no_think.endswith(THINKING_END)
# Both should have the same tool definitions
assert "get_weather" in prompt_think
assert "get_weather" in prompt_no_think
# ── Test: Round-trip encode → parse ──────────────────────────────
class TestE2ERoundTrip:
"""Verify that DSML we encode can be parsed back correctly."""
def test_encoded_tool_call_is_parseable(self):
"""Encode an assistant tool call message, then parse the DSML output."""
messages: list[dict[str, Any]] = [
{"role": "user", "content": "Weather?"},
{
"role": "assistant",
"content": "",
"tool_calls": [
{
"type": "function",
"function": {
"name": "get_weather",
"arguments": '{"city": "Tokyo", "units": "celsius"}',
},
}
],
},
]
prompt = encode_messages(messages, thinking_mode="chat", tools=_WEATHER_TOOLS)
# Extract the DSML function_calls block from the prompt
start = prompt.index(TOOL_CALLS_START)
end = prompt.index(TOOL_CALLS_END) + len(TOOL_CALLS_END)
dsml_block = prompt[start:end]
# Parse it back
parsed = parse_dsml_output(dsml_block)
assert parsed is not None
assert len(parsed) == 1
assert parsed[0].name == "get_weather"
args = json.loads(parsed[0].arguments) # pyright: ignore[reportAny]
assert args["city"] == "Tokyo"
assert args["units"] == "celsius"
def test_encoded_multi_tool_call_round_trips(self):
"""Encode multiple tool calls, verify they parse back correctly."""
messages: list[dict[str, Any]] = [
{"role": "user", "content": "Both please"},
{
"role": "assistant",
"content": "",
"tool_calls": [
{
"type": "function",
"function": {
"name": "get_weather",
"arguments": '{"city": "Paris"}',
},
},
{
"type": "function",
"function": {
"name": "get_time",
"arguments": '{"timezone": "CET"}',
},
},
],
},
]
prompt = encode_messages(messages, thinking_mode="chat", tools=_WEATHER_TOOLS)
start = prompt.index(TOOL_CALLS_START)
end = prompt.index(TOOL_CALLS_END) + len(TOOL_CALLS_END)
dsml_block = prompt[start:end]
parsed = parse_dsml_output(dsml_block)
assert parsed is not None
assert len(parsed) == 2
assert parsed[0].name == "get_weather"
assert parsed[1].name == "get_time"
assert json.loads(parsed[0].arguments) == {"city": "Paris"}
assert json.loads(parsed[1].arguments) == {"timezone": "CET"}
# ── Test: Edge cases with realistic token boundaries ─────────────
class TestE2EEdgeCases:
"""Edge cases that occur in real model inference."""
def test_dsml_marker_split_at_fullwidth_pipe(self):
"""The fullwidth pipe character might be its own token."""
# This is a realistic tokenization: the DSML marker is split at the chars
model_tokens = [
"Let me help.\n\n",
"<\uff5c", # start of DSML
"DSML\uff5c", # rest of DSML token
"function_calls>\n",
f'<{DSML_TOKEN}invoke name="get_weather">\n',
f'<{DSML_TOKEN}parameter name="city" string="true">NYC</{DSML_TOKEN}parameter>\n',
f"</{DSML_TOKEN}invoke>\n",
TOOL_CALLS_END,
]
results = list(parse_deepseek_v32(_simulate_tokens(model_tokens)))
tool_results = [r for r in results if isinstance(r, ToolCallResponse)]
assert len(tool_results) == 1
assert tool_results[0].tool_calls[0].name == "get_weather"
def test_tool_call_with_nested_json_object(self):
"""Model passes a complex JSON object as a non-string parameter."""
dsml_block = (
f"{TOOL_CALLS_START}\n"
f'<{DSML_TOKEN}invoke name="create_event">\n'
f'<{DSML_TOKEN}parameter name="title" string="true">Team Standup</{DSML_TOKEN}parameter>\n'
f'<{DSML_TOKEN}parameter name="config" string="false">'
f'{{"recurring": true, "days": ["mon", "wed", "fri"], "time": "09:00"}}'
f"</{DSML_TOKEN}parameter>\n"
f"</{DSML_TOKEN}invoke>\n"
f"{TOOL_CALLS_END}"
)
# Feed as single token (model might produce it all at once after prefill)
results = list(parse_deepseek_v32(_simulate_tokens([dsml_block])))
tool_results = [r for r in results if isinstance(r, ToolCallResponse)]
assert len(tool_results) == 1
tc = tool_results[0].tool_calls[0]
assert tc.name == "create_event"
args = json.loads(tc.arguments) # pyright: ignore[reportAny]
assert args["title"] == "Team Standup"
assert args["config"]["recurring"] is True
assert args["config"]["days"] == ["mon", "wed", "fri"]
def test_text_with_angle_brackets_not_mistaken_for_dsml(self):
"""Angle brackets in normal text should not trigger DSML buffering."""
model_tokens = [
"The formula is ",
"<x, y>",
" where x > 0",
" and y < 100.",
]
results = list(parse_deepseek_v32(_simulate_tokens(model_tokens)))
gen_results = [r for r in results if isinstance(r, GenerationResponse)]
tool_results = [r for r in results if isinstance(r, ToolCallResponse)]
assert len(tool_results) == 0
full_text = "".join(r.text for r in gen_results)
assert "formula" in full_text
assert "<x, y>" in full_text
def test_empty_model_response(self):
"""Model produces only EOS (empty response)."""
model_tokens = [""]
results = list(parse_deepseek_v32(_simulate_tokens(model_tokens)))
gen_results = [r for r in results if isinstance(r, GenerationResponse)]
assert len(gen_results) == 1
assert gen_results[0].text == ""
assert gen_results[0].finish_reason == "stop"
# ── Test: Full EPDP spec round-trip ──────────────────────────────
class TestE2EFullRoundTrip:
"""Full round-trip matching the vLLM EPDP spec.
Simulates the complete multi-turn flow:
Turn 1: user asks → think → tool call → tool result → think → answer
Turn 2: user asks again → old reasoning stripped → think → answer
"""
def test_single_tool_full_flow_with_thinking(self):
"""Complete flow: user → think → tool call → tool result → think → answer.
This is the core EPDP flow from the vLLM spec.
"""
# ── Turn 1.1: User asks, encode prompt ──
messages: list[dict[str, Any]] = [
{"role": "system", "content": "You are a weather assistant."},
{"role": "user", "content": "How's the weather in Hangzhou?"},
]
prompt_1 = encode_messages(
messages, tools=_WEATHER_TOOLS, thinking_mode="thinking"
)
assert prompt_1.endswith(THINKING_START)
assert "## Tools" in prompt_1
assert "get_weather" in prompt_1
# ── Turn 1.1: Model thinks, then calls tool ──
model_tokens_1 = [
THINKING_START,
"The user wants to know the weather in Hangzhou.",
" I need to use the get_weather tool.",
THINKING_END,
"\n\n",
TOOL_CALLS_START,
"\n",
f'<{DSML_TOKEN}invoke name="get_weather">\n',
f'<{DSML_TOKEN}parameter name="city" string="true">Hangzhou</{DSML_TOKEN}parameter>\n',
f"</{DSML_TOKEN}invoke>\n",
TOOL_CALLS_END,
]
results_1 = list(parse_deepseek_v32(_simulate_tokens(model_tokens_1)))
# Verify: thinking tokens + tool call
gen_1 = [r for r in results_1 if isinstance(r, GenerationResponse)]
tool_1 = [r for r in results_1 if isinstance(r, ToolCallResponse)]
thinking_1 = [r for r in gen_1 if r.is_thinking]
assert len(thinking_1) >= 1
assert "get_weather tool" in "".join(r.text for r in thinking_1)
assert len(tool_1) == 1
assert tool_1[0].tool_calls[0].name == "get_weather"
tc_args = json.loads(tool_1[0].tool_calls[0].arguments) # pyright: ignore[reportAny]
assert tc_args == {"city": "Hangzhou"}
# ── Turn 1.2: Add assistant response + tool result to messages ──
messages.append(
{
"role": "assistant",
"content": "",
"reasoning_content": "The user wants to know the weather in Hangzhou. I need to use the get_weather tool.",
"tool_calls": [
{
"type": "function",
"function": {
"name": "get_weather",
"arguments": '{"city": "Hangzhou"}',
},
}
],
}
)
messages.append(
{
"role": "tool",
"content": '{"temperature": "7~13°C", "condition": "Cloudy"}',
}
)
# Encode prompt for turn 1.2
prompt_2 = encode_messages(
messages, tools=_WEATHER_TOOLS, thinking_mode="thinking"
)
# Verify: prompt has the full conversation structure
assert TOOL_CALLS_START in prompt_2 # assistant's encoded tool call
assert EOS_TOKEN in prompt_2 # assistant turn ends with EOS
assert "<function_results>" in prompt_2
assert "<result>" in prompt_2
assert "Cloudy" in prompt_2
assert "</function_results>" in prompt_2
# After tool results with thinking enabled → <think> appended
assert prompt_2.endswith(THINKING_START)
# The assistant's reasoning_content should appear (it's after last_user_idx)
assert "get_weather tool" in prompt_2
# ── Turn 1.2: Model thinks about results, then answers ──
model_tokens_2 = [
THINKING_START,
"The weather in Hangzhou is Cloudy, 7~13°C.",
" I'll tell the user.",
THINKING_END,
"The weather in Hangzhou is currently cloudy with temperatures between 7°C and 13°C.",
]
results_2 = list(parse_deepseek_v32(_simulate_tokens(model_tokens_2)))
gen_2 = [r for r in results_2 if isinstance(r, GenerationResponse)]
tool_2 = [r for r in results_2 if isinstance(r, ToolCallResponse)]
thinking_2 = [r for r in gen_2 if r.is_thinking]
non_thinking_2 = [r for r in gen_2 if not r.is_thinking]
assert len(tool_2) == 0 # No more tool calls
assert len(thinking_2) >= 1
assert "Cloudy" in "".join(r.text for r in thinking_2)
assert len(non_thinking_2) >= 1
final_text = "".join(r.text for r in non_thinking_2)
assert "7°C" in final_text
assert "13°C" in final_text
def test_multi_tool_full_flow(self):
"""Flow with two tools: user → think → 2 tool calls → 2 results → think → answer."""
# ── Initial prompt ──
messages: list[dict[str, Any]] = [
{"role": "system", "content": "You help with weather and time."},
{"role": "user", "content": "Weather in NYC and time in EST?"},
]
prompt_1 = encode_messages(
messages, tools=_WEATHER_TOOLS, thinking_mode="thinking"
)
assert prompt_1.endswith(THINKING_START)
# ── Model thinks, calls both tools ──
model_tokens_1 = [
THINKING_START,
"Two requests: weather and time. I'll call both.",
THINKING_END,
"\n\n",
TOOL_CALLS_START,
"\n",
f'<{DSML_TOKEN}invoke name="get_weather">\n',
f'<{DSML_TOKEN}parameter name="city" string="true">NYC</{DSML_TOKEN}parameter>\n',
f"</{DSML_TOKEN}invoke>\n",
f'<{DSML_TOKEN}invoke name="get_time">\n',
f'<{DSML_TOKEN}parameter name="timezone" string="true">EST</{DSML_TOKEN}parameter>\n',
f"</{DSML_TOKEN}invoke>\n",
TOOL_CALLS_END,
]
results_1 = list(parse_deepseek_v32(_simulate_tokens(model_tokens_1)))
tool_1 = [r for r in results_1 if isinstance(r, ToolCallResponse)]
assert len(tool_1) == 1
assert len(tool_1[0].tool_calls) == 2
assert tool_1[0].tool_calls[0].name == "get_weather"
assert tool_1[0].tool_calls[1].name == "get_time"
# ── Add assistant + both tool results ──
messages.append(
{
"role": "assistant",
"content": "",
"reasoning_content": "Two requests: weather and time. I'll call both.",
"tool_calls": [
{
"type": "function",
"function": {
"name": "get_weather",
"arguments": '{"city": "NYC"}',
},
},
{
"type": "function",
"function": {
"name": "get_time",
"arguments": '{"timezone": "EST"}',
},
},
],
}
)
messages.append({"role": "tool", "content": "72°F, sunny"})
messages.append({"role": "tool", "content": "2:30 PM EST"})
prompt_2 = encode_messages(
messages, tools=_WEATHER_TOOLS, thinking_mode="thinking"
)
# Verify multi-tool result encoding
# Count is 2: 1 in _TOOLS_SYSTEM_TEMPLATE example + 1 in conversation
assert prompt_2.count("<function_results>") == 2
assert prompt_2.count("</function_results>") == 2
assert "<result>72°F, sunny</result>" in prompt_2
assert "<result>2:30 PM EST</result>" in prompt_2
assert prompt_2.endswith(THINKING_START)
# ── Model thinks about results, answers ──
model_tokens_2 = [
THINKING_START,
"Got both results. Weather is 72°F sunny, time is 2:30 PM.",
THINKING_END,
"In NYC it's currently 72°F and sunny. The time in EST is 2:30 PM.",
]
results_2 = list(parse_deepseek_v32(_simulate_tokens(model_tokens_2)))
tool_2 = [r for r in results_2 if isinstance(r, ToolCallResponse)]
gen_2 = [r for r in results_2 if isinstance(r, GenerationResponse)]
non_thinking_2 = [r for r in gen_2 if not r.is_thinking]
assert len(tool_2) == 0
final_text = "".join(r.text for r in non_thinking_2)
assert "72°F" in final_text
assert "2:30 PM" in final_text
def test_two_user_turns_reasoning_stripped(self):
"""Turn 2: old reasoning_content is stripped from history.
Per the vLLM spec, clear_reasoning_content is called between user turns
to save bandwidth. Our _drop_old_thinking handles this.
"""
# Full turn 1 conversation (already completed)
messages: list[dict[str, Any]] = [
{"role": "system", "content": "You are helpful."},
{"role": "user", "content": "Weather in Hangzhou?"},
{
"role": "assistant",
"content": "",
"reasoning_content": "I need to call get_weather for Hangzhou.",
"tool_calls": [
{
"type": "function",
"function": {
"name": "get_weather",
"arguments": '{"city": "Hangzhou"}',
},
}
],
},
{"role": "tool", "content": "Cloudy 7~13°C"},
{
"role": "assistant",
"content": "The weather in Hangzhou is cloudy, 7-13°C.",
"reasoning_content": "The tool returned cloudy weather. I'll summarize.",
},
# Turn 2: user asks again
{"role": "user", "content": "What about Beijing?"},
]
prompt = encode_messages(
messages, tools=_WEATHER_TOOLS, thinking_mode="thinking"
)
# Old reasoning_content from turn 1 assistants should be STRIPPED
# (they're before the last user message at index 5)
assert "I need to call get_weather" not in prompt
assert "tool returned cloudy" not in prompt
# But the assistant's content and tool calls should still be there
assert "cloudy, 7-13°C" in prompt
assert TOOL_CALLS_START in prompt
# Prompt ends with <think> for the new turn
assert prompt.endswith(THINKING_START)
# ── Turn 2: Model thinks, calls tool for Beijing ──
model_tokens = [
THINKING_START,
"Now the user wants Beijing weather.",
THINKING_END,
"\n\n",
TOOL_CALLS_START,
"\n",
f'<{DSML_TOKEN}invoke name="get_weather">\n',
f'<{DSML_TOKEN}parameter name="city" string="true">Beijing</{DSML_TOKEN}parameter>\n',
f"</{DSML_TOKEN}invoke>\n",
TOOL_CALLS_END,
]
results = list(parse_deepseek_v32(_simulate_tokens(model_tokens)))
tool_results = [r for r in results if isinstance(r, ToolCallResponse)]
assert len(tool_results) == 1
assert tool_results[0].tool_calls[0].name == "get_weather"
args = json.loads(tool_results[0].tool_calls[0].arguments) # pyright: ignore[reportAny]
assert args == {"city": "Beijing"}
def test_chained_tool_calls_loop(self):
"""Model calls tool, gets result, calls another tool, gets result, answers.
This simulates the inner while loop from the vLLM spec where the model
may need multiple sub-turns of tool calling before it has enough info.
"""
# ── Sub-turn 1: user asks, model calls get_time ──
messages: list[dict[str, Any]] = [
{"role": "system", "content": "You are helpful."},
{"role": "user", "content": "What's the weather in Hangzhou tomorrow?"},
]
prompt_1 = encode_messages(
messages, tools=_WEATHER_TOOLS, thinking_mode="thinking"
)
assert prompt_1.endswith(THINKING_START)
# Model first calls get_time to figure out the date
model_tokens_1 = [
THINKING_START,
"I need the current date first to calculate tomorrow.",
THINKING_END,
"\n\n",
TOOL_CALLS_START,
"\n",
f'<{DSML_TOKEN}invoke name="get_time">\n',
f'<{DSML_TOKEN}parameter name="timezone" string="true">Asia/Shanghai</{DSML_TOKEN}parameter>\n',
f"</{DSML_TOKEN}invoke>\n",
TOOL_CALLS_END,
]
results_1 = list(parse_deepseek_v32(_simulate_tokens(model_tokens_1)))
tool_1 = [r for r in results_1 if isinstance(r, ToolCallResponse)]
assert len(tool_1) == 1
assert tool_1[0].tool_calls[0].name == "get_time"
# ── Sub-turn 2: add tool result, model calls get_weather ──
messages.append(
{
"role": "assistant",
"content": "",
"reasoning_content": "I need the current date first to calculate tomorrow.",
"tool_calls": [
{
"type": "function",
"function": {
"name": "get_time",
"arguments": '{"timezone": "Asia/Shanghai"}',
},
}
],
}
)
messages.append({"role": "tool", "content": "2025-12-01 14:30 CST"})
prompt_2 = encode_messages(
messages, tools=_WEATHER_TOOLS, thinking_mode="thinking"
)
assert "<result>2025-12-01 14:30 CST</result>" in prompt_2
assert prompt_2.endswith(THINKING_START)
# Model now knows the date, calls get_weather
model_tokens_2 = [
THINKING_START,
"Today is 2025-12-01, so tomorrow is 2025-12-02.",
" Now I can check weather for Hangzhou.",
THINKING_END,
"\n\n",
TOOL_CALLS_START,
"\n",
f'<{DSML_TOKEN}invoke name="get_weather">\n',
f'<{DSML_TOKEN}parameter name="city" string="true">Hangzhou</{DSML_TOKEN}parameter>\n',
f"</{DSML_TOKEN}invoke>\n",
TOOL_CALLS_END,
]
results_2 = list(parse_deepseek_v32(_simulate_tokens(model_tokens_2)))
tool_2 = [r for r in results_2 if isinstance(r, ToolCallResponse)]
assert len(tool_2) == 1
assert tool_2[0].tool_calls[0].name == "get_weather"
# ── Sub-turn 3: add weather result, model answers ──
messages.append(
{
"role": "assistant",
"content": "",
"reasoning_content": "Today is 2025-12-01, so tomorrow is 2025-12-02. Now I can check weather for Hangzhou.",
"tool_calls": [
{
"type": "function",
"function": {
"name": "get_weather",
"arguments": '{"city": "Hangzhou"}',
},
}
],
}
)
messages.append({"role": "tool", "content": "Sunny, 5~12°C"})
prompt_3 = encode_messages(
messages, tools=_WEATHER_TOOLS, thinking_mode="thinking"
)
# Should have both function_results blocks (one per tool round)
# Count is 3: 1 in _TOOLS_SYSTEM_TEMPLATE example + 2 in conversation
assert prompt_3.count("<function_results>") == 3
assert prompt_3.count("</function_results>") == 3
assert "<result>2025-12-01 14:30 CST</result>" in prompt_3
assert "<result>Sunny, 5~12°C</result>" in prompt_3
assert prompt_3.endswith(THINKING_START)
# Model finally answers
model_tokens_3 = [
THINKING_START,
"I have the weather for tomorrow in Hangzhou.",
THINKING_END,
"Tomorrow in Hangzhou will be sunny with temperatures between 5°C and 12°C.",
]
results_3 = list(parse_deepseek_v32(_simulate_tokens(model_tokens_3)))
tool_3 = [r for r in results_3 if isinstance(r, ToolCallResponse)]
gen_3 = [r for r in results_3 if isinstance(r, GenerationResponse)]
non_thinking_3 = [r for r in gen_3 if not r.is_thinking]
assert len(tool_3) == 0 # No more tool calls — loop ends
final_text = "".join(r.text for r in non_thinking_3)
assert "sunny" in final_text.lower()
assert "5°C" in final_text
assert "12°C" in final_text

View File

@@ -148,6 +148,7 @@ class MockTokenizer:
tool_call_start = None
tool_call_end = None
has_tool_calling = False
has_thinking = False
class MockGroup:

View File

@@ -149,12 +149,23 @@ class TestParseGptOssThinkingThenToolCall:
def test_thinking_then_tool_call(self):
results = _collect(THINKING_THEN_TOOL_TOKENS)
# Should have thinking tags + content + tool call
text_parts = [r.text for r in results if isinstance(r, GenerationResponse)]
combined = "".join(text_parts)
assert "<think>" in combined
assert "</think>" in combined
assert "Let me think about this." in combined
# Thinking tokens should have is_thinking=True and no <think> tags
thinking_responses = [
r for r in results if isinstance(r, GenerationResponse) and r.is_thinking
]
thinking_text = "".join(r.text for r in thinking_responses)
assert "Let me think about this." in thinking_text
assert "<think>" not in thinking_text
assert "</think>" not in thinking_text
# Non-thinking tokens should have is_thinking=False
non_thinking = [
r
for r in results
if isinstance(r, GenerationResponse) and not r.is_thinking
]
non_thinking_text = "".join(r.text for r in non_thinking)
assert "<think>" not in non_thinking_text
# And the tool call
tc = _get_tool_call(results)

View File

@@ -0,0 +1,8 @@
#!/bin/bash
# Run Claude Code against a local exo cluster! (Here, GPT OSS 120B)
ANTHROPIC_BASE_URL="http://localhost:52415/" \
ANTHROPIC_AUTH_TOKEN="dummy" \
ANTHROPIC_MODEL="mlx-community/gpt-oss-120b-MXFP4-Q8" \
ANTHROPIC_SMALL_FAST_MODEL="mlx-community/gpt-oss-120b-MXFP4-Q8" \
CLAUDE_CODE_DISABLE_NONESSENTIAL_TRAFFIC=1 \
claude