mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-19 15:27:02 -05:00
Compare commits
6 Commits
memory-tid
...
feat/prefi
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
340aa36877 | ||
|
|
32669ae82d | ||
|
|
cf648a53b8 | ||
|
|
94b2ce6922 | ||
|
|
423ed0f07f | ||
|
|
ed001f2409 |
@@ -14,6 +14,21 @@
|
||||
: 0,
|
||||
);
|
||||
|
||||
const etaText = $derived.by(() => {
|
||||
if (progress.processed <= 0 || progress.total <= 0) return null;
|
||||
const elapsedMs = performance.now() - progress.startedAt;
|
||||
if (elapsedMs < 200) return null; // need a minimum sample window
|
||||
const tokensPerMs = progress.processed / elapsedMs;
|
||||
const remainingTokens = progress.total - progress.processed;
|
||||
const remainingMs = remainingTokens / tokensPerMs;
|
||||
const remainingSec = Math.ceil(remainingMs / 1000);
|
||||
if (remainingSec <= 0) return null;
|
||||
if (remainingSec < 60) return `~${remainingSec}s remaining`;
|
||||
const mins = Math.floor(remainingSec / 60);
|
||||
const secs = remainingSec % 60;
|
||||
return `~${mins}m ${secs}s remaining`;
|
||||
});
|
||||
|
||||
function formatTokenCount(count: number | undefined): string {
|
||||
if (count == null) return "0";
|
||||
if (count >= 1000) {
|
||||
@@ -40,8 +55,11 @@
|
||||
style="width: {percentage}%"
|
||||
></div>
|
||||
</div>
|
||||
<div class="text-right text-xs text-exo-light-gray/70 mt-0.5 font-mono">
|
||||
{percentage}%
|
||||
<div
|
||||
class="flex items-center justify-between text-xs text-exo-light-gray/70 mt-0.5 font-mono"
|
||||
>
|
||||
<span>{etaText ?? ""}</span>
|
||||
<span>{percentage}%</span>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
|
||||
@@ -276,6 +276,8 @@ export interface TokenData {
|
||||
export interface PrefillProgress {
|
||||
processed: number;
|
||||
total: number;
|
||||
/** Timestamp (performance.now()) when prefill started. */
|
||||
startedAt: number;
|
||||
}
|
||||
|
||||
export interface Message {
|
||||
@@ -1652,11 +1654,12 @@ class AppStore {
|
||||
if (!reader) throw new Error("No response body");
|
||||
|
||||
let fullContent = prefixText;
|
||||
let streamedThinking = "";
|
||||
const collectedTokens: TokenData[] = [...tokensToKeep];
|
||||
|
||||
interface ChatCompletionChunk {
|
||||
choices?: Array<{
|
||||
delta?: { content?: string };
|
||||
delta?: { content?: string; reasoning_content?: string };
|
||||
logprobs?: {
|
||||
content?: Array<{
|
||||
token: string;
|
||||
@@ -1677,6 +1680,7 @@ class AppStore {
|
||||
(parsed) => {
|
||||
const choice = parsed.choices?.[0];
|
||||
const delta = choice?.delta?.content;
|
||||
const thinkingDelta = choice?.delta?.reasoning_content;
|
||||
|
||||
// Collect logprobs data
|
||||
const logprobsContent = choice?.logprobs?.content;
|
||||
@@ -1695,7 +1699,11 @@ class AppStore {
|
||||
}
|
||||
}
|
||||
|
||||
if (delta) {
|
||||
if (thinkingDelta) {
|
||||
streamedThinking += thinkingDelta;
|
||||
}
|
||||
|
||||
if (delta || thinkingDelta) {
|
||||
if (firstTokenTime === null) {
|
||||
firstTokenTime = performance.now();
|
||||
this.ttftMs = firstTokenTime - requestStartTime;
|
||||
@@ -1709,9 +1717,14 @@ class AppStore {
|
||||
this.tps = ((tokenCount - tokensToKeep.length) / elapsed) * 1000;
|
||||
}
|
||||
|
||||
fullContent += delta;
|
||||
const { displayContent, thinkingContent } =
|
||||
if (delta) {
|
||||
fullContent += delta;
|
||||
}
|
||||
const { displayContent, thinkingContent: tagThinking } =
|
||||
this.stripThinkingTags(fullContent);
|
||||
const combinedThinking = [streamedThinking, tagThinking]
|
||||
.filter(Boolean)
|
||||
.join("\n\n");
|
||||
|
||||
if (this.activeConversationId === targetConversationId) {
|
||||
this.currentResponse = displayContent;
|
||||
@@ -1723,7 +1736,7 @@ class AppStore {
|
||||
messageId,
|
||||
(m) => {
|
||||
m.content = displayContent;
|
||||
m.thinking = thinkingContent || undefined;
|
||||
m.thinking = combinedThinking || undefined;
|
||||
m.tokens = [...collectedTokens];
|
||||
},
|
||||
);
|
||||
@@ -1735,11 +1748,14 @@ class AppStore {
|
||||
|
||||
// Final update
|
||||
if (this.conversationExists(targetConversationId)) {
|
||||
const { displayContent, thinkingContent } =
|
||||
const { displayContent, thinkingContent: tagThinking } =
|
||||
this.stripThinkingTags(fullContent);
|
||||
const finalThinking = [streamedThinking, tagThinking]
|
||||
.filter(Boolean)
|
||||
.join("\n\n");
|
||||
this.updateConversationMessage(targetConversationId, messageId, (m) => {
|
||||
m.content = displayContent;
|
||||
m.thinking = thinkingContent || undefined;
|
||||
m.thinking = finalThinking || undefined;
|
||||
m.tokens = [...collectedTokens];
|
||||
if (this.ttftMs !== null) m.ttftMs = this.ttftMs;
|
||||
if (this.tps !== null) m.tps = this.tps;
|
||||
@@ -1847,11 +1863,12 @@ class AppStore {
|
||||
}
|
||||
|
||||
let streamedContent = "";
|
||||
let streamedThinking = "";
|
||||
const collectedTokens: TokenData[] = [];
|
||||
|
||||
interface ChatCompletionChunk {
|
||||
choices?: Array<{
|
||||
delta?: { content?: string };
|
||||
delta?: { content?: string; reasoning_content?: string };
|
||||
logprobs?: {
|
||||
content?: Array<{
|
||||
token: string;
|
||||
@@ -1872,6 +1889,7 @@ class AppStore {
|
||||
(parsed) => {
|
||||
const choice = parsed.choices?.[0];
|
||||
const delta = choice?.delta?.content;
|
||||
const thinkingDelta = choice?.delta?.reasoning_content;
|
||||
|
||||
// Collect logprobs data
|
||||
const logprobsContent = choice?.logprobs?.content;
|
||||
@@ -1890,10 +1908,19 @@ class AppStore {
|
||||
}
|
||||
}
|
||||
|
||||
if (delta) {
|
||||
streamedContent += delta;
|
||||
const { displayContent, thinkingContent } =
|
||||
if (thinkingDelta) {
|
||||
streamedThinking += thinkingDelta;
|
||||
}
|
||||
|
||||
if (delta || thinkingDelta) {
|
||||
if (delta) {
|
||||
streamedContent += delta;
|
||||
}
|
||||
const { displayContent, thinkingContent: tagThinking } =
|
||||
this.stripThinkingTags(streamedContent);
|
||||
const combinedThinking = [streamedThinking, tagThinking]
|
||||
.filter(Boolean)
|
||||
.join("\n\n");
|
||||
|
||||
// Only update currentResponse if target conversation is active
|
||||
if (this.activeConversationId === targetConversationId) {
|
||||
@@ -1906,7 +1933,7 @@ class AppStore {
|
||||
assistantMessage.id,
|
||||
(msg) => {
|
||||
msg.content = displayContent;
|
||||
msg.thinking = thinkingContent || undefined;
|
||||
msg.thinking = combinedThinking || undefined;
|
||||
msg.tokens = [...collectedTokens];
|
||||
},
|
||||
);
|
||||
@@ -1918,14 +1945,17 @@ class AppStore {
|
||||
|
||||
// Final cleanup of the message (if conversation still exists)
|
||||
if (this.conversationExists(targetConversationId)) {
|
||||
const { displayContent, thinkingContent } =
|
||||
const { displayContent, thinkingContent: tagThinking } =
|
||||
this.stripThinkingTags(streamedContent);
|
||||
const finalThinking = [streamedThinking, tagThinking]
|
||||
.filter(Boolean)
|
||||
.join("\n\n");
|
||||
this.updateConversationMessage(
|
||||
targetConversationId,
|
||||
assistantMessage.id,
|
||||
(msg) => {
|
||||
msg.content = displayContent;
|
||||
msg.thinking = thinkingContent || undefined;
|
||||
msg.thinking = finalThinking || undefined;
|
||||
msg.tokens = [...collectedTokens];
|
||||
},
|
||||
);
|
||||
@@ -2317,10 +2347,11 @@ class AppStore {
|
||||
}
|
||||
|
||||
let streamedContent = "";
|
||||
let streamedThinking = "";
|
||||
|
||||
interface ChatCompletionChunk {
|
||||
choices?: Array<{
|
||||
delta?: { content?: string };
|
||||
delta?: { content?: string; reasoning_content?: string };
|
||||
logprobs?: {
|
||||
content?: Array<{
|
||||
token: string;
|
||||
@@ -2348,6 +2379,7 @@ class AppStore {
|
||||
|
||||
const choice = parsed.choices?.[0];
|
||||
const tokenContent = choice?.delta?.content;
|
||||
const thinkingContent = choice?.delta?.reasoning_content;
|
||||
|
||||
// Collect logprobs data
|
||||
const logprobsContent = choice?.logprobs?.content;
|
||||
@@ -2366,7 +2398,11 @@ class AppStore {
|
||||
}
|
||||
}
|
||||
|
||||
if (tokenContent) {
|
||||
if (thinkingContent) {
|
||||
streamedThinking += thinkingContent;
|
||||
}
|
||||
|
||||
if (tokenContent || thinkingContent) {
|
||||
// Track first token for TTFT
|
||||
if (firstTokenTime === null) {
|
||||
firstTokenTime = performance.now();
|
||||
@@ -2383,11 +2419,16 @@ class AppStore {
|
||||
this.tps = (tokenCount / elapsed) * 1000;
|
||||
}
|
||||
|
||||
streamedContent += tokenContent;
|
||||
if (tokenContent) {
|
||||
streamedContent += tokenContent;
|
||||
}
|
||||
|
||||
// Strip thinking tags for display and extract thinking content
|
||||
const { displayContent, thinkingContent } =
|
||||
// Use stripThinkingTags as fallback for any <think> tags still in content
|
||||
const { displayContent, thinkingContent: tagThinking } =
|
||||
this.stripThinkingTags(streamedContent);
|
||||
const combinedThinking = [streamedThinking, tagThinking]
|
||||
.filter(Boolean)
|
||||
.join("\n\n");
|
||||
|
||||
// Only update currentResponse if target conversation is active
|
||||
if (this.activeConversationId === targetConversationId) {
|
||||
@@ -2400,7 +2441,7 @@ class AppStore {
|
||||
assistantMessage.id,
|
||||
(msg) => {
|
||||
msg.content = displayContent;
|
||||
msg.thinking = thinkingContent || undefined;
|
||||
msg.thinking = combinedThinking || undefined;
|
||||
msg.tokens = [...collectedTokens];
|
||||
},
|
||||
);
|
||||
@@ -2420,6 +2461,7 @@ class AppStore {
|
||||
this.prefillProgress = {
|
||||
processed: inner.processed_tokens,
|
||||
total: inner.total_tokens,
|
||||
startedAt: this.prefillProgress?.startedAt ?? performance.now(),
|
||||
};
|
||||
},
|
||||
},
|
||||
@@ -2436,14 +2478,17 @@ class AppStore {
|
||||
|
||||
// Final cleanup of the message (if conversation still exists)
|
||||
if (this.conversationExists(targetConversationId)) {
|
||||
const { displayContent, thinkingContent } =
|
||||
const { displayContent, thinkingContent: tagThinking } =
|
||||
this.stripThinkingTags(streamedContent);
|
||||
const finalThinking = [streamedThinking, tagThinking]
|
||||
.filter(Boolean)
|
||||
.join("\n\n");
|
||||
this.updateConversationMessage(
|
||||
targetConversationId,
|
||||
assistantMessage.id,
|
||||
(msg) => {
|
||||
msg.content = displayContent;
|
||||
msg.thinking = thinkingContent || undefined;
|
||||
msg.thinking = finalThinking || undefined;
|
||||
msg.tokens = [...collectedTokens];
|
||||
// Store performance metrics on the message
|
||||
if (this.ttftMs !== null) {
|
||||
|
||||
@@ -114,6 +114,74 @@
|
||||
});
|
||||
let tb5InfoDismissed = $state(false);
|
||||
|
||||
// Detect Mac Studio nodes using RDMA on en2 (the port next to ethernet — RDMA doesn't work there)
|
||||
const macStudioEn2RdmaWarning = $derived.by(() => {
|
||||
const edges = data?.edges;
|
||||
const ids = tbIdentifiers;
|
||||
const rdmaCtl = rdmaCtlData;
|
||||
if (!edges || !ids || !rdmaCtl) return null;
|
||||
|
||||
const affectedConnections: Array<{
|
||||
nodeId: string;
|
||||
nodeName: string;
|
||||
peerNodeId: string;
|
||||
peerNodeName: string;
|
||||
rdmaIface: string;
|
||||
}> = [];
|
||||
|
||||
const isMacStudio = (node: (typeof data.nodes)[string] | undefined) =>
|
||||
node?.system_info?.model_id === "Mac Studio";
|
||||
|
||||
for (const edge of edges) {
|
||||
if (!edge.sourceRdmaIface && !edge.sinkRdmaIface) continue;
|
||||
|
||||
const sourceNode = data?.nodes?.[edge.source];
|
||||
if (
|
||||
isMacStudio(sourceNode) &&
|
||||
edge.sourceRdmaIface === "rdma_en2" &&
|
||||
rdmaCtl[edge.source]?.enabled
|
||||
) {
|
||||
affectedConnections.push({
|
||||
nodeId: edge.source,
|
||||
nodeName:
|
||||
sourceNode?.friendly_name || edge.source.slice(0, 8) + "...",
|
||||
peerNodeId: edge.target,
|
||||
peerNodeName:
|
||||
data?.nodes?.[edge.target]?.friendly_name ||
|
||||
edge.target.slice(0, 8) + "...",
|
||||
rdmaIface: "en2",
|
||||
});
|
||||
}
|
||||
|
||||
const sinkNode = data?.nodes?.[edge.target];
|
||||
if (
|
||||
isMacStudio(sinkNode) &&
|
||||
edge.sinkRdmaIface === "rdma_en2" &&
|
||||
rdmaCtl[edge.target]?.enabled
|
||||
) {
|
||||
affectedConnections.push({
|
||||
nodeId: edge.target,
|
||||
nodeName: sinkNode?.friendly_name || edge.target.slice(0, 8) + "...",
|
||||
peerNodeId: edge.source,
|
||||
peerNodeName:
|
||||
sourceNode?.friendly_name || edge.source.slice(0, 8) + "...",
|
||||
rdmaIface: "en2",
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Deduplicate by nodeId
|
||||
const seen = new Set<string>();
|
||||
const unique = affectedConnections.filter((c) => {
|
||||
if (seen.has(c.nodeId)) return false;
|
||||
seen.add(c.nodeId);
|
||||
return true;
|
||||
});
|
||||
|
||||
return unique.length > 0 ? unique : null;
|
||||
});
|
||||
let macStudioEn2Dismissed = $state(false);
|
||||
|
||||
// Helper to get friendly node name from node ID
|
||||
function getNodeName(nodeId: string): string {
|
||||
const node = data?.nodes?.[nodeId];
|
||||
@@ -1758,7 +1826,7 @@
|
||||
</script>
|
||||
|
||||
{#snippet clusterWarnings()}
|
||||
{#if tbBridgeCycles.length > 0 || macosVersionMismatch || (tb5WithoutRdma && !tb5InfoDismissed)}
|
||||
{#if tbBridgeCycles.length > 0 || macosVersionMismatch || (tb5WithoutRdma && !tb5InfoDismissed) || (macStudioEn2RdmaWarning && !macStudioEn2Dismissed)}
|
||||
<div class="absolute top-4 left-4 flex flex-col gap-2 z-40">
|
||||
{#if tbBridgeCycles.length > 0}
|
||||
{@const cycle = tbBridgeCycles[0]}
|
||||
@@ -1923,12 +1991,260 @@
|
||||
</button>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
{#if macStudioEn2RdmaWarning && !macStudioEn2Dismissed}
|
||||
<div class="group relative" role="alert">
|
||||
<div
|
||||
class="flex items-center gap-2 px-3 py-2 rounded border border-red-500/50 bg-red-500/10 backdrop-blur-sm cursor-help"
|
||||
>
|
||||
<svg
|
||||
class="w-5 h-5 text-red-400 flex-shrink-0"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d={warningIconPath}
|
||||
/>
|
||||
</svg>
|
||||
<span class="text-sm font-mono text-red-200">
|
||||
RDMA INCOMPATIBLE PORT
|
||||
</span>
|
||||
<button
|
||||
type="button"
|
||||
onclick={() => (macStudioEn2Dismissed = true)}
|
||||
class="ml-1 text-red-300/60 hover:text-red-200 transition-colors cursor-pointer"
|
||||
title="Dismiss"
|
||||
>
|
||||
<svg
|
||||
class="w-4 h-4"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d="M6 18L18 6M6 6l12 12"
|
||||
/>
|
||||
</svg>
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<!-- Expanded tooltip on hover -->
|
||||
<div
|
||||
class="absolute top-full left-0 mt-2 w-96 p-4 rounded border border-red-500/30 bg-[#1a1a1a]/95 backdrop-blur-sm opacity-0 invisible group-hover:opacity-100 group-hover:visible transition-all duration-200 z-50 shadow-lg"
|
||||
>
|
||||
<p class="text-xs text-white/80 mb-3">
|
||||
The Thunderbolt 5 port next to the Ethernet port on Mac Studio
|
||||
does
|
||||
<span class="text-red-400 font-semibold">not support RDMA</span>.
|
||||
Move the cable to one of the other three TB5 ports.
|
||||
</p>
|
||||
|
||||
<div class="text-xs text-white/60 mb-3">
|
||||
<span class="text-red-300">Affected:</span>
|
||||
{#each macStudioEn2RdmaWarning as conn}
|
||||
<div class="ml-2 mt-0.5">
|
||||
<span class="text-white/80">{conn.nodeName}</span>
|
||||
<span class="text-white/30">→</span>
|
||||
<span class="text-white/60">{conn.peerNodeName}</span>
|
||||
<span class="text-white/30 ml-1">(en2)</span>
|
||||
</div>
|
||||
{/each}
|
||||
</div>
|
||||
|
||||
<!-- Mac Studio back panel illustration -->
|
||||
<div class="bg-black/40 rounded p-3 mb-3">
|
||||
<p
|
||||
class="text-[10px] font-mono text-white/30 uppercase tracking-wider mb-2"
|
||||
>
|
||||
Mac Studio — Rear Panel
|
||||
</p>
|
||||
<svg
|
||||
viewBox="0 0 320 72"
|
||||
class="w-full"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
>
|
||||
<rect
|
||||
x="1"
|
||||
y="1"
|
||||
width="318"
|
||||
height="70"
|
||||
rx="6"
|
||||
ry="6"
|
||||
fill="none"
|
||||
stroke="rgba(255,255,255,0.12)"
|
||||
stroke-width="1"
|
||||
/>
|
||||
<!-- TB5 port 1 -->
|
||||
<rect
|
||||
x="24"
|
||||
y="22"
|
||||
width="28"
|
||||
height="14"
|
||||
rx="4"
|
||||
fill="none"
|
||||
stroke="rgba(255,255,255,0.3)"
|
||||
stroke-width="1"
|
||||
/>
|
||||
<text
|
||||
x="38"
|
||||
y="52"
|
||||
text-anchor="middle"
|
||||
fill="rgba(255,255,255,0.25)"
|
||||
style="font-size:7px;font-family:ui-monospace,monospace;"
|
||||
>TB5</text
|
||||
>
|
||||
<!-- TB5 port 2 -->
|
||||
<rect
|
||||
x="62"
|
||||
y="22"
|
||||
width="28"
|
||||
height="14"
|
||||
rx="4"
|
||||
fill="none"
|
||||
stroke="rgba(255,255,255,0.3)"
|
||||
stroke-width="1"
|
||||
/>
|
||||
<text
|
||||
x="76"
|
||||
y="52"
|
||||
text-anchor="middle"
|
||||
fill="rgba(255,255,255,0.25)"
|
||||
style="font-size:7px;font-family:ui-monospace,monospace;"
|
||||
>TB5</text
|
||||
>
|
||||
<!-- TB5 port 3 -->
|
||||
<rect
|
||||
x="100"
|
||||
y="22"
|
||||
width="28"
|
||||
height="14"
|
||||
rx="4"
|
||||
fill="none"
|
||||
stroke="rgba(255,255,255,0.3)"
|
||||
stroke-width="1"
|
||||
/>
|
||||
<text
|
||||
x="114"
|
||||
y="52"
|
||||
text-anchor="middle"
|
||||
fill="rgba(255,255,255,0.25)"
|
||||
style="font-size:7px;font-family:ui-monospace,monospace;"
|
||||
>TB5</text
|
||||
>
|
||||
<!-- TB5 port 4: INCOMPATIBLE (en2) — equally spaced with ports 1-3 -->
|
||||
<rect
|
||||
x="138"
|
||||
y="22"
|
||||
width="28"
|
||||
height="14"
|
||||
rx="4"
|
||||
fill="rgba(239,68,68,0.1)"
|
||||
stroke="rgba(239,68,68,0.7)"
|
||||
stroke-width="1.5"
|
||||
/>
|
||||
<line
|
||||
x1="142"
|
||||
y1="25"
|
||||
x2="162"
|
||||
y2="33"
|
||||
stroke="rgba(239,68,68,0.8)"
|
||||
stroke-width="1.5"
|
||||
stroke-linecap="round"
|
||||
/>
|
||||
<line
|
||||
x1="162"
|
||||
y1="25"
|
||||
x2="142"
|
||||
y2="33"
|
||||
stroke="rgba(239,68,68,0.8)"
|
||||
stroke-width="1.5"
|
||||
stroke-linecap="round"
|
||||
/>
|
||||
<text
|
||||
x="152"
|
||||
y="52"
|
||||
text-anchor="middle"
|
||||
fill="rgba(239,68,68,0.6)"
|
||||
style="font-size:7px;font-family:ui-monospace,monospace;font-weight:600;"
|
||||
>en2</text
|
||||
>
|
||||
<!-- Ethernet port -->
|
||||
<rect
|
||||
x="196"
|
||||
y="19"
|
||||
width="24"
|
||||
height="20"
|
||||
rx="2"
|
||||
fill="none"
|
||||
stroke="rgba(255,255,255,0.2)"
|
||||
stroke-width="1"
|
||||
/>
|
||||
<rect
|
||||
x="200"
|
||||
y="23"
|
||||
width="16"
|
||||
height="12"
|
||||
rx="1"
|
||||
fill="none"
|
||||
stroke="rgba(255,255,255,0.12)"
|
||||
stroke-width="0.75"
|
||||
/>
|
||||
<text
|
||||
x="208"
|
||||
y="52"
|
||||
text-anchor="middle"
|
||||
fill="rgba(255,255,255,0.25)"
|
||||
style="font-size:7px;font-family:ui-monospace,monospace;"
|
||||
>ETH</text
|
||||
>
|
||||
<!-- Green checkmarks on working ports -->
|
||||
<circle
|
||||
cx="38"
|
||||
cy="62"
|
||||
r="3"
|
||||
fill="none"
|
||||
stroke="rgba(74,222,128,0.5)"
|
||||
stroke-width="0.75"
|
||||
/>
|
||||
<circle
|
||||
cx="76"
|
||||
cy="62"
|
||||
r="3"
|
||||
fill="none"
|
||||
stroke="rgba(74,222,128,0.5)"
|
||||
stroke-width="0.75"
|
||||
/>
|
||||
<circle
|
||||
cx="114"
|
||||
cy="62"
|
||||
r="3"
|
||||
fill="none"
|
||||
stroke="rgba(74,222,128,0.5)"
|
||||
stroke-width="0.75"
|
||||
/>
|
||||
</svg>
|
||||
</div>
|
||||
|
||||
<p class="text-xs text-white/50">
|
||||
<span class="text-green-400">Fix:</span> Move the Thunderbolt cable
|
||||
to any of the three leftmost ports (all support RDMA).
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
{/if}
|
||||
{/snippet}
|
||||
|
||||
{#snippet clusterWarningsCompact()}
|
||||
{#if tbBridgeCycles.length > 0 || macosVersionMismatch || (tb5WithoutRdma && !tb5InfoDismissed)}
|
||||
{#if tbBridgeCycles.length > 0 || macosVersionMismatch || (tb5WithoutRdma && !tb5InfoDismissed) || (macStudioEn2RdmaWarning && !macStudioEn2Dismissed)}
|
||||
<div class="absolute top-2 left-2 flex flex-col gap-1">
|
||||
{#if tbBridgeCycles.length > 0}
|
||||
<div
|
||||
@@ -1996,6 +2312,27 @@
|
||||
>
|
||||
</div>
|
||||
{/if}
|
||||
{#if macStudioEn2RdmaWarning && !macStudioEn2Dismissed}
|
||||
<div
|
||||
class="flex items-center gap-1.5 px-2 py-1 rounded border border-red-500/50 bg-red-500/10 backdrop-blur-sm"
|
||||
title="Mac Studio RDMA incompatible port (en2) — move cable to another TB5 port"
|
||||
>
|
||||
<svg
|
||||
class="w-3.5 h-3.5 text-red-400"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d={warningIconPath}
|
||||
/>
|
||||
</svg>
|
||||
<span class="text-[10px] font-mono text-red-200">BAD RDMA PORT</span>
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
{/if}
|
||||
{/snippet}
|
||||
|
||||
BIN
prefill-eta-screenshot.png
Normal file
BIN
prefill-eta-screenshot.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 110 KiB |
@@ -59,7 +59,11 @@ def chat_request_to_text_generation(
|
||||
chat_template_messages.append({"role": "system", "content": content})
|
||||
else:
|
||||
# Skip messages with no meaningful content
|
||||
if msg.content is None and msg.thinking is None and msg.tool_calls is None:
|
||||
if (
|
||||
msg.content is None
|
||||
and msg.reasoning_content is None
|
||||
and msg.tool_calls is None
|
||||
):
|
||||
continue
|
||||
|
||||
if msg.role in ("user", "assistant", "developer"):
|
||||
@@ -111,6 +115,11 @@ def chunk_to_response(
|
||||
]
|
||||
)
|
||||
|
||||
if chunk.is_thinking:
|
||||
delta = ChatCompletionMessage(role="assistant", reasoning_content=chunk.text)
|
||||
else:
|
||||
delta = ChatCompletionMessage(role="assistant", content=chunk.text)
|
||||
|
||||
return ChatCompletionResponse(
|
||||
id=command_id,
|
||||
created=int(time.time()),
|
||||
@@ -118,7 +127,7 @@ def chunk_to_response(
|
||||
choices=[
|
||||
StreamingChoiceResponse(
|
||||
index=0,
|
||||
delta=ChatCompletionMessage(role="assistant", content=chunk.text),
|
||||
delta=delta,
|
||||
logprobs=logprobs,
|
||||
finish_reason=chunk.finish_reason,
|
||||
)
|
||||
@@ -208,6 +217,7 @@ async def collect_chat_response(
|
||||
# FastAPI handles the cancellation better but wouldn't auto-serialize for some reason
|
||||
"""Collect all token chunks and return a single ChatCompletionResponse."""
|
||||
text_parts: list[str] = []
|
||||
thinking_parts: list[str] = []
|
||||
tool_calls: list[ToolCall] = []
|
||||
logprobs_content: list[LogprobsContentItem] = []
|
||||
model: str | None = None
|
||||
@@ -228,7 +238,10 @@ async def collect_chat_response(
|
||||
if model is None:
|
||||
model = chunk.model
|
||||
last_usage = chunk.usage or last_usage
|
||||
text_parts.append(chunk.text)
|
||||
if chunk.is_thinking:
|
||||
thinking_parts.append(chunk.text)
|
||||
else:
|
||||
text_parts.append(chunk.text)
|
||||
if chunk.logprob is not None:
|
||||
logprobs_content.append(
|
||||
LogprobsContentItem(
|
||||
@@ -258,6 +271,7 @@ async def collect_chat_response(
|
||||
raise ValueError(error_message)
|
||||
|
||||
combined_text = "".join(text_parts)
|
||||
combined_thinking = "".join(thinking_parts) if thinking_parts else None
|
||||
assert model is not None
|
||||
|
||||
yield ChatCompletionResponse(
|
||||
@@ -270,6 +284,7 @@ async def collect_chat_response(
|
||||
message=ChatCompletionMessage(
|
||||
role="assistant",
|
||||
content=combined_text,
|
||||
reasoning_content=combined_thinking,
|
||||
tool_calls=tool_calls if tool_calls else None,
|
||||
),
|
||||
logprobs=Logprobs(content=logprobs_content)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Claude Messages API adapter for converting requests/responses."""
|
||||
|
||||
import json
|
||||
import re
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
|
||||
@@ -28,6 +29,8 @@ from exo.shared.types.claude_api import (
|
||||
ClaudeStopReason,
|
||||
ClaudeTextBlock,
|
||||
ClaudeTextDelta,
|
||||
ClaudeThinkingBlock,
|
||||
ClaudeThinkingDelta,
|
||||
ClaudeToolResultBlock,
|
||||
ClaudeToolUseBlock,
|
||||
ClaudeUsage,
|
||||
@@ -61,6 +64,22 @@ def _extract_tool_result_text(block: ClaudeToolResultBlock) -> str:
|
||||
return "".join(sub_block.text for sub_block in block.content)
|
||||
|
||||
|
||||
# Matches "x-anthropic-billing-header: ...;" (with optional trailing newline)
|
||||
# or similar telemetry headers that change every request and break KV prefix caching.
|
||||
_VOLATILE_HEADER_RE = re.compile(r"^x-anthropic-[^\n]*;\n?", re.MULTILINE)
|
||||
|
||||
|
||||
def _strip_volatile_headers(text: str) -> str:
|
||||
"""Remove Anthropic billing/telemetry headers from system prompt text.
|
||||
|
||||
Claude Code prepends headers like 'x-anthropic-billing-header: cc_version=...;
|
||||
cc_entrypoint=...; cch=...;' that contain per-request content hashes. These
|
||||
change every request and break KV prefix caching (the prefix diverges at ~20
|
||||
tokens instead of matching thousands of conversation tokens).
|
||||
"""
|
||||
return _VOLATILE_HEADER_RE.sub("", text)
|
||||
|
||||
|
||||
def claude_request_to_text_generation(
|
||||
request: ClaudeMessagesRequest,
|
||||
) -> TextGenerationTaskParams:
|
||||
@@ -73,6 +92,8 @@ def claude_request_to_text_generation(
|
||||
instructions = request.system
|
||||
else:
|
||||
instructions = "".join(block.text for block in request.system)
|
||||
|
||||
instructions = _strip_volatile_headers(instructions)
|
||||
chat_template_messages.append({"role": "system", "content": instructions})
|
||||
|
||||
# Convert messages to input
|
||||
@@ -85,12 +106,15 @@ def claude_request_to_text_generation(
|
||||
|
||||
# Process structured content blocks
|
||||
text_parts: list[str] = []
|
||||
thinking_parts: list[str] = []
|
||||
tool_calls: list[dict[str, Any]] = []
|
||||
tool_results: list[ClaudeToolResultBlock] = []
|
||||
|
||||
for block in msg.content:
|
||||
if isinstance(block, ClaudeTextBlock):
|
||||
text_parts.append(block.text)
|
||||
elif isinstance(block, ClaudeThinkingBlock):
|
||||
thinking_parts.append(block.thinking)
|
||||
elif isinstance(block, ClaudeToolUseBlock):
|
||||
tool_calls.append(
|
||||
{
|
||||
@@ -106,6 +130,7 @@ def claude_request_to_text_generation(
|
||||
tool_results.append(block)
|
||||
|
||||
content = "".join(text_parts)
|
||||
reasoning_content = "".join(thinking_parts) if thinking_parts else None
|
||||
|
||||
# Build InputMessage from text content
|
||||
if msg.role in ("user", "assistant"):
|
||||
@@ -113,9 +138,14 @@ def claude_request_to_text_generation(
|
||||
|
||||
# Build chat_template_messages preserving tool structure
|
||||
if tool_calls:
|
||||
chat_template_messages.append(
|
||||
{"role": "assistant", "content": content, "tool_calls": tool_calls}
|
||||
)
|
||||
chat_msg: dict[str, Any] = {
|
||||
"role": "assistant",
|
||||
"content": content,
|
||||
"tool_calls": tool_calls,
|
||||
}
|
||||
if reasoning_content:
|
||||
chat_msg["reasoning_content"] = reasoning_content
|
||||
chat_template_messages.append(chat_msg)
|
||||
elif tool_results:
|
||||
for tr in tool_results:
|
||||
chat_template_messages.append(
|
||||
@@ -126,7 +156,10 @@ def claude_request_to_text_generation(
|
||||
}
|
||||
)
|
||||
else:
|
||||
chat_template_messages.append({"role": msg.role, "content": content})
|
||||
chat_msg = {"role": msg.role, "content": content}
|
||||
if reasoning_content:
|
||||
chat_msg["reasoning_content"] = reasoning_content
|
||||
chat_template_messages.append(chat_msg)
|
||||
|
||||
# Convert Claude tool definitions to OpenAI-style function tools
|
||||
tools: list[dict[str, Any]] | None = None
|
||||
@@ -143,6 +176,10 @@ def claude_request_to_text_generation(
|
||||
for tool in request.tools
|
||||
]
|
||||
|
||||
enable_thinking: bool | None = None
|
||||
if request.thinking is not None:
|
||||
enable_thinking = request.thinking.type in ("enabled", "adaptive")
|
||||
|
||||
return TextGenerationTaskParams(
|
||||
model=request.model,
|
||||
input=input_messages
|
||||
@@ -156,6 +193,7 @@ def claude_request_to_text_generation(
|
||||
stop=request.stop_sequences,
|
||||
stream=request.stream,
|
||||
tools=tools,
|
||||
enable_thinking=enable_thinking,
|
||||
chat_template_messages=chat_template_messages
|
||||
if chat_template_messages
|
||||
else None,
|
||||
@@ -173,6 +211,7 @@ async def collect_claude_response(
|
||||
# FastAPI handles the cancellation better but wouldn't auto-serialize for some reason
|
||||
"""Collect all token chunks and return a single ClaudeMessagesResponse."""
|
||||
text_parts: list[str] = []
|
||||
thinking_parts: list[str] = []
|
||||
tool_use_blocks: list[ClaudeToolUseBlock] = []
|
||||
stop_reason: ClaudeStopReason | None = None
|
||||
last_usage: Usage | None = None
|
||||
@@ -200,7 +239,10 @@ async def collect_claude_response(
|
||||
stop_reason = "tool_use"
|
||||
continue
|
||||
|
||||
text_parts.append(chunk.text)
|
||||
if chunk.is_thinking:
|
||||
thinking_parts.append(chunk.text)
|
||||
else:
|
||||
text_parts.append(chunk.text)
|
||||
|
||||
if chunk.finish_reason is not None:
|
||||
stop_reason = finish_reason_to_claude_stop_reason(chunk.finish_reason)
|
||||
@@ -209,9 +251,12 @@ async def collect_claude_response(
|
||||
raise ValueError(error_message)
|
||||
|
||||
combined_text = "".join(text_parts)
|
||||
combined_thinking = "".join(thinking_parts)
|
||||
|
||||
# Build content blocks
|
||||
content: list[ClaudeContentBlock] = []
|
||||
if combined_thinking:
|
||||
content.append(ClaudeThinkingBlock(thinking=combined_thinking))
|
||||
if combined_text:
|
||||
content.append(ClaudeTextBlock(text=combined_text))
|
||||
content.extend(tool_use_blocks)
|
||||
@@ -256,16 +301,16 @@ async def generate_claude_stream(
|
||||
start_event = ClaudeMessageStartEvent(message=initial_message)
|
||||
yield f"event: message_start\ndata: {start_event.model_dump_json()}\n\n"
|
||||
|
||||
# content_block_start for text block at index 0
|
||||
block_start = ClaudeContentBlockStartEvent(
|
||||
index=0, content_block=ClaudeTextBlock(text="")
|
||||
)
|
||||
yield f"event: content_block_start\ndata: {block_start.model_dump_json()}\n\n"
|
||||
|
||||
output_tokens = 0
|
||||
stop_reason: ClaudeStopReason | None = None
|
||||
last_usage: Usage | None = None
|
||||
next_block_index = 1 # text block is 0, tool blocks start at 1
|
||||
next_block_index = 0
|
||||
|
||||
# Track whether we've started thinking/text blocks
|
||||
thinking_block_started = False
|
||||
thinking_block_index = -1
|
||||
text_block_started = False
|
||||
text_block_index = -1
|
||||
|
||||
async for chunk in chunk_stream:
|
||||
if isinstance(chunk, PrefillProgressChunk):
|
||||
@@ -310,12 +355,45 @@ async def generate_claude_stream(
|
||||
|
||||
output_tokens += 1 # Count each chunk as one token
|
||||
|
||||
# content_block_delta
|
||||
delta_event = ClaudeContentBlockDeltaEvent(
|
||||
index=0,
|
||||
delta=ClaudeTextDelta(text=chunk.text),
|
||||
)
|
||||
yield f"event: content_block_delta\ndata: {delta_event.model_dump_json()}\n\n"
|
||||
if chunk.is_thinking:
|
||||
# Start thinking block on first thinking token
|
||||
if not thinking_block_started:
|
||||
thinking_block_started = True
|
||||
thinking_block_index = next_block_index
|
||||
next_block_index += 1
|
||||
block_start = ClaudeContentBlockStartEvent(
|
||||
index=thinking_block_index,
|
||||
content_block=ClaudeThinkingBlock(thinking=""),
|
||||
)
|
||||
yield f"event: content_block_start\ndata: {block_start.model_dump_json()}\n\n"
|
||||
|
||||
delta_event = ClaudeContentBlockDeltaEvent(
|
||||
index=thinking_block_index,
|
||||
delta=ClaudeThinkingDelta(thinking=chunk.text),
|
||||
)
|
||||
yield f"event: content_block_delta\ndata: {delta_event.model_dump_json()}\n\n"
|
||||
else:
|
||||
# Close thinking block when transitioning to text
|
||||
if thinking_block_started and text_block_index == -1:
|
||||
block_stop = ClaudeContentBlockStopEvent(index=thinking_block_index)
|
||||
yield f"event: content_block_stop\ndata: {block_stop.model_dump_json()}\n\n"
|
||||
|
||||
# Start text block on first text token
|
||||
if not text_block_started:
|
||||
text_block_started = True
|
||||
text_block_index = next_block_index
|
||||
next_block_index += 1
|
||||
block_start = ClaudeContentBlockStartEvent(
|
||||
index=text_block_index,
|
||||
content_block=ClaudeTextBlock(text=""),
|
||||
)
|
||||
yield f"event: content_block_start\ndata: {block_start.model_dump_json()}\n\n"
|
||||
|
||||
delta_event = ClaudeContentBlockDeltaEvent(
|
||||
index=text_block_index,
|
||||
delta=ClaudeTextDelta(text=chunk.text),
|
||||
)
|
||||
yield f"event: content_block_delta\ndata: {delta_event.model_dump_json()}\n\n"
|
||||
|
||||
if chunk.finish_reason is not None:
|
||||
stop_reason = finish_reason_to_claude_stop_reason(chunk.finish_reason)
|
||||
@@ -324,9 +402,22 @@ async def generate_claude_stream(
|
||||
if last_usage is not None:
|
||||
output_tokens = last_usage.completion_tokens
|
||||
|
||||
# content_block_stop for text block
|
||||
block_stop = ClaudeContentBlockStopEvent(index=0)
|
||||
yield f"event: content_block_stop\ndata: {block_stop.model_dump_json()}\n\n"
|
||||
# Close any open blocks
|
||||
if thinking_block_started and text_block_index == -1:
|
||||
block_stop = ClaudeContentBlockStopEvent(index=thinking_block_index)
|
||||
yield f"event: content_block_stop\ndata: {block_stop.model_dump_json()}\n\n"
|
||||
|
||||
if text_block_started:
|
||||
block_stop = ClaudeContentBlockStopEvent(index=text_block_index)
|
||||
yield f"event: content_block_stop\ndata: {block_stop.model_dump_json()}\n\n"
|
||||
|
||||
if not thinking_block_started and not text_block_started:
|
||||
empty_start = ClaudeContentBlockStartEvent(
|
||||
index=0, content_block=ClaudeTextBlock(text="")
|
||||
)
|
||||
yield f"event: content_block_start\ndata: {empty_start.model_dump_json()}\n\n"
|
||||
empty_stop = ClaudeContentBlockStopEvent(index=0)
|
||||
yield f"event: content_block_stop\ndata: {empty_stop.model_dump_json()}\n\n"
|
||||
|
||||
# message_delta
|
||||
message_delta = ClaudeMessageDeltaEvent(
|
||||
|
||||
@@ -29,6 +29,12 @@ from exo.shared.types.openai_responses import (
|
||||
ResponseOutputItemAddedEvent,
|
||||
ResponseOutputItemDoneEvent,
|
||||
ResponseOutputText,
|
||||
ResponseReasoningItem,
|
||||
ResponseReasoningSummaryPartAddedEvent,
|
||||
ResponseReasoningSummaryPartDoneEvent,
|
||||
ResponseReasoningSummaryText,
|
||||
ResponseReasoningSummaryTextDeltaEvent,
|
||||
ResponseReasoningSummaryTextDoneEvent,
|
||||
ResponsesRequest,
|
||||
ResponsesResponse,
|
||||
ResponsesStreamEvent,
|
||||
@@ -141,7 +147,9 @@ async def collect_responses_response(
|
||||
"""Collect all token chunks and return a single ResponsesResponse."""
|
||||
response_id = f"resp_{command_id}"
|
||||
item_id = f"item_{command_id}"
|
||||
reasoning_id = f"rs_{command_id}"
|
||||
accumulated_text = ""
|
||||
thinking_parts: list[str] = []
|
||||
function_call_items: list[ResponseFunctionCallItem] = []
|
||||
last_usage: Usage | None = None
|
||||
error_message: str | None = None
|
||||
@@ -168,6 +176,10 @@ async def collect_responses_response(
|
||||
)
|
||||
continue
|
||||
|
||||
if chunk.is_thinking:
|
||||
thinking_parts.append(chunk.text)
|
||||
continue
|
||||
|
||||
accumulated_text += chunk.text
|
||||
|
||||
if error_message is not None:
|
||||
@@ -182,13 +194,21 @@ async def collect_responses_response(
|
||||
total_tokens=last_usage.total_tokens,
|
||||
)
|
||||
|
||||
output: list[ResponseItem] = [
|
||||
output: list[ResponseItem] = []
|
||||
if thinking_parts:
|
||||
output.append(
|
||||
ResponseReasoningItem(
|
||||
id=reasoning_id,
|
||||
summary=[ResponseReasoningSummaryText(text="".join(thinking_parts))],
|
||||
)
|
||||
)
|
||||
output.append(
|
||||
ResponseMessageItem(
|
||||
id=item_id,
|
||||
content=[ResponseOutputText(text=accumulated_text)],
|
||||
status="completed",
|
||||
)
|
||||
]
|
||||
)
|
||||
output.extend(function_call_items)
|
||||
|
||||
yield ResponsesResponse(
|
||||
@@ -212,6 +232,7 @@ async def generate_responses_stream(
|
||||
"""Generate OpenAI Responses API streaming events from TokenChunks."""
|
||||
response_id = f"resp_{command_id}"
|
||||
item_id = f"item_{command_id}"
|
||||
reasoning_id = f"rs_{command_id}"
|
||||
seq = count(1)
|
||||
|
||||
# response.created
|
||||
@@ -233,32 +254,17 @@ async def generate_responses_stream(
|
||||
)
|
||||
yield _format_sse(in_progress_event)
|
||||
|
||||
# response.output_item.added
|
||||
initial_item = ResponseMessageItem(
|
||||
id=item_id,
|
||||
content=[ResponseOutputText(text="")],
|
||||
status="in_progress",
|
||||
)
|
||||
item_added = ResponseOutputItemAddedEvent(
|
||||
sequence_number=next(seq), output_index=0, item=initial_item
|
||||
)
|
||||
yield _format_sse(item_added)
|
||||
|
||||
# response.content_part.added
|
||||
initial_part = ResponseOutputText(text="")
|
||||
part_added = ResponseContentPartAddedEvent(
|
||||
sequence_number=next(seq),
|
||||
item_id=item_id,
|
||||
output_index=0,
|
||||
content_index=0,
|
||||
part=initial_part,
|
||||
)
|
||||
yield _format_sse(part_added)
|
||||
|
||||
accumulated_text = ""
|
||||
accumulated_thinking = ""
|
||||
function_call_items: list[ResponseFunctionCallItem] = []
|
||||
last_usage: Usage | None = None
|
||||
next_output_index = 1 # message item is at 0
|
||||
next_output_index = 0
|
||||
|
||||
# Track dynamic block creation
|
||||
reasoning_started = False
|
||||
reasoning_output_index = -1
|
||||
message_started = False
|
||||
message_output_index = -1
|
||||
|
||||
async for chunk in chunk_stream:
|
||||
if isinstance(chunk, PrefillProgressChunk):
|
||||
@@ -327,23 +333,184 @@ async def generate_responses_stream(
|
||||
next_output_index += 1
|
||||
continue
|
||||
|
||||
if chunk.is_thinking:
|
||||
# Start reasoning block on first thinking token
|
||||
if not reasoning_started:
|
||||
reasoning_started = True
|
||||
reasoning_output_index = next_output_index
|
||||
next_output_index += 1
|
||||
|
||||
# response.output_item.added for reasoning
|
||||
reasoning_item = ResponseReasoningItem(
|
||||
id=reasoning_id,
|
||||
summary=[],
|
||||
status="in_progress",
|
||||
)
|
||||
rs_added = ResponseOutputItemAddedEvent(
|
||||
sequence_number=next(seq),
|
||||
output_index=reasoning_output_index,
|
||||
item=reasoning_item,
|
||||
)
|
||||
yield _format_sse(rs_added)
|
||||
|
||||
# response.reasoning_summary_part.added
|
||||
part_added = ResponseReasoningSummaryPartAddedEvent(
|
||||
sequence_number=next(seq),
|
||||
item_id=reasoning_id,
|
||||
output_index=reasoning_output_index,
|
||||
summary_index=0,
|
||||
part=ResponseReasoningSummaryText(text=""),
|
||||
)
|
||||
yield _format_sse(part_added)
|
||||
|
||||
accumulated_thinking += chunk.text
|
||||
|
||||
# response.reasoning_summary_text.delta
|
||||
rs_delta = ResponseReasoningSummaryTextDeltaEvent(
|
||||
sequence_number=next(seq),
|
||||
item_id=reasoning_id,
|
||||
output_index=reasoning_output_index,
|
||||
summary_index=0,
|
||||
delta=chunk.text,
|
||||
)
|
||||
yield _format_sse(rs_delta)
|
||||
continue
|
||||
|
||||
# Close reasoning block when transitioning to text
|
||||
if reasoning_started and not message_started:
|
||||
# response.reasoning_summary_text.done
|
||||
rs_text_done = ResponseReasoningSummaryTextDoneEvent(
|
||||
sequence_number=next(seq),
|
||||
item_id=reasoning_id,
|
||||
output_index=reasoning_output_index,
|
||||
summary_index=0,
|
||||
text=accumulated_thinking,
|
||||
)
|
||||
yield _format_sse(rs_text_done)
|
||||
|
||||
# response.reasoning_summary_part.done
|
||||
rs_part_done = ResponseReasoningSummaryPartDoneEvent(
|
||||
sequence_number=next(seq),
|
||||
item_id=reasoning_id,
|
||||
output_index=reasoning_output_index,
|
||||
summary_index=0,
|
||||
part=ResponseReasoningSummaryText(text=accumulated_thinking),
|
||||
)
|
||||
yield _format_sse(rs_part_done)
|
||||
|
||||
# response.output_item.done for reasoning
|
||||
rs_item_done = ResponseOutputItemDoneEvent(
|
||||
sequence_number=next(seq),
|
||||
output_index=reasoning_output_index,
|
||||
item=ResponseReasoningItem(
|
||||
id=reasoning_id,
|
||||
summary=[ResponseReasoningSummaryText(text=accumulated_thinking)],
|
||||
),
|
||||
)
|
||||
yield _format_sse(rs_item_done)
|
||||
|
||||
# Start message block on first text token
|
||||
if not message_started:
|
||||
message_started = True
|
||||
message_output_index = next_output_index
|
||||
next_output_index += 1
|
||||
|
||||
initial_item = ResponseMessageItem(
|
||||
id=item_id,
|
||||
content=[ResponseOutputText(text="")],
|
||||
status="in_progress",
|
||||
)
|
||||
item_added = ResponseOutputItemAddedEvent(
|
||||
sequence_number=next(seq),
|
||||
output_index=message_output_index,
|
||||
item=initial_item,
|
||||
)
|
||||
yield _format_sse(item_added)
|
||||
|
||||
initial_part = ResponseOutputText(text="")
|
||||
part_added = ResponseContentPartAddedEvent(
|
||||
sequence_number=next(seq),
|
||||
item_id=item_id,
|
||||
output_index=message_output_index,
|
||||
content_index=0,
|
||||
part=initial_part,
|
||||
)
|
||||
yield _format_sse(part_added)
|
||||
|
||||
accumulated_text += chunk.text
|
||||
|
||||
# response.output_text.delta
|
||||
delta_event = ResponseTextDeltaEvent(
|
||||
sequence_number=next(seq),
|
||||
item_id=item_id,
|
||||
output_index=0,
|
||||
output_index=message_output_index,
|
||||
content_index=0,
|
||||
delta=chunk.text,
|
||||
)
|
||||
yield _format_sse(delta_event)
|
||||
|
||||
# Close reasoning block if it was never followed by text
|
||||
if reasoning_started and not message_started:
|
||||
rs_text_done = ResponseReasoningSummaryTextDoneEvent(
|
||||
sequence_number=next(seq),
|
||||
item_id=reasoning_id,
|
||||
output_index=reasoning_output_index,
|
||||
summary_index=0,
|
||||
text=accumulated_thinking,
|
||||
)
|
||||
yield _format_sse(rs_text_done)
|
||||
|
||||
rs_part_done = ResponseReasoningSummaryPartDoneEvent(
|
||||
sequence_number=next(seq),
|
||||
item_id=reasoning_id,
|
||||
output_index=reasoning_output_index,
|
||||
summary_index=0,
|
||||
part=ResponseReasoningSummaryText(text=accumulated_thinking),
|
||||
)
|
||||
yield _format_sse(rs_part_done)
|
||||
|
||||
rs_item_done = ResponseOutputItemDoneEvent(
|
||||
sequence_number=next(seq),
|
||||
output_index=reasoning_output_index,
|
||||
item=ResponseReasoningItem(
|
||||
id=reasoning_id,
|
||||
summary=[ResponseReasoningSummaryText(text=accumulated_thinking)],
|
||||
),
|
||||
)
|
||||
yield _format_sse(rs_item_done)
|
||||
|
||||
# If no message block was started, create one now (empty text)
|
||||
if not message_started:
|
||||
message_output_index = next_output_index
|
||||
next_output_index += 1
|
||||
|
||||
initial_item = ResponseMessageItem(
|
||||
id=item_id,
|
||||
content=[ResponseOutputText(text="")],
|
||||
status="in_progress",
|
||||
)
|
||||
item_added = ResponseOutputItemAddedEvent(
|
||||
sequence_number=next(seq),
|
||||
output_index=message_output_index,
|
||||
item=initial_item,
|
||||
)
|
||||
yield _format_sse(item_added)
|
||||
|
||||
initial_part = ResponseOutputText(text="")
|
||||
part_added_evt = ResponseContentPartAddedEvent(
|
||||
sequence_number=next(seq),
|
||||
item_id=item_id,
|
||||
output_index=message_output_index,
|
||||
content_index=0,
|
||||
part=initial_part,
|
||||
)
|
||||
yield _format_sse(part_added_evt)
|
||||
|
||||
# response.output_text.done
|
||||
text_done = ResponseTextDoneEvent(
|
||||
sequence_number=next(seq),
|
||||
item_id=item_id,
|
||||
output_index=0,
|
||||
output_index=message_output_index,
|
||||
content_index=0,
|
||||
text=accumulated_text,
|
||||
)
|
||||
@@ -354,7 +521,7 @@ async def generate_responses_stream(
|
||||
part_done = ResponseContentPartDoneEvent(
|
||||
sequence_number=next(seq),
|
||||
item_id=item_id,
|
||||
output_index=0,
|
||||
output_index=message_output_index,
|
||||
content_index=0,
|
||||
part=final_part,
|
||||
)
|
||||
@@ -367,7 +534,9 @@ async def generate_responses_stream(
|
||||
status="completed",
|
||||
)
|
||||
item_done = ResponseOutputItemDoneEvent(
|
||||
sequence_number=next(seq), output_index=0, item=final_message_item
|
||||
sequence_number=next(seq),
|
||||
output_index=message_output_index,
|
||||
item=final_message_item,
|
||||
)
|
||||
yield _format_sse(item_done)
|
||||
|
||||
@@ -381,7 +550,15 @@ async def generate_responses_stream(
|
||||
)
|
||||
|
||||
# response.completed
|
||||
output: list[ResponseItem] = [final_message_item]
|
||||
output: list[ResponseItem] = []
|
||||
if reasoning_started:
|
||||
output.append(
|
||||
ResponseReasoningItem(
|
||||
id=reasoning_id,
|
||||
summary=[ResponseReasoningSummaryText(text=accumulated_thinking)],
|
||||
)
|
||||
)
|
||||
output.append(final_message_item)
|
||||
output.extend(function_call_items)
|
||||
final_response = ResponsesResponse(
|
||||
id=response_id,
|
||||
|
||||
@@ -138,7 +138,6 @@ from exo.shared.types.events import (
|
||||
Event,
|
||||
ForwarderEvent,
|
||||
IndexedEvent,
|
||||
PrefillProgress,
|
||||
TracesMerged,
|
||||
)
|
||||
from exo.shared.types.memory import Memory
|
||||
@@ -1455,22 +1454,6 @@ class API:
|
||||
await queue.send(event.chunk)
|
||||
except BrokenResourceError:
|
||||
self._text_generation_queues.pop(event.command_id, None)
|
||||
|
||||
elif isinstance(event, PrefillProgress):
|
||||
if queue := self._text_generation_queues.get(
|
||||
event.command_id, None
|
||||
):
|
||||
try:
|
||||
await queue.send(
|
||||
PrefillProgressChunk(
|
||||
model=event.model,
|
||||
processed_tokens=event.processed_tokens,
|
||||
total_tokens=event.total_tokens,
|
||||
)
|
||||
)
|
||||
except BrokenResourceError:
|
||||
self._text_generation_queues.pop(event.command_id, None)
|
||||
|
||||
if isinstance(event, TracesMerged):
|
||||
self._save_merged_trace(event)
|
||||
|
||||
|
||||
@@ -261,7 +261,7 @@ class TestGenerateClaudeStreamToolUse:
|
||||
|
||||
parsed = _parse_sse_events(events)
|
||||
|
||||
# Two tool block starts (at indices 1 and 2)
|
||||
# Two tool block starts (at indices 0 and 1 — no text block when only tools)
|
||||
tool_starts = [
|
||||
e
|
||||
for e in parsed
|
||||
@@ -270,12 +270,11 @@ class TestGenerateClaudeStreamToolUse:
|
||||
== "tool_use"
|
||||
]
|
||||
assert len(tool_starts) == 2
|
||||
assert tool_starts[0]["index"] == 1
|
||||
assert tool_starts[1]["index"] == 2
|
||||
assert tool_starts[0]["index"] == 0
|
||||
assert tool_starts[1]["index"] == 1
|
||||
|
||||
# Two tool block stops (at indices 1 and 2), plus text block stop at 0
|
||||
# Two tool block stops (at indices 0 and 1)
|
||||
block_stops = [e for e in parsed if e.get("type") == "content_block_stop"]
|
||||
stop_indices = [e["index"] for e in block_stops]
|
||||
assert 0 in stop_indices
|
||||
assert 1 in stop_indices
|
||||
assert 2 in stop_indices
|
||||
|
||||
@@ -15,7 +15,6 @@ from exo.shared.types.events import (
|
||||
NodeDownloadProgress,
|
||||
NodeGatheredInfo,
|
||||
NodeTimedOut,
|
||||
PrefillProgress,
|
||||
RunnerDeleted,
|
||||
RunnerStatusUpdated,
|
||||
TaskAcknowledged,
|
||||
@@ -65,7 +64,6 @@ def event_apply(event: Event, state: State) -> State:
|
||||
| ChunkGenerated()
|
||||
| TaskAcknowledged()
|
||||
| InputChunkReceived()
|
||||
| PrefillProgress()
|
||||
| TracesCollected()
|
||||
| TracesMerged()
|
||||
): # Pass-through events that don't modify state
|
||||
|
||||
@@ -77,7 +77,7 @@ class ChatCompletionMessage(BaseModel):
|
||||
content: (
|
||||
str | ChatCompletionMessageText | list[ChatCompletionMessageText] | None
|
||||
) = None
|
||||
thinking: str | None = None # Added for GPT-OSS harmony format support
|
||||
reasoning_content: str | None = None
|
||||
name: str | None = None
|
||||
tool_calls: list[ToolCall] | None = None
|
||||
tool_call_id: str | None = None
|
||||
|
||||
@@ -27,6 +27,7 @@ class TokenChunk(BaseChunk):
|
||||
stats: GenerationStats | None = None
|
||||
logprob: float | None = None
|
||||
top_logprobs: list[TopLogprobItem] | None = None
|
||||
is_thinking: bool = False
|
||||
|
||||
|
||||
class ErrorChunk(BaseChunk):
|
||||
|
||||
@@ -47,6 +47,14 @@ class ClaudeImageBlock(BaseModel, frozen=True):
|
||||
source: ClaudeImageSource
|
||||
|
||||
|
||||
class ClaudeThinkingBlock(BaseModel, frozen=True):
|
||||
"""Thinking content block in Claude Messages API."""
|
||||
|
||||
type: Literal["thinking"] = "thinking"
|
||||
thinking: str
|
||||
signature: str | None = None
|
||||
|
||||
|
||||
class ClaudeToolUseBlock(BaseModel, frozen=True):
|
||||
"""Tool use content block in Claude Messages API."""
|
||||
|
||||
@@ -66,11 +74,17 @@ class ClaudeToolResultBlock(BaseModel, frozen=True):
|
||||
cache_control: dict[str, str] | None = None
|
||||
|
||||
|
||||
ClaudeContentBlock = ClaudeTextBlock | ClaudeImageBlock | ClaudeToolUseBlock
|
||||
ClaudeContentBlock = (
|
||||
ClaudeTextBlock | ClaudeImageBlock | ClaudeThinkingBlock | ClaudeToolUseBlock
|
||||
)
|
||||
|
||||
# Input content blocks can also include tool_result (sent by user after tool_use)
|
||||
ClaudeInputContentBlock = (
|
||||
ClaudeTextBlock | ClaudeImageBlock | ClaudeToolUseBlock | ClaudeToolResultBlock
|
||||
ClaudeTextBlock
|
||||
| ClaudeImageBlock
|
||||
| ClaudeThinkingBlock
|
||||
| ClaudeToolUseBlock
|
||||
| ClaudeToolResultBlock
|
||||
)
|
||||
|
||||
|
||||
@@ -82,6 +96,11 @@ class ClaudeMessage(BaseModel, frozen=True):
|
||||
content: str | list[ClaudeInputContentBlock]
|
||||
|
||||
|
||||
class ClaudeThinkingConfig(BaseModel, frozen=True):
|
||||
type: Literal["enabled", "disabled", "adaptive"]
|
||||
budget_tokens: int | None = None
|
||||
|
||||
|
||||
class ClaudeMessagesRequest(BaseModel):
|
||||
"""Request body for Claude Messages API."""
|
||||
|
||||
@@ -96,6 +115,7 @@ class ClaudeMessagesRequest(BaseModel):
|
||||
top_k: int | None = None
|
||||
tools: list[ClaudeToolDefinition] | None = None
|
||||
metadata: dict[str, str] | None = None
|
||||
thinking: ClaudeThinkingConfig | None = None
|
||||
|
||||
|
||||
# Response types
|
||||
@@ -145,7 +165,7 @@ class ClaudeContentBlockStartEvent(BaseModel, frozen=True):
|
||||
|
||||
type: Literal["content_block_start"] = "content_block_start"
|
||||
index: int
|
||||
content_block: ClaudeTextBlock | ClaudeToolUseBlock
|
||||
content_block: ClaudeTextBlock | ClaudeThinkingBlock | ClaudeToolUseBlock
|
||||
|
||||
|
||||
class ClaudeTextDelta(BaseModel, frozen=True):
|
||||
@@ -155,6 +175,13 @@ class ClaudeTextDelta(BaseModel, frozen=True):
|
||||
text: str
|
||||
|
||||
|
||||
class ClaudeThinkingDelta(BaseModel, frozen=True):
|
||||
"""Delta for thinking content block."""
|
||||
|
||||
type: Literal["thinking_delta"] = "thinking_delta"
|
||||
thinking: str
|
||||
|
||||
|
||||
class ClaudeInputJsonDelta(BaseModel, frozen=True):
|
||||
"""Delta for tool use input JSON content block."""
|
||||
|
||||
@@ -167,7 +194,7 @@ class ClaudeContentBlockDeltaEvent(BaseModel, frozen=True):
|
||||
|
||||
type: Literal["content_block_delta"] = "content_block_delta"
|
||||
index: int
|
||||
delta: ClaudeTextDelta | ClaudeInputJsonDelta
|
||||
delta: ClaudeTextDelta | ClaudeThinkingDelta | ClaudeInputJsonDelta
|
||||
|
||||
|
||||
class ClaudeContentBlockStopEvent(BaseModel, frozen=True):
|
||||
|
||||
@@ -5,7 +5,7 @@ from pydantic import Field
|
||||
|
||||
from exo.shared.topology import Connection
|
||||
from exo.shared.types.chunks import GenerationChunk, InputImageChunk
|
||||
from exo.shared.types.common import CommandId, Id, ModelId, NodeId, SessionId
|
||||
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
|
||||
from exo.shared.types.tasks import Task, TaskId, TaskStatus
|
||||
from exo.shared.types.worker.downloads import DownloadProgress
|
||||
from exo.shared.types.worker.instances import Instance, InstanceId
|
||||
@@ -102,13 +102,6 @@ class InputChunkReceived(BaseEvent):
|
||||
chunk: InputImageChunk
|
||||
|
||||
|
||||
class PrefillProgress(BaseEvent):
|
||||
command_id: CommandId
|
||||
model: ModelId
|
||||
processed_tokens: int
|
||||
total_tokens: int
|
||||
|
||||
|
||||
class TopologyEdgeCreated(BaseEvent):
|
||||
conn: Connection
|
||||
|
||||
@@ -155,7 +148,6 @@ Event = (
|
||||
| NodeDownloadProgress
|
||||
| ChunkGenerated
|
||||
| InputChunkReceived
|
||||
| PrefillProgress
|
||||
| TopologyEdgeCreated
|
||||
| TopologyEdgeDeleted
|
||||
| TracesCollected
|
||||
|
||||
@@ -145,7 +145,23 @@ class ResponseFunctionCallItem(BaseModel, frozen=True):
|
||||
status: ResponseStatus = "completed"
|
||||
|
||||
|
||||
ResponseItem = ResponseMessageItem | ResponseFunctionCallItem
|
||||
class ResponseReasoningSummaryText(BaseModel, frozen=True):
|
||||
"""Summary text part in a reasoning output item."""
|
||||
|
||||
type: Literal["summary_text"] = "summary_text"
|
||||
text: str
|
||||
|
||||
|
||||
class ResponseReasoningItem(BaseModel, frozen=True):
|
||||
"""Reasoning output item in response output array."""
|
||||
|
||||
type: Literal["reasoning"] = "reasoning"
|
||||
id: str
|
||||
summary: list[ResponseReasoningSummaryText] = Field(default_factory=list)
|
||||
status: ResponseStatus = "completed"
|
||||
|
||||
|
||||
ResponseItem = ResponseMessageItem | ResponseFunctionCallItem | ResponseReasoningItem
|
||||
|
||||
|
||||
class ResponseUsage(BaseModel, frozen=True):
|
||||
@@ -273,6 +289,58 @@ class ResponseFunctionCallArgumentsDoneEvent(BaseModel, frozen=True):
|
||||
arguments: str
|
||||
|
||||
|
||||
class ResponseReasoningSummaryPartAddedEvent(BaseModel, frozen=True):
|
||||
"""Event sent when a reasoning summary part is added."""
|
||||
|
||||
type: Literal["response.reasoning_summary_part.added"] = (
|
||||
"response.reasoning_summary_part.added"
|
||||
)
|
||||
sequence_number: int
|
||||
item_id: str
|
||||
output_index: int
|
||||
summary_index: int
|
||||
part: ResponseReasoningSummaryText
|
||||
|
||||
|
||||
class ResponseReasoningSummaryTextDeltaEvent(BaseModel, frozen=True):
|
||||
"""Event sent for reasoning summary text delta during streaming."""
|
||||
|
||||
type: Literal["response.reasoning_summary_text.delta"] = (
|
||||
"response.reasoning_summary_text.delta"
|
||||
)
|
||||
sequence_number: int
|
||||
item_id: str
|
||||
output_index: int
|
||||
summary_index: int
|
||||
delta: str
|
||||
|
||||
|
||||
class ResponseReasoningSummaryTextDoneEvent(BaseModel, frozen=True):
|
||||
"""Event sent when reasoning summary text is done."""
|
||||
|
||||
type: Literal["response.reasoning_summary_text.done"] = (
|
||||
"response.reasoning_summary_text.done"
|
||||
)
|
||||
sequence_number: int
|
||||
item_id: str
|
||||
output_index: int
|
||||
summary_index: int
|
||||
text: str
|
||||
|
||||
|
||||
class ResponseReasoningSummaryPartDoneEvent(BaseModel, frozen=True):
|
||||
"""Event sent when a reasoning summary part is done."""
|
||||
|
||||
type: Literal["response.reasoning_summary_part.done"] = (
|
||||
"response.reasoning_summary_part.done"
|
||||
)
|
||||
sequence_number: int
|
||||
item_id: str
|
||||
output_index: int
|
||||
summary_index: int
|
||||
part: ResponseReasoningSummaryText
|
||||
|
||||
|
||||
class ResponseCompletedEvent(BaseModel, frozen=True):
|
||||
"""Event sent when response is completed."""
|
||||
|
||||
@@ -292,5 +360,9 @@ ResponsesStreamEvent = (
|
||||
| ResponseOutputItemDoneEvent
|
||||
| ResponseFunctionCallArgumentsDeltaEvent
|
||||
| ResponseFunctionCallArgumentsDoneEvent
|
||||
| ResponseReasoningSummaryPartAddedEvent
|
||||
| ResponseReasoningSummaryTextDeltaEvent
|
||||
| ResponseReasoningSummaryTextDoneEvent
|
||||
| ResponseReasoningSummaryPartDoneEvent
|
||||
| ResponseCompletedEvent
|
||||
)
|
||||
|
||||
@@ -28,6 +28,7 @@ class GenerationResponse(BaseRunnerResponse):
|
||||
finish_reason: FinishReason | None = None
|
||||
stats: GenerationStats | None = None
|
||||
usage: Usage | None
|
||||
is_thinking: bool = False
|
||||
|
||||
|
||||
class ImageGenerationResponse(BaseRunnerResponse):
|
||||
|
||||
72
src/exo/worker/engines/mlx/dsml_encoding.py
Normal file
72
src/exo/worker/engines/mlx/dsml_encoding.py
Normal file
@@ -0,0 +1,72 @@
|
||||
import json
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from mlx_lm.chat_templates import deepseek_v32
|
||||
|
||||
from exo.shared.types.api import ToolCallItem
|
||||
|
||||
BOS_TOKEN: str = deepseek_v32.bos_token
|
||||
EOS_TOKEN: str = deepseek_v32.eos_token
|
||||
DSML_TOKEN: str = deepseek_v32.dsml_token
|
||||
THINKING_START: str = deepseek_v32.thinking_start_token
|
||||
THINKING_END: str = deepseek_v32.thinking_end_token
|
||||
USER_TOKEN = "<\uff5cUser\uff5c>"
|
||||
ASSISTANT_TOKEN = "<\uff5cAssistant\uff5c>"
|
||||
TOOL_CALLS_START = f"<{DSML_TOKEN}function_calls>"
|
||||
TOOL_CALLS_END = f"</{DSML_TOKEN}function_calls>"
|
||||
encode_messages = deepseek_v32.encode_messages
|
||||
|
||||
_INVOKE_PATTERN = re.compile(
|
||||
rf"<{re.escape(DSML_TOKEN)}invoke\s+name=\"([^\"]+)\">"
|
||||
rf"(.*?)"
|
||||
rf"</{re.escape(DSML_TOKEN)}invoke>",
|
||||
re.DOTALL,
|
||||
)
|
||||
|
||||
_PARAM_PATTERN = re.compile(
|
||||
rf"<{re.escape(DSML_TOKEN)}parameter\s+name=\"([^\"]+)\"\s+string=\"(true|false)\">"
|
||||
rf"(.*?)"
|
||||
rf"</{re.escape(DSML_TOKEN)}parameter>",
|
||||
re.DOTALL,
|
||||
)
|
||||
|
||||
|
||||
def parse_dsml_output(text: str) -> list[ToolCallItem] | None:
|
||||
"""Parse DSML function_calls block from model output text.
|
||||
|
||||
Args:
|
||||
text: The text containing the DSML function_calls block
|
||||
(including the start/end markers).
|
||||
|
||||
Returns:
|
||||
List of ToolCallItem, or None if parsing fails.
|
||||
"""
|
||||
tool_calls: list[ToolCallItem] = []
|
||||
|
||||
for invoke_match in _INVOKE_PATTERN.finditer(text):
|
||||
func_name = invoke_match.group(1)
|
||||
invoke_body = invoke_match.group(2)
|
||||
|
||||
args: dict[str, Any] = {}
|
||||
for param_match in _PARAM_PATTERN.finditer(invoke_body):
|
||||
param_name = param_match.group(1)
|
||||
is_string = param_match.group(2) == "true"
|
||||
param_value = param_match.group(3)
|
||||
|
||||
if is_string:
|
||||
args[param_name] = param_value
|
||||
else:
|
||||
try:
|
||||
args[param_name] = json.loads(param_value)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
args[param_name] = param_value
|
||||
|
||||
tool_calls.append(
|
||||
ToolCallItem(
|
||||
name=func_name,
|
||||
arguments=json.dumps(args),
|
||||
)
|
||||
)
|
||||
|
||||
return tool_calls if tool_calls else None
|
||||
@@ -458,6 +458,19 @@ def _patch_lossy_chat_template(template: str) -> str | None:
|
||||
return patched if n > 0 else None
|
||||
|
||||
|
||||
def _needs_dsml_encoding(task_params: TextGenerationTaskParams) -> bool:
|
||||
if "deepseek-v3.2" not in task_params.model.lower():
|
||||
return False
|
||||
# Use DSML encoding when tools are provided or tool results are in the conversation
|
||||
if task_params.tools:
|
||||
return True
|
||||
if task_params.chat_template_messages:
|
||||
return any(
|
||||
msg.get("role") == "tool" for msg in task_params.chat_template_messages
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
def apply_chat_template(
|
||||
tokenizer: TokenizerWrapper,
|
||||
task_params: TextGenerationTaskParams,
|
||||
@@ -469,7 +482,6 @@ def apply_chat_template(
|
||||
|
||||
When chat_template_messages is available (from Chat Completions API),
|
||||
uses those directly to preserve tool_calls, thinking, and other fields.
|
||||
Otherwise builds messages from the task params input/instructions.
|
||||
"""
|
||||
formatted_messages: list[dict[str, Any]] = []
|
||||
if task_params.chat_template_messages is not None:
|
||||
@@ -497,6 +509,19 @@ def apply_chat_template(
|
||||
partial_assistant_content = cast(str, formatted_messages[-1].get("content", ""))
|
||||
formatted_messages = formatted_messages[:-1]
|
||||
|
||||
if _needs_dsml_encoding(task_params):
|
||||
from exo.worker.engines.mlx.dsml_encoding import encode_messages
|
||||
|
||||
prompt = encode_messages(
|
||||
messages=formatted_messages,
|
||||
thinking_mode="thinking" if task_params.enable_thinking else "chat",
|
||||
tools=task_params.tools,
|
||||
)
|
||||
if partial_assistant_content:
|
||||
prompt += partial_assistant_content
|
||||
logger.info(prompt)
|
||||
return prompt
|
||||
|
||||
extra_kwargs: dict[str, Any] = {}
|
||||
if task_params.enable_thinking is not None:
|
||||
# Qwen3 and GLM use "enable_thinking"; DeepSeek uses "thinking".
|
||||
|
||||
@@ -7,6 +7,7 @@ from functools import cache
|
||||
from typing import Literal
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx_lm.models.deepseek_v32 import Model as DeepseekV32Model
|
||||
from mlx_lm.models.gpt_oss import Model as GptOssModel
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
|
||||
@@ -21,12 +22,17 @@ from exo.shared.constants import EXO_MAX_CHUNK_SIZE, EXO_TRACING_ENABLED
|
||||
from exo.shared.models.model_cards import ModelId, ModelTask
|
||||
from exo.shared.tracing import clear_trace_buffer, get_trace_buffer
|
||||
from exo.shared.types.api import ImageGenerationStats
|
||||
from exo.shared.types.chunks import ErrorChunk, ImageChunk, TokenChunk, ToolCallChunk
|
||||
from exo.shared.types.chunks import (
|
||||
ErrorChunk,
|
||||
ImageChunk,
|
||||
PrefillProgressChunk,
|
||||
TokenChunk,
|
||||
ToolCallChunk,
|
||||
)
|
||||
from exo.shared.types.common import CommandId
|
||||
from exo.shared.types.events import (
|
||||
ChunkGenerated,
|
||||
Event,
|
||||
PrefillProgress,
|
||||
RunnerStatusUpdated,
|
||||
TaskAcknowledged,
|
||||
TaskStatusUpdated,
|
||||
@@ -315,11 +321,13 @@ def main(
|
||||
) -> None:
|
||||
if device_rank == 0:
|
||||
event_sender.send(
|
||||
PrefillProgress(
|
||||
ChunkGenerated(
|
||||
command_id=command_id,
|
||||
model=shard_metadata.model_card.model_id,
|
||||
processed_tokens=processed,
|
||||
total_tokens=total,
|
||||
chunk=PrefillProgressChunk(
|
||||
model=shard_metadata.model_card.model_id,
|
||||
processed_tokens=processed,
|
||||
total_tokens=total,
|
||||
),
|
||||
)
|
||||
)
|
||||
cancelled_tasks.update(cancel_receiver.collect())
|
||||
@@ -346,16 +354,22 @@ def main(
|
||||
group=group,
|
||||
)
|
||||
|
||||
# For other thinking models (GLM, etc.), check if we need to
|
||||
# prepend the thinking tag that was consumed by the chat template
|
||||
if detect_thinking_prompt_suffix(prompt, tokenizer):
|
||||
if tokenizer.has_thinking:
|
||||
mlx_generator = parse_thinking_models(
|
||||
mlx_generator, tokenizer
|
||||
mlx_generator,
|
||||
tokenizer,
|
||||
# For other thinking models (GLM, etc.), check if we need to
|
||||
# prepend the thinking tag that was consumed by the chat template
|
||||
starts_in_thinking=detect_thinking_prompt_suffix(
|
||||
prompt, tokenizer
|
||||
),
|
||||
)
|
||||
|
||||
# GPT-OSS specific parsing to match other model formats.
|
||||
# Model-specific output parsing for tool calls.
|
||||
if isinstance(inference_model, GptOssModel):
|
||||
mlx_generator = parse_gpt_oss(mlx_generator)
|
||||
elif isinstance(inference_model, DeepseekV32Model):
|
||||
mlx_generator = parse_deepseek_v32(mlx_generator)
|
||||
elif tool_parser:
|
||||
mlx_generator = parse_tool_calls(mlx_generator, tool_parser)
|
||||
|
||||
@@ -407,6 +421,7 @@ def main(
|
||||
stats=response.stats,
|
||||
logprob=response.logprob,
|
||||
top_logprobs=response.top_logprobs,
|
||||
is_thinking=response.is_thinking,
|
||||
),
|
||||
)
|
||||
)
|
||||
@@ -668,44 +683,208 @@ def parse_gpt_oss(
|
||||
|
||||
if ch == "analysis" and not thinking:
|
||||
thinking = True
|
||||
yield response.model_copy(update={"text": "<think>"})
|
||||
|
||||
if ch != "analysis" and thinking:
|
||||
thinking = False
|
||||
yield response.model_copy(update={"text": "</think>"})
|
||||
|
||||
if delta:
|
||||
yield response.model_copy(update={"text": delta})
|
||||
yield response.model_copy(update={"text": delta, "is_thinking": thinking})
|
||||
|
||||
if response.finish_reason is not None:
|
||||
if thinking:
|
||||
yield response.model_copy(update={"text": "</think>"})
|
||||
yield response
|
||||
|
||||
|
||||
def parse_deepseek_v32(
|
||||
responses: Generator[GenerationResponse],
|
||||
) -> Generator[GenerationResponse | ToolCallResponse]:
|
||||
"""Parse DeepSeek V3.2 DSML tool calls from the generation stream.
|
||||
|
||||
Uses accumulated-text matching (not per-token marker checks) because
|
||||
DSML markers like <|DSML|function_calls> may span multiple tokens.
|
||||
Also handles <think>...</think> blocks for thinking mode.
|
||||
"""
|
||||
from exo.worker.engines.mlx.dsml_encoding import (
|
||||
THINKING_END,
|
||||
THINKING_START,
|
||||
TOOL_CALLS_END,
|
||||
TOOL_CALLS_START,
|
||||
parse_dsml_output,
|
||||
)
|
||||
|
||||
accumulated = ""
|
||||
in_tool_call = False
|
||||
thinking = False
|
||||
# Tokens buffered while we detect the start of a DSML block
|
||||
pending_buffer: list[GenerationResponse] = []
|
||||
# Text accumulated during a tool call block
|
||||
tool_call_text = ""
|
||||
|
||||
for response in responses:
|
||||
assert isinstance(response, GenerationResponse)
|
||||
|
||||
# ── Handle thinking tags ──
|
||||
if not thinking and THINKING_START in response.text:
|
||||
thinking = True
|
||||
# Yield any text before the <think> tag
|
||||
before = response.text[: response.text.index(THINKING_START)]
|
||||
if before:
|
||||
yield response.model_copy(update={"text": before})
|
||||
continue
|
||||
|
||||
if thinking and THINKING_END in response.text:
|
||||
thinking = False
|
||||
# Yield any text after the </think> tag
|
||||
after = response.text[
|
||||
response.text.index(THINKING_END) + len(THINKING_END) :
|
||||
]
|
||||
if after:
|
||||
yield response.model_copy(update={"text": after, "is_thinking": False})
|
||||
continue
|
||||
|
||||
if thinking:
|
||||
yield response.model_copy(update={"is_thinking": True})
|
||||
continue
|
||||
|
||||
# ── Handle tool call accumulation ──
|
||||
if in_tool_call:
|
||||
tool_call_text += response.text
|
||||
if TOOL_CALLS_END in tool_call_text:
|
||||
# Parse the accumulated DSML block
|
||||
parsed = parse_dsml_output(tool_call_text)
|
||||
if parsed is not None:
|
||||
logger.info(f"parsed DSML tool calls: {parsed}")
|
||||
yield ToolCallResponse(
|
||||
tool_calls=parsed,
|
||||
usage=response.usage,
|
||||
stats=response.stats,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"DSML tool call parsing failed for: {tool_call_text}"
|
||||
)
|
||||
yield response.model_copy(update={"text": tool_call_text})
|
||||
in_tool_call = False
|
||||
tool_call_text = ""
|
||||
continue
|
||||
|
||||
# EOS reached before end marker — yield buffered text as-is
|
||||
if response.finish_reason is not None:
|
||||
logger.info("DSML tool call parsing interrupted by EOS")
|
||||
yield response.model_copy(update={"text": tool_call_text})
|
||||
in_tool_call = False
|
||||
tool_call_text = ""
|
||||
continue
|
||||
|
||||
# ── Detect start of tool call block ──
|
||||
accumulated += response.text
|
||||
|
||||
if TOOL_CALLS_START in accumulated:
|
||||
# The start marker might be split across pending_buffer + current token
|
||||
start_idx = accumulated.index(TOOL_CALLS_START)
|
||||
# Yield any pending tokens that are purely before the marker
|
||||
pre_text = accumulated[:start_idx]
|
||||
if pre_text:
|
||||
# Flush pending buffer tokens that contributed text before the marker
|
||||
for buf_resp in pending_buffer:
|
||||
if pre_text:
|
||||
chunk = buf_resp.text
|
||||
if len(chunk) <= len(pre_text):
|
||||
yield buf_resp
|
||||
pre_text = pre_text[len(chunk) :]
|
||||
else:
|
||||
yield buf_resp.model_copy(update={"text": pre_text})
|
||||
pre_text = ""
|
||||
pending_buffer = []
|
||||
tool_call_text = accumulated[start_idx:]
|
||||
accumulated = ""
|
||||
|
||||
# Check if the end marker is already present (entire tool call in one token)
|
||||
if TOOL_CALLS_END in tool_call_text:
|
||||
parsed = parse_dsml_output(tool_call_text)
|
||||
if parsed is not None:
|
||||
logger.info(f"parsed DSML tool calls: {parsed}")
|
||||
yield ToolCallResponse(
|
||||
tool_calls=parsed,
|
||||
usage=response.usage,
|
||||
stats=response.stats,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"DSML tool call parsing failed for: {tool_call_text}"
|
||||
)
|
||||
yield response.model_copy(update={"text": tool_call_text})
|
||||
tool_call_text = ""
|
||||
else:
|
||||
in_tool_call = True
|
||||
continue
|
||||
|
||||
# Check if accumulated text might be the start of a DSML marker
|
||||
# Buffer tokens if we see a partial match at the end
|
||||
if _could_be_dsml_prefix(accumulated):
|
||||
pending_buffer.append(response)
|
||||
continue
|
||||
|
||||
# No partial match — flush all pending tokens and the current one
|
||||
for buf_resp in pending_buffer:
|
||||
yield buf_resp
|
||||
pending_buffer = []
|
||||
accumulated = ""
|
||||
yield response
|
||||
|
||||
# Flush any remaining pending buffer at generator end
|
||||
for buf_resp in pending_buffer:
|
||||
yield buf_resp
|
||||
|
||||
|
||||
def _could_be_dsml_prefix(text: str) -> bool:
|
||||
"""Check if the end of text could be the start of a DSML function_calls marker.
|
||||
|
||||
We look for suffixes of text that are prefixes of the TOOL_CALLS_START pattern.
|
||||
This allows us to buffer tokens until we can determine if a tool call is starting.
|
||||
"""
|
||||
from exo.worker.engines.mlx.dsml_encoding import TOOL_CALLS_START
|
||||
|
||||
# Only check the last portion of text that could overlap with the marker
|
||||
max_check = len(TOOL_CALLS_START)
|
||||
tail = text[-max_check:] if len(text) > max_check else text
|
||||
|
||||
# Check if any suffix of tail is a prefix of TOOL_CALLS_START
|
||||
for i in range(len(tail)):
|
||||
suffix = tail[i:]
|
||||
if TOOL_CALLS_START.startswith(suffix):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def parse_thinking_models(
|
||||
responses: Generator[GenerationResponse],
|
||||
tokenizer: TokenizerWrapper,
|
||||
starts_in_thinking: bool = True,
|
||||
) -> Generator[GenerationResponse]:
|
||||
"""Route thinking tokens via is_thinking flag.
|
||||
|
||||
Swallows think tag tokens, sets is_thinking on all others.
|
||||
Always yields tokens with finish_reason to avoid hanging the chunk stream.
|
||||
"""
|
||||
For models that inject thinking tags in the prompt (like GLM-4.7),
|
||||
prepend the thinking tag to the output stream so the frontend
|
||||
can properly parse thinking content.
|
||||
"""
|
||||
first = True
|
||||
in_thinking = starts_in_thinking
|
||||
for response in responses:
|
||||
if isinstance(response, ToolCallResponse):
|
||||
yield response
|
||||
continue
|
||||
if first:
|
||||
first = False
|
||||
yield response.model_copy(
|
||||
update={
|
||||
"text": tokenizer.think_start,
|
||||
"token": tokenizer.think_start_id,
|
||||
}
|
||||
)
|
||||
yield response
|
||||
|
||||
is_think_tag = (
|
||||
tokenizer.think_end is not None and response.text == tokenizer.think_end
|
||||
) or (
|
||||
tokenizer.think_start is not None and response.text == tokenizer.think_start
|
||||
)
|
||||
|
||||
if is_think_tag:
|
||||
in_thinking = response.text != tokenizer.think_end
|
||||
# Never swallow finish_reason — the chunk stream needs it to terminate.
|
||||
if response.finish_reason is not None:
|
||||
yield response.model_copy(update={"text": "", "is_thinking": False})
|
||||
continue
|
||||
yield response.model_copy(update={"is_thinking": in_thinking})
|
||||
|
||||
|
||||
def _send_image_chunk(
|
||||
|
||||
967
src/exo/worker/tests/unittests/test_runner/test_dsml_e2e.py
Normal file
967
src/exo/worker/tests/unittests/test_runner/test_dsml_e2e.py
Normal file
@@ -0,0 +1,967 @@
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
from exo.shared.types.worker.runner_response import (
|
||||
GenerationResponse,
|
||||
ToolCallResponse,
|
||||
)
|
||||
from exo.worker.engines.mlx.dsml_encoding import (
|
||||
ASSISTANT_TOKEN,
|
||||
BOS_TOKEN,
|
||||
DSML_TOKEN,
|
||||
EOS_TOKEN,
|
||||
THINKING_END,
|
||||
THINKING_START,
|
||||
TOOL_CALLS_END,
|
||||
TOOL_CALLS_START,
|
||||
USER_TOKEN,
|
||||
encode_messages,
|
||||
parse_dsml_output,
|
||||
)
|
||||
from exo.worker.runner.runner import parse_deepseek_v32
|
||||
|
||||
# ── Shared fixtures ──────────────────────────────────────────────
|
||||
|
||||
_WEATHER_TOOLS: list[dict[str, Any]] = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get the current weather in a given city",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {"type": "string", "description": "The city name"},
|
||||
"units": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
"description": "Temperature units",
|
||||
},
|
||||
},
|
||||
"required": ["city"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_time",
|
||||
"description": "Get the current time in a timezone",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"timezone": {"type": "string"},
|
||||
},
|
||||
"required": ["timezone"],
|
||||
},
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def _simulate_tokens(
|
||||
texts: list[str],
|
||||
finish_on_last: bool = True,
|
||||
) -> Generator[GenerationResponse]:
|
||||
"""Simulate a model producing tokens from a list of text strings."""
|
||||
for i, text in enumerate(texts):
|
||||
is_last = i == len(texts) - 1
|
||||
yield GenerationResponse(
|
||||
text=text,
|
||||
token=i,
|
||||
finish_reason="stop" if (is_last and finish_on_last) else None,
|
||||
usage=None,
|
||||
)
|
||||
|
||||
|
||||
# ── Test: Standard text response (no tool calls) ────────────────
|
||||
|
||||
|
||||
class TestE2EStandardResponse:
|
||||
"""Model generates a plain text response — no tool calling involved."""
|
||||
|
||||
def test_plain_text_passthrough(self):
|
||||
"""Simulate model producing: 'The weather in NYC is 72°F and sunny.'"""
|
||||
# Step 1: Encode the prompt (with tools available)
|
||||
messages: list[dict[str, Any]] = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "What's the weather in NYC?"},
|
||||
]
|
||||
prompt = encode_messages(messages, thinking_mode="chat", tools=_WEATHER_TOOLS)
|
||||
|
||||
# Verify prompt structure
|
||||
assert BOS_TOKEN in prompt
|
||||
assert "## Tools" in prompt
|
||||
assert "get_weather" in prompt
|
||||
assert f"{USER_TOKEN}What's the weather in NYC?{ASSISTANT_TOKEN}" in prompt
|
||||
|
||||
# Step 2: Simulate model response — plain text tokens (no DSML)
|
||||
model_tokens = [
|
||||
"The weather",
|
||||
" in NYC",
|
||||
" is 72",
|
||||
"°F",
|
||||
" and sunny",
|
||||
".",
|
||||
]
|
||||
results = list(parse_deepseek_v32(_simulate_tokens(model_tokens)))
|
||||
|
||||
# Step 3: Verify all tokens pass through as GenerationResponse
|
||||
gen_results = [r for r in results if isinstance(r, GenerationResponse)]
|
||||
tool_results = [r for r in results if isinstance(r, ToolCallResponse)]
|
||||
|
||||
assert len(tool_results) == 0
|
||||
assert len(gen_results) == 6
|
||||
full_text = "".join(r.text for r in gen_results)
|
||||
assert full_text == "The weather in NYC is 72°F and sunny."
|
||||
assert gen_results[-1].finish_reason == "stop"
|
||||
|
||||
|
||||
# ── Test: Tool call response ─────────────────────────────────────
|
||||
|
||||
|
||||
class TestE2EToolCallResponse:
|
||||
"""Model generates a DSML tool call — realistic token boundaries."""
|
||||
|
||||
def test_realistic_tool_call_tokens(self):
|
||||
"""Simulate model generating a get_weather tool call with realistic token splits.
|
||||
|
||||
Real models split DSML markers across tokens unpredictably.
|
||||
This simulates how DeepSeek V3.2 actually tokenizes DSML output.
|
||||
"""
|
||||
# Step 1: Encode prompt
|
||||
messages: list[dict[str, Any]] = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "What's the weather in San Francisco?"},
|
||||
]
|
||||
prompt = encode_messages(messages, thinking_mode="chat", tools=_WEATHER_TOOLS)
|
||||
assert "get_weather" in prompt
|
||||
|
||||
# Step 2: Simulate realistic token-by-token model output
|
||||
# The model first produces some text, then a DSML tool call block
|
||||
model_tokens = [
|
||||
"I'll check the weather for you.",
|
||||
"\n\n",
|
||||
f"<{DSML_TOKEN}", # marker split across tokens
|
||||
"function_calls>\n",
|
||||
f'<{DSML_TOKEN}invoke name="get_weather">\n',
|
||||
f'<{DSML_TOKEN}parameter name="city" string="true">',
|
||||
"San Francisco",
|
||||
f"</{DSML_TOKEN}parameter>\n",
|
||||
f'<{DSML_TOKEN}parameter name="units" string="false">',
|
||||
'"celsius"',
|
||||
f"</{DSML_TOKEN}parameter>\n",
|
||||
f"</{DSML_TOKEN}invoke>\n",
|
||||
f"</{DSML_TOKEN}function_calls>",
|
||||
]
|
||||
|
||||
results = list(parse_deepseek_v32(_simulate_tokens(model_tokens)))
|
||||
|
||||
# Step 3: Verify
|
||||
gen_results = [r for r in results if isinstance(r, GenerationResponse)]
|
||||
tool_results = [r for r in results if isinstance(r, ToolCallResponse)]
|
||||
|
||||
# Should have text tokens before tool call + one ToolCallResponse
|
||||
assert len(tool_results) == 1
|
||||
assert len(tool_results[0].tool_calls) == 1
|
||||
|
||||
tc = tool_results[0].tool_calls[0]
|
||||
assert tc.name == "get_weather"
|
||||
args = json.loads(tc.arguments) # pyright: ignore[reportAny]
|
||||
assert args["city"] == "San Francisco"
|
||||
assert args["units"] == "celsius"
|
||||
|
||||
# The text before the tool call should still be yielded
|
||||
text_before = "".join(r.text for r in gen_results if not r.is_thinking)
|
||||
assert "check the weather" in text_before
|
||||
|
||||
def test_multiple_tool_calls_in_one_block(self):
|
||||
"""Model generates two tool calls in a single function_calls block."""
|
||||
messages: list[dict[str, Any]] = [
|
||||
{"role": "system", "content": "You are helpful."},
|
||||
{"role": "user", "content": "Weather in NYC and time in EST?"},
|
||||
]
|
||||
prompt = encode_messages(messages, thinking_mode="chat", tools=_WEATHER_TOOLS)
|
||||
assert "get_weather" in prompt
|
||||
assert "get_time" in prompt
|
||||
|
||||
# Simulate model output with two invocations
|
||||
model_tokens = [
|
||||
"Let me check both.\n\n",
|
||||
TOOL_CALLS_START,
|
||||
"\n",
|
||||
f'<{DSML_TOKEN}invoke name="get_weather">\n',
|
||||
f'<{DSML_TOKEN}parameter name="city" string="true">NYC</{DSML_TOKEN}parameter>\n',
|
||||
f"</{DSML_TOKEN}invoke>\n",
|
||||
f'<{DSML_TOKEN}invoke name="get_time">\n',
|
||||
f'<{DSML_TOKEN}parameter name="timezone" string="true">EST</{DSML_TOKEN}parameter>\n',
|
||||
f"</{DSML_TOKEN}invoke>\n",
|
||||
TOOL_CALLS_END,
|
||||
]
|
||||
|
||||
results = list(parse_deepseek_v32(_simulate_tokens(model_tokens)))
|
||||
tool_results = [r for r in results if isinstance(r, ToolCallResponse)]
|
||||
|
||||
assert len(tool_results) == 1
|
||||
assert len(tool_results[0].tool_calls) == 2
|
||||
assert tool_results[0].tool_calls[0].name == "get_weather"
|
||||
assert tool_results[0].tool_calls[1].name == "get_time"
|
||||
|
||||
args0 = json.loads(tool_results[0].tool_calls[0].arguments) # pyright: ignore[reportAny]
|
||||
args1 = json.loads(tool_results[0].tool_calls[1].arguments) # pyright: ignore[reportAny]
|
||||
assert args0 == {"city": "NYC"}
|
||||
assert args1 == {"timezone": "EST"}
|
||||
|
||||
|
||||
# ── Test: Multi-turn tool use flow ───────────────────────────────
|
||||
|
||||
|
||||
class TestE2EMultiTurnToolUse:
|
||||
"""Full multi-turn: user asks → model calls tool → tool result → model answers."""
|
||||
|
||||
def test_encode_multi_turn_with_tool_results(self):
|
||||
"""Verify the prompt for turn 2 (after tool results) is correctly encoded."""
|
||||
# Turn 1: user asks, model calls tool
|
||||
# Turn 2: tool result provided, model answers
|
||||
messages: list[dict[str, Any]] = [
|
||||
{"role": "system", "content": "You are a weather assistant."},
|
||||
{"role": "user", "content": "What's the weather in NYC?"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": '{"city": "NYC"}',
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
{"role": "tool", "content": '{"temperature": 72, "condition": "sunny"}'},
|
||||
]
|
||||
|
||||
prompt = encode_messages(messages, thinking_mode="chat", tools=_WEATHER_TOOLS)
|
||||
|
||||
# Verify multi-turn structure
|
||||
assert BOS_TOKEN in prompt
|
||||
assert "You are a weather assistant." in prompt
|
||||
assert "## Tools" in prompt
|
||||
|
||||
# The assistant's tool call should be encoded as DSML
|
||||
assert TOOL_CALLS_START in prompt
|
||||
assert f'<{DSML_TOKEN}invoke name="get_weather">' in prompt
|
||||
assert EOS_TOKEN in prompt
|
||||
|
||||
# The tool result should be wrapped in function_results
|
||||
assert "<function_results>" in prompt
|
||||
assert "<result>" in prompt
|
||||
assert "72" in prompt
|
||||
assert "</function_results>" in prompt
|
||||
|
||||
# Now simulate model answering after seeing the tool result
|
||||
model_tokens = [
|
||||
"The current",
|
||||
" weather in NYC",
|
||||
" is 72°F",
|
||||
" and sunny.",
|
||||
]
|
||||
results = list(parse_deepseek_v32(_simulate_tokens(model_tokens)))
|
||||
|
||||
gen_results = [r for r in results if isinstance(r, GenerationResponse)]
|
||||
tool_results = [r for r in results if isinstance(r, ToolCallResponse)]
|
||||
|
||||
assert len(tool_results) == 0
|
||||
full_text = "".join(r.text for r in gen_results)
|
||||
assert full_text == "The current weather in NYC is 72°F and sunny."
|
||||
|
||||
def test_multi_tool_results_encoding(self):
|
||||
"""Verify encoding when model called two tools and both return results."""
|
||||
messages: list[dict[str, Any]] = [
|
||||
{"role": "user", "content": "Weather and time?"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": '{"city": "LA"}',
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_time",
|
||||
"arguments": '{"timezone": "PST"}',
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
{"role": "tool", "content": "85F, clear skies"},
|
||||
{"role": "tool", "content": "3:42 PM PST"},
|
||||
]
|
||||
|
||||
prompt = encode_messages(messages, thinking_mode="chat", tools=_WEATHER_TOOLS)
|
||||
|
||||
# Should have one function_results block with two results
|
||||
assert prompt.count("<function_results>") == 1
|
||||
assert prompt.count("</function_results>") == 1
|
||||
assert "<result>85F, clear skies</result>" in prompt
|
||||
assert "<result>3:42 PM PST</result>" in prompt
|
||||
|
||||
|
||||
# ── Test: Thinking + tool call ───────────────────────────────────
|
||||
|
||||
|
||||
class TestE2EThinkingAndToolCall:
|
||||
"""Model uses thinking mode, reasons, then makes a tool call."""
|
||||
|
||||
def test_thinking_then_tool_call(self):
|
||||
"""Model thinks first, then produces a DSML tool call block."""
|
||||
messages: list[dict[str, Any]] = [
|
||||
{"role": "user", "content": "What's the weather?"},
|
||||
]
|
||||
prompt = encode_messages(
|
||||
messages, tools=_WEATHER_TOOLS, thinking_mode="thinking"
|
||||
)
|
||||
# Thinking mode: prompt should end with <think>
|
||||
assert prompt.endswith(THINKING_START)
|
||||
|
||||
# Simulate: model outputs <think>, thinks, closes thinking, then tool call.
|
||||
# In the full pipeline, parse_thinking_models handles the case where
|
||||
# <think> is in the prompt. Here we test parse_deepseek_v32 directly,
|
||||
# which detects <think>/<think> markers in the stream.
|
||||
model_tokens = [
|
||||
THINKING_START,
|
||||
"The user wants weather",
|
||||
" information. I should use",
|
||||
" the get_weather tool.",
|
||||
THINKING_END,
|
||||
"\n\n",
|
||||
TOOL_CALLS_START,
|
||||
"\n",
|
||||
f'<{DSML_TOKEN}invoke name="get_weather">\n',
|
||||
f'<{DSML_TOKEN}parameter name="city" string="true">',
|
||||
"San Francisco",
|
||||
f"</{DSML_TOKEN}parameter>\n",
|
||||
f"</{DSML_TOKEN}invoke>\n",
|
||||
TOOL_CALLS_END,
|
||||
]
|
||||
|
||||
results = list(parse_deepseek_v32(_simulate_tokens(model_tokens)))
|
||||
|
||||
gen_results = [r for r in results if isinstance(r, GenerationResponse)]
|
||||
tool_results = [r for r in results if isinstance(r, ToolCallResponse)]
|
||||
|
||||
# Should have thinking tokens + tool call
|
||||
thinking_results = [r for r in gen_results if r.is_thinking]
|
||||
|
||||
assert len(thinking_results) >= 1
|
||||
thinking_text = "".join(r.text for r in thinking_results)
|
||||
assert "get_weather tool" in thinking_text
|
||||
|
||||
assert len(tool_results) == 1
|
||||
assert tool_results[0].tool_calls[0].name == "get_weather"
|
||||
args = json.loads(tool_results[0].tool_calls[0].arguments) # pyright: ignore[reportAny]
|
||||
assert args["city"] == "San Francisco"
|
||||
|
||||
def test_thinking_prompt_encoding(self):
|
||||
"""Verify thinking mode affects prompt encoding correctly."""
|
||||
messages: list[dict[str, Any]] = [
|
||||
{"role": "system", "content": "Be thorough."},
|
||||
{"role": "user", "content": "What's the weather?"},
|
||||
]
|
||||
|
||||
# With thinking enabled
|
||||
prompt_think = encode_messages(
|
||||
messages, tools=_WEATHER_TOOLS, thinking_mode="thinking"
|
||||
)
|
||||
assert prompt_think.endswith(THINKING_START)
|
||||
|
||||
# With thinking disabled
|
||||
prompt_no_think = encode_messages(
|
||||
messages, tools=_WEATHER_TOOLS, thinking_mode="chat"
|
||||
)
|
||||
assert prompt_no_think.endswith(THINKING_END)
|
||||
|
||||
# Both should have the same tool definitions
|
||||
assert "get_weather" in prompt_think
|
||||
assert "get_weather" in prompt_no_think
|
||||
|
||||
|
||||
# ── Test: Round-trip encode → parse ──────────────────────────────
|
||||
|
||||
|
||||
class TestE2ERoundTrip:
|
||||
"""Verify that DSML we encode can be parsed back correctly."""
|
||||
|
||||
def test_encoded_tool_call_is_parseable(self):
|
||||
"""Encode an assistant tool call message, then parse the DSML output."""
|
||||
messages: list[dict[str, Any]] = [
|
||||
{"role": "user", "content": "Weather?"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": '{"city": "Tokyo", "units": "celsius"}',
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
prompt = encode_messages(messages, thinking_mode="chat", tools=_WEATHER_TOOLS)
|
||||
|
||||
# Extract the DSML function_calls block from the prompt
|
||||
start = prompt.index(TOOL_CALLS_START)
|
||||
end = prompt.index(TOOL_CALLS_END) + len(TOOL_CALLS_END)
|
||||
dsml_block = prompt[start:end]
|
||||
|
||||
# Parse it back
|
||||
parsed = parse_dsml_output(dsml_block)
|
||||
assert parsed is not None
|
||||
assert len(parsed) == 1
|
||||
assert parsed[0].name == "get_weather"
|
||||
args = json.loads(parsed[0].arguments) # pyright: ignore[reportAny]
|
||||
assert args["city"] == "Tokyo"
|
||||
assert args["units"] == "celsius"
|
||||
|
||||
def test_encoded_multi_tool_call_round_trips(self):
|
||||
"""Encode multiple tool calls, verify they parse back correctly."""
|
||||
messages: list[dict[str, Any]] = [
|
||||
{"role": "user", "content": "Both please"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": '{"city": "Paris"}',
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_time",
|
||||
"arguments": '{"timezone": "CET"}',
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
prompt = encode_messages(messages, thinking_mode="chat", tools=_WEATHER_TOOLS)
|
||||
|
||||
start = prompt.index(TOOL_CALLS_START)
|
||||
end = prompt.index(TOOL_CALLS_END) + len(TOOL_CALLS_END)
|
||||
dsml_block = prompt[start:end]
|
||||
|
||||
parsed = parse_dsml_output(dsml_block)
|
||||
assert parsed is not None
|
||||
assert len(parsed) == 2
|
||||
assert parsed[0].name == "get_weather"
|
||||
assert parsed[1].name == "get_time"
|
||||
assert json.loads(parsed[0].arguments) == {"city": "Paris"}
|
||||
assert json.loads(parsed[1].arguments) == {"timezone": "CET"}
|
||||
|
||||
|
||||
# ── Test: Edge cases with realistic token boundaries ─────────────
|
||||
|
||||
|
||||
class TestE2EEdgeCases:
|
||||
"""Edge cases that occur in real model inference."""
|
||||
|
||||
def test_dsml_marker_split_at_fullwidth_pipe(self):
|
||||
"""The fullwidth pipe character | might be its own token."""
|
||||
# This is a realistic tokenization: the DSML marker is split at the | chars
|
||||
model_tokens = [
|
||||
"Let me help.\n\n",
|
||||
"<\uff5c", # start of |DSML|
|
||||
"DSML\uff5c", # rest of DSML token
|
||||
"function_calls>\n",
|
||||
f'<{DSML_TOKEN}invoke name="get_weather">\n',
|
||||
f'<{DSML_TOKEN}parameter name="city" string="true">NYC</{DSML_TOKEN}parameter>\n',
|
||||
f"</{DSML_TOKEN}invoke>\n",
|
||||
TOOL_CALLS_END,
|
||||
]
|
||||
|
||||
results = list(parse_deepseek_v32(_simulate_tokens(model_tokens)))
|
||||
tool_results = [r for r in results if isinstance(r, ToolCallResponse)]
|
||||
|
||||
assert len(tool_results) == 1
|
||||
assert tool_results[0].tool_calls[0].name == "get_weather"
|
||||
|
||||
def test_tool_call_with_nested_json_object(self):
|
||||
"""Model passes a complex JSON object as a non-string parameter."""
|
||||
dsml_block = (
|
||||
f"{TOOL_CALLS_START}\n"
|
||||
f'<{DSML_TOKEN}invoke name="create_event">\n'
|
||||
f'<{DSML_TOKEN}parameter name="title" string="true">Team Standup</{DSML_TOKEN}parameter>\n'
|
||||
f'<{DSML_TOKEN}parameter name="config" string="false">'
|
||||
f'{{"recurring": true, "days": ["mon", "wed", "fri"], "time": "09:00"}}'
|
||||
f"</{DSML_TOKEN}parameter>\n"
|
||||
f"</{DSML_TOKEN}invoke>\n"
|
||||
f"{TOOL_CALLS_END}"
|
||||
)
|
||||
|
||||
# Feed as single token (model might produce it all at once after prefill)
|
||||
results = list(parse_deepseek_v32(_simulate_tokens([dsml_block])))
|
||||
tool_results = [r for r in results if isinstance(r, ToolCallResponse)]
|
||||
|
||||
assert len(tool_results) == 1
|
||||
tc = tool_results[0].tool_calls[0]
|
||||
assert tc.name == "create_event"
|
||||
args = json.loads(tc.arguments) # pyright: ignore[reportAny]
|
||||
assert args["title"] == "Team Standup"
|
||||
assert args["config"]["recurring"] is True
|
||||
assert args["config"]["days"] == ["mon", "wed", "fri"]
|
||||
|
||||
def test_text_with_angle_brackets_not_mistaken_for_dsml(self):
|
||||
"""Angle brackets in normal text should not trigger DSML buffering."""
|
||||
model_tokens = [
|
||||
"The formula is ",
|
||||
"<x, y>",
|
||||
" where x > 0",
|
||||
" and y < 100.",
|
||||
]
|
||||
|
||||
results = list(parse_deepseek_v32(_simulate_tokens(model_tokens)))
|
||||
gen_results = [r for r in results if isinstance(r, GenerationResponse)]
|
||||
tool_results = [r for r in results if isinstance(r, ToolCallResponse)]
|
||||
|
||||
assert len(tool_results) == 0
|
||||
full_text = "".join(r.text for r in gen_results)
|
||||
assert "formula" in full_text
|
||||
assert "<x, y>" in full_text
|
||||
|
||||
def test_empty_model_response(self):
|
||||
"""Model produces only EOS (empty response)."""
|
||||
model_tokens = [""]
|
||||
results = list(parse_deepseek_v32(_simulate_tokens(model_tokens)))
|
||||
gen_results = [r for r in results if isinstance(r, GenerationResponse)]
|
||||
assert len(gen_results) == 1
|
||||
assert gen_results[0].text == ""
|
||||
assert gen_results[0].finish_reason == "stop"
|
||||
|
||||
|
||||
# ── Test: Full EPDP spec round-trip ──────────────────────────────
|
||||
|
||||
|
||||
class TestE2EFullRoundTrip:
|
||||
"""Full round-trip matching the vLLM EPDP spec.
|
||||
|
||||
Simulates the complete multi-turn flow:
|
||||
Turn 1: user asks → think → tool call → tool result → think → answer
|
||||
Turn 2: user asks again → old reasoning stripped → think → answer
|
||||
"""
|
||||
|
||||
def test_single_tool_full_flow_with_thinking(self):
|
||||
"""Complete flow: user → think → tool call → tool result → think → answer.
|
||||
|
||||
This is the core EPDP flow from the vLLM spec.
|
||||
"""
|
||||
# ── Turn 1.1: User asks, encode prompt ──
|
||||
messages: list[dict[str, Any]] = [
|
||||
{"role": "system", "content": "You are a weather assistant."},
|
||||
{"role": "user", "content": "How's the weather in Hangzhou?"},
|
||||
]
|
||||
prompt_1 = encode_messages(
|
||||
messages, tools=_WEATHER_TOOLS, thinking_mode="thinking"
|
||||
)
|
||||
assert prompt_1.endswith(THINKING_START)
|
||||
assert "## Tools" in prompt_1
|
||||
assert "get_weather" in prompt_1
|
||||
|
||||
# ── Turn 1.1: Model thinks, then calls tool ──
|
||||
model_tokens_1 = [
|
||||
THINKING_START,
|
||||
"The user wants to know the weather in Hangzhou.",
|
||||
" I need to use the get_weather tool.",
|
||||
THINKING_END,
|
||||
"\n\n",
|
||||
TOOL_CALLS_START,
|
||||
"\n",
|
||||
f'<{DSML_TOKEN}invoke name="get_weather">\n',
|
||||
f'<{DSML_TOKEN}parameter name="city" string="true">Hangzhou</{DSML_TOKEN}parameter>\n',
|
||||
f"</{DSML_TOKEN}invoke>\n",
|
||||
TOOL_CALLS_END,
|
||||
]
|
||||
results_1 = list(parse_deepseek_v32(_simulate_tokens(model_tokens_1)))
|
||||
|
||||
# Verify: thinking tokens + tool call
|
||||
gen_1 = [r for r in results_1 if isinstance(r, GenerationResponse)]
|
||||
tool_1 = [r for r in results_1 if isinstance(r, ToolCallResponse)]
|
||||
thinking_1 = [r for r in gen_1 if r.is_thinking]
|
||||
|
||||
assert len(thinking_1) >= 1
|
||||
assert "get_weather tool" in "".join(r.text for r in thinking_1)
|
||||
assert len(tool_1) == 1
|
||||
assert tool_1[0].tool_calls[0].name == "get_weather"
|
||||
tc_args = json.loads(tool_1[0].tool_calls[0].arguments) # pyright: ignore[reportAny]
|
||||
assert tc_args == {"city": "Hangzhou"}
|
||||
|
||||
# ── Turn 1.2: Add assistant response + tool result to messages ──
|
||||
messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"reasoning_content": "The user wants to know the weather in Hangzhou. I need to use the get_weather tool.",
|
||||
"tool_calls": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": '{"city": "Hangzhou"}',
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"content": '{"temperature": "7~13°C", "condition": "Cloudy"}',
|
||||
}
|
||||
)
|
||||
|
||||
# Encode prompt for turn 1.2
|
||||
prompt_2 = encode_messages(
|
||||
messages, tools=_WEATHER_TOOLS, thinking_mode="thinking"
|
||||
)
|
||||
|
||||
# Verify: prompt has the full conversation structure
|
||||
assert TOOL_CALLS_START in prompt_2 # assistant's encoded tool call
|
||||
assert EOS_TOKEN in prompt_2 # assistant turn ends with EOS
|
||||
assert "<function_results>" in prompt_2
|
||||
assert "<result>" in prompt_2
|
||||
assert "Cloudy" in prompt_2
|
||||
assert "</function_results>" in prompt_2
|
||||
# After tool results with thinking enabled → <think> appended
|
||||
assert prompt_2.endswith(THINKING_START)
|
||||
# The assistant's reasoning_content should appear (it's after last_user_idx)
|
||||
assert "get_weather tool" in prompt_2
|
||||
|
||||
# ── Turn 1.2: Model thinks about results, then answers ──
|
||||
model_tokens_2 = [
|
||||
THINKING_START,
|
||||
"The weather in Hangzhou is Cloudy, 7~13°C.",
|
||||
" I'll tell the user.",
|
||||
THINKING_END,
|
||||
"The weather in Hangzhou is currently cloudy with temperatures between 7°C and 13°C.",
|
||||
]
|
||||
results_2 = list(parse_deepseek_v32(_simulate_tokens(model_tokens_2)))
|
||||
|
||||
gen_2 = [r for r in results_2 if isinstance(r, GenerationResponse)]
|
||||
tool_2 = [r for r in results_2 if isinstance(r, ToolCallResponse)]
|
||||
thinking_2 = [r for r in gen_2 if r.is_thinking]
|
||||
non_thinking_2 = [r for r in gen_2 if not r.is_thinking]
|
||||
|
||||
assert len(tool_2) == 0 # No more tool calls
|
||||
assert len(thinking_2) >= 1
|
||||
assert "Cloudy" in "".join(r.text for r in thinking_2)
|
||||
assert len(non_thinking_2) >= 1
|
||||
final_text = "".join(r.text for r in non_thinking_2)
|
||||
assert "7°C" in final_text
|
||||
assert "13°C" in final_text
|
||||
|
||||
def test_multi_tool_full_flow(self):
|
||||
"""Flow with two tools: user → think → 2 tool calls → 2 results → think → answer."""
|
||||
# ── Initial prompt ──
|
||||
messages: list[dict[str, Any]] = [
|
||||
{"role": "system", "content": "You help with weather and time."},
|
||||
{"role": "user", "content": "Weather in NYC and time in EST?"},
|
||||
]
|
||||
prompt_1 = encode_messages(
|
||||
messages, tools=_WEATHER_TOOLS, thinking_mode="thinking"
|
||||
)
|
||||
assert prompt_1.endswith(THINKING_START)
|
||||
|
||||
# ── Model thinks, calls both tools ──
|
||||
model_tokens_1 = [
|
||||
THINKING_START,
|
||||
"Two requests: weather and time. I'll call both.",
|
||||
THINKING_END,
|
||||
"\n\n",
|
||||
TOOL_CALLS_START,
|
||||
"\n",
|
||||
f'<{DSML_TOKEN}invoke name="get_weather">\n',
|
||||
f'<{DSML_TOKEN}parameter name="city" string="true">NYC</{DSML_TOKEN}parameter>\n',
|
||||
f"</{DSML_TOKEN}invoke>\n",
|
||||
f'<{DSML_TOKEN}invoke name="get_time">\n',
|
||||
f'<{DSML_TOKEN}parameter name="timezone" string="true">EST</{DSML_TOKEN}parameter>\n',
|
||||
f"</{DSML_TOKEN}invoke>\n",
|
||||
TOOL_CALLS_END,
|
||||
]
|
||||
results_1 = list(parse_deepseek_v32(_simulate_tokens(model_tokens_1)))
|
||||
tool_1 = [r for r in results_1 if isinstance(r, ToolCallResponse)]
|
||||
|
||||
assert len(tool_1) == 1
|
||||
assert len(tool_1[0].tool_calls) == 2
|
||||
assert tool_1[0].tool_calls[0].name == "get_weather"
|
||||
assert tool_1[0].tool_calls[1].name == "get_time"
|
||||
|
||||
# ── Add assistant + both tool results ──
|
||||
messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"reasoning_content": "Two requests: weather and time. I'll call both.",
|
||||
"tool_calls": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": '{"city": "NYC"}',
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_time",
|
||||
"arguments": '{"timezone": "EST"}',
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
)
|
||||
messages.append({"role": "tool", "content": "72°F, sunny"})
|
||||
messages.append({"role": "tool", "content": "2:30 PM EST"})
|
||||
|
||||
prompt_2 = encode_messages(
|
||||
messages, tools=_WEATHER_TOOLS, thinking_mode="thinking"
|
||||
)
|
||||
|
||||
# Verify multi-tool result encoding
|
||||
# Count is 2: 1 in _TOOLS_SYSTEM_TEMPLATE example + 1 in conversation
|
||||
assert prompt_2.count("<function_results>") == 2
|
||||
assert prompt_2.count("</function_results>") == 2
|
||||
assert "<result>72°F, sunny</result>" in prompt_2
|
||||
assert "<result>2:30 PM EST</result>" in prompt_2
|
||||
assert prompt_2.endswith(THINKING_START)
|
||||
|
||||
# ── Model thinks about results, answers ──
|
||||
model_tokens_2 = [
|
||||
THINKING_START,
|
||||
"Got both results. Weather is 72°F sunny, time is 2:30 PM.",
|
||||
THINKING_END,
|
||||
"In NYC it's currently 72°F and sunny. The time in EST is 2:30 PM.",
|
||||
]
|
||||
results_2 = list(parse_deepseek_v32(_simulate_tokens(model_tokens_2)))
|
||||
|
||||
tool_2 = [r for r in results_2 if isinstance(r, ToolCallResponse)]
|
||||
gen_2 = [r for r in results_2 if isinstance(r, GenerationResponse)]
|
||||
non_thinking_2 = [r for r in gen_2 if not r.is_thinking]
|
||||
|
||||
assert len(tool_2) == 0
|
||||
final_text = "".join(r.text for r in non_thinking_2)
|
||||
assert "72°F" in final_text
|
||||
assert "2:30 PM" in final_text
|
||||
|
||||
def test_two_user_turns_reasoning_stripped(self):
|
||||
"""Turn 2: old reasoning_content is stripped from history.
|
||||
|
||||
Per the vLLM spec, clear_reasoning_content is called between user turns
|
||||
to save bandwidth. Our _drop_old_thinking handles this.
|
||||
"""
|
||||
# Full turn 1 conversation (already completed)
|
||||
messages: list[dict[str, Any]] = [
|
||||
{"role": "system", "content": "You are helpful."},
|
||||
{"role": "user", "content": "Weather in Hangzhou?"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"reasoning_content": "I need to call get_weather for Hangzhou.",
|
||||
"tool_calls": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": '{"city": "Hangzhou"}',
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
{"role": "tool", "content": "Cloudy 7~13°C"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "The weather in Hangzhou is cloudy, 7-13°C.",
|
||||
"reasoning_content": "The tool returned cloudy weather. I'll summarize.",
|
||||
},
|
||||
# Turn 2: user asks again
|
||||
{"role": "user", "content": "What about Beijing?"},
|
||||
]
|
||||
|
||||
prompt = encode_messages(
|
||||
messages, tools=_WEATHER_TOOLS, thinking_mode="thinking"
|
||||
)
|
||||
|
||||
# Old reasoning_content from turn 1 assistants should be STRIPPED
|
||||
# (they're before the last user message at index 5)
|
||||
assert "I need to call get_weather" not in prompt
|
||||
assert "tool returned cloudy" not in prompt
|
||||
|
||||
# But the assistant's content and tool calls should still be there
|
||||
assert "cloudy, 7-13°C" in prompt
|
||||
assert TOOL_CALLS_START in prompt
|
||||
|
||||
# Prompt ends with <think> for the new turn
|
||||
assert prompt.endswith(THINKING_START)
|
||||
|
||||
# ── Turn 2: Model thinks, calls tool for Beijing ──
|
||||
model_tokens = [
|
||||
THINKING_START,
|
||||
"Now the user wants Beijing weather.",
|
||||
THINKING_END,
|
||||
"\n\n",
|
||||
TOOL_CALLS_START,
|
||||
"\n",
|
||||
f'<{DSML_TOKEN}invoke name="get_weather">\n',
|
||||
f'<{DSML_TOKEN}parameter name="city" string="true">Beijing</{DSML_TOKEN}parameter>\n',
|
||||
f"</{DSML_TOKEN}invoke>\n",
|
||||
TOOL_CALLS_END,
|
||||
]
|
||||
results = list(parse_deepseek_v32(_simulate_tokens(model_tokens)))
|
||||
tool_results = [r for r in results if isinstance(r, ToolCallResponse)]
|
||||
|
||||
assert len(tool_results) == 1
|
||||
assert tool_results[0].tool_calls[0].name == "get_weather"
|
||||
args = json.loads(tool_results[0].tool_calls[0].arguments) # pyright: ignore[reportAny]
|
||||
assert args == {"city": "Beijing"}
|
||||
|
||||
def test_chained_tool_calls_loop(self):
|
||||
"""Model calls tool, gets result, calls another tool, gets result, answers.
|
||||
|
||||
This simulates the inner while loop from the vLLM spec where the model
|
||||
may need multiple sub-turns of tool calling before it has enough info.
|
||||
"""
|
||||
# ── Sub-turn 1: user asks, model calls get_time ──
|
||||
messages: list[dict[str, Any]] = [
|
||||
{"role": "system", "content": "You are helpful."},
|
||||
{"role": "user", "content": "What's the weather in Hangzhou tomorrow?"},
|
||||
]
|
||||
|
||||
prompt_1 = encode_messages(
|
||||
messages, tools=_WEATHER_TOOLS, thinking_mode="thinking"
|
||||
)
|
||||
assert prompt_1.endswith(THINKING_START)
|
||||
|
||||
# Model first calls get_time to figure out the date
|
||||
model_tokens_1 = [
|
||||
THINKING_START,
|
||||
"I need the current date first to calculate tomorrow.",
|
||||
THINKING_END,
|
||||
"\n\n",
|
||||
TOOL_CALLS_START,
|
||||
"\n",
|
||||
f'<{DSML_TOKEN}invoke name="get_time">\n',
|
||||
f'<{DSML_TOKEN}parameter name="timezone" string="true">Asia/Shanghai</{DSML_TOKEN}parameter>\n',
|
||||
f"</{DSML_TOKEN}invoke>\n",
|
||||
TOOL_CALLS_END,
|
||||
]
|
||||
results_1 = list(parse_deepseek_v32(_simulate_tokens(model_tokens_1)))
|
||||
tool_1 = [r for r in results_1 if isinstance(r, ToolCallResponse)]
|
||||
assert len(tool_1) == 1
|
||||
assert tool_1[0].tool_calls[0].name == "get_time"
|
||||
|
||||
# ── Sub-turn 2: add tool result, model calls get_weather ──
|
||||
messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"reasoning_content": "I need the current date first to calculate tomorrow.",
|
||||
"tool_calls": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_time",
|
||||
"arguments": '{"timezone": "Asia/Shanghai"}',
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
messages.append({"role": "tool", "content": "2025-12-01 14:30 CST"})
|
||||
|
||||
prompt_2 = encode_messages(
|
||||
messages, tools=_WEATHER_TOOLS, thinking_mode="thinking"
|
||||
)
|
||||
assert "<result>2025-12-01 14:30 CST</result>" in prompt_2
|
||||
assert prompt_2.endswith(THINKING_START)
|
||||
|
||||
# Model now knows the date, calls get_weather
|
||||
model_tokens_2 = [
|
||||
THINKING_START,
|
||||
"Today is 2025-12-01, so tomorrow is 2025-12-02.",
|
||||
" Now I can check weather for Hangzhou.",
|
||||
THINKING_END,
|
||||
"\n\n",
|
||||
TOOL_CALLS_START,
|
||||
"\n",
|
||||
f'<{DSML_TOKEN}invoke name="get_weather">\n',
|
||||
f'<{DSML_TOKEN}parameter name="city" string="true">Hangzhou</{DSML_TOKEN}parameter>\n',
|
||||
f"</{DSML_TOKEN}invoke>\n",
|
||||
TOOL_CALLS_END,
|
||||
]
|
||||
results_2 = list(parse_deepseek_v32(_simulate_tokens(model_tokens_2)))
|
||||
tool_2 = [r for r in results_2 if isinstance(r, ToolCallResponse)]
|
||||
assert len(tool_2) == 1
|
||||
assert tool_2[0].tool_calls[0].name == "get_weather"
|
||||
|
||||
# ── Sub-turn 3: add weather result, model answers ──
|
||||
messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"reasoning_content": "Today is 2025-12-01, so tomorrow is 2025-12-02. Now I can check weather for Hangzhou.",
|
||||
"tool_calls": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": '{"city": "Hangzhou"}',
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
messages.append({"role": "tool", "content": "Sunny, 5~12°C"})
|
||||
|
||||
prompt_3 = encode_messages(
|
||||
messages, tools=_WEATHER_TOOLS, thinking_mode="thinking"
|
||||
)
|
||||
# Should have both function_results blocks (one per tool round)
|
||||
# Count is 3: 1 in _TOOLS_SYSTEM_TEMPLATE example + 2 in conversation
|
||||
assert prompt_3.count("<function_results>") == 3
|
||||
assert prompt_3.count("</function_results>") == 3
|
||||
assert "<result>2025-12-01 14:30 CST</result>" in prompt_3
|
||||
assert "<result>Sunny, 5~12°C</result>" in prompt_3
|
||||
assert prompt_3.endswith(THINKING_START)
|
||||
|
||||
# Model finally answers
|
||||
model_tokens_3 = [
|
||||
THINKING_START,
|
||||
"I have the weather for tomorrow in Hangzhou.",
|
||||
THINKING_END,
|
||||
"Tomorrow in Hangzhou will be sunny with temperatures between 5°C and 12°C.",
|
||||
]
|
||||
results_3 = list(parse_deepseek_v32(_simulate_tokens(model_tokens_3)))
|
||||
|
||||
tool_3 = [r for r in results_3 if isinstance(r, ToolCallResponse)]
|
||||
gen_3 = [r for r in results_3 if isinstance(r, GenerationResponse)]
|
||||
non_thinking_3 = [r for r in gen_3 if not r.is_thinking]
|
||||
|
||||
assert len(tool_3) == 0 # No more tool calls — loop ends
|
||||
final_text = "".join(r.text for r in non_thinking_3)
|
||||
assert "sunny" in final_text.lower()
|
||||
assert "5°C" in final_text
|
||||
assert "12°C" in final_text
|
||||
@@ -148,6 +148,7 @@ class MockTokenizer:
|
||||
tool_call_start = None
|
||||
tool_call_end = None
|
||||
has_tool_calling = False
|
||||
has_thinking = False
|
||||
|
||||
|
||||
class MockGroup:
|
||||
|
||||
@@ -149,12 +149,23 @@ class TestParseGptOssThinkingThenToolCall:
|
||||
def test_thinking_then_tool_call(self):
|
||||
results = _collect(THINKING_THEN_TOOL_TOKENS)
|
||||
|
||||
# Should have thinking tags + content + tool call
|
||||
text_parts = [r.text for r in results if isinstance(r, GenerationResponse)]
|
||||
combined = "".join(text_parts)
|
||||
assert "<think>" in combined
|
||||
assert "</think>" in combined
|
||||
assert "Let me think about this." in combined
|
||||
# Thinking tokens should have is_thinking=True and no <think> tags
|
||||
thinking_responses = [
|
||||
r for r in results if isinstance(r, GenerationResponse) and r.is_thinking
|
||||
]
|
||||
thinking_text = "".join(r.text for r in thinking_responses)
|
||||
assert "Let me think about this." in thinking_text
|
||||
assert "<think>" not in thinking_text
|
||||
assert "</think>" not in thinking_text
|
||||
|
||||
# Non-thinking tokens should have is_thinking=False
|
||||
non_thinking = [
|
||||
r
|
||||
for r in results
|
||||
if isinstance(r, GenerationResponse) and not r.is_thinking
|
||||
]
|
||||
non_thinking_text = "".join(r.text for r in non_thinking)
|
||||
assert "<think>" not in non_thinking_text
|
||||
|
||||
# And the tool call
|
||||
tc = _get_tool_call(results)
|
||||
|
||||
8
tmp/config_examples/claude_code.sh
Executable file
8
tmp/config_examples/claude_code.sh
Executable file
@@ -0,0 +1,8 @@
|
||||
#!/bin/bash
|
||||
# Run Claude Code against a local exo cluster! (Here, GPT OSS 120B)
|
||||
ANTHROPIC_BASE_URL="http://localhost:52415/" \
|
||||
ANTHROPIC_AUTH_TOKEN="dummy" \
|
||||
ANTHROPIC_MODEL="mlx-community/gpt-oss-120b-MXFP4-Q8" \
|
||||
ANTHROPIC_SMALL_FAST_MODEL="mlx-community/gpt-oss-120b-MXFP4-Q8" \
|
||||
CLAUDE_CODE_DISABLE_NONESSENTIAL_TRAFFIC=1 \
|
||||
claude
|
||||
Reference in New Issue
Block a user