Compare commits

...

3 Commits

Author SHA1 Message Date
Alex Cheema
025ed9fd82 feat: add prefill progress bar for long prompts (#1181)
## Motivation

Users processing long prompts have no visibility into when token
generation will start. This feature adds a progress bar showing prefill
progress, giving users real-time feedback during prompt processing.

## Changes

### Backend
- Added `PrefillProgress` event type with `command_id`,
`processed_tokens`, `total_tokens`
- Added `PrefillProgressResponse` type (though now using direct callback
approach)
- Wired `prompt_progress_callback` through MLX's `stream_generate()`
- Progress events sent directly from callback for real-time updates (not
batched)
- API generates SSE named events: `event: prefill_progress\ndata: {...}`
- Added `PrefillProgressData` dataclass and `StreamEvent` union type in
API

### Dashboard
- Added `PrefillProgress` interface to store
- Updated SSE parsing to handle `event:` lines (named events)
- Created `PrefillProgressBar.svelte` with animated progress bar
- Shows "Processing prompt: X/Y tokens" with percentage
- Progress bar disappears when first token arrives

## Why It Works

MLX's `stream_generate()` accepts a `prompt_progress_callback(processed,
total)` that's called after each prefill chunk. By sending events
directly from this callback (rather than yielding from the generator),
progress updates are sent in real-time during prefill.

Using SSE named events (`event: prefill_progress`) maintains full
OpenAI/Claude API compatibility - standard clients ignore named events
they don't recognize, while the exo dashboard explicitly listens for
them.

## Test Plan

### Manual Testing
- Hardware: MacBook Pro M3 Max
- Set `prefill_step_size=256` for more frequent updates
- Tested with long prompts (pasted large documents)
- Verified progress bar updates incrementally during prefill
- Confirmed progress bar disappears when generation starts
- Tested with curl - standard `data:` events still work normally

Here is it working:


https://github.com/user-attachments/assets/5cc6f075-c5b2-4a44-bb4d-9efb246bc5fe


### Automated Testing
- Type checker passes (0 errors)
- All 192 tests pass
- Dashboard builds successfully

### API Compatibility
- Named SSE events are ignored by OpenAI SDK clients
- Regular token data uses standard `data: {...}` format
- `[DONE]` sentinel works as expected

---

**Note:** `prefill_step_size` is temporarily set to 256 for testing.
Should be changed back to 2048 before merging for production
performance.

---------

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
Co-authored-by: Evan <evanev7@gmail.com>
Co-authored-by: Ryuichi Leo Takashige <leo@exolabs.net>
2026-02-19 03:18:25 +00:00
rltakashige
19bc09550d Add status=downloaded filter for model endpoint (#1539)
## Motivation

https://github.com/exo-explore/exo/issues/1346#issuecomment-3831427905


## Test Plan

### Manual Testing
**Without filter**
<img width="1708" height="1010" alt="Screenshot 2026-02-18 at 22 26 22"
src="https://github.com/user-attachments/assets/f4bf7142-717d-4042-ac28-d8a55a8e45e7"
/>

**With filter**
<img width="1723" height="1021" alt="Screenshot 2026-02-18 at 22 26 45"
src="https://github.com/user-attachments/assets/40a522d5-c6e6-4148-b21a-02caa1221ebe"
/>
2026-02-18 22:34:11 +00:00
Alex Cheema
7cadca4f27 Try multiple endpoints for internet connectivity check (#1516)
## Summary
- `_test_internet_connection()` previously only tried `1.1.1.1:443`,
which some ISPs/networks block, causing exo to incorrectly report no
internet and fail downloads on startup
- Now tries `1.1.1.1`, `8.8.8.8`, and `1.0.0.1` in sequence, succeeding
if any endpoint responds
- Returns early on first success for minimal latency in the common case

Fixes #1425

## Test plan
- [ ] Verify downloads work on networks that block `1.1.1.1`
- [ ] Verify existing behavior unchanged on networks where `1.1.1.1`
works
- [ ] Verify `internet_connection` is set to `False` only when all three
endpoints fail

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: rltakashige <rl.takashige@gmail.com>
2026-02-18 22:10:07 +00:00
19 changed files with 534 additions and 109 deletions

View File

@@ -38,6 +38,8 @@ class Scenario:
expected_function: str | None = None
required_arg_keys: list[str] | None = None
tool_result: str | None = None
nested_array_key: str | None = None
required_item_keys: list[str] | None = None
def load_scenarios(path: Path) -> list[Scenario]:
@@ -105,6 +107,8 @@ def load_scenarios(path: Path) -> list[Scenario]:
expected_function=s.get("expected_function"),
required_arg_keys=s.get("required_arg_keys"),
tool_result=tool_result,
nested_array_key=s.get("nested_array_key"),
required_item_keys=s.get("required_item_keys"),
)
)
@@ -147,6 +151,35 @@ def validate_args(args_str: str, required_keys: list[str]) -> tuple[bool, str |
return True, None
def validate_nested_args(
args_str: str,
array_key: str,
required_item_keys: list[str],
) -> tuple[bool, str | None]:
"""Check that args[array_key] is a list of objects with required keys."""
try:
args = json.loads(args_str)
except (json.JSONDecodeError, TypeError) as exc:
return False, f"Invalid JSON: {exc}"
if not isinstance(args, dict):
return False, f"Expected dict, got {type(args).__name__}"
arr = args.get(array_key)
if not isinstance(arr, list):
return False, f"'{array_key}' is not an array (got {type(arr).__name__})"
if len(arr) == 0:
return False, f"'{array_key}' is empty"
for i, item in enumerate(arr):
if not isinstance(item, dict):
return (
False,
f"'{array_key}[{i}]' is not an object (got {type(item).__name__})",
)
missing = [k for k in required_item_keys if k not in item]
if missing:
return False, f"'{array_key}[{i}]' missing keys: {missing}"
return True, None
def call_api(
client: httpx.Client,
host: str,
@@ -699,6 +732,15 @@ def run_scenario(
checks["valid_arguments"] = ok
else:
checks["valid_arguments"] = True
if scenario.nested_array_key and scenario.required_item_keys:
ok, nested_err = validate_nested_args(
parsed.tool_call["arguments"],
scenario.nested_array_key,
scenario.required_item_keys,
)
checks["valid_nested_structure"] = ok
if not ok:
args_err = nested_err
else:
checks["correct_function"] = False
checks["valid_arguments"] = False

View File

@@ -39,6 +39,30 @@ description = "Product category to filter by"
type = "number"
description = "Maximum price in USD"
[tools.create_todos]
description = "Create a structured todo list"
required = ["todos"]
[tools.create_todos.properties.todos]
type = "array"
description = "List of todo items"
[tools.create_todos.properties.todos.items]
type = "object"
required = ["content", "status", "priority"]
[tools.create_todos.properties.todos.items.properties.content]
type = "string"
description = "The todo item text"
[tools.create_todos.properties.todos.items.properties.status]
type = "string"
description = "Status: pending, in_progress, or completed"
[tools.create_todos.properties.todos.items.properties.priority]
type = "string"
description = "Priority: low, normal, or high"
# -- Should call a tool --
[[scenarios]]
@@ -219,6 +243,48 @@ role = "tool"
tool_call_id = "call_4"
content = '{"temperature": "18C", "condition": "cloudy"}'
# -- Nested object schema (regression for lossy chat template rendering) --
[[scenarios]]
name = "nested_schema_tool_call"
description = "Tool call with nested object array schema -> create_todos"
expect_tool_call = true
expected_function = "create_todos"
required_arg_keys = ["todos"]
nested_array_key = "todos"
required_item_keys = ["content", "status", "priority"]
tools = ["create_todos"]
[[scenarios.messages]]
role = "user"
content = "Create a todo list with 3 items to learn Python"
# -- Tool name integrity (regression for harmony token leaking into name) --
[tools.glob]
description = "Search for files matching a glob pattern in the codebase"
required = ["pattern"]
[tools.glob.properties.pattern]
type = "string"
description = "The glob pattern to match files against, e.g. '**/*.py'"
[tools.glob.properties.path]
type = "string"
description = "The directory to search in"
[[scenarios]]
name = "tool_name_integrity"
description = "Tool name must not contain harmony tokens like <|channel|>"
expect_tool_call = true
expected_function = "glob"
required_arg_keys = ["pattern"]
tools = ["glob"]
[[scenarios.messages]]
role = "user"
content = "Find all Python files in the src directory"
# -- Should NOT call a tool --
[[scenarios]]

View File

@@ -3,16 +3,17 @@
messages,
currentResponse,
isLoading,
prefillProgress,
deleteMessage,
editAndRegenerate,
regenerateLastResponse,
regenerateFromToken,
setEditingImage,
} from "$lib/stores/app.svelte";
import type { Message } from "$lib/stores/app.svelte";
import type { MessageAttachment } from "$lib/stores/app.svelte";
import MarkdownContent from "./MarkdownContent.svelte";
import TokenHeatmap from "./TokenHeatmap.svelte";
import PrefillProgressBar from "./PrefillProgressBar.svelte";
import ImageLightbox from "./ImageLightbox.svelte";
interface Props {
@@ -25,6 +26,7 @@
const messageList = $derived(messages());
const response = $derived(currentResponse());
const loading = $derived(isLoading());
const prefill = $derived(prefillProgress());
// Scroll management - user controls scroll, show button when not at bottom
const SCROLL_THRESHOLD = 100;
@@ -428,6 +430,9 @@
{:else}
<!-- Assistant message styling -->
<div class="p-3 sm:p-4">
{#if loading && isLastAssistantMessage(message.id) && prefill && !message.content}
<PrefillProgressBar progress={prefill} class="mb-3" />
{/if}
{#if message.thinking && message.thinking.trim().length > 0}
<div
class="mb-3 rounded border border-exo-yellow/20 bg-exo-black/40"

View File

@@ -26,7 +26,8 @@
downloadedOnNodes = [],
}: HuggingFaceResultItemProps = $props();
function formatNumber(num: number): string {
function formatNumber(num: number | undefined): string {
if (num == null) return "0";
if (num >= 1000000) {
return `${(num / 1000000).toFixed(1)}M`;
} else if (num >= 1000) {

View File

@@ -0,0 +1,52 @@
<script lang="ts">
import type { PrefillProgress } from "$lib/stores/app.svelte";
interface Props {
progress: PrefillProgress;
class?: string;
}
let { progress, class: className = "" }: Props = $props();
const percentage = $derived(
progress.total > 0
? Math.round((progress.processed / progress.total) * 100)
: 0,
);
function formatTokenCount(count: number | undefined): string {
if (count == null) return "0";
if (count >= 1000) {
return `${(count / 1000).toFixed(1)}k`;
}
return count.toString();
}
</script>
<div class="prefill-progress {className}">
<div
class="flex items-center justify-between text-xs text-exo-light-gray mb-1"
>
<span>Processing prompt</span>
<span class="font-mono">
{formatTokenCount(progress.processed)} / {formatTokenCount(
progress.total,
)} tokens
</span>
</div>
<div class="h-1.5 bg-exo-black/60 rounded-full overflow-hidden">
<div
class="h-full bg-exo-yellow rounded-full transition-all duration-150 ease-out"
style="width: {percentage}%"
></div>
</div>
<div class="text-right text-xs text-exo-light-gray/70 mt-0.5 font-mono">
{percentage}%
</div>
</div>
<style>
.prefill-progress {
width: 100%;
}
</style>

View File

@@ -273,6 +273,11 @@ export interface TokenData {
topLogprobs: TopLogprob[];
}
export interface PrefillProgress {
processed: number;
total: number;
}
export interface Message {
id: string;
role: "user" | "assistant" | "system";
@@ -520,6 +525,7 @@ class AppStore {
ttftMs = $state<number | null>(null); // Time to first token in ms
tps = $state<number | null>(null); // Tokens per second
totalTokens = $state<number>(0); // Total tokens in current response
prefillProgress = $state<PrefillProgress | null>(null);
// Topology state
topologyData = $state<TopologyData | null>(null);
@@ -2005,6 +2011,7 @@ class AppStore {
reader: ReadableStreamDefaultReader<Uint8Array>,
targetConversationId: string,
onChunk: (parsed: T) => void,
onEvent?: Record<string, (data: unknown) => void>,
): Promise<void> {
const decoder = new TextDecoder();
let buffer = "";
@@ -2025,6 +2032,24 @@ class AppStore {
const trimmed = line.trim();
if (!trimmed) continue;
// Handle SSE comments (": key json") for prefill progress etc.
if (trimmed.startsWith(": ") && onEvent) {
const comment = trimmed.slice(2);
const spaceIdx = comment.indexOf(" ");
if (spaceIdx > 0) {
const key = comment.slice(0, spaceIdx);
if (onEvent[key]) {
try {
const parsed = JSON.parse(comment.slice(spaceIdx + 1));
onEvent[key](parsed);
} catch {
// Skip malformed JSON in comment
}
}
}
continue;
}
if (trimmed.startsWith("data: ")) {
const data = trimmed.slice(6);
if (data === "[DONE]") continue;
@@ -2309,6 +2334,11 @@ class AppStore {
reader,
targetConversationId,
(parsed) => {
// Clear prefill progress when first token data arrives
if (this.prefillProgress) {
this.prefillProgress = null;
}
const choice = parsed.choices?.[0];
const tokenContent = choice?.delta?.content;
@@ -2371,8 +2401,26 @@ class AppStore {
this.persistConversation(targetConversationId);
}
},
{
prefill_progress: (data) => {
// TaggedModel wraps as {"PrefillProgressChunk": {...}}
// model_dump_json() uses snake_case (by_alias defaults to False)
const raw = data as Record<string, unknown>;
const inner = (raw["PrefillProgressChunk"] ?? raw) as {
processed_tokens: number;
total_tokens: number;
};
this.prefillProgress = {
processed: inner.processed_tokens,
total: inner.total_tokens,
};
},
},
);
// Clear prefill progress after stream ends
this.prefillProgress = null;
// Calculate final TPS
if (firstTokenTime !== null && tokenCount > 1) {
const totalGenerationTime = performance.now() - firstTokenTime;
@@ -3043,6 +3091,7 @@ export const isLoading = () => appStore.isLoading;
export const ttftMs = () => appStore.ttftMs;
export const tps = () => appStore.tps;
export const totalTokens = () => appStore.totalTokens;
export const prefillProgress = () => appStore.prefillProgress;
export const topologyData = () => appStore.topologyData;
export const instances = () => appStore.instances;
export const runners = () => appStore.runners;

View File

@@ -932,13 +932,6 @@
};
}
// Debug: Log downloads data when it changes
$effect(() => {
if (downloadsData && Object.keys(downloadsData).length > 0) {
console.log("[Download Debug] Current downloads:", downloadsData);
}
});
// Helper to get download status for an instance
function getInstanceDownloadStatus(
instanceId: string,

View File

@@ -123,14 +123,17 @@ class DownloadCoordinator:
tg.start_soon(self._check_internet_connection)
def _test_internet_connection(self) -> None:
try:
socket.create_connection(("1.1.1.1", 443), timeout=3).close()
self.shard_downloader.set_internet_connection(True)
except OSError:
self.shard_downloader.set_internet_connection(False)
logger.debug(
f"Internet connectivity: {self.shard_downloader.internet_connection}"
)
# Try multiple endpoints since some ISPs/networks block specific IPs
for host in ("1.1.1.1", "8.8.8.8", "1.0.0.1"):
try:
socket.create_connection((host, 443), timeout=3).close()
self.shard_downloader.set_internet_connection(True)
logger.debug(f"Internet connectivity: True (via {host})")
return
except OSError:
continue
self.shard_downloader.set_internet_connection(False)
logger.debug("Internet connectivity: False")
async def _check_internet_connection(self) -> None:
first_connection = True

View File

@@ -19,7 +19,12 @@ from exo.shared.types.api import (
ToolCall,
Usage,
)
from exo.shared.types.chunks import ErrorChunk, TokenChunk, ToolCallChunk
from exo.shared.types.chunks import (
ErrorChunk,
PrefillProgressChunk,
TokenChunk,
ToolCallChunk,
)
from exo.shared.types.common import CommandId
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
@@ -123,67 +128,81 @@ def chunk_to_response(
async def generate_chat_stream(
command_id: CommandId,
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
chunk_stream: AsyncGenerator[
PrefillProgressChunk | ErrorChunk | ToolCallChunk | TokenChunk, None
],
) -> AsyncGenerator[str, None]:
"""Generate Chat Completions API streaming events from chunks."""
last_usage: Usage | None = None
async for chunk in chunk_stream:
if isinstance(chunk, ErrorChunk):
error_response = ErrorResponse(
error=ErrorInfo(
message=chunk.error_message or "Internal server error",
type="InternalServerError",
code=500,
)
)
yield f"data: {error_response.model_dump_json()}\n\n"
yield "data: [DONE]\n\n"
return
match chunk:
case PrefillProgressChunk():
# Use SSE comment so third-party clients ignore it
yield f": prefill_progress {chunk.model_dump_json()}\n\n"
last_usage = chunk.usage or last_usage
if isinstance(chunk, ToolCallChunk):
tool_call_deltas = [
ToolCall(
id=tool.id,
index=i,
function=tool,
)
for i, tool in enumerate(chunk.tool_calls)
]
tool_response = ChatCompletionResponse(
id=command_id,
created=int(time.time()),
model=chunk.model,
choices=[
StreamingChoiceResponse(
index=0,
delta=ChatCompletionMessage(
role="assistant",
tool_calls=tool_call_deltas,
),
finish_reason="tool_calls",
case ErrorChunk():
error_response = ErrorResponse(
error=ErrorInfo(
message=chunk.error_message or "Internal server error",
type="InternalServerError",
code=500,
)
],
usage=last_usage,
)
yield f"data: {tool_response.model_dump_json()}\n\n"
yield "data: [DONE]\n\n"
return
)
yield f"data: {error_response.model_dump_json()}\n\n"
yield "data: [DONE]\n\n"
return
chunk_response = chunk_to_response(chunk, command_id)
if chunk.finish_reason is not None:
chunk_response = chunk_response.model_copy(update={"usage": last_usage})
yield f"data: {chunk_response.model_dump_json()}\n\n"
case ToolCallChunk():
last_usage = chunk.usage or last_usage
if chunk.finish_reason is not None:
yield "data: [DONE]\n\n"
tool_call_deltas = [
ToolCall(
id=tool.id,
index=i,
function=tool,
)
for i, tool in enumerate(chunk.tool_calls)
]
tool_response = ChatCompletionResponse(
id=command_id,
created=int(time.time()),
model=chunk.model,
choices=[
StreamingChoiceResponse(
index=0,
delta=ChatCompletionMessage(
role="assistant",
tool_calls=tool_call_deltas,
),
finish_reason="tool_calls",
)
],
usage=last_usage,
)
yield f"data: {tool_response.model_dump_json()}\n\n"
yield "data: [DONE]\n\n"
return
case TokenChunk():
last_usage = chunk.usage or last_usage
chunk_response = chunk_to_response(chunk, command_id)
if chunk.finish_reason is not None:
chunk_response = chunk_response.model_copy(
update={"usage": last_usage}
)
yield f"data: {chunk_response.model_dump_json()}\n\n"
if chunk.finish_reason is not None:
yield "data: [DONE]\n\n"
async def collect_chat_response(
command_id: CommandId,
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
chunk_stream: AsyncGenerator[
ErrorChunk | ToolCallChunk | TokenChunk | PrefillProgressChunk, None
],
) -> AsyncGenerator[str]:
# This is an AsyncGenerator[str] rather than returning a ChatCompletionReponse because
# FastAPI handles the cancellation better but wouldn't auto-serialize for some reason
@@ -197,38 +216,43 @@ async def collect_chat_response(
last_usage: Usage | None = None
async for chunk in chunk_stream:
if isinstance(chunk, ErrorChunk):
error_message = chunk.error_message or "Internal server error"
break
match chunk:
case PrefillProgressChunk():
continue
if model is None:
model = chunk.model
case ErrorChunk():
error_message = chunk.error_message or "Internal server error"
break
last_usage = chunk.usage or last_usage
if isinstance(chunk, TokenChunk):
text_parts.append(chunk.text)
if chunk.logprob is not None:
logprobs_content.append(
LogprobsContentItem(
token=chunk.text,
logprob=chunk.logprob,
top_logprobs=chunk.top_logprobs or [],
case TokenChunk():
if model is None:
model = chunk.model
last_usage = chunk.usage or last_usage
text_parts.append(chunk.text)
if chunk.logprob is not None:
logprobs_content.append(
LogprobsContentItem(
token=chunk.text,
logprob=chunk.logprob,
top_logprobs=chunk.top_logprobs or [],
)
)
)
if chunk.finish_reason is not None:
finish_reason = chunk.finish_reason
if isinstance(chunk, ToolCallChunk):
tool_calls.extend(
ToolCall(
id=tool.id,
index=i,
function=tool,
case ToolCallChunk():
if model is None:
model = chunk.model
last_usage = chunk.usage or last_usage
tool_calls.extend(
ToolCall(
id=tool.id,
index=i,
function=tool,
)
for i, tool in enumerate(chunk.tool_calls)
)
for i, tool in enumerate(chunk.tool_calls)
)
if chunk.finish_reason is not None:
finish_reason = chunk.finish_reason
finish_reason = chunk.finish_reason
if error_message is not None:
raise ValueError(error_message)

View File

@@ -5,7 +5,12 @@ from collections.abc import AsyncGenerator
from typing import Any
from exo.shared.types.api import FinishReason, Usage
from exo.shared.types.chunks import ErrorChunk, TokenChunk, ToolCallChunk
from exo.shared.types.chunks import (
ErrorChunk,
PrefillProgressChunk,
TokenChunk,
ToolCallChunk,
)
from exo.shared.types.claude_api import (
ClaudeContentBlock,
ClaudeContentBlockDeltaEvent,
@@ -160,7 +165,9 @@ def claude_request_to_text_generation(
async def collect_claude_response(
command_id: CommandId,
model: str,
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
chunk_stream: AsyncGenerator[
ErrorChunk | ToolCallChunk | TokenChunk | PrefillProgressChunk, None
],
) -> AsyncGenerator[str]:
# This is an AsyncGenerator[str] rather than returning a ChatCompletionReponse because
# FastAPI handles the cancellation better but wouldn't auto-serialize for some reason
@@ -172,6 +179,9 @@ async def collect_claude_response(
error_message: str | None = None
async for chunk in chunk_stream:
if isinstance(chunk, PrefillProgressChunk):
continue
if isinstance(chunk, ErrorChunk):
error_message = chunk.error_message or "Internal server error"
break
@@ -230,7 +240,9 @@ async def collect_claude_response(
async def generate_claude_stream(
command_id: CommandId,
model: str,
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
chunk_stream: AsyncGenerator[
ErrorChunk | ToolCallChunk | TokenChunk | PrefillProgressChunk, None
],
) -> AsyncGenerator[str, None]:
"""Generate Claude Messages API streaming events from TokenChunks."""
# Initial message_start event
@@ -256,6 +268,9 @@ async def generate_claude_stream(
next_block_index = 1 # text block is 0, tool blocks start at 1
async for chunk in chunk_stream:
if isinstance(chunk, PrefillProgressChunk):
continue
if isinstance(chunk, ErrorChunk):
# Close text block and bail
break

View File

@@ -5,7 +5,12 @@ from itertools import count
from typing import Any
from exo.shared.types.api import Usage
from exo.shared.types.chunks import ErrorChunk, TokenChunk, ToolCallChunk
from exo.shared.types.chunks import (
ErrorChunk,
PrefillProgressChunk,
TokenChunk,
ToolCallChunk,
)
from exo.shared.types.common import CommandId
from exo.shared.types.openai_responses import (
FunctionCallInputItem,
@@ -121,7 +126,9 @@ def responses_request_to_text_generation(
async def collect_responses_response(
command_id: CommandId,
model: str,
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
chunk_stream: AsyncGenerator[
ErrorChunk | ToolCallChunk | TokenChunk | PrefillProgressChunk, None
],
) -> AsyncGenerator[str]:
# This is an AsyncGenerator[str] rather than returning a ChatCompletionReponse because
# FastAPI handles the cancellation better but wouldn't auto-serialize for some reason
@@ -134,6 +141,9 @@ async def collect_responses_response(
error_message: str | None = None
async for chunk in chunk_stream:
if isinstance(chunk, PrefillProgressChunk):
continue
if isinstance(chunk, ErrorChunk):
error_message = chunk.error_message or "Internal server error"
break
@@ -189,7 +199,9 @@ async def collect_responses_response(
async def generate_responses_stream(
command_id: CommandId,
model: str,
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
chunk_stream: AsyncGenerator[
ErrorChunk | ToolCallChunk | TokenChunk | PrefillProgressChunk, None
],
) -> AsyncGenerator[str, None]:
"""Generate OpenAI Responses API streaming events from TokenChunks."""
response_id = f"resp_{command_id}"
@@ -243,6 +255,9 @@ async def generate_responses_stream(
next_output_index = 1 # message item is at 0
async for chunk in chunk_stream:
if isinstance(chunk, PrefillProgressChunk):
continue
if isinstance(chunk, ErrorChunk):
break

View File

@@ -107,6 +107,7 @@ from exo.shared.types.chunks import (
ErrorChunk,
ImageChunk,
InputImageChunk,
PrefillProgressChunk,
TokenChunk,
ToolCallChunk,
)
@@ -137,6 +138,7 @@ from exo.shared.types.events import (
Event,
ForwarderEvent,
IndexedEvent,
PrefillProgress,
TracesMerged,
)
from exo.shared.types.memory import Memory
@@ -145,6 +147,7 @@ from exo.shared.types.openai_responses import (
ResponsesResponse,
)
from exo.shared.types.state import State
from exo.shared.types.worker.downloads import DownloadCompleted
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.shards import Sharding
from exo.utils.banner import print_startup_banner
@@ -220,7 +223,8 @@ class API:
)
self._text_generation_queues: dict[
CommandId, Sender[TokenChunk | ErrorChunk | ToolCallChunk]
CommandId,
Sender[TokenChunk | ErrorChunk | ToolCallChunk | PrefillProgressChunk],
] = {}
self._image_generation_queues: dict[
CommandId, Sender[ImageChunk | ErrorChunk]
@@ -526,19 +530,23 @@ class API:
async def _token_chunk_stream(
self, command_id: CommandId
) -> AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None]:
) -> AsyncGenerator[
TokenChunk | ErrorChunk | ToolCallChunk | PrefillProgressChunk, None
]:
"""Yield chunks for a given command until completion.
This is the internal low-level stream used by all API adapters.
"""
try:
self._text_generation_queues[command_id], recv = channel[
ErrorChunk | ToolCallChunk | TokenChunk
TokenChunk | ErrorChunk | ToolCallChunk | PrefillProgressChunk
]()
with recv as token_chunks:
async for chunk in token_chunks:
yield chunk
if isinstance(chunk, PrefillProgressChunk):
continue
if chunk.finish_reason is not None:
break
@@ -565,6 +573,9 @@ class API:
stats: GenerationStats | None = None
async for chunk in self._token_chunk_stream(command_id):
if isinstance(chunk, PrefillProgressChunk):
continue
if chunk.finish_reason == "error":
raise HTTPException(
status_code=500,
@@ -1292,8 +1303,18 @@ class API:
return total_available
async def get_models(self) -> ModelList:
"""Returns list of available models."""
async def get_models(self, status: str | None = Query(default=None)) -> ModelList:
"""Returns list of available models, optionally filtered by being downloaded."""
cards = await get_model_cards()
if status == "downloaded":
downloaded_model_ids: set[str] = set()
for node_downloads in self.state.downloads.values():
for dl in node_downloads:
if isinstance(dl, DownloadCompleted):
downloaded_model_ids.add(dl.shard_metadata.model_card.model_id)
cards = [c for c in cards if c.model_id in downloaded_model_ids]
return ModelList(
data=[
ModelListModel(
@@ -1311,7 +1332,7 @@ class API:
base_model=card.base_model,
capabilities=card.capabilities,
)
for card in await get_model_cards()
for card in cards
]
)
@@ -1435,6 +1456,21 @@ class API:
except BrokenResourceError:
self._text_generation_queues.pop(event.command_id, None)
elif isinstance(event, PrefillProgress):
if queue := self._text_generation_queues.get(
event.command_id, None
):
try:
await queue.send(
PrefillProgressChunk(
model=event.model,
processed_tokens=event.processed_tokens,
total_tokens=event.total_tokens,
)
)
except BrokenResourceError:
self._text_generation_queues.pop(event.command_id, None)
if isinstance(event, TracesMerged):
self._save_merged_trace(event)

View File

@@ -15,6 +15,7 @@ from exo.shared.types.events import (
NodeDownloadProgress,
NodeGatheredInfo,
NodeTimedOut,
PrefillProgress,
RunnerDeleted,
RunnerStatusUpdated,
TaskAcknowledged,
@@ -64,6 +65,7 @@ def event_apply(event: Event, state: State) -> State:
| ChunkGenerated()
| TaskAcknowledged()
| InputChunkReceived()
| PrefillProgress()
| TracesCollected()
| TracesMerged()
): # Pass-through events that don't modify state

View File

@@ -76,4 +76,13 @@ class InputImageChunk(BaseChunk):
yield name, value
GenerationChunk = TokenChunk | ImageChunk | ToolCallChunk | ErrorChunk
class PrefillProgressChunk(BaseChunk):
"""Data class for prefill progress events during streaming."""
processed_tokens: int
total_tokens: int
GenerationChunk = (
TokenChunk | ImageChunk | ToolCallChunk | ErrorChunk | PrefillProgressChunk
)

View File

@@ -5,7 +5,7 @@ from pydantic import Field
from exo.shared.topology import Connection
from exo.shared.types.chunks import GenerationChunk, InputImageChunk
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
from exo.shared.types.common import CommandId, Id, ModelId, NodeId, SessionId
from exo.shared.types.tasks import Task, TaskId, TaskStatus
from exo.shared.types.worker.downloads import DownloadProgress
from exo.shared.types.worker.instances import Instance, InstanceId
@@ -102,6 +102,13 @@ class InputChunkReceived(BaseEvent):
chunk: InputImageChunk
class PrefillProgress(BaseEvent):
command_id: CommandId
model: ModelId
processed_tokens: int
total_tokens: int
class TopologyEdgeCreated(BaseEvent):
conn: Connection
@@ -148,6 +155,7 @@ Event = (
| NodeDownloadProgress
| ChunkGenerated
| InputChunkReceived
| PrefillProgress
| TopologyEdgeCreated
| TopologyEdgeDeleted
| TracesCollected

View File

@@ -67,3 +67,8 @@ class ToolCallResponse(BaseRunnerResponse):
class FinishedResponse(BaseRunnerResponse):
pass
class PrefillProgressResponse(BaseRunnerResponse):
processed_tokens: int
total_tokens: int

View File

@@ -58,6 +58,7 @@ def prefill(
prompt_tokens: mx.array,
cache: KVCacheType,
group: mx.distributed.Group | None,
on_prefill_progress: Callable[[int, int], None] | None,
) -> tuple[float, int, list[CacheSnapshot]]:
"""Prefill the KV cache with prompt tokens.
@@ -85,6 +86,9 @@ def prefill(
if has_ssm:
snapshots.append(snapshot_ssm_states(cache))
if on_prefill_progress is not None:
on_prefill_progress(processed, total)
set_pipeline_prefill(model, is_prefill=True)
mx_barrier(group)
@@ -99,7 +103,7 @@ def prefill(
max_tokens=1,
sampler=sampler,
prompt_cache=cache,
prefill_step_size=8192,
prefill_step_size=4096,
kv_group_size=KV_GROUP_SIZE,
kv_bits=KV_BITS,
prompt_progress_callback=progress_callback,
@@ -257,6 +261,7 @@ def mlx_generate(
prompt: str,
kv_prefix_cache: KVPrefixCache | None,
group: mx.distributed.Group | None,
on_prefill_progress: Callable[[int, int], None] | None = None,
) -> Generator[GenerationResponse]:
# Ensure that generation stats only contains peak memory for this generation
mx.reset_peak_memory()
@@ -311,7 +316,13 @@ def mlx_generate(
# Prefill cache with all tokens except the last one
prefill_tps, prefill_tokens, ssm_snapshots_list = prefill(
model, tokenizer, sampler, prompt_tokens[:-1], caches, group
model,
tokenizer,
sampler,
prompt_tokens[:-1],
caches,
group,
on_prefill_progress,
)
cache_snapshots: list[CacheSnapshot] | None = ssm_snapshots_list or None

View File

@@ -1,5 +1,6 @@
import json
import os
import re
import sys
import time
from pathlib import Path
@@ -407,6 +408,56 @@ def _normalize_tool_calls(msg_dict: dict[str, Any]) -> None:
func["arguments"] = json.loads(args)
def _collect_nested_property_names(schema: dict[str, Any]) -> set[str]:
names: set[str] = set()
properties: dict[str, Any] = schema.get("properties", {}) # type: ignore[reportAny]
for prop_spec in properties.values(): # pyright: ignore[reportAny]
if not isinstance(prop_spec, dict):
continue
if prop_spec.get("type") == "array": # type: ignore[reportAny]
items: dict[str, Any] | None = prop_spec.get("items") # type: ignore[reportAny]
if isinstance(items, dict) and items.get("type") == "object": # type: ignore[reportAny]
inner_props: dict[str, Any] = items.get("properties", {}) # type: ignore[reportAny]
for k in inner_props: # pyright: ignore[reportUnknownVariableType]
names.add(str(k)) # pyright: ignore[reportUnknownArgumentType]
names.update(_collect_nested_property_names(items)) # pyright: ignore[reportUnknownArgumentType]
return names
def _schemas_lost_in_prompt(prompt: str, tools: list[dict[str, Any]]) -> bool:
"""Return True if nested property names from any tool schema are absent."""
for tool in tools:
fn: dict[str, Any] = tool.get("function", {}) # type: ignore
params: dict[str, Any] = fn.get("parameters", {}) # type: ignore
nested = _collect_nested_property_names(params)
if nested and not all(name in prompt for name in nested):
return True
return False
_LOSSY_TEMPLATE_PATTERN = re.compile(
r"""inner_type\s*==\s*["']object \| object["']\s*or\s*inner_type\|length\s*>\s*\d+""",
)
def _patch_lossy_chat_template(template: str) -> str | None:
"""Patch chat templates that collapse nested object schemas to ``any[]``.
Some templates (e.g., GPT-OSS) have a guard like::
inner_type == "object | object" or inner_type|length > 50
The length check silently drops complex array-of-object schemas.
We remove the length guard, keeping only the object-union check.
Returns the patched template, or *None* if no patch was needed.
"""
patched, n = _LOSSY_TEMPLATE_PATTERN.subn(
lambda m: m.group(0).split(" or ")[0], # keep only the object-union check
template,
)
return patched if n > 0 else None
def apply_chat_template(
tokenizer: TokenizerWrapper,
task_params: TextGenerationTaskParams,
@@ -453,14 +504,28 @@ def apply_chat_template(
extra_kwargs["enable_thinking"] = task_params.enable_thinking
extra_kwargs["thinking"] = task_params.enable_thinking
patched_template: str | None = None
if task_params.tools:
original_template: str | None = getattr(tokenizer, "chat_template", None)
if isinstance(original_template, str):
patched_template = _patch_lossy_chat_template(original_template)
if patched_template is not None:
logger.info(
"Patched lossy chat template (removed inner_type length guard)"
)
prompt: str = tokenizer.apply_chat_template(
formatted_messages,
tokenize=False,
add_generation_prompt=True,
tools=task_params.tools,
**({"chat_template": patched_template} if patched_template is not None else {}),
**extra_kwargs,
)
if task_params.tools and _schemas_lost_in_prompt(prompt, task_params.tools):
logger.warning("Chat template lost nested tool schemas even after patching")
if partial_assistant_content:
prompt += partial_assistant_content

View File

@@ -26,6 +26,7 @@ from exo.shared.types.common import CommandId
from exo.shared.types.events import (
ChunkGenerated,
Event,
PrefillProgress,
RunnerStatusUpdated,
TaskAcknowledged,
TaskStatusUpdated,
@@ -298,6 +299,18 @@ def main(
assert tokenizer
assert check_for_cancel_every
# Define callback to send prefill progress events directly
def on_prefill_progress(processed: int, total: int) -> None:
if device_rank == 0:
event_sender.send(
PrefillProgress(
command_id=command_id,
model=shard_metadata.model_card.model_id,
processed_tokens=processed,
total_tokens=total,
)
)
try:
_check_for_debug_prompts(task_params)
@@ -311,6 +324,7 @@ def main(
task=task_params,
prompt=prompt,
kv_prefix_cache=kv_prefix_cache,
on_prefill_progress=on_prefill_progress,
group=group,
)
@@ -599,11 +613,21 @@ def parse_gpt_oss(
ch = stream.current_channel
recipient = stream.current_recipient
# Debug: log every token with state
logger.debug(
f"parse_gpt_oss token={response.token} text={response.text!r} "
f"recipient={recipient!r} ch={ch!r} delta={delta!r} "
f"state={stream.state} current_tool={current_tool_name!r}"
)
if recipient != current_tool_name:
if current_tool_name is not None:
prefix = "functions."
if current_tool_name.startswith(prefix):
current_tool_name = current_tool_name[len(prefix) :]
logger.info(
f"parse_gpt_oss yielding tool call: name={current_tool_name!r}"
)
yield ToolCallResponse(
tool_calls=[
ToolCallItem(