Compare commits

...

8 Commits

Author SHA1 Message Date
Alex Cheema
3d4f8130c9 Merge remote-tracking branch 'origin/main' into alexcheema/speculative-decoding 2026-02-13 09:39:11 -08:00
Alex Cheema
1c3cc699d4 fix: add missing getModelFitStatus prop to Recent tab (#1470)
## Summary
- Clicking the **Recent** tab in the Model Picker crashed with
`TypeError: e.getModelFitStatus is not a function`
- The `ModelPickerGroup` component in the Recent tab section was missing
the `{getModelFitStatus}` prop, while all other tabs (e.g., the main
model list) passed it correctly
- Added the missing `{getModelFitStatus}` prop so the Recent tab renders
without errors, matching the behavior of the other tabs

## Test plan
- [ ] Open the dashboard and click **SELECT MODEL**
- [ ] Switch to the **Recent** tab — verify it renders without crashing
- [ ] Confirm model fit status indicators display correctly on recent
models
- [ ] Verify the other tabs (All, Favorites) still work as before

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-13 16:37:32 +00:00
Alex Cheema
6762a0a8ba Add draft_model and num_draft_tokens fields to PlaceInstance command
PlaceInstance was missing these fields that are accessed in placement.py
when creating MlxJaccl and MlxRing instances with speculative decoding
support, causing type errors after merging main.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-13 05:51:11 -08:00
Alex Cheema
6524820e5b Merge remote-tracking branch 'origin/main' into alexcheema/speculative-decoding
# Conflicts:
#	src/exo/shared/types/commands.py
2026-02-13 05:46:52 -08:00
rltakashige
5a28642790 Add support for Step 3.5 flash! (#1460)
## Motivation

Working version of #1366 

## Changes

Add Step 3.5 Flash

## Test Plan

### Manual Testing
Works!

### Automated Testing
Running two processes tensor/pipeline sharded gives same logits as
single process.
2026-02-13 12:10:18 +00:00
Alex Cheema
90e2a20091 Merge remote-tracking branch 'origin/main' into alexcheema/speculative-decoding
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-05 06:09:50 -08:00
Alex Cheema
55152fa99d Merge remote-tracking branch 'origin/main' into alexcheema/speculative-decoding
# Conflicts:
#	dashboard/src/routes/+page.svelte
#	src/exo/master/api.py
#	src/exo/shared/types/api.py
#	src/exo/shared/types/commands.py
#	src/exo/worker/engines/mlx/generator/generate.py
#	src/exo/worker/main.py
#	src/exo/worker/plan.py
#	src/exo/worker/runner/runner.py
2026-02-05 06:07:51 -08:00
Alex Cheema
e7f61c3494 Add speculative decoding support with draft models
This adds support for speculative decoding using draft models to accelerate
inference. Key changes:

- Add draft_model and num_draft_tokens fields to Instance for configuration
- Add SetDraftModel task to load/clear draft models on running instances
- Add InstanceDraftModelUpdated event to propagate draft model changes
- Add SetInstanceDraftModel command and API endpoint for runtime updates
- Update plan.py to download draft models in parallel with main model
- Update runner to load draft model during LoadModel phase
- Add draft model UI to dashboard instances panel (both views)

The draft model can be configured when creating an instance or updated on
a running instance via the dashboard or API.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-22 13:09:26 +00:00
16 changed files with 354 additions and 2 deletions

View File

@@ -0,0 +1,151 @@
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
from .switch_layers import SwitchGLU
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
hidden_size: int
num_hidden_layers: int
vocab_size: int
num_attention_heads: int
num_attention_groups: int
head_dim: int
intermediate_size: int
rms_norm_eps: float
rope_theta: float
rope_scaling: Optional[Dict[str, Any]]
max_position_embeddings: int
sliding_window: int
layer_types: Optional[List[str]]
yarn_only_types: Optional[List[str]]
partial_rotary_factors: Optional[List[float]]
attention_other_setting: Optional[Dict[str, Any]]
use_head_wise_attn_gate: bool
moe_num_experts: int
moe_top_k: int
moe_intermediate_size: int
share_expert_dim: int
moe_layers_enum: Optional[str]
moe_router_scaling_factor: float
norm_expert_weight: bool
swiglu_limits: Optional[List[float]]
swiglu_limits_shared: Optional[List[float]]
tie_word_embeddings: bool
class Step3p5MLP(nn.Module):
hidden_size: int
intermediate_size: int
gate_proj: nn.Linear
up_proj: nn.Linear
down_proj: nn.Linear
limit: Optional[float]
def __init__(
self, args: ModelArgs, intermediate_size: int, swiglu_limit: float = 0
) -> None: ...
def __call__(self, x: mx.array) -> mx.array: ...
class Step3p5MoEGate(nn.Module):
top_k: int
n_routed_experts: int
routed_scaling_factor: float
norm_topk_prob: bool
gate: nn.Linear
router_bias: mx.array
def __init__(self, args: ModelArgs) -> None: ...
def __call__(self, x: mx.array) -> tuple[mx.array, mx.array]: ...
class Step3p5MoE(nn.Module):
gate: Step3p5MoEGate
switch_mlp: SwitchGLU
share_expert: Step3p5MLP
sharding_group: Optional[mx.distributed.Group]
def __init__(self, args: ModelArgs, layer_idx: int) -> None: ...
def __call__(self, x: mx.array) -> mx.array: ...
class Step3p5Attention(nn.Module):
is_sliding: bool
num_heads: int
num_kv_heads: int
head_dim: int
scale: float
q_proj: nn.Linear
k_proj: nn.Linear
v_proj: nn.Linear
o_proj: nn.Linear
q_norm: nn.Module
k_norm: nn.Module
use_head_wise_attn_gate: bool
g_proj: nn.Linear
rope: nn.Module
def __init__(self, args: ModelArgs, layer_idx: int) -> None: ...
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array: ...
class Step3p5DecoderLayer(nn.Module):
self_attn: Step3p5Attention
is_sliding: bool
is_moe_layer: bool
mlp: Step3p5MLP | Step3p5MoE
input_layernorm: nn.Module
post_attention_layernorm: nn.Module
def __init__(self, args: ModelArgs, layer_idx: int) -> None: ...
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array: ...
class Step3p5Model(nn.Module):
args: ModelArgs
vocab_size: int
num_layers: int
embed_tokens: nn.Embedding
layers: list[Step3p5DecoderLayer]
norm: nn.Module
_swa_idx: Optional[int]
_full_idx: Optional[int]
def __init__(self, args: ModelArgs) -> None: ...
def __call__(
self,
x: mx.array,
cache: Optional[List[Any]] = None,
) -> mx.array: ...
class Model(nn.Module):
args: ModelArgs
model_type: str
model: Step3p5Model
lm_head: nn.Linear
def __init__(self, args: ModelArgs) -> None: ...
def __call__(
self,
inputs: mx.array,
cache: Optional[List[Any]] = None,
) -> mx.array: ...
def sanitize(self, weights: dict[str, Any]) -> dict[str, Any]: ...
def shard(self, group: Optional[mx.distributed.Group] = None) -> None: ...
@property
def layers(self) -> list[Step3p5DecoderLayer]: ...
def make_cache(self) -> list[Any]: ...
@property
def cast_predicate(self) -> Any: ...
@property
def quant_predicate(self) -> Any: ...

View File

@@ -806,6 +806,7 @@
isFavorite={favorites.has(group.id)}
{selectedModelId}
{canModelFit}
{getModelFitStatus}
onToggleExpand={() => toggleGroupExpanded(group.id)}
onSelectModel={handleSelect}
{onToggleFavorite}

View File

@@ -72,6 +72,8 @@ export interface Instance {
runnerToShard?: Record<string, unknown>;
nodeToRunner?: Record<string, string>;
};
draftModel?: string;
numDraftTokens?: number;
}
// Granular node state types from the new state structure

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/Step-3.5-Flash-4bit"
n_layers = 45
hidden_size = 4096
supports_tensor = true
tasks = ["TextGeneration"]
family = "step"
quantization = "4bit"
base_model = "Step 3.5 Flash"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 114572190076

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/Step-3.5-Flash-6bit"
n_layers = 45
hidden_size = 4096
supports_tensor = true
tasks = ["TextGeneration"]
family = "step"
quantization = "6bit"
base_model = "Step 3.5 Flash"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 159039627774

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/Step-3.5-Flash-8Bit"
n_layers = 45
hidden_size = 4096
supports_tensor = true
tasks = ["TextGeneration"]
family = "step"
quantization = "8bit"
base_model = "Step 3.5 Flash"
capabilities = ["text", "thinking"]
[storage_size]
in_bytes = 209082699847

View File

@@ -24,6 +24,7 @@ from exo.shared.types.commands import (
PlaceInstance,
RequestEventLog,
SendInputChunk,
SetInstanceDraftModel,
TaskFinished,
TestCommand,
TextGeneration,
@@ -35,6 +36,7 @@ from exo.shared.types.events import (
IndexedEvent,
InputChunkReceived,
InstanceDeleted,
InstanceDraftModelUpdated,
NodeGatheredInfo,
NodeTimedOut,
TaskCreated,
@@ -319,6 +321,14 @@ class Master:
chunk=chunk,
)
)
case SetInstanceDraftModel():
generated_events.append(
InstanceDraftModelUpdated(
instance_id=command.instance_id,
draft_model=command.draft_model,
num_draft_tokens=command.num_draft_tokens,
)
)
case TaskFinished():
generated_events.append(
TaskDeleted(

View File

@@ -153,6 +153,8 @@ def place_instance(
shard_assignments=shard_assignments,
jaccl_devices=mlx_jaccl_devices,
jaccl_coordinators=mlx_jaccl_coordinators,
draft_model=command.draft_model,
num_draft_tokens=command.num_draft_tokens,
)
case InstanceMeta.MlxRing:
ephemeral_port = random_ephemeral_port()
@@ -167,6 +169,8 @@ def place_instance(
shard_assignments=shard_assignments,
hosts_by_node=hosts_by_node,
ephemeral_port=ephemeral_port,
draft_model=command.draft_model,
num_draft_tokens=command.num_draft_tokens,
)
return target_instances

View File

@@ -12,6 +12,7 @@ from exo.shared.types.events import (
InputChunkReceived,
InstanceCreated,
InstanceDeleted,
InstanceDraftModelUpdated,
NodeDownloadProgress,
NodeGatheredInfo,
NodeTimedOut,
@@ -72,6 +73,8 @@ def event_apply(event: Event, state: State) -> State:
return apply_instance_created(event, state)
case InstanceDeleted():
return apply_instance_deleted(event, state)
case InstanceDraftModelUpdated():
return apply_instance_draft_model_updated(event, state)
case NodeTimedOut():
return apply_node_timed_out(event, state)
case NodeDownloadProgress():
@@ -190,6 +193,25 @@ def apply_instance_deleted(event: InstanceDeleted, state: State) -> State:
return state.model_copy(update={"instances": new_instances})
def apply_instance_draft_model_updated(
event: InstanceDraftModelUpdated, state: State
) -> State:
if event.instance_id not in state.instances:
return state
instance = state.instances[event.instance_id]
updated_instance = instance.model_copy(
update={
"draft_model": event.draft_model,
"num_draft_tokens": event.num_draft_tokens,
}
)
new_instances: Mapping[InstanceId, Instance] = {
**state.instances,
event.instance_id: updated_instance,
}
return state.model_copy(update={"instances": new_instances})
def apply_runner_status_updated(event: RunnerStatusUpdated, state: State) -> State:
new_runners: Mapping[RunnerId, RunnerStatus] = {
**state.runners,

View File

@@ -189,6 +189,7 @@ class ConfigData(BaseModel):
["MiniMaxM2ForCausalLM"],
["LlamaForCausalLM"],
["GptOssForCausalLM"],
["Step3p5ForCausalLM"],
]
@model_validator(mode="before")

View File

@@ -38,6 +38,8 @@ class PlaceInstance(BaseCommand):
sharding: Sharding
instance_meta: InstanceMeta
min_nodes: int
draft_model: ModelId | None = None
num_draft_tokens: int = 4
class CreateInstance(BaseCommand):
@@ -72,6 +74,14 @@ class DeleteDownload(BaseCommand):
model_id: ModelId
class SetInstanceDraftModel(BaseCommand):
"""Set or update the draft model for an existing instance."""
instance_id: InstanceId
draft_model: ModelId | None # None to disable speculative decoding
num_draft_tokens: int = 4
class CancelDownload(BaseCommand):
target_node_id: NodeId
model_id: ModelId
@@ -89,6 +99,7 @@ Command = (
| PlaceInstance
| CreateInstance
| DeleteInstance
| SetInstanceDraftModel
| TaskFinished
| SendInputChunk
)

View File

@@ -5,7 +5,7 @@ from pydantic import Field
from exo.shared.topology import Connection
from exo.shared.types.chunks import GenerationChunk, InputImageChunk
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
from exo.shared.types.common import CommandId, Id, ModelId, NodeId, SessionId
from exo.shared.types.tasks import Task, TaskId, TaskStatus
from exo.shared.types.worker.downloads import DownloadProgress
from exo.shared.types.worker.instances import Instance, InstanceId
@@ -68,6 +68,14 @@ class InstanceDeleted(BaseEvent):
instance_id: InstanceId
class InstanceDraftModelUpdated(BaseEvent):
"""Draft model updated on an existing instance."""
instance_id: InstanceId
draft_model: ModelId | None
num_draft_tokens: int
class RunnerStatusUpdated(BaseEvent):
runner_id: RunnerId
runner_status: RunnerStatus
@@ -141,6 +149,7 @@ Event = (
| TaskAcknowledged
| InstanceCreated
| InstanceDeleted
| InstanceDraftModelUpdated
| RunnerStatusUpdated
| RunnerDeleted
| NodeTimedOut

View File

@@ -40,6 +40,12 @@ class DownloadModel(BaseTask): # emitted by Worker
shard_metadata: ShardMetadata
class DownloadDraftModel(BaseTask): # emitted by Worker
"""Download a draft model for speculative decoding (rank 0 only)."""
model_id: str # HuggingFace model ID
class LoadModel(BaseTask): # emitted by Worker
pass
@@ -80,9 +86,17 @@ class Shutdown(BaseTask): # emitted by Worker
runner_id: RunnerId
class SetDraftModel(BaseTask): # emitted by Worker
"""Load or clear a draft model on an already-running instance."""
model_id: str | None # HuggingFace model ID, or None to clear
num_draft_tokens: int = 4
Task = (
CreateRunner
| DownloadModel
| DownloadDraftModel
| ConnectToGroup
| LoadModel
| StartWarmup
@@ -90,4 +104,5 @@ Task = (
| ImageGeneration
| ImageEdits
| Shutdown
| SetDraftModel
)

View File

@@ -2,7 +2,7 @@ from enum import Enum
from pydantic import model_validator
from exo.shared.types.common import Host, Id, NodeId
from exo.shared.types.common import Host, Id, ModelId, NodeId
from exo.shared.types.worker.runners import RunnerId, ShardAssignments, ShardMetadata
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
@@ -19,6 +19,8 @@ class InstanceMeta(str, Enum):
class BaseInstance(TaggedModel):
instance_id: InstanceId
shard_assignments: ShardAssignments
draft_model: ModelId | None = None # For speculative decoding (rank 0 only)
num_draft_tokens: int = 4 # Tokens to draft per iteration (when draft_model is set)
def shard(self, runner_id: RunnerId) -> ShardMetadata | None:
return self.shard_assignments.runner_to_shard.get(runner_id, None)

View File

@@ -35,6 +35,9 @@ from mlx_lm.models.qwen3_moe import Model as Qwen3MoeModel
from mlx_lm.models.qwen3_moe import Qwen3MoeSparseMoeBlock
from mlx_lm.models.qwen3_next import Model as Qwen3NextModel
from mlx_lm.models.qwen3_next import Qwen3NextDecoderLayer, Qwen3NextSparseMoeBlock
from mlx_lm.models.step3p5 import Model as Step35Model
from mlx_lm.models.step3p5 import Step3p5MLP as Step35MLP
from mlx_lm.models.step3p5 import Step3p5Model as Step35InnerModel
from transformers.models.qwen3.modeling_qwen3 import Qwen3DecoderLayer
from exo.shared.logging import logger
@@ -264,6 +267,19 @@ def pipeline_auto_parallel(
)
)
if isinstance(inner_model_instance, Step35InnerModel):
inner_model_instance.num_layers = len(layers)
sliding_layers = [
i for i, layer in enumerate(layers) if getattr(layer, "is_sliding", False)
]
full_layers = [
i
for i, layer in enumerate(layers)
if not getattr(layer, "is_sliding", True)
]
inner_model_instance._swa_idx = 0 if not sliding_layers else sliding_layers[0]
inner_model_instance._full_idx = 0 if not full_layers else full_layers[0]
_set_layers(model, layers)
assert isinstance(layers, list), (
@@ -427,6 +443,14 @@ def tensor_auto_parallel(
all_to_sharded_linear_in_place,
sharded_to_all_linear_in_place,
)
elif isinstance(model, Step35Model):
tensor_parallel_sharding_strategy = Step35ShardingStrategy(
group,
all_to_sharded_linear,
sharded_to_all_linear,
all_to_sharded_linear_in_place,
sharded_to_all_linear_in_place,
)
else:
raise ValueError(f"Unsupported model type: {type(model)}")
@@ -981,3 +1005,46 @@ class GptOssShardingStrategy(TensorParallelShardingStrategy):
layer.mlp.sharding_group = self.group # pyright: ignore[reportAttributeAccessIssue]
mx.eval(layer)
return model
class Step35ShardingStrategy(TensorParallelShardingStrategy):
def shard_model(
self,
model: nn.Module,
timeout_seconds: float,
on_timeout: TimeoutCallback | None,
) -> nn.Module:
model = cast(Step35Model, model)
for layer in model.layers:
eval_with_timeout(
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
)
layer.self_attn.q_proj = self.all_to_sharded_linear(layer.self_attn.q_proj)
layer.self_attn.k_proj = self.all_to_sharded_linear(layer.self_attn.k_proj)
layer.self_attn.v_proj = self.all_to_sharded_linear(layer.self_attn.v_proj)
layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)
layer.self_attn.num_heads //= self.N
layer.self_attn.num_kv_heads //= self.N
if getattr(layer.self_attn, "use_head_wise_attn_gate", False):
layer.self_attn.g_proj = self.all_to_sharded_linear(
layer.self_attn.g_proj
)
if isinstance(layer.mlp, Step35MLP):
layer.mlp.gate_proj = self.all_to_sharded_linear(layer.mlp.gate_proj)
layer.mlp.up_proj = self.all_to_sharded_linear(layer.mlp.up_proj)
layer.mlp.down_proj = self.sharded_to_all_linear(layer.mlp.down_proj)
else:
layer.mlp.sharding_group = self.group
self.all_to_sharded_linear_in_place(layer.mlp.share_expert.gate_proj)
self.all_to_sharded_linear_in_place(layer.mlp.share_expert.up_proj)
self.sharded_to_all_linear_in_place(layer.mlp.share_expert.down_proj)
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.gate_proj)
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.up_proj)
self.sharded_to_all_linear_in_place(layer.mlp.switch_mlp.down_proj)
mx.eval(layer)
return model

View File

@@ -223,6 +223,27 @@ def load_mlx_items(
return cast(Model, model), tokenizer
def load_draft_model(model_id: ModelId) -> nn.Module:
"""Load a draft model for speculative decoding (rank 0 only).
Draft models are small models (typically 0.5B-2B parameters) used to
generate candidate tokens quickly, which are then verified by the main
model in a single forward pass.
Assumes the model has already been downloaded by the worker.
Args:
model_id: HuggingFace model ID for the draft model
Returns:
The loaded draft model
"""
model_path = build_model_path(model_id)
draft_model, _ = load_model(model_path, strict=True)
logger.info(f"Loaded draft model from {model_path}")
return draft_model
def shard_and_load(
shard_metadata: ShardMetadata,
group: Group,