Compare commits

..

10 Commits

Author SHA1 Message Date
Sami Khan
046c80c299 resolve issue #1070 2026-01-01 05:01:25 +05:00
RickyChen / 陳昭儒
844bcc7ce6 fix: prevent form submission during IME composition (#1069)
## Problem
When typing in Chinese (or other IME-based languages like
Japanese/Korean), pressing Enter to select a character from the IME
candidate list would incorrectly submit the message instead of
confirming the character selection.

## Solution
Added IME composition state detection in the `handleKeydown` function in
`ChatForm.svelte`:
- Check `event.isComposing` to detect active IME composition
- Fallback to `event.keyCode === 229` for broader browser compatibility
- Return early when IME is active, allowing normal character selection

## Changes
- Modified `dashboard/src/lib/components/ChatForm.svelte` 
- Added IME composition check before Enter key handling

Co-authored-by: Ricky Chen <rickychen@Rickys-MacBook-Pro.local>
2025-12-31 17:11:04 +00:00
Evan Quiney
c1be5184b2 Fix tests broken by 283c (#1063)
Some tests were broken by #1058 and #1046 - this fixes them.
2025-12-31 01:53:55 +00:00
Alex Cheema
1ec550dff1 Emit download progress on start, and change downloads to be keyed by model_id (#1044)
## Motivation

We added a download page to the dashboard which shows the currently
download status of each model on each node. Users have reported this to
be extremely useful.

However, we don't currently fetch the download progress on start, so it
doesn't show any model's download status.

## Changes

Fetch and emit model download status on start of worker, and
periodically every 5 mins.
Also to support this, I changed download_status to be keyed by model_id
instead of shard, since we want download_status of each model, not each
shard.

## Why It Works

The dashboard already implements the correct functionality, we just
weren't populating the download status in the state. Now it gets
populated and shows correctly.

## Test Plan

### Manual Testing
On a cluster of 2 x 512GB M3 Ultra Mac Studio, I launched an instance
onto one node that hadn't been downloaded. I checked the download page
and it showed the in progress download. I downloaded it to completion,
restarted exo on both nodes, and then opened the download page and it
showed the model as 100% downloaded and other models as 0% that hadn't
been downloaded.

---------

Co-authored-by: Evan <evanev7@gmail.com>
2025-12-31 01:18:10 +00:00
Alex Cheema
283c0e39e4 Placement filters for tensor parallel supports_tensor, tensor dimension and pipeline parallel deepseek v3.1 (#1058)
## Motivation

Certain placements are not valid. Added filters to exclude these placements. There were invalid placement previews being shown in the dashboard which would then fail when the user actually tries to launch an instance with that placement.


## Changes

Three filters added:

1. Certain models do not support tensor parallel at all. Checks `supports_tensor` on the model_meta.
2. For models that do support tensor parallelism, certain tensor parallel sizes are not valid. This check is actually not correct right now but it works fine for now. The actual correct check is more involved.
3. For unknown reasons, deepseek v3.1 (8-bit) does not work with tensor parallelism.

## Why It Works

`place_instance` now raises an `Exception` for invalid placements.

## Test Plan

### Manual Testing
Since `/instance/previews` enumerates all possible placements and runs `place_instance`, I checked the dashboard to see if invalid placements are still shown.
2025-12-31 00:33:40 +00:00
Alex Cheema
35be4c55c3 prioritise mlx jaccl coordinator ip (en0 -> en1 -> non-TB5 -> other) 2025-12-31 00:10:19 +00:00
Alex Cheema
31d4cd8409 set KV_CACHE_BITS to None to disable quantized kv cache 2025-12-31 00:03:30 +00:00
Alex Cheema
8a6da58404 remove mx.set_cache_limit 2025-12-30 23:58:15 +00:00
Alex Cheema
16e2bfd3b3 log EXO_LIBP2P_NAMESPACE on start 2025-12-30 04:08:47 +00:00
Alex Cheema
ade3ee7ec5 fix warmup order. should be rank!=0 then rank=0 2025-12-30 03:29:34 +00:00
18 changed files with 177 additions and 311 deletions

View File

@@ -198,10 +198,8 @@
stroke: oklch(0.85 0.18 85 / 0.4);
stroke-width: 1.5px;
stroke-dasharray: 8, 8;
animation: flowAnimation 1.5s linear infinite;
animation: flowAnimation 1s linear infinite;
filter: drop-shadow(0 0 3px oklch(0.85 0.18 85 / 0.5));
/* GPU optimization - hint to browser this element will animate */
will-change: stroke-dashoffset;
}
.graph-link-active {
@@ -210,24 +208,6 @@
filter: drop-shadow(0 0 6px oklch(0.85 0.18 85 / 0.8));
}
/* Reduce motion for users who prefer it - also saves GPU */
@media (prefers-reduced-motion: reduce) {
.graph-link {
animation: none;
}
.shooting-star {
animation: none;
display: none;
}
.status-pulse,
.cursor-blink,
.animate-pulse {
animation: none;
}
}
/* CRT Screen effect for topology */
.crt-screen {
position: relative;
@@ -286,15 +266,13 @@ input:focus, textarea:focus {
box-shadow: none;
}
/* Shooting Stars Animation - GPU optimized */
/* Shooting Stars Animation */
.shooting-stars {
position: fixed;
inset: 0;
overflow: hidden;
pointer-events: none;
z-index: 0;
/* Only render when visible */
content-visibility: auto;
}
.shooting-star {
@@ -307,9 +285,6 @@ input:focus, textarea:focus {
animation: shootingStar var(--duration, 3s) linear infinite;
animation-delay: var(--delay, 0s);
opacity: 0;
/* GPU optimization */
will-change: transform, opacity;
transform: translateZ(0);
}
.shooting-star::before {
@@ -345,13 +320,3 @@ input:focus, textarea:focus {
transform: translate(400px, 400px);
}
}
/* Pause animations when page is hidden to save resources */
:root:has(body[data-page-hidden="true"]) {
.shooting-star,
.graph-link,
.status-pulse,
.cursor-blink {
animation-play-state: paused;
}
}

View File

@@ -60,11 +60,18 @@
return models;
});
// Auto-select the first available model if none is selected
// Auto-select the first available model if none is selected OR if current selection is stale
$effect(() => {
const models = availableModels();
if (models.length > 0 && !currentModel) {
setSelectedChatModel(models[0].id);
if (models.length > 0) {
// If no model selected, select the first available
if (!currentModel) {
setSelectedChatModel(models[0].id);
}
// If current model is stale (no longer has a running instance), reset to first available
else if (!models.some(m => m.id === currentModel)) {
setSelectedChatModel(models[0].id);
}
}
});
@@ -139,6 +146,11 @@
}
function handleKeydown(event: KeyboardEvent) {
// Prevent form submission during IME composition (e.g., Chinese, Japanese, Korean input)
if (event.isComposing || event.keyCode === 229) {
return;
}
if (event.key === 'Enter' && !event.shiftKey) {
event.preventDefault();
handleSubmit();

View File

@@ -1,5 +1,5 @@
<script lang="ts">
import { onMount, onDestroy, tick } from 'svelte';
import { onMount, onDestroy } from 'svelte';
import * as d3 from 'd3';
import { topologyData, isTopologyMinimized, debugMode } from '$lib/stores/app.svelte';
@@ -12,35 +12,11 @@ import { topologyData, isTopologyMinimized, debugMode } from '$lib/stores/app.sv
let svgContainer: SVGSVGElement | undefined = $state();
let resizeObserver: ResizeObserver | undefined;
// Optimization: Track last render state to avoid unnecessary re-renders
let lastRenderHash = '';
let lastHighlightedNodesHash = '';
let lastDimensions = { width: 0, height: 0 };
let isRendering = false;
let pendingRender = false;
const isMinimized = $derived(isTopologyMinimized());
const data = $derived(topologyData());
const debugEnabled = $derived(debugMode());
// Generate a hash of relevant data to detect actual changes
function generateDataHash(topologyData: typeof data, minimized: boolean, debug: boolean): string {
if (!topologyData) return 'null';
const nodes = topologyData.nodes || {};
const edges = topologyData.edges || [];
// Create a lightweight hash from key properties only
const nodeHashes = Object.entries(nodes).map(([id, n]) => {
const macmon = n.macmon_info;
return `${id}:${n.friendly_name || ''}:${macmon?.memory?.ram_usage || 0}:${macmon?.memory?.ram_total || 0}:${macmon?.temp?.gpu_temp_avg || 0}:${macmon?.gpu_usage?.[1] || 0}:${macmon?.sys_power || 0}`;
}).sort().join('|');
const edgeHash = edges.map(e => `${e.source}-${e.target}`).sort().join(',');
return `${nodeHashes}::${edgeHash}::${minimized}::${debug}`;
}
function getNodeLabel(nodeId: string): string {
const node = data?.nodes?.[nodeId];
return node?.friendly_name || nodeId.slice(0, 8);
@@ -956,59 +932,16 @@ function wrapLine(text: string, maxLen: number): string[] {
}
// Throttled render function to prevent too-frequent updates
function scheduleRender() {
if (isRendering) {
pendingRender = true;
return;
}
isRendering = true;
requestAnimationFrame(() => {
renderGraph();
isRendering = false;
if (pendingRender) {
pendingRender = false;
scheduleRender();
}
});
}
$effect(() => {
if (!data || !svgContainer) return;
// Generate hash of current state
const currentHash = generateDataHash(data, isMinimized, debugEnabled);
const highlightHash = Array.from(highlightedNodes).sort().join(',');
// Get current dimensions
const rect = svgContainer.getBoundingClientRect();
const dimensionsChanged = rect.width !== lastDimensions.width || rect.height !== lastDimensions.height;
// Only re-render if something actually changed
if (currentHash !== lastRenderHash || highlightHash !== lastHighlightedNodesHash || dimensionsChanged) {
lastRenderHash = currentHash;
lastHighlightedNodesHash = highlightHash;
lastDimensions = { width: rect.width, height: rect.height };
scheduleRender();
if (data) {
renderGraph();
}
});
onMount(() => {
if (svgContainer) {
// Use a debounced resize observer to prevent rapid re-renders
let resizeTimeout: ReturnType<typeof setTimeout> | null = null;
resizeObserver = new ResizeObserver(() => {
if (resizeTimeout) clearTimeout(resizeTimeout);
resizeTimeout = setTimeout(() => {
const rect = svgContainer!.getBoundingClientRect();
if (rect.width !== lastDimensions.width || rect.height !== lastDimensions.height) {
lastDimensions = { width: rect.width, height: rect.height };
scheduleRender();
}
}, 100);
renderGraph();
});
resizeObserver.observe(svgContainer);
}
@@ -1036,20 +969,11 @@ function wrapLine(text: string, maxLen: number): string[] {
stroke-width: 1px;
stroke-dasharray: 4, 4;
opacity: 0.8;
/* Slower animation = less GPU usage */
animation: flowAnimation 2s linear infinite;
/* GPU optimization */
will-change: stroke-dashoffset;
animation: flowAnimation 0.75s linear infinite;
}
@keyframes flowAnimation {
from { stroke-dashoffset: 0; }
to { stroke-dashoffset: -10; }
}
/* Respect reduced motion preference */
@media (prefers-reduced-motion: reduce) {
:global(.graph-link) {
animation: none;
}
}
</style>

View File

@@ -297,35 +297,6 @@ function extractIpFromMultiaddr(ma?: string): string | undefined {
return undefined;
}
// Deep comparison utility for preventing unnecessary state updates
function shallowEqual(a: unknown, b: unknown): boolean {
if (a === b) return true;
if (a === null || b === null) return false;
if (typeof a !== 'object' || typeof b !== 'object') return false;
const aObj = a as Record<string, unknown>;
const bObj = b as Record<string, unknown>;
const aKeys = Object.keys(aObj);
const bKeys = Object.keys(bObj);
if (aKeys.length !== bKeys.length) return false;
for (const key of aKeys) {
if (aObj[key] !== bObj[key]) return false;
}
return true;
}
// Faster JSON comparison for complex nested objects
function jsonEqual(a: unknown, b: unknown): boolean {
if (a === b) return true;
try {
return JSON.stringify(a) === JSON.stringify(b);
} catch {
return false;
}
}
class AppStore {
// Conversation state
conversations = $state<Conversation[]>([]);
@@ -359,18 +330,9 @@ class AppStore {
topologyOnlyMode = $state(false);
chatSidebarVisible = $state(true); // Shown by default
// Visibility state - used to pause polling when tab is hidden
private isPageVisible = true;
private fetchInterval: ReturnType<typeof setInterval> | null = null;
private previewsInterval: ReturnType<typeof setInterval> | null = null;
private lastConversationPersistTs = 0;
// Cache for comparison - prevents unnecessary reactivity
private lastTopologyJson = '';
private lastInstancesJson = '';
private lastRunnersJson = '';
private lastDownloadsJson = '';
constructor() {
if (browser) {
@@ -379,26 +341,9 @@ class AppStore {
this.loadDebugModeFromStorage();
this.loadTopologyOnlyModeFromStorage();
this.loadChatSidebarVisibleFromStorage();
this.setupVisibilityListener();
}
}
/**
* Listen for page visibility changes to pause polling when hidden
*/
private setupVisibilityListener() {
if (typeof document === 'undefined') return;
document.addEventListener('visibilitychange', () => {
this.isPageVisible = document.visibilityState === 'visible';
if (this.isPageVisible) {
// Resume polling when page becomes visible
this.fetchState();
}
});
}
/**
* Load conversations from localStorage
*/
@@ -825,9 +770,7 @@ class AppStore {
startPolling() {
this.fetchState();
// Poll every 2 seconds instead of 1 second - reduces CPU/GPU load by 50%
// Data comparison ensures we only update when something actually changes
this.fetchInterval = setInterval(() => this.fetchState(), 2000);
this.fetchInterval = setInterval(() => this.fetchState(), 1000);
}
stopPolling() {
@@ -839,9 +782,6 @@ class AppStore {
}
async fetchState() {
// Skip polling when page is hidden to save resources
if (!this.isPageVisible) return;
try {
const response = await fetch('/state');
if (!response.ok) {
@@ -849,44 +789,19 @@ class AppStore {
}
const data: RawStateResponse = await response.json();
// Only update topology if it actually changed (prevents unnecessary D3 re-renders)
if (data.topology) {
const newTopology = transformTopology(data.topology, data.nodeProfiles);
const newTopologyJson = JSON.stringify(newTopology);
if (newTopologyJson !== this.lastTopologyJson) {
this.lastTopologyJson = newTopologyJson;
this.topologyData = newTopology;
}
this.topologyData = transformTopology(data.topology, data.nodeProfiles);
}
// Only update instances if changed
if (data.instances) {
const newInstancesJson = JSON.stringify(data.instances);
if (newInstancesJson !== this.lastInstancesJson) {
this.lastInstancesJson = newInstancesJson;
this.instances = data.instances;
this.refreshConversationModelFromInstances();
}
this.instances = data.instances;
this.refreshConversationModelFromInstances();
}
// Only update runners if changed
if (data.runners) {
const newRunnersJson = JSON.stringify(data.runners);
if (newRunnersJson !== this.lastRunnersJson) {
this.lastRunnersJson = newRunnersJson;
this.runners = data.runners;
}
this.runners = data.runners;
}
// Only update downloads if changed
if (data.downloads) {
const newDownloadsJson = JSON.stringify(data.downloads);
if (newDownloadsJson !== this.lastDownloadsJson) {
this.lastDownloadsJson = newDownloadsJson;
this.downloads = data.downloads;
}
this.downloads = data.downloads;
}
this.lastUpdate = Date.now();
} catch (error) {
console.error('Error fetching state:', error);

View File

@@ -1,25 +1,7 @@
<script lang="ts">
import '../app.css';
import { onMount } from 'svelte';
import { browser } from '$app/environment';
let { children } = $props();
let isPageHidden = $state(false);
onMount(() => {
if (!browser) return;
// Listen for visibility changes to pause animations when hidden
const handleVisibilityChange = () => {
isPageHidden = document.visibilityState === 'hidden';
};
document.addEventListener('visibilitychange', handleVisibilityChange);
return () => {
document.removeEventListener('visibilitychange', handleVisibilityChange);
};
});
</script>
<svelte:head>
@@ -27,7 +9,7 @@
<meta name="description" content="EXO - Distributed AI Cluster Dashboard" />
</svelte:head>
<div class="min-h-screen bg-background text-foreground" data-page-hidden={isPageHidden}>
<div class="min-h-screen bg-background text-foreground">
{@render children?.()}
</div>

View File

@@ -99,35 +99,17 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
}
// Compute highlighted nodes from hovered instance or hovered preview
// Memoized to avoid creating new Sets on every render
let lastHighlightedNodesKey = '';
let cachedHighlightedNodes: Set<string> = new Set();
const highlightedNodes = $derived(() => {
// Create a key for the current state to enable memoization
const previewKey = Array.from(hoveredPreviewNodes).sort().join(',');
const currentKey = `${hoveredInstanceId || 'null'}:${previewKey}`;
// Return cached value if nothing changed
if (currentKey === lastHighlightedNodesKey) {
return cachedHighlightedNodes;
}
lastHighlightedNodesKey = currentKey;
// First check instance hover
if (hoveredInstanceId) {
const instanceWrapped = instanceData[hoveredInstanceId];
cachedHighlightedNodes = unwrapInstanceNodes(instanceWrapped);
return cachedHighlightedNodes;
return unwrapInstanceNodes(instanceWrapped);
}
// Then check preview hover
if (hoveredPreviewNodes.size > 0) {
cachedHighlightedNodes = hoveredPreviewNodes;
return cachedHighlightedNodes;
return hoveredPreviewNodes;
}
cachedHighlightedNodes = new Set<string>();
return cachedHighlightedNodes;
return new Set<string>();
});
// Helper to estimate memory from model ID (mirrors ModelCard logic)
@@ -534,13 +516,12 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
};
}
// Debug: Log downloads data when it changes (disabled in production for performance)
// Uncomment for debugging:
// $effect(() => {
// if (downloadsData && Object.keys(downloadsData).length > 0) {
// console.log('[Download Debug] Current downloads:', downloadsData);
// }
// });
// Debug: Log downloads data when it changes
$effect(() => {
if (downloadsData && Object.keys(downloadsData).length > 0) {
console.log('[Download Debug] Current downloads:', downloadsData);
}
});
// Helper to get download status for an instance
function getInstanceDownloadStatus(instanceId: string, instanceWrapped: unknown): {

View File

@@ -1,5 +1,6 @@
import argparse
import multiprocessing as mp
import os
import signal
from dataclasses import dataclass, field
from typing import Self
@@ -194,6 +195,7 @@ def main():
# TODO: Refactor the current verbosity system
logger_setup(EXO_LOG, args.verbosity)
logger.info("Starting EXO")
logger.info(f"EXO_LIBP2P_NAMESPACE: {os.getenv('EXO_LIBP2P_NAMESPACE')}")
node = anyio.run(Node.create, args)
anyio.run(node.run)

View File

@@ -21,6 +21,7 @@ from exo.shared.types.commands import (
)
from exo.shared.types.events import Event, InstanceCreated, InstanceDeleted
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId
from exo.shared.types.topology import NodeInfo
from exo.shared.types.worker.instances import (
Instance,
@@ -29,6 +30,7 @@ from exo.shared.types.worker.instances import (
MlxJacclInstance,
MlxRingInstance,
)
from exo.shared.types.worker.shards import Sharding
def random_ephemeral_port() -> int:
@@ -65,6 +67,28 @@ def place_instance(
if not cycles_with_sufficient_memory:
raise ValueError("No cycles found with sufficient memory")
if command.sharding == Sharding.Tensor:
if not command.model_meta.supports_tensor:
raise ValueError(
f"Requested Tensor sharding but this model does not support tensor parallelism: {command.model_meta.model_id}"
)
# TODO: the condition here for tensor parallel is not correct, but it works good enough for now.
cycles_with_sufficient_memory = [
cycle
for cycle in cycles_with_sufficient_memory
if command.model_meta.hidden_size % len(cycle) == 0
]
if not cycles_with_sufficient_memory:
raise ValueError(
f"No tensor sharding found for model with hidden_size {command.model_meta.hidden_size} candidate cycles"
)
if command.sharding == Sharding.Pipeline and command.model_meta.model_id == ModelId(
"mlx-community/DeepSeek-V3.1-8bit"
):
raise ValueError(
"Pipeline parallelism is not supported for DeepSeek V3.1 (8-bit)"
)
smallest_cycles = get_smallest_cycles(cycles_with_sufficient_memory)
smallest_tb_cycles = [

View File

@@ -385,13 +385,14 @@ def get_mlx_jaccl_coordinators(
address in format "X.X.X.X:PORT" per node.
"""
rank_0_node = selected_cycle[0]
logger.info(f"Selecting coordinator from rank 0 node: {rank_0_node.node_id}")
logger.debug(f"Selecting coordinator from rank 0 node: {rank_0_node.node_id}")
def get_ip_for_node(n: NodeInfo) -> str:
if n.node_id == rank_0_node.node_id:
return "0.0.0.0"
for ip, _ in _find_connection_ip(n, rank_0_node, cycle_digraph):
ip = _find_ip_prioritised(n, rank_0_node, cycle_digraph)
if ip:
return ip
logger.warning(

View File

@@ -50,7 +50,7 @@ def model_meta() -> ModelMetadata:
storage_size=Memory.from_kb(1000),
pretty_name="Test Model",
n_layers=10,
hidden_size=10,
hidden_size=30,
supports_tensor=True,
)

View File

@@ -450,6 +450,11 @@ async def get_weight_map(repo_id: str, revision: str = "main") -> dict[str, str]
async def resolve_allow_patterns(shard: ShardMetadata) -> list[str]:
# TODO: 'Smart' downloads are disabled because:
# (i) We don't handle all kinds of files;
# (ii) We don't have sticky sessions.
# (iii) Tensor parallel requires all files.
return ["*"]
try:
weight_map = await get_weight_map(str(shard.model_meta.model_id))
return get_allow_patterns(weight_map, shard)

View File

@@ -9,7 +9,7 @@ MAX_KV_SIZE: int | None = 3200
KEEP_KV_SIZE: int | None = 1600
QUANTIZE_MODEL_MODE: str | None = "affine"
CACHE_GROUP_SIZE: int = 64
KV_CACHE_BITS: int | None = 8
KV_CACHE_BITS: int | None = None
# TODO: We should really make this opt-in, but Kimi requires trust_remote_code=True
TRUST_REMOTE_CODE: bool = True

View File

@@ -395,11 +395,5 @@ def set_wired_limit_for_model(model_size: Memory):
"MB. This can be slow. See the documentation for possible work-arounds: "
"https://github.com/ml-explore/mlx-lm/tree/main#large-models"
)
kv_bytes = int(0.02 * model_bytes)
target_cache = int(1.10 * (model_bytes + kv_bytes))
target_cache = min(target_cache, max_rec_size)
mx.set_cache_limit(target_cache)
mx.set_wired_limit(max_rec_size)
logger.info(
f"Wired limit set to {max_rec_size}. Cache limit set to {target_cache}."
)
logger.info(f"Wired limit set to {max_rec_size}.")

View File

@@ -23,6 +23,7 @@ from exo.shared.types.events import (
TopologyEdgeCreated,
TopologyEdgeDeleted,
)
from exo.shared.types.models import ModelId
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.profiling import MemoryPerformanceProfile, NodePerformanceProfile
from exo.shared.types.state import State
@@ -83,7 +84,7 @@ class Worker:
self.out_for_delivery: dict[EventId, ForwarderEvent] = {}
self.state: State = State()
self.download_status: dict[ShardMetadata, DownloadProgress] = {}
self.download_status: dict[ModelId, DownloadProgress] = {}
self.runners: dict[RunnerId, RunnerSupervisor] = {}
self._tg: TaskGroup | None = None
@@ -128,6 +129,7 @@ class Worker:
tg.start_soon(start_polling_node_metrics, resource_monitor_callback)
tg.start_soon(start_polling_memory_metrics, memory_monitor_callback)
tg.start_soon(self._emit_existing_download_progress)
tg.start_soon(self._connection_message_event_writer)
tg.start_soon(self._resend_out_for_delivery)
tg.start_soon(self._event_applier)
@@ -200,11 +202,11 @@ class Worker:
)
)
case DownloadModel(shard_metadata=shard):
if shard not in self.download_status:
if shard.model_meta.model_id not in self.download_status:
progress = DownloadPending(
shard_metadata=shard, node_id=self.node_id
)
self.download_status[shard] = progress
self.download_status[shard.model_meta.model_id] = progress
await self.event_sender.send(
NodeDownloadProgress(download_progress=progress)
)
@@ -217,7 +219,7 @@ class Worker:
progress = DownloadCompleted(
shard_metadata=shard, node_id=self.node_id
)
self.download_status[shard] = progress
self.download_status[shard.model_meta.model_id] = progress
await self.event_sender.send(
NodeDownloadProgress(download_progress=progress)
)
@@ -349,7 +351,7 @@ class Worker:
initial_progress
),
)
self.download_status[task.shard_metadata] = status
self.download_status[task.shard_metadata.model_meta.model_id] = status
self.event_sender.send_nowait(NodeDownloadProgress(download_progress=status))
last_progress_time = 0.0
@@ -363,7 +365,7 @@ class Worker:
nonlocal last_progress_time
if progress.status == "complete":
status = DownloadCompleted(shard_metadata=shard, node_id=self.node_id)
self.download_status[shard] = status
self.download_status[shard.model_meta.model_id] = status
# Footgun!
self.event_sender.send_nowait(
NodeDownloadProgress(download_progress=status)
@@ -384,7 +386,7 @@ class Worker:
progress
),
)
self.download_status[shard] = status
self.download_status[shard.model_meta.model_id] = status
self.event_sender.send_nowait(
NodeDownloadProgress(download_progress=status)
)
@@ -444,3 +446,40 @@ class Worker:
await self.event_sender.send(TopologyEdgeDeleted(edge=conn))
await anyio.sleep(10)
async def _emit_existing_download_progress(self) -> None:
try:
while True:
logger.info("Fetching and emitting existing download progress...")
async for (
_,
progress,
) in self.shard_downloader.get_shard_download_status():
if progress.status == "complete":
status = DownloadCompleted(
node_id=self.node_id, shard_metadata=progress.shard
)
elif progress.status in ["in_progress", "not_started"]:
if progress.downloaded_bytes_this_session.in_bytes == 0:
status = DownloadPending(
node_id=self.node_id, shard_metadata=progress.shard
)
else:
status = DownloadOngoing(
node_id=self.node_id,
shard_metadata=progress.shard,
download_progress=map_repo_download_progress_to_download_progress_data(
progress
),
)
else:
continue
self.download_status[progress.shard.model_meta.model_id] = status
await self.event_sender.send(
NodeDownloadProgress(download_progress=status)
)
logger.info("Done emitting existing download progress.")
await anyio.sleep(5 * 60) # 5 minutes
except Exception as e:
logger.error(f"Error emitting existing download progress: {e}")

View File

@@ -3,6 +3,7 @@
from collections.abc import Mapping, Sequence
from exo.shared.types.common import NodeId
from exo.shared.types.models import ModelId
from exo.shared.types.tasks import (
ChatCompletion,
ConnectToGroup,
@@ -34,7 +35,6 @@ from exo.shared.types.worker.runners import (
RunnerStatus,
RunnerWarmingUp,
)
from exo.shared.types.worker.shards import ShardMetadata
from exo.worker.runner.runner_supervisor import RunnerSupervisor
@@ -43,7 +43,7 @@ def plan(
# Runners is expected to be FRESH and so should not come from state
runners: Mapping[RunnerId, RunnerSupervisor],
# DL_status is expected to be FRESH and so should not come from state
download_status: Mapping[ShardMetadata, DownloadProgress],
download_status: Mapping[ModelId, DownloadProgress],
# gdls is not expected to be fresh
global_download_status: Mapping[NodeId, Sequence[DownloadProgress]],
instances: Mapping[InstanceId, Instance],
@@ -111,13 +111,14 @@ def _create_runner(
def _model_needs_download(
runners: Mapping[RunnerId, RunnerSupervisor],
download_status: Mapping[ShardMetadata, DownloadProgress],
download_status: Mapping[ModelId, DownloadProgress],
) -> DownloadModel | None:
for runner in runners.values():
model_id = runner.bound_instance.bound_shard.model_meta.model_id
if isinstance(runner.status, RunnerIdle) and (
not isinstance(
download_status.get(runner.bound_instance.bound_shard, None),
(DownloadOngoing, DownloadCompleted),
model_id not in download_status
or not isinstance(
download_status[model_id], (DownloadOngoing, DownloadCompleted)
)
):
# We don't invalidate download_status randomly in case a file gets deleted on disk
@@ -235,9 +236,8 @@ def _ready_to_warmup(
assert device_rank < world_size
assert device_rank >= 0
# TODO: Ensure these align with MLX distributeds expectations.
# Rank < n-1
accepting_ranks_ready = device_rank < world_size - 1 and all(
# Rank != 0
accepting_ranks_ready = device_rank > 0 and all(
isinstance(
all_runners.get(global_runner_id, None),
(RunnerLoaded, RunnerWarmingUp),
@@ -245,8 +245,8 @@ def _ready_to_warmup(
for global_runner_id in shard_assignments.runner_to_shard
)
# Rank = n-1
connecting_rank_ready = device_rank == world_size - 1 and all(
# Rank = 0
connecting_rank_ready = device_rank == 0 and all(
isinstance(all_runners.get(global_runner_id, None), RunnerWarmingUp)
for global_runner_id in shard_assignments.runner_to_shard
if global_runner_id != runner_id

View File

@@ -9,9 +9,11 @@ MASTER_NODE_ID = NodeId("ffffffff-aaaa-4aaa-8aaa-aaaaaaaaaaaa")
NODE_A: Final[NodeId] = NodeId("aaaaaaaa-aaaa-4aaa-8aaa-aaaaaaaaaaaa")
NODE_B: Final[NodeId] = NodeId("bbbbbbbb-bbbb-4bbb-8bbb-bbbbbbbbbbbb")
NODE_C: Final[NodeId] = NodeId("cccccccc-cccc-4ccc-8ccc-cccccccccccc")
RUNNER_1_ID: Final[RunnerId] = RunnerId("11111111-1111-4111-8111-111111111111")
RUNNER_2_ID: Final[RunnerId] = RunnerId("33333333-3333-4333-8333-333333333333")
RUNNER_3_ID: Final[RunnerId] = RunnerId("Runner3")
INSTANCE_1_ID: Final[InstanceId] = InstanceId("22222222-2222-4222-8222-222222222222")
INSTANCE_2_ID: Final[InstanceId] = InstanceId("44444444-4444-4444-8444-444444444444")

View File

@@ -1,5 +1,6 @@
import exo.worker.plan as plan_mod
from exo.shared.types.common import NodeId
from exo.shared.types.models import ModelId
from exo.shared.types.tasks import LoadModel
from exo.shared.types.worker.downloads import DownloadCompleted, DownloadProgress
from exo.shared.types.worker.instances import BoundInstance
@@ -7,7 +8,6 @@ from exo.shared.types.worker.runners import (
RunnerConnected,
RunnerIdle,
)
from exo.shared.types.worker.shards import ShardMetadata
from exo.worker.tests.constants import (
INSTANCE_1_ID,
MODEL_A_ID,
@@ -46,7 +46,7 @@ def test_plan_requests_download_when_waiting_and_shard_not_downloaded():
all_runners = {RUNNER_1_ID: RunnerIdle()}
# No entry for this shard -> should trigger DownloadModel
download_status: dict[ShardMetadata, DownloadProgress] = {}
download_status: dict[ModelId, DownloadProgress] = {}
result = plan_mod.plan(
node_id=NODE_A,
@@ -94,7 +94,7 @@ def test_plan_loads_model_when_all_shards_downloaded_and_waiting():
# Local node has already marked its shard as downloaded (not actually used by _load_model)
local_download_status = {
shard1: DownloadCompleted(shard_metadata=shard1, node_id=NODE_A) # type: ignore[reportUnhashable]
MODEL_A_ID: DownloadCompleted(shard_metadata=shard1, node_id=NODE_A)
}
# Global view has completed downloads for both nodes
@@ -140,7 +140,7 @@ def test_plan_does_not_request_download_when_shard_already_downloaded():
# Local status claims the shard is downloaded already
local_download_status = {
shard: DownloadCompleted(shard_metadata=shard, node_id=NODE_A) # type: ignore[reportUnhashable]
MODEL_A_ID: DownloadCompleted(shard_metadata=shard, node_id=NODE_A)
}
# Global view hasn't caught up yet (no completed shards recorded for NODE_A)
@@ -192,7 +192,7 @@ def test_plan_does_not_load_model_until_all_shards_downloaded_globally():
# Only NODE_A's shard is recorded as downloaded globally
local_download_status = {
shard1: DownloadCompleted(shard_metadata=shard1, node_id=NODE_A) # type: ignore[reportUnhashable]
MODEL_A_ID: DownloadCompleted(shard_metadata=shard1, node_id=NODE_A)
}
global_download_status = {
NODE_A: [DownloadCompleted(shard_metadata=shard1, node_id=NODE_A)],

View File

@@ -12,8 +12,10 @@ from exo.worker.tests.constants import (
MODEL_A_ID,
NODE_A,
NODE_B,
NODE_C,
RUNNER_1_ID,
RUNNER_2_ID,
RUNNER_3_ID,
)
from exo.worker.tests.unittests.conftest import (
FakeRunnerSupervisor,
@@ -24,37 +26,39 @@ from exo.worker.tests.unittests.conftest import (
def test_plan_starts_warmup_for_accepting_rank_when_all_loaded_or_warming():
"""
For non-final device_rank shards, StartWarmup should be emitted when all
For non-zero device_rank shards, StartWarmup should be emitted when all
shards in the instance are Loaded/WarmingUp.
"""
shard0 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)
shard0 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=3)
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=3)
shard2 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=2, world_size=3)
instance = get_mlx_ring_instance(
instance_id=INSTANCE_1_ID,
model_id=MODEL_A_ID,
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},
runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1},
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID, NODE_C: RUNNER_3_ID},
runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1, RUNNER_3_ID: shard2},
)
bound_instance = BoundInstance(
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
instance=instance, bound_runner_id=RUNNER_2_ID, bound_node_id=NODE_B
)
local_runner = FakeRunnerSupervisor(
bound_instance=bound_instance, status=RunnerLoaded()
)
runners = {RUNNER_1_ID: local_runner}
runners = {RUNNER_2_ID: local_runner}
instances = {INSTANCE_1_ID: instance}
all_runners = {
RUNNER_1_ID: RunnerLoaded(),
RUNNER_2_ID: RunnerLoaded(),
RUNNER_3_ID: RunnerWarmingUp(),
}
result = plan_mod.plan(
node_id=NODE_A,
node_id=NODE_B,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_B: []},
global_download_status={NODE_A: []},
instances=instances,
all_runners=all_runners,
tasks={},
@@ -150,9 +154,9 @@ def test_plan_does_not_start_warmup_for_rank_zero_until_others_warming():
"""
Rank-zero shard should not start warmup until all non-zero ranks are
already WarmingUp.
For accepting ranks (device_rank != world_size - 1), StartWarmup should be
For accepting ranks (device_rank != 0), StartWarmup should be
emitted when all shards in the instance are Loaded/WarmingUp.
In a 2-node setup, rank 0 is the accepting rank.
In a 2-node setup, rank 1 is the accepting rank.
"""
shard0 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)
@@ -163,7 +167,7 @@ def test_plan_does_not_start_warmup_for_rank_zero_until_others_warming():
runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1},
)
# Rank 0 is the accepting rank
# Rank 1 is the accepting rank
bound_instance = BoundInstance(
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
)
@@ -188,6 +192,23 @@ def test_plan_does_not_start_warmup_for_rank_zero_until_others_warming():
tasks={},
)
assert result is None
all_runners = {
RUNNER_1_ID: RunnerLoaded(),
RUNNER_2_ID: RunnerWarmingUp(),
}
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_A: []},
instances=instances,
all_runners=all_runners,
tasks={},
)
assert isinstance(result, StartWarmup)
assert result.instance_id == INSTANCE_1_ID
@@ -280,9 +301,8 @@ def test_plan_does_not_start_warmup_for_accepting_rank_until_all_loaded_or_warmi
def test_plan_does_not_start_warmup_for_connecting_rank_until_others_warming():
"""
Connecting rank (device_rank == world_size - 1) should not start warmup
Connecting rank (device_rank == 0) should not start warmup
until all other ranks are already WarmingUp.
In a 2-node setup, rank 1 is the connecting rank.
"""
shard0 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)
@@ -295,13 +315,13 @@ def test_plan_does_not_start_warmup_for_connecting_rank_until_others_warming():
# Rank 1 is the connecting rank
bound_instance = BoundInstance(
instance=instance, bound_runner_id=RUNNER_2_ID, bound_node_id=NODE_B
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
)
local_runner = FakeRunnerSupervisor(
bound_instance=bound_instance, status=RunnerLoaded()
)
runners = {RUNNER_2_ID: local_runner}
runners = {RUNNER_1_ID: local_runner}
instances = {INSTANCE_1_ID: instance}
all_runners = {
RUNNER_1_ID: RunnerLoaded(),
@@ -309,7 +329,7 @@ def test_plan_does_not_start_warmup_for_connecting_rank_until_others_warming():
}
result = plan_mod.plan(
node_id=NODE_B,
node_id=NODE_A,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_A: [], NODE_B: []},