Compare commits

..

1 Commits

Author SHA1 Message Date
Sami Khan
40cbecb5c4 optimize dashboard 2025-12-29 22:00:59 +05:00
18 changed files with 311 additions and 177 deletions

View File

@@ -198,8 +198,10 @@
stroke: oklch(0.85 0.18 85 / 0.4);
stroke-width: 1.5px;
stroke-dasharray: 8, 8;
animation: flowAnimation 1s linear infinite;
animation: flowAnimation 1.5s linear infinite;
filter: drop-shadow(0 0 3px oklch(0.85 0.18 85 / 0.5));
/* GPU optimization - hint to browser this element will animate */
will-change: stroke-dashoffset;
}
.graph-link-active {
@@ -208,6 +210,24 @@
filter: drop-shadow(0 0 6px oklch(0.85 0.18 85 / 0.8));
}
/* Reduce motion for users who prefer it - also saves GPU */
@media (prefers-reduced-motion: reduce) {
.graph-link {
animation: none;
}
.shooting-star {
animation: none;
display: none;
}
.status-pulse,
.cursor-blink,
.animate-pulse {
animation: none;
}
}
/* CRT Screen effect for topology */
.crt-screen {
position: relative;
@@ -266,13 +286,15 @@ input:focus, textarea:focus {
box-shadow: none;
}
/* Shooting Stars Animation */
/* Shooting Stars Animation - GPU optimized */
.shooting-stars {
position: fixed;
inset: 0;
overflow: hidden;
pointer-events: none;
z-index: 0;
/* Only render when visible */
content-visibility: auto;
}
.shooting-star {
@@ -285,6 +307,9 @@ input:focus, textarea:focus {
animation: shootingStar var(--duration, 3s) linear infinite;
animation-delay: var(--delay, 0s);
opacity: 0;
/* GPU optimization */
will-change: transform, opacity;
transform: translateZ(0);
}
.shooting-star::before {
@@ -320,3 +345,13 @@ input:focus, textarea:focus {
transform: translate(400px, 400px);
}
}
/* Pause animations when page is hidden to save resources */
:root:has(body[data-page-hidden="true"]) {
.shooting-star,
.graph-link,
.status-pulse,
.cursor-blink {
animation-play-state: paused;
}
}

View File

@@ -60,18 +60,11 @@
return models;
});
// Auto-select the first available model if none is selected OR if current selection is stale
// Auto-select the first available model if none is selected
$effect(() => {
const models = availableModels();
if (models.length > 0) {
// If no model selected, select the first available
if (!currentModel) {
setSelectedChatModel(models[0].id);
}
// If current model is stale (no longer has a running instance), reset to first available
else if (!models.some(m => m.id === currentModel)) {
setSelectedChatModel(models[0].id);
}
if (models.length > 0 && !currentModel) {
setSelectedChatModel(models[0].id);
}
});
@@ -146,11 +139,6 @@
}
function handleKeydown(event: KeyboardEvent) {
// Prevent form submission during IME composition (e.g., Chinese, Japanese, Korean input)
if (event.isComposing || event.keyCode === 229) {
return;
}
if (event.key === 'Enter' && !event.shiftKey) {
event.preventDefault();
handleSubmit();

View File

@@ -1,5 +1,5 @@
<script lang="ts">
import { onMount, onDestroy } from 'svelte';
import { onMount, onDestroy, tick } from 'svelte';
import * as d3 from 'd3';
import { topologyData, isTopologyMinimized, debugMode } from '$lib/stores/app.svelte';
@@ -12,11 +12,35 @@ import { topologyData, isTopologyMinimized, debugMode } from '$lib/stores/app.sv
let svgContainer: SVGSVGElement | undefined = $state();
let resizeObserver: ResizeObserver | undefined;
// Optimization: Track last render state to avoid unnecessary re-renders
let lastRenderHash = '';
let lastHighlightedNodesHash = '';
let lastDimensions = { width: 0, height: 0 };
let isRendering = false;
let pendingRender = false;
const isMinimized = $derived(isTopologyMinimized());
const data = $derived(topologyData());
const debugEnabled = $derived(debugMode());
// Generate a hash of relevant data to detect actual changes
function generateDataHash(topologyData: typeof data, minimized: boolean, debug: boolean): string {
if (!topologyData) return 'null';
const nodes = topologyData.nodes || {};
const edges = topologyData.edges || [];
// Create a lightweight hash from key properties only
const nodeHashes = Object.entries(nodes).map(([id, n]) => {
const macmon = n.macmon_info;
return `${id}:${n.friendly_name || ''}:${macmon?.memory?.ram_usage || 0}:${macmon?.memory?.ram_total || 0}:${macmon?.temp?.gpu_temp_avg || 0}:${macmon?.gpu_usage?.[1] || 0}:${macmon?.sys_power || 0}`;
}).sort().join('|');
const edgeHash = edges.map(e => `${e.source}-${e.target}`).sort().join(',');
return `${nodeHashes}::${edgeHash}::${minimized}::${debug}`;
}
function getNodeLabel(nodeId: string): string {
const node = data?.nodes?.[nodeId];
return node?.friendly_name || nodeId.slice(0, 8);
@@ -932,16 +956,59 @@ function wrapLine(text: string, maxLen: number): string[] {
}
$effect(() => {
if (data) {
// Throttled render function to prevent too-frequent updates
function scheduleRender() {
if (isRendering) {
pendingRender = true;
return;
}
isRendering = true;
requestAnimationFrame(() => {
renderGraph();
isRendering = false;
if (pendingRender) {
pendingRender = false;
scheduleRender();
}
});
}
$effect(() => {
if (!data || !svgContainer) return;
// Generate hash of current state
const currentHash = generateDataHash(data, isMinimized, debugEnabled);
const highlightHash = Array.from(highlightedNodes).sort().join(',');
// Get current dimensions
const rect = svgContainer.getBoundingClientRect();
const dimensionsChanged = rect.width !== lastDimensions.width || rect.height !== lastDimensions.height;
// Only re-render if something actually changed
if (currentHash !== lastRenderHash || highlightHash !== lastHighlightedNodesHash || dimensionsChanged) {
lastRenderHash = currentHash;
lastHighlightedNodesHash = highlightHash;
lastDimensions = { width: rect.width, height: rect.height };
scheduleRender();
}
});
onMount(() => {
if (svgContainer) {
// Use a debounced resize observer to prevent rapid re-renders
let resizeTimeout: ReturnType<typeof setTimeout> | null = null;
resizeObserver = new ResizeObserver(() => {
renderGraph();
if (resizeTimeout) clearTimeout(resizeTimeout);
resizeTimeout = setTimeout(() => {
const rect = svgContainer!.getBoundingClientRect();
if (rect.width !== lastDimensions.width || rect.height !== lastDimensions.height) {
lastDimensions = { width: rect.width, height: rect.height };
scheduleRender();
}
}, 100);
});
resizeObserver.observe(svgContainer);
}
@@ -969,11 +1036,20 @@ function wrapLine(text: string, maxLen: number): string[] {
stroke-width: 1px;
stroke-dasharray: 4, 4;
opacity: 0.8;
animation: flowAnimation 0.75s linear infinite;
/* Slower animation = less GPU usage */
animation: flowAnimation 2s linear infinite;
/* GPU optimization */
will-change: stroke-dashoffset;
}
@keyframes flowAnimation {
from { stroke-dashoffset: 0; }
to { stroke-dashoffset: -10; }
}
/* Respect reduced motion preference */
@media (prefers-reduced-motion: reduce) {
:global(.graph-link) {
animation: none;
}
}
</style>

View File

@@ -297,6 +297,35 @@ function extractIpFromMultiaddr(ma?: string): string | undefined {
return undefined;
}
// Deep comparison utility for preventing unnecessary state updates
function shallowEqual(a: unknown, b: unknown): boolean {
if (a === b) return true;
if (a === null || b === null) return false;
if (typeof a !== 'object' || typeof b !== 'object') return false;
const aObj = a as Record<string, unknown>;
const bObj = b as Record<string, unknown>;
const aKeys = Object.keys(aObj);
const bKeys = Object.keys(bObj);
if (aKeys.length !== bKeys.length) return false;
for (const key of aKeys) {
if (aObj[key] !== bObj[key]) return false;
}
return true;
}
// Faster JSON comparison for complex nested objects
function jsonEqual(a: unknown, b: unknown): boolean {
if (a === b) return true;
try {
return JSON.stringify(a) === JSON.stringify(b);
} catch {
return false;
}
}
class AppStore {
// Conversation state
conversations = $state<Conversation[]>([]);
@@ -330,9 +359,18 @@ class AppStore {
topologyOnlyMode = $state(false);
chatSidebarVisible = $state(true); // Shown by default
// Visibility state - used to pause polling when tab is hidden
private isPageVisible = true;
private fetchInterval: ReturnType<typeof setInterval> | null = null;
private previewsInterval: ReturnType<typeof setInterval> | null = null;
private lastConversationPersistTs = 0;
// Cache for comparison - prevents unnecessary reactivity
private lastTopologyJson = '';
private lastInstancesJson = '';
private lastRunnersJson = '';
private lastDownloadsJson = '';
constructor() {
if (browser) {
@@ -341,9 +379,26 @@ class AppStore {
this.loadDebugModeFromStorage();
this.loadTopologyOnlyModeFromStorage();
this.loadChatSidebarVisibleFromStorage();
this.setupVisibilityListener();
}
}
/**
* Listen for page visibility changes to pause polling when hidden
*/
private setupVisibilityListener() {
if (typeof document === 'undefined') return;
document.addEventListener('visibilitychange', () => {
this.isPageVisible = document.visibilityState === 'visible';
if (this.isPageVisible) {
// Resume polling when page becomes visible
this.fetchState();
}
});
}
/**
* Load conversations from localStorage
*/
@@ -770,7 +825,9 @@ class AppStore {
startPolling() {
this.fetchState();
this.fetchInterval = setInterval(() => this.fetchState(), 1000);
// Poll every 2 seconds instead of 1 second - reduces CPU/GPU load by 50%
// Data comparison ensures we only update when something actually changes
this.fetchInterval = setInterval(() => this.fetchState(), 2000);
}
stopPolling() {
@@ -782,6 +839,9 @@ class AppStore {
}
async fetchState() {
// Skip polling when page is hidden to save resources
if (!this.isPageVisible) return;
try {
const response = await fetch('/state');
if (!response.ok) {
@@ -789,19 +849,44 @@ class AppStore {
}
const data: RawStateResponse = await response.json();
// Only update topology if it actually changed (prevents unnecessary D3 re-renders)
if (data.topology) {
this.topologyData = transformTopology(data.topology, data.nodeProfiles);
const newTopology = transformTopology(data.topology, data.nodeProfiles);
const newTopologyJson = JSON.stringify(newTopology);
if (newTopologyJson !== this.lastTopologyJson) {
this.lastTopologyJson = newTopologyJson;
this.topologyData = newTopology;
}
}
// Only update instances if changed
if (data.instances) {
this.instances = data.instances;
this.refreshConversationModelFromInstances();
const newInstancesJson = JSON.stringify(data.instances);
if (newInstancesJson !== this.lastInstancesJson) {
this.lastInstancesJson = newInstancesJson;
this.instances = data.instances;
this.refreshConversationModelFromInstances();
}
}
// Only update runners if changed
if (data.runners) {
this.runners = data.runners;
const newRunnersJson = JSON.stringify(data.runners);
if (newRunnersJson !== this.lastRunnersJson) {
this.lastRunnersJson = newRunnersJson;
this.runners = data.runners;
}
}
// Only update downloads if changed
if (data.downloads) {
this.downloads = data.downloads;
const newDownloadsJson = JSON.stringify(data.downloads);
if (newDownloadsJson !== this.lastDownloadsJson) {
this.lastDownloadsJson = newDownloadsJson;
this.downloads = data.downloads;
}
}
this.lastUpdate = Date.now();
} catch (error) {
console.error('Error fetching state:', error);

View File

@@ -1,7 +1,25 @@
<script lang="ts">
import '../app.css';
import { onMount } from 'svelte';
import { browser } from '$app/environment';
let { children } = $props();
let isPageHidden = $state(false);
onMount(() => {
if (!browser) return;
// Listen for visibility changes to pause animations when hidden
const handleVisibilityChange = () => {
isPageHidden = document.visibilityState === 'hidden';
};
document.addEventListener('visibilitychange', handleVisibilityChange);
return () => {
document.removeEventListener('visibilitychange', handleVisibilityChange);
};
});
</script>
<svelte:head>
@@ -9,7 +27,7 @@
<meta name="description" content="EXO - Distributed AI Cluster Dashboard" />
</svelte:head>
<div class="min-h-screen bg-background text-foreground">
<div class="min-h-screen bg-background text-foreground" data-page-hidden={isPageHidden}>
{@render children?.()}
</div>

View File

@@ -99,17 +99,35 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
}
// Compute highlighted nodes from hovered instance or hovered preview
// Memoized to avoid creating new Sets on every render
let lastHighlightedNodesKey = '';
let cachedHighlightedNodes: Set<string> = new Set();
const highlightedNodes = $derived(() => {
// Create a key for the current state to enable memoization
const previewKey = Array.from(hoveredPreviewNodes).sort().join(',');
const currentKey = `${hoveredInstanceId || 'null'}:${previewKey}`;
// Return cached value if nothing changed
if (currentKey === lastHighlightedNodesKey) {
return cachedHighlightedNodes;
}
lastHighlightedNodesKey = currentKey;
// First check instance hover
if (hoveredInstanceId) {
const instanceWrapped = instanceData[hoveredInstanceId];
return unwrapInstanceNodes(instanceWrapped);
cachedHighlightedNodes = unwrapInstanceNodes(instanceWrapped);
return cachedHighlightedNodes;
}
// Then check preview hover
if (hoveredPreviewNodes.size > 0) {
return hoveredPreviewNodes;
cachedHighlightedNodes = hoveredPreviewNodes;
return cachedHighlightedNodes;
}
return new Set<string>();
cachedHighlightedNodes = new Set<string>();
return cachedHighlightedNodes;
});
// Helper to estimate memory from model ID (mirrors ModelCard logic)
@@ -516,12 +534,13 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
};
}
// Debug: Log downloads data when it changes
$effect(() => {
if (downloadsData && Object.keys(downloadsData).length > 0) {
console.log('[Download Debug] Current downloads:', downloadsData);
}
});
// Debug: Log downloads data when it changes (disabled in production for performance)
// Uncomment for debugging:
// $effect(() => {
// if (downloadsData && Object.keys(downloadsData).length > 0) {
// console.log('[Download Debug] Current downloads:', downloadsData);
// }
// });
// Helper to get download status for an instance
function getInstanceDownloadStatus(instanceId: string, instanceWrapped: unknown): {

View File

@@ -1,6 +1,5 @@
import argparse
import multiprocessing as mp
import os
import signal
from dataclasses import dataclass, field
from typing import Self
@@ -195,7 +194,6 @@ def main():
# TODO: Refactor the current verbosity system
logger_setup(EXO_LOG, args.verbosity)
logger.info("Starting EXO")
logger.info(f"EXO_LIBP2P_NAMESPACE: {os.getenv('EXO_LIBP2P_NAMESPACE')}")
node = anyio.run(Node.create, args)
anyio.run(node.run)

View File

@@ -21,7 +21,6 @@ from exo.shared.types.commands import (
)
from exo.shared.types.events import Event, InstanceCreated, InstanceDeleted
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId
from exo.shared.types.topology import NodeInfo
from exo.shared.types.worker.instances import (
Instance,
@@ -30,7 +29,6 @@ from exo.shared.types.worker.instances import (
MlxJacclInstance,
MlxRingInstance,
)
from exo.shared.types.worker.shards import Sharding
def random_ephemeral_port() -> int:
@@ -67,28 +65,6 @@ def place_instance(
if not cycles_with_sufficient_memory:
raise ValueError("No cycles found with sufficient memory")
if command.sharding == Sharding.Tensor:
if not command.model_meta.supports_tensor:
raise ValueError(
f"Requested Tensor sharding but this model does not support tensor parallelism: {command.model_meta.model_id}"
)
# TODO: the condition here for tensor parallel is not correct, but it works good enough for now.
cycles_with_sufficient_memory = [
cycle
for cycle in cycles_with_sufficient_memory
if command.model_meta.hidden_size % len(cycle) == 0
]
if not cycles_with_sufficient_memory:
raise ValueError(
f"No tensor sharding found for model with hidden_size {command.model_meta.hidden_size} candidate cycles"
)
if command.sharding == Sharding.Pipeline and command.model_meta.model_id == ModelId(
"mlx-community/DeepSeek-V3.1-8bit"
):
raise ValueError(
"Pipeline parallelism is not supported for DeepSeek V3.1 (8-bit)"
)
smallest_cycles = get_smallest_cycles(cycles_with_sufficient_memory)
smallest_tb_cycles = [

View File

@@ -385,14 +385,13 @@ def get_mlx_jaccl_coordinators(
address in format "X.X.X.X:PORT" per node.
"""
rank_0_node = selected_cycle[0]
logger.debug(f"Selecting coordinator from rank 0 node: {rank_0_node.node_id}")
logger.info(f"Selecting coordinator from rank 0 node: {rank_0_node.node_id}")
def get_ip_for_node(n: NodeInfo) -> str:
if n.node_id == rank_0_node.node_id:
return "0.0.0.0"
ip = _find_ip_prioritised(n, rank_0_node, cycle_digraph)
if ip:
for ip, _ in _find_connection_ip(n, rank_0_node, cycle_digraph):
return ip
logger.warning(

View File

@@ -50,7 +50,7 @@ def model_meta() -> ModelMetadata:
storage_size=Memory.from_kb(1000),
pretty_name="Test Model",
n_layers=10,
hidden_size=30,
hidden_size=10,
supports_tensor=True,
)

View File

@@ -450,11 +450,6 @@ async def get_weight_map(repo_id: str, revision: str = "main") -> dict[str, str]
async def resolve_allow_patterns(shard: ShardMetadata) -> list[str]:
# TODO: 'Smart' downloads are disabled because:
# (i) We don't handle all kinds of files;
# (ii) We don't have sticky sessions.
# (iii) Tensor parallel requires all files.
return ["*"]
try:
weight_map = await get_weight_map(str(shard.model_meta.model_id))
return get_allow_patterns(weight_map, shard)

View File

@@ -9,7 +9,7 @@ MAX_KV_SIZE: int | None = 3200
KEEP_KV_SIZE: int | None = 1600
QUANTIZE_MODEL_MODE: str | None = "affine"
CACHE_GROUP_SIZE: int = 64
KV_CACHE_BITS: int | None = None
KV_CACHE_BITS: int | None = 8
# TODO: We should really make this opt-in, but Kimi requires trust_remote_code=True
TRUST_REMOTE_CODE: bool = True

View File

@@ -395,5 +395,11 @@ def set_wired_limit_for_model(model_size: Memory):
"MB. This can be slow. See the documentation for possible work-arounds: "
"https://github.com/ml-explore/mlx-lm/tree/main#large-models"
)
kv_bytes = int(0.02 * model_bytes)
target_cache = int(1.10 * (model_bytes + kv_bytes))
target_cache = min(target_cache, max_rec_size)
mx.set_cache_limit(target_cache)
mx.set_wired_limit(max_rec_size)
logger.info(f"Wired limit set to {max_rec_size}.")
logger.info(
f"Wired limit set to {max_rec_size}. Cache limit set to {target_cache}."
)

View File

@@ -23,7 +23,6 @@ from exo.shared.types.events import (
TopologyEdgeCreated,
TopologyEdgeDeleted,
)
from exo.shared.types.models import ModelId
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.profiling import MemoryPerformanceProfile, NodePerformanceProfile
from exo.shared.types.state import State
@@ -84,7 +83,7 @@ class Worker:
self.out_for_delivery: dict[EventId, ForwarderEvent] = {}
self.state: State = State()
self.download_status: dict[ModelId, DownloadProgress] = {}
self.download_status: dict[ShardMetadata, DownloadProgress] = {}
self.runners: dict[RunnerId, RunnerSupervisor] = {}
self._tg: TaskGroup | None = None
@@ -129,7 +128,6 @@ class Worker:
tg.start_soon(start_polling_node_metrics, resource_monitor_callback)
tg.start_soon(start_polling_memory_metrics, memory_monitor_callback)
tg.start_soon(self._emit_existing_download_progress)
tg.start_soon(self._connection_message_event_writer)
tg.start_soon(self._resend_out_for_delivery)
tg.start_soon(self._event_applier)
@@ -202,11 +200,11 @@ class Worker:
)
)
case DownloadModel(shard_metadata=shard):
if shard.model_meta.model_id not in self.download_status:
if shard not in self.download_status:
progress = DownloadPending(
shard_metadata=shard, node_id=self.node_id
)
self.download_status[shard.model_meta.model_id] = progress
self.download_status[shard] = progress
await self.event_sender.send(
NodeDownloadProgress(download_progress=progress)
)
@@ -219,7 +217,7 @@ class Worker:
progress = DownloadCompleted(
shard_metadata=shard, node_id=self.node_id
)
self.download_status[shard.model_meta.model_id] = progress
self.download_status[shard] = progress
await self.event_sender.send(
NodeDownloadProgress(download_progress=progress)
)
@@ -351,7 +349,7 @@ class Worker:
initial_progress
),
)
self.download_status[task.shard_metadata.model_meta.model_id] = status
self.download_status[task.shard_metadata] = status
self.event_sender.send_nowait(NodeDownloadProgress(download_progress=status))
last_progress_time = 0.0
@@ -365,7 +363,7 @@ class Worker:
nonlocal last_progress_time
if progress.status == "complete":
status = DownloadCompleted(shard_metadata=shard, node_id=self.node_id)
self.download_status[shard.model_meta.model_id] = status
self.download_status[shard] = status
# Footgun!
self.event_sender.send_nowait(
NodeDownloadProgress(download_progress=status)
@@ -386,7 +384,7 @@ class Worker:
progress
),
)
self.download_status[shard.model_meta.model_id] = status
self.download_status[shard] = status
self.event_sender.send_nowait(
NodeDownloadProgress(download_progress=status)
)
@@ -446,40 +444,3 @@ class Worker:
await self.event_sender.send(TopologyEdgeDeleted(edge=conn))
await anyio.sleep(10)
async def _emit_existing_download_progress(self) -> None:
try:
while True:
logger.info("Fetching and emitting existing download progress...")
async for (
_,
progress,
) in self.shard_downloader.get_shard_download_status():
if progress.status == "complete":
status = DownloadCompleted(
node_id=self.node_id, shard_metadata=progress.shard
)
elif progress.status in ["in_progress", "not_started"]:
if progress.downloaded_bytes_this_session.in_bytes == 0:
status = DownloadPending(
node_id=self.node_id, shard_metadata=progress.shard
)
else:
status = DownloadOngoing(
node_id=self.node_id,
shard_metadata=progress.shard,
download_progress=map_repo_download_progress_to_download_progress_data(
progress
),
)
else:
continue
self.download_status[progress.shard.model_meta.model_id] = status
await self.event_sender.send(
NodeDownloadProgress(download_progress=status)
)
logger.info("Done emitting existing download progress.")
await anyio.sleep(5 * 60) # 5 minutes
except Exception as e:
logger.error(f"Error emitting existing download progress: {e}")

View File

@@ -3,7 +3,6 @@
from collections.abc import Mapping, Sequence
from exo.shared.types.common import NodeId
from exo.shared.types.models import ModelId
from exo.shared.types.tasks import (
ChatCompletion,
ConnectToGroup,
@@ -35,6 +34,7 @@ from exo.shared.types.worker.runners import (
RunnerStatus,
RunnerWarmingUp,
)
from exo.shared.types.worker.shards import ShardMetadata
from exo.worker.runner.runner_supervisor import RunnerSupervisor
@@ -43,7 +43,7 @@ def plan(
# Runners is expected to be FRESH and so should not come from state
runners: Mapping[RunnerId, RunnerSupervisor],
# DL_status is expected to be FRESH and so should not come from state
download_status: Mapping[ModelId, DownloadProgress],
download_status: Mapping[ShardMetadata, DownloadProgress],
# gdls is not expected to be fresh
global_download_status: Mapping[NodeId, Sequence[DownloadProgress]],
instances: Mapping[InstanceId, Instance],
@@ -111,14 +111,13 @@ def _create_runner(
def _model_needs_download(
runners: Mapping[RunnerId, RunnerSupervisor],
download_status: Mapping[ModelId, DownloadProgress],
download_status: Mapping[ShardMetadata, DownloadProgress],
) -> DownloadModel | None:
for runner in runners.values():
model_id = runner.bound_instance.bound_shard.model_meta.model_id
if isinstance(runner.status, RunnerIdle) and (
model_id not in download_status
or not isinstance(
download_status[model_id], (DownloadOngoing, DownloadCompleted)
not isinstance(
download_status.get(runner.bound_instance.bound_shard, None),
(DownloadOngoing, DownloadCompleted),
)
):
# We don't invalidate download_status randomly in case a file gets deleted on disk
@@ -236,8 +235,9 @@ def _ready_to_warmup(
assert device_rank < world_size
assert device_rank >= 0
# Rank != 0
accepting_ranks_ready = device_rank > 0 and all(
# TODO: Ensure these align with MLX distributeds expectations.
# Rank < n-1
accepting_ranks_ready = device_rank < world_size - 1 and all(
isinstance(
all_runners.get(global_runner_id, None),
(RunnerLoaded, RunnerWarmingUp),
@@ -245,8 +245,8 @@ def _ready_to_warmup(
for global_runner_id in shard_assignments.runner_to_shard
)
# Rank = 0
connecting_rank_ready = device_rank == 0 and all(
# Rank = n-1
connecting_rank_ready = device_rank == world_size - 1 and all(
isinstance(all_runners.get(global_runner_id, None), RunnerWarmingUp)
for global_runner_id in shard_assignments.runner_to_shard
if global_runner_id != runner_id

View File

@@ -9,11 +9,9 @@ MASTER_NODE_ID = NodeId("ffffffff-aaaa-4aaa-8aaa-aaaaaaaaaaaa")
NODE_A: Final[NodeId] = NodeId("aaaaaaaa-aaaa-4aaa-8aaa-aaaaaaaaaaaa")
NODE_B: Final[NodeId] = NodeId("bbbbbbbb-bbbb-4bbb-8bbb-bbbbbbbbbbbb")
NODE_C: Final[NodeId] = NodeId("cccccccc-cccc-4ccc-8ccc-cccccccccccc")
RUNNER_1_ID: Final[RunnerId] = RunnerId("11111111-1111-4111-8111-111111111111")
RUNNER_2_ID: Final[RunnerId] = RunnerId("33333333-3333-4333-8333-333333333333")
RUNNER_3_ID: Final[RunnerId] = RunnerId("Runner3")
INSTANCE_1_ID: Final[InstanceId] = InstanceId("22222222-2222-4222-8222-222222222222")
INSTANCE_2_ID: Final[InstanceId] = InstanceId("44444444-4444-4444-8444-444444444444")

View File

@@ -1,6 +1,5 @@
import exo.worker.plan as plan_mod
from exo.shared.types.common import NodeId
from exo.shared.types.models import ModelId
from exo.shared.types.tasks import LoadModel
from exo.shared.types.worker.downloads import DownloadCompleted, DownloadProgress
from exo.shared.types.worker.instances import BoundInstance
@@ -8,6 +7,7 @@ from exo.shared.types.worker.runners import (
RunnerConnected,
RunnerIdle,
)
from exo.shared.types.worker.shards import ShardMetadata
from exo.worker.tests.constants import (
INSTANCE_1_ID,
MODEL_A_ID,
@@ -46,7 +46,7 @@ def test_plan_requests_download_when_waiting_and_shard_not_downloaded():
all_runners = {RUNNER_1_ID: RunnerIdle()}
# No entry for this shard -> should trigger DownloadModel
download_status: dict[ModelId, DownloadProgress] = {}
download_status: dict[ShardMetadata, DownloadProgress] = {}
result = plan_mod.plan(
node_id=NODE_A,
@@ -94,7 +94,7 @@ def test_plan_loads_model_when_all_shards_downloaded_and_waiting():
# Local node has already marked its shard as downloaded (not actually used by _load_model)
local_download_status = {
MODEL_A_ID: DownloadCompleted(shard_metadata=shard1, node_id=NODE_A)
shard1: DownloadCompleted(shard_metadata=shard1, node_id=NODE_A) # type: ignore[reportUnhashable]
}
# Global view has completed downloads for both nodes
@@ -140,7 +140,7 @@ def test_plan_does_not_request_download_when_shard_already_downloaded():
# Local status claims the shard is downloaded already
local_download_status = {
MODEL_A_ID: DownloadCompleted(shard_metadata=shard, node_id=NODE_A)
shard: DownloadCompleted(shard_metadata=shard, node_id=NODE_A) # type: ignore[reportUnhashable]
}
# Global view hasn't caught up yet (no completed shards recorded for NODE_A)
@@ -192,7 +192,7 @@ def test_plan_does_not_load_model_until_all_shards_downloaded_globally():
# Only NODE_A's shard is recorded as downloaded globally
local_download_status = {
MODEL_A_ID: DownloadCompleted(shard_metadata=shard1, node_id=NODE_A)
shard1: DownloadCompleted(shard_metadata=shard1, node_id=NODE_A) # type: ignore[reportUnhashable]
}
global_download_status = {
NODE_A: [DownloadCompleted(shard_metadata=shard1, node_id=NODE_A)],

View File

@@ -12,10 +12,8 @@ from exo.worker.tests.constants import (
MODEL_A_ID,
NODE_A,
NODE_B,
NODE_C,
RUNNER_1_ID,
RUNNER_2_ID,
RUNNER_3_ID,
)
from exo.worker.tests.unittests.conftest import (
FakeRunnerSupervisor,
@@ -26,39 +24,37 @@ from exo.worker.tests.unittests.conftest import (
def test_plan_starts_warmup_for_accepting_rank_when_all_loaded_or_warming():
"""
For non-zero device_rank shards, StartWarmup should be emitted when all
For non-final device_rank shards, StartWarmup should be emitted when all
shards in the instance are Loaded/WarmingUp.
"""
shard0 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=3)
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=3)
shard2 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=2, world_size=3)
shard0 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)
instance = get_mlx_ring_instance(
instance_id=INSTANCE_1_ID,
model_id=MODEL_A_ID,
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID, NODE_C: RUNNER_3_ID},
runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1, RUNNER_3_ID: shard2},
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},
runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1},
)
bound_instance = BoundInstance(
instance=instance, bound_runner_id=RUNNER_2_ID, bound_node_id=NODE_B
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
)
local_runner = FakeRunnerSupervisor(
bound_instance=bound_instance, status=RunnerLoaded()
)
runners = {RUNNER_2_ID: local_runner}
runners = {RUNNER_1_ID: local_runner}
instances = {INSTANCE_1_ID: instance}
all_runners = {
RUNNER_1_ID: RunnerLoaded(),
RUNNER_2_ID: RunnerLoaded(),
RUNNER_3_ID: RunnerWarmingUp(),
}
result = plan_mod.plan(
node_id=NODE_B,
node_id=NODE_A,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_A: []},
global_download_status={NODE_B: []},
instances=instances,
all_runners=all_runners,
tasks={},
@@ -154,9 +150,9 @@ def test_plan_does_not_start_warmup_for_rank_zero_until_others_warming():
"""
Rank-zero shard should not start warmup until all non-zero ranks are
already WarmingUp.
For accepting ranks (device_rank != 0), StartWarmup should be
For accepting ranks (device_rank != world_size - 1), StartWarmup should be
emitted when all shards in the instance are Loaded/WarmingUp.
In a 2-node setup, rank 1 is the accepting rank.
In a 2-node setup, rank 0 is the accepting rank.
"""
shard0 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)
@@ -167,7 +163,7 @@ def test_plan_does_not_start_warmup_for_rank_zero_until_others_warming():
runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1},
)
# Rank 1 is the accepting rank
# Rank 0 is the accepting rank
bound_instance = BoundInstance(
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
)
@@ -192,23 +188,6 @@ def test_plan_does_not_start_warmup_for_rank_zero_until_others_warming():
tasks={},
)
assert result is None
all_runners = {
RUNNER_1_ID: RunnerLoaded(),
RUNNER_2_ID: RunnerWarmingUp(),
}
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_A: []},
instances=instances,
all_runners=all_runners,
tasks={},
)
assert isinstance(result, StartWarmup)
assert result.instance_id == INSTANCE_1_ID
@@ -301,8 +280,9 @@ def test_plan_does_not_start_warmup_for_accepting_rank_until_all_loaded_or_warmi
def test_plan_does_not_start_warmup_for_connecting_rank_until_others_warming():
"""
Connecting rank (device_rank == 0) should not start warmup
Connecting rank (device_rank == world_size - 1) should not start warmup
until all other ranks are already WarmingUp.
In a 2-node setup, rank 1 is the connecting rank.
"""
shard0 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)
@@ -315,13 +295,13 @@ def test_plan_does_not_start_warmup_for_connecting_rank_until_others_warming():
# Rank 1 is the connecting rank
bound_instance = BoundInstance(
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
instance=instance, bound_runner_id=RUNNER_2_ID, bound_node_id=NODE_B
)
local_runner = FakeRunnerSupervisor(
bound_instance=bound_instance, status=RunnerLoaded()
)
runners = {RUNNER_1_ID: local_runner}
runners = {RUNNER_2_ID: local_runner}
instances = {INSTANCE_1_ID: instance}
all_runners = {
RUNNER_1_ID: RunnerLoaded(),
@@ -329,7 +309,7 @@ def test_plan_does_not_start_warmup_for_connecting_rank_until_others_warming():
}
result = plan_mod.plan(
node_id=NODE_A,
node_id=NODE_B,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_A: [], NODE_B: []},