Compare commits

...

21 Commits

Author SHA1 Message Date
Alex Cheema
7ba2408eed fix: restore dashboard build by using main's app.svelte.ts
The prefill-progress branch's app.svelte.ts was missing image generation
features from main. To fix the dashboard build, restored main's
app.svelte.ts and removed the uncertainty visualization and prefill
progress bar features from ChatMessages.svelte that depended on the
missing exports.

Note: TokenHeatmap and PrefillProgressBar components still exist but
are not currently used. The prefill progress backend code is still in
place and can be re-enabled in the dashboard once app.svelte.ts is
properly updated to include both image generation and prefill/uncertainty
features.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-22 12:07:16 +00:00
Alex Cheema
ce4d7f4d43 style: simplify prefill progress bar and use exo color palette
- Remove spinner (progress bar is dynamic enough)
- Use exo-yellow for progress bar fill
- Use exo-black/60 for progress bar background
- Use exo-light-gray for text

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-22 12:00:55 +00:00
Alex Cheema
6727523eab fix: wire prefill progress events to chat completions stream
- Move PrefillProgressData to shared types (chunks.py) to avoid circular imports
- Update generate_chat_stream adapter to handle both TokenChunk and PrefillProgressData
- Use _stream_events instead of _chat_chunk_stream for streaming endpoint
- Prefill progress now properly sent as SSE 'event: prefill_progress' to frontend

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-22 12:00:55 +00:00
Alex Cheema
a8f81e0495 feat: add prefill progress bar for long prompts
Shows real-time progress during prompt processing (prefill phase).
Progress is sent via SSE named events that maintain OpenAI API compatibility.

- Add PrefillProgress event type and PrefillProgressData dataclass
- Wire prompt_progress_callback through MLX stream_generate
- Send progress events directly from callback for real-time updates
- Add PrefillProgressBar.svelte component
- Parse event: prefill_progress SSE events in dashboard

Note: prefill_step_size temporarily set to 256 for testing (normally 2048)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-22 11:59:52 +00:00
Alex Cheema
ba7148ccec style: format app.svelte.ts with nix fmt 2026-01-22 11:53:43 +00:00
Alex Cheema
a64b8addc6 Fix localStorage quota issues by stripping tokens and auto-pruning
- Strip tokens (logprobs data) from messages before saving to localStorage
  since they're large and not essential for persistence
- Add pruneOldConversations() to automatically remove oldest conversations
  when quota is exceeded
- This prevents QuotaExceededError from crashing the app

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-22 11:53:43 +00:00
Alex Cheema
e6599a9408 Fix ReferenceError: controller undefined in sendMessage finally block
Move AbortController creation before the try block in both
sendMessageWithLogprobs and regenerateFromToken functions.
Previously, controller was defined inside the try block but
referenced in the finally block, causing a ReferenceError
if an exception was thrown before the controller was created.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-22 11:53:43 +00:00
Alex Cheema
93f4753598 Add SSE headers to properly close streaming connections
Add Cache-Control, Connection: close, and X-Accel-Buffering headers
to all SSE streaming responses. These headers help ensure:
- No caching of streaming responses
- Connection closes when stream ends (instead of keep-alive)
- No proxy buffering that could delay stream closure

This should fix the issue where the frontend stays on "PROCESSING"
even after receiving the complete response.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-22 11:53:43 +00:00
Alex Cheema
75fe505275 Add debug logging to generate_chat_stream
Add logging to help diagnose why streaming might not be ending properly.
This will show when [DONE] is yielded, when return is called, and when
the finally block runs.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-22 11:53:43 +00:00
Alex Cheema
d7c044e349 Fix streaming not ending after [DONE] is yielded
Add missing return statement after yielding [DONE] in generate_chat_stream.
Without this, the async generator continues waiting for more chunks from
chunk_stream even though generation is complete, causing the stream to hang
indefinitely. The frontend waits for the stream to close (reader.done) which
never happens, resulting in the chat button staying on "PROCESSING" forever.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-22 11:53:43 +00:00
Alex Cheema
53b6d56e9f fix: restore extract_top_logprobs function for uncertainty visualization
The extract_top_logprobs function was lost during rebases. This function
processes the out.logprobs array (full vocabulary logprobs from MLX) to
extract the selected token's logprob and top-k alternatives.

The previous code tried to use getattr(out, "logprob", None) which
doesn't exist - mlx_lm returns logprobs as an mx.array, not individual
values.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-22 11:53:43 +00:00
Alex Cheema
7fe0a61230 fix: remove unsupported logprob params from stream_generate
The mlx_lm.stream_generate already returns logprobs in its output -
we don't need to pass return_logprob or return_top_logprobs kwargs.
The uncertainty visualization feature extracts logprobs from the
existing out.logprobs field.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-22 11:53:43 +00:00
Alex Cheema
5a36542631 feat: add uncertainty visualization with token-level logprobs
- Add TokenHeatmap component for visualizing token confidence
- Collect and stream logprobs in generation pipeline
- Add regenerate-from-token feature with continue_from_prefix
- Add AbortController for request cancellation
- Support continue_final_message for seamless prefix continuation

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-22 11:53:43 +00:00
Alex Cheema
955e0105b3 fix: resolve import and type errors from rebase
- Use claude_request_to_internal instead of old function name
- Fix ModelId imports in runner.py and test files
- Update test_mlx/conftest.py to use ResponsesRequest format
- Remove unused imports

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-22 11:36:11 +00:00
Evan
4d1eb1d9bd fix: rebase fix 2026-01-22 11:32:46 +00:00
Alex Cheema
365416c65e style: move inline imports to top of file in api.py
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-22 11:32:26 +00:00
Alex Cheema
04af76e10f fix: restore try/except structure in runner.py
Replace non-existent context manager with proper try/except block
and remove unused ModelId import.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-22 11:32:04 +00:00
Alex Cheema
a84c3431cd style: fix formatting issues caught by treefmt
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-22 11:31:45 +00:00
Alex Cheema
52445b21f6 refactor: use ResponsesRequest as canonical internal type
- Extend ResponsesRequest with fields: top_k, seed, stop, tools
- Remove redundant InternalTaskParams and InputMessage types
- Update all adapters to convert to ResponsesRequest
- Simplify Responses API (no conversion needed - native passthrough)
- Update all imports across codebase and tests

This eliminates type duplication and makes the Responses API
relationship explicit throughout the codebase.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-22 11:31:44 +00:00
Alex Cheema
435bd7f6fa refactor: make Responses API the canonical internal format
Restructure the API layer so that OpenAI Responses API is the native
format, with Chat Completions and Claude Messages as adapters on top.

Changes:
- Add new chat_completions.py adapter with streaming/non-streaming support
- Update responses.py with collect_responses_response() for non-streaming
- Update claude.py with collect_claude_response() for non-streaming
- Refactor api.py so all endpoints use adapters uniformly
- Rename _chat_chunk_stream to _token_chunk_stream (generic internal format)
- Remove unused chat_response_to_* converter functions
- Update tests to remove tests for deleted functions

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-22 11:30:27 +00:00
Alex Cheema
dd25b5b90e feat: add Claude Messages API and OpenAI Responses API support
Adds two new API endpoints that wrap the existing chat completions:

- /v1/messages - Claude Messages API compatible endpoint
- /v1/responses - OpenAI Responses API compatible endpoint

Both support streaming (SSE) and non-streaming modes with proper
token usage reporting from actual inference stats.

Also adds top_k sampling parameter and stop sequence support to the
MLX inference engine.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-22 11:28:49 +00:00
28 changed files with 2294 additions and 194 deletions

View File

@@ -8,7 +8,6 @@
regenerateLastResponse,
setEditingImage,
} from "$lib/stores/app.svelte";
import type { Message } from "$lib/stores/app.svelte";
import type { MessageAttachment } from "$lib/stores/app.svelte";
import MarkdownContent from "./MarkdownContent.svelte";

View File

@@ -0,0 +1,51 @@
<script lang="ts">
import type { PrefillProgress } from "$lib/stores/app.svelte";
interface Props {
progress: PrefillProgress;
class?: string;
}
let { progress, class: className = "" }: Props = $props();
const percentage = $derived(
progress.total > 0
? Math.round((progress.processed / progress.total) * 100)
: 0,
);
function formatTokenCount(count: number): string {
if (count >= 1000) {
return `${(count / 1000).toFixed(1)}k`;
}
return count.toString();
}
</script>
<div class="prefill-progress {className}">
<div
class="flex items-center justify-between text-xs text-exo-light-gray mb-1"
>
<span>Processing prompt</span>
<span class="font-mono">
{formatTokenCount(progress.processed)} / {formatTokenCount(
progress.total,
)} tokens
</span>
</div>
<div class="h-1.5 bg-exo-black/60 rounded-full overflow-hidden">
<div
class="h-full bg-exo-yellow rounded-full transition-all duration-150 ease-out"
style="width: {percentage}%"
></div>
</div>
<div class="text-right text-xs text-exo-light-gray/70 mt-0.5 font-mono">
{percentage}%
</div>
</div>
<style>
.prefill-progress {
width: 100%;
}
</style>

View File

