Compare commits

...

16 Commits

Author SHA1 Message Date
rltakashige
27b4993e64 Merge branch 'main' into fix-partial-download-progress 2026-02-20 12:44:19 +00:00
Alex Cheema
bddad7e79c feat: show ETA on prefill progress bar (#1557)
## Summary
- Show estimated time remaining during prefill (prompt processing phase)
- Track prefill start time via performance.now() and extrapolate from
observed token throughput
- Display ~Xs remaining or ~Xm Ys remaining next to the percentage on
the progress bar
- Wait 200ms before showing ETA to ensure a stable sample window

## Changes
**PrefillProgressBar.svelte**: Add etaText derived computation that
calculates remaining time from (remainingTokens / tokensPerMs). Renders
in a new flex row below the progress bar alongside the percentage.

**app.svelte.ts**: Add startedAt: number field to PrefillProgress
interface. Set on first prefill_progress SSE event, preserved across
subsequent updates.

## Test plan
- [ ] Start inference with a long prompt (10k+ tokens) on a multi-node
cluster
- [ ] Verify the progress bar shows ~Xs remaining after ~200ms of
prefill
- [ ] Verify the ETA decreases as prefill progresses
- [ ] Verify short prefills (<200ms) dont flash a briefly-visible ETA
- [ ] Verify ETA disappears when prefill completes and token generation
begins

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: rltakashige <rl.takashige@gmail.com>
2026-02-20 12:37:56 +00:00
rltakashige
addf73a144 Add support for Ollama API (#1560)
## Motivation

Ollama has a bunch of integrations, such as OpenWebUI, that are very
handy. Let's support it :)

## Test Plan

### Manual Testing
<img width="3426" height="1998" alt="image"
src="https://github.com/user-attachments/assets/44b07f1e-308e-4ff1-9a11-922d8279939f"
/>
2026-02-20 12:03:27 +00:00
Mustafa Alp Yılmaz
a16ff2c047 fix: correct misleading docstring in seed_models (#1561)
## Summary
- Fixed stale docstring in `seed_models()` that referenced
`.cache/huggingface/hub` when the function actually moves models to
`EXO_MODELS_DIR` (resolved via `ensure_models_dir()`)
- The old docstring was misleading for AI coding agents analyzing the
codebase, causing incorrect conclusions about model storage paths

## Changes
`src/exo/download/download_utils.py`: Updated docstring from `"Move
model in resources folder of app to .cache/huggingface/hub"` to `"Move
models from resources folder to EXO_MODELS_DIR."`

Co-authored-by: rltakashige <rl.takashige@gmail.com>
2026-02-20 11:57:55 +00:00
rltakashige
3006c8ea4e Ensure coordinator is rank 0 (#1559)
## Motivation

Coordinator can be a random rank. Let's just fix this to rank 0 as
that's what we typically assume.

## Test Plan

### Manual Testing
Works as normal on 2 nodes.


Let's wait for a little more testing to merge this.

---------

Co-authored-by: Evan <evanev7@gmail.com>
2026-02-20 11:46:24 +00:00
rltakashige
f662c129dd Prioritise tb for ring instances (#1556)
## Motivation

TB has better bandwidth and latency than ethernet. We should prioritise
TB5 where possible. This drastically improves distributed image
generation performance.

## Test Plan

### Manual Testing
Saw on the dashboard that TB (169.254) addresses were prioritised.

Tested that image models scale much better.

### Automated Testing
No regression on Kimi K2.5
2026-02-19 21:32:48 +00:00
Evan Quiney
c45ff9ad43 memory tidy (#1558)
add some pythonic extensions to memory, did a bunch of cleanup.
2026-02-19 21:15:33 +00:00
rltakashige
7031901ae5 Prevent common fatal crashes (#1555)
## Motivation
Occasionally, memory does not get released when we shut down. There is
no reason to delay deleting the model.

Also handles can become None during shutdown, causing TypeErrors which
are not handled and bringing down exo.

Similarly, we were closing the event sender in the wrong place.

Also let's not verify the SSL certificate for http connections to local
peers, as this is failing sometimes and crashing.

## Test Plan

### Manual Testing
No more crashes as described.
2026-02-19 20:51:17 +00:00
rltakashige
cf648a53b8 Add thinking in thinking blocks, and fix DeepSeek interleaved tool calls (#1548)
## Motivation

OpenCode shows <think> tags and not thinking blocks as we aren't
following the API specs properly.

Claude was also getting horrible prefix cache hits because it sends
headers.

## Changes

Handle thinking tokens properly by placing them in think tags for each
of the API endpoints.
Also support DeepSeekV3.2 tool calling properly as a minor feature.
Strips Claude headers at the API level.

## Test Plan

### Manual Testing
Tested OpenCode manually
Needs testing with Claude.

### Automated Testing
All CI and tests passing - added a new e2e test for DeepSeekV32 tool
parsing.
2026-02-19 18:44:49 +00:00
Alex Cheema
94b2ce6922 feat: Mac Studio en2 RDMA port warning v2 (#1551)
Rebuilt from scratch (replaces PR #1543). Detects when Mac Studio uses
RDMA over en2 (TB5 port next to Ethernet) which does not support RDMA.
Shows dismissible warning banner with hover tooltip showing affected
devices, SVG rear panel illustration, and fix instructions. 205 lines in
+page.svelte.

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: rltakashige <rl.takashige@gmail.com>
2026-02-19 18:39:17 +00:00
rltakashige
423ed0f07f Strip Claude headers to improve prefix cache hit rates (#1552)
## Motivation
Our hits are really bad at the moment (0.2%). This PR makes it 98.5% on
average.

## Changes

Also adds an example for how to run Claude using Exo.

## Why It Works
Claude sends some billing and session headers that change with each
message.

## Test Plan

### Manual Testing
Works in manual testing.
2026-02-19 18:29:34 +00:00
Evan Quiney
ed001f2409 remove prefillprogress event (#1550)
this should never have been a separate event, but i didnt quite
communicate that well when this was merged. convert PrefillProgress to a
chunk like the rest of the runner responses.

tested with Llama-3.3-70B, prefill progress events still show up in the
dashboard as usual
2026-02-19 18:23:28 +00:00
Evan Quiney
4c4c6ce99f simplify rust ident module
this is partly dead code, partly narrowing the rust-python boundary in
prep for future rewrites. no testing as this is all type safe
refactoring.
2026-02-19 17:19:31 +00:00
Jake Hillion
42e1e7322b bench: restore --danger-delete-downloads planning phase (#1542)
c2f2111b extracted shared utilities from exo_bench.py into harness.py
but accidentally dropped the run_planning_phase function and
--danger-delete-downloads CLI argument in the process.

Restored run_planning_phase in harness.py (where its dependencies now
live) and re-added the --danger-delete-downloads argument to
add_common_instance_args. Re-wired the planning phase call in
exo_bench.py's main() before the benchmark loop.
2026-02-19 15:42:02 +00:00
Alex Cheema
aa3f106fb9 fix: import ResponsesStreamEvent and DRY up SSE formatting (#1499)
## Summary
- `ResponsesStreamEvent` was defined in `openai_responses.py` as a union
of all 11 streaming event types but never imported or used anywhere in
the codebase
- Import it in the responses adapter and add a `_format_sse(event:
ResponsesStreamEvent) -> str` helper
- Replace 13 hardcoded `f"event: {type}\ndata:
{event.model_dump_json()}\n\n"` strings with `_format_sse()` calls

## Test plan
- [x] `uv run basedpyright` — 0 errors
- [x] `uv run ruff check` — all checks passed
- [x] `nix fmt` — 0 files changed
- [x] `uv run pytest` — 188 passed, 1 skipped

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-19 13:40:24 +00:00
Alex Cheema
526cd9f333 fix partial download progress showing 0% on restart
On restart, _emit_existing_download_progress() checked
downloaded_bytes_this_session to decide if a download was pending.
Since this field is always 0 in a new session, partially downloaded
models were reported as DownloadPending (0%) instead of DownloadOngoing
with their actual progress. Check downloaded_bytes (actual data on
disk) instead.

Closes #1042

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-17 10:13:24 -08:00
52 changed files with 3452 additions and 560 deletions

View File

@@ -20,6 +20,7 @@ from harness import (
instance_id_from_instance, instance_id_from_instance,
nodes_used_in_instance, nodes_used_in_instance,
resolve_model_short_id, resolve_model_short_id,
run_planning_phase,
settle_and_fetch_placements, settle_and_fetch_placements,
wait_for_instance_gone, wait_for_instance_gone,
wait_for_instance_ready, wait_for_instance_ready,
@@ -962,6 +963,21 @@ Examples:
selected.sort(key=_placement_sort_key) selected.sort(key=_placement_sort_key)
preview = selected[0] preview = selected[0]
settle_deadline = (
time.monotonic() + args.settle_timeout if args.settle_timeout > 0 else None
)
print("Planning phase: checking downloads...", file=log)
run_planning_phase(
exo,
full_model_id,
preview,
args.danger_delete_downloads,
args.timeout,
settle_deadline,
)
instance = preview["instance"] instance = preview["instance"]
instance_id = instance_id_from_instance(instance) instance_id = instance_id_from_instance(instance)
sharding = str(preview["sharding"]) sharding = str(preview["sharding"])

View File

@@ -35,6 +35,7 @@ from harness import (
instance_id_from_instance, instance_id_from_instance,
nodes_used_in_instance, nodes_used_in_instance,
resolve_model_short_id, resolve_model_short_id,
run_planning_phase,
settle_and_fetch_placements, settle_and_fetch_placements,
wait_for_instance_gone, wait_for_instance_gone,
wait_for_instance_ready, wait_for_instance_ready,
@@ -332,6 +333,20 @@ def main() -> int:
if args.dry_run: if args.dry_run:
return 0 return 0
settle_deadline = (
time.monotonic() + args.settle_timeout if args.settle_timeout > 0 else None
)
logger.info("Planning phase: checking downloads...")
run_planning_phase(
client,
full_model_id,
selected[0],
args.danger_delete_downloads,
args.timeout,
settle_deadline,
)
all_rows: list[dict[str, Any]] = [] all_rows: list[dict[str, Any]] = []
for preview in selected: for preview in selected:

View File

@@ -282,6 +282,151 @@ def settle_and_fetch_placements(
return selected return selected
def run_planning_phase(
client: ExoClient,
full_model_id: str,
preview: dict[str, Any],
danger_delete: bool,
timeout: float,
settle_deadline: float | None,
) -> None:
"""Check disk space and ensure model is downloaded before benchmarking."""
# Get model size from /models
models = client.request_json("GET", "/models") or {}
model_bytes = 0
for m in models.get("data", []):
if m.get("hugging_face_id") == full_model_id:
model_bytes = m.get("storage_size_megabytes", 0) * 1024 * 1024
break
if not model_bytes:
logger.warning(
f"Could not determine size for {full_model_id}, skipping disk check"
)
return
# Get nodes from preview
inner = unwrap_instance(preview["instance"])
node_ids = list(inner["shardAssignments"]["nodeToRunner"].keys())
runner_to_shard = inner["shardAssignments"]["runnerToShard"]
state = client.request_json("GET", "/state")
downloads = state.get("downloads", {})
node_disk = state.get("nodeDisk", {})
for node_id in node_ids:
node_downloads = downloads.get(node_id, [])
# Check if model already downloaded on this node
already_downloaded = any(
"DownloadCompleted" in p
and unwrap_instance(p["DownloadCompleted"]["shardMetadata"])["modelCard"][
"modelId"
]
== full_model_id
for p in node_downloads
)
if already_downloaded:
continue
# Wait for disk info if settle_deadline is set
disk_info = node_disk.get(node_id, {})
backoff = _SETTLE_INITIAL_BACKOFF_S
while not disk_info and settle_deadline and time.monotonic() < settle_deadline:
remaining = settle_deadline - time.monotonic()
logger.info(
f"Waiting for disk info on {node_id} ({remaining:.0f}s remaining)..."
)
time.sleep(min(backoff, remaining))
backoff = min(backoff * _SETTLE_BACKOFF_MULTIPLIER, _SETTLE_MAX_BACKOFF_S)
state = client.request_json("GET", "/state")
node_disk = state.get("nodeDisk", {})
disk_info = node_disk.get(node_id, {})
if not disk_info:
logger.warning(f"No disk info for {node_id}, skipping space check")
continue
avail = disk_info.get("available", {}).get("inBytes", 0)
if avail >= model_bytes:
continue
if not danger_delete:
raise RuntimeError(
f"Insufficient disk on {node_id}: need {model_bytes // (1024**3)}GB, "
f"have {avail // (1024**3)}GB. Use --danger-delete-downloads to free space."
)
# Delete from smallest to largest
completed = [
(
unwrap_instance(p["DownloadCompleted"]["shardMetadata"])["modelCard"][
"modelId"
],
p["DownloadCompleted"]["totalBytes"]["inBytes"],
)
for p in node_downloads
if "DownloadCompleted" in p
]
for del_model, size in sorted(completed, key=lambda x: x[1]):
logger.info(f"Deleting {del_model} from {node_id} ({size // (1024**2)}MB)")
client.request_json("DELETE", f"/download/{node_id}/{del_model}")
avail += size
if avail >= model_bytes:
break
if avail < model_bytes:
raise RuntimeError(f"Could not free enough space on {node_id}")
# Start downloads (idempotent)
for node_id in node_ids:
runner_id = inner["shardAssignments"]["nodeToRunner"][node_id]
shard = runner_to_shard[runner_id]
client.request_json(
"POST",
"/download/start",
body={
"targetNodeId": node_id,
"shardMetadata": shard,
},
)
logger.info(f"Started download on {node_id}")
# Wait for downloads
start = time.time()
while time.time() - start < timeout:
state = client.request_json("GET", "/state")
downloads = state.get("downloads", {})
all_done = True
for node_id in node_ids:
done = any(
"DownloadCompleted" in p
and unwrap_instance(p["DownloadCompleted"]["shardMetadata"])[
"modelCard"
]["modelId"]
== full_model_id
for p in downloads.get(node_id, [])
)
failed = [
p["DownloadFailed"]["errorMessage"]
for p in downloads.get(node_id, [])
if "DownloadFailed" in p
and unwrap_instance(p["DownloadFailed"]["shardMetadata"])["modelCard"][
"modelId"
]
== full_model_id
]
if failed:
raise RuntimeError(f"Download failed on {node_id}: {failed[0]}")
if not done:
all_done = False
if all_done:
return
time.sleep(1)
raise TimeoutError("Downloads did not complete in time")
def add_common_instance_args(ap: argparse.ArgumentParser) -> None: def add_common_instance_args(ap: argparse.ArgumentParser) -> None:
ap.add_argument("--host", default=os.environ.get("EXO_HOST", "localhost")) ap.add_argument("--host", default=os.environ.get("EXO_HOST", "localhost"))
ap.add_argument( ap.add_argument(
@@ -325,3 +470,8 @@ def add_common_instance_args(ap: argparse.ArgumentParser) -> None:
default=0, default=0,
help="Max seconds to wait for the cluster to produce valid placements (0 = try once).", help="Max seconds to wait for the cluster to produce valid placements (0 = try once).",
) )
ap.add_argument(
"--danger-delete-downloads",
action="store_true",
help="Delete existing models from smallest to largest to make room for benchmark model.",
)

View File

@@ -14,6 +14,21 @@
: 0, : 0,
); );
const etaText = $derived.by(() => {
if (progress.processed <= 0 || progress.total <= 0) return null;
const elapsedMs = performance.now() - progress.startedAt;
if (elapsedMs < 200) return null; // need a minimum sample window
const tokensPerMs = progress.processed / elapsedMs;
const remainingTokens = progress.total - progress.processed;
const remainingMs = remainingTokens / tokensPerMs;
const remainingSec = Math.ceil(remainingMs / 1000);
if (remainingSec <= 0) return null;
if (remainingSec < 60) return `~${remainingSec}s remaining`;
const mins = Math.floor(remainingSec / 60);
const secs = remainingSec % 60;
return `~${mins}m ${secs}s remaining`;
});
function formatTokenCount(count: number | undefined): string { function formatTokenCount(count: number | undefined): string {
if (count == null) return "0"; if (count == null) return "0";
if (count >= 1000) { if (count >= 1000) {
@@ -40,8 +55,11 @@
style="width: {percentage}%" style="width: {percentage}%"
></div> ></div>
</div> </div>
<div class="text-right text-xs text-exo-light-gray/70 mt-0.5 font-mono"> <div
{percentage}% class="flex items-center justify-between text-xs text-exo-light-gray/70 mt-0.5 font-mono"
>
<span>{etaText ?? ""}</span>
<span>{percentage}%</span>
</div> </div>
</div> </div>

View File

@@ -250,6 +250,11 @@ interface RawStateResponse {
>; >;
// Thunderbolt bridge cycles (nodes with bridge enabled forming loops) // Thunderbolt bridge cycles (nodes with bridge enabled forming loops)
thunderboltBridgeCycles?: string[][]; thunderboltBridgeCycles?: string[][];
// Disk usage per node
nodeDisk?: Record<
string,
{ total: { inBytes: number }; available: { inBytes: number } }
>;
} }
export interface MessageAttachment { export interface MessageAttachment {
@@ -276,6 +281,8 @@ export interface TokenData {
export interface PrefillProgress { export interface PrefillProgress {
processed: number; processed: number;
total: number; total: number;
/** Timestamp (performance.now()) when prefill started. */
startedAt: number;
} }
export interface Message { export interface Message {
@@ -1652,11 +1659,12 @@ class AppStore {
if (!reader) throw new Error("No response body"); if (!reader) throw new Error("No response body");
let fullContent = prefixText; let fullContent = prefixText;
let streamedThinking = "";
const collectedTokens: TokenData[] = [...tokensToKeep]; const collectedTokens: TokenData[] = [...tokensToKeep];
interface ChatCompletionChunk { interface ChatCompletionChunk {
choices?: Array<{ choices?: Array<{
delta?: { content?: string }; delta?: { content?: string; reasoning_content?: string };
logprobs?: { logprobs?: {
content?: Array<{ content?: Array<{
token: string; token: string;
@@ -1677,6 +1685,7 @@ class AppStore {
(parsed) => { (parsed) => {
const choice = parsed.choices?.[0]; const choice = parsed.choices?.[0];
const delta = choice?.delta?.content; const delta = choice?.delta?.content;
const thinkingDelta = choice?.delta?.reasoning_content;
// Collect logprobs data // Collect logprobs data
const logprobsContent = choice?.logprobs?.content; const logprobsContent = choice?.logprobs?.content;
@@ -1695,7 +1704,11 @@ class AppStore {
} }
} }
if (delta) { if (thinkingDelta) {
streamedThinking += thinkingDelta;
}
if (delta || thinkingDelta) {
if (firstTokenTime === null) { if (firstTokenTime === null) {
firstTokenTime = performance.now(); firstTokenTime = performance.now();
this.ttftMs = firstTokenTime - requestStartTime; this.ttftMs = firstTokenTime - requestStartTime;
@@ -1709,9 +1722,14 @@ class AppStore {
this.tps = ((tokenCount - tokensToKeep.length) / elapsed) * 1000; this.tps = ((tokenCount - tokensToKeep.length) / elapsed) * 1000;
} }
fullContent += delta; if (delta) {
const { displayContent, thinkingContent } = fullContent += delta;
}
const { displayContent, thinkingContent: tagThinking } =
this.stripThinkingTags(fullContent); this.stripThinkingTags(fullContent);
const combinedThinking = [streamedThinking, tagThinking]
.filter(Boolean)
.join("\n\n");
if (this.activeConversationId === targetConversationId) { if (this.activeConversationId === targetConversationId) {
this.currentResponse = displayContent; this.currentResponse = displayContent;
@@ -1723,7 +1741,7 @@ class AppStore {
messageId, messageId,
(m) => { (m) => {
m.content = displayContent; m.content = displayContent;
m.thinking = thinkingContent || undefined; m.thinking = combinedThinking || undefined;
m.tokens = [...collectedTokens]; m.tokens = [...collectedTokens];
}, },
); );
@@ -1735,11 +1753,14 @@ class AppStore {
// Final update // Final update
if (this.conversationExists(targetConversationId)) { if (this.conversationExists(targetConversationId)) {
const { displayContent, thinkingContent } = const { displayContent, thinkingContent: tagThinking } =
this.stripThinkingTags(fullContent); this.stripThinkingTags(fullContent);
const finalThinking = [streamedThinking, tagThinking]
.filter(Boolean)
.join("\n\n");
this.updateConversationMessage(targetConversationId, messageId, (m) => { this.updateConversationMessage(targetConversationId, messageId, (m) => {
m.content = displayContent; m.content = displayContent;
m.thinking = thinkingContent || undefined; m.thinking = finalThinking || undefined;
m.tokens = [...collectedTokens]; m.tokens = [...collectedTokens];
if (this.ttftMs !== null) m.ttftMs = this.ttftMs; if (this.ttftMs !== null) m.ttftMs = this.ttftMs;
if (this.tps !== null) m.tps = this.tps; if (this.tps !== null) m.tps = this.tps;
@@ -1847,11 +1868,12 @@ class AppStore {
} }
let streamedContent = ""; let streamedContent = "";
let streamedThinking = "";
const collectedTokens: TokenData[] = []; const collectedTokens: TokenData[] = [];
interface ChatCompletionChunk { interface ChatCompletionChunk {
choices?: Array<{ choices?: Array<{
delta?: { content?: string }; delta?: { content?: string; reasoning_content?: string };
logprobs?: { logprobs?: {
content?: Array<{ content?: Array<{
token: string; token: string;
@@ -1872,6 +1894,7 @@ class AppStore {
(parsed) => { (parsed) => {
const choice = parsed.choices?.[0]; const choice = parsed.choices?.[0];
const delta = choice?.delta?.content; const delta = choice?.delta?.content;
const thinkingDelta = choice?.delta?.reasoning_content;
// Collect logprobs data // Collect logprobs data
const logprobsContent = choice?.logprobs?.content; const logprobsContent = choice?.logprobs?.content;
@@ -1890,10 +1913,19 @@ class AppStore {
} }
} }
if (delta) { if (thinkingDelta) {
streamedContent += delta; streamedThinking += thinkingDelta;
const { displayContent, thinkingContent } = }
if (delta || thinkingDelta) {
if (delta) {
streamedContent += delta;
}
const { displayContent, thinkingContent: tagThinking } =
this.stripThinkingTags(streamedContent); this.stripThinkingTags(streamedContent);
const combinedThinking = [streamedThinking, tagThinking]
.filter(Boolean)
.join("\n\n");
// Only update currentResponse if target conversation is active // Only update currentResponse if target conversation is active
if (this.activeConversationId === targetConversationId) { if (this.activeConversationId === targetConversationId) {
@@ -1906,7 +1938,7 @@ class AppStore {
assistantMessage.id, assistantMessage.id,
(msg) => { (msg) => {
msg.content = displayContent; msg.content = displayContent;
msg.thinking = thinkingContent || undefined; msg.thinking = combinedThinking || undefined;
msg.tokens = [...collectedTokens]; msg.tokens = [...collectedTokens];
}, },
); );
@@ -1918,14 +1950,17 @@ class AppStore {
// Final cleanup of the message (if conversation still exists) // Final cleanup of the message (if conversation still exists)
if (this.conversationExists(targetConversationId)) { if (this.conversationExists(targetConversationId)) {
const { displayContent, thinkingContent } = const { displayContent, thinkingContent: tagThinking } =
this.stripThinkingTags(streamedContent); this.stripThinkingTags(streamedContent);
const finalThinking = [streamedThinking, tagThinking]
.filter(Boolean)
.join("\n\n");
this.updateConversationMessage( this.updateConversationMessage(
targetConversationId, targetConversationId,
assistantMessage.id, assistantMessage.id,
(msg) => { (msg) => {
msg.content = displayContent; msg.content = displayContent;
msg.thinking = thinkingContent || undefined; msg.thinking = finalThinking || undefined;
msg.tokens = [...collectedTokens]; msg.tokens = [...collectedTokens];
}, },
); );
@@ -2317,10 +2352,11 @@ class AppStore {
} }
let streamedContent = ""; let streamedContent = "";
let streamedThinking = "";
interface ChatCompletionChunk { interface ChatCompletionChunk {
choices?: Array<{ choices?: Array<{
delta?: { content?: string }; delta?: { content?: string; reasoning_content?: string };
logprobs?: { logprobs?: {
content?: Array<{ content?: Array<{
token: string; token: string;
@@ -2348,6 +2384,7 @@ class AppStore {
const choice = parsed.choices?.[0]; const choice = parsed.choices?.[0];
const tokenContent = choice?.delta?.content; const tokenContent = choice?.delta?.content;
const thinkingContent = choice?.delta?.reasoning_content;
// Collect logprobs data // Collect logprobs data
const logprobsContent = choice?.logprobs?.content; const logprobsContent = choice?.logprobs?.content;
@@ -2366,7 +2403,11 @@ class AppStore {
} }
} }
if (tokenContent) { if (thinkingContent) {
streamedThinking += thinkingContent;
}
if (tokenContent || thinkingContent) {
// Track first token for TTFT // Track first token for TTFT
if (firstTokenTime === null) { if (firstTokenTime === null) {
firstTokenTime = performance.now(); firstTokenTime = performance.now();
@@ -2383,11 +2424,16 @@ class AppStore {
this.tps = (tokenCount / elapsed) * 1000; this.tps = (tokenCount / elapsed) * 1000;
} }
streamedContent += tokenContent; if (tokenContent) {
streamedContent += tokenContent;
}
// Strip thinking tags for display and extract thinking content // Use stripThinkingTags as fallback for any <think> tags still in content
const { displayContent, thinkingContent } = const { displayContent, thinkingContent: tagThinking } =
this.stripThinkingTags(streamedContent); this.stripThinkingTags(streamedContent);
const combinedThinking = [streamedThinking, tagThinking]
.filter(Boolean)
.join("\n\n");
// Only update currentResponse if target conversation is active // Only update currentResponse if target conversation is active
if (this.activeConversationId === targetConversationId) { if (this.activeConversationId === targetConversationId) {
@@ -2400,7 +2446,7 @@ class AppStore {
assistantMessage.id, assistantMessage.id,
(msg) => { (msg) => {
msg.content = displayContent; msg.content = displayContent;
msg.thinking = thinkingContent || undefined; msg.thinking = combinedThinking || undefined;
msg.tokens = [...collectedTokens]; msg.tokens = [...collectedTokens];
}, },
); );
@@ -2420,6 +2466,7 @@ class AppStore {
this.prefillProgress = { this.prefillProgress = {
processed: inner.processed_tokens, processed: inner.processed_tokens,
total: inner.total_tokens, total: inner.total_tokens,
startedAt: this.prefillProgress?.startedAt ?? performance.now(),
}; };
}, },
}, },
@@ -2436,14 +2483,17 @@ class AppStore {
// Final cleanup of the message (if conversation still exists) // Final cleanup of the message (if conversation still exists)
if (this.conversationExists(targetConversationId)) { if (this.conversationExists(targetConversationId)) {
const { displayContent, thinkingContent } = const { displayContent, thinkingContent: tagThinking } =
this.stripThinkingTags(streamedContent); this.stripThinkingTags(streamedContent);
const finalThinking = [streamedThinking, tagThinking]
.filter(Boolean)
.join("\n\n");
this.updateConversationMessage( this.updateConversationMessage(
targetConversationId, targetConversationId,
assistantMessage.id, assistantMessage.id,
(msg) => { (msg) => {
msg.content = displayContent; msg.content = displayContent;
msg.thinking = thinkingContent || undefined; msg.thinking = finalThinking || undefined;
msg.tokens = [...collectedTokens]; msg.tokens = [...collectedTokens];
// Store performance metrics on the message // Store performance metrics on the message
if (this.ttftMs !== null) { if (this.ttftMs !== null) {

View File

@@ -114,6 +114,74 @@
}); });
let tb5InfoDismissed = $state(false); let tb5InfoDismissed = $state(false);
// Detect Mac Studio nodes using RDMA on en2 (the port next to ethernet — RDMA doesn't work there)
const macStudioEn2RdmaWarning = $derived.by(() => {
const edges = data?.edges;
const ids = tbIdentifiers;
const rdmaCtl = rdmaCtlData;
if (!edges || !ids || !rdmaCtl) return null;
const affectedConnections: Array<{
nodeId: string;
nodeName: string;
peerNodeId: string;
peerNodeName: string;
rdmaIface: string;
}> = [];
const isMacStudio = (node: (typeof data.nodes)[string] | undefined) =>
node?.system_info?.model_id === "Mac Studio";
for (const edge of edges) {
if (!edge.sourceRdmaIface && !edge.sinkRdmaIface) continue;
const sourceNode = data?.nodes?.[edge.source];
if (
isMacStudio(sourceNode) &&
edge.sourceRdmaIface === "rdma_en2" &&
rdmaCtl[edge.source]?.enabled
) {
affectedConnections.push({
nodeId: edge.source,
nodeName:
sourceNode?.friendly_name || edge.source.slice(0, 8) + "...",
peerNodeId: edge.target,
peerNodeName:
data?.nodes?.[edge.target]?.friendly_name ||
edge.target.slice(0, 8) + "...",
rdmaIface: "en2",
});
}
const sinkNode = data?.nodes?.[edge.target];
if (
isMacStudio(sinkNode) &&
edge.sinkRdmaIface === "rdma_en2" &&
rdmaCtl[edge.target]?.enabled
) {
affectedConnections.push({
nodeId: edge.target,
nodeName: sinkNode?.friendly_name || edge.target.slice(0, 8) + "...",
peerNodeId: edge.source,
peerNodeName:
sourceNode?.friendly_name || edge.source.slice(0, 8) + "...",
rdmaIface: "en2",
});
}
}
// Deduplicate by nodeId
const seen = new Set<string>();
const unique = affectedConnections.filter((c) => {
if (seen.has(c.nodeId)) return false;
seen.add(c.nodeId);
return true;
});
return unique.length > 0 ? unique : null;
});
let macStudioEn2Dismissed = $state(false);
// Helper to get friendly node name from node ID // Helper to get friendly node name from node ID
function getNodeName(nodeId: string): string { function getNodeName(nodeId: string): string {
const node = data?.nodes?.[nodeId]; const node = data?.nodes?.[nodeId];
@@ -790,10 +858,8 @@
if (!progress || typeof progress !== "object") return null; if (!progress || typeof progress !== "object") return null;
const prog = progress as Record<string, unknown>; const prog = progress as Record<string, unknown>;
const totalBytes = getBytes(prog.total_bytes ?? prog.totalBytes); const totalBytes = getBytes(prog.total);
const downloadedBytes = getBytes( const downloadedBytes = getBytes(prog.downloaded);
prog.downloaded_bytes ?? prog.downloadedBytes,
);
const speed = (prog.speed as number) ?? 0; const speed = (prog.speed as number) ?? 0;
const completedFiles = const completedFiles =
(prog.completed_files as number) ?? (prog.completedFiles as number) ?? 0; (prog.completed_files as number) ?? (prog.completedFiles as number) ?? 0;
@@ -806,8 +872,8 @@
for (const [fileName, fileData] of Object.entries(filesObj)) { for (const [fileName, fileData] of Object.entries(filesObj)) {
if (!fileData || typeof fileData !== "object") continue; if (!fileData || typeof fileData !== "object") continue;
const fd = fileData as Record<string, unknown>; const fd = fileData as Record<string, unknown>;
const fTotal = getBytes(fd.total_bytes ?? fd.totalBytes); const fTotal = getBytes(fd.total);
const fDownloaded = getBytes(fd.downloaded_bytes ?? fd.downloadedBytes); const fDownloaded = getBytes(fd.downloaded);
files.push({ files.push({
name: fileName, name: fileName,
totalBytes: fTotal, totalBytes: fTotal,
@@ -1196,7 +1262,6 @@
if (typeof value === "number") return value; if (typeof value === "number") return value;
if (value && typeof value === "object") { if (value && typeof value === "object") {
const v = value as Record<string, unknown>; const v = value as Record<string, unknown>;
if (typeof v.in_bytes === "number") return v.in_bytes;
if (typeof v.inBytes === "number") return v.inBytes; if (typeof v.inBytes === "number") return v.inBytes;
} }
return 0; return 0;
@@ -1758,7 +1823,7 @@
</script> </script>
{#snippet clusterWarnings()} {#snippet clusterWarnings()}
{#if tbBridgeCycles.length > 0 || macosVersionMismatch || (tb5WithoutRdma && !tb5InfoDismissed)} {#if tbBridgeCycles.length > 0 || macosVersionMismatch || (tb5WithoutRdma && !tb5InfoDismissed) || (macStudioEn2RdmaWarning && !macStudioEn2Dismissed)}
<div class="absolute top-4 left-4 flex flex-col gap-2 z-40"> <div class="absolute top-4 left-4 flex flex-col gap-2 z-40">
{#if tbBridgeCycles.length > 0} {#if tbBridgeCycles.length > 0}
{@const cycle = tbBridgeCycles[0]} {@const cycle = tbBridgeCycles[0]}
@@ -1923,12 +1988,260 @@
</button> </button>
</div> </div>
{/if} {/if}
{#if macStudioEn2RdmaWarning && !macStudioEn2Dismissed}
<div class="group relative" role="alert">
<div
class="flex items-center gap-2 px-3 py-2 rounded border border-red-500/50 bg-red-500/10 backdrop-blur-sm cursor-help"
>
<svg
class="w-5 h-5 text-red-400 flex-shrink-0"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
d={warningIconPath}
/>
</svg>
<span class="text-sm font-mono text-red-200">
RDMA INCOMPATIBLE PORT
</span>
<button
type="button"
onclick={() => (macStudioEn2Dismissed = true)}
class="ml-1 text-red-300/60 hover:text-red-200 transition-colors cursor-pointer"
title="Dismiss"
>
<svg
class="w-4 h-4"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
d="M6 18L18 6M6 6l12 12"
/>
</svg>
</button>
</div>
<!-- Expanded tooltip on hover -->
<div
class="absolute top-full left-0 mt-2 w-96 p-4 rounded border border-red-500/30 bg-[#1a1a1a]/95 backdrop-blur-sm opacity-0 invisible group-hover:opacity-100 group-hover:visible transition-all duration-200 z-50 shadow-lg"
>
<p class="text-xs text-white/80 mb-3">
The Thunderbolt 5 port next to the Ethernet port on Mac Studio
does
<span class="text-red-400 font-semibold">not support RDMA</span>.
Move the cable to one of the other three TB5 ports.
</p>
<div class="text-xs text-white/60 mb-3">
<span class="text-red-300">Affected:</span>
{#each macStudioEn2RdmaWarning as conn}
<div class="ml-2 mt-0.5">
<span class="text-white/80">{conn.nodeName}</span>
<span class="text-white/30">&rarr;</span>
<span class="text-white/60">{conn.peerNodeName}</span>
<span class="text-white/30 ml-1">(en2)</span>
</div>
{/each}
</div>
<!-- Mac Studio back panel illustration -->
<div class="bg-black/40 rounded p-3 mb-3">
<p
class="text-[10px] font-mono text-white/30 uppercase tracking-wider mb-2"
>
Mac Studio Rear Panel
</p>
<svg
viewBox="0 0 320 72"
class="w-full"
xmlns="http://www.w3.org/2000/svg"
>
<rect
x="1"
y="1"
width="318"
height="70"
rx="6"
ry="6"
fill="none"
stroke="rgba(255,255,255,0.12)"
stroke-width="1"
/>
<!-- TB5 port 1 -->
<rect
x="24"
y="22"
width="28"
height="14"
rx="4"
fill="none"
stroke="rgba(255,255,255,0.3)"
stroke-width="1"
/>
<text
x="38"
y="52"
text-anchor="middle"
fill="rgba(255,255,255,0.25)"
style="font-size:7px;font-family:ui-monospace,monospace;"
>TB5</text
>
<!-- TB5 port 2 -->
<rect
x="62"
y="22"
width="28"
height="14"
rx="4"
fill="none"
stroke="rgba(255,255,255,0.3)"
stroke-width="1"
/>
<text
x="76"
y="52"
text-anchor="middle"
fill="rgba(255,255,255,0.25)"
style="font-size:7px;font-family:ui-monospace,monospace;"
>TB5</text
>
<!-- TB5 port 3 -->
<rect
x="100"
y="22"
width="28"
height="14"
rx="4"
fill="none"
stroke="rgba(255,255,255,0.3)"
stroke-width="1"
/>
<text
x="114"
y="52"
text-anchor="middle"
fill="rgba(255,255,255,0.25)"
style="font-size:7px;font-family:ui-monospace,monospace;"
>TB5</text
>
<!-- TB5 port 4: INCOMPATIBLE (en2) — equally spaced with ports 1-3 -->
<rect
x="138"
y="22"
width="28"
height="14"
rx="4"
fill="rgba(239,68,68,0.1)"
stroke="rgba(239,68,68,0.7)"
stroke-width="1.5"
/>
<line
x1="142"
y1="25"
x2="162"
y2="33"
stroke="rgba(239,68,68,0.8)"
stroke-width="1.5"
stroke-linecap="round"
/>
<line
x1="162"
y1="25"
x2="142"
y2="33"
stroke="rgba(239,68,68,0.8)"
stroke-width="1.5"
stroke-linecap="round"
/>
<text
x="152"
y="52"
text-anchor="middle"
fill="rgba(239,68,68,0.6)"
style="font-size:7px;font-family:ui-monospace,monospace;font-weight:600;"
>en2</text
>
<!-- Ethernet port -->
<rect
x="196"
y="19"
width="24"
height="20"
rx="2"
fill="none"
stroke="rgba(255,255,255,0.2)"
stroke-width="1"
/>
<rect
x="200"
y="23"
width="16"
height="12"
rx="1"
fill="none"
stroke="rgba(255,255,255,0.12)"
stroke-width="0.75"
/>
<text
x="208"
y="52"
text-anchor="middle"
fill="rgba(255,255,255,0.25)"
style="font-size:7px;font-family:ui-monospace,monospace;"
>ETH</text
>
<!-- Green checkmarks on working ports -->
<circle
cx="38"
cy="62"
r="3"
fill="none"
stroke="rgba(74,222,128,0.5)"
stroke-width="0.75"
/>
<circle
cx="76"
cy="62"
r="3"
fill="none"
stroke="rgba(74,222,128,0.5)"
stroke-width="0.75"
/>
<circle
cx="114"
cy="62"
r="3"
fill="none"
stroke="rgba(74,222,128,0.5)"
stroke-width="0.75"
/>
</svg>
</div>
<p class="text-xs text-white/50">
<span class="text-green-400">Fix:</span> Move the Thunderbolt cable
to any of the three leftmost ports (all support RDMA).
</p>
</div>
</div>
{/if}
</div> </div>
{/if} {/if}
{/snippet} {/snippet}
{#snippet clusterWarningsCompact()} {#snippet clusterWarningsCompact()}
{#if tbBridgeCycles.length > 0 || macosVersionMismatch || (tb5WithoutRdma && !tb5InfoDismissed)} {#if tbBridgeCycles.length > 0 || macosVersionMismatch || (tb5WithoutRdma && !tb5InfoDismissed) || (macStudioEn2RdmaWarning && !macStudioEn2Dismissed)}
<div class="absolute top-2 left-2 flex flex-col gap-1"> <div class="absolute top-2 left-2 flex flex-col gap-1">
{#if tbBridgeCycles.length > 0} {#if tbBridgeCycles.length > 0}
<div <div
@@ -1996,6 +2309,27 @@
> >
</div> </div>
{/if} {/if}
{#if macStudioEn2RdmaWarning && !macStudioEn2Dismissed}
<div
class="flex items-center gap-1.5 px-2 py-1 rounded border border-red-500/50 bg-red-500/10 backdrop-blur-sm"
title="Mac Studio RDMA incompatible port (en2) — move cable to another TB5 port"
>
<svg
class="w-3.5 h-3.5 text-red-400"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
d={warningIconPath}
/>
</svg>
<span class="text-[10px] font-mono text-red-200">BAD RDMA PORT</span>
</div>
{/if}
</div> </div>
{/if} {/if}
{/snippet} {/snippet}

View File

@@ -74,7 +74,6 @@
if (typeof value === "number") return value; if (typeof value === "number") return value;
if (value && typeof value === "object") { if (value && typeof value === "object") {
const v = value as Record<string, unknown>; const v = value as Record<string, unknown>;
if (typeof v.in_bytes === "number") return v.in_bytes;
if (typeof v.inBytes === "number") return v.inBytes; if (typeof v.inBytes === "number") return v.inBytes;
} }
return 0; return 0;
@@ -231,23 +230,14 @@
undefined; undefined;
let cell: CellStatus; let cell: CellStatus;
if (tag === "DownloadCompleted") { if (tag === "DownloadCompleted") {
const totalBytes = getBytes( const totalBytes = getBytes(payload.total);
payload.total_bytes ?? payload.totalBytes,
);
cell = { kind: "completed", totalBytes, modelDirectory }; cell = { kind: "completed", totalBytes, modelDirectory };
} else if (tag === "DownloadOngoing") { } else if (tag === "DownloadOngoing") {
const rawProgress = const rawProgress =
payload.download_progress ?? payload.downloadProgress ?? {}; payload.download_progress ?? payload.downloadProgress ?? {};
const prog = rawProgress as Record<string, unknown>; const prog = rawProgress as Record<string, unknown>;
const totalBytes = getBytes( const totalBytes = getBytes(prog.total ?? payload.total);
prog.total_bytes ?? const downloadedBytes = getBytes(prog.downloaded);
prog.totalBytes ??
payload.total_bytes ??
payload.totalBytes,
);
const downloadedBytes = getBytes(
prog.downloaded_bytes ?? prog.downloadedBytes,
);
const speed = (prog.speed as number) ?? 0; const speed = (prog.speed as number) ?? 0;
const etaMs = const etaMs =
(prog.eta_ms as number) ?? (prog.etaMs as number) ?? 0; (prog.eta_ms as number) ?? (prog.etaMs as number) ?? 0;

View File

@@ -19,7 +19,7 @@ class ConnectionUpdate:
Whether this is a connection or disconnection event Whether this is a connection or disconnection event
""" """
@property @property
def peer_id(self) -> PeerId: def peer_id(self) -> builtins.str:
r""" r"""
Identity of the peer that we have connected to or disconnected from. Identity of the peer that we have connected to or disconnected from.
""" """
@@ -40,92 +40,22 @@ class Keypair:
Identity keypair of a node. Identity keypair of a node.
""" """
@staticmethod @staticmethod
def generate_ed25519() -> Keypair: def generate() -> Keypair:
r""" r"""
Generate a new Ed25519 keypair. Generate a new Ed25519 keypair.
""" """
@staticmethod @staticmethod
def generate_ecdsa() -> Keypair: def from_bytes(bytes: bytes) -> Keypair:
r""" r"""
Generate a new ECDSA keypair. Construct an Ed25519 keypair from secret key bytes
"""
@staticmethod
def generate_secp256k1() -> Keypair:
r"""
Generate a new Secp256k1 keypair.
"""
@staticmethod
def from_protobuf_encoding(bytes: bytes) -> Keypair:
r"""
Decode a private key from a protobuf structure and parse it as a `Keypair`.
"""
@staticmethod
def rsa_from_pkcs8(bytes: bytes) -> Keypair:
r"""
Decode an keypair from a DER-encoded secret key in PKCS#8 `PrivateKeyInfo`
format (i.e. unencrypted) as defined in [RFC5208].
[RFC5208]: https://tools.ietf.org/html/rfc5208#section-5
"""
@staticmethod
def secp256k1_from_der(bytes: bytes) -> Keypair:
r"""
Decode a keypair from a DER-encoded Secp256k1 secret key in an `ECPrivateKey`
structure as defined in [RFC5915].
[RFC5915]: https://tools.ietf.org/html/rfc5915
"""
@staticmethod
def ed25519_from_bytes(bytes: bytes) -> Keypair: ...
def to_protobuf_encoding(self) -> bytes:
r"""
Encode a private key as protobuf structure.
"""
def to_peer_id(self) -> PeerId:
r"""
Convert the `Keypair` into the corresponding `PeerId`.
"""
@typing.final
class Multiaddr:
r"""
Representation of a Multiaddr.
"""
@staticmethod
def empty() -> Multiaddr:
r"""
Create a new, empty multiaddress.
"""
@staticmethod
def with_capacity(n: builtins.int) -> Multiaddr:
r"""
Create a new, empty multiaddress with the given capacity.
"""
@staticmethod
def from_bytes(bytes: bytes) -> Multiaddr:
r"""
Parse a `Multiaddr` value from its byte slice representation.
"""
@staticmethod
def from_string(string: builtins.str) -> Multiaddr:
r"""
Parse a `Multiaddr` value from its string representation.
"""
def len(self) -> builtins.int:
r"""
Return the length in bytes of this multiaddress.
"""
def is_empty(self) -> builtins.bool:
r"""
Returns true if the length of this multiaddress is 0.
""" """
def to_bytes(self) -> bytes: def to_bytes(self) -> bytes:
r""" r"""
Return a copy of this [`Multiaddr`]'s byte representation. Get the secret key bytes underlying the keypair
""" """
def to_string(self) -> builtins.str: def to_node_id(self) -> builtins.str:
r""" r"""
Convert a Multiaddr to a string. Convert the `Keypair` into the corresponding `PeerId` string, which we use as our `NodeId`.
""" """
@typing.final @typing.final
@@ -180,37 +110,6 @@ class NoPeersSubscribedToTopicError(builtins.Exception):
def __repr__(self) -> builtins.str: ... def __repr__(self) -> builtins.str: ...
def __str__(self) -> builtins.str: ... def __str__(self) -> builtins.str: ...
@typing.final
class PeerId:
r"""
Identifier of a peer of the network.
The data is a `CIDv0` compatible multihash of the protobuf encoded public key of the peer
as specified in [specs/peer-ids](https://github.com/libp2p/specs/blob/master/peer-ids/peer-ids.md).
"""
@staticmethod
def random() -> PeerId:
r"""
Generates a random peer ID from a cryptographically secure PRNG.
This is useful for randomly walking on a DHT, or for testing purposes.
"""
@staticmethod
def from_bytes(bytes: bytes) -> PeerId:
r"""
Parses a `PeerId` from bytes.
"""
def to_bytes(self) -> bytes:
r"""
Returns a raw bytes representation of this `PeerId`.
"""
def to_base58(self) -> builtins.str:
r"""
Returns a base-58 encoded string of this `PeerId`.
"""
def __repr__(self) -> builtins.str: ...
def __str__(self) -> builtins.str: ...
@typing.final @typing.final
class ConnectionUpdateType(enum.Enum): class ConnectionUpdateType(enum.Enum):
r""" r"""

View File

@@ -1,8 +1,6 @@
use crate::ext::ResultExt as _; use crate::ext::ResultExt as _;
use libp2p::PeerId;
use libp2p::identity::Keypair; use libp2p::identity::Keypair;
use pyo3::prelude::{PyBytesMethods as _, PyModule, PyModuleMethods as _}; use pyo3::types::{PyBytes, PyBytesMethods as _};
use pyo3::types::PyBytes;
use pyo3::{Bound, PyResult, Python, pyclass, pymethods}; use pyo3::{Bound, PyResult, Python, pyclass, pymethods};
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods}; use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods};
@@ -18,142 +16,32 @@ pub struct PyKeypair(pub Keypair);
impl PyKeypair { impl PyKeypair {
/// Generate a new Ed25519 keypair. /// Generate a new Ed25519 keypair.
#[staticmethod] #[staticmethod]
fn generate_ed25519() -> Self { fn generate() -> Self {
Self(Keypair::generate_ed25519()) Self(Keypair::generate_ed25519())
} }
/// Generate a new ECDSA keypair. /// Construct an Ed25519 keypair from secret key bytes
#[staticmethod] #[staticmethod]
fn generate_ecdsa() -> Self { fn from_bytes(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
Self(Keypair::generate_ecdsa())
}
/// Generate a new Secp256k1 keypair.
#[staticmethod]
fn generate_secp256k1() -> Self {
Self(Keypair::generate_secp256k1())
}
/// Decode a private key from a protobuf structure and parse it as a `Keypair`.
#[staticmethod]
fn from_protobuf_encoding(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
let bytes = Vec::from(bytes.as_bytes());
Ok(Self(Keypair::from_protobuf_encoding(&bytes).pyerr()?))
}
/// Decode an keypair from a DER-encoded secret key in PKCS#8 `PrivateKeyInfo`
/// format (i.e. unencrypted) as defined in [RFC5208].
///
/// [RFC5208]: https://tools.ietf.org/html/rfc5208#section-5
#[staticmethod]
fn rsa_from_pkcs8(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
let mut bytes = Vec::from(bytes.as_bytes());
Ok(Self(Keypair::rsa_from_pkcs8(&mut bytes).pyerr()?))
}
/// Decode a keypair from a DER-encoded Secp256k1 secret key in an `ECPrivateKey`
/// structure as defined in [RFC5915].
///
/// [RFC5915]: https://tools.ietf.org/html/rfc5915
#[staticmethod]
fn secp256k1_from_der(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
let mut bytes = Vec::from(bytes.as_bytes());
Ok(Self(Keypair::secp256k1_from_der(&mut bytes).pyerr()?))
}
#[staticmethod]
fn ed25519_from_bytes(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
let mut bytes = Vec::from(bytes.as_bytes()); let mut bytes = Vec::from(bytes.as_bytes());
Ok(Self(Keypair::ed25519_from_bytes(&mut bytes).pyerr()?)) Ok(Self(Keypair::ed25519_from_bytes(&mut bytes).pyerr()?))
} }
/// Encode a private key as protobuf structure. /// Get the secret key bytes underlying the keypair
fn to_protobuf_encoding<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> { fn to_bytes<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
let bytes = self.0.to_protobuf_encoding().pyerr()?; let bytes = self
.0
.clone()
.try_into_ed25519()
.pyerr()?
.secret()
.as_ref()
.to_vec();
Ok(PyBytes::new(py, &bytes)) Ok(PyBytes::new(py, &bytes))
} }
/// Convert the `Keypair` into the corresponding `PeerId`. /// Convert the `Keypair` into the corresponding `PeerId` string, which we use as our `NodeId`.
fn to_peer_id(&self) -> PyPeerId { fn to_node_id(&self) -> String {
PyPeerId(self.0.public().to_peer_id()) self.0.public().to_peer_id().to_base58()
}
// /// Hidden constructor for pickling support. TODO: figure out how to do pickling...
// #[gen_stub(skip)]
// #[new]
// fn py_new(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
// Self::from_protobuf_encoding(bytes)
// }
//
// #[gen_stub(skip)]
// fn __setstate__(&mut self, state: Bound<'_, PyBytes>) -> PyResult<()> {
// *self = Self::from_protobuf_encoding(state)?;
// Ok(())
// }
//
// #[gen_stub(skip)]
// fn __getstate__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
// self.to_protobuf_encoding(py)
// }
//
// #[gen_stub(skip)]
// pub fn __getnewargs__<'py>(&self, py: Python<'py>) -> PyResult<(Bound<'py, PyBytes>,)> {
// Ok((self.to_protobuf_encoding(py)?,))
// }
}
/// Identifier of a peer of the network.
///
/// The data is a `CIDv0` compatible multihash of the protobuf encoded public key of the peer
/// as specified in [specs/peer-ids](https://github.com/libp2p/specs/blob/master/peer-ids/peer-ids.md).
#[gen_stub_pyclass]
#[pyclass(name = "PeerId", frozen)]
#[derive(Debug, Clone)]
#[repr(transparent)]
pub struct PyPeerId(pub PeerId);
#[gen_stub_pymethods]
#[pymethods]
#[allow(clippy::needless_pass_by_value)]
impl PyPeerId {
/// Generates a random peer ID from a cryptographically secure PRNG.
///
/// This is useful for randomly walking on a DHT, or for testing purposes.
#[staticmethod]
fn random() -> Self {
Self(PeerId::random())
}
/// Parses a `PeerId` from bytes.
#[staticmethod]
fn from_bytes(bytes: Bound<'_, PyBytes>) -> PyResult<Self> {
let bytes = Vec::from(bytes.as_bytes());
Ok(Self(PeerId::from_bytes(&bytes).pyerr()?))
}
/// Returns a raw bytes representation of this `PeerId`.
fn to_bytes<'py>(&self, py: Python<'py>) -> Bound<'py, PyBytes> {
let bytes = self.0.to_bytes();
PyBytes::new(py, &bytes)
}
/// Returns a base-58 encoded string of this `PeerId`.
fn to_base58(&self) -> String {
self.0.to_base58()
}
fn __repr__(&self) -> String {
format!("PeerId({})", self.to_base58())
}
fn __str__(&self) -> String {
self.to_base58()
} }
} }
pub fn ident_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<PyKeypair>()?;
m.add_class::<PyPeerId>()?;
Ok(())
}

View File

@@ -8,9 +8,10 @@ mod allow_threading;
mod ident; mod ident;
mod networking; mod networking;
use crate::ident::ident_submodule; use crate::ident::PyKeypair;
use crate::networking::networking_submodule; use crate::networking::networking_submodule;
use pyo3::prelude::PyModule; use pyo3::prelude::PyModule;
use pyo3::types::PyModuleMethods;
use pyo3::{Bound, PyResult, pyclass, pymodule}; use pyo3::{Bound, PyResult, pyclass, pymodule};
use pyo3_stub_gen::define_stub_info_gatherer; use pyo3_stub_gen::define_stub_info_gatherer;
@@ -158,7 +159,7 @@ fn main_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
// TODO: for now this is all NOT a submodule, but figure out how to make the submodule system // TODO: for now this is all NOT a submodule, but figure out how to make the submodule system
// work with maturin, where the types generate correctly, in the right folder, without // work with maturin, where the types generate correctly, in the right folder, without
// too many importing issues... // too many importing issues...
ident_submodule(m)?; m.add_class::<PyKeypair>()?;
networking_submodule(m)?; networking_submodule(m)?;
// top-level constructs // top-level constructs

View File

@@ -8,7 +8,7 @@
use crate::r#const::MPSC_CHANNEL_SIZE; use crate::r#const::MPSC_CHANNEL_SIZE;
use crate::ext::{ByteArrayExt as _, FutureExt, PyErrExt as _}; use crate::ext::{ByteArrayExt as _, FutureExt, PyErrExt as _};
use crate::ext::{ResultExt as _, TokioMpscReceiverExt as _, TokioMpscSenderExt as _}; use crate::ext::{ResultExt as _, TokioMpscReceiverExt as _, TokioMpscSenderExt as _};
use crate::ident::{PyKeypair, PyPeerId}; use crate::ident::PyKeypair;
use crate::pyclass; use crate::pyclass;
use libp2p::futures::StreamExt as _; use libp2p::futures::StreamExt as _;
use libp2p::gossipsub; use libp2p::gossipsub;
@@ -119,7 +119,7 @@ struct PyConnectionUpdate {
/// Identity of the peer that we have connected to or disconnected from. /// Identity of the peer that we have connected to or disconnected from.
#[pyo3(get)] #[pyo3(get)]
peer_id: PyPeerId, peer_id: String,
/// Remote connection's IPv4 address. /// Remote connection's IPv4 address.
#[pyo3(get)] #[pyo3(get)]
@@ -251,7 +251,7 @@ async fn networking_task(
// send connection event to channel (or exit if connection closed) // send connection event to channel (or exit if connection closed)
if let Err(e) = connection_update_tx.send(PyConnectionUpdate { if let Err(e) = connection_update_tx.send(PyConnectionUpdate {
update_type: PyConnectionUpdateType::Connected, update_type: PyConnectionUpdateType::Connected,
peer_id: PyPeerId(peer_id), peer_id: peer_id.to_base58(),
remote_ipv4, remote_ipv4,
remote_tcp_port, remote_tcp_port,
}).await { }).await {
@@ -272,7 +272,7 @@ async fn networking_task(
// send disconnection event to channel (or exit if connection closed) // send disconnection event to channel (or exit if connection closed)
if let Err(e) = connection_update_tx.send(PyConnectionUpdate { if let Err(e) = connection_update_tx.send(PyConnectionUpdate {
update_type: PyConnectionUpdateType::Disconnected, update_type: PyConnectionUpdateType::Disconnected,
peer_id: PyPeerId(peer_id), peer_id: peer_id.to_base58(),
remote_ipv4, remote_ipv4,
remote_tcp_port, remote_tcp_port,
}).await { }).await {

View File

@@ -80,7 +80,7 @@ class DownloadCoordinator:
completed = DownloadCompleted( completed = DownloadCompleted(
shard_metadata=callback_shard, shard_metadata=callback_shard,
node_id=self.node_id, node_id=self.node_id,
total_bytes=progress.total_bytes, total=progress.total,
model_directory=self._model_dir(model_id), model_directory=self._model_dir(model_id),
) )
self.download_status[model_id] = completed self.download_status[model_id] = completed
@@ -203,7 +203,7 @@ class DownloadCoordinator:
completed = DownloadCompleted( completed = DownloadCompleted(
shard_metadata=shard, shard_metadata=shard,
node_id=self.node_id, node_id=self.node_id,
total_bytes=initial_progress.total_bytes, total=initial_progress.total,
model_directory=self._model_dir(model_id), model_directory=self._model_dir(model_id),
) )
self.download_status[model_id] = completed self.download_status[model_id] = completed
@@ -332,13 +332,13 @@ class DownloadCoordinator:
status: DownloadProgress = DownloadCompleted( status: DownloadProgress = DownloadCompleted(
node_id=self.node_id, node_id=self.node_id,
shard_metadata=progress.shard, shard_metadata=progress.shard,
total_bytes=progress.total_bytes, total=progress.total,
model_directory=self._model_dir( model_directory=self._model_dir(
progress.shard.model_card.model_id progress.shard.model_card.model_id
), ),
) )
elif progress.status in ["in_progress", "not_started"]: elif progress.status in ["in_progress", "not_started"]:
if progress.downloaded_bytes_this_session.in_bytes == 0: if progress.downloaded_bytes.in_bytes == 0:
status = DownloadPending( status = DownloadPending(
node_id=self.node_id, node_id=self.node_id,
shard_metadata=progress.shard, shard_metadata=progress.shard,

View File

@@ -80,9 +80,9 @@ def map_repo_file_download_progress_to_download_progress_data(
repo_file_download_progress: RepoFileDownloadProgress, repo_file_download_progress: RepoFileDownloadProgress,
) -> DownloadProgressData: ) -> DownloadProgressData:
return DownloadProgressData( return DownloadProgressData(
downloaded_bytes=repo_file_download_progress.downloaded, downloaded=repo_file_download_progress.downloaded,
downloaded_bytes_this_session=repo_file_download_progress.downloaded_this_session, downloaded_this_session=repo_file_download_progress.downloaded_this_session,
total_bytes=repo_file_download_progress.total, total=repo_file_download_progress.total,
completed_files=1 if repo_file_download_progress.status == "complete" else 0, completed_files=1 if repo_file_download_progress.status == "complete" else 0,
total_files=1, total_files=1,
speed=repo_file_download_progress.speed, speed=repo_file_download_progress.speed,
@@ -95,9 +95,9 @@ def map_repo_download_progress_to_download_progress_data(
repo_download_progress: RepoDownloadProgress, repo_download_progress: RepoDownloadProgress,
) -> DownloadProgressData: ) -> DownloadProgressData:
return DownloadProgressData( return DownloadProgressData(
total_bytes=repo_download_progress.total_bytes, total=repo_download_progress.total,
downloaded_bytes=repo_download_progress.downloaded_bytes, downloaded=repo_download_progress.downloaded,
downloaded_bytes_this_session=repo_download_progress.downloaded_bytes_this_session, downloaded_this_session=repo_download_progress.downloaded_this_session,
completed_files=repo_download_progress.completed_files, completed_files=repo_download_progress.completed_files,
total_files=repo_download_progress.total_files, total_files=repo_download_progress.total_files,
speed=repo_download_progress.overall_speed, speed=repo_download_progress.overall_speed,
@@ -142,7 +142,7 @@ async def delete_model(model_id: ModelId) -> bool:
async def seed_models(seed_dir: str | Path): async def seed_models(seed_dir: str | Path):
"""Move model in resources folder of app to .cache/huggingface/hub""" """Move models from resources folder to EXO_MODELS_DIR."""
source_dir = Path(seed_dir) source_dir = Path(seed_dir)
dest_dir = await ensure_models_dir() dest_dir = await ensure_models_dir()
for path in source_dir.iterdir(): for path in source_dir.iterdir():
@@ -578,19 +578,20 @@ def calculate_repo_progress(
file_progress: dict[str, RepoFileDownloadProgress], file_progress: dict[str, RepoFileDownloadProgress],
all_start_time: float, all_start_time: float,
) -> RepoDownloadProgress: ) -> RepoDownloadProgress:
all_total_bytes = sum((p.total.in_bytes for p in file_progress.values()), 0) all_total = sum((p.total for p in file_progress.values()), Memory.from_bytes(0))
all_downloaded_bytes = sum( all_downloaded = sum(
(p.downloaded.in_bytes for p in file_progress.values()), 0 (p.downloaded for p in file_progress.values()), Memory.from_bytes(0)
) )
all_downloaded_bytes_this_session = sum( all_downloaded_this_session = sum(
(p.downloaded_this_session.in_bytes for p in file_progress.values()), 0 (p.downloaded_this_session for p in file_progress.values()),
Memory.from_bytes(0),
) )
elapsed_time = time.time() - all_start_time elapsed_time = time.time() - all_start_time
all_speed = ( all_speed = (
all_downloaded_bytes_this_session / elapsed_time if elapsed_time > 0 else 0 all_downloaded_this_session.in_bytes / elapsed_time if elapsed_time > 0 else 0
) )
all_eta = ( all_eta = (
timedelta(seconds=(all_total_bytes - all_downloaded_bytes) / all_speed) timedelta(seconds=(all_total - all_downloaded).in_bytes / all_speed)
if all_speed > 0 if all_speed > 0
else timedelta(seconds=0) else timedelta(seconds=0)
) )
@@ -609,11 +610,9 @@ def calculate_repo_progress(
[p for p in file_progress.values() if p.downloaded == p.total] [p for p in file_progress.values() if p.downloaded == p.total]
), ),
total_files=len(file_progress), total_files=len(file_progress),
downloaded_bytes=Memory.from_bytes(all_downloaded_bytes), downloaded=all_downloaded,
downloaded_bytes_this_session=Memory.from_bytes( downloaded_this_session=all_downloaded_this_session,
all_downloaded_bytes_this_session total=all_total,
),
total_bytes=Memory.from_bytes(all_total_bytes),
overall_speed=all_speed, overall_speed=all_speed,
overall_eta=all_eta, overall_eta=all_eta,
status=status, status=status,

View File

@@ -107,9 +107,9 @@ NOOP_DOWNLOAD_PROGRESS = RepoDownloadProgress(
), ),
completed_files=0, completed_files=0,
total_files=0, total_files=0,
downloaded_bytes=Memory.from_bytes(0), downloaded=Memory.from_bytes(0),
downloaded_bytes_this_session=Memory.from_bytes(0), downloaded_this_session=Memory.from_bytes(0),
total_bytes=Memory.from_bytes(0), total=Memory.from_bytes(0),
overall_speed=0, overall_speed=0,
overall_eta=timedelta(seconds=0), overall_eta=timedelta(seconds=0),
status="complete", status="complete",

View File

@@ -45,7 +45,7 @@ class Node:
@classmethod @classmethod
async def create(cls, args: "Args") -> "Self": async def create(cls, args: "Args") -> "Self":
keypair = get_node_id_keypair() keypair = get_node_id_keypair()
node_id = NodeId(keypair.to_peer_id().to_base58()) node_id = NodeId(keypair.to_node_id())
session_id = SessionId(master_node_id=node_id, election_clock=0) session_id = SessionId(master_node_id=node_id, election_clock=0)
router = Router.create(keypair) router = Router.create(keypair)
await router.register_topic(topics.GLOBAL_EVENTS) await router.register_topic(topics.GLOBAL_EVENTS)

View File

@@ -59,7 +59,11 @@ def chat_request_to_text_generation(
chat_template_messages.append({"role": "system", "content": content}) chat_template_messages.append({"role": "system", "content": content})
else: else:
# Skip messages with no meaningful content # Skip messages with no meaningful content
if msg.content is None and msg.thinking is None and msg.tool_calls is None: if (
msg.content is None
and msg.reasoning_content is None
and msg.tool_calls is None
):
continue continue
if msg.role in ("user", "assistant", "developer"): if msg.role in ("user", "assistant", "developer"):
@@ -111,6 +115,11 @@ def chunk_to_response(
] ]
) )
if chunk.is_thinking:
delta = ChatCompletionMessage(role="assistant", reasoning_content=chunk.text)
else:
delta = ChatCompletionMessage(role="assistant", content=chunk.text)
return ChatCompletionResponse( return ChatCompletionResponse(
id=command_id, id=command_id,
created=int(time.time()), created=int(time.time()),
@@ -118,7 +127,7 @@ def chunk_to_response(
choices=[ choices=[
StreamingChoiceResponse( StreamingChoiceResponse(
index=0, index=0,
delta=ChatCompletionMessage(role="assistant", content=chunk.text), delta=delta,
logprobs=logprobs, logprobs=logprobs,
finish_reason=chunk.finish_reason, finish_reason=chunk.finish_reason,
) )
@@ -208,6 +217,7 @@ async def collect_chat_response(
# FastAPI handles the cancellation better but wouldn't auto-serialize for some reason # FastAPI handles the cancellation better but wouldn't auto-serialize for some reason
"""Collect all token chunks and return a single ChatCompletionResponse.""" """Collect all token chunks and return a single ChatCompletionResponse."""
text_parts: list[str] = [] text_parts: list[str] = []
thinking_parts: list[str] = []
tool_calls: list[ToolCall] = [] tool_calls: list[ToolCall] = []
logprobs_content: list[LogprobsContentItem] = [] logprobs_content: list[LogprobsContentItem] = []
model: str | None = None model: str | None = None
@@ -228,7 +238,10 @@ async def collect_chat_response(
if model is None: if model is None:
model = chunk.model model = chunk.model
last_usage = chunk.usage or last_usage last_usage = chunk.usage or last_usage
text_parts.append(chunk.text) if chunk.is_thinking:
thinking_parts.append(chunk.text)
else:
text_parts.append(chunk.text)
if chunk.logprob is not None: if chunk.logprob is not None:
logprobs_content.append( logprobs_content.append(
LogprobsContentItem( LogprobsContentItem(
@@ -258,6 +271,7 @@ async def collect_chat_response(
raise ValueError(error_message) raise ValueError(error_message)
combined_text = "".join(text_parts) combined_text = "".join(text_parts)
combined_thinking = "".join(thinking_parts) if thinking_parts else None
assert model is not None assert model is not None
yield ChatCompletionResponse( yield ChatCompletionResponse(
@@ -270,6 +284,7 @@ async def collect_chat_response(
message=ChatCompletionMessage( message=ChatCompletionMessage(
role="assistant", role="assistant",
content=combined_text, content=combined_text,
reasoning_content=combined_thinking,
tool_calls=tool_calls if tool_calls else None, tool_calls=tool_calls if tool_calls else None,
), ),
logprobs=Logprobs(content=logprobs_content) logprobs=Logprobs(content=logprobs_content)

View File

@@ -1,6 +1,7 @@
"""Claude Messages API adapter for converting requests/responses.""" """Claude Messages API adapter for converting requests/responses."""
import json import json
import re
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
from typing import Any from typing import Any
@@ -28,6 +29,8 @@ from exo.shared.types.claude_api import (
ClaudeStopReason, ClaudeStopReason,
ClaudeTextBlock, ClaudeTextBlock,
ClaudeTextDelta, ClaudeTextDelta,
ClaudeThinkingBlock,
ClaudeThinkingDelta,
ClaudeToolResultBlock, ClaudeToolResultBlock,
ClaudeToolUseBlock, ClaudeToolUseBlock,
ClaudeUsage, ClaudeUsage,
@@ -61,6 +64,22 @@ def _extract_tool_result_text(block: ClaudeToolResultBlock) -> str:
return "".join(sub_block.text for sub_block in block.content) return "".join(sub_block.text for sub_block in block.content)
# Matches "x-anthropic-billing-header: ...;" (with optional trailing newline)
# or similar telemetry headers that change every request and break KV prefix caching.
_VOLATILE_HEADER_RE = re.compile(r"^x-anthropic-[^\n]*;\n?", re.MULTILINE)
def _strip_volatile_headers(text: str) -> str:
"""Remove Anthropic billing/telemetry headers from system prompt text.
Claude Code prepends headers like 'x-anthropic-billing-header: cc_version=...;
cc_entrypoint=...; cch=...;' that contain per-request content hashes. These
change every request and break KV prefix caching (the prefix diverges at ~20
tokens instead of matching thousands of conversation tokens).
"""
return _VOLATILE_HEADER_RE.sub("", text)
def claude_request_to_text_generation( def claude_request_to_text_generation(
request: ClaudeMessagesRequest, request: ClaudeMessagesRequest,
) -> TextGenerationTaskParams: ) -> TextGenerationTaskParams:
@@ -73,6 +92,8 @@ def claude_request_to_text_generation(
instructions = request.system instructions = request.system
else: else:
instructions = "".join(block.text for block in request.system) instructions = "".join(block.text for block in request.system)
instructions = _strip_volatile_headers(instructions)
chat_template_messages.append({"role": "system", "content": instructions}) chat_template_messages.append({"role": "system", "content": instructions})
# Convert messages to input # Convert messages to input
@@ -85,12 +106,15 @@ def claude_request_to_text_generation(
# Process structured content blocks # Process structured content blocks
text_parts: list[str] = [] text_parts: list[str] = []
thinking_parts: list[str] = []
tool_calls: list[dict[str, Any]] = [] tool_calls: list[dict[str, Any]] = []
tool_results: list[ClaudeToolResultBlock] = [] tool_results: list[ClaudeToolResultBlock] = []
for block in msg.content: for block in msg.content:
if isinstance(block, ClaudeTextBlock): if isinstance(block, ClaudeTextBlock):
text_parts.append(block.text) text_parts.append(block.text)
elif isinstance(block, ClaudeThinkingBlock):
thinking_parts.append(block.thinking)
elif isinstance(block, ClaudeToolUseBlock): elif isinstance(block, ClaudeToolUseBlock):
tool_calls.append( tool_calls.append(
{ {
@@ -106,6 +130,7 @@ def claude_request_to_text_generation(
tool_results.append(block) tool_results.append(block)
content = "".join(text_parts) content = "".join(text_parts)
reasoning_content = "".join(thinking_parts) if thinking_parts else None
# Build InputMessage from text content # Build InputMessage from text content
if msg.role in ("user", "assistant"): if msg.role in ("user", "assistant"):
@@ -113,9 +138,14 @@ def claude_request_to_text_generation(
# Build chat_template_messages preserving tool structure # Build chat_template_messages preserving tool structure
if tool_calls: if tool_calls:
chat_template_messages.append( chat_msg: dict[str, Any] = {
{"role": "assistant", "content": content, "tool_calls": tool_calls} "role": "assistant",
) "content": content,
"tool_calls": tool_calls,
}
if reasoning_content:
chat_msg["reasoning_content"] = reasoning_content
chat_template_messages.append(chat_msg)
elif tool_results: elif tool_results:
for tr in tool_results: for tr in tool_results:
chat_template_messages.append( chat_template_messages.append(
@@ -126,7 +156,10 @@ def claude_request_to_text_generation(
} }
) )
else: else:
chat_template_messages.append({"role": msg.role, "content": content}) chat_msg = {"role": msg.role, "content": content}
if reasoning_content:
chat_msg["reasoning_content"] = reasoning_content
chat_template_messages.append(chat_msg)
# Convert Claude tool definitions to OpenAI-style function tools # Convert Claude tool definitions to OpenAI-style function tools
tools: list[dict[str, Any]] | None = None tools: list[dict[str, Any]] | None = None
@@ -143,6 +176,10 @@ def claude_request_to_text_generation(
for tool in request.tools for tool in request.tools
] ]
enable_thinking: bool | None = None
if request.thinking is not None:
enable_thinking = request.thinking.type in ("enabled", "adaptive")
return TextGenerationTaskParams( return TextGenerationTaskParams(
model=request.model, model=request.model,
input=input_messages input=input_messages
@@ -156,6 +193,7 @@ def claude_request_to_text_generation(
stop=request.stop_sequences, stop=request.stop_sequences,
stream=request.stream, stream=request.stream,
tools=tools, tools=tools,
enable_thinking=enable_thinking,
chat_template_messages=chat_template_messages chat_template_messages=chat_template_messages
if chat_template_messages if chat_template_messages
else None, else None,
@@ -173,6 +211,7 @@ async def collect_claude_response(
# FastAPI handles the cancellation better but wouldn't auto-serialize for some reason # FastAPI handles the cancellation better but wouldn't auto-serialize for some reason
"""Collect all token chunks and return a single ClaudeMessagesResponse.""" """Collect all token chunks and return a single ClaudeMessagesResponse."""
text_parts: list[str] = [] text_parts: list[str] = []
thinking_parts: list[str] = []
tool_use_blocks: list[ClaudeToolUseBlock] = [] tool_use_blocks: list[ClaudeToolUseBlock] = []
stop_reason: ClaudeStopReason | None = None stop_reason: ClaudeStopReason | None = None
last_usage: Usage | None = None last_usage: Usage | None = None
@@ -200,7 +239,10 @@ async def collect_claude_response(
stop_reason = "tool_use" stop_reason = "tool_use"
continue continue
text_parts.append(chunk.text) if chunk.is_thinking:
thinking_parts.append(chunk.text)
else:
text_parts.append(chunk.text)
if chunk.finish_reason is not None: if chunk.finish_reason is not None:
stop_reason = finish_reason_to_claude_stop_reason(chunk.finish_reason) stop_reason = finish_reason_to_claude_stop_reason(chunk.finish_reason)
@@ -209,9 +251,12 @@ async def collect_claude_response(
raise ValueError(error_message) raise ValueError(error_message)
combined_text = "".join(text_parts) combined_text = "".join(text_parts)
combined_thinking = "".join(thinking_parts)
# Build content blocks # Build content blocks
content: list[ClaudeContentBlock] = [] content: list[ClaudeContentBlock] = []
if combined_thinking:
content.append(ClaudeThinkingBlock(thinking=combined_thinking))
if combined_text: if combined_text:
content.append(ClaudeTextBlock(text=combined_text)) content.append(ClaudeTextBlock(text=combined_text))
content.extend(tool_use_blocks) content.extend(tool_use_blocks)
@@ -256,16 +301,16 @@ async def generate_claude_stream(
start_event = ClaudeMessageStartEvent(message=initial_message) start_event = ClaudeMessageStartEvent(message=initial_message)
yield f"event: message_start\ndata: {start_event.model_dump_json()}\n\n" yield f"event: message_start\ndata: {start_event.model_dump_json()}\n\n"
# content_block_start for text block at index 0
block_start = ClaudeContentBlockStartEvent(
index=0, content_block=ClaudeTextBlock(text="")
)
yield f"event: content_block_start\ndata: {block_start.model_dump_json()}\n\n"
output_tokens = 0 output_tokens = 0
stop_reason: ClaudeStopReason | None = None stop_reason: ClaudeStopReason | None = None
last_usage: Usage | None = None last_usage: Usage | None = None
next_block_index = 1 # text block is 0, tool blocks start at 1 next_block_index = 0
# Track whether we've started thinking/text blocks
thinking_block_started = False
thinking_block_index = -1
text_block_started = False
text_block_index = -1
async for chunk in chunk_stream: async for chunk in chunk_stream:
if isinstance(chunk, PrefillProgressChunk): if isinstance(chunk, PrefillProgressChunk):
@@ -310,12 +355,45 @@ async def generate_claude_stream(
output_tokens += 1 # Count each chunk as one token output_tokens += 1 # Count each chunk as one token
# content_block_delta if chunk.is_thinking:
delta_event = ClaudeContentBlockDeltaEvent( # Start thinking block on first thinking token
index=0, if not thinking_block_started:
delta=ClaudeTextDelta(text=chunk.text), thinking_block_started = True
) thinking_block_index = next_block_index
yield f"event: content_block_delta\ndata: {delta_event.model_dump_json()}\n\n" next_block_index += 1
block_start = ClaudeContentBlockStartEvent(
index=thinking_block_index,
content_block=ClaudeThinkingBlock(thinking=""),
)
yield f"event: content_block_start\ndata: {block_start.model_dump_json()}\n\n"
delta_event = ClaudeContentBlockDeltaEvent(
index=thinking_block_index,
delta=ClaudeThinkingDelta(thinking=chunk.text),
)
yield f"event: content_block_delta\ndata: {delta_event.model_dump_json()}\n\n"
else:
# Close thinking block when transitioning to text
if thinking_block_started and text_block_index == -1:
block_stop = ClaudeContentBlockStopEvent(index=thinking_block_index)
yield f"event: content_block_stop\ndata: {block_stop.model_dump_json()}\n\n"
# Start text block on first text token
if not text_block_started:
text_block_started = True
text_block_index = next_block_index
next_block_index += 1
block_start = ClaudeContentBlockStartEvent(
index=text_block_index,
content_block=ClaudeTextBlock(text=""),
)
yield f"event: content_block_start\ndata: {block_start.model_dump_json()}\n\n"
delta_event = ClaudeContentBlockDeltaEvent(
index=text_block_index,
delta=ClaudeTextDelta(text=chunk.text),
)
yield f"event: content_block_delta\ndata: {delta_event.model_dump_json()}\n\n"
if chunk.finish_reason is not None: if chunk.finish_reason is not None:
stop_reason = finish_reason_to_claude_stop_reason(chunk.finish_reason) stop_reason = finish_reason_to_claude_stop_reason(chunk.finish_reason)
@@ -324,9 +402,22 @@ async def generate_claude_stream(
if last_usage is not None: if last_usage is not None:
output_tokens = last_usage.completion_tokens output_tokens = last_usage.completion_tokens
# content_block_stop for text block # Close any open blocks
block_stop = ClaudeContentBlockStopEvent(index=0) if thinking_block_started and text_block_index == -1:
yield f"event: content_block_stop\ndata: {block_stop.model_dump_json()}\n\n" block_stop = ClaudeContentBlockStopEvent(index=thinking_block_index)
yield f"event: content_block_stop\ndata: {block_stop.model_dump_json()}\n\n"
if text_block_started:
block_stop = ClaudeContentBlockStopEvent(index=text_block_index)
yield f"event: content_block_stop\ndata: {block_stop.model_dump_json()}\n\n"
if not thinking_block_started and not text_block_started:
empty_start = ClaudeContentBlockStartEvent(
index=0, content_block=ClaudeTextBlock(text="")
)
yield f"event: content_block_start\ndata: {empty_start.model_dump_json()}\n\n"
empty_stop = ClaudeContentBlockStopEvent(index=0)
yield f"event: content_block_stop\ndata: {empty_stop.model_dump_json()}\n\n"
# message_delta # message_delta
message_delta = ClaudeMessageDeltaEvent( message_delta = ClaudeMessageDeltaEvent(

View File

@@ -0,0 +1,456 @@
from __future__ import annotations
import json
from collections.abc import AsyncGenerator
from typing import Any
from exo.shared.types.chunks import (
ErrorChunk,
PrefillProgressChunk,
TokenChunk,
ToolCallChunk,
)
from exo.shared.types.common import CommandId
from exo.shared.types.ollama_api import (
OllamaChatRequest,
OllamaChatResponse,
OllamaDoneReason,
OllamaGenerateRequest,
OllamaGenerateResponse,
OllamaMessage,
OllamaToolCall,
OllamaToolFunction,
)
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
def _map_done_reason(
finish_reason: str | None,
) -> OllamaDoneReason | None:
if finish_reason is None:
return None
if finish_reason == "stop":
return "stop"
if finish_reason == "length":
return "length"
if finish_reason in ("tool_calls", "function_call"):
return "tool_call"
if finish_reason == "error":
return "error"
return "stop"
def _try_parse_json(value: str) -> dict[str, Any] | str:
try:
return json.loads(value) # type: ignore
except json.JSONDecodeError:
return value
def _build_tool_calls(chunk: ToolCallChunk) -> list[OllamaToolCall]:
tool_calls: list[OllamaToolCall] = []
for index, tool in enumerate(chunk.tool_calls):
# tool.arguments is always str; try to parse as JSON dict for Ollama format
arguments: dict[str, Any] | str = _try_parse_json(tool.arguments)
tool_calls.append(
OllamaToolCall(
id=tool.id,
type="function",
function=OllamaToolFunction(
name=tool.name, arguments=arguments, index=index
),
)
)
return tool_calls
def _get_usage(
chunk: TokenChunk | ToolCallChunk,
) -> tuple[int | None, int | None]:
"""Extract (prompt_eval_count, eval_count) from a chunk."""
if chunk.usage is not None:
return (chunk.usage.prompt_tokens, chunk.usage.completion_tokens)
if chunk.stats is not None:
return (chunk.stats.prompt_tokens, chunk.stats.generation_tokens)
return (None, None)
def ollama_request_to_text_generation(
request: OllamaChatRequest,
) -> TextGenerationTaskParams:
"""Convert Ollama chat request to exo's internal text generation format."""
instructions: str | None = None
input_messages: list[InputMessage] = []
chat_template_messages: list[dict[str, Any]] = []
tool_message_index = 0
for msg in request.messages:
content = msg.content or ""
if msg.role == "system":
if instructions is None:
instructions = content
else:
instructions = f"{instructions}\n{content}"
chat_template_messages.append({"role": "system", "content": content})
continue
if msg.role in ("user", "assistant") and (
msg.content is not None or msg.thinking is not None or msg.tool_calls
):
input_messages.append(InputMessage(role=msg.role, content=content))
dumped: dict[str, Any] = {"role": msg.role, "content": content}
if msg.thinking is not None:
dumped["thinking"] = msg.thinking
if msg.tool_calls is not None:
tool_calls_list: list[dict[str, Any]] = []
for tc in msg.tool_calls:
function: dict[str, Any] = {
"name": tc.function.name,
"arguments": (
json.dumps(tc.function.arguments)
if isinstance(tc.function.arguments, dict)
else tc.function.arguments
),
}
if tc.function.index is not None:
function["index"] = tc.function.index
tool_call: dict[str, Any] = {"function": function}
if tc.id is not None:
tool_call["id"] = tc.id
if tc.type is not None:
tool_call["type"] = tc.type
tool_calls_list.append(tool_call)
dumped["tool_calls"] = tool_calls_list
if msg.name is not None:
dumped["name"] = msg.name
if msg.role == "tool":
tool_message_index += 1
tool_call_id = msg.tool_name or msg.name or f"tool_{tool_message_index}"
dumped["tool_call_id"] = tool_call_id
if msg.tool_name is not None:
dumped["tool_name"] = msg.tool_name
chat_template_messages.append(dumped)
options = request.options
return TextGenerationTaskParams(
model=request.model,
input=input_messages
if input_messages
else [InputMessage(role="user", content="")],
instructions=instructions,
max_output_tokens=options.num_predict if options else None,
temperature=options.temperature if options else None,
top_p=options.top_p if options else None,
top_k=options.top_k if options else None,
stop=options.stop if options else None,
seed=options.seed if options else None,
stream=request.stream,
tools=request.tools,
enable_thinking=request.think,
chat_template_messages=chat_template_messages
if chat_template_messages
else None,
)
async def generate_ollama_chat_stream(
_command_id: CommandId,
chunk_stream: AsyncGenerator[
ErrorChunk | ToolCallChunk | TokenChunk | PrefillProgressChunk, None
],
) -> AsyncGenerator[str, None]:
"""Generate streaming responses in Ollama format (newline-delimited JSON)."""
thinking_parts: list[str] = []
async for chunk in chunk_stream:
match chunk:
case PrefillProgressChunk():
continue
case ErrorChunk():
error_response = OllamaChatResponse(
model=str(chunk.model),
message=OllamaMessage(
role="assistant", content=chunk.error_message
),
done=True,
done_reason="error",
)
yield f"{error_response.model_dump_json(exclude_none=True)}\n"
return
case ToolCallChunk():
prompt_eval, eval_count = _get_usage(chunk)
response = OllamaChatResponse(
model=str(chunk.model),
message=OllamaMessage(
role="assistant",
content="",
tool_calls=_build_tool_calls(chunk),
thinking="".join(thinking_parts) if thinking_parts else None,
),
done=True,
done_reason="tool_call",
prompt_eval_count=prompt_eval,
eval_count=eval_count,
)
yield f"{response.model_dump_json(exclude_none=True)}\n"
return
case TokenChunk():
done = chunk.finish_reason is not None
if chunk.is_thinking:
thinking_parts.append(chunk.text)
response = OllamaChatResponse(
model=str(chunk.model),
message=OllamaMessage(
role="assistant", content="", thinking=chunk.text
),
done=False,
)
yield f"{response.model_dump_json(exclude_none=True)}\n"
elif done:
prompt_eval, eval_count = _get_usage(chunk)
response = OllamaChatResponse(
model=str(chunk.model),
message=OllamaMessage(
role="assistant",
content=chunk.text,
),
done=True,
done_reason=_map_done_reason(chunk.finish_reason),
prompt_eval_count=prompt_eval,
eval_count=eval_count,
)
yield f"{response.model_dump_json(exclude_none=True)}\n"
else:
response = OllamaChatResponse(
model=str(chunk.model),
message=OllamaMessage(role="assistant", content=chunk.text),
done=False,
)
yield f"{response.model_dump_json(exclude_none=True)}\n"
if done:
return
async def collect_ollama_chat_response(
_command_id: CommandId,
chunk_stream: AsyncGenerator[
ErrorChunk | ToolCallChunk | TokenChunk | PrefillProgressChunk, None
],
) -> AsyncGenerator[str]:
"""Collect streaming chunks into a single non-streaming Ollama response.
Returns an AsyncGenerator[str] (single yield) for consistency with FastAPI
StreamingResponse cancellation handling.
"""
text_parts: list[str] = []
thinking_parts: list[str] = []
tool_calls: list[OllamaToolCall] = []
model: str | None = None
finish_reason: str | None = None
prompt_eval_count: int | None = None
eval_count: int | None = None
async for chunk in chunk_stream:
match chunk:
case PrefillProgressChunk():
continue
case ErrorChunk():
raise ValueError(chunk.error_message or "Internal server error")
case TokenChunk():
if model is None:
model = str(chunk.model)
if chunk.is_thinking:
thinking_parts.append(chunk.text)
else:
text_parts.append(chunk.text)
if chunk.finish_reason is not None:
finish_reason = chunk.finish_reason
prompt_eval_count, eval_count = _get_usage(chunk)
case ToolCallChunk():
if model is None:
model = str(chunk.model)
tool_calls.extend(_build_tool_calls(chunk))
finish_reason = chunk.finish_reason
prompt_eval_count, eval_count = _get_usage(chunk)
combined_text = "".join(text_parts)
combined_thinking = "".join(thinking_parts) if thinking_parts else None
assert model is not None
yield OllamaChatResponse(
model=model,
message=OllamaMessage(
role="assistant",
content=combined_text,
thinking=combined_thinking,
tool_calls=tool_calls if tool_calls else None,
),
done=True,
done_reason=_map_done_reason(finish_reason),
prompt_eval_count=prompt_eval_count,
eval_count=eval_count,
).model_dump_json(exclude_none=True)
return
# ── /api/generate ──
def ollama_generate_request_to_text_generation(
request: OllamaGenerateRequest,
) -> TextGenerationTaskParams:
"""Convert Ollama generate request to exo's internal text generation format."""
chat_template_messages: list[dict[str, Any]] = []
if request.system:
chat_template_messages.append({"role": "system", "content": request.system})
chat_template_messages.append({"role": "user", "content": request.prompt})
options = request.options
return TextGenerationTaskParams(
model=request.model,
input=[InputMessage(role="user", content=request.prompt)],
instructions=request.system,
max_output_tokens=options.num_predict if options else None,
temperature=options.temperature if options else None,
top_p=options.top_p if options else None,
top_k=options.top_k if options else None,
stop=options.stop if options else None,
seed=options.seed if options else None,
stream=request.stream,
enable_thinking=request.think,
chat_template_messages=chat_template_messages
if chat_template_messages
else None,
)
async def generate_ollama_generate_stream(
_command_id: CommandId,
chunk_stream: AsyncGenerator[
ErrorChunk | ToolCallChunk | TokenChunk | PrefillProgressChunk, None
],
) -> AsyncGenerator[str, None]:
"""Generate streaming responses for /api/generate in Ollama NDJSON format."""
thinking_parts: list[str] = []
async for chunk in chunk_stream:
match chunk:
case PrefillProgressChunk():
continue
case ErrorChunk():
resp = OllamaGenerateResponse(
model=str(chunk.model),
response="",
done=True,
done_reason="error",
)
yield f"{resp.model_dump_json(exclude_none=True)}\n"
return
case ToolCallChunk():
# generate endpoint doesn't support tools; emit as done
prompt_eval, eval_count = _get_usage(chunk)
resp = OllamaGenerateResponse(
model=str(chunk.model),
response="",
done=True,
done_reason="stop",
prompt_eval_count=prompt_eval,
eval_count=eval_count,
)
yield f"{resp.model_dump_json(exclude_none=True)}\n"
return
case TokenChunk():
done = chunk.finish_reason is not None
if chunk.is_thinking:
thinking_parts.append(chunk.text)
resp = OllamaGenerateResponse(
model=str(chunk.model),
response="",
thinking=chunk.text,
done=False,
)
yield f"{resp.model_dump_json(exclude_none=True)}\n"
elif done:
prompt_eval, eval_count = _get_usage(chunk)
resp = OllamaGenerateResponse(
model=str(chunk.model),
response=chunk.text,
done=True,
done_reason=_map_done_reason(chunk.finish_reason),
prompt_eval_count=prompt_eval,
eval_count=eval_count,
)
yield f"{resp.model_dump_json(exclude_none=True)}\n"
else:
resp = OllamaGenerateResponse(
model=str(chunk.model),
response=chunk.text,
done=False,
)
yield f"{resp.model_dump_json(exclude_none=True)}\n"
if done:
return
async def collect_ollama_generate_response(
_command_id: CommandId,
chunk_stream: AsyncGenerator[
ErrorChunk | ToolCallChunk | TokenChunk | PrefillProgressChunk, None
],
) -> AsyncGenerator[str]:
"""Collect chunks into a single non-streaming /api/generate response."""
text_parts: list[str] = []
thinking_parts: list[str] = []
model: str | None = None
finish_reason: str | None = None
prompt_eval_count: int | None = None
eval_count: int | None = None
async for chunk in chunk_stream:
match chunk:
case PrefillProgressChunk():
continue
case ErrorChunk():
raise ValueError(chunk.error_message or "Internal server error")
case TokenChunk():
if model is None:
model = str(chunk.model)
if chunk.is_thinking:
thinking_parts.append(chunk.text)
else:
text_parts.append(chunk.text)
if chunk.finish_reason is not None:
finish_reason = chunk.finish_reason
prompt_eval_count, eval_count = _get_usage(chunk)
case ToolCallChunk():
if model is None:
model = str(chunk.model)
finish_reason = chunk.finish_reason
prompt_eval_count, eval_count = _get_usage(chunk)
assert model is not None
yield OllamaGenerateResponse(
model=model,
response="".join(text_parts),
thinking="".join(thinking_parts) if thinking_parts else None,
done=True,
done_reason=_map_done_reason(finish_reason),
prompt_eval_count=prompt_eval_count,
eval_count=eval_count,
).model_dump_json(exclude_none=True)
return

View File

@@ -29,8 +29,15 @@ from exo.shared.types.openai_responses import (
ResponseOutputItemAddedEvent, ResponseOutputItemAddedEvent,
ResponseOutputItemDoneEvent, ResponseOutputItemDoneEvent,
ResponseOutputText, ResponseOutputText,
ResponseReasoningItem,
ResponseReasoningSummaryPartAddedEvent,
ResponseReasoningSummaryPartDoneEvent,
ResponseReasoningSummaryText,
ResponseReasoningSummaryTextDeltaEvent,
ResponseReasoningSummaryTextDoneEvent,
ResponsesRequest, ResponsesRequest,
ResponsesResponse, ResponsesResponse,
ResponsesStreamEvent,
ResponseTextDeltaEvent, ResponseTextDeltaEvent,
ResponseTextDoneEvent, ResponseTextDoneEvent,
ResponseUsage, ResponseUsage,
@@ -38,6 +45,11 @@ from exo.shared.types.openai_responses import (
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
def _format_sse(event: ResponsesStreamEvent) -> str:
"""Format a streaming event as an SSE message."""
return f"event: {event.type}\ndata: {event.model_dump_json()}\n\n"
def _extract_content(content: str | list[ResponseContentPart]) -> str: def _extract_content(content: str | list[ResponseContentPart]) -> str:
"""Extract plain text from a content field that may be a string or list of parts.""" """Extract plain text from a content field that may be a string or list of parts."""
if isinstance(content, str): if isinstance(content, str):
@@ -135,7 +147,9 @@ async def collect_responses_response(
"""Collect all token chunks and return a single ResponsesResponse.""" """Collect all token chunks and return a single ResponsesResponse."""
response_id = f"resp_{command_id}" response_id = f"resp_{command_id}"
item_id = f"item_{command_id}" item_id = f"item_{command_id}"
reasoning_id = f"rs_{command_id}"
accumulated_text = "" accumulated_text = ""
thinking_parts: list[str] = []
function_call_items: list[ResponseFunctionCallItem] = [] function_call_items: list[ResponseFunctionCallItem] = []
last_usage: Usage | None = None last_usage: Usage | None = None
error_message: str | None = None error_message: str | None = None
@@ -162,6 +176,10 @@ async def collect_responses_response(
) )
continue continue
if chunk.is_thinking:
thinking_parts.append(chunk.text)
continue
accumulated_text += chunk.text accumulated_text += chunk.text
if error_message is not None: if error_message is not None:
@@ -176,13 +194,21 @@ async def collect_responses_response(
total_tokens=last_usage.total_tokens, total_tokens=last_usage.total_tokens,
) )
output: list[ResponseItem] = [ output: list[ResponseItem] = []
if thinking_parts:
output.append(
ResponseReasoningItem(
id=reasoning_id,
summary=[ResponseReasoningSummaryText(text="".join(thinking_parts))],
)
)
output.append(
ResponseMessageItem( ResponseMessageItem(
id=item_id, id=item_id,
content=[ResponseOutputText(text=accumulated_text)], content=[ResponseOutputText(text=accumulated_text)],
status="completed", status="completed",
) )
] )
output.extend(function_call_items) output.extend(function_call_items)
yield ResponsesResponse( yield ResponsesResponse(
@@ -206,6 +232,7 @@ async def generate_responses_stream(
"""Generate OpenAI Responses API streaming events from TokenChunks.""" """Generate OpenAI Responses API streaming events from TokenChunks."""
response_id = f"resp_{command_id}" response_id = f"resp_{command_id}"
item_id = f"item_{command_id}" item_id = f"item_{command_id}"
reasoning_id = f"rs_{command_id}"
seq = count(1) seq = count(1)
# response.created # response.created
@@ -219,40 +246,25 @@ async def generate_responses_stream(
created_event = ResponseCreatedEvent( created_event = ResponseCreatedEvent(
sequence_number=next(seq), response=initial_response sequence_number=next(seq), response=initial_response
) )
yield f"event: response.created\ndata: {created_event.model_dump_json()}\n\n" yield _format_sse(created_event)
# response.in_progress # response.in_progress
in_progress_event = ResponseInProgressEvent( in_progress_event = ResponseInProgressEvent(
sequence_number=next(seq), response=initial_response sequence_number=next(seq), response=initial_response
) )
yield f"event: response.in_progress\ndata: {in_progress_event.model_dump_json()}\n\n" yield _format_sse(in_progress_event)
# response.output_item.added
initial_item = ResponseMessageItem(
id=item_id,
content=[ResponseOutputText(text="")],
status="in_progress",
)
item_added = ResponseOutputItemAddedEvent(
sequence_number=next(seq), output_index=0, item=initial_item
)
yield f"event: response.output_item.added\ndata: {item_added.model_dump_json()}\n\n"
# response.content_part.added
initial_part = ResponseOutputText(text="")
part_added = ResponseContentPartAddedEvent(
sequence_number=next(seq),
item_id=item_id,
output_index=0,
content_index=0,
part=initial_part,
)
yield f"event: response.content_part.added\ndata: {part_added.model_dump_json()}\n\n"
accumulated_text = "" accumulated_text = ""
accumulated_thinking = ""
function_call_items: list[ResponseFunctionCallItem] = [] function_call_items: list[ResponseFunctionCallItem] = []
last_usage: Usage | None = None last_usage: Usage | None = None
next_output_index = 1 # message item is at 0 next_output_index = 0
# Track dynamic block creation
reasoning_started = False
reasoning_output_index = -1
message_started = False
message_output_index = -1
async for chunk in chunk_stream: async for chunk in chunk_stream:
if isinstance(chunk, PrefillProgressChunk): if isinstance(chunk, PrefillProgressChunk):
@@ -281,7 +293,7 @@ async def generate_responses_stream(
output_index=next_output_index, output_index=next_output_index,
item=fc_item, item=fc_item,
) )
yield f"event: response.output_item.added\ndata: {fc_added.model_dump_json()}\n\n" yield _format_sse(fc_added)
# response.function_call_arguments.delta # response.function_call_arguments.delta
args_delta = ResponseFunctionCallArgumentsDeltaEvent( args_delta = ResponseFunctionCallArgumentsDeltaEvent(
@@ -290,7 +302,7 @@ async def generate_responses_stream(
output_index=next_output_index, output_index=next_output_index,
delta=tool.arguments, delta=tool.arguments,
) )
yield f"event: response.function_call_arguments.delta\ndata: {args_delta.model_dump_json()}\n\n" yield _format_sse(args_delta)
# response.function_call_arguments.done # response.function_call_arguments.done
args_done = ResponseFunctionCallArgumentsDoneEvent( args_done = ResponseFunctionCallArgumentsDoneEvent(
@@ -300,7 +312,7 @@ async def generate_responses_stream(
name=tool.name, name=tool.name,
arguments=tool.arguments, arguments=tool.arguments,
) )
yield f"event: response.function_call_arguments.done\ndata: {args_done.model_dump_json()}\n\n" yield _format_sse(args_done)
# response.output_item.done # response.output_item.done
fc_done_item = ResponseFunctionCallItem( fc_done_item = ResponseFunctionCallItem(
@@ -315,44 +327,205 @@ async def generate_responses_stream(
output_index=next_output_index, output_index=next_output_index,
item=fc_done_item, item=fc_done_item,
) )
yield f"event: response.output_item.done\ndata: {fc_item_done.model_dump_json()}\n\n" yield _format_sse(fc_item_done)
function_call_items.append(fc_done_item) function_call_items.append(fc_done_item)
next_output_index += 1 next_output_index += 1
continue continue
if chunk.is_thinking:
# Start reasoning block on first thinking token
if not reasoning_started:
reasoning_started = True
reasoning_output_index = next_output_index
next_output_index += 1
# response.output_item.added for reasoning
reasoning_item = ResponseReasoningItem(
id=reasoning_id,
summary=[],
status="in_progress",
)
rs_added = ResponseOutputItemAddedEvent(
sequence_number=next(seq),
output_index=reasoning_output_index,
item=reasoning_item,
)
yield _format_sse(rs_added)
# response.reasoning_summary_part.added
part_added = ResponseReasoningSummaryPartAddedEvent(
sequence_number=next(seq),
item_id=reasoning_id,
output_index=reasoning_output_index,
summary_index=0,
part=ResponseReasoningSummaryText(text=""),
)
yield _format_sse(part_added)
accumulated_thinking += chunk.text
# response.reasoning_summary_text.delta
rs_delta = ResponseReasoningSummaryTextDeltaEvent(
sequence_number=next(seq),
item_id=reasoning_id,
output_index=reasoning_output_index,
summary_index=0,
delta=chunk.text,
)
yield _format_sse(rs_delta)
continue
# Close reasoning block when transitioning to text
if reasoning_started and not message_started:
# response.reasoning_summary_text.done
rs_text_done = ResponseReasoningSummaryTextDoneEvent(
sequence_number=next(seq),
item_id=reasoning_id,
output_index=reasoning_output_index,
summary_index=0,
text=accumulated_thinking,
)
yield _format_sse(rs_text_done)
# response.reasoning_summary_part.done
rs_part_done = ResponseReasoningSummaryPartDoneEvent(
sequence_number=next(seq),
item_id=reasoning_id,
output_index=reasoning_output_index,
summary_index=0,
part=ResponseReasoningSummaryText(text=accumulated_thinking),
)
yield _format_sse(rs_part_done)
# response.output_item.done for reasoning
rs_item_done = ResponseOutputItemDoneEvent(
sequence_number=next(seq),
output_index=reasoning_output_index,
item=ResponseReasoningItem(
id=reasoning_id,
summary=[ResponseReasoningSummaryText(text=accumulated_thinking)],
),
)
yield _format_sse(rs_item_done)
# Start message block on first text token
if not message_started:
message_started = True
message_output_index = next_output_index
next_output_index += 1
initial_item = ResponseMessageItem(
id=item_id,
content=[ResponseOutputText(text="")],
status="in_progress",
)
item_added = ResponseOutputItemAddedEvent(
sequence_number=next(seq),
output_index=message_output_index,
item=initial_item,
)
yield _format_sse(item_added)
initial_part = ResponseOutputText(text="")
part_added = ResponseContentPartAddedEvent(
sequence_number=next(seq),
item_id=item_id,
output_index=message_output_index,
content_index=0,
part=initial_part,
)
yield _format_sse(part_added)
accumulated_text += chunk.text accumulated_text += chunk.text
# response.output_text.delta # response.output_text.delta
delta_event = ResponseTextDeltaEvent( delta_event = ResponseTextDeltaEvent(
sequence_number=next(seq), sequence_number=next(seq),
item_id=item_id, item_id=item_id,
output_index=0, output_index=message_output_index,
content_index=0, content_index=0,
delta=chunk.text, delta=chunk.text,
) )
yield f"event: response.output_text.delta\ndata: {delta_event.model_dump_json()}\n\n" yield _format_sse(delta_event)
# Close reasoning block if it was never followed by text
if reasoning_started and not message_started:
rs_text_done = ResponseReasoningSummaryTextDoneEvent(
sequence_number=next(seq),
item_id=reasoning_id,
output_index=reasoning_output_index,
summary_index=0,
text=accumulated_thinking,
)
yield _format_sse(rs_text_done)
rs_part_done = ResponseReasoningSummaryPartDoneEvent(
sequence_number=next(seq),
item_id=reasoning_id,
output_index=reasoning_output_index,
summary_index=0,
part=ResponseReasoningSummaryText(text=accumulated_thinking),
)
yield _format_sse(rs_part_done)
rs_item_done = ResponseOutputItemDoneEvent(
sequence_number=next(seq),
output_index=reasoning_output_index,
item=ResponseReasoningItem(
id=reasoning_id,
summary=[ResponseReasoningSummaryText(text=accumulated_thinking)],
),
)
yield _format_sse(rs_item_done)
# If no message block was started, create one now (empty text)
if not message_started:
message_output_index = next_output_index
next_output_index += 1
initial_item = ResponseMessageItem(
id=item_id,
content=[ResponseOutputText(text="")],
status="in_progress",
)
item_added = ResponseOutputItemAddedEvent(
sequence_number=next(seq),
output_index=message_output_index,
item=initial_item,
)
yield _format_sse(item_added)
initial_part = ResponseOutputText(text="")
part_added_evt = ResponseContentPartAddedEvent(
sequence_number=next(seq),
item_id=item_id,
output_index=message_output_index,
content_index=0,
part=initial_part,
)
yield _format_sse(part_added_evt)
# response.output_text.done # response.output_text.done
text_done = ResponseTextDoneEvent( text_done = ResponseTextDoneEvent(
sequence_number=next(seq), sequence_number=next(seq),
item_id=item_id, item_id=item_id,
output_index=0, output_index=message_output_index,
content_index=0, content_index=0,
text=accumulated_text, text=accumulated_text,
) )
yield f"event: response.output_text.done\ndata: {text_done.model_dump_json()}\n\n" yield _format_sse(text_done)
# response.content_part.done # response.content_part.done
final_part = ResponseOutputText(text=accumulated_text) final_part = ResponseOutputText(text=accumulated_text)
part_done = ResponseContentPartDoneEvent( part_done = ResponseContentPartDoneEvent(
sequence_number=next(seq), sequence_number=next(seq),
item_id=item_id, item_id=item_id,
output_index=0, output_index=message_output_index,
content_index=0, content_index=0,
part=final_part, part=final_part,
) )
yield f"event: response.content_part.done\ndata: {part_done.model_dump_json()}\n\n" yield _format_sse(part_done)
# response.output_item.done # response.output_item.done
final_message_item = ResponseMessageItem( final_message_item = ResponseMessageItem(
@@ -361,9 +534,11 @@ async def generate_responses_stream(
status="completed", status="completed",
) )
item_done = ResponseOutputItemDoneEvent( item_done = ResponseOutputItemDoneEvent(
sequence_number=next(seq), output_index=0, item=final_message_item sequence_number=next(seq),
output_index=message_output_index,
item=final_message_item,
) )
yield f"event: response.output_item.done\ndata: {item_done.model_dump_json()}\n\n" yield _format_sse(item_done)
# Create usage from usage data if available # Create usage from usage data if available
usage = None usage = None
@@ -375,7 +550,15 @@ async def generate_responses_stream(
) )
# response.completed # response.completed
output: list[ResponseItem] = [final_message_item] output: list[ResponseItem] = []
if reasoning_started:
output.append(
ResponseReasoningItem(
id=reasoning_id,
summary=[ResponseReasoningSummaryText(text=accumulated_thinking)],
)
)
output.append(final_message_item)
output.extend(function_call_items) output.extend(function_call_items)
final_response = ResponsesResponse( final_response = ResponsesResponse(
id=response_id, id=response_id,
@@ -388,4 +571,4 @@ async def generate_responses_stream(
completed_event = ResponseCompletedEvent( completed_event = ResponseCompletedEvent(
sequence_number=next(seq), response=final_response sequence_number=next(seq), response=final_response
) )
yield f"event: response.completed\ndata: {completed_event.model_dump_json()}\n\n" yield _format_sse(completed_event)

View File

@@ -32,6 +32,14 @@ from exo.master.adapters.claude import (
collect_claude_response, collect_claude_response,
generate_claude_stream, generate_claude_stream,
) )
from exo.master.adapters.ollama import (
collect_ollama_chat_response,
collect_ollama_generate_response,
generate_ollama_chat_stream,
generate_ollama_generate_stream,
ollama_generate_request_to_text_generation,
ollama_request_to_text_generation,
)
from exo.master.adapters.responses import ( from exo.master.adapters.responses import (
collect_responses_response, collect_responses_response,
generate_responses_stream, generate_responses_stream,
@@ -138,10 +146,22 @@ from exo.shared.types.events import (
Event, Event,
ForwarderEvent, ForwarderEvent,
IndexedEvent, IndexedEvent,
PrefillProgress,
TracesMerged, TracesMerged,
) )
from exo.shared.types.memory import Memory from exo.shared.types.memory import Memory
from exo.shared.types.ollama_api import (
OllamaChatRequest,
OllamaChatResponse,
OllamaGenerateRequest,
OllamaGenerateResponse,
OllamaModelDetails,
OllamaModelTag,
OllamaPsModel,
OllamaPsResponse,
OllamaShowRequest,
OllamaShowResponse,
OllamaTagsResponse,
)
from exo.shared.types.openai_responses import ( from exo.shared.types.openai_responses import (
ResponsesRequest, ResponsesRequest,
ResponsesResponse, ResponsesResponse,
@@ -301,6 +321,21 @@ class API:
self.app.get("/images/{image_id}")(self.get_image) self.app.get("/images/{image_id}")(self.get_image)
self.app.post("/v1/messages", response_model=None)(self.claude_messages) self.app.post("/v1/messages", response_model=None)(self.claude_messages)
self.app.post("/v1/responses", response_model=None)(self.openai_responses) self.app.post("/v1/responses", response_model=None)(self.openai_responses)
# Ollama API
self.app.head("/ollama/")(self.ollama_version)
self.app.head("/ollama/api/version")(self.ollama_version)
self.app.post("/ollama/api/chat", response_model=None)(self.ollama_chat)
self.app.post("/ollama/api/api/chat", response_model=None)(self.ollama_chat)
self.app.post("/ollama/api/v1/chat", response_model=None)(self.ollama_chat)
self.app.post("/ollama/api/generate", response_model=None)(self.ollama_generate)
self.app.get("/ollama/api/tags")(self.ollama_tags)
self.app.get("/ollama/api/api/tags")(self.ollama_tags)
self.app.get("/ollama/api/v1/tags")(self.ollama_tags)
self.app.post("/ollama/api/show")(self.ollama_show)
self.app.get("/ollama/api/ps")(self.ollama_ps)
self.app.get("/ollama/api/version")(self.ollama_version)
self.app.get("/state")(lambda: self.state) self.app.get("/state")(lambda: self.state)
self.app.get("/events")(self.stream_events) self.app.get("/events")(self.stream_events)
self.app.post("/download/start")(self.start_download) self.app.post("/download/start")(self.start_download)
@@ -1294,6 +1329,163 @@ class API:
media_type="application/json", media_type="application/json",
) )
async def _ollama_root(self) -> JSONResponse:
"""Respond to HEAD / from Ollama CLI connectivity checks."""
return JSONResponse(content="Ollama is running")
async def ollama_chat(
self, request: Request
) -> OllamaChatResponse | StreamingResponse:
"""Ollama Chat API — accepts JSON regardless of Content-Type."""
body = await request.body()
payload = OllamaChatRequest.model_validate_json(body)
task_params = ollama_request_to_text_generation(payload)
resolved_model = await self._resolve_and_validate_text_model(
ModelId(task_params.model)
)
task_params = task_params.model_copy(update={"model": resolved_model})
command = TextGeneration(task_params=task_params)
await self._send(command)
if payload.stream:
return StreamingResponse(
generate_ollama_chat_stream(
command.command_id,
self._token_chunk_stream(command.command_id),
),
media_type="application/x-ndjson",
headers={
"Cache-Control": "no-cache",
"Connection": "close",
"X-Accel-Buffering": "no",
},
)
else:
return StreamingResponse(
collect_ollama_chat_response(
command.command_id,
self._token_chunk_stream(command.command_id),
),
media_type="application/json",
)
async def ollama_generate(
self, request: Request
) -> OllamaGenerateResponse | StreamingResponse:
"""Ollama Generate API — accepts JSON regardless of Content-Type."""
body = await request.body()
payload = OllamaGenerateRequest.model_validate_json(body)
task_params = ollama_generate_request_to_text_generation(payload)
resolved_model = await self._resolve_and_validate_text_model(
ModelId(task_params.model)
)
task_params = task_params.model_copy(update={"model": resolved_model})
command = TextGeneration(task_params=task_params)
await self._send(command)
if payload.stream:
return StreamingResponse(
generate_ollama_generate_stream(
command.command_id,
self._token_chunk_stream(command.command_id),
),
media_type="application/x-ndjson",
headers={
"Cache-Control": "no-cache",
"Connection": "close",
"X-Accel-Buffering": "no",
},
)
else:
return StreamingResponse(
collect_ollama_generate_response(
command.command_id,
self._token_chunk_stream(command.command_id),
),
media_type="application/json",
)
async def ollama_tags(self) -> OllamaTagsResponse:
"""Returns list of models in Ollama tags format. We return the downloaded ones only."""
def none_if_empty(value: str) -> str | None:
return value or None
downloaded_model_ids: set[str] = set()
for node_downloads in self.state.downloads.values():
for dl in node_downloads:
if isinstance(dl, DownloadCompleted):
downloaded_model_ids.add(dl.shard_metadata.model_card.model_id)
cards = [
c for c in await get_model_cards() if c.model_id in downloaded_model_ids
]
now = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
return OllamaTagsResponse(
models=[
OllamaModelTag(
name=str(card.model_id),
model=str(card.model_id),
modified_at=now,
size=card.storage_size.in_bytes,
digest="sha256:000000000000",
details=OllamaModelDetails(
family=none_if_empty(card.family),
quantization_level=none_if_empty(card.quantization),
),
)
for card in cards
]
)
async def ollama_show(self, request: Request) -> OllamaShowResponse:
"""Returns model information in Ollama show format."""
body = await request.body()
payload = OllamaShowRequest.model_validate_json(body)
model_name = payload.name or payload.model
if not model_name:
raise HTTPException(status_code=400, detail="name or model is required")
try:
card = await ModelCard.load(ModelId(model_name))
except Exception as exc:
raise HTTPException(
status_code=404, detail=f"Model not found: {model_name}"
) from exc
return OllamaShowResponse(
modelfile=f"FROM {card.model_id}",
template="{{ .Prompt }}",
details=OllamaModelDetails(
family=card.family or None,
quantization_level=card.quantization or None,
),
)
async def ollama_ps(self) -> OllamaPsResponse:
"""Returns list of running models (active instances)."""
models: list[OllamaPsModel] = []
seen: set[str] = set()
for instance in self.state.instances.values():
model_id = str(instance.shard_assignments.model_id)
if model_id in seen:
continue
seen.add(model_id)
models.append(
OllamaPsModel(
name=model_id,
model=model_id,
size=0,
)
)
return OllamaPsResponse(models=models)
async def ollama_version(self) -> dict[str, str]:
"""Returns version information for Ollama API compatibility."""
return {"version": "exo v1.0"}
def _calculate_total_available_memory(self) -> Memory: def _calculate_total_available_memory(self) -> Memory:
"""Calculate total available memory across all nodes in bytes.""" """Calculate total available memory across all nodes in bytes."""
total_available = Memory() total_available = Memory()
@@ -1323,7 +1515,7 @@ class API:
name=card.model_id.short(), name=card.model_id.short(),
description="", description="",
tags=[], tags=[],
storage_size_megabytes=int(card.storage_size.in_mb), storage_size_megabytes=card.storage_size.in_mb,
supports_tensor=card.supports_tensor, supports_tensor=card.supports_tensor,
tasks=[task.value for task in card.tasks], tasks=[task.value for task in card.tasks],
is_custom=is_custom_card(card.model_id), is_custom=is_custom_card(card.model_id),
@@ -1455,22 +1647,6 @@ class API:
await queue.send(event.chunk) await queue.send(event.chunk)
except BrokenResourceError: except BrokenResourceError:
self._text_generation_queues.pop(event.command_id, None) self._text_generation_queues.pop(event.command_id, None)
elif isinstance(event, PrefillProgress):
if queue := self._text_generation_queues.get(
event.command_id, None
):
try:
await queue.send(
PrefillProgressChunk(
model=event.model,
processed_tokens=event.processed_tokens,
total_tokens=event.total_tokens,
)
)
except BrokenResourceError:
self._text_generation_queues.pop(event.command_id, None)
if isinstance(event, TracesMerged): if isinstance(event, TracesMerged):
self._save_merged_trace(event) self._save_merged_trace(event)

View File

@@ -141,15 +141,29 @@ def place_instance(
if len(selected_cycle) == 1: if len(selected_cycle) == 1:
command.instance_meta = InstanceMeta.MlxRing command.instance_meta = InstanceMeta.MlxRing
# TODO: Single node instances
match command.instance_meta: match command.instance_meta:
case InstanceMeta.MlxJaccl: case InstanceMeta.MlxJaccl:
# TODO(evan): shard assignments should contain information about ranks, this is ugly
def get_device_rank(node_id: NodeId) -> int:
runner_id = shard_assignments.node_to_runner[node_id]
shard_metadata = shard_assignments.runner_to_shard.get(runner_id)
assert shard_metadata is not None
return shard_metadata.device_rank
zero_node_ids = [
node_id
for node_id in selected_cycle.node_ids
if get_device_rank(node_id) == 0
]
assert len(zero_node_ids) == 1
coordinator_node_id = zero_node_ids[0]
mlx_jaccl_devices = get_mlx_jaccl_devices_matrix( mlx_jaccl_devices = get_mlx_jaccl_devices_matrix(
[node_id for node_id in selected_cycle], [node_id for node_id in selected_cycle],
cycle_digraph, cycle_digraph,
) )
mlx_jaccl_coordinators = get_mlx_jaccl_coordinators( mlx_jaccl_coordinators = get_mlx_jaccl_coordinators(
coordinator=selected_cycle.node_ids[0], coordinator=coordinator_node_id,
coordinator_port=random_ephemeral_port(), coordinator_port=random_ephemeral_port(),
cycle_digraph=cycle_digraph, cycle_digraph=cycle_digraph,
node_network=node_network, node_network=node_network,

View File

@@ -102,22 +102,21 @@ def _allocate_and_validate_layers(
layer_allocations = allocate_layers_proportionally( layer_allocations = allocate_layers_proportionally(
total_layers=model_card.n_layers, total_layers=model_card.n_layers,
memory_fractions=[ memory_fractions=[
node_memory[node_id].ram_available.in_bytes / total_memory.in_bytes node_memory[node_id].ram_available / total_memory for node_id in node_ids
for node_id in node_ids
], ],
) )
total_storage_bytes = model_card.storage_size.in_bytes total_storage = model_card.storage_size
total_layers = model_card.n_layers total_layers = model_card.n_layers
for i, node_id in enumerate(node_ids): for i, node_id in enumerate(node_ids):
node_layers = layer_allocations[i] node_layers = layer_allocations[i]
required_memory = (total_storage_bytes * node_layers) // total_layers required_memory = (total_storage * node_layers) // total_layers
available_memory = node_memory[node_id].ram_available.in_bytes available_memory = node_memory[node_id].ram_available
if required_memory > available_memory: if required_memory > available_memory:
raise ValueError( raise ValueError(
f"Node {i} ({node_id}) has insufficient memory: " f"Node {i} ({node_id}) has insufficient memory: "
f"requires {required_memory / (1024**3):.2f} GB for {node_layers} layers, " f"requires {required_memory.in_gb:.2f} GB for {node_layers} layers, "
f"but only has {available_memory / (1024**3):.2f} GB available" f"but only has {available_memory.in_gb:.2f} GB available"
) )
return layer_allocations return layer_allocations
@@ -342,6 +341,7 @@ def _find_ip_prioritised(
other_node_id: NodeId, other_node_id: NodeId,
cycle_digraph: Topology, cycle_digraph: Topology,
node_network: Mapping[NodeId, NodeNetworkInfo], node_network: Mapping[NodeId, NodeNetworkInfo],
ring: bool,
) -> str | None: ) -> str | None:
"""Find an IP address between nodes with prioritization. """Find an IP address between nodes with prioritization.
@@ -354,13 +354,27 @@ def _find_ip_prioritised(
ip_to_type = { ip_to_type = {
iface.ip_address: iface.interface_type for iface in other_network.interfaces iface.ip_address: iface.interface_type for iface in other_network.interfaces
} }
priority = {
"ethernet": 0, # Ring should prioritise fastest connection. As a best-effort, we prioritise TB.
"wifi": 1, # TODO: Profile and get actual connection speeds.
"unknown": 2, if ring:
"maybe_ethernet": 3, priority = {
"thunderbolt": 4, "thunderbolt": 0,
} "maybe_ethernet": 1,
"ethernet": 2,
"wifi": 3,
"unknown": 4,
}
# RDMA prefers ethernet coordinator
else:
priority = {
"ethernet": 0,
"wifi": 1,
"unknown": 2,
"maybe_ethernet": 3,
"thunderbolt": 4,
}
return min(ips, key=lambda ip: priority.get(ip_to_type.get(ip, "unknown"), 2)) return min(ips, key=lambda ip: priority.get(ip_to_type.get(ip, "unknown"), 2))
@@ -400,7 +414,7 @@ def get_mlx_ring_hosts_by_node(
continue continue
connection_ip = _find_ip_prioritised( connection_ip = _find_ip_prioritised(
node_id, other_node_id, cycle_digraph, node_network node_id, other_node_id, cycle_digraph, node_network, ring=True
) )
if connection_ip is None: if connection_ip is None:
raise ValueError( raise ValueError(
@@ -431,7 +445,9 @@ def get_mlx_jaccl_coordinators(
if n == coordinator: if n == coordinator:
return "0.0.0.0" return "0.0.0.0"
ip = _find_ip_prioritised(n, coordinator, cycle_digraph, node_network) ip = _find_ip_prioritised(
n, coordinator, cycle_digraph, node_network, ring=False
)
if ip is not None: if ip is not None:
return ip return ip

View File

@@ -261,7 +261,7 @@ class TestGenerateClaudeStreamToolUse:
parsed = _parse_sse_events(events) parsed = _parse_sse_events(events)
# Two tool block starts (at indices 1 and 2) # Two tool block starts (at indices 0 and 1 — no text block when only tools)
tool_starts = [ tool_starts = [
e e
for e in parsed for e in parsed
@@ -270,12 +270,11 @@ class TestGenerateClaudeStreamToolUse:
== "tool_use" == "tool_use"
] ]
assert len(tool_starts) == 2 assert len(tool_starts) == 2
assert tool_starts[0]["index"] == 1 assert tool_starts[0]["index"] == 0
assert tool_starts[1]["index"] == 2 assert tool_starts[1]["index"] == 1
# Two tool block stops (at indices 1 and 2), plus text block stop at 0 # Two tool block stops (at indices 0 and 1)
block_stops = [e for e in parsed if e.get("type") == "content_block_stop"] block_stops = [e for e in parsed if e.get("type") == "content_block_stop"]
stop_indices = [e["index"] for e in block_stops] stop_indices = [e["index"] for e in block_stops]
assert 0 in stop_indices assert 0 in stop_indices
assert 1 in stop_indices assert 1 in stop_indices
assert 2 in stop_indices

View File

@@ -42,7 +42,7 @@ from exo.utils.channels import channel
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_master(): async def test_master():
keypair = get_node_id_keypair() keypair = get_node_id_keypair()
node_id = NodeId(keypair.to_peer_id().to_base58()) node_id = NodeId(keypair.to_node_id())
session_id = SessionId(master_node_id=node_id, election_clock=0) session_id = SessionId(master_node_id=node_id, election_clock=0)
ge_sender, global_event_receiver = channel[ForwarderEvent]() ge_sender, global_event_receiver = channel[ForwarderEvent]()
@@ -75,7 +75,7 @@ async def test_master():
async with anyio.create_task_group() as tg: async with anyio.create_task_group() as tg:
tg.start_soon(master.run) tg.start_soon(master.run)
sender_node_id = NodeId(f"{keypair.to_peer_id().to_base58()}_sender") sender_node_id = NodeId(f"{keypair.to_node_id()}_sender")
# inject a NodeGatheredInfo event # inject a NodeGatheredInfo event
logger.info("inject a NodeGatheredInfo event") logger.info("inject a NodeGatheredInfo event")
await local_event_sender.send( await local_event_sender.send(

View File

@@ -80,8 +80,8 @@ def test_get_instance_placements_create_instance(
): ):
# arrange # arrange
model_card.n_layers = total_layers model_card.n_layers = total_layers
model_card.storage_size.in_bytes = sum( model_card.storage_size = Memory.from_bytes(
available_memory sum(available_memory)
) # make it exactly fit across all nodes ) # make it exactly fit across all nodes
topology = Topology() topology = Topology()
@@ -349,7 +349,7 @@ def test_tensor_rdma_backend_connectivity_matrix(
# arrange # arrange
topology = Topology() topology = Topology()
model_card.n_layers = 12 model_card.n_layers = 12
model_card.storage_size.in_bytes = 1500 model_card.storage_size = Memory.from_bytes(1500)
node_a = NodeId() node_a = NodeId()
node_b = NodeId() node_b = NodeId()

View File

@@ -30,7 +30,7 @@ class ConnectionMessage(CamelCaseModel):
@classmethod @classmethod
def from_update(cls, update: ConnectionUpdate) -> "ConnectionMessage": def from_update(cls, update: ConnectionUpdate) -> "ConnectionMessage":
return cls( return cls(
node_id=NodeId(update.peer_id.to_base58()), node_id=NodeId(update.peer_id),
connection_type=ConnectionMessageType.from_update_type(update.update_type), connection_type=ConnectionMessageType.from_update_type(update.update_type),
remote_ipv4=update.remote_ipv4, remote_ipv4=update.remote_ipv4,
remote_tcp_port=update.remote_tcp_port, remote_tcp_port=update.remote_tcp_port,

View File

@@ -221,7 +221,7 @@ def get_node_id_keypair(
Obtain the :class:`PeerId` by from it. Obtain the :class:`PeerId` by from it.
""" """
# TODO(evan): bring back node id persistence once we figure out how to deal with duplicates # TODO(evan): bring back node id persistence once we figure out how to deal with duplicates
return Keypair.generate_ed25519() return Keypair.generate()
def lock_path(path: str | bytes | PathLike[str] | PathLike[bytes]) -> Path: def lock_path(path: str | bytes | PathLike[str] | PathLike[bytes]) -> Path:
return Path(str(path) + ".lock") return Path(str(path) + ".lock")
@@ -235,12 +235,12 @@ def get_node_id_keypair(
protobuf_encoded = f.read() protobuf_encoded = f.read()
try: # if decoded successfully, save & return try: # if decoded successfully, save & return
return Keypair.from_protobuf_encoding(protobuf_encoded) return Keypair.from_bytes(protobuf_encoded)
except ValueError as e: # on runtime error, assume corrupt file except ValueError as e: # on runtime error, assume corrupt file
logger.warning(f"Encountered error when trying to get keypair: {e}") logger.warning(f"Encountered error when trying to get keypair: {e}")
# if no valid credentials, create new ones and persist # if no valid credentials, create new ones and persist
with open(path, "w+b") as f: with open(path, "w+b") as f:
keypair = Keypair.generate_ed25519() keypair = Keypair.generate_ed25519()
f.write(keypair.to_protobuf_encoding()) f.write(keypair.to_bytes())
return keypair return keypair

View File

@@ -15,7 +15,6 @@ from exo.shared.types.events import (
NodeDownloadProgress, NodeDownloadProgress,
NodeGatheredInfo, NodeGatheredInfo,
NodeTimedOut, NodeTimedOut,
PrefillProgress,
RunnerDeleted, RunnerDeleted,
RunnerStatusUpdated, RunnerStatusUpdated,
TaskAcknowledged, TaskAcknowledged,
@@ -65,7 +64,6 @@ def event_apply(event: Event, state: State) -> State:
| ChunkGenerated() | ChunkGenerated()
| TaskAcknowledged() | TaskAcknowledged()
| InputChunkReceived() | InputChunkReceived()
| PrefillProgress()
| TracesCollected() | TracesCollected()
| TracesMerged() | TracesMerged()
): # Pass-through events that don't modify state ): # Pass-through events that don't modify state

View File

@@ -14,7 +14,7 @@ def test_apply_node_download_progress():
event = DownloadCompleted( event = DownloadCompleted(
node_id=NodeId("node-1"), node_id=NodeId("node-1"),
shard_metadata=shard1, shard_metadata=shard1,
total_bytes=Memory(), total=Memory(),
) )
new_state = apply_node_download_progress( new_state = apply_node_download_progress(
@@ -30,12 +30,12 @@ def test_apply_two_node_download_progress():
event1 = DownloadCompleted( event1 = DownloadCompleted(
node_id=NodeId("node-1"), node_id=NodeId("node-1"),
shard_metadata=shard1, shard_metadata=shard1,
total_bytes=Memory(), total=Memory(),
) )
event2 = DownloadCompleted( event2 = DownloadCompleted(
node_id=NodeId("node-1"), node_id=NodeId("node-1"),
shard_metadata=shard2, shard_metadata=shard2,
total_bytes=Memory(), total=Memory(),
) )
state = State(downloads={NodeId("node-1"): [event1]}) state = State(downloads={NodeId("node-1"): [event1]})

View File

@@ -23,7 +23,7 @@ def _get_keypair_concurrent_subprocess_task(
sem.release() sem.release()
# wait to be told to begin simultaneous read # wait to be told to begin simultaneous read
ev.wait() ev.wait()
queue.put(get_node_id_keypair().to_protobuf_encoding()) queue.put(get_node_id_keypair().to_bytes())
def _get_keypair_concurrent(num_procs: int) -> bytes: def _get_keypair_concurrent(num_procs: int) -> bytes:

View File

@@ -77,7 +77,7 @@ class ChatCompletionMessage(BaseModel):
content: ( content: (
str | ChatCompletionMessageText | list[ChatCompletionMessageText] | None str | ChatCompletionMessageText | list[ChatCompletionMessageText] | None
) = None ) = None
thinking: str | None = None # Added for GPT-OSS harmony format support reasoning_content: str | None = None
name: str | None = None name: str | None = None
tool_calls: list[ToolCall] | None = None tool_calls: list[ToolCall] | None = None
tool_call_id: str | None = None tool_call_id: str | None = None

View File

@@ -27,6 +27,7 @@ class TokenChunk(BaseChunk):
stats: GenerationStats | None = None stats: GenerationStats | None = None
logprob: float | None = None logprob: float | None = None
top_logprobs: list[TopLogprobItem] | None = None top_logprobs: list[TopLogprobItem] | None = None
is_thinking: bool = False
class ErrorChunk(BaseChunk): class ErrorChunk(BaseChunk):

View File

@@ -47,6 +47,14 @@ class ClaudeImageBlock(BaseModel, frozen=True):
source: ClaudeImageSource source: ClaudeImageSource
class ClaudeThinkingBlock(BaseModel, frozen=True):
"""Thinking content block in Claude Messages API."""
type: Literal["thinking"] = "thinking"
thinking: str
signature: str | None = None
class ClaudeToolUseBlock(BaseModel, frozen=True): class ClaudeToolUseBlock(BaseModel, frozen=True):
"""Tool use content block in Claude Messages API.""" """Tool use content block in Claude Messages API."""
@@ -66,11 +74,17 @@ class ClaudeToolResultBlock(BaseModel, frozen=True):
cache_control: dict[str, str] | None = None cache_control: dict[str, str] | None = None
ClaudeContentBlock = ClaudeTextBlock | ClaudeImageBlock | ClaudeToolUseBlock ClaudeContentBlock = (
ClaudeTextBlock | ClaudeImageBlock | ClaudeThinkingBlock | ClaudeToolUseBlock
)
# Input content blocks can also include tool_result (sent by user after tool_use) # Input content blocks can also include tool_result (sent by user after tool_use)
ClaudeInputContentBlock = ( ClaudeInputContentBlock = (
ClaudeTextBlock | ClaudeImageBlock | ClaudeToolUseBlock | ClaudeToolResultBlock ClaudeTextBlock
| ClaudeImageBlock
| ClaudeThinkingBlock
| ClaudeToolUseBlock
| ClaudeToolResultBlock
) )
@@ -82,6 +96,11 @@ class ClaudeMessage(BaseModel, frozen=True):
content: str | list[ClaudeInputContentBlock] content: str | list[ClaudeInputContentBlock]
class ClaudeThinkingConfig(BaseModel, frozen=True):
type: Literal["enabled", "disabled", "adaptive"]
budget_tokens: int | None = None
class ClaudeMessagesRequest(BaseModel): class ClaudeMessagesRequest(BaseModel):
"""Request body for Claude Messages API.""" """Request body for Claude Messages API."""
@@ -96,6 +115,7 @@ class ClaudeMessagesRequest(BaseModel):
top_k: int | None = None top_k: int | None = None
tools: list[ClaudeToolDefinition] | None = None tools: list[ClaudeToolDefinition] | None = None
metadata: dict[str, str] | None = None metadata: dict[str, str] | None = None
thinking: ClaudeThinkingConfig | None = None
# Response types # Response types
@@ -145,7 +165,7 @@ class ClaudeContentBlockStartEvent(BaseModel, frozen=True):
type: Literal["content_block_start"] = "content_block_start" type: Literal["content_block_start"] = "content_block_start"
index: int index: int
content_block: ClaudeTextBlock | ClaudeToolUseBlock content_block: ClaudeTextBlock | ClaudeThinkingBlock | ClaudeToolUseBlock
class ClaudeTextDelta(BaseModel, frozen=True): class ClaudeTextDelta(BaseModel, frozen=True):
@@ -155,6 +175,13 @@ class ClaudeTextDelta(BaseModel, frozen=True):
text: str text: str
class ClaudeThinkingDelta(BaseModel, frozen=True):
"""Delta for thinking content block."""
type: Literal["thinking_delta"] = "thinking_delta"
thinking: str
class ClaudeInputJsonDelta(BaseModel, frozen=True): class ClaudeInputJsonDelta(BaseModel, frozen=True):
"""Delta for tool use input JSON content block.""" """Delta for tool use input JSON content block."""
@@ -167,7 +194,7 @@ class ClaudeContentBlockDeltaEvent(BaseModel, frozen=True):
type: Literal["content_block_delta"] = "content_block_delta" type: Literal["content_block_delta"] = "content_block_delta"
index: int index: int
delta: ClaudeTextDelta | ClaudeInputJsonDelta delta: ClaudeTextDelta | ClaudeThinkingDelta | ClaudeInputJsonDelta
class ClaudeContentBlockStopEvent(BaseModel, frozen=True): class ClaudeContentBlockStopEvent(BaseModel, frozen=True):

View File

@@ -5,7 +5,7 @@ from pydantic import Field
from exo.shared.topology import Connection from exo.shared.topology import Connection
from exo.shared.types.chunks import GenerationChunk, InputImageChunk from exo.shared.types.chunks import GenerationChunk, InputImageChunk
from exo.shared.types.common import CommandId, Id, ModelId, NodeId, SessionId from exo.shared.types.common import CommandId, Id, NodeId, SessionId
from exo.shared.types.tasks import Task, TaskId, TaskStatus from exo.shared.types.tasks import Task, TaskId, TaskStatus
from exo.shared.types.worker.downloads import DownloadProgress from exo.shared.types.worker.downloads import DownloadProgress
from exo.shared.types.worker.instances import Instance, InstanceId from exo.shared.types.worker.instances import Instance, InstanceId
@@ -102,13 +102,6 @@ class InputChunkReceived(BaseEvent):
chunk: InputImageChunk chunk: InputImageChunk
class PrefillProgress(BaseEvent):
command_id: CommandId
model: ModelId
processed_tokens: int
total_tokens: int
class TopologyEdgeCreated(BaseEvent): class TopologyEdgeCreated(BaseEvent):
conn: Connection conn: Connection
@@ -155,7 +148,6 @@ Event = (
| NodeDownloadProgress | NodeDownloadProgress
| ChunkGenerated | ChunkGenerated
| InputChunkReceived | InputChunkReceived
| PrefillProgress
| TopologyEdgeCreated | TopologyEdgeCreated
| TopologyEdgeDeleted | TopologyEdgeDeleted
| TracesCollected | TracesCollected

View File

@@ -1,10 +1,10 @@
from math import ceil from math import ceil
from typing import Self from typing import Self, overload
from exo.utils.pydantic_ext import CamelCaseModel from exo.utils.pydantic_ext import FrozenModel
class Memory(CamelCaseModel): class Memory(FrozenModel):
in_bytes: int = 0 in_bytes: int = 0
@classmethod @classmethod
@@ -33,12 +33,22 @@ class Memory(CamelCaseModel):
return cls(in_bytes=round(val * 1024)) return cls(in_bytes=round(val * 1024))
@property @property
def in_mb(self) -> float: def in_mb(self) -> int:
"""The approximate megabytes this memory represents. Setting this property rounds to the nearest byte.""" """The approximate megabytes this memory represents, rounded to nearest MB. Setting this property rounds to the nearest byte."""
return self.in_bytes / (1024**2) return round(self.in_bytes / (1024**2))
@in_mb.setter @in_mb.setter
def in_mb(self, val: float): def in_mb(self, val: int):
"""Set the megabytes for this memory."""
self.in_bytes = val * (1024**2)
@property
def in_float_mb(self) -> float:
"""The megabytes this memory represents as a float. Setting this property rounds to the nearest byte."""
return self.in_bytes / (1024**2)
@in_float_mb.setter
def in_float_mb(self, val: float):
"""Set the megabytes for this memory, rounded to the nearest byte.""" """Set the megabytes for this memory, rounded to the nearest byte."""
self.in_bytes = round(val * (1024**2)) self.in_bytes = round(val * (1024**2))
@@ -57,17 +67,85 @@ class Memory(CamelCaseModel):
"""The approximate gigabytes this memory represents.""" """The approximate gigabytes this memory represents."""
return self.in_bytes / (1024**3) return self.in_bytes / (1024**3)
def __add__(self, other: "Memory") -> "Memory": def __add__(self, other: object) -> "Memory":
return Memory.from_bytes(self.in_bytes + other.in_bytes) if isinstance(other, Memory):
return Memory.from_bytes(self.in_bytes + other.in_bytes)
return NotImplemented
def __lt__(self, other: Self) -> bool: def __radd__(self, other: object) -> "Memory":
return self.in_bytes < other.in_bytes if other == 0:
return self
return NotImplemented
def __le__(self, other: Self) -> bool: def __sub__(self, other: object) -> "Memory":
return self.in_bytes <= other.in_bytes if isinstance(other, Memory):
return Memory.from_bytes(self.in_bytes - other.in_bytes)
return NotImplemented
def __gt__(self, other: Self) -> bool: def __mul__(self, other: int | float):
return self.in_bytes > other.in_bytes return Memory.from_bytes(round(self.in_bytes * other))
def __ge__(self, other: Self) -> bool: def __rmul__(self, other: int | float):
return self.in_bytes >= other.in_bytes return self * other
@overload
def __truediv__(self, other: "Memory") -> float: ...
@overload
def __truediv__(self, other: int) -> "Memory": ...
@overload
def __truediv__(self, other: float) -> "Memory": ...
def __truediv__(self, other: object) -> "Memory | float":
if isinstance(other, Memory):
return self.in_bytes / other.in_bytes
if isinstance(other, (int, float)):
return Memory.from_bytes(round(self.in_bytes / other))
return NotImplemented
def __floordiv__(self, other: object) -> "Memory":
if isinstance(other, (int, float)):
return Memory.from_bytes(int(self.in_bytes // other))
return NotImplemented
def __lt__(self, other: object) -> bool:
if isinstance(other, Memory):
return self.in_bytes < other.in_bytes
return NotImplemented
def __le__(self, other: object) -> bool:
if isinstance(other, Memory):
return self.in_bytes <= other.in_bytes
return NotImplemented
def __gt__(self, other: object) -> bool:
if isinstance(other, Memory):
return self.in_bytes > other.in_bytes
return NotImplemented
def __ge__(self, other: object) -> bool:
if isinstance(other, Memory):
return self.in_bytes >= other.in_bytes
return NotImplemented
def __eq__(self, other: object) -> bool:
if isinstance(other, Memory):
return self.in_bytes == other.in_bytes
return NotImplemented
def __repr__(self) -> str:
return f"Memory.from_bytes({self.in_bytes})"
def __str__(self) -> str:
if self.in_gb > 2:
val = self.in_gb
unit = "GiB"
elif self.in_mb > 2:
val = self.in_mb
unit = "MiB"
elif self.in_kb > 3:
val = self.in_kb
unit = "KiB"
else:
val = self.in_bytes
unit = "B"
return f"{val:.2f} {unit}".rstrip("0").rstrip(".") + f" {unit}"

View File

@@ -0,0 +1,148 @@
from __future__ import annotations
import time
from typing import Any, Literal
from pydantic import BaseModel, Field
from exo.shared.models.model_cards import ModelId
# https://github.com/ollama/ollama/blob/main/docs/api.md
OllamaRole = Literal["system", "user", "assistant", "tool"]
OllamaDoneReason = Literal["stop", "length", "tool_call", "error"]
class OllamaToolFunction(BaseModel, frozen=True):
name: str
arguments: dict[str, Any] | str
index: int | None = None
class OllamaToolCall(BaseModel, frozen=True):
id: str | None = None
type: Literal["function"] | None = None
function: OllamaToolFunction
class OllamaMessage(BaseModel, frozen=True):
role: OllamaRole
content: str | None = None
thinking: str | None = None
tool_calls: list[OllamaToolCall] | None = None
name: str | None = None
tool_name: str | None = None
images: list[str] | None = None
class OllamaOptions(BaseModel, frozen=True):
num_predict: int | None = None
temperature: float | None = None
top_p: float | None = None
top_k: int | None = None
stop: str | list[str] | None = None
seed: int | None = None
class OllamaChatRequest(BaseModel, frozen=True):
model: ModelId
messages: list[OllamaMessage]
stream: bool = True
options: OllamaOptions | None = None
tools: list[dict[str, Any]] | None = None
format: Literal["json"] | dict[str, Any] | None = None
keep_alive: str | int | None = None
think: bool | None = None
class OllamaGenerateRequest(BaseModel, frozen=True):
model: ModelId
prompt: str = ""
system: str | None = None
stream: bool = True
options: OllamaOptions | None = None
format: Literal["json"] | dict[str, Any] | None = None
keep_alive: str | int | None = None
think: bool | None = None
raw: bool = False
class OllamaGenerateResponse(BaseModel, frozen=True, strict=True):
model: str
created_at: str = Field(
default_factory=lambda: time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
)
response: str
thinking: str | None = None
done: bool
done_reason: OllamaDoneReason | None = None
total_duration: int | None = None
load_duration: int | None = None
prompt_eval_count: int | None = None
prompt_eval_duration: int | None = None
eval_count: int | None = None
eval_duration: int | None = None
class OllamaShowRequest(BaseModel, frozen=True):
name: str | None = None
model: str | None = None
verbose: bool | None = None
class OllamaChatResponse(BaseModel, frozen=True, strict=True):
model: str
created_at: str = Field(
default_factory=lambda: time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
)
message: OllamaMessage
done: bool
done_reason: OllamaDoneReason | None = None
total_duration: int | None = None
load_duration: int | None = None
prompt_eval_count: int | None = None
prompt_eval_duration: int | None = None
eval_count: int | None = None
eval_duration: int | None = None
class OllamaModelDetails(BaseModel, frozen=True, strict=True):
format: str | None = None
family: str | None = None
parameter_size: str | None = None
quantization_level: str | None = None
class OllamaModelTag(BaseModel, frozen=True, strict=True):
name: str
model: str | None = None
modified_at: str | None = None
size: int | None = None
digest: str | None = None
details: OllamaModelDetails | None = None
class OllamaTagsResponse(BaseModel, frozen=True, strict=True):
models: list[OllamaModelTag]
class OllamaShowResponse(BaseModel, frozen=True, strict=True):
modelfile: str | None = None
parameters: str | None = None
template: str | None = None
details: OllamaModelDetails | None = None
model_info: dict[str, Any] | None = None
class OllamaPsModel(BaseModel, frozen=True, strict=True):
name: str
model: str
size: int
digest: str | None = None
details: OllamaModelDetails | None = None
expires_at: str | None = None
size_vram: int | None = None
class OllamaPsResponse(BaseModel, frozen=True, strict=True):
models: list[OllamaPsModel]

View File

@@ -145,7 +145,23 @@ class ResponseFunctionCallItem(BaseModel, frozen=True):
status: ResponseStatus = "completed" status: ResponseStatus = "completed"
ResponseItem = ResponseMessageItem | ResponseFunctionCallItem class ResponseReasoningSummaryText(BaseModel, frozen=True):
"""Summary text part in a reasoning output item."""
type: Literal["summary_text"] = "summary_text"
text: str
class ResponseReasoningItem(BaseModel, frozen=True):
"""Reasoning output item in response output array."""
type: Literal["reasoning"] = "reasoning"
id: str
summary: list[ResponseReasoningSummaryText] = Field(default_factory=list)
status: ResponseStatus = "completed"
ResponseItem = ResponseMessageItem | ResponseFunctionCallItem | ResponseReasoningItem
class ResponseUsage(BaseModel, frozen=True): class ResponseUsage(BaseModel, frozen=True):
@@ -273,6 +289,58 @@ class ResponseFunctionCallArgumentsDoneEvent(BaseModel, frozen=True):
arguments: str arguments: str
class ResponseReasoningSummaryPartAddedEvent(BaseModel, frozen=True):
"""Event sent when a reasoning summary part is added."""
type: Literal["response.reasoning_summary_part.added"] = (
"response.reasoning_summary_part.added"
)
sequence_number: int
item_id: str
output_index: int
summary_index: int
part: ResponseReasoningSummaryText
class ResponseReasoningSummaryTextDeltaEvent(BaseModel, frozen=True):
"""Event sent for reasoning summary text delta during streaming."""
type: Literal["response.reasoning_summary_text.delta"] = (
"response.reasoning_summary_text.delta"
)
sequence_number: int
item_id: str
output_index: int
summary_index: int
delta: str
class ResponseReasoningSummaryTextDoneEvent(BaseModel, frozen=True):
"""Event sent when reasoning summary text is done."""
type: Literal["response.reasoning_summary_text.done"] = (
"response.reasoning_summary_text.done"
)
sequence_number: int
item_id: str
output_index: int
summary_index: int
text: str
class ResponseReasoningSummaryPartDoneEvent(BaseModel, frozen=True):
"""Event sent when a reasoning summary part is done."""
type: Literal["response.reasoning_summary_part.done"] = (
"response.reasoning_summary_part.done"
)
sequence_number: int
item_id: str
output_index: int
summary_index: int
part: ResponseReasoningSummaryText
class ResponseCompletedEvent(BaseModel, frozen=True): class ResponseCompletedEvent(BaseModel, frozen=True):
"""Event sent when response is completed.""" """Event sent when response is completed."""
@@ -292,5 +360,9 @@ ResponsesStreamEvent = (
| ResponseOutputItemDoneEvent | ResponseOutputItemDoneEvent
| ResponseFunctionCallArgumentsDeltaEvent | ResponseFunctionCallArgumentsDeltaEvent
| ResponseFunctionCallArgumentsDoneEvent | ResponseFunctionCallArgumentsDoneEvent
| ResponseReasoningSummaryPartAddedEvent
| ResponseReasoningSummaryTextDeltaEvent
| ResponseReasoningSummaryTextDoneEvent
| ResponseReasoningSummaryPartDoneEvent
| ResponseCompletedEvent | ResponseCompletedEvent
) )

View File

@@ -10,9 +10,9 @@ from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
class DownloadProgressData(CamelCaseModel): class DownloadProgressData(CamelCaseModel):
total_bytes: Memory total: Memory
downloaded_bytes: Memory downloaded: Memory
downloaded_bytes_this_session: Memory downloaded_this_session: Memory
completed_files: int completed_files: int
total_files: int total_files: int
@@ -34,7 +34,7 @@ class DownloadPending(BaseDownloadProgress):
class DownloadCompleted(BaseDownloadProgress): class DownloadCompleted(BaseDownloadProgress):
total_bytes: Memory total: Memory
class DownloadFailed(BaseDownloadProgress): class DownloadFailed(BaseDownloadProgress):
@@ -86,9 +86,9 @@ class RepoDownloadProgress(BaseModel):
shard: ShardMetadata shard: ShardMetadata
completed_files: int completed_files: int
total_files: int total_files: int
downloaded_bytes: Memory downloaded: Memory
downloaded_bytes_this_session: Memory downloaded_this_session: Memory
total_bytes: Memory total: Memory
overall_speed: float overall_speed: float
overall_eta: timedelta overall_eta: timedelta
status: Literal["not_started", "in_progress", "complete"] status: Literal["not_started", "in_progress", "complete"]

View File

@@ -28,6 +28,7 @@ class GenerationResponse(BaseRunnerResponse):
finish_reason: FinishReason | None = None finish_reason: FinishReason | None = None
stats: GenerationStats | None = None stats: GenerationStats | None = None
usage: Usage | None usage: Usage | None
is_thinking: bool = False
class ImageGenerationResponse(BaseRunnerResponse): class ImageGenerationResponse(BaseRunnerResponse):

View File

@@ -192,7 +192,13 @@ class MpReceiver[T]:
try: try:
return self.receive_nowait() return self.receive_nowait()
except WouldBlock: except WouldBlock:
item = self._state.buffer.get() try:
item = self._state.buffer.get()
except (TypeError, OSError):
# Queue pipe can get closed while we are blocked on get().
# The underlying connection._handle becomes None, causing
# TypeError in read(handle, remaining).
raise ClosedResourceError from None
if isinstance(item, _MpEndOfStream): if isinstance(item, _MpEndOfStream):
self.close() self.close()
raise EndOfStream from None raise EndOfStream from None

View File

@@ -108,7 +108,7 @@ async def check_reachable(
await send.send((target_ip, expected_node_id)) await send.send((target_ip, expected_node_id))
async with ( async with (
httpx.AsyncClient(timeout=timeout, limits=limits) as client, httpx.AsyncClient(timeout=timeout, limits=limits, verify=False) as client,
create_task_group() as tg, create_task_group() as tg,
): ):
for node_id in topology.list_nodes(): for node_id in topology.list_nodes():

View File

@@ -166,7 +166,7 @@ def generate_image(
else 0.0 else 0.0
) )
peak_memory_gb = mx.get_peak_memory() / (1024**3) peak_memory = Memory.from_bytes(mx.get_peak_memory())
stats = ImageGenerationStats( stats = ImageGenerationStats(
seconds_per_step=seconds_per_step, seconds_per_step=seconds_per_step,
@@ -175,7 +175,7 @@ def generate_image(
num_images=num_images, num_images=num_images,
image_width=width, image_width=width,
image_height=height, image_height=height,
peak_memory_usage=Memory.from_gb(peak_memory_gb), peak_memory_usage=peak_memory,
) )
buffer = io.BytesIO() buffer = io.BytesIO()

View File

@@ -22,7 +22,7 @@ from exo.worker.runner.bootstrap import logger
# Fraction of device memory above which LRU eviction kicks in. # Fraction of device memory above which LRU eviction kicks in.
# Smaller machines need more aggressive eviction. # Smaller machines need more aggressive eviction.
def _default_memory_threshold() -> float: def _default_memory_threshold() -> float:
total_gb = psutil.virtual_memory().total / (1024**3) total_gb = Memory.from_bytes(psutil.virtual_memory().total).in_gb
if total_gb >= 128: if total_gb >= 128:
return 0.85 return 0.85
if total_gb >= 64: if total_gb >= 64:

View File

@@ -0,0 +1,72 @@
import json
import re
from typing import Any
from mlx_lm.chat_templates import deepseek_v32
from exo.shared.types.api import ToolCallItem
BOS_TOKEN: str = deepseek_v32.bos_token
EOS_TOKEN: str = deepseek_v32.eos_token
DSML_TOKEN: str = deepseek_v32.dsml_token
THINKING_START: str = deepseek_v32.thinking_start_token
THINKING_END: str = deepseek_v32.thinking_end_token
USER_TOKEN = "<\uff5cUser\uff5c>"
ASSISTANT_TOKEN = "<\uff5cAssistant\uff5c>"
TOOL_CALLS_START = f"<{DSML_TOKEN}function_calls>"
TOOL_CALLS_END = f"</{DSML_TOKEN}function_calls>"
encode_messages = deepseek_v32.encode_messages
_INVOKE_PATTERN = re.compile(
rf"<{re.escape(DSML_TOKEN)}invoke\s+name=\"([^\"]+)\">"
rf"(.*?)"
rf"</{re.escape(DSML_TOKEN)}invoke>",
re.DOTALL,
)
_PARAM_PATTERN = re.compile(
rf"<{re.escape(DSML_TOKEN)}parameter\s+name=\"([^\"]+)\"\s+string=\"(true|false)\">"
rf"(.*?)"
rf"</{re.escape(DSML_TOKEN)}parameter>",
re.DOTALL,
)
def parse_dsml_output(text: str) -> list[ToolCallItem] | None:
"""Parse DSML function_calls block from model output text.
Args:
text: The text containing the DSML function_calls block
(including the start/end markers).
Returns:
List of ToolCallItem, or None if parsing fails.
"""
tool_calls: list[ToolCallItem] = []
for invoke_match in _INVOKE_PATTERN.finditer(text):
func_name = invoke_match.group(1)
invoke_body = invoke_match.group(2)
args: dict[str, Any] = {}
for param_match in _PARAM_PATTERN.finditer(invoke_body):
param_name = param_match.group(1)
is_string = param_match.group(2) == "true"
param_value = param_match.group(3)
if is_string:
args[param_name] = param_value
else:
try:
args[param_name] = json.loads(param_value)
except (json.JSONDecodeError, ValueError):
args[param_name] = param_value
tool_calls.append(
ToolCallItem(
name=func_name,
arguments=json.dumps(args),
)
)
return tool_calls if tool_calls else None

View File

@@ -232,11 +232,11 @@ def shard_and_load(
# Estimate timeout based on model size (5x default for large queued workloads) # Estimate timeout based on model size (5x default for large queued workloads)
base_timeout = float(os.environ.get("EXO_MODEL_LOAD_TIMEOUT", "300")) base_timeout = float(os.environ.get("EXO_MODEL_LOAD_TIMEOUT", "300"))
model_size_gb = get_weights_size(shard_metadata).in_bytes / (1024**3) model_size = get_weights_size(shard_metadata)
timeout_seconds = base_timeout + model_size_gb timeout_seconds = base_timeout + model_size.in_gb
logger.info( logger.info(
f"Evaluating model parameters with timeout of {timeout_seconds:.0f}s " f"Evaluating model parameters with timeout of {timeout_seconds:.0f}s "
f"(model size: {model_size_gb:.1f}GB)" f"(model size: {model_size.in_gb:.1f}GB)"
) )
match shard_metadata: match shard_metadata:
@@ -458,6 +458,19 @@ def _patch_lossy_chat_template(template: str) -> str | None:
return patched if n > 0 else None return patched if n > 0 else None
def _needs_dsml_encoding(task_params: TextGenerationTaskParams) -> bool:
if "deepseek-v3.2" not in task_params.model.lower():
return False
# Use DSML encoding when tools are provided or tool results are in the conversation
if task_params.tools:
return True
if task_params.chat_template_messages:
return any(
msg.get("role") == "tool" for msg in task_params.chat_template_messages
)
return False
def apply_chat_template( def apply_chat_template(
tokenizer: TokenizerWrapper, tokenizer: TokenizerWrapper,
task_params: TextGenerationTaskParams, task_params: TextGenerationTaskParams,
@@ -469,7 +482,6 @@ def apply_chat_template(
When chat_template_messages is available (from Chat Completions API), When chat_template_messages is available (from Chat Completions API),
uses those directly to preserve tool_calls, thinking, and other fields. uses those directly to preserve tool_calls, thinking, and other fields.
Otherwise builds messages from the task params input/instructions.
""" """
formatted_messages: list[dict[str, Any]] = [] formatted_messages: list[dict[str, Any]] = []
if task_params.chat_template_messages is not None: if task_params.chat_template_messages is not None:
@@ -497,6 +509,19 @@ def apply_chat_template(
partial_assistant_content = cast(str, formatted_messages[-1].get("content", "")) partial_assistant_content = cast(str, formatted_messages[-1].get("content", ""))
formatted_messages = formatted_messages[:-1] formatted_messages = formatted_messages[:-1]
if _needs_dsml_encoding(task_params):
from exo.worker.engines.mlx.dsml_encoding import encode_messages
prompt = encode_messages(
messages=formatted_messages,
thinking_mode="thinking" if task_params.enable_thinking else "chat",
tools=task_params.tools,
)
if partial_assistant_content:
prompt += partial_assistant_content
logger.info(prompt)
return prompt
extra_kwargs: dict[str, Any] = {} extra_kwargs: dict[str, Any] = {}
if task_params.enable_thinking is not None: if task_params.enable_thinking is not None:
# Qwen3 and GLM use "enable_thinking"; DeepSeek uses "thinking". # Qwen3 and GLM use "enable_thinking"; DeepSeek uses "thinking".
@@ -617,18 +642,17 @@ def set_wired_limit_for_model(model_size: Memory):
if not mx.metal.is_available(): if not mx.metal.is_available():
return return
model_bytes = model_size.in_bytes max_rec_size = Memory.from_bytes(
max_rec_size = int(mx.metal.device_info()["max_recommended_working_set_size"]) int(mx.metal.device_info()["max_recommended_working_set_size"])
if model_bytes > 0.9 * max_rec_size: )
model_mb = model_bytes // 2**20 if model_size > 0.9 * max_rec_size:
max_rec_mb = max_rec_size // 2**20
logger.warning( logger.warning(
f"Generating with a model that requires {model_mb} MB " f"Generating with a model that requires {model_size.in_float_mb:.1f} MB "
f"which is close to the maximum recommended size of {max_rec_mb} " f"which is close to the maximum recommended size of {max_rec_size.in_float_mb:.1f} "
"MB. This can be slow. See the documentation for possible work-arounds: " "MB. This can be slow. See the documentation for possible work-arounds: "
"https://github.com/ml-explore/mlx-lm/tree/main#large-models" "https://github.com/ml-explore/mlx-lm/tree/main#large-models"
) )
mx.set_wired_limit(max_rec_size) mx.set_wired_limit(max_rec_size.in_bytes)
logger.info(f"Wired limit set to {max_rec_size}.") logger.info(f"Wired limit set to {max_rec_size}.")

View File

@@ -4,9 +4,10 @@ import resource
import time import time
from collections.abc import Generator from collections.abc import Generator
from functools import cache from functools import cache
from typing import Literal from typing import TYPE_CHECKING, Literal
import mlx.core as mx import mlx.core as mx
from mlx_lm.models.deepseek_v32 import Model as DeepseekV32Model
from mlx_lm.models.gpt_oss import Model as GptOssModel from mlx_lm.models.gpt_oss import Model as GptOssModel
from mlx_lm.tokenizer_utils import TokenizerWrapper from mlx_lm.tokenizer_utils import TokenizerWrapper
from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs] from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
@@ -21,12 +22,17 @@ from exo.shared.constants import EXO_MAX_CHUNK_SIZE, EXO_TRACING_ENABLED
from exo.shared.models.model_cards import ModelId, ModelTask from exo.shared.models.model_cards import ModelId, ModelTask
from exo.shared.tracing import clear_trace_buffer, get_trace_buffer from exo.shared.tracing import clear_trace_buffer, get_trace_buffer
from exo.shared.types.api import ImageGenerationStats from exo.shared.types.api import ImageGenerationStats
from exo.shared.types.chunks import ErrorChunk, ImageChunk, TokenChunk, ToolCallChunk from exo.shared.types.chunks import (
ErrorChunk,
ImageChunk,
PrefillProgressChunk,
TokenChunk,
ToolCallChunk,
)
from exo.shared.types.common import CommandId from exo.shared.types.common import CommandId
from exo.shared.types.events import ( from exo.shared.types.events import (
ChunkGenerated, ChunkGenerated,
Event, Event,
PrefillProgress,
RunnerStatusUpdated, RunnerStatusUpdated,
TaskAcknowledged, TaskAcknowledged,
TaskStatusUpdated, TaskStatusUpdated,
@@ -315,11 +321,13 @@ def main(
) -> None: ) -> None:
if device_rank == 0: if device_rank == 0:
event_sender.send( event_sender.send(
PrefillProgress( ChunkGenerated(
command_id=command_id, command_id=command_id,
model=shard_metadata.model_card.model_id, chunk=PrefillProgressChunk(
processed_tokens=processed, model=shard_metadata.model_card.model_id,
total_tokens=total, processed_tokens=processed,
total_tokens=total,
),
) )
) )
cancelled_tasks.update(cancel_receiver.collect()) cancelled_tasks.update(cancel_receiver.collect())
@@ -346,16 +354,22 @@ def main(
group=group, group=group,
) )
# For other thinking models (GLM, etc.), check if we need to if tokenizer.has_thinking:
# prepend the thinking tag that was consumed by the chat template
if detect_thinking_prompt_suffix(prompt, tokenizer):
mlx_generator = parse_thinking_models( mlx_generator = parse_thinking_models(
mlx_generator, tokenizer mlx_generator,
tokenizer,
# For other thinking models (GLM, etc.), check if we need to
# prepend the thinking tag that was consumed by the chat template
starts_in_thinking=detect_thinking_prompt_suffix(
prompt, tokenizer
),
) )
# GPT-OSS specific parsing to match other model formats. # Model-specific output parsing for tool calls.
if isinstance(inference_model, GptOssModel): if isinstance(inference_model, GptOssModel):
mlx_generator = parse_gpt_oss(mlx_generator) mlx_generator = parse_gpt_oss(mlx_generator)
elif isinstance(inference_model, DeepseekV32Model):
mlx_generator = parse_deepseek_v32(mlx_generator)
elif tool_parser: elif tool_parser:
mlx_generator = parse_tool_calls(mlx_generator, tool_parser) mlx_generator = parse_tool_calls(mlx_generator, tool_parser)
@@ -407,6 +421,7 @@ def main(
stats=response.stats, stats=response.stats,
logprob=response.logprob, logprob=response.logprob,
top_logprobs=response.top_logprobs, top_logprobs=response.top_logprobs,
is_thinking=response.is_thinking,
), ),
) )
) )
@@ -573,6 +588,13 @@ def main(
case Shutdown(): case Shutdown():
current_status = RunnerShuttingDown() current_status = RunnerShuttingDown()
logger.info("runner shutting down") logger.info("runner shutting down")
if not TYPE_CHECKING:
del inference_model, image_model, tokenizer, group
mx.clear_cache()
import gc
gc.collect()
event_sender.send( event_sender.send(
RunnerStatusUpdated( RunnerStatusUpdated(
runner_id=runner_id, runner_status=current_status runner_id=runner_id, runner_status=current_status
@@ -597,12 +619,8 @@ def main(
event_sender.send( event_sender.send(
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status) RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
) )
if isinstance(current_status, RunnerShutdown):
del inference_model, image_model, tokenizer, group
mx.clear_cache()
import gc
gc.collect() if isinstance(current_status, RunnerShutdown):
break break
@@ -668,44 +686,208 @@ def parse_gpt_oss(
if ch == "analysis" and not thinking: if ch == "analysis" and not thinking:
thinking = True thinking = True
yield response.model_copy(update={"text": "<think>"})
if ch != "analysis" and thinking: if ch != "analysis" and thinking:
thinking = False thinking = False
yield response.model_copy(update={"text": "</think>"})
if delta: if delta:
yield response.model_copy(update={"text": delta}) yield response.model_copy(update={"text": delta, "is_thinking": thinking})
if response.finish_reason is not None: if response.finish_reason is not None:
if thinking:
yield response.model_copy(update={"text": "</think>"})
yield response yield response
def parse_deepseek_v32(
responses: Generator[GenerationResponse],
) -> Generator[GenerationResponse | ToolCallResponse]:
"""Parse DeepSeek V3.2 DSML tool calls from the generation stream.
Uses accumulated-text matching (not per-token marker checks) because
DSML markers like <DSMLfunction_calls> may span multiple tokens.
Also handles <think>...</think> blocks for thinking mode.
"""
from exo.worker.engines.mlx.dsml_encoding import (
THINKING_END,
THINKING_START,
TOOL_CALLS_END,
TOOL_CALLS_START,
parse_dsml_output,
)
accumulated = ""
in_tool_call = False
thinking = False
# Tokens buffered while we detect the start of a DSML block
pending_buffer: list[GenerationResponse] = []
# Text accumulated during a tool call block
tool_call_text = ""
for response in responses:
assert isinstance(response, GenerationResponse)
# ── Handle thinking tags ──
if not thinking and THINKING_START in response.text:
thinking = True
# Yield any text before the <think> tag
before = response.text[: response.text.index(THINKING_START)]
if before:
yield response.model_copy(update={"text": before})
continue
if thinking and THINKING_END in response.text:
thinking = False
# Yield any text after the </think> tag
after = response.text[
response.text.index(THINKING_END) + len(THINKING_END) :
]
if after:
yield response.model_copy(update={"text": after, "is_thinking": False})
continue
if thinking:
yield response.model_copy(update={"is_thinking": True})
continue
# ── Handle tool call accumulation ──
if in_tool_call:
tool_call_text += response.text
if TOOL_CALLS_END in tool_call_text:
# Parse the accumulated DSML block
parsed = parse_dsml_output(tool_call_text)
if parsed is not None:
logger.info(f"parsed DSML tool calls: {parsed}")
yield ToolCallResponse(
tool_calls=parsed,
usage=response.usage,
stats=response.stats,
)
else:
logger.warning(
f"DSML tool call parsing failed for: {tool_call_text}"
)
yield response.model_copy(update={"text": tool_call_text})
in_tool_call = False
tool_call_text = ""
continue
# EOS reached before end marker — yield buffered text as-is
if response.finish_reason is not None:
logger.info("DSML tool call parsing interrupted by EOS")
yield response.model_copy(update={"text": tool_call_text})
in_tool_call = False
tool_call_text = ""
continue
# ── Detect start of tool call block ──
accumulated += response.text
if TOOL_CALLS_START in accumulated:
# The start marker might be split across pending_buffer + current token
start_idx = accumulated.index(TOOL_CALLS_START)
# Yield any pending tokens that are purely before the marker
pre_text = accumulated[:start_idx]
if pre_text:
# Flush pending buffer tokens that contributed text before the marker
for buf_resp in pending_buffer:
if pre_text:
chunk = buf_resp.text
if len(chunk) <= len(pre_text):
yield buf_resp
pre_text = pre_text[len(chunk) :]
else:
yield buf_resp.model_copy(update={"text": pre_text})
pre_text = ""
pending_buffer = []
tool_call_text = accumulated[start_idx:]
accumulated = ""
# Check if the end marker is already present (entire tool call in one token)
if TOOL_CALLS_END in tool_call_text:
parsed = parse_dsml_output(tool_call_text)
if parsed is not None:
logger.info(f"parsed DSML tool calls: {parsed}")
yield ToolCallResponse(
tool_calls=parsed,
usage=response.usage,
stats=response.stats,
)
else:
logger.warning(
f"DSML tool call parsing failed for: {tool_call_text}"
)
yield response.model_copy(update={"text": tool_call_text})
tool_call_text = ""
else:
in_tool_call = True
continue
# Check if accumulated text might be the start of a DSML marker
# Buffer tokens if we see a partial match at the end
if _could_be_dsml_prefix(accumulated):
pending_buffer.append(response)
continue
# No partial match — flush all pending tokens and the current one
for buf_resp in pending_buffer:
yield buf_resp
pending_buffer = []
accumulated = ""
yield response
# Flush any remaining pending buffer at generator end
for buf_resp in pending_buffer:
yield buf_resp
def _could_be_dsml_prefix(text: str) -> bool:
"""Check if the end of text could be the start of a DSML function_calls marker.
We look for suffixes of text that are prefixes of the TOOL_CALLS_START pattern.
This allows us to buffer tokens until we can determine if a tool call is starting.
"""
from exo.worker.engines.mlx.dsml_encoding import TOOL_CALLS_START
# Only check the last portion of text that could overlap with the marker
max_check = len(TOOL_CALLS_START)
tail = text[-max_check:] if len(text) > max_check else text
# Check if any suffix of tail is a prefix of TOOL_CALLS_START
for i in range(len(tail)):
suffix = tail[i:]
if TOOL_CALLS_START.startswith(suffix):
return True
return False
def parse_thinking_models( def parse_thinking_models(
responses: Generator[GenerationResponse], responses: Generator[GenerationResponse],
tokenizer: TokenizerWrapper, tokenizer: TokenizerWrapper,
starts_in_thinking: bool = True,
) -> Generator[GenerationResponse]: ) -> Generator[GenerationResponse]:
"""Route thinking tokens via is_thinking flag.
Swallows think tag tokens, sets is_thinking on all others.
Always yields tokens with finish_reason to avoid hanging the chunk stream.
""" """
For models that inject thinking tags in the prompt (like GLM-4.7), in_thinking = starts_in_thinking
prepend the thinking tag to the output stream so the frontend
can properly parse thinking content.
"""
first = True
for response in responses: for response in responses:
if isinstance(response, ToolCallResponse): if isinstance(response, ToolCallResponse):
yield response yield response
continue continue
if first:
first = False is_think_tag = (
yield response.model_copy( tokenizer.think_end is not None and response.text == tokenizer.think_end
update={ ) or (
"text": tokenizer.think_start, tokenizer.think_start is not None and response.text == tokenizer.think_start
"token": tokenizer.think_start_id, )
}
) if is_think_tag:
yield response in_thinking = response.text != tokenizer.think_end
# Never swallow finish_reason — the chunk stream needs it to terminate.
if response.finish_reason is not None:
yield response.model_copy(update={"text": "", "is_thinking": False})
continue
yield response.model_copy(update={"is_thinking": in_thinking})
def _send_image_chunk( def _send_image_chunk(

View File

@@ -100,8 +100,8 @@ class RunnerSupervisor:
logger.info("Runner supervisor shutting down") logger.info("Runner supervisor shutting down")
self._ev_recv.close() self._ev_recv.close()
self._task_sender.close() self._task_sender.close()
self._event_sender.close() with contextlib.suppress(ClosedResourceError):
self._cancel_sender.send(TaskId("CANCEL_CURRENT_TASK")) self._cancel_sender.send(TaskId("CANCEL_CURRENT_TASK"))
self._cancel_sender.close() self._cancel_sender.close()
self.runner_process.join(5) self.runner_process.join(5)
if not self.runner_process.is_alive(): if not self.runner_process.is_alive():
@@ -180,6 +180,7 @@ class RunnerSupervisor:
await self._check_runner(e) await self._check_runner(e)
for tid in self.pending: for tid in self.pending:
self.pending[tid].set() self.pending[tid].set()
self._event_sender.close()
def __del__(self) -> None: def __del__(self) -> None:
if self.runner_process.is_alive(): if self.runner_process.is_alive():
@@ -208,10 +209,15 @@ class RunnerSupervisor:
logger.opt(exception=e).error(f"Runner terminated ({cause})") logger.opt(exception=e).error(f"Runner terminated ({cause})")
await self._event_sender.send( try:
RunnerStatusUpdated( await self._event_sender.send(
runner_id=self.bound_instance.bound_runner_id, RunnerStatusUpdated(
runner_status=RunnerFailed(error_message=f"Terminated ({cause})"), runner_id=self.bound_instance.bound_runner_id,
runner_status=RunnerFailed(error_message=f"Terminated ({cause})"),
)
)
except (ClosedResourceError, BrokenResourceError):
logger.warning(
"Event sender already closed, unable to report runner failure"
) )
)
self.shutdown() self.shutdown()

View File

@@ -90,14 +90,10 @@ def test_plan_loads_model_when_all_shards_downloaded_and_waiting():
global_download_status = { global_download_status = {
NODE_A: [ NODE_A: [
DownloadCompleted( DownloadCompleted(shard_metadata=shard1, node_id=NODE_A, total=Memory())
shard_metadata=shard1, node_id=NODE_A, total_bytes=Memory()
)
], ],
NODE_B: [ NODE_B: [
DownloadCompleted( DownloadCompleted(shard_metadata=shard2, node_id=NODE_B, total=Memory())
shard_metadata=shard2, node_id=NODE_B, total_bytes=Memory()
)
], ],
} }
@@ -138,9 +134,7 @@ def test_plan_does_not_request_download_when_shard_already_downloaded():
# Global state shows shard is downloaded for NODE_A # Global state shows shard is downloaded for NODE_A
global_download_status: dict[NodeId, list[DownloadProgress]] = { global_download_status: dict[NodeId, list[DownloadProgress]] = {
NODE_A: [ NODE_A: [
DownloadCompleted( DownloadCompleted(shard_metadata=shard, node_id=NODE_A, total=Memory())
shard_metadata=shard, node_id=NODE_A, total_bytes=Memory()
)
], ],
NODE_B: [], NODE_B: [],
} }
@@ -187,9 +181,7 @@ def test_plan_does_not_load_model_until_all_shards_downloaded_globally():
global_download_status = { global_download_status = {
NODE_A: [ NODE_A: [
DownloadCompleted( DownloadCompleted(shard_metadata=shard1, node_id=NODE_A, total=Memory())
shard_metadata=shard1, node_id=NODE_A, total_bytes=Memory()
)
], ],
NODE_B: [], # NODE_B has no downloads completed yet NODE_B: [], # NODE_B has no downloads completed yet
} }
@@ -207,14 +199,10 @@ def test_plan_does_not_load_model_until_all_shards_downloaded_globally():
global_download_status = { global_download_status = {
NODE_A: [ NODE_A: [
DownloadCompleted( DownloadCompleted(shard_metadata=shard1, node_id=NODE_A, total=Memory())
shard_metadata=shard1, node_id=NODE_A, total_bytes=Memory()
)
], ],
NODE_B: [ NODE_B: [
DownloadCompleted( DownloadCompleted(shard_metadata=shard2, node_id=NODE_B, total=Memory())
shard_metadata=shard2, node_id=NODE_B, total_bytes=Memory()
)
], # NODE_B has no downloads completed yet ], # NODE_B has no downloads completed yet
} }

View File

@@ -0,0 +1,967 @@
import json
from collections.abc import Generator
from typing import Any
from exo.shared.types.worker.runner_response import (
GenerationResponse,
ToolCallResponse,
)
from exo.worker.engines.mlx.dsml_encoding import (
ASSISTANT_TOKEN,
BOS_TOKEN,
DSML_TOKEN,
EOS_TOKEN,
THINKING_END,
THINKING_START,
TOOL_CALLS_END,
TOOL_CALLS_START,
USER_TOKEN,
encode_messages,
parse_dsml_output,
)
from exo.worker.runner.runner import parse_deepseek_v32
# ── Shared fixtures ──────────────────────────────────────────────
_WEATHER_TOOLS: list[dict[str, Any]] = [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get the current weather in a given city",
"parameters": {
"type": "object",
"properties": {
"city": {"type": "string", "description": "The city name"},
"units": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
"description": "Temperature units",
},
},
"required": ["city"],
},
},
},
{
"type": "function",
"function": {
"name": "get_time",
"description": "Get the current time in a timezone",
"parameters": {
"type": "object",
"properties": {
"timezone": {"type": "string"},
},
"required": ["timezone"],
},
},
},
]
def _simulate_tokens(
texts: list[str],
finish_on_last: bool = True,
) -> Generator[GenerationResponse]:
"""Simulate a model producing tokens from a list of text strings."""
for i, text in enumerate(texts):
is_last = i == len(texts) - 1
yield GenerationResponse(
text=text,
token=i,
finish_reason="stop" if (is_last and finish_on_last) else None,
usage=None,
)
# ── Test: Standard text response (no tool calls) ────────────────
class TestE2EStandardResponse:
"""Model generates a plain text response — no tool calling involved."""
def test_plain_text_passthrough(self):
"""Simulate model producing: 'The weather in NYC is 72°F and sunny.'"""
# Step 1: Encode the prompt (with tools available)
messages: list[dict[str, Any]] = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What's the weather in NYC?"},
]
prompt = encode_messages(messages, thinking_mode="chat", tools=_WEATHER_TOOLS)
# Verify prompt structure
assert BOS_TOKEN in prompt
assert "## Tools" in prompt
assert "get_weather" in prompt
assert f"{USER_TOKEN}What's the weather in NYC?{ASSISTANT_TOKEN}" in prompt
# Step 2: Simulate model response — plain text tokens (no DSML)
model_tokens = [
"The weather",
" in NYC",
" is 72",
"°F",
" and sunny",
".",
]
results = list(parse_deepseek_v32(_simulate_tokens(model_tokens)))
# Step 3: Verify all tokens pass through as GenerationResponse
gen_results = [r for r in results if isinstance(r, GenerationResponse)]
tool_results = [r for r in results if isinstance(r, ToolCallResponse)]
assert len(tool_results) == 0
assert len(gen_results) == 6
full_text = "".join(r.text for r in gen_results)
assert full_text == "The weather in NYC is 72°F and sunny."
assert gen_results[-1].finish_reason == "stop"
# ── Test: Tool call response ─────────────────────────────────────
class TestE2EToolCallResponse:
"""Model generates a DSML tool call — realistic token boundaries."""
def test_realistic_tool_call_tokens(self):
"""Simulate model generating a get_weather tool call with realistic token splits.
Real models split DSML markers across tokens unpredictably.
This simulates how DeepSeek V3.2 actually tokenizes DSML output.
"""
# Step 1: Encode prompt
messages: list[dict[str, Any]] = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What's the weather in San Francisco?"},
]
prompt = encode_messages(messages, thinking_mode="chat", tools=_WEATHER_TOOLS)
assert "get_weather" in prompt
# Step 2: Simulate realistic token-by-token model output
# The model first produces some text, then a DSML tool call block
model_tokens = [
"I'll check the weather for you.",
"\n\n",
f"<{DSML_TOKEN}", # marker split across tokens
"function_calls>\n",
f'<{DSML_TOKEN}invoke name="get_weather">\n',
f'<{DSML_TOKEN}parameter name="city" string="true">',
"San Francisco",
f"</{DSML_TOKEN}parameter>\n",
f'<{DSML_TOKEN}parameter name="units" string="false">',
'"celsius"',
f"</{DSML_TOKEN}parameter>\n",
f"</{DSML_TOKEN}invoke>\n",
f"</{DSML_TOKEN}function_calls>",
]
results = list(parse_deepseek_v32(_simulate_tokens(model_tokens)))
# Step 3: Verify
gen_results = [r for r in results if isinstance(r, GenerationResponse)]
tool_results = [r for r in results if isinstance(r, ToolCallResponse)]
# Should have text tokens before tool call + one ToolCallResponse
assert len(tool_results) == 1
assert len(tool_results[0].tool_calls) == 1
tc = tool_results[0].tool_calls[0]
assert tc.name == "get_weather"
args = json.loads(tc.arguments) # pyright: ignore[reportAny]
assert args["city"] == "San Francisco"
assert args["units"] == "celsius"
# The text before the tool call should still be yielded
text_before = "".join(r.text for r in gen_results if not r.is_thinking)
assert "check the weather" in text_before
def test_multiple_tool_calls_in_one_block(self):
"""Model generates two tool calls in a single function_calls block."""
messages: list[dict[str, Any]] = [
{"role": "system", "content": "You are helpful."},
{"role": "user", "content": "Weather in NYC and time in EST?"},
]
prompt = encode_messages(messages, thinking_mode="chat", tools=_WEATHER_TOOLS)
assert "get_weather" in prompt
assert "get_time" in prompt
# Simulate model output with two invocations
model_tokens = [
"Let me check both.\n\n",
TOOL_CALLS_START,
"\n",
f'<{DSML_TOKEN}invoke name="get_weather">\n',
f'<{DSML_TOKEN}parameter name="city" string="true">NYC</{DSML_TOKEN}parameter>\n',
f"</{DSML_TOKEN}invoke>\n",
f'<{DSML_TOKEN}invoke name="get_time">\n',
f'<{DSML_TOKEN}parameter name="timezone" string="true">EST</{DSML_TOKEN}parameter>\n',
f"</{DSML_TOKEN}invoke>\n",
TOOL_CALLS_END,
]
results = list(parse_deepseek_v32(_simulate_tokens(model_tokens)))
tool_results = [r for r in results if isinstance(r, ToolCallResponse)]
assert len(tool_results) == 1
assert len(tool_results[0].tool_calls) == 2
assert tool_results[0].tool_calls[0].name == "get_weather"
assert tool_results[0].tool_calls[1].name == "get_time"
args0 = json.loads(tool_results[0].tool_calls[0].arguments) # pyright: ignore[reportAny]
args1 = json.loads(tool_results[0].tool_calls[1].arguments) # pyright: ignore[reportAny]
assert args0 == {"city": "NYC"}
assert args1 == {"timezone": "EST"}
# ── Test: Multi-turn tool use flow ───────────────────────────────
class TestE2EMultiTurnToolUse:
"""Full multi-turn: user asks → model calls tool → tool result → model answers."""
def test_encode_multi_turn_with_tool_results(self):
"""Verify the prompt for turn 2 (after tool results) is correctly encoded."""
# Turn 1: user asks, model calls tool
# Turn 2: tool result provided, model answers
messages: list[dict[str, Any]] = [
{"role": "system", "content": "You are a weather assistant."},
{"role": "user", "content": "What's the weather in NYC?"},
{
"role": "assistant",
"content": "",
"tool_calls": [
{
"type": "function",
"function": {
"name": "get_weather",
"arguments": '{"city": "NYC"}',
},
}
],
},
{"role": "tool", "content": '{"temperature": 72, "condition": "sunny"}'},
]
prompt = encode_messages(messages, thinking_mode="chat", tools=_WEATHER_TOOLS)
# Verify multi-turn structure
assert BOS_TOKEN in prompt
assert "You are a weather assistant." in prompt
assert "## Tools" in prompt
# The assistant's tool call should be encoded as DSML
assert TOOL_CALLS_START in prompt
assert f'<{DSML_TOKEN}invoke name="get_weather">' in prompt
assert EOS_TOKEN in prompt
# The tool result should be wrapped in function_results
assert "<function_results>" in prompt
assert "<result>" in prompt
assert "72" in prompt
assert "</function_results>" in prompt
# Now simulate model answering after seeing the tool result
model_tokens = [
"The current",
" weather in NYC",
" is 72°F",
" and sunny.",
]
results = list(parse_deepseek_v32(_simulate_tokens(model_tokens)))
gen_results = [r for r in results if isinstance(r, GenerationResponse)]
tool_results = [r for r in results if isinstance(r, ToolCallResponse)]
assert len(tool_results) == 0
full_text = "".join(r.text for r in gen_results)
assert full_text == "The current weather in NYC is 72°F and sunny."
def test_multi_tool_results_encoding(self):
"""Verify encoding when model called two tools and both return results."""
messages: list[dict[str, Any]] = [
{"role": "user", "content": "Weather and time?"},
{
"role": "assistant",
"content": "",
"tool_calls": [
{
"type": "function",
"function": {
"name": "get_weather",
"arguments": '{"city": "LA"}',
},
},
{
"type": "function",
"function": {
"name": "get_time",
"arguments": '{"timezone": "PST"}',
},
},
],
},
{"role": "tool", "content": "85F, clear skies"},
{"role": "tool", "content": "3:42 PM PST"},
]
prompt = encode_messages(messages, thinking_mode="chat", tools=_WEATHER_TOOLS)
# Should have one function_results block with two results
assert prompt.count("<function_results>") == 1
assert prompt.count("</function_results>") == 1
assert "<result>85F, clear skies</result>" in prompt
assert "<result>3:42 PM PST</result>" in prompt
# ── Test: Thinking + tool call ───────────────────────────────────
class TestE2EThinkingAndToolCall:
"""Model uses thinking mode, reasons, then makes a tool call."""
def test_thinking_then_tool_call(self):
"""Model thinks first, then produces a DSML tool call block."""
messages: list[dict[str, Any]] = [
{"role": "user", "content": "What's the weather?"},
]
prompt = encode_messages(
messages, tools=_WEATHER_TOOLS, thinking_mode="thinking"
)
# Thinking mode: prompt should end with <think>
assert prompt.endswith(THINKING_START)
# Simulate: model outputs <think>, thinks, closes thinking, then tool call.
# In the full pipeline, parse_thinking_models handles the case where
# <think> is in the prompt. Here we test parse_deepseek_v32 directly,
# which detects <think>/<think> markers in the stream.
model_tokens = [
THINKING_START,
"The user wants weather",
" information. I should use",
" the get_weather tool.",
THINKING_END,
"\n\n",
TOOL_CALLS_START,
"\n",
f'<{DSML_TOKEN}invoke name="get_weather">\n',
f'<{DSML_TOKEN}parameter name="city" string="true">',
"San Francisco",
f"</{DSML_TOKEN}parameter>\n",
f"</{DSML_TOKEN}invoke>\n",
TOOL_CALLS_END,
]
results = list(parse_deepseek_v32(_simulate_tokens(model_tokens)))
gen_results = [r for r in results if isinstance(r, GenerationResponse)]
tool_results = [r for r in results if isinstance(r, ToolCallResponse)]
# Should have thinking tokens + tool call
thinking_results = [r for r in gen_results if r.is_thinking]
assert len(thinking_results) >= 1
thinking_text = "".join(r.text for r in thinking_results)
assert "get_weather tool" in thinking_text
assert len(tool_results) == 1
assert tool_results[0].tool_calls[0].name == "get_weather"
args = json.loads(tool_results[0].tool_calls[0].arguments) # pyright: ignore[reportAny]
assert args["city"] == "San Francisco"
def test_thinking_prompt_encoding(self):
"""Verify thinking mode affects prompt encoding correctly."""
messages: list[dict[str, Any]] = [
{"role": "system", "content": "Be thorough."},
{"role": "user", "content": "What's the weather?"},
]
# With thinking enabled
prompt_think = encode_messages(
messages, tools=_WEATHER_TOOLS, thinking_mode="thinking"
)
assert prompt_think.endswith(THINKING_START)
# With thinking disabled
prompt_no_think = encode_messages(
messages, tools=_WEATHER_TOOLS, thinking_mode="chat"
)
assert prompt_no_think.endswith(THINKING_END)
# Both should have the same tool definitions
assert "get_weather" in prompt_think
assert "get_weather" in prompt_no_think
# ── Test: Round-trip encode → parse ──────────────────────────────
class TestE2ERoundTrip:
"""Verify that DSML we encode can be parsed back correctly."""
def test_encoded_tool_call_is_parseable(self):
"""Encode an assistant tool call message, then parse the DSML output."""
messages: list[dict[str, Any]] = [
{"role": "user", "content": "Weather?"},
{
"role": "assistant",
"content": "",
"tool_calls": [
{
"type": "function",
"function": {
"name": "get_weather",
"arguments": '{"city": "Tokyo", "units": "celsius"}',
},
}
],
},
]
prompt = encode_messages(messages, thinking_mode="chat", tools=_WEATHER_TOOLS)
# Extract the DSML function_calls block from the prompt
start = prompt.index(TOOL_CALLS_START)
end = prompt.index(TOOL_CALLS_END) + len(TOOL_CALLS_END)
dsml_block = prompt[start:end]
# Parse it back
parsed = parse_dsml_output(dsml_block)
assert parsed is not None
assert len(parsed) == 1
assert parsed[0].name == "get_weather"
args = json.loads(parsed[0].arguments) # pyright: ignore[reportAny]
assert args["city"] == "Tokyo"
assert args["units"] == "celsius"
def test_encoded_multi_tool_call_round_trips(self):
"""Encode multiple tool calls, verify they parse back correctly."""
messages: list[dict[str, Any]] = [
{"role": "user", "content": "Both please"},
{
"role": "assistant",
"content": "",
"tool_calls": [
{
"type": "function",
"function": {
"name": "get_weather",
"arguments": '{"city": "Paris"}',
},
},
{
"type": "function",
"function": {
"name": "get_time",
"arguments": '{"timezone": "CET"}',
},
},
],
},
]
prompt = encode_messages(messages, thinking_mode="chat", tools=_WEATHER_TOOLS)
start = prompt.index(TOOL_CALLS_START)
end = prompt.index(TOOL_CALLS_END) + len(TOOL_CALLS_END)
dsml_block = prompt[start:end]
parsed = parse_dsml_output(dsml_block)
assert parsed is not None
assert len(parsed) == 2
assert parsed[0].name == "get_weather"
assert parsed[1].name == "get_time"
assert json.loads(parsed[0].arguments) == {"city": "Paris"}
assert json.loads(parsed[1].arguments) == {"timezone": "CET"}
# ── Test: Edge cases with realistic token boundaries ─────────────
class TestE2EEdgeCases:
"""Edge cases that occur in real model inference."""
def test_dsml_marker_split_at_fullwidth_pipe(self):
"""The fullwidth pipe character might be its own token."""
# This is a realistic tokenization: the DSML marker is split at the chars
model_tokens = [
"Let me help.\n\n",
"<\uff5c", # start of DSML
"DSML\uff5c", # rest of DSML token
"function_calls>\n",
f'<{DSML_TOKEN}invoke name="get_weather">\n',
f'<{DSML_TOKEN}parameter name="city" string="true">NYC</{DSML_TOKEN}parameter>\n',
f"</{DSML_TOKEN}invoke>\n",
TOOL_CALLS_END,
]
results = list(parse_deepseek_v32(_simulate_tokens(model_tokens)))
tool_results = [r for r in results if isinstance(r, ToolCallResponse)]
assert len(tool_results) == 1
assert tool_results[0].tool_calls[0].name == "get_weather"
def test_tool_call_with_nested_json_object(self):
"""Model passes a complex JSON object as a non-string parameter."""
dsml_block = (
f"{TOOL_CALLS_START}\n"
f'<{DSML_TOKEN}invoke name="create_event">\n'
f'<{DSML_TOKEN}parameter name="title" string="true">Team Standup</{DSML_TOKEN}parameter>\n'
f'<{DSML_TOKEN}parameter name="config" string="false">'
f'{{"recurring": true, "days": ["mon", "wed", "fri"], "time": "09:00"}}'
f"</{DSML_TOKEN}parameter>\n"
f"</{DSML_TOKEN}invoke>\n"
f"{TOOL_CALLS_END}"
)
# Feed as single token (model might produce it all at once after prefill)
results = list(parse_deepseek_v32(_simulate_tokens([dsml_block])))
tool_results = [r for r in results if isinstance(r, ToolCallResponse)]
assert len(tool_results) == 1
tc = tool_results[0].tool_calls[0]
assert tc.name == "create_event"
args = json.loads(tc.arguments) # pyright: ignore[reportAny]
assert args["title"] == "Team Standup"
assert args["config"]["recurring"] is True
assert args["config"]["days"] == ["mon", "wed", "fri"]
def test_text_with_angle_brackets_not_mistaken_for_dsml(self):
"""Angle brackets in normal text should not trigger DSML buffering."""
model_tokens = [
"The formula is ",
"<x, y>",
" where x > 0",
" and y < 100.",
]
results = list(parse_deepseek_v32(_simulate_tokens(model_tokens)))
gen_results = [r for r in results if isinstance(r, GenerationResponse)]
tool_results = [r for r in results if isinstance(r, ToolCallResponse)]
assert len(tool_results) == 0
full_text = "".join(r.text for r in gen_results)
assert "formula" in full_text
assert "<x, y>" in full_text
def test_empty_model_response(self):
"""Model produces only EOS (empty response)."""
model_tokens = [""]
results = list(parse_deepseek_v32(_simulate_tokens(model_tokens)))
gen_results = [r for r in results if isinstance(r, GenerationResponse)]
assert len(gen_results) == 1
assert gen_results[0].text == ""
assert gen_results[0].finish_reason == "stop"
# ── Test: Full EPDP spec round-trip ──────────────────────────────
class TestE2EFullRoundTrip:
"""Full round-trip matching the vLLM EPDP spec.
Simulates the complete multi-turn flow:
Turn 1: user asks → think → tool call → tool result → think → answer
Turn 2: user asks again → old reasoning stripped → think → answer
"""
def test_single_tool_full_flow_with_thinking(self):
"""Complete flow: user → think → tool call → tool result → think → answer.
This is the core EPDP flow from the vLLM spec.
"""
# ── Turn 1.1: User asks, encode prompt ──
messages: list[dict[str, Any]] = [
{"role": "system", "content": "You are a weather assistant."},
{"role": "user", "content": "How's the weather in Hangzhou?"},
]
prompt_1 = encode_messages(
messages, tools=_WEATHER_TOOLS, thinking_mode="thinking"
)
assert prompt_1.endswith(THINKING_START)
assert "## Tools" in prompt_1
assert "get_weather" in prompt_1
# ── Turn 1.1: Model thinks, then calls tool ──
model_tokens_1 = [
THINKING_START,
"The user wants to know the weather in Hangzhou.",
" I need to use the get_weather tool.",
THINKING_END,
"\n\n",
TOOL_CALLS_START,
"\n",
f'<{DSML_TOKEN}invoke name="get_weather">\n',
f'<{DSML_TOKEN}parameter name="city" string="true">Hangzhou</{DSML_TOKEN}parameter>\n',
f"</{DSML_TOKEN}invoke>\n",
TOOL_CALLS_END,
]
results_1 = list(parse_deepseek_v32(_simulate_tokens(model_tokens_1)))
# Verify: thinking tokens + tool call
gen_1 = [r for r in results_1 if isinstance(r, GenerationResponse)]
tool_1 = [r for r in results_1 if isinstance(r, ToolCallResponse)]
thinking_1 = [r for r in gen_1 if r.is_thinking]
assert len(thinking_1) >= 1
assert "get_weather tool" in "".join(r.text for r in thinking_1)
assert len(tool_1) == 1
assert tool_1[0].tool_calls[0].name == "get_weather"
tc_args = json.loads(tool_1[0].tool_calls[0].arguments) # pyright: ignore[reportAny]
assert tc_args == {"city": "Hangzhou"}
# ── Turn 1.2: Add assistant response + tool result to messages ──
messages.append(
{
"role": "assistant",
"content": "",
"reasoning_content": "The user wants to know the weather in Hangzhou. I need to use the get_weather tool.",
"tool_calls": [
{
"type": "function",
"function": {
"name": "get_weather",
"arguments": '{"city": "Hangzhou"}',
},
}
],
}
)
messages.append(
{
"role": "tool",
"content": '{"temperature": "7~13°C", "condition": "Cloudy"}',
}
)
# Encode prompt for turn 1.2
prompt_2 = encode_messages(
messages, tools=_WEATHER_TOOLS, thinking_mode="thinking"
)
# Verify: prompt has the full conversation structure
assert TOOL_CALLS_START in prompt_2 # assistant's encoded tool call
assert EOS_TOKEN in prompt_2 # assistant turn ends with EOS
assert "<function_results>" in prompt_2
assert "<result>" in prompt_2
assert "Cloudy" in prompt_2
assert "</function_results>" in prompt_2
# After tool results with thinking enabled → <think> appended
assert prompt_2.endswith(THINKING_START)
# The assistant's reasoning_content should appear (it's after last_user_idx)
assert "get_weather tool" in prompt_2
# ── Turn 1.2: Model thinks about results, then answers ──
model_tokens_2 = [
THINKING_START,
"The weather in Hangzhou is Cloudy, 7~13°C.",
" I'll tell the user.",
THINKING_END,
"The weather in Hangzhou is currently cloudy with temperatures between 7°C and 13°C.",
]
results_2 = list(parse_deepseek_v32(_simulate_tokens(model_tokens_2)))
gen_2 = [r for r in results_2 if isinstance(r, GenerationResponse)]
tool_2 = [r for r in results_2 if isinstance(r, ToolCallResponse)]
thinking_2 = [r for r in gen_2 if r.is_thinking]
non_thinking_2 = [r for r in gen_2 if not r.is_thinking]
assert len(tool_2) == 0 # No more tool calls
assert len(thinking_2) >= 1
assert "Cloudy" in "".join(r.text for r in thinking_2)
assert len(non_thinking_2) >= 1
final_text = "".join(r.text for r in non_thinking_2)
assert "7°C" in final_text
assert "13°C" in final_text
def test_multi_tool_full_flow(self):
"""Flow with two tools: user → think → 2 tool calls → 2 results → think → answer."""
# ── Initial prompt ──
messages: list[dict[str, Any]] = [
{"role": "system", "content": "You help with weather and time."},
{"role": "user", "content": "Weather in NYC and time in EST?"},
]
prompt_1 = encode_messages(
messages, tools=_WEATHER_TOOLS, thinking_mode="thinking"
)
assert prompt_1.endswith(THINKING_START)
# ── Model thinks, calls both tools ──
model_tokens_1 = [
THINKING_START,
"Two requests: weather and time. I'll call both.",
THINKING_END,
"\n\n",
TOOL_CALLS_START,
"\n",
f'<{DSML_TOKEN}invoke name="get_weather">\n',
f'<{DSML_TOKEN}parameter name="city" string="true">NYC</{DSML_TOKEN}parameter>\n',
f"</{DSML_TOKEN}invoke>\n",
f'<{DSML_TOKEN}invoke name="get_time">\n',
f'<{DSML_TOKEN}parameter name="timezone" string="true">EST</{DSML_TOKEN}parameter>\n',
f"</{DSML_TOKEN}invoke>\n",
TOOL_CALLS_END,
]
results_1 = list(parse_deepseek_v32(_simulate_tokens(model_tokens_1)))
tool_1 = [r for r in results_1 if isinstance(r, ToolCallResponse)]
assert len(tool_1) == 1
assert len(tool_1[0].tool_calls) == 2
assert tool_1[0].tool_calls[0].name == "get_weather"
assert tool_1[0].tool_calls[1].name == "get_time"
# ── Add assistant + both tool results ──
messages.append(
{
"role": "assistant",
"content": "",
"reasoning_content": "Two requests: weather and time. I'll call both.",
"tool_calls": [
{
"type": "function",
"function": {
"name": "get_weather",
"arguments": '{"city": "NYC"}',
},
},
{
"type": "function",
"function": {
"name": "get_time",
"arguments": '{"timezone": "EST"}',
},
},
],
}
)
messages.append({"role": "tool", "content": "72°F, sunny"})
messages.append({"role": "tool", "content": "2:30 PM EST"})
prompt_2 = encode_messages(
messages, tools=_WEATHER_TOOLS, thinking_mode="thinking"
)
# Verify multi-tool result encoding
# Count is 2: 1 in _TOOLS_SYSTEM_TEMPLATE example + 1 in conversation
assert prompt_2.count("<function_results>") == 2
assert prompt_2.count("</function_results>") == 2
assert "<result>72°F, sunny</result>" in prompt_2
assert "<result>2:30 PM EST</result>" in prompt_2
assert prompt_2.endswith(THINKING_START)
# ── Model thinks about results, answers ──
model_tokens_2 = [
THINKING_START,
"Got both results. Weather is 72°F sunny, time is 2:30 PM.",
THINKING_END,
"In NYC it's currently 72°F and sunny. The time in EST is 2:30 PM.",
]
results_2 = list(parse_deepseek_v32(_simulate_tokens(model_tokens_2)))
tool_2 = [r for r in results_2 if isinstance(r, ToolCallResponse)]
gen_2 = [r for r in results_2 if isinstance(r, GenerationResponse)]
non_thinking_2 = [r for r in gen_2 if not r.is_thinking]
assert len(tool_2) == 0
final_text = "".join(r.text for r in non_thinking_2)
assert "72°F" in final_text
assert "2:30 PM" in final_text
def test_two_user_turns_reasoning_stripped(self):
"""Turn 2: old reasoning_content is stripped from history.
Per the vLLM spec, clear_reasoning_content is called between user turns
to save bandwidth. Our _drop_old_thinking handles this.
"""
# Full turn 1 conversation (already completed)
messages: list[dict[str, Any]] = [
{"role": "system", "content": "You are helpful."},
{"role": "user", "content": "Weather in Hangzhou?"},
{
"role": "assistant",
"content": "",
"reasoning_content": "I need to call get_weather for Hangzhou.",
"tool_calls": [
{
"type": "function",
"function": {
"name": "get_weather",
"arguments": '{"city": "Hangzhou"}',
},
}
],
},
{"role": "tool", "content": "Cloudy 7~13°C"},
{
"role": "assistant",
"content": "The weather in Hangzhou is cloudy, 7-13°C.",
"reasoning_content": "The tool returned cloudy weather. I'll summarize.",
},
# Turn 2: user asks again
{"role": "user", "content": "What about Beijing?"},
]
prompt = encode_messages(
messages, tools=_WEATHER_TOOLS, thinking_mode="thinking"
)
# Old reasoning_content from turn 1 assistants should be STRIPPED
# (they're before the last user message at index 5)
assert "I need to call get_weather" not in prompt
assert "tool returned cloudy" not in prompt
# But the assistant's content and tool calls should still be there
assert "cloudy, 7-13°C" in prompt
assert TOOL_CALLS_START in prompt
# Prompt ends with <think> for the new turn
assert prompt.endswith(THINKING_START)
# ── Turn 2: Model thinks, calls tool for Beijing ──
model_tokens = [
THINKING_START,
"Now the user wants Beijing weather.",
THINKING_END,
"\n\n",
TOOL_CALLS_START,
"\n",
f'<{DSML_TOKEN}invoke name="get_weather">\n',
f'<{DSML_TOKEN}parameter name="city" string="true">Beijing</{DSML_TOKEN}parameter>\n',
f"</{DSML_TOKEN}invoke>\n",
TOOL_CALLS_END,
]
results = list(parse_deepseek_v32(_simulate_tokens(model_tokens)))
tool_results = [r for r in results if isinstance(r, ToolCallResponse)]
assert len(tool_results) == 1
assert tool_results[0].tool_calls[0].name == "get_weather"
args = json.loads(tool_results[0].tool_calls[0].arguments) # pyright: ignore[reportAny]
assert args == {"city": "Beijing"}
def test_chained_tool_calls_loop(self):
"""Model calls tool, gets result, calls another tool, gets result, answers.
This simulates the inner while loop from the vLLM spec where the model
may need multiple sub-turns of tool calling before it has enough info.
"""
# ── Sub-turn 1: user asks, model calls get_time ──
messages: list[dict[str, Any]] = [
{"role": "system", "content": "You are helpful."},
{"role": "user", "content": "What's the weather in Hangzhou tomorrow?"},
]
prompt_1 = encode_messages(
messages, tools=_WEATHER_TOOLS, thinking_mode="thinking"
)
assert prompt_1.endswith(THINKING_START)
# Model first calls get_time to figure out the date
model_tokens_1 = [
THINKING_START,
"I need the current date first to calculate tomorrow.",
THINKING_END,
"\n\n",
TOOL_CALLS_START,
"\n",
f'<{DSML_TOKEN}invoke name="get_time">\n',
f'<{DSML_TOKEN}parameter name="timezone" string="true">Asia/Shanghai</{DSML_TOKEN}parameter>\n',
f"</{DSML_TOKEN}invoke>\n",
TOOL_CALLS_END,
]
results_1 = list(parse_deepseek_v32(_simulate_tokens(model_tokens_1)))
tool_1 = [r for r in results_1 if isinstance(r, ToolCallResponse)]
assert len(tool_1) == 1
assert tool_1[0].tool_calls[0].name == "get_time"
# ── Sub-turn 2: add tool result, model calls get_weather ──
messages.append(
{
"role": "assistant",
"content": "",
"reasoning_content": "I need the current date first to calculate tomorrow.",
"tool_calls": [
{
"type": "function",
"function": {
"name": "get_time",
"arguments": '{"timezone": "Asia/Shanghai"}',
},
}
],
}
)
messages.append({"role": "tool", "content": "2025-12-01 14:30 CST"})
prompt_2 = encode_messages(
messages, tools=_WEATHER_TOOLS, thinking_mode="thinking"
)
assert "<result>2025-12-01 14:30 CST</result>" in prompt_2
assert prompt_2.endswith(THINKING_START)
# Model now knows the date, calls get_weather
model_tokens_2 = [
THINKING_START,
"Today is 2025-12-01, so tomorrow is 2025-12-02.",
" Now I can check weather for Hangzhou.",
THINKING_END,
"\n\n",
TOOL_CALLS_START,
"\n",
f'<{DSML_TOKEN}invoke name="get_weather">\n',
f'<{DSML_TOKEN}parameter name="city" string="true">Hangzhou</{DSML_TOKEN}parameter>\n',
f"</{DSML_TOKEN}invoke>\n",
TOOL_CALLS_END,
]
results_2 = list(parse_deepseek_v32(_simulate_tokens(model_tokens_2)))
tool_2 = [r for r in results_2 if isinstance(r, ToolCallResponse)]
assert len(tool_2) == 1
assert tool_2[0].tool_calls[0].name == "get_weather"
# ── Sub-turn 3: add weather result, model answers ──
messages.append(
{
"role": "assistant",
"content": "",
"reasoning_content": "Today is 2025-12-01, so tomorrow is 2025-12-02. Now I can check weather for Hangzhou.",
"tool_calls": [
{
"type": "function",
"function": {
"name": "get_weather",
"arguments": '{"city": "Hangzhou"}',
},
}
],
}
)
messages.append({"role": "tool", "content": "Sunny, 5~12°C"})
prompt_3 = encode_messages(
messages, tools=_WEATHER_TOOLS, thinking_mode="thinking"
)
# Should have both function_results blocks (one per tool round)
# Count is 3: 1 in _TOOLS_SYSTEM_TEMPLATE example + 2 in conversation
assert prompt_3.count("<function_results>") == 3
assert prompt_3.count("</function_results>") == 3
assert "<result>2025-12-01 14:30 CST</result>" in prompt_3
assert "<result>Sunny, 5~12°C</result>" in prompt_3
assert prompt_3.endswith(THINKING_START)
# Model finally answers
model_tokens_3 = [
THINKING_START,
"I have the weather for tomorrow in Hangzhou.",
THINKING_END,
"Tomorrow in Hangzhou will be sunny with temperatures between 5°C and 12°C.",
]
results_3 = list(parse_deepseek_v32(_simulate_tokens(model_tokens_3)))
tool_3 = [r for r in results_3 if isinstance(r, ToolCallResponse)]
gen_3 = [r for r in results_3 if isinstance(r, GenerationResponse)]
non_thinking_3 = [r for r in gen_3 if not r.is_thinking]
assert len(tool_3) == 0 # No more tool calls — loop ends
final_text = "".join(r.text for r in non_thinking_3)
assert "sunny" in final_text.lower()
assert "5°C" in final_text
assert "12°C" in final_text

View File

@@ -148,6 +148,7 @@ class MockTokenizer:
tool_call_start = None tool_call_start = None
tool_call_end = None tool_call_end = None
has_tool_calling = False has_tool_calling = False
has_thinking = False
class MockGroup: class MockGroup:

View File

@@ -149,12 +149,23 @@ class TestParseGptOssThinkingThenToolCall:
def test_thinking_then_tool_call(self): def test_thinking_then_tool_call(self):
results = _collect(THINKING_THEN_TOOL_TOKENS) results = _collect(THINKING_THEN_TOOL_TOKENS)
# Should have thinking tags + content + tool call # Thinking tokens should have is_thinking=True and no <think> tags
text_parts = [r.text for r in results if isinstance(r, GenerationResponse)] thinking_responses = [
combined = "".join(text_parts) r for r in results if isinstance(r, GenerationResponse) and r.is_thinking
assert "<think>" in combined ]
assert "</think>" in combined thinking_text = "".join(r.text for r in thinking_responses)
assert "Let me think about this." in combined assert "Let me think about this." in thinking_text
assert "<think>" not in thinking_text
assert "</think>" not in thinking_text
# Non-thinking tokens should have is_thinking=False
non_thinking = [
r
for r in results
if isinstance(r, GenerationResponse) and not r.is_thinking
]
non_thinking_text = "".join(r.text for r in non_thinking)
assert "<think>" not in non_thinking_text
# And the tool call # And the tool call
tc = _get_tool_call(results) tc = _get_tool_call(results)

View File

@@ -0,0 +1,8 @@
#!/bin/bash
# Run Claude Code against a local exo cluster! (Here, GPT OSS 120B)
ANTHROPIC_BASE_URL="http://localhost:52415/" \
ANTHROPIC_AUTH_TOKEN="dummy" \
ANTHROPIC_MODEL="mlx-community/gpt-oss-120b-MXFP4-Q8" \
ANTHROPIC_SMALL_FAST_MODEL="mlx-community/gpt-oss-120b-MXFP4-Q8" \
CLAUDE_CODE_DISABLE_NONESSENTIAL_TRAFFIC=1 \
claude