Compare commits

..

1 Commits

Author SHA1 Message Date
Alex Cheema
e3465afae3 Split NodePerformanceProfile state storage into separate mappings
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-17 21:28:42 +00:00
22 changed files with 460 additions and 1979 deletions

View File

@@ -40,31 +40,6 @@ uv run ruff check
nix fmt
```
## Pre-Commit Checks (REQUIRED)
**IMPORTANT: Always run these checks before committing code. CI will fail if these don't pass.**
```bash
# 1. Type checking - MUST pass with 0 errors
uv run basedpyright
# 2. Linting - MUST pass
uv run ruff check
# 3. Formatting - MUST be applied
nix fmt
# 4. Tests - MUST pass
uv run pytest
```
Run all checks in sequence:
```bash
uv run basedpyright && uv run ruff check && nix fmt && uv run pytest
```
If `nix fmt` changes any files, stage them before committing. The CI runs `nix flake check` which verifies formatting, linting, and runs Rust tests.
## Architecture
### Node Composition

View File

@@ -71,35 +71,36 @@ export interface Instance {
};
}
interface RawNodeProfile {
modelId?: string;
chipId?: string;
friendlyName?: string;
networkInterfaces?: Array<{
name?: string;
ipAddress?: string;
addresses?: Array<{ address?: string } | string>;
ipv4?: string;
ipv6?: string;
ipAddresses?: string[];
ips?: string[];
}>;
memory?: {
ramTotal?: { inBytes: number };
ramAvailable?: { inBytes: number };
swapTotal?: { inBytes: number };
swapAvailable?: { inBytes: number };
};
system?: {
gpuUsage?: number;
temp?: number;
sysPower?: number;
};
// Split state interfaces
interface RawNodeIdentity {
modelId: string;
chipId: string;
friendlyName: string;
}
interface RawNodeMemory {
ramTotal: { inBytes: number };
ramAvailable: { inBytes: number };
swapTotal: { inBytes: number };
swapAvailable: { inBytes: number };
}
interface RawNodeSystem {
gpuUsage?: number;
temp?: number;
sysPower?: number;
pcpuUsage?: number;
ecpuUsage?: number;
anePower?: number;
}
interface RawNetworkInterface {
name: string;
ipAddress: string;
}
interface RawTopologyNode {
nodeId: string;
nodeProfile: RawNodeProfile;
}
interface RawTopologyConnection {
@@ -115,8 +116,6 @@ interface RawTopology {
connections?: RawTopologyConnection[];
}
type RawNodeProfiles = Record<string, RawNodeProfile>;
export interface DownloadProgress {
totalBytes: number;
downloadedBytes: number;
@@ -171,7 +170,11 @@ interface RawStateResponse {
>;
runners?: Record<string, unknown>;
downloads?: Record<string, unknown[]>;
nodeProfiles?: RawNodeProfiles;
// Split state fields
nodeIdentities?: Record<string, RawNodeIdentity>;
nodeMemories?: Record<string, RawNodeMemory>;
nodeSystems?: Record<string, RawNodeSystem>;
nodeNetworks?: Record<string, RawNetworkInterface[]>;
}
export interface MessageAttachment {
@@ -208,66 +211,41 @@ const STORAGE_KEY = "exo-conversations";
function transformTopology(
raw: RawTopology,
profiles?: RawNodeProfiles,
identities?: Record<string, RawNodeIdentity>,
memories?: Record<string, RawNodeMemory>,
systems?: Record<string, RawNodeSystem>,
networks?: Record<string, RawNetworkInterface[]>,
): TopologyData {
const nodes: Record<string, NodeInfo> = {};
const edges: TopologyEdge[] = [];
for (const node of raw.nodes || []) {
const mergedProfile = profiles?.[node.nodeId];
const profile = { ...(node.nodeProfile ?? {}), ...(mergedProfile ?? {}) };
const ramTotal = profile?.memory?.ramTotal?.inBytes ?? 0;
const ramAvailable = profile?.memory?.ramAvailable?.inBytes ?? 0;
// Get split state fields (may be undefined if events haven't arrived yet)
const identity = identities?.[node.nodeId];
const memory = memories?.[node.nodeId];
const system = systems?.[node.nodeId];
const network = networks?.[node.nodeId];
const ramTotal = memory?.ramTotal?.inBytes ?? 0;
const ramAvailable = memory?.ramAvailable?.inBytes ?? 0;
const ramUsage = Math.max(ramTotal - ramAvailable, 0);
const networkInterfaces = (profile?.networkInterfaces || []).map(
(iface) => {
const addresses: string[] = [];
if (iface.ipAddress && typeof iface.ipAddress === "string") {
addresses.push(iface.ipAddress);
}
if (Array.isArray(iface.addresses)) {
for (const addr of iface.addresses) {
if (typeof addr === "string") addresses.push(addr);
else if (addr && typeof addr === "object" && addr.address)
addresses.push(addr.address);
}
}
if (Array.isArray(iface.ipAddresses)) {
addresses.push(
...iface.ipAddresses.filter(
(a): a is string => typeof a === "string",
),
);
}
if (Array.isArray(iface.ips)) {
addresses.push(
...iface.ips.filter((a): a is string => typeof a === "string"),
);
}
if (iface.ipv4 && typeof iface.ipv4 === "string")
addresses.push(iface.ipv4);
if (iface.ipv6 && typeof iface.ipv6 === "string")
addresses.push(iface.ipv6);
return {
name: iface.name,
addresses: Array.from(new Set(addresses)),
};
},
);
const networkInterfaces = (network ?? []).map((iface) => ({
name: iface.name,
addresses: [iface.ipAddress],
}));
const ipToInterface: Record<string, string> = {};
for (const iface of networkInterfaces) {
for (const addr of iface.addresses || []) {
ipToInterface[addr] = iface.name ?? "";
for (const addr of iface.addresses) {
ipToInterface[addr] = iface.name;
}
}
nodes[node.nodeId] = {
system_info: {
model_id: profile?.modelId ?? "Unknown",
chip: profile?.chipId,
model_id: identity?.modelId ?? "Unknown",
chip: identity?.chipId,
memory: ramTotal,
},
network_interfaces: networkInterfaces,
@@ -278,17 +256,15 @@ function transformTopology(
ram_total: ramTotal,
},
temp:
profile?.system?.temp !== undefined
? { gpu_temp_avg: profile.system.temp }
system?.temp !== undefined
? { gpu_temp_avg: system.temp }
: undefined,
gpu_usage:
profile?.system?.gpuUsage !== undefined
? [0, profile.system.gpuUsage]
: undefined,
sys_power: profile?.system?.sysPower,
system?.gpuUsage !== undefined ? [0, system.gpuUsage] : undefined,
sys_power: system?.sysPower,
},
last_macmon_update: Date.now() / 1000,
friendly_name: profile?.friendlyName,
friendly_name: identity?.friendlyName,
};
}
@@ -868,7 +844,13 @@ class AppStore {
const data: RawStateResponse = await response.json();
if (data.topology) {
this.topologyData = transformTopology(data.topology, data.nodeProfiles);
this.topologyData = transformTopology(
data.topology,
data.nodeIdentities,
data.nodeMemories,
data.nodeSystems,
data.nodeNetworks,
);
}
if (data.instances) {
this.instances = data.instances;

View File

@@ -599,9 +599,8 @@ class API:
"""Calculate total available memory across all nodes in bytes."""
total_available = Memory()
for node in self.state.topology.list_nodes():
if node.node_profile is not None:
total_available += node.node_profile.memory.ram_available
for memory in self.state.node_memories.values():
total_available += memory.ram_available
return total_available

View File

@@ -113,6 +113,7 @@ def place_instance(
node.node_profile.memory.ram_available
for node in cycle
if node.node_profile is not None
and node.node_profile.memory is not None
),
start=Memory(),
),

View File

@@ -25,7 +25,10 @@ class NodeWithProfile(BaseModel):
def narrow_all_nodes(nodes: list[NodeInfo]) -> TypeGuard[list[NodeWithProfile]]:
return all(node.node_profile is not None for node in nodes)
return all(
node.node_profile is not None and node.node_profile.memory is not None
for node in nodes
)
def filter_cycles_by_memory(
@@ -36,8 +39,14 @@ def filter_cycles_by_memory(
if not narrow_all_nodes(cycle):
continue
# narrow_all_nodes guarantees memory is not None
total_mem = sum(
(node.node_profile.memory.ram_available for node in cycle), start=Memory()
(
node.node_profile.memory.ram_available
for node in cycle
if node.node_profile.memory is not None
),
start=Memory(),
)
if total_mem >= required_memory:
filtered_cycles.append(cast(list[NodeInfo], cycle))
@@ -53,8 +62,13 @@ def get_shard_assignments_for_pipeline_parallel(
model_meta: ModelMetadata,
selected_cycle: list[NodeWithProfile],
):
# NodeWithProfile guarantees memory is not None
cycle_memory = sum(
(node.node_profile.memory.ram_available for node in selected_cycle),
(
node.node_profile.memory.ram_available
for node in selected_cycle
if node.node_profile.memory is not None
),
start=Memory(),
)
total_layers = model_meta.n_layers
@@ -67,6 +81,8 @@ def get_shard_assignments_for_pipeline_parallel(
if i == len(selected_cycle) - 1:
node_layers = total_layers - layers_assigned
else:
# NodeWithProfile guarantees memory is not None
assert node.node_profile.memory is not None
node_layers = round(
total_layers
* (

View File

@@ -19,16 +19,13 @@ from exo.shared.types.events import (
ForwarderEvent,
IndexedEvent,
InstanceCreated,
NodePerformanceMeasured,
NodeIdentityMeasured,
NodeMemoryMeasured,
TaskCreated,
)
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.profiling import (
MemoryPerformanceProfile,
NodePerformanceProfile,
SystemPerformanceProfile,
)
from exo.shared.types.profiling import MemoryPerformanceProfile
from exo.shared.types.tasks import ChatCompletion as ChatCompletionTask
from exo.shared.types.tasks import TaskStatus
from exo.shared.types.worker.instances import (
@@ -75,29 +72,39 @@ async def test_master():
tg.start_soon(master.run)
sender_node_id = NodeId(f"{keypair.to_peer_id().to_base58()}_sender")
# inject a NodePerformanceProfile event
logger.info("inject a NodePerformanceProfile event")
# inject NodeIdentityMeasured and NodeMemoryMeasured events
logger.info("inject NodeIdentityMeasured event")
await local_event_sender.send(
ForwarderEvent(
origin_idx=0,
origin=sender_node_id,
session=session_id,
event=(
NodePerformanceMeasured(
NodeIdentityMeasured(
when=str(datetime.now(tz=timezone.utc)),
node_id=node_id,
node_profile=NodePerformanceProfile(
model_id="maccy",
chip_id="arm",
friendly_name="test",
memory=MemoryPerformanceProfile(
ram_total=Memory.from_bytes(678948 * 1024),
ram_available=Memory.from_bytes(678948 * 1024),
swap_total=Memory.from_bytes(0),
swap_available=Memory.from_bytes(0),
),
network_interfaces=[],
system=SystemPerformanceProfile(),
model_id="maccy",
chip_id="arm",
friendly_name="test",
)
),
)
)
logger.info("inject NodeMemoryMeasured event")
await local_event_sender.send(
ForwarderEvent(
origin_idx=1,
origin=sender_node_id,
session=session_id,
event=(
NodeMemoryMeasured(
when=str(datetime.now(tz=timezone.utc)),
node_id=node_id,
memory=MemoryPerformanceProfile(
ram_total=Memory.from_bytes(678948 * 1024),
ram_available=Memory.from_bytes(678948 * 1024),
swap_total=Memory.from_bytes(0),
swap_available=Memory.from_bytes(0),
),
)
),
@@ -108,7 +115,7 @@ async def test_master():
logger.info("wait for initial topology event")
while len(list(master.state.topology.list_nodes())) == 0:
await anyio.sleep(0.001)
while len(master.state.node_profiles) == 0:
while len(master.state.node_identities) == 0:
await anyio.sleep(0.001)
logger.info("inject a CreateInstance Command")
@@ -155,17 +162,19 @@ async def test_master():
),
)
)
while len(_get_events()) < 3:
while len(_get_events()) < 4:
await anyio.sleep(0.01)
events = _get_events()
assert len(events) == 3
assert len(events) == 4
assert events[0].idx == 0
assert events[1].idx == 1
assert events[2].idx == 2
assert isinstance(events[0].event, NodePerformanceMeasured)
assert isinstance(events[1].event, InstanceCreated)
created_instance = events[1].event.instance
assert events[3].idx == 3
assert isinstance(events[0].event, NodeIdentityMeasured)
assert isinstance(events[1].event, NodeMemoryMeasured)
assert isinstance(events[2].event, InstanceCreated)
created_instance = events[2].event.instance
assert isinstance(created_instance, MlxRingInstance)
runner_id = list(created_instance.shard_assignments.runner_to_shard.keys())[0]
# Validate the shard assignments
@@ -197,10 +206,10 @@ async def test_master():
assert len(created_instance.hosts_by_node[node_id]) == 1
assert created_instance.hosts_by_node[node_id][0].ip == "0.0.0.0"
assert created_instance.ephemeral_port > 0
assert isinstance(events[2].event, TaskCreated)
assert events[2].event.task.task_status == TaskStatus.Pending
assert isinstance(events[2].event.task, ChatCompletionTask)
assert events[2].event.task.task_params == ChatCompletionTaskParams(
assert isinstance(events[3].event, TaskCreated)
assert events[3].event.task.task_status == TaskStatus.Pending
assert isinstance(events[3].event.task, ChatCompletionTask)
assert events[3].event.task.task_params == ChatCompletionTaskParams(
model="llama-3.2-1b",
messages=[
ChatCompletionMessage(role="user", content="Hello, how are you?")

View File

@@ -13,8 +13,10 @@ from exo.shared.types.events import (
InstanceDeleted,
NodeCreated,
NodeDownloadProgress,
NodeIdentityMeasured,
NodeMemoryMeasured,
NodePerformanceMeasured,
NodeNetworkMeasured,
NodeSystemMeasured,
NodeTimedOut,
RunnerDeleted,
RunnerStatusUpdated,
@@ -27,7 +29,13 @@ from exo.shared.types.events import (
TopologyEdgeCreated,
TopologyEdgeDeleted,
)
from exo.shared.types.profiling import NodePerformanceProfile, SystemPerformanceProfile
from exo.shared.types.profiling import (
MemoryPerformanceProfile,
NetworkInterfaceInfo,
NodeIdentity,
NodePerformanceProfile,
SystemPerformanceProfile,
)
from exo.shared.types.state import State
from exo.shared.types.tasks import Task, TaskId, TaskStatus
from exo.shared.types.topology import NodeInfo
@@ -51,8 +59,12 @@ def event_apply(event: Event, state: State) -> State:
return apply_topology_node_created(event, state)
case NodeTimedOut():
return apply_node_timed_out(event, state)
case NodePerformanceMeasured():
return apply_node_performance_measured(event, state)
case NodeIdentityMeasured():
return apply_node_identity_measured(event, state)
case NodeSystemMeasured():
return apply_node_system_measured(event, state)
case NodeNetworkMeasured():
return apply_node_network_measured(event, state)
case NodeDownloadProgress():
return apply_node_download_progress(event, state)
case NodeMemoryMeasured():
@@ -190,8 +202,19 @@ def apply_runner_deleted(event: RunnerDeleted, state: State) -> State:
def apply_node_timed_out(event: NodeTimedOut, state: State) -> State:
topology = copy.copy(state.topology)
state.topology.remove_node(event.node_id)
node_profiles = {
key: value for key, value in state.node_profiles.items() if key != event.node_id
node_identities = {
key: value
for key, value in state.node_identities.items()
if key != event.node_id
}
node_memories = {
key: value for key, value in state.node_memories.items() if key != event.node_id
}
node_systems = {
key: value for key, value in state.node_systems.items() if key != event.node_id
}
node_networks = {
key: value for key, value in state.node_networks.items() if key != event.node_id
}
last_seen = {
key: value for key, value in state.last_seen.items() if key != event.node_id
@@ -199,32 +222,120 @@ def apply_node_timed_out(event: NodeTimedOut, state: State) -> State:
return state.model_copy(
update={
"topology": topology,
"node_profiles": node_profiles,
"node_identities": node_identities,
"node_memories": node_memories,
"node_systems": node_systems,
"node_networks": node_networks,
"last_seen": last_seen,
}
)
def apply_node_performance_measured(
event: NodePerformanceMeasured, state: State
) -> State:
new_profiles: Mapping[NodeId, NodePerformanceProfile] = {
**state.node_profiles,
event.node_id: event.node_profile,
def _reconstruct_profile(
node_id: NodeId,
state: State,
*,
identity: NodeIdentity | None = None,
memory: MemoryPerformanceProfile | None = None,
system: SystemPerformanceProfile | None = None,
network_interfaces: list[NetworkInterfaceInfo] | None = None,
) -> NodePerformanceProfile:
"""Reconstruct a NodePerformanceProfile from split state storage.
Uses provided overrides, falling back to state values.
"""
ident = identity or state.node_identities.get(node_id)
mem = memory or state.node_memories.get(node_id)
sys = system or state.node_systems.get(node_id)
nets = (
network_interfaces
if network_interfaces is not None
else state.node_networks.get(node_id, [])
)
return NodePerformanceProfile(
model_id=ident.model_id if ident else None,
chip_id=ident.chip_id if ident else None,
friendly_name=ident.friendly_name if ident else None,
memory=mem,
network_interfaces=nets,
system=sys,
)
def apply_node_identity_measured(event: NodeIdentityMeasured, state: State) -> State:
topology = copy.copy(state.topology)
identity = NodeIdentity(
model_id=event.model_id,
chip_id=event.chip_id,
friendly_name=event.friendly_name,
)
new_identities: Mapping[NodeId, NodeIdentity] = {
**state.node_identities,
event.node_id: identity,
}
last_seen: Mapping[NodeId, datetime] = {
**state.last_seen,
event.node_id: datetime.fromisoformat(event.when),
}
state = state.model_copy(update={"node_profiles": new_profiles})
topology = copy.copy(state.topology)
# TODO: NodeCreated
if not topology.contains_node(event.node_id):
topology.add_node(NodeInfo(node_id=event.node_id))
topology.update_node_profile(event.node_id, event.node_profile)
reconstructed = _reconstruct_profile(event.node_id, state, identity=identity)
topology.update_node_profile(event.node_id, reconstructed)
return state.model_copy(
update={
"node_profiles": new_profiles,
"node_identities": new_identities,
"topology": topology,
"last_seen": last_seen,
}
)
def apply_node_system_measured(event: NodeSystemMeasured, state: State) -> State:
topology = copy.copy(state.topology)
new_systems: Mapping[NodeId, SystemPerformanceProfile] = {
**state.node_systems,
event.node_id: event.system,
}
last_seen: Mapping[NodeId, datetime] = {
**state.last_seen,
event.node_id: datetime.fromisoformat(event.when),
}
if not topology.contains_node(event.node_id):
topology.add_node(NodeInfo(node_id=event.node_id))
reconstructed = _reconstruct_profile(event.node_id, state, system=event.system)
topology.update_node_profile(event.node_id, reconstructed)
return state.model_copy(
update={
"node_systems": new_systems,
"topology": topology,
"last_seen": last_seen,
}
)
def apply_node_network_measured(event: NodeNetworkMeasured, state: State) -> State:
topology = copy.copy(state.topology)
new_networks: Mapping[NodeId, list[NetworkInterfaceInfo]] = {
**state.node_networks,
event.node_id: event.network_interfaces,
}
last_seen: Mapping[NodeId, datetime] = {
**state.last_seen,
event.node_id: datetime.fromisoformat(event.when),
}
if not topology.contains_node(event.node_id):
topology.add_node(NodeInfo(node_id=event.node_id))
reconstructed = _reconstruct_profile(
event.node_id, state, network_interfaces=event.network_interfaces
)
topology.update_node_profile(event.node_id, reconstructed)
return state.model_copy(
update={
"node_networks": new_networks,
"topology": topology,
"last_seen": last_seen,
}
@@ -232,57 +343,26 @@ def apply_node_performance_measured(
def apply_node_memory_measured(event: NodeMemoryMeasured, state: State) -> State:
existing = state.node_profiles.get(event.node_id)
topology = copy.copy(state.topology)
if existing is None:
created = NodePerformanceProfile(
model_id="unknown",
chip_id="unknown",
friendly_name="Unknown",
memory=event.memory,
network_interfaces=[],
system=SystemPerformanceProfile(
# TODO: flops_fp16=0.0,
gpu_usage=0.0,
temp=0.0,
sys_power=0.0,
pcpu_usage=0.0,
ecpu_usage=0.0,
ane_power=0.0,
),
)
created_profiles: Mapping[NodeId, NodePerformanceProfile] = {
**state.node_profiles,
event.node_id: created,
}
last_seen: Mapping[NodeId, datetime] = {
**state.last_seen,
event.node_id: datetime.fromisoformat(event.when),
}
if not topology.contains_node(event.node_id):
topology.add_node(NodeInfo(node_id=event.node_id))
# TODO: NodeCreated
topology.update_node_profile(event.node_id, created)
return state.model_copy(
update={
"node_profiles": created_profiles,
"topology": topology,
"last_seen": last_seen,
}
)
updated = existing.model_copy(update={"memory": event.memory})
updated_profiles: Mapping[NodeId, NodePerformanceProfile] = {
**state.node_profiles,
event.node_id: updated,
new_memories: Mapping[NodeId, MemoryPerformanceProfile] = {
**state.node_memories,
event.node_id: event.memory,
}
last_seen: Mapping[NodeId, datetime] = {
**state.last_seen,
event.node_id: datetime.fromisoformat(event.when),
}
# TODO: NodeCreated
if not topology.contains_node(event.node_id):
topology.add_node(NodeInfo(node_id=event.node_id))
topology.update_node_profile(event.node_id, updated)
reconstructed = _reconstruct_profile(event.node_id, state, memory=event.memory)
topology.update_node_profile(event.node_id, reconstructed)
return state.model_copy(
update={"node_profiles": updated_profiles, "topology": topology}
update={
"node_memories": new_memories,
"topology": topology,
"last_seen": last_seen,
}
)

View File

@@ -2,10 +2,14 @@ from datetime import datetime
from pydantic import Field
from exo.shared.topology import Connection, NodePerformanceProfile
from exo.shared.topology import Connection
from exo.shared.types.chunks import GenerationChunk
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
from exo.shared.types.profiling import MemoryPerformanceProfile
from exo.shared.types.profiling import (
MemoryPerformanceProfile,
NetworkInterfaceInfo,
SystemPerformanceProfile,
)
from exo.shared.types.tasks import Task, TaskId, TaskStatus
from exo.shared.types.worker.downloads import DownloadProgress
from exo.shared.types.worker.instances import Instance, InstanceId
@@ -85,13 +89,35 @@ class NodeTimedOut(BaseEvent):
node_id: NodeId
class NodePerformanceMeasured(BaseEvent):
class NodeIdentityMeasured(BaseEvent):
"""Static identity info - emitted once at startup."""
node_id: NodeId
when: str # this is a manually cast datetime overrode by the master when the event is indexed, rather than the local time on the device
node_profile: NodePerformanceProfile
model_id: str
chip_id: str
friendly_name: str
class NodeSystemMeasured(BaseEvent):
"""Dynamic system metrics (GPU, temp, power) - emitted at 1s intervals."""
node_id: NodeId
when: str # this is a manually cast datetime overrode by the master when the event is indexed, rather than the local time on the device
system: SystemPerformanceProfile
class NodeNetworkMeasured(BaseEvent):
"""Semi-static network interface info - emitted at 30s intervals."""
node_id: NodeId
when: str # this is a manually cast datetime overrode by the master when the event is indexed, rather than the local time on the device
network_interfaces: list[NetworkInterfaceInfo]
class NodeMemoryMeasured(BaseEvent):
"""Dynamic memory metrics - emitted at 0.5s intervals."""
node_id: NodeId
when: str # this is a manually cast datetime overrode by the master when the event is indexed, rather than the local time on the device
memory: MemoryPerformanceProfile
@@ -127,7 +153,9 @@ Event = (
| RunnerDeleted
| NodeCreated
| NodeTimedOut
| NodePerformanceMeasured
| NodeIdentityMeasured
| NodeSystemMeasured
| NodeNetworkMeasured
| NodeMemoryMeasured
| NodeDownloadProgress
| ChunkGenerated

View File

@@ -52,13 +52,21 @@ class NetworkInterfaceInfo(CamelCaseModel):
ip_address: str
class NodePerformanceProfile(CamelCaseModel):
class NodeIdentity(CamelCaseModel):
"""Static identity info for a node."""
model_id: str
chip_id: str
friendly_name: str
memory: MemoryPerformanceProfile
class NodePerformanceProfile(CamelCaseModel):
model_id: str | None = None
chip_id: str | None = None
friendly_name: str | None = None
memory: MemoryPerformanceProfile | None = None
network_interfaces: list[NetworkInterfaceInfo] = []
system: SystemPerformanceProfile
system: SystemPerformanceProfile | None = None
class ConnectionProfile(CamelCaseModel):

View File

@@ -7,7 +7,12 @@ from pydantic.alias_generators import to_camel
from exo.shared.topology import Topology, TopologySnapshot
from exo.shared.types.common import NodeId
from exo.shared.types.profiling import NodePerformanceProfile
from exo.shared.types.profiling import (
MemoryPerformanceProfile,
NetworkInterfaceInfo,
NodeIdentity,
SystemPerformanceProfile,
)
from exo.shared.types.tasks import Task, TaskId
from exo.shared.types.worker.downloads import DownloadProgress
from exo.shared.types.worker.instances import Instance, InstanceId
@@ -35,7 +40,10 @@ class State(CamelCaseModel):
runners: Mapping[RunnerId, RunnerStatus] = {}
downloads: Mapping[NodeId, Sequence[DownloadProgress]] = {}
tasks: Mapping[TaskId, Task] = {}
node_profiles: Mapping[NodeId, NodePerformanceProfile] = {}
node_identities: Mapping[NodeId, NodeIdentity] = {}
node_memories: Mapping[NodeId, MemoryPerformanceProfile] = {}
node_systems: Mapping[NodeId, SystemPerformanceProfile] = {}
node_networks: Mapping[NodeId, list[NetworkInterfaceInfo]] = {}
last_seen: Mapping[NodeId, datetime] = {}
topology: Topology = Field(default_factory=Topology)
last_event_applied_idx: int = Field(default=-1, ge=-1)

View File

@@ -13,8 +13,3 @@ KV_CACHE_BITS: int | None = None
# TODO: We should really make this opt-in, but Kimi requires trust_remote_code=True
TRUST_REMOTE_CODE: bool = True
# Multi-Token Prediction (MTP) configuration for DeepSeek V3
# MTP enables speculative decoding using the model's built-in draft layer
MTP_ENABLED: bool = True # Feature flag to enable/disable MTP
MTP_NUM_DRAFT_TOKENS: int = 1 # Number of tokens to draft (vLLM reports k=1 is optimal)

View File

@@ -19,13 +19,7 @@ from exo.shared.types.worker.runner_response import (
GenerationResponse,
)
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.constants import (
KV_BITS,
KV_GROUP_SIZE,
MAX_TOKENS,
MTP_ENABLED,
MTP_NUM_DRAFT_TOKENS,
)
from exo.worker.engines.mlx.constants import KV_BITS, KV_GROUP_SIZE, MAX_TOKENS
from exo.worker.engines.mlx.utils_mlx import (
apply_chat_template,
make_kv_cache,
@@ -121,11 +115,6 @@ def eos_ids_from_tokenizer(tokenizer: TokenizerWrapper) -> list[int]:
return eos
def _has_mtp_module(model: Model) -> bool:
"""Check if the model has an attached MTP module."""
return hasattr(model, "mtp_module") and model.mtp_module is not None # type: ignore[attr-defined]
def mlx_generate(
model: Model,
tokenizer: TokenizerWrapper,
@@ -160,43 +149,6 @@ def mlx_generate(
)
max_tokens = task.max_tokens or MAX_TOKENS
# Check if we should use MTP speculative decoding
use_mtp = MTP_ENABLED and _has_mtp_module(model)
if use_mtp:
logger.info("Using MTP speculative decoding")
yield from _mlx_generate_with_mtp(
model=model,
tokenizer=tokenizer,
prompt=prompt,
max_tokens=max_tokens,
sampler=sampler,
logits_processors=logits_processors,
prompt_cache=caches,
)
else:
yield from _mlx_generate_standard(
model=model,
tokenizer=tokenizer,
prompt=prompt,
max_tokens=max_tokens,
sampler=sampler,
logits_processors=logits_processors,
prompt_cache=caches,
)
def _mlx_generate_standard(
model: Model,
tokenizer: TokenizerWrapper,
prompt: str,
max_tokens: int,
sampler: Callable[[mx.array], mx.array],
logits_processors: list[Callable[[mx.array, mx.array], mx.array]],
prompt_cache: list[KVCache | Any],
) -> Generator[GenerationResponse]:
"""Standard generation path using mlx_lm stream_generate."""
for out in stream_generate(
model=model,
tokenizer=tokenizer,
@@ -204,7 +156,7 @@ def _mlx_generate_standard(
max_tokens=max_tokens,
sampler=sampler,
logits_processors=logits_processors,
prompt_cache=prompt_cache,
prompt_cache=caches,
# TODO: Dynamically change prefill step size to be the maximum possible without timing out.
prefill_step_size=2048,
kv_group_size=KV_GROUP_SIZE,
@@ -239,64 +191,4 @@ def _mlx_generate_standard(
if out.finish_reason is not None:
break
def _mlx_generate_with_mtp(
model: Model,
tokenizer: TokenizerWrapper,
prompt: str,
max_tokens: int,
sampler: Callable[[mx.array], mx.array],
logits_processors: list[Callable[[mx.array, mx.array], mx.array]],
prompt_cache: list[KVCache | Any],
) -> Generator[GenerationResponse]:
"""MTP speculative decoding generation path.
Uses the model's attached MTP module for speculative decoding,
which can provide 1.5-2x speedup with ~81% acceptance rate.
"""
from exo.worker.engines.mlx.mtp.speculative_decode import mtp_speculative_generate
mtp_module = model.mtp_module # type: ignore[attr-defined]
for out in mtp_speculative_generate(
model=model,
mtp_module=mtp_module,
tokenizer=tokenizer,
prompt=prompt,
max_tokens=max_tokens,
sampler=sampler,
logits_processors=logits_processors,
prompt_cache=prompt_cache,
num_draft_tokens=MTP_NUM_DRAFT_TOKENS,
prefill_step_size=2048,
kv_group_size=KV_GROUP_SIZE if KV_GROUP_SIZE is not None else 64,
kv_bits=KV_BITS,
):
logger.info(f"{out.text} (from_draft={out.from_draft})")
stats: GenerationStats | None = None
if out.finish_reason is not None:
stats = GenerationStats(
prompt_tps=float(out.prompt_tps),
generation_tps=float(out.generation_tps),
prompt_tokens=int(out.prompt_tokens),
generation_tokens=int(out.generation_tokens),
peak_memory_usage=Memory.from_gb(out.peak_memory),
)
if out.finish_reason not in get_args(FinishReason):
logger.warning(
f"Model generated unexpected finish_reason: {out.finish_reason}"
)
yield GenerationResponse(
text=out.text,
token=out.token,
finish_reason=cast(FinishReason | None, out.finish_reason),
stats=stats,
)
if out.finish_reason is not None:
break
# TODO: Do we want an mx_barrier?

View File

@@ -1,6 +0,0 @@
"""Multi-Token Prediction (MTP) module for DeepSeek V3 speculative decoding."""
from exo.worker.engines.mlx.mtp.module import MTPModule
from exo.worker.engines.mlx.mtp.speculative_decode import mtp_speculative_generate
__all__ = ["MTPModule", "mtp_speculative_generate"]

View File

@@ -1,207 +0,0 @@
"""MTP Module for DeepSeek V3 Multi-Token Prediction.
The MTP architecture predicts one additional token ahead using:
1. hnorm - RMSNorm for hidden state normalization
2. enorm - RMSNorm for embedding normalization
3. eh_proj - Linear(2*hidden_size -> hidden_size) projection
4. transformer_block - Single decoder layer (attention + MLP)
5. Shared embedding/lm_head from main model
Forward pass:
h_norm = hnorm(hidden_state)
e_norm = enorm(embed(token))
projected = eh_proj(concat([h_norm, e_norm]))
new_hidden = transformer_block(projected)
logits = lm_head(output_norm(new_hidden))
"""
from typing import Any
import mlx.core as mx
import mlx.nn as nn
from mlx_lm.models.cache import KVCache
from mlx_lm.models.deepseek_v3 import (
DeepseekV3Attention,
DeepseekV3MLP,
ModelArgs,
)
MTP_LAYER_INDEX = 61
class MTPModule(nn.Module):
"""Multi-Token Prediction module for DeepSeek V3.
This module is initialized from the layer 61 weights that are normally
discarded during model loading. It enables speculative decoding by
predicting one token ahead using the hidden state from the main model.
"""
def __init__(
self,
config: ModelArgs,
shared_embedding: nn.Embedding,
shared_lm_head: nn.Linear,
output_norm: nn.RMSNorm,
) -> None:
super().__init__()
self.config = config
# MTP-specific normalization layers
self.hnorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.enorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
# Projection: concatenated [hidden, embedding] -> hidden_size
self.eh_proj = nn.Linear(2 * config.hidden_size, config.hidden_size, bias=False)
# Single transformer block for MTP
# Use a dense MLP since this is just a single layer
self.transformer_block = MTPTransformerBlock(config)
# Share embedding and lm_head with main model
self._shared_embedding = shared_embedding
self._shared_lm_head = shared_lm_head
self._output_norm = output_norm
def __call__(
self,
hidden_state: mx.array,
draft_token: mx.array,
cache: KVCache | None = None,
mask: mx.array | None = None,
) -> tuple[mx.array, mx.array]:
"""Forward pass for MTP.
Args:
hidden_state: Hidden state from main model [batch, seq_len, hidden_size]
draft_token: Token to embed and combine with hidden state [batch, seq_len]
cache: Optional KV cache for the MTP transformer block
mask: Optional attention mask
Returns:
tuple of (logits, new_hidden_state)
"""
# Get embedding of draft token
embedding = self._shared_embedding(draft_token)
# Normalize hidden state and embedding
h_norm = self.hnorm(hidden_state)
e_norm = self.enorm(embedding)
# Project concatenated representation
concatenated = mx.concatenate([h_norm, e_norm], axis=-1)
projected = self.eh_proj(concatenated)
# Pass through single transformer block
new_hidden = self.transformer_block(projected, mask=mask, cache=cache)
# Apply output norm and get logits
normed_hidden = self._output_norm(new_hidden)
logits = self._shared_lm_head(normed_hidden)
return logits, new_hidden
class MTPTransformerBlock(nn.Module):
"""Single transformer block for MTP.
This is similar to DeepseekV3DecoderLayer but uses a dense MLP
instead of MoE since this is just for the single MTP layer.
"""
def __init__(self, config: ModelArgs) -> None:
super().__init__()
self.self_attn = DeepseekV3Attention(config)
# MTP uses dense MLP, not MoE
self.mlp = DeepseekV3MLP(config)
self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = nn.RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
def __call__(
self,
x: mx.array,
mask: mx.array | None = None,
cache: Any | None = None,
) -> mx.array:
"""Forward pass with residual connections."""
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
r = self.mlp(self.post_attention_layernorm(h))
return h + r
def extract_mtp_weights(weights: dict[str, mx.array]) -> dict[str, mx.array]:
"""Extract MTP-specific weights from layer 61.
The MTP layer has these weight patterns:
- model.layers.61.enorm.weight -> MTP embedding normalization
- model.layers.61.hnorm.weight -> MTP hidden normalization
- model.layers.61.eh_proj.weight -> MTP projection layer
- model.layers.61.self_attn.* -> MTP attention
- model.layers.61.input_layernorm.* -> MTP layer norms
- model.layers.61.post_attention_layernorm.*
- model.layers.61.mlp.* -> MTP MLP (dense, not MoE)
Args:
weights: Full model weights dict
Returns:
Dict of MTP-specific weights with keys renamed for MTPModule
"""
mtp_weights: dict[str, mx.array] = {}
mtp_prefix = f"model.layers.{MTP_LAYER_INDEX}."
for key, value in weights.items():
if key.startswith(mtp_prefix):
# Remove the layer prefix to get relative path
new_key = key[len(mtp_prefix) :]
mtp_weights[new_key] = value
return mtp_weights
def load_mtp_weights_into_module(
mtp_module: MTPModule,
mtp_weights: dict[str, mx.array],
) -> None:
"""Load extracted MTP weights into the MTPModule.
Args:
mtp_module: The MTPModule instance to load weights into
mtp_weights: Extracted MTP weights from extract_mtp_weights()
"""
# Map weight names to module attributes
weight_mapping: dict[str, str] = {
"enorm.weight": "enorm.weight",
"hnorm.weight": "hnorm.weight",
"eh_proj.weight": "eh_proj.weight",
}
# Load direct mappings
for src_name, dst_name in weight_mapping.items():
if src_name in mtp_weights:
parts = dst_name.split(".")
obj: Any = mtp_module
for part in parts[:-1]:
obj = getattr(obj, part)
setattr(obj, parts[-1], mtp_weights[src_name])
# Load transformer block weights (self_attn, mlp, layer norms)
transformer_prefixes = [
"self_attn",
"mlp",
"input_layernorm",
"post_attention_layernorm",
]
for prefix in transformer_prefixes:
for key, value in mtp_weights.items():
if key.startswith(prefix):
# Navigate to the correct attribute
parts = key.split(".")
obj = mtp_module.transformer_block
for part in parts[:-1]:
obj = getattr(obj, part)
setattr(obj, parts[-1], value)

View File

@@ -1,506 +0,0 @@
"""MTP Speculative Decoding for DeepSeek V3.
This module implements speculative decoding using the Multi-Token Prediction (MTP)
layer from DeepSeek V3. The key difference from standard speculative decoding is
that MTP requires hidden states from the main model, not just token predictions.
Based on vLLM/SGLang research:
- 81-82% acceptance rate with k=1
- 1.5-2x speedup at low QPS
"""
import functools
import time
from collections.abc import Callable, Generator
from dataclasses import dataclass
from typing import Any, cast
import mlx.core as mx
import mlx.nn as nn
from mlx_lm.models import cache
from mlx_lm.models.cache import KVCache
from mlx_lm.tokenizer_utils import TokenizerWrapper
from exo.worker.engines.mlx.mtp.module import MTPModule
# Generation stream for async operations
generation_stream = mx.new_stream(mx.default_device())
@dataclass
class MTPGenerationResponse:
"""Response from MTP speculative generation.
Attributes:
text: The next segment of decoded text.
token: The next token.
logprobs: A vector of log probabilities.
from_draft: Whether the token was generated by the MTP draft module.
prompt_tokens: The number of tokens in the prompt.
prompt_tps: The prompt processing tokens-per-second.
generation_tokens: The number of generated tokens.
generation_tps: The tokens-per-second for generation.
peak_memory: The peak memory used so far in GB.
finish_reason: The reason the response is being sent: "length", "stop" or None.
"""
text: str
token: int
logprobs: mx.array
from_draft: bool
prompt_tokens: int
prompt_tps: float
generation_tokens: int
generation_tps: float
peak_memory: float
finish_reason: str | None = None
def maybe_quantize_kv_cache(
prompt_cache: list[Any],
quantized_kv_start: int,
kv_group_size: int,
kv_bits: int | None,
) -> None:
"""Quantize KV cache entries if needed."""
if kv_bits is None:
return
for e, c in enumerate(prompt_cache):
if (
hasattr(c, "to_quantized")
and hasattr(c, "offset")
and c.offset >= quantized_kv_start
):
prompt_cache[e] = c.to_quantized(group_size=kv_group_size, bits=kv_bits)
class ModelWithHiddenStates(nn.Module):
"""Wrapper to extract hidden states before lm_head.
This wrapper allows capturing the hidden states from the transformer
layers before the final lm_head projection, which is needed for MTP.
"""
def __init__(self, base_model: nn.Module) -> None:
super().__init__()
self._base = base_model
def forward_with_hidden(
self,
inputs: mx.array,
model_cache: list[Any] | None = None,
) -> tuple[mx.array, mx.array]:
"""Forward pass that returns both logits and hidden states.
Args:
inputs: Input token ids
model_cache: KV cache
Returns:
Tuple of (logits, hidden_states)
"""
# Call the inner model (transformer layers + norm)
hidden: mx.array = self._base.model(inputs, model_cache)
# Get logits from lm_head
logits: mx.array = self._base.lm_head(hidden)
return logits, hidden
def forward(
self,
inputs: mx.array,
model_cache: list[Any] | None = None,
) -> mx.array:
"""Standard forward pass returning only logits."""
return cast(mx.array, self._base(inputs, cache=model_cache))
@property
def layers(self) -> list[nn.Module]:
"""Access layers for cache creation."""
return cast(list[nn.Module], self._base.layers)
def mtp_speculative_generate_step(
prompt: mx.array,
model: nn.Module,
mtp_module: MTPModule,
*,
num_draft_tokens: int = 1,
max_tokens: int = 256,
sampler: Callable[[mx.array], mx.array] | None = None,
logits_processors: list[Callable[[mx.array, mx.array], mx.array]] | None = None,
prompt_cache: list[Any] | None = None,
mtp_cache: KVCache | None = None,
prefill_step_size: int = 512,
kv_bits: int | None = None,
kv_group_size: int = 64,
quantized_kv_start: int = 0,
) -> Generator[tuple[int, mx.array, bool], None, None]:
"""MTP speculative decoding generator.
Unlike standard speculative decoding where the draft model only needs tokens,
MTP requires the hidden states from the main model. This generator:
1. Runs the main model to get logits AND hidden states
2. Uses MTP module with hidden state + sampled token to predict next token
3. Verifies MTP predictions with the main model
4. Accepts/rejects based on matching
Args:
prompt: The input prompt as token ids
model: The main model (must support return_hidden=True)
mtp_module: The MTP module for draft prediction
num_draft_tokens: Number of tokens to draft (typically 1 for MTP)
max_tokens: Maximum number of tokens to generate
sampler: Optional sampler function for token selection
logits_processors: Optional list of logits processors
prompt_cache: KV cache for the main model
mtp_cache: KV cache for the MTP module
prefill_step_size: Step size for prompt processing
kv_bits: Bits for KV cache quantization
kv_group_size: Group size for KV cache quantization
quantized_kv_start: Step to begin cache quantization
Yields:
Tuple of (token, logprobs, from_draft)
"""
y = prompt.astype(mx.uint32)
prev_tokens: mx.array | None = None
# Wrap model to get hidden states
wrapped_model = (
model
if isinstance(model, ModelWithHiddenStates)
else ModelWithHiddenStates(model)
)
# Create caches if needed
if prompt_cache is None:
prompt_cache = cache.make_prompt_cache(model)
if mtp_cache is None:
mtp_cache = KVCache()
final_sampler = (
sampler if sampler is not None else (lambda x: mx.argmax(x, axis=-1))
)
quantize_cache_fn = functools.partial(
maybe_quantize_kv_cache,
quantized_kv_start=quantized_kv_start,
kv_group_size=kv_group_size,
kv_bits=kv_bits,
)
def _process_and_sample(
tokens: mx.array | None,
logits: mx.array,
) -> tuple[mx.array, mx.array]:
"""Process logits and sample tokens."""
nonlocal logits_processors
processed_logits = logits
if logits_processors:
for processor in logits_processors:
processed_logits = processor(
tokens if tokens is not None else mx.array([]), processed_logits
)
logprobs = processed_logits - mx.logsumexp(
processed_logits, axis=-1, keepdims=True
)
sampled = final_sampler(logprobs)
return sampled, logprobs
def _main_model_step_with_hidden(
input_y: mx.array,
) -> tuple[mx.array, mx.array, mx.array]:
"""Run main model step with hidden state return."""
nonlocal prev_tokens
with mx.stream(generation_stream):
logits, hidden = wrapped_model.forward_with_hidden(
input_y[None], prompt_cache
)
logits = logits[:, -1, :]
quantize_cache_fn(prompt_cache)
if logits_processors:
prev_tokens = (
mx.concatenate([prev_tokens, input_y])
if prev_tokens is not None
else input_y
)
sampled, logprobs_result = _process_and_sample(prev_tokens, logits)
return sampled, logprobs_result.squeeze(0), hidden[:, -1:, :]
def _main_model_step(
input_y: mx.array,
) -> tuple[mx.array, mx.array]:
"""Run main model step without hidden state."""
nonlocal prev_tokens
with mx.stream(generation_stream):
logits = wrapped_model.forward(input_y[None], prompt_cache)
logits = logits[:, -1, :]
quantize_cache_fn(prompt_cache)
if logits_processors:
prev_tokens = (
mx.concatenate([prev_tokens, input_y])
if prev_tokens is not None
else input_y
)
sampled, logprobs_result = _process_and_sample(prev_tokens, logits)
return sampled, logprobs_result.squeeze(0)
def _mtp_draft(
hidden_state: mx.array,
draft_token: mx.array,
) -> tuple[mx.array, mx.array]:
"""Generate draft token using MTP module."""
with mx.stream(generation_stream):
logits, new_hidden = mtp_module(
hidden_state,
draft_token,
cache=mtp_cache,
)
logits = logits[:, -1, :]
sampled, _ = _process_and_sample(None, logits)
return sampled, new_hidden
def _prefill(input_y: mx.array) -> mx.array:
"""Prefill the prompt cache."""
result_y = input_y
while result_y.size > prefill_step_size:
_ = wrapped_model.forward(result_y[:prefill_step_size][None], prompt_cache)
quantize_cache_fn(prompt_cache)
mx.eval([c.state for c in prompt_cache])
result_y = result_y[prefill_step_size:]
mx.clear_cache()
return result_y
def _rewind_cache(num_draft: int, num_accept: int) -> None:
"""Rewind caches after rejection."""
cache.trim_prompt_cache(prompt_cache, num_draft - num_accept)
# Prefill phase
with mx.stream(generation_stream):
y = _prefill(y)
ntoks = 0
num_draft = 0
n_accepted = 0
last_hidden: mx.array | None = None
try:
# Initial step to get first token and hidden state
sampled, logprobs, last_hidden = _main_model_step_with_hidden(y)
mx.eval(sampled, logprobs, last_hidden)
y = sampled
current_logprobs = logprobs
while ntoks < max_tokens:
# Draft phase: use MTP to predict next token
num_draft = min(max_tokens - ntoks - 1, num_draft_tokens)
if num_draft > 0 and last_hidden is not None:
# Use MTP to draft
draft_token, draft_hidden = _mtp_draft(last_hidden, y)
mx.eval(draft_token, draft_hidden)
# Verify with main model
# Feed the drafted token to main model
verify_input = mx.concatenate([y, draft_token.flatten()])
verify_sampled, verify_logprobs, new_hidden = (
_main_model_step_with_hidden(verify_input)
)
mx.eval(verify_sampled, verify_logprobs, new_hidden)
# Check if draft matches verification
draft_token_val = int(draft_token.item())
verify_token_val = (
int(verify_sampled[0].item())
if verify_sampled.shape[0] > 1
else int(verify_sampled.item())
)
# Yield the current token (not from draft)
ntoks += 1
yield int(y.item()), current_logprobs, False
if ntoks >= max_tokens:
break
if draft_token_val == verify_token_val:
# Draft accepted
n_accepted += 1
ntoks += 1
draft_logprobs = (
verify_logprobs[0]
if verify_logprobs.ndim > 1
else verify_logprobs
)
yield draft_token_val, draft_logprobs, True
if ntoks >= max_tokens:
break
# Continue with the token after the draft
y = (
verify_sampled[-1:]
if verify_sampled.ndim > 0 and verify_sampled.shape[0] > 1
else verify_sampled
)
current_logprobs = (
verify_logprobs[-1]
if verify_logprobs.ndim > 1
else verify_logprobs
)
last_hidden = new_hidden
else:
# Draft rejected - rewind and use verified token
_rewind_cache(1, 0)
y = (
verify_sampled[:1]
if verify_sampled.ndim > 0 and verify_sampled.shape[0] > 1
else verify_sampled
)
current_logprobs = (
verify_logprobs[0]
if verify_logprobs.ndim > 1
else verify_logprobs
)
last_hidden = (
new_hidden[:, :1, :] if new_hidden is not None else None
)
else:
# No drafting, just do normal generation
ntoks += 1
yield int(y.item()), current_logprobs, False
if ntoks >= max_tokens:
break
sampled, logprobs, last_hidden = _main_model_step_with_hidden(y)
mx.eval(sampled, logprobs, last_hidden)
y = sampled
current_logprobs = logprobs
if ntoks % 256 == 0:
mx.clear_cache()
finally:
_rewind_cache(num_draft, n_accepted)
def mtp_speculative_generate(
model: nn.Module,
mtp_module: MTPModule,
tokenizer: TokenizerWrapper,
prompt: str | mx.array | list[int],
max_tokens: int = 256,
sampler: Callable[[mx.array], mx.array] | None = None,
logits_processors: list[Callable[[mx.array, mx.array], mx.array]] | None = None,
prompt_cache: list[Any] | None = None,
num_draft_tokens: int = 1,
prefill_step_size: int = 512,
kv_group_size: int = 64,
kv_bits: int | None = None,
) -> Generator[MTPGenerationResponse, None, None]:
"""High-level MTP speculative generation with text output.
Args:
model: The main model
mtp_module: The MTP module for draft prediction
tokenizer: Tokenizer for encoding/decoding
prompt: Input prompt (string, array, or token list)
max_tokens: Maximum tokens to generate
sampler: Optional sampler function
logits_processors: Optional logits processors
prompt_cache: Optional KV cache
num_draft_tokens: Number of draft tokens
prefill_step_size: Prefill step size
kv_group_size: KV group size
kv_bits: KV bits
Yields:
MTPGenerationResponse objects with text and metadata
"""
if not isinstance(prompt, mx.array):
if isinstance(prompt, str):
bos_token = getattr(tokenizer, "bos_token", None)
add_special_tokens = bos_token is None or not prompt.startswith(
str(bos_token)
)
encoded: list[int] = tokenizer.encode(
prompt, add_special_tokens=add_special_tokens
)
prompt = mx.array(encoded)
else:
prompt = mx.array(prompt)
detokenizer = tokenizer.detokenizer
eos_token_ids: list[int] = getattr(tokenizer, "eos_token_ids", [])
token_generator = mtp_speculative_generate_step(
prompt,
model,
mtp_module,
max_tokens=max_tokens,
sampler=sampler,
logits_processors=logits_processors,
prompt_cache=prompt_cache,
num_draft_tokens=num_draft_tokens,
prefill_step_size=prefill_step_size,
kv_group_size=kv_group_size,
kv_bits=kv_bits,
)
tic = time.perf_counter()
prompt_tps = 0.0
token = 0
logprobs: mx.array = mx.array([0.0])
from_draft = False
n = 0
for n, (token, logprobs, from_draft) in enumerate(token_generator):
if n == 0:
prompt_time = time.perf_counter() - tic
prompt_tps = float(prompt.size) / prompt_time
tic = time.perf_counter()
if token in eos_token_ids:
break
detokenizer.add_token(token)
if (n + 1) == max_tokens:
break
yield MTPGenerationResponse(
text=str(detokenizer.last_segment),
token=token,
logprobs=logprobs,
from_draft=from_draft,
prompt_tokens=int(prompt.size),
prompt_tps=prompt_tps,
generation_tokens=n + 1,
generation_tps=(n + 1) / (time.perf_counter() - tic),
peak_memory=mx.get_peak_memory() / 1e9,
finish_reason=None,
)
detokenizer.finalize()
yield MTPGenerationResponse(
text=str(detokenizer.last_segment),
token=token,
logprobs=logprobs,
from_draft=from_draft,
prompt_tokens=int(prompt.size),
prompt_tps=prompt_tps,
generation_tokens=n + 1,
generation_tps=(n + 1) / (time.perf_counter() - tic),
peak_memory=mx.get_peak_memory() / 1e9,
finish_reason="stop" if token in eos_token_ids else "length",
)

View File

@@ -1 +0,0 @@
"""Tests for MTP module."""

View File

@@ -1,412 +0,0 @@
"""Unit tests for MTP module components."""
import mlx.core as mx
import mlx.nn as nn
import pytest
from exo.worker.engines.mlx.mtp.module import (
MTP_LAYER_INDEX,
MTPModule,
MTPTransformerBlock,
extract_mtp_weights,
load_mtp_weights_into_module,
)
class MockModelArgs:
"""Mock ModelArgs for testing without importing deepseek_v3."""
def __init__(
self,
hidden_size: int = 256,
intermediate_size: int = 512,
num_attention_heads: int = 4,
num_key_value_heads: int = 4,
rms_norm_eps: float = 1e-6,
vocab_size: int = 1000,
q_lora_rank: int | None = None,
kv_lora_rank: int = 64,
qk_rope_head_dim: int = 16,
v_head_dim: int = 32,
qk_nope_head_dim: int = 32,
rope_theta: float = 10000.0,
rope_scaling: dict | None = None,
attention_bias: bool = False,
max_position_embeddings: int = 2048,
):
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_attention_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.rms_norm_eps = rms_norm_eps
self.vocab_size = vocab_size
self.q_lora_rank = q_lora_rank
self.kv_lora_rank = kv_lora_rank
self.qk_rope_head_dim = qk_rope_head_dim
self.v_head_dim = v_head_dim
self.qk_nope_head_dim = qk_nope_head_dim
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.attention_bias = attention_bias
self.max_position_embeddings = max_position_embeddings
class TestExtractMTPWeights:
"""Tests for extract_mtp_weights function."""
def test_extracts_layer_61_weights(self) -> None:
"""Should extract only layer 61 weights."""
weights = {
"model.layers.60.self_attn.weight": mx.zeros((10, 10)),
"model.layers.61.enorm.weight": mx.ones((10,)),
"model.layers.61.hnorm.weight": mx.ones((10,)) * 2,
"model.layers.61.eh_proj.weight": mx.ones((10, 20)),
"model.layers.62.self_attn.weight": mx.zeros((10, 10)),
"model.embed_tokens.weight": mx.zeros((100, 10)),
}
mtp_weights = extract_mtp_weights(weights)
assert len(mtp_weights) == 3
assert "enorm.weight" in mtp_weights
assert "hnorm.weight" in mtp_weights
assert "eh_proj.weight" in mtp_weights
# Check values are preserved
assert mx.allclose(mtp_weights["enorm.weight"], mx.ones((10,)))
assert mx.allclose(mtp_weights["hnorm.weight"], mx.ones((10,)) * 2)
def test_returns_empty_dict_when_no_layer_61(self) -> None:
"""Should return empty dict when layer 61 doesn't exist."""
weights = {
"model.layers.0.self_attn.weight": mx.zeros((10, 10)),
"model.layers.60.self_attn.weight": mx.zeros((10, 10)),
}
mtp_weights = extract_mtp_weights(weights)
assert len(mtp_weights) == 0
def test_handles_nested_layer_61_weights(self) -> None:
"""Should handle nested weight paths like self_attn.q_proj.weight."""
weights = {
f"model.layers.{MTP_LAYER_INDEX}.self_attn.q_a_proj.weight": mx.zeros(
(10, 10)
),
f"model.layers.{MTP_LAYER_INDEX}.mlp.gate_proj.weight": mx.zeros((20, 10)),
}
mtp_weights = extract_mtp_weights(weights)
assert "self_attn.q_a_proj.weight" in mtp_weights
assert "mlp.gate_proj.weight" in mtp_weights
class TestMTPTransformerBlock:
"""Tests for MTPTransformerBlock."""
@pytest.fixture
def config(self) -> MockModelArgs:
return MockModelArgs(
hidden_size=64, intermediate_size=128, num_attention_heads=2
)
def test_forward_shape(self, config: MockModelArgs) -> None:
"""Forward pass should preserve input shape."""
# Skip if deepseek_v3 imports fail (CI without mlx_lm)
pytest.importorskip("mlx_lm.models.deepseek_v3")
block = MTPTransformerBlock(config) # type: ignore[arg-type]
x = mx.random.normal((1, 5, config.hidden_size))
output = block(x)
assert output.shape == x.shape
def test_forward_with_mask(self, config: MockModelArgs) -> None:
"""Forward pass should work with attention mask."""
pytest.importorskip("mlx_lm.models.deepseek_v3")
block = MTPTransformerBlock(config) # type: ignore[arg-type]
x = mx.random.normal((1, 5, config.hidden_size))
# Create causal mask
mask = mx.triu(mx.full((5, 5), float("-inf")), k=1)
output = block(x, mask=mask)
assert output.shape == x.shape
class TestMTPModule:
"""Tests for MTPModule."""
@pytest.fixture
def config(self) -> MockModelArgs:
return MockModelArgs(
hidden_size=64,
intermediate_size=128,
num_attention_heads=2,
vocab_size=100,
)
@pytest.fixture
def shared_components(
self, config: MockModelArgs
) -> tuple[nn.Embedding, nn.Linear, nn.RMSNorm]:
embedding = nn.Embedding(config.vocab_size, config.hidden_size)
lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
output_norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
return embedding, lm_head, output_norm
def test_initialization(
self,
config: MockModelArgs,
shared_components: tuple[nn.Embedding, nn.Linear, nn.RMSNorm],
) -> None:
"""MTPModule should initialize with correct components."""
pytest.importorskip("mlx_lm.models.deepseek_v3")
embedding, lm_head, output_norm = shared_components
mtp = MTPModule(
config=config, # type: ignore[arg-type]
shared_embedding=embedding,
shared_lm_head=lm_head,
output_norm=output_norm,
)
assert mtp.hnorm is not None
assert mtp.enorm is not None
assert mtp.eh_proj is not None
assert mtp.transformer_block is not None
def test_forward_output_shapes(
self,
config: MockModelArgs,
shared_components: tuple[nn.Embedding, nn.Linear, nn.RMSNorm],
) -> None:
"""Forward pass should return correct output shapes."""
pytest.importorskip("mlx_lm.models.deepseek_v3")
embedding, lm_head, output_norm = shared_components
mtp = MTPModule(
config=config, # type: ignore[arg-type]
shared_embedding=embedding,
shared_lm_head=lm_head,
output_norm=output_norm,
)
batch_size = 2
seq_len = 1
hidden_state = mx.random.normal((batch_size, seq_len, config.hidden_size))
draft_token = mx.array([[5], [10]]) # [batch, seq_len]
logits, new_hidden = mtp(hidden_state, draft_token)
assert logits.shape == (batch_size, seq_len, config.vocab_size)
assert new_hidden.shape == (batch_size, seq_len, config.hidden_size)
def test_shares_embedding_and_lm_head(
self,
config: MockModelArgs,
shared_components: tuple[nn.Embedding, nn.Linear, nn.RMSNorm],
) -> None:
"""MTPModule should use shared embedding and lm_head."""
pytest.importorskip("mlx_lm.models.deepseek_v3")
embedding, lm_head, output_norm = shared_components
mtp = MTPModule(
config=config, # type: ignore[arg-type]
shared_embedding=embedding,
shared_lm_head=lm_head,
output_norm=output_norm,
)
# Verify they're the same objects
assert mtp._shared_embedding is embedding
assert mtp._shared_lm_head is lm_head
assert mtp._output_norm is output_norm
class TestLoadMTPWeights:
"""Tests for load_mtp_weights_into_module."""
@pytest.fixture
def config(self) -> MockModelArgs:
return MockModelArgs(
hidden_size=64,
intermediate_size=128,
num_attention_heads=2,
vocab_size=100,
)
def test_loads_norm_weights(self, config: MockModelArgs) -> None:
"""Should load enorm and hnorm weights."""
pytest.importorskip("mlx_lm.models.deepseek_v3")
embedding = nn.Embedding(config.vocab_size, config.hidden_size)
lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
output_norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
mtp = MTPModule(
config=config, # type: ignore[arg-type]
shared_embedding=embedding,
shared_lm_head=lm_head,
output_norm=output_norm,
)
# Create test weights
test_enorm = mx.ones((config.hidden_size,)) * 3.0
test_hnorm = mx.ones((config.hidden_size,)) * 5.0
mtp_weights = {
"enorm.weight": test_enorm,
"hnorm.weight": test_hnorm,
}
load_mtp_weights_into_module(mtp, mtp_weights)
assert mx.allclose(mtp.enorm.weight, test_enorm)
assert mx.allclose(mtp.hnorm.weight, test_hnorm)
class TestSanitizePatch:
"""Tests for the sanitize patching logic."""
def test_patch_preserves_layer_61(self) -> None:
"""Patching sanitize should preserve layer 61 weights."""
from exo.worker.engines.mlx.utils_mlx import (
_patch_deepseek_sanitize_for_mtp,
_restore_deepseek_sanitize,
)
deepseek_v3 = pytest.importorskip("mlx_lm.models.deepseek_v3")
model_cls = deepseek_v3.Model
# Get original sanitize behavior
original_sanitize = model_cls.sanitize
try:
# Apply patch
_patch_deepseek_sanitize_for_mtp()
# Note: we can't easily test the full sanitize without a real model
# This test verifies the patch is applied
assert model_cls.sanitize is not original_sanitize
finally:
_restore_deepseek_sanitize()
# Verify restore worked
assert model_cls.sanitize is original_sanitize
def test_restore_sanitize(self) -> None:
"""Restoring sanitize should return to original behavior."""
from exo.worker.engines.mlx.utils_mlx import (
_patch_deepseek_sanitize_for_mtp,
_restore_deepseek_sanitize,
)
deepseek_v3 = pytest.importorskip("mlx_lm.models.deepseek_v3")
model_cls = deepseek_v3.Model
original_sanitize = model_cls.sanitize
_patch_deepseek_sanitize_for_mtp()
assert model_cls.sanitize is not original_sanitize
_restore_deepseek_sanitize()
assert model_cls.sanitize is original_sanitize
def test_double_patch_is_safe(self) -> None:
"""Calling patch twice should be safe (idempotent)."""
from exo.worker.engines.mlx.utils_mlx import (
_patch_deepseek_sanitize_for_mtp,
_restore_deepseek_sanitize,
)
deepseek_v3 = pytest.importorskip("mlx_lm.models.deepseek_v3")
model_cls = deepseek_v3.Model
original_sanitize = model_cls.sanitize
try:
_patch_deepseek_sanitize_for_mtp()
patched_sanitize = model_cls.sanitize
# Patch again - should be no-op
_patch_deepseek_sanitize_for_mtp()
assert model_cls.sanitize is patched_sanitize
finally:
_restore_deepseek_sanitize()
assert model_cls.sanitize is original_sanitize
class TestModelIdDetection:
"""Tests for DeepSeek V3 model ID detection."""
def test_detects_deepseek_v3(self) -> None:
"""Should detect DeepSeek V3 model IDs."""
from exo.worker.engines.mlx.utils_mlx import _might_be_deepseek_v3
assert _might_be_deepseek_v3("deepseek-ai/DeepSeek-V3")
assert _might_be_deepseek_v3("deepseek-ai/deepseek-v3-base")
assert _might_be_deepseek_v3("mlx-community/DeepSeek-V3-4bit")
def test_detects_deepseek_r1(self) -> None:
"""Should detect DeepSeek R1 model IDs (also uses MTP)."""
from exo.worker.engines.mlx.utils_mlx import _might_be_deepseek_v3
assert _might_be_deepseek_v3("deepseek-ai/DeepSeek-R1")
assert _might_be_deepseek_v3("mlx-community/DeepSeek-R1-4bit")
def test_rejects_non_deepseek(self) -> None:
"""Should reject non-DeepSeek model IDs."""
from exo.worker.engines.mlx.utils_mlx import _might_be_deepseek_v3
assert not _might_be_deepseek_v3("meta-llama/Llama-3-70B")
assert not _might_be_deepseek_v3("mistralai/Mixtral-8x7B")
assert not _might_be_deepseek_v3("deepseek-ai/DeepSeek-V2") # V2, not V3
def test_case_insensitive(self) -> None:
"""Detection should be case insensitive."""
from exo.worker.engines.mlx.utils_mlx import _might_be_deepseek_v3
assert _might_be_deepseek_v3("DEEPSEEK-AI/DEEPSEEK-V3")
assert _might_be_deepseek_v3("DeepSeek-AI/deepseek-v3")
class TestFlattenParams:
"""Tests for parameter flattening utility."""
def test_flattens_nested_dict(self) -> None:
"""Should flatten nested parameter dict."""
from exo.worker.engines.mlx.utils_mlx import _flatten_params
params = {
"model": {
"layers": {
"0": {
"weight": mx.zeros((10,)),
}
},
"embed": mx.ones((5,)),
}
}
flat = _flatten_params(params)
assert "model.layers.0.weight" in flat
assert "model.embed" in flat
assert mx.allclose(flat["model.layers.0.weight"], mx.zeros((10,)))
assert mx.allclose(flat["model.embed"], mx.ones((5,)))
def test_handles_flat_dict(self) -> None:
"""Should handle already-flat dict."""
from exo.worker.engines.mlx.utils_mlx import _flatten_params
params = {
"weight": mx.zeros((10,)),
"bias": mx.ones((10,)),
}
flat = _flatten_params(params)
assert flat == params

View File

@@ -1,253 +0,0 @@
"""Unit tests for MTP speculative decoding."""
import mlx.core as mx
import mlx.nn as nn
import pytest
from exo.worker.engines.mlx.mtp.speculative_decode import (
ModelWithHiddenStates,
maybe_quantize_kv_cache,
)
class MockModel(nn.Module):
"""Mock model for testing speculative decoding."""
def __init__(self, hidden_size: int = 64, vocab_size: int = 100) -> None:
super().__init__()
self.hidden_size = hidden_size
self.vocab_size = vocab_size
# Create simple model components
self.model = MockInnerModel(hidden_size)
self.lm_head = nn.Linear(hidden_size, vocab_size, bias=False)
self._layers = [nn.Linear(hidden_size, hidden_size) for _ in range(3)]
def __call__(
self,
inputs: mx.array,
cache: list | None = None,
) -> mx.array:
hidden = self.model(inputs, cache)
return self.lm_head(hidden)
@property
def layers(self) -> list[nn.Module]:
return self._layers
class MockInnerModel(nn.Module):
"""Mock inner model (like DeepseekV3Model)."""
def __init__(self, hidden_size: int) -> None:
super().__init__()
self.embed_tokens = nn.Embedding(100, hidden_size)
self.norm = nn.RMSNorm(hidden_size)
def __call__(
self,
inputs: mx.array,
cache: list | None = None,
) -> mx.array:
# Simple embedding + norm
embedded = self.embed_tokens(inputs)
return self.norm(embedded)
class TestModelWithHiddenStates:
"""Tests for ModelWithHiddenStates wrapper."""
@pytest.fixture
def mock_model(self) -> MockModel:
return MockModel(hidden_size=64, vocab_size=100)
def test_forward_returns_logits(self, mock_model: MockModel) -> None:
"""Standard forward should return logits."""
wrapped = ModelWithHiddenStates(mock_model)
inputs = mx.array([[1, 2, 3]])
logits = wrapped.forward(inputs)
assert logits.shape == (1, 3, mock_model.vocab_size)
def test_forward_with_hidden_returns_tuple(self, mock_model: MockModel) -> None:
"""Forward with hidden should return (logits, hidden)."""
wrapped = ModelWithHiddenStates(mock_model)
inputs = mx.array([[1, 2, 3]])
logits, hidden = wrapped.forward_with_hidden(inputs)
assert logits.shape == (1, 3, mock_model.vocab_size)
assert hidden.shape == (1, 3, mock_model.hidden_size)
def test_layers_property(self, mock_model: MockModel) -> None:
"""Should expose layers property from base model."""
wrapped = ModelWithHiddenStates(mock_model)
assert wrapped.layers == mock_model.layers
assert len(wrapped.layers) == 3
class TestMaybeQuantizeKVCache:
"""Tests for KV cache quantization."""
def test_no_quantization_when_bits_none(self) -> None:
"""Should not quantize when kv_bits is None."""
cache = [MockCache(offset=100)]
maybe_quantize_kv_cache(
cache,
quantized_kv_start=50,
kv_group_size=64,
kv_bits=None,
)
# Cache should be unchanged
assert not hasattr(cache[0], "quantized")
def test_respects_quantized_kv_start(self) -> None:
"""Should only quantize caches past the start threshold."""
cache_below = MockCache(offset=30)
cache_above = MockCache(offset=100)
caches = [cache_below, cache_above]
maybe_quantize_kv_cache(
caches,
quantized_kv_start=50,
kv_group_size=64,
kv_bits=4,
)
# Only cache_above should be quantized
assert not getattr(cache_below, "was_quantized", False)
assert getattr(caches[1], "was_quantized", False)
class MockCache:
"""Mock KV cache for testing."""
def __init__(self, offset: int = 0) -> None:
self.offset = offset
self.was_quantized = False
def to_quantized(self, group_size: int, bits: int) -> "MockCache":
quantized = MockCache(self.offset)
quantized.was_quantized = True
return quantized
class TestSpeculativeDecodingLogic:
"""Tests for the core speculative decoding logic."""
def test_draft_acceptance_identical_tokens(self) -> None:
"""When draft matches verification, both should be accepted."""
# This tests the logic, not the full generator
draft_token = 42
verify_token = 42
accepted = draft_token == verify_token
assert accepted
def test_draft_rejection_different_tokens(self) -> None:
"""When draft differs from verification, draft should be rejected."""
draft_token = 42
verify_token = 99
accepted = draft_token == verify_token
assert not accepted
class TestMTPGenerationResponse:
"""Tests for MTPGenerationResponse dataclass."""
def test_response_creation(self) -> None:
"""Should create response with all fields."""
from exo.worker.engines.mlx.mtp.speculative_decode import MTPGenerationResponse
response = MTPGenerationResponse(
text="Hello",
token=42,
logprobs=mx.array([0.1, 0.2]),
from_draft=True,
prompt_tokens=10,
prompt_tps=100.0,
generation_tokens=5,
generation_tps=50.0,
peak_memory=1.5,
finish_reason=None,
)
assert response.text == "Hello"
assert response.token == 42
assert response.from_draft is True
assert response.finish_reason is None
def test_response_with_finish_reason(self) -> None:
"""Should handle finish_reason."""
from exo.worker.engines.mlx.mtp.speculative_decode import MTPGenerationResponse
response = MTPGenerationResponse(
text="",
token=0,
logprobs=mx.array([0.0]),
from_draft=False,
prompt_tokens=10,
prompt_tps=100.0,
generation_tokens=100,
generation_tps=50.0,
peak_memory=1.5,
finish_reason="length",
)
assert response.finish_reason == "length"
class TestIntegration:
"""Integration tests for the full MTP pipeline."""
def test_mtp_module_with_mock_model(self) -> None:
"""Test MTP module can be created and run with mock components."""
pytest.importorskip("mlx_lm.models.deepseek_v3")
from exo.worker.engines.mlx.mtp.module import MTPModule
# Create mock config
class MockConfig:
hidden_size = 64
intermediate_size = 128
num_attention_heads = 2
num_key_value_heads = 2
rms_norm_eps = 1e-6
q_lora_rank = None
kv_lora_rank = 32
qk_rope_head_dim = 8
v_head_dim = 16
qk_nope_head_dim = 16
rope_theta = 10000.0
rope_scaling = None
attention_bias = False
max_position_embeddings = 2048
config = MockConfig()
embedding = nn.Embedding(100, config.hidden_size)
lm_head = nn.Linear(config.hidden_size, 100, bias=False)
output_norm = nn.RMSNorm(config.hidden_size)
mtp = MTPModule(
config=config, # type: ignore[arg-type]
shared_embedding=embedding,
shared_lm_head=lm_head,
output_norm=output_norm,
)
# Run forward pass
hidden = mx.random.normal((1, 1, config.hidden_size))
token = mx.array([[5]])
logits, new_hidden = mtp(hidden, token)
assert logits.shape == (1, 1, 100)
assert new_hidden.shape == (1, 1, config.hidden_size)
# Verify outputs are valid (not NaN)
assert not mx.any(mx.isnan(logits))
assert not mx.any(mx.isnan(new_hidden))

View File

@@ -28,7 +28,6 @@ from mlx_lm.tokenizer_utils import TokenizerWrapper
from exo.worker.engines.mlx.constants import (
CACHE_GROUP_SIZE,
KV_CACHE_BITS,
MTP_ENABLED,
TRUST_REMOTE_CODE,
)
@@ -70,67 +69,6 @@ Group = mx.distributed.Group
resource.setrlimit(resource.RLIMIT_NOFILE, (2048, 4096))
# MTP (Multi-Token Prediction) support for DeepSeek V3
MTP_LAYER_INDEX = 61
_original_deepseek_sanitize: Callable[..., dict[str, Any]] | None = None
def _is_deepseek_v3_model(model: nn.Module) -> bool:
"""Check if the model is DeepSeek V3."""
return hasattr(model, "model") and isinstance(model.model, DeepseekV3Model)
def _patch_deepseek_sanitize_for_mtp() -> None:
"""Patch DeepSeek V3 Model.sanitize to preserve MTP layer weights.
The original sanitize() method filters out layer 61 (MTP layer) weights.
This patch keeps them so we can extract and use the MTP module.
"""
global _original_deepseek_sanitize
from mlx_lm.models.deepseek_v3 import Model as DeepSeekV3Model
if _original_deepseek_sanitize is not None:
# Already patched
return
_original_deepseek_sanitize = DeepSeekV3Model.sanitize
def sanitize_with_mtp(
self: DeepSeekV3Model, weights: dict[str, Any]
) -> dict[str, Any]:
"""Modified sanitize that keeps MTP layer weights."""
# First, call the original sanitize to handle all the weight transformations
# (dequantization, expert stacking, etc.)
if _original_deepseek_sanitize is None:
raise RuntimeError(
"_original_deepseek_sanitize is None - patch not applied correctly"
)
original_result: dict[str, Any] = _original_deepseek_sanitize(self, weights)
# Re-add the MTP layer weights that were filtered out
mtp_weights = {
k: v
for k, v in weights.items()
if k.startswith(f"model.layers.{MTP_LAYER_INDEX}")
}
return {**original_result, **mtp_weights}
DeepSeekV3Model.sanitize = sanitize_with_mtp
def _restore_deepseek_sanitize() -> None:
"""Restore the original DeepSeek V3 sanitize method."""
global _original_deepseek_sanitize
if _original_deepseek_sanitize is None:
return
from mlx_lm.models.deepseek_v3 import Model as DeepSeekV3Model
DeepSeekV3Model.sanitize = _original_deepseek_sanitize
_original_deepseek_sanitize = None
# TODO: Test this
# ALSO https://github.com/exo-explore/exo/pull/233#discussion_r2549683673
def get_weights_size(model_shard_meta: ShardMetadata) -> Memory:
@@ -295,164 +233,31 @@ def load_mlx_items(
group: Group | None,
on_timeout: TimeoutCallback | None = None,
) -> tuple[Model, TokenizerWrapper]:
"""Load MLX model and tokenizer.
if group is None:
logger.info(f"Single device used for {bound_instance.instance}")
model_path = build_model_path(bound_instance.bound_shard.model_meta.model_id)
start_time = time.perf_counter()
model, _ = load_model(model_path, strict=True)
end_time = time.perf_counter()
logger.info(f"Time taken to load model: {(end_time - start_time):.2f}s")
tokenizer = get_tokenizer(model_path, bound_instance.bound_shard)
Returns:
Tuple of (model, tokenizer)
"""
model_id = bound_instance.bound_shard.model_meta.model_id
mtp_module = None
# Patch sanitize for MTP if this might be DeepSeek V3
should_try_mtp = MTP_ENABLED and _might_be_deepseek_v3(model_id)
if should_try_mtp:
logger.info("Patching DeepSeek V3 sanitize for MTP weight preservation")
_patch_deepseek_sanitize_for_mtp()
try:
if group is None:
logger.info(f"Single device used for {bound_instance.instance}")
model_path = build_model_path(model_id)
start_time = time.perf_counter()
model, _ = load_model(model_path, strict=not should_try_mtp)
end_time = time.perf_counter()
logger.info(f"Time taken to load model: {(end_time - start_time):.2f}s")
tokenizer = get_tokenizer(model_path, bound_instance.bound_shard)
else:
logger.info("Starting distributed init")
start_time = time.perf_counter()
model, tokenizer = shard_and_load(
bound_instance.bound_shard, group=group, on_timeout=on_timeout
)
end_time = time.perf_counter()
logger.info(
f"Time taken to shard and load model: {(end_time - start_time):.2f}s"
)
# Extract MTP module if available
if should_try_mtp and _is_deepseek_v3_model(model):
mtp_module = _extract_mtp_module(model)
if mtp_module is not None:
logger.info("Successfully extracted MTP module from DeepSeek V3")
finally:
# Restore original sanitize
if should_try_mtp:
_restore_deepseek_sanitize()
else:
logger.info("Starting distributed init")
start_time = time.perf_counter()
model, tokenizer = shard_and_load(
bound_instance.bound_shard, group=group, on_timeout=on_timeout
)
end_time = time.perf_counter()
logger.info(
f"Time taken to shard and load model: {(end_time - start_time):.2f}s"
)
set_wired_limit_for_model(get_weights_size(bound_instance.bound_shard))
# Store MTP module on the model for later access
if mtp_module is not None:
model.mtp_module = mtp_module # noqa: B010
return cast(Model, model), tokenizer
def _might_be_deepseek_v3(model_id: str) -> bool:
"""Check if model ID suggests this might be DeepSeek V3."""
model_id_lower = model_id.lower()
return "deepseek" in model_id_lower and (
"v3" in model_id_lower or "r1" in model_id_lower
)
def _flatten_params(
params: dict[str, Any],
prefix: str = "",
) -> dict[str, mx.array]:
"""Flatten nested parameter dict to flat dict with dot-separated keys."""
result: dict[str, mx.array] = {}
for key, value in params.items():
full_key = f"{prefix}.{key}" if prefix else key
if isinstance(value, mx.array):
result[full_key] = value
elif isinstance(value, dict):
result.update(_flatten_params(value, full_key))
return result
def _extract_mtp_module(model: nn.Module) -> Any | None:
"""Extract MTP module from a loaded DeepSeek V3 model.
The MTP weights are stored in model.model.layers at index 61 (if preserved).
This function extracts them and creates an MTPModule.
Returns:
MTPModule if MTP weights were found and extracted, None otherwise.
"""
from exo.worker.engines.mlx.mtp.module import (
MTPModule,
extract_mtp_weights,
load_mtp_weights_into_module,
)
try:
# Check if the model has the MTP layer
inner_model = getattr(model, "model", None)
if inner_model is None or not hasattr(inner_model, "layers"):
logger.debug("Model doesn't have expected structure for MTP extraction")
return None
layers: list[nn.Module] = inner_model.layers # type: ignore[assignment]
if len(layers) <= MTP_LAYER_INDEX:
logger.debug(
f"Model has {len(layers)} layers, MTP layer {MTP_LAYER_INDEX} not found"
)
return None
# Get model config
config = getattr(model, "args", None)
if config is None:
logger.debug("Could not get model config for MTP module")
return None
# Create MTP module with shared weights
embed_tokens = getattr(inner_model, "embed_tokens", None)
lm_head = getattr(model, "lm_head", None)
norm = getattr(inner_model, "norm", None)
if embed_tokens is None or lm_head is None or norm is None:
logger.debug("Could not get required model components for MTP")
return None
mtp_module = MTPModule(
config=config,
shared_embedding=embed_tokens,
shared_lm_head=lm_head,
output_norm=norm,
)
# Extract MTP layer weights from the model's parameters
# The weights should be at model.model.layers.61.*
# model.parameters() returns a nested dict, we need to flatten it
raw_params: dict[str, Any] = dict(model.parameters()) # type: ignore[arg-type]
model_weights = _flatten_params(raw_params)
mtp_weights = extract_mtp_weights(model_weights)
if not mtp_weights:
logger.debug("No MTP weights found in model parameters")
return None
# Load weights into MTP module
load_mtp_weights_into_module(mtp_module, mtp_weights)
# Remove MTP layer from main model to avoid double computation
# Create new layers list without the MTP layer
new_layers = [layer for i, layer in enumerate(layers) if i != MTP_LAYER_INDEX]
inner_model.layers = new_layers # noqa: B010
logger.info(
f"Extracted MTP module, main model now has {len(new_layers)} layers"
)
return mtp_module
except Exception as e:
logger.warning(f"Failed to extract MTP module: {e}")
return None
def shard_and_load(
shard_metadata: ShardMetadata,
group: Group,

View File

@@ -16,8 +16,10 @@ from exo.shared.types.events import (
ForwarderEvent,
IndexedEvent,
NodeDownloadProgress,
NodeIdentityMeasured,
NodeMemoryMeasured,
NodePerformanceMeasured,
NodeNetworkMeasured,
NodeSystemMeasured,
TaskCreated,
TaskStatusUpdated,
TopologyEdgeCreated,
@@ -25,7 +27,11 @@ from exo.shared.types.events import (
)
from exo.shared.types.models import ModelId
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.profiling import MemoryPerformanceProfile, NodePerformanceProfile
from exo.shared.types.profiling import (
MemoryPerformanceProfile,
NetworkInterfaceInfo,
SystemPerformanceProfile,
)
from exo.shared.types.state import State
from exo.shared.types.tasks import (
CreateRunner,
@@ -51,7 +57,13 @@ from exo.worker.download.download_utils import (
from exo.worker.download.shard_downloader import RepoDownloadProgress, ShardDownloader
from exo.worker.plan import plan
from exo.worker.runner.runner_supervisor import RunnerSupervisor
from exo.worker.utils import start_polling_memory_metrics, start_polling_node_metrics
from exo.worker.utils import (
IdentityMetrics,
start_polling_identity_metrics,
start_polling_memory_metrics,
start_polling_network_metrics,
start_polling_system_metrics,
)
from exo.worker.utils.net_profile import check_reachable
@@ -98,37 +110,51 @@ class Worker:
async def run(self):
logger.info("Starting Worker")
# TODO: CLEANUP HEADER
async def resource_monitor_callback(
node_performance_profile: NodePerformanceProfile,
) -> None:
async def identity_callback(identity: IdentityMetrics) -> None:
await self.event_sender.send(
NodePerformanceMeasured(
NodeIdentityMeasured(
node_id=self.node_id,
node_profile=node_performance_profile,
model_id=identity.model_id,
chip_id=identity.chip_id,
friendly_name=identity.friendly_name,
when=str(datetime.now(tz=timezone.utc)),
),
)
async def memory_monitor_callback(
memory_profile: MemoryPerformanceProfile,
) -> None:
async def system_callback(system: SystemPerformanceProfile) -> None:
await self.event_sender.send(
NodeSystemMeasured(
node_id=self.node_id,
system=system,
when=str(datetime.now(tz=timezone.utc)),
),
)
async def network_callback(interfaces: list[NetworkInterfaceInfo]) -> None:
await self.event_sender.send(
NodeNetworkMeasured(
node_id=self.node_id,
network_interfaces=interfaces,
when=str(datetime.now(tz=timezone.utc)),
),
)
async def memory_callback(memory: MemoryPerformanceProfile) -> None:
await self.event_sender.send(
NodeMemoryMeasured(
node_id=self.node_id,
memory=memory_profile,
memory=memory,
when=str(datetime.now(tz=timezone.utc)),
)
)
# END CLEANUP
async with create_task_group() as tg:
self._tg = tg
tg.start_soon(self.plan_step)
tg.start_soon(start_polling_node_metrics, resource_monitor_callback)
tg.start_soon(start_polling_memory_metrics, memory_monitor_callback)
tg.start_soon(start_polling_identity_metrics, identity_callback)
tg.start_soon(start_polling_system_metrics, system_callback)
tg.start_soon(start_polling_network_metrics, network_callback)
tg.start_soon(start_polling_memory_metrics, memory_callback)
tg.start_soon(self._emit_existing_download_progress)
tg.start_soon(self._connection_message_event_writer)
tg.start_soon(self._resend_out_for_delivery)

View File

@@ -1,6 +1,15 @@
from .profile import start_polling_memory_metrics, start_polling_node_metrics
from .profile import (
IdentityMetrics,
start_polling_identity_metrics,
start_polling_memory_metrics,
start_polling_network_metrics,
start_polling_system_metrics,
)
__all__ = [
"start_polling_node_metrics",
"IdentityMetrics",
"start_polling_identity_metrics",
"start_polling_memory_metrics",
"start_polling_network_metrics",
"start_polling_system_metrics",
]

View File

@@ -1,6 +1,7 @@
import asyncio
import os
import platform
from dataclasses import dataclass
from typing import Any, Callable, Coroutine
import anyio
@@ -9,7 +10,7 @@ from loguru import logger
from exo.shared.types.memory import Memory
from exo.shared.types.profiling import (
MemoryPerformanceProfile,
NodePerformanceProfile,
NetworkInterfaceInfo,
SystemPerformanceProfile,
)
@@ -27,6 +28,13 @@ from .system_info import (
)
@dataclass(frozen=True)
class IdentityMetrics:
model_id: str
chip_id: str
friendly_name: str
async def get_metrics_async() -> Metrics | None:
"""Return detailed Metrics on macOS or a minimal fallback elsewhere."""
@@ -67,48 +75,73 @@ async def start_polling_memory_metrics(
await anyio.sleep(poll_interval_s)
async def start_polling_node_metrics(
callback: Callable[[NodePerformanceProfile], Coroutine[Any, Any, None]],
):
poll_interval_s = 1.0
async def start_polling_identity_metrics(
callback: Callable[[IdentityMetrics], Coroutine[Any, Any, None]],
*,
poll_interval_s: float = 30.0,
) -> None:
"""Continuously poll and emit identity metrics at 30s intervals."""
while True:
try:
model_id, chip_id = await get_model_and_chip()
friendly_name = await get_friendly_name()
await callback(
IdentityMetrics(
model_id=model_id,
chip_id=chip_id,
friendly_name=friendly_name,
)
)
except Exception as e:
logger.opt(exception=e).error("Failed to emit identity metrics")
finally:
await anyio.sleep(poll_interval_s)
async def start_polling_system_metrics(
callback: Callable[[SystemPerformanceProfile], Coroutine[Any, Any, None]],
*,
poll_interval_s: float = 1.0,
) -> None:
"""Continuously poll and emit system metrics (GPU, temp, power) at 1s intervals."""
while True:
try:
metrics = await get_metrics_async()
if metrics is None:
return
network_interfaces = get_network_interfaces()
# these awaits could be joined but realistically they should be cached
model_id, chip_id = await get_model_and_chip()
friendly_name = await get_friendly_name()
# do the memory profile last to get a fresh reading to not conflict with the other memory profiling loop
memory_profile = get_memory_profile()
await callback(
NodePerformanceProfile(
model_id=model_id,
chip_id=chip_id,
friendly_name=friendly_name,
network_interfaces=network_interfaces,
memory=memory_profile,
system=SystemPerformanceProfile(
gpu_usage=metrics.gpu_usage[1],
temp=metrics.temp.gpu_temp_avg,
sys_power=metrics.sys_power,
pcpu_usage=metrics.pcpu_usage[1],
ecpu_usage=metrics.ecpu_usage[1],
ane_power=metrics.ane_power,
),
SystemPerformanceProfile(
gpu_usage=metrics.gpu_usage[1],
temp=metrics.temp.gpu_temp_avg,
sys_power=metrics.sys_power,
pcpu_usage=metrics.pcpu_usage[1],
ecpu_usage=metrics.ecpu_usage[1],
ane_power=metrics.ane_power,
)
)
except asyncio.TimeoutError:
logger.warning(
"[resource_monitor] Operation timed out after 30s, skipping this cycle."
"[system_monitor] Operation timed out after 30s, skipping this cycle."
)
except MacMonError as e:
logger.opt(exception=e).error("Resource Monitor encountered error")
logger.opt(exception=e).error("System Monitor encountered error")
return
finally:
await anyio.sleep(poll_interval_s)
async def start_polling_network_metrics(
callback: Callable[[list[NetworkInterfaceInfo]], Coroutine[Any, Any, None]],
*,
poll_interval_s: float = 30.0,
) -> None:
"""Continuously poll and emit network interface info at 30s intervals."""
while True:
try:
network_interfaces = get_network_interfaces()
await callback(network_interfaces)
except Exception as e:
logger.opt(exception=e).error("Network Monitor encountered error")
finally:
await anyio.sleep(poll_interval_s)