Compare commits

..

7 Commits

Author SHA1 Message Date
Sami Khan
37c5a2a246 Merge branch 'main' into sami/flash 2026-01-15 08:57:36 +05:00
Sami Khan
4d7f03834a deleted separate server 2026-01-15 08:50:45 +05:00
Sami Khan
bdb9fbc8c0 Merge branch 'main' into sami/flash 2026-01-14 08:10:51 +05:00
Sami Khan
8c7180810c type checking 2026-01-14 07:15:45 +05:00
Sami Khan
318c6e000b code cleanup 2026-01-14 04:56:59 +05:00
Sami Khan
2d45544da0 use rsh server instead of ssh 2026-01-13 02:46:25 +05:00
Sami Khan
7cbafa768a flash+exo 2026-01-12 10:26:16 +05:00
35 changed files with 2011 additions and 1380 deletions

19
Cargo.lock generated
View File

@@ -4340,6 +4340,25 @@ dependencies = [
"libc",
]
[[package]]
name = "system_custodian"
version = "0.0.1"
dependencies = [
"delegate",
"derive_more",
"either",
"extend",
"futures",
"futures-timer",
"impl-trait-for-tuples",
"keccak-const",
"log",
"thiserror 2.0.17",
"tokio",
"tracing-subscriber",
"util",
]
[[package]]
name = "tagptr"
version = "0.2.0"

View File

@@ -3,6 +3,7 @@ resolver = "3"
members = [
"rust/networking",
"rust/exo_pyo3_bindings",
"rust/system_custodian",
"rust/util",
]
@@ -24,6 +25,7 @@ opt-level = 3
[workspace.dependencies]
## Crate members as common dependencies
networking = { path = "rust/networking" }
system_custodian = { path = "rust/system_custodian" }
util = { path = "rust/util" }
# Proc-macro authoring tools

View File

