Compare commits

..

1 Commits

Author SHA1 Message Date
Alex Cheema
c93376f0fb Add speculative decoding support with draft models
Implements speculative decoding using MLX-LM's built-in stream_generate(draft_model=...)
to accelerate inference. A small draft model generates candidate tokens which are
verified by the main model in a single forward pass.

Key changes:
- Add draft_model and num_draft_tokens to instance configuration
- Auto-download draft models during warmup if not present
- Dashboard UI for selecting draft model and token count
- Display draft model info on running instance cards

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-18 02:42:22 +00:00
18 changed files with 284 additions and 1767 deletions

View File

@@ -69,6 +69,8 @@ export interface Instance {
runnerToShard?: Record<string, unknown>;
nodeToRunner?: Record<string, string>;
};
draftModel?: string;
numDraftTokens?: number;
}
interface RawNodeProfile {

View File

@@ -47,7 +47,7 @@ const sidebarVisible = $derived(chatSidebarVisible());
let mounted = $state(false);
// Instance launch state
let models = $state<Array<{id: string, name?: string, storage_size_megabytes?: number}>>([]);
let models = $state<Array<{id: string, hugging_face_id?: string, name?: string, storage_size_megabytes?: number}>>([]);
let selectedSharding = $state<'Pipeline' | 'Tensor'>('Pipeline');
type InstanceMeta = 'MlxRing' | 'MlxIbv' | 'MlxJaccl';
@@ -58,6 +58,8 @@ const sidebarVisible = $derived(chatSidebarVisible());
sharding: 'Pipeline' | 'Tensor';
instanceType: InstanceMeta;
minNodes: number;
draftModel: string | null;
numDraftTokens: number;
}
function saveLaunchDefaults(): void {
@@ -66,6 +68,8 @@ const sidebarVisible = $derived(chatSidebarVisible());
sharding: selectedSharding,
instanceType: selectedInstanceType,
minNodes: selectedMinNodes,
draftModel: selectedDraftModel,
numDraftTokens: selectedNumDraftTokens,
};
try {
localStorage.setItem(LAUNCH_DEFAULTS_KEY, JSON.stringify(defaults));
@@ -88,24 +92,36 @@ const sidebarVisible = $derived(chatSidebarVisible());
function applyLaunchDefaults(availableModels: Array<{id: string}>, maxNodes: number): void {
const defaults = loadLaunchDefaults();
if (!defaults) return;
// Apply sharding and instance type unconditionally
selectedSharding = defaults.sharding;
selectedInstanceType = defaults.instanceType;
// Apply minNodes if valid (between 1 and maxNodes)
if (defaults.minNodes && defaults.minNodes >= 1 && defaults.minNodes <= maxNodes) {
selectedMinNodes = defaults.minNodes;
}
// Only apply model if it exists in the available models
if (defaults.modelId && availableModels.some(m => m.id === defaults.modelId)) {
selectPreviewModel(defaults.modelId);
}
// Apply draft model if it exists in the available models (check against hugging_face_id)
if (defaults.draftModel && availableModels.some(m => (m as {hugging_face_id?: string}).hugging_face_id === defaults.draftModel)) {
selectedDraftModel = defaults.draftModel;
}
// Apply num draft tokens if valid
if (defaults.numDraftTokens && defaults.numDraftTokens >= 1 && defaults.numDraftTokens <= 10) {
selectedNumDraftTokens = defaults.numDraftTokens;
}
}
let selectedInstanceType = $state<InstanceMeta>('MlxRing');
let selectedMinNodes = $state<number>(1);
let selectedDraftModel = $state<string | null>(null);
let selectedNumDraftTokens = $state<number>(4);
let minNodesInitialized = $state(false);
let launchingModelId = $state<string | null>(null);
let instanceDownloadExpandedNodes = $state<Set<string>>(new Set());
@@ -113,6 +129,8 @@ let instanceDownloadExpandedNodes = $state<Set<string>>(new Set());
// Custom dropdown state
let isModelDropdownOpen = $state(false);
let modelDropdownSearch = $state('');
let isDraftModelDropdownOpen = $state(false);
let draftModelDropdownSearch = $state('');
// Slider dragging state
let isDraggingSlider = $state(false);
@@ -362,47 +380,39 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
async function launchInstance(modelId: string, specificPreview?: PlacementPreview | null) {
if (!modelId || launchingModelId) return;
launchingModelId = modelId;
try {
// Use the specific preview if provided, otherwise fall back to filtered preview
const preview = specificPreview ?? filteredPreview();
let instanceData: unknown;
if (preview?.instance) {
// Use the instance from the preview
instanceData = preview.instance;
} else {
// Fallback: GET placement from API
const placementResponse = await fetch(
`/instance/placement?model_id=${encodeURIComponent(modelId)}&sharding=${selectedSharding}&instance_meta=${selectedInstanceType}&min_nodes=${selectedMinNodes}`
);
if (!placementResponse.ok) {
const errorText = await placementResponse.text();
console.error('Failed to get placement:', errorText);
return;
}
instanceData = await placementResponse.json();
}
// POST the instance to create it
const response = await fetch('/instance', {
let response: Response;
// Use /place_instance endpoint - it handles placement and creation in one step
// This also supports draft_model for speculative decoding
const placePayload = {
model_id: modelId,
sharding: preview?.sharding ?? selectedSharding,
instance_meta: preview?.instance_meta ?? selectedInstanceType,
min_nodes: selectedMinNodes,
draft_model: selectedDraftModel,
num_draft_tokens: selectedDraftModel ? selectedNumDraftTokens : 4,
};
response = await fetch('/place_instance', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ instance: instanceData })
body: JSON.stringify(placePayload)
});
if (!response.ok) {
const errorText = await response.text();
console.error('Failed to launch instance:', errorText);
} else {
// Always auto-select the newly launched model so the user chats to what they just launched
setSelectedChatModel(modelId);
// Scroll to the bottom of instances container to show the new instance
// Use multiple attempts to ensure DOM has updated with the new instance
const scrollToBottom = () => {
@@ -816,30 +826,34 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
}
// Get instance details: type (MLX Ring/IBV), sharding (Pipeline/Tensor), and node names
function getInstanceInfo(instanceWrapped: unknown): {
instanceType: string;
sharding: string;
function getInstanceInfo(instanceWrapped: unknown): {
instanceType: string;
sharding: string;
nodeNames: string[];
nodeIds: string[];
nodeCount: number;
draftModel: string | null;
numDraftTokens: number | null;
} {
const [instanceTag, instance] = getTagged(instanceWrapped);
if (!instance || typeof instance !== 'object') {
return { instanceType: 'Unknown', sharding: 'Unknown', nodeNames: [], nodeIds: [], nodeCount: 0 };
return { instanceType: 'Unknown', sharding: 'Unknown', nodeNames: [], nodeIds: [], nodeCount: 0, draftModel: null, numDraftTokens: null };
}
// Instance type from tag
let instanceType = 'Unknown';
if (instanceTag === 'MlxRingInstance') instanceType = 'MLX Ring';
else if (instanceTag === 'MlxIbvInstance' || instanceTag === 'MlxJacclInstance') instanceType = 'MLX RDMA';
const inst = instance as {
shardAssignments?: {
nodeToRunner?: Record<string, string>;
const inst = instance as {
shardAssignments?: {
nodeToRunner?: Record<string, string>;
runnerToShard?: Record<string, unknown>;
}
};
draftModel?: string;
numDraftTokens?: number;
};
// Sharding strategy from first shard
let sharding = 'Unknown';
const runnerToShard = inst.shardAssignments?.runnerToShard || {};
@@ -850,7 +864,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
else if (shardTag === 'TensorShardMetadata') sharding = 'Tensor';
else if (shardTag === 'PrefillDecodeShardMetadata') sharding = 'Prefill/Decode';
}
// Node names from topology
const nodeToRunner = inst.shardAssignments?.nodeToRunner || {};
const nodeIds = Object.keys(nodeToRunner);
@@ -858,8 +872,12 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
const node = data?.nodes?.[nodeId];
return node?.friendly_name || nodeId.slice(0, 8);
});
return { instanceType, sharding, nodeNames, nodeIds, nodeCount: nodeIds.length };
// Draft model for speculative decoding
const draftModel = inst.draftModel ?? null;
const numDraftTokens = inst.numDraftTokens ?? null;
return { instanceType, sharding, nodeNames, nodeIds, nodeCount: nodeIds.length, draftModel, numDraftTokens };
}
function formatLastUpdate(): string {
@@ -1345,6 +1363,9 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
<div class="pl-2">
<div class="text-exo-yellow text-xs font-mono tracking-wide truncate">{getInstanceModelId(instance)}</div>
<div class="text-white/60 text-xs font-mono">Strategy: <span class="text-white/80">{instanceInfo.sharding} ({instanceInfo.instanceType})</span></div>
{#if instanceInfo.draftModel}
<div class="text-white/60 text-xs font-mono">Draft: <span class="text-cyan-400">{instanceInfo.draftModel.split('/').pop()}</span>{#if instanceInfo.numDraftTokens}<span class="text-white/40"> ({instanceInfo.numDraftTokens}t)</span>{/if}</div>
{/if}
{#if instanceModelId && instanceModelId !== 'Unknown' && instanceModelId !== 'Unknown Model'}
<a
class="inline-flex items-center gap-1 text-[11px] text-white/60 hover:text-exo-yellow transition-colors mt-1"
@@ -1678,8 +1699,80 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
{/each}
</div>
</div>
<!-- Draft Model (Speculative Decoding) -->
<div>
<div class="text-xs text-white/70 font-mono mb-2">Draft Model (Speculative):</div>
<div class="relative">
<button
onclick={() => { isDraftModelDropdownOpen = !isDraftModelDropdownOpen; draftModelDropdownSearch = ''; }}
class="w-full px-3 py-2 text-left text-sm font-mono border rounded transition-all duration-200 cursor-pointer flex items-center justify-between gap-2 {selectedDraftModel ? 'bg-transparent text-exo-yellow border-exo-yellow' : 'bg-transparent text-white/50 border-exo-medium-gray/50 hover:border-exo-yellow/50'}"
>
<span class="truncate">{selectedDraftModel ? selectedDraftModel.split('/').pop() : 'None'}</span>
<svg class="w-4 h-4 flex-shrink-0 transition-transform {isDraftModelDropdownOpen ? 'rotate-180' : ''}" fill="none" stroke="currentColor" viewBox="0 0 24 24">
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M19 9l-7 7-7-7" />
</svg>
</button>
{#if isDraftModelDropdownOpen}
<!-- svelte-ignore a11y_no_static_element_interactions -->
<div
class="fixed inset-0 z-40"
onclick={() => isDraftModelDropdownOpen = false}
onkeydown={(e) => e.key === 'Escape' && (isDraftModelDropdownOpen = false)}
></div>
<div class="absolute top-full left-0 right-0 mt-1 bg-exo-dark-gray border border-exo-medium-gray/50 rounded shadow-lg z-50 max-h-48 overflow-hidden flex flex-col">
<div class="p-2 border-b border-exo-medium-gray/30">
<input
type="text"
bind:value={draftModelDropdownSearch}
placeholder="Search models..."
class="w-full px-2 py-1.5 text-sm font-mono bg-transparent border border-exo-medium-gray/50 rounded text-white/90 placeholder:text-white/30 focus:outline-none focus:border-exo-yellow/50"
/>
</div>
<div class="overflow-y-auto max-h-36">
<!-- None option -->
<button
onclick={() => { selectedDraftModel = null; isDraftModelDropdownOpen = false; saveLaunchDefaults(); }}
class="w-full px-3 py-2 text-left text-sm font-mono tracking-wide transition-colors duration-100 flex items-center gap-2 {selectedDraftModel === null ? 'bg-transparent text-exo-yellow cursor-pointer' : 'text-white/80 hover:text-exo-yellow cursor-pointer'}"
>
<span>None</span>
</button>
{#each models.filter(m => (m.name ?? m.id).toLowerCase().includes(draftModelDropdownSearch.toLowerCase()) && m.id !== selectedModelId) as model}
{@const sizeGB = (model.storage_size_megabytes ?? 0) / 1024}
{@const modelHfId = model.hugging_face_id ?? model.id}
<button
onclick={() => { selectedDraftModel = modelHfId; isDraftModelDropdownOpen = false; saveLaunchDefaults(); }}
class="w-full px-3 py-2 text-left text-sm font-mono tracking-wide transition-colors duration-100 flex items-center justify-between gap-2 {selectedDraftModel === modelHfId ? 'bg-transparent text-exo-yellow cursor-pointer' : 'text-white/80 hover:text-exo-yellow cursor-pointer'}"
>
<span class="truncate">{model.name || model.id}</span>
<span class="flex-shrink-0 text-xs text-white/50">
{sizeGB >= 1 ? sizeGB.toFixed(0) : sizeGB.toFixed(1)}GB
</span>
</button>
{:else}
<div class="px-3 py-2 text-xs text-white/50 font-mono">No models found</div>
{/each}
</div>
</div>
{/if}
</div>
</div>
<!-- Draft Tokens (only show when draft model selected) -->
{#if selectedDraftModel}
<div class="flex items-center gap-2 mt-2">
<span class="text-xs text-white/50 font-mono">Tokens:</span>
<div class="flex items-center gap-1">
{#each [2, 3, 4, 5, 6] as n}
<button
onclick={() => { selectedNumDraftTokens = n; saveLaunchDefaults(); }}
class="w-6 h-6 text-xs font-mono rounded transition-all {selectedNumDraftTokens === n ? 'bg-exo-yellow/20 text-exo-yellow border border-exo-yellow/50' : 'text-white/50 hover:text-white/80 border border-transparent'}"
>{n}</button>
{/each}
</div>
</div>
{/if}
</div>
<!-- Selected Model Preview -->
<div class="space-y-3">
{#if models.length === 0}

View File

@@ -200,6 +200,8 @@ class API:
sharding=payload.sharding,
instance_meta=payload.instance_meta,
min_nodes=payload.min_nodes,
draft_model=payload.draft_model,
num_draft_tokens=payload.num_draft_tokens,
)
await self._send(command)

View File

@@ -151,6 +151,8 @@ def place_instance(
shard_assignments=shard_assignments,
ibv_devices=mlx_ibv_devices,
jaccl_coordinators=mlx_jaccl_coordinators,
draft_model=command.draft_model,
num_draft_tokens=command.num_draft_tokens,
)
case InstanceMeta.MlxRing:
ephemeral_port = random_ephemeral_port()
@@ -164,6 +166,8 @@ def place_instance(
shard_assignments=shard_assignments,
hosts_by_node=hosts_by_node,
ephemeral_port=ephemeral_port,
draft_model=command.draft_model,
num_draft_tokens=command.num_draft_tokens,
)
return target_instances

View File

@@ -161,6 +161,8 @@ class ChatCompletionTaskParams(BaseModel):
tool_choice: str | dict[str, Any] | None = None
parallel_tool_calls: bool | None = None
user: str | None = None
# Speculative decoding: tokens to draft per iteration (if instance has draft model)
num_draft_tokens: int = 3
class BenchChatCompletionTaskParams(ChatCompletionTaskParams):
@@ -172,6 +174,8 @@ class PlaceInstanceParams(BaseModel):
sharding: Sharding = Sharding.Pipeline
instance_meta: InstanceMeta = InstanceMeta.MlxRing
min_nodes: int = 1
draft_model: ModelId | None = None # For speculative decoding
num_draft_tokens: int = 4 # Tokens to draft per iteration
@field_validator("sharding", "instance_meta", mode="plain")
@classmethod

View File

@@ -2,7 +2,7 @@ from pydantic import Field
from exo.shared.types.api import ChatCompletionTaskParams
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.models import ModelMetadata
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.shards import Sharding
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
@@ -25,6 +25,8 @@ class PlaceInstance(BaseCommand):
sharding: Sharding
instance_meta: InstanceMeta
min_nodes: int
draft_model: ModelId | None = None # For speculative decoding
num_draft_tokens: int = 4 # Tokens to draft per iteration
class CreateInstance(BaseCommand):

View File

@@ -3,6 +3,7 @@ from enum import Enum
from pydantic import model_validator
from exo.shared.types.common import Host, Id, NodeId
from exo.shared.types.models import ModelId
from exo.shared.types.worker.runners import RunnerId, ShardAssignments, ShardMetadata
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
@@ -19,6 +20,8 @@ class InstanceMeta(str, Enum):
class BaseInstance(TaggedModel):
instance_id: InstanceId
shard_assignments: ShardAssignments
draft_model: ModelId | None = None # For speculative decoding (rank 0 only)
num_draft_tokens: int = 4 # Tokens to draft per iteration (when draft_model is set)
def shard(self, runner_id: RunnerId) -> ShardMetadata | None:
return self.shard_assignments.runner_to_shard.get(runner_id, None)

View File

@@ -13,8 +13,3 @@ KV_CACHE_BITS: int | None = None
# TODO: We should really make this opt-in, but Kimi requires trust_remote_code=True
TRUST_REMOTE_CODE: bool = True
# Multi-Token Prediction (MTP) configuration for DeepSeek V3
# MTP enables speculative decoding using the model's built-in draft layer
MTP_ENABLED: bool = True # Feature flag to enable/disable MTP
MTP_NUM_DRAFT_TOKENS: int = 1 # Number of tokens to draft (vLLM reports k=1 is optimal)

View File

@@ -19,13 +19,7 @@ from exo.shared.types.worker.runner_response import (
GenerationResponse,
)
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.constants import (
KV_BITS,
KV_GROUP_SIZE,
MAX_TOKENS,
MTP_ENABLED,
MTP_NUM_DRAFT_TOKENS,
)
from exo.worker.engines.mlx.constants import KV_BITS, KV_GROUP_SIZE, MAX_TOKENS
from exo.worker.engines.mlx.utils_mlx import (
apply_chat_template,
make_kv_cache,
@@ -121,15 +115,12 @@ def eos_ids_from_tokenizer(tokenizer: TokenizerWrapper) -> list[int]:
return eos
def _has_mtp_module(model: Model) -> bool:
"""Check if the model has an attached MTP module."""
return hasattr(model, "mtp_module") and model.mtp_module is not None # type: ignore[attr-defined]
def mlx_generate(
model: Model,
tokenizer: TokenizerWrapper,
task: ChatCompletionTaskParams,
draft_model: Model | None = None,
num_draft_tokens: int = 4,
) -> Generator[GenerationResponse]:
# Ensure that generation stats only contains peak memory for this generation
mx.reset_peak_memory()
@@ -146,8 +137,6 @@ def mlx_generate(
chat_task_data=task,
)
caches = make_kv_cache(model=model)
logits_processors: list[Callable[[mx.array, mx.array], mx.array]] = []
if is_bench:
# Only sample length eos tokens
@@ -161,55 +150,30 @@ def mlx_generate(
max_tokens = task.max_tokens or MAX_TOKENS
# Check if we should use MTP speculative decoding
use_mtp = MTP_ENABLED and _has_mtp_module(model)
# Build kwargs for stream_generate, conditionally adding draft model params
generate_kwargs: dict[str, object] = {
"model": model,
"tokenizer": tokenizer,
"prompt": prompt,
"max_tokens": max_tokens,
"sampler": sampler,
"logits_processors": logits_processors,
"prefill_step_size": 2048,
"kv_group_size": KV_GROUP_SIZE,
"kv_bits": KV_BITS,
}
if use_mtp:
logger.info("Using MTP speculative decoding")
yield from _mlx_generate_with_mtp(
model=model,
tokenizer=tokenizer,
prompt=prompt,
max_tokens=max_tokens,
sampler=sampler,
logits_processors=logits_processors,
prompt_cache=caches,
)
# Add speculative decoding parameters if draft model is provided
# Note: When using draft_model, we let mlx_lm create its own trimmable cache
# as speculative decoding requires cache trimming capabilities
if draft_model is not None:
generate_kwargs["draft_model"] = draft_model
generate_kwargs["num_draft_tokens"] = num_draft_tokens
else:
yield from _mlx_generate_standard(
model=model,
tokenizer=tokenizer,
prompt=prompt,
max_tokens=max_tokens,
sampler=sampler,
logits_processors=logits_processors,
prompt_cache=caches,
)
# Only use custom cache for non-speculative generation
generate_kwargs["prompt_cache"] = make_kv_cache(model=model)
def _mlx_generate_standard(
model: Model,
tokenizer: TokenizerWrapper,
prompt: str,
max_tokens: int,
sampler: Callable[[mx.array], mx.array],
logits_processors: list[Callable[[mx.array, mx.array], mx.array]],
prompt_cache: list[KVCache | Any],
) -> Generator[GenerationResponse]:
"""Standard generation path using mlx_lm stream_generate."""
for out in stream_generate(
model=model,
tokenizer=tokenizer,
prompt=prompt,
max_tokens=max_tokens,
sampler=sampler,
logits_processors=logits_processors,
prompt_cache=prompt_cache,
# TODO: Dynamically change prefill step size to be the maximum possible without timing out.
prefill_step_size=2048,
kv_group_size=KV_GROUP_SIZE,
kv_bits=KV_BITS,
):
for out in stream_generate(**generate_kwargs): # type: ignore[arg-type]
logger.info(out.text)
stats: GenerationStats | None = None
@@ -239,64 +203,4 @@ def _mlx_generate_standard(
if out.finish_reason is not None:
break
def _mlx_generate_with_mtp(
model: Model,
tokenizer: TokenizerWrapper,
prompt: str,
max_tokens: int,
sampler: Callable[[mx.array], mx.array],
logits_processors: list[Callable[[mx.array, mx.array], mx.array]],
prompt_cache: list[KVCache | Any],
) -> Generator[GenerationResponse]:
"""MTP speculative decoding generation path.
Uses the model's attached MTP module for speculative decoding,
which can provide 1.5-2x speedup with ~81% acceptance rate.
"""
from exo.worker.engines.mlx.mtp.speculative_decode import mtp_speculative_generate
mtp_module = model.mtp_module # type: ignore[attr-defined]
for out in mtp_speculative_generate(
model=model,
mtp_module=mtp_module,
tokenizer=tokenizer,
prompt=prompt,
max_tokens=max_tokens,
sampler=sampler,
logits_processors=logits_processors,
prompt_cache=prompt_cache,
num_draft_tokens=MTP_NUM_DRAFT_TOKENS,
prefill_step_size=2048,
kv_group_size=KV_GROUP_SIZE if KV_GROUP_SIZE is not None else 64,
kv_bits=KV_BITS,
):
logger.info(f"{out.text} (from_draft={out.from_draft})")
stats: GenerationStats | None = None
if out.finish_reason is not None:
stats = GenerationStats(
prompt_tps=float(out.prompt_tps),
generation_tps=float(out.generation_tps),
prompt_tokens=int(out.prompt_tokens),
generation_tokens=int(out.generation_tokens),
peak_memory_usage=Memory.from_gb(out.peak_memory),
)
if out.finish_reason not in get_args(FinishReason):
logger.warning(
f"Model generated unexpected finish_reason: {out.finish_reason}"
)
yield GenerationResponse(
text=out.text,
token=out.token,
finish_reason=cast(FinishReason | None, out.finish_reason),
stats=stats,
)
if out.finish_reason is not None:
break
# TODO: Do we want an mx_barrier?

View File

@@ -1,6 +0,0 @@
"""Multi-Token Prediction (MTP) module for DeepSeek V3 speculative decoding."""
from exo.worker.engines.mlx.mtp.module import MTPModule
from exo.worker.engines.mlx.mtp.speculative_decode import mtp_speculative_generate
__all__ = ["MTPModule", "mtp_speculative_generate"]

View File

@@ -1,207 +0,0 @@
"""MTP Module for DeepSeek V3 Multi-Token Prediction.
The MTP architecture predicts one additional token ahead using:
1. hnorm - RMSNorm for hidden state normalization
2. enorm - RMSNorm for embedding normalization
3. eh_proj - Linear(2*hidden_size -> hidden_size) projection
4. transformer_block - Single decoder layer (attention + MLP)
5. Shared embedding/lm_head from main model
Forward pass:
h_norm = hnorm(hidden_state)
e_norm = enorm(embed(token))
projected = eh_proj(concat([h_norm, e_norm]))
new_hidden = transformer_block(projected)
logits = lm_head(output_norm(new_hidden))
"""
from typing import Any
import mlx.core as mx
import mlx.nn as nn
from mlx_lm.models.cache import KVCache
from mlx_lm.models.deepseek_v3 import (
DeepseekV3Attention,
DeepseekV3MLP,
ModelArgs,
)
MTP_LAYER_INDEX = 61
class MTPModule(nn.Module):
"""Multi-Token Prediction module for DeepSeek V3.
This module is initialized from the layer 61 weights that are normally
discarded during model loading. It enables speculative decoding by
predicting one token ahead using the hidden state from the main model.
"""
def __init__(
self,
config: ModelArgs,
shared_embedding: nn.Embedding,
shared_lm_head: nn.Linear,
output_norm: nn.RMSNorm,
) -> None:
super().__init__()
self.config = config
# MTP-specific normalization layers
self.hnorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.enorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
# Projection: concatenated [hidden, embedding] -> hidden_size
self.eh_proj = nn.Linear(2 * config.hidden_size, config.hidden_size, bias=False)
# Single transformer block for MTP
# Use a dense MLP since this is just a single layer
self.transformer_block = MTPTransformerBlock(config)
# Share embedding and lm_head with main model
self._shared_embedding = shared_embedding
self._shared_lm_head = shared_lm_head
self._output_norm = output_norm
def __call__(
self,
hidden_state: mx.array,
draft_token: mx.array,
cache: KVCache | None = None,
mask: mx.array | None = None,
) -> tuple[mx.array, mx.array]:
"""Forward pass for MTP.
Args:
hidden_state: Hidden state from main model [batch, seq_len, hidden_size]
draft_token: Token to embed and combine with hidden state [batch, seq_len]
cache: Optional KV cache for the MTP transformer block
mask: Optional attention mask
Returns:
tuple of (logits, new_hidden_state)
"""
# Get embedding of draft token
embedding = self._shared_embedding(draft_token)
# Normalize hidden state and embedding
h_norm = self.hnorm(hidden_state)
e_norm = self.enorm(embedding)
# Project concatenated representation
concatenated = mx.concatenate([h_norm, e_norm], axis=-1)
projected = self.eh_proj(concatenated)
# Pass through single transformer block
new_hidden = self.transformer_block(projected, mask=mask, cache=cache)
# Apply output norm and get logits
normed_hidden = self._output_norm(new_hidden)
logits = self._shared_lm_head(normed_hidden)
return logits, new_hidden
class MTPTransformerBlock(nn.Module):
"""Single transformer block for MTP.
This is similar to DeepseekV3DecoderLayer but uses a dense MLP
instead of MoE since this is just for the single MTP layer.
"""
def __init__(self, config: ModelArgs) -> None:
super().__init__()
self.self_attn = DeepseekV3Attention(config)
# MTP uses dense MLP, not MoE
self.mlp = DeepseekV3MLP(config)
self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = nn.RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
def __call__(
self,
x: mx.array,
mask: mx.array | None = None,
cache: Any | None = None,
) -> mx.array:
"""Forward pass with residual connections."""
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
r = self.mlp(self.post_attention_layernorm(h))
return h + r
def extract_mtp_weights(weights: dict[str, mx.array]) -> dict[str, mx.array]:
"""Extract MTP-specific weights from layer 61.
The MTP layer has these weight patterns:
- model.layers.61.enorm.weight -> MTP embedding normalization
- model.layers.61.hnorm.weight -> MTP hidden normalization
- model.layers.61.eh_proj.weight -> MTP projection layer
- model.layers.61.self_attn.* -> MTP attention
- model.layers.61.input_layernorm.* -> MTP layer norms
- model.layers.61.post_attention_layernorm.*
- model.layers.61.mlp.* -> MTP MLP (dense, not MoE)
Args:
weights: Full model weights dict
Returns:
Dict of MTP-specific weights with keys renamed for MTPModule
"""
mtp_weights: dict[str, mx.array] = {}
mtp_prefix = f"model.layers.{MTP_LAYER_INDEX}."
for key, value in weights.items():
if key.startswith(mtp_prefix):
# Remove the layer prefix to get relative path
new_key = key[len(mtp_prefix) :]
mtp_weights[new_key] = value
return mtp_weights
def load_mtp_weights_into_module(
mtp_module: MTPModule,
mtp_weights: dict[str, mx.array],
) -> None:
"""Load extracted MTP weights into the MTPModule.
Args:
mtp_module: The MTPModule instance to load weights into
mtp_weights: Extracted MTP weights from extract_mtp_weights()
"""
# Map weight names to module attributes
weight_mapping: dict[str, str] = {
"enorm.weight": "enorm.weight",
"hnorm.weight": "hnorm.weight",
"eh_proj.weight": "eh_proj.weight",
}
# Load direct mappings
for src_name, dst_name in weight_mapping.items():
if src_name in mtp_weights:
parts = dst_name.split(".")
obj: Any = mtp_module
for part in parts[:-1]:
obj = getattr(obj, part)
setattr(obj, parts[-1], mtp_weights[src_name])
# Load transformer block weights (self_attn, mlp, layer norms)
transformer_prefixes = [
"self_attn",
"mlp",
"input_layernorm",
"post_attention_layernorm",
]
for prefix in transformer_prefixes:
for key, value in mtp_weights.items():
if key.startswith(prefix):
# Navigate to the correct attribute
parts = key.split(".")
obj = mtp_module.transformer_block
for part in parts[:-1]:
obj = getattr(obj, part)
setattr(obj, parts[-1], value)

View File

@@ -1,506 +0,0 @@
"""MTP Speculative Decoding for DeepSeek V3.
This module implements speculative decoding using the Multi-Token Prediction (MTP)
layer from DeepSeek V3. The key difference from standard speculative decoding is
that MTP requires hidden states from the main model, not just token predictions.
Based on vLLM/SGLang research:
- 81-82% acceptance rate with k=1
- 1.5-2x speedup at low QPS
"""
import functools
import time
from collections.abc import Callable, Generator
from dataclasses import dataclass
from typing import Any, cast
import mlx.core as mx
import mlx.nn as nn
from mlx_lm.models import cache
from mlx_lm.models.cache import KVCache
from mlx_lm.tokenizer_utils import TokenizerWrapper
from exo.worker.engines.mlx.mtp.module import MTPModule
# Generation stream for async operations
generation_stream = mx.new_stream(mx.default_device())
@dataclass
class MTPGenerationResponse:
"""Response from MTP speculative generation.
Attributes:
text: The next segment of decoded text.
token: The next token.
logprobs: A vector of log probabilities.
from_draft: Whether the token was generated by the MTP draft module.
prompt_tokens: The number of tokens in the prompt.
prompt_tps: The prompt processing tokens-per-second.
generation_tokens: The number of generated tokens.
generation_tps: The tokens-per-second for generation.
peak_memory: The peak memory used so far in GB.
finish_reason: The reason the response is being sent: "length", "stop" or None.
"""
text: str
token: int
logprobs: mx.array
from_draft: bool
prompt_tokens: int
prompt_tps: float
generation_tokens: int
generation_tps: float
peak_memory: float
finish_reason: str | None = None
def maybe_quantize_kv_cache(
prompt_cache: list[Any],
quantized_kv_start: int,
kv_group_size: int,
kv_bits: int | None,
) -> None:
"""Quantize KV cache entries if needed."""
if kv_bits is None:
return
for e, c in enumerate(prompt_cache):
if (
hasattr(c, "to_quantized")
and hasattr(c, "offset")
and c.offset >= quantized_kv_start
):
prompt_cache[e] = c.to_quantized(group_size=kv_group_size, bits=kv_bits)
class ModelWithHiddenStates(nn.Module):
"""Wrapper to extract hidden states before lm_head.
This wrapper allows capturing the hidden states from the transformer
layers before the final lm_head projection, which is needed for MTP.
"""
def __init__(self, base_model: nn.Module) -> None:
super().__init__()
self._base = base_model
def forward_with_hidden(
self,
inputs: mx.array,
model_cache: list[Any] | None = None,
) -> tuple[mx.array, mx.array]:
"""Forward pass that returns both logits and hidden states.
Args:
inputs: Input token ids
model_cache: KV cache
Returns:
Tuple of (logits, hidden_states)
"""
# Call the inner model (transformer layers + norm)
hidden: mx.array = self._base.model(inputs, model_cache)
# Get logits from lm_head
logits: mx.array = self._base.lm_head(hidden)
return logits, hidden
def forward(
self,
inputs: mx.array,
model_cache: list[Any] | None = None,
) -> mx.array:
"""Standard forward pass returning only logits."""
return cast(mx.array, self._base(inputs, cache=model_cache))
@property
def layers(self) -> list[nn.Module]:
"""Access layers for cache creation."""
return cast(list[nn.Module], self._base.layers)
def mtp_speculative_generate_step(
prompt: mx.array,
model: nn.Module,
mtp_module: MTPModule,
*,
num_draft_tokens: int = 1,
max_tokens: int = 256,
sampler: Callable[[mx.array], mx.array] | None = None,
logits_processors: list[Callable[[mx.array, mx.array], mx.array]] | None = None,
prompt_cache: list[Any] | None = None,
mtp_cache: KVCache | None = None,
prefill_step_size: int = 512,
kv_bits: int | None = None,
kv_group_size: int = 64,
quantized_kv_start: int = 0,
) -> Generator[tuple[int, mx.array, bool], None, None]:
"""MTP speculative decoding generator.
Unlike standard speculative decoding where the draft model only needs tokens,
MTP requires the hidden states from the main model. This generator:
1. Runs the main model to get logits AND hidden states
2. Uses MTP module with hidden state + sampled token to predict next token
3. Verifies MTP predictions with the main model
4. Accepts/rejects based on matching
Args:
prompt: The input prompt as token ids
model: The main model (must support return_hidden=True)
mtp_module: The MTP module for draft prediction
num_draft_tokens: Number of tokens to draft (typically 1 for MTP)
max_tokens: Maximum number of tokens to generate
sampler: Optional sampler function for token selection
logits_processors: Optional list of logits processors
prompt_cache: KV cache for the main model
mtp_cache: KV cache for the MTP module
prefill_step_size: Step size for prompt processing
kv_bits: Bits for KV cache quantization
kv_group_size: Group size for KV cache quantization
quantized_kv_start: Step to begin cache quantization
Yields:
Tuple of (token, logprobs, from_draft)
"""
y = prompt.astype(mx.uint32)
prev_tokens: mx.array | None = None
# Wrap model to get hidden states
wrapped_model = (
model
if isinstance(model, ModelWithHiddenStates)
else ModelWithHiddenStates(model)
)
# Create caches if needed
if prompt_cache is None:
prompt_cache = cache.make_prompt_cache(model)
if mtp_cache is None:
mtp_cache = KVCache()
final_sampler = (
sampler if sampler is not None else (lambda x: mx.argmax(x, axis=-1))
)
quantize_cache_fn = functools.partial(
maybe_quantize_kv_cache,
quantized_kv_start=quantized_kv_start,
kv_group_size=kv_group_size,
kv_bits=kv_bits,
)
def _process_and_sample(
tokens: mx.array | None,
logits: mx.array,
) -> tuple[mx.array, mx.array]:
"""Process logits and sample tokens."""
nonlocal logits_processors
processed_logits = logits
if logits_processors:
for processor in logits_processors:
processed_logits = processor(
tokens if tokens is not None else mx.array([]), processed_logits
)
logprobs = processed_logits - mx.logsumexp(
processed_logits, axis=-1, keepdims=True
)
sampled = final_sampler(logprobs)
return sampled, logprobs
def _main_model_step_with_hidden(
input_y: mx.array,
) -> tuple[mx.array, mx.array, mx.array]:
"""Run main model step with hidden state return."""
nonlocal prev_tokens
with mx.stream(generation_stream):
logits, hidden = wrapped_model.forward_with_hidden(
input_y[None], prompt_cache
)
logits = logits[:, -1, :]
quantize_cache_fn(prompt_cache)
if logits_processors:
prev_tokens = (
mx.concatenate([prev_tokens, input_y])
if prev_tokens is not None
else input_y
)
sampled, logprobs_result = _process_and_sample(prev_tokens, logits)
return sampled, logprobs_result.squeeze(0), hidden[:, -1:, :]
def _main_model_step(
input_y: mx.array,
) -> tuple[mx.array, mx.array]:
"""Run main model step without hidden state."""
nonlocal prev_tokens
with mx.stream(generation_stream):
logits = wrapped_model.forward(input_y[None], prompt_cache)
logits = logits[:, -1, :]
quantize_cache_fn(prompt_cache)
if logits_processors:
prev_tokens = (
mx.concatenate([prev_tokens, input_y])
if prev_tokens is not None
else input_y
)
sampled, logprobs_result = _process_and_sample(prev_tokens, logits)
return sampled, logprobs_result.squeeze(0)
def _mtp_draft(
hidden_state: mx.array,
draft_token: mx.array,
) -> tuple[mx.array, mx.array]:
"""Generate draft token using MTP module."""
with mx.stream(generation_stream):
logits, new_hidden = mtp_module(
hidden_state,
draft_token,
cache=mtp_cache,
)
logits = logits[:, -1, :]
sampled, _ = _process_and_sample(None, logits)
return sampled, new_hidden
def _prefill(input_y: mx.array) -> mx.array:
"""Prefill the prompt cache."""
result_y = input_y
while result_y.size > prefill_step_size:
_ = wrapped_model.forward(result_y[:prefill_step_size][None], prompt_cache)
quantize_cache_fn(prompt_cache)
mx.eval([c.state for c in prompt_cache])
result_y = result_y[prefill_step_size:]
mx.clear_cache()
return result_y
def _rewind_cache(num_draft: int, num_accept: int) -> None:
"""Rewind caches after rejection."""
cache.trim_prompt_cache(prompt_cache, num_draft - num_accept)
# Prefill phase
with mx.stream(generation_stream):
y = _prefill(y)
ntoks = 0
num_draft = 0
n_accepted = 0
last_hidden: mx.array | None = None
try:
# Initial step to get first token and hidden state
sampled, logprobs, last_hidden = _main_model_step_with_hidden(y)
mx.eval(sampled, logprobs, last_hidden)
y = sampled
current_logprobs = logprobs
while ntoks < max_tokens:
# Draft phase: use MTP to predict next token
num_draft = min(max_tokens - ntoks - 1, num_draft_tokens)
if num_draft > 0 and last_hidden is not None:
# Use MTP to draft
draft_token, draft_hidden = _mtp_draft(last_hidden, y)
mx.eval(draft_token, draft_hidden)
# Verify with main model
# Feed the drafted token to main model
verify_input = mx.concatenate([y, draft_token.flatten()])
verify_sampled, verify_logprobs, new_hidden = (
_main_model_step_with_hidden(verify_input)
)
mx.eval(verify_sampled, verify_logprobs, new_hidden)
# Check if draft matches verification
draft_token_val = int(draft_token.item())
verify_token_val = (
int(verify_sampled[0].item())
if verify_sampled.shape[0] > 1
else int(verify_sampled.item())
)
# Yield the current token (not from draft)
ntoks += 1
yield int(y.item()), current_logprobs, False
if ntoks >= max_tokens:
break
if draft_token_val == verify_token_val:
# Draft accepted
n_accepted += 1
ntoks += 1
draft_logprobs = (
verify_logprobs[0]
if verify_logprobs.ndim > 1
else verify_logprobs
)
yield draft_token_val, draft_logprobs, True
if ntoks >= max_tokens:
break
# Continue with the token after the draft
y = (
verify_sampled[-1:]
if verify_sampled.ndim > 0 and verify_sampled.shape[0] > 1
else verify_sampled
)
current_logprobs = (
verify_logprobs[-1]
if verify_logprobs.ndim > 1
else verify_logprobs
)
last_hidden = new_hidden
else:
# Draft rejected - rewind and use verified token
_rewind_cache(1, 0)
y = (
verify_sampled[:1]
if verify_sampled.ndim > 0 and verify_sampled.shape[0] > 1
else verify_sampled
)
current_logprobs = (
verify_logprobs[0]
if verify_logprobs.ndim > 1
else verify_logprobs
)
last_hidden = (
new_hidden[:, :1, :] if new_hidden is not None else None
)
else:
# No drafting, just do normal generation
ntoks += 1
yield int(y.item()), current_logprobs, False
if ntoks >= max_tokens:
break
sampled, logprobs, last_hidden = _main_model_step_with_hidden(y)
mx.eval(sampled, logprobs, last_hidden)
y = sampled
current_logprobs = logprobs
if ntoks % 256 == 0:
mx.clear_cache()
finally:
_rewind_cache(num_draft, n_accepted)
def mtp_speculative_generate(
model: nn.Module,
mtp_module: MTPModule,
tokenizer: TokenizerWrapper,
prompt: str | mx.array | list[int],
max_tokens: int = 256,
sampler: Callable[[mx.array], mx.array] | None = None,
logits_processors: list[Callable[[mx.array, mx.array], mx.array]] | None = None,
prompt_cache: list[Any] | None = None,
num_draft_tokens: int = 1,
prefill_step_size: int = 512,
kv_group_size: int = 64,
kv_bits: int | None = None,
) -> Generator[MTPGenerationResponse, None, None]:
"""High-level MTP speculative generation with text output.
Args:
model: The main model
mtp_module: The MTP module for draft prediction
tokenizer: Tokenizer for encoding/decoding
prompt: Input prompt (string, array, or token list)
max_tokens: Maximum tokens to generate
sampler: Optional sampler function
logits_processors: Optional logits processors
prompt_cache: Optional KV cache
num_draft_tokens: Number of draft tokens
prefill_step_size: Prefill step size
kv_group_size: KV group size
kv_bits: KV bits
Yields:
MTPGenerationResponse objects with text and metadata
"""
if not isinstance(prompt, mx.array):
if isinstance(prompt, str):
bos_token = getattr(tokenizer, "bos_token", None)
add_special_tokens = bos_token is None or not prompt.startswith(
str(bos_token)
)
encoded: list[int] = tokenizer.encode(
prompt, add_special_tokens=add_special_tokens
)
prompt = mx.array(encoded)
else:
prompt = mx.array(prompt)
detokenizer = tokenizer.detokenizer
eos_token_ids: list[int] = getattr(tokenizer, "eos_token_ids", [])
token_generator = mtp_speculative_generate_step(
prompt,
model,
mtp_module,
max_tokens=max_tokens,
sampler=sampler,
logits_processors=logits_processors,
prompt_cache=prompt_cache,
num_draft_tokens=num_draft_tokens,
prefill_step_size=prefill_step_size,
kv_group_size=kv_group_size,
kv_bits=kv_bits,
)
tic = time.perf_counter()
prompt_tps = 0.0
token = 0
logprobs: mx.array = mx.array([0.0])
from_draft = False
n = 0
for n, (token, logprobs, from_draft) in enumerate(token_generator):
if n == 0:
prompt_time = time.perf_counter() - tic
prompt_tps = float(prompt.size) / prompt_time
tic = time.perf_counter()
if token in eos_token_ids:
break
detokenizer.add_token(token)
if (n + 1) == max_tokens:
break
yield MTPGenerationResponse(
text=str(detokenizer.last_segment),
token=token,
logprobs=logprobs,
from_draft=from_draft,
prompt_tokens=int(prompt.size),
prompt_tps=prompt_tps,
generation_tokens=n + 1,
generation_tps=(n + 1) / (time.perf_counter() - tic),
peak_memory=mx.get_peak_memory() / 1e9,
finish_reason=None,
)
detokenizer.finalize()
yield MTPGenerationResponse(
text=str(detokenizer.last_segment),
token=token,
logprobs=logprobs,
from_draft=from_draft,
prompt_tokens=int(prompt.size),
prompt_tps=prompt_tps,
generation_tokens=n + 1,
generation_tps=(n + 1) / (time.perf_counter() - tic),
peak_memory=mx.get_peak_memory() / 1e9,
finish_reason="stop" if token in eos_token_ids else "length",
)

View File

@@ -1 +0,0 @@
"""Tests for MTP module."""

View File

@@ -1,412 +0,0 @@
"""Unit tests for MTP module components."""
import mlx.core as mx
import mlx.nn as nn
import pytest
from exo.worker.engines.mlx.mtp.module import (
MTP_LAYER_INDEX,
MTPModule,
MTPTransformerBlock,
extract_mtp_weights,
load_mtp_weights_into_module,
)
class MockModelArgs:
"""Mock ModelArgs for testing without importing deepseek_v3."""
def __init__(
self,
hidden_size: int = 256,
intermediate_size: int = 512,
num_attention_heads: int = 4,
num_key_value_heads: int = 4,
rms_norm_eps: float = 1e-6,
vocab_size: int = 1000,
q_lora_rank: int | None = None,
kv_lora_rank: int = 64,
qk_rope_head_dim: int = 16,
v_head_dim: int = 32,
qk_nope_head_dim: int = 32,
rope_theta: float = 10000.0,
rope_scaling: dict | None = None,
attention_bias: bool = False,
max_position_embeddings: int = 2048,
):
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_attention_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.rms_norm_eps = rms_norm_eps
self.vocab_size = vocab_size
self.q_lora_rank = q_lora_rank
self.kv_lora_rank = kv_lora_rank
self.qk_rope_head_dim = qk_rope_head_dim
self.v_head_dim = v_head_dim
self.qk_nope_head_dim = qk_nope_head_dim
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.attention_bias = attention_bias
self.max_position_embeddings = max_position_embeddings
class TestExtractMTPWeights:
"""Tests for extract_mtp_weights function."""
def test_extracts_layer_61_weights(self) -> None:
"""Should extract only layer 61 weights."""
weights = {
"model.layers.60.self_attn.weight": mx.zeros((10, 10)),
"model.layers.61.enorm.weight": mx.ones((10,)),
"model.layers.61.hnorm.weight": mx.ones((10,)) * 2,
"model.layers.61.eh_proj.weight": mx.ones((10, 20)),
"model.layers.62.self_attn.weight": mx.zeros((10, 10)),
"model.embed_tokens.weight": mx.zeros((100, 10)),
}
mtp_weights = extract_mtp_weights(weights)
assert len(mtp_weights) == 3
assert "enorm.weight" in mtp_weights
assert "hnorm.weight" in mtp_weights
assert "eh_proj.weight" in mtp_weights
# Check values are preserved
assert mx.allclose(mtp_weights["enorm.weight"], mx.ones((10,)))
assert mx.allclose(mtp_weights["hnorm.weight"], mx.ones((10,)) * 2)
def test_returns_empty_dict_when_no_layer_61(self) -> None:
"""Should return empty dict when layer 61 doesn't exist."""
weights = {
"model.layers.0.self_attn.weight": mx.zeros((10, 10)),
"model.layers.60.self_attn.weight": mx.zeros((10, 10)),
}
mtp_weights = extract_mtp_weights(weights)
assert len(mtp_weights) == 0
def test_handles_nested_layer_61_weights(self) -> None:
"""Should handle nested weight paths like self_attn.q_proj.weight."""
weights = {
f"model.layers.{MTP_LAYER_INDEX}.self_attn.q_a_proj.weight": mx.zeros(
(10, 10)
),
f"model.layers.{MTP_LAYER_INDEX}.mlp.gate_proj.weight": mx.zeros((20, 10)),
}
mtp_weights = extract_mtp_weights(weights)
assert "self_attn.q_a_proj.weight" in mtp_weights
assert "mlp.gate_proj.weight" in mtp_weights
class TestMTPTransformerBlock:
"""Tests for MTPTransformerBlock."""
@pytest.fixture
def config(self) -> MockModelArgs:
return MockModelArgs(
hidden_size=64, intermediate_size=128, num_attention_heads=2
)
def test_forward_shape(self, config: MockModelArgs) -> None:
"""Forward pass should preserve input shape."""
# Skip if deepseek_v3 imports fail (CI without mlx_lm)
pytest.importorskip("mlx_lm.models.deepseek_v3")
block = MTPTransformerBlock(config) # type: ignore[arg-type]
x = mx.random.normal((1, 5, config.hidden_size))
output = block(x)
assert output.shape == x.shape
def test_forward_with_mask(self, config: MockModelArgs) -> None:
"""Forward pass should work with attention mask."""
pytest.importorskip("mlx_lm.models.deepseek_v3")
block = MTPTransformerBlock(config) # type: ignore[arg-type]
x = mx.random.normal((1, 5, config.hidden_size))
# Create causal mask
mask = mx.triu(mx.full((5, 5), float("-inf")), k=1)
output = block(x, mask=mask)
assert output.shape == x.shape
class TestMTPModule:
"""Tests for MTPModule."""
@pytest.fixture
def config(self) -> MockModelArgs:
return MockModelArgs(
hidden_size=64,
intermediate_size=128,
num_attention_heads=2,
vocab_size=100,
)
@pytest.fixture
def shared_components(
self, config: MockModelArgs
) -> tuple[nn.Embedding, nn.Linear, nn.RMSNorm]:
embedding = nn.Embedding(config.vocab_size, config.hidden_size)
lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
output_norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
return embedding, lm_head, output_norm
def test_initialization(
self,
config: MockModelArgs,
shared_components: tuple[nn.Embedding, nn.Linear, nn.RMSNorm],
) -> None:
"""MTPModule should initialize with correct components."""
pytest.importorskip("mlx_lm.models.deepseek_v3")
embedding, lm_head, output_norm = shared_components
mtp = MTPModule(
config=config, # type: ignore[arg-type]
shared_embedding=embedding,
shared_lm_head=lm_head,
output_norm=output_norm,
)
assert mtp.hnorm is not None
assert mtp.enorm is not None
assert mtp.eh_proj is not None
assert mtp.transformer_block is not None
def test_forward_output_shapes(
self,
config: MockModelArgs,
shared_components: tuple[nn.Embedding, nn.Linear, nn.RMSNorm],
) -> None:
"""Forward pass should return correct output shapes."""
pytest.importorskip("mlx_lm.models.deepseek_v3")
embedding, lm_head, output_norm = shared_components
mtp = MTPModule(
config=config, # type: ignore[arg-type]
shared_embedding=embedding,
shared_lm_head=lm_head,
output_norm=output_norm,
)
batch_size = 2
seq_len = 1
hidden_state = mx.random.normal((batch_size, seq_len, config.hidden_size))
draft_token = mx.array([[5], [10]]) # [batch, seq_len]
logits, new_hidden = mtp(hidden_state, draft_token)
assert logits.shape == (batch_size, seq_len, config.vocab_size)
assert new_hidden.shape == (batch_size, seq_len, config.hidden_size)
def test_shares_embedding_and_lm_head(
self,
config: MockModelArgs,
shared_components: tuple[nn.Embedding, nn.Linear, nn.RMSNorm],
) -> None:
"""MTPModule should use shared embedding and lm_head."""
pytest.importorskip("mlx_lm.models.deepseek_v3")
embedding, lm_head, output_norm = shared_components
mtp = MTPModule(
config=config, # type: ignore[arg-type]
shared_embedding=embedding,
shared_lm_head=lm_head,
output_norm=output_norm,
)
# Verify they're the same objects
assert mtp._shared_embedding is embedding
assert mtp._shared_lm_head is lm_head
assert mtp._output_norm is output_norm
class TestLoadMTPWeights:
"""Tests for load_mtp_weights_into_module."""
@pytest.fixture
def config(self) -> MockModelArgs:
return MockModelArgs(
hidden_size=64,
intermediate_size=128,
num_attention_heads=2,
vocab_size=100,
)
def test_loads_norm_weights(self, config: MockModelArgs) -> None:
"""Should load enorm and hnorm weights."""
pytest.importorskip("mlx_lm.models.deepseek_v3")
embedding = nn.Embedding(config.vocab_size, config.hidden_size)
lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
output_norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
mtp = MTPModule(
config=config, # type: ignore[arg-type]
shared_embedding=embedding,
shared_lm_head=lm_head,
output_norm=output_norm,
)
# Create test weights
test_enorm = mx.ones((config.hidden_size,)) * 3.0
test_hnorm = mx.ones((config.hidden_size,)) * 5.0
mtp_weights = {
"enorm.weight": test_enorm,
"hnorm.weight": test_hnorm,
}
load_mtp_weights_into_module(mtp, mtp_weights)
assert mx.allclose(mtp.enorm.weight, test_enorm)
assert mx.allclose(mtp.hnorm.weight, test_hnorm)
class TestSanitizePatch:
"""Tests for the sanitize patching logic."""
def test_patch_preserves_layer_61(self) -> None:
"""Patching sanitize should preserve layer 61 weights."""
from exo.worker.engines.mlx.utils_mlx import (
_patch_deepseek_sanitize_for_mtp,
_restore_deepseek_sanitize,
)
deepseek_v3 = pytest.importorskip("mlx_lm.models.deepseek_v3")
model_cls = deepseek_v3.Model
# Get original sanitize behavior
original_sanitize = model_cls.sanitize
try:
# Apply patch
_patch_deepseek_sanitize_for_mtp()
# Note: we can't easily test the full sanitize without a real model
# This test verifies the patch is applied
assert model_cls.sanitize is not original_sanitize
finally:
_restore_deepseek_sanitize()
# Verify restore worked
assert model_cls.sanitize is original_sanitize
def test_restore_sanitize(self) -> None:
"""Restoring sanitize should return to original behavior."""
from exo.worker.engines.mlx.utils_mlx import (
_patch_deepseek_sanitize_for_mtp,
_restore_deepseek_sanitize,
)
deepseek_v3 = pytest.importorskip("mlx_lm.models.deepseek_v3")
model_cls = deepseek_v3.Model
original_sanitize = model_cls.sanitize
_patch_deepseek_sanitize_for_mtp()
assert model_cls.sanitize is not original_sanitize
_restore_deepseek_sanitize()
assert model_cls.sanitize is original_sanitize
def test_double_patch_is_safe(self) -> None:
"""Calling patch twice should be safe (idempotent)."""
from exo.worker.engines.mlx.utils_mlx import (
_patch_deepseek_sanitize_for_mtp,
_restore_deepseek_sanitize,
)
deepseek_v3 = pytest.importorskip("mlx_lm.models.deepseek_v3")
model_cls = deepseek_v3.Model
original_sanitize = model_cls.sanitize
try:
_patch_deepseek_sanitize_for_mtp()
patched_sanitize = model_cls.sanitize
# Patch again - should be no-op
_patch_deepseek_sanitize_for_mtp()
assert model_cls.sanitize is patched_sanitize
finally:
_restore_deepseek_sanitize()
assert model_cls.sanitize is original_sanitize
class TestModelIdDetection:
"""Tests for DeepSeek V3 model ID detection."""
def test_detects_deepseek_v3(self) -> None:
"""Should detect DeepSeek V3 model IDs."""
from exo.worker.engines.mlx.utils_mlx import _might_be_deepseek_v3
assert _might_be_deepseek_v3("deepseek-ai/DeepSeek-V3")
assert _might_be_deepseek_v3("deepseek-ai/deepseek-v3-base")
assert _might_be_deepseek_v3("mlx-community/DeepSeek-V3-4bit")
def test_detects_deepseek_r1(self) -> None:
"""Should detect DeepSeek R1 model IDs (also uses MTP)."""
from exo.worker.engines.mlx.utils_mlx import _might_be_deepseek_v3
assert _might_be_deepseek_v3("deepseek-ai/DeepSeek-R1")
assert _might_be_deepseek_v3("mlx-community/DeepSeek-R1-4bit")
def test_rejects_non_deepseek(self) -> None:
"""Should reject non-DeepSeek model IDs."""
from exo.worker.engines.mlx.utils_mlx import _might_be_deepseek_v3
assert not _might_be_deepseek_v3("meta-llama/Llama-3-70B")
assert not _might_be_deepseek_v3("mistralai/Mixtral-8x7B")
assert not _might_be_deepseek_v3("deepseek-ai/DeepSeek-V2") # V2, not V3
def test_case_insensitive(self) -> None:
"""Detection should be case insensitive."""
from exo.worker.engines.mlx.utils_mlx import _might_be_deepseek_v3
assert _might_be_deepseek_v3("DEEPSEEK-AI/DEEPSEEK-V3")
assert _might_be_deepseek_v3("DeepSeek-AI/deepseek-v3")
class TestFlattenParams:
"""Tests for parameter flattening utility."""
def test_flattens_nested_dict(self) -> None:
"""Should flatten nested parameter dict."""
from exo.worker.engines.mlx.utils_mlx import _flatten_params
params = {
"model": {
"layers": {
"0": {
"weight": mx.zeros((10,)),
}
},
"embed": mx.ones((5,)),
}
}
flat = _flatten_params(params)
assert "model.layers.0.weight" in flat
assert "model.embed" in flat
assert mx.allclose(flat["model.layers.0.weight"], mx.zeros((10,)))
assert mx.allclose(flat["model.embed"], mx.ones((5,)))
def test_handles_flat_dict(self) -> None:
"""Should handle already-flat dict."""
from exo.worker.engines.mlx.utils_mlx import _flatten_params
params = {
"weight": mx.zeros((10,)),
"bias": mx.ones((10,)),
}
flat = _flatten_params(params)
assert flat == params

View File

@@ -1,253 +0,0 @@
"""Unit tests for MTP speculative decoding."""
import mlx.core as mx
import mlx.nn as nn
import pytest
from exo.worker.engines.mlx.mtp.speculative_decode import (
ModelWithHiddenStates,
maybe_quantize_kv_cache,
)
class MockModel(nn.Module):
"""Mock model for testing speculative decoding."""
def __init__(self, hidden_size: int = 64, vocab_size: int = 100) -> None:
super().__init__()
self.hidden_size = hidden_size
self.vocab_size = vocab_size
# Create simple model components
self.model = MockInnerModel(hidden_size)
self.lm_head = nn.Linear(hidden_size, vocab_size, bias=False)
self._layers = [nn.Linear(hidden_size, hidden_size) for _ in range(3)]
def __call__(
self,
inputs: mx.array,
cache: list | None = None,
) -> mx.array:
hidden = self.model(inputs, cache)
return self.lm_head(hidden)
@property
def layers(self) -> list[nn.Module]:
return self._layers
class MockInnerModel(nn.Module):
"""Mock inner model (like DeepseekV3Model)."""
def __init__(self, hidden_size: int) -> None:
super().__init__()
self.embed_tokens = nn.Embedding(100, hidden_size)
self.norm = nn.RMSNorm(hidden_size)
def __call__(
self,
inputs: mx.array,
cache: list | None = None,
) -> mx.array:
# Simple embedding + norm
embedded = self.embed_tokens(inputs)
return self.norm(embedded)
class TestModelWithHiddenStates:
"""Tests for ModelWithHiddenStates wrapper."""
@pytest.fixture
def mock_model(self) -> MockModel:
return MockModel(hidden_size=64, vocab_size=100)
def test_forward_returns_logits(self, mock_model: MockModel) -> None:
"""Standard forward should return logits."""
wrapped = ModelWithHiddenStates(mock_model)
inputs = mx.array([[1, 2, 3]])
logits = wrapped.forward(inputs)
assert logits.shape == (1, 3, mock_model.vocab_size)
def test_forward_with_hidden_returns_tuple(self, mock_model: MockModel) -> None:
"""Forward with hidden should return (logits, hidden)."""
wrapped = ModelWithHiddenStates(mock_model)
inputs = mx.array([[1, 2, 3]])
logits, hidden = wrapped.forward_with_hidden(inputs)
assert logits.shape == (1, 3, mock_model.vocab_size)
assert hidden.shape == (1, 3, mock_model.hidden_size)
def test_layers_property(self, mock_model: MockModel) -> None:
"""Should expose layers property from base model."""
wrapped = ModelWithHiddenStates(mock_model)
assert wrapped.layers == mock_model.layers
assert len(wrapped.layers) == 3
class TestMaybeQuantizeKVCache:
"""Tests for KV cache quantization."""
def test_no_quantization_when_bits_none(self) -> None:
"""Should not quantize when kv_bits is None."""
cache = [MockCache(offset=100)]
maybe_quantize_kv_cache(
cache,
quantized_kv_start=50,
kv_group_size=64,
kv_bits=None,
)
# Cache should be unchanged
assert not hasattr(cache[0], "quantized")
def test_respects_quantized_kv_start(self) -> None:
"""Should only quantize caches past the start threshold."""
cache_below = MockCache(offset=30)
cache_above = MockCache(offset=100)
caches = [cache_below, cache_above]
maybe_quantize_kv_cache(
caches,
quantized_kv_start=50,
kv_group_size=64,
kv_bits=4,
)
# Only cache_above should be quantized
assert not getattr(cache_below, "was_quantized", False)
assert getattr(caches[1], "was_quantized", False)
class MockCache:
"""Mock KV cache for testing."""
def __init__(self, offset: int = 0) -> None:
self.offset = offset
self.was_quantized = False
def to_quantized(self, group_size: int, bits: int) -> "MockCache":
quantized = MockCache(self.offset)
quantized.was_quantized = True
return quantized
class TestSpeculativeDecodingLogic:
"""Tests for the core speculative decoding logic."""
def test_draft_acceptance_identical_tokens(self) -> None:
"""When draft matches verification, both should be accepted."""
# This tests the logic, not the full generator
draft_token = 42
verify_token = 42
accepted = draft_token == verify_token
assert accepted
def test_draft_rejection_different_tokens(self) -> None:
"""When draft differs from verification, draft should be rejected."""
draft_token = 42
verify_token = 99
accepted = draft_token == verify_token
assert not accepted
class TestMTPGenerationResponse:
"""Tests for MTPGenerationResponse dataclass."""
def test_response_creation(self) -> None:
"""Should create response with all fields."""
from exo.worker.engines.mlx.mtp.speculative_decode import MTPGenerationResponse
response = MTPGenerationResponse(
text="Hello",
token=42,
logprobs=mx.array([0.1, 0.2]),
from_draft=True,
prompt_tokens=10,
prompt_tps=100.0,
generation_tokens=5,
generation_tps=50.0,
peak_memory=1.5,
finish_reason=None,
)
assert response.text == "Hello"
assert response.token == 42
assert response.from_draft is True
assert response.finish_reason is None
def test_response_with_finish_reason(self) -> None:
"""Should handle finish_reason."""
from exo.worker.engines.mlx.mtp.speculative_decode import MTPGenerationResponse
response = MTPGenerationResponse(
text="",
token=0,
logprobs=mx.array([0.0]),
from_draft=False,
prompt_tokens=10,
prompt_tps=100.0,
generation_tokens=100,
generation_tps=50.0,
peak_memory=1.5,
finish_reason="length",
)
assert response.finish_reason == "length"
class TestIntegration:
"""Integration tests for the full MTP pipeline."""
def test_mtp_module_with_mock_model(self) -> None:
"""Test MTP module can be created and run with mock components."""
pytest.importorskip("mlx_lm.models.deepseek_v3")
from exo.worker.engines.mlx.mtp.module import MTPModule
# Create mock config
class MockConfig:
hidden_size = 64
intermediate_size = 128
num_attention_heads = 2
num_key_value_heads = 2
rms_norm_eps = 1e-6
q_lora_rank = None
kv_lora_rank = 32
qk_rope_head_dim = 8
v_head_dim = 16
qk_nope_head_dim = 16
rope_theta = 10000.0
rope_scaling = None
attention_bias = False
max_position_embeddings = 2048
config = MockConfig()
embedding = nn.Embedding(100, config.hidden_size)
lm_head = nn.Linear(config.hidden_size, 100, bias=False)
output_norm = nn.RMSNorm(config.hidden_size)
mtp = MTPModule(
config=config, # type: ignore[arg-type]
shared_embedding=embedding,
shared_lm_head=lm_head,
output_norm=output_norm,
)
# Run forward pass
hidden = mx.random.normal((1, 1, config.hidden_size))
token = mx.array([[5]])
logits, new_hidden = mtp(hidden, token)
assert logits.shape == (1, 1, 100)
assert new_hidden.shape == (1, 1, config.hidden_size)
# Verify outputs are valid (not NaN)
assert not mx.any(mx.isnan(logits))
assert not mx.any(mx.isnan(new_hidden))

View File

@@ -28,7 +28,6 @@ from mlx_lm.tokenizer_utils import TokenizerWrapper
from exo.worker.engines.mlx.constants import (
CACHE_GROUP_SIZE,
KV_CACHE_BITS,
MTP_ENABLED,
TRUST_REMOTE_CODE,
)
@@ -70,67 +69,6 @@ Group = mx.distributed.Group
resource.setrlimit(resource.RLIMIT_NOFILE, (2048, 4096))
# MTP (Multi-Token Prediction) support for DeepSeek V3
MTP_LAYER_INDEX = 61
_original_deepseek_sanitize: Callable[..., dict[str, Any]] | None = None
def _is_deepseek_v3_model(model: nn.Module) -> bool:
"""Check if the model is DeepSeek V3."""
return hasattr(model, "model") and isinstance(model.model, DeepseekV3Model)
def _patch_deepseek_sanitize_for_mtp() -> None:
"""Patch DeepSeek V3 Model.sanitize to preserve MTP layer weights.
The original sanitize() method filters out layer 61 (MTP layer) weights.
This patch keeps them so we can extract and use the MTP module.
"""
global _original_deepseek_sanitize
from mlx_lm.models.deepseek_v3 import Model as DeepSeekV3Model
if _original_deepseek_sanitize is not None:
# Already patched
return
_original_deepseek_sanitize = DeepSeekV3Model.sanitize
def sanitize_with_mtp(
self: DeepSeekV3Model, weights: dict[str, Any]
) -> dict[str, Any]:
"""Modified sanitize that keeps MTP layer weights."""
# First, call the original sanitize to handle all the weight transformations
# (dequantization, expert stacking, etc.)
if _original_deepseek_sanitize is None:
raise RuntimeError(
"_original_deepseek_sanitize is None - patch not applied correctly"
)
original_result: dict[str, Any] = _original_deepseek_sanitize(self, weights)
# Re-add the MTP layer weights that were filtered out
mtp_weights = {
k: v
for k, v in weights.items()
if k.startswith(f"model.layers.{MTP_LAYER_INDEX}")
}
return {**original_result, **mtp_weights}
DeepSeekV3Model.sanitize = sanitize_with_mtp
def _restore_deepseek_sanitize() -> None:
"""Restore the original DeepSeek V3 sanitize method."""
global _original_deepseek_sanitize
if _original_deepseek_sanitize is None:
return
from mlx_lm.models.deepseek_v3 import Model as DeepSeekV3Model
DeepSeekV3Model.sanitize = _original_deepseek_sanitize
_original_deepseek_sanitize = None
# TODO: Test this
# ALSO https://github.com/exo-explore/exo/pull/233#discussion_r2549683673
def get_weights_size(model_shard_meta: ShardMetadata) -> Memory:
@@ -295,162 +233,50 @@ def load_mlx_items(
group: Group | None,
on_timeout: TimeoutCallback | None = None,
) -> tuple[Model, TokenizerWrapper]:
"""Load MLX model and tokenizer.
if group is None:
logger.info(f"Single device used for {bound_instance.instance}")
model_path = build_model_path(bound_instance.bound_shard.model_meta.model_id)
start_time = time.perf_counter()
model, _ = load_model(model_path, strict=True)
end_time = time.perf_counter()
logger.info(f"Time taken to load model: {(end_time - start_time):.2f}s")
tokenizer = get_tokenizer(model_path, bound_instance.bound_shard)
Returns:
Tuple of (model, tokenizer)
"""
model_id = bound_instance.bound_shard.model_meta.model_id
mtp_module = None
# Patch sanitize for MTP if this might be DeepSeek V3
should_try_mtp = MTP_ENABLED and _might_be_deepseek_v3(model_id)
if should_try_mtp:
logger.info("Patching DeepSeek V3 sanitize for MTP weight preservation")
_patch_deepseek_sanitize_for_mtp()
try:
if group is None:
logger.info(f"Single device used for {bound_instance.instance}")
model_path = build_model_path(model_id)
start_time = time.perf_counter()
model, _ = load_model(model_path, strict=not should_try_mtp)
end_time = time.perf_counter()
logger.info(f"Time taken to load model: {(end_time - start_time):.2f}s")
tokenizer = get_tokenizer(model_path, bound_instance.bound_shard)
else:
logger.info("Starting distributed init")
start_time = time.perf_counter()
model, tokenizer = shard_and_load(
bound_instance.bound_shard, group=group, on_timeout=on_timeout
)
end_time = time.perf_counter()
logger.info(
f"Time taken to shard and load model: {(end_time - start_time):.2f}s"
)
# Extract MTP module if available
if should_try_mtp and _is_deepseek_v3_model(model):
mtp_module = _extract_mtp_module(model)
if mtp_module is not None:
logger.info("Successfully extracted MTP module from DeepSeek V3")
finally:
# Restore original sanitize
if should_try_mtp:
_restore_deepseek_sanitize()
else:
logger.info("Starting distributed init")
start_time = time.perf_counter()
model, tokenizer = shard_and_load(
bound_instance.bound_shard, group=group, on_timeout=on_timeout
)
end_time = time.perf_counter()
logger.info(
f"Time taken to shard and load model: {(end_time - start_time):.2f}s"
)
set_wired_limit_for_model(get_weights_size(bound_instance.bound_shard))
# Store MTP module on the model for later access
if mtp_module is not None:
model.mtp_module = mtp_module # noqa: B010
return cast(Model, model), tokenizer
def _might_be_deepseek_v3(model_id: str) -> bool:
"""Check if model ID suggests this might be DeepSeek V3."""
model_id_lower = model_id.lower()
return "deepseek" in model_id_lower and (
"v3" in model_id_lower or "r1" in model_id_lower
)
def load_draft_model(model_id: str) -> nn.Module:
"""Load a draft model for speculative decoding (rank 0 only).
Draft models are small models (typically 0.5B-2B parameters) used to
generate candidate tokens quickly, which are then verified by the main
model in a single forward pass.
def _flatten_params(
params: dict[str, Any],
prefix: str = "",
) -> dict[str, mx.array]:
"""Flatten nested parameter dict to flat dict with dot-separated keys."""
result: dict[str, mx.array] = {}
for key, value in params.items():
full_key = f"{prefix}.{key}" if prefix else key
if isinstance(value, mx.array):
result[full_key] = value
elif isinstance(value, dict):
result.update(_flatten_params(value, full_key))
return result
Assumes the model has already been downloaded by the worker.
def _extract_mtp_module(model: nn.Module) -> Any | None:
"""Extract MTP module from a loaded DeepSeek V3 model.
The MTP weights are stored in model.model.layers at index 61 (if preserved).
This function extracts them and creates an MTPModule.
Args:
model_id: HuggingFace model ID for the draft model
Returns:
MTPModule if MTP weights were found and extracted, None otherwise.
The loaded draft model
"""
from exo.worker.engines.mlx.mtp.module import (
MTPModule,
extract_mtp_weights,
load_mtp_weights_into_module,
)
try:
# Check if the model has the MTP layer
inner_model = getattr(model, "model", None)
if inner_model is None or not hasattr(inner_model, "layers"):
logger.debug("Model doesn't have expected structure for MTP extraction")
return None
layers: list[nn.Module] = inner_model.layers # type: ignore[assignment]
if len(layers) <= MTP_LAYER_INDEX:
logger.debug(
f"Model has {len(layers)} layers, MTP layer {MTP_LAYER_INDEX} not found"
)
return None
# Get model config
config = getattr(model, "args", None)
if config is None:
logger.debug("Could not get model config for MTP module")
return None
# Create MTP module with shared weights
embed_tokens = getattr(inner_model, "embed_tokens", None)
lm_head = getattr(model, "lm_head", None)
norm = getattr(inner_model, "norm", None)
if embed_tokens is None or lm_head is None or norm is None:
logger.debug("Could not get required model components for MTP")
return None
mtp_module = MTPModule(
config=config,
shared_embedding=embed_tokens,
shared_lm_head=lm_head,
output_norm=norm,
)
# Extract MTP layer weights from the model's parameters
# The weights should be at model.model.layers.61.*
# model.parameters() returns a nested dict, we need to flatten it
raw_params: dict[str, Any] = dict(model.parameters()) # type: ignore[arg-type]
model_weights = _flatten_params(raw_params)
mtp_weights = extract_mtp_weights(model_weights)
if not mtp_weights:
logger.debug("No MTP weights found in model parameters")
return None
# Load weights into MTP module
load_mtp_weights_into_module(mtp_module, mtp_weights)
# Remove MTP layer from main model to avoid double computation
# Create new layers list without the MTP layer
new_layers = [layer for i, layer in enumerate(layers) if i != MTP_LAYER_INDEX]
inner_model.layers = new_layers # noqa: B010
logger.info(
f"Extracted MTP module, main model now has {len(new_layers)} layers"
)
return mtp_module
except Exception as e:
logger.warning(f"Failed to extract MTP module: {e}")
return None
model_path = build_model_path(model_id)
draft_model, _ = load_model(model_path, strict=True)
logger.info(f"Loaded draft model from {model_path}")
return draft_model
def shard_and_load(

View File

@@ -3,7 +3,8 @@
from collections.abc import Mapping, Sequence
from exo.shared.types.common import NodeId
from exo.shared.types.models import ModelId
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.tasks import (
ChatCompletion,
ConnectToGroup,
@@ -35,6 +36,7 @@ from exo.shared.types.worker.runners import (
RunnerStatus,
RunnerWarmingUp,
)
from exo.shared.types.worker.shards import PipelineShardMetadata
from exo.worker.runner.runner_supervisor import RunnerSupervisor
@@ -57,6 +59,7 @@ def plan(
or _model_needs_download(runners, download_status)
or _init_distributed_backend(runners, all_runners)
or _load_model(runners, all_runners, global_download_status)
or _draft_model_needs_download(runners, download_status)
or _ready_to_warmup(runners, all_runners)
or _pending_tasks(runners, tasks, all_runners)
)
@@ -128,6 +131,57 @@ def _model_needs_download(
)
def _draft_model_needs_download(
runners: Mapping[RunnerId, RunnerSupervisor],
download_status: Mapping[ModelId, DownloadProgress],
) -> DownloadModel | None:
"""Check if draft model needs download (for speculative decoding).
Only rank 0 needs the draft model, and only after the main model is loaded.
"""
for runner in runners.values():
instance = runner.bound_instance.instance
shard = runner.bound_instance.bound_shard
# Only check when runner is loaded and ready for warmup
if not isinstance(runner.status, RunnerLoaded):
continue
# Only rank 0 loads the draft model
if shard.device_rank != 0:
continue
# Check if instance has a draft model configured
draft_model_id = instance.draft_model
if draft_model_id is None:
continue
# Check if draft model needs download
if draft_model_id not in download_status or not isinstance(
download_status[draft_model_id], (DownloadOngoing, DownloadCompleted)
):
# Create minimal shard metadata for draft model download
draft_shard = PipelineShardMetadata(
model_meta=ModelMetadata(
model_id=draft_model_id,
pretty_name=str(draft_model_id),
storage_size=Memory.from_bytes(0), # Unknown, will be determined during download
n_layers=1, # Placeholder
hidden_size=1, # Placeholder
supports_tensor=False,
),
device_rank=0,
world_size=1,
start_layer=0,
end_layer=1,
n_layers=1,
)
return DownloadModel(
instance_id=instance.instance_id,
shard_metadata=draft_shard,
)
def _init_distributed_backend(
runners: Mapping[RunnerId, RunnerSupervisor],
all_runners: Mapping[RunnerId, RunnerStatus],

View File

@@ -56,6 +56,7 @@ from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.generator.generate import mlx_generate, warmup_inference
from exo.worker.engines.mlx.utils_mlx import (
initialize_mlx,
load_draft_model,
load_mlx_items,
mlx_force_oom,
)
@@ -110,6 +111,7 @@ def main(
model = None
tokenizer = None
group = None
draft_model: Model | None = None # Loaded during warmup if instance has draft_model
current_status: RunnerStatus = RunnerIdle()
logger.info("runner created")
@@ -178,11 +180,20 @@ def main(
)
)
# Load draft model for speculative decoding (rank 0 only)
if (
instance.draft_model is not None
and shard_metadata.device_rank == 0
):
logger.info(f"Loading draft model: {instance.draft_model}")
draft_model = cast(
Model, load_draft_model(str(instance.draft_model))
)
logger.info(f"warming up inference for instance: {instance}")
toks = warmup_inference(
model=cast(Model, model),
tokenizer=tokenizer,
# kv_prefix_cache=kv_prefix_cache, # supply for warmup-time prefix caching
)
logger.info(f"warmed up by generating {toks} tokens")
logger.info(
@@ -212,11 +223,13 @@ def main(
assert task_params.messages[0].content is not None
_check_for_debug_prompts(task_params.messages[0].content)
# Generate responses using the actual MLX generation
# Generate responses (draft_model loaded at warmup if configured)
mlx_generator = mlx_generate(
model=cast(Model, model),
tokenizer=tokenizer,
task=task_params,
draft_model=draft_model,
num_draft_tokens=instance.num_draft_tokens,
)
# GPT-OSS specific parsing to match other model formats.
@@ -265,7 +278,7 @@ def main(
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
)
if isinstance(current_status, RunnerShutdown):
del model, tokenizer, group
del model, tokenizer, group, draft_model
mx.clear_cache()
import gc