Compare commits

...

3 Commits

Author SHA1 Message Date
Alex Cheema
90e2a20091 Merge remote-tracking branch 'origin/main' into alexcheema/speculative-decoding
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-05 06:09:50 -08:00
Alex Cheema
55152fa99d Merge remote-tracking branch 'origin/main' into alexcheema/speculative-decoding
# Conflicts:
#	dashboard/src/routes/+page.svelte
#	src/exo/master/api.py
#	src/exo/shared/types/api.py
#	src/exo/shared/types/commands.py
#	src/exo/worker/engines/mlx/generator/generate.py
#	src/exo/worker/main.py
#	src/exo/worker/plan.py
#	src/exo/worker/runner/runner.py
2026-02-05 06:07:51 -08:00
Alex Cheema
e7f61c3494 Add speculative decoding support with draft models
This adds support for speculative decoding using draft models to accelerate
inference. Key changes:

- Add draft_model and num_draft_tokens fields to Instance for configuration
- Add SetDraftModel task to load/clear draft models on running instances
- Add InstanceDraftModelUpdated event to propagate draft model changes
- Add SetInstanceDraftModel command and API endpoint for runtime updates
- Update plan.py to download draft models in parallel with main model
- Update runner to load draft model during LoadModel phase
- Add draft model UI to dashboard instances panel (both views)

The draft model can be configured when creating an instance or updated on
a running instance via the dashboard or API.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-22 13:09:26 +00:00
9 changed files with 96 additions and 2 deletions

View File

@@ -69,6 +69,8 @@ export interface Instance {
runnerToShard?: Record<string, unknown>;
nodeToRunner?: Record<string, string>;
};
draftModel?: string;
numDraftTokens?: number;
}
// Granular node state types from the new state structure

View File

