Compare commits

..

6 Commits

Author SHA1 Message Date
Alex Cheema
3d4f8130c9 Merge remote-tracking branch 'origin/main' into alexcheema/speculative-decoding 2026-02-13 09:39:11 -08:00
Alex Cheema
6762a0a8ba Add draft_model and num_draft_tokens fields to PlaceInstance command
PlaceInstance was missing these fields that are accessed in placement.py
when creating MlxJaccl and MlxRing instances with speculative decoding
support, causing type errors after merging main.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-13 05:51:11 -08:00
Alex Cheema
6524820e5b Merge remote-tracking branch 'origin/main' into alexcheema/speculative-decoding
# Conflicts:
#	src/exo/shared/types/commands.py
2026-02-13 05:46:52 -08:00
Alex Cheema
90e2a20091 Merge remote-tracking branch 'origin/main' into alexcheema/speculative-decoding
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-05 06:09:50 -08:00
Alex Cheema
55152fa99d Merge remote-tracking branch 'origin/main' into alexcheema/speculative-decoding
# Conflicts:
#	dashboard/src/routes/+page.svelte
#	src/exo/master/api.py
#	src/exo/shared/types/api.py
#	src/exo/shared/types/commands.py
#	src/exo/worker/engines/mlx/generator/generate.py
#	src/exo/worker/main.py
#	src/exo/worker/plan.py
#	src/exo/worker/runner/runner.py
2026-02-05 06:07:51 -08:00
Alex Cheema
e7f61c3494 Add speculative decoding support with draft models
This adds support for speculative decoding using draft models to accelerate
inference. Key changes:

- Add draft_model and num_draft_tokens fields to Instance for configuration
- Add SetDraftModel task to load/clear draft models on running instances
- Add InstanceDraftModelUpdated event to propagate draft model changes
- Add SetInstanceDraftModel command and API endpoint for runtime updates
- Update plan.py to download draft models in parallel with main model
- Update runner to load draft model during LoadModel phase
- Add draft model UI to dashboard instances panel (both views)

The draft model can be configured when creating an instance or updated on
a running instance via the dashboard or API.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-22 13:09:26 +00:00
15 changed files with 131 additions and 54 deletions

View File

@@ -72,6 +72,8 @@ export interface Instance {
runnerToShard?: Record<string, unknown>;
nodeToRunner?: Record<string, string>;
};
draftModel?: string;
numDraftTokens?: number;
}
// Granular node state types from the new state structure

View File

