Compare commits

..

1 Commits

Author SHA1 Message Date
Sami Khan
40cbecb5c4 optimize dashboard 2025-12-29 22:00:59 +05:00
25 changed files with 338 additions and 370 deletions

View File

@@ -20,8 +20,6 @@ struct ContentView: View {
@State private var showDebugInfo = false
@State private var bugReportInFlight = false
@State private var bugReportMessage: String?
@State private var showAdvancedOptions = false
@State private var pendingNamespace: String = ""
var body: some View {
VStack(alignment: .leading, spacing: 12) {
@@ -199,8 +197,6 @@ struct ContentView: View {
updater.checkForUpdates()
}
.padding(.bottom, 8)
advancedOptionsSection
.padding(.bottom, 8)
debugSection
.padding(.bottom, 8)
controlButton(title: "Quit", tint: .secondary) {
@@ -331,47 +327,6 @@ struct ContentView: View {
}
}
private var advancedOptionsSection: some View {
VStack(alignment: .leading, spacing: 6) {
HStack {
Text("Advanced Options")
.font(.caption)
.foregroundColor(.secondary)
Spacer()
collapseButton(isExpanded: $showAdvancedOptions)
}
.animation(nil, value: showAdvancedOptions)
if showAdvancedOptions {
VStack(alignment: .leading, spacing: 8) {
VStack(alignment: .leading, spacing: 4) {
Text("Cluster Namespace")
.font(.caption2)
.foregroundColor(.secondary)
HStack {
TextField("optional", text: $pendingNamespace)
.textFieldStyle(.roundedBorder)
.font(.caption2)
.onAppear {
pendingNamespace = controller.customNamespace
}
Button("Save & Restart") {
controller.customNamespace = pendingNamespace
if controller.status == .running || controller.status == .starting {
controller.restart()
}
}
.font(.caption2)
.disabled(pendingNamespace == controller.customNamespace)
}
}
}
.transition(.opacity)
}
}
.animation(.easeInOut(duration: 0.25), value: showAdvancedOptions)
}
private var debugSection: some View {
VStack(alignment: .leading, spacing: 6) {
HStack {

View File

@@ -2,8 +2,6 @@ import AppKit
import Combine
import Foundation
private let customNamespaceKey = "EXOCustomNamespace"
@MainActor
final class ExoProcessController: ObservableObject {
enum Status: Equatable {
@@ -29,13 +27,6 @@ final class ExoProcessController: ObservableObject {
@Published private(set) var status: Status = .stopped
@Published private(set) var lastError: String?
@Published private(set) var launchCountdownSeconds: Int?
@Published var customNamespace: String = {
return UserDefaults.standard.string(forKey: customNamespaceKey) ?? ""
}() {
didSet {
UserDefaults.standard.set(customNamespace, forKey: customNamespaceKey)
}
}
private var process: Process?
private var runtimeDirectoryURL: URL?
@@ -189,7 +180,7 @@ final class ExoProcessController: ObservableObject {
private func makeEnvironment(for runtimeURL: URL) -> [String: String] {
var environment = ProcessInfo.processInfo.environment
environment["EXO_RUNTIME_DIR"] = runtimeURL.path
environment["EXO_LIBP2P_NAMESPACE"] = computeNamespace()
environment["EXO_LIBP2P_NAMESPACE"] = buildTag()
var paths: [String] = []
if let existing = environment["PATH"], !existing.isEmpty {
@@ -226,12 +217,6 @@ final class ExoProcessController: ObservableObject {
}
return "dev"
}
private func computeNamespace() -> String {
let base = buildTag()
let custom = customNamespace.trimmingCharacters(in: .whitespaces)
return custom.isEmpty ? base : custom
}
}
struct RuntimeError: LocalizedError {

View File

@@ -198,8 +198,10 @@
stroke: oklch(0.85 0.18 85 / 0.4);
stroke-width: 1.5px;
stroke-dasharray: 8, 8;
animation: flowAnimation 1s linear infinite;
animation: flowAnimation 1.5s linear infinite;
filter: drop-shadow(0 0 3px oklch(0.85 0.18 85 / 0.5));
/* GPU optimization - hint to browser this element will animate */
will-change: stroke-dashoffset;
}
.graph-link-active {
@@ -208,6 +210,24 @@
filter: drop-shadow(0 0 6px oklch(0.85 0.18 85 / 0.8));
}
/* Reduce motion for users who prefer it - also saves GPU */
@media (prefers-reduced-motion: reduce) {
.graph-link {
animation: none;
}
.shooting-star {
animation: none;
display: none;
}
.status-pulse,
.cursor-blink,
.animate-pulse {
animation: none;
}
}
/* CRT Screen effect for topology */
.crt-screen {
position: relative;
@@ -266,13 +286,15 @@ input:focus, textarea:focus {
box-shadow: none;
}
/* Shooting Stars Animation */
/* Shooting Stars Animation - GPU optimized */
.shooting-stars {
position: fixed;
inset: 0;
overflow: hidden;
pointer-events: none;
z-index: 0;
/* Only render when visible */
content-visibility: auto;
}
.shooting-star {
@@ -285,6 +307,9 @@ input:focus, textarea:focus {
animation: shootingStar var(--duration, 3s) linear infinite;
animation-delay: var(--delay, 0s);
opacity: 0;
/* GPU optimization */
will-change: transform, opacity;
transform: translateZ(0);
}
.shooting-star::before {
@@ -320,3 +345,13 @@ input:focus, textarea:focus {
transform: translate(400px, 400px);
}
}
/* Pause animations when page is hidden to save resources */
:root:has(body[data-page-hidden="true"]) {
.shooting-star,
.graph-link,
.status-pulse,
.cursor-blink {
animation-play-state: paused;
}
}

View File

@@ -139,11 +139,6 @@
}
function handleKeydown(event: KeyboardEvent) {
// Prevent form submission during IME composition (e.g., Chinese, Japanese, Korean input)
if (event.isComposing || event.keyCode === 229) {
return;
}
if (event.key === 'Enter' && !event.shiftKey) {
event.preventDefault();
handleSubmit();

View File

@@ -1,5 +1,5 @@
<script lang="ts">
import { onMount, onDestroy } from 'svelte';
import { onMount, onDestroy, tick } from 'svelte';
import * as d3 from 'd3';
import { topologyData, isTopologyMinimized, debugMode } from '$lib/stores/app.svelte';
@@ -12,11 +12,35 @@ import { topologyData, isTopologyMinimized, debugMode } from '$lib/stores/app.sv
let svgContainer: SVGSVGElement | undefined = $state();
let resizeObserver: ResizeObserver | undefined;
// Optimization: Track last render state to avoid unnecessary re-renders
let lastRenderHash = '';
let lastHighlightedNodesHash = '';
let lastDimensions = { width: 0, height: 0 };
let isRendering = false;
let pendingRender = false;
const isMinimized = $derived(isTopologyMinimized());
const data = $derived(topologyData());
const debugEnabled = $derived(debugMode());
// Generate a hash of relevant data to detect actual changes
function generateDataHash(topologyData: typeof data, minimized: boolean, debug: boolean): string {
if (!topologyData) return 'null';
const nodes = topologyData.nodes || {};
const edges = topologyData.edges || [];
// Create a lightweight hash from key properties only
const nodeHashes = Object.entries(nodes).map(([id, n]) => {
const macmon = n.macmon_info;
return `${id}:${n.friendly_name || ''}:${macmon?.memory?.ram_usage || 0}:${macmon?.memory?.ram_total || 0}:${macmon?.temp?.gpu_temp_avg || 0}:${macmon?.gpu_usage?.[1] || 0}:${macmon?.sys_power || 0}`;
}).sort().join('|');
const edgeHash = edges.map(e => `${e.source}-${e.target}`).sort().join(',');
return `${nodeHashes}::${edgeHash}::${minimized}::${debug}`;
}
function getNodeLabel(nodeId: string): string {
const node = data?.nodes?.[nodeId];
return node?.friendly_name || nodeId.slice(0, 8);
@@ -932,16 +956,59 @@ function wrapLine(text: string, maxLen: number): string[] {
}
$effect(() => {
if (data) {
// Throttled render function to prevent too-frequent updates
function scheduleRender() {
if (isRendering) {
pendingRender = true;
return;
}
isRendering = true;
requestAnimationFrame(() => {
renderGraph();
isRendering = false;
if (pendingRender) {
pendingRender = false;
scheduleRender();
}
});
}
$effect(() => {
if (!data || !svgContainer) return;
// Generate hash of current state
const currentHash = generateDataHash(data, isMinimized, debugEnabled);
const highlightHash = Array.from(highlightedNodes).sort().join(',');
// Get current dimensions
const rect = svgContainer.getBoundingClientRect();
const dimensionsChanged = rect.width !== lastDimensions.width || rect.height !== lastDimensions.height;
// Only re-render if something actually changed
if (currentHash !== lastRenderHash || highlightHash !== lastHighlightedNodesHash || dimensionsChanged) {
lastRenderHash = currentHash;
lastHighlightedNodesHash = highlightHash;
lastDimensions = { width: rect.width, height: rect.height };
scheduleRender();
}
});
onMount(() => {
if (svgContainer) {
// Use a debounced resize observer to prevent rapid re-renders
let resizeTimeout: ReturnType<typeof setTimeout> | null = null;
resizeObserver = new ResizeObserver(() => {
renderGraph();
if (resizeTimeout) clearTimeout(resizeTimeout);
resizeTimeout = setTimeout(() => {
const rect = svgContainer!.getBoundingClientRect();
if (rect.width !== lastDimensions.width || rect.height !== lastDimensions.height) {
lastDimensions = { width: rect.width, height: rect.height };
scheduleRender();
}
}, 100);
});
resizeObserver.observe(svgContainer);
}
@@ -969,11 +1036,20 @@ function wrapLine(text: string, maxLen: number): string[] {
stroke-width: 1px;
stroke-dasharray: 4, 4;
opacity: 0.8;
animation: flowAnimation 0.75s linear infinite;
/* Slower animation = less GPU usage */
animation: flowAnimation 2s linear infinite;
/* GPU optimization */
will-change: stroke-dashoffset;
}
@keyframes flowAnimation {
from { stroke-dashoffset: 0; }
to { stroke-dashoffset: -10; }
}
/* Respect reduced motion preference */
@media (prefers-reduced-motion: reduce) {
:global(.graph-link) {
animation: none;
}
}
</style>

View File

@@ -297,6 +297,35 @@ function extractIpFromMultiaddr(ma?: string): string | undefined {
return undefined;
}
// Deep comparison utility for preventing unnecessary state updates
function shallowEqual(a: unknown, b: unknown): boolean {
if (a === b) return true;
if (a === null || b === null) return false;
if (typeof a !== 'object' || typeof b !== 'object') return false;
const aObj = a as Record<string, unknown>;
const bObj = b as Record<string, unknown>;
const aKeys = Object.keys(aObj);
const bKeys = Object.keys(bObj);
if (aKeys.length !== bKeys.length) return false;
for (const key of aKeys) {
if (aObj[key] !== bObj[key]) return false;
}
return true;
}
// Faster JSON comparison for complex nested objects
function jsonEqual(a: unknown, b: unknown): boolean {
if (a === b) return true;
try {
return JSON.stringify(a) === JSON.stringify(b);
} catch {
return false;
}
}
class AppStore {
// Conversation state
conversations = $state<Conversation[]>([]);
@@ -330,9 +359,18 @@ class AppStore {
topologyOnlyMode = $state(false);
chatSidebarVisible = $state(true); // Shown by default
// Visibility state - used to pause polling when tab is hidden
private isPageVisible = true;
private fetchInterval: ReturnType<typeof setInterval> | null = null;
private previewsInterval: ReturnType<typeof setInterval> | null = null;
private lastConversationPersistTs = 0;
// Cache for comparison - prevents unnecessary reactivity
private lastTopologyJson = '';
private lastInstancesJson = '';
private lastRunnersJson = '';
private lastDownloadsJson = '';
constructor() {
if (browser) {
@@ -341,9 +379,26 @@ class AppStore {
this.loadDebugModeFromStorage();
this.loadTopologyOnlyModeFromStorage();
this.loadChatSidebarVisibleFromStorage();
this.setupVisibilityListener();
}
}
/**
* Listen for page visibility changes to pause polling when hidden
*/
private setupVisibilityListener() {
if (typeof document === 'undefined') return;
document.addEventListener('visibilitychange', () => {
this.isPageVisible = document.visibilityState === 'visible';
if (this.isPageVisible) {
// Resume polling when page becomes visible
this.fetchState();
}
});
}
/**
* Load conversations from localStorage
*/
@@ -770,7 +825,9 @@ class AppStore {
startPolling() {
this.fetchState();
this.fetchInterval = setInterval(() => this.fetchState(), 1000);
// Poll every 2 seconds instead of 1 second - reduces CPU/GPU load by 50%
// Data comparison ensures we only update when something actually changes
this.fetchInterval = setInterval(() => this.fetchState(), 2000);
}
stopPolling() {
@@ -782,6 +839,9 @@ class AppStore {
}
async fetchState() {
// Skip polling when page is hidden to save resources
if (!this.isPageVisible) return;
try {
const response = await fetch('/state');
if (!response.ok) {
@@ -789,19 +849,44 @@ class AppStore {
}
const data: RawStateResponse = await response.json();
// Only update topology if it actually changed (prevents unnecessary D3 re-renders)
if (data.topology) {
this.topologyData = transformTopology(data.topology, data.nodeProfiles);
const newTopology = transformTopology(data.topology, data.nodeProfiles);
const newTopologyJson = JSON.stringify(newTopology);
if (newTopologyJson !== this.lastTopologyJson) {
this.lastTopologyJson = newTopologyJson;
this.topologyData = newTopology;
}
}
// Only update instances if changed
if (data.instances) {
this.instances = data.instances;
this.refreshConversationModelFromInstances();
const newInstancesJson = JSON.stringify(data.instances);
if (newInstancesJson !== this.lastInstancesJson) {
this.lastInstancesJson = newInstancesJson;
this.instances = data.instances;
this.refreshConversationModelFromInstances();
}
}
// Only update runners if changed
if (data.runners) {
this.runners = data.runners;
const newRunnersJson = JSON.stringify(data.runners);
if (newRunnersJson !== this.lastRunnersJson) {
this.lastRunnersJson = newRunnersJson;
this.runners = data.runners;
}
}
// Only update downloads if changed
if (data.downloads) {
this.downloads = data.downloads;
const newDownloadsJson = JSON.stringify(data.downloads);
if (newDownloadsJson !== this.lastDownloadsJson) {
this.lastDownloadsJson = newDownloadsJson;
this.downloads = data.downloads;
}
}
this.lastUpdate = Date.now();
} catch (error) {
console.error('Error fetching state:', error);

View File

@@ -1,7 +1,25 @@
<script lang="ts">
import '../app.css';
import { onMount } from 'svelte';
import { browser } from '$app/environment';
let { children } = $props();
let isPageHidden = $state(false);
onMount(() => {
if (!browser) return;
// Listen for visibility changes to pause animations when hidden
const handleVisibilityChange = () => {
isPageHidden = document.visibilityState === 'hidden';
};
document.addEventListener('visibilitychange', handleVisibilityChange);
return () => {
document.removeEventListener('visibilitychange', handleVisibilityChange);
};
});
</script>
<svelte:head>
@@ -9,7 +27,7 @@
<meta name="description" content="EXO - Distributed AI Cluster Dashboard" />
</svelte:head>
<div class="min-h-screen bg-background text-foreground">
<div class="min-h-screen bg-background text-foreground" data-page-hidden={isPageHidden}>
{@render children?.()}
</div>

View File

@@ -51,59 +51,6 @@ const sidebarVisible = $derived(chatSidebarVisible());
let selectedSharding = $state<'Pipeline' | 'Tensor'>('Pipeline');
type InstanceMeta = 'MlxRing' | 'MlxIbv' | 'MlxJaccl';
// Launch defaults persistence
const LAUNCH_DEFAULTS_KEY = 'exo-launch-defaults';
interface LaunchDefaults {
modelId: string | null;
sharding: 'Pipeline' | 'Tensor';
instanceType: InstanceMeta;
minNodes: number;
}
function saveLaunchDefaults(): void {
const defaults: LaunchDefaults = {
modelId: selectedPreviewModelId(),
sharding: selectedSharding,
instanceType: selectedInstanceType,
minNodes: selectedMinNodes,
};
try {
localStorage.setItem(LAUNCH_DEFAULTS_KEY, JSON.stringify(defaults));
} catch (e) {
console.warn('Failed to save launch defaults:', e);
}
}
function loadLaunchDefaults(): LaunchDefaults | null {
try {
const stored = localStorage.getItem(LAUNCH_DEFAULTS_KEY);
if (!stored) return null;
return JSON.parse(stored) as LaunchDefaults;
} catch (e) {
console.warn('Failed to load launch defaults:', e);
return null;
}
}
function applyLaunchDefaults(availableModels: Array<{id: string}>, maxNodes: number): void {
const defaults = loadLaunchDefaults();
if (!defaults) return;
// Apply sharding and instance type unconditionally
selectedSharding = defaults.sharding;
selectedInstanceType = defaults.instanceType;
// Apply minNodes if valid (between 1 and maxNodes)
if (defaults.minNodes && defaults.minNodes >= 1 && defaults.minNodes <= maxNodes) {
selectedMinNodes = defaults.minNodes;
}
// Only apply model if it exists in the available models
if (defaults.modelId && availableModels.some(m => m.id === defaults.modelId)) {
selectPreviewModel(defaults.modelId);
}
}
let selectedInstanceType = $state<InstanceMeta>('MlxRing');
let selectedMinNodes = $state<number>(1);
let minNodesInitialized = $state(false);
@@ -152,17 +99,35 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
}
// Compute highlighted nodes from hovered instance or hovered preview
// Memoized to avoid creating new Sets on every render
let lastHighlightedNodesKey = '';
let cachedHighlightedNodes: Set<string> = new Set();
const highlightedNodes = $derived(() => {
// Create a key for the current state to enable memoization
const previewKey = Array.from(hoveredPreviewNodes).sort().join(',');
const currentKey = `${hoveredInstanceId || 'null'}:${previewKey}`;
// Return cached value if nothing changed
if (currentKey === lastHighlightedNodesKey) {
return cachedHighlightedNodes;
}
lastHighlightedNodesKey = currentKey;
// First check instance hover
if (hoveredInstanceId) {
const instanceWrapped = instanceData[hoveredInstanceId];
return unwrapInstanceNodes(instanceWrapped);
cachedHighlightedNodes = unwrapInstanceNodes(instanceWrapped);
return cachedHighlightedNodes;
}
// Then check preview hover
if (hoveredPreviewNodes.size > 0) {
return hoveredPreviewNodes;
cachedHighlightedNodes = hoveredPreviewNodes;
return cachedHighlightedNodes;
}
return new Set<string>();
cachedHighlightedNodes = new Set<string>();
return cachedHighlightedNodes;
});
// Helper to estimate memory from model ID (mirrors ModelCard logic)
@@ -351,9 +316,6 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
const data = await response.json();
// API returns { data: [{ id, name }] } format
models = data.data || [];
// Restore last launch defaults if available
const currentNodeCount = topologyData() ? Object.keys(topologyData()!.nodes).length : 1;
applyLaunchDefaults(models, currentNodeCount);
}
} catch (error) {
console.error('Failed to fetch models:', error);
@@ -572,12 +534,13 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
};
}
// Debug: Log downloads data when it changes
$effect(() => {
if (downloadsData && Object.keys(downloadsData).length > 0) {
console.log('[Download Debug] Current downloads:', downloadsData);
}
});
// Debug: Log downloads data when it changes (disabled in production for performance)
// Uncomment for debugging:
// $effect(() => {
// if (downloadsData && Object.keys(downloadsData).length > 0) {
// console.log('[Download Debug] Current downloads:', downloadsData);
// }
// });
// Helper to get download status for an instance
function getInstanceDownloadStatus(instanceId: string, instanceWrapped: unknown): {
@@ -1044,7 +1007,6 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
function handleSliderMouseUp() {
isDraggingSlider = false;
saveLaunchDefaults();
}
// Handle touch events for mobile
@@ -1064,7 +1026,6 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
function handleSliderTouchEnd() {
isDraggingSlider = false;
saveLaunchDefaults();
}
const nodeCount = $derived(data ? Object.keys(data.nodes).length : 0);
@@ -1522,7 +1483,6 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
onclick={() => {
if (modelCanFit) {
selectPreviewModel(model.id);
saveLaunchDefaults();
isModelDropdownOpen = false;
modelDropdownSearch = '';
}
@@ -1556,7 +1516,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
<div class="text-xs text-white/70 font-mono mb-2">Sharding:</div>
<div class="flex gap-2">
<button
onclick={() => { selectedSharding = 'Pipeline'; saveLaunchDefaults(); }}
onclick={() => selectedSharding = 'Pipeline'}
class="flex items-center gap-2 py-2 px-4 text-sm font-mono border rounded transition-all duration-200 cursor-pointer {selectedSharding === 'Pipeline' ? 'bg-transparent text-exo-yellow border-exo-yellow' : 'bg-transparent text-white/70 border-exo-medium-gray/50 hover:border-exo-yellow/50'}"
>
<span class="w-4 h-4 rounded-full border-2 flex items-center justify-center {selectedSharding === 'Pipeline' ? 'border-exo-yellow' : 'border-exo-medium-gray'}">
@@ -1567,7 +1527,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
Pipeline
</button>
<button
onclick={() => { selectedSharding = 'Tensor'; saveLaunchDefaults(); }}
onclick={() => selectedSharding = 'Tensor'}
class="flex items-center gap-2 py-2 px-4 text-sm font-mono border rounded transition-all duration-200 cursor-pointer {selectedSharding === 'Tensor' ? 'bg-transparent text-exo-yellow border-exo-yellow' : 'bg-transparent text-white/70 border-exo-medium-gray/50 hover:border-exo-yellow/50'}"
>
<span class="w-4 h-4 rounded-full border-2 flex items-center justify-center {selectedSharding === 'Tensor' ? 'border-exo-yellow' : 'border-exo-medium-gray'}">
@@ -1585,7 +1545,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
<div class="text-xs text-white/70 font-mono mb-2">Instance Type:</div>
<div class="flex gap-2">
<button
onclick={() => { selectedInstanceType = 'MlxRing'; saveLaunchDefaults(); }}
onclick={() => selectedInstanceType = 'MlxRing'}
class="flex items-center gap-2 py-2 px-4 text-sm font-mono border rounded transition-all duration-200 cursor-pointer {selectedInstanceType === 'MlxRing' ? 'bg-transparent text-exo-yellow border-exo-yellow' : 'bg-transparent text-white/70 border-exo-medium-gray/50 hover:border-exo-yellow/50'}"
>
<span class="w-4 h-4 rounded-full border-2 flex items-center justify-center {selectedInstanceType === 'MlxRing' ? 'border-exo-yellow' : 'border-exo-medium-gray'}">
@@ -1596,7 +1556,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
MLX Ring
</button>
<button
onclick={() => { selectedInstanceType = 'MlxIbv'; saveLaunchDefaults(); }}
onclick={() => selectedInstanceType = 'MlxIbv'}
class="flex items-center gap-2 py-2 px-4 text-sm font-mono border rounded transition-all duration-200 cursor-pointer {selectedInstanceType === 'MlxIbv' ? 'bg-transparent text-exo-yellow border-exo-yellow' : 'bg-transparent text-white/70 border-exo-medium-gray/50 hover:border-exo-yellow/50'}"
>
<span class="w-4 h-4 rounded-full border-2 flex items-center justify-center {selectedInstanceType === 'MlxIbv' ? 'border-exo-yellow' : 'border-exo-medium-gray'}">

View File

@@ -1,6 +1,5 @@
import argparse
import multiprocessing as mp
import os
import signal
from dataclasses import dataclass, field
from typing import Self
@@ -28,7 +27,7 @@ from exo.worker.main import Worker
@dataclass
class Node:
router: Router
worker: Worker | None
worker: Worker
election: Election # Every node participates in election, as we do want a node to become master even if it isn't a master candidate if no master candidates are present.
election_result_receiver: Receiver[ElectionResult]
master: Master | None
@@ -62,19 +61,15 @@ class Node:
else:
api = None
if not args.no_worker:
worker = Worker(
node_id,
session_id,
exo_shard_downloader(),
connection_message_receiver=router.receiver(topics.CONNECTION_MESSAGES),
global_event_receiver=router.receiver(topics.GLOBAL_EVENTS),
local_event_sender=router.sender(topics.LOCAL_EVENTS),
command_sender=router.sender(topics.COMMANDS),
)
else:
worker = None
worker = Worker(
node_id,
session_id,
exo_shard_downloader(),
connection_message_receiver=router.receiver(topics.CONNECTION_MESSAGES),
global_event_receiver=router.receiver(topics.GLOBAL_EVENTS),
local_event_sender=router.sender(topics.LOCAL_EVENTS),
command_sender=router.sender(topics.COMMANDS),
)
# We start every node with a master
master = Master(
node_id,
@@ -104,9 +99,8 @@ class Node:
async with self._tg as tg:
signal.signal(signal.SIGINT, lambda _, __: self.shutdown())
tg.start_soon(self.router.run)
tg.start_soon(self.worker.run)
tg.start_soon(self.election.run)
if self.worker:
tg.start_soon(self.worker.run)
if self.master:
tg.start_soon(self.master.run)
if self.api:
@@ -200,7 +194,6 @@ def main():
# TODO: Refactor the current verbosity system
logger_setup(EXO_LOG, args.verbosity)
logger.info("Starting EXO")
logger.info(f"EXO_LIBP2P_NAMESPACE: {os.getenv('EXO_LIBP2P_NAMESPACE')}")
node = anyio.run(Node.create, args)
anyio.run(node.run)
@@ -214,7 +207,6 @@ class Args(CamelCaseModel):
spawn_api: bool = False
api_port: PositiveInt = 52415
tb_only: bool = False
no_worker: bool = False
@classmethod
def parse(cls) -> Self:
@@ -252,10 +244,6 @@ class Args(CamelCaseModel):
dest="api_port",
default=52415,
)
parser.add_argument(
"--no-worker",
action="store_true",
)
args = parser.parse_args()
return cls(**vars(args)) # pyright: ignore[reportAny] - We are intentionally validating here, we can't do it statically

View File

@@ -21,7 +21,6 @@ from exo.shared.types.commands import (
)
from exo.shared.types.events import Event, InstanceCreated, InstanceDeleted
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId
from exo.shared.types.topology import NodeInfo
from exo.shared.types.worker.instances import (
Instance,
@@ -30,7 +29,6 @@ from exo.shared.types.worker.instances import (
MlxJacclInstance,
MlxRingInstance,
)
from exo.shared.types.worker.shards import Sharding
def random_ephemeral_port() -> int:
@@ -67,28 +65,6 @@ def place_instance(
if not cycles_with_sufficient_memory:
raise ValueError("No cycles found with sufficient memory")
if command.sharding == Sharding.Tensor:
if not command.model_meta.supports_tensor:
raise ValueError(
f"Requested Tensor sharding but this model does not support tensor parallelism: {command.model_meta.model_id}"
)
# TODO: the condition here for tensor parallel is not correct, but it works good enough for now.
cycles_with_sufficient_memory = [
cycle
for cycle in cycles_with_sufficient_memory
if command.model_meta.hidden_size % len(cycle) == 0
]
if not cycles_with_sufficient_memory:
raise ValueError(
f"No tensor sharding found for model with hidden_size {command.model_meta.hidden_size} candidate cycles"
)
if command.sharding == Sharding.Pipeline and command.model_meta.model_id == ModelId(
"mlx-community/DeepSeek-V3.1-8bit"
):
raise ValueError(
"Pipeline parallelism is not supported for DeepSeek V3.1 (8-bit)"
)
smallest_cycles = get_smallest_cycles(cycles_with_sufficient_memory)
smallest_tb_cycles = [

View File

@@ -385,14 +385,13 @@ def get_mlx_jaccl_coordinators(
address in format "X.X.X.X:PORT" per node.
"""
rank_0_node = selected_cycle[0]
logger.debug(f"Selecting coordinator from rank 0 node: {rank_0_node.node_id}")
logger.info(f"Selecting coordinator from rank 0 node: {rank_0_node.node_id}")
def get_ip_for_node(n: NodeInfo) -> str:
if n.node_id == rank_0_node.node_id:
return "0.0.0.0"
ip = _find_ip_prioritised(n, rank_0_node, cycle_digraph)
if ip:
for ip, _ in _find_connection_ip(n, rank_0_node, cycle_digraph):
return ip
logger.warning(

View File

@@ -50,7 +50,7 @@ def model_meta() -> ModelMetadata:
storage_size=Memory.from_kb(1000),
pretty_name="Test Model",
n_layers=10,
hidden_size=30,
hidden_size=10,
supports_tensor=True,
)

View File

@@ -53,10 +53,6 @@ class RunnerRunning(BaseRunnerStatus):
pass
class RunnerShuttingDown(BaseRunnerStatus):
pass
class RunnerShutdown(BaseRunnerStatus):
pass
@@ -74,7 +70,6 @@ RunnerStatus = (
| RunnerWarmingUp
| RunnerReady
| RunnerRunning
| RunnerShuttingDown
| RunnerShutdown
| RunnerFailed
)

View File

@@ -450,11 +450,6 @@ async def get_weight_map(repo_id: str, revision: str = "main") -> dict[str, str]
async def resolve_allow_patterns(shard: ShardMetadata) -> list[str]:
# TODO: 'Smart' downloads are disabled because:
# (i) We don't handle all kinds of files;
# (ii) We don't have sticky sessions.
# (iii) Tensor parallel requires all files.
return ["*"]
try:
weight_map = await get_weight_map(str(shard.model_meta.model_id))
return get_allow_patterns(weight_map, shard)

View File

@@ -9,7 +9,7 @@ MAX_KV_SIZE: int | None = 3200
KEEP_KV_SIZE: int | None = 1600
QUANTIZE_MODEL_MODE: str | None = "affine"
CACHE_GROUP_SIZE: int = 64
KV_CACHE_BITS: int | None = None
KV_CACHE_BITS: int | None = 8
# TODO: We should really make this opt-in, but Kimi requires trust_remote_code=True
TRUST_REMOTE_CODE: bool = True

View File

@@ -395,5 +395,11 @@ def set_wired_limit_for_model(model_size: Memory):
"MB. This can be slow. See the documentation for possible work-arounds: "
"https://github.com/ml-explore/mlx-lm/tree/main#large-models"
)
kv_bytes = int(0.02 * model_bytes)
target_cache = int(1.10 * (model_bytes + kv_bytes))
target_cache = min(target_cache, max_rec_size)
mx.set_cache_limit(target_cache)
mx.set_wired_limit(max_rec_size)
logger.info(f"Wired limit set to {max_rec_size}.")
logger.info(
f"Wired limit set to {max_rec_size}. Cache limit set to {target_cache}."
)

View File

@@ -23,7 +23,6 @@ from exo.shared.types.events import (
TopologyEdgeCreated,
TopologyEdgeDeleted,
)
from exo.shared.types.models import ModelId
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.profiling import MemoryPerformanceProfile, NodePerformanceProfile
from exo.shared.types.state import State
@@ -84,7 +83,7 @@ class Worker:
self.out_for_delivery: dict[EventId, ForwarderEvent] = {}
self.state: State = State()
self.download_status: dict[ModelId, DownloadProgress] = {}
self.download_status: dict[ShardMetadata, DownloadProgress] = {}
self.runners: dict[RunnerId, RunnerSupervisor] = {}
self._tg: TaskGroup | None = None
@@ -129,7 +128,6 @@ class Worker:
tg.start_soon(start_polling_node_metrics, resource_monitor_callback)
tg.start_soon(start_polling_memory_metrics, memory_monitor_callback)
tg.start_soon(self._emit_existing_download_progress)
tg.start_soon(self._connection_message_event_writer)
tg.start_soon(self._resend_out_for_delivery)
tg.start_soon(self._event_applier)
@@ -202,11 +200,11 @@ class Worker:
)
)
case DownloadModel(shard_metadata=shard):
if shard.model_meta.model_id not in self.download_status:
if shard not in self.download_status:
progress = DownloadPending(
shard_metadata=shard, node_id=self.node_id
)
self.download_status[shard.model_meta.model_id] = progress
self.download_status[shard] = progress
await self.event_sender.send(
NodeDownloadProgress(download_progress=progress)
)
@@ -219,7 +217,7 @@ class Worker:
progress = DownloadCompleted(
shard_metadata=shard, node_id=self.node_id
)
self.download_status[shard.model_meta.model_id] = progress
self.download_status[shard] = progress
await self.event_sender.send(
NodeDownloadProgress(download_progress=progress)
)
@@ -351,7 +349,7 @@ class Worker:
initial_progress
),
)
self.download_status[task.shard_metadata.model_meta.model_id] = status
self.download_status[task.shard_metadata] = status
self.event_sender.send_nowait(NodeDownloadProgress(download_progress=status))
last_progress_time = 0.0
@@ -365,7 +363,7 @@ class Worker:
nonlocal last_progress_time
if progress.status == "complete":
status = DownloadCompleted(shard_metadata=shard, node_id=self.node_id)
self.download_status[shard.model_meta.model_id] = status
self.download_status[shard] = status
# Footgun!
self.event_sender.send_nowait(
NodeDownloadProgress(download_progress=status)
@@ -386,7 +384,7 @@ class Worker:
progress
),
)
self.download_status[shard.model_meta.model_id] = status
self.download_status[shard] = status
self.event_sender.send_nowait(
NodeDownloadProgress(download_progress=status)
)
@@ -446,40 +444,3 @@ class Worker:
await self.event_sender.send(TopologyEdgeDeleted(edge=conn))
await anyio.sleep(10)
async def _emit_existing_download_progress(self) -> None:
try:
while True:
logger.info("Fetching and emitting existing download progress...")
async for (
_,
progress,
) in self.shard_downloader.get_shard_download_status():
if progress.status == "complete":
status = DownloadCompleted(
node_id=self.node_id, shard_metadata=progress.shard
)
elif progress.status in ["in_progress", "not_started"]:
if progress.downloaded_bytes_this_session.in_bytes == 0:
status = DownloadPending(
node_id=self.node_id, shard_metadata=progress.shard
)
else:
status = DownloadOngoing(
node_id=self.node_id,
shard_metadata=progress.shard,
download_progress=map_repo_download_progress_to_download_progress_data(
progress
),
)
else:
continue
self.download_status[progress.shard.model_meta.model_id] = status
await self.event_sender.send(
NodeDownloadProgress(download_progress=status)
)
logger.info("Done emitting existing download progress.")
await anyio.sleep(5 * 60) # 5 minutes
except Exception as e:
logger.error(f"Error emitting existing download progress: {e}")

View File

@@ -3,7 +3,6 @@
from collections.abc import Mapping, Sequence
from exo.shared.types.common import NodeId
from exo.shared.types.models import ModelId
from exo.shared.types.tasks import (
ChatCompletion,
ConnectToGroup,
@@ -35,6 +34,7 @@ from exo.shared.types.worker.runners import (
RunnerStatus,
RunnerWarmingUp,
)
from exo.shared.types.worker.shards import ShardMetadata
from exo.worker.runner.runner_supervisor import RunnerSupervisor
@@ -43,7 +43,7 @@ def plan(
# Runners is expected to be FRESH and so should not come from state
runners: Mapping[RunnerId, RunnerSupervisor],
# DL_status is expected to be FRESH and so should not come from state
download_status: Mapping[ModelId, DownloadProgress],
download_status: Mapping[ShardMetadata, DownloadProgress],
# gdls is not expected to be fresh
global_download_status: Mapping[NodeId, Sequence[DownloadProgress]],
instances: Mapping[InstanceId, Instance],
@@ -111,14 +111,13 @@ def _create_runner(
def _model_needs_download(
runners: Mapping[RunnerId, RunnerSupervisor],
download_status: Mapping[ModelId, DownloadProgress],
download_status: Mapping[ShardMetadata, DownloadProgress],
) -> DownloadModel | None:
for runner in runners.values():
model_id = runner.bound_instance.bound_shard.model_meta.model_id
if isinstance(runner.status, RunnerIdle) and (
model_id not in download_status
or not isinstance(
download_status[model_id], (DownloadOngoing, DownloadCompleted)
not isinstance(
download_status.get(runner.bound_instance.bound_shard, None),
(DownloadOngoing, DownloadCompleted),
)
):
# We don't invalidate download_status randomly in case a file gets deleted on disk
@@ -236,8 +235,9 @@ def _ready_to_warmup(
assert device_rank < world_size
assert device_rank >= 0
# Rank != 0
accepting_ranks_ready = device_rank > 0 and all(
# TODO: Ensure these align with MLX distributeds expectations.
# Rank < n-1
accepting_ranks_ready = device_rank < world_size - 1 and all(
isinstance(
all_runners.get(global_runner_id, None),
(RunnerLoaded, RunnerWarmingUp),
@@ -245,8 +245,8 @@ def _ready_to_warmup(
for global_runner_id in shard_assignments.runner_to_shard
)
# Rank = 0
connecting_rank_ready = device_rank == 0 and all(
# Rank = n-1
connecting_rank_ready = device_rank == world_size - 1 and all(
isinstance(all_runners.get(global_runner_id, None), RunnerWarmingUp)
for global_runner_id in shard_assignments.runner_to_shard
if global_runner_id != runner_id
@@ -274,12 +274,6 @@ def _pending_tasks(
if task.instance_id != runner.bound_instance.instance.instance_id:
continue
# I have a design point here; this is a state race in disguise as the task status doesn't get updated to completed fast enough
# however, realistically the task status should be set to completed by the LAST runner, so this is a true race
# the actual solution is somewhat deeper than this bypass - TODO!
if task.task_id in runner.completed:
continue
# TODO: Check ordering aligns with MLX distributeds expectations.
if isinstance(runner.status, RunnerReady) and all(

View File

@@ -32,7 +32,6 @@ from exo.shared.types.worker.runners import (
RunnerReady,
RunnerRunning,
RunnerShutdown,
RunnerShuttingDown,
RunnerStatus,
RunnerWarmingUp,
)
@@ -188,14 +187,13 @@ def main(
current_status = RunnerReady()
logger.info("runner ready")
case Shutdown():
current_status = RunnerShuttingDown()
logger.info("runner shutting down")
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=current_status
TaskStatusUpdated(
task_id=task.task_id, task_status=TaskStatus.Complete
)
)
current_status = RunnerShutdown()
break
case _:
raise ValueError(
f"Received {task.__class__.__name__} outside of state machine in {current_status=}"
@@ -210,8 +208,9 @@ def main(
runner_id=runner_id, runner_status=current_status
)
)
if isinstance(current_status, RunnerShutdown):
break
event_sender.send(
RunnerStatusUpdated(runner_id=runner_id, runner_status=RunnerShutdown())
)
except ClosedResourceError:
logger.warning("runner communication closed unexpectedly")
except Exception as e:

View File

@@ -14,23 +14,13 @@ from anyio import (
from anyio.abc import TaskGroup
from loguru import logger
from exo.shared.types.events import (
Event,
RunnerStatusUpdated,
TaskAcknowledged,
TaskStatusUpdated,
)
from exo.shared.types.tasks import Task, TaskId, TaskStatus
from exo.shared.types.events import Event, RunnerStatusUpdated, TaskAcknowledged
from exo.shared.types.tasks import Task, TaskId
from exo.shared.types.worker.instances import BoundInstance
from exo.shared.types.worker.runners import (
RunnerConnecting,
RunnerFailed,
RunnerIdle,
RunnerLoading,
RunnerRunning,
RunnerShuttingDown,
RunnerStatus,
RunnerWarmingUp,
)
from exo.shared.types.worker.shards import ShardMetadata
from exo.utils.channels import MpReceiver, MpSender, Sender, mp_channel
@@ -49,10 +39,10 @@ class RunnerSupervisor:
_ev_recv: MpReceiver[Event]
_task_sender: MpSender[Task]
_event_sender: Sender[Event]
# err_path: str
_tg: TaskGroup | None = field(default=None, init=False)
status: RunnerStatus = field(default_factory=RunnerIdle, init=False)
pending: dict[TaskId, anyio.Event] = field(default_factory=dict, init=False)
completed: set[TaskId] = field(default_factory=set, init=False)
@classmethod
def create(
@@ -87,6 +77,7 @@ class RunnerSupervisor:
_ev_recv=ev_recv,
_task_sender=task_sender,
_event_sender=event_sender,
# err_path=err_path,
)
return self
@@ -127,10 +118,6 @@ class RunnerSupervisor:
self._tg.cancel_scope.cancel()
async def start_task(self, task: Task):
if task.task_id in self.completed:
logger.info(
f"Skipping invalid task {task} as it has already been completed"
)
logger.info(f"Starting task {task}")
event = anyio.Event()
self.pending[task.task_id] = event
@@ -151,22 +138,6 @@ class RunnerSupervisor:
if isinstance(event, TaskAcknowledged):
self.pending.pop(event.task_id).set()
continue
if (
isinstance(event, TaskStatusUpdated)
and event.task_status == TaskStatus.Complete
):
# If a task has just been completed, we should be working on it.
assert isinstance(
self.status,
(
RunnerRunning,
RunnerWarmingUp,
RunnerLoading,
RunnerConnecting,
RunnerShuttingDown,
),
)
self.completed.add(event.task_id)
await self._event_sender.send(event)
except (ClosedResourceError, BrokenResourceError) as e:
await self._check_runner(e)

View File

@@ -9,11 +9,9 @@ MASTER_NODE_ID = NodeId("ffffffff-aaaa-4aaa-8aaa-aaaaaaaaaaaa")
NODE_A: Final[NodeId] = NodeId("aaaaaaaa-aaaa-4aaa-8aaa-aaaaaaaaaaaa")
NODE_B: Final[NodeId] = NodeId("bbbbbbbb-bbbb-4bbb-8bbb-bbbbbbbbbbbb")
NODE_C: Final[NodeId] = NodeId("cccccccc-cccc-4ccc-8ccc-cccccccccccc")
RUNNER_1_ID: Final[RunnerId] = RunnerId("11111111-1111-4111-8111-111111111111")
RUNNER_2_ID: Final[RunnerId] = RunnerId("33333333-3333-4333-8333-333333333333")
RUNNER_3_ID: Final[RunnerId] = RunnerId("Runner3")
INSTANCE_1_ID: Final[InstanceId] = InstanceId("22222222-2222-4222-8222-222222222222")
INSTANCE_2_ID: Final[InstanceId] = InstanceId("44444444-4444-4444-8444-444444444444")

View File

@@ -1,9 +1,11 @@
from dataclasses import dataclass, field
from __future__ import annotations
from dataclasses import dataclass
from exo.shared.types.common import NodeId
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.tasks import BaseTask, TaskId
from exo.shared.types.tasks import BaseTask
from exo.shared.types.worker.instances import (
BoundInstance,
Instance,
@@ -19,7 +21,6 @@ from exo.shared.types.worker.shards import PipelineShardMetadata, ShardMetadata
class FakeRunnerSupervisor:
bound_instance: BoundInstance
status: RunnerStatus
completed: set[TaskId] = field(default_factory=set)
class OtherTask(BaseTask):

View File

@@ -1,6 +1,5 @@
import exo.worker.plan as plan_mod
from exo.shared.types.common import NodeId
from exo.shared.types.models import ModelId
from exo.shared.types.tasks import LoadModel
from exo.shared.types.worker.downloads import DownloadCompleted, DownloadProgress
from exo.shared.types.worker.instances import BoundInstance
@@ -8,6 +7,7 @@ from exo.shared.types.worker.runners import (
RunnerConnected,
RunnerIdle,
)
from exo.shared.types.worker.shards import ShardMetadata
from exo.worker.tests.constants import (
INSTANCE_1_ID,
MODEL_A_ID,
@@ -46,7 +46,7 @@ def test_plan_requests_download_when_waiting_and_shard_not_downloaded():
all_runners = {RUNNER_1_ID: RunnerIdle()}
# No entry for this shard -> should trigger DownloadModel
download_status: dict[ModelId, DownloadProgress] = {}
download_status: dict[ShardMetadata, DownloadProgress] = {}
result = plan_mod.plan(
node_id=NODE_A,
@@ -94,7 +94,7 @@ def test_plan_loads_model_when_all_shards_downloaded_and_waiting():
# Local node has already marked its shard as downloaded (not actually used by _load_model)
local_download_status = {
MODEL_A_ID: DownloadCompleted(shard_metadata=shard1, node_id=NODE_A)
shard1: DownloadCompleted(shard_metadata=shard1, node_id=NODE_A) # type: ignore[reportUnhashable]
}
# Global view has completed downloads for both nodes
@@ -140,7 +140,7 @@ def test_plan_does_not_request_download_when_shard_already_downloaded():
# Local status claims the shard is downloaded already
local_download_status = {
MODEL_A_ID: DownloadCompleted(shard_metadata=shard, node_id=NODE_A)
shard: DownloadCompleted(shard_metadata=shard, node_id=NODE_A) # type: ignore[reportUnhashable]
}
# Global view hasn't caught up yet (no completed shards recorded for NODE_A)
@@ -192,7 +192,7 @@ def test_plan_does_not_load_model_until_all_shards_downloaded_globally():
# Only NODE_A's shard is recorded as downloaded globally
local_download_status = {
MODEL_A_ID: DownloadCompleted(shard_metadata=shard1, node_id=NODE_A)
shard1: DownloadCompleted(shard_metadata=shard1, node_id=NODE_A) # type: ignore[reportUnhashable]
}
global_download_status = {
NODE_A: [DownloadCompleted(shard_metadata=shard1, node_id=NODE_A)],

View File

@@ -12,10 +12,8 @@ from exo.worker.tests.constants import (
MODEL_A_ID,
NODE_A,
NODE_B,
NODE_C,
RUNNER_1_ID,
RUNNER_2_ID,
RUNNER_3_ID,
)
from exo.worker.tests.unittests.conftest import (
FakeRunnerSupervisor,
@@ -26,39 +24,37 @@ from exo.worker.tests.unittests.conftest import (
def test_plan_starts_warmup_for_accepting_rank_when_all_loaded_or_warming():
"""
For non-zero device_rank shards, StartWarmup should be emitted when all
For non-final device_rank shards, StartWarmup should be emitted when all
shards in the instance are Loaded/WarmingUp.
"""
shard0 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=3)
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=3)
shard2 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=2, world_size=3)
shard0 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)
instance = get_mlx_ring_instance(
instance_id=INSTANCE_1_ID,
model_id=MODEL_A_ID,
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID, NODE_C: RUNNER_3_ID},
runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1, RUNNER_3_ID: shard2},
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},
runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1},
)
bound_instance = BoundInstance(
instance=instance, bound_runner_id=RUNNER_2_ID, bound_node_id=NODE_B
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
)
local_runner = FakeRunnerSupervisor(
bound_instance=bound_instance, status=RunnerLoaded()
)
runners = {RUNNER_2_ID: local_runner}
runners = {RUNNER_1_ID: local_runner}
instances = {INSTANCE_1_ID: instance}
all_runners = {
RUNNER_1_ID: RunnerLoaded(),
RUNNER_2_ID: RunnerLoaded(),
RUNNER_3_ID: RunnerWarmingUp(),
}
result = plan_mod.plan(
node_id=NODE_B,
node_id=NODE_A,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_A: []},
global_download_status={NODE_B: []},
instances=instances,
all_runners=all_runners,
tasks={},
@@ -154,9 +150,9 @@ def test_plan_does_not_start_warmup_for_rank_zero_until_others_warming():
"""
Rank-zero shard should not start warmup until all non-zero ranks are
already WarmingUp.
For accepting ranks (device_rank != 0), StartWarmup should be
For accepting ranks (device_rank != world_size - 1), StartWarmup should be
emitted when all shards in the instance are Loaded/WarmingUp.
In a 2-node setup, rank 1 is the accepting rank.
In a 2-node setup, rank 0 is the accepting rank.
"""
shard0 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)
@@ -167,7 +163,7 @@ def test_plan_does_not_start_warmup_for_rank_zero_until_others_warming():
runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1},
)
# Rank 1 is the accepting rank
# Rank 0 is the accepting rank
bound_instance = BoundInstance(
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
)
@@ -192,23 +188,6 @@ def test_plan_does_not_start_warmup_for_rank_zero_until_others_warming():
tasks={},
)
assert result is None
all_runners = {
RUNNER_1_ID: RunnerLoaded(),
RUNNER_2_ID: RunnerWarmingUp(),
}
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_A: []},
instances=instances,
all_runners=all_runners,
tasks={},
)
assert isinstance(result, StartWarmup)
assert result.instance_id == INSTANCE_1_ID
@@ -301,8 +280,9 @@ def test_plan_does_not_start_warmup_for_accepting_rank_until_all_loaded_or_warmi
def test_plan_does_not_start_warmup_for_connecting_rank_until_others_warming():
"""
Connecting rank (device_rank == 0) should not start warmup
Connecting rank (device_rank == world_size - 1) should not start warmup
until all other ranks are already WarmingUp.
In a 2-node setup, rank 1 is the connecting rank.
"""
shard0 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)
@@ -315,13 +295,13 @@ def test_plan_does_not_start_warmup_for_connecting_rank_until_others_warming():
# Rank 1 is the connecting rank
bound_instance = BoundInstance(
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
instance=instance, bound_runner_id=RUNNER_2_ID, bound_node_id=NODE_B
)
local_runner = FakeRunnerSupervisor(
bound_instance=bound_instance, status=RunnerLoaded()
)
runners = {RUNNER_1_ID: local_runner}
runners = {RUNNER_2_ID: local_runner}
instances = {INSTANCE_1_ID: instance}
all_runners = {
RUNNER_1_ID: RunnerLoaded(),
@@ -329,7 +309,7 @@ def test_plan_does_not_start_warmup_for_connecting_rank_until_others_warming():
}
result = plan_mod.plan(
node_id=NODE_A,
node_id=NODE_B,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_A: [], NODE_B: []},

View File

@@ -34,7 +34,6 @@ from exo.shared.types.worker.runners import (
RunnerReady,
RunnerRunning,
RunnerShutdown,
RunnerShuttingDown,
RunnerWarmingUp,
)
from exo.utils.channels import mp_channel
@@ -200,9 +199,6 @@ def test_events_processed_in_correct_order(patch_out_mlx: pytest.MonkeyPatch):
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerReady()),
TaskStatusUpdated(task_id=SHUTDOWN_TASK_ID, task_status=TaskStatus.Running),
TaskAcknowledged(task_id=SHUTDOWN_TASK_ID),
RunnerStatusUpdated(
runner_id=RUNNER_1_ID, runner_status=RunnerShuttingDown()
),
TaskStatusUpdated(
task_id=SHUTDOWN_TASK_ID, task_status=TaskStatus.Complete
),