Compare commits

...

28 Commits

Author SHA1 Message Date
Alex Cheema
f9ffdaef5f Preserve last_failure_error across instance recreation, fix RDMA banner wording
- apply_instance_created no longer clears last_failure_error so the
  error context persists while the new instance starts up
- Dashboard retryError shows the error without (N/3) prefix when
  consecutiveFailures is 0 (instance was recreated)
- Jaccl warning tooltip now says "experimental RDMA driver in macOS"

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 16:48:30 -08:00
Alex Cheema
8c2416c9ea chore: remove temporary screenshot files
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 16:40:16 -08:00
Alex Cheema
e5007f619a temp: add jaccl warning screenshots for PR comment
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 16:38:53 -08:00
Alex Cheema
a627f67253 dashboard: show warning banner for [jaccl] RDMA driver errors
Detect errors containing "[jaccl]" in MetaInstance failure errors and
display a red dismissible alert banner. The tooltip explains this is a
macOS RDMA driver issue and that the affected machine needs to be
restarted. Alert re-appears if a new error arrives after dismissal.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 16:38:42 -08:00
Alex Cheema
f189222bfc Merge remote-tracking branch 'origin/main' into alexcheema/meta-instance
# Conflicts:
#	dashboard/src/lib/stores/app.svelte.ts
#	dashboard/src/routes/+page.svelte
2026-02-11 15:59:50 -08:00
Alex Cheema
ad6d35d68a Retry runners within the same Instance instead of recreating
When runners fail for a MetaInstance-backed Instance, retry up to 3
times by restarting runners within the same Instance rather than
deleting and recreating it each time. After 3 failures, delete the
Instance so MetaInstanceReconciler can create a fresh one.

- Add InstanceRetrying event that removes runners from state (signaling
  workers to restart) and increments consecutive_failures on MetaInstance
- InstanceHealthReconciler emits InstanceRetrying when under retry limit,
  InstanceDeleted when exhausted or no MetaInstance
- Worker _kill_runner detects retry signal (runner deleted from state +
  terminal supervisor) and cleans up for _create_runner to recreate
- Worker _create_runner guards against oscillation by blocking creation
  while any peer runner has explicit terminal status
- InstanceCreated resets consecutive_failures for fresh starts

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 14:21:11 -08:00
Alex Cheema
c236d62caf Remove timestamp-based retry cooldown
Remove last_failure_at field and RETRY_COOLDOWN_SECONDS logic.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 12:59:39 -08:00
Alex Cheema
a8069e8a30 Consolidate failure state onto MetaInstance, add 5s retry cooldown
Move placement_error, consecutive_failures, last_failure_error, and
last_failure_at directly onto the MetaInstance model instead of keeping
them as separate State mappings (meta_instance_errors, InstanceFailureInfo,
meta_instance_failure_info). Adds a 5-second cooldown between retry attempts
to prevent rapid instance churn when runners fail instantly.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 12:55:47 -08:00
Alex Cheema
84ce555d55 Show retry attempt count with error message, e.g. (2/3)
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 12:43:20 -08:00
Alex Cheema
b78ea438bc Include node friendly names in runner error messages
Each error in the combined message is now prefixed with the node's friendly
name (e.g. "MacBook Pro: OOM; Mac Studio: connection reset") so the root
cause node is easily identifiable.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 12:41:10 -08:00
Alex Cheema
1960b16f9f Remove permanent retry blocking, allow continuous retry batches
The dashboard % 3 logic already handles displaying retry progress in batches
(RETRYING 1/3, 2/3, 3/3, then PLACING with error, repeat). No need to
permanently block placement after 3 failures.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 12:35:03 -08:00
Alex Cheema
c6838c8fd8 Show retry count in exceeded retry limit message (3/3)
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 12:28:17 -08:00
Alex Cheema
420d9b9e76 Collect all runner error messages instead of just the last one
When multiple runners fail, concatenate all error messages with "; " so the
real error isn't hidden by generic side-effect failures from other runners.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 12:27:49 -08:00
Alex Cheema
13f1e9c489 Stop infinite retries after 3 failures, show errors persistently in dashboard
MetaInstanceReconciler now checks failure count before placement — after 3
consecutive failures it emits MetaInstancePlacementFailed instead of retrying
forever. Dashboard shows "Retrying after error: <msg>" in orange throughout
the retry cycle, not just during the brief window with no backing instance.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 12:21:11 -08:00
Alex Cheema
451a06b3d8 Add instance retry logic with max 3 retries and failure tracking
- Extend InstanceDeleted with failure_error field for runner crash info
- Add InstanceFailureInfo model tracking consecutive failures per MetaInstance
- InstanceHealthReconciler now detects runner failures (all terminal with
  at least one RunnerFailed) in addition to connection failures
- apply_instance_deleted increments failure counter for meta-bound instances
- Dashboard shows RETRYING (N/3) status with error messages, and
  "Instance re-created due to failure" after 3 consecutive failures
- Extract and display RunnerFailed error messages in instance status

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 12:09:42 -08:00
Alex Cheema
94b55d66f4 Fix MetaInstance.node_ids frozenset failing JSON deserialization
frozenset serializes to a JSON array but cannot be deserialized back
in strict mode through the TaggedModel wrap validator (list → frozenset
coercion is rejected). Changed to list[NodeId] since the model is
already frozen/immutable.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 10:54:56 -08:00
Alex Cheema
2b68b931c5 Send node_ids from placement preview when launching instances
The dashboard now extracts node IDs from the selected preview's
memory_delta_by_node, ensuring the backend places on exactly the
nodes the user was shown. Also reverts incorrect RDMA min_nodes >= 2
enforcement since single-node RDMA is valid.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 10:49:37 -08:00
Alex Cheema
4aecaa7748 Enforce min_nodes >= 2 for RDMA (MlxJaccl) instances
RDMA requires at least 2 nodes — a single-node RDMA instance is
nonsensical. Enforce this in both the dashboard (when building the
launch request) and the backend placement (when filtering cycles).
Previously, selecting RDMA would still place on 1 node because
min_nodes defaulted to 1 and the placement silently switched to Ring.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 10:42:21 -08:00
Alex Cheema
25e2891c30 Ensure min_nodes >= node filter size when launching
When user selects specific nodes via the filter, min_nodes should be at
least the number of filtered nodes to prevent placement from picking a
smaller cycle.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 10:36:51 -08:00
Alex Cheema
16345e0ffa Send node_ids from dashboard, error on RDMA when unavailable
Dashboard was not including the user's node filter in the POST to
/meta_instance, so placement ignored which nodes the user selected.
Also, placement silently fell back to Ring when RDMA was requested but
no RDMA-connected cycles were available — now raises an error that
surfaces via MetaInstancePlacementFailed.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 10:26:29 -08:00
Alex Cheema
3a845f90b0 Fix use_default validator silently ignoring sharding/instance_meta
The mode="plain" validator bypassed Pydantic's string-to-enum coercion,
so JSON strings like "Tensor" and "MlxJaccl" from the dashboard failed
the isinstance check and silently fell back to Pipeline/MlxRing defaults.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 10:05:00 -08:00
Alex Cheema
dccf2440ba Add placement error feedback and per-node loading status
Show why MetaInstance placement fails instead of stuck "PLACING", and
show per-node runner status during loading for multi-node instances.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 10:01:07 -08:00
Alex Cheema
f96f3f2c0f Show MetaInstance sharding/type while PLACING, fix MlxIbv references
When a MetaInstance has no backing instance yet, derive the strategy
display from the MetaInstance's own sharding and instanceMeta fields
rather than showing "Unknown (Unknown)".

Also clean up all stale MlxIbv references across the dashboard —
the backend enum is MlxJaccl.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 09:23:44 -08:00
Alex Cheema
7d54e468d5 Extract reconciler into ProcessManager protocol, fix RDMA instance type
- Replace inline _plan() with ProcessManager loop (_reconcile), tick
  every 1s instead of 10s — safe because all PMs are idempotent
- Fix dashboard sending "MlxIbv" instead of "MlxJaccl" for RDMA
  instance type, which silently fell back to MlxRing default
- Remove all stale MlxIbv references from dashboard

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 09:19:13 -08:00
Alex Cheema
124d504f95 Extract reconciler into ProcessManager protocol
Replace inline _plan() steps with a list of ProcessManagers, each
implementing async reconcile(State) -> Sequence[Event]. Tick every
1s instead of 10s — safe because all PMs are idempotent against state.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 09:06:41 -08:00
Alex Cheema
9ab4a40989 Simplify MetaInstance binding: put meta_instance_id on Instance
The separate MetaInstanceBound event + meta_instance_backing map
introduced two bugs: stale exclusion sets in the reconciler loop and
a delete ordering race. Embedding meta_instance_id directly on
BaseInstance eliminates the binding mechanism entirely — when an
instance is created for a MetaInstance it carries the ID, when
deleted the binding is gone. No separate map, no cleanup, no races.

Also fixes delete_meta_instance to cascade-delete backing instances.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-10 16:15:29 -08:00
Alex Cheema
f4329c72c2 Add explicit MetaInstance binding, slim MetaInstance to use ModelId
- Add MetaInstanceBound event and meta_instance_backing State field
  for explicit MetaInstance → Instance binding (prevents ambiguous
  linking when two MetaInstances have identical constraints)
- Replace model_card: ModelCard with model_id: ModelId on MetaInstance
  (load ModelCard on-demand at placement time)
- Add MetaInstance API endpoints (POST /meta_instance, DELETE)
- Update dashboard to use MetaInstances as primary primitive with
  unified display items merging MetaInstances and orphan instances
- Dashboard launches via MetaInstance instead of direct Instance creation

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-10 15:53:07 -08:00
Alex Cheema
ceb76b8f6c Add MetaInstance declarative layer with connection health checking
Introduces MetaInstance as a declarative constraint ensuring an instance
matching given parameters (model, sharding, min_nodes) always exists.
The master's reconciliation loop continuously checks for unsatisfied
meta-instances and attempts placement. Connection health checking
verifies that specific IPs (MlxRing) and RDMA interfaces (MlxJaccl)
stored on instances still exist as topology edges, enabling automatic
recovery when cables are swapped or interfaces change.

Also eliminates the master's loopback event path, unifying all event
emission through _apply_and_broadcast for simpler control flow.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-10 13:43:50 -08:00
22 changed files with 2098 additions and 230 deletions

View File

