Compare commits

...

11 Commits

Author SHA1 Message Date
madanlalit
4f6fcd9e93 feat(macos-app): add custom namespace UI for cluster isolation
Add Advanced Options section with custom namespace field that allows
users to override EXO_LIBP2P_NAMESPACE environment variable. This
enables splitting machines that can see each other into separate
clusters.

- Added customNamespace property with UserDefaults persistence
- Added Advanced Options collapsible section with text field
- Added Save & Restart button that auto-restarts exo process
- Namespace replaces buildTag when custom value is set
- Falls back to buildTag (version) when namespace is empty
2026-01-05 15:25:00 +01:00
Evan Quiney
839b67f318 [feat] Add an option to disable the worker (#1091)
## Motivation

Workerless machines can be used for networking without running any gpu
jobs - add a cli flag that adds this basic functionality.

## Changes

Adds the --no-worker cli flag

## Test Plan

### Manual Testing

Exo starts as expected

### Automated Testing

None
2026-01-05 12:05:03 +00:00
Drifter4242
47b8e0ce12 feat: remember last launch settings (model, sharding, instance type) (#1028)
## Motivation

Saves the last launch settings, so that the next time you run exo it
will default to the same launch settings.
This is just a small quality of life improvement.

## Changes

When you launch it saves the settings to the web browser local storage.
When it fills out the model list, it reads the settings and sets the
default.

I reviewed, tested and edited the code, but some of the code was written
by Claude Opus. I hope that's ok.

## Why It Works

See above

## Test Plan

### Manual Testing

I have two Macbook Studio M3 Ultras, each with 512Gb ram, connected with
Thunderbolt 5. I ran Kimi K2 Thinking with MLX Ring and Tensor Split.
I ran exo multiple times to confirm that the default works.

### Automated Testing

No changes to automated testing.
2026-01-05 11:27:14 +00:00
Evan Quiney
17f9b583a4 Task Deduplication (#1062) 2026-01-03 20:01:49 +00:00
RickyChen / 陳昭儒
844bcc7ce6 fix: prevent form submission during IME composition (#1069)
## Problem
When typing in Chinese (or other IME-based languages like
Japanese/Korean), pressing Enter to select a character from the IME
candidate list would incorrectly submit the message instead of
confirming the character selection.

## Solution
Added IME composition state detection in the `handleKeydown` function in
`ChatForm.svelte`:
- Check `event.isComposing` to detect active IME composition
- Fallback to `event.keyCode === 229` for broader browser compatibility
- Return early when IME is active, allowing normal character selection

## Changes
- Modified `dashboard/src/lib/components/ChatForm.svelte` 
- Added IME composition check before Enter key handling

Co-authored-by: Ricky Chen <rickychen@Rickys-MacBook-Pro.local>
2025-12-31 17:11:04 +00:00
Evan Quiney
c1be5184b2 Fix tests broken by 283c (#1063)
Some tests were broken by #1058 and #1046 - this fixes them.
2025-12-31 01:53:55 +00:00
Alex Cheema
1ec550dff1 Emit download progress on start, and change downloads to be keyed by model_id (#1044)
## Motivation

We added a download page to the dashboard which shows the currently
download status of each model on each node. Users have reported this to
be extremely useful.

However, we don't currently fetch the download progress on start, so it
doesn't show any model's download status.

## Changes

Fetch and emit model download status on start of worker, and
periodically every 5 mins.
Also to support this, I changed download_status to be keyed by model_id
instead of shard, since we want download_status of each model, not each
shard.

## Why It Works

The dashboard already implements the correct functionality, we just
weren't populating the download status in the state. Now it gets
populated and shows correctly.

## Test Plan

### Manual Testing
On a cluster of 2 x 512GB M3 Ultra Mac Studio, I launched an instance
onto one node that hadn't been downloaded. I checked the download page
and it showed the in progress download. I downloaded it to completion,
restarted exo on both nodes, and then opened the download page and it
showed the model as 100% downloaded and other models as 0% that hadn't
been downloaded.

---------

Co-authored-by: Evan <evanev7@gmail.com>
2025-12-31 01:18:10 +00:00
Alex Cheema
283c0e39e4 Placement filters for tensor parallel supports_tensor, tensor dimension and pipeline parallel deepseek v3.1 (#1058)
## Motivation

Certain placements are not valid. Added filters to exclude these placements. There were invalid placement previews being shown in the dashboard which would then fail when the user actually tries to launch an instance with that placement.


## Changes

Three filters added:

1. Certain models do not support tensor parallel at all. Checks `supports_tensor` on the model_meta.
2. For models that do support tensor parallelism, certain tensor parallel sizes are not valid. This check is actually not correct right now but it works fine for now. The actual correct check is more involved.
3. For unknown reasons, deepseek v3.1 (8-bit) does not work with tensor parallelism.

## Why It Works

`place_instance` now raises an `Exception` for invalid placements.

## Test Plan

### Manual Testing
Since `/instance/previews` enumerates all possible placements and runs `place_instance`, I checked the dashboard to see if invalid placements are still shown.
2025-12-31 00:33:40 +00:00
Alex Cheema
35be4c55c3 prioritise mlx jaccl coordinator ip (en0 -> en1 -> non-TB5 -> other) 2025-12-31 00:10:19 +00:00
Alex Cheema
31d4cd8409 set KV_CACHE_BITS to None to disable quantized kv cache 2025-12-31 00:03:30 +00:00
Alex Cheema
8a6da58404 remove mx.set_cache_limit 2025-12-30 23:58:15 +00:00
21 changed files with 340 additions and 76 deletions

View File

@@ -20,6 +20,8 @@ struct ContentView: View {
@State private var showDebugInfo = false
@State private var bugReportInFlight = false
@State private var bugReportMessage: String?
@State private var showAdvancedOptions = false
@State private var pendingNamespace: String = ""
var body: some View {
VStack(alignment: .leading, spacing: 12) {
@@ -197,6 +199,8 @@ struct ContentView: View {
updater.checkForUpdates()
}
.padding(.bottom, 8)
advancedOptionsSection
.padding(.bottom, 8)
debugSection
.padding(.bottom, 8)
controlButton(title: "Quit", tint: .secondary) {
@@ -327,6 +331,47 @@ struct ContentView: View {
}
}
private var advancedOptionsSection: some View {
VStack(alignment: .leading, spacing: 6) {
HStack {
Text("Advanced Options")
.font(.caption)
.foregroundColor(.secondary)
Spacer()
collapseButton(isExpanded: $showAdvancedOptions)
}
.animation(nil, value: showAdvancedOptions)
if showAdvancedOptions {
VStack(alignment: .leading, spacing: 8) {
VStack(alignment: .leading, spacing: 4) {
Text("Cluster Namespace")
.font(.caption2)
.foregroundColor(.secondary)
HStack {
TextField("optional", text: $pendingNamespace)
.textFieldStyle(.roundedBorder)
.font(.caption2)
.onAppear {
pendingNamespace = controller.customNamespace
}
Button("Save & Restart") {
controller.customNamespace = pendingNamespace
if controller.status == .running || controller.status == .starting {
controller.restart()
}
}
.font(.caption2)
.disabled(pendingNamespace == controller.customNamespace)
}
}
}
.transition(.opacity)
}
}
.animation(.easeInOut(duration: 0.25), value: showAdvancedOptions)
}
private var debugSection: some View {
VStack(alignment: .leading, spacing: 6) {
HStack {

View File

@@ -2,6 +2,8 @@ import AppKit
import Combine
import Foundation
private let customNamespaceKey = "EXOCustomNamespace"
@MainActor
final class ExoProcessController: ObservableObject {
enum Status: Equatable {
@@ -27,6 +29,13 @@ final class ExoProcessController: ObservableObject {
@Published private(set) var status: Status = .stopped
@Published private(set) var lastError: String?
@Published private(set) var launchCountdownSeconds: Int?
@Published var customNamespace: String = {
return UserDefaults.standard.string(forKey: customNamespaceKey) ?? ""
}() {
didSet {
UserDefaults.standard.set(customNamespace, forKey: customNamespaceKey)
}
}
private var process: Process?
private var runtimeDirectoryURL: URL?
@@ -180,7 +189,7 @@ final class ExoProcessController: ObservableObject {
private func makeEnvironment(for runtimeURL: URL) -> [String: String] {
var environment = ProcessInfo.processInfo.environment
environment["EXO_RUNTIME_DIR"] = runtimeURL.path
environment["EXO_LIBP2P_NAMESPACE"] = buildTag()
environment["EXO_LIBP2P_NAMESPACE"] = computeNamespace()
var paths: [String] = []
if let existing = environment["PATH"], !existing.isEmpty {
@@ -217,6 +226,12 @@ final class ExoProcessController: ObservableObject {
}
return "dev"
}
private func computeNamespace() -> String {
let base = buildTag()
let custom = customNamespace.trimmingCharacters(in: .whitespaces)
return custom.isEmpty ? base : custom
}
}
struct RuntimeError: LocalizedError {

View File

@@ -139,6 +139,11 @@
}
function handleKeydown(event: KeyboardEvent) {
// Prevent form submission during IME composition (e.g., Chinese, Japanese, Korean input)
if (event.isComposing || event.keyCode === 229) {
return;
}
if (event.key === 'Enter' && !event.shiftKey) {
event.preventDefault();
handleSubmit();

View File

@@ -51,6 +51,59 @@ const sidebarVisible = $derived(chatSidebarVisible());
let selectedSharding = $state<'Pipeline' | 'Tensor'>('Pipeline');
type InstanceMeta = 'MlxRing' | 'MlxIbv' | 'MlxJaccl';
// Launch defaults persistence
const LAUNCH_DEFAULTS_KEY = 'exo-launch-defaults';
interface LaunchDefaults {
modelId: string | null;
sharding: 'Pipeline' | 'Tensor';
instanceType: InstanceMeta;
minNodes: number;
}
function saveLaunchDefaults(): void {
const defaults: LaunchDefaults = {
modelId: selectedPreviewModelId(),
sharding: selectedSharding,
instanceType: selectedInstanceType,
minNodes: selectedMinNodes,
};
try {
localStorage.setItem(LAUNCH_DEFAULTS_KEY, JSON.stringify(defaults));
} catch (e) {
console.warn('Failed to save launch defaults:', e);
}
}
function loadLaunchDefaults(): LaunchDefaults | null {
try {
const stored = localStorage.getItem(LAUNCH_DEFAULTS_KEY);
if (!stored) return null;
return JSON.parse(stored) as LaunchDefaults;
} catch (e) {
console.warn('Failed to load launch defaults:', e);
return null;
}
}
function applyLaunchDefaults(availableModels: Array<{id: string}>, maxNodes: number): void {
const defaults = loadLaunchDefaults();
if (!defaults) return;
// Apply sharding and instance type unconditionally
selectedSharding = defaults.sharding;
selectedInstanceType = defaults.instanceType;
// Apply minNodes if valid (between 1 and maxNodes)
if (defaults.minNodes && defaults.minNodes >= 1 && defaults.minNodes <= maxNodes) {
selectedMinNodes = defaults.minNodes;
}
// Only apply model if it exists in the available models
if (defaults.modelId && availableModels.some(m => m.id === defaults.modelId)) {
selectPreviewModel(defaults.modelId);
}
}
let selectedInstanceType = $state<InstanceMeta>('MlxRing');
let selectedMinNodes = $state<number>(1);
let minNodesInitialized = $state(false);
@@ -298,6 +351,9 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
const data = await response.json();
// API returns { data: [{ id, name }] } format
models = data.data || [];
// Restore last launch defaults if available
const currentNodeCount = topologyData() ? Object.keys(topologyData()!.nodes).length : 1;
applyLaunchDefaults(models, currentNodeCount);
}
} catch (error) {
console.error('Failed to fetch models:', error);
@@ -988,6 +1044,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
function handleSliderMouseUp() {
isDraggingSlider = false;
saveLaunchDefaults();
}
// Handle touch events for mobile
@@ -1007,6 +1064,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
function handleSliderTouchEnd() {
isDraggingSlider = false;
saveLaunchDefaults();
}
const nodeCount = $derived(data ? Object.keys(data.nodes).length : 0);
@@ -1464,6 +1522,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
onclick={() => {
if (modelCanFit) {
selectPreviewModel(model.id);
saveLaunchDefaults();
isModelDropdownOpen = false;
modelDropdownSearch = '';
}
@@ -1497,7 +1556,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
<div class="text-xs text-white/70 font-mono mb-2">Sharding:</div>
<div class="flex gap-2">
<button
onclick={() => selectedSharding = 'Pipeline'}
onclick={() => { selectedSharding = 'Pipeline'; saveLaunchDefaults(); }}
class="flex items-center gap-2 py-2 px-4 text-sm font-mono border rounded transition-all duration-200 cursor-pointer {selectedSharding === 'Pipeline' ? 'bg-transparent text-exo-yellow border-exo-yellow' : 'bg-transparent text-white/70 border-exo-medium-gray/50 hover:border-exo-yellow/50'}"
>
<span class="w-4 h-4 rounded-full border-2 flex items-center justify-center {selectedSharding === 'Pipeline' ? 'border-exo-yellow' : 'border-exo-medium-gray'}">
@@ -1508,7 +1567,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
Pipeline
</button>
<button
onclick={() => selectedSharding = 'Tensor'}
onclick={() => { selectedSharding = 'Tensor'; saveLaunchDefaults(); }}
class="flex items-center gap-2 py-2 px-4 text-sm font-mono border rounded transition-all duration-200 cursor-pointer {selectedSharding === 'Tensor' ? 'bg-transparent text-exo-yellow border-exo-yellow' : 'bg-transparent text-white/70 border-exo-medium-gray/50 hover:border-exo-yellow/50'}"
>
<span class="w-4 h-4 rounded-full border-2 flex items-center justify-center {selectedSharding === 'Tensor' ? 'border-exo-yellow' : 'border-exo-medium-gray'}">
@@ -1526,7 +1585,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
<div class="text-xs text-white/70 font-mono mb-2">Instance Type:</div>
<div class="flex gap-2">
<button
onclick={() => selectedInstanceType = 'MlxRing'}
onclick={() => { selectedInstanceType = 'MlxRing'; saveLaunchDefaults(); }}
class="flex items-center gap-2 py-2 px-4 text-sm font-mono border rounded transition-all duration-200 cursor-pointer {selectedInstanceType === 'MlxRing' ? 'bg-transparent text-exo-yellow border-exo-yellow' : 'bg-transparent text-white/70 border-exo-medium-gray/50 hover:border-exo-yellow/50'}"
>
<span class="w-4 h-4 rounded-full border-2 flex items-center justify-center {selectedInstanceType === 'MlxRing' ? 'border-exo-yellow' : 'border-exo-medium-gray'}">
@@ -1537,7 +1596,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
MLX Ring
</button>
<button
onclick={() => selectedInstanceType = 'MlxIbv'}
onclick={() => { selectedInstanceType = 'MlxIbv'; saveLaunchDefaults(); }}
class="flex items-center gap-2 py-2 px-4 text-sm font-mono border rounded transition-all duration-200 cursor-pointer {selectedInstanceType === 'MlxIbv' ? 'bg-transparent text-exo-yellow border-exo-yellow' : 'bg-transparent text-white/70 border-exo-medium-gray/50 hover:border-exo-yellow/50'}"
>
<span class="w-4 h-4 rounded-full border-2 flex items-center justify-center {selectedInstanceType === 'MlxIbv' ? 'border-exo-yellow' : 'border-exo-medium-gray'}">

View File

@@ -28,7 +28,7 @@ from exo.worker.main import Worker
@dataclass
class Node:
router: Router
worker: Worker
worker: Worker | None
election: Election # Every node participates in election, as we do want a node to become master even if it isn't a master candidate if no master candidates are present.
election_result_receiver: Receiver[ElectionResult]
master: Master | None
@@ -62,15 +62,19 @@ class Node:
else:
api = None
worker = Worker(
node_id,
session_id,
exo_shard_downloader(),
connection_message_receiver=router.receiver(topics.CONNECTION_MESSAGES),
global_event_receiver=router.receiver(topics.GLOBAL_EVENTS),
local_event_sender=router.sender(topics.LOCAL_EVENTS),
command_sender=router.sender(topics.COMMANDS),
)
if not args.no_worker:
worker = Worker(
node_id,
session_id,
exo_shard_downloader(),
connection_message_receiver=router.receiver(topics.CONNECTION_MESSAGES),
global_event_receiver=router.receiver(topics.GLOBAL_EVENTS),
local_event_sender=router.sender(topics.LOCAL_EVENTS),
command_sender=router.sender(topics.COMMANDS),
)
else:
worker = None
# We start every node with a master
master = Master(
node_id,
@@ -100,8 +104,9 @@ class Node:
async with self._tg as tg:
signal.signal(signal.SIGINT, lambda _, __: self.shutdown())
tg.start_soon(self.router.run)
tg.start_soon(self.worker.run)
tg.start_soon(self.election.run)
if self.worker:
tg.start_soon(self.worker.run)
if self.master:
tg.start_soon(self.master.run)
if self.api:
@@ -209,6 +214,7 @@ class Args(CamelCaseModel):
spawn_api: bool = False
api_port: PositiveInt = 52415
tb_only: bool = False
no_worker: bool = False
@classmethod
def parse(cls) -> Self:
@@ -246,6 +252,10 @@ class Args(CamelCaseModel):
dest="api_port",
default=52415,
)
parser.add_argument(
"--no-worker",
action="store_true",
)
args = parser.parse_args()
return cls(**vars(args)) # pyright: ignore[reportAny] - We are intentionally validating here, we can't do it statically

View File

@@ -21,6 +21,7 @@ from exo.shared.types.commands import (
)
from exo.shared.types.events import Event, InstanceCreated, InstanceDeleted
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId
from exo.shared.types.topology import NodeInfo
from exo.shared.types.worker.instances import (
Instance,
@@ -29,6 +30,7 @@ from exo.shared.types.worker.instances import (
MlxJacclInstance,
MlxRingInstance,
)
from exo.shared.types.worker.shards import Sharding
def random_ephemeral_port() -> int:
@@ -65,6 +67,28 @@ def place_instance(
if not cycles_with_sufficient_memory:
raise ValueError("No cycles found with sufficient memory")
if command.sharding == Sharding.Tensor:
if not command.model_meta.supports_tensor:
raise ValueError(
f"Requested Tensor sharding but this model does not support tensor parallelism: {command.model_meta.model_id}"
)
# TODO: the condition here for tensor parallel is not correct, but it works good enough for now.
cycles_with_sufficient_memory = [
cycle
for cycle in cycles_with_sufficient_memory
if command.model_meta.hidden_size % len(cycle) == 0
]
if not cycles_with_sufficient_memory:
raise ValueError(
f"No tensor sharding found for model with hidden_size {command.model_meta.hidden_size} candidate cycles"
)
if command.sharding == Sharding.Pipeline and command.model_meta.model_id == ModelId(
"mlx-community/DeepSeek-V3.1-8bit"
):
raise ValueError(
"Pipeline parallelism is not supported for DeepSeek V3.1 (8-bit)"
)
smallest_cycles = get_smallest_cycles(cycles_with_sufficient_memory)
smallest_tb_cycles = [

View File

@@ -385,13 +385,14 @@ def get_mlx_jaccl_coordinators(
address in format "X.X.X.X:PORT" per node.
"""
rank_0_node = selected_cycle[0]
logger.info(f"Selecting coordinator from rank 0 node: {rank_0_node.node_id}")
logger.debug(f"Selecting coordinator from rank 0 node: {rank_0_node.node_id}")
def get_ip_for_node(n: NodeInfo) -> str:
if n.node_id == rank_0_node.node_id:
return "0.0.0.0"
for ip, _ in _find_connection_ip(n, rank_0_node, cycle_digraph):
ip = _find_ip_prioritised(n, rank_0_node, cycle_digraph)
if ip:
return ip
logger.warning(

View File

@@ -50,7 +50,7 @@ def model_meta() -> ModelMetadata:
storage_size=Memory.from_kb(1000),
pretty_name="Test Model",
n_layers=10,
hidden_size=10,
hidden_size=30,
supports_tensor=True,
)

View File

@@ -53,6 +53,10 @@ class RunnerRunning(BaseRunnerStatus):
pass
class RunnerShuttingDown(BaseRunnerStatus):
pass
class RunnerShutdown(BaseRunnerStatus):
pass
@@ -70,6 +74,7 @@ RunnerStatus = (
| RunnerWarmingUp
| RunnerReady
| RunnerRunning
| RunnerShuttingDown
| RunnerShutdown
| RunnerFailed
)

View File

@@ -450,6 +450,11 @@ async def get_weight_map(repo_id: str, revision: str = "main") -> dict[str, str]
async def resolve_allow_patterns(shard: ShardMetadata) -> list[str]:
# TODO: 'Smart' downloads are disabled because:
# (i) We don't handle all kinds of files;
# (ii) We don't have sticky sessions.
# (iii) Tensor parallel requires all files.
return ["*"]
try:
weight_map = await get_weight_map(str(shard.model_meta.model_id))
return get_allow_patterns(weight_map, shard)

View File

@@ -9,7 +9,7 @@ MAX_KV_SIZE: int | None = 3200
KEEP_KV_SIZE: int | None = 1600
QUANTIZE_MODEL_MODE: str | None = "affine"
CACHE_GROUP_SIZE: int = 64
KV_CACHE_BITS: int | None = 8
KV_CACHE_BITS: int | None = None
# TODO: We should really make this opt-in, but Kimi requires trust_remote_code=True
TRUST_REMOTE_CODE: bool = True

View File

@@ -395,11 +395,5 @@ def set_wired_limit_for_model(model_size: Memory):
"MB. This can be slow. See the documentation for possible work-arounds: "
"https://github.com/ml-explore/mlx-lm/tree/main#large-models"
)
kv_bytes = int(0.02 * model_bytes)
target_cache = int(1.10 * (model_bytes + kv_bytes))
target_cache = min(target_cache, max_rec_size)
mx.set_cache_limit(target_cache)
mx.set_wired_limit(max_rec_size)
logger.info(
f"Wired limit set to {max_rec_size}. Cache limit set to {target_cache}."
)
logger.info(f"Wired limit set to {max_rec_size}.")

View File

@@ -23,6 +23,7 @@ from exo.shared.types.events import (
TopologyEdgeCreated,
TopologyEdgeDeleted,
)
from exo.shared.types.models import ModelId
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.profiling import MemoryPerformanceProfile, NodePerformanceProfile
from exo.shared.types.state import State
@@ -83,7 +84,7 @@ class Worker:
self.out_for_delivery: dict[EventId, ForwarderEvent] = {}
self.state: State = State()
self.download_status: dict[ShardMetadata, DownloadProgress] = {}
self.download_status: dict[ModelId, DownloadProgress] = {}
self.runners: dict[RunnerId, RunnerSupervisor] = {}
self._tg: TaskGroup | None = None
@@ -128,6 +129,7 @@ class Worker:
tg.start_soon(start_polling_node_metrics, resource_monitor_callback)
tg.start_soon(start_polling_memory_metrics, memory_monitor_callback)
tg.start_soon(self._emit_existing_download_progress)
tg.start_soon(self._connection_message_event_writer)
tg.start_soon(self._resend_out_for_delivery)
tg.start_soon(self._event_applier)
@@ -200,11 +202,11 @@ class Worker:
)
)
case DownloadModel(shard_metadata=shard):
if shard not in self.download_status:
if shard.model_meta.model_id not in self.download_status:
progress = DownloadPending(
shard_metadata=shard, node_id=self.node_id
)
self.download_status[shard] = progress
self.download_status[shard.model_meta.model_id] = progress
await self.event_sender.send(
NodeDownloadProgress(download_progress=progress)
)
@@ -217,7 +219,7 @@ class Worker:
progress = DownloadCompleted(
shard_metadata=shard, node_id=self.node_id
)
self.download_status[shard] = progress
self.download_status[shard.model_meta.model_id] = progress
await self.event_sender.send(
NodeDownloadProgress(download_progress=progress)
)
@@ -349,7 +351,7 @@ class Worker:
initial_progress
),
)
self.download_status[task.shard_metadata] = status
self.download_status[task.shard_metadata.model_meta.model_id] = status
self.event_sender.send_nowait(NodeDownloadProgress(download_progress=status))
last_progress_time = 0.0
@@ -363,7 +365,7 @@ class Worker:
nonlocal last_progress_time
if progress.status == "complete":
status = DownloadCompleted(shard_metadata=shard, node_id=self.node_id)
self.download_status[shard] = status
self.download_status[shard.model_meta.model_id] = status
# Footgun!
self.event_sender.send_nowait(
NodeDownloadProgress(download_progress=status)
@@ -384,7 +386,7 @@ class Worker:
progress
),
)
self.download_status[shard] = status
self.download_status[shard.model_meta.model_id] = status
self.event_sender.send_nowait(
NodeDownloadProgress(download_progress=status)
)
@@ -444,3 +446,40 @@ class Worker:
await self.event_sender.send(TopologyEdgeDeleted(edge=conn))
await anyio.sleep(10)
async def _emit_existing_download_progress(self) -> None:
try:
while True:
logger.info("Fetching and emitting existing download progress...")
async for (
_,
progress,
) in self.shard_downloader.get_shard_download_status():
if progress.status == "complete":
status = DownloadCompleted(
node_id=self.node_id, shard_metadata=progress.shard
)
elif progress.status in ["in_progress", "not_started"]:
if progress.downloaded_bytes_this_session.in_bytes == 0:
status = DownloadPending(
node_id=self.node_id, shard_metadata=progress.shard
)
else:
status = DownloadOngoing(
node_id=self.node_id,
shard_metadata=progress.shard,
download_progress=map_repo_download_progress_to_download_progress_data(
progress
),
)
else:
continue
self.download_status[progress.shard.model_meta.model_id] = status
await self.event_sender.send(
NodeDownloadProgress(download_progress=status)
)
logger.info("Done emitting existing download progress.")
await anyio.sleep(5 * 60) # 5 minutes
except Exception as e:
logger.error(f"Error emitting existing download progress: {e}")

View File

@@ -3,6 +3,7 @@
from collections.abc import Mapping, Sequence
from exo.shared.types.common import NodeId
from exo.shared.types.models import ModelId
from exo.shared.types.tasks import (
ChatCompletion,
ConnectToGroup,
@@ -34,7 +35,6 @@ from exo.shared.types.worker.runners import (
RunnerStatus,
RunnerWarmingUp,
)
from exo.shared.types.worker.shards import ShardMetadata
from exo.worker.runner.runner_supervisor import RunnerSupervisor
@@ -43,7 +43,7 @@ def plan(
# Runners is expected to be FRESH and so should not come from state
runners: Mapping[RunnerId, RunnerSupervisor],
# DL_status is expected to be FRESH and so should not come from state
download_status: Mapping[ShardMetadata, DownloadProgress],
download_status: Mapping[ModelId, DownloadProgress],
# gdls is not expected to be fresh
global_download_status: Mapping[NodeId, Sequence[DownloadProgress]],
instances: Mapping[InstanceId, Instance],
@@ -111,13 +111,14 @@ def _create_runner(
def _model_needs_download(
runners: Mapping[RunnerId, RunnerSupervisor],
download_status: Mapping[ShardMetadata, DownloadProgress],
download_status: Mapping[ModelId, DownloadProgress],
) -> DownloadModel | None:
for runner in runners.values():
model_id = runner.bound_instance.bound_shard.model_meta.model_id
if isinstance(runner.status, RunnerIdle) and (
not isinstance(
download_status.get(runner.bound_instance.bound_shard, None),
(DownloadOngoing, DownloadCompleted),
model_id not in download_status
or not isinstance(
download_status[model_id], (DownloadOngoing, DownloadCompleted)
)
):
# We don't invalidate download_status randomly in case a file gets deleted on disk
@@ -273,6 +274,12 @@ def _pending_tasks(
if task.instance_id != runner.bound_instance.instance.instance_id:
continue
# I have a design point here; this is a state race in disguise as the task status doesn't get updated to completed fast enough
# however, realistically the task status should be set to completed by the LAST runner, so this is a true race
# the actual solution is somewhat deeper than this bypass - TODO!
if task.task_id in runner.completed:
continue
# TODO: Check ordering aligns with MLX distributeds expectations.
if isinstance(runner.status, RunnerReady) and all(

View File

@@ -32,6 +32,7 @@ from exo.shared.types.worker.runners import (
RunnerReady,
RunnerRunning,
RunnerShutdown,
RunnerShuttingDown,
RunnerStatus,
RunnerWarmingUp,
)
@@ -187,13 +188,14 @@ def main(
current_status = RunnerReady()
logger.info("runner ready")
case Shutdown():
current_status = RunnerShuttingDown()
logger.info("runner shutting down")
event_sender.send(
TaskStatusUpdated(
task_id=task.task_id, task_status=TaskStatus.Complete
RunnerStatusUpdated(
runner_id=runner_id, runner_status=current_status
)
)
break
current_status = RunnerShutdown()
case _:
raise ValueError(
f"Received {task.__class__.__name__} outside of state machine in {current_status=}"
@@ -208,9 +210,8 @@ def main(
runner_id=runner_id, runner_status=current_status
)
)
event_sender.send(
RunnerStatusUpdated(runner_id=runner_id, runner_status=RunnerShutdown())
)
if isinstance(current_status, RunnerShutdown):
break
except ClosedResourceError:
logger.warning("runner communication closed unexpectedly")
except Exception as e:

View File

@@ -14,13 +14,23 @@ from anyio import (
from anyio.abc import TaskGroup
from loguru import logger
from exo.shared.types.events import Event, RunnerStatusUpdated, TaskAcknowledged
from exo.shared.types.tasks import Task, TaskId
from exo.shared.types.events import (
Event,
RunnerStatusUpdated,
TaskAcknowledged,
TaskStatusUpdated,
)
from exo.shared.types.tasks import Task, TaskId, TaskStatus
from exo.shared.types.worker.instances import BoundInstance
from exo.shared.types.worker.runners import (
RunnerConnecting,
RunnerFailed,
RunnerIdle,
RunnerLoading,
RunnerRunning,
RunnerShuttingDown,
RunnerStatus,
RunnerWarmingUp,
)
from exo.shared.types.worker.shards import ShardMetadata
from exo.utils.channels import MpReceiver, MpSender, Sender, mp_channel
@@ -39,10 +49,10 @@ class RunnerSupervisor:
_ev_recv: MpReceiver[Event]
_task_sender: MpSender[Task]
_event_sender: Sender[Event]
# err_path: str
_tg: TaskGroup | None = field(default=None, init=False)
status: RunnerStatus = field(default_factory=RunnerIdle, init=False)
pending: dict[TaskId, anyio.Event] = field(default_factory=dict, init=False)
completed: set[TaskId] = field(default_factory=set, init=False)
@classmethod
def create(
@@ -77,7 +87,6 @@ class RunnerSupervisor:
_ev_recv=ev_recv,
_task_sender=task_sender,
_event_sender=event_sender,
# err_path=err_path,
)
return self
@@ -118,6 +127,10 @@ class RunnerSupervisor:
self._tg.cancel_scope.cancel()
async def start_task(self, task: Task):
if task.task_id in self.completed:
logger.info(
f"Skipping invalid task {task} as it has already been completed"
)
logger.info(f"Starting task {task}")
event = anyio.Event()
self.pending[task.task_id] = event
@@ -138,6 +151,22 @@ class RunnerSupervisor:
if isinstance(event, TaskAcknowledged):
self.pending.pop(event.task_id).set()
continue
if (
isinstance(event, TaskStatusUpdated)
and event.task_status == TaskStatus.Complete
):
# If a task has just been completed, we should be working on it.
assert isinstance(
self.status,
(
RunnerRunning,
RunnerWarmingUp,
RunnerLoading,
RunnerConnecting,
RunnerShuttingDown,
),
)
self.completed.add(event.task_id)
await self._event_sender.send(event)
except (ClosedResourceError, BrokenResourceError) as e:
await self._check_runner(e)

View File

@@ -9,9 +9,11 @@ MASTER_NODE_ID = NodeId("ffffffff-aaaa-4aaa-8aaa-aaaaaaaaaaaa")
NODE_A: Final[NodeId] = NodeId("aaaaaaaa-aaaa-4aaa-8aaa-aaaaaaaaaaaa")
NODE_B: Final[NodeId] = NodeId("bbbbbbbb-bbbb-4bbb-8bbb-bbbbbbbbbbbb")
NODE_C: Final[NodeId] = NodeId("cccccccc-cccc-4ccc-8ccc-cccccccccccc")
RUNNER_1_ID: Final[RunnerId] = RunnerId("11111111-1111-4111-8111-111111111111")
RUNNER_2_ID: Final[RunnerId] = RunnerId("33333333-3333-4333-8333-333333333333")
RUNNER_3_ID: Final[RunnerId] = RunnerId("Runner3")
INSTANCE_1_ID: Final[InstanceId] = InstanceId("22222222-2222-4222-8222-222222222222")
INSTANCE_2_ID: Final[InstanceId] = InstanceId("44444444-4444-4444-8444-444444444444")

View File

@@ -1,11 +1,9 @@
from __future__ import annotations
from dataclasses import dataclass
from dataclasses import dataclass, field
from exo.shared.types.common import NodeId
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.tasks import BaseTask
from exo.shared.types.tasks import BaseTask, TaskId
from exo.shared.types.worker.instances import (
BoundInstance,
Instance,
@@ -21,6 +19,7 @@ from exo.shared.types.worker.shards import PipelineShardMetadata, ShardMetadata
class FakeRunnerSupervisor:
bound_instance: BoundInstance
status: RunnerStatus
completed: set[TaskId] = field(default_factory=set)
class OtherTask(BaseTask):

View File

@@ -1,5 +1,6 @@
import exo.worker.plan as plan_mod
from exo.shared.types.common import NodeId
from exo.shared.types.models import ModelId
from exo.shared.types.tasks import LoadModel
from exo.shared.types.worker.downloads import DownloadCompleted, DownloadProgress
from exo.shared.types.worker.instances import BoundInstance
@@ -7,7 +8,6 @@ from exo.shared.types.worker.runners import (
RunnerConnected,
RunnerIdle,
)
from exo.shared.types.worker.shards import ShardMetadata
from exo.worker.tests.constants import (
INSTANCE_1_ID,
MODEL_A_ID,
@@ -46,7 +46,7 @@ def test_plan_requests_download_when_waiting_and_shard_not_downloaded():
all_runners = {RUNNER_1_ID: RunnerIdle()}
# No entry for this shard -> should trigger DownloadModel
download_status: dict[ShardMetadata, DownloadProgress] = {}
download_status: dict[ModelId, DownloadProgress] = {}
result = plan_mod.plan(
node_id=NODE_A,
@@ -94,7 +94,7 @@ def test_plan_loads_model_when_all_shards_downloaded_and_waiting():
# Local node has already marked its shard as downloaded (not actually used by _load_model)
local_download_status = {
shard1: DownloadCompleted(shard_metadata=shard1, node_id=NODE_A) # type: ignore[reportUnhashable]
MODEL_A_ID: DownloadCompleted(shard_metadata=shard1, node_id=NODE_A)
}
# Global view has completed downloads for both nodes
@@ -140,7 +140,7 @@ def test_plan_does_not_request_download_when_shard_already_downloaded():
# Local status claims the shard is downloaded already
local_download_status = {
shard: DownloadCompleted(shard_metadata=shard, node_id=NODE_A) # type: ignore[reportUnhashable]
MODEL_A_ID: DownloadCompleted(shard_metadata=shard, node_id=NODE_A)
}
# Global view hasn't caught up yet (no completed shards recorded for NODE_A)
@@ -192,7 +192,7 @@ def test_plan_does_not_load_model_until_all_shards_downloaded_globally():
# Only NODE_A's shard is recorded as downloaded globally
local_download_status = {
shard1: DownloadCompleted(shard_metadata=shard1, node_id=NODE_A) # type: ignore[reportUnhashable]
MODEL_A_ID: DownloadCompleted(shard_metadata=shard1, node_id=NODE_A)
}
global_download_status = {
NODE_A: [DownloadCompleted(shard_metadata=shard1, node_id=NODE_A)],

View File

@@ -12,8 +12,10 @@ from exo.worker.tests.constants import (
MODEL_A_ID,
NODE_A,
NODE_B,
NODE_C,
RUNNER_1_ID,
RUNNER_2_ID,
RUNNER_3_ID,
)
from exo.worker.tests.unittests.conftest import (
FakeRunnerSupervisor,
@@ -24,37 +26,39 @@ from exo.worker.tests.unittests.conftest import (
def test_plan_starts_warmup_for_accepting_rank_when_all_loaded_or_warming():
"""
For non-final device_rank shards, StartWarmup should be emitted when all
For non-zero device_rank shards, StartWarmup should be emitted when all
shards in the instance are Loaded/WarmingUp.
"""
shard0 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)
shard0 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=3)
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=3)
shard2 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=2, world_size=3)
instance = get_mlx_ring_instance(
instance_id=INSTANCE_1_ID,
model_id=MODEL_A_ID,
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},
runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1},
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID, NODE_C: RUNNER_3_ID},
runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1, RUNNER_3_ID: shard2},
)
bound_instance = BoundInstance(
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
instance=instance, bound_runner_id=RUNNER_2_ID, bound_node_id=NODE_B
)
local_runner = FakeRunnerSupervisor(
bound_instance=bound_instance, status=RunnerLoaded()
)
runners = {RUNNER_1_ID: local_runner}
runners = {RUNNER_2_ID: local_runner}
instances = {INSTANCE_1_ID: instance}
all_runners = {
RUNNER_1_ID: RunnerLoaded(),
RUNNER_2_ID: RunnerLoaded(),
RUNNER_3_ID: RunnerWarmingUp(),
}
result = plan_mod.plan(
node_id=NODE_A,
node_id=NODE_B,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_B: []},
global_download_status={NODE_A: []},
instances=instances,
all_runners=all_runners,
tasks={},
@@ -150,9 +154,9 @@ def test_plan_does_not_start_warmup_for_rank_zero_until_others_warming():
"""
Rank-zero shard should not start warmup until all non-zero ranks are
already WarmingUp.
For accepting ranks (device_rank != world_size - 1), StartWarmup should be
For accepting ranks (device_rank != 0), StartWarmup should be
emitted when all shards in the instance are Loaded/WarmingUp.
In a 2-node setup, rank 0 is the accepting rank.
In a 2-node setup, rank 1 is the accepting rank.
"""
shard0 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)
@@ -163,7 +167,7 @@ def test_plan_does_not_start_warmup_for_rank_zero_until_others_warming():
runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1},
)
# Rank 0 is the accepting rank
# Rank 1 is the accepting rank
bound_instance = BoundInstance(
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
)
@@ -188,6 +192,23 @@ def test_plan_does_not_start_warmup_for_rank_zero_until_others_warming():
tasks={},
)
assert result is None
all_runners = {
RUNNER_1_ID: RunnerLoaded(),
RUNNER_2_ID: RunnerWarmingUp(),
}
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_A: []},
instances=instances,
all_runners=all_runners,
tasks={},
)
assert isinstance(result, StartWarmup)
assert result.instance_id == INSTANCE_1_ID
@@ -280,9 +301,8 @@ def test_plan_does_not_start_warmup_for_accepting_rank_until_all_loaded_or_warmi
def test_plan_does_not_start_warmup_for_connecting_rank_until_others_warming():
"""
Connecting rank (device_rank == world_size - 1) should not start warmup
Connecting rank (device_rank == 0) should not start warmup
until all other ranks are already WarmingUp.
In a 2-node setup, rank 1 is the connecting rank.
"""
shard0 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)
@@ -295,13 +315,13 @@ def test_plan_does_not_start_warmup_for_connecting_rank_until_others_warming():
# Rank 1 is the connecting rank
bound_instance = BoundInstance(
instance=instance, bound_runner_id=RUNNER_2_ID, bound_node_id=NODE_B
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
)
local_runner = FakeRunnerSupervisor(
bound_instance=bound_instance, status=RunnerLoaded()
)
runners = {RUNNER_2_ID: local_runner}
runners = {RUNNER_1_ID: local_runner}
instances = {INSTANCE_1_ID: instance}
all_runners = {
RUNNER_1_ID: RunnerLoaded(),
@@ -309,7 +329,7 @@ def test_plan_does_not_start_warmup_for_connecting_rank_until_others_warming():
}
result = plan_mod.plan(
node_id=NODE_B,
node_id=NODE_A,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_A: [], NODE_B: []},

View File

@@ -34,6 +34,7 @@ from exo.shared.types.worker.runners import (
RunnerReady,
RunnerRunning,
RunnerShutdown,
RunnerShuttingDown,
RunnerWarmingUp,
)
from exo.utils.channels import mp_channel
@@ -199,6 +200,9 @@ def test_events_processed_in_correct_order(patch_out_mlx: pytest.MonkeyPatch):
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerReady()),
TaskStatusUpdated(task_id=SHUTDOWN_TASK_ID, task_status=TaskStatus.Running),
TaskAcknowledged(task_id=SHUTDOWN_TASK_ID),
RunnerStatusUpdated(
runner_id=RUNNER_1_ID, runner_status=RunnerShuttingDown()
),
TaskStatusUpdated(
task_id=SHUTDOWN_TASK_ID, task_status=TaskStatus.Complete
),