@@ -0,0 +1,229 @@
<script lang="ts">
import type { TokenData } from "$lib/stores/app.svelte";
interface Props {
tokens: TokenData[];
class?: string;
isGenerating?: boolean;
onRegenerateFrom?: (tokenIndex: number) => void;
}
let {
tokens,
class: className = "",
isGenerating = false,
onRegenerateFrom,
}: Props = $props();
// Tooltip state - track both token data and index
let hoveredTokenIndex = $state<number | null>(null);
let hoveredPosition = $state<{ x: number; y: number } | null>(null);
let isTooltipHovered = $state(false);
let hideTimeoutId: ReturnType<typeof setTimeout> | null = null;
// Derive the hovered token from the index (stable across re-renders)
const hoveredToken = $derived(
hoveredTokenIndex !== null && hoveredPosition && tokens[hoveredTokenIndex]
? {
token: tokens[hoveredTokenIndex],
index: hoveredTokenIndex,
...hoveredPosition,
}
: null,
);
/**
* Get confidence styling based on probability.
* Following Apple design principles: high confidence tokens blend in,
* only uncertainty draws attention.
*/
function getConfidenceClass(probability: number): string {
if (probability > 0.8) return "text-inherit"; // Expected tokens - blend in
if (probability > 0.5) return "bg-gray-500/10 text-inherit"; // Slight hint
if (probability > 0.2) return "bg-amber-500/15 text-amber-200/90"; // Subtle warmth
return "bg-red-500/20 text-red-200/90"; // Draws attention
}
/**
* Get border/underline styling for uncertain tokens
*/
function getBorderClass(probability: number): string {
if (probability > 0.8) return "border-transparent"; // No border for expected
if (probability > 0.5) return "border-gray-500/20";
if (probability > 0.2) return "border-amber-500/30";
return "border-red-500/40";
}
function clearHideTimeout() {
if (hideTimeoutId) {
clearTimeout(hideTimeoutId);
hideTimeoutId = null;
}
}
function handleMouseEnter(
event: MouseEvent,
token: TokenData,
index: number,
) {
clearHideTimeout();
const rect = (event.target as HTMLElement).getBoundingClientRect();
hoveredTokenIndex = index;
hoveredPosition = {
x: rect.left + rect.width / 2,
y: rect.top - 10,
};
}
function handleMouseLeave() {
clearHideTimeout();
// Use longer delay during generation to account for re-renders
const delay = isGenerating ? 300 : 100;
hideTimeoutId = setTimeout(() => {
if (!isTooltipHovered) {
hoveredTokenIndex = null;
hoveredPosition = null;
}
}, delay);
}
function handleTooltipEnter() {
clearHideTimeout();
isTooltipHovered = true;
}
function handleTooltipLeave() {
isTooltipHovered = false;
hoveredTokenIndex = null;
hoveredPosition = null;
}
function handleRegenerate() {
if (hoveredToken && onRegenerateFrom) {
const indexToRegenerate = hoveredToken.index;
// Clear hover state immediately
hoveredTokenIndex = null;
hoveredPosition = null;
isTooltipHovered = false;
// Call regenerate
onRegenerateFrom(indexToRegenerate);
}
}
function formatProbability(prob: number): string {
return (prob * 100).toFixed(1) + "%";
}
function formatLogprob(logprob: number): string {
return logprob.toFixed(3);
}
function getProbabilityColor(probability: number): string {
if (probability > 0.8) return "text-gray-300";
if (probability > 0.5) return "text-gray-400";
if (probability > 0.2) return "text-amber-400";
return "text-red-400";
}
</script>
<div class="token-heatmap leading-relaxed {className}">
{#each tokens as tokenData, i (i)}
<span
role="button"
tabindex="0"
class="token-span inline rounded px-0.5 py-0.5 cursor-pointer transition-all duration-150 border {getConfidenceClass(
tokenData.probability,
)} {getBorderClass(tokenData.probability)} hover:opacity-80"
onmouseenter={(e) => handleMouseEnter(e, tokenData, i)}
onmouseleave={handleMouseLeave}>{tokenData.token}</span
>
{/each}
</div>
<!-- Tooltip -->
{#if hoveredToken}
<div
class="fixed z-50"
style="left: {hoveredToken.x}px; top: {hoveredToken.y}px; transform: translate(-50%, -100%);"
onmouseenter={handleTooltipEnter}
onmouseleave={handleTooltipLeave}
>
<div
class="bg-gray-900/95 backdrop-blur-sm border border-gray-700/50 rounded-xl shadow-xl p-3 text-sm min-w-48"
>
<!-- Token info -->
<div class="mb-2">
<span class="text-gray-500 text-xs">Token:</span>
<span class="text-white font-mono ml-1"
>"{hoveredToken.token.token}"</span
>
<span class="{getProbabilityColor(hoveredToken.token.probability)} ml-2"
>{formatProbability(hoveredToken.token.probability)}</span
>
</div>
<div class="text-gray-400 text-xs mb-1">
logprob: <span class="text-gray-300 font-mono"
>{formatLogprob(hoveredToken.token.logprob)}</span
>
</div>
<!-- Top alternatives -->
{#if hoveredToken.token.topLogprobs.length > 0}
<div class="border-t border-gray-700/50 mt-2 pt-2">
<div class="text-gray-500 text-xs mb-1">Alternatives:</div>
{#each hoveredToken.token.topLogprobs.slice(0, 5) as alt, idx (idx)}
{@const altProb = Math.exp(alt.logprob)}
<div class="flex justify-between items-center text-xs py-0.5">
<span class="text-gray-300 font-mono truncate max-w-24"
>"{alt.token}"</span
>
<span class="text-gray-400 ml-2"
>{formatProbability(altProb)}</span
>
</div>
{/each}
</div>
{/if}
<!-- Regenerate button -->
{#if onRegenerateFrom}
<button
onclick={handleRegenerate}
class="w-full mt-2 pt-2 border-t border-gray-700/50 flex items-center justify-center gap-1.5 text-xs text-gray-400 hover:text-white transition-colors cursor-pointer"
>
<svg
class="w-3 h-3"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
stroke-width="2"
d="M4 4v5h.582m15.356 2A8.001 8.001 0 004.582 9m0 0H9m11 11v-5h-.581m0 0a8.003 8.003 0 01-15.357-2m15.357 2H15"
/>
</svg>
Regenerate from here
</button>
{/if}
</div>
<!-- Arrow -->
<div class="absolute left-1/2 -translate-x-1/2 top-full">
<div class="border-8 border-transparent border-t-gray-900"></div>
</div>
</div>
{/if}
<style>
.token-heatmap {
word-wrap: break-word;
white-space: pre-wrap;
}
.token-span {
margin: 0;
border-width: 1px;
}
</style>

View File

@@ -0,0 +1 @@
"""API adapters for different API formats (Claude, OpenAI Responses, etc.)."""

View File

@@ -0,0 +1,197 @@
"""OpenAI Chat Completions API adapter for converting requests/responses."""
import time
from collections.abc import AsyncGenerator
from loguru import logger
from exo.shared.types.api import (
ChatCompletionChoice,
ChatCompletionMessage,
ChatCompletionMessageText,
ChatCompletionResponse,
ChatCompletionTaskParams,
ErrorInfo,
ErrorResponse,
FinishReason,
Logprobs,
LogprobsContentItem,
StreamingChoiceResponse,
)
from exo.shared.types.chunks import PrefillProgressData, StreamEvent, TokenChunk
from exo.shared.types.common import CommandId
from exo.shared.types.openai_responses import ResponseInputMessage, ResponsesRequest
def chat_request_to_internal(request: ChatCompletionTaskParams) -> ResponsesRequest:
"""Convert Chat Completions API request to ResponsesRequest (canonical internal format).
Extracts system message as instructions, converts messages to input.
"""
instructions: str | None = None
input_messages: list[ResponseInputMessage] = []
for msg in request.messages:
# Normalize content to string
content: str
if msg.content is None:
content = ""
elif isinstance(msg.content, str):
content = msg.content
elif isinstance(msg.content, ChatCompletionMessageText):
content = msg.content.text
else:
# List of ChatCompletionMessageText
content = "\n".join(item.text for item in msg.content)
# Extract system message as instructions
if msg.role == "system":
if instructions is None:
instructions = content
else:
# Append additional system messages
instructions = f"{instructions}\n{content}"
else:
# Convert to ResponseInputMessage (only user, assistant, developer roles)
if msg.role in ("user", "assistant", "developer"):
input_messages.append(
ResponseInputMessage(role=msg.role, content=content)
)
return ResponsesRequest(
model=request.model,
input=input_messages if input_messages else "",
instructions=instructions,
max_output_tokens=request.max_tokens,
temperature=request.temperature,
top_p=request.top_p,
top_k=request.top_k,
stop=request.stop,
seed=request.seed,
stream=request.stream,
tools=request.tools,
continue_from_prefix=request.continue_from_prefix,
)
def chunk_to_response(
chunk: TokenChunk, command_id: CommandId
) -> ChatCompletionResponse:
"""Convert a TokenChunk to a streaming ChatCompletionResponse."""
# Build logprobs if available
logprobs: Logprobs | None = None
if chunk.logprob is not None:
logprobs = Logprobs(
content=[
LogprobsContentItem(
token=chunk.text,
logprob=chunk.logprob,
top_logprobs=chunk.top_logprobs or [],
)
]
)
return ChatCompletionResponse(
id=command_id,
created=int(time.time()),
model=chunk.model,
choices=[
StreamingChoiceResponse(
index=0,
delta=ChatCompletionMessage(role="assistant", content=chunk.text),
logprobs=logprobs,
finish_reason=chunk.finish_reason,
)
],
)
async def generate_chat_stream(
command_id: CommandId,
event_stream: AsyncGenerator[StreamEvent, None],
) -> AsyncGenerator[str, None]:
"""Generate Chat Completions API streaming events from StreamEvents.
Handles both TokenChunks (token generation) and PrefillProgressData (prefill progress).
"""
try:
async for event in event_stream:
if isinstance(event, PrefillProgressData):
# Send prefill progress as a named SSE event
progress_json = f'{{"processed":{event.processed_tokens},"total":{event.total_tokens}}}'
yield f"event: prefill_progress\ndata: {progress_json}\n\n"
continue
# TokenChunk handling
chunk = event
if chunk.finish_reason == "error":
error_response = ErrorResponse(
error=ErrorInfo(
message=chunk.error_message or "Internal server error",
type="InternalServerError",
code=500,
)
)
yield f"data: {error_response.model_dump_json()}\n\n"
yield "data: [DONE]\n\n"
logger.info(f"generate_chat_stream ending (error): {command_id}")
return
chunk_response = chunk_to_response(chunk, command_id)
yield f"data: {chunk_response.model_dump_json()}\n\n"
if chunk.finish_reason is not None:
logger.info(
f"generate_chat_stream yielding [DONE] for finish_reason={chunk.finish_reason}: {command_id}"
)
yield "data: [DONE]\n\n"
logger.info(f"generate_chat_stream returning: {command_id}")
return
finally:
logger.info(f"generate_chat_stream finally block: {command_id}")
async def collect_chat_response(
command_id: CommandId,
chunk_stream: AsyncGenerator[TokenChunk, None],
) -> ChatCompletionResponse:
"""Collect all token chunks and return a single ChatCompletionResponse."""
text_parts: list[str] = []
model: str | None = None
finish_reason: FinishReason | None = None
error_message: str | None = None
async for chunk in chunk_stream:
if chunk.finish_reason == "error":
error_message = chunk.error_message or "Internal server error"
break
if model is None:
model = chunk.model
text_parts.append(chunk.text)
if chunk.finish_reason is not None:
finish_reason = chunk.finish_reason
if error_message is not None:
raise ValueError(error_message)
combined_text = "".join(text_parts)
assert model is not None
return ChatCompletionResponse(
id=command_id,
created=int(time.time()),
model=model,
choices=[
ChatCompletionChoice(
index=0,
message=ChatCompletionMessage(
role="assistant",
content=combined_text,
),
finish_reason=finish_reason,
)
],
)

View File

@@ -0,0 +1,190 @@
"""Claude Messages API adapter for converting requests/responses."""
from collections.abc import AsyncGenerator
from exo.shared.types.api import FinishReason
from exo.shared.types.chunks import TokenChunk
from exo.shared.types.claude_api import (
ClaudeContentBlockDeltaEvent,
ClaudeContentBlockStartEvent,
ClaudeContentBlockStopEvent,
ClaudeMessageDelta,
ClaudeMessageDeltaEvent,
ClaudeMessageDeltaUsage,
ClaudeMessagesRequest,
ClaudeMessagesResponse,
ClaudeMessageStart,
ClaudeMessageStartEvent,
ClaudeMessageStopEvent,
ClaudeStopReason,
ClaudeTextBlock,
ClaudeTextDelta,
ClaudeUsage,
)
from exo.shared.types.common import CommandId
from exo.shared.types.openai_responses import ResponseInputMessage, ResponsesRequest
def finish_reason_to_claude_stop_reason(
finish_reason: FinishReason | None,
) -> ClaudeStopReason | None:
"""Map OpenAI finish_reason to Claude stop_reason."""
if finish_reason is None:
return None
mapping: dict[FinishReason, ClaudeStopReason] = {
"stop": "end_turn",
"length": "max_tokens",
"tool_calls": "tool_use",
"content_filter": "end_turn",
"function_call": "tool_use",
}
return mapping.get(finish_reason, "end_turn")
def claude_request_to_internal(request: ClaudeMessagesRequest) -> ResponsesRequest:
"""Convert Claude Messages API request to ResponsesRequest (canonical internal format).
Converts Claude's system parameter to instructions,
and messages to input.
"""
# Handle system message
instructions: str | None = None
if request.system:
if isinstance(request.system, str):
instructions = request.system
else:
# List of text blocks
instructions = "".join(block.text for block in request.system)
# Convert messages to input
input_messages: list[ResponseInputMessage] = []
for msg in request.messages:
content: str
if isinstance(msg.content, str):
content = msg.content
else:
# Concatenate text blocks (images not supported for MVP)
text_parts: list[str] = []
for block in msg.content:
if isinstance(block, ClaudeTextBlock):
text_parts.append(block.text)
content = "".join(text_parts)
# Claude uses "user" and "assistant" roles
input_messages.append(ResponseInputMessage(role=msg.role, content=content))
return ResponsesRequest(
model=request.model,
input=input_messages if input_messages else "",
instructions=instructions,
max_output_tokens=request.max_tokens,
temperature=request.temperature,
top_p=request.top_p,
top_k=request.top_k,
stop=request.stop_sequences,
stream=request.stream,
)
async def collect_claude_response(
command_id: CommandId,
model: str,
chunk_stream: AsyncGenerator[TokenChunk, None],
) -> ClaudeMessagesResponse:
"""Collect all token chunks and return a single ClaudeMessagesResponse."""
text_parts: list[str] = []
stop_reason: ClaudeStopReason | None = None
last_stats = None
error_message: str | None = None
async for chunk in chunk_stream:
if chunk.finish_reason == "error":
error_message = chunk.error_message or "Internal server error"
break
text_parts.append(chunk.text)
last_stats = chunk.stats or last_stats
if chunk.finish_reason is not None:
stop_reason = finish_reason_to_claude_stop_reason(chunk.finish_reason)
if error_message is not None:
raise ValueError(error_message)
combined_text = "".join(text_parts)
# Use actual usage data from stats if available
input_tokens = last_stats.prompt_tokens if last_stats else 0
output_tokens = last_stats.generation_tokens if last_stats else 0
return ClaudeMessagesResponse(
id=f"msg_{command_id}",
model=model,
content=[ClaudeTextBlock(text=combined_text)],
stop_reason=stop_reason,
usage=ClaudeUsage(
input_tokens=input_tokens,
output_tokens=output_tokens,
),
)
async def generate_claude_stream(
command_id: CommandId,
model: str,
chunk_stream: AsyncGenerator[TokenChunk, None],
) -> AsyncGenerator[str, None]:
"""Generate Claude Messages API streaming events from TokenChunks."""
# Initial message_start event
initial_message = ClaudeMessageStart(
id=f"msg_{command_id}",
model=model,
content=[],
stop_reason=None,
usage=ClaudeUsage(input_tokens=0, output_tokens=0),
)
start_event = ClaudeMessageStartEvent(message=initial_message)
yield f"event: message_start\ndata: {start_event.model_dump_json()}\n\n"
# content_block_start
block_start = ClaudeContentBlockStartEvent(
index=0, content_block=ClaudeTextBlock(text="")
)
yield f"event: content_block_start\ndata: {block_start.model_dump_json()}\n\n"
output_tokens = 0
stop_reason: ClaudeStopReason | None = None
last_stats = None
async for chunk in chunk_stream:
output_tokens += 1 # Count each chunk as one token
last_stats = chunk.stats or last_stats
# content_block_delta
delta_event = ClaudeContentBlockDeltaEvent(
index=0,
delta=ClaudeTextDelta(text=chunk.text),
)
yield f"event: content_block_delta\ndata: {delta_event.model_dump_json()}\n\n"
if chunk.finish_reason is not None:
stop_reason = finish_reason_to_claude_stop_reason(chunk.finish_reason)
# Use actual token count from stats if available
if last_stats is not None:
output_tokens = last_stats.generation_tokens
# content_block_stop
block_stop = ClaudeContentBlockStopEvent(index=0)
yield f"event: content_block_stop\ndata: {block_stop.model_dump_json()}\n\n"
# message_delta
message_delta = ClaudeMessageDeltaEvent(
delta=ClaudeMessageDelta(stop_reason=stop_reason),
usage=ClaudeMessageDeltaUsage(output_tokens=output_tokens),
)
yield f"event: message_delta\ndata: {message_delta.model_dump_json()}\n\n"
# message_stop
message_stop = ClaudeMessageStopEvent()
yield f"event: message_stop\ndata: {message_stop.model_dump_json()}\n\n"

View File

@@ -0,0 +1,173 @@
"""OpenAI Responses API adapter for converting requests/responses.
ResponsesRequest is the canonical internal format. Responses API is the most featureful,
making it the natural choice for the internal format. All other API formats (Chat
Completions, Claude) are converted TO ResponsesRequest.
"""
from collections.abc import AsyncGenerator
from exo.shared.types.chunks import TokenChunk
from exo.shared.types.common import CommandId
from exo.shared.types.openai_responses import (
ResponseCompletedEvent,
ResponseContentPartAddedEvent,
ResponseContentPartDoneEvent,
ResponseCreatedEvent,
ResponseInProgressEvent,
ResponseMessageItem,
ResponseOutputItemAddedEvent,
ResponseOutputItemDoneEvent,
ResponseOutputText,
ResponsesResponse,
ResponseTextDeltaEvent,
ResponseTextDoneEvent,
ResponseUsage,
)
async def collect_responses_response(
command_id: CommandId,
model: str,
chunk_stream: AsyncGenerator[TokenChunk, None],
) -> ResponsesResponse:
"""Collect all token chunks and return a single ResponsesResponse."""
response_id = f"resp_{command_id}"
item_id = f"item_{command_id}"
accumulated_text = ""
last_stats = None
error_message: str | None = None
async for chunk in chunk_stream:
if chunk.finish_reason == "error":
error_message = chunk.error_message or "Internal server error"
break
accumulated_text += chunk.text
last_stats = chunk.stats or last_stats
if error_message is not None:
raise ValueError(error_message)
# Create usage from stats if available
usage = None
if last_stats is not None:
usage = ResponseUsage(
input_tokens=last_stats.prompt_tokens,
output_tokens=last_stats.generation_tokens,
total_tokens=last_stats.prompt_tokens + last_stats.generation_tokens,
)
output_item = ResponseMessageItem(
id=item_id,
content=[ResponseOutputText(text=accumulated_text)],
status="completed",
)
return ResponsesResponse(
id=response_id,
model=model,
status="completed",
output=[output_item],
output_text=accumulated_text,
usage=usage,
)
async def generate_responses_stream(
command_id: CommandId,
model: str,
chunk_stream: AsyncGenerator[TokenChunk, None],
) -> AsyncGenerator[str, None]:
"""Generate OpenAI Responses API streaming events from TokenChunks."""
response_id = f"resp_{command_id}"
item_id = f"item_{command_id}"
# response.created
initial_response = ResponsesResponse(
id=response_id,
model=model,
status="in_progress",
output=[],
output_text="",
)
created_event = ResponseCreatedEvent(response=initial_response)
yield f"event: response.created\ndata: {created_event.model_dump_json()}\n\n"
# response.in_progress
in_progress_event = ResponseInProgressEvent(response=initial_response)
yield f"event: response.in_progress\ndata: {in_progress_event.model_dump_json()}\n\n"
# response.output_item.added
initial_item = ResponseMessageItem(
id=item_id,
content=[ResponseOutputText(text="")],
status="in_progress",
)
item_added = ResponseOutputItemAddedEvent(output_index=0, item=initial_item)
yield f"event: response.output_item.added\ndata: {item_added.model_dump_json()}\n\n"
# response.content_part.added
initial_part = ResponseOutputText(text="")
part_added = ResponseContentPartAddedEvent(
output_index=0, content_index=0, part=initial_part
)
yield f"event: response.content_part.added\ndata: {part_added.model_dump_json()}\n\n"
accumulated_text = ""
last_stats = None
async for chunk in chunk_stream:
accumulated_text += chunk.text
last_stats = chunk.stats or last_stats
# response.output_text.delta
delta_event = ResponseTextDeltaEvent(
output_index=0,
content_index=0,
delta=chunk.text,
)
yield f"event: response.output_text.delta\ndata: {delta_event.model_dump_json()}\n\n"
# response.output_text.done
text_done = ResponseTextDoneEvent(
output_index=0, content_index=0, text=accumulated_text
)
yield f"event: response.output_text.done\ndata: {text_done.model_dump_json()}\n\n"
# response.content_part.done
final_part = ResponseOutputText(text=accumulated_text)
part_done = ResponseContentPartDoneEvent(
output_index=0, content_index=0, part=final_part
)
yield f"event: response.content_part.done\ndata: {part_done.model_dump_json()}\n\n"
# response.output_item.done
final_item = ResponseMessageItem(
id=item_id,
content=[ResponseOutputText(text=accumulated_text)],
status="completed",
)
item_done = ResponseOutputItemDoneEvent(output_index=0, item=final_item)
yield f"event: response.output_item.done\ndata: {item_done.model_dump_json()}\n\n"
# Create usage from stats if available
usage = None
if last_stats is not None:
usage = ResponseUsage(
input_tokens=last_stats.prompt_tokens,
output_tokens=last_stats.generation_tokens,
total_tokens=last_stats.prompt_tokens + last_stats.generation_tokens,
)
# response.completed
final_response = ResponsesResponse(
id=response_id,
model=model,
status="completed",
output=[final_item],
output_text=accumulated_text,
usage=usage,
)
completed_event = ResponseCompletedEvent(response=final_response)
yield f"event: response.completed\ndata: {completed_event.model_dump_json()}\n\n"

View File

@@ -17,6 +17,21 @@ from hypercorn.config import Config
from hypercorn.typing import ASGIFramework
from loguru import logger
from exo.master.adapters.chat_completions import (
chat_request_to_internal,
chunk_to_response,
collect_chat_response,
generate_chat_stream,
)
from exo.master.adapters.claude import (
claude_request_to_internal,
collect_claude_response,
generate_claude_stream,
)
from exo.master.adapters.responses import (
collect_responses_response,
generate_responses_stream,
)
from exo.master.image_store import ImageStore
from exo.master.placement import place_instance as get_instance_placements
from exo.shared.apply import apply
@@ -36,6 +51,7 @@ from exo.shared.types.api import (
ChatCompletionChoice,
ChatCompletionMessage,
ChatCompletionResponse,
ChatCompletionTaskParams,
CreateInstanceParams,
CreateInstanceResponse,
DeleteInstanceResponse,
@@ -55,9 +71,18 @@ from exo.shared.types.api import (
PlaceInstanceParams,
PlacementPreview,
PlacementPreviewResponse,
StreamingChoiceResponse,
)
from exo.shared.types.chunks import ImageChunk, InputImageChunk, TokenChunk
from exo.shared.types.chunks import (
ImageChunk,
InputImageChunk,
PrefillProgressData,
StreamEvent,
TokenChunk,
)
from exo.shared.types.claude_api import (
ClaudeMessagesRequest,
ClaudeMessagesResponse,
)
from exo.shared.types.commands import (
ChatCompletion,
Command,
@@ -76,10 +101,14 @@ from exo.shared.types.events import (
Event,
ForwarderEvent,
IndexedEvent,
PrefillProgress,
)
from exo.shared.types.memory import Memory
from exo.shared.types.openai_responses import (
ResponsesRequest,
ResponsesResponse,
)
from exo.shared.types.state import State
from exo.shared.types.tasks import ChatCompletionTaskParams
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.shards import Sharding
from exo.utils.banner import print_startup_banner
@@ -92,23 +121,6 @@ def _format_to_content_type(image_format: Literal["png", "jpeg", "webp"] | None)
return f"image/{image_format or 'png'}"
def chunk_to_response(
chunk: TokenChunk, command_id: CommandId
) -> ChatCompletionResponse:
return ChatCompletionResponse(
id=command_id,
created=int(time.time()),
model=chunk.model,
choices=[
StreamingChoiceResponse(
index=0,
delta=ChatCompletionMessage(role="assistant", content=chunk.text),
finish_reason=chunk.finish_reason,
)
],
)
async def resolve_model_card(model_id: ModelId) -> ModelCard:
if model_id in MODEL_CARDS:
model_card = MODEL_CARDS[model_id]
@@ -162,7 +174,7 @@ class API:
name="dashboard",
)
self._chat_completion_queues: dict[CommandId, Sender[TokenChunk]] = {}
self._chat_completion_queues: dict[CommandId, Sender[StreamEvent]] = {}
self._image_generation_queues: dict[CommandId, Sender[ImageChunk]] = {}
self._image_store = ImageStore(EXO_IMAGE_CACHE_DIR)
self._tg: TaskGroup | None = None
@@ -229,6 +241,8 @@ class API:
self.app.post("/bench/images/edits")(self.bench_image_edits)
self.app.get("/images")(self.list_images)
self.app.get("/images/{image_id}")(self.get_image)
self.app.post("/v1/messages", response_model=None)(self.claude_messages)
self.app.post("/v1/responses", response_model=None)(self.openai_responses)
self.app.get("/state")(lambda: self.state)
self.app.get("/events")(lambda: self._event_log)
@@ -437,18 +451,23 @@ class API:
instance_id=instance_id,
)
async def _chat_chunk_stream(
async def _stream_events(
self, command_id: CommandId
) -> AsyncGenerator[TokenChunk, None]:
"""Yield `TokenChunk`s for a given command until completion."""
) -> AsyncGenerator[StreamEvent, None]:
"""Yield stream events (TokenChunks or PrefillProgressData) for a command.
This is the internal low-level stream used by all API adapters.
"""
try:
self._chat_completion_queues[command_id], recv = channel[TokenChunk]()
self._chat_completion_queues[command_id], recv = channel[StreamEvent]()
with recv as token_chunks:
async for chunk in token_chunks:
yield chunk
if chunk.finish_reason is not None:
with recv as events:
async for event in events:
yield event
if (
isinstance(event, TokenChunk)
and event.finish_reason is not None
):
break
except anyio.get_cancelled_exc_class():
@@ -464,33 +483,36 @@ class API:
await self._send(command)
del self._chat_completion_queues[command_id]
async def _chat_chunk_stream(
self, command_id: CommandId
) -> AsyncGenerator[TokenChunk, None]:
"""Yield only TokenChunks, filtering out progress events."""
async for event in self._stream_events(command_id):
if isinstance(event, TokenChunk):
yield event
async def _generate_chat_stream(
self, command_id: CommandId
) -> AsyncGenerator[str, None]:
"""Generate chat completion stream as JSON strings."""
async for chunk in self._chat_chunk_stream(command_id):
if chunk.finish_reason == "error":
error_response = ErrorResponse(
error=ErrorInfo(
message=chunk.error_message or "Internal server error",
type="InternalServerError",
code=500,
)
async for event in self._stream_events(command_id):
if isinstance(event, PrefillProgressData):
# Send prefill progress as a named SSE event
progress_json = f'{{"processed":{event.processed_tokens},"total":{event.total_tokens}}}'
yield f"event: prefill_progress\ndata: {progress_json}\n\n"
else:
# TokenChunk - regular token generation
chunk_response: ChatCompletionResponse = chunk_to_response(
event, command_id
)
yield f"data: {error_response.model_dump_json()}\n\n"
yield "data: [DONE]\n\n"
return
logger.debug(f"chunk_response: {chunk_response}")
chunk_response: ChatCompletionResponse = chunk_to_response(
chunk, command_id
)
logger.debug(f"chunk_response: {chunk_response}")
yield f"data: {chunk_response.model_dump_json()}\n\n"
yield f"data: {chunk_response.model_dump_json()}\n\n"
if chunk.finish_reason is not None:
yield "data: [DONE]\n\n"
if event.finish_reason is not None:
yield "data: [DONE]\n\n"
async def _collect_chat_completion(
self, command_id: CommandId
@@ -502,12 +524,6 @@ class API:
finish_reason: FinishReason | None = None
async for chunk in self._chat_chunk_stream(command_id):
if chunk.finish_reason == "error":
raise HTTPException(
status_code=500,
detail=chunk.error_message or "Internal server error",
)
if model is None:
model = chunk.model
@@ -580,7 +596,7 @@ class API:
)
return resp
async def _trigger_notify_user_to_download_model(self, model_id: str) -> None:
async def _trigger_notify_user_to_download_model(self, model_id: ModelId) -> None:
logger.warning(
"TODO: we should send a notification to the user to download the model"
)
@@ -588,49 +604,67 @@ class API:
async def chat_completions(
self, payload: ChatCompletionTaskParams
) -> ChatCompletionResponse | StreamingResponse:
"""Handle chat completions, supporting both streaming and non-streaming responses."""
model_card = await resolve_model_card(ModelId(payload.model))
payload.model = model_card.model_id
"""OpenAI Chat Completions API - adapter."""
internal_params = chat_request_to_internal(payload)
if not any(
instance.shard_assignments.model_id == payload.model
instance.shard_assignments.model_id == internal_params.model
for instance in self.state.instances.values()
):
await self._trigger_notify_user_to_download_model(payload.model)
await self._trigger_notify_user_to_download_model(
ModelId(internal_params.model)
)
raise HTTPException(
status_code=404, detail=f"No instance found for model {payload.model}"
status_code=404,
detail=f"No instance found for model {internal_params.model}",
)
command = ChatCompletion(
request_params=payload,
)
command = ChatCompletion(request_params=internal_params)
await self._send(command)
if payload.stream:
return StreamingResponse(
self._generate_chat_stream(command.command_id),
generate_chat_stream(
command.command_id,
self._stream_events(command.command_id),
),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "close",
"X-Accel-Buffering": "no",
},
)
return await self._collect_chat_completion(command.command_id)
try:
return await collect_chat_response(
command.command_id,
self._chat_chunk_stream(command.command_id),
)
except ValueError as e:
raise HTTPException(status_code=500, detail=str(e)) from e
async def bench_chat_completions(
self, payload: BenchChatCompletionTaskParams
) -> BenchChatCompletionResponse:
model_card = await resolve_model_card(ModelId(payload.model))
payload.model = model_card.model_id
# Convert to internal format (BenchChatCompletionTaskParams extends ChatCompletionTaskParams)
internal_params = chat_request_to_internal(payload)
if not any(
instance.shard_assignments.model_id == payload.model
instance.shard_assignments.model_id == internal_params.model
for instance in self.state.instances.values()
):
await self._trigger_notify_user_to_download_model(payload.model)
await self._trigger_notify_user_to_download_model(
ModelId(internal_params.model)
)
raise HTTPException(
status_code=404, detail=f"No instance found for model {payload.model}"
status_code=404,
detail=f"No instance found for model {internal_params.model}",
)
payload.stream = False
internal_params.stream = False
command = ChatCompletion(request_params=payload)
command = ChatCompletion(request_params=internal_params)
await self._send(command)
response = await self._collect_chat_completion_with_stats(command.command_id)
@@ -1088,6 +1122,98 @@ class API:
response_format=response_format,
)
async def claude_messages(
self, payload: ClaudeMessagesRequest
) -> ClaudeMessagesResponse | StreamingResponse:
"""Claude Messages API - adapter."""
internal_params = claude_request_to_internal(payload)
model_card = await resolve_model_card(ModelId(internal_params.model))
internal_params.model = model_card.model_id
if not any(
instance.shard_assignments.model_id == internal_params.model
for instance in self.state.instances.values()
):
await self._trigger_notify_user_to_download_model(internal_params.model)
raise HTTPException(
status_code=404,
detail=f"No instance found for model {internal_params.model}",
)
command = ChatCompletion(request_params=internal_params)
await self._send(command)
if payload.stream:
return StreamingResponse(
generate_claude_stream(
command.command_id,
payload.model,
self._chat_chunk_stream(command.command_id),
),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "close",
"X-Accel-Buffering": "no",
},
)
try:
return await collect_claude_response(
command.command_id,
payload.model,
self._chat_chunk_stream(command.command_id),
)
except ValueError as e:
raise HTTPException(status_code=500, detail=str(e)) from e
async def openai_responses(
self, payload: ResponsesRequest
) -> ResponsesResponse | StreamingResponse:
"""OpenAI Responses API - native format."""
internal_params = payload
model_card = await resolve_model_card(internal_params.model)
internal_params.model = model_card.model_id
if not any(
instance.shard_assignments.model_id == internal_params.model
for instance in self.state.instances.values()
):
await self._trigger_notify_user_to_download_model(
ModelId(internal_params.model)
)
raise HTTPException(
status_code=404,
detail=f"No instance found for model {internal_params.model}",
)
command = ChatCompletion(request_params=internal_params)
await self._send(command)
if payload.stream:
return StreamingResponse(
generate_responses_stream(
command.command_id,
payload.model,
self._chat_chunk_stream(command.command_id),
),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "close",
"X-Accel-Buffering": "no",
},
)
try:
return await collect_responses_response(
command.command_id,
payload.model,
self._chat_chunk_stream(command.command_id),
)
except ValueError as e:
raise HTTPException(status_code=500, detail=str(e)) from e
def _calculate_total_available_memory(self) -> Memory:
"""Calculate total available memory across all nodes in bytes."""
total_available = Memory()
@@ -1169,6 +1295,18 @@ class API:
self._image_generation_queues.pop(
event.command_id, None
)
elif isinstance(event, PrefillProgress):
queue = self._chat_completion_queues.get(event.command_id)
if queue is not None:
try:
await queue.send(
PrefillProgressData(
processed_tokens=event.processed_tokens,
total_tokens=event.total_tokens,
)
)
except BrokenResourceError:
self._chat_completion_queues.pop(event.command_id, None)
async def _pause_on_new_election(self):
with self.election_receiver as ems:

View File

@@ -0,0 +1,284 @@
"""Tests for Claude Messages API conversion functions and types."""
import json
from typing import Any, cast
import pydantic
import pytest
from exo.master.adapters.claude import (
claude_request_to_internal,
finish_reason_to_claude_stop_reason,
)
from exo.shared.types.claude_api import (
ClaudeContentBlockDeltaEvent,
ClaudeContentBlockStartEvent,
ClaudeContentBlockStopEvent,
ClaudeMessage,
ClaudeMessageDelta,
ClaudeMessageDeltaEvent,
ClaudeMessageDeltaUsage,
ClaudeMessagesRequest,
ClaudeMessageStart,
ClaudeMessageStartEvent,
ClaudeMessageStopEvent,
ClaudeTextBlock,
ClaudeTextDelta,
ClaudeUsage,
)
from exo.shared.types.common import ModelId
class TestFinishReasonToClaudeStopReason:
"""Tests for finish_reason to Claude stop_reason mapping."""
def test_stop_maps_to_end_turn(self):
assert finish_reason_to_claude_stop_reason("stop") == "end_turn"
def test_length_maps_to_max_tokens(self):
assert finish_reason_to_claude_stop_reason("length") == "max_tokens"
def test_tool_calls_maps_to_tool_use(self):
assert finish_reason_to_claude_stop_reason("tool_calls") == "tool_use"
def test_function_call_maps_to_tool_use(self):
assert finish_reason_to_claude_stop_reason("function_call") == "tool_use"
def test_content_filter_maps_to_end_turn(self):
assert finish_reason_to_claude_stop_reason("content_filter") == "end_turn"
def test_none_returns_none(self):
assert finish_reason_to_claude_stop_reason(None) is None
class TestClaudeRequestToInternal:
"""Tests for converting Claude Messages API requests to ResponsesRequest."""
def test_basic_request_conversion(self):
request = ClaudeMessagesRequest(
model=ModelId("claude-3-opus"),
max_tokens=100,
messages=[
ClaudeMessage(role="user", content="Hello"),
],
)
params = claude_request_to_internal(request)
assert params.model == "claude-3-opus"
assert params.max_output_tokens == 100
assert isinstance(params.input, list)
assert len(params.input) == 1
assert params.input[0].role == "user"
assert params.input[0].content == "Hello"
assert params.instructions is None
def test_request_with_system_string(self):
request = ClaudeMessagesRequest(
model=ModelId("claude-3-opus"),
max_tokens=100,
system="You are a helpful assistant.",
messages=[
ClaudeMessage(role="user", content="Hello"),
],
)
params = claude_request_to_internal(request)
assert params.instructions == "You are a helpful assistant."
assert isinstance(params.input, list)
assert len(params.input) == 1
assert params.input[0].role == "user"
assert params.input[0].content == "Hello"
def test_request_with_system_text_blocks(self):
request = ClaudeMessagesRequest(
model=ModelId("claude-3-opus"),
max_tokens=100,
system=[
ClaudeTextBlock(text="You are helpful. "),
ClaudeTextBlock(text="Be concise."),
],
messages=[
ClaudeMessage(role="user", content="Hello"),
],
)
params = claude_request_to_internal(request)
assert params.instructions == "You are helpful. Be concise."
assert isinstance(params.input, list)
assert len(params.input) == 1
def test_request_with_content_blocks(self):
request = ClaudeMessagesRequest(
model=ModelId("claude-3-opus"),
max_tokens=100,
messages=[
ClaudeMessage(
role="user",
content=[
ClaudeTextBlock(text="First part. "),
ClaudeTextBlock(text="Second part."),
],
),
],
)
params = claude_request_to_internal(request)
assert isinstance(params.input, list)
assert len(params.input) == 1
assert params.input[0].content == "First part. Second part."
def test_request_with_multi_turn_conversation(self):
request = ClaudeMessagesRequest(
model=ModelId("claude-3-opus"),
max_tokens=100,
messages=[
ClaudeMessage(role="user", content="Hello"),
ClaudeMessage(role="assistant", content="Hi there!"),
ClaudeMessage(role="user", content="How are you?"),
],
)
params = claude_request_to_internal(request)
assert isinstance(params.input, list)
assert len(params.input) == 3
assert params.input[0].role == "user"
assert params.input[1].role == "assistant"
assert params.input[2].role == "user"
def test_request_with_optional_parameters(self):
request = ClaudeMessagesRequest(
model=ModelId("claude-3-opus"),
max_tokens=100,
messages=[ClaudeMessage(role="user", content="Hello")],
temperature=0.7,
top_p=0.9,
top_k=40,
stop_sequences=["STOP", "END"],
stream=True,
)
params = claude_request_to_internal(request)
assert params.temperature == 0.7
assert params.top_p == 0.9
assert params.top_k == 40
assert params.stop == ["STOP", "END"]
assert params.stream is True
class TestClaudeMessagesRequestValidation:
"""Tests for Claude Messages API request validation."""
def test_request_requires_model(self):
with pytest.raises(pydantic.ValidationError):
ClaudeMessagesRequest.model_validate(
{
"max_tokens": 100,
"messages": [{"role": "user", "content": "Hello"}],
}
)
def test_request_requires_max_tokens(self):
with pytest.raises(pydantic.ValidationError):
ClaudeMessagesRequest.model_validate(
{
"model": "claude-3-opus",
"messages": [{"role": "user", "content": "Hello"}],
}
)
def test_request_requires_messages(self):
with pytest.raises(pydantic.ValidationError):
ClaudeMessagesRequest.model_validate(
{
"model": "claude-3-opus",
"max_tokens": 100,
}
)
class TestClaudeStreamingEvents:
"""Tests for Claude Messages API streaming event serialization."""
def test_message_start_event_format(self):
message = ClaudeMessageStart(
id="msg_123",
model="claude-3-opus",
content=[],
stop_reason=None,
usage=ClaudeUsage(input_tokens=10, output_tokens=0),
)
event = ClaudeMessageStartEvent(message=message)
json_str = event.model_dump_json()
parsed = cast(dict[str, Any], json.loads(json_str))
assert parsed["type"] == "message_start"
assert parsed["message"]["id"] == "msg_123"
assert parsed["message"]["type"] == "message"
assert parsed["message"]["role"] == "assistant"
assert parsed["message"]["model"] == "claude-3-opus"
def test_content_block_start_event_format(self):
event = ClaudeContentBlockStartEvent(
index=0,
content_block=ClaudeTextBlock(text=""),
)
json_str = event.model_dump_json()
parsed = cast(dict[str, Any], json.loads(json_str))
assert parsed["type"] == "content_block_start"
assert parsed["index"] == 0
assert parsed["content_block"]["type"] == "text"
assert parsed["content_block"]["text"] == ""
def test_content_block_delta_event_format(self):
event = ClaudeContentBlockDeltaEvent(
index=0,
delta=ClaudeTextDelta(text="Hello"),
)
json_str = event.model_dump_json()
parsed = cast(dict[str, Any], json.loads(json_str))
assert parsed["type"] == "content_block_delta"
assert parsed["index"] == 0
assert parsed["delta"]["type"] == "text_delta"
assert parsed["delta"]["text"] == "Hello"
def test_content_block_stop_event_format(self):
event = ClaudeContentBlockStopEvent(index=0)
json_str = event.model_dump_json()
parsed = cast(dict[str, Any], json.loads(json_str))
assert parsed["type"] == "content_block_stop"
assert parsed["index"] == 0
def test_message_delta_event_format(self):
event = ClaudeMessageDeltaEvent(
delta=ClaudeMessageDelta(stop_reason="end_turn"),
usage=ClaudeMessageDeltaUsage(output_tokens=25),
)
json_str = event.model_dump_json()
parsed = cast(dict[str, Any], json.loads(json_str))
assert parsed["type"] == "message_delta"
assert parsed["delta"]["stop_reason"] == "end_turn"
assert parsed["usage"]["output_tokens"] == 25
def test_message_stop_event_format(self):
event = ClaudeMessageStopEvent()
json_str = event.model_dump_json()
parsed = cast(dict[str, Any], json.loads(json_str))
assert parsed["type"] == "message_stop"
def test_sse_format(self):
"""Test that SSE format is correctly generated."""
event = ClaudeContentBlockDeltaEvent(
index=0,
delta=ClaudeTextDelta(text="Hello"),
)
# Simulate the SSE format used in the streaming generator
sse_line = f"event: content_block_delta\ndata: {event.model_dump_json()}\n\n"
assert sse_line.startswith("event: content_block_delta\n")
assert "data: " in sse_line
assert sse_line.endswith("\n\n")

View File

@@ -7,15 +7,14 @@ from loguru import logger
from exo.master.main import Master
from exo.routing.router import get_node_id_keypair
from exo.shared.models.model_cards import ModelCard, ModelId, ModelTask
from exo.shared.types.api import ChatCompletionMessage, ChatCompletionTaskParams
from exo.shared.models.model_cards import ModelCard, ModelTask
from exo.shared.types.commands import (
ChatCompletion,
CommandId,
ForwarderCommand,
PlaceInstance,
)
from exo.shared.types.common import NodeId, SessionId
from exo.shared.types.common import ModelId, NodeId, SessionId
from exo.shared.types.events import (
ForwarderEvent,
IndexedEvent,
@@ -24,6 +23,7 @@ from exo.shared.types.events import (
TaskCreated,
)
from exo.shared.types.memory import Memory
from exo.shared.types.openai_responses import ResponsesRequest
from exo.shared.types.profiling import (
MemoryUsage,
)
@@ -134,13 +134,9 @@ async def test_master():
command=(
ChatCompletion(
command_id=CommandId(),
request_params=ChatCompletionTaskParams(
model="llama-3.2-1b",
messages=[
ChatCompletionMessage(
role="user", content="Hello, how are you?"
)
],
request_params=ResponsesRequest(
model=ModelId("llama-3.2-1b"),
input="Hello, how are you?",
),
)
),
@@ -191,11 +187,9 @@ async def test_master():
assert isinstance(events[2].event, TaskCreated)
assert events[2].event.task.task_status == TaskStatus.Pending
assert isinstance(events[2].event.task, ChatCompletionTask)
assert events[2].event.task.task_params == ChatCompletionTaskParams(
model="llama-3.2-1b",
messages=[
ChatCompletionMessage(role="user", content="Hello, how are you?")
],
assert events[2].event.task.task_params == ResponsesRequest(
model=ModelId("llama-3.2-1b"),
input="Hello, how are you?",
)
await master.shutdown()

View File

@@ -0,0 +1,294 @@
"""Tests for OpenAI Responses API types.
ResponsesRequest is the canonical internal type used throughout the pipeline.
No conversion is needed for Responses API requests.
"""
import json
from typing import Any, cast
import pydantic
import pytest
from exo.shared.types.common import ModelId
from exo.shared.types.openai_responses import (
ResponseCompletedEvent,
ResponseContentPartAddedEvent,
ResponseCreatedEvent,
ResponseInputMessage,
ResponseMessageItem,
ResponseOutputItemAddedEvent,
ResponseOutputItemDoneEvent,
ResponseOutputText,
ResponsesRequest,
ResponsesResponse,
ResponseTextDeltaEvent,
ResponseTextDoneEvent,
ResponseUsage,
)
class TestResponsesRequestAsCanonicalType:
"""Tests for ResponsesRequest as the canonical internal type."""
def test_string_input(self):
request = ResponsesRequest(
model=ModelId("gpt-4o"),
input="Hello, how are you?",
)
assert request.model == "gpt-4o"
assert request.input == "Hello, how are you?"
assert request.instructions is None
def test_message_array_input(self):
request = ResponsesRequest(
model=ModelId("gpt-4o"),
input=[
ResponseInputMessage(role="user", content="Hello"),
ResponseInputMessage(role="assistant", content="Hi there!"),
ResponseInputMessage(role="user", content="How are you?"),
],
)
assert isinstance(request.input, list)
assert len(request.input) == 3
assert request.input[0].role == "user"
assert request.input[0].content == "Hello"
assert request.input[1].role == "assistant"
assert request.input[1].content == "Hi there!"
assert request.input[2].role == "user"
assert request.input[2].content == "How are you?"
def test_request_with_instructions(self):
request = ResponsesRequest(
model=ModelId("gpt-4o"),
input="Hello",
instructions="You are a helpful assistant. Be concise.",
)
assert request.input == "Hello"
assert request.instructions == "You are a helpful assistant. Be concise."
def test_request_with_optional_parameters(self):
request = ResponsesRequest(
model=ModelId("gpt-4o"),
input="Hello",
max_output_tokens=500,
temperature=0.8,
top_p=0.95,
stream=True,
)
assert request.max_output_tokens == 500
assert request.temperature == 0.8
assert request.top_p == 0.95
assert request.stream is True
def test_request_with_new_fields(self):
"""Test the additional fields added for internal use."""
request = ResponsesRequest(
model=ModelId("gpt-4o"),
input="Hello",
top_k=40,
seed=42,
stop=["STOP", "END"],
tools=[{"type": "function", "function": {"name": "test"}}],
)
assert request.top_k == 40
assert request.seed == 42
assert request.stop == ["STOP", "END"]
assert request.tools == [{"type": "function", "function": {"name": "test"}}]
def test_request_with_system_role_in_messages(self):
request = ResponsesRequest(
model=ModelId("gpt-4o"),
input=[
ResponseInputMessage(role="system", content="Be helpful"),
ResponseInputMessage(role="user", content="Hello"),
],
)
assert isinstance(request.input, list)
assert len(request.input) == 2
assert request.input[0].role == "system"
assert request.input[1].role == "user"
def test_request_with_developer_role(self):
request = ResponsesRequest(
model=ModelId("gpt-4o"),
input=[
ResponseInputMessage(role="developer", content="Internal note"),
ResponseInputMessage(role="user", content="Hello"),
],
)
assert isinstance(request.input, list)
assert len(request.input) == 2
assert request.input[0].role == "developer"
class TestResponsesRequestValidation:
"""Tests for OpenAI Responses API request validation."""
def test_request_requires_model(self):
with pytest.raises(pydantic.ValidationError):
ResponsesRequest.model_validate(
{
"input": "Hello",
}
)
def test_request_requires_input(self):
with pytest.raises(pydantic.ValidationError):
ResponsesRequest.model_validate(
{
"model": "gpt-4o",
}
)
def test_request_accepts_string_input(self):
request = ResponsesRequest(
model=ModelId("gpt-4o"),
input="Hello",
)
assert request.input == "Hello"
def test_request_accepts_message_array_input(self):
request = ResponsesRequest(
model=ModelId("gpt-4o"),
input=[ResponseInputMessage(role="user", content="Hello")],
)
assert len(request.input) == 1
class TestResponsesStreamingEvents:
"""Tests for OpenAI Responses API streaming event serialization."""
def test_response_created_event_format(self):
response = ResponsesResponse(
id="resp_123",
model="gpt-4o",
status="in_progress",
output=[],
output_text="",
)
event = ResponseCreatedEvent(response=response)
json_str = event.model_dump_json()
parsed = cast(dict[str, Any], json.loads(json_str))
assert parsed["type"] == "response.created"
assert parsed["response"]["id"] == "resp_123"
assert parsed["response"]["object"] == "response"
assert parsed["response"]["status"] == "in_progress"
def test_output_item_added_event_format(self):
item = ResponseMessageItem(
id="item_123",
content=[ResponseOutputText(text="")],
status="in_progress",
)
event = ResponseOutputItemAddedEvent(output_index=0, item=item)
json_str = event.model_dump_json()
parsed = cast(dict[str, Any], json.loads(json_str))
assert parsed["type"] == "response.output_item.added"
assert parsed["output_index"] == 0
assert parsed["item"]["type"] == "message"
assert parsed["item"]["id"] == "item_123"
assert parsed["item"]["role"] == "assistant"
def test_content_part_added_event_format(self):
part = ResponseOutputText(text="")
event = ResponseContentPartAddedEvent(
output_index=0,
content_index=0,
part=part,
)
json_str = event.model_dump_json()
parsed = cast(dict[str, Any], json.loads(json_str))
assert parsed["type"] == "response.content_part.added"
assert parsed["output_index"] == 0
assert parsed["content_index"] == 0
assert parsed["part"]["type"] == "output_text"
def test_text_delta_event_format(self):
event = ResponseTextDeltaEvent(
output_index=0,
content_index=0,
delta="Hello",
)
json_str = event.model_dump_json()
parsed = cast(dict[str, Any], json.loads(json_str))
assert parsed["type"] == "response.output_text.delta"
assert parsed["output_index"] == 0
assert parsed["content_index"] == 0
assert parsed["delta"] == "Hello"
def test_text_done_event_format(self):
event = ResponseTextDoneEvent(
output_index=0,
content_index=0,
text="Hello, world!",
)
json_str = event.model_dump_json()
parsed = cast(dict[str, Any], json.loads(json_str))
assert parsed["type"] == "response.output_text.done"
assert parsed["text"] == "Hello, world!"
def test_output_item_done_event_format(self):
item = ResponseMessageItem(
id="item_123",
content=[ResponseOutputText(text="Hello, world!")],
status="completed",
)
event = ResponseOutputItemDoneEvent(output_index=0, item=item)
json_str = event.model_dump_json()
parsed = cast(dict[str, Any], json.loads(json_str))
assert parsed["type"] == "response.output_item.done"
assert parsed["item"]["status"] == "completed"
assert parsed["item"]["content"][0]["text"] == "Hello, world!"
def test_response_completed_event_format(self):
item = ResponseMessageItem(
id="item_123",
content=[ResponseOutputText(text="Hello!")],
status="completed",
)
response = ResponsesResponse(
id="resp_123",
model="gpt-4o",
status="completed",
output=[item],
output_text="Hello!",
usage=ResponseUsage(input_tokens=10, output_tokens=5, total_tokens=15),
)
event = ResponseCompletedEvent(response=response)
json_str = event.model_dump_json()
parsed = cast(dict[str, Any], json.loads(json_str))
assert parsed["type"] == "response.completed"
assert parsed["response"]["status"] == "completed"
assert parsed["response"]["output_text"] == "Hello!"
assert parsed["response"]["usage"]["total_tokens"] == 15
def test_sse_format(self):
"""Test that SSE format is correctly generated."""
event = ResponseTextDeltaEvent(
output_index=0,
content_index=0,
delta="Hello",
)
# Simulate the SSE format used in the streaming generator
sse_line = (
f"event: response.output_text.delta\ndata: {event.model_dump_json()}\n\n"
)
assert sse_line.startswith("event: response.output_text.delta\n")
assert "data: " in sse_line
assert sse_line.endswith("\n\n")

View File

@@ -15,6 +15,7 @@ from exo.shared.types.events import (
NodeDownloadProgress,
NodeGatheredInfo,
NodeTimedOut,
PrefillProgress,
RunnerDeleted,
RunnerStatusUpdated,
TaskAcknowledged,
@@ -53,7 +54,11 @@ def event_apply(event: Event, state: State) -> State:
"""Apply an event to state."""
match event:
case (
TestEvent() | ChunkGenerated() | TaskAcknowledged() | InputChunkReceived()
TestEvent()
| ChunkGenerated()
| TaskAcknowledged()
| InputChunkReceived()
| PrefillProgress()
): # Pass-through events that don't modify state
return state
case InstanceCreated():

View File

@@ -158,7 +158,7 @@ class BenchChatCompletionResponse(ChatCompletionResponse):
class ChatCompletionTaskParams(BaseModel):
model: str
model: ModelId
frequency_penalty: float | None = None
messages: list[ChatCompletionMessage]
logit_bias: dict[str, int] | None = None
@@ -173,10 +173,13 @@ class ChatCompletionTaskParams(BaseModel):
stream: bool = False
temperature: float | None = None
top_p: float | None = None
top_k: int | None = None
tools: list[dict[str, Any]] | None = None
tool_choice: str | dict[str, Any] | None = None
parallel_tool_calls: bool | None = None
user: str | None = None
# When True, continue the last assistant message without EOS tokens
continue_from_prefix: bool = False
class BenchChatCompletionTaskParams(ChatCompletionTaskParams):

View File

@@ -2,8 +2,8 @@ from collections.abc import Generator
from enum import Enum
from typing import Any, Literal
from exo.shared.models.model_cards import ModelId
from exo.shared.types.api import GenerationStats, ImageGenerationStats
from exo.shared.types.api import GenerationStats, ImageGenerationStats, TopLogprobItem
from exo.shared.types.common import ModelId
from exo.utils.pydantic_ext import TaggedModel
from .api import FinishReason
@@ -23,6 +23,8 @@ class BaseChunk(TaggedModel):
class TokenChunk(BaseChunk):
text: str
token_id: int
logprob: float | None = None # Log probability of the selected token
top_logprobs: list[TopLogprobItem] | None = None # Top-k alternative tokens
finish_reason: FinishReason | None = None
stats: GenerationStats | None = None
error_message: str | None = None
@@ -64,3 +66,14 @@ class InputImageChunk(BaseChunk):
GenerationChunk = TokenChunk | ImageChunk
class PrefillProgressData(TaggedModel):
"""Data class for prefill progress events during streaming."""
processed_tokens: int
total_tokens: int
# Stream events can be either token chunks or prefill progress
StreamEvent = TokenChunk | PrefillProgressData

View File

@@ -0,0 +1,170 @@
"""Claude Messages API types for request/response conversion."""
from typing import Literal
from pydantic import BaseModel, Field
from exo.shared.types.common import ModelId
# Type aliases
ClaudeRole = Literal["user", "assistant"]
ClaudeStopReason = Literal["end_turn", "max_tokens", "stop_sequence", "tool_use"]
# Content block types
class ClaudeTextBlock(BaseModel, frozen=True):
"""Text content block in Claude Messages API."""
type: Literal["text"] = "text"
text: str
class ClaudeImageSource(BaseModel, frozen=True):
"""Image source for Claude image blocks."""
type: Literal["base64", "url"]
media_type: str | None = None
data: str | None = None
url: str | None = None
class ClaudeImageBlock(BaseModel, frozen=True):
"""Image content block in Claude Messages API."""
type: Literal["image"] = "image"
source: ClaudeImageSource
ClaudeContentBlock = ClaudeTextBlock | ClaudeImageBlock
# Request types
class ClaudeMessage(BaseModel, frozen=True):
"""Message in Claude Messages API request."""
role: ClaudeRole
content: str | list[ClaudeContentBlock]
class ClaudeMessagesRequest(BaseModel):
"""Request body for Claude Messages API."""
model: ModelId
max_tokens: int
messages: list[ClaudeMessage]
system: str | list[ClaudeTextBlock] | None = None
stop_sequences: list[str] | None = None
stream: bool = False
temperature: float | None = None
top_p: float | None = None
top_k: int | None = None
metadata: dict[str, str] | None = None
# Response types
class ClaudeUsage(BaseModel, frozen=True):
"""Token usage in Claude Messages API response."""
input_tokens: int
output_tokens: int
class ClaudeMessagesResponse(BaseModel, frozen=True):
"""Response body for Claude Messages API."""
id: str
type: Literal["message"] = "message"
role: Literal["assistant"] = "assistant"
content: list[ClaudeTextBlock]
model: str
stop_reason: ClaudeStopReason | None = None
stop_sequence: str | None = None
usage: ClaudeUsage
# Streaming event types
class ClaudeMessageStart(BaseModel, frozen=True):
"""Partial message in message_start event."""
id: str
type: Literal["message"] = "message"
role: Literal["assistant"] = "assistant"
content: list[ClaudeTextBlock] = Field(default_factory=list)
model: str
stop_reason: ClaudeStopReason | None = None
stop_sequence: str | None = None
usage: ClaudeUsage
class ClaudeMessageStartEvent(BaseModel, frozen=True):
"""Event sent at start of message stream."""
type: Literal["message_start"] = "message_start"
message: ClaudeMessageStart
class ClaudeContentBlockStartEvent(BaseModel, frozen=True):
"""Event sent at start of a content block."""
type: Literal["content_block_start"] = "content_block_start"
index: int
content_block: ClaudeTextBlock
class ClaudeTextDelta(BaseModel, frozen=True):
"""Delta for text content block."""
type: Literal["text_delta"] = "text_delta"
text: str
class ClaudeContentBlockDeltaEvent(BaseModel, frozen=True):
"""Event sent for content block delta."""
type: Literal["content_block_delta"] = "content_block_delta"
index: int
delta: ClaudeTextDelta
class ClaudeContentBlockStopEvent(BaseModel, frozen=True):
"""Event sent at end of a content block."""
type: Literal["content_block_stop"] = "content_block_stop"
index: int
class ClaudeMessageDeltaUsage(BaseModel, frozen=True):
"""Usage in message_delta event."""
output_tokens: int
class ClaudeMessageDelta(BaseModel, frozen=True):
"""Delta in message_delta event."""
stop_reason: ClaudeStopReason | None = None
stop_sequence: str | None = None
class ClaudeMessageDeltaEvent(BaseModel, frozen=True):
"""Event sent with final message delta."""
type: Literal["message_delta"] = "message_delta"
delta: ClaudeMessageDelta
usage: ClaudeMessageDeltaUsage
class ClaudeMessageStopEvent(BaseModel, frozen=True):
"""Event sent at end of message stream."""
type: Literal["message_stop"] = "message_stop"
ClaudeStreamEvent = (
ClaudeMessageStartEvent
| ClaudeContentBlockStartEvent
| ClaudeContentBlockDeltaEvent
| ClaudeContentBlockStopEvent
| ClaudeMessageDeltaEvent
| ClaudeMessageStopEvent
)

View File

@@ -2,12 +2,12 @@ from pydantic import Field
from exo.shared.models.model_cards import ModelCard
from exo.shared.types.api import (
ChatCompletionTaskParams,
ImageEditsInternalParams,
ImageGenerationTaskParams,
)
from exo.shared.types.chunks import InputImageChunk
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.openai_responses import ResponsesRequest
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.shards import Sharding
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
@@ -22,7 +22,7 @@ class TestCommand(BaseCommand):
class ChatCompletion(BaseCommand):
request_params: ChatCompletionTaskParams
request_params: ResponsesRequest
class ImageGeneration(BaseCommand):

View File

@@ -101,6 +101,12 @@ class InputChunkReceived(BaseEvent):
chunk: InputImageChunk
class PrefillProgress(BaseEvent):
command_id: CommandId
processed_tokens: int
total_tokens: int
class TopologyEdgeCreated(BaseEvent):
conn: Connection
@@ -125,6 +131,7 @@ Event = (
| NodeDownloadProgress
| ChunkGenerated
| InputChunkReceived
| PrefillProgress
| TopologyEdgeCreated
| TopologyEdgeDeleted
)

View File

@@ -0,0 +1,192 @@
"""OpenAI Responses API types for request/response conversion.
ResponsesRequest serves as both:
1. The external API request type for /v1/responses
2. The canonical internal type used throughout the inference pipeline
All external API formats (Chat Completions, Claude) are converted to
ResponsesRequest at the API boundary.
"""
import time
from typing import Any, Literal
from pydantic import BaseModel, Field
from exo.shared.types.common import ModelId
# Type aliases
ResponseStatus = Literal["completed", "failed", "in_progress", "incomplete"]
ResponseRole = Literal["user", "assistant", "system", "developer"]
# Request types
class ResponseInputMessage(BaseModel, frozen=True):
"""Input message for Responses API.
This is also used as the internal message format throughout the pipeline.
"""
role: ResponseRole
content: str
class ResponsesRequest(BaseModel):
"""Request body for OpenAI Responses API.
This is also the canonical internal task params format used throughout
the inference pipeline. All external API formats are converted to this
format at the API boundary.
Field mapping from other APIs:
- input: Replaces 'messages' from Chat Completions
- instructions: System message, extracted from messages or Claude's 'system'
- max_output_tokens: Replaces 'max_tokens' from Chat Completions
"""
model: ModelId
input: str | list[ResponseInputMessage]
instructions: str | None = None
max_output_tokens: int | None = None
temperature: float | None = None
top_p: float | None = None
top_k: int | None = None
stop: str | list[str] | None = None
seed: int | None = None
stream: bool = False
# Tools support
tools: list[dict[str, Any]] | None = None
# previous_response_id not supported in MVP
metadata: dict[str, str] | None = None
# When True, continue the last assistant message without EOS tokens
continue_from_prefix: bool = False
# Response types
class ResponseOutputText(BaseModel, frozen=True):
"""Text content in response output."""
type: Literal["output_text"] = "output_text"
text: str
annotations: list[dict[str, str]] = Field(default_factory=list)
class ResponseMessageItem(BaseModel, frozen=True):
"""Message item in response output array."""
type: Literal["message"] = "message"
id: str
role: Literal["assistant"] = "assistant"
content: list[ResponseOutputText]
status: ResponseStatus = "completed"
ResponseItem = ResponseMessageItem # Can expand for function_call, reasoning, etc.
class ResponseUsage(BaseModel, frozen=True):
"""Token usage in Responses API response."""
input_tokens: int
output_tokens: int
total_tokens: int
class ResponsesResponse(BaseModel, frozen=True):
"""Response body for OpenAI Responses API."""
id: str
object: Literal["response"] = "response"
created_at: int = Field(default_factory=lambda: int(time.time()))
status: ResponseStatus = "completed"
model: str
output: list[ResponseItem]
output_text: str
usage: ResponseUsage | None = None
# Streaming event types
class ResponseCreatedEvent(BaseModel, frozen=True):
"""Event sent when response is created."""
type: Literal["response.created"] = "response.created"
response: ResponsesResponse
class ResponseInProgressEvent(BaseModel, frozen=True):
"""Event sent when response starts processing."""
type: Literal["response.in_progress"] = "response.in_progress"
response: ResponsesResponse
class ResponseOutputItemAddedEvent(BaseModel, frozen=True):
"""Event sent when an output item is added."""
type: Literal["response.output_item.added"] = "response.output_item.added"
output_index: int
item: ResponseItem
class ResponseContentPartAddedEvent(BaseModel, frozen=True):
"""Event sent when a content part is added."""
type: Literal["response.content_part.added"] = "response.content_part.added"
output_index: int
content_index: int
part: ResponseOutputText
class ResponseTextDeltaEvent(BaseModel, frozen=True):
"""Event sent for text delta during streaming."""
type: Literal["response.output_text.delta"] = "response.output_text.delta"
output_index: int
content_index: int
delta: str
class ResponseTextDoneEvent(BaseModel, frozen=True):
"""Event sent when text content is done."""
type: Literal["response.output_text.done"] = "response.output_text.done"
output_index: int
content_index: int
text: str
class ResponseContentPartDoneEvent(BaseModel, frozen=True):
"""Event sent when a content part is done."""
type: Literal["response.content_part.done"] = "response.content_part.done"
output_index: int
content_index: int
part: ResponseOutputText
class ResponseOutputItemDoneEvent(BaseModel, frozen=True):
"""Event sent when an output item is done."""
type: Literal["response.output_item.done"] = "response.output_item.done"
output_index: int
item: ResponseItem
class ResponseCompletedEvent(BaseModel, frozen=True):
"""Event sent when response is completed."""
type: Literal["response.completed"] = "response.completed"
response: ResponsesResponse
ResponsesStreamEvent = (
ResponseCreatedEvent
| ResponseInProgressEvent
| ResponseOutputItemAddedEvent
| ResponseContentPartAddedEvent
| ResponseTextDeltaEvent
| ResponseTextDoneEvent
| ResponseContentPartDoneEvent
| ResponseOutputItemDoneEvent
| ResponseCompletedEvent
)

View File

@@ -3,11 +3,11 @@ from enum import Enum
from pydantic import Field
from exo.shared.types.api import (
ChatCompletionTaskParams,
ImageEditsInternalParams,
ImageGenerationTaskParams,
)
from exo.shared.types.common import CommandId, Id
from exo.shared.types.openai_responses import ResponsesRequest
from exo.shared.types.worker.instances import BoundInstance, InstanceId
from exo.shared.types.worker.runners import RunnerId
from exo.shared.types.worker.shards import ShardMetadata
@@ -54,7 +54,7 @@ class StartWarmup(BaseTask): # emitted by Worker
class ChatCompletion(BaseTask): # emitted by Master
command_id: CommandId
task_params: ChatCompletionTaskParams
task_params: ResponsesRequest
error_type: str | None = Field(default=None)
error_message: str | None = Field(default=None)

View File

@@ -1,7 +1,12 @@
from collections.abc import Generator
from typing import Any, Literal
from exo.shared.types.api import FinishReason, GenerationStats, ImageGenerationStats
from exo.shared.types.api import (
FinishReason,
GenerationStats,
ImageGenerationStats,
TopLogprobItem,
)
from exo.utils.pydantic_ext import TaggedModel
@@ -16,7 +21,8 @@ class TokenizedResponse(BaseRunnerResponse):
class GenerationResponse(BaseRunnerResponse):
text: str
token: int
# logprobs: list[float] | None = None # too big. we can change to be top-k
logprob: float | None = None # Log probability of the selected token
top_logprobs: list[TopLogprobItem] | None = None # Top-k alternative tokens
finish_reason: FinishReason | None = None
stats: GenerationStats | None = None
@@ -50,3 +56,8 @@ class PartialImageResponse(BaseRunnerResponse):
class FinishedResponse(BaseRunnerResponse):
pass
class PrefillProgressResponse(BaseRunnerResponse):
processed_tokens: int
total_tokens: int

View File

@@ -1,3 +1,5 @@
from typing import Any
import mlx.core as mx
import mlx.nn as nn
from mlx_lm.models.cache import KVCache
@@ -15,3 +17,29 @@ class Model(nn.Module):
cache: list[KVCache] | None,
input_embeddings: mx.array | None = None,
) -> mx.array: ...
class Detokenizer:
def reset(self) -> None: ...
def add_token(self, token: int) -> None: ...
def finalize(self) -> None: ...
@property
def last_segment(self) -> str: ...
class TokenizerWrapper:
bos_token: str | None
eos_token_ids: list[int]
detokenizer: Detokenizer
def encode(self, text: str, add_special_tokens: bool = True) -> list[int]: ...
def apply_chat_template(
self,
messages_dicts: list[dict[str, Any]],
tokenize: bool = False,
add_generation_prompt: bool = True,
continue_final_message: bool = False,
tools: list[dict[str, Any]] | None = None,
) -> str: ...

View File

@@ -8,13 +8,13 @@ from mlx_lm.tokenizer_utils import TokenizerWrapper
# from exo.engines.mlx.cache import KVPrefixCache
from exo.shared.types.api import (
BenchChatCompletionTaskParams,
ChatCompletionMessage,
FinishReason,
GenerationStats,
TopLogprobItem,
)
from exo.shared.types.common import ModelId
from exo.shared.types.memory import Memory
from exo.shared.types.tasks import ChatCompletionTaskParams
from exo.shared.types.openai_responses import ResponsesRequest
from exo.shared.types.worker.runner_response import (
GenerationResponse,
)
@@ -53,14 +53,9 @@ def warmup_inference(
warmup_prompt = apply_chat_template(
tokenizer=tokenizer,
chat_task_data=ChatCompletionTaskParams(
model="",
messages=[
ChatCompletionMessage(
role="user",
content=content,
)
],
task_params=ResponsesRequest(
model=ModelId(""),
input=content,
),
)
@@ -81,7 +76,7 @@ def warmup_inference(
max_tokens=50,
sampler=sampler,
prompt_cache=cache,
prefill_step_size=2048,
prefill_step_size=256, # Temporarily reduced from 2048 for testing progress bar
kv_group_size=KV_GROUP_SIZE,
kv_bits=KV_BITS,
):
@@ -115,18 +110,70 @@ def eos_ids_from_tokenizer(tokenizer: TokenizerWrapper) -> list[int]:
return eos
def extract_top_logprobs(
logprobs: mx.array,
tokenizer: TokenizerWrapper,
top_k: int,
selected_token: int,
) -> tuple[float, list[TopLogprobItem]]:
"""Extract the selected token's logprob and top-k alternative tokens.
Args:
logprobs: Full vocabulary logprobs array from MLX
tokenizer: Tokenizer for decoding token IDs to strings
top_k: Number of top alternatives to return
selected_token: The token ID that was actually sampled
Returns:
Tuple of (selected_token_logprob, list of TopLogprobItem for top-k tokens)
"""
# Get the logprob of the selected token
selected_logprob = float(logprobs[selected_token].item())
# Get top-k indices (most probable tokens)
# mx.argpartition gives indices that would partition the array
# We negate logprobs since argpartition finds smallest, and we want largest
top_k = min(top_k, logprobs.shape[0]) # Don't exceed vocab size
top_indices = mx.argpartition(-logprobs, top_k)[:top_k]
# Get the actual logprob values for these indices
top_values = logprobs[top_indices]
# Sort by logprob (descending) for consistent ordering
sort_order = mx.argsort(-top_values)
top_indices = top_indices[sort_order]
top_values = top_values[sort_order]
# Convert to list of TopLogprobItem
top_logprob_items: list[TopLogprobItem] = []
for i in range(top_k):
token_id = int(top_indices[i].item())
token_logprob = float(top_values[i].item())
# Decode token ID to string
token_str = tokenizer.decode([token_id])
# Get byte representation
token_bytes = list(token_str.encode("utf-8"))
top_logprob_items.append(
TopLogprobItem(
token=token_str,
logprob=token_logprob,
bytes=token_bytes,
)
)
return selected_logprob, top_logprob_items
def mlx_generate(
model: Model,
tokenizer: TokenizerWrapper,
task: ChatCompletionTaskParams,
task: ResponsesRequest,
prompt: str,
is_bench: bool = False,
on_prefill_progress: Callable[[int, int], None] | None = None,
) -> Generator[GenerationResponse]:
# Ensure that generation stats only contains peak memory for this generation
mx.reset_peak_memory()
is_bench: bool = isinstance(task, BenchChatCompletionTaskParams)
# Currently we support chat-completion tasks only.
logger.debug(f"task_params: {task}")
if task.seed is not None:
mx.random.seed(task.seed)
@@ -142,9 +189,20 @@ def mlx_generate(
sampler = make_sampler(
temp=task.temperature if task.temperature is not None else 0.7,
top_p=task.top_p if task.top_p is not None else 1.0,
top_k=task.top_k if task.top_k is not None else 0,
)
max_tokens = task.max_tokens or MAX_TOKENS
# Normalize stop sequences to a list
stop_sequences: list[str] = (
([task.stop] if isinstance(task.stop, str) else task.stop)
if task.stop is not None
else []
)
max_stop_len = max((len(s) for s in stop_sequences), default=0)
max_tokens = task.max_output_tokens or MAX_TOKENS
accumulated_text = ""
for out in stream_generate(
model=model,
tokenizer=tokenizer,
@@ -154,14 +212,36 @@ def mlx_generate(
logits_processors=logits_processors,
prompt_cache=caches,
# TODO: Dynamically change prefill step size to be the maximum possible without timing out.
prefill_step_size=2048,
prefill_step_size=256, # Temporarily reduced from 2048 for testing progress bar
kv_group_size=KV_GROUP_SIZE,
kv_bits=KV_BITS,
prompt_progress_callback=on_prefill_progress,
):
logger.info(out.text)
accumulated_text += out.text
# Check for stop sequences
text = out.text
finish_reason: FinishReason | None = cast(
FinishReason | None, out.finish_reason
)
stop_matched = False
if stop_sequences:
for stop_seq in stop_sequences:
if stop_seq in accumulated_text:
# Trim text to just before the stop sequence
stop_index = accumulated_text.find(stop_seq)
text_before_stop = accumulated_text[:stop_index]
chunk_start = len(accumulated_text) - len(out.text)
text = text_before_stop[chunk_start:]
finish_reason = "stop"
stop_matched = True
break
is_done = finish_reason is not None
stats: GenerationStats | None = None
if out.finish_reason is not None:
if is_done:
stats = GenerationStats(
prompt_tps=float(out.prompt_tps),
generation_tps=float(out.generation_tps),
@@ -169,22 +249,33 @@ def mlx_generate(
generation_tokens=int(out.generation_tokens),
peak_memory_usage=Memory.from_gb(out.peak_memory),
)
if out.finish_reason not in get_args(FinishReason):
# We don't throw here as this failure case is really not all that bad
# Just log the error and move on
if not stop_matched and out.finish_reason not in get_args(FinishReason):
logger.warning(
f"Model generated unexpected finish_reason: {out.finish_reason}"
)
# Extract logprobs from the full vocabulary logprobs array
logprob, top_logprobs = extract_top_logprobs(
logprobs=out.logprobs,
tokenizer=tokenizer,
top_k=5,
selected_token=out.token,
)
yield GenerationResponse(
text=out.text,
text=text,
token=out.token,
finish_reason=cast(FinishReason | None, out.finish_reason),
logprob=logprob,
top_logprobs=top_logprobs,
finish_reason=finish_reason,
stats=stats,
)
if out.finish_reason is not None:
if is_done:
break
# Limit accumulated_text to what's needed for stop sequence detection
if max_stop_len > 0 and len(accumulated_text) > max_stop_len:
accumulated_text = accumulated_text[-max_stop_len:]
# TODO: Do we want an mx_barrier?

View File

@@ -41,10 +41,9 @@ import mlx.nn as nn
from mlx_lm.utils import load_model
from pydantic import RootModel
from exo.shared.types.api import ChatCompletionMessageText
from exo.shared.types.common import Host
from exo.shared.types.memory import Memory
from exo.shared.types.tasks import ChatCompletionTaskParams
from exo.shared.types.openai_responses import ResponsesRequest
from exo.shared.types.worker.instances import (
BoundInstance,
MlxJacclInstance,
@@ -367,35 +366,53 @@ def load_tokenizer_for_model_id(
def apply_chat_template(
tokenizer: TokenizerWrapper,
chat_task_data: ChatCompletionTaskParams,
task_params: ResponsesRequest,
) -> str:
# Now we can properly access the messages
messages = chat_task_data.messages
"""Convert ResponsesRequest to a chat template prompt.
Converts the internal format (input + instructions) to a messages list
that can be processed by the tokenizer's chat template.
"""
formatted_messages: list[dict[str, Any]] = []
for message in messages:
if isinstance(message.content, ChatCompletionMessageText):
message.content = message.content.text
if isinstance(message.content, list):
if len(message.content) == 0:
logger.warning("Received prompt with no content, skipping")
continue
message.content = "\n".join(c.text for c in message.content).strip()
if message.content is None and message.thinking is None:
continue
# Null values are not valid when applying templates in tokenizer
# Add system message (instructions) if present
if task_params.instructions:
formatted_messages.append(
{k: v for k, v in message.model_dump().items() if v is not None} # type: ignore
{"role": "system", "content": task_params.instructions}
)
prompt: str = tokenizer.apply_chat_template(
formatted_messages,
tokenize=False,
add_generation_prompt=True,
tools=chat_task_data.tools,
)
# Convert input to messages
if isinstance(task_params.input, str):
# Simple string input becomes a single user message
formatted_messages.append({"role": "user", "content": task_params.input})
else:
# List of InputMessage
for msg in task_params.input:
if not msg.content:
logger.warning("Received message with empty content, skipping")
continue
formatted_messages.append({"role": msg.role, "content": msg.content})
# Use continue_final_message when continuing from prefix (e.g., regenerate from token)
# This keeps the final assistant message open without EOS tokens
# Note: explicitly set add_generation_prompt=False when using continue_final_message
# because some tokenizers (e.g., Kimi) default add_generation_prompt=True
prompt: str
if task_params.continue_from_prefix:
prompt = tokenizer.apply_chat_template(
formatted_messages,
tokenize=False,
continue_final_message=True,
add_generation_prompt=False,
tools=task_params.tools,
)
else:
prompt = tokenizer.apply_chat_template(
formatted_messages,
tokenize=False,
add_generation_prompt=True,
tools=task_params.tools,
)
logger.info(prompt)

View File

@@ -15,17 +15,19 @@ from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
)
from exo.shared.constants import EXO_MAX_CHUNK_SIZE
from exo.shared.models.model_cards import ModelId, ModelTask
from exo.shared.types.api import ChatCompletionMessageText, ImageGenerationStats
from exo.shared.models.model_cards import ModelTask
from exo.shared.types.api import ImageGenerationStats
from exo.shared.types.chunks import ImageChunk, TokenChunk
from exo.shared.types.common import CommandId
from exo.shared.types.common import CommandId, ModelId
from exo.shared.types.events import (
ChunkGenerated,
Event,
PrefillProgress,
RunnerStatusUpdated,
TaskAcknowledged,
TaskStatusUpdated,
)
from exo.shared.types.openai_responses import ResponsesRequest
from exo.shared.types.tasks import (
ChatCompletion,
ConnectToGroup,
@@ -217,10 +219,20 @@ def main(
)
assert model and not isinstance(model, DistributedImageModel)
assert tokenizer
assert task_params.messages[0].content is not None
# Define callback to send prefill progress events directly
def on_prefill_progress(processed: int, total: int) -> None:
if device_rank == 0:
event_sender.send(
PrefillProgress(
command_id=command_id,
processed_tokens=processed,
total_tokens=total,
)
)
try:
_check_for_debug_prompts(task_params.messages[0].content)
_check_for_debug_prompts(task_params)
# Build prompt once - used for both generation and thinking detection
prompt = apply_chat_template(tokenizer, task_params)
@@ -231,6 +243,7 @@ def main(
tokenizer=tokenizer,
task=task_params,
prompt=prompt,
on_prefill_progress=on_prefill_progress,
)
# GPT-OSS specific parsing to match other model formats.
@@ -258,6 +271,8 @@ def main(
model=shard_metadata.model_card.model_id,
text=response.text,
token_id=response.token,
logprob=response.logprob,
top_logprobs=response.top_logprobs,
finish_reason=response.finish_reason,
stats=response.stats,
),
@@ -573,17 +588,23 @@ EXO_RUNNER_MUST_OOM = "EXO RUNNER MUST OOM"
EXO_RUNNER_MUST_TIMEOUT = "EXO RUNNER MUST TIMEOUT"
def _check_for_debug_prompts(
prompt: str | ChatCompletionMessageText | list[ChatCompletionMessageText],
):
if isinstance(prompt, list):
if len(prompt) == 0:
logger.debug("Empty message prompt received in debug prompt")
return
prompt = prompt[0]
def _check_for_debug_prompts(task_params: ResponsesRequest) -> None:
"""Check for debug prompt triggers in the input.
if isinstance(prompt, ChatCompletionMessageText):
prompt = prompt.text
Extracts the first user input text and checks for debug triggers.
"""
prompt: str
if isinstance(task_params.input, str):
prompt = task_params.input
else:
# List of InputMessage - get first message content
if len(task_params.input) == 0:
logger.debug("Empty message list in debug prompt check")
return
prompt = task_params.input[0].content
if not prompt:
return
if EXO_RUNNER_MUST_FAIL in prompt:
logger.info("raising exception")

View File

@@ -12,10 +12,9 @@ import mlx.nn as nn
from exo.shared.constants import EXO_MODELS_DIR
from exo.shared.models.model_cards import ModelCard, ModelTask
from exo.shared.types.api import ChatCompletionMessage
from exo.shared.types.common import ModelId
from exo.shared.types.memory import Memory
from exo.shared.types.tasks import ChatCompletionTaskParams
from exo.shared.types.openai_responses import ResponsesRequest
from exo.shared.types.worker.shards import PipelineShardMetadata, TensorShardMetadata
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.generator.generate import mlx_generate
@@ -113,10 +112,10 @@ def run_gpt_oss_pipeline_device(
tokens = tokens[:prompt_tokens]
prompt_text = tokenizer.decode(tokens)
task = ChatCompletionTaskParams(
task = ResponsesRequest(
model=DEFAULT_GPT_OSS_MODEL_ID,
messages=[ChatCompletionMessage(role="user", content=prompt_text)],
max_tokens=max_tokens,
input=prompt_text,
max_output_tokens=max_tokens,
)
prompt = apply_chat_template(tokenizer, task)
@@ -181,10 +180,10 @@ def run_gpt_oss_tensor_parallel_device(
tokens = tokens[:prompt_tokens]
prompt_text = tokenizer.decode(tokens)
task = ChatCompletionTaskParams(
task = ResponsesRequest(
model=DEFAULT_GPT_OSS_MODEL_ID,
messages=[ChatCompletionMessage(role="user", content=prompt_text)],
max_tokens=max_tokens,
input=prompt_text,
max_output_tokens=max_tokens,
)
prompt = apply_chat_template(tokenizer, task)

View File

@@ -1,7 +1,7 @@
from typing import cast
import exo.worker.plan as plan_mod
from exo.shared.types.api import ChatCompletionTaskParams
from exo.shared.types.openai_responses import ResponsesRequest
from exo.shared.types.tasks import ChatCompletion, Task, TaskId, TaskStatus
from exo.shared.types.worker.instances import BoundInstance, InstanceId
from exo.shared.types.worker.runners import (
@@ -59,7 +59,7 @@ def test_plan_forwards_pending_chat_completion_when_runner_ready():
instance_id=INSTANCE_1_ID,
task_status=TaskStatus.Pending,
command_id=COMMAND_1_ID,
task_params=ChatCompletionTaskParams(model=MODEL_A_ID, messages=[]),
task_params=ResponsesRequest(model=MODEL_A_ID, input=""),
)
result = plan_mod.plan(
@@ -107,7 +107,7 @@ def test_plan_does_not_forward_chat_completion_if_any_runner_not_ready():
instance_id=INSTANCE_1_ID,
task_status=TaskStatus.Pending,
command_id=COMMAND_1_ID,
task_params=ChatCompletionTaskParams(model=MODEL_A_ID, messages=[]),
task_params=ResponsesRequest(model=MODEL_A_ID, input=""),
)
result = plan_mod.plan(
@@ -152,7 +152,7 @@ def test_plan_does_not_forward_tasks_for_other_instances():
instance_id=other_instance_id,
task_status=TaskStatus.Pending,
command_id=COMMAND_1_ID,
task_params=ChatCompletionTaskParams(model=MODEL_A_ID, messages=[]),
task_params=ResponsesRequest(model=MODEL_A_ID, input=""),
)
result = plan_mod.plan(
@@ -201,7 +201,7 @@ def test_plan_ignores_non_pending_or_non_chat_tasks():
instance_id=INSTANCE_1_ID,
task_status=TaskStatus.Complete,
command_id=COMMAND_1_ID,
task_params=ChatCompletionTaskParams(model=MODEL_A_ID, messages=[]),
task_params=ResponsesRequest(model=MODEL_A_ID, input=""),
)
other_task_id = TaskId("other-task")

View File

@@ -5,7 +5,6 @@ from typing import Callable
import pytest
import exo.worker.runner.runner as mlx_runner
from exo.shared.types.api import ChatCompletionMessage
from exo.shared.types.chunks import TokenChunk
from exo.shared.types.events import (
ChunkGenerated,
@@ -14,9 +13,9 @@ from exo.shared.types.events import (
TaskAcknowledged,
TaskStatusUpdated,
)
from exo.shared.types.openai_responses import ResponsesRequest
from exo.shared.types.tasks import (
ChatCompletion,
ChatCompletionTaskParams,
ConnectToGroup,
LoadModel,
Shutdown,
@@ -85,11 +84,11 @@ SHUTDOWN_TASK = Shutdown(
runner_id=RUNNER_1_ID,
)
CHAT_PARAMS = ChatCompletionTaskParams(
model=str(MODEL_A_ID),
messages=[ChatCompletionMessage(role="user", content="hello")],
CHAT_PARAMS = ResponsesRequest(
model=MODEL_A_ID,
input="hello",
stream=True,
max_tokens=4,
max_output_tokens=4,
temperature=0.0,
)

View File

@@ -13,10 +13,10 @@ from pydantic import BaseModel
from exo.shared.logging import InterceptLogger, logger_setup
from exo.shared.models.model_cards import MODEL_CARDS, ModelId
from exo.shared.types.api import ChatCompletionMessage, ChatCompletionTaskParams
from exo.shared.types.commands import CommandId
from exo.shared.types.common import Host, NodeId
from exo.shared.types.events import Event
from exo.shared.types.openai_responses import ResponsesRequest
from exo.shared.types.tasks import (
ChatCompletion,
ConnectToGroup,
@@ -180,16 +180,10 @@ async def execute_test(test: Tests, instance: Instance, hn: str):
send.send(StartWarmup(instance_id=iid))
send.send(
ChatCompletion(
task_params=ChatCompletionTaskParams(
task_params=ResponsesRequest(
model=test.model_id,
messages=[
ChatCompletionMessage(
role="system", content="You are a helpful assistant"
),
ChatCompletionMessage(
role="user", content="What is the capital of France?"
),
],
instructions="You are a helpful assistant",
input="What is the capital of France?",
),
command_id=CommandId("yo"),
instance_id=iid,