mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-19 15:27:02 -05:00
Compare commits
9 Commits
cancel-tas
...
feat/prefi
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
340aa36877 | ||
|
|
32669ae82d | ||
|
|
cf648a53b8 | ||
|
|
94b2ce6922 | ||
|
|
423ed0f07f | ||
|
|
ed001f2409 | ||
|
|
4c4c6ce99f | ||
|
|
42e1e7322b | ||
|
|
aa3f106fb9 |
@@ -20,6 +20,7 @@ from harness import (
|
||||
instance_id_from_instance,
|
||||
nodes_used_in_instance,
|
||||
resolve_model_short_id,
|
||||
run_planning_phase,
|
||||
settle_and_fetch_placements,
|
||||
wait_for_instance_gone,
|
||||
wait_for_instance_ready,
|
||||
@@ -962,6 +963,21 @@ Examples:
|
||||
|
||||
selected.sort(key=_placement_sort_key)
|
||||
preview = selected[0]
|
||||
|
||||
settle_deadline = (
|
||||
time.monotonic() + args.settle_timeout if args.settle_timeout > 0 else None
|
||||
)
|
||||
|
||||
print("Planning phase: checking downloads...", file=log)
|
||||
run_planning_phase(
|
||||
exo,
|
||||
full_model_id,
|
||||
preview,
|
||||
args.danger_delete_downloads,
|
||||
args.timeout,
|
||||
settle_deadline,
|
||||
)
|
||||
|
||||
instance = preview["instance"]
|
||||
instance_id = instance_id_from_instance(instance)
|
||||
sharding = str(preview["sharding"])
|
||||
|
||||
@@ -35,6 +35,7 @@ from harness import (
|
||||
instance_id_from_instance,
|
||||
nodes_used_in_instance,
|
||||
resolve_model_short_id,
|
||||
run_planning_phase,
|
||||
settle_and_fetch_placements,
|
||||
wait_for_instance_gone,
|
||||
wait_for_instance_ready,
|
||||
@@ -332,6 +333,20 @@ def main() -> int:
|
||||
if args.dry_run:
|
||||
return 0
|
||||
|
||||
settle_deadline = (
|
||||
time.monotonic() + args.settle_timeout if args.settle_timeout > 0 else None
|
||||
)
|
||||
|
||||
logger.info("Planning phase: checking downloads...")
|
||||
run_planning_phase(
|
||||
client,
|
||||
full_model_id,
|
||||
selected[0],
|
||||
args.danger_delete_downloads,
|
||||
args.timeout,
|
||||
settle_deadline,
|
||||
)
|
||||
|
||||
all_rows: list[dict[str, Any]] = []
|
||||
|
||||
for preview in selected:
|
||||
|
||||
150
bench/harness.py
150
bench/harness.py
@@ -282,6 +282,151 @@ def settle_and_fetch_placements(
|
||||
return selected
|
||||
|
||||
|
||||
def run_planning_phase(
|
||||
client: ExoClient,
|
||||
full_model_id: str,
|
||||
preview: dict[str, Any],
|
||||
danger_delete: bool,
|
||||
timeout: float,
|
||||
settle_deadline: float | None,
|
||||
) -> None:
|
||||
"""Check disk space and ensure model is downloaded before benchmarking."""
|
||||
# Get model size from /models
|
||||
models = client.request_json("GET", "/models") or {}
|
||||
model_bytes = 0
|
||||
for m in models.get("data", []):
|
||||
if m.get("hugging_face_id") == full_model_id:
|
||||
model_bytes = m.get("storage_size_megabytes", 0) * 1024 * 1024
|
||||
break
|
||||
|
||||
if not model_bytes:
|
||||
logger.warning(
|
||||
f"Could not determine size for {full_model_id}, skipping disk check"
|
||||
)
|
||||
return
|
||||
|
||||
# Get nodes from preview
|
||||
inner = unwrap_instance(preview["instance"])
|
||||
node_ids = list(inner["shardAssignments"]["nodeToRunner"].keys())
|
||||
runner_to_shard = inner["shardAssignments"]["runnerToShard"]
|
||||
|
||||
state = client.request_json("GET", "/state")
|
||||
downloads = state.get("downloads", {})
|
||||
node_disk = state.get("nodeDisk", {})
|
||||
|
||||
for node_id in node_ids:
|
||||
node_downloads = downloads.get(node_id, [])
|
||||
|
||||
# Check if model already downloaded on this node
|
||||
already_downloaded = any(
|
||||
"DownloadCompleted" in p
|
||||
and unwrap_instance(p["DownloadCompleted"]["shardMetadata"])["modelCard"][
|
||||
"modelId"
|
||||
]
|
||||
== full_model_id
|
||||
for p in node_downloads
|
||||
)
|
||||
if already_downloaded:
|
||||
continue
|
||||
|
||||
# Wait for disk info if settle_deadline is set
|
||||
disk_info = node_disk.get(node_id, {})
|
||||
backoff = _SETTLE_INITIAL_BACKOFF_S
|
||||
while not disk_info and settle_deadline and time.monotonic() < settle_deadline:
|
||||
remaining = settle_deadline - time.monotonic()
|
||||
logger.info(
|
||||
f"Waiting for disk info on {node_id} ({remaining:.0f}s remaining)..."
|
||||
)
|
||||
time.sleep(min(backoff, remaining))
|
||||
backoff = min(backoff * _SETTLE_BACKOFF_MULTIPLIER, _SETTLE_MAX_BACKOFF_S)
|
||||
state = client.request_json("GET", "/state")
|
||||
node_disk = state.get("nodeDisk", {})
|
||||
disk_info = node_disk.get(node_id, {})
|
||||
|
||||
if not disk_info:
|
||||
logger.warning(f"No disk info for {node_id}, skipping space check")
|
||||
continue
|
||||
|
||||
avail = disk_info.get("available", {}).get("inBytes", 0)
|
||||
if avail >= model_bytes:
|
||||
continue
|
||||
|
||||
if not danger_delete:
|
||||
raise RuntimeError(
|
||||
f"Insufficient disk on {node_id}: need {model_bytes // (1024**3)}GB, "
|
||||
f"have {avail // (1024**3)}GB. Use --danger-delete-downloads to free space."
|
||||
)
|
||||
|
||||
# Delete from smallest to largest
|
||||
completed = [
|
||||
(
|
||||
unwrap_instance(p["DownloadCompleted"]["shardMetadata"])["modelCard"][
|
||||
"modelId"
|
||||
],
|
||||
p["DownloadCompleted"]["totalBytes"]["inBytes"],
|
||||
)
|
||||
for p in node_downloads
|
||||
if "DownloadCompleted" in p
|
||||
]
|
||||
for del_model, size in sorted(completed, key=lambda x: x[1]):
|
||||
logger.info(f"Deleting {del_model} from {node_id} ({size // (1024**2)}MB)")
|
||||
client.request_json("DELETE", f"/download/{node_id}/{del_model}")
|
||||
avail += size
|
||||
if avail >= model_bytes:
|
||||
break
|
||||
|
||||
if avail < model_bytes:
|
||||
raise RuntimeError(f"Could not free enough space on {node_id}")
|
||||
|
||||
# Start downloads (idempotent)
|
||||
for node_id in node_ids:
|
||||
runner_id = inner["shardAssignments"]["nodeToRunner"][node_id]
|
||||
shard = runner_to_shard[runner_id]
|
||||
client.request_json(
|
||||
"POST",
|
||||
"/download/start",
|
||||
body={
|
||||
"targetNodeId": node_id,
|
||||
"shardMetadata": shard,
|
||||
},
|
||||
)
|
||||
logger.info(f"Started download on {node_id}")
|
||||
|
||||
# Wait for downloads
|
||||
start = time.time()
|
||||
while time.time() - start < timeout:
|
||||
state = client.request_json("GET", "/state")
|
||||
downloads = state.get("downloads", {})
|
||||
all_done = True
|
||||
for node_id in node_ids:
|
||||
done = any(
|
||||
"DownloadCompleted" in p
|
||||
and unwrap_instance(p["DownloadCompleted"]["shardMetadata"])[
|
||||
"modelCard"
|
||||
]["modelId"]
|
||||
== full_model_id
|
||||
for p in downloads.get(node_id, [])
|
||||
)
|
||||
failed = [
|
||||
p["DownloadFailed"]["errorMessage"]
|
||||
for p in downloads.get(node_id, [])
|
||||
if "DownloadFailed" in p
|
||||
and unwrap_instance(p["DownloadFailed"]["shardMetadata"])["modelCard"][
|
||||
"modelId"
|
||||
]
|
||||
== full_model_id
|
||||
]
|
||||
if failed:
|
||||
raise RuntimeError(f"Download failed on {node_id}: {failed[0]}")
|
||||
if not done:
|
||||
all_done = False
|
||||
if all_done:
|
||||
return
|
||||
time.sleep(1)
|
||||
|
||||
raise TimeoutError("Downloads did not complete in time")
|
||||
|
||||
|
||||
def add_common_instance_args(ap: argparse.ArgumentParser) -> None:
|
||||
ap.add_argument("--host", default=os.environ.get("EXO_HOST", "localhost"))
|
||||
ap.add_argument(
|
||||
@@ -325,3 +470,8 @@ def add_common_instance_args(ap: argparse.ArgumentParser) -> None:
|
||||
default=0,
|
||||
help="Max seconds to wait for the cluster to produce valid placements (0 = try once).",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--danger-delete-downloads",
|
||||
action="store_true",
|
||||
help="Delete existing models from smallest to largest to make room for benchmark model.",
|
||||
)
|
||||
|
||||
@@ -14,6 +14,21 @@
|
||||
: 0,
|
||||
);
|
||||
|
||||
const etaText = $derived.by(() => {
|
||||
if (progress.processed <= 0 || progress.total <= 0) return null;
|
||||
const elapsedMs = performance.now() - progress.startedAt;
|
||||
if (elapsedMs < 200) return null; // need a minimum sample window
|
||||
const tokensPerMs = progress.processed / elapsedMs;
|
||||
const remainingTokens = progress.total - progress.processed;
|
||||
const remainingMs = remainingTokens / tokensPerMs;
|
||||
const remainingSec = Math.ceil(remainingMs / 1000);
|
||||
if (remainingSec <= 0) return null;
|
||||
if (remainingSec < 60) return `~${remainingSec}s remaining`;
|
||||
const mins = Math.floor(remainingSec / 60);
|
||||
const secs = remainingSec % 60;
|
||||
return `~${mins}m ${secs}s remaining`;
|
||||
});
|
||||
|
||||
function formatTokenCount(count: number | undefined): string {
|
||||
if (count == null) return "0";
|
||||
if (count >= 1000) {
|
||||
@@ -40,8 +55,11 @@
|
||||
style="width: {percentage}%"
|
||||
></div>
|
||||
</div>
|
||||
<div class="text-right text-xs text-exo-light-gray/70 mt-0.5 font-mono">
|
||||
{percentage}%
|
||||
<div
|
||||
class="flex items-center justify-between text-xs text-exo-light-gray/70 mt-0.5 font-mono"
|
||||
>
|
||||
<span>{etaText ?? ""}</span>
|
||||
<span>{percentage}%</span>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
|
||||
@@ -276,6 +276,8 @@ export interface TokenData {
|
||||
export interface PrefillProgress {
|
||||
processed: number;
|
||||
total: number;
|
||||
/** Timestamp (performance.now()) when prefill started. */
|
||||
startedAt: number;
|
||||
}
|
||||
|
||||
export interface Message {
|
||||
@@ -1652,11 +1654,12 @@ class AppStore {
|
||||
if (!reader) throw new Error("No response body");
|
||||
|
||||
let fullContent = prefixText;
|
||||
let streamedThinking = "";
|
||||
const collectedTokens: TokenData[] = [...tokensToKeep];
|
||||
|
||||
interface ChatCompletionChunk {
|
||||
choices?: Array<{
|
||||
delta?: { content?: string };
|
||||
delta?: { content?: string; reasoning_content?: string };
|
||||
logprobs?: {
|
||||
content?: Array<{
|
||||
token: string;
|
||||
@@ -1677,6 +1680,7 @@ class AppStore {
|
||||
(parsed) => {
|
||||
const choice = parsed.choices?.[0];
|
||||
const delta = choice?.delta?.content;
|
||||
const thinkingDelta = choice?.delta?.reasoning_content;
|
||||
|
||||
// Collect logprobs data
|
||||
const logprobsContent = choice?.logprobs?.content;
|
||||
@@ -1695,7 +1699,11 @@ class AppStore {
|
||||
}
|
||||
}
|
||||
|
||||
if (delta) {
|
||||
if (thinkingDelta) {
|
||||
streamedThinking += thinkingDelta;
|
||||
}
|
||||
|
||||
if (delta || thinkingDelta) {
|
||||
if (firstTokenTime === null) {
|
||||
firstTokenTime = performance.now();
|
||||
this.ttftMs = firstTokenTime - requestStartTime;
|
||||
@@ -1709,9 +1717,14 @@ class AppStore {
|
||||
this.tps = ((tokenCount - tokensToKeep.length) / elapsed) * 1000;
|
||||
}
|
||||
|
||||
fullContent += delta;
|
||||
const { displayContent, thinkingContent } =
|
||||
if (delta) {
|
||||
fullContent += delta;
|
||||
}
|
||||
const { displayContent, thinkingContent: tagThinking } =
|
||||
this.stripThinkingTags(fullContent);
|
||||
const combinedThinking = [streamedThinking, tagThinking]
|
||||
.filter(Boolean)
|
||||
.join("\n\n");
|
||||
|
||||
if (this.activeConversationId === targetConversationId) {
|
||||
this.currentResponse = displayContent;
|
||||
@@ -1723,7 +1736,7 @@ class AppStore {
|
||||
messageId,
|
||||
(m) => {
|
||||
m.content = displayContent;
|
||||
m.thinking = thinkingContent || undefined;
|
||||
m.thinking = combinedThinking || undefined;
|
||||
m.tokens = [...collectedTokens];
|
||||
},
|
||||
);
|
||||
@@ -1735,11 +1748,14 @@ class AppStore {
|
||||
|
||||
// Final update
|
||||
if (this.conversationExists(targetConversationId)) {
|
||||
const { displayContent, thinkingContent } =
|
||||
const { displayContent, thinkingContent: tagThinking } =
|
||||
this.stripThinkingTags(fullContent);
|
||||
const finalThinking = [streamedThinking, tagThinking]
|
||||
.filter(Boolean)
|
||||
.join("\n\n");
|
||||
this.updateConversationMessage(targetConversationId, messageId, (m) => {
|
||||
m.content = displayContent;
|
||||
m.thinking = thinkingContent || undefined;
|
||||
m.thinking = finalThinking || undefined;
|
||||
m.tokens = [...collectedTokens];
|
||||
if (this.ttftMs !== null) m.ttftMs = this.ttftMs;
|
||||
if (this.tps !== null) m.tps = this.tps;
|
||||
@@ -1847,11 +1863,12 @@ class AppStore {
|
||||
}
|
||||
|
||||
let streamedContent = "";
|
||||
let streamedThinking = "";
|
||||
const collectedTokens: TokenData[] = [];
|
||||
|
||||
interface ChatCompletionChunk {
|
||||
choices?: Array<{
|
||||
delta?: { content?: string };
|
||||
delta?: { content?: string; reasoning_content?: string };
|
||||
logprobs?: {
|
||||
content?: Array<{
|
||||
token: string;
|
||||
@@ -1872,6 +1889,7 @@ class AppStore {
|
||||
(parsed) => {
|
||||
const choice = parsed.choices?.[0];
|
||||
const delta = choice?.delta?.content;
|
||||
const thinkingDelta = choice?.delta?.reasoning_content;
|
||||
|
||||
// Collect logprobs data
|
||||
const logprobsContent = choice?.logprobs?.content;
|
||||
@@ -1890,10 +1908,19 @@ class AppStore {
|
||||
}
|
||||
}
|
||||
|
||||
if (delta) {
|
||||
streamedContent += delta;
|
||||
const { displayContent, thinkingContent } =
|
||||
if (thinkingDelta) {
|
||||
streamedThinking += thinkingDelta;
|
||||
}
|
||||
|
||||
if (delta || thinkingDelta) {
|
||||
if (delta) {
|
||||
streamedContent += delta;
|
||||
}
|
||||
const { displayContent, thinkingContent: tagThinking } =
|
||||
this.stripThinkingTags(streamedContent);
|
||||
const combinedThinking = [streamedThinking, tagThinking]
|
||||
.filter(Boolean)
|
||||
.join("\n\n");
|
||||
|
||||
// Only update currentResponse if target conversation is active
|
||||
if (this.activeConversationId === targetConversationId) {
|
||||
@@ -1906,7 +1933,7 @@ class AppStore {
|
||||
assistantMessage.id,
|
||||
(msg) => {
|
||||
msg.content = displayContent;
|
||||
msg.thinking = thinkingContent || undefined;
|
||||
msg.thinking = combinedThinking || undefined;
|
||||
msg.tokens = [...collectedTokens];
|
||||
},
|
||||
);
|
||||
@@ -1918,14 +1945,17 @@ class AppStore {
|
||||
|
||||
// Final cleanup of the message (if conversation still exists)
|
||||
if (this.conversationExists(targetConversationId)) {
|
||||
const { displayContent, thinkingContent } =
|
||||
const { displayContent, thinkingContent: tagThinking } =
|
||||
this.stripThinkingTags(streamedContent);
|
||||
const finalThinking = [streamedThinking, tagThinking]
|
||||
.filter(Boolean)
|
||||
.join("\n\n");
|
||||
this.updateConversationMessage(
|
||||
targetConversationId,
|
||||
assistantMessage.id,
|
||||
(msg) => {
|
||||
msg.content = displayContent;
|
||||
msg.thinking = thinkingContent || undefined;
|
||||
msg.thinking = finalThinking || undefined;
|
||||
msg.tokens = [...collectedTokens];
|
||||
},
|
||||
);
|
||||
@@ -2317,10 +2347,11 @@ class AppStore {
|
||||
}
|
||||
|
||||
let streamedContent = "";
|
||||
let streamedThinking = "";
|
||||
|
||||
interface ChatCompletionChunk {
|
||||
choices?: Array<{
|
||||
delta?: { content?: string };
|
||||
delta?: { content?: string; reasoning_content?: string };
|
||||
logprobs?: {
|
||||
content?: Array<{
|
||||
token: string;
|
||||
@@ -2348,6 +2379,7 @@ class AppStore {
|
||||
|
||||
const choice = parsed.choices?.[0];
|
||||
const tokenContent = choice?.delta?.content;
|
||||
const thinkingContent = choice?.delta?.reasoning_content;
|
||||
|
||||
// Collect logprobs data
|
||||
const logprobsContent = choice?.logprobs?.content;
|
||||
@@ -2366,7 +2398,11 @@ class AppStore {
|
||||
}
|
||||
}
|
||||
|
||||
if (tokenContent) {
|
||||
if (thinkingContent) {
|
||||
streamedThinking += thinkingContent;
|
||||
}
|
||||
|
||||
if (tokenContent || thinkingContent) {
|
||||
// Track first token for TTFT
|
||||
if (firstTokenTime === null) {
|
||||
firstTokenTime = performance.now();
|
||||
@@ -2383,11 +2419,16 @@ class AppStore {
|
||||
this.tps = (tokenCount / elapsed) * 1000;
|
||||
}
|
||||
|
||||
streamedContent += tokenContent;
|
||||
if (tokenContent) {
|
||||
streamedContent += tokenContent;
|
||||
}
|
||||
|
||||
// Strip thinking tags for display and extract thinking content
|
||||
const { displayContent, thinkingContent } =
|
||||
// Use stripThinkingTags as fallback for any <think> tags still in content
|
||||
const { displayContent, thinkingContent: tagThinking } =
|
||||
this.stripThinkingTags(streamedContent);
|
||||
const combinedThinking = [streamedThinking, tagThinking]
|
||||
.filter(Boolean)
|
||||
.join("\n\n");
|
||||
|
||||
// Only update currentResponse if target conversation is active
|
||||
if (this.activeConversationId === targetConversationId) {
|
||||
@@ -2400,7 +2441,7 @@ class AppStore {
|
||||
assistantMessage.id,
|
||||
(msg) => {
|
||||
msg.content = displayContent;
|
||||
msg.thinking = thinkingContent || undefined;
|
||||
msg.thinking = combinedThinking || undefined;
|
||||
msg.tokens = [...collectedTokens];
|
||||
},
|
||||
);
|
||||
@@ -2420,6 +2461,7 @@ class AppStore {
|
||||
this.prefillProgress = {
|
||||
processed: inner.processed_tokens,
|
||||
total: inner.total_tokens,
|
||||
startedAt: this.prefillProgress?.startedAt ?? performance.now(),
|
||||
};
|
||||
},
|
||||
},
|
||||
@@ -2436,14 +2478,17 @@ class AppStore {
|
||||
|
||||
// Final cleanup of the message (if conversation still exists)
|
||||
if (this.conversationExists(targetConversationId)) {
|
||||
const { displayContent, thinkingContent } =
|
||||
const { displayContent, thinkingContent: tagThinking } =
|
||||
this.stripThinkingTags(streamedContent);
|
||||
const finalThinking = [streamedThinking, tagThinking]
|
||||
.filter(Boolean)
|
||||
.join("\n\n");
|
||||
this.updateConversationMessage(
|
||||
targetConversationId,
|
||||
assistantMessage.id,
|
||||
(msg) => {
|
||||
msg.content = displayContent;
|
||||
msg.thinking = thinkingContent || undefined;
|
||||
msg.thinking = finalThinking || undefined;
|
||||
msg.tokens = [...collectedTokens];
|
||||
// Store performance metrics on the message
|
||||
if (this.ttftMs !== null) {
|
||||
|
||||
@@ -114,6 +114,74 @@
|
||||
});
|
||||
let tb5InfoDismissed = $state(false);
|
||||
|
||||
// Detect Mac Studio nodes using RDMA on en2 (the port next to ethernet — RDMA doesn't work there)
|
||||
const macStudioEn2RdmaWarning = $derived.by(() => {
|
||||
const edges = data?.edges;
|
||||
const ids = tbIdentifiers;
|
||||
const rdmaCtl = rdmaCtlData;
|
||||
if (!edges || !ids || !rdmaCtl) return null;
|
||||
|
||||
const affectedConnections: Array<{
|
||||
nodeId: string;
|
||||
nodeName: string;
|
||||
peerNodeId: string;
|
||||
peerNodeName: string;
|
||||
rdmaIface: string;
|
||||
}> = [];
|
||||
|
||||
const isMacStudio = (node: (typeof data.nodes)[string] | undefined) =>
|
||||
node?.system_info?.model_id === "Mac Studio";
|
||||
|
||||
for (const edge of edges) {
|
||||
if (!edge.sourceRdmaIface && !edge.sinkRdmaIface) continue;
|
||||
|
||||
const sourceNode = data?.nodes?.[edge.source];
|
||||
if (
|
||||
isMacStudio(sourceNode) &&
|
||||
edge.sourceRdmaIface === "rdma_en2" &&
|
||||
rdmaCtl[edge.source]?.enabled
|
||||
) {
|
||||
affectedConnections.push({
|
||||
nodeId: edge.source,
|
||||
nodeName:
|
||||
sourceNode?.friendly_name || edge.source.slice(0, 8) + "...",
|
||||
peerNodeId: edge.target,
|
||||
peerNodeName:
|
||||
data?.nodes?.[edge.target]?.friendly_name ||
|
||||
edge.target.slice(0, 8) + "...",
|
||||
rdmaIface: "en2",
|
||||
});
|
||||
}
|
||||
|
||||
const sinkNode = data?.nodes?.[edge.target];
|
||||
if (
|
||||
isMacStudio(sinkNode) &&
|
||||
edge.sinkRdmaIface === "rdma_en2" &&
|
||||
rdmaCtl[edge.target]?.enabled
|
||||
) {
|
||||
affectedConnections.push({
|
||||
nodeId: edge.target,
|
||||
nodeName: sinkNode?.friendly_name || edge.target.slice(0, 8) + "...",
|
||||
peerNodeId: edge.source,
|
||||
peerNodeName:
|
||||
sourceNode?.friendly_name || edge.source.slice(0, 8) + "...",
|
||||
rdmaIface: "en2",
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Deduplicate by nodeId
|
||||
const seen = new Set<string>();
|
||||
const unique = affectedConnections.filter((c) => {
|
||||
if (seen.has(c.nodeId)) return false;
|
||||
seen.add(c.nodeId);
|
||||
return true;
|
||||
});
|
||||
|
||||
return unique.length > 0 ? unique : null;
|
||||
});
|
||||
let macStudioEn2Dismissed = $state(false);
|
||||
|
||||
// Helper to get friendly node name from node ID
|
||||
function getNodeName(nodeId: string): string {
|
||||
const node = data?.nodes?.[nodeId];
|
||||
@@ -1077,7 +1145,7 @@
|
||||
return {
|
||||
isDownloading: false,
|
||||
isFailed: statusInfo.statusText === "FAILED",
|
||||
errorMessage: statusInfo.errorMessage,
|
||||
errorMessage: null,
|
||||
progress: null,
|
||||
statusText: statusInfo.statusText,
|
||||
perNode: [],
|
||||
@@ -1135,15 +1203,10 @@
|
||||
function deriveInstanceStatus(instanceWrapped: unknown): {
|
||||
statusText: string;
|
||||
statusClass: string;
|
||||
errorMessage: string | null;
|
||||
} {
|
||||
const [, instance] = getTagged(instanceWrapped);
|
||||
if (!instance || typeof instance !== "object") {
|
||||
return {
|
||||
statusText: "PREPARING",
|
||||
statusClass: "inactive",
|
||||
errorMessage: null,
|
||||
};
|
||||
return { statusText: "PREPARING", statusClass: "inactive" };
|
||||
}
|
||||
|
||||
const inst = instance as {
|
||||
@@ -1151,106 +1214,50 @@
|
||||
};
|
||||
const runnerIds = Object.keys(inst.shardAssignments?.runnerToShard || {});
|
||||
|
||||
const statusMap: Record<string, string> = {
|
||||
RunnerWaitingForInitialization: "WaitingForInitialization",
|
||||
RunnerInitializingBackend: "InitializingBackend",
|
||||
RunnerWaitingForModel: "WaitingForModel",
|
||||
RunnerLoading: "Loading",
|
||||
RunnerLoaded: "Loaded",
|
||||
RunnerWarmingUp: "WarmingUp",
|
||||
RunnerReady: "Ready",
|
||||
RunnerRunning: "Running",
|
||||
RunnerShutdown: "Shutdown",
|
||||
RunnerFailed: "Failed",
|
||||
};
|
||||
|
||||
const statuses = runnerIds
|
||||
.map((rid) => {
|
||||
const r = runnersData[rid];
|
||||
if (!r) return null;
|
||||
const [kind, payload] = getTagged(r);
|
||||
if (!kind || !statusMap[kind]) return null;
|
||||
const errorMessage =
|
||||
kind === "RunnerFailed" && payload && typeof payload === "object"
|
||||
? (((payload as Record<string, unknown>).errorMessage as string) ??
|
||||
null)
|
||||
: null;
|
||||
return { status: statusMap[kind], errorMessage };
|
||||
const [kind] = getTagged(r);
|
||||
const statusMap: Record<string, string> = {
|
||||
RunnerWaitingForInitialization: "WaitingForInitialization",
|
||||
RunnerInitializingBackend: "InitializingBackend",
|
||||
RunnerWaitingForModel: "WaitingForModel",
|
||||
RunnerLoading: "Loading",
|
||||
RunnerLoaded: "Loaded",
|
||||
RunnerWarmingUp: "WarmingUp",
|
||||
RunnerReady: "Ready",
|
||||
RunnerRunning: "Running",
|
||||
RunnerShutdown: "Shutdown",
|
||||
RunnerFailed: "Failed",
|
||||
};
|
||||
return kind ? statusMap[kind] || null : null;
|
||||
})
|
||||
.filter(
|
||||
(s): s is { status: string; errorMessage: string | null } => s !== null,
|
||||
);
|
||||
.filter((s): s is string => s !== null);
|
||||
|
||||
const has = (s: string) => statuses.some((e) => e.status === s);
|
||||
const has = (s: string) => statuses.includes(s);
|
||||
|
||||
if (statuses.length === 0)
|
||||
return {
|
||||
statusText: "PREPARING",
|
||||
statusClass: "inactive",
|
||||
errorMessage: null,
|
||||
};
|
||||
if (has("Failed")) {
|
||||
const failedEntry = statuses.find(
|
||||
(e) => e.status === "Failed" && e.errorMessage,
|
||||
);
|
||||
return {
|
||||
statusText: "FAILED",
|
||||
statusClass: "failed",
|
||||
errorMessage: failedEntry?.errorMessage ?? null,
|
||||
};
|
||||
}
|
||||
return { statusText: "PREPARING", statusClass: "inactive" };
|
||||
if (has("Failed")) return { statusText: "FAILED", statusClass: "failed" };
|
||||
if (has("Shutdown"))
|
||||
return {
|
||||
statusText: "SHUTDOWN",
|
||||
statusClass: "inactive",
|
||||
errorMessage: null,
|
||||
};
|
||||
return { statusText: "SHUTDOWN", statusClass: "inactive" };
|
||||
if (has("Loading"))
|
||||
return {
|
||||
statusText: "LOADING",
|
||||
statusClass: "starting",
|
||||
errorMessage: null,
|
||||
};
|
||||
return { statusText: "LOADING", statusClass: "starting" };
|
||||
if (has("WarmingUp"))
|
||||
return {
|
||||
statusText: "WARMING UP",
|
||||
statusClass: "starting",
|
||||
errorMessage: null,
|
||||
};
|
||||
return { statusText: "WARMING UP", statusClass: "starting" };
|
||||
if (has("Running"))
|
||||
return {
|
||||
statusText: "RUNNING",
|
||||
statusClass: "running",
|
||||
errorMessage: null,
|
||||
};
|
||||
if (has("Ready"))
|
||||
return { statusText: "READY", statusClass: "loaded", errorMessage: null };
|
||||
if (has("Loaded"))
|
||||
return {
|
||||
statusText: "LOADED",
|
||||
statusClass: "loaded",
|
||||
errorMessage: null,
|
||||
};
|
||||
return { statusText: "RUNNING", statusClass: "running" };
|
||||
if (has("Ready")) return { statusText: "READY", statusClass: "loaded" };
|
||||
if (has("Loaded")) return { statusText: "LOADED", statusClass: "loaded" };
|
||||
if (has("WaitingForModel"))
|
||||
return {
|
||||
statusText: "WAITING",
|
||||
statusClass: "starting",
|
||||
errorMessage: null,
|
||||
};
|
||||
return { statusText: "WAITING", statusClass: "starting" };
|
||||
if (has("InitializingBackend"))
|
||||
return {
|
||||
statusText: "INITIALIZING",
|
||||
statusClass: "starting",
|
||||
errorMessage: null,
|
||||
};
|
||||
return { statusText: "INITIALIZING", statusClass: "starting" };
|
||||
if (has("WaitingForInitialization"))
|
||||
return {
|
||||
statusText: "INITIALIZING",
|
||||
statusClass: "starting",
|
||||
errorMessage: null,
|
||||
};
|
||||
return { statusText: "INITIALIZING", statusClass: "starting" };
|
||||
|
||||
return { statusText: "RUNNING", statusClass: "active", errorMessage: null };
|
||||
return { statusText: "RUNNING", statusClass: "active" };
|
||||
}
|
||||
|
||||
function getBytes(value: unknown): number {
|
||||
@@ -1819,7 +1826,7 @@
|
||||
</script>
|
||||
|
||||
{#snippet clusterWarnings()}
|
||||
{#if tbBridgeCycles.length > 0 || macosVersionMismatch || (tb5WithoutRdma && !tb5InfoDismissed)}
|
||||
{#if tbBridgeCycles.length > 0 || macosVersionMismatch || (tb5WithoutRdma && !tb5InfoDismissed) || (macStudioEn2RdmaWarning && !macStudioEn2Dismissed)}
|
||||
<div class="absolute top-4 left-4 flex flex-col gap-2 z-40">
|
||||
{#if tbBridgeCycles.length > 0}
|
||||
{@const cycle = tbBridgeCycles[0]}
|
||||
@@ -1984,12 +1991,260 @@
|
||||
</button>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
{#if macStudioEn2RdmaWarning && !macStudioEn2Dismissed}
|
||||
<div class="group relative" role="alert">
|
||||
<div
|
||||
class="flex items-center gap-2 px-3 py-2 rounded border border-red-500/50 bg-red-500/10 backdrop-blur-sm cursor-help"
|
||||
>
|
||||
<svg
|
||||
class="w-5 h-5 text-red-400 flex-shrink-0"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d={warningIconPath}
|
||||
/>
|
||||
</svg>
|
||||
<span class="text-sm font-mono text-red-200">
|
||||
RDMA INCOMPATIBLE PORT
|
||||
</span>
|
||||
<button
|
||||
type="button"
|
||||
onclick={() => (macStudioEn2Dismissed = true)}
|
||||
class="ml-1 text-red-300/60 hover:text-red-200 transition-colors cursor-pointer"
|
||||
title="Dismiss"
|
||||
>
|
||||
<svg
|
||||
class="w-4 h-4"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d="M6 18L18 6M6 6l12 12"
|
||||
/>
|
||||
</svg>
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<!-- Expanded tooltip on hover -->
|
||||
<div
|
||||
class="absolute top-full left-0 mt-2 w-96 p-4 rounded border border-red-500/30 bg-[#1a1a1a]/95 backdrop-blur-sm opacity-0 invisible group-hover:opacity-100 group-hover:visible transition-all duration-200 z-50 shadow-lg"
|
||||
>
|
||||
<p class="text-xs text-white/80 mb-3">
|
||||
The Thunderbolt 5 port next to the Ethernet port on Mac Studio
|
||||
does
|
||||
<span class="text-red-400 font-semibold">not support RDMA</span>.
|
||||
Move the cable to one of the other three TB5 ports.
|
||||
</p>
|
||||
|
||||
<div class="text-xs text-white/60 mb-3">
|
||||
<span class="text-red-300">Affected:</span>
|
||||
{#each macStudioEn2RdmaWarning as conn}
|
||||
<div class="ml-2 mt-0.5">
|
||||
<span class="text-white/80">{conn.nodeName}</span>
|
||||
<span class="text-white/30">→</span>
|
||||
<span class="text-white/60">{conn.peerNodeName}</span>
|
||||
<span class="text-white/30 ml-1">(en2)</span>
|
||||
</div>
|
||||
{/each}
|
||||
</div>
|
||||
|
||||
<!-- Mac Studio back panel illustration -->
|
||||
<div class="bg-black/40 rounded p-3 mb-3">
|
||||
<p
|
||||
class="text-[10px] font-mono text-white/30 uppercase tracking-wider mb-2"
|
||||
>
|
||||
Mac Studio — Rear Panel
|
||||
</p>
|
||||
<svg
|
||||
viewBox="0 0 320 72"
|
||||
class="w-full"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
>
|
||||
<rect
|
||||
x="1"
|
||||
y="1"
|
||||
width="318"
|
||||
height="70"
|
||||
rx="6"
|
||||
ry="6"
|
||||
fill="none"
|
||||
stroke="rgba(255,255,255,0.12)"
|
||||
stroke-width="1"
|
||||
/>
|
||||
<!-- TB5 port 1 -->
|
||||
<rect
|
||||
x="24"
|
||||
y="22"
|
||||
width="28"
|
||||
height="14"
|
||||
rx="4"
|
||||
fill="none"
|
||||
stroke="rgba(255,255,255,0.3)"
|
||||
stroke-width="1"
|
||||
/>
|
||||
<text
|
||||
x="38"
|
||||
y="52"
|
||||
text-anchor="middle"
|
||||
fill="rgba(255,255,255,0.25)"
|
||||
style="font-size:7px;font-family:ui-monospace,monospace;"
|
||||
>TB5</text
|
||||
>
|
||||
<!-- TB5 port 2 -->
|
||||
<rect
|
||||
x="62"
|
||||
y="22"
|
||||
width="28"
|
||||
height="14"
|
||||
rx="4"
|
||||
fill="none"
|
||||
stroke="rgba(255,255,255,0.3)"
|
||||
stroke-width="1"
|
||||
/>
|
||||
<text
|
||||
x="76"
|
||||
y="52"
|
||||
text-anchor="middle"
|
||||
fill="rgba(255,255,255,0.25)"
|
||||
style="font-size:7px;font-family:ui-monospace,monospace;"
|
||||
>TB5</text
|
||||
>
|
||||
<!-- TB5 port 3 -->
|
||||
<rect
|
||||
x="100"
|
||||
y="22"
|
||||
width="28"
|
||||
height="14"
|
||||
rx="4"
|
||||
fill="none"
|
||||
stroke="rgba(255,255,255,0.3)"
|
||||
stroke-width="1"
|
||||
/>
|
||||
<text
|
||||
x="114"
|
||||
y="52"
|
||||
text-anchor="middle"
|
||||
fill="rgba(255,255,255,0.25)"
|
||||
style="font-size:7px;font-family:ui-monospace,monospace;"
|
||||
>TB5</text
|
||||
>
|
||||
<!-- TB5 port 4: INCOMPATIBLE (en2) — equally spaced with ports 1-3 -->
|
||||
<rect
|
||||
x="138"
|
||||
y="22"
|
||||
width="28"
|
||||
height="14"
|
||||
rx="4"
|
||||
fill="rgba(239,68,68,0.1)"
|
||||
stroke="rgba(239,68,68,0.7)"
|
||||
stroke-width="1.5"
|
||||
/>
|
||||
<line
|
||||
x1="142"
|
||||
y1="25"
|
||||
x2="162"
|
||||
y2="33"
|
||||
stroke="rgba(239,68,68,0.8)"
|
||||
stroke-width="1.5"
|
||||
stroke-linecap="round"
|
||||
/>
|
||||
<line
|
||||
x1="162"
|
||||
y1="25"
|
||||
x2="142"
|
||||
y2="33"
|
||||
stroke="rgba(239,68,68,0.8)"
|
||||
stroke-width="1.5"
|
||||
stroke-linecap="round"
|
||||
/>
|
||||
<text
|
||||
x="152"
|
||||
y="52"
|
||||
text-anchor="middle"
|
||||
fill="rgba(239,68,68,0.6)"
|
||||
style="font-size:7px;font-family:ui-monospace,monospace;font-weight:600;"
|
||||
>en2</text
|
||||
>
|
||||
<!-- Ethernet port -->
|
||||
<rect
|
||||
x="196"
|
||||
y="19"
|
||||
width="24"
|
||||
height="20"
|
||||
rx="2"
|
||||
fill="none"
|
||||
stroke="rgba(255,255,255,0.2)"
|
||||
stroke-width="1"
|
||||
/>
|
||||
<rect
|
||||
x="200"
|
||||
y="23"
|
||||
width="16"
|
||||
height="12"
|
||||
rx="1"
|
||||
fill="none"
|
||||
stroke="rgba(255,255,255,0.12)"
|
||||
stroke-width="0.75"
|
||||
/>
|
||||
<text
|
||||
x="208"
|
||||
y="52"
|
||||
text-anchor="middle"
|
||||
fill="rgba(255,255,255,0.25)"
|
||||
style="font-size:7px;font-family:ui-monospace,monospace;"
|
||||
>ETH</text
|
||||
>
|
||||
<!-- Green checkmarks on working ports -->
|
||||
<circle
|
||||
cx="38"
|
||||
cy="62"
|
||||
r="3"
|
||||
fill="none"
|
||||
stroke="rgba(74,222,128,0.5)"
|
||||
stroke-width="0.75"
|
||||
/>
|
||||
<circle
|
||||
cx="76"
|
||||
cy="62"
|
||||
r="3"
|
||||
fill="none"
|
||||
stroke="rgba(74,222,128,0.5)"
|
||||
stroke-width="0.75"
|
||||
/>
|
||||
<circle
|
||||
cx="114"
|
||||
cy="62"
|
||||
r="3"
|
||||
fill="none"
|
||||
stroke="rgba(74,222,128,0.5)"
|
||||
stroke-width="0.75"
|
||||
/>
|
||||
</svg>
|
||||
</div>
|
||||
|
||||
<p class="text-xs text-white/50">
|
||||
<span class="text-green-400">Fix:</span> Move the Thunderbolt cable
|
||||
to any of the three leftmost ports (all support RDMA).
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
{/if}
|
||||
{/snippet}
|
||||
|
||||
{#snippet clusterWarningsCompact()}
|
||||
{#if tbBridgeCycles.length > 0 || macosVersionMismatch || (tb5WithoutRdma && !tb5InfoDismissed)}
|
||||
{#if tbBridgeCycles.length > 0 || macosVersionMismatch || (tb5WithoutRdma && !tb5InfoDismissed) || (macStudioEn2RdmaWarning && !macStudioEn2Dismissed)}
|
||||
<div class="absolute top-2 left-2 flex flex-col gap-1">
|
||||
{#if tbBridgeCycles.length > 0}
|
||||
<div
|
||||
@@ -2057,6 +2312,27 @@
|
||||
>
|
||||
</div>
|
||||
{/if}
|
||||
{#if macStudioEn2RdmaWarning && !macStudioEn2Dismissed}
|
||||
<div
|
||||
class="flex items-center gap-1.5 px-2 py-1 rounded border border-red-500/50 bg-red-500/10 backdrop-blur-sm"
|
||||
title="Mac Studio RDMA incompatible port (en2) — move cable to another TB5 port"
|
||||
>
|
||||
<svg
|
||||
class="w-3.5 h-3.5 text-red-400"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d={warningIconPath}
|
||||
/>
|
||||
</svg>
|
||||
<span class="text-[10px] font-mono text-red-200">BAD RDMA PORT</span>
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
{/if}
|
||||
{/snippet}
|
||||
|
||||
@@ -30,7 +30,7 @@
|
||||
modelDirectory?: string;
|
||||
}
|
||||
| { kind: "pending"; modelDirectory?: string }
|
||||
| { kind: "failed"; modelDirectory?: string; errorMessage?: string }
|
||||
| { kind: "failed"; modelDirectory?: string }
|
||||
| { kind: "not_present" };
|
||||
|
||||
type ModelCardInfo = {
|
||||
@@ -263,10 +263,7 @@
|
||||
modelDirectory,
|
||||
};
|
||||
} else if (tag === "DownloadFailed") {
|
||||
const errorMsg =
|
||||
((payload.errorMessage ?? payload.error_message) as string) ||
|
||||
undefined;
|
||||
cell = { kind: "failed", modelDirectory, errorMessage: errorMsg };
|
||||
cell = { kind: "failed", modelDirectory };
|
||||
} else {
|
||||
cell = { kind: "pending", modelDirectory };
|
||||
}
|
||||
@@ -502,7 +499,7 @@
|
||||
{:else if cell.kind === "failed"}
|
||||
<div
|
||||
class="flex flex-col items-center gap-0.5"
|
||||
title={cell.errorMessage ?? "Download failed"}
|
||||
title="Download failed"
|
||||
>
|
||||
<svg
|
||||
class="w-5 h-5 text-red-400"
|
||||
@@ -515,14 +512,6 @@
|
||||
clip-rule="evenodd"
|
||||
></path>
|
||||
</svg>
|
||||
{#if cell.errorMessage}
|
||||
<span
|
||||
class="text-[9px] text-red-400/70 max-w-[120px] truncate"
|
||||
title={cell.errorMessage}
|
||||
>
|
||||
{cell.errorMessage}
|
||||
</span>
|
||||
{/if}
|
||||
{#if row.shardMetadata}
|
||||
<button
|
||||
type="button"
|
||||
@@ -718,11 +707,6 @@
|
||||
({clampPercent(cellStatus.percentage).toFixed(0)}%)
|
||||
{/if}
|
||||
</span>
|
||||
{#if cellStatus.kind === "failed" && "errorMessage" in cellStatus && cellStatus.errorMessage}
|
||||
<span class="text-[9px] text-red-400/70 break-all pl-1">
|
||||
{cellStatus.errorMessage}
|
||||
</span>
|
||||
{/if}
|
||||
{#if "modelDirectory" in cellStatus && cellStatus.modelDirectory}
|
||||
<span
|
||||
class="text-[9px] text-white/30 break-all pl-1"
|
||||
|
||||
BIN
prefill-eta-screenshot.png
Normal file
BIN
prefill-eta-screenshot.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 110 KiB |
@@ -19,7 +19,7 @@ class ConnectionUpdate:
|
||||
Whether this is a connection or disconnection event
|
||||
"""
|
||||
@property
|
||||
def peer_id(self) -> PeerId:
|
||||
def peer_id(self) -> builtins.str:
|
||||
r"""
|
||||
Identity of the peer that we have connected to or disconnected from.
|
||||
"""
|
||||
@@ -40,92 +40,22 @@ class Keypair:
|
||||
Identity keypair of a node.
|
||||
"""
|
||||
@staticmethod
|
||||
def generate_ed25519() -> Keypair:
|
||||
def generate() -> Keypair:
|
||||
r"""
|
||||
Generate a new Ed25519 keypair.
|
||||
"""
|
||||
@staticmethod
|
||||
def generate_ecdsa() -> Keypair:
|
||||
def from_bytes(bytes: bytes) -> Keypair:
|
||||
r"""
|
||||
Generate a new ECDSA keypair.
|
||||
"""
|
||||
@staticmethod
|
||||
def generate_secp256k1() -> Keypair:
|
||||
r"""
|
||||
Generate a new Secp256k1 keypair.
|
||||
"""
|
||||
@staticmethod
|
||||
def from_protobuf_encoding(bytes: bytes) -> Keypair:
|
||||
r"""
|
||||
Decode a private key from a protobuf structure and parse it as a `Keypair`.
|
||||
"""
|
||||
@staticmethod
|
||||
def rsa_from_pkcs8(bytes: bytes) -> Keypair:
|
||||
r"""
|
||||
Decode an keypair from a DER-encoded secret key in PKCS#8 `PrivateKeyInfo`
|
||||
format (i.e. unencrypted) as defined in [RFC5208].
|
||||
|
||||
[RFC5208]: https://tools.ietf.org/html/rfc5208#section-5
|
||||
"""
|
||||
@staticmethod
|
||||
def secp256k1_from_der(bytes: bytes) -> Keypair:
|
||||
r"""
|
||||
Decode a keypair from a DER-encoded Secp256k1 secret key in an `ECPrivateKey`
|
||||
structure as defined in [RFC5915].
|
||||
|
||||
[RFC5915]: https://tools.ietf.org/html/rfc5915
|
||||
"""
|
||||
@staticmethod
|
||||
def ed25519_from_bytes(bytes: bytes) -> Keypair: ...
|
||||
def to_protobuf_encoding(self) -> bytes:
|
||||
r"""
|
||||
Encode a private key as protobuf structure.
|
||||
"""
|
||||
def to_peer_id(self) -> PeerId:
|
||||
r"""
|
||||
Convert the `Keypair` into the corresponding `PeerId`.
|
||||
"""
|
||||
|
||||
@typing.final
|
||||
class Multiaddr:
|
||||
r"""
|
||||
Representation of a Multiaddr.
|
||||
"""
|
||||
@staticmethod
|
||||
def empty() -> Multiaddr:
|
||||
r"""
|
||||
Create a new, empty multiaddress.
|
||||
"""
|
||||
@staticmethod
|
||||
def with_capacity(n: builtins.int) -> Multiaddr:
|
||||
r"""
|
||||
Create a new, empty multiaddress with the given capacity.
|
||||
"""
|
||||
@staticmethod
|
||||
def from_bytes(bytes: bytes) -> Multiaddr:
|
||||
r"""
|
||||
Parse a `Multiaddr` value from its byte slice representation.
|
||||
"""
|
||||
@staticmethod
|
||||
def from_string(string: builtins.str) -> Multiaddr:
|
||||
r"""
|
||||
Parse a `Multiaddr` value from its string representation.
|
||||
"""
|
||||
def len(self) -> builtins.int:
|
||||
r"""
|
||||
Return the length in bytes of this multiaddress.
|
||||
"""
|
||||
def is_empty(self) -> builtins.bool:
|
||||
r"""
|
||||
Returns true if the length of this multiaddress is 0.
|
||||
Construct an Ed25519 keypair from secret key bytes
|
||||
"""
|
||||
def to_bytes(self) -> bytes:
|
||||
r"""
|
||||
Return a copy of this [`Multiaddr`]'s byte representation.
|
||||
Get the secret key bytes underlying the keypair
|
||||
"""
|
||||
def to_string(self) -> builtins.str:
|
||||
def to_node_id(self) -> builtins.str:
|
||||
r"""
|
||||
Convert a Multiaddr to a string.
|
||||
Convert the `Keypair` into the corresponding `PeerId` string, which we use as our `NodeId`.
|
||||
"""
|
||||
|
||||
@typing.final
|
||||
@@ -180,37 +110,6 @@ class NoPeersSubscribedToTopicError(builtins.Exception):
|
||||
def __repr__(self) -> builtins.str: ...
|
||||
def __str__(self) -> builtins.str: ...
|
||||
|
||||
@typing.final
|
||||
class PeerId:
|
||||
r"""
|
||||
Identifier of a peer of the network.
|
||||
|
||||
The data is a `CIDv0` compatible multihash of the protobuf encoded public key of the peer
|
||||
as specified in [specs/peer-ids](https://github.com/libp2p/specs/blob/master/peer-ids/peer-ids.md).
|
||||
"""
|
||||
@staticmethod
|
||||
def random() -> PeerId:
|
||||
r"""
|
||||
Generates a random peer ID from a cryptographically secure PRNG.
|
||||
|
||||
This is useful for randomly walking on a DHT, or for testing purposes.
|
||||
"""
|
||||
@staticmethod
|
||||
def from_bytes(bytes: bytes) -> PeerId:
|
||||
r"""
|
||||
Parses a `PeerId` from bytes.
|
||||
"""
|
||||
def to_bytes(self) -> bytes:
|
||||
r"""
|
||||
Returns a raw bytes representation of this `PeerId`.
|
||||
"""
|
||||
def to_base58(self) -> builtins.str:
|
||||
r"""
|
||||
Returns a base-58 encoded string of this `PeerId`.
|
||||
"""
|
||||
def __repr__(self) -> builtins.str: ...
|
||||
def __str__(self) -> builtins.str: ...
|
||||
|
||||
@typing.final
|
||||
class ConnectionUpdateType(enum.Enum):
|
||||
r"""
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
use crate::ext::ResultExt as _;
|
||||
use libp2p::PeerId;
|
||||
use libp2p::identity::Keypair;
|
||||
use pyo3::prelude::{PyBytesMethods as _, PyModule, PyModuleMethods as _};
|
||||
use pyo3::types::PyBytes;
|
||||
use pyo3::types::{PyBytes, PyBytesMethods as _};
|
||||
use pyo3::{Bound, PyResult, Python, pyclass, pymethods};
|
||||
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods};
|
||||
|
||||
@@ -18,142 +16,32 @@ pub struct PyKeypair(pub Keypair);
|
||||
impl PyKeypair {
|
||||
/// Generate a new Ed25519 keypair.
|
||||
#[staticmethod]
|
||||
fn generate_ed25519() -> Self {
|
||||
fn generate() -> Self {
|
||||
Self(Keypair::generate_ed25519())
|
||||
}
|
||||
|
||||
/// Generate a new ECDSA keypair.
|
||||
/// Construct an Ed25519 keypair from secret key bytes
|
||||
#[staticmethod]
|
||||
fn generate_ecdsa() -> Self {
|
||||
Self(Keypair::generate_ecdsa())
|
||||
}
|
||||
|
||||
/// Generate a new Secp256k1 keypair.
|
||||
#[staticmethod]
|
||||
fn generate_secp256k1() -> Self {
|
||||
Self(Keypair::generate_secp256k1())
|
||||
}
|
||||
|
||||
/// Decode a private key from a protobuf structure and parse it as a `Keypair`.
|
||||
#[staticmethod]
|
||||
fn from_protobuf_encoding(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
||||
let bytes = Vec::from(bytes.as_bytes());
|
||||
Ok(Self(Keypair::from_protobuf_encoding(&bytes).pyerr()?))
|
||||
}
|
||||
|
||||
/// Decode an keypair from a DER-encoded secret key in PKCS#8 `PrivateKeyInfo`
|
||||
/// format (i.e. unencrypted) as defined in [RFC5208].
|
||||
///
|
||||
/// [RFC5208]: https://tools.ietf.org/html/rfc5208#section-5
|
||||
#[staticmethod]
|
||||
fn rsa_from_pkcs8(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
||||
let mut bytes = Vec::from(bytes.as_bytes());
|
||||
Ok(Self(Keypair::rsa_from_pkcs8(&mut bytes).pyerr()?))
|
||||
}
|
||||
|
||||
/// Decode a keypair from a DER-encoded Secp256k1 secret key in an `ECPrivateKey`
|
||||
/// structure as defined in [RFC5915].
|
||||
///
|
||||
/// [RFC5915]: https://tools.ietf.org/html/rfc5915
|
||||
#[staticmethod]
|
||||
fn secp256k1_from_der(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
||||
let mut bytes = Vec::from(bytes.as_bytes());
|
||||
Ok(Self(Keypair::secp256k1_from_der(&mut bytes).pyerr()?))
|
||||
}
|
||||
|
||||
#[staticmethod]
|
||||
fn ed25519_from_bytes(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
||||
fn from_bytes(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
||||
let mut bytes = Vec::from(bytes.as_bytes());
|
||||
Ok(Self(Keypair::ed25519_from_bytes(&mut bytes).pyerr()?))
|
||||
}
|
||||
|
||||
/// Encode a private key as protobuf structure.
|
||||
fn to_protobuf_encoding<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
|
||||
let bytes = self.0.to_protobuf_encoding().pyerr()?;
|
||||
/// Get the secret key bytes underlying the keypair
|
||||
fn to_bytes<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
|
||||
let bytes = self
|
||||
.0
|
||||
.clone()
|
||||
.try_into_ed25519()
|
||||
.pyerr()?
|
||||
.secret()
|
||||
.as_ref()
|
||||
.to_vec();
|
||||
Ok(PyBytes::new(py, &bytes))
|
||||
}
|
||||
|
||||
/// Convert the `Keypair` into the corresponding `PeerId`.
|
||||
fn to_peer_id(&self) -> PyPeerId {
|
||||
PyPeerId(self.0.public().to_peer_id())
|
||||
}
|
||||
|
||||
// /// Hidden constructor for pickling support. TODO: figure out how to do pickling...
|
||||
// #[gen_stub(skip)]
|
||||
// #[new]
|
||||
// fn py_new(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
||||
// Self::from_protobuf_encoding(bytes)
|
||||
// }
|
||||
//
|
||||
// #[gen_stub(skip)]
|
||||
// fn __setstate__(&mut self, state: Bound<'_, PyBytes>) -> PyResult<()> {
|
||||
// *self = Self::from_protobuf_encoding(state)?;
|
||||
// Ok(())
|
||||
// }
|
||||
//
|
||||
// #[gen_stub(skip)]
|
||||
// fn __getstate__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
|
||||
// self.to_protobuf_encoding(py)
|
||||
// }
|
||||
//
|
||||
// #[gen_stub(skip)]
|
||||
// pub fn __getnewargs__<'py>(&self, py: Python<'py>) -> PyResult<(Bound<'py, PyBytes>,)> {
|
||||
// Ok((self.to_protobuf_encoding(py)?,))
|
||||
// }
|
||||
}
|
||||
|
||||
/// Identifier of a peer of the network.
|
||||
///
|
||||
/// The data is a `CIDv0` compatible multihash of the protobuf encoded public key of the peer
|
||||
/// as specified in [specs/peer-ids](https://github.com/libp2p/specs/blob/master/peer-ids/peer-ids.md).
|
||||
#[gen_stub_pyclass]
|
||||
#[pyclass(name = "PeerId", frozen)]
|
||||
#[derive(Debug, Clone)]
|
||||
#[repr(transparent)]
|
||||
pub struct PyPeerId(pub PeerId);
|
||||
|
||||
#[gen_stub_pymethods]
|
||||
#[pymethods]
|
||||
#[allow(clippy::needless_pass_by_value)]
|
||||
impl PyPeerId {
|
||||
/// Generates a random peer ID from a cryptographically secure PRNG.
|
||||
///
|
||||
/// This is useful for randomly walking on a DHT, or for testing purposes.
|
||||
#[staticmethod]
|
||||
fn random() -> Self {
|
||||
Self(PeerId::random())
|
||||
}
|
||||
|
||||
/// Parses a `PeerId` from bytes.
|
||||
#[staticmethod]
|
||||
fn from_bytes(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
|
||||
let bytes = Vec::from(bytes.as_bytes());
|
||||
Ok(Self(PeerId::from_bytes(&bytes).pyerr()?))
|
||||
}
|
||||
|
||||
/// Returns a raw bytes representation of this `PeerId`.
|
||||
fn to_bytes<'py>(&self, py: Python<'py>) -> Bound<'py, PyBytes> {
|
||||
let bytes = self.0.to_bytes();
|
||||
PyBytes::new(py, &bytes)
|
||||
}
|
||||
|
||||
/// Returns a base-58 encoded string of this `PeerId`.
|
||||
fn to_base58(&self) -> String {
|
||||
self.0.to_base58()
|
||||
}
|
||||
|
||||
fn __repr__(&self) -> String {
|
||||
format!("PeerId({})", self.to_base58())
|
||||
}
|
||||
|
||||
fn __str__(&self) -> String {
|
||||
self.to_base58()
|
||||
/// Convert the `Keypair` into the corresponding `PeerId` string, which we use as our `NodeId`.
|
||||
fn to_node_id(&self) -> String {
|
||||
self.0.public().to_peer_id().to_base58()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn ident_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
m.add_class::<PyKeypair>()?;
|
||||
m.add_class::<PyPeerId>()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -8,9 +8,10 @@ mod allow_threading;
|
||||
mod ident;
|
||||
mod networking;
|
||||
|
||||
use crate::ident::ident_submodule;
|
||||
use crate::ident::PyKeypair;
|
||||
use crate::networking::networking_submodule;
|
||||
use pyo3::prelude::PyModule;
|
||||
use pyo3::types::PyModuleMethods;
|
||||
use pyo3::{Bound, PyResult, pyclass, pymodule};
|
||||
use pyo3_stub_gen::define_stub_info_gatherer;
|
||||
|
||||
@@ -158,7 +159,7 @@ fn main_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
// TODO: for now this is all NOT a submodule, but figure out how to make the submodule system
|
||||
// work with maturin, where the types generate correctly, in the right folder, without
|
||||
// too many importing issues...
|
||||
ident_submodule(m)?;
|
||||
m.add_class::<PyKeypair>()?;
|
||||
networking_submodule(m)?;
|
||||
|
||||
// top-level constructs
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
use crate::r#const::MPSC_CHANNEL_SIZE;
|
||||
use crate::ext::{ByteArrayExt as _, FutureExt, PyErrExt as _};
|
||||
use crate::ext::{ResultExt as _, TokioMpscReceiverExt as _, TokioMpscSenderExt as _};
|
||||
use crate::ident::{PyKeypair, PyPeerId};
|
||||
use crate::ident::PyKeypair;
|
||||
use crate::pyclass;
|
||||
use libp2p::futures::StreamExt as _;
|
||||
use libp2p::gossipsub;
|
||||
@@ -119,7 +119,7 @@ struct PyConnectionUpdate {
|
||||
|
||||
/// Identity of the peer that we have connected to or disconnected from.
|
||||
#[pyo3(get)]
|
||||
peer_id: PyPeerId,
|
||||
peer_id: String,
|
||||
|
||||
/// Remote connection's IPv4 address.
|
||||
#[pyo3(get)]
|
||||
@@ -251,7 +251,7 @@ async fn networking_task(
|
||||
// send connection event to channel (or exit if connection closed)
|
||||
if let Err(e) = connection_update_tx.send(PyConnectionUpdate {
|
||||
update_type: PyConnectionUpdateType::Connected,
|
||||
peer_id: PyPeerId(peer_id),
|
||||
peer_id: peer_id.to_base58(),
|
||||
remote_ipv4,
|
||||
remote_tcp_port,
|
||||
}).await {
|
||||
@@ -272,7 +272,7 @@ async fn networking_task(
|
||||
// send disconnection event to channel (or exit if connection closed)
|
||||
if let Err(e) = connection_update_tx.send(PyConnectionUpdate {
|
||||
update_type: PyConnectionUpdateType::Disconnected,
|
||||
peer_id: PyPeerId(peer_id),
|
||||
peer_id: peer_id.to_base58(),
|
||||
remote_ipv4,
|
||||
remote_tcp_port,
|
||||
}).await {
|
||||
|
||||
@@ -45,7 +45,7 @@ class Node:
|
||||
@classmethod
|
||||
async def create(cls, args: "Args") -> "Self":
|
||||
keypair = get_node_id_keypair()
|
||||
node_id = NodeId(keypair.to_peer_id().to_base58())
|
||||
node_id = NodeId(keypair.to_node_id())
|
||||
session_id = SessionId(master_node_id=node_id, election_clock=0)
|
||||
router = Router.create(keypair)
|
||||
await router.register_topic(topics.GLOBAL_EVENTS)
|
||||
|
||||
@@ -59,7 +59,11 @@ def chat_request_to_text_generation(
|
||||
chat_template_messages.append({"role": "system", "content": content})
|
||||
else:
|
||||
# Skip messages with no meaningful content
|
||||
if msg.content is None and msg.thinking is None and msg.tool_calls is None:
|
||||
if (
|
||||
msg.content is None
|
||||
and msg.reasoning_content is None
|
||||
and msg.tool_calls is None
|
||||
):
|
||||
continue
|
||||
|
||||
if msg.role in ("user", "assistant", "developer"):
|
||||
@@ -111,6 +115,11 @@ def chunk_to_response(
|
||||
]
|
||||
)
|
||||
|
||||
if chunk.is_thinking:
|
||||
delta = ChatCompletionMessage(role="assistant", reasoning_content=chunk.text)
|
||||
else:
|
||||
delta = ChatCompletionMessage(role="assistant", content=chunk.text)
|
||||
|
||||
return ChatCompletionResponse(
|
||||
id=command_id,
|
||||
created=int(time.time()),
|
||||
@@ -118,7 +127,7 @@ def chunk_to_response(
|
||||
choices=[
|
||||
StreamingChoiceResponse(
|
||||
index=0,
|
||||
delta=ChatCompletionMessage(role="assistant", content=chunk.text),
|
||||
delta=delta,
|
||||
logprobs=logprobs,
|
||||
finish_reason=chunk.finish_reason,
|
||||
)
|
||||
@@ -208,6 +217,7 @@ async def collect_chat_response(
|
||||
# FastAPI handles the cancellation better but wouldn't auto-serialize for some reason
|
||||
"""Collect all token chunks and return a single ChatCompletionResponse."""
|
||||
text_parts: list[str] = []
|
||||
thinking_parts: list[str] = []
|
||||
tool_calls: list[ToolCall] = []
|
||||
logprobs_content: list[LogprobsContentItem] = []
|
||||
model: str | None = None
|
||||
@@ -228,7 +238,10 @@ async def collect_chat_response(
|
||||
if model is None:
|
||||
model = chunk.model
|
||||
last_usage = chunk.usage or last_usage
|
||||
text_parts.append(chunk.text)
|
||||
if chunk.is_thinking:
|
||||
thinking_parts.append(chunk.text)
|
||||
else:
|
||||
text_parts.append(chunk.text)
|
||||
if chunk.logprob is not None:
|
||||
logprobs_content.append(
|
||||
LogprobsContentItem(
|
||||
@@ -258,6 +271,7 @@ async def collect_chat_response(
|
||||
raise ValueError(error_message)
|
||||
|
||||
combined_text = "".join(text_parts)
|
||||
combined_thinking = "".join(thinking_parts) if thinking_parts else None
|
||||
assert model is not None
|
||||
|
||||
yield ChatCompletionResponse(
|
||||
@@ -270,6 +284,7 @@ async def collect_chat_response(
|
||||
message=ChatCompletionMessage(
|
||||
role="assistant",
|
||||
content=combined_text,
|
||||
reasoning_content=combined_thinking,
|
||||
tool_calls=tool_calls if tool_calls else None,
|
||||
),
|
||||
logprobs=Logprobs(content=logprobs_content)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Claude Messages API adapter for converting requests/responses."""
|
||||
|
||||
import json
|
||||
import re
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
|
||||
@@ -28,6 +29,8 @@ from exo.shared.types.claude_api import (
|
||||
ClaudeStopReason,
|
||||
ClaudeTextBlock,
|
||||
ClaudeTextDelta,
|
||||
ClaudeThinkingBlock,
|
||||
ClaudeThinkingDelta,
|
||||
ClaudeToolResultBlock,
|
||||
ClaudeToolUseBlock,
|
||||
ClaudeUsage,
|
||||
@@ -61,6 +64,22 @@ def _extract_tool_result_text(block: ClaudeToolResultBlock) -> str:
|
||||
return "".join(sub_block.text for sub_block in block.content)
|
||||
|
||||
|
||||
# Matches "x-anthropic-billing-header: ...;" (with optional trailing newline)
|
||||
# or similar telemetry headers that change every request and break KV prefix caching.
|
||||
_VOLATILE_HEADER_RE = re.compile(r"^x-anthropic-[^\n]*;\n?", re.MULTILINE)
|
||||
|
||||
|
||||
def _strip_volatile_headers(text: str) -> str:
|
||||
"""Remove Anthropic billing/telemetry headers from system prompt text.
|
||||
|
||||
Claude Code prepends headers like 'x-anthropic-billing-header: cc_version=...;
|
||||
cc_entrypoint=...; cch=...;' that contain per-request content hashes. These
|
||||
change every request and break KV prefix caching (the prefix diverges at ~20
|
||||
tokens instead of matching thousands of conversation tokens).
|
||||
"""
|
||||
return _VOLATILE_HEADER_RE.sub("", text)
|
||||
|
||||
|
||||
def claude_request_to_text_generation(
|
||||
request: ClaudeMessagesRequest,
|
||||
) -> TextGenerationTaskParams:
|
||||
@@ -73,6 +92,8 @@ def claude_request_to_text_generation(
|
||||
instructions = request.system
|
||||
else:
|
||||
instructions = "".join(block.text for block in request.system)
|
||||
|
||||
instructions = _strip_volatile_headers(instructions)
|
||||
chat_template_messages.append({"role": "system", "content": instructions})
|
||||
|
||||
# Convert messages to input
|
||||
@@ -85,12 +106,15 @@ def claude_request_to_text_generation(
|
||||
|
||||
# Process structured content blocks
|
||||
text_parts: list[str] = []
|
||||
thinking_parts: list[str] = []
|
||||
tool_calls: list[dict[str, Any]] = []
|
||||
tool_results: list[ClaudeToolResultBlock] = []
|
||||
|
||||
for block in msg.content:
|
||||
if isinstance(block, ClaudeTextBlock):
|
||||
text_parts.append(block.text)
|
||||
elif isinstance(block, ClaudeThinkingBlock):
|
||||
thinking_parts.append(block.thinking)
|
||||
elif isinstance(block, ClaudeToolUseBlock):
|
||||
tool_calls.append(
|
||||
{
|
||||
@@ -106,6 +130,7 @@ def claude_request_to_text_generation(
|
||||
tool_results.append(block)
|
||||
|
||||
content = "".join(text_parts)
|
||||
reasoning_content = "".join(thinking_parts) if thinking_parts else None
|
||||
|
||||
# Build InputMessage from text content
|
||||
if msg.role in ("user", "assistant"):
|
||||
@@ -113,9 +138,14 @@ def claude_request_to_text_generation(
|
||||
|
||||
# Build chat_template_messages preserving tool structure
|
||||
if tool_calls:
|
||||
chat_template_messages.append(
|
||||
{"role": "assistant", "content": content, "tool_calls": tool_calls}
|
||||
)
|
||||
chat_msg: dict[str, Any] = {
|
||||
"role": "assistant",
|
||||
"content": content,
|
||||
"tool_calls": tool_calls,
|
||||
}
|
||||
if reasoning_content:
|
||||
chat_msg["reasoning_content"] = reasoning_content
|
||||
chat_template_messages.append(chat_msg)
|
||||
elif tool_results:
|
||||
for tr in tool_results:
|
||||
chat_template_messages.append(
|
||||
@@ -126,7 +156,10 @@ def claude_request_to_text_generation(
|
||||
}
|
||||
)
|
||||
else:
|
||||
chat_template_messages.append({"role": msg.role, "content": content})
|
||||
chat_msg = {"role": msg.role, "content": content}
|
||||
if reasoning_content:
|
||||
chat_msg["reasoning_content"] = reasoning_content
|
||||
chat_template_messages.append(chat_msg)
|
||||
|
||||
# Convert Claude tool definitions to OpenAI-style function tools
|
||||
tools: list[dict[str, Any]] | None = None
|
||||
@@ -143,6 +176,10 @@ def claude_request_to_text_generation(
|
||||
for tool in request.tools
|
||||
]
|
||||
|
||||
enable_thinking: bool | None = None
|
||||
if request.thinking is not None:
|
||||
enable_thinking = request.thinking.type in ("enabled", "adaptive")
|
||||
|
||||
return TextGenerationTaskParams(
|
||||
model=request.model,
|
||||
input=input_messages
|
||||
@@ -156,6 +193,7 @@ def claude_request_to_text_generation(
|
||||
stop=request.stop_sequences,
|
||||
stream=request.stream,
|
||||
tools=tools,
|
||||
enable_thinking=enable_thinking,
|
||||
chat_template_messages=chat_template_messages
|
||||
if chat_template_messages
|
||||
else None,
|
||||
@@ -173,6 +211,7 @@ async def collect_claude_response(
|
||||
# FastAPI handles the cancellation better but wouldn't auto-serialize for some reason
|
||||
"""Collect all token chunks and return a single ClaudeMessagesResponse."""
|
||||
text_parts: list[str] = []
|
||||
thinking_parts: list[str] = []
|
||||
tool_use_blocks: list[ClaudeToolUseBlock] = []
|
||||
stop_reason: ClaudeStopReason | None = None
|
||||
last_usage: Usage | None = None
|
||||
@@ -200,7 +239,10 @@ async def collect_claude_response(
|
||||
stop_reason = "tool_use"
|
||||
continue
|
||||
|
||||
text_parts.append(chunk.text)
|
||||
if chunk.is_thinking:
|
||||
thinking_parts.append(chunk.text)
|
||||
else:
|
||||
text_parts.append(chunk.text)
|
||||
|
||||
if chunk.finish_reason is not None:
|
||||
stop_reason = finish_reason_to_claude_stop_reason(chunk.finish_reason)
|
||||
@@ -209,9 +251,12 @@ async def collect_claude_response(
|
||||
raise ValueError(error_message)
|
||||
|
||||
combined_text = "".join(text_parts)
|
||||
combined_thinking = "".join(thinking_parts)
|
||||
|
||||
# Build content blocks
|
||||
content: list[ClaudeContentBlock] = []
|
||||
if combined_thinking:
|
||||
content.append(ClaudeThinkingBlock(thinking=combined_thinking))
|
||||
if combined_text:
|
||||
content.append(ClaudeTextBlock(text=combined_text))
|
||||
content.extend(tool_use_blocks)
|
||||
@@ -256,16 +301,16 @@ async def generate_claude_stream(
|
||||
start_event = ClaudeMessageStartEvent(message=initial_message)
|
||||
yield f"event: message_start\ndata: {start_event.model_dump_json()}\n\n"
|
||||
|
||||
# content_block_start for text block at index 0
|
||||
block_start = ClaudeContentBlockStartEvent(
|
||||
index=0, content_block=ClaudeTextBlock(text="")
|
||||
)
|
||||
yield f"event: content_block_start\ndata: {block_start.model_dump_json()}\n\n"
|
||||
|
||||
output_tokens = 0
|
||||
stop_reason: ClaudeStopReason | None = None
|
||||
last_usage: Usage | None = None
|
||||
next_block_index = 1 # text block is 0, tool blocks start at 1
|
||||
next_block_index = 0
|
||||
|
||||
# Track whether we've started thinking/text blocks
|
||||
thinking_block_started = False
|
||||
thinking_block_index = -1
|
||||
text_block_started = False
|
||||
text_block_index = -1
|
||||
|
||||
async for chunk in chunk_stream:
|
||||
if isinstance(chunk, PrefillProgressChunk):
|
||||
@@ -310,12 +355,45 @@ async def generate_claude_stream(
|
||||
|
||||
output_tokens += 1 # Count each chunk as one token
|
||||
|
||||
# content_block_delta
|
||||
delta_event = ClaudeContentBlockDeltaEvent(
|
||||
index=0,
|
||||
delta=ClaudeTextDelta(text=chunk.text),
|
||||
)
|
||||
yield f"event: content_block_delta\ndata: {delta_event.model_dump_json()}\n\n"
|
||||
if chunk.is_thinking:
|
||||
# Start thinking block on first thinking token
|
||||
if not thinking_block_started:
|
||||
thinking_block_started = True
|
||||
thinking_block_index = next_block_index
|
||||
next_block_index += 1
|
||||
block_start = ClaudeContentBlockStartEvent(
|
||||
index=thinking_block_index,
|
||||
content_block=ClaudeThinkingBlock(thinking=""),
|
||||
)
|
||||
yield f"event: content_block_start\ndata: {block_start.model_dump_json()}\n\n"
|
||||
|
||||
delta_event = ClaudeContentBlockDeltaEvent(
|
||||
index=thinking_block_index,
|
||||
delta=ClaudeThinkingDelta(thinking=chunk.text),
|
||||
)
|
||||
yield f"event: content_block_delta\ndata: {delta_event.model_dump_json()}\n\n"
|
||||
else:
|
||||
# Close thinking block when transitioning to text
|
||||
if thinking_block_started and text_block_index == -1:
|
||||
block_stop = ClaudeContentBlockStopEvent(index=thinking_block_index)
|
||||
yield f"event: content_block_stop\ndata: {block_stop.model_dump_json()}\n\n"
|
||||
|
||||
# Start text block on first text token
|
||||
if not text_block_started:
|
||||
text_block_started = True
|
||||
text_block_index = next_block_index
|
||||
next_block_index += 1
|
||||
block_start = ClaudeContentBlockStartEvent(
|
||||
index=text_block_index,
|
||||
content_block=ClaudeTextBlock(text=""),
|
||||
)
|
||||
yield f"event: content_block_start\ndata: {block_start.model_dump_json()}\n\n"
|
||||
|
||||
delta_event = ClaudeContentBlockDeltaEvent(
|
||||
index=text_block_index,
|
||||
delta=ClaudeTextDelta(text=chunk.text),
|
||||
)
|
||||
yield f"event: content_block_delta\ndata: {delta_event.model_dump_json()}\n\n"
|
||||
|
||||
if chunk.finish_reason is not None:
|
||||
stop_reason = finish_reason_to_claude_stop_reason(chunk.finish_reason)
|
||||
@@ -324,9 +402,22 @@ async def generate_claude_stream(
|
||||
if last_usage is not None:
|
||||
output_tokens = last_usage.completion_tokens
|
||||
|
||||
# content_block_stop for text block
|
||||
block_stop = ClaudeContentBlockStopEvent(index=0)
|
||||
yield f"event: content_block_stop\ndata: {block_stop.model_dump_json()}\n\n"
|
||||
# Close any open blocks
|
||||
if thinking_block_started and text_block_index == -1:
|
||||
block_stop = ClaudeContentBlockStopEvent(index=thinking_block_index)
|
||||
yield f"event: content_block_stop\ndata: {block_stop.model_dump_json()}\n\n"
|
||||
|
||||
if text_block_started:
|
||||
block_stop = ClaudeContentBlockStopEvent(index=text_block_index)
|
||||
yield f"event: content_block_stop\ndata: {block_stop.model_dump_json()}\n\n"
|
||||
|
||||
if not thinking_block_started and not text_block_started:
|
||||
empty_start = ClaudeContentBlockStartEvent(
|
||||
index=0, content_block=ClaudeTextBlock(text="")
|
||||
)
|
||||
yield f"event: content_block_start\ndata: {empty_start.model_dump_json()}\n\n"
|
||||
empty_stop = ClaudeContentBlockStopEvent(index=0)
|
||||
yield f"event: content_block_stop\ndata: {empty_stop.model_dump_json()}\n\n"
|
||||
|
||||
# message_delta
|
||||
message_delta = ClaudeMessageDeltaEvent(
|
||||
|
||||
@@ -29,8 +29,15 @@ from exo.shared.types.openai_responses import (
|
||||
ResponseOutputItemAddedEvent,
|
||||
ResponseOutputItemDoneEvent,
|
||||
ResponseOutputText,
|
||||
ResponseReasoningItem,
|
||||
ResponseReasoningSummaryPartAddedEvent,
|
||||
ResponseReasoningSummaryPartDoneEvent,
|
||||
ResponseReasoningSummaryText,
|
||||
ResponseReasoningSummaryTextDeltaEvent,
|
||||
ResponseReasoningSummaryTextDoneEvent,
|
||||
ResponsesRequest,
|
||||
ResponsesResponse,
|
||||
ResponsesStreamEvent,
|
||||
ResponseTextDeltaEvent,
|
||||
ResponseTextDoneEvent,
|
||||
ResponseUsage,
|
||||
@@ -38,6 +45,11 @@ from exo.shared.types.openai_responses import (
|
||||
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
|
||||
|
||||
|
||||
def _format_sse(event: ResponsesStreamEvent) -> str:
|
||||
"""Format a streaming event as an SSE message."""
|
||||
return f"event: {event.type}\ndata: {event.model_dump_json()}\n\n"
|
||||
|
||||
|
||||
def _extract_content(content: str | list[ResponseContentPart]) -> str:
|
||||
"""Extract plain text from a content field that may be a string or list of parts."""
|
||||
if isinstance(content, str):
|
||||
@@ -135,7 +147,9 @@ async def collect_responses_response(
|
||||
"""Collect all token chunks and return a single ResponsesResponse."""
|
||||
response_id = f"resp_{command_id}"
|
||||
item_id = f"item_{command_id}"
|
||||
reasoning_id = f"rs_{command_id}"
|
||||
accumulated_text = ""
|
||||
thinking_parts: list[str] = []
|
||||
function_call_items: list[ResponseFunctionCallItem] = []
|
||||
last_usage: Usage | None = None
|
||||
error_message: str | None = None
|
||||
@@ -162,6 +176,10 @@ async def collect_responses_response(
|
||||
)
|
||||
continue
|
||||
|
||||
if chunk.is_thinking:
|
||||
thinking_parts.append(chunk.text)
|
||||
continue
|
||||
|
||||
accumulated_text += chunk.text
|
||||
|
||||
if error_message is not None:
|
||||
@@ -176,13 +194,21 @@ async def collect_responses_response(
|
||||
total_tokens=last_usage.total_tokens,
|
||||
)
|
||||
|
||||
output: list[ResponseItem] = [
|
||||
output: list[ResponseItem] = []
|
||||
if thinking_parts:
|
||||
output.append(
|
||||
ResponseReasoningItem(
|
||||
id=reasoning_id,
|
||||
summary=[ResponseReasoningSummaryText(text="".join(thinking_parts))],
|
||||
)
|
||||
)
|
||||
output.append(
|
||||
ResponseMessageItem(
|
||||
id=item_id,
|
||||
content=[ResponseOutputText(text=accumulated_text)],
|
||||
status="completed",
|
||||
)
|
||||
]
|
||||
)
|
||||
output.extend(function_call_items)
|
||||
|
||||
yield ResponsesResponse(
|
||||
@@ -206,6 +232,7 @@ async def generate_responses_stream(
|
||||
"""Generate OpenAI Responses API streaming events from TokenChunks."""
|
||||
response_id = f"resp_{command_id}"
|
||||
item_id = f"item_{command_id}"
|
||||
reasoning_id = f"rs_{command_id}"
|
||||
seq = count(1)
|
||||
|
||||
# response.created
|
||||
@@ -219,40 +246,25 @@ async def generate_responses_stream(
|
||||
created_event = ResponseCreatedEvent(
|
||||
sequence_number=next(seq), response=initial_response
|
||||
)
|
||||
yield f"event: response.created\ndata: {created_event.model_dump_json()}\n\n"
|
||||
yield _format_sse(created_event)
|
||||
|
||||
# response.in_progress
|
||||
in_progress_event = ResponseInProgressEvent(
|
||||
sequence_number=next(seq), response=initial_response
|
||||
)
|
||||
yield f"event: response.in_progress\ndata: {in_progress_event.model_dump_json()}\n\n"
|
||||
|
||||
# response.output_item.added
|
||||
initial_item = ResponseMessageItem(
|
||||
id=item_id,
|
||||
content=[ResponseOutputText(text="")],
|
||||
status="in_progress",
|
||||
)
|
||||
item_added = ResponseOutputItemAddedEvent(
|
||||
sequence_number=next(seq), output_index=0, item=initial_item
|
||||
)
|
||||
yield f"event: response.output_item.added\ndata: {item_added.model_dump_json()}\n\n"
|
||||
|
||||
# response.content_part.added
|
||||
initial_part = ResponseOutputText(text="")
|
||||
part_added = ResponseContentPartAddedEvent(
|
||||
sequence_number=next(seq),
|
||||
item_id=item_id,
|
||||
output_index=0,
|
||||
content_index=0,
|
||||
part=initial_part,
|
||||
)
|
||||
yield f"event: response.content_part.added\ndata: {part_added.model_dump_json()}\n\n"
|
||||
yield _format_sse(in_progress_event)
|
||||
|
||||
accumulated_text = ""
|
||||
accumulated_thinking = ""
|
||||
function_call_items: list[ResponseFunctionCallItem] = []
|
||||
last_usage: Usage | None = None
|
||||
next_output_index = 1 # message item is at 0
|
||||
next_output_index = 0
|
||||
|
||||
# Track dynamic block creation
|
||||
reasoning_started = False
|
||||
reasoning_output_index = -1
|
||||
message_started = False
|
||||
message_output_index = -1
|
||||
|
||||
async for chunk in chunk_stream:
|
||||
if isinstance(chunk, PrefillProgressChunk):
|
||||
@@ -281,7 +293,7 @@ async def generate_responses_stream(
|
||||
output_index=next_output_index,
|
||||
item=fc_item,
|
||||
)
|
||||
yield f"event: response.output_item.added\ndata: {fc_added.model_dump_json()}\n\n"
|
||||
yield _format_sse(fc_added)
|
||||
|
||||
# response.function_call_arguments.delta
|
||||
args_delta = ResponseFunctionCallArgumentsDeltaEvent(
|
||||
@@ -290,7 +302,7 @@ async def generate_responses_stream(
|
||||
output_index=next_output_index,
|
||||
delta=tool.arguments,
|
||||
)
|
||||
yield f"event: response.function_call_arguments.delta\ndata: {args_delta.model_dump_json()}\n\n"
|
||||
yield _format_sse(args_delta)
|
||||
|
||||
# response.function_call_arguments.done
|
||||
args_done = ResponseFunctionCallArgumentsDoneEvent(
|
||||
@@ -300,7 +312,7 @@ async def generate_responses_stream(
|
||||
name=tool.name,
|
||||
arguments=tool.arguments,
|
||||
)
|
||||
yield f"event: response.function_call_arguments.done\ndata: {args_done.model_dump_json()}\n\n"
|
||||
yield _format_sse(args_done)
|
||||
|
||||
# response.output_item.done
|
||||
fc_done_item = ResponseFunctionCallItem(
|
||||
@@ -315,44 +327,205 @@ async def generate_responses_stream(
|
||||
output_index=next_output_index,
|
||||
item=fc_done_item,
|
||||
)
|
||||
yield f"event: response.output_item.done\ndata: {fc_item_done.model_dump_json()}\n\n"
|
||||
yield _format_sse(fc_item_done)
|
||||
|
||||
function_call_items.append(fc_done_item)
|
||||
next_output_index += 1
|
||||
continue
|
||||
|
||||
if chunk.is_thinking:
|
||||
# Start reasoning block on first thinking token
|
||||
if not reasoning_started:
|
||||
reasoning_started = True
|
||||
reasoning_output_index = next_output_index
|
||||
next_output_index += 1
|
||||
|
||||
# response.output_item.added for reasoning
|
||||
reasoning_item = ResponseReasoningItem(
|
||||
id=reasoning_id,
|
||||
summary=[],
|
||||
status="in_progress",
|
||||
)
|
||||
rs_added = ResponseOutputItemAddedEvent(
|
||||
sequence_number=next(seq),
|
||||
output_index=reasoning_output_index,
|
||||
item=reasoning_item,
|
||||
)
|
||||
yield _format_sse(rs_added)
|
||||
|
||||
# response.reasoning_summary_part.added
|
||||
part_added = ResponseReasoningSummaryPartAddedEvent(
|
||||
sequence_number=next(seq),
|
||||
item_id=reasoning_id,
|
||||
output_index=reasoning_output_index,
|
||||
summary_index=0,
|
||||
part=ResponseReasoningSummaryText(text=""),
|
||||
)
|
||||
yield _format_sse(part_added)
|
||||
|
||||
accumulated_thinking += chunk.text
|
||||
|
||||
# response.reasoning_summary_text.delta
|
||||
rs_delta = ResponseReasoningSummaryTextDeltaEvent(
|
||||
sequence_number=next(seq),
|
||||
item_id=reasoning_id,
|
||||
output_index=reasoning_output_index,
|
||||
summary_index=0,
|
||||
delta=chunk.text,
|
||||
)
|
||||
yield _format_sse(rs_delta)
|
||||
continue
|
||||
|
||||
# Close reasoning block when transitioning to text
|
||||
if reasoning_started and not message_started:
|
||||
# response.reasoning_summary_text.done
|
||||
rs_text_done = ResponseReasoningSummaryTextDoneEvent(
|
||||
sequence_number=next(seq),
|
||||
item_id=reasoning_id,
|
||||
output_index=reasoning_output_index,
|
||||
summary_index=0,
|
||||
text=accumulated_thinking,
|
||||
)
|
||||
yield _format_sse(rs_text_done)
|
||||
|
||||
# response.reasoning_summary_part.done
|
||||
rs_part_done = ResponseReasoningSummaryPartDoneEvent(
|
||||
sequence_number=next(seq),
|
||||
item_id=reasoning_id,
|
||||
output_index=reasoning_output_index,
|
||||
summary_index=0,
|
||||
part=ResponseReasoningSummaryText(text=accumulated_thinking),
|
||||
)
|
||||
yield _format_sse(rs_part_done)
|
||||
|
||||
# response.output_item.done for reasoning
|
||||
rs_item_done = ResponseOutputItemDoneEvent(
|
||||
sequence_number=next(seq),
|
||||
output_index=reasoning_output_index,
|
||||
item=ResponseReasoningItem(
|
||||
id=reasoning_id,
|
||||
summary=[ResponseReasoningSummaryText(text=accumulated_thinking)],
|
||||
),
|
||||
)
|
||||
yield _format_sse(rs_item_done)
|
||||
|
||||
# Start message block on first text token
|
||||
if not message_started:
|
||||
message_started = True
|
||||
message_output_index = next_output_index
|
||||
next_output_index += 1
|
||||
|
||||
initial_item = ResponseMessageItem(
|
||||
id=item_id,
|
||||
content=[ResponseOutputText(text="")],
|
||||
status="in_progress",
|
||||
)
|
||||
item_added = ResponseOutputItemAddedEvent(
|
||||
sequence_number=next(seq),
|
||||
output_index=message_output_index,
|
||||
item=initial_item,
|
||||
)
|
||||
yield _format_sse(item_added)
|
||||
|
||||
initial_part = ResponseOutputText(text="")
|
||||
part_added = ResponseContentPartAddedEvent(
|
||||
sequence_number=next(seq),
|
||||
item_id=item_id,
|
||||
output_index=message_output_index,
|
||||
content_index=0,
|
||||
part=initial_part,
|
||||
)
|
||||
yield _format_sse(part_added)
|
||||
|
||||
accumulated_text += chunk.text
|
||||
|
||||
# response.output_text.delta
|
||||
delta_event = ResponseTextDeltaEvent(
|
||||
sequence_number=next(seq),
|
||||
item_id=item_id,
|
||||
output_index=0,
|
||||
output_index=message_output_index,
|
||||
content_index=0,
|
||||
delta=chunk.text,
|
||||
)
|
||||
yield f"event: response.output_text.delta\ndata: {delta_event.model_dump_json()}\n\n"
|
||||
yield _format_sse(delta_event)
|
||||
|
||||
# Close reasoning block if it was never followed by text
|
||||
if reasoning_started and not message_started:
|
||||
rs_text_done = ResponseReasoningSummaryTextDoneEvent(
|
||||
sequence_number=next(seq),
|
||||
item_id=reasoning_id,
|
||||
output_index=reasoning_output_index,
|
||||
summary_index=0,
|
||||
text=accumulated_thinking,
|
||||
)
|
||||
yield _format_sse(rs_text_done)
|
||||
|
||||
rs_part_done = ResponseReasoningSummaryPartDoneEvent(
|
||||
sequence_number=next(seq),
|
||||
item_id=reasoning_id,
|
||||
output_index=reasoning_output_index,
|
||||
summary_index=0,
|
||||
part=ResponseReasoningSummaryText(text=accumulated_thinking),
|
||||
)
|
||||
yield _format_sse(rs_part_done)
|
||||
|
||||
rs_item_done = ResponseOutputItemDoneEvent(
|
||||
sequence_number=next(seq),
|
||||
output_index=reasoning_output_index,
|
||||
item=ResponseReasoningItem(
|
||||
id=reasoning_id,
|
||||
summary=[ResponseReasoningSummaryText(text=accumulated_thinking)],
|
||||
),
|
||||
)
|
||||
yield _format_sse(rs_item_done)
|
||||
|
||||
# If no message block was started, create one now (empty text)
|
||||
if not message_started:
|
||||
message_output_index = next_output_index
|
||||
next_output_index += 1
|
||||
|
||||
initial_item = ResponseMessageItem(
|
||||
id=item_id,
|
||||
content=[ResponseOutputText(text="")],
|
||||
status="in_progress",
|
||||
)
|
||||
item_added = ResponseOutputItemAddedEvent(
|
||||
sequence_number=next(seq),
|
||||
output_index=message_output_index,
|
||||
item=initial_item,
|
||||
)
|
||||
yield _format_sse(item_added)
|
||||
|
||||
initial_part = ResponseOutputText(text="")
|
||||
part_added_evt = ResponseContentPartAddedEvent(
|
||||
sequence_number=next(seq),
|
||||
item_id=item_id,
|
||||
output_index=message_output_index,
|
||||
content_index=0,
|
||||
part=initial_part,
|
||||
)
|
||||
yield _format_sse(part_added_evt)
|
||||
|
||||
# response.output_text.done
|
||||
text_done = ResponseTextDoneEvent(
|
||||
sequence_number=next(seq),
|
||||
item_id=item_id,
|
||||
output_index=0,
|
||||
output_index=message_output_index,
|
||||
content_index=0,
|
||||
text=accumulated_text,
|
||||
)
|
||||
yield f"event: response.output_text.done\ndata: {text_done.model_dump_json()}\n\n"
|
||||
yield _format_sse(text_done)
|
||||
|
||||
# response.content_part.done
|
||||
final_part = ResponseOutputText(text=accumulated_text)
|
||||
part_done = ResponseContentPartDoneEvent(
|
||||
sequence_number=next(seq),
|
||||
item_id=item_id,
|
||||
output_index=0,
|
||||
output_index=message_output_index,
|
||||
content_index=0,
|
||||
part=final_part,
|
||||
)
|
||||
yield f"event: response.content_part.done\ndata: {part_done.model_dump_json()}\n\n"
|
||||
yield _format_sse(part_done)
|
||||
|
||||
# response.output_item.done
|
||||
final_message_item = ResponseMessageItem(
|
||||
@@ -361,9 +534,11 @@ async def generate_responses_stream(
|
||||
status="completed",
|
||||
)
|
||||
item_done = ResponseOutputItemDoneEvent(
|
||||
sequence_number=next(seq), output_index=0, item=final_message_item
|
||||
sequence_number=next(seq),
|
||||
output_index=message_output_index,
|
||||
item=final_message_item,
|
||||
)
|
||||
yield f"event: response.output_item.done\ndata: {item_done.model_dump_json()}\n\n"
|
||||
yield _format_sse(item_done)
|
||||
|
||||
# Create usage from usage data if available
|
||||
usage = None
|
||||
@@ -375,7 +550,15 @@ async def generate_responses_stream(
|
||||
)
|
||||
|
||||
# response.completed
|
||||
output: list[ResponseItem] = [final_message_item]
|
||||
output: list[ResponseItem] = []
|
||||
if reasoning_started:
|
||||
output.append(
|
||||
ResponseReasoningItem(
|
||||
id=reasoning_id,
|
||||
summary=[ResponseReasoningSummaryText(text=accumulated_thinking)],
|
||||
)
|
||||
)
|
||||
output.append(final_message_item)
|
||||
output.extend(function_call_items)
|
||||
final_response = ResponsesResponse(
|
||||
id=response_id,
|
||||
@@ -388,4 +571,4 @@ async def generate_responses_stream(
|
||||
completed_event = ResponseCompletedEvent(
|
||||
sequence_number=next(seq), response=final_response
|
||||
)
|
||||
yield f"event: response.completed\ndata: {completed_event.model_dump_json()}\n\n"
|
||||
yield _format_sse(completed_event)
|
||||
|
||||
@@ -138,7 +138,6 @@ from exo.shared.types.events import (
|
||||
Event,
|
||||
ForwarderEvent,
|
||||
IndexedEvent,
|
||||
PrefillProgress,
|
||||
TracesMerged,
|
||||
)
|
||||
from exo.shared.types.memory import Memory
|
||||
@@ -1455,22 +1454,6 @@ class API:
|
||||
await queue.send(event.chunk)
|
||||
except BrokenResourceError:
|
||||
self._text_generation_queues.pop(event.command_id, None)
|
||||
|
||||
elif isinstance(event, PrefillProgress):
|
||||
if queue := self._text_generation_queues.get(
|
||||
event.command_id, None
|
||||
):
|
||||
try:
|
||||
await queue.send(
|
||||
PrefillProgressChunk(
|
||||
model=event.model,
|
||||
processed_tokens=event.processed_tokens,
|
||||
total_tokens=event.total_tokens,
|
||||
)
|
||||
)
|
||||
except BrokenResourceError:
|
||||
self._text_generation_queues.pop(event.command_id, None)
|
||||
|
||||
if isinstance(event, TracesMerged):
|
||||
self._save_merged_trace(event)
|
||||
|
||||
|
||||
@@ -261,7 +261,7 @@ class TestGenerateClaudeStreamToolUse:
|
||||
|
||||
parsed = _parse_sse_events(events)
|
||||
|
||||
# Two tool block starts (at indices 1 and 2)
|
||||
# Two tool block starts (at indices 0 and 1 — no text block when only tools)
|
||||
tool_starts = [
|
||||
e
|
||||
for e in parsed
|
||||
@@ -270,12 +270,11 @@ class TestGenerateClaudeStreamToolUse:
|
||||
== "tool_use"
|
||||
]
|
||||
assert len(tool_starts) == 2
|
||||
assert tool_starts[0]["index"] == 1
|
||||
assert tool_starts[1]["index"] == 2
|
||||
assert tool_starts[0]["index"] == 0
|
||||
assert tool_starts[1]["index"] == 1
|
||||
|
||||
# Two tool block stops (at indices 1 and 2), plus text block stop at 0
|
||||
# Two tool block stops (at indices 0 and 1)
|
||||
block_stops = [e for e in parsed if e.get("type") == "content_block_stop"]
|
||||
stop_indices = [e["index"] for e in block_stops]
|
||||
assert 0 in stop_indices
|
||||
assert 1 in stop_indices
|
||||
assert 2 in stop_indices
|
||||
|
||||
@@ -42,7 +42,7 @@ from exo.utils.channels import channel
|
||||
@pytest.mark.asyncio
|
||||
async def test_master():
|
||||
keypair = get_node_id_keypair()
|
||||
node_id = NodeId(keypair.to_peer_id().to_base58())
|
||||
node_id = NodeId(keypair.to_node_id())
|
||||
session_id = SessionId(master_node_id=node_id, election_clock=0)
|
||||
|
||||
ge_sender, global_event_receiver = channel[ForwarderEvent]()
|
||||
@@ -75,7 +75,7 @@ async def test_master():
|
||||
async with anyio.create_task_group() as tg:
|
||||
tg.start_soon(master.run)
|
||||
|
||||
sender_node_id = NodeId(f"{keypair.to_peer_id().to_base58()}_sender")
|
||||
sender_node_id = NodeId(f"{keypair.to_node_id()}_sender")
|
||||
# inject a NodeGatheredInfo event
|
||||
logger.info("inject a NodeGatheredInfo event")
|
||||
await local_event_sender.send(
|
||||
|
||||
@@ -14,12 +14,10 @@ from exo.shared.models.model_cards import ModelCard, ModelId, ModelTask
|
||||
from exo.shared.topology import Topology
|
||||
from exo.shared.types.commands import PlaceInstance
|
||||
from exo.shared.types.common import CommandId, NodeId
|
||||
from exo.shared.types.events import InstanceCreated, InstanceDeleted, TaskStatusUpdated
|
||||
from exo.shared.types.events import InstanceCreated, InstanceDeleted
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.multiaddr import Multiaddr
|
||||
from exo.shared.types.profiling import NetworkInterfaceInfo, NodeNetworkInfo
|
||||
from exo.shared.types.tasks import TaskId, TaskStatus, TextGeneration
|
||||
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
|
||||
from exo.shared.types.topology import Connection, SocketConnection
|
||||
from exo.shared.types.worker.instances import (
|
||||
Instance,
|
||||
@@ -458,117 +456,3 @@ def test_tensor_rdma_backend_connectivity_matrix(
|
||||
else:
|
||||
ip_part = coordinator.split(":")[0]
|
||||
assert len(ip_part.split(".")) == 4
|
||||
|
||||
|
||||
def _make_task(
|
||||
instance_id: InstanceId,
|
||||
status: TaskStatus = TaskStatus.Running,
|
||||
) -> TextGeneration:
|
||||
return TextGeneration(
|
||||
task_id=TaskId(),
|
||||
task_status=status,
|
||||
instance_id=instance_id,
|
||||
command_id=CommandId(),
|
||||
task_params=TextGenerationTaskParams(
|
||||
model=ModelId("test-model"),
|
||||
input=[InputMessage(role="user", content="hello")],
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def test_get_transition_events_delete_instance_cancels_running_tasks(
|
||||
instance: Instance,
|
||||
):
|
||||
# arrange
|
||||
instance_id = InstanceId()
|
||||
current_instances: dict[InstanceId, Instance] = {instance_id: instance}
|
||||
target_instances: dict[InstanceId, Instance] = {}
|
||||
task = _make_task(instance_id, TaskStatus.Running)
|
||||
tasks = {task.task_id: task}
|
||||
|
||||
# act
|
||||
events = get_transition_events(current_instances, target_instances, tasks)
|
||||
|
||||
# assert – cancellation event should come before the deletion event
|
||||
assert len(events) == 2
|
||||
assert isinstance(events[0], TaskStatusUpdated)
|
||||
assert events[0].task_id == task.task_id
|
||||
assert events[0].task_status == TaskStatus.Cancelled
|
||||
assert isinstance(events[1], InstanceDeleted)
|
||||
assert events[1].instance_id == instance_id
|
||||
|
||||
|
||||
def test_get_transition_events_delete_instance_cancels_pending_tasks(
|
||||
instance: Instance,
|
||||
):
|
||||
# arrange
|
||||
instance_id = InstanceId()
|
||||
current_instances: dict[InstanceId, Instance] = {instance_id: instance}
|
||||
target_instances: dict[InstanceId, Instance] = {}
|
||||
task = _make_task(instance_id, TaskStatus.Pending)
|
||||
tasks = {task.task_id: task}
|
||||
|
||||
# act
|
||||
events = get_transition_events(current_instances, target_instances, tasks)
|
||||
|
||||
# assert
|
||||
assert len(events) == 2
|
||||
assert isinstance(events[0], TaskStatusUpdated)
|
||||
assert events[0].task_id == task.task_id
|
||||
assert events[0].task_status == TaskStatus.Cancelled
|
||||
assert isinstance(events[1], InstanceDeleted)
|
||||
|
||||
|
||||
def test_get_transition_events_delete_instance_ignores_completed_tasks(
|
||||
instance: Instance,
|
||||
):
|
||||
# arrange
|
||||
instance_id = InstanceId()
|
||||
current_instances: dict[InstanceId, Instance] = {instance_id: instance}
|
||||
target_instances: dict[InstanceId, Instance] = {}
|
||||
tasks = {
|
||||
t.task_id: t
|
||||
for t in [
|
||||
_make_task(instance_id, TaskStatus.Complete),
|
||||
_make_task(instance_id, TaskStatus.Failed),
|
||||
_make_task(instance_id, TaskStatus.TimedOut),
|
||||
_make_task(instance_id, TaskStatus.Cancelled),
|
||||
]
|
||||
}
|
||||
|
||||
# act
|
||||
events = get_transition_events(current_instances, target_instances, tasks)
|
||||
|
||||
# assert – only the InstanceDeleted event, no cancellations
|
||||
assert len(events) == 1
|
||||
assert isinstance(events[0], InstanceDeleted)
|
||||
|
||||
|
||||
def test_get_transition_events_delete_instance_cancels_only_matching_tasks(
|
||||
instance: Instance,
|
||||
):
|
||||
# arrange
|
||||
instance_id_a = InstanceId()
|
||||
instance_id_b = InstanceId()
|
||||
current_instances: dict[InstanceId, Instance] = {
|
||||
instance_id_a: instance,
|
||||
instance_id_b: instance,
|
||||
}
|
||||
# only delete instance A, keep instance B
|
||||
target_instances: dict[InstanceId, Instance] = {instance_id_b: instance}
|
||||
|
||||
task_a = _make_task(instance_id_a, TaskStatus.Running)
|
||||
task_b = _make_task(instance_id_b, TaskStatus.Running)
|
||||
tasks = {task_a.task_id: task_a, task_b.task_id: task_b}
|
||||
|
||||
# act
|
||||
events = get_transition_events(current_instances, target_instances, tasks)
|
||||
|
||||
# assert – only task_a should be cancelled
|
||||
cancel_events = [e for e in events if isinstance(e, TaskStatusUpdated)]
|
||||
delete_events = [e for e in events if isinstance(e, InstanceDeleted)]
|
||||
assert len(cancel_events) == 1
|
||||
assert cancel_events[0].task_id == task_a.task_id
|
||||
assert cancel_events[0].task_status == TaskStatus.Cancelled
|
||||
assert len(delete_events) == 1
|
||||
assert delete_events[0].instance_id == instance_id_a
|
||||
|
||||
@@ -30,7 +30,7 @@ class ConnectionMessage(CamelCaseModel):
|
||||
@classmethod
|
||||
def from_update(cls, update: ConnectionUpdate) -> "ConnectionMessage":
|
||||
return cls(
|
||||
node_id=NodeId(update.peer_id.to_base58()),
|
||||
node_id=NodeId(update.peer_id),
|
||||
connection_type=ConnectionMessageType.from_update_type(update.update_type),
|
||||
remote_ipv4=update.remote_ipv4,
|
||||
remote_tcp_port=update.remote_tcp_port,
|
||||
|
||||
@@ -221,7 +221,7 @@ def get_node_id_keypair(
|
||||
Obtain the :class:`PeerId` by from it.
|
||||
"""
|
||||
# TODO(evan): bring back node id persistence once we figure out how to deal with duplicates
|
||||
return Keypair.generate_ed25519()
|
||||
return Keypair.generate()
|
||||
|
||||
def lock_path(path: str | bytes | PathLike[str] | PathLike[bytes]) -> Path:
|
||||
return Path(str(path) + ".lock")
|
||||
@@ -235,12 +235,12 @@ def get_node_id_keypair(
|
||||
protobuf_encoded = f.read()
|
||||
|
||||
try: # if decoded successfully, save & return
|
||||
return Keypair.from_protobuf_encoding(protobuf_encoded)
|
||||
return Keypair.from_bytes(protobuf_encoded)
|
||||
except ValueError as e: # on runtime error, assume corrupt file
|
||||
logger.warning(f"Encountered error when trying to get keypair: {e}")
|
||||
|
||||
# if no valid credentials, create new ones and persist
|
||||
with open(path, "w+b") as f:
|
||||
keypair = Keypair.generate_ed25519()
|
||||
f.write(keypair.to_protobuf_encoding())
|
||||
f.write(keypair.to_bytes())
|
||||
return keypair
|
||||
|
||||
@@ -15,7 +15,6 @@ from exo.shared.types.events import (
|
||||
NodeDownloadProgress,
|
||||
NodeGatheredInfo,
|
||||
NodeTimedOut,
|
||||
PrefillProgress,
|
||||
RunnerDeleted,
|
||||
RunnerStatusUpdated,
|
||||
TaskAcknowledged,
|
||||
@@ -65,7 +64,6 @@ def event_apply(event: Event, state: State) -> State:
|
||||
| ChunkGenerated()
|
||||
| TaskAcknowledged()
|
||||
| InputChunkReceived()
|
||||
| PrefillProgress()
|
||||
| TracesCollected()
|
||||
| TracesMerged()
|
||||
): # Pass-through events that don't modify state
|
||||
|
||||
@@ -23,7 +23,7 @@ def _get_keypair_concurrent_subprocess_task(
|
||||
sem.release()
|
||||
# wait to be told to begin simultaneous read
|
||||
ev.wait()
|
||||
queue.put(get_node_id_keypair().to_protobuf_encoding())
|
||||
queue.put(get_node_id_keypair().to_bytes())
|
||||
|
||||
|
||||
def _get_keypair_concurrent(num_procs: int) -> bytes:
|
||||
|
||||
@@ -77,7 +77,7 @@ class ChatCompletionMessage(BaseModel):
|
||||
content: (
|
||||
str | ChatCompletionMessageText | list[ChatCompletionMessageText] | None
|
||||
) = None
|
||||
thinking: str | None = None # Added for GPT-OSS harmony format support
|
||||
reasoning_content: str | None = None
|
||||
name: str | None = None
|
||||
tool_calls: list[ToolCall] | None = None
|
||||
tool_call_id: str | None = None
|
||||
|
||||
@@ -27,6 +27,7 @@ class TokenChunk(BaseChunk):
|
||||
stats: GenerationStats | None = None
|
||||
logprob: float | None = None
|
||||
top_logprobs: list[TopLogprobItem] | None = None
|
||||
is_thinking: bool = False
|
||||
|
||||
|
||||
class ErrorChunk(BaseChunk):
|
||||
|
||||
@@ -47,6 +47,14 @@ class ClaudeImageBlock(BaseModel, frozen=True):
|
||||
source: ClaudeImageSource
|
||||
|
||||
|
||||
class ClaudeThinkingBlock(BaseModel, frozen=True):
|
||||
"""Thinking content block in Claude Messages API."""
|
||||
|
||||
type: Literal["thinking"] = "thinking"
|
||||
thinking: str
|
||||
signature: str | None = None
|
||||
|
||||
|
||||
class ClaudeToolUseBlock(BaseModel, frozen=True):
|
||||
"""Tool use content block in Claude Messages API."""
|
||||
|
||||
@@ -66,11 +74,17 @@ class ClaudeToolResultBlock(BaseModel, frozen=True):
|
||||
cache_control: dict[str, str] | None = None
|
||||
|
||||
|
||||
ClaudeContentBlock = ClaudeTextBlock | ClaudeImageBlock | ClaudeToolUseBlock
|
||||
ClaudeContentBlock = (
|
||||
ClaudeTextBlock | ClaudeImageBlock | ClaudeThinkingBlock | ClaudeToolUseBlock
|
||||
)
|
||||
|
||||
# Input content blocks can also include tool_result (sent by user after tool_use)
|
||||
ClaudeInputContentBlock = (
|
||||
ClaudeTextBlock | ClaudeImageBlock | ClaudeToolUseBlock | ClaudeToolResultBlock
|
||||
ClaudeTextBlock
|
||||
| ClaudeImageBlock
|
||||
| ClaudeThinkingBlock
|
||||
| ClaudeToolUseBlock
|
||||
| ClaudeToolResultBlock
|
||||
)
|
||||
|
||||
|
||||
@@ -82,6 +96,11 @@ class ClaudeMessage(BaseModel, frozen=True):
|
||||
content: str | list[ClaudeInputContentBlock]
|
||||
|
||||
|
||||
class ClaudeThinkingConfig(BaseModel, frozen=True):
|
||||
type: Literal["enabled", "disabled", "adaptive"]
|
||||
budget_tokens: int | None = None
|
||||
|
||||
|
||||
class ClaudeMessagesRequest(BaseModel):
|
||||
"""Request body for Claude Messages API."""
|
||||
|
||||
@@ -96,6 +115,7 @@ class ClaudeMessagesRequest(BaseModel):
|
||||
top_k: int | None = None
|
||||
tools: list[ClaudeToolDefinition] | None = None
|
||||
metadata: dict[str, str] | None = None
|
||||
thinking: ClaudeThinkingConfig | None = None
|
||||
|
||||
|
||||
# Response types
|
||||
@@ -145,7 +165,7 @@ class ClaudeContentBlockStartEvent(BaseModel, frozen=True):
|
||||
|
||||
type: Literal["content_block_start"] = "content_block_start"
|
||||
index: int
|
||||
content_block: ClaudeTextBlock | ClaudeToolUseBlock
|
||||
content_block: ClaudeTextBlock | ClaudeThinkingBlock | ClaudeToolUseBlock
|
||||
|
||||
|
||||
class ClaudeTextDelta(BaseModel, frozen=True):
|
||||
@@ -155,6 +175,13 @@ class ClaudeTextDelta(BaseModel, frozen=True):
|
||||
text: str
|
||||
|
||||
|
||||
class ClaudeThinkingDelta(BaseModel, frozen=True):
|
||||
"""Delta for thinking content block."""
|
||||
|
||||
type: Literal["thinking_delta"] = "thinking_delta"
|
||||
thinking: str
|
||||
|
||||
|
||||
class ClaudeInputJsonDelta(BaseModel, frozen=True):
|
||||
"""Delta for tool use input JSON content block."""
|
||||
|
||||
@@ -167,7 +194,7 @@ class ClaudeContentBlockDeltaEvent(BaseModel, frozen=True):
|
||||
|
||||
type: Literal["content_block_delta"] = "content_block_delta"
|
||||
index: int
|
||||
delta: ClaudeTextDelta | ClaudeInputJsonDelta
|
||||
delta: ClaudeTextDelta | ClaudeThinkingDelta | ClaudeInputJsonDelta
|
||||
|
||||
|
||||
class ClaudeContentBlockStopEvent(BaseModel, frozen=True):
|
||||
|
||||
@@ -5,7 +5,7 @@ from pydantic import Field
|
||||
|
||||
from exo.shared.topology import Connection
|
||||
from exo.shared.types.chunks import GenerationChunk, InputImageChunk
|
||||
from exo.shared.types.common import CommandId, Id, ModelId, NodeId, SessionId
|
||||
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
|
||||
from exo.shared.types.tasks import Task, TaskId, TaskStatus
|
||||
from exo.shared.types.worker.downloads import DownloadProgress
|
||||
from exo.shared.types.worker.instances import Instance, InstanceId
|
||||
@@ -102,13 +102,6 @@ class InputChunkReceived(BaseEvent):
|
||||
chunk: InputImageChunk
|
||||
|
||||
|
||||
class PrefillProgress(BaseEvent):
|
||||
command_id: CommandId
|
||||
model: ModelId
|
||||
processed_tokens: int
|
||||
total_tokens: int
|
||||
|
||||
|
||||
class TopologyEdgeCreated(BaseEvent):
|
||||
conn: Connection
|
||||
|
||||
@@ -155,7 +148,6 @@ Event = (
|
||||
| NodeDownloadProgress
|
||||
| ChunkGenerated
|
||||
| InputChunkReceived
|
||||
| PrefillProgress
|
||||
| TopologyEdgeCreated
|
||||
| TopologyEdgeDeleted
|
||||
| TracesCollected
|
||||
|
||||
@@ -145,7 +145,23 @@ class ResponseFunctionCallItem(BaseModel, frozen=True):
|
||||
status: ResponseStatus = "completed"
|
||||
|
||||
|
||||
ResponseItem = ResponseMessageItem | ResponseFunctionCallItem
|
||||
class ResponseReasoningSummaryText(BaseModel, frozen=True):
|
||||
"""Summary text part in a reasoning output item."""
|
||||
|
||||
type: Literal["summary_text"] = "summary_text"
|
||||
text: str
|
||||
|
||||
|
||||
class ResponseReasoningItem(BaseModel, frozen=True):
|
||||
"""Reasoning output item in response output array."""
|
||||
|
||||
type: Literal["reasoning"] = "reasoning"
|
||||
id: str
|
||||
summary: list[ResponseReasoningSummaryText] = Field(default_factory=list)
|
||||
status: ResponseStatus = "completed"
|
||||
|
||||
|
||||
ResponseItem = ResponseMessageItem | ResponseFunctionCallItem | ResponseReasoningItem
|
||||
|
||||
|
||||
class ResponseUsage(BaseModel, frozen=True):
|
||||
@@ -273,6 +289,58 @@ class ResponseFunctionCallArgumentsDoneEvent(BaseModel, frozen=True):
|
||||
arguments: str
|
||||
|
||||
|
||||
class ResponseReasoningSummaryPartAddedEvent(BaseModel, frozen=True):
|
||||
"""Event sent when a reasoning summary part is added."""
|
||||
|
||||
type: Literal["response.reasoning_summary_part.added"] = (
|
||||
"response.reasoning_summary_part.added"
|
||||
)
|
||||
sequence_number: int
|
||||
item_id: str
|
||||
output_index: int
|
||||
summary_index: int
|
||||
part: ResponseReasoningSummaryText
|
||||
|
||||
|
||||
class ResponseReasoningSummaryTextDeltaEvent(BaseModel, frozen=True):
|
||||
"""Event sent for reasoning summary text delta during streaming."""
|
||||
|
||||
type: Literal["response.reasoning_summary_text.delta"] = (
|
||||
"response.reasoning_summary_text.delta"
|
||||
)
|
||||
sequence_number: int
|
||||
item_id: str
|
||||
output_index: int
|
||||
summary_index: int
|
||||
delta: str
|
||||
|
||||
|
||||
class ResponseReasoningSummaryTextDoneEvent(BaseModel, frozen=True):
|
||||
"""Event sent when reasoning summary text is done."""
|
||||
|
||||
type: Literal["response.reasoning_summary_text.done"] = (
|
||||
"response.reasoning_summary_text.done"
|
||||
)
|
||||
sequence_number: int
|
||||
item_id: str
|
||||
output_index: int
|
||||
summary_index: int
|
||||
text: str
|
||||
|
||||
|
||||
class ResponseReasoningSummaryPartDoneEvent(BaseModel, frozen=True):
|
||||
"""Event sent when a reasoning summary part is done."""
|
||||
|
||||
type: Literal["response.reasoning_summary_part.done"] = (
|
||||
"response.reasoning_summary_part.done"
|
||||
)
|
||||
sequence_number: int
|
||||
item_id: str
|
||||
output_index: int
|
||||
summary_index: int
|
||||
part: ResponseReasoningSummaryText
|
||||
|
||||
|
||||
class ResponseCompletedEvent(BaseModel, frozen=True):
|
||||
"""Event sent when response is completed."""
|
||||
|
||||
@@ -292,5 +360,9 @@ ResponsesStreamEvent = (
|
||||
| ResponseOutputItemDoneEvent
|
||||
| ResponseFunctionCallArgumentsDeltaEvent
|
||||
| ResponseFunctionCallArgumentsDoneEvent
|
||||
| ResponseReasoningSummaryPartAddedEvent
|
||||
| ResponseReasoningSummaryTextDeltaEvent
|
||||
| ResponseReasoningSummaryTextDoneEvent
|
||||
| ResponseReasoningSummaryPartDoneEvent
|
||||
| ResponseCompletedEvent
|
||||
)
|
||||
|
||||
@@ -28,6 +28,7 @@ class GenerationResponse(BaseRunnerResponse):
|
||||
finish_reason: FinishReason | None = None
|
||||
stats: GenerationStats | None = None
|
||||
usage: Usage | None
|
||||
is_thinking: bool = False
|
||||
|
||||
|
||||
class ImageGenerationResponse(BaseRunnerResponse):
|
||||
|
||||
72
src/exo/worker/engines/mlx/dsml_encoding.py
Normal file
72
src/exo/worker/engines/mlx/dsml_encoding.py
Normal file
@@ -0,0 +1,72 @@
|
||||
import json
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from mlx_lm.chat_templates import deepseek_v32
|
||||
|
||||
from exo.shared.types.api import ToolCallItem
|
||||
|
||||
BOS_TOKEN: str = deepseek_v32.bos_token
|
||||
EOS_TOKEN: str = deepseek_v32.eos_token
|
||||
DSML_TOKEN: str = deepseek_v32.dsml_token
|
||||
THINKING_START: str = deepseek_v32.thinking_start_token
|
||||
THINKING_END: str = deepseek_v32.thinking_end_token
|
||||
USER_TOKEN = "<\uff5cUser\uff5c>"
|
||||
ASSISTANT_TOKEN = "<\uff5cAssistant\uff5c>"
|
||||
TOOL_CALLS_START = f"<{DSML_TOKEN}function_calls>"
|
||||
TOOL_CALLS_END = f"</{DSML_TOKEN}function_calls>"
|
||||
encode_messages = deepseek_v32.encode_messages
|
||||
|
||||
_INVOKE_PATTERN = re.compile(
|
||||
rf"<{re.escape(DSML_TOKEN)}invoke\s+name=\"([^\"]+)\">"
|
||||
rf"(.*?)"
|
||||
rf"</{re.escape(DSML_TOKEN)}invoke>",
|
||||
re.DOTALL,
|
||||
)
|
||||
|
||||
_PARAM_PATTERN = re.compile(
|
||||
rf"<{re.escape(DSML_TOKEN)}parameter\s+name=\"([^\"]+)\"\s+string=\"(true|false)\">"
|
||||
rf"(.*?)"
|
||||
rf"</{re.escape(DSML_TOKEN)}parameter>",
|
||||
re.DOTALL,
|
||||
)
|
||||
|
||||
|
||||
def parse_dsml_output(text: str) -> list[ToolCallItem] | None:
|
||||
"""Parse DSML function_calls block from model output text.
|
||||
|
||||
Args:
|
||||
text: The text containing the DSML function_calls block
|
||||
(including the start/end markers).
|
||||
|
||||
Returns:
|
||||
List of ToolCallItem, or None if parsing fails.
|
||||
"""
|
||||
tool_calls: list[ToolCallItem] = []
|
||||
|
||||
for invoke_match in _INVOKE_PATTERN.finditer(text):
|
||||
func_name = invoke_match.group(1)
|
||||
invoke_body = invoke_match.group(2)
|
||||
|
||||
args: dict[str, Any] = {}
|
||||
for param_match in _PARAM_PATTERN.finditer(invoke_body):
|
||||
param_name = param_match.group(1)
|
||||
is_string = param_match.group(2) == "true"
|
||||
param_value = param_match.group(3)
|
||||
|
||||
if is_string:
|
||||
args[param_name] = param_value
|
||||
else:
|
||||
try:
|
||||
args[param_name] = json.loads(param_value)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
args[param_name] = param_value
|
||||
|
||||
tool_calls.append(
|
||||
ToolCallItem(
|
||||
name=func_name,
|
||||
arguments=json.dumps(args),
|
||||
)
|
||||
)
|
||||
|
||||
return tool_calls if tool_calls else None
|
||||
@@ -458,6 +458,19 @@ def _patch_lossy_chat_template(template: str) -> str | None:
|
||||
return patched if n > 0 else None
|
||||
|
||||
|
||||
def _needs_dsml_encoding(task_params: TextGenerationTaskParams) -> bool:
|
||||
if "deepseek-v3.2" not in task_params.model.lower():
|
||||
return False
|
||||
# Use DSML encoding when tools are provided or tool results are in the conversation
|
||||
if task_params.tools:
|
||||
return True
|
||||
if task_params.chat_template_messages:
|
||||
return any(
|
||||
msg.get("role") == "tool" for msg in task_params.chat_template_messages
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
def apply_chat_template(
|
||||
tokenizer: TokenizerWrapper,
|
||||
task_params: TextGenerationTaskParams,
|
||||
@@ -469,7 +482,6 @@ def apply_chat_template(
|
||||
|
||||
When chat_template_messages is available (from Chat Completions API),
|
||||
uses those directly to preserve tool_calls, thinking, and other fields.
|
||||
Otherwise builds messages from the task params input/instructions.
|
||||
"""
|
||||
formatted_messages: list[dict[str, Any]] = []
|
||||
if task_params.chat_template_messages is not None:
|
||||
@@ -497,6 +509,19 @@ def apply_chat_template(
|
||||
partial_assistant_content = cast(str, formatted_messages[-1].get("content", ""))
|
||||
formatted_messages = formatted_messages[:-1]
|
||||
|
||||
if _needs_dsml_encoding(task_params):
|
||||
from exo.worker.engines.mlx.dsml_encoding import encode_messages
|
||||
|
||||
prompt = encode_messages(
|
||||
messages=formatted_messages,
|
||||
thinking_mode="thinking" if task_params.enable_thinking else "chat",
|
||||
tools=task_params.tools,
|
||||
)
|
||||
if partial_assistant_content:
|
||||
prompt += partial_assistant_content
|
||||
logger.info(prompt)
|
||||
return prompt
|
||||
|
||||
extra_kwargs: dict[str, Any] = {}
|
||||
if task_params.enable_thinking is not None:
|
||||
# Qwen3 and GLM use "enable_thinking"; DeepSeek uses "thinking".
|
||||
|
||||
@@ -7,6 +7,7 @@ from functools import cache
|
||||
from typing import Literal
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx_lm.models.deepseek_v32 import Model as DeepseekV32Model
|
||||
from mlx_lm.models.gpt_oss import Model as GptOssModel
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
|
||||
@@ -21,12 +22,17 @@ from exo.shared.constants import EXO_MAX_CHUNK_SIZE, EXO_TRACING_ENABLED
|
||||
from exo.shared.models.model_cards import ModelId, ModelTask
|
||||
from exo.shared.tracing import clear_trace_buffer, get_trace_buffer
|
||||
from exo.shared.types.api import ImageGenerationStats
|
||||
from exo.shared.types.chunks import ErrorChunk, ImageChunk, TokenChunk, ToolCallChunk
|
||||
from exo.shared.types.chunks import (
|
||||
ErrorChunk,
|
||||
ImageChunk,
|
||||
PrefillProgressChunk,
|
||||
TokenChunk,
|
||||
ToolCallChunk,
|
||||
)
|
||||
from exo.shared.types.common import CommandId
|
||||
from exo.shared.types.events import (
|
||||
ChunkGenerated,
|
||||
Event,
|
||||
PrefillProgress,
|
||||
RunnerStatusUpdated,
|
||||
TaskAcknowledged,
|
||||
TaskStatusUpdated,
|
||||
@@ -315,11 +321,13 @@ def main(
|
||||
) -> None:
|
||||
if device_rank == 0:
|
||||
event_sender.send(
|
||||
PrefillProgress(
|
||||
ChunkGenerated(
|
||||
command_id=command_id,
|
||||
model=shard_metadata.model_card.model_id,
|
||||
processed_tokens=processed,
|
||||
total_tokens=total,
|
||||
chunk=PrefillProgressChunk(
|
||||
model=shard_metadata.model_card.model_id,
|
||||
processed_tokens=processed,
|
||||
total_tokens=total,
|
||||
),
|
||||
)
|
||||
)
|
||||
cancelled_tasks.update(cancel_receiver.collect())
|
||||
@@ -346,16 +354,22 @@ def main(
|
||||
group=group,
|
||||
)
|
||||
|
||||
# For other thinking models (GLM, etc.), check if we need to
|
||||
# prepend the thinking tag that was consumed by the chat template
|
||||
if detect_thinking_prompt_suffix(prompt, tokenizer):
|
||||
if tokenizer.has_thinking:
|
||||
mlx_generator = parse_thinking_models(
|
||||
mlx_generator, tokenizer
|
||||
mlx_generator,
|
||||
tokenizer,
|
||||
# For other thinking models (GLM, etc.), check if we need to
|
||||
# prepend the thinking tag that was consumed by the chat template
|
||||
starts_in_thinking=detect_thinking_prompt_suffix(
|
||||
prompt, tokenizer
|
||||
),
|
||||
)
|
||||
|
||||
# GPT-OSS specific parsing to match other model formats.
|
||||
# Model-specific output parsing for tool calls.
|
||||
if isinstance(inference_model, GptOssModel):
|
||||
mlx_generator = parse_gpt_oss(mlx_generator)
|
||||
elif isinstance(inference_model, DeepseekV32Model):
|
||||
mlx_generator = parse_deepseek_v32(mlx_generator)
|
||||
elif tool_parser:
|
||||
mlx_generator = parse_tool_calls(mlx_generator, tool_parser)
|
||||
|
||||
@@ -407,6 +421,7 @@ def main(
|
||||
stats=response.stats,
|
||||
logprob=response.logprob,
|
||||
top_logprobs=response.top_logprobs,
|
||||
is_thinking=response.is_thinking,
|
||||
),
|
||||
)
|
||||
)
|
||||
@@ -668,44 +683,208 @@ def parse_gpt_oss(
|
||||
|
||||
if ch == "analysis" and not thinking:
|
||||
thinking = True
|
||||
yield response.model_copy(update={"text": "<think>"})
|
||||
|
||||
if ch != "analysis" and thinking:
|
||||
thinking = False
|
||||
yield response.model_copy(update={"text": "</think>"})
|
||||
|
||||
if delta:
|
||||
yield response.model_copy(update={"text": delta})
|
||||
yield response.model_copy(update={"text": delta, "is_thinking": thinking})
|
||||
|
||||
if response.finish_reason is not None:
|
||||
if thinking:
|
||||
yield response.model_copy(update={"text": "</think>"})
|
||||
yield response
|
||||
|
||||
|
||||
def parse_deepseek_v32(
|
||||
responses: Generator[GenerationResponse],
|
||||
) -> Generator[GenerationResponse | ToolCallResponse]:
|
||||
"""Parse DeepSeek V3.2 DSML tool calls from the generation stream.
|
||||
|
||||
Uses accumulated-text matching (not per-token marker checks) because
|
||||
DSML markers like <|DSML|function_calls> may span multiple tokens.
|
||||
Also handles <think>...</think> blocks for thinking mode.
|
||||
"""
|
||||
from exo.worker.engines.mlx.dsml_encoding import (
|
||||
THINKING_END,
|
||||
THINKING_START,
|
||||
TOOL_CALLS_END,
|
||||
TOOL_CALLS_START,
|
||||
parse_dsml_output,
|
||||
)
|
||||
|
||||
accumulated = ""
|
||||
in_tool_call = False
|
||||
thinking = False
|
||||
# Tokens buffered while we detect the start of a DSML block
|
||||
pending_buffer: list[GenerationResponse] = []
|
||||
# Text accumulated during a tool call block
|
||||
tool_call_text = ""
|
||||
|
||||
for response in responses:
|
||||
assert isinstance(response, GenerationResponse)
|
||||
|
||||
# ── Handle thinking tags ──
|
||||
if not thinking and THINKING_START in response.text:
|
||||
thinking = True
|
||||
# Yield any text before the <think> tag
|
||||
before = response.text[: response.text.index(THINKING_START)]
|
||||
if before:
|
||||
yield response.model_copy(update={"text": before})
|
||||
continue
|
||||
|
||||
if thinking and THINKING_END in response.text:
|
||||
thinking = False
|
||||
# Yield any text after the </think> tag
|
||||
after = response.text[
|
||||
response.text.index(THINKING_END) + len(THINKING_END) :
|
||||
]
|
||||
if after:
|
||||
yield response.model_copy(update={"text": after, "is_thinking": False})
|
||||
continue
|
||||
|
||||
if thinking:
|
||||
yield response.model_copy(update={"is_thinking": True})
|
||||
continue
|
||||
|
||||
# ── Handle tool call accumulation ──
|
||||
if in_tool_call:
|
||||
tool_call_text += response.text
|
||||
if TOOL_CALLS_END in tool_call_text:
|
||||
# Parse the accumulated DSML block
|
||||
parsed = parse_dsml_output(tool_call_text)
|
||||
if parsed is not None:
|
||||
logger.info(f"parsed DSML tool calls: {parsed}")
|
||||
yield ToolCallResponse(
|
||||
tool_calls=parsed,
|
||||
usage=response.usage,
|
||||
stats=response.stats,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"DSML tool call parsing failed for: {tool_call_text}"
|
||||
)
|
||||
yield response.model_copy(update={"text": tool_call_text})
|
||||
in_tool_call = False
|
||||
tool_call_text = ""
|
||||
continue
|
||||
|
||||
# EOS reached before end marker — yield buffered text as-is
|
||||
if response.finish_reason is not None:
|
||||
logger.info("DSML tool call parsing interrupted by EOS")
|
||||
yield response.model_copy(update={"text": tool_call_text})
|
||||
in_tool_call = False
|
||||
tool_call_text = ""
|
||||
continue
|
||||
|
||||
# ── Detect start of tool call block ──
|
||||
accumulated += response.text
|
||||
|
||||
if TOOL_CALLS_START in accumulated:
|
||||
# The start marker might be split across pending_buffer + current token
|
||||
start_idx = accumulated.index(TOOL_CALLS_START)
|
||||
# Yield any pending tokens that are purely before the marker
|
||||
pre_text = accumulated[:start_idx]
|
||||
if pre_text:
|
||||
# Flush pending buffer tokens that contributed text before the marker
|
||||
for buf_resp in pending_buffer:
|
||||
if pre_text:
|
||||
chunk = buf_resp.text
|
||||
if len(chunk) <= len(pre_text):
|
||||
yield buf_resp
|
||||
pre_text = pre_text[len(chunk) :]
|
||||
else:
|
||||
yield buf_resp.model_copy(update={"text": pre_text})
|
||||
pre_text = ""
|
||||
pending_buffer = []
|
||||
tool_call_text = accumulated[start_idx:]
|
||||
accumulated = ""
|
||||
|
||||
# Check if the end marker is already present (entire tool call in one token)
|
||||
if TOOL_CALLS_END in tool_call_text:
|
||||
parsed = parse_dsml_output(tool_call_text)
|
||||
if parsed is not None:
|
||||
logger.info(f"parsed DSML tool calls: {parsed}")
|
||||
yield ToolCallResponse(
|
||||
tool_calls=parsed,
|
||||
usage=response.usage,
|
||||
stats=response.stats,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"DSML tool call parsing failed for: {tool_call_text}"
|
||||
)
|
||||
yield response.model_copy(update={"text": tool_call_text})
|
||||
tool_call_text = ""
|
||||
else:
|
||||
in_tool_call = True
|
||||
continue
|
||||
|
||||
# Check if accumulated text might be the start of a DSML marker
|
||||
# Buffer tokens if we see a partial match at the end
|
||||
if _could_be_dsml_prefix(accumulated):
|
||||
pending_buffer.append(response)
|
||||
continue
|
||||
|
||||
# No partial match — flush all pending tokens and the current one
|
||||
for buf_resp in pending_buffer:
|
||||
yield buf_resp
|
||||
pending_buffer = []
|
||||
accumulated = ""
|
||||
yield response
|
||||
|
||||
# Flush any remaining pending buffer at generator end
|
||||
for buf_resp in pending_buffer:
|
||||
yield buf_resp
|
||||
|
||||
|
||||
def _could_be_dsml_prefix(text: str) -> bool:
|
||||
"""Check if the end of text could be the start of a DSML function_calls marker.
|
||||
|
||||
We look for suffixes of text that are prefixes of the TOOL_CALLS_START pattern.
|
||||
This allows us to buffer tokens until we can determine if a tool call is starting.
|
||||
"""
|
||||
from exo.worker.engines.mlx.dsml_encoding import TOOL_CALLS_START
|
||||
|
||||
# Only check the last portion of text that could overlap with the marker
|
||||
max_check = len(TOOL_CALLS_START)
|
||||
tail = text[-max_check:] if len(text) > max_check else text
|
||||
|
||||
# Check if any suffix of tail is a prefix of TOOL_CALLS_START
|
||||
for i in range(len(tail)):
|
||||
suffix = tail[i:]
|
||||
if TOOL_CALLS_START.startswith(suffix):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def parse_thinking_models(
|
||||
responses: Generator[GenerationResponse],
|
||||
tokenizer: TokenizerWrapper,
|
||||
starts_in_thinking: bool = True,
|
||||
) -> Generator[GenerationResponse]:
|
||||
"""Route thinking tokens via is_thinking flag.
|
||||
|
||||
Swallows think tag tokens, sets is_thinking on all others.
|
||||
Always yields tokens with finish_reason to avoid hanging the chunk stream.
|
||||
"""
|
||||
For models that inject thinking tags in the prompt (like GLM-4.7),
|
||||
prepend the thinking tag to the output stream so the frontend
|
||||
can properly parse thinking content.
|
||||
"""
|
||||
first = True
|
||||
in_thinking = starts_in_thinking
|
||||
for response in responses:
|
||||
if isinstance(response, ToolCallResponse):
|
||||
yield response
|
||||
continue
|
||||
if first:
|
||||
first = False
|
||||
yield response.model_copy(
|
||||
update={
|
||||
"text": tokenizer.think_start,
|
||||
"token": tokenizer.think_start_id,
|
||||
}
|
||||
)
|
||||
yield response
|
||||
|
||||
is_think_tag = (
|
||||
tokenizer.think_end is not None and response.text == tokenizer.think_end
|
||||
) or (
|
||||
tokenizer.think_start is not None and response.text == tokenizer.think_start
|
||||
)
|
||||
|
||||
if is_think_tag:
|
||||
in_thinking = response.text != tokenizer.think_end
|
||||
# Never swallow finish_reason — the chunk stream needs it to terminate.
|
||||
if response.finish_reason is not None:
|
||||
yield response.model_copy(update={"text": "", "is_thinking": False})
|
||||
continue
|
||||
yield response.model_copy(update={"is_thinking": in_thinking})
|
||||
|
||||
|
||||
def _send_image_chunk(
|
||||
|
||||
967
src/exo/worker/tests/unittests/test_runner/test_dsml_e2e.py
Normal file
967
src/exo/worker/tests/unittests/test_runner/test_dsml_e2e.py
Normal file
@@ -0,0 +1,967 @@
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
from exo.shared.types.worker.runner_response import (
|
||||
GenerationResponse,
|
||||
ToolCallResponse,
|
||||
)
|
||||
from exo.worker.engines.mlx.dsml_encoding import (
|
||||
ASSISTANT_TOKEN,
|
||||
BOS_TOKEN,
|
||||
DSML_TOKEN,
|
||||
EOS_TOKEN,
|
||||
THINKING_END,
|
||||
THINKING_START,
|
||||
TOOL_CALLS_END,
|
||||
TOOL_CALLS_START,
|
||||
USER_TOKEN,
|
||||
encode_messages,
|
||||
parse_dsml_output,
|
||||
)
|
||||
from exo.worker.runner.runner import parse_deepseek_v32
|
||||
|
||||
# ── Shared fixtures ──────────────────────────────────────────────
|
||||
|
||||
_WEATHER_TOOLS: list[dict[str, Any]] = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get the current weather in a given city",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {"type": "string", "description": "The city name"},
|
||||
"units": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
"description": "Temperature units",
|
||||
},
|
||||
},
|
||||
"required": ["city"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_time",
|
||||
"description": "Get the current time in a timezone",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"timezone": {"type": "string"},
|
||||
},
|
||||
"required": ["timezone"],
|
||||
},
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def _simulate_tokens(
|
||||
texts: list[str],
|
||||
finish_on_last: bool = True,
|
||||
) -> Generator[GenerationResponse]:
|
||||
"""Simulate a model producing tokens from a list of text strings."""
|
||||
for i, text in enumerate(texts):
|
||||
is_last = i == len(texts) - 1
|
||||
yield GenerationResponse(
|
||||
text=text,
|
||||
token=i,
|
||||
finish_reason="stop" if (is_last and finish_on_last) else None,
|
||||
usage=None,
|
||||
)
|
||||
|
||||
|
||||
# ── Test: Standard text response (no tool calls) ────────────────
|
||||
|
||||
|
||||
class TestE2EStandardResponse:
|
||||
"""Model generates a plain text response — no tool calling involved."""
|
||||
|
||||
def test_plain_text_passthrough(self):
|
||||
"""Simulate model producing: 'The weather in NYC is 72°F and sunny.'"""
|
||||
# Step 1: Encode the prompt (with tools available)
|
||||
messages: list[dict[str, Any]] = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "What's the weather in NYC?"},
|
||||
]
|
||||
prompt = encode_messages(messages, thinking_mode="chat", tools=_WEATHER_TOOLS)
|
||||
|
||||
# Verify prompt structure
|
||||
assert BOS_TOKEN in prompt
|
||||
assert "## Tools" in prompt
|
||||
assert "get_weather" in prompt
|
||||
assert f"{USER_TOKEN}What's the weather in NYC?{ASSISTANT_TOKEN}" in prompt
|
||||
|
||||
# Step 2: Simulate model response — plain text tokens (no DSML)
|
||||
model_tokens = [
|
||||
"The weather",
|
||||
" in NYC",
|
||||
" is 72",
|
||||
"°F",
|
||||
" and sunny",
|
||||
".",
|
||||
]
|
||||
results = list(parse_deepseek_v32(_simulate_tokens(model_tokens)))
|
||||
|
||||
# Step 3: Verify all tokens pass through as GenerationResponse
|
||||
gen_results = [r for r in results if isinstance(r, GenerationResponse)]
|
||||
tool_results = [r for r in results if isinstance(r, ToolCallResponse)]
|
||||
|
||||
assert len(tool_results) == 0
|
||||
assert len(gen_results) == 6
|
||||
full_text = "".join(r.text for r in gen_results)
|
||||
assert full_text == "The weather in NYC is 72°F and sunny."
|
||||
assert gen_results[-1].finish_reason == "stop"
|
||||
|
||||
|
||||
# ── Test: Tool call response ─────────────────────────────────────
|
||||
|
||||
|
||||
class TestE2EToolCallResponse:
|
||||
"""Model generates a DSML tool call — realistic token boundaries."""
|
||||
|
||||
def test_realistic_tool_call_tokens(self):
|
||||
"""Simulate model generating a get_weather tool call with realistic token splits.
|
||||
|
||||
Real models split DSML markers across tokens unpredictably.
|
||||
This simulates how DeepSeek V3.2 actually tokenizes DSML output.
|
||||
"""
|
||||
# Step 1: Encode prompt
|
||||
messages: list[dict[str, Any]] = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "What's the weather in San Francisco?"},
|
||||
]
|
||||
prompt = encode_messages(messages, thinking_mode="chat", tools=_WEATHER_TOOLS)
|
||||
assert "get_weather" in prompt
|
||||
|
||||
# Step 2: Simulate realistic token-by-token model output
|
||||
# The model first produces some text, then a DSML tool call block
|
||||
model_tokens = [
|
||||
"I'll check the weather for you.",
|
||||
"\n\n",
|
||||
f"<{DSML_TOKEN}", # marker split across tokens
|
||||
"function_calls>\n",
|
||||
f'<{DSML_TOKEN}invoke name="get_weather">\n',
|
||||
f'<{DSML_TOKEN}parameter name="city" string="true">',
|
||||
"San Francisco",
|
||||
f"</{DSML_TOKEN}parameter>\n",
|
||||
f'<{DSML_TOKEN}parameter name="units" string="false">',
|
||||
'"celsius"',
|
||||
f"</{DSML_TOKEN}parameter>\n",
|
||||
f"</{DSML_TOKEN}invoke>\n",
|
||||
f"</{DSML_TOKEN}function_calls>",
|
||||
]
|
||||
|
||||
results = list(parse_deepseek_v32(_simulate_tokens(model_tokens)))
|
||||
|
||||
# Step 3: Verify
|
||||
gen_results = [r for r in results if isinstance(r, GenerationResponse)]
|
||||
tool_results = [r for r in results if isinstance(r, ToolCallResponse)]
|
||||
|
||||
# Should have text tokens before tool call + one ToolCallResponse
|
||||
assert len(tool_results) == 1
|
||||
assert len(tool_results[0].tool_calls) == 1
|
||||
|
||||
tc = tool_results[0].tool_calls[0]
|
||||
assert tc.name == "get_weather"
|
||||
args = json.loads(tc.arguments) # pyright: ignore[reportAny]
|
||||
assert args["city"] == "San Francisco"
|
||||
assert args["units"] == "celsius"
|
||||
|
||||
# The text before the tool call should still be yielded
|
||||
text_before = "".join(r.text for r in gen_results if not r.is_thinking)
|
||||
assert "check the weather" in text_before
|
||||
|
||||
def test_multiple_tool_calls_in_one_block(self):
|
||||
"""Model generates two tool calls in a single function_calls block."""
|
||||
messages: list[dict[str, Any]] = [
|
||||
{"role": "system", "content": "You are helpful."},
|
||||
{"role": "user", "content": "Weather in NYC and time in EST?"},
|
||||
]
|
||||
prompt = encode_messages(messages, thinking_mode="chat", tools=_WEATHER_TOOLS)
|
||||
assert "get_weather" in prompt
|
||||
assert "get_time" in prompt
|
||||
|
||||
# Simulate model output with two invocations
|
||||
model_tokens = [
|
||||
"Let me check both.\n\n",
|
||||
TOOL_CALLS_START,
|
||||
"\n",
|
||||
f'<{DSML_TOKEN}invoke name="get_weather">\n',
|
||||
f'<{DSML_TOKEN}parameter name="city" string="true">NYC</{DSML_TOKEN}parameter>\n',
|
||||
f"</{DSML_TOKEN}invoke>\n",
|
||||
f'<{DSML_TOKEN}invoke name="get_time">\n',
|
||||
f'<{DSML_TOKEN}parameter name="timezone" string="true">EST</{DSML_TOKEN}parameter>\n',
|
||||
f"</{DSML_TOKEN}invoke>\n",
|
||||
TOOL_CALLS_END,
|
||||
]
|
||||
|
||||
results = list(parse_deepseek_v32(_simulate_tokens(model_tokens)))
|
||||
tool_results = [r for r in results if isinstance(r, ToolCallResponse)]
|
||||
|
||||
assert len(tool_results) == 1
|
||||
assert len(tool_results[0].tool_calls) == 2
|
||||
assert tool_results[0].tool_calls[0].name == "get_weather"
|
||||
assert tool_results[0].tool_calls[1].name == "get_time"
|
||||
|
||||
args0 = json.loads(tool_results[0].tool_calls[0].arguments) # pyright: ignore[reportAny]
|
||||
args1 = json.loads(tool_results[0].tool_calls[1].arguments) # pyright: ignore[reportAny]
|
||||
assert args0 == {"city": "NYC"}
|
||||
assert args1 == {"timezone": "EST"}
|
||||
|
||||
|
||||
# ── Test: Multi-turn tool use flow ───────────────────────────────
|
||||
|
||||
|
||||
class TestE2EMultiTurnToolUse:
|
||||
"""Full multi-turn: user asks → model calls tool → tool result → model answers."""
|
||||
|
||||
def test_encode_multi_turn_with_tool_results(self):
|
||||
"""Verify the prompt for turn 2 (after tool results) is correctly encoded."""
|
||||
# Turn 1: user asks, model calls tool
|
||||
# Turn 2: tool result provided, model answers
|
||||
messages: list[dict[str, Any]] = [
|
||||
{"role": "system", "content": "You are a weather assistant."},
|
||||
{"role": "user", "content": "What's the weather in NYC?"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": '{"city": "NYC"}',
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
{"role": "tool", "content": '{"temperature": 72, "condition": "sunny"}'},
|
||||
]
|
||||
|
||||
prompt = encode_messages(messages, thinking_mode="chat", tools=_WEATHER_TOOLS)
|
||||
|
||||
# Verify multi-turn structure
|
||||
assert BOS_TOKEN in prompt
|
||||
assert "You are a weather assistant." in prompt
|
||||
assert "## Tools" in prompt
|
||||
|
||||
# The assistant's tool call should be encoded as DSML
|
||||
assert TOOL_CALLS_START in prompt
|
||||
assert f'<{DSML_TOKEN}invoke name="get_weather">' in prompt
|
||||
assert EOS_TOKEN in prompt
|
||||
|
||||
# The tool result should be wrapped in function_results
|
||||
assert "<function_results>" in prompt
|
||||
assert "<result>" in prompt
|
||||
assert "72" in prompt
|
||||
assert "</function_results>" in prompt
|
||||
|
||||
# Now simulate model answering after seeing the tool result
|
||||
model_tokens = [
|
||||
"The current",
|
||||
" weather in NYC",
|
||||
" is 72°F",
|
||||
" and sunny.",
|
||||
]
|
||||
results = list(parse_deepseek_v32(_simulate_tokens(model_tokens)))
|
||||
|
||||
gen_results = [r for r in results if isinstance(r, GenerationResponse)]
|
||||
tool_results = [r for r in results if isinstance(r, ToolCallResponse)]
|
||||
|
||||
assert len(tool_results) == 0
|
||||
full_text = "".join(r.text for r in gen_results)
|
||||
assert full_text == "The current weather in NYC is 72°F and sunny."
|
||||
|
||||
def test_multi_tool_results_encoding(self):
|
||||
"""Verify encoding when model called two tools and both return results."""
|
||||
messages: list[dict[str, Any]] = [
|
||||
{"role": "user", "content": "Weather and time?"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": '{"city": "LA"}',
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_time",
|
||||
"arguments": '{"timezone": "PST"}',
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
{"role": "tool", "content": "85F, clear skies"},
|
||||
{"role": "tool", "content": "3:42 PM PST"},
|
||||
]
|
||||
|
||||
prompt = encode_messages(messages, thinking_mode="chat", tools=_WEATHER_TOOLS)
|
||||
|
||||
# Should have one function_results block with two results
|
||||
assert prompt.count("<function_results>") == 1
|
||||
assert prompt.count("</function_results>") == 1
|
||||
assert "<result>85F, clear skies</result>" in prompt
|
||||
assert "<result>3:42 PM PST</result>" in prompt
|
||||
|
||||
|
||||
# ── Test: Thinking + tool call ───────────────────────────────────
|
||||
|
||||
|
||||
class TestE2EThinkingAndToolCall:
|
||||
"""Model uses thinking mode, reasons, then makes a tool call."""
|
||||
|
||||
def test_thinking_then_tool_call(self):
|
||||
"""Model thinks first, then produces a DSML tool call block."""
|
||||
messages: list[dict[str, Any]] = [
|
||||
{"role": "user", "content": "What's the weather?"},
|
||||
]
|
||||
prompt = encode_messages(
|
||||
messages, tools=_WEATHER_TOOLS, thinking_mode="thinking"
|
||||
)
|
||||
# Thinking mode: prompt should end with <think>
|
||||
assert prompt.endswith(THINKING_START)
|
||||
|
||||
# Simulate: model outputs <think>, thinks, closes thinking, then tool call.
|
||||
# In the full pipeline, parse_thinking_models handles the case where
|
||||
# <think> is in the prompt. Here we test parse_deepseek_v32 directly,
|
||||
# which detects <think>/<think> markers in the stream.
|
||||
model_tokens = [
|
||||
THINKING_START,
|
||||
"The user wants weather",
|
||||
" information. I should use",
|
||||
" the get_weather tool.",
|
||||
THINKING_END,
|
||||
"\n\n",
|
||||
TOOL_CALLS_START,
|
||||
"\n",
|
||||
f'<{DSML_TOKEN}invoke name="get_weather">\n',
|
||||
f'<{DSML_TOKEN}parameter name="city" string="true">',
|
||||
"San Francisco",
|
||||
f"</{DSML_TOKEN}parameter>\n",
|
||||
f"</{DSML_TOKEN}invoke>\n",
|
||||
TOOL_CALLS_END,
|
||||
]
|
||||
|
||||
results = list(parse_deepseek_v32(_simulate_tokens(model_tokens)))
|
||||
|
||||
gen_results = [r for r in results if isinstance(r, GenerationResponse)]
|
||||
tool_results = [r for r in results if isinstance(r, ToolCallResponse)]
|
||||
|
||||
# Should have thinking tokens + tool call
|
||||
thinking_results = [r for r in gen_results if r.is_thinking]
|
||||
|
||||
assert len(thinking_results) >= 1
|
||||
thinking_text = "".join(r.text for r in thinking_results)
|
||||
assert "get_weather tool" in thinking_text
|
||||
|
||||
assert len(tool_results) == 1
|
||||
assert tool_results[0].tool_calls[0].name == "get_weather"
|
||||
args = json.loads(tool_results[0].tool_calls[0].arguments) # pyright: ignore[reportAny]
|
||||
assert args["city"] == "San Francisco"
|
||||
|
||||
def test_thinking_prompt_encoding(self):
|
||||
"""Verify thinking mode affects prompt encoding correctly."""
|
||||
messages: list[dict[str, Any]] = [
|
||||
{"role": "system", "content": "Be thorough."},
|
||||
{"role": "user", "content": "What's the weather?"},
|
||||
]
|
||||
|
||||
# With thinking enabled
|
||||
prompt_think = encode_messages(
|
||||
messages, tools=_WEATHER_TOOLS, thinking_mode="thinking"
|
||||
)
|
||||
assert prompt_think.endswith(THINKING_START)
|
||||
|
||||
# With thinking disabled
|
||||
prompt_no_think = encode_messages(
|
||||
messages, tools=_WEATHER_TOOLS, thinking_mode="chat"
|
||||
)
|
||||
assert prompt_no_think.endswith(THINKING_END)
|
||||
|
||||
# Both should have the same tool definitions
|
||||
assert "get_weather" in prompt_think
|
||||
assert "get_weather" in prompt_no_think
|
||||
|
||||
|
||||
# ── Test: Round-trip encode → parse ──────────────────────────────
|
||||
|
||||
|
||||
class TestE2ERoundTrip:
|
||||
"""Verify that DSML we encode can be parsed back correctly."""
|
||||
|
||||
def test_encoded_tool_call_is_parseable(self):
|
||||
"""Encode an assistant tool call message, then parse the DSML output."""
|
||||
messages: list[dict[str, Any]] = [
|
||||
{"role": "user", "content": "Weather?"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": '{"city": "Tokyo", "units": "celsius"}',
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
prompt = encode_messages(messages, thinking_mode="chat", tools=_WEATHER_TOOLS)
|
||||
|
||||
# Extract the DSML function_calls block from the prompt
|
||||
start = prompt.index(TOOL_CALLS_START)
|
||||
end = prompt.index(TOOL_CALLS_END) + len(TOOL_CALLS_END)
|
||||
dsml_block = prompt[start:end]
|
||||
|
||||
# Parse it back
|
||||
parsed = parse_dsml_output(dsml_block)
|
||||
assert parsed is not None
|
||||
assert len(parsed) == 1
|
||||
assert parsed[0].name == "get_weather"
|
||||
args = json.loads(parsed[0].arguments) # pyright: ignore[reportAny]
|
||||
assert args["city"] == "Tokyo"
|
||||
assert args["units"] == "celsius"
|
||||
|
||||
def test_encoded_multi_tool_call_round_trips(self):
|
||||
"""Encode multiple tool calls, verify they parse back correctly."""
|
||||
messages: list[dict[str, Any]] = [
|
||||
{"role": "user", "content": "Both please"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": '{"city": "Paris"}',
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_time",
|
||||
"arguments": '{"timezone": "CET"}',
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
prompt = encode_messages(messages, thinking_mode="chat", tools=_WEATHER_TOOLS)
|
||||
|
||||
start = prompt.index(TOOL_CALLS_START)
|
||||
end = prompt.index(TOOL_CALLS_END) + len(TOOL_CALLS_END)
|
||||
dsml_block = prompt[start:end]
|
||||
|
||||
parsed = parse_dsml_output(dsml_block)
|
||||
assert parsed is not None
|
||||
assert len(parsed) == 2
|
||||
assert parsed[0].name == "get_weather"
|
||||
assert parsed[1].name == "get_time"
|
||||
assert json.loads(parsed[0].arguments) == {"city": "Paris"}
|
||||
assert json.loads(parsed[1].arguments) == {"timezone": "CET"}
|
||||
|
||||
|
||||
# ── Test: Edge cases with realistic token boundaries ─────────────
|
||||
|
||||
|
||||
class TestE2EEdgeCases:
|
||||
"""Edge cases that occur in real model inference."""
|
||||
|
||||
def test_dsml_marker_split_at_fullwidth_pipe(self):
|
||||
"""The fullwidth pipe character | might be its own token."""
|
||||
# This is a realistic tokenization: the DSML marker is split at the | chars
|
||||
model_tokens = [
|
||||
"Let me help.\n\n",
|
||||
"<\uff5c", # start of |DSML|
|
||||
"DSML\uff5c", # rest of DSML token
|
||||
"function_calls>\n",
|
||||
f'<{DSML_TOKEN}invoke name="get_weather">\n',
|
||||
f'<{DSML_TOKEN}parameter name="city" string="true">NYC</{DSML_TOKEN}parameter>\n',
|
||||
f"</{DSML_TOKEN}invoke>\n",
|
||||
TOOL_CALLS_END,
|
||||
]
|
||||
|
||||
results = list(parse_deepseek_v32(_simulate_tokens(model_tokens)))
|
||||
tool_results = [r for r in results if isinstance(r, ToolCallResponse)]
|
||||
|
||||
assert len(tool_results) == 1
|
||||
assert tool_results[0].tool_calls[0].name == "get_weather"
|
||||
|
||||
def test_tool_call_with_nested_json_object(self):
|
||||
"""Model passes a complex JSON object as a non-string parameter."""
|
||||
dsml_block = (
|
||||
f"{TOOL_CALLS_START}\n"
|
||||
f'<{DSML_TOKEN}invoke name="create_event">\n'
|
||||
f'<{DSML_TOKEN}parameter name="title" string="true">Team Standup</{DSML_TOKEN}parameter>\n'
|
||||
f'<{DSML_TOKEN}parameter name="config" string="false">'
|
||||
f'{{"recurring": true, "days": ["mon", "wed", "fri"], "time": "09:00"}}'
|
||||
f"</{DSML_TOKEN}parameter>\n"
|
||||
f"</{DSML_TOKEN}invoke>\n"
|
||||
f"{TOOL_CALLS_END}"
|
||||
)
|
||||
|
||||
# Feed as single token (model might produce it all at once after prefill)
|
||||
results = list(parse_deepseek_v32(_simulate_tokens([dsml_block])))
|
||||
tool_results = [r for r in results if isinstance(r, ToolCallResponse)]
|
||||
|
||||
assert len(tool_results) == 1
|
||||
tc = tool_results[0].tool_calls[0]
|
||||
assert tc.name == "create_event"
|
||||
args = json.loads(tc.arguments) # pyright: ignore[reportAny]
|
||||
assert args["title"] == "Team Standup"
|
||||
assert args["config"]["recurring"] is True
|
||||
assert args["config"]["days"] == ["mon", "wed", "fri"]
|
||||
|
||||
def test_text_with_angle_brackets_not_mistaken_for_dsml(self):
|
||||
"""Angle brackets in normal text should not trigger DSML buffering."""
|
||||
model_tokens = [
|
||||
"The formula is ",
|
||||
"<x, y>",
|
||||
" where x > 0",
|
||||
" and y < 100.",
|
||||
]
|
||||
|
||||
results = list(parse_deepseek_v32(_simulate_tokens(model_tokens)))
|
||||
gen_results = [r for r in results if isinstance(r, GenerationResponse)]
|
||||
tool_results = [r for r in results if isinstance(r, ToolCallResponse)]
|
||||
|
||||
assert len(tool_results) == 0
|
||||
full_text = "".join(r.text for r in gen_results)
|
||||
assert "formula" in full_text
|
||||
assert "<x, y>" in full_text
|
||||
|
||||
def test_empty_model_response(self):
|
||||
"""Model produces only EOS (empty response)."""
|
||||
model_tokens = [""]
|
||||
results = list(parse_deepseek_v32(_simulate_tokens(model_tokens)))
|
||||
gen_results = [r for r in results if isinstance(r, GenerationResponse)]
|
||||
assert len(gen_results) == 1
|
||||
assert gen_results[0].text == ""
|
||||
assert gen_results[0].finish_reason == "stop"
|
||||
|
||||
|
||||
# ── Test: Full EPDP spec round-trip ──────────────────────────────
|
||||
|
||||
|
||||
class TestE2EFullRoundTrip:
|
||||
"""Full round-trip matching the vLLM EPDP spec.
|
||||
|
||||
Simulates the complete multi-turn flow:
|
||||
Turn 1: user asks → think → tool call → tool result → think → answer
|
||||
Turn 2: user asks again → old reasoning stripped → think → answer
|
||||
"""
|
||||
|
||||
def test_single_tool_full_flow_with_thinking(self):
|
||||
"""Complete flow: user → think → tool call → tool result → think → answer.
|
||||
|
||||
This is the core EPDP flow from the vLLM spec.
|
||||
"""
|
||||
# ── Turn 1.1: User asks, encode prompt ──
|
||||
messages: list[dict[str, Any]] = [
|
||||
{"role": "system", "content": "You are a weather assistant."},
|
||||
{"role": "user", "content": "How's the weather in Hangzhou?"},
|
||||
]
|
||||
prompt_1 = encode_messages(
|
||||
messages, tools=_WEATHER_TOOLS, thinking_mode="thinking"
|
||||
)
|
||||
assert prompt_1.endswith(THINKING_START)
|
||||
assert "## Tools" in prompt_1
|
||||
assert "get_weather" in prompt_1
|
||||
|
||||
# ── Turn 1.1: Model thinks, then calls tool ──
|
||||
model_tokens_1 = [
|
||||
THINKING_START,
|
||||
"The user wants to know the weather in Hangzhou.",
|
||||
" I need to use the get_weather tool.",
|
||||
THINKING_END,
|
||||
"\n\n",
|
||||
TOOL_CALLS_START,
|
||||
"\n",
|
||||
f'<{DSML_TOKEN}invoke name="get_weather">\n',
|
||||
f'<{DSML_TOKEN}parameter name="city" string="true">Hangzhou</{DSML_TOKEN}parameter>\n',
|
||||
f"</{DSML_TOKEN}invoke>\n",
|
||||
TOOL_CALLS_END,
|
||||
]
|
||||
results_1 = list(parse_deepseek_v32(_simulate_tokens(model_tokens_1)))
|
||||
|
||||
# Verify: thinking tokens + tool call
|
||||
gen_1 = [r for r in results_1 if isinstance(r, GenerationResponse)]
|
||||
tool_1 = [r for r in results_1 if isinstance(r, ToolCallResponse)]
|
||||
thinking_1 = [r for r in gen_1 if r.is_thinking]
|
||||
|
||||
assert len(thinking_1) >= 1
|
||||
assert "get_weather tool" in "".join(r.text for r in thinking_1)
|
||||
assert len(tool_1) == 1
|
||||
assert tool_1[0].tool_calls[0].name == "get_weather"
|
||||
tc_args = json.loads(tool_1[0].tool_calls[0].arguments) # pyright: ignore[reportAny]
|
||||
assert tc_args == {"city": "Hangzhou"}
|
||||
|
||||
# ── Turn 1.2: Add assistant response + tool result to messages ──
|
||||
messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"reasoning_content": "The user wants to know the weather in Hangzhou. I need to use the get_weather tool.",
|
||||
"tool_calls": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": '{"city": "Hangzhou"}',
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"content": '{"temperature": "7~13°C", "condition": "Cloudy"}',
|
||||
}
|
||||
)
|
||||
|
||||
# Encode prompt for turn 1.2
|
||||
prompt_2 = encode_messages(
|
||||
messages, tools=_WEATHER_TOOLS, thinking_mode="thinking"
|
||||
)
|
||||
|
||||
# Verify: prompt has the full conversation structure
|
||||
assert TOOL_CALLS_START in prompt_2 # assistant's encoded tool call
|
||||
assert EOS_TOKEN in prompt_2 # assistant turn ends with EOS
|
||||
assert "<function_results>" in prompt_2
|
||||
assert "<result>" in prompt_2
|
||||
assert "Cloudy" in prompt_2
|
||||
assert "</function_results>" in prompt_2
|
||||
# After tool results with thinking enabled → <think> appended
|
||||
assert prompt_2.endswith(THINKING_START)
|
||||
# The assistant's reasoning_content should appear (it's after last_user_idx)
|
||||
assert "get_weather tool" in prompt_2
|
||||
|
||||
# ── Turn 1.2: Model thinks about results, then answers ──
|
||||
model_tokens_2 = [
|
||||
THINKING_START,
|
||||
"The weather in Hangzhou is Cloudy, 7~13°C.",
|
||||
" I'll tell the user.",
|
||||
THINKING_END,
|
||||
"The weather in Hangzhou is currently cloudy with temperatures between 7°C and 13°C.",
|
||||
]
|
||||
results_2 = list(parse_deepseek_v32(_simulate_tokens(model_tokens_2)))
|
||||
|
||||
gen_2 = [r for r in results_2 if isinstance(r, GenerationResponse)]
|
||||
tool_2 = [r for r in results_2 if isinstance(r, ToolCallResponse)]
|
||||
thinking_2 = [r for r in gen_2 if r.is_thinking]
|
||||
non_thinking_2 = [r for r in gen_2 if not r.is_thinking]
|
||||
|
||||
assert len(tool_2) == 0 # No more tool calls
|
||||
assert len(thinking_2) >= 1
|
||||
assert "Cloudy" in "".join(r.text for r in thinking_2)
|
||||
assert len(non_thinking_2) >= 1
|
||||
final_text = "".join(r.text for r in non_thinking_2)
|
||||
assert "7°C" in final_text
|
||||
assert "13°C" in final_text
|
||||
|
||||
def test_multi_tool_full_flow(self):
|
||||
"""Flow with two tools: user → think → 2 tool calls → 2 results → think → answer."""
|
||||
# ── Initial prompt ──
|
||||
messages: list[dict[str, Any]] = [
|
||||
{"role": "system", "content": "You help with weather and time."},
|
||||
{"role": "user", "content": "Weather in NYC and time in EST?"},
|
||||
]
|
||||
prompt_1 = encode_messages(
|
||||
messages, tools=_WEATHER_TOOLS, thinking_mode="thinking"
|
||||
)
|
||||
assert prompt_1.endswith(THINKING_START)
|
||||
|
||||
# ── Model thinks, calls both tools ──
|
||||
model_tokens_1 = [
|
||||
THINKING_START,
|
||||
"Two requests: weather and time. I'll call both.",
|
||||
THINKING_END,
|
||||
"\n\n",
|
||||
TOOL_CALLS_START,
|
||||
"\n",
|
||||
f'<{DSML_TOKEN}invoke name="get_weather">\n',
|
||||
f'<{DSML_TOKEN}parameter name="city" string="true">NYC</{DSML_TOKEN}parameter>\n',
|
||||
f"</{DSML_TOKEN}invoke>\n",
|
||||
f'<{DSML_TOKEN}invoke name="get_time">\n',
|
||||
f'<{DSML_TOKEN}parameter name="timezone" string="true">EST</{DSML_TOKEN}parameter>\n',
|
||||
f"</{DSML_TOKEN}invoke>\n",
|
||||
TOOL_CALLS_END,
|
||||
]
|
||||
results_1 = list(parse_deepseek_v32(_simulate_tokens(model_tokens_1)))
|
||||
tool_1 = [r for r in results_1 if isinstance(r, ToolCallResponse)]
|
||||
|
||||
assert len(tool_1) == 1
|
||||
assert len(tool_1[0].tool_calls) == 2
|
||||
assert tool_1[0].tool_calls[0].name == "get_weather"
|
||||
assert tool_1[0].tool_calls[1].name == "get_time"
|
||||
|
||||
# ── Add assistant + both tool results ──
|
||||
messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"reasoning_content": "Two requests: weather and time. I'll call both.",
|
||||
"tool_calls": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": '{"city": "NYC"}',
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_time",
|
||||
"arguments": '{"timezone": "EST"}',
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
)
|
||||
messages.append({"role": "tool", "content": "72°F, sunny"})
|
||||
messages.append({"role": "tool", "content": "2:30 PM EST"})
|
||||
|
||||
prompt_2 = encode_messages(
|
||||
messages, tools=_WEATHER_TOOLS, thinking_mode="thinking"
|
||||
)
|
||||
|
||||
# Verify multi-tool result encoding
|
||||
# Count is 2: 1 in _TOOLS_SYSTEM_TEMPLATE example + 1 in conversation
|
||||
assert prompt_2.count("<function_results>") == 2
|
||||
assert prompt_2.count("</function_results>") == 2
|
||||
assert "<result>72°F, sunny</result>" in prompt_2
|
||||
assert "<result>2:30 PM EST</result>" in prompt_2
|
||||
assert prompt_2.endswith(THINKING_START)
|
||||
|
||||
# ── Model thinks about results, answers ──
|
||||
model_tokens_2 = [
|
||||
THINKING_START,
|
||||
"Got both results. Weather is 72°F sunny, time is 2:30 PM.",
|
||||
THINKING_END,
|
||||
"In NYC it's currently 72°F and sunny. The time in EST is 2:30 PM.",
|
||||
]
|
||||
results_2 = list(parse_deepseek_v32(_simulate_tokens(model_tokens_2)))
|
||||
|
||||
tool_2 = [r for r in results_2 if isinstance(r, ToolCallResponse)]
|
||||
gen_2 = [r for r in results_2 if isinstance(r, GenerationResponse)]
|
||||
non_thinking_2 = [r for r in gen_2 if not r.is_thinking]
|
||||
|
||||
assert len(tool_2) == 0
|
||||
final_text = "".join(r.text for r in non_thinking_2)
|
||||
assert "72°F" in final_text
|
||||
assert "2:30 PM" in final_text
|
||||
|
||||
def test_two_user_turns_reasoning_stripped(self):
|
||||
"""Turn 2: old reasoning_content is stripped from history.
|
||||
|
||||
Per the vLLM spec, clear_reasoning_content is called between user turns
|
||||
to save bandwidth. Our _drop_old_thinking handles this.
|
||||
"""
|
||||
# Full turn 1 conversation (already completed)
|
||||
messages: list[dict[str, Any]] = [
|
||||
{"role": "system", "content": "You are helpful."},
|
||||
{"role": "user", "content": "Weather in Hangzhou?"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"reasoning_content": "I need to call get_weather for Hangzhou.",
|
||||
"tool_calls": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": '{"city": "Hangzhou"}',
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
{"role": "tool", "content": "Cloudy 7~13°C"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "The weather in Hangzhou is cloudy, 7-13°C.",
|
||||
"reasoning_content": "The tool returned cloudy weather. I'll summarize.",
|
||||
},
|
||||
# Turn 2: user asks again
|
||||
{"role": "user", "content": "What about Beijing?"},
|
||||
]
|
||||
|
||||
prompt = encode_messages(
|
||||
messages, tools=_WEATHER_TOOLS, thinking_mode="thinking"
|
||||
)
|
||||
|
||||
# Old reasoning_content from turn 1 assistants should be STRIPPED
|
||||
# (they're before the last user message at index 5)
|
||||
assert "I need to call get_weather" not in prompt
|
||||
assert "tool returned cloudy" not in prompt
|
||||
|
||||
# But the assistant's content and tool calls should still be there
|
||||
assert "cloudy, 7-13°C" in prompt
|
||||
assert TOOL_CALLS_START in prompt
|
||||
|
||||
# Prompt ends with <think> for the new turn
|
||||
assert prompt.endswith(THINKING_START)
|
||||
|
||||
# ── Turn 2: Model thinks, calls tool for Beijing ──
|
||||
model_tokens = [
|
||||
THINKING_START,
|
||||
"Now the user wants Beijing weather.",
|
||||
THINKING_END,
|
||||
"\n\n",
|
||||
TOOL_CALLS_START,
|
||||
"\n",
|
||||
f'<{DSML_TOKEN}invoke name="get_weather">\n',
|
||||
f'<{DSML_TOKEN}parameter name="city" string="true">Beijing</{DSML_TOKEN}parameter>\n',
|
||||
f"</{DSML_TOKEN}invoke>\n",
|
||||
TOOL_CALLS_END,
|
||||
]
|
||||
results = list(parse_deepseek_v32(_simulate_tokens(model_tokens)))
|
||||
tool_results = [r for r in results if isinstance(r, ToolCallResponse)]
|
||||
|
||||
assert len(tool_results) == 1
|
||||
assert tool_results[0].tool_calls[0].name == "get_weather"
|
||||
args = json.loads(tool_results[0].tool_calls[0].arguments) # pyright: ignore[reportAny]
|
||||
assert args == {"city": "Beijing"}
|
||||
|
||||
def test_chained_tool_calls_loop(self):
|
||||
"""Model calls tool, gets result, calls another tool, gets result, answers.
|
||||
|
||||
This simulates the inner while loop from the vLLM spec where the model
|
||||
may need multiple sub-turns of tool calling before it has enough info.
|
||||
"""
|
||||
# ── Sub-turn 1: user asks, model calls get_time ──
|
||||
messages: list[dict[str, Any]] = [
|
||||
{"role": "system", "content": "You are helpful."},
|
||||
{"role": "user", "content": "What's the weather in Hangzhou tomorrow?"},
|
||||
]
|
||||
|
||||
prompt_1 = encode_messages(
|
||||
messages, tools=_WEATHER_TOOLS, thinking_mode="thinking"
|
||||
)
|
||||
assert prompt_1.endswith(THINKING_START)
|
||||
|
||||
# Model first calls get_time to figure out the date
|
||||
model_tokens_1 = [
|
||||
THINKING_START,
|
||||
"I need the current date first to calculate tomorrow.",
|
||||
THINKING_END,
|
||||
"\n\n",
|
||||
TOOL_CALLS_START,
|
||||
"\n",
|
||||
f'<{DSML_TOKEN}invoke name="get_time">\n',
|
||||
f'<{DSML_TOKEN}parameter name="timezone" string="true">Asia/Shanghai</{DSML_TOKEN}parameter>\n',
|
||||
f"</{DSML_TOKEN}invoke>\n",
|
||||
TOOL_CALLS_END,
|
||||
]
|
||||
results_1 = list(parse_deepseek_v32(_simulate_tokens(model_tokens_1)))
|
||||
tool_1 = [r for r in results_1 if isinstance(r, ToolCallResponse)]
|
||||
assert len(tool_1) == 1
|
||||
assert tool_1[0].tool_calls[0].name == "get_time"
|
||||
|
||||
# ── Sub-turn 2: add tool result, model calls get_weather ──
|
||||
messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"reasoning_content": "I need the current date first to calculate tomorrow.",
|
||||
"tool_calls": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_time",
|
||||
"arguments": '{"timezone": "Asia/Shanghai"}',
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
messages.append({"role": "tool", "content": "2025-12-01 14:30 CST"})
|
||||
|
||||
prompt_2 = encode_messages(
|
||||
messages, tools=_WEATHER_TOOLS, thinking_mode="thinking"
|
||||
)
|
||||
assert "<result>2025-12-01 14:30 CST</result>" in prompt_2
|
||||
assert prompt_2.endswith(THINKING_START)
|
||||
|
||||
# Model now knows the date, calls get_weather
|
||||
model_tokens_2 = [
|
||||
THINKING_START,
|
||||
"Today is 2025-12-01, so tomorrow is 2025-12-02.",
|
||||
" Now I can check weather for Hangzhou.",
|
||||
THINKING_END,
|
||||
"\n\n",
|
||||
TOOL_CALLS_START,
|
||||
"\n",
|
||||
f'<{DSML_TOKEN}invoke name="get_weather">\n',
|
||||
f'<{DSML_TOKEN}parameter name="city" string="true">Hangzhou</{DSML_TOKEN}parameter>\n',
|
||||
f"</{DSML_TOKEN}invoke>\n",
|
||||
TOOL_CALLS_END,
|
||||
]
|
||||
results_2 = list(parse_deepseek_v32(_simulate_tokens(model_tokens_2)))
|
||||
tool_2 = [r for r in results_2 if isinstance(r, ToolCallResponse)]
|
||||
assert len(tool_2) == 1
|
||||
assert tool_2[0].tool_calls[0].name == "get_weather"
|
||||
|
||||
# ── Sub-turn 3: add weather result, model answers ──
|
||||
messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"reasoning_content": "Today is 2025-12-01, so tomorrow is 2025-12-02. Now I can check weather for Hangzhou.",
|
||||
"tool_calls": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": '{"city": "Hangzhou"}',
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
messages.append({"role": "tool", "content": "Sunny, 5~12°C"})
|
||||
|
||||
prompt_3 = encode_messages(
|
||||
messages, tools=_WEATHER_TOOLS, thinking_mode="thinking"
|
||||
)
|
||||
# Should have both function_results blocks (one per tool round)
|
||||
# Count is 3: 1 in _TOOLS_SYSTEM_TEMPLATE example + 2 in conversation
|
||||
assert prompt_3.count("<function_results>") == 3
|
||||
assert prompt_3.count("</function_results>") == 3
|
||||
assert "<result>2025-12-01 14:30 CST</result>" in prompt_3
|
||||
assert "<result>Sunny, 5~12°C</result>" in prompt_3
|
||||
assert prompt_3.endswith(THINKING_START)
|
||||
|
||||
# Model finally answers
|
||||
model_tokens_3 = [
|
||||
THINKING_START,
|
||||
"I have the weather for tomorrow in Hangzhou.",
|
||||
THINKING_END,
|
||||
"Tomorrow in Hangzhou will be sunny with temperatures between 5°C and 12°C.",
|
||||
]
|
||||
results_3 = list(parse_deepseek_v32(_simulate_tokens(model_tokens_3)))
|
||||
|
||||
tool_3 = [r for r in results_3 if isinstance(r, ToolCallResponse)]
|
||||
gen_3 = [r for r in results_3 if isinstance(r, GenerationResponse)]
|
||||
non_thinking_3 = [r for r in gen_3 if not r.is_thinking]
|
||||
|
||||
assert len(tool_3) == 0 # No more tool calls — loop ends
|
||||
final_text = "".join(r.text for r in non_thinking_3)
|
||||
assert "sunny" in final_text.lower()
|
||||
assert "5°C" in final_text
|
||||
assert "12°C" in final_text
|
||||
@@ -148,6 +148,7 @@ class MockTokenizer:
|
||||
tool_call_start = None
|
||||
tool_call_end = None
|
||||
has_tool_calling = False
|
||||
has_thinking = False
|
||||
|
||||
|
||||
class MockGroup:
|
||||
|
||||
@@ -149,12 +149,23 @@ class TestParseGptOssThinkingThenToolCall:
|
||||
def test_thinking_then_tool_call(self):
|
||||
results = _collect(THINKING_THEN_TOOL_TOKENS)
|
||||
|
||||
# Should have thinking tags + content + tool call
|
||||
text_parts = [r.text for r in results if isinstance(r, GenerationResponse)]
|
||||
combined = "".join(text_parts)
|
||||
assert "<think>" in combined
|
||||
assert "</think>" in combined
|
||||
assert "Let me think about this." in combined
|
||||
# Thinking tokens should have is_thinking=True and no <think> tags
|
||||
thinking_responses = [
|
||||
r for r in results if isinstance(r, GenerationResponse) and r.is_thinking
|
||||
]
|
||||
thinking_text = "".join(r.text for r in thinking_responses)
|
||||
assert "Let me think about this." in thinking_text
|
||||
assert "<think>" not in thinking_text
|
||||
assert "</think>" not in thinking_text
|
||||
|
||||
# Non-thinking tokens should have is_thinking=False
|
||||
non_thinking = [
|
||||
r
|
||||
for r in results
|
||||
if isinstance(r, GenerationResponse) and not r.is_thinking
|
||||
]
|
||||
non_thinking_text = "".join(r.text for r in non_thinking)
|
||||
assert "<think>" not in non_thinking_text
|
||||
|
||||
# And the tool call
|
||||
tc = _get_tool_call(results)
|
||||
|
||||
8
tmp/config_examples/claude_code.sh
Executable file
8
tmp/config_examples/claude_code.sh
Executable file
@@ -0,0 +1,8 @@
|
||||
#!/bin/bash
|
||||
# Run Claude Code against a local exo cluster! (Here, GPT OSS 120B)
|
||||
ANTHROPIC_BASE_URL="http://localhost:52415/" \
|
||||
ANTHROPIC_AUTH_TOKEN="dummy" \
|
||||
ANTHROPIC_MODEL="mlx-community/gpt-oss-120b-MXFP4-Q8" \
|
||||
ANTHROPIC_SMALL_FAST_MODEL="mlx-community/gpt-oss-120b-MXFP4-Q8" \
|
||||
CLAUDE_CODE_DISABLE_NONESSENTIAL_TRAFFIC=1 \
|
||||
claude
|
||||
Reference in New Issue
Block a user