Compare commits

..

3 Commits

Author SHA1 Message Date
Alex Cheema
c4d24a24e4 Handle MessageTooLarge error in router to prevent node crash
When a message exceeds gossipsub's 1 MiB limit, the resulting
RuntimeError("MessageTooLarge") was unhandled in _networking_publish,
causing the entire TaskGroup to fail and crash the node. Now catch
this error and log it, dropping the oversized message gracefully.

Fixes #1296

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-17 10:34:22 -08:00
rltakashige
83af8c63fa Revert "Use custom fork that resolves GPU locks" (#1502)
Reverts exo-explore/exo#1489

Goddammit Claude...
2026-02-17 18:18:54 +00:00
Evan Quiney
eccc6298d1 Revert "Add MetaInstance declarative layer (#1447)"
This reverts commit a962a28afc.
2026-02-17 18:11:47 +00:00
47 changed files with 400 additions and 4598 deletions

View File

@@ -194,40 +194,3 @@ GitHub's API doesn't support direct image upload for PR comments. Workaround:
git push origin <branch>
```
The images still render in the PR comment because they reference the permanent commit SHA.
## Running exo Remotely via SSH (macOS mDNS)
**CRITICAL: On macOS, mDNS multicast (used for peer discovery) only works when the process runs in a proper macOS user session.** Background processes started via `nohup ... &`, `screen`, or plain SSH commands will NOT send mDNS packets and nodes will never discover each other.
### The Problem
When you SSH into a Mac and run `nohup uv run exo &`, the process runs in a detached session without access to macOS multicast networking. The exo node will start but will never discover peers, even if they're on the same network.
### The Solution: Use `open` with a `.command` wrapper
Create a `.command` script that `open` will execute in the proper macOS GUI session context:
```bash
# 1. Create wrapper script on the remote machine
ssh user@remote-mac "cat > /tmp/run_exo.command << 'SCRIPT'
#!/bin/bash
export PATH=/opt/homebrew/bin:\$HOME/.local/bin:\$PATH
export EXO_LIBP2P_NAMESPACE=your-namespace # must match across all nodes
cd ~/path/to/exo
exec uv run exo -vv 2>&1 | tee /tmp/exo.log
SCRIPT
chmod +x /tmp/run_exo.command"
# 2. Launch it via `open` (runs in macOS GUI session with proper mDNS)
ssh user@remote-mac "open /tmp/run_exo.command"
# 3. Check logs
ssh user@remote-mac "tail -f /tmp/exo.log"
```
### Key Details
- **`EXO_LIBP2P_NAMESPACE`**: All nodes in a cluster MUST use the same namespace value. The EXO.app uses a build-specific namespace (check with `ps eww <pid> | grep NAMESPACE`). If mixing dev builds with EXO.app, set the dev build's namespace to match.
- **`open *.command`**: This is the macOS equivalent of double-clicking the script in Finder. It runs in the user's GUI session with full network access.
- **Do NOT use**: `nohup ... &`, `screen -dm`, `tmux new-session -d`, or `sshpass`. These all create detached sessions where mDNS won't work.
- **Killing**: `ssh user@remote-mac "pkill -f 'python.*exo'"` works fine for stopping.
- **Dashboard**: Must be built before running: `cd dashboard && npm install && npm run build && cd ..`. Node.js is at `/opt/homebrew/bin/node` on Apple Silicon Macs.
- **Verifying cluster**: `curl -s http://localhost:52415/state | python3 -c "import json,sys; s=json.load(sys.stdin); print(len(s['topology']['nodes']), 'nodes')"` — should show 2+ nodes.

View File

@@ -72,23 +72,16 @@ There are two ways to run exo:
### Run from Source (macOS)
If you have [Nix](https://nixos.org/) installed, you can skip most of the steps below and run exo directly (after accepting the Cachix cache):
```bash
nix run .#exo
```
**Prerequisites:**
- [Xcode](https://developer.apple.com/xcode/) (provides the Metal ToolChain required for MLX compilation)
- [brew](https://github.com/Homebrew/brew) (for simple package management on macOS)
```bash
/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/HEAD/install.sh)"
```
- [uv](https://github.com/astral-sh/uv) (for Python dependency management)
- [macmon](https://github.com/vladkens/macmon) (for hardware monitoring on Apple Silicon)
- [node](https://github.com/nodejs/node) (for building the dashboard)
```bash
brew install uv macmon node
```

View File

@@ -185,7 +185,11 @@
let instanceType: string | null = null;
if (instanceTag === "MlxRingInstance") instanceType = "MLX Ring";
else if (instanceTag === "MlxJacclInstance") instanceType = "MLX RDMA";
else if (
instanceTag === "MlxIbvInstance" ||
instanceTag === "MlxJacclInstance"
)
instanceType = "MLX RDMA";
let sharding: string | null = null;
const inst = instance as {

View File

@@ -21,7 +21,7 @@
} | null;
nodes?: Record<string, NodeInfo>;
sharding?: "Pipeline" | "Tensor";
runtime?: "MlxRing" | "MlxJaccl";
runtime?: "MlxRing" | "MlxIbv" | "MlxJaccl";
onLaunch?: () => void;
tags?: string[];
apiPreview?: PlacementPreview | null;
@@ -348,7 +348,7 @@
// Debug mode state
const isDebugMode = $derived(debugMode());
const topology = $derived(topologyData());
const isRdma = $derived(runtime === "MlxJaccl");
const isRdma = $derived(runtime === "MlxIbv" || runtime === "MlxJaccl");
// Get interface name for an IP from node data
function getInterfaceForIp(nodeId: string, ip?: string): string | null {
@@ -575,7 +575,7 @@
>
{runtime === "MlxRing"
? "MLX Ring"
: runtime === "MlxJaccl"
: runtime === "MlxIbv" || runtime === "MlxJaccl"
? "MLX RDMA"
: runtime}
</span>

View File

@@ -168,7 +168,7 @@ export interface ModelDownloadStatus {
export interface PlacementPreview {
model_id: string;
sharding: "Pipeline" | "Tensor";
instance_meta: "MlxRing" | "MlxJaccl";
instance_meta: "MlxRing" | "MlxIbv" | "MlxJaccl";
instance: unknown | null;
memory_delta_by_node: Record<string, number> | null;
error: string | null;
@@ -219,6 +219,7 @@ interface RawStateResponse {
string,
{
MlxRingInstance?: Instance;
MlxIbvInstance?: Instance;
MlxJacclInstance?: Instance;
}
>;
@@ -249,20 +250,6 @@ interface RawStateResponse {
>;
// Thunderbolt bridge cycles (nodes with bridge enabled forming loops)
thunderboltBridgeCycles?: string[][];
// MetaInstances (declarative instance constraints)
metaInstances?: Record<string, MetaInstanceData>;
}
export interface MetaInstanceData {
metaInstanceId: string;
modelId: string;
sharding: string;
instanceMeta: string;
minNodes: number;
nodeIds: string[] | null;
placementError: string | null;
consecutiveFailures: number;
lastFailureError: string | null;
}
export interface MessageAttachment {
@@ -550,7 +537,6 @@ class AppStore {
previewNodeFilter = $state<Set<string>>(new Set());
lastUpdate = $state<number | null>(null);
nodeIdentities = $state<Record<string, RawNodeIdentity>>({});
metaInstances = $state<Record<string, MetaInstanceData>>({});
thunderboltBridgeCycles = $state<string[][]>([]);
nodeThunderbolt = $state<
Record<
@@ -909,7 +895,11 @@ class AppStore {
let instanceType: string | null = null;
if (instanceTag === "MlxRingInstance") instanceType = "MLX Ring";
else if (instanceTag === "MlxJacclInstance") instanceType = "MLX RDMA";
else if (
instanceTag === "MlxIbvInstance" ||
instanceTag === "MlxJacclInstance"
)
instanceType = "MLX RDMA";
let sharding: string | null = null;
const inst = instance as {
@@ -1283,8 +1273,6 @@ class AppStore {
this.nodeThunderbolt = data.nodeThunderbolt ?? {};
// RDMA ctl status per node
this.nodeRdmaCtl = data.nodeRdmaCtl ?? {};
// MetaInstances
this.metaInstances = data.metaInstances ?? {};
// Thunderbolt bridge cycles
this.thunderboltBridgeCycles = data.thunderboltBridgeCycles ?? [];
// Thunderbolt bridge status per node
@@ -3056,7 +3044,6 @@ export const tps = () => appStore.tps;
export const totalTokens = () => appStore.totalTokens;
export const topologyData = () => appStore.topologyData;
export const instances = () => appStore.instances;
export const metaInstances = () => appStore.metaInstances;
export const runners = () => appStore.runners;
export const downloads = () => appStore.downloads;
export const nodeDisk = () => appStore.nodeDisk;

View File

File diff suppressed because it is too large Load Diff

View File

@@ -115,7 +115,7 @@
packages = lib.optionalAttrs pkgs.stdenv.hostPlatform.isDarwin (
let
uvLock = builtins.fromTOML (builtins.readFile ./uv.lock);
mlxPackage = builtins.head (builtins.filter (p: p.name == "mlx" && p.source ? git) uvLock.package);
mlxPackage = builtins.head (builtins.filter (p: p.name == "mlx") uvLock.package);
uvLockMlxVersion = mlxPackage.version;
in
{

View File

@@ -41,16 +41,16 @@ let
mlx = stdenv.mkDerivation rec {
pname = "mlx";
version = let v = "0.30.7.dev20260217+50487b41"; in
version = let v = "0.30.6"; in
assert v == uvLockMlxVersion || throw "MLX version mismatch: nix/mlx.nix has ${v} but uv.lock has ${uvLockMlxVersion}. Update both the version and hash in nix/mlx.nix.";
v;
pyproject = true;
src = fetchFromGitHub {
owner = "rltakashige";
repo = "mlx-jaccl-fix-small-recv";
rev = "50487b4141f3c951122655db3b83df5146c1fbeb";
hash = "sha256-IL4a9vMX5nocgJU1WG4zE8hArHkHJtnh4sdYh3od5zU=";
owner = "ml-explore";
repo = "mlx";
tag = "v${version}";
hash = "sha256-avD5EGhwgmPdXLAyQSqTO6AXk/W3ziH+f6AetjK3Sdo=";
};
patches = [

View File

@@ -17,7 +17,7 @@ dependencies = [
"loguru>=0.7.3",
"exo_pyo3_bindings", # rust bindings
"anyio==4.11.0",
"mlx; sys_platform == 'darwin'",
"mlx==0.30.6; sys_platform == 'darwin'",
"mlx[cpu]==0.30.6; sys_platform == 'linux'",
"mlx-lm==0.30.6",
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
@@ -64,7 +64,6 @@ members = [
[tool.uv.sources]
exo_pyo3_bindings = { workspace = true }
mlx = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git", branch = "address-rdma-gpu-locks", marker = "sys_platform == 'darwin'" }
#mlx-lm = { git = "https://github.com/davidmcc73/mlx-lm", branch = "stable" }
# Uncomment to use local mlx/mlx-lm development versions:
# mlx = { path = "/Users/Shared/mlx", editable=true }

View File

@@ -58,21 +58,6 @@
lib.optionalAttrs pkgs.stdenv.hostPlatform.isLinux (
(lib.mapAttrs (_: ignoreMissing) nvidiaPackages) // {
mlx = ignoreMissing prev.mlx;
mlx-cuda-13 = prev.mlx-cuda-13.overrideAttrs (old: {
buildInputs = (old.buildInputs or [ ]) ++ [
final.nvidia-cublas
final.nvidia-cuda-nvrtc
final.nvidia-cudnn-cu13
final.nvidia-nccl-cu13
];
preFixup = ''
addAutoPatchelfSearchPath ${final.nvidia-cublas}
addAutoPatchelfSearchPath ${final.nvidia-cuda-nvrtc}
addAutoPatchelfSearchPath ${final.nvidia-cudnn-cu13}
addAutoPatchelfSearchPath ${final.nvidia-nccl-cu13}
'';
autoPatchelfIgnoreMissingDeps = [ "libcuda.so.1" ];
});
torch = ignoreMissing prev.torch;
triton = ignoreMissing prev.triton;
}
@@ -89,25 +74,14 @@
linuxOverlay
]
);
# mlx-cpu and mlx-cuda-13 both ship mlx/ site-packages files; keep first.
# mlx-cpu/mlx-cuda-13 and nvidia-cudnn-cu12/cu13 ship overlapping files.
venvCollisionPaths = lib.optionals pkgs.stdenv.hostPlatform.isLinux [
"lib/python3.13/site-packages/mlx*"
"lib/python3.13/site-packages/nvidia*"
];
exoVenv = (pythonSet.mkVirtualEnv "exo-env" workspace.deps.default).overrideAttrs {
venvIgnoreCollisions = venvCollisionPaths;
};
exoVenv = pythonSet.mkVirtualEnv "exo-env" workspace.deps.default;
# Virtual environment with dev dependencies for testing
testVenv = (pythonSet.mkVirtualEnv "exo-test-env" (
testVenv = pythonSet.mkVirtualEnv "exo-test-env" (
workspace.deps.default // {
exo = [ "dev" ]; # Include pytest, pytest-asyncio, pytest-env
}
)).overrideAttrs {
venvIgnoreCollisions = venvCollisionPaths;
};
);
mkPythonScript = name: path: pkgs.writeShellApplication {
inherit name;

View File

@@ -314,17 +314,7 @@ class DownloadCoordinator:
),
)
elif progress.status in ["in_progress", "not_started"]:
if (
progress.downloaded_bytes.in_bytes
>= progress.total_bytes.in_bytes
> 0
):
status = DownloadCompleted(
node_id=self.node_id,
shard_metadata=progress.shard,
total_bytes=progress.total_bytes,
)
elif progress.downloaded_bytes_this_session.in_bytes == 0:
if progress.downloaded_bytes_this_session.in_bytes == 0:
status = DownloadPending(
node_id=self.node_id,
shard_metadata=progress.shard,

View File

@@ -254,7 +254,7 @@ def main():
target = min(max(soft, 65535), hard)
resource.setrlimit(resource.RLIMIT_NOFILE, (target, hard))
mp.set_start_method("spawn", force=True)
mp.set_start_method("spawn")
# TODO: Refactor the current verbosity system
logger_setup(EXO_LOG, args.verbosity)
logger.info("Starting EXO")

View File

@@ -71,13 +71,8 @@ from exo.shared.types.api import (
ChatCompletionResponse,
CreateInstanceParams,
CreateInstanceResponse,
CreateMetaInstanceParams,
CreateMetaInstanceResponse,
DeleteDownloadResponse,
DeleteInstanceResponse,
DeleteMetaInstanceResponse,
DistributeModelParams,
DistributeModelResponse,
ErrorInfo,
ErrorResponse,
FinishReason,
@@ -120,11 +115,8 @@ from exo.shared.types.claude_api import (
from exo.shared.types.commands import (
Command,
CreateInstance,
CreateMetaInstance,
DeleteDownload,
DeleteInstance,
DeleteMetaInstance,
DistributeModel,
DownloadCommand,
ForwarderCommand,
ForwarderDownloadCommand,
@@ -137,7 +129,7 @@ from exo.shared.types.commands import (
TaskFinished,
TextGeneration,
)
from exo.shared.types.common import CommandId, Id, MetaInstanceId, NodeId, SessionId
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
from exo.shared.types.events import (
ChunkGenerated,
Event,
@@ -146,13 +138,11 @@ from exo.shared.types.events import (
TracesMerged,
)
from exo.shared.types.memory import Memory
from exo.shared.types.meta_instance import MetaInstance
from exo.shared.types.openai_responses import (
ResponsesRequest,
ResponsesResponse,
)
from exo.shared.types.state import State
from exo.shared.types.worker.downloads import DownloadCompleted
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.shards import Sharding
from exo.utils.banner import print_startup_banner
@@ -286,9 +276,6 @@ class API:
self.app.get("/instance/previews")(self.get_placement_previews)
self.app.get("/instance/{instance_id}")(self.get_instance)
self.app.delete("/instance/{instance_id}")(self.delete_instance)
self.app.get("/meta_instances")(self.list_meta_instances)
self.app.post("/meta_instance")(self.create_meta_instance)
self.app.delete("/meta_instance/{meta_instance_id}")(self.delete_meta_instance)
self.app.get("/models")(self.get_models)
self.app.get("/v1/models")(self.get_models)
self.app.post("/models/add")(self.add_custom_model)
@@ -312,34 +299,18 @@ class API:
self.app.get("/events")(self.stream_events)
self.app.post("/download/start")(self.start_download)
self.app.delete("/download/{node_id}/{model_id:path}")(self.delete_download)
self.app.post("/v1/models/{model_id:path}/distribute")(self.distribute_model)
self.app.get("/v1/traces")(self.list_traces)
self.app.get("/v1/traces/{task_id}")(self.get_trace)
self.app.get("/v1/traces/{task_id}/stats")(self.get_trace_stats)
self.app.get("/v1/traces/{task_id}/raw")(self.get_trace_raw)
async def place_instance(self, payload: PlaceInstanceParams):
model_card = await ModelCard.load(payload.model_id)
command = PlaceInstance(
model_card=model_card,
model_card=await ModelCard.load(payload.model_id),
sharding=payload.sharding,
instance_meta=payload.instance_meta,
min_nodes=payload.min_nodes,
)
# Validate placement before sending — fail fast with a clear error
# instead of silently dropping the command in the master.
try:
get_instance_placements(
command,
topology=self.state.topology,
current_instances=self.state.instances,
node_memory=self.state.node_memory,
node_network=self.state.node_network,
)
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
await self._send(command)
return CreateInstanceResponse(
@@ -551,44 +522,6 @@ class API:
instance_id=instance_id,
)
def list_meta_instances(self) -> dict[MetaInstanceId, MetaInstance]:
return dict(self.state.meta_instances)
async def create_meta_instance(
self, payload: CreateMetaInstanceParams
) -> CreateMetaInstanceResponse:
meta_instance = MetaInstance(
model_id=payload.model_id,
sharding=payload.sharding,
instance_meta=payload.instance_meta,
min_nodes=payload.min_nodes,
node_ids=payload.node_ids,
)
command = CreateMetaInstance(meta_instance=meta_instance)
await self._send(command)
return CreateMetaInstanceResponse(
message="Command received.",
command_id=command.command_id,
meta_instance_id=meta_instance.meta_instance_id,
)
async def delete_meta_instance(
self, meta_instance_id: MetaInstanceId
) -> DeleteMetaInstanceResponse:
meta = self.state.meta_instances.get(meta_instance_id)
if not meta:
raise HTTPException(status_code=404, detail="MetaInstance not found")
# Command processor handles cascade-deleting backing instances
command = DeleteMetaInstance(meta_instance_id=meta_instance_id)
await self._send(command)
return DeleteMetaInstanceResponse(
message="Command received.",
command_id=command.command_id,
meta_instance_id=meta_instance_id,
)
async def _token_chunk_stream(
self, command_id: CommandId
) -> AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None]:
@@ -608,10 +541,10 @@ class API:
break
except anyio.get_cancelled_exc_class():
cancel_command = TaskCancelled(cancelled_command_id=command_id)
command = TaskCancelled(cancelled_command_id=command_id)
with anyio.CancelScope(shield=True):
await self.command_sender.send(
ForwarderCommand(origin=self.node_id, command=cancel_command)
ForwarderCommand(origin=self.node_id, command=command)
)
raise
finally:
@@ -951,10 +884,10 @@ class API:
del image_metadata[key]
except anyio.get_cancelled_exc_class():
cancel_command = TaskCancelled(cancelled_command_id=command_id)
command = TaskCancelled(cancelled_command_id=command_id)
with anyio.CancelScope(shield=True):
await self.command_sender.send(
ForwarderCommand(origin=self.node_id, command=cancel_command)
ForwarderCommand(origin=self.node_id, command=command)
)
raise
finally:
@@ -1037,10 +970,10 @@ class API:
return (images, stats if capture_stats else None)
except anyio.get_cancelled_exc_class():
cancel_command = TaskCancelled(cancelled_command_id=command_id)
command = TaskCancelled(cancelled_command_id=command_id)
with anyio.CancelScope(shield=True):
await self.command_sender.send(
ForwarderCommand(origin=self.node_id, command=cancel_command)
ForwarderCommand(origin=self.node_id, command=command)
)
raise
finally:
@@ -1562,57 +1495,6 @@ class API:
await self._send_download(command)
return DeleteDownloadResponse(command_id=command.command_id)
async def distribute_model(
self, model_id: ModelId, payload: DistributeModelParams
) -> DistributeModelResponse:
"""Distribute model files from one node to others via MLX distributed."""
# Find a source node that has the model downloaded
source_node_id: NodeId | None = None
for nid, downloads in self.state.downloads.items():
for dp in downloads:
if (
isinstance(dp, DownloadCompleted)
and dp.shard_metadata.model_card.model_id == model_id
):
source_node_id = nid
break
if source_node_id is not None:
break
if source_node_id is None:
raise HTTPException(
status_code=404,
detail=f"No node has model {model_id} downloaded",
)
# Determine target nodes
if payload.target_node_ids is not None:
target_node_ids = [
nid for nid in payload.target_node_ids if nid != source_node_id
]
else:
target_node_ids = [
nid for nid in self.state.topology.list_nodes() if nid != source_node_id
]
if not target_node_ids:
raise HTTPException(
status_code=400,
detail="No target nodes to distribute to",
)
command = DistributeModel(
model_id=model_id,
source_node_id=source_node_id,
target_node_ids=target_node_ids,
)
await self._send(command)
return DistributeModelResponse(
command_id=command.command_id,
message=f"Distributing {model_id} from {source_node_id} to {len(target_node_ids)} node(s)",
)
def _get_trace_path(self, task_id: str) -> Path:
return EXO_TRACING_CACHE_DIR / f"trace_{task_id}.json"

View File

@@ -1,5 +1,4 @@
from collections.abc import Sequence
from datetime import datetime, timezone
from datetime import datetime, timedelta, timezone
import anyio
from anyio.abc import TaskGroup
@@ -13,23 +12,11 @@ from exo.master.placement import (
get_transition_events,
place_instance,
)
from exo.master.process_managers import ProcessManager
from exo.master.process_managers.instance_health import InstanceHealthReconciler
from exo.master.process_managers.meta_instance import MetaInstanceReconciler
from exo.master.process_managers.node_timeout import NodeTimeoutReconciler
from exo.master.reconcile import (
find_unsatisfied_meta_instances,
try_place_for_meta_instance,
)
from exo.shared.apply import apply
from exo.shared.constants import EXO_EVENT_LOG_DIR, EXO_TRACING_ENABLED
from exo.shared.models.model_cards import ModelCard
from exo.shared.types.commands import (
CreateInstance,
CreateMetaInstance,
DeleteInstance,
DeleteMetaInstance,
DistributeModel,
ForwarderCommand,
ForwarderDownloadCommand,
ImageEdits,
@@ -49,12 +36,8 @@ from exo.shared.types.events import (
IndexedEvent,
InputChunkReceived,
InstanceDeleted,
JacclSideChannelData,
JacclSideChannelGathered,
MetaInstanceCreated,
MetaInstanceDeleted,
MetaInstancePlacementFailed,
NodeGatheredInfo,
NodeTimedOut,
TaskCreated,
TaskDeleted,
TaskStatusUpdated,
@@ -77,8 +60,7 @@ from exo.shared.types.tasks import (
TextGeneration as TextGenerationTask,
)
from exo.shared.types.worker.instances import InstanceId
from exo.shared.types.worker.runners import RunnerId
from exo.utils.channels import Receiver, Sender
from exo.utils.channels import Receiver, Sender, channel
from exo.utils.event_buffer import MultiSourceBuffer
@@ -102,16 +84,16 @@ class Master:
self.local_event_receiver = local_event_receiver
self.global_event_sender = global_event_sender
self.download_command_sender = download_command_sender
send, recv = channel[Event]()
self.event_sender: Sender[Event] = send
self._loopback_event_receiver: Receiver[Event] = recv
self._loopback_event_sender: Sender[ForwarderEvent] = (
local_event_receiver.clone_sender()
)
self._multi_buffer = MultiSourceBuffer[NodeId, Event]()
self._event_log = DiskEventLog(EXO_EVENT_LOG_DIR / "master")
self._pending_traces: dict[TaskId, dict[int, list[TraceEventData]]] = {}
self._expected_ranks: dict[TaskId, set[int]] = {}
self._jaccl_pending: dict[InstanceId, dict[int, dict[RunnerId, bytes]]] = {}
self._process_managers: Sequence[ProcessManager] = [
InstanceHealthReconciler(),
NodeTimeoutReconciler(),
MetaInstanceReconciler(),
]
async def run(self):
logger.info("Starting Master")
@@ -120,12 +102,15 @@ class Master:
async with self._tg as tg:
tg.start_soon(self._event_processor)
tg.start_soon(self._command_processor)
tg.start_soon(self._reconcile)
tg.start_soon(self._loopback_processor)
tg.start_soon(self._plan)
finally:
self._event_log.close()
self.global_event_sender.close()
self.local_event_receiver.close()
self.command_receiver.close()
self._loopback_event_sender.close()
self._loopback_event_receiver.close()
async def shutdown(self):
logger.info("Stopping Master")
@@ -307,86 +292,6 @@ class Master:
)
)
generated_events.extend(transition_events)
case CreateMetaInstance():
logger.info(
f"Creating MetaInstance for {command.meta_instance.model_id}"
f" (min_nodes={command.meta_instance.min_nodes},"
f" sharding={command.meta_instance.sharding})"
)
# Apply immediately so self.state is fresh across
# the await below and the reconciler won't race.
await self._apply_and_broadcast(
MetaInstanceCreated(meta_instance=command.meta_instance)
)
# Immediate placement attempt for responsiveness
model_card = await ModelCard.load(
command.meta_instance.model_id
)
# Re-check: reconciler may have satisfied it during the await
meta_id = command.meta_instance.meta_instance_id
still_unsatisfied = any(
m.meta_instance_id == meta_id
for m in find_unsatisfied_meta_instances(
self.state.meta_instances,
self.state.instances,
self.state.topology,
)
)
if still_unsatisfied:
result = try_place_for_meta_instance(
command.meta_instance,
model_card,
self.state.topology,
self.state.instances,
self.state.node_memory,
self.state.node_network,
self.state.tasks,
)
generated_events.extend(result.events)
if result.error is not None:
generated_events.append(
MetaInstancePlacementFailed(
meta_instance_id=meta_id,
reason=result.error,
)
)
case DeleteMetaInstance():
backing_count = sum(
1
for inst in self.state.instances.values()
if inst.meta_instance_id == command.meta_instance_id
)
logger.info(
f"Deleting MetaInstance {command.meta_instance_id}"
f" (cascade-deleting {backing_count} backing instance(s))"
)
generated_events.append(
MetaInstanceDeleted(
meta_instance_id=command.meta_instance_id
)
)
# Cascade-delete backing instances atomically,
# cancelling any active tasks first.
for iid, inst in self.state.instances.items():
if inst.meta_instance_id == command.meta_instance_id:
for task in self.state.tasks.values():
if (
task.instance_id == iid
and task.task_status
in (
TaskStatus.Pending,
TaskStatus.Running,
)
):
generated_events.append(
TaskStatusUpdated(
task_status=TaskStatus.Cancelled,
task_id=task.task_id,
)
)
generated_events.append(
InstanceDeleted(instance_id=iid)
)
case PlaceInstance():
placement = place_instance(
command,
@@ -409,36 +314,6 @@ class Master:
self.state.instances, placement, self.state.tasks
)
generated_events.extend(transition_events)
case DistributeModel():
from exo.shared.types.worker.instances import InstanceMeta
from exo.shared.types.worker.shards import Sharding
model_card = await ModelCard.load(command.model_id)
all_node_ids = set(
[command.source_node_id] + list(command.target_node_ids)
)
place_command = PlaceInstance(
model_card=model_card,
sharding=Sharding.Pipeline,
instance_meta=InstanceMeta.MlxRing,
min_nodes=len(all_node_ids),
)
placement = place_instance(
place_command,
self.state.topology,
self.state.instances,
self.state.node_memory,
self.state.node_network,
required_nodes=all_node_ids,
)
# Mark new instances as transfer-only
for instance_id, instance in placement.items():
if instance_id not in self.state.instances:
instance.shard_assignments.transfer_only = True
transition_events = get_transition_events(
self.state.instances, placement, self.state.tasks
)
generated_events.extend(transition_events)
case SendInputChunk(chunk=chunk):
generated_events.append(
InputChunkReceived(
@@ -448,19 +323,16 @@ class Master:
)
case TaskCancelled():
if (
command.cancelled_command_id
in self.command_task_mapping
):
task_id := self.command_task_mapping.get(
command.cancelled_command_id
)
) is not None:
generated_events.append(
TaskDeleted(
task_id=self.command_task_mapping[
command.cancelled_command_id
]
TaskStatusUpdated(
task_status=TaskStatus.Cancelled,
task_id=task_id,
)
)
del self.command_task_mapping[
command.cancelled_command_id
]
case TaskFinished():
generated_events.append(
TaskDeleted(
@@ -469,10 +341,9 @@ class Master:
]
)
)
if command.finished_command_id in self.command_task_mapping:
del self.command_task_mapping[
command.finished_command_id
]
self.command_task_mapping.pop(
command.finished_command_id, None
)
case RequestEventLog():
# We should just be able to send everything, since other buffers will ignore old messages
# rate limit to 1000 at a time
@@ -483,32 +354,31 @@ class Master:
):
await self._send_event(IndexedEvent(idx=i, event=event))
for event in generated_events:
await self._apply_and_broadcast(event)
await self.event_sender.send(event)
except ValueError as e:
logger.opt(exception=e).warning("Error in command processor")
async def _apply_and_broadcast(self, event: Event) -> None:
"""Apply event to state, persist to disk, and broadcast to workers.
State is updated synchronously (before any await), so callers can
rely on ``self.state`` reflecting this event immediately after the
call. Python's cooperative scheduling guarantees no interleaving
between the state read and write.
"""
logger.debug(f"Master indexing event: {str(event)[:100]}")
indexed = IndexedEvent(event=event, idx=len(self._event_log))
self.state = apply(self.state, indexed)
event._master_time_stamp = datetime.now(tz=timezone.utc) # pyright: ignore[reportPrivateUsage]
self._event_log.append(event)
await self._send_event(indexed)
async def _reconcile(self) -> None:
# These plan loops are the cracks showing in our event sourcing architecture - more things could be commands
async def _plan(self) -> None:
while True:
for pm in self._process_managers:
events = await pm.reconcile(self.state)
for event in events:
await self._apply_and_broadcast(event)
await anyio.sleep(1)
# kill broken instances
connected_node_ids = set(self.state.topology.list_nodes())
for instance_id, instance in self.state.instances.items():
for node_id in instance.shard_assignments.node_to_runner:
if node_id not in connected_node_ids:
await self.event_sender.send(
InstanceDeleted(instance_id=instance_id)
)
break
# time out dead nodes
for node_id, time in self.state.last_seen.items():
now = datetime.now(tz=timezone.utc)
if now - time > timedelta(seconds=30):
logger.info(f"Manually removing node {node_id} due to inactivity")
await self.event_sender.send(NodeTimedOut(node_id=node_id))
await anyio.sleep(10)
async def _event_processor(self) -> None:
with self.local_event_receiver as local_events:
@@ -526,15 +396,32 @@ class Master:
await self._handle_traces_collected(event)
continue
if isinstance(event, JacclSideChannelData):
await self._apply_and_broadcast(event)
await self._handle_jaccl_side_channel(event)
continue
logger.debug(f"Master indexing event: {str(event)[:100]}")
indexed = IndexedEvent(event=event, idx=len(self._event_log))
self.state = apply(self.state, indexed)
event._master_time_stamp = datetime.now(tz=timezone.utc) # pyright: ignore[reportPrivateUsage]
if isinstance(event, NodeGatheredInfo):
event.when = str(datetime.now(tz=timezone.utc))
await self._apply_and_broadcast(event)
self._event_log.append(event)
await self._send_event(indexed)
async def _loopback_processor(self) -> None:
# this would ideally not be necessary.
# this is WAY less hacky than how I was working around this before
local_index = 0
with self._loopback_event_receiver as events:
async for event in events:
await self._loopback_event_sender.send(
ForwarderEvent(
origin=NodeId(f"master_{self.node_id}"),
origin_idx=local_index,
session=self.session_id,
event=event,
)
)
local_index += 1
# This function is re-entrant, take care!
async def _send_event(self, event: IndexedEvent):
@@ -566,49 +453,10 @@ class Master:
for trace_data in self._pending_traces[task_id].values():
all_trace_data.extend(trace_data)
await self._apply_and_broadcast(
await self.event_sender.send(
TracesMerged(task_id=task_id, traces=all_trace_data)
)
del self._pending_traces[task_id]
if task_id in self._expected_ranks:
del self._expected_ranks[task_id]
async def _handle_jaccl_side_channel(self, event: JacclSideChannelData) -> None:
"""Accumulate SideChannel contributions; when all runners for an instance
have submitted for the same sequence, emit JacclSideChannelGathered."""
iid = event.instance_id
seq = event.sequence
if iid not in self._jaccl_pending:
self._jaccl_pending[iid] = {}
if seq not in self._jaccl_pending[iid]:
self._jaccl_pending[iid][seq] = {}
self._jaccl_pending[iid][seq][event.runner_id] = event.data
instance = self.state.instances.get(iid)
if instance is None:
logger.warning(f"JacclSideChannelData for unknown instance {iid}")
return
expected_runners = set(instance.shard_assignments.runner_to_shard.keys())
submitted = set(self._jaccl_pending[iid][seq].keys())
logger.info(
f"JACCL side channel: instance={iid} seq={seq} "
f"submitted={len(submitted)}/{len(expected_runners)}"
)
if submitted >= expected_runners:
gathered = dict(self._jaccl_pending[iid][seq])
del self._jaccl_pending[iid][seq]
if not self._jaccl_pending[iid]:
del self._jaccl_pending[iid]
await self._apply_and_broadcast(
JacclSideChannelGathered(
instance_id=iid,
sequence=seq,
gathered_data=gathered,
)
)

View File

@@ -6,11 +6,11 @@ from typing import Sequence
from exo.master.placement_utils import (
Cycle,
filter_cycles_by_memory,
get_largest_cycles,
get_mlx_jaccl_coordinators,
get_mlx_jaccl_devices_matrix,
get_mlx_ring_hosts_by_node,
get_shard_assignments,
get_smallest_cycles,
)
from exo.shared.models.model_cards import ModelId
from exo.shared.topology import Topology
@@ -106,27 +106,23 @@ def place_instance(
"Pipeline parallelism is not supported for DeepSeek V3.1 (8-bit)"
)
largest_cycles = get_largest_cycles(cycles_with_sufficient_memory)
smallest_cycles = get_smallest_cycles(cycles_with_sufficient_memory)
largest_rdma_cycles = [
cycle for cycle in largest_cycles if topology.is_rdma_cycle(cycle)
smallest_rdma_cycles = [
cycle for cycle in smallest_cycles if topology.is_rdma_cycle(cycle)
]
if command.instance_meta == InstanceMeta.MlxJaccl:
if not largest_rdma_cycles:
raise ValueError(
"Requested RDMA (MlxJaccl) but no RDMA-connected cycles available"
)
largest_cycles = largest_rdma_cycles
if command.instance_meta == InstanceMeta.MlxJaccl and smallest_rdma_cycles != []:
smallest_cycles = smallest_rdma_cycles
cycles_with_leaf_nodes: list[Cycle] = [
cycle
for cycle in largest_cycles
for cycle in smallest_cycles
if any(topology.node_is_leaf(node_id) for node_id in cycle)
]
selected_cycle = max(
cycles_with_leaf_nodes if cycles_with_leaf_nodes != [] else largest_cycles,
cycles_with_leaf_nodes if cycles_with_leaf_nodes != [] else smallest_cycles,
key=lambda cycle: sum(
(node_memory[node_id].ram_available for node_id in cycle),
start=Memory(),

View File

@@ -37,11 +37,11 @@ def filter_cycles_by_memory(
return filtered_cycles
def get_largest_cycles(
def get_smallest_cycles(
cycles: list[Cycle],
) -> list[Cycle]:
max_nodes = max(len(cycle) for cycle in cycles)
return [cycle for cycle in cycles if len(cycle) == max_nodes]
min_nodes = min(len(cycle) for cycle in cycles)
return [cycle for cycle in cycles if len(cycle) == min_nodes]
def allocate_layers_proportionally(

View File

@@ -1,12 +0,0 @@
from collections.abc import Sequence
from typing import Protocol, runtime_checkable
from exo.shared.types.events import Event
from exo.shared.types.state import State
@runtime_checkable
class ProcessManager(Protocol):
"""A reconciliation step that examines state and returns corrective events."""
async def reconcile(self, state: State) -> Sequence[Event]: ...

View File

@@ -1,62 +0,0 @@
from collections.abc import Sequence
from typing import final
from loguru import logger
from exo.master.reconcile import instance_connections_healthy, instance_runners_failed
from exo.shared.types.events import Event, InstanceDeleted, InstanceRetrying
from exo.shared.types.state import State
MAX_INSTANCE_RETRIES = 3
@final
class InstanceHealthReconciler:
"""Delete instances whose network connections are broken or whose runners have all failed."""
async def reconcile(self, state: State) -> Sequence[Event]:
events: list[Event] = []
for instance_id, instance in state.instances.items():
if not instance_connections_healthy(instance, state.topology):
events.append(
InstanceDeleted(
instance_id=instance_id,
failure_error="Network connection lost",
)
)
continue
is_failed, error_message = instance_runners_failed(
instance, state.runners, state.node_identities
)
if is_failed:
# Retry within the same instance if backed by a MetaInstance
mid = instance.meta_instance_id
mi = state.meta_instances.get(mid) if mid else None
if mid and mi and mi.consecutive_failures < MAX_INSTANCE_RETRIES:
logger.info(
f"Instance {instance_id} failed (attempt"
f" {mi.consecutive_failures + 1}/{MAX_INSTANCE_RETRIES}),"
f" retrying: {error_message}"
)
events.append(
InstanceRetrying(
instance_id=instance_id,
meta_instance_id=mid,
failure_error=error_message or "Runner failed",
)
)
else:
if mid and mi:
logger.warning(
f"Instance {instance_id} exceeded retry limit"
f" ({MAX_INSTANCE_RETRIES}), deleting:"
f" {error_message}"
)
events.append(
InstanceDeleted(
instance_id=instance_id,
failure_error=error_message,
)
)
return events

View File

@@ -1,92 +0,0 @@
from collections.abc import Sequence
from typing import final
import anyio
from loguru import logger
from exo.master.reconcile import (
find_unsatisfied_meta_instances,
try_place_for_meta_instance,
)
from exo.shared.models.model_cards import ModelCard
from exo.shared.types.events import Event, InstanceCreated, MetaInstancePlacementFailed
from exo.shared.types.state import State
from exo.shared.types.worker.instances import Instance, InstanceId
MODEL_CARD_LOAD_TIMEOUT_SECONDS = 10
@final
class MetaInstanceReconciler:
"""Place instances for unsatisfied MetaInstances."""
async def reconcile(self, state: State) -> Sequence[Event]:
all_events: list[Event] = []
# Local copy for intermediate tracking — so placement of B
# sees A's instance and doesn't double-place on same resources.
current_instances: dict[InstanceId, Instance] = dict(state.instances)
unsatisfied = find_unsatisfied_meta_instances(
state.meta_instances,
current_instances,
state.topology,
)
for meta_instance in unsatisfied:
try:
with anyio.fail_after(MODEL_CARD_LOAD_TIMEOUT_SECONDS):
model_card = await ModelCard.load(meta_instance.model_id)
except TimeoutError:
logger.warning(
f"ModelCard.load timed out for {meta_instance.model_id}, skipping this cycle"
)
continue
except Exception as exc:
logger.warning(
f"ModelCard.load failed for {meta_instance.model_id}: {exc}"
)
error = f"Failed to load model card: {exc}"
if meta_instance.placement_error != error:
all_events.append(
MetaInstancePlacementFailed(
meta_instance_id=meta_instance.meta_instance_id,
reason=error,
)
)
continue
result = try_place_for_meta_instance(
meta_instance,
model_card,
state.topology,
current_instances,
state.node_memory,
state.node_network,
state.tasks,
)
# Update local instance map so next placement sees this one
for event in result.events:
if isinstance(event, InstanceCreated):
logger.info(
f"MetaInstance reconciler placed instance"
f" {event.instance.instance_id} for"
f" {meta_instance.model_id}"
)
current_instances[event.instance.instance_id] = event.instance
all_events.extend(result.events)
# Emit placement failure if error differs from what's already in state
if (
result.error is not None
and meta_instance.placement_error != result.error
):
logger.warning(
f"MetaInstance placement failed for"
f" {meta_instance.model_id}: {result.error}"
)
all_events.append(
MetaInstancePlacementFailed(
meta_instance_id=meta_instance.meta_instance_id,
reason=result.error,
)
)
return all_events

View File

@@ -1,27 +0,0 @@
from collections.abc import Sequence
from datetime import datetime, timedelta, timezone
from typing import final
from loguru import logger
from exo.shared.types.events import Event, NodeTimedOut
from exo.shared.types.state import State
_DEFAULT_TIMEOUT = timedelta(seconds=30)
@final
class NodeTimeoutReconciler:
"""Time out nodes that haven't been seen recently."""
def __init__(self, timeout: timedelta = _DEFAULT_TIMEOUT) -> None:
self.timeout = timeout
async def reconcile(self, state: State) -> Sequence[Event]:
now = datetime.now(tz=timezone.utc)
events: list[Event] = []
for node_id, last_seen in state.last_seen.items():
if now - last_seen > self.timeout:
logger.info(f"Removing node {node_id} due to inactivity")
events.append(NodeTimedOut(node_id=node_id))
return events

View File

@@ -1,244 +0,0 @@
from collections.abc import Mapping, Sequence
from typing import NamedTuple
from loguru import logger
from exo.master.placement import get_transition_events, place_instance
from exo.shared.models.model_cards import ModelCard
from exo.shared.topology import Topology
from exo.shared.types.commands import PlaceInstance
from exo.shared.types.common import MetaInstanceId, NodeId
from exo.shared.types.events import Event
from exo.shared.types.meta_instance import MetaInstance
from exo.shared.types.profiling import MemoryUsage, NodeIdentity, NodeNetworkInfo
from exo.shared.types.tasks import Task, TaskId
from exo.shared.types.topology import RDMAConnection, SocketConnection
from exo.shared.types.worker.instances import (
BaseInstance,
Instance,
InstanceId,
MlxJacclInstance,
MlxRingInstance,
)
from exo.shared.types.worker.runners import (
RunnerFailed,
RunnerId,
RunnerShutdown,
RunnerStatus,
)
class PlacementResult(NamedTuple):
"""Result of a placement attempt: events to apply and optional error reason."""
events: Sequence[Event]
error: str | None
def _get_ring_order(instance: BaseInstance) -> list[NodeId]:
"""Reconstruct ring order from shard device_rank."""
node_ranks: list[tuple[NodeId, int]] = []
for node_id, runner_id in instance.shard_assignments.node_to_runner.items():
shard = instance.shard_assignments.runner_to_shard[runner_id]
node_ranks.append((node_id, shard.device_rank))
node_ranks.sort(key=lambda x: x[1])
return [node_id for node_id, _ in node_ranks]
def _ring_connections_healthy(instance: MlxRingInstance, topology: Topology) -> bool:
"""Check that the specific IPs used by a ring instance still exist in the topology."""
ring = _get_ring_order(instance)
n = len(ring)
for node in ring:
hosts = instance.hosts_by_node[node]
for idx in range(n):
host = hosts[idx]
if host.ip in ("0.0.0.0", "198.51.100.1"):
continue # self or placeholder
# Real connection: node → ring[idx]. Check specific IP.
connections = topology.get_all_connections_between(node, ring[idx])
if not any(
isinstance(c, SocketConnection)
and c.sink_multiaddr.ip_address == host.ip
for c in connections
):
return False
return True
def _jaccl_connections_healthy(instance: MlxJacclInstance, topology: Topology) -> bool:
"""Check that the specific RDMA interfaces used by a JACCL instance still exist."""
ring = _get_ring_order(instance)
n = len(ring)
for i in range(n):
for j in range(n):
iface = instance.jaccl_devices[i][j]
if iface is None:
continue
connections = topology.get_all_connections_between(ring[i], ring[j])
if not any(
isinstance(c, RDMAConnection) and c.source_rdma_iface == iface
for c in connections
):
return False
return True
def instance_connections_healthy(instance: Instance, topology: Topology) -> bool:
"""Check that an instance's nodes and specific connections are still in the topology."""
instance_nodes = set(instance.shard_assignments.node_to_runner.keys())
if not all(topology.contains_node(n) for n in instance_nodes):
return False
if len(instance_nodes) <= 1:
return True
match instance:
case MlxRingInstance():
return _ring_connections_healthy(instance, topology)
case MlxJacclInstance():
return _jaccl_connections_healthy(instance, topology)
def instance_runners_failed(
instance: Instance,
runners: Mapping[RunnerId, RunnerStatus],
node_identities: Mapping[NodeId, NodeIdentity],
) -> tuple[bool, str | None]:
"""Check if an instance's runners have all reached terminal failure states.
Returns ``(True, error_message)`` when ALL runners are terminal
(``RunnerFailed`` or ``RunnerShutdown``) and at least one is ``RunnerFailed``.
Returns ``(False, None)`` when runners are still active, haven't reported
yet, or all gracefully shut down (no ``RunnerFailed``).
"""
instance_runner_ids = set(instance.shard_assignments.node_to_runner.values())
if not instance_runner_ids:
return False, None
# Build reverse mapping: runner_id -> node_id
runner_to_node: dict[RunnerId, NodeId] = {
runner_id: node_id
for node_id, runner_id in instance.shard_assignments.node_to_runner.items()
}
has_any_failed = False
error_messages: list[str] = []
for runner_id in instance_runner_ids:
status = runners.get(runner_id)
if status is None:
# Runner hasn't reported yet — instance is still starting
return False, None
if isinstance(status, RunnerFailed):
has_any_failed = True
if status.error_message:
node_id = runner_to_node.get(runner_id)
name = (
node_identities[node_id].friendly_name
if node_id and node_id in node_identities
else node_id or "unknown"
)
error_messages.append(f"{name}: {status.error_message}")
elif isinstance(status, RunnerShutdown):
pass # Terminal but not a failure indicator on its own
else:
# Runner is still active (connecting, loading, running, etc.)
return False, None
if has_any_failed:
return True, "; ".join(error_messages) if error_messages else "Runner failed"
# All runners are Shutdown but none Failed — graceful shutdown, not a failure
return False, None
def instance_satisfies_meta_instance(
meta_instance: MetaInstance,
instance: Instance,
) -> bool:
"""Check if a single instance satisfies a meta-instance's constraints.
This is a pure constraint check (model, min_nodes, node_ids).
Use ``instance_connections_healthy`` separately for topology health.
"""
if instance.shard_assignments.model_id != meta_instance.model_id:
return False
instance_nodes = set(instance.shard_assignments.node_to_runner.keys())
if len(instance_nodes) < meta_instance.min_nodes:
return False
return meta_instance.node_ids is None or set(meta_instance.node_ids).issubset(
instance_nodes
)
def find_unsatisfied_meta_instances(
meta_instances: Mapping[MetaInstanceId, MetaInstance],
instances: Mapping[InstanceId, Instance],
topology: Topology,
) -> Sequence[MetaInstance]:
"""Return meta-instances that have no healthy backing instance."""
unsatisfied: list[MetaInstance] = []
for meta_id, meta_instance in meta_instances.items():
has_healthy_backing = any(
instance.meta_instance_id == meta_id
and instance_connections_healthy(instance, topology)
for instance in instances.values()
)
if not has_healthy_backing:
unsatisfied.append(meta_instance)
return unsatisfied
def try_place_for_meta_instance(
meta_instance: MetaInstance,
model_card: ModelCard,
topology: Topology,
current_instances: Mapping[InstanceId, Instance],
node_memory: Mapping[NodeId, MemoryUsage],
node_network: Mapping[NodeId, NodeNetworkInfo],
tasks: Mapping[TaskId, Task],
) -> PlacementResult:
"""Try to place an instance satisfying the meta-instance constraints.
Returns a :class:`PlacementResult` with events on success, or an error
reason on failure.
"""
command = PlaceInstance(
model_card=model_card,
sharding=meta_instance.sharding,
instance_meta=meta_instance.instance_meta,
min_nodes=meta_instance.min_nodes,
)
try:
target_instances = place_instance(
command,
topology,
current_instances,
node_memory,
node_network,
required_nodes=(
set(meta_instance.node_ids) if meta_instance.node_ids else None
),
)
# Tag the new instance with meta_instance_id
new_instance_ids = set(target_instances.keys()) - set(current_instances.keys())
if new_instance_ids:
new_id = next(iter(new_instance_ids))
target_instances[new_id] = target_instances[new_id].model_copy(
update={"meta_instance_id": meta_instance.meta_instance_id}
)
return PlacementResult(
events=list(
get_transition_events(current_instances, target_instances, tasks)
),
error=None,
)
except ValueError as e:
logger.debug(
f"MetaInstance placement not possible for {meta_instance.model_id}: {e}"
)
return PlacementResult(events=[], error=str(e))

View File

@@ -1,778 +0,0 @@
"""Edge-case and regression tests for MetaInstance lifecycle, concurrent operations, and error handling."""
import pytest
from exo.master.process_managers.instance_health import (
MAX_INSTANCE_RETRIES,
InstanceHealthReconciler,
)
from exo.master.process_managers.meta_instance import MetaInstanceReconciler
from exo.master.reconcile import (
find_unsatisfied_meta_instances,
instance_connections_healthy,
instance_runners_failed,
instance_satisfies_meta_instance,
)
from exo.shared.apply import apply
from exo.shared.models.model_cards import ModelCard, ModelId, ModelTask
from exo.shared.topology import Topology
from exo.shared.types.common import Host, MetaInstanceId, NodeId
from exo.shared.types.events import (
IndexedEvent,
InstanceCreated,
InstanceDeleted,
InstanceRetrying,
MetaInstanceCreated,
MetaInstanceDeleted,
MetaInstancePlacementFailed,
TaskStatusUpdated,
)
from exo.shared.types.memory import Memory
from exo.shared.types.meta_instance import MetaInstance
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.profiling import NodeIdentity
from exo.shared.types.state import State
from exo.shared.types.tasks import LoadModel, TaskId, TaskStatus
from exo.shared.types.topology import Connection, SocketConnection
from exo.shared.types.worker.instances import (
InstanceId,
MlxRingInstance,
)
from exo.shared.types.worker.runners import (
RunnerFailed,
RunnerId,
RunnerReady,
ShardAssignments,
)
from exo.shared.types.worker.shards import PipelineShardMetadata
# --- Helpers (copied from test_reconcile.py for independence) ---
def _model_card(model_id: str = "test-org/test-model") -> ModelCard:
return ModelCard(
model_id=ModelId(model_id),
storage_size=Memory.from_kb(1000),
n_layers=10,
hidden_size=30,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
)
def _topology(*node_ids: str, connect: bool = True) -> Topology:
t = Topology()
nodes = [NodeId(n) for n in node_ids]
for n in nodes:
t.add_node(n)
if connect and len(nodes) > 1:
for i in range(len(nodes)):
j = (i + 1) % len(nodes)
t.add_connection(
Connection(
source=nodes[i],
sink=nodes[j],
edge=SocketConnection(
sink_multiaddr=Multiaddr(
address=f"/ip4/10.0.0.{j + 1}/tcp/50000"
)
),
)
)
t.add_connection(
Connection(
source=nodes[j],
sink=nodes[i],
edge=SocketConnection(
sink_multiaddr=Multiaddr(
address=f"/ip4/10.0.0.{i + 1}/tcp/50000"
)
),
)
)
return t
def _meta_instance(
model_id: str = "test-org/test-model",
*,
min_nodes: int = 1,
node_ids: list[NodeId] | None = None,
meta_instance_id: MetaInstanceId | None = None,
consecutive_failures: int = 0,
last_failure_error: str | None = None,
placement_error: str | None = None,
) -> MetaInstance:
return MetaInstance(
meta_instance_id=meta_instance_id or MetaInstanceId(),
model_id=ModelId(model_id),
min_nodes=min_nodes,
node_ids=node_ids,
consecutive_failures=consecutive_failures,
last_failure_error=last_failure_error,
placement_error=placement_error,
)
def _instance(
model_id: str = "test-org/test-model",
node_ids: list[str] | None = None,
instance_id: InstanceId | None = None,
meta_instance_id: MetaInstanceId | None = None,
) -> tuple[InstanceId, MlxRingInstance]:
iid = instance_id or InstanceId()
nodes = node_ids or ["node-a"]
n = len(nodes)
mc = _model_card(model_id)
ephemeral_port = 50000
node_to_runner = {NodeId(nd): RunnerId() for nd in nodes}
runner_to_shard = {
runner_id: PipelineShardMetadata(
model_card=mc,
device_rank=i,
world_size=n,
start_layer=0,
end_layer=mc.n_layers,
n_layers=mc.n_layers,
)
for i, runner_id in enumerate(node_to_runner.values())
}
hosts_by_node: dict[NodeId, list[Host]] = {}
for r, node_str in enumerate(nodes):
hosts: list[Host] = []
for idx in range(n):
if idx == r:
hosts.append(Host(ip="0.0.0.0", port=ephemeral_port))
elif n > 1 and idx in ((r - 1) % n, (r + 1) % n):
hosts.append(Host(ip=f"10.0.0.{idx + 1}", port=ephemeral_port))
else:
hosts.append(Host(ip="198.51.100.1", port=0))
hosts_by_node[NodeId(node_str)] = hosts
return iid, MlxRingInstance(
instance_id=iid,
shard_assignments=ShardAssignments(
model_id=ModelId(model_id),
runner_to_shard=runner_to_shard,
node_to_runner=node_to_runner,
),
hosts_by_node=hosts_by_node,
ephemeral_port=ephemeral_port,
meta_instance_id=meta_instance_id,
)
# =============================================================================
# 1. MetaInstance lifecycle edge cases
# =============================================================================
def test_meta_instance_model_is_frozen():
"""MetaInstance should be immutable (frozen model)."""
meta = _meta_instance()
try:
meta.model_id = ModelId("something-else")
raise AssertionError("Should have raised")
except Exception:
pass # Expected — frozen model
def test_meta_instance_created_then_deleted_roundtrip():
"""Create and delete a MetaInstance through apply — state should be clean."""
state = State()
meta = _meta_instance()
state = apply(
state, IndexedEvent(idx=0, event=MetaInstanceCreated(meta_instance=meta))
)
assert meta.meta_instance_id in state.meta_instances
state = apply(
state,
IndexedEvent(
idx=1, event=MetaInstanceDeleted(meta_instance_id=meta.meta_instance_id)
),
)
assert meta.meta_instance_id not in state.meta_instances
assert len(state.meta_instances) == 0
def test_delete_nonexistent_meta_instance_is_safe():
"""Deleting a MetaInstance that doesn't exist should not crash."""
state = State()
event = MetaInstanceDeleted(meta_instance_id=MetaInstanceId("nonexistent"))
new_state = apply(state, IndexedEvent(idx=0, event=event))
assert len(new_state.meta_instances) == 0
def test_placement_failed_for_nonexistent_meta_instance_is_safe():
"""MetaInstancePlacementFailed for unknown ID should not crash."""
state = State()
event = MetaInstancePlacementFailed(
meta_instance_id=MetaInstanceId("nonexistent"),
reason="test",
)
new_state = apply(state, IndexedEvent(idx=0, event=event))
assert len(new_state.meta_instances) == 0
def test_multiple_meta_instances_for_same_model():
"""Multiple MetaInstances for the same model are tracked independently."""
state = State()
meta_a = _meta_instance("test-org/model-x")
meta_b = _meta_instance("test-org/model-x")
state = apply(
state, IndexedEvent(idx=0, event=MetaInstanceCreated(meta_instance=meta_a))
)
state = apply(
state, IndexedEvent(idx=1, event=MetaInstanceCreated(meta_instance=meta_b))
)
assert len(state.meta_instances) == 2
assert meta_a.meta_instance_id in state.meta_instances
assert meta_b.meta_instance_id in state.meta_instances
# =============================================================================
# 2. Retry logic edge cases
# =============================================================================
def test_retry_counter_resets_on_successful_instance_creation():
"""When a new instance is created for a meta-instance, failures should reset."""
meta = _meta_instance(consecutive_failures=2, last_failure_error="old")
_, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
state = State(meta_instances={meta.meta_instance_id: meta})
state = apply(state, IndexedEvent(idx=0, event=InstanceCreated(instance=inst)))
mi = state.meta_instances[meta.meta_instance_id]
assert mi.consecutive_failures == 0
# last_failure_error is preserved (for UI display)
assert mi.last_failure_error == "old"
async def test_retry_count_increments_through_full_cycle():
"""Walk through MAX_INSTANCE_RETRIES worth of retries, then verify delete."""
meta = _meta_instance()
iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
topology = _topology("node-a")
state = State(
meta_instances={meta.meta_instance_id: meta},
instances={iid: inst},
topology=topology,
)
runner_ids = list(inst.shard_assignments.node_to_runner.values())
for idx, i in enumerate(range(MAX_INSTANCE_RETRIES)):
# Simulate runners failing
state_with_runners = state.model_copy(
update={"runners": {runner_ids[0]: RunnerFailed(error_message=f"fail-{i}")}}
)
reconciler = InstanceHealthReconciler()
events = await reconciler.reconcile(state_with_runners)
assert len(events) == 1
assert isinstance(events[0], InstanceRetrying), f"iteration {i}"
state = apply(state, IndexedEvent(idx=idx, event=events[0]))
# After MAX_INSTANCE_RETRIES retries, failure counter should be at max
mi = state.meta_instances[meta.meta_instance_id]
assert mi.consecutive_failures == MAX_INSTANCE_RETRIES
# Next failure should result in deletion
state_with_runners = state.model_copy(
update={"runners": {runner_ids[0]: RunnerFailed(error_message="final")}}
)
reconciler = InstanceHealthReconciler()
events = await reconciler.reconcile(state_with_runners)
assert len(events) == 1
assert isinstance(events[0], InstanceDeleted)
async def test_health_reconciler_respects_exact_limit():
"""At exactly MAX_INSTANCE_RETRIES, reconciler should delete, not retry."""
meta = _meta_instance(consecutive_failures=MAX_INSTANCE_RETRIES)
iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
runner_ids = list(inst.shard_assignments.node_to_runner.values())
state = State(
meta_instances={meta.meta_instance_id: meta},
instances={iid: inst},
runners={runner_ids[0]: RunnerFailed(error_message="OOM")},
topology=_topology("node-a"),
)
reconciler = InstanceHealthReconciler()
events = await reconciler.reconcile(state)
assert len(events) == 1
assert isinstance(events[0], InstanceDeleted)
async def test_health_reconciler_at_limit_minus_one_retries():
"""At MAX_INSTANCE_RETRIES - 1, reconciler should still retry."""
meta = _meta_instance(consecutive_failures=MAX_INSTANCE_RETRIES - 1)
iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
runner_ids = list(inst.shard_assignments.node_to_runner.values())
state = State(
meta_instances={meta.meta_instance_id: meta},
instances={iid: inst},
runners={runner_ids[0]: RunnerFailed(error_message="OOM")},
topology=_topology("node-a"),
)
reconciler = InstanceHealthReconciler()
events = await reconciler.reconcile(state)
assert len(events) == 1
assert isinstance(events[0], InstanceRetrying)
# =============================================================================
# 3. Error handling edge cases
# =============================================================================
def test_runners_failed_with_empty_error_message():
"""RunnerFailed with empty error_message should still report as failed."""
_, inst = _instance(node_ids=["node-a"])
runners = {
rid: RunnerFailed(error_message="")
for rid in inst.shard_assignments.node_to_runner.values()
}
is_failed, error = instance_runners_failed(inst, runners, {})
assert is_failed is True
# Empty error message means we get the fallback
assert error == "Runner failed"
def test_runners_failed_with_none_error_message():
"""RunnerFailed with None error_message should still report as failed."""
_, inst = _instance(node_ids=["node-a"])
runners = {
rid: RunnerFailed(error_message=None)
for rid in inst.shard_assignments.node_to_runner.values()
}
is_failed, error = instance_runners_failed(inst, runners, {})
assert is_failed is True
assert error == "Runner failed"
def test_runners_failed_collects_all_error_messages():
"""With multiple failed runners, all error messages should be collected."""
_, inst = _instance(node_ids=["node-a", "node-b", "node-c"])
runner_ids = list(inst.shard_assignments.node_to_runner.values())
runners = {
runner_ids[0]: RunnerFailed(error_message="OOM on GPU 0"),
runner_ids[1]: RunnerFailed(error_message="OOM on GPU 1"),
runner_ids[2]: RunnerFailed(error_message="OOM on GPU 2"),
}
is_failed, error = instance_runners_failed(inst, runners, {})
assert is_failed is True
assert error is not None
assert "OOM on GPU 0" in error
assert "OOM on GPU 1" in error
assert "OOM on GPU 2" in error
def test_runners_failed_includes_friendly_name():
"""Error messages should include node friendly names when available."""
_, inst = _instance(node_ids=["node-a"])
node_id = NodeId("node-a")
runner_ids = list(inst.shard_assignments.node_to_runner.values())
runners = {runner_ids[0]: RunnerFailed(error_message="OOM")}
identities = {node_id: NodeIdentity(friendly_name="My Mac Studio")}
is_failed, error = instance_runners_failed(inst, runners, identities)
assert is_failed is True
assert error is not None
assert "My Mac Studio" in error
def test_instance_retrying_for_missing_instance_is_safe():
"""InstanceRetrying for an instance not in state should not crash.
NOTE: When the instance is missing, the handler returns early WITHOUT
incrementing the MetaInstance failure counter. This means stale retry
events for already-deleted instances are silently dropped. This is
acceptable since the InstanceDeleted handler already increments failures.
"""
meta = _meta_instance()
state = State(meta_instances={meta.meta_instance_id: meta})
event = InstanceRetrying(
instance_id=InstanceId("nonexistent"),
meta_instance_id=meta.meta_instance_id,
failure_error="crash",
)
new_state = apply(state, IndexedEvent(idx=0, event=event))
# Does not crash, but failure count is NOT incremented (early return)
mi = new_state.meta_instances[meta.meta_instance_id]
assert mi.consecutive_failures == 0
# =============================================================================
# 4. Backward compatibility
# =============================================================================
def test_instance_without_meta_instance_id_works():
"""Instances created without meta_instance_id should still function normally."""
_, inst = _instance(node_ids=["node-a"])
assert inst.meta_instance_id is None
topology = _topology("node-a")
assert instance_connections_healthy(inst, topology) is True
def test_instance_deleted_without_meta_does_not_affect_meta_instances():
"""Deleting an instance without meta_instance_id should not affect meta_instances."""
meta = _meta_instance()
iid, inst = _instance(node_ids=["node-a"]) # no meta_instance_id
state = State(
meta_instances={meta.meta_instance_id: meta},
instances={iid: inst},
)
event = InstanceDeleted(instance_id=iid, failure_error="crash")
new_state = apply(state, IndexedEvent(idx=0, event=event))
mi = new_state.meta_instances[meta.meta_instance_id]
assert mi.consecutive_failures == 0 # unchanged
def test_satisfies_ignores_meta_instance_id_binding():
"""instance_satisfies_meta_instance checks constraints only, not binding."""
meta = _meta_instance()
_, inst = _instance(node_ids=["node-a"]) # no meta_instance_id set
# Should match on constraints (model, min_nodes) regardless of binding
assert instance_satisfies_meta_instance(meta, inst) is True
def test_find_unsatisfied_uses_binding_not_constraints():
"""find_unsatisfied checks meta_instance_id binding, not just constraint matching."""
meta = _meta_instance()
# Instance matches constraints but is NOT bound to this meta_instance
iid, inst = _instance(node_ids=["node-a"])
topology = _topology("node-a")
result = find_unsatisfied_meta_instances(
{meta.meta_instance_id: meta}, {iid: inst}, topology
)
# Should be unsatisfied because instance.meta_instance_id != meta.meta_instance_id
assert list(result) == [meta]
# =============================================================================
# 5. Concurrent / multi-instance scenarios
# =============================================================================
async def test_health_reconciler_handles_multiple_failing_instances():
"""Multiple instances failing simultaneously should each get their own event."""
meta_a = _meta_instance()
meta_b = _meta_instance()
iid_a, inst_a = _instance(
node_ids=["node-a"], meta_instance_id=meta_a.meta_instance_id
)
iid_b, inst_b = _instance(
node_ids=["node-b"], meta_instance_id=meta_b.meta_instance_id
)
runner_ids_a = list(inst_a.shard_assignments.node_to_runner.values())
runner_ids_b = list(inst_b.shard_assignments.node_to_runner.values())
state = State(
meta_instances={
meta_a.meta_instance_id: meta_a,
meta_b.meta_instance_id: meta_b,
},
instances={iid_a: inst_a, iid_b: inst_b},
runners={
runner_ids_a[0]: RunnerFailed(error_message="OOM"),
runner_ids_b[0]: RunnerFailed(error_message="OOM"),
},
topology=_topology("node-a", "node-b"),
)
reconciler = InstanceHealthReconciler()
events = await reconciler.reconcile(state)
assert len(events) == 2
# Both should be InstanceRetrying since failures < MAX
assert all(isinstance(e, InstanceRetrying) for e in events)
instance_ids = {e.instance_id for e in events} # type: ignore[union-attr]
assert instance_ids == {iid_a, iid_b}
async def test_health_reconciler_mixed_healthy_and_failing():
"""Only failing instances should produce events; healthy ones should not."""
meta_healthy = _meta_instance()
meta_failing = _meta_instance()
iid_h, inst_h = _instance(
node_ids=["node-a"], meta_instance_id=meta_healthy.meta_instance_id
)
iid_f, inst_f = _instance(
node_ids=["node-b"], meta_instance_id=meta_failing.meta_instance_id
)
runner_ids_h = list(inst_h.shard_assignments.node_to_runner.values())
runner_ids_f = list(inst_f.shard_assignments.node_to_runner.values())
state = State(
meta_instances={
meta_healthy.meta_instance_id: meta_healthy,
meta_failing.meta_instance_id: meta_failing,
},
instances={iid_h: inst_h, iid_f: inst_f},
runners={
runner_ids_h[0]: RunnerReady(),
runner_ids_f[0]: RunnerFailed(error_message="crash"),
},
topology=_topology("node-a", "node-b"),
)
reconciler = InstanceHealthReconciler()
events = await reconciler.reconcile(state)
assert len(events) == 1
assert isinstance(events[0], InstanceRetrying)
assert events[0].instance_id == iid_f
async def test_meta_instance_reconciler_empty_state():
"""MetaInstanceReconciler with no meta_instances should produce no events."""
state = State()
reconciler = MetaInstanceReconciler()
events = await reconciler.reconcile(state)
assert len(events) == 0
# =============================================================================
# 6. Placement error tracking
# =============================================================================
def test_placement_failed_sets_error():
"""MetaInstancePlacementFailed should set placement_error on the MetaInstance."""
meta = _meta_instance()
state = State(meta_instances={meta.meta_instance_id: meta})
event = MetaInstancePlacementFailed(
meta_instance_id=meta.meta_instance_id,
reason="Not enough memory",
)
new_state = apply(state, IndexedEvent(idx=0, event=event))
mi = new_state.meta_instances[meta.meta_instance_id]
assert mi.placement_error == "Not enough memory"
def test_instance_created_clears_placement_error():
"""InstanceCreated should clear placement_error on the MetaInstance."""
meta = _meta_instance(placement_error="Not enough memory")
_, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
state = State(meta_instances={meta.meta_instance_id: meta})
state = apply(state, IndexedEvent(idx=0, event=InstanceCreated(instance=inst)))
mi = state.meta_instances[meta.meta_instance_id]
assert mi.placement_error is None
def test_placement_error_does_not_increment_failures():
"""Placement failures should only set placement_error, not increment consecutive_failures."""
meta = _meta_instance()
state = State(meta_instances={meta.meta_instance_id: meta})
event = MetaInstancePlacementFailed(
meta_instance_id=meta.meta_instance_id,
reason="No resources",
)
new_state = apply(state, IndexedEvent(idx=0, event=event))
mi = new_state.meta_instances[meta.meta_instance_id]
assert mi.consecutive_failures == 0
assert mi.placement_error == "No resources"
# =============================================================================
# 7. State serialization roundtrip
# =============================================================================
def test_state_with_meta_instances_serializes():
"""State with meta_instances should serialize and deserialize correctly."""
meta = _meta_instance(consecutive_failures=2, last_failure_error="test")
iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
state = State(
meta_instances={meta.meta_instance_id: meta},
instances={iid: inst},
)
json_str = state.model_dump_json()
restored = State.model_validate_json(json_str)
assert meta.meta_instance_id in restored.meta_instances
mi = restored.meta_instances[meta.meta_instance_id]
assert mi.model_id == meta.model_id
assert mi.consecutive_failures == 2
assert mi.last_failure_error == "test"
assert iid in restored.instances
assert restored.instances[iid].meta_instance_id == meta.meta_instance_id
# =============================================================================
# 8. MetaInstanceReconciler error handling
# =============================================================================
async def test_meta_instance_reconciler_model_load_error_emits_placement_failed(
monkeypatch: "pytest.MonkeyPatch",
):
"""When ModelCard.load raises, reconciler emits MetaInstancePlacementFailed."""
import exo.master.process_managers.meta_instance as mi_mod
meta = _meta_instance()
topo = _topology("node-a")
state = State(
meta_instances={meta.meta_instance_id: meta},
topology=topo,
)
async def _failing_load(_model_id: ModelId) -> ModelCard:
raise RuntimeError("Network error")
monkeypatch.setattr(
mi_mod, "ModelCard", type("MC", (), {"load": staticmethod(_failing_load)})
)
reconciler = MetaInstanceReconciler()
events = await reconciler.reconcile(state)
placement_failed = [e for e in events if isinstance(e, MetaInstancePlacementFailed)]
assert len(placement_failed) == 1
assert "Failed to load model card" in placement_failed[0].reason
assert meta.meta_instance_id == placement_failed[0].meta_instance_id
async def test_meta_instance_reconciler_model_load_error_skips_dedup(
monkeypatch: "pytest.MonkeyPatch",
):
"""When ModelCard.load error matches existing placement_error, no duplicate event."""
import exo.master.process_managers.meta_instance as mi_mod
meta = _meta_instance(placement_error="Failed to load model card: Network error")
topo = _topology("node-a")
state = State(
meta_instances={meta.meta_instance_id: meta},
topology=topo,
)
async def _failing_load(_model_id: ModelId) -> ModelCard:
raise RuntimeError("Network error")
monkeypatch.setattr(
mi_mod, "ModelCard", type("MC", (), {"load": staticmethod(_failing_load)})
)
reconciler = MetaInstanceReconciler()
events = await reconciler.reconcile(state)
# Error matches existing placement_error, so no duplicate event emitted
assert len(events) == 0
async def test_meta_instance_reconciler_continues_after_error(
monkeypatch: "pytest.MonkeyPatch",
):
"""Reconciler should continue to next meta-instance after one fails to load."""
import exo.master.process_managers.meta_instance as mi_mod
meta_a = _meta_instance(model_id="org/model-a")
meta_b = _meta_instance(model_id="org/model-b")
topo = _topology("node-a")
state = State(
meta_instances={
meta_a.meta_instance_id: meta_a,
meta_b.meta_instance_id: meta_b,
},
topology=topo,
)
call_count = 0
async def _load_second_fails(model_id: ModelId) -> ModelCard:
nonlocal call_count
call_count += 1
raise RuntimeError(f"Cannot load {model_id}")
monkeypatch.setattr(
mi_mod, "ModelCard", type("MC", (), {"load": staticmethod(_load_second_fails)})
)
reconciler = MetaInstanceReconciler()
events = await reconciler.reconcile(state)
# Both meta-instances should have been attempted (not short-circuited)
assert call_count == 2
# Both should have placement failed events
placement_failed = [e for e in events if isinstance(e, MetaInstancePlacementFailed)]
assert len(placement_failed) == 2
# =============================================================================
# 8. Cascade delete with task cancellation
# =============================================================================
def test_cascade_delete_cancels_active_tasks():
"""Deleting a MetaInstance should cancel tasks on backing instances.
Regression test: previously, cascade-deleting backing instances via
DeleteMetaInstance did not emit TaskStatusUpdated(Cancelled) for active
tasks, leaving orphaned task references in state.
"""
meta = _meta_instance()
iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
task_id = TaskId()
task = LoadModel(task_id=task_id, instance_id=iid, task_status=TaskStatus.Running)
# Build state with meta-instance, backing instance, and active task
state = State(
meta_instances={meta.meta_instance_id: meta},
instances={iid: inst},
tasks={task_id: task},
topology=_topology("node-a"),
)
# Simulate the cascade-delete event sequence produced by main.py:
# 1. MetaInstanceDeleted
# 2. TaskStatusUpdated(Cancelled) for active tasks
# 3. InstanceDeleted
idx = 0
state = apply(
state,
IndexedEvent(
idx=idx,
event=MetaInstanceDeleted(meta_instance_id=meta.meta_instance_id),
),
)
idx += 1
state = apply(
state,
IndexedEvent(
idx=idx,
event=TaskStatusUpdated(task_id=task_id, task_status=TaskStatus.Cancelled),
),
)
idx += 1
state = apply(
state,
IndexedEvent(idx=idx, event=InstanceDeleted(instance_id=iid)),
)
# Verify everything is cleaned up
assert len(state.meta_instances) == 0
assert len(state.instances) == 0
assert state.tasks[task_id].task_status == TaskStatus.Cancelled
def test_cascade_delete_skips_completed_tasks():
"""Cascade delete should only cancel Pending/Running tasks, not completed ones."""
meta = _meta_instance()
iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
running_task_id = TaskId()
completed_task_id = TaskId()
running_task = LoadModel(
task_id=running_task_id, instance_id=iid, task_status=TaskStatus.Running
)
completed_task = LoadModel(
task_id=completed_task_id, instance_id=iid, task_status=TaskStatus.Complete
)
state = State(
meta_instances={meta.meta_instance_id: meta},
instances={iid: inst},
tasks={running_task_id: running_task, completed_task_id: completed_task},
topology=_topology("node-a"),
)
# Only the running task should be cancelled — we verify the logic pattern
# by checking which tasks are Pending or Running
active_tasks = [
t
for t in state.tasks.values()
if t.instance_id == iid
and t.task_status in (TaskStatus.Pending, TaskStatus.Running)
]
assert len(active_tasks) == 1
assert active_tasks[0].task_id == running_task_id

View File

@@ -3,10 +3,10 @@ import pytest
from exo.master.placement_utils import (
allocate_layers_proportionally,
filter_cycles_by_memory,
get_largest_cycles,
get_mlx_jaccl_coordinators,
get_shard_assignments,
get_shard_assignments_for_pipeline_parallel,
get_smallest_cycles,
)
from exo.master.tests.conftest import (
create_node_memory,
@@ -143,7 +143,7 @@ def test_filter_multiple_cycles_by_memory():
}
def test_get_largest_cycles():
def test_get_smallest_cycles():
# arrange
node_a_id = NodeId()
node_b_id = NodeId()
@@ -175,12 +175,12 @@ def test_get_largest_cycles():
cycles = [c for c in topology.get_cycles() if len(c) != 1] # ignore singletons
# act
largest_cycles = get_largest_cycles(cycles)
smallest_cycles = get_smallest_cycles(cycles)
# assert
assert len(largest_cycles) == 1
assert len(largest_cycles[0]) == 3
assert set(n for n in largest_cycles[0]) == {node_a_id, node_b_id, node_c_id}
assert len(smallest_cycles) == 1
assert len(smallest_cycles[0]) == 2
assert set(n for n in smallest_cycles[0]) == {node_a_id, node_b_id}
@pytest.mark.parametrize(

View File

@@ -1,742 +0,0 @@
from exo.master.process_managers.instance_health import InstanceHealthReconciler
from exo.master.reconcile import (
find_unsatisfied_meta_instances,
instance_connections_healthy,
instance_runners_failed,
instance_satisfies_meta_instance,
)
from exo.shared.apply import apply
from exo.shared.models.model_cards import ModelCard, ModelId, ModelTask
from exo.shared.topology import Topology
from exo.shared.types.common import Host, MetaInstanceId, NodeId
from exo.shared.types.events import (
IndexedEvent,
InstanceCreated,
InstanceDeleted,
InstanceRetrying,
MetaInstanceCreated,
MetaInstanceDeleted,
)
from exo.shared.types.memory import Memory
from exo.shared.types.meta_instance import MetaInstance
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.state import State
from exo.shared.types.topology import Connection, SocketConnection
from exo.shared.types.worker.instances import (
InstanceId,
MlxRingInstance,
)
from exo.shared.types.worker.runners import (
RunnerFailed,
RunnerId,
RunnerLoading,
RunnerReady,
RunnerShutdown,
ShardAssignments,
)
from exo.shared.types.worker.shards import PipelineShardMetadata
def _model_card(model_id: str = "test-org/test-model") -> ModelCard:
return ModelCard(
model_id=ModelId(model_id),
storage_size=Memory.from_kb(1000),
n_layers=10,
hidden_size=30,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
)
def _topology(*node_ids: str, connect: bool = True) -> Topology:
"""Build a topology with nodes connected in a bidirectional ring with unique IPs.
Node at index ``i`` gets IP ``10.0.0.{i+1}``. Edges go in both directions
between consecutive nodes (including wrap-around).
"""
t = Topology()
nodes = [NodeId(n) for n in node_ids]
for n in nodes:
t.add_node(n)
if connect and len(nodes) > 1:
for i in range(len(nodes)):
j = (i + 1) % len(nodes)
t.add_connection(
Connection(
source=nodes[i],
sink=nodes[j],
edge=SocketConnection(
sink_multiaddr=Multiaddr(
address=f"/ip4/10.0.0.{j + 1}/tcp/50000"
)
),
)
)
t.add_connection(
Connection(
source=nodes[j],
sink=nodes[i],
edge=SocketConnection(
sink_multiaddr=Multiaddr(
address=f"/ip4/10.0.0.{i + 1}/tcp/50000"
)
),
)
)
return t
def _meta_instance(
model_id: str = "test-org/test-model",
*,
min_nodes: int = 1,
node_ids: list[NodeId] | None = None,
meta_instance_id: MetaInstanceId | None = None,
) -> MetaInstance:
return MetaInstance(
meta_instance_id=meta_instance_id or MetaInstanceId(),
model_id=ModelId(model_id),
min_nodes=min_nodes,
node_ids=node_ids,
)
def _instance(
model_id: str = "test-org/test-model",
node_ids: list[str] | None = None,
instance_id: InstanceId | None = None,
meta_instance_id: MetaInstanceId | None = None,
) -> tuple[InstanceId, MlxRingInstance]:
"""Create a test instance with hosts_by_node matching ``_topology()`` IPs."""
iid = instance_id or InstanceId()
nodes = node_ids or ["node-a"]
n = len(nodes)
mc = _model_card(model_id)
ephemeral_port = 50000
node_to_runner = {NodeId(nd): RunnerId() for nd in nodes}
runner_to_shard = {
runner_id: PipelineShardMetadata(
model_card=mc,
device_rank=i,
world_size=n,
start_layer=0,
end_layer=mc.n_layers,
n_layers=mc.n_layers,
)
for i, runner_id in enumerate(node_to_runner.values())
}
# Build hosts_by_node with IPs matching _topology() convention:
# node at index idx has IP 10.0.0.{idx+1}
hosts_by_node: dict[NodeId, list[Host]] = {}
for r, node_str in enumerate(nodes):
hosts: list[Host] = []
for idx in range(n):
if idx == r:
hosts.append(Host(ip="0.0.0.0", port=ephemeral_port))
elif n > 1 and idx in ((r - 1) % n, (r + 1) % n):
hosts.append(Host(ip=f"10.0.0.{idx + 1}", port=ephemeral_port))
else:
hosts.append(Host(ip="198.51.100.1", port=0))
hosts_by_node[NodeId(node_str)] = hosts
return iid, MlxRingInstance(
instance_id=iid,
shard_assignments=ShardAssignments(
model_id=ModelId(model_id),
runner_to_shard=runner_to_shard,
node_to_runner=node_to_runner,
),
hosts_by_node=hosts_by_node,
ephemeral_port=ephemeral_port,
meta_instance_id=meta_instance_id,
)
# --- instance_satisfies_meta_instance (pure constraint matching) ---
def test_satisfies_matching_model():
meta = _meta_instance()
_, inst = _instance(node_ids=["node-a"])
assert instance_satisfies_meta_instance(meta, inst) is True
def test_not_satisfies_wrong_model():
meta = _meta_instance("test-org/model-a")
_, inst = _instance("test-org/model-b")
assert instance_satisfies_meta_instance(meta, inst) is False
def test_not_satisfies_missing_required_node():
meta = _meta_instance(node_ids=[NodeId("node-c")])
_, inst = _instance(node_ids=["node-a", "node-b"])
assert instance_satisfies_meta_instance(meta, inst) is False
def test_not_satisfies_fewer_than_min_nodes():
meta = _meta_instance(min_nodes=3)
_, inst = _instance(node_ids=["node-a", "node-b"])
assert instance_satisfies_meta_instance(meta, inst) is False
def test_satisfies_with_node_ids_specified():
meta = _meta_instance(node_ids=[NodeId("node-a"), NodeId("node-b")], min_nodes=2)
_, inst = _instance(node_ids=["node-a", "node-b", "node-c"])
assert instance_satisfies_meta_instance(meta, inst) is True
# --- instance_connections_healthy ---
def test_healthy_single_node_present():
_, inst = _instance(node_ids=["node-a"])
topology = _topology("node-a")
assert instance_connections_healthy(inst, topology) is True
def test_unhealthy_single_node_missing():
_, inst = _instance(node_ids=["node-a"])
topology = Topology() # empty
assert instance_connections_healthy(inst, topology) is False
def test_healthy_two_node_ring():
_, inst = _instance(node_ids=["node-a", "node-b"])
topology = _topology("node-a", "node-b")
assert instance_connections_healthy(inst, topology) is True
def test_unhealthy_two_node_edge_removed():
"""Nodes present but edge removed — ring broken."""
_, inst = _instance(node_ids=["node-a", "node-b"])
topology = _topology("node-a", "node-b", connect=False)
assert instance_connections_healthy(inst, topology) is False
def test_unhealthy_two_node_ip_changed():
"""Edge exists but with a different IP than instance was configured with."""
_, inst = _instance(node_ids=["node-a", "node-b"])
# Build topology with different IPs than _instance() expects
topology = Topology()
topology.add_node(NodeId("node-a"))
topology.add_node(NodeId("node-b"))
topology.add_connection(
Connection(
source=NodeId("node-a"),
sink=NodeId("node-b"),
edge=SocketConnection(
sink_multiaddr=Multiaddr(address="/ip4/192.168.99.99/tcp/50000")
),
)
)
topology.add_connection(
Connection(
source=NodeId("node-b"),
sink=NodeId("node-a"),
edge=SocketConnection(
sink_multiaddr=Multiaddr(address="/ip4/192.168.99.98/tcp/50000")
),
)
)
assert instance_connections_healthy(inst, topology) is False
def test_healthy_three_node_ring():
_, inst = _instance(node_ids=["node-a", "node-b", "node-c"])
topology = _topology("node-a", "node-b", "node-c")
assert instance_connections_healthy(inst, topology) is True
def test_unhealthy_three_node_one_edge_removed():
"""Remove one edge from a three-node ring — instance unhealthy."""
_, inst = _instance(node_ids=["node-a", "node-b", "node-c"])
# Build topology with one direction of one edge missing
topology = Topology()
nodes = [NodeId("node-a"), NodeId("node-b"), NodeId("node-c")]
for n in nodes:
topology.add_node(n)
# Add all edges except node-a → node-b
topology.add_connection(
Connection(
source=nodes[1],
sink=nodes[0],
edge=SocketConnection(
sink_multiaddr=Multiaddr(address="/ip4/10.0.0.1/tcp/50000")
),
)
)
topology.add_connection(
Connection(
source=nodes[1],
sink=nodes[2],
edge=SocketConnection(
sink_multiaddr=Multiaddr(address="/ip4/10.0.0.3/tcp/50000")
),
)
)
topology.add_connection(
Connection(
source=nodes[2],
sink=nodes[1],
edge=SocketConnection(
sink_multiaddr=Multiaddr(address="/ip4/10.0.0.2/tcp/50000")
),
)
)
topology.add_connection(
Connection(
source=nodes[2],
sink=nodes[0],
edge=SocketConnection(
sink_multiaddr=Multiaddr(address="/ip4/10.0.0.1/tcp/50000")
),
)
)
topology.add_connection(
Connection(
source=nodes[0],
sink=nodes[2],
edge=SocketConnection(
sink_multiaddr=Multiaddr(address="/ip4/10.0.0.3/tcp/50000")
),
)
)
# Missing: node-a → node-b (ip 10.0.0.2)
assert instance_connections_healthy(inst, topology) is False
def test_unhealthy_node_missing_from_topology():
"""Instance has a node that's not in the topology at all."""
_, inst = _instance(node_ids=["node-a", "node-b"])
topology = _topology("node-a") # node-b not present
assert instance_connections_healthy(inst, topology) is False
def test_healthy_extra_nodes_in_topology():
"""Extra nodes in topology don't affect instance health."""
_, inst = _instance(node_ids=["node-a", "node-b"])
topology = _topology("node-a", "node-b", "node-c")
assert instance_connections_healthy(inst, topology) is True
# --- find_unsatisfied_meta_instances ---
def test_unsatisfied_no_meta_instances():
result = find_unsatisfied_meta_instances({}, {}, Topology())
assert list(result) == []
def test_unsatisfied_one_satisfied():
meta = _meta_instance()
id_a, inst_a = _instance(meta_instance_id=meta.meta_instance_id)
topology = _topology("node-a")
result = find_unsatisfied_meta_instances(
{meta.meta_instance_id: meta},
{id_a: inst_a},
topology,
)
assert list(result) == []
def test_unsatisfied_one_not_satisfied():
meta = _meta_instance("test-org/model-x")
id_a, inst_a = _instance("test-org/model-y")
topology = _topology("node-a")
result = find_unsatisfied_meta_instances(
{meta.meta_instance_id: meta}, {id_a: inst_a}, topology
)
assert list(result) == [meta]
def test_unsatisfied_mix():
meta_satisfied = _meta_instance("test-org/model-a")
meta_unsatisfied = _meta_instance("test-org/model-b")
id_a, inst_a = _instance(
"test-org/model-a", meta_instance_id=meta_satisfied.meta_instance_id
)
topology = _topology("node-a")
result = find_unsatisfied_meta_instances(
{
meta_satisfied.meta_instance_id: meta_satisfied,
meta_unsatisfied.meta_instance_id: meta_unsatisfied,
},
{id_a: inst_a},
topology,
)
assert list(result) == [meta_unsatisfied]
def test_unsatisfied_node_disconnect():
meta = _meta_instance()
id_a, inst_a = _instance(
node_ids=["node-a", "node-b"], meta_instance_id=meta.meta_instance_id
)
topology = _topology("node-a") # node-b disconnected
result = find_unsatisfied_meta_instances(
{meta.meta_instance_id: meta},
{id_a: inst_a},
topology,
)
assert list(result) == [meta]
def test_unsatisfied_edge_break():
"""Instance exists but its connections broke — meta-instance becomes unsatisfied."""
meta = _meta_instance()
id_a, inst_a = _instance(
node_ids=["node-a", "node-b"], meta_instance_id=meta.meta_instance_id
)
topology = _topology("node-a", "node-b", connect=False) # nodes present, no edges
result = find_unsatisfied_meta_instances(
{meta.meta_instance_id: meta},
{id_a: inst_a},
topology,
)
assert list(result) == [meta]
def test_unsatisfied_idempotent():
meta = _meta_instance("test-org/model-x")
topology = _topology("node-a")
meta_instances = {meta.meta_instance_id: meta}
instances: dict[InstanceId, MlxRingInstance] = {}
result_1 = list(
find_unsatisfied_meta_instances(meta_instances, instances, topology)
)
result_2 = list(
find_unsatisfied_meta_instances(meta_instances, instances, topology)
)
assert result_1 == result_2
def test_unsatisfied_exclusive_binding():
"""Two MetaInstances for the same model: one is bound via meta_instance_id, the other is unsatisfied."""
meta_a = _meta_instance("test-org/model-x")
meta_b = _meta_instance("test-org/model-x")
id_inst, inst = _instance(
"test-org/model-x", meta_instance_id=meta_a.meta_instance_id
)
topology = _topology("node-a")
result = find_unsatisfied_meta_instances(
{
meta_a.meta_instance_id: meta_a,
meta_b.meta_instance_id: meta_b,
},
{id_inst: inst},
topology,
)
assert list(result) == [meta_b]
# --- apply handlers ---
def test_apply_meta_instance_created():
state = State()
meta = _meta_instance()
event = MetaInstanceCreated(meta_instance=meta)
new_state = apply(state, IndexedEvent(idx=0, event=event))
assert meta.meta_instance_id in new_state.meta_instances
assert new_state.meta_instances[meta.meta_instance_id] == meta
def test_apply_meta_instance_deleted():
meta = _meta_instance()
state = State(meta_instances={meta.meta_instance_id: meta})
event = MetaInstanceDeleted(meta_instance_id=meta.meta_instance_id)
new_state = apply(state, IndexedEvent(idx=0, event=event))
assert meta.meta_instance_id not in new_state.meta_instances
def test_apply_meta_instance_deleted_clears_failure_info():
meta = _meta_instance().model_copy(
update={"consecutive_failures": 2, "last_failure_error": "OOM"}
)
state = State(meta_instances={meta.meta_instance_id: meta})
event = MetaInstanceDeleted(meta_instance_id=meta.meta_instance_id)
new_state = apply(state, IndexedEvent(idx=0, event=event))
assert meta.meta_instance_id not in new_state.meta_instances
# --- instance_runners_failed ---
def test_runners_failed_all_failed():
"""All runners in RunnerFailed -> instance is failed."""
_, inst = _instance(node_ids=["node-a", "node-b"])
runners = {
rid: RunnerFailed(error_message="OOM")
for rid in inst.shard_assignments.node_to_runner.values()
}
is_failed, error = instance_runners_failed(inst, runners, {})
assert is_failed is True
assert error is not None
assert "OOM" in error
def test_runners_failed_mixed_failed_shutdown():
"""One Failed + one Shutdown = failed."""
_, inst = _instance(node_ids=["node-a", "node-b"])
runner_ids = list(inst.shard_assignments.node_to_runner.values())
runners = {
runner_ids[0]: RunnerFailed(error_message="crash"),
runner_ids[1]: RunnerShutdown(),
}
is_failed, error = instance_runners_failed(inst, runners, {})
assert is_failed is True
assert error is not None
assert "crash" in error
def test_runners_not_failed_all_shutdown():
"""All Shutdown (graceful) = not a failure."""
_, inst = _instance(node_ids=["node-a"])
runners = {
rid: RunnerShutdown() for rid in inst.shard_assignments.node_to_runner.values()
}
is_failed, _ = instance_runners_failed(inst, runners, {})
assert is_failed is False
def test_runners_not_failed_still_active():
"""Some runners still active = not failed yet."""
_, inst = _instance(node_ids=["node-a", "node-b"])
runner_ids = list(inst.shard_assignments.node_to_runner.values())
runners = {
runner_ids[0]: RunnerFailed(error_message="OOM"),
runner_ids[1]: RunnerLoading(),
}
is_failed, _ = instance_runners_failed(inst, runners, {})
assert is_failed is False
def test_runners_not_failed_no_status():
"""Runner not yet reported = not failed."""
_, inst = _instance(node_ids=["node-a"])
is_failed, _ = instance_runners_failed(inst, {}, {})
assert is_failed is False
def test_runners_not_failed_healthy():
"""Runners in Ready state = not failed."""
_, inst = _instance(node_ids=["node-a"])
runners = {
rid: RunnerReady() for rid in inst.shard_assignments.node_to_runner.values()
}
is_failed, _ = instance_runners_failed(inst, runners, {})
assert is_failed is False
# --- failure tracking in apply_instance_deleted ---
def test_apply_instance_deleted_tracks_failure():
"""InstanceDeleted with failure_error increments meta instance failure count."""
meta = _meta_instance()
iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
state = State(
meta_instances={meta.meta_instance_id: meta},
instances={iid: inst},
)
event = InstanceDeleted(instance_id=iid, failure_error="Runner OOM")
new_state = apply(state, IndexedEvent(idx=0, event=event))
mi = new_state.meta_instances[meta.meta_instance_id]
assert mi.consecutive_failures == 1
assert mi.last_failure_error == "Runner OOM"
def test_apply_instance_deleted_increments_failure():
"""Subsequent failures increment the counter."""
meta = _meta_instance().model_copy(
update={"consecutive_failures": 2, "last_failure_error": "previous error"}
)
iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
state = State(
meta_instances={meta.meta_instance_id: meta},
instances={iid: inst},
)
event = InstanceDeleted(instance_id=iid, failure_error="new error")
new_state = apply(state, IndexedEvent(idx=0, event=event))
mi = new_state.meta_instances[meta.meta_instance_id]
assert mi.consecutive_failures == 3
assert mi.last_failure_error == "new error"
def test_apply_instance_deleted_no_failure_no_tracking():
"""InstanceDeleted without failure_error does not track."""
meta = _meta_instance()
iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
state = State(
meta_instances={meta.meta_instance_id: meta},
instances={iid: inst},
)
event = InstanceDeleted(instance_id=iid)
new_state = apply(state, IndexedEvent(idx=0, event=event))
mi = new_state.meta_instances[meta.meta_instance_id]
assert mi.consecutive_failures == 0
def test_apply_instance_deleted_orphan_no_tracking():
"""InstanceDeleted for orphan instance (no meta_instance_id) does not track."""
iid, inst = _instance(node_ids=["node-a"])
state = State(instances={iid: inst})
event = InstanceDeleted(instance_id=iid, failure_error="crash")
new_state = apply(state, IndexedEvent(idx=0, event=event))
assert len(new_state.meta_instances) == 0
# --- InstanceRetrying ---
def test_apply_instance_retrying_removes_runners():
"""InstanceRetrying removes the instance's runners from state but keeps the instance."""
meta = _meta_instance()
iid, inst = _instance(
node_ids=["node-a", "node-b"], meta_instance_id=meta.meta_instance_id
)
runner_ids = list(inst.shard_assignments.node_to_runner.values())
runners = {
runner_ids[0]: RunnerFailed(error_message="OOM"),
runner_ids[1]: RunnerShutdown(),
}
state = State(
meta_instances={meta.meta_instance_id: meta},
instances={iid: inst},
runners=runners,
)
event = InstanceRetrying(
instance_id=iid,
meta_instance_id=meta.meta_instance_id,
failure_error="OOM",
)
new_state = apply(state, IndexedEvent(idx=0, event=event))
# Instance still exists
assert iid in new_state.instances
# Runners removed
assert runner_ids[0] not in new_state.runners
assert runner_ids[1] not in new_state.runners
def test_apply_instance_retrying_increments_failure():
"""InstanceRetrying increments consecutive_failures on the MetaInstance."""
meta = _meta_instance()
iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
state = State(
meta_instances={meta.meta_instance_id: meta},
instances={iid: inst},
)
event = InstanceRetrying(
instance_id=iid,
meta_instance_id=meta.meta_instance_id,
failure_error="crash",
)
new_state = apply(state, IndexedEvent(idx=0, event=event))
mi = new_state.meta_instances[meta.meta_instance_id]
assert mi.consecutive_failures == 1
assert mi.last_failure_error == "crash"
def test_apply_instance_retrying_skips_missing_runners():
"""InstanceRetrying doesn't assert if runners haven't reported yet."""
meta = _meta_instance()
iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
# No runners in state at all
state = State(
meta_instances={meta.meta_instance_id: meta},
instances={iid: inst},
)
event = InstanceRetrying(
instance_id=iid,
meta_instance_id=meta.meta_instance_id,
failure_error="crash",
)
# Should not raise
new_state = apply(state, IndexedEvent(idx=0, event=event))
assert iid in new_state.instances
def test_apply_instance_created_resets_failure_counter():
"""InstanceCreated resets consecutive_failures but preserves last_failure_error."""
meta = _meta_instance().model_copy(
update={"consecutive_failures": 3, "last_failure_error": "old error"}
)
_, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
state = State(meta_instances={meta.meta_instance_id: meta})
event = InstanceCreated(instance=inst)
new_state = apply(state, IndexedEvent(idx=0, event=event))
mi = new_state.meta_instances[meta.meta_instance_id]
assert mi.consecutive_failures == 0
assert mi.last_failure_error == "old error"
assert mi.placement_error is None
# --- InstanceHealthReconciler retry-vs-delete ---
async def test_health_reconciler_retries_when_under_limit():
"""InstanceHealthReconciler emits InstanceRetrying when consecutive_failures < 3."""
meta = _meta_instance()
iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
runner_ids = list(inst.shard_assignments.node_to_runner.values())
state = State(
meta_instances={meta.meta_instance_id: meta},
instances={iid: inst},
runners={runner_ids[0]: RunnerFailed(error_message="OOM")},
topology=_topology("node-a"),
)
reconciler = InstanceHealthReconciler()
events = await reconciler.reconcile(state)
assert len(events) == 1
assert isinstance(events[0], InstanceRetrying)
assert events[0].instance_id == iid
assert events[0].meta_instance_id == meta.meta_instance_id
async def test_health_reconciler_deletes_when_limit_reached():
"""InstanceHealthReconciler emits InstanceDeleted when consecutive_failures >= 3."""
meta = _meta_instance().model_copy(update={"consecutive_failures": 3})
iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
runner_ids = list(inst.shard_assignments.node_to_runner.values())
state = State(
meta_instances={meta.meta_instance_id: meta},
instances={iid: inst},
runners={runner_ids[0]: RunnerFailed(error_message="OOM")},
topology=_topology("node-a"),
)
reconciler = InstanceHealthReconciler()
events = await reconciler.reconcile(state)
assert len(events) == 1
assert isinstance(events[0], InstanceDeleted)
async def test_health_reconciler_deletes_without_meta_instance():
"""Instances without a MetaInstance are deleted immediately on runner failure."""
iid, inst = _instance(node_ids=["node-a"])
runner_ids = list(inst.shard_assignments.node_to_runner.values())
state = State(
instances={iid: inst},
runners={runner_ids[0]: RunnerFailed(error_message="crash")},
topology=_topology("node-a"),
)
reconciler = InstanceHealthReconciler()
events = await reconciler.reconcile(state)
assert len(events) == 1
assert isinstance(events[0], InstanceDeleted)
async def test_health_reconciler_network_failure_always_deletes():
"""Network failure always triggers InstanceDeleted regardless of retry count."""
meta = _meta_instance()
iid, inst = _instance(
node_ids=["node-a", "node-b"], meta_instance_id=meta.meta_instance_id
)
state = State(
meta_instances={meta.meta_instance_id: meta},
instances={iid: inst},
topology=_topology("node-a"), # node-b missing
)
reconciler = InstanceHealthReconciler()
events = await reconciler.reconcile(state)
assert len(events) == 1
assert isinstance(events[0], InstanceDeleted)
assert events[0].failure_error == "Network connection lost"

View File

@@ -211,6 +211,14 @@ class Router:
pass
except AllQueuesFullError:
logger.warning(f"All peer queues full, dropping message on {topic}")
except RuntimeError as e:
if "MessageTooLarge" in str(e):
logger.error(
f"Message too large for gossipsub on topic {topic} "
f"({len(data)} bytes), dropping message"
)
else:
raise
def get_node_id_keypair(

View File

@@ -4,7 +4,7 @@ from datetime import datetime
from loguru import logger
from exo.shared.types.common import MetaInstanceId, NodeId
from exo.shared.types.common import NodeId
from exo.shared.types.events import (
ChunkGenerated,
Event,
@@ -12,12 +12,6 @@ from exo.shared.types.events import (
InputChunkReceived,
InstanceCreated,
InstanceDeleted,
InstanceRetrying,
JacclSideChannelData,
JacclSideChannelGathered,
MetaInstanceCreated,
MetaInstanceDeleted,
MetaInstancePlacementFailed,
NodeDownloadProgress,
NodeGatheredInfo,
NodeTimedOut,
@@ -34,7 +28,6 @@ from exo.shared.types.events import (
TracesCollected,
TracesMerged,
)
from exo.shared.types.meta_instance import MetaInstance
from exo.shared.types.profiling import (
NodeIdentity,
NodeNetworkInfo,
@@ -73,22 +66,12 @@ def event_apply(event: Event, state: State) -> State:
| InputChunkReceived()
| TracesCollected()
| TracesMerged()
| JacclSideChannelData()
| JacclSideChannelGathered()
): # Pass-through events that don't modify state
return state
case InstanceCreated():
return apply_instance_created(event, state)
case InstanceDeleted():
return apply_instance_deleted(event, state)
case InstanceRetrying():
return apply_instance_retrying(event, state)
case MetaInstanceCreated():
return apply_meta_instance_created(event, state)
case MetaInstanceDeleted():
return apply_meta_instance_deleted(event, state)
case MetaInstancePlacementFailed():
return apply_meta_instance_placement_failed(event, state)
case NodeTimedOut():
return apply_node_timed_out(event, state)
case NodeDownloadProgress():
@@ -191,123 +174,20 @@ def apply_task_failed(event: TaskFailed, state: State) -> State:
return state.model_copy(update={"tasks": new_tasks})
def _update_meta_instance(
state: State, mid: MetaInstanceId, **fields: object
) -> Mapping[MetaInstanceId, MetaInstance]:
mi = state.meta_instances[mid]
return {**state.meta_instances, mid: mi.model_copy(update=fields)}
def apply_instance_created(event: InstanceCreated, state: State) -> State:
instance = event.instance
new_instances: Mapping[InstanceId, Instance] = {
**state.instances,
instance.instance_id: instance,
}
update: dict[str, object] = {"instances": new_instances}
# Reset failure tracking when a new instance is created for a meta-instance
if instance.meta_instance_id and instance.meta_instance_id in state.meta_instances:
mi = state.meta_instances[instance.meta_instance_id]
if mi.placement_error is not None or mi.consecutive_failures > 0:
update["meta_instances"] = _update_meta_instance(
state,
instance.meta_instance_id,
placement_error=None,
consecutive_failures=0,
)
return state.model_copy(update=update)
return state.model_copy(update={"instances": new_instances})
def apply_instance_deleted(event: InstanceDeleted, state: State) -> State:
deleted_instance = state.instances.get(event.instance_id)
new_instances: Mapping[InstanceId, Instance] = {
iid: inst for iid, inst in state.instances.items() if iid != event.instance_id
}
update: dict[str, object] = {"instances": new_instances}
# Track failure on the MetaInstance itself
if (
event.failure_error
and deleted_instance
and deleted_instance.meta_instance_id
and deleted_instance.meta_instance_id in state.meta_instances
):
mid = deleted_instance.meta_instance_id
mi = state.meta_instances[mid]
update["meta_instances"] = {
**state.meta_instances,
mid: mi.model_copy(
update={
"consecutive_failures": mi.consecutive_failures + 1,
"last_failure_error": event.failure_error,
}
),
}
return state.model_copy(update=update)
def apply_instance_retrying(event: InstanceRetrying, state: State) -> State:
"""Runners failed but retry limit not reached — remove runners, keep instance."""
instance = state.instances.get(event.instance_id)
if instance is None:
# Instance was already deleted (e.g. cascade from DeleteMetaInstance).
# The InstanceDeleted handler already incremented consecutive_failures
# on the MetaInstance, so skipping here avoids double-counting.
return state
# Remove all runners belonging to this instance from state
runner_ids_to_remove = set(instance.shard_assignments.node_to_runner.values())
new_runners: Mapping[RunnerId, RunnerStatus] = {
rid: rs for rid, rs in state.runners.items() if rid not in runner_ids_to_remove
}
update: dict[str, object] = {"runners": new_runners}
# Increment failure count on the MetaInstance
if event.meta_instance_id in state.meta_instances:
update["meta_instances"] = _update_meta_instance(
state,
event.meta_instance_id,
consecutive_failures=state.meta_instances[
event.meta_instance_id
].consecutive_failures
+ 1,
last_failure_error=event.failure_error,
)
return state.model_copy(update=update)
def apply_meta_instance_created(event: MetaInstanceCreated, state: State) -> State:
new_meta: Mapping[MetaInstanceId, MetaInstance] = {
**state.meta_instances,
event.meta_instance.meta_instance_id: event.meta_instance,
}
return state.model_copy(update={"meta_instances": new_meta})
def apply_meta_instance_deleted(event: MetaInstanceDeleted, state: State) -> State:
new_meta: Mapping[MetaInstanceId, MetaInstance] = {
mid: mi
for mid, mi in state.meta_instances.items()
if mid != event.meta_instance_id
}
return state.model_copy(update={"meta_instances": new_meta})
def apply_meta_instance_placement_failed(
event: MetaInstancePlacementFailed, state: State
) -> State:
if event.meta_instance_id not in state.meta_instances:
return state
return state.model_copy(
update={
"meta_instances": _update_meta_instance(
state, event.meta_instance_id, placement_error=event.reason
)
}
)
return state.model_copy(update={"instances": new_instances})
def apply_runner_status_updated(event: RunnerStatusUpdated, state: State) -> State:

View File

@@ -6,7 +6,7 @@ from uuid import uuid4
from pydantic import BaseModel, Field
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.types.common import CommandId, MetaInstanceId, NodeId
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.memory import Memory
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.shards import Sharding, ShardMetadata
@@ -262,26 +262,6 @@ class DeleteInstanceResponse(BaseModel):
instance_id: InstanceId
class CreateMetaInstanceParams(BaseModel):
model_id: ModelId
sharding: Sharding = Sharding.Pipeline
instance_meta: InstanceMeta = InstanceMeta.MlxRing
min_nodes: int = 1
node_ids: list[NodeId] | None = None
class CreateMetaInstanceResponse(BaseModel):
message: str
command_id: CommandId
meta_instance_id: MetaInstanceId
class DeleteMetaInstanceResponse(BaseModel):
message: str
command_id: CommandId
meta_instance_id: MetaInstanceId
class AdvancedImageParams(BaseModel):
seed: Annotated[int, Field(ge=0)] | None = None
num_inference_steps: Annotated[int, Field(ge=1, le=100)] | None = None
@@ -386,15 +366,6 @@ class DeleteDownloadResponse(CamelCaseModel):
command_id: CommandId
class DistributeModelParams(CamelCaseModel):
target_node_ids: list[NodeId] | None = None # None = all connected nodes
class DistributeModelResponse(CamelCaseModel):
command_id: CommandId
message: str
class TraceEventResponse(CamelCaseModel):
name: str
start_us: int

View File

@@ -6,8 +6,7 @@ from exo.shared.types.api import (
ImageGenerationTaskParams,
)
from exo.shared.types.chunks import InputImageChunk
from exo.shared.types.common import CommandId, MetaInstanceId, NodeId
from exo.shared.types.meta_instance import MetaInstance
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.text_generation import TextGenerationTaskParams
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.shards import Sharding, ShardMetadata
@@ -53,14 +52,6 @@ class TaskCancelled(BaseCommand):
cancelled_command_id: CommandId
class CreateMetaInstance(BaseCommand):
meta_instance: MetaInstance
class DeleteMetaInstance(BaseCommand):
meta_instance_id: MetaInstanceId
class TaskFinished(BaseCommand):
finished_command_id: CommandId
@@ -90,14 +81,6 @@ class CancelDownload(BaseCommand):
model_id: ModelId
class DistributeModel(BaseCommand):
"""Distribute model files from one node to others via MLX distributed."""
model_id: ModelId
source_node_id: NodeId
target_node_ids: list[NodeId]
DownloadCommand = StartDownload | DeleteDownload | CancelDownload
@@ -111,11 +94,8 @@ Command = (
| CreateInstance
| DeleteInstance
| TaskCancelled
| CreateMetaInstance
| DeleteMetaInstance
| TaskFinished
| SendInputChunk
| DistributeModel
)

View File

@@ -42,10 +42,6 @@ class CommandId(Id):
pass
class MetaInstanceId(Id):
"""Identifier for a MetaInstance."""
class Host(CamelCaseModel):
ip: str
port: int

View File

@@ -1,14 +1,11 @@
import base64
from collections.abc import Mapping
from datetime import datetime
from typing import Annotated, final
from typing import final
from pydantic import BeforeValidator, Field, PlainSerializer
from pydantic import Field
from exo.shared.topology import Connection
from exo.shared.types.chunks import GenerationChunk, InputImageChunk
from exo.shared.types.common import CommandId, Id, MetaInstanceId, NodeId, SessionId
from exo.shared.types.meta_instance import MetaInstance
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
from exo.shared.types.tasks import Task, TaskId, TaskStatus
from exo.shared.types.worker.downloads import DownloadProgress
from exo.shared.types.worker.instances import Instance, InstanceId
@@ -17,28 +14,6 @@ from exo.utils.info_gatherer.info_gatherer import GatheredInfo
from exo.utils.pydantic_ext import CamelCaseModel, FrozenModel, TaggedModel
def _decode_base64_bytes(v: bytes | str) -> bytes:
if isinstance(v, bytes):
return v
return base64.b64decode(v)
def _encode_base64_bytes(v: bytes) -> str:
return base64.b64encode(v).decode("ascii")
Base64Bytes = Annotated[
bytes,
BeforeValidator(_decode_base64_bytes),
PlainSerializer(_encode_base64_bytes, return_type=str),
]
"""bytes that serialize to/from base64 strings in JSON.
Needed because TaggedModel's wrap validator converts JSON→Python validation
context, which breaks strict-mode bytes deserialization from JSON strings.
"""
class EventId(Id):
"""
Newtype around `ID`
@@ -91,30 +66,6 @@ class InstanceCreated(BaseEvent):
class InstanceDeleted(BaseEvent):
instance_id: InstanceId
failure_error: str | None = None
class MetaInstanceCreated(BaseEvent):
meta_instance: MetaInstance
class MetaInstanceDeleted(BaseEvent):
meta_instance_id: MetaInstanceId
@final
class MetaInstancePlacementFailed(BaseEvent):
meta_instance_id: MetaInstanceId
reason: str
@final
class InstanceRetrying(BaseEvent):
"""Runners failed but retry count is below the limit — restart runners, keep instance."""
instance_id: InstanceId
meta_instance_id: MetaInstanceId
failure_error: str
class RunnerStatusUpdated(BaseEvent):
@@ -181,25 +132,6 @@ class TracesMerged(BaseEvent):
traces: list[TraceEventData]
@final
class JacclSideChannelData(BaseEvent):
"""A runner's local contribution to a JACCL SideChannel all_gather round."""
instance_id: InstanceId
runner_id: RunnerId
sequence: int
data: Base64Bytes
@final
class JacclSideChannelGathered(BaseEvent):
"""Gathered result of a JACCL SideChannel all_gather round."""
instance_id: InstanceId
sequence: int
gathered_data: Mapping[RunnerId, Base64Bytes]
Event = (
TestEvent
| TaskCreated
@@ -209,10 +141,6 @@ Event = (
| TaskAcknowledged
| InstanceCreated
| InstanceDeleted
| InstanceRetrying
| MetaInstanceCreated
| MetaInstanceDeleted
| MetaInstancePlacementFailed
| RunnerStatusUpdated
| RunnerDeleted
| NodeTimedOut
@@ -224,8 +152,6 @@ Event = (
| TopologyEdgeDeleted
| TracesCollected
| TracesMerged
| JacclSideChannelData
| JacclSideChannelGathered
)

View File

@@ -1,25 +0,0 @@
from typing import final
from pydantic import Field
from exo.shared.models.model_cards import ModelId
from exo.shared.types.common import MetaInstanceId, NodeId
from exo.shared.types.worker.instances import InstanceMeta
from exo.shared.types.worker.shards import Sharding
from exo.utils.pydantic_ext import FrozenModel
@final
class MetaInstance(FrozenModel):
"""Declarative constraint: ensure an instance matching these parameters always exists."""
meta_instance_id: MetaInstanceId = Field(default_factory=MetaInstanceId)
model_id: ModelId
sharding: Sharding = Sharding.Pipeline
instance_meta: InstanceMeta = InstanceMeta.MlxRing
min_nodes: int = 1
node_ids: list[NodeId] | None = None
# Failure tracking
placement_error: str | None = None
consecutive_failures: int = 0
last_failure_error: str | None = None

View File

@@ -6,8 +6,7 @@ from pydantic import ConfigDict, Field, field_serializer, field_validator
from pydantic.alias_generators import to_camel
from exo.shared.topology import Topology, TopologySnapshot
from exo.shared.types.common import MetaInstanceId, NodeId
from exo.shared.types.meta_instance import MetaInstance
from exo.shared.types.common import NodeId
from exo.shared.types.profiling import (
DiskUsage,
MemoryUsage,
@@ -42,7 +41,6 @@ class State(CamelCaseModel):
arbitrary_types_allowed=True,
)
instances: Mapping[InstanceId, Instance] = {}
meta_instances: Mapping[MetaInstanceId, MetaInstance] = {}
runners: Mapping[RunnerId, RunnerStatus] = {}
downloads: Mapping[NodeId, Sequence[DownloadProgress]] = {}
tasks: Mapping[TaskId, Task] = {}

View File

@@ -42,7 +42,7 @@ class DownloadModel(BaseTask): # emitted by Worker
class LoadModel(BaseTask): # emitted by Worker
has_local_model: bool = Field(default=True)
pass
class ConnectToGroup(BaseTask): # emitted by Worker
@@ -61,7 +61,7 @@ class TextGeneration(BaseTask): # emitted by Master
error_message: str | None = Field(default=None)
class CancelTask(BaseTask): # emitted by Worker when master cancels a task
class CancelTask(BaseTask):
cancelled_task_id: TaskId
runner_id: RunnerId
@@ -82,13 +82,6 @@ class ImageEdits(BaseTask): # emitted by Master
error_message: str | None = Field(default=None)
class TransferModelToDisk(BaseTask): # emitted by Worker
"""Transfer all model files from source to receivers' disk via MLX distributed."""
shard_metadata: ShardMetadata
has_local_model: bool = Field(default=True)
class Shutdown(BaseTask): # emitted by Worker
runner_id: RunnerId
@@ -98,7 +91,6 @@ Task = (
| DownloadModel
| ConnectToGroup
| LoadModel
| TransferModelToDisk
| StartWarmup
| TextGeneration
| CancelTask

View File

@@ -2,7 +2,7 @@ from enum import Enum
from pydantic import model_validator
from exo.shared.types.common import Host, Id, MetaInstanceId, NodeId
from exo.shared.types.common import Host, Id, NodeId
from exo.shared.types.worker.runners import RunnerId, ShardAssignments, ShardMetadata
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
@@ -19,7 +19,6 @@ class InstanceMeta(str, Enum):
class BaseInstance(TaggedModel):
instance_id: InstanceId
shard_assignments: ShardAssignments
meta_instance_id: MetaInstanceId | None = None
def shard(self, runner_id: RunnerId) -> ShardMetadata | None:
return self.shard_assignments.runner_to_shard.get(runner_id, None)

View File

@@ -84,7 +84,6 @@ class ShardAssignments(CamelCaseModel):
model_id: ModelId
runner_to_shard: Mapping[RunnerId, ShardMetadata]
node_to_runner: Mapping[NodeId, RunnerId]
transfer_only: bool = False
@model_validator(mode="after")
def validate_runners_exist(self) -> "ShardAssignments":

View File

@@ -125,7 +125,9 @@ class MpSender[T]:
self._state.buffer.put(item, block=True)
async def send_async(self, item: T) -> None:
await to_thread.run_sync(self.send, item, limiter=CapacityLimiter(1))
await to_thread.run_sync(
self.send, item, limiter=CapacityLimiter(1), abandon_on_cancel=True
)
def close(self) -> None:
if not self._state.closed.is_set():

View File

@@ -47,7 +47,6 @@ if TYPE_CHECKING:
from mlx_lm.models.cache import Cache
TimeoutCallback = Callable[[], None]
WeightLoader = Callable[[nn.Module, int], None] | None
def eval_with_timeout(
@@ -347,7 +346,6 @@ def tensor_auto_parallel(
group: mx.distributed.Group,
timeout_seconds: float = 60.0,
on_timeout: TimeoutCallback | None = None,
weight_loader: WeightLoader = None,
) -> nn.Module:
all_to_sharded_linear = partial(
shard_linear,
@@ -457,7 +455,7 @@ def tensor_auto_parallel(
raise ValueError(f"Unsupported model type: {type(model)}")
model = tensor_parallel_sharding_strategy.shard_model(
model, timeout_seconds, on_timeout, weight_loader
model, timeout_seconds, on_timeout
)
return patch_tensor_model(model)
@@ -484,7 +482,6 @@ class TensorParallelShardingStrategy(ABC):
model: nn.Module,
timeout_seconds: float,
on_timeout: TimeoutCallback | None,
weight_loader: WeightLoader = None,
) -> nn.Module: ...
@@ -494,12 +491,9 @@ class LlamaShardingStrategy(TensorParallelShardingStrategy):
model: nn.Module,
timeout_seconds: float,
on_timeout: TimeoutCallback | None,
weight_loader: WeightLoader = None,
) -> nn.Module:
model = cast(LlamaModel, model)
for i, layer in enumerate(model.layers):
if weight_loader is not None:
weight_loader(model, i)
for layer in model.layers:
# Force load weights before sharding to avoid FAST_SYNCH deadlock
eval_with_timeout(
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
@@ -551,12 +545,9 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
model: nn.Module,
timeout_seconds: float,
on_timeout: TimeoutCallback | None,
weight_loader: WeightLoader = None,
) -> nn.Module:
model = cast(DeepseekV3Model, model)
for i, layer in enumerate(model.layers):
if weight_loader is not None:
weight_loader(model, i)
for layer in model.layers:
eval_with_timeout(
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
)
@@ -629,12 +620,9 @@ class GLM4MoeLiteShardingStrategy(TensorParallelShardingStrategy):
model: nn.Module,
timeout_seconds: float,
on_timeout: TimeoutCallback | None,
weight_loader: WeightLoader = None,
) -> nn.Module:
model = cast(GLM4MoeLiteModel, model)
for i, layer in enumerate(model.layers): # type: ignore
if weight_loader is not None:
weight_loader(model, i)
for layer in model.layers: # type: ignore
layer = cast(Glm4MoeLiteDecoderLayer, layer)
eval_with_timeout(
layer.parameters(),
@@ -774,12 +762,9 @@ class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
model: nn.Module,
timeout_seconds: float,
on_timeout: TimeoutCallback | None,
weight_loader: WeightLoader = None,
) -> nn.Module:
model = cast(MiniMaxModel, model)
for i, layer in enumerate(model.layers):
if weight_loader is not None:
weight_loader(model, i)
for layer in model.layers:
eval_with_timeout(
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
)
@@ -817,12 +802,9 @@ class QwenShardingStrategy(TensorParallelShardingStrategy):
model: nn.Module,
timeout_seconds: float,
on_timeout: TimeoutCallback | None,
weight_loader: WeightLoader = None,
) -> nn.Module:
model = cast(Qwen3MoeModel | Qwen3NextModel, model)
for i, layer in enumerate(model.layers):
if weight_loader is not None:
weight_loader(model, i)
for layer in model.layers:
eval_with_timeout(
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
)
@@ -944,12 +926,9 @@ class Glm4MoeShardingStrategy(TensorParallelShardingStrategy):
model: nn.Module,
timeout_seconds: float,
on_timeout: TimeoutCallback | None,
weight_loader: WeightLoader = None,
) -> nn.Module:
model = cast(Glm4MoeModel, model)
for i, layer in enumerate(model.layers):
if weight_loader is not None:
weight_loader(model, i)
for layer in model.layers:
eval_with_timeout(
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
)
@@ -993,13 +972,10 @@ class GptOssShardingStrategy(TensorParallelShardingStrategy):
model: nn.Module,
timeout_seconds: float,
on_timeout: TimeoutCallback | None,
weight_loader: WeightLoader = None,
) -> nn.Module:
model = cast(GptOssMoeModel, model)
for i, layer in enumerate(model.layers):
if weight_loader is not None:
weight_loader(model, i)
for layer in model.layers:
eval_with_timeout(
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
)
@@ -1037,13 +1013,10 @@ class Step35ShardingStrategy(TensorParallelShardingStrategy):
model: nn.Module,
timeout_seconds: float,
on_timeout: TimeoutCallback | None,
weight_loader: WeightLoader = None,
) -> nn.Module:
model = cast(Step35Model, model)
for i, layer in enumerate(model.layers):
if weight_loader is not None:
weight_loader(model, i)
for layer in model.layers:
eval_with_timeout(
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
)

View File

@@ -1,507 +0,0 @@
"""
Model transfer via MLX distributed all_sum.
Three transfer modes:
1. Metadata file transfer: broadcast small files (config.json, tokenizer, etc.) to disk
2. Weight tensor broadcast: stream weight tensors directly into memory via all_sum
3. Full file transfer: broadcast all files (including safetensors) to disk
All functions are collective operations — every rank in the group must call them.
Protocol relies on all_sum: source has real data, receivers have zeros.
all_sum(source + zeros) = source data on all ranks.
"""
from __future__ import annotations
import json
import os
import re
import shutil
import tempfile
from functools import partial
from pathlib import Path
from typing import Any, Final, cast
import mlx.core as mx
from exo.shared.constants import EXO_MODELS_DIR
from exo.shared.models.model_cards import ModelId
from exo.worker.runner.bootstrap import logger
Group = mx.distributed.Group
CHUNK_SIZE: Final[int] = 100 * 1024 * 1024 # 100 MB
_LAYER_RE: Final[re.Pattern[str]] = re.compile(r"(?:^|\.)(layers|h)\.(\d+)\.")
def _all_sum_cpu(x: mx.array, group: Group) -> mx.array:
"""all_sum on CPU stream to avoid GPU memory pressure."""
return mx.distributed.all_sum(
x, stream=mx.default_stream(mx.Device(mx.cpu)), group=group
)
def _is_metadata_file(filename: str) -> bool:
"""A metadata file is anything that isn't a weight file or weight index.
Weight indices (.safetensors.index.json) reference safetensors shard paths.
Transferring them to a receiver that has no safetensors files is harmless
today (load_model's glob doesn't match them), but excluding them avoids
stale references and keeps the transfer minimal.
"""
if filename.endswith(".safetensors"):
return False
return not filename.endswith(".safetensors.index.json")
def model_path_for_id(model_id: ModelId) -> Path:
"""Get model path without requiring directory to exist (unlike build_model_path)."""
return EXO_MODELS_DIR / model_id.normalize()
def coordinate_transfer(group: Group, has_local_model: bool) -> tuple[bool, int]:
"""
Determine if a transfer is needed and which rank is the source.
All ranks must call this function (uses collective all_sum).
Returns:
(needs_transfer, source_rank) — source_rank is the lowest rank
that has the model. needs_transfer is True if any rank is missing it.
"""
all_sum = partial(_all_sum_cpu, group=group)
world_size = group.size()
# Each rank broadcasts a one-hot vector at its position if it has the model
bitmask = mx.zeros(world_size, dtype=mx.int32)
if has_local_model:
bitmask = bitmask.at[group.rank()].add(1)
summed = all_sum(bitmask)
mx.eval(summed)
has_model_flags: list[int] = summed.tolist() # type: ignore[assignment]
total_have = sum(has_model_flags)
if total_have == 0:
raise RuntimeError(
"No rank has the model files — cannot transfer. "
"At least one node must have downloaded the model."
)
if total_have == world_size:
logger.info("All ranks have model files, no transfer needed")
return False, 0
source_rank = next(i for i, flag in enumerate(has_model_flags) if flag > 0)
logger.info(
f"Transfer needed: source_rank={source_rank}, "
f"{total_have}/{world_size} ranks have model"
)
return True, source_rank
def _broadcast_json(obj: object, group: Group, is_source: bool) -> object:
"""Broadcast a JSON-serializable object from source to all ranks."""
all_sum = partial(_all_sum_cpu, group=group)
data = json.dumps(obj, separators=(",", ":")).encode("utf-8") if is_source else b""
# Broadcast length
len_arr = mx.array([len(data) if is_source else 0], dtype=mx.int64)
len_result = all_sum(len_arr)
mx.eval(len_result)
length = int(len_result.item())
if length == 0:
return None
# Broadcast payload
if is_source:
arr = mx.array(list(data), dtype=mx.uint8)
else:
arr = mx.zeros(length, dtype=mx.uint8)
result = all_sum(arr)
mx.eval(result)
return json.loads(bytes(cast(list[int], result.tolist()))) # pyright: ignore[reportAny]
def _build_manifest(
model_path: Path, metadata_only: bool = False
) -> list[dict[str, str | int]]:
"""Build a list of files in the model directory with their relative paths and sizes."""
manifest: list[dict[str, str | int]] = []
for root, _dirs, files in os.walk(model_path):
for fname in sorted(files):
if metadata_only and not _is_metadata_file(fname):
continue
full_path = Path(root) / fname
rel_path = str(full_path.relative_to(model_path))
manifest.append(
{
"path": rel_path,
"size": full_path.stat().st_size,
}
)
return manifest
def _transfer_file_to_disk(
source_path: Path,
rel_path: str,
file_size: int,
group: Group,
is_source: bool,
dest_path: Path,
) -> None:
"""Transfer a single file chunk-by-chunk via all_sum. Source reads from disk, receivers write to dest_path."""
all_sum = partial(_all_sum_cpu, group=group)
if is_source:
src_file = source_path / rel_path
with open(src_file, "rb") as f:
offset = 0
while offset < file_size:
chunk_bytes = min(CHUNK_SIZE, file_size - offset)
data = f.read(chunk_bytes)
if not data:
break
size_arr = mx.array([len(data)], dtype=mx.int64)
mx.eval(all_sum(size_arr))
chunk_arr = mx.array(list(data), dtype=mx.uint8)
result = all_sum(chunk_arr)
mx.eval(result)
offset += len(data)
# Signal end of file
mx.eval(all_sum(mx.array([0], dtype=mx.int64)))
else:
dst_file = dest_path / rel_path
os.makedirs(dst_file.parent, exist_ok=True)
with open(dst_file, "wb") as f:
while True:
size_arr = all_sum(mx.zeros(1, dtype=mx.int64))
mx.eval(size_arr)
chunk_size = int(size_arr.item())
if chunk_size == 0:
break
chunk_data = all_sum(mx.zeros(chunk_size, dtype=mx.uint8))
mx.eval(chunk_data)
f.write(bytes(cast(list[int], chunk_data.tolist())))
def _transfer_files_to_disk(
model_path: Path,
group: Group,
is_source: bool,
metadata_only: bool = False,
) -> None:
"""
Transfer files from source to all receivers' disk.
Source broadcasts a manifest then each file. Receivers write to a temp dir
then atomically move files to model_path.
"""
if is_source:
source_manifest = _build_manifest(model_path, metadata_only=metadata_only)
else:
source_manifest = []
manifest = cast(
list[dict[str, str | int]],
_broadcast_json(source_manifest if is_source else None, group, is_source),
)
if not manifest:
logger.info("No files to transfer")
return
logger.info(
f"Transferring {len(manifest)} files ({'metadata only' if metadata_only else 'all'})"
)
temp_dir: Path | None = None
if not is_source:
os.makedirs(model_path.parent, exist_ok=True)
temp_dir = Path(
tempfile.mkdtemp(
dir=model_path.parent,
prefix=f".transfer_{model_path.name}_",
)
)
try:
for entry in manifest:
rel_path = str(entry["path"])
file_size = int(entry["size"])
logger.info(f" {rel_path} ({file_size} bytes)")
_transfer_file_to_disk(
source_path=model_path,
rel_path=rel_path,
file_size=file_size,
group=group,
is_source=is_source,
dest_path=temp_dir if temp_dir is not None else model_path,
)
if temp_dir is not None:
os.makedirs(model_path, exist_ok=True)
for entry in manifest:
rel_path = str(entry["path"])
src = temp_dir / rel_path
dst = model_path / rel_path
os.makedirs(dst.parent, exist_ok=True)
os.replace(src, dst)
logger.info(
f"Transfer complete: {len(manifest)} files moved to {model_path}"
)
finally:
if temp_dir is not None and temp_dir.exists():
shutil.rmtree(temp_dir, ignore_errors=True)
def transfer_metadata_files(model_path: Path, group: Group, is_source: bool) -> None:
"""
Transfer metadata files (config.json, tokenizer files, etc.) to receivers' disk.
All ranks must call this function (collective operation).
Only the designated source (is_source=True) should send; all others receive.
"""
_transfer_files_to_disk(model_path, group, is_source=is_source, metadata_only=True)
def transfer_all_files(model_path: Path, group: Group, is_source: bool) -> None:
"""
Transfer ALL model files (including safetensors) to receivers' disk.
All ranks must call this function (collective operation).
Only the designated source (is_source=True) should send; all others receive.
"""
_transfer_files_to_disk(model_path, group, is_source=is_source, metadata_only=False)
def _parse_mx_dtype(dtype_str: str) -> mx.Dtype:
"""Convert a dtype string like 'float16' or 'mlx.core.float16' to mx.Dtype."""
name = dtype_str.split(".")[-1]
dtype = getattr(mx, name, None)
if dtype is None:
raise ValueError(f"Unknown MLX dtype: {dtype_str}")
return dtype # type: ignore[return-value]
def _extract_layer_index(name: str) -> int | None:
"""Extract layer index from a weight name, or None for non-layer weights.
Matches patterns like ``model.layers.5.self_attn.q_proj.weight``
or ``transformer.h.12.mlp.gate_proj.scales``.
"""
m = _LAYER_RE.search(name)
return int(m.group(2)) if m else None
class WeightBroadcastState:
"""Holds state for layer-by-layer weight broadcasting.
Created by :func:`prepare_weight_broadcast`. Callers stream weights
incrementally via :meth:`broadcast_non_layer_weights` and
:meth:`broadcast_layer` so that at most one layer's worth of un-sharded
weight data is resident at a time.
"""
def __init__(
self,
meta: dict[str, dict[str, Any]],
source_weights: dict[str, mx.array] | None,
group: Group,
is_source: bool,
) -> None:
self.meta = meta
self.source_weights = source_weights
self.group = group
self.is_source = is_source
# Partition weight names into layer vs. non-layer
self.layer_names: dict[int, list[str]] = {}
self.non_layer_names: list[str] = []
for name in sorted(meta.keys()):
layer_idx = _extract_layer_index(name)
if layer_idx is not None:
self.layer_names.setdefault(layer_idx, []).append(name)
else:
self.non_layer_names.append(name)
logger.info(
f"WeightBroadcastState: {len(self.non_layer_names)} non-layer weights, "
f"{len(self.layer_names)} layers"
)
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
def _broadcast_names(self, names: list[str]) -> dict[str, mx.array]:
"""Broadcast a specific set of weight tensors by name."""
all_sum = partial(_all_sum_cpu, group=self.group)
result: dict[str, mx.array] = {}
for name in names:
info = self.meta[name]
shape = cast(list[int], info["s"])
dtype = _parse_mx_dtype(cast(str, info["d"]))
if self.is_source:
assert self.source_weights is not None
tensor = self.source_weights.pop(name)
mx.eval(tensor) # loads from disk (lazy)
else:
tensor = mx.zeros(shape, dtype=dtype)
broadcasted = all_sum(tensor)
mx.eval(broadcasted)
result[name] = broadcasted
return result
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
def broadcast_non_layer_weights(self) -> dict[str, mx.array]:
"""Broadcast non-layer weights (embeddings, norms, lm_head)."""
if not self.non_layer_names:
return {}
logger.info(
f"Broadcasting {len(self.non_layer_names)} non-layer weight tensors"
)
return self._broadcast_names(self.non_layer_names)
def broadcast_layer(self, layer_idx: int) -> dict[str, mx.array]:
"""Broadcast weights for a single transformer layer."""
names = self.layer_names.get(layer_idx, [])
if not names:
return {}
return self._broadcast_names(names)
def prepare_weight_broadcast(
model_path: Path,
group: Group,
is_source: bool,
) -> WeightBroadcastState:
"""Prepare for layer-by-layer weight broadcasting.
Source loads safetensors lazily and broadcasts weight metadata (names,
shapes, dtypes) as JSON. Returns a :class:`WeightBroadcastState` that
can then stream weights incrementally via ``broadcast_layer()``.
All ranks must call this function (collective operation).
"""
source_weights: dict[str, mx.array] | None = None
if is_source:
source_weights = {}
weight_files = sorted(model_path.glob("*.safetensors"))
if not weight_files:
weight_files = sorted(model_path.glob("**/*.safetensors"))
for wf in weight_files:
try:
loaded = cast(
dict[str, mx.array],
mx.load(str(wf), lazy=True), # pyright: ignore[reportCallIssue]
)
except TypeError:
loaded = cast(dict[str, mx.array], mx.load(str(wf)))
source_weights.update(loaded)
logger.info(
f"Source loaded {len(source_weights)} weight tensors (lazy) "
f"from {len(weight_files)} files"
)
# Broadcast metadata
if is_source and source_weights is not None:
source_meta: dict[str, dict[str, Any]] = {
name: {"s": list(tensor.shape), "d": str(tensor.dtype)}
for name, tensor in source_weights.items()
}
else:
source_meta = {}
meta = cast(
dict[str, dict[str, Any]],
_broadcast_json(source_meta if is_source else None, group, is_source),
)
logger.info(f"Weight broadcast prepared: {len(meta)} tensors")
return WeightBroadcastState(meta, source_weights, group, is_source)
def broadcast_model_weights(
model_path: Path,
group: Group,
is_source: bool,
) -> dict[str, mx.array]:
"""
Broadcast model weight tensors from source rank to all receivers' memory.
Source loads weights from .safetensors files on disk and broadcasts each
tensor via all_sum. Receivers receive tensors directly as mx.arrays in
memory — no disk write for weight data.
All ranks must call this function (collective operation).
Only the designated source (is_source=True) should send; all others receive.
Returns:
dict mapping weight names to mx.arrays (on all ranks).
"""
all_sum = partial(_all_sum_cpu, group=group)
# Source loads weights (lazy if supported, so only one tensor in memory at a time)
weights: dict[str, mx.array] = {}
if is_source:
weight_files = sorted(model_path.glob("*.safetensors"))
if not weight_files:
weight_files = sorted(model_path.glob("**/*.safetensors"))
for wf in weight_files:
try:
loaded = cast(dict[str, mx.array], mx.load(str(wf), lazy=True)) # pyright: ignore[reportCallIssue]
except TypeError:
loaded = cast(dict[str, mx.array], mx.load(str(wf)))
weights.update(loaded)
logger.info(
f"Source loaded {len(weights)} weight tensors from {len(weight_files)} files"
)
# Broadcast weight metadata: {name: {shape, dtype}}
if is_source:
source_meta: dict[str, dict[str, Any]] = {
name: {"s": list(tensor.shape), "d": str(tensor.dtype)}
for name, tensor in weights.items()
}
else:
source_meta = {}
meta = cast(
dict[str, dict[str, Any]],
_broadcast_json(source_meta if is_source else None, group, is_source),
)
logger.info(f"Broadcasting {len(meta)} weight tensors")
# Broadcast each tensor in sorted order (deterministic across ranks).
# Source loads one tensor at a time from disk (lazy), broadcasts it,
# then drops the reference so only one tensor is in flight at a time.
result: dict[str, mx.array] = {}
for i, name in enumerate(sorted(meta.keys())):
info = meta[name]
shape = cast(list[int], info["s"])
dtype_str = cast(str, info["d"])
dtype = _parse_mx_dtype(dtype_str)
if is_source:
tensor = weights.pop(name) # pop to free lazy ref after broadcast
mx.eval(tensor) # loads from disk
else:
tensor = mx.zeros(shape, dtype=dtype)
broadcasted = all_sum(tensor)
mx.eval(broadcasted)
result[name] = broadcasted
if (i + 1) % 100 == 0:
logger.info(f" Broadcast {i + 1}/{len(meta)} tensors")
logger.info(f"Weight broadcast complete: {len(result)} tensors")
return result

View File

@@ -2,7 +2,6 @@ import json
import os
import sys
import time
from collections.abc import Callable
from pathlib import Path
from typing import Any, cast
@@ -60,13 +59,6 @@ from exo.worker.engines.mlx.auto_parallel import (
pipeline_auto_parallel,
tensor_auto_parallel,
)
from exo.worker.engines.mlx.model_transfer import (
WeightBroadcastState,
coordinate_transfer,
model_path_for_id,
prepare_weight_broadcast,
transfer_metadata_files,
)
from exo.worker.runner.bootstrap import logger
Group = mx.distributed.Group
@@ -179,7 +171,6 @@ def load_mlx_items(
bound_instance: BoundInstance,
group: Group | None,
on_timeout: TimeoutCallback | None = None,
has_local_model: bool = True,
) -> tuple[Model, TokenizerWrapper]:
if group is None:
logger.info(f"Single device used for {bound_instance.instance}")
@@ -194,10 +185,7 @@ def load_mlx_items(
logger.info("Starting distributed init")
start_time = time.perf_counter()
model, tokenizer = shard_and_load(
bound_instance.bound_shard,
group=group,
on_timeout=on_timeout,
has_local_model=has_local_model,
bound_instance.bound_shard, group=group, on_timeout=on_timeout
)
end_time = time.perf_counter()
logger.info(
@@ -213,89 +201,30 @@ def shard_and_load(
shard_metadata: ShardMetadata,
group: Group,
on_timeout: TimeoutCallback | None = None,
has_local_model: bool = True,
) -> tuple[nn.Module, TokenizerWrapper]:
model_id = shard_metadata.model_card.model_id
model_path = model_path_for_id(model_id)
model_path = build_model_path(shard_metadata.model_card.model_id)
# Coordinate: does any rank need a transfer?
needs_transfer, source_rank = coordinate_transfer(group, has_local_model)
is_source = group.rank() == source_rank
# Step 1: Always ensure all nodes have metadata files (config, tokenizer, etc.).
# This is cheap (~20MB, ~1s) and guarantees config.json is present for load_model().
transfer_metadata_files(model_path, group, is_source)
# Step 2: Only broadcast weights if some rank is missing the model
broadcast_state: WeightBroadcastState | None = None
if needs_transfer:
logger.info(
f"Model transfer needed (source_rank={source_rank}, "
f"is_source={is_source}, local_weights={has_local_model})"
)
broadcast_state = prepare_weight_broadcast(model_path, group, is_source)
# Create model architecture (all ranks have config.json on disk now).
# Always use lazy=True when we have broadcast state: load_model's internal
# nn.quantize skips quantization when weights dict is empty (no safetensors),
# leaving the model un-quantized. lazy=False would then mx.eval() the full
# fp16 model (~72GB for a 36B-param model), causing OOM on the receiver.
# We handle quantization ourselves below before loading broadcast weights.
use_lazy = has_local_model or broadcast_state is not None
model, _ = load_model(model_path, lazy=use_lazy, strict=False)
model, _ = load_model(model_path, lazy=True, strict=False)
logger.debug(model)
if hasattr(model, "model") and isinstance(model.model, DeepseekV3Model): # type: ignore
pass
# TODO: See if we should quantize the model.
# def is_attention_layer(path: str) -> bool:
# path = path.lower()
# return "self_attn" in path and "layernorm" not in path
# def quant_predicate(path: str, module: nn.Module):
# if not isinstance(module, nn.Linear):
# return False
# return is_attention_layer(path)
# model, config = quantize_model(
# model, config, group_size=KV_GROUP_SIZE, bits=ATTENTION_KV_BITS, quant_predicate=quant_predicate, mode=QUANTIZE_MODEL_MODE
# )
assert isinstance(model, nn.Module)
if broadcast_state is not None:
# When receiver has no weight files, load_model skips quantization
# (its class_predicate checks `f"{p}.scales" in weights`, which is
# always False when weights is empty). Apply quantization explicitly
# using the broadcast metadata to determine which layers are quantized,
# matching load_model's selective quantization logic exactly.
if not has_local_model:
config_path = model_path / "config.json"
with open(config_path) as f:
config = json.load(f) # pyright: ignore[reportAny]
quant_config: dict[str, Any] | None = config.get( # pyright: ignore[reportAny]
"quantization", None
)
if quant_config is not None:
logger.info(f"Applying quantization to receiver model: {quant_config}")
broadcast_weight_names = set(broadcast_state.meta.keys())
def _class_predicate(p: str, m: nn.Module) -> bool | dict[str, Any]:
# Per-layer overrides from config (e.g. "lm_head": false)
assert quant_config is not None
if p in quant_config:
return quant_config[p] # pyright: ignore[reportAny]
if not hasattr(m, "to_quantized"):
return False
# Only quantize layers whose .scales exist in broadcast weights
return f"{p}.scales" in broadcast_weight_names
group_size = int(quant_config.get("group_size", 64)) # pyright: ignore[reportAny]
bits = int(quant_config.get("bits", 4)) # pyright: ignore[reportAny]
mode: str = quant_config.get("mode", "affine") # pyright: ignore[reportAny]
nn.quantize( # pyright: ignore[reportUnknownMemberType]
model,
group_size=group_size,
bits=bits,
mode=mode,
class_predicate=_class_predicate,
)
# Broadcast and load non-layer weights (embeddings, norms, lm_head) upfront.
# These are small (~600MB) and needed before the sharding loop.
non_layer_weights = broadcast_state.broadcast_non_layer_weights()
if non_layer_weights:
model.load_weights(list(non_layer_weights.items()), strict=False)
logger.info(f"Loaded {len(non_layer_weights)} non-layer weight tensors")
del non_layer_weights
tokenizer = get_tokenizer(model_path, shard_metadata)
logger.info(f"Group size: {group.size()}, group rank: {group.rank()}")
@@ -309,43 +238,12 @@ def shard_and_load(
f"(model size: {model_size_gb:.1f}GB)"
)
# Build per-layer weight loader for streaming broadcast during sharding.
# Each layer's weights are broadcast via all_sum just before that layer is
# sharded, so at most one un-sharded layer is in memory at a time.
weight_loader_fn: Callable[[nn.Module, int], None] | None = None
if broadcast_state is not None:
_state = broadcast_state # capture for closure
def _load_layer_weights(mdl: nn.Module, layer_idx: int) -> None:
layer_weights = _state.broadcast_layer(layer_idx)
if layer_weights:
mdl.load_weights(list(layer_weights.items()), strict=False)
weight_loader_fn = _load_layer_weights
match shard_metadata:
case TensorShardMetadata():
logger.info(f"loading model from {model_path} with tensor parallelism")
model = tensor_auto_parallel(
model, group, timeout_seconds, on_timeout, weight_loader_fn
)
model = tensor_auto_parallel(model, group, timeout_seconds, on_timeout)
case PipelineShardMetadata():
logger.info(f"loading model from {model_path} with pipeline parallelism")
# Broadcast all layers (all_sum is collective — all ranks must
# participate) but only load weights for layers this node will
# keep after pipeline slicing. Out-of-range results are discarded,
# keeping peak memory proportional to this node's layer count.
if broadcast_state is not None:
for layer_idx in sorted(broadcast_state.layer_names.keys()):
layer_weights = broadcast_state.broadcast_layer(layer_idx)
if (
shard_metadata.start_layer
<= layer_idx
< shard_metadata.end_layer
and layer_weights
):
model.load_weights(list(layer_weights.items()), strict=False)
del layer_weights
model = pipeline_auto_parallel(model, group, shard_metadata)
eval_with_timeout(model.parameters(), timeout_seconds, on_timeout)
case CfgShardMetadata():
@@ -354,8 +252,6 @@ def shard_and_load(
"this metadata type is only for image generation models"
)
del broadcast_state
# TODO: Do we need this?
mx.eval(model)
@@ -678,11 +574,6 @@ def mlx_cleanup(
def mx_any(bool_: bool, group: Group | None) -> bool:
"""Synchronize a boolean across all distributed nodes.
Returns True if any node has bool_=True. Uses all_sum so every
node participates in the collective — preventing GPU deadlocks.
"""
if group is None:
return bool_
num_true = mx.distributed.all_sum(

View File

@@ -24,7 +24,6 @@ from exo.shared.types.events import (
ForwarderEvent,
IndexedEvent,
InputChunkReceived,
JacclSideChannelGathered,
NodeGatheredInfo,
TaskCreated,
TaskStatusUpdated,
@@ -34,6 +33,7 @@ from exo.shared.types.events import (
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.state import State
from exo.shared.types.tasks import (
CancelTask,
CreateRunner,
DownloadModel,
ImageEdits,
@@ -159,15 +159,6 @@ class Worker:
for idx, event in indexed_events:
self.state = apply(self.state, IndexedEvent(idx=idx, event=event))
# Dispatch JACCL gathered events to the relevant RunnerSupervisor
if isinstance(event, JacclSideChannelGathered):
for runner in self.runners.values():
if (
runner.bound_instance.instance.instance_id
== event.instance_id
):
runner.notify_gathered(event)
# Buffer input image chunks for image editing
if isinstance(event, InputChunkReceived):
cmd_id = event.command_id
@@ -234,15 +225,22 @@ class Worker:
)
)
case Shutdown(runner_id=runner_id):
runner = self.runners.pop(runner_id)
try:
with fail_after(3):
await self.runners.pop(runner_id).start_task(task)
await runner.start_task(task)
except TimeoutError:
await self.event_sender.send(
TaskStatusUpdated(
task_id=task.task_id, task_status=TaskStatus.TimedOut
)
)
finally:
runner.shutdown()
case CancelTask(
cancelled_task_id=cancelled_task_id, runner_id=runner_id
):
await self.runners[runner_id].cancel_task(cancelled_task_id)
case ImageEdits() if task.task_params.total_input_chunks > 0:
# Assemble image from chunks and inject into task
cmd_id = task.command_id
@@ -280,18 +278,18 @@ class Worker:
del self.input_chunk_buffer[cmd_id]
if cmd_id in self.input_chunk_counts:
del self.input_chunk_counts[cmd_id]
await self.runners[self._task_to_runner_id(task)].start_task(
modified_task
)
await self._start_runner_task(modified_task)
case task:
await self.runners[self._task_to_runner_id(task)].start_task(task)
await self._start_runner_task(task)
def shutdown(self):
self._tg.cancel_scope.cancel()
def _task_to_runner_id(self, task: Task):
instance = self.state.instances[task.instance_id]
return instance.shard_assignments.node_to_runner[self.node_id]
async def _start_runner_task(self, task: Task):
if (instance := self.state.instances.get(task.instance_id)) is not None:
await self.runners[
instance.shard_assignments.node_to_runner[self.node_id]
].start_task(task)
async def _nack_request(self, since_idx: int) -> None:
# We request all events after (and including) the missing index.

View File

@@ -2,7 +2,6 @@
from collections.abc import Mapping, Sequence
from exo.shared.models.model_cards import ModelId
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.tasks import (
CancelTask,
@@ -18,7 +17,6 @@ from exo.shared.types.tasks import (
TaskId,
TaskStatus,
TextGeneration,
TransferModelToDisk,
)
from exo.shared.types.worker.downloads import (
DownloadCompleted,
@@ -37,11 +35,8 @@ from exo.shared.types.worker.runners import (
RunnerLoading,
RunnerReady,
RunnerRunning,
RunnerShutdown,
RunnerShuttingDown,
RunnerStatus,
RunnerWarmingUp,
ShardAssignments,
)
from exo.worker.runner.runner_supervisor import RunnerSupervisor
@@ -61,10 +56,9 @@ def plan(
return (
_cancel_tasks(runners, tasks)
or _kill_runner(runners, all_runners, instances)
or _create_runner(node_id, runners, instances, all_runners)
or _create_runner(node_id, runners, instances)
or _model_needs_download(node_id, runners, global_download_status)
or _init_distributed_backend(runners, all_runners)
or _transfer_model_to_disk(runners, all_runners, global_download_status)
or _load_model(runners, all_runners, global_download_status)
or _ready_to_warmup(runners, all_runners)
or _pending_tasks(runners, tasks, all_runners, input_chunk_buffer or {})
@@ -81,12 +75,6 @@ def _kill_runner(
if (instance_id := runner.bound_instance.instance.instance_id) not in instances:
return Shutdown(instance_id=instance_id, runner_id=runner_id)
# Master removed our runner from state (retry signal) and process is dead
if runner_id not in all_runners and isinstance(
runner.status, (RunnerFailed, RunnerShutdown)
):
return Shutdown(instance_id=instance_id, runner_id=runner_id)
for (
global_runner_id
) in runner.bound_instance.instance.shard_assignments.node_to_runner.values():
@@ -104,7 +92,6 @@ def _create_runner(
node_id: NodeId,
runners: Mapping[RunnerId, RunnerSupervisor],
instances: Mapping[InstanceId, Instance],
all_runners: Mapping[RunnerId, RunnerStatus],
) -> CreateRunner | None:
for instance in instances.values():
runner_id = instance.shard_assignments.node_to_runner.get(node_id, None)
@@ -114,16 +101,6 @@ def _create_runner(
if runner_id in runners:
continue
# Don't create while any peer runner is in a terminal state — wait for
# the master to emit InstanceRetrying which removes them from state.
has_terminal_peer = any(
isinstance(all_runners.get(peer_rid), (RunnerFailed, RunnerShutdown))
for peer_rid in instance.shard_assignments.node_to_runner.values()
if peer_rid != runner_id
)
if has_terminal_peer:
continue
shard = instance.shard(runner_id)
assert shard is not None
@@ -146,10 +123,6 @@ def _model_needs_download(
}
for runner in runners.values():
# Transfer-only instances don't need downloads
if runner.bound_instance.instance.shard_assignments.transfer_only:
continue
model_id = runner.bound_instance.bound_shard.model_card.model_id
if isinstance(runner.status, RunnerIdle) and (
model_id not in download_status
@@ -158,15 +131,6 @@ def _model_needs_download(
(DownloadOngoing, DownloadCompleted, DownloadFailed),
)
):
# For multi-node instances, skip download if a peer already has the model.
# The model will be transferred via MLX distributed during LoadModel.
instance = runner.bound_instance.instance
is_multi_node = len(instance.shard_assignments.node_to_runner) > 1
if is_multi_node and _any_peer_has_model(
node_id, model_id, instance, global_download_status
):
continue
# We don't invalidate download_status randomly in case a file gets deleted on disk
return DownloadModel(
instance_id=runner.bound_instance.instance.instance_id,
@@ -224,43 +188,6 @@ def _init_distributed_backend(
return None
def _transfer_model_to_disk(
runners: Mapping[RunnerId, RunnerSupervisor],
all_runners: Mapping[RunnerId, RunnerStatus],
global_download_status: Mapping[NodeId, Sequence[DownloadProgress]],
) -> TransferModelToDisk | None:
"""For transfer-only instances: after all ranks are connected, emit TransferModelToDisk."""
for runner in runners.values():
instance = runner.bound_instance.instance
shard_assignments = instance.shard_assignments
if not shard_assignments.transfer_only:
continue
is_runner_connected = isinstance(runner.status, RunnerConnected)
all_connected_or_further = all(
isinstance(
all_runners.get(global_runner_id, None),
(RunnerConnected, RunnerLoading, RunnerShuttingDown, RunnerShutdown),
)
for global_runner_id in shard_assignments.runner_to_shard
)
if is_runner_connected and all_connected_or_further:
has_local = _node_has_download(
runner.bound_instance.bound_node_id,
shard_assignments.model_id,
global_download_status,
)
return TransferModelToDisk(
instance_id=instance.instance_id,
shard_metadata=runner.bound_instance.bound_shard,
has_local_model=has_local,
)
return None
def _load_model(
runners: Mapping[RunnerId, RunnerSupervisor],
all_runners: Mapping[RunnerId, RunnerStatus],
@@ -270,97 +197,38 @@ def _load_model(
instance = runner.bound_instance.instance
shard_assignments = instance.shard_assignments
# Transfer-only instances don't load models for inference
if shard_assignments.transfer_only:
all_local_downloads_complete = all(
nid in global_download_status
and any(
isinstance(dp, DownloadCompleted)
and dp.shard_metadata.model_card.model_id == shard_assignments.model_id
for dp in global_download_status[nid]
)
for nid in shard_assignments.node_to_runner
)
if not all_local_downloads_complete:
continue
is_single_node_instance = len(shard_assignments.runner_to_shard) == 1
is_single_node_instance = len(instance.shard_assignments.runner_to_shard) == 1
if is_single_node_instance and isinstance(runner.status, RunnerIdle):
return LoadModel(instance_id=instance.instance_id)
if is_single_node_instance:
# Single-node: require local download complete
if not _all_downloads_complete(shard_assignments, global_download_status):
continue
if isinstance(runner.status, RunnerIdle):
return LoadModel(instance_id=instance.instance_id, has_local_model=True)
else:
# Multi-node: require at least one node to have the model downloaded.
# Nodes without the model will receive it via MLX distributed transfer
# during model loading.
if not _any_download_complete(shard_assignments, global_download_status):
continue
is_runner_waiting = isinstance(runner.status, RunnerConnected)
is_runner_waiting = isinstance(runner.status, RunnerConnected)
all_ready_for_model = all(
isinstance(
all_runners.get(global_runner_id, None),
(RunnerConnected, RunnerLoading, RunnerLoaded),
)
for global_runner_id in shard_assignments.runner_to_shard
all_ready_for_model = all(
isinstance(
all_runners.get(global_runner_id, None),
(RunnerConnected, RunnerLoading, RunnerLoaded),
)
for global_runner_id in shard_assignments.runner_to_shard
)
if is_runner_waiting and all_ready_for_model:
has_local = _node_has_download(
runner.bound_instance.bound_node_id,
shard_assignments.model_id,
global_download_status,
)
return LoadModel(
instance_id=instance.instance_id,
has_local_model=has_local,
)
if is_runner_waiting and all_ready_for_model:
return LoadModel(instance_id=instance.instance_id)
return None
def _node_has_download(
nid: NodeId,
model_id: ModelId,
global_download_status: Mapping[NodeId, Sequence[DownloadProgress]],
) -> bool:
"""Check if a specific node has completed downloading the given model."""
return any(
isinstance(dp, DownloadCompleted)
and dp.shard_metadata.model_card.model_id == model_id
for dp in global_download_status.get(nid, [])
)
def _any_peer_has_model(
node_id: NodeId,
model_id: ModelId,
instance: Instance,
global_download_status: Mapping[NodeId, Sequence[DownloadProgress]],
) -> bool:
"""Check if any other node in the instance already has the model downloaded."""
return any(
_node_has_download(nid, model_id, global_download_status)
for nid in instance.shard_assignments.node_to_runner
if nid != node_id
)
def _all_downloads_complete(
shard_assignments: ShardAssignments,
global_download_status: Mapping[NodeId, Sequence[DownloadProgress]],
) -> bool:
"""Check if ALL nodes in the instance have completed downloading the model."""
return all(
_node_has_download(nid, shard_assignments.model_id, global_download_status)
for nid in shard_assignments.node_to_runner
)
def _any_download_complete(
shard_assignments: ShardAssignments,
global_download_status: Mapping[NodeId, Sequence[DownloadProgress]],
) -> bool:
"""Check if at least one node in the instance has completed downloading the model."""
return any(
_node_has_download(nid, shard_assignments.model_id, global_download_status)
for nid in shard_assignments.node_to_runner
)
def _ready_to_warmup(
runners: Mapping[RunnerId, RunnerSupervisor],
all_runners: Mapping[RunnerId, RunnerStatus],
@@ -368,11 +236,6 @@ def _ready_to_warmup(
for runner in runners.values():
instance = runner.bound_instance.instance
shard_assignments = instance.shard_assignments
# Transfer-only instances don't go through warmup
if shard_assignments.transfer_only:
continue
shard = runner.bound_instance.bound_shard
device_rank = shard.device_rank
runner_id = runner.bound_instance.bound_runner_id
@@ -447,8 +310,7 @@ def _pending_tasks(
def _cancel_tasks(
runners: Mapping[RunnerId, RunnerSupervisor],
tasks: Mapping[TaskId, Task],
) -> CancelTask | None:
"""Find a cancelled task that hasn't been sent to the runner yet."""
) -> Task | None:
for task in tasks.values():
if task.task_status != TaskStatus.Cancelled:
continue

View File

@@ -17,7 +17,6 @@ def entrypoint(
task_receiver: MpReceiver[Task],
cancel_receiver: MpReceiver[TaskId],
_logger: "loguru.Logger",
pipe_fifo_paths: tuple[str, str] | None = None,
) -> None:
fast_synch_override = os.environ.get("EXO_FAST_SYNCH")
if fast_synch_override == "on" or (
@@ -31,16 +30,6 @@ def entrypoint(
else:
os.environ["MLX_METAL_FAST_SYNCH"] = "0"
# Open JACCL FIFOs by path and set env vars for C++ SideChannel.
# Named pipes (FIFOs) work across multiprocessing spawn (macOS default).
if pipe_fifo_paths is not None:
fifo_c2p, fifo_p2c = pipe_fifo_paths
# C++ reads gathered data from p2c (PIPE_IN), writes local data to c2p (PIPE_OUT)
pipe_in_fd = os.open(fifo_p2c, os.O_RDONLY)
pipe_out_fd = os.open(fifo_c2p, os.O_WRONLY)
os.environ["MLX_JACCL_PIPE_IN"] = str(pipe_in_fd)
os.environ["MLX_JACCL_PIPE_OUT"] = str(pipe_out_fd)
global logger
logger = _logger
@@ -67,9 +56,7 @@ def entrypoint(
try:
event_sender.close()
task_receiver.close()
cancel_receiver.close()
finally:
event_sender.join()
task_receiver.join()
cancel_receiver.join()
logger.info("bye from the runner")

View File

@@ -42,7 +42,6 @@ from exo.shared.types.tasks import (
TaskId,
TaskStatus,
TextGeneration,
TransferModelToDisk,
)
from exo.shared.types.text_generation import TextGenerationTaskParams
from exo.shared.types.worker.instances import BoundInstance
@@ -82,11 +81,6 @@ from exo.worker.engines.image import (
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.cache import KVPrefixCache
from exo.worker.engines.mlx.generator.generate import mlx_generate, warmup_inference
from exo.worker.engines.mlx.model_transfer import (
coordinate_transfer,
model_path_for_id,
transfer_all_files,
)
from exo.worker.engines.mlx.utils_mlx import (
apply_chat_template,
detect_thinking_prompt_suffix,
@@ -207,10 +201,7 @@ def main(
if ModelTask.TextGeneration in shard_metadata.model_card.tasks:
inference_model, tokenizer = load_mlx_items(
bound_instance,
group,
on_timeout=on_model_load_timeout,
has_local_model=task.has_local_model,
bound_instance, group, on_timeout=on_model_load_timeout
)
logger.info(
f"model has_tool_calling={tokenizer.has_tool_calling} using tokens {tokenizer.tool_call_start}, {tokenizer.tool_call_end}"
@@ -252,7 +243,7 @@ def main(
assert inference_model
assert tokenizer
t = time.perf_counter()
t = time.monotonic()
toks = warmup_inference(
model=inference_model,
tokenizer=tokenizer,
@@ -260,7 +251,7 @@ def main(
)
logger.info(f"warmed up by generating {toks} tokens")
check_for_cancel_every = min(
math.ceil(toks / max(time.perf_counter() - t, 0.001)), 100
math.ceil(toks / min(time.monotonic() - t, 0.001)), 100
)
if group is not None:
check_for_cancel_every = int(
@@ -544,27 +535,6 @@ def main(
current_status = RunnerReady()
logger.info("runner ready")
case TransferModelToDisk() if (
isinstance(current_status, RunnerConnected) and group is not None
):
logger.info("starting disk-to-disk model transfer")
event_sender.send(TaskAcknowledged(task_id=task.task_id))
model_path = model_path_for_id(
task.shard_metadata.model_card.model_id
)
_, source_rank = coordinate_transfer(group, task.has_local_model)
is_source = group.rank() == source_rank
transfer_all_files(model_path, group, is_source)
logger.info("disk-to-disk model transfer complete")
current_status = RunnerShuttingDown()
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=current_status
)
)
current_status = RunnerShutdown()
case Shutdown():
current_status = RunnerShuttingDown()
logger.info("runner shutting down")

View File

@@ -1,10 +1,6 @@
import contextlib
import os
import signal
import struct
import tempfile
from dataclasses import dataclass, field
from functools import partial
from multiprocessing import Process
from typing import Self
@@ -18,14 +14,12 @@ from loguru import logger
from exo.shared.types.events import (
Event,
JacclSideChannelData,
JacclSideChannelGathered,
RunnerStatusUpdated,
TaskAcknowledged,
TaskStatusUpdated,
)
from exo.shared.types.tasks import Task, TaskId, TaskStatus
from exo.shared.types.worker.instances import BoundInstance, MlxJacclInstance
from exo.shared.types.worker.instances import BoundInstance
from exo.shared.types.worker.runners import (
RunnerConnecting,
RunnerFailed,
@@ -40,26 +34,6 @@ from exo.shared.types.worker.shards import ShardMetadata
from exo.utils.channels import MpReceiver, MpSender, Sender, mp_channel
from exo.worker.runner.bootstrap import entrypoint
def _pipe_read_exact(fd: int, n: int) -> bytes | None:
"""Read exactly n bytes from a file descriptor. Returns None on EOF."""
data = b""
while len(data) < n:
chunk = os.read(fd, n - len(data))
if not chunk:
return None
data += chunk
return data
def _pipe_write_all(fd: int, data: bytes) -> None:
"""Write all bytes to a file descriptor."""
view = memoryview(data)
while view:
written = os.write(fd, view)
view = view[written:]
PREFILL_TIMEOUT_SECONDS = 60
DECODE_TIMEOUT_SECONDS = 5
@@ -72,21 +46,12 @@ class RunnerSupervisor:
initialize_timeout: float
_ev_recv: MpReceiver[Event]
_task_sender: MpSender[Task]
_cancel_sender: MpSender[TaskId]
_event_sender: Sender[Event]
_pipe_read_fd: int | None = None # Python reads runner's pipe output
_pipe_write_fd: int | None = None # Python writes gathered data to runner
_child_pipe_fds: tuple[int, int] | None = None # fds to close after fork
_fifo_dir: str | None = None # Temp dir for FIFO files (for cleanup)
_fifo_c2p: str | None = None # FIFO path: C++ writes → Python reads
_fifo_p2c: str | None = None # FIFO path: Python writes → C++ reads
_cancel_sender: MpSender[TaskId]
status: RunnerStatus = field(default_factory=RunnerIdle, init=False)
pending: dict[TaskId, anyio.Event] = field(default_factory=dict, init=False)
completed: set[TaskId] = field(default_factory=set, init=False)
cancelled: set[TaskId] = field(default_factory=set, init=False)
_gathered_waiters: dict[
int, tuple[anyio.Event, JacclSideChannelGathered | None]
] = field(default_factory=dict, init=False)
@classmethod
def create(
@@ -100,23 +65,6 @@ class RunnerSupervisor:
task_sender, task_recv = mp_channel[Task]()
cancel_sender, cancel_recv = mp_channel[TaskId]()
# For MlxJaccl instances, create named pipes (FIFOs) for SideChannel relay.
# Named pipes work across multiprocessing.Process spawn (macOS default).
# FIFO c2p: C++ writes local data → Python reads it
# FIFO p2c: Python writes gathered data → C++ reads it
fifo_dir: str | None = None
fifo_c2p: str | None = None
fifo_p2c: str | None = None
pipe_fifo_paths: tuple[str, str] | None = None
if isinstance(bound_instance.instance, MlxJacclInstance):
fifo_dir = tempfile.mkdtemp(prefix="exo_jaccl_")
fifo_c2p = os.path.join(fifo_dir, "c2p") # C++ → Python
fifo_p2c = os.path.join(fifo_dir, "p2c") # Python → C++
os.mkfifo(fifo_c2p)
os.mkfifo(fifo_p2c)
pipe_fifo_paths = (fifo_c2p, fifo_p2c)
runner_process = Process(
target=entrypoint,
args=(
@@ -125,7 +73,6 @@ class RunnerSupervisor:
task_recv,
cancel_recv,
logger,
pipe_fifo_paths,
),
daemon=True,
)
@@ -141,54 +88,21 @@ class RunnerSupervisor:
_task_sender=task_sender,
_cancel_sender=cancel_sender,
_event_sender=event_sender,
_fifo_dir=fifo_dir,
_fifo_c2p=fifo_c2p,
_fifo_p2c=fifo_p2c,
)
return self
async def run(self):
self.runner_process.start()
if self._fifo_c2p is not None and self._fifo_p2c is not None:
# Open FIFOs from parent side. These block until child opens the other end,
# so we run them in threads concurrently to avoid deadlock.
fifo_c2p = self._fifo_c2p
fifo_p2c = self._fifo_p2c
async def open_read() -> None:
self._pipe_read_fd = await to_thread.run_sync(
partial(os.open, fifo_c2p, os.O_RDONLY)
)
async def open_write() -> None:
self._pipe_write_fd = await to_thread.run_sync(
partial(os.open, fifo_p2c, os.O_WRONLY)
)
async with anyio.create_task_group() as open_tg:
open_tg.start_soon(open_read)
open_tg.start_soon(open_write)
logger.info(
f"JACCL pipe relay: FIFOs opened (read_fd={self._pipe_read_fd}, write_fd={self._pipe_write_fd})"
)
async with anyio.create_task_group() as tg:
tg.start_soon(self._pipe_relay)
tg.start_soon(self._forward_events)
else:
await self._forward_events()
await self._forward_events()
def shutdown(self):
logger.info("Runner supervisor shutting down")
self._ev_recv.close()
self._task_sender.close()
self._event_sender.close()
self._cancel_sender.send(TaskId("CANCEL_CURRENT_TASK"))
self._cancel_sender.close()
self._event_sender.close()
self._close_pipe_fds()
self.runner_process.join(1)
if not self.runner_process.is_alive():
logger.info("Runner process succesfully terminated")
@@ -226,7 +140,6 @@ class RunnerSupervisor:
await event.wait()
async def cancel_task(self, task_id: TaskId):
"""Send a cancellation signal to the runner process."""
if task_id in self.completed:
logger.info(f"Unable to cancel {task_id} as it has been completed")
return
@@ -268,110 +181,6 @@ class RunnerSupervisor:
for tid in self.pending:
self.pending[tid].set()
def _close_pipe_fds(self) -> None:
if self._pipe_read_fd is not None:
with contextlib.suppress(OSError):
os.close(self._pipe_read_fd)
self._pipe_read_fd = None
if self._pipe_write_fd is not None:
with contextlib.suppress(OSError):
os.close(self._pipe_write_fd)
self._pipe_write_fd = None
if self._child_pipe_fds is not None:
for fd in self._child_pipe_fds:
with contextlib.suppress(OSError):
os.close(fd)
self._child_pipe_fds = None
# Clean up FIFO files
if self._fifo_c2p is not None:
with contextlib.suppress(OSError):
os.unlink(self._fifo_c2p)
self._fifo_c2p = None
if self._fifo_p2c is not None:
with contextlib.suppress(OSError):
os.unlink(self._fifo_p2c)
self._fifo_p2c = None
if self._fifo_dir is not None:
with contextlib.suppress(OSError):
os.rmdir(self._fifo_dir)
self._fifo_dir = None
async def _pipe_relay(self) -> None:
"""Relay JACCL SideChannel all_gather rounds between runner pipes and exo events."""
assert self._pipe_read_fd is not None
assert self._pipe_write_fd is not None
read_fd = self._pipe_read_fd
write_fd = self._pipe_write_fd
sequence = 0
try:
while True:
# 1. Read local data from runner: [uint32 size][size bytes]
header = await to_thread.run_sync(partial(_pipe_read_exact, read_fd, 4))
if header is None:
logger.info("JACCL pipe relay: runner closed pipe (EOF)")
break
data_size: int = struct.unpack("<I", header)[0] # pyright: ignore[reportAny]
local_data = await to_thread.run_sync(
partial(_pipe_read_exact, read_fd, data_size)
)
if local_data is None:
logger.warning("JACCL pipe relay: EOF reading data payload")
break
logger.info(
f"JACCL pipe relay: read {data_size} bytes from runner, seq={sequence}"
)
# 2. Emit JacclSideChannelData event
waiter = anyio.Event()
self._gathered_waiters[sequence] = (waiter, None)
await self._event_sender.send(
JacclSideChannelData(
instance_id=self.bound_instance.instance.instance_id,
runner_id=self.bound_instance.bound_runner_id,
sequence=sequence,
data=local_data,
)
)
# 3. Wait for gathered result
await waiter.wait()
_, gathered_event = self._gathered_waiters.pop(sequence)
assert gathered_event is not None
# 4. Order gathered data by runner rank and concatenate
instance = self.bound_instance.instance
assert isinstance(instance, MlxJacclInstance)
runner_order = list(instance.shard_assignments.runner_to_shard.keys())
ordered_data = b"".join(
gathered_event.gathered_data[rid] for rid in runner_order
)
# 5. Write gathered data to runner: [uint32 total_size][total_size bytes]
total_size = len(ordered_data)
response = struct.pack("<I", total_size) + ordered_data
await to_thread.run_sync(partial(_pipe_write_all, write_fd, response))
logger.info(
f"JACCL pipe relay: wrote {total_size} bytes to runner, seq={sequence}"
)
sequence += 1
except OSError as e:
logger.warning(f"JACCL pipe relay: OS error: {e}")
except Exception as e:
logger.opt(exception=e).error("JACCL pipe relay: unexpected error")
def notify_gathered(self, event: JacclSideChannelGathered) -> None:
"""Called by the worker when a JacclSideChannelGathered event arrives."""
seq = event.sequence
if seq not in self._gathered_waiters:
logger.warning(f"JACCL: received gathered event for unknown sequence {seq}")
return
waiter, _ = self._gathered_waiters[seq]
self._gathered_waiters[seq] = (waiter, event)
waiter.set()
def __del__(self) -> None:
if self.runner_process.is_alive():
logger.warning("RunnerSupervisor was not stopped cleanly.")

View File

@@ -112,7 +112,6 @@ def test_plan_loads_model_when_all_shards_downloaded_and_waiting():
assert isinstance(result, LoadModel)
assert result.instance_id == INSTANCE_1_ID
assert result.has_local_model is True
def test_plan_does_not_request_download_when_shard_already_downloaded():
@@ -158,11 +157,10 @@ def test_plan_does_not_request_download_when_shard_already_downloaded():
assert not isinstance(result, plan_mod.DownloadModel)
def test_plan_loads_model_when_any_node_has_download_for_multi_node():
def test_plan_does_not_load_model_until_all_shards_downloaded_globally():
"""
For multi-node instances, LoadModel should be emitted when at least one
node has the model downloaded. Nodes without the model will receive it
via MLX distributed transfer during model loading.
LoadModel should not be emitted while some shards are still missing from
the global_download_status.
"""
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
shard2 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)
@@ -187,7 +185,6 @@ def test_plan_loads_model_when_any_node_has_download_for_multi_node():
RUNNER_2_ID: RunnerConnected(),
}
# Only NODE_A has the model — LoadModel should still fire
global_download_status = {
NODE_A: [
DownloadCompleted(
@@ -206,42 +203,19 @@ def test_plan_loads_model_when_any_node_has_download_for_multi_node():
tasks={},
)
assert isinstance(result, LoadModel)
assert result.instance_id == INSTANCE_1_ID
assert result.has_local_model is True
assert result is None
def test_plan_does_not_load_model_when_no_node_has_download():
"""
LoadModel should not be emitted when no node has the model downloaded.
"""
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
shard2 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)
instance = get_mlx_ring_instance(
instance_id=INSTANCE_1_ID,
model_id=MODEL_A_ID,
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},
runner_to_shard={RUNNER_1_ID: shard1, RUNNER_2_ID: shard2},
)
bound_instance = BoundInstance(
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
)
local_runner = FakeRunnerSupervisor(
bound_instance=bound_instance, status=RunnerConnected()
)
runners = {RUNNER_1_ID: local_runner}
instances = {INSTANCE_1_ID: instance}
all_runners = {
RUNNER_1_ID: RunnerConnected(),
RUNNER_2_ID: RunnerConnected(),
}
# No node has the model
global_download_status: dict[NodeId, list[DownloadProgress]] = {
NODE_A: [],
NODE_B: [],
global_download_status = {
NODE_A: [
DownloadCompleted(
shard_metadata=shard1, node_id=NODE_A, total_bytes=Memory()
)
],
NODE_B: [
DownloadCompleted(
shard_metadata=shard2, node_id=NODE_B, total_bytes=Memory()
)
], # NODE_B has no downloads completed yet
}
result = plan_mod.plan(
@@ -253,57 +227,4 @@ def test_plan_does_not_load_model_when_no_node_has_download():
tasks={},
)
assert result is None
def test_plan_load_model_has_local_model_false_when_node_missing_download():
"""
For multi-node instances, when the local node does NOT have the model
but a peer does, LoadModel should be emitted with has_local_model=False.
"""
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
shard2 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)
instance = get_mlx_ring_instance(
instance_id=INSTANCE_1_ID,
model_id=MODEL_A_ID,
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},
runner_to_shard={RUNNER_1_ID: shard1, RUNNER_2_ID: shard2},
)
# NODE_B is the local node (bound_node_id=NODE_B), it does NOT have the model
bound_instance = BoundInstance(
instance=instance, bound_runner_id=RUNNER_2_ID, bound_node_id=NODE_B
)
local_runner = FakeRunnerSupervisor(
bound_instance=bound_instance, status=RunnerConnected()
)
runners = {RUNNER_2_ID: local_runner}
instances = {INSTANCE_1_ID: instance}
all_runners = {
RUNNER_1_ID: RunnerConnected(),
RUNNER_2_ID: RunnerConnected(),
}
# Only NODE_A has the model, NODE_B does not
global_download_status: dict[NodeId, list[DownloadProgress]] = {
NODE_A: [
DownloadCompleted(
shard_metadata=shard1, node_id=NODE_A, total_bytes=Memory()
)
],
NODE_B: [],
}
result = plan_mod.plan(
node_id=NODE_B,
runners=runners, # type: ignore
global_download_status=global_download_status,
instances=instances,
all_runners=all_runners,
tasks={},
)
assert isinstance(result, LoadModel)
assert result.instance_id == INSTANCE_1_ID
assert result.has_local_model is False
assert result is not None

View File

@@ -1,7 +1,9 @@
# Check tasks are complete before runner is ever ready.
import unittest.mock
from collections.abc import Iterable
from typing import Callable
import mlx.core as mx
import pytest
import exo.worker.runner.runner as mlx_runner
@@ -115,12 +117,6 @@ def patch_out_mlx(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(mlx_runner, "warmup_inference", make_nothin(1))
monkeypatch.setattr(mlx_runner, "_check_for_debug_prompts", nothin)
monkeypatch.setattr(mlx_runner, "mx_any", make_nothin(False))
# Mock mx.distributed.all_gather so MockGroup doesn't hit real MLX C++ bindings.
def _mock_all_gather(x: object, **_kw: object) -> object:
return x
monkeypatch.setattr(mlx_runner.mx.distributed, "all_gather", _mock_all_gather)
# Mock apply_chat_template since we're using a fake tokenizer (integer 1).
# Returns a prompt without thinking tag so detect_thinking_prompt_suffix returns None.
monkeypatch.setattr(mlx_runner, "apply_chat_template", make_nothin("test prompt"))
@@ -182,15 +178,16 @@ def _run(tasks: Iterable[Task]):
# this is some c++ nonsense
task_receiver.close = nothin
task_receiver.join = nothin
cancel_receiver.close = nothin
cancel_receiver.join = nothin
mlx_runner.main(
bound_instance,
event_sender, # pyright: ignore[reportArgumentType]
task_receiver,
cancel_receiver,
)
with unittest.mock.patch(
"exo.worker.runner.runner.mx.distributed.all_gather",
make_nothin(mx.array([1])),
):
mlx_runner.main(
bound_instance,
event_sender, # pyright: ignore[reportArgumentType]
task_receiver,
cancel_receiver,
)
return event_sender.events

40
uv.lock generated
View File

@@ -377,8 +377,8 @@ dependencies = [
{ name = "hypercorn", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "loguru", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "mflux", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "mlx", version = "0.30.6", source = { registry = "https://pypi.org/simple" }, extra = ["cpu"], marker = "sys_platform == 'linux'" },
{ name = "mlx", version = "0.30.7.dev20260217+50487b41", source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git?branch=address-rdma-gpu-locks#50487b4141f3c951122655db3b83df5146c1fbeb" }, marker = "sys_platform == 'darwin'" },
{ name = "mlx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "mlx", extra = ["cpu"], marker = "sys_platform == 'linux'" },
{ name = "mlx-lm", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "msgspec", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "openai-harmony", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
@@ -416,7 +416,7 @@ requires-dist = [
{ name = "hypercorn", specifier = ">=0.18.0" },
{ name = "loguru", specifier = ">=0.7.3" },
{ name = "mflux", specifier = "==0.15.5" },
{ name = "mlx", marker = "sys_platform == 'darwin'", git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git?branch=address-rdma-gpu-locks" },
{ name = "mlx", marker = "sys_platform == 'darwin'", specifier = "==0.30.6" },
{ name = "mlx", extras = ["cpu"], marker = "sys_platform == 'linux'", specifier = "==0.30.6" },
{ name = "mlx-lm", specifier = "==0.30.6" },
{ name = "msgspec", specifier = ">=0.19.0" },
@@ -1020,8 +1020,8 @@ dependencies = [
{ name = "fonttools", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "huggingface-hub", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "matplotlib", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "mlx", version = "0.30.6", source = { registry = "https://pypi.org/simple" }, extra = ["cuda13"], marker = "sys_platform == 'linux'" },
{ name = "mlx", version = "0.30.7.dev20260217+50487b41", source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git?branch=address-rdma-gpu-locks#50487b4141f3c951122655db3b83df5146c1fbeb" }, marker = "sys_platform == 'darwin'" },
{ name = "mlx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "mlx", extra = ["cuda13"], marker = "sys_platform == 'linux'" },
{ name = "numpy", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "opencv-python", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "piexif", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
@@ -1048,12 +1048,18 @@ wheels = [
name = "mlx"
version = "0.30.6"
source = { registry = "https://pypi.org/simple" }
resolution-markers = [
"sys_platform == 'linux'",
dependencies = [
{ name = "mlx-metal", marker = "sys_platform == 'darwin'" },
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/ae/5b/e460e144a34d5529e010056cccf50b538d56ed001473bc6b246018fd58cb/mlx-0.30.6-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:ed86f8bffc174c2f259ca589ea25464c96cf69d1bb457074a2bf2ef53737e54f", size = 573515, upload-time = "2026-02-06T03:45:23.405Z" },
{ url = "https://files.pythonhosted.org/packages/60/25/69833fefb9a3fef30b56792b1bcd022496c4fea83e45411d289b77ef7546/mlx-0.30.6-cp313-cp313-macosx_15_0_arm64.whl", hash = "sha256:c52294958269e20f300639a17c1900ca8fc737d859ddda737f9811e94bd040e5", size = 573516, upload-time = "2026-02-06T03:45:24.618Z" },
{ url = "https://files.pythonhosted.org/packages/9c/6a/7e7fbeebc5cb51b6a5eba96b263a6298707bcbdc059f4b0b73e088bc3dea/mlx-0.30.6-cp313-cp313-macosx_26_0_arm64.whl", hash = "sha256:b5b6636f7c49a4d86d8ec82643b972f45a144a7a9f3a967b27b2e6e22cf71e6a", size = 573592, upload-time = "2026-02-06T03:45:25.928Z" },
{ url = "https://files.pythonhosted.org/packages/93/06/280f6f2ba80520a7109730425eda0d966658793aa0d02d8be8d351f75253/mlx-0.30.6-cp313-cp313-manylinux_2_35_aarch64.whl", hash = "sha256:67e6c9e30a9faeacc209917ef5523177cf9b086914b6b5d83ff886e4294b727d", size = 622011, upload-time = "2026-02-06T03:45:28.165Z" },
{ url = "https://files.pythonhosted.org/packages/fe/35/f872afbee9c079cc69924d9e9c46f5663adb7da58cba3511db082dd307c1/mlx-0.30.6-cp313-cp313-manylinux_2_35_x86_64.whl", hash = "sha256:47db8b16fcb6f6c5a47c0bdb24ed377b41237017ac93aa6cb6aa206c9bdf82e4", size = 663650, upload-time = "2026-02-06T03:45:30.315Z" },
{ url = "https://files.pythonhosted.org/packages/60/23/361dc7a5797634e4d7e9bdd6564c6b28f9b1246672632def2f91bf066b18/mlx-0.30.6-cp314-cp314-macosx_14_0_arm64.whl", hash = "sha256:78804a89dcff4a838f7c2da72392fe87a523e95122a3c840e53df019122aad45", size = 575028, upload-time = "2026-02-06T03:45:31.549Z" },
{ url = "https://files.pythonhosted.org/packages/a8/69/1854484d414171586814dfbe8def95f75c4ea2c7341ba13ba8ee675f7c62/mlx-0.30.6-cp314-cp314-macosx_15_0_arm64.whl", hash = "sha256:ec13584ab069665cc7ad34a05494d9291cd623aef6ae96be48875fc87cfc25d6", size = 575026, upload-time = "2026-02-06T03:45:33.072Z" },
{ url = "https://files.pythonhosted.org/packages/6b/b8/3adbc441924209a7e4c568308b2a0b54bd09aee6a68db5bae85304791e54/mlx-0.30.6-cp314-cp314-macosx_26_0_arm64.whl", hash = "sha256:b2c5e8a090a753ef99a1380a4d059c983083f36198864f6df9faaf1223d083df", size = 575041, upload-time = "2026-02-06T03:45:34.814Z" },
{ url = "https://files.pythonhosted.org/packages/3f/54/9d9e06804fb2088202a2cdf60458e00b221f71420bea285720b60f9e82b5/mlx-0.30.6-cp314-cp314-manylinux_2_35_aarch64.whl", hash = "sha256:9ceddede4af0de31d1f6b3099f70e5469d60cd7c546975dedbdbeab3519cab3f", size = 624002, upload-time = "2026-02-06T03:45:36Z" },
{ url = "https://files.pythonhosted.org/packages/42/92/3140a15a50cb1f9267a6552171e1dfa577861de53e093124bc43707f2a0e/mlx-0.30.6-cp314-cp314-manylinux_2_35_x86_64.whl", hash = "sha256:4a6ffd2d16728cf95f63a1b555d7c2eaeea686a0e6b73228bd265411cb5d77a4", size = 663569, upload-time = "2026-02-06T03:45:37.242Z" },
]
@@ -1066,14 +1072,6 @@ cuda13 = [
{ name = "mlx-cuda-13", marker = "sys_platform == 'linux'" },
]
[[package]]
name = "mlx"
version = "0.30.7.dev20260217+50487b41"
source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git?branch=address-rdma-gpu-locks#50487b4141f3c951122655db3b83df5146c1fbeb" }
resolution-markers = [
"sys_platform == 'darwin'",
]
[[package]]
name = "mlx-cpu"
version = "0.30.6"
@@ -1104,7 +1102,7 @@ version = "0.30.6"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "jinja2", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "mlx", version = "0.30.7.dev20260217+50487b41", source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git?branch=address-rdma-gpu-locks#50487b4141f3c951122655db3b83df5146c1fbeb" }, marker = "sys_platform == 'darwin'" },
{ name = "mlx", marker = "sys_platform == 'darwin'" },
{ name = "numpy", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "protobuf", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "pyyaml", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
@@ -1116,6 +1114,16 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/20/5f/01d281f1fa8a1521d5936659beb4f5ab1f32b463d059263cf9d4cef969d9/mlx_lm-0.30.6-py3-none-any.whl", hash = "sha256:a7405bd581eacc4bf8209d7a6b7f23629585a0d7c6740c2a97e51fee35b3b0e1", size = 379451, upload-time = "2026-02-04T21:27:43.222Z" },
]
[[package]]
name = "mlx-metal"
version = "0.30.6"
source = { registry = "https://pypi.org/simple" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/f3/85/44406b521f920248fad621334d4dc15e77660a494edf890e7cbee33bf38d/mlx_metal-0.30.6-py3-none-macosx_14_0_arm64.whl", hash = "sha256:ea6d0c973def9a5b4f652cc77036237db3f88c9d0af63701d76b5fddde99b820", size = 38437818, upload-time = "2026-02-06T03:44:56.19Z" },
{ url = "https://files.pythonhosted.org/packages/d0/cb/10a516995f7d0c154b0d7e633c54b51e96977a86a355105b6474cfcbe0d0/mlx_metal-0.30.6-py3-none-macosx_15_0_arm64.whl", hash = "sha256:0f8cb94634d07e06a372d6ad9a090f38a18bab1ff19a140aede60eacf707bb94", size = 38433701, upload-time = "2026-02-06T03:44:59.678Z" },
{ url = "https://files.pythonhosted.org/packages/4c/7d/70cb272f7373c334709f210ed8420511fc9d64d05a7a646c0b3b94c29c04/mlx_metal-0.30.6-py3-none-macosx_26_0_arm64.whl", hash = "sha256:d761ae26304f2c4b454eeea7f612a56919d9e5e57dbb1dc0788f8e34aa6f41c2", size = 47718448, upload-time = "2026-02-06T03:45:03.133Z" },
]
[[package]]
name = "more-itertools"
version = "10.8.0"