@@ -185,11 +185,7 @@
let instanceType: string | null = null;
if (instanceTag === "MlxRingInstance") instanceType = "MLX Ring";
else if (
instanceTag === "MlxIbvInstance" ||
instanceTag === "MlxJacclInstance"
)
instanceType = "MLX RDMA";
else if (instanceTag === "MlxJacclInstance") instanceType = "MLX RDMA";
let sharding: string | null = null;
const inst = instance as {

View File

@@ -21,7 +21,7 @@
} | null;
nodes?: Record<string, NodeInfo>;
sharding?: "Pipeline" | "Tensor";
runtime?: "MlxRing" | "MlxIbv" | "MlxJaccl";
runtime?: "MlxRing" | "MlxJaccl";
onLaunch?: () => void;
tags?: string[];
apiPreview?: PlacementPreview | null;
@@ -348,7 +348,7 @@
// Debug mode state
const isDebugMode = $derived(debugMode());
const topology = $derived(topologyData());
const isRdma = $derived(runtime === "MlxIbv" || runtime === "MlxJaccl");
const isRdma = $derived(runtime === "MlxJaccl");
// Get interface name for an IP from node data
function getInterfaceForIp(nodeId: string, ip?: string): string | null {
@@ -575,7 +575,7 @@
>
{runtime === "MlxRing"
? "MLX Ring"
: runtime === "MlxIbv" || runtime === "MlxJaccl"
: runtime === "MlxJaccl"
? "MLX RDMA"
: runtime}
</span>

View File

@@ -168,7 +168,7 @@ export interface ModelDownloadStatus {
export interface PlacementPreview {
model_id: string;
sharding: "Pipeline" | "Tensor";
instance_meta: "MlxRing" | "MlxIbv" | "MlxJaccl";
instance_meta: "MlxRing" | "MlxJaccl";
instance: unknown | null;
memory_delta_by_node: Record<string, number> | null;
error: string | null;
@@ -219,7 +219,6 @@ interface RawStateResponse {
string,
{
MlxRingInstance?: Instance;
MlxIbvInstance?: Instance;
MlxJacclInstance?: Instance;
}
>;
@@ -250,6 +249,20 @@ interface RawStateResponse {
>;
// Thunderbolt bridge cycles (nodes with bridge enabled forming loops)
thunderboltBridgeCycles?: string[][];
// MetaInstances (declarative instance constraints)
metaInstances?: Record<string, MetaInstanceData>;
}
export interface MetaInstanceData {
metaInstanceId: string;
modelId: string;
sharding: string;
instanceMeta: string;
minNodes: number;
nodeIds: string[] | null;
placementError: string | null;
consecutiveFailures: number;
lastFailureError: string | null;
}
export interface MessageAttachment {
@@ -535,6 +548,7 @@ class AppStore {
isLoadingPreviews = $state(false);
previewNodeFilter = $state<Set<string>>(new Set());
lastUpdate = $state<number | null>(null);
metaInstances = $state<Record<string, MetaInstanceData>>({});
nodeIdentities = $state<Record<string, RawNodeIdentity>>({});
thunderboltBridgeCycles = $state<string[][]>([]);
nodeThunderbolt = $state<
@@ -891,11 +905,7 @@ class AppStore {
let instanceType: string | null = null;
if (instanceTag === "MlxRingInstance") instanceType = "MLX Ring";
else if (
instanceTag === "MlxIbvInstance" ||
instanceTag === "MlxJacclInstance"
)
instanceType = "MLX RDMA";
else if (instanceTag === "MlxJacclInstance") instanceType = "MLX RDMA";
let sharding: string | null = null;
const inst = instance as {
@@ -1260,6 +1270,8 @@ class AppStore {
if (data.downloads) {
this.downloads = data.downloads;
}
// MetaInstances
this.metaInstances = data.metaInstances ?? {};
if (data.nodeDisk) {
this.nodeDisk = data.nodeDisk;
}
@@ -3019,6 +3031,7 @@ export const tps = () => appStore.tps;
export const totalTokens = () => appStore.totalTokens;
export const topologyData = () => appStore.topologyData;
export const instances = () => appStore.instances;
export const metaInstances = () => appStore.metaInstances;
export const runners = () => appStore.runners;
export const downloads = () => appStore.downloads;
export const nodeDisk = () => appStore.nodeDisk;

View File

@@ -42,6 +42,7 @@
toggleTopologyOnlyMode,
chatSidebarVisible,
toggleChatSidebarVisible,
metaInstances,
nodeThunderbolt,
nodeRdmaCtl,
thunderboltBridgeCycles,
@@ -49,6 +50,7 @@
nodeIdentities,
type DownloadProgress,
type PlacementPreview,
type MetaInstanceData,
} from "$lib/stores/app.svelte";
import HeaderNav from "$lib/components/HeaderNav.svelte";
import { fade, fly } from "svelte/transition";
@@ -68,7 +70,72 @@
const debugEnabled = $derived(debugMode());
const topologyOnlyEnabled = $derived(topologyOnlyMode());
const sidebarVisible = $derived(chatSidebarVisible());
const metaInstancesData = $derived(metaInstances());
const tbBridgeCycles = $derived(thunderboltBridgeCycles());
// Get status for a MetaInstance that has no backing instance yet
function getMetaInstancePlacingStatus(metaInstanceId: string) {
const meta = metaInstancesData[metaInstanceId];
const placementError = meta?.placementError;
const failures = meta?.consecutiveFailures ?? 0;
const lastError = meta?.lastFailureError;
if (placementError) {
return {
statusText: "PLACEMENT FAILED",
statusClass: "failed",
isDownloading: false as const,
isFailed: true,
progress: null,
perNode: [] as Array<{
nodeId: string;
nodeName: string;
progress: DownloadProgress;
}>,
perNodeStatus: [] as PerNodeRunnerStatus[],
errorMessage: placementError,
};
}
if (failures > 0) {
const retryPosition = ((failures - 1) % 3) + 1;
const isRecreated = failures % 3 === 0;
return {
statusText: isRecreated ? "PLACING" : `RETRYING (${retryPosition}/3)`,
statusClass: "starting",
isDownloading: false as const,
isFailed: false,
progress: null,
perNode: [] as Array<{
nodeId: string;
nodeName: string;
progress: DownloadProgress;
}>,
perNodeStatus: [] as PerNodeRunnerStatus[],
errorMessage: isRecreated
? `Instance re-created due to failure: ${lastError}`
: lastError
? `Previous failure: ${lastError}`
: null,
};
}
return {
statusText: "PLACING",
statusClass: "starting",
isDownloading: false as const,
isFailed: false,
progress: null,
perNode: [] as Array<{
nodeId: string;
nodeName: string;
progress: DownloadProgress;
}>,
perNodeStatus: [] as PerNodeRunnerStatus[],
errorMessage: null,
};
}
const tbBridgeData = $derived(nodeThunderboltBridge());
const identitiesData = $derived(nodeIdentities());
const tbIdentifiers = $derived(nodeThunderbolt());
@@ -114,6 +181,17 @@
});
let tb5InfoDismissed = $state(false);
// Detect [jaccl] RDMA driver errors from MetaInstance failure errors
const jacclError = $derived.by(() => {
for (const mi of Object.values(metaInstancesData)) {
if (mi.lastFailureError?.includes("[jaccl]")) {
return mi.lastFailureError;
}
}
return null;
});
let jacclDismissedError = $state<string | null>(null);
// Helper to get friendly node name from node ID
function getNodeName(nodeId: string): string {
const node = data?.nodes?.[nodeId];
@@ -211,7 +289,7 @@
return model.tasks.includes("ImageToImage");
}
let selectedSharding = $state<"Pipeline" | "Tensor">("Pipeline");
type InstanceMeta = "MlxRing" | "MlxIbv" | "MlxJaccl";
type InstanceMeta = "MlxRing" | "MlxJaccl";
// Launch defaults persistence
const LAUNCH_DEFAULTS_KEY = "exo-launch-defaults";
@@ -468,7 +546,7 @@
const matchesSelectedRuntime = (runtime: InstanceMeta): boolean =>
selectedInstanceType === "MlxRing"
? runtime === "MlxRing"
: runtime === "MlxIbv" || runtime === "MlxJaccl";
: runtime === "MlxJaccl" || runtime === "MlxJaccl";
// Helper to check if a model can be launched (has valid placement with >= minNodes)
function canModelFit(modelId: string): boolean {
@@ -684,39 +762,30 @@
launchingModelId = modelId;
try {
// Use the specific preview if provided, otherwise fall back to filtered preview
const preview = specificPreview ?? filteredPreview();
let instanceData: unknown;
// Extract node IDs from the preview the user is seeing
const previewNodeIds = preview?.memory_delta_by_node
? Object.keys(preview.memory_delta_by_node)
: nodeFilter.size > 0
? Array.from(nodeFilter)
: undefined;
if (preview?.instance) {
// Use the instance from the preview
instanceData = preview.instance;
} else {
// Fallback: GET placement from API
const placementResponse = await fetch(
`/instance/placement?model_id=${encodeURIComponent(modelId)}&sharding=${selectedSharding}&instance_meta=${selectedInstanceType}&min_nodes=${selectedMinNodes}`,
);
if (!placementResponse.ok) {
const errorText = await placementResponse.text();
console.error("Failed to get placement:", errorText);
return;
}
instanceData = await placementResponse.json();
}
// POST the instance to create it
const response = await fetch("/instance", {
const response = await fetch("/meta_instance", {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({ instance: instanceData }),
body: JSON.stringify({
model_id: modelId,
sharding: preview?.sharding ?? selectedSharding,
instance_meta: preview?.instance_meta ?? selectedInstanceType,
min_nodes: selectedMinNodes,
node_ids: previewNodeIds,
}),
});
if (!response.ok) {
const errorText = await response.text();
console.error("Failed to launch instance:", errorText);
console.error("Failed to create meta instance:", errorText);
} else {
// Always auto-select the newly launched model so the user chats to what they just launched
setSelectedChatModel(modelId);
@@ -739,7 +808,7 @@
setTimeout(scrollToBottom, 1000);
}
} catch (error) {
console.error("Error launching instance:", error);
console.error("Error creating meta instance:", error);
} finally {
launchingModelId = null;
}
@@ -941,15 +1010,18 @@
nodeName: string;
progress: DownloadProgress;
}>;
perNodeStatus: PerNodeRunnerStatus[];
} {
if (!downloadsData || Object.keys(downloadsData).length === 0) {
const statusInfo = deriveInstanceStatus(instanceWrapped);
return {
isDownloading: false,
isFailed: false,
errorMessage: null,
isFailed: statusInfo.statusText === "FAILED",
errorMessage: statusInfo.errorMessage,
progress: null,
statusText: "RUNNING",
statusText: statusInfo.statusText,
perNode: [],
perNodeStatus: statusInfo.perNodeStatus,
};
}
@@ -963,6 +1035,7 @@
progress: null,
statusText: "PREPARING",
perNode: [],
perNodeStatus: [],
};
}
@@ -1031,6 +1104,7 @@
progress: null,
statusText: "FAILED",
perNode: [],
perNodeStatus: [],
};
}
}
@@ -1071,10 +1145,11 @@
return {
isDownloading: false,
isFailed: statusInfo.statusText === "FAILED",
errorMessage: null,
errorMessage: statusInfo.errorMessage,
progress: null,
statusText: statusInfo.statusText,
perNode: [],
perNodeStatus: statusInfo.perNodeStatus,
};
}
@@ -1098,92 +1173,172 @@
},
statusText: "DOWNLOADING",
perNode,
perNodeStatus: [],
};
}
// Derive instance status from runners
// Get color class for a status
function getStatusColor(statusText: string): string {
switch (statusText) {
case "FAILED":
return "text-red-400";
case "SHUTDOWN":
return "text-gray-400";
case "DOWNLOADING":
return "text-blue-400";
case "LOADING":
case "WARMING UP":
case "WAITING":
case "INITIALIZING":
return "text-yellow-400";
case "RUNNING":
return "text-teal-400";
case "READY":
case "LOADED":
return "text-green-400";
default:
return "text-exo-light-gray";
}
if (statusText === "FAILED" || statusText === "PLACEMENT FAILED")
return "text-red-400";
if (statusText.startsWith("RETRYING")) return "text-orange-400";
if (statusText === "SHUTDOWN") return "text-gray-400";
if (statusText === "DOWNLOADING") return "text-blue-400";
if (
statusText.startsWith("LOADING") ||
statusText.startsWith("WARMING UP") ||
statusText === "WAITING" ||
statusText === "INITIALIZING"
)
return "text-yellow-400";
if (statusText === "RUNNING") return "text-teal-400";
if (statusText === "READY" || statusText === "LOADED")
return "text-green-400";
return "text-exo-light-gray";
}
const RUNNER_STATUS_MAP: Record<string, string> = {
RunnerWaitingForInitialization: "WaitingForInitialization",
RunnerInitializingBackend: "InitializingBackend",
RunnerWaitingForModel: "WaitingForModel",
RunnerLoading: "Loading",
RunnerLoaded: "Loaded",
RunnerWarmingUp: "WarmingUp",
RunnerReady: "Ready",
RunnerRunning: "Running",
RunnerShutdown: "Shutdown",
RunnerFailed: "Failed",
};
// Friendly labels for display
const RUNNER_STATUS_DISPLAY: Record<string, string> = {
WaitingForInitialization: "Initializing",
InitializingBackend: "Initializing",
WaitingForModel: "Waiting",
Loading: "Loading",
Loaded: "Loaded",
WarmingUp: "Warming Up",
Ready: "Ready",
Running: "Running",
Shutdown: "Shutdown",
Failed: "Failed",
};
interface PerNodeRunnerStatus {
nodeId: string;
nodeName: string;
status: string; // friendly display status
}
function deriveInstanceStatus(instanceWrapped: unknown): {
statusText: string;
statusClass: string;
perNodeStatus: PerNodeRunnerStatus[];
errorMessage: string | null;
} {
const [, instance] = getTagged(instanceWrapped);
if (!instance || typeof instance !== "object") {
return { statusText: "PREPARING", statusClass: "inactive" };
return {
statusText: "PREPARING",
statusClass: "inactive",
perNodeStatus: [],
errorMessage: null,
};
}
const inst = instance as {
shardAssignments?: { runnerToShard?: Record<string, unknown> };
shardAssignments?: {
runnerToShard?: Record<string, unknown>;
nodeToRunner?: Record<string, string>;
};
};
const nodeToRunner = inst.shardAssignments?.nodeToRunner || {};
const runnerIds = Object.keys(inst.shardAssignments?.runnerToShard || {});
const totalNodes = runnerIds.length;
const statuses = runnerIds
.map((rid) => {
const r = runnersData[rid];
if (!r) return null;
const [kind] = getTagged(r);
const statusMap: Record<string, string> = {
RunnerWaitingForInitialization: "WaitingForInitialization",
RunnerInitializingBackend: "InitializingBackend",
RunnerWaitingForModel: "WaitingForModel",
RunnerLoading: "Loading",
RunnerLoaded: "Loaded",
RunnerWarmingUp: "WarmingUp",
RunnerReady: "Ready",
RunnerRunning: "Running",
RunnerShutdown: "Shutdown",
RunnerFailed: "Failed",
};
return kind ? statusMap[kind] || null : null;
})
.filter((s): s is string => s !== null);
// Build per-node status and extract error messages from RunnerFailed
const perNodeStatus: PerNodeRunnerStatus[] = [];
const statuses: string[] = [];
const failedErrors: string[] = [];
for (const [nodeId, runnerId] of Object.entries(nodeToRunner)) {
const r = runnersData[runnerId];
let status: string | null = null;
if (r) {
const [kind, runnerData] = getTagged(r);
status = kind ? RUNNER_STATUS_MAP[kind] || null : null;
// Extract error message from RunnerFailed
if (
kind === "RunnerFailed" &&
runnerData &&
typeof runnerData === "object"
) {
const rd = runnerData as { errorMessage?: string };
if (rd.errorMessage) failedErrors.push(`${getNodeName(nodeId)}: ${rd.errorMessage}`);
}
}
if (status) {
statuses.push(status);
perNodeStatus.push({
nodeId,
nodeName: getNodeName(nodeId),
status: RUNNER_STATUS_DISPLAY[status] || status,
});
}
}
const has = (s: string) => statuses.includes(s);
const count = (s: string) => statuses.filter((v) => v === s).length;
if (statuses.length === 0)
return { statusText: "PREPARING", statusClass: "inactive" };
if (has("Failed")) return { statusText: "FAILED", statusClass: "failed" };
return {
statusText: "PREPARING",
statusClass: "inactive",
perNodeStatus,
errorMessage: null,
};
if (has("Failed"))
return {
statusText: "FAILED",
statusClass: "failed",
perNodeStatus,
errorMessage: failedErrors.length > 0 ? failedErrors.join("; ") : null,
};
if (has("Shutdown"))
return { statusText: "SHUTDOWN", statusClass: "inactive" };
if (has("Loading"))
return { statusText: "LOADING", statusClass: "starting" };
if (has("WarmingUp"))
return { statusText: "WARMING UP", statusClass: "starting" };
if (has("Running"))
return { statusText: "RUNNING", statusClass: "running" };
if (has("Ready")) return { statusText: "READY", statusClass: "loaded" };
if (has("Loaded")) return { statusText: "LOADED", statusClass: "loaded" };
if (has("WaitingForModel"))
return { statusText: "WAITING", statusClass: "starting" };
if (has("InitializingBackend"))
return { statusText: "INITIALIZING", statusClass: "starting" };
if (has("WaitingForInitialization"))
return { statusText: "INITIALIZING", statusClass: "starting" };
return { statusText: "SHUTDOWN", statusClass: "inactive", perNodeStatus, errorMessage: null };
return { statusText: "RUNNING", statusClass: "active" };
// For loading/warming states, show node progress when multi-node
if (has("Loading")) {
const readyCount = count("Ready") + count("Running") + count("Loaded");
const statusText =
totalNodes > 1
? `LOADING (${readyCount}/${totalNodes} nodes ready)`
: "LOADING";
return { statusText, statusClass: "starting", perNodeStatus, errorMessage: null };
}
if (has("WarmingUp")) {
const readyCount = count("Ready") + count("Running");
const statusText =
totalNodes > 1
? `WARMING UP (${readyCount}/${totalNodes} nodes ready)`
: "WARMING UP";
return { statusText, statusClass: "starting", perNodeStatus, errorMessage: null };
}
if (has("Running"))
return { statusText: "RUNNING", statusClass: "running", perNodeStatus, errorMessage: null };
if (has("Ready"))
return { statusText: "READY", statusClass: "loaded", perNodeStatus, errorMessage: null };
if (has("Loaded"))
return { statusText: "LOADED", statusClass: "loaded", perNodeStatus, errorMessage: null };
if (has("WaitingForModel"))
return { statusText: "WAITING", statusClass: "starting", perNodeStatus, errorMessage: null };
if (has("InitializingBackend"))
return { statusText: "INITIALIZING", statusClass: "starting", perNodeStatus, errorMessage: null };
if (has("WaitingForInitialization"))
return { statusText: "INITIALIZING", statusClass: "starting", perNodeStatus, errorMessage: null };
return { statusText: "RUNNING", statusClass: "active", perNodeStatus, errorMessage: null };
}
function getBytes(value: unknown): number {
@@ -1242,6 +1397,75 @@
}
}
async function deleteMetaInstance(metaInstanceId: string) {
const meta = metaInstancesData[metaInstanceId];
const modelId = meta?.modelId ?? "unknown";
if (!confirm(`Delete model ${modelId}?`)) return;
const wasSelected = selectedChatModel() === modelId;
try {
const response = await fetch(`/meta_instance/${metaInstanceId}`, {
method: "DELETE",
headers: { "Content-Type": "application/json" },
});
if (!response.ok) {
console.error("Failed to delete meta instance:", response.status);
} else if (wasSelected) {
// Switch to another available model or clear selection
const remainingInstances = Object.entries(instanceData).filter(
([id]) => id !== getBackingInstanceId(metaInstanceId),
);
if (remainingInstances.length > 0) {
const [, lastInstance] =
remainingInstances[remainingInstances.length - 1];
const newModelId = getInstanceModelId(lastInstance);
if (
newModelId &&
newModelId !== "Unknown" &&
newModelId !== "Unknown Model"
) {
setSelectedChatModel(newModelId);
} else {
setSelectedChatModel("");
}
} else {
setSelectedChatModel("");
}
}
} catch (error) {
console.error("Error deleting meta instance:", error);
}
}
// Find the backing Instance ID for a MetaInstance by scanning instances
function getBackingInstanceId(metaInstanceId: string): string | null {
for (const [id, inst] of Object.entries(instanceData)) {
const [, inner] = getTagged(inst);
if (
inner &&
typeof inner === "object" &&
(inner as Record<string, unknown>).metaInstanceId === metaInstanceId
) {
return id;
}
}
return null;
}
// Get orphan Instance IDs (not backing any MetaInstance)
function getOrphanInstanceIds(): string[] {
return Object.keys(instanceData).filter((id) => {
const [, inner] = getTagged(instanceData[id]);
return (
!inner ||
typeof inner !== "object" ||
!(inner as Record<string, unknown>).metaInstanceId
);
});
}
// Helper to unwrap tagged unions like { MlxRingInstance: {...} }
function getTagged(obj: unknown): [string | null, unknown] {
if (!obj || typeof obj !== "object") return [null, null];
@@ -1282,11 +1506,7 @@
// Instance type from tag
let instanceType = "Unknown";
if (instanceTag === "MlxRingInstance") instanceType = "MLX Ring";
else if (
instanceTag === "MlxIbvInstance" ||
instanceTag === "MlxJacclInstance"
)
instanceType = "MLX RDMA";
else if (instanceTag === "MlxJacclInstance") instanceType = "MLX RDMA";
const inst = instance as {
shardAssignments?: {
@@ -1634,7 +1854,51 @@
}
const nodeCount = $derived(data ? Object.keys(data.nodes).length : 0);
const instanceCount = $derived(Object.keys(instanceData).length);
const metaInstanceCount = $derived(Object.keys(metaInstancesData).length);
const orphanInstanceIds = $derived(getOrphanInstanceIds());
const instanceCount = $derived(metaInstanceCount + orphanInstanceIds.length);
// Unified display items: MetaInstances first, then orphan Instances
interface DisplayItem {
id: string; // MetaInstance ID or Instance ID (used as key and displayed)
modelId: string;
instance: unknown | null; // The backing/orphan instance (tagged union) or null if placing
instanceId: string | null; // The actual Instance ID (for topology hover)
isMetaInstance: boolean;
sharding: string | null; // From MetaInstance constraints (used when instance is null)
instanceMeta: string | null; // From MetaInstance constraints (used when instance is null)
}
const unifiedDisplayItems = $derived((): DisplayItem[] => {
const items: DisplayItem[] = [];
// MetaInstances
for (const [metaId, meta] of Object.entries(metaInstancesData)) {
const backingId = getBackingInstanceId(metaId);
items.push({
id: metaId,
modelId: meta.modelId,
instance: backingId ? instanceData[backingId] : null,
instanceId: backingId,
isMetaInstance: true,
sharding: meta.sharding,
instanceMeta: meta.instanceMeta,
});
}
// Orphan Instances
for (const orphanId of getOrphanInstanceIds()) {
const inst = instanceData[orphanId];
items.push({
id: orphanId,
modelId: getInstanceModelId(inst),
instance: inst,
instanceId: orphanId,
isMetaInstance: false,
sharding: null,
instanceMeta: null,
});
}
return items;
});
// Helper to get the number of nodes in a placement preview
function getPreviewNodeCount(preview: PlacementPreview): number {
@@ -1752,8 +2016,71 @@
</script>
{#snippet clusterWarnings()}
{#if tbBridgeCycles.length > 0 || macosVersionMismatch || (tb5WithoutRdma && !tb5InfoDismissed)}
{#if tbBridgeCycles.length > 0 || macosVersionMismatch || (tb5WithoutRdma && !tb5InfoDismissed) || (jacclError && jacclError !== jacclDismissedError)}
<div class="absolute top-4 left-4 flex flex-col gap-2 z-40">
{#if jacclError && jacclError !== jacclDismissedError}
<div class="group relative" role="alert">
<div
class="flex items-center gap-2 px-3 py-2 rounded border border-red-500/50 bg-red-500/10 backdrop-blur-sm cursor-help"
>
<svg
class="w-5 h-5 text-red-400 flex-shrink-0"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
d={warningIconPath}
/>
</svg>
<span class="text-sm font-mono text-red-200">
JACCL RDMA ERROR
</span>
<button
type="button"
onclick={() => (jacclDismissedError = jacclError)}
class="ml-1 text-red-300/60 hover:text-red-200 transition-colors cursor-pointer"
title="Dismiss"
>
<svg
class="w-4 h-4"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
d="M6 18L18 6M6 6l12 12"
/>
</svg>
</button>
</div>
<!-- Tooltip on hover -->
<div
class="absolute top-full left-0 mt-2 w-80 p-3 rounded border border-red-500/30 bg-exo-dark-gray/95 backdrop-blur-sm opacity-0 invisible group-hover:opacity-100 group-hover:visible transition-all duration-200 z-50 shadow-lg"
>
<p class="text-xs text-white/80 mb-2">
A macOS RDMA driver error was detected. This is a known issue
with the experimental RDMA driver in macOS.
</p>
<p class="text-xs text-white/60 mb-2">
<span class="text-red-300">Error:</span>
{jacclError}
</p>
<p class="text-xs text-white/60">
<span class="text-red-300">To fix:</span> Restart the affected machine.
There is currently no other workaround for this issue.
</p>
</div>
</div>
{/if}
{#if tbBridgeCycles.length > 0}
{@const cycle = tbBridgeCycles[0]}
{@const serviceName = getTbBridgeServiceName(cycle)}
@@ -1922,8 +2249,29 @@
{/snippet}
{#snippet clusterWarningsCompact()}
{#if tbBridgeCycles.length > 0 || macosVersionMismatch || (tb5WithoutRdma && !tb5InfoDismissed)}
{#if tbBridgeCycles.length > 0 || macosVersionMismatch || (tb5WithoutRdma && !tb5InfoDismissed) || (jacclError && jacclError !== jacclDismissedError)}
<div class="absolute top-2 left-2 flex flex-col gap-1">
{#if jacclError && jacclError !== jacclDismissedError}
<div
class="flex items-center gap-1.5 px-2 py-1 rounded border border-red-500/50 bg-red-500/10 backdrop-blur-sm"
title="JACCL RDMA driver error — restart affected machine"
>
<svg
class="w-3.5 h-3.5 text-red-400"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
d={warningIconPath}
/>
</svg>
<span class="text-[10px] font-mono text-red-200">JACCL ERROR</span>
</div>
{/if}
{#if tbBridgeCycles.length > 0}
<div
class="flex items-center gap-1.5 px-2 py-1 rounded border border-yellow-500/50 bg-yellow-500/10 backdrop-blur-sm"
@@ -2301,31 +2649,57 @@
bind:this={instancesContainerRef}
class="max-h-72 xl:max-h-96 space-y-3 overflow-y-auto overflow-x-hidden py-px"
>
{#each Object.entries(instanceData) as [id, instance]}
{@const downloadInfo = getInstanceDownloadStatus(
id,
instance,
)}
{#each unifiedDisplayItems() as item (item.id)}
{@const id = item.id}
{@const instance = item.instance}
{@const downloadInfo = instance
? getInstanceDownloadStatus(item.instanceId ?? id, instance)
: getMetaInstancePlacingStatus(id)}
{@const metaData = item.isMetaInstance ? metaInstancesData[id] : null}
{@const retryError = metaData?.lastFailureError && !downloadInfo.isFailed
? metaData.consecutiveFailures > 0
? `(${((metaData.consecutiveFailures - 1) % 3) + 1}/3) ${metaData.lastFailureError}`
: metaData.lastFailureError
: null}
{@const statusText = downloadInfo.statusText}
{@const isDownloading = downloadInfo.isDownloading}
{@const isFailed = statusText === "FAILED"}
{@const isFailed =
statusText === "FAILED" ||
statusText === "PLACEMENT FAILED"}
{@const isLoading =
statusText === "LOADING" ||
statusText === "WARMING UP" ||
statusText === "WAITING"}
statusText.startsWith("LOADING") ||
statusText.startsWith("WARMING UP") ||
statusText === "WAITING" ||
statusText === "PLACING" ||
statusText.startsWith("RETRYING")}
{@const isReady =
statusText === "READY" || statusText === "LOADED"}
{@const isRunning = statusText === "RUNNING"}
<!-- Instance Card -->
{@const instanceModelId = getInstanceModelId(instance)}
{@const instanceInfo = getInstanceInfo(instance)}
{@const instanceConnections =
getInstanceConnections(instance)}
{@const instanceModelId = item.modelId}
{@const instanceInfo = instance
? getInstanceInfo(instance)
: {
instanceType:
item.instanceMeta === "MlxRing"
? "MLX Ring"
: item.instanceMeta === "MlxJaccl"
? "MLX RDMA"
: "Unknown",
sharding: item.sharding ?? "Unknown",
nodeNames: [] as string[],
nodeIds: [] as string[],
nodeCount: 0,
}}
{@const instanceConnections = instance
? getInstanceConnections(instance)
: []}
<div
class="relative group cursor-pointer"
role="button"
tabindex="0"
onmouseenter={() => (hoveredInstanceId = id)}
onmouseenter={() =>
(hoveredInstanceId = item.instanceId ?? id)}
onmouseleave={() => (hoveredInstanceId = null)}
onclick={() => {
if (
@@ -2424,7 +2798,10 @@
>
</div>
<button
onclick={() => deleteInstance(id)}
onclick={() =>
item.isMetaInstance
? deleteMetaInstance(id)
: deleteInstance(id)}
class="text-xs px-2 py-1 font-mono tracking-wider uppercase border border-red-500/30 text-red-400 hover:bg-red-500/20 hover:text-red-400 hover:border-red-500/50 transition-all duration-200 cursor-pointer"
>
DELETE
@@ -2434,7 +2811,7 @@
<div
class="text-exo-yellow text-xs font-mono tracking-wide truncate"
>
{getInstanceModelId(instance)}
{instanceModelId}
</div>
<div class="text-white/60 text-xs font-mono">
Strategy: <span class="text-white/80"
@@ -2702,6 +3079,30 @@
>
{downloadInfo.errorMessage}
</div>
{:else if retryError}
<div
class="text-xs text-orange-400/80 font-mono mt-1 break-words"
>
Retrying after error: {retryError}
</div>
{/if}
{#if downloadInfo.perNodeStatus.length > 1 && (statusText.startsWith("LOADING") || statusText.startsWith("WARMING UP") || statusText === "WAITING" || statusText === "INITIALIZING")}
<div class="mt-1.5 space-y-0.5">
{#each downloadInfo.perNodeStatus as node}
<div
class="flex items-center justify-between text-[10px] font-mono"
>
<span class="text-white/60 truncate pr-2"
>{node.nodeName}</span
>
<span
class={getStatusColor(
node.status.toUpperCase(),
)}>{node.status}</span
>
</div>
{/each}
</div>
{/if}
{/if}
</div>
@@ -2870,21 +3271,21 @@
</button>
<button
onclick={() => {
selectedInstanceType = "MlxIbv";
selectedInstanceType = "MlxJaccl";
saveLaunchDefaults();
}}
class="flex items-center gap-2 py-2 px-4 text-sm font-mono border rounded transition-all duration-200 cursor-pointer {selectedInstanceType ===
'MlxIbv'
'MlxJaccl'
? 'bg-transparent text-exo-yellow border-exo-yellow'
: 'bg-transparent text-white/70 border-exo-medium-gray/50 hover:border-exo-yellow/50'}"
>
<span
class="w-4 h-4 rounded-full border-2 flex items-center justify-center {selectedInstanceType ===
'MlxIbv'
'MlxJaccl'
? 'border-exo-yellow'
: 'border-exo-medium-gray'}"
>
{#if selectedInstanceType === "MlxIbv"}
{#if selectedInstanceType === "MlxJaccl"}
<span class="w-2 h-2 rounded-full bg-exo-yellow"></span>
{/if}
</span>
@@ -3113,31 +3514,60 @@
<div
class="space-y-3 max-h-72 xl:max-h-96 overflow-y-auto overflow-x-hidden py-px pr-1"
>
{#each Object.entries(instanceData) as [id, instance]}
{@const downloadInfo = getInstanceDownloadStatus(
id,
instance,
)}
{#each unifiedDisplayItems() as item (item.id)}
{@const id = item.id}
{@const instance = item.instance}
{@const downloadInfo = instance
? getInstanceDownloadStatus(
item.instanceId ?? id,
instance,
)
: getMetaInstancePlacingStatus(id)}
{@const metaData = item.isMetaInstance ? metaInstancesData[id] : null}
{@const retryError = metaData?.lastFailureError && !downloadInfo.isFailed
? metaData.consecutiveFailures > 0
? `(${((metaData.consecutiveFailures - 1) % 3) + 1}/3) ${metaData.lastFailureError}`
: metaData.lastFailureError
: null}
{@const statusText = downloadInfo.statusText}
{@const isDownloading = downloadInfo.isDownloading}
{@const isFailed = statusText === "FAILED"}
{@const isFailed =
statusText === "FAILED" ||
statusText === "PLACEMENT FAILED"}
{@const isLoading =
statusText === "LOADING" ||
statusText === "WARMING UP" ||
statusText === "WAITING"}
statusText.startsWith("LOADING") ||
statusText.startsWith("WARMING UP") ||
statusText === "WAITING" ||
statusText === "PLACING" ||
statusText.startsWith("RETRYING")}
{@const isReady =
statusText === "READY" || statusText === "LOADED"}
{@const isRunning = statusText === "RUNNING"}
<!-- Instance Card -->
{@const instanceModelId = getInstanceModelId(instance)}
{@const instanceInfo = getInstanceInfo(instance)}
{@const instanceConnections =
getInstanceConnections(instance)}
{@const instanceModelId = item.modelId}
{@const instanceInfo = instance
? getInstanceInfo(instance)
: {
instanceType:
item.instanceMeta === "MlxRing"
? "MLX Ring"
: item.instanceMeta === "MlxJaccl"
? "MLX RDMA"
: "Unknown",
sharding: item.sharding ?? "Unknown",
nodeNames: [] as string[],
nodeIds: [] as string[],
nodeCount: 0,
}}
{@const instanceConnections = instance
? getInstanceConnections(instance)
: []}
<div
class="relative group cursor-pointer"
role="button"
tabindex="0"
onmouseenter={() => (hoveredInstanceId = id)}
onmouseenter={() =>
(hoveredInstanceId = item.instanceId ?? id)}
onmouseleave={() => (hoveredInstanceId = null)}
onclick={() => {
if (
@@ -3236,7 +3666,10 @@
>
</div>
<button
onclick={() => deleteInstance(id)}
onclick={() =>
item.isMetaInstance
? deleteMetaInstance(id)
: deleteInstance(id)}
class="text-xs px-2 py-1 font-mono tracking-wider uppercase border border-red-500/30 text-red-400 hover:bg-red-500/20 hover:text-red-400 hover:border-red-500/50 transition-all duration-200 cursor-pointer"
>
DELETE
@@ -3246,7 +3679,7 @@
<div
class="text-exo-yellow text-xs font-mono tracking-wide truncate"
>
{getInstanceModelId(instance)}
{instanceModelId}
</div>
<div class="text-white/60 text-xs font-mono">
Strategy: <span class="text-white/80"
@@ -3524,6 +3957,30 @@
>
{downloadInfo.errorMessage}
</div>
{:else if retryError}
<div
class="text-xs text-orange-400/80 font-mono mt-1 break-words"
>
Retrying after error: {retryError}
</div>
{/if}
{#if downloadInfo.perNodeStatus.length > 1 && (statusText.startsWith("LOADING") || statusText.startsWith("WARMING UP") || statusText === "WAITING" || statusText === "INITIALIZING")}
<div class="mt-1.5 space-y-0.5">
{#each downloadInfo.perNodeStatus as node}
<div
class="flex items-center justify-between text-[10px] font-mono"
>
<span class="text-white/60 truncate pr-2"
>{node.nodeName}</span
>
<span
class={getStatusColor(
node.status.toUpperCase(),
)}>{node.status}</span
>
</div>
{/each}
</div>
{/if}
{/if}
</div>

View File

@@ -71,8 +71,11 @@ from exo.shared.types.api import (
ChatCompletionResponse,
CreateInstanceParams,
CreateInstanceResponse,
CreateMetaInstanceParams,
CreateMetaInstanceResponse,
DeleteDownloadResponse,
DeleteInstanceResponse,
DeleteMetaInstanceResponse,
ErrorInfo,
ErrorResponse,
FinishReason,
@@ -115,8 +118,10 @@ from exo.shared.types.claude_api import (
from exo.shared.types.commands import (
Command,
CreateInstance,
CreateMetaInstance,
DeleteDownload,
DeleteInstance,
DeleteMetaInstance,
DownloadCommand,
ForwarderCommand,
ForwarderDownloadCommand,
@@ -128,7 +133,7 @@ from exo.shared.types.commands import (
TaskFinished,
TextGeneration,
)
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
from exo.shared.types.common import CommandId, Id, MetaInstanceId, NodeId, SessionId
from exo.shared.types.events import (
ChunkGenerated,
Event,
@@ -137,6 +142,7 @@ from exo.shared.types.events import (
TracesMerged,
)
from exo.shared.types.memory import Memory
from exo.shared.types.meta_instance import MetaInstance
from exo.shared.types.openai_responses import (
ResponsesRequest,
ResponsesResponse,
@@ -275,6 +281,8 @@ class API:
self.app.get("/instance/previews")(self.get_placement_previews)
self.app.get("/instance/{instance_id}")(self.get_instance)
self.app.delete("/instance/{instance_id}")(self.delete_instance)
self.app.post("/meta_instance")(self.create_meta_instance)
self.app.delete("/meta_instance/{meta_instance_id}")(self.delete_meta_instance)
self.app.get("/models")(self.get_models)
self.app.get("/v1/models")(self.get_models)
self.app.post("/models/add")(self.add_custom_model)
@@ -521,6 +529,46 @@ class API:
instance_id=instance_id,
)
async def create_meta_instance(
self, payload: CreateMetaInstanceParams
) -> CreateMetaInstanceResponse:
meta_instance = MetaInstance(
model_id=payload.model_id,
sharding=payload.sharding,
instance_meta=payload.instance_meta,
min_nodes=payload.min_nodes,
node_ids=payload.node_ids,
)
command = CreateMetaInstance(meta_instance=meta_instance)
await self._send(command)
return CreateMetaInstanceResponse(
message="Command received.",
command_id=command.command_id,
meta_instance_id=meta_instance.meta_instance_id,
)
async def delete_meta_instance(
self, meta_instance_id: MetaInstanceId
) -> DeleteMetaInstanceResponse:
meta = self.state.meta_instances.get(meta_instance_id)
if not meta:
raise HTTPException(status_code=404, detail="MetaInstance not found")
# Delete MetaInstance first to prevent reconciler from re-placing
command = DeleteMetaInstance(meta_instance_id=meta_instance_id)
await self._send(command)
# Then cascade-delete any backing instances
for instance_id, instance in self.state.instances.items():
if instance.meta_instance_id == meta_instance_id:
await self._send(DeleteInstance(instance_id=instance_id))
return DeleteMetaInstanceResponse(
message="Command received.",
command_id=command.command_id,
meta_instance_id=meta_instance_id,
)
async def _token_chunk_stream(
self, command_id: CommandId
) -> AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None]:

View File

@@ -1,4 +1,5 @@
from datetime import datetime, timedelta, timezone
from collections.abc import Sequence
from datetime import datetime, timezone
import anyio
from anyio.abc import TaskGroup
@@ -12,11 +13,19 @@ from exo.master.placement import (
get_transition_events,
place_instance,
)
from exo.master.process_managers import ProcessManager
from exo.master.process_managers.instance_health import InstanceHealthReconciler
from exo.master.process_managers.meta_instance import MetaInstanceReconciler
from exo.master.process_managers.node_timeout import NodeTimeoutReconciler
from exo.master.reconcile import try_place_for_meta_instance
from exo.shared.apply import apply
from exo.shared.constants import EXO_EVENT_LOG_DIR, EXO_TRACING_ENABLED
from exo.shared.models.model_cards import ModelCard
from exo.shared.types.commands import (
CreateInstance,
CreateMetaInstance,
DeleteInstance,
DeleteMetaInstance,
ForwarderCommand,
ForwarderDownloadCommand,
ImageEdits,
@@ -34,9 +43,9 @@ from exo.shared.types.events import (
ForwarderEvent,
IndexedEvent,
InputChunkReceived,
InstanceDeleted,
MetaInstanceCreated,
MetaInstanceDeleted,
NodeGatheredInfo,
NodeTimedOut,
TaskCreated,
TaskDeleted,
TraceEventData,
@@ -58,7 +67,7 @@ from exo.shared.types.tasks import (
TextGeneration as TextGenerationTask,
)
from exo.shared.types.worker.instances import InstanceId
from exo.utils.channels import Receiver, Sender, channel
from exo.utils.channels import Receiver, Sender
from exo.utils.event_buffer import MultiSourceBuffer
@@ -82,16 +91,15 @@ class Master:
self.local_event_receiver = local_event_receiver
self.global_event_sender = global_event_sender
self.download_command_sender = download_command_sender
send, recv = channel[Event]()
self.event_sender: Sender[Event] = send
self._loopback_event_receiver: Receiver[Event] = recv
self._loopback_event_sender: Sender[ForwarderEvent] = (
local_event_receiver.clone_sender()
)
self._multi_buffer = MultiSourceBuffer[NodeId, Event]()
self._event_log = DiskEventLog(EXO_EVENT_LOG_DIR / "master")
self._pending_traces: dict[TaskId, dict[int, list[TraceEventData]]] = {}
self._expected_ranks: dict[TaskId, set[int]] = {}
self._process_managers: Sequence[ProcessManager] = [
InstanceHealthReconciler(),
NodeTimeoutReconciler(),
MetaInstanceReconciler(),
]
async def run(self):
logger.info("Starting Master")
@@ -100,15 +108,12 @@ class Master:
async with self._tg as tg:
tg.start_soon(self._event_processor)
tg.start_soon(self._command_processor)
tg.start_soon(self._loopback_processor)
tg.start_soon(self._plan)
tg.start_soon(self._reconcile)
finally:
self._event_log.close()
self.global_event_sender.close()
self.local_event_receiver.close()
self.command_receiver.close()
self._loopback_event_sender.close()
self._loopback_event_receiver.close()
async def shutdown(self):
logger.info("Stopping Master")
@@ -290,6 +295,29 @@ class Master:
)
)
generated_events.extend(transition_events)
case CreateMetaInstance():
generated_events.append(
MetaInstanceCreated(meta_instance=command.meta_instance)
)
# Immediate placement attempt for responsiveness
model_card = await ModelCard.load(
command.meta_instance.model_id
)
result = try_place_for_meta_instance(
command.meta_instance,
model_card,
self.state.topology,
self.state.instances,
self.state.node_memory,
self.state.node_network,
)
generated_events.extend(result.events)
case DeleteMetaInstance():
generated_events.append(
MetaInstanceDeleted(
meta_instance_id=command.meta_instance_id
)
)
case PlaceInstance():
placement = place_instance(
command,
@@ -341,31 +369,32 @@ class Master:
):
await self._send_event(IndexedEvent(idx=i, event=event))
for event in generated_events:
await self.event_sender.send(event)
await self._apply_and_broadcast(event)
except ValueError as e:
logger.opt(exception=e).warning("Error in command processor")
# These plan loops are the cracks showing in our event sourcing architecture - more things could be commands
async def _plan(self) -> None:
async def _apply_and_broadcast(self, event: Event) -> None:
"""Apply event to state, persist to disk, and broadcast to workers.
State is updated synchronously (before any await), so callers can
rely on ``self.state`` reflecting this event immediately after the
call. Python's cooperative scheduling guarantees no interleaving
between the state read and write.
"""
logger.debug(f"Master indexing event: {str(event)[:100]}")
indexed = IndexedEvent(event=event, idx=len(self._event_log))
self.state = apply(self.state, indexed)
event._master_time_stamp = datetime.now(tz=timezone.utc) # pyright: ignore[reportPrivateUsage]
self._event_log.append(event)
await self._send_event(indexed)
async def _reconcile(self) -> None:
while True:
# kill broken instances
connected_node_ids = set(self.state.topology.list_nodes())
for instance_id, instance in self.state.instances.items():
for node_id in instance.shard_assignments.node_to_runner:
if node_id not in connected_node_ids:
await self.event_sender.send(
InstanceDeleted(instance_id=instance_id)
)
break
# time out dead nodes
for node_id, time in self.state.last_seen.items():
now = datetime.now(tz=timezone.utc)
if now - time > timedelta(seconds=30):
logger.info(f"Manually removing node {node_id} due to inactivity")
await self.event_sender.send(NodeTimedOut(node_id=node_id))
await anyio.sleep(10)
for pm in self._process_managers:
events = await pm.reconcile(self.state)
for event in events:
await self._apply_and_broadcast(event)
await anyio.sleep(1)
async def _event_processor(self) -> None:
with self.local_event_receiver as local_events:
@@ -383,32 +412,10 @@ class Master:
await self._handle_traces_collected(event)
continue
logger.debug(f"Master indexing event: {str(event)[:100]}")
indexed = IndexedEvent(event=event, idx=len(self._event_log))
self.state = apply(self.state, indexed)
event._master_time_stamp = datetime.now(tz=timezone.utc) # pyright: ignore[reportPrivateUsage]
if isinstance(event, NodeGatheredInfo):
event.when = str(datetime.now(tz=timezone.utc))
self._event_log.append(event)
await self._send_event(indexed)
async def _loopback_processor(self) -> None:
# this would ideally not be necessary.
# this is WAY less hacky than how I was working around this before
local_index = 0
with self._loopback_event_receiver as events:
async for event in events:
await self._loopback_event_sender.send(
ForwarderEvent(
origin=NodeId(f"master_{self.node_id}"),
origin_idx=local_index,
session=self.session_id,
event=event,
)
)
local_index += 1
await self._apply_and_broadcast(event)
# This function is re-entrant, take care!
async def _send_event(self, event: IndexedEvent):
@@ -440,7 +447,7 @@ class Master:
for trace_data in self._pending_traces[task_id].values():
all_trace_data.extend(trace_data)
await self.event_sender.send(
await self._apply_and_broadcast(
TracesMerged(task_id=task_id, traces=all_trace_data)
)

View File

@@ -63,7 +63,9 @@ def place_instance(
required_nodes: set[NodeId] | None = None,
) -> dict[InstanceId, Instance]:
cycles = topology.get_cycles()
candidate_cycles = list(filter(lambda it: len(it) >= command.min_nodes, cycles))
candidate_cycles = list(
filter(lambda it: len(it) >= command.min_nodes, cycles)
)
# Filter to cycles containing all required nodes (subset matching)
if required_nodes:
@@ -106,7 +108,11 @@ def place_instance(
cycle for cycle in smallest_cycles if topology.is_rdma_cycle(cycle)
]
if command.instance_meta == InstanceMeta.MlxJaccl and smallest_rdma_cycles != []:
if command.instance_meta == InstanceMeta.MlxJaccl:
if not smallest_rdma_cycles:
raise ValueError(
"Requested RDMA (MlxJaccl) but no RDMA-connected cycles available"
)
smallest_cycles = smallest_rdma_cycles
cycles_with_leaf_nodes: list[Cycle] = [

View File

@@ -0,0 +1,12 @@
from collections.abc import Sequence
from typing import Protocol, runtime_checkable
from exo.shared.types.events import Event
from exo.shared.types.state import State
@runtime_checkable
class ProcessManager(Protocol):
"""A reconciliation step that examines state and returns corrective events."""
async def reconcile(self, state: State) -> Sequence[Event]: ...

View File

@@ -0,0 +1,49 @@
from collections.abc import Sequence
from typing import final
from exo.master.reconcile import instance_connections_healthy, instance_runners_failed
from exo.shared.types.events import Event, InstanceDeleted, InstanceRetrying
from exo.shared.types.state import State
MAX_INSTANCE_RETRIES = 3
@final
class InstanceHealthReconciler:
"""Delete instances whose network connections are broken or whose runners have all failed."""
async def reconcile(self, state: State) -> Sequence[Event]:
events: list[Event] = []
for instance_id, instance in state.instances.items():
if not instance_connections_healthy(instance, state.topology):
events.append(
InstanceDeleted(
instance_id=instance_id,
failure_error="Network connection lost",
)
)
continue
is_failed, error_message = instance_runners_failed(
instance, state.runners, state.node_identities
)
if is_failed:
# Retry within the same instance if backed by a MetaInstance
mid = instance.meta_instance_id
mi = state.meta_instances.get(mid) if mid else None
if mid and mi and mi.consecutive_failures < MAX_INSTANCE_RETRIES:
events.append(
InstanceRetrying(
instance_id=instance_id,
meta_instance_id=mid,
failure_error=error_message or "Runner failed",
)
)
else:
events.append(
InstanceDeleted(
instance_id=instance_id,
failure_error=error_message,
)
)
return events

View File

@@ -0,0 +1,53 @@
from collections.abc import Sequence
from typing import final
from exo.master.reconcile import (
find_unsatisfied_meta_instances,
try_place_for_meta_instance,
)
from exo.shared.models.model_cards import ModelCard
from exo.shared.types.events import Event, InstanceCreated, MetaInstancePlacementFailed
from exo.shared.types.state import State
from exo.shared.types.worker.instances import Instance, InstanceId
@final
class MetaInstanceReconciler:
"""Place instances for unsatisfied MetaInstances."""
async def reconcile(self, state: State) -> Sequence[Event]:
all_events: list[Event] = []
# Local copy for intermediate tracking — so placement of B
# sees A's instance and doesn't double-place on same resources.
current_instances: dict[InstanceId, Instance] = dict(state.instances)
unsatisfied = find_unsatisfied_meta_instances(
state.meta_instances,
current_instances,
state.topology,
)
for meta_instance in unsatisfied:
model_card = await ModelCard.load(meta_instance.model_id)
result = try_place_for_meta_instance(
meta_instance,
model_card,
state.topology,
current_instances,
state.node_memory,
state.node_network,
)
# Update local instance map so next placement sees this one
for event in result.events:
if isinstance(event, InstanceCreated):
current_instances[event.instance.instance_id] = event.instance
all_events.extend(result.events)
# Emit placement failure if error differs from what's already in state
if result.error is not None and meta_instance.placement_error != result.error:
all_events.append(
MetaInstancePlacementFailed(
meta_instance_id=meta_instance.meta_instance_id,
reason=result.error,
)
)
return all_events

View File

@@ -0,0 +1,27 @@
from collections.abc import Sequence
from datetime import datetime, timedelta, timezone
from typing import final
from loguru import logger
from exo.shared.types.events import Event, NodeTimedOut
from exo.shared.types.state import State
_DEFAULT_TIMEOUT = timedelta(seconds=30)
@final
class NodeTimeoutReconciler:
"""Time out nodes that haven't been seen recently."""
def __init__(self, timeout: timedelta = _DEFAULT_TIMEOUT) -> None:
self.timeout = timeout
async def reconcile(self, state: State) -> Sequence[Event]:
now = datetime.now(tz=timezone.utc)
events: list[Event] = []
for node_id, last_seen in state.last_seen.items():
if now - last_seen > self.timeout:
logger.info(f"Removing node {node_id} due to inactivity")
events.append(NodeTimedOut(node_id=node_id))
return events

236
src/exo/master/reconcile.py Normal file
View File

@@ -0,0 +1,236 @@
from collections.abc import Mapping, Sequence
from typing import NamedTuple
from loguru import logger
from exo.master.placement import get_transition_events, place_instance
from exo.shared.models.model_cards import ModelCard
from exo.shared.topology import Topology
from exo.shared.types.commands import PlaceInstance
from exo.shared.types.common import MetaInstanceId, NodeId
from exo.shared.types.events import Event
from exo.shared.types.meta_instance import MetaInstance
from exo.shared.types.profiling import MemoryUsage, NodeIdentity, NodeNetworkInfo
from exo.shared.types.topology import RDMAConnection, SocketConnection
from exo.shared.types.worker.instances import (
BaseInstance,
Instance,
InstanceId,
MlxJacclInstance,
MlxRingInstance,
)
from exo.shared.types.worker.runners import (
RunnerFailed,
RunnerId,
RunnerShutdown,
RunnerStatus,
)
class PlacementResult(NamedTuple):
"""Result of a placement attempt: events to apply and optional error reason."""
events: Sequence[Event]
error: str | None
def _get_ring_order(instance: BaseInstance) -> list[NodeId]:
"""Reconstruct ring order from shard device_rank."""
node_ranks: list[tuple[NodeId, int]] = []
for node_id, runner_id in instance.shard_assignments.node_to_runner.items():
shard = instance.shard_assignments.runner_to_shard[runner_id]
node_ranks.append((node_id, shard.device_rank))
node_ranks.sort(key=lambda x: x[1])
return [node_id for node_id, _ in node_ranks]
def _ring_connections_healthy(instance: MlxRingInstance, topology: Topology) -> bool:
"""Check that the specific IPs used by a ring instance still exist in the topology."""
ring = _get_ring_order(instance)
n = len(ring)
for node in ring:
hosts = instance.hosts_by_node[node]
for idx in range(n):
host = hosts[idx]
if host.ip in ("0.0.0.0", "198.51.100.1"):
continue # self or placeholder
# Real connection: node → ring[idx]. Check specific IP.
connections = topology.get_all_connections_between(node, ring[idx])
if not any(
isinstance(c, SocketConnection)
and c.sink_multiaddr.ip_address == host.ip
for c in connections
):
return False
return True
def _jaccl_connections_healthy(instance: MlxJacclInstance, topology: Topology) -> bool:
"""Check that the specific RDMA interfaces used by a JACCL instance still exist."""
ring = _get_ring_order(instance)
n = len(ring)
for i in range(n):
for j in range(n):
iface = instance.jaccl_devices[i][j]
if iface is None:
continue
connections = topology.get_all_connections_between(ring[i], ring[j])
if not any(
isinstance(c, RDMAConnection) and c.source_rdma_iface == iface
for c in connections
):
return False
return True
def instance_connections_healthy(instance: Instance, topology: Topology) -> bool:
"""Check that an instance's nodes and specific connections are still in the topology."""
instance_nodes = set(instance.shard_assignments.node_to_runner.keys())
if not all(topology.contains_node(n) for n in instance_nodes):
return False
if len(instance_nodes) <= 1:
return True
match instance:
case MlxRingInstance():
return _ring_connections_healthy(instance, topology)
case MlxJacclInstance():
return _jaccl_connections_healthy(instance, topology)
def instance_runners_failed(
instance: Instance,
runners: Mapping[RunnerId, RunnerStatus],
node_identities: Mapping[NodeId, NodeIdentity],
) -> tuple[bool, str | None]:
"""Check if an instance's runners have all reached terminal failure states.
Returns ``(True, error_message)`` when ALL runners are terminal
(``RunnerFailed`` or ``RunnerShutdown``) and at least one is ``RunnerFailed``.
Returns ``(False, None)`` when runners are still active, haven't reported
yet, or all gracefully shut down (no ``RunnerFailed``).
"""
instance_runner_ids = set(instance.shard_assignments.node_to_runner.values())
if not instance_runner_ids:
return False, None
# Build reverse mapping: runner_id -> node_id
runner_to_node: dict[RunnerId, NodeId] = {
runner_id: node_id
for node_id, runner_id in instance.shard_assignments.node_to_runner.items()
}
has_any_failed = False
error_messages: list[str] = []
for runner_id in instance_runner_ids:
status = runners.get(runner_id)
if status is None:
# Runner hasn't reported yet — instance is still starting
return False, None
if isinstance(status, RunnerFailed):
has_any_failed = True
if status.error_message:
node_id = runner_to_node.get(runner_id)
name = node_identities[node_id].friendly_name if node_id and node_id in node_identities else node_id or "unknown"
error_messages.append(f"{name}: {status.error_message}")
elif isinstance(status, RunnerShutdown):
pass # Terminal but not a failure indicator on its own
else:
# Runner is still active (connecting, loading, running, etc.)
return False, None
if has_any_failed:
return True, "; ".join(error_messages) if error_messages else "Runner failed"
# All runners are Shutdown but none Failed — graceful shutdown, not a failure
return False, None
def instance_satisfies_meta_instance(
meta_instance: MetaInstance,
instance: Instance,
) -> bool:
"""Check if a single instance satisfies a meta-instance's constraints.
This is a pure constraint check (model, min_nodes, node_ids).
Use ``instance_connections_healthy`` separately for topology health.
"""
if instance.shard_assignments.model_id != meta_instance.model_id:
return False
instance_nodes = set(instance.shard_assignments.node_to_runner.keys())
if len(instance_nodes) < meta_instance.min_nodes:
return False
return meta_instance.node_ids is None or set(meta_instance.node_ids).issubset(
instance_nodes
)
def find_unsatisfied_meta_instances(
meta_instances: Mapping[MetaInstanceId, MetaInstance],
instances: Mapping[InstanceId, Instance],
topology: Topology,
) -> Sequence[MetaInstance]:
"""Return meta-instances that have no healthy backing instance."""
unsatisfied: list[MetaInstance] = []
for meta_id, meta_instance in meta_instances.items():
has_healthy_backing = any(
instance.meta_instance_id == meta_id
and instance_connections_healthy(instance, topology)
for instance in instances.values()
)
if not has_healthy_backing:
unsatisfied.append(meta_instance)
return unsatisfied
def try_place_for_meta_instance(
meta_instance: MetaInstance,
model_card: ModelCard,
topology: Topology,
current_instances: Mapping[InstanceId, Instance],
node_memory: Mapping[NodeId, MemoryUsage],
node_network: Mapping[NodeId, NodeNetworkInfo],
) -> PlacementResult:
"""Try to place an instance satisfying the meta-instance constraints.
Returns a :class:`PlacementResult` with events on success, or an error
reason on failure.
"""
command = PlaceInstance(
model_card=model_card,
sharding=meta_instance.sharding,
instance_meta=meta_instance.instance_meta,
min_nodes=meta_instance.min_nodes,
)
try:
target_instances = place_instance(
command,
topology,
current_instances,
node_memory,
node_network,
required_nodes=(
set(meta_instance.node_ids) if meta_instance.node_ids else None
),
)
# Tag the new instance with meta_instance_id
new_instance_ids = set(target_instances.keys()) - set(current_instances.keys())
if new_instance_ids:
new_id = next(iter(new_instance_ids))
target_instances[new_id] = target_instances[new_id].model_copy(
update={"meta_instance_id": meta_instance.meta_instance_id}
)
return PlacementResult(
events=list(get_transition_events(current_instances, target_instances)),
error=None,
)
except ValueError as e:
logger.debug(
f"MetaInstance placement not possible for {meta_instance.model_id}: {e}"
)
return PlacementResult(events=[], error=str(e))

View File

@@ -0,0 +1,750 @@
from exo.master.process_managers.instance_health import InstanceHealthReconciler
from exo.master.reconcile import (
find_unsatisfied_meta_instances,
instance_connections_healthy,
instance_runners_failed,
instance_satisfies_meta_instance,
)
from exo.shared.apply import apply
from exo.shared.models.model_cards import ModelCard, ModelId, ModelTask
from exo.shared.topology import Topology
from exo.shared.types.common import Host, MetaInstanceId, NodeId
from exo.shared.types.events import (
IndexedEvent,
InstanceCreated,
InstanceDeleted,
InstanceRetrying,
MetaInstanceCreated,
MetaInstanceDeleted,
)
from exo.shared.types.memory import Memory
from exo.shared.types.meta_instance import MetaInstance
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.state import State
from exo.shared.types.topology import Connection, SocketConnection
from exo.shared.types.worker.instances import (
InstanceId,
MlxRingInstance,
)
from exo.shared.types.worker.runners import (
RunnerFailed,
RunnerId,
RunnerLoading,
RunnerReady,
RunnerShutdown,
ShardAssignments,
)
from exo.shared.types.worker.shards import PipelineShardMetadata
def _model_card(model_id: str = "test-org/test-model") -> ModelCard:
return ModelCard(
model_id=ModelId(model_id),
storage_size=Memory.from_kb(1000),
n_layers=10,
hidden_size=30,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
)
def _topology(*node_ids: str, connect: bool = True) -> Topology:
"""Build a topology with nodes connected in a bidirectional ring with unique IPs.
Node at index ``i`` gets IP ``10.0.0.{i+1}``. Edges go in both directions
between consecutive nodes (including wrap-around).
"""
t = Topology()
nodes = [NodeId(n) for n in node_ids]
for n in nodes:
t.add_node(n)
if connect and len(nodes) > 1:
for i in range(len(nodes)):
j = (i + 1) % len(nodes)
t.add_connection(
Connection(
source=nodes[i],
sink=nodes[j],
edge=SocketConnection(
sink_multiaddr=Multiaddr(
address=f"/ip4/10.0.0.{j + 1}/tcp/50000"
)
),
)
)
t.add_connection(
Connection(
source=nodes[j],
sink=nodes[i],
edge=SocketConnection(
sink_multiaddr=Multiaddr(
address=f"/ip4/10.0.0.{i + 1}/tcp/50000"
)
),
)
)
return t
def _meta_instance(
model_id: str = "test-org/test-model",
*,
min_nodes: int = 1,
node_ids: list[NodeId] | None = None,
meta_instance_id: MetaInstanceId | None = None,
) -> MetaInstance:
return MetaInstance(
meta_instance_id=meta_instance_id or MetaInstanceId(),
model_id=ModelId(model_id),
min_nodes=min_nodes,
node_ids=node_ids,
)
def _instance(
model_id: str = "test-org/test-model",
node_ids: list[str] | None = None,
instance_id: InstanceId | None = None,
meta_instance_id: MetaInstanceId | None = None,
) -> tuple[InstanceId, MlxRingInstance]:
"""Create a test instance with hosts_by_node matching ``_topology()`` IPs."""
iid = instance_id or InstanceId()
nodes = node_ids or ["node-a"]
n = len(nodes)
mc = _model_card(model_id)
ephemeral_port = 50000
node_to_runner = {NodeId(nd): RunnerId() for nd in nodes}
runner_to_shard = {
runner_id: PipelineShardMetadata(
model_card=mc,
device_rank=i,
world_size=n,
start_layer=0,
end_layer=mc.n_layers,
n_layers=mc.n_layers,
)
for i, runner_id in enumerate(node_to_runner.values())
}
# Build hosts_by_node with IPs matching _topology() convention:
# node at index idx has IP 10.0.0.{idx+1}
hosts_by_node: dict[NodeId, list[Host]] = {}
for r, node_str in enumerate(nodes):
hosts: list[Host] = []
for idx in range(n):
if idx == r:
hosts.append(Host(ip="0.0.0.0", port=ephemeral_port))
elif n > 1 and idx in ((r - 1) % n, (r + 1) % n):
hosts.append(Host(ip=f"10.0.0.{idx + 1}", port=ephemeral_port))
else:
hosts.append(Host(ip="198.51.100.1", port=0))
hosts_by_node[NodeId(node_str)] = hosts
return iid, MlxRingInstance(
instance_id=iid,
shard_assignments=ShardAssignments(
model_id=ModelId(model_id),
runner_to_shard=runner_to_shard,
node_to_runner=node_to_runner,
),
hosts_by_node=hosts_by_node,
ephemeral_port=ephemeral_port,
meta_instance_id=meta_instance_id,
)
# --- instance_satisfies_meta_instance (pure constraint matching) ---
def test_satisfies_matching_model():
meta = _meta_instance()
_, inst = _instance(node_ids=["node-a"])
assert instance_satisfies_meta_instance(meta, inst) is True
def test_not_satisfies_wrong_model():
meta = _meta_instance("test-org/model-a")
_, inst = _instance("test-org/model-b")
assert instance_satisfies_meta_instance(meta, inst) is False
def test_not_satisfies_missing_required_node():
meta = _meta_instance(node_ids=[NodeId("node-c")])
_, inst = _instance(node_ids=["node-a", "node-b"])
assert instance_satisfies_meta_instance(meta, inst) is False
def test_not_satisfies_fewer_than_min_nodes():
meta = _meta_instance(min_nodes=3)
_, inst = _instance(node_ids=["node-a", "node-b"])
assert instance_satisfies_meta_instance(meta, inst) is False
def test_satisfies_with_node_ids_specified():
meta = _meta_instance(
node_ids=[NodeId("node-a"), NodeId("node-b")], min_nodes=2
)
_, inst = _instance(node_ids=["node-a", "node-b", "node-c"])
assert instance_satisfies_meta_instance(meta, inst) is True
# --- instance_connections_healthy ---
def test_healthy_single_node_present():
_, inst = _instance(node_ids=["node-a"])
topology = _topology("node-a")
assert instance_connections_healthy(inst, topology) is True
def test_unhealthy_single_node_missing():
_, inst = _instance(node_ids=["node-a"])
topology = Topology() # empty
assert instance_connections_healthy(inst, topology) is False
def test_healthy_two_node_ring():
_, inst = _instance(node_ids=["node-a", "node-b"])
topology = _topology("node-a", "node-b")
assert instance_connections_healthy(inst, topology) is True
def test_unhealthy_two_node_edge_removed():
"""Nodes present but edge removed — ring broken."""
_, inst = _instance(node_ids=["node-a", "node-b"])
topology = _topology("node-a", "node-b", connect=False)
assert instance_connections_healthy(inst, topology) is False
def test_unhealthy_two_node_ip_changed():
"""Edge exists but with a different IP than instance was configured with."""
_, inst = _instance(node_ids=["node-a", "node-b"])
# Build topology with different IPs than _instance() expects
topology = Topology()
topology.add_node(NodeId("node-a"))
topology.add_node(NodeId("node-b"))
topology.add_connection(
Connection(
source=NodeId("node-a"),
sink=NodeId("node-b"),
edge=SocketConnection(
sink_multiaddr=Multiaddr(address="/ip4/192.168.99.99/tcp/50000")
),
)
)
topology.add_connection(
Connection(
source=NodeId("node-b"),
sink=NodeId("node-a"),
edge=SocketConnection(
sink_multiaddr=Multiaddr(address="/ip4/192.168.99.98/tcp/50000")
),
)
)
assert instance_connections_healthy(inst, topology) is False
def test_healthy_three_node_ring():
_, inst = _instance(node_ids=["node-a", "node-b", "node-c"])
topology = _topology("node-a", "node-b", "node-c")
assert instance_connections_healthy(inst, topology) is True
def test_unhealthy_three_node_one_edge_removed():
"""Remove one edge from a three-node ring — instance unhealthy."""
_, inst = _instance(node_ids=["node-a", "node-b", "node-c"])
# Build topology with one direction of one edge missing
topology = Topology()
nodes = [NodeId("node-a"), NodeId("node-b"), NodeId("node-c")]
for n in nodes:
topology.add_node(n)
# Add all edges except node-a → node-b
topology.add_connection(
Connection(
source=nodes[1],
sink=nodes[0],
edge=SocketConnection(
sink_multiaddr=Multiaddr(address="/ip4/10.0.0.1/tcp/50000")
),
)
)
topology.add_connection(
Connection(
source=nodes[1],
sink=nodes[2],
edge=SocketConnection(
sink_multiaddr=Multiaddr(address="/ip4/10.0.0.3/tcp/50000")
),
)
)
topology.add_connection(
Connection(
source=nodes[2],
sink=nodes[1],
edge=SocketConnection(
sink_multiaddr=Multiaddr(address="/ip4/10.0.0.2/tcp/50000")
),
)
)
topology.add_connection(
Connection(
source=nodes[2],
sink=nodes[0],
edge=SocketConnection(
sink_multiaddr=Multiaddr(address="/ip4/10.0.0.1/tcp/50000")
),
)
)
topology.add_connection(
Connection(
source=nodes[0],
sink=nodes[2],
edge=SocketConnection(
sink_multiaddr=Multiaddr(address="/ip4/10.0.0.3/tcp/50000")
),
)
)
# Missing: node-a → node-b (ip 10.0.0.2)
assert instance_connections_healthy(inst, topology) is False
def test_unhealthy_node_missing_from_topology():
"""Instance has a node that's not in the topology at all."""
_, inst = _instance(node_ids=["node-a", "node-b"])
topology = _topology("node-a") # node-b not present
assert instance_connections_healthy(inst, topology) is False
def test_healthy_extra_nodes_in_topology():
"""Extra nodes in topology don't affect instance health."""
_, inst = _instance(node_ids=["node-a", "node-b"])
topology = _topology("node-a", "node-b", "node-c")
assert instance_connections_healthy(inst, topology) is True
# --- find_unsatisfied_meta_instances ---
def test_unsatisfied_no_meta_instances():
result = find_unsatisfied_meta_instances({}, {}, Topology())
assert list(result) == []
def test_unsatisfied_one_satisfied():
meta = _meta_instance()
id_a, inst_a = _instance(meta_instance_id=meta.meta_instance_id)
topology = _topology("node-a")
result = find_unsatisfied_meta_instances(
{meta.meta_instance_id: meta},
{id_a: inst_a},
topology,
)
assert list(result) == []
def test_unsatisfied_one_not_satisfied():
meta = _meta_instance("test-org/model-x")
id_a, inst_a = _instance("test-org/model-y")
topology = _topology("node-a")
result = find_unsatisfied_meta_instances(
{meta.meta_instance_id: meta}, {id_a: inst_a}, topology
)
assert list(result) == [meta]
def test_unsatisfied_mix():
meta_satisfied = _meta_instance("test-org/model-a")
meta_unsatisfied = _meta_instance("test-org/model-b")
id_a, inst_a = _instance(
"test-org/model-a", meta_instance_id=meta_satisfied.meta_instance_id
)
topology = _topology("node-a")
result = find_unsatisfied_meta_instances(
{
meta_satisfied.meta_instance_id: meta_satisfied,
meta_unsatisfied.meta_instance_id: meta_unsatisfied,
},
{id_a: inst_a},
topology,
)
assert list(result) == [meta_unsatisfied]
def test_unsatisfied_node_disconnect():
meta = _meta_instance()
id_a, inst_a = _instance(
node_ids=["node-a", "node-b"], meta_instance_id=meta.meta_instance_id
)
topology = _topology("node-a") # node-b disconnected
result = find_unsatisfied_meta_instances(
{meta.meta_instance_id: meta},
{id_a: inst_a},
topology,
)
assert list(result) == [meta]
def test_unsatisfied_edge_break():
"""Instance exists but its connections broke — meta-instance becomes unsatisfied."""
meta = _meta_instance()
id_a, inst_a = _instance(
node_ids=["node-a", "node-b"], meta_instance_id=meta.meta_instance_id
)
topology = _topology("node-a", "node-b", connect=False) # nodes present, no edges
result = find_unsatisfied_meta_instances(
{meta.meta_instance_id: meta},
{id_a: inst_a},
topology,
)
assert list(result) == [meta]
def test_unsatisfied_idempotent():
meta = _meta_instance("test-org/model-x")
topology = _topology("node-a")
meta_instances = {meta.meta_instance_id: meta}
instances: dict[InstanceId, MlxRingInstance] = {}
result_1 = list(
find_unsatisfied_meta_instances(meta_instances, instances, topology)
)
result_2 = list(
find_unsatisfied_meta_instances(meta_instances, instances, topology)
)
assert result_1 == result_2
def test_unsatisfied_exclusive_binding():
"""Two MetaInstances for the same model: one is bound via meta_instance_id, the other is unsatisfied."""
meta_a = _meta_instance("test-org/model-x")
meta_b = _meta_instance("test-org/model-x")
id_inst, inst = _instance(
"test-org/model-x", meta_instance_id=meta_a.meta_instance_id
)
topology = _topology("node-a")
result = find_unsatisfied_meta_instances(
{
meta_a.meta_instance_id: meta_a,
meta_b.meta_instance_id: meta_b,
},
{id_inst: inst},
topology,
)
assert list(result) == [meta_b]
# --- apply handlers ---
def test_apply_meta_instance_created():
state = State()
meta = _meta_instance()
event = MetaInstanceCreated(meta_instance=meta)
new_state = apply(state, IndexedEvent(idx=0, event=event))
assert meta.meta_instance_id in new_state.meta_instances
assert new_state.meta_instances[meta.meta_instance_id] == meta
def test_apply_meta_instance_deleted():
meta = _meta_instance()
state = State(meta_instances={meta.meta_instance_id: meta})
event = MetaInstanceDeleted(meta_instance_id=meta.meta_instance_id)
new_state = apply(state, IndexedEvent(idx=0, event=event))
assert meta.meta_instance_id not in new_state.meta_instances
def test_apply_meta_instance_deleted_clears_failure_info():
meta = _meta_instance().model_copy(
update={"consecutive_failures": 2, "last_failure_error": "OOM"}
)
state = State(meta_instances={meta.meta_instance_id: meta})
event = MetaInstanceDeleted(meta_instance_id=meta.meta_instance_id)
new_state = apply(state, IndexedEvent(idx=0, event=event))
assert meta.meta_instance_id not in new_state.meta_instances
# --- instance_runners_failed ---
def test_runners_failed_all_failed():
"""All runners in RunnerFailed -> instance is failed."""
_, inst = _instance(node_ids=["node-a", "node-b"])
runners = {
rid: RunnerFailed(error_message="OOM")
for rid in inst.shard_assignments.node_to_runner.values()
}
is_failed, error = instance_runners_failed(inst, runners, {})
assert is_failed is True
assert error is not None
assert "OOM" in error
def test_runners_failed_mixed_failed_shutdown():
"""One Failed + one Shutdown = failed."""
_, inst = _instance(node_ids=["node-a", "node-b"])
runner_ids = list(inst.shard_assignments.node_to_runner.values())
runners = {
runner_ids[0]: RunnerFailed(error_message="crash"),
runner_ids[1]: RunnerShutdown(),
}
is_failed, error = instance_runners_failed(inst, runners, {})
assert is_failed is True
assert error is not None
assert "crash" in error
def test_runners_not_failed_all_shutdown():
"""All Shutdown (graceful) = not a failure."""
_, inst = _instance(node_ids=["node-a"])
runners = {
rid: RunnerShutdown()
for rid in inst.shard_assignments.node_to_runner.values()
}
is_failed, _ = instance_runners_failed(inst, runners, {})
assert is_failed is False
def test_runners_not_failed_still_active():
"""Some runners still active = not failed yet."""
_, inst = _instance(node_ids=["node-a", "node-b"])
runner_ids = list(inst.shard_assignments.node_to_runner.values())
runners = {
runner_ids[0]: RunnerFailed(error_message="OOM"),
runner_ids[1]: RunnerLoading(),
}
is_failed, _ = instance_runners_failed(inst, runners, {})
assert is_failed is False
def test_runners_not_failed_no_status():
"""Runner not yet reported = not failed."""
_, inst = _instance(node_ids=["node-a"])
is_failed, _ = instance_runners_failed(inst, {}, {})
assert is_failed is False
def test_runners_not_failed_healthy():
"""Runners in Ready state = not failed."""
_, inst = _instance(node_ids=["node-a"])
runners = {
rid: RunnerReady()
for rid in inst.shard_assignments.node_to_runner.values()
}
is_failed, _ = instance_runners_failed(inst, runners, {})
assert is_failed is False
# --- failure tracking in apply_instance_deleted ---
def test_apply_instance_deleted_tracks_failure():
"""InstanceDeleted with failure_error increments meta instance failure count."""
meta = _meta_instance()
iid, inst = _instance(
node_ids=["node-a"], meta_instance_id=meta.meta_instance_id
)
state = State(
meta_instances={meta.meta_instance_id: meta},
instances={iid: inst},
)
event = InstanceDeleted(instance_id=iid, failure_error="Runner OOM")
new_state = apply(state, IndexedEvent(idx=0, event=event))
mi = new_state.meta_instances[meta.meta_instance_id]
assert mi.consecutive_failures == 1
assert mi.last_failure_error == "Runner OOM"
def test_apply_instance_deleted_increments_failure():
"""Subsequent failures increment the counter."""
meta = _meta_instance().model_copy(
update={"consecutive_failures": 2, "last_failure_error": "previous error"}
)
iid, inst = _instance(
node_ids=["node-a"], meta_instance_id=meta.meta_instance_id
)
state = State(
meta_instances={meta.meta_instance_id: meta},
instances={iid: inst},
)
event = InstanceDeleted(instance_id=iid, failure_error="new error")
new_state = apply(state, IndexedEvent(idx=0, event=event))
mi = new_state.meta_instances[meta.meta_instance_id]
assert mi.consecutive_failures == 3
assert mi.last_failure_error == "new error"
def test_apply_instance_deleted_no_failure_no_tracking():
"""InstanceDeleted without failure_error does not track."""
meta = _meta_instance()
iid, inst = _instance(
node_ids=["node-a"], meta_instance_id=meta.meta_instance_id
)
state = State(
meta_instances={meta.meta_instance_id: meta},
instances={iid: inst},
)
event = InstanceDeleted(instance_id=iid)
new_state = apply(state, IndexedEvent(idx=0, event=event))
mi = new_state.meta_instances[meta.meta_instance_id]
assert mi.consecutive_failures == 0
def test_apply_instance_deleted_orphan_no_tracking():
"""InstanceDeleted for orphan instance (no meta_instance_id) does not track."""
iid, inst = _instance(node_ids=["node-a"])
state = State(instances={iid: inst})
event = InstanceDeleted(instance_id=iid, failure_error="crash")
new_state = apply(state, IndexedEvent(idx=0, event=event))
assert len(new_state.meta_instances) == 0
# --- InstanceRetrying ---
def test_apply_instance_retrying_removes_runners():
"""InstanceRetrying removes the instance's runners from state but keeps the instance."""
meta = _meta_instance()
iid, inst = _instance(node_ids=["node-a", "node-b"], meta_instance_id=meta.meta_instance_id)
runner_ids = list(inst.shard_assignments.node_to_runner.values())
runners = {
runner_ids[0]: RunnerFailed(error_message="OOM"),
runner_ids[1]: RunnerShutdown(),
}
state = State(
meta_instances={meta.meta_instance_id: meta},
instances={iid: inst},
runners=runners,
)
event = InstanceRetrying(
instance_id=iid,
meta_instance_id=meta.meta_instance_id,
failure_error="OOM",
)
new_state = apply(state, IndexedEvent(idx=0, event=event))
# Instance still exists
assert iid in new_state.instances
# Runners removed
assert runner_ids[0] not in new_state.runners
assert runner_ids[1] not in new_state.runners
def test_apply_instance_retrying_increments_failure():
"""InstanceRetrying increments consecutive_failures on the MetaInstance."""
meta = _meta_instance()
iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
state = State(
meta_instances={meta.meta_instance_id: meta},
instances={iid: inst},
)
event = InstanceRetrying(
instance_id=iid,
meta_instance_id=meta.meta_instance_id,
failure_error="crash",
)
new_state = apply(state, IndexedEvent(idx=0, event=event))
mi = new_state.meta_instances[meta.meta_instance_id]
assert mi.consecutive_failures == 1
assert mi.last_failure_error == "crash"
def test_apply_instance_retrying_skips_missing_runners():
"""InstanceRetrying doesn't assert if runners haven't reported yet."""
meta = _meta_instance()
iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
# No runners in state at all
state = State(
meta_instances={meta.meta_instance_id: meta},
instances={iid: inst},
)
event = InstanceRetrying(
instance_id=iid,
meta_instance_id=meta.meta_instance_id,
failure_error="crash",
)
# Should not raise
new_state = apply(state, IndexedEvent(idx=0, event=event))
assert iid in new_state.instances
def test_apply_instance_created_resets_failure_counter():
"""InstanceCreated resets consecutive_failures but preserves last_failure_error."""
meta = _meta_instance().model_copy(
update={"consecutive_failures": 3, "last_failure_error": "old error"}
)
_, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
state = State(meta_instances={meta.meta_instance_id: meta})
event = InstanceCreated(instance=inst)
new_state = apply(state, IndexedEvent(idx=0, event=event))
mi = new_state.meta_instances[meta.meta_instance_id]
assert mi.consecutive_failures == 0
assert mi.last_failure_error == "old error"
assert mi.placement_error is None
# --- InstanceHealthReconciler retry-vs-delete ---
async def test_health_reconciler_retries_when_under_limit():
"""InstanceHealthReconciler emits InstanceRetrying when consecutive_failures < 3."""
meta = _meta_instance()
iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
runner_ids = list(inst.shard_assignments.node_to_runner.values())
state = State(
meta_instances={meta.meta_instance_id: meta},
instances={iid: inst},
runners={runner_ids[0]: RunnerFailed(error_message="OOM")},
topology=_topology("node-a"),
)
reconciler = InstanceHealthReconciler()
events = await reconciler.reconcile(state)
assert len(events) == 1
assert isinstance(events[0], InstanceRetrying)
assert events[0].instance_id == iid
assert events[0].meta_instance_id == meta.meta_instance_id
async def test_health_reconciler_deletes_when_limit_reached():
"""InstanceHealthReconciler emits InstanceDeleted when consecutive_failures >= 3."""
meta = _meta_instance().model_copy(update={"consecutive_failures": 3})
iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
runner_ids = list(inst.shard_assignments.node_to_runner.values())
state = State(
meta_instances={meta.meta_instance_id: meta},
instances={iid: inst},
runners={runner_ids[0]: RunnerFailed(error_message="OOM")},
topology=_topology("node-a"),
)
reconciler = InstanceHealthReconciler()
events = await reconciler.reconcile(state)
assert len(events) == 1
assert isinstance(events[0], InstanceDeleted)
async def test_health_reconciler_deletes_without_meta_instance():
"""Instances without a MetaInstance are deleted immediately on runner failure."""
iid, inst = _instance(node_ids=["node-a"])
runner_ids = list(inst.shard_assignments.node_to_runner.values())
state = State(
instances={iid: inst},
runners={runner_ids[0]: RunnerFailed(error_message="crash")},
topology=_topology("node-a"),
)
reconciler = InstanceHealthReconciler()
events = await reconciler.reconcile(state)
assert len(events) == 1
assert isinstance(events[0], InstanceDeleted)
async def test_health_reconciler_network_failure_always_deletes():
"""Network failure always triggers InstanceDeleted regardless of retry count."""
meta = _meta_instance()
iid, inst = _instance(
node_ids=["node-a", "node-b"], meta_instance_id=meta.meta_instance_id
)
state = State(
meta_instances={meta.meta_instance_id: meta},
instances={iid: inst},
topology=_topology("node-a"), # node-b missing
)
reconciler = InstanceHealthReconciler()
events = await reconciler.reconcile(state)
assert len(events) == 1
assert isinstance(events[0], InstanceDeleted)
assert events[0].failure_error == "Network connection lost"

View File

@@ -4,7 +4,7 @@ from datetime import datetime
from loguru import logger
from exo.shared.types.common import NodeId
from exo.shared.types.common import MetaInstanceId, NodeId
from exo.shared.types.events import (
ChunkGenerated,
Event,
@@ -12,6 +12,10 @@ from exo.shared.types.events import (
InputChunkReceived,
InstanceCreated,
InstanceDeleted,
InstanceRetrying,
MetaInstanceCreated,
MetaInstanceDeleted,
MetaInstancePlacementFailed,
NodeDownloadProgress,
NodeGatheredInfo,
NodeTimedOut,
@@ -28,6 +32,7 @@ from exo.shared.types.events import (
TracesCollected,
TracesMerged,
)
from exo.shared.types.meta_instance import MetaInstance
from exo.shared.types.profiling import (
NodeIdentity,
NodeNetworkInfo,
@@ -72,6 +77,14 @@ def event_apply(event: Event, state: State) -> State:
return apply_instance_created(event, state)
case InstanceDeleted():
return apply_instance_deleted(event, state)
case InstanceRetrying():
return apply_instance_retrying(event, state)
case MetaInstanceCreated():
return apply_meta_instance_created(event, state)
case MetaInstanceDeleted():
return apply_meta_instance_deleted(event, state)
case MetaInstancePlacementFailed():
return apply_meta_instance_placement_failed(event, state)
case NodeTimedOut():
return apply_node_timed_out(event, state)
case NodeDownloadProgress():
@@ -174,20 +187,119 @@ def apply_task_failed(event: TaskFailed, state: State) -> State:
return state.model_copy(update={"tasks": new_tasks})
def _update_meta_instance(
state: State, mid: MetaInstanceId, **fields: object
) -> Mapping[MetaInstanceId, MetaInstance]:
mi = state.meta_instances[mid]
return {**state.meta_instances, mid: mi.model_copy(update=fields)}
def apply_instance_created(event: InstanceCreated, state: State) -> State:
instance = event.instance
new_instances: Mapping[InstanceId, Instance] = {
**state.instances,
instance.instance_id: instance,
}
return state.model_copy(update={"instances": new_instances})
update: dict[str, object] = {"instances": new_instances}
# Reset failure tracking when a new instance is created for a meta-instance
if instance.meta_instance_id and instance.meta_instance_id in state.meta_instances:
mi = state.meta_instances[instance.meta_instance_id]
if mi.placement_error is not None or mi.consecutive_failures > 0:
update["meta_instances"] = _update_meta_instance(
state,
instance.meta_instance_id,
placement_error=None,
consecutive_failures=0,
)
return state.model_copy(update=update)
def apply_instance_deleted(event: InstanceDeleted, state: State) -> State:
deleted_instance = state.instances.get(event.instance_id)
new_instances: Mapping[InstanceId, Instance] = {
iid: inst for iid, inst in state.instances.items() if iid != event.instance_id
}
return state.model_copy(update={"instances": new_instances})
update: dict[str, object] = {"instances": new_instances}
# Track failure on the MetaInstance itself
if (
event.failure_error
and deleted_instance
and deleted_instance.meta_instance_id
and deleted_instance.meta_instance_id in state.meta_instances
):
mid = deleted_instance.meta_instance_id
mi = state.meta_instances[mid]
update["meta_instances"] = {
**state.meta_instances,
mid: mi.model_copy(
update={
"consecutive_failures": mi.consecutive_failures + 1,
"last_failure_error": event.failure_error,
}
),
}
return state.model_copy(update=update)
def apply_instance_retrying(event: InstanceRetrying, state: State) -> State:
"""Runners failed but retry limit not reached — remove runners, keep instance."""
instance = state.instances.get(event.instance_id)
if instance is None:
return state
# Remove all runners belonging to this instance from state
runner_ids_to_remove = set(instance.shard_assignments.node_to_runner.values())
new_runners: Mapping[RunnerId, RunnerStatus] = {
rid: rs
for rid, rs in state.runners.items()
if rid not in runner_ids_to_remove
}
update: dict[str, object] = {"runners": new_runners}
# Increment failure count on the MetaInstance
if event.meta_instance_id in state.meta_instances:
update["meta_instances"] = _update_meta_instance(
state,
event.meta_instance_id,
consecutive_failures=state.meta_instances[event.meta_instance_id].consecutive_failures + 1,
last_failure_error=event.failure_error,
)
return state.model_copy(update=update)
def apply_meta_instance_created(event: MetaInstanceCreated, state: State) -> State:
new_meta: Mapping[MetaInstanceId, MetaInstance] = {
**state.meta_instances,
event.meta_instance.meta_instance_id: event.meta_instance,
}
return state.model_copy(update={"meta_instances": new_meta})
def apply_meta_instance_deleted(event: MetaInstanceDeleted, state: State) -> State:
new_meta: Mapping[MetaInstanceId, MetaInstance] = {
mid: mi
for mid, mi in state.meta_instances.items()
if mid != event.meta_instance_id
}
return state.model_copy(update={"meta_instances": new_meta})
def apply_meta_instance_placement_failed(
event: MetaInstancePlacementFailed, state: State
) -> State:
if event.meta_instance_id not in state.meta_instances:
return state
return state.model_copy(
update={
"meta_instances": _update_meta_instance(
state, event.meta_instance_id, placement_error=event.reason
)
}
)
def apply_runner_status_updated(event: RunnerStatusUpdated, state: State) -> State:

View File

@@ -3,11 +3,10 @@ from collections.abc import Generator
from typing import Annotated, Any, Literal
from uuid import uuid4
from pydantic import BaseModel, Field, field_validator
from pydantic_core import PydanticUseDefault
from pydantic import BaseModel, Field
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.common import CommandId, MetaInstanceId, NodeId
from exo.shared.types.memory import Memory
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.shards import Sharding, ShardMetadata
@@ -227,13 +226,6 @@ class PlaceInstanceParams(BaseModel):
instance_meta: InstanceMeta = InstanceMeta.MlxRing
min_nodes: int = 1
@field_validator("sharding", "instance_meta", mode="plain")
@classmethod
def use_default(cls, v: object):
if not v or not isinstance(v, (Sharding, InstanceMeta)):
raise PydanticUseDefault()
return v
class CreateInstanceParams(BaseModel):
instance: Instance
@@ -269,6 +261,26 @@ class DeleteInstanceResponse(BaseModel):
instance_id: InstanceId
class CreateMetaInstanceParams(BaseModel):
model_id: ModelId
sharding: Sharding = Sharding.Pipeline
instance_meta: InstanceMeta = InstanceMeta.MlxRing
min_nodes: int = 1
node_ids: list[NodeId] | None = None
class CreateMetaInstanceResponse(BaseModel):
message: str
command_id: CommandId
meta_instance_id: MetaInstanceId
class DeleteMetaInstanceResponse(BaseModel):
message: str
command_id: CommandId
meta_instance_id: MetaInstanceId
class AdvancedImageParams(BaseModel):
seed: Annotated[int, Field(ge=0)] | None = None
num_inference_steps: Annotated[int, Field(ge=1, le=100)] | None = None

View File

@@ -6,7 +6,8 @@ from exo.shared.types.api import (
ImageGenerationTaskParams,
)
from exo.shared.types.chunks import InputImageChunk
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.common import CommandId, MetaInstanceId, NodeId
from exo.shared.types.meta_instance import MetaInstance
from exo.shared.types.text_generation import TextGenerationTaskParams
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.shards import Sharding, ShardMetadata
@@ -48,6 +49,14 @@ class DeleteInstance(BaseCommand):
instance_id: InstanceId
class CreateMetaInstance(BaseCommand):
meta_instance: MetaInstance
class DeleteMetaInstance(BaseCommand):
meta_instance_id: MetaInstanceId
class TaskFinished(BaseCommand):
finished_command_id: CommandId
@@ -89,6 +98,8 @@ Command = (
| PlaceInstance
| CreateInstance
| DeleteInstance
| CreateMetaInstance
| DeleteMetaInstance
| TaskFinished
| SendInputChunk
)

View File

@@ -42,6 +42,10 @@ class CommandId(Id):
pass
class MetaInstanceId(Id):
"""Identifier for a MetaInstance."""
class Host(CamelCaseModel):
ip: str
port: int

View File

@@ -5,7 +5,8 @@ from pydantic import Field
from exo.shared.topology import Connection
from exo.shared.types.chunks import GenerationChunk, InputImageChunk
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
from exo.shared.types.common import CommandId, Id, MetaInstanceId, NodeId, SessionId
from exo.shared.types.meta_instance import MetaInstance
from exo.shared.types.tasks import Task, TaskId, TaskStatus
from exo.shared.types.worker.downloads import DownloadProgress
from exo.shared.types.worker.instances import Instance, InstanceId
@@ -66,6 +67,30 @@ class InstanceCreated(BaseEvent):
class InstanceDeleted(BaseEvent):
instance_id: InstanceId
failure_error: str | None = None
class MetaInstanceCreated(BaseEvent):
meta_instance: MetaInstance
class MetaInstanceDeleted(BaseEvent):
meta_instance_id: MetaInstanceId
@final
class MetaInstancePlacementFailed(BaseEvent):
meta_instance_id: MetaInstanceId
reason: str
@final
class InstanceRetrying(BaseEvent):
"""Runners failed but retry count is below the limit — restart runners, keep instance."""
instance_id: InstanceId
meta_instance_id: MetaInstanceId
failure_error: str
class RunnerStatusUpdated(BaseEvent):
@@ -141,6 +166,10 @@ Event = (
| TaskAcknowledged
| InstanceCreated
| InstanceDeleted
| InstanceRetrying
| MetaInstanceCreated
| MetaInstanceDeleted
| MetaInstancePlacementFailed
| RunnerStatusUpdated
| RunnerDeleted
| NodeTimedOut

View File

@@ -0,0 +1,25 @@
from typing import final
from pydantic import Field
from exo.shared.models.model_cards import ModelId
from exo.shared.types.common import MetaInstanceId, NodeId
from exo.shared.types.worker.instances import InstanceMeta
from exo.shared.types.worker.shards import Sharding
from exo.utils.pydantic_ext import FrozenModel
@final
class MetaInstance(FrozenModel):
"""Declarative constraint: ensure an instance matching these parameters always exists."""
meta_instance_id: MetaInstanceId = Field(default_factory=MetaInstanceId)
model_id: ModelId
sharding: Sharding = Sharding.Pipeline
instance_meta: InstanceMeta = InstanceMeta.MlxRing
min_nodes: int = 1
node_ids: list[NodeId] | None = None
# Failure tracking
placement_error: str | None = None
consecutive_failures: int = 0
last_failure_error: str | None = None

View File

@@ -6,7 +6,8 @@ from pydantic import ConfigDict, Field, field_serializer, field_validator
from pydantic.alias_generators import to_camel
from exo.shared.topology import Topology, TopologySnapshot
from exo.shared.types.common import NodeId
from exo.shared.types.common import MetaInstanceId, NodeId
from exo.shared.types.meta_instance import MetaInstance
from exo.shared.types.profiling import (
DiskUsage,
MemoryUsage,
@@ -41,6 +42,7 @@ class State(CamelCaseModel):
arbitrary_types_allowed=True,
)
instances: Mapping[InstanceId, Instance] = {}
meta_instances: Mapping[MetaInstanceId, MetaInstance] = {}
runners: Mapping[RunnerId, RunnerStatus] = {}
downloads: Mapping[NodeId, Sequence[DownloadProgress]] = {}
tasks: Mapping[TaskId, Task] = {}

View File

@@ -2,7 +2,7 @@ from enum import Enum
from pydantic import model_validator
from exo.shared.types.common import Host, Id, NodeId
from exo.shared.types.common import Host, Id, MetaInstanceId, NodeId
from exo.shared.types.worker.runners import RunnerId, ShardAssignments, ShardMetadata
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
@@ -19,6 +19,7 @@ class InstanceMeta(str, Enum):
class BaseInstance(TaggedModel):
instance_id: InstanceId
shard_assignments: ShardAssignments
meta_instance_id: MetaInstanceId | None = None
def shard(self, runner_id: RunnerId) -> ShardMetadata | None:
return self.shard_assignments.runner_to_shard.get(runner_id, None)

View File

@@ -34,6 +34,7 @@ from exo.shared.types.worker.runners import (
RunnerLoading,
RunnerReady,
RunnerRunning,
RunnerShutdown,
RunnerStatus,
RunnerWarmingUp,
)
@@ -54,7 +55,7 @@ def plan(
# Python short circuiting OR logic should evaluate these sequentially.
return (
_kill_runner(runners, all_runners, instances)
or _create_runner(node_id, runners, instances)
or _create_runner(node_id, runners, instances, all_runners)
or _model_needs_download(node_id, runners, global_download_status)
or _init_distributed_backend(runners, all_runners)
or _load_model(runners, all_runners, global_download_status)
@@ -73,6 +74,12 @@ def _kill_runner(
if (instance_id := runner.bound_instance.instance.instance_id) not in instances:
return Shutdown(instance_id=instance_id, runner_id=runner_id)
# Master removed our runner from state (retry signal) and process is dead
if runner_id not in all_runners and isinstance(
runner.status, (RunnerFailed, RunnerShutdown)
):
return Shutdown(instance_id=instance_id, runner_id=runner_id)
for (
global_runner_id
) in runner.bound_instance.instance.shard_assignments.node_to_runner.values():
@@ -90,6 +97,7 @@ def _create_runner(
node_id: NodeId,
runners: Mapping[RunnerId, RunnerSupervisor],
instances: Mapping[InstanceId, Instance],
all_runners: Mapping[RunnerId, RunnerStatus],
) -> CreateRunner | None:
for instance in instances.values():
runner_id = instance.shard_assignments.node_to_runner.get(node_id, None)
@@ -99,6 +107,16 @@ def _create_runner(
if runner_id in runners:
continue
# Don't create while any peer runner is in a terminal state — wait for
# the master to emit InstanceRetrying which removes them from state.
has_terminal_peer = any(
isinstance(all_runners.get(peer_rid), (RunnerFailed, RunnerShutdown))
for peer_rid in instance.shard_assignments.node_to_runner.values()
if peer_rid != runner_id
)
if has_terminal_peer:
continue
shard = instance.shard(runner_id)
assert shard is not None