Compare commits

..

2 Commits

Author SHA1 Message Date
Ryuichi Leo Takashige
67deec88ca Test stuff 2026-01-20 11:30:35 +00:00
Ryuichi Leo Takashige
209d618d5a Load model layers individually but eagerly 2026-01-19 22:00:31 +00:00
47 changed files with 1739 additions and 1621 deletions

View File

@@ -863,6 +863,7 @@
"integrity": "sha512-oH8tXw7EZnie8FdOWYrF7Yn4IKrqTFHhXvl8YxXxbKwTMcD/5NNCryUSEXRk2ZR4ojnub0P8rNrsVGHXWqIDtA==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"@standard-schema/spec": "^1.0.0",
"@sveltejs/acorn-typescript": "^1.0.5",
@@ -902,6 +903,7 @@
"integrity": "sha512-Y1Cs7hhTc+a5E9Va/xwKlAJoariQyHY+5zBgCZg4PFWNYQ1nMN9sjK1zhw1gK69DuqVP++sht/1GZg1aRwmAXQ==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"@sveltejs/vite-plugin-svelte-inspector": "^4.0.1",
"debug": "^4.4.1",
@@ -1518,6 +1520,7 @@
"integrity": "sha512-LCCV0HdSZZZb34qifBsyWlUmok6W7ouER+oQIGBScS8EsZsQbrtFTUrDX4hOl+CS6p7cnNC4td+qrSVGSCTUfQ==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"undici-types": "~6.21.0"
}
@@ -1527,6 +1530,7 @@
"resolved": "https://registry.npmjs.org/acorn/-/acorn-8.15.0.tgz",
"integrity": "sha512-NZyJarBfL7nWwIq+FDL6Zp/yHEhePMNnnJ0y3qfieCrmNvYct8uvtiV41UvlSe6apAfk0fY1FbWx+NwfmpvtTg==",
"license": "MIT",
"peer": true,
"bin": {
"acorn": "bin/acorn"
},
@@ -1939,6 +1943,7 @@
"integrity": "sha512-fmTRWbNMmsmWq6xJV8D19U/gw/bwrHfNXxrIN+HfZgnzqTHp9jOmKMhsTUjXOJnZOdZY9Q28y4yebKzqDKlxlQ==",
"dev": true,
"license": "ISC",
"peer": true,
"engines": {
"node": ">=12"
}
@@ -2646,6 +2651,7 @@
"integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==",
"dev": true,
"license": "MIT",
"peer": true,
"engines": {
"node": ">=12"
},
@@ -2833,6 +2839,7 @@
"resolved": "https://registry.npmjs.org/svelte/-/svelte-5.45.3.tgz",
"integrity": "sha512-ngKXNhNvwPzF43QqEhDOue7TQTrG09em1sd4HBxVF0Wr2gopAmdEWan+rgbdgK4fhBtSOTJO8bYU4chUG7VXZQ==",
"license": "MIT",
"peer": true,
"dependencies": {
"@jridgewell/remapping": "^2.3.4",
"@jridgewell/sourcemap-codec": "^1.5.0",
@@ -2977,6 +2984,7 @@
"integrity": "sha512-jl1vZzPDinLr9eUt3J/t7V6FgNEw9QjvBPdysz9KfQDD41fQrC2Y4vKQdiaUpFT4bXlb1RHhLpp8wtm6M5TgSw==",
"dev": true,
"license": "Apache-2.0",
"peer": true,
"bin": {
"tsc": "bin/tsc",
"tsserver": "bin/tsserver"
@@ -2998,6 +3006,7 @@
"integrity": "sha512-+Oxm7q9hDoLMyJOYfUYBuHQo+dkAloi33apOPP56pzj+vsdJDzr+j1NISE5pyaAuKL4A3UD34qd0lx5+kfKp2g==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"esbuild": "^0.25.0",
"fdir": "^6.4.4",

View File

@@ -71,46 +71,44 @@ export interface Instance {
};
}
// Granular node state types from the new state structure
interface RawNodeIdentity {
interface RawNodeProfile {
modelId?: string;
chipId?: string;
friendlyName?: string;
networkInterfaces?: Array<{
name?: string;
ipAddress?: string;
addresses?: Array<{ address?: string } | string>;
ipv4?: string;
ipv6?: string;
ipAddresses?: string[];
ips?: string[];
}>;
memory?: {
ramTotal?: { inBytes: number };
ramAvailable?: { inBytes: number };
swapTotal?: { inBytes: number };
swapAvailable?: { inBytes: number };
};
system?: {
gpuUsage?: number;
temp?: number;
sysPower?: number;
};
}
interface RawMemoryUsage {
ramTotal?: { inBytes: number };
ramAvailable?: { inBytes: number };
swapTotal?: { inBytes: number };
swapAvailable?: { inBytes: number };
}
interface RawSystemPerformanceProfile {
gpuUsage?: number;
temp?: number;
sysPower?: number;
pcpuUsage?: number;
ecpuUsage?: number;
}
interface RawNetworkInterfaceInfo {
name?: string;
ipAddress?: string;
addresses?: Array<{ address?: string } | string>;
ipv4?: string;
ipv6?: string;
ipAddresses?: string[];
ips?: string[];
}
interface RawNodeNetworkInfo {
interfaces?: RawNetworkInterfaceInfo[];
interface RawTopologyNode {
nodeId: string;
nodeProfile?: RawNodeProfile;
}
// New connection edge types from Python SocketConnection/RDMAConnection
interface RawSocketConnection {
sinkMultiaddr?: {
address?: string;
// Multiaddr uses snake_case (no camelCase alias)
ip_address?: string;
ipAddress?: string; // fallback in case it changes
address_type?: string;
port?: number;
};
@@ -127,10 +125,14 @@ type RawConnectionEdge = RawSocketConnection | RawRDMAConnection;
type RawConnectionsMap = Record<string, Record<string, RawConnectionEdge[]>>;
interface RawTopology {
nodes: string[];
// nodes can be array of strings (node IDs) or array of objects with nodeId/nodeProfile
nodes: (string | RawTopologyNode)[];
// New nested mapping format
connections?: RawConnectionsMap;
}
type RawNodeProfiles = Record<string, RawNodeProfile>;
export interface DownloadProgress {
totalBytes: number;
downloadedBytes: number;
@@ -185,11 +187,7 @@ interface RawStateResponse {
>;
runners?: Record<string, unknown>;
downloads?: Record<string, unknown[]>;
// New granular node state fields
nodeIdentities?: Record<string, RawNodeIdentity>;
nodeMemory?: Record<string, RawMemoryUsage>;
nodeSystem?: Record<string, RawSystemPerformanceProfile>;
nodeNetwork?: Record<string, RawNodeNetworkInfo>;
nodeProfiles?: RawNodeProfiles;
}
export interface MessageAttachment {
@@ -224,69 +222,65 @@ export interface Conversation {
const STORAGE_KEY = "exo-conversations";
interface GranularNodeState {
nodeIdentities?: Record<string, RawNodeIdentity>;
nodeMemory?: Record<string, RawMemoryUsage>;
nodeSystem?: Record<string, RawSystemPerformanceProfile>;
nodeNetwork?: Record<string, RawNodeNetworkInfo>;
}
function transformNetworkInterface(iface: RawNetworkInterfaceInfo): {
name?: string;
addresses: string[];
} {
const addresses: string[] = [];
if (iface.ipAddress && typeof iface.ipAddress === "string") {
addresses.push(iface.ipAddress);
}
if (Array.isArray(iface.addresses)) {
for (const addr of iface.addresses) {
if (typeof addr === "string") addresses.push(addr);
else if (addr && typeof addr === "object" && addr.address)
addresses.push(addr.address);
}
}
if (Array.isArray(iface.ipAddresses)) {
addresses.push(
...iface.ipAddresses.filter((a): a is string => typeof a === "string"),
);
}
if (Array.isArray(iface.ips)) {
addresses.push(
...iface.ips.filter((a): a is string => typeof a === "string"),
);
}
if (iface.ipv4 && typeof iface.ipv4 === "string") addresses.push(iface.ipv4);
if (iface.ipv6 && typeof iface.ipv6 === "string") addresses.push(iface.ipv6);
return {
name: iface.name,
addresses: Array.from(new Set(addresses)),
};
}
function transformTopology(
raw: RawTopology,
granularState: GranularNodeState,
profiles?: RawNodeProfiles,
): TopologyData {
const nodes: Record<string, NodeInfo> = {};
const edges: TopologyEdge[] = [];
for (const nodeId of raw.nodes || []) {
// Handle nodes - can be array of strings (node IDs) or array of objects with nodeId/nodeProfile
for (const node of raw.nodes || []) {
// Determine the node ID - could be a string or an object with nodeId property
const nodeId = typeof node === "string" ? node : node.nodeId;
if (!nodeId) continue;
// Get data from granular state mappings
const identity = granularState.nodeIdentities?.[nodeId];
const memory = granularState.nodeMemory?.[nodeId];
const system = granularState.nodeSystem?.[nodeId];
const network = granularState.nodeNetwork?.[nodeId];
// Get the profile - from the separate profiles map or from the node object itself
const profileFromMap = profiles?.[nodeId];
const profileFromNode =
typeof node === "object" ? node.nodeProfile : undefined;
const profile = { ...(profileFromNode ?? {}), ...(profileFromMap ?? {}) };
const ramTotal = memory?.ramTotal?.inBytes ?? 0;
const ramAvailable = memory?.ramAvailable?.inBytes ?? 0;
const ramTotal = profile?.memory?.ramTotal?.inBytes ?? 0;
const ramAvailable = profile?.memory?.ramAvailable?.inBytes ?? 0;
const ramUsage = Math.max(ramTotal - ramAvailable, 0);
const rawInterfaces = network?.interfaces || [];
const networkInterfaces = rawInterfaces.map(transformNetworkInterface);
const networkInterfaces = (profile?.networkInterfaces || []).map(
(iface) => {
const addresses: string[] = [];
if (iface.ipAddress && typeof iface.ipAddress === "string") {
addresses.push(iface.ipAddress);
}
if (Array.isArray(iface.addresses)) {
for (const addr of iface.addresses) {
if (typeof addr === "string") addresses.push(addr);
else if (addr && typeof addr === "object" && addr.address)
addresses.push(addr.address);
}
}
if (Array.isArray(iface.ipAddresses)) {
addresses.push(
...iface.ipAddresses.filter(
(a): a is string => typeof a === "string",
),
);
}
if (Array.isArray(iface.ips)) {
addresses.push(
...iface.ips.filter((a): a is string => typeof a === "string"),
);
}
if (iface.ipv4 && typeof iface.ipv4 === "string")
addresses.push(iface.ipv4);
if (iface.ipv6 && typeof iface.ipv6 === "string")
addresses.push(iface.ipv6);
return {
name: iface.name,
addresses: Array.from(new Set(addresses)),
};
},
);
const ipToInterface: Record<string, string> = {};
for (const iface of networkInterfaces) {
@@ -297,8 +291,8 @@ function transformTopology(
nodes[nodeId] = {
system_info: {
model_id: identity?.modelId ?? "Unknown",
chip: identity?.chipId,
model_id: profile?.modelId ?? "Unknown",
chip: profile?.chipId,
memory: ramTotal,
},
network_interfaces: networkInterfaces,
@@ -309,15 +303,17 @@ function transformTopology(
ram_total: ramTotal,
},
temp:
system?.temp !== undefined
? { gpu_temp_avg: system.temp }
profile?.system?.temp !== undefined
? { gpu_temp_avg: profile.system.temp }
: undefined,
gpu_usage:
system?.gpuUsage !== undefined ? [0, system.gpuUsage] : undefined,
sys_power: system?.sysPower,
profile?.system?.gpuUsage !== undefined
? [0, profile.system.gpuUsage]
: undefined,
sys_power: profile?.system?.sysPower,
},
last_macmon_update: Date.now() / 1000,
friendly_name: identity?.friendlyName,
friendly_name: profile?.friendlyName,
};
}
@@ -329,15 +325,19 @@ function transformTopology(
for (const [sink, edgeList] of Object.entries(sinks)) {
if (!Array.isArray(edgeList)) continue;
for (const edge of edgeList) {
// Extract IP from SocketConnection (uses snake_case: ip_address)
let sendBackIp: string | undefined;
if (edge && typeof edge === "object" && "sinkMultiaddr" in edge) {
const multiaddr = edge.sinkMultiaddr;
if (multiaddr) {
// Try both snake_case (actual) and camelCase (in case it changes)
sendBackIp =
multiaddr.ip_address ||
multiaddr.ipAddress ||
extractIpFromMultiaddr(multiaddr.address);
}
}
// RDMAConnection (sourceRdmaIface/sinkRdmaIface) has no IP - edge just shows connection exists
if (nodes[source] && nodes[sink] && source !== sink) {
edges.push({ source, target: sink, sendBackIp });
@@ -898,12 +898,7 @@ class AppStore {
const data: RawStateResponse = await response.json();
if (data.topology) {
this.topologyData = transformTopology(data.topology, {
nodeIdentities: data.nodeIdentities,
nodeMemory: data.nodeMemory,
nodeSystem: data.nodeSystem,
nodeNetwork: data.nodeNetwork,
});
this.topologyData = transformTopology(data.topology, data.nodeProfiles);
}
if (data.instances) {
this.instances = data.instances;

View File

@@ -434,8 +434,8 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
const shardData = shardObj[shardKeys[0]] as Record<string, unknown>;
if (!shardData) return null;
// Model meta is nested: shard.model_card.model_id
const modelMeta = shardData.model_card ?? shardData.modelCard;
// Model meta is nested: shard.model_meta.model_id
const modelMeta = shardData.model_meta ?? shardData.modelMeta;
if (!modelMeta || typeof modelMeta !== 'object') return null;
const meta = modelMeta as Record<string, unknown>;

View File

@@ -98,7 +98,7 @@
const shardData = shardObj[shardKeys[0]] as Record<string, unknown>;
if (!shardData) return null;
const modelMeta = shardData.model_card ?? shardData.modelCard;
const modelMeta = shardData.model_meta ?? shardData.modelMeta;
if (!modelMeta || typeof modelMeta !== 'object') return null;
const meta = modelMeta as Record<string, unknown>;
@@ -190,7 +190,7 @@
const shardKeys = Object.keys(shardObj);
if (shardKeys.length !== 1) return null;
const shardData = shardObj[shardKeys[0]] as Record<string, unknown>;
const modelMeta = shardData?.model_card ?? shardData?.modelCard;
const modelMeta = shardData?.model_meta ?? shardData?.modelMeta;
if (!modelMeta || typeof modelMeta !== 'object') return null;
const meta = modelMeta as Record<string, unknown>;
return (meta.prettyName as string) ?? null;

View File

@@ -17,8 +17,8 @@ dependencies = [
"loguru>=0.7.3",
"exo_pyo3_bindings", # rust bindings
"anyio==4.11.0",
"mlx==0.30.3; sys_platform == 'darwin'",
"mlx[cpu]==0.30.3; sys_platform == 'linux'",
"mlx==0.30.1; sys_platform == 'darwin'",
"mlx[cpu]==0.30.1; sys_platform == 'linux'",
"mlx-lm @ git+https://github.com/AlexCheema/mlx-lm.git@fix-transformers-5.0.0rc2",
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
"hypercorn>=0.18.0",

View File

@@ -19,8 +19,8 @@ from exo.master.placement import place_instance as get_instance_placements
from exo.shared.apply import apply
from exo.shared.election import ElectionMessage
from exo.shared.logging import InterceptLogger
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard, ModelId
from exo.shared.models.model_meta import get_model_card
from exo.shared.models.model_cards import MODEL_CARDS
from exo.shared.models.model_meta import get_model_meta
from exo.shared.types.api import (
BenchChatCompletionResponse,
BenchChatCompletionTaskParams,
@@ -59,6 +59,7 @@ from exo.shared.types.events import (
IndexedEvent,
)
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.state import State
from exo.shared.types.tasks import ChatCompletionTaskParams
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
@@ -86,12 +87,12 @@ def chunk_to_response(
)
async def resolve_model_card(model_id: str) -> ModelCard:
async def resolve_model_meta(model_id: str) -> ModelMetadata:
if model_id in MODEL_CARDS:
model_card = MODEL_CARDS[model_id]
return model_card
return model_card.metadata
else:
return await get_model_card(model_id)
return await get_model_meta(model_id)
class API:
@@ -196,7 +197,7 @@ class API:
async def place_instance(self, payload: PlaceInstanceParams):
command = PlaceInstance(
model_card=await resolve_model_card(payload.model_id),
model_meta=await resolve_model_meta(payload.model_id),
sharding=payload.sharding,
instance_meta=payload.instance_meta,
min_nodes=payload.min_nodes,
@@ -206,15 +207,15 @@ class API:
return CreateInstanceResponse(
message="Command received.",
command_id=command.command_id,
model_card=command.model_card,
model_meta=command.model_meta,
)
async def create_instance(
self, payload: CreateInstanceParams
) -> CreateInstanceResponse:
instance = payload.instance
model_card = await resolve_model_card(instance.shard_assignments.model_id)
required_memory = model_card.storage_size
model_meta = await resolve_model_meta(instance.shard_assignments.model_id)
required_memory = model_meta.storage_size
available_memory = self._calculate_total_available_memory()
if required_memory > available_memory:
@@ -231,7 +232,7 @@ class API:
return CreateInstanceResponse(
message="Command received.",
command_id=command.command_id,
model_card=model_card,
model_meta=model_meta,
)
async def get_placement(
@@ -241,18 +242,17 @@ class API:
instance_meta: InstanceMeta = InstanceMeta.MlxRing,
min_nodes: int = 1,
) -> Instance:
model_card = await resolve_model_card(model_id)
model_meta = await resolve_model_meta(model_id)
try:
placements = get_instance_placements(
PlaceInstance(
model_card=model_card,
model_meta=model_meta,
sharding=sharding,
instance_meta=instance_meta,
min_nodes=min_nodes,
),
node_memory=self.state.node_memory,
node_network=self.state.node_network,
node_profiles=self.state.node_profiles,
topology=self.state.topology,
current_instances=self.state.instances,
)
@@ -279,7 +279,7 @@ class API:
if len(list(self.state.topology.list_nodes())) == 0:
return PlacementPreviewResponse(previews=[])
cards = [card for card in MODEL_CARDS.values() if card.model_id == model_id]
cards = [card for card in MODEL_CARDS.values() if card.short_id == model_id]
if not cards:
raise HTTPException(status_code=404, detail=f"Model {model_id} not found")
@@ -297,33 +297,33 @@ class API:
# TODO: PDD
# instance_combinations.append((Sharding.PrefillDecodeDisaggregation, InstanceMeta.MlxRing, 1))
for model_card in cards:
for card in cards:
model_meta = card.metadata
for sharding, instance_meta, min_nodes in instance_combinations:
try:
placements = get_instance_placements(
PlaceInstance(
model_card=model_card,
model_meta=model_meta,
sharding=sharding,
instance_meta=instance_meta,
min_nodes=min_nodes,
),
node_memory=self.state.node_memory,
node_network=self.state.node_network,
node_profiles=self.state.node_profiles,
topology=self.state.topology,
current_instances=self.state.instances,
)
except ValueError as exc:
if (model_card.model_id, sharding, instance_meta, 0) not in seen:
if (card.model_id, sharding, instance_meta, 0) not in seen:
previews.append(
PlacementPreview(
model_id=model_card.model_id,
model_id=card.model_id,
sharding=sharding,
instance_meta=instance_meta,
instance=None,
error=str(exc),
)
)
seen.add((model_card.model_id, sharding, instance_meta, 0))
seen.add((card.model_id, sharding, instance_meta, 0))
continue
current_ids = set(self.state.instances.keys())
@@ -334,17 +334,17 @@ class API:
]
if len(new_instances) != 1:
if (model_card.model_id, sharding, instance_meta, 0) not in seen:
if (card.model_id, sharding, instance_meta, 0) not in seen:
previews.append(
PlacementPreview(
model_id=model_card.model_id,
model_id=card.model_id,
sharding=sharding,
instance_meta=instance_meta,
instance=None,
error="Expected exactly one new instance from placement",
)
)
seen.add((model_card.model_id, sharding, instance_meta, 0))
seen.add((card.model_id, sharding, instance_meta, 0))
continue
instance = new_instances[0]
@@ -353,7 +353,7 @@ class API:
memory_delta_by_node: dict[str, int] = {}
if node_ids:
total_bytes = model_card.storage_size.in_bytes
total_bytes = model_meta.storage_size.in_bytes
per_node = total_bytes // len(node_ids)
remainder = total_bytes % len(node_ids)
for index, node_id in enumerate(sorted(node_ids, key=str)):
@@ -361,14 +361,14 @@ class API:
memory_delta_by_node[str(node_id)] = per_node + extra
if (
model_card.model_id,
card.model_id,
sharding,
instance_meta,
len(node_ids),
) not in seen:
previews.append(
PlacementPreview(
model_id=model_card.model_id,
model_id=card.model_id,
sharding=sharding,
instance_meta=instance_meta,
instance=instance,
@@ -376,7 +376,7 @@ class API:
error=None,
)
)
seen.add((model_card.model_id, sharding, instance_meta, len(node_ids)))
seen.add((card.model_id, sharding, instance_meta, len(node_ids)))
return PlacementPreviewResponse(previews=previews)
@@ -551,8 +551,8 @@ class API:
self, payload: ChatCompletionTaskParams
) -> ChatCompletionResponse | StreamingResponse:
"""Handle chat completions, supporting both streaming and non-streaming responses."""
model_card = await resolve_model_card(payload.model)
payload.model = model_card.model_id
model_meta = await resolve_model_meta(payload.model)
payload.model = model_meta.model_id
if not any(
instance.shard_assignments.model_id == payload.model
@@ -578,8 +578,8 @@ class API:
async def bench_chat_completions(
self, payload: BenchChatCompletionTaskParams
) -> BenchChatCompletionResponse:
model_card = await resolve_model_card(payload.model)
payload.model = model_card.model_id
model_meta = await resolve_model_meta(payload.model)
payload.model = model_meta.model_id
if not any(
instance.shard_assignments.model_id == payload.model
@@ -602,8 +602,8 @@ class API:
"""Calculate total available memory across all nodes in bytes."""
total_available = Memory()
for memory in self.state.node_memory.values():
total_available += memory.ram_available
for profile in self.state.node_profiles.values():
total_available += profile.memory.ram_available
return total_available
@@ -612,13 +612,13 @@ class API:
return ModelList(
data=[
ModelListModel(
id=card.model_id,
id=card.short_id,
hugging_face_id=card.model_id,
name=card.model_id.short(),
description="",
tags=[],
storage_size_megabytes=int(card.storage_size.in_mb),
supports_tensor=card.supports_tensor,
name=card.name,
description=card.description,
tags=card.tags,
storage_size_megabytes=int(card.metadata.storage_size.in_mb),
supports_tensor=card.metadata.supports_tensor,
)
for card in MODEL_CARDS.values()
]

View File

@@ -159,8 +159,7 @@ class Master:
command,
self.state.topology,
self.state.instances,
self.state.node_memory,
self.state.node_network,
self.state.node_profiles,
)
transition_events = get_transition_events(
self.state.instances, placement

View File

@@ -14,7 +14,6 @@ from exo.master.placement_utils import (
get_shard_assignments,
get_smallest_cycles,
)
from exo.shared.models.model_cards import ModelId
from exo.shared.topology import Topology
from exo.shared.types.commands import (
CreateInstance,
@@ -24,7 +23,8 @@ from exo.shared.types.commands import (
from exo.shared.types.common import NodeId
from exo.shared.types.events import Event, InstanceCreated, InstanceDeleted
from exo.shared.types.memory import Memory
from exo.shared.types.profiling import MemoryUsage, NodeNetworkInfo
from exo.shared.types.models import ModelId
from exo.shared.types.profiling import NodePerformanceProfile
from exo.shared.types.worker.instances import (
Instance,
InstanceId,
@@ -54,33 +54,32 @@ def place_instance(
command: PlaceInstance,
topology: Topology,
current_instances: Mapping[InstanceId, Instance],
node_memory: Mapping[NodeId, MemoryUsage],
node_network: Mapping[NodeId, NodeNetworkInfo],
node_profiles: Mapping[NodeId, NodePerformanceProfile],
) -> dict[InstanceId, Instance]:
cycles = topology.get_cycles()
candidate_cycles = list(filter(lambda it: len(it) >= command.min_nodes, cycles))
cycles_with_sufficient_memory = filter_cycles_by_memory(
candidate_cycles, node_memory, command.model_card.storage_size
candidate_cycles, node_profiles, command.model_meta.storage_size
)
if len(cycles_with_sufficient_memory) == 0:
raise ValueError("No cycles found with sufficient memory")
if command.sharding == Sharding.Tensor:
if not command.model_card.supports_tensor:
if not command.model_meta.supports_tensor:
raise ValueError(
f"Requested Tensor sharding but this model does not support tensor parallelism: {command.model_card.model_id}"
f"Requested Tensor sharding but this model does not support tensor parallelism: {command.model_meta.model_id}"
)
# TODO: the condition here for tensor parallel is not correct, but it works good enough for now.
cycles_with_sufficient_memory = [
cycle
for cycle in cycles_with_sufficient_memory
if command.model_card.hidden_size % len(cycle) == 0
if command.model_meta.hidden_size % len(cycle) == 0
]
if not cycles_with_sufficient_memory:
raise ValueError(
f"No tensor sharding found for model with hidden_size {command.model_card.hidden_size} candidate cycles"
f"No tensor sharding found for model with hidden_size {command.model_meta.hidden_size} candidate cycles"
)
if command.sharding == Sharding.Pipeline and command.model_card.model_id == ModelId(
if command.sharding == Sharding.Pipeline and command.model_meta.model_id == ModelId(
"mlx-community/DeepSeek-V3.1-8bit"
):
raise ValueError(
@@ -105,13 +104,13 @@ def place_instance(
selected_cycle = max(
cycles_with_leaf_nodes if cycles_with_leaf_nodes != [] else smallest_cycles,
key=lambda cycle: sum(
(node_memory[node_id].ram_available for node_id in cycle),
(node_profiles[node_id].memory.ram_available for node_id in cycle),
start=Memory(),
),
)
shard_assignments = get_shard_assignments(
command.model_card, selected_cycle, command.sharding, node_memory
command.model_meta, selected_cycle, command.sharding, node_profiles
)
cycle_digraph: Topology = topology.get_subgraph_from_nodes(selected_cycle.node_ids)
@@ -137,7 +136,7 @@ def place_instance(
coordinator=selected_cycle.node_ids[0],
coordinator_port=random_ephemeral_port(),
cycle_digraph=cycle_digraph,
node_network=node_network,
node_profiles=node_profiles,
)
target_instances[instance_id] = MlxJacclInstance(
instance_id=instance_id,
@@ -151,7 +150,7 @@ def place_instance(
selected_cycle=selected_cycle,
cycle_digraph=cycle_digraph,
ephemeral_port=ephemeral_port,
node_network=node_network,
node_profiles=node_profiles,
)
target_instances[instance_id] = MlxRingInstance(
instance_id=instance_id,

View File

@@ -2,11 +2,11 @@ from collections.abc import Generator, Mapping
from loguru import logger
from exo.shared.models.model_cards import ModelCard
from exo.shared.topology import Topology
from exo.shared.types.common import Host, NodeId
from exo.shared.types.memory import Memory
from exo.shared.types.profiling import MemoryUsage, NodeNetworkInfo
from exo.shared.types.models import ModelMetadata
from exo.shared.types.profiling import NodePerformanceProfile
from exo.shared.types.topology import Cycle, RDMAConnection, SocketConnection
from exo.shared.types.worker.runners import RunnerId, ShardAssignments
from exo.shared.types.worker.shards import (
@@ -19,16 +19,16 @@ from exo.shared.types.worker.shards import (
def filter_cycles_by_memory(
cycles: list[Cycle],
node_memory: Mapping[NodeId, MemoryUsage],
node_profiles: Mapping[NodeId, NodePerformanceProfile],
required_memory: Memory,
) -> list[Cycle]:
filtered_cycles: list[Cycle] = []
for cycle in cycles:
if not all(node in node_memory for node in cycle):
if not all(node in node_profiles for node in cycle):
continue
total_mem = sum(
(node_memory[node_id].ram_available for node_id in cycle.node_ids),
(node_profiles[node_id].memory.ram_available for node_id in cycle.node_ids),
start=Memory(),
)
if total_mem >= required_memory:
@@ -75,21 +75,22 @@ def allocate_layers_proportionally(
def get_shard_assignments_for_pipeline_parallel(
model_card: ModelCard,
model_meta: ModelMetadata,
cycle: Cycle,
node_memory: Mapping[NodeId, MemoryUsage],
node_profiles: Mapping[NodeId, NodePerformanceProfile],
):
if not cycle.node_ids:
raise ValueError("Cannot create shard assignments for empty node cycle")
cycle_memory = sum(
(node_memory[node_id].ram_available for node_id in cycle.node_ids),
(node_profiles[node_id].memory.ram_available for node_id in cycle.node_ids),
start=Memory(),
)
if cycle_memory.in_bytes == 0:
raise ValueError("Cannot create shard assignments: total available memory is 0")
total_layers = model_card.n_layers
total_layers = model_meta.n_layers
world_size = len(cycle)
runner_to_shard: dict[RunnerId, ShardMetadata] = {}
node_to_runner: dict[NodeId, RunnerId] = {}
@@ -97,18 +98,18 @@ def get_shard_assignments_for_pipeline_parallel(
layer_allocations = allocate_layers_proportionally(
total_layers=total_layers,
memory_fractions=[
node_memory[node_id].ram_available.in_bytes / cycle_memory.in_bytes
node_profiles[node_id].memory.ram_available.in_bytes / cycle_memory.in_bytes
for node_id in cycle.node_ids
],
)
# Validate each node has sufficient memory for its assigned layers
memory_per_layer = model_card.storage_size.in_bytes / total_layers
memory_per_layer = model_meta.storage_size.in_bytes / total_layers
for i, (node_id, node_layers) in enumerate(
zip(cycle.node_ids, layer_allocations, strict=True)
):
required_memory = node_layers * memory_per_layer
available_memory = node_memory[node_id].ram_available.in_bytes
available_memory = node_profiles[node_id].memory.ram_available.in_bytes
if required_memory > available_memory:
raise ValueError(
f"Node {i} ({node_id}) has insufficient memory: "
@@ -123,7 +124,7 @@ def get_shard_assignments_for_pipeline_parallel(
runner_id = RunnerId()
shard = PipelineShardMetadata(
model_card=model_card,
model_meta=model_meta,
device_rank=i,
world_size=world_size,
start_layer=layers_assigned,
@@ -136,7 +137,7 @@ def get_shard_assignments_for_pipeline_parallel(
layers_assigned += node_layers
shard_assignments = ShardAssignments(
model_id=model_card.model_id,
model_id=model_meta.model_id,
runner_to_shard=runner_to_shard,
node_to_runner=node_to_runner,
)
@@ -145,17 +146,17 @@ def get_shard_assignments_for_pipeline_parallel(
def get_shard_assignments_for_tensor_parallel(
model_card: ModelCard,
model_meta: ModelMetadata,
cycle: Cycle,
):
total_layers = model_card.n_layers
total_layers = model_meta.n_layers
world_size = len(cycle)
runner_to_shard: dict[RunnerId, ShardMetadata] = {}
node_to_runner: dict[NodeId, RunnerId] = {}
for i, node_id in enumerate(cycle):
shard = TensorShardMetadata(
model_card=model_card,
model_meta=model_meta,
device_rank=i,
world_size=world_size,
start_layer=0,
@@ -169,7 +170,7 @@ def get_shard_assignments_for_tensor_parallel(
node_to_runner[node_id] = runner_id
shard_assignments = ShardAssignments(
model_id=model_card.model_id,
model_id=model_meta.model_id,
runner_to_shard=runner_to_shard,
node_to_runner=node_to_runner,
)
@@ -178,21 +179,21 @@ def get_shard_assignments_for_tensor_parallel(
def get_shard_assignments(
model_card: ModelCard,
model_meta: ModelMetadata,
cycle: Cycle,
sharding: Sharding,
node_memory: Mapping[NodeId, MemoryUsage],
node_profiles: Mapping[NodeId, NodePerformanceProfile],
) -> ShardAssignments:
match sharding:
case Sharding.Pipeline:
return get_shard_assignments_for_pipeline_parallel(
model_card=model_card,
model_meta=model_meta,
cycle=cycle,
node_memory=node_memory,
node_profiles=node_profiles,
)
case Sharding.Tensor:
return get_shard_assignments_for_tensor_parallel(
model_card=model_card,
model_meta=model_meta,
cycle=cycle,
)
@@ -287,10 +288,10 @@ def _find_connection_ip(
def _find_interface_name_for_ip(
ip_address: str, node_network: NodeNetworkInfo
ip_address: str, node_profile: NodePerformanceProfile
) -> str | None:
"""Find the interface name for an IP address on a node (any interface)."""
for interface in node_network.interfaces:
for interface in node_profile.network_interfaces:
if interface.ip_address == ip_address:
return interface.name
@@ -301,7 +302,7 @@ def _find_ip_prioritised(
node_id: NodeId,
other_node_id: NodeId,
cycle_digraph: Topology,
node_network: Mapping[NodeId, NodeNetworkInfo],
node_profiles: Mapping[NodeId, NodePerformanceProfile],
) -> str | None:
# TODO: Actually prioritize in the correct Ethernet > Wifi > Non-TB > TB order.
"""Find an IP address between nodes with prioritization.
@@ -315,9 +316,7 @@ def _find_ip_prioritised(
ips = list(_find_connection_ip(node_id, other_node_id, cycle_digraph))
# We expect a unique iface -> ip mapping
iface_map = {
_find_interface_name_for_ip(
ip, node_network.get(other_node_id, NodeNetworkInfo())
): ip
_find_interface_name_for_ip(ip, node_profiles[other_node_id]): ip
for ip, _ in ips
}
@@ -346,7 +345,7 @@ def get_mlx_ring_hosts_by_node(
selected_cycle: Cycle,
cycle_digraph: Topology,
ephemeral_port: int,
node_network: Mapping[NodeId, NodeNetworkInfo],
node_profiles: Mapping[NodeId, NodePerformanceProfile],
) -> dict[NodeId, list[Host]]:
"""Generate per-node host lists for MLX ring backend.
@@ -378,7 +377,7 @@ def get_mlx_ring_hosts_by_node(
continue
connection_ip = _find_ip_prioritised(
node_id, other_node_id, cycle_digraph, node_network
node_id, other_node_id, cycle_digraph, node_profiles
)
if connection_ip is None:
logger.warning(
@@ -399,7 +398,7 @@ def get_mlx_jaccl_coordinators(
coordinator: NodeId,
coordinator_port: int,
cycle_digraph: Topology,
node_network: Mapping[NodeId, NodeNetworkInfo],
node_profiles: Mapping[NodeId, NodePerformanceProfile],
) -> dict[NodeId, str]:
"""Get the coordinator addresses for MLX JACCL (rank 0 device).
@@ -412,7 +411,7 @@ def get_mlx_jaccl_coordinators(
if n == coordinator:
return "0.0.0.0"
ip = _find_ip_prioritised(n, coordinator, cycle_digraph, node_network)
ip = _find_ip_prioritised(n, coordinator, cycle_digraph, node_profiles)
if ip is not None:
return ip

View File

@@ -2,26 +2,28 @@ from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.profiling import (
MemoryUsage,
NetworkInterfaceInfo,
NodeNetworkInfo,
NodePerformanceProfile,
SystemPerformanceProfile,
)
from exo.shared.types.topology import RDMAConnection, SocketConnection
def create_node_memory(memory: int) -> MemoryUsage:
return MemoryUsage.from_bytes(
ram_total=1000,
ram_available=memory,
swap_total=1000,
swap_available=1000,
)
def create_node_network() -> NodeNetworkInfo:
return NodeNetworkInfo(
interfaces=[
def create_node_profile(memory: int) -> NodePerformanceProfile:
return NodePerformanceProfile(
model_id="test",
chip_id="test",
friendly_name="test",
memory=MemoryUsage.from_bytes(
ram_total=1000,
ram_available=memory,
swap_total=1000,
swap_available=1000,
),
network_interfaces=[
NetworkInterfaceInfo(name="en0", ip_address=f"169.254.0.{i}")
for i in range(10)
]
],
system=SystemPerformanceProfile(),
)

View File

@@ -7,7 +7,6 @@ from loguru import logger
from exo.master.main import Master
from exo.routing.router import get_node_id_keypair
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.types.api import ChatCompletionMessage, ChatCompletionTaskParams
from exo.shared.types.commands import (
ChatCompletion,
@@ -24,6 +23,7 @@ from exo.shared.types.events import (
TaskCreated,
)
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.profiling import (
MemoryUsage,
)
@@ -73,8 +73,8 @@ async def test_master():
tg.start_soon(master.run)
sender_node_id = NodeId(f"{keypair.to_peer_id().to_base58()}_sender")
# inject a NodeGatheredInfo event
logger.info("inject a NodeGatheredInfo event")
# inject a NodePerformanceProfile event
logger.info("inject a NodePerformanceProfile event")
await local_event_sender.send(
ForwarderEvent(
origin_idx=0,
@@ -99,7 +99,7 @@ async def test_master():
logger.info("wait for initial topology event")
while len(list(master.state.topology.list_nodes())) == 0:
await anyio.sleep(0.001)
while len(master.state.node_memory) == 0:
while len(master.state.node_profiles) == 0:
await anyio.sleep(0.001)
logger.info("inject a CreateInstance Command")
@@ -109,8 +109,9 @@ async def test_master():
command=(
PlaceInstance(
command_id=CommandId(),
model_card=ModelCard(
model_meta=ModelMetadata(
model_id=ModelId("llama-3.2-1b"),
pretty_name="Llama 3.2 1B",
n_layers=16,
storage_size=Memory.from_bytes(678948),
hidden_size=7168,
@@ -166,8 +167,9 @@ async def test_master():
start_layer=0,
end_layer=16,
n_layers=16,
model_card=ModelCard(
model_meta=ModelMetadata(
model_id=ModelId("llama-3.2-1b"),
pretty_name="Llama 3.2 1B",
n_layers=16,
storage_size=Memory.from_bytes(678948),
hidden_size=7168,

View File

@@ -5,19 +5,18 @@ from exo.master.placement import (
place_instance,
)
from exo.master.tests.conftest import (
create_node_memory,
create_node_network,
create_node_profile,
create_rdma_connection,
create_socket_connection,
)
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.topology import Topology
from exo.shared.types.commands import PlaceInstance
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.events import InstanceCreated, InstanceDeleted
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.profiling import NetworkInterfaceInfo, NodeNetworkInfo
from exo.shared.types.profiling import NetworkInterfaceInfo
from exo.shared.types.topology import Connection, SocketConnection
from exo.shared.types.worker.instances import (
Instance,
@@ -43,20 +42,21 @@ def instance() -> Instance:
@pytest.fixture
def model_card() -> ModelCard:
return ModelCard(
def model_meta() -> ModelMetadata:
return ModelMetadata(
model_id=ModelId("test-model"),
storage_size=Memory.from_kb(1000),
pretty_name="Test Model",
n_layers=10,
hidden_size=30,
supports_tensor=True,
)
def place_instance_command(model_card: ModelCard) -> PlaceInstance:
def place_instance_command(model_meta: ModelMetadata) -> PlaceInstance:
return PlaceInstance(
command_id=CommandId(),
model_card=model_card,
model_meta=model_meta,
sharding=Sharding.Pipeline,
instance_meta=InstanceMeta.MlxRing,
min_nodes=1,
@@ -75,16 +75,16 @@ def test_get_instance_placements_create_instance(
available_memory: tuple[int, int, int],
total_layers: int,
expected_layers: tuple[int, int, int],
model_card: ModelCard,
model_meta: ModelMetadata,
):
# arrange
model_card.n_layers = total_layers
model_card.storage_size.in_bytes = sum(
model_meta.n_layers = total_layers
model_meta.storage_size.in_bytes = sum(
available_memory
) # make it exactly fit across all nodes
topology = Topology()
cic = place_instance_command(model_card)
cic = place_instance_command(model_meta)
node_id_a = NodeId()
node_id_b = NodeId()
node_id_c = NodeId()
@@ -109,15 +109,10 @@ def test_get_instance_placements_create_instance(
source=node_id_b, sink=node_id_a, edge=create_socket_connection(6)
)
node_memory = {
node_id_a: create_node_memory(available_memory[0]),
node_id_b: create_node_memory(available_memory[1]),
node_id_c: create_node_memory(available_memory[2]),
}
node_network = {
node_id_a: create_node_network(),
node_id_b: create_node_network(),
node_id_c: create_node_network(),
profiles = {
node_id_a: create_node_profile(available_memory[0]),
node_id_b: create_node_profile(available_memory[1]),
node_id_c: create_node_profile(available_memory[2]),
}
topology.add_node(node_id_a)
topology.add_node(node_id_b)
@@ -130,13 +125,13 @@ def test_get_instance_placements_create_instance(
topology.add_connection(conn_b_a)
# act
placements = place_instance(cic, topology, {}, node_memory, node_network)
placements = place_instance(cic, topology, {}, profiles)
# assert
assert len(placements) == 1
instance_id = list(placements.keys())[0]
instance = placements[instance_id]
assert instance.shard_assignments.model_id == model_card.model_id
assert instance.shard_assignments.model_id == model_meta.model_id
runner_id_a = instance.shard_assignments.node_to_runner[node_id_a]
runner_id_b = instance.shard_assignments.node_to_runner[node_id_b]
@@ -160,18 +155,18 @@ def test_get_instance_placements_one_node_exact_fit() -> None:
topology = Topology()
node_id = NodeId()
topology.add_node(node_id)
node_memory = {node_id: create_node_memory(1000 * 1024)}
node_network = {node_id: create_node_network()}
profiles = {node_id: create_node_profile(1000 * 1024)}
cic = place_instance_command(
ModelCard(
ModelMetadata(
model_id=ModelId("test-model"),
storage_size=Memory.from_kb(1000),
pretty_name="Test Model",
n_layers=10,
hidden_size=1000,
supports_tensor=True,
),
)
placements = place_instance(cic, topology, {}, node_memory, node_network)
placements = place_instance(cic, topology, {}, profiles)
assert len(placements) == 1
instance_id = list(placements.keys())[0]
@@ -186,18 +181,18 @@ def test_get_instance_placements_one_node_fits_with_extra_memory() -> None:
topology = Topology()
node_id = NodeId()
topology.add_node(node_id)
node_memory = {node_id: create_node_memory(1001 * 1024)}
node_network = {node_id: create_node_network()}
profiles = {node_id: create_node_profile(1001 * 1024)}
cic = place_instance_command(
ModelCard(
ModelMetadata(
model_id=ModelId("test-model"),
storage_size=Memory.from_kb(1000),
pretty_name="Test Model",
n_layers=10,
hidden_size=1000,
supports_tensor=True,
),
)
placements = place_instance(cic, topology, {}, node_memory, node_network)
placements = place_instance(cic, topology, {}, profiles)
assert len(placements) == 1
instance_id = list(placements.keys())[0]
@@ -212,12 +207,12 @@ def test_get_instance_placements_one_node_not_fit() -> None:
topology = Topology()
node_id = NodeId()
topology.add_node(node_id)
node_memory = {node_id: create_node_memory(1000 * 1024)}
node_network = {node_id: create_node_network()}
profiles = {node_id: create_node_profile(1000 * 1024)}
cic = place_instance_command(
model_card=ModelCard(
model_meta=ModelMetadata(
model_id=ModelId("test-model"),
storage_size=Memory.from_kb(1001),
pretty_name="Test Model",
n_layers=10,
hidden_size=1000,
supports_tensor=True,
@@ -225,7 +220,7 @@ def test_get_instance_placements_one_node_not_fit() -> None:
)
with pytest.raises(ValueError, match="No cycles found with sufficient memory"):
place_instance(cic, topology, {}, node_memory, node_network)
place_instance(cic, topology, {}, profiles)
def test_get_transition_events_no_change(instance: Instance):
@@ -271,31 +266,23 @@ def test_get_transition_events_delete_instance(instance: Instance):
def test_placement_selects_leaf_nodes(
model_card: ModelCard,
model_meta: ModelMetadata,
):
# arrange
topology = Topology()
# Model requires more than any single node but fits within a 3-node cycle
model_card.storage_size.in_bytes = 1500
model_card.n_layers = 12
model_meta.storage_size = Memory.from_bytes(1000)
node_id_a = NodeId()
node_id_b = NodeId()
node_id_c = NodeId()
node_id_d = NodeId()
node_memory = {
node_id_a: create_node_memory(500),
node_id_b: create_node_memory(600),
node_id_c: create_node_memory(600),
node_id_d: create_node_memory(500),
}
node_network = {
node_id_a: create_node_network(),
node_id_b: create_node_network(),
node_id_c: create_node_network(),
node_id_d: create_node_network(),
profiles = {
node_id_a: create_node_profile(500),
node_id_b: create_node_profile(600),
node_id_c: create_node_profile(600),
node_id_d: create_node_profile(500),
}
topology.add_node(node_id_a)
@@ -323,10 +310,10 @@ def test_placement_selects_leaf_nodes(
Connection(source=node_id_d, sink=node_id_c, edge=create_socket_connection(1))
)
cic = place_instance_command(model_card=model_card)
cic = place_instance_command(model_meta=model_meta)
# act
placements = place_instance(cic, topology, {}, node_memory, node_network)
placements = place_instance(cic, topology, {}, profiles)
# assert
assert len(placements) == 1
@@ -342,21 +329,21 @@ def test_placement_selects_leaf_nodes(
def test_tensor_rdma_backend_connectivity_matrix(
model_card: ModelCard,
model_meta: ModelMetadata,
):
# arrange
topology = Topology()
model_card.n_layers = 12
model_card.storage_size.in_bytes = 1500
model_meta.n_layers = 12
model_meta.storage_size.in_bytes = 1500
node_a = NodeId()
node_b = NodeId()
node_c = NodeId()
node_memory = {
node_a: create_node_memory(500),
node_b: create_node_memory(500),
node_c: create_node_memory(500),
profiles = {
node_a: create_node_profile(500),
node_b: create_node_profile(500),
node_c: create_node_profile(500),
}
ethernet_interface = NetworkInterfaceInfo(
@@ -367,11 +354,9 @@ def test_tensor_rdma_backend_connectivity_matrix(
sink_multiaddr=Multiaddr(address="/ip4/10.0.0.1/tcp/8000")
)
node_network = {
node_a: NodeNetworkInfo(interfaces=[ethernet_interface]),
node_b: NodeNetworkInfo(interfaces=[ethernet_interface]),
node_c: NodeNetworkInfo(interfaces=[ethernet_interface]),
}
profiles[node_a].network_interfaces = [ethernet_interface]
profiles[node_b].network_interfaces = [ethernet_interface]
profiles[node_c].network_interfaces = [ethernet_interface]
topology.add_node(node_a)
topology.add_node(node_b)
@@ -409,12 +394,12 @@ def test_tensor_rdma_backend_connectivity_matrix(
sharding=Sharding.Tensor,
instance_meta=InstanceMeta.MlxJaccl,
command_id=CommandId(),
model_card=model_card,
model_meta=model_meta,
min_nodes=1,
)
# act
placements = place_instance(cic, topology, {}, node_memory, node_network)
placements = place_instance(cic, topology, {}, profiles)
# assert
assert len(placements) == 1

View File

@@ -1,3 +1,5 @@
from copy import copy
import pytest
from exo.master.placement_utils import (
@@ -8,17 +10,16 @@ from exo.master.placement_utils import (
get_shard_assignments,
get_smallest_cycles,
)
from exo.master.tests.conftest import (
create_node_memory,
create_socket_connection,
)
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.master.tests.conftest import create_node_profile, create_socket_connection
from exo.shared.topology import Topology
from exo.shared.types.common import Host, NodeId
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.profiling import (
MemoryUsage,
NetworkInterfaceInfo,
NodeNetworkInfo,
NodePerformanceProfile,
SystemPerformanceProfile,
)
from exo.shared.types.topology import Connection, SocketConnection
from exo.shared.types.worker.shards import Sharding
@@ -35,9 +36,9 @@ def test_filter_cycles_by_memory():
source=node2_id, sink=node1_id, edge=create_socket_connection(2)
)
node1_mem = create_node_memory(1000 * 1024)
node2_mem = create_node_memory(1000 * 1024)
node_memory = {node1_id: node1_mem, node2_id: node2_mem}
node1 = create_node_profile(1000 * 1024)
node2 = create_node_profile(1000 * 1024)
node_profiles = {node1_id: node1, node2_id: node2}
topology = Topology()
topology.add_node(node1_id)
@@ -50,7 +51,9 @@ def test_filter_cycles_by_memory():
assert len(cycles[0]) == 2
# act
filtered_cycles = filter_cycles_by_memory(cycles, node_memory, Memory.from_bytes(1))
filtered_cycles = filter_cycles_by_memory(
cycles, node_profiles, Memory.from_bytes(1)
)
# assert
assert len(filtered_cycles) == 1
@@ -69,9 +72,9 @@ def test_filter_cycles_by_insufficient_memory():
source=node2_id, sink=node1_id, edge=create_socket_connection(2)
)
node1_mem = create_node_memory(1000 * 1024)
node2_mem = create_node_memory(1000 * 1024)
node_memory = {node1_id: node1_mem, node2_id: node2_mem}
node1 = create_node_profile(1000 * 1024)
node2 = create_node_profile(1000 * 1024)
node_profiles = {node1_id: node1, node2_id: node2}
topology = Topology()
topology.add_node(node1_id)
@@ -81,7 +84,7 @@ def test_filter_cycles_by_insufficient_memory():
# act
filtered_cycles = filter_cycles_by_memory(
topology.get_cycles(), node_memory, Memory.from_kb(2001)
topology.get_cycles(), node_profiles, Memory.from_kb(2001)
)
# assert
@@ -106,13 +109,13 @@ def test_filter_multiple_cycles_by_memory():
source=node_c_id, sink=node_b_id, edge=create_socket_connection(4)
)
node_a_mem = create_node_memory(500 * 1024)
node_b_mem = create_node_memory(500 * 1024)
node_c_mem = create_node_memory(1000 * 1024)
node_memory = {
node_a_id: node_a_mem,
node_b_id: node_b_mem,
node_c_id: node_c_mem,
node_a = create_node_profile(500 * 1024)
node_b = create_node_profile(500 * 1024)
node_c = create_node_profile(1000 * 1024)
node_profiles = {
node_a_id: node_a,
node_b_id: node_b,
node_c_id: node_c,
}
topology = Topology()
@@ -127,7 +130,9 @@ def test_filter_multiple_cycles_by_memory():
cycles = topology.get_cycles()
# act
filtered_cycles = filter_cycles_by_memory(cycles, node_memory, Memory.from_kb(1500))
filtered_cycles = filter_cycles_by_memory(
cycles, node_profiles, Memory.from_kb(1500)
)
# assert
assert len(filtered_cycles) == 1
@@ -223,17 +228,18 @@ def test_get_shard_assignments(
topology.add_connection(connection3)
topology.add_connection(connection4)
node_a_mem = create_node_memory(available_memory[0] * 1024)
node_b_mem = create_node_memory(available_memory[1] * 1024)
node_c_mem = create_node_memory(available_memory[2] * 1024)
node_memory = {
node_a_id: node_a_mem,
node_b_id: node_b_mem,
node_c_id: node_c_mem,
node_a = create_node_profile(available_memory[0] * 1024)
node_b = create_node_profile(available_memory[1] * 1024)
node_c = create_node_profile(available_memory[2] * 1024)
node_profiles = {
node_a_id: node_a,
node_b_id: node_b,
node_c_id: node_c,
}
model_card = ModelCard(
model_meta = ModelMetadata(
model_id=ModelId("test-model"),
pretty_name="Test Model",
n_layers=total_layers,
storage_size=Memory.from_kb(1000),
hidden_size=1000,
@@ -247,7 +253,7 @@ def test_get_shard_assignments(
# act
shard_assignments = get_shard_assignments(
model_card, selected_cycle, Sharding.Pipeline, node_memory=node_memory
model_meta, selected_cycle, Sharding.Pipeline, node_profiles=node_profiles
)
# assert
@@ -337,28 +343,38 @@ def test_get_mlx_jaccl_coordinators():
source=node_a_id, sink=node_c_id, edge=create_socket_connection(6)
)
network_a = NodeNetworkInfo(
interfaces=[
NetworkInterfaceInfo(name="en0", ip_address="169.254.0.5"),
NetworkInterfaceInfo(name="en0", ip_address="169.254.0.2"),
]
npp = NodePerformanceProfile(
model_id="test",
chip_id="test",
friendly_name="test",
memory=MemoryUsage.from_bytes(
ram_total=0,
ram_available=0,
swap_total=0,
swap_available=0,
),
network_interfaces=[],
system=SystemPerformanceProfile(),
)
network_b = NodeNetworkInfo(
interfaces=[
NetworkInterfaceInfo(name="en0", ip_address="169.254.0.1"),
NetworkInterfaceInfo(name="en0", ip_address="169.254.0.4"),
]
)
network_c = NodeNetworkInfo(
interfaces=[
NetworkInterfaceInfo(name="en0", ip_address="169.254.0.3"),
NetworkInterfaceInfo(name="en0", ip_address="169.254.0.6"),
]
)
node_network = {
node_a_id: network_a,
node_b_id: network_b,
node_c_id: network_c,
npp_a = copy(npp)
npp_a.network_interfaces = [
NetworkInterfaceInfo(name="en0", ip_address="169.254.0.5"),
NetworkInterfaceInfo(name="en0", ip_address="169.254.0.2"),
]
npp_b = copy(npp)
npp_b.network_interfaces = [
NetworkInterfaceInfo(name="en0", ip_address="169.254.0.1"),
NetworkInterfaceInfo(name="en0", ip_address="169.254.0.4"),
]
npp_c = copy(npp)
npp_c.network_interfaces = [
NetworkInterfaceInfo(name="en0", ip_address="169.254.0.3"),
NetworkInterfaceInfo(name="en0", ip_address="169.254.0.6"),
]
node_profiles = {
node_a_id: npp_a,
node_b_id: npp_b,
node_c_id: npp_c,
}
topology = Topology()
@@ -378,7 +394,7 @@ def test_get_mlx_jaccl_coordinators():
node_a_id,
coordinator_port=5000,
cycle_digraph=topology,
node_network=node_network,
node_profiles=node_profiles,
)
# assert
@@ -480,9 +496,9 @@ def test_get_shard_assignments_insufficient_memory_raises():
topology = Topology()
# Node C has only 10 KB but would need 50 KB for 1 layer (1000 KB / 20 layers)
node_a_mem = create_node_memory(900 * 1024)
node_b_mem = create_node_memory(50 * 1024)
node_c_mem = create_node_memory(10 * 1024) # Insufficient memory
node_a = create_node_profile(900 * 1024)
node_b = create_node_profile(50 * 1024)
node_c = create_node_profile(10 * 1024) # Insufficient memory
topology.add_node(node_a_id)
topology.add_node(node_b_id)
@@ -505,14 +521,15 @@ def test_get_shard_assignments_insufficient_memory_raises():
topology.add_connection(conn_c_a)
topology.add_connection(conn_b_a)
node_memory = {
node_a_id: node_a_mem,
node_b_id: node_b_mem,
node_c_id: node_c_mem,
profiles = {
node_a_id: node_a,
node_b_id: node_b,
node_c_id: node_c,
}
model_card = ModelCard(
model_meta = ModelMetadata(
model_id=ModelId("test-model"),
pretty_name="Test Model",
n_layers=20,
storage_size=Memory.from_kb(1000),
hidden_size=1000,
@@ -522,6 +539,4 @@ def test_get_shard_assignments_insufficient_memory_raises():
selected_cycle = cycles[0]
with pytest.raises(ValueError, match="insufficient memory"):
get_shard_assignments(
model_card, selected_cycle, Sharding.Pipeline, node_memory
)
get_shard_assignments(model_meta, selected_cycle, Sharding.Pipeline, profiles)

View File

@@ -3,6 +3,11 @@ import pytest
from exo.shared.topology import Topology
from exo.shared.types.common import NodeId
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.profiling import (
MemoryUsage,
NodePerformanceProfile,
SystemPerformanceProfile,
)
from exo.shared.types.topology import Connection, SocketConnection
@@ -18,6 +23,22 @@ def socket_connection() -> SocketConnection:
)
@pytest.fixture
def node_profile() -> NodePerformanceProfile:
memory_profile = MemoryUsage.from_bytes(
ram_total=1000, ram_available=1000, swap_total=1000, swap_available=1000
)
system_profile = SystemPerformanceProfile()
return NodePerformanceProfile(
model_id="test",
chip_id="test",
friendly_name="test",
memory=memory_profile,
network_interfaces=[],
system=system_profile,
)
def test_add_node(topology: Topology):
# arrange
node_id = NodeId()

View File

@@ -25,11 +25,7 @@ from exo.shared.types.events import (
TopologyEdgeCreated,
TopologyEdgeDeleted,
)
from exo.shared.types.profiling import (
NodeIdentity,
NodeNetworkInfo,
NodeThunderboltInfo,
)
from exo.shared.types.profiling import NodePerformanceProfile
from exo.shared.types.state import State
from exo.shared.types.tasks import Task, TaskId, TaskStatus
from exo.shared.types.topology import Connection, RDMAConnection
@@ -197,43 +193,22 @@ def apply_runner_deleted(event: RunnerDeleted, state: State) -> State:
def apply_node_timed_out(event: NodeTimedOut, state: State) -> State:
topology = copy.deepcopy(state.topology)
topology.remove_node(event.node_id)
state.topology.remove_node(event.node_id)
node_profiles = {
key: value for key, value in state.node_profiles.items() if key != event.node_id
}
last_seen = {
key: value for key, value in state.last_seen.items() if key != event.node_id
}
downloads = {
key: value for key, value in state.downloads.items() if key != event.node_id
}
# Clean up all granular node mappings
node_identities = {
key: value
for key, value in state.node_identities.items()
if key != event.node_id
}
node_memory = {
key: value for key, value in state.node_memory.items() if key != event.node_id
}
node_system = {
key: value for key, value in state.node_system.items() if key != event.node_id
}
node_network = {
key: value for key, value in state.node_network.items() if key != event.node_id
}
node_thunderbolt = {
key: value
for key, value in state.node_thunderbolt.items()
if key != event.node_id
}
return state.model_copy(
update={
"downloads": downloads,
"topology": topology,
"node_profiles": node_profiles,
"last_seen": last_seen,
"node_identities": node_identities,
"node_memory": node_memory,
"node_system": node_system,
"node_network": node_network,
"node_thunderbolt": node_thunderbolt,
}
)
@@ -242,60 +217,29 @@ def apply_node_gathered_info(event: NodeGatheredInfo, state: State) -> State:
topology = copy.deepcopy(state.topology)
topology.add_node(event.node_id)
info = event.info
# Build update dict with only the mappings that change
update: dict[str, object] = {
"last_seen": {
**state.last_seen,
event.node_id: datetime.fromisoformat(event.when),
},
"topology": topology,
}
profile = state.node_profiles.get(event.node_id, NodePerformanceProfile())
match info:
case MacmonMetrics():
update["node_system"] = {
**state.node_system,
event.node_id: info.system_profile,
}
update["node_memory"] = {**state.node_memory, event.node_id: info.memory}
profile.system = info.system_profile
profile.memory = info.memory
case MemoryUsage():
update["node_memory"] = {**state.node_memory, event.node_id: info}
profile.memory = info
case NodeConfig():
pass
case MiscData():
current_identity = state.node_identities.get(event.node_id, NodeIdentity())
new_identity = current_identity.model_copy(
update={"friendly_name": info.friendly_name}
)
update["node_identities"] = {
**state.node_identities,
event.node_id: new_identity,
}
profile.friendly_name = info.friendly_name
case StaticNodeInformation():
current_identity = state.node_identities.get(event.node_id, NodeIdentity())
new_identity = current_identity.model_copy(
update={"model_id": info.model, "chip_id": info.chip}
)
update["node_identities"] = {
**state.node_identities,
event.node_id: new_identity,
}
profile.model_id = info.model
profile.chip_id = info.chip
case NodeNetworkInterfaces():
update["node_network"] = {
**state.node_network,
event.node_id: NodeNetworkInfo(interfaces=info.ifaces),
}
profile.network_interfaces = info.ifaces
case MacThunderboltIdentifiers():
update["node_thunderbolt"] = {
**state.node_thunderbolt,
event.node_id: NodeThunderboltInfo(interfaces=info.idents),
}
profile.tb_interfaces = info.idents
case MacThunderboltConnections():
conn_map = {
tb_ident.domain_uuid: (nid, tb_ident.rdma_interface)
for nid in state.node_thunderbolt
for tb_ident in state.node_thunderbolt[nid].interfaces
for nid in state.node_profiles
for tb_ident in state.node_profiles[nid].tb_interfaces
}
as_rdma_conns = [
Connection(
@@ -312,7 +256,15 @@ def apply_node_gathered_info(event: NodeGatheredInfo, state: State) -> State:
]
topology.replace_all_out_rdma_connections(event.node_id, as_rdma_conns)
return state.model_copy(update=update)
last_seen = {**state.last_seen, event.node_id: datetime.fromisoformat(event.when)}
new_profiles = {**state.node_profiles, event.node_id: profile}
return state.model_copy(
update={
"node_profiles": new_profiles,
"last_seen": last_seen,
"topology": topology,
}
)
def apply_topology_edge_created(event: TopologyEdgeCreated, state: State) -> State:

View File

@@ -1,310 +1,552 @@
from pydantic import PositiveInt
from exo.shared.types.common import Id
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.utils.pydantic_ext import CamelCaseModel
class ModelId(Id):
def normalize(self) -> str:
return self.replace("/", "--")
def short(self) -> str:
return self.split("/")[-1]
class ModelCard(CamelCaseModel):
short_id: str
model_id: ModelId
storage_size: Memory
n_layers: PositiveInt
hidden_size: PositiveInt
supports_tensor: bool
name: str
description: str
tags: list[str]
metadata: ModelMetadata
MODEL_CARDS: dict[str, ModelCard] = {
# deepseek v3
"deepseek-v3.1-4bit": ModelCard(
short_id="deepseek-v3.1-4bit",
model_id=ModelId("mlx-community/DeepSeek-V3.1-4bit"),
storage_size=Memory.from_gb(378),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
name="DeepSeek V3.1 (4-bit)",
description="""DeepSeek V3.1 is a large language model trained on the DeepSeek V3.1 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/DeepSeek-V3.1-4bit"),
pretty_name="DeepSeek V3.1 (4-bit)",
storage_size=Memory.from_gb(378),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
),
),
"deepseek-v3.1-8bit": ModelCard(
short_id="deepseek-v3.1-8bit",
model_id=ModelId("mlx-community/DeepSeek-V3.1-8bit"),
storage_size=Memory.from_gb(713),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
name="DeepSeek V3.1 (8-bit)",
description="""DeepSeek V3.1 is a large language model trained on the DeepSeek V3.1 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/DeepSeek-V3.1-8bit"),
pretty_name="DeepSeek V3.1 (8-bit)",
storage_size=Memory.from_gb(713),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
),
),
# kimi k2
"kimi-k2-instruct-4bit": ModelCard(
short_id="kimi-k2-instruct-4bit",
model_id=ModelId("mlx-community/Kimi-K2-Instruct-4bit"),
storage_size=Memory.from_gb(578),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
name="Kimi K2 Instruct (4-bit)",
description="""Kimi K2 is a large language model trained on the Kimi K2 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Kimi-K2-Instruct-4bit"),
pretty_name="Kimi K2 Instruct (4-bit)",
storage_size=Memory.from_gb(578),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
),
),
"kimi-k2-thinking": ModelCard(
short_id="kimi-k2-thinking",
model_id=ModelId("mlx-community/Kimi-K2-Thinking"),
storage_size=Memory.from_gb(658),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
name="Kimi K2 Thinking (4-bit)",
description="""Kimi K2 Thinking is the latest, most capable version of open-source thinking model.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Kimi-K2-Thinking"),
pretty_name="Kimi K2 Thinking (4-bit)",
storage_size=Memory.from_gb(658),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
),
),
# llama-3.1
"llama-3.1-8b": ModelCard(
short_id="llama-3.1-8b",
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-4bit"),
storage_size=Memory.from_mb(4423),
n_layers=32,
hidden_size=4096,
supports_tensor=True,
name="Llama 3.1 8B (4-bit)",
description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-4bit"),
pretty_name="Llama 3.1 8B (4-bit)",
storage_size=Memory.from_mb(4423),
n_layers=32,
hidden_size=4096,
supports_tensor=True,
),
),
"llama-3.1-8b-8bit": ModelCard(
short_id="llama-3.1-8b-8bit",
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-8bit"),
storage_size=Memory.from_mb(8540),
n_layers=32,
hidden_size=4096,
supports_tensor=True,
name="Llama 3.1 8B (8-bit)",
description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-8bit"),
pretty_name="Llama 3.1 8B (8-bit)",
storage_size=Memory.from_mb(8540),
n_layers=32,
hidden_size=4096,
supports_tensor=True,
),
),
"llama-3.1-8b-bf16": ModelCard(
short_id="llama-3.1-8b-bf16",
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-bf16"),
storage_size=Memory.from_mb(16100),
n_layers=32,
hidden_size=4096,
supports_tensor=True,
name="Llama 3.1 8B (BF16)",
description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-bf16"),
pretty_name="Llama 3.1 8B (BF16)",
storage_size=Memory.from_mb(16100),
n_layers=32,
hidden_size=4096,
supports_tensor=True,
),
),
"llama-3.1-70b": ModelCard(
short_id="llama-3.1-70b",
model_id=ModelId("mlx-community/Meta-Llama-3.1-70B-Instruct-4bit"),
storage_size=Memory.from_mb(38769),
n_layers=80,
hidden_size=8192,
supports_tensor=True,
name="Llama 3.1 70B (4-bit)",
description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Meta-Llama-3.1-70B-Instruct-4bit"),
pretty_name="Llama 3.1 70B (4-bit)",
storage_size=Memory.from_mb(38769),
n_layers=80,
hidden_size=8192,
supports_tensor=True,
),
),
# llama-3.2
"llama-3.2-1b": ModelCard(
short_id="llama-3.2-1b",
model_id=ModelId("mlx-community/Llama-3.2-1B-Instruct-4bit"),
storage_size=Memory.from_mb(696),
n_layers=16,
hidden_size=2048,
supports_tensor=True,
name="Llama 3.2 1B (4-bit)",
description="""Llama 3.2 is a large language model trained on the Llama 3.2 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Llama-3.2-1B-Instruct-4bit"),
pretty_name="Llama 3.2 1B (4-bit)",
storage_size=Memory.from_mb(696),
n_layers=16,
hidden_size=2048,
supports_tensor=True,
),
),
"llama-3.2-3b": ModelCard(
short_id="llama-3.2-3b",
model_id=ModelId("mlx-community/Llama-3.2-3B-Instruct-4bit"),
storage_size=Memory.from_mb(1777),
n_layers=28,
hidden_size=3072,
supports_tensor=True,
name="Llama 3.2 3B (4-bit)",
description="""Llama 3.2 is a large language model trained on the Llama 3.2 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Llama-3.2-3B-Instruct-4bit"),
pretty_name="Llama 3.2 3B (4-bit)",
storage_size=Memory.from_mb(1777),
n_layers=28,
hidden_size=3072,
supports_tensor=True,
),
),
"llama-3.2-3b-8bit": ModelCard(
short_id="llama-3.2-3b-8bit",
model_id=ModelId("mlx-community/Llama-3.2-3B-Instruct-8bit"),
storage_size=Memory.from_mb(3339),
n_layers=28,
hidden_size=3072,
supports_tensor=True,
name="Llama 3.2 3B (8-bit)",
description="""Llama 3.2 is a large language model trained on the Llama 3.2 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Llama-3.2-3B-Instruct-8bit"),
pretty_name="Llama 3.2 3B (8-bit)",
storage_size=Memory.from_mb(3339),
n_layers=28,
hidden_size=3072,
supports_tensor=True,
),
),
# llama-3.3
"llama-3.3-70b": ModelCard(
short_id="llama-3.3-70b",
model_id=ModelId("mlx-community/Llama-3.3-70B-Instruct-4bit"),
storage_size=Memory.from_mb(38769),
n_layers=80,
hidden_size=8192,
supports_tensor=True,
name="Llama 3.3 70B (4-bit)",
description="""The Meta Llama 3.3 multilingual large language model (LLM) is an instruction tuned generative model in 70B (text in/text out)""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Llama-3.3-70B-Instruct-4bit"),
pretty_name="Llama 3.3 70B",
storage_size=Memory.from_mb(38769),
n_layers=80,
hidden_size=8192,
supports_tensor=True,
),
),
"llama-3.3-70b-8bit": ModelCard(
short_id="llama-3.3-70b-8bit",
model_id=ModelId("mlx-community/Llama-3.3-70B-Instruct-8bit"),
storage_size=Memory.from_mb(73242),
n_layers=80,
hidden_size=8192,
supports_tensor=True,
name="Llama 3.3 70B (8-bit)",
description="""The Meta Llama 3.3 multilingual large language model (LLM) is an instruction tuned generative model in 70B (text in/text out)""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Llama-3.3-70B-Instruct-8bit"),
pretty_name="Llama 3.3 70B (8-bit)",
storage_size=Memory.from_mb(73242),
n_layers=80,
hidden_size=8192,
supports_tensor=True,
),
),
"llama-3.3-70b-fp16": ModelCard(
short_id="llama-3.3-70b-fp16",
model_id=ModelId("mlx-community/llama-3.3-70b-instruct-fp16"),
storage_size=Memory.from_mb(137695),
n_layers=80,
hidden_size=8192,
supports_tensor=True,
name="Llama 3.3 70B (FP16)",
description="""The Meta Llama 3.3 multilingual large language model (LLM) is an instruction tuned generative model in 70B (text in/text out)""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/llama-3.3-70b-instruct-fp16"),
pretty_name="Llama 3.3 70B (FP16)",
storage_size=Memory.from_mb(137695),
n_layers=80,
hidden_size=8192,
supports_tensor=True,
),
),
# qwen3
"qwen3-0.6b": ModelCard(
short_id="qwen3-0.6b",
model_id=ModelId("mlx-community/Qwen3-0.6B-4bit"),
storage_size=Memory.from_mb(327),
n_layers=28,
hidden_size=1024,
supports_tensor=False,
name="Qwen3 0.6B (4-bit)",
description="""Qwen3 0.6B is a large language model trained on the Qwen3 0.6B dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-0.6B-4bit"),
pretty_name="Qwen3 0.6B (4-bit)",
storage_size=Memory.from_mb(327),
n_layers=28,
hidden_size=1024,
supports_tensor=False,
),
),
"qwen3-0.6b-8bit": ModelCard(
short_id="qwen3-0.6b-8bit",
model_id=ModelId("mlx-community/Qwen3-0.6B-8bit"),
storage_size=Memory.from_mb(666),
n_layers=28,
hidden_size=1024,
supports_tensor=False,
name="Qwen3 0.6B (8-bit)",
description="""Qwen3 0.6B is a large language model trained on the Qwen3 0.6B dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-0.6B-8bit"),
pretty_name="Qwen3 0.6B (8-bit)",
storage_size=Memory.from_mb(666),
n_layers=28,
hidden_size=1024,
supports_tensor=False,
),
),
"qwen3-30b": ModelCard(
short_id="qwen3-30b",
model_id=ModelId("mlx-community/Qwen3-30B-A3B-4bit"),
storage_size=Memory.from_mb(16797),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
name="Qwen3 30B A3B (4-bit)",
description="""Qwen3 30B is a large language model trained on the Qwen3 30B dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-30B-A3B-4bit"),
pretty_name="Qwen3 30B A3B (4-bit)",
storage_size=Memory.from_mb(16797),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
),
"qwen3-30b-8bit": ModelCard(
short_id="qwen3-30b-8bit",
model_id=ModelId("mlx-community/Qwen3-30B-A3B-8bit"),
storage_size=Memory.from_mb(31738),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
name="Qwen3 30B A3B (8-bit)",
description="""Qwen3 30B is a large language model trained on the Qwen3 30B dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-30B-A3B-8bit"),
pretty_name="Qwen3 30B A3B (8-bit)",
storage_size=Memory.from_mb(31738),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
),
"qwen3-80b-a3B-4bit": ModelCard(
short_id="qwen3-80b-a3B-4bit",
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-4bit"),
storage_size=Memory.from_mb(44800),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
name="Qwen3 80B A3B (4-bit)",
description="""Qwen3 80B""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-4bit"),
pretty_name="Qwen3 80B A3B (4-bit)",
storage_size=Memory.from_mb(44800),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
),
"qwen3-80b-a3B-8bit": ModelCard(
short_id="qwen3-80b-a3B-8bit",
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit"),
storage_size=Memory.from_mb(84700),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
name="Qwen3 80B A3B (8-bit)",
description="""Qwen3 80B""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit"),
pretty_name="Qwen3 80B A3B (8-bit)",
storage_size=Memory.from_mb(84700),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
),
"qwen3-80b-a3B-thinking-4bit": ModelCard(
short_id="qwen3-80b-a3B-thinking-4bit",
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit"),
storage_size=Memory.from_mb(84700),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
name="Qwen3 80B A3B Thinking (4-bit)",
description="""Qwen3 80B Reasoning model""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit"),
pretty_name="Qwen3 80B A3B (4-bit)",
storage_size=Memory.from_mb(84700),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
),
"qwen3-80b-a3B-thinking-8bit": ModelCard(
short_id="qwen3-80b-a3B-thinking-8bit",
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit"),
storage_size=Memory.from_mb(84700),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
name="Qwen3 80B A3B Thinking (8-bit)",
description="""Qwen3 80B Reasoning model""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit"),
pretty_name="Qwen3 80B A3B (8-bit)",
storage_size=Memory.from_mb(84700),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
),
"qwen3-235b-a22b-4bit": ModelCard(
short_id="qwen3-235b-a22b-4bit",
model_id=ModelId("mlx-community/Qwen3-235B-A22B-Instruct-2507-4bit"),
storage_size=Memory.from_gb(132),
n_layers=94,
hidden_size=4096,
supports_tensor=True,
name="Qwen3 235B A22B (4-bit)",
description="""Qwen3 235B (Active 22B) is a large language model trained on the Qwen3 235B dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-235B-A22B-Instruct-2507-4bit"),
pretty_name="Qwen3 235B A22B (4-bit)",
storage_size=Memory.from_gb(132),
n_layers=94,
hidden_size=4096,
supports_tensor=True,
),
),
"qwen3-235b-a22b-8bit": ModelCard(
short_id="qwen3-235b-a22b-8bit",
model_id=ModelId("mlx-community/Qwen3-235B-A22B-Instruct-2507-8bit"),
storage_size=Memory.from_gb(250),
n_layers=94,
hidden_size=4096,
supports_tensor=True,
name="Qwen3 235B A22B (8-bit)",
description="""Qwen3 235B (Active 22B) is a large language model trained on the Qwen3 235B dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-235B-A22B-Instruct-2507-8bit"),
pretty_name="Qwen3 235B A22B (8-bit)",
storage_size=Memory.from_gb(250),
n_layers=94,
hidden_size=4096,
supports_tensor=True,
),
),
"qwen3-coder-480b-a35b-4bit": ModelCard(
short_id="qwen3-coder-480b-a35b-4bit",
model_id=ModelId("mlx-community/Qwen3-Coder-480B-A35B-Instruct-4bit"),
storage_size=Memory.from_gb(270),
n_layers=62,
hidden_size=6144,
supports_tensor=True,
name="Qwen3 Coder 480B A35B (4-bit)",
description="""Qwen3 Coder 480B (Active 35B) is a large language model trained on the Qwen3 Coder 480B dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-Coder-480B-A35B-Instruct-4bit"),
pretty_name="Qwen3 Coder 480B A35B (4-bit)",
storage_size=Memory.from_gb(270),
n_layers=62,
hidden_size=6144,
supports_tensor=True,
),
),
"qwen3-coder-480b-a35b-8bit": ModelCard(
short_id="qwen3-coder-480b-a35b-8bit",
model_id=ModelId("mlx-community/Qwen3-Coder-480B-A35B-Instruct-8bit"),
storage_size=Memory.from_gb(540),
n_layers=62,
hidden_size=6144,
supports_tensor=True,
name="Qwen3 Coder 480B A35B (8-bit)",
description="""Qwen3 Coder 480B (Active 35B) is a large language model trained on the Qwen3 Coder 480B dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-Coder-480B-A35B-Instruct-8bit"),
pretty_name="Qwen3 Coder 480B A35B (8-bit)",
storage_size=Memory.from_gb(540),
n_layers=62,
hidden_size=6144,
supports_tensor=True,
),
),
# gpt-oss
"gpt-oss-120b-MXFP4-Q8": ModelCard(
short_id="gpt-oss-120b-MXFP4-Q8",
model_id=ModelId("mlx-community/gpt-oss-120b-MXFP4-Q8"),
storage_size=Memory.from_kb(68_996_301),
n_layers=36,
hidden_size=2880,
supports_tensor=True,
name="GPT-OSS 120B (MXFP4-Q8, MLX)",
description="""OpenAI's GPT-OSS 120B is a 117B-parameter Mixture-of-Experts model designed for high-reasoning and general-purpose use; this variant is a 4-bit MLX conversion for Apple Silicon.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/gpt-oss-120b-MXFP4-Q8"),
pretty_name="GPT-OSS 120B (MXFP4-Q8, MLX)",
storage_size=Memory.from_kb(68_996_301),
n_layers=36,
hidden_size=2880,
supports_tensor=True,
),
),
"gpt-oss-20b-MXFP4-Q8": ModelCard(
short_id="gpt-oss-20b-MXFP4-Q8",
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q8"),
storage_size=Memory.from_kb(11_744_051),
n_layers=24,
hidden_size=2880,
supports_tensor=True,
name="GPT-OSS 20B (MXFP4-Q8, MLX)",
description="""OpenAI's GPT-OSS 20B is a medium-sized MoE model for lower-latency and local or specialized use cases; this variant is a 4-bit MLX conversion for Apple Silicon.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q8"),
pretty_name="GPT-OSS 20B (MXFP4-Q8, MLX)",
storage_size=Memory.from_kb(11_744_051),
n_layers=24,
hidden_size=2880,
supports_tensor=True,
),
),
# glm 4.5
"glm-4.5-air-8bit": ModelCard(
# Needs to be quantized g32 or g16 to work with tensor parallel
short_id="glm-4.5-air-8bit",
model_id=ModelId("mlx-community/GLM-4.5-Air-8bit"),
storage_size=Memory.from_gb(114),
n_layers=46,
hidden_size=4096,
supports_tensor=False,
name="GLM 4.5 Air 8bit",
description="""GLM 4.5 Air 8bit""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/GLM-4.5-Air-8bit"),
pretty_name="GLM 4.5 Air 8bit",
storage_size=Memory.from_gb(114),
n_layers=46,
hidden_size=4096,
supports_tensor=False,
),
),
"glm-4.5-air-bf16": ModelCard(
short_id="glm-4.5-air-bf16",
model_id=ModelId("mlx-community/GLM-4.5-Air-bf16"),
storage_size=Memory.from_gb(214),
n_layers=46,
hidden_size=4096,
supports_tensor=True,
name="GLM 4.5 Air bf16",
description="""GLM 4.5 Air bf16""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/GLM-4.5-Air-bf16"),
pretty_name="GLM 4.5 Air bf16",
storage_size=Memory.from_gb(214),
n_layers=46,
hidden_size=4096,
supports_tensor=True,
),
),
# glm 4.7
"glm-4.7-4bit": ModelCard(
short_id="glm-4.7-4bit",
model_id=ModelId("mlx-community/GLM-4.7-4bit"),
storage_size=Memory.from_bytes(198556925568),
n_layers=91,
hidden_size=5120,
supports_tensor=True,
name="GLM 4.7 4bit",
description="GLM 4.7 4bit",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/GLM-4.7-4bit"),
pretty_name="GLM 4.7 4bit",
storage_size=Memory.from_bytes(198556925568),
n_layers=91,
hidden_size=5120,
supports_tensor=True,
),
),
"glm-4.7-6bit": ModelCard(
short_id="glm-4.7-6bit",
model_id=ModelId("mlx-community/GLM-4.7-6bit"),
storage_size=Memory.from_bytes(286737579648),
n_layers=91,
hidden_size=5120,
supports_tensor=True,
name="GLM 4.7 6bit",
description="GLM 4.7 6bit",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/GLM-4.7-6bit"),
pretty_name="GLM 4.7 6bit",
storage_size=Memory.from_bytes(286737579648),
n_layers=91,
hidden_size=5120,
supports_tensor=True,
),
),
"glm-4.7-8bit-gs32": ModelCard(
short_id="glm-4.7-8bit-gs32",
model_id=ModelId("mlx-community/GLM-4.7-8bit-gs32"),
storage_size=Memory.from_bytes(396963397248),
n_layers=91,
hidden_size=5120,
supports_tensor=True,
),
# glm 4.7 flash
"glm-4.7-flash-4bit": ModelCard(
model_id=ModelId("mlx-community/GLM-4.7-Flash-4bit"),
storage_size=Memory.from_gb(18),
n_layers=47,
hidden_size=2048,
supports_tensor=True,
),
"glm-4.7-flash-5bit": ModelCard(
model_id=ModelId("mlx-community/GLM-4.7-Flash-5bit"),
storage_size=Memory.from_gb(21),
n_layers=47,
hidden_size=2048,
supports_tensor=True,
),
"glm-4.7-flash-6bit": ModelCard(
model_id=ModelId("mlx-community/GLM-4.7-Flash-6bit"),
storage_size=Memory.from_gb(25),
n_layers=47,
hidden_size=2048,
supports_tensor=True,
),
"glm-4.7-flash-8bit": ModelCard(
model_id=ModelId("mlx-community/GLM-4.7-Flash-8bit"),
storage_size=Memory.from_gb(32),
n_layers=47,
hidden_size=2048,
supports_tensor=True,
name="GLM 4.7 8bit (gs32)",
description="GLM 4.7 8bit (gs32)",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/GLM-4.7-8bit-gs32"),
pretty_name="GLM 4.7 8bit (gs32)",
storage_size=Memory.from_bytes(396963397248),
n_layers=91,
hidden_size=5120,
supports_tensor=True,
),
),
# minimax-m2
"minimax-m2.1-8bit": ModelCard(
short_id="minimax-m2.1-8bit",
model_id=ModelId("mlx-community/MiniMax-M2.1-8bit"),
storage_size=Memory.from_bytes(242986745856),
n_layers=61,
hidden_size=3072,
supports_tensor=True,
name="MiniMax M2.1 8bit",
description="MiniMax M2.1 8bit",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/MiniMax-M2.1-8bit"),
pretty_name="MiniMax M2.1 8bit",
storage_size=Memory.from_bytes(242986745856),
n_layers=61,
hidden_size=3072,
supports_tensor=True,
),
),
"minimax-m2.1-3bit": ModelCard(
short_id="minimax-m2.1-3bit",
model_id=ModelId("mlx-community/MiniMax-M2.1-3bit"),
storage_size=Memory.from_bytes(100086644736),
n_layers=61,
hidden_size=3072,
supports_tensor=True,
name="MiniMax M2.1 3bit",
description="MiniMax M2.1 3bit",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/MiniMax-M2.1-3bit"),
pretty_name="MiniMax M2.1 3bit",
storage_size=Memory.from_bytes(100086644736),
n_layers=61,
hidden_size=3072,
supports_tensor=True,
),
),
}

View File

@@ -6,8 +6,9 @@ from huggingface_hub import model_info
from loguru import logger
from pydantic import BaseModel, Field
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard, ModelId
from exo.shared.models.model_cards import MODEL_CARDS
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.worker.download.download_utils import (
ModelSafetensorsIndex,
download_file_with_retry,
@@ -91,18 +92,18 @@ async def get_safetensors_size(model_id: str) -> Memory:
return Memory.from_bytes(info.safetensors.total)
_model_card_cache: dict[str, ModelCard] = {}
_model_meta_cache: dict[str, ModelMetadata] = {}
async def get_model_card(model_id: str) -> ModelCard:
if model_id in _model_card_cache:
return _model_card_cache[model_id]
model_card = await _get_model_card(model_id)
_model_card_cache[model_id] = model_card
return model_card
async def get_model_meta(model_id: str) -> ModelMetadata:
if model_id in _model_meta_cache:
return _model_meta_cache[model_id]
model_meta = await _get_model_meta(model_id)
_model_meta_cache[model_id] = model_meta
return model_meta
async def _get_model_card(model_id: str) -> ModelCard:
async def _get_model_meta(model_id: str) -> ModelMetadata:
"""Fetches storage size and number of layers for a Hugging Face model, returns Pydantic ModelMeta."""
config_data = await get_config_data(model_id)
num_layers = config_data.layer_count
@@ -112,11 +113,14 @@ async def _get_model_card(model_id: str) -> ModelCard:
None,
)
return ModelCard(
return ModelMetadata(
model_id=ModelId(model_id),
pretty_name=model_card.name if model_card is not None else model_id,
storage_size=mem_size_bytes,
n_layers=num_layers,
hidden_size=config_data.hidden_size or 0,
# TODO: all custom models currently do not support tensor. We could add a dynamic test for this?
supports_tensor=model_card.supports_tensor if model_card is not None else False,
supports_tensor=model_card.metadata.supports_tensor
if model_card is not None
else False,
)

View File

@@ -7,8 +7,8 @@ import pytest
from _pytest.logging import LogCaptureFixture
from loguru import logger
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.worker.shards import PipelineShardMetadata, ShardMetadata
@@ -31,8 +31,9 @@ def get_pipeline_shard_metadata(
model_id: ModelId, device_rank: int, world_size: int = 1
) -> ShardMetadata:
return PipelineShardMetadata(
model_card=ModelCard(
model_meta=ModelMetadata(
model_id=model_id,
pretty_name=str(model_id),
storage_size=Memory.from_mb(100000),
n_layers=32,
hidden_size=1000,

View File

@@ -4,9 +4,9 @@ from typing import Any, Literal
from pydantic import BaseModel, Field, field_validator
from pydantic_core import PydanticUseDefault
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.types.common import CommandId
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.shards import Sharding
@@ -206,7 +206,7 @@ class DeleteInstanceTaskParams(BaseModel):
class CreateInstanceResponse(BaseModel):
message: str
command_id: CommandId
model_card: ModelCard
model_meta: ModelMetadata
class DeleteInstanceResponse(BaseModel):

View File

@@ -1,10 +1,10 @@
from enum import Enum
from exo.shared.models.model_cards import ModelId
from exo.shared.types.api import GenerationStats
from exo.utils.pydantic_ext import TaggedModel
from .api import FinishReason
from .models import ModelId
class ChunkType(str, Enum):

View File

@@ -1,8 +1,8 @@
from pydantic import Field
from exo.shared.models.model_cards import ModelCard
from exo.shared.types.api import ChatCompletionTaskParams
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.models import ModelMetadata
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.shards import Sharding
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
@@ -21,7 +21,7 @@ class ChatCompletion(BaseCommand):
class PlaceInstance(BaseCommand):
model_card: ModelCard
model_meta: ModelMetadata
sharding: Sharding
instance_meta: InstanceMeta
min_nodes: int

View File

@@ -16,9 +16,7 @@ class Id(str):
cls, _source: type, handler: GetCoreSchemaHandler
) -> core_schema.CoreSchema:
# Just use a plain string schema
return core_schema.no_info_after_validator_function(
cls, core_schema.str_schema()
)
return core_schema.str_schema()
class NodeId(Id):

View File

@@ -0,0 +1,18 @@
from pydantic import PositiveInt
from exo.shared.types.common import Id
from exo.shared.types.memory import Memory
from exo.utils.pydantic_ext import CamelCaseModel
class ModelId(Id):
pass
class ModelMetadata(CamelCaseModel):
model_id: ModelId
pretty_name: str
storage_size: Memory
n_layers: PositiveInt
hidden_size: PositiveInt
supports_tensor: bool

View File

@@ -53,21 +53,13 @@ class NetworkInterfaceInfo(CamelCaseModel):
ip_address: str
class NodeIdentity(CamelCaseModel):
"""Static and slow-changing node identification data."""
class NodePerformanceProfile(CamelCaseModel):
model_id: str = "Unknown"
chip_id: str = "Unknown"
friendly_name: str = "Unknown"
class NodeNetworkInfo(CamelCaseModel):
"""Network interface information for a node."""
interfaces: Sequence[NetworkInterfaceInfo] = []
class NodeThunderboltInfo(CamelCaseModel):
"""Thunderbolt interface identifiers for a node."""
interfaces: Sequence[ThunderboltIdentifier] = []
memory: MemoryUsage = MemoryUsage.from_bytes(
ram_total=0, ram_available=0, swap_total=0, swap_available=0
)
network_interfaces: Sequence[NetworkInterfaceInfo] = []
tb_interfaces: Sequence[ThunderboltIdentifier] = []
system: SystemPerformanceProfile = SystemPerformanceProfile()

View File

@@ -7,13 +7,7 @@ from pydantic.alias_generators import to_camel
from exo.shared.topology import Topology, TopologySnapshot
from exo.shared.types.common import NodeId
from exo.shared.types.profiling import (
MemoryUsage,
NodeIdentity,
NodeNetworkInfo,
NodeThunderboltInfo,
SystemPerformanceProfile,
)
from exo.shared.types.profiling import NodePerformanceProfile
from exo.shared.types.tasks import Task, TaskId
from exo.shared.types.worker.downloads import DownloadProgress
from exo.shared.types.worker.instances import Instance, InstanceId
@@ -41,17 +35,11 @@ class State(CamelCaseModel):
runners: Mapping[RunnerId, RunnerStatus] = {}
downloads: Mapping[NodeId, Sequence[DownloadProgress]] = {}
tasks: Mapping[TaskId, Task] = {}
node_profiles: Mapping[NodeId, NodePerformanceProfile] = {}
last_seen: Mapping[NodeId, datetime] = {}
topology: Topology = Field(default_factory=Topology)
last_event_applied_idx: int = Field(default=-1, ge=-1)
# Granular node state mappings (update independently at different frequencies)
node_identities: Mapping[NodeId, NodeIdentity] = {}
node_memory: Mapping[NodeId, MemoryUsage] = {}
node_system: Mapping[NodeId, SystemPerformanceProfile] = {}
node_network: Mapping[NodeId, NodeNetworkInfo] = {}
node_thunderbolt: Mapping[NodeId, NodeThunderboltInfo] = {}
@field_serializer("topology", mode="plain")
def _encode_topology(self, value: Topology) -> TopologySnapshot:
return value.to_snapshot()

View File

@@ -2,8 +2,8 @@ from collections.abc import Mapping
from pydantic import model_validator
from exo.shared.models.model_cards import ModelId
from exo.shared.types.common import Id, NodeId
from exo.shared.types.models import ModelId
from exo.shared.types.worker.shards import ShardMetadata
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel

View File

@@ -2,7 +2,7 @@ from enum import Enum
from pydantic import Field
from exo.shared.models.model_cards import ModelCard
from exo.shared.types.models import ModelMetadata
from exo.utils.pydantic_ext import TaggedModel
@@ -17,7 +17,7 @@ class BaseShardMetadata(TaggedModel):
Replaces previous `Shard` object.
"""
model_card: ModelCard
model_meta: ModelMetadata
device_rank: int
world_size: int
@@ -41,7 +41,7 @@ class BaseShardMetadata(TaggedModel):
def __hash__(self) -> int:
return hash(
(
self.model_card.model_id,
self.model_meta.model_id,
self.start_layer,
self.end_layer,
self.n_layers,

View File

@@ -7,7 +7,7 @@ from loguru import logger
from exo.shared.topology import Topology
from exo.shared.types.common import NodeId
from exo.shared.types.profiling import NodeNetworkInfo
from exo.shared.types.profiling import NodePerformanceProfile
REACHABILITY_ATTEMPTS = 3
@@ -79,7 +79,7 @@ async def check_reachability(
async def check_reachable(
topology: Topology,
self_node_id: NodeId,
node_network: Mapping[NodeId, NodeNetworkInfo],
node_profiles: Mapping[NodeId, NodePerformanceProfile],
) -> dict[NodeId, set[str]]:
"""Check which nodes are reachable and return their IPs."""
@@ -98,11 +98,11 @@ async def check_reachable(
create_task_group() as tg,
):
for node_id in topology.list_nodes():
if node_id not in node_network:
if node_id not in node_profiles:
continue
if node_id == self_node_id:
continue
for iface in node_network[node_id].interfaces:
for iface in node_profiles[node_id].network_interfaces:
tg.start_soon(
check_reachability,
iface.ip_address,

View File

@@ -460,10 +460,10 @@ async def resolve_allow_patterns(shard: ShardMetadata) -> list[str]:
# (iii) Tensor parallel requires all files.
return ["*"]
try:
weight_map = await get_weight_map(str(shard.model_card.model_id))
weight_map = await get_weight_map(str(shard.model_meta.model_id))
return get_allow_patterns(weight_map, shard)
except Exception:
logger.error(f"Error getting weight map for {shard.model_card.model_id=}")
logger.error(f"Error getting weight map for {shard.model_meta.model_id=}")
logger.error(traceback.format_exc())
return ["*"]
@@ -532,18 +532,18 @@ async def download_shard(
allow_patterns: list[str] | None = None,
) -> tuple[Path, RepoDownloadProgress]:
if not skip_download:
logger.info(f"Downloading {shard.model_card.model_id=}")
logger.info(f"Downloading {shard.model_meta.model_id=}")
# Handle local paths
if await aios.path.exists(str(shard.model_card.model_id)):
logger.info(f"Using local model path {shard.model_card.model_id}")
local_path = Path(str(shard.model_card.model_id))
if await aios.path.exists(str(shard.model_meta.model_id)):
logger.info(f"Using local model path {shard.model_meta.model_id}")
local_path = Path(str(shard.model_meta.model_id))
return local_path, await download_progress_for_local_path(
str(shard.model_card.model_id), shard, local_path
str(shard.model_meta.model_id), shard, local_path
)
revision = "main"
target_dir = await ensure_models_dir() / str(shard.model_card.model_id).replace(
target_dir = await ensure_models_dir() / str(shard.model_meta.model_id).replace(
"/", "--"
)
if not skip_download:
@@ -552,13 +552,13 @@ async def download_shard(
if not allow_patterns:
allow_patterns = await resolve_allow_patterns(shard)
logger.info(f"Downloading {shard.model_card.model_id=} with {allow_patterns=}")
logger.info(f"Downloading {shard.model_meta.model_id=} with {allow_patterns=}")
all_start_time = time.time()
# TODO: currently not recursive. Some models might require subdirectories - thus this will need to be changed.
# Update: <- This does not seem to be the case. Yay?
file_list = await fetch_file_list_with_cache(
str(shard.model_card.model_id), revision, recursive=True
str(shard.model_meta.model_id), revision, recursive=True
)
filtered_file_list = list(
filter_repo_objects(
@@ -592,7 +592,7 @@ async def download_shard(
else timedelta(seconds=0)
)
file_progress[file.path] = RepoFileDownloadProgress(
repo_id=str(shard.model_card.model_id),
repo_id=str(shard.model_meta.model_id),
repo_revision=revision,
file_path=file.path,
downloaded=Memory.from_bytes(curr_bytes),
@@ -609,7 +609,7 @@ async def download_shard(
shard,
calculate_repo_progress(
shard,
str(shard.model_card.model_id),
str(shard.model_meta.model_id),
revision,
file_progress,
all_start_time,
@@ -619,7 +619,7 @@ async def download_shard(
for file in filtered_file_list:
downloaded_bytes = await get_downloaded_size(target_dir / file.path)
file_progress[file.path] = RepoFileDownloadProgress(
repo_id=str(shard.model_card.model_id),
repo_id=str(shard.model_meta.model_id),
repo_revision=revision,
file_path=file.path,
downloaded=Memory.from_bytes(downloaded_bytes),
@@ -643,7 +643,7 @@ async def download_shard(
async def download_with_semaphore(file: FileListEntry) -> None:
async with semaphore:
await download_file_with_retry(
str(shard.model_card.model_id),
str(shard.model_meta.model_id),
revision,
file.path,
target_dir,
@@ -657,7 +657,7 @@ async def download_shard(
*[download_with_semaphore(file) for file in filtered_file_list]
)
final_repo_progress = calculate_repo_progress(
shard, str(shard.model_card.model_id), revision, file_progress, all_start_time
shard, str(shard.model_meta.model_id), revision, file_progress, all_start_time
)
await on_progress(shard, final_repo_progress)
if gguf := next((f for f in filtered_file_list if f.path.endswith(".gguf")), None):

View File

@@ -4,7 +4,7 @@ from pathlib import Path
from typing import AsyncIterator, Callable
from exo.shared.models.model_cards import MODEL_CARDS
from exo.shared.models.model_meta import get_model_card
from exo.shared.models.model_meta import get_model_meta
from exo.shared.types.worker.shards import (
PipelineShardMetadata,
ShardMetadata,
@@ -20,21 +20,21 @@ def exo_shard_downloader(max_parallel_downloads: int = 8) -> ShardDownloader:
async def build_base_shard(model_id: str) -> ShardMetadata:
model_card = await get_model_card(model_id)
model_meta = await get_model_meta(model_id)
return PipelineShardMetadata(
model_card=model_card,
model_meta=model_meta,
device_rank=0,
world_size=1,
start_layer=0,
end_layer=model_card.n_layers,
n_layers=model_card.n_layers,
end_layer=model_meta.n_layers,
n_layers=model_meta.n_layers,
)
async def build_full_shard(model_id: str) -> PipelineShardMetadata:
base_shard = await build_base_shard(model_id)
return PipelineShardMetadata(
model_card=base_shard.model_card,
model_meta=base_shard.model_meta,
device_rank=base_shard.device_rank,
world_size=base_shard.world_size,
start_layer=base_shard.start_layer,
@@ -93,11 +93,11 @@ class CachedShardDownloader(ShardDownloader):
async def ensure_shard(
self, shard: ShardMetadata, config_only: bool = False
) -> Path:
if (shard.model_card.model_id, shard) in self.cache:
return self.cache[(shard.model_card.model_id, shard)]
if (shard.model_meta.model_id, shard) in self.cache:
return self.cache[(shard.model_meta.model_id, shard)]
target_dir = await self.shard_downloader.ensure_shard(shard, config_only)
self.cache[(shard.model_card.model_id, shard)] = target_dir
self.cache[(shard.model_meta.model_id, shard)] = target_dir
return target_dir
async def get_shard_download_status(

View File

@@ -5,8 +5,8 @@ from datetime import timedelta
from pathlib import Path
from typing import AsyncIterator, Callable
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.worker.shards import (
PipelineShardMetadata,
ShardMetadata,
@@ -86,8 +86,9 @@ NOOP_DOWNLOAD_PROGRESS = RepoDownloadProgress(
repo_id="noop",
repo_revision="noop",
shard=PipelineShardMetadata(
model_card=ModelCard(
model_meta=ModelMetadata(
model_id=ModelId("noop"),
pretty_name="noope",
storage_size=Memory.from_bytes(0),
n_layers=1,
hidden_size=1,

View File

@@ -1,3 +1,5 @@
from typing import Any
import mlx.core as mx
import mlx.nn as nn
from mlx_lm.models.cache import KVCache
@@ -15,3 +17,27 @@ class Model(nn.Module):
cache: list[KVCache] | None,
input_embeddings: mx.array | None = None,
) -> mx.array: ...
class Detokenizer:
def reset(self) -> None: ...
def add_token(self, token: int) -> None: ...
def finalize(self) -> None: ...
@property
def last_segment(self) -> str: ...
class TokenizerWrapper:
bos_token: str | None
eos_token_ids: list[int]
detokenizer: Detokenizer
def encode(self, text: str, add_special_tokens: bool = True) -> list[int]: ...
def apply_chat_template(
self,
messages_dicts: list[dict[str, Any]],
tokenize: bool = False,
add_generation_prompt: bool = True,
) -> str: ...

View File

@@ -248,9 +248,9 @@ def patch_pipeline_model[T](model: T, group: mx.distributed.Group) -> T:
"cache", None
)
# Add dependency to last cache entry to ensure distributed ops are evaluated
if cache is not None:
cache[-1].state = mx.depends(cache[-1].state, logits) # type: ignore
# # Add dependency to last cache entry to ensure distributed ops are evaluated
# if cache is not None:
# cache[-1].state = mx.depends(cache[-1].state, logits) # type: ignore
logits = mx.distributed.all_gather(logits, group=group)[
-logits.shape[0] :
@@ -566,7 +566,7 @@ class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
layer.block_sparse_moe.switch_mlp.up_proj
)
layer.block_sparse_moe = ShardedQwenMoE(layer.block_sparse_moe) # pyright: ignore[reportAttributeAccessIssue, reportArgumentType]
layer.block_sparse_moe.sharding_group = self.group # pyright: ignore[reportAttributeAccessIssue]
layer.block_sparse_moe.sharding_group = self.group
return model
@@ -661,7 +661,7 @@ class GptOssShardingStrategy(TensorParallelShardingStrategy):
self.all_to_sharded_linear_in_place(layer.mlp.experts.up_proj)
layer.mlp = ShardedGptOssMoE(layer.mlp) # type: ignore
layer.mlp.sharding_group = self.group # pyright: ignore[reportAttributeAccessIssue]
layer.mlp.sharding_group = self.group
return model

View File

@@ -119,7 +119,6 @@ def mlx_generate(
model: Model,
tokenizer: TokenizerWrapper,
task: ChatCompletionTaskParams,
prompt: str,
) -> Generator[GenerationResponse]:
# Ensure that generation stats only contains peak memory for this generation
mx.reset_peak_memory()
@@ -131,6 +130,11 @@ def mlx_generate(
if task.seed is not None:
mx.random.seed(task.seed)
prompt = apply_chat_template(
tokenizer=tokenizer,
chat_task_data=task,
)
caches = make_kv_cache(model=model)
logits_processors: list[Callable[[mx.array, mx.array], mx.array]] = []

View File

@@ -75,7 +75,7 @@ def get_weights_size(model_shard_meta: ShardMetadata) -> Memory:
return Memory.from_float_kb(
(model_shard_meta.end_layer - model_shard_meta.start_layer)
/ model_shard_meta.n_layers
* model_shard_meta.model_card.storage_size.in_kb
* model_shard_meta.model_meta.storage_size.in_kb
/ (
1
if isinstance(model_shard_meta, PipelineShardMetadata)
@@ -169,10 +169,10 @@ def mlx_distributed_init(
# TODO: update once upstream fixes
logger.info(
f"rank {rank} MLX_JACCL_DEVICES: {coordination_file} with devices: {jaccl_devices_json}"
f"rank {rank} MLX_IBV_DEVICES: {coordination_file} with devices: {jaccl_devices_json}"
)
logger.info(f"rank {rank} MLX_JACCL_COORDINATOR: {jaccl_coordinator}")
os.environ["MLX_JACCL_DEVICES"] = coordination_file
os.environ["MLX_IBV_DEVICES"] = coordination_file
os.environ["MLX_RANK"] = str(rank)
os.environ["MLX_JACCL_COORDINATOR"] = jaccl_coordinator
group = mx.distributed.init(backend="jaccl", strict=True)
@@ -206,7 +206,7 @@ def load_mlx_items(
) -> tuple[Model, TokenizerWrapper]:
if group is None:
logger.info(f"Single device used for {bound_instance.instance}")
model_path = build_model_path(bound_instance.bound_shard.model_card.model_id)
model_path = build_model_path(bound_instance.bound_shard.model_meta.model_id)
start_time = time.perf_counter()
model, _ = load_model(model_path, strict=True)
end_time = time.perf_counter()
@@ -234,7 +234,7 @@ def shard_and_load(
group: Group,
on_timeout: TimeoutCallback | None = None,
) -> tuple[nn.Module, TokenizerWrapper]:
model_path = build_model_path(shard_metadata.model_card.model_id)
model_path = build_model_path(shard_metadata.model_meta.model_id)
model, _ = load_model(model_path, lazy=True, strict=False)
logger.debug(model)
@@ -293,7 +293,7 @@ def shard_and_load(
def get_tokenizer(model_path: Path, shard_metadata: ShardMetadata) -> TokenizerWrapper:
"""Load tokenizer for a model shard. Delegates to load_tokenizer_for_model_id."""
return load_tokenizer_for_model_id(shard_metadata.model_card.model_id, model_path)
return load_tokenizer_for_model_id(shard_metadata.model_meta.model_id, model_path)
def get_eos_token_ids_for_model(model_id: str) -> list[int] | None:
@@ -312,9 +312,6 @@ def get_eos_token_ids_for_model(model_id: str) -> list[int] | None:
model_id_lower = model_id.lower()
if "kimi-k2" in model_id_lower:
return [163586]
elif "glm-4.7-flash" in model_id_lower:
# 154820: <|endoftext|>, 154827: <|user|>, 154829: <|observation|>
return [154820, 154827, 154829]
elif "glm" in model_id_lower:
return [151336, 151329, 151338]
return None
@@ -399,16 +396,6 @@ def apply_chat_template(
return prompt
def detect_thinking_prompt_suffix(prompt: str, tokenizer: TokenizerWrapper) -> bool:
"""
Detect if prompt ends with a thinking opening tag that should be
prepended to the output stream.
"""
think_token = tokenizer.think_start
return think_token is not None and prompt.rstrip().endswith(think_token)
class NullKVCache(KVCache):
"""
A KVCache that pretends to exist but holds zero tokens.

View File

@@ -8,7 +8,6 @@ from loguru import logger
from exo.routing.connection_message import ConnectionMessage, ConnectionMessageType
from exo.shared.apply import apply
from exo.shared.models.model_cards import ModelId
from exo.shared.types.commands import ForwarderCommand, RequestEventLog
from exo.shared.types.common import NodeId, SessionId
from exo.shared.types.events import (
@@ -23,6 +22,7 @@ from exo.shared.types.events import (
TopologyEdgeCreated,
TopologyEdgeDeleted,
)
from exo.shared.types.models import ModelId
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.state import State
from exo.shared.types.tasks import (
@@ -186,11 +186,11 @@ class Worker:
)
)
case DownloadModel(shard_metadata=shard):
if shard.model_card.model_id not in self.download_status:
if shard.model_meta.model_id not in self.download_status:
progress = DownloadPending(
shard_metadata=shard, node_id=self.node_id
)
self.download_status[shard.model_card.model_id] = progress
self.download_status[shard.model_meta.model_id] = progress
await self.event_sender.send(
NodeDownloadProgress(download_progress=progress)
)
@@ -205,7 +205,7 @@ class Worker:
node_id=self.node_id,
total_bytes=initial_progress.total_bytes,
)
self.download_status[shard.model_card.model_id] = progress
self.download_status[shard.model_meta.model_id] = progress
await self.event_sender.send(
NodeDownloadProgress(download_progress=progress)
)
@@ -339,7 +339,7 @@ class Worker:
initial_progress
),
)
self.download_status[task.shard_metadata.model_card.model_id] = status
self.download_status[task.shard_metadata.model_meta.model_id] = status
self.event_sender.send_nowait(NodeDownloadProgress(download_progress=status))
last_progress_time = 0.0
@@ -356,7 +356,7 @@ class Worker:
node_id=self.node_id,
total_bytes=progress.total_bytes,
)
self.download_status[shard.model_card.model_id] = status
self.download_status[shard.model_meta.model_id] = status
await self.event_sender.send(
NodeDownloadProgress(download_progress=status)
)
@@ -376,7 +376,7 @@ class Worker:
progress
),
)
self.download_status[shard.model_card.model_id] = status
self.download_status[shard.model_meta.model_id] = status
await self.event_sender.send(
NodeDownloadProgress(download_progress=status)
)
@@ -409,7 +409,7 @@ class Worker:
conns = await check_reachable(
self.state.topology,
self.node_id,
self.state.node_network,
self.state.node_profiles,
)
for nid in conns:
for ip in conns[nid]:
@@ -478,7 +478,7 @@ class Worker:
else:
continue
self.download_status[progress.shard.model_card.model_id] = status
self.download_status[progress.shard.model_meta.model_id] = status
await self.event_sender.send(
NodeDownloadProgress(download_progress=status)
)

View File

@@ -2,8 +2,8 @@
from collections.abc import Mapping, Sequence
from exo.shared.models.model_cards import ModelId
from exo.shared.types.common import NodeId
from exo.shared.types.models import ModelId
from exo.shared.types.tasks import (
ChatCompletion,
ConnectToGroup,
@@ -114,7 +114,7 @@ def _model_needs_download(
download_status: Mapping[ModelId, DownloadProgress],
) -> DownloadModel | None:
for runner in runners.values():
model_id = runner.bound_instance.bound_shard.model_card.model_id
model_id = runner.bound_instance.bound_shard.model_meta.model_id
if isinstance(runner.status, RunnerIdle) and (
model_id not in download_status
or not isinstance(
@@ -191,7 +191,7 @@ def _load_model(
nid in global_download_status
and any(
isinstance(dp, DownloadCompleted)
and dp.shard_metadata.model_card.model_id == shard_assignments.model_id
and dp.shard_metadata.model_meta.model_id == shard_assignments.model_id
for dp in global_download_status[nid]
)
for nid in shard_assignments.node_to_runner

View File

@@ -4,7 +4,6 @@ from functools import cache
import mlx.core as mx
from mlx_lm.models.gpt_oss import Model as GptOssModel
from mlx_lm.tokenizer_utils import TokenizerWrapper
from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
HarmonyEncodingName,
Role,
@@ -51,8 +50,6 @@ from exo.shared.types.worker.runners import (
from exo.utils.channels import MpReceiver, MpSender
from exo.worker.engines.mlx.generator.generate import mlx_generate, warmup_inference
from exo.worker.engines.mlx.utils_mlx import (
apply_chat_template,
detect_thinking_prompt_suffix,
initialize_mlx,
load_mlx_items,
mlx_force_oom,
@@ -180,28 +177,17 @@ def main(
try:
_check_for_debug_prompts(task_params.messages[0].content)
# Build prompt once - used for both generation and thinking detection
prompt = apply_chat_template(tokenizer, task_params)
# Generate responses using the actual MLX generation
mlx_generator = mlx_generate(
model=model,
tokenizer=tokenizer,
task=task_params,
prompt=prompt,
)
# GPT-OSS specific parsing to match other model formats.
if isinstance(model, GptOssModel):
mlx_generator = parse_gpt_oss(mlx_generator)
# For other thinking models (GLM, etc.), check if we need to
# prepend the thinking tag that was consumed by the chat template
if detect_thinking_prompt_suffix(prompt, tokenizer):
mlx_generator = parse_thinking_models(
mlx_generator, tokenizer
)
# TODO: Add tool call parser here
for response in mlx_generator:
@@ -213,7 +199,7 @@ def main(
command_id=command_id,
chunk=TokenChunk(
idx=response.token,
model=shard_metadata.model_card.model_id,
model=shard_metadata.model_meta.model_id,
text=response.text,
token_id=response.token,
finish_reason=response.finish_reason,
@@ -230,7 +216,7 @@ def main(
command_id=command_id,
chunk=TokenChunk(
idx=0,
model=shard_metadata.model_card.model_id,
model=shard_metadata.model_meta.model_id,
text="",
token_id=0,
finish_reason="error",
@@ -307,28 +293,6 @@ def parse_gpt_oss(
break
def parse_thinking_models(
responses: Generator[GenerationResponse],
tokenizer: TokenizerWrapper,
) -> Generator[GenerationResponse]:
"""
For models that inject thinking tags in the prompt (like GLM-4.7),
prepend the thinking tag to the output stream so the frontend
can properly parse thinking content.
"""
first = True
for response in responses:
if first:
first = False
yield response.model_copy(
update={
"text": tokenizer.think_start,
"token": tokenizer.think_start_id, # type: ignore
}
)
yield response
EXO_RUNNER_MUST_FAIL = "EXO RUNNER MUST FAIL"
EXO_RUNNER_MUST_OOM = "EXO RUNNER MUST OOM"
EXO_RUNNER_MUST_TIMEOUT = "EXO RUNNER MUST TIMEOUT"

View File

@@ -1,7 +1,7 @@
from typing import Final
from exo.shared.models.model_cards import ModelId
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.models import ModelId
from exo.shared.types.tasks import TaskId
from exo.shared.types.worker.instances import InstanceId, RunnerId

View File

@@ -1,8 +1,8 @@
from dataclasses import dataclass, field
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.types.common import NodeId
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.tasks import BaseTask, TaskId
from exo.shared.types.worker.instances import (
BoundInstance,
@@ -32,8 +32,9 @@ def get_pipeline_shard_metadata(
model_id: ModelId, device_rank: int, world_size: int = 1
) -> ShardMetadata:
return PipelineShardMetadata(
model_card=ModelCard(
model_meta=ModelMetadata(
model_id=model_id,
pretty_name=str(model_id),
storage_size=Memory.from_mb(100000),
n_layers=32,
hidden_size=2048,

View File

@@ -11,9 +11,9 @@ import mlx.core as mx
import mlx.nn as nn
from exo.shared.constants import EXO_MODELS_DIR
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.types.api import ChatCompletionMessage
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.tasks import ChatCompletionTaskParams
from exo.shared.types.worker.shards import PipelineShardMetadata, TensorShardMetadata
from exo.worker.engines.mlx import Model
@@ -81,8 +81,9 @@ def run_gpt_oss_pipeline_device(
start_layer, end_layer = layer_splits[rank]
shard_meta = PipelineShardMetadata(
model_card=ModelCard(
model_meta=ModelMetadata(
model_id=ModelId(DEFAULT_GPT_OSS_MODEL_ID),
pretty_name="GPT-OSS 20B",
storage_size=Memory.from_gb(12),
n_layers=24,
hidden_size=2880,
@@ -150,8 +151,9 @@ def run_gpt_oss_tensor_parallel_device(
# For tensor parallelism, all devices run all layers
shard_meta = TensorShardMetadata(
model_card=ModelCard(
model_meta=ModelMetadata(
model_id=ModelId(DEFAULT_GPT_OSS_MODEL_ID),
pretty_name="GPT-OSS 20B",
storage_size=Memory.from_gb(12),
n_layers=24,
hidden_size=2880,

View File

@@ -76,13 +76,13 @@ def get_test_models() -> list[tuple[str, ModelCard]]:
"""Get a representative sample of models to test."""
# Pick one model from each family to test
families: dict[str, tuple[str, ModelCard]] = {}
for _, card in MODEL_CARDS.items():
for short_id, card in MODEL_CARDS.items():
# Extract family name (e.g., "llama-3.1" from "llama-3.1-8b")
parts = card.model_id.short().split("-")
parts = short_id.split("-")
family = "-".join(parts[:2]) if len(parts) >= 2 else parts[0]
if family not in families:
families[family] = (card.model_id.short(), card)
families[family] = (short_id, card)
return list(families.values())

View File

@@ -1,7 +1,7 @@
import exo.worker.plan as plan_mod
from exo.shared.models.model_cards import ModelId
from exo.shared.types.common import NodeId
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId
from exo.shared.types.tasks import LoadModel
from exo.shared.types.worker.downloads import DownloadCompleted, DownloadProgress
from exo.shared.types.worker.instances import BoundInstance

View File

@@ -114,10 +114,6 @@ def patch_out_mlx(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(mlx_runner, "load_mlx_items", make_nothin((1, 1)))
monkeypatch.setattr(mlx_runner, "warmup_inference", make_nothin(1))
monkeypatch.setattr(mlx_runner, "_check_for_debug_prompts", nothin)
# Mock apply_chat_template since we're using a fake tokenizer (integer 1).
# Returns a prompt without thinking tag so detect_thinking_prompt_suffix returns None.
monkeypatch.setattr(mlx_runner, "apply_chat_template", make_nothin("test prompt"))
monkeypatch.setattr(mlx_runner, "detect_thinking_prompt_suffix", make_nothin(False))
def fake_generate(*_1: object, **_2: object):
yield GenerationResponse(token=0, text="hi", finish_reason="stop")

View File

@@ -82,7 +82,7 @@ async def tb_detection():
send, recv = channel[GatheredInfo]()
ig = InfoGatherer(send)
with anyio.move_on_after(1):
await ig._monitor_system_profiler_thunderbolt_data() # pyright: ignore[reportPrivateUsage]
await ig._monitor_system_profiler() # pyright: ignore[reportPrivateUsage]
with recv:
return recv.collect()
@@ -135,7 +135,7 @@ def ring_instance(test: Tests, iid: InstanceId, hn: str) -> Instance:
else:
raise ValueError(f"{hn} not in {test.devs}")
card = MODEL_CARDS[test.model_id]
meta = MODEL_CARDS[test.model_id].metadata
instance = MlxRingInstance(
instance_id=iid,
ephemeral_port=52416,
@@ -145,15 +145,15 @@ def ring_instance(test: Tests, iid: InstanceId, hn: str) -> Instance:
node_to_runner={NodeId(host[0]): RunnerId(host[0]) for host in test.devs},
runner_to_shard={
RunnerId(test.devs[i][0]): PipelineShardMetadata(
model_card=card,
model_meta=meta,
device_rank=i,
world_size=world_size,
start_layer=(card.n_layers // world_size) * i,
start_layer=(meta.n_layers // world_size) * i,
end_layer=min(
card.n_layers, (card.n_layers // world_size) * (i + 1)
meta.n_layers, (meta.n_layers // world_size) * (i + 1)
),
n_layers=min(card.n_layers, (card.n_layers // world_size) * (i + 1))
- (card.n_layers // world_size) * i,
n_layers=min(meta.n_layers, (meta.n_layers // world_size) * (i + 1))
- (meta.n_layers // world_size) * i,
)
for i in range(world_size)
},
@@ -224,7 +224,7 @@ async def jaccl_backend(test: Tests):
def jaccl_instance(test: Tests, iid: InstanceId):
card = MODEL_CARDS[test.model_id]
meta = MODEL_CARDS[test.model_id].metadata
world_size = len(test.devs)
return MlxJacclInstance(
@@ -239,12 +239,12 @@ def jaccl_instance(test: Tests, iid: InstanceId):
node_to_runner={NodeId(host[0]): RunnerId(host[0]) for host in test.devs},
runner_to_shard={
RunnerId(test.devs[i][0]): TensorShardMetadata(
model_card=card,
model_meta=meta,
device_rank=i,
world_size=world_size,
start_layer=card.n_layers,
end_layer=card.n_layers,
n_layers=card.n_layers,
start_layer=meta.n_layers,
end_layer=meta.n_layers,
n_layers=meta.n_layers,
)
for i in range(world_size)
},

View File

@@ -1,84 +0,0 @@
#!/usr/bin/env bash
set -euo pipefail
PREFS="${PREFS:-/Library/Preferences/SystemConfiguration/preferences.plist}"
tmpdir="$(mktemp -d)"
trap 'rm -rf "$tmpdir"' EXIT
injson="$tmpdir/in.json"
outjson="$tmpdir/out.json"
plutil -convert json -o "$injson" "$PREFS"
perl -Mstrict -Mwarnings -MJSON::PP -e '
my ($in, $out) = @ARGV;
open my $fh, "<", $in or die "open $in: $!";
local $/;
my $txt = <$fh>;
close $fh;
my $json = JSON::PP->new->utf8->relaxed(1);
my $d = $json->decode($txt);
if (ref($d->{VirtualNetworkInterfaces}) eq "HASH"
&& ref($d->{VirtualNetworkInterfaces}{Bridge}) eq "HASH") {
delete $d->{VirtualNetworkInterfaces}{Bridge}{bridge0};
}
my @bridge_svcs;
if (ref($d->{NetworkServices}) eq "HASH") {
for my $k (keys %{ $d->{NetworkServices} }) {
my $svc = $d->{NetworkServices}{$k};
next unless ref($svc) eq "HASH";
my $iface = $svc->{Interface};
next unless ref($iface) eq "HASH";
my $dev = $iface->{DeviceName};
if (defined $dev && $dev eq "bridge0") {
push @bridge_svcs, $k;
}
}
delete @{ $d->{NetworkServices} }{ @bridge_svcs } if @bridge_svcs;
}
my %is_bridge = map { $_ => 1 } @bridge_svcs;
if (ref($d->{Sets}) eq "HASH") {
for my $setk (keys %{ $d->{Sets} }) {
my $set = $d->{Sets}{$setk};
next unless ref($set) eq "HASH";
my $net = $set->{Network};
next unless ref($net) eq "HASH";
if (ref($net->{Interface}) eq "HASH") {
delete $net->{Interface}{bridge0};
}
if (ref($net->{Service}) eq "HASH" && @bridge_svcs) {
for my $svc (@bridge_svcs) {
delete $net->{Service}{$svc};
}
}
my $g = $net->{Global};
if (ref($g) eq "HASH"
&& ref($g->{IPv4}) eq "HASH"
&& ref($g->{IPv4}{ServiceOrder}) eq "ARRAY"
&& @bridge_svcs) {
my @so = @{ $g->{IPv4}{ServiceOrder} };
@so = grep { !defined($_) || !$is_bridge{$_} } @so;
$g->{IPv4}{ServiceOrder} = \@so;
}
}
}
open my $oh, ">", $out or die "open $out: $!";
print $oh JSON::PP->new->utf8->canonical(1)->pretty(1)->encode($d);
close $oh;
' "$injson" "$outjson"
# Convert JSON -> plist (write back as binary1; change to xml1 if you prefer)
plutil -convert xml1 -o "$PREFS" "$outjson"
# Ask configd to reload SystemConfiguration state
killall -HUP configd 2>/dev/null || true

1496
uv.lock generated
View File

File diff suppressed because it is too large Load Diff