Compare commits

..

12 Commits

Author SHA1 Message Date
ciaranbor
92de82bdf6 Add traces page to UI 2026-01-30 18:59:02 +00:00
ciaranbor
e792b19f5d Integrate trace collection with exo runners 2026-01-30 18:59:02 +00:00
ciaranbor
edc3a88b12 Minor tweaks 2026-01-30 18:58:30 +00:00
ciaranbor
0ae0c788d5 Average per step 2026-01-30 18:58:30 +00:00
ciaranbor
09abf44d49 Add statistics interpretation 2026-01-30 18:58:30 +00:00
ciaranbor
ac4b78349b Require external evals 2026-01-30 18:58:30 +00:00
ciaranbor
96c440c3b1 Instrument distributed runner 2026-01-30 18:58:30 +00:00
ciaranbor
c14d63cf61 Add tracing utils 2026-01-30 18:58:30 +00:00
Evan Quiney
cd946742f7 fix skipping logic in worker plan (#1342)
the worker plan function had some skipping logic missing, leading to
double-submitting tasks.
2026-01-30 14:31:40 +00:00
rltakashige
a5bc38ad1f Check all nodes to evict (#1341)
## Motivation

If nodes have uneven memory, one node may evict cache that remains on
another node. This will break prefill on some setups.

## Changes

<!-- Describe what you changed in detail -->

## Why It Works

<!-- Explain why your approach solves the problem -->

## Test Plan

### Manual Testing
<!-- Hardware: (e.g., MacBook Pro M1 Max 32GB, Mac Mini M2 16GB,
connected via Thunderbolt 4) -->
<!-- What you did: -->
<!-- - -->

### Automated Testing
<!-- Describe changes to automated tests, or how existing tests cover
this change -->
<!-- - -->
2026-01-30 13:42:09 +00:00
Evan Quiney
2a4e0d4629 make node-ids unique per-session (#1338)
we currently have no strict reuqirements that node ids persist across
sessions, so we can generate fresh nodeids each time

this avoids issues like #1332, but prevents further features such as
caching downloads or node-id dialling

Co-authored-by: rltakashige <rl.takashige@gmail.com>
2026-01-30 13:33:31 +00:00
Evan Quiney
46a14153dd switch to ModelCard.load outside of download log (#1339)
some attempts to load model cards (i.e. build_base_shard) always went
through networking rather than using downloaded model cards. we should
always default to ModelCard.load in these scenarios
2026-01-30 11:20:20 +00:00
21 changed files with 1709 additions and 227 deletions

View File

@@ -178,6 +178,36 @@ interface ImageApiResponse {
data: Array<{ b64_json?: string; url?: string }>;
}
// Trace API response types
export interface TraceCategoryStats {
totalUs: number;
count: number;
minUs: number;
maxUs: number;
avgUs: number;
}
export interface TraceRankStats {
byCategory: Record<string, TraceCategoryStats>;
}
export interface TraceStatsResponse {
taskId: string;
totalWallTimeUs: number;
byCategory: Record<string, TraceCategoryStats>;
byRank: Record<number, TraceRankStats>;
}
export interface TraceListItem {
taskId: string;
createdAt: string;
fileSize: number;
}
export interface TraceListResponse {
traces: TraceListItem[];
}
interface RawStateResponse {
topology?: RawTopology;
instances?: Record<
@@ -2555,6 +2585,49 @@ class AppStore {
throw error;
}
}
/**
* List all available traces
*/
async listTraces(): Promise<TraceListResponse> {
const response = await fetch("/v1/traces");
if (!response.ok) {
throw new Error(`Failed to list traces: ${response.status}`);
}
return (await response.json()) as TraceListResponse;
}
/**
* Check if a trace exists for a given task ID
*/
async checkTraceExists(taskId: string): Promise<boolean> {
try {
const response = await fetch(`/v1/traces/${encodeURIComponent(taskId)}`);
return response.ok;
} catch {
return false;
}
}
/**
* Get computed statistics for a task's trace
*/
async fetchTraceStats(taskId: string): Promise<TraceStatsResponse> {
const response = await fetch(
`/v1/traces/${encodeURIComponent(taskId)}/stats`,
);
if (!response.ok) {
throw new Error(`Failed to fetch trace stats: ${response.status}`);
}
return (await response.json()) as TraceStatsResponse;
}
/**
* Get the URL for the raw trace file (for Perfetto)
*/
getTraceRawUrl(taskId: string): string {
return `/v1/traces/${encodeURIComponent(taskId)}/raw`;
}
}
export const appStore = new AppStore();
@@ -2666,3 +2739,12 @@ export const startDownload = (nodeId: string, shardMetadata: object) =>
appStore.startDownload(nodeId, shardMetadata);
export const deleteDownload = (nodeId: string, modelId: string) =>
appStore.deleteDownload(nodeId, modelId);
// Trace actions
export const listTraces = () => appStore.listTraces();
export const checkTraceExists = (taskId: string) =>
appStore.checkTraceExists(taskId);
export const fetchTraceStats = (taskId: string) =>
appStore.fetchTraceStats(taskId);
export const getTraceRawUrl = (taskId: string) =>
appStore.getTraceRawUrl(taskId);

View File

@@ -0,0 +1,172 @@
<script lang="ts">
import { onMount } from "svelte";
import {
listTraces,
getTraceRawUrl,
type TraceListItem,
} from "$lib/stores/app.svelte";
import HeaderNav from "$lib/components/HeaderNav.svelte";
let traces = $state<TraceListItem[]>([]);
let loading = $state(true);
let error = $state<string | null>(null);
function formatBytes(bytes: number): string {
if (!bytes || bytes <= 0) return "0B";
const units = ["B", "KB", "MB", "GB"];
const i = Math.min(
Math.floor(Math.log(bytes) / Math.log(1024)),
units.length - 1,
);
const val = bytes / Math.pow(1024, i);
return `${val.toFixed(val >= 10 ? 0 : 1)}${units[i]}`;
}
function formatDate(isoString: string): string {
const date = new Date(isoString);
return date.toLocaleString();
}
async function openInPerfetto(taskId: string) {
// Fetch trace data from our local API
const response = await fetch(getTraceRawUrl(taskId));
const traceData = await response.arrayBuffer();
// Open Perfetto UI
const perfettoWindow = window.open("https://ui.perfetto.dev");
if (!perfettoWindow) {
alert("Failed to open Perfetto. Please allow popups.");
return;
}
// Wait for Perfetto to be ready, then send trace via postMessage
const onMessage = (e: MessageEvent) => {
if (e.data === "PONG") {
window.removeEventListener("message", onMessage);
perfettoWindow.postMessage(
{
perfetto: {
buffer: traceData,
title: `Trace ${taskId}`,
},
},
"https://ui.perfetto.dev",
);
}
};
window.addEventListener("message", onMessage);
// Ping Perfetto until it responds
const pingInterval = setInterval(() => {
perfettoWindow.postMessage("PING", "https://ui.perfetto.dev");
}, 50);
// Clean up after 10 seconds
setTimeout(() => {
clearInterval(pingInterval);
window.removeEventListener("message", onMessage);
}, 10000);
}
async function refresh() {
loading = true;
error = null;
try {
const response = await listTraces();
traces = response.traces;
} catch (e) {
error = e instanceof Error ? e.message : "Failed to load traces";
} finally {
loading = false;
}
}
onMount(() => {
refresh();
});
</script>
<div class="min-h-screen bg-exo-dark-gray text-white">
<HeaderNav showHome={true} />
<div class="max-w-7xl mx-auto px-4 lg:px-8 py-6 space-y-6">
<div class="flex items-center justify-between gap-4 flex-wrap">
<div>
<h1
class="text-2xl font-mono tracking-[0.2em] uppercase text-exo-yellow"
>
Traces
</h1>
</div>
<div class="flex items-center gap-3">
<button
type="button"
class="text-xs font-mono text-exo-light-gray hover:text-exo-yellow transition-colors uppercase border border-exo-medium-gray/40 px-2 py-1 rounded"
onclick={refresh}
disabled={loading}
>
Refresh
</button>
</div>
</div>
{#if loading}
<div
class="rounded border border-exo-medium-gray/30 bg-exo-black/30 p-6 text-center text-exo-light-gray"
>
<div class="text-sm">Loading traces...</div>
</div>
{:else if error}
<div
class="rounded border border-red-500/30 bg-red-500/10 p-6 text-center text-red-400"
>
<div class="text-sm">{error}</div>
</div>
{:else if traces.length === 0}
<div
class="rounded border border-exo-medium-gray/30 bg-exo-black/30 p-6 text-center text-exo-light-gray space-y-2"
>
<div class="text-sm">No traces found.</div>
<div class="text-xs text-exo-light-gray/70">
Run exo with EXO_TRACING_ENABLED=1 to collect traces.
</div>
</div>
{:else}
<div class="space-y-3">
{#each traces as trace}
<div
class="rounded border border-exo-medium-gray/30 bg-exo-black/30 p-4 flex items-center justify-between gap-4"
>
<div class="min-w-0 flex-1">
<a
href="#/traces/{trace.taskId}"
class="text-sm font-mono text-white hover:text-exo-yellow transition-colors truncate block"
>
{trace.taskId}
</a>
<div class="text-xs text-exo-light-gray font-mono mt-1">
{formatDate(trace.createdAt)} &bull; {formatBytes(
trace.fileSize,
)}
</div>
</div>
<div class="flex items-center gap-2 shrink-0">
<a
href="#/traces/{trace.taskId}"
class="text-xs font-mono text-exo-light-gray hover:text-exo-yellow transition-colors uppercase border border-exo-medium-gray/40 px-2 py-1 rounded"
>
View Stats
</a>
<button
type="button"
class="text-xs font-mono text-exo-dark-gray bg-exo-yellow hover:bg-exo-yellow/90 transition-colors uppercase px-2 py-1 rounded font-semibold"
onclick={() => openInPerfetto(trace.taskId)}
>
View Trace
</button>
</div>
</div>
{/each}
</div>
{/if}
</div>
</div>

View File

@@ -0,0 +1,347 @@
<script lang="ts">
import { page } from "$app/stores";
import { onMount } from "svelte";
import {
fetchTraceStats,
getTraceRawUrl,
type TraceStatsResponse,
type TraceCategoryStats,
} from "$lib/stores/app.svelte";
import HeaderNav from "$lib/components/HeaderNav.svelte";
const taskId = $derived($page.params.taskId);
let stats = $state<TraceStatsResponse | null>(null);
let loading = $state(true);
let error = $state<string | null>(null);
function formatDuration(us: number): string {
if (us < 1000) return `${us.toFixed(0)}us`;
if (us < 1_000_000) return `${(us / 1000).toFixed(2)}ms`;
return `${(us / 1_000_000).toFixed(2)}s`;
}
function formatPercentage(part: number, total: number): string {
if (total === 0) return "0.0%";
return `${((part / total) * 100).toFixed(1)}%`;
}
// Parse hierarchical categories like "sync/compute" into phases
type PhaseData = {
name: string;
subcategories: { name: string; stats: TraceCategoryStats }[];
totalUs: number; // From outer span (e.g., "sync" category)
stepCount: number; // Count of outer span events
};
function parsePhases(
byCategory: Record<string, TraceCategoryStats>,
): PhaseData[] {
const phases = new Map<
string,
{
subcats: Map<string, TraceCategoryStats>;
outerStats: TraceCategoryStats | null;
}
>();
for (const [category, catStats] of Object.entries(byCategory)) {
if (category.includes("/")) {
const [phase, subcat] = category.split("/", 2);
if (!phases.has(phase)) {
phases.set(phase, { subcats: new Map(), outerStats: null });
}
phases.get(phase)!.subcats.set(subcat, catStats);
} else {
// Outer span - this IS the phase total
if (!phases.has(category)) {
phases.set(category, { subcats: new Map(), outerStats: null });
}
phases.get(category)!.outerStats = catStats;
}
}
return Array.from(phases.entries())
.filter(([_, data]) => data.outerStats !== null) // Only phases with outer spans
.map(([name, data]) => ({
name,
subcategories: Array.from(data.subcats.entries())
.map(([subName, subStats]) => ({ name: subName, stats: subStats }))
.sort((a, b) => b.stats.totalUs - a.stats.totalUs),
totalUs: data.outerStats!.totalUs, // Outer span total
stepCount: data.outerStats!.count, // Number of steps
}))
.sort((a, b) => b.totalUs - a.totalUs);
}
async function openInPerfetto() {
if (!taskId) return;
// Fetch trace data from our local API
const response = await fetch(getTraceRawUrl(taskId));
const traceData = await response.arrayBuffer();
// Open Perfetto UI
const perfettoWindow = window.open("https://ui.perfetto.dev");
if (!perfettoWindow) {
alert("Failed to open Perfetto. Please allow popups.");
return;
}
// Wait for Perfetto to be ready, then send trace via postMessage
const onMessage = (e: MessageEvent) => {
if (e.data === "PONG") {
window.removeEventListener("message", onMessage);
perfettoWindow.postMessage(
{
perfetto: {
buffer: traceData,
title: `Trace ${taskId}`,
},
},
"https://ui.perfetto.dev",
);
}
};
window.addEventListener("message", onMessage);
// Ping Perfetto until it responds
const pingInterval = setInterval(() => {
perfettoWindow.postMessage("PING", "https://ui.perfetto.dev");
}, 50);
// Clean up after 10 seconds
setTimeout(() => {
clearInterval(pingInterval);
window.removeEventListener("message", onMessage);
}, 10000);
}
onMount(async () => {
if (!taskId) {
error = "No task ID provided";
loading = false;
return;
}
try {
stats = await fetchTraceStats(taskId);
} catch (e) {
error = e instanceof Error ? e.message : "Failed to load trace";
} finally {
loading = false;
}
});
const phases = $derived(stats ? parsePhases(stats.byCategory) : []);
const sortedRanks = $derived(
stats
? Object.keys(stats.byRank)
.map(Number)
.sort((a, b) => a - b)
: [],
);
const nodeCount = $derived(sortedRanks.length || 1);
</script>
<div class="min-h-screen bg-exo-dark-gray text-white">
<HeaderNav showHome={true} />
<div class="max-w-7xl mx-auto px-4 lg:px-8 py-6 space-y-6">
<div class="flex items-center justify-between gap-4 flex-wrap">
<div>
<h1
class="text-2xl font-mono tracking-[0.2em] uppercase text-exo-yellow"
>
Trace
</h1>
<p class="text-sm text-exo-light-gray font-mono truncate max-w-lg">
{taskId}
</p>
</div>
<div class="flex items-center gap-3">
<a
href="#/traces"
class="text-xs font-mono text-exo-light-gray hover:text-exo-yellow transition-colors uppercase border border-exo-medium-gray/40 px-3 py-1.5 rounded"
>
All Traces
</a>
<button
type="button"
class="text-xs font-mono text-exo-dark-gray bg-exo-yellow hover:bg-exo-yellow/90 transition-colors uppercase px-3 py-1.5 rounded font-semibold"
onclick={openInPerfetto}
disabled={loading || !!error}
>
View Trace
</button>
</div>
</div>
{#if loading}
<div
class="rounded border border-exo-medium-gray/30 bg-exo-black/30 p-6 text-center text-exo-light-gray"
>
<div class="text-sm">Loading trace data...</div>
</div>
{:else if error}
<div
class="rounded border border-red-500/30 bg-red-500/10 p-6 text-center text-red-400"
>
<div class="text-sm">{error}</div>
</div>
{:else if stats}
<!-- Wall Time Summary -->
<div
class="rounded border border-exo-medium-gray/30 bg-exo-black/30 p-4 space-y-2"
>
<h2
class="text-sm font-mono uppercase tracking-wider text-exo-light-gray"
>
Summary
</h2>
<div class="text-3xl font-mono text-exo-yellow">
{formatDuration(stats.totalWallTimeUs)}
</div>
<div class="text-xs text-exo-light-gray">Total wall time</div>
</div>
<!-- By Phase -->
{#if phases.length > 0}
<div
class="rounded border border-exo-medium-gray/30 bg-exo-black/30 p-4 space-y-4"
>
<h2
class="text-sm font-mono uppercase tracking-wider text-exo-light-gray"
>
By Phase <span class="text-exo-light-gray/50">(avg per node)</span>
</h2>
<div class="space-y-4">
{#each phases as phase}
{@const normalizedTotal = phase.totalUs / nodeCount}
{@const normalizedStepCount = phase.stepCount / nodeCount}
<div class="space-y-2">
<div class="flex items-center justify-between">
<span class="text-sm font-mono text-white">{phase.name}</span>
<span class="text-sm font-mono">
<span class="text-exo-yellow"
>{formatDuration(normalizedTotal)}</span
>
<span class="text-exo-light-gray ml-2">
({normalizedStepCount} steps, {formatDuration(
normalizedTotal / normalizedStepCount,
)}/step)
</span>
</span>
</div>
{#if phase.subcategories.length > 0}
<div class="pl-4 space-y-1.5">
{#each phase.subcategories as subcat}
{@const normalizedSubcat =
subcat.stats.totalUs / nodeCount}
{@const pct = formatPercentage(
normalizedSubcat,
normalizedTotal,
)}
{@const perStep = normalizedSubcat / normalizedStepCount}
<div
class="flex items-center justify-between text-xs font-mono"
>
<span class="text-exo-light-gray">{subcat.name}</span>
<span class="text-white">
{formatDuration(normalizedSubcat)}
<span class="text-exo-light-gray ml-2">({pct})</span>
<span class="text-exo-light-gray/60 ml-2"
>{formatDuration(perStep)}/step</span
>
</span>
</div>
<!-- Progress bar -->
<div
class="relative h-1.5 bg-exo-black/60 rounded-sm overflow-hidden"
>
<div
class="absolute inset-y-0 left-0 bg-gradient-to-r from-exo-yellow to-exo-yellow/70 transition-all duration-300"
style="width: {pct}"
></div>
</div>
{/each}
</div>
{/if}
</div>
{/each}
</div>
</div>
{/if}
<!-- By Rank -->
{#if sortedRanks.length > 0}
<div
class="rounded border border-exo-medium-gray/30 bg-exo-black/30 p-4 space-y-4"
>
<h2
class="text-sm font-mono uppercase tracking-wider text-exo-light-gray"
>
By Rank
</h2>
<div class="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 gap-4">
{#each sortedRanks as rank}
{@const rankStats = stats.byRank[rank]}
{@const rankPhases = parsePhases(rankStats.byCategory)}
<div
class="rounded border border-exo-medium-gray/20 bg-exo-dark-gray/60 p-3 space-y-3"
>
<div class="text-sm font-mono text-exo-yellow">
Rank {rank}
</div>
<div class="space-y-2">
{#each rankPhases as phase}
<div class="space-y-1">
<div class="flex items-center justify-between text-xs">
<span class="font-mono text-exo-light-gray"
>{phase.name}</span
>
<span class="font-mono text-white">
{formatDuration(phase.totalUs)}
<span class="text-exo-light-gray/50 ml-1">
({phase.stepCount}x)
</span>
</span>
</div>
{#if phase.subcategories.length > 0}
<div class="pl-2 space-y-0.5">
{#each phase.subcategories as subcat}
{@const pct = formatPercentage(
subcat.stats.totalUs,
phase.totalUs,
)}
{@const perStep =
subcat.stats.totalUs / phase.stepCount}
<div
class="flex items-center justify-between text-[10px] font-mono"
>
<span class="text-exo-light-gray/70"
>{subcat.name}</span
>
<span class="text-exo-light-gray">
{formatDuration(subcat.stats.totalUs)}
<span class="text-exo-light-gray/50"
>({pct})</span
>
<span class="text-exo-light-gray/30 ml-1"
>{formatDuration(perStep)}/step</span
>
</span>
</div>
{/each}
</div>
{/if}
</div>
{/each}
</div>
</div>
{/each}
</div>
</div>
{/if}
{/if}
</div>
</div>

View File

@@ -19,7 +19,7 @@ dependencies = [
"anyio==4.11.0",
"mlx==0.30.4; sys_platform == 'darwin'",
"mlx[cpu]==0.30.4; sys_platform == 'linux'",
"mlx-lm==0.30.5",
"mlx-lm",
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
"hypercorn>=0.18.0",
"openai-harmony>=0.0.8",
@@ -31,6 +31,8 @@ dependencies = [
]
[project.scripts]
exo-master = "exo.master.main:main"
exo-worker = "exo.worker.main:main"
exo = "exo.main:main"
# dependencies only required for development
@@ -44,6 +46,12 @@ dev = [
"ruff>=0.11.13",
]
# mlx[cuda] requires a newer version of mlx. the ideal on linux is: default to mlx[cpu] unless[cuda] specified.
[project.optional-dependencies]
# cuda = [
# "mlx[cuda]==0.26.3",
# ]
###
# workspace configuration
###
@@ -55,8 +63,8 @@ members = [
[tool.uv.sources]
exo_pyo3_bindings = { workspace = true }
mlx-lm = { git = "https://github.com/ml-explore/mlx-lm", branch = "main" }
# Uncomment to use local mlx/mlx-lm development versions:
# mlx-lm = { git = "https://github.com/ml-explore/mlx-lm", branch = "main" }
# mlx = { path = "/Users/Shared/mlx", editable=true }
# mlx-lm = { path = "/Users/Shared/mlx-lm", editable=true }
@@ -108,7 +116,7 @@ environments = [
###
[tool.ruff]
extend-exclude = ["*mlx_typings/**", "rust/exo_pyo3_bindings/**"]
extend-exclude = ["shared/protobufs/**", "*mlx_typings/**", "rust/exo_pyo3_bindings/**"]
[tool.ruff.lint]
extend-select = ["I", "N", "B", "A", "PIE", "SIM"]

View File

@@ -21,7 +21,7 @@ def exo_shard_downloader(max_parallel_downloads: int = 8) -> ShardDownloader:
async def build_base_shard(model_id: ModelId) -> ShardMetadata:
model_card = await ModelCard.from_hf(model_id)
model_card = await ModelCard.load(model_id)
return PipelineShardMetadata(
model_card=model_card,
device_rank=0,

View File

@@ -3,7 +3,9 @@ import contextlib
import json
import time
from collections.abc import AsyncGenerator
from datetime import datetime, timezone
from http import HTTPStatus
from pathlib import Path
from typing import Annotated, Literal, cast
from uuid import uuid4
@@ -33,6 +35,7 @@ from exo.shared.models.model_cards import (
ModelCard,
ModelId,
)
from exo.shared.tracing import compute_stats, load_trace_file
from exo.shared.types.api import (
AdvancedImageParams,
BenchChatCompletionResponse,
@@ -67,6 +70,13 @@ from exo.shared.types.api import (
StreamingChoiceResponse,
StreamOptions,
ToolCall,
TraceCategoryStats,
TraceEventResponse,
TraceListItem,
TraceListResponse,
TraceRankStats,
TraceResponse,
TraceStatsResponse,
Usage,
)
from exo.shared.types.chunks import (
@@ -146,18 +156,6 @@ def chunk_to_response(
)
async def resolve_model_card(model_id: ModelId) -> ModelCard:
if model_id in MODEL_CARDS:
model_card = MODEL_CARDS[model_id]
return model_card
for card in MODEL_CARDS.values():
if card.model_id == ModelId(model_id):
return card
return await ModelCard.from_hf(model_id)
class API:
def __init__(
self,
@@ -276,10 +274,14 @@ class API:
self.app.get("/events")(lambda: self._event_log)
self.app.post("/download/start")(self.start_download)
self.app.delete("/download/{node_id}/{model_id:path}")(self.delete_download)
self.app.get("/v1/traces")(self.list_traces)
self.app.get("/v1/traces/{task_id}")(self.get_trace)
self.app.get("/v1/traces/{task_id}/stats")(self.get_trace_stats)
self.app.get("/v1/traces/{task_id}/raw")(self.get_trace_raw)
async def place_instance(self, payload: PlaceInstanceParams):
command = PlaceInstance(
model_card=await resolve_model_card(payload.model_id),
model_card=await ModelCard.load(payload.model_id),
sharding=payload.sharding,
instance_meta=payload.instance_meta,
min_nodes=payload.min_nodes,
@@ -296,7 +298,7 @@ class API:
self, payload: CreateInstanceParams
) -> CreateInstanceResponse:
instance = payload.instance
model_card = await resolve_model_card(instance.shard_assignments.model_id)
model_card = await ModelCard.load(instance.shard_assignments.model_id)
required_memory = model_card.storage_size
available_memory = self._calculate_total_available_memory()
@@ -324,7 +326,7 @@ class API:
instance_meta: InstanceMeta = InstanceMeta.MlxRing,
min_nodes: int = 1,
) -> Instance:
model_card = await resolve_model_card(model_id)
model_card = await ModelCard.load(model_id)
try:
placements = get_instance_placements(
@@ -565,7 +567,7 @@ class API:
text_parts: list[str] = []
tool_calls: list[ToolCall] = []
model: str | None = None
model: ModelId | None = None
finish_reason: FinishReason | None = None
usage: Usage | None = None
@@ -624,7 +626,7 @@ class API:
) -> BenchChatCompletionResponse:
text_parts: list[str] = []
tool_calls: list[ToolCall] = []
model: str | None = None
model: ModelId | None = None
finish_reason: FinishReason | None = None
stats: GenerationStats | None = None
@@ -677,7 +679,7 @@ class API:
)
return resp
async def _trigger_notify_user_to_download_model(self, model_id: str) -> None:
async def _trigger_notify_user_to_download_model(self, model_id: ModelId) -> None:
logger.warning(
"TODO: we should send a notification to the user to download the model"
)
@@ -686,7 +688,7 @@ class API:
self, payload: ChatCompletionTaskParams
) -> ChatCompletionResponse | StreamingResponse:
"""Handle chat completions, supporting both streaming and non-streaming responses."""
model_card = await resolve_model_card(ModelId(payload.model))
model_card = await ModelCard.load(ModelId(payload.model))
payload.model = model_card.model_id
if not any(
@@ -713,7 +715,7 @@ class API:
async def bench_chat_completions(
self, payload: BenchChatCompletionTaskParams
) -> BenchChatCompletionResponse:
model_card = await resolve_model_card(ModelId(payload.model))
model_card = await ModelCard.load(ModelId(payload.model))
payload.model = model_card.model_id
if not any(
@@ -733,12 +735,12 @@ class API:
response = await self._collect_chat_completion_with_stats(command.command_id)
return response
async def _validate_image_model(self, model: str) -> ModelId:
async def _validate_image_model(self, model: ModelId) -> ModelId:
"""Validate model exists and return resolved model ID.
Raises HTTPException 404 if no instance is found for the model.
"""
model_card = await resolve_model_card(ModelId(model))
model_card = await ModelCard.load(model)
resolved_model = model_card.model_id
if not any(
instance.shard_assignments.model_id == resolved_model
@@ -784,7 +786,7 @@ class API:
When stream=True and partial_images > 0, returns a StreamingResponse
with SSE-formatted events for partial and final images.
"""
payload.model = await self._validate_image_model(payload.model)
payload.model = await self._validate_image_model(ModelId(payload.model))
command = ImageGeneration(
request_params=payload,
@@ -1029,7 +1031,7 @@ class API:
async def bench_image_generations(
self, request: Request, payload: BenchImageGenerationTaskParams
) -> BenchImageGenerationResponse:
payload.model = await self._validate_image_model(payload.model)
payload.model = await self._validate_image_model(ModelId(payload.model))
payload.stream = False
payload.partial_images = 0
@@ -1050,7 +1052,7 @@ class API:
self,
image: UploadFile,
prompt: str,
model: str,
model: ModelId,
n: int,
size: str,
response_format: Literal["url", "b64_json"],
@@ -1145,7 +1147,7 @@ class API:
command = await self._send_image_edits_command(
image=image,
prompt=prompt,
model=model,
model=ModelId(model),
n=n,
size=size,
response_format=response_format,
@@ -1201,7 +1203,7 @@ class API:
command = await self._send_image_edits_command(
image=image,
prompt=prompt,
model=model,
model=ModelId(model),
n=n,
size=size,
response_format=response_format,
@@ -1348,3 +1350,110 @@ class API:
)
await self._send_download(command)
return DeleteDownloadResponse(command_id=command.command_id)
def _get_traces_dir(self) -> Path:
return Path.home() / ".exo" / "traces"
def _get_trace_path(self, task_id: str) -> Path:
return self._get_traces_dir() / f"trace_{task_id}.json"
async def list_traces(self) -> TraceListResponse:
traces_dir = self._get_traces_dir()
traces: list[TraceListItem] = []
if not traces_dir.exists():
return TraceListResponse(traces=[])
for trace_file in sorted(
traces_dir.glob("trace_*.json"),
key=lambda p: p.stat().st_mtime,
reverse=True,
):
# Extract task_id from filename (trace_{task_id}.json)
task_id = trace_file.stem.removeprefix("trace_")
stat = trace_file.stat()
created_at = datetime.fromtimestamp(
stat.st_mtime, tz=timezone.utc
).isoformat()
traces.append(
TraceListItem(
task_id=task_id,
created_at=created_at,
file_size=stat.st_size,
)
)
return TraceListResponse(traces=traces)
async def get_trace(self, task_id: str) -> TraceResponse:
trace_path = self._get_trace_path(task_id)
if not trace_path.exists():
raise HTTPException(status_code=404, detail=f"Trace not found: {task_id}")
trace_events = load_trace_file(trace_path)
return TraceResponse(
task_id=task_id,
traces=[
TraceEventResponse(
name=event.name,
start_us=event.start_us,
duration_us=event.duration_us,
rank=event.rank,
category=event.category,
)
for event in trace_events
],
)
async def get_trace_stats(self, task_id: str) -> TraceStatsResponse:
trace_path = self._get_trace_path(task_id)
if not trace_path.exists():
raise HTTPException(status_code=404, detail=f"Trace not found: {task_id}")
trace_events = load_trace_file(trace_path)
stats = compute_stats(trace_events)
return TraceStatsResponse(
task_id=task_id,
total_wall_time_us=stats.total_wall_time_us,
by_category={
category: TraceCategoryStats(
total_us=cat_stats.total_us,
count=cat_stats.count,
min_us=cat_stats.min_us,
max_us=cat_stats.max_us,
avg_us=cat_stats.avg_us,
)
for category, cat_stats in stats.by_category.items()
},
by_rank={
rank: TraceRankStats(
by_category={
category: TraceCategoryStats(
total_us=cat_stats.total_us,
count=cat_stats.count,
min_us=cat_stats.min_us,
max_us=cat_stats.max_us,
avg_us=cat_stats.avg_us,
)
for category, cat_stats in rank_stats.items()
}
)
for rank, rank_stats in stats.by_rank.items()
},
)
async def get_trace_raw(self, task_id: str) -> FileResponse:
trace_path = self._get_trace_path(task_id)
if not trace_path.exists():
raise HTTPException(status_code=404, detail=f"Trace not found: {task_id}")
return FileResponse(
path=trace_path,
media_type="application/json",
filename=f"trace_{task_id}.json",
)

View File

@@ -1,4 +1,5 @@
from datetime import datetime, timedelta, timezone
from pathlib import Path
import anyio
from anyio.abc import TaskGroup
@@ -11,6 +12,7 @@ from exo.master.placement import (
place_instance,
)
from exo.shared.apply import apply
from exo.shared.tracing import TraceEvent, export_trace, is_tracing_enabled
from exo.shared.types.commands import (
ChatCompletion,
CreateInstance,
@@ -35,6 +37,8 @@ from exo.shared.types.events import (
NodeTimedOut,
TaskCreated,
TaskDeleted,
TraceEventData,
TracesCollected,
)
from exo.shared.types.state import State
from exo.shared.types.tasks import (
@@ -86,6 +90,8 @@ class Master:
self._multi_buffer = MultiSourceBuffer[NodeId, Event]()
# TODO: not have this
self._event_log: list[Event] = []
self._pending_traces: dict[TaskId, dict[int, list[TraceEventData]]] = {}
self._expected_ranks: dict[TaskId, set[int]] = {}
async def run(self):
logger.info("Starting Master")
@@ -187,13 +193,14 @@ class Master:
)
task_id = TaskId()
selected_instance_id = available_instance_ids[0]
generated_events.append(
TaskCreated(
task_id=task_id,
task=ImageGenerationTask(
task_id=task_id,
command_id=command.command_id,
instance_id=available_instance_ids[0],
instance_id=selected_instance_id,
task_status=TaskStatus.Pending,
task_params=command.request_params,
),
@@ -201,6 +208,17 @@ class Master:
)
self.command_task_mapping[command.command_id] = task_id
if is_tracing_enabled():
selected_instance = self.state.instances.get(
selected_instance_id
)
if selected_instance:
ranks = set(
shard.device_rank
for shard in selected_instance.shard_assignments.runner_to_shard.values()
)
self._expected_ranks[task_id] = ranks
case ImageEdits():
for instance in self.state.instances.values():
if (
@@ -229,13 +247,14 @@ class Master:
)
task_id = TaskId()
selected_instance_id = available_instance_ids[0]
generated_events.append(
TaskCreated(
task_id=task_id,
task=ImageEditsTask(
task_id=task_id,
command_id=command.command_id,
instance_id=available_instance_ids[0],
instance_id=selected_instance_id,
task_status=TaskStatus.Pending,
task_params=command.request_params,
),
@@ -243,6 +262,17 @@ class Master:
)
self.command_task_mapping[command.command_id] = task_id
if is_tracing_enabled():
selected_instance = self.state.instances.get(
selected_instance_id
)
if selected_instance:
ranks = set(
shard.device_rank
for shard in selected_instance.shard_assignments.runner_to_shard.values()
)
self._expected_ranks[task_id] = ranks
case DeleteInstance():
placement = delete_instance(command, self.state.instances)
transition_events = get_transition_events(
@@ -335,6 +365,10 @@ class Master:
local_event.origin,
)
for event in self._multi_buffer.drain():
if isinstance(event, TracesCollected):
self._handle_traces_collected(event)
continue
logger.debug(f"Master indexing event: {str(event)[:100]}")
indexed = IndexedEvent(event=event, idx=len(self._event_log))
self.state = apply(self.state, indexed)
@@ -373,3 +407,38 @@ class Master:
event=event.event,
)
)
def _handle_traces_collected(self, event: TracesCollected) -> None:
task_id = event.task_id
if task_id not in self._pending_traces:
self._pending_traces[task_id] = {}
self._pending_traces[task_id][event.rank] = event.traces
if (
task_id in self._expected_ranks
and set(self._pending_traces[task_id].keys())
>= self._expected_ranks[task_id]
):
self._merge_and_save_traces(task_id)
def _merge_and_save_traces(self, task_id: TaskId) -> None:
all_traces: list[TraceEvent] = []
for trace_data in self._pending_traces[task_id].values():
for t in trace_data:
all_traces.append(
TraceEvent(
name=t.name,
start_us=t.start_us,
duration_us=t.duration_us,
rank=t.rank,
category=t.category,
)
)
output_path = Path.home() / ".exo" / "traces" / f"trace_{task_id}.json"
export_trace(all_traces, output_path)
logger.info(f"Merged traces saved to {output_path}")
del self._pending_traces[task_id]
if task_id in self._expected_ranks:
del self._expected_ranks[task_id]

View File

@@ -216,6 +216,8 @@ def get_node_id_keypair(
Obtains the :class:`Keypair` associated with this node-ID.
Obtain the :class:`PeerId` by from it.
"""
# TODO(evan): bring back node id persistence once we figure out how to deal with duplicates
return Keypair.generate_ed25519()
def lock_path(path: str | bytes | PathLike[str] | PathLike[bytes]) -> Path:
return Path(str(path) + ".lock")

View File

@@ -25,6 +25,7 @@ from exo.shared.types.events import (
TestEvent,
TopologyEdgeCreated,
TopologyEdgeDeleted,
TracesCollected,
)
from exo.shared.types.profiling import (
NodeIdentity,
@@ -55,7 +56,11 @@ def event_apply(event: Event, state: State) -> State:
"""Apply an event to state."""
match event:
case (
TestEvent() | ChunkGenerated() | TaskAcknowledged() | InputChunkReceived()
TestEvent()
| ChunkGenerated()
| TaskAcknowledged()
| InputChunkReceived()
| TracesCollected()
): # Pass-through events that don't modify state
return state
case InstanceCreated():

View File

@@ -53,3 +53,9 @@ EXO_IMAGE_CACHE_DIR = EXO_CACHE_HOME / "images"
EXO_ENABLE_IMAGE_MODELS = (
os.getenv("EXO_ENABLE_IMAGE_MODELS", "false").lower() == "true"
)
EXO_TRACING_ENABLED = os.getenv("EXO_TRACING_ENABLED", "").lower() in (
"1",
"true",
"yes",
)

View File

@@ -8,7 +8,7 @@ from multiprocessing.synchronize import Event as EventT
from multiprocessing.synchronize import Semaphore as SemaphoreT
from loguru import logger
from pytest import LogCaptureFixture
from pytest import LogCaptureFixture, mark
from exo.routing.router import get_node_id_keypair
from exo.shared.constants import EXO_NODE_ID_KEYPAIR
@@ -74,6 +74,7 @@ def _delete_if_exists(p: str | bytes | os.PathLike[str] | os.PathLike[bytes]):
os.remove(p)
@mark.skip(reason="this functionality is currently disabled but may return in future")
def test_node_id_fetching(caplog: LogCaptureFixture):
reps = 10

450
src/exo/shared/tracing.py Normal file
View File

@@ -0,0 +1,450 @@
from __future__ import annotations
import json
import logging
import time
from collections import defaultdict
from collections.abc import Generator
from contextlib import contextmanager
from contextvars import ContextVar
from dataclasses import dataclass, field
from pathlib import Path
from typing import cast, final
from exo.shared.constants import EXO_TRACING_ENABLED
logger = logging.getLogger(__name__)
# Context variable to track the current trace category for hierarchical nesting
_current_category: ContextVar[str | None] = ContextVar("current_category", default=None)
@final
@dataclass(frozen=True)
class TraceEvent:
name: str
start_us: int
duration_us: int
rank: int
category: str
@final
@dataclass
class CategoryStats:
total_us: int = 0
count: int = 0
min_us: int = 0
max_us: int = 0
def add(self, duration_us: int) -> None:
if self.count == 0:
self.min_us = duration_us
self.max_us = duration_us
else:
self.min_us = min(self.min_us, duration_us)
self.max_us = max(self.max_us, duration_us)
self.total_us += duration_us
self.count += 1
@property
def avg_us(self) -> float:
return self.total_us / self.count if self.count > 0 else 0.0
@final
@dataclass
class TraceStats:
total_wall_time_us: int = 0
by_category: dict[str, CategoryStats] = field(default_factory=dict)
by_rank: dict[int, dict[str, CategoryStats]] = field(default_factory=dict)
# Global trace buffer - each rank accumulates traces here
_trace_buffer: list[TraceEvent] = []
def is_tracing_enabled() -> bool:
"""Check if tracing is enabled via environment variable."""
return EXO_TRACING_ENABLED
def _record_span(
name: str, start_us: int, duration_us: int, rank: int, category: str
) -> None:
_trace_buffer.append(
TraceEvent(
name=name,
start_us=start_us,
duration_us=duration_us,
rank=rank,
category=category,
)
)
@contextmanager
def trace(
name: str,
rank: int,
category: str = "compute",
) -> Generator[None, None, None]:
"""Context manager to trace any operation.
Nested traces automatically inherit the parent category, creating hierarchical
categories like "sync/compute" or "async/comms".
Args:
name: Name of the operation (e.g., "recv 0", "send 1", "joint_blocks")
rank: This rank's ID
category: Category for grouping in trace viewer ("comm", "compute", "step")
Example:
with trace(f"sync {t}", rank, "sync"):
with trace("joint_blocks", rank, "compute"):
# Recorded with category "sync/compute"
hidden_states = some_computation(...)
"""
if not is_tracing_enabled():
yield
return
# Combine with parent category if nested
parent = _current_category.get()
full_category = f"{parent}/{category}" if parent else category
# Set as current for nested traces
token = _current_category.set(full_category)
try:
start_us = int(time.time() * 1_000_000)
start_perf = time.perf_counter()
yield
duration_us = int((time.perf_counter() - start_perf) * 1_000_000)
_record_span(name, start_us, duration_us, rank, full_category)
finally:
_current_category.reset(token)
def get_trace_buffer() -> list[TraceEvent]:
return list(_trace_buffer)
def clear_trace_buffer() -> None:
_trace_buffer.clear()
def export_trace(traces: list[TraceEvent], output_path: Path) -> None:
trace_events: list[dict[str, object]] = []
for event in traces:
# Chrome trace format uses "X" for complete events (with duration)
chrome_event: dict[str, object] = {
"name": event.name,
"cat": event.category,
"ph": "X",
"ts": event.start_us,
"dur": event.duration_us,
"pid": 0,
"tid": event.rank,
"args": {"rank": event.rank},
}
trace_events.append(chrome_event)
ranks_seen = set(t.rank for t in traces)
for rank in ranks_seen:
trace_events.append(
{
"name": "thread_name",
"ph": "M", # Metadata event
"pid": 0,
"tid": rank,
"args": {"name": f"Rank {rank}"},
}
)
chrome_trace = {"traceEvents": trace_events}
try:
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, "w") as f:
json.dump(chrome_trace, f, indent=2)
except OSError as e:
logger.warning("Failed to export trace to %s: %s", output_path, e)
def export_local_traces(rank: int) -> None:
if not is_tracing_enabled():
return
local_traces = get_trace_buffer()
if local_traces:
output_path = Path.home() / ".exo" / "traces" / f"trace_{rank}.json"
try:
export_trace(local_traces, output_path)
except Exception as e:
logger.warning("Failed to export local traces for rank %d: %s", rank, e)
clear_trace_buffer()
def merge_trace_files(trace_dir: Path | None = None) -> Path | None:
if trace_dir is None:
trace_dir = Path.home() / ".exo" / "traces"
if not trace_dir.exists():
return None
trace_files = sorted(trace_dir.glob("trace_*.json"))
if not trace_files:
return None
merged_events: list[dict[str, object]] = []
for trace_file in trace_files:
file_rank = int(trace_file.stem.split("_")[1])
with open(trace_file) as f:
raw = f.read()
data = cast(dict[str, list[dict[str, object]]], json.loads(raw))
events: list[dict[str, object]] = data.get("traceEvents", [])
for event in events:
event["tid"] = file_rank
if "args" in event and isinstance(event["args"], dict):
event["args"]["rank"] = file_rank
merged_events.extend(events)
output_path = Path.home() / ".exo" / "traces" / "merged_trace.json"
try:
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, "w") as f:
json.dump({"traceEvents": merged_events}, f, indent=2)
except OSError as e:
logger.warning("Failed to write merged trace to %s: %s", output_path, e)
return None
return output_path
def _format_duration(us: int | float) -> str:
if us < 1000:
return f"{us:.0f}µs"
elif us < 1_000_000:
return f"{us / 1000:.2f}ms"
else:
return f"{us / 1_000_000:.2f}s"
def load_trace_file(path: Path) -> list[TraceEvent]:
"""Load a Chrome Trace Format JSON file into TraceEvent objects."""
with open(path) as f:
data = cast(dict[str, list[dict[str, object]]], json.load(f))
events = data.get("traceEvents", [])
traces: list[TraceEvent] = []
for event in events:
# Skip metadata events
if event.get("ph") == "M":
continue
name = str(event.get("name", ""))
category = str(event.get("cat", ""))
ts_value = event.get("ts", 0)
dur_value = event.get("dur", 0)
tid_value = event.get("tid", 0)
start_us = int(ts_value) if isinstance(ts_value, (int, float, str)) else 0
duration_us = int(dur_value) if isinstance(dur_value, (int, float, str)) else 0
# Get rank from tid or args
rank = int(tid_value) if isinstance(tid_value, (int, float, str)) else 0
args = event.get("args")
if isinstance(args, dict):
args_dict = cast(dict[str, object], args)
rank_from_args = args_dict.get("rank")
if isinstance(rank_from_args, (int, float, str)):
rank = int(rank_from_args)
traces.append(
TraceEvent(
name=name,
start_us=start_us,
duration_us=duration_us,
rank=rank,
category=category,
)
)
return traces
def compute_stats(traces: list[TraceEvent]) -> TraceStats:
"""Compute comprehensive statistics from trace events."""
stats = TraceStats()
if not traces:
return stats
# Calculate wall time from earliest start to latest end
min_start = min(t.start_us for t in traces)
max_end = max(t.start_us + t.duration_us for t in traces)
stats.total_wall_time_us = max_end - min_start
# Initialize nested dicts
by_category: dict[str, CategoryStats] = defaultdict(CategoryStats)
by_rank: dict[int, dict[str, CategoryStats]] = defaultdict(
lambda: defaultdict(CategoryStats)
)
for event in traces:
# By category
by_category[event.category].add(event.duration_us)
# By rank and category
by_rank[event.rank][event.category].add(event.duration_us)
stats.by_category = dict(by_category)
stats.by_rank = {k: dict(v) for k, v in by_rank.items()}
return stats
def print_stats(stats: TraceStats) -> None:
"""Print formatted trace statistics."""
print("=== Trace Statistics ===")
print()
print(f"Wall Time: {_format_duration(stats.total_wall_time_us)}")
print()
# Parse hierarchical categories (e.g., "sync/compute" -> phase="sync", subcat="compute")
if stats.by_category:
phases: dict[str, dict[str, CategoryStats]] = defaultdict(dict)
has_hierarchical = False
for cat, cat_stats in stats.by_category.items():
if "/" in cat:
phase, subcat = cat.split("/", 1)
phases[phase][subcat] = cat_stats
has_hierarchical = True
else:
phases[cat]["_total"] = cat_stats
if has_hierarchical:
print("By Phase:")
for phase in sorted(phases.keys()):
subcats = phases[phase]
# Skip phases that only have _total (non-hierarchical top-level categories)
non_total_subcats = {k: v for k, v in subcats.items() if k != "_total"}
if not non_total_subcats:
continue
phase_total = sum(s.total_us for s in non_total_subcats.values())
print(f" {phase}:")
for subcat, subcat_stats in sorted(
non_total_subcats.items(),
key=lambda x: x[1].total_us,
reverse=True,
):
pct = (
subcat_stats.total_us / phase_total * 100 if phase_total else 0
)
# Use parent phase's step count for per-step average
phase_step_count = subcats.get("_total", CategoryStats()).count
if phase_step_count > 0:
avg_per_step = subcat_stats.total_us / phase_step_count
else:
avg_per_step = subcat_stats.avg_us # fallback
print(
f" {subcat:12s} {_format_duration(subcat_stats.total_us):>10s} "
f"({pct:5.1f}%) avg: {_format_duration(avg_per_step)}"
)
print()
else:
# Fall back to flat category display if no hierarchical categories
print("By Category:")
total_time = sum(c.total_us for c in stats.by_category.values())
for category, cat_stats in sorted(
stats.by_category.items(), key=lambda x: x[1].total_us, reverse=True
):
pct = (cat_stats.total_us / total_time * 100) if total_time > 0 else 0
print(
f" {category:12s} {_format_duration(cat_stats.total_us):>10s} "
f"({pct:5.1f}%) avg: {_format_duration(cat_stats.avg_us):>8s} "
f"count: {cat_stats.count}"
)
print()
# By Rank
if stats.by_rank:
print("By Rank:")
for rank in sorted(stats.by_rank.keys()):
rank_stats = stats.by_rank[rank]
print(f" Rank {rank}:")
# Parse hierarchical categories for this rank
rank_phases: dict[str, dict[str, CategoryStats]] = defaultdict(dict)
has_hierarchical = False
for cat, cat_stats in rank_stats.items():
if "/" in cat:
phase, subcat = cat.split("/", 1)
rank_phases[phase][subcat] = cat_stats
has_hierarchical = True
else:
rank_phases[cat]["_total"] = cat_stats
if has_hierarchical:
for phase in sorted(rank_phases.keys()):
subcats = rank_phases[phase]
non_total_subcats = {
k: v for k, v in subcats.items() if k != "_total"
}
if not non_total_subcats:
continue
phase_total = sum(s.total_us for s in non_total_subcats.values())
print(f" {phase}:")
for subcat, subcat_stats in sorted(
non_total_subcats.items(),
key=lambda x: x[1].total_us,
reverse=True,
):
pct = (
subcat_stats.total_us / phase_total * 100
if phase_total
else 0
)
# Use parent phase's step count for per-step average
phase_step_count = subcats.get("_total", CategoryStats()).count
if phase_step_count > 0:
avg_per_step = subcat_stats.total_us / phase_step_count
else:
avg_per_step = subcat_stats.avg_us # fallback
print(
f" {subcat:12s} {_format_duration(subcat_stats.total_us):>10s} "
f"({pct:5.1f}%) avg: {_format_duration(avg_per_step)}"
)
else:
# Flat display fallback
for category, cat_stats in sorted(
rank_stats.items(), key=lambda x: x[1].total_us, reverse=True
):
print(f" {category}: {_format_duration(cat_stats.total_us)}")
print()
if __name__ == "__main__":
import sys
path = Path(sys.argv[1]) if len(sys.argv) > 1 else Path("trace.json")
if not path.exists():
print(f"Error: File not found: {path}")
sys.exit(1)
traces = load_trace_file(path)
if not traces:
print("No trace events found in file.")
sys.exit(0)
computed_stats = compute_stats(traces)
print_stats(computed_stats)

View File

@@ -373,3 +373,45 @@ class StartDownloadResponse(CamelCaseModel):
class DeleteDownloadResponse(CamelCaseModel):
command_id: CommandId
class TraceEventResponse(CamelCaseModel):
name: str
start_us: int
duration_us: int
rank: int
category: str
class TraceResponse(CamelCaseModel):
task_id: str
traces: list[TraceEventResponse]
class TraceCategoryStats(CamelCaseModel):
total_us: int
count: int
min_us: int
max_us: int
avg_us: float
class TraceRankStats(CamelCaseModel):
by_category: dict[str, TraceCategoryStats]
class TraceStatsResponse(CamelCaseModel):
task_id: str
total_wall_time_us: int
by_category: dict[str, TraceCategoryStats]
by_rank: dict[int, TraceRankStats]
class TraceListItem(CamelCaseModel):
task_id: str
created_at: str
file_size: int
class TraceListResponse(CamelCaseModel):
traces: list[TraceListItem]

View File

@@ -1,4 +1,5 @@
from datetime import datetime
from typing import final
from pydantic import Field
@@ -10,7 +11,7 @@ from exo.shared.types.worker.downloads import DownloadProgress
from exo.shared.types.worker.instances import Instance, InstanceId
from exo.shared.types.worker.runners import RunnerId, RunnerStatus
from exo.utils.info_gatherer.info_gatherer import GatheredInfo
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
from exo.utils.pydantic_ext import CamelCaseModel, FrozenModel, TaggedModel
class EventId(Id):
@@ -109,6 +110,22 @@ class TopologyEdgeDeleted(BaseEvent):
conn: Connection
@final
class TraceEventData(FrozenModel):
name: str
start_us: int
duration_us: int
rank: int
category: str
@final
class TracesCollected(BaseEvent):
task_id: TaskId
rank: int
traces: list[TraceEventData]
Event = (
TestEvent
| TaskCreated
@@ -127,6 +144,7 @@ Event = (
| InputChunkReceived
| TopologyEdgeCreated
| TopologyEdgeDeleted
| TracesCollected
)

View File

@@ -6,6 +6,11 @@ from mflux.models.common.config.config import Config
from mflux.utils.exceptions import StopImageGenerationException
from tqdm import tqdm
from exo.shared.tracing import (
clear_trace_buffer,
is_tracing_enabled,
trace,
)
from exo.shared.types.worker.shards import PipelineShardMetadata
from exo.worker.engines.image.config import ImageModelConfig
from exo.worker.engines.image.models.base import (
@@ -324,6 +329,7 @@ class DiffusionRunner:
capture_steps = set()
self._reset_all_caches()
clear_trace_buffer()
time_steps = tqdm(range(runtime_config.num_inference_steps))
@@ -465,20 +471,22 @@ class DiffusionRunner:
if self.group is None:
return self._single_node_step(t, config, latents, prompt_data)
elif t < config.init_time_step + num_sync_steps:
return self._sync_pipeline_step(
t,
config,
latents,
prompt_data,
)
with trace(name=f"sync {t}", rank=self.rank, category="sync"):
return self._sync_pipeline_step(
t,
config,
latents,
prompt_data,
)
else:
return self._async_pipeline_step(
t,
config,
latents,
prompt_data,
is_first_async_step=t == config.init_time_step + num_sync_steps,
)
with trace(name=f"async {t}", rank=self.rank, category="async"):
return self._async_pipeline_step(
t,
config,
latents,
prompt_data,
is_first_async_step=t == config.init_time_step + num_sync_steps,
)
def _single_node_step(
self,
@@ -586,30 +594,41 @@ class DiffusionRunner:
if self.has_joint_blocks:
if not self.is_first_stage:
hidden_states = mx.distributed.recv(
(batch_size, num_img_tokens, hidden_dim),
dtype,
self.prev_rank,
group=self.group,
)
encoder_hidden_states = mx.distributed.recv(
(batch_size, text_seq_len, hidden_dim),
dtype,
self.prev_rank,
group=self.group,
)
mx.eval(hidden_states, encoder_hidden_states)
with trace(
name=f"recv {self.prev_rank}", rank=self.rank, category="comms"
):
hidden_states = mx.distributed.recv(
(batch_size, num_img_tokens, hidden_dim),
dtype,
self.prev_rank,
group=self.group,
)
encoder_hidden_states = mx.distributed.recv(
(batch_size, text_seq_len, hidden_dim),
dtype,
self.prev_rank,
group=self.group,
)
mx.eval(hidden_states, encoder_hidden_states)
assert self.joint_block_wrappers is not None
assert encoder_hidden_states is not None
for wrapper in self.joint_block_wrappers:
wrapper.set_patch(BlockWrapperMode.CACHING)
encoder_hidden_states, hidden_states = wrapper(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
text_embeddings=text_embeddings,
rotary_embeddings=image_rotary_embeddings,
)
with trace(
name="joint_blocks",
rank=self.rank,
category="compute",
):
for wrapper in self.joint_block_wrappers:
wrapper.set_patch(BlockWrapperMode.CACHING)
encoder_hidden_states, hidden_states = wrapper(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
text_embeddings=text_embeddings,
rotary_embeddings=image_rotary_embeddings,
)
if is_tracing_enabled():
mx.eval(encoder_hidden_states, hidden_states)
if self.owns_concat_stage:
assert encoder_hidden_states is not None
@@ -620,45 +639,63 @@ class DiffusionRunner:
if self.has_single_blocks or self.is_last_stage:
hidden_states = concatenated
else:
concatenated = mx.distributed.send(
concatenated, self.next_rank, group=self.group
)
mx.async_eval(concatenated)
with trace(
name=f"send {self.next_rank}", rank=self.rank, category="comms"
):
concatenated = mx.distributed.send(
concatenated, self.next_rank, group=self.group
)
mx.async_eval(concatenated)
elif self.has_joint_blocks and not self.is_last_stage:
assert encoder_hidden_states is not None
hidden_states = mx.distributed.send(
hidden_states, self.next_rank, group=self.group
)
encoder_hidden_states = mx.distributed.send(
encoder_hidden_states, self.next_rank, group=self.group
)
mx.async_eval(hidden_states, encoder_hidden_states)
if self.has_single_blocks:
if not self.owns_concat_stage and not self.is_first_stage:
hidden_states = mx.distributed.recv(
(batch_size, text_seq_len + num_img_tokens, hidden_dim),
dtype,
self.prev_rank,
group=self.group,
)
mx.eval(hidden_states)
assert self.single_block_wrappers is not None
for wrapper in self.single_block_wrappers:
wrapper.set_patch(BlockWrapperMode.CACHING)
hidden_states = wrapper(
hidden_states=hidden_states,
text_embeddings=text_embeddings,
rotary_embeddings=image_rotary_embeddings,
)
if not self.is_last_stage:
with trace(name=f"send {self.next_rank}", rank=self.rank, category="comms"):
hidden_states = mx.distributed.send(
hidden_states, self.next_rank, group=self.group
)
mx.async_eval(hidden_states)
encoder_hidden_states = mx.distributed.send(
encoder_hidden_states, self.next_rank, group=self.group
)
mx.async_eval(hidden_states, encoder_hidden_states)
if self.has_single_blocks:
if not self.owns_concat_stage and not self.is_first_stage:
with trace(
name=f"recv {self.prev_rank}", rank=self.rank, category="comms"
):
hidden_states = mx.distributed.recv(
(batch_size, text_seq_len + num_img_tokens, hidden_dim),
dtype,
self.prev_rank,
group=self.group,
)
mx.eval(hidden_states)
assert self.single_block_wrappers is not None
with trace(
name="single blocks",
rank=self.rank,
category="compute",
):
for wrapper in self.single_block_wrappers:
wrapper.set_patch(BlockWrapperMode.CACHING)
hidden_states = wrapper(
hidden_states=hidden_states,
text_embeddings=text_embeddings,
rotary_embeddings=image_rotary_embeddings,
)
if is_tracing_enabled():
mx.eval(hidden_states)
if not self.is_last_stage:
with trace(
name=f"send {self.next_rank}", rank=self.rank, category="comms"
):
hidden_states = mx.distributed.send(
hidden_states, self.next_rank, group=self.group
)
mx.async_eval(hidden_states)
hidden_states = hidden_states[:, text_seq_len:, ...]
@@ -742,14 +779,20 @@ class DiffusionRunner:
)
if not self.is_first_stage:
hidden_states = mx.distributed.send(hidden_states, 0, group=self.group)
mx.async_eval(hidden_states)
with trace(name="send 0", rank=self.rank, category="comms"):
hidden_states = mx.distributed.send(
hidden_states, 0, group=self.group
)
mx.async_eval(hidden_states)
elif self.is_first_stage:
hidden_states = mx.distributed.recv_like(
prev_latents, src=self.world_size - 1, group=self.group
)
mx.eval(hidden_states)
with trace(
name=f"recv {self.world_size - 1}", rank=self.rank, category="comms"
):
hidden_states = mx.distributed.recv_like(
prev_latents, src=self.world_size - 1, group=self.group
)
mx.eval(hidden_states)
else:
hidden_states = prev_latents
@@ -809,10 +852,13 @@ class DiffusionRunner:
and not self.is_last_stage
and not is_first_async_step
):
patch = mx.distributed.recv_like(
patch, src=self.prev_rank, group=self.group
)
mx.eval(patch)
with trace(
name=f"recv {self.prev_rank}", rank=self.rank, category="comms"
):
patch = mx.distributed.recv_like(
patch, src=self.prev_rank, group=self.group
)
mx.eval(patch)
step_patch = mx.concatenate([patch, patch], axis=0) if needs_cfg else patch
@@ -843,10 +889,13 @@ class DiffusionRunner:
)
if not self.is_first_stage and t != config.num_inference_steps - 1:
patch_latents[patch_idx] = mx.distributed.send(
patch_latents[patch_idx], self.next_rank, group=self.group
)
mx.async_eval(patch_latents[patch_idx])
with trace(
name=f"send {self.next_rank}", rank=self.rank, category="comms"
):
patch_latents[patch_idx] = mx.distributed.send(
patch_latents[patch_idx], self.next_rank, group=self.group
)
mx.async_eval(patch_latents[patch_idx])
return mx.concatenate(patch_latents, axis=1)
@@ -885,22 +934,28 @@ class DiffusionRunner:
if self.has_joint_blocks:
if not self.is_first_stage:
patch_len = patch.shape[1]
patch = mx.distributed.recv(
(batch_size, patch_len, hidden_dim),
patch.dtype,
self.prev_rank,
group=self.group,
)
mx.eval(patch)
if patch_idx == 0:
encoder_hidden_states = mx.distributed.recv(
(batch_size, text_seq_len, hidden_dim),
with trace(
name=f"recv {self.prev_rank}", rank=self.rank, category="comms"
):
patch = mx.distributed.recv(
(batch_size, patch_len, hidden_dim),
patch.dtype,
self.prev_rank,
group=self.group,
)
mx.eval(encoder_hidden_states)
mx.eval(patch)
if patch_idx == 0:
with trace(
name=f"recv {self.prev_rank}", rank=self.rank, category="comms"
):
encoder_hidden_states = mx.distributed.recv(
(batch_size, text_seq_len, hidden_dim),
patch.dtype,
self.prev_rank,
group=self.group,
)
mx.eval(encoder_hidden_states)
if self.is_first_stage:
patch, encoder_hidden_states = self.adapter.compute_embeddings(
@@ -909,14 +964,22 @@ class DiffusionRunner:
assert self.joint_block_wrappers is not None
assert encoder_hidden_states is not None
for wrapper in self.joint_block_wrappers:
wrapper.set_patch(BlockWrapperMode.PATCHED, start_token, end_token)
encoder_hidden_states, patch = wrapper(
hidden_states=patch,
encoder_hidden_states=encoder_hidden_states,
text_embeddings=text_embeddings,
rotary_embeddings=image_rotary_embeddings,
)
with trace(
name=f"joint patch {patch_idx}",
rank=self.rank,
category="compute",
):
for wrapper in self.joint_block_wrappers:
wrapper.set_patch(BlockWrapperMode.PATCHED, start_token, end_token)
encoder_hidden_states, patch = wrapper(
hidden_states=patch,
encoder_hidden_states=encoder_hidden_states,
text_embeddings=text_embeddings,
rotary_embeddings=image_rotary_embeddings,
)
if is_tracing_enabled():
mx.eval(encoder_hidden_states, patch)
if self.owns_concat_stage:
assert encoder_hidden_states is not None
@@ -925,49 +988,70 @@ class DiffusionRunner:
if self.has_single_blocks or self.is_last_stage:
patch = patch_concat
else:
patch_concat = mx.distributed.send(
patch_concat, self.next_rank, group=self.group
)
mx.async_eval(patch_concat)
with trace(
name=f"send {self.next_rank}", rank=self.rank, category="comms"
):
patch_concat = mx.distributed.send(
patch_concat, self.next_rank, group=self.group
)
mx.async_eval(patch_concat)
elif self.has_joint_blocks and not self.is_last_stage:
patch = mx.distributed.send(patch, self.next_rank, group=self.group)
mx.async_eval(patch)
with trace(name=f"send {self.next_rank}", rank=self.rank, category="comms"):
patch = mx.distributed.send(patch, self.next_rank, group=self.group)
mx.async_eval(patch)
if patch_idx == 0:
assert encoder_hidden_states is not None
encoder_hidden_states = mx.distributed.send(
encoder_hidden_states, self.next_rank, group=self.group
)
mx.async_eval(encoder_hidden_states)
with trace(
name=f"send {self.next_rank}", rank=self.rank, category="comms"
):
encoder_hidden_states = mx.distributed.send(
encoder_hidden_states, self.next_rank, group=self.group
)
mx.async_eval(encoder_hidden_states)
if self.has_single_blocks:
if not self.owns_concat_stage and not self.is_first_stage:
patch_len = patch.shape[1]
patch = mx.distributed.recv(
(batch_size, text_seq_len + patch_len, hidden_dim),
patch.dtype,
self.prev_rank,
group=self.group,
)
mx.eval(patch)
with trace(
name=f"recv {self.prev_rank}", rank=self.rank, category="comms"
):
patch = mx.distributed.recv(
(batch_size, text_seq_len + patch_len, hidden_dim),
patch.dtype,
self.prev_rank,
group=self.group,
)
mx.eval(patch)
assert self.single_block_wrappers is not None
for wrapper in self.single_block_wrappers:
wrapper.set_patch(BlockWrapperMode.PATCHED, start_token, end_token)
patch = wrapper(
hidden_states=patch,
text_embeddings=text_embeddings,
rotary_embeddings=image_rotary_embeddings,
)
with trace(
name=f"single patch {patch_idx}",
rank=self.rank,
category="compute",
):
for wrapper in self.single_block_wrappers:
wrapper.set_patch(BlockWrapperMode.PATCHED, start_token, end_token)
patch = wrapper(
hidden_states=patch,
text_embeddings=text_embeddings,
rotary_embeddings=image_rotary_embeddings,
)
if is_tracing_enabled():
mx.eval(patch)
if not self.is_last_stage:
patch = mx.distributed.send(patch, self.next_rank, group=self.group)
mx.async_eval(patch)
with trace(
name=f"send {self.next_rank}", rank=self.rank, category="comms"
):
patch = mx.distributed.send(patch, self.next_rank, group=self.group)
mx.async_eval(patch)
noise: mx.array | None = None
if self.is_last_stage:
patch_img_only = patch[:, text_seq_len:, :]
noise = self.adapter.final_projection(patch_img_only, text_embeddings)
patch = patch[:, text_seq_len:, :]
noise = self.adapter.final_projection(patch, text_embeddings)
return noise, encoder_hidden_states

View File

@@ -3,6 +3,7 @@ from copy import deepcopy
from typing import Any, cast
import mlx.core as mx
import psutil
from mlx_lm.models.cache import (
KVCache,
QuantizedKVCache,
@@ -12,25 +13,29 @@ from mlx_lm.models.cache import (
from mlx_lm.models.gpt_oss import Model as GptOssModel
from mlx_lm.tokenizer_utils import TokenizerWrapper
from exo.shared.types.memory import Memory
from exo.shared.types.mlx import KVCacheType
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.constants import CACHE_GROUP_SIZE, KV_CACHE_BITS
from exo.worker.runner.bootstrap import logger
# Fraction of device memory above which LRU eviction kicks in
_DEFAULT_MEMORY_THRESHOLD = 0.85
_DEFAULT_MEMORY_THRESHOLD = 0.9
_MEMORY_THRESHOLD = float(
os.environ.get("EXO_MEMORY_THRESHOLD", _DEFAULT_MEMORY_THRESHOLD)
)
class KVPrefixCache:
def __init__(self, tokenizer: TokenizerWrapper):
def __init__(
self, tokenizer: TokenizerWrapper, group: mx.distributed.Group | None = None
):
self.prompts: list[mx.array] = [] # mx array of tokens (ints)
self.caches: list[KVCacheType] = []
self._last_used: list[int] = [] # monotonic counter of last access per entry
self._access_counter: int = 0
self._tokenizer: TokenizerWrapper = tokenizer
self._group = group
def clear(self):
"""Clear all cached prompts and caches."""
@@ -81,13 +86,13 @@ class KVPrefixCache:
best_snapshot_index, best_snapshot_length = None, 0
for i, cached_prompt in enumerate(self.prompts):
length = _get_prefix_length(tokenized_prompt, cached_prompt)
length = get_prefix_length(tokenized_prompt, cached_prompt)
if length == max_length:
# Exact match - cached prompt starts with our entire prompt
# Trim cache to prompt length - 1, return last token for stream_generate
prompt_cache = deepcopy(self.caches[i])
cached_length = _cache_length(self.caches[i])
cached_length = cache_length(self.caches[i])
tokens_to_trim = cached_length - (max_length - 1)
if tokens_to_trim > 0:
trim_prompt_cache(cast(list[Any], prompt_cache), tokens_to_trim)
@@ -109,7 +114,7 @@ class KVPrefixCache:
prompt_cache = deepcopy(self.caches[best_snapshot_index])
# Trim removes tokens from the end, so we trim (cached_length - prefix_length) to keep the prefix
cached_length = _cache_length(self.caches[best_snapshot_index])
cached_length = cache_length(self.caches[best_snapshot_index])
tokens_to_trim = cached_length - best_snapshot_length
if tokens_to_trim > 0:
trim_prompt_cache(cast(list[Any], prompt_cache), tokens_to_trim)
@@ -131,29 +136,37 @@ class KVPrefixCache:
return prompt_cache, tokenized_prompt, None
def _evict_if_needed(self):
"""Evict least recently used entries while memory pressure is high."""
"""Evict least recently used entries while memory usage is high."""
if len(self.caches) == 0:
return
active: int = mx.metal.get_active_memory()
limit = int(mx.metal.device_info()["max_recommended_working_set_size"])
if active < limit * _MEMORY_THRESHOLD:
return
# Evict LRU entries until below threshold or only one entry left
while len(self.caches) > 0:
while (
len(self.caches) > 1
and self.get_memory_used_percentage() > _MEMORY_THRESHOLD
):
lru_index = self._last_used.index(min(self._last_used))
evicted_tokens = len(self.prompts[lru_index])
self.prompts.pop(lru_index)
self.caches.pop(lru_index)
self._last_used.pop(lru_index)
logger.info(
f"KV cache evicted LRU entry ({evicted_tokens} tokens) due to memory pressure"
f"KV cache evicted LRU entry ({evicted_tokens} tokens) due to memory usage"
)
active = mx.metal.get_active_memory()
if active < limit * _MEMORY_THRESHOLD:
break
def get_memory_used_percentage(self) -> float:
local_pressure: float = get_memory_used_percentage()
if self._group is None:
return local_pressure
all_pressure = mx.distributed.all_gather(
mx.array([local_pressure], dtype=mx.float32),
group=self._group,
)
# .item() evals.
max_pressure = float(mx.max(all_pressure).item())
return max_pressure
def encode_prompt(tokenizer: TokenizerWrapper, prompt: str) -> mx.array:
@@ -168,13 +181,13 @@ def encode_prompt(tokenizer: TokenizerWrapper, prompt: str) -> mx.array:
return mx.array(tokenized_prompt)
def _cache_length(cache: KVCacheType) -> int:
def cache_length(cache: KVCacheType) -> int:
"""Get the number of tokens in a KV cache."""
# Use .offset attribute which all cache types have (len() not implemented in older QuantizedKVCache)
return max(c.offset for c in cache) # type: ignore
def _get_prefix_length(prompt: mx.array, cached_prompt: mx.array) -> int:
def get_prefix_length(prompt: mx.array, cached_prompt: mx.array) -> int:
"""Find the length of the common prefix between two token arrays."""
n = min(int(prompt.shape[0]), int(cached_prompt.shape[0]))
if n == 0:
@@ -185,6 +198,17 @@ def _get_prefix_length(prompt: mx.array, cached_prompt: mx.array) -> int:
return int(mx.sum(prefix_mask).item())
def get_available_memory() -> Memory:
mem: int = psutil.virtual_memory().available
return Memory.from_bytes(mem)
def get_memory_used_percentage() -> float:
mem = psutil.virtual_memory()
# percent is 0-100
return float(mem.percent / 100)
def make_kv_cache(
model: Model, max_kv_size: int | None = None, keep: int = 0
) -> KVCacheType:

View File

@@ -18,6 +18,7 @@ from pydantic import ValidationError
from exo.shared.constants import EXO_MAX_CHUNK_SIZE
from exo.shared.models.model_cards import ModelId, ModelTask
from exo.shared.tracing import clear_trace_buffer, get_trace_buffer, is_tracing_enabled
from exo.shared.types.api import ChatCompletionMessageText, ImageGenerationStats
from exo.shared.types.chunks import ErrorChunk, ImageChunk, TokenChunk, ToolCallChunk
from exo.shared.types.common import CommandId
@@ -27,6 +28,8 @@ from exo.shared.types.events import (
RunnerStatusUpdated,
TaskAcknowledged,
TaskStatusUpdated,
TraceEventData,
TracesCollected,
)
from exo.shared.types.tasks import (
ChatCompletion,
@@ -37,6 +40,7 @@ from exo.shared.types.tasks import (
Shutdown,
StartWarmup,
Task,
TaskId,
TaskStatus,
)
from exo.shared.types.worker.instances import BoundInstance
@@ -111,8 +115,12 @@ def main(
event_sender.send(
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
)
seen = set[TaskId]()
with task_receiver as tasks:
for task in tasks:
if task.task_id in seen:
logger.warning("repeat task - potential error")
seen.add(task.task_id)
event_sender.send(
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Running)
)
@@ -163,7 +171,7 @@ def main(
logger.info(
f"model has_tool_calling={tokenizer.has_tool_calling}"
)
kv_prefix_cache = KVPrefixCache(tokenizer)
kv_prefix_cache = KVPrefixCache(tokenizer, group)
elif (
ModelTask.TextToImage in shard_metadata.model_card.tasks
@@ -403,6 +411,10 @@ def main(
)
)
raise
finally:
_send_traces_if_enabled(
event_sender, task.task_id, shard_metadata.device_rank
)
current_status = RunnerReady()
logger.info("runner ready")
@@ -461,6 +473,10 @@ def main(
)
)
raise
finally:
_send_traces_if_enabled(
event_sender, task.task_id, shard_metadata.device_rank
)
current_status = RunnerReady()
logger.info("runner ready")
@@ -635,6 +651,36 @@ def _send_image_chunk(
)
def _send_traces_if_enabled(
event_sender: MpSender[Event],
task_id: TaskId,
rank: int,
) -> None:
if not is_tracing_enabled():
return
traces = get_trace_buffer()
if traces:
trace_data = [
TraceEventData(
name=t.name,
start_us=t.start_us,
duration_us=t.duration_us,
rank=t.rank,
category=t.category,
)
for t in traces
]
event_sender.send(
TracesCollected(
task_id=task_id,
rank=rank,
traces=trace_data,
)
)
clear_trace_buffer()
def _process_image_response(
response: ImageGenerationResponse | PartialImageResponse,
command_id: CommandId,

View File

@@ -127,20 +127,25 @@ class RunnerSupervisor:
self._tg.cancel_scope.cancel()
async def start_task(self, task: Task):
if task.task_id in self.pending:
logger.warning(
f"Skipping invalid task {task} as it has already been submitted"
)
return
if task.task_id in self.completed:
logger.info(
logger.warning(
f"Skipping invalid task {task} as it has already been completed"
)
return
logger.info(f"Starting task {task}")
event = anyio.Event()
self.pending[task.task_id] = event
try:
self._task_sender.send(task)
await self._task_sender.send_async(task)
except ClosedResourceError:
logger.warning(f"Task {task} dropped, runner closed communication.")
return
await event.wait()
logger.info(f"Finished task {task}")
async def _forward_events(self):
with self._ev_recv as events:

View File

@@ -14,9 +14,9 @@ from exo.shared.types.tasks import ChatCompletionTaskParams
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.cache import (
KVPrefixCache,
_cache_length,
_get_prefix_length,
cache_length,
encode_prompt,
get_prefix_length,
make_kv_cache,
)
from exo.worker.engines.mlx.generator.generate import mlx_generate, prefill
@@ -35,47 +35,47 @@ class TestGetPrefixLength:
def test_identical_arrays(self):
a = mx.array([1, 2, 3, 4, 5])
b = mx.array([1, 2, 3, 4, 5])
assert _get_prefix_length(a, b) == 5
assert get_prefix_length(a, b) == 5
def test_no_common_prefix(self):
a = mx.array([1, 2, 3])
b = mx.array([4, 5, 6])
assert _get_prefix_length(a, b) == 0
assert get_prefix_length(a, b) == 0
def test_partial_prefix(self):
a = mx.array([1, 2, 3, 4, 5])
b = mx.array([1, 2, 3, 7, 8])
assert _get_prefix_length(a, b) == 3
assert get_prefix_length(a, b) == 3
def test_prompt_longer_than_cached(self):
a = mx.array([1, 2, 3, 4, 5])
b = mx.array([1, 2, 3])
assert _get_prefix_length(a, b) == 3
assert get_prefix_length(a, b) == 3
def test_cached_longer_than_prompt(self):
a = mx.array([1, 2, 3])
b = mx.array([1, 2, 3, 4, 5])
assert _get_prefix_length(a, b) == 3
assert get_prefix_length(a, b) == 3
def test_single_token_match(self):
a = mx.array([1, 2, 3])
b = mx.array([1, 5, 6])
assert _get_prefix_length(a, b) == 1
assert get_prefix_length(a, b) == 1
def test_empty_prompt(self):
a = mx.array([]).astype(mx.int32)
b = mx.array([1, 2, 3])
assert _get_prefix_length(a, b) == 0
assert get_prefix_length(a, b) == 0
def test_empty_cached(self):
a = mx.array([1, 2, 3])
b = mx.array([]).astype(mx.int32)
assert _get_prefix_length(a, b) == 0
assert get_prefix_length(a, b) == 0
def test_both_empty(self):
a = mx.array([]).astype(mx.int32)
b = mx.array([]).astype(mx.int32)
assert _get_prefix_length(a, b) == 0
assert get_prefix_length(a, b) == 0
class TestKVPrefix:
@@ -146,7 +146,7 @@ class TestKVPrefixCacheWithModel:
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
# Cache should now hold the prompt tokens
assert _cache_length(cache) == len(tokens)
assert cache_length(cache) == len(tokens)
def test_add_and_get_exact_match(self, model_and_tokenizer):
model, tokenizer = model_and_tokenizer
@@ -166,7 +166,7 @@ class TestKVPrefixCacheWithModel:
kv_prefix_cache.add_kv_cache(prompt, cache)
assert len(kv_prefix_cache.prompts) == 1
stored_length = _cache_length(kv_prefix_cache.caches[0])
stored_length = cache_length(kv_prefix_cache.caches[0])
assert stored_length > 0
# Retrieve with same prompt: exact match
@@ -209,7 +209,7 @@ class TestKVPrefixCacheWithModel:
long_tokens = encode_prompt(tokenizer, long_prompt)
# The prompts share a prefix (chat template preamble + "Hi")
expected_prefix = _get_prefix_length(long_tokens, short_tokens)
expected_prefix = get_prefix_length(long_tokens, short_tokens)
assert expected_prefix > 0, (
"Prompts should share a prefix from the chat template"
)
@@ -243,7 +243,7 @@ class TestKVPrefixCacheWithModel:
kv_prefix_cache = KVPrefixCache(tokenizer)
kv_prefix_cache.add_kv_cache(prompt, cache)
stored_length = _cache_length(kv_prefix_cache.caches[0])
stored_length = cache_length(kv_prefix_cache.caches[0])
# Get cache and mutate it (simulating what generation does)
result_cache, _, matched_index = kv_prefix_cache.get_kv_cache(model, prompt)
@@ -259,7 +259,7 @@ class TestKVPrefixCacheWithModel:
mx.eval([c.keys for c in result_cache])
# Stored cache must be unchanged
assert _cache_length(kv_prefix_cache.caches[0]) == stored_length
assert cache_length(kv_prefix_cache.caches[0]) == stored_length
def test_stored_cache_survives_repeated_get_mutate_cycles(
self, model_and_tokenizer
@@ -281,7 +281,7 @@ class TestKVPrefixCacheWithModel:
kv_prefix_cache = KVPrefixCache(tokenizer)
kv_prefix_cache.add_kv_cache(prompt, cache)
stored_length = _cache_length(kv_prefix_cache.caches[0])
stored_length = cache_length(kv_prefix_cache.caches[0])
for i in range(3):
result_cache, _, _ = kv_prefix_cache.get_kv_cache(model, prompt)
@@ -293,7 +293,7 @@ class TestKVPrefixCacheWithModel:
layer_cache.update_and_fetch(extra, extra)
mx.eval([c.keys for c in result_cache])
assert _cache_length(kv_prefix_cache.caches[0]) == stored_length, (
assert cache_length(kv_prefix_cache.caches[0]) == stored_length, (
f"Failed on loop {i}"
)
@@ -325,7 +325,7 @@ class TestKVPrefixCacheWithModel:
assert len(kv_prefix_cache.caches) == 1
# Cache should contain prompt + generated tokens
expected_length = len(prompt_tokens) + generated_tokens
assert _cache_length(kv_prefix_cache.caches[0]) == expected_length
assert cache_length(kv_prefix_cache.caches[0]) == expected_length
def test_mlx_generate_second_call_gets_prefix_hit(self, model_and_tokenizer):
"""Second mlx_generate call with same prompt should get a prefix hit from stored cache."""
@@ -400,7 +400,7 @@ class TestKVPrefixCacheWithModel:
first_gen_time = time.perf_counter() - t0
assert len(kv_prefix_cache.prompts) == 1
first_cache_length = _cache_length(kv_prefix_cache.caches[0])
first_cache_length = cache_length(kv_prefix_cache.caches[0])
# Second generation: same long prompt + extra content (simulating multi-turn)
task2 = ChatCompletionTaskParams(
@@ -416,7 +416,7 @@ class TestKVPrefixCacheWithModel:
prompt2_tokens = encode_prompt(tokenizer, prompt2)
# Verify the prompts share a long prefix
prefix_len = _get_prefix_length(prompt2_tokens, prompt1_tokens)
prefix_len = get_prefix_length(prompt2_tokens, prompt1_tokens)
assert prefix_len > 1000, "Prompts must share > 1000 token prefix"
# Second generation should reuse the cached prefix (only prefill new tokens)
@@ -440,7 +440,7 @@ class TestKVPrefixCacheWithModel:
# With prefix_hit > 1000, should update in-place (not add a second entry)
assert len(kv_prefix_cache.prompts) == 1
# Updated cache should be longer (prompt2 + generated > prompt1 + generated)
updated_cache_length = _cache_length(kv_prefix_cache.caches[0])
updated_cache_length = cache_length(kv_prefix_cache.caches[0])
assert updated_cache_length > first_cache_length
def test_mlx_generate_stored_cache_not_mutated(self, model_and_tokenizer):
@@ -465,7 +465,7 @@ class TestKVPrefixCacheWithModel:
):
pass
first_cache_length = _cache_length(kv_prefix_cache.caches[0])
firstcache_length = cache_length(kv_prefix_cache.caches[0])
# Second generation gets the cache and mutates it during generation
for _response in mlx_generate(
@@ -478,7 +478,7 @@ class TestKVPrefixCacheWithModel:
pass
# The first stored cache must not have been mutated by the second generation
assert _cache_length(kv_prefix_cache.caches[0]) == first_cache_length
assert cache_length(kv_prefix_cache.caches[0]) == firstcache_length
def test_evicts_lru_entry_under_memory_pressure(self, model_and_tokenizer):
"""Under memory pressure, adding a new cache entry evicts the least recently used one."""
@@ -540,6 +540,6 @@ class TestKVPrefixCacheWithModel:
assert len(kv_prefix_cache.prompts) == 1
# The surviving entry should be the newly added one
new_tokens = encode_prompt(tokenizer, prompt)
assert _get_prefix_length(kv_prefix_cache.prompts[0], new_tokens) == len(
assert get_prefix_length(kv_prefix_cache.prompts[0], new_tokens) == len(
new_tokens
)

View File

@@ -109,8 +109,8 @@ def assert_events_equal(test_events: Iterable[Event], true_events: Iterable[Even
@pytest.fixture
def patch_out_mlx(monkeypatch: pytest.MonkeyPatch):
# initialize_mlx returns a "group" equal to 1
monkeypatch.setattr(mlx_runner, "initialize_mlx", make_nothin(1))
# initialize_mlx returns a mock group
monkeypatch.setattr(mlx_runner, "initialize_mlx", make_nothin(MockGroup()))
monkeypatch.setattr(mlx_runner, "load_mlx_items", make_nothin((1, MockTokenizer)))
monkeypatch.setattr(mlx_runner, "warmup_inference", make_nothin(1))
monkeypatch.setattr(mlx_runner, "_check_for_debug_prompts", nothin)
@@ -147,6 +147,14 @@ class MockTokenizer:
has_tool_calling = False
class MockGroup:
def rank(self) -> int:
return 0
def size(self) -> int:
return 1
def _run(tasks: Iterable[Task]):
bound_instance = get_bound_mlx_ring_instance(
instance_id=INSTANCE_1_ID,

30
uv.lock generated
View File

@@ -192,14 +192,20 @@ sdist = { url = "https://files.pythonhosted.org/packages/eb/56/b1ba7935a17738ae8
wheels = [
{ url = "https://files.pythonhosted.org/packages/b0/1e/d22cc63332bd59b06481ceaac49d6c507598642e2230f201649058a7e704/cffi-2.0.0-cp313-cp313-manylinux1_i686.manylinux2014_i686.manylinux_2_17_i686.manylinux_2_5_i686.whl", hash = "sha256:07b271772c100085dd28b74fa0cd81c8fb1a3ba18b21e03d7c27f3436a10606b", size = 212446 },
{ url = "https://files.pythonhosted.org/packages/a9/f5/a2c23eb03b61a0b8747f211eb716446c826ad66818ddc7810cc2cc19b3f2/cffi-2.0.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:d48a880098c96020b02d5a1f7d9251308510ce8858940e6fa99ece33f610838b", size = 220101 },
{ url = "https://files.pythonhosted.org/packages/f2/7f/e6647792fc5850d634695bc0e6ab4111ae88e89981d35ac269956605feba/cffi-2.0.0-cp313-cp313-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:f93fd8e5c8c0a4aa1f424d6173f14a892044054871c771f8566e4008eaa359d2", size = 207948 },
{ url = "https://files.pythonhosted.org/packages/cb/1e/a5a1bd6f1fb30f22573f76533de12a00bf274abcdc55c8edab639078abb6/cffi-2.0.0-cp313-cp313-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:dd4f05f54a52fb558f1ba9f528228066954fee3ebe629fc1660d874d040ae5a3", size = 206422 },
{ url = "https://files.pythonhosted.org/packages/98/df/0a1755e750013a2081e863e7cd37e0cdd02664372c754e5560099eb7aa44/cffi-2.0.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:c8d3b5532fc71b7a77c09192b4a5a200ea992702734a2e9279a37f2478236f26", size = 219499 },
{ url = "https://files.pythonhosted.org/packages/50/e1/a969e687fcf9ea58e6e2a928ad5e2dd88cc12f6f0ab477e9971f2309b57c/cffi-2.0.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:d9b29c1f0ae438d5ee9acb31cadee00a58c46cc9c0b2f9038c6b0b3470877a8c", size = 222928 },
{ url = "https://files.pythonhosted.org/packages/36/54/0362578dd2c9e557a28ac77698ed67323ed5b9775ca9d3fe73fe191bb5d8/cffi-2.0.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:6d50360be4546678fc1b79ffe7a66265e28667840010348dd69a314145807a1b", size = 221302 },
{ url = "https://files.pythonhosted.org/packages/d6/43/0e822876f87ea8a4ef95442c3d766a06a51fc5298823f884ef87aaad168c/cffi-2.0.0-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:24b6f81f1983e6df8db3adc38562c83f7d4a0c36162885ec7f7b77c7dcbec97b", size = 220049 },
{ url = "https://files.pythonhosted.org/packages/b4/89/76799151d9c2d2d1ead63c2429da9ea9d7aac304603de0c6e8764e6e8e70/cffi-2.0.0-cp314-cp314-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:12873ca6cb9b0f0d3a0da705d6086fe911591737a59f28b7936bdfed27c0d47c", size = 207793 },
{ url = "https://files.pythonhosted.org/packages/bb/dd/3465b14bb9e24ee24cb88c9e3730f6de63111fffe513492bf8c808a3547e/cffi-2.0.0-cp314-cp314-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:d9b97165e8aed9272a6bb17c01e3cc5871a594a446ebedc996e2397a1c1ea8ef", size = 206300 },
{ url = "https://files.pythonhosted.org/packages/47/d9/d83e293854571c877a92da46fdec39158f8d7e68da75bf73581225d28e90/cffi-2.0.0-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:afb8db5439b81cf9c9d0c80404b60c3cc9c3add93e114dcae767f1477cb53775", size = 219244 },
{ url = "https://files.pythonhosted.org/packages/2b/0f/1f177e3683aead2bb00f7679a16451d302c436b5cbf2505f0ea8146ef59e/cffi-2.0.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:737fe7d37e1a1bffe70bd5754ea763a62a066dc5913ca57e957824b72a85e205", size = 222828 },
{ url = "https://files.pythonhosted.org/packages/c6/0f/cafacebd4b040e3119dcb32fed8bdef8dfe94da653155f9d0b9dc660166e/cffi-2.0.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:38100abb9d1b1435bc4cc340bb4489635dc2f0da7456590877030c9b3d40b0c1", size = 220926 },
{ url = "https://files.pythonhosted.org/packages/be/b4/c56878d0d1755cf9caa54ba71e5d049479c52f9e4afc230f06822162ab2f/cffi-2.0.0-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:7cc09976e8b56f8cebd752f7113ad07752461f48a58cbba644139015ac24954c", size = 221593 },
{ url = "https://files.pythonhosted.org/packages/e0/0d/eb704606dfe8033e7128df5e90fee946bbcb64a04fcdaa97321309004000/cffi-2.0.0-cp314-cp314t-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:92b68146a71df78564e4ef48af17551a5ddd142e5190cdf2c5624d0c3ff5b2e8", size = 209354 },
{ url = "https://files.pythonhosted.org/packages/d8/19/3c435d727b368ca475fb8742ab97c9cb13a0de600ce86f62eab7fa3eea60/cffi-2.0.0-cp314-cp314t-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:b1e74d11748e7e98e2f426ab176d4ed720a64412b6a15054378afdb71e0f37dc", size = 208480 },
{ url = "https://files.pythonhosted.org/packages/d0/44/681604464ed9541673e486521497406fadcc15b5217c3e326b061696899a/cffi-2.0.0-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:28a3a209b96630bca57cce802da70c266eb08c6e97e5afd61a75611ee6c64592", size = 221584 },
{ url = "https://files.pythonhosted.org/packages/25/8e/342a504ff018a2825d395d44d63a767dd8ebc927ebda557fecdaca3ac33a/cffi-2.0.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:7553fb2090d71822f02c629afe6042c299edf91ba1bf94951165613553984512", size = 224443 },
{ url = "https://files.pythonhosted.org/packages/e1/5e/b666bacbbc60fbf415ba9988324a132c9a7a0448a9a8f125074671c0f2c3/cffi-2.0.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:6c6c373cfc5c83a975506110d17457138c8c63016b563cc9ed6e056a82f13ce4", size = 223437 },
@@ -305,8 +311,10 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/5c/49/498c86566a1d80e978b42f0d702795f69887005548c041636df6ae1ca64c/cryptography-46.0.3-cp311-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:01ca9ff2885f3acc98c29f1860552e37f6d7c7d013d7334ff2a9de43a449315d", size = 4450807 },
{ url = "https://files.pythonhosted.org/packages/4b/0a/863a3604112174c8624a2ac3c038662d9e59970c7f926acdcfaed8d61142/cryptography-46.0.3-cp311-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:6eae65d4c3d33da080cff9c4ab1f711b15c1d9760809dad6ea763f3812d254cb", size = 4299615 },
{ url = "https://files.pythonhosted.org/packages/64/02/b73a533f6b64a69f3cd3872acb6ebc12aef924d8d103133bb3ea750dc703/cryptography-46.0.3-cp311-abi3-manylinux_2_28_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:e5bf0ed4490068a2e72ac03d786693adeb909981cc596425d09032d372bcc849", size = 4016800 },
{ url = "https://files.pythonhosted.org/packages/25/d5/16e41afbfa450cde85a3b7ec599bebefaef16b5c6ba4ec49a3532336ed72/cryptography-46.0.3-cp311-abi3-manylinux_2_28_ppc64le.whl", hash = "sha256:5ecfccd2329e37e9b7112a888e76d9feca2347f12f37918facbb893d7bb88ee8", size = 4984707 },
{ url = "https://files.pythonhosted.org/packages/c9/56/e7e69b427c3878352c2fb9b450bd0e19ed552753491d39d7d0a2f5226d41/cryptography-46.0.3-cp311-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:a2c0cd47381a3229c403062f764160d57d4d175e022c1df84e168c6251a22eec", size = 4482541 },
{ url = "https://files.pythonhosted.org/packages/78/f6/50736d40d97e8483172f1bb6e698895b92a223dba513b0ca6f06b2365339/cryptography-46.0.3-cp311-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:549e234ff32571b1f4076ac269fcce7a808d3bf98b76c8dd560e42dbc66d7d91", size = 4299464 },
{ url = "https://files.pythonhosted.org/packages/00/de/d8e26b1a855f19d9994a19c702fa2e93b0456beccbcfe437eda00e0701f2/cryptography-46.0.3-cp311-abi3-manylinux_2_34_ppc64le.whl", hash = "sha256:c0a7bb1a68a5d3471880e264621346c48665b3bf1c3759d682fc0864c540bd9e", size = 4950838 },
{ url = "https://files.pythonhosted.org/packages/8f/29/798fc4ec461a1c9e9f735f2fc58741b0daae30688f41b2497dcbc9ed1355/cryptography-46.0.3-cp311-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:10b01676fc208c3e6feeb25a8b83d81767e8059e1fe86e1dc62d10a3018fa926", size = 4481596 },
{ url = "https://files.pythonhosted.org/packages/15/8d/03cd48b20a573adfff7652b76271078e3045b9f49387920e7f1f631d125e/cryptography-46.0.3-cp311-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:0abf1ffd6e57c67e92af68330d05760b7b7efb243aab8377e583284dbab72c71", size = 4426782 },
{ url = "https://files.pythonhosted.org/packages/fa/b1/ebacbfe53317d55cf33165bda24c86523497a6881f339f9aae5c2e13e57b/cryptography-46.0.3-cp311-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:a04bee9ab6a4da801eb9b51f1b708a1b5b5c9eb48c03f74198464c66f0d344ac", size = 4698381 },
@@ -314,8 +322,10 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/c5/fd/bc1daf8230eaa075184cbbf5f8cd00ba9db4fd32d63fb83da4671b72ed8a/cryptography-46.0.3-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:39b6755623145ad5eff1dab323f4eae2a32a77a7abef2c5089a04a3d04366715", size = 4435078 },
{ url = "https://files.pythonhosted.org/packages/82/98/d3bd5407ce4c60017f8ff9e63ffee4200ab3e23fe05b765cab805a7db008/cryptography-46.0.3-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:db391fa7c66df6762ee3f00c95a89e6d428f4d60e7abc8328f4fe155b5ac6e54", size = 4293460 },
{ url = "https://files.pythonhosted.org/packages/26/e9/e23e7900983c2b8af7a08098db406cf989d7f09caea7897e347598d4cd5b/cryptography-46.0.3-cp314-cp314t-manylinux_2_28_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:78a97cf6a8839a48c49271cdcbd5cf37ca2c1d6b7fdd86cc864f302b5e9bf459", size = 3995237 },
{ url = "https://files.pythonhosted.org/packages/91/15/af68c509d4a138cfe299d0d7ddb14afba15233223ebd933b4bbdbc7155d3/cryptography-46.0.3-cp314-cp314t-manylinux_2_28_ppc64le.whl", hash = "sha256:dfb781ff7eaa91a6f7fd41776ec37c5853c795d3b358d4896fdbb5df168af422", size = 4967344 },
{ url = "https://files.pythonhosted.org/packages/ca/e3/8643d077c53868b681af077edf6b3cb58288b5423610f21c62aadcbe99f4/cryptography-46.0.3-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:6f61efb26e76c45c4a227835ddeae96d83624fb0d29eb5df5b96e14ed1a0afb7", size = 4466564 },
{ url = "https://files.pythonhosted.org/packages/0e/43/c1e8726fa59c236ff477ff2b5dc071e54b21e5a1e51aa2cee1676f1c986f/cryptography-46.0.3-cp314-cp314t-manylinux_2_34_aarch64.whl", hash = "sha256:23b1a8f26e43f47ceb6d6a43115f33a5a37d57df4ea0ca295b780ae8546e8044", size = 4292415 },
{ url = "https://files.pythonhosted.org/packages/42/f9/2f8fefdb1aee8a8e3256a0568cffc4e6d517b256a2fe97a029b3f1b9fe7e/cryptography-46.0.3-cp314-cp314t-manylinux_2_34_ppc64le.whl", hash = "sha256:b419ae593c86b87014b9be7396b385491ad7f320bde96826d0dd174459e54665", size = 4931457 },
{ url = "https://files.pythonhosted.org/packages/79/30/9b54127a9a778ccd6d27c3da7563e9f2d341826075ceab89ae3b41bf5be2/cryptography-46.0.3-cp314-cp314t-manylinux_2_34_x86_64.whl", hash = "sha256:50fc3343ac490c6b08c0cf0d704e881d0d660be923fd3076db3e932007e726e3", size = 4466074 },
{ url = "https://files.pythonhosted.org/packages/ac/68/b4f4a10928e26c941b1b6a179143af9f4d27d88fe84a6a3c53592d2e76bf/cryptography-46.0.3-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:22d7e97932f511d6b0b04f2bfd818d73dcd5928db509460aaf48384778eb6d20", size = 4420569 },
{ url = "https://files.pythonhosted.org/packages/a3/49/3746dab4c0d1979888f125226357d3262a6dd40e114ac29e3d2abdf1ec55/cryptography-46.0.3-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:d55f3dffadd674514ad19451161118fd010988540cee43d8bc20675e775925de", size = 4681941 },
@@ -323,8 +333,10 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/26/42/fa8389d4478368743e24e61eea78846a0006caffaf72ea24a15159215a14/cryptography-46.0.3-cp38-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:15ab9b093e8f09daab0f2159bb7e47532596075139dd74365da52ecc9cb46c5d", size = 4440029 },
{ url = "https://files.pythonhosted.org/packages/5f/eb/f483db0ec5ac040824f269e93dd2bd8a21ecd1027e77ad7bdf6914f2fd80/cryptography-46.0.3-cp38-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:46acf53b40ea38f9c6c229599a4a13f0d46a6c3fa9ef19fc1a124d62e338dfa0", size = 4297222 },
{ url = "https://files.pythonhosted.org/packages/fd/cf/da9502c4e1912cb1da3807ea3618a6829bee8207456fbbeebc361ec38ba3/cryptography-46.0.3-cp38-abi3-manylinux_2_28_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:10ca84c4668d066a9878890047f03546f3ae0a6b8b39b697457b7757aaf18dbc", size = 4012280 },
{ url = "https://files.pythonhosted.org/packages/6b/8f/9adb86b93330e0df8b3dcf03eae67c33ba89958fc2e03862ef1ac2b42465/cryptography-46.0.3-cp38-abi3-manylinux_2_28_ppc64le.whl", hash = "sha256:36e627112085bb3b81b19fed209c05ce2a52ee8b15d161b7c643a7d5a88491f3", size = 4978958 },
{ url = "https://files.pythonhosted.org/packages/d1/a0/5fa77988289c34bdb9f913f5606ecc9ada1adb5ae870bd0d1054a7021cc4/cryptography-46.0.3-cp38-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:1000713389b75c449a6e979ffc7dcc8ac90b437048766cef052d4d30b8220971", size = 4473714 },
{ url = "https://files.pythonhosted.org/packages/14/e5/fc82d72a58d41c393697aa18c9abe5ae1214ff6f2a5c18ac470f92777895/cryptography-46.0.3-cp38-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:b02cf04496f6576afffef5ddd04a0cb7d49cf6be16a9059d793a30b035f6b6ac", size = 4296970 },
{ url = "https://files.pythonhosted.org/packages/78/06/5663ed35438d0b09056973994f1aec467492b33bd31da36e468b01ec1097/cryptography-46.0.3-cp38-abi3-manylinux_2_34_ppc64le.whl", hash = "sha256:71e842ec9bc7abf543b47cf86b9a743baa95f4677d22baa4c7d5c69e49e9bc04", size = 4940236 },
{ url = "https://files.pythonhosted.org/packages/fc/59/873633f3f2dcd8a053b8dd1d38f783043b5fce589c0f6988bf55ef57e43e/cryptography-46.0.3-cp38-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:402b58fc32614f00980b66d6e56a5b4118e6cb362ae8f3fda141ba4689bd4506", size = 4472642 },
{ url = "https://files.pythonhosted.org/packages/3d/39/8e71f3930e40f6877737d6f69248cf74d4e34b886a3967d32f919cc50d3b/cryptography-46.0.3-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:ef639cb3372f69ec44915fafcd6698b6cc78fbe0c2ea41be867f6ed612811963", size = 4423126 },
{ url = "https://files.pythonhosted.org/packages/cd/c7/f65027c2810e14c3e7268353b1681932b87e5a48e65505d8cc17c99e36ae/cryptography-46.0.3-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:3b51b8ca4f1c6453d8829e1eb7299499ca7f313900dd4d89a24b8b87c0a780d4", size = 4686573 },
@@ -403,7 +415,7 @@ requires-dist = [
{ name = "mflux", specifier = "==0.15.4" },
{ name = "mlx", marker = "sys_platform == 'darwin'", specifier = "==0.30.4" },
{ name = "mlx", extras = ["cpu"], marker = "sys_platform == 'linux'", specifier = "==0.30.4" },
{ name = "mlx-lm", specifier = "==0.30.5" },
{ name = "mlx-lm", git = "https://github.com/ml-explore/mlx-lm?branch=main" },
{ name = "openai-harmony", specifier = ">=0.0.8" },
{ name = "pillow", specifier = ">=11.0,<12.0" },
{ name = "psutil", specifier = ">=7.0.0" },
@@ -1061,7 +1073,7 @@ wheels = [
[[package]]
name = "mlx-lm"
version = "0.30.5"
source = { registry = "https://pypi.org/simple" }
source = { git = "https://github.com/ml-explore/mlx-lm?branch=main#96699e6dadb13b82b28285bb131a0741997d19ae" }
dependencies = [
{ name = "jinja2", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "mlx", marker = "sys_platform == 'darwin'" },
@@ -1071,10 +1083,6 @@ dependencies = [
{ name = "sentencepiece", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "transformers", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/0b/90/4469d9f75f196e6255f59a89441abe0079925d30a001462e1c1c4bc4e6a1/mlx_lm-0.30.5.tar.gz", hash = "sha256:9e6cb258c65b766c6af25cb90958aef40acab67139f05839eef19864cb3154f6", size = 262367 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/89/ba/66db6e1e5f1ef506655b562932f6bd8f72600116d5f31f92d71c1f200b3f/mlx_lm-0.30.5-py3-none-any.whl", hash = "sha256:a80bc8e3efdebe81813b0f6eb403fb66a7a15071e256f4e7102ada986acb75bb", size = 366716 },
]
[[package]]
name = "mlx-metal"
@@ -1495,9 +1503,6 @@ version = "11.3.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/f3/0d/d0d6dea55cd152ce3d6767bb38a8fc10e33796ba4ba210cbab9354b6d238/pillow-11.3.0.tar.gz", hash = "sha256:3828ee7586cd0b2091b6209e5ad53e20d0649bbe87164a459d0676e035e8f523", size = 47113069 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/1e/93/0952f2ed8db3a5a4c7a11f91965d6184ebc8cd7cbb7941a260d5f018cd2d/pillow-11.3.0-cp313-cp313-ios_13_0_arm64_iphoneos.whl", hash = "sha256:1c627742b539bba4309df89171356fcb3cc5a9178355b2727d1b74a6cf155fbd", size = 2128328 },
{ url = "https://files.pythonhosted.org/packages/4b/e8/100c3d114b1a0bf4042f27e0f87d2f25e857e838034e98ca98fe7b8c0a9c/pillow-11.3.0-cp313-cp313-ios_13_0_arm64_iphonesimulator.whl", hash = "sha256:30b7c02f3899d10f13d7a48163c8969e4e653f8b43416d23d13d1bbfdc93b9f8", size = 2170652 },
{ url = "https://files.pythonhosted.org/packages/aa/86/3f758a28a6e381758545f7cdb4942e1cb79abd271bea932998fc0db93cb6/pillow-11.3.0-cp313-cp313-ios_13_0_x86_64_iphonesimulator.whl", hash = "sha256:7859a4cc7c9295f5838015d8cc0a9c215b77e43d07a25e460f35cf516df8626f", size = 2227443 },
{ url = "https://files.pythonhosted.org/packages/01/f4/91d5b3ffa718df2f53b0dc109877993e511f4fd055d7e9508682e8aba092/pillow-11.3.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:ec1ee50470b0d050984394423d96325b744d55c701a439d2bd66089bff963d3c", size = 5278474 },
{ url = "https://files.pythonhosted.org/packages/f9/0e/37d7d3eca6c879fbd9dba21268427dffda1ab00d4eb05b32923d4fbe3b12/pillow-11.3.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:7db51d222548ccfd274e4572fdbf3e810a5e66b00608862f947b163e613b67dd", size = 4686038 },
{ url = "https://files.pythonhosted.org/packages/ff/b0/3426e5c7f6565e752d81221af9d3676fdbb4f352317ceafd42899aaf5d8a/pillow-11.3.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:2d6fcc902a24ac74495df63faad1884282239265c6839a0a6416d33faedfae7e", size = 5864407 },
@@ -2273,7 +2278,7 @@ wheels = [
[[package]]
name = "transformers"
version = "5.0.0rc3"
version = "5.0.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "filelock", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
@@ -2282,15 +2287,14 @@ dependencies = [
{ name = "packaging", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "pyyaml", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "regex", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "requests", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "safetensors", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "tokenizers", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "tqdm", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "typer-slim", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/3f/a3/7c116a8d85f69ea7749cf4c2df79e64c35d028e5fc7ea0168f299d03b8c7/transformers-5.0.0rc3.tar.gz", hash = "sha256:a0315b92b7e087617ade42ec9e6e92ee7620541cc5d6a3331886c52cbe306f5c", size = 8388520 }
sdist = { url = "https://files.pythonhosted.org/packages/bc/79/845941711811789c85fb7e2599cea425a14a07eda40f50896b9d3fda7492/transformers-5.0.0.tar.gz", hash = "sha256:5f5634efed6cf76ad068cc5834c7adbc32db78bbd6211fb70df2325a9c37dec8", size = 8424830 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/1e/f2/ae2b8968764253bdf38a48dee3c299b8d0bedf7c8ffbe3449fca9bd95338/transformers-5.0.0rc3-py3-none-any.whl", hash = "sha256:383fad27f4f73092d330e45fae384681e5c8521e1dc1cf6cb1a297780e68bf2d", size = 10107087 },
{ url = "https://files.pythonhosted.org/packages/52/f3/ac976fa8e305c9e49772527e09fbdc27cc6831b8a2f6b6063406626be5dd/transformers-5.0.0-py3-none-any.whl", hash = "sha256:587086f249ce64c817213cf36afdb318d087f790723e9b3d4500b97832afd52d", size = 10142091 },
]
[[package]]