From a962a28afc7deb65ecb99d1f1cf83be1c5c10dee Mon Sep 17 00:00:00 2001 From: Alex Cheema <41707476+AlexCheema@users.noreply.github.com> Date: Tue, 17 Feb 2026 09:48:19 -0800 Subject: [PATCH] Add MetaInstance declarative layer (#1447) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Motivation Users currently manage instances directly, which means if a node disconnects or connections break, the instance dies and nothing recreates it. MetaInstance is a declarative primitive: "ensure an instance matching these parameters always exists." The reconciler watches for unhealthy or missing backing instances and re-places them automatically. ## Changes - **MetaInstance type** (`meta_instance.py`): declarative constraint with `model_id`, `min_nodes`, optional `node_ids`, and `sharding` - **Reconciler** (`reconcile.py`): `find_unsatisfied_meta_instances` checks which MetaInstances lack a healthy backing instance, `try_place_for_meta_instance` creates one - **Master loop** (`main.py`): periodically reconciles unsatisfied MetaInstances; immediate placement on `CreateMetaInstance` command - **API** (`api.py`): `create_meta_instance` / `delete_meta_instance` / `GET /meta_instances` endpoints; delete cascades to backing instances with task cancellation - **Binding via `meta_instance_id` on Instance** (`instances.py`): no separate binding event or backing map — the instance carries its parent MetaInstance ID directly, eliminating race conditions in the reconciler - **Dashboard**: sidebar shows MetaInstances with their backing instance status; orphan instances (created directly) still shown separately - **Tests**: constraint matching, connection health, unsatisfied detection, exclusive binding, cascade delete with task cancellation ### Recent improvements - **fix: cancel active tasks on cascade delete** — `DeleteMetaInstance` now emits `TaskStatusUpdated(Cancelled)` for any Pending/Running tasks on backing instances before emitting `InstanceDeleted`. Previously, cascade-deleting backing instances left orphaned task references in state. - **Lifecycle logging** — added `logger.info`/`logger.warning` for: `CreateMetaInstance` (model, min_nodes, sharding), `DeleteMetaInstance` (with cascade count), reconciler placement success/failure, and retry decisions with attempt counts in `InstanceHealthReconciler`. - **GET `/meta_instances` endpoint** — lists all meta-instances without needing to fetch full state. - **2 regression tests** — `test_cascade_delete_cancels_active_tasks` and `test_cascade_delete_skips_completed_tasks` verify the cascade-delete event sequence. ## Why It Works Putting `meta_instance_id` on `BaseInstance` makes binding inherent to instance creation. When the reconciler creates an instance for a MetaInstance, it tags it via `model_copy`. When the instance is deleted, the binding disappears with it. This avoids the two bugs that a separate binding mechanism would introduce: 1. Stale exclusion sets — the reconciler loop can't accidentally bind two MetaInstances to the same instance 2. Delete ordering race — no window between deleting an instance and its binding where the reconciler could re-place ## Test Plan ### Manual Testing - Created MetaInstance via dashboard, verified instance placed - Verified delete cascades (deleting MetaInstance removes backing instance) - Verified orphan instances still work independently ### Automated Testing - 30 tests in `test_meta_instance_edge_cases.py`: lifecycle, retry logic, error handling, concurrent operations, cascade delete with task cancellation - 24 tests in `test_reconcile.py`: constraint matching, connection health (single/multi-node, edge removal, IP changes), unsatisfied detection, exclusive binding, idempotency - All 261 tests pass - basedpyright 0 errors, ruff clean, dashboard builds --------- Co-authored-by: Claude Opus 4.6 --- .../src/lib/components/ChatSidebar.svelte | 6 +- dashboard/src/lib/components/ModelCard.svelte | 6 +- dashboard/src/lib/stores/app.svelte.ts | 27 +- dashboard/src/routes/+page.svelte | 782 +++++++++++++++--- src/exo/download/coordinator.py | 12 +- src/exo/main.py | 2 +- src/exo/master/api.py | 78 +- src/exo/master/main.py | 255 ++++-- src/exo/master/placement.py | 20 +- src/exo/master/placement_utils.py | 6 +- src/exo/master/process_managers/__init__.py | 12 + .../process_managers/instance_health.py | 62 ++ .../master/process_managers/meta_instance.py | 92 +++ .../master/process_managers/node_timeout.py | 27 + src/exo/master/reconcile.py | 244 ++++++ .../tests/test_meta_instance_edge_cases.py | 778 +++++++++++++++++ src/exo/master/tests/test_placement_utils.py | 12 +- src/exo/master/tests/test_reconcile.py | 742 +++++++++++++++++ src/exo/shared/apply.py | 126 ++- src/exo/shared/types/api.py | 22 +- src/exo/shared/types/commands.py | 13 +- src/exo/shared/types/common.py | 4 + src/exo/shared/types/events.py | 80 +- src/exo/shared/types/meta_instance.py | 25 + src/exo/shared/types/state.py | 4 +- src/exo/shared/types/tasks.py | 2 +- src/exo/shared/types/worker/instances.py | 3 +- src/exo/utils/channels.py | 4 +- src/exo/worker/engines/mlx/utils_mlx.py | 5 + src/exo/worker/main.py | 34 +- src/exo/worker/plan.py | 23 +- src/exo/worker/runner/bootstrap.py | 13 + src/exo/worker/runner/runner.py | 4 +- src/exo/worker/runner/runner_supervisor.py | 199 ++++- .../test_runner/test_event_ordering.py | 27 +- 35 files changed, 3457 insertions(+), 294 deletions(-) create mode 100644 src/exo/master/process_managers/__init__.py create mode 100644 src/exo/master/process_managers/instance_health.py create mode 100644 src/exo/master/process_managers/meta_instance.py create mode 100644 src/exo/master/process_managers/node_timeout.py create mode 100644 src/exo/master/reconcile.py create mode 100644 src/exo/master/tests/test_meta_instance_edge_cases.py create mode 100644 src/exo/master/tests/test_reconcile.py create mode 100644 src/exo/shared/types/meta_instance.py diff --git a/dashboard/src/lib/components/ChatSidebar.svelte b/dashboard/src/lib/components/ChatSidebar.svelte index b721b0339..6a822ddf6 100644 --- a/dashboard/src/lib/components/ChatSidebar.svelte +++ b/dashboard/src/lib/components/ChatSidebar.svelte @@ -185,11 +185,7 @@ let instanceType: string | null = null; if (instanceTag === "MlxRingInstance") instanceType = "MLX Ring"; - else if ( - instanceTag === "MlxIbvInstance" || - instanceTag === "MlxJacclInstance" - ) - instanceType = "MLX RDMA"; + else if (instanceTag === "MlxJacclInstance") instanceType = "MLX RDMA"; let sharding: string | null = null; const inst = instance as { diff --git a/dashboard/src/lib/components/ModelCard.svelte b/dashboard/src/lib/components/ModelCard.svelte index 9046d2a60..561c325bb 100644 --- a/dashboard/src/lib/components/ModelCard.svelte +++ b/dashboard/src/lib/components/ModelCard.svelte @@ -21,7 +21,7 @@ } | null; nodes?: Record; sharding?: "Pipeline" | "Tensor"; - runtime?: "MlxRing" | "MlxIbv" | "MlxJaccl"; + runtime?: "MlxRing" | "MlxJaccl"; onLaunch?: () => void; tags?: string[]; apiPreview?: PlacementPreview | null; @@ -348,7 +348,7 @@ // Debug mode state const isDebugMode = $derived(debugMode()); const topology = $derived(topologyData()); - const isRdma = $derived(runtime === "MlxIbv" || runtime === "MlxJaccl"); + const isRdma = $derived(runtime === "MlxJaccl"); // Get interface name for an IP from node data function getInterfaceForIp(nodeId: string, ip?: string): string | null { @@ -575,7 +575,7 @@ > {runtime === "MlxRing" ? "MLX Ring" - : runtime === "MlxIbv" || runtime === "MlxJaccl" + : runtime === "MlxJaccl" ? "MLX RDMA" : runtime} diff --git a/dashboard/src/lib/stores/app.svelte.ts b/dashboard/src/lib/stores/app.svelte.ts index e5dbf9029..ebb2d0df3 100644 --- a/dashboard/src/lib/stores/app.svelte.ts +++ b/dashboard/src/lib/stores/app.svelte.ts @@ -168,7 +168,7 @@ export interface ModelDownloadStatus { export interface PlacementPreview { model_id: string; sharding: "Pipeline" | "Tensor"; - instance_meta: "MlxRing" | "MlxIbv" | "MlxJaccl"; + instance_meta: "MlxRing" | "MlxJaccl"; instance: unknown | null; memory_delta_by_node: Record | null; error: string | null; @@ -219,7 +219,6 @@ interface RawStateResponse { string, { MlxRingInstance?: Instance; - MlxIbvInstance?: Instance; MlxJacclInstance?: Instance; } >; @@ -250,6 +249,20 @@ interface RawStateResponse { >; // Thunderbolt bridge cycles (nodes with bridge enabled forming loops) thunderboltBridgeCycles?: string[][]; + // MetaInstances (declarative instance constraints) + metaInstances?: Record; +} + +export interface MetaInstanceData { + metaInstanceId: string; + modelId: string; + sharding: string; + instanceMeta: string; + minNodes: number; + nodeIds: string[] | null; + placementError: string | null; + consecutiveFailures: number; + lastFailureError: string | null; } export interface MessageAttachment { @@ -537,6 +550,7 @@ class AppStore { previewNodeFilter = $state>(new Set()); lastUpdate = $state(null); nodeIdentities = $state>({}); + metaInstances = $state>({}); thunderboltBridgeCycles = $state([]); nodeThunderbolt = $state< Record< @@ -895,11 +909,7 @@ class AppStore { let instanceType: string | null = null; if (instanceTag === "MlxRingInstance") instanceType = "MLX Ring"; - else if ( - instanceTag === "MlxIbvInstance" || - instanceTag === "MlxJacclInstance" - ) - instanceType = "MLX RDMA"; + else if (instanceTag === "MlxJacclInstance") instanceType = "MLX RDMA"; let sharding: string | null = null; const inst = instance as { @@ -1273,6 +1283,8 @@ class AppStore { this.nodeThunderbolt = data.nodeThunderbolt ?? {}; // RDMA ctl status per node this.nodeRdmaCtl = data.nodeRdmaCtl ?? {}; + // MetaInstances + this.metaInstances = data.metaInstances ?? {}; // Thunderbolt bridge cycles this.thunderboltBridgeCycles = data.thunderboltBridgeCycles ?? []; // Thunderbolt bridge status per node @@ -3044,6 +3056,7 @@ export const tps = () => appStore.tps; export const totalTokens = () => appStore.totalTokens; export const topologyData = () => appStore.topologyData; export const instances = () => appStore.instances; +export const metaInstances = () => appStore.metaInstances; export const runners = () => appStore.runners; export const downloads = () => appStore.downloads; export const nodeDisk = () => appStore.nodeDisk; diff --git a/dashboard/src/routes/+page.svelte b/dashboard/src/routes/+page.svelte index 2fdeb8ab0..5f7dcf04a 100644 --- a/dashboard/src/routes/+page.svelte +++ b/dashboard/src/routes/+page.svelte @@ -44,11 +44,13 @@ toggleChatSidebarVisible, nodeThunderbolt, nodeRdmaCtl, + metaInstances, thunderboltBridgeCycles, nodeThunderboltBridge, nodeIdentities, type DownloadProgress, type PlacementPreview, + type MetaInstanceData, } from "$lib/stores/app.svelte"; import HeaderNav from "$lib/components/HeaderNav.svelte"; import { fade, fly } from "svelte/transition"; @@ -68,7 +70,70 @@ const debugEnabled = $derived(debugMode()); const topologyOnlyEnabled = $derived(topologyOnlyMode()); const sidebarVisible = $derived(chatSidebarVisible()); + const metaInstancesData = $derived(metaInstances()); const tbBridgeCycles = $derived(thunderboltBridgeCycles()); + + // Get status for a MetaInstance that has no backing instance yet + function getMetaInstancePlacingStatus(metaInstanceId: string) { + const meta = metaInstancesData[metaInstanceId]; + const placementError = meta?.placementError; + const failures = meta?.consecutiveFailures ?? 0; + const lastError = meta?.lastFailureError; + + if (placementError) { + return { + statusText: "PLACEMENT FAILED", + statusClass: "failed", + isDownloading: false as const, + isFailed: true, + progress: null, + perNode: [] as Array<{ + nodeId: string; + nodeName: string; + progress: DownloadProgress; + }>, + perNodeStatus: [] as PerNodeRunnerStatus[], + errorMessage: placementError, + }; + } + + if (failures > 0) { + const retryPosition = ((failures - 1) % 3) + 1; + const isRecreated = failures % 3 === 0; + return { + statusText: isRecreated ? "PLACING" : `RETRYING (${retryPosition}/3)`, + statusClass: "starting", + isDownloading: false as const, + isFailed: false, + progress: null, + perNode: [] as Array<{ + nodeId: string; + nodeName: string; + progress: DownloadProgress; + }>, + perNodeStatus: [] as PerNodeRunnerStatus[], + errorMessage: isRecreated + ? `Instance re-created due to failure: ${lastError}` + : `Previous failure: ${lastError}`, + }; + } + + return { + statusText: "PLACING", + statusClass: "starting", + isDownloading: false as const, + isFailed: false, + progress: null, + perNode: [] as Array<{ + nodeId: string; + nodeName: string; + progress: DownloadProgress; + }>, + perNodeStatus: [] as PerNodeRunnerStatus[], + errorMessage: null, + }; + } + const tbBridgeData = $derived(nodeThunderboltBridge()); const identitiesData = $derived(nodeIdentities()); const tbIdentifiers = $derived(nodeThunderbolt()); @@ -114,6 +179,17 @@ }); let tb5InfoDismissed = $state(false); + // Detect [jaccl] RDMA driver errors from MetaInstance failure errors + const jacclError = $derived.by(() => { + for (const mi of Object.values(metaInstancesData)) { + if (mi.lastFailureError?.includes("[jaccl]")) { + return mi.lastFailureError; + } + } + return null; + }); + let jacclDismissedError = $state(null); + // Helper to get friendly node name from node ID function getNodeName(nodeId: string): string { const node = data?.nodes?.[nodeId]; @@ -224,7 +300,7 @@ return model.tasks.includes("ImageToImage"); } let selectedSharding = $state<"Pipeline" | "Tensor">("Pipeline"); - type InstanceMeta = "MlxRing" | "MlxIbv" | "MlxJaccl"; + type InstanceMeta = "MlxRing" | "MlxJaccl"; // Launch defaults persistence const LAUNCH_DEFAULTS_KEY = "exo-launch-defaults"; @@ -481,7 +557,7 @@ const matchesSelectedRuntime = (runtime: InstanceMeta): boolean => selectedInstanceType === "MlxRing" ? runtime === "MlxRing" - : runtime === "MlxIbv" || runtime === "MlxJaccl"; + : runtime === "MlxJaccl" || runtime === "MlxJaccl"; // Helper to check if a model can be launched (has valid placement with >= minNodes) function canModelFit(modelId: string): boolean { @@ -697,39 +773,30 @@ launchingModelId = modelId; try { - // Use the specific preview if provided, otherwise fall back to filtered preview const preview = specificPreview ?? filteredPreview(); - let instanceData: unknown; + // Extract node IDs from the preview the user is seeing + const previewNodeIds = preview?.memory_delta_by_node + ? Object.keys(preview.memory_delta_by_node) + : nodeFilter.size > 0 + ? Array.from(nodeFilter) + : undefined; - if (preview?.instance) { - // Use the instance from the preview - instanceData = preview.instance; - } else { - // Fallback: GET placement from API - const placementResponse = await fetch( - `/instance/placement?model_id=${encodeURIComponent(modelId)}&sharding=${selectedSharding}&instance_meta=${selectedInstanceType}&min_nodes=${selectedMinNodes}`, - ); - - if (!placementResponse.ok) { - const errorText = await placementResponse.text(); - console.error("Failed to get placement:", errorText); - return; - } - - instanceData = await placementResponse.json(); - } - - // POST the instance to create it - const response = await fetch("/instance", { + const response = await fetch("/meta_instance", { method: "POST", headers: { "Content-Type": "application/json" }, - body: JSON.stringify({ instance: instanceData }), + body: JSON.stringify({ + model_id: modelId, + sharding: preview?.sharding ?? selectedSharding, + instance_meta: preview?.instance_meta ?? selectedInstanceType, + min_nodes: selectedMinNodes, + node_ids: previewNodeIds, + }), }); if (!response.ok) { const errorText = await response.text(); - console.error("Failed to launch instance:", errorText); + console.error("Failed to create meta instance:", errorText); } else { // Always auto-select the newly launched model so the user chats to what they just launched setSelectedChatModel(modelId); @@ -752,7 +819,7 @@ setTimeout(scrollToBottom, 1000); } } catch (error) { - console.error("Error launching instance:", error); + console.error("Error creating meta instance:", error); } finally { launchingModelId = null; } @@ -954,15 +1021,18 @@ nodeName: string; progress: DownloadProgress; }>; + perNodeStatus: PerNodeRunnerStatus[]; } { if (!downloadsData || Object.keys(downloadsData).length === 0) { + const statusInfo = deriveInstanceStatus(instanceWrapped); return { isDownloading: false, - isFailed: false, - errorMessage: null, + isFailed: statusInfo.statusText === "FAILED", + errorMessage: statusInfo.errorMessage, progress: null, - statusText: "RUNNING", + statusText: statusInfo.statusText, perNode: [], + perNodeStatus: statusInfo.perNodeStatus, }; } @@ -976,6 +1046,7 @@ progress: null, statusText: "PREPARING", perNode: [], + perNodeStatus: [], }; } @@ -1044,6 +1115,7 @@ progress: null, statusText: "FAILED", perNode: [], + perNodeStatus: [], }; } } @@ -1084,10 +1156,11 @@ return { isDownloading: false, isFailed: statusInfo.statusText === "FAILED", - errorMessage: null, + errorMessage: statusInfo.errorMessage, progress: null, statusText: statusInfo.statusText, perNode: [], + perNodeStatus: statusInfo.perNodeStatus, }; } @@ -1111,92 +1184,223 @@ }, statusText: "DOWNLOADING", perNode, + perNodeStatus: [], }; } // Derive instance status from runners // Get color class for a status function getStatusColor(statusText: string): string { - switch (statusText) { - case "FAILED": - return "text-red-400"; - case "SHUTDOWN": - return "text-gray-400"; - case "DOWNLOADING": - return "text-blue-400"; - case "LOADING": - case "WARMING UP": - case "WAITING": - case "INITIALIZING": - return "text-yellow-400"; - case "RUNNING": - return "text-teal-400"; - case "READY": - case "LOADED": - return "text-green-400"; - default: - return "text-exo-light-gray"; - } + if (statusText === "FAILED" || statusText === "PLACEMENT FAILED") + return "text-red-400"; + if (statusText.startsWith("RETRYING")) return "text-orange-400"; + if (statusText === "SHUTDOWN") return "text-gray-400"; + if (statusText === "DOWNLOADING") return "text-blue-400"; + if ( + statusText.startsWith("LOADING") || + statusText.startsWith("WARMING UP") || + statusText === "WAITING" || + statusText === "INITIALIZING" + ) + return "text-yellow-400"; + if (statusText === "RUNNING") return "text-teal-400"; + if (statusText === "READY" || statusText === "LOADED") + return "text-green-400"; + return "text-exo-light-gray"; + } + + const RUNNER_STATUS_MAP: Record = { + RunnerWaitingForInitialization: "WaitingForInitialization", + RunnerInitializingBackend: "InitializingBackend", + RunnerWaitingForModel: "WaitingForModel", + RunnerLoading: "Loading", + RunnerLoaded: "Loaded", + RunnerWarmingUp: "WarmingUp", + RunnerReady: "Ready", + RunnerRunning: "Running", + RunnerShutdown: "Shutdown", + RunnerFailed: "Failed", + }; + + // Friendly labels for display + const RUNNER_STATUS_DISPLAY: Record = { + WaitingForInitialization: "Initializing", + InitializingBackend: "Initializing", + WaitingForModel: "Waiting", + Loading: "Loading", + Loaded: "Loaded", + WarmingUp: "Warming Up", + Ready: "Ready", + Running: "Running", + Shutdown: "Shutdown", + Failed: "Failed", + }; + + interface PerNodeRunnerStatus { + nodeId: string; + nodeName: string; + status: string; // friendly display status } function deriveInstanceStatus(instanceWrapped: unknown): { statusText: string; statusClass: string; + perNodeStatus: PerNodeRunnerStatus[]; + errorMessage: string | null; } { const [, instance] = getTagged(instanceWrapped); if (!instance || typeof instance !== "object") { - return { statusText: "PREPARING", statusClass: "inactive" }; + return { + statusText: "PREPARING", + statusClass: "inactive", + perNodeStatus: [], + errorMessage: null, + }; } const inst = instance as { - shardAssignments?: { runnerToShard?: Record }; + shardAssignments?: { + runnerToShard?: Record; + nodeToRunner?: Record; + }; }; + const nodeToRunner = inst.shardAssignments?.nodeToRunner || {}; const runnerIds = Object.keys(inst.shardAssignments?.runnerToShard || {}); + const totalNodes = runnerIds.length; - const statuses = runnerIds - .map((rid) => { - const r = runnersData[rid]; - if (!r) return null; - const [kind] = getTagged(r); - const statusMap: Record = { - RunnerWaitingForInitialization: "WaitingForInitialization", - RunnerInitializingBackend: "InitializingBackend", - RunnerWaitingForModel: "WaitingForModel", - RunnerLoading: "Loading", - RunnerLoaded: "Loaded", - RunnerWarmingUp: "WarmingUp", - RunnerReady: "Ready", - RunnerRunning: "Running", - RunnerShutdown: "Shutdown", - RunnerFailed: "Failed", - }; - return kind ? statusMap[kind] || null : null; - }) - .filter((s): s is string => s !== null); + // Build per-node status and extract error messages from RunnerFailed + const perNodeStatus: PerNodeRunnerStatus[] = []; + const statuses: string[] = []; + const failedErrors: string[] = []; + for (const [nodeId, runnerId] of Object.entries(nodeToRunner)) { + const r = runnersData[runnerId]; + let status: string | null = null; + if (r) { + const [kind, runnerData] = getTagged(r); + status = kind ? RUNNER_STATUS_MAP[kind] || null : null; + // Extract error message from RunnerFailed + if ( + kind === "RunnerFailed" && + runnerData && + typeof runnerData === "object" + ) { + const rd = runnerData as { errorMessage?: string }; + if (rd.errorMessage) + failedErrors.push(`${getNodeName(nodeId)}: ${rd.errorMessage}`); + } + } + if (status) { + statuses.push(status); + perNodeStatus.push({ + nodeId, + nodeName: getNodeName(nodeId), + status: RUNNER_STATUS_DISPLAY[status] || status, + }); + } + } const has = (s: string) => statuses.includes(s); + const count = (s: string) => statuses.filter((v) => v === s).length; if (statuses.length === 0) - return { statusText: "PREPARING", statusClass: "inactive" }; - if (has("Failed")) return { statusText: "FAILED", statusClass: "failed" }; + return { + statusText: "PREPARING", + statusClass: "inactive", + perNodeStatus, + errorMessage: null, + }; + if (has("Failed")) + return { + statusText: "FAILED", + statusClass: "failed", + perNodeStatus, + errorMessage: failedErrors.length > 0 ? failedErrors.join("; ") : null, + }; if (has("Shutdown")) - return { statusText: "SHUTDOWN", statusClass: "inactive" }; - if (has("Loading")) - return { statusText: "LOADING", statusClass: "starting" }; - if (has("WarmingUp")) - return { statusText: "WARMING UP", statusClass: "starting" }; - if (has("Running")) - return { statusText: "RUNNING", statusClass: "running" }; - if (has("Ready")) return { statusText: "READY", statusClass: "loaded" }; - if (has("Loaded")) return { statusText: "LOADED", statusClass: "loaded" }; - if (has("WaitingForModel")) - return { statusText: "WAITING", statusClass: "starting" }; - if (has("InitializingBackend")) - return { statusText: "INITIALIZING", statusClass: "starting" }; - if (has("WaitingForInitialization")) - return { statusText: "INITIALIZING", statusClass: "starting" }; + return { + statusText: "SHUTDOWN", + statusClass: "inactive", + perNodeStatus, + errorMessage: null, + }; - return { statusText: "RUNNING", statusClass: "active" }; + // For loading/warming states, show node progress when multi-node + if (has("Loading")) { + const readyCount = count("Ready") + count("Running") + count("Loaded"); + const statusText = + totalNodes > 1 + ? `LOADING (${readyCount}/${totalNodes} nodes ready)` + : "LOADING"; + return { + statusText, + statusClass: "starting", + perNodeStatus, + errorMessage: null, + }; + } + if (has("WarmingUp")) { + const readyCount = count("Ready") + count("Running"); + const statusText = + totalNodes > 1 + ? `WARMING UP (${readyCount}/${totalNodes} nodes ready)` + : "WARMING UP"; + return { + statusText, + statusClass: "starting", + perNodeStatus, + errorMessage: null, + }; + } + + if (has("Running")) + return { + statusText: "RUNNING", + statusClass: "running", + perNodeStatus, + errorMessage: null, + }; + if (has("Ready")) + return { + statusText: "READY", + statusClass: "loaded", + perNodeStatus, + errorMessage: null, + }; + if (has("Loaded")) + return { + statusText: "LOADED", + statusClass: "loaded", + perNodeStatus, + errorMessage: null, + }; + if (has("WaitingForModel")) + return { + statusText: "WAITING", + statusClass: "starting", + perNodeStatus, + errorMessage: null, + }; + if (has("InitializingBackend")) + return { + statusText: "INITIALIZING", + statusClass: "starting", + perNodeStatus, + errorMessage: null, + }; + if (has("WaitingForInitialization")) + return { + statusText: "INITIALIZING", + statusClass: "starting", + perNodeStatus, + errorMessage: null, + }; + + return { + statusText: "RUNNING", + statusClass: "active", + perNodeStatus, + errorMessage: null, + }; } function getBytes(value: unknown): number { @@ -1255,6 +1459,75 @@ } } + async function deleteMetaInstance(metaInstanceId: string) { + const meta = metaInstancesData[metaInstanceId]; + const modelId = meta?.modelId ?? "unknown"; + if (!confirm(`Delete model ${modelId}?`)) return; + + const wasSelected = selectedChatModel() === modelId; + + try { + const response = await fetch(`/meta_instance/${metaInstanceId}`, { + method: "DELETE", + headers: { "Content-Type": "application/json" }, + }); + + if (!response.ok) { + console.error("Failed to delete meta instance:", response.status); + } else if (wasSelected) { + // Switch to another available model or clear selection + const remainingInstances = Object.entries(instanceData).filter( + ([id]) => id !== getBackingInstanceId(metaInstanceId), + ); + if (remainingInstances.length > 0) { + const [, lastInstance] = + remainingInstances[remainingInstances.length - 1]; + const newModelId = getInstanceModelId(lastInstance); + if ( + newModelId && + newModelId !== "Unknown" && + newModelId !== "Unknown Model" + ) { + setSelectedChatModel(newModelId); + } else { + setSelectedChatModel(""); + } + } else { + setSelectedChatModel(""); + } + } + } catch (error) { + console.error("Error deleting meta instance:", error); + } + } + + // Find the backing Instance ID for a MetaInstance by scanning instances + function getBackingInstanceId(metaInstanceId: string): string | null { + for (const [id, inst] of Object.entries(instanceData)) { + const [, inner] = getTagged(inst); + if ( + inner && + typeof inner === "object" && + (inner as Record).metaInstanceId === metaInstanceId + ) { + return id; + } + } + return null; + } + + // Get orphan Instance IDs (not backing any MetaInstance) + function getOrphanInstanceIds(): string[] { + return Object.keys(instanceData).filter((id) => { + const [, inner] = getTagged(instanceData[id]); + return ( + !inner || + typeof inner !== "object" || + !(inner as Record).metaInstanceId + ); + }); + } + // Helper to unwrap tagged unions like { MlxRingInstance: {...} } function getTagged(obj: unknown): [string | null, unknown] { if (!obj || typeof obj !== "object") return [null, null]; @@ -1295,11 +1568,7 @@ // Instance type from tag let instanceType = "Unknown"; if (instanceTag === "MlxRingInstance") instanceType = "MLX Ring"; - else if ( - instanceTag === "MlxIbvInstance" || - instanceTag === "MlxJacclInstance" - ) - instanceType = "MLX RDMA"; + else if (instanceTag === "MlxJacclInstance") instanceType = "MLX RDMA"; const inst = instance as { shardAssignments?: { @@ -1647,7 +1916,51 @@ } const nodeCount = $derived(data ? Object.keys(data.nodes).length : 0); - const instanceCount = $derived(Object.keys(instanceData).length); + const metaInstanceCount = $derived(Object.keys(metaInstancesData).length); + const orphanInstanceIds = $derived(getOrphanInstanceIds()); + const instanceCount = $derived(metaInstanceCount + orphanInstanceIds.length); + + // Unified display items: MetaInstances first, then orphan Instances + interface DisplayItem { + id: string; // MetaInstance ID or Instance ID (used as key and displayed) + modelId: string; + instance: unknown | null; // The backing/orphan instance (tagged union) or null if placing + instanceId: string | null; // The actual Instance ID (for topology hover) + isMetaInstance: boolean; + sharding: string | null; // From MetaInstance constraints (used when instance is null) + instanceMeta: string | null; // From MetaInstance constraints (used when instance is null) + } + + const unifiedDisplayItems = $derived.by((): DisplayItem[] => { + const items: DisplayItem[] = []; + // MetaInstances + for (const [metaId, meta] of Object.entries(metaInstancesData)) { + const backingId = getBackingInstanceId(metaId); + items.push({ + id: metaId, + modelId: meta.modelId, + instance: backingId ? instanceData[backingId] : null, + instanceId: backingId, + isMetaInstance: true, + sharding: meta.sharding, + instanceMeta: meta.instanceMeta, + }); + } + // Orphan Instances + for (const orphanId of getOrphanInstanceIds()) { + const inst = instanceData[orphanId]; + items.push({ + id: orphanId, + modelId: getInstanceModelId(inst), + instance: inst, + instanceId: orphanId, + isMetaInstance: false, + sharding: null, + instanceMeta: null, + }); + } + return items; + }); // Helper to get the number of nodes in a placement preview function getPreviewNodeCount(preview: PlacementPreview): number { @@ -1765,8 +2078,71 @@ {#snippet clusterWarnings()} - {#if tbBridgeCycles.length > 0 || macosVersionMismatch || (tb5WithoutRdma && !tb5InfoDismissed)} + {#if tbBridgeCycles.length > 0 || macosVersionMismatch || (tb5WithoutRdma && !tb5InfoDismissed) || (jacclError && jacclError !== jacclDismissedError)}
+ {#if jacclError && jacclError !== jacclDismissedError} + + {/if} + {#if tbBridgeCycles.length > 0} {@const cycle = tbBridgeCycles[0]} {@const serviceName = getTbBridgeServiceName(cycle)} @@ -1935,8 +2311,29 @@ {/snippet} {#snippet clusterWarningsCompact()} - {#if tbBridgeCycles.length > 0 || macosVersionMismatch || (tb5WithoutRdma && !tb5InfoDismissed)} + {#if tbBridgeCycles.length > 0 || macosVersionMismatch || (tb5WithoutRdma && !tb5InfoDismissed) || (jacclError && jacclError !== jacclDismissedError)}
+ {#if jacclError && jacclError !== jacclDismissedError} +
+ + + + JACCL ERROR +
+ {/if} {#if tbBridgeCycles.length > 0}
- {#each Object.entries(instanceData) as [id, instance]} - {@const downloadInfo = getInstanceDownloadStatus( - id, - instance, - )} + {#each unifiedDisplayItems as item (item.id)} + {@const id = item.id} + {@const instance = item.instance} + {@const downloadInfo = instance + ? getInstanceDownloadStatus(item.instanceId ?? id, instance) + : getMetaInstancePlacingStatus(id)} + {@const metaData = item.isMetaInstance + ? metaInstancesData[id] + : null} + {@const retryError = + metaData?.lastFailureError && !downloadInfo.isFailed + ? metaData.consecutiveFailures > 0 + ? `(${((metaData.consecutiveFailures - 1) % 3) + 1}/3) ${metaData.lastFailureError}` + : metaData.lastFailureError + : null} {@const statusText = downloadInfo.statusText} {@const isDownloading = downloadInfo.isDownloading} - {@const isFailed = statusText === "FAILED"} + {@const isFailed = + statusText === "FAILED" || + statusText === "PLACEMENT FAILED"} {@const isLoading = - statusText === "LOADING" || - statusText === "WARMING UP" || - statusText === "WAITING"} + statusText.startsWith("LOADING") || + statusText.startsWith("WARMING UP") || + statusText === "WAITING" || + statusText === "PLACING" || + statusText.startsWith("RETRYING")} {@const isReady = statusText === "READY" || statusText === "LOADED"} {@const isRunning = statusText === "RUNNING"} - {@const instanceModelId = getInstanceModelId(instance)} - {@const instanceInfo = getInstanceInfo(instance)} - {@const instanceConnections = - getInstanceConnections(instance)} + {@const instanceModelId = item.modelId} + {@const instanceInfo = instance + ? getInstanceInfo(instance) + : { + instanceType: + item.instanceMeta === "MlxRing" + ? "MLX Ring" + : item.instanceMeta === "MlxJaccl" + ? "MLX RDMA" + : "Unknown", + sharding: item.sharding ?? "Unknown", + nodeNames: [] as string[], + nodeIds: [] as string[], + nodeCount: 0, + }} + {@const instanceConnections = instance + ? getInstanceConnections(instance) + : []}
(hoveredInstanceId = id)} + onmouseenter={() => + (hoveredInstanceId = item.instanceId ?? id)} onmouseleave={() => (hoveredInstanceId = null)} onclick={() => { if ( @@ -2438,7 +2864,10 @@ >
@@ -2884,21 +3337,21 @@
diff --git a/src/exo/download/coordinator.py b/src/exo/download/coordinator.py index db13ccef7..f661c8782 100644 --- a/src/exo/download/coordinator.py +++ b/src/exo/download/coordinator.py @@ -314,7 +314,17 @@ class DownloadCoordinator: ), ) elif progress.status in ["in_progress", "not_started"]: - if progress.downloaded_bytes_this_session.in_bytes == 0: + if ( + progress.downloaded_bytes.in_bytes + >= progress.total_bytes.in_bytes + > 0 + ): + status = DownloadCompleted( + node_id=self.node_id, + shard_metadata=progress.shard, + total_bytes=progress.total_bytes, + ) + elif progress.downloaded_bytes_this_session.in_bytes == 0: status = DownloadPending( node_id=self.node_id, shard_metadata=progress.shard, diff --git a/src/exo/main.py b/src/exo/main.py index 1d3589756..8f0c5a417 100644 --- a/src/exo/main.py +++ b/src/exo/main.py @@ -254,7 +254,7 @@ def main(): target = min(max(soft, 65535), hard) resource.setrlimit(resource.RLIMIT_NOFILE, (target, hard)) - mp.set_start_method("spawn") + mp.set_start_method("spawn", force=True) # TODO: Refactor the current verbosity system logger_setup(EXO_LOG, args.verbosity) logger.info("Starting EXO") diff --git a/src/exo/master/api.py b/src/exo/master/api.py index b84763343..91c74f41d 100644 --- a/src/exo/master/api.py +++ b/src/exo/master/api.py @@ -71,8 +71,11 @@ from exo.shared.types.api import ( ChatCompletionResponse, CreateInstanceParams, CreateInstanceResponse, + CreateMetaInstanceParams, + CreateMetaInstanceResponse, DeleteDownloadResponse, DeleteInstanceResponse, + DeleteMetaInstanceResponse, ErrorInfo, ErrorResponse, FinishReason, @@ -115,8 +118,10 @@ from exo.shared.types.claude_api import ( from exo.shared.types.commands import ( Command, CreateInstance, + CreateMetaInstance, DeleteDownload, DeleteInstance, + DeleteMetaInstance, DownloadCommand, ForwarderCommand, ForwarderDownloadCommand, @@ -129,7 +134,7 @@ from exo.shared.types.commands import ( TaskFinished, TextGeneration, ) -from exo.shared.types.common import CommandId, Id, NodeId, SessionId +from exo.shared.types.common import CommandId, Id, MetaInstanceId, NodeId, SessionId from exo.shared.types.events import ( ChunkGenerated, Event, @@ -138,6 +143,7 @@ from exo.shared.types.events import ( TracesMerged, ) from exo.shared.types.memory import Memory +from exo.shared.types.meta_instance import MetaInstance from exo.shared.types.openai_responses import ( ResponsesRequest, ResponsesResponse, @@ -276,6 +282,9 @@ class API: self.app.get("/instance/previews")(self.get_placement_previews) self.app.get("/instance/{instance_id}")(self.get_instance) self.app.delete("/instance/{instance_id}")(self.delete_instance) + self.app.get("/meta_instances")(self.list_meta_instances) + self.app.post("/meta_instance")(self.create_meta_instance) + self.app.delete("/meta_instance/{meta_instance_id}")(self.delete_meta_instance) self.app.get("/models")(self.get_models) self.app.get("/v1/models")(self.get_models) self.app.post("/models/add")(self.add_custom_model) @@ -305,12 +314,27 @@ class API: self.app.get("/v1/traces/{task_id}/raw")(self.get_trace_raw) async def place_instance(self, payload: PlaceInstanceParams): + model_card = await ModelCard.load(payload.model_id) command = PlaceInstance( - model_card=await ModelCard.load(payload.model_id), + model_card=model_card, sharding=payload.sharding, instance_meta=payload.instance_meta, min_nodes=payload.min_nodes, ) + + # Validate placement before sending — fail fast with a clear error + # instead of silently dropping the command in the master. + try: + get_instance_placements( + command, + topology=self.state.topology, + current_instances=self.state.instances, + node_memory=self.state.node_memory, + node_network=self.state.node_network, + ) + except ValueError as exc: + raise HTTPException(status_code=400, detail=str(exc)) from exc + await self._send(command) return CreateInstanceResponse( @@ -522,6 +546,44 @@ class API: instance_id=instance_id, ) + def list_meta_instances(self) -> dict[MetaInstanceId, MetaInstance]: + return dict(self.state.meta_instances) + + async def create_meta_instance( + self, payload: CreateMetaInstanceParams + ) -> CreateMetaInstanceResponse: + meta_instance = MetaInstance( + model_id=payload.model_id, + sharding=payload.sharding, + instance_meta=payload.instance_meta, + min_nodes=payload.min_nodes, + node_ids=payload.node_ids, + ) + command = CreateMetaInstance(meta_instance=meta_instance) + await self._send(command) + return CreateMetaInstanceResponse( + message="Command received.", + command_id=command.command_id, + meta_instance_id=meta_instance.meta_instance_id, + ) + + async def delete_meta_instance( + self, meta_instance_id: MetaInstanceId + ) -> DeleteMetaInstanceResponse: + meta = self.state.meta_instances.get(meta_instance_id) + if not meta: + raise HTTPException(status_code=404, detail="MetaInstance not found") + + # Command processor handles cascade-deleting backing instances + command = DeleteMetaInstance(meta_instance_id=meta_instance_id) + await self._send(command) + + return DeleteMetaInstanceResponse( + message="Command received.", + command_id=command.command_id, + meta_instance_id=meta_instance_id, + ) + async def _token_chunk_stream( self, command_id: CommandId ) -> AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None]: @@ -541,10 +603,10 @@ class API: break except anyio.get_cancelled_exc_class(): - command = TaskCancelled(cancelled_command_id=command_id) + cancel_command = TaskCancelled(cancelled_command_id=command_id) with anyio.CancelScope(shield=True): await self.command_sender.send( - ForwarderCommand(origin=self.node_id, command=command) + ForwarderCommand(origin=self.node_id, command=cancel_command) ) raise finally: @@ -884,10 +946,10 @@ class API: del image_metadata[key] except anyio.get_cancelled_exc_class(): - command = TaskCancelled(cancelled_command_id=command_id) + cancel_command = TaskCancelled(cancelled_command_id=command_id) with anyio.CancelScope(shield=True): await self.command_sender.send( - ForwarderCommand(origin=self.node_id, command=command) + ForwarderCommand(origin=self.node_id, command=cancel_command) ) raise finally: @@ -970,10 +1032,10 @@ class API: return (images, stats if capture_stats else None) except anyio.get_cancelled_exc_class(): - command = TaskCancelled(cancelled_command_id=command_id) + cancel_command = TaskCancelled(cancelled_command_id=command_id) with anyio.CancelScope(shield=True): await self.command_sender.send( - ForwarderCommand(origin=self.node_id, command=command) + ForwarderCommand(origin=self.node_id, command=cancel_command) ) raise finally: diff --git a/src/exo/master/main.py b/src/exo/master/main.py index 9c7cf5785..405f6495b 100644 --- a/src/exo/master/main.py +++ b/src/exo/master/main.py @@ -1,4 +1,5 @@ -from datetime import datetime, timedelta, timezone +from collections.abc import Sequence +from datetime import datetime, timezone import anyio from anyio.abc import TaskGroup @@ -12,11 +13,22 @@ from exo.master.placement import ( get_transition_events, place_instance, ) +from exo.master.process_managers import ProcessManager +from exo.master.process_managers.instance_health import InstanceHealthReconciler +from exo.master.process_managers.meta_instance import MetaInstanceReconciler +from exo.master.process_managers.node_timeout import NodeTimeoutReconciler +from exo.master.reconcile import ( + find_unsatisfied_meta_instances, + try_place_for_meta_instance, +) from exo.shared.apply import apply from exo.shared.constants import EXO_EVENT_LOG_DIR, EXO_TRACING_ENABLED +from exo.shared.models.model_cards import ModelCard from exo.shared.types.commands import ( CreateInstance, + CreateMetaInstance, DeleteInstance, + DeleteMetaInstance, ForwarderCommand, ForwarderDownloadCommand, ImageEdits, @@ -36,8 +48,12 @@ from exo.shared.types.events import ( IndexedEvent, InputChunkReceived, InstanceDeleted, + JacclSideChannelData, + JacclSideChannelGathered, + MetaInstanceCreated, + MetaInstanceDeleted, + MetaInstancePlacementFailed, NodeGatheredInfo, - NodeTimedOut, TaskCreated, TaskDeleted, TaskStatusUpdated, @@ -60,7 +76,8 @@ from exo.shared.types.tasks import ( TextGeneration as TextGenerationTask, ) from exo.shared.types.worker.instances import InstanceId -from exo.utils.channels import Receiver, Sender, channel +from exo.shared.types.worker.runners import RunnerId +from exo.utils.channels import Receiver, Sender from exo.utils.event_buffer import MultiSourceBuffer @@ -84,16 +101,16 @@ class Master: self.local_event_receiver = local_event_receiver self.global_event_sender = global_event_sender self.download_command_sender = download_command_sender - send, recv = channel[Event]() - self.event_sender: Sender[Event] = send - self._loopback_event_receiver: Receiver[Event] = recv - self._loopback_event_sender: Sender[ForwarderEvent] = ( - local_event_receiver.clone_sender() - ) self._multi_buffer = MultiSourceBuffer[NodeId, Event]() self._event_log = DiskEventLog(EXO_EVENT_LOG_DIR / "master") self._pending_traces: dict[TaskId, dict[int, list[TraceEventData]]] = {} self._expected_ranks: dict[TaskId, set[int]] = {} + self._jaccl_pending: dict[InstanceId, dict[int, dict[RunnerId, bytes]]] = {} + self._process_managers: Sequence[ProcessManager] = [ + InstanceHealthReconciler(), + NodeTimeoutReconciler(), + MetaInstanceReconciler(), + ] async def run(self): logger.info("Starting Master") @@ -102,15 +119,12 @@ class Master: async with self._tg as tg: tg.start_soon(self._event_processor) tg.start_soon(self._command_processor) - tg.start_soon(self._loopback_processor) - tg.start_soon(self._plan) + tg.start_soon(self._reconcile) finally: self._event_log.close() self.global_event_sender.close() self.local_event_receiver.close() self.command_receiver.close() - self._loopback_event_sender.close() - self._loopback_event_receiver.close() async def shutdown(self): logger.info("Stopping Master") @@ -292,6 +306,86 @@ class Master: ) ) generated_events.extend(transition_events) + case CreateMetaInstance(): + logger.info( + f"Creating MetaInstance for {command.meta_instance.model_id}" + f" (min_nodes={command.meta_instance.min_nodes}," + f" sharding={command.meta_instance.sharding})" + ) + # Apply immediately so self.state is fresh across + # the await below and the reconciler won't race. + await self._apply_and_broadcast( + MetaInstanceCreated(meta_instance=command.meta_instance) + ) + # Immediate placement attempt for responsiveness + model_card = await ModelCard.load( + command.meta_instance.model_id + ) + # Re-check: reconciler may have satisfied it during the await + meta_id = command.meta_instance.meta_instance_id + still_unsatisfied = any( + m.meta_instance_id == meta_id + for m in find_unsatisfied_meta_instances( + self.state.meta_instances, + self.state.instances, + self.state.topology, + ) + ) + if still_unsatisfied: + result = try_place_for_meta_instance( + command.meta_instance, + model_card, + self.state.topology, + self.state.instances, + self.state.node_memory, + self.state.node_network, + self.state.tasks, + ) + generated_events.extend(result.events) + if result.error is not None: + generated_events.append( + MetaInstancePlacementFailed( + meta_instance_id=meta_id, + reason=result.error, + ) + ) + case DeleteMetaInstance(): + backing_count = sum( + 1 + for inst in self.state.instances.values() + if inst.meta_instance_id == command.meta_instance_id + ) + logger.info( + f"Deleting MetaInstance {command.meta_instance_id}" + f" (cascade-deleting {backing_count} backing instance(s))" + ) + generated_events.append( + MetaInstanceDeleted( + meta_instance_id=command.meta_instance_id + ) + ) + # Cascade-delete backing instances atomically, + # cancelling any active tasks first. + for iid, inst in self.state.instances.items(): + if inst.meta_instance_id == command.meta_instance_id: + for task in self.state.tasks.values(): + if ( + task.instance_id == iid + and task.task_status + in ( + TaskStatus.Pending, + TaskStatus.Running, + ) + ): + generated_events.append( + TaskStatusUpdated( + task_status=TaskStatus.Cancelled, + task_id=task.task_id, + ) + ) + generated_events.append( + InstanceDeleted(instance_id=iid) + ) case PlaceInstance(): placement = place_instance( command, @@ -323,16 +417,19 @@ class Master: ) case TaskCancelled(): if ( - task_id := self.command_task_mapping.get( - command.cancelled_command_id - ) - ) is not None: + command.cancelled_command_id + in self.command_task_mapping + ): generated_events.append( - TaskStatusUpdated( - task_status=TaskStatus.Cancelled, - task_id=task_id, + TaskDeleted( + task_id=self.command_task_mapping[ + command.cancelled_command_id + ] ) ) + del self.command_task_mapping[ + command.cancelled_command_id + ] case TaskFinished(): generated_events.append( TaskDeleted( @@ -341,9 +438,10 @@ class Master: ] ) ) - self.command_task_mapping.pop( - command.finished_command_id, None - ) + if command.finished_command_id in self.command_task_mapping: + del self.command_task_mapping[ + command.finished_command_id + ] case RequestEventLog(): # We should just be able to send everything, since other buffers will ignore old messages # rate limit to 1000 at a time @@ -354,31 +452,32 @@ class Master: ): await self._send_event(IndexedEvent(idx=i, event=event)) for event in generated_events: - await self.event_sender.send(event) + await self._apply_and_broadcast(event) except ValueError as e: logger.opt(exception=e).warning("Error in command processor") - # These plan loops are the cracks showing in our event sourcing architecture - more things could be commands - async def _plan(self) -> None: + async def _apply_and_broadcast(self, event: Event) -> None: + """Apply event to state, persist to disk, and broadcast to workers. + + State is updated synchronously (before any await), so callers can + rely on ``self.state`` reflecting this event immediately after the + call. Python's cooperative scheduling guarantees no interleaving + between the state read and write. + """ + logger.debug(f"Master indexing event: {str(event)[:100]}") + indexed = IndexedEvent(event=event, idx=len(self._event_log)) + self.state = apply(self.state, indexed) + event._master_time_stamp = datetime.now(tz=timezone.utc) # pyright: ignore[reportPrivateUsage] + self._event_log.append(event) + await self._send_event(indexed) + + async def _reconcile(self) -> None: while True: - # kill broken instances - connected_node_ids = set(self.state.topology.list_nodes()) - for instance_id, instance in self.state.instances.items(): - for node_id in instance.shard_assignments.node_to_runner: - if node_id not in connected_node_ids: - await self.event_sender.send( - InstanceDeleted(instance_id=instance_id) - ) - break - - # time out dead nodes - for node_id, time in self.state.last_seen.items(): - now = datetime.now(tz=timezone.utc) - if now - time > timedelta(seconds=30): - logger.info(f"Manually removing node {node_id} due to inactivity") - await self.event_sender.send(NodeTimedOut(node_id=node_id)) - - await anyio.sleep(10) + for pm in self._process_managers: + events = await pm.reconcile(self.state) + for event in events: + await self._apply_and_broadcast(event) + await anyio.sleep(1) async def _event_processor(self) -> None: with self.local_event_receiver as local_events: @@ -396,32 +495,15 @@ class Master: await self._handle_traces_collected(event) continue - logger.debug(f"Master indexing event: {str(event)[:100]}") - indexed = IndexedEvent(event=event, idx=len(self._event_log)) - self.state = apply(self.state, indexed) + if isinstance(event, JacclSideChannelData): + await self._apply_and_broadcast(event) + await self._handle_jaccl_side_channel(event) + continue - event._master_time_stamp = datetime.now(tz=timezone.utc) # pyright: ignore[reportPrivateUsage] if isinstance(event, NodeGatheredInfo): event.when = str(datetime.now(tz=timezone.utc)) - self._event_log.append(event) - await self._send_event(indexed) - - async def _loopback_processor(self) -> None: - # this would ideally not be necessary. - # this is WAY less hacky than how I was working around this before - local_index = 0 - with self._loopback_event_receiver as events: - async for event in events: - await self._loopback_event_sender.send( - ForwarderEvent( - origin=NodeId(f"master_{self.node_id}"), - origin_idx=local_index, - session=self.session_id, - event=event, - ) - ) - local_index += 1 + await self._apply_and_broadcast(event) # This function is re-entrant, take care! async def _send_event(self, event: IndexedEvent): @@ -453,10 +535,49 @@ class Master: for trace_data in self._pending_traces[task_id].values(): all_trace_data.extend(trace_data) - await self.event_sender.send( + await self._apply_and_broadcast( TracesMerged(task_id=task_id, traces=all_trace_data) ) del self._pending_traces[task_id] if task_id in self._expected_ranks: del self._expected_ranks[task_id] + + async def _handle_jaccl_side_channel(self, event: JacclSideChannelData) -> None: + """Accumulate SideChannel contributions; when all runners for an instance + have submitted for the same sequence, emit JacclSideChannelGathered.""" + iid = event.instance_id + seq = event.sequence + + if iid not in self._jaccl_pending: + self._jaccl_pending[iid] = {} + if seq not in self._jaccl_pending[iid]: + self._jaccl_pending[iid][seq] = {} + self._jaccl_pending[iid][seq][event.runner_id] = event.data + + instance = self.state.instances.get(iid) + if instance is None: + logger.warning(f"JacclSideChannelData for unknown instance {iid}") + return + + expected_runners = set(instance.shard_assignments.runner_to_shard.keys()) + submitted = set(self._jaccl_pending[iid][seq].keys()) + + logger.info( + f"JACCL side channel: instance={iid} seq={seq} " + f"submitted={len(submitted)}/{len(expected_runners)}" + ) + + if submitted >= expected_runners: + gathered = dict(self._jaccl_pending[iid][seq]) + del self._jaccl_pending[iid][seq] + if not self._jaccl_pending[iid]: + del self._jaccl_pending[iid] + + await self._apply_and_broadcast( + JacclSideChannelGathered( + instance_id=iid, + sequence=seq, + gathered_data=gathered, + ) + ) diff --git a/src/exo/master/placement.py b/src/exo/master/placement.py index cf31ca789..ab886c3e7 100644 --- a/src/exo/master/placement.py +++ b/src/exo/master/placement.py @@ -6,11 +6,11 @@ from typing import Sequence from exo.master.placement_utils import ( Cycle, filter_cycles_by_memory, + get_largest_cycles, get_mlx_jaccl_coordinators, get_mlx_jaccl_devices_matrix, get_mlx_ring_hosts_by_node, get_shard_assignments, - get_smallest_cycles, ) from exo.shared.models.model_cards import ModelId from exo.shared.topology import Topology @@ -106,23 +106,27 @@ def place_instance( "Pipeline parallelism is not supported for DeepSeek V3.1 (8-bit)" ) - smallest_cycles = get_smallest_cycles(cycles_with_sufficient_memory) + largest_cycles = get_largest_cycles(cycles_with_sufficient_memory) - smallest_rdma_cycles = [ - cycle for cycle in smallest_cycles if topology.is_rdma_cycle(cycle) + largest_rdma_cycles = [ + cycle for cycle in largest_cycles if topology.is_rdma_cycle(cycle) ] - if command.instance_meta == InstanceMeta.MlxJaccl and smallest_rdma_cycles != []: - smallest_cycles = smallest_rdma_cycles + if command.instance_meta == InstanceMeta.MlxJaccl: + if not largest_rdma_cycles: + raise ValueError( + "Requested RDMA (MlxJaccl) but no RDMA-connected cycles available" + ) + largest_cycles = largest_rdma_cycles cycles_with_leaf_nodes: list[Cycle] = [ cycle - for cycle in smallest_cycles + for cycle in largest_cycles if any(topology.node_is_leaf(node_id) for node_id in cycle) ] selected_cycle = max( - cycles_with_leaf_nodes if cycles_with_leaf_nodes != [] else smallest_cycles, + cycles_with_leaf_nodes if cycles_with_leaf_nodes != [] else largest_cycles, key=lambda cycle: sum( (node_memory[node_id].ram_available for node_id in cycle), start=Memory(), diff --git a/src/exo/master/placement_utils.py b/src/exo/master/placement_utils.py index b20a39ccd..d47c40c0b 100644 --- a/src/exo/master/placement_utils.py +++ b/src/exo/master/placement_utils.py @@ -37,11 +37,11 @@ def filter_cycles_by_memory( return filtered_cycles -def get_smallest_cycles( +def get_largest_cycles( cycles: list[Cycle], ) -> list[Cycle]: - min_nodes = min(len(cycle) for cycle in cycles) - return [cycle for cycle in cycles if len(cycle) == min_nodes] + max_nodes = max(len(cycle) for cycle in cycles) + return [cycle for cycle in cycles if len(cycle) == max_nodes] def allocate_layers_proportionally( diff --git a/src/exo/master/process_managers/__init__.py b/src/exo/master/process_managers/__init__.py new file mode 100644 index 000000000..f6a3ee0ac --- /dev/null +++ b/src/exo/master/process_managers/__init__.py @@ -0,0 +1,12 @@ +from collections.abc import Sequence +from typing import Protocol, runtime_checkable + +from exo.shared.types.events import Event +from exo.shared.types.state import State + + +@runtime_checkable +class ProcessManager(Protocol): + """A reconciliation step that examines state and returns corrective events.""" + + async def reconcile(self, state: State) -> Sequence[Event]: ... diff --git a/src/exo/master/process_managers/instance_health.py b/src/exo/master/process_managers/instance_health.py new file mode 100644 index 000000000..f5f8e922b --- /dev/null +++ b/src/exo/master/process_managers/instance_health.py @@ -0,0 +1,62 @@ +from collections.abc import Sequence +from typing import final + +from loguru import logger + +from exo.master.reconcile import instance_connections_healthy, instance_runners_failed +from exo.shared.types.events import Event, InstanceDeleted, InstanceRetrying +from exo.shared.types.state import State + +MAX_INSTANCE_RETRIES = 3 + + +@final +class InstanceHealthReconciler: + """Delete instances whose network connections are broken or whose runners have all failed.""" + + async def reconcile(self, state: State) -> Sequence[Event]: + events: list[Event] = [] + for instance_id, instance in state.instances.items(): + if not instance_connections_healthy(instance, state.topology): + events.append( + InstanceDeleted( + instance_id=instance_id, + failure_error="Network connection lost", + ) + ) + continue + + is_failed, error_message = instance_runners_failed( + instance, state.runners, state.node_identities + ) + if is_failed: + # Retry within the same instance if backed by a MetaInstance + mid = instance.meta_instance_id + mi = state.meta_instances.get(mid) if mid else None + if mid and mi and mi.consecutive_failures < MAX_INSTANCE_RETRIES: + logger.info( + f"Instance {instance_id} failed (attempt" + f" {mi.consecutive_failures + 1}/{MAX_INSTANCE_RETRIES})," + f" retrying: {error_message}" + ) + events.append( + InstanceRetrying( + instance_id=instance_id, + meta_instance_id=mid, + failure_error=error_message or "Runner failed", + ) + ) + else: + if mid and mi: + logger.warning( + f"Instance {instance_id} exceeded retry limit" + f" ({MAX_INSTANCE_RETRIES}), deleting:" + f" {error_message}" + ) + events.append( + InstanceDeleted( + instance_id=instance_id, + failure_error=error_message, + ) + ) + return events diff --git a/src/exo/master/process_managers/meta_instance.py b/src/exo/master/process_managers/meta_instance.py new file mode 100644 index 000000000..93037ea8e --- /dev/null +++ b/src/exo/master/process_managers/meta_instance.py @@ -0,0 +1,92 @@ +from collections.abc import Sequence +from typing import final + +import anyio +from loguru import logger + +from exo.master.reconcile import ( + find_unsatisfied_meta_instances, + try_place_for_meta_instance, +) +from exo.shared.models.model_cards import ModelCard +from exo.shared.types.events import Event, InstanceCreated, MetaInstancePlacementFailed +from exo.shared.types.state import State +from exo.shared.types.worker.instances import Instance, InstanceId + +MODEL_CARD_LOAD_TIMEOUT_SECONDS = 10 + + +@final +class MetaInstanceReconciler: + """Place instances for unsatisfied MetaInstances.""" + + async def reconcile(self, state: State) -> Sequence[Event]: + all_events: list[Event] = [] + # Local copy for intermediate tracking — so placement of B + # sees A's instance and doesn't double-place on same resources. + current_instances: dict[InstanceId, Instance] = dict(state.instances) + + unsatisfied = find_unsatisfied_meta_instances( + state.meta_instances, + current_instances, + state.topology, + ) + for meta_instance in unsatisfied: + try: + with anyio.fail_after(MODEL_CARD_LOAD_TIMEOUT_SECONDS): + model_card = await ModelCard.load(meta_instance.model_id) + except TimeoutError: + logger.warning( + f"ModelCard.load timed out for {meta_instance.model_id}, skipping this cycle" + ) + continue + except Exception as exc: + logger.warning( + f"ModelCard.load failed for {meta_instance.model_id}: {exc}" + ) + error = f"Failed to load model card: {exc}" + if meta_instance.placement_error != error: + all_events.append( + MetaInstancePlacementFailed( + meta_instance_id=meta_instance.meta_instance_id, + reason=error, + ) + ) + continue + + result = try_place_for_meta_instance( + meta_instance, + model_card, + state.topology, + current_instances, + state.node_memory, + state.node_network, + state.tasks, + ) + # Update local instance map so next placement sees this one + for event in result.events: + if isinstance(event, InstanceCreated): + logger.info( + f"MetaInstance reconciler placed instance" + f" {event.instance.instance_id} for" + f" {meta_instance.model_id}" + ) + current_instances[event.instance.instance_id] = event.instance + all_events.extend(result.events) + + # Emit placement failure if error differs from what's already in state + if ( + result.error is not None + and meta_instance.placement_error != result.error + ): + logger.warning( + f"MetaInstance placement failed for" + f" {meta_instance.model_id}: {result.error}" + ) + all_events.append( + MetaInstancePlacementFailed( + meta_instance_id=meta_instance.meta_instance_id, + reason=result.error, + ) + ) + return all_events diff --git a/src/exo/master/process_managers/node_timeout.py b/src/exo/master/process_managers/node_timeout.py new file mode 100644 index 000000000..98045c251 --- /dev/null +++ b/src/exo/master/process_managers/node_timeout.py @@ -0,0 +1,27 @@ +from collections.abc import Sequence +from datetime import datetime, timedelta, timezone +from typing import final + +from loguru import logger + +from exo.shared.types.events import Event, NodeTimedOut +from exo.shared.types.state import State + +_DEFAULT_TIMEOUT = timedelta(seconds=30) + + +@final +class NodeTimeoutReconciler: + """Time out nodes that haven't been seen recently.""" + + def __init__(self, timeout: timedelta = _DEFAULT_TIMEOUT) -> None: + self.timeout = timeout + + async def reconcile(self, state: State) -> Sequence[Event]: + now = datetime.now(tz=timezone.utc) + events: list[Event] = [] + for node_id, last_seen in state.last_seen.items(): + if now - last_seen > self.timeout: + logger.info(f"Removing node {node_id} due to inactivity") + events.append(NodeTimedOut(node_id=node_id)) + return events diff --git a/src/exo/master/reconcile.py b/src/exo/master/reconcile.py new file mode 100644 index 000000000..4ca968b81 --- /dev/null +++ b/src/exo/master/reconcile.py @@ -0,0 +1,244 @@ +from collections.abc import Mapping, Sequence +from typing import NamedTuple + +from loguru import logger + +from exo.master.placement import get_transition_events, place_instance +from exo.shared.models.model_cards import ModelCard +from exo.shared.topology import Topology +from exo.shared.types.commands import PlaceInstance +from exo.shared.types.common import MetaInstanceId, NodeId +from exo.shared.types.events import Event +from exo.shared.types.meta_instance import MetaInstance +from exo.shared.types.profiling import MemoryUsage, NodeIdentity, NodeNetworkInfo +from exo.shared.types.tasks import Task, TaskId +from exo.shared.types.topology import RDMAConnection, SocketConnection +from exo.shared.types.worker.instances import ( + BaseInstance, + Instance, + InstanceId, + MlxJacclInstance, + MlxRingInstance, +) +from exo.shared.types.worker.runners import ( + RunnerFailed, + RunnerId, + RunnerShutdown, + RunnerStatus, +) + + +class PlacementResult(NamedTuple): + """Result of a placement attempt: events to apply and optional error reason.""" + + events: Sequence[Event] + error: str | None + + +def _get_ring_order(instance: BaseInstance) -> list[NodeId]: + """Reconstruct ring order from shard device_rank.""" + node_ranks: list[tuple[NodeId, int]] = [] + for node_id, runner_id in instance.shard_assignments.node_to_runner.items(): + shard = instance.shard_assignments.runner_to_shard[runner_id] + node_ranks.append((node_id, shard.device_rank)) + node_ranks.sort(key=lambda x: x[1]) + return [node_id for node_id, _ in node_ranks] + + +def _ring_connections_healthy(instance: MlxRingInstance, topology: Topology) -> bool: + """Check that the specific IPs used by a ring instance still exist in the topology.""" + ring = _get_ring_order(instance) + n = len(ring) + for node in ring: + hosts = instance.hosts_by_node[node] + for idx in range(n): + host = hosts[idx] + if host.ip in ("0.0.0.0", "198.51.100.1"): + continue # self or placeholder + # Real connection: node → ring[idx]. Check specific IP. + connections = topology.get_all_connections_between(node, ring[idx]) + if not any( + isinstance(c, SocketConnection) + and c.sink_multiaddr.ip_address == host.ip + for c in connections + ): + return False + return True + + +def _jaccl_connections_healthy(instance: MlxJacclInstance, topology: Topology) -> bool: + """Check that the specific RDMA interfaces used by a JACCL instance still exist.""" + ring = _get_ring_order(instance) + n = len(ring) + for i in range(n): + for j in range(n): + iface = instance.jaccl_devices[i][j] + if iface is None: + continue + connections = topology.get_all_connections_between(ring[i], ring[j]) + if not any( + isinstance(c, RDMAConnection) and c.source_rdma_iface == iface + for c in connections + ): + return False + return True + + +def instance_connections_healthy(instance: Instance, topology: Topology) -> bool: + """Check that an instance's nodes and specific connections are still in the topology.""" + instance_nodes = set(instance.shard_assignments.node_to_runner.keys()) + if not all(topology.contains_node(n) for n in instance_nodes): + return False + if len(instance_nodes) <= 1: + return True + match instance: + case MlxRingInstance(): + return _ring_connections_healthy(instance, topology) + case MlxJacclInstance(): + return _jaccl_connections_healthy(instance, topology) + + +def instance_runners_failed( + instance: Instance, + runners: Mapping[RunnerId, RunnerStatus], + node_identities: Mapping[NodeId, NodeIdentity], +) -> tuple[bool, str | None]: + """Check if an instance's runners have all reached terminal failure states. + + Returns ``(True, error_message)`` when ALL runners are terminal + (``RunnerFailed`` or ``RunnerShutdown``) and at least one is ``RunnerFailed``. + + Returns ``(False, None)`` when runners are still active, haven't reported + yet, or all gracefully shut down (no ``RunnerFailed``). + """ + instance_runner_ids = set(instance.shard_assignments.node_to_runner.values()) + + if not instance_runner_ids: + return False, None + + # Build reverse mapping: runner_id -> node_id + runner_to_node: dict[RunnerId, NodeId] = { + runner_id: node_id + for node_id, runner_id in instance.shard_assignments.node_to_runner.items() + } + + has_any_failed = False + error_messages: list[str] = [] + + for runner_id in instance_runner_ids: + status = runners.get(runner_id) + if status is None: + # Runner hasn't reported yet — instance is still starting + return False, None + if isinstance(status, RunnerFailed): + has_any_failed = True + if status.error_message: + node_id = runner_to_node.get(runner_id) + name = ( + node_identities[node_id].friendly_name + if node_id and node_id in node_identities + else node_id or "unknown" + ) + error_messages.append(f"{name}: {status.error_message}") + elif isinstance(status, RunnerShutdown): + pass # Terminal but not a failure indicator on its own + else: + # Runner is still active (connecting, loading, running, etc.) + return False, None + + if has_any_failed: + return True, "; ".join(error_messages) if error_messages else "Runner failed" + + # All runners are Shutdown but none Failed — graceful shutdown, not a failure + return False, None + + +def instance_satisfies_meta_instance( + meta_instance: MetaInstance, + instance: Instance, +) -> bool: + """Check if a single instance satisfies a meta-instance's constraints. + + This is a pure constraint check (model, min_nodes, node_ids). + Use ``instance_connections_healthy`` separately for topology health. + """ + if instance.shard_assignments.model_id != meta_instance.model_id: + return False + + instance_nodes = set(instance.shard_assignments.node_to_runner.keys()) + + if len(instance_nodes) < meta_instance.min_nodes: + return False + + return meta_instance.node_ids is None or set(meta_instance.node_ids).issubset( + instance_nodes + ) + + +def find_unsatisfied_meta_instances( + meta_instances: Mapping[MetaInstanceId, MetaInstance], + instances: Mapping[InstanceId, Instance], + topology: Topology, +) -> Sequence[MetaInstance]: + """Return meta-instances that have no healthy backing instance.""" + unsatisfied: list[MetaInstance] = [] + for meta_id, meta_instance in meta_instances.items(): + has_healthy_backing = any( + instance.meta_instance_id == meta_id + and instance_connections_healthy(instance, topology) + for instance in instances.values() + ) + if not has_healthy_backing: + unsatisfied.append(meta_instance) + return unsatisfied + + +def try_place_for_meta_instance( + meta_instance: MetaInstance, + model_card: ModelCard, + topology: Topology, + current_instances: Mapping[InstanceId, Instance], + node_memory: Mapping[NodeId, MemoryUsage], + node_network: Mapping[NodeId, NodeNetworkInfo], + tasks: Mapping[TaskId, Task], +) -> PlacementResult: + """Try to place an instance satisfying the meta-instance constraints. + + Returns a :class:`PlacementResult` with events on success, or an error + reason on failure. + """ + command = PlaceInstance( + model_card=model_card, + sharding=meta_instance.sharding, + instance_meta=meta_instance.instance_meta, + min_nodes=meta_instance.min_nodes, + ) + try: + target_instances = place_instance( + command, + topology, + current_instances, + node_memory, + node_network, + required_nodes=( + set(meta_instance.node_ids) if meta_instance.node_ids else None + ), + ) + # Tag the new instance with meta_instance_id + new_instance_ids = set(target_instances.keys()) - set(current_instances.keys()) + if new_instance_ids: + new_id = next(iter(new_instance_ids)) + target_instances[new_id] = target_instances[new_id].model_copy( + update={"meta_instance_id": meta_instance.meta_instance_id} + ) + return PlacementResult( + events=list( + get_transition_events(current_instances, target_instances, tasks) + ), + error=None, + ) + except ValueError as e: + logger.debug( + f"MetaInstance placement not possible for {meta_instance.model_id}: {e}" + ) + return PlacementResult(events=[], error=str(e)) diff --git a/src/exo/master/tests/test_meta_instance_edge_cases.py b/src/exo/master/tests/test_meta_instance_edge_cases.py new file mode 100644 index 000000000..265548347 --- /dev/null +++ b/src/exo/master/tests/test_meta_instance_edge_cases.py @@ -0,0 +1,778 @@ +"""Edge-case and regression tests for MetaInstance lifecycle, concurrent operations, and error handling.""" + +import pytest + +from exo.master.process_managers.instance_health import ( + MAX_INSTANCE_RETRIES, + InstanceHealthReconciler, +) +from exo.master.process_managers.meta_instance import MetaInstanceReconciler +from exo.master.reconcile import ( + find_unsatisfied_meta_instances, + instance_connections_healthy, + instance_runners_failed, + instance_satisfies_meta_instance, +) +from exo.shared.apply import apply +from exo.shared.models.model_cards import ModelCard, ModelId, ModelTask +from exo.shared.topology import Topology +from exo.shared.types.common import Host, MetaInstanceId, NodeId +from exo.shared.types.events import ( + IndexedEvent, + InstanceCreated, + InstanceDeleted, + InstanceRetrying, + MetaInstanceCreated, + MetaInstanceDeleted, + MetaInstancePlacementFailed, + TaskStatusUpdated, +) +from exo.shared.types.memory import Memory +from exo.shared.types.meta_instance import MetaInstance +from exo.shared.types.multiaddr import Multiaddr +from exo.shared.types.profiling import NodeIdentity +from exo.shared.types.state import State +from exo.shared.types.tasks import LoadModel, TaskId, TaskStatus +from exo.shared.types.topology import Connection, SocketConnection +from exo.shared.types.worker.instances import ( + InstanceId, + MlxRingInstance, +) +from exo.shared.types.worker.runners import ( + RunnerFailed, + RunnerId, + RunnerReady, + ShardAssignments, +) +from exo.shared.types.worker.shards import PipelineShardMetadata + +# --- Helpers (copied from test_reconcile.py for independence) --- + + +def _model_card(model_id: str = "test-org/test-model") -> ModelCard: + return ModelCard( + model_id=ModelId(model_id), + storage_size=Memory.from_kb(1000), + n_layers=10, + hidden_size=30, + supports_tensor=True, + tasks=[ModelTask.TextGeneration], + ) + + +def _topology(*node_ids: str, connect: bool = True) -> Topology: + t = Topology() + nodes = [NodeId(n) for n in node_ids] + for n in nodes: + t.add_node(n) + if connect and len(nodes) > 1: + for i in range(len(nodes)): + j = (i + 1) % len(nodes) + t.add_connection( + Connection( + source=nodes[i], + sink=nodes[j], + edge=SocketConnection( + sink_multiaddr=Multiaddr( + address=f"/ip4/10.0.0.{j + 1}/tcp/50000" + ) + ), + ) + ) + t.add_connection( + Connection( + source=nodes[j], + sink=nodes[i], + edge=SocketConnection( + sink_multiaddr=Multiaddr( + address=f"/ip4/10.0.0.{i + 1}/tcp/50000" + ) + ), + ) + ) + return t + + +def _meta_instance( + model_id: str = "test-org/test-model", + *, + min_nodes: int = 1, + node_ids: list[NodeId] | None = None, + meta_instance_id: MetaInstanceId | None = None, + consecutive_failures: int = 0, + last_failure_error: str | None = None, + placement_error: str | None = None, +) -> MetaInstance: + return MetaInstance( + meta_instance_id=meta_instance_id or MetaInstanceId(), + model_id=ModelId(model_id), + min_nodes=min_nodes, + node_ids=node_ids, + consecutive_failures=consecutive_failures, + last_failure_error=last_failure_error, + placement_error=placement_error, + ) + + +def _instance( + model_id: str = "test-org/test-model", + node_ids: list[str] | None = None, + instance_id: InstanceId | None = None, + meta_instance_id: MetaInstanceId | None = None, +) -> tuple[InstanceId, MlxRingInstance]: + iid = instance_id or InstanceId() + nodes = node_ids or ["node-a"] + n = len(nodes) + mc = _model_card(model_id) + ephemeral_port = 50000 + node_to_runner = {NodeId(nd): RunnerId() for nd in nodes} + runner_to_shard = { + runner_id: PipelineShardMetadata( + model_card=mc, + device_rank=i, + world_size=n, + start_layer=0, + end_layer=mc.n_layers, + n_layers=mc.n_layers, + ) + for i, runner_id in enumerate(node_to_runner.values()) + } + hosts_by_node: dict[NodeId, list[Host]] = {} + for r, node_str in enumerate(nodes): + hosts: list[Host] = [] + for idx in range(n): + if idx == r: + hosts.append(Host(ip="0.0.0.0", port=ephemeral_port)) + elif n > 1 and idx in ((r - 1) % n, (r + 1) % n): + hosts.append(Host(ip=f"10.0.0.{idx + 1}", port=ephemeral_port)) + else: + hosts.append(Host(ip="198.51.100.1", port=0)) + hosts_by_node[NodeId(node_str)] = hosts + return iid, MlxRingInstance( + instance_id=iid, + shard_assignments=ShardAssignments( + model_id=ModelId(model_id), + runner_to_shard=runner_to_shard, + node_to_runner=node_to_runner, + ), + hosts_by_node=hosts_by_node, + ephemeral_port=ephemeral_port, + meta_instance_id=meta_instance_id, + ) + + +# ============================================================================= +# 1. MetaInstance lifecycle edge cases +# ============================================================================= + + +def test_meta_instance_model_is_frozen(): + """MetaInstance should be immutable (frozen model).""" + meta = _meta_instance() + try: + meta.model_id = ModelId("something-else") + raise AssertionError("Should have raised") + except Exception: + pass # Expected — frozen model + + +def test_meta_instance_created_then_deleted_roundtrip(): + """Create and delete a MetaInstance through apply — state should be clean.""" + state = State() + meta = _meta_instance() + state = apply( + state, IndexedEvent(idx=0, event=MetaInstanceCreated(meta_instance=meta)) + ) + assert meta.meta_instance_id in state.meta_instances + state = apply( + state, + IndexedEvent( + idx=1, event=MetaInstanceDeleted(meta_instance_id=meta.meta_instance_id) + ), + ) + assert meta.meta_instance_id not in state.meta_instances + assert len(state.meta_instances) == 0 + + +def test_delete_nonexistent_meta_instance_is_safe(): + """Deleting a MetaInstance that doesn't exist should not crash.""" + state = State() + event = MetaInstanceDeleted(meta_instance_id=MetaInstanceId("nonexistent")) + new_state = apply(state, IndexedEvent(idx=0, event=event)) + assert len(new_state.meta_instances) == 0 + + +def test_placement_failed_for_nonexistent_meta_instance_is_safe(): + """MetaInstancePlacementFailed for unknown ID should not crash.""" + state = State() + event = MetaInstancePlacementFailed( + meta_instance_id=MetaInstanceId("nonexistent"), + reason="test", + ) + new_state = apply(state, IndexedEvent(idx=0, event=event)) + assert len(new_state.meta_instances) == 0 + + +def test_multiple_meta_instances_for_same_model(): + """Multiple MetaInstances for the same model are tracked independently.""" + state = State() + meta_a = _meta_instance("test-org/model-x") + meta_b = _meta_instance("test-org/model-x") + state = apply( + state, IndexedEvent(idx=0, event=MetaInstanceCreated(meta_instance=meta_a)) + ) + state = apply( + state, IndexedEvent(idx=1, event=MetaInstanceCreated(meta_instance=meta_b)) + ) + assert len(state.meta_instances) == 2 + assert meta_a.meta_instance_id in state.meta_instances + assert meta_b.meta_instance_id in state.meta_instances + + +# ============================================================================= +# 2. Retry logic edge cases +# ============================================================================= + + +def test_retry_counter_resets_on_successful_instance_creation(): + """When a new instance is created for a meta-instance, failures should reset.""" + meta = _meta_instance(consecutive_failures=2, last_failure_error="old") + _, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id) + state = State(meta_instances={meta.meta_instance_id: meta}) + state = apply(state, IndexedEvent(idx=0, event=InstanceCreated(instance=inst))) + mi = state.meta_instances[meta.meta_instance_id] + assert mi.consecutive_failures == 0 + # last_failure_error is preserved (for UI display) + assert mi.last_failure_error == "old" + + +async def test_retry_count_increments_through_full_cycle(): + """Walk through MAX_INSTANCE_RETRIES worth of retries, then verify delete.""" + meta = _meta_instance() + iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id) + topology = _topology("node-a") + state = State( + meta_instances={meta.meta_instance_id: meta}, + instances={iid: inst}, + topology=topology, + ) + + runner_ids = list(inst.shard_assignments.node_to_runner.values()) + for idx, i in enumerate(range(MAX_INSTANCE_RETRIES)): + # Simulate runners failing + state_with_runners = state.model_copy( + update={"runners": {runner_ids[0]: RunnerFailed(error_message=f"fail-{i}")}} + ) + reconciler = InstanceHealthReconciler() + events = await reconciler.reconcile(state_with_runners) + assert len(events) == 1 + assert isinstance(events[0], InstanceRetrying), f"iteration {i}" + state = apply(state, IndexedEvent(idx=idx, event=events[0])) + + # After MAX_INSTANCE_RETRIES retries, failure counter should be at max + mi = state.meta_instances[meta.meta_instance_id] + assert mi.consecutive_failures == MAX_INSTANCE_RETRIES + + # Next failure should result in deletion + state_with_runners = state.model_copy( + update={"runners": {runner_ids[0]: RunnerFailed(error_message="final")}} + ) + reconciler = InstanceHealthReconciler() + events = await reconciler.reconcile(state_with_runners) + assert len(events) == 1 + assert isinstance(events[0], InstanceDeleted) + + +async def test_health_reconciler_respects_exact_limit(): + """At exactly MAX_INSTANCE_RETRIES, reconciler should delete, not retry.""" + meta = _meta_instance(consecutive_failures=MAX_INSTANCE_RETRIES) + iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id) + runner_ids = list(inst.shard_assignments.node_to_runner.values()) + state = State( + meta_instances={meta.meta_instance_id: meta}, + instances={iid: inst}, + runners={runner_ids[0]: RunnerFailed(error_message="OOM")}, + topology=_topology("node-a"), + ) + reconciler = InstanceHealthReconciler() + events = await reconciler.reconcile(state) + assert len(events) == 1 + assert isinstance(events[0], InstanceDeleted) + + +async def test_health_reconciler_at_limit_minus_one_retries(): + """At MAX_INSTANCE_RETRIES - 1, reconciler should still retry.""" + meta = _meta_instance(consecutive_failures=MAX_INSTANCE_RETRIES - 1) + iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id) + runner_ids = list(inst.shard_assignments.node_to_runner.values()) + state = State( + meta_instances={meta.meta_instance_id: meta}, + instances={iid: inst}, + runners={runner_ids[0]: RunnerFailed(error_message="OOM")}, + topology=_topology("node-a"), + ) + reconciler = InstanceHealthReconciler() + events = await reconciler.reconcile(state) + assert len(events) == 1 + assert isinstance(events[0], InstanceRetrying) + + +# ============================================================================= +# 3. Error handling edge cases +# ============================================================================= + + +def test_runners_failed_with_empty_error_message(): + """RunnerFailed with empty error_message should still report as failed.""" + _, inst = _instance(node_ids=["node-a"]) + runners = { + rid: RunnerFailed(error_message="") + for rid in inst.shard_assignments.node_to_runner.values() + } + is_failed, error = instance_runners_failed(inst, runners, {}) + assert is_failed is True + # Empty error message means we get the fallback + assert error == "Runner failed" + + +def test_runners_failed_with_none_error_message(): + """RunnerFailed with None error_message should still report as failed.""" + _, inst = _instance(node_ids=["node-a"]) + runners = { + rid: RunnerFailed(error_message=None) + for rid in inst.shard_assignments.node_to_runner.values() + } + is_failed, error = instance_runners_failed(inst, runners, {}) + assert is_failed is True + assert error == "Runner failed" + + +def test_runners_failed_collects_all_error_messages(): + """With multiple failed runners, all error messages should be collected.""" + _, inst = _instance(node_ids=["node-a", "node-b", "node-c"]) + runner_ids = list(inst.shard_assignments.node_to_runner.values()) + runners = { + runner_ids[0]: RunnerFailed(error_message="OOM on GPU 0"), + runner_ids[1]: RunnerFailed(error_message="OOM on GPU 1"), + runner_ids[2]: RunnerFailed(error_message="OOM on GPU 2"), + } + is_failed, error = instance_runners_failed(inst, runners, {}) + assert is_failed is True + assert error is not None + assert "OOM on GPU 0" in error + assert "OOM on GPU 1" in error + assert "OOM on GPU 2" in error + + +def test_runners_failed_includes_friendly_name(): + """Error messages should include node friendly names when available.""" + _, inst = _instance(node_ids=["node-a"]) + node_id = NodeId("node-a") + runner_ids = list(inst.shard_assignments.node_to_runner.values()) + runners = {runner_ids[0]: RunnerFailed(error_message="OOM")} + identities = {node_id: NodeIdentity(friendly_name="My Mac Studio")} + is_failed, error = instance_runners_failed(inst, runners, identities) + assert is_failed is True + assert error is not None + assert "My Mac Studio" in error + + +def test_instance_retrying_for_missing_instance_is_safe(): + """InstanceRetrying for an instance not in state should not crash. + + NOTE: When the instance is missing, the handler returns early WITHOUT + incrementing the MetaInstance failure counter. This means stale retry + events for already-deleted instances are silently dropped. This is + acceptable since the InstanceDeleted handler already increments failures. + """ + meta = _meta_instance() + state = State(meta_instances={meta.meta_instance_id: meta}) + event = InstanceRetrying( + instance_id=InstanceId("nonexistent"), + meta_instance_id=meta.meta_instance_id, + failure_error="crash", + ) + new_state = apply(state, IndexedEvent(idx=0, event=event)) + # Does not crash, but failure count is NOT incremented (early return) + mi = new_state.meta_instances[meta.meta_instance_id] + assert mi.consecutive_failures == 0 + + +# ============================================================================= +# 4. Backward compatibility +# ============================================================================= + + +def test_instance_without_meta_instance_id_works(): + """Instances created without meta_instance_id should still function normally.""" + _, inst = _instance(node_ids=["node-a"]) + assert inst.meta_instance_id is None + topology = _topology("node-a") + assert instance_connections_healthy(inst, topology) is True + + +def test_instance_deleted_without_meta_does_not_affect_meta_instances(): + """Deleting an instance without meta_instance_id should not affect meta_instances.""" + meta = _meta_instance() + iid, inst = _instance(node_ids=["node-a"]) # no meta_instance_id + state = State( + meta_instances={meta.meta_instance_id: meta}, + instances={iid: inst}, + ) + event = InstanceDeleted(instance_id=iid, failure_error="crash") + new_state = apply(state, IndexedEvent(idx=0, event=event)) + mi = new_state.meta_instances[meta.meta_instance_id] + assert mi.consecutive_failures == 0 # unchanged + + +def test_satisfies_ignores_meta_instance_id_binding(): + """instance_satisfies_meta_instance checks constraints only, not binding.""" + meta = _meta_instance() + _, inst = _instance(node_ids=["node-a"]) # no meta_instance_id set + # Should match on constraints (model, min_nodes) regardless of binding + assert instance_satisfies_meta_instance(meta, inst) is True + + +def test_find_unsatisfied_uses_binding_not_constraints(): + """find_unsatisfied checks meta_instance_id binding, not just constraint matching.""" + meta = _meta_instance() + # Instance matches constraints but is NOT bound to this meta_instance + iid, inst = _instance(node_ids=["node-a"]) + topology = _topology("node-a") + result = find_unsatisfied_meta_instances( + {meta.meta_instance_id: meta}, {iid: inst}, topology + ) + # Should be unsatisfied because instance.meta_instance_id != meta.meta_instance_id + assert list(result) == [meta] + + +# ============================================================================= +# 5. Concurrent / multi-instance scenarios +# ============================================================================= + + +async def test_health_reconciler_handles_multiple_failing_instances(): + """Multiple instances failing simultaneously should each get their own event.""" + meta_a = _meta_instance() + meta_b = _meta_instance() + iid_a, inst_a = _instance( + node_ids=["node-a"], meta_instance_id=meta_a.meta_instance_id + ) + iid_b, inst_b = _instance( + node_ids=["node-b"], meta_instance_id=meta_b.meta_instance_id + ) + runner_ids_a = list(inst_a.shard_assignments.node_to_runner.values()) + runner_ids_b = list(inst_b.shard_assignments.node_to_runner.values()) + state = State( + meta_instances={ + meta_a.meta_instance_id: meta_a, + meta_b.meta_instance_id: meta_b, + }, + instances={iid_a: inst_a, iid_b: inst_b}, + runners={ + runner_ids_a[0]: RunnerFailed(error_message="OOM"), + runner_ids_b[0]: RunnerFailed(error_message="OOM"), + }, + topology=_topology("node-a", "node-b"), + ) + reconciler = InstanceHealthReconciler() + events = await reconciler.reconcile(state) + assert len(events) == 2 + # Both should be InstanceRetrying since failures < MAX + assert all(isinstance(e, InstanceRetrying) for e in events) + instance_ids = {e.instance_id for e in events} # type: ignore[union-attr] + assert instance_ids == {iid_a, iid_b} + + +async def test_health_reconciler_mixed_healthy_and_failing(): + """Only failing instances should produce events; healthy ones should not.""" + meta_healthy = _meta_instance() + meta_failing = _meta_instance() + iid_h, inst_h = _instance( + node_ids=["node-a"], meta_instance_id=meta_healthy.meta_instance_id + ) + iid_f, inst_f = _instance( + node_ids=["node-b"], meta_instance_id=meta_failing.meta_instance_id + ) + runner_ids_h = list(inst_h.shard_assignments.node_to_runner.values()) + runner_ids_f = list(inst_f.shard_assignments.node_to_runner.values()) + state = State( + meta_instances={ + meta_healthy.meta_instance_id: meta_healthy, + meta_failing.meta_instance_id: meta_failing, + }, + instances={iid_h: inst_h, iid_f: inst_f}, + runners={ + runner_ids_h[0]: RunnerReady(), + runner_ids_f[0]: RunnerFailed(error_message="crash"), + }, + topology=_topology("node-a", "node-b"), + ) + reconciler = InstanceHealthReconciler() + events = await reconciler.reconcile(state) + assert len(events) == 1 + assert isinstance(events[0], InstanceRetrying) + assert events[0].instance_id == iid_f + + +async def test_meta_instance_reconciler_empty_state(): + """MetaInstanceReconciler with no meta_instances should produce no events.""" + state = State() + reconciler = MetaInstanceReconciler() + events = await reconciler.reconcile(state) + assert len(events) == 0 + + +# ============================================================================= +# 6. Placement error tracking +# ============================================================================= + + +def test_placement_failed_sets_error(): + """MetaInstancePlacementFailed should set placement_error on the MetaInstance.""" + meta = _meta_instance() + state = State(meta_instances={meta.meta_instance_id: meta}) + event = MetaInstancePlacementFailed( + meta_instance_id=meta.meta_instance_id, + reason="Not enough memory", + ) + new_state = apply(state, IndexedEvent(idx=0, event=event)) + mi = new_state.meta_instances[meta.meta_instance_id] + assert mi.placement_error == "Not enough memory" + + +def test_instance_created_clears_placement_error(): + """InstanceCreated should clear placement_error on the MetaInstance.""" + meta = _meta_instance(placement_error="Not enough memory") + _, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id) + state = State(meta_instances={meta.meta_instance_id: meta}) + state = apply(state, IndexedEvent(idx=0, event=InstanceCreated(instance=inst))) + mi = state.meta_instances[meta.meta_instance_id] + assert mi.placement_error is None + + +def test_placement_error_does_not_increment_failures(): + """Placement failures should only set placement_error, not increment consecutive_failures.""" + meta = _meta_instance() + state = State(meta_instances={meta.meta_instance_id: meta}) + event = MetaInstancePlacementFailed( + meta_instance_id=meta.meta_instance_id, + reason="No resources", + ) + new_state = apply(state, IndexedEvent(idx=0, event=event)) + mi = new_state.meta_instances[meta.meta_instance_id] + assert mi.consecutive_failures == 0 + assert mi.placement_error == "No resources" + + +# ============================================================================= +# 7. State serialization roundtrip +# ============================================================================= + + +def test_state_with_meta_instances_serializes(): + """State with meta_instances should serialize and deserialize correctly.""" + meta = _meta_instance(consecutive_failures=2, last_failure_error="test") + iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id) + state = State( + meta_instances={meta.meta_instance_id: meta}, + instances={iid: inst}, + ) + json_str = state.model_dump_json() + restored = State.model_validate_json(json_str) + assert meta.meta_instance_id in restored.meta_instances + mi = restored.meta_instances[meta.meta_instance_id] + assert mi.model_id == meta.model_id + assert mi.consecutive_failures == 2 + assert mi.last_failure_error == "test" + assert iid in restored.instances + assert restored.instances[iid].meta_instance_id == meta.meta_instance_id + + +# ============================================================================= +# 8. MetaInstanceReconciler error handling +# ============================================================================= + + +async def test_meta_instance_reconciler_model_load_error_emits_placement_failed( + monkeypatch: "pytest.MonkeyPatch", +): + """When ModelCard.load raises, reconciler emits MetaInstancePlacementFailed.""" + import exo.master.process_managers.meta_instance as mi_mod + + meta = _meta_instance() + topo = _topology("node-a") + state = State( + meta_instances={meta.meta_instance_id: meta}, + topology=topo, + ) + + async def _failing_load(_model_id: ModelId) -> ModelCard: + raise RuntimeError("Network error") + + monkeypatch.setattr( + mi_mod, "ModelCard", type("MC", (), {"load": staticmethod(_failing_load)}) + ) + + reconciler = MetaInstanceReconciler() + events = await reconciler.reconcile(state) + + placement_failed = [e for e in events if isinstance(e, MetaInstancePlacementFailed)] + assert len(placement_failed) == 1 + assert "Failed to load model card" in placement_failed[0].reason + assert meta.meta_instance_id == placement_failed[0].meta_instance_id + + +async def test_meta_instance_reconciler_model_load_error_skips_dedup( + monkeypatch: "pytest.MonkeyPatch", +): + """When ModelCard.load error matches existing placement_error, no duplicate event.""" + import exo.master.process_managers.meta_instance as mi_mod + + meta = _meta_instance(placement_error="Failed to load model card: Network error") + topo = _topology("node-a") + state = State( + meta_instances={meta.meta_instance_id: meta}, + topology=topo, + ) + + async def _failing_load(_model_id: ModelId) -> ModelCard: + raise RuntimeError("Network error") + + monkeypatch.setattr( + mi_mod, "ModelCard", type("MC", (), {"load": staticmethod(_failing_load)}) + ) + + reconciler = MetaInstanceReconciler() + events = await reconciler.reconcile(state) + + # Error matches existing placement_error, so no duplicate event emitted + assert len(events) == 0 + + +async def test_meta_instance_reconciler_continues_after_error( + monkeypatch: "pytest.MonkeyPatch", +): + """Reconciler should continue to next meta-instance after one fails to load.""" + import exo.master.process_managers.meta_instance as mi_mod + + meta_a = _meta_instance(model_id="org/model-a") + meta_b = _meta_instance(model_id="org/model-b") + topo = _topology("node-a") + state = State( + meta_instances={ + meta_a.meta_instance_id: meta_a, + meta_b.meta_instance_id: meta_b, + }, + topology=topo, + ) + + call_count = 0 + + async def _load_second_fails(model_id: ModelId) -> ModelCard: + nonlocal call_count + call_count += 1 + raise RuntimeError(f"Cannot load {model_id}") + + monkeypatch.setattr( + mi_mod, "ModelCard", type("MC", (), {"load": staticmethod(_load_second_fails)}) + ) + + reconciler = MetaInstanceReconciler() + events = await reconciler.reconcile(state) + + # Both meta-instances should have been attempted (not short-circuited) + assert call_count == 2 + # Both should have placement failed events + placement_failed = [e for e in events if isinstance(e, MetaInstancePlacementFailed)] + assert len(placement_failed) == 2 + + +# ============================================================================= +# 8. Cascade delete with task cancellation +# ============================================================================= + + +def test_cascade_delete_cancels_active_tasks(): + """Deleting a MetaInstance should cancel tasks on backing instances. + + Regression test: previously, cascade-deleting backing instances via + DeleteMetaInstance did not emit TaskStatusUpdated(Cancelled) for active + tasks, leaving orphaned task references in state. + """ + meta = _meta_instance() + iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id) + task_id = TaskId() + task = LoadModel(task_id=task_id, instance_id=iid, task_status=TaskStatus.Running) + + # Build state with meta-instance, backing instance, and active task + state = State( + meta_instances={meta.meta_instance_id: meta}, + instances={iid: inst}, + tasks={task_id: task}, + topology=_topology("node-a"), + ) + + # Simulate the cascade-delete event sequence produced by main.py: + # 1. MetaInstanceDeleted + # 2. TaskStatusUpdated(Cancelled) for active tasks + # 3. InstanceDeleted + idx = 0 + state = apply( + state, + IndexedEvent( + idx=idx, + event=MetaInstanceDeleted(meta_instance_id=meta.meta_instance_id), + ), + ) + idx += 1 + state = apply( + state, + IndexedEvent( + idx=idx, + event=TaskStatusUpdated(task_id=task_id, task_status=TaskStatus.Cancelled), + ), + ) + idx += 1 + state = apply( + state, + IndexedEvent(idx=idx, event=InstanceDeleted(instance_id=iid)), + ) + + # Verify everything is cleaned up + assert len(state.meta_instances) == 0 + assert len(state.instances) == 0 + assert state.tasks[task_id].task_status == TaskStatus.Cancelled + + +def test_cascade_delete_skips_completed_tasks(): + """Cascade delete should only cancel Pending/Running tasks, not completed ones.""" + meta = _meta_instance() + iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id) + + running_task_id = TaskId() + completed_task_id = TaskId() + running_task = LoadModel( + task_id=running_task_id, instance_id=iid, task_status=TaskStatus.Running + ) + completed_task = LoadModel( + task_id=completed_task_id, instance_id=iid, task_status=TaskStatus.Complete + ) + + state = State( + meta_instances={meta.meta_instance_id: meta}, + instances={iid: inst}, + tasks={running_task_id: running_task, completed_task_id: completed_task}, + topology=_topology("node-a"), + ) + + # Only the running task should be cancelled — we verify the logic pattern + # by checking which tasks are Pending or Running + active_tasks = [ + t + for t in state.tasks.values() + if t.instance_id == iid + and t.task_status in (TaskStatus.Pending, TaskStatus.Running) + ] + assert len(active_tasks) == 1 + assert active_tasks[0].task_id == running_task_id diff --git a/src/exo/master/tests/test_placement_utils.py b/src/exo/master/tests/test_placement_utils.py index 245c4fd7e..9c7ebadab 100644 --- a/src/exo/master/tests/test_placement_utils.py +++ b/src/exo/master/tests/test_placement_utils.py @@ -3,10 +3,10 @@ import pytest from exo.master.placement_utils import ( allocate_layers_proportionally, filter_cycles_by_memory, + get_largest_cycles, get_mlx_jaccl_coordinators, get_shard_assignments, get_shard_assignments_for_pipeline_parallel, - get_smallest_cycles, ) from exo.master.tests.conftest import ( create_node_memory, @@ -143,7 +143,7 @@ def test_filter_multiple_cycles_by_memory(): } -def test_get_smallest_cycles(): +def test_get_largest_cycles(): # arrange node_a_id = NodeId() node_b_id = NodeId() @@ -175,12 +175,12 @@ def test_get_smallest_cycles(): cycles = [c for c in topology.get_cycles() if len(c) != 1] # ignore singletons # act - smallest_cycles = get_smallest_cycles(cycles) + largest_cycles = get_largest_cycles(cycles) # assert - assert len(smallest_cycles) == 1 - assert len(smallest_cycles[0]) == 2 - assert set(n for n in smallest_cycles[0]) == {node_a_id, node_b_id} + assert len(largest_cycles) == 1 + assert len(largest_cycles[0]) == 3 + assert set(n for n in largest_cycles[0]) == {node_a_id, node_b_id, node_c_id} @pytest.mark.parametrize( diff --git a/src/exo/master/tests/test_reconcile.py b/src/exo/master/tests/test_reconcile.py new file mode 100644 index 000000000..e2d6e776c --- /dev/null +++ b/src/exo/master/tests/test_reconcile.py @@ -0,0 +1,742 @@ +from exo.master.process_managers.instance_health import InstanceHealthReconciler +from exo.master.reconcile import ( + find_unsatisfied_meta_instances, + instance_connections_healthy, + instance_runners_failed, + instance_satisfies_meta_instance, +) +from exo.shared.apply import apply +from exo.shared.models.model_cards import ModelCard, ModelId, ModelTask +from exo.shared.topology import Topology +from exo.shared.types.common import Host, MetaInstanceId, NodeId +from exo.shared.types.events import ( + IndexedEvent, + InstanceCreated, + InstanceDeleted, + InstanceRetrying, + MetaInstanceCreated, + MetaInstanceDeleted, +) +from exo.shared.types.memory import Memory +from exo.shared.types.meta_instance import MetaInstance +from exo.shared.types.multiaddr import Multiaddr +from exo.shared.types.state import State +from exo.shared.types.topology import Connection, SocketConnection +from exo.shared.types.worker.instances import ( + InstanceId, + MlxRingInstance, +) +from exo.shared.types.worker.runners import ( + RunnerFailed, + RunnerId, + RunnerLoading, + RunnerReady, + RunnerShutdown, + ShardAssignments, +) +from exo.shared.types.worker.shards import PipelineShardMetadata + + +def _model_card(model_id: str = "test-org/test-model") -> ModelCard: + return ModelCard( + model_id=ModelId(model_id), + storage_size=Memory.from_kb(1000), + n_layers=10, + hidden_size=30, + supports_tensor=True, + tasks=[ModelTask.TextGeneration], + ) + + +def _topology(*node_ids: str, connect: bool = True) -> Topology: + """Build a topology with nodes connected in a bidirectional ring with unique IPs. + + Node at index ``i`` gets IP ``10.0.0.{i+1}``. Edges go in both directions + between consecutive nodes (including wrap-around). + """ + t = Topology() + nodes = [NodeId(n) for n in node_ids] + for n in nodes: + t.add_node(n) + if connect and len(nodes) > 1: + for i in range(len(nodes)): + j = (i + 1) % len(nodes) + t.add_connection( + Connection( + source=nodes[i], + sink=nodes[j], + edge=SocketConnection( + sink_multiaddr=Multiaddr( + address=f"/ip4/10.0.0.{j + 1}/tcp/50000" + ) + ), + ) + ) + t.add_connection( + Connection( + source=nodes[j], + sink=nodes[i], + edge=SocketConnection( + sink_multiaddr=Multiaddr( + address=f"/ip4/10.0.0.{i + 1}/tcp/50000" + ) + ), + ) + ) + return t + + +def _meta_instance( + model_id: str = "test-org/test-model", + *, + min_nodes: int = 1, + node_ids: list[NodeId] | None = None, + meta_instance_id: MetaInstanceId | None = None, +) -> MetaInstance: + return MetaInstance( + meta_instance_id=meta_instance_id or MetaInstanceId(), + model_id=ModelId(model_id), + min_nodes=min_nodes, + node_ids=node_ids, + ) + + +def _instance( + model_id: str = "test-org/test-model", + node_ids: list[str] | None = None, + instance_id: InstanceId | None = None, + meta_instance_id: MetaInstanceId | None = None, +) -> tuple[InstanceId, MlxRingInstance]: + """Create a test instance with hosts_by_node matching ``_topology()`` IPs.""" + iid = instance_id or InstanceId() + nodes = node_ids or ["node-a"] + n = len(nodes) + mc = _model_card(model_id) + ephemeral_port = 50000 + node_to_runner = {NodeId(nd): RunnerId() for nd in nodes} + runner_to_shard = { + runner_id: PipelineShardMetadata( + model_card=mc, + device_rank=i, + world_size=n, + start_layer=0, + end_layer=mc.n_layers, + n_layers=mc.n_layers, + ) + for i, runner_id in enumerate(node_to_runner.values()) + } + # Build hosts_by_node with IPs matching _topology() convention: + # node at index idx has IP 10.0.0.{idx+1} + hosts_by_node: dict[NodeId, list[Host]] = {} + for r, node_str in enumerate(nodes): + hosts: list[Host] = [] + for idx in range(n): + if idx == r: + hosts.append(Host(ip="0.0.0.0", port=ephemeral_port)) + elif n > 1 and idx in ((r - 1) % n, (r + 1) % n): + hosts.append(Host(ip=f"10.0.0.{idx + 1}", port=ephemeral_port)) + else: + hosts.append(Host(ip="198.51.100.1", port=0)) + hosts_by_node[NodeId(node_str)] = hosts + return iid, MlxRingInstance( + instance_id=iid, + shard_assignments=ShardAssignments( + model_id=ModelId(model_id), + runner_to_shard=runner_to_shard, + node_to_runner=node_to_runner, + ), + hosts_by_node=hosts_by_node, + ephemeral_port=ephemeral_port, + meta_instance_id=meta_instance_id, + ) + + +# --- instance_satisfies_meta_instance (pure constraint matching) --- + + +def test_satisfies_matching_model(): + meta = _meta_instance() + _, inst = _instance(node_ids=["node-a"]) + assert instance_satisfies_meta_instance(meta, inst) is True + + +def test_not_satisfies_wrong_model(): + meta = _meta_instance("test-org/model-a") + _, inst = _instance("test-org/model-b") + assert instance_satisfies_meta_instance(meta, inst) is False + + +def test_not_satisfies_missing_required_node(): + meta = _meta_instance(node_ids=[NodeId("node-c")]) + _, inst = _instance(node_ids=["node-a", "node-b"]) + assert instance_satisfies_meta_instance(meta, inst) is False + + +def test_not_satisfies_fewer_than_min_nodes(): + meta = _meta_instance(min_nodes=3) + _, inst = _instance(node_ids=["node-a", "node-b"]) + assert instance_satisfies_meta_instance(meta, inst) is False + + +def test_satisfies_with_node_ids_specified(): + meta = _meta_instance(node_ids=[NodeId("node-a"), NodeId("node-b")], min_nodes=2) + _, inst = _instance(node_ids=["node-a", "node-b", "node-c"]) + assert instance_satisfies_meta_instance(meta, inst) is True + + +# --- instance_connections_healthy --- + + +def test_healthy_single_node_present(): + _, inst = _instance(node_ids=["node-a"]) + topology = _topology("node-a") + assert instance_connections_healthy(inst, topology) is True + + +def test_unhealthy_single_node_missing(): + _, inst = _instance(node_ids=["node-a"]) + topology = Topology() # empty + assert instance_connections_healthy(inst, topology) is False + + +def test_healthy_two_node_ring(): + _, inst = _instance(node_ids=["node-a", "node-b"]) + topology = _topology("node-a", "node-b") + assert instance_connections_healthy(inst, topology) is True + + +def test_unhealthy_two_node_edge_removed(): + """Nodes present but edge removed — ring broken.""" + _, inst = _instance(node_ids=["node-a", "node-b"]) + topology = _topology("node-a", "node-b", connect=False) + assert instance_connections_healthy(inst, topology) is False + + +def test_unhealthy_two_node_ip_changed(): + """Edge exists but with a different IP than instance was configured with.""" + _, inst = _instance(node_ids=["node-a", "node-b"]) + # Build topology with different IPs than _instance() expects + topology = Topology() + topology.add_node(NodeId("node-a")) + topology.add_node(NodeId("node-b")) + topology.add_connection( + Connection( + source=NodeId("node-a"), + sink=NodeId("node-b"), + edge=SocketConnection( + sink_multiaddr=Multiaddr(address="/ip4/192.168.99.99/tcp/50000") + ), + ) + ) + topology.add_connection( + Connection( + source=NodeId("node-b"), + sink=NodeId("node-a"), + edge=SocketConnection( + sink_multiaddr=Multiaddr(address="/ip4/192.168.99.98/tcp/50000") + ), + ) + ) + assert instance_connections_healthy(inst, topology) is False + + +def test_healthy_three_node_ring(): + _, inst = _instance(node_ids=["node-a", "node-b", "node-c"]) + topology = _topology("node-a", "node-b", "node-c") + assert instance_connections_healthy(inst, topology) is True + + +def test_unhealthy_three_node_one_edge_removed(): + """Remove one edge from a three-node ring — instance unhealthy.""" + _, inst = _instance(node_ids=["node-a", "node-b", "node-c"]) + # Build topology with one direction of one edge missing + topology = Topology() + nodes = [NodeId("node-a"), NodeId("node-b"), NodeId("node-c")] + for n in nodes: + topology.add_node(n) + # Add all edges except node-a → node-b + topology.add_connection( + Connection( + source=nodes[1], + sink=nodes[0], + edge=SocketConnection( + sink_multiaddr=Multiaddr(address="/ip4/10.0.0.1/tcp/50000") + ), + ) + ) + topology.add_connection( + Connection( + source=nodes[1], + sink=nodes[2], + edge=SocketConnection( + sink_multiaddr=Multiaddr(address="/ip4/10.0.0.3/tcp/50000") + ), + ) + ) + topology.add_connection( + Connection( + source=nodes[2], + sink=nodes[1], + edge=SocketConnection( + sink_multiaddr=Multiaddr(address="/ip4/10.0.0.2/tcp/50000") + ), + ) + ) + topology.add_connection( + Connection( + source=nodes[2], + sink=nodes[0], + edge=SocketConnection( + sink_multiaddr=Multiaddr(address="/ip4/10.0.0.1/tcp/50000") + ), + ) + ) + topology.add_connection( + Connection( + source=nodes[0], + sink=nodes[2], + edge=SocketConnection( + sink_multiaddr=Multiaddr(address="/ip4/10.0.0.3/tcp/50000") + ), + ) + ) + # Missing: node-a → node-b (ip 10.0.0.2) + assert instance_connections_healthy(inst, topology) is False + + +def test_unhealthy_node_missing_from_topology(): + """Instance has a node that's not in the topology at all.""" + _, inst = _instance(node_ids=["node-a", "node-b"]) + topology = _topology("node-a") # node-b not present + assert instance_connections_healthy(inst, topology) is False + + +def test_healthy_extra_nodes_in_topology(): + """Extra nodes in topology don't affect instance health.""" + _, inst = _instance(node_ids=["node-a", "node-b"]) + topology = _topology("node-a", "node-b", "node-c") + assert instance_connections_healthy(inst, topology) is True + + +# --- find_unsatisfied_meta_instances --- + + +def test_unsatisfied_no_meta_instances(): + result = find_unsatisfied_meta_instances({}, {}, Topology()) + assert list(result) == [] + + +def test_unsatisfied_one_satisfied(): + meta = _meta_instance() + id_a, inst_a = _instance(meta_instance_id=meta.meta_instance_id) + topology = _topology("node-a") + result = find_unsatisfied_meta_instances( + {meta.meta_instance_id: meta}, + {id_a: inst_a}, + topology, + ) + assert list(result) == [] + + +def test_unsatisfied_one_not_satisfied(): + meta = _meta_instance("test-org/model-x") + id_a, inst_a = _instance("test-org/model-y") + topology = _topology("node-a") + result = find_unsatisfied_meta_instances( + {meta.meta_instance_id: meta}, {id_a: inst_a}, topology + ) + assert list(result) == [meta] + + +def test_unsatisfied_mix(): + meta_satisfied = _meta_instance("test-org/model-a") + meta_unsatisfied = _meta_instance("test-org/model-b") + id_a, inst_a = _instance( + "test-org/model-a", meta_instance_id=meta_satisfied.meta_instance_id + ) + topology = _topology("node-a") + result = find_unsatisfied_meta_instances( + { + meta_satisfied.meta_instance_id: meta_satisfied, + meta_unsatisfied.meta_instance_id: meta_unsatisfied, + }, + {id_a: inst_a}, + topology, + ) + assert list(result) == [meta_unsatisfied] + + +def test_unsatisfied_node_disconnect(): + meta = _meta_instance() + id_a, inst_a = _instance( + node_ids=["node-a", "node-b"], meta_instance_id=meta.meta_instance_id + ) + topology = _topology("node-a") # node-b disconnected + result = find_unsatisfied_meta_instances( + {meta.meta_instance_id: meta}, + {id_a: inst_a}, + topology, + ) + assert list(result) == [meta] + + +def test_unsatisfied_edge_break(): + """Instance exists but its connections broke — meta-instance becomes unsatisfied.""" + meta = _meta_instance() + id_a, inst_a = _instance( + node_ids=["node-a", "node-b"], meta_instance_id=meta.meta_instance_id + ) + topology = _topology("node-a", "node-b", connect=False) # nodes present, no edges + result = find_unsatisfied_meta_instances( + {meta.meta_instance_id: meta}, + {id_a: inst_a}, + topology, + ) + assert list(result) == [meta] + + +def test_unsatisfied_idempotent(): + meta = _meta_instance("test-org/model-x") + topology = _topology("node-a") + meta_instances = {meta.meta_instance_id: meta} + instances: dict[InstanceId, MlxRingInstance] = {} + result_1 = list( + find_unsatisfied_meta_instances(meta_instances, instances, topology) + ) + result_2 = list( + find_unsatisfied_meta_instances(meta_instances, instances, topology) + ) + assert result_1 == result_2 + + +def test_unsatisfied_exclusive_binding(): + """Two MetaInstances for the same model: one is bound via meta_instance_id, the other is unsatisfied.""" + meta_a = _meta_instance("test-org/model-x") + meta_b = _meta_instance("test-org/model-x") + id_inst, inst = _instance( + "test-org/model-x", meta_instance_id=meta_a.meta_instance_id + ) + topology = _topology("node-a") + result = find_unsatisfied_meta_instances( + { + meta_a.meta_instance_id: meta_a, + meta_b.meta_instance_id: meta_b, + }, + {id_inst: inst}, + topology, + ) + assert list(result) == [meta_b] + + +# --- apply handlers --- + + +def test_apply_meta_instance_created(): + state = State() + meta = _meta_instance() + event = MetaInstanceCreated(meta_instance=meta) + new_state = apply(state, IndexedEvent(idx=0, event=event)) + assert meta.meta_instance_id in new_state.meta_instances + assert new_state.meta_instances[meta.meta_instance_id] == meta + + +def test_apply_meta_instance_deleted(): + meta = _meta_instance() + state = State(meta_instances={meta.meta_instance_id: meta}) + event = MetaInstanceDeleted(meta_instance_id=meta.meta_instance_id) + new_state = apply(state, IndexedEvent(idx=0, event=event)) + assert meta.meta_instance_id not in new_state.meta_instances + + +def test_apply_meta_instance_deleted_clears_failure_info(): + meta = _meta_instance().model_copy( + update={"consecutive_failures": 2, "last_failure_error": "OOM"} + ) + state = State(meta_instances={meta.meta_instance_id: meta}) + event = MetaInstanceDeleted(meta_instance_id=meta.meta_instance_id) + new_state = apply(state, IndexedEvent(idx=0, event=event)) + assert meta.meta_instance_id not in new_state.meta_instances + + +# --- instance_runners_failed --- + + +def test_runners_failed_all_failed(): + """All runners in RunnerFailed -> instance is failed.""" + _, inst = _instance(node_ids=["node-a", "node-b"]) + runners = { + rid: RunnerFailed(error_message="OOM") + for rid in inst.shard_assignments.node_to_runner.values() + } + is_failed, error = instance_runners_failed(inst, runners, {}) + assert is_failed is True + assert error is not None + assert "OOM" in error + + +def test_runners_failed_mixed_failed_shutdown(): + """One Failed + one Shutdown = failed.""" + _, inst = _instance(node_ids=["node-a", "node-b"]) + runner_ids = list(inst.shard_assignments.node_to_runner.values()) + runners = { + runner_ids[0]: RunnerFailed(error_message="crash"), + runner_ids[1]: RunnerShutdown(), + } + is_failed, error = instance_runners_failed(inst, runners, {}) + assert is_failed is True + assert error is not None + assert "crash" in error + + +def test_runners_not_failed_all_shutdown(): + """All Shutdown (graceful) = not a failure.""" + _, inst = _instance(node_ids=["node-a"]) + runners = { + rid: RunnerShutdown() for rid in inst.shard_assignments.node_to_runner.values() + } + is_failed, _ = instance_runners_failed(inst, runners, {}) + assert is_failed is False + + +def test_runners_not_failed_still_active(): + """Some runners still active = not failed yet.""" + _, inst = _instance(node_ids=["node-a", "node-b"]) + runner_ids = list(inst.shard_assignments.node_to_runner.values()) + runners = { + runner_ids[0]: RunnerFailed(error_message="OOM"), + runner_ids[1]: RunnerLoading(), + } + is_failed, _ = instance_runners_failed(inst, runners, {}) + assert is_failed is False + + +def test_runners_not_failed_no_status(): + """Runner not yet reported = not failed.""" + _, inst = _instance(node_ids=["node-a"]) + is_failed, _ = instance_runners_failed(inst, {}, {}) + assert is_failed is False + + +def test_runners_not_failed_healthy(): + """Runners in Ready state = not failed.""" + _, inst = _instance(node_ids=["node-a"]) + runners = { + rid: RunnerReady() for rid in inst.shard_assignments.node_to_runner.values() + } + is_failed, _ = instance_runners_failed(inst, runners, {}) + assert is_failed is False + + +# --- failure tracking in apply_instance_deleted --- + + +def test_apply_instance_deleted_tracks_failure(): + """InstanceDeleted with failure_error increments meta instance failure count.""" + meta = _meta_instance() + iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id) + state = State( + meta_instances={meta.meta_instance_id: meta}, + instances={iid: inst}, + ) + event = InstanceDeleted(instance_id=iid, failure_error="Runner OOM") + new_state = apply(state, IndexedEvent(idx=0, event=event)) + mi = new_state.meta_instances[meta.meta_instance_id] + assert mi.consecutive_failures == 1 + assert mi.last_failure_error == "Runner OOM" + + +def test_apply_instance_deleted_increments_failure(): + """Subsequent failures increment the counter.""" + meta = _meta_instance().model_copy( + update={"consecutive_failures": 2, "last_failure_error": "previous error"} + ) + iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id) + state = State( + meta_instances={meta.meta_instance_id: meta}, + instances={iid: inst}, + ) + event = InstanceDeleted(instance_id=iid, failure_error="new error") + new_state = apply(state, IndexedEvent(idx=0, event=event)) + mi = new_state.meta_instances[meta.meta_instance_id] + assert mi.consecutive_failures == 3 + assert mi.last_failure_error == "new error" + + +def test_apply_instance_deleted_no_failure_no_tracking(): + """InstanceDeleted without failure_error does not track.""" + meta = _meta_instance() + iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id) + state = State( + meta_instances={meta.meta_instance_id: meta}, + instances={iid: inst}, + ) + event = InstanceDeleted(instance_id=iid) + new_state = apply(state, IndexedEvent(idx=0, event=event)) + mi = new_state.meta_instances[meta.meta_instance_id] + assert mi.consecutive_failures == 0 + + +def test_apply_instance_deleted_orphan_no_tracking(): + """InstanceDeleted for orphan instance (no meta_instance_id) does not track.""" + iid, inst = _instance(node_ids=["node-a"]) + state = State(instances={iid: inst}) + event = InstanceDeleted(instance_id=iid, failure_error="crash") + new_state = apply(state, IndexedEvent(idx=0, event=event)) + assert len(new_state.meta_instances) == 0 + + +# --- InstanceRetrying --- + + +def test_apply_instance_retrying_removes_runners(): + """InstanceRetrying removes the instance's runners from state but keeps the instance.""" + meta = _meta_instance() + iid, inst = _instance( + node_ids=["node-a", "node-b"], meta_instance_id=meta.meta_instance_id + ) + runner_ids = list(inst.shard_assignments.node_to_runner.values()) + runners = { + runner_ids[0]: RunnerFailed(error_message="OOM"), + runner_ids[1]: RunnerShutdown(), + } + state = State( + meta_instances={meta.meta_instance_id: meta}, + instances={iid: inst}, + runners=runners, + ) + event = InstanceRetrying( + instance_id=iid, + meta_instance_id=meta.meta_instance_id, + failure_error="OOM", + ) + new_state = apply(state, IndexedEvent(idx=0, event=event)) + # Instance still exists + assert iid in new_state.instances + # Runners removed + assert runner_ids[0] not in new_state.runners + assert runner_ids[1] not in new_state.runners + + +def test_apply_instance_retrying_increments_failure(): + """InstanceRetrying increments consecutive_failures on the MetaInstance.""" + meta = _meta_instance() + iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id) + state = State( + meta_instances={meta.meta_instance_id: meta}, + instances={iid: inst}, + ) + event = InstanceRetrying( + instance_id=iid, + meta_instance_id=meta.meta_instance_id, + failure_error="crash", + ) + new_state = apply(state, IndexedEvent(idx=0, event=event)) + mi = new_state.meta_instances[meta.meta_instance_id] + assert mi.consecutive_failures == 1 + assert mi.last_failure_error == "crash" + + +def test_apply_instance_retrying_skips_missing_runners(): + """InstanceRetrying doesn't assert if runners haven't reported yet.""" + meta = _meta_instance() + iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id) + # No runners in state at all + state = State( + meta_instances={meta.meta_instance_id: meta}, + instances={iid: inst}, + ) + event = InstanceRetrying( + instance_id=iid, + meta_instance_id=meta.meta_instance_id, + failure_error="crash", + ) + # Should not raise + new_state = apply(state, IndexedEvent(idx=0, event=event)) + assert iid in new_state.instances + + +def test_apply_instance_created_resets_failure_counter(): + """InstanceCreated resets consecutive_failures but preserves last_failure_error.""" + meta = _meta_instance().model_copy( + update={"consecutive_failures": 3, "last_failure_error": "old error"} + ) + _, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id) + state = State(meta_instances={meta.meta_instance_id: meta}) + event = InstanceCreated(instance=inst) + new_state = apply(state, IndexedEvent(idx=0, event=event)) + mi = new_state.meta_instances[meta.meta_instance_id] + assert mi.consecutive_failures == 0 + assert mi.last_failure_error == "old error" + assert mi.placement_error is None + + +# --- InstanceHealthReconciler retry-vs-delete --- + + +async def test_health_reconciler_retries_when_under_limit(): + """InstanceHealthReconciler emits InstanceRetrying when consecutive_failures < 3.""" + meta = _meta_instance() + iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id) + runner_ids = list(inst.shard_assignments.node_to_runner.values()) + state = State( + meta_instances={meta.meta_instance_id: meta}, + instances={iid: inst}, + runners={runner_ids[0]: RunnerFailed(error_message="OOM")}, + topology=_topology("node-a"), + ) + reconciler = InstanceHealthReconciler() + events = await reconciler.reconcile(state) + assert len(events) == 1 + assert isinstance(events[0], InstanceRetrying) + assert events[0].instance_id == iid + assert events[0].meta_instance_id == meta.meta_instance_id + + +async def test_health_reconciler_deletes_when_limit_reached(): + """InstanceHealthReconciler emits InstanceDeleted when consecutive_failures >= 3.""" + meta = _meta_instance().model_copy(update={"consecutive_failures": 3}) + iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id) + runner_ids = list(inst.shard_assignments.node_to_runner.values()) + state = State( + meta_instances={meta.meta_instance_id: meta}, + instances={iid: inst}, + runners={runner_ids[0]: RunnerFailed(error_message="OOM")}, + topology=_topology("node-a"), + ) + reconciler = InstanceHealthReconciler() + events = await reconciler.reconcile(state) + assert len(events) == 1 + assert isinstance(events[0], InstanceDeleted) + + +async def test_health_reconciler_deletes_without_meta_instance(): + """Instances without a MetaInstance are deleted immediately on runner failure.""" + iid, inst = _instance(node_ids=["node-a"]) + runner_ids = list(inst.shard_assignments.node_to_runner.values()) + state = State( + instances={iid: inst}, + runners={runner_ids[0]: RunnerFailed(error_message="crash")}, + topology=_topology("node-a"), + ) + reconciler = InstanceHealthReconciler() + events = await reconciler.reconcile(state) + assert len(events) == 1 + assert isinstance(events[0], InstanceDeleted) + + +async def test_health_reconciler_network_failure_always_deletes(): + """Network failure always triggers InstanceDeleted regardless of retry count.""" + meta = _meta_instance() + iid, inst = _instance( + node_ids=["node-a", "node-b"], meta_instance_id=meta.meta_instance_id + ) + state = State( + meta_instances={meta.meta_instance_id: meta}, + instances={iid: inst}, + topology=_topology("node-a"), # node-b missing + ) + reconciler = InstanceHealthReconciler() + events = await reconciler.reconcile(state) + assert len(events) == 1 + assert isinstance(events[0], InstanceDeleted) + assert events[0].failure_error == "Network connection lost" diff --git a/src/exo/shared/apply.py b/src/exo/shared/apply.py index 94869dfe8..f96c6b7fe 100644 --- a/src/exo/shared/apply.py +++ b/src/exo/shared/apply.py @@ -4,7 +4,7 @@ from datetime import datetime from loguru import logger -from exo.shared.types.common import NodeId +from exo.shared.types.common import MetaInstanceId, NodeId from exo.shared.types.events import ( ChunkGenerated, Event, @@ -12,6 +12,12 @@ from exo.shared.types.events import ( InputChunkReceived, InstanceCreated, InstanceDeleted, + InstanceRetrying, + JacclSideChannelData, + JacclSideChannelGathered, + MetaInstanceCreated, + MetaInstanceDeleted, + MetaInstancePlacementFailed, NodeDownloadProgress, NodeGatheredInfo, NodeTimedOut, @@ -28,6 +34,7 @@ from exo.shared.types.events import ( TracesCollected, TracesMerged, ) +from exo.shared.types.meta_instance import MetaInstance from exo.shared.types.profiling import ( NodeIdentity, NodeNetworkInfo, @@ -66,12 +73,22 @@ def event_apply(event: Event, state: State) -> State: | InputChunkReceived() | TracesCollected() | TracesMerged() + | JacclSideChannelData() + | JacclSideChannelGathered() ): # Pass-through events that don't modify state return state case InstanceCreated(): return apply_instance_created(event, state) case InstanceDeleted(): return apply_instance_deleted(event, state) + case InstanceRetrying(): + return apply_instance_retrying(event, state) + case MetaInstanceCreated(): + return apply_meta_instance_created(event, state) + case MetaInstanceDeleted(): + return apply_meta_instance_deleted(event, state) + case MetaInstancePlacementFailed(): + return apply_meta_instance_placement_failed(event, state) case NodeTimedOut(): return apply_node_timed_out(event, state) case NodeDownloadProgress(): @@ -174,20 +191,123 @@ def apply_task_failed(event: TaskFailed, state: State) -> State: return state.model_copy(update={"tasks": new_tasks}) +def _update_meta_instance( + state: State, mid: MetaInstanceId, **fields: object +) -> Mapping[MetaInstanceId, MetaInstance]: + mi = state.meta_instances[mid] + return {**state.meta_instances, mid: mi.model_copy(update=fields)} + + def apply_instance_created(event: InstanceCreated, state: State) -> State: instance = event.instance new_instances: Mapping[InstanceId, Instance] = { **state.instances, instance.instance_id: instance, } - return state.model_copy(update={"instances": new_instances}) + update: dict[str, object] = {"instances": new_instances} + # Reset failure tracking when a new instance is created for a meta-instance + if instance.meta_instance_id and instance.meta_instance_id in state.meta_instances: + mi = state.meta_instances[instance.meta_instance_id] + if mi.placement_error is not None or mi.consecutive_failures > 0: + update["meta_instances"] = _update_meta_instance( + state, + instance.meta_instance_id, + placement_error=None, + consecutive_failures=0, + ) + return state.model_copy(update=update) def apply_instance_deleted(event: InstanceDeleted, state: State) -> State: + deleted_instance = state.instances.get(event.instance_id) new_instances: Mapping[InstanceId, Instance] = { iid: inst for iid, inst in state.instances.items() if iid != event.instance_id } - return state.model_copy(update={"instances": new_instances}) + update: dict[str, object] = {"instances": new_instances} + + # Track failure on the MetaInstance itself + if ( + event.failure_error + and deleted_instance + and deleted_instance.meta_instance_id + and deleted_instance.meta_instance_id in state.meta_instances + ): + mid = deleted_instance.meta_instance_id + mi = state.meta_instances[mid] + update["meta_instances"] = { + **state.meta_instances, + mid: mi.model_copy( + update={ + "consecutive_failures": mi.consecutive_failures + 1, + "last_failure_error": event.failure_error, + } + ), + } + + return state.model_copy(update=update) + + +def apply_instance_retrying(event: InstanceRetrying, state: State) -> State: + """Runners failed but retry limit not reached — remove runners, keep instance.""" + instance = state.instances.get(event.instance_id) + if instance is None: + # Instance was already deleted (e.g. cascade from DeleteMetaInstance). + # The InstanceDeleted handler already incremented consecutive_failures + # on the MetaInstance, so skipping here avoids double-counting. + return state + + # Remove all runners belonging to this instance from state + runner_ids_to_remove = set(instance.shard_assignments.node_to_runner.values()) + new_runners: Mapping[RunnerId, RunnerStatus] = { + rid: rs for rid, rs in state.runners.items() if rid not in runner_ids_to_remove + } + + update: dict[str, object] = {"runners": new_runners} + + # Increment failure count on the MetaInstance + if event.meta_instance_id in state.meta_instances: + update["meta_instances"] = _update_meta_instance( + state, + event.meta_instance_id, + consecutive_failures=state.meta_instances[ + event.meta_instance_id + ].consecutive_failures + + 1, + last_failure_error=event.failure_error, + ) + + return state.model_copy(update=update) + + +def apply_meta_instance_created(event: MetaInstanceCreated, state: State) -> State: + new_meta: Mapping[MetaInstanceId, MetaInstance] = { + **state.meta_instances, + event.meta_instance.meta_instance_id: event.meta_instance, + } + return state.model_copy(update={"meta_instances": new_meta}) + + +def apply_meta_instance_deleted(event: MetaInstanceDeleted, state: State) -> State: + new_meta: Mapping[MetaInstanceId, MetaInstance] = { + mid: mi + for mid, mi in state.meta_instances.items() + if mid != event.meta_instance_id + } + return state.model_copy(update={"meta_instances": new_meta}) + + +def apply_meta_instance_placement_failed( + event: MetaInstancePlacementFailed, state: State +) -> State: + if event.meta_instance_id not in state.meta_instances: + return state + return state.model_copy( + update={ + "meta_instances": _update_meta_instance( + state, event.meta_instance_id, placement_error=event.reason + ) + } + ) def apply_runner_status_updated(event: RunnerStatusUpdated, state: State) -> State: diff --git a/src/exo/shared/types/api.py b/src/exo/shared/types/api.py index 2756f0d4b..2f0fafa8f 100644 --- a/src/exo/shared/types/api.py +++ b/src/exo/shared/types/api.py @@ -6,7 +6,7 @@ from uuid import uuid4 from pydantic import BaseModel, Field from exo.shared.models.model_cards import ModelCard, ModelId -from exo.shared.types.common import CommandId, NodeId +from exo.shared.types.common import CommandId, MetaInstanceId, NodeId from exo.shared.types.memory import Memory from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta from exo.shared.types.worker.shards import Sharding, ShardMetadata @@ -262,6 +262,26 @@ class DeleteInstanceResponse(BaseModel): instance_id: InstanceId +class CreateMetaInstanceParams(BaseModel): + model_id: ModelId + sharding: Sharding = Sharding.Pipeline + instance_meta: InstanceMeta = InstanceMeta.MlxRing + min_nodes: int = 1 + node_ids: list[NodeId] | None = None + + +class CreateMetaInstanceResponse(BaseModel): + message: str + command_id: CommandId + meta_instance_id: MetaInstanceId + + +class DeleteMetaInstanceResponse(BaseModel): + message: str + command_id: CommandId + meta_instance_id: MetaInstanceId + + class AdvancedImageParams(BaseModel): seed: Annotated[int, Field(ge=0)] | None = None num_inference_steps: Annotated[int, Field(ge=1, le=100)] | None = None diff --git a/src/exo/shared/types/commands.py b/src/exo/shared/types/commands.py index 09c135aa1..8697a6c2d 100644 --- a/src/exo/shared/types/commands.py +++ b/src/exo/shared/types/commands.py @@ -6,7 +6,8 @@ from exo.shared.types.api import ( ImageGenerationTaskParams, ) from exo.shared.types.chunks import InputImageChunk -from exo.shared.types.common import CommandId, NodeId +from exo.shared.types.common import CommandId, MetaInstanceId, NodeId +from exo.shared.types.meta_instance import MetaInstance from exo.shared.types.text_generation import TextGenerationTaskParams from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta from exo.shared.types.worker.shards import Sharding, ShardMetadata @@ -52,6 +53,14 @@ class TaskCancelled(BaseCommand): cancelled_command_id: CommandId +class CreateMetaInstance(BaseCommand): + meta_instance: MetaInstance + + +class DeleteMetaInstance(BaseCommand): + meta_instance_id: MetaInstanceId + + class TaskFinished(BaseCommand): finished_command_id: CommandId @@ -94,6 +103,8 @@ Command = ( | CreateInstance | DeleteInstance | TaskCancelled + | CreateMetaInstance + | DeleteMetaInstance | TaskFinished | SendInputChunk ) diff --git a/src/exo/shared/types/common.py b/src/exo/shared/types/common.py index 5db51cef1..51806de27 100644 --- a/src/exo/shared/types/common.py +++ b/src/exo/shared/types/common.py @@ -42,6 +42,10 @@ class CommandId(Id): pass +class MetaInstanceId(Id): + """Identifier for a MetaInstance.""" + + class Host(CamelCaseModel): ip: str port: int diff --git a/src/exo/shared/types/events.py b/src/exo/shared/types/events.py index 5cf93d0c6..dd28d0709 100644 --- a/src/exo/shared/types/events.py +++ b/src/exo/shared/types/events.py @@ -1,11 +1,14 @@ +import base64 +from collections.abc import Mapping from datetime import datetime -from typing import final +from typing import Annotated, final -from pydantic import Field +from pydantic import BeforeValidator, Field, PlainSerializer from exo.shared.topology import Connection from exo.shared.types.chunks import GenerationChunk, InputImageChunk -from exo.shared.types.common import CommandId, Id, NodeId, SessionId +from exo.shared.types.common import CommandId, Id, MetaInstanceId, NodeId, SessionId +from exo.shared.types.meta_instance import MetaInstance from exo.shared.types.tasks import Task, TaskId, TaskStatus from exo.shared.types.worker.downloads import DownloadProgress from exo.shared.types.worker.instances import Instance, InstanceId @@ -14,6 +17,28 @@ from exo.utils.info_gatherer.info_gatherer import GatheredInfo from exo.utils.pydantic_ext import CamelCaseModel, FrozenModel, TaggedModel +def _decode_base64_bytes(v: bytes | str) -> bytes: + if isinstance(v, bytes): + return v + return base64.b64decode(v) + + +def _encode_base64_bytes(v: bytes) -> str: + return base64.b64encode(v).decode("ascii") + + +Base64Bytes = Annotated[ + bytes, + BeforeValidator(_decode_base64_bytes), + PlainSerializer(_encode_base64_bytes, return_type=str), +] +"""bytes that serialize to/from base64 strings in JSON. + +Needed because TaggedModel's wrap validator converts JSON→Python validation +context, which breaks strict-mode bytes deserialization from JSON strings. +""" + + class EventId(Id): """ Newtype around `ID` @@ -66,6 +91,30 @@ class InstanceCreated(BaseEvent): class InstanceDeleted(BaseEvent): instance_id: InstanceId + failure_error: str | None = None + + +class MetaInstanceCreated(BaseEvent): + meta_instance: MetaInstance + + +class MetaInstanceDeleted(BaseEvent): + meta_instance_id: MetaInstanceId + + +@final +class MetaInstancePlacementFailed(BaseEvent): + meta_instance_id: MetaInstanceId + reason: str + + +@final +class InstanceRetrying(BaseEvent): + """Runners failed but retry count is below the limit — restart runners, keep instance.""" + + instance_id: InstanceId + meta_instance_id: MetaInstanceId + failure_error: str class RunnerStatusUpdated(BaseEvent): @@ -132,6 +181,25 @@ class TracesMerged(BaseEvent): traces: list[TraceEventData] +@final +class JacclSideChannelData(BaseEvent): + """A runner's local contribution to a JACCL SideChannel all_gather round.""" + + instance_id: InstanceId + runner_id: RunnerId + sequence: int + data: Base64Bytes + + +@final +class JacclSideChannelGathered(BaseEvent): + """Gathered result of a JACCL SideChannel all_gather round.""" + + instance_id: InstanceId + sequence: int + gathered_data: Mapping[RunnerId, Base64Bytes] + + Event = ( TestEvent | TaskCreated @@ -141,6 +209,10 @@ Event = ( | TaskAcknowledged | InstanceCreated | InstanceDeleted + | InstanceRetrying + | MetaInstanceCreated + | MetaInstanceDeleted + | MetaInstancePlacementFailed | RunnerStatusUpdated | RunnerDeleted | NodeTimedOut @@ -152,6 +224,8 @@ Event = ( | TopologyEdgeDeleted | TracesCollected | TracesMerged + | JacclSideChannelData + | JacclSideChannelGathered ) diff --git a/src/exo/shared/types/meta_instance.py b/src/exo/shared/types/meta_instance.py new file mode 100644 index 000000000..63052184b --- /dev/null +++ b/src/exo/shared/types/meta_instance.py @@ -0,0 +1,25 @@ +from typing import final + +from pydantic import Field + +from exo.shared.models.model_cards import ModelId +from exo.shared.types.common import MetaInstanceId, NodeId +from exo.shared.types.worker.instances import InstanceMeta +from exo.shared.types.worker.shards import Sharding +from exo.utils.pydantic_ext import FrozenModel + + +@final +class MetaInstance(FrozenModel): + """Declarative constraint: ensure an instance matching these parameters always exists.""" + + meta_instance_id: MetaInstanceId = Field(default_factory=MetaInstanceId) + model_id: ModelId + sharding: Sharding = Sharding.Pipeline + instance_meta: InstanceMeta = InstanceMeta.MlxRing + min_nodes: int = 1 + node_ids: list[NodeId] | None = None + # Failure tracking + placement_error: str | None = None + consecutive_failures: int = 0 + last_failure_error: str | None = None diff --git a/src/exo/shared/types/state.py b/src/exo/shared/types/state.py index 7350cfb0f..4ff1d4f73 100644 --- a/src/exo/shared/types/state.py +++ b/src/exo/shared/types/state.py @@ -6,7 +6,8 @@ from pydantic import ConfigDict, Field, field_serializer, field_validator from pydantic.alias_generators import to_camel from exo.shared.topology import Topology, TopologySnapshot -from exo.shared.types.common import NodeId +from exo.shared.types.common import MetaInstanceId, NodeId +from exo.shared.types.meta_instance import MetaInstance from exo.shared.types.profiling import ( DiskUsage, MemoryUsage, @@ -41,6 +42,7 @@ class State(CamelCaseModel): arbitrary_types_allowed=True, ) instances: Mapping[InstanceId, Instance] = {} + meta_instances: Mapping[MetaInstanceId, MetaInstance] = {} runners: Mapping[RunnerId, RunnerStatus] = {} downloads: Mapping[NodeId, Sequence[DownloadProgress]] = {} tasks: Mapping[TaskId, Task] = {} diff --git a/src/exo/shared/types/tasks.py b/src/exo/shared/types/tasks.py index 8d8664560..cb88d4019 100644 --- a/src/exo/shared/types/tasks.py +++ b/src/exo/shared/types/tasks.py @@ -61,7 +61,7 @@ class TextGeneration(BaseTask): # emitted by Master error_message: str | None = Field(default=None) -class CancelTask(BaseTask): +class CancelTask(BaseTask): # emitted by Worker when master cancels a task cancelled_task_id: TaskId runner_id: RunnerId diff --git a/src/exo/shared/types/worker/instances.py b/src/exo/shared/types/worker/instances.py index cda11ffaa..4254b9984 100644 --- a/src/exo/shared/types/worker/instances.py +++ b/src/exo/shared/types/worker/instances.py @@ -2,7 +2,7 @@ from enum import Enum from pydantic import model_validator -from exo.shared.types.common import Host, Id, NodeId +from exo.shared.types.common import Host, Id, MetaInstanceId, NodeId from exo.shared.types.worker.runners import RunnerId, ShardAssignments, ShardMetadata from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel @@ -19,6 +19,7 @@ class InstanceMeta(str, Enum): class BaseInstance(TaggedModel): instance_id: InstanceId shard_assignments: ShardAssignments + meta_instance_id: MetaInstanceId | None = None def shard(self, runner_id: RunnerId) -> ShardMetadata | None: return self.shard_assignments.runner_to_shard.get(runner_id, None) diff --git a/src/exo/utils/channels.py b/src/exo/utils/channels.py index 646ac8f6a..ebf0165fe 100644 --- a/src/exo/utils/channels.py +++ b/src/exo/utils/channels.py @@ -125,9 +125,7 @@ class MpSender[T]: self._state.buffer.put(item, block=True) async def send_async(self, item: T) -> None: - await to_thread.run_sync( - self.send, item, limiter=CapacityLimiter(1), abandon_on_cancel=True - ) + await to_thread.run_sync(self.send, item, limiter=CapacityLimiter(1)) def close(self) -> None: if not self._state.closed.is_set(): diff --git a/src/exo/worker/engines/mlx/utils_mlx.py b/src/exo/worker/engines/mlx/utils_mlx.py index 3ed65eccc..670847eb0 100644 --- a/src/exo/worker/engines/mlx/utils_mlx.py +++ b/src/exo/worker/engines/mlx/utils_mlx.py @@ -574,6 +574,11 @@ def mlx_cleanup( def mx_any(bool_: bool, group: Group | None) -> bool: + """Synchronize a boolean across all distributed nodes. + + Returns True if any node has bool_=True. Uses all_sum so every + node participates in the collective — preventing GPU deadlocks. + """ if group is None: return bool_ num_true = mx.distributed.all_sum( diff --git a/src/exo/worker/main.py b/src/exo/worker/main.py index af105652e..6b2a94757 100644 --- a/src/exo/worker/main.py +++ b/src/exo/worker/main.py @@ -24,6 +24,7 @@ from exo.shared.types.events import ( ForwarderEvent, IndexedEvent, InputChunkReceived, + JacclSideChannelGathered, NodeGatheredInfo, TaskCreated, TaskStatusUpdated, @@ -33,7 +34,6 @@ from exo.shared.types.events import ( from exo.shared.types.multiaddr import Multiaddr from exo.shared.types.state import State from exo.shared.types.tasks import ( - CancelTask, CreateRunner, DownloadModel, ImageEdits, @@ -159,6 +159,15 @@ class Worker: for idx, event in indexed_events: self.state = apply(self.state, IndexedEvent(idx=idx, event=event)) + # Dispatch JACCL gathered events to the relevant RunnerSupervisor + if isinstance(event, JacclSideChannelGathered): + for runner in self.runners.values(): + if ( + runner.bound_instance.instance.instance_id + == event.instance_id + ): + runner.notify_gathered(event) + # Buffer input image chunks for image editing if isinstance(event, InputChunkReceived): cmd_id = event.command_id @@ -225,22 +234,15 @@ class Worker: ) ) case Shutdown(runner_id=runner_id): - runner = self.runners.pop(runner_id) try: with fail_after(3): - await runner.start_task(task) + await self.runners.pop(runner_id).start_task(task) except TimeoutError: await self.event_sender.send( TaskStatusUpdated( task_id=task.task_id, task_status=TaskStatus.TimedOut ) ) - finally: - runner.shutdown() - case CancelTask( - cancelled_task_id=cancelled_task_id, runner_id=runner_id - ): - await self.runners[runner_id].cancel_task(cancelled_task_id) case ImageEdits() if task.task_params.total_input_chunks > 0: # Assemble image from chunks and inject into task cmd_id = task.command_id @@ -278,18 +280,18 @@ class Worker: del self.input_chunk_buffer[cmd_id] if cmd_id in self.input_chunk_counts: del self.input_chunk_counts[cmd_id] - await self._start_runner_task(modified_task) + await self.runners[self._task_to_runner_id(task)].start_task( + modified_task + ) case task: - await self._start_runner_task(task) + await self.runners[self._task_to_runner_id(task)].start_task(task) def shutdown(self): self._tg.cancel_scope.cancel() - async def _start_runner_task(self, task: Task): - if (instance := self.state.instances.get(task.instance_id)) is not None: - await self.runners[ - instance.shard_assignments.node_to_runner[self.node_id] - ].start_task(task) + def _task_to_runner_id(self, task: Task): + instance = self.state.instances[task.instance_id] + return instance.shard_assignments.node_to_runner[self.node_id] async def _nack_request(self, since_idx: int) -> None: # We request all events after (and including) the missing index. diff --git a/src/exo/worker/plan.py b/src/exo/worker/plan.py index ce2eb4a96..4107d5535 100644 --- a/src/exo/worker/plan.py +++ b/src/exo/worker/plan.py @@ -35,6 +35,7 @@ from exo.shared.types.worker.runners import ( RunnerLoading, RunnerReady, RunnerRunning, + RunnerShutdown, RunnerStatus, RunnerWarmingUp, ) @@ -56,7 +57,7 @@ def plan( return ( _cancel_tasks(runners, tasks) or _kill_runner(runners, all_runners, instances) - or _create_runner(node_id, runners, instances) + or _create_runner(node_id, runners, instances, all_runners) or _model_needs_download(node_id, runners, global_download_status) or _init_distributed_backend(runners, all_runners) or _load_model(runners, all_runners, global_download_status) @@ -75,6 +76,12 @@ def _kill_runner( if (instance_id := runner.bound_instance.instance.instance_id) not in instances: return Shutdown(instance_id=instance_id, runner_id=runner_id) + # Master removed our runner from state (retry signal) and process is dead + if runner_id not in all_runners and isinstance( + runner.status, (RunnerFailed, RunnerShutdown) + ): + return Shutdown(instance_id=instance_id, runner_id=runner_id) + for ( global_runner_id ) in runner.bound_instance.instance.shard_assignments.node_to_runner.values(): @@ -92,6 +99,7 @@ def _create_runner( node_id: NodeId, runners: Mapping[RunnerId, RunnerSupervisor], instances: Mapping[InstanceId, Instance], + all_runners: Mapping[RunnerId, RunnerStatus], ) -> CreateRunner | None: for instance in instances.values(): runner_id = instance.shard_assignments.node_to_runner.get(node_id, None) @@ -101,6 +109,16 @@ def _create_runner( if runner_id in runners: continue + # Don't create while any peer runner is in a terminal state — wait for + # the master to emit InstanceRetrying which removes them from state. + has_terminal_peer = any( + isinstance(all_runners.get(peer_rid), (RunnerFailed, RunnerShutdown)) + for peer_rid in instance.shard_assignments.node_to_runner.values() + if peer_rid != runner_id + ) + if has_terminal_peer: + continue + shard = instance.shard(runner_id) assert shard is not None @@ -310,7 +328,8 @@ def _pending_tasks( def _cancel_tasks( runners: Mapping[RunnerId, RunnerSupervisor], tasks: Mapping[TaskId, Task], -) -> Task | None: +) -> CancelTask | None: + """Find a cancelled task that hasn't been sent to the runner yet.""" for task in tasks.values(): if task.task_status != TaskStatus.Cancelled: continue diff --git a/src/exo/worker/runner/bootstrap.py b/src/exo/worker/runner/bootstrap.py index ed420aab3..69ef6c720 100644 --- a/src/exo/worker/runner/bootstrap.py +++ b/src/exo/worker/runner/bootstrap.py @@ -17,6 +17,7 @@ def entrypoint( task_receiver: MpReceiver[Task], cancel_receiver: MpReceiver[TaskId], _logger: "loguru.Logger", + pipe_fifo_paths: tuple[str, str] | None = None, ) -> None: fast_synch_override = os.environ.get("EXO_FAST_SYNCH") if fast_synch_override == "on" or ( @@ -30,6 +31,16 @@ def entrypoint( else: os.environ["MLX_METAL_FAST_SYNCH"] = "0" + # Open JACCL FIFOs by path and set env vars for C++ SideChannel. + # Named pipes (FIFOs) work across multiprocessing spawn (macOS default). + if pipe_fifo_paths is not None: + fifo_c2p, fifo_p2c = pipe_fifo_paths + # C++ reads gathered data from p2c (PIPE_IN), writes local data to c2p (PIPE_OUT) + pipe_in_fd = os.open(fifo_p2c, os.O_RDONLY) + pipe_out_fd = os.open(fifo_c2p, os.O_WRONLY) + os.environ["MLX_JACCL_PIPE_IN"] = str(pipe_in_fd) + os.environ["MLX_JACCL_PIPE_OUT"] = str(pipe_out_fd) + global logger logger = _logger @@ -56,7 +67,9 @@ def entrypoint( try: event_sender.close() task_receiver.close() + cancel_receiver.close() finally: event_sender.join() task_receiver.join() + cancel_receiver.join() logger.info("bye from the runner") diff --git a/src/exo/worker/runner/runner.py b/src/exo/worker/runner/runner.py index e55456d32..818bd9be2 100644 --- a/src/exo/worker/runner/runner.py +++ b/src/exo/worker/runner/runner.py @@ -243,7 +243,7 @@ def main( assert inference_model assert tokenizer - t = time.monotonic() + t = time.perf_counter() toks = warmup_inference( model=inference_model, tokenizer=tokenizer, @@ -251,7 +251,7 @@ def main( ) logger.info(f"warmed up by generating {toks} tokens") check_for_cancel_every = min( - math.ceil(toks / min(time.monotonic() - t, 0.001)), 100 + math.ceil(toks / max(time.perf_counter() - t, 0.001)), 100 ) if group is not None: check_for_cancel_every = int( diff --git a/src/exo/worker/runner/runner_supervisor.py b/src/exo/worker/runner/runner_supervisor.py index 5d39a881d..519d7b072 100644 --- a/src/exo/worker/runner/runner_supervisor.py +++ b/src/exo/worker/runner/runner_supervisor.py @@ -1,6 +1,10 @@ import contextlib +import os import signal +import struct +import tempfile from dataclasses import dataclass, field +from functools import partial from multiprocessing import Process from typing import Self @@ -14,12 +18,14 @@ from loguru import logger from exo.shared.types.events import ( Event, + JacclSideChannelData, + JacclSideChannelGathered, RunnerStatusUpdated, TaskAcknowledged, TaskStatusUpdated, ) from exo.shared.types.tasks import Task, TaskId, TaskStatus -from exo.shared.types.worker.instances import BoundInstance +from exo.shared.types.worker.instances import BoundInstance, MlxJacclInstance from exo.shared.types.worker.runners import ( RunnerConnecting, RunnerFailed, @@ -34,6 +40,26 @@ from exo.shared.types.worker.shards import ShardMetadata from exo.utils.channels import MpReceiver, MpSender, Sender, mp_channel from exo.worker.runner.bootstrap import entrypoint + +def _pipe_read_exact(fd: int, n: int) -> bytes | None: + """Read exactly n bytes from a file descriptor. Returns None on EOF.""" + data = b"" + while len(data) < n: + chunk = os.read(fd, n - len(data)) + if not chunk: + return None + data += chunk + return data + + +def _pipe_write_all(fd: int, data: bytes) -> None: + """Write all bytes to a file descriptor.""" + view = memoryview(data) + while view: + written = os.write(fd, view) + view = view[written:] + + PREFILL_TIMEOUT_SECONDS = 60 DECODE_TIMEOUT_SECONDS = 5 @@ -46,12 +72,21 @@ class RunnerSupervisor: initialize_timeout: float _ev_recv: MpReceiver[Event] _task_sender: MpSender[Task] - _event_sender: Sender[Event] _cancel_sender: MpSender[TaskId] + _event_sender: Sender[Event] + _pipe_read_fd: int | None = None # Python reads runner's pipe output + _pipe_write_fd: int | None = None # Python writes gathered data to runner + _child_pipe_fds: tuple[int, int] | None = None # fds to close after fork + _fifo_dir: str | None = None # Temp dir for FIFO files (for cleanup) + _fifo_c2p: str | None = None # FIFO path: C++ writes → Python reads + _fifo_p2c: str | None = None # FIFO path: Python writes → C++ reads status: RunnerStatus = field(default_factory=RunnerIdle, init=False) pending: dict[TaskId, anyio.Event] = field(default_factory=dict, init=False) completed: set[TaskId] = field(default_factory=set, init=False) cancelled: set[TaskId] = field(default_factory=set, init=False) + _gathered_waiters: dict[ + int, tuple[anyio.Event, JacclSideChannelGathered | None] + ] = field(default_factory=dict, init=False) @classmethod def create( @@ -65,6 +100,23 @@ class RunnerSupervisor: task_sender, task_recv = mp_channel[Task]() cancel_sender, cancel_recv = mp_channel[TaskId]() + # For MlxJaccl instances, create named pipes (FIFOs) for SideChannel relay. + # Named pipes work across multiprocessing.Process spawn (macOS default). + # FIFO c2p: C++ writes local data → Python reads it + # FIFO p2c: Python writes gathered data → C++ reads it + fifo_dir: str | None = None + fifo_c2p: str | None = None + fifo_p2c: str | None = None + pipe_fifo_paths: tuple[str, str] | None = None + + if isinstance(bound_instance.instance, MlxJacclInstance): + fifo_dir = tempfile.mkdtemp(prefix="exo_jaccl_") + fifo_c2p = os.path.join(fifo_dir, "c2p") # C++ → Python + fifo_p2c = os.path.join(fifo_dir, "p2c") # Python → C++ + os.mkfifo(fifo_c2p) + os.mkfifo(fifo_p2c) + pipe_fifo_paths = (fifo_c2p, fifo_p2c) + runner_process = Process( target=entrypoint, args=( @@ -73,6 +125,7 @@ class RunnerSupervisor: task_recv, cancel_recv, logger, + pipe_fifo_paths, ), daemon=True, ) @@ -88,21 +141,54 @@ class RunnerSupervisor: _task_sender=task_sender, _cancel_sender=cancel_sender, _event_sender=event_sender, + _fifo_dir=fifo_dir, + _fifo_c2p=fifo_c2p, + _fifo_p2c=fifo_p2c, ) return self async def run(self): self.runner_process.start() - await self._forward_events() + + if self._fifo_c2p is not None and self._fifo_p2c is not None: + # Open FIFOs from parent side. These block until child opens the other end, + # so we run them in threads concurrently to avoid deadlock. + fifo_c2p = self._fifo_c2p + fifo_p2c = self._fifo_p2c + + async def open_read() -> None: + self._pipe_read_fd = await to_thread.run_sync( + partial(os.open, fifo_c2p, os.O_RDONLY) + ) + + async def open_write() -> None: + self._pipe_write_fd = await to_thread.run_sync( + partial(os.open, fifo_p2c, os.O_WRONLY) + ) + + async with anyio.create_task_group() as open_tg: + open_tg.start_soon(open_read) + open_tg.start_soon(open_write) + + logger.info( + f"JACCL pipe relay: FIFOs opened (read_fd={self._pipe_read_fd}, write_fd={self._pipe_write_fd})" + ) + + async with anyio.create_task_group() as tg: + tg.start_soon(self._pipe_relay) + tg.start_soon(self._forward_events) + else: + await self._forward_events() def shutdown(self): logger.info("Runner supervisor shutting down") self._ev_recv.close() self._task_sender.close() - self._event_sender.close() self._cancel_sender.send(TaskId("CANCEL_CURRENT_TASK")) self._cancel_sender.close() + self._event_sender.close() + self._close_pipe_fds() self.runner_process.join(1) if not self.runner_process.is_alive(): logger.info("Runner process succesfully terminated") @@ -140,6 +226,7 @@ class RunnerSupervisor: await event.wait() async def cancel_task(self, task_id: TaskId): + """Send a cancellation signal to the runner process.""" if task_id in self.completed: logger.info(f"Unable to cancel {task_id} as it has been completed") return @@ -181,6 +268,110 @@ class RunnerSupervisor: for tid in self.pending: self.pending[tid].set() + def _close_pipe_fds(self) -> None: + if self._pipe_read_fd is not None: + with contextlib.suppress(OSError): + os.close(self._pipe_read_fd) + self._pipe_read_fd = None + if self._pipe_write_fd is not None: + with contextlib.suppress(OSError): + os.close(self._pipe_write_fd) + self._pipe_write_fd = None + if self._child_pipe_fds is not None: + for fd in self._child_pipe_fds: + with contextlib.suppress(OSError): + os.close(fd) + self._child_pipe_fds = None + # Clean up FIFO files + if self._fifo_c2p is not None: + with contextlib.suppress(OSError): + os.unlink(self._fifo_c2p) + self._fifo_c2p = None + if self._fifo_p2c is not None: + with contextlib.suppress(OSError): + os.unlink(self._fifo_p2c) + self._fifo_p2c = None + if self._fifo_dir is not None: + with contextlib.suppress(OSError): + os.rmdir(self._fifo_dir) + self._fifo_dir = None + + async def _pipe_relay(self) -> None: + """Relay JACCL SideChannel all_gather rounds between runner pipes and exo events.""" + assert self._pipe_read_fd is not None + assert self._pipe_write_fd is not None + read_fd = self._pipe_read_fd + write_fd = self._pipe_write_fd + sequence = 0 + + try: + while True: + # 1. Read local data from runner: [uint32 size][size bytes] + header = await to_thread.run_sync(partial(_pipe_read_exact, read_fd, 4)) + if header is None: + logger.info("JACCL pipe relay: runner closed pipe (EOF)") + break + data_size: int = struct.unpack(" None: + """Called by the worker when a JacclSideChannelGathered event arrives.""" + seq = event.sequence + if seq not in self._gathered_waiters: + logger.warning(f"JACCL: received gathered event for unknown sequence {seq}") + return + waiter, _ = self._gathered_waiters[seq] + self._gathered_waiters[seq] = (waiter, event) + waiter.set() + def __del__(self) -> None: if self.runner_process.is_alive(): logger.warning("RunnerSupervisor was not stopped cleanly.") diff --git a/src/exo/worker/tests/unittests/test_runner/test_event_ordering.py b/src/exo/worker/tests/unittests/test_runner/test_event_ordering.py index 38a0a921b..878aea990 100644 --- a/src/exo/worker/tests/unittests/test_runner/test_event_ordering.py +++ b/src/exo/worker/tests/unittests/test_runner/test_event_ordering.py @@ -1,9 +1,7 @@ # Check tasks are complete before runner is ever ready. -import unittest.mock from collections.abc import Iterable from typing import Callable -import mlx.core as mx import pytest import exo.worker.runner.runner as mlx_runner @@ -117,6 +115,12 @@ def patch_out_mlx(monkeypatch: pytest.MonkeyPatch): monkeypatch.setattr(mlx_runner, "warmup_inference", make_nothin(1)) monkeypatch.setattr(mlx_runner, "_check_for_debug_prompts", nothin) monkeypatch.setattr(mlx_runner, "mx_any", make_nothin(False)) + + # Mock mx.distributed.all_gather so MockGroup doesn't hit real MLX C++ bindings. + def _mock_all_gather(x: object, **_kw: object) -> object: + return x + + monkeypatch.setattr(mlx_runner.mx.distributed, "all_gather", _mock_all_gather) # Mock apply_chat_template since we're using a fake tokenizer (integer 1). # Returns a prompt without thinking tag so detect_thinking_prompt_suffix returns None. monkeypatch.setattr(mlx_runner, "apply_chat_template", make_nothin("test prompt")) @@ -178,16 +182,15 @@ def _run(tasks: Iterable[Task]): # this is some c++ nonsense task_receiver.close = nothin task_receiver.join = nothin - with unittest.mock.patch( - "exo.worker.runner.runner.mx.distributed.all_gather", - make_nothin(mx.array([1])), - ): - mlx_runner.main( - bound_instance, - event_sender, # pyright: ignore[reportArgumentType] - task_receiver, - cancel_receiver, - ) + cancel_receiver.close = nothin + cancel_receiver.join = nothin + + mlx_runner.main( + bound_instance, + event_sender, # pyright: ignore[reportArgumentType] + task_receiver, + cancel_receiver, + ) return event_sender.events