mirror of
https://github.com/exo-explore/exo.git
synced 2026-01-18 10:58:35 -05:00
Compare commits
2 Commits
alexcheema
...
alexcheema
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f2857adf63 | ||
|
|
3a161b4a3e |
@@ -69,8 +69,6 @@ export interface Instance {
|
||||
runnerToShard?: Record<string, unknown>;
|
||||
nodeToRunner?: Record<string, string>;
|
||||
};
|
||||
draftModel?: string;
|
||||
numDraftTokens?: number;
|
||||
}
|
||||
|
||||
interface RawNodeProfile {
|
||||
|
||||
@@ -47,7 +47,7 @@ const sidebarVisible = $derived(chatSidebarVisible());
|
||||
let mounted = $state(false);
|
||||
|
||||
// Instance launch state
|
||||
let models = $state<Array<{id: string, hugging_face_id?: string, name?: string, storage_size_megabytes?: number}>>([]);
|
||||
let models = $state<Array<{id: string, name?: string, storage_size_megabytes?: number}>>([]);
|
||||
let selectedSharding = $state<'Pipeline' | 'Tensor'>('Pipeline');
|
||||
type InstanceMeta = 'MlxRing' | 'MlxIbv' | 'MlxJaccl';
|
||||
|
||||
@@ -58,8 +58,6 @@ const sidebarVisible = $derived(chatSidebarVisible());
|
||||
sharding: 'Pipeline' | 'Tensor';
|
||||
instanceType: InstanceMeta;
|
||||
minNodes: number;
|
||||
draftModel: string | null;
|
||||
numDraftTokens: number;
|
||||
}
|
||||
|
||||
function saveLaunchDefaults(): void {
|
||||
@@ -68,8 +66,6 @@ const sidebarVisible = $derived(chatSidebarVisible());
|
||||
sharding: selectedSharding,
|
||||
instanceType: selectedInstanceType,
|
||||
minNodes: selectedMinNodes,
|
||||
draftModel: selectedDraftModel,
|
||||
numDraftTokens: selectedNumDraftTokens,
|
||||
};
|
||||
try {
|
||||
localStorage.setItem(LAUNCH_DEFAULTS_KEY, JSON.stringify(defaults));
|
||||
@@ -92,36 +88,24 @@ const sidebarVisible = $derived(chatSidebarVisible());
|
||||
function applyLaunchDefaults(availableModels: Array<{id: string}>, maxNodes: number): void {
|
||||
const defaults = loadLaunchDefaults();
|
||||
if (!defaults) return;
|
||||
|
||||
|
||||
// Apply sharding and instance type unconditionally
|
||||
selectedSharding = defaults.sharding;
|
||||
selectedInstanceType = defaults.instanceType;
|
||||
|
||||
|
||||
// Apply minNodes if valid (between 1 and maxNodes)
|
||||
if (defaults.minNodes && defaults.minNodes >= 1 && defaults.minNodes <= maxNodes) {
|
||||
selectedMinNodes = defaults.minNodes;
|
||||
}
|
||||
|
||||
|
||||
// Only apply model if it exists in the available models
|
||||
if (defaults.modelId && availableModels.some(m => m.id === defaults.modelId)) {
|
||||
selectPreviewModel(defaults.modelId);
|
||||
}
|
||||
|
||||
// Apply draft model if it exists in the available models (check against hugging_face_id)
|
||||
if (defaults.draftModel && availableModels.some(m => (m as {hugging_face_id?: string}).hugging_face_id === defaults.draftModel)) {
|
||||
selectedDraftModel = defaults.draftModel;
|
||||
}
|
||||
|
||||
// Apply num draft tokens if valid
|
||||
if (defaults.numDraftTokens && defaults.numDraftTokens >= 1 && defaults.numDraftTokens <= 10) {
|
||||
selectedNumDraftTokens = defaults.numDraftTokens;
|
||||
}
|
||||
}
|
||||
|
||||
let selectedInstanceType = $state<InstanceMeta>('MlxRing');
|
||||
let selectedMinNodes = $state<number>(1);
|
||||
let selectedDraftModel = $state<string | null>(null);
|
||||
let selectedNumDraftTokens = $state<number>(4);
|
||||
let minNodesInitialized = $state(false);
|
||||
let launchingModelId = $state<string | null>(null);
|
||||
let instanceDownloadExpandedNodes = $state<Set<string>>(new Set());
|
||||
@@ -129,8 +113,6 @@ let instanceDownloadExpandedNodes = $state<Set<string>>(new Set());
|
||||
// Custom dropdown state
|
||||
let isModelDropdownOpen = $state(false);
|
||||
let modelDropdownSearch = $state('');
|
||||
let isDraftModelDropdownOpen = $state(false);
|
||||
let draftModelDropdownSearch = $state('');
|
||||
|
||||
// Slider dragging state
|
||||
let isDraggingSlider = $state(false);
|
||||
@@ -380,39 +362,47 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
|
||||
async function launchInstance(modelId: string, specificPreview?: PlacementPreview | null) {
|
||||
if (!modelId || launchingModelId) return;
|
||||
|
||||
|
||||
launchingModelId = modelId;
|
||||
|
||||
|
||||
try {
|
||||
// Use the specific preview if provided, otherwise fall back to filtered preview
|
||||
const preview = specificPreview ?? filteredPreview();
|
||||
|
||||
let response: Response;
|
||||
|
||||
// Use /place_instance endpoint - it handles placement and creation in one step
|
||||
// This also supports draft_model for speculative decoding
|
||||
const placePayload = {
|
||||
model_id: modelId,
|
||||
sharding: preview?.sharding ?? selectedSharding,
|
||||
instance_meta: preview?.instance_meta ?? selectedInstanceType,
|
||||
min_nodes: selectedMinNodes,
|
||||
draft_model: selectedDraftModel,
|
||||
num_draft_tokens: selectedDraftModel ? selectedNumDraftTokens : 4,
|
||||
};
|
||||
|
||||
response = await fetch('/place_instance', {
|
||||
|
||||
let instanceData: unknown;
|
||||
|
||||
if (preview?.instance) {
|
||||
// Use the instance from the preview
|
||||
instanceData = preview.instance;
|
||||
} else {
|
||||
// Fallback: GET placement from API
|
||||
const placementResponse = await fetch(
|
||||
`/instance/placement?model_id=${encodeURIComponent(modelId)}&sharding=${selectedSharding}&instance_meta=${selectedInstanceType}&min_nodes=${selectedMinNodes}`
|
||||
);
|
||||
|
||||
if (!placementResponse.ok) {
|
||||
const errorText = await placementResponse.text();
|
||||
console.error('Failed to get placement:', errorText);
|
||||
return;
|
||||
}
|
||||
|
||||
instanceData = await placementResponse.json();
|
||||
}
|
||||
|
||||
// POST the instance to create it
|
||||
const response = await fetch('/instance', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify(placePayload)
|
||||
body: JSON.stringify({ instance: instanceData })
|
||||
});
|
||||
|
||||
|
||||
if (!response.ok) {
|
||||
const errorText = await response.text();
|
||||
console.error('Failed to launch instance:', errorText);
|
||||
} else {
|
||||
// Always auto-select the newly launched model so the user chats to what they just launched
|
||||
setSelectedChatModel(modelId);
|
||||
|
||||
|
||||
// Scroll to the bottom of instances container to show the new instance
|
||||
// Use multiple attempts to ensure DOM has updated with the new instance
|
||||
const scrollToBottom = () => {
|
||||
@@ -826,34 +816,30 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
}
|
||||
|
||||
// Get instance details: type (MLX Ring/IBV), sharding (Pipeline/Tensor), and node names
|
||||
function getInstanceInfo(instanceWrapped: unknown): {
|
||||
instanceType: string;
|
||||
sharding: string;
|
||||
function getInstanceInfo(instanceWrapped: unknown): {
|
||||
instanceType: string;
|
||||
sharding: string;
|
||||
nodeNames: string[];
|
||||
nodeIds: string[];
|
||||
nodeCount: number;
|
||||
draftModel: string | null;
|
||||
numDraftTokens: number | null;
|
||||
} {
|
||||
const [instanceTag, instance] = getTagged(instanceWrapped);
|
||||
if (!instance || typeof instance !== 'object') {
|
||||
return { instanceType: 'Unknown', sharding: 'Unknown', nodeNames: [], nodeIds: [], nodeCount: 0, draftModel: null, numDraftTokens: null };
|
||||
return { instanceType: 'Unknown', sharding: 'Unknown', nodeNames: [], nodeIds: [], nodeCount: 0 };
|
||||
}
|
||||
|
||||
|
||||
// Instance type from tag
|
||||
let instanceType = 'Unknown';
|
||||
if (instanceTag === 'MlxRingInstance') instanceType = 'MLX Ring';
|
||||
else if (instanceTag === 'MlxIbvInstance' || instanceTag === 'MlxJacclInstance') instanceType = 'MLX RDMA';
|
||||
|
||||
const inst = instance as {
|
||||
shardAssignments?: {
|
||||
nodeToRunner?: Record<string, string>;
|
||||
|
||||
const inst = instance as {
|
||||
shardAssignments?: {
|
||||
nodeToRunner?: Record<string, string>;
|
||||
runnerToShard?: Record<string, unknown>;
|
||||
};
|
||||
draftModel?: string;
|
||||
numDraftTokens?: number;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
// Sharding strategy from first shard
|
||||
let sharding = 'Unknown';
|
||||
const runnerToShard = inst.shardAssignments?.runnerToShard || {};
|
||||
@@ -864,7 +850,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
else if (shardTag === 'TensorShardMetadata') sharding = 'Tensor';
|
||||
else if (shardTag === 'PrefillDecodeShardMetadata') sharding = 'Prefill/Decode';
|
||||
}
|
||||
|
||||
|
||||
// Node names from topology
|
||||
const nodeToRunner = inst.shardAssignments?.nodeToRunner || {};
|
||||
const nodeIds = Object.keys(nodeToRunner);
|
||||
@@ -872,12 +858,8 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
const node = data?.nodes?.[nodeId];
|
||||
return node?.friendly_name || nodeId.slice(0, 8);
|
||||
});
|
||||
|
||||
// Draft model for speculative decoding
|
||||
const draftModel = inst.draftModel ?? null;
|
||||
const numDraftTokens = inst.numDraftTokens ?? null;
|
||||
|
||||
return { instanceType, sharding, nodeNames, nodeIds, nodeCount: nodeIds.length, draftModel, numDraftTokens };
|
||||
|
||||
return { instanceType, sharding, nodeNames, nodeIds, nodeCount: nodeIds.length };
|
||||
}
|
||||
|
||||
function formatLastUpdate(): string {
|
||||
@@ -1363,9 +1345,6 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
<div class="pl-2">
|
||||
<div class="text-exo-yellow text-xs font-mono tracking-wide truncate">{getInstanceModelId(instance)}</div>
|
||||
<div class="text-white/60 text-xs font-mono">Strategy: <span class="text-white/80">{instanceInfo.sharding} ({instanceInfo.instanceType})</span></div>
|
||||
{#if instanceInfo.draftModel}
|
||||
<div class="text-white/60 text-xs font-mono">Draft: <span class="text-cyan-400">{instanceInfo.draftModel.split('/').pop()}</span>{#if instanceInfo.numDraftTokens}<span class="text-white/40"> ({instanceInfo.numDraftTokens}t)</span>{/if}</div>
|
||||
{/if}
|
||||
{#if instanceModelId && instanceModelId !== 'Unknown' && instanceModelId !== 'Unknown Model'}
|
||||
<a
|
||||
class="inline-flex items-center gap-1 text-[11px] text-white/60 hover:text-exo-yellow transition-colors mt-1"
|
||||
@@ -1699,80 +1678,8 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
{/each}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Draft Model (Speculative Decoding) -->
|
||||
<div>
|
||||
<div class="text-xs text-white/70 font-mono mb-2">Draft Model (Speculative):</div>
|
||||
<div class="relative">
|
||||
<button
|
||||
onclick={() => { isDraftModelDropdownOpen = !isDraftModelDropdownOpen; draftModelDropdownSearch = ''; }}
|
||||
class="w-full px-3 py-2 text-left text-sm font-mono border rounded transition-all duration-200 cursor-pointer flex items-center justify-between gap-2 {selectedDraftModel ? 'bg-transparent text-exo-yellow border-exo-yellow' : 'bg-transparent text-white/50 border-exo-medium-gray/50 hover:border-exo-yellow/50'}"
|
||||
>
|
||||
<span class="truncate">{selectedDraftModel ? selectedDraftModel.split('/').pop() : 'None'}</span>
|
||||
<svg class="w-4 h-4 flex-shrink-0 transition-transform {isDraftModelDropdownOpen ? 'rotate-180' : ''}" fill="none" stroke="currentColor" viewBox="0 0 24 24">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M19 9l-7 7-7-7" />
|
||||
</svg>
|
||||
</button>
|
||||
{#if isDraftModelDropdownOpen}
|
||||
<!-- svelte-ignore a11y_no_static_element_interactions -->
|
||||
<div
|
||||
class="fixed inset-0 z-40"
|
||||
onclick={() => isDraftModelDropdownOpen = false}
|
||||
onkeydown={(e) => e.key === 'Escape' && (isDraftModelDropdownOpen = false)}
|
||||
></div>
|
||||
<div class="absolute top-full left-0 right-0 mt-1 bg-exo-dark-gray border border-exo-medium-gray/50 rounded shadow-lg z-50 max-h-48 overflow-hidden flex flex-col">
|
||||
<div class="p-2 border-b border-exo-medium-gray/30">
|
||||
<input
|
||||
type="text"
|
||||
bind:value={draftModelDropdownSearch}
|
||||
placeholder="Search models..."
|
||||
class="w-full px-2 py-1.5 text-sm font-mono bg-transparent border border-exo-medium-gray/50 rounded text-white/90 placeholder:text-white/30 focus:outline-none focus:border-exo-yellow/50"
|
||||
/>
|
||||
</div>
|
||||
<div class="overflow-y-auto max-h-36">
|
||||
<!-- None option -->
|
||||
<button
|
||||
onclick={() => { selectedDraftModel = null; isDraftModelDropdownOpen = false; saveLaunchDefaults(); }}
|
||||
class="w-full px-3 py-2 text-left text-sm font-mono tracking-wide transition-colors duration-100 flex items-center gap-2 {selectedDraftModel === null ? 'bg-transparent text-exo-yellow cursor-pointer' : 'text-white/80 hover:text-exo-yellow cursor-pointer'}"
|
||||
>
|
||||
<span>None</span>
|
||||
</button>
|
||||
{#each models.filter(m => (m.name ?? m.id).toLowerCase().includes(draftModelDropdownSearch.toLowerCase()) && m.id !== selectedModelId) as model}
|
||||
{@const sizeGB = (model.storage_size_megabytes ?? 0) / 1024}
|
||||
{@const modelHfId = model.hugging_face_id ?? model.id}
|
||||
<button
|
||||
onclick={() => { selectedDraftModel = modelHfId; isDraftModelDropdownOpen = false; saveLaunchDefaults(); }}
|
||||
class="w-full px-3 py-2 text-left text-sm font-mono tracking-wide transition-colors duration-100 flex items-center justify-between gap-2 {selectedDraftModel === modelHfId ? 'bg-transparent text-exo-yellow cursor-pointer' : 'text-white/80 hover:text-exo-yellow cursor-pointer'}"
|
||||
>
|
||||
<span class="truncate">{model.name || model.id}</span>
|
||||
<span class="flex-shrink-0 text-xs text-white/50">
|
||||
{sizeGB >= 1 ? sizeGB.toFixed(0) : sizeGB.toFixed(1)}GB
|
||||
</span>
|
||||
</button>
|
||||
{:else}
|
||||
<div class="px-3 py-2 text-xs text-white/50 font-mono">No models found</div>
|
||||
{/each}
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
</div>
|
||||
<!-- Draft Tokens (only show when draft model selected) -->
|
||||
{#if selectedDraftModel}
|
||||
<div class="flex items-center gap-2 mt-2">
|
||||
<span class="text-xs text-white/50 font-mono">Tokens:</span>
|
||||
<div class="flex items-center gap-1">
|
||||
{#each [2, 3, 4, 5, 6] as n}
|
||||
<button
|
||||
onclick={() => { selectedNumDraftTokens = n; saveLaunchDefaults(); }}
|
||||
class="w-6 h-6 text-xs font-mono rounded transition-all {selectedNumDraftTokens === n ? 'bg-exo-yellow/20 text-exo-yellow border border-exo-yellow/50' : 'text-white/50 hover:text-white/80 border border-transparent'}"
|
||||
>{n}</button>
|
||||
{/each}
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
|
||||
<!-- Selected Model Preview -->
|
||||
<div class="space-y-3">
|
||||
{#if models.length === 0}
|
||||
|
||||
@@ -200,8 +200,6 @@ class API:
|
||||
sharding=payload.sharding,
|
||||
instance_meta=payload.instance_meta,
|
||||
min_nodes=payload.min_nodes,
|
||||
draft_model=payload.draft_model,
|
||||
num_draft_tokens=payload.num_draft_tokens,
|
||||
)
|
||||
await self._send(command)
|
||||
|
||||
|
||||
@@ -151,8 +151,6 @@ def place_instance(
|
||||
shard_assignments=shard_assignments,
|
||||
ibv_devices=mlx_ibv_devices,
|
||||
jaccl_coordinators=mlx_jaccl_coordinators,
|
||||
draft_model=command.draft_model,
|
||||
num_draft_tokens=command.num_draft_tokens,
|
||||
)
|
||||
case InstanceMeta.MlxRing:
|
||||
ephemeral_port = random_ephemeral_port()
|
||||
@@ -166,8 +164,6 @@ def place_instance(
|
||||
shard_assignments=shard_assignments,
|
||||
hosts_by_node=hosts_by_node,
|
||||
ephemeral_port=ephemeral_port,
|
||||
draft_model=command.draft_model,
|
||||
num_draft_tokens=command.num_draft_tokens,
|
||||
)
|
||||
|
||||
return target_instances
|
||||
|
||||
@@ -161,8 +161,6 @@ class ChatCompletionTaskParams(BaseModel):
|
||||
tool_choice: str | dict[str, Any] | None = None
|
||||
parallel_tool_calls: bool | None = None
|
||||
user: str | None = None
|
||||
# Speculative decoding: tokens to draft per iteration (if instance has draft model)
|
||||
num_draft_tokens: int = 3
|
||||
|
||||
|
||||
class BenchChatCompletionTaskParams(ChatCompletionTaskParams):
|
||||
@@ -174,8 +172,6 @@ class PlaceInstanceParams(BaseModel):
|
||||
sharding: Sharding = Sharding.Pipeline
|
||||
instance_meta: InstanceMeta = InstanceMeta.MlxRing
|
||||
min_nodes: int = 1
|
||||
draft_model: ModelId | None = None # For speculative decoding
|
||||
num_draft_tokens: int = 4 # Tokens to draft per iteration
|
||||
|
||||
@field_validator("sharding", "instance_meta", mode="plain")
|
||||
@classmethod
|
||||
|
||||
@@ -2,7 +2,7 @@ from pydantic import Field
|
||||
|
||||
from exo.shared.types.api import ChatCompletionTaskParams
|
||||
from exo.shared.types.common import CommandId, NodeId
|
||||
from exo.shared.types.models import ModelId, ModelMetadata
|
||||
from exo.shared.types.models import ModelMetadata
|
||||
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
|
||||
from exo.shared.types.worker.shards import Sharding
|
||||
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
|
||||
@@ -25,8 +25,6 @@ class PlaceInstance(BaseCommand):
|
||||
sharding: Sharding
|
||||
instance_meta: InstanceMeta
|
||||
min_nodes: int
|
||||
draft_model: ModelId | None = None # For speculative decoding
|
||||
num_draft_tokens: int = 4 # Tokens to draft per iteration
|
||||
|
||||
|
||||
class CreateInstance(BaseCommand):
|
||||
|
||||
@@ -3,7 +3,6 @@ from enum import Enum
|
||||
from pydantic import model_validator
|
||||
|
||||
from exo.shared.types.common import Host, Id, NodeId
|
||||
from exo.shared.types.models import ModelId
|
||||
from exo.shared.types.worker.runners import RunnerId, ShardAssignments, ShardMetadata
|
||||
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
|
||||
|
||||
@@ -20,8 +19,6 @@ class InstanceMeta(str, Enum):
|
||||
class BaseInstance(TaggedModel):
|
||||
instance_id: InstanceId
|
||||
shard_assignments: ShardAssignments
|
||||
draft_model: ModelId | None = None # For speculative decoding (rank 0 only)
|
||||
num_draft_tokens: int = 4 # Tokens to draft per iteration (when draft_model is set)
|
||||
|
||||
def shard(self, runner_id: RunnerId) -> ShardMetadata | None:
|
||||
return self.shard_assignments.runner_to_shard.get(runner_id, None)
|
||||
|
||||
@@ -13,3 +13,8 @@ KV_CACHE_BITS: int | None = None
|
||||
|
||||
# TODO: We should really make this opt-in, but Kimi requires trust_remote_code=True
|
||||
TRUST_REMOTE_CODE: bool = True
|
||||
|
||||
# Multi-Token Prediction (MTP) configuration for DeepSeek V3
|
||||
# MTP enables speculative decoding using the model's built-in draft layer
|
||||
MTP_ENABLED: bool = True # Feature flag to enable/disable MTP
|
||||
MTP_NUM_DRAFT_TOKENS: int = 1 # Number of tokens to draft (vLLM reports k=1 is optimal)
|
||||
|
||||
@@ -19,7 +19,13 @@ from exo.shared.types.worker.runner_response import (
|
||||
GenerationResponse,
|
||||
)
|
||||
from exo.worker.engines.mlx import Model
|
||||
from exo.worker.engines.mlx.constants import KV_BITS, KV_GROUP_SIZE, MAX_TOKENS
|
||||
from exo.worker.engines.mlx.constants import (
|
||||
KV_BITS,
|
||||
KV_GROUP_SIZE,
|
||||
MAX_TOKENS,
|
||||
MTP_ENABLED,
|
||||
MTP_NUM_DRAFT_TOKENS,
|
||||
)
|
||||
from exo.worker.engines.mlx.utils_mlx import (
|
||||
apply_chat_template,
|
||||
make_kv_cache,
|
||||
@@ -115,12 +121,15 @@ def eos_ids_from_tokenizer(tokenizer: TokenizerWrapper) -> list[int]:
|
||||
return eos
|
||||
|
||||
|
||||
def _has_mtp_module(model: Model) -> bool:
|
||||
"""Check if the model has an attached MTP module."""
|
||||
return hasattr(model, "mtp_module") and model.mtp_module is not None # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def mlx_generate(
|
||||
model: Model,
|
||||
tokenizer: TokenizerWrapper,
|
||||
task: ChatCompletionTaskParams,
|
||||
draft_model: Model | None = None,
|
||||
num_draft_tokens: int = 4,
|
||||
) -> Generator[GenerationResponse]:
|
||||
# Ensure that generation stats only contains peak memory for this generation
|
||||
mx.reset_peak_memory()
|
||||
@@ -137,6 +146,8 @@ def mlx_generate(
|
||||
chat_task_data=task,
|
||||
)
|
||||
|
||||
caches = make_kv_cache(model=model)
|
||||
|
||||
logits_processors: list[Callable[[mx.array, mx.array], mx.array]] = []
|
||||
if is_bench:
|
||||
# Only sample length eos tokens
|
||||
@@ -150,30 +161,55 @@ def mlx_generate(
|
||||
|
||||
max_tokens = task.max_tokens or MAX_TOKENS
|
||||
|
||||
# Build kwargs for stream_generate, conditionally adding draft model params
|
||||
generate_kwargs: dict[str, object] = {
|
||||
"model": model,
|
||||
"tokenizer": tokenizer,
|
||||
"prompt": prompt,
|
||||
"max_tokens": max_tokens,
|
||||
"sampler": sampler,
|
||||
"logits_processors": logits_processors,
|
||||
"prefill_step_size": 2048,
|
||||
"kv_group_size": KV_GROUP_SIZE,
|
||||
"kv_bits": KV_BITS,
|
||||
}
|
||||
# Check if we should use MTP speculative decoding
|
||||
use_mtp = MTP_ENABLED and _has_mtp_module(model)
|
||||
|
||||
# Add speculative decoding parameters if draft model is provided
|
||||
# Note: When using draft_model, we let mlx_lm create its own trimmable cache
|
||||
# as speculative decoding requires cache trimming capabilities
|
||||
if draft_model is not None:
|
||||
generate_kwargs["draft_model"] = draft_model
|
||||
generate_kwargs["num_draft_tokens"] = num_draft_tokens
|
||||
if use_mtp:
|
||||
logger.info("Using MTP speculative decoding")
|
||||
yield from _mlx_generate_with_mtp(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
prompt=prompt,
|
||||
max_tokens=max_tokens,
|
||||
sampler=sampler,
|
||||
logits_processors=logits_processors,
|
||||
prompt_cache=caches,
|
||||
)
|
||||
else:
|
||||
# Only use custom cache for non-speculative generation
|
||||
generate_kwargs["prompt_cache"] = make_kv_cache(model=model)
|
||||
yield from _mlx_generate_standard(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
prompt=prompt,
|
||||
max_tokens=max_tokens,
|
||||
sampler=sampler,
|
||||
logits_processors=logits_processors,
|
||||
prompt_cache=caches,
|
||||
)
|
||||
|
||||
for out in stream_generate(**generate_kwargs): # type: ignore[arg-type]
|
||||
|
||||
def _mlx_generate_standard(
|
||||
model: Model,
|
||||
tokenizer: TokenizerWrapper,
|
||||
prompt: str,
|
||||
max_tokens: int,
|
||||
sampler: Callable[[mx.array], mx.array],
|
||||
logits_processors: list[Callable[[mx.array, mx.array], mx.array]],
|
||||
prompt_cache: list[KVCache | Any],
|
||||
) -> Generator[GenerationResponse]:
|
||||
"""Standard generation path using mlx_lm stream_generate."""
|
||||
for out in stream_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
prompt=prompt,
|
||||
max_tokens=max_tokens,
|
||||
sampler=sampler,
|
||||
logits_processors=logits_processors,
|
||||
prompt_cache=prompt_cache,
|
||||
# TODO: Dynamically change prefill step size to be the maximum possible without timing out.
|
||||
prefill_step_size=2048,
|
||||
kv_group_size=KV_GROUP_SIZE,
|
||||
kv_bits=KV_BITS,
|
||||
):
|
||||
logger.info(out.text)
|
||||
|
||||
stats: GenerationStats | None = None
|
||||
@@ -203,4 +239,64 @@ def mlx_generate(
|
||||
if out.finish_reason is not None:
|
||||
break
|
||||
|
||||
|
||||
def _mlx_generate_with_mtp(
|
||||
model: Model,
|
||||
tokenizer: TokenizerWrapper,
|
||||
prompt: str,
|
||||
max_tokens: int,
|
||||
sampler: Callable[[mx.array], mx.array],
|
||||
logits_processors: list[Callable[[mx.array, mx.array], mx.array]],
|
||||
prompt_cache: list[KVCache | Any],
|
||||
) -> Generator[GenerationResponse]:
|
||||
"""MTP speculative decoding generation path.
|
||||
|
||||
Uses the model's attached MTP module for speculative decoding,
|
||||
which can provide 1.5-2x speedup with ~81% acceptance rate.
|
||||
"""
|
||||
from exo.worker.engines.mlx.mtp.speculative_decode import mtp_speculative_generate
|
||||
|
||||
mtp_module = model.mtp_module # type: ignore[attr-defined]
|
||||
|
||||
for out in mtp_speculative_generate(
|
||||
model=model,
|
||||
mtp_module=mtp_module,
|
||||
tokenizer=tokenizer,
|
||||
prompt=prompt,
|
||||
max_tokens=max_tokens,
|
||||
sampler=sampler,
|
||||
logits_processors=logits_processors,
|
||||
prompt_cache=prompt_cache,
|
||||
num_draft_tokens=MTP_NUM_DRAFT_TOKENS,
|
||||
prefill_step_size=2048,
|
||||
kv_group_size=KV_GROUP_SIZE if KV_GROUP_SIZE is not None else 64,
|
||||
kv_bits=KV_BITS,
|
||||
):
|
||||
logger.info(f"{out.text} (from_draft={out.from_draft})")
|
||||
|
||||
stats: GenerationStats | None = None
|
||||
if out.finish_reason is not None:
|
||||
stats = GenerationStats(
|
||||
prompt_tps=float(out.prompt_tps),
|
||||
generation_tps=float(out.generation_tps),
|
||||
prompt_tokens=int(out.prompt_tokens),
|
||||
generation_tokens=int(out.generation_tokens),
|
||||
peak_memory_usage=Memory.from_gb(out.peak_memory),
|
||||
)
|
||||
|
||||
if out.finish_reason not in get_args(FinishReason):
|
||||
logger.warning(
|
||||
f"Model generated unexpected finish_reason: {out.finish_reason}"
|
||||
)
|
||||
|
||||
yield GenerationResponse(
|
||||
text=out.text,
|
||||
token=out.token,
|
||||
finish_reason=cast(FinishReason | None, out.finish_reason),
|
||||
stats=stats,
|
||||
)
|
||||
|
||||
if out.finish_reason is not None:
|
||||
break
|
||||
|
||||
# TODO: Do we want an mx_barrier?
|
||||
|
||||
6
src/exo/worker/engines/mlx/mtp/__init__.py
Normal file
6
src/exo/worker/engines/mlx/mtp/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""Multi-Token Prediction (MTP) module for DeepSeek V3 speculative decoding."""
|
||||
|
||||
from exo.worker.engines.mlx.mtp.module import MTPModule
|
||||
from exo.worker.engines.mlx.mtp.speculative_decode import mtp_speculative_generate
|
||||
|
||||
__all__ = ["MTPModule", "mtp_speculative_generate"]
|
||||
207
src/exo/worker/engines/mlx/mtp/module.py
Normal file
207
src/exo/worker/engines/mlx/mtp/module.py
Normal file
@@ -0,0 +1,207 @@
|
||||
"""MTP Module for DeepSeek V3 Multi-Token Prediction.
|
||||
|
||||
The MTP architecture predicts one additional token ahead using:
|
||||
1. hnorm - RMSNorm for hidden state normalization
|
||||
2. enorm - RMSNorm for embedding normalization
|
||||
3. eh_proj - Linear(2*hidden_size -> hidden_size) projection
|
||||
4. transformer_block - Single decoder layer (attention + MLP)
|
||||
5. Shared embedding/lm_head from main model
|
||||
|
||||
Forward pass:
|
||||
h_norm = hnorm(hidden_state)
|
||||
e_norm = enorm(embed(token))
|
||||
projected = eh_proj(concat([h_norm, e_norm]))
|
||||
new_hidden = transformer_block(projected)
|
||||
logits = lm_head(output_norm(new_hidden))
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from mlx_lm.models.cache import KVCache
|
||||
from mlx_lm.models.deepseek_v3 import (
|
||||
DeepseekV3Attention,
|
||||
DeepseekV3MLP,
|
||||
ModelArgs,
|
||||
)
|
||||
|
||||
MTP_LAYER_INDEX = 61
|
||||
|
||||
|
||||
class MTPModule(nn.Module):
|
||||
"""Multi-Token Prediction module for DeepSeek V3.
|
||||
|
||||
This module is initialized from the layer 61 weights that are normally
|
||||
discarded during model loading. It enables speculative decoding by
|
||||
predicting one token ahead using the hidden state from the main model.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: ModelArgs,
|
||||
shared_embedding: nn.Embedding,
|
||||
shared_lm_head: nn.Linear,
|
||||
output_norm: nn.RMSNorm,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
# MTP-specific normalization layers
|
||||
self.hnorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.enorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
# Projection: concatenated [hidden, embedding] -> hidden_size
|
||||
self.eh_proj = nn.Linear(2 * config.hidden_size, config.hidden_size, bias=False)
|
||||
|
||||
# Single transformer block for MTP
|
||||
# Use a dense MLP since this is just a single layer
|
||||
self.transformer_block = MTPTransformerBlock(config)
|
||||
|
||||
# Share embedding and lm_head with main model
|
||||
self._shared_embedding = shared_embedding
|
||||
self._shared_lm_head = shared_lm_head
|
||||
self._output_norm = output_norm
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
hidden_state: mx.array,
|
||||
draft_token: mx.array,
|
||||
cache: KVCache | None = None,
|
||||
mask: mx.array | None = None,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
"""Forward pass for MTP.
|
||||
|
||||
Args:
|
||||
hidden_state: Hidden state from main model [batch, seq_len, hidden_size]
|
||||
draft_token: Token to embed and combine with hidden state [batch, seq_len]
|
||||
cache: Optional KV cache for the MTP transformer block
|
||||
mask: Optional attention mask
|
||||
|
||||
Returns:
|
||||
tuple of (logits, new_hidden_state)
|
||||
"""
|
||||
# Get embedding of draft token
|
||||
embedding = self._shared_embedding(draft_token)
|
||||
|
||||
# Normalize hidden state and embedding
|
||||
h_norm = self.hnorm(hidden_state)
|
||||
e_norm = self.enorm(embedding)
|
||||
|
||||
# Project concatenated representation
|
||||
concatenated = mx.concatenate([h_norm, e_norm], axis=-1)
|
||||
projected = self.eh_proj(concatenated)
|
||||
|
||||
# Pass through single transformer block
|
||||
new_hidden = self.transformer_block(projected, mask=mask, cache=cache)
|
||||
|
||||
# Apply output norm and get logits
|
||||
normed_hidden = self._output_norm(new_hidden)
|
||||
logits = self._shared_lm_head(normed_hidden)
|
||||
|
||||
return logits, new_hidden
|
||||
|
||||
|
||||
class MTPTransformerBlock(nn.Module):
|
||||
"""Single transformer block for MTP.
|
||||
|
||||
This is similar to DeepseekV3DecoderLayer but uses a dense MLP
|
||||
instead of MoE since this is just for the single MTP layer.
|
||||
"""
|
||||
|
||||
def __init__(self, config: ModelArgs) -> None:
|
||||
super().__init__()
|
||||
self.self_attn = DeepseekV3Attention(config)
|
||||
# MTP uses dense MLP, not MoE
|
||||
self.mlp = DeepseekV3MLP(config)
|
||||
self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = nn.RMSNorm(
|
||||
config.hidden_size, eps=config.rms_norm_eps
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: mx.array | None = None,
|
||||
cache: Any | None = None,
|
||||
) -> mx.array:
|
||||
"""Forward pass with residual connections."""
|
||||
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||
h = x + r
|
||||
r = self.mlp(self.post_attention_layernorm(h))
|
||||
return h + r
|
||||
|
||||
|
||||
def extract_mtp_weights(weights: dict[str, mx.array]) -> dict[str, mx.array]:
|
||||
"""Extract MTP-specific weights from layer 61.
|
||||
|
||||
The MTP layer has these weight patterns:
|
||||
- model.layers.61.enorm.weight -> MTP embedding normalization
|
||||
- model.layers.61.hnorm.weight -> MTP hidden normalization
|
||||
- model.layers.61.eh_proj.weight -> MTP projection layer
|
||||
- model.layers.61.self_attn.* -> MTP attention
|
||||
- model.layers.61.input_layernorm.* -> MTP layer norms
|
||||
- model.layers.61.post_attention_layernorm.*
|
||||
- model.layers.61.mlp.* -> MTP MLP (dense, not MoE)
|
||||
|
||||
Args:
|
||||
weights: Full model weights dict
|
||||
|
||||
Returns:
|
||||
Dict of MTP-specific weights with keys renamed for MTPModule
|
||||
"""
|
||||
mtp_weights: dict[str, mx.array] = {}
|
||||
mtp_prefix = f"model.layers.{MTP_LAYER_INDEX}."
|
||||
|
||||
for key, value in weights.items():
|
||||
if key.startswith(mtp_prefix):
|
||||
# Remove the layer prefix to get relative path
|
||||
new_key = key[len(mtp_prefix) :]
|
||||
mtp_weights[new_key] = value
|
||||
|
||||
return mtp_weights
|
||||
|
||||
|
||||
def load_mtp_weights_into_module(
|
||||
mtp_module: MTPModule,
|
||||
mtp_weights: dict[str, mx.array],
|
||||
) -> None:
|
||||
"""Load extracted MTP weights into the MTPModule.
|
||||
|
||||
Args:
|
||||
mtp_module: The MTPModule instance to load weights into
|
||||
mtp_weights: Extracted MTP weights from extract_mtp_weights()
|
||||
"""
|
||||
# Map weight names to module attributes
|
||||
weight_mapping: dict[str, str] = {
|
||||
"enorm.weight": "enorm.weight",
|
||||
"hnorm.weight": "hnorm.weight",
|
||||
"eh_proj.weight": "eh_proj.weight",
|
||||
}
|
||||
|
||||
# Load direct mappings
|
||||
for src_name, dst_name in weight_mapping.items():
|
||||
if src_name in mtp_weights:
|
||||
parts = dst_name.split(".")
|
||||
obj: Any = mtp_module
|
||||
for part in parts[:-1]:
|
||||
obj = getattr(obj, part)
|
||||
setattr(obj, parts[-1], mtp_weights[src_name])
|
||||
|
||||
# Load transformer block weights (self_attn, mlp, layer norms)
|
||||
transformer_prefixes = [
|
||||
"self_attn",
|
||||
"mlp",
|
||||
"input_layernorm",
|
||||
"post_attention_layernorm",
|
||||
]
|
||||
|
||||
for prefix in transformer_prefixes:
|
||||
for key, value in mtp_weights.items():
|
||||
if key.startswith(prefix):
|
||||
# Navigate to the correct attribute
|
||||
parts = key.split(".")
|
||||
obj = mtp_module.transformer_block
|
||||
for part in parts[:-1]:
|
||||
obj = getattr(obj, part)
|
||||
setattr(obj, parts[-1], value)
|
||||
506
src/exo/worker/engines/mlx/mtp/speculative_decode.py
Normal file
506
src/exo/worker/engines/mlx/mtp/speculative_decode.py
Normal file
@@ -0,0 +1,506 @@
|
||||
"""MTP Speculative Decoding for DeepSeek V3.
|
||||
|
||||
This module implements speculative decoding using the Multi-Token Prediction (MTP)
|
||||
layer from DeepSeek V3. The key difference from standard speculative decoding is
|
||||
that MTP requires hidden states from the main model, not just token predictions.
|
||||
|
||||
Based on vLLM/SGLang research:
|
||||
- 81-82% acceptance rate with k=1
|
||||
- 1.5-2x speedup at low QPS
|
||||
"""
|
||||
|
||||
import functools
|
||||
import time
|
||||
from collections.abc import Callable, Generator
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, cast
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from mlx_lm.models import cache
|
||||
from mlx_lm.models.cache import KVCache
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
|
||||
from exo.worker.engines.mlx.mtp.module import MTPModule
|
||||
|
||||
# Generation stream for async operations
|
||||
generation_stream = mx.new_stream(mx.default_device())
|
||||
|
||||
|
||||
@dataclass
|
||||
class MTPGenerationResponse:
|
||||
"""Response from MTP speculative generation.
|
||||
|
||||
Attributes:
|
||||
text: The next segment of decoded text.
|
||||
token: The next token.
|
||||
logprobs: A vector of log probabilities.
|
||||
from_draft: Whether the token was generated by the MTP draft module.
|
||||
prompt_tokens: The number of tokens in the prompt.
|
||||
prompt_tps: The prompt processing tokens-per-second.
|
||||
generation_tokens: The number of generated tokens.
|
||||
generation_tps: The tokens-per-second for generation.
|
||||
peak_memory: The peak memory used so far in GB.
|
||||
finish_reason: The reason the response is being sent: "length", "stop" or None.
|
||||
"""
|
||||
|
||||
text: str
|
||||
token: int
|
||||
logprobs: mx.array
|
||||
from_draft: bool
|
||||
prompt_tokens: int
|
||||
prompt_tps: float
|
||||
generation_tokens: int
|
||||
generation_tps: float
|
||||
peak_memory: float
|
||||
finish_reason: str | None = None
|
||||
|
||||
|
||||
def maybe_quantize_kv_cache(
|
||||
prompt_cache: list[Any],
|
||||
quantized_kv_start: int,
|
||||
kv_group_size: int,
|
||||
kv_bits: int | None,
|
||||
) -> None:
|
||||
"""Quantize KV cache entries if needed."""
|
||||
if kv_bits is None:
|
||||
return
|
||||
for e, c in enumerate(prompt_cache):
|
||||
if (
|
||||
hasattr(c, "to_quantized")
|
||||
and hasattr(c, "offset")
|
||||
and c.offset >= quantized_kv_start
|
||||
):
|
||||
prompt_cache[e] = c.to_quantized(group_size=kv_group_size, bits=kv_bits)
|
||||
|
||||
|
||||
class ModelWithHiddenStates(nn.Module):
|
||||
"""Wrapper to extract hidden states before lm_head.
|
||||
|
||||
This wrapper allows capturing the hidden states from the transformer
|
||||
layers before the final lm_head projection, which is needed for MTP.
|
||||
"""
|
||||
|
||||
def __init__(self, base_model: nn.Module) -> None:
|
||||
super().__init__()
|
||||
self._base = base_model
|
||||
|
||||
def forward_with_hidden(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
model_cache: list[Any] | None = None,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
"""Forward pass that returns both logits and hidden states.
|
||||
|
||||
Args:
|
||||
inputs: Input token ids
|
||||
model_cache: KV cache
|
||||
|
||||
Returns:
|
||||
Tuple of (logits, hidden_states)
|
||||
"""
|
||||
# Call the inner model (transformer layers + norm)
|
||||
hidden: mx.array = self._base.model(inputs, model_cache)
|
||||
# Get logits from lm_head
|
||||
logits: mx.array = self._base.lm_head(hidden)
|
||||
return logits, hidden
|
||||
|
||||
def forward(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
model_cache: list[Any] | None = None,
|
||||
) -> mx.array:
|
||||
"""Standard forward pass returning only logits."""
|
||||
return cast(mx.array, self._base(inputs, cache=model_cache))
|
||||
|
||||
@property
|
||||
def layers(self) -> list[nn.Module]:
|
||||
"""Access layers for cache creation."""
|
||||
return cast(list[nn.Module], self._base.layers)
|
||||
|
||||
|
||||
def mtp_speculative_generate_step(
|
||||
prompt: mx.array,
|
||||
model: nn.Module,
|
||||
mtp_module: MTPModule,
|
||||
*,
|
||||
num_draft_tokens: int = 1,
|
||||
max_tokens: int = 256,
|
||||
sampler: Callable[[mx.array], mx.array] | None = None,
|
||||
logits_processors: list[Callable[[mx.array, mx.array], mx.array]] | None = None,
|
||||
prompt_cache: list[Any] | None = None,
|
||||
mtp_cache: KVCache | None = None,
|
||||
prefill_step_size: int = 512,
|
||||
kv_bits: int | None = None,
|
||||
kv_group_size: int = 64,
|
||||
quantized_kv_start: int = 0,
|
||||
) -> Generator[tuple[int, mx.array, bool], None, None]:
|
||||
"""MTP speculative decoding generator.
|
||||
|
||||
Unlike standard speculative decoding where the draft model only needs tokens,
|
||||
MTP requires the hidden states from the main model. This generator:
|
||||
|
||||
1. Runs the main model to get logits AND hidden states
|
||||
2. Uses MTP module with hidden state + sampled token to predict next token
|
||||
3. Verifies MTP predictions with the main model
|
||||
4. Accepts/rejects based on matching
|
||||
|
||||
Args:
|
||||
prompt: The input prompt as token ids
|
||||
model: The main model (must support return_hidden=True)
|
||||
mtp_module: The MTP module for draft prediction
|
||||
num_draft_tokens: Number of tokens to draft (typically 1 for MTP)
|
||||
max_tokens: Maximum number of tokens to generate
|
||||
sampler: Optional sampler function for token selection
|
||||
logits_processors: Optional list of logits processors
|
||||
prompt_cache: KV cache for the main model
|
||||
mtp_cache: KV cache for the MTP module
|
||||
prefill_step_size: Step size for prompt processing
|
||||
kv_bits: Bits for KV cache quantization
|
||||
kv_group_size: Group size for KV cache quantization
|
||||
quantized_kv_start: Step to begin cache quantization
|
||||
|
||||
Yields:
|
||||
Tuple of (token, logprobs, from_draft)
|
||||
"""
|
||||
y = prompt.astype(mx.uint32)
|
||||
prev_tokens: mx.array | None = None
|
||||
|
||||
# Wrap model to get hidden states
|
||||
wrapped_model = (
|
||||
model
|
||||
if isinstance(model, ModelWithHiddenStates)
|
||||
else ModelWithHiddenStates(model)
|
||||
)
|
||||
|
||||
# Create caches if needed
|
||||
if prompt_cache is None:
|
||||
prompt_cache = cache.make_prompt_cache(model)
|
||||
if mtp_cache is None:
|
||||
mtp_cache = KVCache()
|
||||
|
||||
final_sampler = (
|
||||
sampler if sampler is not None else (lambda x: mx.argmax(x, axis=-1))
|
||||
)
|
||||
|
||||
quantize_cache_fn = functools.partial(
|
||||
maybe_quantize_kv_cache,
|
||||
quantized_kv_start=quantized_kv_start,
|
||||
kv_group_size=kv_group_size,
|
||||
kv_bits=kv_bits,
|
||||
)
|
||||
|
||||
def _process_and_sample(
|
||||
tokens: mx.array | None,
|
||||
logits: mx.array,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
"""Process logits and sample tokens."""
|
||||
nonlocal logits_processors
|
||||
processed_logits = logits
|
||||
if logits_processors:
|
||||
for processor in logits_processors:
|
||||
processed_logits = processor(
|
||||
tokens if tokens is not None else mx.array([]), processed_logits
|
||||
)
|
||||
|
||||
logprobs = processed_logits - mx.logsumexp(
|
||||
processed_logits, axis=-1, keepdims=True
|
||||
)
|
||||
sampled = final_sampler(logprobs)
|
||||
return sampled, logprobs
|
||||
|
||||
def _main_model_step_with_hidden(
|
||||
input_y: mx.array,
|
||||
) -> tuple[mx.array, mx.array, mx.array]:
|
||||
"""Run main model step with hidden state return."""
|
||||
nonlocal prev_tokens
|
||||
|
||||
with mx.stream(generation_stream):
|
||||
logits, hidden = wrapped_model.forward_with_hidden(
|
||||
input_y[None], prompt_cache
|
||||
)
|
||||
logits = logits[:, -1, :]
|
||||
quantize_cache_fn(prompt_cache)
|
||||
|
||||
if logits_processors:
|
||||
prev_tokens = (
|
||||
mx.concatenate([prev_tokens, input_y])
|
||||
if prev_tokens is not None
|
||||
else input_y
|
||||
)
|
||||
|
||||
sampled, logprobs_result = _process_and_sample(prev_tokens, logits)
|
||||
return sampled, logprobs_result.squeeze(0), hidden[:, -1:, :]
|
||||
|
||||
def _main_model_step(
|
||||
input_y: mx.array,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
"""Run main model step without hidden state."""
|
||||
nonlocal prev_tokens
|
||||
|
||||
with mx.stream(generation_stream):
|
||||
logits = wrapped_model.forward(input_y[None], prompt_cache)
|
||||
logits = logits[:, -1, :]
|
||||
quantize_cache_fn(prompt_cache)
|
||||
|
||||
if logits_processors:
|
||||
prev_tokens = (
|
||||
mx.concatenate([prev_tokens, input_y])
|
||||
if prev_tokens is not None
|
||||
else input_y
|
||||
)
|
||||
|
||||
sampled, logprobs_result = _process_and_sample(prev_tokens, logits)
|
||||
return sampled, logprobs_result.squeeze(0)
|
||||
|
||||
def _mtp_draft(
|
||||
hidden_state: mx.array,
|
||||
draft_token: mx.array,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
"""Generate draft token using MTP module."""
|
||||
with mx.stream(generation_stream):
|
||||
logits, new_hidden = mtp_module(
|
||||
hidden_state,
|
||||
draft_token,
|
||||
cache=mtp_cache,
|
||||
)
|
||||
logits = logits[:, -1, :]
|
||||
sampled, _ = _process_and_sample(None, logits)
|
||||
return sampled, new_hidden
|
||||
|
||||
def _prefill(input_y: mx.array) -> mx.array:
|
||||
"""Prefill the prompt cache."""
|
||||
result_y = input_y
|
||||
while result_y.size > prefill_step_size:
|
||||
_ = wrapped_model.forward(result_y[:prefill_step_size][None], prompt_cache)
|
||||
quantize_cache_fn(prompt_cache)
|
||||
mx.eval([c.state for c in prompt_cache])
|
||||
result_y = result_y[prefill_step_size:]
|
||||
mx.clear_cache()
|
||||
return result_y
|
||||
|
||||
def _rewind_cache(num_draft: int, num_accept: int) -> None:
|
||||
"""Rewind caches after rejection."""
|
||||
cache.trim_prompt_cache(prompt_cache, num_draft - num_accept)
|
||||
|
||||
# Prefill phase
|
||||
with mx.stream(generation_stream):
|
||||
y = _prefill(y)
|
||||
|
||||
ntoks = 0
|
||||
num_draft = 0
|
||||
n_accepted = 0
|
||||
last_hidden: mx.array | None = None
|
||||
|
||||
try:
|
||||
# Initial step to get first token and hidden state
|
||||
sampled, logprobs, last_hidden = _main_model_step_with_hidden(y)
|
||||
mx.eval(sampled, logprobs, last_hidden)
|
||||
|
||||
y = sampled
|
||||
current_logprobs = logprobs
|
||||
|
||||
while ntoks < max_tokens:
|
||||
# Draft phase: use MTP to predict next token
|
||||
num_draft = min(max_tokens - ntoks - 1, num_draft_tokens)
|
||||
|
||||
if num_draft > 0 and last_hidden is not None:
|
||||
# Use MTP to draft
|
||||
draft_token, draft_hidden = _mtp_draft(last_hidden, y)
|
||||
mx.eval(draft_token, draft_hidden)
|
||||
|
||||
# Verify with main model
|
||||
# Feed the drafted token to main model
|
||||
verify_input = mx.concatenate([y, draft_token.flatten()])
|
||||
verify_sampled, verify_logprobs, new_hidden = (
|
||||
_main_model_step_with_hidden(verify_input)
|
||||
)
|
||||
mx.eval(verify_sampled, verify_logprobs, new_hidden)
|
||||
|
||||
# Check if draft matches verification
|
||||
draft_token_val = int(draft_token.item())
|
||||
verify_token_val = (
|
||||
int(verify_sampled[0].item())
|
||||
if verify_sampled.shape[0] > 1
|
||||
else int(verify_sampled.item())
|
||||
)
|
||||
|
||||
# Yield the current token (not from draft)
|
||||
ntoks += 1
|
||||
yield int(y.item()), current_logprobs, False
|
||||
|
||||
if ntoks >= max_tokens:
|
||||
break
|
||||
|
||||
if draft_token_val == verify_token_val:
|
||||
# Draft accepted
|
||||
n_accepted += 1
|
||||
ntoks += 1
|
||||
draft_logprobs = (
|
||||
verify_logprobs[0]
|
||||
if verify_logprobs.ndim > 1
|
||||
else verify_logprobs
|
||||
)
|
||||
yield draft_token_val, draft_logprobs, True
|
||||
|
||||
if ntoks >= max_tokens:
|
||||
break
|
||||
|
||||
# Continue with the token after the draft
|
||||
y = (
|
||||
verify_sampled[-1:]
|
||||
if verify_sampled.ndim > 0 and verify_sampled.shape[0] > 1
|
||||
else verify_sampled
|
||||
)
|
||||
current_logprobs = (
|
||||
verify_logprobs[-1]
|
||||
if verify_logprobs.ndim > 1
|
||||
else verify_logprobs
|
||||
)
|
||||
last_hidden = new_hidden
|
||||
else:
|
||||
# Draft rejected - rewind and use verified token
|
||||
_rewind_cache(1, 0)
|
||||
y = (
|
||||
verify_sampled[:1]
|
||||
if verify_sampled.ndim > 0 and verify_sampled.shape[0] > 1
|
||||
else verify_sampled
|
||||
)
|
||||
current_logprobs = (
|
||||
verify_logprobs[0]
|
||||
if verify_logprobs.ndim > 1
|
||||
else verify_logprobs
|
||||
)
|
||||
last_hidden = (
|
||||
new_hidden[:, :1, :] if new_hidden is not None else None
|
||||
)
|
||||
else:
|
||||
# No drafting, just do normal generation
|
||||
ntoks += 1
|
||||
yield int(y.item()), current_logprobs, False
|
||||
|
||||
if ntoks >= max_tokens:
|
||||
break
|
||||
|
||||
sampled, logprobs, last_hidden = _main_model_step_with_hidden(y)
|
||||
mx.eval(sampled, logprobs, last_hidden)
|
||||
|
||||
y = sampled
|
||||
current_logprobs = logprobs
|
||||
|
||||
if ntoks % 256 == 0:
|
||||
mx.clear_cache()
|
||||
|
||||
finally:
|
||||
_rewind_cache(num_draft, n_accepted)
|
||||
|
||||
|
||||
def mtp_speculative_generate(
|
||||
model: nn.Module,
|
||||
mtp_module: MTPModule,
|
||||
tokenizer: TokenizerWrapper,
|
||||
prompt: str | mx.array | list[int],
|
||||
max_tokens: int = 256,
|
||||
sampler: Callable[[mx.array], mx.array] | None = None,
|
||||
logits_processors: list[Callable[[mx.array, mx.array], mx.array]] | None = None,
|
||||
prompt_cache: list[Any] | None = None,
|
||||
num_draft_tokens: int = 1,
|
||||
prefill_step_size: int = 512,
|
||||
kv_group_size: int = 64,
|
||||
kv_bits: int | None = None,
|
||||
) -> Generator[MTPGenerationResponse, None, None]:
|
||||
"""High-level MTP speculative generation with text output.
|
||||
|
||||
Args:
|
||||
model: The main model
|
||||
mtp_module: The MTP module for draft prediction
|
||||
tokenizer: Tokenizer for encoding/decoding
|
||||
prompt: Input prompt (string, array, or token list)
|
||||
max_tokens: Maximum tokens to generate
|
||||
sampler: Optional sampler function
|
||||
logits_processors: Optional logits processors
|
||||
prompt_cache: Optional KV cache
|
||||
num_draft_tokens: Number of draft tokens
|
||||
prefill_step_size: Prefill step size
|
||||
kv_group_size: KV group size
|
||||
kv_bits: KV bits
|
||||
|
||||
Yields:
|
||||
MTPGenerationResponse objects with text and metadata
|
||||
"""
|
||||
if not isinstance(prompt, mx.array):
|
||||
if isinstance(prompt, str):
|
||||
bos_token = getattr(tokenizer, "bos_token", None)
|
||||
add_special_tokens = bos_token is None or not prompt.startswith(
|
||||
str(bos_token)
|
||||
)
|
||||
encoded: list[int] = tokenizer.encode(
|
||||
prompt, add_special_tokens=add_special_tokens
|
||||
)
|
||||
prompt = mx.array(encoded)
|
||||
else:
|
||||
prompt = mx.array(prompt)
|
||||
|
||||
detokenizer = tokenizer.detokenizer
|
||||
eos_token_ids: list[int] = getattr(tokenizer, "eos_token_ids", [])
|
||||
|
||||
token_generator = mtp_speculative_generate_step(
|
||||
prompt,
|
||||
model,
|
||||
mtp_module,
|
||||
max_tokens=max_tokens,
|
||||
sampler=sampler,
|
||||
logits_processors=logits_processors,
|
||||
prompt_cache=prompt_cache,
|
||||
num_draft_tokens=num_draft_tokens,
|
||||
prefill_step_size=prefill_step_size,
|
||||
kv_group_size=kv_group_size,
|
||||
kv_bits=kv_bits,
|
||||
)
|
||||
|
||||
tic = time.perf_counter()
|
||||
prompt_tps = 0.0
|
||||
token = 0
|
||||
logprobs: mx.array = mx.array([0.0])
|
||||
from_draft = False
|
||||
n = 0
|
||||
|
||||
for n, (token, logprobs, from_draft) in enumerate(token_generator):
|
||||
if n == 0:
|
||||
prompt_time = time.perf_counter() - tic
|
||||
prompt_tps = float(prompt.size) / prompt_time
|
||||
tic = time.perf_counter()
|
||||
|
||||
if token in eos_token_ids:
|
||||
break
|
||||
|
||||
detokenizer.add_token(token)
|
||||
if (n + 1) == max_tokens:
|
||||
break
|
||||
|
||||
yield MTPGenerationResponse(
|
||||
text=str(detokenizer.last_segment),
|
||||
token=token,
|
||||
logprobs=logprobs,
|
||||
from_draft=from_draft,
|
||||
prompt_tokens=int(prompt.size),
|
||||
prompt_tps=prompt_tps,
|
||||
generation_tokens=n + 1,
|
||||
generation_tps=(n + 1) / (time.perf_counter() - tic),
|
||||
peak_memory=mx.get_peak_memory() / 1e9,
|
||||
finish_reason=None,
|
||||
)
|
||||
|
||||
detokenizer.finalize()
|
||||
yield MTPGenerationResponse(
|
||||
text=str(detokenizer.last_segment),
|
||||
token=token,
|
||||
logprobs=logprobs,
|
||||
from_draft=from_draft,
|
||||
prompt_tokens=int(prompt.size),
|
||||
prompt_tps=prompt_tps,
|
||||
generation_tokens=n + 1,
|
||||
generation_tps=(n + 1) / (time.perf_counter() - tic),
|
||||
peak_memory=mx.get_peak_memory() / 1e9,
|
||||
finish_reason="stop" if token in eos_token_ids else "length",
|
||||
)
|
||||
1
src/exo/worker/engines/mlx/mtp/tests/__init__.py
Normal file
1
src/exo/worker/engines/mlx/mtp/tests/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Tests for MTP module."""
|
||||
412
src/exo/worker/engines/mlx/mtp/tests/test_mtp_module.py
Normal file
412
src/exo/worker/engines/mlx/mtp/tests/test_mtp_module.py
Normal file
@@ -0,0 +1,412 @@
|
||||
"""Unit tests for MTP module components."""
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import pytest
|
||||
|
||||
from exo.worker.engines.mlx.mtp.module import (
|
||||
MTP_LAYER_INDEX,
|
||||
MTPModule,
|
||||
MTPTransformerBlock,
|
||||
extract_mtp_weights,
|
||||
load_mtp_weights_into_module,
|
||||
)
|
||||
|
||||
|
||||
class MockModelArgs:
|
||||
"""Mock ModelArgs for testing without importing deepseek_v3."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int = 256,
|
||||
intermediate_size: int = 512,
|
||||
num_attention_heads: int = 4,
|
||||
num_key_value_heads: int = 4,
|
||||
rms_norm_eps: float = 1e-6,
|
||||
vocab_size: int = 1000,
|
||||
q_lora_rank: int | None = None,
|
||||
kv_lora_rank: int = 64,
|
||||
qk_rope_head_dim: int = 16,
|
||||
v_head_dim: int = 32,
|
||||
qk_nope_head_dim: int = 32,
|
||||
rope_theta: float = 10000.0,
|
||||
rope_scaling: dict | None = None,
|
||||
attention_bias: bool = False,
|
||||
max_position_embeddings: int = 2048,
|
||||
):
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.vocab_size = vocab_size
|
||||
self.q_lora_rank = q_lora_rank
|
||||
self.kv_lora_rank = kv_lora_rank
|
||||
self.qk_rope_head_dim = qk_rope_head_dim
|
||||
self.v_head_dim = v_head_dim
|
||||
self.qk_nope_head_dim = qk_nope_head_dim
|
||||
self.rope_theta = rope_theta
|
||||
self.rope_scaling = rope_scaling
|
||||
self.attention_bias = attention_bias
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
|
||||
|
||||
class TestExtractMTPWeights:
|
||||
"""Tests for extract_mtp_weights function."""
|
||||
|
||||
def test_extracts_layer_61_weights(self) -> None:
|
||||
"""Should extract only layer 61 weights."""
|
||||
weights = {
|
||||
"model.layers.60.self_attn.weight": mx.zeros((10, 10)),
|
||||
"model.layers.61.enorm.weight": mx.ones((10,)),
|
||||
"model.layers.61.hnorm.weight": mx.ones((10,)) * 2,
|
||||
"model.layers.61.eh_proj.weight": mx.ones((10, 20)),
|
||||
"model.layers.62.self_attn.weight": mx.zeros((10, 10)),
|
||||
"model.embed_tokens.weight": mx.zeros((100, 10)),
|
||||
}
|
||||
|
||||
mtp_weights = extract_mtp_weights(weights)
|
||||
|
||||
assert len(mtp_weights) == 3
|
||||
assert "enorm.weight" in mtp_weights
|
||||
assert "hnorm.weight" in mtp_weights
|
||||
assert "eh_proj.weight" in mtp_weights
|
||||
# Check values are preserved
|
||||
assert mx.allclose(mtp_weights["enorm.weight"], mx.ones((10,)))
|
||||
assert mx.allclose(mtp_weights["hnorm.weight"], mx.ones((10,)) * 2)
|
||||
|
||||
def test_returns_empty_dict_when_no_layer_61(self) -> None:
|
||||
"""Should return empty dict when layer 61 doesn't exist."""
|
||||
weights = {
|
||||
"model.layers.0.self_attn.weight": mx.zeros((10, 10)),
|
||||
"model.layers.60.self_attn.weight": mx.zeros((10, 10)),
|
||||
}
|
||||
|
||||
mtp_weights = extract_mtp_weights(weights)
|
||||
|
||||
assert len(mtp_weights) == 0
|
||||
|
||||
def test_handles_nested_layer_61_weights(self) -> None:
|
||||
"""Should handle nested weight paths like self_attn.q_proj.weight."""
|
||||
weights = {
|
||||
f"model.layers.{MTP_LAYER_INDEX}.self_attn.q_a_proj.weight": mx.zeros(
|
||||
(10, 10)
|
||||
),
|
||||
f"model.layers.{MTP_LAYER_INDEX}.mlp.gate_proj.weight": mx.zeros((20, 10)),
|
||||
}
|
||||
|
||||
mtp_weights = extract_mtp_weights(weights)
|
||||
|
||||
assert "self_attn.q_a_proj.weight" in mtp_weights
|
||||
assert "mlp.gate_proj.weight" in mtp_weights
|
||||
|
||||
|
||||
class TestMTPTransformerBlock:
|
||||
"""Tests for MTPTransformerBlock."""
|
||||
|
||||
@pytest.fixture
|
||||
def config(self) -> MockModelArgs:
|
||||
return MockModelArgs(
|
||||
hidden_size=64, intermediate_size=128, num_attention_heads=2
|
||||
)
|
||||
|
||||
def test_forward_shape(self, config: MockModelArgs) -> None:
|
||||
"""Forward pass should preserve input shape."""
|
||||
# Skip if deepseek_v3 imports fail (CI without mlx_lm)
|
||||
pytest.importorskip("mlx_lm.models.deepseek_v3")
|
||||
|
||||
block = MTPTransformerBlock(config) # type: ignore[arg-type]
|
||||
x = mx.random.normal((1, 5, config.hidden_size))
|
||||
|
||||
output = block(x)
|
||||
|
||||
assert output.shape == x.shape
|
||||
|
||||
def test_forward_with_mask(self, config: MockModelArgs) -> None:
|
||||
"""Forward pass should work with attention mask."""
|
||||
pytest.importorskip("mlx_lm.models.deepseek_v3")
|
||||
|
||||
block = MTPTransformerBlock(config) # type: ignore[arg-type]
|
||||
x = mx.random.normal((1, 5, config.hidden_size))
|
||||
# Create causal mask
|
||||
mask = mx.triu(mx.full((5, 5), float("-inf")), k=1)
|
||||
|
||||
output = block(x, mask=mask)
|
||||
|
||||
assert output.shape == x.shape
|
||||
|
||||
|
||||
class TestMTPModule:
|
||||
"""Tests for MTPModule."""
|
||||
|
||||
@pytest.fixture
|
||||
def config(self) -> MockModelArgs:
|
||||
return MockModelArgs(
|
||||
hidden_size=64,
|
||||
intermediate_size=128,
|
||||
num_attention_heads=2,
|
||||
vocab_size=100,
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def shared_components(
|
||||
self, config: MockModelArgs
|
||||
) -> tuple[nn.Embedding, nn.Linear, nn.RMSNorm]:
|
||||
embedding = nn.Embedding(config.vocab_size, config.hidden_size)
|
||||
lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
output_norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
return embedding, lm_head, output_norm
|
||||
|
||||
def test_initialization(
|
||||
self,
|
||||
config: MockModelArgs,
|
||||
shared_components: tuple[nn.Embedding, nn.Linear, nn.RMSNorm],
|
||||
) -> None:
|
||||
"""MTPModule should initialize with correct components."""
|
||||
pytest.importorskip("mlx_lm.models.deepseek_v3")
|
||||
|
||||
embedding, lm_head, output_norm = shared_components
|
||||
mtp = MTPModule(
|
||||
config=config, # type: ignore[arg-type]
|
||||
shared_embedding=embedding,
|
||||
shared_lm_head=lm_head,
|
||||
output_norm=output_norm,
|
||||
)
|
||||
|
||||
assert mtp.hnorm is not None
|
||||
assert mtp.enorm is not None
|
||||
assert mtp.eh_proj is not None
|
||||
assert mtp.transformer_block is not None
|
||||
|
||||
def test_forward_output_shapes(
|
||||
self,
|
||||
config: MockModelArgs,
|
||||
shared_components: tuple[nn.Embedding, nn.Linear, nn.RMSNorm],
|
||||
) -> None:
|
||||
"""Forward pass should return correct output shapes."""
|
||||
pytest.importorskip("mlx_lm.models.deepseek_v3")
|
||||
|
||||
embedding, lm_head, output_norm = shared_components
|
||||
mtp = MTPModule(
|
||||
config=config, # type: ignore[arg-type]
|
||||
shared_embedding=embedding,
|
||||
shared_lm_head=lm_head,
|
||||
output_norm=output_norm,
|
||||
)
|
||||
|
||||
batch_size = 2
|
||||
seq_len = 1
|
||||
hidden_state = mx.random.normal((batch_size, seq_len, config.hidden_size))
|
||||
draft_token = mx.array([[5], [10]]) # [batch, seq_len]
|
||||
|
||||
logits, new_hidden = mtp(hidden_state, draft_token)
|
||||
|
||||
assert logits.shape == (batch_size, seq_len, config.vocab_size)
|
||||
assert new_hidden.shape == (batch_size, seq_len, config.hidden_size)
|
||||
|
||||
def test_shares_embedding_and_lm_head(
|
||||
self,
|
||||
config: MockModelArgs,
|
||||
shared_components: tuple[nn.Embedding, nn.Linear, nn.RMSNorm],
|
||||
) -> None:
|
||||
"""MTPModule should use shared embedding and lm_head."""
|
||||
pytest.importorskip("mlx_lm.models.deepseek_v3")
|
||||
|
||||
embedding, lm_head, output_norm = shared_components
|
||||
mtp = MTPModule(
|
||||
config=config, # type: ignore[arg-type]
|
||||
shared_embedding=embedding,
|
||||
shared_lm_head=lm_head,
|
||||
output_norm=output_norm,
|
||||
)
|
||||
|
||||
# Verify they're the same objects
|
||||
assert mtp._shared_embedding is embedding
|
||||
assert mtp._shared_lm_head is lm_head
|
||||
assert mtp._output_norm is output_norm
|
||||
|
||||
|
||||
class TestLoadMTPWeights:
|
||||
"""Tests for load_mtp_weights_into_module."""
|
||||
|
||||
@pytest.fixture
|
||||
def config(self) -> MockModelArgs:
|
||||
return MockModelArgs(
|
||||
hidden_size=64,
|
||||
intermediate_size=128,
|
||||
num_attention_heads=2,
|
||||
vocab_size=100,
|
||||
)
|
||||
|
||||
def test_loads_norm_weights(self, config: MockModelArgs) -> None:
|
||||
"""Should load enorm and hnorm weights."""
|
||||
pytest.importorskip("mlx_lm.models.deepseek_v3")
|
||||
|
||||
embedding = nn.Embedding(config.vocab_size, config.hidden_size)
|
||||
lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
output_norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
mtp = MTPModule(
|
||||
config=config, # type: ignore[arg-type]
|
||||
shared_embedding=embedding,
|
||||
shared_lm_head=lm_head,
|
||||
output_norm=output_norm,
|
||||
)
|
||||
|
||||
# Create test weights
|
||||
test_enorm = mx.ones((config.hidden_size,)) * 3.0
|
||||
test_hnorm = mx.ones((config.hidden_size,)) * 5.0
|
||||
mtp_weights = {
|
||||
"enorm.weight": test_enorm,
|
||||
"hnorm.weight": test_hnorm,
|
||||
}
|
||||
|
||||
load_mtp_weights_into_module(mtp, mtp_weights)
|
||||
|
||||
assert mx.allclose(mtp.enorm.weight, test_enorm)
|
||||
assert mx.allclose(mtp.hnorm.weight, test_hnorm)
|
||||
|
||||
|
||||
class TestSanitizePatch:
|
||||
"""Tests for the sanitize patching logic."""
|
||||
|
||||
def test_patch_preserves_layer_61(self) -> None:
|
||||
"""Patching sanitize should preserve layer 61 weights."""
|
||||
from exo.worker.engines.mlx.utils_mlx import (
|
||||
_patch_deepseek_sanitize_for_mtp,
|
||||
_restore_deepseek_sanitize,
|
||||
)
|
||||
|
||||
deepseek_v3 = pytest.importorskip("mlx_lm.models.deepseek_v3")
|
||||
model_cls = deepseek_v3.Model
|
||||
|
||||
# Get original sanitize behavior
|
||||
original_sanitize = model_cls.sanitize
|
||||
|
||||
try:
|
||||
# Apply patch
|
||||
_patch_deepseek_sanitize_for_mtp()
|
||||
|
||||
# Note: we can't easily test the full sanitize without a real model
|
||||
# This test verifies the patch is applied
|
||||
assert model_cls.sanitize is not original_sanitize
|
||||
|
||||
finally:
|
||||
_restore_deepseek_sanitize()
|
||||
# Verify restore worked
|
||||
assert model_cls.sanitize is original_sanitize
|
||||
|
||||
def test_restore_sanitize(self) -> None:
|
||||
"""Restoring sanitize should return to original behavior."""
|
||||
from exo.worker.engines.mlx.utils_mlx import (
|
||||
_patch_deepseek_sanitize_for_mtp,
|
||||
_restore_deepseek_sanitize,
|
||||
)
|
||||
|
||||
deepseek_v3 = pytest.importorskip("mlx_lm.models.deepseek_v3")
|
||||
model_cls = deepseek_v3.Model
|
||||
|
||||
original_sanitize = model_cls.sanitize
|
||||
|
||||
_patch_deepseek_sanitize_for_mtp()
|
||||
assert model_cls.sanitize is not original_sanitize
|
||||
|
||||
_restore_deepseek_sanitize()
|
||||
assert model_cls.sanitize is original_sanitize
|
||||
|
||||
def test_double_patch_is_safe(self) -> None:
|
||||
"""Calling patch twice should be safe (idempotent)."""
|
||||
from exo.worker.engines.mlx.utils_mlx import (
|
||||
_patch_deepseek_sanitize_for_mtp,
|
||||
_restore_deepseek_sanitize,
|
||||
)
|
||||
|
||||
deepseek_v3 = pytest.importorskip("mlx_lm.models.deepseek_v3")
|
||||
model_cls = deepseek_v3.Model
|
||||
|
||||
original_sanitize = model_cls.sanitize
|
||||
|
||||
try:
|
||||
_patch_deepseek_sanitize_for_mtp()
|
||||
patched_sanitize = model_cls.sanitize
|
||||
|
||||
# Patch again - should be no-op
|
||||
_patch_deepseek_sanitize_for_mtp()
|
||||
assert model_cls.sanitize is patched_sanitize
|
||||
|
||||
finally:
|
||||
_restore_deepseek_sanitize()
|
||||
assert model_cls.sanitize is original_sanitize
|
||||
|
||||
|
||||
class TestModelIdDetection:
|
||||
"""Tests for DeepSeek V3 model ID detection."""
|
||||
|
||||
def test_detects_deepseek_v3(self) -> None:
|
||||
"""Should detect DeepSeek V3 model IDs."""
|
||||
from exo.worker.engines.mlx.utils_mlx import _might_be_deepseek_v3
|
||||
|
||||
assert _might_be_deepseek_v3("deepseek-ai/DeepSeek-V3")
|
||||
assert _might_be_deepseek_v3("deepseek-ai/deepseek-v3-base")
|
||||
assert _might_be_deepseek_v3("mlx-community/DeepSeek-V3-4bit")
|
||||
|
||||
def test_detects_deepseek_r1(self) -> None:
|
||||
"""Should detect DeepSeek R1 model IDs (also uses MTP)."""
|
||||
from exo.worker.engines.mlx.utils_mlx import _might_be_deepseek_v3
|
||||
|
||||
assert _might_be_deepseek_v3("deepseek-ai/DeepSeek-R1")
|
||||
assert _might_be_deepseek_v3("mlx-community/DeepSeek-R1-4bit")
|
||||
|
||||
def test_rejects_non_deepseek(self) -> None:
|
||||
"""Should reject non-DeepSeek model IDs."""
|
||||
from exo.worker.engines.mlx.utils_mlx import _might_be_deepseek_v3
|
||||
|
||||
assert not _might_be_deepseek_v3("meta-llama/Llama-3-70B")
|
||||
assert not _might_be_deepseek_v3("mistralai/Mixtral-8x7B")
|
||||
assert not _might_be_deepseek_v3("deepseek-ai/DeepSeek-V2") # V2, not V3
|
||||
|
||||
def test_case_insensitive(self) -> None:
|
||||
"""Detection should be case insensitive."""
|
||||
from exo.worker.engines.mlx.utils_mlx import _might_be_deepseek_v3
|
||||
|
||||
assert _might_be_deepseek_v3("DEEPSEEK-AI/DEEPSEEK-V3")
|
||||
assert _might_be_deepseek_v3("DeepSeek-AI/deepseek-v3")
|
||||
|
||||
|
||||
class TestFlattenParams:
|
||||
"""Tests for parameter flattening utility."""
|
||||
|
||||
def test_flattens_nested_dict(self) -> None:
|
||||
"""Should flatten nested parameter dict."""
|
||||
from exo.worker.engines.mlx.utils_mlx import _flatten_params
|
||||
|
||||
params = {
|
||||
"model": {
|
||||
"layers": {
|
||||
"0": {
|
||||
"weight": mx.zeros((10,)),
|
||||
}
|
||||
},
|
||||
"embed": mx.ones((5,)),
|
||||
}
|
||||
}
|
||||
|
||||
flat = _flatten_params(params)
|
||||
|
||||
assert "model.layers.0.weight" in flat
|
||||
assert "model.embed" in flat
|
||||
assert mx.allclose(flat["model.layers.0.weight"], mx.zeros((10,)))
|
||||
assert mx.allclose(flat["model.embed"], mx.ones((5,)))
|
||||
|
||||
def test_handles_flat_dict(self) -> None:
|
||||
"""Should handle already-flat dict."""
|
||||
from exo.worker.engines.mlx.utils_mlx import _flatten_params
|
||||
|
||||
params = {
|
||||
"weight": mx.zeros((10,)),
|
||||
"bias": mx.ones((10,)),
|
||||
}
|
||||
|
||||
flat = _flatten_params(params)
|
||||
|
||||
assert flat == params
|
||||
253
src/exo/worker/engines/mlx/mtp/tests/test_speculative_decode.py
Normal file
253
src/exo/worker/engines/mlx/mtp/tests/test_speculative_decode.py
Normal file
@@ -0,0 +1,253 @@
|
||||
"""Unit tests for MTP speculative decoding."""
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import pytest
|
||||
|
||||
from exo.worker.engines.mlx.mtp.speculative_decode import (
|
||||
ModelWithHiddenStates,
|
||||
maybe_quantize_kv_cache,
|
||||
)
|
||||
|
||||
|
||||
class MockModel(nn.Module):
|
||||
"""Mock model for testing speculative decoding."""
|
||||
|
||||
def __init__(self, hidden_size: int = 64, vocab_size: int = 100) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.vocab_size = vocab_size
|
||||
|
||||
# Create simple model components
|
||||
self.model = MockInnerModel(hidden_size)
|
||||
self.lm_head = nn.Linear(hidden_size, vocab_size, bias=False)
|
||||
self._layers = [nn.Linear(hidden_size, hidden_size) for _ in range(3)]
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache: list | None = None,
|
||||
) -> mx.array:
|
||||
hidden = self.model(inputs, cache)
|
||||
return self.lm_head(hidden)
|
||||
|
||||
@property
|
||||
def layers(self) -> list[nn.Module]:
|
||||
return self._layers
|
||||
|
||||
|
||||
class MockInnerModel(nn.Module):
|
||||
"""Mock inner model (like DeepseekV3Model)."""
|
||||
|
||||
def __init__(self, hidden_size: int) -> None:
|
||||
super().__init__()
|
||||
self.embed_tokens = nn.Embedding(100, hidden_size)
|
||||
self.norm = nn.RMSNorm(hidden_size)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache: list | None = None,
|
||||
) -> mx.array:
|
||||
# Simple embedding + norm
|
||||
embedded = self.embed_tokens(inputs)
|
||||
return self.norm(embedded)
|
||||
|
||||
|
||||
class TestModelWithHiddenStates:
|
||||
"""Tests for ModelWithHiddenStates wrapper."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_model(self) -> MockModel:
|
||||
return MockModel(hidden_size=64, vocab_size=100)
|
||||
|
||||
def test_forward_returns_logits(self, mock_model: MockModel) -> None:
|
||||
"""Standard forward should return logits."""
|
||||
wrapped = ModelWithHiddenStates(mock_model)
|
||||
inputs = mx.array([[1, 2, 3]])
|
||||
|
||||
logits = wrapped.forward(inputs)
|
||||
|
||||
assert logits.shape == (1, 3, mock_model.vocab_size)
|
||||
|
||||
def test_forward_with_hidden_returns_tuple(self, mock_model: MockModel) -> None:
|
||||
"""Forward with hidden should return (logits, hidden)."""
|
||||
wrapped = ModelWithHiddenStates(mock_model)
|
||||
inputs = mx.array([[1, 2, 3]])
|
||||
|
||||
logits, hidden = wrapped.forward_with_hidden(inputs)
|
||||
|
||||
assert logits.shape == (1, 3, mock_model.vocab_size)
|
||||
assert hidden.shape == (1, 3, mock_model.hidden_size)
|
||||
|
||||
def test_layers_property(self, mock_model: MockModel) -> None:
|
||||
"""Should expose layers property from base model."""
|
||||
wrapped = ModelWithHiddenStates(mock_model)
|
||||
|
||||
assert wrapped.layers == mock_model.layers
|
||||
assert len(wrapped.layers) == 3
|
||||
|
||||
|
||||
class TestMaybeQuantizeKVCache:
|
||||
"""Tests for KV cache quantization."""
|
||||
|
||||
def test_no_quantization_when_bits_none(self) -> None:
|
||||
"""Should not quantize when kv_bits is None."""
|
||||
cache = [MockCache(offset=100)]
|
||||
|
||||
maybe_quantize_kv_cache(
|
||||
cache,
|
||||
quantized_kv_start=50,
|
||||
kv_group_size=64,
|
||||
kv_bits=None,
|
||||
)
|
||||
|
||||
# Cache should be unchanged
|
||||
assert not hasattr(cache[0], "quantized")
|
||||
|
||||
def test_respects_quantized_kv_start(self) -> None:
|
||||
"""Should only quantize caches past the start threshold."""
|
||||
cache_below = MockCache(offset=30)
|
||||
cache_above = MockCache(offset=100)
|
||||
caches = [cache_below, cache_above]
|
||||
|
||||
maybe_quantize_kv_cache(
|
||||
caches,
|
||||
quantized_kv_start=50,
|
||||
kv_group_size=64,
|
||||
kv_bits=4,
|
||||
)
|
||||
|
||||
# Only cache_above should be quantized
|
||||
assert not getattr(cache_below, "was_quantized", False)
|
||||
assert getattr(caches[1], "was_quantized", False)
|
||||
|
||||
|
||||
class MockCache:
|
||||
"""Mock KV cache for testing."""
|
||||
|
||||
def __init__(self, offset: int = 0) -> None:
|
||||
self.offset = offset
|
||||
self.was_quantized = False
|
||||
|
||||
def to_quantized(self, group_size: int, bits: int) -> "MockCache":
|
||||
quantized = MockCache(self.offset)
|
||||
quantized.was_quantized = True
|
||||
return quantized
|
||||
|
||||
|
||||
class TestSpeculativeDecodingLogic:
|
||||
"""Tests for the core speculative decoding logic."""
|
||||
|
||||
def test_draft_acceptance_identical_tokens(self) -> None:
|
||||
"""When draft matches verification, both should be accepted."""
|
||||
# This tests the logic, not the full generator
|
||||
draft_token = 42
|
||||
verify_token = 42
|
||||
|
||||
accepted = draft_token == verify_token
|
||||
assert accepted
|
||||
|
||||
def test_draft_rejection_different_tokens(self) -> None:
|
||||
"""When draft differs from verification, draft should be rejected."""
|
||||
draft_token = 42
|
||||
verify_token = 99
|
||||
|
||||
accepted = draft_token == verify_token
|
||||
assert not accepted
|
||||
|
||||
|
||||
class TestMTPGenerationResponse:
|
||||
"""Tests for MTPGenerationResponse dataclass."""
|
||||
|
||||
def test_response_creation(self) -> None:
|
||||
"""Should create response with all fields."""
|
||||
from exo.worker.engines.mlx.mtp.speculative_decode import MTPGenerationResponse
|
||||
|
||||
response = MTPGenerationResponse(
|
||||
text="Hello",
|
||||
token=42,
|
||||
logprobs=mx.array([0.1, 0.2]),
|
||||
from_draft=True,
|
||||
prompt_tokens=10,
|
||||
prompt_tps=100.0,
|
||||
generation_tokens=5,
|
||||
generation_tps=50.0,
|
||||
peak_memory=1.5,
|
||||
finish_reason=None,
|
||||
)
|
||||
|
||||
assert response.text == "Hello"
|
||||
assert response.token == 42
|
||||
assert response.from_draft is True
|
||||
assert response.finish_reason is None
|
||||
|
||||
def test_response_with_finish_reason(self) -> None:
|
||||
"""Should handle finish_reason."""
|
||||
from exo.worker.engines.mlx.mtp.speculative_decode import MTPGenerationResponse
|
||||
|
||||
response = MTPGenerationResponse(
|
||||
text="",
|
||||
token=0,
|
||||
logprobs=mx.array([0.0]),
|
||||
from_draft=False,
|
||||
prompt_tokens=10,
|
||||
prompt_tps=100.0,
|
||||
generation_tokens=100,
|
||||
generation_tps=50.0,
|
||||
peak_memory=1.5,
|
||||
finish_reason="length",
|
||||
)
|
||||
|
||||
assert response.finish_reason == "length"
|
||||
|
||||
|
||||
class TestIntegration:
|
||||
"""Integration tests for the full MTP pipeline."""
|
||||
|
||||
def test_mtp_module_with_mock_model(self) -> None:
|
||||
"""Test MTP module can be created and run with mock components."""
|
||||
pytest.importorskip("mlx_lm.models.deepseek_v3")
|
||||
|
||||
from exo.worker.engines.mlx.mtp.module import MTPModule
|
||||
|
||||
# Create mock config
|
||||
class MockConfig:
|
||||
hidden_size = 64
|
||||
intermediate_size = 128
|
||||
num_attention_heads = 2
|
||||
num_key_value_heads = 2
|
||||
rms_norm_eps = 1e-6
|
||||
q_lora_rank = None
|
||||
kv_lora_rank = 32
|
||||
qk_rope_head_dim = 8
|
||||
v_head_dim = 16
|
||||
qk_nope_head_dim = 16
|
||||
rope_theta = 10000.0
|
||||
rope_scaling = None
|
||||
attention_bias = False
|
||||
max_position_embeddings = 2048
|
||||
|
||||
config = MockConfig()
|
||||
embedding = nn.Embedding(100, config.hidden_size)
|
||||
lm_head = nn.Linear(config.hidden_size, 100, bias=False)
|
||||
output_norm = nn.RMSNorm(config.hidden_size)
|
||||
|
||||
mtp = MTPModule(
|
||||
config=config, # type: ignore[arg-type]
|
||||
shared_embedding=embedding,
|
||||
shared_lm_head=lm_head,
|
||||
output_norm=output_norm,
|
||||
)
|
||||
|
||||
# Run forward pass
|
||||
hidden = mx.random.normal((1, 1, config.hidden_size))
|
||||
token = mx.array([[5]])
|
||||
|
||||
logits, new_hidden = mtp(hidden, token)
|
||||
|
||||
assert logits.shape == (1, 1, 100)
|
||||
assert new_hidden.shape == (1, 1, config.hidden_size)
|
||||
# Verify outputs are valid (not NaN)
|
||||
assert not mx.any(mx.isnan(logits))
|
||||
assert not mx.any(mx.isnan(new_hidden))
|
||||
@@ -28,6 +28,7 @@ from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
from exo.worker.engines.mlx.constants import (
|
||||
CACHE_GROUP_SIZE,
|
||||
KV_CACHE_BITS,
|
||||
MTP_ENABLED,
|
||||
TRUST_REMOTE_CODE,
|
||||
)
|
||||
|
||||
@@ -69,6 +70,67 @@ Group = mx.distributed.Group
|
||||
resource.setrlimit(resource.RLIMIT_NOFILE, (2048, 4096))
|
||||
|
||||
|
||||
# MTP (Multi-Token Prediction) support for DeepSeek V3
|
||||
MTP_LAYER_INDEX = 61
|
||||
_original_deepseek_sanitize: Callable[..., dict[str, Any]] | None = None
|
||||
|
||||
|
||||
def _is_deepseek_v3_model(model: nn.Module) -> bool:
|
||||
"""Check if the model is DeepSeek V3."""
|
||||
return hasattr(model, "model") and isinstance(model.model, DeepseekV3Model)
|
||||
|
||||
|
||||
def _patch_deepseek_sanitize_for_mtp() -> None:
|
||||
"""Patch DeepSeek V3 Model.sanitize to preserve MTP layer weights.
|
||||
|
||||
The original sanitize() method filters out layer 61 (MTP layer) weights.
|
||||
This patch keeps them so we can extract and use the MTP module.
|
||||
"""
|
||||
global _original_deepseek_sanitize
|
||||
from mlx_lm.models.deepseek_v3 import Model as DeepSeekV3Model
|
||||
|
||||
if _original_deepseek_sanitize is not None:
|
||||
# Already patched
|
||||
return
|
||||
|
||||
_original_deepseek_sanitize = DeepSeekV3Model.sanitize
|
||||
|
||||
def sanitize_with_mtp(
|
||||
self: DeepSeekV3Model, weights: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
"""Modified sanitize that keeps MTP layer weights."""
|
||||
# First, call the original sanitize to handle all the weight transformations
|
||||
# (dequantization, expert stacking, etc.)
|
||||
if _original_deepseek_sanitize is None:
|
||||
raise RuntimeError(
|
||||
"_original_deepseek_sanitize is None - patch not applied correctly"
|
||||
)
|
||||
original_result: dict[str, Any] = _original_deepseek_sanitize(self, weights)
|
||||
|
||||
# Re-add the MTP layer weights that were filtered out
|
||||
mtp_weights = {
|
||||
k: v
|
||||
for k, v in weights.items()
|
||||
if k.startswith(f"model.layers.{MTP_LAYER_INDEX}")
|
||||
}
|
||||
|
||||
return {**original_result, **mtp_weights}
|
||||
|
||||
DeepSeekV3Model.sanitize = sanitize_with_mtp
|
||||
|
||||
|
||||
def _restore_deepseek_sanitize() -> None:
|
||||
"""Restore the original DeepSeek V3 sanitize method."""
|
||||
global _original_deepseek_sanitize
|
||||
if _original_deepseek_sanitize is None:
|
||||
return
|
||||
|
||||
from mlx_lm.models.deepseek_v3 import Model as DeepSeekV3Model
|
||||
|
||||
DeepSeekV3Model.sanitize = _original_deepseek_sanitize
|
||||
_original_deepseek_sanitize = None
|
||||
|
||||
|
||||
# TODO: Test this
|
||||
# ALSO https://github.com/exo-explore/exo/pull/233#discussion_r2549683673
|
||||
def get_weights_size(model_shard_meta: ShardMetadata) -> Memory:
|
||||
@@ -233,50 +295,162 @@ def load_mlx_items(
|
||||
group: Group | None,
|
||||
on_timeout: TimeoutCallback | None = None,
|
||||
) -> tuple[Model, TokenizerWrapper]:
|
||||
if group is None:
|
||||
logger.info(f"Single device used for {bound_instance.instance}")
|
||||
model_path = build_model_path(bound_instance.bound_shard.model_meta.model_id)
|
||||
start_time = time.perf_counter()
|
||||
model, _ = load_model(model_path, strict=True)
|
||||
end_time = time.perf_counter()
|
||||
logger.info(f"Time taken to load model: {(end_time - start_time):.2f}s")
|
||||
tokenizer = get_tokenizer(model_path, bound_instance.bound_shard)
|
||||
"""Load MLX model and tokenizer.
|
||||
|
||||
else:
|
||||
logger.info("Starting distributed init")
|
||||
start_time = time.perf_counter()
|
||||
model, tokenizer = shard_and_load(
|
||||
bound_instance.bound_shard, group=group, on_timeout=on_timeout
|
||||
)
|
||||
end_time = time.perf_counter()
|
||||
logger.info(
|
||||
f"Time taken to shard and load model: {(end_time - start_time):.2f}s"
|
||||
)
|
||||
Returns:
|
||||
Tuple of (model, tokenizer)
|
||||
"""
|
||||
model_id = bound_instance.bound_shard.model_meta.model_id
|
||||
mtp_module = None
|
||||
|
||||
# Patch sanitize for MTP if this might be DeepSeek V3
|
||||
should_try_mtp = MTP_ENABLED and _might_be_deepseek_v3(model_id)
|
||||
if should_try_mtp:
|
||||
logger.info("Patching DeepSeek V3 sanitize for MTP weight preservation")
|
||||
_patch_deepseek_sanitize_for_mtp()
|
||||
|
||||
try:
|
||||
if group is None:
|
||||
logger.info(f"Single device used for {bound_instance.instance}")
|
||||
model_path = build_model_path(model_id)
|
||||
start_time = time.perf_counter()
|
||||
model, _ = load_model(model_path, strict=not should_try_mtp)
|
||||
end_time = time.perf_counter()
|
||||
logger.info(f"Time taken to load model: {(end_time - start_time):.2f}s")
|
||||
tokenizer = get_tokenizer(model_path, bound_instance.bound_shard)
|
||||
|
||||
else:
|
||||
logger.info("Starting distributed init")
|
||||
start_time = time.perf_counter()
|
||||
model, tokenizer = shard_and_load(
|
||||
bound_instance.bound_shard, group=group, on_timeout=on_timeout
|
||||
)
|
||||
end_time = time.perf_counter()
|
||||
logger.info(
|
||||
f"Time taken to shard and load model: {(end_time - start_time):.2f}s"
|
||||
)
|
||||
|
||||
# Extract MTP module if available
|
||||
if should_try_mtp and _is_deepseek_v3_model(model):
|
||||
mtp_module = _extract_mtp_module(model)
|
||||
if mtp_module is not None:
|
||||
logger.info("Successfully extracted MTP module from DeepSeek V3")
|
||||
|
||||
finally:
|
||||
# Restore original sanitize
|
||||
if should_try_mtp:
|
||||
_restore_deepseek_sanitize()
|
||||
|
||||
set_wired_limit_for_model(get_weights_size(bound_instance.bound_shard))
|
||||
|
||||
# Store MTP module on the model for later access
|
||||
if mtp_module is not None:
|
||||
model.mtp_module = mtp_module # noqa: B010
|
||||
|
||||
return cast(Model, model), tokenizer
|
||||
|
||||
|
||||
def load_draft_model(model_id: str) -> nn.Module:
|
||||
"""Load a draft model for speculative decoding (rank 0 only).
|
||||
def _might_be_deepseek_v3(model_id: str) -> bool:
|
||||
"""Check if model ID suggests this might be DeepSeek V3."""
|
||||
model_id_lower = model_id.lower()
|
||||
return "deepseek" in model_id_lower and (
|
||||
"v3" in model_id_lower or "r1" in model_id_lower
|
||||
)
|
||||
|
||||
Draft models are small models (typically 0.5B-2B parameters) used to
|
||||
generate candidate tokens quickly, which are then verified by the main
|
||||
model in a single forward pass.
|
||||
|
||||
Assumes the model has already been downloaded by the worker.
|
||||
def _flatten_params(
|
||||
params: dict[str, Any],
|
||||
prefix: str = "",
|
||||
) -> dict[str, mx.array]:
|
||||
"""Flatten nested parameter dict to flat dict with dot-separated keys."""
|
||||
result: dict[str, mx.array] = {}
|
||||
for key, value in params.items():
|
||||
full_key = f"{prefix}.{key}" if prefix else key
|
||||
if isinstance(value, mx.array):
|
||||
result[full_key] = value
|
||||
elif isinstance(value, dict):
|
||||
result.update(_flatten_params(value, full_key))
|
||||
return result
|
||||
|
||||
Args:
|
||||
model_id: HuggingFace model ID for the draft model
|
||||
|
||||
def _extract_mtp_module(model: nn.Module) -> Any | None:
|
||||
"""Extract MTP module from a loaded DeepSeek V3 model.
|
||||
|
||||
The MTP weights are stored in model.model.layers at index 61 (if preserved).
|
||||
This function extracts them and creates an MTPModule.
|
||||
|
||||
Returns:
|
||||
The loaded draft model
|
||||
MTPModule if MTP weights were found and extracted, None otherwise.
|
||||
"""
|
||||
model_path = build_model_path(model_id)
|
||||
draft_model, _ = load_model(model_path, strict=True)
|
||||
logger.info(f"Loaded draft model from {model_path}")
|
||||
return draft_model
|
||||
from exo.worker.engines.mlx.mtp.module import (
|
||||
MTPModule,
|
||||
extract_mtp_weights,
|
||||
load_mtp_weights_into_module,
|
||||
)
|
||||
|
||||
try:
|
||||
# Check if the model has the MTP layer
|
||||
inner_model = getattr(model, "model", None)
|
||||
if inner_model is None or not hasattr(inner_model, "layers"):
|
||||
logger.debug("Model doesn't have expected structure for MTP extraction")
|
||||
return None
|
||||
|
||||
layers: list[nn.Module] = inner_model.layers # type: ignore[assignment]
|
||||
if len(layers) <= MTP_LAYER_INDEX:
|
||||
logger.debug(
|
||||
f"Model has {len(layers)} layers, MTP layer {MTP_LAYER_INDEX} not found"
|
||||
)
|
||||
return None
|
||||
|
||||
# Get model config
|
||||
config = getattr(model, "args", None)
|
||||
if config is None:
|
||||
logger.debug("Could not get model config for MTP module")
|
||||
return None
|
||||
|
||||
# Create MTP module with shared weights
|
||||
embed_tokens = getattr(inner_model, "embed_tokens", None)
|
||||
lm_head = getattr(model, "lm_head", None)
|
||||
norm = getattr(inner_model, "norm", None)
|
||||
|
||||
if embed_tokens is None or lm_head is None or norm is None:
|
||||
logger.debug("Could not get required model components for MTP")
|
||||
return None
|
||||
|
||||
mtp_module = MTPModule(
|
||||
config=config,
|
||||
shared_embedding=embed_tokens,
|
||||
shared_lm_head=lm_head,
|
||||
output_norm=norm,
|
||||
)
|
||||
|
||||
# Extract MTP layer weights from the model's parameters
|
||||
# The weights should be at model.model.layers.61.*
|
||||
# model.parameters() returns a nested dict, we need to flatten it
|
||||
raw_params: dict[str, Any] = dict(model.parameters()) # type: ignore[arg-type]
|
||||
model_weights = _flatten_params(raw_params)
|
||||
mtp_weights = extract_mtp_weights(model_weights)
|
||||
|
||||
if not mtp_weights:
|
||||
logger.debug("No MTP weights found in model parameters")
|
||||
return None
|
||||
|
||||
# Load weights into MTP module
|
||||
load_mtp_weights_into_module(mtp_module, mtp_weights)
|
||||
|
||||
# Remove MTP layer from main model to avoid double computation
|
||||
# Create new layers list without the MTP layer
|
||||
new_layers = [layer for i, layer in enumerate(layers) if i != MTP_LAYER_INDEX]
|
||||
inner_model.layers = new_layers # noqa: B010
|
||||
|
||||
logger.info(
|
||||
f"Extracted MTP module, main model now has {len(new_layers)} layers"
|
||||
)
|
||||
return mtp_module
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to extract MTP module: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def shard_and_load(
|
||||
|
||||
@@ -3,8 +3,7 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.models import ModelId, ModelMetadata
|
||||
from exo.shared.types.models import ModelId
|
||||
from exo.shared.types.tasks import (
|
||||
ChatCompletion,
|
||||
ConnectToGroup,
|
||||
@@ -36,7 +35,6 @@ from exo.shared.types.worker.runners import (
|
||||
RunnerStatus,
|
||||
RunnerWarmingUp,
|
||||
)
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata
|
||||
from exo.worker.runner.runner_supervisor import RunnerSupervisor
|
||||
|
||||
|
||||
@@ -59,7 +57,6 @@ def plan(
|
||||
or _model_needs_download(runners, download_status)
|
||||
or _init_distributed_backend(runners, all_runners)
|
||||
or _load_model(runners, all_runners, global_download_status)
|
||||
or _draft_model_needs_download(runners, download_status)
|
||||
or _ready_to_warmup(runners, all_runners)
|
||||
or _pending_tasks(runners, tasks, all_runners)
|
||||
)
|
||||
@@ -131,57 +128,6 @@ def _model_needs_download(
|
||||
)
|
||||
|
||||
|
||||
def _draft_model_needs_download(
|
||||
runners: Mapping[RunnerId, RunnerSupervisor],
|
||||
download_status: Mapping[ModelId, DownloadProgress],
|
||||
) -> DownloadModel | None:
|
||||
"""Check if draft model needs download (for speculative decoding).
|
||||
|
||||
Only rank 0 needs the draft model, and only after the main model is loaded.
|
||||
"""
|
||||
for runner in runners.values():
|
||||
instance = runner.bound_instance.instance
|
||||
shard = runner.bound_instance.bound_shard
|
||||
|
||||
# Only check when runner is loaded and ready for warmup
|
||||
if not isinstance(runner.status, RunnerLoaded):
|
||||
continue
|
||||
|
||||
# Only rank 0 loads the draft model
|
||||
if shard.device_rank != 0:
|
||||
continue
|
||||
|
||||
# Check if instance has a draft model configured
|
||||
draft_model_id = instance.draft_model
|
||||
if draft_model_id is None:
|
||||
continue
|
||||
|
||||
# Check if draft model needs download
|
||||
if draft_model_id not in download_status or not isinstance(
|
||||
download_status[draft_model_id], (DownloadOngoing, DownloadCompleted)
|
||||
):
|
||||
# Create minimal shard metadata for draft model download
|
||||
draft_shard = PipelineShardMetadata(
|
||||
model_meta=ModelMetadata(
|
||||
model_id=draft_model_id,
|
||||
pretty_name=str(draft_model_id),
|
||||
storage_size=Memory.from_bytes(0), # Unknown, will be determined during download
|
||||
n_layers=1, # Placeholder
|
||||
hidden_size=1, # Placeholder
|
||||
supports_tensor=False,
|
||||
),
|
||||
device_rank=0,
|
||||
world_size=1,
|
||||
start_layer=0,
|
||||
end_layer=1,
|
||||
n_layers=1,
|
||||
)
|
||||
return DownloadModel(
|
||||
instance_id=instance.instance_id,
|
||||
shard_metadata=draft_shard,
|
||||
)
|
||||
|
||||
|
||||
def _init_distributed_backend(
|
||||
runners: Mapping[RunnerId, RunnerSupervisor],
|
||||
all_runners: Mapping[RunnerId, RunnerStatus],
|
||||
|
||||
@@ -56,7 +56,6 @@ from exo.worker.engines.mlx import Model
|
||||
from exo.worker.engines.mlx.generator.generate import mlx_generate, warmup_inference
|
||||
from exo.worker.engines.mlx.utils_mlx import (
|
||||
initialize_mlx,
|
||||
load_draft_model,
|
||||
load_mlx_items,
|
||||
mlx_force_oom,
|
||||
)
|
||||
@@ -111,7 +110,6 @@ def main(
|
||||
model = None
|
||||
tokenizer = None
|
||||
group = None
|
||||
draft_model: Model | None = None # Loaded during warmup if instance has draft_model
|
||||
|
||||
current_status: RunnerStatus = RunnerIdle()
|
||||
logger.info("runner created")
|
||||
@@ -180,20 +178,11 @@ def main(
|
||||
)
|
||||
)
|
||||
|
||||
# Load draft model for speculative decoding (rank 0 only)
|
||||
if (
|
||||
instance.draft_model is not None
|
||||
and shard_metadata.device_rank == 0
|
||||
):
|
||||
logger.info(f"Loading draft model: {instance.draft_model}")
|
||||
draft_model = cast(
|
||||
Model, load_draft_model(str(instance.draft_model))
|
||||
)
|
||||
|
||||
logger.info(f"warming up inference for instance: {instance}")
|
||||
toks = warmup_inference(
|
||||
model=cast(Model, model),
|
||||
tokenizer=tokenizer,
|
||||
# kv_prefix_cache=kv_prefix_cache, # supply for warmup-time prefix caching
|
||||
)
|
||||
logger.info(f"warmed up by generating {toks} tokens")
|
||||
logger.info(
|
||||
@@ -223,13 +212,11 @@ def main(
|
||||
assert task_params.messages[0].content is not None
|
||||
_check_for_debug_prompts(task_params.messages[0].content)
|
||||
|
||||
# Generate responses (draft_model loaded at warmup if configured)
|
||||
# Generate responses using the actual MLX generation
|
||||
mlx_generator = mlx_generate(
|
||||
model=cast(Model, model),
|
||||
tokenizer=tokenizer,
|
||||
task=task_params,
|
||||
draft_model=draft_model,
|
||||
num_draft_tokens=instance.num_draft_tokens,
|
||||
)
|
||||
|
||||
# GPT-OSS specific parsing to match other model formats.
|
||||
@@ -278,7 +265,7 @@ def main(
|
||||
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
|
||||
)
|
||||
if isinstance(current_status, RunnerShutdown):
|
||||
del model, tokenizer, group, draft_model
|
||||
del model, tokenizer, group
|
||||
mx.clear_cache()
|
||||
import gc
|
||||
|
||||
|
||||
Reference in New Issue
Block a user