@@ -21,6 +21,7 @@ from exo.shared.types.commands import (
PlaceInstance,
RequestEventLog,
SendInputChunk,
SetInstanceDraftModel,
TaskFinished,
TestCommand,
TextGeneration,
@@ -32,6 +33,7 @@ from exo.shared.types.events import (
IndexedEvent,
InputChunkReceived,
InstanceDeleted,
InstanceDraftModelUpdated,
NodeGatheredInfo,
NodeTimedOut,
TaskCreated,
@@ -308,6 +310,14 @@ class Master:
chunk=chunk,
)
)
case SetInstanceDraftModel():
generated_events.append(
InstanceDraftModelUpdated(
instance_id=command.instance_id,
draft_model=command.draft_model,
num_draft_tokens=command.num_draft_tokens,
)
)
case TaskFinished():
generated_events.append(
TaskDeleted(

View File

@@ -147,6 +147,8 @@ def place_instance(
shard_assignments=shard_assignments,
jaccl_devices=mlx_jaccl_devices,
jaccl_coordinators=mlx_jaccl_coordinators,
draft_model=command.draft_model,
num_draft_tokens=command.num_draft_tokens,
)
case InstanceMeta.MlxRing:
ephemeral_port = random_ephemeral_port()
@@ -161,6 +163,8 @@ def place_instance(
shard_assignments=shard_assignments,
hosts_by_node=hosts_by_node,
ephemeral_port=ephemeral_port,
draft_model=command.draft_model,
num_draft_tokens=command.num_draft_tokens,
)
return target_instances

View File

@@ -12,6 +12,7 @@ from exo.shared.types.events import (
InputChunkReceived,
InstanceCreated,
InstanceDeleted,
InstanceDraftModelUpdated,
NodeDownloadProgress,
NodeGatheredInfo,
NodeTimedOut,
@@ -69,6 +70,8 @@ def event_apply(event: Event, state: State) -> State:
return apply_instance_created(event, state)
case InstanceDeleted():
return apply_instance_deleted(event, state)
case InstanceDraftModelUpdated():
return apply_instance_draft_model_updated(event, state)
case NodeTimedOut():
return apply_node_timed_out(event, state)
case NodeDownloadProgress():
@@ -187,6 +190,25 @@ def apply_instance_deleted(event: InstanceDeleted, state: State) -> State:
return state.model_copy(update={"instances": new_instances})
def apply_instance_draft_model_updated(
event: InstanceDraftModelUpdated, state: State
) -> State:
if event.instance_id not in state.instances:
return state
instance = state.instances[event.instance_id]
updated_instance = instance.model_copy(
update={
"draft_model": event.draft_model,
"num_draft_tokens": event.num_draft_tokens,
}
)
new_instances: Mapping[InstanceId, Instance] = {
**state.instances,
event.instance_id: updated_instance,
}
return state.model_copy(update={"instances": new_instances})
def apply_runner_status_updated(event: RunnerStatusUpdated, state: State) -> State:
new_runners: Mapping[RunnerId, RunnerStatus] = {
**state.runners,

View File

@@ -72,6 +72,14 @@ class DeleteDownload(BaseCommand):
model_id: ModelId
class SetInstanceDraftModel(BaseCommand):
"""Set or update the draft model for an existing instance."""
instance_id: InstanceId
draft_model: ModelId | None # None to disable speculative decoding
num_draft_tokens: int = 4
DownloadCommand = StartDownload | DeleteDownload
@@ -84,6 +92,7 @@ Command = (
| PlaceInstance
| CreateInstance
| DeleteInstance
| SetInstanceDraftModel
| TaskFinished
| SendInputChunk
)

View File

@@ -5,7 +5,7 @@ from pydantic import Field
from exo.shared.topology import Connection
from exo.shared.types.chunks import GenerationChunk, InputImageChunk
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
from exo.shared.types.common import CommandId, Id, ModelId, NodeId, SessionId
from exo.shared.types.tasks import Task, TaskId, TaskStatus
from exo.shared.types.worker.downloads import DownloadProgress
from exo.shared.types.worker.instances import Instance, InstanceId
@@ -68,6 +68,14 @@ class InstanceDeleted(BaseEvent):
instance_id: InstanceId
class InstanceDraftModelUpdated(BaseEvent):
"""Draft model updated on an existing instance."""
instance_id: InstanceId
draft_model: ModelId | None
num_draft_tokens: int
class RunnerStatusUpdated(BaseEvent):
runner_id: RunnerId
runner_status: RunnerStatus
@@ -141,6 +149,7 @@ Event = (
| TaskAcknowledged
| InstanceCreated
| InstanceDeleted
| InstanceDraftModelUpdated
| RunnerStatusUpdated
| RunnerDeleted
| NodeTimedOut

View File

@@ -40,6 +40,12 @@ class DownloadModel(BaseTask): # emitted by Worker
shard_metadata: ShardMetadata
class DownloadDraftModel(BaseTask): # emitted by Worker
"""Download a draft model for speculative decoding (rank 0 only)."""
model_id: str # HuggingFace model ID
class LoadModel(BaseTask): # emitted by Worker
pass
@@ -80,9 +86,17 @@ class Shutdown(BaseTask): # emitted by Worker
runner_id: RunnerId
class SetDraftModel(BaseTask): # emitted by Worker
"""Load or clear a draft model on an already-running instance."""
model_id: str | None # HuggingFace model ID, or None to clear
num_draft_tokens: int = 4
Task = (
CreateRunner
| DownloadModel
| DownloadDraftModel
| ConnectToGroup
| LoadModel
| StartWarmup
@@ -90,4 +104,5 @@ Task = (
| ImageGeneration
| ImageEdits
| Shutdown
| SetDraftModel
)

View File

@@ -2,7 +2,7 @@ from enum import Enum
from pydantic import model_validator
from exo.shared.types.common import Host, Id, NodeId
from exo.shared.types.common import Host, Id, ModelId, NodeId
from exo.shared.types.worker.runners import RunnerId, ShardAssignments, ShardMetadata
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
@@ -19,6 +19,8 @@ class InstanceMeta(str, Enum):
class BaseInstance(TaggedModel):
instance_id: InstanceId
shard_assignments: ShardAssignments
draft_model: ModelId | None = None # For speculative decoding (rank 0 only)
num_draft_tokens: int = 4 # Tokens to draft per iteration (when draft_model is set)
def shard(self, runner_id: RunnerId) -> ShardMetadata | None:
return self.shard_assignments.runner_to_shard.get(runner_id, None)

View File

@@ -226,6 +226,27 @@ def load_mlx_items(
return cast(Model, model), tokenizer
def load_draft_model(model_id: ModelId) -> nn.Module:
"""Load a draft model for speculative decoding (rank 0 only).
Draft models are small models (typically 0.5B-2B parameters) used to
generate candidate tokens quickly, which are then verified by the main
model in a single forward pass.
Assumes the model has already been downloaded by the worker.
Args:
model_id: HuggingFace model ID for the draft model
Returns:
The loaded draft model
"""
model_path = build_model_path(model_id)
draft_model, _ = load_model(model_path, strict=True)
logger.info(f"Loaded draft model from {model_path}")
return draft_model
def shard_and_load(
shard_metadata: ShardMetadata,
group: Group,