@@ -17,7 +17,6 @@ from exo.shared.types.api import (
LogprobsContentItem,
StreamingChoiceResponse,
ToolCall,
Usage,
)
from exo.shared.types.chunks import ErrorChunk, TokenChunk, ToolCallChunk
from exo.shared.types.common import CommandId
@@ -126,8 +125,6 @@ async def generate_chat_stream(
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
) -> AsyncGenerator[str, None]:
"""Generate Chat Completions API streaming events from chunks."""
last_usage: Usage | None = None
async for chunk in chunk_stream:
if isinstance(chunk, ErrorChunk):
error_response = ErrorResponse(
@@ -141,8 +138,6 @@ async def generate_chat_stream(
yield "data: [DONE]\n\n"
return
last_usage = chunk.usage or last_usage
if isinstance(chunk, ToolCallChunk):
tool_call_deltas = [
ToolCall(
@@ -166,15 +161,12 @@ async def generate_chat_stream(
finish_reason="tool_calls",
)
],
usage=last_usage,
)
yield f"data: {tool_response.model_dump_json()}\n\n"
yield "data: [DONE]\n\n"
return
chunk_response = chunk_to_response(chunk, command_id)
if chunk.finish_reason is not None:
chunk_response = chunk_response.model_copy(update={"usage": last_usage})
yield f"data: {chunk_response.model_dump_json()}\n\n"
if chunk.finish_reason is not None:
@@ -192,7 +184,6 @@ async def collect_chat_response(
model: str | None = None
finish_reason: FinishReason | None = None
error_message: str | None = None
last_usage: Usage | None = None
async for chunk in chunk_stream:
if isinstance(chunk, ErrorChunk):
@@ -202,8 +193,6 @@ async def collect_chat_response(
if model is None:
model = chunk.model
last_usage = chunk.usage or last_usage
if isinstance(chunk, TokenChunk):
text_parts.append(chunk.text)
if chunk.logprob is not None:
@@ -252,5 +241,4 @@ async def collect_chat_response(
finish_reason=finish_reason,
)
],
usage=last_usage,
)

View File

@@ -4,7 +4,7 @@ import json
from collections.abc import AsyncGenerator
from typing import Any
from exo.shared.types.api import FinishReason, Usage
from exo.shared.types.api import FinishReason
from exo.shared.types.chunks import ErrorChunk, TokenChunk, ToolCallChunk
from exo.shared.types.claude_api import (
ClaudeContentBlock,
@@ -166,7 +166,7 @@ async def collect_claude_response(
text_parts: list[str] = []
tool_use_blocks: list[ClaudeToolUseBlock] = []
stop_reason: ClaudeStopReason | None = None
last_usage: Usage | None = None
last_stats = None
error_message: str | None = None
async for chunk in chunk_stream:
@@ -174,8 +174,6 @@ async def collect_claude_response(
error_message = chunk.error_message or "Internal server error"
break
last_usage = chunk.usage or last_usage
if isinstance(chunk, ToolCallChunk):
for tool in chunk.tool_calls:
tool_use_blocks.append(
@@ -185,10 +183,12 @@ async def collect_claude_response(
input=json.loads(tool.arguments), # pyright: ignore[reportAny]
)
)
last_stats = chunk.stats or last_stats
stop_reason = "tool_use"
continue
text_parts.append(chunk.text)
last_stats = chunk.stats or last_stats
if chunk.finish_reason is not None:
stop_reason = finish_reason_to_claude_stop_reason(chunk.finish_reason)
@@ -208,9 +208,9 @@ async def collect_claude_response(
if not content:
content.append(ClaudeTextBlock(text=""))
# Use actual usage data if available
input_tokens = last_usage.prompt_tokens if last_usage else 0
output_tokens = last_usage.completion_tokens if last_usage else 0
# Use actual usage data from stats if available
input_tokens = last_stats.prompt_tokens if last_stats else 0
output_tokens = last_stats.generation_tokens if last_stats else 0
return ClaudeMessagesResponse(
id=f"msg_{command_id}",
@@ -249,7 +249,7 @@ async def generate_claude_stream(
output_tokens = 0
stop_reason: ClaudeStopReason | None = None
last_usage: Usage | None = None
last_stats = None
next_block_index = 1 # text block is 0, tool blocks start at 1
async for chunk in chunk_stream:
@@ -257,9 +257,8 @@ async def generate_claude_stream(
# Close text block and bail
break
last_usage = chunk.usage or last_usage
if isinstance(chunk, ToolCallChunk):
last_stats = chunk.stats or last_stats
stop_reason = "tool_use"
# Emit tool_use content blocks
@@ -291,6 +290,7 @@ async def generate_claude_stream(
continue
output_tokens += 1 # Count each chunk as one token
last_stats = chunk.stats or last_stats
# content_block_delta
delta_event = ClaudeContentBlockDeltaEvent(
@@ -302,9 +302,9 @@ async def generate_claude_stream(
if chunk.finish_reason is not None:
stop_reason = finish_reason_to_claude_stop_reason(chunk.finish_reason)
# Use actual token count from usage if available
if last_usage is not None:
output_tokens = last_usage.completion_tokens
# Use actual token count from stats if available
if last_stats is not None:
output_tokens = last_stats.generation_tokens
# content_block_stop for text block
block_stop = ClaudeContentBlockStopEvent(index=0)

View File

@@ -4,7 +4,6 @@ from collections.abc import AsyncGenerator
from itertools import count
from typing import Any
from exo.shared.types.api import Usage
from exo.shared.types.chunks import ErrorChunk, TokenChunk, ToolCallChunk
from exo.shared.types.common import CommandId
from exo.shared.types.openai_responses import (
@@ -128,7 +127,7 @@ async def collect_responses_response(
item_id = f"item_{command_id}"
accumulated_text = ""
function_call_items: list[ResponseFunctionCallItem] = []
last_usage: Usage | None = None
last_stats = None
error_message: str | None = None
async for chunk in chunk_stream:
@@ -136,8 +135,6 @@ async def collect_responses_response(
error_message = chunk.error_message or "Internal server error"
break
last_usage = chunk.usage or last_usage
if isinstance(chunk, ToolCallChunk):
for tool in chunk.tool_calls:
function_call_items.append(
@@ -148,20 +145,22 @@ async def collect_responses_response(
arguments=tool.arguments,
)
)
last_stats = chunk.stats or last_stats
continue
accumulated_text += chunk.text
last_stats = chunk.stats or last_stats
if error_message is not None:
raise ValueError(error_message)
# Create usage from usage data if available
# Create usage from stats if available
usage = None
if last_usage is not None:
if last_stats is not None:
usage = ResponseUsage(
input_tokens=last_usage.prompt_tokens,
output_tokens=last_usage.completion_tokens,
total_tokens=last_usage.total_tokens,
input_tokens=last_stats.prompt_tokens,
output_tokens=last_stats.generation_tokens,
total_tokens=last_stats.prompt_tokens + last_stats.generation_tokens,
)
output: list[ResponseItem] = [
@@ -236,16 +235,15 @@ async def generate_responses_stream(
accumulated_text = ""
function_call_items: list[ResponseFunctionCallItem] = []
last_usage: Usage | None = None
last_stats = None
next_output_index = 1 # message item is at 0
async for chunk in chunk_stream:
if isinstance(chunk, ErrorChunk):
break
last_usage = chunk.usage or last_usage
if isinstance(chunk, ToolCallChunk):
last_stats = chunk.stats or last_stats
for tool in chunk.tool_calls:
fc_id = f"fc_{tool.id}"
call_id = f"call_{tool.id}"
@@ -304,6 +302,7 @@ async def generate_responses_stream(
continue
accumulated_text += chunk.text
last_stats = chunk.stats or last_stats
# response.output_text.delta
delta_event = ResponseTextDeltaEvent(
@@ -347,13 +346,13 @@ async def generate_responses_stream(
)
yield f"event: response.output_item.done\ndata: {item_done.model_dump_json()}\n\n"
# Create usage from usage data if available
# Create usage from stats if available
usage = None
if last_usage is not None:
if last_stats is not None:
usage = ResponseUsage(
input_tokens=last_usage.prompt_tokens,
output_tokens=last_usage.completion_tokens,
total_tokens=last_usage.total_tokens,
input_tokens=last_stats.prompt_tokens,
output_tokens=last_stats.generation_tokens,
total_tokens=last_stats.prompt_tokens + last_stats.generation_tokens,
)
# response.completed

View File

@@ -24,6 +24,7 @@ from exo.shared.types.commands import (
PlaceInstance,
RequestEventLog,
SendInputChunk,
SetInstanceDraftModel,
TaskFinished,
TestCommand,
TextGeneration,
@@ -35,6 +36,7 @@ from exo.shared.types.events import (
IndexedEvent,
InputChunkReceived,
InstanceDeleted,
InstanceDraftModelUpdated,
NodeGatheredInfo,
NodeTimedOut,
TaskCreated,
@@ -319,6 +321,14 @@ class Master:
chunk=chunk,
)
)
case SetInstanceDraftModel():
generated_events.append(
InstanceDraftModelUpdated(
instance_id=command.instance_id,
draft_model=command.draft_model,
num_draft_tokens=command.num_draft_tokens,
)
)
case TaskFinished():
generated_events.append(
TaskDeleted(

View File

@@ -153,6 +153,8 @@ def place_instance(
shard_assignments=shard_assignments,
jaccl_devices=mlx_jaccl_devices,
jaccl_coordinators=mlx_jaccl_coordinators,
draft_model=command.draft_model,
num_draft_tokens=command.num_draft_tokens,
)
case InstanceMeta.MlxRing:
ephemeral_port = random_ephemeral_port()
@@ -167,6 +169,8 @@ def place_instance(
shard_assignments=shard_assignments,
hosts_by_node=hosts_by_node,
ephemeral_port=ephemeral_port,
draft_model=command.draft_model,
num_draft_tokens=command.num_draft_tokens,
)
return target_instances

View File

@@ -12,6 +12,7 @@ from exo.shared.types.events import (
InputChunkReceived,
InstanceCreated,
InstanceDeleted,
InstanceDraftModelUpdated,
NodeDownloadProgress,
NodeGatheredInfo,
NodeTimedOut,
@@ -72,6 +73,8 @@ def event_apply(event: Event, state: State) -> State:
return apply_instance_created(event, state)
case InstanceDeleted():
return apply_instance_deleted(event, state)
case InstanceDraftModelUpdated():
return apply_instance_draft_model_updated(event, state)
case NodeTimedOut():
return apply_node_timed_out(event, state)
case NodeDownloadProgress():
@@ -190,6 +193,25 @@ def apply_instance_deleted(event: InstanceDeleted, state: State) -> State:
return state.model_copy(update={"instances": new_instances})
def apply_instance_draft_model_updated(
event: InstanceDraftModelUpdated, state: State
) -> State:
if event.instance_id not in state.instances:
return state
instance = state.instances[event.instance_id]
updated_instance = instance.model_copy(
update={
"draft_model": event.draft_model,
"num_draft_tokens": event.num_draft_tokens,
}
)
new_instances: Mapping[InstanceId, Instance] = {
**state.instances,
event.instance_id: updated_instance,
}
return state.model_copy(update={"instances": new_instances})
def apply_runner_status_updated(event: RunnerStatusUpdated, state: State) -> State:
new_runners: Mapping[RunnerId, RunnerStatus] = {
**state.runners,

View File

@@ -38,6 +38,8 @@ class PlaceInstance(BaseCommand):
sharding: Sharding
instance_meta: InstanceMeta
min_nodes: int
draft_model: ModelId | None = None
num_draft_tokens: int = 4
class CreateInstance(BaseCommand):
@@ -72,6 +74,14 @@ class DeleteDownload(BaseCommand):
model_id: ModelId
class SetInstanceDraftModel(BaseCommand):
"""Set or update the draft model for an existing instance."""
instance_id: InstanceId
draft_model: ModelId | None # None to disable speculative decoding
num_draft_tokens: int = 4
class CancelDownload(BaseCommand):
target_node_id: NodeId
model_id: ModelId
@@ -89,6 +99,7 @@ Command = (
| PlaceInstance
| CreateInstance
| DeleteInstance
| SetInstanceDraftModel
| TaskFinished
| SendInputChunk
)

View File

@@ -5,7 +5,7 @@ from pydantic import Field
from exo.shared.topology import Connection
from exo.shared.types.chunks import GenerationChunk, InputImageChunk
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
from exo.shared.types.common import CommandId, Id, ModelId, NodeId, SessionId
from exo.shared.types.tasks import Task, TaskId, TaskStatus
from exo.shared.types.worker.downloads import DownloadProgress
from exo.shared.types.worker.instances import Instance, InstanceId
@@ -68,6 +68,14 @@ class InstanceDeleted(BaseEvent):
instance_id: InstanceId
class InstanceDraftModelUpdated(BaseEvent):
"""Draft model updated on an existing instance."""
instance_id: InstanceId
draft_model: ModelId | None
num_draft_tokens: int
class RunnerStatusUpdated(BaseEvent):
runner_id: RunnerId
runner_status: RunnerStatus
@@ -141,6 +149,7 @@ Event = (
| TaskAcknowledged
| InstanceCreated
| InstanceDeleted
| InstanceDraftModelUpdated
| RunnerStatusUpdated
| RunnerDeleted
| NodeTimedOut

View File

@@ -40,6 +40,12 @@ class DownloadModel(BaseTask): # emitted by Worker
shard_metadata: ShardMetadata
class DownloadDraftModel(BaseTask): # emitted by Worker
"""Download a draft model for speculative decoding (rank 0 only)."""
model_id: str # HuggingFace model ID
class LoadModel(BaseTask): # emitted by Worker
pass
@@ -80,9 +86,17 @@ class Shutdown(BaseTask): # emitted by Worker
runner_id: RunnerId
class SetDraftModel(BaseTask): # emitted by Worker
"""Load or clear a draft model on an already-running instance."""
model_id: str | None # HuggingFace model ID, or None to clear
num_draft_tokens: int = 4
Task = (
CreateRunner
| DownloadModel
| DownloadDraftModel
| ConnectToGroup
| LoadModel
| StartWarmup
@@ -90,4 +104,5 @@ Task = (
| ImageGeneration
| ImageEdits
| Shutdown
| SetDraftModel
)

View File

@@ -2,7 +2,7 @@ from enum import Enum
from pydantic import model_validator
from exo.shared.types.common import Host, Id, NodeId
from exo.shared.types.common import Host, Id, ModelId, NodeId
from exo.shared.types.worker.runners import RunnerId, ShardAssignments, ShardMetadata
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
@@ -19,6 +19,8 @@ class InstanceMeta(str, Enum):
class BaseInstance(TaggedModel):
instance_id: InstanceId
shard_assignments: ShardAssignments
draft_model: ModelId | None = None # For speculative decoding (rank 0 only)
num_draft_tokens: int = 4 # Tokens to draft per iteration (when draft_model is set)
def shard(self, runner_id: RunnerId) -> ShardMetadata | None:
return self.shard_assignments.runner_to_shard.get(runner_id, None)

View File

@@ -62,7 +62,6 @@ class PartialImageResponse(BaseRunnerResponse):
class ToolCallResponse(BaseRunnerResponse):
tool_calls: list[ToolCallItem]
usage: Usage | None
stats: GenerationStats | None = None
class FinishedResponse(BaseRunnerResponse):

View File

@@ -393,11 +393,10 @@ def mlx_generate(
f"Model generated unexpected finish_reason: {out.finish_reason}"
)
total_prompt_tokens = len(all_prompt_tokens)
usage = Usage(
prompt_tokens=total_prompt_tokens,
prompt_tokens=int(out.prompt_tokens),
completion_tokens=completion_tokens,
total_tokens=total_prompt_tokens + completion_tokens,
total_tokens=int(out.prompt_tokens) + completion_tokens,
prompt_tokens_details=PromptTokensDetails(
cached_tokens=prefix_hit_length
),

View File

@@ -223,6 +223,27 @@ def load_mlx_items(
return cast(Model, model), tokenizer
def load_draft_model(model_id: ModelId) -> nn.Module:
"""Load a draft model for speculative decoding (rank 0 only).
Draft models are small models (typically 0.5B-2B parameters) used to
generate candidate tokens quickly, which are then verified by the main
model in a single forward pass.
Assumes the model has already been downloaded by the worker.
Args:
model_id: HuggingFace model ID for the draft model
Returns:
The loaded draft model
"""
model_path = build_model_path(model_id)
draft_model, _ = load_model(model_path, strict=True)
logger.info(f"Loaded draft model from {model_path}")
return draft_model
def shard_and_load(
shard_metadata: ShardMetadata,
group: Group,

View File

@@ -364,7 +364,6 @@ def main(
tool_calls=response.tool_calls,
model=shard_metadata.model_card.model_id,
usage=response.usage,
stats=response.stats,
),
)
)
@@ -765,9 +764,7 @@ def parse_tool_calls(
tools = [_validate_single_tool(tool) for tool in parsed]
else:
tools = [_validate_single_tool(parsed)]
yield ToolCallResponse(
tool_calls=tools, usage=response.usage, stats=response.stats
)
yield ToolCallResponse(tool_calls=tools, usage=response.usage)
except (
json.JSONDecodeError,
@@ -798,8 +795,7 @@ def parse_tool_calls(
text=tool_call_start + "".join(tool_call_text_parts),
token=0,
finish_reason=response.finish_reason,
usage=response.usage,
stats=response.stats,
usage=None,
)
continue
# fallthrough