Compare commits

..

19 Commits

Author SHA1 Message Date
Alex Cheema
3a181fbf33 style: format app.svelte.ts with nix fmt 2026-01-19 14:55:42 +00:00
Alex Cheema
67d4f23c61 Fix localStorage quota issues by stripping tokens and auto-pruning
- Strip tokens (logprobs data) from messages before saving to localStorage
  since they're large and not essential for persistence
- Add pruneOldConversations() to automatically remove oldest conversations
  when quota is exceeded
- This prevents QuotaExceededError from crashing the app

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-19 14:55:42 +00:00
Alex Cheema
0b0d0f7faf Fix ReferenceError: controller undefined in sendMessage finally block
Move AbortController creation before the try block in both
sendMessageWithLogprobs and regenerateFromToken functions.
Previously, controller was defined inside the try block but
referenced in the finally block, causing a ReferenceError
if an exception was thrown before the controller was created.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-19 14:55:42 +00:00
Alex Cheema
28f7521540 Add SSE headers to properly close streaming connections
Add Cache-Control, Connection: close, and X-Accel-Buffering headers
to all SSE streaming responses. These headers help ensure:
- No caching of streaming responses
- Connection closes when stream ends (instead of keep-alive)
- No proxy buffering that could delay stream closure

This should fix the issue where the frontend stays on "PROCESSING"
even after receiving the complete response.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-19 14:55:42 +00:00
Alex Cheema
94f9a09f24 Add debug logging to generate_chat_stream
Add logging to help diagnose why streaming might not be ending properly.
This will show when [DONE] is yielded, when return is called, and when
the finally block runs.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-19 14:55:42 +00:00
Alex Cheema
2bf64ffd47 Fix streaming not ending after [DONE] is yielded
Add missing return statement after yielding [DONE] in generate_chat_stream.
Without this, the async generator continues waiting for more chunks from
chunk_stream even though generation is complete, causing the stream to hang
indefinitely. The frontend waits for the stream to close (reader.done) which
never happens, resulting in the chat button staying on "PROCESSING" forever.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-19 14:55:42 +00:00
Alex Cheema
d091c84dc5 fix: restore extract_top_logprobs function for uncertainty visualization
The extract_top_logprobs function was lost during rebases. This function
processes the out.logprobs array (full vocabulary logprobs from MLX) to
extract the selected token's logprob and top-k alternatives.

The previous code tried to use getattr(out, "logprob", None) which
doesn't exist - mlx_lm returns logprobs as an mx.array, not individual
values.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-19 14:55:42 +00:00
Alex Cheema
e2c15f76b0 fix: remove unsupported logprob params from stream_generate
The mlx_lm.stream_generate already returns logprobs in its output -
we don't need to pass return_logprob or return_top_logprobs kwargs.
The uncertainty visualization feature extracts logprobs from the
existing out.logprobs field.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-19 14:55:42 +00:00
Alex Cheema
7d77043217 feat: add uncertainty visualization with token-level logprobs
- Add TokenHeatmap component for visualizing token confidence
- Collect and stream logprobs in generation pipeline
- Add regenerate-from-token feature with continue_from_prefix
- Add AbortController for request cancellation
- Support continue_final_message for seamless prefix continuation

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-19 14:55:42 +00:00
Alex Cheema
e1e4516a8f style: move inline imports to top of file in api.py
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-19 14:55:21 +00:00
Alex Cheema
8a67e949d1 fix: restore try/except structure in runner.py
Replace non-existent context manager with proper try/except block
and remove unused ModelId import.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-19 12:08:04 +00:00
Alex Cheema
fbf58bebd2 style: fix formatting issues caught by treefmt
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-19 12:06:15 +00:00
Alex Cheema
71b8e88d4b refactor: use ResponsesRequest as canonical internal type
- Extend ResponsesRequest with fields: top_k, seed, stop, tools
- Remove redundant InternalTaskParams and InputMessage types
- Update all adapters to convert to ResponsesRequest
- Simplify Responses API (no conversion needed - native passthrough)
- Update all imports across codebase and tests

This eliminates type duplication and makes the Responses API
relationship explicit throughout the codebase.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-19 12:06:14 +00:00
Alex Cheema
4b0ebb8ae4 refactor: make Responses API the canonical internal format
Restructure the API layer so that OpenAI Responses API is the native
format, with Chat Completions and Claude Messages as adapters on top.

Changes:
- Add new chat_completions.py adapter with streaming/non-streaming support
- Update responses.py with collect_responses_response() for non-streaming
- Update claude.py with collect_claude_response() for non-streaming
- Refactor api.py so all endpoints use adapters uniformly
- Rename _chat_chunk_stream to _token_chunk_stream (generic internal format)
- Remove unused chat_response_to_* converter functions
- Update tests to remove tests for deleted functions

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-19 12:06:14 +00:00
Alex Cheema
4df036d796 feat: add Claude Messages API and OpenAI Responses API support
Adds two new API endpoints that wrap the existing chat completions:

- /v1/messages - Claude Messages API compatible endpoint
- /v1/responses - OpenAI Responses API compatible endpoint

Both support streaming (SSE) and non-streaming modes with proper
token usage reporting from actual inference stats.