@@ -56,11 +56,6 @@ struct ContentView: View {
}
private var shouldShowLocalNetworkWarning: Bool {
// Show warning if local network is not working and EXO is running.
// The checker uses a longer timeout on first launch to allow time for
// the permission prompt, so this correctly handles both:
// 1. User denied permission on first launch
// 2. Permission broke after restart (macOS TCC bug)
if case .notWorking = localNetworkChecker.status {
return controller.status != .stopped
}

View File

@@ -5,8 +5,8 @@ import os.log
/// Checks if the app's local network permission is actually functional.
///
/// macOS local network permission can appear enabled in System Preferences but not
/// actually work after a restart. This service uses NWConnection to mDNS multicast
/// to verify actual connectivity.
/// actually work after a restart. This service detects this by creating a UDP
/// connection to the mDNS multicast address (224.0.0.251:5353).
@MainActor
final class LocalNetworkChecker: ObservableObject {
enum Status: Equatable {
@@ -35,43 +35,30 @@ final class LocalNetworkChecker: ObservableObject {
}
private static let logger = Logger(subsystem: "io.exo.EXO", category: "LocalNetworkChecker")
private static let hasCompletedInitialCheckKey = "LocalNetworkChecker.hasCompletedInitialCheck"
@Published private(set) var status: Status = .unknown
@Published private(set) var lastConnectionState: String = "none"
private var connection: NWConnection?
private var checkTask: Task<Void, Never>?
/// Whether we've completed at least one check (stored in UserDefaults)
private var hasCompletedInitialCheck: Bool {
get { UserDefaults.standard.bool(forKey: Self.hasCompletedInitialCheckKey) }
set { UserDefaults.standard.set(newValue, forKey: Self.hasCompletedInitialCheckKey) }
}
/// Checks if local network access is working.
func check() {
checkTask?.cancel()
status = .checking
// Use longer timeout on first launch to allow time for permission prompt
let isFirstCheck = !hasCompletedInitialCheck
let timeout: UInt64 = isFirstCheck ? 30_000_000_000 : 3_000_000_000
lastConnectionState = "connecting"
checkTask = Task { [weak self] in
guard let self else { return }
Self.logger.info("Checking local network connectivity (first check: \(isFirstCheck))")
let result = await self.checkConnectivity(timeout: timeout)
let result = await self.performCheck()
self.status = result
self.hasCompletedInitialCheck = true
Self.logger.info("Local network check complete: \(result.displayText)")
}
}
/// Checks connectivity using NWConnection to mDNS multicast.
/// The connection attempt triggers the permission prompt if not yet shown.
private func checkConnectivity(timeout: UInt64) async -> Status {
private func performCheck() async -> Status {
Self.logger.info("Checking local network access via UDP multicast")
connection?.cancel()
connection = nil
@@ -97,7 +84,22 @@ final class LocalNetworkChecker: ObservableObject {
continuation.resume(returning: status)
}
conn.stateUpdateHandler = { state in
conn.stateUpdateHandler = { [weak self] state in
let stateStr: String
switch state {
case .setup: stateStr = "setup"
case .preparing: stateStr = "preparing"
case .ready: stateStr = "ready"
case .waiting(let e): stateStr = "waiting(\(e))"
case .failed(let e): stateStr = "failed(\(e))"
case .cancelled: stateStr = "cancelled"
@unknown default: stateStr = "unknown"
}
Task { @MainActor in
self?.lastConnectionState = stateStr
}
switch state {
case .ready:
resumeOnce(.working)
@@ -106,7 +108,6 @@ final class LocalNetworkChecker: ObservableObject {
if errorStr.contains("54") || errorStr.contains("ECONNRESET") {
resumeOnce(.notWorking(reason: "Connection blocked"))
}
// Otherwise keep waiting - might be showing permission prompt
case .failed(let error):
let errorStr = "\(error)"
if errorStr.contains("65") || errorStr.contains("EHOSTUNREACH")
@@ -126,7 +127,7 @@ final class LocalNetworkChecker: ObservableObject {
conn.start(queue: .main)
Task {
try? await Task.sleep(nanoseconds: timeout)
try? await Task.sleep(nanoseconds: 3_000_000_000)
let state = conn.state
switch state {
case .ready:

View File

@@ -241,9 +241,6 @@ class PromptSizer:
ids = tokenizer.apply_chat_template(
messages, tokenize=True, add_generation_prompt=True
)
# Fix for transformers 5.x
if hasattr(ids, "input_ids"):
ids = ids.input_ids
return int(len(ids))
return count_fn

View File

@@ -863,7 +863,6 @@
"integrity": "sha512-oH8tXw7EZnie8FdOWYrF7Yn4IKrqTFHhXvl8YxXxbKwTMcD/5NNCryUSEXRk2ZR4ojnub0P8rNrsVGHXWqIDtA==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"@standard-schema/spec": "^1.0.0",
"@sveltejs/acorn-typescript": "^1.0.5",
@@ -903,7 +902,6 @@
"integrity": "sha512-Y1Cs7hhTc+a5E9Va/xwKlAJoariQyHY+5zBgCZg4PFWNYQ1nMN9sjK1zhw1gK69DuqVP++sht/1GZg1aRwmAXQ==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"@sveltejs/vite-plugin-svelte-inspector": "^4.0.1",
"debug": "^4.4.1",
@@ -1520,7 +1518,6 @@
"integrity": "sha512-LCCV0HdSZZZb34qifBsyWlUmok6W7ouER+oQIGBScS8EsZsQbrtFTUrDX4hOl+CS6p7cnNC4td+qrSVGSCTUfQ==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"undici-types": "~6.21.0"
}
@@ -1530,7 +1527,6 @@
"resolved": "https://registry.npmjs.org/acorn/-/acorn-8.15.0.tgz",
"integrity": "sha512-NZyJarBfL7nWwIq+FDL6Zp/yHEhePMNnnJ0y3qfieCrmNvYct8uvtiV41UvlSe6apAfk0fY1FbWx+NwfmpvtTg==",
"license": "MIT",
"peer": true,
"bin": {
"acorn": "bin/acorn"
},
@@ -1943,7 +1939,6 @@
"integrity": "sha512-fmTRWbNMmsmWq6xJV8D19U/gw/bwrHfNXxrIN+HfZgnzqTHp9jOmKMhsTUjXOJnZOdZY9Q28y4yebKzqDKlxlQ==",
"dev": true,
"license": "ISC",
"peer": true,
"engines": {
"node": ">=12"
}
@@ -2651,7 +2646,6 @@
"integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==",
"dev": true,
"license": "MIT",
"peer": true,
"engines": {
"node": ">=12"
},
@@ -2839,7 +2833,6 @@
"resolved": "https://registry.npmjs.org/svelte/-/svelte-5.45.3.tgz",
"integrity": "sha512-ngKXNhNvwPzF43QqEhDOue7TQTrG09em1sd4HBxVF0Wr2gopAmdEWan+rgbdgK4fhBtSOTJO8bYU4chUG7VXZQ==",
"license": "MIT",
"peer": true,
"dependencies": {
"@jridgewell/remapping": "^2.3.4",
"@jridgewell/sourcemap-codec": "^1.5.0",
@@ -2984,7 +2977,6 @@
"integrity": "sha512-jl1vZzPDinLr9eUt3J/t7V6FgNEw9QjvBPdysz9KfQDD41fQrC2Y4vKQdiaUpFT4bXlb1RHhLpp8wtm6M5TgSw==",
"dev": true,
"license": "Apache-2.0",
"peer": true,
"bin": {
"tsc": "bin/tsc",
"tsserver": "bin/tsserver"
@@ -3006,7 +2998,6 @@
"integrity": "sha512-+Oxm7q9hDoLMyJOYfUYBuHQo+dkAloi33apOPP56pzj+vsdJDzr+j1NISE5pyaAuKL4A3UD34qd0lx5+kfKp2g==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"esbuild": "^0.25.0",
"fdir": "^6.4.4",

View File

@@ -60,39 +60,12 @@
return models;
});
// Track previous model IDs to detect newly added models (plain variable to avoid reactive loop)
let previousModelIds: Set<string> = new Set();
// Auto-select the first available model if none is selected, if current selection is stale, or if a new model is added
// Auto-select the first available model if none is selected
$effect(() => {
const models = availableModels();
const currentModelIds = new Set(models.map(m => m.id));
if (models.length > 0) {
// Find newly added models (in current but not in previous)
const newModels = models.filter(m => !previousModelIds.has(m.id));
// If no model selected, select the first available
if (!currentModel) {
setSelectedChatModel(models[0].id);
}
// If current model is stale (no longer has a running instance), reset to first available
else if (!models.some(m => m.id === currentModel)) {
setSelectedChatModel(models[0].id);
}
// If a new model was just added, select it
else if (newModels.length > 0 && previousModelIds.size > 0) {
setSelectedChatModel(newModels[0].id);
}
} else {
// No instances running - clear the selected model
if (currentModel) {
setSelectedChatModel('');
}
if (models.length > 0 && !currentModel) {
setSelectedChatModel(models[0].id);
}
// Update previous model IDs for next comparison
previousModelIds = currentModelIds;
});
function getInstanceModelId(instanceWrapped: unknown): string {

View File

@@ -400,8 +400,10 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
const errorText = await response.text();
console.error('Failed to launch instance:', errorText);
} else {
// Always auto-select the newly launched model so the user chats to what they just launched
setSelectedChatModel(modelId);
// Auto-select the launched model only if no model is currently selected
if (!selectedChatModel()) {
setSelectedChatModel(modelId);
}
// Scroll to the bottom of instances container to show the new instance
// Use multiple attempts to ensure DOM has updated with the new instance
@@ -761,10 +763,6 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
async function deleteInstance(instanceId: string) {
if (!confirm(`Delete instance ${instanceId.slice(0, 8)}...?`)) return;
// Get the model ID of the instance being deleted before we delete it
const deletedInstanceModelId = getInstanceModelId(instanceData[instanceId]);
const wasSelected = selectedChatModel() === deletedInstanceModelId;
try {
const response = await fetch(`/instance/${instanceId}`, {
method: 'DELETE',
@@ -773,24 +771,6 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
if (!response.ok) {
console.error('Failed to delete instance:', response.status);
} else if (wasSelected) {
// If we deleted the currently selected model, switch to another available model
// Find another instance that isn't the one we just deleted
const remainingInstances = Object.entries(instanceData).filter(([id]) => id !== instanceId);
if (remainingInstances.length > 0) {
// Select the last instance (most recently added, since objects preserve insertion order)
const [, lastInstance] = remainingInstances[remainingInstances.length - 1];
const newModelId = getInstanceModelId(lastInstance);
if (newModelId && newModelId !== 'Unknown' && newModelId !== 'Unknown Model') {
setSelectedChatModel(newModelId);
} else {
// Clear selection if no valid model found
setSelectedChatModel('');
}
} else {
// No more instances, clear the selection
setSelectedChatModel('');
}
}
} catch (error) {
console.error('Error deleting instance:', error);

View File

@@ -1,5 +1,3 @@
export NIX_CONFIG := "extra-experimental-features = nix-command flakes"
fmt:
nix fmt

View File

@@ -23,13 +23,13 @@ dependencies = [
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
"hypercorn>=0.18.0",
"openai-harmony>=0.0.8",
"httpx>=0.28.1",
]
[project.scripts]
exo-master = "exo.master.main:main"
exo-worker = "exo.worker.main:main"
exo = "exo.main:main"
exo-rsh = "exo.rsh.client:main"
# dependencies only required for development
[dependency-groups]

View File

@@ -81,6 +81,20 @@
config = {
packages = {
# The system_custodian binary
system_custodian = craneLib.buildPackage (
commonArgs
// {
inherit cargoArtifacts;
cargoExtraArgs = "-p system_custodian";
meta = {
description = "System custodian daemon for exo";
mainProgram = "system_custodian";
};
}
);
# Python bindings wheel via maturin
exo_pyo3_bindings = craneLib.buildPackage (
commonArgs

View File

@@ -0,0 +1,47 @@
[package]
name = "system_custodian"
version = { workspace = true }
edition = { workspace = true }
publish = false
[lib]
doctest = false
name = "system_custodian"
path = "src/lib.rs"
[[bin]]
path = "src/bin/main.rs"
name = "system_custodian"
doc = false
[lints]
workspace = true
[dependencies]
# datastructures
either = { workspace = true }
# macro dependencies
extend = { workspace = true }
delegate = { workspace = true }
impl-trait-for-tuples = { workspace = true }
derive_more = { workspace = true }
# async
tokio = { workspace = true, features = ["full"] }
futures = { workspace = true }
futures-timer = { workspace = true }
# utility dependencies
util = { workspace = true }
thiserror = { workspace = true }
#internment = { workspace = true }
#recursion = { workspace = true }
#generativity = { workspace = true }
#itertools = { workspace = true }
tracing-subscriber = { version = "0.3.19", features = ["default", "env-filter"] }
keccak-const = { workspace = true }
# tracing/logging
log = { workspace = true }

View File

@@ -0,0 +1,4 @@
//! TODO: documentation
//!
fn main() {}

View File

@@ -0,0 +1,69 @@
//! This crate defines the logic of, and ways to interact with, Exo's **_System Custodian_** daemon.
//!
//! The **_System Custodian_** daemon is supposed to be a long-living process that precedes the
//! launch of the Exo application, and responsible for ensuring the system (configuration, settings,
//! etc.) is in an appropriate state to facilitate the running of Exo application.
//! The **_System Custodian_** daemon shall expose a [D-Bus](https://www.freedesktop.org/wiki/Software/dbus/)
//! service which Exo application use to _control & query_ it.
//!
//! # Lifecycle
//! When the Exo application starts, it will _wake_ the **_System Custodian_** daemon for the
//! duration of its lifetime, and after it has terminated the daemon will go back to sleep. When
//! the daemon wakes up, it will configure the system into a state suitable for the Exo Application;
//! When the daemon goes to sleep, it will revert those changes as much as it can in case they were
//! destructive to the user's pre-existing configurations.
//!
//! # Responsibilities
//! TODO: these are purely on MacOS, but change to be more broad
//! The **_System Custodian_** daemon is responsible for using System Configuration framework to
//! 1. duplicate the current network set
//! 2. modify existing services to turn on IPv6 if not there
//! 3. remove any bridge services & add any missing services that AREN'T bridge
//! TODO: In the future:
//! 1. run a dummy AWDL service to [allow for macOS peer-to-peer wireless networking](https://yggdrasil-network.github.io/2019/08/19/awdl.html)
//! 2. toggle some GPU/memory configurations to speed up GPU (ask Alex what those configurations are)
//! 3. if we ever decide to provide our **own network interfaces** that abstract over some userland
//! logic, this would be the place to spin that up.
//!
//! Then it will watch the SCDynamicStore for:
//! 1. all __actual__ network interfaces -> collect information on them e.g. their BSD name, MAC
//! address, MTU, IPv6 addresses, etc. -> and set up watchers/notifiers to inform the DBus
//! interface of any changes
//! 2. watch for any __undesirable__ changes to configuration and revert it
//!
//! It should somehow (probably through system sockets and/or BSD interface) trigger IPv6 NDP on
//! each of the interfaces & also listen to/query for any changes on the OS routing cache??
//! Basically emulate the `ping6 ff02::1%enX` and `ndp -an` commands BUT BETTER!!!
//! 1. all that info should coalesce back to the overall state colleted -> should be queryable
//! over D-Bus
//! TODO:
//! 1. we might potentially add to this step a handshake of some kind...? To ensure that we can
//! ACTUALLY communicate with that machine over that link over e.g. TCP, UDP, etc. Will the
//! handshake require to know Node ID? Will the handshake require heartbeats? Who knows...
//! 2. if we ever decide to write proprietary L2/L3 protocols for quicker communication,
//! e.g. [AF_NDRV](https://www.zerotier.com/blog/how-zerotier-eliminated-kernel-extensions-on-macos/)
//! for raw ethernet frame communication, or even a [custom thunderbolt PCIe driver](https://developer.apple.com/documentation/pcidriverkit/creating-custom-pcie-drivers-for-thunderbolt-devices),
//! then this would be the place to carry out discovery and propper handshakes with devices
//! on the other end of the link.
//!
// enable Rust-unstable features for convenience
#![feature(trait_alias)]
#![feature(stmt_expr_attributes)]
#![feature(type_alias_impl_trait)]
#![feature(specialization)]
#![feature(unboxed_closures)]
#![feature(const_trait_impl)]
#![feature(fn_traits)]
pub(crate) mod private {
// sealed traits support
pub trait Sealed {}
impl<T: ?Sized> Sealed for T {}
}
/// Namespace for all the type/trait aliases used by this crate.
pub(crate) mod alias {}
/// Namespace for crate-wide extension traits/methods
pub(crate) mod ext {}

View File

@@ -1,6 +1,8 @@
import asyncio
import os
import time
from collections.abc import AsyncGenerator
from typing import cast
from typing import Any, Optional, cast
import anyio
from anyio import create_task_group
@@ -13,6 +15,13 @@ from hypercorn.asyncio import serve # pyright: ignore[reportUnknownVariableType
from hypercorn.config import Config
from hypercorn.typing import ASGIFramework
from loguru import logger
from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
HarmonyEncodingName,
Role,
StreamableParser,
load_harmony_encoding,
)
from pydantic import BaseModel
from exo.master.placement import place_instance as get_instance_placements
from exo.shared.apply import apply
@@ -45,7 +54,9 @@ from exo.shared.types.commands import (
CreateInstance,
DeleteInstance,
ForwarderCommand,
LaunchFLASH,
PlaceInstance,
StopFLASH,
TaskFinished,
)
from exo.shared.types.common import CommandId, NodeId, SessionId
@@ -54,13 +65,36 @@ from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.state import State
from exo.shared.types.tasks import ChatCompletionTaskParams
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.instances import (
FLASHInstance,
Instance,
InstanceId,
InstanceMeta,
)
from exo.shared.types.worker.shards import Sharding
from exo.utils.banner import print_startup_banner
from exo.utils.channels import Receiver, Sender, channel
from exo.utils.dashboard_path import find_dashboard
from exo.utils.event_buffer import OrderedBuffer
encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
class ExecuteRequest(BaseModel):
"""Request to execute a command."""
command: list[str]
cwd: Optional[str] = None
env: Optional[dict[str, str]] = None
class ExecuteResponse(BaseModel):
"""Response from command execution."""
exit_code: int
stdout: str
stderr: str
def chunk_to_response(
chunk: TokenChunk, command_id: CommandId
@@ -170,6 +204,12 @@ class API:
self.app.post("/bench/chat/completions")(self.bench_chat_completions)
self.app.get("/state")(lambda: self.state)
self.app.get("/events")(lambda: self._event_log)
# FLASH simulation endpoints
self.app.post("/flash/launch")(self.launch_flash)
self.app.delete("/flash/{instance_id}")(self.stop_flash)
self.app.get("/flash/instances")(self.list_flash_instances)
# Remote execution endpoint (used by exo-rsh for MPI)
self.app.post("/execute")(self.execute)
async def place_instance(self, payload: PlaceInstanceParams):
command = PlaceInstance(
@@ -373,8 +413,35 @@ class API:
instance_id=instance_id,
)
async def _process_gpt_oss(self, token_chunks: Receiver[TokenChunk]):
stream = StreamableParser(encoding, role=Role.ASSISTANT)
thinking = False
async for chunk in token_chunks:
stream.process(chunk.token_id)
delta = stream.last_content_delta
ch = stream.current_channel
if ch == "analysis" and not thinking:
thinking = True
yield chunk.model_copy(update={"text": "<think>"})
if ch != "analysis" and thinking:
thinking = False
yield chunk.model_copy(update={"text": "</think>"})
if delta:
yield chunk.model_copy(update={"text": delta})
if chunk.finish_reason is not None:
if thinking:
yield chunk.model_copy(update={"text": "</think>"})
yield chunk
break
async def _chat_chunk_stream(
self, command_id: CommandId
self, command_id: CommandId, parse_gpt_oss: bool
) -> AsyncGenerator[TokenChunk, None]:
"""Yield `TokenChunk`s for a given command until completion."""
@@ -382,10 +449,16 @@ class API:
self._chat_completion_queues[command_id], recv = channel[TokenChunk]()
with recv as token_chunks:
async for chunk in token_chunks:
yield chunk
if chunk.finish_reason is not None:
break
if parse_gpt_oss:
async for chunk in self._process_gpt_oss(token_chunks):
yield chunk
if chunk.finish_reason is not None:
break
else:
async for chunk in token_chunks:
yield chunk
if chunk.finish_reason is not None:
break
except anyio.get_cancelled_exc_class():
# TODO: TaskCancelled
@@ -401,11 +474,11 @@ class API:
del self._chat_completion_queues[command_id]
async def _generate_chat_stream(
self, command_id: CommandId
self, command_id: CommandId, parse_gpt_oss: bool
) -> AsyncGenerator[str, None]:
"""Generate chat completion stream as JSON strings."""
async for chunk in self._chat_chunk_stream(command_id):
async for chunk in self._chat_chunk_stream(command_id, parse_gpt_oss):
chunk_response: ChatCompletionResponse = chunk_to_response(
chunk, command_id
)
@@ -417,7 +490,7 @@ class API:
yield "data: [DONE]\n\n"
async def _collect_chat_completion(
self, command_id: CommandId
self, command_id: CommandId, parse_gpt_oss: bool
) -> ChatCompletionResponse:
"""Collect all token chunks for a chat completion and return a single response."""
@@ -425,7 +498,7 @@ class API:
model: str | None = None
finish_reason: FinishReason | None = None
async for chunk in self._chat_chunk_stream(command_id):
async for chunk in self._chat_chunk_stream(command_id, parse_gpt_oss):
if model is None:
model = chunk.model
@@ -454,7 +527,7 @@ class API:
)
async def _collect_chat_completion_with_stats(
self, command_id: CommandId
self, command_id: CommandId, parse_gpt_oss: bool
) -> BenchChatCompletionResponse:
text_parts: list[str] = []
model: str | None = None
@@ -462,7 +535,7 @@ class API:
stats: GenerationStats | None = None
async for chunk in self._chat_chunk_stream(command_id):
async for chunk in self._chat_chunk_stream(command_id, parse_gpt_oss):
if model is None:
model = chunk.model
@@ -503,6 +576,8 @@ class API:
"""Handle chat completions, supporting both streaming and non-streaming responses."""
model_meta = await resolve_model_meta(payload.model)
payload.model = model_meta.model_id
parse_gpt_oss = "gpt-oss" in model_meta.model_id.lower()
logger.info(f"{parse_gpt_oss=}")
if not any(
instance.shard_assignments.model_id == payload.model
@@ -519,16 +594,17 @@ class API:
await self._send(command)
if payload.stream:
return StreamingResponse(
self._generate_chat_stream(command.command_id),
self._generate_chat_stream(command.command_id, parse_gpt_oss),
media_type="text/event-stream",
)
return await self._collect_chat_completion(command.command_id)
return await self._collect_chat_completion(command.command_id, parse_gpt_oss)
async def bench_chat_completions(
self, payload: BenchChatCompletionTaskParams
) -> BenchChatCompletionResponse:
model_meta = await resolve_model_meta(payload.model)
parse_gpt_oss = "gpt-oss" in model_meta.model_id.lower()
payload.model = model_meta.model_id
if not any(
@@ -545,7 +621,10 @@ class API:
command = ChatCompletion(request_params=payload)
await self._send(command)
response = await self._collect_chat_completion_with_stats(command.command_id)
response = await self._collect_chat_completion_with_stats(
command.command_id,
parse_gpt_oss,
)
return response
def _calculate_total_available_memory(self) -> Memory:
@@ -575,6 +654,145 @@ class API:
]
)
async def launch_flash(
self,
simulation_name: str,
flash_executable_path: str,
working_directory: str,
parameter_file_path: str = "",
ranks_per_node: int = 1,
min_nodes: int = 1,
hosts: str = "",
) -> dict[str, str]:
"""Launch a FLASH MPI simulation across the cluster.
Args:
hosts: Optional comma-separated hostnames (e.g., "s14,james21-1").
If not provided, IPs are discovered from topology edges.
"""
command = LaunchFLASH(
simulation_name=simulation_name,
flash_executable_path=flash_executable_path,
parameter_file_path=parameter_file_path,
working_directory=working_directory,
ranks_per_node=ranks_per_node,
min_nodes=min_nodes,
hosts=hosts,
)
await self._send(command)
return {
"message": "FLASH launch command received",
"command_id": str(command.command_id),
"simulation_name": simulation_name,
}
async def stop_flash(self, instance_id: InstanceId) -> dict[str, str]:
"""Stop a running FLASH simulation."""
if instance_id not in self.state.instances:
raise HTTPException(status_code=404, detail="Instance not found")
instance = self.state.instances[instance_id]
if not isinstance(instance, FLASHInstance):
raise HTTPException(
status_code=400, detail="Instance is not a FLASH simulation"
)
command = StopFLASH(instance_id=instance_id)
await self._send(command)
return {
"message": "Stop command received",
"command_id": str(command.command_id),
"instance_id": str(instance_id),
}
async def list_flash_instances(self) -> list[dict[str, Any]]:
"""List all FLASH simulation instances."""
flash_instances: list[dict[str, Any]] = []
for instance_id, instance in self.state.instances.items():
if isinstance(instance, FLASHInstance):
# Get runner statuses for this instance
runner_statuses: dict[str, str | None] = {}
for (
node_id,
runner_id,
) in instance.shard_assignments.node_to_runner.items():
runner_status = self.state.runners.get(runner_id)
runner_statuses[str(node_id)] = (
str(runner_status) if runner_status else None
)
flash_instances.append(
{
"instance_id": str(instance_id),
"simulation_name": instance.simulation_name,
"total_ranks": instance.total_ranks,
"working_directory": instance.working_directory,
"runner_statuses": runner_statuses,
}
)
return flash_instances
async def execute(self, request: ExecuteRequest) -> ExecuteResponse:
"""Execute a command locally. Used by exo-rsh for MPI remote execution."""
cmd_str = " ".join(request.command)
logger.info(f"Executing: {cmd_str}")
try:
# Build environment
env = os.environ.copy()
if request.env:
env.update(request.env)
# Check if command contains shell metacharacters
# If so, run through shell. mpirun sends complex commands like:
# "VAR=value;export VAR;/path/to/prted --args"
needs_shell = any(c in cmd_str for c in ";|&$`")
if needs_shell:
process = await asyncio.create_subprocess_shell(
cmd_str,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
cwd=request.cwd,
env=env,
)
else:
process = await asyncio.create_subprocess_exec(
*request.command,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
cwd=request.cwd,
env=env,
)
stdout, stderr = await process.communicate()
exit_code = process.returncode or 0
logger.info(f"Command completed with exit code {exit_code}")
return ExecuteResponse(
exit_code=exit_code,
stdout=stdout.decode("utf-8", errors="replace"),
stderr=stderr.decode("utf-8", errors="replace"),
)
except FileNotFoundError:
logger.error(f"Command not found: {request.command[0]}")
return ExecuteResponse(
exit_code=127,
stdout="",
stderr=f"Command not found: {request.command[0]}",
)
except Exception as e:
logger.error(f"Execution error: {e}")
return ExecuteResponse(
exit_code=1,
stdout="",
stderr=str(e),
)
async def run(self):
cfg = Config()
cfg.bind = f"0.0.0.0:{self.port}"

View File

@@ -8,6 +8,7 @@ from exo.master.placement import (
add_instance_to_placements,
delete_instance,
get_transition_events,
place_flash_instance,
place_instance,
)
from exo.shared.apply import apply
@@ -16,8 +17,10 @@ from exo.shared.types.commands import (
CreateInstance,
DeleteInstance,
ForwarderCommand,
LaunchFLASH,
PlaceInstance,
RequestEventLog,
StopFLASH,
TaskFinished,
TestCommand,
)
@@ -173,6 +176,26 @@ class Master:
self.state.instances, placement
)
generated_events.extend(transition_events)
case LaunchFLASH():
placement = place_flash_instance(
command,
self.state.topology,
self.state.instances,
)
transition_events = get_transition_events(
self.state.instances, placement
)
generated_events.extend(transition_events)
case StopFLASH():
# Reuse delete_instance logic to stop FLASH simulation
placement = delete_instance(
DeleteInstance(instance_id=command.instance_id),
self.state.instances,
)
transition_events = get_transition_events(
self.state.instances, placement
)
generated_events.extend(transition_events)
case TaskFinished():
generated_events.append(
TaskDeleted(

View File

@@ -17,20 +17,24 @@ from exo.shared.topology import Topology
from exo.shared.types.commands import (
CreateInstance,
DeleteInstance,
LaunchFLASH,
PlaceInstance,
)
from exo.shared.types.common import Host, NodeId
from exo.shared.types.events import Event, InstanceCreated, InstanceDeleted
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.topology import NodeInfo
from exo.shared.types.worker.instances import (
FLASHInstance,
Instance,
InstanceId,
InstanceMeta,
MlxJacclInstance,
MlxRingInstance,
)
from exo.shared.types.worker.shards import Sharding
from exo.shared.types.worker.runners import RunnerId, ShardAssignments
from exo.shared.types.worker.shards import PipelineShardMetadata, Sharding
def random_ephemeral_port() -> int:
@@ -165,6 +169,9 @@ def place_instance(
hosts_by_node=hosts_by_node,
ephemeral_port=ephemeral_port,
)
case InstanceMeta.FLASH:
# FLASH instances are handled by place_flash_instance()
raise ValueError("FLASH instances should use place_flash_instance()")
return target_instances
@@ -180,6 +187,148 @@ def delete_instance(
raise ValueError(f"Instance {command.instance_id} not found")
def place_flash_instance(
command: LaunchFLASH,
topology: Topology,
current_instances: Mapping[InstanceId, Instance],
) -> dict[InstanceId, Instance]:
"""Place a FLASH simulation instance across available nodes.
Unlike MLX instances which use ring/JACCL topology for tensor parallelism,
FLASH instances use MPI for communication. We just need to provide the
node IPs so the runner can generate an MPI hostfile.
"""
instance_id = InstanceId()
target_instances = dict(deepcopy(current_instances))
all_nodes = list(topology.list_nodes())
if len(all_nodes) < command.min_nodes:
raise ValueError(
f"Not enough nodes: need {command.min_nodes}, have {len(all_nodes)}"
)
# Select nodes (take the first min_nodes)
selected_nodes = all_nodes[: command.min_nodes]
logger.info(
f"Placing FLASH instance '{command.simulation_name}' on {len(selected_nodes)} nodes"
)
# Build shard assignments (one runner per node for FLASH)
runner_to_shard: dict[RunnerId, PipelineShardMetadata] = {}
node_to_runner: dict[NodeId, RunnerId] = {}
# Create a dummy ModelMetadata for FLASH (required by ShardMetadata interface)
flash_model_meta = ModelMetadata(
model_id=ModelId(command.simulation_name),
pretty_name=f"FLASH: {command.simulation_name}",
storage_size=Memory(in_bytes=0),
n_layers=1,
hidden_size=1,
supports_tensor=False,
)
for i, node_info in enumerate(selected_nodes):
runner_id = RunnerId()
node_to_runner[node_info.node_id] = runner_id
runner_to_shard[runner_id] = PipelineShardMetadata(
device_rank=i,
world_size=len(selected_nodes),
model_meta=flash_model_meta,
start_layer=0,
end_layer=1,
n_layers=1,
)
shard_assignments = ShardAssignments(
model_id=ModelId(command.simulation_name),
runner_to_shard=runner_to_shard,
node_to_runner=node_to_runner,
)
# Build hosts_by_node - get hostnames/IPs for MPI hostfile generation
hosts_by_node: dict[NodeId, list[Host]] = {}
# If explicit hosts are provided, use them directly
if command.hosts:
explicit_hosts = [h.strip() for h in command.hosts.split(",") if h.strip()]
logger.info(f"FLASH placement: explicit hosts provided: {explicit_hosts}")
for i, node_info in enumerate(selected_nodes):
if i < len(explicit_hosts):
hosts_by_node[node_info.node_id] = [Host(ip=explicit_hosts[i], port=0)]
logger.info(
f"FLASH placement: node {node_info.node_id} (rank {i}) -> IP {explicit_hosts[i]}"
)
else:
logger.warning(
f"Not enough hosts provided for node {i}, using localhost"
)
hosts_by_node[node_info.node_id] = [Host(ip="127.0.0.1", port=0)]
logger.info(
f"FLASH placement: coordinator will be rank 0 at IP {explicit_hosts[0]}"
)
else:
# Try to get IPs from topology edges
for node_info in selected_nodes:
node_hosts: list[Host] = []
# Get IP from outgoing edges (connections to other nodes via mDNS discovery)
for _, edge_data in topology.out_edges(node_info.node_id):
if hasattr(edge_data, "send_back_multiaddr"):
# Extract IP from multiaddr like /ip4/192.168.1.100/tcp/52415
multiaddr = str(edge_data.send_back_multiaddr)
if "/ip4/" in multiaddr:
parts = multiaddr.split("/")
try:
ip_idx = parts.index("ip4") + 1
ip = parts[ip_idx]
# Skip link-local and localhost addresses
if not ip.startswith("169.254.") and not ip.startswith(
"127."
):
node_hosts.append(Host(ip=ip, port=0))
break
except (ValueError, IndexError):
pass
# Last resort: use localhost (will only work for single-node)
if not node_hosts:
logger.warning(
f"Could not determine IP for node {node_info.node_id}, using localhost"
)
node_hosts.append(Host(ip="127.0.0.1", port=0))
hosts_by_node[node_info.node_id] = node_hosts
total_ranks = len(selected_nodes) * command.ranks_per_node
# Determine coordinator IP - first node's first host IP
first_node_id: NodeId = next(iter(hosts_by_node.keys()))
coordinator_ip: str = (
hosts_by_node[first_node_id][0].ip
if hosts_by_node[first_node_id]
else "127.0.0.1"
)
target_instances[instance_id] = FLASHInstance(
instance_id=instance_id,
shard_assignments=shard_assignments,
hosts_by_node=hosts_by_node,
flash_executable_path=command.flash_executable_path,
parameter_file_path=command.parameter_file_path,
working_directory=command.working_directory,
ranks_per_node=command.ranks_per_node,
total_ranks=total_ranks,
simulation_name=command.simulation_name,
coordinator_ip=coordinator_ip,
)
logger.info(f"Created FLASH instance {instance_id} with {total_ranks} total ranks")
return target_instances
def get_transition_events(
current_instances: Mapping[InstanceId, Instance],
target_instances: Mapping[InstanceId, Instance],

View File

@@ -49,22 +49,20 @@ def get_smallest_cycles(cycles: list[list[NodeInfo]]) -> list[list[NodeInfo]]:
return [cycle for cycle in cycles if len(cycle) == min_nodes]
def _assign_layers_by_ram(
def get_shard_assignments_for_pipeline_parallel(
model_meta: ModelMetadata,
selected_cycle: list[NodeWithProfile],
) -> ShardAssignments:
"""Assign layers proportionally based on available RAM."""
):
cycle_memory = sum(
(node.node_profile.memory.ram_available for node in selected_cycle),
start=Memory(),
)
total_layers = model_meta.n_layers
world_size = len(selected_cycle)
runner_to_shard: dict[RunnerId, ShardMetadata] = {}
node_to_runner: dict[NodeId, RunnerId] = {}
cycle_memory = sum(
(node.node_profile.memory.ram_available for node in selected_cycle),
start=Memory(),
)
layers_assigned = 0
for i, node in enumerate(selected_cycle):
if i == len(selected_cycle) - 1:
node_layers = total_layers - layers_assigned
@@ -79,6 +77,7 @@ def _assign_layers_by_ram(
node_layers = max(1, node_layers)
runner_id = RunnerId()
shard = PipelineShardMetadata(
model_meta=model_meta,
device_rank=i,
@@ -87,143 +86,18 @@ def _assign_layers_by_ram(
end_layer=layers_assigned + node_layers,
n_layers=total_layers,
)
runner_to_shard[runner_id] = shard
node_to_runner[node.node_id] = runner_id
layers_assigned += node_layers
return ShardAssignments(
shard_assignments = ShardAssignments(
model_id=model_meta.model_id,
runner_to_shard=runner_to_shard,
node_to_runner=node_to_runner,
)
def _reserve_base_layers(world_size: int, total_layers: int) -> dict[int, int]:
"""Reserve 1 layer per node to ensure connectivity."""
assignments = {i: 0 for i in range(world_size)}
remaining_layers = total_layers
for i in range(world_size):
assignments[i] = 1
remaining_layers -= 1
if remaining_layers < 0:
logger.warning(
"Fewer layers than nodes! Reducing to 1 layer per node where possible."
)
assignments = {i: 1 if i < total_layers else 0 for i in range(world_size)}
remaining_layers = 0
return assignments
def _distribute_layers_by_bandwidth(
selected_cycle: list[NodeWithProfile],
assignments: dict[int, int],
remaining_layers: int,
model_meta: ModelMetadata,
) -> None:
"""Distribute remaining layers based on bandwidth and RAM capacity."""
indexed_nodes = list(enumerate(selected_cycle))
sorted_nodes = sorted(
indexed_nodes,
key=lambda x: x[1].node_profile.memory_bandwidth or 0,
reverse=True,
)
for original_idx, node in sorted_nodes:
if remaining_layers <= 0:
break
layer_size_bytes = model_meta.storage_size.in_bytes / model_meta.n_layers
max_layers_by_ram = int(
node.node_profile.memory.ram_available.in_bytes // layer_size_bytes
)
can_take = max(0, max_layers_by_ram - assignments[original_idx])
take = min(can_take, remaining_layers)
assignments[original_idx] += take
remaining_layers -= take
if remaining_layers > 0:
logger.warning(
"All nodes maxed out on RAM estimation, dumping remaining layers on fastest nodes."
)
for original_idx, _ in sorted_nodes:
assignments[original_idx] += 1
remaining_layers -= 1
if remaining_layers == 0:
break
def _create_shard_assignments(
model_meta: ModelMetadata,
selected_cycle: list[NodeWithProfile],
assignments: dict[int, int],
) -> ShardAssignments:
"""Create shard assignments from layer assignments."""
world_size = len(selected_cycle)
runner_to_shard: dict[RunnerId, ShardMetadata] = {}
node_to_runner: dict[NodeId, RunnerId] = {}
current_start = 0
for i, node in enumerate(selected_cycle):
count = assignments[i]
runner_id = RunnerId()
shard = PipelineShardMetadata(
model_meta=model_meta,
device_rank=i,
world_size=world_size,
start_layer=current_start,
end_layer=current_start + count,
n_layers=model_meta.n_layers,
)
runner_to_shard[runner_id] = shard
node_to_runner[node.node_id] = runner_id
current_start += count
return ShardAssignments(
model_id=model_meta.model_id,
runner_to_shard=runner_to_shard,
node_to_runner=node_to_runner,
)
def _assign_layers_by_bandwidth(
model_meta: ModelMetadata,
selected_cycle: list[NodeWithProfile],
) -> ShardAssignments:
"""Assign layers based on memory bandwidth."""
logger.info("Using bandwidth-aware shard assignment")
total_layers = model_meta.n_layers
world_size = len(selected_cycle)
assignments = _reserve_base_layers(world_size, total_layers)
remaining_layers = total_layers - sum(assignments.values())
if remaining_layers > 0:
_distribute_layers_by_bandwidth(
selected_cycle, assignments, remaining_layers, model_meta
)
return _create_shard_assignments(model_meta, selected_cycle, assignments)
def get_shard_assignments_for_pipeline_parallel(
model_meta: ModelMetadata,
selected_cycle: list[NodeWithProfile],
):
has_bandwidth = all(
node.node_profile.memory_bandwidth is not None for node in selected_cycle
)
if not has_bandwidth:
logger.info(
"Bandwidth data missing for some nodes, falling back to RAM-proportional assignment"
)
return _assign_layers_by_ram(model_meta, selected_cycle)
return _assign_layers_by_bandwidth(model_meta, selected_cycle)
return shard_assignments
def get_shard_assignments_for_tensor_parallel(

View File

@@ -397,106 +397,3 @@ def test_get_mlx_jaccl_coordinators(
assert coordinators[node_c_id] == (
f"{conn_c_a.send_back_multiaddr.ip_address}:5000"
), "node_c should use the IP from conn_c_a"
def test_get_shard_assignments_bandwidth_aware(
topology: Topology,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
):
# arrange
node_a_id = NodeId()
node_b_id = NodeId()
node_c_id = NodeId()
# Create nodes with identical RAM (plenty of it)
# Using 1GB to ensure no RAM constraints (model is small)
node_a = create_node(1024 * 1024 * 1024, node_a_id)
node_b = create_node(1024 * 1024 * 1024, node_b_id)
node_c = create_node(1024 * 1024 * 1024, node_c_id)
# Set Bandwidths: A=400 (Fastest), B=200, C=100 (Slowest)
assert node_a.node_profile is not None
assert node_b.node_profile is not None
assert node_c.node_profile is not None
node_a.node_profile.memory_bandwidth = 400_000_000_000
node_b.node_profile.memory_bandwidth = 200_000_000_000
node_c.node_profile.memory_bandwidth = 100_000_000_000
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_node(node_c)
topology.add_connection(create_connection(node_a_id, node_b_id))
topology.add_connection(create_connection(node_b_id, node_c_id))
topology.add_connection(create_connection(node_c_id, node_a_id))
# Needs full cycle edges for get_cycles/get_shard_assignments if strict?
# Actually get_cycles just looks for cycles.
# But let's follow the pattern of other tests if they add bidirectional.
# checking test_filter_cycles_by_memory, it adds both directions.
topology.add_connection(create_connection(node_b_id, node_a_id))
topology.add_connection(create_connection(node_c_id, node_b_id))
topology.add_connection(create_connection(node_a_id, node_c_id))
model_meta = ModelMetadata(
model_id=ModelId("test-model"),
pretty_name="Test Model",
n_layers=30, # 30 layers
storage_size=Memory.from_kb(
300
), # 10KB per layer. Nodes have 100MB RAM (100*1024 in create_node usually means KB? other tests use 1000*1024).
# create_node arg is likely KB or Bytes.
# test_filter_cycles_by_memory: create_node(1000 * 1024, ...) -> Memory.from_bytes(1) passes.
# Let's assume create_node takes Bytes or KB consistently.
# If I give 100*1024*1024 bytes = 100MB.
# Model storage = 300KB.
# So capacity is definitely not an issue.
hidden_size=1000,
supports_tensor=True,
)
cycles = topology.get_cycles()
# Depending on how get_cycles works and order of addition, we might get multiple cycles.
# filtering by memory usually done in master.
# Here we just pick one.
selected_cycle = cycles[0]
# act
shard_assignments = get_shard_assignments(
model_meta, selected_cycle, Sharding.Pipeline
)
# assert
runner_id_a = shard_assignments.node_to_runner[node_a_id]
runner_id_b = shard_assignments.node_to_runner[node_b_id]
runner_id_c = shard_assignments.node_to_runner[node_c_id]
# Get layer counts
layers_a = (
shard_assignments.runner_to_shard[runner_id_a].end_layer
- shard_assignments.runner_to_shard[runner_id_a].start_layer
)
layers_b = (
shard_assignments.runner_to_shard[runner_id_b].end_layer
- shard_assignments.runner_to_shard[runner_id_b].start_layer
)
layers_c = (
shard_assignments.runner_to_shard[runner_id_c].end_layer
- shard_assignments.runner_to_shard[runner_id_c].start_layer
)
# Check total
assert layers_a + layers_b + layers_c == 30
# Check that the fastest node (A with 400GB/s) gets saturated first.
# With strict greedy assignment and plenty of RAM:
# 1. Reserve: A=1, B=1, C=1. Remaining=27.
# 2. Sort: [A, B, C]
# 3. A takes min(remaining=27, capacity=huge) = 27.
# 4. A=28, B=1, C=1.
assert layers_a == 28
assert layers_b == 1
assert layers_c == 1

13
src/exo/rsh/__init__.py Normal file
View File

@@ -0,0 +1,13 @@
"""Exo RSH - Remote Shell for MPI without SSH.
This module provides a remote execution mechanism that allows mpirun to spawn
processes on remote nodes without requiring SSH setup. It works by:
1. Each Exo node runs an API server on port 52415 with an /execute endpoint
2. The exo-rsh script acts as a drop-in replacement for ssh
3. When mpirun calls "exo-rsh hostname command", it HTTP POSTs to the target's /execute
4. The target executes the command and returns output
Usage:
mpirun --mca plm_rsh_agent exo-rsh -np 4 --hostfile hosts.txt ./program
"""

101
src/exo/rsh/client.py Normal file
View File

@@ -0,0 +1,101 @@
#!/usr/bin/env python3
"""exo-rsh - Remote shell client for MPI.
This script is called by mpirun as a replacement for ssh.
Usage: exo-rsh [ssh-options...] hostname command [args...]
It connects to the target node's Exo API (port 52415) and executes the command.
"""
import json
import socket
import sys
from typing import Any, cast
from urllib.error import URLError
from urllib.request import Request, urlopen
# Use the same port as Exo's API server
EXO_API_PORT = 52415
def resolve_hostname(hostname: str) -> str:
"""Resolve hostname to IP address."""
try:
return socket.gethostbyname(hostname)
except socket.gaierror:
# If resolution fails, try using the hostname directly
return hostname
def main():
# Parse arguments - mpirun calls us like: exo-rsh [options] hostname command [args...]
# SSH options we might see: -x (disable X11), -o options, etc.
args = sys.argv[1:]
# Skip SSH-style options
hostname = None
command_start = 0
i = 0
while i < len(args):
arg = args[i]
if arg.startswith("-"):
# Skip option and its value if needed
if arg in ("-o", "-i", "-l", "-p", "-F"):
i += 2 # Skip option and its argument
continue
i += 1
continue
else:
# First non-option is the hostname
hostname = arg
command_start = i + 1
break
i += 1
if hostname is None or command_start >= len(args):
print("Usage: exo-rsh [options] hostname command [args...]", file=sys.stderr)
sys.exit(1)
command = args[command_start:]
# Resolve hostname to IP
ip = resolve_hostname(hostname)
# Make request to Exo API
url = f"http://{ip}:{EXO_API_PORT}/execute"
data = json.dumps({"command": command}).encode("utf-8")
try:
req = Request(url, data=data, headers={"Content-Type": "application/json"})
with urlopen(req, timeout=300) as response: # pyright: ignore[reportAny]
response_body: bytes = cast(bytes, response.read()) # pyright: ignore[reportAny]
result: dict[str, Any] = json.loads(response_body.decode("utf-8")) # pyright: ignore[reportAny]
# Output stdout/stderr
stdout: str = cast(str, result.get("stdout", ""))
stderr: str = cast(str, result.get("stderr", ""))
exit_code: int = cast(int, result.get("exit_code", 0))
if stdout:
sys.stdout.write(stdout)
sys.stdout.flush()
if stderr:
sys.stderr.write(stderr)
sys.stderr.flush()
sys.exit(exit_code)
except URLError as e:
print(
f"exo-rsh: Failed to connect to {hostname}:{EXO_API_PORT}: {e}",
file=sys.stderr,
)
sys.exit(255)
except Exception as e:
print(f"exo-rsh: Error: {e}", file=sys.stderr)
sys.exit(1)
if __name__ == "__main__":
main()

View File

@@ -29,11 +29,6 @@ class _InterceptHandler(logging.Handler):
def logger_setup(log_file: Path | None, verbosity: int = 0):
"""Set up logging for this process - formatting, file handles, verbosity and output"""
logging.getLogger("exo_pyo3_bindings").setLevel(logging.WARNING)
logging.getLogger("httpx").setLevel(logging.WARNING)
logging.getLogger("httpcore").setLevel(logging.WARNING)
logger.remove()
# replace all stdlib loggers with _InterceptHandlers that log to loguru

View File

@@ -14,6 +14,32 @@ class ModelCard(CamelCaseModel):
MODEL_CARDS: dict[str, ModelCard] = {
# deepseek v3
# "deepseek-v3-0324:4bit": ModelCard(
# short_id="deepseek-v3-0324:4bit",
# model_id="mlx-community/DeepSeek-V3-0324-4bit",
# name="DeepSeek V3 0324 (4-bit)",
# description="""DeepSeek V3 is a large language model trained on the DeepSeek V3 dataset.""",
# tags=[],
# metadata=ModelMetadata(
# model_id=ModelId("mlx-community/DeepSeek-V3-0324-4bit"),
# pretty_name="DeepSeek V3 0324 (4-bit)",
# storage_size=Memory.from_kb(409706307),
# n_layers=61,
# ),
# ),
# "deepseek-v3-0324": ModelCard(
# short_id="deepseek-v3-0324",
# model_id="mlx-community/DeepSeek-v3-0324-8bit",
# name="DeepSeek V3 0324 (8-bit)",
# description="""DeepSeek V3 is a large language model trained on the DeepSeek V3 dataset.""",
# tags=[],
# metadata=ModelMetadata(
# model_id=ModelId("mlx-community/DeepSeek-v3-0324-8bit"),
# pretty_name="DeepSeek V3 0324 (8-bit)",
# storage_size=Memory.from_kb(754706307),
# n_layers=61,
# ),
# ),
"deepseek-v3.1-4bit": ModelCard(
short_id="deepseek-v3.1-4bit",
model_id=ModelId("mlx-community/DeepSeek-V3.1-4bit"),
@@ -44,6 +70,65 @@ MODEL_CARDS: dict[str, ModelCard] = {
supports_tensor=True,
),
),
# "deepseek-v3.2": ModelCard(
# short_id="deepseek-v3.2",
# model_id=ModelId("mlx-community/DeepSeek-V3.2-8bit"),
# name="DeepSeek V3.2 (8-bit)",
# description="""DeepSeek V3.2 is a large language model trained on the DeepSeek V3.2 dataset.""",
# tags=[],
# metadata=ModelMetadata(
# model_id=ModelId("mlx-community/DeepSeek-V3.2-8bit"),
# pretty_name="DeepSeek V3.2 (8-bit)",
# storage_size=Memory.from_kb(754706307),
# n_layers=61,
# hidden_size=7168,
# supports_tensor=True,
# ),
# ),
# "deepseek-v3.2-4bit": ModelCard(
# short_id="deepseek-v3.2-4bit",
# model_id=ModelId("mlx-community/DeepSeek-V3.2-4bit"),
# name="DeepSeek V3.2 (4-bit)",
# description="""DeepSeek V3.2 is a large language model trained on the DeepSeek V3.2 dataset.""",
# tags=[],
# metadata=ModelMetadata(
# model_id=ModelId("mlx-community/DeepSeek-V3.2-4bit"),
# pretty_name="DeepSeek V3.2 (4-bit)",
# storage_size=Memory.from_kb(754706307 // 2), # TODO !!!!!
# n_layers=61,
# hidden_size=7168,
# supports_tensor=True,
# ),
# ),
# deepseek r1
# "deepseek-r1-0528-4bit": ModelCard(
# short_id="deepseek-r1-0528-4bit",
# model_id="mlx-community/DeepSeek-R1-0528-4bit",
# name="DeepSeek-R1-0528 (4-bit)",
# description="""DeepSeek R1 is a large language model trained on the DeepSeek R1 dataset.""",
# tags=[],
# metadata=ModelMetadata(
# model_id=ModelId("mlx-community/DeepSeek-R1-0528-4bit"),
# pretty_name="DeepSeek R1 671B (4-bit)",
# storage_size=Memory.from_kb(409706307),
# n_layers=61,
# hidden_size=7168,
# ),
# ),
# "deepseek-r1-0528": ModelCard(
# short_id="deepseek-r1-0528",
# model_id="mlx-community/DeepSeek-R1-0528-8bit",
# name="DeepSeek-R1-0528 (8-bit)",
# description="""DeepSeek R1 is a large language model trained on the DeepSeek R1 dataset.""",
# tags=[],
# metadata=ModelMetadata(
# model_id=ModelId("mlx-community/DeepSeek-R1-0528-8bit"),
# pretty_name="DeepSeek R1 671B (8-bit)",
# storage_size=Memory.from_bytes(754998771712),
# n_layers=61,
# . hidden_size=7168,
# ),
# ),
# kimi k2
"kimi-k2-instruct-4bit": ModelCard(
short_id="kimi-k2-instruct-4bit",
@@ -425,24 +510,23 @@ MODEL_CARDS: dict[str, ModelCard] = {
supports_tensor=True,
),
),
"gpt-oss-20b-MXFP4-Q8": ModelCard(
short_id="gpt-oss-20b-MXFP4-Q8",
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q8"),
name="GPT-OSS 20B (MXFP4-Q8, MLX)",
description="""OpenAI's GPT-OSS 20B is a medium-sized MoE model for lower-latency and local or specialized use cases; this variant is a 4-bit MLX conversion for Apple Silicon.""",
"gpt-oss-20b-4bit": ModelCard(
short_id="gpt-oss-20b-4bit",
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q4"),
name="GPT-OSS 20B (MXFP4-Q4, MLX)",
description="""OpenAI's GPT-OSS 20B is a medium-sized MoE model for lower-latency and local or specialized use cases; this MLX variant uses MXFP4 4-bit quantization.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q8"),
pretty_name="GPT-OSS 20B (MXFP4-Q8, MLX)",
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q4"),
pretty_name="GPT-OSS 20B (MXFP4-Q4, MLX)",
storage_size=Memory.from_kb(11_744_051),
n_layers=24,
hidden_size=2880,
supports_tensor=True,
),
),
# glm 4.5
# Needs to be quantized g32 or g16.
"glm-4.5-air-8bit": ModelCard(
# Needs to be quantized g32 or g16 to work with tensor parallel
short_id="glm-4.5-air-8bit",
model_id=ModelId("mlx-community/GLM-4.5-Air-8bit"),
name="GLM 4.5 Air 8bit",
@@ -472,7 +556,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
supports_tensor=True,
),
),
# glm 4.7
"glm-4.7-4bit": ModelCard(
short_id="glm-4.7-4bit",
model_id=ModelId("mlx-community/GLM-4.7-4bit"),
@@ -518,7 +601,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
supports_tensor=True,
),
),
# minimax-m2
"minimax-m2.1-8bit": ModelCard(
short_id="minimax-m2.1-8bit",
model_id=ModelId("mlx-community/MiniMax-M2.1-8bit"),
@@ -549,4 +631,19 @@ MODEL_CARDS: dict[str, ModelCard] = {
supports_tensor=True,
),
),
# "devstral-2-123b-instruct-2512-8bit": ModelCard(
# short_id="devstral-2-123b-instruct-2512-8bit",
# model_id=ModelId("mlx-community/Devstral-2-123B-Instruct-2512-8bit"),
# name="Devstral 2 123B Instruct 2512 (8-bit, MLX)",
# description="""Mistral AI's Devstral 2 123B Instruct (2512) is an agentic coding model.""",
# tags=[],
# metadata=ModelMetadata(
# model_id=ModelId("mlx-community/Devstral-2-123B-Instruct-2512-8bit"),
# pretty_name="Devstral 2 123B Instruct 2512 (8-bit, MLX)",
# storage_size=Memory.from_kb(133_000_000),
# n_layers=88,
# hidden_size=12288,
# supports_tensor=True,
# ),
# ),
}

View File

@@ -35,6 +35,26 @@ class DeleteInstance(BaseCommand):
instance_id: InstanceId
class LaunchFLASH(BaseCommand):
"""Command to launch a FLASH MPI simulation."""
simulation_name: str
flash_executable_path: str
parameter_file_path: str
working_directory: str
ranks_per_node: int = 1
min_nodes: int = 1
# Optional: explicit hostnames for MPI (e.g., "s14,james21-1")
# Used when topology edges don't contain IP addresses
hosts: str = ""
class StopFLASH(BaseCommand):
"""Command to stop a running FLASH simulation."""
instance_id: InstanceId
class TaskFinished(BaseCommand):
finished_command_id: CommandId
@@ -50,6 +70,8 @@ Command = (
| PlaceInstance
| CreateInstance
| DeleteInstance
| LaunchFLASH
| StopFLASH
| TaskFinished
)

View File

@@ -57,7 +57,6 @@ class NodePerformanceProfile(CamelCaseModel):
chip_id: str
friendly_name: str
memory: MemoryPerformanceProfile
memory_bandwidth: int | None = None
network_interfaces: list[NetworkInterfaceInfo] = []
system: SystemPerformanceProfile

View File

@@ -14,6 +14,7 @@ class InstanceId(Id):
class InstanceMeta(str, Enum):
MlxRing = "MlxRing"
MlxJaccl = "MlxJaccl"
FLASH = "FLASH"
class BaseInstance(TaggedModel):
@@ -34,8 +35,27 @@ class MlxJacclInstance(BaseInstance):
jaccl_coordinators: dict[NodeId, str]
class FLASHInstance(BaseInstance):
"""Instance for FLASH MPI simulation.
Unlike MLX instances which do tensor parallelism, FLASH instances
coordinate MPI processes across nodes. Each node runs one or more
MPI ranks of the FLASH simulation.
"""
hosts_by_node: dict[NodeId, list[Host]]
flash_executable_path: str
parameter_file_path: str
working_directory: str
ranks_per_node: int = 1
total_ranks: int
simulation_name: str
coordinator_ip: str
network_interface: str = "en0" # Network interface for MPI (e.g., en0, eth0)
# TODO: Single node instance
Instance = MlxRingInstance | MlxJacclInstance
Instance = MlxRingInstance | MlxJacclInstance | FLASHInstance
class BoundInstance(CamelCaseModel):

View File

@@ -20,7 +20,6 @@ except ImportError:
from mlx_lm.models.cache import KVCache, QuantizedKVCache, RotatingKVCache
from mlx_lm.models.deepseek_v3 import DeepseekV3Model
from mlx_lm.models.gpt_oss import Model as GptOssModel
from mlx_lm.tokenizer_utils import TokenizerWrapper
from exo.worker.engines.mlx.constants import (
@@ -165,6 +164,11 @@ def mlx_distributed_init(
os.environ["MLX_JACCL_COORDINATOR"] = jaccl_coordinator
group = mx.distributed.init(backend="jaccl", strict=True)
case _:
raise ValueError(
f"Unsupported instance type for MLX distributed: {type(bound_instance.instance)}"
)
logger.info(f"Rank {rank} mlx distributed initialization complete")
return group
@@ -366,8 +370,6 @@ def apply_chat_template(
tools=chat_task_data.tools,
)
logger.info(prompt)
return prompt
@@ -399,11 +401,6 @@ def make_kv_cache(
) -> list[KVCache | RotatingKVCache | QuantizedKVCache]:
assert hasattr(model, "layers")
# TODO: Do this for all models
if hasattr(model, "make_cache") and isinstance(model, GptOssModel):
logger.info("Using MLX LM's make cache")
return model.make_cache() # type: ignore
if max_kv_size is None:
if KV_CACHE_BITS is None:
logger.info("Using default KV cache")

View File

@@ -21,7 +21,12 @@ from exo.shared.types.worker.downloads import (
DownloadOngoing,
DownloadProgress,
)
from exo.shared.types.worker.instances import BoundInstance, Instance, InstanceId
from exo.shared.types.worker.instances import (
BoundInstance,
FLASHInstance,
Instance,
InstanceId,
)
from exo.shared.types.worker.runners import (
RunnerConnected,
RunnerConnecting,
@@ -50,6 +55,11 @@ def plan(
all_runners: Mapping[RunnerId, RunnerStatus], # all global
tasks: Mapping[TaskId, Task],
) -> Task | None:
# Check for FLASH instance tasks first
flash_task = _plan_flash(runners, instances)
if flash_task is not None:
return flash_task
# Python short circuiting OR logic should evaluate these sequentially.
return (
_kill_runner(runners, all_runners, instances)
@@ -62,6 +72,34 @@ def plan(
)
def _plan_flash(
runners: Mapping[RunnerId, RunnerSupervisor],
instances: Mapping[InstanceId, Instance],
) -> Task | None:
"""Plan tasks specifically for FLASH instances.
FLASH instances have a simpler lifecycle:
- CreateRunner (handled by _create_runner)
- LoadModel (starts the simulation immediately)
- Shutdown (handled by _kill_runner)
This function handles the LoadModel step for FLASH instances,
skipping the MLX-specific download/init/warmup steps.
"""
for runner in runners.values():
instance = runner.bound_instance.instance
# Only handle FLASH instances
if not isinstance(instance, FLASHInstance):
continue
# If runner is idle, emit LoadModel to start the simulation
if isinstance(runner.status, RunnerIdle):
return LoadModel(instance_id=instance.instance_id)
return None
def _kill_runner(
runners: Mapping[RunnerId, RunnerSupervisor],
all_runners: Mapping[RunnerId, RunnerStatus],
@@ -114,6 +152,10 @@ def _model_needs_download(
download_status: Mapping[ModelId, DownloadProgress],
) -> DownloadModel | None:
for runner in runners.values():
# FLASH instances don't need model downloads
if isinstance(runner.bound_instance.instance, FLASHInstance):
continue
model_id = runner.bound_instance.bound_shard.model_meta.model_id
if isinstance(runner.status, RunnerIdle) and (
model_id not in download_status

View File

@@ -4,7 +4,11 @@ import loguru
from exo.shared.types.events import Event, RunnerStatusUpdated
from exo.shared.types.tasks import Task
from exo.shared.types.worker.instances import BoundInstance, MlxJacclInstance
from exo.shared.types.worker.instances import (
BoundInstance,
FLASHInstance,
MlxJacclInstance,
)
from exo.shared.types.worker.runners import RunnerFailed
from exo.utils.channels import ClosedResourceError, MpReceiver, MpSender
@@ -17,20 +21,27 @@ def entrypoint(
task_receiver: MpReceiver[Task],
_logger: "loguru.Logger",
) -> None:
if (
isinstance(bound_instance.instance, MlxJacclInstance)
and len(bound_instance.instance.ibv_devices) >= 2
):
os.environ["MLX_METAL_FAST_SYNCH"] = "1"
global logger
logger = _logger
# Import main after setting global logger - this lets us just import logger from this module
# Route based on instance type
try:
from exo.worker.runner.runner import main
if isinstance(bound_instance.instance, FLASHInstance):
# FLASH MPI simulation runner
from exo.worker.runner.flash_runner import main
main(bound_instance, event_sender, task_receiver)
main(bound_instance, event_sender, task_receiver)
else:
# MLX runner (default)
if (
isinstance(bound_instance.instance, MlxJacclInstance)
and len(bound_instance.instance.ibv_devices) >= 2
):
os.environ["MLX_METAL_FAST_SYNCH"] = "1"
from exo.worker.runner.runner import main
main(bound_instance, event_sender, task_receiver)
except ClosedResourceError:
logger.warning("Runner communication closed unexpectedly")
except Exception as e:

View File

@@ -0,0 +1,301 @@
"""FLASH MPI Runner - spawns and monitors FLASH simulations.
Exo-native distributed MPI:
- Exo handles node discovery and coordination
- Coordinator generates hostfile from Exo topology
- mpirun uses exo-rsh (no SSH required) to spawn on remote nodes
- exo-rsh connects to each node's Exo API (/execute endpoint) for remote execution
- Workers just report ready and wait
"""
import os
import shutil
import socket
import subprocess
import threading
from exo.shared.types.events import (
Event,
RunnerStatusUpdated,
TaskAcknowledged,
TaskStatusUpdated,
)
from exo.shared.types.tasks import (
LoadModel,
Shutdown,
Task,
TaskStatus,
)
from exo.shared.types.worker.instances import BoundInstance, FLASHInstance
from exo.shared.types.worker.runners import (
RunnerFailed,
RunnerIdle,
RunnerLoading,
RunnerReady,
RunnerRunning,
RunnerShutdown,
RunnerShuttingDown,
RunnerStatus,
)
from exo.utils.channels import MpReceiver, MpSender
from exo.worker.runner.bootstrap import logger
# Find mpirun in PATH, fallback to common locations
MPIRUN_PATH = shutil.which("mpirun") or "/opt/homebrew/bin/mpirun"
# exo-rsh is installed as console script by exo package
_exo_rsh_path = shutil.which("exo-rsh")
if not _exo_rsh_path:
raise RuntimeError("exo-rsh not found in PATH - this should be installed with exo")
EXO_RSH_PATH: str = _exo_rsh_path
def get_my_rank(instance: FLASHInstance, my_node_id: str) -> int:
"""Determine this node's rank based on position in hosts_by_node."""
for i, node_id in enumerate(instance.hosts_by_node.keys()):
if str(node_id) == str(my_node_id):
return i
return -1
def get_coordinator_host(instance: FLASHInstance) -> str:
"""Get the IP of the coordinator node."""
return instance.coordinator_ip
def resolve_host(host: str) -> str:
"""Resolve host string to a usable hostname for MPI hostfile.
Accepts either an IP address or hostname. For IPs, attempts to resolve
to a hostname via DNS/mDNS. Hostnames are returned as-is after validation.
"""
# Check if input is already a hostname (not an IP)
try:
socket.inet_aton(host)
is_ip = True
except socket.error:
is_ip = False
if not is_ip:
# Already a hostname, verify it resolves and return as-is
try:
socket.gethostbyname(host)
return host
except socket.gaierror:
logger.warning(f"Hostname {host} does not resolve, using anyway")
return host
# It's an IP address, try to resolve to hostname
try:
hostname, _, _ = socket.gethostbyaddr(host)
hostname = hostname.split(".")[0]
logger.info(f"Resolved {host} to {hostname}")
return hostname
except socket.herror:
pass
# Fall back to IP
logger.warning(f"Could not resolve {host} to hostname, using IP directly")
return host
def generate_hostfile(instance: FLASHInstance, working_dir: str) -> str:
"""Generate MPI hostfile from instance topology."""
hostfile_path = os.path.join(working_dir, "flash_hosts.txt")
with open(hostfile_path, "w") as f:
for _node_id, hosts in instance.hosts_by_node.items():
if hosts:
host = resolve_host(hosts[0].ip)
f.write(f"{host} slots={instance.ranks_per_node}\n")
logger.info(f"Generated hostfile at {hostfile_path}")
with open(hostfile_path, "r") as f:
logger.info(f"Hostfile contents:\n{f.read()}")
return hostfile_path
def main(
bound_instance: BoundInstance,
event_sender: MpSender[Event],
task_receiver: MpReceiver[Task],
):
"""Main FLASH runner loop.
Coordinator: generates hostfile and runs mpirun (uses exo-rsh instead of SSH)
Workers: just report ready and wait for mpirun to spawn processes on them
"""
assert isinstance(bound_instance.instance, FLASHInstance)
instance = bound_instance.instance
runner_id = bound_instance.bound_runner_id
my_node_id = str(bound_instance.bound_node_id)
logger.info(f"FLASH runner starting for simulation: {instance.simulation_name}")
my_rank = get_my_rank(instance, my_node_id)
world_size = len(instance.hosts_by_node)
is_coordinator = my_rank == 0
coordinator_ip = get_coordinator_host(instance)
logger.info(
f"FLASH node: rank={my_rank}, world_size={world_size}, coordinator={is_coordinator}"
)
logger.info(f"FLASH coordinator IP: {coordinator_ip}")
process: subprocess.Popen[bytes] | None = None
current_status: RunnerStatus = RunnerIdle()
shutdown_requested = False
event_sender.send(
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
)
def monitor_output(proc: subprocess.Popen[bytes]) -> None:
"""Monitor FLASH stdout for progress updates."""
if proc.stdout is None:
return
for line in iter(proc.stdout.readline, b""):
if shutdown_requested:
break
try:
decoded: str = line.decode("utf-8", errors="replace").strip()
if decoded:
logger.info(f"[FLASH] {decoded}")
except Exception as e:
logger.warning(f"Error parsing FLASH output: {e}")
with task_receiver as tasks:
for task in tasks:
event_sender.send(
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Running)
)
event_sender.send(TaskAcknowledged(task_id=task.task_id))
match task:
case LoadModel() if isinstance(current_status, RunnerIdle):
current_status = RunnerLoading()
logger.info("Starting FLASH simulation")
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=current_status
)
)
try:
if is_coordinator:
# Coordinator: generate hostfile and run mpirun
hostfile = generate_hostfile(
instance, instance.working_directory
)
iface = instance.network_interface
cmd = [
MPIRUN_PATH,
"-np",
str(instance.total_ranks),
"--hostfile",
hostfile,
"--wdir",
instance.working_directory,
"--oversubscribe",
"--mca",
"btl",
"tcp,self",
"--mca",
"btl_tcp_if_include",
iface,
"--mca",
"oob_tcp_if_include",
iface,
"--mca",
"plm_rsh_no_tree_spawn",
"1",
]
# Use exo-rsh for remote execution (no SSH needed)
cmd.extend(["--mca", "plm_rsh_agent", EXO_RSH_PATH])
cmd.append(instance.flash_executable_path)
logger.info(f"FLASH distributed launch: {' '.join(cmd)}")
process = subprocess.Popen(
cmd,
cwd=instance.working_directory,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
)
monitor_thread = threading.Thread(
target=monitor_output, args=(process,), daemon=True
)
monitor_thread.start()
current_status = RunnerRunning()
logger.info(
f"FLASH running on {world_size} nodes with {instance.total_ranks} ranks"
)
else:
# Worker: mpirun on coordinator will use exo-rsh to spawn processes here
logger.info(
f"Worker {my_rank}: Ready for mpirun to spawn processes via exo-rsh"
)
current_status = RunnerRunning()
except Exception as e:
logger.error(f"Failed to start FLASH: {e}")
import traceback
logger.error(traceback.format_exc())
current_status = RunnerFailed(error_message=str(e))
case Shutdown():
shutdown_requested = True
current_status = RunnerShuttingDown()
logger.info("FLASH runner shutting down")
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=current_status
)
)
if process and process.poll() is None:
logger.info("Terminating FLASH simulation")
process.terminate()
try:
process.wait(timeout=10)
except subprocess.TimeoutExpired:
logger.warning("FLASH didn't terminate, killing")
process.kill()
process.wait()
current_status = RunnerShutdown()
case _:
if process and process.poll() is not None:
exit_code = process.returncode
if exit_code == 0:
logger.info("FLASH simulation completed successfully")
current_status = RunnerReady()
else:
logger.error(
f"FLASH simulation failed with code {exit_code}"
)
current_status = RunnerFailed(
error_message=f"Exit code {exit_code}"
)
event_sender.send(
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Complete)
)
event_sender.send(
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
)
if isinstance(current_status, RunnerShutdown):
break
if process and process.poll() is None:
process.terminate()
process.wait(timeout=5)
logger.info("FLASH runner exiting")

View File

@@ -1,15 +1,6 @@
import time
from collections.abc import Generator
from functools import cache
import mlx.core as mx
from mlx_lm.models.gpt_oss import Model as GptOssModel
from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
HarmonyEncodingName,
Role,
StreamableParser,
load_harmony_encoding,
)
from exo.shared.types.api import ChatCompletionMessageText
from exo.shared.types.chunks import TokenChunk
@@ -162,19 +153,11 @@ def main(
_check_for_debug_prompts(task_params.messages[0].content)
# Generate responses using the actual MLX generation
mlx_generator = mlx_generate(
for response in mlx_generate(
model=model,
tokenizer=tokenizer,
task=task_params,
)
# GPT-OSS specific parsing to match other model formats.
if isinstance(model, GptOssModel):
mlx_generator = parse_gpt_oss(mlx_generator)
# TODO: Add tool call parser here
for response in mlx_generator:
):
match response:
case GenerationResponse():
if shard_metadata.device_rank == 0:
@@ -224,43 +207,6 @@ def main(
break
@cache
def get_gpt_oss_encoding():
encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
return encoding
def parse_gpt_oss(
responses: Generator[GenerationResponse],
) -> Generator[GenerationResponse]:
encoding = get_gpt_oss_encoding()
stream = StreamableParser(encoding, role=Role.ASSISTANT)
thinking = False
for response in responses:
stream.process(response.token)
delta = stream.last_content_delta
ch = stream.current_channel
if ch == "analysis" and not thinking:
thinking = True
yield response.model_copy(update={"text": "<think>"})
if ch != "analysis" and thinking:
thinking = False
yield response.model_copy(update={"text": "</think>"})
if delta:
yield response.model_copy(update={"text": delta})
if response.finish_reason is not None:
if thinking:
yield response.model_copy(update={"text": "</think>"})
yield response
break
EXO_RUNNER_MUST_FAIL = "EXO RUNNER MUST FAIL"
EXO_RUNNER_MUST_OOM = "EXO RUNNER MUST OOM"
EXO_RUNNER_MUST_TIMEOUT = "EXO RUNNER MUST TIMEOUT"

View File

@@ -1,64 +1,49 @@
import anyio
import httpx
from anyio import create_task_group
import http.client
from anyio import create_task_group, to_thread
from loguru import logger
from exo.shared.topology import Topology
from exo.shared.types.common import NodeId
REACHABILITY_ATTEMPTS = 3
async def check_reachability(
target_ip: str,
expected_node_id: NodeId,
self_node_id: NodeId,
out: dict[NodeId, set[str]],
client: httpx.AsyncClient,
) -> None:
"""Check if a node is reachable at the given IP and verify its identity."""
if ":" in target_ip:
# TODO: use real IpAddress types
target_ip = f"[{target_ip}]"
url = f"http://{target_ip}:52415/node_id"
remote_node_id = None
last_error = None
for _ in range(REACHABILITY_ATTEMPTS):
def _fetch_remote_node_id() -> NodeId | None:
connection = http.client.HTTPConnection(target_ip, 52415, timeout=1)
try:
r = await client.get(url)
if r.status_code != 200:
await anyio.sleep(1)
continue
connection.request("GET", "/node_id")
response = connection.getresponse()
if response.status != 200:
return None
body = r.text.strip().strip('"')
if not body:
await anyio.sleep(1)
continue
body = response.read().decode("utf-8").strip()
remote_node_id = NodeId(body)
break
# Strip quotes if present (JSON string response)
if body.startswith('"') and body.endswith('"') and len(body) >= 2:
body = body[1:-1]
# expected failure cases
except (
httpx.TimeoutException,
httpx.NetworkError,
):
await anyio.sleep(1)
# other failures should be logged on last attempt
except httpx.HTTPError as e:
last_error = e
await anyio.sleep(1)
if last_error is not None:
logger.warning(
f"connect error {type(last_error).__name__} from {target_ip} after {REACHABILITY_ATTEMPTS} attempts; treating as down"
)
return NodeId(body) or None
except OSError:
return None
except http.client.HTTPException:
return None
finally:
connection.close()
remote_node_id = await to_thread.run_sync(_fetch_remote_node_id)
if remote_node_id is None:
return
if remote_node_id == self_node_id:
return
if remote_node_id != expected_node_id:
logger.warning(
f"Discovered node with unexpected node_id; "
@@ -76,33 +61,18 @@ async def check_reachable(
topology: Topology, self_node_id: NodeId
) -> dict[NodeId, set[str]]:
"""Check which nodes are reachable and return their IPs."""
reachable: dict[NodeId, set[str]] = {}
# these are intentionally httpx's defaults so we can tune them later
timeout = httpx.Timeout(timeout=5.0)
limits = httpx.Limits(
max_connections=100,
max_keepalive_connections=20,
keepalive_expiry=5,
)
async with (
httpx.AsyncClient(timeout=timeout, limits=limits) as client,
create_task_group() as tg,
):
async with create_task_group() as tg:
for node in topology.list_nodes():
if not node.node_profile:
continue
if node.node_id == self_node_id:
continue
for iface in node.node_profile.network_interfaces:
tg.start_soon(
check_reachability,
iface.ip_address,
node.node_id,
self_node_id,
reachable,
client,
)
return reachable

View File

@@ -4,7 +4,6 @@ import platform
from typing import Any, Callable, Coroutine
import anyio
from anyio import to_thread
from loguru import logger
from exo.shared.types.memory import Memory
@@ -25,61 +24,8 @@ from .system_info import (
get_friendly_name,
get_model_and_chip,
get_network_interfaces,
profile_memory_bandwidth,
)
# Module-level cache for memory bandwidth (doesn't change at runtime)
_cached_bandwidth: int | None = None
_bandwidth_profiled: bool = False
_bandwidth_profiling_task: asyncio.Task[int | None] | None = None
async def profile_bandwidth_once() -> int | None:
"""Profile bandwidth once in a background thread and cache the result.
This function is non-blocking - it runs the profiling in a thread pool.
Subsequent calls return the cached result immediately.
"""
global _cached_bandwidth, _bandwidth_profiled, _bandwidth_profiling_task
# Already profiled, return cached value
if _bandwidth_profiled:
return _cached_bandwidth
# Profiling already in progress, wait for it
if _bandwidth_profiling_task is not None:
return await _bandwidth_profiling_task
# Start profiling in background thread
async def _do_profile() -> int | None:
global _cached_bandwidth, _bandwidth_profiled
try:
logger.info("Starting memory bandwidth profiling in background thread...")
bandwidth = await to_thread.run_sync(profile_memory_bandwidth, cancellable=True)
_cached_bandwidth = bandwidth
_bandwidth_profiled = True
if bandwidth:
logger.info(f"Memory bandwidth profiled: {bandwidth / 1e9:.1f} GB/s")
else:
logger.warning("Memory bandwidth profiling returned None")
return bandwidth
except Exception as e:
logger.opt(exception=e).error("Memory bandwidth profiling failed")
_bandwidth_profiled = True # Mark as done to avoid retrying
return None
_bandwidth_profiling_task = asyncio.create_task(_do_profile())
return await _bandwidth_profiling_task
def get_memory_bandwidth_cached() -> int | None:
"""Return cached bandwidth or None if not yet profiled.
This is a non-blocking synchronous function that returns immediately.
Call profile_bandwidth_once() first to trigger profiling.
"""
return _cached_bandwidth if _bandwidth_profiled else None
async def get_metrics_async() -> Metrics | None:
"""Return detailed Metrics on macOS or a minimal fallback elsewhere."""
@@ -125,8 +71,6 @@ async def start_polling_node_metrics(
callback: Callable[[NodePerformanceProfile], Coroutine[Any, Any, None]],
):
poll_interval_s = 1.0
bandwidth_profile_started = False
while True:
try:
metrics = await get_metrics_async()
@@ -141,15 +85,6 @@ async def start_polling_node_metrics(
# do the memory profile last to get a fresh reading to not conflict with the other memory profiling loop
memory_profile = get_memory_profile()
# Start bandwidth profiling in background on first poll (non-blocking)
if not bandwidth_profile_started:
bandwidth_profile_started = True
# Fire and forget - don't await, let it run in background
asyncio.create_task(profile_bandwidth_once())
# Use cached bandwidth (None until profiling completes)
memory_bandwidth = get_memory_bandwidth_cached()
await callback(
NodePerformanceProfile(
model_id=model_id,
@@ -157,7 +92,6 @@ async def start_polling_node_metrics(
friendly_name=friendly_name,
network_interfaces=network_interfaces,
memory=memory_profile,
memory_bandwidth=memory_bandwidth,
system=SystemPerformanceProfile(
gpu_usage=metrics.gpu_usage[1],
temp=metrics.temp.gpu_temp_avg,

View File

@@ -1,6 +1,5 @@
import socket
import sys
import time
from subprocess import CalledProcessError
import psutil
@@ -82,68 +81,3 @@ async def get_model_and_chip() -> tuple[str, str]:
chip = chip_line.split(": ")[1] if chip_line else "Unknown Chip"
return (model, chip)
def profile_memory_bandwidth() -> int | None:
"""
Profile device memory bandwidth using MLX GPU operations.
Uses a large array copy on the GPU to measure unified memory bandwidth.
Returns measured bandwidth in bytes/second, or None if MLX is unavailable.
"""
try:
import mlx.core as mx
if not mx.metal.is_available():
return None
# Use 2GB buffer to better saturate memory bandwidth
# Use 2D shape to avoid potential issues with very large 1D arrays
size_bytes = 2 * 1024 * 1024 * 1024
side = int((size_bytes // 4) ** 0.5) # Square 2D array of float32
shape = (side, side)
actual_bytes = side * side * 4
bytes_transferred = actual_bytes * 2 # read + write
# Warm-up: run the full benchmark operation multiple times to stabilize GPU
for _ in range(3):
src = mx.random.uniform(shape=shape, dtype=mx.float32)
mx.eval(src)
dst = src + 0.0
mx.eval(dst)
mx.synchronize()
del src, dst
# Benchmark: measure time to copy array
best_bandwidth = 0.0
num_runs = 4
for _ in range(num_runs):
src = mx.random.uniform(shape=shape, dtype=mx.float32)
mx.eval(src)
mx.synchronize()
# Time the copy operation (src + 0.0 forces read of src, write of dst)
start = time.perf_counter()
dst = src + 0.0
mx.eval(dst)
mx.synchronize()
end = time.perf_counter()
bandwidth = bytes_transferred / (end - start)
best_bandwidth = max(best_bandwidth, bandwidth)
del src, dst
return int(best_bandwidth)
except Exception:
return None
def get_memory_bandwidth(_chip_id: str) -> int | None:
"""
Returns measured memory bandwidth in bytes/second.
Uses MLX GPU operations for accurate unified memory bandwidth measurement.
"""
return profile_memory_bandwidth()

1484
uv.lock generated
View File

File diff suppressed because it is too large Load Diff