mirror of
https://github.com/exo-explore/exo.git
synced 2026-01-19 11:28:51 -05:00
Compare commits
1 Commits
leo/fix-pi
...
alexcheema
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
71efccd10e |
@@ -69,6 +69,8 @@ export interface Instance {
|
||||
runnerToShard?: Record<string, unknown>;
|
||||
nodeToRunner?: Record<string, string>;
|
||||
};
|
||||
draftModel?: string;
|
||||
numDraftTokens?: number;
|
||||
}
|
||||
|
||||
interface RawNodeProfile {
|
||||
|
||||
@@ -47,7 +47,7 @@ const sidebarVisible = $derived(chatSidebarVisible());
|
||||
let mounted = $state(false);
|
||||
|
||||
// Instance launch state
|
||||
let models = $state<Array<{id: string, name?: string, storage_size_megabytes?: number}>>([]);
|
||||
let models = $state<Array<{id: string, hugging_face_id?: string, name?: string, storage_size_megabytes?: number}>>([]);
|
||||
let selectedSharding = $state<'Pipeline' | 'Tensor'>('Pipeline');
|
||||
type InstanceMeta = 'MlxRing' | 'MlxIbv' | 'MlxJaccl';
|
||||
|
||||
@@ -59,7 +59,7 @@ const sidebarVisible = $derived(chatSidebarVisible());
|
||||
instanceType: InstanceMeta;
|
||||
minNodes: number;
|
||||
}
|
||||
|
||||
|
||||
function saveLaunchDefaults(): void {
|
||||
const defaults: LaunchDefaults = {
|
||||
modelId: selectedPreviewModelId(),
|
||||
@@ -88,16 +88,16 @@ const sidebarVisible = $derived(chatSidebarVisible());
|
||||
function applyLaunchDefaults(availableModels: Array<{id: string}>, maxNodes: number): void {
|
||||
const defaults = loadLaunchDefaults();
|
||||
if (!defaults) return;
|
||||
|
||||
|
||||
// Apply sharding and instance type unconditionally
|
||||
selectedSharding = defaults.sharding;
|
||||
selectedInstanceType = defaults.instanceType;
|
||||
|
||||
|
||||
// Apply minNodes if valid (between 1 and maxNodes)
|
||||
if (defaults.minNodes && defaults.minNodes >= 1 && defaults.minNodes <= maxNodes) {
|
||||
selectedMinNodes = defaults.minNodes;
|
||||
}
|
||||
|
||||
|
||||
// Only apply model if it exists in the available models
|
||||
if (defaults.modelId && availableModels.some(m => m.id === defaults.modelId)) {
|
||||
selectPreviewModel(defaults.modelId);
|
||||
@@ -109,11 +109,19 @@ const sidebarVisible = $derived(chatSidebarVisible());
|
||||
let minNodesInitialized = $state(false);
|
||||
let launchingModelId = $state<string | null>(null);
|
||||
let instanceDownloadExpandedNodes = $state<Set<string>>(new Set());
|
||||
|
||||
|
||||
// Draft model edit modal state
|
||||
let editingDraftInstanceId = $state<string | null>(null);
|
||||
let editDraftModel = $state<string | null>(null);
|
||||
let editNumDraftTokens = $state<number>(4);
|
||||
let isDraftEditDropdownOpen = $state(false);
|
||||
let draftEditDropdownSearch = $state('');
|
||||
let isSavingDraftModel = $state(false);
|
||||
|
||||
// Custom dropdown state
|
||||
let isModelDropdownOpen = $state(false);
|
||||
let modelDropdownSearch = $state('');
|
||||
|
||||
|
||||
// Slider dragging state
|
||||
let isDraggingSlider = $state(false);
|
||||
let sliderTrackElement: HTMLDivElement | null = $state(null);
|
||||
@@ -362,47 +370,36 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
|
||||
async function launchInstance(modelId: string, specificPreview?: PlacementPreview | null) {
|
||||
if (!modelId || launchingModelId) return;
|
||||
|
||||
|
||||
launchingModelId = modelId;
|
||||
|
||||
|
||||
try {
|
||||
// Use the specific preview if provided, otherwise fall back to filtered preview
|
||||
const preview = specificPreview ?? filteredPreview();
|
||||
|
||||
let instanceData: unknown;
|
||||
|
||||
if (preview?.instance) {
|
||||
// Use the instance from the preview
|
||||
instanceData = preview.instance;
|
||||
} else {
|
||||
// Fallback: GET placement from API
|
||||
const placementResponse = await fetch(
|
||||
`/instance/placement?model_id=${encodeURIComponent(modelId)}&sharding=${selectedSharding}&instance_meta=${selectedInstanceType}&min_nodes=${selectedMinNodes}`
|
||||
);
|
||||
|
||||
if (!placementResponse.ok) {
|
||||
const errorText = await placementResponse.text();
|
||||
console.error('Failed to get placement:', errorText);
|
||||
return;
|
||||
}
|
||||
|
||||
instanceData = await placementResponse.json();
|
||||
}
|
||||
|
||||
// POST the instance to create it
|
||||
const response = await fetch('/instance', {
|
||||
|
||||
let response: Response;
|
||||
|
||||
// Use /place_instance endpoint - it handles placement and creation in one step
|
||||
const placePayload = {
|
||||
model_id: modelId,
|
||||
sharding: preview?.sharding ?? selectedSharding,
|
||||
instance_meta: preview?.instance_meta ?? selectedInstanceType,
|
||||
min_nodes: selectedMinNodes,
|
||||
};
|
||||
|
||||
response = await fetch('/place_instance', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ instance: instanceData })
|
||||
body: JSON.stringify(placePayload)
|
||||
});
|
||||
|
||||
|
||||
if (!response.ok) {
|
||||
const errorText = await response.text();
|
||||
console.error('Failed to launch instance:', errorText);
|
||||
} else {
|
||||
// Always auto-select the newly launched model so the user chats to what they just launched
|
||||
setSelectedChatModel(modelId);
|
||||
|
||||
|
||||
// Scroll to the bottom of instances container to show the new instance
|
||||
// Use multiple attempts to ensure DOM has updated with the new instance
|
||||
const scrollToBottom = () => {
|
||||
@@ -797,6 +794,52 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
}
|
||||
}
|
||||
|
||||
// Open draft model edit modal for an instance
|
||||
function openDraftModelEdit(instanceId: string, currentDraftModel: string | null, currentNumTokens: number | null) {
|
||||
editingDraftInstanceId = instanceId;
|
||||
editDraftModel = currentDraftModel;
|
||||
editNumDraftTokens = currentNumTokens ?? 4;
|
||||
isDraftEditDropdownOpen = false;
|
||||
draftEditDropdownSearch = '';
|
||||
}
|
||||
|
||||
// Close draft model edit modal
|
||||
function closeDraftModelEdit() {
|
||||
editingDraftInstanceId = null;
|
||||
editDraftModel = null;
|
||||
editNumDraftTokens = 4;
|
||||
isDraftEditDropdownOpen = false;
|
||||
draftEditDropdownSearch = '';
|
||||
}
|
||||
|
||||
// Save draft model settings for an instance
|
||||
async function saveDraftModel() {
|
||||
if (!editingDraftInstanceId || isSavingDraftModel) return;
|
||||
|
||||
isSavingDraftModel = true;
|
||||
try {
|
||||
const response = await fetch(`/instance/${editingDraftInstanceId}/draft_model`, {
|
||||
method: 'PUT',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({
|
||||
draft_model: editDraftModel,
|
||||
num_draft_tokens: editNumDraftTokens,
|
||||
})
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const errorText = await response.text();
|
||||
console.error('Failed to set draft model:', errorText);
|
||||
} else {
|
||||
closeDraftModelEdit();
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Error setting draft model:', error);
|
||||
} finally {
|
||||
isSavingDraftModel = false;
|
||||
}
|
||||
}
|
||||
|
||||
// Helper to unwrap tagged unions like { MlxRingInstance: {...} }
|
||||
function getTagged(obj: unknown): [string | null, unknown] {
|
||||
if (!obj || typeof obj !== 'object') return [null, null];
|
||||
@@ -816,30 +859,34 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
}
|
||||
|
||||
// Get instance details: type (MLX Ring/IBV), sharding (Pipeline/Tensor), and node names
|
||||
function getInstanceInfo(instanceWrapped: unknown): {
|
||||
instanceType: string;
|
||||
sharding: string;
|
||||
function getInstanceInfo(instanceWrapped: unknown): {
|
||||
instanceType: string;
|
||||
sharding: string;
|
||||
nodeNames: string[];
|
||||
nodeIds: string[];
|
||||
nodeCount: number;
|
||||
draftModel: string | null;
|
||||
numDraftTokens: number | null;
|
||||
} {
|
||||
const [instanceTag, instance] = getTagged(instanceWrapped);
|
||||
if (!instance || typeof instance !== 'object') {
|
||||
return { instanceType: 'Unknown', sharding: 'Unknown', nodeNames: [], nodeIds: [], nodeCount: 0 };
|
||||
return { instanceType: 'Unknown', sharding: 'Unknown', nodeNames: [], nodeIds: [], nodeCount: 0, draftModel: null, numDraftTokens: null };
|
||||
}
|
||||
|
||||
|
||||
// Instance type from tag
|
||||
let instanceType = 'Unknown';
|
||||
if (instanceTag === 'MlxRingInstance') instanceType = 'MLX Ring';
|
||||
else if (instanceTag === 'MlxIbvInstance' || instanceTag === 'MlxJacclInstance') instanceType = 'MLX RDMA';
|
||||
|
||||
const inst = instance as {
|
||||
shardAssignments?: {
|
||||
nodeToRunner?: Record<string, string>;
|
||||
|
||||
const inst = instance as {
|
||||
shardAssignments?: {
|
||||
nodeToRunner?: Record<string, string>;
|
||||
runnerToShard?: Record<string, unknown>;
|
||||
}
|
||||
};
|
||||
draftModel?: string;
|
||||
numDraftTokens?: number;
|
||||
};
|
||||
|
||||
|
||||
// Sharding strategy from first shard
|
||||
let sharding = 'Unknown';
|
||||
const runnerToShard = inst.shardAssignments?.runnerToShard || {};
|
||||
@@ -850,7 +897,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
else if (shardTag === 'TensorShardMetadata') sharding = 'Tensor';
|
||||
else if (shardTag === 'PrefillDecodeShardMetadata') sharding = 'Prefill/Decode';
|
||||
}
|
||||
|
||||
|
||||
// Node names from topology
|
||||
const nodeToRunner = inst.shardAssignments?.nodeToRunner || {};
|
||||
const nodeIds = Object.keys(nodeToRunner);
|
||||
@@ -858,8 +905,12 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
const node = data?.nodes?.[nodeId];
|
||||
return node?.friendly_name || nodeId.slice(0, 8);
|
||||
});
|
||||
|
||||
return { instanceType, sharding, nodeNames, nodeIds, nodeCount: nodeIds.length };
|
||||
|
||||
// Draft model for speculative decoding
|
||||
const draftModel = inst.draftModel ?? null;
|
||||
const numDraftTokens = inst.numDraftTokens ?? null;
|
||||
|
||||
return { instanceType, sharding, nodeNames, nodeIds, nodeCount: nodeIds.length, draftModel, numDraftTokens };
|
||||
}
|
||||
|
||||
function formatLastUpdate(): string {
|
||||
@@ -1335,16 +1386,31 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
<div class="w-1.5 h-1.5 {isDownloading ? 'bg-blue-400 animate-pulse' : isFailed ? 'bg-red-400' : isLoading ? 'bg-yellow-400 animate-pulse' : isReady ? 'bg-green-400' : 'bg-teal-400'} rounded-full shadow-[0_0_6px_currentColor]"></div>
|
||||
<span class="text-exo-light-gray font-mono text-sm tracking-wider">{id.slice(0, 8).toUpperCase()}</span>
|
||||
</div>
|
||||
<button
|
||||
onclick={() => deleteInstance(id)}
|
||||
class="text-xs px-2 py-1 font-mono tracking-wider uppercase border border-red-500/30 text-red-400 hover:bg-red-500/20 hover:text-red-400 hover:border-red-500/50 transition-all duration-200 cursor-pointer"
|
||||
>
|
||||
DELETE
|
||||
</button>
|
||||
<div class="flex items-center gap-2">
|
||||
<!-- Draft Model Button -->
|
||||
<button
|
||||
onclick={() => openDraftModelEdit(id, instanceInfo.draftModel, instanceInfo.numDraftTokens)}
|
||||
class="p-1.5 font-mono border transition-all duration-200 cursor-pointer {instanceInfo.draftModel ? 'border-cyan-500/50 text-cyan-400 hover:bg-cyan-500/20 hover:border-cyan-500' : 'border-exo-medium-gray/50 text-white/40 hover:text-cyan-400 hover:border-cyan-500/50'}"
|
||||
title={instanceInfo.draftModel ? `Draft: ${instanceInfo.draftModel.split('/').pop()} (${instanceInfo.numDraftTokens}t)` : 'Configure speculative decoding'}
|
||||
>
|
||||
<svg class="w-4 h-4" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
|
||||
<path d="M13 2L3 14h9l-1 8 10-12h-9l1-8z"/>
|
||||
</svg>
|
||||
</button>
|
||||
<button
|
||||
onclick={() => deleteInstance(id)}
|
||||
class="text-xs px-2 py-1 font-mono tracking-wider uppercase border border-red-500/30 text-red-400 hover:bg-red-500/20 hover:text-red-400 hover:border-red-500/50 transition-all duration-200 cursor-pointer"
|
||||
>
|
||||
DELETE
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
<div class="pl-2">
|
||||
<div class="text-exo-yellow text-xs font-mono tracking-wide truncate">{getInstanceModelId(instance)}</div>
|
||||
<div class="text-white/60 text-xs font-mono">Strategy: <span class="text-white/80">{instanceInfo.sharding} ({instanceInfo.instanceType})</span></div>
|
||||
{#if instanceInfo.draftModel}
|
||||
<div class="text-white/60 text-xs font-mono">Draft: <span class="text-cyan-400">{instanceInfo.draftModel.split('/').pop()}</span>{#if instanceInfo.numDraftTokens}<span class="text-white/40"> ({instanceInfo.numDraftTokens}t)</span>{/if}</div>
|
||||
{/if}
|
||||
{#if instanceModelId && instanceModelId !== 'Unknown' && instanceModelId !== 'Unknown Model'}
|
||||
<a
|
||||
class="inline-flex items-center gap-1 text-[11px] text-white/60 hover:text-exo-yellow transition-colors mt-1"
|
||||
@@ -1679,7 +1745,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
|
||||
<!-- Selected Model Preview -->
|
||||
<div class="space-y-3">
|
||||
{#if models.length === 0}
|
||||
@@ -1838,16 +1904,31 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
<div class="w-1.5 h-1.5 {isDownloading ? 'bg-blue-400 animate-pulse' : isFailed ? 'bg-red-400' : isLoading ? 'bg-yellow-400 animate-pulse' : isReady ? 'bg-green-400' : 'bg-teal-400'} rounded-full shadow-[0_0_6px_currentColor]"></div>
|
||||
<span class="text-exo-light-gray font-mono text-sm tracking-wider">{id.slice(0, 8).toUpperCase()}</span>
|
||||
</div>
|
||||
<button
|
||||
onclick={() => deleteInstance(id)}
|
||||
class="text-xs px-2 py-1 font-mono tracking-wider uppercase border border-red-500/30 text-red-400 hover:bg-red-500/20 hover:text-red-400 hover:border-red-500/50 transition-all duration-200 cursor-pointer"
|
||||
>
|
||||
DELETE
|
||||
</button>
|
||||
<div class="flex items-center gap-2">
|
||||
<!-- Draft Model Button -->
|
||||
<button
|
||||
onclick={() => openDraftModelEdit(id, instanceInfo.draftModel, instanceInfo.numDraftTokens)}
|
||||
class="p-1.5 font-mono border transition-all duration-200 cursor-pointer {instanceInfo.draftModel ? 'border-cyan-500/50 text-cyan-400 hover:bg-cyan-500/20 hover:border-cyan-500' : 'border-exo-medium-gray/50 text-white/40 hover:text-cyan-400 hover:border-cyan-500/50'}"
|
||||
title={instanceInfo.draftModel ? `Draft: ${instanceInfo.draftModel.split('/').pop()} (${instanceInfo.numDraftTokens}t)` : 'Configure speculative decoding'}
|
||||
>
|
||||
<svg class="w-4 h-4" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
|
||||
<path d="M13 2L3 14h9l-1 8 10-12h-9l1-8z"/>
|
||||
</svg>
|
||||
</button>
|
||||
<button
|
||||
onclick={() => deleteInstance(id)}
|
||||
class="text-xs px-2 py-1 font-mono tracking-wider uppercase border border-red-500/30 text-red-400 hover:bg-red-500/20 hover:text-red-400 hover:border-red-500/50 transition-all duration-200 cursor-pointer"
|
||||
>
|
||||
DELETE
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
<div class="pl-2">
|
||||
<div class="text-exo-yellow text-xs font-mono tracking-wide truncate">{getInstanceModelId(instance)}</div>
|
||||
<div class="text-white/60 text-xs font-mono">Strategy: <span class="text-white/80">{instanceInfo.sharding} ({instanceInfo.instanceType})</span></div>
|
||||
{#if instanceInfo.draftModel}
|
||||
<div class="text-white/60 text-xs font-mono">Draft: <span class="text-cyan-400">{instanceInfo.draftModel.split('/').pop()}</span>{#if instanceInfo.numDraftTokens}<span class="text-white/40"> ({instanceInfo.numDraftTokens}t)</span>{/if}</div>
|
||||
{/if}
|
||||
{#if instanceModelId && instanceModelId !== 'Unknown' && instanceModelId !== 'Unknown Model'}
|
||||
<a
|
||||
class="inline-flex items-center gap-1 text-[11px] text-white/60 hover:text-exo-yellow transition-colors mt-1"
|
||||
@@ -1978,4 +2059,120 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
{/if}
|
||||
</main>
|
||||
|
||||
<!-- Draft Model Edit Modal -->
|
||||
{#if editingDraftInstanceId}
|
||||
<!-- svelte-ignore a11y_no_static_element_interactions -->
|
||||
<div
|
||||
class="fixed inset-0 z-50 flex items-center justify-center bg-black/70 backdrop-blur-sm"
|
||||
onclick={closeDraftModelEdit}
|
||||
onkeydown={(e) => e.key === 'Escape' && closeDraftModelEdit()}
|
||||
>
|
||||
<!-- svelte-ignore a11y_click_events_have_key_events -->
|
||||
<div
|
||||
class="bg-exo-dark-gray border border-exo-medium-gray/50 rounded-lg shadow-2xl p-6 w-full max-w-md mx-4"
|
||||
onclick={(e) => e.stopPropagation()}
|
||||
>
|
||||
<div class="flex items-center justify-between mb-4">
|
||||
<h3 class="text-lg font-mono text-exo-yellow tracking-wide">Speculative Decoding</h3>
|
||||
<button
|
||||
onclick={closeDraftModelEdit}
|
||||
class="text-white/60 hover:text-white transition-colors cursor-pointer"
|
||||
aria-label="Close"
|
||||
>
|
||||
<svg class="w-5 h-5" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" d="M6 18L18 6M6 6l12 12" />
|
||||
</svg>
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<p class="text-white/60 text-sm font-mono mb-4">
|
||||
Configure a draft model for faster generation. The draft model proposes tokens that the main model verifies.
|
||||
</p>
|
||||
|
||||
<!-- Draft Model Dropdown -->
|
||||
<div class="mb-4">
|
||||
<div class="text-xs text-white/70 font-mono mb-2">Draft Model:</div>
|
||||
<div class="relative">
|
||||
<button
|
||||
onclick={() => { isDraftEditDropdownOpen = !isDraftEditDropdownOpen; draftEditDropdownSearch = ''; }}
|
||||
class="w-full px-3 py-2 text-left text-sm font-mono border rounded transition-all duration-200 cursor-pointer flex items-center justify-between gap-2 {editDraftModel ? 'bg-transparent text-cyan-400 border-cyan-500/50' : 'bg-transparent text-white/50 border-exo-medium-gray/50 hover:border-cyan-500/50'}"
|
||||
>
|
||||
<span class="truncate">{editDraftModel ? editDraftModel.split('/').pop() : 'None'}</span>
|
||||
<svg class="w-4 h-4 flex-shrink-0 transition-transform {isDraftEditDropdownOpen ? 'rotate-180' : ''}" fill="none" stroke="currentColor" viewBox="0 0 24 24">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M19 9l-7 7-7-7" />
|
||||
</svg>
|
||||
</button>
|
||||
{#if isDraftEditDropdownOpen}
|
||||
<div class="absolute top-full left-0 right-0 mt-1 bg-exo-dark-gray border border-exo-medium-gray/50 rounded shadow-lg z-50 max-h-48 overflow-hidden flex flex-col">
|
||||
<div class="p-2 border-b border-exo-medium-gray/30">
|
||||
<input
|
||||
type="text"
|
||||
bind:value={draftEditDropdownSearch}
|
||||
placeholder="Search models..."
|
||||
class="w-full px-2 py-1.5 text-sm font-mono bg-transparent border border-exo-medium-gray/50 rounded text-white/90 placeholder:text-white/30 focus:outline-none focus:border-cyan-500/50"
|
||||
/>
|
||||
</div>
|
||||
<div class="overflow-y-auto max-h-36">
|
||||
<!-- None option -->
|
||||
<button
|
||||
onclick={() => { editDraftModel = null; isDraftEditDropdownOpen = false; }}
|
||||
class="w-full px-3 py-2 text-left text-sm font-mono tracking-wide transition-colors duration-100 flex items-center gap-2 {editDraftModel === null ? 'bg-transparent text-cyan-400 cursor-pointer' : 'text-white/80 hover:text-cyan-400 cursor-pointer'}"
|
||||
>
|
||||
<span>None (Disable)</span>
|
||||
</button>
|
||||
{#each models.filter(m => (m.name ?? m.id).toLowerCase().includes(draftEditDropdownSearch.toLowerCase())) as model}
|
||||
{@const sizeGB = (model.storage_size_megabytes ?? 0) / 1024}
|
||||
{@const modelHfId = model.hugging_face_id ?? model.id}
|
||||
<button
|
||||
onclick={() => { editDraftModel = modelHfId; isDraftEditDropdownOpen = false; }}
|
||||
class="w-full px-3 py-2 text-left text-sm font-mono tracking-wide transition-colors duration-100 flex items-center justify-between gap-2 {editDraftModel === modelHfId ? 'bg-transparent text-cyan-400 cursor-pointer' : 'text-white/80 hover:text-cyan-400 cursor-pointer'}"
|
||||
>
|
||||
<span class="truncate">{model.name || model.id}</span>
|
||||
<span class="flex-shrink-0 text-xs text-white/50">
|
||||
{sizeGB >= 1 ? sizeGB.toFixed(0) : sizeGB.toFixed(1)}GB
|
||||
</span>
|
||||
</button>
|
||||
{:else}
|
||||
<div class="px-3 py-2 text-xs text-white/50 font-mono">No models found</div>
|
||||
{/each}
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Draft Tokens -->
|
||||
{#if editDraftModel}
|
||||
<div class="mb-6">
|
||||
<div class="text-xs text-white/70 font-mono mb-2">Draft Tokens per Iteration:</div>
|
||||
<div class="flex items-center gap-2">
|
||||
{#each [2, 3, 4, 5, 6] as n}
|
||||
<button
|
||||
onclick={() => editNumDraftTokens = n}
|
||||
class="w-8 h-8 text-sm font-mono rounded transition-all {editNumDraftTokens === n ? 'bg-cyan-500/20 text-cyan-400 border border-cyan-500/50' : 'text-white/50 hover:text-white/80 border border-exo-medium-gray/50 hover:border-white/30'} cursor-pointer"
|
||||
>{n}</button>
|
||||
{/each}
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<!-- Action Buttons -->
|
||||
<div class="flex items-center justify-end gap-3">
|
||||
<button
|
||||
onclick={closeDraftModelEdit}
|
||||
class="px-4 py-2 text-sm font-mono text-white/70 hover:text-white transition-colors cursor-pointer"
|
||||
>
|
||||
Cancel
|
||||
</button>
|
||||
<button
|
||||
onclick={saveDraftModel}
|
||||
disabled={isSavingDraftModel}
|
||||
class="px-4 py-2 text-sm font-mono border border-cyan-500/50 text-cyan-400 hover:bg-cyan-500/20 hover:border-cyan-500 transition-all disabled:opacity-50 disabled:cursor-not-allowed cursor-pointer"
|
||||
>
|
||||
{isSavingDraftModel ? 'Saving...' : 'Save'}
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
@@ -39,6 +39,8 @@ from exo.shared.types.api import (
|
||||
PlaceInstanceParams,
|
||||
PlacementPreview,
|
||||
PlacementPreviewResponse,
|
||||
SetDraftModelParams,
|
||||
SetDraftModelResponse,
|
||||
StreamingChoiceResponse,
|
||||
)
|
||||
from exo.shared.types.chunks import TokenChunk
|
||||
@@ -49,6 +51,7 @@ from exo.shared.types.commands import (
|
||||
DeleteInstance,
|
||||
ForwarderCommand,
|
||||
PlaceInstance,
|
||||
SetInstanceDraftModel,
|
||||
TaskFinished,
|
||||
)
|
||||
from exo.shared.types.common import CommandId, NodeId, SessionId
|
||||
@@ -185,6 +188,7 @@ class API:
|
||||
self.app.get("/instance/previews")(self.get_placement_previews)
|
||||
self.app.get("/instance/{instance_id}")(self.get_instance)
|
||||
self.app.delete("/instance/{instance_id}")(self.delete_instance)
|
||||
self.app.put("/instance/{instance_id}/draft_model")(self.set_draft_model)
|
||||
self.app.get("/models")(self.get_models)
|
||||
self.app.get("/v1/models")(self.get_models)
|
||||
self.app.post("/v1/chat/completions", response_model=None)(
|
||||
@@ -200,6 +204,8 @@ class API:
|
||||
sharding=payload.sharding,
|
||||
instance_meta=payload.instance_meta,
|
||||
min_nodes=payload.min_nodes,
|
||||
draft_model=payload.draft_model,
|
||||
num_draft_tokens=payload.num_draft_tokens,
|
||||
)
|
||||
await self._send(command)
|
||||
|
||||
@@ -396,6 +402,24 @@ class API:
|
||||
instance_id=instance_id,
|
||||
)
|
||||
|
||||
async def set_draft_model(
|
||||
self, instance_id: InstanceId, payload: SetDraftModelParams
|
||||
) -> SetDraftModelResponse:
|
||||
if instance_id not in self.state.instances:
|
||||
raise HTTPException(status_code=404, detail="Instance not found")
|
||||
|
||||
command = SetInstanceDraftModel(
|
||||
instance_id=instance_id,
|
||||
draft_model=payload.draft_model,
|
||||
num_draft_tokens=payload.num_draft_tokens,
|
||||
)
|
||||
await self._send(command)
|
||||
return SetDraftModelResponse(
|
||||
message="Command received.",
|
||||
command_id=command.command_id,
|
||||
instance_id=instance_id,
|
||||
)
|
||||
|
||||
async def _chat_chunk_stream(
|
||||
self, command_id: CommandId
|
||||
) -> AsyncGenerator[TokenChunk, None]:
|
||||
|
||||
@@ -18,6 +18,7 @@ from exo.shared.types.commands import (
|
||||
ForwarderCommand,
|
||||
PlaceInstance,
|
||||
RequestEventLog,
|
||||
SetInstanceDraftModel,
|
||||
TaskFinished,
|
||||
TestCommand,
|
||||
)
|
||||
@@ -27,6 +28,7 @@ from exo.shared.types.events import (
|
||||
ForwarderEvent,
|
||||
IndexedEvent,
|
||||
InstanceDeleted,
|
||||
InstanceDraftModelUpdated,
|
||||
NodeTimedOut,
|
||||
TaskCreated,
|
||||
TaskDeleted,
|
||||
@@ -173,6 +175,14 @@ class Master:
|
||||
self.state.instances, placement
|
||||
)
|
||||
generated_events.extend(transition_events)
|
||||
case SetInstanceDraftModel():
|
||||
generated_events.append(
|
||||
InstanceDraftModelUpdated(
|
||||
instance_id=command.instance_id,
|
||||
draft_model=command.draft_model,
|
||||
num_draft_tokens=command.num_draft_tokens,
|
||||
)
|
||||
)
|
||||
case TaskFinished():
|
||||
generated_events.append(
|
||||
TaskDeleted(
|
||||
|
||||
@@ -3,8 +3,6 @@ from collections.abc import Mapping
|
||||
from copy import deepcopy
|
||||
from typing import Sequence
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from exo.master.placement_utils import (
|
||||
filter_cycles_by_memory,
|
||||
get_mlx_ibv_devices_matrix,
|
||||
@@ -55,7 +53,6 @@ def place_instance(
|
||||
) -> dict[InstanceId, Instance]:
|
||||
all_nodes = list(topology.list_nodes())
|
||||
|
||||
logger.info("finding cycles:")
|
||||
cycles = topology.get_cycles()
|
||||
singleton_cycles = [[node] for node in all_nodes]
|
||||
candidate_cycles = list(
|
||||
@@ -128,10 +125,6 @@ def place_instance(
|
||||
target_instances = dict(deepcopy(current_instances))
|
||||
|
||||
if len(selected_cycle) == 1:
|
||||
logger.warning(
|
||||
"You have likely selected ibv for a single node instance; falling back to MlxRing"
|
||||
)
|
||||
|
||||
command.instance_meta = InstanceMeta.MlxRing
|
||||
|
||||
# TODO: Single node instances
|
||||
@@ -151,6 +144,8 @@ def place_instance(
|
||||
shard_assignments=shard_assignments,
|
||||
ibv_devices=mlx_ibv_devices,
|
||||
jaccl_coordinators=mlx_jaccl_coordinators,
|
||||
draft_model=command.draft_model,
|
||||
num_draft_tokens=command.num_draft_tokens,
|
||||
)
|
||||
case InstanceMeta.MlxRing:
|
||||
ephemeral_port = random_ephemeral_port()
|
||||
@@ -164,6 +159,8 @@ def place_instance(
|
||||
shard_assignments=shard_assignments,
|
||||
hosts_by_node=hosts_by_node,
|
||||
ephemeral_port=ephemeral_port,
|
||||
draft_model=command.draft_model,
|
||||
num_draft_tokens=command.num_draft_tokens,
|
||||
)
|
||||
|
||||
return target_instances
|
||||
|
||||
@@ -11,6 +11,7 @@ from exo.shared.types.events import (
|
||||
IndexedEvent,
|
||||
InstanceCreated,
|
||||
InstanceDeleted,
|
||||
InstanceDraftModelUpdated,
|
||||
NodeCreated,
|
||||
NodeDownloadProgress,
|
||||
NodeMemoryMeasured,
|
||||
@@ -47,6 +48,8 @@ def event_apply(event: Event, state: State) -> State:
|
||||
return apply_instance_created(event, state)
|
||||
case InstanceDeleted():
|
||||
return apply_instance_deleted(event, state)
|
||||
case InstanceDraftModelUpdated():
|
||||
return apply_instance_draft_model_updated(event, state)
|
||||
case NodeCreated():
|
||||
return apply_topology_node_created(event, state)
|
||||
case NodeTimedOut():
|
||||
@@ -169,6 +172,25 @@ def apply_instance_deleted(event: InstanceDeleted, state: State) -> State:
|
||||
return state.model_copy(update={"instances": new_instances})
|
||||
|
||||
|
||||
def apply_instance_draft_model_updated(
|
||||
event: InstanceDraftModelUpdated, state: State
|
||||
) -> State:
|
||||
if event.instance_id not in state.instances:
|
||||
return state
|
||||
instance = state.instances[event.instance_id]
|
||||
updated_instance = instance.model_copy(
|
||||
update={
|
||||
"draft_model": event.draft_model,
|
||||
"num_draft_tokens": event.num_draft_tokens,
|
||||
}
|
||||
)
|
||||
new_instances: Mapping[InstanceId, Instance] = {
|
||||
**state.instances,
|
||||
event.instance_id: updated_instance,
|
||||
}
|
||||
return state.model_copy(update={"instances": new_instances})
|
||||
|
||||
|
||||
def apply_runner_status_updated(event: RunnerStatusUpdated, state: State) -> State:
|
||||
new_runners: Mapping[RunnerId, RunnerStatus] = {
|
||||
**state.runners,
|
||||
|
||||
@@ -161,6 +161,8 @@ class ChatCompletionTaskParams(BaseModel):
|
||||
tool_choice: str | dict[str, Any] | None = None
|
||||
parallel_tool_calls: bool | None = None
|
||||
user: str | None = None
|
||||
# Speculative decoding: tokens to draft per iteration (if instance has draft model)
|
||||
num_draft_tokens: int = 3
|
||||
|
||||
|
||||
class BenchChatCompletionTaskParams(ChatCompletionTaskParams):
|
||||
@@ -172,6 +174,8 @@ class PlaceInstanceParams(BaseModel):
|
||||
sharding: Sharding = Sharding.Pipeline
|
||||
instance_meta: InstanceMeta = InstanceMeta.MlxRing
|
||||
min_nodes: int = 1
|
||||
draft_model: ModelId | None = None # For speculative decoding
|
||||
num_draft_tokens: int = 4 # Tokens to draft per iteration
|
||||
|
||||
@field_validator("sharding", "instance_meta", mode="plain")
|
||||
@classmethod
|
||||
@@ -213,3 +217,14 @@ class DeleteInstanceResponse(BaseModel):
|
||||
message: str
|
||||
command_id: CommandId
|
||||
instance_id: InstanceId
|
||||
|
||||
|
||||
class SetDraftModelParams(BaseModel):
|
||||
draft_model: ModelId | None = None # None to disable speculative decoding
|
||||
num_draft_tokens: int = 4
|
||||
|
||||
|
||||
class SetDraftModelResponse(BaseModel):
|
||||
message: str
|
||||
command_id: CommandId
|
||||
instance_id: InstanceId
|
||||
|
||||
@@ -2,7 +2,7 @@ from pydantic import Field
|
||||
|
||||
from exo.shared.types.api import ChatCompletionTaskParams
|
||||
from exo.shared.types.common import CommandId, NodeId
|
||||
from exo.shared.types.models import ModelMetadata
|
||||
from exo.shared.types.models import ModelId, ModelMetadata
|
||||
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
|
||||
from exo.shared.types.worker.shards import Sharding
|
||||
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
|
||||
@@ -25,6 +25,8 @@ class PlaceInstance(BaseCommand):
|
||||
sharding: Sharding
|
||||
instance_meta: InstanceMeta
|
||||
min_nodes: int
|
||||
draft_model: ModelId | None = None # For speculative decoding
|
||||
num_draft_tokens: int = 4 # Tokens to draft per iteration
|
||||
|
||||
|
||||
class CreateInstance(BaseCommand):
|
||||
@@ -35,6 +37,14 @@ class DeleteInstance(BaseCommand):
|
||||
instance_id: InstanceId
|
||||
|
||||
|
||||
class SetInstanceDraftModel(BaseCommand):
|
||||
"""Set or update the draft model for an existing instance."""
|
||||
|
||||
instance_id: InstanceId
|
||||
draft_model: ModelId | None # None to disable speculative decoding
|
||||
num_draft_tokens: int = 4
|
||||
|
||||
|
||||
class TaskFinished(BaseCommand):
|
||||
finished_command_id: CommandId
|
||||
|
||||
@@ -50,6 +60,7 @@ Command = (
|
||||
| PlaceInstance
|
||||
| CreateInstance
|
||||
| DeleteInstance
|
||||
| SetInstanceDraftModel
|
||||
| TaskFinished
|
||||
)
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ from pydantic import Field
|
||||
from exo.shared.topology import Connection, NodePerformanceProfile
|
||||
from exo.shared.types.chunks import GenerationChunk
|
||||
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
|
||||
from exo.shared.types.models import ModelId
|
||||
from exo.shared.types.profiling import MemoryPerformanceProfile
|
||||
from exo.shared.types.tasks import Task, TaskId, TaskStatus
|
||||
from exo.shared.types.worker.downloads import DownloadProgress
|
||||
@@ -67,6 +68,14 @@ class InstanceDeleted(BaseEvent):
|
||||
instance_id: InstanceId
|
||||
|
||||
|
||||
class InstanceDraftModelUpdated(BaseEvent):
|
||||
"""Draft model updated on an existing instance."""
|
||||
|
||||
instance_id: InstanceId
|
||||
draft_model: ModelId | None
|
||||
num_draft_tokens: int
|
||||
|
||||
|
||||
class RunnerStatusUpdated(BaseEvent):
|
||||
runner_id: RunnerId
|
||||
runner_status: RunnerStatus
|
||||
@@ -123,6 +132,7 @@ Event = (
|
||||
| TaskAcknowledged
|
||||
| InstanceCreated
|
||||
| InstanceDeleted
|
||||
| InstanceDraftModelUpdated
|
||||
| RunnerStatusUpdated
|
||||
| RunnerDeleted
|
||||
| NodeCreated
|
||||
|
||||
@@ -36,6 +36,12 @@ class DownloadModel(BaseTask): # emitted by Worker
|
||||
shard_metadata: ShardMetadata
|
||||
|
||||
|
||||
class DownloadDraftModel(BaseTask): # emitted by Worker
|
||||
"""Download a draft model for speculative decoding (rank 0 only)."""
|
||||
|
||||
model_id: str # HuggingFace model ID
|
||||
|
||||
|
||||
class LoadModel(BaseTask): # emitted by Worker
|
||||
pass
|
||||
|
||||
@@ -60,12 +66,21 @@ class Shutdown(BaseTask): # emitted by Worker
|
||||
runner_id: RunnerId
|
||||
|
||||
|
||||
class SetDraftModel(BaseTask): # emitted by Worker
|
||||
"""Load or clear a draft model on an already-running instance."""
|
||||
|
||||
model_id: str | None # HuggingFace model ID, or None to clear
|
||||
num_draft_tokens: int = 4
|
||||
|
||||
|
||||
Task = (
|
||||
CreateRunner
|
||||
| DownloadModel
|
||||
| DownloadDraftModel
|
||||
| ConnectToGroup
|
||||
| LoadModel
|
||||
| StartWarmup
|
||||
| ChatCompletion
|
||||
| Shutdown
|
||||
| SetDraftModel
|
||||
)
|
||||
|
||||
@@ -3,6 +3,7 @@ from enum import Enum
|
||||
from pydantic import model_validator
|
||||
|
||||
from exo.shared.types.common import Host, Id, NodeId
|
||||
from exo.shared.types.models import ModelId
|
||||
from exo.shared.types.worker.runners import RunnerId, ShardAssignments, ShardMetadata
|
||||
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
|
||||
|
||||
@@ -19,6 +20,8 @@ class InstanceMeta(str, Enum):
|
||||
class BaseInstance(TaggedModel):
|
||||
instance_id: InstanceId
|
||||
shard_assignments: ShardAssignments
|
||||
draft_model: ModelId | None = None # For speculative decoding (rank 0 only)
|
||||
num_draft_tokens: int = 4 # Tokens to draft per iteration (when draft_model is set)
|
||||
|
||||
def shard(self, runner_id: RunnerId) -> ShardMetadata | None:
|
||||
return self.shard_assignments.runner_to_shard.get(runner_id, None)
|
||||
|
||||
@@ -48,6 +48,8 @@ def maybe_quantize_kv_cache(
|
||||
def warmup_inference(
|
||||
model: Model,
|
||||
tokenizer: TokenizerWrapper,
|
||||
draft_model: Model | None = None,
|
||||
num_draft_tokens: int = 4,
|
||||
) -> int:
|
||||
content = "Prompt to warm up the inference engine. Repeat this."
|
||||
|
||||
@@ -66,25 +68,30 @@ def warmup_inference(
|
||||
|
||||
tokens_generated = 0
|
||||
|
||||
cache = make_kv_cache(
|
||||
model=model,
|
||||
)
|
||||
|
||||
# Use a default sampler for warmup
|
||||
sampler = make_sampler(temp=0.7)
|
||||
|
||||
generate_kwargs: dict[str, object] = {
|
||||
"model": model,
|
||||
"tokenizer": tokenizer,
|
||||
"prompt": warmup_prompt,
|
||||
"max_tokens": 50,
|
||||
"sampler": sampler,
|
||||
"prefill_step_size": 2048,
|
||||
"kv_group_size": KV_GROUP_SIZE,
|
||||
"kv_bits": KV_BITS,
|
||||
}
|
||||
|
||||
# Warm up with draft model if provided (speculative decoding path)
|
||||
if draft_model is not None:
|
||||
logger.info("Warming up with speculative decoding (draft model)")
|
||||
generate_kwargs["draft_model"] = draft_model
|
||||
generate_kwargs["num_draft_tokens"] = num_draft_tokens
|
||||
else:
|
||||
generate_kwargs["prompt_cache"] = make_kv_cache(model=model)
|
||||
|
||||
logger.info("Generating warmup tokens")
|
||||
for _r in stream_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
prompt=warmup_prompt,
|
||||
max_tokens=50,
|
||||
sampler=sampler,
|
||||
prompt_cache=cache,
|
||||
prefill_step_size=2048,
|
||||
kv_group_size=KV_GROUP_SIZE,
|
||||
kv_bits=KV_BITS,
|
||||
):
|
||||
for _r in stream_generate(**generate_kwargs): # type: ignore[arg-type]
|
||||
logger.info("Generated warmup token: " + str(_r.text))
|
||||
tokens_generated += 1
|
||||
|
||||
@@ -119,6 +126,8 @@ def mlx_generate(
|
||||
model: Model,
|
||||
tokenizer: TokenizerWrapper,
|
||||
task: ChatCompletionTaskParams,
|
||||
draft_model: Model | None = None,
|
||||
num_draft_tokens: int = 4,
|
||||
) -> Generator[GenerationResponse]:
|
||||
# Ensure that generation stats only contains peak memory for this generation
|
||||
mx.reset_peak_memory()
|
||||
@@ -135,8 +144,6 @@ def mlx_generate(
|
||||
chat_task_data=task,
|
||||
)
|
||||
|
||||
caches = make_kv_cache(model=model)
|
||||
|
||||
logits_processors: list[Callable[[mx.array, mx.array], mx.array]] = []
|
||||
if is_bench:
|
||||
# Only sample length eos tokens
|
||||
@@ -149,19 +156,31 @@ def mlx_generate(
|
||||
)
|
||||
|
||||
max_tokens = task.max_tokens or MAX_TOKENS
|
||||
for out in stream_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
prompt=prompt,
|
||||
max_tokens=max_tokens,
|
||||
sampler=sampler,
|
||||
logits_processors=logits_processors,
|
||||
prompt_cache=caches,
|
||||
# TODO: Dynamically change prefill step size to be the maximum possible without timing out.
|
||||
prefill_step_size=2048,
|
||||
kv_group_size=KV_GROUP_SIZE,
|
||||
kv_bits=KV_BITS,
|
||||
):
|
||||
|
||||
# Build kwargs for stream_generate, conditionally adding draft model params
|
||||
generate_kwargs: dict[str, object] = {
|
||||
"model": model,
|
||||
"tokenizer": tokenizer,
|
||||
"prompt": prompt,
|
||||
"max_tokens": max_tokens,
|
||||
"sampler": sampler,
|
||||
"logits_processors": logits_processors,
|
||||
"prefill_step_size": 2048,
|
||||
"kv_group_size": KV_GROUP_SIZE,
|
||||
"kv_bits": KV_BITS,
|
||||
}
|
||||
|
||||
# Add speculative decoding parameters if draft model is provided
|
||||
# Note: When using draft_model, we let mlx_lm create its own trimmable cache
|
||||
# as speculative decoding requires cache trimming capabilities
|
||||
if draft_model is not None:
|
||||
generate_kwargs["draft_model"] = draft_model
|
||||
generate_kwargs["num_draft_tokens"] = num_draft_tokens
|
||||
else:
|
||||
# Only use custom cache for non-speculative generation
|
||||
generate_kwargs["prompt_cache"] = make_kv_cache(model=model)
|
||||
|
||||
for out in stream_generate(**generate_kwargs): # type: ignore[arg-type]
|
||||
logger.info(out.text)
|
||||
|
||||
stats: GenerationStats | None = None
|
||||
|
||||
@@ -258,6 +258,27 @@ def load_mlx_items(
|
||||
return cast(Model, model), tokenizer
|
||||
|
||||
|
||||
def load_draft_model(model_id: str) -> nn.Module:
|
||||
"""Load a draft model for speculative decoding (rank 0 only).
|
||||
|
||||
Draft models are small models (typically 0.5B-2B parameters) used to
|
||||
generate candidate tokens quickly, which are then verified by the main
|
||||
model in a single forward pass.
|
||||
|
||||
Assumes the model has already been downloaded by the worker.
|
||||
|
||||
Args:
|
||||
model_id: HuggingFace model ID for the draft model
|
||||
|
||||
Returns:
|
||||
The loaded draft model
|
||||
"""
|
||||
model_path = build_model_path(model_id)
|
||||
draft_model, _ = load_model(model_path, strict=True)
|
||||
logger.info(f"Loaded draft model from {model_path}")
|
||||
return draft_model
|
||||
|
||||
|
||||
def shard_and_load(
|
||||
shard_metadata: ShardMetadata,
|
||||
group: Group,
|
||||
|
||||
@@ -29,7 +29,9 @@ from exo.shared.types.profiling import MemoryPerformanceProfile, NodePerformance
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.types.tasks import (
|
||||
CreateRunner,
|
||||
DownloadDraftModel,
|
||||
DownloadModel,
|
||||
SetDraftModel,
|
||||
Shutdown,
|
||||
Task,
|
||||
TaskStatus,
|
||||
@@ -48,6 +50,7 @@ from exo.utils.event_buffer import OrderedBuffer
|
||||
from exo.worker.download.download_utils import (
|
||||
map_repo_download_progress_to_download_progress_data,
|
||||
)
|
||||
from exo.worker.download.impl_shard_downloader import build_full_shard
|
||||
from exo.worker.download.shard_downloader import RepoDownloadProgress, ShardDownloader
|
||||
from exo.worker.plan import plan
|
||||
from exo.worker.runner.runner_supervisor import RunnerSupervisor
|
||||
@@ -202,42 +205,10 @@ class Worker:
|
||||
)
|
||||
)
|
||||
case DownloadModel(shard_metadata=shard):
|
||||
if shard.model_meta.model_id not in self.download_status:
|
||||
progress = DownloadPending(
|
||||
shard_metadata=shard, node_id=self.node_id
|
||||
)
|
||||
self.download_status[shard.model_meta.model_id] = progress
|
||||
await self.event_sender.send(
|
||||
NodeDownloadProgress(download_progress=progress)
|
||||
)
|
||||
initial_progress = (
|
||||
await self.shard_downloader.get_shard_download_status_for_shard(
|
||||
shard
|
||||
)
|
||||
)
|
||||
if initial_progress.status == "complete":
|
||||
progress = DownloadCompleted(
|
||||
shard_metadata=shard,
|
||||
node_id=self.node_id,
|
||||
total_bytes=initial_progress.total_bytes,
|
||||
)
|
||||
self.download_status[shard.model_meta.model_id] = progress
|
||||
await self.event_sender.send(
|
||||
NodeDownloadProgress(download_progress=progress)
|
||||
)
|
||||
await self.event_sender.send(
|
||||
TaskStatusUpdated(
|
||||
task_id=task.task_id,
|
||||
task_status=TaskStatus.Complete,
|
||||
)
|
||||
)
|
||||
else:
|
||||
await self.event_sender.send(
|
||||
TaskStatusUpdated(
|
||||
task_id=task.task_id, task_status=TaskStatus.Running
|
||||
)
|
||||
)
|
||||
self._handle_shard_download_process(task, initial_progress)
|
||||
await self._handle_download(shard, task)
|
||||
case DownloadDraftModel(model_id=model_id):
|
||||
shard = await build_full_shard(model_id)
|
||||
await self._handle_download(shard, task)
|
||||
case Shutdown(runner_id=runner_id):
|
||||
try:
|
||||
with fail_after(3):
|
||||
@@ -248,6 +219,25 @@ class Worker:
|
||||
task_id=task.task_id, task_status=TaskStatus.TimedOut
|
||||
)
|
||||
)
|
||||
case SetDraftModel(
|
||||
model_id=draft_model_id, num_draft_tokens=num_tokens
|
||||
):
|
||||
runner = self.runners[self._task_to_runner_id(task)]
|
||||
await runner.start_task(task)
|
||||
# Update bound_instance to reflect new/cleared draft model
|
||||
updated_instance = runner.bound_instance.instance.model_copy(
|
||||
update={
|
||||
"draft_model": (
|
||||
ModelId(draft_model_id)
|
||||
if draft_model_id is not None
|
||||
else None
|
||||
),
|
||||
"num_draft_tokens": num_tokens,
|
||||
}
|
||||
)
|
||||
runner.bound_instance = runner.bound_instance.model_copy(
|
||||
update={"instance": updated_instance}
|
||||
)
|
||||
case task:
|
||||
await self.runners[self._task_to_runner_id(task)].start_task(task)
|
||||
|
||||
@@ -340,6 +330,46 @@ class Worker:
|
||||
self._tg.start_soon(runner.run)
|
||||
return runner
|
||||
|
||||
async def _handle_download(self, shard: ShardMetadata, task: Task) -> None:
|
||||
"""Handle model download - shared logic for main and draft models."""
|
||||
model_id = shard.model_meta.model_id
|
||||
|
||||
if model_id not in self.download_status:
|
||||
progress = DownloadPending(shard_metadata=shard, node_id=self.node_id)
|
||||
self.download_status[model_id] = progress
|
||||
await self.event_sender.send(
|
||||
NodeDownloadProgress(download_progress=progress)
|
||||
)
|
||||
|
||||
initial_progress = (
|
||||
await self.shard_downloader.get_shard_download_status_for_shard(shard)
|
||||
)
|
||||
|
||||
if initial_progress.status == "complete":
|
||||
progress = DownloadCompleted(
|
||||
shard_metadata=shard,
|
||||
node_id=self.node_id,
|
||||
total_bytes=initial_progress.total_bytes,
|
||||
)
|
||||
self.download_status[model_id] = progress
|
||||
await self.event_sender.send(
|
||||
NodeDownloadProgress(download_progress=progress)
|
||||
)
|
||||
await self.event_sender.send(
|
||||
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Complete)
|
||||
)
|
||||
else:
|
||||
await self.event_sender.send(
|
||||
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Running)
|
||||
)
|
||||
download_task = DownloadModel(
|
||||
instance_id=task.instance_id,
|
||||
shard_metadata=shard,
|
||||
task_id=task.task_id,
|
||||
task_status=task.task_status,
|
||||
)
|
||||
self._handle_shard_download_process(download_task, initial_progress)
|
||||
|
||||
def _handle_shard_download_process(
|
||||
self,
|
||||
task: DownloadModel,
|
||||
|
||||
@@ -8,8 +8,10 @@ from exo.shared.types.tasks import (
|
||||
ChatCompletion,
|
||||
ConnectToGroup,
|
||||
CreateRunner,
|
||||
DownloadDraftModel,
|
||||
DownloadModel,
|
||||
LoadModel,
|
||||
SetDraftModel,
|
||||
Shutdown,
|
||||
StartWarmup,
|
||||
Task,
|
||||
@@ -38,6 +40,16 @@ from exo.shared.types.worker.runners import (
|
||||
from exo.worker.runner.runner_supervisor import RunnerSupervisor
|
||||
|
||||
|
||||
def _is_download_in_progress_or_complete(
|
||||
model_id: ModelId,
|
||||
download_status: Mapping[ModelId, DownloadProgress],
|
||||
) -> bool:
|
||||
"""Check if model download is in progress or complete."""
|
||||
return model_id in download_status and isinstance(
|
||||
download_status[model_id], (DownloadOngoing, DownloadCompleted)
|
||||
)
|
||||
|
||||
|
||||
def plan(
|
||||
node_id: NodeId,
|
||||
# Runners is expected to be FRESH and so should not come from state
|
||||
@@ -55,9 +67,11 @@ def plan(
|
||||
_kill_runner(runners, all_runners, instances)
|
||||
or _create_runner(node_id, runners, instances)
|
||||
or _model_needs_download(runners, download_status)
|
||||
or _draft_model_needs_download(runners, download_status, instances)
|
||||
or _init_distributed_backend(runners, all_runners)
|
||||
or _load_model(runners, all_runners, global_download_status)
|
||||
or _load_model(runners, all_runners, global_download_status, download_status)
|
||||
or _ready_to_warmup(runners, all_runners)
|
||||
or _set_draft_model(runners, instances, download_status)
|
||||
or _pending_tasks(runners, tasks, all_runners)
|
||||
)
|
||||
|
||||
@@ -115,12 +129,9 @@ def _model_needs_download(
|
||||
) -> DownloadModel | None:
|
||||
for runner in runners.values():
|
||||
model_id = runner.bound_instance.bound_shard.model_meta.model_id
|
||||
if isinstance(runner.status, RunnerIdle) and (
|
||||
model_id not in download_status
|
||||
or not isinstance(
|
||||
download_status[model_id], (DownloadOngoing, DownloadCompleted)
|
||||
)
|
||||
):
|
||||
if isinstance(
|
||||
runner.status, RunnerIdle
|
||||
) and not _is_download_in_progress_or_complete(model_id, download_status):
|
||||
# We don't invalidate download_status randomly in case a file gets deleted on disk
|
||||
return DownloadModel(
|
||||
instance_id=runner.bound_instance.instance.instance_id,
|
||||
@@ -128,6 +139,43 @@ def _model_needs_download(
|
||||
)
|
||||
|
||||
|
||||
def _draft_model_needs_download(
|
||||
runners: Mapping[RunnerId, RunnerSupervisor],
|
||||
download_status: Mapping[ModelId, DownloadProgress],
|
||||
instances: Mapping[InstanceId, Instance],
|
||||
) -> DownloadDraftModel | None:
|
||||
"""Check if draft model needs download for rank 0 runner.
|
||||
|
||||
Triggers download when:
|
||||
- RunnerIdle with draft model (initial setup)
|
||||
- RunnerReady with new draft model (updated via API)
|
||||
"""
|
||||
rank_0_runner = next(
|
||||
(r for r in runners.values() if r.bound_instance.bound_shard.device_rank == 0),
|
||||
None,
|
||||
)
|
||||
if rank_0_runner is None:
|
||||
return None
|
||||
if not isinstance(rank_0_runner.status, (RunnerIdle, RunnerReady)):
|
||||
return None
|
||||
|
||||
# Use current instance state (may have been updated via API)
|
||||
instance_id = rank_0_runner.bound_instance.instance.instance_id
|
||||
current_instance = instances.get(instance_id)
|
||||
if current_instance is None:
|
||||
return None
|
||||
|
||||
draft_model_id = current_instance.draft_model
|
||||
if draft_model_id is None:
|
||||
return None
|
||||
if _is_download_in_progress_or_complete(draft_model_id, download_status):
|
||||
return None
|
||||
return DownloadDraftModel(
|
||||
instance_id=instance_id,
|
||||
model_id=str(draft_model_id),
|
||||
)
|
||||
|
||||
|
||||
def _init_distributed_backend(
|
||||
runners: Mapping[RunnerId, RunnerSupervisor],
|
||||
all_runners: Mapping[RunnerId, RunnerStatus],
|
||||
@@ -182,10 +230,12 @@ def _load_model(
|
||||
runners: Mapping[RunnerId, RunnerSupervisor],
|
||||
all_runners: Mapping[RunnerId, RunnerStatus],
|
||||
global_download_status: Mapping[NodeId, Sequence[DownloadProgress]],
|
||||
download_status: Mapping[ModelId, DownloadProgress],
|
||||
) -> LoadModel | None:
|
||||
for runner in runners.values():
|
||||
instance = runner.bound_instance.instance
|
||||
shard_assignments = instance.shard_assignments
|
||||
shard = runner.bound_instance.bound_shard
|
||||
|
||||
all_local_downloads_complete = all(
|
||||
nid in global_download_status
|
||||
@@ -199,6 +249,14 @@ def _load_model(
|
||||
if not all_local_downloads_complete:
|
||||
continue
|
||||
|
||||
# Rank 0 with draft model must wait for draft download before loading
|
||||
if shard.device_rank == 0:
|
||||
draft_model_id = instance.draft_model
|
||||
if draft_model_id is not None and not isinstance(
|
||||
download_status.get(draft_model_id), DownloadCompleted
|
||||
):
|
||||
continue
|
||||
|
||||
is_single_node_instance = len(instance.shard_assignments.runner_to_shard) == 1
|
||||
if is_single_node_instance and isinstance(runner.status, RunnerIdle):
|
||||
return LoadModel(instance_id=instance.instance_id)
|
||||
@@ -258,6 +316,53 @@ def _ready_to_warmup(
|
||||
return None
|
||||
|
||||
|
||||
def _set_draft_model(
|
||||
runners: Mapping[RunnerId, RunnerSupervisor],
|
||||
instances: Mapping[InstanceId, Instance],
|
||||
download_status: Mapping[ModelId, DownloadProgress],
|
||||
) -> SetDraftModel | None:
|
||||
"""Check if rank 0 runner needs to load or clear a draft model."""
|
||||
rank_0_runner = next(
|
||||
(r for r in runners.values() if r.bound_instance.bound_shard.device_rank == 0),
|
||||
None,
|
||||
)
|
||||
if rank_0_runner is None:
|
||||
return None
|
||||
if not isinstance(rank_0_runner.status, RunnerReady):
|
||||
return None
|
||||
|
||||
instance_id = rank_0_runner.bound_instance.instance.instance_id
|
||||
current_instance = instances.get(instance_id)
|
||||
if current_instance is None:
|
||||
return None
|
||||
|
||||
# Compare runner's bound draft model vs current instance draft model
|
||||
runner_draft_model = rank_0_runner.bound_instance.instance.draft_model
|
||||
current_draft_model = current_instance.draft_model
|
||||
|
||||
if runner_draft_model == current_draft_model:
|
||||
return None
|
||||
|
||||
# Draft model changed - need to update
|
||||
if current_draft_model is None:
|
||||
# Clear draft model
|
||||
return SetDraftModel(
|
||||
instance_id=instance_id,
|
||||
model_id=None,
|
||||
num_draft_tokens=4,
|
||||
)
|
||||
|
||||
# Wait for draft model to be downloaded
|
||||
if not isinstance(download_status.get(current_draft_model), DownloadCompleted):
|
||||
return None
|
||||
|
||||
return SetDraftModel(
|
||||
instance_id=instance_id,
|
||||
model_id=str(current_draft_model),
|
||||
num_draft_tokens=current_instance.num_draft_tokens,
|
||||
)
|
||||
|
||||
|
||||
def _pending_tasks(
|
||||
runners: Mapping[RunnerId, RunnerSupervisor],
|
||||
tasks: Mapping[TaskId, Task],
|
||||
|
||||
@@ -28,6 +28,7 @@ from exo.shared.types.tasks import (
|
||||
ChatCompletion,
|
||||
ConnectToGroup,
|
||||
LoadModel,
|
||||
SetDraftModel,
|
||||
Shutdown,
|
||||
StartWarmup,
|
||||
Task,
|
||||
@@ -56,6 +57,7 @@ from exo.worker.engines.mlx import Model
|
||||
from exo.worker.engines.mlx.generator.generate import mlx_generate, warmup_inference
|
||||
from exo.worker.engines.mlx.utils_mlx import (
|
||||
initialize_mlx,
|
||||
load_draft_model,
|
||||
load_mlx_items,
|
||||
mlx_force_oom,
|
||||
)
|
||||
@@ -110,6 +112,7 @@ def main(
|
||||
model = None
|
||||
tokenizer = None
|
||||
group = None
|
||||
draft_model: Model | None = None # Loaded during warmup if instance has draft_model
|
||||
|
||||
current_status: RunnerStatus = RunnerIdle()
|
||||
logger.info("runner created")
|
||||
@@ -165,6 +168,16 @@ def main(
|
||||
bound_instance, group, on_timeout=on_model_load_timeout
|
||||
)
|
||||
|
||||
# Load draft model for speculative decoding (rank 0 only)
|
||||
if (
|
||||
instance.draft_model is not None
|
||||
and shard_metadata.device_rank == 0
|
||||
):
|
||||
logger.info(f"Loading draft model: {instance.draft_model}")
|
||||
draft_model = cast(
|
||||
Model, load_draft_model(str(instance.draft_model))
|
||||
)
|
||||
|
||||
current_status = RunnerLoaded()
|
||||
logger.info("runner loaded")
|
||||
case StartWarmup() if isinstance(current_status, RunnerLoaded):
|
||||
@@ -182,7 +195,8 @@ def main(
|
||||
toks = warmup_inference(
|
||||
model=cast(Model, model),
|
||||
tokenizer=tokenizer,
|
||||
# kv_prefix_cache=kv_prefix_cache, # supply for warmup-time prefix caching
|
||||
draft_model=draft_model,
|
||||
num_draft_tokens=instance.num_draft_tokens,
|
||||
)
|
||||
logger.info(f"warmed up by generating {toks} tokens")
|
||||
logger.info(
|
||||
@@ -212,11 +226,13 @@ def main(
|
||||
assert task_params.messages[0].content is not None
|
||||
_check_for_debug_prompts(task_params.messages[0].content)
|
||||
|
||||
# Generate responses using the actual MLX generation
|
||||
# Generate responses (draft_model loaded at warmup if configured)
|
||||
mlx_generator = mlx_generate(
|
||||
model=cast(Model, model),
|
||||
tokenizer=tokenizer,
|
||||
task=task_params,
|
||||
draft_model=draft_model,
|
||||
num_draft_tokens=instance.num_draft_tokens,
|
||||
)
|
||||
|
||||
# GPT-OSS specific parsing to match other model formats.
|
||||
@@ -245,6 +261,50 @@ def main(
|
||||
|
||||
current_status = RunnerReady()
|
||||
logger.info("runner ready")
|
||||
case SetDraftModel(
|
||||
model_id=draft_model_id, num_draft_tokens=num_tokens
|
||||
) if isinstance(current_status, RunnerReady):
|
||||
current_status = RunnerWarmingUp()
|
||||
logger.info("runner warming up (setting draft model)")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
assert model is not None
|
||||
assert tokenizer is not None
|
||||
|
||||
if draft_model_id is None:
|
||||
# Clear draft model
|
||||
logger.info("Clearing draft model")
|
||||
draft_model = None
|
||||
instance = instance.model_copy(
|
||||
update={
|
||||
"draft_model": None,
|
||||
"num_draft_tokens": 4,
|
||||
}
|
||||
)
|
||||
else:
|
||||
# Load new draft model
|
||||
logger.info(f"Loading draft model: {draft_model_id}")
|
||||
draft_model = cast(Model, load_draft_model(draft_model_id))
|
||||
instance = instance.model_copy(
|
||||
update={
|
||||
"draft_model": ModelId(draft_model_id),
|
||||
"num_draft_tokens": num_tokens,
|
||||
}
|
||||
)
|
||||
# Warm up with speculative decoding
|
||||
logger.info("Warming up with new draft model")
|
||||
warmup_inference(
|
||||
model=cast(Model, model),
|
||||
tokenizer=tokenizer,
|
||||
draft_model=draft_model,
|
||||
num_draft_tokens=num_tokens,
|
||||
)
|
||||
logger.info("Draft model loaded and warmed up")
|
||||
|
||||
current_status = RunnerReady()
|
||||
case Shutdown():
|
||||
current_status = RunnerShuttingDown()
|
||||
logger.info("runner shutting down")
|
||||
@@ -265,7 +325,7 @@ def main(
|
||||
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
|
||||
)
|
||||
if isinstance(current_status, RunnerShutdown):
|
||||
del model, tokenizer, group
|
||||
del model, tokenizer, group, draft_model
|
||||
mx.clear_cache()
|
||||
import gc
|
||||
|
||||
|
||||
Reference in New Issue
Block a user