mirror of
https://github.com/exo-explore/exo.git
synced 2026-01-19 11:28:51 -05:00
Compare commits
10 Commits
alexcheema
...
leo/fix-pi
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0c5c87cd9d | ||
|
|
4c1af11f14 | ||
|
|
f654b98d97 | ||
|
|
060dc8a3d8 | ||
|
|
ea0588429b | ||
|
|
73b3f87e07 | ||
|
|
746589ba6b | ||
|
|
f82f862fd7 | ||
|
|
7ff937d8a1 | ||
|
|
d19bf02404 |
@@ -27,6 +27,15 @@ exo connects all your devices into an AI cluster. Not only does exo enable runni
|
||||
- **Tensor Parallelism**: exo supports sharding models, for up to 1.8x speedup on 2 devices and 3.2x speedup on 4 devices.
|
||||
- **MLX Support**: exo uses [MLX](https://github.com/ml-explore/mlx) as an inference backend and [MLX distributed](https://ml-explore.github.io/mlx/build/html/usage/distributed.html) for distributed communication.
|
||||
|
||||
## Dashboard
|
||||
|
||||
exo includes a built-in dashboard for managing your cluster and chatting with models.
|
||||
|
||||
<p align="center">
|
||||
<img src="docs/imgs/dashboard-cluster-view.png" alt="exo dashboard - cluster view showing 4 x M3 Ultra Mac Studio with DeepSeek v3.1 and Kimi-K2-Thinking loaded" width="80%" />
|
||||
</p>
|
||||
<p align="center"><em>4 × 512GB M3 Ultra Mac Studio running DeepSeek v3.1 (8-bit) and Kimi-K2-Thinking (4-bit)</em></p>
|
||||
|
||||
## Benchmarks
|
||||
|
||||
<details>
|
||||
|
||||
@@ -490,17 +490,17 @@ def main() -> int:
|
||||
logger.debug(f" warmup {i + 1}/{args.warmup} done")
|
||||
|
||||
for pp in pp_list:
|
||||
if (
|
||||
pp * n_nodes > 2048
|
||||
and "ring" in instance_meta.lower()
|
||||
and "tensor" in sharding.lower()
|
||||
):
|
||||
model_card = MODEL_CARDS[short_id]
|
||||
if model_card.metadata.storage_size > Memory.from_gb(10):
|
||||
logger.info(
|
||||
f"Skipping tensor ring as this is too slow for model of size {model_card.metadata.storage_size} on {n_nodes=}"
|
||||
)
|
||||
continue
|
||||
# if (
|
||||
# pp * n_nodes > 2048
|
||||
# and "ring" in instance_meta.lower()
|
||||
# and "tensor" in sharding.lower()
|
||||
# ):
|
||||
# model_card = MODEL_CARDS[short_id]
|
||||
# if model_card.metadata.storage_size > Memory.from_gb(10):
|
||||
# logger.info(
|
||||
# f"Skipping tensor ring as this is too slow for model of size {model_card.metadata.storage_size} on {n_nodes=}"
|
||||
# )
|
||||
# continue
|
||||
for tg in tg_list:
|
||||
runs: list[dict[str, Any]] = []
|
||||
for r in range(args.repeat):
|
||||
|
||||
@@ -69,8 +69,6 @@ export interface Instance {
|
||||
runnerToShard?: Record<string, unknown>;
|
||||
nodeToRunner?: Record<string, string>;
|
||||
};
|
||||
draftModel?: string;
|
||||
numDraftTokens?: number;
|
||||
}
|
||||
|
||||
interface RawNodeProfile {
|
||||
|
||||
@@ -47,7 +47,7 @@ const sidebarVisible = $derived(chatSidebarVisible());
|
||||
let mounted = $state(false);
|
||||
|
||||
// Instance launch state
|
||||
let models = $state<Array<{id: string, hugging_face_id?: string, name?: string, storage_size_megabytes?: number}>>([]);
|
||||
let models = $state<Array<{id: string, name?: string, storage_size_megabytes?: number}>>([]);
|
||||
let selectedSharding = $state<'Pipeline' | 'Tensor'>('Pipeline');
|
||||
type InstanceMeta = 'MlxRing' | 'MlxIbv' | 'MlxJaccl';
|
||||
|
||||
@@ -59,7 +59,7 @@ const sidebarVisible = $derived(chatSidebarVisible());
|
||||
instanceType: InstanceMeta;
|
||||
minNodes: number;
|
||||
}
|
||||
|
||||
|
||||
function saveLaunchDefaults(): void {
|
||||
const defaults: LaunchDefaults = {
|
||||
modelId: selectedPreviewModelId(),
|
||||
@@ -88,16 +88,16 @@ const sidebarVisible = $derived(chatSidebarVisible());
|
||||
function applyLaunchDefaults(availableModels: Array<{id: string}>, maxNodes: number): void {
|
||||
const defaults = loadLaunchDefaults();
|
||||
if (!defaults) return;
|
||||
|
||||
|
||||
// Apply sharding and instance type unconditionally
|
||||
selectedSharding = defaults.sharding;
|
||||
selectedInstanceType = defaults.instanceType;
|
||||
|
||||
|
||||
// Apply minNodes if valid (between 1 and maxNodes)
|
||||
if (defaults.minNodes && defaults.minNodes >= 1 && defaults.minNodes <= maxNodes) {
|
||||
selectedMinNodes = defaults.minNodes;
|
||||
}
|
||||
|
||||
|
||||
// Only apply model if it exists in the available models
|
||||
if (defaults.modelId && availableModels.some(m => m.id === defaults.modelId)) {
|
||||
selectPreviewModel(defaults.modelId);
|
||||
@@ -109,19 +109,11 @@ const sidebarVisible = $derived(chatSidebarVisible());
|
||||
let minNodesInitialized = $state(false);
|
||||
let launchingModelId = $state<string | null>(null);
|
||||
let instanceDownloadExpandedNodes = $state<Set<string>>(new Set());
|
||||
|
||||
// Draft model edit modal state
|
||||
let editingDraftInstanceId = $state<string | null>(null);
|
||||
let editDraftModel = $state<string | null>(null);
|
||||
let editNumDraftTokens = $state<number>(4);
|
||||
let isDraftEditDropdownOpen = $state(false);
|
||||
let draftEditDropdownSearch = $state('');
|
||||
let isSavingDraftModel = $state(false);
|
||||
|
||||
|
||||
// Custom dropdown state
|
||||
let isModelDropdownOpen = $state(false);
|
||||
let modelDropdownSearch = $state('');
|
||||
|
||||
|
||||
// Slider dragging state
|
||||
let isDraggingSlider = $state(false);
|
||||
let sliderTrackElement: HTMLDivElement | null = $state(null);
|
||||
@@ -370,36 +362,47 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
|
||||
async function launchInstance(modelId: string, specificPreview?: PlacementPreview | null) {
|
||||
if (!modelId || launchingModelId) return;
|
||||
|
||||
|
||||
launchingModelId = modelId;
|
||||
|
||||
|
||||
try {
|
||||
// Use the specific preview if provided, otherwise fall back to filtered preview
|
||||
const preview = specificPreview ?? filteredPreview();
|
||||
|
||||
let response: Response;
|
||||
|
||||
// Use /place_instance endpoint - it handles placement and creation in one step
|
||||
const placePayload = {
|
||||
model_id: modelId,
|
||||
sharding: preview?.sharding ?? selectedSharding,
|
||||
instance_meta: preview?.instance_meta ?? selectedInstanceType,
|
||||
min_nodes: selectedMinNodes,
|
||||
};
|
||||
|
||||
response = await fetch('/place_instance', {
|
||||
|
||||
let instanceData: unknown;
|
||||
|
||||
if (preview?.instance) {
|
||||
// Use the instance from the preview
|
||||
instanceData = preview.instance;
|
||||
} else {
|
||||
// Fallback: GET placement from API
|
||||
const placementResponse = await fetch(
|
||||
`/instance/placement?model_id=${encodeURIComponent(modelId)}&sharding=${selectedSharding}&instance_meta=${selectedInstanceType}&min_nodes=${selectedMinNodes}`
|
||||
);
|
||||
|
||||
if (!placementResponse.ok) {
|
||||
const errorText = await placementResponse.text();
|
||||
console.error('Failed to get placement:', errorText);
|
||||
return;
|
||||
}
|
||||
|
||||
instanceData = await placementResponse.json();
|
||||
}
|
||||
|
||||
// POST the instance to create it
|
||||
const response = await fetch('/instance', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify(placePayload)
|
||||
body: JSON.stringify({ instance: instanceData })
|
||||
});
|
||||
|
||||
|
||||
if (!response.ok) {
|
||||
const errorText = await response.text();
|
||||
console.error('Failed to launch instance:', errorText);
|
||||
} else {
|
||||
// Always auto-select the newly launched model so the user chats to what they just launched
|
||||
setSelectedChatModel(modelId);
|
||||
|
||||
|
||||
// Scroll to the bottom of instances container to show the new instance
|
||||
// Use multiple attempts to ensure DOM has updated with the new instance
|
||||
const scrollToBottom = () => {
|
||||
@@ -794,52 +797,6 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
}
|
||||
}
|
||||
|
||||
// Open draft model edit modal for an instance
|
||||
function openDraftModelEdit(instanceId: string, currentDraftModel: string | null, currentNumTokens: number | null) {
|
||||
editingDraftInstanceId = instanceId;
|
||||
editDraftModel = currentDraftModel;
|
||||
editNumDraftTokens = currentNumTokens ?? 4;
|
||||
isDraftEditDropdownOpen = false;
|
||||
draftEditDropdownSearch = '';
|
||||
}
|
||||
|
||||
// Close draft model edit modal
|
||||
function closeDraftModelEdit() {
|
||||
editingDraftInstanceId = null;
|
||||
editDraftModel = null;
|
||||
editNumDraftTokens = 4;
|
||||
isDraftEditDropdownOpen = false;
|
||||
draftEditDropdownSearch = '';
|
||||
}
|
||||
|
||||
// Save draft model settings for an instance
|
||||
async function saveDraftModel() {
|
||||
if (!editingDraftInstanceId || isSavingDraftModel) return;
|
||||
|
||||
isSavingDraftModel = true;
|
||||
try {
|
||||
const response = await fetch(`/instance/${editingDraftInstanceId}/draft_model`, {
|
||||
method: 'PUT',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({
|
||||
draft_model: editDraftModel,
|
||||
num_draft_tokens: editNumDraftTokens,
|
||||
})
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const errorText = await response.text();
|
||||
console.error('Failed to set draft model:', errorText);
|
||||
} else {
|
||||
closeDraftModelEdit();
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Error setting draft model:', error);
|
||||
} finally {
|
||||
isSavingDraftModel = false;
|
||||
}
|
||||
}
|
||||
|
||||
// Helper to unwrap tagged unions like { MlxRingInstance: {...} }
|
||||
function getTagged(obj: unknown): [string | null, unknown] {
|
||||
if (!obj || typeof obj !== 'object') return [null, null];
|
||||
@@ -859,34 +816,30 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
}
|
||||
|
||||
// Get instance details: type (MLX Ring/IBV), sharding (Pipeline/Tensor), and node names
|
||||
function getInstanceInfo(instanceWrapped: unknown): {
|
||||
instanceType: string;
|
||||
sharding: string;
|
||||
function getInstanceInfo(instanceWrapped: unknown): {
|
||||
instanceType: string;
|
||||
sharding: string;
|
||||
nodeNames: string[];
|
||||
nodeIds: string[];
|
||||
nodeCount: number;
|
||||
draftModel: string | null;
|
||||
numDraftTokens: number | null;
|
||||
} {
|
||||
const [instanceTag, instance] = getTagged(instanceWrapped);
|
||||
if (!instance || typeof instance !== 'object') {
|
||||
return { instanceType: 'Unknown', sharding: 'Unknown', nodeNames: [], nodeIds: [], nodeCount: 0, draftModel: null, numDraftTokens: null };
|
||||
return { instanceType: 'Unknown', sharding: 'Unknown', nodeNames: [], nodeIds: [], nodeCount: 0 };
|
||||
}
|
||||
|
||||
|
||||
// Instance type from tag
|
||||
let instanceType = 'Unknown';
|
||||
if (instanceTag === 'MlxRingInstance') instanceType = 'MLX Ring';
|
||||
else if (instanceTag === 'MlxIbvInstance' || instanceTag === 'MlxJacclInstance') instanceType = 'MLX RDMA';
|
||||
|
||||
const inst = instance as {
|
||||
shardAssignments?: {
|
||||
nodeToRunner?: Record<string, string>;
|
||||
|
||||
const inst = instance as {
|
||||
shardAssignments?: {
|
||||
nodeToRunner?: Record<string, string>;
|
||||
runnerToShard?: Record<string, unknown>;
|
||||
};
|
||||
draftModel?: string;
|
||||
numDraftTokens?: number;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
// Sharding strategy from first shard
|
||||
let sharding = 'Unknown';
|
||||
const runnerToShard = inst.shardAssignments?.runnerToShard || {};
|
||||
@@ -897,7 +850,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
else if (shardTag === 'TensorShardMetadata') sharding = 'Tensor';
|
||||
else if (shardTag === 'PrefillDecodeShardMetadata') sharding = 'Prefill/Decode';
|
||||
}
|
||||
|
||||
|
||||
// Node names from topology
|
||||
const nodeToRunner = inst.shardAssignments?.nodeToRunner || {};
|
||||
const nodeIds = Object.keys(nodeToRunner);
|
||||
@@ -905,12 +858,8 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
const node = data?.nodes?.[nodeId];
|
||||
return node?.friendly_name || nodeId.slice(0, 8);
|
||||
});
|
||||
|
||||
// Draft model for speculative decoding
|
||||
const draftModel = inst.draftModel ?? null;
|
||||
const numDraftTokens = inst.numDraftTokens ?? null;
|
||||
|
||||
return { instanceType, sharding, nodeNames, nodeIds, nodeCount: nodeIds.length, draftModel, numDraftTokens };
|
||||
|
||||
return { instanceType, sharding, nodeNames, nodeIds, nodeCount: nodeIds.length };
|
||||
}
|
||||
|
||||
function formatLastUpdate(): string {
|
||||
@@ -1386,31 +1335,16 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
<div class="w-1.5 h-1.5 {isDownloading ? 'bg-blue-400 animate-pulse' : isFailed ? 'bg-red-400' : isLoading ? 'bg-yellow-400 animate-pulse' : isReady ? 'bg-green-400' : 'bg-teal-400'} rounded-full shadow-[0_0_6px_currentColor]"></div>
|
||||
<span class="text-exo-light-gray font-mono text-sm tracking-wider">{id.slice(0, 8).toUpperCase()}</span>
|
||||
</div>
|
||||
<div class="flex items-center gap-2">
|
||||
<!-- Draft Model Button -->
|
||||
<button
|
||||
onclick={() => openDraftModelEdit(id, instanceInfo.draftModel, instanceInfo.numDraftTokens)}
|
||||
class="p-1.5 font-mono border transition-all duration-200 cursor-pointer {instanceInfo.draftModel ? 'border-cyan-500/50 text-cyan-400 hover:bg-cyan-500/20 hover:border-cyan-500' : 'border-exo-medium-gray/50 text-white/40 hover:text-cyan-400 hover:border-cyan-500/50'}"
|
||||
title={instanceInfo.draftModel ? `Draft: ${instanceInfo.draftModel.split('/').pop()} (${instanceInfo.numDraftTokens}t)` : 'Configure speculative decoding'}
|
||||
>
|
||||
<svg class="w-4 h-4" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
|
||||
<path d="M13 2L3 14h9l-1 8 10-12h-9l1-8z"/>
|
||||
</svg>
|
||||
</button>
|
||||
<button
|
||||
onclick={() => deleteInstance(id)}
|
||||
class="text-xs px-2 py-1 font-mono tracking-wider uppercase border border-red-500/30 text-red-400 hover:bg-red-500/20 hover:text-red-400 hover:border-red-500/50 transition-all duration-200 cursor-pointer"
|
||||
>
|
||||
DELETE
|
||||
</button>
|
||||
</div>
|
||||
<button
|
||||
onclick={() => deleteInstance(id)}
|
||||
class="text-xs px-2 py-1 font-mono tracking-wider uppercase border border-red-500/30 text-red-400 hover:bg-red-500/20 hover:text-red-400 hover:border-red-500/50 transition-all duration-200 cursor-pointer"
|
||||
>
|
||||
DELETE
|
||||
</button>
|
||||
</div>
|
||||
<div class="pl-2">
|
||||
<div class="text-exo-yellow text-xs font-mono tracking-wide truncate">{getInstanceModelId(instance)}</div>
|
||||
<div class="text-white/60 text-xs font-mono">Strategy: <span class="text-white/80">{instanceInfo.sharding} ({instanceInfo.instanceType})</span></div>
|
||||
{#if instanceInfo.draftModel}
|
||||
<div class="text-white/60 text-xs font-mono">Draft: <span class="text-cyan-400">{instanceInfo.draftModel.split('/').pop()}</span>{#if instanceInfo.numDraftTokens}<span class="text-white/40"> ({instanceInfo.numDraftTokens}t)</span>{/if}</div>
|
||||
{/if}
|
||||
{#if instanceModelId && instanceModelId !== 'Unknown' && instanceModelId !== 'Unknown Model'}
|
||||
<a
|
||||
class="inline-flex items-center gap-1 text-[11px] text-white/60 hover:text-exo-yellow transition-colors mt-1"
|
||||
@@ -1745,7 +1679,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
|
||||
<!-- Selected Model Preview -->
|
||||
<div class="space-y-3">
|
||||
{#if models.length === 0}
|
||||
@@ -1904,31 +1838,16 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
<div class="w-1.5 h-1.5 {isDownloading ? 'bg-blue-400 animate-pulse' : isFailed ? 'bg-red-400' : isLoading ? 'bg-yellow-400 animate-pulse' : isReady ? 'bg-green-400' : 'bg-teal-400'} rounded-full shadow-[0_0_6px_currentColor]"></div>
|
||||
<span class="text-exo-light-gray font-mono text-sm tracking-wider">{id.slice(0, 8).toUpperCase()}</span>
|
||||
</div>
|
||||
<div class="flex items-center gap-2">
|
||||
<!-- Draft Model Button -->
|
||||
<button
|
||||
onclick={() => openDraftModelEdit(id, instanceInfo.draftModel, instanceInfo.numDraftTokens)}
|
||||
class="p-1.5 font-mono border transition-all duration-200 cursor-pointer {instanceInfo.draftModel ? 'border-cyan-500/50 text-cyan-400 hover:bg-cyan-500/20 hover:border-cyan-500' : 'border-exo-medium-gray/50 text-white/40 hover:text-cyan-400 hover:border-cyan-500/50'}"
|
||||
title={instanceInfo.draftModel ? `Draft: ${instanceInfo.draftModel.split('/').pop()} (${instanceInfo.numDraftTokens}t)` : 'Configure speculative decoding'}
|
||||
>
|
||||
<svg class="w-4 h-4" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
|
||||
<path d="M13 2L3 14h9l-1 8 10-12h-9l1-8z"/>
|
||||
</svg>
|
||||
</button>
|
||||
<button
|
||||
onclick={() => deleteInstance(id)}
|
||||
class="text-xs px-2 py-1 font-mono tracking-wider uppercase border border-red-500/30 text-red-400 hover:bg-red-500/20 hover:text-red-400 hover:border-red-500/50 transition-all duration-200 cursor-pointer"
|
||||
>
|
||||
DELETE
|
||||
</button>
|
||||
</div>
|
||||
<button
|
||||
onclick={() => deleteInstance(id)}
|
||||
class="text-xs px-2 py-1 font-mono tracking-wider uppercase border border-red-500/30 text-red-400 hover:bg-red-500/20 hover:text-red-400 hover:border-red-500/50 transition-all duration-200 cursor-pointer"
|
||||
>
|
||||
DELETE
|
||||
</button>
|
||||
</div>
|
||||
<div class="pl-2">
|
||||
<div class="text-exo-yellow text-xs font-mono tracking-wide truncate">{getInstanceModelId(instance)}</div>
|
||||
<div class="text-white/60 text-xs font-mono">Strategy: <span class="text-white/80">{instanceInfo.sharding} ({instanceInfo.instanceType})</span></div>
|
||||
{#if instanceInfo.draftModel}
|
||||
<div class="text-white/60 text-xs font-mono">Draft: <span class="text-cyan-400">{instanceInfo.draftModel.split('/').pop()}</span>{#if instanceInfo.numDraftTokens}<span class="text-white/40"> ({instanceInfo.numDraftTokens}t)</span>{/if}</div>
|
||||
{/if}
|
||||
{#if instanceModelId && instanceModelId !== 'Unknown' && instanceModelId !== 'Unknown Model'}
|
||||
<a
|
||||
class="inline-flex items-center gap-1 text-[11px] text-white/60 hover:text-exo-yellow transition-colors mt-1"
|
||||
@@ -2059,120 +1978,4 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
{/if}
|
||||
</main>
|
||||
|
||||
<!-- Draft Model Edit Modal -->
|
||||
{#if editingDraftInstanceId}
|
||||
<!-- svelte-ignore a11y_no_static_element_interactions -->
|
||||
<div
|
||||
class="fixed inset-0 z-50 flex items-center justify-center bg-black/70 backdrop-blur-sm"
|
||||
onclick={closeDraftModelEdit}
|
||||
onkeydown={(e) => e.key === 'Escape' && closeDraftModelEdit()}
|
||||
>
|
||||
<!-- svelte-ignore a11y_click_events_have_key_events -->
|
||||
<div
|
||||
class="bg-exo-dark-gray border border-exo-medium-gray/50 rounded-lg shadow-2xl p-6 w-full max-w-md mx-4"
|
||||
onclick={(e) => e.stopPropagation()}
|
||||
>
|
||||
<div class="flex items-center justify-between mb-4">
|
||||
<h3 class="text-lg font-mono text-exo-yellow tracking-wide">Speculative Decoding</h3>
|
||||
<button
|
||||
onclick={closeDraftModelEdit}
|
||||
class="text-white/60 hover:text-white transition-colors cursor-pointer"
|
||||
aria-label="Close"
|
||||
>
|
||||
<svg class="w-5 h-5" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" d="M6 18L18 6M6 6l12 12" />
|
||||
</svg>
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<p class="text-white/60 text-sm font-mono mb-4">
|
||||
Configure a draft model for faster generation. The draft model proposes tokens that the main model verifies.
|
||||
</p>
|
||||
|
||||
<!-- Draft Model Dropdown -->
|
||||
<div class="mb-4">
|
||||
<div class="text-xs text-white/70 font-mono mb-2">Draft Model:</div>
|
||||
<div class="relative">
|
||||
<button
|
||||
onclick={() => { isDraftEditDropdownOpen = !isDraftEditDropdownOpen; draftEditDropdownSearch = ''; }}
|
||||
class="w-full px-3 py-2 text-left text-sm font-mono border rounded transition-all duration-200 cursor-pointer flex items-center justify-between gap-2 {editDraftModel ? 'bg-transparent text-cyan-400 border-cyan-500/50' : 'bg-transparent text-white/50 border-exo-medium-gray/50 hover:border-cyan-500/50'}"
|
||||
>
|
||||
<span class="truncate">{editDraftModel ? editDraftModel.split('/').pop() : 'None'}</span>
|
||||
<svg class="w-4 h-4 flex-shrink-0 transition-transform {isDraftEditDropdownOpen ? 'rotate-180' : ''}" fill="none" stroke="currentColor" viewBox="0 0 24 24">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M19 9l-7 7-7-7" />
|
||||
</svg>
|
||||
</button>
|
||||
{#if isDraftEditDropdownOpen}
|
||||
<div class="absolute top-full left-0 right-0 mt-1 bg-exo-dark-gray border border-exo-medium-gray/50 rounded shadow-lg z-50 max-h-48 overflow-hidden flex flex-col">
|
||||
<div class="p-2 border-b border-exo-medium-gray/30">
|
||||
<input
|
||||
type="text"
|
||||
bind:value={draftEditDropdownSearch}
|
||||
placeholder="Search models..."
|
||||
class="w-full px-2 py-1.5 text-sm font-mono bg-transparent border border-exo-medium-gray/50 rounded text-white/90 placeholder:text-white/30 focus:outline-none focus:border-cyan-500/50"
|
||||
/>
|
||||
</div>
|
||||
<div class="overflow-y-auto max-h-36">
|
||||
<!-- None option -->
|
||||
<button
|
||||
onclick={() => { editDraftModel = null; isDraftEditDropdownOpen = false; }}
|
||||
class="w-full px-3 py-2 text-left text-sm font-mono tracking-wide transition-colors duration-100 flex items-center gap-2 {editDraftModel === null ? 'bg-transparent text-cyan-400 cursor-pointer' : 'text-white/80 hover:text-cyan-400 cursor-pointer'}"
|
||||
>
|
||||
<span>None (Disable)</span>
|
||||
</button>
|
||||
{#each models.filter(m => (m.name ?? m.id).toLowerCase().includes(draftEditDropdownSearch.toLowerCase())) as model}
|
||||
{@const sizeGB = (model.storage_size_megabytes ?? 0) / 1024}
|
||||
{@const modelHfId = model.hugging_face_id ?? model.id}
|
||||
<button
|
||||
onclick={() => { editDraftModel = modelHfId; isDraftEditDropdownOpen = false; }}
|
||||
class="w-full px-3 py-2 text-left text-sm font-mono tracking-wide transition-colors duration-100 flex items-center justify-between gap-2 {editDraftModel === modelHfId ? 'bg-transparent text-cyan-400 cursor-pointer' : 'text-white/80 hover:text-cyan-400 cursor-pointer'}"
|
||||
>
|
||||
<span class="truncate">{model.name || model.id}</span>
|
||||
<span class="flex-shrink-0 text-xs text-white/50">
|
||||
{sizeGB >= 1 ? sizeGB.toFixed(0) : sizeGB.toFixed(1)}GB
|
||||
</span>
|
||||
</button>
|
||||
{:else}
|
||||
<div class="px-3 py-2 text-xs text-white/50 font-mono">No models found</div>
|
||||
{/each}
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Draft Tokens -->
|
||||
{#if editDraftModel}
|
||||
<div class="mb-6">
|
||||
<div class="text-xs text-white/70 font-mono mb-2">Draft Tokens per Iteration:</div>
|
||||
<div class="flex items-center gap-2">
|
||||
{#each [2, 3, 4, 5, 6] as n}
|
||||
<button
|
||||
onclick={() => editNumDraftTokens = n}
|
||||
class="w-8 h-8 text-sm font-mono rounded transition-all {editNumDraftTokens === n ? 'bg-cyan-500/20 text-cyan-400 border border-cyan-500/50' : 'text-white/50 hover:text-white/80 border border-exo-medium-gray/50 hover:border-white/30'} cursor-pointer"
|
||||
>{n}</button>
|
||||
{/each}
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<!-- Action Buttons -->
|
||||
<div class="flex items-center justify-end gap-3">
|
||||
<button
|
||||
onclick={closeDraftModelEdit}
|
||||
class="px-4 py-2 text-sm font-mono text-white/70 hover:text-white transition-colors cursor-pointer"
|
||||
>
|
||||
Cancel
|
||||
</button>
|
||||
<button
|
||||
onclick={saveDraftModel}
|
||||
disabled={isSavingDraftModel}
|
||||
class="px-4 py-2 text-sm font-mono border border-cyan-500/50 text-cyan-400 hover:bg-cyan-500/20 hover:border-cyan-500 transition-all disabled:opacity-50 disabled:cursor-not-allowed cursor-pointer"
|
||||
>
|
||||
{isSavingDraftModel ? 'Saving...' : 'Save'}
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
BIN
docs/imgs/dashboard-cluster-view.png
Normal file
BIN
docs/imgs/dashboard-cluster-view.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 187 KiB |
@@ -39,8 +39,6 @@ from exo.shared.types.api import (
|
||||
PlaceInstanceParams,
|
||||
PlacementPreview,
|
||||
PlacementPreviewResponse,
|
||||
SetDraftModelParams,
|
||||
SetDraftModelResponse,
|
||||
StreamingChoiceResponse,
|
||||
)
|
||||
from exo.shared.types.chunks import TokenChunk
|
||||
@@ -51,7 +49,6 @@ from exo.shared.types.commands import (
|
||||
DeleteInstance,
|
||||
ForwarderCommand,
|
||||
PlaceInstance,
|
||||
SetInstanceDraftModel,
|
||||
TaskFinished,
|
||||
)
|
||||
from exo.shared.types.common import CommandId, NodeId, SessionId
|
||||
@@ -158,18 +155,19 @@ class API:
|
||||
self.paused_ev = anyio.Event()
|
||||
|
||||
def _setup_exception_handlers(self) -> None:
|
||||
@self.app.exception_handler(HTTPException)
|
||||
async def http_exception_handler( # pyright: ignore[reportUnusedFunction]
|
||||
_: Request, exc: HTTPException
|
||||
) -> JSONResponse:
|
||||
err = ErrorResponse(
|
||||
error=ErrorInfo(
|
||||
message=exc.detail,
|
||||
type=HTTPStatus(exc.status_code).phrase,
|
||||
code=exc.status_code,
|
||||
)
|
||||
self.app.exception_handler(HTTPException)(self.http_exception_handler)
|
||||
|
||||
async def http_exception_handler(
|
||||
self, _: Request, exc: HTTPException
|
||||
) -> JSONResponse:
|
||||
err = ErrorResponse(
|
||||
error=ErrorInfo(
|
||||
message=exc.detail,
|
||||
type=HTTPStatus(exc.status_code).phrase,
|
||||
code=exc.status_code,
|
||||
)
|
||||
return JSONResponse(err.model_dump(), status_code=exc.status_code)
|
||||
)
|
||||
return JSONResponse(err.model_dump(), status_code=exc.status_code)
|
||||
|
||||
def _setup_cors(self) -> None:
|
||||
self.app.add_middleware(
|
||||
@@ -188,7 +186,6 @@ class API:
|
||||
self.app.get("/instance/previews")(self.get_placement_previews)
|
||||
self.app.get("/instance/{instance_id}")(self.get_instance)
|
||||
self.app.delete("/instance/{instance_id}")(self.delete_instance)
|
||||
self.app.put("/instance/{instance_id}/draft_model")(self.set_draft_model)
|
||||
self.app.get("/models")(self.get_models)
|
||||
self.app.get("/v1/models")(self.get_models)
|
||||
self.app.post("/v1/chat/completions", response_model=None)(
|
||||
@@ -204,8 +201,6 @@ class API:
|
||||
sharding=payload.sharding,
|
||||
instance_meta=payload.instance_meta,
|
||||
min_nodes=payload.min_nodes,
|
||||
draft_model=payload.draft_model,
|
||||
num_draft_tokens=payload.num_draft_tokens,
|
||||
)
|
||||
await self._send(command)
|
||||
|
||||
@@ -402,24 +397,6 @@ class API:
|
||||
instance_id=instance_id,
|
||||
)
|
||||
|
||||
async def set_draft_model(
|
||||
self, instance_id: InstanceId, payload: SetDraftModelParams
|
||||
) -> SetDraftModelResponse:
|
||||
if instance_id not in self.state.instances:
|
||||
raise HTTPException(status_code=404, detail="Instance not found")
|
||||
|
||||
command = SetInstanceDraftModel(
|
||||
instance_id=instance_id,
|
||||
draft_model=payload.draft_model,
|
||||
num_draft_tokens=payload.num_draft_tokens,
|
||||
)
|
||||
await self._send(command)
|
||||
return SetDraftModelResponse(
|
||||
message="Command received.",
|
||||
command_id=command.command_id,
|
||||
instance_id=instance_id,
|
||||
)
|
||||
|
||||
async def _chat_chunk_stream(
|
||||
self, command_id: CommandId
|
||||
) -> AsyncGenerator[TokenChunk, None]:
|
||||
|
||||
@@ -18,7 +18,6 @@ from exo.shared.types.commands import (
|
||||
ForwarderCommand,
|
||||
PlaceInstance,
|
||||
RequestEventLog,
|
||||
SetInstanceDraftModel,
|
||||
TaskFinished,
|
||||
TestCommand,
|
||||
)
|
||||
@@ -28,7 +27,6 @@ from exo.shared.types.events import (
|
||||
ForwarderEvent,
|
||||
IndexedEvent,
|
||||
InstanceDeleted,
|
||||
InstanceDraftModelUpdated,
|
||||
NodeTimedOut,
|
||||
TaskCreated,
|
||||
TaskDeleted,
|
||||
@@ -175,14 +173,6 @@ class Master:
|
||||
self.state.instances, placement
|
||||
)
|
||||
generated_events.extend(transition_events)
|
||||
case SetInstanceDraftModel():
|
||||
generated_events.append(
|
||||
InstanceDraftModelUpdated(
|
||||
instance_id=command.instance_id,
|
||||
draft_model=command.draft_model,
|
||||
num_draft_tokens=command.num_draft_tokens,
|
||||
)
|
||||
)
|
||||
case TaskFinished():
|
||||
generated_events.append(
|
||||
TaskDeleted(
|
||||
|
||||
@@ -3,6 +3,8 @@ from collections.abc import Mapping
|
||||
from copy import deepcopy
|
||||
from typing import Sequence
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from exo.master.placement_utils import (
|
||||
filter_cycles_by_memory,
|
||||
get_mlx_ibv_devices_matrix,
|
||||
@@ -53,6 +55,7 @@ def place_instance(
|
||||
) -> dict[InstanceId, Instance]:
|
||||
all_nodes = list(topology.list_nodes())
|
||||
|
||||
logger.info("finding cycles:")
|
||||
cycles = topology.get_cycles()
|
||||
singleton_cycles = [[node] for node in all_nodes]
|
||||
candidate_cycles = list(
|
||||
@@ -125,6 +128,10 @@ def place_instance(
|
||||
target_instances = dict(deepcopy(current_instances))
|
||||
|
||||
if len(selected_cycle) == 1:
|
||||
logger.warning(
|
||||
"You have likely selected ibv for a single node instance; falling back to MlxRing"
|
||||
)
|
||||
|
||||
command.instance_meta = InstanceMeta.MlxRing
|
||||
|
||||
# TODO: Single node instances
|
||||
@@ -144,8 +151,6 @@ def place_instance(
|
||||
shard_assignments=shard_assignments,
|
||||
ibv_devices=mlx_ibv_devices,
|
||||
jaccl_coordinators=mlx_jaccl_coordinators,
|
||||
draft_model=command.draft_model,
|
||||
num_draft_tokens=command.num_draft_tokens,
|
||||
)
|
||||
case InstanceMeta.MlxRing:
|
||||
ephemeral_port = random_ephemeral_port()
|
||||
@@ -159,8 +164,6 @@ def place_instance(
|
||||
shard_assignments=shard_assignments,
|
||||
hosts_by_node=hosts_by_node,
|
||||
ephemeral_port=ephemeral_port,
|
||||
draft_model=command.draft_model,
|
||||
num_draft_tokens=command.num_draft_tokens,
|
||||
)
|
||||
|
||||
return target_instances
|
||||
|
||||
@@ -49,33 +49,83 @@ def get_smallest_cycles(cycles: list[list[NodeInfo]]) -> list[list[NodeInfo]]:
|
||||
return [cycle for cycle in cycles if len(cycle) == min_nodes]
|
||||
|
||||
|
||||
def allocate_layers_proportionally(
|
||||
total_layers: int,
|
||||
memory_fractions: list[float],
|
||||
) -> list[int]:
|
||||
n = len(memory_fractions)
|
||||
if n == 0:
|
||||
raise ValueError("Cannot allocate layers to an empty node list")
|
||||
if total_layers < n:
|
||||
raise ValueError(
|
||||
f"Cannot distribute {total_layers} layers across {n} nodes "
|
||||
"(need at least 1 layer per node)"
|
||||
)
|
||||
|
||||
# Largest remainder: floor each, then distribute remainder by fractional part
|
||||
raw = [f * total_layers for f in memory_fractions]
|
||||
result = [int(r) for r in raw]
|
||||
by_remainder = sorted(range(n), key=lambda i: raw[i] - result[i], reverse=True)
|
||||
for i in range(total_layers - sum(result)):
|
||||
result[by_remainder[i]] += 1
|
||||
|
||||
# Ensure minimum 1 per node by taking from the largest
|
||||
for i in range(n):
|
||||
if result[i] == 0:
|
||||
max_idx = max(range(n), key=lambda j: result[j])
|
||||
assert result[max_idx] > 1
|
||||
result[max_idx] -= 1
|
||||
result[i] = 1
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def get_shard_assignments_for_pipeline_parallel(
|
||||
model_meta: ModelMetadata,
|
||||
selected_cycle: list[NodeWithProfile],
|
||||
):
|
||||
if not selected_cycle:
|
||||
raise ValueError("Cannot create shard assignments for empty node cycle")
|
||||
|
||||
cycle_memory = sum(
|
||||
(node.node_profile.memory.ram_available for node in selected_cycle),
|
||||
start=Memory(),
|
||||
)
|
||||
|
||||
if cycle_memory.in_bytes == 0:
|
||||
raise ValueError("Cannot create shard assignments: total available memory is 0")
|
||||
|
||||
total_layers = model_meta.n_layers
|
||||
world_size = len(selected_cycle)
|
||||
runner_to_shard: dict[RunnerId, ShardMetadata] = {}
|
||||
node_to_runner: dict[NodeId, RunnerId] = {}
|
||||
|
||||
layers_assigned = 0
|
||||
for i, node in enumerate(selected_cycle):
|
||||
if i == len(selected_cycle) - 1:
|
||||
node_layers = total_layers - layers_assigned
|
||||
else:
|
||||
node_layers = round(
|
||||
total_layers
|
||||
* (
|
||||
node.node_profile.memory.ram_available.in_bytes
|
||||
/ cycle_memory.in_bytes
|
||||
)
|
||||
)
|
||||
node_layers = max(1, node_layers)
|
||||
layer_allocations = allocate_layers_proportionally(
|
||||
total_layers=total_layers,
|
||||
memory_fractions=[
|
||||
node.node_profile.memory.ram_available.in_bytes / cycle_memory.in_bytes
|
||||
for node in selected_cycle
|
||||
],
|
||||
)
|
||||
|
||||
# Validate each node has sufficient memory for its assigned layers
|
||||
memory_per_layer = model_meta.storage_size.in_bytes / total_layers
|
||||
for i, (node, node_layers) in enumerate(
|
||||
zip(selected_cycle, layer_allocations, strict=True)
|
||||
):
|
||||
required_memory = node_layers * memory_per_layer
|
||||
available_memory = node.node_profile.memory.ram_available.in_bytes
|
||||
if required_memory > available_memory:
|
||||
raise ValueError(
|
||||
f"Node {i} ({node.node_id}) has insufficient memory: "
|
||||
f"requires {required_memory / (1024**3):.2f} GB for {node_layers} layers, "
|
||||
f"but only has {available_memory / (1024**3):.2f} GB available"
|
||||
)
|
||||
|
||||
layers_assigned = 0
|
||||
for i, (node, node_layers) in enumerate(
|
||||
zip(selected_cycle, layer_allocations, strict=True)
|
||||
):
|
||||
runner_id = RunnerId()
|
||||
|
||||
shard = PipelineShardMetadata(
|
||||
|
||||
@@ -70,7 +70,7 @@ def place_instance_command(model_meta: ModelMetadata) -> PlaceInstance:
|
||||
[
|
||||
((500, 500, 1000), 12, (3, 3, 6)),
|
||||
((500, 500, 500), 12, (4, 4, 4)),
|
||||
((312, 518, 1024), 12, (2, 3, 7)),
|
||||
((312, 468, 1092), 12, (2, 3, 7)),
|
||||
],
|
||||
)
|
||||
def test_get_instance_placements_create_instance(
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import Callable
|
||||
import pytest
|
||||
|
||||
from exo.master.placement_utils import (
|
||||
allocate_layers_proportionally,
|
||||
filter_cycles_by_memory,
|
||||
get_hosts_from_subgraph,
|
||||
get_mlx_jaccl_coordinators,
|
||||
@@ -165,6 +166,9 @@ def test_get_smallest_cycles(
|
||||
((500, 500, 1000), 12, (3, 3, 6)),
|
||||
((500, 500, 500), 12, (4, 4, 4)),
|
||||
((312, 518, 1024), 12, (2, 3, 7)),
|
||||
# Edge case: one node has ~90% of memory - should not over-allocate.
|
||||
# Each node must have enough memory for at least 1 layer (50 KB = 1000/20).
|
||||
((900, 50, 50), 20, (18, 1, 1)),
|
||||
],
|
||||
)
|
||||
def test_get_shard_assignments(
|
||||
@@ -397,3 +401,96 @@ def test_get_mlx_jaccl_coordinators(
|
||||
assert coordinators[node_c_id] == (
|
||||
f"{conn_c_a.send_back_multiaddr.ip_address}:5000"
|
||||
), "node_c should use the IP from conn_c_a"
|
||||
|
||||
|
||||
class TestAllocateLayersProportionally:
|
||||
def test_empty_node_list_raises(self):
|
||||
with pytest.raises(ValueError, match="empty node list"):
|
||||
allocate_layers_proportionally(total_layers=10, memory_fractions=[])
|
||||
|
||||
def test_zero_layers_raises(self):
|
||||
with pytest.raises(ValueError, match="need at least 1 layer per node"):
|
||||
allocate_layers_proportionally(total_layers=0, memory_fractions=[0.5, 0.5])
|
||||
|
||||
def test_negative_layers_raises(self):
|
||||
with pytest.raises(ValueError, match="need at least 1 layer per node"):
|
||||
allocate_layers_proportionally(total_layers=-1, memory_fractions=[0.5, 0.5])
|
||||
|
||||
def test_fewer_layers_than_nodes_raises(self):
|
||||
with pytest.raises(ValueError, match="need at least 1 layer per node"):
|
||||
allocate_layers_proportionally(
|
||||
total_layers=2, memory_fractions=[0.33, 0.33, 0.34]
|
||||
)
|
||||
|
||||
def test_equal_distribution(self):
|
||||
result = allocate_layers_proportionally(
|
||||
total_layers=12, memory_fractions=[0.25, 0.25, 0.25, 0.25]
|
||||
)
|
||||
assert result == [3, 3, 3, 3]
|
||||
assert sum(result) == 12
|
||||
|
||||
def test_proportional_distribution(self):
|
||||
result = allocate_layers_proportionally(
|
||||
total_layers=12, memory_fractions=[0.25, 0.25, 0.50]
|
||||
)
|
||||
assert result == [3, 3, 6]
|
||||
assert sum(result) == 12
|
||||
|
||||
def test_extreme_imbalance_ensures_minimum(self):
|
||||
result = allocate_layers_proportionally(
|
||||
total_layers=20, memory_fractions=[0.975, 0.0125, 0.0125]
|
||||
)
|
||||
assert all(layers >= 1 for layers in result)
|
||||
assert sum(result) == 20
|
||||
# Small nodes get minimum 1 layer
|
||||
assert result == [18, 1, 1]
|
||||
|
||||
def test_single_node_gets_all_layers(self):
|
||||
result = allocate_layers_proportionally(total_layers=10, memory_fractions=[1.0])
|
||||
assert result == [10]
|
||||
|
||||
def test_minimum_viable_allocation(self):
|
||||
result = allocate_layers_proportionally(
|
||||
total_layers=3, memory_fractions=[0.33, 0.33, 0.34]
|
||||
)
|
||||
assert result == [1, 1, 1]
|
||||
assert sum(result) == 3
|
||||
|
||||
|
||||
def test_get_shard_assignments_insufficient_memory_raises(
|
||||
topology: Topology,
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
create_connection: Callable[[NodeId, NodeId], Connection],
|
||||
):
|
||||
"""Test that ValueError is raised when a node has insufficient memory for its layers."""
|
||||
node_a_id = NodeId()
|
||||
node_b_id = NodeId()
|
||||
node_c_id = NodeId()
|
||||
|
||||
# Node C has only 10 KB but would need 50 KB for 1 layer (1000 KB / 20 layers)
|
||||
node_a = create_node(900 * 1024, node_a_id)
|
||||
node_b = create_node(50 * 1024, node_b_id)
|
||||
node_c = create_node(10 * 1024, node_c_id) # Insufficient memory
|
||||
|
||||
topology.add_node(node_a)
|
||||
topology.add_node(node_b)
|
||||
topology.add_node(node_c)
|
||||
|
||||
topology.add_connection(create_connection(node_a_id, node_b_id))
|
||||
topology.add_connection(create_connection(node_b_id, node_c_id))
|
||||
topology.add_connection(create_connection(node_c_id, node_a_id))
|
||||
topology.add_connection(create_connection(node_b_id, node_a_id))
|
||||
|
||||
model_meta = ModelMetadata(
|
||||
model_id=ModelId("test-model"),
|
||||
pretty_name="Test Model",
|
||||
n_layers=20,
|
||||
storage_size=Memory.from_kb(1000),
|
||||
hidden_size=1000,
|
||||
supports_tensor=True,
|
||||
)
|
||||
cycles = topology.get_cycles()
|
||||
selected_cycle = cycles[0]
|
||||
|
||||
with pytest.raises(ValueError, match="insufficient memory"):
|
||||
get_shard_assignments(model_meta, selected_cycle, Sharding.Pipeline)
|
||||
|
||||
@@ -11,7 +11,6 @@ from exo.shared.types.events import (
|
||||
IndexedEvent,
|
||||
InstanceCreated,
|
||||
InstanceDeleted,
|
||||
InstanceDraftModelUpdated,
|
||||
NodeCreated,
|
||||
NodeDownloadProgress,
|
||||
NodeMemoryMeasured,
|
||||
@@ -48,8 +47,6 @@ def event_apply(event: Event, state: State) -> State:
|
||||
return apply_instance_created(event, state)
|
||||
case InstanceDeleted():
|
||||
return apply_instance_deleted(event, state)
|
||||
case InstanceDraftModelUpdated():
|
||||
return apply_instance_draft_model_updated(event, state)
|
||||
case NodeCreated():
|
||||
return apply_topology_node_created(event, state)
|
||||
case NodeTimedOut():
|
||||
@@ -172,25 +169,6 @@ def apply_instance_deleted(event: InstanceDeleted, state: State) -> State:
|
||||
return state.model_copy(update={"instances": new_instances})
|
||||
|
||||
|
||||
def apply_instance_draft_model_updated(
|
||||
event: InstanceDraftModelUpdated, state: State
|
||||
) -> State:
|
||||
if event.instance_id not in state.instances:
|
||||
return state
|
||||
instance = state.instances[event.instance_id]
|
||||
updated_instance = instance.model_copy(
|
||||
update={
|
||||
"draft_model": event.draft_model,
|
||||
"num_draft_tokens": event.num_draft_tokens,
|
||||
}
|
||||
)
|
||||
new_instances: Mapping[InstanceId, Instance] = {
|
||||
**state.instances,
|
||||
event.instance_id: updated_instance,
|
||||
}
|
||||
return state.model_copy(update={"instances": new_instances})
|
||||
|
||||
|
||||
def apply_runner_status_updated(event: RunnerStatusUpdated, state: State) -> State:
|
||||
new_runners: Mapping[RunnerId, RunnerStatus] = {
|
||||
**state.runners,
|
||||
|
||||
@@ -161,8 +161,6 @@ class ChatCompletionTaskParams(BaseModel):
|
||||
tool_choice: str | dict[str, Any] | None = None
|
||||
parallel_tool_calls: bool | None = None
|
||||
user: str | None = None
|
||||
# Speculative decoding: tokens to draft per iteration (if instance has draft model)
|
||||
num_draft_tokens: int = 3
|
||||
|
||||
|
||||
class BenchChatCompletionTaskParams(ChatCompletionTaskParams):
|
||||
@@ -174,8 +172,6 @@ class PlaceInstanceParams(BaseModel):
|
||||
sharding: Sharding = Sharding.Pipeline
|
||||
instance_meta: InstanceMeta = InstanceMeta.MlxRing
|
||||
min_nodes: int = 1
|
||||
draft_model: ModelId | None = None # For speculative decoding
|
||||
num_draft_tokens: int = 4 # Tokens to draft per iteration
|
||||
|
||||
@field_validator("sharding", "instance_meta", mode="plain")
|
||||
@classmethod
|
||||
@@ -217,14 +213,3 @@ class DeleteInstanceResponse(BaseModel):
|
||||
message: str
|
||||
command_id: CommandId
|
||||
instance_id: InstanceId
|
||||
|
||||
|
||||
class SetDraftModelParams(BaseModel):
|
||||
draft_model: ModelId | None = None # None to disable speculative decoding
|
||||
num_draft_tokens: int = 4
|
||||
|
||||
|
||||
class SetDraftModelResponse(BaseModel):
|
||||
message: str
|
||||
command_id: CommandId
|
||||
instance_id: InstanceId
|
||||
|
||||
@@ -2,7 +2,7 @@ from pydantic import Field
|
||||
|
||||
from exo.shared.types.api import ChatCompletionTaskParams
|
||||
from exo.shared.types.common import CommandId, NodeId
|
||||
from exo.shared.types.models import ModelId, ModelMetadata
|
||||
from exo.shared.types.models import ModelMetadata
|
||||
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
|
||||
from exo.shared.types.worker.shards import Sharding
|
||||
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
|
||||
@@ -25,8 +25,6 @@ class PlaceInstance(BaseCommand):
|
||||
sharding: Sharding
|
||||
instance_meta: InstanceMeta
|
||||
min_nodes: int
|
||||
draft_model: ModelId | None = None # For speculative decoding
|
||||
num_draft_tokens: int = 4 # Tokens to draft per iteration
|
||||
|
||||
|
||||
class CreateInstance(BaseCommand):
|
||||
@@ -37,14 +35,6 @@ class DeleteInstance(BaseCommand):
|
||||
instance_id: InstanceId
|
||||
|
||||
|
||||
class SetInstanceDraftModel(BaseCommand):
|
||||
"""Set or update the draft model for an existing instance."""
|
||||
|
||||
instance_id: InstanceId
|
||||
draft_model: ModelId | None # None to disable speculative decoding
|
||||
num_draft_tokens: int = 4
|
||||
|
||||
|
||||
class TaskFinished(BaseCommand):
|
||||
finished_command_id: CommandId
|
||||
|
||||
@@ -60,7 +50,6 @@ Command = (
|
||||
| PlaceInstance
|
||||
| CreateInstance
|
||||
| DeleteInstance
|
||||
| SetInstanceDraftModel
|
||||
| TaskFinished
|
||||
)
|
||||
|
||||
|
||||
@@ -5,7 +5,6 @@ from pydantic import Field
|
||||
from exo.shared.topology import Connection, NodePerformanceProfile
|
||||
from exo.shared.types.chunks import GenerationChunk
|
||||
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
|
||||
from exo.shared.types.models import ModelId
|
||||
from exo.shared.types.profiling import MemoryPerformanceProfile
|
||||
from exo.shared.types.tasks import Task, TaskId, TaskStatus
|
||||
from exo.shared.types.worker.downloads import DownloadProgress
|
||||
@@ -68,14 +67,6 @@ class InstanceDeleted(BaseEvent):
|
||||
instance_id: InstanceId
|
||||
|
||||
|
||||
class InstanceDraftModelUpdated(BaseEvent):
|
||||
"""Draft model updated on an existing instance."""
|
||||
|
||||
instance_id: InstanceId
|
||||
draft_model: ModelId | None
|
||||
num_draft_tokens: int
|
||||
|
||||
|
||||
class RunnerStatusUpdated(BaseEvent):
|
||||
runner_id: RunnerId
|
||||
runner_status: RunnerStatus
|
||||
@@ -132,7 +123,6 @@ Event = (
|
||||
| TaskAcknowledged
|
||||
| InstanceCreated
|
||||
| InstanceDeleted
|
||||
| InstanceDraftModelUpdated
|
||||
| RunnerStatusUpdated
|
||||
| RunnerDeleted
|
||||
| NodeCreated
|
||||
|
||||
@@ -36,12 +36,6 @@ class DownloadModel(BaseTask): # emitted by Worker
|
||||
shard_metadata: ShardMetadata
|
||||
|
||||
|
||||
class DownloadDraftModel(BaseTask): # emitted by Worker
|
||||
"""Download a draft model for speculative decoding (rank 0 only)."""
|
||||
|
||||
model_id: str # HuggingFace model ID
|
||||
|
||||
|
||||
class LoadModel(BaseTask): # emitted by Worker
|
||||
pass
|
||||
|
||||
@@ -66,21 +60,12 @@ class Shutdown(BaseTask): # emitted by Worker
|
||||
runner_id: RunnerId
|
||||
|
||||
|
||||
class SetDraftModel(BaseTask): # emitted by Worker
|
||||
"""Load or clear a draft model on an already-running instance."""
|
||||
|
||||
model_id: str | None # HuggingFace model ID, or None to clear
|
||||
num_draft_tokens: int = 4
|
||||
|
||||
|
||||
Task = (
|
||||
CreateRunner
|
||||
| DownloadModel
|
||||
| DownloadDraftModel
|
||||
| ConnectToGroup
|
||||
| LoadModel
|
||||
| StartWarmup
|
||||
| ChatCompletion
|
||||
| Shutdown
|
||||
| SetDraftModel
|
||||
)
|
||||
|
||||
@@ -3,7 +3,6 @@ from enum import Enum
|
||||
from pydantic import model_validator
|
||||
|
||||
from exo.shared.types.common import Host, Id, NodeId
|
||||
from exo.shared.types.models import ModelId
|
||||
from exo.shared.types.worker.runners import RunnerId, ShardAssignments, ShardMetadata
|
||||
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
|
||||
|
||||
@@ -20,8 +19,6 @@ class InstanceMeta(str, Enum):
|
||||
class BaseInstance(TaggedModel):
|
||||
instance_id: InstanceId
|
||||
shard_assignments: ShardAssignments
|
||||
draft_model: ModelId | None = None # For speculative decoding (rank 0 only)
|
||||
num_draft_tokens: int = 4 # Tokens to draft per iteration (when draft_model is set)
|
||||
|
||||
def shard(self, runner_id: RunnerId) -> ShardMetadata | None:
|
||||
return self.shard_assignments.runner_to_shard.get(runner_id, None)
|
||||
|
||||
@@ -41,14 +41,16 @@ class _LayerCallable(Protocol):
|
||||
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array: ...
|
||||
|
||||
|
||||
class CustomMlxLayer(nn.Module):
|
||||
class CustomMlxModule(nn.Module):
|
||||
"""Base class for replacing an MLX layer with a custom implementation."""
|
||||
|
||||
def __init__(self, original_layer: _LayerCallable):
|
||||
super().__init__()
|
||||
# Set twice to avoid __setattr__ recursion
|
||||
object.__setattr__(self, "_original_layer", original_layer)
|
||||
self.original_layer: _LayerCallable = original_layer
|
||||
|
||||
@property
|
||||
def original_layer(self) -> _LayerCallable:
|
||||
return cast(_LayerCallable, object.__getattribute__(self, "_original_layer"))
|
||||
|
||||
# Calls __getattr__ for any attributes not found on nn.Module (e.g. use_sliding)
|
||||
if not TYPE_CHECKING:
|
||||
@@ -58,10 +60,10 @@ class CustomMlxLayer(nn.Module):
|
||||
return super().__getattr__(name)
|
||||
except AttributeError:
|
||||
original_layer = object.__getattribute__(self, "_original_layer")
|
||||
return object.__getattribute__(original_layer, name)
|
||||
return getattr(original_layer, name)
|
||||
|
||||
|
||||
class PipelineFirstLayer(CustomMlxLayer):
|
||||
class PipelineFirstLayer(CustomMlxModule):
|
||||
def __init__(
|
||||
self,
|
||||
original_layer: _LayerCallable,
|
||||
@@ -78,7 +80,7 @@ class PipelineFirstLayer(CustomMlxLayer):
|
||||
return self.original_layer(x, *args, **kwargs)
|
||||
|
||||
|
||||
class PipelineLastLayer(CustomMlxLayer):
|
||||
class PipelineLastLayer(CustomMlxModule):
|
||||
def __init__(
|
||||
self,
|
||||
original_layer: _LayerCallable,
|
||||
@@ -168,11 +170,21 @@ def pipeline_auto_parallel(
|
||||
inner_model_instance.layer_types = inner_model_instance.layer_types[ # type: ignore
|
||||
start_layer:end_layer
|
||||
]
|
||||
inner_model_instance.swa_idx = inner_model_instance.layer_types.index( # type: ignore
|
||||
"sliding_attention"
|
||||
# We can assume the model has at least one layer thanks to placement.
|
||||
# If a layer type doesn't exist, we can set it to 0.
|
||||
inner_model_instance.swa_idx = (
|
||||
0
|
||||
if "sliding_attention" not in inner_model_instance.layer_types # type: ignore
|
||||
else inner_model_instance.layer_types.index( # type: ignore
|
||||
"sliding_attention"
|
||||
)
|
||||
)
|
||||
inner_model_instance.ga_idx = inner_model_instance.layer_types.index( # type: ignore
|
||||
"full_attention"
|
||||
inner_model_instance.ga_idx = (
|
||||
0
|
||||
if "full_attention" not in inner_model_instance.layer_types # type: ignore
|
||||
else inner_model_instance.layer_types.index( # type: ignore
|
||||
"full_attention"
|
||||
)
|
||||
)
|
||||
|
||||
_set_layers(model, layers)
|
||||
@@ -181,7 +193,32 @@ def pipeline_auto_parallel(
|
||||
"Expected a list of layers after auto-parallel initialisation"
|
||||
)
|
||||
|
||||
return model
|
||||
return PipelineParallelModel(model, group)
|
||||
|
||||
|
||||
class PipelineParallelModel(CustomMlxModule):
|
||||
def __init__(self, model: nn.Module, group: mx.distributed.Group):
|
||||
super().__init__(model)
|
||||
self.original_call_signature = signature(self.original_layer.__call__)
|
||||
self.group = group
|
||||
dict.__setitem__(self, "original_layer", model)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
*args: object,
|
||||
**kwargs: object,
|
||||
) -> mx.array:
|
||||
logits: mx.array = self.original_layer(*args, **kwargs) # type: ignore
|
||||
cache = self.original_call_signature.bind_partial(
|
||||
*args, **kwargs
|
||||
).arguments.get("cache", None)
|
||||
|
||||
if cache is not None:
|
||||
for c in cache: # type: ignore
|
||||
if hasattr(c, "state") and c.state is not None: # type: ignore
|
||||
c.state = mx.depends(c.state, logits) # type: ignore
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
def tensor_auto_parallel(
|
||||
@@ -389,7 +426,7 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
|
||||
return model
|
||||
|
||||
|
||||
class ShardedDeepseekV3MoE(CustomMlxLayer):
|
||||
class ShardedDeepseekV3MoE(CustomMlxModule):
|
||||
def __init__(self, layer: _LayerCallable):
|
||||
super().__init__(layer)
|
||||
self.sharding_group: mx.distributed.Group | None = None
|
||||
@@ -464,7 +501,7 @@ class QwenShardingStrategy(TensorParallelShardingStrategy):
|
||||
return model
|
||||
|
||||
|
||||
class ShardedQwenMoE(CustomMlxLayer):
|
||||
class ShardedQwenMoE(CustomMlxModule):
|
||||
def __init__(self, layer: _LayerCallable):
|
||||
super().__init__(layer)
|
||||
self.sharding_group: mx.distributed.Group | None = None
|
||||
@@ -511,7 +548,7 @@ class GptOssShardingStrategy(TensorParallelShardingStrategy):
|
||||
return model
|
||||
|
||||
|
||||
class ShardedGptOssMoE(CustomMlxLayer):
|
||||
class ShardedGptOssMoE(CustomMlxModule):
|
||||
def __init__(self, layer: nn.Module):
|
||||
super().__init__(layer)
|
||||
self.sharding_group: mx.distributed.Group | None = None
|
||||
|
||||
@@ -48,8 +48,6 @@ def maybe_quantize_kv_cache(
|
||||
def warmup_inference(
|
||||
model: Model,
|
||||
tokenizer: TokenizerWrapper,
|
||||
draft_model: Model | None = None,
|
||||
num_draft_tokens: int = 4,
|
||||
) -> int:
|
||||
content = "Prompt to warm up the inference engine. Repeat this."
|
||||
|
||||
@@ -68,30 +66,25 @@ def warmup_inference(
|
||||
|
||||
tokens_generated = 0
|
||||
|
||||
cache = make_kv_cache(
|
||||
model=model,
|
||||
)
|
||||
|
||||
# Use a default sampler for warmup
|
||||
sampler = make_sampler(temp=0.7)
|
||||
|
||||
generate_kwargs: dict[str, object] = {
|
||||
"model": model,
|
||||
"tokenizer": tokenizer,
|
||||
"prompt": warmup_prompt,
|
||||
"max_tokens": 50,
|
||||
"sampler": sampler,
|
||||
"prefill_step_size": 2048,
|
||||
"kv_group_size": KV_GROUP_SIZE,
|
||||
"kv_bits": KV_BITS,
|
||||
}
|
||||
|
||||
# Warm up with draft model if provided (speculative decoding path)
|
||||
if draft_model is not None:
|
||||
logger.info("Warming up with speculative decoding (draft model)")
|
||||
generate_kwargs["draft_model"] = draft_model
|
||||
generate_kwargs["num_draft_tokens"] = num_draft_tokens
|
||||
else:
|
||||
generate_kwargs["prompt_cache"] = make_kv_cache(model=model)
|
||||
|
||||
logger.info("Generating warmup tokens")
|
||||
for _r in stream_generate(**generate_kwargs): # type: ignore[arg-type]
|
||||
for _r in stream_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
prompt=warmup_prompt,
|
||||
max_tokens=50,
|
||||
sampler=sampler,
|
||||
prompt_cache=cache,
|
||||
prefill_step_size=2048,
|
||||
kv_group_size=KV_GROUP_SIZE,
|
||||
kv_bits=KV_BITS,
|
||||
):
|
||||
logger.info("Generated warmup token: " + str(_r.text))
|
||||
tokens_generated += 1
|
||||
|
||||
@@ -126,8 +119,6 @@ def mlx_generate(
|
||||
model: Model,
|
||||
tokenizer: TokenizerWrapper,
|
||||
task: ChatCompletionTaskParams,
|
||||
draft_model: Model | None = None,
|
||||
num_draft_tokens: int = 4,
|
||||
) -> Generator[GenerationResponse]:
|
||||
# Ensure that generation stats only contains peak memory for this generation
|
||||
mx.reset_peak_memory()
|
||||
@@ -144,6 +135,8 @@ def mlx_generate(
|
||||
chat_task_data=task,
|
||||
)
|
||||
|
||||
caches = make_kv_cache(model=model)
|
||||
|
||||
logits_processors: list[Callable[[mx.array, mx.array], mx.array]] = []
|
||||
if is_bench:
|
||||
# Only sample length eos tokens
|
||||
@@ -156,31 +149,19 @@ def mlx_generate(
|
||||
)
|
||||
|
||||
max_tokens = task.max_tokens or MAX_TOKENS
|
||||
|
||||
# Build kwargs for stream_generate, conditionally adding draft model params
|
||||
generate_kwargs: dict[str, object] = {
|
||||
"model": model,
|
||||
"tokenizer": tokenizer,
|
||||
"prompt": prompt,
|
||||
"max_tokens": max_tokens,
|
||||
"sampler": sampler,
|
||||
"logits_processors": logits_processors,
|
||||
"prefill_step_size": 2048,
|
||||
"kv_group_size": KV_GROUP_SIZE,
|
||||
"kv_bits": KV_BITS,
|
||||
}
|
||||
|
||||
# Add speculative decoding parameters if draft model is provided
|
||||
# Note: When using draft_model, we let mlx_lm create its own trimmable cache
|
||||
# as speculative decoding requires cache trimming capabilities
|
||||
if draft_model is not None:
|
||||
generate_kwargs["draft_model"] = draft_model
|
||||
generate_kwargs["num_draft_tokens"] = num_draft_tokens
|
||||
else:
|
||||
# Only use custom cache for non-speculative generation
|
||||
generate_kwargs["prompt_cache"] = make_kv_cache(model=model)
|
||||
|
||||
for out in stream_generate(**generate_kwargs): # type: ignore[arg-type]
|
||||
for out in stream_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
prompt=prompt,
|
||||
max_tokens=max_tokens,
|
||||
sampler=sampler,
|
||||
logits_processors=logits_processors,
|
||||
prompt_cache=caches,
|
||||
# TODO: Dynamically change prefill step size to be the maximum possible without timing out.
|
||||
prefill_step_size=2048,
|
||||
kv_group_size=KV_GROUP_SIZE,
|
||||
kv_bits=KV_BITS,
|
||||
):
|
||||
logger.info(out.text)
|
||||
|
||||
stats: GenerationStats | None = None
|
||||
|
||||
@@ -258,27 +258,6 @@ def load_mlx_items(
|
||||
return cast(Model, model), tokenizer
|
||||
|
||||
|
||||
def load_draft_model(model_id: str) -> nn.Module:
|
||||
"""Load a draft model for speculative decoding (rank 0 only).
|
||||
|
||||
Draft models are small models (typically 0.5B-2B parameters) used to
|
||||
generate candidate tokens quickly, which are then verified by the main
|
||||
model in a single forward pass.
|
||||
|
||||
Assumes the model has already been downloaded by the worker.
|
||||
|
||||
Args:
|
||||
model_id: HuggingFace model ID for the draft model
|
||||
|
||||
Returns:
|
||||
The loaded draft model
|
||||
"""
|
||||
model_path = build_model_path(model_id)
|
||||
draft_model, _ = load_model(model_path, strict=True)
|
||||
logger.info(f"Loaded draft model from {model_path}")
|
||||
return draft_model
|
||||
|
||||
|
||||
def shard_and_load(
|
||||
shard_metadata: ShardMetadata,
|
||||
group: Group,
|
||||
|
||||
@@ -29,9 +29,7 @@ from exo.shared.types.profiling import MemoryPerformanceProfile, NodePerformance
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.types.tasks import (
|
||||
CreateRunner,
|
||||
DownloadDraftModel,
|
||||
DownloadModel,
|
||||
SetDraftModel,
|
||||
Shutdown,
|
||||
Task,
|
||||
TaskStatus,
|
||||
@@ -50,7 +48,6 @@ from exo.utils.event_buffer import OrderedBuffer
|
||||
from exo.worker.download.download_utils import (
|
||||
map_repo_download_progress_to_download_progress_data,
|
||||
)
|
||||
from exo.worker.download.impl_shard_downloader import build_full_shard
|
||||
from exo.worker.download.shard_downloader import RepoDownloadProgress, ShardDownloader
|
||||
from exo.worker.plan import plan
|
||||
from exo.worker.runner.runner_supervisor import RunnerSupervisor
|
||||
@@ -205,10 +202,42 @@ class Worker:
|
||||
)
|
||||
)
|
||||
case DownloadModel(shard_metadata=shard):
|
||||
await self._handle_download(shard, task)
|
||||
case DownloadDraftModel(model_id=model_id):
|
||||
shard = await build_full_shard(model_id)
|
||||
await self._handle_download(shard, task)
|
||||
if shard.model_meta.model_id not in self.download_status:
|
||||
progress = DownloadPending(
|
||||
shard_metadata=shard, node_id=self.node_id
|
||||
)
|
||||
self.download_status[shard.model_meta.model_id] = progress
|
||||
await self.event_sender.send(
|
||||
NodeDownloadProgress(download_progress=progress)
|
||||
)
|
||||
initial_progress = (
|
||||
await self.shard_downloader.get_shard_download_status_for_shard(
|
||||
shard
|
||||
)
|
||||
)
|
||||
if initial_progress.status == "complete":
|
||||
progress = DownloadCompleted(
|
||||
shard_metadata=shard,
|
||||
node_id=self.node_id,
|
||||
total_bytes=initial_progress.total_bytes,
|
||||
)
|
||||
self.download_status[shard.model_meta.model_id] = progress
|
||||
await self.event_sender.send(
|
||||
NodeDownloadProgress(download_progress=progress)
|
||||
)
|
||||
await self.event_sender.send(
|
||||
TaskStatusUpdated(
|
||||
task_id=task.task_id,
|
||||
task_status=TaskStatus.Complete,
|
||||
)
|
||||
)
|
||||
else:
|
||||
await self.event_sender.send(
|
||||
TaskStatusUpdated(
|
||||
task_id=task.task_id, task_status=TaskStatus.Running
|
||||
)
|
||||
)
|
||||
self._handle_shard_download_process(task, initial_progress)
|
||||
case Shutdown(runner_id=runner_id):
|
||||
try:
|
||||
with fail_after(3):
|
||||
@@ -219,25 +248,6 @@ class Worker:
|
||||
task_id=task.task_id, task_status=TaskStatus.TimedOut
|
||||
)
|
||||
)
|
||||
case SetDraftModel(
|
||||
model_id=draft_model_id, num_draft_tokens=num_tokens
|
||||
):
|
||||
runner = self.runners[self._task_to_runner_id(task)]
|
||||
await runner.start_task(task)
|
||||
# Update bound_instance to reflect new/cleared draft model
|
||||
updated_instance = runner.bound_instance.instance.model_copy(
|
||||
update={
|
||||
"draft_model": (
|
||||
ModelId(draft_model_id)
|
||||
if draft_model_id is not None
|
||||
else None
|
||||
),
|
||||
"num_draft_tokens": num_tokens,
|
||||
}
|
||||
)
|
||||
runner.bound_instance = runner.bound_instance.model_copy(
|
||||
update={"instance": updated_instance}
|
||||
)
|
||||
case task:
|
||||
await self.runners[self._task_to_runner_id(task)].start_task(task)
|
||||
|
||||
@@ -330,46 +340,6 @@ class Worker:
|
||||
self._tg.start_soon(runner.run)
|
||||
return runner
|
||||
|
||||
async def _handle_download(self, shard: ShardMetadata, task: Task) -> None:
|
||||
"""Handle model download - shared logic for main and draft models."""
|
||||
model_id = shard.model_meta.model_id
|
||||
|
||||
if model_id not in self.download_status:
|
||||
progress = DownloadPending(shard_metadata=shard, node_id=self.node_id)
|
||||
self.download_status[model_id] = progress
|
||||
await self.event_sender.send(
|
||||
NodeDownloadProgress(download_progress=progress)
|
||||
)
|
||||
|
||||
initial_progress = (
|
||||
await self.shard_downloader.get_shard_download_status_for_shard(shard)
|
||||
)
|
||||
|
||||
if initial_progress.status == "complete":
|
||||
progress = DownloadCompleted(
|
||||
shard_metadata=shard,
|
||||
node_id=self.node_id,
|
||||
total_bytes=initial_progress.total_bytes,
|
||||
)
|
||||
self.download_status[model_id] = progress
|
||||
await self.event_sender.send(
|
||||
NodeDownloadProgress(download_progress=progress)
|
||||
)
|
||||
await self.event_sender.send(
|
||||
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Complete)
|
||||
)
|
||||
else:
|
||||
await self.event_sender.send(
|
||||
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Running)
|
||||
)
|
||||
download_task = DownloadModel(
|
||||
instance_id=task.instance_id,
|
||||
shard_metadata=shard,
|
||||
task_id=task.task_id,
|
||||
task_status=task.task_status,
|
||||
)
|
||||
self._handle_shard_download_process(download_task, initial_progress)
|
||||
|
||||
def _handle_shard_download_process(
|
||||
self,
|
||||
task: DownloadModel,
|
||||
|
||||
@@ -8,10 +8,8 @@ from exo.shared.types.tasks import (
|
||||
ChatCompletion,
|
||||
ConnectToGroup,
|
||||
CreateRunner,
|
||||
DownloadDraftModel,
|
||||
DownloadModel,
|
||||
LoadModel,
|
||||
SetDraftModel,
|
||||
Shutdown,
|
||||
StartWarmup,
|
||||
Task,
|
||||
@@ -40,16 +38,6 @@ from exo.shared.types.worker.runners import (
|
||||
from exo.worker.runner.runner_supervisor import RunnerSupervisor
|
||||
|
||||
|
||||
def _is_download_in_progress_or_complete(
|
||||
model_id: ModelId,
|
||||
download_status: Mapping[ModelId, DownloadProgress],
|
||||
) -> bool:
|
||||
"""Check if model download is in progress or complete."""
|
||||
return model_id in download_status and isinstance(
|
||||
download_status[model_id], (DownloadOngoing, DownloadCompleted)
|
||||
)
|
||||
|
||||
|
||||
def plan(
|
||||
node_id: NodeId,
|
||||
# Runners is expected to be FRESH and so should not come from state
|
||||
@@ -67,11 +55,9 @@ def plan(
|
||||
_kill_runner(runners, all_runners, instances)
|
||||
or _create_runner(node_id, runners, instances)
|
||||
or _model_needs_download(runners, download_status)
|
||||
or _draft_model_needs_download(runners, download_status, instances)
|
||||
or _init_distributed_backend(runners, all_runners)
|
||||
or _load_model(runners, all_runners, global_download_status, download_status)
|
||||
or _load_model(runners, all_runners, global_download_status)
|
||||
or _ready_to_warmup(runners, all_runners)
|
||||
or _set_draft_model(runners, instances, download_status)
|
||||
or _pending_tasks(runners, tasks, all_runners)
|
||||
)
|
||||
|
||||
@@ -129,9 +115,12 @@ def _model_needs_download(
|
||||
) -> DownloadModel | None:
|
||||
for runner in runners.values():
|
||||
model_id = runner.bound_instance.bound_shard.model_meta.model_id
|
||||
if isinstance(
|
||||
runner.status, RunnerIdle
|
||||
) and not _is_download_in_progress_or_complete(model_id, download_status):
|
||||
if isinstance(runner.status, RunnerIdle) and (
|
||||
model_id not in download_status
|
||||
or not isinstance(
|
||||
download_status[model_id], (DownloadOngoing, DownloadCompleted)
|
||||
)
|
||||
):
|
||||
# We don't invalidate download_status randomly in case a file gets deleted on disk
|
||||
return DownloadModel(
|
||||
instance_id=runner.bound_instance.instance.instance_id,
|
||||
@@ -139,43 +128,6 @@ def _model_needs_download(
|
||||
)
|
||||
|
||||
|
||||
def _draft_model_needs_download(
|
||||
runners: Mapping[RunnerId, RunnerSupervisor],
|
||||
download_status: Mapping[ModelId, DownloadProgress],
|
||||
instances: Mapping[InstanceId, Instance],
|
||||
) -> DownloadDraftModel | None:
|
||||
"""Check if draft model needs download for rank 0 runner.
|
||||
|
||||
Triggers download when:
|
||||
- RunnerIdle with draft model (initial setup)
|
||||
- RunnerReady with new draft model (updated via API)
|
||||
"""
|
||||
rank_0_runner = next(
|
||||
(r for r in runners.values() if r.bound_instance.bound_shard.device_rank == 0),
|
||||
None,
|
||||
)
|
||||
if rank_0_runner is None:
|
||||
return None
|
||||
if not isinstance(rank_0_runner.status, (RunnerIdle, RunnerReady)):
|
||||
return None
|
||||
|
||||
# Use current instance state (may have been updated via API)
|
||||
instance_id = rank_0_runner.bound_instance.instance.instance_id
|
||||
current_instance = instances.get(instance_id)
|
||||
if current_instance is None:
|
||||
return None
|
||||
|
||||
draft_model_id = current_instance.draft_model
|
||||
if draft_model_id is None:
|
||||
return None
|
||||
if _is_download_in_progress_or_complete(draft_model_id, download_status):
|
||||
return None
|
||||
return DownloadDraftModel(
|
||||
instance_id=instance_id,
|
||||
model_id=str(draft_model_id),
|
||||
)
|
||||
|
||||
|
||||
def _init_distributed_backend(
|
||||
runners: Mapping[RunnerId, RunnerSupervisor],
|
||||
all_runners: Mapping[RunnerId, RunnerStatus],
|
||||
@@ -230,12 +182,10 @@ def _load_model(
|
||||
runners: Mapping[RunnerId, RunnerSupervisor],
|
||||
all_runners: Mapping[RunnerId, RunnerStatus],
|
||||
global_download_status: Mapping[NodeId, Sequence[DownloadProgress]],
|
||||
download_status: Mapping[ModelId, DownloadProgress],
|
||||
) -> LoadModel | None:
|
||||
for runner in runners.values():
|
||||
instance = runner.bound_instance.instance
|
||||
shard_assignments = instance.shard_assignments
|
||||
shard = runner.bound_instance.bound_shard
|
||||
|
||||
all_local_downloads_complete = all(
|
||||
nid in global_download_status
|
||||
@@ -249,14 +199,6 @@ def _load_model(
|
||||
if not all_local_downloads_complete:
|
||||
continue
|
||||
|
||||
# Rank 0 with draft model must wait for draft download before loading
|
||||
if shard.device_rank == 0:
|
||||
draft_model_id = instance.draft_model
|
||||
if draft_model_id is not None and not isinstance(
|
||||
download_status.get(draft_model_id), DownloadCompleted
|
||||
):
|
||||
continue
|
||||
|
||||
is_single_node_instance = len(instance.shard_assignments.runner_to_shard) == 1
|
||||
if is_single_node_instance and isinstance(runner.status, RunnerIdle):
|
||||
return LoadModel(instance_id=instance.instance_id)
|
||||
@@ -316,53 +258,6 @@ def _ready_to_warmup(
|
||||
return None
|
||||
|
||||
|
||||
def _set_draft_model(
|
||||
runners: Mapping[RunnerId, RunnerSupervisor],
|
||||
instances: Mapping[InstanceId, Instance],
|
||||
download_status: Mapping[ModelId, DownloadProgress],
|
||||
) -> SetDraftModel | None:
|
||||
"""Check if rank 0 runner needs to load or clear a draft model."""
|
||||
rank_0_runner = next(
|
||||
(r for r in runners.values() if r.bound_instance.bound_shard.device_rank == 0),
|
||||
None,
|
||||
)
|
||||
if rank_0_runner is None:
|
||||
return None
|
||||
if not isinstance(rank_0_runner.status, RunnerReady):
|
||||
return None
|
||||
|
||||
instance_id = rank_0_runner.bound_instance.instance.instance_id
|
||||
current_instance = instances.get(instance_id)
|
||||
if current_instance is None:
|
||||
return None
|
||||
|
||||
# Compare runner's bound draft model vs current instance draft model
|
||||
runner_draft_model = rank_0_runner.bound_instance.instance.draft_model
|
||||
current_draft_model = current_instance.draft_model
|
||||
|
||||
if runner_draft_model == current_draft_model:
|
||||
return None
|
||||
|
||||
# Draft model changed - need to update
|
||||
if current_draft_model is None:
|
||||
# Clear draft model
|
||||
return SetDraftModel(
|
||||
instance_id=instance_id,
|
||||
model_id=None,
|
||||
num_draft_tokens=4,
|
||||
)
|
||||
|
||||
# Wait for draft model to be downloaded
|
||||
if not isinstance(download_status.get(current_draft_model), DownloadCompleted):
|
||||
return None
|
||||
|
||||
return SetDraftModel(
|
||||
instance_id=instance_id,
|
||||
model_id=str(current_draft_model),
|
||||
num_draft_tokens=current_instance.num_draft_tokens,
|
||||
)
|
||||
|
||||
|
||||
def _pending_tasks(
|
||||
runners: Mapping[RunnerId, RunnerSupervisor],
|
||||
tasks: Mapping[TaskId, Task],
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from functools import cache
|
||||
from typing import cast
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx_lm.models.gpt_oss import Model as GptOssModel
|
||||
@@ -15,7 +13,6 @@ from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
|
||||
|
||||
from exo.shared.types.api import ChatCompletionMessageText
|
||||
from exo.shared.types.chunks import TokenChunk
|
||||
from exo.shared.types.common import CommandId
|
||||
from exo.shared.types.events import (
|
||||
ChunkGenerated,
|
||||
Event,
|
||||
@@ -23,12 +20,10 @@ from exo.shared.types.events import (
|
||||
TaskAcknowledged,
|
||||
TaskStatusUpdated,
|
||||
)
|
||||
from exo.shared.types.models import ModelId
|
||||
from exo.shared.types.tasks import (
|
||||
ChatCompletion,
|
||||
ConnectToGroup,
|
||||
LoadModel,
|
||||
SetDraftModel,
|
||||
Shutdown,
|
||||
StartWarmup,
|
||||
Task,
|
||||
@@ -53,44 +48,15 @@ from exo.shared.types.worker.runners import (
|
||||
RunnerWarmingUp,
|
||||
)
|
||||
from exo.utils.channels import MpReceiver, MpSender
|
||||
from exo.worker.engines.mlx import Model
|
||||
from exo.worker.engines.mlx.generator.generate import mlx_generate, warmup_inference
|
||||
from exo.worker.engines.mlx.utils_mlx import (
|
||||
initialize_mlx,
|
||||
load_draft_model,
|
||||
load_mlx_items,
|
||||
mlx_force_oom,
|
||||
)
|
||||
from exo.worker.runner.bootstrap import logger
|
||||
|
||||
|
||||
@contextmanager
|
||||
def send_error_chunk_on_exception(
|
||||
event_sender: MpSender[Event],
|
||||
command_id: CommandId,
|
||||
model_id: ModelId,
|
||||
device_rank: int,
|
||||
):
|
||||
try:
|
||||
yield
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
if device_rank == 0:
|
||||
event_sender.send(
|
||||
ChunkGenerated(
|
||||
command_id=command_id,
|
||||
chunk=TokenChunk(
|
||||
idx=0,
|
||||
model=model_id,
|
||||
text="",
|
||||
token_id=0,
|
||||
finish_reason="error",
|
||||
error_message=str(e),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def main(
|
||||
bound_instance: BoundInstance,
|
||||
event_sender: MpSender[Event],
|
||||
@@ -101,6 +67,7 @@ def main(
|
||||
bound_instance.bound_runner_id,
|
||||
bound_instance.bound_shard,
|
||||
)
|
||||
device_rank = shard_metadata.device_rank
|
||||
logger.info("hello from the runner")
|
||||
if getattr(shard_metadata, "immediate_exception", False):
|
||||
raise Exception("Fake exception - runner failed to spin up.")
|
||||
@@ -112,7 +79,6 @@ def main(
|
||||
model = None
|
||||
tokenizer = None
|
||||
group = None
|
||||
draft_model: Model | None = None # Loaded during warmup if instance has draft_model
|
||||
|
||||
current_status: RunnerStatus = RunnerIdle()
|
||||
logger.info("runner created")
|
||||
@@ -168,16 +134,6 @@ def main(
|
||||
bound_instance, group, on_timeout=on_model_load_timeout
|
||||
)
|
||||
|
||||
# Load draft model for speculative decoding (rank 0 only)
|
||||
if (
|
||||
instance.draft_model is not None
|
||||
and shard_metadata.device_rank == 0
|
||||
):
|
||||
logger.info(f"Loading draft model: {instance.draft_model}")
|
||||
draft_model = cast(
|
||||
Model, load_draft_model(str(instance.draft_model))
|
||||
)
|
||||
|
||||
current_status = RunnerLoaded()
|
||||
logger.info("runner loaded")
|
||||
case StartWarmup() if isinstance(current_status, RunnerLoaded):
|
||||
@@ -193,10 +149,9 @@ def main(
|
||||
|
||||
logger.info(f"warming up inference for instance: {instance}")
|
||||
toks = warmup_inference(
|
||||
model=cast(Model, model),
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
draft_model=draft_model,
|
||||
num_draft_tokens=instance.num_draft_tokens,
|
||||
# kv_prefix_cache=kv_prefix_cache, # supply for warmup-time prefix caching
|
||||
)
|
||||
logger.info(f"warmed up by generating {toks} tokens")
|
||||
logger.info(
|
||||
@@ -215,24 +170,18 @@ def main(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
with send_error_chunk_on_exception(
|
||||
event_sender,
|
||||
command_id,
|
||||
shard_metadata.model_meta.model_id,
|
||||
shard_metadata.device_rank,
|
||||
):
|
||||
assert model
|
||||
assert tokenizer
|
||||
assert task_params.messages[0].content is not None
|
||||
assert model
|
||||
assert tokenizer
|
||||
assert task_params.messages[0].content is not None
|
||||
|
||||
try:
|
||||
_check_for_debug_prompts(task_params.messages[0].content)
|
||||
|
||||
# Generate responses (draft_model loaded at warmup if configured)
|
||||
# Generate responses using the actual MLX generation
|
||||
mlx_generator = mlx_generate(
|
||||
model=cast(Model, model),
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
task=task_params,
|
||||
draft_model=draft_model,
|
||||
num_draft_tokens=instance.num_draft_tokens,
|
||||
)
|
||||
|
||||
# GPT-OSS specific parsing to match other model formats.
|
||||
@@ -244,7 +193,7 @@ def main(
|
||||
for response in mlx_generator:
|
||||
match response:
|
||||
case GenerationResponse():
|
||||
if shard_metadata.device_rank == 0:
|
||||
if device_rank == 0:
|
||||
event_sender.send(
|
||||
ChunkGenerated(
|
||||
command_id=command_id,
|
||||
@@ -259,52 +208,26 @@ def main(
|
||||
)
|
||||
)
|
||||
|
||||
# can we make this more explicit?
|
||||
except Exception as e:
|
||||
if device_rank == 0:
|
||||
event_sender.send(
|
||||
ChunkGenerated(
|
||||
command_id=command_id,
|
||||
chunk=TokenChunk(
|
||||
idx=0,
|
||||
model=shard_metadata.model_meta.model_id,
|
||||
text="",
|
||||
token_id=0,
|
||||
finish_reason="error",
|
||||
error_message=str(e),
|
||||
),
|
||||
)
|
||||
)
|
||||
raise
|
||||
|
||||
current_status = RunnerReady()
|
||||
logger.info("runner ready")
|
||||
case SetDraftModel(
|
||||
model_id=draft_model_id, num_draft_tokens=num_tokens
|
||||
) if isinstance(current_status, RunnerReady):
|
||||
current_status = RunnerWarmingUp()
|
||||
logger.info("runner warming up (setting draft model)")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
assert model is not None
|
||||
assert tokenizer is not None
|
||||
|
||||
if draft_model_id is None:
|
||||
# Clear draft model
|
||||
logger.info("Clearing draft model")
|
||||
draft_model = None
|
||||
instance = instance.model_copy(
|
||||
update={
|
||||
"draft_model": None,
|
||||
"num_draft_tokens": 4,
|
||||
}
|
||||
)
|
||||
else:
|
||||
# Load new draft model
|
||||
logger.info(f"Loading draft model: {draft_model_id}")
|
||||
draft_model = cast(Model, load_draft_model(draft_model_id))
|
||||
instance = instance.model_copy(
|
||||
update={
|
||||
"draft_model": ModelId(draft_model_id),
|
||||
"num_draft_tokens": num_tokens,
|
||||
}
|
||||
)
|
||||
# Warm up with speculative decoding
|
||||
logger.info("Warming up with new draft model")
|
||||
warmup_inference(
|
||||
model=cast(Model, model),
|
||||
tokenizer=tokenizer,
|
||||
draft_model=draft_model,
|
||||
num_draft_tokens=num_tokens,
|
||||
)
|
||||
logger.info("Draft model loaded and warmed up")
|
||||
|
||||
current_status = RunnerReady()
|
||||
case Shutdown():
|
||||
current_status = RunnerShuttingDown()
|
||||
logger.info("runner shutting down")
|
||||
@@ -325,7 +248,7 @@ def main(
|
||||
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
|
||||
)
|
||||
if isinstance(current_status, RunnerShutdown):
|
||||
del model, tokenizer, group, draft_model
|
||||
del model, tokenizer, group
|
||||
mx.clear_cache()
|
||||
import gc
|
||||
|
||||
|
||||
220
src/exo/worker/tests/unittests/test_mlx/conftest.py
Normal file
220
src/exo/worker/tests/unittests/test_mlx/conftest.py
Normal file
@@ -0,0 +1,220 @@
|
||||
# type: ignore
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, cast
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from exo.shared.constants import EXO_MODELS_DIR
|
||||
|
||||
|
||||
class MockLayer(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.custom_attr = "test_value"
|
||||
self.use_sliding = True
|
||||
|
||||
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array:
|
||||
return x * 2
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PipelineTestConfig:
|
||||
model_path: Path
|
||||
total_layers: int
|
||||
base_port: int
|
||||
max_tokens: int
|
||||
|
||||
|
||||
def create_hostfile(world_size: int, base_port: int) -> tuple[str, list[str]]:
|
||||
import json
|
||||
import tempfile
|
||||
|
||||
hosts = [f"127.0.0.1:{base_port + i}" for i in range(world_size)]
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
|
||||
json.dump(hosts, f)
|
||||
hostfile_path = f.name
|
||||
|
||||
return hostfile_path, hosts
|
||||
|
||||
|
||||
# Use GPT OSS 20b to test as it is a model with a lot of strange behaviour
|
||||
|
||||
DEFAULT_GPT_OSS_CONFIG = PipelineTestConfig(
|
||||
model_path=EXO_MODELS_DIR / "mlx-community--gpt-oss-20b-MXFP4-Q8",
|
||||
total_layers=24,
|
||||
base_port=29600,
|
||||
max_tokens=200,
|
||||
)
|
||||
|
||||
|
||||
DEFAULT_GPT_OSS_MODEL_ID = "mlx-community/gpt-oss-20b-MXFP4-Q8"
|
||||
|
||||
|
||||
def run_gpt_oss_pipeline_device(
|
||||
rank: int,
|
||||
world_size: int,
|
||||
hostfile_path: str,
|
||||
layer_splits: list[tuple[int, int]],
|
||||
prompt_tokens: int,
|
||||
prefill_step_size: int,
|
||||
result_queue: Any, # pyright: ignore[reportAny]
|
||||
max_tokens: int = 200,
|
||||
) -> None:
|
||||
import os
|
||||
import traceback
|
||||
|
||||
os.environ["MLX_HOSTFILE"] = hostfile_path
|
||||
os.environ["MLX_RANK"] = str(rank)
|
||||
|
||||
import mlx.core as mlx_core
|
||||
|
||||
from exo.shared.types.api import ChatCompletionMessage
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.models import ModelId, ModelMetadata
|
||||
from exo.shared.types.tasks import ChatCompletionTaskParams
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata
|
||||
from exo.worker.engines.mlx import Model
|
||||
from exo.worker.engines.mlx.generator.generate import mlx_generate
|
||||
from exo.worker.engines.mlx.utils_mlx import shard_and_load
|
||||
|
||||
try:
|
||||
group = mlx_core.distributed.init(backend="ring", strict=True)
|
||||
|
||||
start_layer, end_layer = layer_splits[rank]
|
||||
|
||||
shard_meta = PipelineShardMetadata(
|
||||
model_meta=ModelMetadata(
|
||||
model_id=ModelId(DEFAULT_GPT_OSS_MODEL_ID),
|
||||
pretty_name="GPT-OSS 20B",
|
||||
storage_size=Memory.from_gb(12),
|
||||
n_layers=24,
|
||||
hidden_size=2880,
|
||||
supports_tensor=False,
|
||||
),
|
||||
device_rank=rank,
|
||||
world_size=world_size,
|
||||
start_layer=start_layer,
|
||||
end_layer=end_layer,
|
||||
n_layers=24,
|
||||
)
|
||||
|
||||
model, tokenizer = shard_and_load(shard_meta, group)
|
||||
model = cast(Model, model)
|
||||
|
||||
# Generate a prompt of exact token length
|
||||
base_text = "The quick brown fox jumps over the lazy dog. "
|
||||
base_tokens = tokenizer.encode(base_text)
|
||||
base_len = len(base_tokens)
|
||||
|
||||
# Build prompt with approximate target length
|
||||
repeats = (prompt_tokens // base_len) + 2
|
||||
long_text = base_text * repeats
|
||||
tokens = tokenizer.encode(long_text)
|
||||
# Truncate to exact target length
|
||||
tokens = tokens[:prompt_tokens]
|
||||
prompt_text = tokenizer.decode(tokens)
|
||||
|
||||
task = ChatCompletionTaskParams(
|
||||
model=DEFAULT_GPT_OSS_MODEL_ID,
|
||||
messages=[ChatCompletionMessage(role="user", content=prompt_text)],
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
|
||||
generated_text = ""
|
||||
for response in mlx_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
task=task,
|
||||
):
|
||||
generated_text += response.text
|
||||
if response.finish_reason is not None:
|
||||
break
|
||||
|
||||
result_queue.put((rank, True, generated_text)) # pyright: ignore[reportAny]
|
||||
|
||||
except Exception as e:
|
||||
result_queue.put((rank, False, f"{e}\n{traceback.format_exc()}")) # pyright: ignore[reportAny]
|
||||
|
||||
|
||||
def run_gpt_oss_tensor_parallel_device(
|
||||
rank: int,
|
||||
world_size: int,
|
||||
hostfile_path: str,
|
||||
prompt_tokens: int,
|
||||
prefill_step_size: int,
|
||||
result_queue: Any, # pyright: ignore[reportAny]
|
||||
max_tokens: int = 10,
|
||||
) -> None:
|
||||
import os
|
||||
import traceback
|
||||
|
||||
os.environ["MLX_HOSTFILE"] = hostfile_path
|
||||
os.environ["MLX_RANK"] = str(rank)
|
||||
|
||||
import mlx.core as mlx_core
|
||||
|
||||
from exo.shared.types.api import ChatCompletionMessage
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.models import ModelId, ModelMetadata
|
||||
from exo.shared.types.tasks import ChatCompletionTaskParams
|
||||
from exo.shared.types.worker.shards import TensorShardMetadata
|
||||
from exo.worker.engines.mlx import Model
|
||||
from exo.worker.engines.mlx.generator.generate import mlx_generate
|
||||
from exo.worker.engines.mlx.utils_mlx import shard_and_load
|
||||
|
||||
try:
|
||||
group = mlx_core.distributed.init(backend="ring", strict=True)
|
||||
|
||||
# For tensor parallelism, all devices run all layers
|
||||
shard_meta = TensorShardMetadata(
|
||||
model_meta=ModelMetadata(
|
||||
model_id=ModelId(DEFAULT_GPT_OSS_MODEL_ID),
|
||||
pretty_name="GPT-OSS 20B",
|
||||
storage_size=Memory.from_gb(12),
|
||||
n_layers=24,
|
||||
hidden_size=2880,
|
||||
supports_tensor=True,
|
||||
),
|
||||
device_rank=rank,
|
||||
world_size=world_size,
|
||||
start_layer=0,
|
||||
end_layer=24,
|
||||
n_layers=24,
|
||||
)
|
||||
|
||||
model, tokenizer = shard_and_load(shard_meta, group)
|
||||
model = cast(Model, model)
|
||||
|
||||
base_text = "The quick brown fox jumps over the lazy dog. "
|
||||
base_tokens = tokenizer.encode(base_text)
|
||||
base_len = len(base_tokens)
|
||||
|
||||
repeats = (prompt_tokens // base_len) + 2
|
||||
long_text = base_text * repeats
|
||||
tokens = tokenizer.encode(long_text)
|
||||
tokens = tokens[:prompt_tokens]
|
||||
prompt_text = tokenizer.decode(tokens)
|
||||
|
||||
task = ChatCompletionTaskParams(
|
||||
model=DEFAULT_GPT_OSS_MODEL_ID,
|
||||
messages=[ChatCompletionMessage(role="user", content=prompt_text)],
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
|
||||
generated_text = ""
|
||||
for response in mlx_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
task=task,
|
||||
):
|
||||
generated_text += response.text
|
||||
if response.finish_reason is not None:
|
||||
break
|
||||
|
||||
result_queue.put((rank, True, generated_text)) # pyright: ignore[reportAny]
|
||||
|
||||
except Exception as e:
|
||||
result_queue.put((rank, False, f"{e}\n{traceback.format_exc()}")) # pyright: ignore[reportAny]
|
||||
154
src/exo/worker/tests/unittests/test_mlx/test_auto_parallel.py
Normal file
154
src/exo/worker/tests/unittests/test_mlx/test_auto_parallel.py
Normal file
@@ -0,0 +1,154 @@
|
||||
import multiprocessing as mp
|
||||
from typing import Any
|
||||
|
||||
import mlx.core as mx
|
||||
import pytest
|
||||
|
||||
from exo.worker.engines.mlx.auto_parallel import (
|
||||
CustomMlxModule,
|
||||
PipelineFirstLayer,
|
||||
PipelineLastLayer,
|
||||
PipelineParallelModel,
|
||||
)
|
||||
from exo.worker.tests.unittests.test_mlx.conftest import MockLayer
|
||||
|
||||
|
||||
def run_pipeline_device(
|
||||
rank: int,
|
||||
world_size: int,
|
||||
hostfile_path: str,
|
||||
result_queue: Any, # pyright: ignore[reportAny]
|
||||
) -> None:
|
||||
import os
|
||||
|
||||
os.environ["MLX_HOSTFILE"] = hostfile_path
|
||||
os.environ["MLX_RANK"] = str(rank)
|
||||
|
||||
import mlx.core as mlx_core
|
||||
import mlx.nn as mlx_nn
|
||||
|
||||
class MockLayerInner(mlx_nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.custom_attr = "test_value"
|
||||
|
||||
def __call__(
|
||||
self, x: mlx_core.array, *args: object, **kwargs: object
|
||||
) -> mlx_core.array:
|
||||
return x * 2
|
||||
|
||||
class MockModel(mlx_nn.Module):
|
||||
def __init__(self, layers: list[mlx_nn.Module]) -> None:
|
||||
super().__init__()
|
||||
self.layers = layers
|
||||
|
||||
def __call__(
|
||||
self, x: mlx_core.array, *args: object, **kwargs: object
|
||||
) -> mlx_core.array:
|
||||
for layer in self.layers:
|
||||
x = layer(x, *args, **kwargs) # pyright: ignore[reportUnknownVariableType]
|
||||
return x # pyright: ignore[reportUnknownVariableType]
|
||||
|
||||
try:
|
||||
group = mlx_core.distributed.init(backend="ring", strict=True)
|
||||
|
||||
mock = MockLayerInner()
|
||||
first = PipelineFirstLayer(mock, r=rank, group=group)
|
||||
composed = PipelineLastLayer(first, r=rank, s=world_size, group=group)
|
||||
|
||||
# Wrap in a mock model, then wrap in PipelineParallelModel for all_gather
|
||||
inner_model = MockModel([composed])
|
||||
model = PipelineParallelModel(inner_model, group)
|
||||
|
||||
x = mlx_core.ones((1, 4))
|
||||
result = model(x)
|
||||
mlx_core.eval(result)
|
||||
|
||||
success = result.shape == x.shape
|
||||
result_queue.put((rank, success, result)) # pyright: ignore[reportAny]
|
||||
except Exception as e:
|
||||
result_queue.put((rank, False, str(e))) # pyright: ignore[reportAny]
|
||||
|
||||
|
||||
def test_single_wrapper_delegates_attributes() -> None:
|
||||
mock = MockLayer()
|
||||
wrapped = CustomMlxModule(mock)
|
||||
|
||||
assert wrapped.custom_attr == "test_value" # type: ignore[attr-defined]
|
||||
assert wrapped.use_sliding is True # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def test_composed_wrappers_delegate_attributes() -> None:
|
||||
mock = MockLayer()
|
||||
group = mx.distributed.init()
|
||||
|
||||
first = PipelineFirstLayer(mock, r=0, group=group)
|
||||
composed = PipelineLastLayer(first, r=0, s=1, group=group)
|
||||
|
||||
assert composed.custom_attr == "test_value" # type: ignore[attr-defined]
|
||||
assert composed.use_sliding is True # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def test_missing_attribute_raises() -> None:
|
||||
mock = MockLayer()
|
||||
wrapped = CustomMlxModule(mock)
|
||||
|
||||
with pytest.raises(AttributeError):
|
||||
_ = wrapped.nonexistent_attr # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def test_composed_call_works() -> None:
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
ctx = mp.get_context("spawn")
|
||||
|
||||
world_size = 2
|
||||
base_port = 29500
|
||||
|
||||
hosts = [f"127.0.0.1:{base_port + i}" for i in range(world_size)]
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
|
||||
json.dump(hosts, f)
|
||||
hostfile_path = f.name
|
||||
|
||||
try:
|
||||
result_queue: Any = ctx.Queue()
|
||||
|
||||
processes: list[Any] = []
|
||||
for rank in range(world_size):
|
||||
p = ctx.Process(
|
||||
target=run_pipeline_device,
|
||||
args=(rank, world_size, hostfile_path, result_queue),
|
||||
)
|
||||
p.start()
|
||||
processes.append(p)
|
||||
|
||||
for p in processes: # pyright: ignore[reportAny]
|
||||
p.join(timeout=10) # pyright: ignore[reportAny]
|
||||
|
||||
results: dict[int, Any] = {}
|
||||
errors: dict[int, str] = {}
|
||||
while not result_queue.empty(): # pyright: ignore[reportAny]
|
||||
rank, success, value = result_queue.get() # pyright: ignore[reportAny]
|
||||
if success:
|
||||
results[rank] = value
|
||||
else:
|
||||
errors[rank] = value
|
||||
|
||||
assert len(results) == world_size, (
|
||||
f"Expected {world_size} results, got {len(results)}. Errors: {errors}"
|
||||
)
|
||||
|
||||
for rank in range(world_size):
|
||||
assert rank in results, (
|
||||
f"Device {rank} failed: {errors.get(rank, 'unknown')}"
|
||||
)
|
||||
result_array = results[rank]
|
||||
# Both devices see the final result (4.0) after all_gather
|
||||
assert (result_array == 4.0).all(), (
|
||||
f"Device {rank}: expected 4.0, got {result_array}"
|
||||
)
|
||||
finally:
|
||||
os.unlink(hostfile_path)
|
||||
230
src/exo/worker/tests/unittests/test_mlx/test_distributed_fix.py
Normal file
230
src/exo/worker/tests/unittests/test_mlx/test_distributed_fix.py
Normal file
@@ -0,0 +1,230 @@
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable
|
||||
|
||||
import pytest
|
||||
|
||||
from exo.worker.tests.unittests.test_mlx.conftest import (
|
||||
DEFAULT_GPT_OSS_CONFIG,
|
||||
create_hostfile,
|
||||
run_gpt_oss_pipeline_device,
|
||||
run_gpt_oss_tensor_parallel_device,
|
||||
)
|
||||
|
||||
|
||||
def _check_model_exists() -> bool:
|
||||
return DEFAULT_GPT_OSS_CONFIG.model_path.exists()
|
||||
|
||||
|
||||
pytestmark = [
|
||||
pytest.mark.skipif(
|
||||
not _check_model_exists(),
|
||||
reason=f"GPT-OSS model not found at {DEFAULT_GPT_OSS_CONFIG.model_path}",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class DistributedTestResult:
|
||||
timed_out: bool
|
||||
world_size: int
|
||||
results: dict[int, tuple[bool, str]]
|
||||
|
||||
@property
|
||||
def all_success(self) -> bool:
|
||||
if len(self.results) != self.world_size:
|
||||
return False
|
||||
return all(r[0] for r in self.results.values())
|
||||
|
||||
|
||||
def run_distributed_test(
|
||||
world_size: int,
|
||||
port_offset: int,
|
||||
process_timeout: int,
|
||||
target: Callable[..., None],
|
||||
make_args: Callable[[int], tuple[Any, ...]],
|
||||
) -> DistributedTestResult:
|
||||
ctx = mp.get_context("spawn")
|
||||
hostfile_path, _ = create_hostfile(
|
||||
world_size, DEFAULT_GPT_OSS_CONFIG.base_port + port_offset
|
||||
)
|
||||
|
||||
try:
|
||||
result_queue: Any = ctx.Queue()
|
||||
processes: list[Any] = []
|
||||
|
||||
for rank in range(world_size):
|
||||
args = make_args(rank)
|
||||
p = ctx.Process(
|
||||
target=target,
|
||||
args=(rank, world_size, hostfile_path, *args, result_queue),
|
||||
)
|
||||
p.start()
|
||||
processes.append(p)
|
||||
|
||||
for p in processes: # pyright: ignore[reportAny]
|
||||
p.join(timeout=process_timeout) # pyright: ignore[reportAny]
|
||||
|
||||
timed_out = any(p.is_alive() for p in processes) # pyright: ignore[reportAny]
|
||||
|
||||
for p in processes: # pyright: ignore[reportAny]
|
||||
if p.is_alive(): # pyright: ignore[reportAny]
|
||||
p.terminate() # pyright: ignore[reportAny]
|
||||
p.join(timeout=5) # pyright: ignore[reportAny]
|
||||
|
||||
results: dict[int, tuple[bool, str]] = {}
|
||||
while not result_queue.empty(): # pyright: ignore[reportAny]
|
||||
rank, success, value = result_queue.get() # pyright: ignore[reportAny]
|
||||
results[rank] = (success, value)
|
||||
|
||||
return DistributedTestResult(
|
||||
timed_out=timed_out, world_size=world_size, results=results
|
||||
)
|
||||
|
||||
finally:
|
||||
os.unlink(hostfile_path)
|
||||
|
||||
|
||||
def run_pipeline_test(
|
||||
layer_splits: list[tuple[int, int]],
|
||||
prompt_tokens: int,
|
||||
prefill_step_size: int,
|
||||
port_offset: int = 0,
|
||||
process_timeout: int = 60,
|
||||
) -> DistributedTestResult:
|
||||
def make_args(rank: int) -> tuple[Any, ...]:
|
||||
return (
|
||||
layer_splits,
|
||||
prompt_tokens,
|
||||
prefill_step_size,
|
||||
)
|
||||
|
||||
return run_distributed_test(
|
||||
world_size=len(layer_splits),
|
||||
port_offset=port_offset,
|
||||
process_timeout=process_timeout,
|
||||
target=run_gpt_oss_pipeline_device,
|
||||
make_args=make_args,
|
||||
)
|
||||
|
||||
|
||||
def run_tensor_test(
|
||||
prompt_tokens: int,
|
||||
prefill_step_size: int,
|
||||
port_offset: int = 0,
|
||||
process_timeout: int = 60,
|
||||
) -> DistributedTestResult:
|
||||
def make_args(rank: int) -> tuple[Any, ...]:
|
||||
return (
|
||||
prompt_tokens,
|
||||
prefill_step_size,
|
||||
)
|
||||
|
||||
return run_distributed_test(
|
||||
world_size=2,
|
||||
port_offset=port_offset,
|
||||
process_timeout=process_timeout,
|
||||
target=run_gpt_oss_tensor_parallel_device,
|
||||
make_args=make_args,
|
||||
)
|
||||
|
||||
|
||||
class TestPipelineParallelFix:
|
||||
BUG_TRIGGER_SPLITS: list[tuple[int, int]] = [(0, 1), (1, 24)]
|
||||
|
||||
def test_pipeline_single_layer_first_device(self) -> None:
|
||||
result = run_pipeline_test(
|
||||
layer_splits=self.BUG_TRIGGER_SPLITS,
|
||||
prompt_tokens=100,
|
||||
prefill_step_size=64,
|
||||
process_timeout=60,
|
||||
)
|
||||
assert not result.timed_out, "Unexpected timeout - fix may not be working"
|
||||
assert result.all_success, f"Failures: {result.results}"
|
||||
|
||||
|
||||
class TestPipelineSplitConfigurations:
|
||||
@pytest.mark.parametrize(
|
||||
"layer_splits",
|
||||
[
|
||||
[(0, 1), (1, 24)],
|
||||
[(0, 6), (6, 24)],
|
||||
[(0, 12), (12, 24)],
|
||||
],
|
||||
ids=["1_23", "6_18", "12_12"],
|
||||
)
|
||||
def test_pipeline_splits(
|
||||
self,
|
||||
layer_splits: list[tuple[int, int]],
|
||||
) -> None:
|
||||
result = run_pipeline_test(
|
||||
layer_splits=layer_splits,
|
||||
prompt_tokens=600,
|
||||
prefill_step_size=512,
|
||||
port_offset=100,
|
||||
)
|
||||
assert not result.timed_out, f"Timeout with {layer_splits}"
|
||||
assert result.all_success, f"Failures with {layer_splits}: {result.results}"
|
||||
|
||||
|
||||
class TestPrefillStepSizeBoundaries:
|
||||
@pytest.mark.parametrize(
|
||||
"prefill_step_size,prompt_tokens",
|
||||
[
|
||||
(512, 511),
|
||||
(512, 512),
|
||||
(512, 513),
|
||||
(512, 1024),
|
||||
],
|
||||
ids=["under", "exact", "over", "double"],
|
||||
)
|
||||
def test_boundary_conditions(
|
||||
self,
|
||||
prefill_step_size: int,
|
||||
prompt_tokens: int,
|
||||
) -> None:
|
||||
result = run_pipeline_test(
|
||||
layer_splits=[(0, 12), (12, 24)],
|
||||
prompt_tokens=prompt_tokens,
|
||||
prefill_step_size=prefill_step_size,
|
||||
port_offset=200,
|
||||
)
|
||||
assert not result.timed_out, f"Timeout: {prompt_tokens=}, {prefill_step_size=}"
|
||||
assert result.all_success, f"Failures: {result.results}"
|
||||
|
||||
|
||||
class TestTensorParallelFix:
|
||||
def test_tensor_parallel(self) -> None:
|
||||
result = run_tensor_test(
|
||||
prompt_tokens=100,
|
||||
prefill_step_size=64,
|
||||
port_offset=400,
|
||||
)
|
||||
assert not result.timed_out, "Unexpected timeout"
|
||||
assert result.all_success, f"Failures: {result.results}"
|
||||
|
||||
|
||||
class TestTensorParallelBoundaries:
|
||||
@pytest.mark.parametrize(
|
||||
"prefill_step_size,prompt_tokens",
|
||||
[
|
||||
(512, 511),
|
||||
(512, 512),
|
||||
(512, 513),
|
||||
(512, 1024),
|
||||
],
|
||||
ids=["under", "exact", "over", "double"],
|
||||
)
|
||||
def test_tensor_parallel_boundaries(
|
||||
self,
|
||||
prefill_step_size: int,
|
||||
prompt_tokens: int,
|
||||
) -> None:
|
||||
result = run_tensor_test(
|
||||
prompt_tokens=prompt_tokens,
|
||||
prefill_step_size=prefill_step_size,
|
||||
port_offset=500,
|
||||
)
|
||||
assert not result.timed_out, f"Timeout: {prompt_tokens=}, {prefill_step_size=}"
|
||||
assert result.all_success, f"Failures: {result.results}"
|
||||
@@ -1,50 +0,0 @@
|
||||
# pyright: reportAny=false
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from exo.shared.types.chunks import TokenChunk
|
||||
from exo.shared.types.common import CommandId
|
||||
from exo.shared.types.events import ChunkGenerated
|
||||
from exo.worker.runner.runner import send_error_chunk_on_exception
|
||||
from exo.worker.tests.constants import MODEL_A_ID
|
||||
|
||||
|
||||
def test_send_error_chunk_on_exception_no_error() -> None:
|
||||
event_sender = MagicMock()
|
||||
command_id = CommandId()
|
||||
|
||||
with send_error_chunk_on_exception(
|
||||
event_sender, command_id, MODEL_A_ID, device_rank=0
|
||||
):
|
||||
_ = 1 + 1
|
||||
|
||||
event_sender.send.assert_not_called()
|
||||
|
||||
|
||||
def test_send_error_chunk_on_exception_catches_error() -> None:
|
||||
event_sender = MagicMock()
|
||||
command_id = CommandId()
|
||||
|
||||
with send_error_chunk_on_exception(
|
||||
event_sender, command_id, MODEL_A_ID, device_rank=0
|
||||
):
|
||||
raise ValueError("test error")
|
||||
|
||||
event_sender.send.assert_called_once()
|
||||
call_args = event_sender.send.call_args[0][0]
|
||||
assert isinstance(call_args, ChunkGenerated)
|
||||
assert call_args.command_id == command_id
|
||||
assert isinstance(call_args.chunk, TokenChunk)
|
||||
assert call_args.chunk.finish_reason == "error"
|
||||
assert call_args.chunk.error_message == "test error"
|
||||
|
||||
|
||||
def test_send_error_chunk_on_exception_skips_non_rank_zero() -> None:
|
||||
event_sender = MagicMock()
|
||||
command_id = CommandId()
|
||||
|
||||
with send_error_chunk_on_exception(
|
||||
event_sender, command_id, MODEL_A_ID, device_rank=1
|
||||
):
|
||||
raise ValueError("test error")
|
||||
|
||||
event_sender.send.assert_not_called()
|
||||
Reference in New Issue
Block a user