mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-12 07:01:06 -05:00
Compare commits
28 Commits
main
...
alexcheema
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f9ffdaef5f | ||
|
|
8c2416c9ea | ||
|
|
e5007f619a | ||
|
|
a627f67253 | ||
|
|
f189222bfc | ||
|
|
ad6d35d68a | ||
|
|
c236d62caf | ||
|
|
a8069e8a30 | ||
|
|
84ce555d55 | ||
|
|
b78ea438bc | ||
|
|
1960b16f9f | ||
|
|
c6838c8fd8 | ||
|
|
420d9b9e76 | ||
|
|
13f1e9c489 | ||
|
|
451a06b3d8 | ||
|
|
94b55d66f4 | ||
|
|
2b68b931c5 | ||
|
|
4aecaa7748 | ||
|
|
25e2891c30 | ||
|
|
16345e0ffa | ||
|
|
3a845f90b0 | ||
|
|
dccf2440ba | ||
|
|
f96f3f2c0f | ||
|
|
7d54e468d5 | ||
|
|
124d504f95 | ||
|
|
9ab4a40989 | ||
|
|
f4329c72c2 | ||
|
|
ceb76b8f6c |
@@ -185,11 +185,7 @@
|
||||
|
||||
let instanceType: string | null = null;
|
||||
if (instanceTag === "MlxRingInstance") instanceType = "MLX Ring";
|
||||
else if (
|
||||
instanceTag === "MlxIbvInstance" ||
|
||||
instanceTag === "MlxJacclInstance"
|
||||
)
|
||||
instanceType = "MLX RDMA";
|
||||
else if (instanceTag === "MlxJacclInstance") instanceType = "MLX RDMA";
|
||||
|
||||
let sharding: string | null = null;
|
||||
const inst = instance as {
|
||||
|
||||
@@ -21,7 +21,7 @@
|
||||
} | null;
|
||||
nodes?: Record<string, NodeInfo>;
|
||||
sharding?: "Pipeline" | "Tensor";
|
||||
runtime?: "MlxRing" | "MlxIbv" | "MlxJaccl";
|
||||
runtime?: "MlxRing" | "MlxJaccl";
|
||||
onLaunch?: () => void;
|
||||
tags?: string[];
|
||||
apiPreview?: PlacementPreview | null;
|
||||
@@ -348,7 +348,7 @@
|
||||
// Debug mode state
|
||||
const isDebugMode = $derived(debugMode());
|
||||
const topology = $derived(topologyData());
|
||||
const isRdma = $derived(runtime === "MlxIbv" || runtime === "MlxJaccl");
|
||||
const isRdma = $derived(runtime === "MlxJaccl");
|
||||
|
||||
// Get interface name for an IP from node data
|
||||
function getInterfaceForIp(nodeId: string, ip?: string): string | null {
|
||||
@@ -575,7 +575,7 @@
|
||||
>
|
||||
{runtime === "MlxRing"
|
||||
? "MLX Ring"
|
||||
: runtime === "MlxIbv" || runtime === "MlxJaccl"
|
||||
: runtime === "MlxJaccl"
|
||||
? "MLX RDMA"
|
||||
: runtime}
|
||||
</span>
|
||||
|
||||
@@ -168,7 +168,7 @@ export interface ModelDownloadStatus {
|
||||
export interface PlacementPreview {
|
||||
model_id: string;
|
||||
sharding: "Pipeline" | "Tensor";
|
||||
instance_meta: "MlxRing" | "MlxIbv" | "MlxJaccl";
|
||||
instance_meta: "MlxRing" | "MlxJaccl";
|
||||
instance: unknown | null;
|
||||
memory_delta_by_node: Record<string, number> | null;
|
||||
error: string | null;
|
||||
@@ -219,7 +219,6 @@ interface RawStateResponse {
|
||||
string,
|
||||
{
|
||||
MlxRingInstance?: Instance;
|
||||
MlxIbvInstance?: Instance;
|
||||
MlxJacclInstance?: Instance;
|
||||
}
|
||||
>;
|
||||
@@ -250,6 +249,20 @@ interface RawStateResponse {
|
||||
>;
|
||||
// Thunderbolt bridge cycles (nodes with bridge enabled forming loops)
|
||||
thunderboltBridgeCycles?: string[][];
|
||||
// MetaInstances (declarative instance constraints)
|
||||
metaInstances?: Record<string, MetaInstanceData>;
|
||||
}
|
||||
|
||||
export interface MetaInstanceData {
|
||||
metaInstanceId: string;
|
||||
modelId: string;
|
||||
sharding: string;
|
||||
instanceMeta: string;
|
||||
minNodes: number;
|
||||
nodeIds: string[] | null;
|
||||
placementError: string | null;
|
||||
consecutiveFailures: number;
|
||||
lastFailureError: string | null;
|
||||
}
|
||||
|
||||
export interface MessageAttachment {
|
||||
@@ -535,6 +548,7 @@ class AppStore {
|
||||
isLoadingPreviews = $state(false);
|
||||
previewNodeFilter = $state<Set<string>>(new Set());
|
||||
lastUpdate = $state<number | null>(null);
|
||||
metaInstances = $state<Record<string, MetaInstanceData>>({});
|
||||
nodeIdentities = $state<Record<string, RawNodeIdentity>>({});
|
||||
thunderboltBridgeCycles = $state<string[][]>([]);
|
||||
nodeThunderbolt = $state<
|
||||
@@ -891,11 +905,7 @@ class AppStore {
|
||||
|
||||
let instanceType: string | null = null;
|
||||
if (instanceTag === "MlxRingInstance") instanceType = "MLX Ring";
|
||||
else if (
|
||||
instanceTag === "MlxIbvInstance" ||
|
||||
instanceTag === "MlxJacclInstance"
|
||||
)
|
||||
instanceType = "MLX RDMA";
|
||||
else if (instanceTag === "MlxJacclInstance") instanceType = "MLX RDMA";
|
||||
|
||||
let sharding: string | null = null;
|
||||
const inst = instance as {
|
||||
@@ -1260,6 +1270,8 @@ class AppStore {
|
||||
if (data.downloads) {
|
||||
this.downloads = data.downloads;
|
||||
}
|
||||
// MetaInstances
|
||||
this.metaInstances = data.metaInstances ?? {};
|
||||
if (data.nodeDisk) {
|
||||
this.nodeDisk = data.nodeDisk;
|
||||
}
|
||||
@@ -3019,6 +3031,7 @@ export const tps = () => appStore.tps;
|
||||
export const totalTokens = () => appStore.totalTokens;
|
||||
export const topologyData = () => appStore.topologyData;
|
||||
export const instances = () => appStore.instances;
|
||||
export const metaInstances = () => appStore.metaInstances;
|
||||
export const runners = () => appStore.runners;
|
||||
export const downloads = () => appStore.downloads;
|
||||
export const nodeDisk = () => appStore.nodeDisk;
|
||||
|
||||
@@ -42,6 +42,7 @@
|
||||
toggleTopologyOnlyMode,
|
||||
chatSidebarVisible,
|
||||
toggleChatSidebarVisible,
|
||||
metaInstances,
|
||||
nodeThunderbolt,
|
||||
nodeRdmaCtl,
|
||||
thunderboltBridgeCycles,
|
||||
@@ -49,6 +50,7 @@
|
||||
nodeIdentities,
|
||||
type DownloadProgress,
|
||||
type PlacementPreview,
|
||||
type MetaInstanceData,
|
||||
} from "$lib/stores/app.svelte";
|
||||
import HeaderNav from "$lib/components/HeaderNav.svelte";
|
||||
import { fade, fly } from "svelte/transition";
|
||||
@@ -68,7 +70,72 @@
|
||||
const debugEnabled = $derived(debugMode());
|
||||
const topologyOnlyEnabled = $derived(topologyOnlyMode());
|
||||
const sidebarVisible = $derived(chatSidebarVisible());
|
||||
const metaInstancesData = $derived(metaInstances());
|
||||
const tbBridgeCycles = $derived(thunderboltBridgeCycles());
|
||||
|
||||
// Get status for a MetaInstance that has no backing instance yet
|
||||
function getMetaInstancePlacingStatus(metaInstanceId: string) {
|
||||
const meta = metaInstancesData[metaInstanceId];
|
||||
const placementError = meta?.placementError;
|
||||
const failures = meta?.consecutiveFailures ?? 0;
|
||||
const lastError = meta?.lastFailureError;
|
||||
|
||||
if (placementError) {
|
||||
return {
|
||||
statusText: "PLACEMENT FAILED",
|
||||
statusClass: "failed",
|
||||
isDownloading: false as const,
|
||||
isFailed: true,
|
||||
progress: null,
|
||||
perNode: [] as Array<{
|
||||
nodeId: string;
|
||||
nodeName: string;
|
||||
progress: DownloadProgress;
|
||||
}>,
|
||||
perNodeStatus: [] as PerNodeRunnerStatus[],
|
||||
errorMessage: placementError,
|
||||
};
|
||||
}
|
||||
|
||||
if (failures > 0) {
|
||||
const retryPosition = ((failures - 1) % 3) + 1;
|
||||
const isRecreated = failures % 3 === 0;
|
||||
return {
|
||||
statusText: isRecreated ? "PLACING" : `RETRYING (${retryPosition}/3)`,
|
||||
statusClass: "starting",
|
||||
isDownloading: false as const,
|
||||
isFailed: false,
|
||||
progress: null,
|
||||
perNode: [] as Array<{
|
||||
nodeId: string;
|
||||
nodeName: string;
|
||||
progress: DownloadProgress;
|
||||
}>,
|
||||
perNodeStatus: [] as PerNodeRunnerStatus[],
|
||||
errorMessage: isRecreated
|
||||
? `Instance re-created due to failure: ${lastError}`
|
||||
: lastError
|
||||
? `Previous failure: ${lastError}`
|
||||
: null,
|
||||
};
|
||||
}
|
||||
|
||||
return {
|
||||
statusText: "PLACING",
|
||||
statusClass: "starting",
|
||||
isDownloading: false as const,
|
||||
isFailed: false,
|
||||
progress: null,
|
||||
perNode: [] as Array<{
|
||||
nodeId: string;
|
||||
nodeName: string;
|
||||
progress: DownloadProgress;
|
||||
}>,
|
||||
perNodeStatus: [] as PerNodeRunnerStatus[],
|
||||
errorMessage: null,
|
||||
};
|
||||
}
|
||||
|
||||
const tbBridgeData = $derived(nodeThunderboltBridge());
|
||||
const identitiesData = $derived(nodeIdentities());
|
||||
const tbIdentifiers = $derived(nodeThunderbolt());
|
||||
@@ -114,6 +181,17 @@
|
||||
});
|
||||
let tb5InfoDismissed = $state(false);
|
||||
|
||||
// Detect [jaccl] RDMA driver errors from MetaInstance failure errors
|
||||
const jacclError = $derived.by(() => {
|
||||
for (const mi of Object.values(metaInstancesData)) {
|
||||
if (mi.lastFailureError?.includes("[jaccl]")) {
|
||||
return mi.lastFailureError;
|
||||
}
|
||||
}
|
||||
return null;
|
||||
});
|
||||
let jacclDismissedError = $state<string | null>(null);
|
||||
|
||||
// Helper to get friendly node name from node ID
|
||||
function getNodeName(nodeId: string): string {
|
||||
const node = data?.nodes?.[nodeId];
|
||||
@@ -211,7 +289,7 @@
|
||||
return model.tasks.includes("ImageToImage");
|
||||
}
|
||||
let selectedSharding = $state<"Pipeline" | "Tensor">("Pipeline");
|
||||
type InstanceMeta = "MlxRing" | "MlxIbv" | "MlxJaccl";
|
||||
type InstanceMeta = "MlxRing" | "MlxJaccl";
|
||||
|
||||
// Launch defaults persistence
|
||||
const LAUNCH_DEFAULTS_KEY = "exo-launch-defaults";
|
||||
@@ -468,7 +546,7 @@
|
||||
const matchesSelectedRuntime = (runtime: InstanceMeta): boolean =>
|
||||
selectedInstanceType === "MlxRing"
|
||||
? runtime === "MlxRing"
|
||||
: runtime === "MlxIbv" || runtime === "MlxJaccl";
|
||||
: runtime === "MlxJaccl" || runtime === "MlxJaccl";
|
||||
|
||||
// Helper to check if a model can be launched (has valid placement with >= minNodes)
|
||||
function canModelFit(modelId: string): boolean {
|
||||
@@ -684,39 +762,30 @@
|
||||
launchingModelId = modelId;
|
||||
|
||||
try {
|
||||
// Use the specific preview if provided, otherwise fall back to filtered preview
|
||||
const preview = specificPreview ?? filteredPreview();
|
||||
|
||||
let instanceData: unknown;
|
||||
// Extract node IDs from the preview the user is seeing
|
||||
const previewNodeIds = preview?.memory_delta_by_node
|
||||
? Object.keys(preview.memory_delta_by_node)
|
||||
: nodeFilter.size > 0
|
||||
? Array.from(nodeFilter)
|
||||
: undefined;
|
||||
|
||||
if (preview?.instance) {
|
||||
// Use the instance from the preview
|
||||
instanceData = preview.instance;
|
||||
} else {
|
||||
// Fallback: GET placement from API
|
||||
const placementResponse = await fetch(
|
||||
`/instance/placement?model_id=${encodeURIComponent(modelId)}&sharding=${selectedSharding}&instance_meta=${selectedInstanceType}&min_nodes=${selectedMinNodes}`,
|
||||
);
|
||||
|
||||
if (!placementResponse.ok) {
|
||||
const errorText = await placementResponse.text();
|
||||
console.error("Failed to get placement:", errorText);
|
||||
return;
|
||||
}
|
||||
|
||||
instanceData = await placementResponse.json();
|
||||
}
|
||||
|
||||
// POST the instance to create it
|
||||
const response = await fetch("/instance", {
|
||||
const response = await fetch("/meta_instance", {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({ instance: instanceData }),
|
||||
body: JSON.stringify({
|
||||
model_id: modelId,
|
||||
sharding: preview?.sharding ?? selectedSharding,
|
||||
instance_meta: preview?.instance_meta ?? selectedInstanceType,
|
||||
min_nodes: selectedMinNodes,
|
||||
node_ids: previewNodeIds,
|
||||
}),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const errorText = await response.text();
|
||||
console.error("Failed to launch instance:", errorText);
|
||||
console.error("Failed to create meta instance:", errorText);
|
||||
} else {
|
||||
// Always auto-select the newly launched model so the user chats to what they just launched
|
||||
setSelectedChatModel(modelId);
|
||||
@@ -739,7 +808,7 @@
|
||||
setTimeout(scrollToBottom, 1000);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Error launching instance:", error);
|
||||
console.error("Error creating meta instance:", error);
|
||||
} finally {
|
||||
launchingModelId = null;
|
||||
}
|
||||
@@ -941,15 +1010,18 @@
|
||||
nodeName: string;
|
||||
progress: DownloadProgress;
|
||||
}>;
|
||||
perNodeStatus: PerNodeRunnerStatus[];
|
||||
} {
|
||||
if (!downloadsData || Object.keys(downloadsData).length === 0) {
|
||||
const statusInfo = deriveInstanceStatus(instanceWrapped);
|
||||
return {
|
||||
isDownloading: false,
|
||||
isFailed: false,
|
||||
errorMessage: null,
|
||||
isFailed: statusInfo.statusText === "FAILED",
|
||||
errorMessage: statusInfo.errorMessage,
|
||||
progress: null,
|
||||
statusText: "RUNNING",
|
||||
statusText: statusInfo.statusText,
|
||||
perNode: [],
|
||||
perNodeStatus: statusInfo.perNodeStatus,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -963,6 +1035,7 @@
|
||||
progress: null,
|
||||
statusText: "PREPARING",
|
||||
perNode: [],
|
||||
perNodeStatus: [],
|
||||
};
|
||||
}
|
||||
|
||||
@@ -1031,6 +1104,7 @@
|
||||
progress: null,
|
||||
statusText: "FAILED",
|
||||
perNode: [],
|
||||
perNodeStatus: [],
|
||||
};
|
||||
}
|
||||
}
|
||||
@@ -1071,10 +1145,11 @@
|
||||
return {
|
||||
isDownloading: false,
|
||||
isFailed: statusInfo.statusText === "FAILED",
|
||||
errorMessage: null,
|
||||
errorMessage: statusInfo.errorMessage,
|
||||
progress: null,
|
||||
statusText: statusInfo.statusText,
|
||||
perNode: [],
|
||||
perNodeStatus: statusInfo.perNodeStatus,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -1098,92 +1173,172 @@
|
||||
},
|
||||
statusText: "DOWNLOADING",
|
||||
perNode,
|
||||
perNodeStatus: [],
|
||||
};
|
||||
}
|
||||
|
||||
// Derive instance status from runners
|
||||
// Get color class for a status
|
||||
function getStatusColor(statusText: string): string {
|
||||
switch (statusText) {
|
||||
case "FAILED":
|
||||
return "text-red-400";
|
||||
case "SHUTDOWN":
|
||||
return "text-gray-400";
|
||||
case "DOWNLOADING":
|
||||
return "text-blue-400";
|
||||
case "LOADING":
|
||||
case "WARMING UP":
|
||||
case "WAITING":
|
||||
case "INITIALIZING":
|
||||
return "text-yellow-400";
|
||||
case "RUNNING":
|
||||
return "text-teal-400";
|
||||
case "READY":
|
||||
case "LOADED":
|
||||
return "text-green-400";
|
||||
default:
|
||||
return "text-exo-light-gray";
|
||||
}
|
||||
if (statusText === "FAILED" || statusText === "PLACEMENT FAILED")
|
||||
return "text-red-400";
|
||||
if (statusText.startsWith("RETRYING")) return "text-orange-400";
|
||||
if (statusText === "SHUTDOWN") return "text-gray-400";
|
||||
if (statusText === "DOWNLOADING") return "text-blue-400";
|
||||
if (
|
||||
statusText.startsWith("LOADING") ||
|
||||
statusText.startsWith("WARMING UP") ||
|
||||
statusText === "WAITING" ||
|
||||
statusText === "INITIALIZING"
|
||||
)
|
||||
return "text-yellow-400";
|
||||
if (statusText === "RUNNING") return "text-teal-400";
|
||||
if (statusText === "READY" || statusText === "LOADED")
|
||||
return "text-green-400";
|
||||
return "text-exo-light-gray";
|
||||
}
|
||||
|
||||
const RUNNER_STATUS_MAP: Record<string, string> = {
|
||||
RunnerWaitingForInitialization: "WaitingForInitialization",
|
||||
RunnerInitializingBackend: "InitializingBackend",
|
||||
RunnerWaitingForModel: "WaitingForModel",
|
||||
RunnerLoading: "Loading",
|
||||
RunnerLoaded: "Loaded",
|
||||
RunnerWarmingUp: "WarmingUp",
|
||||
RunnerReady: "Ready",
|
||||
RunnerRunning: "Running",
|
||||
RunnerShutdown: "Shutdown",
|
||||
RunnerFailed: "Failed",
|
||||
};
|
||||
|
||||
// Friendly labels for display
|
||||
const RUNNER_STATUS_DISPLAY: Record<string, string> = {
|
||||
WaitingForInitialization: "Initializing",
|
||||
InitializingBackend: "Initializing",
|
||||
WaitingForModel: "Waiting",
|
||||
Loading: "Loading",
|
||||
Loaded: "Loaded",
|
||||
WarmingUp: "Warming Up",
|
||||
Ready: "Ready",
|
||||
Running: "Running",
|
||||
Shutdown: "Shutdown",
|
||||
Failed: "Failed",
|
||||
};
|
||||
|
||||
interface PerNodeRunnerStatus {
|
||||
nodeId: string;
|
||||
nodeName: string;
|
||||
status: string; // friendly display status
|
||||
}
|
||||
|
||||
function deriveInstanceStatus(instanceWrapped: unknown): {
|
||||
statusText: string;
|
||||
statusClass: string;
|
||||
perNodeStatus: PerNodeRunnerStatus[];
|
||||
errorMessage: string | null;
|
||||
} {
|
||||
const [, instance] = getTagged(instanceWrapped);
|
||||
if (!instance || typeof instance !== "object") {
|
||||
return { statusText: "PREPARING", statusClass: "inactive" };
|
||||
return {
|
||||
statusText: "PREPARING",
|
||||
statusClass: "inactive",
|
||||
perNodeStatus: [],
|
||||
errorMessage: null,
|
||||
};
|
||||
}
|
||||
|
||||
const inst = instance as {
|
||||
shardAssignments?: { runnerToShard?: Record<string, unknown> };
|
||||
shardAssignments?: {
|
||||
runnerToShard?: Record<string, unknown>;
|
||||
nodeToRunner?: Record<string, string>;
|
||||
};
|
||||
};
|
||||
const nodeToRunner = inst.shardAssignments?.nodeToRunner || {};
|
||||
const runnerIds = Object.keys(inst.shardAssignments?.runnerToShard || {});
|
||||
const totalNodes = runnerIds.length;
|
||||
|
||||
const statuses = runnerIds
|
||||
.map((rid) => {
|
||||
const r = runnersData[rid];
|
||||
if (!r) return null;
|
||||
const [kind] = getTagged(r);
|
||||
const statusMap: Record<string, string> = {
|
||||
RunnerWaitingForInitialization: "WaitingForInitialization",
|
||||
RunnerInitializingBackend: "InitializingBackend",
|
||||
RunnerWaitingForModel: "WaitingForModel",
|
||||
RunnerLoading: "Loading",
|
||||
RunnerLoaded: "Loaded",
|
||||
RunnerWarmingUp: "WarmingUp",
|
||||
RunnerReady: "Ready",
|
||||
RunnerRunning: "Running",
|
||||
RunnerShutdown: "Shutdown",
|
||||
RunnerFailed: "Failed",
|
||||
};
|
||||
return kind ? statusMap[kind] || null : null;
|
||||
})
|
||||
.filter((s): s is string => s !== null);
|
||||
// Build per-node status and extract error messages from RunnerFailed
|
||||
const perNodeStatus: PerNodeRunnerStatus[] = [];
|
||||
const statuses: string[] = [];
|
||||
const failedErrors: string[] = [];
|
||||
for (const [nodeId, runnerId] of Object.entries(nodeToRunner)) {
|
||||
const r = runnersData[runnerId];
|
||||
let status: string | null = null;
|
||||
if (r) {
|
||||
const [kind, runnerData] = getTagged(r);
|
||||
status = kind ? RUNNER_STATUS_MAP[kind] || null : null;
|
||||
// Extract error message from RunnerFailed
|
||||
if (
|
||||
kind === "RunnerFailed" &&
|
||||
runnerData &&
|
||||
typeof runnerData === "object"
|
||||
) {
|
||||
const rd = runnerData as { errorMessage?: string };
|
||||
if (rd.errorMessage) failedErrors.push(`${getNodeName(nodeId)}: ${rd.errorMessage}`);
|
||||
}
|
||||
}
|
||||
if (status) {
|
||||
statuses.push(status);
|
||||
perNodeStatus.push({
|
||||
nodeId,
|
||||
nodeName: getNodeName(nodeId),
|
||||
status: RUNNER_STATUS_DISPLAY[status] || status,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
const has = (s: string) => statuses.includes(s);
|
||||
const count = (s: string) => statuses.filter((v) => v === s).length;
|
||||
|
||||
if (statuses.length === 0)
|
||||
return { statusText: "PREPARING", statusClass: "inactive" };
|
||||
if (has("Failed")) return { statusText: "FAILED", statusClass: "failed" };
|
||||
return {
|
||||
statusText: "PREPARING",
|
||||
statusClass: "inactive",
|
||||
perNodeStatus,
|
||||
errorMessage: null,
|
||||
};
|
||||
if (has("Failed"))
|
||||
return {
|
||||
statusText: "FAILED",
|
||||
statusClass: "failed",
|
||||
perNodeStatus,
|
||||
errorMessage: failedErrors.length > 0 ? failedErrors.join("; ") : null,
|
||||
};
|
||||
if (has("Shutdown"))
|
||||
return { statusText: "SHUTDOWN", statusClass: "inactive" };
|
||||
if (has("Loading"))
|
||||
return { statusText: "LOADING", statusClass: "starting" };
|
||||
if (has("WarmingUp"))
|
||||
return { statusText: "WARMING UP", statusClass: "starting" };
|
||||
if (has("Running"))
|
||||
return { statusText: "RUNNING", statusClass: "running" };
|
||||
if (has("Ready")) return { statusText: "READY", statusClass: "loaded" };
|
||||
if (has("Loaded")) return { statusText: "LOADED", statusClass: "loaded" };
|
||||
if (has("WaitingForModel"))
|
||||
return { statusText: "WAITING", statusClass: "starting" };
|
||||
if (has("InitializingBackend"))
|
||||
return { statusText: "INITIALIZING", statusClass: "starting" };
|
||||
if (has("WaitingForInitialization"))
|
||||
return { statusText: "INITIALIZING", statusClass: "starting" };
|
||||
return { statusText: "SHUTDOWN", statusClass: "inactive", perNodeStatus, errorMessage: null };
|
||||
|
||||
return { statusText: "RUNNING", statusClass: "active" };
|
||||
// For loading/warming states, show node progress when multi-node
|
||||
if (has("Loading")) {
|
||||
const readyCount = count("Ready") + count("Running") + count("Loaded");
|
||||
const statusText =
|
||||
totalNodes > 1
|
||||
? `LOADING (${readyCount}/${totalNodes} nodes ready)`
|
||||
: "LOADING";
|
||||
return { statusText, statusClass: "starting", perNodeStatus, errorMessage: null };
|
||||
}
|
||||
if (has("WarmingUp")) {
|
||||
const readyCount = count("Ready") + count("Running");
|
||||
const statusText =
|
||||
totalNodes > 1
|
||||
? `WARMING UP (${readyCount}/${totalNodes} nodes ready)`
|
||||
: "WARMING UP";
|
||||
return { statusText, statusClass: "starting", perNodeStatus, errorMessage: null };
|
||||
}
|
||||
|
||||
if (has("Running"))
|
||||
return { statusText: "RUNNING", statusClass: "running", perNodeStatus, errorMessage: null };
|
||||
if (has("Ready"))
|
||||
return { statusText: "READY", statusClass: "loaded", perNodeStatus, errorMessage: null };
|
||||
if (has("Loaded"))
|
||||
return { statusText: "LOADED", statusClass: "loaded", perNodeStatus, errorMessage: null };
|
||||
if (has("WaitingForModel"))
|
||||
return { statusText: "WAITING", statusClass: "starting", perNodeStatus, errorMessage: null };
|
||||
if (has("InitializingBackend"))
|
||||
return { statusText: "INITIALIZING", statusClass: "starting", perNodeStatus, errorMessage: null };
|
||||
if (has("WaitingForInitialization"))
|
||||
return { statusText: "INITIALIZING", statusClass: "starting", perNodeStatus, errorMessage: null };
|
||||
|
||||
return { statusText: "RUNNING", statusClass: "active", perNodeStatus, errorMessage: null };
|
||||
}
|
||||
|
||||
function getBytes(value: unknown): number {
|
||||
@@ -1242,6 +1397,75 @@
|
||||
}
|
||||
}
|
||||
|
||||
async function deleteMetaInstance(metaInstanceId: string) {
|
||||
const meta = metaInstancesData[metaInstanceId];
|
||||
const modelId = meta?.modelId ?? "unknown";
|
||||
if (!confirm(`Delete model ${modelId}?`)) return;
|
||||
|
||||
const wasSelected = selectedChatModel() === modelId;
|
||||
|
||||
try {
|
||||
const response = await fetch(`/meta_instance/${metaInstanceId}`, {
|
||||
method: "DELETE",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
console.error("Failed to delete meta instance:", response.status);
|
||||
} else if (wasSelected) {
|
||||
// Switch to another available model or clear selection
|
||||
const remainingInstances = Object.entries(instanceData).filter(
|
||||
([id]) => id !== getBackingInstanceId(metaInstanceId),
|
||||
);
|
||||
if (remainingInstances.length > 0) {
|
||||
const [, lastInstance] =
|
||||
remainingInstances[remainingInstances.length - 1];
|
||||
const newModelId = getInstanceModelId(lastInstance);
|
||||
if (
|
||||
newModelId &&
|
||||
newModelId !== "Unknown" &&
|
||||
newModelId !== "Unknown Model"
|
||||
) {
|
||||
setSelectedChatModel(newModelId);
|
||||
} else {
|
||||
setSelectedChatModel("");
|
||||
}
|
||||
} else {
|
||||
setSelectedChatModel("");
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Error deleting meta instance:", error);
|
||||
}
|
||||
}
|
||||
|
||||
// Find the backing Instance ID for a MetaInstance by scanning instances
|
||||
function getBackingInstanceId(metaInstanceId: string): string | null {
|
||||
for (const [id, inst] of Object.entries(instanceData)) {
|
||||
const [, inner] = getTagged(inst);
|
||||
if (
|
||||
inner &&
|
||||
typeof inner === "object" &&
|
||||
(inner as Record<string, unknown>).metaInstanceId === metaInstanceId
|
||||
) {
|
||||
return id;
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
// Get orphan Instance IDs (not backing any MetaInstance)
|
||||
function getOrphanInstanceIds(): string[] {
|
||||
return Object.keys(instanceData).filter((id) => {
|
||||
const [, inner] = getTagged(instanceData[id]);
|
||||
return (
|
||||
!inner ||
|
||||
typeof inner !== "object" ||
|
||||
!(inner as Record<string, unknown>).metaInstanceId
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
// Helper to unwrap tagged unions like { MlxRingInstance: {...} }
|
||||
function getTagged(obj: unknown): [string | null, unknown] {
|
||||
if (!obj || typeof obj !== "object") return [null, null];
|
||||
@@ -1282,11 +1506,7 @@
|
||||
// Instance type from tag
|
||||
let instanceType = "Unknown";
|
||||
if (instanceTag === "MlxRingInstance") instanceType = "MLX Ring";
|
||||
else if (
|
||||
instanceTag === "MlxIbvInstance" ||
|
||||
instanceTag === "MlxJacclInstance"
|
||||
)
|
||||
instanceType = "MLX RDMA";
|
||||
else if (instanceTag === "MlxJacclInstance") instanceType = "MLX RDMA";
|
||||
|
||||
const inst = instance as {
|
||||
shardAssignments?: {
|
||||
@@ -1634,7 +1854,51 @@
|
||||
}
|
||||
|
||||
const nodeCount = $derived(data ? Object.keys(data.nodes).length : 0);
|
||||
const instanceCount = $derived(Object.keys(instanceData).length);
|
||||
const metaInstanceCount = $derived(Object.keys(metaInstancesData).length);
|
||||
const orphanInstanceIds = $derived(getOrphanInstanceIds());
|
||||
const instanceCount = $derived(metaInstanceCount + orphanInstanceIds.length);
|
||||
|
||||
// Unified display items: MetaInstances first, then orphan Instances
|
||||
interface DisplayItem {
|
||||
id: string; // MetaInstance ID or Instance ID (used as key and displayed)
|
||||
modelId: string;
|
||||
instance: unknown | null; // The backing/orphan instance (tagged union) or null if placing
|
||||
instanceId: string | null; // The actual Instance ID (for topology hover)
|
||||
isMetaInstance: boolean;
|
||||
sharding: string | null; // From MetaInstance constraints (used when instance is null)
|
||||
instanceMeta: string | null; // From MetaInstance constraints (used when instance is null)
|
||||
}
|
||||
|
||||
const unifiedDisplayItems = $derived((): DisplayItem[] => {
|
||||
const items: DisplayItem[] = [];
|
||||
// MetaInstances
|
||||
for (const [metaId, meta] of Object.entries(metaInstancesData)) {
|
||||
const backingId = getBackingInstanceId(metaId);
|
||||
items.push({
|
||||
id: metaId,
|
||||
modelId: meta.modelId,
|
||||
instance: backingId ? instanceData[backingId] : null,
|
||||
instanceId: backingId,
|
||||
isMetaInstance: true,
|
||||
sharding: meta.sharding,
|
||||
instanceMeta: meta.instanceMeta,
|
||||
});
|
||||
}
|
||||
// Orphan Instances
|
||||
for (const orphanId of getOrphanInstanceIds()) {
|
||||
const inst = instanceData[orphanId];
|
||||
items.push({
|
||||
id: orphanId,
|
||||
modelId: getInstanceModelId(inst),
|
||||
instance: inst,
|
||||
instanceId: orphanId,
|
||||
isMetaInstance: false,
|
||||
sharding: null,
|
||||
instanceMeta: null,
|
||||
});
|
||||
}
|
||||
return items;
|
||||
});
|
||||
|
||||
// Helper to get the number of nodes in a placement preview
|
||||
function getPreviewNodeCount(preview: PlacementPreview): number {
|
||||
@@ -1752,8 +2016,71 @@
|
||||
</script>
|
||||
|
||||
{#snippet clusterWarnings()}
|
||||
{#if tbBridgeCycles.length > 0 || macosVersionMismatch || (tb5WithoutRdma && !tb5InfoDismissed)}
|
||||
{#if tbBridgeCycles.length > 0 || macosVersionMismatch || (tb5WithoutRdma && !tb5InfoDismissed) || (jacclError && jacclError !== jacclDismissedError)}
|
||||
<div class="absolute top-4 left-4 flex flex-col gap-2 z-40">
|
||||
{#if jacclError && jacclError !== jacclDismissedError}
|
||||
<div class="group relative" role="alert">
|
||||
<div
|
||||
class="flex items-center gap-2 px-3 py-2 rounded border border-red-500/50 bg-red-500/10 backdrop-blur-sm cursor-help"
|
||||
>
|
||||
<svg
|
||||
class="w-5 h-5 text-red-400 flex-shrink-0"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d={warningIconPath}
|
||||
/>
|
||||
</svg>
|
||||
<span class="text-sm font-mono text-red-200">
|
||||
JACCL RDMA ERROR
|
||||
</span>
|
||||
<button
|
||||
type="button"
|
||||
onclick={() => (jacclDismissedError = jacclError)}
|
||||
class="ml-1 text-red-300/60 hover:text-red-200 transition-colors cursor-pointer"
|
||||
title="Dismiss"
|
||||
>
|
||||
<svg
|
||||
class="w-4 h-4"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d="M6 18L18 6M6 6l12 12"
|
||||
/>
|
||||
</svg>
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<!-- Tooltip on hover -->
|
||||
<div
|
||||
class="absolute top-full left-0 mt-2 w-80 p-3 rounded border border-red-500/30 bg-exo-dark-gray/95 backdrop-blur-sm opacity-0 invisible group-hover:opacity-100 group-hover:visible transition-all duration-200 z-50 shadow-lg"
|
||||
>
|
||||
<p class="text-xs text-white/80 mb-2">
|
||||
A macOS RDMA driver error was detected. This is a known issue
|
||||
with the experimental RDMA driver in macOS.
|
||||
</p>
|
||||
<p class="text-xs text-white/60 mb-2">
|
||||
<span class="text-red-300">Error:</span>
|
||||
{jacclError}
|
||||
</p>
|
||||
<p class="text-xs text-white/60">
|
||||
<span class="text-red-300">To fix:</span> Restart the affected machine.
|
||||
There is currently no other workaround for this issue.
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
{#if tbBridgeCycles.length > 0}
|
||||
{@const cycle = tbBridgeCycles[0]}
|
||||
{@const serviceName = getTbBridgeServiceName(cycle)}
|
||||
@@ -1922,8 +2249,29 @@
|
||||
{/snippet}
|
||||
|
||||
{#snippet clusterWarningsCompact()}
|
||||
{#if tbBridgeCycles.length > 0 || macosVersionMismatch || (tb5WithoutRdma && !tb5InfoDismissed)}
|
||||
{#if tbBridgeCycles.length > 0 || macosVersionMismatch || (tb5WithoutRdma && !tb5InfoDismissed) || (jacclError && jacclError !== jacclDismissedError)}
|
||||
<div class="absolute top-2 left-2 flex flex-col gap-1">
|
||||
{#if jacclError && jacclError !== jacclDismissedError}
|
||||
<div
|
||||
class="flex items-center gap-1.5 px-2 py-1 rounded border border-red-500/50 bg-red-500/10 backdrop-blur-sm"
|
||||
title="JACCL RDMA driver error — restart affected machine"
|
||||
>
|
||||
<svg
|
||||
class="w-3.5 h-3.5 text-red-400"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d={warningIconPath}
|
||||
/>
|
||||
</svg>
|
||||
<span class="text-[10px] font-mono text-red-200">JACCL ERROR</span>
|
||||
</div>
|
||||
{/if}
|
||||
{#if tbBridgeCycles.length > 0}
|
||||
<div
|
||||
class="flex items-center gap-1.5 px-2 py-1 rounded border border-yellow-500/50 bg-yellow-500/10 backdrop-blur-sm"
|
||||
@@ -2301,31 +2649,57 @@
|
||||
bind:this={instancesContainerRef}
|
||||
class="max-h-72 xl:max-h-96 space-y-3 overflow-y-auto overflow-x-hidden py-px"
|
||||
>
|
||||
{#each Object.entries(instanceData) as [id, instance]}
|
||||
{@const downloadInfo = getInstanceDownloadStatus(
|
||||
id,
|
||||
instance,
|
||||
)}
|
||||
{#each unifiedDisplayItems() as item (item.id)}
|
||||
{@const id = item.id}
|
||||
{@const instance = item.instance}
|
||||
{@const downloadInfo = instance
|
||||
? getInstanceDownloadStatus(item.instanceId ?? id, instance)
|
||||
: getMetaInstancePlacingStatus(id)}
|
||||
{@const metaData = item.isMetaInstance ? metaInstancesData[id] : null}
|
||||
{@const retryError = metaData?.lastFailureError && !downloadInfo.isFailed
|
||||
? metaData.consecutiveFailures > 0
|
||||
? `(${((metaData.consecutiveFailures - 1) % 3) + 1}/3) ${metaData.lastFailureError}`
|
||||
: metaData.lastFailureError
|
||||
: null}
|
||||
{@const statusText = downloadInfo.statusText}
|
||||
{@const isDownloading = downloadInfo.isDownloading}
|
||||
{@const isFailed = statusText === "FAILED"}
|
||||
{@const isFailed =
|
||||
statusText === "FAILED" ||
|
||||
statusText === "PLACEMENT FAILED"}
|
||||
{@const isLoading =
|
||||
statusText === "LOADING" ||
|
||||
statusText === "WARMING UP" ||
|
||||
statusText === "WAITING"}
|
||||
statusText.startsWith("LOADING") ||
|
||||
statusText.startsWith("WARMING UP") ||
|
||||
statusText === "WAITING" ||
|
||||
statusText === "PLACING" ||
|
||||
statusText.startsWith("RETRYING")}
|
||||
{@const isReady =
|
||||
statusText === "READY" || statusText === "LOADED"}
|
||||
{@const isRunning = statusText === "RUNNING"}
|
||||
<!-- Instance Card -->
|
||||
{@const instanceModelId = getInstanceModelId(instance)}
|
||||
{@const instanceInfo = getInstanceInfo(instance)}
|
||||
{@const instanceConnections =
|
||||
getInstanceConnections(instance)}
|
||||
{@const instanceModelId = item.modelId}
|
||||
{@const instanceInfo = instance
|
||||
? getInstanceInfo(instance)
|
||||
: {
|
||||
instanceType:
|
||||
item.instanceMeta === "MlxRing"
|
||||
? "MLX Ring"
|
||||
: item.instanceMeta === "MlxJaccl"
|
||||
? "MLX RDMA"
|
||||
: "Unknown",
|
||||
sharding: item.sharding ?? "Unknown",
|
||||
nodeNames: [] as string[],
|
||||
nodeIds: [] as string[],
|
||||
nodeCount: 0,
|
||||
}}
|
||||
{@const instanceConnections = instance
|
||||
? getInstanceConnections(instance)
|
||||
: []}
|
||||
<div
|
||||
class="relative group cursor-pointer"
|
||||
role="button"
|
||||
tabindex="0"
|
||||
onmouseenter={() => (hoveredInstanceId = id)}
|
||||
onmouseenter={() =>
|
||||
(hoveredInstanceId = item.instanceId ?? id)}
|
||||
onmouseleave={() => (hoveredInstanceId = null)}
|
||||
onclick={() => {
|
||||
if (
|
||||
@@ -2424,7 +2798,10 @@
|
||||
>
|
||||
</div>
|
||||
<button
|
||||
onclick={() => deleteInstance(id)}
|
||||
onclick={() =>
|
||||
item.isMetaInstance
|
||||
? deleteMetaInstance(id)
|
||||
: deleteInstance(id)}
|
||||
class="text-xs px-2 py-1 font-mono tracking-wider uppercase border border-red-500/30 text-red-400 hover:bg-red-500/20 hover:text-red-400 hover:border-red-500/50 transition-all duration-200 cursor-pointer"
|
||||
>
|
||||
DELETE
|
||||
@@ -2434,7 +2811,7 @@
|
||||
<div
|
||||
class="text-exo-yellow text-xs font-mono tracking-wide truncate"
|
||||
>
|
||||
{getInstanceModelId(instance)}
|
||||
{instanceModelId}
|
||||
</div>
|
||||
<div class="text-white/60 text-xs font-mono">
|
||||
Strategy: <span class="text-white/80"
|
||||
@@ -2702,6 +3079,30 @@
|
||||
>
|
||||
{downloadInfo.errorMessage}
|
||||
</div>
|
||||
{:else if retryError}
|
||||
<div
|
||||
class="text-xs text-orange-400/80 font-mono mt-1 break-words"
|
||||
>
|
||||
Retrying after error: {retryError}
|
||||
</div>
|
||||
{/if}
|
||||
{#if downloadInfo.perNodeStatus.length > 1 && (statusText.startsWith("LOADING") || statusText.startsWith("WARMING UP") || statusText === "WAITING" || statusText === "INITIALIZING")}
|
||||
<div class="mt-1.5 space-y-0.5">
|
||||
{#each downloadInfo.perNodeStatus as node}
|
||||
<div
|
||||
class="flex items-center justify-between text-[10px] font-mono"
|
||||
>
|
||||
<span class="text-white/60 truncate pr-2"
|
||||
>{node.nodeName}</span
|
||||
>
|
||||
<span
|
||||
class={getStatusColor(
|
||||
node.status.toUpperCase(),
|
||||
)}>{node.status}</span
|
||||
>
|
||||
</div>
|
||||
{/each}
|
||||
</div>
|
||||
{/if}
|
||||
{/if}
|
||||
</div>
|
||||
@@ -2870,21 +3271,21 @@
|
||||
</button>
|
||||
<button
|
||||
onclick={() => {
|
||||
selectedInstanceType = "MlxIbv";
|
||||
selectedInstanceType = "MlxJaccl";
|
||||
saveLaunchDefaults();
|
||||
}}
|
||||
class="flex items-center gap-2 py-2 px-4 text-sm font-mono border rounded transition-all duration-200 cursor-pointer {selectedInstanceType ===
|
||||
'MlxIbv'
|
||||
'MlxJaccl'
|
||||
? 'bg-transparent text-exo-yellow border-exo-yellow'
|
||||
: 'bg-transparent text-white/70 border-exo-medium-gray/50 hover:border-exo-yellow/50'}"
|
||||
>
|
||||
<span
|
||||
class="w-4 h-4 rounded-full border-2 flex items-center justify-center {selectedInstanceType ===
|
||||
'MlxIbv'
|
||||
'MlxJaccl'
|
||||
? 'border-exo-yellow'
|
||||
: 'border-exo-medium-gray'}"
|
||||
>
|
||||
{#if selectedInstanceType === "MlxIbv"}
|
||||
{#if selectedInstanceType === "MlxJaccl"}
|
||||
<span class="w-2 h-2 rounded-full bg-exo-yellow"></span>
|
||||
{/if}
|
||||
</span>
|
||||
@@ -3113,31 +3514,60 @@
|
||||
<div
|
||||
class="space-y-3 max-h-72 xl:max-h-96 overflow-y-auto overflow-x-hidden py-px pr-1"
|
||||
>
|
||||
{#each Object.entries(instanceData) as [id, instance]}
|
||||
{@const downloadInfo = getInstanceDownloadStatus(
|
||||
id,
|
||||
instance,
|
||||
)}
|
||||
{#each unifiedDisplayItems() as item (item.id)}
|
||||
{@const id = item.id}
|
||||
{@const instance = item.instance}
|
||||
{@const downloadInfo = instance
|
||||
? getInstanceDownloadStatus(
|
||||
item.instanceId ?? id,
|
||||
instance,
|
||||
)
|
||||
: getMetaInstancePlacingStatus(id)}
|
||||
{@const metaData = item.isMetaInstance ? metaInstancesData[id] : null}
|
||||
{@const retryError = metaData?.lastFailureError && !downloadInfo.isFailed
|
||||
? metaData.consecutiveFailures > 0
|
||||
? `(${((metaData.consecutiveFailures - 1) % 3) + 1}/3) ${metaData.lastFailureError}`
|
||||
: metaData.lastFailureError
|
||||
: null}
|
||||
{@const statusText = downloadInfo.statusText}
|
||||
{@const isDownloading = downloadInfo.isDownloading}
|
||||
{@const isFailed = statusText === "FAILED"}
|
||||
{@const isFailed =
|
||||
statusText === "FAILED" ||
|
||||
statusText === "PLACEMENT FAILED"}
|
||||
{@const isLoading =
|
||||
statusText === "LOADING" ||
|
||||
statusText === "WARMING UP" ||
|
||||
statusText === "WAITING"}
|
||||
statusText.startsWith("LOADING") ||
|
||||
statusText.startsWith("WARMING UP") ||
|
||||
statusText === "WAITING" ||
|
||||
statusText === "PLACING" ||
|
||||
statusText.startsWith("RETRYING")}
|
||||
{@const isReady =
|
||||
statusText === "READY" || statusText === "LOADED"}
|
||||
{@const isRunning = statusText === "RUNNING"}
|
||||
<!-- Instance Card -->
|
||||
{@const instanceModelId = getInstanceModelId(instance)}
|
||||
{@const instanceInfo = getInstanceInfo(instance)}
|
||||
{@const instanceConnections =
|
||||
getInstanceConnections(instance)}
|
||||
{@const instanceModelId = item.modelId}
|
||||
{@const instanceInfo = instance
|
||||
? getInstanceInfo(instance)
|
||||
: {
|
||||
instanceType:
|
||||
item.instanceMeta === "MlxRing"
|
||||
? "MLX Ring"
|
||||
: item.instanceMeta === "MlxJaccl"
|
||||
? "MLX RDMA"
|
||||
: "Unknown",
|
||||
sharding: item.sharding ?? "Unknown",
|
||||
nodeNames: [] as string[],
|
||||
nodeIds: [] as string[],
|
||||
nodeCount: 0,
|
||||
}}
|
||||
{@const instanceConnections = instance
|
||||
? getInstanceConnections(instance)
|
||||
: []}
|
||||
<div
|
||||
class="relative group cursor-pointer"
|
||||
role="button"
|
||||
tabindex="0"
|
||||
onmouseenter={() => (hoveredInstanceId = id)}
|
||||
onmouseenter={() =>
|
||||
(hoveredInstanceId = item.instanceId ?? id)}
|
||||
onmouseleave={() => (hoveredInstanceId = null)}
|
||||
onclick={() => {
|
||||
if (
|
||||
@@ -3236,7 +3666,10 @@
|
||||
>
|
||||
</div>
|
||||
<button
|
||||
onclick={() => deleteInstance(id)}
|
||||
onclick={() =>
|
||||
item.isMetaInstance
|
||||
? deleteMetaInstance(id)
|
||||
: deleteInstance(id)}
|
||||
class="text-xs px-2 py-1 font-mono tracking-wider uppercase border border-red-500/30 text-red-400 hover:bg-red-500/20 hover:text-red-400 hover:border-red-500/50 transition-all duration-200 cursor-pointer"
|
||||
>
|
||||
DELETE
|
||||
@@ -3246,7 +3679,7 @@
|
||||
<div
|
||||
class="text-exo-yellow text-xs font-mono tracking-wide truncate"
|
||||
>
|
||||
{getInstanceModelId(instance)}
|
||||
{instanceModelId}
|
||||
</div>
|
||||
<div class="text-white/60 text-xs font-mono">
|
||||
Strategy: <span class="text-white/80"
|
||||
@@ -3524,6 +3957,30 @@
|
||||
>
|
||||
{downloadInfo.errorMessage}
|
||||
</div>
|
||||
{:else if retryError}
|
||||
<div
|
||||
class="text-xs text-orange-400/80 font-mono mt-1 break-words"
|
||||
>
|
||||
Retrying after error: {retryError}
|
||||
</div>
|
||||
{/if}
|
||||
{#if downloadInfo.perNodeStatus.length > 1 && (statusText.startsWith("LOADING") || statusText.startsWith("WARMING UP") || statusText === "WAITING" || statusText === "INITIALIZING")}
|
||||
<div class="mt-1.5 space-y-0.5">
|
||||
{#each downloadInfo.perNodeStatus as node}
|
||||
<div
|
||||
class="flex items-center justify-between text-[10px] font-mono"
|
||||
>
|
||||
<span class="text-white/60 truncate pr-2"
|
||||
>{node.nodeName}</span
|
||||
>
|
||||
<span
|
||||
class={getStatusColor(
|
||||
node.status.toUpperCase(),
|
||||
)}>{node.status}</span
|
||||
>
|
||||
</div>
|
||||
{/each}
|
||||
</div>
|
||||
{/if}
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
@@ -71,8 +71,11 @@ from exo.shared.types.api import (
|
||||
ChatCompletionResponse,
|
||||
CreateInstanceParams,
|
||||
CreateInstanceResponse,
|
||||
CreateMetaInstanceParams,
|
||||
CreateMetaInstanceResponse,
|
||||
DeleteDownloadResponse,
|
||||
DeleteInstanceResponse,
|
||||
DeleteMetaInstanceResponse,
|
||||
ErrorInfo,
|
||||
ErrorResponse,
|
||||
FinishReason,
|
||||
@@ -115,8 +118,10 @@ from exo.shared.types.claude_api import (
|
||||
from exo.shared.types.commands import (
|
||||
Command,
|
||||
CreateInstance,
|
||||
CreateMetaInstance,
|
||||
DeleteDownload,
|
||||
DeleteInstance,
|
||||
DeleteMetaInstance,
|
||||
DownloadCommand,
|
||||
ForwarderCommand,
|
||||
ForwarderDownloadCommand,
|
||||
@@ -128,7 +133,7 @@ from exo.shared.types.commands import (
|
||||
TaskFinished,
|
||||
TextGeneration,
|
||||
)
|
||||
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
|
||||
from exo.shared.types.common import CommandId, Id, MetaInstanceId, NodeId, SessionId
|
||||
from exo.shared.types.events import (
|
||||
ChunkGenerated,
|
||||
Event,
|
||||
@@ -137,6 +142,7 @@ from exo.shared.types.events import (
|
||||
TracesMerged,
|
||||
)
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.meta_instance import MetaInstance
|
||||
from exo.shared.types.openai_responses import (
|
||||
ResponsesRequest,
|
||||
ResponsesResponse,
|
||||
@@ -275,6 +281,8 @@ class API:
|
||||
self.app.get("/instance/previews")(self.get_placement_previews)
|
||||
self.app.get("/instance/{instance_id}")(self.get_instance)
|
||||
self.app.delete("/instance/{instance_id}")(self.delete_instance)
|
||||
self.app.post("/meta_instance")(self.create_meta_instance)
|
||||
self.app.delete("/meta_instance/{meta_instance_id}")(self.delete_meta_instance)
|
||||
self.app.get("/models")(self.get_models)
|
||||
self.app.get("/v1/models")(self.get_models)
|
||||
self.app.post("/models/add")(self.add_custom_model)
|
||||
@@ -521,6 +529,46 @@ class API:
|
||||
instance_id=instance_id,
|
||||
)
|
||||
|
||||
async def create_meta_instance(
|
||||
self, payload: CreateMetaInstanceParams
|
||||
) -> CreateMetaInstanceResponse:
|
||||
meta_instance = MetaInstance(
|
||||
model_id=payload.model_id,
|
||||
sharding=payload.sharding,
|
||||
instance_meta=payload.instance_meta,
|
||||
min_nodes=payload.min_nodes,
|
||||
node_ids=payload.node_ids,
|
||||
)
|
||||
command = CreateMetaInstance(meta_instance=meta_instance)
|
||||
await self._send(command)
|
||||
return CreateMetaInstanceResponse(
|
||||
message="Command received.",
|
||||
command_id=command.command_id,
|
||||
meta_instance_id=meta_instance.meta_instance_id,
|
||||
)
|
||||
|
||||
async def delete_meta_instance(
|
||||
self, meta_instance_id: MetaInstanceId
|
||||
) -> DeleteMetaInstanceResponse:
|
||||
meta = self.state.meta_instances.get(meta_instance_id)
|
||||
if not meta:
|
||||
raise HTTPException(status_code=404, detail="MetaInstance not found")
|
||||
|
||||
# Delete MetaInstance first to prevent reconciler from re-placing
|
||||
command = DeleteMetaInstance(meta_instance_id=meta_instance_id)
|
||||
await self._send(command)
|
||||
|
||||
# Then cascade-delete any backing instances
|
||||
for instance_id, instance in self.state.instances.items():
|
||||
if instance.meta_instance_id == meta_instance_id:
|
||||
await self._send(DeleteInstance(instance_id=instance_id))
|
||||
|
||||
return DeleteMetaInstanceResponse(
|
||||
message="Command received.",
|
||||
command_id=command.command_id,
|
||||
meta_instance_id=meta_instance_id,
|
||||
)
|
||||
|
||||
async def _token_chunk_stream(
|
||||
self, command_id: CommandId
|
||||
) -> AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None]:
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import anyio
|
||||
from anyio.abc import TaskGroup
|
||||
@@ -12,11 +13,19 @@ from exo.master.placement import (
|
||||
get_transition_events,
|
||||
place_instance,
|
||||
)
|
||||
from exo.master.process_managers import ProcessManager
|
||||
from exo.master.process_managers.instance_health import InstanceHealthReconciler
|
||||
from exo.master.process_managers.meta_instance import MetaInstanceReconciler
|
||||
from exo.master.process_managers.node_timeout import NodeTimeoutReconciler
|
||||
from exo.master.reconcile import try_place_for_meta_instance
|
||||
from exo.shared.apply import apply
|
||||
from exo.shared.constants import EXO_EVENT_LOG_DIR, EXO_TRACING_ENABLED
|
||||
from exo.shared.models.model_cards import ModelCard
|
||||
from exo.shared.types.commands import (
|
||||
CreateInstance,
|
||||
CreateMetaInstance,
|
||||
DeleteInstance,
|
||||
DeleteMetaInstance,
|
||||
ForwarderCommand,
|
||||
ForwarderDownloadCommand,
|
||||
ImageEdits,
|
||||
@@ -34,9 +43,9 @@ from exo.shared.types.events import (
|
||||
ForwarderEvent,
|
||||
IndexedEvent,
|
||||
InputChunkReceived,
|
||||
InstanceDeleted,
|
||||
MetaInstanceCreated,
|
||||
MetaInstanceDeleted,
|
||||
NodeGatheredInfo,
|
||||
NodeTimedOut,
|
||||
TaskCreated,
|
||||
TaskDeleted,
|
||||
TraceEventData,
|
||||
@@ -58,7 +67,7 @@ from exo.shared.types.tasks import (
|
||||
TextGeneration as TextGenerationTask,
|
||||
)
|
||||
from exo.shared.types.worker.instances import InstanceId
|
||||
from exo.utils.channels import Receiver, Sender, channel
|
||||
from exo.utils.channels import Receiver, Sender
|
||||
from exo.utils.event_buffer import MultiSourceBuffer
|
||||
|
||||
|
||||
@@ -82,16 +91,15 @@ class Master:
|
||||
self.local_event_receiver = local_event_receiver
|
||||
self.global_event_sender = global_event_sender
|
||||
self.download_command_sender = download_command_sender
|
||||
send, recv = channel[Event]()
|
||||
self.event_sender: Sender[Event] = send
|
||||
self._loopback_event_receiver: Receiver[Event] = recv
|
||||
self._loopback_event_sender: Sender[ForwarderEvent] = (
|
||||
local_event_receiver.clone_sender()
|
||||
)
|
||||
self._multi_buffer = MultiSourceBuffer[NodeId, Event]()
|
||||
self._event_log = DiskEventLog(EXO_EVENT_LOG_DIR / "master")
|
||||
self._pending_traces: dict[TaskId, dict[int, list[TraceEventData]]] = {}
|
||||
self._expected_ranks: dict[TaskId, set[int]] = {}
|
||||
self._process_managers: Sequence[ProcessManager] = [
|
||||
InstanceHealthReconciler(),
|
||||
NodeTimeoutReconciler(),
|
||||
MetaInstanceReconciler(),
|
||||
]
|
||||
|
||||
async def run(self):
|
||||
logger.info("Starting Master")
|
||||
@@ -100,15 +108,12 @@ class Master:
|
||||
async with self._tg as tg:
|
||||
tg.start_soon(self._event_processor)
|
||||
tg.start_soon(self._command_processor)
|
||||
tg.start_soon(self._loopback_processor)
|
||||
tg.start_soon(self._plan)
|
||||
tg.start_soon(self._reconcile)
|
||||
finally:
|
||||
self._event_log.close()
|
||||
self.global_event_sender.close()
|
||||
self.local_event_receiver.close()
|
||||
self.command_receiver.close()
|
||||
self._loopback_event_sender.close()
|
||||
self._loopback_event_receiver.close()
|
||||
|
||||
async def shutdown(self):
|
||||
logger.info("Stopping Master")
|
||||
@@ -290,6 +295,29 @@ class Master:
|
||||
)
|
||||
)
|
||||
generated_events.extend(transition_events)
|
||||
case CreateMetaInstance():
|
||||
generated_events.append(
|
||||
MetaInstanceCreated(meta_instance=command.meta_instance)
|
||||
)
|
||||
# Immediate placement attempt for responsiveness
|
||||
model_card = await ModelCard.load(
|
||||
command.meta_instance.model_id
|
||||
)
|
||||
result = try_place_for_meta_instance(
|
||||
command.meta_instance,
|
||||
model_card,
|
||||
self.state.topology,
|
||||
self.state.instances,
|
||||
self.state.node_memory,
|
||||
self.state.node_network,
|
||||
)
|
||||
generated_events.extend(result.events)
|
||||
case DeleteMetaInstance():
|
||||
generated_events.append(
|
||||
MetaInstanceDeleted(
|
||||
meta_instance_id=command.meta_instance_id
|
||||
)
|
||||
)
|
||||
case PlaceInstance():
|
||||
placement = place_instance(
|
||||
command,
|
||||
@@ -341,31 +369,32 @@ class Master:
|
||||
):
|
||||
await self._send_event(IndexedEvent(idx=i, event=event))
|
||||
for event in generated_events:
|
||||
await self.event_sender.send(event)
|
||||
await self._apply_and_broadcast(event)
|
||||
except ValueError as e:
|
||||
logger.opt(exception=e).warning("Error in command processor")
|
||||
|
||||
# These plan loops are the cracks showing in our event sourcing architecture - more things could be commands
|
||||
async def _plan(self) -> None:
|
||||
async def _apply_and_broadcast(self, event: Event) -> None:
|
||||
"""Apply event to state, persist to disk, and broadcast to workers.
|
||||
|
||||
State is updated synchronously (before any await), so callers can
|
||||
rely on ``self.state`` reflecting this event immediately after the
|
||||
call. Python's cooperative scheduling guarantees no interleaving
|
||||
between the state read and write.
|
||||
"""
|
||||
logger.debug(f"Master indexing event: {str(event)[:100]}")
|
||||
indexed = IndexedEvent(event=event, idx=len(self._event_log))
|
||||
self.state = apply(self.state, indexed)
|
||||
event._master_time_stamp = datetime.now(tz=timezone.utc) # pyright: ignore[reportPrivateUsage]
|
||||
self._event_log.append(event)
|
||||
await self._send_event(indexed)
|
||||
|
||||
async def _reconcile(self) -> None:
|
||||
while True:
|
||||
# kill broken instances
|
||||
connected_node_ids = set(self.state.topology.list_nodes())
|
||||
for instance_id, instance in self.state.instances.items():
|
||||
for node_id in instance.shard_assignments.node_to_runner:
|
||||
if node_id not in connected_node_ids:
|
||||
await self.event_sender.send(
|
||||
InstanceDeleted(instance_id=instance_id)
|
||||
)
|
||||
break
|
||||
|
||||
# time out dead nodes
|
||||
for node_id, time in self.state.last_seen.items():
|
||||
now = datetime.now(tz=timezone.utc)
|
||||
if now - time > timedelta(seconds=30):
|
||||
logger.info(f"Manually removing node {node_id} due to inactivity")
|
||||
await self.event_sender.send(NodeTimedOut(node_id=node_id))
|
||||
|
||||
await anyio.sleep(10)
|
||||
for pm in self._process_managers:
|
||||
events = await pm.reconcile(self.state)
|
||||
for event in events:
|
||||
await self._apply_and_broadcast(event)
|
||||
await anyio.sleep(1)
|
||||
|
||||
async def _event_processor(self) -> None:
|
||||
with self.local_event_receiver as local_events:
|
||||
@@ -383,32 +412,10 @@ class Master:
|
||||
await self._handle_traces_collected(event)
|
||||
continue
|
||||
|
||||
logger.debug(f"Master indexing event: {str(event)[:100]}")
|
||||
indexed = IndexedEvent(event=event, idx=len(self._event_log))
|
||||
self.state = apply(self.state, indexed)
|
||||
|
||||
event._master_time_stamp = datetime.now(tz=timezone.utc) # pyright: ignore[reportPrivateUsage]
|
||||
if isinstance(event, NodeGatheredInfo):
|
||||
event.when = str(datetime.now(tz=timezone.utc))
|
||||
|
||||
self._event_log.append(event)
|
||||
await self._send_event(indexed)
|
||||
|
||||
async def _loopback_processor(self) -> None:
|
||||
# this would ideally not be necessary.
|
||||
# this is WAY less hacky than how I was working around this before
|
||||
local_index = 0
|
||||
with self._loopback_event_receiver as events:
|
||||
async for event in events:
|
||||
await self._loopback_event_sender.send(
|
||||
ForwarderEvent(
|
||||
origin=NodeId(f"master_{self.node_id}"),
|
||||
origin_idx=local_index,
|
||||
session=self.session_id,
|
||||
event=event,
|
||||
)
|
||||
)
|
||||
local_index += 1
|
||||
await self._apply_and_broadcast(event)
|
||||
|
||||
# This function is re-entrant, take care!
|
||||
async def _send_event(self, event: IndexedEvent):
|
||||
@@ -440,7 +447,7 @@ class Master:
|
||||
for trace_data in self._pending_traces[task_id].values():
|
||||
all_trace_data.extend(trace_data)
|
||||
|
||||
await self.event_sender.send(
|
||||
await self._apply_and_broadcast(
|
||||
TracesMerged(task_id=task_id, traces=all_trace_data)
|
||||
)
|
||||
|
||||
|
||||
@@ -63,7 +63,9 @@ def place_instance(
|
||||
required_nodes: set[NodeId] | None = None,
|
||||
) -> dict[InstanceId, Instance]:
|
||||
cycles = topology.get_cycles()
|
||||
candidate_cycles = list(filter(lambda it: len(it) >= command.min_nodes, cycles))
|
||||
candidate_cycles = list(
|
||||
filter(lambda it: len(it) >= command.min_nodes, cycles)
|
||||
)
|
||||
|
||||
# Filter to cycles containing all required nodes (subset matching)
|
||||
if required_nodes:
|
||||
@@ -106,7 +108,11 @@ def place_instance(
|
||||
cycle for cycle in smallest_cycles if topology.is_rdma_cycle(cycle)
|
||||
]
|
||||
|
||||
if command.instance_meta == InstanceMeta.MlxJaccl and smallest_rdma_cycles != []:
|
||||
if command.instance_meta == InstanceMeta.MlxJaccl:
|
||||
if not smallest_rdma_cycles:
|
||||
raise ValueError(
|
||||
"Requested RDMA (MlxJaccl) but no RDMA-connected cycles available"
|
||||
)
|
||||
smallest_cycles = smallest_rdma_cycles
|
||||
|
||||
cycles_with_leaf_nodes: list[Cycle] = [
|
||||
|
||||
12
src/exo/master/process_managers/__init__.py
Normal file
12
src/exo/master/process_managers/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import Protocol, runtime_checkable
|
||||
|
||||
from exo.shared.types.events import Event
|
||||
from exo.shared.types.state import State
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class ProcessManager(Protocol):
|
||||
"""A reconciliation step that examines state and returns corrective events."""
|
||||
|
||||
async def reconcile(self, state: State) -> Sequence[Event]: ...
|
||||
49
src/exo/master/process_managers/instance_health.py
Normal file
49
src/exo/master/process_managers/instance_health.py
Normal file
@@ -0,0 +1,49 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import final
|
||||
|
||||
from exo.master.reconcile import instance_connections_healthy, instance_runners_failed
|
||||
from exo.shared.types.events import Event, InstanceDeleted, InstanceRetrying
|
||||
from exo.shared.types.state import State
|
||||
|
||||
MAX_INSTANCE_RETRIES = 3
|
||||
|
||||
|
||||
@final
|
||||
class InstanceHealthReconciler:
|
||||
"""Delete instances whose network connections are broken or whose runners have all failed."""
|
||||
|
||||
async def reconcile(self, state: State) -> Sequence[Event]:
|
||||
events: list[Event] = []
|
||||
for instance_id, instance in state.instances.items():
|
||||
if not instance_connections_healthy(instance, state.topology):
|
||||
events.append(
|
||||
InstanceDeleted(
|
||||
instance_id=instance_id,
|
||||
failure_error="Network connection lost",
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
is_failed, error_message = instance_runners_failed(
|
||||
instance, state.runners, state.node_identities
|
||||
)
|
||||
if is_failed:
|
||||
# Retry within the same instance if backed by a MetaInstance
|
||||
mid = instance.meta_instance_id
|
||||
mi = state.meta_instances.get(mid) if mid else None
|
||||
if mid and mi and mi.consecutive_failures < MAX_INSTANCE_RETRIES:
|
||||
events.append(
|
||||
InstanceRetrying(
|
||||
instance_id=instance_id,
|
||||
meta_instance_id=mid,
|
||||
failure_error=error_message or "Runner failed",
|
||||
)
|
||||
)
|
||||
else:
|
||||
events.append(
|
||||
InstanceDeleted(
|
||||
instance_id=instance_id,
|
||||
failure_error=error_message,
|
||||
)
|
||||
)
|
||||
return events
|
||||
53
src/exo/master/process_managers/meta_instance.py
Normal file
53
src/exo/master/process_managers/meta_instance.py
Normal file
@@ -0,0 +1,53 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import final
|
||||
|
||||
from exo.master.reconcile import (
|
||||
find_unsatisfied_meta_instances,
|
||||
try_place_for_meta_instance,
|
||||
)
|
||||
from exo.shared.models.model_cards import ModelCard
|
||||
from exo.shared.types.events import Event, InstanceCreated, MetaInstancePlacementFailed
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.types.worker.instances import Instance, InstanceId
|
||||
|
||||
|
||||
@final
|
||||
class MetaInstanceReconciler:
|
||||
"""Place instances for unsatisfied MetaInstances."""
|
||||
|
||||
async def reconcile(self, state: State) -> Sequence[Event]:
|
||||
all_events: list[Event] = []
|
||||
# Local copy for intermediate tracking — so placement of B
|
||||
# sees A's instance and doesn't double-place on same resources.
|
||||
current_instances: dict[InstanceId, Instance] = dict(state.instances)
|
||||
|
||||
unsatisfied = find_unsatisfied_meta_instances(
|
||||
state.meta_instances,
|
||||
current_instances,
|
||||
state.topology,
|
||||
)
|
||||
for meta_instance in unsatisfied:
|
||||
model_card = await ModelCard.load(meta_instance.model_id)
|
||||
result = try_place_for_meta_instance(
|
||||
meta_instance,
|
||||
model_card,
|
||||
state.topology,
|
||||
current_instances,
|
||||
state.node_memory,
|
||||
state.node_network,
|
||||
)
|
||||
# Update local instance map so next placement sees this one
|
||||
for event in result.events:
|
||||
if isinstance(event, InstanceCreated):
|
||||
current_instances[event.instance.instance_id] = event.instance
|
||||
all_events.extend(result.events)
|
||||
|
||||
# Emit placement failure if error differs from what's already in state
|
||||
if result.error is not None and meta_instance.placement_error != result.error:
|
||||
all_events.append(
|
||||
MetaInstancePlacementFailed(
|
||||
meta_instance_id=meta_instance.meta_instance_id,
|
||||
reason=result.error,
|
||||
)
|
||||
)
|
||||
return all_events
|
||||
27
src/exo/master/process_managers/node_timeout.py
Normal file
27
src/exo/master/process_managers/node_timeout.py
Normal file
@@ -0,0 +1,27 @@
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import final
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from exo.shared.types.events import Event, NodeTimedOut
|
||||
from exo.shared.types.state import State
|
||||
|
||||
_DEFAULT_TIMEOUT = timedelta(seconds=30)
|
||||
|
||||
|
||||
@final
|
||||
class NodeTimeoutReconciler:
|
||||
"""Time out nodes that haven't been seen recently."""
|
||||
|
||||
def __init__(self, timeout: timedelta = _DEFAULT_TIMEOUT) -> None:
|
||||
self.timeout = timeout
|
||||
|
||||
async def reconcile(self, state: State) -> Sequence[Event]:
|
||||
now = datetime.now(tz=timezone.utc)
|
||||
events: list[Event] = []
|
||||
for node_id, last_seen in state.last_seen.items():
|
||||
if now - last_seen > self.timeout:
|
||||
logger.info(f"Removing node {node_id} due to inactivity")
|
||||
events.append(NodeTimedOut(node_id=node_id))
|
||||
return events
|
||||
236
src/exo/master/reconcile.py
Normal file
236
src/exo/master/reconcile.py
Normal file
@@ -0,0 +1,236 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import NamedTuple
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from exo.master.placement import get_transition_events, place_instance
|
||||
from exo.shared.models.model_cards import ModelCard
|
||||
from exo.shared.topology import Topology
|
||||
from exo.shared.types.commands import PlaceInstance
|
||||
from exo.shared.types.common import MetaInstanceId, NodeId
|
||||
from exo.shared.types.events import Event
|
||||
from exo.shared.types.meta_instance import MetaInstance
|
||||
from exo.shared.types.profiling import MemoryUsage, NodeIdentity, NodeNetworkInfo
|
||||
from exo.shared.types.topology import RDMAConnection, SocketConnection
|
||||
from exo.shared.types.worker.instances import (
|
||||
BaseInstance,
|
||||
Instance,
|
||||
InstanceId,
|
||||
MlxJacclInstance,
|
||||
MlxRingInstance,
|
||||
)
|
||||
from exo.shared.types.worker.runners import (
|
||||
RunnerFailed,
|
||||
RunnerId,
|
||||
RunnerShutdown,
|
||||
RunnerStatus,
|
||||
)
|
||||
|
||||
|
||||
class PlacementResult(NamedTuple):
|
||||
"""Result of a placement attempt: events to apply and optional error reason."""
|
||||
|
||||
events: Sequence[Event]
|
||||
error: str | None
|
||||
|
||||
|
||||
def _get_ring_order(instance: BaseInstance) -> list[NodeId]:
|
||||
"""Reconstruct ring order from shard device_rank."""
|
||||
node_ranks: list[tuple[NodeId, int]] = []
|
||||
for node_id, runner_id in instance.shard_assignments.node_to_runner.items():
|
||||
shard = instance.shard_assignments.runner_to_shard[runner_id]
|
||||
node_ranks.append((node_id, shard.device_rank))
|
||||
node_ranks.sort(key=lambda x: x[1])
|
||||
return [node_id for node_id, _ in node_ranks]
|
||||
|
||||
|
||||
def _ring_connections_healthy(instance: MlxRingInstance, topology: Topology) -> bool:
|
||||
"""Check that the specific IPs used by a ring instance still exist in the topology."""
|
||||
ring = _get_ring_order(instance)
|
||||
n = len(ring)
|
||||
for node in ring:
|
||||
hosts = instance.hosts_by_node[node]
|
||||
for idx in range(n):
|
||||
host = hosts[idx]
|
||||
if host.ip in ("0.0.0.0", "198.51.100.1"):
|
||||
continue # self or placeholder
|
||||
# Real connection: node → ring[idx]. Check specific IP.
|
||||
connections = topology.get_all_connections_between(node, ring[idx])
|
||||
if not any(
|
||||
isinstance(c, SocketConnection)
|
||||
and c.sink_multiaddr.ip_address == host.ip
|
||||
for c in connections
|
||||
):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _jaccl_connections_healthy(instance: MlxJacclInstance, topology: Topology) -> bool:
|
||||
"""Check that the specific RDMA interfaces used by a JACCL instance still exist."""
|
||||
ring = _get_ring_order(instance)
|
||||
n = len(ring)
|
||||
for i in range(n):
|
||||
for j in range(n):
|
||||
iface = instance.jaccl_devices[i][j]
|
||||
if iface is None:
|
||||
continue
|
||||
connections = topology.get_all_connections_between(ring[i], ring[j])
|
||||
if not any(
|
||||
isinstance(c, RDMAConnection) and c.source_rdma_iface == iface
|
||||
for c in connections
|
||||
):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def instance_connections_healthy(instance: Instance, topology: Topology) -> bool:
|
||||
"""Check that an instance's nodes and specific connections are still in the topology."""
|
||||
instance_nodes = set(instance.shard_assignments.node_to_runner.keys())
|
||||
if not all(topology.contains_node(n) for n in instance_nodes):
|
||||
return False
|
||||
if len(instance_nodes) <= 1:
|
||||
return True
|
||||
match instance:
|
||||
case MlxRingInstance():
|
||||
return _ring_connections_healthy(instance, topology)
|
||||
case MlxJacclInstance():
|
||||
return _jaccl_connections_healthy(instance, topology)
|
||||
|
||||
|
||||
def instance_runners_failed(
|
||||
instance: Instance,
|
||||
runners: Mapping[RunnerId, RunnerStatus],
|
||||
node_identities: Mapping[NodeId, NodeIdentity],
|
||||
) -> tuple[bool, str | None]:
|
||||
"""Check if an instance's runners have all reached terminal failure states.
|
||||
|
||||
Returns ``(True, error_message)`` when ALL runners are terminal
|
||||
(``RunnerFailed`` or ``RunnerShutdown``) and at least one is ``RunnerFailed``.
|
||||
|
||||
Returns ``(False, None)`` when runners are still active, haven't reported
|
||||
yet, or all gracefully shut down (no ``RunnerFailed``).
|
||||
"""
|
||||
instance_runner_ids = set(instance.shard_assignments.node_to_runner.values())
|
||||
|
||||
if not instance_runner_ids:
|
||||
return False, None
|
||||
|
||||
# Build reverse mapping: runner_id -> node_id
|
||||
runner_to_node: dict[RunnerId, NodeId] = {
|
||||
runner_id: node_id
|
||||
for node_id, runner_id in instance.shard_assignments.node_to_runner.items()
|
||||
}
|
||||
|
||||
has_any_failed = False
|
||||
error_messages: list[str] = []
|
||||
|
||||
for runner_id in instance_runner_ids:
|
||||
status = runners.get(runner_id)
|
||||
if status is None:
|
||||
# Runner hasn't reported yet — instance is still starting
|
||||
return False, None
|
||||
if isinstance(status, RunnerFailed):
|
||||
has_any_failed = True
|
||||
if status.error_message:
|
||||
node_id = runner_to_node.get(runner_id)
|
||||
name = node_identities[node_id].friendly_name if node_id and node_id in node_identities else node_id or "unknown"
|
||||
error_messages.append(f"{name}: {status.error_message}")
|
||||
elif isinstance(status, RunnerShutdown):
|
||||
pass # Terminal but not a failure indicator on its own
|
||||
else:
|
||||
# Runner is still active (connecting, loading, running, etc.)
|
||||
return False, None
|
||||
|
||||
if has_any_failed:
|
||||
return True, "; ".join(error_messages) if error_messages else "Runner failed"
|
||||
|
||||
# All runners are Shutdown but none Failed — graceful shutdown, not a failure
|
||||
return False, None
|
||||
|
||||
|
||||
def instance_satisfies_meta_instance(
|
||||
meta_instance: MetaInstance,
|
||||
instance: Instance,
|
||||
) -> bool:
|
||||
"""Check if a single instance satisfies a meta-instance's constraints.
|
||||
|
||||
This is a pure constraint check (model, min_nodes, node_ids).
|
||||
Use ``instance_connections_healthy`` separately for topology health.
|
||||
"""
|
||||
if instance.shard_assignments.model_id != meta_instance.model_id:
|
||||
return False
|
||||
|
||||
instance_nodes = set(instance.shard_assignments.node_to_runner.keys())
|
||||
|
||||
if len(instance_nodes) < meta_instance.min_nodes:
|
||||
return False
|
||||
|
||||
return meta_instance.node_ids is None or set(meta_instance.node_ids).issubset(
|
||||
instance_nodes
|
||||
)
|
||||
|
||||
|
||||
def find_unsatisfied_meta_instances(
|
||||
meta_instances: Mapping[MetaInstanceId, MetaInstance],
|
||||
instances: Mapping[InstanceId, Instance],
|
||||
topology: Topology,
|
||||
) -> Sequence[MetaInstance]:
|
||||
"""Return meta-instances that have no healthy backing instance."""
|
||||
unsatisfied: list[MetaInstance] = []
|
||||
for meta_id, meta_instance in meta_instances.items():
|
||||
has_healthy_backing = any(
|
||||
instance.meta_instance_id == meta_id
|
||||
and instance_connections_healthy(instance, topology)
|
||||
for instance in instances.values()
|
||||
)
|
||||
if not has_healthy_backing:
|
||||
unsatisfied.append(meta_instance)
|
||||
return unsatisfied
|
||||
|
||||
|
||||
def try_place_for_meta_instance(
|
||||
meta_instance: MetaInstance,
|
||||
model_card: ModelCard,
|
||||
topology: Topology,
|
||||
current_instances: Mapping[InstanceId, Instance],
|
||||
node_memory: Mapping[NodeId, MemoryUsage],
|
||||
node_network: Mapping[NodeId, NodeNetworkInfo],
|
||||
) -> PlacementResult:
|
||||
"""Try to place an instance satisfying the meta-instance constraints.
|
||||
|
||||
Returns a :class:`PlacementResult` with events on success, or an error
|
||||
reason on failure.
|
||||
"""
|
||||
command = PlaceInstance(
|
||||
model_card=model_card,
|
||||
sharding=meta_instance.sharding,
|
||||
instance_meta=meta_instance.instance_meta,
|
||||
min_nodes=meta_instance.min_nodes,
|
||||
)
|
||||
try:
|
||||
target_instances = place_instance(
|
||||
command,
|
||||
topology,
|
||||
current_instances,
|
||||
node_memory,
|
||||
node_network,
|
||||
required_nodes=(
|
||||
set(meta_instance.node_ids) if meta_instance.node_ids else None
|
||||
),
|
||||
)
|
||||
# Tag the new instance with meta_instance_id
|
||||
new_instance_ids = set(target_instances.keys()) - set(current_instances.keys())
|
||||
if new_instance_ids:
|
||||
new_id = next(iter(new_instance_ids))
|
||||
target_instances[new_id] = target_instances[new_id].model_copy(
|
||||
update={"meta_instance_id": meta_instance.meta_instance_id}
|
||||
)
|
||||
return PlacementResult(
|
||||
events=list(get_transition_events(current_instances, target_instances)),
|
||||
error=None,
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.debug(
|
||||
f"MetaInstance placement not possible for {meta_instance.model_id}: {e}"
|
||||
)
|
||||
return PlacementResult(events=[], error=str(e))
|
||||
750
src/exo/master/tests/test_reconcile.py
Normal file
750
src/exo/master/tests/test_reconcile.py
Normal file
@@ -0,0 +1,750 @@
|
||||
from exo.master.process_managers.instance_health import InstanceHealthReconciler
|
||||
from exo.master.reconcile import (
|
||||
find_unsatisfied_meta_instances,
|
||||
instance_connections_healthy,
|
||||
instance_runners_failed,
|
||||
instance_satisfies_meta_instance,
|
||||
)
|
||||
from exo.shared.apply import apply
|
||||
from exo.shared.models.model_cards import ModelCard, ModelId, ModelTask
|
||||
from exo.shared.topology import Topology
|
||||
from exo.shared.types.common import Host, MetaInstanceId, NodeId
|
||||
from exo.shared.types.events import (
|
||||
IndexedEvent,
|
||||
InstanceCreated,
|
||||
InstanceDeleted,
|
||||
InstanceRetrying,
|
||||
MetaInstanceCreated,
|
||||
MetaInstanceDeleted,
|
||||
)
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.meta_instance import MetaInstance
|
||||
from exo.shared.types.multiaddr import Multiaddr
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.types.topology import Connection, SocketConnection
|
||||
from exo.shared.types.worker.instances import (
|
||||
InstanceId,
|
||||
MlxRingInstance,
|
||||
)
|
||||
from exo.shared.types.worker.runners import (
|
||||
RunnerFailed,
|
||||
RunnerId,
|
||||
RunnerLoading,
|
||||
RunnerReady,
|
||||
RunnerShutdown,
|
||||
ShardAssignments,
|
||||
)
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata
|
||||
|
||||
|
||||
def _model_card(model_id: str = "test-org/test-model") -> ModelCard:
|
||||
return ModelCard(
|
||||
model_id=ModelId(model_id),
|
||||
storage_size=Memory.from_kb(1000),
|
||||
n_layers=10,
|
||||
hidden_size=30,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
)
|
||||
|
||||
|
||||
def _topology(*node_ids: str, connect: bool = True) -> Topology:
|
||||
"""Build a topology with nodes connected in a bidirectional ring with unique IPs.
|
||||
|
||||
Node at index ``i`` gets IP ``10.0.0.{i+1}``. Edges go in both directions
|
||||
between consecutive nodes (including wrap-around).
|
||||
"""
|
||||
t = Topology()
|
||||
nodes = [NodeId(n) for n in node_ids]
|
||||
for n in nodes:
|
||||
t.add_node(n)
|
||||
if connect and len(nodes) > 1:
|
||||
for i in range(len(nodes)):
|
||||
j = (i + 1) % len(nodes)
|
||||
t.add_connection(
|
||||
Connection(
|
||||
source=nodes[i],
|
||||
sink=nodes[j],
|
||||
edge=SocketConnection(
|
||||
sink_multiaddr=Multiaddr(
|
||||
address=f"/ip4/10.0.0.{j + 1}/tcp/50000"
|
||||
)
|
||||
),
|
||||
)
|
||||
)
|
||||
t.add_connection(
|
||||
Connection(
|
||||
source=nodes[j],
|
||||
sink=nodes[i],
|
||||
edge=SocketConnection(
|
||||
sink_multiaddr=Multiaddr(
|
||||
address=f"/ip4/10.0.0.{i + 1}/tcp/50000"
|
||||
)
|
||||
),
|
||||
)
|
||||
)
|
||||
return t
|
||||
|
||||
|
||||
def _meta_instance(
|
||||
model_id: str = "test-org/test-model",
|
||||
*,
|
||||
min_nodes: int = 1,
|
||||
node_ids: list[NodeId] | None = None,
|
||||
meta_instance_id: MetaInstanceId | None = None,
|
||||
) -> MetaInstance:
|
||||
return MetaInstance(
|
||||
meta_instance_id=meta_instance_id or MetaInstanceId(),
|
||||
model_id=ModelId(model_id),
|
||||
min_nodes=min_nodes,
|
||||
node_ids=node_ids,
|
||||
)
|
||||
|
||||
|
||||
def _instance(
|
||||
model_id: str = "test-org/test-model",
|
||||
node_ids: list[str] | None = None,
|
||||
instance_id: InstanceId | None = None,
|
||||
meta_instance_id: MetaInstanceId | None = None,
|
||||
) -> tuple[InstanceId, MlxRingInstance]:
|
||||
"""Create a test instance with hosts_by_node matching ``_topology()`` IPs."""
|
||||
iid = instance_id or InstanceId()
|
||||
nodes = node_ids or ["node-a"]
|
||||
n = len(nodes)
|
||||
mc = _model_card(model_id)
|
||||
ephemeral_port = 50000
|
||||
node_to_runner = {NodeId(nd): RunnerId() for nd in nodes}
|
||||
runner_to_shard = {
|
||||
runner_id: PipelineShardMetadata(
|
||||
model_card=mc,
|
||||
device_rank=i,
|
||||
world_size=n,
|
||||
start_layer=0,
|
||||
end_layer=mc.n_layers,
|
||||
n_layers=mc.n_layers,
|
||||
)
|
||||
for i, runner_id in enumerate(node_to_runner.values())
|
||||
}
|
||||
# Build hosts_by_node with IPs matching _topology() convention:
|
||||
# node at index idx has IP 10.0.0.{idx+1}
|
||||
hosts_by_node: dict[NodeId, list[Host]] = {}
|
||||
for r, node_str in enumerate(nodes):
|
||||
hosts: list[Host] = []
|
||||
for idx in range(n):
|
||||
if idx == r:
|
||||
hosts.append(Host(ip="0.0.0.0", port=ephemeral_port))
|
||||
elif n > 1 and idx in ((r - 1) % n, (r + 1) % n):
|
||||
hosts.append(Host(ip=f"10.0.0.{idx + 1}", port=ephemeral_port))
|
||||
else:
|
||||
hosts.append(Host(ip="198.51.100.1", port=0))
|
||||
hosts_by_node[NodeId(node_str)] = hosts
|
||||
return iid, MlxRingInstance(
|
||||
instance_id=iid,
|
||||
shard_assignments=ShardAssignments(
|
||||
model_id=ModelId(model_id),
|
||||
runner_to_shard=runner_to_shard,
|
||||
node_to_runner=node_to_runner,
|
||||
),
|
||||
hosts_by_node=hosts_by_node,
|
||||
ephemeral_port=ephemeral_port,
|
||||
meta_instance_id=meta_instance_id,
|
||||
)
|
||||
|
||||
|
||||
# --- instance_satisfies_meta_instance (pure constraint matching) ---
|
||||
|
||||
|
||||
def test_satisfies_matching_model():
|
||||
meta = _meta_instance()
|
||||
_, inst = _instance(node_ids=["node-a"])
|
||||
assert instance_satisfies_meta_instance(meta, inst) is True
|
||||
|
||||
|
||||
def test_not_satisfies_wrong_model():
|
||||
meta = _meta_instance("test-org/model-a")
|
||||
_, inst = _instance("test-org/model-b")
|
||||
assert instance_satisfies_meta_instance(meta, inst) is False
|
||||
|
||||
|
||||
def test_not_satisfies_missing_required_node():
|
||||
meta = _meta_instance(node_ids=[NodeId("node-c")])
|
||||
_, inst = _instance(node_ids=["node-a", "node-b"])
|
||||
assert instance_satisfies_meta_instance(meta, inst) is False
|
||||
|
||||
|
||||
def test_not_satisfies_fewer_than_min_nodes():
|
||||
meta = _meta_instance(min_nodes=3)
|
||||
_, inst = _instance(node_ids=["node-a", "node-b"])
|
||||
assert instance_satisfies_meta_instance(meta, inst) is False
|
||||
|
||||
|
||||
def test_satisfies_with_node_ids_specified():
|
||||
meta = _meta_instance(
|
||||
node_ids=[NodeId("node-a"), NodeId("node-b")], min_nodes=2
|
||||
)
|
||||
_, inst = _instance(node_ids=["node-a", "node-b", "node-c"])
|
||||
assert instance_satisfies_meta_instance(meta, inst) is True
|
||||
|
||||
|
||||
# --- instance_connections_healthy ---
|
||||
|
||||
|
||||
def test_healthy_single_node_present():
|
||||
_, inst = _instance(node_ids=["node-a"])
|
||||
topology = _topology("node-a")
|
||||
assert instance_connections_healthy(inst, topology) is True
|
||||
|
||||
|
||||
def test_unhealthy_single_node_missing():
|
||||
_, inst = _instance(node_ids=["node-a"])
|
||||
topology = Topology() # empty
|
||||
assert instance_connections_healthy(inst, topology) is False
|
||||
|
||||
|
||||
def test_healthy_two_node_ring():
|
||||
_, inst = _instance(node_ids=["node-a", "node-b"])
|
||||
topology = _topology("node-a", "node-b")
|
||||
assert instance_connections_healthy(inst, topology) is True
|
||||
|
||||
|
||||
def test_unhealthy_two_node_edge_removed():
|
||||
"""Nodes present but edge removed — ring broken."""
|
||||
_, inst = _instance(node_ids=["node-a", "node-b"])
|
||||
topology = _topology("node-a", "node-b", connect=False)
|
||||
assert instance_connections_healthy(inst, topology) is False
|
||||
|
||||
|
||||
def test_unhealthy_two_node_ip_changed():
|
||||
"""Edge exists but with a different IP than instance was configured with."""
|
||||
_, inst = _instance(node_ids=["node-a", "node-b"])
|
||||
# Build topology with different IPs than _instance() expects
|
||||
topology = Topology()
|
||||
topology.add_node(NodeId("node-a"))
|
||||
topology.add_node(NodeId("node-b"))
|
||||
topology.add_connection(
|
||||
Connection(
|
||||
source=NodeId("node-a"),
|
||||
sink=NodeId("node-b"),
|
||||
edge=SocketConnection(
|
||||
sink_multiaddr=Multiaddr(address="/ip4/192.168.99.99/tcp/50000")
|
||||
),
|
||||
)
|
||||
)
|
||||
topology.add_connection(
|
||||
Connection(
|
||||
source=NodeId("node-b"),
|
||||
sink=NodeId("node-a"),
|
||||
edge=SocketConnection(
|
||||
sink_multiaddr=Multiaddr(address="/ip4/192.168.99.98/tcp/50000")
|
||||
),
|
||||
)
|
||||
)
|
||||
assert instance_connections_healthy(inst, topology) is False
|
||||
|
||||
|
||||
def test_healthy_three_node_ring():
|
||||
_, inst = _instance(node_ids=["node-a", "node-b", "node-c"])
|
||||
topology = _topology("node-a", "node-b", "node-c")
|
||||
assert instance_connections_healthy(inst, topology) is True
|
||||
|
||||
|
||||
def test_unhealthy_three_node_one_edge_removed():
|
||||
"""Remove one edge from a three-node ring — instance unhealthy."""
|
||||
_, inst = _instance(node_ids=["node-a", "node-b", "node-c"])
|
||||
# Build topology with one direction of one edge missing
|
||||
topology = Topology()
|
||||
nodes = [NodeId("node-a"), NodeId("node-b"), NodeId("node-c")]
|
||||
for n in nodes:
|
||||
topology.add_node(n)
|
||||
# Add all edges except node-a → node-b
|
||||
topology.add_connection(
|
||||
Connection(
|
||||
source=nodes[1],
|
||||
sink=nodes[0],
|
||||
edge=SocketConnection(
|
||||
sink_multiaddr=Multiaddr(address="/ip4/10.0.0.1/tcp/50000")
|
||||
),
|
||||
)
|
||||
)
|
||||
topology.add_connection(
|
||||
Connection(
|
||||
source=nodes[1],
|
||||
sink=nodes[2],
|
||||
edge=SocketConnection(
|
||||
sink_multiaddr=Multiaddr(address="/ip4/10.0.0.3/tcp/50000")
|
||||
),
|
||||
)
|
||||
)
|
||||
topology.add_connection(
|
||||
Connection(
|
||||
source=nodes[2],
|
||||
sink=nodes[1],
|
||||
edge=SocketConnection(
|
||||
sink_multiaddr=Multiaddr(address="/ip4/10.0.0.2/tcp/50000")
|
||||
),
|
||||
)
|
||||
)
|
||||
topology.add_connection(
|
||||
Connection(
|
||||
source=nodes[2],
|
||||
sink=nodes[0],
|
||||
edge=SocketConnection(
|
||||
sink_multiaddr=Multiaddr(address="/ip4/10.0.0.1/tcp/50000")
|
||||
),
|
||||
)
|
||||
)
|
||||
topology.add_connection(
|
||||
Connection(
|
||||
source=nodes[0],
|
||||
sink=nodes[2],
|
||||
edge=SocketConnection(
|
||||
sink_multiaddr=Multiaddr(address="/ip4/10.0.0.3/tcp/50000")
|
||||
),
|
||||
)
|
||||
)
|
||||
# Missing: node-a → node-b (ip 10.0.0.2)
|
||||
assert instance_connections_healthy(inst, topology) is False
|
||||
|
||||
|
||||
def test_unhealthy_node_missing_from_topology():
|
||||
"""Instance has a node that's not in the topology at all."""
|
||||
_, inst = _instance(node_ids=["node-a", "node-b"])
|
||||
topology = _topology("node-a") # node-b not present
|
||||
assert instance_connections_healthy(inst, topology) is False
|
||||
|
||||
|
||||
def test_healthy_extra_nodes_in_topology():
|
||||
"""Extra nodes in topology don't affect instance health."""
|
||||
_, inst = _instance(node_ids=["node-a", "node-b"])
|
||||
topology = _topology("node-a", "node-b", "node-c")
|
||||
assert instance_connections_healthy(inst, topology) is True
|
||||
|
||||
|
||||
# --- find_unsatisfied_meta_instances ---
|
||||
|
||||
|
||||
def test_unsatisfied_no_meta_instances():
|
||||
result = find_unsatisfied_meta_instances({}, {}, Topology())
|
||||
assert list(result) == []
|
||||
|
||||
|
||||
def test_unsatisfied_one_satisfied():
|
||||
meta = _meta_instance()
|
||||
id_a, inst_a = _instance(meta_instance_id=meta.meta_instance_id)
|
||||
topology = _topology("node-a")
|
||||
result = find_unsatisfied_meta_instances(
|
||||
{meta.meta_instance_id: meta},
|
||||
{id_a: inst_a},
|
||||
topology,
|
||||
)
|
||||
assert list(result) == []
|
||||
|
||||
|
||||
def test_unsatisfied_one_not_satisfied():
|
||||
meta = _meta_instance("test-org/model-x")
|
||||
id_a, inst_a = _instance("test-org/model-y")
|
||||
topology = _topology("node-a")
|
||||
result = find_unsatisfied_meta_instances(
|
||||
{meta.meta_instance_id: meta}, {id_a: inst_a}, topology
|
||||
)
|
||||
assert list(result) == [meta]
|
||||
|
||||
|
||||
def test_unsatisfied_mix():
|
||||
meta_satisfied = _meta_instance("test-org/model-a")
|
||||
meta_unsatisfied = _meta_instance("test-org/model-b")
|
||||
id_a, inst_a = _instance(
|
||||
"test-org/model-a", meta_instance_id=meta_satisfied.meta_instance_id
|
||||
)
|
||||
topology = _topology("node-a")
|
||||
result = find_unsatisfied_meta_instances(
|
||||
{
|
||||
meta_satisfied.meta_instance_id: meta_satisfied,
|
||||
meta_unsatisfied.meta_instance_id: meta_unsatisfied,
|
||||
},
|
||||
{id_a: inst_a},
|
||||
topology,
|
||||
)
|
||||
assert list(result) == [meta_unsatisfied]
|
||||
|
||||
|
||||
def test_unsatisfied_node_disconnect():
|
||||
meta = _meta_instance()
|
||||
id_a, inst_a = _instance(
|
||||
node_ids=["node-a", "node-b"], meta_instance_id=meta.meta_instance_id
|
||||
)
|
||||
topology = _topology("node-a") # node-b disconnected
|
||||
result = find_unsatisfied_meta_instances(
|
||||
{meta.meta_instance_id: meta},
|
||||
{id_a: inst_a},
|
||||
topology,
|
||||
)
|
||||
assert list(result) == [meta]
|
||||
|
||||
|
||||
def test_unsatisfied_edge_break():
|
||||
"""Instance exists but its connections broke — meta-instance becomes unsatisfied."""
|
||||
meta = _meta_instance()
|
||||
id_a, inst_a = _instance(
|
||||
node_ids=["node-a", "node-b"], meta_instance_id=meta.meta_instance_id
|
||||
)
|
||||
topology = _topology("node-a", "node-b", connect=False) # nodes present, no edges
|
||||
result = find_unsatisfied_meta_instances(
|
||||
{meta.meta_instance_id: meta},
|
||||
{id_a: inst_a},
|
||||
topology,
|
||||
)
|
||||
assert list(result) == [meta]
|
||||
|
||||
|
||||
def test_unsatisfied_idempotent():
|
||||
meta = _meta_instance("test-org/model-x")
|
||||
topology = _topology("node-a")
|
||||
meta_instances = {meta.meta_instance_id: meta}
|
||||
instances: dict[InstanceId, MlxRingInstance] = {}
|
||||
result_1 = list(
|
||||
find_unsatisfied_meta_instances(meta_instances, instances, topology)
|
||||
)
|
||||
result_2 = list(
|
||||
find_unsatisfied_meta_instances(meta_instances, instances, topology)
|
||||
)
|
||||
assert result_1 == result_2
|
||||
|
||||
|
||||
def test_unsatisfied_exclusive_binding():
|
||||
"""Two MetaInstances for the same model: one is bound via meta_instance_id, the other is unsatisfied."""
|
||||
meta_a = _meta_instance("test-org/model-x")
|
||||
meta_b = _meta_instance("test-org/model-x")
|
||||
id_inst, inst = _instance(
|
||||
"test-org/model-x", meta_instance_id=meta_a.meta_instance_id
|
||||
)
|
||||
topology = _topology("node-a")
|
||||
result = find_unsatisfied_meta_instances(
|
||||
{
|
||||
meta_a.meta_instance_id: meta_a,
|
||||
meta_b.meta_instance_id: meta_b,
|
||||
},
|
||||
{id_inst: inst},
|
||||
topology,
|
||||
)
|
||||
assert list(result) == [meta_b]
|
||||
|
||||
|
||||
# --- apply handlers ---
|
||||
|
||||
|
||||
def test_apply_meta_instance_created():
|
||||
state = State()
|
||||
meta = _meta_instance()
|
||||
event = MetaInstanceCreated(meta_instance=meta)
|
||||
new_state = apply(state, IndexedEvent(idx=0, event=event))
|
||||
assert meta.meta_instance_id in new_state.meta_instances
|
||||
assert new_state.meta_instances[meta.meta_instance_id] == meta
|
||||
|
||||
|
||||
def test_apply_meta_instance_deleted():
|
||||
meta = _meta_instance()
|
||||
state = State(meta_instances={meta.meta_instance_id: meta})
|
||||
event = MetaInstanceDeleted(meta_instance_id=meta.meta_instance_id)
|
||||
new_state = apply(state, IndexedEvent(idx=0, event=event))
|
||||
assert meta.meta_instance_id not in new_state.meta_instances
|
||||
|
||||
|
||||
def test_apply_meta_instance_deleted_clears_failure_info():
|
||||
meta = _meta_instance().model_copy(
|
||||
update={"consecutive_failures": 2, "last_failure_error": "OOM"}
|
||||
)
|
||||
state = State(meta_instances={meta.meta_instance_id: meta})
|
||||
event = MetaInstanceDeleted(meta_instance_id=meta.meta_instance_id)
|
||||
new_state = apply(state, IndexedEvent(idx=0, event=event))
|
||||
assert meta.meta_instance_id not in new_state.meta_instances
|
||||
|
||||
|
||||
# --- instance_runners_failed ---
|
||||
|
||||
|
||||
def test_runners_failed_all_failed():
|
||||
"""All runners in RunnerFailed -> instance is failed."""
|
||||
_, inst = _instance(node_ids=["node-a", "node-b"])
|
||||
runners = {
|
||||
rid: RunnerFailed(error_message="OOM")
|
||||
for rid in inst.shard_assignments.node_to_runner.values()
|
||||
}
|
||||
is_failed, error = instance_runners_failed(inst, runners, {})
|
||||
assert is_failed is True
|
||||
assert error is not None
|
||||
assert "OOM" in error
|
||||
|
||||
|
||||
def test_runners_failed_mixed_failed_shutdown():
|
||||
"""One Failed + one Shutdown = failed."""
|
||||
_, inst = _instance(node_ids=["node-a", "node-b"])
|
||||
runner_ids = list(inst.shard_assignments.node_to_runner.values())
|
||||
runners = {
|
||||
runner_ids[0]: RunnerFailed(error_message="crash"),
|
||||
runner_ids[1]: RunnerShutdown(),
|
||||
}
|
||||
is_failed, error = instance_runners_failed(inst, runners, {})
|
||||
assert is_failed is True
|
||||
assert error is not None
|
||||
assert "crash" in error
|
||||
|
||||
|
||||
def test_runners_not_failed_all_shutdown():
|
||||
"""All Shutdown (graceful) = not a failure."""
|
||||
_, inst = _instance(node_ids=["node-a"])
|
||||
runners = {
|
||||
rid: RunnerShutdown()
|
||||
for rid in inst.shard_assignments.node_to_runner.values()
|
||||
}
|
||||
is_failed, _ = instance_runners_failed(inst, runners, {})
|
||||
assert is_failed is False
|
||||
|
||||
|
||||
def test_runners_not_failed_still_active():
|
||||
"""Some runners still active = not failed yet."""
|
||||
_, inst = _instance(node_ids=["node-a", "node-b"])
|
||||
runner_ids = list(inst.shard_assignments.node_to_runner.values())
|
||||
runners = {
|
||||
runner_ids[0]: RunnerFailed(error_message="OOM"),
|
||||
runner_ids[1]: RunnerLoading(),
|
||||
}
|
||||
is_failed, _ = instance_runners_failed(inst, runners, {})
|
||||
assert is_failed is False
|
||||
|
||||
|
||||
def test_runners_not_failed_no_status():
|
||||
"""Runner not yet reported = not failed."""
|
||||
_, inst = _instance(node_ids=["node-a"])
|
||||
is_failed, _ = instance_runners_failed(inst, {}, {})
|
||||
assert is_failed is False
|
||||
|
||||
|
||||
def test_runners_not_failed_healthy():
|
||||
"""Runners in Ready state = not failed."""
|
||||
_, inst = _instance(node_ids=["node-a"])
|
||||
runners = {
|
||||
rid: RunnerReady()
|
||||
for rid in inst.shard_assignments.node_to_runner.values()
|
||||
}
|
||||
is_failed, _ = instance_runners_failed(inst, runners, {})
|
||||
assert is_failed is False
|
||||
|
||||
|
||||
# --- failure tracking in apply_instance_deleted ---
|
||||
|
||||
|
||||
def test_apply_instance_deleted_tracks_failure():
|
||||
"""InstanceDeleted with failure_error increments meta instance failure count."""
|
||||
meta = _meta_instance()
|
||||
iid, inst = _instance(
|
||||
node_ids=["node-a"], meta_instance_id=meta.meta_instance_id
|
||||
)
|
||||
state = State(
|
||||
meta_instances={meta.meta_instance_id: meta},
|
||||
instances={iid: inst},
|
||||
)
|
||||
event = InstanceDeleted(instance_id=iid, failure_error="Runner OOM")
|
||||
new_state = apply(state, IndexedEvent(idx=0, event=event))
|
||||
mi = new_state.meta_instances[meta.meta_instance_id]
|
||||
assert mi.consecutive_failures == 1
|
||||
assert mi.last_failure_error == "Runner OOM"
|
||||
|
||||
|
||||
def test_apply_instance_deleted_increments_failure():
|
||||
"""Subsequent failures increment the counter."""
|
||||
meta = _meta_instance().model_copy(
|
||||
update={"consecutive_failures": 2, "last_failure_error": "previous error"}
|
||||
)
|
||||
iid, inst = _instance(
|
||||
node_ids=["node-a"], meta_instance_id=meta.meta_instance_id
|
||||
)
|
||||
state = State(
|
||||
meta_instances={meta.meta_instance_id: meta},
|
||||
instances={iid: inst},
|
||||
)
|
||||
event = InstanceDeleted(instance_id=iid, failure_error="new error")
|
||||
new_state = apply(state, IndexedEvent(idx=0, event=event))
|
||||
mi = new_state.meta_instances[meta.meta_instance_id]
|
||||
assert mi.consecutive_failures == 3
|
||||
assert mi.last_failure_error == "new error"
|
||||
|
||||
|
||||
def test_apply_instance_deleted_no_failure_no_tracking():
|
||||
"""InstanceDeleted without failure_error does not track."""
|
||||
meta = _meta_instance()
|
||||
iid, inst = _instance(
|
||||
node_ids=["node-a"], meta_instance_id=meta.meta_instance_id
|
||||
)
|
||||
state = State(
|
||||
meta_instances={meta.meta_instance_id: meta},
|
||||
instances={iid: inst},
|
||||
)
|
||||
event = InstanceDeleted(instance_id=iid)
|
||||
new_state = apply(state, IndexedEvent(idx=0, event=event))
|
||||
mi = new_state.meta_instances[meta.meta_instance_id]
|
||||
assert mi.consecutive_failures == 0
|
||||
|
||||
|
||||
def test_apply_instance_deleted_orphan_no_tracking():
|
||||
"""InstanceDeleted for orphan instance (no meta_instance_id) does not track."""
|
||||
iid, inst = _instance(node_ids=["node-a"])
|
||||
state = State(instances={iid: inst})
|
||||
event = InstanceDeleted(instance_id=iid, failure_error="crash")
|
||||
new_state = apply(state, IndexedEvent(idx=0, event=event))
|
||||
assert len(new_state.meta_instances) == 0
|
||||
|
||||
|
||||
# --- InstanceRetrying ---
|
||||
|
||||
|
||||
def test_apply_instance_retrying_removes_runners():
|
||||
"""InstanceRetrying removes the instance's runners from state but keeps the instance."""
|
||||
meta = _meta_instance()
|
||||
iid, inst = _instance(node_ids=["node-a", "node-b"], meta_instance_id=meta.meta_instance_id)
|
||||
runner_ids = list(inst.shard_assignments.node_to_runner.values())
|
||||
runners = {
|
||||
runner_ids[0]: RunnerFailed(error_message="OOM"),
|
||||
runner_ids[1]: RunnerShutdown(),
|
||||
}
|
||||
state = State(
|
||||
meta_instances={meta.meta_instance_id: meta},
|
||||
instances={iid: inst},
|
||||
runners=runners,
|
||||
)
|
||||
event = InstanceRetrying(
|
||||
instance_id=iid,
|
||||
meta_instance_id=meta.meta_instance_id,
|
||||
failure_error="OOM",
|
||||
)
|
||||
new_state = apply(state, IndexedEvent(idx=0, event=event))
|
||||
# Instance still exists
|
||||
assert iid in new_state.instances
|
||||
# Runners removed
|
||||
assert runner_ids[0] not in new_state.runners
|
||||
assert runner_ids[1] not in new_state.runners
|
||||
|
||||
|
||||
def test_apply_instance_retrying_increments_failure():
|
||||
"""InstanceRetrying increments consecutive_failures on the MetaInstance."""
|
||||
meta = _meta_instance()
|
||||
iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
|
||||
state = State(
|
||||
meta_instances={meta.meta_instance_id: meta},
|
||||
instances={iid: inst},
|
||||
)
|
||||
event = InstanceRetrying(
|
||||
instance_id=iid,
|
||||
meta_instance_id=meta.meta_instance_id,
|
||||
failure_error="crash",
|
||||
)
|
||||
new_state = apply(state, IndexedEvent(idx=0, event=event))
|
||||
mi = new_state.meta_instances[meta.meta_instance_id]
|
||||
assert mi.consecutive_failures == 1
|
||||
assert mi.last_failure_error == "crash"
|
||||
|
||||
|
||||
def test_apply_instance_retrying_skips_missing_runners():
|
||||
"""InstanceRetrying doesn't assert if runners haven't reported yet."""
|
||||
meta = _meta_instance()
|
||||
iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
|
||||
# No runners in state at all
|
||||
state = State(
|
||||
meta_instances={meta.meta_instance_id: meta},
|
||||
instances={iid: inst},
|
||||
)
|
||||
event = InstanceRetrying(
|
||||
instance_id=iid,
|
||||
meta_instance_id=meta.meta_instance_id,
|
||||
failure_error="crash",
|
||||
)
|
||||
# Should not raise
|
||||
new_state = apply(state, IndexedEvent(idx=0, event=event))
|
||||
assert iid in new_state.instances
|
||||
|
||||
|
||||
def test_apply_instance_created_resets_failure_counter():
|
||||
"""InstanceCreated resets consecutive_failures but preserves last_failure_error."""
|
||||
meta = _meta_instance().model_copy(
|
||||
update={"consecutive_failures": 3, "last_failure_error": "old error"}
|
||||
)
|
||||
_, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
|
||||
state = State(meta_instances={meta.meta_instance_id: meta})
|
||||
event = InstanceCreated(instance=inst)
|
||||
new_state = apply(state, IndexedEvent(idx=0, event=event))
|
||||
mi = new_state.meta_instances[meta.meta_instance_id]
|
||||
assert mi.consecutive_failures == 0
|
||||
assert mi.last_failure_error == "old error"
|
||||
assert mi.placement_error is None
|
||||
|
||||
|
||||
# --- InstanceHealthReconciler retry-vs-delete ---
|
||||
|
||||
|
||||
async def test_health_reconciler_retries_when_under_limit():
|
||||
"""InstanceHealthReconciler emits InstanceRetrying when consecutive_failures < 3."""
|
||||
meta = _meta_instance()
|
||||
iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
|
||||
runner_ids = list(inst.shard_assignments.node_to_runner.values())
|
||||
state = State(
|
||||
meta_instances={meta.meta_instance_id: meta},
|
||||
instances={iid: inst},
|
||||
runners={runner_ids[0]: RunnerFailed(error_message="OOM")},
|
||||
topology=_topology("node-a"),
|
||||
)
|
||||
reconciler = InstanceHealthReconciler()
|
||||
events = await reconciler.reconcile(state)
|
||||
assert len(events) == 1
|
||||
assert isinstance(events[0], InstanceRetrying)
|
||||
assert events[0].instance_id == iid
|
||||
assert events[0].meta_instance_id == meta.meta_instance_id
|
||||
|
||||
|
||||
async def test_health_reconciler_deletes_when_limit_reached():
|
||||
"""InstanceHealthReconciler emits InstanceDeleted when consecutive_failures >= 3."""
|
||||
meta = _meta_instance().model_copy(update={"consecutive_failures": 3})
|
||||
iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
|
||||
runner_ids = list(inst.shard_assignments.node_to_runner.values())
|
||||
state = State(
|
||||
meta_instances={meta.meta_instance_id: meta},
|
||||
instances={iid: inst},
|
||||
runners={runner_ids[0]: RunnerFailed(error_message="OOM")},
|
||||
topology=_topology("node-a"),
|
||||
)
|
||||
reconciler = InstanceHealthReconciler()
|
||||
events = await reconciler.reconcile(state)
|
||||
assert len(events) == 1
|
||||
assert isinstance(events[0], InstanceDeleted)
|
||||
|
||||
|
||||
async def test_health_reconciler_deletes_without_meta_instance():
|
||||
"""Instances without a MetaInstance are deleted immediately on runner failure."""
|
||||
iid, inst = _instance(node_ids=["node-a"])
|
||||
runner_ids = list(inst.shard_assignments.node_to_runner.values())
|
||||
state = State(
|
||||
instances={iid: inst},
|
||||
runners={runner_ids[0]: RunnerFailed(error_message="crash")},
|
||||
topology=_topology("node-a"),
|
||||
)
|
||||
reconciler = InstanceHealthReconciler()
|
||||
events = await reconciler.reconcile(state)
|
||||
assert len(events) == 1
|
||||
assert isinstance(events[0], InstanceDeleted)
|
||||
|
||||
|
||||
async def test_health_reconciler_network_failure_always_deletes():
|
||||
"""Network failure always triggers InstanceDeleted regardless of retry count."""
|
||||
meta = _meta_instance()
|
||||
iid, inst = _instance(
|
||||
node_ids=["node-a", "node-b"], meta_instance_id=meta.meta_instance_id
|
||||
)
|
||||
state = State(
|
||||
meta_instances={meta.meta_instance_id: meta},
|
||||
instances={iid: inst},
|
||||
topology=_topology("node-a"), # node-b missing
|
||||
)
|
||||
reconciler = InstanceHealthReconciler()
|
||||
events = await reconciler.reconcile(state)
|
||||
assert len(events) == 1
|
||||
assert isinstance(events[0], InstanceDeleted)
|
||||
assert events[0].failure_error == "Network connection lost"
|
||||
@@ -4,7 +4,7 @@ from datetime import datetime
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.common import MetaInstanceId, NodeId
|
||||
from exo.shared.types.events import (
|
||||
ChunkGenerated,
|
||||
Event,
|
||||
@@ -12,6 +12,10 @@ from exo.shared.types.events import (
|
||||
InputChunkReceived,
|
||||
InstanceCreated,
|
||||
InstanceDeleted,
|
||||
InstanceRetrying,
|
||||
MetaInstanceCreated,
|
||||
MetaInstanceDeleted,
|
||||
MetaInstancePlacementFailed,
|
||||
NodeDownloadProgress,
|
||||
NodeGatheredInfo,
|
||||
NodeTimedOut,
|
||||
@@ -28,6 +32,7 @@ from exo.shared.types.events import (
|
||||
TracesCollected,
|
||||
TracesMerged,
|
||||
)
|
||||
from exo.shared.types.meta_instance import MetaInstance
|
||||
from exo.shared.types.profiling import (
|
||||
NodeIdentity,
|
||||
NodeNetworkInfo,
|
||||
@@ -72,6 +77,14 @@ def event_apply(event: Event, state: State) -> State:
|
||||
return apply_instance_created(event, state)
|
||||
case InstanceDeleted():
|
||||
return apply_instance_deleted(event, state)
|
||||
case InstanceRetrying():
|
||||
return apply_instance_retrying(event, state)
|
||||
case MetaInstanceCreated():
|
||||
return apply_meta_instance_created(event, state)
|
||||
case MetaInstanceDeleted():
|
||||
return apply_meta_instance_deleted(event, state)
|
||||
case MetaInstancePlacementFailed():
|
||||
return apply_meta_instance_placement_failed(event, state)
|
||||
case NodeTimedOut():
|
||||
return apply_node_timed_out(event, state)
|
||||
case NodeDownloadProgress():
|
||||
@@ -174,20 +187,119 @@ def apply_task_failed(event: TaskFailed, state: State) -> State:
|
||||
return state.model_copy(update={"tasks": new_tasks})
|
||||
|
||||
|
||||
def _update_meta_instance(
|
||||
state: State, mid: MetaInstanceId, **fields: object
|
||||
) -> Mapping[MetaInstanceId, MetaInstance]:
|
||||
mi = state.meta_instances[mid]
|
||||
return {**state.meta_instances, mid: mi.model_copy(update=fields)}
|
||||
|
||||
|
||||
def apply_instance_created(event: InstanceCreated, state: State) -> State:
|
||||
instance = event.instance
|
||||
new_instances: Mapping[InstanceId, Instance] = {
|
||||
**state.instances,
|
||||
instance.instance_id: instance,
|
||||
}
|
||||
return state.model_copy(update={"instances": new_instances})
|
||||
update: dict[str, object] = {"instances": new_instances}
|
||||
# Reset failure tracking when a new instance is created for a meta-instance
|
||||
if instance.meta_instance_id and instance.meta_instance_id in state.meta_instances:
|
||||
mi = state.meta_instances[instance.meta_instance_id]
|
||||
if mi.placement_error is not None or mi.consecutive_failures > 0:
|
||||
update["meta_instances"] = _update_meta_instance(
|
||||
state,
|
||||
instance.meta_instance_id,
|
||||
placement_error=None,
|
||||
consecutive_failures=0,
|
||||
)
|
||||
return state.model_copy(update=update)
|
||||
|
||||
|
||||
def apply_instance_deleted(event: InstanceDeleted, state: State) -> State:
|
||||
deleted_instance = state.instances.get(event.instance_id)
|
||||
new_instances: Mapping[InstanceId, Instance] = {
|
||||
iid: inst for iid, inst in state.instances.items() if iid != event.instance_id
|
||||
}
|
||||
return state.model_copy(update={"instances": new_instances})
|
||||
update: dict[str, object] = {"instances": new_instances}
|
||||
|
||||
# Track failure on the MetaInstance itself
|
||||
if (
|
||||
event.failure_error
|
||||
and deleted_instance
|
||||
and deleted_instance.meta_instance_id
|
||||
and deleted_instance.meta_instance_id in state.meta_instances
|
||||
):
|
||||
mid = deleted_instance.meta_instance_id
|
||||
mi = state.meta_instances[mid]
|
||||
update["meta_instances"] = {
|
||||
**state.meta_instances,
|
||||
mid: mi.model_copy(
|
||||
update={
|
||||
"consecutive_failures": mi.consecutive_failures + 1,
|
||||
"last_failure_error": event.failure_error,
|
||||
}
|
||||
),
|
||||
}
|
||||
|
||||
return state.model_copy(update=update)
|
||||
|
||||
|
||||
def apply_instance_retrying(event: InstanceRetrying, state: State) -> State:
|
||||
"""Runners failed but retry limit not reached — remove runners, keep instance."""
|
||||
instance = state.instances.get(event.instance_id)
|
||||
if instance is None:
|
||||
return state
|
||||
|
||||
# Remove all runners belonging to this instance from state
|
||||
runner_ids_to_remove = set(instance.shard_assignments.node_to_runner.values())
|
||||
new_runners: Mapping[RunnerId, RunnerStatus] = {
|
||||
rid: rs
|
||||
for rid, rs in state.runners.items()
|
||||
if rid not in runner_ids_to_remove
|
||||
}
|
||||
|
||||
update: dict[str, object] = {"runners": new_runners}
|
||||
|
||||
# Increment failure count on the MetaInstance
|
||||
if event.meta_instance_id in state.meta_instances:
|
||||
update["meta_instances"] = _update_meta_instance(
|
||||
state,
|
||||
event.meta_instance_id,
|
||||
consecutive_failures=state.meta_instances[event.meta_instance_id].consecutive_failures + 1,
|
||||
last_failure_error=event.failure_error,
|
||||
)
|
||||
|
||||
return state.model_copy(update=update)
|
||||
|
||||
|
||||
def apply_meta_instance_created(event: MetaInstanceCreated, state: State) -> State:
|
||||
new_meta: Mapping[MetaInstanceId, MetaInstance] = {
|
||||
**state.meta_instances,
|
||||
event.meta_instance.meta_instance_id: event.meta_instance,
|
||||
}
|
||||
return state.model_copy(update={"meta_instances": new_meta})
|
||||
|
||||
|
||||
def apply_meta_instance_deleted(event: MetaInstanceDeleted, state: State) -> State:
|
||||
new_meta: Mapping[MetaInstanceId, MetaInstance] = {
|
||||
mid: mi
|
||||
for mid, mi in state.meta_instances.items()
|
||||
if mid != event.meta_instance_id
|
||||
}
|
||||
return state.model_copy(update={"meta_instances": new_meta})
|
||||
|
||||
|
||||
def apply_meta_instance_placement_failed(
|
||||
event: MetaInstancePlacementFailed, state: State
|
||||
) -> State:
|
||||
if event.meta_instance_id not in state.meta_instances:
|
||||
return state
|
||||
return state.model_copy(
|
||||
update={
|
||||
"meta_instances": _update_meta_instance(
|
||||
state, event.meta_instance_id, placement_error=event.reason
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def apply_runner_status_updated(event: RunnerStatusUpdated, state: State) -> State:
|
||||
|
||||
@@ -3,11 +3,10 @@ from collections.abc import Generator
|
||||
from typing import Annotated, Any, Literal
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic_core import PydanticUseDefault
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from exo.shared.models.model_cards import ModelCard, ModelId
|
||||
from exo.shared.types.common import CommandId, NodeId
|
||||
from exo.shared.types.common import CommandId, MetaInstanceId, NodeId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
|
||||
from exo.shared.types.worker.shards import Sharding, ShardMetadata
|
||||
@@ -227,13 +226,6 @@ class PlaceInstanceParams(BaseModel):
|
||||
instance_meta: InstanceMeta = InstanceMeta.MlxRing
|
||||
min_nodes: int = 1
|
||||
|
||||
@field_validator("sharding", "instance_meta", mode="plain")
|
||||
@classmethod
|
||||
def use_default(cls, v: object):
|
||||
if not v or not isinstance(v, (Sharding, InstanceMeta)):
|
||||
raise PydanticUseDefault()
|
||||
return v
|
||||
|
||||
|
||||
class CreateInstanceParams(BaseModel):
|
||||
instance: Instance
|
||||
@@ -269,6 +261,26 @@ class DeleteInstanceResponse(BaseModel):
|
||||
instance_id: InstanceId
|
||||
|
||||
|
||||
class CreateMetaInstanceParams(BaseModel):
|
||||
model_id: ModelId
|
||||
sharding: Sharding = Sharding.Pipeline
|
||||
instance_meta: InstanceMeta = InstanceMeta.MlxRing
|
||||
min_nodes: int = 1
|
||||
node_ids: list[NodeId] | None = None
|
||||
|
||||
|
||||
class CreateMetaInstanceResponse(BaseModel):
|
||||
message: str
|
||||
command_id: CommandId
|
||||
meta_instance_id: MetaInstanceId
|
||||
|
||||
|
||||
class DeleteMetaInstanceResponse(BaseModel):
|
||||
message: str
|
||||
command_id: CommandId
|
||||
meta_instance_id: MetaInstanceId
|
||||
|
||||
|
||||
class AdvancedImageParams(BaseModel):
|
||||
seed: Annotated[int, Field(ge=0)] | None = None
|
||||
num_inference_steps: Annotated[int, Field(ge=1, le=100)] | None = None
|
||||
|
||||
@@ -6,7 +6,8 @@ from exo.shared.types.api import (
|
||||
ImageGenerationTaskParams,
|
||||
)
|
||||
from exo.shared.types.chunks import InputImageChunk
|
||||
from exo.shared.types.common import CommandId, NodeId
|
||||
from exo.shared.types.common import CommandId, MetaInstanceId, NodeId
|
||||
from exo.shared.types.meta_instance import MetaInstance
|
||||
from exo.shared.types.text_generation import TextGenerationTaskParams
|
||||
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
|
||||
from exo.shared.types.worker.shards import Sharding, ShardMetadata
|
||||
@@ -48,6 +49,14 @@ class DeleteInstance(BaseCommand):
|
||||
instance_id: InstanceId
|
||||
|
||||
|
||||
class CreateMetaInstance(BaseCommand):
|
||||
meta_instance: MetaInstance
|
||||
|
||||
|
||||
class DeleteMetaInstance(BaseCommand):
|
||||
meta_instance_id: MetaInstanceId
|
||||
|
||||
|
||||
class TaskFinished(BaseCommand):
|
||||
finished_command_id: CommandId
|
||||
|
||||
@@ -89,6 +98,8 @@ Command = (
|
||||
| PlaceInstance
|
||||
| CreateInstance
|
||||
| DeleteInstance
|
||||
| CreateMetaInstance
|
||||
| DeleteMetaInstance
|
||||
| TaskFinished
|
||||
| SendInputChunk
|
||||
)
|
||||
|
||||
@@ -42,6 +42,10 @@ class CommandId(Id):
|
||||
pass
|
||||
|
||||
|
||||
class MetaInstanceId(Id):
|
||||
"""Identifier for a MetaInstance."""
|
||||
|
||||
|
||||
class Host(CamelCaseModel):
|
||||
ip: str
|
||||
port: int
|
||||
|
||||
@@ -5,7 +5,8 @@ from pydantic import Field
|
||||
|
||||
from exo.shared.topology import Connection
|
||||
from exo.shared.types.chunks import GenerationChunk, InputImageChunk
|
||||
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
|
||||
from exo.shared.types.common import CommandId, Id, MetaInstanceId, NodeId, SessionId
|
||||
from exo.shared.types.meta_instance import MetaInstance
|
||||
from exo.shared.types.tasks import Task, TaskId, TaskStatus
|
||||
from exo.shared.types.worker.downloads import DownloadProgress
|
||||
from exo.shared.types.worker.instances import Instance, InstanceId
|
||||
@@ -66,6 +67,30 @@ class InstanceCreated(BaseEvent):
|
||||
|
||||
class InstanceDeleted(BaseEvent):
|
||||
instance_id: InstanceId
|
||||
failure_error: str | None = None
|
||||
|
||||
|
||||
class MetaInstanceCreated(BaseEvent):
|
||||
meta_instance: MetaInstance
|
||||
|
||||
|
||||
class MetaInstanceDeleted(BaseEvent):
|
||||
meta_instance_id: MetaInstanceId
|
||||
|
||||
|
||||
@final
|
||||
class MetaInstancePlacementFailed(BaseEvent):
|
||||
meta_instance_id: MetaInstanceId
|
||||
reason: str
|
||||
|
||||
|
||||
@final
|
||||
class InstanceRetrying(BaseEvent):
|
||||
"""Runners failed but retry count is below the limit — restart runners, keep instance."""
|
||||
|
||||
instance_id: InstanceId
|
||||
meta_instance_id: MetaInstanceId
|
||||
failure_error: str
|
||||
|
||||
|
||||
class RunnerStatusUpdated(BaseEvent):
|
||||
@@ -141,6 +166,10 @@ Event = (
|
||||
| TaskAcknowledged
|
||||
| InstanceCreated
|
||||
| InstanceDeleted
|
||||
| InstanceRetrying
|
||||
| MetaInstanceCreated
|
||||
| MetaInstanceDeleted
|
||||
| MetaInstancePlacementFailed
|
||||
| RunnerStatusUpdated
|
||||
| RunnerDeleted
|
||||
| NodeTimedOut
|
||||
|
||||
25
src/exo/shared/types/meta_instance.py
Normal file
25
src/exo/shared/types/meta_instance.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from typing import final
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from exo.shared.models.model_cards import ModelId
|
||||
from exo.shared.types.common import MetaInstanceId, NodeId
|
||||
from exo.shared.types.worker.instances import InstanceMeta
|
||||
from exo.shared.types.worker.shards import Sharding
|
||||
from exo.utils.pydantic_ext import FrozenModel
|
||||
|
||||
|
||||
@final
|
||||
class MetaInstance(FrozenModel):
|
||||
"""Declarative constraint: ensure an instance matching these parameters always exists."""
|
||||
|
||||
meta_instance_id: MetaInstanceId = Field(default_factory=MetaInstanceId)
|
||||
model_id: ModelId
|
||||
sharding: Sharding = Sharding.Pipeline
|
||||
instance_meta: InstanceMeta = InstanceMeta.MlxRing
|
||||
min_nodes: int = 1
|
||||
node_ids: list[NodeId] | None = None
|
||||
# Failure tracking
|
||||
placement_error: str | None = None
|
||||
consecutive_failures: int = 0
|
||||
last_failure_error: str | None = None
|
||||
@@ -6,7 +6,8 @@ from pydantic import ConfigDict, Field, field_serializer, field_validator
|
||||
from pydantic.alias_generators import to_camel
|
||||
|
||||
from exo.shared.topology import Topology, TopologySnapshot
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.common import MetaInstanceId, NodeId
|
||||
from exo.shared.types.meta_instance import MetaInstance
|
||||
from exo.shared.types.profiling import (
|
||||
DiskUsage,
|
||||
MemoryUsage,
|
||||
@@ -41,6 +42,7 @@ class State(CamelCaseModel):
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
instances: Mapping[InstanceId, Instance] = {}
|
||||
meta_instances: Mapping[MetaInstanceId, MetaInstance] = {}
|
||||
runners: Mapping[RunnerId, RunnerStatus] = {}
|
||||
downloads: Mapping[NodeId, Sequence[DownloadProgress]] = {}
|
||||
tasks: Mapping[TaskId, Task] = {}
|
||||
|
||||
@@ -2,7 +2,7 @@ from enum import Enum
|
||||
|
||||
from pydantic import model_validator
|
||||
|
||||
from exo.shared.types.common import Host, Id, NodeId
|
||||
from exo.shared.types.common import Host, Id, MetaInstanceId, NodeId
|
||||
from exo.shared.types.worker.runners import RunnerId, ShardAssignments, ShardMetadata
|
||||
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
|
||||
|
||||
@@ -19,6 +19,7 @@ class InstanceMeta(str, Enum):
|
||||
class BaseInstance(TaggedModel):
|
||||
instance_id: InstanceId
|
||||
shard_assignments: ShardAssignments
|
||||
meta_instance_id: MetaInstanceId | None = None
|
||||
|
||||
def shard(self, runner_id: RunnerId) -> ShardMetadata | None:
|
||||
return self.shard_assignments.runner_to_shard.get(runner_id, None)
|
||||
|
||||
@@ -34,6 +34,7 @@ from exo.shared.types.worker.runners import (
|
||||
RunnerLoading,
|
||||
RunnerReady,
|
||||
RunnerRunning,
|
||||
RunnerShutdown,
|
||||
RunnerStatus,
|
||||
RunnerWarmingUp,
|
||||
)
|
||||
@@ -54,7 +55,7 @@ def plan(
|
||||
# Python short circuiting OR logic should evaluate these sequentially.
|
||||
return (
|
||||
_kill_runner(runners, all_runners, instances)
|
||||
or _create_runner(node_id, runners, instances)
|
||||
or _create_runner(node_id, runners, instances, all_runners)
|
||||
or _model_needs_download(node_id, runners, global_download_status)
|
||||
or _init_distributed_backend(runners, all_runners)
|
||||
or _load_model(runners, all_runners, global_download_status)
|
||||
@@ -73,6 +74,12 @@ def _kill_runner(
|
||||
if (instance_id := runner.bound_instance.instance.instance_id) not in instances:
|
||||
return Shutdown(instance_id=instance_id, runner_id=runner_id)
|
||||
|
||||
# Master removed our runner from state (retry signal) and process is dead
|
||||
if runner_id not in all_runners and isinstance(
|
||||
runner.status, (RunnerFailed, RunnerShutdown)
|
||||
):
|
||||
return Shutdown(instance_id=instance_id, runner_id=runner_id)
|
||||
|
||||
for (
|
||||
global_runner_id
|
||||
) in runner.bound_instance.instance.shard_assignments.node_to_runner.values():
|
||||
@@ -90,6 +97,7 @@ def _create_runner(
|
||||
node_id: NodeId,
|
||||
runners: Mapping[RunnerId, RunnerSupervisor],
|
||||
instances: Mapping[InstanceId, Instance],
|
||||
all_runners: Mapping[RunnerId, RunnerStatus],
|
||||
) -> CreateRunner | None:
|
||||
for instance in instances.values():
|
||||
runner_id = instance.shard_assignments.node_to_runner.get(node_id, None)
|
||||
@@ -99,6 +107,16 @@ def _create_runner(
|
||||
if runner_id in runners:
|
||||
continue
|
||||
|
||||
# Don't create while any peer runner is in a terminal state — wait for
|
||||
# the master to emit InstanceRetrying which removes them from state.
|
||||
has_terminal_peer = any(
|
||||
isinstance(all_runners.get(peer_rid), (RunnerFailed, RunnerShutdown))
|
||||
for peer_rid in instance.shard_assignments.node_to_runner.values()
|
||||
if peer_rid != runner_id
|
||||
)
|
||||
if has_terminal_peer:
|
||||
continue
|
||||
|
||||
shard = instance.shard(runner_id)
|
||||
assert shard is not None
|
||||
|
||||
|
||||
Reference in New Issue
Block a user