mirror of
https://github.com/exo-explore/exo.git
synced 2026-01-30 16:51:02 -05:00
Compare commits
1 Commits
ciaran/pro
...
pyproject-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c890bd0beb |
@@ -178,36 +178,6 @@ interface ImageApiResponse {
|
||||
data: Array<{ b64_json?: string; url?: string }>;
|
||||
}
|
||||
|
||||
// Trace API response types
|
||||
export interface TraceCategoryStats {
|
||||
totalUs: number;
|
||||
count: number;
|
||||
minUs: number;
|
||||
maxUs: number;
|
||||
avgUs: number;
|
||||
}
|
||||
|
||||
export interface TraceRankStats {
|
||||
byCategory: Record<string, TraceCategoryStats>;
|
||||
}
|
||||
|
||||
export interface TraceStatsResponse {
|
||||
taskId: string;
|
||||
totalWallTimeUs: number;
|
||||
byCategory: Record<string, TraceCategoryStats>;
|
||||
byRank: Record<number, TraceRankStats>;
|
||||
}
|
||||
|
||||
export interface TraceListItem {
|
||||
taskId: string;
|
||||
createdAt: string;
|
||||
fileSize: number;
|
||||
}
|
||||
|
||||
export interface TraceListResponse {
|
||||
traces: TraceListItem[];
|
||||
}
|
||||
|
||||
interface RawStateResponse {
|
||||
topology?: RawTopology;
|
||||
instances?: Record<
|
||||
@@ -2585,49 +2555,6 @@ class AppStore {
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* List all available traces
|
||||
*/
|
||||
async listTraces(): Promise<TraceListResponse> {
|
||||
const response = await fetch("/v1/traces");
|
||||
if (!response.ok) {
|
||||
throw new Error(`Failed to list traces: ${response.status}`);
|
||||
}
|
||||
return (await response.json()) as TraceListResponse;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a trace exists for a given task ID
|
||||
*/
|
||||
async checkTraceExists(taskId: string): Promise<boolean> {
|
||||
try {
|
||||
const response = await fetch(`/v1/traces/${encodeURIComponent(taskId)}`);
|
||||
return response.ok;
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get computed statistics for a task's trace
|
||||
*/
|
||||
async fetchTraceStats(taskId: string): Promise<TraceStatsResponse> {
|
||||
const response = await fetch(
|
||||
`/v1/traces/${encodeURIComponent(taskId)}/stats`,
|
||||
);
|
||||
if (!response.ok) {
|
||||
throw new Error(`Failed to fetch trace stats: ${response.status}`);
|
||||
}
|
||||
return (await response.json()) as TraceStatsResponse;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the URL for the raw trace file (for Perfetto)
|
||||
*/
|
||||
getTraceRawUrl(taskId: string): string {
|
||||
return `/v1/traces/${encodeURIComponent(taskId)}/raw`;
|
||||
}
|
||||
}
|
||||
|
||||
export const appStore = new AppStore();
|
||||
@@ -2739,12 +2666,3 @@ export const startDownload = (nodeId: string, shardMetadata: object) =>
|
||||
appStore.startDownload(nodeId, shardMetadata);
|
||||
export const deleteDownload = (nodeId: string, modelId: string) =>
|
||||
appStore.deleteDownload(nodeId, modelId);
|
||||
|
||||
// Trace actions
|
||||
export const listTraces = () => appStore.listTraces();
|
||||
export const checkTraceExists = (taskId: string) =>
|
||||
appStore.checkTraceExists(taskId);
|
||||
export const fetchTraceStats = (taskId: string) =>
|
||||
appStore.fetchTraceStats(taskId);
|
||||
export const getTraceRawUrl = (taskId: string) =>
|
||||
appStore.getTraceRawUrl(taskId);
|
||||
|
||||
@@ -1,172 +0,0 @@
|
||||
<script lang="ts">
|
||||
import { onMount } from "svelte";
|
||||
import {
|
||||
listTraces,
|
||||
getTraceRawUrl,
|
||||
type TraceListItem,
|
||||
} from "$lib/stores/app.svelte";
|
||||
import HeaderNav from "$lib/components/HeaderNav.svelte";
|
||||
|
||||
let traces = $state<TraceListItem[]>([]);
|
||||
let loading = $state(true);
|
||||
let error = $state<string | null>(null);
|
||||
|
||||
function formatBytes(bytes: number): string {
|
||||
if (!bytes || bytes <= 0) return "0B";
|
||||
const units = ["B", "KB", "MB", "GB"];
|
||||
const i = Math.min(
|
||||
Math.floor(Math.log(bytes) / Math.log(1024)),
|
||||
units.length - 1,
|
||||
);
|
||||
const val = bytes / Math.pow(1024, i);
|
||||
return `${val.toFixed(val >= 10 ? 0 : 1)}${units[i]}`;
|
||||
}
|
||||
|
||||
function formatDate(isoString: string): string {
|
||||
const date = new Date(isoString);
|
||||
return date.toLocaleString();
|
||||
}
|
||||
|
||||
async function openInPerfetto(taskId: string) {
|
||||
// Fetch trace data from our local API
|
||||
const response = await fetch(getTraceRawUrl(taskId));
|
||||
const traceData = await response.arrayBuffer();
|
||||
|
||||
// Open Perfetto UI
|
||||
const perfettoWindow = window.open("https://ui.perfetto.dev");
|
||||
if (!perfettoWindow) {
|
||||
alert("Failed to open Perfetto. Please allow popups.");
|
||||
return;
|
||||
}
|
||||
|
||||
// Wait for Perfetto to be ready, then send trace via postMessage
|
||||
const onMessage = (e: MessageEvent) => {
|
||||
if (e.data === "PONG") {
|
||||
window.removeEventListener("message", onMessage);
|
||||
perfettoWindow.postMessage(
|
||||
{
|
||||
perfetto: {
|
||||
buffer: traceData,
|
||||
title: `Trace ${taskId}`,
|
||||
},
|
||||
},
|
||||
"https://ui.perfetto.dev",
|
||||
);
|
||||
}
|
||||
};
|
||||
window.addEventListener("message", onMessage);
|
||||
|
||||
// Ping Perfetto until it responds
|
||||
const pingInterval = setInterval(() => {
|
||||
perfettoWindow.postMessage("PING", "https://ui.perfetto.dev");
|
||||
}, 50);
|
||||
|
||||
// Clean up after 10 seconds
|
||||
setTimeout(() => {
|
||||
clearInterval(pingInterval);
|
||||
window.removeEventListener("message", onMessage);
|
||||
}, 10000);
|
||||
}
|
||||
|
||||
async function refresh() {
|
||||
loading = true;
|
||||
error = null;
|
||||
try {
|
||||
const response = await listTraces();
|
||||
traces = response.traces;
|
||||
} catch (e) {
|
||||
error = e instanceof Error ? e.message : "Failed to load traces";
|
||||
} finally {
|
||||
loading = false;
|
||||
}
|
||||
}
|
||||
|
||||
onMount(() => {
|
||||
refresh();
|
||||
});
|
||||
</script>
|
||||
|
||||
<div class="min-h-screen bg-exo-dark-gray text-white">
|
||||
<HeaderNav showHome={true} />
|
||||
<div class="max-w-7xl mx-auto px-4 lg:px-8 py-6 space-y-6">
|
||||
<div class="flex items-center justify-between gap-4 flex-wrap">
|
||||
<div>
|
||||
<h1
|
||||
class="text-2xl font-mono tracking-[0.2em] uppercase text-exo-yellow"
|
||||
>
|
||||
Traces
|
||||
</h1>
|
||||
</div>
|
||||
<div class="flex items-center gap-3">
|
||||
<button
|
||||
type="button"
|
||||
class="text-xs font-mono text-exo-light-gray hover:text-exo-yellow transition-colors uppercase border border-exo-medium-gray/40 px-2 py-1 rounded"
|
||||
onclick={refresh}
|
||||
disabled={loading}
|
||||
>
|
||||
Refresh
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{#if loading}
|
||||
<div
|
||||
class="rounded border border-exo-medium-gray/30 bg-exo-black/30 p-6 text-center text-exo-light-gray"
|
||||
>
|
||||
<div class="text-sm">Loading traces...</div>
|
||||
</div>
|
||||
{:else if error}
|
||||
<div
|
||||
class="rounded border border-red-500/30 bg-red-500/10 p-6 text-center text-red-400"
|
||||
>
|
||||
<div class="text-sm">{error}</div>
|
||||
</div>
|
||||
{:else if traces.length === 0}
|
||||
<div
|
||||
class="rounded border border-exo-medium-gray/30 bg-exo-black/30 p-6 text-center text-exo-light-gray space-y-2"
|
||||
>
|
||||
<div class="text-sm">No traces found.</div>
|
||||
<div class="text-xs text-exo-light-gray/70">
|
||||
Run exo with EXO_TRACING_ENABLED=1 to collect traces.
|
||||
</div>
|
||||
</div>
|
||||
{:else}
|
||||
<div class="space-y-3">
|
||||
{#each traces as trace}
|
||||
<div
|
||||
class="rounded border border-exo-medium-gray/30 bg-exo-black/30 p-4 flex items-center justify-between gap-4"
|
||||
>
|
||||
<div class="min-w-0 flex-1">
|
||||
<a
|
||||
href="#/traces/{trace.taskId}"
|
||||
class="text-sm font-mono text-white hover:text-exo-yellow transition-colors truncate block"
|
||||
>
|
||||
{trace.taskId}
|
||||
</a>
|
||||
<div class="text-xs text-exo-light-gray font-mono mt-1">
|
||||
{formatDate(trace.createdAt)} • {formatBytes(
|
||||
trace.fileSize,
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
<div class="flex items-center gap-2 shrink-0">
|
||||
<a
|
||||
href="#/traces/{trace.taskId}"
|
||||
class="text-xs font-mono text-exo-light-gray hover:text-exo-yellow transition-colors uppercase border border-exo-medium-gray/40 px-2 py-1 rounded"
|
||||
>
|
||||
View Stats
|
||||
</a>
|
||||
<button
|
||||
type="button"
|
||||
class="text-xs font-mono text-exo-dark-gray bg-exo-yellow hover:bg-exo-yellow/90 transition-colors uppercase px-2 py-1 rounded font-semibold"
|
||||
onclick={() => openInPerfetto(trace.taskId)}
|
||||
>
|
||||
View Trace
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
{/each}
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
</div>
|
||||
@@ -1,347 +0,0 @@
|
||||
<script lang="ts">
|
||||
import { page } from "$app/stores";
|
||||
import { onMount } from "svelte";
|
||||
import {
|
||||
fetchTraceStats,
|
||||
getTraceRawUrl,
|
||||
type TraceStatsResponse,
|
||||
type TraceCategoryStats,
|
||||
} from "$lib/stores/app.svelte";
|
||||
import HeaderNav from "$lib/components/HeaderNav.svelte";
|
||||
|
||||
const taskId = $derived($page.params.taskId);
|
||||
|
||||
let stats = $state<TraceStatsResponse | null>(null);
|
||||
let loading = $state(true);
|
||||
let error = $state<string | null>(null);
|
||||
|
||||
function formatDuration(us: number): string {
|
||||
if (us < 1000) return `${us.toFixed(0)}us`;
|
||||
if (us < 1_000_000) return `${(us / 1000).toFixed(2)}ms`;
|
||||
return `${(us / 1_000_000).toFixed(2)}s`;
|
||||
}
|
||||
|
||||
function formatPercentage(part: number, total: number): string {
|
||||
if (total === 0) return "0.0%";
|
||||
return `${((part / total) * 100).toFixed(1)}%`;
|
||||
}
|
||||
|
||||
// Parse hierarchical categories like "sync/compute" into phases
|
||||
type PhaseData = {
|
||||
name: string;
|
||||
subcategories: { name: string; stats: TraceCategoryStats }[];
|
||||
totalUs: number; // From outer span (e.g., "sync" category)
|
||||
stepCount: number; // Count of outer span events
|
||||
};
|
||||
|
||||
function parsePhases(
|
||||
byCategory: Record<string, TraceCategoryStats>,
|
||||
): PhaseData[] {
|
||||
const phases = new Map<
|
||||
string,
|
||||
{
|
||||
subcats: Map<string, TraceCategoryStats>;
|
||||
outerStats: TraceCategoryStats | null;
|
||||
}
|
||||
>();
|
||||
|
||||
for (const [category, catStats] of Object.entries(byCategory)) {
|
||||
if (category.includes("/")) {
|
||||
const [phase, subcat] = category.split("/", 2);
|
||||
if (!phases.has(phase)) {
|
||||
phases.set(phase, { subcats: new Map(), outerStats: null });
|
||||
}
|
||||
phases.get(phase)!.subcats.set(subcat, catStats);
|
||||
} else {
|
||||
// Outer span - this IS the phase total
|
||||
if (!phases.has(category)) {
|
||||
phases.set(category, { subcats: new Map(), outerStats: null });
|
||||
}
|
||||
phases.get(category)!.outerStats = catStats;
|
||||
}
|
||||
}
|
||||
|
||||
return Array.from(phases.entries())
|
||||
.filter(([_, data]) => data.outerStats !== null) // Only phases with outer spans
|
||||
.map(([name, data]) => ({
|
||||
name,
|
||||
subcategories: Array.from(data.subcats.entries())
|
||||
.map(([subName, subStats]) => ({ name: subName, stats: subStats }))
|
||||
.sort((a, b) => b.stats.totalUs - a.stats.totalUs),
|
||||
totalUs: data.outerStats!.totalUs, // Outer span total
|
||||
stepCount: data.outerStats!.count, // Number of steps
|
||||
}))
|
||||
.sort((a, b) => b.totalUs - a.totalUs);
|
||||
}
|
||||
|
||||
async function openInPerfetto() {
|
||||
if (!taskId) return;
|
||||
|
||||
// Fetch trace data from our local API
|
||||
const response = await fetch(getTraceRawUrl(taskId));
|
||||
const traceData = await response.arrayBuffer();
|
||||
|
||||
// Open Perfetto UI
|
||||
const perfettoWindow = window.open("https://ui.perfetto.dev");
|
||||
if (!perfettoWindow) {
|
||||
alert("Failed to open Perfetto. Please allow popups.");
|
||||
return;
|
||||
}
|
||||
|
||||
// Wait for Perfetto to be ready, then send trace via postMessage
|
||||
const onMessage = (e: MessageEvent) => {
|
||||
if (e.data === "PONG") {
|
||||
window.removeEventListener("message", onMessage);
|
||||
perfettoWindow.postMessage(
|
||||
{
|
||||
perfetto: {
|
||||
buffer: traceData,
|
||||
title: `Trace ${taskId}`,
|
||||
},
|
||||
},
|
||||
"https://ui.perfetto.dev",
|
||||
);
|
||||
}
|
||||
};
|
||||
window.addEventListener("message", onMessage);
|
||||
|
||||
// Ping Perfetto until it responds
|
||||
const pingInterval = setInterval(() => {
|
||||
perfettoWindow.postMessage("PING", "https://ui.perfetto.dev");
|
||||
}, 50);
|
||||
|
||||
// Clean up after 10 seconds
|
||||
setTimeout(() => {
|
||||
clearInterval(pingInterval);
|
||||
window.removeEventListener("message", onMessage);
|
||||
}, 10000);
|
||||
}
|
||||
|
||||
onMount(async () => {
|
||||
if (!taskId) {
|
||||
error = "No task ID provided";
|
||||
loading = false;
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
stats = await fetchTraceStats(taskId);
|
||||
} catch (e) {
|
||||
error = e instanceof Error ? e.message : "Failed to load trace";
|
||||
} finally {
|
||||
loading = false;
|
||||
}
|
||||
});
|
||||
|
||||
const phases = $derived(stats ? parsePhases(stats.byCategory) : []);
|
||||
const sortedRanks = $derived(
|
||||
stats
|
||||
? Object.keys(stats.byRank)
|
||||
.map(Number)
|
||||
.sort((a, b) => a - b)
|
||||
: [],
|
||||
);
|
||||
const nodeCount = $derived(sortedRanks.length || 1);
|
||||
</script>
|
||||
|
||||
<div class="min-h-screen bg-exo-dark-gray text-white">
|
||||
<HeaderNav showHome={true} />
|
||||
<div class="max-w-7xl mx-auto px-4 lg:px-8 py-6 space-y-6">
|
||||
<div class="flex items-center justify-between gap-4 flex-wrap">
|
||||
<div>
|
||||
<h1
|
||||
class="text-2xl font-mono tracking-[0.2em] uppercase text-exo-yellow"
|
||||
>
|
||||
Trace
|
||||
</h1>
|
||||
<p class="text-sm text-exo-light-gray font-mono truncate max-w-lg">
|
||||
{taskId}
|
||||
</p>
|
||||
</div>
|
||||
<div class="flex items-center gap-3">
|
||||
<a
|
||||
href="#/traces"
|
||||
class="text-xs font-mono text-exo-light-gray hover:text-exo-yellow transition-colors uppercase border border-exo-medium-gray/40 px-3 py-1.5 rounded"
|
||||
>
|
||||
All Traces
|
||||
</a>
|
||||
<button
|
||||
type="button"
|
||||
class="text-xs font-mono text-exo-dark-gray bg-exo-yellow hover:bg-exo-yellow/90 transition-colors uppercase px-3 py-1.5 rounded font-semibold"
|
||||
onclick={openInPerfetto}
|
||||
disabled={loading || !!error}
|
||||
>
|
||||
View Trace
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{#if loading}
|
||||
<div
|
||||
class="rounded border border-exo-medium-gray/30 bg-exo-black/30 p-6 text-center text-exo-light-gray"
|
||||
>
|
||||
<div class="text-sm">Loading trace data...</div>
|
||||
</div>
|
||||
{:else if error}
|
||||
<div
|
||||
class="rounded border border-red-500/30 bg-red-500/10 p-6 text-center text-red-400"
|
||||
>
|
||||
<div class="text-sm">{error}</div>
|
||||
</div>
|
||||
{:else if stats}
|
||||
<!-- Wall Time Summary -->
|
||||
<div
|
||||
class="rounded border border-exo-medium-gray/30 bg-exo-black/30 p-4 space-y-2"
|
||||
>
|
||||
<h2
|
||||
class="text-sm font-mono uppercase tracking-wider text-exo-light-gray"
|
||||
>
|
||||
Summary
|
||||
</h2>
|
||||
<div class="text-3xl font-mono text-exo-yellow">
|
||||
{formatDuration(stats.totalWallTimeUs)}
|
||||
</div>
|
||||
<div class="text-xs text-exo-light-gray">Total wall time</div>
|
||||
</div>
|
||||
|
||||
<!-- By Phase -->
|
||||
{#if phases.length > 0}
|
||||
<div
|
||||
class="rounded border border-exo-medium-gray/30 bg-exo-black/30 p-4 space-y-4"
|
||||
>
|
||||
<h2
|
||||
class="text-sm font-mono uppercase tracking-wider text-exo-light-gray"
|
||||
>
|
||||
By Phase <span class="text-exo-light-gray/50">(avg per node)</span>
|
||||
</h2>
|
||||
<div class="space-y-4">
|
||||
{#each phases as phase}
|
||||
{@const normalizedTotal = phase.totalUs / nodeCount}
|
||||
{@const normalizedStepCount = phase.stepCount / nodeCount}
|
||||
<div class="space-y-2">
|
||||
<div class="flex items-center justify-between">
|
||||
<span class="text-sm font-mono text-white">{phase.name}</span>
|
||||
<span class="text-sm font-mono">
|
||||
<span class="text-exo-yellow"
|
||||
>{formatDuration(normalizedTotal)}</span
|
||||
>
|
||||
<span class="text-exo-light-gray ml-2">
|
||||
({normalizedStepCount} steps, {formatDuration(
|
||||
normalizedTotal / normalizedStepCount,
|
||||
)}/step)
|
||||
</span>
|
||||
</span>
|
||||
</div>
|
||||
{#if phase.subcategories.length > 0}
|
||||
<div class="pl-4 space-y-1.5">
|
||||
{#each phase.subcategories as subcat}
|
||||
{@const normalizedSubcat =
|
||||
subcat.stats.totalUs / nodeCount}
|
||||
{@const pct = formatPercentage(
|
||||
normalizedSubcat,
|
||||
normalizedTotal,
|
||||
)}
|
||||
{@const perStep = normalizedSubcat / normalizedStepCount}
|
||||
<div
|
||||
class="flex items-center justify-between text-xs font-mono"
|
||||
>
|
||||
<span class="text-exo-light-gray">{subcat.name}</span>
|
||||
<span class="text-white">
|
||||
{formatDuration(normalizedSubcat)}
|
||||
<span class="text-exo-light-gray ml-2">({pct})</span>
|
||||
<span class="text-exo-light-gray/60 ml-2"
|
||||
>{formatDuration(perStep)}/step</span
|
||||
>
|
||||
</span>
|
||||
</div>
|
||||
<!-- Progress bar -->
|
||||
<div
|
||||
class="relative h-1.5 bg-exo-black/60 rounded-sm overflow-hidden"
|
||||
>
|
||||
<div
|
||||
class="absolute inset-y-0 left-0 bg-gradient-to-r from-exo-yellow to-exo-yellow/70 transition-all duration-300"
|
||||
style="width: {pct}"
|
||||
></div>
|
||||
</div>
|
||||
{/each}
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
{/each}
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<!-- By Rank -->
|
||||
{#if sortedRanks.length > 0}
|
||||
<div
|
||||
class="rounded border border-exo-medium-gray/30 bg-exo-black/30 p-4 space-y-4"
|
||||
>
|
||||
<h2
|
||||
class="text-sm font-mono uppercase tracking-wider text-exo-light-gray"
|
||||
>
|
||||
By Rank
|
||||
</h2>
|
||||
<div class="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 gap-4">
|
||||
{#each sortedRanks as rank}
|
||||
{@const rankStats = stats.byRank[rank]}
|
||||
{@const rankPhases = parsePhases(rankStats.byCategory)}
|
||||
<div
|
||||
class="rounded border border-exo-medium-gray/20 bg-exo-dark-gray/60 p-3 space-y-3"
|
||||
>
|
||||
<div class="text-sm font-mono text-exo-yellow">
|
||||
Rank {rank}
|
||||
</div>
|
||||
<div class="space-y-2">
|
||||
{#each rankPhases as phase}
|
||||
<div class="space-y-1">
|
||||
<div class="flex items-center justify-between text-xs">
|
||||
<span class="font-mono text-exo-light-gray"
|
||||
>{phase.name}</span
|
||||
>
|
||||
<span class="font-mono text-white">
|
||||
{formatDuration(phase.totalUs)}
|
||||
<span class="text-exo-light-gray/50 ml-1">
|
||||
({phase.stepCount}x)
|
||||
</span>
|
||||
</span>
|
||||
</div>
|
||||
{#if phase.subcategories.length > 0}
|
||||
<div class="pl-2 space-y-0.5">
|
||||
{#each phase.subcategories as subcat}
|
||||
{@const pct = formatPercentage(
|
||||
subcat.stats.totalUs,
|
||||
phase.totalUs,
|
||||
)}
|
||||
{@const perStep =
|
||||
subcat.stats.totalUs / phase.stepCount}
|
||||
<div
|
||||
class="flex items-center justify-between text-[10px] font-mono"
|
||||
>
|
||||
<span class="text-exo-light-gray/70"
|
||||
>{subcat.name}</span
|
||||
>
|
||||
<span class="text-exo-light-gray">
|
||||
{formatDuration(subcat.stats.totalUs)}
|
||||
<span class="text-exo-light-gray/50"
|
||||
>({pct})</span
|
||||
>
|
||||
<span class="text-exo-light-gray/30 ml-1"
|
||||
>{formatDuration(perStep)}/step</span
|
||||
>
|
||||
</span>
|
||||
</div>
|
||||
{/each}
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
{/each}
|
||||
</div>
|
||||
</div>
|
||||
{/each}
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
{/if}
|
||||
</div>
|
||||
</div>
|
||||
@@ -19,7 +19,7 @@ dependencies = [
|
||||
"anyio==4.11.0",
|
||||
"mlx==0.30.4; sys_platform == 'darwin'",
|
||||
"mlx[cpu]==0.30.4; sys_platform == 'linux'",
|
||||
"mlx-lm",
|
||||
"mlx-lm==0.30.5",
|
||||
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
|
||||
"hypercorn>=0.18.0",
|
||||
"openai-harmony>=0.0.8",
|
||||
@@ -31,8 +31,6 @@ dependencies = [
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
exo-master = "exo.master.main:main"
|
||||
exo-worker = "exo.worker.main:main"
|
||||
exo = "exo.main:main"
|
||||
|
||||
# dependencies only required for development
|
||||
@@ -46,12 +44,6 @@ dev = [
|
||||
"ruff>=0.11.13",
|
||||
]
|
||||
|
||||
# mlx[cuda] requires a newer version of mlx. the ideal on linux is: default to mlx[cpu] unless[cuda] specified.
|
||||
[project.optional-dependencies]
|
||||
# cuda = [
|
||||
# "mlx[cuda]==0.26.3",
|
||||
# ]
|
||||
|
||||
###
|
||||
# workspace configuration
|
||||
###
|
||||
@@ -63,8 +55,8 @@ members = [
|
||||
|
||||
[tool.uv.sources]
|
||||
exo_pyo3_bindings = { workspace = true }
|
||||
mlx-lm = { git = "https://github.com/ml-explore/mlx-lm", branch = "main" }
|
||||
# Uncomment to use local mlx/mlx-lm development versions:
|
||||
# mlx-lm = { git = "https://github.com/ml-explore/mlx-lm", branch = "main" }
|
||||
# mlx = { path = "/Users/Shared/mlx", editable=true }
|
||||
# mlx-lm = { path = "/Users/Shared/mlx-lm", editable=true }
|
||||
|
||||
@@ -116,7 +108,7 @@ environments = [
|
||||
###
|
||||
|
||||
[tool.ruff]
|
||||
extend-exclude = ["shared/protobufs/**", "*mlx_typings/**", "rust/exo_pyo3_bindings/**"]
|
||||
extend-exclude = ["*mlx_typings/**", "rust/exo_pyo3_bindings/**"]
|
||||
|
||||
[tool.ruff.lint]
|
||||
extend-select = ["I", "N", "B", "A", "PIE", "SIM"]
|
||||
|
||||
@@ -21,7 +21,7 @@ def exo_shard_downloader(max_parallel_downloads: int = 8) -> ShardDownloader:
|
||||
|
||||
|
||||
async def build_base_shard(model_id: ModelId) -> ShardMetadata:
|
||||
model_card = await ModelCard.load(model_id)
|
||||
model_card = await ModelCard.from_hf(model_id)
|
||||
return PipelineShardMetadata(
|
||||
model_card=model_card,
|
||||
device_rank=0,
|
||||
|
||||
@@ -3,9 +3,7 @@ import contextlib
|
||||
import json
|
||||
import time
|
||||
from collections.abc import AsyncGenerator
|
||||
from datetime import datetime, timezone
|
||||
from http import HTTPStatus
|
||||
from pathlib import Path
|
||||
from typing import Annotated, Literal, cast
|
||||
from uuid import uuid4
|
||||
|
||||
@@ -35,7 +33,6 @@ from exo.shared.models.model_cards import (
|
||||
ModelCard,
|
||||
ModelId,
|
||||
)
|
||||
from exo.shared.tracing import compute_stats, load_trace_file
|
||||
from exo.shared.types.api import (
|
||||
AdvancedImageParams,
|
||||
BenchChatCompletionResponse,
|
||||
@@ -70,13 +67,6 @@ from exo.shared.types.api import (
|
||||
StreamingChoiceResponse,
|
||||
StreamOptions,
|
||||
ToolCall,
|
||||
TraceCategoryStats,
|
||||
TraceEventResponse,
|
||||
TraceListItem,
|
||||
TraceListResponse,
|
||||
TraceRankStats,
|
||||
TraceResponse,
|
||||
TraceStatsResponse,
|
||||
Usage,
|
||||
)
|
||||
from exo.shared.types.chunks import (
|
||||
@@ -156,6 +146,18 @@ def chunk_to_response(
|
||||
)
|
||||
|
||||
|
||||
async def resolve_model_card(model_id: ModelId) -> ModelCard:
|
||||
if model_id in MODEL_CARDS:
|
||||
model_card = MODEL_CARDS[model_id]
|
||||
return model_card
|
||||
|
||||
for card in MODEL_CARDS.values():
|
||||
if card.model_id == ModelId(model_id):
|
||||
return card
|
||||
|
||||
return await ModelCard.from_hf(model_id)
|
||||
|
||||
|
||||
class API:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -274,14 +276,10 @@ class API:
|
||||
self.app.get("/events")(lambda: self._event_log)
|
||||
self.app.post("/download/start")(self.start_download)
|
||||
self.app.delete("/download/{node_id}/{model_id:path}")(self.delete_download)
|
||||
self.app.get("/v1/traces")(self.list_traces)
|
||||
self.app.get("/v1/traces/{task_id}")(self.get_trace)
|
||||
self.app.get("/v1/traces/{task_id}/stats")(self.get_trace_stats)
|
||||
self.app.get("/v1/traces/{task_id}/raw")(self.get_trace_raw)
|
||||
|
||||
async def place_instance(self, payload: PlaceInstanceParams):
|
||||
command = PlaceInstance(
|
||||
model_card=await ModelCard.load(payload.model_id),
|
||||
model_card=await resolve_model_card(payload.model_id),
|
||||
sharding=payload.sharding,
|
||||
instance_meta=payload.instance_meta,
|
||||
min_nodes=payload.min_nodes,
|
||||
@@ -298,7 +296,7 @@ class API:
|
||||
self, payload: CreateInstanceParams
|
||||
) -> CreateInstanceResponse:
|
||||
instance = payload.instance
|
||||
model_card = await ModelCard.load(instance.shard_assignments.model_id)
|
||||
model_card = await resolve_model_card(instance.shard_assignments.model_id)
|
||||
required_memory = model_card.storage_size
|
||||
available_memory = self._calculate_total_available_memory()
|
||||
|
||||
@@ -326,7 +324,7 @@ class API:
|
||||
instance_meta: InstanceMeta = InstanceMeta.MlxRing,
|
||||
min_nodes: int = 1,
|
||||
) -> Instance:
|
||||
model_card = await ModelCard.load(model_id)
|
||||
model_card = await resolve_model_card(model_id)
|
||||
|
||||
try:
|
||||
placements = get_instance_placements(
|
||||
@@ -567,7 +565,7 @@ class API:
|
||||
|
||||
text_parts: list[str] = []
|
||||
tool_calls: list[ToolCall] = []
|
||||
model: ModelId | None = None
|
||||
model: str | None = None
|
||||
finish_reason: FinishReason | None = None
|
||||
usage: Usage | None = None
|
||||
|
||||
@@ -626,7 +624,7 @@ class API:
|
||||
) -> BenchChatCompletionResponse:
|
||||
text_parts: list[str] = []
|
||||
tool_calls: list[ToolCall] = []
|
||||
model: ModelId | None = None
|
||||
model: str | None = None
|
||||
finish_reason: FinishReason | None = None
|
||||
|
||||
stats: GenerationStats | None = None
|
||||
@@ -679,7 +677,7 @@ class API:
|
||||
)
|
||||
return resp
|
||||
|
||||
async def _trigger_notify_user_to_download_model(self, model_id: ModelId) -> None:
|
||||
async def _trigger_notify_user_to_download_model(self, model_id: str) -> None:
|
||||
logger.warning(
|
||||
"TODO: we should send a notification to the user to download the model"
|
||||
)
|
||||
@@ -688,7 +686,7 @@ class API:
|
||||
self, payload: ChatCompletionTaskParams
|
||||
) -> ChatCompletionResponse | StreamingResponse:
|
||||
"""Handle chat completions, supporting both streaming and non-streaming responses."""
|
||||
model_card = await ModelCard.load(ModelId(payload.model))
|
||||
model_card = await resolve_model_card(ModelId(payload.model))
|
||||
payload.model = model_card.model_id
|
||||
|
||||
if not any(
|
||||
@@ -715,7 +713,7 @@ class API:
|
||||
async def bench_chat_completions(
|
||||
self, payload: BenchChatCompletionTaskParams
|
||||
) -> BenchChatCompletionResponse:
|
||||
model_card = await ModelCard.load(ModelId(payload.model))
|
||||
model_card = await resolve_model_card(ModelId(payload.model))
|
||||
payload.model = model_card.model_id
|
||||
|
||||
if not any(
|
||||
@@ -735,12 +733,12 @@ class API:
|
||||
response = await self._collect_chat_completion_with_stats(command.command_id)
|
||||
return response
|
||||
|
||||
async def _validate_image_model(self, model: ModelId) -> ModelId:
|
||||
async def _validate_image_model(self, model: str) -> ModelId:
|
||||
"""Validate model exists and return resolved model ID.
|
||||
|
||||
Raises HTTPException 404 if no instance is found for the model.
|
||||
"""
|
||||
model_card = await ModelCard.load(model)
|
||||
model_card = await resolve_model_card(ModelId(model))
|
||||
resolved_model = model_card.model_id
|
||||
if not any(
|
||||
instance.shard_assignments.model_id == resolved_model
|
||||
@@ -786,7 +784,7 @@ class API:
|
||||
When stream=True and partial_images > 0, returns a StreamingResponse
|
||||
with SSE-formatted events for partial and final images.
|
||||
"""
|
||||
payload.model = await self._validate_image_model(ModelId(payload.model))
|
||||
payload.model = await self._validate_image_model(payload.model)
|
||||
|
||||
command = ImageGeneration(
|
||||
request_params=payload,
|
||||
@@ -1031,7 +1029,7 @@ class API:
|
||||
async def bench_image_generations(
|
||||
self, request: Request, payload: BenchImageGenerationTaskParams
|
||||
) -> BenchImageGenerationResponse:
|
||||
payload.model = await self._validate_image_model(ModelId(payload.model))
|
||||
payload.model = await self._validate_image_model(payload.model)
|
||||
|
||||
payload.stream = False
|
||||
payload.partial_images = 0
|
||||
@@ -1052,7 +1050,7 @@ class API:
|
||||
self,
|
||||
image: UploadFile,
|
||||
prompt: str,
|
||||
model: ModelId,
|
||||
model: str,
|
||||
n: int,
|
||||
size: str,
|
||||
response_format: Literal["url", "b64_json"],
|
||||
@@ -1147,7 +1145,7 @@ class API:
|
||||
command = await self._send_image_edits_command(
|
||||
image=image,
|
||||
prompt=prompt,
|
||||
model=ModelId(model),
|
||||
model=model,
|
||||
n=n,
|
||||
size=size,
|
||||
response_format=response_format,
|
||||
@@ -1203,7 +1201,7 @@ class API:
|
||||
command = await self._send_image_edits_command(
|
||||
image=image,
|
||||
prompt=prompt,
|
||||
model=ModelId(model),
|
||||
model=model,
|
||||
n=n,
|
||||
size=size,
|
||||
response_format=response_format,
|
||||
@@ -1350,110 +1348,3 @@ class API:
|
||||
)
|
||||
await self._send_download(command)
|
||||
return DeleteDownloadResponse(command_id=command.command_id)
|
||||
|
||||
def _get_traces_dir(self) -> Path:
|
||||
return Path.home() / ".exo" / "traces"
|
||||
|
||||
def _get_trace_path(self, task_id: str) -> Path:
|
||||
return self._get_traces_dir() / f"trace_{task_id}.json"
|
||||
|
||||
async def list_traces(self) -> TraceListResponse:
|
||||
traces_dir = self._get_traces_dir()
|
||||
traces: list[TraceListItem] = []
|
||||
|
||||
if not traces_dir.exists():
|
||||
return TraceListResponse(traces=[])
|
||||
|
||||
for trace_file in sorted(
|
||||
traces_dir.glob("trace_*.json"),
|
||||
key=lambda p: p.stat().st_mtime,
|
||||
reverse=True,
|
||||
):
|
||||
# Extract task_id from filename (trace_{task_id}.json)
|
||||
task_id = trace_file.stem.removeprefix("trace_")
|
||||
stat = trace_file.stat()
|
||||
created_at = datetime.fromtimestamp(
|
||||
stat.st_mtime, tz=timezone.utc
|
||||
).isoformat()
|
||||
traces.append(
|
||||
TraceListItem(
|
||||
task_id=task_id,
|
||||
created_at=created_at,
|
||||
file_size=stat.st_size,
|
||||
)
|
||||
)
|
||||
|
||||
return TraceListResponse(traces=traces)
|
||||
|
||||
async def get_trace(self, task_id: str) -> TraceResponse:
|
||||
trace_path = self._get_trace_path(task_id)
|
||||
|
||||
if not trace_path.exists():
|
||||
raise HTTPException(status_code=404, detail=f"Trace not found: {task_id}")
|
||||
|
||||
trace_events = load_trace_file(trace_path)
|
||||
|
||||
return TraceResponse(
|
||||
task_id=task_id,
|
||||
traces=[
|
||||
TraceEventResponse(
|
||||
name=event.name,
|
||||
start_us=event.start_us,
|
||||
duration_us=event.duration_us,
|
||||
rank=event.rank,
|
||||
category=event.category,
|
||||
)
|
||||
for event in trace_events
|
||||
],
|
||||
)
|
||||
|
||||
async def get_trace_stats(self, task_id: str) -> TraceStatsResponse:
|
||||
trace_path = self._get_trace_path(task_id)
|
||||
|
||||
if not trace_path.exists():
|
||||
raise HTTPException(status_code=404, detail=f"Trace not found: {task_id}")
|
||||
|
||||
trace_events = load_trace_file(trace_path)
|
||||
stats = compute_stats(trace_events)
|
||||
|
||||
return TraceStatsResponse(
|
||||
task_id=task_id,
|
||||
total_wall_time_us=stats.total_wall_time_us,
|
||||
by_category={
|
||||
category: TraceCategoryStats(
|
||||
total_us=cat_stats.total_us,
|
||||
count=cat_stats.count,
|
||||
min_us=cat_stats.min_us,
|
||||
max_us=cat_stats.max_us,
|
||||
avg_us=cat_stats.avg_us,
|
||||
)
|
||||
for category, cat_stats in stats.by_category.items()
|
||||
},
|
||||
by_rank={
|
||||
rank: TraceRankStats(
|
||||
by_category={
|
||||
category: TraceCategoryStats(
|
||||
total_us=cat_stats.total_us,
|
||||
count=cat_stats.count,
|
||||
min_us=cat_stats.min_us,
|
||||
max_us=cat_stats.max_us,
|
||||
avg_us=cat_stats.avg_us,
|
||||
)
|
||||
for category, cat_stats in rank_stats.items()
|
||||
}
|
||||
)
|
||||
for rank, rank_stats in stats.by_rank.items()
|
||||
},
|
||||
)
|
||||
|
||||
async def get_trace_raw(self, task_id: str) -> FileResponse:
|
||||
trace_path = self._get_trace_path(task_id)
|
||||
|
||||
if not trace_path.exists():
|
||||
raise HTTPException(status_code=404, detail=f"Trace not found: {task_id}")
|
||||
|
||||
return FileResponse(
|
||||
path=trace_path,
|
||||
media_type="application/json",
|
||||
filename=f"trace_{task_id}.json",
|
||||
)
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from pathlib import Path
|
||||
|
||||
import anyio
|
||||
from anyio.abc import TaskGroup
|
||||
@@ -12,7 +11,6 @@ from exo.master.placement import (
|
||||
place_instance,
|
||||
)
|
||||
from exo.shared.apply import apply
|
||||
from exo.shared.tracing import TraceEvent, export_trace, is_tracing_enabled
|
||||
from exo.shared.types.commands import (
|
||||
ChatCompletion,
|
||||
CreateInstance,
|
||||
@@ -37,8 +35,6 @@ from exo.shared.types.events import (
|
||||
NodeTimedOut,
|
||||
TaskCreated,
|
||||
TaskDeleted,
|
||||
TraceEventData,
|
||||
TracesCollected,
|
||||
)
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.types.tasks import (
|
||||
@@ -90,8 +86,6 @@ class Master:
|
||||
self._multi_buffer = MultiSourceBuffer[NodeId, Event]()
|
||||
# TODO: not have this
|
||||
self._event_log: list[Event] = []
|
||||
self._pending_traces: dict[TaskId, dict[int, list[TraceEventData]]] = {}
|
||||
self._expected_ranks: dict[TaskId, set[int]] = {}
|
||||
|
||||
async def run(self):
|
||||
logger.info("Starting Master")
|
||||
@@ -193,14 +187,13 @@ class Master:
|
||||
)
|
||||
|
||||
task_id = TaskId()
|
||||
selected_instance_id = available_instance_ids[0]
|
||||
generated_events.append(
|
||||
TaskCreated(
|
||||
task_id=task_id,
|
||||
task=ImageGenerationTask(
|
||||
task_id=task_id,
|
||||
command_id=command.command_id,
|
||||
instance_id=selected_instance_id,
|
||||
instance_id=available_instance_ids[0],
|
||||
task_status=TaskStatus.Pending,
|
||||
task_params=command.request_params,
|
||||
),
|
||||
@@ -208,17 +201,6 @@ class Master:
|
||||
)
|
||||
|
||||
self.command_task_mapping[command.command_id] = task_id
|
||||
|
||||
if is_tracing_enabled():
|
||||
selected_instance = self.state.instances.get(
|
||||
selected_instance_id
|
||||
)
|
||||
if selected_instance:
|
||||
ranks = set(
|
||||
shard.device_rank
|
||||
for shard in selected_instance.shard_assignments.runner_to_shard.values()
|
||||
)
|
||||
self._expected_ranks[task_id] = ranks
|
||||
case ImageEdits():
|
||||
for instance in self.state.instances.values():
|
||||
if (
|
||||
@@ -247,14 +229,13 @@ class Master:
|
||||
)
|
||||
|
||||
task_id = TaskId()
|
||||
selected_instance_id = available_instance_ids[0]
|
||||
generated_events.append(
|
||||
TaskCreated(
|
||||
task_id=task_id,
|
||||
task=ImageEditsTask(
|
||||
task_id=task_id,
|
||||
command_id=command.command_id,
|
||||
instance_id=selected_instance_id,
|
||||
instance_id=available_instance_ids[0],
|
||||
task_status=TaskStatus.Pending,
|
||||
task_params=command.request_params,
|
||||
),
|
||||
@@ -262,17 +243,6 @@ class Master:
|
||||
)
|
||||
|
||||
self.command_task_mapping[command.command_id] = task_id
|
||||
|
||||
if is_tracing_enabled():
|
||||
selected_instance = self.state.instances.get(
|
||||
selected_instance_id
|
||||
)
|
||||
if selected_instance:
|
||||
ranks = set(
|
||||
shard.device_rank
|
||||
for shard in selected_instance.shard_assignments.runner_to_shard.values()
|
||||
)
|
||||
self._expected_ranks[task_id] = ranks
|
||||
case DeleteInstance():
|
||||
placement = delete_instance(command, self.state.instances)
|
||||
transition_events = get_transition_events(
|
||||
@@ -365,10 +335,6 @@ class Master:
|
||||
local_event.origin,
|
||||
)
|
||||
for event in self._multi_buffer.drain():
|
||||
if isinstance(event, TracesCollected):
|
||||
self._handle_traces_collected(event)
|
||||
continue
|
||||
|
||||
logger.debug(f"Master indexing event: {str(event)[:100]}")
|
||||
indexed = IndexedEvent(event=event, idx=len(self._event_log))
|
||||
self.state = apply(self.state, indexed)
|
||||
@@ -407,38 +373,3 @@ class Master:
|
||||
event=event.event,
|
||||
)
|
||||
)
|
||||
|
||||
def _handle_traces_collected(self, event: TracesCollected) -> None:
|
||||
task_id = event.task_id
|
||||
if task_id not in self._pending_traces:
|
||||
self._pending_traces[task_id] = {}
|
||||
self._pending_traces[task_id][event.rank] = event.traces
|
||||
|
||||
if (
|
||||
task_id in self._expected_ranks
|
||||
and set(self._pending_traces[task_id].keys())
|
||||
>= self._expected_ranks[task_id]
|
||||
):
|
||||
self._merge_and_save_traces(task_id)
|
||||
|
||||
def _merge_and_save_traces(self, task_id: TaskId) -> None:
|
||||
all_traces: list[TraceEvent] = []
|
||||
for trace_data in self._pending_traces[task_id].values():
|
||||
for t in trace_data:
|
||||
all_traces.append(
|
||||
TraceEvent(
|
||||
name=t.name,
|
||||
start_us=t.start_us,
|
||||
duration_us=t.duration_us,
|
||||
rank=t.rank,
|
||||
category=t.category,
|
||||
)
|
||||
)
|
||||
|
||||
output_path = Path.home() / ".exo" / "traces" / f"trace_{task_id}.json"
|
||||
export_trace(all_traces, output_path)
|
||||
logger.info(f"Merged traces saved to {output_path}")
|
||||
|
||||
del self._pending_traces[task_id]
|
||||
if task_id in self._expected_ranks:
|
||||
del self._expected_ranks[task_id]
|
||||
|
||||
@@ -216,8 +216,6 @@ def get_node_id_keypair(
|
||||
Obtains the :class:`Keypair` associated with this node-ID.
|
||||
Obtain the :class:`PeerId` by from it.
|
||||
"""
|
||||
# TODO(evan): bring back node id persistence once we figure out how to deal with duplicates
|
||||
return Keypair.generate_ed25519()
|
||||
|
||||
def lock_path(path: str | bytes | PathLike[str] | PathLike[bytes]) -> Path:
|
||||
return Path(str(path) + ".lock")
|
||||
|
||||
@@ -25,7 +25,6 @@ from exo.shared.types.events import (
|
||||
TestEvent,
|
||||
TopologyEdgeCreated,
|
||||
TopologyEdgeDeleted,
|
||||
TracesCollected,
|
||||
)
|
||||
from exo.shared.types.profiling import (
|
||||
NodeIdentity,
|
||||
@@ -56,11 +55,7 @@ def event_apply(event: Event, state: State) -> State:
|
||||
"""Apply an event to state."""
|
||||
match event:
|
||||
case (
|
||||
TestEvent()
|
||||
| ChunkGenerated()
|
||||
| TaskAcknowledged()
|
||||
| InputChunkReceived()
|
||||
| TracesCollected()
|
||||
TestEvent() | ChunkGenerated() | TaskAcknowledged() | InputChunkReceived()
|
||||
): # Pass-through events that don't modify state
|
||||
return state
|
||||
case InstanceCreated():
|
||||
|
||||
@@ -53,9 +53,3 @@ EXO_IMAGE_CACHE_DIR = EXO_CACHE_HOME / "images"
|
||||
EXO_ENABLE_IMAGE_MODELS = (
|
||||
os.getenv("EXO_ENABLE_IMAGE_MODELS", "false").lower() == "true"
|
||||
)
|
||||
|
||||
EXO_TRACING_ENABLED = os.getenv("EXO_TRACING_ENABLED", "").lower() in (
|
||||
"1",
|
||||
"true",
|
||||
"yes",
|
||||
)
|
||||
|
||||
@@ -8,7 +8,7 @@ from multiprocessing.synchronize import Event as EventT
|
||||
from multiprocessing.synchronize import Semaphore as SemaphoreT
|
||||
|
||||
from loguru import logger
|
||||
from pytest import LogCaptureFixture, mark
|
||||
from pytest import LogCaptureFixture
|
||||
|
||||
from exo.routing.router import get_node_id_keypair
|
||||
from exo.shared.constants import EXO_NODE_ID_KEYPAIR
|
||||
@@ -74,7 +74,6 @@ def _delete_if_exists(p: str | bytes | os.PathLike[str] | os.PathLike[bytes]):
|
||||
os.remove(p)
|
||||
|
||||
|
||||
@mark.skip(reason="this functionality is currently disabled but may return in future")
|
||||
def test_node_id_fetching(caplog: LogCaptureFixture):
|
||||
reps = 10
|
||||
|
||||
|
||||
@@ -1,450 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from contextvars import ContextVar
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import cast, final
|
||||
|
||||
from exo.shared.constants import EXO_TRACING_ENABLED
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Context variable to track the current trace category for hierarchical nesting
|
||||
_current_category: ContextVar[str | None] = ContextVar("current_category", default=None)
|
||||
|
||||
|
||||
@final
|
||||
@dataclass(frozen=True)
|
||||
class TraceEvent:
|
||||
name: str
|
||||
start_us: int
|
||||
duration_us: int
|
||||
rank: int
|
||||
category: str
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
class CategoryStats:
|
||||
total_us: int = 0
|
||||
count: int = 0
|
||||
min_us: int = 0
|
||||
max_us: int = 0
|
||||
|
||||
def add(self, duration_us: int) -> None:
|
||||
if self.count == 0:
|
||||
self.min_us = duration_us
|
||||
self.max_us = duration_us
|
||||
else:
|
||||
self.min_us = min(self.min_us, duration_us)
|
||||
self.max_us = max(self.max_us, duration_us)
|
||||
self.total_us += duration_us
|
||||
self.count += 1
|
||||
|
||||
@property
|
||||
def avg_us(self) -> float:
|
||||
return self.total_us / self.count if self.count > 0 else 0.0
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
class TraceStats:
|
||||
total_wall_time_us: int = 0
|
||||
by_category: dict[str, CategoryStats] = field(default_factory=dict)
|
||||
by_rank: dict[int, dict[str, CategoryStats]] = field(default_factory=dict)
|
||||
|
||||
|
||||
# Global trace buffer - each rank accumulates traces here
|
||||
_trace_buffer: list[TraceEvent] = []
|
||||
|
||||
|
||||
def is_tracing_enabled() -> bool:
|
||||
"""Check if tracing is enabled via environment variable."""
|
||||
return EXO_TRACING_ENABLED
|
||||
|
||||
|
||||
def _record_span(
|
||||
name: str, start_us: int, duration_us: int, rank: int, category: str
|
||||
) -> None:
|
||||
_trace_buffer.append(
|
||||
TraceEvent(
|
||||
name=name,
|
||||
start_us=start_us,
|
||||
duration_us=duration_us,
|
||||
rank=rank,
|
||||
category=category,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def trace(
|
||||
name: str,
|
||||
rank: int,
|
||||
category: str = "compute",
|
||||
) -> Generator[None, None, None]:
|
||||
"""Context manager to trace any operation.
|
||||
|
||||
Nested traces automatically inherit the parent category, creating hierarchical
|
||||
categories like "sync/compute" or "async/comms".
|
||||
|
||||
Args:
|
||||
name: Name of the operation (e.g., "recv 0", "send 1", "joint_blocks")
|
||||
rank: This rank's ID
|
||||
category: Category for grouping in trace viewer ("comm", "compute", "step")
|
||||
|
||||
Example:
|
||||
with trace(f"sync {t}", rank, "sync"):
|
||||
with trace("joint_blocks", rank, "compute"):
|
||||
# Recorded with category "sync/compute"
|
||||
hidden_states = some_computation(...)
|
||||
"""
|
||||
if not is_tracing_enabled():
|
||||
yield
|
||||
return
|
||||
|
||||
# Combine with parent category if nested
|
||||
parent = _current_category.get()
|
||||
full_category = f"{parent}/{category}" if parent else category
|
||||
|
||||
# Set as current for nested traces
|
||||
token = _current_category.set(full_category)
|
||||
|
||||
try:
|
||||
start_us = int(time.time() * 1_000_000)
|
||||
start_perf = time.perf_counter()
|
||||
yield
|
||||
duration_us = int((time.perf_counter() - start_perf) * 1_000_000)
|
||||
_record_span(name, start_us, duration_us, rank, full_category)
|
||||
finally:
|
||||
_current_category.reset(token)
|
||||
|
||||
|
||||
def get_trace_buffer() -> list[TraceEvent]:
|
||||
return list(_trace_buffer)
|
||||
|
||||
|
||||
def clear_trace_buffer() -> None:
|
||||
_trace_buffer.clear()
|
||||
|
||||
|
||||
def export_trace(traces: list[TraceEvent], output_path: Path) -> None:
|
||||
trace_events: list[dict[str, object]] = []
|
||||
|
||||
for event in traces:
|
||||
# Chrome trace format uses "X" for complete events (with duration)
|
||||
chrome_event: dict[str, object] = {
|
||||
"name": event.name,
|
||||
"cat": event.category,
|
||||
"ph": "X",
|
||||
"ts": event.start_us,
|
||||
"dur": event.duration_us,
|
||||
"pid": 0,
|
||||
"tid": event.rank,
|
||||
"args": {"rank": event.rank},
|
||||
}
|
||||
trace_events.append(chrome_event)
|
||||
|
||||
ranks_seen = set(t.rank for t in traces)
|
||||
for rank in ranks_seen:
|
||||
trace_events.append(
|
||||
{
|
||||
"name": "thread_name",
|
||||
"ph": "M", # Metadata event
|
||||
"pid": 0,
|
||||
"tid": rank,
|
||||
"args": {"name": f"Rank {rank}"},
|
||||
}
|
||||
)
|
||||
|
||||
chrome_trace = {"traceEvents": trace_events}
|
||||
|
||||
try:
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(output_path, "w") as f:
|
||||
json.dump(chrome_trace, f, indent=2)
|
||||
except OSError as e:
|
||||
logger.warning("Failed to export trace to %s: %s", output_path, e)
|
||||
|
||||
|
||||
def export_local_traces(rank: int) -> None:
|
||||
if not is_tracing_enabled():
|
||||
return
|
||||
|
||||
local_traces = get_trace_buffer()
|
||||
if local_traces:
|
||||
output_path = Path.home() / ".exo" / "traces" / f"trace_{rank}.json"
|
||||
try:
|
||||
export_trace(local_traces, output_path)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to export local traces for rank %d: %s", rank, e)
|
||||
|
||||
clear_trace_buffer()
|
||||
|
||||
|
||||
def merge_trace_files(trace_dir: Path | None = None) -> Path | None:
|
||||
if trace_dir is None:
|
||||
trace_dir = Path.home() / ".exo" / "traces"
|
||||
|
||||
if not trace_dir.exists():
|
||||
return None
|
||||
|
||||
trace_files = sorted(trace_dir.glob("trace_*.json"))
|
||||
|
||||
if not trace_files:
|
||||
return None
|
||||
|
||||
merged_events: list[dict[str, object]] = []
|
||||
for trace_file in trace_files:
|
||||
file_rank = int(trace_file.stem.split("_")[1])
|
||||
|
||||
with open(trace_file) as f:
|
||||
raw = f.read()
|
||||
data = cast(dict[str, list[dict[str, object]]], json.loads(raw))
|
||||
events: list[dict[str, object]] = data.get("traceEvents", [])
|
||||
for event in events:
|
||||
event["tid"] = file_rank
|
||||
if "args" in event and isinstance(event["args"], dict):
|
||||
event["args"]["rank"] = file_rank
|
||||
merged_events.extend(events)
|
||||
|
||||
output_path = Path.home() / ".exo" / "traces" / "merged_trace.json"
|
||||
try:
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(output_path, "w") as f:
|
||||
json.dump({"traceEvents": merged_events}, f, indent=2)
|
||||
except OSError as e:
|
||||
logger.warning("Failed to write merged trace to %s: %s", output_path, e)
|
||||
return None
|
||||
|
||||
return output_path
|
||||
|
||||
|
||||
def _format_duration(us: int | float) -> str:
|
||||
if us < 1000:
|
||||
return f"{us:.0f}µs"
|
||||
elif us < 1_000_000:
|
||||
return f"{us / 1000:.2f}ms"
|
||||
else:
|
||||
return f"{us / 1_000_000:.2f}s"
|
||||
|
||||
|
||||
def load_trace_file(path: Path) -> list[TraceEvent]:
|
||||
"""Load a Chrome Trace Format JSON file into TraceEvent objects."""
|
||||
with open(path) as f:
|
||||
data = cast(dict[str, list[dict[str, object]]], json.load(f))
|
||||
|
||||
events = data.get("traceEvents", [])
|
||||
traces: list[TraceEvent] = []
|
||||
|
||||
for event in events:
|
||||
# Skip metadata events
|
||||
if event.get("ph") == "M":
|
||||
continue
|
||||
|
||||
name = str(event.get("name", ""))
|
||||
category = str(event.get("cat", ""))
|
||||
ts_value = event.get("ts", 0)
|
||||
dur_value = event.get("dur", 0)
|
||||
tid_value = event.get("tid", 0)
|
||||
start_us = int(ts_value) if isinstance(ts_value, (int, float, str)) else 0
|
||||
duration_us = int(dur_value) if isinstance(dur_value, (int, float, str)) else 0
|
||||
|
||||
# Get rank from tid or args
|
||||
rank = int(tid_value) if isinstance(tid_value, (int, float, str)) else 0
|
||||
args = event.get("args")
|
||||
if isinstance(args, dict):
|
||||
args_dict = cast(dict[str, object], args)
|
||||
rank_from_args = args_dict.get("rank")
|
||||
if isinstance(rank_from_args, (int, float, str)):
|
||||
rank = int(rank_from_args)
|
||||
|
||||
traces.append(
|
||||
TraceEvent(
|
||||
name=name,
|
||||
start_us=start_us,
|
||||
duration_us=duration_us,
|
||||
rank=rank,
|
||||
category=category,
|
||||
)
|
||||
)
|
||||
|
||||
return traces
|
||||
|
||||
|
||||
def compute_stats(traces: list[TraceEvent]) -> TraceStats:
|
||||
"""Compute comprehensive statistics from trace events."""
|
||||
stats = TraceStats()
|
||||
|
||||
if not traces:
|
||||
return stats
|
||||
|
||||
# Calculate wall time from earliest start to latest end
|
||||
min_start = min(t.start_us for t in traces)
|
||||
max_end = max(t.start_us + t.duration_us for t in traces)
|
||||
stats.total_wall_time_us = max_end - min_start
|
||||
|
||||
# Initialize nested dicts
|
||||
by_category: dict[str, CategoryStats] = defaultdict(CategoryStats)
|
||||
by_rank: dict[int, dict[str, CategoryStats]] = defaultdict(
|
||||
lambda: defaultdict(CategoryStats)
|
||||
)
|
||||
|
||||
for event in traces:
|
||||
# By category
|
||||
by_category[event.category].add(event.duration_us)
|
||||
|
||||
# By rank and category
|
||||
by_rank[event.rank][event.category].add(event.duration_us)
|
||||
|
||||
stats.by_category = dict(by_category)
|
||||
stats.by_rank = {k: dict(v) for k, v in by_rank.items()}
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
def print_stats(stats: TraceStats) -> None:
|
||||
"""Print formatted trace statistics."""
|
||||
print("=== Trace Statistics ===")
|
||||
print()
|
||||
print(f"Wall Time: {_format_duration(stats.total_wall_time_us)}")
|
||||
print()
|
||||
|
||||
# Parse hierarchical categories (e.g., "sync/compute" -> phase="sync", subcat="compute")
|
||||
if stats.by_category:
|
||||
phases: dict[str, dict[str, CategoryStats]] = defaultdict(dict)
|
||||
has_hierarchical = False
|
||||
|
||||
for cat, cat_stats in stats.by_category.items():
|
||||
if "/" in cat:
|
||||
phase, subcat = cat.split("/", 1)
|
||||
phases[phase][subcat] = cat_stats
|
||||
has_hierarchical = True
|
||||
else:
|
||||
phases[cat]["_total"] = cat_stats
|
||||
|
||||
if has_hierarchical:
|
||||
print("By Phase:")
|
||||
for phase in sorted(phases.keys()):
|
||||
subcats = phases[phase]
|
||||
# Skip phases that only have _total (non-hierarchical top-level categories)
|
||||
non_total_subcats = {k: v for k, v in subcats.items() if k != "_total"}
|
||||
if not non_total_subcats:
|
||||
continue
|
||||
|
||||
phase_total = sum(s.total_us for s in non_total_subcats.values())
|
||||
print(f" {phase}:")
|
||||
for subcat, subcat_stats in sorted(
|
||||
non_total_subcats.items(),
|
||||
key=lambda x: x[1].total_us,
|
||||
reverse=True,
|
||||
):
|
||||
pct = (
|
||||
subcat_stats.total_us / phase_total * 100 if phase_total else 0
|
||||
)
|
||||
# Use parent phase's step count for per-step average
|
||||
phase_step_count = subcats.get("_total", CategoryStats()).count
|
||||
if phase_step_count > 0:
|
||||
avg_per_step = subcat_stats.total_us / phase_step_count
|
||||
else:
|
||||
avg_per_step = subcat_stats.avg_us # fallback
|
||||
print(
|
||||
f" {subcat:12s} {_format_duration(subcat_stats.total_us):>10s} "
|
||||
f"({pct:5.1f}%) avg: {_format_duration(avg_per_step)}"
|
||||
)
|
||||
print()
|
||||
else:
|
||||
# Fall back to flat category display if no hierarchical categories
|
||||
print("By Category:")
|
||||
total_time = sum(c.total_us for c in stats.by_category.values())
|
||||
for category, cat_stats in sorted(
|
||||
stats.by_category.items(), key=lambda x: x[1].total_us, reverse=True
|
||||
):
|
||||
pct = (cat_stats.total_us / total_time * 100) if total_time > 0 else 0
|
||||
print(
|
||||
f" {category:12s} {_format_duration(cat_stats.total_us):>10s} "
|
||||
f"({pct:5.1f}%) avg: {_format_duration(cat_stats.avg_us):>8s} "
|
||||
f"count: {cat_stats.count}"
|
||||
)
|
||||
print()
|
||||
|
||||
# By Rank
|
||||
if stats.by_rank:
|
||||
print("By Rank:")
|
||||
for rank in sorted(stats.by_rank.keys()):
|
||||
rank_stats = stats.by_rank[rank]
|
||||
print(f" Rank {rank}:")
|
||||
|
||||
# Parse hierarchical categories for this rank
|
||||
rank_phases: dict[str, dict[str, CategoryStats]] = defaultdict(dict)
|
||||
has_hierarchical = False
|
||||
for cat, cat_stats in rank_stats.items():
|
||||
if "/" in cat:
|
||||
phase, subcat = cat.split("/", 1)
|
||||
rank_phases[phase][subcat] = cat_stats
|
||||
has_hierarchical = True
|
||||
else:
|
||||
rank_phases[cat]["_total"] = cat_stats
|
||||
|
||||
if has_hierarchical:
|
||||
for phase in sorted(rank_phases.keys()):
|
||||
subcats = rank_phases[phase]
|
||||
non_total_subcats = {
|
||||
k: v for k, v in subcats.items() if k != "_total"
|
||||
}
|
||||
if not non_total_subcats:
|
||||
continue
|
||||
|
||||
phase_total = sum(s.total_us for s in non_total_subcats.values())
|
||||
print(f" {phase}:")
|
||||
for subcat, subcat_stats in sorted(
|
||||
non_total_subcats.items(),
|
||||
key=lambda x: x[1].total_us,
|
||||
reverse=True,
|
||||
):
|
||||
pct = (
|
||||
subcat_stats.total_us / phase_total * 100
|
||||
if phase_total
|
||||
else 0
|
||||
)
|
||||
# Use parent phase's step count for per-step average
|
||||
phase_step_count = subcats.get("_total", CategoryStats()).count
|
||||
if phase_step_count > 0:
|
||||
avg_per_step = subcat_stats.total_us / phase_step_count
|
||||
else:
|
||||
avg_per_step = subcat_stats.avg_us # fallback
|
||||
print(
|
||||
f" {subcat:12s} {_format_duration(subcat_stats.total_us):>10s} "
|
||||
f"({pct:5.1f}%) avg: {_format_duration(avg_per_step)}"
|
||||
)
|
||||
else:
|
||||
# Flat display fallback
|
||||
for category, cat_stats in sorted(
|
||||
rank_stats.items(), key=lambda x: x[1].total_us, reverse=True
|
||||
):
|
||||
print(f" {category}: {_format_duration(cat_stats.total_us)}")
|
||||
print()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
path = Path(sys.argv[1]) if len(sys.argv) > 1 else Path("trace.json")
|
||||
|
||||
if not path.exists():
|
||||
print(f"Error: File not found: {path}")
|
||||
sys.exit(1)
|
||||
|
||||
traces = load_trace_file(path)
|
||||
if not traces:
|
||||
print("No trace events found in file.")
|
||||
sys.exit(0)
|
||||
|
||||
computed_stats = compute_stats(traces)
|
||||
print_stats(computed_stats)
|
||||
@@ -373,45 +373,3 @@ class StartDownloadResponse(CamelCaseModel):
|
||||
|
||||
class DeleteDownloadResponse(CamelCaseModel):
|
||||
command_id: CommandId
|
||||
|
||||
|
||||
class TraceEventResponse(CamelCaseModel):
|
||||
name: str
|
||||
start_us: int
|
||||
duration_us: int
|
||||
rank: int
|
||||
category: str
|
||||
|
||||
|
||||
class TraceResponse(CamelCaseModel):
|
||||
task_id: str
|
||||
traces: list[TraceEventResponse]
|
||||
|
||||
|
||||
class TraceCategoryStats(CamelCaseModel):
|
||||
total_us: int
|
||||
count: int
|
||||
min_us: int
|
||||
max_us: int
|
||||
avg_us: float
|
||||
|
||||
|
||||
class TraceRankStats(CamelCaseModel):
|
||||
by_category: dict[str, TraceCategoryStats]
|
||||
|
||||
|
||||
class TraceStatsResponse(CamelCaseModel):
|
||||
task_id: str
|
||||
total_wall_time_us: int
|
||||
by_category: dict[str, TraceCategoryStats]
|
||||
by_rank: dict[int, TraceRankStats]
|
||||
|
||||
|
||||
class TraceListItem(CamelCaseModel):
|
||||
task_id: str
|
||||
created_at: str
|
||||
file_size: int
|
||||
|
||||
|
||||
class TraceListResponse(CamelCaseModel):
|
||||
traces: list[TraceListItem]
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from datetime import datetime
|
||||
from typing import final
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
@@ -11,7 +10,7 @@ from exo.shared.types.worker.downloads import DownloadProgress
|
||||
from exo.shared.types.worker.instances import Instance, InstanceId
|
||||
from exo.shared.types.worker.runners import RunnerId, RunnerStatus
|
||||
from exo.utils.info_gatherer.info_gatherer import GatheredInfo
|
||||
from exo.utils.pydantic_ext import CamelCaseModel, FrozenModel, TaggedModel
|
||||
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
|
||||
|
||||
|
||||
class EventId(Id):
|
||||
@@ -110,22 +109,6 @@ class TopologyEdgeDeleted(BaseEvent):
|
||||
conn: Connection
|
||||
|
||||
|
||||
@final
|
||||
class TraceEventData(FrozenModel):
|
||||
name: str
|
||||
start_us: int
|
||||
duration_us: int
|
||||
rank: int
|
||||
category: str
|
||||
|
||||
|
||||
@final
|
||||
class TracesCollected(BaseEvent):
|
||||
task_id: TaskId
|
||||
rank: int
|
||||
traces: list[TraceEventData]
|
||||
|
||||
|
||||
Event = (
|
||||
TestEvent
|
||||
| TaskCreated
|
||||
@@ -144,7 +127,6 @@ Event = (
|
||||
| InputChunkReceived
|
||||
| TopologyEdgeCreated
|
||||
| TopologyEdgeDeleted
|
||||
| TracesCollected
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -6,11 +6,6 @@ from mflux.models.common.config.config import Config
|
||||
from mflux.utils.exceptions import StopImageGenerationException
|
||||
from tqdm import tqdm
|
||||
|
||||
from exo.shared.tracing import (
|
||||
clear_trace_buffer,
|
||||
is_tracing_enabled,
|
||||
trace,
|
||||
)
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata
|
||||
from exo.worker.engines.image.config import ImageModelConfig
|
||||
from exo.worker.engines.image.models.base import (
|
||||
@@ -329,7 +324,6 @@ class DiffusionRunner:
|
||||
capture_steps = set()
|
||||
|
||||
self._reset_all_caches()
|
||||
clear_trace_buffer()
|
||||
|
||||
time_steps = tqdm(range(runtime_config.num_inference_steps))
|
||||
|
||||
@@ -471,22 +465,20 @@ class DiffusionRunner:
|
||||
if self.group is None:
|
||||
return self._single_node_step(t, config, latents, prompt_data)
|
||||
elif t < config.init_time_step + num_sync_steps:
|
||||
with trace(name=f"sync {t}", rank=self.rank, category="sync"):
|
||||
return self._sync_pipeline_step(
|
||||
t,
|
||||
config,
|
||||
latents,
|
||||
prompt_data,
|
||||
)
|
||||
return self._sync_pipeline_step(
|
||||
t,
|
||||
config,
|
||||
latents,
|
||||
prompt_data,
|
||||
)
|
||||
else:
|
||||
with trace(name=f"async {t}", rank=self.rank, category="async"):
|
||||
return self._async_pipeline_step(
|
||||
t,
|
||||
config,
|
||||
latents,
|
||||
prompt_data,
|
||||
is_first_async_step=t == config.init_time_step + num_sync_steps,
|
||||
)
|
||||
return self._async_pipeline_step(
|
||||
t,
|
||||
config,
|
||||
latents,
|
||||
prompt_data,
|
||||
is_first_async_step=t == config.init_time_step + num_sync_steps,
|
||||
)
|
||||
|
||||
def _single_node_step(
|
||||
self,
|
||||
@@ -594,41 +586,30 @@ class DiffusionRunner:
|
||||
|
||||
if self.has_joint_blocks:
|
||||
if not self.is_first_stage:
|
||||
with trace(
|
||||
name=f"recv {self.prev_rank}", rank=self.rank, category="comms"
|
||||
):
|
||||
hidden_states = mx.distributed.recv(
|
||||
(batch_size, num_img_tokens, hidden_dim),
|
||||
dtype,
|
||||
self.prev_rank,
|
||||
group=self.group,
|
||||
)
|
||||
encoder_hidden_states = mx.distributed.recv(
|
||||
(batch_size, text_seq_len, hidden_dim),
|
||||
dtype,
|
||||
self.prev_rank,
|
||||
group=self.group,
|
||||
)
|
||||
mx.eval(hidden_states, encoder_hidden_states)
|
||||
hidden_states = mx.distributed.recv(
|
||||
(batch_size, num_img_tokens, hidden_dim),
|
||||
dtype,
|
||||
self.prev_rank,
|
||||
group=self.group,
|
||||
)
|
||||
encoder_hidden_states = mx.distributed.recv(
|
||||
(batch_size, text_seq_len, hidden_dim),
|
||||
dtype,
|
||||
self.prev_rank,
|
||||
group=self.group,
|
||||
)
|
||||
mx.eval(hidden_states, encoder_hidden_states)
|
||||
|
||||
assert self.joint_block_wrappers is not None
|
||||
assert encoder_hidden_states is not None
|
||||
with trace(
|
||||
name="joint_blocks",
|
||||
rank=self.rank,
|
||||
category="compute",
|
||||
):
|
||||
for wrapper in self.joint_block_wrappers:
|
||||
wrapper.set_patch(BlockWrapperMode.CACHING)
|
||||
encoder_hidden_states, hidden_states = wrapper(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
text_embeddings=text_embeddings,
|
||||
rotary_embeddings=image_rotary_embeddings,
|
||||
)
|
||||
|
||||
if is_tracing_enabled():
|
||||
mx.eval(encoder_hidden_states, hidden_states)
|
||||
for wrapper in self.joint_block_wrappers:
|
||||
wrapper.set_patch(BlockWrapperMode.CACHING)
|
||||
encoder_hidden_states, hidden_states = wrapper(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
text_embeddings=text_embeddings,
|
||||
rotary_embeddings=image_rotary_embeddings,
|
||||
)
|
||||
|
||||
if self.owns_concat_stage:
|
||||
assert encoder_hidden_states is not None
|
||||
@@ -639,63 +620,45 @@ class DiffusionRunner:
|
||||
if self.has_single_blocks or self.is_last_stage:
|
||||
hidden_states = concatenated
|
||||
else:
|
||||
with trace(
|
||||
name=f"send {self.next_rank}", rank=self.rank, category="comms"
|
||||
):
|
||||
concatenated = mx.distributed.send(
|
||||
concatenated, self.next_rank, group=self.group
|
||||
)
|
||||
mx.async_eval(concatenated)
|
||||
concatenated = mx.distributed.send(
|
||||
concatenated, self.next_rank, group=self.group
|
||||
)
|
||||
mx.async_eval(concatenated)
|
||||
|
||||
elif self.has_joint_blocks and not self.is_last_stage:
|
||||
assert encoder_hidden_states is not None
|
||||
with trace(name=f"send {self.next_rank}", rank=self.rank, category="comms"):
|
||||
hidden_states = mx.distributed.send(
|
||||
hidden_states, self.next_rank, group=self.group
|
||||
)
|
||||
encoder_hidden_states = mx.distributed.send(
|
||||
encoder_hidden_states, self.next_rank, group=self.group
|
||||
)
|
||||
mx.async_eval(hidden_states, encoder_hidden_states)
|
||||
hidden_states = mx.distributed.send(
|
||||
hidden_states, self.next_rank, group=self.group
|
||||
)
|
||||
encoder_hidden_states = mx.distributed.send(
|
||||
encoder_hidden_states, self.next_rank, group=self.group
|
||||
)
|
||||
mx.async_eval(hidden_states, encoder_hidden_states)
|
||||
|
||||
if self.has_single_blocks:
|
||||
if not self.owns_concat_stage and not self.is_first_stage:
|
||||
with trace(
|
||||
name=f"recv {self.prev_rank}", rank=self.rank, category="comms"
|
||||
):
|
||||
hidden_states = mx.distributed.recv(
|
||||
(batch_size, text_seq_len + num_img_tokens, hidden_dim),
|
||||
dtype,
|
||||
self.prev_rank,
|
||||
group=self.group,
|
||||
)
|
||||
mx.eval(hidden_states)
|
||||
hidden_states = mx.distributed.recv(
|
||||
(batch_size, text_seq_len + num_img_tokens, hidden_dim),
|
||||
dtype,
|
||||
self.prev_rank,
|
||||
group=self.group,
|
||||
)
|
||||
mx.eval(hidden_states)
|
||||
|
||||
assert self.single_block_wrappers is not None
|
||||
with trace(
|
||||
name="single blocks",
|
||||
rank=self.rank,
|
||||
category="compute",
|
||||
):
|
||||
for wrapper in self.single_block_wrappers:
|
||||
wrapper.set_patch(BlockWrapperMode.CACHING)
|
||||
hidden_states = wrapper(
|
||||
hidden_states=hidden_states,
|
||||
text_embeddings=text_embeddings,
|
||||
rotary_embeddings=image_rotary_embeddings,
|
||||
)
|
||||
|
||||
if is_tracing_enabled():
|
||||
mx.eval(hidden_states)
|
||||
for wrapper in self.single_block_wrappers:
|
||||
wrapper.set_patch(BlockWrapperMode.CACHING)
|
||||
hidden_states = wrapper(
|
||||
hidden_states=hidden_states,
|
||||
text_embeddings=text_embeddings,
|
||||
rotary_embeddings=image_rotary_embeddings,
|
||||
)
|
||||
|
||||
if not self.is_last_stage:
|
||||
with trace(
|
||||
name=f"send {self.next_rank}", rank=self.rank, category="comms"
|
||||
):
|
||||
hidden_states = mx.distributed.send(
|
||||
hidden_states, self.next_rank, group=self.group
|
||||
)
|
||||
mx.async_eval(hidden_states)
|
||||
hidden_states = mx.distributed.send(
|
||||
hidden_states, self.next_rank, group=self.group
|
||||
)
|
||||
mx.async_eval(hidden_states)
|
||||
|
||||
hidden_states = hidden_states[:, text_seq_len:, ...]
|
||||
|
||||
@@ -779,20 +742,14 @@ class DiffusionRunner:
|
||||
)
|
||||
|
||||
if not self.is_first_stage:
|
||||
with trace(name="send 0", rank=self.rank, category="comms"):
|
||||
hidden_states = mx.distributed.send(
|
||||
hidden_states, 0, group=self.group
|
||||
)
|
||||
mx.async_eval(hidden_states)
|
||||
hidden_states = mx.distributed.send(hidden_states, 0, group=self.group)
|
||||
mx.async_eval(hidden_states)
|
||||
|
||||
elif self.is_first_stage:
|
||||
with trace(
|
||||
name=f"recv {self.world_size - 1}", rank=self.rank, category="comms"
|
||||
):
|
||||
hidden_states = mx.distributed.recv_like(
|
||||
prev_latents, src=self.world_size - 1, group=self.group
|
||||
)
|
||||
mx.eval(hidden_states)
|
||||
hidden_states = mx.distributed.recv_like(
|
||||
prev_latents, src=self.world_size - 1, group=self.group
|
||||
)
|
||||
mx.eval(hidden_states)
|
||||
|
||||
else:
|
||||
hidden_states = prev_latents
|
||||
@@ -852,13 +809,10 @@ class DiffusionRunner:
|
||||
and not self.is_last_stage
|
||||
and not is_first_async_step
|
||||
):
|
||||
with trace(
|
||||
name=f"recv {self.prev_rank}", rank=self.rank, category="comms"
|
||||
):
|
||||
patch = mx.distributed.recv_like(
|
||||
patch, src=self.prev_rank, group=self.group
|
||||
)
|
||||
mx.eval(patch)
|
||||
patch = mx.distributed.recv_like(
|
||||
patch, src=self.prev_rank, group=self.group
|
||||
)
|
||||
mx.eval(patch)
|
||||
|
||||
step_patch = mx.concatenate([patch, patch], axis=0) if needs_cfg else patch
|
||||
|
||||
@@ -889,13 +843,10 @@ class DiffusionRunner:
|
||||
)
|
||||
|
||||
if not self.is_first_stage and t != config.num_inference_steps - 1:
|
||||
with trace(
|
||||
name=f"send {self.next_rank}", rank=self.rank, category="comms"
|
||||
):
|
||||
patch_latents[patch_idx] = mx.distributed.send(
|
||||
patch_latents[patch_idx], self.next_rank, group=self.group
|
||||
)
|
||||
mx.async_eval(patch_latents[patch_idx])
|
||||
patch_latents[patch_idx] = mx.distributed.send(
|
||||
patch_latents[patch_idx], self.next_rank, group=self.group
|
||||
)
|
||||
mx.async_eval(patch_latents[patch_idx])
|
||||
|
||||
return mx.concatenate(patch_latents, axis=1)
|
||||
|
||||
@@ -934,28 +885,22 @@ class DiffusionRunner:
|
||||
if self.has_joint_blocks:
|
||||
if not self.is_first_stage:
|
||||
patch_len = patch.shape[1]
|
||||
with trace(
|
||||
name=f"recv {self.prev_rank}", rank=self.rank, category="comms"
|
||||
):
|
||||
patch = mx.distributed.recv(
|
||||
(batch_size, patch_len, hidden_dim),
|
||||
patch = mx.distributed.recv(
|
||||
(batch_size, patch_len, hidden_dim),
|
||||
patch.dtype,
|
||||
self.prev_rank,
|
||||
group=self.group,
|
||||
)
|
||||
mx.eval(patch)
|
||||
|
||||
if patch_idx == 0:
|
||||
encoder_hidden_states = mx.distributed.recv(
|
||||
(batch_size, text_seq_len, hidden_dim),
|
||||
patch.dtype,
|
||||
self.prev_rank,
|
||||
group=self.group,
|
||||
)
|
||||
mx.eval(patch)
|
||||
|
||||
if patch_idx == 0:
|
||||
with trace(
|
||||
name=f"recv {self.prev_rank}", rank=self.rank, category="comms"
|
||||
):
|
||||
encoder_hidden_states = mx.distributed.recv(
|
||||
(batch_size, text_seq_len, hidden_dim),
|
||||
patch.dtype,
|
||||
self.prev_rank,
|
||||
group=self.group,
|
||||
)
|
||||
mx.eval(encoder_hidden_states)
|
||||
mx.eval(encoder_hidden_states)
|
||||
|
||||
if self.is_first_stage:
|
||||
patch, encoder_hidden_states = self.adapter.compute_embeddings(
|
||||
@@ -964,22 +909,14 @@ class DiffusionRunner:
|
||||
|
||||
assert self.joint_block_wrappers is not None
|
||||
assert encoder_hidden_states is not None
|
||||
with trace(
|
||||
name=f"joint patch {patch_idx}",
|
||||
rank=self.rank,
|
||||
category="compute",
|
||||
):
|
||||
for wrapper in self.joint_block_wrappers:
|
||||
wrapper.set_patch(BlockWrapperMode.PATCHED, start_token, end_token)
|
||||
encoder_hidden_states, patch = wrapper(
|
||||
hidden_states=patch,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
text_embeddings=text_embeddings,
|
||||
rotary_embeddings=image_rotary_embeddings,
|
||||
)
|
||||
|
||||
if is_tracing_enabled():
|
||||
mx.eval(encoder_hidden_states, patch)
|
||||
for wrapper in self.joint_block_wrappers:
|
||||
wrapper.set_patch(BlockWrapperMode.PATCHED, start_token, end_token)
|
||||
encoder_hidden_states, patch = wrapper(
|
||||
hidden_states=patch,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
text_embeddings=text_embeddings,
|
||||
rotary_embeddings=image_rotary_embeddings,
|
||||
)
|
||||
|
||||
if self.owns_concat_stage:
|
||||
assert encoder_hidden_states is not None
|
||||
@@ -988,70 +925,49 @@ class DiffusionRunner:
|
||||
if self.has_single_blocks or self.is_last_stage:
|
||||
patch = patch_concat
|
||||
else:
|
||||
with trace(
|
||||
name=f"send {self.next_rank}", rank=self.rank, category="comms"
|
||||
):
|
||||
patch_concat = mx.distributed.send(
|
||||
patch_concat, self.next_rank, group=self.group
|
||||
)
|
||||
mx.async_eval(patch_concat)
|
||||
patch_concat = mx.distributed.send(
|
||||
patch_concat, self.next_rank, group=self.group
|
||||
)
|
||||
mx.async_eval(patch_concat)
|
||||
|
||||
elif self.has_joint_blocks and not self.is_last_stage:
|
||||
with trace(name=f"send {self.next_rank}", rank=self.rank, category="comms"):
|
||||
patch = mx.distributed.send(patch, self.next_rank, group=self.group)
|
||||
mx.async_eval(patch)
|
||||
patch = mx.distributed.send(patch, self.next_rank, group=self.group)
|
||||
mx.async_eval(patch)
|
||||
|
||||
if patch_idx == 0:
|
||||
assert encoder_hidden_states is not None
|
||||
with trace(
|
||||
name=f"send {self.next_rank}", rank=self.rank, category="comms"
|
||||
):
|
||||
encoder_hidden_states = mx.distributed.send(
|
||||
encoder_hidden_states, self.next_rank, group=self.group
|
||||
)
|
||||
mx.async_eval(encoder_hidden_states)
|
||||
encoder_hidden_states = mx.distributed.send(
|
||||
encoder_hidden_states, self.next_rank, group=self.group
|
||||
)
|
||||
mx.async_eval(encoder_hidden_states)
|
||||
|
||||
if self.has_single_blocks:
|
||||
if not self.owns_concat_stage and not self.is_first_stage:
|
||||
patch_len = patch.shape[1]
|
||||
with trace(
|
||||
name=f"recv {self.prev_rank}", rank=self.rank, category="comms"
|
||||
):
|
||||
patch = mx.distributed.recv(
|
||||
(batch_size, text_seq_len + patch_len, hidden_dim),
|
||||
patch.dtype,
|
||||
self.prev_rank,
|
||||
group=self.group,
|
||||
)
|
||||
mx.eval(patch)
|
||||
patch = mx.distributed.recv(
|
||||
(batch_size, text_seq_len + patch_len, hidden_dim),
|
||||
patch.dtype,
|
||||
self.prev_rank,
|
||||
group=self.group,
|
||||
)
|
||||
mx.eval(patch)
|
||||
|
||||
assert self.single_block_wrappers is not None
|
||||
with trace(
|
||||
name=f"single patch {patch_idx}",
|
||||
rank=self.rank,
|
||||
category="compute",
|
||||
):
|
||||
for wrapper in self.single_block_wrappers:
|
||||
wrapper.set_patch(BlockWrapperMode.PATCHED, start_token, end_token)
|
||||
patch = wrapper(
|
||||
hidden_states=patch,
|
||||
text_embeddings=text_embeddings,
|
||||
rotary_embeddings=image_rotary_embeddings,
|
||||
)
|
||||
|
||||
if is_tracing_enabled():
|
||||
mx.eval(patch)
|
||||
for wrapper in self.single_block_wrappers:
|
||||
wrapper.set_patch(BlockWrapperMode.PATCHED, start_token, end_token)
|
||||
patch = wrapper(
|
||||
hidden_states=patch,
|
||||
text_embeddings=text_embeddings,
|
||||
rotary_embeddings=image_rotary_embeddings,
|
||||
)
|
||||
|
||||
if not self.is_last_stage:
|
||||
with trace(
|
||||
name=f"send {self.next_rank}", rank=self.rank, category="comms"
|
||||
):
|
||||
patch = mx.distributed.send(patch, self.next_rank, group=self.group)
|
||||
mx.async_eval(patch)
|
||||
patch = mx.distributed.send(patch, self.next_rank, group=self.group)
|
||||
mx.async_eval(patch)
|
||||
|
||||
noise: mx.array | None = None
|
||||
if self.is_last_stage:
|
||||
patch = patch[:, text_seq_len:, :]
|
||||
noise = self.adapter.final_projection(patch, text_embeddings)
|
||||
patch_img_only = patch[:, text_seq_len:, :]
|
||||
noise = self.adapter.final_projection(patch_img_only, text_embeddings)
|
||||
|
||||
return noise, encoder_hidden_states
|
||||
|
||||
@@ -3,7 +3,6 @@ from copy import deepcopy
|
||||
from typing import Any, cast
|
||||
|
||||
import mlx.core as mx
|
||||
import psutil
|
||||
from mlx_lm.models.cache import (
|
||||
KVCache,
|
||||
QuantizedKVCache,
|
||||
@@ -13,29 +12,25 @@ from mlx_lm.models.cache import (
|
||||
from mlx_lm.models.gpt_oss import Model as GptOssModel
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.mlx import KVCacheType
|
||||
from exo.worker.engines.mlx import Model
|
||||
from exo.worker.engines.mlx.constants import CACHE_GROUP_SIZE, KV_CACHE_BITS
|
||||
from exo.worker.runner.bootstrap import logger
|
||||
|
||||
# Fraction of device memory above which LRU eviction kicks in
|
||||
_DEFAULT_MEMORY_THRESHOLD = 0.9
|
||||
_DEFAULT_MEMORY_THRESHOLD = 0.85
|
||||
_MEMORY_THRESHOLD = float(
|
||||
os.environ.get("EXO_MEMORY_THRESHOLD", _DEFAULT_MEMORY_THRESHOLD)
|
||||
)
|
||||
|
||||
|
||||
class KVPrefixCache:
|
||||
def __init__(
|
||||
self, tokenizer: TokenizerWrapper, group: mx.distributed.Group | None = None
|
||||
):
|
||||
def __init__(self, tokenizer: TokenizerWrapper):
|
||||
self.prompts: list[mx.array] = [] # mx array of tokens (ints)
|
||||
self.caches: list[KVCacheType] = []
|
||||
self._last_used: list[int] = [] # monotonic counter of last access per entry
|
||||
self._access_counter: int = 0
|
||||
self._tokenizer: TokenizerWrapper = tokenizer
|
||||
self._group = group
|
||||
|
||||
def clear(self):
|
||||
"""Clear all cached prompts and caches."""
|
||||
@@ -86,13 +81,13 @@ class KVPrefixCache:
|
||||
best_snapshot_index, best_snapshot_length = None, 0
|
||||
|
||||
for i, cached_prompt in enumerate(self.prompts):
|
||||
length = get_prefix_length(tokenized_prompt, cached_prompt)
|
||||
length = _get_prefix_length(tokenized_prompt, cached_prompt)
|
||||
|
||||
if length == max_length:
|
||||
# Exact match - cached prompt starts with our entire prompt
|
||||
# Trim cache to prompt length - 1, return last token for stream_generate
|
||||
prompt_cache = deepcopy(self.caches[i])
|
||||
cached_length = cache_length(self.caches[i])
|
||||
cached_length = _cache_length(self.caches[i])
|
||||
tokens_to_trim = cached_length - (max_length - 1)
|
||||
if tokens_to_trim > 0:
|
||||
trim_prompt_cache(cast(list[Any], prompt_cache), tokens_to_trim)
|
||||
@@ -114,7 +109,7 @@ class KVPrefixCache:
|
||||
prompt_cache = deepcopy(self.caches[best_snapshot_index])
|
||||
|
||||
# Trim removes tokens from the end, so we trim (cached_length - prefix_length) to keep the prefix
|
||||
cached_length = cache_length(self.caches[best_snapshot_index])
|
||||
cached_length = _cache_length(self.caches[best_snapshot_index])
|
||||
tokens_to_trim = cached_length - best_snapshot_length
|
||||
if tokens_to_trim > 0:
|
||||
trim_prompt_cache(cast(list[Any], prompt_cache), tokens_to_trim)
|
||||
@@ -136,37 +131,29 @@ class KVPrefixCache:
|
||||
return prompt_cache, tokenized_prompt, None
|
||||
|
||||
def _evict_if_needed(self):
|
||||
"""Evict least recently used entries while memory usage is high."""
|
||||
"""Evict least recently used entries while memory pressure is high."""
|
||||
if len(self.caches) == 0:
|
||||
return
|
||||
|
||||
active: int = mx.metal.get_active_memory()
|
||||
limit = int(mx.metal.device_info()["max_recommended_working_set_size"])
|
||||
if active < limit * _MEMORY_THRESHOLD:
|
||||
return
|
||||
|
||||
# Evict LRU entries until below threshold or only one entry left
|
||||
while (
|
||||
len(self.caches) > 1
|
||||
and self.get_memory_used_percentage() > _MEMORY_THRESHOLD
|
||||
):
|
||||
while len(self.caches) > 0:
|
||||
lru_index = self._last_used.index(min(self._last_used))
|
||||
evicted_tokens = len(self.prompts[lru_index])
|
||||
self.prompts.pop(lru_index)
|
||||
self.caches.pop(lru_index)
|
||||
self._last_used.pop(lru_index)
|
||||
logger.info(
|
||||
f"KV cache evicted LRU entry ({evicted_tokens} tokens) due to memory usage"
|
||||
f"KV cache evicted LRU entry ({evicted_tokens} tokens) due to memory pressure"
|
||||
)
|
||||
|
||||
def get_memory_used_percentage(self) -> float:
|
||||
local_pressure: float = get_memory_used_percentage()
|
||||
|
||||
if self._group is None:
|
||||
return local_pressure
|
||||
|
||||
all_pressure = mx.distributed.all_gather(
|
||||
mx.array([local_pressure], dtype=mx.float32),
|
||||
group=self._group,
|
||||
)
|
||||
# .item() evals.
|
||||
max_pressure = float(mx.max(all_pressure).item())
|
||||
return max_pressure
|
||||
active = mx.metal.get_active_memory()
|
||||
if active < limit * _MEMORY_THRESHOLD:
|
||||
break
|
||||
|
||||
|
||||
def encode_prompt(tokenizer: TokenizerWrapper, prompt: str) -> mx.array:
|
||||
@@ -181,13 +168,13 @@ def encode_prompt(tokenizer: TokenizerWrapper, prompt: str) -> mx.array:
|
||||
return mx.array(tokenized_prompt)
|
||||
|
||||
|
||||
def cache_length(cache: KVCacheType) -> int:
|
||||
def _cache_length(cache: KVCacheType) -> int:
|
||||
"""Get the number of tokens in a KV cache."""
|
||||
# Use .offset attribute which all cache types have (len() not implemented in older QuantizedKVCache)
|
||||
return max(c.offset for c in cache) # type: ignore
|
||||
|
||||
|
||||
def get_prefix_length(prompt: mx.array, cached_prompt: mx.array) -> int:
|
||||
def _get_prefix_length(prompt: mx.array, cached_prompt: mx.array) -> int:
|
||||
"""Find the length of the common prefix between two token arrays."""
|
||||
n = min(int(prompt.shape[0]), int(cached_prompt.shape[0]))
|
||||
if n == 0:
|
||||
@@ -198,17 +185,6 @@ def get_prefix_length(prompt: mx.array, cached_prompt: mx.array) -> int:
|
||||
return int(mx.sum(prefix_mask).item())
|
||||
|
||||
|
||||
def get_available_memory() -> Memory:
|
||||
mem: int = psutil.virtual_memory().available
|
||||
return Memory.from_bytes(mem)
|
||||
|
||||
|
||||
def get_memory_used_percentage() -> float:
|
||||
mem = psutil.virtual_memory()
|
||||
# percent is 0-100
|
||||
return float(mem.percent / 100)
|
||||
|
||||
|
||||
def make_kv_cache(
|
||||
model: Model, max_kv_size: int | None = None, keep: int = 0
|
||||
) -> KVCacheType:
|
||||
|
||||
@@ -18,7 +18,6 @@ from pydantic import ValidationError
|
||||
|
||||
from exo.shared.constants import EXO_MAX_CHUNK_SIZE
|
||||
from exo.shared.models.model_cards import ModelId, ModelTask
|
||||
from exo.shared.tracing import clear_trace_buffer, get_trace_buffer, is_tracing_enabled
|
||||
from exo.shared.types.api import ChatCompletionMessageText, ImageGenerationStats
|
||||
from exo.shared.types.chunks import ErrorChunk, ImageChunk, TokenChunk, ToolCallChunk
|
||||
from exo.shared.types.common import CommandId
|
||||
@@ -28,8 +27,6 @@ from exo.shared.types.events import (
|
||||
RunnerStatusUpdated,
|
||||
TaskAcknowledged,
|
||||
TaskStatusUpdated,
|
||||
TraceEventData,
|
||||
TracesCollected,
|
||||
)
|
||||
from exo.shared.types.tasks import (
|
||||
ChatCompletion,
|
||||
@@ -40,7 +37,6 @@ from exo.shared.types.tasks import (
|
||||
Shutdown,
|
||||
StartWarmup,
|
||||
Task,
|
||||
TaskId,
|
||||
TaskStatus,
|
||||
)
|
||||
from exo.shared.types.worker.instances import BoundInstance
|
||||
@@ -115,12 +111,8 @@ def main(
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
|
||||
)
|
||||
seen = set[TaskId]()
|
||||
with task_receiver as tasks:
|
||||
for task in tasks:
|
||||
if task.task_id in seen:
|
||||
logger.warning("repeat task - potential error")
|
||||
seen.add(task.task_id)
|
||||
event_sender.send(
|
||||
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Running)
|
||||
)
|
||||
@@ -171,7 +163,7 @@ def main(
|
||||
logger.info(
|
||||
f"model has_tool_calling={tokenizer.has_tool_calling}"
|
||||
)
|
||||
kv_prefix_cache = KVPrefixCache(tokenizer, group)
|
||||
kv_prefix_cache = KVPrefixCache(tokenizer)
|
||||
|
||||
elif (
|
||||
ModelTask.TextToImage in shard_metadata.model_card.tasks
|
||||
@@ -411,10 +403,6 @@ def main(
|
||||
)
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
_send_traces_if_enabled(
|
||||
event_sender, task.task_id, shard_metadata.device_rank
|
||||
)
|
||||
|
||||
current_status = RunnerReady()
|
||||
logger.info("runner ready")
|
||||
@@ -473,10 +461,6 @@ def main(
|
||||
)
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
_send_traces_if_enabled(
|
||||
event_sender, task.task_id, shard_metadata.device_rank
|
||||
)
|
||||
|
||||
current_status = RunnerReady()
|
||||
logger.info("runner ready")
|
||||
@@ -651,36 +635,6 @@ def _send_image_chunk(
|
||||
)
|
||||
|
||||
|
||||
def _send_traces_if_enabled(
|
||||
event_sender: MpSender[Event],
|
||||
task_id: TaskId,
|
||||
rank: int,
|
||||
) -> None:
|
||||
if not is_tracing_enabled():
|
||||
return
|
||||
|
||||
traces = get_trace_buffer()
|
||||
if traces:
|
||||
trace_data = [
|
||||
TraceEventData(
|
||||
name=t.name,
|
||||
start_us=t.start_us,
|
||||
duration_us=t.duration_us,
|
||||
rank=t.rank,
|
||||
category=t.category,
|
||||
)
|
||||
for t in traces
|
||||
]
|
||||
event_sender.send(
|
||||
TracesCollected(
|
||||
task_id=task_id,
|
||||
rank=rank,
|
||||
traces=trace_data,
|
||||
)
|
||||
)
|
||||
clear_trace_buffer()
|
||||
|
||||
|
||||
def _process_image_response(
|
||||
response: ImageGenerationResponse | PartialImageResponse,
|
||||
command_id: CommandId,
|
||||
|
||||
@@ -127,25 +127,20 @@ class RunnerSupervisor:
|
||||
self._tg.cancel_scope.cancel()
|
||||
|
||||
async def start_task(self, task: Task):
|
||||
if task.task_id in self.pending:
|
||||
logger.warning(
|
||||
f"Skipping invalid task {task} as it has already been submitted"
|
||||
)
|
||||
return
|
||||
if task.task_id in self.completed:
|
||||
logger.warning(
|
||||
logger.info(
|
||||
f"Skipping invalid task {task} as it has already been completed"
|
||||
)
|
||||
return
|
||||
logger.info(f"Starting task {task}")
|
||||
event = anyio.Event()
|
||||
self.pending[task.task_id] = event
|
||||
try:
|
||||
await self._task_sender.send_async(task)
|
||||
self._task_sender.send(task)
|
||||
except ClosedResourceError:
|
||||
logger.warning(f"Task {task} dropped, runner closed communication.")
|
||||
return
|
||||
await event.wait()
|
||||
logger.info(f"Finished task {task}")
|
||||
|
||||
async def _forward_events(self):
|
||||
with self._ev_recv as events:
|
||||
|
||||
@@ -14,9 +14,9 @@ from exo.shared.types.tasks import ChatCompletionTaskParams
|
||||
from exo.worker.engines.mlx import Model
|
||||
from exo.worker.engines.mlx.cache import (
|
||||
KVPrefixCache,
|
||||
cache_length,
|
||||
_cache_length,
|
||||
_get_prefix_length,
|
||||
encode_prompt,
|
||||
get_prefix_length,
|
||||
make_kv_cache,
|
||||
)
|
||||
from exo.worker.engines.mlx.generator.generate import mlx_generate, prefill
|
||||
@@ -35,47 +35,47 @@ class TestGetPrefixLength:
|
||||
def test_identical_arrays(self):
|
||||
a = mx.array([1, 2, 3, 4, 5])
|
||||
b = mx.array([1, 2, 3, 4, 5])
|
||||
assert get_prefix_length(a, b) == 5
|
||||
assert _get_prefix_length(a, b) == 5
|
||||
|
||||
def test_no_common_prefix(self):
|
||||
a = mx.array([1, 2, 3])
|
||||
b = mx.array([4, 5, 6])
|
||||
assert get_prefix_length(a, b) == 0
|
||||
assert _get_prefix_length(a, b) == 0
|
||||
|
||||
def test_partial_prefix(self):
|
||||
a = mx.array([1, 2, 3, 4, 5])
|
||||
b = mx.array([1, 2, 3, 7, 8])
|
||||
assert get_prefix_length(a, b) == 3
|
||||
assert _get_prefix_length(a, b) == 3
|
||||
|
||||
def test_prompt_longer_than_cached(self):
|
||||
a = mx.array([1, 2, 3, 4, 5])
|
||||
b = mx.array([1, 2, 3])
|
||||
assert get_prefix_length(a, b) == 3
|
||||
assert _get_prefix_length(a, b) == 3
|
||||
|
||||
def test_cached_longer_than_prompt(self):
|
||||
a = mx.array([1, 2, 3])
|
||||
b = mx.array([1, 2, 3, 4, 5])
|
||||
assert get_prefix_length(a, b) == 3
|
||||
assert _get_prefix_length(a, b) == 3
|
||||
|
||||
def test_single_token_match(self):
|
||||
a = mx.array([1, 2, 3])
|
||||
b = mx.array([1, 5, 6])
|
||||
assert get_prefix_length(a, b) == 1
|
||||
assert _get_prefix_length(a, b) == 1
|
||||
|
||||
def test_empty_prompt(self):
|
||||
a = mx.array([]).astype(mx.int32)
|
||||
b = mx.array([1, 2, 3])
|
||||
assert get_prefix_length(a, b) == 0
|
||||
assert _get_prefix_length(a, b) == 0
|
||||
|
||||
def test_empty_cached(self):
|
||||
a = mx.array([1, 2, 3])
|
||||
b = mx.array([]).astype(mx.int32)
|
||||
assert get_prefix_length(a, b) == 0
|
||||
assert _get_prefix_length(a, b) == 0
|
||||
|
||||
def test_both_empty(self):
|
||||
a = mx.array([]).astype(mx.int32)
|
||||
b = mx.array([]).astype(mx.int32)
|
||||
assert get_prefix_length(a, b) == 0
|
||||
assert _get_prefix_length(a, b) == 0
|
||||
|
||||
|
||||
class TestKVPrefix:
|
||||
@@ -146,7 +146,7 @@ class TestKVPrefixCacheWithModel:
|
||||
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
|
||||
|
||||
# Cache should now hold the prompt tokens
|
||||
assert cache_length(cache) == len(tokens)
|
||||
assert _cache_length(cache) == len(tokens)
|
||||
|
||||
def test_add_and_get_exact_match(self, model_and_tokenizer):
|
||||
model, tokenizer = model_and_tokenizer
|
||||
@@ -166,7 +166,7 @@ class TestKVPrefixCacheWithModel:
|
||||
kv_prefix_cache.add_kv_cache(prompt, cache)
|
||||
|
||||
assert len(kv_prefix_cache.prompts) == 1
|
||||
stored_length = cache_length(kv_prefix_cache.caches[0])
|
||||
stored_length = _cache_length(kv_prefix_cache.caches[0])
|
||||
assert stored_length > 0
|
||||
|
||||
# Retrieve with same prompt: exact match
|
||||
@@ -209,7 +209,7 @@ class TestKVPrefixCacheWithModel:
|
||||
long_tokens = encode_prompt(tokenizer, long_prompt)
|
||||
|
||||
# The prompts share a prefix (chat template preamble + "Hi")
|
||||
expected_prefix = get_prefix_length(long_tokens, short_tokens)
|
||||
expected_prefix = _get_prefix_length(long_tokens, short_tokens)
|
||||
assert expected_prefix > 0, (
|
||||
"Prompts should share a prefix from the chat template"
|
||||
)
|
||||
@@ -243,7 +243,7 @@ class TestKVPrefixCacheWithModel:
|
||||
kv_prefix_cache = KVPrefixCache(tokenizer)
|
||||
kv_prefix_cache.add_kv_cache(prompt, cache)
|
||||
|
||||
stored_length = cache_length(kv_prefix_cache.caches[0])
|
||||
stored_length = _cache_length(kv_prefix_cache.caches[0])
|
||||
|
||||
# Get cache and mutate it (simulating what generation does)
|
||||
result_cache, _, matched_index = kv_prefix_cache.get_kv_cache(model, prompt)
|
||||
@@ -259,7 +259,7 @@ class TestKVPrefixCacheWithModel:
|
||||
mx.eval([c.keys for c in result_cache])
|
||||
|
||||
# Stored cache must be unchanged
|
||||
assert cache_length(kv_prefix_cache.caches[0]) == stored_length
|
||||
assert _cache_length(kv_prefix_cache.caches[0]) == stored_length
|
||||
|
||||
def test_stored_cache_survives_repeated_get_mutate_cycles(
|
||||
self, model_and_tokenizer
|
||||
@@ -281,7 +281,7 @@ class TestKVPrefixCacheWithModel:
|
||||
kv_prefix_cache = KVPrefixCache(tokenizer)
|
||||
kv_prefix_cache.add_kv_cache(prompt, cache)
|
||||
|
||||
stored_length = cache_length(kv_prefix_cache.caches[0])
|
||||
stored_length = _cache_length(kv_prefix_cache.caches[0])
|
||||
|
||||
for i in range(3):
|
||||
result_cache, _, _ = kv_prefix_cache.get_kv_cache(model, prompt)
|
||||
@@ -293,7 +293,7 @@ class TestKVPrefixCacheWithModel:
|
||||
layer_cache.update_and_fetch(extra, extra)
|
||||
mx.eval([c.keys for c in result_cache])
|
||||
|
||||
assert cache_length(kv_prefix_cache.caches[0]) == stored_length, (
|
||||
assert _cache_length(kv_prefix_cache.caches[0]) == stored_length, (
|
||||
f"Failed on loop {i}"
|
||||
)
|
||||
|
||||
@@ -325,7 +325,7 @@ class TestKVPrefixCacheWithModel:
|
||||
assert len(kv_prefix_cache.caches) == 1
|
||||
# Cache should contain prompt + generated tokens
|
||||
expected_length = len(prompt_tokens) + generated_tokens
|
||||
assert cache_length(kv_prefix_cache.caches[0]) == expected_length
|
||||
assert _cache_length(kv_prefix_cache.caches[0]) == expected_length
|
||||
|
||||
def test_mlx_generate_second_call_gets_prefix_hit(self, model_and_tokenizer):
|
||||
"""Second mlx_generate call with same prompt should get a prefix hit from stored cache."""
|
||||
@@ -400,7 +400,7 @@ class TestKVPrefixCacheWithModel:
|
||||
first_gen_time = time.perf_counter() - t0
|
||||
|
||||
assert len(kv_prefix_cache.prompts) == 1
|
||||
first_cache_length = cache_length(kv_prefix_cache.caches[0])
|
||||
first_cache_length = _cache_length(kv_prefix_cache.caches[0])
|
||||
|
||||
# Second generation: same long prompt + extra content (simulating multi-turn)
|
||||
task2 = ChatCompletionTaskParams(
|
||||
@@ -416,7 +416,7 @@ class TestKVPrefixCacheWithModel:
|
||||
prompt2_tokens = encode_prompt(tokenizer, prompt2)
|
||||
|
||||
# Verify the prompts share a long prefix
|
||||
prefix_len = get_prefix_length(prompt2_tokens, prompt1_tokens)
|
||||
prefix_len = _get_prefix_length(prompt2_tokens, prompt1_tokens)
|
||||
assert prefix_len > 1000, "Prompts must share > 1000 token prefix"
|
||||
|
||||
# Second generation should reuse the cached prefix (only prefill new tokens)
|
||||
@@ -440,7 +440,7 @@ class TestKVPrefixCacheWithModel:
|
||||
# With prefix_hit > 1000, should update in-place (not add a second entry)
|
||||
assert len(kv_prefix_cache.prompts) == 1
|
||||
# Updated cache should be longer (prompt2 + generated > prompt1 + generated)
|
||||
updated_cache_length = cache_length(kv_prefix_cache.caches[0])
|
||||
updated_cache_length = _cache_length(kv_prefix_cache.caches[0])
|
||||
assert updated_cache_length > first_cache_length
|
||||
|
||||
def test_mlx_generate_stored_cache_not_mutated(self, model_and_tokenizer):
|
||||
@@ -465,7 +465,7 @@ class TestKVPrefixCacheWithModel:
|
||||
):
|
||||
pass
|
||||
|
||||
firstcache_length = cache_length(kv_prefix_cache.caches[0])
|
||||
first_cache_length = _cache_length(kv_prefix_cache.caches[0])
|
||||
|
||||
# Second generation gets the cache and mutates it during generation
|
||||
for _response in mlx_generate(
|
||||
@@ -478,7 +478,7 @@ class TestKVPrefixCacheWithModel:
|
||||
pass
|
||||
|
||||
# The first stored cache must not have been mutated by the second generation
|
||||
assert cache_length(kv_prefix_cache.caches[0]) == firstcache_length
|
||||
assert _cache_length(kv_prefix_cache.caches[0]) == first_cache_length
|
||||
|
||||
def test_evicts_lru_entry_under_memory_pressure(self, model_and_tokenizer):
|
||||
"""Under memory pressure, adding a new cache entry evicts the least recently used one."""
|
||||
@@ -540,6 +540,6 @@ class TestKVPrefixCacheWithModel:
|
||||
assert len(kv_prefix_cache.prompts) == 1
|
||||
# The surviving entry should be the newly added one
|
||||
new_tokens = encode_prompt(tokenizer, prompt)
|
||||
assert get_prefix_length(kv_prefix_cache.prompts[0], new_tokens) == len(
|
||||
assert _get_prefix_length(kv_prefix_cache.prompts[0], new_tokens) == len(
|
||||
new_tokens
|
||||
)
|
||||
|
||||
@@ -109,8 +109,8 @@ def assert_events_equal(test_events: Iterable[Event], true_events: Iterable[Even
|
||||
|
||||
@pytest.fixture
|
||||
def patch_out_mlx(monkeypatch: pytest.MonkeyPatch):
|
||||
# initialize_mlx returns a mock group
|
||||
monkeypatch.setattr(mlx_runner, "initialize_mlx", make_nothin(MockGroup()))
|
||||
# initialize_mlx returns a "group" equal to 1
|
||||
monkeypatch.setattr(mlx_runner, "initialize_mlx", make_nothin(1))
|
||||
monkeypatch.setattr(mlx_runner, "load_mlx_items", make_nothin((1, MockTokenizer)))
|
||||
monkeypatch.setattr(mlx_runner, "warmup_inference", make_nothin(1))
|
||||
monkeypatch.setattr(mlx_runner, "_check_for_debug_prompts", nothin)
|
||||
@@ -147,14 +147,6 @@ class MockTokenizer:
|
||||
has_tool_calling = False
|
||||
|
||||
|
||||
class MockGroup:
|
||||
def rank(self) -> int:
|
||||
return 0
|
||||
|
||||
def size(self) -> int:
|
||||
return 1
|
||||
|
||||
|
||||
def _run(tasks: Iterable[Task]):
|
||||
bound_instance = get_bound_mlx_ring_instance(
|
||||
instance_id=INSTANCE_1_ID,
|
||||
|
||||
30
uv.lock
generated
30
uv.lock
generated
@@ -192,20 +192,14 @@ sdist = { url = "https://files.pythonhosted.org/packages/eb/56/b1ba7935a17738ae8
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/b0/1e/d22cc63332bd59b06481ceaac49d6c507598642e2230f201649058a7e704/cffi-2.0.0-cp313-cp313-manylinux1_i686.manylinux2014_i686.manylinux_2_17_i686.manylinux_2_5_i686.whl", hash = "sha256:07b271772c100085dd28b74fa0cd81c8fb1a3ba18b21e03d7c27f3436a10606b", size = 212446 },
|
||||
{ url = "https://files.pythonhosted.org/packages/a9/f5/a2c23eb03b61a0b8747f211eb716446c826ad66818ddc7810cc2cc19b3f2/cffi-2.0.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:d48a880098c96020b02d5a1f7d9251308510ce8858940e6fa99ece33f610838b", size = 220101 },
|
||||
{ url = "https://files.pythonhosted.org/packages/f2/7f/e6647792fc5850d634695bc0e6ab4111ae88e89981d35ac269956605feba/cffi-2.0.0-cp313-cp313-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:f93fd8e5c8c0a4aa1f424d6173f14a892044054871c771f8566e4008eaa359d2", size = 207948 },
|
||||
{ url = "https://files.pythonhosted.org/packages/cb/1e/a5a1bd6f1fb30f22573f76533de12a00bf274abcdc55c8edab639078abb6/cffi-2.0.0-cp313-cp313-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:dd4f05f54a52fb558f1ba9f528228066954fee3ebe629fc1660d874d040ae5a3", size = 206422 },
|
||||
{ url = "https://files.pythonhosted.org/packages/98/df/0a1755e750013a2081e863e7cd37e0cdd02664372c754e5560099eb7aa44/cffi-2.0.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:c8d3b5532fc71b7a77c09192b4a5a200ea992702734a2e9279a37f2478236f26", size = 219499 },
|
||||
{ url = "https://files.pythonhosted.org/packages/50/e1/a969e687fcf9ea58e6e2a928ad5e2dd88cc12f6f0ab477e9971f2309b57c/cffi-2.0.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:d9b29c1f0ae438d5ee9acb31cadee00a58c46cc9c0b2f9038c6b0b3470877a8c", size = 222928 },
|
||||
{ url = "https://files.pythonhosted.org/packages/36/54/0362578dd2c9e557a28ac77698ed67323ed5b9775ca9d3fe73fe191bb5d8/cffi-2.0.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:6d50360be4546678fc1b79ffe7a66265e28667840010348dd69a314145807a1b", size = 221302 },
|
||||
{ url = "https://files.pythonhosted.org/packages/d6/43/0e822876f87ea8a4ef95442c3d766a06a51fc5298823f884ef87aaad168c/cffi-2.0.0-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:24b6f81f1983e6df8db3adc38562c83f7d4a0c36162885ec7f7b77c7dcbec97b", size = 220049 },
|
||||
{ url = "https://files.pythonhosted.org/packages/b4/89/76799151d9c2d2d1ead63c2429da9ea9d7aac304603de0c6e8764e6e8e70/cffi-2.0.0-cp314-cp314-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:12873ca6cb9b0f0d3a0da705d6086fe911591737a59f28b7936bdfed27c0d47c", size = 207793 },
|
||||
{ url = "https://files.pythonhosted.org/packages/bb/dd/3465b14bb9e24ee24cb88c9e3730f6de63111fffe513492bf8c808a3547e/cffi-2.0.0-cp314-cp314-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:d9b97165e8aed9272a6bb17c01e3cc5871a594a446ebedc996e2397a1c1ea8ef", size = 206300 },
|
||||
{ url = "https://files.pythonhosted.org/packages/47/d9/d83e293854571c877a92da46fdec39158f8d7e68da75bf73581225d28e90/cffi-2.0.0-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:afb8db5439b81cf9c9d0c80404b60c3cc9c3add93e114dcae767f1477cb53775", size = 219244 },
|
||||
{ url = "https://files.pythonhosted.org/packages/2b/0f/1f177e3683aead2bb00f7679a16451d302c436b5cbf2505f0ea8146ef59e/cffi-2.0.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:737fe7d37e1a1bffe70bd5754ea763a62a066dc5913ca57e957824b72a85e205", size = 222828 },
|
||||
{ url = "https://files.pythonhosted.org/packages/c6/0f/cafacebd4b040e3119dcb32fed8bdef8dfe94da653155f9d0b9dc660166e/cffi-2.0.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:38100abb9d1b1435bc4cc340bb4489635dc2f0da7456590877030c9b3d40b0c1", size = 220926 },
|
||||
{ url = "https://files.pythonhosted.org/packages/be/b4/c56878d0d1755cf9caa54ba71e5d049479c52f9e4afc230f06822162ab2f/cffi-2.0.0-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:7cc09976e8b56f8cebd752f7113ad07752461f48a58cbba644139015ac24954c", size = 221593 },
|
||||
{ url = "https://files.pythonhosted.org/packages/e0/0d/eb704606dfe8033e7128df5e90fee946bbcb64a04fcdaa97321309004000/cffi-2.0.0-cp314-cp314t-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:92b68146a71df78564e4ef48af17551a5ddd142e5190cdf2c5624d0c3ff5b2e8", size = 209354 },
|
||||
{ url = "https://files.pythonhosted.org/packages/d8/19/3c435d727b368ca475fb8742ab97c9cb13a0de600ce86f62eab7fa3eea60/cffi-2.0.0-cp314-cp314t-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:b1e74d11748e7e98e2f426ab176d4ed720a64412b6a15054378afdb71e0f37dc", size = 208480 },
|
||||
{ url = "https://files.pythonhosted.org/packages/d0/44/681604464ed9541673e486521497406fadcc15b5217c3e326b061696899a/cffi-2.0.0-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:28a3a209b96630bca57cce802da70c266eb08c6e97e5afd61a75611ee6c64592", size = 221584 },
|
||||
{ url = "https://files.pythonhosted.org/packages/25/8e/342a504ff018a2825d395d44d63a767dd8ebc927ebda557fecdaca3ac33a/cffi-2.0.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:7553fb2090d71822f02c629afe6042c299edf91ba1bf94951165613553984512", size = 224443 },
|
||||
{ url = "https://files.pythonhosted.org/packages/e1/5e/b666bacbbc60fbf415ba9988324a132c9a7a0448a9a8f125074671c0f2c3/cffi-2.0.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:6c6c373cfc5c83a975506110d17457138c8c63016b563cc9ed6e056a82f13ce4", size = 223437 },
|
||||
@@ -311,10 +305,8 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/5c/49/498c86566a1d80e978b42f0d702795f69887005548c041636df6ae1ca64c/cryptography-46.0.3-cp311-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:01ca9ff2885f3acc98c29f1860552e37f6d7c7d013d7334ff2a9de43a449315d", size = 4450807 },
|
||||
{ url = "https://files.pythonhosted.org/packages/4b/0a/863a3604112174c8624a2ac3c038662d9e59970c7f926acdcfaed8d61142/cryptography-46.0.3-cp311-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:6eae65d4c3d33da080cff9c4ab1f711b15c1d9760809dad6ea763f3812d254cb", size = 4299615 },
|
||||
{ url = "https://files.pythonhosted.org/packages/64/02/b73a533f6b64a69f3cd3872acb6ebc12aef924d8d103133bb3ea750dc703/cryptography-46.0.3-cp311-abi3-manylinux_2_28_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:e5bf0ed4490068a2e72ac03d786693adeb909981cc596425d09032d372bcc849", size = 4016800 },
|
||||
{ url = "https://files.pythonhosted.org/packages/25/d5/16e41afbfa450cde85a3b7ec599bebefaef16b5c6ba4ec49a3532336ed72/cryptography-46.0.3-cp311-abi3-manylinux_2_28_ppc64le.whl", hash = "sha256:5ecfccd2329e37e9b7112a888e76d9feca2347f12f37918facbb893d7bb88ee8", size = 4984707 },
|
||||
{ url = "https://files.pythonhosted.org/packages/c9/56/e7e69b427c3878352c2fb9b450bd0e19ed552753491d39d7d0a2f5226d41/cryptography-46.0.3-cp311-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:a2c0cd47381a3229c403062f764160d57d4d175e022c1df84e168c6251a22eec", size = 4482541 },
|
||||
{ url = "https://files.pythonhosted.org/packages/78/f6/50736d40d97e8483172f1bb6e698895b92a223dba513b0ca6f06b2365339/cryptography-46.0.3-cp311-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:549e234ff32571b1f4076ac269fcce7a808d3bf98b76c8dd560e42dbc66d7d91", size = 4299464 },
|
||||
{ url = "https://files.pythonhosted.org/packages/00/de/d8e26b1a855f19d9994a19c702fa2e93b0456beccbcfe437eda00e0701f2/cryptography-46.0.3-cp311-abi3-manylinux_2_34_ppc64le.whl", hash = "sha256:c0a7bb1a68a5d3471880e264621346c48665b3bf1c3759d682fc0864c540bd9e", size = 4950838 },
|
||||
{ url = "https://files.pythonhosted.org/packages/8f/29/798fc4ec461a1c9e9f735f2fc58741b0daae30688f41b2497dcbc9ed1355/cryptography-46.0.3-cp311-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:10b01676fc208c3e6feeb25a8b83d81767e8059e1fe86e1dc62d10a3018fa926", size = 4481596 },
|
||||
{ url = "https://files.pythonhosted.org/packages/15/8d/03cd48b20a573adfff7652b76271078e3045b9f49387920e7f1f631d125e/cryptography-46.0.3-cp311-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:0abf1ffd6e57c67e92af68330d05760b7b7efb243aab8377e583284dbab72c71", size = 4426782 },
|
||||
{ url = "https://files.pythonhosted.org/packages/fa/b1/ebacbfe53317d55cf33165bda24c86523497a6881f339f9aae5c2e13e57b/cryptography-46.0.3-cp311-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:a04bee9ab6a4da801eb9b51f1b708a1b5b5c9eb48c03f74198464c66f0d344ac", size = 4698381 },
|
||||
@@ -322,10 +314,8 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/c5/fd/bc1daf8230eaa075184cbbf5f8cd00ba9db4fd32d63fb83da4671b72ed8a/cryptography-46.0.3-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:39b6755623145ad5eff1dab323f4eae2a32a77a7abef2c5089a04a3d04366715", size = 4435078 },
|
||||
{ url = "https://files.pythonhosted.org/packages/82/98/d3bd5407ce4c60017f8ff9e63ffee4200ab3e23fe05b765cab805a7db008/cryptography-46.0.3-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:db391fa7c66df6762ee3f00c95a89e6d428f4d60e7abc8328f4fe155b5ac6e54", size = 4293460 },
|
||||
{ url = "https://files.pythonhosted.org/packages/26/e9/e23e7900983c2b8af7a08098db406cf989d7f09caea7897e347598d4cd5b/cryptography-46.0.3-cp314-cp314t-manylinux_2_28_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:78a97cf6a8839a48c49271cdcbd5cf37ca2c1d6b7fdd86cc864f302b5e9bf459", size = 3995237 },
|
||||
{ url = "https://files.pythonhosted.org/packages/91/15/af68c509d4a138cfe299d0d7ddb14afba15233223ebd933b4bbdbc7155d3/cryptography-46.0.3-cp314-cp314t-manylinux_2_28_ppc64le.whl", hash = "sha256:dfb781ff7eaa91a6f7fd41776ec37c5853c795d3b358d4896fdbb5df168af422", size = 4967344 },
|
||||
{ url = "https://files.pythonhosted.org/packages/ca/e3/8643d077c53868b681af077edf6b3cb58288b5423610f21c62aadcbe99f4/cryptography-46.0.3-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:6f61efb26e76c45c4a227835ddeae96d83624fb0d29eb5df5b96e14ed1a0afb7", size = 4466564 },
|
||||
{ url = "https://files.pythonhosted.org/packages/0e/43/c1e8726fa59c236ff477ff2b5dc071e54b21e5a1e51aa2cee1676f1c986f/cryptography-46.0.3-cp314-cp314t-manylinux_2_34_aarch64.whl", hash = "sha256:23b1a8f26e43f47ceb6d6a43115f33a5a37d57df4ea0ca295b780ae8546e8044", size = 4292415 },
|
||||
{ url = "https://files.pythonhosted.org/packages/42/f9/2f8fefdb1aee8a8e3256a0568cffc4e6d517b256a2fe97a029b3f1b9fe7e/cryptography-46.0.3-cp314-cp314t-manylinux_2_34_ppc64le.whl", hash = "sha256:b419ae593c86b87014b9be7396b385491ad7f320bde96826d0dd174459e54665", size = 4931457 },
|
||||
{ url = "https://files.pythonhosted.org/packages/79/30/9b54127a9a778ccd6d27c3da7563e9f2d341826075ceab89ae3b41bf5be2/cryptography-46.0.3-cp314-cp314t-manylinux_2_34_x86_64.whl", hash = "sha256:50fc3343ac490c6b08c0cf0d704e881d0d660be923fd3076db3e932007e726e3", size = 4466074 },
|
||||
{ url = "https://files.pythonhosted.org/packages/ac/68/b4f4a10928e26c941b1b6a179143af9f4d27d88fe84a6a3c53592d2e76bf/cryptography-46.0.3-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:22d7e97932f511d6b0b04f2bfd818d73dcd5928db509460aaf48384778eb6d20", size = 4420569 },
|
||||
{ url = "https://files.pythonhosted.org/packages/a3/49/3746dab4c0d1979888f125226357d3262a6dd40e114ac29e3d2abdf1ec55/cryptography-46.0.3-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:d55f3dffadd674514ad19451161118fd010988540cee43d8bc20675e775925de", size = 4681941 },
|
||||
@@ -333,10 +323,8 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/26/42/fa8389d4478368743e24e61eea78846a0006caffaf72ea24a15159215a14/cryptography-46.0.3-cp38-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:15ab9b093e8f09daab0f2159bb7e47532596075139dd74365da52ecc9cb46c5d", size = 4440029 },
|
||||
{ url = "https://files.pythonhosted.org/packages/5f/eb/f483db0ec5ac040824f269e93dd2bd8a21ecd1027e77ad7bdf6914f2fd80/cryptography-46.0.3-cp38-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:46acf53b40ea38f9c6c229599a4a13f0d46a6c3fa9ef19fc1a124d62e338dfa0", size = 4297222 },
|
||||
{ url = "https://files.pythonhosted.org/packages/fd/cf/da9502c4e1912cb1da3807ea3618a6829bee8207456fbbeebc361ec38ba3/cryptography-46.0.3-cp38-abi3-manylinux_2_28_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:10ca84c4668d066a9878890047f03546f3ae0a6b8b39b697457b7757aaf18dbc", size = 4012280 },
|
||||
{ url = "https://files.pythonhosted.org/packages/6b/8f/9adb86b93330e0df8b3dcf03eae67c33ba89958fc2e03862ef1ac2b42465/cryptography-46.0.3-cp38-abi3-manylinux_2_28_ppc64le.whl", hash = "sha256:36e627112085bb3b81b19fed209c05ce2a52ee8b15d161b7c643a7d5a88491f3", size = 4978958 },
|
||||
{ url = "https://files.pythonhosted.org/packages/d1/a0/5fa77988289c34bdb9f913f5606ecc9ada1adb5ae870bd0d1054a7021cc4/cryptography-46.0.3-cp38-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:1000713389b75c449a6e979ffc7dcc8ac90b437048766cef052d4d30b8220971", size = 4473714 },
|
||||
{ url = "https://files.pythonhosted.org/packages/14/e5/fc82d72a58d41c393697aa18c9abe5ae1214ff6f2a5c18ac470f92777895/cryptography-46.0.3-cp38-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:b02cf04496f6576afffef5ddd04a0cb7d49cf6be16a9059d793a30b035f6b6ac", size = 4296970 },
|
||||
{ url = "https://files.pythonhosted.org/packages/78/06/5663ed35438d0b09056973994f1aec467492b33bd31da36e468b01ec1097/cryptography-46.0.3-cp38-abi3-manylinux_2_34_ppc64le.whl", hash = "sha256:71e842ec9bc7abf543b47cf86b9a743baa95f4677d22baa4c7d5c69e49e9bc04", size = 4940236 },
|
||||
{ url = "https://files.pythonhosted.org/packages/fc/59/873633f3f2dcd8a053b8dd1d38f783043b5fce589c0f6988bf55ef57e43e/cryptography-46.0.3-cp38-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:402b58fc32614f00980b66d6e56a5b4118e6cb362ae8f3fda141ba4689bd4506", size = 4472642 },
|
||||
{ url = "https://files.pythonhosted.org/packages/3d/39/8e71f3930e40f6877737d6f69248cf74d4e34b886a3967d32f919cc50d3b/cryptography-46.0.3-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:ef639cb3372f69ec44915fafcd6698b6cc78fbe0c2ea41be867f6ed612811963", size = 4423126 },
|
||||
{ url = "https://files.pythonhosted.org/packages/cd/c7/f65027c2810e14c3e7268353b1681932b87e5a48e65505d8cc17c99e36ae/cryptography-46.0.3-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:3b51b8ca4f1c6453d8829e1eb7299499ca7f313900dd4d89a24b8b87c0a780d4", size = 4686573 },
|
||||
@@ -415,7 +403,7 @@ requires-dist = [
|
||||
{ name = "mflux", specifier = "==0.15.4" },
|
||||
{ name = "mlx", marker = "sys_platform == 'darwin'", specifier = "==0.30.4" },
|
||||
{ name = "mlx", extras = ["cpu"], marker = "sys_platform == 'linux'", specifier = "==0.30.4" },
|
||||
{ name = "mlx-lm", git = "https://github.com/ml-explore/mlx-lm?branch=main" },
|
||||
{ name = "mlx-lm", specifier = "==0.30.5" },
|
||||
{ name = "openai-harmony", specifier = ">=0.0.8" },
|
||||
{ name = "pillow", specifier = ">=11.0,<12.0" },
|
||||
{ name = "psutil", specifier = ">=7.0.0" },
|
||||
@@ -1073,7 +1061,7 @@ wheels = [
|
||||
[[package]]
|
||||
name = "mlx-lm"
|
||||
version = "0.30.5"
|
||||
source = { git = "https://github.com/ml-explore/mlx-lm?branch=main#96699e6dadb13b82b28285bb131a0741997d19ae" }
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "jinja2", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "mlx", marker = "sys_platform == 'darwin'" },
|
||||
@@ -1083,6 +1071,10 @@ dependencies = [
|
||||
{ name = "sentencepiece", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "transformers", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/0b/90/4469d9f75f196e6255f59a89441abe0079925d30a001462e1c1c4bc4e6a1/mlx_lm-0.30.5.tar.gz", hash = "sha256:9e6cb258c65b766c6af25cb90958aef40acab67139f05839eef19864cb3154f6", size = 262367 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/89/ba/66db6e1e5f1ef506655b562932f6bd8f72600116d5f31f92d71c1f200b3f/mlx_lm-0.30.5-py3-none-any.whl", hash = "sha256:a80bc8e3efdebe81813b0f6eb403fb66a7a15071e256f4e7102ada986acb75bb", size = 366716 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mlx-metal"
|
||||
@@ -1503,6 +1495,9 @@ version = "11.3.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/f3/0d/d0d6dea55cd152ce3d6767bb38a8fc10e33796ba4ba210cbab9354b6d238/pillow-11.3.0.tar.gz", hash = "sha256:3828ee7586cd0b2091b6209e5ad53e20d0649bbe87164a459d0676e035e8f523", size = 47113069 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/1e/93/0952f2ed8db3a5a4c7a11f91965d6184ebc8cd7cbb7941a260d5f018cd2d/pillow-11.3.0-cp313-cp313-ios_13_0_arm64_iphoneos.whl", hash = "sha256:1c627742b539bba4309df89171356fcb3cc5a9178355b2727d1b74a6cf155fbd", size = 2128328 },
|
||||
{ url = "https://files.pythonhosted.org/packages/4b/e8/100c3d114b1a0bf4042f27e0f87d2f25e857e838034e98ca98fe7b8c0a9c/pillow-11.3.0-cp313-cp313-ios_13_0_arm64_iphonesimulator.whl", hash = "sha256:30b7c02f3899d10f13d7a48163c8969e4e653f8b43416d23d13d1bbfdc93b9f8", size = 2170652 },
|
||||
{ url = "https://files.pythonhosted.org/packages/aa/86/3f758a28a6e381758545f7cdb4942e1cb79abd271bea932998fc0db93cb6/pillow-11.3.0-cp313-cp313-ios_13_0_x86_64_iphonesimulator.whl", hash = "sha256:7859a4cc7c9295f5838015d8cc0a9c215b77e43d07a25e460f35cf516df8626f", size = 2227443 },
|
||||
{ url = "https://files.pythonhosted.org/packages/01/f4/91d5b3ffa718df2f53b0dc109877993e511f4fd055d7e9508682e8aba092/pillow-11.3.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:ec1ee50470b0d050984394423d96325b744d55c701a439d2bd66089bff963d3c", size = 5278474 },
|
||||
{ url = "https://files.pythonhosted.org/packages/f9/0e/37d7d3eca6c879fbd9dba21268427dffda1ab00d4eb05b32923d4fbe3b12/pillow-11.3.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:7db51d222548ccfd274e4572fdbf3e810a5e66b00608862f947b163e613b67dd", size = 4686038 },
|
||||
{ url = "https://files.pythonhosted.org/packages/ff/b0/3426e5c7f6565e752d81221af9d3676fdbb4f352317ceafd42899aaf5d8a/pillow-11.3.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:2d6fcc902a24ac74495df63faad1884282239265c6839a0a6416d33faedfae7e", size = 5864407 },
|
||||
@@ -2278,7 +2273,7 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "transformers"
|
||||
version = "5.0.0"
|
||||
version = "5.0.0rc3"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "filelock", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
@@ -2287,14 +2282,15 @@ dependencies = [
|
||||
{ name = "packaging", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "pyyaml", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "regex", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "requests", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "safetensors", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "tokenizers", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "tqdm", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "typer-slim", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/bc/79/845941711811789c85fb7e2599cea425a14a07eda40f50896b9d3fda7492/transformers-5.0.0.tar.gz", hash = "sha256:5f5634efed6cf76ad068cc5834c7adbc32db78bbd6211fb70df2325a9c37dec8", size = 8424830 }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/3f/a3/7c116a8d85f69ea7749cf4c2df79e64c35d028e5fc7ea0168f299d03b8c7/transformers-5.0.0rc3.tar.gz", hash = "sha256:a0315b92b7e087617ade42ec9e6e92ee7620541cc5d6a3331886c52cbe306f5c", size = 8388520 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/52/f3/ac976fa8e305c9e49772527e09fbdc27cc6831b8a2f6b6063406626be5dd/transformers-5.0.0-py3-none-any.whl", hash = "sha256:587086f249ce64c817213cf36afdb318d087f790723e9b3d4500b97832afd52d", size = 10142091 },
|
||||
{ url = "https://files.pythonhosted.org/packages/1e/f2/ae2b8968764253bdf38a48dee3c299b8d0bedf7c8ffbe3449fca9bd95338/transformers-5.0.0rc3-py3-none-any.whl", hash = "sha256:383fad27f4f73092d330e45fae384681e5c8521e1dc1cf6cb1a297780e68bf2d", size = 10107087 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
Reference in New Issue
Block a user