mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-19 23:36:30 -05:00
Compare commits
1 Commits
leo/add-ol
...
session-id
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
932878c9e7 |
@@ -250,11 +250,6 @@ interface RawStateResponse {
|
||||
>;
|
||||
// Thunderbolt bridge cycles (nodes with bridge enabled forming loops)
|
||||
thunderboltBridgeCycles?: string[][];
|
||||
// Disk usage per node
|
||||
nodeDisk?: Record<
|
||||
string,
|
||||
{ total: { inBytes: number }; available: { inBytes: number } }
|
||||
>;
|
||||
}
|
||||
|
||||
export interface MessageAttachment {
|
||||
@@ -1657,12 +1652,11 @@ class AppStore {
|
||||
if (!reader) throw new Error("No response body");
|
||||
|
||||
let fullContent = prefixText;
|
||||
let streamedThinking = "";
|
||||
const collectedTokens: TokenData[] = [...tokensToKeep];
|
||||
|
||||
interface ChatCompletionChunk {
|
||||
choices?: Array<{
|
||||
delta?: { content?: string; reasoning_content?: string };
|
||||
delta?: { content?: string };
|
||||
logprobs?: {
|
||||
content?: Array<{
|
||||
token: string;
|
||||
@@ -1683,7 +1677,6 @@ class AppStore {
|
||||
(parsed) => {
|
||||
const choice = parsed.choices?.[0];
|
||||
const delta = choice?.delta?.content;
|
||||
const thinkingDelta = choice?.delta?.reasoning_content;
|
||||
|
||||
// Collect logprobs data
|
||||
const logprobsContent = choice?.logprobs?.content;
|
||||
@@ -1702,11 +1695,7 @@ class AppStore {
|
||||
}
|
||||
}
|
||||
|
||||
if (thinkingDelta) {
|
||||
streamedThinking += thinkingDelta;
|
||||
}
|
||||
|
||||
if (delta || thinkingDelta) {
|
||||
if (delta) {
|
||||
if (firstTokenTime === null) {
|
||||
firstTokenTime = performance.now();
|
||||
this.ttftMs = firstTokenTime - requestStartTime;
|
||||
@@ -1720,14 +1709,9 @@ class AppStore {
|
||||
this.tps = ((tokenCount - tokensToKeep.length) / elapsed) * 1000;
|
||||
}
|
||||
|
||||
if (delta) {
|
||||
fullContent += delta;
|
||||
}
|
||||
const { displayContent, thinkingContent: tagThinking } =
|
||||
fullContent += delta;
|
||||
const { displayContent, thinkingContent } =
|
||||
this.stripThinkingTags(fullContent);
|
||||
const combinedThinking = [streamedThinking, tagThinking]
|
||||
.filter(Boolean)
|
||||
.join("\n\n");
|
||||
|
||||
if (this.activeConversationId === targetConversationId) {
|
||||
this.currentResponse = displayContent;
|
||||
@@ -1739,7 +1723,7 @@ class AppStore {
|
||||
messageId,
|
||||
(m) => {
|
||||
m.content = displayContent;
|
||||
m.thinking = combinedThinking || undefined;
|
||||
m.thinking = thinkingContent || undefined;
|
||||
m.tokens = [...collectedTokens];
|
||||
},
|
||||
);
|
||||
@@ -1751,14 +1735,11 @@ class AppStore {
|
||||
|
||||
// Final update
|
||||
if (this.conversationExists(targetConversationId)) {
|
||||
const { displayContent, thinkingContent: tagThinking } =
|
||||
const { displayContent, thinkingContent } =
|
||||
this.stripThinkingTags(fullContent);
|
||||
const finalThinking = [streamedThinking, tagThinking]
|
||||
.filter(Boolean)
|
||||
.join("\n\n");
|
||||
this.updateConversationMessage(targetConversationId, messageId, (m) => {
|
||||
m.content = displayContent;
|
||||
m.thinking = finalThinking || undefined;
|
||||
m.thinking = thinkingContent || undefined;
|
||||
m.tokens = [...collectedTokens];
|
||||
if (this.ttftMs !== null) m.ttftMs = this.ttftMs;
|
||||
if (this.tps !== null) m.tps = this.tps;
|
||||
@@ -1866,12 +1847,11 @@ class AppStore {
|
||||
}
|
||||
|
||||
let streamedContent = "";
|
||||
let streamedThinking = "";
|
||||
const collectedTokens: TokenData[] = [];
|
||||
|
||||
interface ChatCompletionChunk {
|
||||
choices?: Array<{
|
||||
delta?: { content?: string; reasoning_content?: string };
|
||||
delta?: { content?: string };
|
||||
logprobs?: {
|
||||
content?: Array<{
|
||||
token: string;
|
||||
@@ -1892,7 +1872,6 @@ class AppStore {
|
||||
(parsed) => {
|
||||
const choice = parsed.choices?.[0];
|
||||
const delta = choice?.delta?.content;
|
||||
const thinkingDelta = choice?.delta?.reasoning_content;
|
||||
|
||||
// Collect logprobs data
|
||||
const logprobsContent = choice?.logprobs?.content;
|
||||
@@ -1911,19 +1890,10 @@ class AppStore {
|
||||
}
|
||||
}
|
||||
|
||||
if (thinkingDelta) {
|
||||
streamedThinking += thinkingDelta;
|
||||
}
|
||||
|
||||
if (delta || thinkingDelta) {
|
||||
if (delta) {
|
||||
streamedContent += delta;
|
||||
}
|
||||
const { displayContent, thinkingContent: tagThinking } =
|
||||
if (delta) {
|
||||
streamedContent += delta;
|
||||
const { displayContent, thinkingContent } =
|
||||
this.stripThinkingTags(streamedContent);
|
||||
const combinedThinking = [streamedThinking, tagThinking]
|
||||
.filter(Boolean)
|
||||
.join("\n\n");
|
||||
|
||||
// Only update currentResponse if target conversation is active
|
||||
if (this.activeConversationId === targetConversationId) {
|
||||
@@ -1936,7 +1906,7 @@ class AppStore {
|
||||
assistantMessage.id,
|
||||
(msg) => {
|
||||
msg.content = displayContent;
|
||||
msg.thinking = combinedThinking || undefined;
|
||||
msg.thinking = thinkingContent || undefined;
|
||||
msg.tokens = [...collectedTokens];
|
||||
},
|
||||
);
|
||||
@@ -1948,17 +1918,14 @@ class AppStore {
|
||||
|
||||
// Final cleanup of the message (if conversation still exists)
|
||||
if (this.conversationExists(targetConversationId)) {
|
||||
const { displayContent, thinkingContent: tagThinking } =
|
||||
const { displayContent, thinkingContent } =
|
||||
this.stripThinkingTags(streamedContent);
|
||||
const finalThinking = [streamedThinking, tagThinking]
|
||||
.filter(Boolean)
|
||||
.join("\n\n");
|
||||
this.updateConversationMessage(
|
||||
targetConversationId,
|
||||
assistantMessage.id,
|
||||
(msg) => {
|
||||
msg.content = displayContent;
|
||||
msg.thinking = finalThinking || undefined;
|
||||
msg.thinking = thinkingContent || undefined;
|
||||
msg.tokens = [...collectedTokens];
|
||||
},
|
||||
);
|
||||
@@ -2350,11 +2317,10 @@ class AppStore {
|
||||
}
|
||||
|
||||
let streamedContent = "";
|
||||
let streamedThinking = "";
|
||||
|
||||
interface ChatCompletionChunk {
|
||||
choices?: Array<{
|
||||
delta?: { content?: string; reasoning_content?: string };
|
||||
delta?: { content?: string };
|
||||
logprobs?: {
|
||||
content?: Array<{
|
||||
token: string;
|
||||
@@ -2382,7 +2348,6 @@ class AppStore {
|
||||
|
||||
const choice = parsed.choices?.[0];
|
||||
const tokenContent = choice?.delta?.content;
|
||||
const thinkingContent = choice?.delta?.reasoning_content;
|
||||
|
||||
// Collect logprobs data
|
||||
const logprobsContent = choice?.logprobs?.content;
|
||||
@@ -2401,11 +2366,7 @@ class AppStore {
|
||||
}
|
||||
}
|
||||
|
||||
if (thinkingContent) {
|
||||
streamedThinking += thinkingContent;
|
||||
}
|
||||
|
||||
if (tokenContent || thinkingContent) {
|
||||
if (tokenContent) {
|
||||
// Track first token for TTFT
|
||||
if (firstTokenTime === null) {
|
||||
firstTokenTime = performance.now();
|
||||
@@ -2422,16 +2383,11 @@ class AppStore {
|
||||
this.tps = (tokenCount / elapsed) * 1000;
|
||||
}
|
||||
|
||||
if (tokenContent) {
|
||||
streamedContent += tokenContent;
|
||||
}
|
||||
streamedContent += tokenContent;
|
||||
|
||||
// Use stripThinkingTags as fallback for any <think> tags still in content
|
||||
const { displayContent, thinkingContent: tagThinking } =
|
||||
// Strip thinking tags for display and extract thinking content
|
||||
const { displayContent, thinkingContent } =
|
||||
this.stripThinkingTags(streamedContent);
|
||||
const combinedThinking = [streamedThinking, tagThinking]
|
||||
.filter(Boolean)
|
||||
.join("\n\n");
|
||||
|
||||
// Only update currentResponse if target conversation is active
|
||||
if (this.activeConversationId === targetConversationId) {
|
||||
@@ -2444,7 +2400,7 @@ class AppStore {
|
||||
assistantMessage.id,
|
||||
(msg) => {
|
||||
msg.content = displayContent;
|
||||
msg.thinking = combinedThinking || undefined;
|
||||
msg.thinking = thinkingContent || undefined;
|
||||
msg.tokens = [...collectedTokens];
|
||||
},
|
||||
);
|
||||
@@ -2480,17 +2436,14 @@ class AppStore {
|
||||
|
||||
// Final cleanup of the message (if conversation still exists)
|
||||
if (this.conversationExists(targetConversationId)) {
|
||||
const { displayContent, thinkingContent: tagThinking } =
|
||||
const { displayContent, thinkingContent } =
|
||||
this.stripThinkingTags(streamedContent);
|
||||
const finalThinking = [streamedThinking, tagThinking]
|
||||
.filter(Boolean)
|
||||
.join("\n\n");
|
||||
this.updateConversationMessage(
|
||||
targetConversationId,
|
||||
assistantMessage.id,
|
||||
(msg) => {
|
||||
msg.content = displayContent;
|
||||
msg.thinking = finalThinking || undefined;
|
||||
msg.thinking = thinkingContent || undefined;
|
||||
msg.tokens = [...collectedTokens];
|
||||
// Store performance metrics on the message
|
||||
if (this.ttftMs !== null) {
|
||||
|
||||
@@ -114,74 +114,6 @@
|
||||
});
|
||||
let tb5InfoDismissed = $state(false);
|
||||
|
||||
// Detect Mac Studio nodes using RDMA on en2 (the port next to ethernet — RDMA doesn't work there)
|
||||
const macStudioEn2RdmaWarning = $derived.by(() => {
|
||||
const edges = data?.edges;
|
||||
const ids = tbIdentifiers;
|
||||
const rdmaCtl = rdmaCtlData;
|
||||
if (!edges || !ids || !rdmaCtl) return null;
|
||||
|
||||
const affectedConnections: Array<{
|
||||
nodeId: string;
|
||||
nodeName: string;
|
||||
peerNodeId: string;
|
||||
peerNodeName: string;
|
||||
rdmaIface: string;
|
||||
}> = [];
|
||||
|
||||
const isMacStudio = (node: (typeof data.nodes)[string] | undefined) =>
|
||||
node?.system_info?.model_id === "Mac Studio";
|
||||
|
||||
for (const edge of edges) {
|
||||
if (!edge.sourceRdmaIface && !edge.sinkRdmaIface) continue;
|
||||
|
||||
const sourceNode = data?.nodes?.[edge.source];
|
||||
if (
|
||||
isMacStudio(sourceNode) &&
|
||||
edge.sourceRdmaIface === "rdma_en2" &&
|
||||
rdmaCtl[edge.source]?.enabled
|
||||
) {
|
||||
affectedConnections.push({
|
||||
nodeId: edge.source,
|
||||
nodeName:
|
||||
sourceNode?.friendly_name || edge.source.slice(0, 8) + "...",
|
||||
peerNodeId: edge.target,
|
||||
peerNodeName:
|
||||
data?.nodes?.[edge.target]?.friendly_name ||
|
||||
edge.target.slice(0, 8) + "...",
|
||||
rdmaIface: "en2",
|
||||
});
|
||||
}
|
||||
|
||||
const sinkNode = data?.nodes?.[edge.target];
|
||||
if (
|
||||
isMacStudio(sinkNode) &&
|
||||
edge.sinkRdmaIface === "rdma_en2" &&
|
||||
rdmaCtl[edge.target]?.enabled
|
||||
) {
|
||||
affectedConnections.push({
|
||||
nodeId: edge.target,
|
||||
nodeName: sinkNode?.friendly_name || edge.target.slice(0, 8) + "...",
|
||||
peerNodeId: edge.source,
|
||||
peerNodeName:
|
||||
sourceNode?.friendly_name || edge.source.slice(0, 8) + "...",
|
||||
rdmaIface: "en2",
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Deduplicate by nodeId
|
||||
const seen = new Set<string>();
|
||||
const unique = affectedConnections.filter((c) => {
|
||||
if (seen.has(c.nodeId)) return false;
|
||||
seen.add(c.nodeId);
|
||||
return true;
|
||||
});
|
||||
|
||||
return unique.length > 0 ? unique : null;
|
||||
});
|
||||
let macStudioEn2Dismissed = $state(false);
|
||||
|
||||
// Helper to get friendly node name from node ID
|
||||
function getNodeName(nodeId: string): string {
|
||||
const node = data?.nodes?.[nodeId];
|
||||
@@ -858,8 +790,10 @@
|
||||
if (!progress || typeof progress !== "object") return null;
|
||||
|
||||
const prog = progress as Record<string, unknown>;
|
||||
const totalBytes = getBytes(prog.total);
|
||||
const downloadedBytes = getBytes(prog.downloaded);
|
||||
const totalBytes = getBytes(prog.total_bytes ?? prog.totalBytes);
|
||||
const downloadedBytes = getBytes(
|
||||
prog.downloaded_bytes ?? prog.downloadedBytes,
|
||||
);
|
||||
const speed = (prog.speed as number) ?? 0;
|
||||
const completedFiles =
|
||||
(prog.completed_files as number) ?? (prog.completedFiles as number) ?? 0;
|
||||
@@ -872,8 +806,8 @@
|
||||
for (const [fileName, fileData] of Object.entries(filesObj)) {
|
||||
if (!fileData || typeof fileData !== "object") continue;
|
||||
const fd = fileData as Record<string, unknown>;
|
||||
const fTotal = getBytes(fd.total);
|
||||
const fDownloaded = getBytes(fd.downloaded);
|
||||
const fTotal = getBytes(fd.total_bytes ?? fd.totalBytes);
|
||||
const fDownloaded = getBytes(fd.downloaded_bytes ?? fd.downloadedBytes);
|
||||
files.push({
|
||||
name: fileName,
|
||||
totalBytes: fTotal,
|
||||
@@ -1262,6 +1196,7 @@
|
||||
if (typeof value === "number") return value;
|
||||
if (value && typeof value === "object") {
|
||||
const v = value as Record<string, unknown>;
|
||||
if (typeof v.in_bytes === "number") return v.in_bytes;
|
||||
if (typeof v.inBytes === "number") return v.inBytes;
|
||||
}
|
||||
return 0;
|
||||
@@ -1823,7 +1758,7 @@
|
||||
</script>
|
||||
|
||||
{#snippet clusterWarnings()}
|
||||
{#if tbBridgeCycles.length > 0 || macosVersionMismatch || (tb5WithoutRdma && !tb5InfoDismissed) || (macStudioEn2RdmaWarning && !macStudioEn2Dismissed)}
|
||||
{#if tbBridgeCycles.length > 0 || macosVersionMismatch || (tb5WithoutRdma && !tb5InfoDismissed)}
|
||||
<div class="absolute top-4 left-4 flex flex-col gap-2 z-40">
|
||||
{#if tbBridgeCycles.length > 0}
|
||||
{@const cycle = tbBridgeCycles[0]}
|
||||
@@ -1988,260 +1923,12 @@
|
||||
</button>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
{#if macStudioEn2RdmaWarning && !macStudioEn2Dismissed}
|
||||
<div class="group relative" role="alert">
|
||||
<div
|
||||
class="flex items-center gap-2 px-3 py-2 rounded border border-red-500/50 bg-red-500/10 backdrop-blur-sm cursor-help"
|
||||
>
|
||||
<svg
|
||||
class="w-5 h-5 text-red-400 flex-shrink-0"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d={warningIconPath}
|
||||
/>
|
||||
</svg>
|
||||
<span class="text-sm font-mono text-red-200">
|
||||
RDMA INCOMPATIBLE PORT
|
||||
</span>
|
||||
<button
|
||||
type="button"
|
||||
onclick={() => (macStudioEn2Dismissed = true)}
|
||||
class="ml-1 text-red-300/60 hover:text-red-200 transition-colors cursor-pointer"
|
||||
title="Dismiss"
|
||||
>
|
||||
<svg
|
||||
class="w-4 h-4"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d="M6 18L18 6M6 6l12 12"
|
||||
/>
|
||||
</svg>
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<!-- Expanded tooltip on hover -->
|
||||
<div
|
||||
class="absolute top-full left-0 mt-2 w-96 p-4 rounded border border-red-500/30 bg-[#1a1a1a]/95 backdrop-blur-sm opacity-0 invisible group-hover:opacity-100 group-hover:visible transition-all duration-200 z-50 shadow-lg"
|
||||
>
|
||||
<p class="text-xs text-white/80 mb-3">
|
||||
The Thunderbolt 5 port next to the Ethernet port on Mac Studio
|
||||
does
|
||||
<span class="text-red-400 font-semibold">not support RDMA</span>.
|
||||
Move the cable to one of the other three TB5 ports.
|
||||
</p>
|
||||
|
||||
<div class="text-xs text-white/60 mb-3">
|
||||
<span class="text-red-300">Affected:</span>
|
||||
{#each macStudioEn2RdmaWarning as conn}
|
||||
<div class="ml-2 mt-0.5">
|
||||
<span class="text-white/80">{conn.nodeName}</span>
|
||||
<span class="text-white/30">→</span>
|
||||
<span class="text-white/60">{conn.peerNodeName}</span>
|
||||
<span class="text-white/30 ml-1">(en2)</span>
|
||||
</div>
|
||||
{/each}
|
||||
</div>
|
||||
|
||||
<!-- Mac Studio back panel illustration -->
|
||||
<div class="bg-black/40 rounded p-3 mb-3">
|
||||
<p
|
||||
class="text-[10px] font-mono text-white/30 uppercase tracking-wider mb-2"
|
||||
>
|
||||
Mac Studio — Rear Panel
|
||||
</p>
|
||||
<svg
|
||||
viewBox="0 0 320 72"
|
||||
class="w-full"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
>
|
||||
<rect
|
||||
x="1"
|
||||
y="1"
|
||||
width="318"
|
||||
height="70"
|
||||
rx="6"
|
||||
ry="6"
|
||||
fill="none"
|
||||
stroke="rgba(255,255,255,0.12)"
|
||||
stroke-width="1"
|
||||
/>
|
||||
<!-- TB5 port 1 -->
|
||||
<rect
|
||||
x="24"
|
||||
y="22"
|
||||
width="28"
|
||||
height="14"
|
||||
rx="4"
|
||||
fill="none"
|
||||
stroke="rgba(255,255,255,0.3)"
|
||||
stroke-width="1"
|
||||
/>
|
||||
<text
|
||||
x="38"
|
||||
y="52"
|
||||
text-anchor="middle"
|
||||
fill="rgba(255,255,255,0.25)"
|
||||
style="font-size:7px;font-family:ui-monospace,monospace;"
|
||||
>TB5</text
|
||||
>
|
||||
<!-- TB5 port 2 -->
|
||||
<rect
|
||||
x="62"
|
||||
y="22"
|
||||
width="28"
|
||||
height="14"
|
||||
rx="4"
|
||||
fill="none"
|
||||
stroke="rgba(255,255,255,0.3)"
|
||||
stroke-width="1"
|
||||
/>
|
||||
<text
|
||||
x="76"
|
||||
y="52"
|
||||
text-anchor="middle"
|
||||
fill="rgba(255,255,255,0.25)"
|
||||
style="font-size:7px;font-family:ui-monospace,monospace;"
|
||||
>TB5</text
|
||||
>
|
||||
<!-- TB5 port 3 -->
|
||||
<rect
|
||||
x="100"
|
||||
y="22"
|
||||
width="28"
|
||||
height="14"
|
||||
rx="4"
|
||||
fill="none"
|
||||
stroke="rgba(255,255,255,0.3)"
|
||||
stroke-width="1"
|
||||
/>
|
||||
<text
|
||||
x="114"
|
||||
y="52"
|
||||
text-anchor="middle"
|
||||
fill="rgba(255,255,255,0.25)"
|
||||
style="font-size:7px;font-family:ui-monospace,monospace;"
|
||||
>TB5</text
|
||||
>
|
||||
<!-- TB5 port 4: INCOMPATIBLE (en2) — equally spaced with ports 1-3 -->
|
||||
<rect
|
||||
x="138"
|
||||
y="22"
|
||||
width="28"
|
||||
height="14"
|
||||
rx="4"
|
||||
fill="rgba(239,68,68,0.1)"
|
||||
stroke="rgba(239,68,68,0.7)"
|
||||
stroke-width="1.5"
|
||||
/>
|
||||
<line
|
||||
x1="142"
|
||||
y1="25"
|
||||
x2="162"
|
||||
y2="33"
|
||||
stroke="rgba(239,68,68,0.8)"
|
||||
stroke-width="1.5"
|
||||
stroke-linecap="round"
|
||||
/>
|
||||
<line
|
||||
x1="162"
|
||||
y1="25"
|
||||
x2="142"
|
||||
y2="33"
|
||||
stroke="rgba(239,68,68,0.8)"
|
||||
stroke-width="1.5"
|
||||
stroke-linecap="round"
|
||||
/>
|
||||
<text
|
||||
x="152"
|
||||
y="52"
|
||||
text-anchor="middle"
|
||||
fill="rgba(239,68,68,0.6)"
|
||||
style="font-size:7px;font-family:ui-monospace,monospace;font-weight:600;"
|
||||
>en2</text
|
||||
>
|
||||
<!-- Ethernet port -->
|
||||
<rect
|
||||
x="196"
|
||||
y="19"
|
||||
width="24"
|
||||
height="20"
|
||||
rx="2"
|
||||
fill="none"
|
||||
stroke="rgba(255,255,255,0.2)"
|
||||
stroke-width="1"
|
||||
/>
|
||||
<rect
|
||||
x="200"
|
||||
y="23"
|
||||
width="16"
|
||||
height="12"
|
||||
rx="1"
|
||||
fill="none"
|
||||
stroke="rgba(255,255,255,0.12)"
|
||||
stroke-width="0.75"
|
||||
/>
|
||||
<text
|
||||
x="208"
|
||||
y="52"
|
||||
text-anchor="middle"
|
||||
fill="rgba(255,255,255,0.25)"
|
||||
style="font-size:7px;font-family:ui-monospace,monospace;"
|
||||
>ETH</text
|
||||
>
|
||||
<!-- Green checkmarks on working ports -->
|
||||
<circle
|
||||
cx="38"
|
||||
cy="62"
|
||||
r="3"
|
||||
fill="none"
|
||||
stroke="rgba(74,222,128,0.5)"
|
||||
stroke-width="0.75"
|
||||
/>
|
||||
<circle
|
||||
cx="76"
|
||||
cy="62"
|
||||
r="3"
|
||||
fill="none"
|
||||
stroke="rgba(74,222,128,0.5)"
|
||||
stroke-width="0.75"
|
||||
/>
|
||||
<circle
|
||||
cx="114"
|
||||
cy="62"
|
||||
r="3"
|
||||
fill="none"
|
||||
stroke="rgba(74,222,128,0.5)"
|
||||
stroke-width="0.75"
|
||||
/>
|
||||
</svg>
|
||||
</div>
|
||||
|
||||
<p class="text-xs text-white/50">
|
||||
<span class="text-green-400">Fix:</span> Move the Thunderbolt cable
|
||||
to any of the three leftmost ports (all support RDMA).
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
{/if}
|
||||
{/snippet}
|
||||
|
||||
{#snippet clusterWarningsCompact()}
|
||||
{#if tbBridgeCycles.length > 0 || macosVersionMismatch || (tb5WithoutRdma && !tb5InfoDismissed) || (macStudioEn2RdmaWarning && !macStudioEn2Dismissed)}
|
||||
{#if tbBridgeCycles.length > 0 || macosVersionMismatch || (tb5WithoutRdma && !tb5InfoDismissed)}
|
||||
<div class="absolute top-2 left-2 flex flex-col gap-1">
|
||||
{#if tbBridgeCycles.length > 0}
|
||||
<div
|
||||
@@ -2309,27 +1996,6 @@
|
||||
>
|
||||
</div>
|
||||
{/if}
|
||||
{#if macStudioEn2RdmaWarning && !macStudioEn2Dismissed}
|
||||
<div
|
||||
class="flex items-center gap-1.5 px-2 py-1 rounded border border-red-500/50 bg-red-500/10 backdrop-blur-sm"
|
||||
title="Mac Studio RDMA incompatible port (en2) — move cable to another TB5 port"
|
||||
>
|
||||
<svg
|
||||
class="w-3.5 h-3.5 text-red-400"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d={warningIconPath}
|
||||
/>
|
||||
</svg>
|
||||
<span class="text-[10px] font-mono text-red-200">BAD RDMA PORT</span>
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
{/if}
|
||||
{/snippet}
|
||||
|
||||
@@ -74,6 +74,7 @@
|
||||
if (typeof value === "number") return value;
|
||||
if (value && typeof value === "object") {
|
||||
const v = value as Record<string, unknown>;
|
||||
if (typeof v.in_bytes === "number") return v.in_bytes;
|
||||
if (typeof v.inBytes === "number") return v.inBytes;
|
||||
}
|
||||
return 0;
|
||||
@@ -230,14 +231,23 @@
|
||||
undefined;
|
||||
let cell: CellStatus;
|
||||
if (tag === "DownloadCompleted") {
|
||||
const totalBytes = getBytes(payload.total);
|
||||
const totalBytes = getBytes(
|
||||
payload.total_bytes ?? payload.totalBytes,
|
||||
);
|
||||
cell = { kind: "completed", totalBytes, modelDirectory };
|
||||
} else if (tag === "DownloadOngoing") {
|
||||
const rawProgress =
|
||||
payload.download_progress ?? payload.downloadProgress ?? {};
|
||||
const prog = rawProgress as Record<string, unknown>;
|
||||
const totalBytes = getBytes(prog.total ?? payload.total);
|
||||
const downloadedBytes = getBytes(prog.downloaded);
|
||||
const totalBytes = getBytes(
|
||||
prog.total_bytes ??
|
||||
prog.totalBytes ??
|
||||
payload.total_bytes ??
|
||||
payload.totalBytes,
|
||||
);
|
||||
const downloadedBytes = getBytes(
|
||||
prog.downloaded_bytes ?? prog.downloadedBytes,
|
||||
);
|
||||
const speed = (prog.speed as number) ?? 0;
|
||||
const etaMs =
|
||||
(prog.eta_ms as number) ?? (prog.etaMs as number) ?? 0;
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import asyncio
|
||||
import socket
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Iterator
|
||||
|
||||
import anyio
|
||||
from anyio import current_time
|
||||
@@ -22,10 +21,10 @@ from exo.shared.types.commands import (
|
||||
ForwarderDownloadCommand,
|
||||
StartDownload,
|
||||
)
|
||||
from exo.shared.types.common import NodeId, SessionId
|
||||
from exo.shared.types.common import NodeId, SessionId, SystemId
|
||||
from exo.shared.types.events import (
|
||||
Event,
|
||||
ForwarderEvent,
|
||||
LocalForwarderEvent,
|
||||
NodeDownloadProgress,
|
||||
)
|
||||
from exo.shared.types.worker.downloads import (
|
||||
@@ -45,9 +44,9 @@ class DownloadCoordinator:
|
||||
session_id: SessionId
|
||||
shard_downloader: ShardDownloader
|
||||
download_command_receiver: Receiver[ForwarderDownloadCommand]
|
||||
local_event_sender: Sender[ForwarderEvent]
|
||||
event_index_counter: Iterator[int]
|
||||
local_event_sender: Sender[LocalForwarderEvent]
|
||||
offline: bool = False
|
||||
_system_id: SystemId = field(default_factory=SystemId)
|
||||
|
||||
# Local state
|
||||
download_status: dict[ModelId, DownloadProgress] = field(default_factory=dict)
|
||||
@@ -80,7 +79,7 @@ class DownloadCoordinator:
|
||||
completed = DownloadCompleted(
|
||||
shard_metadata=callback_shard,
|
||||
node_id=self.node_id,
|
||||
total=progress.total,
|
||||
total_bytes=progress.total_bytes,
|
||||
model_directory=self._model_dir(model_id),
|
||||
)
|
||||
self.download_status[model_id] = completed
|
||||
@@ -203,7 +202,7 @@ class DownloadCoordinator:
|
||||
completed = DownloadCompleted(
|
||||
shard_metadata=shard,
|
||||
node_id=self.node_id,
|
||||
total=initial_progress.total,
|
||||
total_bytes=initial_progress.total_bytes,
|
||||
model_directory=self._model_dir(model_id),
|
||||
)
|
||||
self.download_status[model_id] = completed
|
||||
@@ -298,15 +297,16 @@ class DownloadCoordinator:
|
||||
del self.download_status[model_id]
|
||||
|
||||
async def _forward_events(self) -> None:
|
||||
idx = 0
|
||||
with self.event_receiver as events:
|
||||
async for event in events:
|
||||
idx = next(self.event_index_counter)
|
||||
fe = ForwarderEvent(
|
||||
fe = LocalForwarderEvent(
|
||||
origin_idx=idx,
|
||||
origin=self.node_id,
|
||||
origin=self._system_id,
|
||||
session=self.session_id,
|
||||
event=event,
|
||||
)
|
||||
idx += 1
|
||||
logger.debug(
|
||||
f"DownloadCoordinator published event {idx}: {str(event)[:100]}"
|
||||
)
|
||||
@@ -332,13 +332,13 @@ class DownloadCoordinator:
|
||||
status: DownloadProgress = DownloadCompleted(
|
||||
node_id=self.node_id,
|
||||
shard_metadata=progress.shard,
|
||||
total=progress.total,
|
||||
total_bytes=progress.total_bytes,
|
||||
model_directory=self._model_dir(
|
||||
progress.shard.model_card.model_id
|
||||
),
|
||||
)
|
||||
elif progress.status in ["in_progress", "not_started"]:
|
||||
if progress.downloaded_this_session.in_bytes == 0:
|
||||
if progress.downloaded_bytes_this_session.in_bytes == 0:
|
||||
status = DownloadPending(
|
||||
node_id=self.node_id,
|
||||
shard_metadata=progress.shard,
|
||||
|
||||
@@ -80,9 +80,9 @@ def map_repo_file_download_progress_to_download_progress_data(
|
||||
repo_file_download_progress: RepoFileDownloadProgress,
|
||||
) -> DownloadProgressData:
|
||||
return DownloadProgressData(
|
||||
downloaded=repo_file_download_progress.downloaded,
|
||||
downloaded_this_session=repo_file_download_progress.downloaded_this_session,
|
||||
total=repo_file_download_progress.total,
|
||||
downloaded_bytes=repo_file_download_progress.downloaded,
|
||||
downloaded_bytes_this_session=repo_file_download_progress.downloaded_this_session,
|
||||
total_bytes=repo_file_download_progress.total,
|
||||
completed_files=1 if repo_file_download_progress.status == "complete" else 0,
|
||||
total_files=1,
|
||||
speed=repo_file_download_progress.speed,
|
||||
@@ -95,9 +95,9 @@ def map_repo_download_progress_to_download_progress_data(
|
||||
repo_download_progress: RepoDownloadProgress,
|
||||
) -> DownloadProgressData:
|
||||
return DownloadProgressData(
|
||||
total=repo_download_progress.total,
|
||||
downloaded=repo_download_progress.downloaded,
|
||||
downloaded_this_session=repo_download_progress.downloaded_this_session,
|
||||
total_bytes=repo_download_progress.total_bytes,
|
||||
downloaded_bytes=repo_download_progress.downloaded_bytes,
|
||||
downloaded_bytes_this_session=repo_download_progress.downloaded_bytes_this_session,
|
||||
completed_files=repo_download_progress.completed_files,
|
||||
total_files=repo_download_progress.total_files,
|
||||
speed=repo_download_progress.overall_speed,
|
||||
@@ -578,20 +578,19 @@ def calculate_repo_progress(
|
||||
file_progress: dict[str, RepoFileDownloadProgress],
|
||||
all_start_time: float,
|
||||
) -> RepoDownloadProgress:
|
||||
all_total = sum((p.total for p in file_progress.values()), Memory.from_bytes(0))
|
||||
all_downloaded = sum(
|
||||
(p.downloaded for p in file_progress.values()), Memory.from_bytes(0)
|
||||
all_total_bytes = sum((p.total.in_bytes for p in file_progress.values()), 0)
|
||||
all_downloaded_bytes = sum(
|
||||
(p.downloaded.in_bytes for p in file_progress.values()), 0
|
||||
)
|
||||
all_downloaded_this_session = sum(
|
||||
(p.downloaded_this_session for p in file_progress.values()),
|
||||
Memory.from_bytes(0),
|
||||
all_downloaded_bytes_this_session = sum(
|
||||
(p.downloaded_this_session.in_bytes for p in file_progress.values()), 0
|
||||
)
|
||||
elapsed_time = time.time() - all_start_time
|
||||
all_speed = (
|
||||
all_downloaded_this_session.in_bytes / elapsed_time if elapsed_time > 0 else 0
|
||||
all_downloaded_bytes_this_session / elapsed_time if elapsed_time > 0 else 0
|
||||
)
|
||||
all_eta = (
|
||||
timedelta(seconds=(all_total - all_downloaded).in_bytes / all_speed)
|
||||
timedelta(seconds=(all_total_bytes - all_downloaded_bytes) / all_speed)
|
||||
if all_speed > 0
|
||||
else timedelta(seconds=0)
|
||||
)
|
||||
@@ -610,9 +609,11 @@ def calculate_repo_progress(
|
||||
[p for p in file_progress.values() if p.downloaded == p.total]
|
||||
),
|
||||
total_files=len(file_progress),
|
||||
downloaded=all_downloaded,
|
||||
downloaded_this_session=all_downloaded_this_session,
|
||||
total=all_total,
|
||||
downloaded_bytes=Memory.from_bytes(all_downloaded_bytes),
|
||||
downloaded_bytes_this_session=Memory.from_bytes(
|
||||
all_downloaded_bytes_this_session
|
||||
),
|
||||
total_bytes=Memory.from_bytes(all_total_bytes),
|
||||
overall_speed=all_speed,
|
||||
overall_eta=all_eta,
|
||||
status=status,
|
||||
|
||||
@@ -107,9 +107,9 @@ NOOP_DOWNLOAD_PROGRESS = RepoDownloadProgress(
|
||||
),
|
||||
completed_files=0,
|
||||
total_files=0,
|
||||
downloaded=Memory.from_bytes(0),
|
||||
downloaded_this_session=Memory.from_bytes(0),
|
||||
total=Memory.from_bytes(0),
|
||||
downloaded_bytes=Memory.from_bytes(0),
|
||||
downloaded_bytes_this_session=Memory.from_bytes(0),
|
||||
total_bytes=Memory.from_bytes(0),
|
||||
overall_speed=0,
|
||||
overall_eta=timedelta(seconds=0),
|
||||
status="complete",
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
import argparse
|
||||
import itertools
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
import resource
|
||||
import signal
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Iterator, Self
|
||||
from typing import Self
|
||||
|
||||
import anyio
|
||||
from anyio.abc import TaskGroup
|
||||
@@ -38,12 +37,11 @@ class Node:
|
||||
api: API | None
|
||||
|
||||
node_id: NodeId
|
||||
event_index_counter: Iterator[int]
|
||||
offline: bool
|
||||
_tg: TaskGroup = field(init=False, default_factory=anyio.create_task_group)
|
||||
|
||||
@classmethod
|
||||
async def create(cls, args: "Args") -> "Self":
|
||||
async def create(cls, args: "Args") -> Self:
|
||||
keypair = get_node_id_keypair()
|
||||
node_id = NodeId(keypair.to_node_id())
|
||||
session_id = SessionId(master_node_id=node_id, election_clock=0)
|
||||
@@ -57,9 +55,6 @@ class Node:
|
||||
|
||||
logger.info(f"Starting node {node_id}")
|
||||
|
||||
# Create shared event index counter for Worker and DownloadCoordinator
|
||||
event_index_counter = itertools.count()
|
||||
|
||||
# Create DownloadCoordinator (unless --no-downloads)
|
||||
if not args.no_downloads:
|
||||
download_coordinator = DownloadCoordinator(
|
||||
@@ -68,7 +63,6 @@ class Node:
|
||||
exo_shard_downloader(),
|
||||
download_command_receiver=router.receiver(topics.DOWNLOAD_COMMANDS),
|
||||
local_event_sender=router.sender(topics.LOCAL_EVENTS),
|
||||
event_index_counter=event_index_counter,
|
||||
offline=args.offline,
|
||||
)
|
||||
else:
|
||||
@@ -95,7 +89,6 @@ class Node:
|
||||
local_event_sender=router.sender(topics.LOCAL_EVENTS),
|
||||
command_sender=router.sender(topics.COMMANDS),
|
||||
download_command_sender=router.sender(topics.DOWNLOAD_COMMANDS),
|
||||
event_index_counter=event_index_counter,
|
||||
)
|
||||
else:
|
||||
worker = None
|
||||
@@ -133,7 +126,6 @@ class Node:
|
||||
master,
|
||||
api,
|
||||
node_id,
|
||||
event_index_counter,
|
||||
args.offline,
|
||||
)
|
||||
|
||||
@@ -212,8 +204,6 @@ class Node:
|
||||
)
|
||||
if result.is_new_master:
|
||||
await anyio.sleep(0)
|
||||
# Fresh counter for new session (buffer expects indices from 0)
|
||||
self.event_index_counter = itertools.count()
|
||||
if self.download_coordinator:
|
||||
self.download_coordinator.shutdown()
|
||||
self.download_coordinator = DownloadCoordinator(
|
||||
@@ -224,7 +214,6 @@ class Node:
|
||||
topics.DOWNLOAD_COMMANDS
|
||||
),
|
||||
local_event_sender=self.router.sender(topics.LOCAL_EVENTS),
|
||||
event_index_counter=self.event_index_counter,
|
||||
offline=self.offline,
|
||||
)
|
||||
self._tg.start_soon(self.download_coordinator.run)
|
||||
@@ -242,7 +231,6 @@ class Node:
|
||||
download_command_sender=self.router.sender(
|
||||
topics.DOWNLOAD_COMMANDS
|
||||
),
|
||||
event_index_counter=self.event_index_counter,
|
||||
)
|
||||
self._tg.start_soon(self.worker.run)
|
||||
if self.api:
|
||||
|
||||
@@ -59,11 +59,7 @@ def chat_request_to_text_generation(
|
||||
chat_template_messages.append({"role": "system", "content": content})
|
||||
else:
|
||||
# Skip messages with no meaningful content
|
||||
if (
|
||||
msg.content is None
|
||||
and msg.reasoning_content is None
|
||||
and msg.tool_calls is None
|
||||
):
|
||||
if msg.content is None and msg.thinking is None and msg.tool_calls is None:
|
||||
continue
|
||||
|
||||
if msg.role in ("user", "assistant", "developer"):
|
||||
@@ -115,11 +111,6 @@ def chunk_to_response(
|
||||
]
|
||||
)
|
||||
|
||||
if chunk.is_thinking:
|
||||
delta = ChatCompletionMessage(role="assistant", reasoning_content=chunk.text)
|
||||
else:
|
||||
delta = ChatCompletionMessage(role="assistant", content=chunk.text)
|
||||
|
||||
return ChatCompletionResponse(
|
||||
id=command_id,
|
||||
created=int(time.time()),
|
||||
@@ -127,7 +118,7 @@ def chunk_to_response(
|
||||
choices=[
|
||||
StreamingChoiceResponse(
|
||||
index=0,
|
||||
delta=delta,
|
||||
delta=ChatCompletionMessage(role="assistant", content=chunk.text),
|
||||
logprobs=logprobs,
|
||||
finish_reason=chunk.finish_reason,
|
||||
)
|
||||
@@ -217,7 +208,6 @@ async def collect_chat_response(
|
||||
# FastAPI handles the cancellation better but wouldn't auto-serialize for some reason
|
||||
"""Collect all token chunks and return a single ChatCompletionResponse."""
|
||||
text_parts: list[str] = []
|
||||
thinking_parts: list[str] = []
|
||||
tool_calls: list[ToolCall] = []
|
||||
logprobs_content: list[LogprobsContentItem] = []
|
||||
model: str | None = None
|
||||
@@ -238,10 +228,7 @@ async def collect_chat_response(
|
||||
if model is None:
|
||||
model = chunk.model
|
||||
last_usage = chunk.usage or last_usage
|
||||
if chunk.is_thinking:
|
||||
thinking_parts.append(chunk.text)
|
||||
else:
|
||||
text_parts.append(chunk.text)
|
||||
text_parts.append(chunk.text)
|
||||
if chunk.logprob is not None:
|
||||
logprobs_content.append(
|
||||
LogprobsContentItem(
|
||||
@@ -271,7 +258,6 @@ async def collect_chat_response(
|
||||
raise ValueError(error_message)
|
||||
|
||||
combined_text = "".join(text_parts)
|
||||
combined_thinking = "".join(thinking_parts) if thinking_parts else None
|
||||
assert model is not None
|
||||
|
||||
yield ChatCompletionResponse(
|
||||
@@ -284,7 +270,6 @@ async def collect_chat_response(
|
||||
message=ChatCompletionMessage(
|
||||
role="assistant",
|
||||
content=combined_text,
|
||||
reasoning_content=combined_thinking,
|
||||
tool_calls=tool_calls if tool_calls else None,
|
||||
),
|
||||
logprobs=Logprobs(content=logprobs_content)
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
"""Claude Messages API adapter for converting requests/responses."""
|
||||
|
||||
import json
|
||||
import re
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
|
||||
@@ -29,8 +28,6 @@ from exo.shared.types.claude_api import (
|
||||
ClaudeStopReason,
|
||||
ClaudeTextBlock,
|
||||
ClaudeTextDelta,
|
||||
ClaudeThinkingBlock,
|
||||
ClaudeThinkingDelta,
|
||||
ClaudeToolResultBlock,
|
||||
ClaudeToolUseBlock,
|
||||
ClaudeUsage,
|
||||
@@ -64,22 +61,6 @@ def _extract_tool_result_text(block: ClaudeToolResultBlock) -> str:
|
||||
return "".join(sub_block.text for sub_block in block.content)
|
||||
|
||||
|
||||
# Matches "x-anthropic-billing-header: ...;" (with optional trailing newline)
|
||||
# or similar telemetry headers that change every request and break KV prefix caching.
|
||||
_VOLATILE_HEADER_RE = re.compile(r"^x-anthropic-[^\n]*;\n?", re.MULTILINE)
|
||||
|
||||
|
||||
def _strip_volatile_headers(text: str) -> str:
|
||||
"""Remove Anthropic billing/telemetry headers from system prompt text.
|
||||
|
||||
Claude Code prepends headers like 'x-anthropic-billing-header: cc_version=...;
|
||||
cc_entrypoint=...; cch=...;' that contain per-request content hashes. These
|
||||
change every request and break KV prefix caching (the prefix diverges at ~20
|
||||
tokens instead of matching thousands of conversation tokens).
|
||||
"""
|
||||
return _VOLATILE_HEADER_RE.sub("", text)
|
||||
|
||||
|
||||
def claude_request_to_text_generation(
|
||||
request: ClaudeMessagesRequest,
|
||||
) -> TextGenerationTaskParams:
|
||||
@@ -92,8 +73,6 @@ def claude_request_to_text_generation(
|
||||
instructions = request.system
|
||||
else:
|
||||
instructions = "".join(block.text for block in request.system)
|
||||
|
||||
instructions = _strip_volatile_headers(instructions)
|
||||
chat_template_messages.append({"role": "system", "content": instructions})
|
||||
|
||||
# Convert messages to input
|
||||
@@ -106,15 +85,12 @@ def claude_request_to_text_generation(
|
||||
|
||||
# Process structured content blocks
|
||||
text_parts: list[str] = []
|
||||
thinking_parts: list[str] = []
|
||||
tool_calls: list[dict[str, Any]] = []
|
||||
tool_results: list[ClaudeToolResultBlock] = []
|
||||
|
||||
for block in msg.content:
|
||||
if isinstance(block, ClaudeTextBlock):
|
||||
text_parts.append(block.text)
|
||||
elif isinstance(block, ClaudeThinkingBlock):
|
||||
thinking_parts.append(block.thinking)
|
||||
elif isinstance(block, ClaudeToolUseBlock):
|
||||
tool_calls.append(
|
||||
{
|
||||
@@ -130,7 +106,6 @@ def claude_request_to_text_generation(
|
||||
tool_results.append(block)
|
||||
|
||||
content = "".join(text_parts)
|
||||
reasoning_content = "".join(thinking_parts) if thinking_parts else None
|
||||
|
||||
# Build InputMessage from text content
|
||||
if msg.role in ("user", "assistant"):
|
||||
@@ -138,14 +113,9 @@ def claude_request_to_text_generation(
|
||||
|
||||
# Build chat_template_messages preserving tool structure
|
||||
if tool_calls:
|
||||
chat_msg: dict[str, Any] = {
|
||||
"role": "assistant",
|
||||
"content": content,
|
||||
"tool_calls": tool_calls,
|
||||
}
|
||||
if reasoning_content:
|
||||
chat_msg["reasoning_content"] = reasoning_content
|
||||
chat_template_messages.append(chat_msg)
|
||||
chat_template_messages.append(
|
||||
{"role": "assistant", "content": content, "tool_calls": tool_calls}
|
||||
)
|
||||
elif tool_results:
|
||||
for tr in tool_results:
|
||||
chat_template_messages.append(
|
||||
@@ -156,10 +126,7 @@ def claude_request_to_text_generation(
|
||||
}
|
||||
)
|
||||
else:
|
||||
chat_msg = {"role": msg.role, "content": content}
|
||||
if reasoning_content:
|
||||
chat_msg["reasoning_content"] = reasoning_content
|
||||
chat_template_messages.append(chat_msg)
|
||||
chat_template_messages.append({"role": msg.role, "content": content})
|
||||
|
||||
# Convert Claude tool definitions to OpenAI-style function tools
|
||||
tools: list[dict[str, Any]] | None = None
|
||||
@@ -176,10 +143,6 @@ def claude_request_to_text_generation(
|
||||
for tool in request.tools
|
||||
]
|
||||
|
||||
enable_thinking: bool | None = None
|
||||
if request.thinking is not None:
|
||||
enable_thinking = request.thinking.type in ("enabled", "adaptive")
|
||||
|
||||
return TextGenerationTaskParams(
|
||||
model=request.model,
|
||||
input=input_messages
|
||||
@@ -193,7 +156,6 @@ def claude_request_to_text_generation(
|
||||
stop=request.stop_sequences,
|
||||
stream=request.stream,
|
||||
tools=tools,
|
||||
enable_thinking=enable_thinking,
|
||||
chat_template_messages=chat_template_messages
|
||||
if chat_template_messages
|
||||
else None,
|
||||
@@ -211,7 +173,6 @@ async def collect_claude_response(
|
||||
# FastAPI handles the cancellation better but wouldn't auto-serialize for some reason
|
||||
"""Collect all token chunks and return a single ClaudeMessagesResponse."""
|
||||
text_parts: list[str] = []
|
||||
thinking_parts: list[str] = []
|
||||
tool_use_blocks: list[ClaudeToolUseBlock] = []
|
||||
stop_reason: ClaudeStopReason | None = None
|
||||
last_usage: Usage | None = None
|
||||
@@ -239,10 +200,7 @@ async def collect_claude_response(
|
||||
stop_reason = "tool_use"
|
||||
continue
|
||||
|
||||
if chunk.is_thinking:
|
||||
thinking_parts.append(chunk.text)
|
||||
else:
|
||||
text_parts.append(chunk.text)
|
||||
text_parts.append(chunk.text)
|
||||
|
||||
if chunk.finish_reason is not None:
|
||||
stop_reason = finish_reason_to_claude_stop_reason(chunk.finish_reason)
|
||||
@@ -251,12 +209,9 @@ async def collect_claude_response(
|
||||
raise ValueError(error_message)
|
||||
|
||||
combined_text = "".join(text_parts)
|
||||
combined_thinking = "".join(thinking_parts)
|
||||
|
||||
# Build content blocks
|
||||
content: list[ClaudeContentBlock] = []
|
||||
if combined_thinking:
|
||||
content.append(ClaudeThinkingBlock(thinking=combined_thinking))
|
||||
if combined_text:
|
||||
content.append(ClaudeTextBlock(text=combined_text))
|
||||
content.extend(tool_use_blocks)
|
||||
@@ -301,16 +256,16 @@ async def generate_claude_stream(
|
||||
start_event = ClaudeMessageStartEvent(message=initial_message)
|
||||
yield f"event: message_start\ndata: {start_event.model_dump_json()}\n\n"
|
||||
|
||||
# content_block_start for text block at index 0
|
||||
block_start = ClaudeContentBlockStartEvent(
|
||||
index=0, content_block=ClaudeTextBlock(text="")
|
||||
)
|
||||
yield f"event: content_block_start\ndata: {block_start.model_dump_json()}\n\n"
|
||||
|
||||
output_tokens = 0
|
||||
stop_reason: ClaudeStopReason | None = None
|
||||
last_usage: Usage | None = None
|
||||
next_block_index = 0
|
||||
|
||||
# Track whether we've started thinking/text blocks
|
||||
thinking_block_started = False
|
||||
thinking_block_index = -1
|
||||
text_block_started = False
|
||||
text_block_index = -1
|
||||
next_block_index = 1 # text block is 0, tool blocks start at 1
|
||||
|
||||
async for chunk in chunk_stream:
|
||||
if isinstance(chunk, PrefillProgressChunk):
|
||||
@@ -355,45 +310,12 @@ async def generate_claude_stream(
|
||||
|
||||
output_tokens += 1 # Count each chunk as one token
|
||||
|
||||
if chunk.is_thinking:
|
||||
# Start thinking block on first thinking token
|
||||
if not thinking_block_started:
|
||||
thinking_block_started = True
|
||||
thinking_block_index = next_block_index
|
||||
next_block_index += 1
|
||||
block_start = ClaudeContentBlockStartEvent(
|
||||
index=thinking_block_index,
|
||||
content_block=ClaudeThinkingBlock(thinking=""),
|
||||
)
|
||||
yield f"event: content_block_start\ndata: {block_start.model_dump_json()}\n\n"
|
||||
|
||||
delta_event = ClaudeContentBlockDeltaEvent(
|
||||
index=thinking_block_index,
|
||||
delta=ClaudeThinkingDelta(thinking=chunk.text),
|
||||
)
|
||||
yield f"event: content_block_delta\ndata: {delta_event.model_dump_json()}\n\n"
|
||||
else:
|
||||
# Close thinking block when transitioning to text
|
||||
if thinking_block_started and text_block_index == -1:
|
||||
block_stop = ClaudeContentBlockStopEvent(index=thinking_block_index)
|
||||
yield f"event: content_block_stop\ndata: {block_stop.model_dump_json()}\n\n"
|
||||
|
||||
# Start text block on first text token
|
||||
if not text_block_started:
|
||||
text_block_started = True
|
||||
text_block_index = next_block_index
|
||||
next_block_index += 1
|
||||
block_start = ClaudeContentBlockStartEvent(
|
||||
index=text_block_index,
|
||||
content_block=ClaudeTextBlock(text=""),
|
||||
)
|
||||
yield f"event: content_block_start\ndata: {block_start.model_dump_json()}\n\n"
|
||||
|
||||
delta_event = ClaudeContentBlockDeltaEvent(
|
||||
index=text_block_index,
|
||||
delta=ClaudeTextDelta(text=chunk.text),
|
||||
)
|
||||
yield f"event: content_block_delta\ndata: {delta_event.model_dump_json()}\n\n"
|
||||
# content_block_delta
|
||||
delta_event = ClaudeContentBlockDeltaEvent(
|
||||
index=0,
|
||||
delta=ClaudeTextDelta(text=chunk.text),
|
||||
)
|
||||
yield f"event: content_block_delta\ndata: {delta_event.model_dump_json()}\n\n"
|
||||
|
||||
if chunk.finish_reason is not None:
|
||||
stop_reason = finish_reason_to_claude_stop_reason(chunk.finish_reason)
|
||||
@@ -402,22 +324,9 @@ async def generate_claude_stream(
|
||||
if last_usage is not None:
|
||||
output_tokens = last_usage.completion_tokens
|
||||
|
||||
# Close any open blocks
|
||||
if thinking_block_started and text_block_index == -1:
|
||||
block_stop = ClaudeContentBlockStopEvent(index=thinking_block_index)
|
||||
yield f"event: content_block_stop\ndata: {block_stop.model_dump_json()}\n\n"
|
||||
|
||||
if text_block_started:
|
||||
block_stop = ClaudeContentBlockStopEvent(index=text_block_index)
|
||||
yield f"event: content_block_stop\ndata: {block_stop.model_dump_json()}\n\n"
|
||||
|
||||
if not thinking_block_started and not text_block_started:
|
||||
empty_start = ClaudeContentBlockStartEvent(
|
||||
index=0, content_block=ClaudeTextBlock(text="")
|
||||
)
|
||||
yield f"event: content_block_start\ndata: {empty_start.model_dump_json()}\n\n"
|
||||
empty_stop = ClaudeContentBlockStopEvent(index=0)
|
||||
yield f"event: content_block_stop\ndata: {empty_stop.model_dump_json()}\n\n"
|
||||
# content_block_stop for text block
|
||||
block_stop = ClaudeContentBlockStopEvent(index=0)
|
||||
yield f"event: content_block_stop\ndata: {block_stop.model_dump_json()}\n\n"
|
||||
|
||||
# message_delta
|
||||
message_delta = ClaudeMessageDeltaEvent(
|
||||
|
||||
@@ -1,456 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
|
||||
from exo.shared.types.chunks import (
|
||||
ErrorChunk,
|
||||
PrefillProgressChunk,
|
||||
TokenChunk,
|
||||
ToolCallChunk,
|
||||
)
|
||||
from exo.shared.types.common import CommandId
|
||||
from exo.shared.types.ollama_api import (
|
||||
OllamaChatRequest,
|
||||
OllamaChatResponse,
|
||||
OllamaDoneReason,
|
||||
OllamaGenerateRequest,
|
||||
OllamaGenerateResponse,
|
||||
OllamaMessage,
|
||||
OllamaToolCall,
|
||||
OllamaToolFunction,
|
||||
)
|
||||
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
|
||||
|
||||
|
||||
def _map_done_reason(
|
||||
finish_reason: str | None,
|
||||
) -> OllamaDoneReason | None:
|
||||
if finish_reason is None:
|
||||
return None
|
||||
if finish_reason == "stop":
|
||||
return "stop"
|
||||
if finish_reason == "length":
|
||||
return "length"
|
||||
if finish_reason in ("tool_calls", "function_call"):
|
||||
return "tool_call"
|
||||
if finish_reason == "error":
|
||||
return "error"
|
||||
return "stop"
|
||||
|
||||
|
||||
def _try_parse_json(value: str) -> dict[str, Any] | str:
|
||||
try:
|
||||
return json.loads(value) # type: ignore
|
||||
except json.JSONDecodeError:
|
||||
return value
|
||||
|
||||
|
||||
def _build_tool_calls(chunk: ToolCallChunk) -> list[OllamaToolCall]:
|
||||
tool_calls: list[OllamaToolCall] = []
|
||||
for index, tool in enumerate(chunk.tool_calls):
|
||||
# tool.arguments is always str; try to parse as JSON dict for Ollama format
|
||||
arguments: dict[str, Any] | str = _try_parse_json(tool.arguments)
|
||||
tool_calls.append(
|
||||
OllamaToolCall(
|
||||
id=tool.id,
|
||||
type="function",
|
||||
function=OllamaToolFunction(
|
||||
name=tool.name, arguments=arguments, index=index
|
||||
),
|
||||
)
|
||||
)
|
||||
return tool_calls
|
||||
|
||||
|
||||
def _get_usage(
|
||||
chunk: TokenChunk | ToolCallChunk,
|
||||
) -> tuple[int | None, int | None]:
|
||||
"""Extract (prompt_eval_count, eval_count) from a chunk."""
|
||||
if chunk.usage is not None:
|
||||
return (chunk.usage.prompt_tokens, chunk.usage.completion_tokens)
|
||||
if chunk.stats is not None:
|
||||
return (chunk.stats.prompt_tokens, chunk.stats.generation_tokens)
|
||||
return (None, None)
|
||||
|
||||
|
||||
def ollama_request_to_text_generation(
|
||||
request: OllamaChatRequest,
|
||||
) -> TextGenerationTaskParams:
|
||||
"""Convert Ollama chat request to exo's internal text generation format."""
|
||||
instructions: str | None = None
|
||||
input_messages: list[InputMessage] = []
|
||||
chat_template_messages: list[dict[str, Any]] = []
|
||||
tool_message_index = 0
|
||||
|
||||
for msg in request.messages:
|
||||
content = msg.content or ""
|
||||
|
||||
if msg.role == "system":
|
||||
if instructions is None:
|
||||
instructions = content
|
||||
else:
|
||||
instructions = f"{instructions}\n{content}"
|
||||
chat_template_messages.append({"role": "system", "content": content})
|
||||
continue
|
||||
|
||||
if msg.role in ("user", "assistant") and (
|
||||
msg.content is not None or msg.thinking is not None or msg.tool_calls
|
||||
):
|
||||
input_messages.append(InputMessage(role=msg.role, content=content))
|
||||
|
||||
dumped: dict[str, Any] = {"role": msg.role, "content": content}
|
||||
if msg.thinking is not None:
|
||||
dumped["thinking"] = msg.thinking
|
||||
if msg.tool_calls is not None:
|
||||
tool_calls_list: list[dict[str, Any]] = []
|
||||
for tc in msg.tool_calls:
|
||||
function: dict[str, Any] = {
|
||||
"name": tc.function.name,
|
||||
"arguments": (
|
||||
json.dumps(tc.function.arguments)
|
||||
if isinstance(tc.function.arguments, dict)
|
||||
else tc.function.arguments
|
||||
),
|
||||
}
|
||||
if tc.function.index is not None:
|
||||
function["index"] = tc.function.index
|
||||
tool_call: dict[str, Any] = {"function": function}
|
||||
if tc.id is not None:
|
||||
tool_call["id"] = tc.id
|
||||
if tc.type is not None:
|
||||
tool_call["type"] = tc.type
|
||||
tool_calls_list.append(tool_call)
|
||||
dumped["tool_calls"] = tool_calls_list
|
||||
if msg.name is not None:
|
||||
dumped["name"] = msg.name
|
||||
if msg.role == "tool":
|
||||
tool_message_index += 1
|
||||
tool_call_id = msg.tool_name or msg.name or f"tool_{tool_message_index}"
|
||||
dumped["tool_call_id"] = tool_call_id
|
||||
if msg.tool_name is not None:
|
||||
dumped["tool_name"] = msg.tool_name
|
||||
chat_template_messages.append(dumped)
|
||||
|
||||
options = request.options
|
||||
return TextGenerationTaskParams(
|
||||
model=request.model,
|
||||
input=input_messages
|
||||
if input_messages
|
||||
else [InputMessage(role="user", content="")],
|
||||
instructions=instructions,
|
||||
max_output_tokens=options.num_predict if options else None,
|
||||
temperature=options.temperature if options else None,
|
||||
top_p=options.top_p if options else None,
|
||||
top_k=options.top_k if options else None,
|
||||
stop=options.stop if options else None,
|
||||
seed=options.seed if options else None,
|
||||
stream=request.stream,
|
||||
tools=request.tools,
|
||||
enable_thinking=request.think,
|
||||
chat_template_messages=chat_template_messages
|
||||
if chat_template_messages
|
||||
else None,
|
||||
)
|
||||
|
||||
|
||||
async def generate_ollama_chat_stream(
|
||||
_command_id: CommandId,
|
||||
chunk_stream: AsyncGenerator[
|
||||
ErrorChunk | ToolCallChunk | TokenChunk | PrefillProgressChunk, None
|
||||
],
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Generate streaming responses in Ollama format (newline-delimited JSON)."""
|
||||
thinking_parts: list[str] = []
|
||||
|
||||
async for chunk in chunk_stream:
|
||||
match chunk:
|
||||
case PrefillProgressChunk():
|
||||
continue
|
||||
|
||||
case ErrorChunk():
|
||||
error_response = OllamaChatResponse(
|
||||
model=str(chunk.model),
|
||||
message=OllamaMessage(
|
||||
role="assistant", content=chunk.error_message
|
||||
),
|
||||
done=True,
|
||||
done_reason="error",
|
||||
)
|
||||
yield f"{error_response.model_dump_json(exclude_none=True)}\n"
|
||||
return
|
||||
|
||||
case ToolCallChunk():
|
||||
prompt_eval, eval_count = _get_usage(chunk)
|
||||
response = OllamaChatResponse(
|
||||
model=str(chunk.model),
|
||||
message=OllamaMessage(
|
||||
role="assistant",
|
||||
content="",
|
||||
tool_calls=_build_tool_calls(chunk),
|
||||
thinking="".join(thinking_parts) if thinking_parts else None,
|
||||
),
|
||||
done=True,
|
||||
done_reason="tool_call",
|
||||
prompt_eval_count=prompt_eval,
|
||||
eval_count=eval_count,
|
||||
)
|
||||
yield f"{response.model_dump_json(exclude_none=True)}\n"
|
||||
return
|
||||
|
||||
case TokenChunk():
|
||||
done = chunk.finish_reason is not None
|
||||
|
||||
if chunk.is_thinking:
|
||||
thinking_parts.append(chunk.text)
|
||||
response = OllamaChatResponse(
|
||||
model=str(chunk.model),
|
||||
message=OllamaMessage(
|
||||
role="assistant", content="", thinking=chunk.text
|
||||
),
|
||||
done=False,
|
||||
)
|
||||
yield f"{response.model_dump_json(exclude_none=True)}\n"
|
||||
elif done:
|
||||
prompt_eval, eval_count = _get_usage(chunk)
|
||||
response = OllamaChatResponse(
|
||||
model=str(chunk.model),
|
||||
message=OllamaMessage(
|
||||
role="assistant",
|
||||
content=chunk.text,
|
||||
),
|
||||
done=True,
|
||||
done_reason=_map_done_reason(chunk.finish_reason),
|
||||
prompt_eval_count=prompt_eval,
|
||||
eval_count=eval_count,
|
||||
)
|
||||
yield f"{response.model_dump_json(exclude_none=True)}\n"
|
||||
else:
|
||||
response = OllamaChatResponse(
|
||||
model=str(chunk.model),
|
||||
message=OllamaMessage(role="assistant", content=chunk.text),
|
||||
done=False,
|
||||
)
|
||||
yield f"{response.model_dump_json(exclude_none=True)}\n"
|
||||
|
||||
if done:
|
||||
return
|
||||
|
||||
|
||||
async def collect_ollama_chat_response(
|
||||
_command_id: CommandId,
|
||||
chunk_stream: AsyncGenerator[
|
||||
ErrorChunk | ToolCallChunk | TokenChunk | PrefillProgressChunk, None
|
||||
],
|
||||
) -> AsyncGenerator[str]:
|
||||
"""Collect streaming chunks into a single non-streaming Ollama response.
|
||||
|
||||
Returns an AsyncGenerator[str] (single yield) for consistency with FastAPI
|
||||
StreamingResponse cancellation handling.
|
||||
"""
|
||||
text_parts: list[str] = []
|
||||
thinking_parts: list[str] = []
|
||||
tool_calls: list[OllamaToolCall] = []
|
||||
model: str | None = None
|
||||
finish_reason: str | None = None
|
||||
prompt_eval_count: int | None = None
|
||||
eval_count: int | None = None
|
||||
|
||||
async for chunk in chunk_stream:
|
||||
match chunk:
|
||||
case PrefillProgressChunk():
|
||||
continue
|
||||
|
||||
case ErrorChunk():
|
||||
raise ValueError(chunk.error_message or "Internal server error")
|
||||
|
||||
case TokenChunk():
|
||||
if model is None:
|
||||
model = str(chunk.model)
|
||||
if chunk.is_thinking:
|
||||
thinking_parts.append(chunk.text)
|
||||
else:
|
||||
text_parts.append(chunk.text)
|
||||
if chunk.finish_reason is not None:
|
||||
finish_reason = chunk.finish_reason
|
||||
prompt_eval_count, eval_count = _get_usage(chunk)
|
||||
|
||||
case ToolCallChunk():
|
||||
if model is None:
|
||||
model = str(chunk.model)
|
||||
tool_calls.extend(_build_tool_calls(chunk))
|
||||
finish_reason = chunk.finish_reason
|
||||
prompt_eval_count, eval_count = _get_usage(chunk)
|
||||
|
||||
combined_text = "".join(text_parts)
|
||||
combined_thinking = "".join(thinking_parts) if thinking_parts else None
|
||||
assert model is not None
|
||||
|
||||
yield OllamaChatResponse(
|
||||
model=model,
|
||||
message=OllamaMessage(
|
||||
role="assistant",
|
||||
content=combined_text,
|
||||
thinking=combined_thinking,
|
||||
tool_calls=tool_calls if tool_calls else None,
|
||||
),
|
||||
done=True,
|
||||
done_reason=_map_done_reason(finish_reason),
|
||||
prompt_eval_count=prompt_eval_count,
|
||||
eval_count=eval_count,
|
||||
).model_dump_json(exclude_none=True)
|
||||
return
|
||||
|
||||
|
||||
# ── /api/generate ──
|
||||
|
||||
|
||||
def ollama_generate_request_to_text_generation(
|
||||
request: OllamaGenerateRequest,
|
||||
) -> TextGenerationTaskParams:
|
||||
"""Convert Ollama generate request to exo's internal text generation format."""
|
||||
chat_template_messages: list[dict[str, Any]] = []
|
||||
if request.system:
|
||||
chat_template_messages.append({"role": "system", "content": request.system})
|
||||
chat_template_messages.append({"role": "user", "content": request.prompt})
|
||||
|
||||
options = request.options
|
||||
return TextGenerationTaskParams(
|
||||
model=request.model,
|
||||
input=[InputMessage(role="user", content=request.prompt)],
|
||||
instructions=request.system,
|
||||
max_output_tokens=options.num_predict if options else None,
|
||||
temperature=options.temperature if options else None,
|
||||
top_p=options.top_p if options else None,
|
||||
top_k=options.top_k if options else None,
|
||||
stop=options.stop if options else None,
|
||||
seed=options.seed if options else None,
|
||||
stream=request.stream,
|
||||
enable_thinking=request.think,
|
||||
chat_template_messages=chat_template_messages
|
||||
if chat_template_messages
|
||||
else None,
|
||||
)
|
||||
|
||||
|
||||
async def generate_ollama_generate_stream(
|
||||
_command_id: CommandId,
|
||||
chunk_stream: AsyncGenerator[
|
||||
ErrorChunk | ToolCallChunk | TokenChunk | PrefillProgressChunk, None
|
||||
],
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Generate streaming responses for /api/generate in Ollama NDJSON format."""
|
||||
thinking_parts: list[str] = []
|
||||
|
||||
async for chunk in chunk_stream:
|
||||
match chunk:
|
||||
case PrefillProgressChunk():
|
||||
continue
|
||||
|
||||
case ErrorChunk():
|
||||
resp = OllamaGenerateResponse(
|
||||
model=str(chunk.model),
|
||||
response="",
|
||||
done=True,
|
||||
done_reason="error",
|
||||
)
|
||||
yield f"{resp.model_dump_json(exclude_none=True)}\n"
|
||||
return
|
||||
|
||||
case ToolCallChunk():
|
||||
# generate endpoint doesn't support tools; emit as done
|
||||
prompt_eval, eval_count = _get_usage(chunk)
|
||||
resp = OllamaGenerateResponse(
|
||||
model=str(chunk.model),
|
||||
response="",
|
||||
done=True,
|
||||
done_reason="stop",
|
||||
prompt_eval_count=prompt_eval,
|
||||
eval_count=eval_count,
|
||||
)
|
||||
yield f"{resp.model_dump_json(exclude_none=True)}\n"
|
||||
return
|
||||
|
||||
case TokenChunk():
|
||||
done = chunk.finish_reason is not None
|
||||
|
||||
if chunk.is_thinking:
|
||||
thinking_parts.append(chunk.text)
|
||||
resp = OllamaGenerateResponse(
|
||||
model=str(chunk.model),
|
||||
response="",
|
||||
thinking=chunk.text,
|
||||
done=False,
|
||||
)
|
||||
yield f"{resp.model_dump_json(exclude_none=True)}\n"
|
||||
elif done:
|
||||
prompt_eval, eval_count = _get_usage(chunk)
|
||||
resp = OllamaGenerateResponse(
|
||||
model=str(chunk.model),
|
||||
response=chunk.text,
|
||||
done=True,
|
||||
done_reason=_map_done_reason(chunk.finish_reason),
|
||||
prompt_eval_count=prompt_eval,
|
||||
eval_count=eval_count,
|
||||
)
|
||||
yield f"{resp.model_dump_json(exclude_none=True)}\n"
|
||||
else:
|
||||
resp = OllamaGenerateResponse(
|
||||
model=str(chunk.model),
|
||||
response=chunk.text,
|
||||
done=False,
|
||||
)
|
||||
yield f"{resp.model_dump_json(exclude_none=True)}\n"
|
||||
|
||||
if done:
|
||||
return
|
||||
|
||||
|
||||
async def collect_ollama_generate_response(
|
||||
_command_id: CommandId,
|
||||
chunk_stream: AsyncGenerator[
|
||||
ErrorChunk | ToolCallChunk | TokenChunk | PrefillProgressChunk, None
|
||||
],
|
||||
) -> AsyncGenerator[str]:
|
||||
"""Collect chunks into a single non-streaming /api/generate response."""
|
||||
text_parts: list[str] = []
|
||||
thinking_parts: list[str] = []
|
||||
model: str | None = None
|
||||
finish_reason: str | None = None
|
||||
prompt_eval_count: int | None = None
|
||||
eval_count: int | None = None
|
||||
|
||||
async for chunk in chunk_stream:
|
||||
match chunk:
|
||||
case PrefillProgressChunk():
|
||||
continue
|
||||
case ErrorChunk():
|
||||
raise ValueError(chunk.error_message or "Internal server error")
|
||||
case TokenChunk():
|
||||
if model is None:
|
||||
model = str(chunk.model)
|
||||
if chunk.is_thinking:
|
||||
thinking_parts.append(chunk.text)
|
||||
else:
|
||||
text_parts.append(chunk.text)
|
||||
if chunk.finish_reason is not None:
|
||||
finish_reason = chunk.finish_reason
|
||||
prompt_eval_count, eval_count = _get_usage(chunk)
|
||||
case ToolCallChunk():
|
||||
if model is None:
|
||||
model = str(chunk.model)
|
||||
finish_reason = chunk.finish_reason
|
||||
prompt_eval_count, eval_count = _get_usage(chunk)
|
||||
|
||||
assert model is not None
|
||||
yield OllamaGenerateResponse(
|
||||
model=model,
|
||||
response="".join(text_parts),
|
||||
thinking="".join(thinking_parts) if thinking_parts else None,
|
||||
done=True,
|
||||
done_reason=_map_done_reason(finish_reason),
|
||||
prompt_eval_count=prompt_eval_count,
|
||||
eval_count=eval_count,
|
||||
).model_dump_json(exclude_none=True)
|
||||
return
|
||||
@@ -29,12 +29,6 @@ from exo.shared.types.openai_responses import (
|
||||
ResponseOutputItemAddedEvent,
|
||||
ResponseOutputItemDoneEvent,
|
||||
ResponseOutputText,
|
||||
ResponseReasoningItem,
|
||||
ResponseReasoningSummaryPartAddedEvent,
|
||||
ResponseReasoningSummaryPartDoneEvent,
|
||||
ResponseReasoningSummaryText,
|
||||
ResponseReasoningSummaryTextDeltaEvent,
|
||||
ResponseReasoningSummaryTextDoneEvent,
|
||||
ResponsesRequest,
|
||||
ResponsesResponse,
|
||||
ResponsesStreamEvent,
|
||||
@@ -147,9 +141,7 @@ async def collect_responses_response(
|
||||
"""Collect all token chunks and return a single ResponsesResponse."""
|
||||
response_id = f"resp_{command_id}"
|
||||
item_id = f"item_{command_id}"
|
||||
reasoning_id = f"rs_{command_id}"
|
||||
accumulated_text = ""
|
||||
thinking_parts: list[str] = []
|
||||
function_call_items: list[ResponseFunctionCallItem] = []
|
||||
last_usage: Usage | None = None
|
||||
error_message: str | None = None
|
||||
@@ -176,10 +168,6 @@ async def collect_responses_response(
|
||||
)
|
||||
continue
|
||||
|
||||
if chunk.is_thinking:
|
||||
thinking_parts.append(chunk.text)
|
||||
continue
|
||||
|
||||
accumulated_text += chunk.text
|
||||
|
||||
if error_message is not None:
|
||||
@@ -194,21 +182,13 @@ async def collect_responses_response(
|
||||
total_tokens=last_usage.total_tokens,
|
||||
)
|
||||
|
||||
output: list[ResponseItem] = []
|
||||
if thinking_parts:
|
||||
output.append(
|
||||
ResponseReasoningItem(
|
||||
id=reasoning_id,
|
||||
summary=[ResponseReasoningSummaryText(text="".join(thinking_parts))],
|
||||
)
|
||||
)
|
||||
output.append(
|
||||
output: list[ResponseItem] = [
|
||||
ResponseMessageItem(
|
||||
id=item_id,
|
||||
content=[ResponseOutputText(text=accumulated_text)],
|
||||
status="completed",
|
||||
)
|
||||
)
|
||||
]
|
||||
output.extend(function_call_items)
|
||||
|
||||
yield ResponsesResponse(
|
||||
@@ -232,7 +212,6 @@ async def generate_responses_stream(
|
||||
"""Generate OpenAI Responses API streaming events from TokenChunks."""
|
||||
response_id = f"resp_{command_id}"
|
||||
item_id = f"item_{command_id}"
|
||||
reasoning_id = f"rs_{command_id}"
|
||||
seq = count(1)
|
||||
|
||||
# response.created
|
||||
@@ -254,17 +233,32 @@ async def generate_responses_stream(
|
||||
)
|
||||
yield _format_sse(in_progress_event)
|
||||
|
||||
# response.output_item.added
|
||||
initial_item = ResponseMessageItem(
|
||||
id=item_id,
|
||||
content=[ResponseOutputText(text="")],
|
||||
status="in_progress",
|
||||
)
|
||||
item_added = ResponseOutputItemAddedEvent(
|
||||
sequence_number=next(seq), output_index=0, item=initial_item
|
||||
)
|
||||
yield _format_sse(item_added)
|
||||
|
||||
# response.content_part.added
|
||||
initial_part = ResponseOutputText(text="")
|
||||
part_added = ResponseContentPartAddedEvent(
|
||||
sequence_number=next(seq),
|
||||
item_id=item_id,
|
||||
output_index=0,
|
||||
content_index=0,
|
||||
part=initial_part,
|
||||
)
|
||||
yield _format_sse(part_added)
|
||||
|
||||
accumulated_text = ""
|
||||
accumulated_thinking = ""
|
||||
function_call_items: list[ResponseFunctionCallItem] = []
|
||||
last_usage: Usage | None = None
|
||||
next_output_index = 0
|
||||
|
||||
# Track dynamic block creation
|
||||
reasoning_started = False
|
||||
reasoning_output_index = -1
|
||||
message_started = False
|
||||
message_output_index = -1
|
||||
next_output_index = 1 # message item is at 0
|
||||
|
||||
async for chunk in chunk_stream:
|
||||
if isinstance(chunk, PrefillProgressChunk):
|
||||
@@ -333,184 +327,23 @@ async def generate_responses_stream(
|
||||
next_output_index += 1
|
||||
continue
|
||||
|
||||
if chunk.is_thinking:
|
||||
# Start reasoning block on first thinking token
|
||||
if not reasoning_started:
|
||||
reasoning_started = True
|
||||
reasoning_output_index = next_output_index
|
||||
next_output_index += 1
|
||||
|
||||
# response.output_item.added for reasoning
|
||||
reasoning_item = ResponseReasoningItem(
|
||||
id=reasoning_id,
|
||||
summary=[],
|
||||
status="in_progress",
|
||||
)
|
||||
rs_added = ResponseOutputItemAddedEvent(
|
||||
sequence_number=next(seq),
|
||||
output_index=reasoning_output_index,
|
||||
item=reasoning_item,
|
||||
)
|
||||
yield _format_sse(rs_added)
|
||||
|
||||
# response.reasoning_summary_part.added
|
||||
part_added = ResponseReasoningSummaryPartAddedEvent(
|
||||
sequence_number=next(seq),
|
||||
item_id=reasoning_id,
|
||||
output_index=reasoning_output_index,
|
||||
summary_index=0,
|
||||
part=ResponseReasoningSummaryText(text=""),
|
||||
)
|
||||
yield _format_sse(part_added)
|
||||
|
||||
accumulated_thinking += chunk.text
|
||||
|
||||
# response.reasoning_summary_text.delta
|
||||
rs_delta = ResponseReasoningSummaryTextDeltaEvent(
|
||||
sequence_number=next(seq),
|
||||
item_id=reasoning_id,
|
||||
output_index=reasoning_output_index,
|
||||
summary_index=0,
|
||||
delta=chunk.text,
|
||||
)
|
||||
yield _format_sse(rs_delta)
|
||||
continue
|
||||
|
||||
# Close reasoning block when transitioning to text
|
||||
if reasoning_started and not message_started:
|
||||
# response.reasoning_summary_text.done
|
||||
rs_text_done = ResponseReasoningSummaryTextDoneEvent(
|
||||
sequence_number=next(seq),
|
||||
item_id=reasoning_id,
|
||||
output_index=reasoning_output_index,
|
||||
summary_index=0,
|
||||
text=accumulated_thinking,
|
||||
)
|
||||
yield _format_sse(rs_text_done)
|
||||
|
||||
# response.reasoning_summary_part.done
|
||||
rs_part_done = ResponseReasoningSummaryPartDoneEvent(
|
||||
sequence_number=next(seq),
|
||||
item_id=reasoning_id,
|
||||
output_index=reasoning_output_index,
|
||||
summary_index=0,
|
||||
part=ResponseReasoningSummaryText(text=accumulated_thinking),
|
||||
)
|
||||
yield _format_sse(rs_part_done)
|
||||
|
||||
# response.output_item.done for reasoning
|
||||
rs_item_done = ResponseOutputItemDoneEvent(
|
||||
sequence_number=next(seq),
|
||||
output_index=reasoning_output_index,
|
||||
item=ResponseReasoningItem(
|
||||
id=reasoning_id,
|
||||
summary=[ResponseReasoningSummaryText(text=accumulated_thinking)],
|
||||
),
|
||||
)
|
||||
yield _format_sse(rs_item_done)
|
||||
|
||||
# Start message block on first text token
|
||||
if not message_started:
|
||||
message_started = True
|
||||
message_output_index = next_output_index
|
||||
next_output_index += 1
|
||||
|
||||
initial_item = ResponseMessageItem(
|
||||
id=item_id,
|
||||
content=[ResponseOutputText(text="")],
|
||||
status="in_progress",
|
||||
)
|
||||
item_added = ResponseOutputItemAddedEvent(
|
||||
sequence_number=next(seq),
|
||||
output_index=message_output_index,
|
||||
item=initial_item,
|
||||
)
|
||||
yield _format_sse(item_added)
|
||||
|
||||
initial_part = ResponseOutputText(text="")
|
||||
part_added = ResponseContentPartAddedEvent(
|
||||
sequence_number=next(seq),
|
||||
item_id=item_id,
|
||||
output_index=message_output_index,
|
||||
content_index=0,
|
||||
part=initial_part,
|
||||
)
|
||||
yield _format_sse(part_added)
|
||||
|
||||
accumulated_text += chunk.text
|
||||
|
||||
# response.output_text.delta
|
||||
delta_event = ResponseTextDeltaEvent(
|
||||
sequence_number=next(seq),
|
||||
item_id=item_id,
|
||||
output_index=message_output_index,
|
||||
output_index=0,
|
||||
content_index=0,
|
||||
delta=chunk.text,
|
||||
)
|
||||
yield _format_sse(delta_event)
|
||||
|
||||
# Close reasoning block if it was never followed by text
|
||||
if reasoning_started and not message_started:
|
||||
rs_text_done = ResponseReasoningSummaryTextDoneEvent(
|
||||
sequence_number=next(seq),
|
||||
item_id=reasoning_id,
|
||||
output_index=reasoning_output_index,
|
||||
summary_index=0,
|
||||
text=accumulated_thinking,
|
||||
)
|
||||
yield _format_sse(rs_text_done)
|
||||
|
||||
rs_part_done = ResponseReasoningSummaryPartDoneEvent(
|
||||
sequence_number=next(seq),
|
||||
item_id=reasoning_id,
|
||||
output_index=reasoning_output_index,
|
||||
summary_index=0,
|
||||
part=ResponseReasoningSummaryText(text=accumulated_thinking),
|
||||
)
|
||||
yield _format_sse(rs_part_done)
|
||||
|
||||
rs_item_done = ResponseOutputItemDoneEvent(
|
||||
sequence_number=next(seq),
|
||||
output_index=reasoning_output_index,
|
||||
item=ResponseReasoningItem(
|
||||
id=reasoning_id,
|
||||
summary=[ResponseReasoningSummaryText(text=accumulated_thinking)],
|
||||
),
|
||||
)
|
||||
yield _format_sse(rs_item_done)
|
||||
|
||||
# If no message block was started, create one now (empty text)
|
||||
if not message_started:
|
||||
message_output_index = next_output_index
|
||||
next_output_index += 1
|
||||
|
||||
initial_item = ResponseMessageItem(
|
||||
id=item_id,
|
||||
content=[ResponseOutputText(text="")],
|
||||
status="in_progress",
|
||||
)
|
||||
item_added = ResponseOutputItemAddedEvent(
|
||||
sequence_number=next(seq),
|
||||
output_index=message_output_index,
|
||||
item=initial_item,
|
||||
)
|
||||
yield _format_sse(item_added)
|
||||
|
||||
initial_part = ResponseOutputText(text="")
|
||||
part_added_evt = ResponseContentPartAddedEvent(
|
||||
sequence_number=next(seq),
|
||||
item_id=item_id,
|
||||
output_index=message_output_index,
|
||||
content_index=0,
|
||||
part=initial_part,
|
||||
)
|
||||
yield _format_sse(part_added_evt)
|
||||
|
||||
# response.output_text.done
|
||||
text_done = ResponseTextDoneEvent(
|
||||
sequence_number=next(seq),
|
||||
item_id=item_id,
|
||||
output_index=message_output_index,
|
||||
output_index=0,
|
||||
content_index=0,
|
||||
text=accumulated_text,
|
||||
)
|
||||
@@ -521,7 +354,7 @@ async def generate_responses_stream(
|
||||
part_done = ResponseContentPartDoneEvent(
|
||||
sequence_number=next(seq),
|
||||
item_id=item_id,
|
||||
output_index=message_output_index,
|
||||
output_index=0,
|
||||
content_index=0,
|
||||
part=final_part,
|
||||
)
|
||||
@@ -534,9 +367,7 @@ async def generate_responses_stream(
|
||||
status="completed",
|
||||
)
|
||||
item_done = ResponseOutputItemDoneEvent(
|
||||
sequence_number=next(seq),
|
||||
output_index=message_output_index,
|
||||
item=final_message_item,
|
||||
sequence_number=next(seq), output_index=0, item=final_message_item
|
||||
)
|
||||
yield _format_sse(item_done)
|
||||
|
||||
@@ -550,15 +381,7 @@ async def generate_responses_stream(
|
||||
)
|
||||
|
||||
# response.completed
|
||||
output: list[ResponseItem] = []
|
||||
if reasoning_started:
|
||||
output.append(
|
||||
ResponseReasoningItem(
|
||||
id=reasoning_id,
|
||||
summary=[ResponseReasoningSummaryText(text=accumulated_thinking)],
|
||||
)
|
||||
)
|
||||
output.append(final_message_item)
|
||||
output: list[ResponseItem] = [final_message_item]
|
||||
output.extend(function_call_items)
|
||||
final_response = ResponsesResponse(
|
||||
id=response_id,
|
||||
|
||||
@@ -32,14 +32,6 @@ from exo.master.adapters.claude import (
|
||||
collect_claude_response,
|
||||
generate_claude_stream,
|
||||
)
|
||||
from exo.master.adapters.ollama import (
|
||||
collect_ollama_chat_response,
|
||||
collect_ollama_generate_response,
|
||||
generate_ollama_chat_stream,
|
||||
generate_ollama_generate_stream,
|
||||
ollama_generate_request_to_text_generation,
|
||||
ollama_request_to_text_generation,
|
||||
)
|
||||
from exo.master.adapters.responses import (
|
||||
collect_responses_response,
|
||||
generate_responses_stream,
|
||||
@@ -140,28 +132,16 @@ from exo.shared.types.commands import (
|
||||
TaskFinished,
|
||||
TextGeneration,
|
||||
)
|
||||
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
|
||||
from exo.shared.types.common import CommandId, Id, NodeId, SessionId, SystemId
|
||||
from exo.shared.types.events import (
|
||||
ChunkGenerated,
|
||||
Event,
|
||||
ForwarderEvent,
|
||||
GlobalForwarderEvent,
|
||||
IndexedEvent,
|
||||
PrefillProgress,
|
||||
TracesMerged,
|
||||
)
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.ollama_api import (
|
||||
OllamaChatRequest,
|
||||
OllamaChatResponse,
|
||||
OllamaGenerateRequest,
|
||||
OllamaGenerateResponse,
|
||||
OllamaModelDetails,
|
||||
OllamaModelTag,
|
||||
OllamaPsModel,
|
||||
OllamaPsResponse,
|
||||
OllamaShowRequest,
|
||||
OllamaShowResponse,
|
||||
OllamaTagsResponse,
|
||||
)
|
||||
from exo.shared.types.openai_responses import (
|
||||
ResponsesRequest,
|
||||
ResponsesResponse,
|
||||
@@ -197,8 +177,7 @@ class API:
|
||||
session_id: SessionId,
|
||||
*,
|
||||
port: int,
|
||||
# Ideally this would be a MasterForwarderEvent but type system says no :(
|
||||
global_event_receiver: Receiver[ForwarderEvent],
|
||||
global_event_receiver: Receiver[GlobalForwarderEvent],
|
||||
command_sender: Sender[ForwarderCommand],
|
||||
download_command_sender: Sender[ForwarderDownloadCommand],
|
||||
# This lets us pause the API if an election is running
|
||||
@@ -206,6 +185,7 @@ class API:
|
||||
) -> None:
|
||||
self.state = State()
|
||||
self._event_log = DiskEventLog(_API_EVENT_LOG_DIR)
|
||||
self._system_id = SystemId()
|
||||
self.command_sender = command_sender
|
||||
self.download_command_sender = download_command_sender
|
||||
self.global_event_receiver = global_event_receiver
|
||||
@@ -257,6 +237,7 @@ class API:
|
||||
self._event_log.close()
|
||||
self._event_log = DiskEventLog(_API_EVENT_LOG_DIR)
|
||||
self.state = State()
|
||||
self._system_id = SystemId()
|
||||
self.session_id = new_session_id
|
||||
self.event_buffer = OrderedBuffer[Event]()
|
||||
self._text_generation_queues = {}
|
||||
@@ -321,20 +302,6 @@ class API:
|
||||
self.app.get("/images/{image_id}")(self.get_image)
|
||||
self.app.post("/v1/messages", response_model=None)(self.claude_messages)
|
||||
self.app.post("/v1/responses", response_model=None)(self.openai_responses)
|
||||
# Ollama API — health checks (must be before static files mount)
|
||||
self.app.head("/")(self._ollama_root)
|
||||
self.app.head("/api/version")(self.ollama_version)
|
||||
# Ollama API
|
||||
self.app.post("/api/chat", response_model=None)(self.ollama_chat)
|
||||
self.app.post("/api/api/chat", response_model=None)(self.ollama_chat)
|
||||
self.app.post("/api/v1/chat", response_model=None)(self.ollama_chat)
|
||||
self.app.post("/api/generate", response_model=None)(self.ollama_generate)
|
||||
self.app.get("/api/tags")(self.ollama_tags)
|
||||
self.app.get("/api/api/tags")(self.ollama_tags)
|
||||
self.app.get("/api/v1/tags")(self.ollama_tags)
|
||||
self.app.post("/api/show")(self.ollama_show)
|
||||
self.app.get("/api/ps")(self.ollama_ps)
|
||||
self.app.get("/api/version")(self.ollama_version)
|
||||
self.app.get("/state")(lambda: self.state)
|
||||
self.app.get("/events")(self.stream_events)
|
||||
self.app.post("/download/start")(self.start_download)
|
||||
@@ -588,7 +555,7 @@ class API:
|
||||
command = TaskCancelled(cancelled_command_id=command_id)
|
||||
with anyio.CancelScope(shield=True):
|
||||
await self.command_sender.send(
|
||||
ForwarderCommand(origin=self.node_id, command=command)
|
||||
ForwarderCommand(origin=self._system_id, command=command)
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
@@ -936,7 +903,7 @@ class API:
|
||||
command = TaskCancelled(cancelled_command_id=command_id)
|
||||
with anyio.CancelScope(shield=True):
|
||||
await self.command_sender.send(
|
||||
ForwarderCommand(origin=self.node_id, command=command)
|
||||
ForwarderCommand(origin=self._system_id, command=command)
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
@@ -1022,7 +989,7 @@ class API:
|
||||
command = TaskCancelled(cancelled_command_id=command_id)
|
||||
with anyio.CancelScope(shield=True):
|
||||
await self.command_sender.send(
|
||||
ForwarderCommand(origin=self.node_id, command=command)
|
||||
ForwarderCommand(origin=self._system_id, command=command)
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
@@ -1328,158 +1295,6 @@ class API:
|
||||
media_type="application/json",
|
||||
)
|
||||
|
||||
async def _ollama_root(self) -> JSONResponse:
|
||||
"""Respond to HEAD / from Ollama CLI connectivity checks."""
|
||||
return JSONResponse(content="Ollama is running")
|
||||
|
||||
async def ollama_chat(
|
||||
self, request: Request
|
||||
) -> OllamaChatResponse | StreamingResponse:
|
||||
"""Ollama Chat API — accepts JSON regardless of Content-Type."""
|
||||
body = await request.body()
|
||||
payload = OllamaChatRequest.model_validate_json(body)
|
||||
task_params = ollama_request_to_text_generation(payload)
|
||||
resolved_model = await self._resolve_and_validate_text_model(
|
||||
ModelId(task_params.model)
|
||||
)
|
||||
task_params = task_params.model_copy(update={"model": resolved_model})
|
||||
|
||||
command = TextGeneration(task_params=task_params)
|
||||
await self._send(command)
|
||||
|
||||
if payload.stream:
|
||||
return StreamingResponse(
|
||||
generate_ollama_chat_stream(
|
||||
command.command_id,
|
||||
self._token_chunk_stream(command.command_id),
|
||||
),
|
||||
media_type="application/x-ndjson",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "close",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
else:
|
||||
return StreamingResponse(
|
||||
collect_ollama_chat_response(
|
||||
command.command_id,
|
||||
self._token_chunk_stream(command.command_id),
|
||||
),
|
||||
media_type="application/json",
|
||||
)
|
||||
|
||||
async def ollama_generate(
|
||||
self, request: Request
|
||||
) -> OllamaGenerateResponse | StreamingResponse:
|
||||
"""Ollama Generate API — accepts JSON regardless of Content-Type."""
|
||||
body = await request.body()
|
||||
payload = OllamaGenerateRequest.model_validate_json(body)
|
||||
task_params = ollama_generate_request_to_text_generation(payload)
|
||||
resolved_model = await self._resolve_and_validate_text_model(
|
||||
ModelId(task_params.model)
|
||||
)
|
||||
task_params = task_params.model_copy(update={"model": resolved_model})
|
||||
|
||||
command = TextGeneration(task_params=task_params)
|
||||
await self._send(command)
|
||||
|
||||
if payload.stream:
|
||||
return StreamingResponse(
|
||||
generate_ollama_generate_stream(
|
||||
command.command_id,
|
||||
self._token_chunk_stream(command.command_id),
|
||||
),
|
||||
media_type="application/x-ndjson",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "close",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
else:
|
||||
return StreamingResponse(
|
||||
collect_ollama_generate_response(
|
||||
command.command_id,
|
||||
self._token_chunk_stream(command.command_id),
|
||||
),
|
||||
media_type="application/json",
|
||||
)
|
||||
|
||||
async def ollama_tags(self) -> OllamaTagsResponse:
|
||||
"""Returns list of models in Ollama tags format. We return the downloaded ones only."""
|
||||
|
||||
def none_if_empty(value: str) -> str | None:
|
||||
return value or None
|
||||
|
||||
downloaded_model_ids: set[str] = set()
|
||||
for node_downloads in self.state.downloads.values():
|
||||
for dl in node_downloads:
|
||||
if isinstance(dl, DownloadCompleted):
|
||||
downloaded_model_ids.add(dl.shard_metadata.model_card.model_id)
|
||||
|
||||
cards = [
|
||||
c for c in await get_model_cards() if c.model_id in downloaded_model_ids
|
||||
]
|
||||
|
||||
now = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
|
||||
return OllamaTagsResponse(
|
||||
models=[
|
||||
OllamaModelTag(
|
||||
name=str(card.model_id),
|
||||
model=str(card.model_id),
|
||||
modified_at=now,
|
||||
size=card.storage_size.in_bytes,
|
||||
digest="sha256:000000000000",
|
||||
details=OllamaModelDetails(
|
||||
family=none_if_empty(card.family),
|
||||
quantization_level=none_if_empty(card.quantization),
|
||||
),
|
||||
)
|
||||
for card in cards
|
||||
]
|
||||
)
|
||||
|
||||
async def ollama_show(self, request: Request) -> OllamaShowResponse:
|
||||
"""Returns model information in Ollama show format."""
|
||||
body = await request.body()
|
||||
payload = OllamaShowRequest.model_validate_json(body)
|
||||
try:
|
||||
card = await ModelCard.load(ModelId(payload.name))
|
||||
except Exception as exc:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Model not found: {payload.name}"
|
||||
) from exc
|
||||
|
||||
return OllamaShowResponse(
|
||||
details=OllamaModelDetails(
|
||||
family=card.family or None,
|
||||
quantization_level=card.quantization or None,
|
||||
),
|
||||
)
|
||||
|
||||
async def ollama_ps(self) -> OllamaPsResponse:
|
||||
"""Returns list of running models (active instances)."""
|
||||
models: list[OllamaPsModel] = []
|
||||
seen: set[str] = set()
|
||||
for instance in self.state.instances.values():
|
||||
model_id = str(instance.shard_assignments.model_id)
|
||||
if model_id in seen:
|
||||
continue
|
||||
seen.add(model_id)
|
||||
models.append(
|
||||
OllamaPsModel(
|
||||
name=model_id,
|
||||
model=model_id,
|
||||
size=0,
|
||||
)
|
||||
)
|
||||
return OllamaPsResponse(models=models)
|
||||
|
||||
async def ollama_version(self) -> dict[str, str]:
|
||||
"""Returns version information for Ollama API compatibility."""
|
||||
return {"version": "exo v1.0"}
|
||||
|
||||
def _calculate_total_available_memory(self) -> Memory:
|
||||
"""Calculate total available memory across all nodes in bytes."""
|
||||
total_available = Memory()
|
||||
@@ -1509,7 +1324,7 @@ class API:
|
||||
name=card.model_id.short(),
|
||||
description="",
|
||||
tags=[],
|
||||
storage_size_megabytes=card.storage_size.in_mb,
|
||||
storage_size_megabytes=int(card.storage_size.in_mb),
|
||||
supports_tensor=card.supports_tensor,
|
||||
tasks=[task.value for task in card.tasks],
|
||||
is_custom=is_custom_card(card.model_id),
|
||||
@@ -1615,6 +1430,8 @@ class API:
|
||||
async def _apply_state(self):
|
||||
with self.global_event_receiver as events:
|
||||
async for f_event in events:
|
||||
if f_event.session != self.session_id:
|
||||
continue
|
||||
if f_event.origin != self.session_id.master_node_id:
|
||||
continue
|
||||
self.event_buffer.ingest(f_event.origin_idx, f_event.event)
|
||||
@@ -1641,6 +1458,22 @@ class API:
|
||||
await queue.send(event.chunk)
|
||||
except BrokenResourceError:
|
||||
self._text_generation_queues.pop(event.command_id, None)
|
||||
|
||||
elif isinstance(event, PrefillProgress):
|
||||
if queue := self._text_generation_queues.get(
|
||||
event.command_id, None
|
||||
):
|
||||
try:
|
||||
await queue.send(
|
||||
PrefillProgressChunk(
|
||||
model=event.model,
|
||||
processed_tokens=event.processed_tokens,
|
||||
total_tokens=event.total_tokens,
|
||||
)
|
||||
)
|
||||
except BrokenResourceError:
|
||||
self._text_generation_queues.pop(event.command_id, None)
|
||||
|
||||
if isinstance(event, TracesMerged):
|
||||
self._save_merged_trace(event)
|
||||
|
||||
@@ -1678,12 +1511,12 @@ class API:
|
||||
while self.paused:
|
||||
await self.paused_ev.wait()
|
||||
await self.command_sender.send(
|
||||
ForwarderCommand(origin=self.node_id, command=command)
|
||||
ForwarderCommand(origin=self._system_id, command=command)
|
||||
)
|
||||
|
||||
async def _send_download(self, command: DownloadCommand):
|
||||
await self.download_command_sender.send(
|
||||
ForwarderDownloadCommand(origin=self.node_id, command=command)
|
||||
ForwarderDownloadCommand(origin=self._system_id, command=command)
|
||||
)
|
||||
|
||||
async def start_download(
|
||||
|
||||
@@ -29,13 +29,14 @@ from exo.shared.types.commands import (
|
||||
TestCommand,
|
||||
TextGeneration,
|
||||
)
|
||||
from exo.shared.types.common import CommandId, NodeId, SessionId
|
||||
from exo.shared.types.common import CommandId, NodeId, SessionId, SystemId
|
||||
from exo.shared.types.events import (
|
||||
Event,
|
||||
ForwarderEvent,
|
||||
GlobalForwarderEvent,
|
||||
IndexedEvent,
|
||||
InputChunkReceived,
|
||||
InstanceDeleted,
|
||||
LocalForwarderEvent,
|
||||
NodeGatheredInfo,
|
||||
NodeTimedOut,
|
||||
TaskCreated,
|
||||
@@ -71,8 +72,8 @@ class Master:
|
||||
session_id: SessionId,
|
||||
*,
|
||||
command_receiver: Receiver[ForwarderCommand],
|
||||
local_event_receiver: Receiver[ForwarderEvent],
|
||||
global_event_sender: Sender[ForwarderEvent],
|
||||
local_event_receiver: Receiver[LocalForwarderEvent],
|
||||
global_event_sender: Sender[GlobalForwarderEvent],
|
||||
download_command_sender: Sender[ForwarderDownloadCommand],
|
||||
):
|
||||
self.state = State()
|
||||
@@ -87,10 +88,11 @@ class Master:
|
||||
send, recv = channel[Event]()
|
||||
self.event_sender: Sender[Event] = send
|
||||
self._loopback_event_receiver: Receiver[Event] = recv
|
||||
self._loopback_event_sender: Sender[ForwarderEvent] = (
|
||||
self._loopback_event_sender: Sender[LocalForwarderEvent] = (
|
||||
local_event_receiver.clone_sender()
|
||||
)
|
||||
self._multi_buffer = MultiSourceBuffer[NodeId, Event]()
|
||||
self._system_id = SystemId()
|
||||
self._multi_buffer = MultiSourceBuffer[SystemId, Event]()
|
||||
self._event_log = DiskEventLog(EXO_EVENT_LOG_DIR / "master")
|
||||
self._pending_traces: dict[TaskId, dict[int, list[TraceEventData]]] = {}
|
||||
self._expected_ranks: dict[TaskId, set[int]] = {}
|
||||
@@ -288,7 +290,7 @@ class Master:
|
||||
):
|
||||
await self.download_command_sender.send(
|
||||
ForwarderDownloadCommand(
|
||||
origin=self.node_id, command=cmd
|
||||
origin=self._system_id, command=cmd
|
||||
)
|
||||
)
|
||||
generated_events.extend(transition_events)
|
||||
@@ -414,8 +416,8 @@ class Master:
|
||||
with self._loopback_event_receiver as events:
|
||||
async for event in events:
|
||||
await self._loopback_event_sender.send(
|
||||
ForwarderEvent(
|
||||
origin=NodeId(f"master_{self.node_id}"),
|
||||
LocalForwarderEvent(
|
||||
origin=self._system_id,
|
||||
origin_idx=local_index,
|
||||
session=self.session_id,
|
||||
event=event,
|
||||
@@ -427,7 +429,7 @@ class Master:
|
||||
async def _send_event(self, event: IndexedEvent):
|
||||
# Convenience method since this line is ugly
|
||||
await self.global_event_sender.send(
|
||||
ForwarderEvent(
|
||||
GlobalForwarderEvent(
|
||||
origin=self.node_id,
|
||||
origin_idx=event.idx,
|
||||
session=self.session_id,
|
||||
|
||||
@@ -102,21 +102,22 @@ def _allocate_and_validate_layers(
|
||||
layer_allocations = allocate_layers_proportionally(
|
||||
total_layers=model_card.n_layers,
|
||||
memory_fractions=[
|
||||
node_memory[node_id].ram_available / total_memory for node_id in node_ids
|
||||
node_memory[node_id].ram_available.in_bytes / total_memory.in_bytes
|
||||
for node_id in node_ids
|
||||
],
|
||||
)
|
||||
|
||||
total_storage = model_card.storage_size
|
||||
total_storage_bytes = model_card.storage_size.in_bytes
|
||||
total_layers = model_card.n_layers
|
||||
for i, node_id in enumerate(node_ids):
|
||||
node_layers = layer_allocations[i]
|
||||
required_memory = (total_storage * node_layers) // total_layers
|
||||
available_memory = node_memory[node_id].ram_available
|
||||
required_memory = (total_storage_bytes * node_layers) // total_layers
|
||||
available_memory = node_memory[node_id].ram_available.in_bytes
|
||||
if required_memory > available_memory:
|
||||
raise ValueError(
|
||||
f"Node {i} ({node_id}) has insufficient memory: "
|
||||
f"requires {required_memory.in_gb:.2f} GB for {node_layers} layers, "
|
||||
f"but only has {available_memory.in_gb:.2f} GB available"
|
||||
f"requires {required_memory / (1024**3):.2f} GB for {node_layers} layers, "
|
||||
f"but only has {available_memory / (1024**3):.2f} GB available"
|
||||
)
|
||||
|
||||
return layer_allocations
|
||||
@@ -341,7 +342,6 @@ def _find_ip_prioritised(
|
||||
other_node_id: NodeId,
|
||||
cycle_digraph: Topology,
|
||||
node_network: Mapping[NodeId, NodeNetworkInfo],
|
||||
ring: bool,
|
||||
) -> str | None:
|
||||
"""Find an IP address between nodes with prioritization.
|
||||
|
||||
@@ -354,27 +354,13 @@ def _find_ip_prioritised(
|
||||
ip_to_type = {
|
||||
iface.ip_address: iface.interface_type for iface in other_network.interfaces
|
||||
}
|
||||
|
||||
# Ring should prioritise fastest connection. As a best-effort, we prioritise TB.
|
||||
# TODO: Profile and get actual connection speeds.
|
||||
if ring:
|
||||
priority = {
|
||||
"thunderbolt": 0,
|
||||
"maybe_ethernet": 1,
|
||||
"ethernet": 2,
|
||||
"wifi": 3,
|
||||
"unknown": 4,
|
||||
}
|
||||
|
||||
# RDMA prefers ethernet coordinator
|
||||
else:
|
||||
priority = {
|
||||
"ethernet": 0,
|
||||
"wifi": 1,
|
||||
"unknown": 2,
|
||||
"maybe_ethernet": 3,
|
||||
"thunderbolt": 4,
|
||||
}
|
||||
priority = {
|
||||
"ethernet": 0,
|
||||
"wifi": 1,
|
||||
"unknown": 2,
|
||||
"maybe_ethernet": 3,
|
||||
"thunderbolt": 4,
|
||||
}
|
||||
return min(ips, key=lambda ip: priority.get(ip_to_type.get(ip, "unknown"), 2))
|
||||
|
||||
|
||||
@@ -414,7 +400,7 @@ def get_mlx_ring_hosts_by_node(
|
||||
continue
|
||||
|
||||
connection_ip = _find_ip_prioritised(
|
||||
node_id, other_node_id, cycle_digraph, node_network, ring=True
|
||||
node_id, other_node_id, cycle_digraph, node_network
|
||||
)
|
||||
if connection_ip is None:
|
||||
raise ValueError(
|
||||
@@ -445,9 +431,7 @@ def get_mlx_jaccl_coordinators(
|
||||
if n == coordinator:
|
||||
return "0.0.0.0"
|
||||
|
||||
ip = _find_ip_prioritised(
|
||||
n, coordinator, cycle_digraph, node_network, ring=False
|
||||
)
|
||||
ip = _find_ip_prioritised(n, coordinator, cycle_digraph, node_network)
|
||||
if ip is not None:
|
||||
return ip
|
||||
|
||||
|
||||
@@ -261,7 +261,7 @@ class TestGenerateClaudeStreamToolUse:
|
||||
|
||||
parsed = _parse_sse_events(events)
|
||||
|
||||
# Two tool block starts (at indices 0 and 1 — no text block when only tools)
|
||||
# Two tool block starts (at indices 1 and 2)
|
||||
tool_starts = [
|
||||
e
|
||||
for e in parsed
|
||||
@@ -270,11 +270,12 @@ class TestGenerateClaudeStreamToolUse:
|
||||
== "tool_use"
|
||||
]
|
||||
assert len(tool_starts) == 2
|
||||
assert tool_starts[0]["index"] == 0
|
||||
assert tool_starts[1]["index"] == 1
|
||||
assert tool_starts[0]["index"] == 1
|
||||
assert tool_starts[1]["index"] == 2
|
||||
|
||||
# Two tool block stops (at indices 0 and 1)
|
||||
# Two tool block stops (at indices 1 and 2), plus text block stop at 0
|
||||
block_stops = [e for e in parsed if e.get("type") == "content_block_stop"]
|
||||
stop_indices = [e["index"] for e in block_stops]
|
||||
assert 0 in stop_indices
|
||||
assert 1 in stop_indices
|
||||
assert 2 in stop_indices
|
||||
|
||||
@@ -15,11 +15,12 @@ from exo.shared.types.commands import (
|
||||
PlaceInstance,
|
||||
TextGeneration,
|
||||
)
|
||||
from exo.shared.types.common import ModelId, NodeId, SessionId
|
||||
from exo.shared.types.common import ModelId, NodeId, SessionId, SystemId
|
||||
from exo.shared.types.events import (
|
||||
ForwarderEvent,
|
||||
GlobalForwarderEvent,
|
||||
IndexedEvent,
|
||||
InstanceCreated,
|
||||
LocalForwarderEvent,
|
||||
NodeGatheredInfo,
|
||||
TaskCreated,
|
||||
)
|
||||
@@ -45,9 +46,9 @@ async def test_master():
|
||||
node_id = NodeId(keypair.to_node_id())
|
||||
session_id = SessionId(master_node_id=node_id, election_clock=0)
|
||||
|
||||
ge_sender, global_event_receiver = channel[ForwarderEvent]()
|
||||
ge_sender, global_event_receiver = channel[GlobalForwarderEvent]()
|
||||
command_sender, co_receiver = channel[ForwarderCommand]()
|
||||
local_event_sender, le_receiver = channel[ForwarderEvent]()
|
||||
local_event_sender, le_receiver = channel[LocalForwarderEvent]()
|
||||
fcds, _fcdr = channel[ForwarderDownloadCommand]()
|
||||
|
||||
all_events: list[IndexedEvent] = []
|
||||
@@ -75,13 +76,12 @@ async def test_master():
|
||||
async with anyio.create_task_group() as tg:
|
||||
tg.start_soon(master.run)
|
||||
|
||||
sender_node_id = NodeId(f"{keypair.to_node_id()}_sender")
|
||||
# inject a NodeGatheredInfo event
|
||||
logger.info("inject a NodeGatheredInfo event")
|
||||
await local_event_sender.send(
|
||||
ForwarderEvent(
|
||||
LocalForwarderEvent(
|
||||
origin_idx=0,
|
||||
origin=sender_node_id,
|
||||
origin=SystemId("Worker"),
|
||||
session=session_id,
|
||||
event=(
|
||||
NodeGatheredInfo(
|
||||
@@ -108,7 +108,7 @@ async def test_master():
|
||||
logger.info("inject a CreateInstance Command")
|
||||
await command_sender.send(
|
||||
ForwarderCommand(
|
||||
origin=node_id,
|
||||
origin=SystemId("API"),
|
||||
command=(
|
||||
PlaceInstance(
|
||||
command_id=CommandId(),
|
||||
@@ -133,7 +133,7 @@ async def test_master():
|
||||
logger.info("inject a TextGeneration Command")
|
||||
await command_sender.send(
|
||||
ForwarderCommand(
|
||||
origin=node_id,
|
||||
origin=SystemId("API"),
|
||||
command=(
|
||||
TextGeneration(
|
||||
command_id=CommandId(),
|
||||
|
||||
@@ -80,8 +80,8 @@ def test_get_instance_placements_create_instance(
|
||||
):
|
||||
# arrange
|
||||
model_card.n_layers = total_layers
|
||||
model_card.storage_size = Memory.from_bytes(
|
||||
sum(available_memory)
|
||||
model_card.storage_size.in_bytes = sum(
|
||||
available_memory
|
||||
) # make it exactly fit across all nodes
|
||||
topology = Topology()
|
||||
|
||||
@@ -349,7 +349,7 @@ def test_tensor_rdma_backend_connectivity_matrix(
|
||||
# arrange
|
||||
topology = Topology()
|
||||
model_card.n_layers = 12
|
||||
model_card.storage_size = Memory.from_bytes(1500)
|
||||
model_card.storage_size.in_bytes = 1500
|
||||
|
||||
node_a = NodeId()
|
||||
node_b = NodeId()
|
||||
|
||||
@@ -5,7 +5,8 @@ from exo.routing.connection_message import ConnectionMessage
|
||||
from exo.shared.election import ElectionMessage
|
||||
from exo.shared.types.commands import ForwarderCommand, ForwarderDownloadCommand
|
||||
from exo.shared.types.events import (
|
||||
ForwarderEvent,
|
||||
GlobalForwarderEvent,
|
||||
LocalForwarderEvent,
|
||||
)
|
||||
from exo.utils.pydantic_ext import CamelCaseModel
|
||||
|
||||
@@ -36,8 +37,8 @@ class TypedTopic[T: CamelCaseModel]:
|
||||
return self.model_type.model_validate_json(b.decode("utf-8"))
|
||||
|
||||
|
||||
GLOBAL_EVENTS = TypedTopic("global_events", PublishPolicy.Always, ForwarderEvent)
|
||||
LOCAL_EVENTS = TypedTopic("local_events", PublishPolicy.Always, ForwarderEvent)
|
||||
GLOBAL_EVENTS = TypedTopic("global_events", PublishPolicy.Always, GlobalForwarderEvent)
|
||||
LOCAL_EVENTS = TypedTopic("local_events", PublishPolicy.Always, LocalForwarderEvent)
|
||||
COMMANDS = TypedTopic("commands", PublishPolicy.Always, ForwarderCommand)
|
||||
ELECTION_MESSAGES = TypedTopic(
|
||||
"election_messages", PublishPolicy.Always, ElectionMessage
|
||||
|
||||
@@ -15,6 +15,7 @@ from exo.shared.types.events import (
|
||||
NodeDownloadProgress,
|
||||
NodeGatheredInfo,
|
||||
NodeTimedOut,
|
||||
PrefillProgress,
|
||||
RunnerDeleted,
|
||||
RunnerStatusUpdated,
|
||||
TaskAcknowledged,
|
||||
@@ -64,6 +65,7 @@ def event_apply(event: Event, state: State) -> State:
|
||||
| ChunkGenerated()
|
||||
| TaskAcknowledged()
|
||||
| InputChunkReceived()
|
||||
| PrefillProgress()
|
||||
| TracesCollected()
|
||||
| TracesMerged()
|
||||
): # Pass-through events that don't modify state
|
||||
|
||||
@@ -14,7 +14,7 @@ def test_apply_node_download_progress():
|
||||
event = DownloadCompleted(
|
||||
node_id=NodeId("node-1"),
|
||||
shard_metadata=shard1,
|
||||
total=Memory(),
|
||||
total_bytes=Memory(),
|
||||
)
|
||||
|
||||
new_state = apply_node_download_progress(
|
||||
@@ -30,12 +30,12 @@ def test_apply_two_node_download_progress():
|
||||
event1 = DownloadCompleted(
|
||||
node_id=NodeId("node-1"),
|
||||
shard_metadata=shard1,
|
||||
total=Memory(),
|
||||
total_bytes=Memory(),
|
||||
)
|
||||
event2 = DownloadCompleted(
|
||||
node_id=NodeId("node-1"),
|
||||
shard_metadata=shard2,
|
||||
total=Memory(),
|
||||
total_bytes=Memory(),
|
||||
)
|
||||
state = State(downloads={NodeId("node-1"): [event1]})
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ from anyio import create_task_group, fail_after, move_on_after
|
||||
from exo.routing.connection_message import ConnectionMessage, ConnectionMessageType
|
||||
from exo.shared.election import Election, ElectionMessage, ElectionResult
|
||||
from exo.shared.types.commands import ForwarderCommand, TestCommand
|
||||
from exo.shared.types.common import NodeId, SessionId
|
||||
from exo.shared.types.common import NodeId, SessionId, SystemId
|
||||
from exo.utils.channels import channel
|
||||
|
||||
# ======= #
|
||||
@@ -384,7 +384,7 @@ async def test_tie_breaker_prefers_node_with_more_commands_seen() -> None:
|
||||
# Pump local commands so our commands_seen is high before the round starts
|
||||
for _ in range(50):
|
||||
await co_tx.send(
|
||||
ForwarderCommand(origin=NodeId("SOMEONE"), command=TestCommand())
|
||||
ForwarderCommand(origin=SystemId("SOMEONE"), command=TestCommand())
|
||||
)
|
||||
|
||||
# Trigger a round at clock=1 with a peer of equal seniority but fewer commands
|
||||
|
||||
@@ -77,7 +77,7 @@ class ChatCompletionMessage(BaseModel):
|
||||
content: (
|
||||
str | ChatCompletionMessageText | list[ChatCompletionMessageText] | None
|
||||
) = None
|
||||
reasoning_content: str | None = None
|
||||
thinking: str | None = None # Added for GPT-OSS harmony format support
|
||||
name: str | None = None
|
||||
tool_calls: list[ToolCall] | None = None
|
||||
tool_call_id: str | None = None
|
||||
|
||||
@@ -27,7 +27,6 @@ class TokenChunk(BaseChunk):
|
||||
stats: GenerationStats | None = None
|
||||
logprob: float | None = None
|
||||
top_logprobs: list[TopLogprobItem] | None = None
|
||||
is_thinking: bool = False
|
||||
|
||||
|
||||
class ErrorChunk(BaseChunk):
|
||||
|
||||
@@ -47,14 +47,6 @@ class ClaudeImageBlock(BaseModel, frozen=True):
|
||||
source: ClaudeImageSource
|
||||
|
||||
|
||||
class ClaudeThinkingBlock(BaseModel, frozen=True):
|
||||
"""Thinking content block in Claude Messages API."""
|
||||
|
||||
type: Literal["thinking"] = "thinking"
|
||||
thinking: str
|
||||
signature: str | None = None
|
||||
|
||||
|
||||
class ClaudeToolUseBlock(BaseModel, frozen=True):
|
||||
"""Tool use content block in Claude Messages API."""
|
||||
|
||||
@@ -74,17 +66,11 @@ class ClaudeToolResultBlock(BaseModel, frozen=True):
|
||||
cache_control: dict[str, str] | None = None
|
||||
|
||||
|
||||
ClaudeContentBlock = (
|
||||
ClaudeTextBlock | ClaudeImageBlock | ClaudeThinkingBlock | ClaudeToolUseBlock
|
||||
)
|
||||
ClaudeContentBlock = ClaudeTextBlock | ClaudeImageBlock | ClaudeToolUseBlock
|
||||
|
||||
# Input content blocks can also include tool_result (sent by user after tool_use)
|
||||
ClaudeInputContentBlock = (
|
||||
ClaudeTextBlock
|
||||
| ClaudeImageBlock
|
||||
| ClaudeThinkingBlock
|
||||
| ClaudeToolUseBlock
|
||||
| ClaudeToolResultBlock
|
||||
ClaudeTextBlock | ClaudeImageBlock | ClaudeToolUseBlock | ClaudeToolResultBlock
|
||||
)
|
||||
|
||||
|
||||
@@ -96,11 +82,6 @@ class ClaudeMessage(BaseModel, frozen=True):
|
||||
content: str | list[ClaudeInputContentBlock]
|
||||
|
||||
|
||||
class ClaudeThinkingConfig(BaseModel, frozen=True):
|
||||
type: Literal["enabled", "disabled", "adaptive"]
|
||||
budget_tokens: int | None = None
|
||||
|
||||
|
||||
class ClaudeMessagesRequest(BaseModel):
|
||||
"""Request body for Claude Messages API."""
|
||||
|
||||
@@ -115,7 +96,6 @@ class ClaudeMessagesRequest(BaseModel):
|
||||
top_k: int | None = None
|
||||
tools: list[ClaudeToolDefinition] | None = None
|
||||
metadata: dict[str, str] | None = None
|
||||
thinking: ClaudeThinkingConfig | None = None
|
||||
|
||||
|
||||
# Response types
|
||||
@@ -165,7 +145,7 @@ class ClaudeContentBlockStartEvent(BaseModel, frozen=True):
|
||||
|
||||
type: Literal["content_block_start"] = "content_block_start"
|
||||
index: int
|
||||
content_block: ClaudeTextBlock | ClaudeThinkingBlock | ClaudeToolUseBlock
|
||||
content_block: ClaudeTextBlock | ClaudeToolUseBlock
|
||||
|
||||
|
||||
class ClaudeTextDelta(BaseModel, frozen=True):
|
||||
@@ -175,13 +155,6 @@ class ClaudeTextDelta(BaseModel, frozen=True):
|
||||
text: str
|
||||
|
||||
|
||||
class ClaudeThinkingDelta(BaseModel, frozen=True):
|
||||
"""Delta for thinking content block."""
|
||||
|
||||
type: Literal["thinking_delta"] = "thinking_delta"
|
||||
thinking: str
|
||||
|
||||
|
||||
class ClaudeInputJsonDelta(BaseModel, frozen=True):
|
||||
"""Delta for tool use input JSON content block."""
|
||||
|
||||
@@ -194,7 +167,7 @@ class ClaudeContentBlockDeltaEvent(BaseModel, frozen=True):
|
||||
|
||||
type: Literal["content_block_delta"] = "content_block_delta"
|
||||
index: int
|
||||
delta: ClaudeTextDelta | ClaudeThinkingDelta | ClaudeInputJsonDelta
|
||||
delta: ClaudeTextDelta | ClaudeInputJsonDelta
|
||||
|
||||
|
||||
class ClaudeContentBlockStopEvent(BaseModel, frozen=True):
|
||||
|
||||
@@ -6,7 +6,7 @@ from exo.shared.types.api import (
|
||||
ImageGenerationTaskParams,
|
||||
)
|
||||
from exo.shared.types.chunks import InputImageChunk
|
||||
from exo.shared.types.common import CommandId, NodeId
|
||||
from exo.shared.types.common import CommandId, NodeId, SystemId
|
||||
from exo.shared.types.text_generation import TextGenerationTaskParams
|
||||
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
|
||||
from exo.shared.types.worker.shards import Sharding, ShardMetadata
|
||||
@@ -100,10 +100,10 @@ Command = (
|
||||
|
||||
|
||||
class ForwarderCommand(CamelCaseModel):
|
||||
origin: NodeId
|
||||
origin: SystemId
|
||||
command: Command
|
||||
|
||||
|
||||
class ForwarderDownloadCommand(CamelCaseModel):
|
||||
origin: NodeId
|
||||
origin: SystemId
|
||||
command: DownloadCommand
|
||||
|
||||
@@ -25,6 +25,10 @@ class NodeId(Id):
|
||||
pass
|
||||
|
||||
|
||||
class SystemId(Id):
|
||||
pass
|
||||
|
||||
|
||||
class ModelId(Id):
|
||||
def normalize(self) -> str:
|
||||
return self.replace("/", "--")
|
||||
|
||||
@@ -5,7 +5,7 @@ from pydantic import Field
|
||||
|
||||
from exo.shared.topology import Connection
|
||||
from exo.shared.types.chunks import GenerationChunk, InputImageChunk
|
||||
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
|
||||
from exo.shared.types.common import CommandId, Id, ModelId, NodeId, SessionId, SystemId
|
||||
from exo.shared.types.tasks import Task, TaskId, TaskStatus
|
||||
from exo.shared.types.worker.downloads import DownloadProgress
|
||||
from exo.shared.types.worker.instances import Instance, InstanceId
|
||||
@@ -102,6 +102,13 @@ class InputChunkReceived(BaseEvent):
|
||||
chunk: InputImageChunk
|
||||
|
||||
|
||||
class PrefillProgress(BaseEvent):
|
||||
command_id: CommandId
|
||||
model: ModelId
|
||||
processed_tokens: int
|
||||
total_tokens: int
|
||||
|
||||
|
||||
class TopologyEdgeCreated(BaseEvent):
|
||||
conn: Connection
|
||||
|
||||
@@ -148,6 +155,7 @@ Event = (
|
||||
| NodeDownloadProgress
|
||||
| ChunkGenerated
|
||||
| InputChunkReceived
|
||||
| PrefillProgress
|
||||
| TopologyEdgeCreated
|
||||
| TopologyEdgeDeleted
|
||||
| TracesCollected
|
||||
@@ -162,10 +170,19 @@ class IndexedEvent(CamelCaseModel):
|
||||
event: Event
|
||||
|
||||
|
||||
class ForwarderEvent(CamelCaseModel):
|
||||
class GlobalForwarderEvent(CamelCaseModel):
|
||||
"""An event the forwarder will serialize and send over the network"""
|
||||
|
||||
origin_idx: int = Field(ge=0)
|
||||
origin: NodeId
|
||||
session: SessionId
|
||||
event: Event
|
||||
|
||||
|
||||
class LocalForwarderEvent(CamelCaseModel):
|
||||
"""An event the forwarder will serialize and send over the network"""
|
||||
|
||||
origin_idx: int = Field(ge=0)
|
||||
origin: SystemId
|
||||
session: SessionId
|
||||
event: Event
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
from math import ceil
|
||||
from typing import Self, overload
|
||||
from typing import Self
|
||||
|
||||
from exo.utils.pydantic_ext import FrozenModel
|
||||
from exo.utils.pydantic_ext import CamelCaseModel
|
||||
|
||||
|
||||
class Memory(FrozenModel):
|
||||
class Memory(CamelCaseModel):
|
||||
in_bytes: int = 0
|
||||
|
||||
@classmethod
|
||||
@@ -33,22 +33,12 @@ class Memory(FrozenModel):
|
||||
return cls(in_bytes=round(val * 1024))
|
||||
|
||||
@property
|
||||
def in_mb(self) -> int:
|
||||
"""The approximate megabytes this memory represents, rounded to nearest MB. Setting this property rounds to the nearest byte."""
|
||||
return round(self.in_bytes / (1024**2))
|
||||
|
||||
@in_mb.setter
|
||||
def in_mb(self, val: int):
|
||||
"""Set the megabytes for this memory."""
|
||||
self.in_bytes = val * (1024**2)
|
||||
|
||||
@property
|
||||
def in_float_mb(self) -> float:
|
||||
"""The megabytes this memory represents as a float. Setting this property rounds to the nearest byte."""
|
||||
def in_mb(self) -> float:
|
||||
"""The approximate megabytes this memory represents. Setting this property rounds to the nearest byte."""
|
||||
return self.in_bytes / (1024**2)
|
||||
|
||||
@in_float_mb.setter
|
||||
def in_float_mb(self, val: float):
|
||||
@in_mb.setter
|
||||
def in_mb(self, val: float):
|
||||
"""Set the megabytes for this memory, rounded to the nearest byte."""
|
||||
self.in_bytes = round(val * (1024**2))
|
||||
|
||||
@@ -67,85 +57,17 @@ class Memory(FrozenModel):
|
||||
"""The approximate gigabytes this memory represents."""
|
||||
return self.in_bytes / (1024**3)
|
||||
|
||||
def __add__(self, other: object) -> "Memory":
|
||||
if isinstance(other, Memory):
|
||||
return Memory.from_bytes(self.in_bytes + other.in_bytes)
|
||||
return NotImplemented
|
||||
def __add__(self, other: "Memory") -> "Memory":
|
||||
return Memory.from_bytes(self.in_bytes + other.in_bytes)
|
||||
|
||||
def __radd__(self, other: object) -> "Memory":
|
||||
if other == 0:
|
||||
return self
|
||||
return NotImplemented
|
||||
def __lt__(self, other: Self) -> bool:
|
||||
return self.in_bytes < other.in_bytes
|
||||
|
||||
def __sub__(self, other: object) -> "Memory":
|
||||
if isinstance(other, Memory):
|
||||
return Memory.from_bytes(self.in_bytes - other.in_bytes)
|
||||
return NotImplemented
|
||||
def __le__(self, other: Self) -> bool:
|
||||
return self.in_bytes <= other.in_bytes
|
||||
|
||||
def __mul__(self, other: int | float):
|
||||
return Memory.from_bytes(round(self.in_bytes * other))
|
||||
def __gt__(self, other: Self) -> bool:
|
||||
return self.in_bytes > other.in_bytes
|
||||
|
||||
def __rmul__(self, other: int | float):
|
||||
return self * other
|
||||
|
||||
@overload
|
||||
def __truediv__(self, other: "Memory") -> float: ...
|
||||
@overload
|
||||
def __truediv__(self, other: int) -> "Memory": ...
|
||||
@overload
|
||||
def __truediv__(self, other: float) -> "Memory": ...
|
||||
def __truediv__(self, other: object) -> "Memory | float":
|
||||
if isinstance(other, Memory):
|
||||
return self.in_bytes / other.in_bytes
|
||||
if isinstance(other, (int, float)):
|
||||
return Memory.from_bytes(round(self.in_bytes / other))
|
||||
return NotImplemented
|
||||
|
||||
def __floordiv__(self, other: object) -> "Memory":
|
||||
if isinstance(other, (int, float)):
|
||||
return Memory.from_bytes(int(self.in_bytes // other))
|
||||
return NotImplemented
|
||||
|
||||
def __lt__(self, other: object) -> bool:
|
||||
if isinstance(other, Memory):
|
||||
return self.in_bytes < other.in_bytes
|
||||
return NotImplemented
|
||||
|
||||
def __le__(self, other: object) -> bool:
|
||||
if isinstance(other, Memory):
|
||||
return self.in_bytes <= other.in_bytes
|
||||
return NotImplemented
|
||||
|
||||
def __gt__(self, other: object) -> bool:
|
||||
if isinstance(other, Memory):
|
||||
return self.in_bytes > other.in_bytes
|
||||
return NotImplemented
|
||||
|
||||
def __ge__(self, other: object) -> bool:
|
||||
if isinstance(other, Memory):
|
||||
return self.in_bytes >= other.in_bytes
|
||||
return NotImplemented
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if isinstance(other, Memory):
|
||||
return self.in_bytes == other.in_bytes
|
||||
return NotImplemented
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"Memory.from_bytes({self.in_bytes})"
|
||||
|
||||
def __str__(self) -> str:
|
||||
if self.in_gb > 2:
|
||||
val = self.in_gb
|
||||
unit = "GiB"
|
||||
elif self.in_mb > 2:
|
||||
val = self.in_mb
|
||||
unit = "MiB"
|
||||
elif self.in_kb > 3:
|
||||
val = self.in_kb
|
||||
unit = "KiB"
|
||||
else:
|
||||
val = self.in_bytes
|
||||
unit = "B"
|
||||
|
||||
return f"{val:.2f} {unit}".rstrip("0").rstrip(".") + f" {unit}"
|
||||
def __ge__(self, other: Self) -> bool:
|
||||
return self.in_bytes >= other.in_bytes
|
||||
|
||||
@@ -1,147 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from exo.shared.models.model_cards import ModelId
|
||||
|
||||
# https://github.com/ollama/ollama/blob/main/docs/api.md
|
||||
|
||||
OllamaRole = Literal["system", "user", "assistant", "tool"]
|
||||
OllamaDoneReason = Literal["stop", "length", "tool_call", "error"]
|
||||
|
||||
|
||||
class OllamaToolFunction(BaseModel, frozen=True):
|
||||
name: str
|
||||
arguments: dict[str, Any] | str
|
||||
index: int | None = None
|
||||
|
||||
|
||||
class OllamaToolCall(BaseModel, frozen=True):
|
||||
id: str | None = None
|
||||
type: Literal["function"] | None = None
|
||||
function: OllamaToolFunction
|
||||
|
||||
|
||||
class OllamaMessage(BaseModel, frozen=True):
|
||||
role: OllamaRole
|
||||
content: str | None = None
|
||||
thinking: str | None = None
|
||||
tool_calls: list[OllamaToolCall] | None = None
|
||||
name: str | None = None
|
||||
tool_name: str | None = None
|
||||
images: list[str] | None = None
|
||||
|
||||
|
||||
class OllamaOptions(BaseModel, frozen=True):
|
||||
num_predict: int | None = None
|
||||
temperature: float | None = None
|
||||
top_p: float | None = None
|
||||
top_k: int | None = None
|
||||
stop: str | list[str] | None = None
|
||||
seed: int | None = None
|
||||
|
||||
|
||||
class OllamaChatRequest(BaseModel, frozen=True):
|
||||
model: ModelId
|
||||
messages: list[OllamaMessage]
|
||||
stream: bool = True
|
||||
options: OllamaOptions | None = None
|
||||
tools: list[dict[str, Any]] | None = None
|
||||
format: Literal["json"] | dict[str, Any] | None = None
|
||||
keep_alive: str | int | None = None
|
||||
think: bool | None = None
|
||||
|
||||
|
||||
class OllamaGenerateRequest(BaseModel, frozen=True):
|
||||
model: ModelId
|
||||
prompt: str = ""
|
||||
system: str | None = None
|
||||
stream: bool = True
|
||||
options: OllamaOptions | None = None
|
||||
format: Literal["json"] | dict[str, Any] | None = None
|
||||
keep_alive: str | int | None = None
|
||||
think: bool | None = None
|
||||
raw: bool = False
|
||||
|
||||
|
||||
class OllamaGenerateResponse(BaseModel, frozen=True, strict=True):
|
||||
model: str
|
||||
created_at: str = Field(
|
||||
default_factory=lambda: time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
|
||||
)
|
||||
response: str
|
||||
thinking: str | None = None
|
||||
done: bool
|
||||
done_reason: OllamaDoneReason | None = None
|
||||
total_duration: int | None = None
|
||||
load_duration: int | None = None
|
||||
prompt_eval_count: int | None = None
|
||||
prompt_eval_duration: int | None = None
|
||||
eval_count: int | None = None
|
||||
eval_duration: int | None = None
|
||||
|
||||
|
||||
class OllamaShowRequest(BaseModel, frozen=True):
|
||||
name: str
|
||||
verbose: bool | None = None
|
||||
|
||||
|
||||
class OllamaChatResponse(BaseModel, frozen=True, strict=True):
|
||||
model: str
|
||||
created_at: str = Field(
|
||||
default_factory=lambda: time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
|
||||
)
|
||||
message: OllamaMessage
|
||||
done: bool
|
||||
done_reason: OllamaDoneReason | None = None
|
||||
total_duration: int | None = None
|
||||
load_duration: int | None = None
|
||||
prompt_eval_count: int | None = None
|
||||
prompt_eval_duration: int | None = None
|
||||
eval_count: int | None = None
|
||||
eval_duration: int | None = None
|
||||
|
||||
|
||||
class OllamaModelDetails(BaseModel, frozen=True, strict=True):
|
||||
format: str | None = None
|
||||
family: str | None = None
|
||||
parameter_size: str | None = None
|
||||
quantization_level: str | None = None
|
||||
|
||||
|
||||
class OllamaModelTag(BaseModel, frozen=True, strict=True):
|
||||
name: str
|
||||
model: str | None = None
|
||||
modified_at: str | None = None
|
||||
size: int | None = None
|
||||
digest: str | None = None
|
||||
details: OllamaModelDetails | None = None
|
||||
|
||||
|
||||
class OllamaTagsResponse(BaseModel, frozen=True, strict=True):
|
||||
models: list[OllamaModelTag]
|
||||
|
||||
|
||||
class OllamaShowResponse(BaseModel, frozen=True, strict=True):
|
||||
modelfile: str | None = None
|
||||
parameters: str | None = None
|
||||
template: str | None = None
|
||||
details: OllamaModelDetails | None = None
|
||||
model_info: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class OllamaPsModel(BaseModel, frozen=True, strict=True):
|
||||
name: str
|
||||
model: str
|
||||
size: int
|
||||
digest: str | None = None
|
||||
details: OllamaModelDetails | None = None
|
||||
expires_at: str | None = None
|
||||
size_vram: int | None = None
|
||||
|
||||
|
||||
class OllamaPsResponse(BaseModel, frozen=True, strict=True):
|
||||
models: list[OllamaPsModel]
|
||||
@@ -145,23 +145,7 @@ class ResponseFunctionCallItem(BaseModel, frozen=True):
|
||||
status: ResponseStatus = "completed"
|
||||
|
||||
|
||||
class ResponseReasoningSummaryText(BaseModel, frozen=True):
|
||||
"""Summary text part in a reasoning output item."""
|
||||
|
||||
type: Literal["summary_text"] = "summary_text"
|
||||
text: str
|
||||
|
||||
|
||||
class ResponseReasoningItem(BaseModel, frozen=True):
|
||||
"""Reasoning output item in response output array."""
|
||||
|
||||
type: Literal["reasoning"] = "reasoning"
|
||||
id: str
|
||||
summary: list[ResponseReasoningSummaryText] = Field(default_factory=list)
|
||||
status: ResponseStatus = "completed"
|
||||
|
||||
|
||||
ResponseItem = ResponseMessageItem | ResponseFunctionCallItem | ResponseReasoningItem
|
||||
ResponseItem = ResponseMessageItem | ResponseFunctionCallItem
|
||||
|
||||
|
||||
class ResponseUsage(BaseModel, frozen=True):
|
||||
@@ -289,58 +273,6 @@ class ResponseFunctionCallArgumentsDoneEvent(BaseModel, frozen=True):
|
||||
arguments: str
|
||||
|
||||
|
||||
class ResponseReasoningSummaryPartAddedEvent(BaseModel, frozen=True):
|
||||
"""Event sent when a reasoning summary part is added."""
|
||||
|
||||
type: Literal["response.reasoning_summary_part.added"] = (
|
||||
"response.reasoning_summary_part.added"
|
||||
)
|
||||
sequence_number: int
|
||||
item_id: str
|
||||
output_index: int
|
||||
summary_index: int
|
||||
part: ResponseReasoningSummaryText
|
||||
|
||||
|
||||
class ResponseReasoningSummaryTextDeltaEvent(BaseModel, frozen=True):
|
||||
"""Event sent for reasoning summary text delta during streaming."""
|
||||
|
||||
type: Literal["response.reasoning_summary_text.delta"] = (
|
||||
"response.reasoning_summary_text.delta"
|
||||
)
|
||||
sequence_number: int
|
||||
item_id: str
|
||||
output_index: int
|
||||
summary_index: int
|
||||
delta: str
|
||||
|
||||
|
||||
class ResponseReasoningSummaryTextDoneEvent(BaseModel, frozen=True):
|
||||
"""Event sent when reasoning summary text is done."""
|
||||
|
||||
type: Literal["response.reasoning_summary_text.done"] = (
|
||||
"response.reasoning_summary_text.done"
|
||||
)
|
||||
sequence_number: int
|
||||
item_id: str
|
||||
output_index: int
|
||||
summary_index: int
|
||||
text: str
|
||||
|
||||
|
||||
class ResponseReasoningSummaryPartDoneEvent(BaseModel, frozen=True):
|
||||
"""Event sent when a reasoning summary part is done."""
|
||||
|
||||
type: Literal["response.reasoning_summary_part.done"] = (
|
||||
"response.reasoning_summary_part.done"
|
||||
)
|
||||
sequence_number: int
|
||||
item_id: str
|
||||
output_index: int
|
||||
summary_index: int
|
||||
part: ResponseReasoningSummaryText
|
||||
|
||||
|
||||
class ResponseCompletedEvent(BaseModel, frozen=True):
|
||||
"""Event sent when response is completed."""
|
||||
|
||||
@@ -360,9 +292,5 @@ ResponsesStreamEvent = (
|
||||
| ResponseOutputItemDoneEvent
|
||||
| ResponseFunctionCallArgumentsDeltaEvent
|
||||
| ResponseFunctionCallArgumentsDoneEvent
|
||||
| ResponseReasoningSummaryPartAddedEvent
|
||||
| ResponseReasoningSummaryTextDeltaEvent
|
||||
| ResponseReasoningSummaryTextDoneEvent
|
||||
| ResponseReasoningSummaryPartDoneEvent
|
||||
| ResponseCompletedEvent
|
||||
)
|
||||
|
||||
@@ -10,9 +10,9 @@ from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
|
||||
|
||||
|
||||
class DownloadProgressData(CamelCaseModel):
|
||||
total: Memory
|
||||
downloaded: Memory
|
||||
downloaded_this_session: Memory
|
||||
total_bytes: Memory
|
||||
downloaded_bytes: Memory
|
||||
downloaded_bytes_this_session: Memory
|
||||
|
||||
completed_files: int
|
||||
total_files: int
|
||||
@@ -34,7 +34,7 @@ class DownloadPending(BaseDownloadProgress):
|
||||
|
||||
|
||||
class DownloadCompleted(BaseDownloadProgress):
|
||||
total: Memory
|
||||
total_bytes: Memory
|
||||
|
||||
|
||||
class DownloadFailed(BaseDownloadProgress):
|
||||
@@ -86,9 +86,9 @@ class RepoDownloadProgress(BaseModel):
|
||||
shard: ShardMetadata
|
||||
completed_files: int
|
||||
total_files: int
|
||||
downloaded: Memory
|
||||
downloaded_this_session: Memory
|
||||
total: Memory
|
||||
downloaded_bytes: Memory
|
||||
downloaded_bytes_this_session: Memory
|
||||
total_bytes: Memory
|
||||
overall_speed: float
|
||||
overall_eta: timedelta
|
||||
status: Literal["not_started", "in_progress", "complete"]
|
||||
|
||||
@@ -28,7 +28,6 @@ class GenerationResponse(BaseRunnerResponse):
|
||||
finish_reason: FinishReason | None = None
|
||||
stats: GenerationStats | None = None
|
||||
usage: Usage | None
|
||||
is_thinking: bool = False
|
||||
|
||||
|
||||
class ImageGenerationResponse(BaseRunnerResponse):
|
||||
|
||||
@@ -192,13 +192,7 @@ class MpReceiver[T]:
|
||||
try:
|
||||
return self.receive_nowait()
|
||||
except WouldBlock:
|
||||
try:
|
||||
item = self._state.buffer.get()
|
||||
except (TypeError, OSError):
|
||||
# Queue pipe can get closed while we are blocked on get().
|
||||
# The underlying connection._handle becomes None, causing
|
||||
# TypeError in read(handle, remaining).
|
||||
raise ClosedResourceError from None
|
||||
item = self._state.buffer.get()
|
||||
if isinstance(item, _MpEndOfStream):
|
||||
self.close()
|
||||
raise EndOfStream from None
|
||||
|
||||
@@ -108,7 +108,7 @@ async def check_reachable(
|
||||
await send.send((target_ip, expected_node_id))
|
||||
|
||||
async with (
|
||||
httpx.AsyncClient(timeout=timeout, limits=limits, verify=False) as client,
|
||||
httpx.AsyncClient(timeout=timeout, limits=limits) as client,
|
||||
create_task_group() as tg,
|
||||
):
|
||||
for node_id in topology.list_nodes():
|
||||
|
||||
@@ -166,7 +166,7 @@ def generate_image(
|
||||
else 0.0
|
||||
)
|
||||
|
||||
peak_memory = Memory.from_bytes(mx.get_peak_memory())
|
||||
peak_memory_gb = mx.get_peak_memory() / (1024**3)
|
||||
|
||||
stats = ImageGenerationStats(
|
||||
seconds_per_step=seconds_per_step,
|
||||
@@ -175,7 +175,7 @@ def generate_image(
|
||||
num_images=num_images,
|
||||
image_width=width,
|
||||
image_height=height,
|
||||
peak_memory_usage=peak_memory,
|
||||
peak_memory_usage=Memory.from_gb(peak_memory_gb),
|
||||
)
|
||||
|
||||
buffer = io.BytesIO()
|
||||
|
||||
@@ -22,7 +22,7 @@ from exo.worker.runner.bootstrap import logger
|
||||
# Fraction of device memory above which LRU eviction kicks in.
|
||||
# Smaller machines need more aggressive eviction.
|
||||
def _default_memory_threshold() -> float:
|
||||
total_gb = Memory.from_bytes(psutil.virtual_memory().total).in_gb
|
||||
total_gb = psutil.virtual_memory().total / (1024**3)
|
||||
if total_gb >= 128:
|
||||
return 0.85
|
||||
if total_gb >= 64:
|
||||
|
||||
@@ -1,72 +0,0 @@
|
||||
import json
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from mlx_lm.chat_templates import deepseek_v32
|
||||
|
||||
from exo.shared.types.api import ToolCallItem
|
||||
|
||||
BOS_TOKEN: str = deepseek_v32.bos_token
|
||||
EOS_TOKEN: str = deepseek_v32.eos_token
|
||||
DSML_TOKEN: str = deepseek_v32.dsml_token
|
||||
THINKING_START: str = deepseek_v32.thinking_start_token
|
||||
THINKING_END: str = deepseek_v32.thinking_end_token
|
||||
USER_TOKEN = "<\uff5cUser\uff5c>"
|
||||
ASSISTANT_TOKEN = "<\uff5cAssistant\uff5c>"
|
||||
TOOL_CALLS_START = f"<{DSML_TOKEN}function_calls>"
|
||||
TOOL_CALLS_END = f"</{DSML_TOKEN}function_calls>"
|
||||
encode_messages = deepseek_v32.encode_messages
|
||||
|
||||
_INVOKE_PATTERN = re.compile(
|
||||
rf"<{re.escape(DSML_TOKEN)}invoke\s+name=\"([^\"]+)\">"
|
||||
rf"(.*?)"
|
||||
rf"</{re.escape(DSML_TOKEN)}invoke>",
|
||||
re.DOTALL,
|
||||
)
|
||||
|
||||
_PARAM_PATTERN = re.compile(
|
||||
rf"<{re.escape(DSML_TOKEN)}parameter\s+name=\"([^\"]+)\"\s+string=\"(true|false)\">"
|
||||
rf"(.*?)"
|
||||
rf"</{re.escape(DSML_TOKEN)}parameter>",
|
||||
re.DOTALL,
|
||||
)
|
||||
|
||||
|
||||
def parse_dsml_output(text: str) -> list[ToolCallItem] | None:
|
||||
"""Parse DSML function_calls block from model output text.
|
||||
|
||||
Args:
|
||||
text: The text containing the DSML function_calls block
|
||||
(including the start/end markers).
|
||||
|
||||
Returns:
|
||||
List of ToolCallItem, or None if parsing fails.
|
||||
"""
|
||||
tool_calls: list[ToolCallItem] = []
|
||||
|
||||
for invoke_match in _INVOKE_PATTERN.finditer(text):
|
||||
func_name = invoke_match.group(1)
|
||||
invoke_body = invoke_match.group(2)
|
||||
|
||||
args: dict[str, Any] = {}
|
||||
for param_match in _PARAM_PATTERN.finditer(invoke_body):
|
||||
param_name = param_match.group(1)
|
||||
is_string = param_match.group(2) == "true"
|
||||
param_value = param_match.group(3)
|
||||
|
||||
if is_string:
|
||||
args[param_name] = param_value
|
||||
else:
|
||||
try:
|
||||
args[param_name] = json.loads(param_value)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
args[param_name] = param_value
|
||||
|
||||
tool_calls.append(
|
||||
ToolCallItem(
|
||||
name=func_name,
|
||||
arguments=json.dumps(args),
|
||||
)
|
||||
)
|
||||
|
||||
return tool_calls if tool_calls else None
|
||||
@@ -232,11 +232,11 @@ def shard_and_load(
|
||||
|
||||
# Estimate timeout based on model size (5x default for large queued workloads)
|
||||
base_timeout = float(os.environ.get("EXO_MODEL_LOAD_TIMEOUT", "300"))
|
||||
model_size = get_weights_size(shard_metadata)
|
||||
timeout_seconds = base_timeout + model_size.in_gb
|
||||
model_size_gb = get_weights_size(shard_metadata).in_bytes / (1024**3)
|
||||
timeout_seconds = base_timeout + model_size_gb
|
||||
logger.info(
|
||||
f"Evaluating model parameters with timeout of {timeout_seconds:.0f}s "
|
||||
f"(model size: {model_size.in_gb:.1f}GB)"
|
||||
f"(model size: {model_size_gb:.1f}GB)"
|
||||
)
|
||||
|
||||
match shard_metadata:
|
||||
@@ -458,19 +458,6 @@ def _patch_lossy_chat_template(template: str) -> str | None:
|
||||
return patched if n > 0 else None
|
||||
|
||||
|
||||
def _needs_dsml_encoding(task_params: TextGenerationTaskParams) -> bool:
|
||||
if "deepseek-v3.2" not in task_params.model.lower():
|
||||
return False
|
||||
# Use DSML encoding when tools are provided or tool results are in the conversation
|
||||
if task_params.tools:
|
||||
return True
|
||||
if task_params.chat_template_messages:
|
||||
return any(
|
||||
msg.get("role") == "tool" for msg in task_params.chat_template_messages
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
def apply_chat_template(
|
||||
tokenizer: TokenizerWrapper,
|
||||
task_params: TextGenerationTaskParams,
|
||||
@@ -482,6 +469,7 @@ def apply_chat_template(
|
||||
|
||||
When chat_template_messages is available (from Chat Completions API),
|
||||
uses those directly to preserve tool_calls, thinking, and other fields.
|
||||
Otherwise builds messages from the task params input/instructions.
|
||||
"""
|
||||
formatted_messages: list[dict[str, Any]] = []
|
||||
if task_params.chat_template_messages is not None:
|
||||
@@ -509,19 +497,6 @@ def apply_chat_template(
|
||||
partial_assistant_content = cast(str, formatted_messages[-1].get("content", ""))
|
||||
formatted_messages = formatted_messages[:-1]
|
||||
|
||||
if _needs_dsml_encoding(task_params):
|
||||
from exo.worker.engines.mlx.dsml_encoding import encode_messages
|
||||
|
||||
prompt = encode_messages(
|
||||
messages=formatted_messages,
|
||||
thinking_mode="thinking" if task_params.enable_thinking else "chat",
|
||||
tools=task_params.tools,
|
||||
)
|
||||
if partial_assistant_content:
|
||||
prompt += partial_assistant_content
|
||||
logger.info(prompt)
|
||||
return prompt
|
||||
|
||||
extra_kwargs: dict[str, Any] = {}
|
||||
if task_params.enable_thinking is not None:
|
||||
# Qwen3 and GLM use "enable_thinking"; DeepSeek uses "thinking".
|
||||
@@ -642,17 +617,18 @@ def set_wired_limit_for_model(model_size: Memory):
|
||||
if not mx.metal.is_available():
|
||||
return
|
||||
|
||||
max_rec_size = Memory.from_bytes(
|
||||
int(mx.metal.device_info()["max_recommended_working_set_size"])
|
||||
)
|
||||
if model_size > 0.9 * max_rec_size:
|
||||
model_bytes = model_size.in_bytes
|
||||
max_rec_size = int(mx.metal.device_info()["max_recommended_working_set_size"])
|
||||
if model_bytes > 0.9 * max_rec_size:
|
||||
model_mb = model_bytes // 2**20
|
||||
max_rec_mb = max_rec_size // 2**20
|
||||
logger.warning(
|
||||
f"Generating with a model that requires {model_size.in_float_mb:.1f} MB "
|
||||
f"which is close to the maximum recommended size of {max_rec_size.in_float_mb:.1f} "
|
||||
f"Generating with a model that requires {model_mb} MB "
|
||||
f"which is close to the maximum recommended size of {max_rec_mb} "
|
||||
"MB. This can be slow. See the documentation for possible work-arounds: "
|
||||
"https://github.com/ml-explore/mlx-lm/tree/main#large-models"
|
||||
)
|
||||
mx.set_wired_limit(max_rec_size.in_bytes)
|
||||
mx.set_wired_limit(max_rec_size)
|
||||
logger.info(f"Wired limit set to {max_rec_size}.")
|
||||
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timezone
|
||||
from random import random
|
||||
from typing import Iterator
|
||||
|
||||
import anyio
|
||||
from anyio import CancelScope, create_task_group, fail_after
|
||||
@@ -17,13 +16,14 @@ from exo.shared.types.commands import (
|
||||
RequestEventLog,
|
||||
StartDownload,
|
||||
)
|
||||
from exo.shared.types.common import CommandId, NodeId, SessionId
|
||||
from exo.shared.types.common import CommandId, NodeId, SessionId, SystemId
|
||||
from exo.shared.types.events import (
|
||||
Event,
|
||||
EventId,
|
||||
ForwarderEvent,
|
||||
GlobalForwarderEvent,
|
||||
IndexedEvent,
|
||||
InputChunkReceived,
|
||||
LocalForwarderEvent,
|
||||
NodeGatheredInfo,
|
||||
TaskCreated,
|
||||
TaskStatusUpdated,
|
||||
@@ -58,24 +58,22 @@ class Worker:
|
||||
node_id: NodeId,
|
||||
session_id: SessionId,
|
||||
*,
|
||||
global_event_receiver: Receiver[ForwarderEvent],
|
||||
local_event_sender: Sender[ForwarderEvent],
|
||||
global_event_receiver: Receiver[GlobalForwarderEvent],
|
||||
local_event_sender: Sender[LocalForwarderEvent],
|
||||
# This is for requesting updates. It doesn't need to be a general command sender right now,
|
||||
# but I think it's the correct way to be thinking about commands
|
||||
command_sender: Sender[ForwarderCommand],
|
||||
download_command_sender: Sender[ForwarderDownloadCommand],
|
||||
event_index_counter: Iterator[int],
|
||||
):
|
||||
self.node_id: NodeId = node_id
|
||||
self.session_id: SessionId = session_id
|
||||
|
||||
self.global_event_receiver = global_event_receiver
|
||||
self.local_event_sender = local_event_sender
|
||||
self.event_index_counter = event_index_counter
|
||||
self.command_sender = command_sender
|
||||
self.download_command_sender = download_command_sender
|
||||
self.event_buffer = OrderedBuffer[Event]()
|
||||
self.out_for_delivery: dict[EventId, ForwarderEvent] = {}
|
||||
self.out_for_delivery: dict[EventId, LocalForwarderEvent] = {}
|
||||
|
||||
self.state: State = State()
|
||||
self.runners: dict[RunnerId, RunnerSupervisor] = {}
|
||||
@@ -86,6 +84,8 @@ class Worker:
|
||||
self._nack_base_seconds: float = 0.5
|
||||
self._nack_cap_seconds: float = 10.0
|
||||
|
||||
self._system_id = SystemId()
|
||||
|
||||
self.event_sender, self.event_receiver = channel[Event]()
|
||||
|
||||
# Buffer for input image chunks (for image editing)
|
||||
@@ -132,6 +132,8 @@ class Worker:
|
||||
async def _event_applier(self):
|
||||
with self.global_event_receiver as events:
|
||||
async for f_event in events:
|
||||
if f_event.session != self.session_id:
|
||||
continue
|
||||
if f_event.origin != self.session_id.master_node_id:
|
||||
continue
|
||||
self.event_buffer.ingest(f_event.origin_idx, f_event.event)
|
||||
@@ -212,7 +214,7 @@ class Worker:
|
||||
|
||||
await self.download_command_sender.send(
|
||||
ForwarderDownloadCommand(
|
||||
origin=self.node_id,
|
||||
origin=self._system_id,
|
||||
command=StartDownload(
|
||||
target_node_id=self.node_id,
|
||||
shard_metadata=shard,
|
||||
@@ -317,7 +319,7 @@ class Worker:
|
||||
)
|
||||
await self.command_sender.send(
|
||||
ForwarderCommand(
|
||||
origin=self.node_id,
|
||||
origin=self._system_id,
|
||||
command=RequestEventLog(since_idx=since_idx),
|
||||
)
|
||||
)
|
||||
@@ -344,15 +346,16 @@ class Worker:
|
||||
return runner
|
||||
|
||||
async def _forward_events(self) -> None:
|
||||
idx = 0
|
||||
with self.event_receiver as events:
|
||||
async for event in events:
|
||||
idx = next(self.event_index_counter)
|
||||
fe = ForwarderEvent(
|
||||
fe = LocalForwarderEvent(
|
||||
origin_idx=idx,
|
||||
origin=self.node_id,
|
||||
origin=self._system_id,
|
||||
session=self.session_id,
|
||||
event=event,
|
||||
)
|
||||
idx += 1
|
||||
logger.debug(f"Worker published event {idx}: {str(event)[:100]}")
|
||||
await self.local_event_sender.send(fe)
|
||||
self.out_for_delivery[event.event_id] = fe
|
||||
|
||||
@@ -4,10 +4,9 @@ import resource
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from functools import cache
|
||||
from typing import TYPE_CHECKING, Literal
|
||||
from typing import Literal
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx_lm.models.deepseek_v32 import Model as DeepseekV32Model
|
||||
from mlx_lm.models.gpt_oss import Model as GptOssModel
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
|
||||
@@ -22,17 +21,12 @@ from exo.shared.constants import EXO_MAX_CHUNK_SIZE, EXO_TRACING_ENABLED
|
||||
from exo.shared.models.model_cards import ModelId, ModelTask
|
||||
from exo.shared.tracing import clear_trace_buffer, get_trace_buffer
|
||||
from exo.shared.types.api import ImageGenerationStats
|
||||
from exo.shared.types.chunks import (
|
||||
ErrorChunk,
|
||||
ImageChunk,
|
||||
PrefillProgressChunk,
|
||||
TokenChunk,
|
||||
ToolCallChunk,
|
||||
)
|
||||
from exo.shared.types.chunks import ErrorChunk, ImageChunk, TokenChunk, ToolCallChunk
|
||||
from exo.shared.types.common import CommandId
|
||||
from exo.shared.types.events import (
|
||||
ChunkGenerated,
|
||||
Event,
|
||||
PrefillProgress,
|
||||
RunnerStatusUpdated,
|
||||
TaskAcknowledged,
|
||||
TaskStatusUpdated,
|
||||
@@ -321,13 +315,11 @@ def main(
|
||||
) -> None:
|
||||
if device_rank == 0:
|
||||
event_sender.send(
|
||||
ChunkGenerated(
|
||||
PrefillProgress(
|
||||
command_id=command_id,
|
||||
chunk=PrefillProgressChunk(
|
||||
model=shard_metadata.model_card.model_id,
|
||||
processed_tokens=processed,
|
||||
total_tokens=total,
|
||||
),
|
||||
model=shard_metadata.model_card.model_id,
|
||||
processed_tokens=processed,
|
||||
total_tokens=total,
|
||||
)
|
||||
)
|
||||
cancelled_tasks.update(cancel_receiver.collect())
|
||||
@@ -354,22 +346,16 @@ def main(
|
||||
group=group,
|
||||
)
|
||||
|
||||
if tokenizer.has_thinking:
|
||||
# For other thinking models (GLM, etc.), check if we need to
|
||||
# prepend the thinking tag that was consumed by the chat template
|
||||
if detect_thinking_prompt_suffix(prompt, tokenizer):
|
||||
mlx_generator = parse_thinking_models(
|
||||
mlx_generator,
|
||||
tokenizer,
|
||||
# For other thinking models (GLM, etc.), check if we need to
|
||||
# prepend the thinking tag that was consumed by the chat template
|
||||
starts_in_thinking=detect_thinking_prompt_suffix(
|
||||
prompt, tokenizer
|
||||
),
|
||||
mlx_generator, tokenizer
|
||||
)
|
||||
|
||||
# Model-specific output parsing for tool calls.
|
||||
# GPT-OSS specific parsing to match other model formats.
|
||||
if isinstance(inference_model, GptOssModel):
|
||||
mlx_generator = parse_gpt_oss(mlx_generator)
|
||||
elif isinstance(inference_model, DeepseekV32Model):
|
||||
mlx_generator = parse_deepseek_v32(mlx_generator)
|
||||
elif tool_parser:
|
||||
mlx_generator = parse_tool_calls(mlx_generator, tool_parser)
|
||||
|
||||
@@ -421,7 +407,6 @@ def main(
|
||||
stats=response.stats,
|
||||
logprob=response.logprob,
|
||||
top_logprobs=response.top_logprobs,
|
||||
is_thinking=response.is_thinking,
|
||||
),
|
||||
)
|
||||
)
|
||||
@@ -588,13 +573,6 @@ def main(
|
||||
case Shutdown():
|
||||
current_status = RunnerShuttingDown()
|
||||
logger.info("runner shutting down")
|
||||
if not TYPE_CHECKING:
|
||||
del inference_model, image_model, tokenizer, group
|
||||
mx.clear_cache()
|
||||
import gc
|
||||
|
||||
gc.collect()
|
||||
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
@@ -619,8 +597,12 @@ def main(
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
|
||||
)
|
||||
|
||||
if isinstance(current_status, RunnerShutdown):
|
||||
del inference_model, image_model, tokenizer, group
|
||||
mx.clear_cache()
|
||||
import gc
|
||||
|
||||
gc.collect()
|
||||
break
|
||||
|
||||
|
||||
@@ -686,208 +668,44 @@ def parse_gpt_oss(
|
||||
|
||||
if ch == "analysis" and not thinking:
|
||||
thinking = True
|
||||
yield response.model_copy(update={"text": "<think>"})
|
||||
|
||||
if ch != "analysis" and thinking:
|
||||
thinking = False
|
||||
yield response.model_copy(update={"text": "</think>"})
|
||||
|
||||
if delta:
|
||||
yield response.model_copy(update={"text": delta, "is_thinking": thinking})
|
||||
yield response.model_copy(update={"text": delta})
|
||||
|
||||
if response.finish_reason is not None:
|
||||
if thinking:
|
||||
yield response.model_copy(update={"text": "</think>"})
|
||||
yield response
|
||||
|
||||
|
||||
def parse_deepseek_v32(
|
||||
responses: Generator[GenerationResponse],
|
||||
) -> Generator[GenerationResponse | ToolCallResponse]:
|
||||
"""Parse DeepSeek V3.2 DSML tool calls from the generation stream.
|
||||
|
||||
Uses accumulated-text matching (not per-token marker checks) because
|
||||
DSML markers like <|DSML|function_calls> may span multiple tokens.
|
||||
Also handles <think>...</think> blocks for thinking mode.
|
||||
"""
|
||||
from exo.worker.engines.mlx.dsml_encoding import (
|
||||
THINKING_END,
|
||||
THINKING_START,
|
||||
TOOL_CALLS_END,
|
||||
TOOL_CALLS_START,
|
||||
parse_dsml_output,
|
||||
)
|
||||
|
||||
accumulated = ""
|
||||
in_tool_call = False
|
||||
thinking = False
|
||||
# Tokens buffered while we detect the start of a DSML block
|
||||
pending_buffer: list[GenerationResponse] = []
|
||||
# Text accumulated during a tool call block
|
||||
tool_call_text = ""
|
||||
|
||||
for response in responses:
|
||||
assert isinstance(response, GenerationResponse)
|
||||
|
||||
# ── Handle thinking tags ──
|
||||
if not thinking and THINKING_START in response.text:
|
||||
thinking = True
|
||||
# Yield any text before the <think> tag
|
||||
before = response.text[: response.text.index(THINKING_START)]
|
||||
if before:
|
||||
yield response.model_copy(update={"text": before})
|
||||
continue
|
||||
|
||||
if thinking and THINKING_END in response.text:
|
||||
thinking = False
|
||||
# Yield any text after the </think> tag
|
||||
after = response.text[
|
||||
response.text.index(THINKING_END) + len(THINKING_END) :
|
||||
]
|
||||
if after:
|
||||
yield response.model_copy(update={"text": after, "is_thinking": False})
|
||||
continue
|
||||
|
||||
if thinking:
|
||||
yield response.model_copy(update={"is_thinking": True})
|
||||
continue
|
||||
|
||||
# ── Handle tool call accumulation ──
|
||||
if in_tool_call:
|
||||
tool_call_text += response.text
|
||||
if TOOL_CALLS_END in tool_call_text:
|
||||
# Parse the accumulated DSML block
|
||||
parsed = parse_dsml_output(tool_call_text)
|
||||
if parsed is not None:
|
||||
logger.info(f"parsed DSML tool calls: {parsed}")
|
||||
yield ToolCallResponse(
|
||||
tool_calls=parsed,
|
||||
usage=response.usage,
|
||||
stats=response.stats,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"DSML tool call parsing failed for: {tool_call_text}"
|
||||
)
|
||||
yield response.model_copy(update={"text": tool_call_text})
|
||||
in_tool_call = False
|
||||
tool_call_text = ""
|
||||
continue
|
||||
|
||||
# EOS reached before end marker — yield buffered text as-is
|
||||
if response.finish_reason is not None:
|
||||
logger.info("DSML tool call parsing interrupted by EOS")
|
||||
yield response.model_copy(update={"text": tool_call_text})
|
||||
in_tool_call = False
|
||||
tool_call_text = ""
|
||||
continue
|
||||
|
||||
# ── Detect start of tool call block ──
|
||||
accumulated += response.text
|
||||
|
||||
if TOOL_CALLS_START in accumulated:
|
||||
# The start marker might be split across pending_buffer + current token
|
||||
start_idx = accumulated.index(TOOL_CALLS_START)
|
||||
# Yield any pending tokens that are purely before the marker
|
||||
pre_text = accumulated[:start_idx]
|
||||
if pre_text:
|
||||
# Flush pending buffer tokens that contributed text before the marker
|
||||
for buf_resp in pending_buffer:
|
||||
if pre_text:
|
||||
chunk = buf_resp.text
|
||||
if len(chunk) <= len(pre_text):
|
||||
yield buf_resp
|
||||
pre_text = pre_text[len(chunk) :]
|
||||
else:
|
||||
yield buf_resp.model_copy(update={"text": pre_text})
|
||||
pre_text = ""
|
||||
pending_buffer = []
|
||||
tool_call_text = accumulated[start_idx:]
|
||||
accumulated = ""
|
||||
|
||||
# Check if the end marker is already present (entire tool call in one token)
|
||||
if TOOL_CALLS_END in tool_call_text:
|
||||
parsed = parse_dsml_output(tool_call_text)
|
||||
if parsed is not None:
|
||||
logger.info(f"parsed DSML tool calls: {parsed}")
|
||||
yield ToolCallResponse(
|
||||
tool_calls=parsed,
|
||||
usage=response.usage,
|
||||
stats=response.stats,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"DSML tool call parsing failed for: {tool_call_text}"
|
||||
)
|
||||
yield response.model_copy(update={"text": tool_call_text})
|
||||
tool_call_text = ""
|
||||
else:
|
||||
in_tool_call = True
|
||||
continue
|
||||
|
||||
# Check if accumulated text might be the start of a DSML marker
|
||||
# Buffer tokens if we see a partial match at the end
|
||||
if _could_be_dsml_prefix(accumulated):
|
||||
pending_buffer.append(response)
|
||||
continue
|
||||
|
||||
# No partial match — flush all pending tokens and the current one
|
||||
for buf_resp in pending_buffer:
|
||||
yield buf_resp
|
||||
pending_buffer = []
|
||||
accumulated = ""
|
||||
yield response
|
||||
|
||||
# Flush any remaining pending buffer at generator end
|
||||
for buf_resp in pending_buffer:
|
||||
yield buf_resp
|
||||
|
||||
|
||||
def _could_be_dsml_prefix(text: str) -> bool:
|
||||
"""Check if the end of text could be the start of a DSML function_calls marker.
|
||||
|
||||
We look for suffixes of text that are prefixes of the TOOL_CALLS_START pattern.
|
||||
This allows us to buffer tokens until we can determine if a tool call is starting.
|
||||
"""
|
||||
from exo.worker.engines.mlx.dsml_encoding import TOOL_CALLS_START
|
||||
|
||||
# Only check the last portion of text that could overlap with the marker
|
||||
max_check = len(TOOL_CALLS_START)
|
||||
tail = text[-max_check:] if len(text) > max_check else text
|
||||
|
||||
# Check if any suffix of tail is a prefix of TOOL_CALLS_START
|
||||
for i in range(len(tail)):
|
||||
suffix = tail[i:]
|
||||
if TOOL_CALLS_START.startswith(suffix):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def parse_thinking_models(
|
||||
responses: Generator[GenerationResponse],
|
||||
tokenizer: TokenizerWrapper,
|
||||
starts_in_thinking: bool = True,
|
||||
) -> Generator[GenerationResponse]:
|
||||
"""Route thinking tokens via is_thinking flag.
|
||||
|
||||
Swallows think tag tokens, sets is_thinking on all others.
|
||||
Always yields tokens with finish_reason to avoid hanging the chunk stream.
|
||||
"""
|
||||
in_thinking = starts_in_thinking
|
||||
For models that inject thinking tags in the prompt (like GLM-4.7),
|
||||
prepend the thinking tag to the output stream so the frontend
|
||||
can properly parse thinking content.
|
||||
"""
|
||||
first = True
|
||||
for response in responses:
|
||||
if isinstance(response, ToolCallResponse):
|
||||
yield response
|
||||
continue
|
||||
|
||||
is_think_tag = (
|
||||
tokenizer.think_end is not None and response.text == tokenizer.think_end
|
||||
) or (
|
||||
tokenizer.think_start is not None and response.text == tokenizer.think_start
|
||||
)
|
||||
|
||||
if is_think_tag:
|
||||
in_thinking = response.text != tokenizer.think_end
|
||||
# Never swallow finish_reason — the chunk stream needs it to terminate.
|
||||
if response.finish_reason is not None:
|
||||
yield response.model_copy(update={"text": "", "is_thinking": False})
|
||||
continue
|
||||
yield response.model_copy(update={"is_thinking": in_thinking})
|
||||
if first:
|
||||
first = False
|
||||
yield response.model_copy(
|
||||
update={
|
||||
"text": tokenizer.think_start,
|
||||
"token": tokenizer.think_start_id,
|
||||
}
|
||||
)
|
||||
yield response
|
||||
|
||||
|
||||
def _send_image_chunk(
|
||||
|
||||
@@ -100,8 +100,8 @@ class RunnerSupervisor:
|
||||
logger.info("Runner supervisor shutting down")
|
||||
self._ev_recv.close()
|
||||
self._task_sender.close()
|
||||
with contextlib.suppress(ClosedResourceError):
|
||||
self._cancel_sender.send(TaskId("CANCEL_CURRENT_TASK"))
|
||||
self._event_sender.close()
|
||||
self._cancel_sender.send(TaskId("CANCEL_CURRENT_TASK"))
|
||||
self._cancel_sender.close()
|
||||
self.runner_process.join(5)
|
||||
if not self.runner_process.is_alive():
|
||||
@@ -180,7 +180,6 @@ class RunnerSupervisor:
|
||||
await self._check_runner(e)
|
||||
for tid in self.pending:
|
||||
self.pending[tid].set()
|
||||
self._event_sender.close()
|
||||
|
||||
def __del__(self) -> None:
|
||||
if self.runner_process.is_alive():
|
||||
@@ -209,15 +208,10 @@ class RunnerSupervisor:
|
||||
|
||||
logger.opt(exception=e).error(f"Runner terminated ({cause})")
|
||||
|
||||
try:
|
||||
await self._event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=self.bound_instance.bound_runner_id,
|
||||
runner_status=RunnerFailed(error_message=f"Terminated ({cause})"),
|
||||
)
|
||||
)
|
||||
except (ClosedResourceError, BrokenResourceError):
|
||||
logger.warning(
|
||||
"Event sender already closed, unable to report runner failure"
|
||||
await self._event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=self.bound_instance.bound_runner_id,
|
||||
runner_status=RunnerFailed(error_message=f"Terminated ({cause})"),
|
||||
)
|
||||
)
|
||||
self.shutdown()
|
||||
|
||||
@@ -90,10 +90,14 @@ def test_plan_loads_model_when_all_shards_downloaded_and_waiting():
|
||||
|
||||
global_download_status = {
|
||||
NODE_A: [
|
||||
DownloadCompleted(shard_metadata=shard1, node_id=NODE_A, total=Memory())
|
||||
DownloadCompleted(
|
||||
shard_metadata=shard1, node_id=NODE_A, total_bytes=Memory()
|
||||
)
|
||||
],
|
||||
NODE_B: [
|
||||
DownloadCompleted(shard_metadata=shard2, node_id=NODE_B, total=Memory())
|
||||
DownloadCompleted(
|
||||
shard_metadata=shard2, node_id=NODE_B, total_bytes=Memory()
|
||||
)
|
||||
],
|
||||
}
|
||||
|
||||
@@ -134,7 +138,9 @@ def test_plan_does_not_request_download_when_shard_already_downloaded():
|
||||
# Global state shows shard is downloaded for NODE_A
|
||||
global_download_status: dict[NodeId, list[DownloadProgress]] = {
|
||||
NODE_A: [
|
||||
DownloadCompleted(shard_metadata=shard, node_id=NODE_A, total=Memory())
|
||||
DownloadCompleted(
|
||||
shard_metadata=shard, node_id=NODE_A, total_bytes=Memory()
|
||||
)
|
||||
],
|
||||
NODE_B: [],
|
||||
}
|
||||
@@ -181,7 +187,9 @@ def test_plan_does_not_load_model_until_all_shards_downloaded_globally():
|
||||
|
||||
global_download_status = {
|
||||
NODE_A: [
|
||||
DownloadCompleted(shard_metadata=shard1, node_id=NODE_A, total=Memory())
|
||||
DownloadCompleted(
|
||||
shard_metadata=shard1, node_id=NODE_A, total_bytes=Memory()
|
||||
)
|
||||
],
|
||||
NODE_B: [], # NODE_B has no downloads completed yet
|
||||
}
|
||||
@@ -199,10 +207,14 @@ def test_plan_does_not_load_model_until_all_shards_downloaded_globally():
|
||||
|
||||
global_download_status = {
|
||||
NODE_A: [
|
||||
DownloadCompleted(shard_metadata=shard1, node_id=NODE_A, total=Memory())
|
||||
DownloadCompleted(
|
||||
shard_metadata=shard1, node_id=NODE_A, total_bytes=Memory()
|
||||
)
|
||||
],
|
||||
NODE_B: [
|
||||
DownloadCompleted(shard_metadata=shard2, node_id=NODE_B, total=Memory())
|
||||
DownloadCompleted(
|
||||
shard_metadata=shard2, node_id=NODE_B, total_bytes=Memory()
|
||||
)
|
||||
], # NODE_B has no downloads completed yet
|
||||
}
|
||||
|
||||
|
||||
@@ -1,967 +0,0 @@
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
from exo.shared.types.worker.runner_response import (
|
||||
GenerationResponse,
|
||||
ToolCallResponse,
|
||||
)
|
||||
from exo.worker.engines.mlx.dsml_encoding import (
|
||||
ASSISTANT_TOKEN,
|
||||
BOS_TOKEN,
|
||||
DSML_TOKEN,
|
||||
EOS_TOKEN,
|
||||
THINKING_END,
|
||||
THINKING_START,
|
||||
TOOL_CALLS_END,
|
||||
TOOL_CALLS_START,
|
||||
USER_TOKEN,
|
||||
encode_messages,
|
||||
parse_dsml_output,
|
||||
)
|
||||
from exo.worker.runner.runner import parse_deepseek_v32
|
||||
|
||||
# ── Shared fixtures ──────────────────────────────────────────────
|
||||
|
||||
_WEATHER_TOOLS: list[dict[str, Any]] = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get the current weather in a given city",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {"type": "string", "description": "The city name"},
|
||||
"units": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
"description": "Temperature units",
|
||||
},
|
||||
},
|
||||
"required": ["city"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_time",
|
||||
"description": "Get the current time in a timezone",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"timezone": {"type": "string"},
|
||||
},
|
||||
"required": ["timezone"],
|
||||
},
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def _simulate_tokens(
|
||||
texts: list[str],
|
||||
finish_on_last: bool = True,
|
||||
) -> Generator[GenerationResponse]:
|
||||
"""Simulate a model producing tokens from a list of text strings."""
|
||||
for i, text in enumerate(texts):
|
||||
is_last = i == len(texts) - 1
|
||||
yield GenerationResponse(
|
||||
text=text,
|
||||
token=i,
|
||||
finish_reason="stop" if (is_last and finish_on_last) else None,
|
||||
usage=None,
|
||||
)
|
||||
|
||||
|
||||
# ── Test: Standard text response (no tool calls) ────────────────
|
||||
|
||||
|
||||
class TestE2EStandardResponse:
|
||||
"""Model generates a plain text response — no tool calling involved."""
|
||||
|
||||
def test_plain_text_passthrough(self):
|
||||
"""Simulate model producing: 'The weather in NYC is 72°F and sunny.'"""
|
||||
# Step 1: Encode the prompt (with tools available)
|
||||
messages: list[dict[str, Any]] = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "What's the weather in NYC?"},
|
||||
]
|
||||
prompt = encode_messages(messages, thinking_mode="chat", tools=_WEATHER_TOOLS)
|
||||
|
||||
# Verify prompt structure
|
||||
assert BOS_TOKEN in prompt
|
||||
assert "## Tools" in prompt
|
||||
assert "get_weather" in prompt
|
||||
assert f"{USER_TOKEN}What's the weather in NYC?{ASSISTANT_TOKEN}" in prompt
|
||||
|
||||
# Step 2: Simulate model response — plain text tokens (no DSML)
|
||||
model_tokens = [
|
||||
"The weather",
|
||||
" in NYC",
|
||||
" is 72",
|
||||
"°F",
|
||||
" and sunny",
|
||||
".",
|
||||
]
|
||||
results = list(parse_deepseek_v32(_simulate_tokens(model_tokens)))
|
||||
|
||||
# Step 3: Verify all tokens pass through as GenerationResponse
|
||||
gen_results = [r for r in results if isinstance(r, GenerationResponse)]
|
||||
tool_results = [r for r in results if isinstance(r, ToolCallResponse)]
|
||||
|
||||
assert len(tool_results) == 0
|
||||
assert len(gen_results) == 6
|
||||
full_text = "".join(r.text for r in gen_results)
|
||||
assert full_text == "The weather in NYC is 72°F and sunny."
|
||||
assert gen_results[-1].finish_reason == "stop"
|
||||
|
||||
|
||||
# ── Test: Tool call response ─────────────────────────────────────
|
||||
|
||||
|
||||
class TestE2EToolCallResponse:
|
||||
"""Model generates a DSML tool call — realistic token boundaries."""
|
||||
|
||||
def test_realistic_tool_call_tokens(self):
|
||||
"""Simulate model generating a get_weather tool call with realistic token splits.
|
||||
|
||||
Real models split DSML markers across tokens unpredictably.
|
||||
This simulates how DeepSeek V3.2 actually tokenizes DSML output.
|
||||
"""
|
||||
# Step 1: Encode prompt
|
||||
messages: list[dict[str, Any]] = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "What's the weather in San Francisco?"},
|
||||
]
|
||||
prompt = encode_messages(messages, thinking_mode="chat", tools=_WEATHER_TOOLS)
|
||||
assert "get_weather" in prompt
|
||||
|
||||
# Step 2: Simulate realistic token-by-token model output
|
||||
# The model first produces some text, then a DSML tool call block
|
||||
model_tokens = [
|
||||
"I'll check the weather for you.",
|
||||
"\n\n",
|
||||
f"<{DSML_TOKEN}", # marker split across tokens
|
||||
"function_calls>\n",
|
||||
f'<{DSML_TOKEN}invoke name="get_weather">\n',
|
||||
f'<{DSML_TOKEN}parameter name="city" string="true">',
|
||||
"San Francisco",
|
||||
f"</{DSML_TOKEN}parameter>\n",
|
||||
f'<{DSML_TOKEN}parameter name="units" string="false">',
|
||||
'"celsius"',
|
||||
f"</{DSML_TOKEN}parameter>\n",
|
||||
f"</{DSML_TOKEN}invoke>\n",
|
||||
f"</{DSML_TOKEN}function_calls>",
|
||||
]
|
||||
|
||||
results = list(parse_deepseek_v32(_simulate_tokens(model_tokens)))
|
||||
|
||||
# Step 3: Verify
|
||||
gen_results = [r for r in results if isinstance(r, GenerationResponse)]
|
||||
tool_results = [r for r in results if isinstance(r, ToolCallResponse)]
|
||||
|
||||
# Should have text tokens before tool call + one ToolCallResponse
|
||||
assert len(tool_results) == 1
|
||||
assert len(tool_results[0].tool_calls) == 1
|
||||
|
||||
tc = tool_results[0].tool_calls[0]
|
||||
assert tc.name == "get_weather"
|
||||
args = json.loads(tc.arguments) # pyright: ignore[reportAny]
|
||||
assert args["city"] == "San Francisco"
|
||||
assert args["units"] == "celsius"
|
||||
|
||||
# The text before the tool call should still be yielded
|
||||
text_before = "".join(r.text for r in gen_results if not r.is_thinking)
|
||||
assert "check the weather" in text_before
|
||||
|
||||
def test_multiple_tool_calls_in_one_block(self):
|
||||
"""Model generates two tool calls in a single function_calls block."""
|
||||
messages: list[dict[str, Any]] = [
|
||||
{"role": "system", "content": "You are helpful."},
|
||||
{"role": "user", "content": "Weather in NYC and time in EST?"},
|
||||
]
|
||||
prompt = encode_messages(messages, thinking_mode="chat", tools=_WEATHER_TOOLS)
|
||||
assert "get_weather" in prompt
|
||||
assert "get_time" in prompt
|
||||
|
||||
# Simulate model output with two invocations
|
||||
model_tokens = [
|
||||
"Let me check both.\n\n",
|
||||
TOOL_CALLS_START,
|
||||
"\n",
|
||||
f'<{DSML_TOKEN}invoke name="get_weather">\n',
|
||||
f'<{DSML_TOKEN}parameter name="city" string="true">NYC</{DSML_TOKEN}parameter>\n',
|
||||
f"</{DSML_TOKEN}invoke>\n",
|
||||
f'<{DSML_TOKEN}invoke name="get_time">\n',
|
||||
f'<{DSML_TOKEN}parameter name="timezone" string="true">EST</{DSML_TOKEN}parameter>\n',
|
||||
f"</{DSML_TOKEN}invoke>\n",
|
||||
TOOL_CALLS_END,
|
||||
]
|
||||
|
||||
results = list(parse_deepseek_v32(_simulate_tokens(model_tokens)))
|
||||
tool_results = [r for r in results if isinstance(r, ToolCallResponse)]
|
||||
|
||||
assert len(tool_results) == 1
|
||||
assert len(tool_results[0].tool_calls) == 2
|
||||
assert tool_results[0].tool_calls[0].name == "get_weather"
|
||||
assert tool_results[0].tool_calls[1].name == "get_time"
|
||||
|
||||
args0 = json.loads(tool_results[0].tool_calls[0].arguments) # pyright: ignore[reportAny]
|
||||
args1 = json.loads(tool_results[0].tool_calls[1].arguments) # pyright: ignore[reportAny]
|
||||
assert args0 == {"city": "NYC"}
|
||||
assert args1 == {"timezone": "EST"}
|
||||
|
||||
|
||||
# ── Test: Multi-turn tool use flow ───────────────────────────────
|
||||
|
||||
|
||||
class TestE2EMultiTurnToolUse:
|
||||
"""Full multi-turn: user asks → model calls tool → tool result → model answers."""
|
||||
|
||||
def test_encode_multi_turn_with_tool_results(self):
|
||||
"""Verify the prompt for turn 2 (after tool results) is correctly encoded."""
|
||||
# Turn 1: user asks, model calls tool
|
||||
# Turn 2: tool result provided, model answers
|
||||
messages: list[dict[str, Any]] = [
|
||||
{"role": "system", "content": "You are a weather assistant."},
|
||||
{"role": "user", "content": "What's the weather in NYC?"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": '{"city": "NYC"}',
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
{"role": "tool", "content": '{"temperature": 72, "condition": "sunny"}'},
|
||||
]
|
||||
|
||||
prompt = encode_messages(messages, thinking_mode="chat", tools=_WEATHER_TOOLS)
|
||||
|
||||
# Verify multi-turn structure
|
||||
assert BOS_TOKEN in prompt
|
||||
assert "You are a weather assistant." in prompt
|
||||
assert "## Tools" in prompt
|
||||
|
||||
# The assistant's tool call should be encoded as DSML
|
||||
assert TOOL_CALLS_START in prompt
|
||||
assert f'<{DSML_TOKEN}invoke name="get_weather">' in prompt
|
||||
assert EOS_TOKEN in prompt
|
||||
|
||||
# The tool result should be wrapped in function_results
|
||||
assert "<function_results>" in prompt
|
||||
assert "<result>" in prompt
|
||||
assert "72" in prompt
|
||||
assert "</function_results>" in prompt
|
||||
|
||||
# Now simulate model answering after seeing the tool result
|
||||
model_tokens = [
|
||||
"The current",
|
||||
" weather in NYC",
|
||||
" is 72°F",
|
||||
" and sunny.",
|
||||
]
|
||||
results = list(parse_deepseek_v32(_simulate_tokens(model_tokens)))
|
||||
|
||||
gen_results = [r for r in results if isinstance(r, GenerationResponse)]
|
||||
tool_results = [r for r in results if isinstance(r, ToolCallResponse)]
|
||||
|
||||
assert len(tool_results) == 0
|
||||
full_text = "".join(r.text for r in gen_results)
|
||||
assert full_text == "The current weather in NYC is 72°F and sunny."
|
||||
|
||||
def test_multi_tool_results_encoding(self):
|
||||
"""Verify encoding when model called two tools and both return results."""
|
||||
messages: list[dict[str, Any]] = [
|
||||
{"role": "user", "content": "Weather and time?"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": '{"city": "LA"}',
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_time",
|
||||
"arguments": '{"timezone": "PST"}',
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
{"role": "tool", "content": "85F, clear skies"},
|
||||
{"role": "tool", "content": "3:42 PM PST"},
|
||||
]
|
||||
|
||||
prompt = encode_messages(messages, thinking_mode="chat", tools=_WEATHER_TOOLS)
|
||||
|
||||
# Should have one function_results block with two results
|
||||
assert prompt.count("<function_results>") == 1
|
||||
assert prompt.count("</function_results>") == 1
|
||||
assert "<result>85F, clear skies</result>" in prompt
|
||||
assert "<result>3:42 PM PST</result>" in prompt
|
||||
|
||||
|
||||
# ── Test: Thinking + tool call ───────────────────────────────────
|
||||
|
||||
|
||||
class TestE2EThinkingAndToolCall:
|
||||
"""Model uses thinking mode, reasons, then makes a tool call."""
|
||||
|
||||
def test_thinking_then_tool_call(self):
|
||||
"""Model thinks first, then produces a DSML tool call block."""
|
||||
messages: list[dict[str, Any]] = [
|
||||
{"role": "user", "content": "What's the weather?"},
|
||||
]
|
||||
prompt = encode_messages(
|
||||
messages, tools=_WEATHER_TOOLS, thinking_mode="thinking"
|
||||
)
|
||||
# Thinking mode: prompt should end with <think>
|
||||
assert prompt.endswith(THINKING_START)
|
||||
|
||||
# Simulate: model outputs <think>, thinks, closes thinking, then tool call.
|
||||
# In the full pipeline, parse_thinking_models handles the case where
|
||||
# <think> is in the prompt. Here we test parse_deepseek_v32 directly,
|
||||
# which detects <think>/<think> markers in the stream.
|
||||
model_tokens = [
|
||||
THINKING_START,
|
||||
"The user wants weather",
|
||||
" information. I should use",
|
||||
" the get_weather tool.",
|
||||
THINKING_END,
|
||||
"\n\n",
|
||||
TOOL_CALLS_START,
|
||||
"\n",
|
||||
f'<{DSML_TOKEN}invoke name="get_weather">\n',
|
||||
f'<{DSML_TOKEN}parameter name="city" string="true">',
|
||||
"San Francisco",
|
||||
f"</{DSML_TOKEN}parameter>\n",
|
||||
f"</{DSML_TOKEN}invoke>\n",
|
||||
TOOL_CALLS_END,
|
||||
]
|
||||
|
||||
results = list(parse_deepseek_v32(_simulate_tokens(model_tokens)))
|
||||
|
||||
gen_results = [r for r in results if isinstance(r, GenerationResponse)]
|
||||
tool_results = [r for r in results if isinstance(r, ToolCallResponse)]
|
||||
|
||||
# Should have thinking tokens + tool call
|
||||
thinking_results = [r for r in gen_results if r.is_thinking]
|
||||
|
||||
assert len(thinking_results) >= 1
|
||||
thinking_text = "".join(r.text for r in thinking_results)
|
||||
assert "get_weather tool" in thinking_text
|
||||
|
||||
assert len(tool_results) == 1
|
||||
assert tool_results[0].tool_calls[0].name == "get_weather"
|
||||
args = json.loads(tool_results[0].tool_calls[0].arguments) # pyright: ignore[reportAny]
|
||||
assert args["city"] == "San Francisco"
|
||||
|
||||
def test_thinking_prompt_encoding(self):
|
||||
"""Verify thinking mode affects prompt encoding correctly."""
|
||||
messages: list[dict[str, Any]] = [
|
||||
{"role": "system", "content": "Be thorough."},
|
||||
{"role": "user", "content": "What's the weather?"},
|
||||
]
|
||||
|
||||
# With thinking enabled
|
||||
prompt_think = encode_messages(
|
||||
messages, tools=_WEATHER_TOOLS, thinking_mode="thinking"
|
||||
)
|
||||
assert prompt_think.endswith(THINKING_START)
|
||||
|
||||
# With thinking disabled
|
||||
prompt_no_think = encode_messages(
|
||||
messages, tools=_WEATHER_TOOLS, thinking_mode="chat"
|
||||
)
|
||||
assert prompt_no_think.endswith(THINKING_END)
|
||||
|
||||
# Both should have the same tool definitions
|
||||
assert "get_weather" in prompt_think
|
||||
assert "get_weather" in prompt_no_think
|
||||
|
||||
|
||||
# ── Test: Round-trip encode → parse ──────────────────────────────
|
||||
|
||||
|
||||
class TestE2ERoundTrip:
|
||||
"""Verify that DSML we encode can be parsed back correctly."""
|
||||
|
||||
def test_encoded_tool_call_is_parseable(self):
|
||||
"""Encode an assistant tool call message, then parse the DSML output."""
|
||||
messages: list[dict[str, Any]] = [
|
||||
{"role": "user", "content": "Weather?"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": '{"city": "Tokyo", "units": "celsius"}',
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
prompt = encode_messages(messages, thinking_mode="chat", tools=_WEATHER_TOOLS)
|
||||
|
||||
# Extract the DSML function_calls block from the prompt
|
||||
start = prompt.index(TOOL_CALLS_START)
|
||||
end = prompt.index(TOOL_CALLS_END) + len(TOOL_CALLS_END)
|
||||
dsml_block = prompt[start:end]
|
||||
|
||||
# Parse it back
|
||||
parsed = parse_dsml_output(dsml_block)
|
||||
assert parsed is not None
|
||||
assert len(parsed) == 1
|
||||
assert parsed[0].name == "get_weather"
|
||||
args = json.loads(parsed[0].arguments) # pyright: ignore[reportAny]
|
||||
assert args["city"] == "Tokyo"
|
||||
assert args["units"] == "celsius"
|
||||
|
||||
def test_encoded_multi_tool_call_round_trips(self):
|
||||
"""Encode multiple tool calls, verify they parse back correctly."""
|
||||
messages: list[dict[str, Any]] = [
|
||||
{"role": "user", "content": "Both please"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": '{"city": "Paris"}',
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_time",
|
||||
"arguments": '{"timezone": "CET"}',
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
prompt = encode_messages(messages, thinking_mode="chat", tools=_WEATHER_TOOLS)
|
||||
|
||||
start = prompt.index(TOOL_CALLS_START)
|
||||
end = prompt.index(TOOL_CALLS_END) + len(TOOL_CALLS_END)
|
||||
dsml_block = prompt[start:end]
|
||||
|
||||
parsed = parse_dsml_output(dsml_block)
|
||||
assert parsed is not None
|
||||
assert len(parsed) == 2
|
||||
assert parsed[0].name == "get_weather"
|
||||
assert parsed[1].name == "get_time"
|
||||
assert json.loads(parsed[0].arguments) == {"city": "Paris"}
|
||||
assert json.loads(parsed[1].arguments) == {"timezone": "CET"}
|
||||
|
||||
|
||||
# ── Test: Edge cases with realistic token boundaries ─────────────
|
||||
|
||||
|
||||
class TestE2EEdgeCases:
|
||||
"""Edge cases that occur in real model inference."""
|
||||
|
||||
def test_dsml_marker_split_at_fullwidth_pipe(self):
|
||||
"""The fullwidth pipe character | might be its own token."""
|
||||
# This is a realistic tokenization: the DSML marker is split at the | chars
|
||||
model_tokens = [
|
||||
"Let me help.\n\n",
|
||||
"<\uff5c", # start of |DSML|
|
||||
"DSML\uff5c", # rest of DSML token
|
||||
"function_calls>\n",
|
||||
f'<{DSML_TOKEN}invoke name="get_weather">\n',
|
||||
f'<{DSML_TOKEN}parameter name="city" string="true">NYC</{DSML_TOKEN}parameter>\n',
|
||||
f"</{DSML_TOKEN}invoke>\n",
|
||||
TOOL_CALLS_END,
|
||||
]
|
||||
|
||||
results = list(parse_deepseek_v32(_simulate_tokens(model_tokens)))
|
||||
tool_results = [r for r in results if isinstance(r, ToolCallResponse)]
|
||||
|
||||
assert len(tool_results) == 1
|
||||
assert tool_results[0].tool_calls[0].name == "get_weather"
|
||||
|
||||
def test_tool_call_with_nested_json_object(self):
|
||||
"""Model passes a complex JSON object as a non-string parameter."""
|
||||
dsml_block = (
|
||||
f"{TOOL_CALLS_START}\n"
|
||||
f'<{DSML_TOKEN}invoke name="create_event">\n'
|
||||
f'<{DSML_TOKEN}parameter name="title" string="true">Team Standup</{DSML_TOKEN}parameter>\n'
|
||||
f'<{DSML_TOKEN}parameter name="config" string="false">'
|
||||
f'{{"recurring": true, "days": ["mon", "wed", "fri"], "time": "09:00"}}'
|
||||
f"</{DSML_TOKEN}parameter>\n"
|
||||
f"</{DSML_TOKEN}invoke>\n"
|
||||
f"{TOOL_CALLS_END}"
|
||||
)
|
||||
|
||||
# Feed as single token (model might produce it all at once after prefill)
|
||||
results = list(parse_deepseek_v32(_simulate_tokens([dsml_block])))
|
||||
tool_results = [r for r in results if isinstance(r, ToolCallResponse)]
|
||||
|
||||
assert len(tool_results) == 1
|
||||
tc = tool_results[0].tool_calls[0]
|
||||
assert tc.name == "create_event"
|
||||
args = json.loads(tc.arguments) # pyright: ignore[reportAny]
|
||||
assert args["title"] == "Team Standup"
|
||||
assert args["config"]["recurring"] is True
|
||||
assert args["config"]["days"] == ["mon", "wed", "fri"]
|
||||
|
||||
def test_text_with_angle_brackets_not_mistaken_for_dsml(self):
|
||||
"""Angle brackets in normal text should not trigger DSML buffering."""
|
||||
model_tokens = [
|
||||
"The formula is ",
|
||||
"<x, y>",
|
||||
" where x > 0",
|
||||
" and y < 100.",
|
||||
]
|
||||
|
||||
results = list(parse_deepseek_v32(_simulate_tokens(model_tokens)))
|
||||
gen_results = [r for r in results if isinstance(r, GenerationResponse)]
|
||||
tool_results = [r for r in results if isinstance(r, ToolCallResponse)]
|
||||
|
||||
assert len(tool_results) == 0
|
||||
full_text = "".join(r.text for r in gen_results)
|
||||
assert "formula" in full_text
|
||||
assert "<x, y>" in full_text
|
||||
|
||||
def test_empty_model_response(self):
|
||||
"""Model produces only EOS (empty response)."""
|
||||
model_tokens = [""]
|
||||
results = list(parse_deepseek_v32(_simulate_tokens(model_tokens)))
|
||||
gen_results = [r for r in results if isinstance(r, GenerationResponse)]
|
||||
assert len(gen_results) == 1
|
||||
assert gen_results[0].text == ""
|
||||
assert gen_results[0].finish_reason == "stop"
|
||||
|
||||
|
||||
# ── Test: Full EPDP spec round-trip ──────────────────────────────
|
||||
|
||||
|
||||
class TestE2EFullRoundTrip:
|
||||
"""Full round-trip matching the vLLM EPDP spec.
|
||||
|
||||
Simulates the complete multi-turn flow:
|
||||
Turn 1: user asks → think → tool call → tool result → think → answer
|
||||
Turn 2: user asks again → old reasoning stripped → think → answer
|
||||
"""
|
||||
|
||||
def test_single_tool_full_flow_with_thinking(self):
|
||||
"""Complete flow: user → think → tool call → tool result → think → answer.
|
||||
|
||||
This is the core EPDP flow from the vLLM spec.
|
||||
"""
|
||||
# ── Turn 1.1: User asks, encode prompt ──
|
||||
messages: list[dict[str, Any]] = [
|
||||
{"role": "system", "content": "You are a weather assistant."},
|
||||
{"role": "user", "content": "How's the weather in Hangzhou?"},
|
||||
]
|
||||
prompt_1 = encode_messages(
|
||||
messages, tools=_WEATHER_TOOLS, thinking_mode="thinking"
|
||||
)
|
||||
assert prompt_1.endswith(THINKING_START)
|
||||
assert "## Tools" in prompt_1
|
||||
assert "get_weather" in prompt_1
|
||||
|
||||
# ── Turn 1.1: Model thinks, then calls tool ──
|
||||
model_tokens_1 = [
|
||||
THINKING_START,
|
||||
"The user wants to know the weather in Hangzhou.",
|
||||
" I need to use the get_weather tool.",
|
||||
THINKING_END,
|
||||
"\n\n",
|
||||
TOOL_CALLS_START,
|
||||
"\n",
|
||||
f'<{DSML_TOKEN}invoke name="get_weather">\n',
|
||||
f'<{DSML_TOKEN}parameter name="city" string="true">Hangzhou</{DSML_TOKEN}parameter>\n',
|
||||
f"</{DSML_TOKEN}invoke>\n",
|
||||
TOOL_CALLS_END,
|
||||
]
|
||||
results_1 = list(parse_deepseek_v32(_simulate_tokens(model_tokens_1)))
|
||||
|
||||
# Verify: thinking tokens + tool call
|
||||
gen_1 = [r for r in results_1 if isinstance(r, GenerationResponse)]
|
||||
tool_1 = [r for r in results_1 if isinstance(r, ToolCallResponse)]
|
||||
thinking_1 = [r for r in gen_1 if r.is_thinking]
|
||||
|
||||
assert len(thinking_1) >= 1
|
||||
assert "get_weather tool" in "".join(r.text for r in thinking_1)
|
||||
assert len(tool_1) == 1
|
||||
assert tool_1[0].tool_calls[0].name == "get_weather"
|
||||
tc_args = json.loads(tool_1[0].tool_calls[0].arguments) # pyright: ignore[reportAny]
|
||||
assert tc_args == {"city": "Hangzhou"}
|
||||
|
||||
# ── Turn 1.2: Add assistant response + tool result to messages ──
|
||||
messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"reasoning_content": "The user wants to know the weather in Hangzhou. I need to use the get_weather tool.",
|
||||
"tool_calls": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": '{"city": "Hangzhou"}',
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"content": '{"temperature": "7~13°C", "condition": "Cloudy"}',
|
||||
}
|
||||
)
|
||||
|
||||
# Encode prompt for turn 1.2
|
||||
prompt_2 = encode_messages(
|
||||
messages, tools=_WEATHER_TOOLS, thinking_mode="thinking"
|
||||
)
|
||||
|
||||
# Verify: prompt has the full conversation structure
|
||||
assert TOOL_CALLS_START in prompt_2 # assistant's encoded tool call
|
||||
assert EOS_TOKEN in prompt_2 # assistant turn ends with EOS
|
||||
assert "<function_results>" in prompt_2
|
||||
assert "<result>" in prompt_2
|
||||
assert "Cloudy" in prompt_2
|
||||
assert "</function_results>" in prompt_2
|
||||
# After tool results with thinking enabled → <think> appended
|
||||
assert prompt_2.endswith(THINKING_START)
|
||||
# The assistant's reasoning_content should appear (it's after last_user_idx)
|
||||
assert "get_weather tool" in prompt_2
|
||||
|
||||
# ── Turn 1.2: Model thinks about results, then answers ──
|
||||
model_tokens_2 = [
|
||||
THINKING_START,
|
||||
"The weather in Hangzhou is Cloudy, 7~13°C.",
|
||||
" I'll tell the user.",
|
||||
THINKING_END,
|
||||
"The weather in Hangzhou is currently cloudy with temperatures between 7°C and 13°C.",
|
||||
]
|
||||
results_2 = list(parse_deepseek_v32(_simulate_tokens(model_tokens_2)))
|
||||
|
||||
gen_2 = [r for r in results_2 if isinstance(r, GenerationResponse)]
|
||||
tool_2 = [r for r in results_2 if isinstance(r, ToolCallResponse)]
|
||||
thinking_2 = [r for r in gen_2 if r.is_thinking]
|
||||
non_thinking_2 = [r for r in gen_2 if not r.is_thinking]
|
||||
|
||||
assert len(tool_2) == 0 # No more tool calls
|
||||
assert len(thinking_2) >= 1
|
||||
assert "Cloudy" in "".join(r.text for r in thinking_2)
|
||||
assert len(non_thinking_2) >= 1
|
||||
final_text = "".join(r.text for r in non_thinking_2)
|
||||
assert "7°C" in final_text
|
||||
assert "13°C" in final_text
|
||||
|
||||
def test_multi_tool_full_flow(self):
|
||||
"""Flow with two tools: user → think → 2 tool calls → 2 results → think → answer."""
|
||||
# ── Initial prompt ──
|
||||
messages: list[dict[str, Any]] = [
|
||||
{"role": "system", "content": "You help with weather and time."},
|
||||
{"role": "user", "content": "Weather in NYC and time in EST?"},
|
||||
]
|
||||
prompt_1 = encode_messages(
|
||||
messages, tools=_WEATHER_TOOLS, thinking_mode="thinking"
|
||||
)
|
||||
assert prompt_1.endswith(THINKING_START)
|
||||
|
||||
# ── Model thinks, calls both tools ──
|
||||
model_tokens_1 = [
|
||||
THINKING_START,
|
||||
"Two requests: weather and time. I'll call both.",
|
||||
THINKING_END,
|
||||
"\n\n",
|
||||
TOOL_CALLS_START,
|
||||
"\n",
|
||||
f'<{DSML_TOKEN}invoke name="get_weather">\n',
|
||||
f'<{DSML_TOKEN}parameter name="city" string="true">NYC</{DSML_TOKEN}parameter>\n',
|
||||
f"</{DSML_TOKEN}invoke>\n",
|
||||
f'<{DSML_TOKEN}invoke name="get_time">\n',
|
||||
f'<{DSML_TOKEN}parameter name="timezone" string="true">EST</{DSML_TOKEN}parameter>\n',
|
||||
f"</{DSML_TOKEN}invoke>\n",
|
||||
TOOL_CALLS_END,
|
||||
]
|
||||
results_1 = list(parse_deepseek_v32(_simulate_tokens(model_tokens_1)))
|
||||
tool_1 = [r for r in results_1 if isinstance(r, ToolCallResponse)]
|
||||
|
||||
assert len(tool_1) == 1
|
||||
assert len(tool_1[0].tool_calls) == 2
|
||||
assert tool_1[0].tool_calls[0].name == "get_weather"
|
||||
assert tool_1[0].tool_calls[1].name == "get_time"
|
||||
|
||||
# ── Add assistant + both tool results ──
|
||||
messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"reasoning_content": "Two requests: weather and time. I'll call both.",
|
||||
"tool_calls": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": '{"city": "NYC"}',
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_time",
|
||||
"arguments": '{"timezone": "EST"}',
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
)
|
||||
messages.append({"role": "tool", "content": "72°F, sunny"})
|
||||
messages.append({"role": "tool", "content": "2:30 PM EST"})
|
||||
|
||||
prompt_2 = encode_messages(
|
||||
messages, tools=_WEATHER_TOOLS, thinking_mode="thinking"
|
||||
)
|
||||
|
||||
# Verify multi-tool result encoding
|
||||
# Count is 2: 1 in _TOOLS_SYSTEM_TEMPLATE example + 1 in conversation
|
||||
assert prompt_2.count("<function_results>") == 2
|
||||
assert prompt_2.count("</function_results>") == 2
|
||||
assert "<result>72°F, sunny</result>" in prompt_2
|
||||
assert "<result>2:30 PM EST</result>" in prompt_2
|
||||
assert prompt_2.endswith(THINKING_START)
|
||||
|
||||
# ── Model thinks about results, answers ──
|
||||
model_tokens_2 = [
|
||||
THINKING_START,
|
||||
"Got both results. Weather is 72°F sunny, time is 2:30 PM.",
|
||||
THINKING_END,
|
||||
"In NYC it's currently 72°F and sunny. The time in EST is 2:30 PM.",
|
||||
]
|
||||
results_2 = list(parse_deepseek_v32(_simulate_tokens(model_tokens_2)))
|
||||
|
||||
tool_2 = [r for r in results_2 if isinstance(r, ToolCallResponse)]
|
||||
gen_2 = [r for r in results_2 if isinstance(r, GenerationResponse)]
|
||||
non_thinking_2 = [r for r in gen_2 if not r.is_thinking]
|
||||
|
||||
assert len(tool_2) == 0
|
||||
final_text = "".join(r.text for r in non_thinking_2)
|
||||
assert "72°F" in final_text
|
||||
assert "2:30 PM" in final_text
|
||||
|
||||
def test_two_user_turns_reasoning_stripped(self):
|
||||
"""Turn 2: old reasoning_content is stripped from history.
|
||||
|
||||
Per the vLLM spec, clear_reasoning_content is called between user turns
|
||||
to save bandwidth. Our _drop_old_thinking handles this.
|
||||
"""
|
||||
# Full turn 1 conversation (already completed)
|
||||
messages: list[dict[str, Any]] = [
|
||||
{"role": "system", "content": "You are helpful."},
|
||||
{"role": "user", "content": "Weather in Hangzhou?"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"reasoning_content": "I need to call get_weather for Hangzhou.",
|
||||
"tool_calls": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": '{"city": "Hangzhou"}',
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
{"role": "tool", "content": "Cloudy 7~13°C"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "The weather in Hangzhou is cloudy, 7-13°C.",
|
||||
"reasoning_content": "The tool returned cloudy weather. I'll summarize.",
|
||||
},
|
||||
# Turn 2: user asks again
|
||||
{"role": "user", "content": "What about Beijing?"},
|
||||
]
|
||||
|
||||
prompt = encode_messages(
|
||||
messages, tools=_WEATHER_TOOLS, thinking_mode="thinking"
|
||||
)
|
||||
|
||||
# Old reasoning_content from turn 1 assistants should be STRIPPED
|
||||
# (they're before the last user message at index 5)
|
||||
assert "I need to call get_weather" not in prompt
|
||||
assert "tool returned cloudy" not in prompt
|
||||
|
||||
# But the assistant's content and tool calls should still be there
|
||||
assert "cloudy, 7-13°C" in prompt
|
||||
assert TOOL_CALLS_START in prompt
|
||||
|
||||
# Prompt ends with <think> for the new turn
|
||||
assert prompt.endswith(THINKING_START)
|
||||
|
||||
# ── Turn 2: Model thinks, calls tool for Beijing ──
|
||||
model_tokens = [
|
||||
THINKING_START,
|
||||
"Now the user wants Beijing weather.",
|
||||
THINKING_END,
|
||||
"\n\n",
|
||||
TOOL_CALLS_START,
|
||||
"\n",
|
||||
f'<{DSML_TOKEN}invoke name="get_weather">\n',
|
||||
f'<{DSML_TOKEN}parameter name="city" string="true">Beijing</{DSML_TOKEN}parameter>\n',
|
||||
f"</{DSML_TOKEN}invoke>\n",
|
||||
TOOL_CALLS_END,
|
||||
]
|
||||
results = list(parse_deepseek_v32(_simulate_tokens(model_tokens)))
|
||||
tool_results = [r for r in results if isinstance(r, ToolCallResponse)]
|
||||
|
||||
assert len(tool_results) == 1
|
||||
assert tool_results[0].tool_calls[0].name == "get_weather"
|
||||
args = json.loads(tool_results[0].tool_calls[0].arguments) # pyright: ignore[reportAny]
|
||||
assert args == {"city": "Beijing"}
|
||||
|
||||
def test_chained_tool_calls_loop(self):
|
||||
"""Model calls tool, gets result, calls another tool, gets result, answers.
|
||||
|
||||
This simulates the inner while loop from the vLLM spec where the model
|
||||
may need multiple sub-turns of tool calling before it has enough info.
|
||||
"""
|
||||
# ── Sub-turn 1: user asks, model calls get_time ──
|
||||
messages: list[dict[str, Any]] = [
|
||||
{"role": "system", "content": "You are helpful."},
|
||||
{"role": "user", "content": "What's the weather in Hangzhou tomorrow?"},
|
||||
]
|
||||
|
||||
prompt_1 = encode_messages(
|
||||
messages, tools=_WEATHER_TOOLS, thinking_mode="thinking"
|
||||
)
|
||||
assert prompt_1.endswith(THINKING_START)
|
||||
|
||||
# Model first calls get_time to figure out the date
|
||||
model_tokens_1 = [
|
||||
THINKING_START,
|
||||
"I need the current date first to calculate tomorrow.",
|
||||
THINKING_END,
|
||||
"\n\n",
|
||||
TOOL_CALLS_START,
|
||||
"\n",
|
||||
f'<{DSML_TOKEN}invoke name="get_time">\n',
|
||||
f'<{DSML_TOKEN}parameter name="timezone" string="true">Asia/Shanghai</{DSML_TOKEN}parameter>\n',
|
||||
f"</{DSML_TOKEN}invoke>\n",
|
||||
TOOL_CALLS_END,
|
||||
]
|
||||
results_1 = list(parse_deepseek_v32(_simulate_tokens(model_tokens_1)))
|
||||
tool_1 = [r for r in results_1 if isinstance(r, ToolCallResponse)]
|
||||
assert len(tool_1) == 1
|
||||
assert tool_1[0].tool_calls[0].name == "get_time"
|
||||
|
||||
# ── Sub-turn 2: add tool result, model calls get_weather ──
|
||||
messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"reasoning_content": "I need the current date first to calculate tomorrow.",
|
||||
"tool_calls": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_time",
|
||||
"arguments": '{"timezone": "Asia/Shanghai"}',
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
messages.append({"role": "tool", "content": "2025-12-01 14:30 CST"})
|
||||
|
||||
prompt_2 = encode_messages(
|
||||
messages, tools=_WEATHER_TOOLS, thinking_mode="thinking"
|
||||
)
|
||||
assert "<result>2025-12-01 14:30 CST</result>" in prompt_2
|
||||
assert prompt_2.endswith(THINKING_START)
|
||||
|
||||
# Model now knows the date, calls get_weather
|
||||
model_tokens_2 = [
|
||||
THINKING_START,
|
||||
"Today is 2025-12-01, so tomorrow is 2025-12-02.",
|
||||
" Now I can check weather for Hangzhou.",
|
||||
THINKING_END,
|
||||
"\n\n",
|
||||
TOOL_CALLS_START,
|
||||
"\n",
|
||||
f'<{DSML_TOKEN}invoke name="get_weather">\n',
|
||||
f'<{DSML_TOKEN}parameter name="city" string="true">Hangzhou</{DSML_TOKEN}parameter>\n',
|
||||
f"</{DSML_TOKEN}invoke>\n",
|
||||
TOOL_CALLS_END,
|
||||
]
|
||||
results_2 = list(parse_deepseek_v32(_simulate_tokens(model_tokens_2)))
|
||||
tool_2 = [r for r in results_2 if isinstance(r, ToolCallResponse)]
|
||||
assert len(tool_2) == 1
|
||||
assert tool_2[0].tool_calls[0].name == "get_weather"
|
||||
|
||||
# ── Sub-turn 3: add weather result, model answers ──
|
||||
messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"reasoning_content": "Today is 2025-12-01, so tomorrow is 2025-12-02. Now I can check weather for Hangzhou.",
|
||||
"tool_calls": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": '{"city": "Hangzhou"}',
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
messages.append({"role": "tool", "content": "Sunny, 5~12°C"})
|
||||
|
||||
prompt_3 = encode_messages(
|
||||
messages, tools=_WEATHER_TOOLS, thinking_mode="thinking"
|
||||
)
|
||||
# Should have both function_results blocks (one per tool round)
|
||||
# Count is 3: 1 in _TOOLS_SYSTEM_TEMPLATE example + 2 in conversation
|
||||
assert prompt_3.count("<function_results>") == 3
|
||||
assert prompt_3.count("</function_results>") == 3
|
||||
assert "<result>2025-12-01 14:30 CST</result>" in prompt_3
|
||||
assert "<result>Sunny, 5~12°C</result>" in prompt_3
|
||||
assert prompt_3.endswith(THINKING_START)
|
||||
|
||||
# Model finally answers
|
||||
model_tokens_3 = [
|
||||
THINKING_START,
|
||||
"I have the weather for tomorrow in Hangzhou.",
|
||||
THINKING_END,
|
||||
"Tomorrow in Hangzhou will be sunny with temperatures between 5°C and 12°C.",
|
||||
]
|
||||
results_3 = list(parse_deepseek_v32(_simulate_tokens(model_tokens_3)))
|
||||
|
||||
tool_3 = [r for r in results_3 if isinstance(r, ToolCallResponse)]
|
||||
gen_3 = [r for r in results_3 if isinstance(r, GenerationResponse)]
|
||||
non_thinking_3 = [r for r in gen_3 if not r.is_thinking]
|
||||
|
||||
assert len(tool_3) == 0 # No more tool calls — loop ends
|
||||
final_text = "".join(r.text for r in non_thinking_3)
|
||||
assert "sunny" in final_text.lower()
|
||||
assert "5°C" in final_text
|
||||
assert "12°C" in final_text
|
||||
@@ -148,7 +148,6 @@ class MockTokenizer:
|
||||
tool_call_start = None
|
||||
tool_call_end = None
|
||||
has_tool_calling = False
|
||||
has_thinking = False
|
||||
|
||||
|
||||
class MockGroup:
|
||||
|
||||
@@ -149,23 +149,12 @@ class TestParseGptOssThinkingThenToolCall:
|
||||
def test_thinking_then_tool_call(self):
|
||||
results = _collect(THINKING_THEN_TOOL_TOKENS)
|
||||
|
||||
# Thinking tokens should have is_thinking=True and no <think> tags
|
||||
thinking_responses = [
|
||||
r for r in results if isinstance(r, GenerationResponse) and r.is_thinking
|
||||
]
|
||||
thinking_text = "".join(r.text for r in thinking_responses)
|
||||
assert "Let me think about this." in thinking_text
|
||||
assert "<think>" not in thinking_text
|
||||
assert "</think>" not in thinking_text
|
||||
|
||||
# Non-thinking tokens should have is_thinking=False
|
||||
non_thinking = [
|
||||
r
|
||||
for r in results
|
||||
if isinstance(r, GenerationResponse) and not r.is_thinking
|
||||
]
|
||||
non_thinking_text = "".join(r.text for r in non_thinking)
|
||||
assert "<think>" not in non_thinking_text
|
||||
# Should have thinking tags + content + tool call
|
||||
text_parts = [r.text for r in results if isinstance(r, GenerationResponse)]
|
||||
combined = "".join(text_parts)
|
||||
assert "<think>" in combined
|
||||
assert "</think>" in combined
|
||||
assert "Let me think about this." in combined
|
||||
|
||||
# And the tool call
|
||||
tc = _get_tool_call(results)
|
||||
|
||||
@@ -1,8 +0,0 @@
|
||||
#!/bin/bash
|
||||
# Run Claude Code against a local exo cluster! (Here, GPT OSS 120B)
|
||||
ANTHROPIC_BASE_URL="http://localhost:52415/" \
|
||||
ANTHROPIC_AUTH_TOKEN="dummy" \
|
||||
ANTHROPIC_MODEL="mlx-community/gpt-oss-120b-MXFP4-Q8" \
|
||||
ANTHROPIC_SMALL_FAST_MODEL="mlx-community/gpt-oss-120b-MXFP4-Q8" \
|
||||
CLAUDE_CODE_DISABLE_NONESSENTIAL_TRAFFIC=1 \
|
||||
claude
|
||||
Reference in New Issue
Block a user