Compare commits

...

1 Commits

Author SHA1 Message Date
Alex Cheema
e7f61c3494 Add speculative decoding support with draft models
This adds support for speculative decoding using draft models to accelerate
inference. Key changes:

- Add draft_model and num_draft_tokens fields to Instance for configuration
- Add SetDraftModel task to load/clear draft models on running instances
- Add InstanceDraftModelUpdated event to propagate draft model changes
- Add SetInstanceDraftModel command and API endpoint for runtime updates
- Update plan.py to download draft models in parallel with main model
- Update runner to load draft model during LoadModel phase
- Add draft model UI to dashboard instances panel (both views)

The draft model can be configured when creating an instance or updated on
a running instance via the dashboard or API.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-22 13:09:26 +00:00
16 changed files with 687 additions and 91 deletions

View File

@@ -69,6 +69,8 @@ export interface Instance {
runnerToShard?: Record<string, unknown>;
nodeToRunner?: Record<string, string>;
};
draftModel?: string;
numDraftTokens?: number;
}
// Granular node state types from the new state structure

View File

@@ -162,6 +162,12 @@
let launchingModelId = $state<string | null>(null);
let instanceDownloadExpandedNodes = $state<Set<string>>(new Set());
// Draft model editing state
let editingDraftInstanceId = $state<string | null>(null);
let editDraftModel = $state<string | null>(null);
let editNumDraftTokens = $state<number>(4);
let draftEditDropdownSearch = $state("");
// Custom dropdown state
let isModelDropdownOpen = $state(false);
let modelDropdownSearch = $state("");
@@ -1012,6 +1018,53 @@
}
}
// Open draft model edit modal for an instance
function openDraftModelEdit(
instanceId: string,
currentDraftModel: string | null,
currentNumTokens: number | null,
): void {
editingDraftInstanceId = instanceId;
editDraftModel = currentDraftModel;
editNumDraftTokens = currentNumTokens ?? 4;
draftEditDropdownSearch = "";
}
// Close draft model edit modal
function closeDraftModelEdit(): void {
editingDraftInstanceId = null;
editDraftModel = null;
editNumDraftTokens = 4;
draftEditDropdownSearch = "";
}
// Save draft model settings for an instance
async function saveDraftModel(): Promise<void> {
if (!editingDraftInstanceId) return;
try {
const response = await fetch(
`/instance/${editingDraftInstanceId}/draft_model`,
{
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({
draft_model: editDraftModel,
num_draft_tokens: editNumDraftTokens,
}),
},
);
if (!response.ok) {
const errorText = await response.text();
console.error("Failed to set draft model:", errorText);
}
} catch (error) {
console.error("Error setting draft model:", error);
} finally {
closeDraftModelEdit();
}
}
// Helper to unwrap tagged unions like { MlxRingInstance: {...} }
function getTagged(obj: unknown): [string | null, unknown] {
if (!obj || typeof obj !== "object") return [null, null];
@@ -1037,6 +1090,8 @@
nodeNames: string[];
nodeIds: string[];
nodeCount: number;
draftModel: string | null;
numDraftTokens: number | null;
} {
const [instanceTag, instance] = getTagged(instanceWrapped);
if (!instance || typeof instance !== "object") {
@@ -1046,6 +1101,8 @@
nodeNames: [],
nodeIds: [],
nodeCount: 0,
draftModel: null,
numDraftTokens: null,
};
}
@@ -1063,6 +1120,8 @@
nodeToRunner?: Record<string, string>;
runnerToShard?: Record<string, unknown>;
};
draftModel?: string;
numDraftTokens?: number;
};
// Sharding strategy from first shard
@@ -1085,12 +1144,18 @@
return node?.friendly_name || nodeId.slice(0, 8);
});
// Draft model for speculative decoding
const draftModel = inst.draftModel ?? null;
const numDraftTokens = inst.numDraftTokens ?? null;
return {
instanceType,
sharding,
nodeNames,
nodeIds,
nodeCount: nodeIds.length,
draftModel,
numDraftTokens,
};
}
@@ -1788,12 +1853,42 @@
>{id.slice(0, 8).toUpperCase()}</span
>
</div>
<button
onclick={() => deleteInstance(id)}
class="text-xs px-2 py-1 font-mono tracking-wider uppercase border border-red-500/30 text-red-400 hover:bg-red-500/20 hover:text-red-400 hover:border-red-500/50 transition-all duration-200 cursor-pointer"
>
DELETE
</button>
<div class="flex items-center gap-2">
<button
onclick={() =>
openDraftModelEdit(
id,
instanceInfo.draftModel,
instanceInfo.numDraftTokens,
)}
class="p-1.5 font-mono border transition-all duration-200 cursor-pointer {instanceInfo.draftModel
? 'border-cyan-500/50 text-cyan-400 hover:bg-cyan-500/20 hover:border-cyan-500'
: 'border-exo-medium-gray/50 text-white/40 hover:text-cyan-400 hover:border-cyan-500/50'}"
title={instanceInfo.draftModel
? `Draft: ${instanceInfo.draftModel.split("/").pop()} (${instanceInfo.numDraftTokens}t)`
: "Configure speculative decoding"}
>
<svg
class="w-3.5 h-3.5"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
d="M13 10V3L4 14h7v7l9-11h-7z"
/>
</svg>
</button>
<button
onclick={() => deleteInstance(id)}
class="text-xs px-2 py-1 font-mono tracking-wider uppercase border border-red-500/30 text-red-400 hover:bg-red-500/20 hover:text-red-400 hover:border-red-500/50 transition-all duration-200 cursor-pointer"
>
DELETE
</button>
</div>
</div>
<div class="pl-2">
<div
@@ -1806,6 +1901,13 @@
>{instanceInfo.sharding} ({instanceInfo.instanceType})</span
>
</div>
{#if instanceInfo.draftModel}
<div class="text-cyan-400/80 text-xs font-mono">
Draft: <span class="text-cyan-400"
>{instanceInfo.draftModel.split("/").pop()} ({instanceInfo.numDraftTokens}t)</span
>
</div>
{/if}
{#if instanceModelId && instanceModelId !== "Unknown" && instanceModelId !== "Unknown Model"}
<a
class="inline-flex items-center gap-1 text-[11px] text-white/60 hover:text-exo-yellow transition-colors mt-1"
@@ -2710,12 +2812,42 @@
>{id.slice(0, 8).toUpperCase()}</span
>
</div>
<button
onclick={() => deleteInstance(id)}
class="text-xs px-2 py-1 font-mono tracking-wider uppercase border border-red-500/30 text-red-400 hover:bg-red-500/20 hover:text-red-400 hover:border-red-500/50 transition-all duration-200 cursor-pointer"
>
DELETE
</button>
<div class="flex items-center gap-2">
<button
onclick={() =>
openDraftModelEdit(
id,
instanceInfo.draftModel,
instanceInfo.numDraftTokens,
)}
class="p-1.5 font-mono border transition-all duration-200 cursor-pointer {instanceInfo.draftModel
? 'border-cyan-500/50 text-cyan-400 hover:bg-cyan-500/20 hover:border-cyan-500'
: 'border-exo-medium-gray/50 text-white/40 hover:text-cyan-400 hover:border-cyan-500/50'}"
title={instanceInfo.draftModel
? `Draft: ${instanceInfo.draftModel.split("/").pop()} (${instanceInfo.numDraftTokens}t)`
: "Configure speculative decoding"}
>
<svg
class="w-3.5 h-3.5"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
d="M13 10V3L4 14h7v7l9-11h-7z"
/>
</svg>
</button>
<button
onclick={() => deleteInstance(id)}
class="text-xs px-2 py-1 font-mono tracking-wider uppercase border border-red-500/30 text-red-400 hover:bg-red-500/20 hover:text-red-400 hover:border-red-500/50 transition-all duration-200 cursor-pointer"
>
DELETE
</button>
</div>
</div>
<div class="pl-2">
<div
@@ -2728,6 +2860,13 @@
>{instanceInfo.sharding} ({instanceInfo.instanceType})</span
>
</div>
{#if instanceInfo.draftModel}
<div class="text-cyan-400/80 text-xs font-mono">
Draft: <span class="text-cyan-400"
>{instanceInfo.draftModel.split("/").pop()} ({instanceInfo.numDraftTokens}t)</span
>
</div>
{/if}
{#if instanceModelId && instanceModelId !== "Unknown" && instanceModelId !== "Unknown Model"}
<a
class="inline-flex items-center gap-1 text-[11px] text-white/60 hover:text-exo-yellow transition-colors mt-1"
@@ -3005,5 +3144,111 @@
{/if}
</div>
{/if}
<!-- Draft Model Edit Modal -->
{#if editingDraftInstanceId}
<div
class="fixed inset-0 bg-black/80 backdrop-blur-sm z-50 flex items-center justify-center"
role="dialog"
aria-modal="true"
onclick={closeDraftModelEdit}
onkeydown={(e) => e.key === "Escape" && closeDraftModelEdit()}
>
<div
class="bg-exo-dark-gray border border-exo-yellow/30 p-6 max-w-md w-full mx-4"
onclick={(e) => e.stopPropagation()}
onkeydown={(e) => e.stopPropagation()}
role="document"
>
<h3 class="text-exo-yellow font-mono text-lg tracking-wider mb-4">
SPECULATIVE DECODING
</h3>
<div class="space-y-4">
<!-- Draft Model Selection -->
<div>
<label
class="block text-white/60 text-xs font-mono tracking-wider mb-2"
>DRAFT MODEL</label
>
<div class="relative">
<input
type="text"
bind:value={draftEditDropdownSearch}
placeholder={editDraftModel || "Select a draft model..."}
class="w-full bg-exo-dark-gray/60 border border-exo-yellow/30 text-white text-sm font-mono px-3 py-2 focus:outline-none focus:border-exo-yellow/60"
/>
{#if draftEditDropdownSearch}
<div
class="absolute z-10 w-full mt-1 bg-exo-dark-gray border border-exo-yellow/30 max-h-48 overflow-y-auto"
>
{#each models.filter((m) => m.id
.toLowerCase()
.includes(draftEditDropdownSearch.toLowerCase())) as model}
<button
class="w-full text-left px-3 py-2 text-sm font-mono text-white/80 hover:bg-exo-yellow/10 hover:text-exo-yellow cursor-pointer"
onclick={() => {
editDraftModel = model.hugging_face_id || model.id;
draftEditDropdownSearch = "";
}}
>
{model.name || model.id}
</button>
{/each}
</div>
{/if}
</div>
{#if editDraftModel}
<div
class="mt-2 flex items-center justify-between text-cyan-400 text-xs font-mono"
>
<span>{editDraftModel}</span>
<button
class="text-red-400 hover:text-red-300 cursor-pointer"
onclick={() => (editDraftModel = null)}
>
Clear
</button>
</div>
{/if}
</div>
<!-- Num Draft Tokens -->
<div>
<label
class="block text-white/60 text-xs font-mono tracking-wider mb-2"
>DRAFT TOKENS</label
>
<input
type="number"
min="1"
max="16"
bind:value={editNumDraftTokens}
class="w-full bg-exo-dark-gray/60 border border-exo-yellow/30 text-white text-sm font-mono px-3 py-2 focus:outline-none focus:border-exo-yellow/60"
/>
<p class="text-white/40 text-xs font-mono mt-1">
Number of tokens to draft per step (1-16)
</p>
</div>
</div>
<!-- Actions -->
<div class="flex justify-end gap-3 mt-6">
<button
onclick={closeDraftModelEdit}
class="px-4 py-2 text-xs font-mono tracking-wider uppercase border border-white/30 text-white/60 hover:bg-white/10 hover:text-white transition-all cursor-pointer"
>
CANCEL
</button>
<button
onclick={saveDraftModel}
class="px-4 py-2 text-xs font-mono tracking-wider uppercase border border-cyan-500/50 text-cyan-400 hover:bg-cyan-500/20 transition-all cursor-pointer"
>
SAVE
</button>
</div>
</div>
</div>
{/if}
</main>
</div>

View File

@@ -55,6 +55,8 @@ from exo.shared.types.api import (
PlaceInstanceParams,
PlacementPreview,
PlacementPreviewResponse,
SetDraftModelParams,
SetDraftModelResponse,
StreamingChoiceResponse,
)
from exo.shared.types.chunks import ImageChunk, InputImageChunk, TokenChunk
@@ -68,6 +70,7 @@ from exo.shared.types.commands import (
ImageGeneration,
PlaceInstance,
SendInputChunk,
SetInstanceDraftModel,
TaskFinished,
)
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
@@ -215,6 +218,7 @@ class API:
self.app.get("/instance/previews")(self.get_placement_previews)
self.app.get("/instance/{instance_id}")(self.get_instance)
self.app.delete("/instance/{instance_id}")(self.delete_instance)
self.app.put("/instance/{instance_id}/draft_model")(self.set_draft_model)
self.app.get("/models")(self.get_models)
self.app.get("/v1/models")(self.get_models)
self.app.post("/v1/chat/completions", response_model=None)(
@@ -238,6 +242,8 @@ class API:
sharding=payload.sharding,
instance_meta=payload.instance_meta,
min_nodes=payload.min_nodes,
draft_model=payload.draft_model,
num_draft_tokens=payload.num_draft_tokens,
)
await self._send(command)
@@ -437,6 +443,24 @@ class API:
instance_id=instance_id,
)
async def set_draft_model(
self, instance_id: InstanceId, payload: SetDraftModelParams
) -> SetDraftModelResponse:
if instance_id not in self.state.instances:
raise HTTPException(status_code=404, detail="Instance not found")
command = SetInstanceDraftModel(
instance_id=instance_id,
draft_model=payload.draft_model,
num_draft_tokens=payload.num_draft_tokens,
)
await self._send(command)
return SetDraftModelResponse(
message="Command received.",
command_id=command.command_id,
instance_id=instance_id,
)
async def _chat_chunk_stream(
self, command_id: CommandId
) -> AsyncGenerator[TokenChunk, None]:

View File

@@ -21,6 +21,7 @@ from exo.shared.types.commands import (
PlaceInstance,
RequestEventLog,
SendInputChunk,
SetInstanceDraftModel,
TaskFinished,
TestCommand,
)
@@ -31,6 +32,7 @@ from exo.shared.types.events import (
IndexedEvent,
InputChunkReceived,
InstanceDeleted,
InstanceDraftModelUpdated,
NodeGatheredInfo,
NodeTimedOut,
TaskCreated,
@@ -278,6 +280,14 @@ class Master:
chunk=chunk,
)
)
case SetInstanceDraftModel():
generated_events.append(
InstanceDraftModelUpdated(
instance_id=command.instance_id,
draft_model=command.draft_model,
num_draft_tokens=command.num_draft_tokens,
)
)
case TaskFinished():
generated_events.append(
TaskDeleted(

View File

@@ -138,6 +138,8 @@ def place_instance(
shard_assignments=shard_assignments,
jaccl_devices=mlx_jaccl_devices,
jaccl_coordinators=mlx_jaccl_coordinators,
draft_model=command.draft_model,
num_draft_tokens=command.num_draft_tokens,
)
case InstanceMeta.MlxRing:
ephemeral_port = random_ephemeral_port()
@@ -152,6 +154,8 @@ def place_instance(
shard_assignments=shard_assignments,
hosts_by_node=hosts_by_node,
ephemeral_port=ephemeral_port,
draft_model=command.draft_model,
num_draft_tokens=command.num_draft_tokens,
)
return target_instances

View File

@@ -12,6 +12,7 @@ from exo.shared.types.events import (
InputChunkReceived,
InstanceCreated,
InstanceDeleted,
InstanceDraftModelUpdated,
NodeDownloadProgress,
NodeGatheredInfo,
NodeTimedOut,
@@ -60,6 +61,8 @@ def event_apply(event: Event, state: State) -> State:
return apply_instance_created(event, state)
case InstanceDeleted():
return apply_instance_deleted(event, state)
case InstanceDraftModelUpdated():
return apply_instance_draft_model_updated(event, state)
case NodeTimedOut():
return apply_node_timed_out(event, state)
case NodeDownloadProgress():
@@ -178,6 +181,25 @@ def apply_instance_deleted(event: InstanceDeleted, state: State) -> State:
return state.model_copy(update={"instances": new_instances})
def apply_instance_draft_model_updated(
event: InstanceDraftModelUpdated, state: State
) -> State:
if event.instance_id not in state.instances:
return state
instance = state.instances[event.instance_id]
updated_instance = instance.model_copy(
update={
"draft_model": event.draft_model,
"num_draft_tokens": event.num_draft_tokens,
}
)
new_instances: Mapping[InstanceId, Instance] = {
**state.instances,
event.instance_id: updated_instance,
}
return state.model_copy(update={"instances": new_instances})
def apply_runner_status_updated(event: RunnerStatusUpdated, state: State) -> State:
new_runners: Mapping[RunnerId, RunnerStatus] = {
**state.runners,

View File

@@ -177,6 +177,8 @@ class ChatCompletionTaskParams(BaseModel):
tool_choice: str | dict[str, Any] | None = None
parallel_tool_calls: bool | None = None
user: str | None = None
# Speculative decoding: tokens to draft per iteration (if instance has draft model)
num_draft_tokens: int = 3
class BenchChatCompletionTaskParams(ChatCompletionTaskParams):
@@ -188,6 +190,8 @@ class PlaceInstanceParams(BaseModel):
sharding: Sharding = Sharding.Pipeline
instance_meta: InstanceMeta = InstanceMeta.MlxRing
min_nodes: int = 1
draft_model: ModelId | None = None # For speculative decoding
num_draft_tokens: int = 4 # Tokens to draft per iteration
@field_validator("sharding", "instance_meta", mode="plain")
@classmethod
@@ -340,3 +344,14 @@ class ImageListItem(BaseModel, frozen=True):
class ImageListResponse(BaseModel, frozen=True):
data: list[ImageListItem]
class SetDraftModelParams(BaseModel):
draft_model: ModelId | None = None # None to disable speculative decoding
num_draft_tokens: int = 4
class SetDraftModelResponse(BaseModel):
message: str
command_id: CommandId
instance_id: InstanceId

View File

@@ -7,7 +7,7 @@ from exo.shared.types.api import (
ImageGenerationTaskParams,
)
from exo.shared.types.chunks import InputImageChunk
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.common import CommandId, ModelId, NodeId
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.shards import Sharding
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
@@ -38,6 +38,8 @@ class PlaceInstance(BaseCommand):
sharding: Sharding
instance_meta: InstanceMeta
min_nodes: int
draft_model: ModelId | None = None # For speculative decoding
num_draft_tokens: int = 4 # Tokens to draft per iteration
class CreateInstance(BaseCommand):
@@ -48,6 +50,14 @@ class DeleteInstance(BaseCommand):
instance_id: InstanceId
class SetInstanceDraftModel(BaseCommand):
"""Set or update the draft model for an existing instance."""
instance_id: InstanceId
draft_model: ModelId | None # None to disable speculative decoding
num_draft_tokens: int = 4
class TaskFinished(BaseCommand):
finished_command_id: CommandId
@@ -71,6 +81,7 @@ Command = (
| PlaceInstance
| CreateInstance
| DeleteInstance
| SetInstanceDraftModel
| TaskFinished
| SendInputChunk
)

View File

@@ -4,7 +4,7 @@ from pydantic import Field
from exo.shared.topology import Connection
from exo.shared.types.chunks import GenerationChunk, InputImageChunk
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
from exo.shared.types.common import CommandId, Id, ModelId, NodeId, SessionId
from exo.shared.types.tasks import Task, TaskId, TaskStatus
from exo.shared.types.worker.downloads import DownloadProgress
from exo.shared.types.worker.instances import Instance, InstanceId
@@ -67,6 +67,14 @@ class InstanceDeleted(BaseEvent):
instance_id: InstanceId
class InstanceDraftModelUpdated(BaseEvent):
"""Draft model updated on an existing instance."""
instance_id: InstanceId
draft_model: ModelId | None
num_draft_tokens: int
class RunnerStatusUpdated(BaseEvent):
runner_id: RunnerId
runner_status: RunnerStatus
@@ -118,6 +126,7 @@ Event = (
| TaskAcknowledged
| InstanceCreated
| InstanceDeleted
| InstanceDraftModelUpdated
| RunnerStatusUpdated
| RunnerDeleted
| NodeTimedOut

View File

@@ -40,6 +40,12 @@ class DownloadModel(BaseTask): # emitted by Worker
shard_metadata: ShardMetadata
class DownloadDraftModel(BaseTask): # emitted by Worker
"""Download a draft model for speculative decoding (rank 0 only)."""
model_id: str # HuggingFace model ID
class LoadModel(BaseTask): # emitted by Worker
pass
@@ -80,9 +86,17 @@ class Shutdown(BaseTask): # emitted by Worker
runner_id: RunnerId
class SetDraftModel(BaseTask): # emitted by Worker
"""Load or clear a draft model on an already-running instance."""
model_id: str | None # HuggingFace model ID, or None to clear
num_draft_tokens: int = 4
Task = (
CreateRunner
| DownloadModel
| DownloadDraftModel
| ConnectToGroup
| LoadModel
| StartWarmup
@@ -90,4 +104,5 @@ Task = (
| ImageGeneration
| ImageEdits
| Shutdown
| SetDraftModel
)

View File

@@ -2,7 +2,7 @@ from enum import Enum
from pydantic import model_validator
from exo.shared.types.common import Host, Id, NodeId
from exo.shared.types.common import Host, Id, ModelId, NodeId
from exo.shared.types.worker.runners import RunnerId, ShardAssignments, ShardMetadata
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
@@ -19,6 +19,8 @@ class InstanceMeta(str, Enum):
class BaseInstance(TaggedModel):
instance_id: InstanceId
shard_assignments: ShardAssignments
draft_model: ModelId | None = None # For speculative decoding (rank 0 only)
num_draft_tokens: int = 4 # Tokens to draft per iteration (when draft_model is set)
def shard(self, runner_id: RunnerId) -> ShardMetadata | None:
return self.shard_assignments.runner_to_shard.get(runner_id, None)

View File

@@ -48,6 +48,8 @@ def maybe_quantize_kv_cache(
def warmup_inference(
model: Model,
tokenizer: TokenizerWrapper,
draft_model: Model | None = None,
num_draft_tokens: int = 4,
) -> int:
content = "Prompt to warm up the inference engine. Repeat this."
@@ -66,25 +68,30 @@ def warmup_inference(
tokens_generated = 0
cache = make_kv_cache(
model=model,
)
# Use a default sampler for warmup
sampler = make_sampler(temp=0.7)
generate_kwargs: dict[str, object] = {
"model": model,
"tokenizer": tokenizer,
"prompt": warmup_prompt,
"max_tokens": 50,
"sampler": sampler,
"prefill_step_size": 2048,
"kv_group_size": KV_GROUP_SIZE,
"kv_bits": KV_BITS,
}
# Warm up with draft model if provided (speculative decoding path)
if draft_model is not None:
logger.info("Warming up with speculative decoding (draft model)")
generate_kwargs["draft_model"] = draft_model
generate_kwargs["num_draft_tokens"] = num_draft_tokens
else:
generate_kwargs["prompt_cache"] = make_kv_cache(model=model)
logger.info("Generating warmup tokens")
for _r in stream_generate(
model=model,
tokenizer=tokenizer,
prompt=warmup_prompt,
max_tokens=50,
sampler=sampler,
prompt_cache=cache,
prefill_step_size=2048,
kv_group_size=KV_GROUP_SIZE,
kv_bits=KV_BITS,
):
for _r in stream_generate(**generate_kwargs): # type: ignore[arg-type]
logger.info("Generated warmup token: " + str(_r.text))
tokens_generated += 1
@@ -120,6 +127,8 @@ def mlx_generate(
tokenizer: TokenizerWrapper,
task: ChatCompletionTaskParams,
prompt: str,
draft_model: Model | None = None,
num_draft_tokens: int = 4,
) -> Generator[GenerationResponse]:
# Ensure that generation stats only contains peak memory for this generation
mx.reset_peak_memory()
@@ -131,8 +140,6 @@ def mlx_generate(
if task.seed is not None:
mx.random.seed(task.seed)
caches = make_kv_cache(model=model)
logits_processors: list[Callable[[mx.array, mx.array], mx.array]] = []
if is_bench:
# Only sample length eos tokens
@@ -145,19 +152,31 @@ def mlx_generate(
)
max_tokens = task.max_tokens or MAX_TOKENS
for out in stream_generate(
model=model,
tokenizer=tokenizer,
prompt=prompt,
max_tokens=max_tokens,
sampler=sampler,
logits_processors=logits_processors,
prompt_cache=caches,
# TODO: Dynamically change prefill step size to be the maximum possible without timing out.
prefill_step_size=2048,
kv_group_size=KV_GROUP_SIZE,
kv_bits=KV_BITS,
):
# Build kwargs for stream_generate, conditionally adding draft model params
generate_kwargs: dict[str, object] = {
"model": model,
"tokenizer": tokenizer,
"prompt": prompt,
"max_tokens": max_tokens,
"sampler": sampler,
"logits_processors": logits_processors,
"prefill_step_size": 2048,
"kv_group_size": KV_GROUP_SIZE,
"kv_bits": KV_BITS,
}
# Add speculative decoding parameters if draft model is provided
# Note: When using draft_model, we let mlx_lm create its own trimmable cache
# as speculative decoding requires cache trimming capabilities
if draft_model is not None:
generate_kwargs["draft_model"] = draft_model
generate_kwargs["num_draft_tokens"] = num_draft_tokens
else:
# Only use custom cache for non-speculative generation
generate_kwargs["prompt_cache"] = make_kv_cache(model=model)
for out in stream_generate(**generate_kwargs): # type: ignore[arg-type]
logger.info(out.text)
stats: GenerationStats | None = None

View File

@@ -230,6 +230,27 @@ def load_mlx_items(
return cast(Model, model), tokenizer
def load_draft_model(model_id: ModelId) -> nn.Module:
"""Load a draft model for speculative decoding (rank 0 only).
Draft models are small models (typically 0.5B-2B parameters) used to
generate candidate tokens quickly, which are then verified by the main
model in a single forward pass.
Assumes the model has already been downloaded by the worker.
Args:
model_id: HuggingFace model ID for the draft model
Returns:
The loaded draft model
"""
model_path = build_model_path(model_id)
draft_model, _ = load_model(model_path, strict=True)
logger.info(f"Loaded draft model from {model_path}")
return draft_model
def shard_and_load(
shard_metadata: ShardMetadata,
group: Group,

View File

@@ -29,8 +29,10 @@ from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.state import State
from exo.shared.types.tasks import (
CreateRunner,
DownloadDraftModel,
DownloadModel,
ImageEdits,
SetDraftModel,
Shutdown,
Task,
TaskStatus,
@@ -51,6 +53,7 @@ from exo.utils.info_gatherer.net_profile import check_reachable
from exo.worker.download.download_utils import (
map_repo_download_progress_to_download_progress_data,
)
from exo.worker.download.impl_shard_downloader import build_full_shard
from exo.worker.download.shard_downloader import RepoDownloadProgress, ShardDownloader
from exo.worker.plan import plan
from exo.worker.runner.runner_supervisor import RunnerSupervisor
@@ -206,42 +209,10 @@ class Worker:
)
)
case DownloadModel(shard_metadata=shard):
if shard.model_card.model_id not in self.download_status:
progress = DownloadPending(
shard_metadata=shard, node_id=self.node_id
)
self.download_status[shard.model_card.model_id] = progress
await self.event_sender.send(
NodeDownloadProgress(download_progress=progress)
)
initial_progress = (
await self.shard_downloader.get_shard_download_status_for_shard(
shard
)
)
if initial_progress.status == "complete":
progress = DownloadCompleted(
shard_metadata=shard,
node_id=self.node_id,
total_bytes=initial_progress.total_bytes,
)
self.download_status[shard.model_card.model_id] = progress
await self.event_sender.send(
NodeDownloadProgress(download_progress=progress)
)
await self.event_sender.send(
TaskStatusUpdated(
task_id=task.task_id,
task_status=TaskStatus.Complete,
)
)
else:
await self.event_sender.send(
TaskStatusUpdated(
task_id=task.task_id, task_status=TaskStatus.Running
)
)
self._handle_shard_download_process(task, initial_progress)
await self._handle_download(shard, task)
case DownloadDraftModel(model_id=model_id):
shard = await build_full_shard(ModelId(model_id))
await self._handle_download(shard, task)
case Shutdown(runner_id=runner_id):
try:
with fail_after(3):
@@ -292,6 +263,25 @@ class Worker:
await self.runners[self._task_to_runner_id(task)].start_task(
modified_task
)
case SetDraftModel(
model_id=draft_model_id, num_draft_tokens=num_tokens
):
runner = self.runners[self._task_to_runner_id(task)]
await runner.start_task(task)
# Update bound_instance to reflect new/cleared draft model
updated_instance = runner.bound_instance.instance.model_copy(
update={
"draft_model": (
ModelId(draft_model_id)
if draft_model_id is not None
else None
),
"num_draft_tokens": num_tokens,
}
)
runner.bound_instance = runner.bound_instance.model_copy(
update={"instance": updated_instance}
)
case task:
await self.runners[self._task_to_runner_id(task)].start_task(task)
@@ -386,6 +376,46 @@ class Worker:
self._tg.start_soon(runner.run)
return runner
async def _handle_download(self, shard: ShardMetadata, task: Task) -> None:
"""Handle model download - shared logic for main and draft models."""
model_id = shard.model_card.model_id
if model_id not in self.download_status:
progress = DownloadPending(shard_metadata=shard, node_id=self.node_id)
self.download_status[model_id] = progress
await self.event_sender.send(
NodeDownloadProgress(download_progress=progress)
)
initial_progress = (
await self.shard_downloader.get_shard_download_status_for_shard(shard)
)
if initial_progress.status == "complete":
progress = DownloadCompleted(
shard_metadata=shard,
node_id=self.node_id,
total_bytes=initial_progress.total_bytes,
)
self.download_status[model_id] = progress
await self.event_sender.send(
NodeDownloadProgress(download_progress=progress)
)
await self.event_sender.send(
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Complete)
)
else:
await self.event_sender.send(
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Running)
)
download_task = DownloadModel(
instance_id=task.instance_id,
shard_metadata=shard,
task_id=task.task_id,
task_status=task.task_status,
)
self._handle_shard_download_process(download_task, initial_progress)
def _handle_shard_download_process(
self,
task: DownloadModel,

View File

@@ -8,10 +8,12 @@ from exo.shared.types.tasks import (
ChatCompletion,
ConnectToGroup,
CreateRunner,
DownloadDraftModel,
DownloadModel,
ImageEdits,
ImageGeneration,
LoadModel,
SetDraftModel,
Shutdown,
StartWarmup,
Task,
@@ -40,6 +42,16 @@ from exo.shared.types.worker.runners import (
from exo.worker.runner.runner_supervisor import RunnerSupervisor
def _is_download_in_progress_or_complete(
model_id: ModelId,
download_status: Mapping[ModelId, DownloadProgress],
) -> bool:
"""Check if model download is in progress or complete."""
return model_id in download_status and isinstance(
download_status[model_id], (DownloadOngoing, DownloadCompleted)
)
def plan(
node_id: NodeId,
# Runners is expected to be FRESH and so should not come from state
@@ -59,9 +71,11 @@ def plan(
_kill_runner(runners, all_runners, instances)
or _create_runner(node_id, runners, instances)
or _model_needs_download(runners, download_status)
or _draft_model_needs_download(runners, download_status, instances)
or _init_distributed_backend(runners, all_runners)
or _load_model(runners, all_runners, global_download_status)
or _load_model(runners, all_runners, global_download_status, download_status)
or _ready_to_warmup(runners, all_runners)
or _set_draft_model(runners, instances, download_status)
or _pending_tasks(runners, tasks, all_runners, input_chunk_buffer)
)
@@ -119,12 +133,9 @@ def _model_needs_download(
) -> DownloadModel | None:
for runner in runners.values():
model_id = runner.bound_instance.bound_shard.model_card.model_id
if isinstance(runner.status, RunnerIdle) and (
model_id not in download_status
or not isinstance(
download_status[model_id], (DownloadOngoing, DownloadCompleted)
)
):
if isinstance(
runner.status, RunnerIdle
) and not _is_download_in_progress_or_complete(model_id, download_status):
# We don't invalidate download_status randomly in case a file gets deleted on disk
return DownloadModel(
instance_id=runner.bound_instance.instance.instance_id,
@@ -132,6 +143,43 @@ def _model_needs_download(
)
def _draft_model_needs_download(
runners: Mapping[RunnerId, RunnerSupervisor],
download_status: Mapping[ModelId, DownloadProgress],
instances: Mapping[InstanceId, Instance],
) -> DownloadDraftModel | None:
"""Check if draft model needs download for rank 0 runner.
Triggers download when:
- RunnerIdle with draft model (initial setup)
- RunnerReady with new draft model (updated via API)
"""
rank_0_runner = next(
(r for r in runners.values() if r.bound_instance.bound_shard.device_rank == 0),
None,
)
if rank_0_runner is None:
return None
if not isinstance(rank_0_runner.status, (RunnerIdle, RunnerReady)):
return None
# Use current instance state (may have been updated via API)
instance_id = rank_0_runner.bound_instance.instance.instance_id
current_instance = instances.get(instance_id)
if current_instance is None:
return None
draft_model_id = current_instance.draft_model
if draft_model_id is None:
return None
if _is_download_in_progress_or_complete(draft_model_id, download_status):
return None
return DownloadDraftModel(
instance_id=instance_id,
model_id=str(draft_model_id),
)
def _init_distributed_backend(
runners: Mapping[RunnerId, RunnerSupervisor],
all_runners: Mapping[RunnerId, RunnerStatus],
@@ -186,10 +234,12 @@ def _load_model(
runners: Mapping[RunnerId, RunnerSupervisor],
all_runners: Mapping[RunnerId, RunnerStatus],
global_download_status: Mapping[NodeId, Sequence[DownloadProgress]],
download_status: Mapping[ModelId, DownloadProgress],
) -> LoadModel | None:
for runner in runners.values():
instance = runner.bound_instance.instance
shard_assignments = instance.shard_assignments
shard = runner.bound_instance.bound_shard
all_local_downloads_complete = all(
nid in global_download_status
@@ -203,6 +253,14 @@ def _load_model(
if not all_local_downloads_complete:
continue
# Rank 0 with draft model must wait for draft download before loading
if shard.device_rank == 0:
draft_model_id = instance.draft_model
if draft_model_id is not None and not isinstance(
download_status.get(draft_model_id), DownloadCompleted
):
continue
is_single_node_instance = len(instance.shard_assignments.runner_to_shard) == 1
if is_single_node_instance and isinstance(runner.status, RunnerIdle):
return LoadModel(instance_id=instance.instance_id)
@@ -262,6 +320,53 @@ def _ready_to_warmup(
return None
def _set_draft_model(
runners: Mapping[RunnerId, RunnerSupervisor],
instances: Mapping[InstanceId, Instance],
download_status: Mapping[ModelId, DownloadProgress],
) -> SetDraftModel | None:
"""Check if rank 0 runner needs to load or clear a draft model."""
rank_0_runner = next(
(r for r in runners.values() if r.bound_instance.bound_shard.device_rank == 0),
None,
)
if rank_0_runner is None:
return None
if not isinstance(rank_0_runner.status, RunnerReady):
return None
instance_id = rank_0_runner.bound_instance.instance.instance_id
current_instance = instances.get(instance_id)
if current_instance is None:
return None
# Compare runner's bound draft model vs current instance draft model
runner_draft_model = rank_0_runner.bound_instance.instance.draft_model
current_draft_model = current_instance.draft_model
if runner_draft_model == current_draft_model:
return None
# Draft model changed - need to update
if current_draft_model is None:
# Clear draft model
return SetDraftModel(
instance_id=instance_id,
model_id=None,
num_draft_tokens=4,
)
# Wait for draft model to be downloaded
if not isinstance(download_status.get(current_draft_model), DownloadCompleted):
return None
return SetDraftModel(
instance_id=instance_id,
model_id=str(current_draft_model),
num_draft_tokens=current_instance.num_draft_tokens,
)
def _pending_tasks(
runners: Mapping[RunnerId, RunnerSupervisor],
tasks: Mapping[TaskId, Task],

View File

@@ -2,7 +2,7 @@ import base64
import time
from collections.abc import Generator
from functools import cache
from typing import Literal
from typing import Literal, cast
import mlx.core as mx
from mlx_lm.models.gpt_oss import Model as GptOssModel
@@ -32,6 +32,7 @@ from exo.shared.types.tasks import (
ImageEdits,
ImageGeneration,
LoadModel,
SetDraftModel,
Shutdown,
StartWarmup,
Task,
@@ -71,6 +72,7 @@ from exo.worker.engines.mlx.utils_mlx import (
apply_chat_template,
detect_thinking_prompt_suffix,
initialize_mlx,
load_draft_model,
load_mlx_items,
mlx_force_oom,
)
@@ -99,6 +101,7 @@ def main(
model: Model | DistributedImageModel | None = None
tokenizer = None
group = None
draft_model: Model | None = None # Loaded during warmup if instance has draft_model
current_status: RunnerStatus = RunnerIdle()
logger.info("runner created")
@@ -164,6 +167,16 @@ def main(
f"Unknown model task(s): {shard_metadata.model_card.tasks}"
)
# Load draft model for speculative decoding (rank 0 only)
if (
instance.draft_model is not None
and shard_metadata.device_rank == 0
):
logger.info(f"Loading draft model: {instance.draft_model}")
draft_model = cast(
Model, load_draft_model(str(instance.draft_model))
)
current_status = RunnerLoaded()
logger.info("runner loaded")
case StartWarmup() if isinstance(current_status, RunnerLoaded):
@@ -185,7 +198,8 @@ def main(
toks = warmup_inference(
model=model,
tokenizer=tokenizer,
# kv_prefix_cache=kv_prefix_cache, # supply for warmup-time prefix caching
draft_model=draft_model,
num_draft_tokens=instance.num_draft_tokens,
)
logger.info(f"warmed up by generating {toks} tokens")
logger.info(
@@ -231,6 +245,8 @@ def main(
tokenizer=tokenizer,
task=task_params,
prompt=prompt,
draft_model=draft_model,
num_draft_tokens=instance.num_draft_tokens,
)
# GPT-OSS specific parsing to match other model formats.
@@ -412,6 +428,52 @@ def main(
current_status = RunnerReady()
logger.info("runner ready")
case SetDraftModel(
model_id=draft_model_id, num_draft_tokens=num_tokens
) if isinstance(current_status, RunnerReady):
current_status = RunnerWarmingUp()
logger.info("runner warming up (setting draft model)")
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=current_status
)
)
assert model is not None
assert tokenizer is not None
if draft_model_id is None:
# Clear draft model
logger.info("Clearing draft model")
draft_model = None
instance = instance.model_copy(
update={
"draft_model": None,
"num_draft_tokens": 4,
}
)
else:
# Load new draft model
logger.info(f"Loading draft model: {draft_model_id}")
draft_model = cast(
Model, load_draft_model(ModelId(draft_model_id))
)
instance = instance.model_copy(
update={
"draft_model": ModelId(draft_model_id),
"num_draft_tokens": num_tokens,
}
)
# Warm up with speculative decoding
logger.info("Warming up with new draft model")
warmup_inference(
model=cast(Model, model),
tokenizer=tokenizer,
draft_model=draft_model,
num_draft_tokens=num_tokens,
)
logger.info("Draft model loaded and warmed up")
current_status = RunnerReady()
case Shutdown():
current_status = RunnerShuttingDown()
logger.info("runner shutting down")
@@ -432,7 +494,7 @@ def main(
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
)
if isinstance(current_status, RunnerShutdown):
del model, tokenizer, group
del model, tokenizer, group, draft_model
mx.clear_cache()
import gc