Also adds top_k sampling parameter and stop sequence support to the
MLX inference engine.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-19 12:06:14 +00:00
Evan Quiney
746589ba6b tidy: remove context manager from api (#1199) 2026-01-19 11:58:13 +00:00
rltakashige
f82f862fd7 Fix several issues with placement (#1200)
## Motivation

Uneven placements were causing issues for some users with lopsided
setups. While fixing, I ran into another issue with impossible
allocation of memory.

## Changes

- Allocate at least 1 layer per device.
- Catch overallocation of memory with an error.

## Why It Works

<!-- Explain why your approach solves the problem -->

## Test Plan

### Manual Testing
Tested that GPT OSS is placed correctly.

### Automated Testing
Added breaking tests in the first commit. Resolved with new placement
algorithm in the second one.
2026-01-19 11:52:35 +00:00
Alex Cheema
7ff937d8a1 Add dashboard screenshots to README (#1185)
## Motivation

The README showcases exo's features and benchmarks but doesn't show what
the dashboard actually looks like. Adding a screenshot helps users
understand what they'll get when they run exo.

## Changes

- Added dashboard screenshot to `docs/imgs/dashboard-cluster-view.png`:
Shows the cluster topology view with 4 × 512GB M3 Ultra Mac Studio
running DeepSeek v3.1 (8-bit) and Kimi-K2-Thinking (4-bit)
- Added a new "Dashboard" section to README.md below Features,
displaying the screenshot with caption

## Why It Works

Visual documentation helps users understand what exo offers before they
install it. The screenshot demonstrates the cluster management
capabilities.

## Test Plan

### Manual Testing
- Verified image renders correctly in GitHub markdown preview

### Automated Testing
- N/A - documentation only change

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-19 10:43:27 +00:00
Evan Quiney
d19bf02404 re-raise exceptions in the runner (#1198)
## Motivation

Runners that crash can swallow errors - we should re-raise. Also the
exception handler annoyed me.

## Changes

The try: except in the runner's chat now re-raises.
2026-01-19 10:35:23 +00:00
39 changed files with 2794 additions and 1021 deletions

View File

@@ -27,6 +27,15 @@ exo connects all your devices into an AI cluster. Not only does exo enable runni
- **Tensor Parallelism**: exo supports sharding models, for up to 1.8x speedup on 2 devices and 3.2x speedup on 4 devices.
- **MLX Support**: exo uses [MLX](https://github.com/ml-explore/mlx) as an inference backend and [MLX distributed](https://ml-explore.github.io/mlx/build/html/usage/distributed.html) for distributed communication.
## Dashboard
exo includes a built-in dashboard for managing your cluster and chatting with models.
<p align="center">
<img src="docs/imgs/dashboard-cluster-view.png" alt="exo dashboard - cluster view showing 4 x M3 Ultra Mac Studio with DeepSeek v3.1 and Kimi-K2-Thinking loaded" width="80%" />
</p>
<p align="center"><em>4 × 512GB M3 Ultra Mac Studio running DeepSeek v3.1 (8-bit) and Kimi-K2-Thinking (4-bit)</em></p>
## Benchmarks
<details>

View File

@@ -1,14 +1,16 @@
<script lang="ts">
import {
messages,
currentResponse,
import {
messages,
currentResponse,
isLoading,
deleteMessage,
editAndRegenerate,
regenerateLastResponse
regenerateLastResponse,
regenerateFromToken
} from '$lib/stores/app.svelte';
import type { MessageAttachment } from '$lib/stores/app.svelte';
import MarkdownContent from './MarkdownContent.svelte';
import TokenHeatmap from './TokenHeatmap.svelte';
interface Props {
class?: string;
@@ -95,6 +97,23 @@
let copiedMessageId = $state<string | null>(null);
let expandedThinkingMessageIds = $state<Set<string>>(new Set());
// Uncertainty view state - tracks which messages show token heatmap
let uncertaintyViewMessageIds = $state<Set<string>>(new Set());
function toggleUncertaintyView(messageId: string) {
const newSet = new Set(uncertaintyViewMessageIds);
if (newSet.has(messageId)) {
newSet.delete(messageId);
} else {
newSet.add(messageId);
}
uncertaintyViewMessageIds = newSet;
}
function isUncertaintyViewEnabled(messageId: string): boolean {
return uncertaintyViewMessageIds.has(messageId);
}
function formatTimestamp(timestamp: number): string {
return new Date(timestamp).toLocaleTimeString('en-US', {
hour12: false,
@@ -366,7 +385,17 @@ function isThinkingExpanded(messageId: string): boolean {
</div>
{/if}
<div class="text-xs text-foreground">
<MarkdownContent content={message.content || (loading ? response : '')} />
{#if message.role === 'assistant' && isUncertaintyViewEnabled(message.id) && message.tokens && message.tokens.length > 0}
<!-- Uncertainty heatmap view -->
<TokenHeatmap
tokens={message.tokens}
isGenerating={loading}
onRegenerateFrom={(tokenIndex) => regenerateFromToken(message.id, tokenIndex)}
/>
{:else}
<!-- Normal markdown view -->
<MarkdownContent content={message.content || (loading ? response : '')} />
{/if}
{#if loading && !message.content}
<span class="inline-block w-2 h-4 bg-exo-yellow/70 ml-1 cursor-blink"></span>
{/if}
@@ -419,7 +448,20 @@ function isThinkingExpanded(messageId: string): boolean {
</svg>
</button>
{/if}
<!-- Uncertainty view toggle (assistant messages with tokens only) -->
{#if message.role === 'assistant' && message.tokens && message.tokens.length > 0}
<button
onclick={() => toggleUncertaintyView(message.id)}
class="p-1.5 transition-colors rounded cursor-pointer {isUncertaintyViewEnabled(message.id) ? 'text-exo-yellow' : 'text-exo-light-gray hover:text-exo-yellow'}"
title={isUncertaintyViewEnabled(message.id) ? 'Hide uncertainty' : 'Show uncertainty'}
>
<svg class="w-3.5 h-3.5" fill="none" viewBox="0 0 24 24" stroke="currentColor">
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M9 19v-6a2 2 0 00-2-2H5a2 2 0 00-2 2v6a2 2 0 002 2h2a2 2 0 002-2zm0 0V9a2 2 0 012-2h2a2 2 0 012 2v10m-6 0a2 2 0 002 2h2a2 2 0 002-2m0 0V5a2 2 0 012-2h2a2 2 0 012 2v14a2 2 0 01-2 2h-2a2 2 0 01-2-2z" />
</svg>
</button>
{/if}
<!-- Delete button -->
<button
onclick={() => handleDeleteClick(message.id)}

View File

@@ -0,0 +1,192 @@
<script lang="ts">
import type { TokenData } from '$lib/stores/app.svelte';
interface Props {
tokens: TokenData[];
class?: string;
isGenerating?: boolean;
onRegenerateFrom?: (tokenIndex: number) => void;
}
let { tokens, class: className = '', isGenerating = false, onRegenerateFrom }: Props = $props();
// Tooltip state - track both token data and index
let hoveredTokenIndex = $state<number | null>(null);
let hoveredPosition = $state<{ x: number; y: number } | null>(null);
let isTooltipHovered = $state(false);
let hideTimeoutId: ReturnType<typeof setTimeout> | null = null;
// Derive the hovered token from the index (stable across re-renders)
const hoveredToken = $derived(
hoveredTokenIndex !== null && hoveredPosition && tokens[hoveredTokenIndex]
? { token: tokens[hoveredTokenIndex], index: hoveredTokenIndex, ...hoveredPosition }
: null
);
/**
* Get confidence styling based on probability.
* Following Apple design principles: high confidence tokens blend in,
* only uncertainty draws attention.
*/
function getConfidenceClass(probability: number): string {
if (probability > 0.8) return 'text-inherit'; // Expected tokens - blend in
if (probability > 0.5) return 'bg-gray-500/10 text-inherit'; // Slight hint
if (probability > 0.2) return 'bg-amber-500/15 text-amber-200/90'; // Subtle warmth
return 'bg-red-500/20 text-red-200/90'; // Draws attention
}
/**
* Get border/underline styling for uncertain tokens
*/
function getBorderClass(probability: number): string {
if (probability > 0.8) return 'border-transparent'; // No border for expected
if (probability > 0.5) return 'border-gray-500/20';
if (probability > 0.2) return 'border-amber-500/30';
return 'border-red-500/40';
}
function clearHideTimeout() {
if (hideTimeoutId) {
clearTimeout(hideTimeoutId);
hideTimeoutId = null;
}
}
function handleMouseEnter(event: MouseEvent, token: TokenData, index: number) {
clearHideTimeout();
const rect = (event.target as HTMLElement).getBoundingClientRect();
hoveredTokenIndex = index;
hoveredPosition = {
x: rect.left + rect.width / 2,
y: rect.top - 10
};
}
function handleMouseLeave() {
clearHideTimeout();
// Use longer delay during generation to account for re-renders
const delay = isGenerating ? 300 : 100;
hideTimeoutId = setTimeout(() => {
if (!isTooltipHovered) {
hoveredTokenIndex = null;
hoveredPosition = null;
}
}, delay);
}
function handleTooltipEnter() {
clearHideTimeout();
isTooltipHovered = true;
}
function handleTooltipLeave() {
isTooltipHovered = false;
hoveredTokenIndex = null;
hoveredPosition = null;
}
function handleRegenerate() {
if (hoveredToken && onRegenerateFrom) {
const indexToRegenerate = hoveredToken.index;
// Clear hover state immediately
hoveredTokenIndex = null;
hoveredPosition = null;
isTooltipHovered = false;
// Call regenerate
onRegenerateFrom(indexToRegenerate);
}
}
function formatProbability(prob: number): string {
return (prob * 100).toFixed(1) + '%';
}
function formatLogprob(logprob: number): string {
return logprob.toFixed(3);
}
function getProbabilityColor(probability: number): string {
if (probability > 0.8) return 'text-gray-300';
if (probability > 0.5) return 'text-gray-400';
if (probability > 0.2) return 'text-amber-400';
return 'text-red-400';
}
</script>
<div class="token-heatmap leading-relaxed {className}">
{#each tokens as tokenData, i (i)}
<span
role="button"
tabindex="0"
class="token-span inline rounded px-0.5 py-0.5 cursor-pointer transition-all duration-150 border {getConfidenceClass(tokenData.probability)} {getBorderClass(tokenData.probability)} hover:opacity-80"
onmouseenter={(e) => handleMouseEnter(e, tokenData, i)}
onmouseleave={handleMouseLeave}
>{tokenData.token}</span>
{/each}
</div>
<!-- Tooltip -->
{#if hoveredToken}
<div
class="fixed z-50"
style="left: {hoveredToken.x}px; top: {hoveredToken.y}px; transform: translate(-50%, -100%);"
onmouseenter={handleTooltipEnter}
onmouseleave={handleTooltipLeave}
>
<div class="bg-gray-900/95 backdrop-blur-sm border border-gray-700/50 rounded-xl shadow-xl p-3 text-sm min-w-48">
<!-- Token info -->
<div class="mb-2">
<span class="text-gray-500 text-xs">Token:</span>
<span class="text-white font-mono ml-1">"{hoveredToken.token.token}"</span>
<span class="{getProbabilityColor(hoveredToken.token.probability)} ml-2">{formatProbability(hoveredToken.token.probability)}</span>
</div>
<div class="text-gray-400 text-xs mb-1">
logprob: <span class="text-gray-300 font-mono">{formatLogprob(hoveredToken.token.logprob)}</span>
</div>
<!-- Top alternatives -->
{#if hoveredToken.token.topLogprobs.length > 0}
<div class="border-t border-gray-700/50 mt-2 pt-2">
<div class="text-gray-500 text-xs mb-1">Alternatives:</div>
{#each hoveredToken.token.topLogprobs.slice(0, 5) as alt, idx (idx)}
{@const altProb = Math.exp(alt.logprob)}
<div class="flex justify-between items-center text-xs py-0.5">
<span class="text-gray-300 font-mono truncate max-w-24">"{alt.token}"</span>
<span class="text-gray-400 ml-2">{formatProbability(altProb)}</span>
</div>
{/each}
</div>
{/if}
<!-- Regenerate button -->
{#if onRegenerateFrom}
<button
onclick={handleRegenerate}
class="w-full mt-2 pt-2 border-t border-gray-700/50 flex items-center justify-center gap-1.5 text-xs text-gray-400 hover:text-white transition-colors cursor-pointer"
>
<svg class="w-3 h-3" fill="none" viewBox="0 0 24 24" stroke="currentColor">
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M4 4v5h.582m15.356 2A8.001 8.001 0 004.582 9m0 0H9m11 11v-5h-.581m0 0a8.003 8.003 0 01-15.357-2m15.357 2H15" />
</svg>
Regenerate from here
</button>
{/if}
</div>
<!-- Arrow -->
<div class="absolute left-1/2 -translate-x-1/2 top-full">
<div class="border-8 border-transparent border-t-gray-900"></div>
</div>
</div>
{/if}
<style>
.token-heatmap {
word-wrap: break-word;
white-space: pre-wrap;
}
.token-span {
margin: 0;
border-width: 1px;
}
</style>

View File

@@ -69,8 +69,6 @@ export interface Instance {
runnerToShard?: Record<string, unknown>;
nodeToRunner?: Record<string, string>;
};
draftModel?: string;
numDraftTokens?: number;
}
interface RawNodeProfile {
@@ -184,6 +182,20 @@ export interface MessageAttachment {
mimeType?: string;
}
// Token-level data for uncertainty visualization
export interface TopLogprob {
token: string;
logprob: number;
bytes?: number[];
}
export interface TokenData {
token: string;
logprob: number;
probability: number; // exp(logprob)
topLogprobs: TopLogprob[];
}
export interface Message {
id: string;
role: "user" | "assistant" | "system";
@@ -193,6 +205,7 @@ export interface Message {
attachments?: MessageAttachment[];
ttftMs?: number; // Time to first token in ms (for assistant messages)
tps?: number; // Tokens per second (for assistant messages)
tokens?: TokenData[]; // Token-level data for uncertainty visualization
}
export interface Conversation {
@@ -370,6 +383,21 @@ class AppStore {
private fetchInterval: ReturnType<typeof setInterval> | null = null;
private previewsInterval: ReturnType<typeof setInterval> | null = null;
private lastConversationPersistTs = 0;
private currentRequestController: AbortController | null = null;
/**
* Abort any in-flight generation request
*/
abortCurrentRequest(): boolean {
if (this.currentRequestController) {
this.currentRequestController.abort();
this.currentRequestController = null;
this.isLoading = false;
this.currentResponse = "";
return true;
}
return false;
}
constructor() {
if (browser) {
@@ -407,12 +435,61 @@ class AppStore {
/**
* Save conversations to localStorage
* Note: We strip tokens (logprobs data) to save space - they're large and not essential for persistence
*/
private saveConversationsToStorage() {
try {
localStorage.setItem(STORAGE_KEY, JSON.stringify(this.conversations));
// Strip tokens from messages to save localStorage space
const conversationsToSave = this.conversations.map((conv) => ({
...conv,
messages: conv.messages.map((msg) => {
// eslint-disable-next-line @typescript-eslint/no-unused-vars
const { tokens, ...msgWithoutTokens } = msg;
return msgWithoutTokens;
}),
}));
localStorage.setItem(STORAGE_KEY, JSON.stringify(conversationsToSave));
} catch (error) {
console.error("Failed to save conversations:", error);
// If quota exceeded, try to clear old conversations and retry
if (
error instanceof DOMException &&
error.name === "QuotaExceededError"
) {
console.warn(
"Storage quota exceeded, clearing oldest conversations...",
);
this.pruneOldConversations();
}
}
}
/**
* Remove oldest conversations to free up storage space
*/
private pruneOldConversations() {
if (this.conversations.length <= 1) return;
// Sort by updatedAt and remove oldest half
const sorted = [...this.conversations].sort(
(a, b) => (b.updatedAt || 0) - (a.updatedAt || 0),
);
const keepCount = Math.max(1, Math.ceil(sorted.length / 2));
this.conversations = sorted.slice(0, keepCount);
// Try saving again
try {
const conversationsToSave = this.conversations.map((conv) => ({
...conv,
messages: conv.messages.map((msg) => {
// eslint-disable-next-line @typescript-eslint/no-unused-vars
const { tokens, ...msgWithoutTokens } = msg;
return msgWithoutTokens;
}),
}));
localStorage.setItem(STORAGE_KEY, JSON.stringify(conversationsToSave));
} catch {
console.error("Still failed to save after pruning");
}
}
@@ -1333,6 +1410,11 @@ class AppStore {
const assistantMessage = this.addMessage("assistant", "");
this.updateActiveConversation();
// Create abort controller for this request - must be defined before try block
// so it's available in the finally block
const controller = new AbortController();
this.currentRequestController = controller;
try {
// Build the messages array for the API with system prompt
const systemPrompt = {
@@ -1410,7 +1492,10 @@ class AppStore {
messages: apiMessages,
temperature: 0.7,
stream: true,
logprobs: true,
top_logprobs: 5,
}),
signal: controller.signal,
});
if (!response.ok) {
@@ -1426,6 +1511,7 @@ class AppStore {
const decoder = new TextDecoder();
let fullContent = "";
let buffer = "";
const collectedTokens: TokenData[] = [];
while (true) {
const { done, value } = await reader.read();
@@ -1447,8 +1533,8 @@ class AppStore {
try {
const parsed = JSON.parse(data);
const tokenContent = parsed.choices?.[0]?.delta?.content;
if (tokenContent) {
const delta = parsed.choices?.[0]?.delta?.content;
if (delta) {
// Track first token for TTFT
if (firstTokenTime === null) {
firstTokenTime = performance.now();
@@ -1465,7 +1551,30 @@ class AppStore {
this.tps = (tokenCount / elapsed) * 1000;
}
fullContent += tokenContent;
// Extract logprobs for uncertainty visualization
const logprobsData = parsed.choices?.[0]?.logprobs;
if (logprobsData?.content?.[0]) {
const logprobItem = logprobsData.content[0];
const tokenData: TokenData = {
token: logprobItem.token || delta,
logprob: logprobItem.logprob ?? 0,
probability: Math.exp(logprobItem.logprob ?? 0),
topLogprobs: (logprobItem.top_logprobs || []).map(
(item: {
token: string;
logprob: number;
bytes?: number[];
}) => ({
token: item.token,
logprob: item.logprob,
bytes: item.bytes,
}),
),
};
collectedTokens.push(tokenData);
}
fullContent += delta;
// Strip thinking tags for display and extract thinking content
const { displayContent, thinkingContent } =
@@ -1479,6 +1588,7 @@ class AppStore {
if (idx !== -1) {
this.messages[idx].content = displayContent;
this.messages[idx].thinking = thinkingContent || undefined;
this.messages[idx].tokens = [...collectedTokens];
}
this.persistActiveConversation();
}
@@ -1526,9 +1636,16 @@ class AppStore {
if (this.tps !== null) {
this.messages[idx].tps = this.tps;
}
if (collectedTokens.length > 0) {
this.messages[idx].tokens = collectedTokens;
}
}
this.persistActiveConversation();
} catch (error) {
// Don't show error for aborted requests (user cancelled)
if (error instanceof Error && error.name === "AbortError") {
return;
}
console.error("Error sending message:", error);
// Update the assistant message with error
const idx = this.messages.findIndex((m) => m.id === assistantMessage.id);
@@ -1538,6 +1655,237 @@ class AppStore {
}
this.persistActiveConversation();
} finally {
// Clean up controller if this is still the active request
if (this.currentRequestController === controller) {
this.currentRequestController = null;
}
this.isLoading = false;
this.currentResponse = "";
this.updateActiveConversation();
}
}
/**
* Regenerate from a specific token in an assistant message.
* Keeps content up to and including the specified token, then continues generation.
* If a generation is already in progress, it will be aborted first.
*/
async regenerateFromToken(
messageId: string,
tokenIndex: number,
): Promise<void> {
// Abort any in-flight request first
this.abortCurrentRequest();
const messageIdx = this.messages.findIndex((m) => m.id === messageId);
if (messageIdx === -1) return;
const message = this.messages[messageIdx];
if (message.role !== "assistant" || !message.tokens) return;
// Get tokens up to and including the specified index
const tokensToKeep = message.tokens.slice(0, tokenIndex + 1);
const prefixText = tokensToKeep.map((t) => t.token).join("");
// Remove all messages after this assistant message
this.messages = this.messages.slice(0, messageIdx + 1);
// Update the message to show the prefix
this.messages[messageIdx].content = prefixText;
this.messages[messageIdx].tokens = tokensToKeep;
// Set up for continuation
this.isLoading = true;
this.currentResponse = prefixText;
this.ttftMs = null;
this.tps = null;
this.totalTokens = tokensToKeep.length;
// Create abort controller before try block so it's available in finally
const controller = new AbortController();
this.currentRequestController = controller;
try {
// Build messages for API - include the partial assistant message
const systemPrompt = {
role: "system" as const,
content:
"You are a helpful AI assistant. Respond directly and concisely. Do not show your reasoning or thought process.",
};
// Get all messages up to and including the one we're regenerating from
const apiMessages = [
systemPrompt,
...this.messages.map((m) => {
let msgContent = m.content;
if (m.attachments) {
for (const attachment of m.attachments) {
if (attachment.type === "text" && attachment.content) {
msgContent += `\n\n[File: ${attachment.name}]\n\`\`\`\n${attachment.content}\n\`\`\``;
}
}
}
return { role: m.role, content: msgContent };
}),
];
// Determine model
let modelToUse = this.selectedChatModel;
if (!modelToUse) {
for (const [, instanceWrapper] of Object.entries(this.instances)) {
if (instanceWrapper && typeof instanceWrapper === "object") {
const keys = Object.keys(
instanceWrapper as Record<string, unknown>,
);
if (keys.length === 1) {
const instance = (instanceWrapper as Record<string, unknown>)[
keys[0]
] as { shardAssignments?: { modelId?: string } };
if (instance?.shardAssignments?.modelId) {
modelToUse = instance.shardAssignments.modelId;
break;
}
}
}
}
}
if (!modelToUse) {
throw new Error("No model available");
}
// Start timing
const requestStartTime = performance.now();
let firstTokenTime: number | null = null;
let tokenCount = tokensToKeep.length;
const response = await fetch("/v1/chat/completions", {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({
model: modelToUse,
messages: apiMessages,
stream: true,
logprobs: true,
top_logprobs: 5,
continue_from_prefix: true,
}),
signal: controller.signal,
});
if (!response.ok) {
const errorText = await response.text();
throw new Error(`API error: ${response.status} - ${errorText}`);
}
const reader = response.body?.getReader();
if (!reader) throw new Error("No response body");
const decoder = new TextDecoder();
let fullContent = prefixText;
let buffer = "";
const collectedTokens: TokenData[] = [...tokensToKeep];
while (true) {
const { done, value } = await reader.read();
if (done) break;
buffer += decoder.decode(value, { stream: true });
const lines = buffer.split("\n");
buffer = lines.pop() || "";
for (const line of lines) {
const trimmed = line.trim();
if (!trimmed || trimmed === "data: [DONE]") continue;
if (trimmed.startsWith("data: ")) {
try {
const json = JSON.parse(trimmed.slice(6));
const delta = json.choices?.[0]?.delta?.content;
if (delta) {
if (firstTokenTime === null) {
firstTokenTime = performance.now();
this.ttftMs = firstTokenTime - requestStartTime;
}
tokenCount += 1;
this.totalTokens = tokenCount;
if (
firstTokenTime !== null &&
tokenCount > tokensToKeep.length
) {
const elapsed = performance.now() - firstTokenTime;
this.tps =
((tokenCount - tokensToKeep.length) / elapsed) * 1000;
}
// Extract logprobs
const logprobsData = json.choices?.[0]?.logprobs;
if (logprobsData?.content?.[0]) {
const logprobItem = logprobsData.content[0];
collectedTokens.push({
token: logprobItem.token || delta,
logprob: logprobItem.logprob ?? 0,
probability: Math.exp(logprobItem.logprob ?? 0),
topLogprobs: (logprobItem.top_logprobs || []).map(
(item: {
token: string;
logprob: number;
bytes?: number[];
}) => ({
token: item.token,
logprob: item.logprob,
bytes: item.bytes,
}),
),
});
}
fullContent += delta;
const { displayContent, thinkingContent } =
this.stripThinkingTags(fullContent);
this.currentResponse = displayContent;
this.messages[messageIdx].content = displayContent;
this.messages[messageIdx].thinking =
thinkingContent || undefined;
this.messages[messageIdx].tokens = [...collectedTokens];
this.persistActiveConversation();
}
} catch {
// Skip malformed JSON
}
}
}
}
// Final update
const { displayContent, thinkingContent } =
this.stripThinkingTags(fullContent);
this.messages[messageIdx].content = displayContent;
this.messages[messageIdx].thinking = thinkingContent || undefined;
this.messages[messageIdx].tokens = collectedTokens;
if (this.ttftMs !== null) {
this.messages[messageIdx].ttftMs = this.ttftMs;
}
if (this.tps !== null) {
this.messages[messageIdx].tps = this.tps;
}
this.persistActiveConversation();
} catch (error) {
if (error instanceof Error && error.name === "AbortError") {
return;
}
console.error("Error regenerating from token:", error);
this.messages[messageIdx].content =
`${prefixText}\n\nError: ${error instanceof Error ? error.message : "Unknown error"}`;
this.persistActiveConversation();
} finally {
if (this.currentRequestController === controller) {
this.currentRequestController = null;
}
this.isLoading = false;
this.currentResponse = "";
this.updateActiveConversation();
@@ -1617,6 +1965,8 @@ export const editMessage = (messageId: string, newContent: string) =>
export const editAndRegenerate = (messageId: string, newContent: string) =>
appStore.editAndRegenerate(messageId, newContent);
export const regenerateLastResponse = () => appStore.regenerateLastResponse();
export const regenerateFromToken = (messageId: string, tokenIndex: number) =>
appStore.regenerateFromToken(messageId, tokenIndex);
// Conversation actions
export const conversations = () => appStore.conversations;

View File

@@ -47,7 +47,7 @@ const sidebarVisible = $derived(chatSidebarVisible());
let mounted = $state(false);
// Instance launch state
let models = $state<Array<{id: string, hugging_face_id?: string, name?: string, storage_size_megabytes?: number}>>([]);
let models = $state<Array<{id: string, name?: string, storage_size_megabytes?: number}>>([]);
let selectedSharding = $state<'Pipeline' | 'Tensor'>('Pipeline');
type InstanceMeta = 'MlxRing' | 'MlxIbv' | 'MlxJaccl';
@@ -59,7 +59,7 @@ const sidebarVisible = $derived(chatSidebarVisible());
instanceType: InstanceMeta;
minNodes: number;
}
function saveLaunchDefaults(): void {
const defaults: LaunchDefaults = {
modelId: selectedPreviewModelId(),
@@ -88,16 +88,16 @@ const sidebarVisible = $derived(chatSidebarVisible());
function applyLaunchDefaults(availableModels: Array<{id: string}>, maxNodes: number): void {
const defaults = loadLaunchDefaults();
if (!defaults) return;
// Apply sharding and instance type unconditionally
selectedSharding = defaults.sharding;
selectedInstanceType = defaults.instanceType;
// Apply minNodes if valid (between 1 and maxNodes)
if (defaults.minNodes && defaults.minNodes >= 1 && defaults.minNodes <= maxNodes) {
selectedMinNodes = defaults.minNodes;
}
// Only apply model if it exists in the available models
if (defaults.modelId && availableModels.some(m => m.id === defaults.modelId)) {
selectPreviewModel(defaults.modelId);
@@ -109,19 +109,11 @@ const sidebarVisible = $derived(chatSidebarVisible());
let minNodesInitialized = $state(false);
let launchingModelId = $state<string | null>(null);
let instanceDownloadExpandedNodes = $state<Set<string>>(new Set());
// Draft model edit modal state
let editingDraftInstanceId = $state<string | null>(null);
let editDraftModel = $state<string | null>(null);
let editNumDraftTokens = $state<number>(4);
let isDraftEditDropdownOpen = $state(false);
let draftEditDropdownSearch = $state('');
let isSavingDraftModel = $state(false);
// Custom dropdown state
let isModelDropdownOpen = $state(false);
let modelDropdownSearch = $state('');
// Slider dragging state
let isDraggingSlider = $state(false);
let sliderTrackElement: HTMLDivElement | null = $state(null);
@@ -370,36 +362,47 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
async function launchInstance(modelId: string, specificPreview?: PlacementPreview | null) {
if (!modelId || launchingModelId) return;
launchingModelId = modelId;
try {
// Use the specific preview if provided, otherwise fall back to filtered preview
const preview = specificPreview ?? filteredPreview();
let response: Response;
// Use /place_instance endpoint - it handles placement and creation in one step
const placePayload = {
model_id: modelId,
sharding: preview?.sharding ?? selectedSharding,
instance_meta: preview?.instance_meta ?? selectedInstanceType,
min_nodes: selectedMinNodes,
};
response = await fetch('/place_instance', {
let instanceData: unknown;
if (preview?.instance) {
// Use the instance from the preview
instanceData = preview.instance;
} else {
// Fallback: GET placement from API
const placementResponse = await fetch(
`/instance/placement?model_id=${encodeURIComponent(modelId)}&sharding=${selectedSharding}&instance_meta=${selectedInstanceType}&min_nodes=${selectedMinNodes}`
);
if (!placementResponse.ok) {
const errorText = await placementResponse.text();
console.error('Failed to get placement:', errorText);
return;
}
instanceData = await placementResponse.json();
}
// POST the instance to create it
const response = await fetch('/instance', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify(placePayload)
body: JSON.stringify({ instance: instanceData })
});
if (!response.ok) {
const errorText = await response.text();
console.error('Failed to launch instance:', errorText);
} else {
// Always auto-select the newly launched model so the user chats to what they just launched
setSelectedChatModel(modelId);
// Scroll to the bottom of instances container to show the new instance
// Use multiple attempts to ensure DOM has updated with the new instance
const scrollToBottom = () => {
@@ -794,52 +797,6 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
}
}
// Open draft model edit modal for an instance
function openDraftModelEdit(instanceId: string, currentDraftModel: string | null, currentNumTokens: number | null) {
editingDraftInstanceId = instanceId;
editDraftModel = currentDraftModel;
editNumDraftTokens = currentNumTokens ?? 4;
isDraftEditDropdownOpen = false;
draftEditDropdownSearch = '';
}
// Close draft model edit modal
function closeDraftModelEdit() {
editingDraftInstanceId = null;
editDraftModel = null;
editNumDraftTokens = 4;
isDraftEditDropdownOpen = false;
draftEditDropdownSearch = '';
}
// Save draft model settings for an instance
async function saveDraftModel() {
if (!editingDraftInstanceId || isSavingDraftModel) return;
isSavingDraftModel = true;
try {
const response = await fetch(`/instance/${editingDraftInstanceId}/draft_model`, {
method: 'PUT',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({
draft_model: editDraftModel,
num_draft_tokens: editNumDraftTokens,
})
});
if (!response.ok) {
const errorText = await response.text();
console.error('Failed to set draft model:', errorText);
} else {
closeDraftModelEdit();
}
} catch (error) {
console.error('Error setting draft model:', error);
} finally {
isSavingDraftModel = false;
}
}
// Helper to unwrap tagged unions like { MlxRingInstance: {...} }
function getTagged(obj: unknown): [string | null, unknown] {
if (!obj || typeof obj !== 'object') return [null, null];
@@ -859,34 +816,30 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
}
// Get instance details: type (MLX Ring/IBV), sharding (Pipeline/Tensor), and node names
function getInstanceInfo(instanceWrapped: unknown): {
instanceType: string;
sharding: string;
function getInstanceInfo(instanceWrapped: unknown): {
instanceType: string;
sharding: string;
nodeNames: string[];
nodeIds: string[];
nodeCount: number;
draftModel: string | null;
numDraftTokens: number | null;
} {
const [instanceTag, instance] = getTagged(instanceWrapped);
if (!instance || typeof instance !== 'object') {
return { instanceType: 'Unknown', sharding: 'Unknown', nodeNames: [], nodeIds: [], nodeCount: 0, draftModel: null, numDraftTokens: null };
return { instanceType: 'Unknown', sharding: 'Unknown', nodeNames: [], nodeIds: [], nodeCount: 0 };
}
// Instance type from tag
let instanceType = 'Unknown';
if (instanceTag === 'MlxRingInstance') instanceType = 'MLX Ring';
else if (instanceTag === 'MlxIbvInstance' || instanceTag === 'MlxJacclInstance') instanceType = 'MLX RDMA';
const inst = instance as {
shardAssignments?: {
nodeToRunner?: Record<string, string>;
const inst = instance as {
shardAssignments?: {
nodeToRunner?: Record<string, string>;
runnerToShard?: Record<string, unknown>;
};
draftModel?: string;
numDraftTokens?: number;
}
};
// Sharding strategy from first shard
let sharding = 'Unknown';
const runnerToShard = inst.shardAssignments?.runnerToShard || {};
@@ -897,7 +850,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
else if (shardTag === 'TensorShardMetadata') sharding = 'Tensor';
else if (shardTag === 'PrefillDecodeShardMetadata') sharding = 'Prefill/Decode';
}
// Node names from topology
const nodeToRunner = inst.shardAssignments?.nodeToRunner || {};
const nodeIds = Object.keys(nodeToRunner);
@@ -905,12 +858,8 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
const node = data?.nodes?.[nodeId];
return node?.friendly_name || nodeId.slice(0, 8);
});
// Draft model for speculative decoding
const draftModel = inst.draftModel ?? null;
const numDraftTokens = inst.numDraftTokens ?? null;
return { instanceType, sharding, nodeNames, nodeIds, nodeCount: nodeIds.length, draftModel, numDraftTokens };
return { instanceType, sharding, nodeNames, nodeIds, nodeCount: nodeIds.length };
}
function formatLastUpdate(): string {
@@ -1386,31 +1335,16 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
<div class="w-1.5 h-1.5 {isDownloading ? 'bg-blue-400 animate-pulse' : isFailed ? 'bg-red-400' : isLoading ? 'bg-yellow-400 animate-pulse' : isReady ? 'bg-green-400' : 'bg-teal-400'} rounded-full shadow-[0_0_6px_currentColor]"></div>
<span class="text-exo-light-gray font-mono text-sm tracking-wider">{id.slice(0, 8).toUpperCase()}</span>
</div>
<div class="flex items-center gap-2">
<!-- Draft Model Button -->
<button
onclick={() => openDraftModelEdit(id, instanceInfo.draftModel, instanceInfo.numDraftTokens)}
class="p-1.5 font-mono border transition-all duration-200 cursor-pointer {instanceInfo.draftModel ? 'border-cyan-500/50 text-cyan-400 hover:bg-cyan-500/20 hover:border-cyan-500' : 'border-exo-medium-gray/50 text-white/40 hover:text-cyan-400 hover:border-cyan-500/50'}"
title={instanceInfo.draftModel ? `Draft: ${instanceInfo.draftModel.split('/').pop()} (${instanceInfo.numDraftTokens}t)` : 'Configure speculative decoding'}
>
<svg class="w-4 h-4" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
<path d="M13 2L3 14h9l-1 8 10-12h-9l1-8z"/>
</svg>
</button>
<button
onclick={() => deleteInstance(id)}
class="text-xs px-2 py-1 font-mono tracking-wider uppercase border border-red-500/30 text-red-400 hover:bg-red-500/20 hover:text-red-400 hover:border-red-500/50 transition-all duration-200 cursor-pointer"
>
DELETE
</button>
</div>
<button
onclick={() => deleteInstance(id)}
class="text-xs px-2 py-1 font-mono tracking-wider uppercase border border-red-500/30 text-red-400 hover:bg-red-500/20 hover:text-red-400 hover:border-red-500/50 transition-all duration-200 cursor-pointer"
>
DELETE
</button>
</div>
<div class="pl-2">
<div class="text-exo-yellow text-xs font-mono tracking-wide truncate">{getInstanceModelId(instance)}</div>
<div class="text-white/60 text-xs font-mono">Strategy: <span class="text-white/80">{instanceInfo.sharding} ({instanceInfo.instanceType})</span></div>
{#if instanceInfo.draftModel}
<div class="text-white/60 text-xs font-mono">Draft: <span class="text-cyan-400">{instanceInfo.draftModel.split('/').pop()}</span>{#if instanceInfo.numDraftTokens}<span class="text-white/40"> ({instanceInfo.numDraftTokens}t)</span>{/if}</div>
{/if}
{#if instanceModelId && instanceModelId !== 'Unknown' && instanceModelId !== 'Unknown Model'}
<a
class="inline-flex items-center gap-1 text-[11px] text-white/60 hover:text-exo-yellow transition-colors mt-1"
@@ -1745,7 +1679,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
</div>
</div>
</div>
<!-- Selected Model Preview -->
<div class="space-y-3">
{#if models.length === 0}
@@ -1904,31 +1838,16 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
<div class="w-1.5 h-1.5 {isDownloading ? 'bg-blue-400 animate-pulse' : isFailed ? 'bg-red-400' : isLoading ? 'bg-yellow-400 animate-pulse' : isReady ? 'bg-green-400' : 'bg-teal-400'} rounded-full shadow-[0_0_6px_currentColor]"></div>
<span class="text-exo-light-gray font-mono text-sm tracking-wider">{id.slice(0, 8).toUpperCase()}</span>
</div>
<div class="flex items-center gap-2">
<!-- Draft Model Button -->
<button
onclick={() => openDraftModelEdit(id, instanceInfo.draftModel, instanceInfo.numDraftTokens)}
class="p-1.5 font-mono border transition-all duration-200 cursor-pointer {instanceInfo.draftModel ? 'border-cyan-500/50 text-cyan-400 hover:bg-cyan-500/20 hover:border-cyan-500' : 'border-exo-medium-gray/50 text-white/40 hover:text-cyan-400 hover:border-cyan-500/50'}"
title={instanceInfo.draftModel ? `Draft: ${instanceInfo.draftModel.split('/').pop()} (${instanceInfo.numDraftTokens}t)` : 'Configure speculative decoding'}
>
<svg class="w-4 h-4" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
<path d="M13 2L3 14h9l-1 8 10-12h-9l1-8z"/>
</svg>
</button>
<button
onclick={() => deleteInstance(id)}
class="text-xs px-2 py-1 font-mono tracking-wider uppercase border border-red-500/30 text-red-400 hover:bg-red-500/20 hover:text-red-400 hover:border-red-500/50 transition-all duration-200 cursor-pointer"
>
DELETE
</button>
</div>
<button
onclick={() => deleteInstance(id)}
class="text-xs px-2 py-1 font-mono tracking-wider uppercase border border-red-500/30 text-red-400 hover:bg-red-500/20 hover:text-red-400 hover:border-red-500/50 transition-all duration-200 cursor-pointer"
>
DELETE
</button>
</div>
<div class="pl-2">
<div class="text-exo-yellow text-xs font-mono tracking-wide truncate">{getInstanceModelId(instance)}</div>
<div class="text-white/60 text-xs font-mono">Strategy: <span class="text-white/80">{instanceInfo.sharding} ({instanceInfo.instanceType})</span></div>
{#if instanceInfo.draftModel}
<div class="text-white/60 text-xs font-mono">Draft: <span class="text-cyan-400">{instanceInfo.draftModel.split('/').pop()}</span>{#if instanceInfo.numDraftTokens}<span class="text-white/40"> ({instanceInfo.numDraftTokens}t)</span>{/if}</div>
{/if}
{#if instanceModelId && instanceModelId !== 'Unknown' && instanceModelId !== 'Unknown Model'}
<a
class="inline-flex items-center gap-1 text-[11px] text-white/60 hover:text-exo-yellow transition-colors mt-1"
@@ -2059,120 +1978,4 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
{/if}
</main>
<!-- Draft Model Edit Modal -->
{#if editingDraftInstanceId}
<!-- svelte-ignore a11y_no_static_element_interactions -->
<div
class="fixed inset-0 z-50 flex items-center justify-center bg-black/70 backdrop-blur-sm"
onclick={closeDraftModelEdit}
onkeydown={(e) => e.key === 'Escape' && closeDraftModelEdit()}
>
<!-- svelte-ignore a11y_click_events_have_key_events -->
<div
class="bg-exo-dark-gray border border-exo-medium-gray/50 rounded-lg shadow-2xl p-6 w-full max-w-md mx-4"
onclick={(e) => e.stopPropagation()}
>
<div class="flex items-center justify-between mb-4">
<h3 class="text-lg font-mono text-exo-yellow tracking-wide">Speculative Decoding</h3>
<button
onclick={closeDraftModelEdit}
class="text-white/60 hover:text-white transition-colors cursor-pointer"
aria-label="Close"
>
<svg class="w-5 h-5" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2">
<path stroke-linecap="round" stroke-linejoin="round" d="M6 18L18 6M6 6l12 12" />
</svg>
</button>
</div>
<p class="text-white/60 text-sm font-mono mb-4">
Configure a draft model for faster generation. The draft model proposes tokens that the main model verifies.
</p>
<!-- Draft Model Dropdown -->
<div class="mb-4">
<div class="text-xs text-white/70 font-mono mb-2">Draft Model:</div>
<div class="relative">
<button
onclick={() => { isDraftEditDropdownOpen = !isDraftEditDropdownOpen; draftEditDropdownSearch = ''; }}
class="w-full px-3 py-2 text-left text-sm font-mono border rounded transition-all duration-200 cursor-pointer flex items-center justify-between gap-2 {editDraftModel ? 'bg-transparent text-cyan-400 border-cyan-500/50' : 'bg-transparent text-white/50 border-exo-medium-gray/50 hover:border-cyan-500/50'}"
>
<span class="truncate">{editDraftModel ? editDraftModel.split('/').pop() : 'None'}</span>
<svg class="w-4 h-4 flex-shrink-0 transition-transform {isDraftEditDropdownOpen ? 'rotate-180' : ''}" fill="none" stroke="currentColor" viewBox="0 0 24 24">
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M19 9l-7 7-7-7" />
</svg>
</button>
{#if isDraftEditDropdownOpen}
<div class="absolute top-full left-0 right-0 mt-1 bg-exo-dark-gray border border-exo-medium-gray/50 rounded shadow-lg z-50 max-h-48 overflow-hidden flex flex-col">
<div class="p-2 border-b border-exo-medium-gray/30">
<input
type="text"
bind:value={draftEditDropdownSearch}
placeholder="Search models..."
class="w-full px-2 py-1.5 text-sm font-mono bg-transparent border border-exo-medium-gray/50 rounded text-white/90 placeholder:text-white/30 focus:outline-none focus:border-cyan-500/50"
/>
</div>
<div class="overflow-y-auto max-h-36">
<!-- None option -->
<button
onclick={() => { editDraftModel = null; isDraftEditDropdownOpen = false; }}
class="w-full px-3 py-2 text-left text-sm font-mono tracking-wide transition-colors duration-100 flex items-center gap-2 {editDraftModel === null ? 'bg-transparent text-cyan-400 cursor-pointer' : 'text-white/80 hover:text-cyan-400 cursor-pointer'}"
>
<span>None (Disable)</span>
</button>
{#each models.filter(m => (m.name ?? m.id).toLowerCase().includes(draftEditDropdownSearch.toLowerCase())) as model}
{@const sizeGB = (model.storage_size_megabytes ?? 0) / 1024}
{@const modelHfId = model.hugging_face_id ?? model.id}
<button
onclick={() => { editDraftModel = modelHfId; isDraftEditDropdownOpen = false; }}
class="w-full px-3 py-2 text-left text-sm font-mono tracking-wide transition-colors duration-100 flex items-center justify-between gap-2 {editDraftModel === modelHfId ? 'bg-transparent text-cyan-400 cursor-pointer' : 'text-white/80 hover:text-cyan-400 cursor-pointer'}"
>
<span class="truncate">{model.name || model.id}</span>
<span class="flex-shrink-0 text-xs text-white/50">
{sizeGB >= 1 ? sizeGB.toFixed(0) : sizeGB.toFixed(1)}GB
</span>
</button>
{:else}
<div class="px-3 py-2 text-xs text-white/50 font-mono">No models found</div>
{/each}
</div>
</div>
{/if}
</div>
</div>
<!-- Draft Tokens -->
{#if editDraftModel}
<div class="mb-6">
<div class="text-xs text-white/70 font-mono mb-2">Draft Tokens per Iteration:</div>
<div class="flex items-center gap-2">
{#each [2, 3, 4, 5, 6] as n}
<button
onclick={() => editNumDraftTokens = n}
class="w-8 h-8 text-sm font-mono rounded transition-all {editNumDraftTokens === n ? 'bg-cyan-500/20 text-cyan-400 border border-cyan-500/50' : 'text-white/50 hover:text-white/80 border border-exo-medium-gray/50 hover:border-white/30'} cursor-pointer"
>{n}</button>
{/each}
</div>
</div>
{/if}
<!-- Action Buttons -->
<div class="flex items-center justify-end gap-3">
<button
onclick={closeDraftModelEdit}
class="px-4 py-2 text-sm font-mono text-white/70 hover:text-white transition-colors cursor-pointer"
>
Cancel
</button>
<button
onclick={saveDraftModel}
disabled={isSavingDraftModel}
class="px-4 py-2 text-sm font-mono border border-cyan-500/50 text-cyan-400 hover:bg-cyan-500/20 hover:border-cyan-500 transition-all disabled:opacity-50 disabled:cursor-not-allowed cursor-pointer"
>
{isSavingDraftModel ? 'Saving...' : 'Save'}
</button>
</div>
</div>
</div>
{/if}
</div>

View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 187 KiB

View File

@@ -0,0 +1 @@
"""API adapters for different API formats (Claude, OpenAI Responses, etc.)."""

View File

@@ -0,0 +1,186 @@
"""OpenAI Chat Completions API adapter for converting requests/responses."""
import time
from collections.abc import AsyncGenerator
from loguru import logger
from exo.shared.types.api import (
ChatCompletionChoice,
ChatCompletionMessage,
ChatCompletionMessageText,
ChatCompletionResponse,
ChatCompletionTaskParams,
ErrorInfo,
ErrorResponse,
FinishReason,
Logprobs,
LogprobsContentItem,
StreamingChoiceResponse,
)
from exo.shared.types.chunks import TokenChunk
from exo.shared.types.common import CommandId
from exo.shared.types.openai_responses import ResponseInputMessage, ResponsesRequest
def chat_request_to_internal(request: ChatCompletionTaskParams) -> ResponsesRequest:
"""Convert Chat Completions API request to ResponsesRequest (canonical internal format).
Extracts system message as instructions, converts messages to input.
"""
instructions: str | None = None
input_messages: list[ResponseInputMessage] = []
for msg in request.messages:
# Normalize content to string
content: str
if msg.content is None:
content = ""
elif isinstance(msg.content, str):
content = msg.content
elif isinstance(msg.content, ChatCompletionMessageText):
content = msg.content.text
else:
# List of ChatCompletionMessageText
content = "\n".join(item.text for item in msg.content)
# Extract system message as instructions
if msg.role == "system":
if instructions is None:
instructions = content
else:
# Append additional system messages
instructions = f"{instructions}\n{content}"
else:
# Convert to ResponseInputMessage (only user, assistant, developer roles)
if msg.role in ("user", "assistant", "developer"):
input_messages.append(
ResponseInputMessage(role=msg.role, content=content)
)
return ResponsesRequest(
model=request.model,
input=input_messages if input_messages else "",
instructions=instructions,
max_output_tokens=request.max_tokens,
temperature=request.temperature,
top_p=request.top_p,
top_k=request.top_k,
stop=request.stop,
seed=request.seed,
stream=request.stream,
tools=request.tools,
continue_from_prefix=request.continue_from_prefix,
)
def chunk_to_response(
chunk: TokenChunk, command_id: CommandId
) -> ChatCompletionResponse:
"""Convert a TokenChunk to a streaming ChatCompletionResponse."""
# Build logprobs if available
logprobs: Logprobs | None = None
if chunk.logprob is not None:
logprobs = Logprobs(
content=[
LogprobsContentItem(
token=chunk.text,
logprob=chunk.logprob,
top_logprobs=chunk.top_logprobs or [],
)
]
)
return ChatCompletionResponse(
id=command_id,
created=int(time.time()),
model=chunk.model,
choices=[
StreamingChoiceResponse(
index=0,
delta=ChatCompletionMessage(role="assistant", content=chunk.text),
logprobs=logprobs,
finish_reason=chunk.finish_reason,
)
],
)
async def generate_chat_stream(
command_id: CommandId,
chunk_stream: AsyncGenerator[TokenChunk, None],
) -> AsyncGenerator[str, None]:
"""Generate Chat Completions API streaming events from TokenChunks."""
try:
async for chunk in chunk_stream:
if chunk.finish_reason == "error":
error_response = ErrorResponse(
error=ErrorInfo(
message=chunk.error_message or "Internal server error",
type="InternalServerError",
code=500,
)
)
yield f"data: {error_response.model_dump_json()}\n\n"
yield "data: [DONE]\n\n"
logger.info(f"generate_chat_stream ending (error): {command_id}")
return
chunk_response = chunk_to_response(chunk, command_id)
yield f"data: {chunk_response.model_dump_json()}\n\n"
if chunk.finish_reason is not None:
logger.info(
f"generate_chat_stream yielding [DONE] for finish_reason={chunk.finish_reason}: {command_id}"
)
yield "data: [DONE]\n\n"
logger.info(f"generate_chat_stream returning: {command_id}")
return
finally:
logger.info(f"generate_chat_stream finally block: {command_id}")
async def collect_chat_response(
command_id: CommandId,
chunk_stream: AsyncGenerator[TokenChunk, None],
) -> ChatCompletionResponse:
"""Collect all token chunks and return a single ChatCompletionResponse."""
text_parts: list[str] = []
model: str | None = None
finish_reason: FinishReason | None = None
error_message: str | None = None
async for chunk in chunk_stream:
if chunk.finish_reason == "error":
error_message = chunk.error_message or "Internal server error"
break
if model is None:
model = chunk.model
text_parts.append(chunk.text)
if chunk.finish_reason is not None:
finish_reason = chunk.finish_reason
if error_message is not None:
raise ValueError(error_message)
combined_text = "".join(text_parts)
assert model is not None
return ChatCompletionResponse(
id=command_id,
created=int(time.time()),
model=model,
choices=[
ChatCompletionChoice(
index=0,
message=ChatCompletionMessage(
role="assistant",
content=combined_text,
),
finish_reason=finish_reason,
)
],
)

View File

@@ -0,0 +1,190 @@
"""Claude Messages API adapter for converting requests/responses."""
from collections.abc import AsyncGenerator
from exo.shared.types.api import FinishReason
from exo.shared.types.chunks import TokenChunk
from exo.shared.types.claude_api import (
ClaudeContentBlockDeltaEvent,
ClaudeContentBlockStartEvent,
ClaudeContentBlockStopEvent,
ClaudeMessageDelta,
ClaudeMessageDeltaEvent,
ClaudeMessageDeltaUsage,
ClaudeMessagesRequest,
ClaudeMessagesResponse,
ClaudeMessageStart,
ClaudeMessageStartEvent,
ClaudeMessageStopEvent,
ClaudeStopReason,
ClaudeTextBlock,
ClaudeTextDelta,
ClaudeUsage,
)
from exo.shared.types.common import CommandId
from exo.shared.types.openai_responses import ResponseInputMessage, ResponsesRequest
def finish_reason_to_claude_stop_reason(
finish_reason: FinishReason | None,
) -> ClaudeStopReason | None:
"""Map OpenAI finish_reason to Claude stop_reason."""
if finish_reason is None:
return None
mapping: dict[FinishReason, ClaudeStopReason] = {
"stop": "end_turn",
"length": "max_tokens",
"tool_calls": "tool_use",
"content_filter": "end_turn",
"function_call": "tool_use",
}
return mapping.get(finish_reason, "end_turn")
def claude_request_to_internal(request: ClaudeMessagesRequest) -> ResponsesRequest:
"""Convert Claude Messages API request to ResponsesRequest (canonical internal format).
Converts Claude's system parameter to instructions,
and messages to input.
"""
# Handle system message
instructions: str | None = None
if request.system:
if isinstance(request.system, str):
instructions = request.system
else:
# List of text blocks
instructions = "".join(block.text for block in request.system)
# Convert messages to input
input_messages: list[ResponseInputMessage] = []
for msg in request.messages:
content: str
if isinstance(msg.content, str):
content = msg.content
else:
# Concatenate text blocks (images not supported for MVP)
text_parts: list[str] = []
for block in msg.content:
if isinstance(block, ClaudeTextBlock):
text_parts.append(block.text)
content = "".join(text_parts)
# Claude uses "user" and "assistant" roles
input_messages.append(ResponseInputMessage(role=msg.role, content=content))
return ResponsesRequest(
model=request.model,
input=input_messages if input_messages else "",
instructions=instructions,
max_output_tokens=request.max_tokens,
temperature=request.temperature,
top_p=request.top_p,
top_k=request.top_k,
stop=request.stop_sequences,
stream=request.stream,
)
async def collect_claude_response(
command_id: CommandId,
model: str,
chunk_stream: AsyncGenerator[TokenChunk, None],
) -> ClaudeMessagesResponse:
"""Collect all token chunks and return a single ClaudeMessagesResponse."""
text_parts: list[str] = []
stop_reason: ClaudeStopReason | None = None
last_stats = None
error_message: str | None = None
async for chunk in chunk_stream:
if chunk.finish_reason == "error":
error_message = chunk.error_message or "Internal server error"
break
text_parts.append(chunk.text)
last_stats = chunk.stats or last_stats
if chunk.finish_reason is not None:
stop_reason = finish_reason_to_claude_stop_reason(chunk.finish_reason)
if error_message is not None:
raise ValueError(error_message)
combined_text = "".join(text_parts)
# Use actual usage data from stats if available
input_tokens = last_stats.prompt_tokens if last_stats else 0
output_tokens = last_stats.generation_tokens if last_stats else 0
return ClaudeMessagesResponse(
id=f"msg_{command_id}",
model=model,
content=[ClaudeTextBlock(text=combined_text)],
stop_reason=stop_reason,
usage=ClaudeUsage(
input_tokens=input_tokens,
output_tokens=output_tokens,
),
)
async def generate_claude_stream(
command_id: CommandId,
model: str,
chunk_stream: AsyncGenerator[TokenChunk, None],
) -> AsyncGenerator[str, None]:
"""Generate Claude Messages API streaming events from TokenChunks."""
# Initial message_start event
initial_message = ClaudeMessageStart(
id=f"msg_{command_id}",
model=model,
content=[],
stop_reason=None,
usage=ClaudeUsage(input_tokens=0, output_tokens=0),
)
start_event = ClaudeMessageStartEvent(message=initial_message)
yield f"event: message_start\ndata: {start_event.model_dump_json()}\n\n"
# content_block_start
block_start = ClaudeContentBlockStartEvent(
index=0, content_block=ClaudeTextBlock(text="")
)
yield f"event: content_block_start\ndata: {block_start.model_dump_json()}\n\n"
output_tokens = 0
stop_reason: ClaudeStopReason | None = None
last_stats = None
async for chunk in chunk_stream:
output_tokens += 1 # Count each chunk as one token
last_stats = chunk.stats or last_stats
# content_block_delta
delta_event = ClaudeContentBlockDeltaEvent(
index=0,
delta=ClaudeTextDelta(text=chunk.text),
)
yield f"event: content_block_delta\ndata: {delta_event.model_dump_json()}\n\n"
if chunk.finish_reason is not None:
stop_reason = finish_reason_to_claude_stop_reason(chunk.finish_reason)
# Use actual token count from stats if available
if last_stats is not None:
output_tokens = last_stats.generation_tokens
# content_block_stop
block_stop = ClaudeContentBlockStopEvent(index=0)
yield f"event: content_block_stop\ndata: {block_stop.model_dump_json()}\n\n"
# message_delta
message_delta = ClaudeMessageDeltaEvent(
delta=ClaudeMessageDelta(stop_reason=stop_reason),
usage=ClaudeMessageDeltaUsage(output_tokens=output_tokens),
)
yield f"event: message_delta\ndata: {message_delta.model_dump_json()}\n\n"
# message_stop
message_stop = ClaudeMessageStopEvent()
yield f"event: message_stop\ndata: {message_stop.model_dump_json()}\n\n"

View File

@@ -0,0 +1,173 @@
"""OpenAI Responses API adapter for converting requests/responses.
ResponsesRequest is the canonical internal format. Responses API is the most featureful,
making it the natural choice for the internal format. All other API formats (Chat
Completions, Claude) are converted TO ResponsesRequest.
"""
from collections.abc import AsyncGenerator
from exo.shared.types.chunks import TokenChunk
from exo.shared.types.common import CommandId
from exo.shared.types.openai_responses import (
ResponseCompletedEvent,
ResponseContentPartAddedEvent,
ResponseContentPartDoneEvent,
ResponseCreatedEvent,
ResponseInProgressEvent,
ResponseMessageItem,
ResponseOutputItemAddedEvent,
ResponseOutputItemDoneEvent,
ResponseOutputText,
ResponsesResponse,
ResponseTextDeltaEvent,
ResponseTextDoneEvent,
ResponseUsage,
)
async def collect_responses_response(
command_id: CommandId,
model: str,
chunk_stream: AsyncGenerator[TokenChunk, None],
) -> ResponsesResponse:
"""Collect all token chunks and return a single ResponsesResponse."""
response_id = f"resp_{command_id}"
item_id = f"item_{command_id}"
accumulated_text = ""
last_stats = None
error_message: str | None = None
async for chunk in chunk_stream:
if chunk.finish_reason == "error":
error_message = chunk.error_message or "Internal server error"
break
accumulated_text += chunk.text
last_stats = chunk.stats or last_stats
if error_message is not None:
raise ValueError(error_message)
# Create usage from stats if available
usage = None
if last_stats is not None:
usage = ResponseUsage(
input_tokens=last_stats.prompt_tokens,
output_tokens=last_stats.generation_tokens,
total_tokens=last_stats.prompt_tokens + last_stats.generation_tokens,
)
output_item = ResponseMessageItem(
id=item_id,
content=[ResponseOutputText(text=accumulated_text)],
status="completed",
)
return ResponsesResponse(
id=response_id,
model=model,
status="completed",
output=[output_item],
output_text=accumulated_text,
usage=usage,
)
async def generate_responses_stream(
command_id: CommandId,
model: str,
chunk_stream: AsyncGenerator[TokenChunk, None],
) -> AsyncGenerator[str, None]:
"""Generate OpenAI Responses API streaming events from TokenChunks."""
response_id = f"resp_{command_id}"
item_id = f"item_{command_id}"
# response.created
initial_response = ResponsesResponse(
id=response_id,
model=model,
status="in_progress",
output=[],
output_text="",
)
created_event = ResponseCreatedEvent(response=initial_response)
yield f"event: response.created\ndata: {created_event.model_dump_json()}\n\n"
# response.in_progress
in_progress_event = ResponseInProgressEvent(response=initial_response)
yield f"event: response.in_progress\ndata: {in_progress_event.model_dump_json()}\n\n"
# response.output_item.added
initial_item = ResponseMessageItem(
id=item_id,
content=[ResponseOutputText(text="")],
status="in_progress",
)
item_added = ResponseOutputItemAddedEvent(output_index=0, item=initial_item)
yield f"event: response.output_item.added\ndata: {item_added.model_dump_json()}\n\n"
# response.content_part.added
initial_part = ResponseOutputText(text="")
part_added = ResponseContentPartAddedEvent(
output_index=0, content_index=0, part=initial_part
)
yield f"event: response.content_part.added\ndata: {part_added.model_dump_json()}\n\n"
accumulated_text = ""
last_stats = None
async for chunk in chunk_stream:
accumulated_text += chunk.text
last_stats = chunk.stats or last_stats
# response.output_text.delta
delta_event = ResponseTextDeltaEvent(
output_index=0,
content_index=0,
delta=chunk.text,
)
yield f"event: response.output_text.delta\ndata: {delta_event.model_dump_json()}\n\n"
# response.output_text.done
text_done = ResponseTextDoneEvent(
output_index=0, content_index=0, text=accumulated_text
)
yield f"event: response.output_text.done\ndata: {text_done.model_dump_json()}\n\n"
# response.content_part.done
final_part = ResponseOutputText(text=accumulated_text)
part_done = ResponseContentPartDoneEvent(
output_index=0, content_index=0, part=final_part
)
yield f"event: response.content_part.done\ndata: {part_done.model_dump_json()}\n\n"
# response.output_item.done
final_item = ResponseMessageItem(
id=item_id,
content=[ResponseOutputText(text=accumulated_text)],
status="completed",
)
item_done = ResponseOutputItemDoneEvent(output_index=0, item=final_item)
yield f"event: response.output_item.done\ndata: {item_done.model_dump_json()}\n\n"
# Create usage from stats if available
usage = None
if last_stats is not None:
usage = ResponseUsage(
input_tokens=last_stats.prompt_tokens,
output_tokens=last_stats.generation_tokens,
total_tokens=last_stats.prompt_tokens + last_stats.generation_tokens,
)
# response.completed
final_response = ResponsesResponse(
id=response_id,
model=model,
status="completed",
output=[final_item],
output_text=accumulated_text,
usage=usage,
)
completed_event = ResponseCompletedEvent(response=final_response)
yield f"event: response.completed\ndata: {completed_event.model_dump_json()}\n\n"

View File

@@ -15,6 +15,20 @@ from hypercorn.config import Config
from hypercorn.typing import ASGIFramework
from loguru import logger
from exo.master.adapters.chat_completions import (
chat_request_to_internal,
collect_chat_response,
generate_chat_stream,
)
from exo.master.adapters.claude import (
claude_request_to_internal,
collect_claude_response,
generate_claude_stream,
)
from exo.master.adapters.responses import (
collect_responses_response,
generate_responses_stream,
)
from exo.master.placement import place_instance as get_instance_placements
from exo.shared.apply import apply
from exo.shared.election import ElectionMessage
@@ -27,6 +41,7 @@ from exo.shared.types.api import (
ChatCompletionChoice,
ChatCompletionMessage,
ChatCompletionResponse,
ChatCompletionTaskParams,
CreateInstanceParams,
CreateInstanceResponse,
DeleteInstanceResponse,
@@ -39,11 +54,12 @@ from exo.shared.types.api import (
PlaceInstanceParams,
PlacementPreview,
PlacementPreviewResponse,
SetDraftModelParams,
SetDraftModelResponse,
StreamingChoiceResponse,
)
from exo.shared.types.chunks import TokenChunk
from exo.shared.types.claude_api import (
ClaudeMessagesRequest,
ClaudeMessagesResponse,
)
from exo.shared.types.commands import (
ChatCompletion,
Command,
@@ -51,7 +67,6 @@ from exo.shared.types.commands import (
DeleteInstance,
ForwarderCommand,
PlaceInstance,
SetInstanceDraftModel,
TaskFinished,
)
from exo.shared.types.common import CommandId, NodeId, SessionId
@@ -63,8 +78,11 @@ from exo.shared.types.events import (
)
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.openai_responses import (
ResponsesRequest,
ResponsesResponse,
)
from exo.shared.types.state import State
from exo.shared.types.tasks import ChatCompletionTaskParams
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.shards import Sharding
from exo.utils.banner import print_startup_banner
@@ -73,23 +91,6 @@ from exo.utils.dashboard_path import find_dashboard
from exo.utils.event_buffer import OrderedBuffer
def chunk_to_response(
chunk: TokenChunk, command_id: CommandId
) -> ChatCompletionResponse:
return ChatCompletionResponse(
id=command_id,
created=int(time.time()),
model=chunk.model,
choices=[
StreamingChoiceResponse(
index=0,
delta=ChatCompletionMessage(role="assistant", content=chunk.text),
finish_reason=chunk.finish_reason,
)
],
)
async def resolve_model_meta(model_id: str) -> ModelMetadata:
if model_id in MODEL_CARDS:
model_card = MODEL_CARDS[model_id]
@@ -158,18 +159,19 @@ class API:
self.paused_ev = anyio.Event()
def _setup_exception_handlers(self) -> None:
@self.app.exception_handler(HTTPException)
async def http_exception_handler( # pyright: ignore[reportUnusedFunction]
_: Request, exc: HTTPException
) -> JSONResponse:
err = ErrorResponse(
error=ErrorInfo(
message=exc.detail,
type=HTTPStatus(exc.status_code).phrase,
code=exc.status_code,
)
self.app.exception_handler(HTTPException)(self.http_exception_handler)
async def http_exception_handler(
self, _: Request, exc: HTTPException
) -> JSONResponse:
err = ErrorResponse(
error=ErrorInfo(
message=exc.detail,
type=HTTPStatus(exc.status_code).phrase,
code=exc.status_code,
)
return JSONResponse(err.model_dump(), status_code=exc.status_code)
)
return JSONResponse(err.model_dump(), status_code=exc.status_code)
def _setup_cors(self) -> None:
self.app.add_middleware(
@@ -188,13 +190,14 @@ class API:
self.app.get("/instance/previews")(self.get_placement_previews)
self.app.get("/instance/{instance_id}")(self.get_instance)
self.app.delete("/instance/{instance_id}")(self.delete_instance)
self.app.put("/instance/{instance_id}/draft_model")(self.set_draft_model)
self.app.get("/models")(self.get_models)
self.app.get("/v1/models")(self.get_models)
self.app.post("/v1/chat/completions", response_model=None)(
self.chat_completions
)
self.app.post("/bench/chat/completions")(self.bench_chat_completions)
self.app.post("/v1/messages", response_model=None)(self.claude_messages)
self.app.post("/v1/responses", response_model=None)(self.openai_responses)
self.app.get("/state")(lambda: self.state)
self.app.get("/events")(lambda: self._event_log)
@@ -204,8 +207,6 @@ class API:
sharding=payload.sharding,
instance_meta=payload.instance_meta,
min_nodes=payload.min_nodes,
draft_model=payload.draft_model,
num_draft_tokens=payload.num_draft_tokens,
)
await self._send(command)
@@ -402,29 +403,13 @@ class API:
instance_id=instance_id,
)
async def set_draft_model(
self, instance_id: InstanceId, payload: SetDraftModelParams
) -> SetDraftModelResponse:
if instance_id not in self.state.instances:
raise HTTPException(status_code=404, detail="Instance not found")
command = SetInstanceDraftModel(
instance_id=instance_id,
draft_model=payload.draft_model,
num_draft_tokens=payload.num_draft_tokens,
)
await self._send(command)
return SetDraftModelResponse(
message="Command received.",
command_id=command.command_id,
instance_id=instance_id,
)
async def _chat_chunk_stream(
async def _token_chunk_stream(
self, command_id: CommandId
) -> AsyncGenerator[TokenChunk, None]:
"""Yield `TokenChunk`s for a given command until completion."""
"""Yield `TokenChunk`s for a given command until completion.
This is the internal low-level stream used by all API adapters.
"""
try:
self._chat_completion_queues[command_id], recv = channel[TokenChunk]()
@@ -447,77 +432,6 @@ class API:
await self._send(command)
del self._chat_completion_queues[command_id]
async def _generate_chat_stream(
self, command_id: CommandId
) -> AsyncGenerator[str, None]:
"""Generate chat completion stream as JSON strings."""
async for chunk in self._chat_chunk_stream(command_id):
if chunk.finish_reason == "error":
error_response = ErrorResponse(
error=ErrorInfo(
message=chunk.error_message or "Internal server error",
type="InternalServerError",
code=500,
)
)
yield f"data: {error_response.model_dump_json()}\n\n"
yield "data: [DONE]\n\n"
return
chunk_response: ChatCompletionResponse = chunk_to_response(
chunk, command_id
)
logger.debug(f"chunk_response: {chunk_response}")
yield f"data: {chunk_response.model_dump_json()}\n\n"
if chunk.finish_reason is not None:
yield "data: [DONE]\n\n"
async def _collect_chat_completion(
self, command_id: CommandId
) -> ChatCompletionResponse:
"""Collect all token chunks for a chat completion and return a single response."""
text_parts: list[str] = []
model: str | None = None
finish_reason: FinishReason | None = None
async for chunk in self._chat_chunk_stream(command_id):
if chunk.finish_reason == "error":
raise HTTPException(
status_code=500,
detail=chunk.error_message or "Internal server error",
)
if model is None:
model = chunk.model
text_parts.append(chunk.text)
if chunk.finish_reason is not None:
finish_reason = chunk.finish_reason
combined_text = "".join(text_parts)
assert model is not None
return ChatCompletionResponse(
id=command_id,
created=int(time.time()),
model=model,
choices=[
ChatCompletionChoice(
index=0,
message=ChatCompletionMessage(
role="assistant",
content=combined_text,
),
finish_reason=finish_reason,
)
],
)
async def _collect_chat_completion_with_stats(
self, command_id: CommandId
) -> BenchChatCompletionResponse:
@@ -527,7 +441,7 @@ class API:
stats: GenerationStats | None = None
async for chunk in self._chat_chunk_stream(command_id):
async for chunk in self._token_chunk_stream(command_id):
if chunk.finish_reason == "error":
raise HTTPException(
status_code=500,
@@ -571,54 +485,162 @@ class API:
async def chat_completions(
self, payload: ChatCompletionTaskParams
) -> ChatCompletionResponse | StreamingResponse:
"""Handle chat completions, supporting both streaming and non-streaming responses."""
model_meta = await resolve_model_meta(payload.model)
payload.model = model_meta.model_id
"""OpenAI Chat Completions API - adapter."""
internal_params = chat_request_to_internal(payload)
model_meta = await resolve_model_meta(internal_params.model)
internal_params.model = model_meta.model_id
if not any(
instance.shard_assignments.model_id == payload.model
instance.shard_assignments.model_id == internal_params.model
for instance in self.state.instances.values()
):
await self._trigger_notify_user_to_download_model(payload.model)
await self._trigger_notify_user_to_download_model(internal_params.model)
raise HTTPException(
status_code=404, detail=f"No instance found for model {payload.model}"
status_code=404,
detail=f"No instance found for model {internal_params.model}",
)
command = ChatCompletion(
request_params=payload,
)
command = ChatCompletion(request_params=internal_params)
await self._send(command)
if payload.stream:
return StreamingResponse(
self._generate_chat_stream(command.command_id),
generate_chat_stream(
command.command_id,
self._token_chunk_stream(command.command_id),
),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "close",
"X-Accel-Buffering": "no",
},
)
return await self._collect_chat_completion(command.command_id)
try:
return await collect_chat_response(
command.command_id,
self._token_chunk_stream(command.command_id),
)
except ValueError as e:
raise HTTPException(status_code=500, detail=str(e)) from e
async def bench_chat_completions(
self, payload: BenchChatCompletionTaskParams
) -> BenchChatCompletionResponse:
model_meta = await resolve_model_meta(payload.model)
payload.model = model_meta.model_id
# Convert to internal format (BenchChatCompletionTaskParams extends ChatCompletionTaskParams)
internal_params = chat_request_to_internal(payload)
model_meta = await resolve_model_meta(internal_params.model)
internal_params.model = model_meta.model_id
if not any(
instance.shard_assignments.model_id == payload.model
instance.shard_assignments.model_id == internal_params.model
for instance in self.state.instances.values()
):
await self._trigger_notify_user_to_download_model(payload.model)
await self._trigger_notify_user_to_download_model(internal_params.model)
raise HTTPException(
status_code=404, detail=f"No instance found for model {payload.model}"
status_code=404,
detail=f"No instance found for model {internal_params.model}",
)
payload.stream = False
internal_params.stream = False
command = ChatCompletion(request_params=payload)
command = ChatCompletion(request_params=internal_params)
await self._send(command)
response = await self._collect_chat_completion_with_stats(command.command_id)
return response
async def claude_messages(
self, payload: ClaudeMessagesRequest
) -> ClaudeMessagesResponse | StreamingResponse:
"""Claude Messages API - adapter."""
internal_params = claude_request_to_internal(payload)
model_meta = await resolve_model_meta(internal_params.model)
internal_params.model = model_meta.model_id
if not any(
instance.shard_assignments.model_id == internal_params.model
for instance in self.state.instances.values()
):
await self._trigger_notify_user_to_download_model(internal_params.model)
raise HTTPException(
status_code=404,
detail=f"No instance found for model {internal_params.model}",
)
command = ChatCompletion(request_params=internal_params)
await self._send(command)
if payload.stream:
return StreamingResponse(
generate_claude_stream(
command.command_id,
payload.model,
self._token_chunk_stream(command.command_id),
),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "close",
"X-Accel-Buffering": "no",
},
)
try:
return await collect_claude_response(
command.command_id,
payload.model,
self._token_chunk_stream(command.command_id),
)
except ValueError as e:
raise HTTPException(status_code=500, detail=str(e)) from e
async def openai_responses(
self, payload: ResponsesRequest
) -> ResponsesResponse | StreamingResponse:
"""OpenAI Responses API - native format (no conversion needed)."""
model_meta = await resolve_model_meta(payload.model)
# Update model to resolved model_id
request_params = payload.model_copy(update={"model": model_meta.model_id})
if not any(
instance.shard_assignments.model_id == request_params.model
for instance in self.state.instances.values()
):
await self._trigger_notify_user_to_download_model(request_params.model)
raise HTTPException(
status_code=404,
detail=f"No instance found for model {request_params.model}",
)
command = ChatCompletion(request_params=request_params)
await self._send(command)
if payload.stream:
return StreamingResponse(
generate_responses_stream(
command.command_id,
payload.model,
self._token_chunk_stream(command.command_id),
),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "close",
"X-Accel-Buffering": "no",
},
)
try:
return await collect_responses_response(
command.command_id,
payload.model,
self._token_chunk_stream(command.command_id),
)
except ValueError as e:
raise HTTPException(status_code=500, detail=str(e)) from e
def _calculate_total_available_memory(self) -> Memory:
"""Calculate total available memory across all nodes in bytes."""
total_available = Memory()

View File

@@ -18,7 +18,6 @@ from exo.shared.types.commands import (
ForwarderCommand,
PlaceInstance,
RequestEventLog,
SetInstanceDraftModel,
TaskFinished,
TestCommand,
)
@@ -28,7 +27,6 @@ from exo.shared.types.events import (
ForwarderEvent,
IndexedEvent,
InstanceDeleted,
InstanceDraftModelUpdated,
NodeTimedOut,
TaskCreated,
TaskDeleted,
@@ -175,14 +173,6 @@ class Master:
self.state.instances, placement
)
generated_events.extend(transition_events)
case SetInstanceDraftModel():
generated_events.append(
InstanceDraftModelUpdated(
instance_id=command.instance_id,
draft_model=command.draft_model,
num_draft_tokens=command.num_draft_tokens,
)
)
case TaskFinished():
generated_events.append(
TaskDeleted(

View File

@@ -3,6 +3,8 @@ from collections.abc import Mapping
from copy import deepcopy
from typing import Sequence
from loguru import logger
from exo.master.placement_utils import (
filter_cycles_by_memory,
get_mlx_ibv_devices_matrix,
@@ -53,6 +55,7 @@ def place_instance(
) -> dict[InstanceId, Instance]:
all_nodes = list(topology.list_nodes())
logger.info("finding cycles:")
cycles = topology.get_cycles()
singleton_cycles = [[node] for node in all_nodes]
candidate_cycles = list(
@@ -125,6 +128,10 @@ def place_instance(
target_instances = dict(deepcopy(current_instances))
if len(selected_cycle) == 1:
logger.warning(
"You have likely selected ibv for a single node instance; falling back to MlxRing"
)
command.instance_meta = InstanceMeta.MlxRing
# TODO: Single node instances
@@ -144,8 +151,6 @@ def place_instance(
shard_assignments=shard_assignments,
ibv_devices=mlx_ibv_devices,
jaccl_coordinators=mlx_jaccl_coordinators,
draft_model=command.draft_model,
num_draft_tokens=command.num_draft_tokens,
)
case InstanceMeta.MlxRing:
ephemeral_port = random_ephemeral_port()
@@ -159,8 +164,6 @@ def place_instance(
shard_assignments=shard_assignments,
hosts_by_node=hosts_by_node,
ephemeral_port=ephemeral_port,
draft_model=command.draft_model,
num_draft_tokens=command.num_draft_tokens,
)
return target_instances

View File

@@ -49,33 +49,83 @@ def get_smallest_cycles(cycles: list[list[NodeInfo]]) -> list[list[NodeInfo]]:
return [cycle for cycle in cycles if len(cycle) == min_nodes]
def allocate_layers_proportionally(
total_layers: int,
memory_fractions: list[float],
) -> list[int]:
n = len(memory_fractions)
if n == 0:
raise ValueError("Cannot allocate layers to an empty node list")
if total_layers < n:
raise ValueError(
f"Cannot distribute {total_layers} layers across {n} nodes "
"(need at least 1 layer per node)"
)
# Largest remainder: floor each, then distribute remainder by fractional part
raw = [f * total_layers for f in memory_fractions]
result = [int(r) for r in raw]
by_remainder = sorted(range(n), key=lambda i: raw[i] - result[i], reverse=True)
for i in range(total_layers - sum(result)):
result[by_remainder[i]] += 1
# Ensure minimum 1 per node by taking from the largest
for i in range(n):
if result[i] == 0:
max_idx = max(range(n), key=lambda j: result[j])
assert result[max_idx] > 1
result[max_idx] -= 1
result[i] = 1
return result
def get_shard_assignments_for_pipeline_parallel(
model_meta: ModelMetadata,
selected_cycle: list[NodeWithProfile],
):
if not selected_cycle:
raise ValueError("Cannot create shard assignments for empty node cycle")
cycle_memory = sum(
(node.node_profile.memory.ram_available for node in selected_cycle),
start=Memory(),
)
if cycle_memory.in_bytes == 0:
raise ValueError("Cannot create shard assignments: total available memory is 0")
total_layers = model_meta.n_layers
world_size = len(selected_cycle)
runner_to_shard: dict[RunnerId, ShardMetadata] = {}
node_to_runner: dict[NodeId, RunnerId] = {}
layers_assigned = 0
for i, node in enumerate(selected_cycle):
if i == len(selected_cycle) - 1:
node_layers = total_layers - layers_assigned
else:
node_layers = round(
total_layers
* (
node.node_profile.memory.ram_available.in_bytes
/ cycle_memory.in_bytes
)
)
node_layers = max(1, node_layers)
layer_allocations = allocate_layers_proportionally(
total_layers=total_layers,
memory_fractions=[
node.node_profile.memory.ram_available.in_bytes / cycle_memory.in_bytes
for node in selected_cycle
],
)
# Validate each node has sufficient memory for its assigned layers
memory_per_layer = model_meta.storage_size.in_bytes / total_layers
for i, (node, node_layers) in enumerate(
zip(selected_cycle, layer_allocations, strict=True)
):
required_memory = node_layers * memory_per_layer
available_memory = node.node_profile.memory.ram_available.in_bytes
if required_memory > available_memory:
raise ValueError(
f"Node {i} ({node.node_id}) has insufficient memory: "
f"requires {required_memory / (1024**3):.2f} GB for {node_layers} layers, "
f"but only has {available_memory / (1024**3):.2f} GB available"
)
layers_assigned = 0
for i, (node, node_layers) in enumerate(
zip(selected_cycle, layer_allocations, strict=True)
):
runner_id = RunnerId()
shard = PipelineShardMetadata(

View File

@@ -0,0 +1,283 @@
"""Tests for Claude Messages API conversion functions and types."""
import json
from typing import Any, cast
import pydantic
import pytest
from exo.master.adapters.claude import (
claude_request_to_internal,
finish_reason_to_claude_stop_reason,
)
from exo.shared.types.claude_api import (
ClaudeContentBlockDeltaEvent,
ClaudeContentBlockStartEvent,
ClaudeContentBlockStopEvent,
ClaudeMessage,
ClaudeMessageDelta,
ClaudeMessageDeltaEvent,
ClaudeMessageDeltaUsage,
ClaudeMessagesRequest,
ClaudeMessageStart,
ClaudeMessageStartEvent,
ClaudeMessageStopEvent,
ClaudeTextBlock,
ClaudeTextDelta,
ClaudeUsage,
)
class TestFinishReasonToClaudeStopReason:
"""Tests for finish_reason to Claude stop_reason mapping."""
def test_stop_maps_to_end_turn(self):
assert finish_reason_to_claude_stop_reason("stop") == "end_turn"
def test_length_maps_to_max_tokens(self):
assert finish_reason_to_claude_stop_reason("length") == "max_tokens"
def test_tool_calls_maps_to_tool_use(self):
assert finish_reason_to_claude_stop_reason("tool_calls") == "tool_use"
def test_function_call_maps_to_tool_use(self):
assert finish_reason_to_claude_stop_reason("function_call") == "tool_use"
def test_content_filter_maps_to_end_turn(self):
assert finish_reason_to_claude_stop_reason("content_filter") == "end_turn"
def test_none_returns_none(self):
assert finish_reason_to_claude_stop_reason(None) is None
class TestClaudeRequestToInternal:
"""Tests for converting Claude Messages API requests to ResponsesRequest."""
def test_basic_request_conversion(self):
request = ClaudeMessagesRequest(
model="claude-3-opus",
max_tokens=100,
messages=[
ClaudeMessage(role="user", content="Hello"),
],
)
params = claude_request_to_internal(request)
assert params.model == "claude-3-opus"
assert params.max_output_tokens == 100
assert isinstance(params.input, list)
assert len(params.input) == 1
assert params.input[0].role == "user"
assert params.input[0].content == "Hello"
assert params.instructions is None
def test_request_with_system_string(self):
request = ClaudeMessagesRequest(
model="claude-3-opus",
max_tokens=100,
system="You are a helpful assistant.",
messages=[
ClaudeMessage(role="user", content="Hello"),
],
)
params = claude_request_to_internal(request)
assert params.instructions == "You are a helpful assistant."
assert isinstance(params.input, list)
assert len(params.input) == 1
assert params.input[0].role == "user"
assert params.input[0].content == "Hello"
def test_request_with_system_text_blocks(self):
request = ClaudeMessagesRequest(
model="claude-3-opus",
max_tokens=100,
system=[
ClaudeTextBlock(text="You are helpful. "),
ClaudeTextBlock(text="Be concise."),
],
messages=[
ClaudeMessage(role="user", content="Hello"),
],
)
params = claude_request_to_internal(request)
assert params.instructions == "You are helpful. Be concise."
assert isinstance(params.input, list)
assert len(params.input) == 1
def test_request_with_content_blocks(self):
request = ClaudeMessagesRequest(
model="claude-3-opus",
max_tokens=100,
messages=[
ClaudeMessage(
role="user",
content=[
ClaudeTextBlock(text="First part. "),
ClaudeTextBlock(text="Second part."),
],
),
],
)
params = claude_request_to_internal(request)
assert isinstance(params.input, list)
assert len(params.input) == 1
assert params.input[0].content == "First part. Second part."
def test_request_with_multi_turn_conversation(self):
request = ClaudeMessagesRequest(
model="claude-3-opus",
max_tokens=100,
messages=[
ClaudeMessage(role="user", content="Hello"),
ClaudeMessage(role="assistant", content="Hi there!"),
ClaudeMessage(role="user", content="How are you?"),
],
)
params = claude_request_to_internal(request)
assert isinstance(params.input, list)
assert len(params.input) == 3
assert params.input[0].role == "user"
assert params.input[1].role == "assistant"
assert params.input[2].role == "user"
def test_request_with_optional_parameters(self):
request = ClaudeMessagesRequest(
model="claude-3-opus",
max_tokens=100,
messages=[ClaudeMessage(role="user", content="Hello")],
temperature=0.7,
top_p=0.9,
top_k=40,
stop_sequences=["STOP", "END"],
stream=True,
)
params = claude_request_to_internal(request)
assert params.temperature == 0.7
assert params.top_p == 0.9
assert params.top_k == 40
assert params.stop == ["STOP", "END"]
assert params.stream is True
class TestClaudeMessagesRequestValidation:
"""Tests for Claude Messages API request validation."""
def test_request_requires_model(self):
with pytest.raises(pydantic.ValidationError):
ClaudeMessagesRequest.model_validate(
{
"max_tokens": 100,
"messages": [{"role": "user", "content": "Hello"}],
}
)
def test_request_requires_max_tokens(self):
with pytest.raises(pydantic.ValidationError):
ClaudeMessagesRequest.model_validate(
{
"model": "claude-3-opus",
"messages": [{"role": "user", "content": "Hello"}],
}
)
def test_request_requires_messages(self):
with pytest.raises(pydantic.ValidationError):
ClaudeMessagesRequest.model_validate(
{
"model": "claude-3-opus",
"max_tokens": 100,
}
)
class TestClaudeStreamingEvents:
"""Tests for Claude Messages API streaming event serialization."""
def test_message_start_event_format(self):
message = ClaudeMessageStart(
id="msg_123",
model="claude-3-opus",
content=[],
stop_reason=None,
usage=ClaudeUsage(input_tokens=10, output_tokens=0),
)
event = ClaudeMessageStartEvent(message=message)
json_str = event.model_dump_json()
parsed = cast(dict[str, Any], json.loads(json_str))
assert parsed["type"] == "message_start"
assert parsed["message"]["id"] == "msg_123"
assert parsed["message"]["type"] == "message"
assert parsed["message"]["role"] == "assistant"
assert parsed["message"]["model"] == "claude-3-opus"
def test_content_block_start_event_format(self):
event = ClaudeContentBlockStartEvent(
index=0,
content_block=ClaudeTextBlock(text=""),
)
json_str = event.model_dump_json()
parsed = cast(dict[str, Any], json.loads(json_str))
assert parsed["type"] == "content_block_start"
assert parsed["index"] == 0
assert parsed["content_block"]["type"] == "text"
assert parsed["content_block"]["text"] == ""
def test_content_block_delta_event_format(self):
event = ClaudeContentBlockDeltaEvent(
index=0,
delta=ClaudeTextDelta(text="Hello"),
)
json_str = event.model_dump_json()
parsed = cast(dict[str, Any], json.loads(json_str))
assert parsed["type"] == "content_block_delta"
assert parsed["index"] == 0
assert parsed["delta"]["type"] == "text_delta"
assert parsed["delta"]["text"] == "Hello"
def test_content_block_stop_event_format(self):
event = ClaudeContentBlockStopEvent(index=0)
json_str = event.model_dump_json()
parsed = cast(dict[str, Any], json.loads(json_str))
assert parsed["type"] == "content_block_stop"
assert parsed["index"] == 0
def test_message_delta_event_format(self):
event = ClaudeMessageDeltaEvent(
delta=ClaudeMessageDelta(stop_reason="end_turn"),
usage=ClaudeMessageDeltaUsage(output_tokens=25),
)
json_str = event.model_dump_json()
parsed = cast(dict[str, Any], json.loads(json_str))
assert parsed["type"] == "message_delta"
assert parsed["delta"]["stop_reason"] == "end_turn"
assert parsed["usage"]["output_tokens"] == 25
def test_message_stop_event_format(self):
event = ClaudeMessageStopEvent()
json_str = event.model_dump_json()
parsed = cast(dict[str, Any], json.loads(json_str))
assert parsed["type"] == "message_stop"
def test_sse_format(self):
"""Test that SSE format is correctly generated."""
event = ClaudeContentBlockDeltaEvent(
index=0,
delta=ClaudeTextDelta(text="Hello"),
)
# Simulate the SSE format used in the streaming generator
sse_line = f"event: content_block_delta\ndata: {event.model_dump_json()}\n\n"
assert sse_line.startswith("event: content_block_delta\n")
assert "data: " in sse_line
assert sse_line.endswith("\n\n")

View File

@@ -7,7 +7,6 @@ from loguru import logger
from exo.master.main import Master
from exo.routing.router import get_node_id_keypair
from exo.shared.types.api import ChatCompletionMessage, ChatCompletionTaskParams
from exo.shared.types.commands import (
ChatCompletion,
CommandId,
@@ -24,6 +23,7 @@ from exo.shared.types.events import (
)
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.openai_responses import ResponsesRequest
from exo.shared.types.profiling import (
MemoryPerformanceProfile,
NodePerformanceProfile,
@@ -143,13 +143,9 @@ async def test_master():
command=(
ChatCompletion(
command_id=CommandId(),
request_params=ChatCompletionTaskParams(
request_params=ResponsesRequest(
model="llama-3.2-1b",
messages=[
ChatCompletionMessage(
role="user", content="Hello, how are you?"
)
],
input="Hello, how are you?",
),
)
),
@@ -200,11 +196,9 @@ async def test_master():
assert isinstance(events[2].event, TaskCreated)
assert events[2].event.task.task_status == TaskStatus.Pending
assert isinstance(events[2].event.task, ChatCompletionTask)
assert events[2].event.task.task_params == ChatCompletionTaskParams(
assert events[2].event.task.task_params == ResponsesRequest(
model="llama-3.2-1b",
messages=[
ChatCompletionMessage(role="user", content="Hello, how are you?")
],
input="Hello, how are you?",
)
await master.shutdown()

View File

@@ -0,0 +1,293 @@
"""Tests for OpenAI Responses API types.
ResponsesRequest is the canonical internal type used throughout the pipeline.
No conversion is needed for Responses API requests.
"""
import json
from typing import Any, cast
import pydantic
import pytest
from exo.shared.types.openai_responses import (
ResponseCompletedEvent,
ResponseContentPartAddedEvent,
ResponseCreatedEvent,
ResponseInputMessage,
ResponseMessageItem,
ResponseOutputItemAddedEvent,
ResponseOutputItemDoneEvent,
ResponseOutputText,
ResponsesRequest,
ResponsesResponse,
ResponseTextDeltaEvent,
ResponseTextDoneEvent,
ResponseUsage,
)
class TestResponsesRequestAsCanonicalType:
"""Tests for ResponsesRequest as the canonical internal type."""
def test_string_input(self):
request = ResponsesRequest(
model="gpt-4o",
input="Hello, how are you?",
)
assert request.model == "gpt-4o"
assert request.input == "Hello, how are you?"
assert request.instructions is None
def test_message_array_input(self):
request = ResponsesRequest(
model="gpt-4o",
input=[
ResponseInputMessage(role="user", content="Hello"),
ResponseInputMessage(role="assistant", content="Hi there!"),
ResponseInputMessage(role="user", content="How are you?"),
],
)
assert isinstance(request.input, list)
assert len(request.input) == 3
assert request.input[0].role == "user"
assert request.input[0].content == "Hello"
assert request.input[1].role == "assistant"
assert request.input[1].content == "Hi there!"
assert request.input[2].role == "user"
assert request.input[2].content == "How are you?"
def test_request_with_instructions(self):
request = ResponsesRequest(
model="gpt-4o",
input="Hello",
instructions="You are a helpful assistant. Be concise.",
)
assert request.input == "Hello"
assert request.instructions == "You are a helpful assistant. Be concise."
def test_request_with_optional_parameters(self):
request = ResponsesRequest(
model="gpt-4o",
input="Hello",
max_output_tokens=500,
temperature=0.8,
top_p=0.95,
stream=True,
)
assert request.max_output_tokens == 500
assert request.temperature == 0.8
assert request.top_p == 0.95
assert request.stream is True
def test_request_with_new_fields(self):
"""Test the additional fields added for internal use."""
request = ResponsesRequest(
model="gpt-4o",
input="Hello",
top_k=40,
seed=42,
stop=["STOP", "END"],
tools=[{"type": "function", "function": {"name": "test"}}],
)
assert request.top_k == 40
assert request.seed == 42
assert request.stop == ["STOP", "END"]
assert request.tools == [{"type": "function", "function": {"name": "test"}}]
def test_request_with_system_role_in_messages(self):
request = ResponsesRequest(
model="gpt-4o",
input=[
ResponseInputMessage(role="system", content="Be helpful"),
ResponseInputMessage(role="user", content="Hello"),
],
)
assert isinstance(request.input, list)
assert len(request.input) == 2
assert request.input[0].role == "system"
assert request.input[1].role == "user"
def test_request_with_developer_role(self):
request = ResponsesRequest(
model="gpt-4o",
input=[
ResponseInputMessage(role="developer", content="Internal note"),
ResponseInputMessage(role="user", content="Hello"),
],
)
assert isinstance(request.input, list)
assert len(request.input) == 2
assert request.input[0].role == "developer"
class TestResponsesRequestValidation:
"""Tests for OpenAI Responses API request validation."""
def test_request_requires_model(self):
with pytest.raises(pydantic.ValidationError):
ResponsesRequest.model_validate(
{
"input": "Hello",
}
)
def test_request_requires_input(self):
with pytest.raises(pydantic.ValidationError):
ResponsesRequest.model_validate(
{
"model": "gpt-4o",
}
)
def test_request_accepts_string_input(self):
request = ResponsesRequest(
model="gpt-4o",
input="Hello",
)
assert request.input == "Hello"
def test_request_accepts_message_array_input(self):
request = ResponsesRequest(
model="gpt-4o",
input=[ResponseInputMessage(role="user", content="Hello")],
)
assert len(request.input) == 1
class TestResponsesStreamingEvents:
"""Tests for OpenAI Responses API streaming event serialization."""
def test_response_created_event_format(self):
response = ResponsesResponse(
id="resp_123",
model="gpt-4o",
status="in_progress",
output=[],
output_text="",
)
event = ResponseCreatedEvent(response=response)
json_str = event.model_dump_json()
parsed = cast(dict[str, Any], json.loads(json_str))
assert parsed["type"] == "response.created"
assert parsed["response"]["id"] == "resp_123"
assert parsed["response"]["object"] == "response"
assert parsed["response"]["status"] == "in_progress"
def test_output_item_added_event_format(self):
item = ResponseMessageItem(
id="item_123",
content=[ResponseOutputText(text="")],
status="in_progress",
)
event = ResponseOutputItemAddedEvent(output_index=0, item=item)
json_str = event.model_dump_json()
parsed = cast(dict[str, Any], json.loads(json_str))
assert parsed["type"] == "response.output_item.added"
assert parsed["output_index"] == 0
assert parsed["item"]["type"] == "message"
assert parsed["item"]["id"] == "item_123"
assert parsed["item"]["role"] == "assistant"
def test_content_part_added_event_format(self):
part = ResponseOutputText(text="")
event = ResponseContentPartAddedEvent(
output_index=0,
content_index=0,
part=part,
)
json_str = event.model_dump_json()
parsed = cast(dict[str, Any], json.loads(json_str))
assert parsed["type"] == "response.content_part.added"
assert parsed["output_index"] == 0
assert parsed["content_index"] == 0
assert parsed["part"]["type"] == "output_text"
def test_text_delta_event_format(self):
event = ResponseTextDeltaEvent(
output_index=0,
content_index=0,
delta="Hello",
)
json_str = event.model_dump_json()
parsed = cast(dict[str, Any], json.loads(json_str))
assert parsed["type"] == "response.output_text.delta"
assert parsed["output_index"] == 0
assert parsed["content_index"] == 0
assert parsed["delta"] == "Hello"
def test_text_done_event_format(self):
event = ResponseTextDoneEvent(
output_index=0,
content_index=0,
text="Hello, world!",
)
json_str = event.model_dump_json()
parsed = cast(dict[str, Any], json.loads(json_str))
assert parsed["type"] == "response.output_text.done"
assert parsed["text"] == "Hello, world!"
def test_output_item_done_event_format(self):
item = ResponseMessageItem(
id="item_123",
content=[ResponseOutputText(text="Hello, world!")],
status="completed",
)
event = ResponseOutputItemDoneEvent(output_index=0, item=item)
json_str = event.model_dump_json()
parsed = cast(dict[str, Any], json.loads(json_str))
assert parsed["type"] == "response.output_item.done"
assert parsed["item"]["status"] == "completed"
assert parsed["item"]["content"][0]["text"] == "Hello, world!"
def test_response_completed_event_format(self):
item = ResponseMessageItem(
id="item_123",
content=[ResponseOutputText(text="Hello!")],
status="completed",
)
response = ResponsesResponse(
id="resp_123",
model="gpt-4o",
status="completed",
output=[item],
output_text="Hello!",
usage=ResponseUsage(input_tokens=10, output_tokens=5, total_tokens=15),
)
event = ResponseCompletedEvent(response=response)
json_str = event.model_dump_json()
parsed = cast(dict[str, Any], json.loads(json_str))
assert parsed["type"] == "response.completed"
assert parsed["response"]["status"] == "completed"
assert parsed["response"]["output_text"] == "Hello!"
assert parsed["response"]["usage"]["total_tokens"] == 15
def test_sse_format(self):
"""Test that SSE format is correctly generated."""
event = ResponseTextDeltaEvent(
output_index=0,
content_index=0,
delta="Hello",
)
# Simulate the SSE format used in the streaming generator
sse_line = (
f"event: response.output_text.delta\ndata: {event.model_dump_json()}\n\n"
)
assert sse_line.startswith("event: response.output_text.delta\n")
assert "data: " in sse_line
assert sse_line.endswith("\n\n")

View File

@@ -70,7 +70,7 @@ def place_instance_command(model_meta: ModelMetadata) -> PlaceInstance:
[
((500, 500, 1000), 12, (3, 3, 6)),
((500, 500, 500), 12, (4, 4, 4)),
((312, 518, 1024), 12, (2, 3, 7)),
((312, 468, 1092), 12, (2, 3, 7)),
],
)
def test_get_instance_placements_create_instance(

View File

@@ -3,6 +3,7 @@ from typing import Callable
import pytest
from exo.master.placement_utils import (
allocate_layers_proportionally,
filter_cycles_by_memory,
get_hosts_from_subgraph,
get_mlx_jaccl_coordinators,
@@ -165,6 +166,9 @@ def test_get_smallest_cycles(
((500, 500, 1000), 12, (3, 3, 6)),
((500, 500, 500), 12, (4, 4, 4)),
((312, 518, 1024), 12, (2, 3, 7)),
# Edge case: one node has ~90% of memory - should not over-allocate.
# Each node must have enough memory for at least 1 layer (50 KB = 1000/20).
((900, 50, 50), 20, (18, 1, 1)),
],
)
def test_get_shard_assignments(
@@ -397,3 +401,96 @@ def test_get_mlx_jaccl_coordinators(
assert coordinators[node_c_id] == (
f"{conn_c_a.send_back_multiaddr.ip_address}:5000"
), "node_c should use the IP from conn_c_a"
class TestAllocateLayersProportionally:
def test_empty_node_list_raises(self):
with pytest.raises(ValueError, match="empty node list"):
allocate_layers_proportionally(total_layers=10, memory_fractions=[])
def test_zero_layers_raises(self):
with pytest.raises(ValueError, match="need at least 1 layer per node"):
allocate_layers_proportionally(total_layers=0, memory_fractions=[0.5, 0.5])
def test_negative_layers_raises(self):
with pytest.raises(ValueError, match="need at least 1 layer per node"):
allocate_layers_proportionally(total_layers=-1, memory_fractions=[0.5, 0.5])
def test_fewer_layers_than_nodes_raises(self):
with pytest.raises(ValueError, match="need at least 1 layer per node"):
allocate_layers_proportionally(
total_layers=2, memory_fractions=[0.33, 0.33, 0.34]
)
def test_equal_distribution(self):
result = allocate_layers_proportionally(
total_layers=12, memory_fractions=[0.25, 0.25, 0.25, 0.25]
)
assert result == [3, 3, 3, 3]
assert sum(result) == 12
def test_proportional_distribution(self):
result = allocate_layers_proportionally(
total_layers=12, memory_fractions=[0.25, 0.25, 0.50]
)
assert result == [3, 3, 6]
assert sum(result) == 12
def test_extreme_imbalance_ensures_minimum(self):
result = allocate_layers_proportionally(
total_layers=20, memory_fractions=[0.975, 0.0125, 0.0125]
)
assert all(layers >= 1 for layers in result)
assert sum(result) == 20
# Small nodes get minimum 1 layer
assert result == [18, 1, 1]
def test_single_node_gets_all_layers(self):
result = allocate_layers_proportionally(total_layers=10, memory_fractions=[1.0])
assert result == [10]
def test_minimum_viable_allocation(self):
result = allocate_layers_proportionally(
total_layers=3, memory_fractions=[0.33, 0.33, 0.34]
)
assert result == [1, 1, 1]
assert sum(result) == 3
def test_get_shard_assignments_insufficient_memory_raises(
topology: Topology,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
):
"""Test that ValueError is raised when a node has insufficient memory for its layers."""
node_a_id = NodeId()
node_b_id = NodeId()
node_c_id = NodeId()
# Node C has only 10 KB but would need 50 KB for 1 layer (1000 KB / 20 layers)
node_a = create_node(900 * 1024, node_a_id)
node_b = create_node(50 * 1024, node_b_id)
node_c = create_node(10 * 1024, node_c_id) # Insufficient memory
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_node(node_c)
topology.add_connection(create_connection(node_a_id, node_b_id))
topology.add_connection(create_connection(node_b_id, node_c_id))
topology.add_connection(create_connection(node_c_id, node_a_id))
topology.add_connection(create_connection(node_b_id, node_a_id))
model_meta = ModelMetadata(
model_id=ModelId("test-model"),
pretty_name="Test Model",
n_layers=20,
storage_size=Memory.from_kb(1000),
hidden_size=1000,
supports_tensor=True,
)
cycles = topology.get_cycles()
selected_cycle = cycles[0]
with pytest.raises(ValueError, match="insufficient memory"):
get_shard_assignments(model_meta, selected_cycle, Sharding.Pipeline)

View File

@@ -11,7 +11,6 @@ from exo.shared.types.events import (
IndexedEvent,
InstanceCreated,
InstanceDeleted,
InstanceDraftModelUpdated,
NodeCreated,
NodeDownloadProgress,
NodeMemoryMeasured,
@@ -48,8 +47,6 @@ def event_apply(event: Event, state: State) -> State:
return apply_instance_created(event, state)
case InstanceDeleted():
return apply_instance_deleted(event, state)
case InstanceDraftModelUpdated():
return apply_instance_draft_model_updated(event, state)
case NodeCreated():
return apply_topology_node_created(event, state)
case NodeTimedOut():
@@ -172,25 +169,6 @@ def apply_instance_deleted(event: InstanceDeleted, state: State) -> State:
return state.model_copy(update={"instances": new_instances})
def apply_instance_draft_model_updated(
event: InstanceDraftModelUpdated, state: State
) -> State:
if event.instance_id not in state.instances:
return state
instance = state.instances[event.instance_id]
updated_instance = instance.model_copy(
update={
"draft_model": event.draft_model,
"num_draft_tokens": event.num_draft_tokens,
}
)
new_instances: Mapping[InstanceId, Instance] = {
**state.instances,
event.instance_id: updated_instance,
}
return state.model_copy(update={"instances": new_instances})
def apply_runner_status_updated(event: RunnerStatusUpdated, state: State) -> State:
new_runners: Mapping[RunnerId, RunnerStatus] = {
**state.runners,

View File

@@ -157,12 +157,13 @@ class ChatCompletionTaskParams(BaseModel):
stream: bool = False
temperature: float | None = None
top_p: float | None = None
top_k: int | None = None
tools: list[dict[str, Any]] | None = None
tool_choice: str | dict[str, Any] | None = None
parallel_tool_calls: bool | None = None
user: str | None = None
# Speculative decoding: tokens to draft per iteration (if instance has draft model)
num_draft_tokens: int = 3
# When True, continue the last assistant message without EOS tokens
continue_from_prefix: bool = False
class BenchChatCompletionTaskParams(ChatCompletionTaskParams):
@@ -174,8 +175,6 @@ class PlaceInstanceParams(BaseModel):
sharding: Sharding = Sharding.Pipeline
instance_meta: InstanceMeta = InstanceMeta.MlxRing
min_nodes: int = 1
draft_model: ModelId | None = None # For speculative decoding
num_draft_tokens: int = 4 # Tokens to draft per iteration
@field_validator("sharding", "instance_meta", mode="plain")
@classmethod
@@ -217,14 +216,3 @@ class DeleteInstanceResponse(BaseModel):
message: str
command_id: CommandId
instance_id: InstanceId
class SetDraftModelParams(BaseModel):
draft_model: ModelId | None = None # None to disable speculative decoding
num_draft_tokens: int = 4
class SetDraftModelResponse(BaseModel):
message: str
command_id: CommandId
instance_id: InstanceId

View File

@@ -1,6 +1,6 @@
from enum import Enum
from exo.shared.types.api import GenerationStats
from exo.shared.types.api import GenerationStats, TopLogprobItem
from exo.utils.pydantic_ext import TaggedModel
from .api import FinishReason
@@ -20,6 +20,8 @@ class BaseChunk(TaggedModel):
class TokenChunk(BaseChunk):
text: str
token_id: int
logprob: float | None = None # Log probability of the selected token
top_logprobs: list[TopLogprobItem] | None = None # Top-k alternative tokens
finish_reason: FinishReason | None = None
stats: GenerationStats | None = None
error_message: str | None = None

View File

@@ -0,0 +1,168 @@
"""Claude Messages API types for request/response conversion."""
from typing import Literal
from pydantic import BaseModel, Field
# Type aliases
ClaudeRole = Literal["user", "assistant"]
ClaudeStopReason = Literal["end_turn", "max_tokens", "stop_sequence", "tool_use"]
# Content block types
class ClaudeTextBlock(BaseModel, frozen=True):
"""Text content block in Claude Messages API."""
type: Literal["text"] = "text"
text: str
class ClaudeImageSource(BaseModel, frozen=True):
"""Image source for Claude image blocks."""
type: Literal["base64", "url"]
media_type: str | None = None
data: str | None = None
url: str | None = None
class ClaudeImageBlock(BaseModel, frozen=True):
"""Image content block in Claude Messages API."""
type: Literal["image"] = "image"
source: ClaudeImageSource
ClaudeContentBlock = ClaudeTextBlock | ClaudeImageBlock
# Request types
class ClaudeMessage(BaseModel, frozen=True):
"""Message in Claude Messages API request."""
role: ClaudeRole
content: str | list[ClaudeContentBlock]
class ClaudeMessagesRequest(BaseModel):
"""Request body for Claude Messages API."""
model: str
max_tokens: int
messages: list[ClaudeMessage]
system: str | list[ClaudeTextBlock] | None = None
stop_sequences: list[str] | None = None
stream: bool = False
temperature: float | None = None
top_p: float | None = None
top_k: int | None = None
metadata: dict[str, str] | None = None
# Response types
class ClaudeUsage(BaseModel, frozen=True):
"""Token usage in Claude Messages API response."""
input_tokens: int
output_tokens: int
class ClaudeMessagesResponse(BaseModel, frozen=True):
"""Response body for Claude Messages API."""
id: str
type: Literal["message"] = "message"
role: Literal["assistant"] = "assistant"
content: list[ClaudeTextBlock]
model: str
stop_reason: ClaudeStopReason | None = None
stop_sequence: str | None = None
usage: ClaudeUsage
# Streaming event types
class ClaudeMessageStart(BaseModel, frozen=True):
"""Partial message in message_start event."""
id: str
type: Literal["message"] = "message"
role: Literal["assistant"] = "assistant"
content: list[ClaudeTextBlock] = Field(default_factory=list)
model: str
stop_reason: ClaudeStopReason | None = None
stop_sequence: str | None = None
usage: ClaudeUsage
class ClaudeMessageStartEvent(BaseModel, frozen=True):
"""Event sent at start of message stream."""
type: Literal["message_start"] = "message_start"
message: ClaudeMessageStart
class ClaudeContentBlockStartEvent(BaseModel, frozen=True):
"""Event sent at start of a content block."""
type: Literal["content_block_start"] = "content_block_start"
index: int
content_block: ClaudeTextBlock
class ClaudeTextDelta(BaseModel, frozen=True):
"""Delta for text content block."""
type: Literal["text_delta"] = "text_delta"
text: str
class ClaudeContentBlockDeltaEvent(BaseModel, frozen=True):
"""Event sent for content block delta."""
type: Literal["content_block_delta"] = "content_block_delta"
index: int
delta: ClaudeTextDelta
class ClaudeContentBlockStopEvent(BaseModel, frozen=True):
"""Event sent at end of a content block."""
type: Literal["content_block_stop"] = "content_block_stop"
index: int
class ClaudeMessageDeltaUsage(BaseModel, frozen=True):
"""Usage in message_delta event."""
output_tokens: int
class ClaudeMessageDelta(BaseModel, frozen=True):
"""Delta in message_delta event."""
stop_reason: ClaudeStopReason | None = None
stop_sequence: str | None = None
class ClaudeMessageDeltaEvent(BaseModel, frozen=True):
"""Event sent with final message delta."""
type: Literal["message_delta"] = "message_delta"
delta: ClaudeMessageDelta
usage: ClaudeMessageDeltaUsage
class ClaudeMessageStopEvent(BaseModel, frozen=True):
"""Event sent at end of message stream."""
type: Literal["message_stop"] = "message_stop"
ClaudeStreamEvent = (
ClaudeMessageStartEvent
| ClaudeContentBlockStartEvent
| ClaudeContentBlockDeltaEvent
| ClaudeContentBlockStopEvent
| ClaudeMessageDeltaEvent
| ClaudeMessageStopEvent
)

View File

@@ -1,8 +1,8 @@
from pydantic import Field
from exo.shared.types.api import ChatCompletionTaskParams
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.models import ModelMetadata
from exo.shared.types.openai_responses import ResponsesRequest
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.shards import Sharding
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
@@ -17,7 +17,7 @@ class TestCommand(BaseCommand):
class ChatCompletion(BaseCommand):
request_params: ChatCompletionTaskParams
request_params: ResponsesRequest
class PlaceInstance(BaseCommand):
@@ -25,8 +25,6 @@ class PlaceInstance(BaseCommand):
sharding: Sharding
instance_meta: InstanceMeta
min_nodes: int
draft_model: ModelId | None = None # For speculative decoding
num_draft_tokens: int = 4 # Tokens to draft per iteration
class CreateInstance(BaseCommand):
@@ -37,14 +35,6 @@ class DeleteInstance(BaseCommand):
instance_id: InstanceId
class SetInstanceDraftModel(BaseCommand):
"""Set or update the draft model for an existing instance."""
instance_id: InstanceId
draft_model: ModelId | None # None to disable speculative decoding
num_draft_tokens: int = 4
class TaskFinished(BaseCommand):
finished_command_id: CommandId
@@ -60,7 +50,6 @@ Command = (
| PlaceInstance
| CreateInstance
| DeleteInstance
| SetInstanceDraftModel
| TaskFinished
)

View File

@@ -5,7 +5,6 @@ from pydantic import Field
from exo.shared.topology import Connection, NodePerformanceProfile
from exo.shared.types.chunks import GenerationChunk
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
from exo.shared.types.models import ModelId
from exo.shared.types.profiling import MemoryPerformanceProfile
from exo.shared.types.tasks import Task, TaskId, TaskStatus
from exo.shared.types.worker.downloads import DownloadProgress
@@ -68,14 +67,6 @@ class InstanceDeleted(BaseEvent):
instance_id: InstanceId
class InstanceDraftModelUpdated(BaseEvent):
"""Draft model updated on an existing instance."""
instance_id: InstanceId
draft_model: ModelId | None
num_draft_tokens: int
class RunnerStatusUpdated(BaseEvent):
runner_id: RunnerId
runner_status: RunnerStatus
@@ -132,7 +123,6 @@ Event = (
| TaskAcknowledged
| InstanceCreated
| InstanceDeleted
| InstanceDraftModelUpdated
| RunnerStatusUpdated
| RunnerDeleted
| NodeCreated

View File

@@ -0,0 +1,190 @@
"""OpenAI Responses API types for request/response conversion.
ResponsesRequest serves as both:
1. The external API request type for /v1/responses
2. The canonical internal type used throughout the inference pipeline
All external API formats (Chat Completions, Claude) are converted to
ResponsesRequest at the API boundary.
"""
import time
from typing import Any, Literal
from pydantic import BaseModel, Field
# Type aliases
ResponseStatus = Literal["completed", "failed", "in_progress", "incomplete"]
ResponseRole = Literal["user", "assistant", "system", "developer"]
# Request types
class ResponseInputMessage(BaseModel, frozen=True):
"""Input message for Responses API.
This is also used as the internal message format throughout the pipeline.
"""
role: ResponseRole
content: str
class ResponsesRequest(BaseModel):
"""Request body for OpenAI Responses API.
This is also the canonical internal task params format used throughout
the inference pipeline. All external API formats are converted to this
format at the API boundary.
Field mapping from other APIs:
- input: Replaces 'messages' from Chat Completions
- instructions: System message, extracted from messages or Claude's 'system'
- max_output_tokens: Replaces 'max_tokens' from Chat Completions
"""
model: str
input: str | list[ResponseInputMessage]
instructions: str | None = None
max_output_tokens: int | None = None
temperature: float | None = None
top_p: float | None = None
top_k: int | None = None
stop: str | list[str] | None = None
seed: int | None = None
stream: bool = False
# Tools support
tools: list[dict[str, Any]] | None = None
# previous_response_id not supported in MVP
metadata: dict[str, str] | None = None
# When True, continue the last assistant message without EOS tokens
continue_from_prefix: bool = False
# Response types
class ResponseOutputText(BaseModel, frozen=True):
"""Text content in response output."""
type: Literal["output_text"] = "output_text"
text: str
annotations: list[dict[str, str]] = Field(default_factory=list)
class ResponseMessageItem(BaseModel, frozen=True):
"""Message item in response output array."""
type: Literal["message"] = "message"
id: str
role: Literal["assistant"] = "assistant"
content: list[ResponseOutputText]
status: ResponseStatus = "completed"
ResponseItem = ResponseMessageItem # Can expand for function_call, reasoning, etc.
class ResponseUsage(BaseModel, frozen=True):
"""Token usage in Responses API response."""
input_tokens: int
output_tokens: int
total_tokens: int
class ResponsesResponse(BaseModel, frozen=True):
"""Response body for OpenAI Responses API."""
id: str
object: Literal["response"] = "response"
created_at: int = Field(default_factory=lambda: int(time.time()))
status: ResponseStatus = "completed"
model: str
output: list[ResponseItem]
output_text: str
usage: ResponseUsage | None = None
# Streaming event types
class ResponseCreatedEvent(BaseModel, frozen=True):
"""Event sent when response is created."""
type: Literal["response.created"] = "response.created"
response: ResponsesResponse
class ResponseInProgressEvent(BaseModel, frozen=True):
"""Event sent when response starts processing."""
type: Literal["response.in_progress"] = "response.in_progress"
response: ResponsesResponse
class ResponseOutputItemAddedEvent(BaseModel, frozen=True):
"""Event sent when an output item is added."""
type: Literal["response.output_item.added"] = "response.output_item.added"
output_index: int
item: ResponseItem
class ResponseContentPartAddedEvent(BaseModel, frozen=True):
"""Event sent when a content part is added."""
type: Literal["response.content_part.added"] = "response.content_part.added"
output_index: int
content_index: int
part: ResponseOutputText
class ResponseTextDeltaEvent(BaseModel, frozen=True):
"""Event sent for text delta during streaming."""
type: Literal["response.output_text.delta"] = "response.output_text.delta"
output_index: int
content_index: int
delta: str
class ResponseTextDoneEvent(BaseModel, frozen=True):
"""Event sent when text content is done."""
type: Literal["response.output_text.done"] = "response.output_text.done"
output_index: int
content_index: int
text: str
class ResponseContentPartDoneEvent(BaseModel, frozen=True):
"""Event sent when a content part is done."""
type: Literal["response.content_part.done"] = "response.content_part.done"
output_index: int
content_index: int
part: ResponseOutputText
class ResponseOutputItemDoneEvent(BaseModel, frozen=True):
"""Event sent when an output item is done."""
type: Literal["response.output_item.done"] = "response.output_item.done"
output_index: int
item: ResponseItem
class ResponseCompletedEvent(BaseModel, frozen=True):
"""Event sent when response is completed."""
type: Literal["response.completed"] = "response.completed"
response: ResponsesResponse
ResponsesStreamEvent = (
ResponseCreatedEvent
| ResponseInProgressEvent
| ResponseOutputItemAddedEvent
| ResponseContentPartAddedEvent
| ResponseTextDeltaEvent
| ResponseTextDoneEvent
| ResponseContentPartDoneEvent
| ResponseOutputItemDoneEvent
| ResponseCompletedEvent
)

View File

@@ -2,8 +2,8 @@ from enum import Enum
from pydantic import Field
from exo.shared.types.api import ChatCompletionTaskParams
from exo.shared.types.common import CommandId, Id
from exo.shared.types.openai_responses import ResponsesRequest
from exo.shared.types.worker.instances import BoundInstance, InstanceId
from exo.shared.types.worker.runners import RunnerId
from exo.shared.types.worker.shards import ShardMetadata
@@ -36,12 +36,6 @@ class DownloadModel(BaseTask): # emitted by Worker
shard_metadata: ShardMetadata
class DownloadDraftModel(BaseTask): # emitted by Worker
"""Download a draft model for speculative decoding (rank 0 only)."""
model_id: str # HuggingFace model ID
class LoadModel(BaseTask): # emitted by Worker
pass
@@ -56,7 +50,7 @@ class StartWarmup(BaseTask): # emitted by Worker
class ChatCompletion(BaseTask): # emitted by Master
command_id: CommandId
task_params: ChatCompletionTaskParams
task_params: ResponsesRequest
error_type: str | None = Field(default=None)
error_message: str | None = Field(default=None)
@@ -66,21 +60,12 @@ class Shutdown(BaseTask): # emitted by Worker
runner_id: RunnerId
class SetDraftModel(BaseTask): # emitted by Worker
"""Load or clear a draft model on an already-running instance."""
model_id: str | None # HuggingFace model ID, or None to clear
num_draft_tokens: int = 4
Task = (
CreateRunner
| DownloadModel
| DownloadDraftModel
| ConnectToGroup
| LoadModel
| StartWarmup
| ChatCompletion
| Shutdown
| SetDraftModel
)

View File

@@ -3,7 +3,6 @@ from enum import Enum
from pydantic import model_validator
from exo.shared.types.common import Host, Id, NodeId
from exo.shared.types.models import ModelId
from exo.shared.types.worker.runners import RunnerId, ShardAssignments, ShardMetadata
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
@@ -20,8 +19,6 @@ class InstanceMeta(str, Enum):
class BaseInstance(TaggedModel):
instance_id: InstanceId
shard_assignments: ShardAssignments
draft_model: ModelId | None = None # For speculative decoding (rank 0 only)
num_draft_tokens: int = 4 # Tokens to draft per iteration (when draft_model is set)
def shard(self, runner_id: RunnerId) -> ShardMetadata | None:
return self.shard_assignments.runner_to_shard.get(runner_id, None)

View File

@@ -1,4 +1,4 @@
from exo.shared.types.api import FinishReason, GenerationStats
from exo.shared.types.api import FinishReason, GenerationStats, TopLogprobItem
from exo.utils.pydantic_ext import TaggedModel
@@ -13,7 +13,8 @@ class TokenizedResponse(BaseRunnerResponse):
class GenerationResponse(BaseRunnerResponse):
text: str
token: int
# logprobs: list[float] | None = None # too big. we can change to be top-k
logprob: float | None = None # Log probability of the selected token
top_logprobs: list[TopLogprobItem] | None = None # Top-k alternative tokens
finish_reason: FinishReason | None = None
stats: GenerationStats | None = None

View File

@@ -40,4 +40,6 @@ class TokenizerWrapper:
messages_dicts: list[dict[str, Any]],
tokenize: bool = False,
add_generation_prompt: bool = True,
continue_final_message: bool = False,
tools: list[dict[str, Any]] | None = None,
) -> str: ...

View File

@@ -8,13 +8,12 @@ from mlx_lm.tokenizer_utils import TokenizerWrapper
# from exo.engines.mlx.cache import KVPrefixCache
from exo.shared.types.api import (
BenchChatCompletionTaskParams,
ChatCompletionMessage,
FinishReason,
GenerationStats,
TopLogprobItem,
)
from exo.shared.types.memory import Memory
from exo.shared.types.tasks import ChatCompletionTaskParams
from exo.shared.types.openai_responses import ResponsesRequest
from exo.shared.types.worker.runner_response import (
GenerationResponse,
)
@@ -48,50 +47,38 @@ def maybe_quantize_kv_cache(
def warmup_inference(
model: Model,
tokenizer: TokenizerWrapper,
draft_model: Model | None = None,
num_draft_tokens: int = 4,
) -> int:
content = "Prompt to warm up the inference engine. Repeat this."
warmup_prompt = apply_chat_template(
tokenizer=tokenizer,
chat_task_data=ChatCompletionTaskParams(
task_params=ResponsesRequest(
model="",
messages=[
ChatCompletionMessage(
role="user",
content=content,
)
],
input=content,
),
)
tokens_generated = 0
cache = make_kv_cache(
model=model,
)
# Use a default sampler for warmup
sampler = make_sampler(temp=0.7)
generate_kwargs: dict[str, object] = {
"model": model,
"tokenizer": tokenizer,
"prompt": warmup_prompt,
"max_tokens": 50,
"sampler": sampler,
"prefill_step_size": 2048,
"kv_group_size": KV_GROUP_SIZE,
"kv_bits": KV_BITS,
}
# Warm up with draft model if provided (speculative decoding path)
if draft_model is not None:
logger.info("Warming up with speculative decoding (draft model)")
generate_kwargs["draft_model"] = draft_model
generate_kwargs["num_draft_tokens"] = num_draft_tokens
else:
generate_kwargs["prompt_cache"] = make_kv_cache(model=model)
logger.info("Generating warmup tokens")
for _r in stream_generate(**generate_kwargs): # type: ignore[arg-type]
for _r in stream_generate(
model=model,
tokenizer=tokenizer,
prompt=warmup_prompt,
max_tokens=50,
sampler=sampler,
prompt_cache=cache,
prefill_step_size=2048,
kv_group_size=KV_GROUP_SIZE,
kv_bits=KV_BITS,
):
logger.info("Generated warmup token: " + str(_r.text))
tokens_generated += 1
@@ -122,16 +109,68 @@ def eos_ids_from_tokenizer(tokenizer: TokenizerWrapper) -> list[int]:
return eos
def extract_top_logprobs(
logprobs: mx.array,
tokenizer: TokenizerWrapper,
top_k: int,
selected_token: int,
) -> tuple[float, list[TopLogprobItem]]:
"""Extract the selected token's logprob and top-k alternative tokens.
Args:
logprobs: Full vocabulary logprobs array from MLX
tokenizer: Tokenizer for decoding token IDs to strings
top_k: Number of top alternatives to return
selected_token: The token ID that was actually sampled
Returns:
Tuple of (selected_token_logprob, list of TopLogprobItem for top-k tokens)
"""
# Get the logprob of the selected token
selected_logprob = float(logprobs[selected_token].item())
# Get top-k indices (most probable tokens)
# mx.argpartition gives indices that would partition the array
# We negate logprobs since argpartition finds smallest, and we want largest
top_k = min(top_k, logprobs.shape[0]) # Don't exceed vocab size
top_indices = mx.argpartition(-logprobs, top_k)[:top_k]
# Get the actual logprob values for these indices
top_values = logprobs[top_indices]
# Sort by logprob (descending) for consistent ordering
sort_order = mx.argsort(-top_values)
top_indices = top_indices[sort_order]
top_values = top_values[sort_order]
# Convert to list of TopLogprobItem
top_logprob_items: list[TopLogprobItem] = []
for i in range(top_k):
token_id = int(top_indices[i].item())
token_logprob = float(top_values[i].item())
# Decode token ID to string
token_str = tokenizer.decode([token_id])
# Get byte representation
token_bytes = list(token_str.encode("utf-8"))
top_logprob_items.append(
TopLogprobItem(
token=token_str,
logprob=token_logprob,
bytes=token_bytes,
)
)
return selected_logprob, top_logprob_items
def mlx_generate(
model: Model,
tokenizer: TokenizerWrapper,
task: ChatCompletionTaskParams,
draft_model: Model | None = None,
num_draft_tokens: int = 4,
task: ResponsesRequest,
is_bench: bool = False,
) -> Generator[GenerationResponse]:
# Ensure that generation stats only contains peak memory for this generation
mx.reset_peak_memory()
is_bench: bool = isinstance(task, BenchChatCompletionTaskParams)
# Currently we support chat-completion tasks only.
logger.info(f"task_params: {task}")
@@ -141,9 +180,11 @@ def mlx_generate(
prompt = apply_chat_template(
tokenizer=tokenizer,
chat_task_data=task,
task_params=task,
)
caches = make_kv_cache(model=model)
logits_processors: list[Callable[[mx.array, mx.array], mx.array]] = []
if is_bench:
# Only sample length eos tokens
@@ -153,38 +194,57 @@ def mlx_generate(
sampler = make_sampler(
temp=task.temperature if task.temperature is not None else 0.7,
top_p=task.top_p if task.top_p is not None else 1.0,
top_k=task.top_k if task.top_k is not None else 0,
)
max_tokens = task.max_tokens or MAX_TOKENS
# Normalize stop sequences to a list
stop_sequences: list[str] = (
([task.stop] if isinstance(task.stop, str) else task.stop)
if task.stop is not None
else []
)
max_stop_len = max((len(s) for s in stop_sequences), default=0)
# Build kwargs for stream_generate, conditionally adding draft model params
generate_kwargs: dict[str, object] = {
"model": model,
"tokenizer": tokenizer,
"prompt": prompt,
"max_tokens": max_tokens,
"sampler": sampler,
"logits_processors": logits_processors,
"prefill_step_size": 2048,
"kv_group_size": KV_GROUP_SIZE,
"kv_bits": KV_BITS,
}
max_tokens = task.max_output_tokens or MAX_TOKENS
accumulated_text = ""
# Add speculative decoding parameters if draft model is provided
# Note: When using draft_model, we let mlx_lm create its own trimmable cache
# as speculative decoding requires cache trimming capabilities
if draft_model is not None:
generate_kwargs["draft_model"] = draft_model
generate_kwargs["num_draft_tokens"] = num_draft_tokens
else:
# Only use custom cache for non-speculative generation
generate_kwargs["prompt_cache"] = make_kv_cache(model=model)
for out in stream_generate(**generate_kwargs): # type: ignore[arg-type]
for out in stream_generate(
model=model,
tokenizer=tokenizer,
prompt=prompt,
max_tokens=max_tokens,
sampler=sampler,
logits_processors=logits_processors,
prompt_cache=caches,
prefill_step_size=2048,
kv_group_size=KV_GROUP_SIZE,
kv_bits=KV_BITS,
):
logger.info(out.text)
accumulated_text += out.text
# Check for stop sequences
text = out.text
finish_reason: FinishReason | None = cast(
FinishReason | None, out.finish_reason
)
stop_matched = False
if stop_sequences:
for stop_seq in stop_sequences:
if stop_seq in accumulated_text:
# Trim text to just before the stop sequence
stop_index = accumulated_text.find(stop_seq)
text_before_stop = accumulated_text[:stop_index]
chunk_start = len(accumulated_text) - len(out.text)
text = text_before_stop[chunk_start:]
finish_reason = "stop"
stop_matched = True
break
is_done = finish_reason is not None
stats: GenerationStats | None = None
if out.finish_reason is not None:
if is_done:
stats = GenerationStats(
prompt_tps=float(out.prompt_tps),
generation_tps=float(out.generation_tps),
@@ -192,22 +252,33 @@ def mlx_generate(
generation_tokens=int(out.generation_tokens),
peak_memory_usage=Memory.from_gb(out.peak_memory),
)
if out.finish_reason not in get_args(FinishReason):
# We don't throw here as this failure case is really not all that bad
# Just log the error and move on
if not stop_matched and out.finish_reason not in get_args(FinishReason):
logger.warning(
f"Model generated unexpected finish_reason: {out.finish_reason}"
)
# Extract logprobs from the full vocabulary logprobs array
logprob, top_logprobs = extract_top_logprobs(
logprobs=out.logprobs,
tokenizer=tokenizer,
top_k=5,
selected_token=out.token,
)
yield GenerationResponse(
text=out.text,
text=text,
token=out.token,
finish_reason=cast(FinishReason | None, out.finish_reason),
logprob=logprob,
top_logprobs=top_logprobs,
finish_reason=finish_reason,
stats=stats,
)
if out.finish_reason is not None:
if is_done:
break
# Limit accumulated_text to what's needed for stop sequence detection
if max_stop_len > 0 and len(accumulated_text) > max_stop_len:
accumulated_text = accumulated_text[-max_stop_len:]
# TODO: Do we want an mx_barrier?

View File

@@ -42,10 +42,9 @@ import mlx.nn as nn
from mlx_lm.utils import load_model
from pydantic import RootModel
from exo.shared.types.api import ChatCompletionMessageText
from exo.shared.types.common import Host
from exo.shared.types.memory import Memory
from exo.shared.types.tasks import ChatCompletionTaskParams
from exo.shared.types.openai_responses import ResponsesRequest
from exo.shared.types.worker.instances import (
BoundInstance,
MlxJacclInstance,
@@ -258,27 +257,6 @@ def load_mlx_items(
return cast(Model, model), tokenizer
def load_draft_model(model_id: str) -> nn.Module:
"""Load a draft model for speculative decoding (rank 0 only).
Draft models are small models (typically 0.5B-2B parameters) used to
generate candidate tokens quickly, which are then verified by the main
model in a single forward pass.
Assumes the model has already been downloaded by the worker.
Args:
model_id: HuggingFace model ID for the draft model
Returns:
The loaded draft model
"""
model_path = build_model_path(model_id)
draft_model, _ = load_model(model_path, strict=True)
logger.info(f"Loaded draft model from {model_path}")
return draft_model
def shard_and_load(
shard_metadata: ShardMetadata,
group: Group,
@@ -411,35 +389,53 @@ def load_tokenizer_for_model_id(model_id: str, model_path: Path) -> TokenizerWra
def apply_chat_template(
tokenizer: TokenizerWrapper,
chat_task_data: ChatCompletionTaskParams,
task_params: ResponsesRequest,
) -> str:
# Now we can properly access the messages
messages = chat_task_data.messages
"""Convert ResponsesRequest to a chat template prompt.
Converts the internal format (input + instructions) to a messages list
that can be processed by the tokenizer's chat template.
"""
formatted_messages: list[dict[str, Any]] = []
for message in messages:
if isinstance(message.content, ChatCompletionMessageText):
message.content = message.content.text
if isinstance(message.content, list):
if len(message.content) == 0:
logger.warning("Received prompt with no content, skipping")
continue
message.content = "\n".join(c.text for c in message.content).strip()
if message.content is None and message.thinking is None:
continue
# Null values are not valid when applying templates in tokenizer
# Add system message (instructions) if present
if task_params.instructions:
formatted_messages.append(
{k: v for k, v in message.model_dump().items() if v is not None} # type: ignore
{"role": "system", "content": task_params.instructions}
)
prompt: str = tokenizer.apply_chat_template(
formatted_messages,
tokenize=False,
add_generation_prompt=True,
tools=chat_task_data.tools,
)
# Convert input to messages
if isinstance(task_params.input, str):
# Simple string input becomes a single user message
formatted_messages.append({"role": "user", "content": task_params.input})
else:
# List of InputMessage
for msg in task_params.input:
if not msg.content:
logger.warning("Received message with empty content, skipping")
continue
formatted_messages.append({"role": msg.role, "content": msg.content})
# Use continue_final_message when continuing from prefix (e.g., regenerate from token)
# This keeps the final assistant message open without EOS tokens
# Note: explicitly set add_generation_prompt=False when using continue_final_message
# because some tokenizers (e.g., Kimi) default add_generation_prompt=True
prompt: str
if task_params.continue_from_prefix:
prompt = tokenizer.apply_chat_template(
formatted_messages,
tokenize=False,
continue_final_message=True,
add_generation_prompt=False,
tools=task_params.tools,
)
else:
prompt = tokenizer.apply_chat_template(
formatted_messages,
tokenize=False,
add_generation_prompt=True,
tools=task_params.tools,
)
logger.info(prompt)

View File

@@ -29,9 +29,7 @@ from exo.shared.types.profiling import MemoryPerformanceProfile, NodePerformance
from exo.shared.types.state import State
from exo.shared.types.tasks import (
CreateRunner,
DownloadDraftModel,
DownloadModel,
SetDraftModel,
Shutdown,
Task,
TaskStatus,
@@ -50,7 +48,6 @@ from exo.utils.event_buffer import OrderedBuffer
from exo.worker.download.download_utils import (
map_repo_download_progress_to_download_progress_data,
)
from exo.worker.download.impl_shard_downloader import build_full_shard
from exo.worker.download.shard_downloader import RepoDownloadProgress, ShardDownloader
from exo.worker.plan import plan
from exo.worker.runner.runner_supervisor import RunnerSupervisor
@@ -205,10 +202,42 @@ class Worker:
)
)
case DownloadModel(shard_metadata=shard):
await self._handle_download(shard, task)
case DownloadDraftModel(model_id=model_id):
shard = await build_full_shard(model_id)
await self._handle_download(shard, task)
if shard.model_meta.model_id not in self.download_status:
progress = DownloadPending(
shard_metadata=shard, node_id=self.node_id
)
self.download_status[shard.model_meta.model_id] = progress
await self.event_sender.send(
NodeDownloadProgress(download_progress=progress)
)
initial_progress = (
await self.shard_downloader.get_shard_download_status_for_shard(
shard
)
)
if initial_progress.status == "complete":
progress = DownloadCompleted(
shard_metadata=shard,
node_id=self.node_id,
total_bytes=initial_progress.total_bytes,
)
self.download_status[shard.model_meta.model_id] = progress
await self.event_sender.send(
NodeDownloadProgress(download_progress=progress)
)
await self.event_sender.send(
TaskStatusUpdated(
task_id=task.task_id,
task_status=TaskStatus.Complete,
)
)
else:
await self.event_sender.send(
TaskStatusUpdated(
task_id=task.task_id, task_status=TaskStatus.Running
)
)
self._handle_shard_download_process(task, initial_progress)
case Shutdown(runner_id=runner_id):
try:
with fail_after(3):
@@ -219,25 +248,6 @@ class Worker:
task_id=task.task_id, task_status=TaskStatus.TimedOut
)
)
case SetDraftModel(
model_id=draft_model_id, num_draft_tokens=num_tokens
):
runner = self.runners[self._task_to_runner_id(task)]
await runner.start_task(task)
# Update bound_instance to reflect new/cleared draft model
updated_instance = runner.bound_instance.instance.model_copy(
update={
"draft_model": (
ModelId(draft_model_id)
if draft_model_id is not None
else None
),
"num_draft_tokens": num_tokens,
}
)
runner.bound_instance = runner.bound_instance.model_copy(
update={"instance": updated_instance}
)
case task:
await self.runners[self._task_to_runner_id(task)].start_task(task)
@@ -330,46 +340,6 @@ class Worker:
self._tg.start_soon(runner.run)
return runner
async def _handle_download(self, shard: ShardMetadata, task: Task) -> None:
"""Handle model download - shared logic for main and draft models."""
model_id = shard.model_meta.model_id
if model_id not in self.download_status:
progress = DownloadPending(shard_metadata=shard, node_id=self.node_id)
self.download_status[model_id] = progress
await self.event_sender.send(
NodeDownloadProgress(download_progress=progress)
)
initial_progress = (
await self.shard_downloader.get_shard_download_status_for_shard(shard)
)
if initial_progress.status == "complete":
progress = DownloadCompleted(
shard_metadata=shard,
node_id=self.node_id,
total_bytes=initial_progress.total_bytes,
)
self.download_status[model_id] = progress
await self.event_sender.send(
NodeDownloadProgress(download_progress=progress)
)
await self.event_sender.send(
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Complete)
)
else:
await self.event_sender.send(
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Running)
)
download_task = DownloadModel(
instance_id=task.instance_id,
shard_metadata=shard,
task_id=task.task_id,
task_status=task.task_status,
)
self._handle_shard_download_process(download_task, initial_progress)
def _handle_shard_download_process(
self,
task: DownloadModel,

View File

@@ -8,10 +8,8 @@ from exo.shared.types.tasks import (
ChatCompletion,
ConnectToGroup,
CreateRunner,
DownloadDraftModel,
DownloadModel,
LoadModel,
SetDraftModel,
Shutdown,
StartWarmup,
Task,
@@ -40,16 +38,6 @@ from exo.shared.types.worker.runners import (
from exo.worker.runner.runner_supervisor import RunnerSupervisor
def _is_download_in_progress_or_complete(
model_id: ModelId,
download_status: Mapping[ModelId, DownloadProgress],
) -> bool:
"""Check if model download is in progress or complete."""
return model_id in download_status and isinstance(
download_status[model_id], (DownloadOngoing, DownloadCompleted)
)
def plan(
node_id: NodeId,
# Runners is expected to be FRESH and so should not come from state
@@ -67,11 +55,9 @@ def plan(
_kill_runner(runners, all_runners, instances)
or _create_runner(node_id, runners, instances)
or _model_needs_download(runners, download_status)
or _draft_model_needs_download(runners, download_status, instances)
or _init_distributed_backend(runners, all_runners)
or _load_model(runners, all_runners, global_download_status, download_status)
or _load_model(runners, all_runners, global_download_status)
or _ready_to_warmup(runners, all_runners)
or _set_draft_model(runners, instances, download_status)
or _pending_tasks(runners, tasks, all_runners)
)
@@ -129,9 +115,12 @@ def _model_needs_download(
) -> DownloadModel | None:
for runner in runners.values():
model_id = runner.bound_instance.bound_shard.model_meta.model_id
if isinstance(
runner.status, RunnerIdle
) and not _is_download_in_progress_or_complete(model_id, download_status):
if isinstance(runner.status, RunnerIdle) and (
model_id not in download_status
or not isinstance(
download_status[model_id], (DownloadOngoing, DownloadCompleted)
)
):
# We don't invalidate download_status randomly in case a file gets deleted on disk
return DownloadModel(
instance_id=runner.bound_instance.instance.instance_id,
@@ -139,43 +128,6 @@ def _model_needs_download(
)
def _draft_model_needs_download(
runners: Mapping[RunnerId, RunnerSupervisor],
download_status: Mapping[ModelId, DownloadProgress],
instances: Mapping[InstanceId, Instance],
) -> DownloadDraftModel | None:
"""Check if draft model needs download for rank 0 runner.
Triggers download when:
- RunnerIdle with draft model (initial setup)
- RunnerReady with new draft model (updated via API)
"""
rank_0_runner = next(
(r for r in runners.values() if r.bound_instance.bound_shard.device_rank == 0),
None,
)
if rank_0_runner is None:
return None
if not isinstance(rank_0_runner.status, (RunnerIdle, RunnerReady)):
return None
# Use current instance state (may have been updated via API)
instance_id = rank_0_runner.bound_instance.instance.instance_id
current_instance = instances.get(instance_id)
if current_instance is None:
return None
draft_model_id = current_instance.draft_model
if draft_model_id is None:
return None
if _is_download_in_progress_or_complete(draft_model_id, download_status):
return None
return DownloadDraftModel(
instance_id=instance_id,
model_id=str(draft_model_id),
)
def _init_distributed_backend(
runners: Mapping[RunnerId, RunnerSupervisor],
all_runners: Mapping[RunnerId, RunnerStatus],
@@ -230,12 +182,10 @@ def _load_model(
runners: Mapping[RunnerId, RunnerSupervisor],
all_runners: Mapping[RunnerId, RunnerStatus],
global_download_status: Mapping[NodeId, Sequence[DownloadProgress]],
download_status: Mapping[ModelId, DownloadProgress],
) -> LoadModel | None:
for runner in runners.values():
instance = runner.bound_instance.instance
shard_assignments = instance.shard_assignments
shard = runner.bound_instance.bound_shard
all_local_downloads_complete = all(
nid in global_download_status
@@ -249,14 +199,6 @@ def _load_model(
if not all_local_downloads_complete:
continue
# Rank 0 with draft model must wait for draft download before loading
if shard.device_rank == 0:
draft_model_id = instance.draft_model
if draft_model_id is not None and not isinstance(
download_status.get(draft_model_id), DownloadCompleted
):
continue
is_single_node_instance = len(instance.shard_assignments.runner_to_shard) == 1
if is_single_node_instance and isinstance(runner.status, RunnerIdle):
return LoadModel(instance_id=instance.instance_id)
@@ -316,53 +258,6 @@ def _ready_to_warmup(
return None
def _set_draft_model(
runners: Mapping[RunnerId, RunnerSupervisor],
instances: Mapping[InstanceId, Instance],
download_status: Mapping[ModelId, DownloadProgress],
) -> SetDraftModel | None:
"""Check if rank 0 runner needs to load or clear a draft model."""
rank_0_runner = next(
(r for r in runners.values() if r.bound_instance.bound_shard.device_rank == 0),
None,
)
if rank_0_runner is None:
return None
if not isinstance(rank_0_runner.status, RunnerReady):
return None
instance_id = rank_0_runner.bound_instance.instance.instance_id
current_instance = instances.get(instance_id)
if current_instance is None:
return None
# Compare runner's bound draft model vs current instance draft model
runner_draft_model = rank_0_runner.bound_instance.instance.draft_model
current_draft_model = current_instance.draft_model
if runner_draft_model == current_draft_model:
return None
# Draft model changed - need to update
if current_draft_model is None:
# Clear draft model
return SetDraftModel(
instance_id=instance_id,
model_id=None,
num_draft_tokens=4,
)
# Wait for draft model to be downloaded
if not isinstance(download_status.get(current_draft_model), DownloadCompleted):
return None
return SetDraftModel(
instance_id=instance_id,
model_id=str(current_draft_model),
num_draft_tokens=current_instance.num_draft_tokens,
)
def _pending_tasks(
runners: Mapping[RunnerId, RunnerSupervisor],
tasks: Mapping[TaskId, Task],

View File

@@ -1,8 +1,6 @@
import time
from collections.abc import Generator
from contextlib import contextmanager
from functools import cache
from typing import cast
import mlx.core as mx
from mlx_lm.models.gpt_oss import Model as GptOssModel
@@ -13,9 +11,7 @@ from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
load_harmony_encoding,
)
from exo.shared.types.api import ChatCompletionMessageText
from exo.shared.types.chunks import TokenChunk
from exo.shared.types.common import CommandId
from exo.shared.types.events import (
ChunkGenerated,
Event,
@@ -23,12 +19,11 @@ from exo.shared.types.events import (
TaskAcknowledged,
TaskStatusUpdated,
)
from exo.shared.types.models import ModelId
from exo.shared.types.openai_responses import ResponsesRequest
from exo.shared.types.tasks import (
ChatCompletion,
ConnectToGroup,
LoadModel,
SetDraftModel,
Shutdown,
StartWarmup,
Task,
@@ -53,44 +48,15 @@ from exo.shared.types.worker.runners import (
RunnerWarmingUp,
)
from exo.utils.channels import MpReceiver, MpSender
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.generator.generate import mlx_generate, warmup_inference
from exo.worker.engines.mlx.utils_mlx import (
initialize_mlx,
load_draft_model,
load_mlx_items,
mlx_force_oom,
)
from exo.worker.runner.bootstrap import logger
@contextmanager
def send_error_chunk_on_exception(
event_sender: MpSender[Event],
command_id: CommandId,
model_id: ModelId,
device_rank: int,
):
try:
yield
except Exception as e:
logger.error(e)
if device_rank == 0:
event_sender.send(
ChunkGenerated(
command_id=command_id,
chunk=TokenChunk(
idx=0,
model=model_id,
text="",
token_id=0,
finish_reason="error",
error_message=str(e),
),
)
)
def main(
bound_instance: BoundInstance,
event_sender: MpSender[Event],
@@ -101,6 +67,7 @@ def main(
bound_instance.bound_runner_id,
bound_instance.bound_shard,
)
device_rank = shard_metadata.device_rank
logger.info("hello from the runner")
if getattr(shard_metadata, "immediate_exception", False):
raise Exception("Fake exception - runner failed to spin up.")
@@ -112,7 +79,6 @@ def main(
model = None
tokenizer = None
group = None
draft_model: Model | None = None # Loaded during warmup if instance has draft_model
current_status: RunnerStatus = RunnerIdle()
logger.info("runner created")
@@ -168,16 +134,6 @@ def main(
bound_instance, group, on_timeout=on_model_load_timeout
)
# Load draft model for speculative decoding (rank 0 only)
if (
instance.draft_model is not None
and shard_metadata.device_rank == 0
):
logger.info(f"Loading draft model: {instance.draft_model}")
draft_model = cast(
Model, load_draft_model(str(instance.draft_model))
)
current_status = RunnerLoaded()
logger.info("runner loaded")
case StartWarmup() if isinstance(current_status, RunnerLoaded):
@@ -193,10 +149,9 @@ def main(
logger.info(f"warming up inference for instance: {instance}")
toks = warmup_inference(
model=cast(Model, model),
model=model,
tokenizer=tokenizer,
draft_model=draft_model,
num_draft_tokens=instance.num_draft_tokens,
# kv_prefix_cache=kv_prefix_cache, # supply for warmup-time prefix caching
)
logger.info(f"warmed up by generating {toks} tokens")
logger.info(
@@ -215,24 +170,17 @@ def main(
runner_id=runner_id, runner_status=current_status
)
)
with send_error_chunk_on_exception(
event_sender,
command_id,
shard_metadata.model_meta.model_id,
shard_metadata.device_rank,
):
assert model
assert tokenizer
assert task_params.messages[0].content is not None
_check_for_debug_prompts(task_params.messages[0].content)
assert model
assert tokenizer
# Generate responses (draft_model loaded at warmup if configured)
try:
_check_for_debug_prompts(task_params)
# Generate responses using the actual MLX generation
mlx_generator = mlx_generate(
model=cast(Model, model),
model=model,
tokenizer=tokenizer,
task=task_params,
draft_model=draft_model,
num_draft_tokens=instance.num_draft_tokens,
)
# GPT-OSS specific parsing to match other model formats.
@@ -244,7 +192,7 @@ def main(
for response in mlx_generator:
match response:
case GenerationResponse():
if shard_metadata.device_rank == 0:
if device_rank == 0:
event_sender.send(
ChunkGenerated(
command_id=command_id,
@@ -253,58 +201,34 @@ def main(
model=shard_metadata.model_meta.model_id,
text=response.text,
token_id=response.token,
logprob=response.logprob,
top_logprobs=response.top_logprobs,
finish_reason=response.finish_reason,
stats=response.stats,
),
)
)
# can we make this more explicit?
except Exception as e:
if device_rank == 0:
event_sender.send(
ChunkGenerated(
command_id=command_id,
chunk=TokenChunk(
idx=0,
model=shard_metadata.model_meta.model_id,
text="",
token_id=0,
finish_reason="error",
error_message=str(e),
),
)
)
raise
current_status = RunnerReady()
logger.info("runner ready")
case SetDraftModel(
model_id=draft_model_id, num_draft_tokens=num_tokens
) if isinstance(current_status, RunnerReady):
current_status = RunnerWarmingUp()
logger.info("runner warming up (setting draft model)")
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=current_status
)
)
assert model is not None
assert tokenizer is not None
if draft_model_id is None:
# Clear draft model
logger.info("Clearing draft model")
draft_model = None
instance = instance.model_copy(
update={
"draft_model": None,
"num_draft_tokens": 4,
}
)
else:
# Load new draft model
logger.info(f"Loading draft model: {draft_model_id}")
draft_model = cast(Model, load_draft_model(draft_model_id))
instance = instance.model_copy(
update={
"draft_model": ModelId(draft_model_id),
"num_draft_tokens": num_tokens,
}
)
# Warm up with speculative decoding
logger.info("Warming up with new draft model")
warmup_inference(
model=cast(Model, model),
tokenizer=tokenizer,
draft_model=draft_model,
num_draft_tokens=num_tokens,
)
logger.info("Draft model loaded and warmed up")
current_status = RunnerReady()
case Shutdown():
current_status = RunnerShuttingDown()
logger.info("runner shutting down")
@@ -325,7 +249,7 @@ def main(
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
)
if isinstance(current_status, RunnerShutdown):
del model, tokenizer, group, draft_model
del model, tokenizer, group
mx.clear_cache()
import gc
@@ -375,17 +299,23 @@ EXO_RUNNER_MUST_OOM = "EXO RUNNER MUST OOM"
EXO_RUNNER_MUST_TIMEOUT = "EXO RUNNER MUST TIMEOUT"
def _check_for_debug_prompts(
prompt: str | ChatCompletionMessageText | list[ChatCompletionMessageText],
):
if isinstance(prompt, list):
if len(prompt) == 0:
logger.debug("Empty message prompt received in debug prompt")
return
prompt = prompt[0]
def _check_for_debug_prompts(task_params: ResponsesRequest) -> None:
"""Check for debug prompt triggers in the input.
if isinstance(prompt, ChatCompletionMessageText):
prompt = prompt.text
Extracts the first user input text and checks for debug triggers.
"""
prompt: str
if isinstance(task_params.input, str):
prompt = task_params.input
else:
# List of InputMessage - get first message content
if len(task_params.input) == 0:
logger.debug("Empty message list in debug prompt check")
return
prompt = task_params.input[0].content
if not prompt:
return
if EXO_RUNNER_MUST_FAIL in prompt:
logger.info("raising exception")

View File

@@ -1,7 +1,7 @@
from typing import cast
import exo.worker.plan as plan_mod
from exo.shared.types.api import ChatCompletionTaskParams
from exo.shared.types.openai_responses import ResponsesRequest
from exo.shared.types.tasks import ChatCompletion, Task, TaskId, TaskStatus
from exo.shared.types.worker.instances import BoundInstance, InstanceId
from exo.shared.types.worker.runners import (
@@ -59,7 +59,7 @@ def test_plan_forwards_pending_chat_completion_when_runner_ready():
instance_id=INSTANCE_1_ID,
task_status=TaskStatus.Pending,
command_id=COMMAND_1_ID,
task_params=ChatCompletionTaskParams(model=MODEL_A_ID, messages=[]),
task_params=ResponsesRequest(model=MODEL_A_ID, input=""),
)
result = plan_mod.plan(
@@ -107,7 +107,7 @@ def test_plan_does_not_forward_chat_completion_if_any_runner_not_ready():
instance_id=INSTANCE_1_ID,
task_status=TaskStatus.Pending,
command_id=COMMAND_1_ID,
task_params=ChatCompletionTaskParams(model=MODEL_A_ID, messages=[]),
task_params=ResponsesRequest(model=MODEL_A_ID, input=""),
)
result = plan_mod.plan(
@@ -152,7 +152,7 @@ def test_plan_does_not_forward_tasks_for_other_instances():
instance_id=other_instance_id,
task_status=TaskStatus.Pending,
command_id=COMMAND_1_ID,
task_params=ChatCompletionTaskParams(model=MODEL_A_ID, messages=[]),
task_params=ResponsesRequest(model=MODEL_A_ID, input=""),
)
result = plan_mod.plan(
@@ -201,7 +201,7 @@ def test_plan_ignores_non_pending_or_non_chat_tasks():
instance_id=INSTANCE_1_ID,
task_status=TaskStatus.Complete,
command_id=COMMAND_1_ID,
task_params=ChatCompletionTaskParams(model=MODEL_A_ID, messages=[]),
task_params=ResponsesRequest(model=MODEL_A_ID, input=""),
)
other_task_id = TaskId("other-task")

View File

@@ -1,50 +0,0 @@
# pyright: reportAny=false
from unittest.mock import MagicMock
from exo.shared.types.chunks import TokenChunk
from exo.shared.types.common import CommandId
from exo.shared.types.events import ChunkGenerated
from exo.worker.runner.runner import send_error_chunk_on_exception
from exo.worker.tests.constants import MODEL_A_ID
def test_send_error_chunk_on_exception_no_error() -> None:
event_sender = MagicMock()
command_id = CommandId()
with send_error_chunk_on_exception(
event_sender, command_id, MODEL_A_ID, device_rank=0
):
_ = 1 + 1
event_sender.send.assert_not_called()
def test_send_error_chunk_on_exception_catches_error() -> None:
event_sender = MagicMock()
command_id = CommandId()
with send_error_chunk_on_exception(
event_sender, command_id, MODEL_A_ID, device_rank=0
):
raise ValueError("test error")
event_sender.send.assert_called_once()
call_args = event_sender.send.call_args[0][0]
assert isinstance(call_args, ChunkGenerated)
assert call_args.command_id == command_id
assert isinstance(call_args.chunk, TokenChunk)
assert call_args.chunk.finish_reason == "error"
assert call_args.chunk.error_message == "test error"
def test_send_error_chunk_on_exception_skips_non_rank_zero() -> None:
event_sender = MagicMock()
command_id = CommandId()
with send_error_chunk_on_exception(
event_sender, command_id, MODEL_A_ID, device_rank=1
):
raise ValueError("test error")
event_sender.send.assert_not_called()

View File

@@ -5,7 +5,6 @@ from typing import Callable
import pytest
import exo.worker.runner.runner as mlx_runner
from exo.shared.types.api import ChatCompletionMessage
from exo.shared.types.chunks import TokenChunk
from exo.shared.types.events import (
ChunkGenerated,
@@ -14,9 +13,9 @@ from exo.shared.types.events import (
TaskAcknowledged,
TaskStatusUpdated,
)
from exo.shared.types.openai_responses import ResponsesRequest
from exo.shared.types.tasks import (
ChatCompletion,
ChatCompletionTaskParams,
ConnectToGroup,
LoadModel,
Shutdown,
@@ -85,11 +84,11 @@ SHUTDOWN_TASK = Shutdown(
runner_id=RUNNER_1_ID,
)
CHAT_PARAMS = ChatCompletionTaskParams(
CHAT_PARAMS = ResponsesRequest(
model=str(MODEL_A_ID),
messages=[ChatCompletionMessage(role="user", content="hello")],
input="hello",
stream=True,
max_tokens=4,
max_output_tokens=4,
temperature=0.0,
)

View File

@@ -13,10 +13,10 @@ from pydantic import BaseModel
from exo.shared.logging import InterceptLogger, logger_setup
from exo.shared.models.model_cards import MODEL_CARDS, ModelId
from exo.shared.types.api import ChatCompletionMessage, ChatCompletionTaskParams
from exo.shared.types.commands import CommandId
from exo.shared.types.common import Host, NodeId
from exo.shared.types.events import Event
from exo.shared.types.openai_responses import ResponsesRequest
from exo.shared.types.tasks import (
ChatCompletion,
ConnectToGroup,
@@ -169,16 +169,10 @@ async def execute_test(test: Tests, instance: Instance, hn: str):
send.send(StartWarmup(instance_id=iid))
send.send(
ChatCompletion(
task_params=ChatCompletionTaskParams(
task_params=ResponsesRequest(
model=test.model_id,
messages=[
ChatCompletionMessage(
role="system", content="You are a helpful assistant"
),
ChatCompletionMessage(
role="user", content="What is the capital of France?"
),
],
instructions="You are a helpful assistant",
input="What is the capital of France?",
),
command_id=CommandId("yo"),
instance_id=iid,