Compare commits

..

15 Commits

Author SHA1 Message Date
Sami Khan
cd630dea43 formatting 2026-01-20 07:54:52 +05:00
Sami Khan
e55dae5ce8 code quality 2026-01-20 07:47:14 +05:00
Sami Khan
302c43afd5 Merge main into sami/flash 2026-01-20 07:31:22 +05:00
Sami Khan
2cf59e2322 nix flake check 2026-01-20 07:21:41 +05:00
Sami Khan
e506c7d65c exo plugins 2026-01-20 06:53:43 +05:00
Sami Khan
c1fa2ddeaf SLURM compatible commands 2026-01-20 06:53:43 +05:00
Alex Cheema
39f0ed6018 Prepend <think> tag to stream for thinking models like GLM-4.7 (#1186)
## Motivation

For thinking models like GLM-4.7, the `<think>` tag is inserted by the
tokenizer's `apply_chat_template()` into the **prompt** (input). The
model generates tokens starting *after* this tag, so `<think>` never
appears in the streamed output. The frontend expects
`<think>...</think>` tags to extract and display thinking content.

**Log evidence:**
```
[gMASK]<sop><|system|>...<|user|>...<|assistant|><think>
```
The prompt ends with `<think>`, so the model generates content after it,
never returning the opening tag.

## Changes

- Added `detect_thinking_prompt_suffix()` helper function in
`utils_mlx.py` to detect if a prompt ends with `<think>` tag
- Added `parse_thinking_models()` generator wrapper in `runner.py` that
prepends the thinking tag to the output stream
- Modified the main generation loop to use the thinking wrapper for
non-GptOssModel models when a thinking prefix is detected
- Updated test mocks to handle the new `apply_chat_template` call

## Why It Works

The solution follows the same pattern as `parse_gpt_oss()` - a generator
wrapper that transforms the output stream. When the chat template ends
with `<think>`, we prepend this tag to the first generated token so the
frontend receives the complete `<think>...</think>` structure it
expects.

## Test Plan

### Manual Testing
<!-- Hardware: (e.g., MacBook Pro M1 Max 32GB, Mac Mini M2 16GB,
connected via Thunderbolt 4) -->
<!-- What you did: -->
- Run exo: `uv run exo`
- Send a chat request to GLM-4.7:
  ```bash
curl http://localhost:52415/v1/chat/completions -H "Content-Type:
application/json" -d '{
    "model": "mlx-community/GLM-4.7-8bit-gs32",
    "messages": [{"role": "user", "content": "What is 2+2?"}],
    "stream": true
  }'
  ```
- Verify the streamed response starts with `<think>` tag
- Verify the frontend dashboard correctly shows the thinking section
collapsed

### Automated Testing
- All 72 worker tests pass: `uv run pytest src/exo/worker/`
- Type checker passes: `uv run basedpyright`
- Linter passes: `uv run ruff check`

🤖 Generated with [Claude Code](https://claude.com/claude-code)

---------

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
Co-authored-by: Ryuichi Leo Takashige <leo@exolabs.net>
2026-01-19 19:44:51 +00:00
Alex Cheema
ee43b598fe Split NodePerformanceProfile into granular state mappings (#1209)
## Motivation

The current `NodePerformanceProfile` is a monolithic object where every
update (even 1-second memory updates) replaces the entire profile,
touching unrelated data. Different fields update at vastly different
frequencies:

| Data | Update Frequency |
|------|------------------|
| Memory, System | 1 second |
| Thunderbolt | 5 seconds |
| Network interfaces | 10 seconds |
| Friendly name | 60 seconds |
| Model/Chip ID | Once at startup |

## Changes

Split into separate state mappings so each data type updates
independently:

- `node_identities`: Static and slow-changing data (model_id, chip_id,
friendly_name)
- `node_memory`: RAM and swap usage
- `node_system`: GPU usage, temperature, power, CPU metrics
- `node_network`: Network interface information
- `node_thunderbolt`: Thunderbolt interface identifiers

Added a backwards-compatible `node_profiles` property that reconstructs
`NodePerformanceProfile` from the granular mappings for dashboard
compatibility.

**Files modified:**
- `src/exo/shared/types/profiling.py` - Added `NodeIdentity`,
`NodeNetworkInfo`, `NodeThunderboltInfo` types
- `src/exo/shared/types/state.py` - Added 5 new mappings +
`node_profiles` property
- `src/exo/shared/apply.py` - Updated `apply_node_gathered_info` and
`apply_node_timed_out`

## Why It Works

Each info type now writes only to its specific mapping, avoiding
unnecessary updates to unrelated data. The `MacThunderboltConnections`
handler reads from `node_thunderbolt` instead of the old `node_profiles`
for RDMA connection mapping. The backwards-compatible property ensures
the dashboard continues to work unchanged.

## Test Plan

### Manual Testing
<!-- Hardware: (e.g., MacBook Pro M1 Max 32GB, Mac Mini M2 16GB,
connected via Thunderbolt 4) -->
<!-- What you did: -->
- Start exo and verify dashboard shows node info
- Verify memory/GPU updates stream correctly
- Check that node timeout properly cleans up all mappings

### Automated Testing
- All 162 existing tests pass
- basedpyright: 0 errors
- ruff check: All checks passed
- nix fmt: Applied

🤖 Generated with [Claude Code](https://claude.com/claude-code)

---------

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-19 18:24:15 +00:00
Sami Khan
37c5a2a246 Merge branch 'main' into sami/flash 2026-01-15 08:57:36 +05:00
Sami Khan
4d7f03834a deleted separate server 2026-01-15 08:50:45 +05:00
Sami Khan
bdb9fbc8c0 Merge branch 'main' into sami/flash 2026-01-14 08:10:51 +05:00
Sami Khan
8c7180810c type checking 2026-01-14 07:15:45 +05:00
Sami Khan
318c6e000b code cleanup 2026-01-14 04:56:59 +05:00
Sami Khan
2d45544da0 use rsh server instead of ssh 2026-01-13 02:46:25 +05:00
Sami Khan
7cbafa768a flash+exo 2026-01-12 10:26:16 +05:00
46 changed files with 2717 additions and 588 deletions

View File

@@ -71,44 +71,46 @@ export interface Instance {
};
}
interface RawNodeProfile {
// Granular node state types from the new state structure
interface RawNodeIdentity {
modelId?: string;
chipId?: string;
friendlyName?: string;
networkInterfaces?: Array<{
name?: string;
ipAddress?: string;
addresses?: Array<{ address?: string } | string>;
ipv4?: string;
ipv6?: string;
ipAddresses?: string[];
ips?: string[];
}>;
memory?: {
ramTotal?: { inBytes: number };
ramAvailable?: { inBytes: number };
swapTotal?: { inBytes: number };
swapAvailable?: { inBytes: number };
};
system?: {
gpuUsage?: number;
temp?: number;
sysPower?: number;
};
}
interface RawTopologyNode {
nodeId: string;
nodeProfile?: RawNodeProfile;
interface RawMemoryUsage {
ramTotal?: { inBytes: number };
ramAvailable?: { inBytes: number };
swapTotal?: { inBytes: number };
swapAvailable?: { inBytes: number };
}
interface RawSystemPerformanceProfile {
gpuUsage?: number;
temp?: number;
sysPower?: number;
pcpuUsage?: number;
ecpuUsage?: number;
}
interface RawNetworkInterfaceInfo {
name?: string;
ipAddress?: string;
addresses?: Array<{ address?: string } | string>;
ipv4?: string;
ipv6?: string;
ipAddresses?: string[];
ips?: string[];
}
interface RawNodeNetworkInfo {
interfaces?: RawNetworkInterfaceInfo[];
}
// New connection edge types from Python SocketConnection/RDMAConnection
interface RawSocketConnection {
sinkMultiaddr?: {
address?: string;
// Multiaddr uses snake_case (no camelCase alias)
ip_address?: string;
ipAddress?: string; // fallback in case it changes
address_type?: string;
port?: number;
};
@@ -125,14 +127,10 @@ type RawConnectionEdge = RawSocketConnection | RawRDMAConnection;
type RawConnectionsMap = Record<string, Record<string, RawConnectionEdge[]>>;
interface RawTopology {
// nodes can be array of strings (node IDs) or array of objects with nodeId/nodeProfile
nodes: (string | RawTopologyNode)[];
// New nested mapping format
nodes: string[];
connections?: RawConnectionsMap;
}
type RawNodeProfiles = Record<string, RawNodeProfile>;
export interface DownloadProgress {
totalBytes: number;
downloadedBytes: number;
@@ -187,7 +185,11 @@ interface RawStateResponse {
>;
runners?: Record<string, unknown>;
downloads?: Record<string, unknown[]>;
nodeProfiles?: RawNodeProfiles;
// New granular node state fields
nodeIdentities?: Record<string, RawNodeIdentity>;
nodeMemory?: Record<string, RawMemoryUsage>;
nodeSystem?: Record<string, RawSystemPerformanceProfile>;
nodeNetwork?: Record<string, RawNodeNetworkInfo>;
}
export interface MessageAttachment {
@@ -222,65 +224,69 @@ export interface Conversation {
const STORAGE_KEY = "exo-conversations";
interface GranularNodeState {
nodeIdentities?: Record<string, RawNodeIdentity>;
nodeMemory?: Record<string, RawMemoryUsage>;
nodeSystem?: Record<string, RawSystemPerformanceProfile>;
nodeNetwork?: Record<string, RawNodeNetworkInfo>;
}
function transformNetworkInterface(iface: RawNetworkInterfaceInfo): {
name?: string;
addresses: string[];
} {
const addresses: string[] = [];
if (iface.ipAddress && typeof iface.ipAddress === "string") {
addresses.push(iface.ipAddress);
}
if (Array.isArray(iface.addresses)) {
for (const addr of iface.addresses) {
if (typeof addr === "string") addresses.push(addr);
else if (addr && typeof addr === "object" && addr.address)
addresses.push(addr.address);
}
}
if (Array.isArray(iface.ipAddresses)) {
addresses.push(
...iface.ipAddresses.filter((a): a is string => typeof a === "string"),
);
}
if (Array.isArray(iface.ips)) {
addresses.push(
...iface.ips.filter((a): a is string => typeof a === "string"),
);
}
if (iface.ipv4 && typeof iface.ipv4 === "string") addresses.push(iface.ipv4);
if (iface.ipv6 && typeof iface.ipv6 === "string") addresses.push(iface.ipv6);
return {
name: iface.name,
addresses: Array.from(new Set(addresses)),
};
}
function transformTopology(
raw: RawTopology,
profiles?: RawNodeProfiles,
granularState: GranularNodeState,
): TopologyData {
const nodes: Record<string, NodeInfo> = {};
const edges: TopologyEdge[] = [];
// Handle nodes - can be array of strings (node IDs) or array of objects with nodeId/nodeProfile
for (const node of raw.nodes || []) {
// Determine the node ID - could be a string or an object with nodeId property
const nodeId = typeof node === "string" ? node : node.nodeId;
for (const nodeId of raw.nodes || []) {
if (!nodeId) continue;
// Get the profile - from the separate profiles map or from the node object itself
const profileFromMap = profiles?.[nodeId];
const profileFromNode =
typeof node === "object" ? node.nodeProfile : undefined;
const profile = { ...(profileFromNode ?? {}), ...(profileFromMap ?? {}) };
// Get data from granular state mappings
const identity = granularState.nodeIdentities?.[nodeId];
const memory = granularState.nodeMemory?.[nodeId];
const system = granularState.nodeSystem?.[nodeId];
const network = granularState.nodeNetwork?.[nodeId];
const ramTotal = profile?.memory?.ramTotal?.inBytes ?? 0;
const ramAvailable = profile?.memory?.ramAvailable?.inBytes ?? 0;
const ramTotal = memory?.ramTotal?.inBytes ?? 0;
const ramAvailable = memory?.ramAvailable?.inBytes ?? 0;
const ramUsage = Math.max(ramTotal - ramAvailable, 0);
const networkInterfaces = (profile?.networkInterfaces || []).map(
(iface) => {
const addresses: string[] = [];
if (iface.ipAddress && typeof iface.ipAddress === "string") {
addresses.push(iface.ipAddress);
}
if (Array.isArray(iface.addresses)) {
for (const addr of iface.addresses) {
if (typeof addr === "string") addresses.push(addr);
else if (addr && typeof addr === "object" && addr.address)
addresses.push(addr.address);
}
}
if (Array.isArray(iface.ipAddresses)) {
addresses.push(
...iface.ipAddresses.filter(
(a): a is string => typeof a === "string",
),
);
}
if (Array.isArray(iface.ips)) {
addresses.push(
...iface.ips.filter((a): a is string => typeof a === "string"),
);
}
if (iface.ipv4 && typeof iface.ipv4 === "string")
addresses.push(iface.ipv4);
if (iface.ipv6 && typeof iface.ipv6 === "string")
addresses.push(iface.ipv6);
return {
name: iface.name,
addresses: Array.from(new Set(addresses)),
};
},
);
const rawInterfaces = network?.interfaces || [];
const networkInterfaces = rawInterfaces.map(transformNetworkInterface);
const ipToInterface: Record<string, string> = {};
for (const iface of networkInterfaces) {
@@ -291,8 +297,8 @@ function transformTopology(
nodes[nodeId] = {
system_info: {
model_id: profile?.modelId ?? "Unknown",
chip: profile?.chipId,
model_id: identity?.modelId ?? "Unknown",
chip: identity?.chipId,
memory: ramTotal,
},
network_interfaces: networkInterfaces,
@@ -303,17 +309,15 @@ function transformTopology(
ram_total: ramTotal,
},
temp:
profile?.system?.temp !== undefined
? { gpu_temp_avg: profile.system.temp }
system?.temp !== undefined
? { gpu_temp_avg: system.temp }
: undefined,
gpu_usage:
profile?.system?.gpuUsage !== undefined
? [0, profile.system.gpuUsage]
: undefined,
sys_power: profile?.system?.sysPower,
system?.gpuUsage !== undefined ? [0, system.gpuUsage] : undefined,
sys_power: system?.sysPower,
},
last_macmon_update: Date.now() / 1000,
friendly_name: profile?.friendlyName,
friendly_name: identity?.friendlyName,
};
}
@@ -325,19 +329,15 @@ function transformTopology(
for (const [sink, edgeList] of Object.entries(sinks)) {
if (!Array.isArray(edgeList)) continue;
for (const edge of edgeList) {
// Extract IP from SocketConnection (uses snake_case: ip_address)
let sendBackIp: string | undefined;
if (edge && typeof edge === "object" && "sinkMultiaddr" in edge) {
const multiaddr = edge.sinkMultiaddr;
if (multiaddr) {
// Try both snake_case (actual) and camelCase (in case it changes)
sendBackIp =
multiaddr.ip_address ||
multiaddr.ipAddress ||
extractIpFromMultiaddr(multiaddr.address);
}
}
// RDMAConnection (sourceRdmaIface/sinkRdmaIface) has no IP - edge just shows connection exists
if (nodes[source] && nodes[sink] && source !== sink) {
edges.push({ source, target: sink, sendBackIp });
@@ -898,7 +898,12 @@ class AppStore {
const data: RawStateResponse = await response.json();
if (data.topology) {
this.topologyData = transformTopology(data.topology, data.nodeProfiles);
this.topologyData = transformTopology(data.topology, {
nodeIdentities: data.nodeIdentities,
nodeMemory: data.nodeMemory,
nodeSystem: data.nodeSystem,
nodeNetwork: data.nodeNetwork,
});
}
if (data.instances) {
this.instances = data.instances;

View File

@@ -30,6 +30,7 @@ dependencies = [
exo-master = "exo.master.main:main"
exo-worker = "exo.worker.main:main"
exo = "exo.main:main"
exo-rsh = "exo.rsh.client:main"
# dependencies only required for development
[dependency-groups]

32
src/exo/cli/__init__.py Normal file
View File

@@ -0,0 +1,32 @@
"""Exo CLI - SLURM-compatible job management commands."""
def run_subcommand(command: str, args: list[str]) -> int:
"""Route to the appropriate subcommand handler.
Args:
command: The subcommand name (sbatch, squeue, scancel, salloc)
args: Command line arguments for the subcommand
Returns:
Exit code from the subcommand
"""
if command == "sbatch":
from exo.cli.sbatch import main
return main(args)
elif command == "squeue":
from exo.cli.squeue import main
return main(args)
elif command == "scancel":
from exo.cli.scancel import main
return main(args)
elif command == "salloc":
from exo.cli.salloc import main
return main(args)
else:
print(f"Unknown subcommand: {command}")
return 1

118
src/exo/cli/common.py Normal file
View File

@@ -0,0 +1,118 @@
"""Common utilities for Exo CLI commands."""
import json
import os
import urllib.request
from typing import Any
from urllib.error import HTTPError, URLError
# Default API endpoint
DEFAULT_API_HOST = "localhost"
DEFAULT_API_PORT = 52415
def get_api_base() -> str:
"""Get the API base URL from environment or defaults."""
host = os.environ.get("EXO_API_HOST", DEFAULT_API_HOST)
port = os.environ.get("EXO_API_PORT", str(DEFAULT_API_PORT))
return f"http://{host}:{port}"
def api_request(
method: str,
path: str,
data: dict[str, Any] | None = None,
) -> dict[str, Any] | list[Any]:
"""Make an API request to the Exo server.
Args:
method: HTTP method (GET, POST, DELETE, etc.)
path: API path (e.g., "/flash/instances")
data: Optional JSON data for POST/PUT requests
Returns:
Parsed JSON response
Raises:
SystemExit: On connection or HTTP errors
"""
url = f"{get_api_base()}{path}"
request_data = None
if data is not None:
request_data = json.dumps(data).encode("utf-8")
req = urllib.request.Request(
url,
data=request_data,
method=method,
)
req.add_header("Content-Type", "application/json")
try:
with urllib.request.urlopen(req, timeout=30) as response: # pyright: ignore[reportAny]
body: str = response.read().decode("utf-8") # pyright: ignore[reportAny]
if body:
return json.loads(body) # pyright: ignore[reportAny]
return {}
except HTTPError as e:
error_body = e.read().decode("utf-8") if e.fp else ""
print(f"API error: {e.code} {e.reason}")
if error_body:
try:
error_json: dict[str, str] = json.loads(error_body) # pyright: ignore[reportAny]
if "detail" in error_json:
print(f" {error_json['detail']}")
except json.JSONDecodeError:
print(f" {error_body}")
raise SystemExit(1) from None
except URLError as e:
print(f"Connection error: {e.reason}")
print(f"Is Exo running at {get_api_base()}?")
raise SystemExit(1) from None
def truncate_id(instance_id: str, length: int = 8) -> str:
"""Truncate a UUID for display.
Args:
instance_id: Full UUID string
length: Number of characters to keep
Returns:
Truncated ID without hyphens
"""
return instance_id.replace("-", "")[:length]
def format_table(headers: list[str], rows: list[list[str]]) -> str:
"""Format data as a simple text table.
Args:
headers: Column headers
rows: List of rows, each row is a list of column values
Returns:
Formatted table string
"""
if not rows:
return " ".join(f"{h:<10}" for h in headers)
# Calculate column widths
widths = [len(h) for h in headers]
for row in rows:
for i, cell in enumerate(row):
if i < len(widths):
widths[i] = max(widths[i], len(cell))
# Build format string
fmt = " ".join(f"{{:<{w}}}" for w in widths)
# Format output
lines = [fmt.format(*headers)]
for row in rows:
# Pad row if needed
padded = row + [""] * (len(headers) - len(row))
lines.append(fmt.format(*padded[: len(headers)]))
return "\n".join(lines)

100
src/exo/cli/salloc.py Normal file
View File

@@ -0,0 +1,100 @@
"""salloc - Allocate nodes for interactive use.
Usage:
exo salloc [options] [-- command [args...]]
Options:
-N, --nodes N Number of nodes to allocate (default: 1)
--hosts HOSTS Comma-separated list of hostnames
If a command is provided after --, it will be executed with
SLURM-like environment variables set:
SLURM_JOB_NODELIST - Comma-separated list of allocated nodes
SLURM_NNODES - Number of allocated nodes
Examples:
exo salloc --nodes=2 --hosts=node1,node2 -- mpirun ./my_program
exo salloc --hosts=localhost -- bash
"""
import argparse
import os
import subprocess
import sys
def main(args: list[str]) -> int:
"""Main entry point for salloc command."""
# Split args at -- if present
cmd_args: list[str] = []
salloc_args = args
if "--" in args:
idx = args.index("--")
salloc_args = args[:idx]
cmd_args = args[idx + 1 :]
parser = argparse.ArgumentParser(
prog="exo salloc",
description="Allocate nodes for interactive use",
)
parser.add_argument(
"-N",
"--nodes",
type=int,
default=1,
help="Number of nodes to allocate (default: 1)",
)
parser.add_argument(
"--hosts",
help="Comma-separated list of hostnames (required)",
)
parsed = parser.parse_args(salloc_args)
nodes: int = parsed.nodes # pyright: ignore[reportAny]
hosts: str | None = parsed.hosts # pyright: ignore[reportAny]
# Require explicit hosts since we can't discover them from topology
if not hosts:
print("Error: --hosts is required (e.g., --hosts=node1,node2)", file=sys.stderr)
print(" The Exo topology doesn't expose hostnames.", file=sys.stderr)
return 1
host_list = [h.strip() for h in hosts.split(",") if h.strip()]
if len(host_list) < nodes:
print(
f"Error: Requested {nodes} nodes but only {len(host_list)} hosts provided",
file=sys.stderr,
)
return 1
# Use first N hosts
allocated_hosts = host_list[:nodes]
nodelist = ",".join(allocated_hosts)
# Set environment variables
env = os.environ.copy()
env["SLURM_JOB_NODELIST"] = nodelist
env["SLURM_NNODES"] = str(nodes)
print(f"salloc: Granted job allocation on {nodes} node(s)")
print(f"salloc: Nodes: {nodelist}")
if cmd_args:
# Run the command
print(f"salloc: Running: {' '.join(cmd_args)}")
result = subprocess.run(cmd_args, env=env)
return result.returncode
else:
# Start interactive shell
shell = os.environ.get("SHELL", "/bin/bash")
print(f"salloc: Starting shell {shell}")
print("salloc: Use 'exit' to release allocation")
result = subprocess.run([shell], env=env)
return result.returncode
if __name__ == "__main__":
sys.exit(main(sys.argv[1:]))

233
src/exo/cli/sbatch.py Normal file
View File

@@ -0,0 +1,233 @@
"""sbatch - Submit a batch job to Exo.
Usage:
exo sbatch [options] <script|executable>
exo sbatch --job-name=NAME --nodes=N <executable>
Options:
-J, --job-name NAME Job name
-N, --nodes N Number of nodes (default: 1)
--ntasks-per-node N Tasks per node (default: 1)
-D, --chdir DIR Working directory
--hosts HOSTS Comma-separated list of hostnames
Job scripts can contain #SBATCH directives:
#!/bin/bash
#SBATCH --job-name=Sod2D
#SBATCH --nodes=2
#SBATCH --chdir=/path/to/workdir
/path/to/flash4
"""
import argparse
import os
import re
import sys
from exo.cli.common import api_request, truncate_id
def parse_job_script(script_path: str) -> tuple[dict[str, str], str | None]:
"""Parse a job script for #SBATCH directives and executable.
Args:
script_path: Path to the job script
Returns:
Tuple of (directives dict, executable path or None)
"""
directives: dict[str, str] = {}
executable: str | None = None
with open(script_path, "r") as f:
for line in f:
line = line.strip()
# Parse #SBATCH directives
if line.startswith("#SBATCH"):
# Handle both --option=value and --option value formats
match = re.match(r"#SBATCH\s+(-\w|--[\w-]+)(?:=|\s+)(.+)", line)
if match:
opt, val = match.groups()
directives[opt.lstrip("-")] = val.strip()
continue
# Skip comments and empty lines
if line.startswith("#") or not line:
continue
# First non-comment, non-directive line is the executable
if executable is None:
# Handle lines like "/path/to/flash4" or "srun /path/to/flash4"
parts = line.split()
if parts:
# Skip srun/mpirun prefixes if present
for part in parts:
if not part.startswith("-") and "/" in part:
executable = part
break
if executable is None and parts:
executable = parts[-1] # Last token
return directives, executable
def main(args: list[str]) -> int:
"""Main entry point for sbatch command."""
parser = argparse.ArgumentParser(
prog="exo sbatch",
description="Submit a batch job to Exo",
)
parser.add_argument(
"script",
help="Job script or executable path",
)
parser.add_argument(
"-J",
"--job-name",
dest="job_name",
help="Job name",
)
parser.add_argument(
"-N",
"--nodes",
type=int,
default=1,
help="Number of nodes (default: 1)",
)
parser.add_argument(
"--ntasks-per-node",
type=int,
default=1,
help="Tasks per node (default: 1)",
)
parser.add_argument(
"-D",
"--chdir",
help="Working directory",
)
parser.add_argument(
"--hosts",
help="Comma-separated list of hostnames",
)
parsed = parser.parse_args(args)
# Extract typed values from namespace
script_path: str = parsed.script # pyright: ignore[reportAny]
arg_job_name: str | None = parsed.job_name # pyright: ignore[reportAny]
arg_nodes: int = parsed.nodes # pyright: ignore[reportAny]
arg_ntasks: int = parsed.ntasks_per_node # pyright: ignore[reportAny]
arg_chdir: str | None = parsed.chdir # pyright: ignore[reportAny]
arg_hosts: str | None = parsed.hosts # pyright: ignore[reportAny]
# Determine if input is a script or direct executable
executable: str | None = None
directives: dict[str, str] = {}
if os.path.isfile(script_path):
# Check if it's a binary file (executable) or text script
is_binary = False
try:
with open(script_path, "rb") as f:
chunk = f.read(512)
# Binary files typically contain null bytes
is_binary = b"\x00" in chunk
except OSError:
pass
if is_binary:
# It's a binary executable
executable = script_path
else:
# Try to read as text
try:
with open(script_path, "r") as f:
first_line = f.readline()
f.seek(0)
content = f.read(1024)
if first_line.startswith("#!") or "#SBATCH" in content:
# It's a job script - parse it
directives, executable = parse_job_script(script_path)
else:
# It's an executable (text but no shebang/directives)
executable = script_path
except UnicodeDecodeError:
# Can't read as text - treat as binary executable
executable = script_path
else:
# Not a file - treat as executable path
executable = script_path
if executable is None:
print("Error: No executable found in job script", file=sys.stderr)
return 1
# Build job parameters - CLI args override script directives
job_name = arg_job_name or directives.get("job-name") or directives.get("J")
if not job_name:
# Generate name from executable
job_name = os.path.basename(executable).replace(".", "_")
nodes = arg_nodes
if "nodes" in directives:
nodes = int(directives["nodes"])
if "N" in directives:
nodes = int(directives["N"])
if arg_nodes != 1: # CLI override
nodes = arg_nodes
ntasks = arg_ntasks
if "ntasks-per-node" in directives:
ntasks = int(directives["ntasks-per-node"])
if arg_ntasks != 1: # CLI override
ntasks = arg_ntasks
workdir = arg_chdir or directives.get("chdir") or directives.get("D")
if not workdir:
workdir = os.getcwd()
hosts = arg_hosts or directives.get("hosts") or ""
# Resolve executable to absolute path
if not os.path.isabs(executable):
executable = os.path.abspath(os.path.join(workdir, executable))
# Submit job via API using query parameters
from urllib.parse import urlencode
params = {
"simulation_name": job_name,
"flash_executable_path": executable,
"parameter_file_path": "", # FLASH par file - use default
"working_directory": workdir,
"ranks_per_node": str(ntasks),
"min_nodes": str(nodes),
"hosts": hosts,
}
query_string = urlencode(params)
result = api_request("POST", f"/flash/launch?{query_string}")
# Print job submission confirmation
if isinstance(result, dict):
instance_id_val = result.get("instance_id")
if instance_id_val is not None:
job_id = truncate_id(str(instance_id_val)) # pyright: ignore[reportAny]
print(f"Submitted batch job {job_id}")
else:
# Instance created asynchronously - user should check squeue
print("Job submitted successfully")
print("Use 'exo squeue' to view job ID")
else:
print("Job submitted successfully")
print("Use 'exo squeue' to view job ID")
return 0
if __name__ == "__main__":
sys.exit(main(sys.argv[1:]))

95
src/exo/cli/scancel.py Normal file
View File

@@ -0,0 +1,95 @@
"""scancel - Cancel jobs in the Exo queue.
Usage:
exo scancel <jobid> [<jobid>...]
Arguments:
jobid Job ID (or prefix) to cancel. Can specify multiple.
Examples:
exo scancel abc123 # Cancel job starting with abc123
exo scancel abc123 def456 # Cancel multiple jobs
"""
import argparse
import sys
from typing import Any, cast
from exo.cli.common import api_request, truncate_id
def main(args: list[str]) -> int:
"""Main entry point for scancel command."""
parser = argparse.ArgumentParser(
prog="exo scancel",
description="Cancel jobs in the Exo queue",
)
parser.add_argument(
"jobids",
nargs="+",
help="Job ID(s) to cancel",
)
parsed = parser.parse_args(args)
jobids: list[str] = parsed.jobids # pyright: ignore[reportAny]
# Fetch current jobs to resolve partial IDs
result = api_request("GET", "/flash/instances")
if isinstance(result, list):
instances = cast(list[dict[str, Any]], result)
else:
instances = cast(list[dict[str, Any]], result.get("instances", []))
# Build lookup of full IDs
id_map: dict[str, str] = {}
for inst in instances:
iid = inst.get("instance_id", "") # pyright: ignore[reportAny]
full_id = str(iid) if iid else "" # pyright: ignore[reportAny]
if full_id:
# Map both full ID and truncated versions
normalized = full_id.replace("-", "").lower()
id_map[normalized] = full_id
# Also map prefixes
for length in range(4, len(normalized) + 1):
prefix = normalized[:length]
if prefix not in id_map:
id_map[prefix] = full_id
cancelled = 0
errors = 0
for jobid in jobids:
search = jobid.lower().replace("-", "")
# Find matching full ID
full_id = id_map.get(search)
if not full_id:
# Try prefix match
matches = [fid for key, fid in id_map.items() if key.startswith(search)]
if len(matches) == 1:
full_id = matches[0]
elif len(matches) > 1:
print(f"Ambiguous job ID: {jobid} matches multiple jobs")
errors += 1
continue
else:
print(f"Job not found: {jobid}")
errors += 1
continue
# Cancel the job
try:
api_request("DELETE", f"/flash/{full_id}")
print(f"Job {truncate_id(full_id)} cancelled")
cancelled += 1
except SystemExit:
print(f"Failed to cancel job {truncate_id(full_id)}")
errors += 1
if errors > 0 and cancelled == 0:
return 1
return 0
if __name__ == "__main__":
sys.exit(main(sys.argv[1:]))

165
src/exo/cli/squeue.py Normal file
View File

@@ -0,0 +1,165 @@
"""squeue - View the Exo job queue.
Usage:
exo squeue [options]
Options:
-l, --long Show detailed output
-j, --job ID Show only this job
Output columns:
JOBID - Job identifier (truncated UUID)
NAME - Job name
NODES - Number of nodes
STATE - Job state (PENDING, RUNNING, FAILED, etc.)
"""
import argparse
import sys
from typing import Any, cast
from exo.cli.common import api_request, format_table, truncate_id
# Map Exo runner statuses to SLURM-like states
STATUS_MAP: dict[str, str] = {
"RunnerIdle": "PENDING",
"RunnerConnecting": "CONFIGURING",
"RunnerConnected": "CONFIGURING",
"RunnerLoading": "CONFIGURING",
"RunnerLoaded": "CONFIGURING",
"RunnerWarmingUp": "CONFIGURING",
"RunnerReady": "COMPLETING",
"RunnerRunning": "RUNNING",
"RunnerShuttingDown": "COMPLETING",
"RunnerShutdown": "COMPLETED",
"RunnerFailed": "FAILED",
}
def get_job_state(runner_statuses: dict[str, Any]) -> str:
"""Determine overall job state from runner statuses."""
if not runner_statuses:
return "PENDING"
states: set[str] = set()
for status_val in runner_statuses.values(): # pyright: ignore[reportAny]
if isinstance(status_val, dict):
# Extract status type from discriminated union
type_val = status_val.get("type", "RunnerIdle") # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
status_type = str(type_val) if type_val else "RunnerIdle" # pyright: ignore[reportUnknownArgumentType]
elif isinstance(status_val, str):
status_type = status_val
else:
status_type = "RunnerIdle"
# Strip parentheses from status strings like "RunnerRunning()"
if status_type.endswith("()"):
status_type = status_type[:-2]
states.add(STATUS_MAP.get(status_type, "UNKNOWN"))
# Priority order for overall state
if "FAILED" in states:
return "FAILED"
if "RUNNING" in states:
return "RUNNING"
if "CONFIGURING" in states:
return "CONFIGURING"
if "COMPLETING" in states:
return "COMPLETING"
if "COMPLETED" in states:
return "COMPLETED"
if "PENDING" in states:
return "PENDING"
return "UNKNOWN"
def main(args: list[str]) -> int:
"""Main entry point for squeue command."""
parser = argparse.ArgumentParser(
prog="exo squeue",
description="View the Exo job queue",
)
parser.add_argument(
"-l",
"--long",
action="store_true",
help="Show detailed output",
)
parser.add_argument(
"-j",
"--job",
help="Show only this job ID",
)
parsed = parser.parse_args(args)
# Extract typed values
long_format: bool = parsed.long # pyright: ignore[reportAny]
job_filter: str | None = parsed.job # pyright: ignore[reportAny]
# Fetch jobs from API - returns list directly
result = api_request("GET", "/flash/instances")
# API returns list directly, not {"instances": [...]}
if isinstance(result, list):
instances = cast(list[dict[str, Any]], result)
else:
instances = cast(list[dict[str, Any]], result.get("instances", []))
if not instances:
# No jobs - just print header
if long_format:
print("JOBID NAME NODES RANKS STATE WORKDIR")
else:
print("JOBID NAME NODES STATE")
return 0
# Filter by job ID if specified
if job_filter:
search = job_filter.lower()
filtered: list[dict[str, Any]] = []
for i in instances:
iid = i.get("instance_id", "") # pyright: ignore[reportAny]
if search in str(iid).lower().replace("-", ""): # pyright: ignore[reportAny]
filtered.append(i)
instances = filtered
# Build table
rows: list[list[str]] = []
if long_format:
headers = ["JOBID", "NAME", "NODES", "RANKS", "STATE", "WORKDIR"]
for inst in instances:
iid_val = inst.get("instance_id", "") # pyright: ignore[reportAny]
instance_id = str(iid_val) if iid_val else "" # pyright: ignore[reportAny]
job_id = truncate_id(instance_id, 12)
name_val = inst.get("simulation_name", "") # pyright: ignore[reportAny]
name = (str(name_val) if name_val else "")[:15] # pyright: ignore[reportAny]
runner_statuses = cast(dict[str, Any], inst.get("runner_statuses", {}))
nodes = str(len(runner_statuses))
ranks_val = inst.get("total_ranks", 0) # pyright: ignore[reportAny]
ranks = str(ranks_val) if ranks_val else "0" # pyright: ignore[reportAny]
state = get_job_state(runner_statuses)
workdir_val = inst.get("working_directory", "") # pyright: ignore[reportAny]
workdir = str(workdir_val) if workdir_val else "" # pyright: ignore[reportAny]
# Truncate workdir for display
if len(workdir) > 30:
workdir = "..." + workdir[-27:]
rows.append([job_id, name, nodes, ranks, state, workdir])
else:
headers = ["JOBID", "NAME", "NODES", "STATE"]
for inst in instances:
iid_val = inst.get("instance_id", "") # pyright: ignore[reportAny]
instance_id = str(iid_val) if iid_val else "" # pyright: ignore[reportAny]
job_id = truncate_id(instance_id, 8)
name_val = inst.get("simulation_name", "") # pyright: ignore[reportAny]
name = (str(name_val) if name_val else "")[:15] # pyright: ignore[reportAny]
runner_statuses = cast(dict[str, Any], inst.get("runner_statuses", {}))
nodes = str(len(runner_statuses))
state = get_job_state(runner_statuses)
rows.append([job_id, name, nodes, state])
print(format_table(headers, rows))
return 0
if __name__ == "__main__":
sys.exit(main(sys.argv[1:]))

View File

@@ -195,6 +195,14 @@ class Node:
def main():
# Check for SLURM-compatible subcommands first
import sys
if len(sys.argv) > 1 and sys.argv[1] in ("sbatch", "squeue", "scancel", "salloc"):
from exo.cli import run_subcommand
sys.exit(run_subcommand(sys.argv[1], sys.argv[2:]))
args = Args.parse()
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
resource.setrlimit(resource.RLIMIT_NOFILE, (max(soft, 65535), hard))
@@ -205,6 +213,11 @@ def main():
logger.info("Starting EXO")
logger.info(f"EXO_LIBP2P_NAMESPACE: {os.getenv('EXO_LIBP2P_NAMESPACE')}")
# Discover and register plugins
from exo.plugins.registry import discover_plugins
discover_plugins()
# Set FAST_SYNCH override env var for runner subprocesses
if args.fast_synch is True:
os.environ["EXO_FAST_SYNCH"] = "on"

View File

@@ -1,7 +1,9 @@
import asyncio
import os
import time
from collections.abc import AsyncGenerator
from http import HTTPStatus
from typing import cast
from typing import Any, Optional, cast
import anyio
from anyio import BrokenResourceError, create_task_group
@@ -14,6 +16,7 @@ from hypercorn.asyncio import serve # pyright: ignore[reportUnknownVariableType
from hypercorn.config import Config
from hypercorn.typing import ASGIFramework
from loguru import logger
from pydantic import BaseModel
from exo.master.placement import place_instance as get_instance_placements
from exo.shared.apply import apply
@@ -62,7 +65,11 @@ from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.state import State
from exo.shared.types.tasks import ChatCompletionTaskParams
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.instances import (
Instance,
InstanceId,
InstanceMeta,
)
from exo.shared.types.worker.shards import Sharding
from exo.utils.banner import print_startup_banner
from exo.utils.channels import Receiver, Sender, channel
@@ -70,6 +77,22 @@ from exo.utils.dashboard_path import find_dashboard
from exo.utils.event_buffer import OrderedBuffer
class ExecuteRequest(BaseModel):
"""Request to execute a command."""
command: list[str]
cwd: Optional[str] = None
env: Optional[dict[str, str]] = None
class ExecuteResponse(BaseModel):
"""Response from command execution."""
exit_code: int
stdout: str
stderr: str
def chunk_to_response(
chunk: TokenChunk, command_id: CommandId
) -> ChatCompletionResponse:
@@ -126,6 +149,7 @@ class API:
self._setup_exception_handlers()
self._setup_cors()
self._setup_routes()
self._register_plugin_routes()
self.app.mount(
"/",
@@ -194,6 +218,58 @@ class API:
self.app.post("/bench/chat/completions")(self.bench_chat_completions)
self.app.get("/state")(lambda: self.state)
self.app.get("/events")(lambda: self._event_log)
# Remote execution endpoint (used by exo-rsh for MPI)
self.app.post("/execute")(self.execute)
def _register_plugin_routes(self) -> None:
"""Register API routes from all loaded plugins."""
import functools
import inspect
from exo.plugins.context import PluginContext
from exo.plugins.registry import PluginRegistry
registry = PluginRegistry.get()
for method, path, handler, plugin in registry.get_all_api_routes():
# Create a wrapper that injects the plugin context while preserving
# the original function signature for FastAPI parameter extraction
@functools.wraps(handler)
async def wrapped_handler( # pyright: ignore[reportAny]
*args: Any, # pyright: ignore[reportAny]
_handler: Any = handler, # pyright: ignore[reportAny]
**kwargs: Any, # pyright: ignore[reportAny]
) -> Any:
context = PluginContext(
state=self.state,
send_command=self._send,
node_id=self.node_id,
)
# Pass context as first argument, then forward all other args
return await _handler(context, *args, **kwargs) # pyright: ignore[reportAny]
# Modify the wrapper signature to match the original handler
# but without the 'ctx' parameter (we inject it)
orig_sig = inspect.signature(handler)
params = list(orig_sig.parameters.values())
# Remove the first parameter (ctx: PluginContext)
if params and params[0].name in ("ctx", "context"):
params = params[1:]
wrapped_handler.__signature__ = orig_sig.replace(parameters=params) # pyright: ignore[reportAttributeAccessIssue]
# Register the route based on HTTP method
if method == "get":
self.app.get(path)(wrapped_handler)
elif method == "post":
self.app.post(path)(wrapped_handler)
elif method == "delete":
self.app.delete(path)(wrapped_handler)
elif method == "put":
self.app.put(path)(wrapped_handler)
logger.debug(
f"Registered plugin route: {method.upper()} {path} ({plugin.name})"
)
async def place_instance(self, payload: PlaceInstanceParams):
command = PlaceInstance(
@@ -252,7 +328,8 @@ class API:
instance_meta=instance_meta,
min_nodes=min_nodes,
),
node_profiles=self.state.node_profiles,
node_memory=self.state.node_memory,
node_network=self.state.node_network,
topology=self.state.topology,
current_instances=self.state.instances,
)
@@ -308,7 +385,8 @@ class API:
instance_meta=instance_meta,
min_nodes=min_nodes,
),
node_profiles=self.state.node_profiles,
node_memory=self.state.node_memory,
node_network=self.state.node_network,
topology=self.state.topology,
current_instances=self.state.instances,
)
@@ -602,8 +680,8 @@ class API:
"""Calculate total available memory across all nodes in bytes."""
total_available = Memory()
for profile in self.state.node_profiles.values():
total_available += profile.memory.ram_available
for memory in self.state.node_memory.values():
total_available += memory.ram_available
return total_available
@@ -624,6 +702,65 @@ class API:
]
)
async def execute(self, request: ExecuteRequest) -> ExecuteResponse:
"""Execute a command locally. Used by exo-rsh for MPI remote execution."""
cmd_str = " ".join(request.command)
logger.info(f"Executing: {cmd_str}")
try:
# Build environment
env = os.environ.copy()
if request.env:
env.update(request.env)
# Check if command contains shell metacharacters
# If so, run through shell. mpirun sends complex commands like:
# "VAR=value;export VAR;/path/to/prted --args"
needs_shell = any(c in cmd_str for c in ";|&$`")
if needs_shell:
process = await asyncio.create_subprocess_shell(
cmd_str,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
cwd=request.cwd,
env=env,
)
else:
process = await asyncio.create_subprocess_exec(
*request.command,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
cwd=request.cwd,
env=env,
)
stdout, stderr = await process.communicate()
exit_code = process.returncode or 0
logger.info(f"Command completed with exit code {exit_code}")
return ExecuteResponse(
exit_code=exit_code,
stdout=stdout.decode("utf-8", errors="replace"),
stderr=stderr.decode("utf-8", errors="replace"),
)
except FileNotFoundError:
logger.error(f"Command not found: {request.command[0]}")
return ExecuteResponse(
exit_code=127,
stdout="",
stderr=f"Command not found: {request.command[0]}",
)
except Exception as e:
logger.error(f"Execution error: {e}")
return ExecuteResponse(
exit_code=1,
stdout="",
stderr=str(e),
)
async def run(self):
cfg = Config()
cfg.bind = f"0.0.0.0:{self.port}"

View File

@@ -96,103 +96,128 @@ class Master:
self._tg.cancel_scope.cancel()
async def _command_processor(self) -> None:
from exo.plugins.registry import PluginRegistry
registry = PluginRegistry.get()
with self.command_receiver as commands:
async for forwarder_command in commands:
try:
logger.info(f"Executing command: {forwarder_command.command}")
generated_events: list[Event] = []
command = forwarder_command.command
match command:
case TestCommand():
pass
case ChatCompletion():
instance_task_counts: dict[InstanceId, int] = {}
for instance in self.state.instances.values():
if (
instance.shard_assignments.model_id
== command.request_params.model
):
task_count = sum(
1
for task in self.state.tasks.values()
if task.instance_id == instance.instance_id
)
instance_task_counts[instance.instance_id] = (
task_count
# Check if a plugin handles this command
plugin = registry.get_plugin_for_command(command)
if plugin is not None:
events = plugin.process_command(
command,
self.state.topology,
self.state.instances,
)
generated_events.extend(events)
else:
# Core command handling
match command:
case TestCommand():
pass
case ChatCompletion():
instance_task_counts: dict[InstanceId, int] = {}
for instance in self.state.instances.values():
if (
instance.shard_assignments.model_id
== command.request_params.model
):
task_count = sum(
1
for task in self.state.tasks.values()
if task.instance_id == instance.instance_id
)
instance_task_counts[instance.instance_id] = (
task_count
)
if not instance_task_counts:
raise ValueError(
f"No instance found for model {command.request_params.model}"
)
if not instance_task_counts:
raise ValueError(
f"No instance found for model {command.request_params.model}"
available_instance_ids = sorted(
instance_task_counts.keys(),
key=lambda instance_id: instance_task_counts[
instance_id
],
)
available_instance_ids = sorted(
instance_task_counts.keys(),
key=lambda instance_id: instance_task_counts[
instance_id
],
)
task_id = TaskId()
generated_events.append(
TaskCreated(
task_id=task_id,
task=ChatCompletionTask(
task_id = TaskId()
generated_events.append(
TaskCreated(
task_id=task_id,
command_id=command.command_id,
instance_id=available_instance_ids[0],
task_status=TaskStatus.Pending,
task_params=command.request_params,
),
task=ChatCompletionTask(
task_id=task_id,
command_id=command.command_id,
instance_id=available_instance_ids[0],
task_status=TaskStatus.Pending,
task_params=command.request_params,
),
)
)
)
self.command_task_mapping[command.command_id] = task_id
case DeleteInstance():
placement = delete_instance(command, self.state.instances)
transition_events = get_transition_events(
self.state.instances, placement
)
generated_events.extend(transition_events)
case PlaceInstance():
placement = place_instance(
command,
self.state.topology,
self.state.instances,
self.state.node_profiles,
)
transition_events = get_transition_events(
self.state.instances, placement
)
generated_events.extend(transition_events)
case CreateInstance():
placement = add_instance_to_placements(
command,
self.state.topology,
self.state.instances,
)
transition_events = get_transition_events(
self.state.instances, placement
)
generated_events.extend(transition_events)
case TaskFinished():
generated_events.append(
TaskDeleted(
task_id=self.command_task_mapping[
self.command_task_mapping[command.command_id] = task_id
case DeleteInstance():
placement = delete_instance(
command, self.state.instances
)
transition_events = get_transition_events(
self.state.instances, placement
)
generated_events.extend(transition_events)
case PlaceInstance():
placement = place_instance(
command,
self.state.topology,
self.state.instances,
self.state.node_memory,
self.state.node_network,
)
transition_events = get_transition_events(
self.state.instances, placement
)
generated_events.extend(transition_events)
case CreateInstance():
placement = add_instance_to_placements(
command,
self.state.topology,
self.state.instances,
)
transition_events = get_transition_events(
self.state.instances, placement
)
generated_events.extend(transition_events)
case TaskFinished():
generated_events.append(
TaskDeleted(
task_id=self.command_task_mapping[
command.finished_command_id
]
)
)
if (
command.finished_command_id
in self.command_task_mapping
):
del self.command_task_mapping[
command.finished_command_id
]
)
)
if command.finished_command_id in self.command_task_mapping:
del self.command_task_mapping[
command.finished_command_id
]
case RequestEventLog():
# We should just be able to send everything, since other buffers will ignore old messages
for i in range(command.since_idx, len(self._event_log)):
await self._send_event(
IndexedEvent(idx=i, event=self._event_log[i])
)
case RequestEventLog():
# We should just be able to send everything, since other buffers will ignore old messages
for i in range(command.since_idx, len(self._event_log)):
await self._send_event(
IndexedEvent(idx=i, event=self._event_log[i])
)
case _:
# Plugin-managed commands are handled above
pass
for event in generated_events:
await self.event_sender.send(event)
except ValueError as e:

View File

@@ -24,7 +24,7 @@ from exo.shared.types.common import NodeId
from exo.shared.types.events import Event, InstanceCreated, InstanceDeleted
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId
from exo.shared.types.profiling import NodePerformanceProfile
from exo.shared.types.profiling import MemoryUsage, NodeNetworkInfo
from exo.shared.types.worker.instances import (
Instance,
InstanceId,
@@ -54,12 +54,13 @@ def place_instance(
command: PlaceInstance,
topology: Topology,
current_instances: Mapping[InstanceId, Instance],
node_profiles: Mapping[NodeId, NodePerformanceProfile],
node_memory: Mapping[NodeId, MemoryUsage],
node_network: Mapping[NodeId, NodeNetworkInfo],
) -> dict[InstanceId, Instance]:
cycles = topology.get_cycles()
candidate_cycles = list(filter(lambda it: len(it) >= command.min_nodes, cycles))
cycles_with_sufficient_memory = filter_cycles_by_memory(
candidate_cycles, node_profiles, command.model_meta.storage_size
candidate_cycles, node_memory, command.model_meta.storage_size
)
if len(cycles_with_sufficient_memory) == 0:
raise ValueError("No cycles found with sufficient memory")
@@ -104,13 +105,13 @@ def place_instance(
selected_cycle = max(
cycles_with_leaf_nodes if cycles_with_leaf_nodes != [] else smallest_cycles,
key=lambda cycle: sum(
(node_profiles[node_id].memory.ram_available for node_id in cycle),
(node_memory[node_id].ram_available for node_id in cycle),
start=Memory(),
),
)
shard_assignments = get_shard_assignments(
command.model_meta, selected_cycle, command.sharding, node_profiles
command.model_meta, selected_cycle, command.sharding, node_memory
)
cycle_digraph: Topology = topology.get_subgraph_from_nodes(selected_cycle.node_ids)
@@ -136,7 +137,7 @@ def place_instance(
coordinator=selected_cycle.node_ids[0],
coordinator_port=random_ephemeral_port(),
cycle_digraph=cycle_digraph,
node_profiles=node_profiles,
node_network=node_network,
)
target_instances[instance_id] = MlxJacclInstance(
instance_id=instance_id,
@@ -150,7 +151,7 @@ def place_instance(
selected_cycle=selected_cycle,
cycle_digraph=cycle_digraph,
ephemeral_port=ephemeral_port,
node_profiles=node_profiles,
node_network=node_network,
)
target_instances[instance_id] = MlxRingInstance(
instance_id=instance_id,
@@ -158,6 +159,11 @@ def place_instance(
hosts_by_node=hosts_by_node,
ephemeral_port=ephemeral_port,
)
case _:
# Plugin-managed instance types have their own placement functions
raise ValueError(
f"Instance type {command.instance_meta} must use plugin placement"
)
return target_instances

View File

@@ -6,7 +6,7 @@ from exo.shared.topology import Topology
from exo.shared.types.common import Host, NodeId
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelMetadata
from exo.shared.types.profiling import NodePerformanceProfile
from exo.shared.types.profiling import MemoryUsage, NodeNetworkInfo
from exo.shared.types.topology import Cycle, RDMAConnection, SocketConnection
from exo.shared.types.worker.runners import RunnerId, ShardAssignments
from exo.shared.types.worker.shards import (
@@ -19,16 +19,16 @@ from exo.shared.types.worker.shards import (
def filter_cycles_by_memory(
cycles: list[Cycle],
node_profiles: Mapping[NodeId, NodePerformanceProfile],
node_memory: Mapping[NodeId, MemoryUsage],
required_memory: Memory,
) -> list[Cycle]:
filtered_cycles: list[Cycle] = []
for cycle in cycles:
if not all(node in node_profiles for node in cycle):
if not all(node in node_memory for node in cycle):
continue
total_mem = sum(
(node_profiles[node_id].memory.ram_available for node_id in cycle.node_ids),
(node_memory[node_id].ram_available for node_id in cycle.node_ids),
start=Memory(),
)
if total_mem >= required_memory:
@@ -77,13 +77,13 @@ def allocate_layers_proportionally(
def get_shard_assignments_for_pipeline_parallel(
model_meta: ModelMetadata,
cycle: Cycle,
node_profiles: Mapping[NodeId, NodePerformanceProfile],
node_memory: Mapping[NodeId, MemoryUsage],
):
if not cycle.node_ids:
raise ValueError("Cannot create shard assignments for empty node cycle")
cycle_memory = sum(
(node_profiles[node_id].memory.ram_available for node_id in cycle.node_ids),
(node_memory[node_id].ram_available for node_id in cycle.node_ids),
start=Memory(),
)
@@ -98,7 +98,7 @@ def get_shard_assignments_for_pipeline_parallel(
layer_allocations = allocate_layers_proportionally(
total_layers=total_layers,
memory_fractions=[
node_profiles[node_id].memory.ram_available.in_bytes / cycle_memory.in_bytes
node_memory[node_id].ram_available.in_bytes / cycle_memory.in_bytes
for node_id in cycle.node_ids
],
)
@@ -109,7 +109,7 @@ def get_shard_assignments_for_pipeline_parallel(
zip(cycle.node_ids, layer_allocations, strict=True)
):
required_memory = node_layers * memory_per_layer
available_memory = node_profiles[node_id].memory.ram_available.in_bytes
available_memory = node_memory[node_id].ram_available.in_bytes
if required_memory > available_memory:
raise ValueError(
f"Node {i} ({node_id}) has insufficient memory: "
@@ -182,14 +182,14 @@ def get_shard_assignments(
model_meta: ModelMetadata,
cycle: Cycle,
sharding: Sharding,
node_profiles: Mapping[NodeId, NodePerformanceProfile],
node_memory: Mapping[NodeId, MemoryUsage],
) -> ShardAssignments:
match sharding:
case Sharding.Pipeline:
return get_shard_assignments_for_pipeline_parallel(
model_meta=model_meta,
cycle=cycle,
node_profiles=node_profiles,
node_memory=node_memory,
)
case Sharding.Tensor:
return get_shard_assignments_for_tensor_parallel(
@@ -288,10 +288,10 @@ def _find_connection_ip(
def _find_interface_name_for_ip(
ip_address: str, node_profile: NodePerformanceProfile
ip_address: str, node_network: NodeNetworkInfo
) -> str | None:
"""Find the interface name for an IP address on a node (any interface)."""
for interface in node_profile.network_interfaces:
for interface in node_network.interfaces:
if interface.ip_address == ip_address:
return interface.name
@@ -302,7 +302,7 @@ def _find_ip_prioritised(
node_id: NodeId,
other_node_id: NodeId,
cycle_digraph: Topology,
node_profiles: Mapping[NodeId, NodePerformanceProfile],
node_network: Mapping[NodeId, NodeNetworkInfo],
) -> str | None:
# TODO: Actually prioritize in the correct Ethernet > Wifi > Non-TB > TB order.
"""Find an IP address between nodes with prioritization.
@@ -316,7 +316,9 @@ def _find_ip_prioritised(
ips = list(_find_connection_ip(node_id, other_node_id, cycle_digraph))
# We expect a unique iface -> ip mapping
iface_map = {
_find_interface_name_for_ip(ip, node_profiles[other_node_id]): ip
_find_interface_name_for_ip(
ip, node_network.get(other_node_id, NodeNetworkInfo())
): ip
for ip, _ in ips
}
@@ -345,7 +347,7 @@ def get_mlx_ring_hosts_by_node(
selected_cycle: Cycle,
cycle_digraph: Topology,
ephemeral_port: int,
node_profiles: Mapping[NodeId, NodePerformanceProfile],
node_network: Mapping[NodeId, NodeNetworkInfo],
) -> dict[NodeId, list[Host]]:
"""Generate per-node host lists for MLX ring backend.
@@ -377,7 +379,7 @@ def get_mlx_ring_hosts_by_node(
continue
connection_ip = _find_ip_prioritised(
node_id, other_node_id, cycle_digraph, node_profiles
node_id, other_node_id, cycle_digraph, node_network
)
if connection_ip is None:
logger.warning(
@@ -398,7 +400,7 @@ def get_mlx_jaccl_coordinators(
coordinator: NodeId,
coordinator_port: int,
cycle_digraph: Topology,
node_profiles: Mapping[NodeId, NodePerformanceProfile],
node_network: Mapping[NodeId, NodeNetworkInfo],
) -> dict[NodeId, str]:
"""Get the coordinator addresses for MLX JACCL (rank 0 device).
@@ -411,7 +413,7 @@ def get_mlx_jaccl_coordinators(
if n == coordinator:
return "0.0.0.0"
ip = _find_ip_prioritised(n, coordinator, cycle_digraph, node_profiles)
ip = _find_ip_prioritised(n, coordinator, cycle_digraph, node_network)
if ip is not None:
return ip

View File

@@ -2,28 +2,26 @@ from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.profiling import (
MemoryUsage,
NetworkInterfaceInfo,
NodePerformanceProfile,
SystemPerformanceProfile,
NodeNetworkInfo,
)
from exo.shared.types.topology import RDMAConnection, SocketConnection
def create_node_profile(memory: int) -> NodePerformanceProfile:
return NodePerformanceProfile(
model_id="test",
chip_id="test",
friendly_name="test",
memory=MemoryUsage.from_bytes(
ram_total=1000,
ram_available=memory,
swap_total=1000,
swap_available=1000,
),
network_interfaces=[
def create_node_memory(memory: int) -> MemoryUsage:
return MemoryUsage.from_bytes(
ram_total=1000,
ram_available=memory,
swap_total=1000,
swap_available=1000,
)
def create_node_network() -> NodeNetworkInfo:
return NodeNetworkInfo(
interfaces=[
NetworkInterfaceInfo(name="en0", ip_address=f"169.254.0.{i}")
for i in range(10)
],
system=SystemPerformanceProfile(),
]
)

View File

@@ -73,8 +73,8 @@ async def test_master():
tg.start_soon(master.run)
sender_node_id = NodeId(f"{keypair.to_peer_id().to_base58()}_sender")
# inject a NodePerformanceProfile event
logger.info("inject a NodePerformanceProfile event")
# inject a NodeGatheredInfo event
logger.info("inject a NodeGatheredInfo event")
await local_event_sender.send(
ForwarderEvent(
origin_idx=0,
@@ -99,7 +99,7 @@ async def test_master():
logger.info("wait for initial topology event")
while len(list(master.state.topology.list_nodes())) == 0:
await anyio.sleep(0.001)
while len(master.state.node_profiles) == 0:
while len(master.state.node_memory) == 0:
await anyio.sleep(0.001)
logger.info("inject a CreateInstance Command")

View File

@@ -5,7 +5,8 @@ from exo.master.placement import (
place_instance,
)
from exo.master.tests.conftest import (
create_node_profile,
create_node_memory,
create_node_network,
create_rdma_connection,
create_socket_connection,
)
@@ -16,7 +17,7 @@ from exo.shared.types.events import InstanceCreated, InstanceDeleted
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.profiling import NetworkInterfaceInfo
from exo.shared.types.profiling import NetworkInterfaceInfo, NodeNetworkInfo
from exo.shared.types.topology import Connection, SocketConnection
from exo.shared.types.worker.instances import (
Instance,
@@ -109,10 +110,15 @@ def test_get_instance_placements_create_instance(
source=node_id_b, sink=node_id_a, edge=create_socket_connection(6)
)
profiles = {
node_id_a: create_node_profile(available_memory[0]),
node_id_b: create_node_profile(available_memory[1]),
node_id_c: create_node_profile(available_memory[2]),
node_memory = {
node_id_a: create_node_memory(available_memory[0]),
node_id_b: create_node_memory(available_memory[1]),
node_id_c: create_node_memory(available_memory[2]),
}
node_network = {
node_id_a: create_node_network(),
node_id_b: create_node_network(),
node_id_c: create_node_network(),
}
topology.add_node(node_id_a)
topology.add_node(node_id_b)
@@ -125,7 +131,7 @@ def test_get_instance_placements_create_instance(
topology.add_connection(conn_b_a)
# act
placements = place_instance(cic, topology, {}, profiles)
placements = place_instance(cic, topology, {}, node_memory, node_network)
# assert
assert len(placements) == 1
@@ -155,7 +161,8 @@ def test_get_instance_placements_one_node_exact_fit() -> None:
topology = Topology()
node_id = NodeId()
topology.add_node(node_id)
profiles = {node_id: create_node_profile(1000 * 1024)}
node_memory = {node_id: create_node_memory(1000 * 1024)}
node_network = {node_id: create_node_network()}
cic = place_instance_command(
ModelMetadata(
model_id=ModelId("test-model"),
@@ -166,7 +173,7 @@ def test_get_instance_placements_one_node_exact_fit() -> None:
supports_tensor=True,
),
)
placements = place_instance(cic, topology, {}, profiles)
placements = place_instance(cic, topology, {}, node_memory, node_network)
assert len(placements) == 1
instance_id = list(placements.keys())[0]
@@ -181,7 +188,8 @@ def test_get_instance_placements_one_node_fits_with_extra_memory() -> None:
topology = Topology()
node_id = NodeId()
topology.add_node(node_id)
profiles = {node_id: create_node_profile(1001 * 1024)}
node_memory = {node_id: create_node_memory(1001 * 1024)}
node_network = {node_id: create_node_network()}
cic = place_instance_command(
ModelMetadata(
model_id=ModelId("test-model"),
@@ -192,7 +200,7 @@ def test_get_instance_placements_one_node_fits_with_extra_memory() -> None:
supports_tensor=True,
),
)
placements = place_instance(cic, topology, {}, profiles)
placements = place_instance(cic, topology, {}, node_memory, node_network)
assert len(placements) == 1
instance_id = list(placements.keys())[0]
@@ -207,7 +215,8 @@ def test_get_instance_placements_one_node_not_fit() -> None:
topology = Topology()
node_id = NodeId()
topology.add_node(node_id)
profiles = {node_id: create_node_profile(1000 * 1024)}
node_memory = {node_id: create_node_memory(1000 * 1024)}
node_network = {node_id: create_node_network()}
cic = place_instance_command(
model_meta=ModelMetadata(
model_id=ModelId("test-model"),
@@ -220,7 +229,7 @@ def test_get_instance_placements_one_node_not_fit() -> None:
)
with pytest.raises(ValueError, match="No cycles found with sufficient memory"):
place_instance(cic, topology, {}, profiles)
place_instance(cic, topology, {}, node_memory, node_network)
def test_get_transition_events_no_change(instance: Instance):
@@ -278,11 +287,17 @@ def test_placement_selects_leaf_nodes(
node_id_c = NodeId()
node_id_d = NodeId()
profiles = {
node_id_a: create_node_profile(500),
node_id_b: create_node_profile(600),
node_id_c: create_node_profile(600),
node_id_d: create_node_profile(500),
node_memory = {
node_id_a: create_node_memory(500),
node_id_b: create_node_memory(600),
node_id_c: create_node_memory(600),
node_id_d: create_node_memory(500),
}
node_network = {
node_id_a: create_node_network(),
node_id_b: create_node_network(),
node_id_c: create_node_network(),
node_id_d: create_node_network(),
}
topology.add_node(node_id_a)
@@ -313,7 +328,7 @@ def test_placement_selects_leaf_nodes(
cic = place_instance_command(model_meta=model_meta)
# act
placements = place_instance(cic, topology, {}, profiles)
placements = place_instance(cic, topology, {}, node_memory, node_network)
# assert
assert len(placements) == 1
@@ -340,10 +355,10 @@ def test_tensor_rdma_backend_connectivity_matrix(
node_b = NodeId()
node_c = NodeId()
profiles = {
node_a: create_node_profile(500),
node_b: create_node_profile(500),
node_c: create_node_profile(500),
node_memory = {
node_a: create_node_memory(500),
node_b: create_node_memory(500),
node_c: create_node_memory(500),
}
ethernet_interface = NetworkInterfaceInfo(
@@ -354,9 +369,11 @@ def test_tensor_rdma_backend_connectivity_matrix(
sink_multiaddr=Multiaddr(address="/ip4/10.0.0.1/tcp/8000")
)
profiles[node_a].network_interfaces = [ethernet_interface]
profiles[node_b].network_interfaces = [ethernet_interface]
profiles[node_c].network_interfaces = [ethernet_interface]
node_network = {
node_a: NodeNetworkInfo(interfaces=[ethernet_interface]),
node_b: NodeNetworkInfo(interfaces=[ethernet_interface]),
node_c: NodeNetworkInfo(interfaces=[ethernet_interface]),
}
topology.add_node(node_a)
topology.add_node(node_b)
@@ -399,7 +416,7 @@ def test_tensor_rdma_backend_connectivity_matrix(
)
# act
placements = place_instance(cic, topology, {}, profiles)
placements = place_instance(cic, topology, {}, node_memory, node_network)
# assert
assert len(placements) == 1

View File

@@ -1,5 +1,3 @@
from copy import copy
import pytest
from exo.master.placement_utils import (
@@ -10,16 +8,17 @@ from exo.master.placement_utils import (
get_shard_assignments,
get_smallest_cycles,
)
from exo.master.tests.conftest import create_node_profile, create_socket_connection
from exo.master.tests.conftest import (
create_node_memory,
create_socket_connection,
)
from exo.shared.topology import Topology
from exo.shared.types.common import Host, NodeId
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.profiling import (
MemoryUsage,
NetworkInterfaceInfo,
NodePerformanceProfile,
SystemPerformanceProfile,
NodeNetworkInfo,
)
from exo.shared.types.topology import Connection, SocketConnection
from exo.shared.types.worker.shards import Sharding
@@ -36,9 +35,9 @@ def test_filter_cycles_by_memory():
source=node2_id, sink=node1_id, edge=create_socket_connection(2)
)
node1 = create_node_profile(1000 * 1024)
node2 = create_node_profile(1000 * 1024)
node_profiles = {node1_id: node1, node2_id: node2}
node1_mem = create_node_memory(1000 * 1024)
node2_mem = create_node_memory(1000 * 1024)
node_memory = {node1_id: node1_mem, node2_id: node2_mem}
topology = Topology()
topology.add_node(node1_id)
@@ -51,9 +50,7 @@ def test_filter_cycles_by_memory():
assert len(cycles[0]) == 2
# act
filtered_cycles = filter_cycles_by_memory(
cycles, node_profiles, Memory.from_bytes(1)
)
filtered_cycles = filter_cycles_by_memory(cycles, node_memory, Memory.from_bytes(1))
# assert
assert len(filtered_cycles) == 1
@@ -72,9 +69,9 @@ def test_filter_cycles_by_insufficient_memory():
source=node2_id, sink=node1_id, edge=create_socket_connection(2)
)
node1 = create_node_profile(1000 * 1024)
node2 = create_node_profile(1000 * 1024)
node_profiles = {node1_id: node1, node2_id: node2}
node1_mem = create_node_memory(1000 * 1024)
node2_mem = create_node_memory(1000 * 1024)
node_memory = {node1_id: node1_mem, node2_id: node2_mem}
topology = Topology()
topology.add_node(node1_id)
@@ -84,7 +81,7 @@ def test_filter_cycles_by_insufficient_memory():
# act
filtered_cycles = filter_cycles_by_memory(
topology.get_cycles(), node_profiles, Memory.from_kb(2001)
topology.get_cycles(), node_memory, Memory.from_kb(2001)
)
# assert
@@ -109,13 +106,13 @@ def test_filter_multiple_cycles_by_memory():
source=node_c_id, sink=node_b_id, edge=create_socket_connection(4)
)
node_a = create_node_profile(500 * 1024)
node_b = create_node_profile(500 * 1024)
node_c = create_node_profile(1000 * 1024)
node_profiles = {
node_a_id: node_a,
node_b_id: node_b,
node_c_id: node_c,
node_a_mem = create_node_memory(500 * 1024)
node_b_mem = create_node_memory(500 * 1024)
node_c_mem = create_node_memory(1000 * 1024)
node_memory = {
node_a_id: node_a_mem,
node_b_id: node_b_mem,
node_c_id: node_c_mem,
}
topology = Topology()
@@ -130,9 +127,7 @@ def test_filter_multiple_cycles_by_memory():
cycles = topology.get_cycles()
# act
filtered_cycles = filter_cycles_by_memory(
cycles, node_profiles, Memory.from_kb(1500)
)
filtered_cycles = filter_cycles_by_memory(cycles, node_memory, Memory.from_kb(1500))
# assert
assert len(filtered_cycles) == 1
@@ -228,13 +223,13 @@ def test_get_shard_assignments(
topology.add_connection(connection3)
topology.add_connection(connection4)
node_a = create_node_profile(available_memory[0] * 1024)
node_b = create_node_profile(available_memory[1] * 1024)
node_c = create_node_profile(available_memory[2] * 1024)
node_profiles = {
node_a_id: node_a,
node_b_id: node_b,
node_c_id: node_c,
node_a_mem = create_node_memory(available_memory[0] * 1024)
node_b_mem = create_node_memory(available_memory[1] * 1024)
node_c_mem = create_node_memory(available_memory[2] * 1024)
node_memory = {
node_a_id: node_a_mem,
node_b_id: node_b_mem,
node_c_id: node_c_mem,
}
model_meta = ModelMetadata(
@@ -253,7 +248,7 @@ def test_get_shard_assignments(
# act
shard_assignments = get_shard_assignments(
model_meta, selected_cycle, Sharding.Pipeline, node_profiles=node_profiles
model_meta, selected_cycle, Sharding.Pipeline, node_memory=node_memory
)
# assert
@@ -343,38 +338,28 @@ def test_get_mlx_jaccl_coordinators():
source=node_a_id, sink=node_c_id, edge=create_socket_connection(6)
)
npp = NodePerformanceProfile(
model_id="test",
chip_id="test",
friendly_name="test",
memory=MemoryUsage.from_bytes(
ram_total=0,
ram_available=0,
swap_total=0,
swap_available=0,
),
network_interfaces=[],
system=SystemPerformanceProfile(),
network_a = NodeNetworkInfo(
interfaces=[
NetworkInterfaceInfo(name="en0", ip_address="169.254.0.5"),
NetworkInterfaceInfo(name="en0", ip_address="169.254.0.2"),
]
)
npp_a = copy(npp)
npp_a.network_interfaces = [
NetworkInterfaceInfo(name="en0", ip_address="169.254.0.5"),
NetworkInterfaceInfo(name="en0", ip_address="169.254.0.2"),
]
npp_b = copy(npp)
npp_b.network_interfaces = [
NetworkInterfaceInfo(name="en0", ip_address="169.254.0.1"),
NetworkInterfaceInfo(name="en0", ip_address="169.254.0.4"),
]
npp_c = copy(npp)
npp_c.network_interfaces = [
NetworkInterfaceInfo(name="en0", ip_address="169.254.0.3"),
NetworkInterfaceInfo(name="en0", ip_address="169.254.0.6"),
]
node_profiles = {
node_a_id: npp_a,
node_b_id: npp_b,
node_c_id: npp_c,
network_b = NodeNetworkInfo(
interfaces=[
NetworkInterfaceInfo(name="en0", ip_address="169.254.0.1"),
NetworkInterfaceInfo(name="en0", ip_address="169.254.0.4"),
]
)
network_c = NodeNetworkInfo(
interfaces=[
NetworkInterfaceInfo(name="en0", ip_address="169.254.0.3"),
NetworkInterfaceInfo(name="en0", ip_address="169.254.0.6"),
]
)
node_network = {
node_a_id: network_a,
node_b_id: network_b,
node_c_id: network_c,
}
topology = Topology()
@@ -394,7 +379,7 @@ def test_get_mlx_jaccl_coordinators():
node_a_id,
coordinator_port=5000,
cycle_digraph=topology,
node_profiles=node_profiles,
node_network=node_network,
)
# assert
@@ -496,9 +481,9 @@ def test_get_shard_assignments_insufficient_memory_raises():
topology = Topology()
# Node C has only 10 KB but would need 50 KB for 1 layer (1000 KB / 20 layers)
node_a = create_node_profile(900 * 1024)
node_b = create_node_profile(50 * 1024)
node_c = create_node_profile(10 * 1024) # Insufficient memory
node_a_mem = create_node_memory(900 * 1024)
node_b_mem = create_node_memory(50 * 1024)
node_c_mem = create_node_memory(10 * 1024) # Insufficient memory
topology.add_node(node_a_id)
topology.add_node(node_b_id)
@@ -521,10 +506,10 @@ def test_get_shard_assignments_insufficient_memory_raises():
topology.add_connection(conn_c_a)
topology.add_connection(conn_b_a)
profiles = {
node_a_id: node_a,
node_b_id: node_b,
node_c_id: node_c,
node_memory = {
node_a_id: node_a_mem,
node_b_id: node_b_mem,
node_c_id: node_c_mem,
}
model_meta = ModelMetadata(
@@ -539,4 +524,6 @@ def test_get_shard_assignments_insufficient_memory_raises():
selected_cycle = cycles[0]
with pytest.raises(ValueError, match="insufficient memory"):
get_shard_assignments(model_meta, selected_cycle, Sharding.Pipeline, profiles)
get_shard_assignments(
model_meta, selected_cycle, Sharding.Pipeline, node_memory
)

View File

@@ -3,11 +3,6 @@ import pytest
from exo.shared.topology import Topology
from exo.shared.types.common import NodeId
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.profiling import (
MemoryUsage,
NodePerformanceProfile,
SystemPerformanceProfile,
)
from exo.shared.types.topology import Connection, SocketConnection
@@ -23,22 +18,6 @@ def socket_connection() -> SocketConnection:
)
@pytest.fixture
def node_profile() -> NodePerformanceProfile:
memory_profile = MemoryUsage.from_bytes(
ram_total=1000, ram_available=1000, swap_total=1000, swap_available=1000
)
system_profile = SystemPerformanceProfile()
return NodePerformanceProfile(
model_id="test",
chip_id="test",
friendly_name="test",
memory=memory_profile,
network_interfaces=[],
system=system_profile,
)
def test_add_node(topology: Topology):
# arrange
node_id = NodeId()

View File

@@ -0,0 +1,16 @@
"""Exo Plugin System.
This module provides the plugin architecture for extending exo with custom
workload types (simulations, ML frameworks, etc.) without modifying core code.
"""
from exo.plugins.base import ExoPlugin, PluginCommand, PluginInstance
from exo.plugins.registry import PluginRegistry, discover_plugins
__all__ = [
"ExoPlugin",
"PluginCommand",
"PluginInstance",
"PluginRegistry",
"discover_plugins",
]

171
src/exo/plugins/base.py Normal file
View File

@@ -0,0 +1,171 @@
"""Base classes and protocols for Exo plugins."""
from abc import ABC, abstractmethod
from collections.abc import Callable, Mapping, Sequence
from typing import TYPE_CHECKING, Any
from pydantic import Field
from exo.shared.types.common import CommandId
from exo.shared.types.events import Event
from exo.shared.types.tasks import Task
from exo.shared.types.worker.instances import InstanceId
from exo.shared.types.worker.runners import RunnerId
from exo.utils.pydantic_ext import TaggedModel
if TYPE_CHECKING:
from exo.shared.topology import Topology
from exo.shared.types.worker.instances import BoundInstance, Instance
from exo.utils.channels import MpReceiver, MpSender
from exo.worker.runner.runner_supervisor import RunnerSupervisor
class PluginCommand(TaggedModel):
"""Base class for plugin-defined commands.
All plugin commands must inherit from this class. Commands are serialized
with their class name as a tag for routing.
"""
command_id: CommandId = Field(default_factory=CommandId)
class PluginInstance(TaggedModel):
"""Base class for plugin-defined instances.
All plugin instances must inherit from this class. Plugins are expected
to define their own instance type with workload-specific fields.
"""
instance_id: InstanceId
class ExoPlugin(ABC):
"""Protocol that all exo plugins must implement.
A plugin provides:
- Custom command types for API -> Master communication
- Custom instance types representing running workloads
- Placement logic for distributing work across nodes
- Planning logic for local task scheduling
- Runner implementation for executing work
"""
@property
@abstractmethod
def name(self) -> str:
"""Unique identifier for this plugin (e.g., 'flash', 'pytorch', 'mpi')."""
...
@property
@abstractmethod
def version(self) -> str:
"""Semantic version string (e.g., '1.0.0')."""
...
# ========== Type Registration ==========
@abstractmethod
def get_command_types(self) -> Sequence[type]:
"""Return command types this plugin handles.
These commands are routed to this plugin's process_command method.
Can return core BaseCommand types or PluginCommand types.
"""
...
@abstractmethod
def get_instance_type(self) -> type:
"""Return the instance type this plugin creates.
This instance type is used for routing in planning and runner bootstrap.
Can return core Instance types or PluginInstance types.
"""
...
# ========== API Routes ==========
@abstractmethod
def get_api_routes(
self,
) -> Sequence[tuple[str, str, Callable[..., Any]]]:
"""Return FastAPI routes to register.
Each tuple: (method, path, handler)
Example: [('post', '/flash/launch', self.launch_handler)]
Handlers receive a PluginContext with access to:
- state: Current State object
- send_command: Async function to send commands
- node_id: Current node's ID
"""
...
# ========== Master Command Handling ==========
@abstractmethod
def handles_command(self, command: Any) -> bool: # pyright: ignore[reportAny]
"""Return True if this plugin handles the given command type."""
...
@abstractmethod
def process_command(
self,
command: Any, # pyright: ignore[reportAny]
topology: "Topology",
current_instances: Mapping[InstanceId, "Instance"],
) -> Sequence[Event]:
"""Process a command and return events to emit.
Typically creates placement and returns InstanceCreated/InstanceDeleted events.
Args:
command: The command to process
topology: Current cluster topology
current_instances: Currently running instances
Returns:
Sequence of events to emit (e.g., InstanceCreated, InstanceDeleted)
"""
...
# ========== Worker Planning ==========
@abstractmethod
def handles_instance(self, instance: object) -> bool:
"""Return True if this plugin manages the given instance type."""
...
@abstractmethod
def plan_task(
self,
runners: Mapping[RunnerId, "RunnerSupervisor"],
instances: Mapping[InstanceId, "Instance"],
) -> Task | None:
"""Plan the next task for plugin instances.
Called during each planning cycle.
Return None if no task is needed.
"""
...
@abstractmethod
def should_skip_download(self, instance: object) -> bool:
"""Return True if this instance type doesn't need model downloads."""
...
# ========== Runner Bootstrap ==========
@abstractmethod
def create_runner(
self,
bound_instance: "BoundInstance",
event_sender: "MpSender[Event]",
task_receiver: "MpReceiver[Task]",
) -> None:
"""Entry point for the runner process.
Called in a subprocess to execute the actual workload.
This function should block until the workload completes.
"""
...

View File

@@ -0,0 +1,21 @@
"""Context objects passed to plugin handlers."""
from collections.abc import Awaitable, Callable
from dataclasses import dataclass
from exo.shared.types.commands import Command
from exo.shared.types.common import NodeId
from exo.shared.types.state import State
@dataclass
class PluginContext:
"""Context provided to plugin API handlers.
This gives plugins access to the current state and the ability to send
commands without direct access to internal API components.
"""
state: State
send_command: Callable[[Command], Awaitable[None]]
node_id: NodeId

View File

@@ -0,0 +1,5 @@
"""Plugin implementations directory.
Each subdirectory should contain a plugin with a register() function
that returns an ExoPlugin instance.
"""

View File

@@ -0,0 +1,8 @@
"""FLASH Plugin - MPI-based simulation support for Exo."""
from exo.plugins.implementations.flash.plugin import FLASHPlugin
def register() -> FLASHPlugin:
"""Entry point for plugin discovery."""
return FLASHPlugin()

View File

@@ -0,0 +1,108 @@
"""FLASH plugin API handlers."""
from typing import Any
from fastapi import HTTPException
from exo.plugins.context import PluginContext
# Use core types for serialization compatibility
from exo.shared.types.commands import LaunchFLASH, StopFLASH
from exo.shared.types.worker.instances import FLASHInstance
async def handle_launch_flash(
ctx: PluginContext,
simulation_name: str,
flash_executable_path: str,
working_directory: str,
parameter_file_path: str = "",
ranks_per_node: int = 1,
min_nodes: int = 1,
hosts: str = "",
) -> dict[str, str]:
"""Launch a FLASH MPI simulation across the cluster.
Args:
ctx: Plugin context with state and send_command
simulation_name: Name of the simulation
flash_executable_path: Path to the FLASH executable
working_directory: Working directory for the simulation
parameter_file_path: Path to parameter file (optional)
ranks_per_node: Number of MPI ranks per node
min_nodes: Minimum number of nodes required
hosts: Optional comma-separated hostnames (e.g., "s14,james21-1").
If not provided, IPs are discovered from topology edges.
"""
command = LaunchFLASH(
simulation_name=simulation_name,
flash_executable_path=flash_executable_path,
parameter_file_path=parameter_file_path,
working_directory=working_directory,
ranks_per_node=ranks_per_node,
min_nodes=min_nodes,
hosts=hosts,
)
await ctx.send_command(command)
return {
"message": "FLASH launch command received",
"command_id": str(command.command_id),
"simulation_name": simulation_name,
}
async def handle_stop_flash(
ctx: PluginContext,
instance_id: str,
) -> dict[str, str]:
"""Stop a running FLASH simulation."""
from exo.shared.types.worker.instances import InstanceId
inst_id = InstanceId(instance_id)
if inst_id not in ctx.state.instances:
raise HTTPException(status_code=404, detail="Instance not found")
instance = ctx.state.instances[inst_id]
if not isinstance(instance, FLASHInstance):
raise HTTPException(
status_code=400, detail="Instance is not a FLASH simulation"
)
command = StopFLASH(instance_id=inst_id)
await ctx.send_command(command)
return {
"message": "Stop command received",
"command_id": str(command.command_id),
"instance_id": str(instance_id),
}
async def handle_list_flash_instances(ctx: PluginContext) -> list[dict[str, Any]]:
"""List all FLASH simulation instances."""
flash_instances: list[dict[str, Any]] = []
for instance_id, instance in ctx.state.instances.items():
if isinstance(instance, FLASHInstance):
# Get runner statuses for this instance
runner_statuses: dict[str, str | None] = {}
for (
node_id,
runner_id,
) in instance.shard_assignments.node_to_runner.items():
runner_status = ctx.state.runners.get(runner_id)
runner_statuses[str(node_id)] = (
str(runner_status) if runner_status else None
)
flash_instances.append(
{
"instance_id": str(instance_id),
"simulation_name": instance.simulation_name,
"total_ranks": instance.total_ranks,
"working_directory": instance.working_directory,
"runner_statuses": runner_statuses,
}
)
return flash_instances

View File

@@ -0,0 +1,152 @@
"""FLASH plugin placement logic."""
from collections.abc import Mapping
from copy import deepcopy
from loguru import logger
from exo.shared.topology import Topology
from exo.shared.types.commands import LaunchFLASH
from exo.shared.types.common import Host, NodeId
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.topology import SocketConnection
from exo.shared.types.worker.instances import FLASHInstance, Instance, InstanceId
from exo.shared.types.worker.runners import (
RunnerId,
ShardAssignments,
)
from exo.shared.types.worker.shards import PipelineShardMetadata
def place_flash_instance(
command: LaunchFLASH,
topology: Topology,
current_instances: Mapping[InstanceId, Instance],
) -> dict[InstanceId, Instance]:
"""Place a FLASH simulation instance across available nodes.
Unlike MLX instances which use ring/JACCL topology for tensor parallelism,
FLASH instances use MPI for communication. We just need to provide the
node IPs so the runner can generate an MPI hostfile.
"""
instance_id = InstanceId()
target_instances: dict[InstanceId, Instance] = dict(deepcopy(current_instances))
all_nodes = list(topology.list_nodes())
if len(all_nodes) < command.min_nodes:
raise ValueError(
f"Not enough nodes: need {command.min_nodes}, have {len(all_nodes)}"
)
# Select nodes (take the first min_nodes)
selected_nodes = all_nodes[: command.min_nodes]
logger.info(
f"Placing FLASH instance '{command.simulation_name}' on {len(selected_nodes)} nodes"
)
# Build shard assignments (one runner per node for FLASH)
runner_to_shard: dict[RunnerId, PipelineShardMetadata] = {}
node_to_runner: dict[NodeId, RunnerId] = {}
# Create a dummy ModelMetadata for FLASH (required by ShardMetadata interface)
flash_model_meta = ModelMetadata(
model_id=ModelId(command.simulation_name),
pretty_name=f"FLASH: {command.simulation_name}",
storage_size=Memory(in_bytes=0),
n_layers=1,
hidden_size=1,
supports_tensor=False,
)
for i, node_id in enumerate(selected_nodes):
runner_id = RunnerId()
node_to_runner[node_id] = runner_id
runner_to_shard[runner_id] = PipelineShardMetadata(
device_rank=i,
world_size=len(selected_nodes),
model_meta=flash_model_meta,
start_layer=0,
end_layer=1,
n_layers=1,
)
shard_assignments = ShardAssignments(
model_id=ModelId(command.simulation_name),
runner_to_shard=runner_to_shard,
node_to_runner=node_to_runner,
)
# Build hosts_by_node - get hostnames/IPs for MPI hostfile generation
hosts_by_node: dict[NodeId, list[Host]] = {}
# If explicit hosts are provided, use them directly
if command.hosts:
explicit_hosts = [h.strip() for h in command.hosts.split(",") if h.strip()]
logger.info(f"FLASH placement: explicit hosts provided: {explicit_hosts}")
for i, node_id in enumerate(selected_nodes):
if i < len(explicit_hosts):
hosts_by_node[node_id] = [Host(ip=explicit_hosts[i], port=0)]
logger.info(
f"FLASH placement: node {node_id} (rank {i}) -> IP {explicit_hosts[i]}"
)
else:
logger.warning(
f"Not enough hosts provided for node {i}, using localhost"
)
hosts_by_node[node_id] = [Host(ip="127.0.0.1", port=0)]
logger.info(
f"FLASH placement: coordinator will be rank 0 at IP {explicit_hosts[0]}"
)
else:
# Try to get IPs from topology edges
for node_id in selected_nodes:
node_hosts: list[Host] = []
# Get IP from outgoing edges (connections to other nodes via mDNS discovery)
for conn in topology.out_edges(node_id):
if isinstance(conn.edge, SocketConnection):
# Extract IP from multiaddr
ip = conn.edge.sink_multiaddr.ip_address
# Skip link-local and localhost addresses
if not ip.startswith("169.254.") and not ip.startswith("127."):
node_hosts.append(Host(ip=ip, port=0))
break
# Last resort: use localhost (will only work for single-node)
if not node_hosts:
logger.warning(
f"Could not determine IP for node {node_id}, using localhost"
)
node_hosts.append(Host(ip="127.0.0.1", port=0))
hosts_by_node[node_id] = node_hosts
total_ranks = len(selected_nodes) * command.ranks_per_node
# Determine coordinator IP - first node's first host IP
first_node_id: NodeId = next(iter(hosts_by_node.keys()))
coordinator_ip: str = (
hosts_by_node[first_node_id][0].ip
if hosts_by_node[first_node_id]
else "127.0.0.1"
)
target_instances[instance_id] = FLASHInstance(
instance_id=instance_id,
shard_assignments=shard_assignments,
hosts_by_node=hosts_by_node,
flash_executable_path=command.flash_executable_path,
parameter_file_path=command.parameter_file_path,
working_directory=command.working_directory,
ranks_per_node=command.ranks_per_node,
total_ranks=total_ranks,
simulation_name=command.simulation_name,
coordinator_ip=coordinator_ip,
)
logger.info(f"Created FLASH instance {instance_id} with {total_ranks} total ranks")
return target_instances

View File

@@ -0,0 +1,36 @@
"""FLASH plugin planning logic."""
from collections.abc import Mapping
from exo.shared.types.tasks import LoadModel, Task
from exo.shared.types.worker.instances import FLASHInstance, Instance, InstanceId
from exo.shared.types.worker.runners import RunnerId, RunnerIdle
from exo.worker.runner.runner_supervisor import RunnerSupervisor
def plan_flash(
runners: Mapping[RunnerId, RunnerSupervisor],
instances: Mapping[InstanceId, Instance],
) -> Task | None:
"""Plan tasks specifically for FLASH instances.
FLASH instances have a simpler lifecycle:
- CreateRunner (handled by core _create_runner)
- LoadModel (starts the simulation immediately)
- Shutdown (handled by core _kill_runner)
This function handles the LoadModel step for FLASH instances,
skipping the MLX-specific download/init/warmup steps.
"""
for runner in runners.values():
instance = runner.bound_instance.instance
# Only handle FLASH instances
if not isinstance(instance, FLASHInstance):
continue
# If runner is idle, emit LoadModel to start the simulation
if isinstance(runner.status, RunnerIdle):
return LoadModel(instance_id=instance.instance_id)
return None

View File

@@ -0,0 +1,98 @@
"""FLASH Plugin - Main plugin class."""
from collections.abc import Callable, Mapping, Sequence
from typing import Any
from exo.plugins.base import ExoPlugin
from exo.plugins.implementations.flash.api_handlers import (
handle_launch_flash,
handle_list_flash_instances,
handle_stop_flash,
)
from exo.plugins.implementations.flash.placement import place_flash_instance
from exo.plugins.implementations.flash.planning import plan_flash
from exo.plugins.implementations.flash.runner import main as flash_runner_main
from exo.shared.topology import Topology
from exo.shared.types.commands import DeleteInstance, LaunchFLASH, StopFLASH
from exo.shared.types.events import Event
from exo.shared.types.tasks import Task
from exo.shared.types.worker.instances import (
BoundInstance,
FLASHInstance,
Instance,
InstanceId,
)
from exo.shared.types.worker.runners import RunnerId
from exo.utils.channels import MpReceiver, MpSender
from exo.worker.runner.runner_supervisor import RunnerSupervisor
class FLASHPlugin(ExoPlugin):
"""Plugin for FLASH MPI simulations."""
@property
def name(self) -> str:
return "flash"
@property
def version(self) -> str:
return "1.0.0"
def get_command_types(self) -> Sequence[type]:
return [LaunchFLASH, StopFLASH]
def get_instance_type(self) -> type:
return FLASHInstance
def get_api_routes(
self,
) -> Sequence[tuple[str, str, Callable[..., Any]]]:
return [
("post", "/flash/launch", handle_launch_flash),
("delete", "/flash/{instance_id}", handle_stop_flash),
("get", "/flash/instances", handle_list_flash_instances),
]
def handles_command(self, command: Any) -> bool: # pyright: ignore[reportAny]
return isinstance(command, (LaunchFLASH, StopFLASH))
def process_command(
self,
command: Any, # pyright: ignore[reportAny]
topology: Topology,
current_instances: Mapping[InstanceId, Instance],
) -> Sequence[Event]:
from exo.master.placement import delete_instance, get_transition_events
if isinstance(command, LaunchFLASH):
placement = place_flash_instance(command, topology, current_instances)
return list(get_transition_events(current_instances, placement))
elif isinstance(command, StopFLASH):
placement = delete_instance(
DeleteInstance(instance_id=command.instance_id),
current_instances,
)
return list(get_transition_events(current_instances, placement))
return []
def handles_instance(self, instance: object) -> bool:
return isinstance(instance, FLASHInstance)
def plan_task(
self,
runners: Mapping[RunnerId, RunnerSupervisor],
instances: Mapping[InstanceId, Instance],
) -> Task | None:
return plan_flash(runners, instances)
def should_skip_download(self, instance: object) -> bool:
# FLASH instances don't need model downloads
return True
def create_runner(
self,
bound_instance: BoundInstance,
event_sender: MpSender[Event],
task_receiver: MpReceiver[Task],
) -> None:
flash_runner_main(bound_instance, event_sender, task_receiver)

View File

@@ -0,0 +1,302 @@
"""FLASH MPI Runner - spawns and monitors FLASH simulations.
Exo-native distributed MPI:
- Exo handles node discovery and coordination
- Coordinator generates hostfile from Exo topology
- mpirun uses exo-rsh (no SSH required) to spawn on remote nodes
- exo-rsh connects to each node's Exo API (/execute endpoint) for remote execution
- Workers just report ready and wait
"""
import os
import shutil
import socket
import subprocess
import threading
from loguru import logger
from exo.shared.types.events import (
Event,
RunnerStatusUpdated,
TaskAcknowledged,
TaskStatusUpdated,
)
from exo.shared.types.tasks import (
LoadModel,
Shutdown,
Task,
TaskStatus,
)
from exo.shared.types.worker.instances import BoundInstance, FLASHInstance
from exo.shared.types.worker.runners import (
RunnerFailed,
RunnerIdle,
RunnerLoading,
RunnerReady,
RunnerRunning,
RunnerShutdown,
RunnerShuttingDown,
RunnerStatus,
)
from exo.utils.channels import MpReceiver, MpSender
# Find mpirun in PATH, fallback to common locations
MPIRUN_PATH = shutil.which("mpirun") or "/opt/homebrew/bin/mpirun"
# exo-rsh is installed as console script by exo package
_exo_rsh_path = shutil.which("exo-rsh")
if not _exo_rsh_path:
raise RuntimeError("exo-rsh not found in PATH - this should be installed with exo")
EXO_RSH_PATH: str = _exo_rsh_path
def get_my_rank(instance: FLASHInstance, my_node_id: str) -> int:
"""Determine this node's rank based on position in hosts_by_node."""
for i, node_id in enumerate(instance.hosts_by_node.keys()):
if str(node_id) == str(my_node_id):
return i
return -1
def get_coordinator_host(instance: FLASHInstance) -> str:
"""Get the IP of the coordinator node."""
return instance.coordinator_ip
def resolve_host(host: str) -> str:
"""Resolve host string to a usable hostname for MPI hostfile.
Accepts either an IP address or hostname. For IPs, attempts to resolve
to a hostname via DNS/mDNS. Hostnames are returned as-is after validation.
"""
# Check if input is already a hostname (not an IP)
try:
socket.inet_aton(host)
is_ip = True
except socket.error:
is_ip = False
if not is_ip:
# Already a hostname, verify it resolves and return as-is
try:
socket.gethostbyname(host)
return host
except socket.gaierror:
logger.warning(f"Hostname {host} does not resolve, using anyway")
return host
# It's an IP address, try to resolve to hostname
try:
hostname, _, _ = socket.gethostbyaddr(host)
hostname = hostname.split(".")[0]
logger.info(f"Resolved {host} to {hostname}")
return hostname
except socket.herror:
pass
# Fall back to IP
logger.warning(f"Could not resolve {host} to hostname, using IP directly")
return host
def generate_hostfile(instance: FLASHInstance, working_dir: str) -> str:
"""Generate MPI hostfile from instance topology."""
hostfile_path = os.path.join(working_dir, "flash_hosts.txt")
with open(hostfile_path, "w") as f:
for _node_id, hosts in instance.hosts_by_node.items():
if hosts:
host = resolve_host(hosts[0].ip)
f.write(f"{host} slots={instance.ranks_per_node}\n")
logger.info(f"Generated hostfile at {hostfile_path}")
with open(hostfile_path, "r") as f:
logger.info(f"Hostfile contents:\n{f.read()}")
return hostfile_path
def main(
bound_instance: BoundInstance,
event_sender: MpSender[Event],
task_receiver: MpReceiver[Task],
) -> None:
"""Main FLASH runner loop.
Coordinator: generates hostfile and runs mpirun (uses exo-rsh instead of SSH)
Workers: just report ready and wait for mpirun to spawn processes on them
"""
assert isinstance(bound_instance.instance, FLASHInstance)
instance = bound_instance.instance
runner_id = bound_instance.bound_runner_id
my_node_id = str(bound_instance.bound_node_id)
logger.info(f"FLASH runner starting for simulation: {instance.simulation_name}")
my_rank = get_my_rank(instance, my_node_id)
world_size = len(instance.hosts_by_node)
is_coordinator = my_rank == 0
coordinator_ip = get_coordinator_host(instance)
logger.info(
f"FLASH node: rank={my_rank}, world_size={world_size}, coordinator={is_coordinator}"
)
logger.info(f"FLASH coordinator IP: {coordinator_ip}")
process: subprocess.Popen[bytes] | None = None
current_status: RunnerStatus = RunnerIdle()
shutdown_requested = False
event_sender.send(
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
)
def monitor_output(proc: subprocess.Popen[bytes]) -> None:
"""Monitor FLASH stdout for progress updates."""
if proc.stdout is None:
return
for line in iter(proc.stdout.readline, b""):
if shutdown_requested:
break
try:
decoded: str = line.decode("utf-8", errors="replace").strip()
if decoded:
logger.info(f"[FLASH] {decoded}")
except Exception as e:
logger.warning(f"Error parsing FLASH output: {e}")
with task_receiver as tasks:
for task in tasks:
event_sender.send(
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Running)
)
event_sender.send(TaskAcknowledged(task_id=task.task_id))
match task:
case LoadModel() if isinstance(current_status, RunnerIdle):
current_status = RunnerLoading()
logger.info("Starting FLASH simulation")
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=current_status
)
)
try:
if is_coordinator:
# Coordinator: generate hostfile and run mpirun
hostfile = generate_hostfile(
instance, instance.working_directory
)
iface = instance.network_interface
cmd = [
MPIRUN_PATH,
"-np",
str(instance.total_ranks),
"--hostfile",
hostfile,
"--wdir",
instance.working_directory,
"--oversubscribe",
"--mca",
"btl",
"tcp,self",
"--mca",
"btl_tcp_if_include",
iface,
"--mca",
"oob_tcp_if_include",
iface,
"--mca",
"plm_rsh_no_tree_spawn",
"1",
]
# Use exo-rsh for remote execution (no SSH needed)
cmd.extend(["--mca", "plm_rsh_agent", EXO_RSH_PATH])
cmd.append(instance.flash_executable_path)
logger.info(f"FLASH distributed launch: {' '.join(cmd)}")
process = subprocess.Popen(
cmd,
cwd=instance.working_directory,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
)
monitor_thread = threading.Thread(
target=monitor_output, args=(process,), daemon=True
)
monitor_thread.start()
current_status = RunnerRunning()
logger.info(
f"FLASH running on {world_size} nodes with {instance.total_ranks} ranks"
)
else:
# Worker: mpirun on coordinator will use exo-rsh to spawn processes here
logger.info(
f"Worker {my_rank}: Ready for mpirun to spawn processes via exo-rsh"
)
current_status = RunnerRunning()
except Exception as e:
logger.error(f"Failed to start FLASH: {e}")
import traceback
logger.error(traceback.format_exc())
current_status = RunnerFailed(error_message=str(e))
case Shutdown():
shutdown_requested = True
current_status = RunnerShuttingDown()
logger.info("FLASH runner shutting down")
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=current_status
)
)
if process and process.poll() is None:
logger.info("Terminating FLASH simulation")
process.terminate()
try:
process.wait(timeout=10)
except subprocess.TimeoutExpired:
logger.warning("FLASH didn't terminate, killing")
process.kill()
process.wait()
current_status = RunnerShutdown()
case _:
if process and process.poll() is not None:
exit_code = process.returncode
if exit_code == 0:
logger.info("FLASH simulation completed successfully")
current_status = RunnerReady()
else:
logger.error(
f"FLASH simulation failed with code {exit_code}"
)
current_status = RunnerFailed(
error_message=f"Exit code {exit_code}"
)
event_sender.send(
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Complete)
)
event_sender.send(
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
)
if isinstance(current_status, RunnerShutdown):
break
if process and process.poll() is None:
process.terminate()
process.wait(timeout=5)
logger.info("FLASH runner exiting")

110
src/exo/plugins/registry.py Normal file
View File

@@ -0,0 +1,110 @@
"""Plugin registry for discovering and managing plugins."""
from collections.abc import Callable, Sequence
from typing import Any
from loguru import logger
from exo.plugins.base import ExoPlugin
class PluginRegistry:
"""Central registry for all plugins."""
_instance: "PluginRegistry | None" = None
def __init__(self) -> None:
self._plugins: dict[str, ExoPlugin] = {}
self._command_handlers: dict[type, ExoPlugin] = {}
self._instance_handlers: dict[type, ExoPlugin] = {}
@classmethod
def get(cls) -> "PluginRegistry":
"""Get the singleton registry instance."""
if cls._instance is None:
cls._instance = cls()
return cls._instance
@classmethod
def reset(cls) -> None:
"""Reset the singleton instance (useful for testing)."""
cls._instance = None
def register(self, plugin: ExoPlugin) -> None:
"""Register a plugin and its types."""
if plugin.name in self._plugins:
raise ValueError(f"Plugin '{plugin.name}' already registered")
logger.info(f"Registering plugin: {plugin.name} v{plugin.version}")
self._plugins[plugin.name] = plugin
# Register command handlers
for cmd_type in plugin.get_command_types():
self._command_handlers[cmd_type] = plugin
logger.debug(f" Registered command: {cmd_type.__name__}")
# Register instance handler
instance_type = plugin.get_instance_type()
self._instance_handlers[instance_type] = plugin
logger.debug(f" Registered instance: {instance_type.__name__}")
def get_plugin(self, name: str) -> ExoPlugin | None:
"""Get a plugin by name."""
return self._plugins.get(name)
def get_plugin_for_command(self, command: object) -> ExoPlugin | None:
"""Get the plugin that handles a command."""
for plugin in self._plugins.values():
if plugin.handles_command(command):
return plugin
return None
def get_plugin_for_instance(self, instance: object) -> ExoPlugin | None:
"""Get the plugin that manages an instance."""
for plugin in self._plugins.values():
if plugin.handles_instance(instance):
return plugin
return None
def all_plugins(self) -> Sequence[ExoPlugin]:
"""Get all registered plugins."""
return list(self._plugins.values())
def get_all_api_routes(
self,
) -> Sequence[tuple[str, str, Callable[..., Any], ExoPlugin]]:
"""Get all API routes from all plugins."""
routes: list[tuple[str, str, Callable[..., Any], ExoPlugin]] = []
for plugin in self._plugins.values():
for method, path, handler in plugin.get_api_routes():
routes.append((method, path, handler, plugin))
return routes
def discover_plugins() -> None:
"""Auto-discover and register plugins from the implementations directory.
Plugins should have a register() function that returns an ExoPlugin instance.
"""
import importlib
import pkgutil
registry = PluginRegistry.get()
try:
import exo.plugins.implementations as impl_package
for _, module_name, _ in pkgutil.iter_modules(impl_package.__path__):
try:
module = importlib.import_module(
f"exo.plugins.implementations.{module_name}"
)
if hasattr(module, "register"):
plugin = module.register() # pyright: ignore[reportAny]
if plugin is not None:
registry.register(plugin) # pyright: ignore[reportAny]
except Exception as e:
logger.warning(f"Failed to load plugin {module_name}: {e}")
except ImportError:
logger.debug("No plugin implementations package found")

13
src/exo/rsh/__init__.py Normal file
View File

@@ -0,0 +1,13 @@
"""Exo RSH - Remote Shell for MPI without SSH.
This module provides a remote execution mechanism that allows mpirun to spawn
processes on remote nodes without requiring SSH setup. It works by:
1. Each Exo node runs an API server on port 52415 with an /execute endpoint
2. The exo-rsh script acts as a drop-in replacement for ssh
3. When mpirun calls "exo-rsh hostname command", it HTTP POSTs to the target's /execute
4. The target executes the command and returns output
Usage:
mpirun --mca plm_rsh_agent exo-rsh -np 4 --hostfile hosts.txt ./program
"""

101
src/exo/rsh/client.py Normal file
View File

@@ -0,0 +1,101 @@
#!/usr/bin/env python3
"""exo-rsh - Remote shell client for MPI.
This script is called by mpirun as a replacement for ssh.
Usage: exo-rsh [ssh-options...] hostname command [args...]
It connects to the target node's Exo API (port 52415) and executes the command.
"""
import json
import socket
import sys
from typing import Any, cast
from urllib.error import URLError
from urllib.request import Request, urlopen
# Use the same port as Exo's API server
EXO_API_PORT = 52415
def resolve_hostname(hostname: str) -> str:
"""Resolve hostname to IP address."""
try:
return socket.gethostbyname(hostname)
except socket.gaierror:
# If resolution fails, try using the hostname directly
return hostname
def main():
# Parse arguments - mpirun calls us like: exo-rsh [options] hostname command [args...]
# SSH options we might see: -x (disable X11), -o options, etc.
args = sys.argv[1:]
# Skip SSH-style options
hostname = None
command_start = 0
i = 0
while i < len(args):
arg = args[i]
if arg.startswith("-"):
# Skip option and its value if needed
if arg in ("-o", "-i", "-l", "-p", "-F"):
i += 2 # Skip option and its argument
continue
i += 1
continue
else:
# First non-option is the hostname
hostname = arg
command_start = i + 1
break
i += 1
if hostname is None or command_start >= len(args):
print("Usage: exo-rsh [options] hostname command [args...]", file=sys.stderr)
sys.exit(1)
command = args[command_start:]
# Resolve hostname to IP
ip = resolve_hostname(hostname)
# Make request to Exo API
url = f"http://{ip}:{EXO_API_PORT}/execute"
data = json.dumps({"command": command}).encode("utf-8")
try:
req = Request(url, data=data, headers={"Content-Type": "application/json"})
with urlopen(req, timeout=300) as response: # pyright: ignore[reportAny]
response_body: bytes = cast(bytes, response.read()) # pyright: ignore[reportAny]
result: dict[str, Any] = json.loads(response_body.decode("utf-8")) # pyright: ignore[reportAny]
# Output stdout/stderr
stdout: str = cast(str, result.get("stdout", ""))
stderr: str = cast(str, result.get("stderr", ""))
exit_code: int = cast(int, result.get("exit_code", 0))
if stdout:
sys.stdout.write(stdout)
sys.stdout.flush()
if stderr:
sys.stderr.write(stderr)
sys.stderr.flush()
sys.exit(exit_code)
except URLError as e:
print(
f"exo-rsh: Failed to connect to {hostname}:{EXO_API_PORT}: {e}",
file=sys.stderr,
)
sys.exit(255)
except Exception as e:
print(f"exo-rsh: Error: {e}", file=sys.stderr)
sys.exit(1)
if __name__ == "__main__":
main()

View File

@@ -25,7 +25,11 @@ from exo.shared.types.events import (
TopologyEdgeCreated,
TopologyEdgeDeleted,
)
from exo.shared.types.profiling import NodePerformanceProfile
from exo.shared.types.profiling import (
NodeIdentity,
NodeNetworkInfo,
NodeThunderboltInfo,
)
from exo.shared.types.state import State
from exo.shared.types.tasks import Task, TaskId, TaskStatus
from exo.shared.types.topology import Connection, RDMAConnection
@@ -193,22 +197,43 @@ def apply_runner_deleted(event: RunnerDeleted, state: State) -> State:
def apply_node_timed_out(event: NodeTimedOut, state: State) -> State:
topology = copy.deepcopy(state.topology)
state.topology.remove_node(event.node_id)
node_profiles = {
key: value for key, value in state.node_profiles.items() if key != event.node_id
}
topology.remove_node(event.node_id)
last_seen = {
key: value for key, value in state.last_seen.items() if key != event.node_id
}
downloads = {
key: value for key, value in state.downloads.items() if key != event.node_id
}
# Clean up all granular node mappings
node_identities = {
key: value
for key, value in state.node_identities.items()
if key != event.node_id
}
node_memory = {
key: value for key, value in state.node_memory.items() if key != event.node_id
}
node_system = {
key: value for key, value in state.node_system.items() if key != event.node_id
}
node_network = {
key: value for key, value in state.node_network.items() if key != event.node_id
}
node_thunderbolt = {
key: value
for key, value in state.node_thunderbolt.items()
if key != event.node_id
}
return state.model_copy(
update={
"downloads": downloads,
"topology": topology,
"node_profiles": node_profiles,
"last_seen": last_seen,
"node_identities": node_identities,
"node_memory": node_memory,
"node_system": node_system,
"node_network": node_network,
"node_thunderbolt": node_thunderbolt,
}
)
@@ -217,29 +242,60 @@ def apply_node_gathered_info(event: NodeGatheredInfo, state: State) -> State:
topology = copy.deepcopy(state.topology)
topology.add_node(event.node_id)
info = event.info
profile = state.node_profiles.get(event.node_id, NodePerformanceProfile())
# Build update dict with only the mappings that change
update: dict[str, object] = {
"last_seen": {
**state.last_seen,
event.node_id: datetime.fromisoformat(event.when),
},
"topology": topology,
}
match info:
case MacmonMetrics():
profile.system = info.system_profile
profile.memory = info.memory
update["node_system"] = {
**state.node_system,
event.node_id: info.system_profile,
}
update["node_memory"] = {**state.node_memory, event.node_id: info.memory}
case MemoryUsage():
profile.memory = info
update["node_memory"] = {**state.node_memory, event.node_id: info}
case NodeConfig():
pass
case MiscData():
profile.friendly_name = info.friendly_name
current_identity = state.node_identities.get(event.node_id, NodeIdentity())
new_identity = current_identity.model_copy(
update={"friendly_name": info.friendly_name}
)
update["node_identities"] = {
**state.node_identities,
event.node_id: new_identity,
}
case StaticNodeInformation():
profile.model_id = info.model
profile.chip_id = info.chip
current_identity = state.node_identities.get(event.node_id, NodeIdentity())
new_identity = current_identity.model_copy(
update={"model_id": info.model, "chip_id": info.chip}
)
update["node_identities"] = {
**state.node_identities,
event.node_id: new_identity,
}
case NodeNetworkInterfaces():
profile.network_interfaces = info.ifaces
update["node_network"] = {
**state.node_network,
event.node_id: NodeNetworkInfo(interfaces=info.ifaces),
}
case MacThunderboltIdentifiers():
profile.tb_interfaces = info.idents
update["node_thunderbolt"] = {
**state.node_thunderbolt,
event.node_id: NodeThunderboltInfo(interfaces=info.idents),
}
case MacThunderboltConnections():
conn_map = {
tb_ident.domain_uuid: (nid, tb_ident.rdma_interface)
for nid in state.node_profiles
for tb_ident in state.node_profiles[nid].tb_interfaces
for nid in state.node_thunderbolt
for tb_ident in state.node_thunderbolt[nid].interfaces
}
as_rdma_conns = [
Connection(
@@ -256,15 +312,7 @@ def apply_node_gathered_info(event: NodeGatheredInfo, state: State) -> State:
]
topology.replace_all_out_rdma_connections(event.node_id, as_rdma_conns)
last_seen = {**state.last_seen, event.node_id: datetime.fromisoformat(event.when)}
new_profiles = {**state.node_profiles, event.node_id: profile}
return state.model_copy(
update={
"node_profiles": new_profiles,
"last_seen": last_seen,
"topology": topology,
}
)
return state.model_copy(update=update)
def apply_topology_edge_created(event: TopologyEdgeCreated, state: State) -> State:

View File

@@ -35,6 +35,26 @@ class DeleteInstance(BaseCommand):
instance_id: InstanceId
class LaunchFLASH(BaseCommand):
"""Command to launch a FLASH MPI simulation."""
simulation_name: str
flash_executable_path: str
parameter_file_path: str
working_directory: str
ranks_per_node: int = 1
min_nodes: int = 1
# Optional: explicit hostnames for MPI (e.g., "s14,james21-1")
# Used when topology edges don't contain IP addresses
hosts: str = ""
class StopFLASH(BaseCommand):
"""Command to stop a running FLASH simulation."""
instance_id: InstanceId
class TaskFinished(BaseCommand):
finished_command_id: CommandId
@@ -50,6 +70,8 @@ Command = (
| PlaceInstance
| CreateInstance
| DeleteInstance
| LaunchFLASH
| StopFLASH
| TaskFinished
)

View File

@@ -53,13 +53,21 @@ class NetworkInterfaceInfo(CamelCaseModel):
ip_address: str
class NodePerformanceProfile(CamelCaseModel):
class NodeIdentity(CamelCaseModel):
"""Static and slow-changing node identification data."""
model_id: str = "Unknown"
chip_id: str = "Unknown"
friendly_name: str = "Unknown"
memory: MemoryUsage = MemoryUsage.from_bytes(
ram_total=0, ram_available=0, swap_total=0, swap_available=0
)
network_interfaces: Sequence[NetworkInterfaceInfo] = []
tb_interfaces: Sequence[ThunderboltIdentifier] = []
system: SystemPerformanceProfile = SystemPerformanceProfile()
class NodeNetworkInfo(CamelCaseModel):
"""Network interface information for a node."""
interfaces: Sequence[NetworkInterfaceInfo] = []
class NodeThunderboltInfo(CamelCaseModel):
"""Thunderbolt interface identifiers for a node."""
interfaces: Sequence[ThunderboltIdentifier] = []

View File

@@ -7,7 +7,13 @@ from pydantic.alias_generators import to_camel
from exo.shared.topology import Topology, TopologySnapshot
from exo.shared.types.common import NodeId
from exo.shared.types.profiling import NodePerformanceProfile
from exo.shared.types.profiling import (
MemoryUsage,
NodeIdentity,
NodeNetworkInfo,
NodeThunderboltInfo,
SystemPerformanceProfile,
)
from exo.shared.types.tasks import Task, TaskId
from exo.shared.types.worker.downloads import DownloadProgress
from exo.shared.types.worker.instances import Instance, InstanceId
@@ -35,11 +41,17 @@ class State(CamelCaseModel):
runners: Mapping[RunnerId, RunnerStatus] = {}
downloads: Mapping[NodeId, Sequence[DownloadProgress]] = {}
tasks: Mapping[TaskId, Task] = {}
node_profiles: Mapping[NodeId, NodePerformanceProfile] = {}
last_seen: Mapping[NodeId, datetime] = {}
topology: Topology = Field(default_factory=Topology)
last_event_applied_idx: int = Field(default=-1, ge=-1)
# Granular node state mappings (update independently at different frequencies)
node_identities: Mapping[NodeId, NodeIdentity] = {}
node_memory: Mapping[NodeId, MemoryUsage] = {}
node_system: Mapping[NodeId, SystemPerformanceProfile] = {}
node_network: Mapping[NodeId, NodeNetworkInfo] = {}
node_thunderbolt: Mapping[NodeId, NodeThunderboltInfo] = {}
@field_serializer("topology", mode="plain")
def _encode_topology(self, value: Topology) -> TopologySnapshot:
return value.to_snapshot()

View File

@@ -14,6 +14,7 @@ class InstanceId(Id):
class InstanceMeta(str, Enum):
MlxRing = "MlxRing"
MlxJaccl = "MlxJaccl"
FLASH = "FLASH"
class BaseInstance(TaggedModel):
@@ -34,8 +35,27 @@ class MlxJacclInstance(BaseInstance):
jaccl_coordinators: dict[NodeId, str]
class FLASHInstance(BaseInstance):
"""Instance for FLASH MPI simulation.
Unlike MLX instances which do tensor parallelism, FLASH instances
coordinate MPI processes across nodes. Each node runs one or more
MPI ranks of the FLASH simulation.
"""
hosts_by_node: dict[NodeId, list[Host]]
flash_executable_path: str
parameter_file_path: str
working_directory: str
ranks_per_node: int = 1
total_ranks: int
simulation_name: str
coordinator_ip: str
network_interface: str = "en0" # Network interface for MPI (e.g., en0, eth0)
# TODO: Single node instance
Instance = MlxRingInstance | MlxJacclInstance
Instance = MlxRingInstance | MlxJacclInstance | FLASHInstance
class BoundInstance(CamelCaseModel):

View File

@@ -7,7 +7,7 @@ from loguru import logger
from exo.shared.topology import Topology
from exo.shared.types.common import NodeId
from exo.shared.types.profiling import NodePerformanceProfile
from exo.shared.types.profiling import NodeNetworkInfo
REACHABILITY_ATTEMPTS = 3
@@ -79,7 +79,7 @@ async def check_reachability(
async def check_reachable(
topology: Topology,
self_node_id: NodeId,
node_profiles: Mapping[NodeId, NodePerformanceProfile],
node_network: Mapping[NodeId, NodeNetworkInfo],
) -> dict[NodeId, set[str]]:
"""Check which nodes are reachable and return their IPs."""
@@ -98,11 +98,11 @@ async def check_reachable(
create_task_group() as tg,
):
for node_id in topology.list_nodes():
if node_id not in node_profiles:
if node_id not in node_network:
continue
if node_id == self_node_id:
continue
for iface in node_profiles[node_id].network_interfaces:
for iface in node_network[node_id].interfaces:
tg.start_soon(
check_reachability,
iface.ip_address,

View File

@@ -1,5 +1,3 @@
from typing import Any
import mlx.core as mx
import mlx.nn as nn
from mlx_lm.models.cache import KVCache
@@ -17,27 +15,3 @@ class Model(nn.Module):
cache: list[KVCache] | None,
input_embeddings: mx.array | None = None,
) -> mx.array: ...
class Detokenizer:
def reset(self) -> None: ...
def add_token(self, token: int) -> None: ...
def finalize(self) -> None: ...
@property
def last_segment(self) -> str: ...
class TokenizerWrapper:
bos_token: str | None
eos_token_ids: list[int]
detokenizer: Detokenizer
def encode(self, text: str, add_special_tokens: bool = True) -> list[int]: ...
def apply_chat_template(
self,
messages_dicts: list[dict[str, Any]],
tokenize: bool = False,
add_generation_prompt: bool = True,
) -> str: ...

View File

@@ -1,10 +1,7 @@
import os
import threading
from abc import ABC, abstractmethod
from collections.abc import Callable
from functools import partial
from inspect import signature
from typing import TYPE_CHECKING, Any, Protocol, cast
from typing import TYPE_CHECKING, Callable, Protocol, cast
import mlx.core as mx
import mlx.nn as nn
@@ -32,40 +29,6 @@ from mlx_lm.models.qwen3_next import Qwen3NextSparseMoeBlock
from exo.shared.logging import logger
from exo.shared.types.worker.shards import PipelineShardMetadata
TimeoutCallback = Callable[[], None]
def eval_with_timeout(
mlx_item: Any, # pyright: ignore[reportAny]
timeout_seconds: float = 60.0,
on_timeout: TimeoutCallback | None = None,
) -> None:
"""Evaluate MLX item with a hard timeout.
If on_timeout callback is provided, it will be called before terminating
the process. This allows the runner to send a failure event before exit.
"""
completed = threading.Event()
def watchdog() -> None:
if not completed.wait(timeout=timeout_seconds):
logger.error(
f"mlx_item evaluation timed out after {timeout_seconds:.0f}s. "
"This may indicate an issue with FAST_SYNCH and tensor parallel sharding. "
"Terminating process."
)
if on_timeout is not None:
on_timeout()
os._exit(1)
watchdog_thread = threading.Thread(target=watchdog, daemon=True)
watchdog_thread.start()
try:
mx.eval(mlx_item) # pyright: ignore[reportAny]
finally:
completed.set()
class _LayerCallable(Protocol):
"""Structural type that any compatible layer must satisfy.
@@ -173,30 +136,10 @@ def _get_layers(inner_model_instance: nn.Module) -> list[_LayerCallable]:
return layers
class _IdentityModule(nn.Module):
"""Identity module that returns input unchanged. Used to skip computation."""
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array:
return x
class _IdentityLmHead(nn.Module):
"""Identity lm_head that returns zeros. Used for non-final pipeline ranks."""
def __init__(self, vocab_size: int, dtype: mx.Dtype = mx.float16):
super().__init__()
self.vocab_size = vocab_size
self.dtype = dtype
def __call__(self, x: mx.array) -> mx.array:
# Return zeros with correct shape (batch, seq, vocab_size)
return mx.zeros((*x.shape[:-1], self.vocab_size), dtype=self.dtype)
def pipeline_auto_parallel(
model: nn.Module,
group: mx.distributed.Group,
model_shard_meta: PipelineShardMetadata
model_shard_meta: PipelineShardMetadata,
) -> nn.Module:
"""
Automatically parallelize a model across multiple devices.
@@ -214,7 +157,6 @@ def pipeline_auto_parallel(
device_rank, world_size = model_shard_meta.device_rank, model_shard_meta.world_size
layers = layers[start_layer:end_layer]
layers[0] = PipelineFirstLayer(layers[0], device_rank, group=group)
layers[-1] = PipelineLastLayer(
layers[-1],
@@ -283,37 +225,9 @@ def patch_pipeline_model[T](model: T, group: mx.distributed.Group) -> T:
return model
def patch_tensor_model[T](model: T) -> T:
"""Patch model's __call__ to ensure distributed ops sync during inference."""
cls = model.__class__
original_call = cls.__call__
call_signature = signature(original_call)
def patched_call(
self: T,
*args: object,
**kwargs: object,
) -> mx.array:
logits: mx.array = original_call(self, *args, **kwargs) # pyright: ignore[reportAny]
cache = call_signature.bind_partial(self, *args, **kwargs).arguments.get(
"cache", None
)
# Add dependency to last cache entry to ensure distributed ops are evaluated
if cache is not None and len(cache) > 0: # pyright: ignore[reportAny]
cache[-1].state = mx.depends(cache[-1].state, logits) # pyright: ignore[reportAny,reportUnknownMemberType]
return logits
cls.__call__ = patched_call
return model
def tensor_auto_parallel(
model: nn.Module,
group: mx.distributed.Group,
timeout_seconds: float = 60.0,
on_timeout: TimeoutCallback | None = None,
) -> nn.Module:
all_to_sharded_linear = partial(
shard_linear,
@@ -358,7 +272,7 @@ def tensor_auto_parallel(
if hasattr(model, "shard"):
try:
model.shard(group) # type: ignore
return patch_tensor_model(model)
return model
except (AttributeError, TypeError, NameError):
pass
@@ -408,10 +322,7 @@ def tensor_auto_parallel(
else:
raise ValueError(f"Unsupported model type: {type(model)}")
model = tensor_parallel_sharding_strategy.shard_model(
model, timeout_seconds, on_timeout
)
return patch_tensor_model(model)
return tensor_parallel_sharding_strategy.shard_model(model)
class TensorParallelShardingStrategy(ABC):
@@ -431,27 +342,13 @@ class TensorParallelShardingStrategy(ABC):
self.N = group.size()
@abstractmethod
def shard_model(
self,
model: nn.Module,
timeout_seconds: float,
on_timeout: TimeoutCallback | None,
) -> nn.Module: ...
def shard_model(self, model: nn.Module) -> nn.Module: ...
class LlamaShardingStrategy(TensorParallelShardingStrategy):
def shard_model(
self,
model: nn.Module,
timeout_seconds: float,
on_timeout: TimeoutCallback | None,
) -> nn.Module:
def shard_model(self, model: nn.Module) -> nn.Module:
model = cast(LlamaModel, model)
for layer in model.layers:
# Force load weights before sharding to avoid FAST_SYNCH deadlock
eval_with_timeout(
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
)
layer.self_attn.q_proj = self.all_to_sharded_linear(layer.self_attn.q_proj)
layer.self_attn.k_proj = self.all_to_sharded_linear(layer.self_attn.k_proj)
layer.self_attn.v_proj = self.all_to_sharded_linear(layer.self_attn.v_proj)
@@ -494,17 +391,9 @@ def _set_layers(model: nn.Module, layers: list[_LayerCallable]) -> None:
class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
def shard_model(
self,
model: nn.Module,
timeout_seconds: float,
on_timeout: TimeoutCallback | None,
) -> nn.Module:
def shard_model(self, model: nn.Module) -> nn.Module:
model = cast(DeepseekV3Model, model)
for layer in model.layers:
eval_with_timeout(
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
)
# Shard the self attention
if layer.self_attn.q_lora_rank is None:
layer.self_attn.q_proj = self.all_to_sharded_linear(
@@ -556,17 +445,9 @@ class ShardedDeepseekV3MoE(CustomMlxLayer):
class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
def shard_model(
self,
model: nn.Module,
timeout_seconds: float,
on_timeout: TimeoutCallback | None,
) -> nn.Module:
def shard_model(self, model: nn.Module) -> nn.Module:
model = cast(MiniMaxModel, model)
for layer in model.layers:
eval_with_timeout(
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
)
# Shard the self attention
layer.self_attn.q_proj = self.all_to_sharded_linear(layer.self_attn.q_proj)
layer.self_attn.k_proj = self.all_to_sharded_linear(layer.self_attn.k_proj)
@@ -593,17 +474,9 @@ class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
class QwenShardingStrategy(TensorParallelShardingStrategy):
def shard_model(
self,
model: nn.Module,
timeout_seconds: float,
on_timeout: TimeoutCallback | None,
) -> nn.Module:
def shard_model(self, model: nn.Module) -> nn.Module:
model = cast(Qwen3MoeModel, model)
for layer in model.layers:
eval_with_timeout(
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
)
# Shard the self attention
layer.self_attn.q_proj = self.all_to_sharded_linear(layer.self_attn.q_proj)
layer.self_attn.k_proj = self.all_to_sharded_linear(layer.self_attn.k_proj)
@@ -647,18 +520,10 @@ class ShardedQwenMoE(CustomMlxLayer):
class GptOssShardingStrategy(TensorParallelShardingStrategy):
def shard_model(
self,
model: nn.Module,
timeout_seconds: float,
on_timeout: TimeoutCallback | None,
) -> nn.Module:
def shard_model(self, model: nn.Module) -> nn.Module:
model = cast(GptOssMoeModel, model)
for layer in model.layers:
eval_with_timeout(
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
)
layer.self_attn.q_proj = self.all_to_sharded_linear(layer.self_attn.q_proj)
layer.self_attn.k_proj = self.all_to_sharded_linear(layer.self_attn.k_proj)
layer.self_attn.v_proj = self.all_to_sharded_linear(layer.self_attn.v_proj)

View File

@@ -119,6 +119,7 @@ def mlx_generate(
model: Model,
tokenizer: TokenizerWrapper,
task: ChatCompletionTaskParams,
prompt: str,
) -> Generator[GenerationResponse]:
# Ensure that generation stats only contains peak memory for this generation
mx.reset_peak_memory()
@@ -130,11 +131,6 @@ def mlx_generate(
if task.seed is not None:
mx.random.seed(task.seed)
prompt = apply_chat_template(
tokenizer=tokenizer,
chat_task_data=task,
)
caches = make_kv_cache(model=model)
logits_processors: list[Callable[[mx.array, mx.array], mx.array]] = []

View File

@@ -2,7 +2,9 @@ import json
import os
import resource
import sys
import threading
import time
from collections.abc import Callable
from pathlib import Path
from typing import Any, cast
@@ -57,8 +59,6 @@ from exo.shared.types.worker.shards import (
from exo.worker.download.download_utils import build_model_path
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.auto_parallel import (
TimeoutCallback,
eval_with_timeout,
pipeline_auto_parallel,
tensor_auto_parallel,
)
@@ -88,6 +88,41 @@ class ModelLoadingTimeoutError(Exception):
pass
TimeoutCallback = Callable[[], None]
def eval_with_timeout(
mlx_item: Any, # pyright: ignore[reportAny]
timeout_seconds: float = 60.0,
on_timeout: TimeoutCallback | None = None,
) -> None:
"""Evaluate MLX item with a hard timeout.
If on_timeout callback is provided, it will be called before terminating
the process. This allows the runner to send a failure event before exit.
"""
completed = threading.Event()
def watchdog() -> None:
if not completed.wait(timeout=timeout_seconds):
logger.error(
f"mlx_item evaluation timed out after {timeout_seconds:.0f}s. "
"This may indicate an issue with FAST_SYNCH and tensor parallel sharding. "
"Terminating process."
)
if on_timeout is not None:
on_timeout()
os._exit(1)
watchdog_thread = threading.Thread(target=watchdog, daemon=True)
watchdog_thread.start()
try:
mx.eval(mlx_item) # pyright: ignore[reportAny]
finally:
completed.set()
def mx_barrier(group: Group | None = None):
mx.eval(
mx.distributed.all_sum(
@@ -177,6 +212,11 @@ def mlx_distributed_init(
os.environ["MLX_JACCL_COORDINATOR"] = jaccl_coordinator
group = mx.distributed.init(backend="jaccl", strict=True)
case _:
raise ValueError(
f"Unsupported instance type for MLX distributed: {type(bound_instance.instance)}"
)
logger.info(f"Rank {rank} mlx distributed initialization complete")
return group
@@ -261,6 +301,14 @@ def shard_and_load(
logger.info(f"Group size: {group.size()}, group rank: {group.rank()}")
match shard_metadata:
case TensorShardMetadata():
logger.info(f"loading model from {model_path} with tensor parallelism")
model = tensor_auto_parallel(model, group)
case PipelineShardMetadata():
logger.info(f"loading model from {model_path} with pipeline parallelism")
model = pipeline_auto_parallel(model, group, shard_metadata)
# Estimate timeout based on model size
base_timeout = float(os.environ.get("EXO_MODEL_LOAD_TIMEOUT", "60"))
model_size_gb = get_weights_size(shard_metadata).in_bytes / (1024**3)
@@ -269,23 +317,11 @@ def shard_and_load(
f"Evaluating model parameters with timeout of {timeout_seconds:.0f}s "
f"(model size: {model_size_gb:.1f}GB)"
)
match shard_metadata:
case TensorShardMetadata():
logger.info(f"loading model from {model_path} with tensor parallelism")
model = tensor_auto_parallel(model, group, timeout_seconds, on_timeout)
case PipelineShardMetadata():
logger.info(f"loading model from {model_path} with pipeline parallelism")
model = pipeline_auto_parallel(
model, group, shard_metadata
)
# Skip eval for pipeline parallel to avoid fast synch issues
mx_barrier(group)
return model, tokenizer
# Eager eval for tensor parallel (ranks have same operations on sharded data)
eval_with_timeout(model.parameters(), timeout_seconds, on_timeout)
# TODO: Do we need this?
mx.eval(model)
logger.debug("SHARDED")
logger.debug(model)
@@ -400,6 +436,16 @@ def apply_chat_template(
return prompt
def detect_thinking_prompt_suffix(prompt: str, tokenizer: TokenizerWrapper) -> bool:
"""
Detect if prompt ends with a thinking opening tag that should be
prepended to the output stream.
"""
think_token = tokenizer.think_start
return think_token is not None and prompt.rstrip().endswith(think_token)
class NullKVCache(KVCache):
"""
A KVCache that pretends to exist but holds zero tokens.

View File

@@ -409,7 +409,7 @@ class Worker:
conns = await check_reachable(
self.state.topology,
self.node_id,
self.state.node_profiles,
self.state.node_network,
)
for nid in conns:
for ip in conns[nid]:

View File

@@ -21,7 +21,11 @@ from exo.shared.types.worker.downloads import (
DownloadOngoing,
DownloadProgress,
)
from exo.shared.types.worker.instances import BoundInstance, Instance, InstanceId
from exo.shared.types.worker.instances import (
BoundInstance,
Instance,
InstanceId,
)
from exo.shared.types.worker.runners import (
RunnerConnected,
RunnerConnecting,
@@ -50,6 +54,16 @@ def plan(
all_runners: Mapping[RunnerId, RunnerStatus], # all global
tasks: Mapping[TaskId, Task],
) -> Task | None:
from exo.plugins.registry import PluginRegistry
registry = PluginRegistry.get()
# Check plugin tasks first
for plugin in registry.all_plugins():
task = plugin.plan_task(runners, instances)
if task is not None:
return task
# Python short circuiting OR logic should evaluate these sequentially.
return (
_kill_runner(runners, all_runners, instances)
@@ -113,7 +127,18 @@ def _model_needs_download(
runners: Mapping[RunnerId, RunnerSupervisor],
download_status: Mapping[ModelId, DownloadProgress],
) -> DownloadModel | None:
from exo.plugins.registry import PluginRegistry
registry = PluginRegistry.get()
for runner in runners.values():
instance = runner.bound_instance.instance
# Check if any plugin wants to skip download for this instance
plugin = registry.get_plugin_for_instance(instance)
if plugin is not None and plugin.should_skip_download(instance):
continue
model_id = runner.bound_instance.bound_shard.model_meta.model_id
if isinstance(runner.status, RunnerIdle) and (
model_id not in download_status

View File

@@ -4,7 +4,10 @@ import loguru
from exo.shared.types.events import Event, RunnerStatusUpdated
from exo.shared.types.tasks import Task
from exo.shared.types.worker.instances import BoundInstance, MlxJacclInstance
from exo.shared.types.worker.instances import (
BoundInstance,
MlxJacclInstance,
)
from exo.shared.types.worker.runners import RunnerFailed
from exo.utils.channels import ClosedResourceError, MpReceiver, MpSender
@@ -17,6 +20,7 @@ def entrypoint(
task_receiver: MpReceiver[Task],
_logger: "loguru.Logger",
) -> None:
# Set FAST_SYNCH based on env var or JACCL device count
fast_synch_override = os.environ.get("EXO_FAST_SYNCH")
if fast_synch_override == "on" or (
fast_synch_override != "off"
@@ -34,11 +38,26 @@ def entrypoint(
logger.info(f"Fast synch flag: {os.environ['MLX_METAL_FAST_SYNCH']}")
# Import main after setting global logger - this lets us just import logger from this module
# Route based on instance type (plugins or default MLX)
try:
from exo.worker.runner.runner import main
from exo.plugins.registry import PluginRegistry, discover_plugins
main(bound_instance, event_sender, task_receiver)
# Discover plugins in subprocess (they aren't inherited from main process)
discover_plugins()
registry = PluginRegistry.get()
instance = bound_instance.instance
# Check if a plugin handles this instance type
plugin = registry.get_plugin_for_instance(instance)
if plugin is not None:
# Delegate to plugin runner
plugin.create_runner(bound_instance, event_sender, task_receiver)
else:
# MLX runner (default)
from exo.worker.runner.runner import main
main(bound_instance, event_sender, task_receiver)
except ClosedResourceError:
logger.warning("Runner communication closed unexpectedly")
except Exception as e:

View File

@@ -4,6 +4,7 @@ from functools import cache
import mlx.core as mx
from mlx_lm.models.gpt_oss import Model as GptOssModel
from mlx_lm.tokenizer_utils import TokenizerWrapper
from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
HarmonyEncodingName,
Role,
@@ -50,6 +51,8 @@ from exo.shared.types.worker.runners import (
from exo.utils.channels import MpReceiver, MpSender
from exo.worker.engines.mlx.generator.generate import mlx_generate, warmup_inference
from exo.worker.engines.mlx.utils_mlx import (
apply_chat_template,
detect_thinking_prompt_suffix,
initialize_mlx,
load_mlx_items,
mlx_force_oom,
@@ -177,17 +180,28 @@ def main(
try:
_check_for_debug_prompts(task_params.messages[0].content)
# Build prompt once - used for both generation and thinking detection
prompt = apply_chat_template(tokenizer, task_params)
# Generate responses using the actual MLX generation
mlx_generator = mlx_generate(
model=model,
tokenizer=tokenizer,
task=task_params,
prompt=prompt,
)
# GPT-OSS specific parsing to match other model formats.
if isinstance(model, GptOssModel):
mlx_generator = parse_gpt_oss(mlx_generator)
# For other thinking models (GLM, etc.), check if we need to
# prepend the thinking tag that was consumed by the chat template
if detect_thinking_prompt_suffix(prompt, tokenizer):
mlx_generator = parse_thinking_models(
mlx_generator, tokenizer
)
# TODO: Add tool call parser here
for response in mlx_generator:
@@ -293,6 +307,28 @@ def parse_gpt_oss(
break
def parse_thinking_models(
responses: Generator[GenerationResponse],
tokenizer: TokenizerWrapper,
) -> Generator[GenerationResponse]:
"""
For models that inject thinking tags in the prompt (like GLM-4.7),
prepend the thinking tag to the output stream so the frontend
can properly parse thinking content.
"""
first = True
for response in responses:
if first:
first = False
yield response.model_copy(
update={
"text": tokenizer.think_start,
"token": tokenizer.think_start_id, # type: ignore
}
)
yield response
EXO_RUNNER_MUST_FAIL = "EXO RUNNER MUST FAIL"
EXO_RUNNER_MUST_OOM = "EXO RUNNER MUST OOM"
EXO_RUNNER_MUST_TIMEOUT = "EXO RUNNER MUST TIMEOUT"

View File

@@ -114,6 +114,10 @@ def patch_out_mlx(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(mlx_runner, "load_mlx_items", make_nothin((1, 1)))
monkeypatch.setattr(mlx_runner, "warmup_inference", make_nothin(1))
monkeypatch.setattr(mlx_runner, "_check_for_debug_prompts", nothin)
# Mock apply_chat_template since we're using a fake tokenizer (integer 1).
# Returns a prompt without thinking tag so detect_thinking_prompt_suffix returns None.
monkeypatch.setattr(mlx_runner, "apply_chat_template", make_nothin("test prompt"))
monkeypatch.setattr(mlx_runner, "detect_thinking_prompt_suffix", make_nothin(False))
def fake_generate(*_1: object, **_2: object):
yield GenerationResponse(token=0, text="hi", finish_reason="stop")