Compare commits

...

14 Commits

Author SHA1 Message Date
Sami Khan
a9db83ba6b consolidated all FLASH types into the plugin folder 2026-01-21 11:59:32 +05:00
Sami Khan
cd630dea43 formatting 2026-01-20 07:54:52 +05:00
Sami Khan
e55dae5ce8 code quality 2026-01-20 07:47:14 +05:00
Sami Khan
302c43afd5 Merge main into sami/flash 2026-01-20 07:31:22 +05:00
Sami Khan
2cf59e2322 nix flake check 2026-01-20 07:21:41 +05:00
Sami Khan
e506c7d65c exo plugins 2026-01-20 06:53:43 +05:00
Sami Khan
c1fa2ddeaf SLURM compatible commands 2026-01-20 06:53:43 +05:00
Sami Khan
37c5a2a246 Merge branch 'main' into sami/flash 2026-01-15 08:57:36 +05:00
Sami Khan
4d7f03834a deleted separate server 2026-01-15 08:50:45 +05:00
Sami Khan
bdb9fbc8c0 Merge branch 'main' into sami/flash 2026-01-14 08:10:51 +05:00
Sami Khan
8c7180810c type checking 2026-01-14 07:15:45 +05:00
Sami Khan
318c6e000b code cleanup 2026-01-14 04:56:59 +05:00
Sami Khan
2d45544da0 use rsh server instead of ssh 2026-01-13 02:46:25 +05:00
Sami Khan
7cbafa768a flash+exo 2026-01-12 10:26:16 +05:00
30 changed files with 2355 additions and 91 deletions

View File

@@ -30,6 +30,7 @@ dependencies = [
exo-master = "exo.master.main:main"
exo-worker = "exo.worker.main:main"
exo = "exo.main:main"
exo-rsh = "exo.rsh.client:main"
# dependencies only required for development
[dependency-groups]

32
src/exo/cli/__init__.py Normal file
View File

@@ -0,0 +1,32 @@
"""Exo CLI - SLURM-compatible job management commands."""
def run_subcommand(command: str, args: list[str]) -> int:
"""Route to the appropriate subcommand handler.
Args:
command: The subcommand name (sbatch, squeue, scancel, salloc)
args: Command line arguments for the subcommand
Returns:
Exit code from the subcommand
"""
if command == "sbatch":
from exo.cli.sbatch import main
return main(args)
elif command == "squeue":
from exo.cli.squeue import main
return main(args)
elif command == "scancel":
from exo.cli.scancel import main
return main(args)
elif command == "salloc":
from exo.cli.salloc import main
return main(args)
else:
print(f"Unknown subcommand: {command}")
return 1

118
src/exo/cli/common.py Normal file
View File

@@ -0,0 +1,118 @@
"""Common utilities for Exo CLI commands."""
import json
import os
import urllib.request
from typing import Any
from urllib.error import HTTPError, URLError
# Default API endpoint
DEFAULT_API_HOST = "localhost"
DEFAULT_API_PORT = 52415
def get_api_base() -> str:
"""Get the API base URL from environment or defaults."""
host = os.environ.get("EXO_API_HOST", DEFAULT_API_HOST)
port = os.environ.get("EXO_API_PORT", str(DEFAULT_API_PORT))
return f"http://{host}:{port}"
def api_request(
method: str,
path: str,
data: dict[str, Any] | None = None,
) -> dict[str, Any] | list[Any]:
"""Make an API request to the Exo server.
Args:
method: HTTP method (GET, POST, DELETE, etc.)
path: API path (e.g., "/flash/instances")
data: Optional JSON data for POST/PUT requests
Returns:
Parsed JSON response
Raises:
SystemExit: On connection or HTTP errors
"""
url = f"{get_api_base()}{path}"
request_data = None
if data is not None:
request_data = json.dumps(data).encode("utf-8")
req = urllib.request.Request(
url,
data=request_data,
method=method,
)
req.add_header("Content-Type", "application/json")
try:
with urllib.request.urlopen(req, timeout=30) as response: # pyright: ignore[reportAny]
body: str = response.read().decode("utf-8") # pyright: ignore[reportAny]
if body:
return json.loads(body) # pyright: ignore[reportAny]
return {}
except HTTPError as e:
error_body = e.read().decode("utf-8") if e.fp else ""
print(f"API error: {e.code} {e.reason}")
if error_body:
try:
error_json: dict[str, str] = json.loads(error_body) # pyright: ignore[reportAny]
if "detail" in error_json:
print(f" {error_json['detail']}")
except json.JSONDecodeError:
print(f" {error_body}")
raise SystemExit(1) from None
except URLError as e:
print(f"Connection error: {e.reason}")
print(f"Is Exo running at {get_api_base()}?")
raise SystemExit(1) from None
def truncate_id(instance_id: str, length: int = 8) -> str:
"""Truncate a UUID for display.
Args:
instance_id: Full UUID string
length: Number of characters to keep
Returns:
Truncated ID without hyphens
"""
return instance_id.replace("-", "")[:length]
def format_table(headers: list[str], rows: list[list[str]]) -> str:
"""Format data as a simple text table.
Args:
headers: Column headers
rows: List of rows, each row is a list of column values
Returns:
Formatted table string
"""
if not rows:
return " ".join(f"{h:<10}" for h in headers)
# Calculate column widths
widths = [len(h) for h in headers]
for row in rows:
for i, cell in enumerate(row):
if i < len(widths):
widths[i] = max(widths[i], len(cell))
# Build format string
fmt = " ".join(f"{{:<{w}}}" for w in widths)
# Format output
lines = [fmt.format(*headers)]
for row in rows:
# Pad row if needed
padded = row + [""] * (len(headers) - len(row))
lines.append(fmt.format(*padded[: len(headers)]))
return "\n".join(lines)

100
src/exo/cli/salloc.py Normal file
View File

@@ -0,0 +1,100 @@
"""salloc - Allocate nodes for interactive use.
Usage:
exo salloc [options] [-- command [args...]]
Options:
-N, --nodes N Number of nodes to allocate (default: 1)
--hosts HOSTS Comma-separated list of hostnames
If a command is provided after --, it will be executed with
SLURM-like environment variables set:
SLURM_JOB_NODELIST - Comma-separated list of allocated nodes
SLURM_NNODES - Number of allocated nodes
Examples:
exo salloc --nodes=2 --hosts=node1,node2 -- mpirun ./my_program
exo salloc --hosts=localhost -- bash
"""
import argparse
import os
import subprocess
import sys
def main(args: list[str]) -> int:
"""Main entry point for salloc command."""
# Split args at -- if present
cmd_args: list[str] = []
salloc_args = args
if "--" in args:
idx = args.index("--")
salloc_args = args[:idx]
cmd_args = args[idx + 1 :]
parser = argparse.ArgumentParser(
prog="exo salloc",
description="Allocate nodes for interactive use",
)
parser.add_argument(
"-N",
"--nodes",
type=int,
default=1,
help="Number of nodes to allocate (default: 1)",
)
parser.add_argument(
"--hosts",
help="Comma-separated list of hostnames (required)",
)
parsed = parser.parse_args(salloc_args)
nodes: int = parsed.nodes # pyright: ignore[reportAny]
hosts: str | None = parsed.hosts # pyright: ignore[reportAny]
# Require explicit hosts since we can't discover them from topology
if not hosts:
print("Error: --hosts is required (e.g., --hosts=node1,node2)", file=sys.stderr)
print(" The Exo topology doesn't expose hostnames.", file=sys.stderr)
return 1
host_list = [h.strip() for h in hosts.split(",") if h.strip()]
if len(host_list) < nodes:
print(
f"Error: Requested {nodes} nodes but only {len(host_list)} hosts provided",
file=sys.stderr,
)
return 1
# Use first N hosts
allocated_hosts = host_list[:nodes]
nodelist = ",".join(allocated_hosts)
# Set environment variables
env = os.environ.copy()
env["SLURM_JOB_NODELIST"] = nodelist
env["SLURM_NNODES"] = str(nodes)
print(f"salloc: Granted job allocation on {nodes} node(s)")
print(f"salloc: Nodes: {nodelist}")
if cmd_args:
# Run the command
print(f"salloc: Running: {' '.join(cmd_args)}")
result = subprocess.run(cmd_args, env=env)
return result.returncode
else:
# Start interactive shell
shell = os.environ.get("SHELL", "/bin/bash")
print(f"salloc: Starting shell {shell}")
print("salloc: Use 'exit' to release allocation")
result = subprocess.run([shell], env=env)
return result.returncode
if __name__ == "__main__":
sys.exit(main(sys.argv[1:]))

233
src/exo/cli/sbatch.py Normal file
View File

@@ -0,0 +1,233 @@
"""sbatch - Submit a batch job to Exo.
Usage:
exo sbatch [options] <script|executable>
exo sbatch --job-name=NAME --nodes=N <executable>
Options:
-J, --job-name NAME Job name
-N, --nodes N Number of nodes (default: 1)
--ntasks-per-node N Tasks per node (default: 1)
-D, --chdir DIR Working directory
--hosts HOSTS Comma-separated list of hostnames
Job scripts can contain #SBATCH directives:
#!/bin/bash
#SBATCH --job-name=Sod2D
#SBATCH --nodes=2
#SBATCH --chdir=/path/to/workdir
/path/to/flash4
"""
import argparse
import os
import re
import sys
from exo.cli.common import api_request, truncate_id
def parse_job_script(script_path: str) -> tuple[dict[str, str], str | None]:
"""Parse a job script for #SBATCH directives and executable.
Args:
script_path: Path to the job script
Returns:
Tuple of (directives dict, executable path or None)
"""
directives: dict[str, str] = {}
executable: str | None = None
with open(script_path, "r") as f:
for line in f:
line = line.strip()
# Parse #SBATCH directives
if line.startswith("#SBATCH"):
# Handle both --option=value and --option value formats
match = re.match(r"#SBATCH\s+(-\w|--[\w-]+)(?:=|\s+)(.+)", line)
if match:
opt, val = match.groups()
directives[opt.lstrip("-")] = val.strip()
continue
# Skip comments and empty lines
if line.startswith("#") or not line:
continue
# First non-comment, non-directive line is the executable
if executable is None:
# Handle lines like "/path/to/flash4" or "srun /path/to/flash4"
parts = line.split()
if parts:
# Skip srun/mpirun prefixes if present
for part in parts:
if not part.startswith("-") and "/" in part:
executable = part
break
if executable is None and parts:
executable = parts[-1] # Last token
return directives, executable
def main(args: list[str]) -> int:
"""Main entry point for sbatch command."""
parser = argparse.ArgumentParser(
prog="exo sbatch",
description="Submit a batch job to Exo",
)
parser.add_argument(
"script",
help="Job script or executable path",
)
parser.add_argument(
"-J",
"--job-name",
dest="job_name",
help="Job name",
)
parser.add_argument(
"-N",
"--nodes",
type=int,
default=1,
help="Number of nodes (default: 1)",
)
parser.add_argument(
"--ntasks-per-node",
type=int,
default=1,
help="Tasks per node (default: 1)",
)
parser.add_argument(
"-D",
"--chdir",
help="Working directory",
)
parser.add_argument(
"--hosts",
help="Comma-separated list of hostnames",
)
parsed = parser.parse_args(args)
# Extract typed values from namespace
script_path: str = parsed.script # pyright: ignore[reportAny]
arg_job_name: str | None = parsed.job_name # pyright: ignore[reportAny]
arg_nodes: int = parsed.nodes # pyright: ignore[reportAny]
arg_ntasks: int = parsed.ntasks_per_node # pyright: ignore[reportAny]
arg_chdir: str | None = parsed.chdir # pyright: ignore[reportAny]
arg_hosts: str | None = parsed.hosts # pyright: ignore[reportAny]
# Determine if input is a script or direct executable
executable: str | None = None
directives: dict[str, str] = {}
if os.path.isfile(script_path):
# Check if it's a binary file (executable) or text script
is_binary = False
try:
with open(script_path, "rb") as f:
chunk = f.read(512)
# Binary files typically contain null bytes
is_binary = b"\x00" in chunk
except OSError:
pass
if is_binary:
# It's a binary executable
executable = script_path
else:
# Try to read as text
try:
with open(script_path, "r") as f:
first_line = f.readline()
f.seek(0)
content = f.read(1024)
if first_line.startswith("#!") or "#SBATCH" in content:
# It's a job script - parse it
directives, executable = parse_job_script(script_path)
else:
# It's an executable (text but no shebang/directives)
executable = script_path
except UnicodeDecodeError:
# Can't read as text - treat as binary executable
executable = script_path
else:
# Not a file - treat as executable path
executable = script_path
if executable is None:
print("Error: No executable found in job script", file=sys.stderr)
return 1
# Build job parameters - CLI args override script directives
job_name = arg_job_name or directives.get("job-name") or directives.get("J")
if not job_name:
# Generate name from executable
job_name = os.path.basename(executable).replace(".", "_")
nodes = arg_nodes
if "nodes" in directives:
nodes = int(directives["nodes"])
if "N" in directives:
nodes = int(directives["N"])
if arg_nodes != 1: # CLI override
nodes = arg_nodes
ntasks = arg_ntasks
if "ntasks-per-node" in directives:
ntasks = int(directives["ntasks-per-node"])
if arg_ntasks != 1: # CLI override
ntasks = arg_ntasks
workdir = arg_chdir or directives.get("chdir") or directives.get("D")
if not workdir:
workdir = os.getcwd()
hosts = arg_hosts or directives.get("hosts") or ""
# Resolve executable to absolute path
if not os.path.isabs(executable):
executable = os.path.abspath(os.path.join(workdir, executable))
# Submit job via API using query parameters
from urllib.parse import urlencode
params = {
"simulation_name": job_name,
"flash_executable_path": executable,
"parameter_file_path": "", # FLASH par file - use default
"working_directory": workdir,
"ranks_per_node": str(ntasks),
"min_nodes": str(nodes),
"hosts": hosts,
}
query_string = urlencode(params)
result = api_request("POST", f"/flash/launch?{query_string}")
# Print job submission confirmation
if isinstance(result, dict):
instance_id_val = result.get("instance_id")
if instance_id_val is not None:
job_id = truncate_id(str(instance_id_val)) # pyright: ignore[reportAny]
print(f"Submitted batch job {job_id}")
else:
# Instance created asynchronously - user should check squeue
print("Job submitted successfully")
print("Use 'exo squeue' to view job ID")
else:
print("Job submitted successfully")
print("Use 'exo squeue' to view job ID")
return 0
if __name__ == "__main__":
sys.exit(main(sys.argv[1:]))

95
src/exo/cli/scancel.py Normal file
View File

@@ -0,0 +1,95 @@
"""scancel - Cancel jobs in the Exo queue.
Usage:
exo scancel <jobid> [<jobid>...]
Arguments:
jobid Job ID (or prefix) to cancel. Can specify multiple.
Examples:
exo scancel abc123 # Cancel job starting with abc123
exo scancel abc123 def456 # Cancel multiple jobs
"""
import argparse
import sys
from typing import Any, cast
from exo.cli.common import api_request, truncate_id
def main(args: list[str]) -> int:
"""Main entry point for scancel command."""
parser = argparse.ArgumentParser(
prog="exo scancel",
description="Cancel jobs in the Exo queue",
)
parser.add_argument(
"jobids",
nargs="+",
help="Job ID(s) to cancel",
)
parsed = parser.parse_args(args)
jobids: list[str] = parsed.jobids # pyright: ignore[reportAny]
# Fetch current jobs to resolve partial IDs
result = api_request("GET", "/flash/instances")
if isinstance(result, list):
instances = cast(list[dict[str, Any]], result)
else:
instances = cast(list[dict[str, Any]], result.get("instances", []))
# Build lookup of full IDs
id_map: dict[str, str] = {}
for inst in instances:
iid = inst.get("instance_id", "") # pyright: ignore[reportAny]
full_id = str(iid) if iid else "" # pyright: ignore[reportAny]
if full_id:
# Map both full ID and truncated versions
normalized = full_id.replace("-", "").lower()
id_map[normalized] = full_id
# Also map prefixes
for length in range(4, len(normalized) + 1):
prefix = normalized[:length]
if prefix not in id_map:
id_map[prefix] = full_id
cancelled = 0
errors = 0
for jobid in jobids:
search = jobid.lower().replace("-", "")
# Find matching full ID
full_id = id_map.get(search)
if not full_id:
# Try prefix match
matches = [fid for key, fid in id_map.items() if key.startswith(search)]
if len(matches) == 1:
full_id = matches[0]
elif len(matches) > 1:
print(f"Ambiguous job ID: {jobid} matches multiple jobs")
errors += 1
continue
else:
print(f"Job not found: {jobid}")
errors += 1
continue
# Cancel the job
try:
api_request("DELETE", f"/flash/{full_id}")
print(f"Job {truncate_id(full_id)} cancelled")
cancelled += 1
except SystemExit:
print(f"Failed to cancel job {truncate_id(full_id)}")
errors += 1
if errors > 0 and cancelled == 0:
return 1
return 0
if __name__ == "__main__":
sys.exit(main(sys.argv[1:]))

165
src/exo/cli/squeue.py Normal file
View File

@@ -0,0 +1,165 @@
"""squeue - View the Exo job queue.
Usage:
exo squeue [options]
Options:
-l, --long Show detailed output
-j, --job ID Show only this job
Output columns:
JOBID - Job identifier (truncated UUID)
NAME - Job name
NODES - Number of nodes
STATE - Job state (PENDING, RUNNING, FAILED, etc.)
"""
import argparse
import sys
from typing import Any, cast
from exo.cli.common import api_request, format_table, truncate_id
# Map Exo runner statuses to SLURM-like states
STATUS_MAP: dict[str, str] = {
"RunnerIdle": "PENDING",
"RunnerConnecting": "CONFIGURING",
"RunnerConnected": "CONFIGURING",
"RunnerLoading": "CONFIGURING",
"RunnerLoaded": "CONFIGURING",
"RunnerWarmingUp": "CONFIGURING",
"RunnerReady": "COMPLETING",
"RunnerRunning": "RUNNING",
"RunnerShuttingDown": "COMPLETING",
"RunnerShutdown": "COMPLETED",
"RunnerFailed": "FAILED",
}
def get_job_state(runner_statuses: dict[str, Any]) -> str:
"""Determine overall job state from runner statuses."""
if not runner_statuses:
return "PENDING"
states: set[str] = set()
for status_val in runner_statuses.values(): # pyright: ignore[reportAny]
if isinstance(status_val, dict):
# Extract status type from discriminated union
type_val = status_val.get("type", "RunnerIdle") # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
status_type = str(type_val) if type_val else "RunnerIdle" # pyright: ignore[reportUnknownArgumentType]
elif isinstance(status_val, str):
status_type = status_val
else:
status_type = "RunnerIdle"
# Strip parentheses from status strings like "RunnerRunning()"
if status_type.endswith("()"):
status_type = status_type[:-2]
states.add(STATUS_MAP.get(status_type, "UNKNOWN"))
# Priority order for overall state
if "FAILED" in states:
return "FAILED"
if "RUNNING" in states:
return "RUNNING"
if "CONFIGURING" in states:
return "CONFIGURING"
if "COMPLETING" in states:
return "COMPLETING"
if "COMPLETED" in states:
return "COMPLETED"
if "PENDING" in states:
return "PENDING"
return "UNKNOWN"
def main(args: list[str]) -> int:
"""Main entry point for squeue command."""
parser = argparse.ArgumentParser(
prog="exo squeue",
description="View the Exo job queue",
)
parser.add_argument(
"-l",
"--long",
action="store_true",
help="Show detailed output",
)
parser.add_argument(
"-j",
"--job",
help="Show only this job ID",
)
parsed = parser.parse_args(args)
# Extract typed values
long_format: bool = parsed.long # pyright: ignore[reportAny]
job_filter: str | None = parsed.job # pyright: ignore[reportAny]
# Fetch jobs from API - returns list directly
result = api_request("GET", "/flash/instances")
# API returns list directly, not {"instances": [...]}
if isinstance(result, list):
instances = cast(list[dict[str, Any]], result)
else:
instances = cast(list[dict[str, Any]], result.get("instances", []))
if not instances:
# No jobs - just print header
if long_format:
print("JOBID NAME NODES RANKS STATE WORKDIR")
else:
print("JOBID NAME NODES STATE")
return 0
# Filter by job ID if specified
if job_filter:
search = job_filter.lower()
filtered: list[dict[str, Any]] = []
for i in instances:
iid = i.get("instance_id", "") # pyright: ignore[reportAny]
if search in str(iid).lower().replace("-", ""): # pyright: ignore[reportAny]
filtered.append(i)
instances = filtered
# Build table
rows: list[list[str]] = []
if long_format:
headers = ["JOBID", "NAME", "NODES", "RANKS", "STATE", "WORKDIR"]
for inst in instances:
iid_val = inst.get("instance_id", "") # pyright: ignore[reportAny]
instance_id = str(iid_val) if iid_val else "" # pyright: ignore[reportAny]
job_id = truncate_id(instance_id, 12)
name_val = inst.get("simulation_name", "") # pyright: ignore[reportAny]
name = (str(name_val) if name_val else "")[:15] # pyright: ignore[reportAny]
runner_statuses = cast(dict[str, Any], inst.get("runner_statuses", {}))
nodes = str(len(runner_statuses))
ranks_val = inst.get("total_ranks", 0) # pyright: ignore[reportAny]
ranks = str(ranks_val) if ranks_val else "0" # pyright: ignore[reportAny]
state = get_job_state(runner_statuses)
workdir_val = inst.get("working_directory", "") # pyright: ignore[reportAny]
workdir = str(workdir_val) if workdir_val else "" # pyright: ignore[reportAny]
# Truncate workdir for display
if len(workdir) > 30:
workdir = "..." + workdir[-27:]
rows.append([job_id, name, nodes, ranks, state, workdir])
else:
headers = ["JOBID", "NAME", "NODES", "STATE"]
for inst in instances:
iid_val = inst.get("instance_id", "") # pyright: ignore[reportAny]
instance_id = str(iid_val) if iid_val else "" # pyright: ignore[reportAny]
job_id = truncate_id(instance_id, 8)
name_val = inst.get("simulation_name", "") # pyright: ignore[reportAny]
name = (str(name_val) if name_val else "")[:15] # pyright: ignore[reportAny]
runner_statuses = cast(dict[str, Any], inst.get("runner_statuses", {}))
nodes = str(len(runner_statuses))
state = get_job_state(runner_statuses)
rows.append([job_id, name, nodes, state])
print(format_table(headers, rows))
return 0
if __name__ == "__main__":
sys.exit(main(sys.argv[1:]))

View File

@@ -195,6 +195,14 @@ class Node:
def main():
# Check for SLURM-compatible subcommands first
import sys
if len(sys.argv) > 1 and sys.argv[1] in ("sbatch", "squeue", "scancel", "salloc"):
from exo.cli import run_subcommand
sys.exit(run_subcommand(sys.argv[1], sys.argv[2:]))
args = Args.parse()
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
resource.setrlimit(resource.RLIMIT_NOFILE, (max(soft, 65535), hard))
@@ -205,6 +213,11 @@ def main():
logger.info("Starting EXO")
logger.info(f"EXO_LIBP2P_NAMESPACE: {os.getenv('EXO_LIBP2P_NAMESPACE')}")
# Discover and register plugins
from exo.plugins.registry import discover_plugins
discover_plugins()
# Set FAST_SYNCH override env var for runner subprocesses
if args.fast_synch is True:
os.environ["EXO_FAST_SYNCH"] = "on"

View File

@@ -1,7 +1,9 @@
import asyncio
import os
import time
from collections.abc import AsyncGenerator
from http import HTTPStatus
from typing import cast
from typing import Any, Optional, cast
import anyio
from anyio import BrokenResourceError, create_task_group
@@ -14,6 +16,7 @@ from hypercorn.asyncio import serve # pyright: ignore[reportUnknownVariableType
from hypercorn.config import Config
from hypercorn.typing import ASGIFramework
from loguru import logger
from pydantic import BaseModel
from exo.master.placement import place_instance as get_instance_placements
from exo.shared.apply import apply
@@ -62,7 +65,11 @@ from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.state import State
from exo.shared.types.tasks import ChatCompletionTaskParams
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.instances import (
Instance,
InstanceId,
InstanceMeta,
)
from exo.shared.types.worker.shards import Sharding
from exo.utils.banner import print_startup_banner
from exo.utils.channels import Receiver, Sender, channel
@@ -70,6 +77,22 @@ from exo.utils.dashboard_path import find_dashboard
from exo.utils.event_buffer import OrderedBuffer
class ExecuteRequest(BaseModel):
"""Request to execute a command."""
command: list[str]
cwd: Optional[str] = None
env: Optional[dict[str, str]] = None
class ExecuteResponse(BaseModel):
"""Response from command execution."""
exit_code: int
stdout: str
stderr: str
def chunk_to_response(
chunk: TokenChunk, command_id: CommandId
) -> ChatCompletionResponse:
@@ -126,6 +149,7 @@ class API:
self._setup_exception_handlers()
self._setup_cors()
self._setup_routes()
self._register_plugin_routes()
self.app.mount(
"/",
@@ -194,6 +218,58 @@ class API:
self.app.post("/bench/chat/completions")(self.bench_chat_completions)
self.app.get("/state")(lambda: self.state)
self.app.get("/events")(lambda: self._event_log)
# Remote execution endpoint (used by exo-rsh for MPI)
self.app.post("/execute")(self.execute)
def _register_plugin_routes(self) -> None:
"""Register API routes from all loaded plugins."""
import functools
import inspect
from exo.plugins.context import PluginContext
from exo.plugins.registry import PluginRegistry
registry = PluginRegistry.get()
for method, path, handler, plugin in registry.get_all_api_routes():
# Create a wrapper that injects the plugin context while preserving
# the original function signature for FastAPI parameter extraction
@functools.wraps(handler)
async def wrapped_handler( # pyright: ignore[reportAny]
*args: Any, # pyright: ignore[reportAny]
_handler: Any = handler, # pyright: ignore[reportAny]
**kwargs: Any, # pyright: ignore[reportAny]
) -> Any:
context = PluginContext(
state=self.state,
send_command=self._send,
node_id=self.node_id,
)
# Pass context as first argument, then forward all other args
return await _handler(context, *args, **kwargs) # pyright: ignore[reportAny]
# Modify the wrapper signature to match the original handler
# but without the 'ctx' parameter (we inject it)
orig_sig = inspect.signature(handler)
params = list(orig_sig.parameters.values())
# Remove the first parameter (ctx: PluginContext)
if params and params[0].name in ("ctx", "context"):
params = params[1:]
wrapped_handler.__signature__ = orig_sig.replace(parameters=params) # pyright: ignore[reportAttributeAccessIssue]
# Register the route based on HTTP method
if method == "get":
self.app.get(path)(wrapped_handler)
elif method == "post":
self.app.post(path)(wrapped_handler)
elif method == "delete":
self.app.delete(path)(wrapped_handler)
elif method == "put":
self.app.put(path)(wrapped_handler)
logger.debug(
f"Registered plugin route: {method.upper()} {path} ({plugin.name})"
)
async def place_instance(self, payload: PlaceInstanceParams):
command = PlaceInstance(
@@ -626,6 +702,65 @@ class API:
]
)
async def execute(self, request: ExecuteRequest) -> ExecuteResponse:
"""Execute a command locally. Used by exo-rsh for MPI remote execution."""
cmd_str = " ".join(request.command)
logger.info(f"Executing: {cmd_str}")
try:
# Build environment
env = os.environ.copy()
if request.env:
env.update(request.env)
# Check if command contains shell metacharacters
# If so, run through shell. mpirun sends complex commands like:
# "VAR=value;export VAR;/path/to/prted --args"
needs_shell = any(c in cmd_str for c in ";|&$`")
if needs_shell:
process = await asyncio.create_subprocess_shell(
cmd_str,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
cwd=request.cwd,
env=env,
)
else:
process = await asyncio.create_subprocess_exec(
*request.command,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
cwd=request.cwd,
env=env,
)
stdout, stderr = await process.communicate()
exit_code = process.returncode or 0
logger.info(f"Command completed with exit code {exit_code}")
return ExecuteResponse(
exit_code=exit_code,
stdout=stdout.decode("utf-8", errors="replace"),
stderr=stderr.decode("utf-8", errors="replace"),
)
except FileNotFoundError:
logger.error(f"Command not found: {request.command[0]}")
return ExecuteResponse(
exit_code=127,
stdout="",
stderr=f"Command not found: {request.command[0]}",
)
except Exception as e:
logger.error(f"Execution error: {e}")
return ExecuteResponse(
exit_code=1,
stdout="",
stderr=str(e),
)
async def run(self):
cfg = Config()
cfg.bind = f"0.0.0.0:{self.port}"

View File

@@ -96,104 +96,128 @@ class Master:
self._tg.cancel_scope.cancel()
async def _command_processor(self) -> None:
from exo.plugins.registry import PluginRegistry
registry = PluginRegistry.get()
with self.command_receiver as commands:
async for forwarder_command in commands:
try:
logger.info(f"Executing command: {forwarder_command.command}")
generated_events: list[Event] = []
command = forwarder_command.command
match command:
case TestCommand():
pass
case ChatCompletion():
instance_task_counts: dict[InstanceId, int] = {}
for instance in self.state.instances.values():
if (
instance.shard_assignments.model_id
== command.request_params.model
):
task_count = sum(
1
for task in self.state.tasks.values()
if task.instance_id == instance.instance_id
)
instance_task_counts[instance.instance_id] = (
task_count
# Check if a plugin handles this command
plugin = registry.get_plugin_for_command(command)
if plugin is not None:
events = plugin.process_command(
command,
self.state.topology,
self.state.instances,
)
generated_events.extend(events)
else:
# Core command handling
match command:
case TestCommand():
pass
case ChatCompletion():
instance_task_counts: dict[InstanceId, int] = {}
for instance in self.state.instances.values():
if (
instance.shard_assignments.model_id
== command.request_params.model
):
task_count = sum(
1
for task in self.state.tasks.values()
if task.instance_id == instance.instance_id
)
instance_task_counts[instance.instance_id] = (
task_count
)
if not instance_task_counts:
raise ValueError(
f"No instance found for model {command.request_params.model}"
)
if not instance_task_counts:
raise ValueError(
f"No instance found for model {command.request_params.model}"
available_instance_ids = sorted(
instance_task_counts.keys(),
key=lambda instance_id: instance_task_counts[
instance_id
],
)
available_instance_ids = sorted(
instance_task_counts.keys(),
key=lambda instance_id: instance_task_counts[
instance_id
],
)
task_id = TaskId()
generated_events.append(
TaskCreated(
task_id=task_id,
task=ChatCompletionTask(
task_id = TaskId()
generated_events.append(
TaskCreated(
task_id=task_id,
command_id=command.command_id,
instance_id=available_instance_ids[0],
task_status=TaskStatus.Pending,
task_params=command.request_params,
),
task=ChatCompletionTask(
task_id=task_id,
command_id=command.command_id,
instance_id=available_instance_ids[0],
task_status=TaskStatus.Pending,
task_params=command.request_params,
),
)
)
)
self.command_task_mapping[command.command_id] = task_id
case DeleteInstance():
placement = delete_instance(command, self.state.instances)
transition_events = get_transition_events(
self.state.instances, placement
)
generated_events.extend(transition_events)
case PlaceInstance():
placement = place_instance(
command,
self.state.topology,
self.state.instances,
self.state.node_memory,
self.state.node_network,
)
transition_events = get_transition_events(
self.state.instances, placement
)
generated_events.extend(transition_events)
case CreateInstance():
placement = add_instance_to_placements(
command,
self.state.topology,
self.state.instances,
)
transition_events = get_transition_events(
self.state.instances, placement
)
generated_events.extend(transition_events)
case TaskFinished():
generated_events.append(
TaskDeleted(
task_id=self.command_task_mapping[
self.command_task_mapping[command.command_id] = task_id
case DeleteInstance():
placement = delete_instance(
command, self.state.instances
)
transition_events = get_transition_events(
self.state.instances, placement
)
generated_events.extend(transition_events)
case PlaceInstance():
placement = place_instance(
command,
self.state.topology,
self.state.instances,
self.state.node_memory,
self.state.node_network,
)
transition_events = get_transition_events(
self.state.instances, placement
)
generated_events.extend(transition_events)
case CreateInstance():
placement = add_instance_to_placements(
command,
self.state.topology,
self.state.instances,
)
transition_events = get_transition_events(
self.state.instances, placement
)
generated_events.extend(transition_events)
case TaskFinished():
generated_events.append(
TaskDeleted(
task_id=self.command_task_mapping[
command.finished_command_id
]
)
)
if (
command.finished_command_id
in self.command_task_mapping
):
del self.command_task_mapping[
command.finished_command_id
]
)
)
if command.finished_command_id in self.command_task_mapping:
del self.command_task_mapping[
command.finished_command_id
]
case RequestEventLog():
# We should just be able to send everything, since other buffers will ignore old messages
for i in range(command.since_idx, len(self._event_log)):
await self._send_event(
IndexedEvent(idx=i, event=self._event_log[i])
)
case RequestEventLog():
# We should just be able to send everything, since other buffers will ignore old messages
for i in range(command.since_idx, len(self._event_log)):
await self._send_event(
IndexedEvent(idx=i, event=self._event_log[i])
)
case _:
# Plugin-managed commands are handled above
pass
for event in generated_events:
await self.event_sender.send(event)
except ValueError as e:

View File

@@ -159,6 +159,11 @@ def place_instance(
hosts_by_node=hosts_by_node,
ephemeral_port=ephemeral_port,
)
case _:
# Plugin-managed instance types have their own placement functions
raise ValueError(
f"Instance type {command.instance_meta} must use plugin placement"
)
return target_instances

View File

@@ -0,0 +1,38 @@
"""Exo Plugin System.
This module provides the plugin architecture for extending exo with custom
workload types (simulations, ML frameworks, etc.) without modifying core code.
"""
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from exo.plugins.base import EXOPlugin, PluginCommand, PluginInstance
from exo.plugins.registry import PluginRegistry, discover_plugins
__all__ = [
"EXOPlugin",
"PluginCommand",
"PluginInstance",
"PluginRegistry",
"discover_plugins",
]
def __getattr__(name: str) -> Any: # pyright: ignore[reportAny]
"""Lazy import to avoid circular dependencies."""
if name in ("EXOPlugin", "PluginCommand", "PluginInstance"):
from exo.plugins.base import EXOPlugin, PluginCommand, PluginInstance
return {
"EXOPlugin": EXOPlugin,
"PluginCommand": PluginCommand,
"PluginInstance": PluginInstance,
}[name]
if name in ("PluginRegistry", "discover_plugins"):
from exo.plugins.registry import PluginRegistry, discover_plugins
return {"PluginRegistry": PluginRegistry, "discover_plugins": discover_plugins}[
name
]
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")

171
src/exo/plugins/base.py Normal file
View File

@@ -0,0 +1,171 @@
"""Base classes and protocols for Exo plugins."""
from abc import ABC, abstractmethod
from collections.abc import Callable, Mapping, Sequence
from typing import TYPE_CHECKING, Any
from pydantic import Field
from exo.shared.types.common import CommandId
from exo.shared.types.events import Event
from exo.shared.types.tasks import Task
from exo.shared.types.worker.instances import InstanceId
from exo.shared.types.worker.runners import RunnerId
from exo.utils.pydantic_ext import TaggedModel
if TYPE_CHECKING:
from exo.shared.topology import Topology
from exo.shared.types.worker.instances import BoundInstance, Instance
from exo.utils.channels import MpReceiver, MpSender
from exo.worker.runner.runner_supervisor import RunnerSupervisor
class PluginCommand(TaggedModel):
"""Base class for plugin-defined commands.
All plugin commands must inherit from this class. Commands are serialized
with their class name as a tag for routing.
"""
command_id: CommandId = Field(default_factory=CommandId)
class PluginInstance(TaggedModel):
"""Base class for plugin-defined instances.
All plugin instances must inherit from this class. Plugins are expected
to define their own instance type with workload-specific fields.
"""
instance_id: InstanceId
class EXOPlugin(ABC):
"""Protocol that all exo plugins must implement.
A plugin provides:
- Custom command types for API -> Master communication
- Custom instance types representing running workloads
- Placement logic for distributing work across nodes
- Planning logic for local task scheduling
- Runner implementation for executing work
"""
@property
@abstractmethod
def name(self) -> str:
"""Unique identifier for this plugin (e.g., 'flash', 'pytorch', 'mpi')."""
...
@property
@abstractmethod
def version(self) -> str:
"""Semantic version string (e.g., '1.0.0')."""
...
# ========== Type Registration ==========
@abstractmethod
def get_command_types(self) -> Sequence[type]:
"""Return command types this plugin handles.
These commands are routed to this plugin's process_command method.
Can return core BaseCommand types or PluginCommand types.
"""
...
@abstractmethod
def get_instance_type(self) -> type:
"""Return the instance type this plugin creates.
This instance type is used for routing in planning and runner bootstrap.
Can return core Instance types or PluginInstance types.
"""
...
# ========== API Routes ==========
@abstractmethod
def get_api_routes(
self,
) -> Sequence[tuple[str, str, Callable[..., Any]]]:
"""Return FastAPI routes to register.
Each tuple: (method, path, handler)
Example: [('post', '/flash/launch', self.launch_handler)]
Handlers receive a PluginContext with access to:
- state: Current State object
- send_command: Async function to send commands
- node_id: Current node's ID
"""
...
# ========== Master Command Handling ==========
@abstractmethod
def handles_command(self, command: Any) -> bool: # pyright: ignore[reportAny]
"""Return True if this plugin handles the given command type."""
...
@abstractmethod
def process_command(
self,
command: Any, # pyright: ignore[reportAny]
topology: "Topology",
current_instances: Mapping[InstanceId, "Instance"],
) -> Sequence[Event]:
"""Process a command and return events to emit.
Typically creates placement and returns InstanceCreated/InstanceDeleted events.
Args:
command: The command to process
topology: Current cluster topology
current_instances: Currently running instances
Returns:
Sequence of events to emit (e.g., InstanceCreated, InstanceDeleted)
"""
...
# ========== Worker Planning ==========
@abstractmethod
def handles_instance(self, instance: object) -> bool:
"""Return True if this plugin manages the given instance type."""
...
@abstractmethod
def plan_task(
self,
runners: Mapping[RunnerId, "RunnerSupervisor"],
instances: Mapping[InstanceId, "Instance"],
) -> Task | None:
"""Plan the next task for plugin instances.
Called during each planning cycle.
Return None if no task is needed.
"""
...
@abstractmethod
def should_skip_download(self, instance: object) -> bool:
"""Return True if this instance type doesn't need model downloads."""
...
# ========== Runner Bootstrap ==========
@abstractmethod
def create_runner(
self,
bound_instance: "BoundInstance",
event_sender: "MpSender[Event]",
task_receiver: "MpReceiver[Task]",
) -> None:
"""Entry point for the runner process.
Called in a subprocess to execute the actual workload.
This function should block until the workload completes.
"""
...

View File

@@ -0,0 +1,21 @@
"""Context objects passed to plugin handlers."""
from collections.abc import Awaitable, Callable
from dataclasses import dataclass
from exo.shared.types.commands import Command
from exo.shared.types.common import NodeId
from exo.shared.types.state import State
@dataclass
class PluginContext:
"""Context provided to plugin API handlers.
This gives plugins access to the current state and the ability to send
commands without direct access to internal API components.
"""
state: State
send_command: Callable[[Command], Awaitable[None]]
node_id: NodeId

View File

@@ -0,0 +1,5 @@
"""Plugin implementations directory.
Each subdirectory should contain a plugin with a register() function
that returns an EXOPlugin instance.
"""

View File

@@ -0,0 +1,32 @@
"""FLASH Plugin - MPI-based simulation support for Exo."""
from typing import TYPE_CHECKING, Any
# Import types directly (these don't cause circular imports)
from exo.plugins.implementations.flash.types import (
FLASHInstance,
LaunchFLASH,
StopFLASH,
)
if TYPE_CHECKING:
from exo.plugins.implementations.flash.plugin import FLASHPlugin
__all__ = ["FLASHPlugin", "FLASHInstance", "LaunchFLASH", "StopFLASH", "register"]
def register() -> "FLASHPlugin":
"""Entry point for plugin discovery."""
# Lazy import to avoid circular imports during module loading
from exo.plugins.implementations.flash.plugin import FLASHPlugin
return FLASHPlugin()
# For backwards compatibility, allow importing FLASHPlugin from this module
def __getattr__(name: str) -> Any: # pyright: ignore[reportAny]
if name == "FLASHPlugin":
from exo.plugins.implementations.flash.plugin import FLASHPlugin
return FLASHPlugin
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")

View File

@@ -0,0 +1,109 @@
"""FLASH plugin API handlers."""
from typing import Any
from fastapi import HTTPException
from exo.plugins.context import PluginContext
from exo.plugins.implementations.flash.types import (
FLASHInstance,
LaunchFLASH,
StopFLASH,
)
async def handle_launch_flash(
ctx: PluginContext,
simulation_name: str,
flash_executable_path: str,
working_directory: str,
parameter_file_path: str = "",
ranks_per_node: int = 1,
min_nodes: int = 1,
hosts: str = "",
) -> dict[str, str]:
"""Launch a FLASH MPI simulation across the cluster.
Args:
ctx: Plugin context with state and send_command
simulation_name: Name of the simulation
flash_executable_path: Path to the FLASH executable
working_directory: Working directory for the simulation
parameter_file_path: Path to parameter file (optional)
ranks_per_node: Number of MPI ranks per node
min_nodes: Minimum number of nodes required
hosts: Optional comma-separated hostnames (e.g., "s14,james21-1").
If not provided, IPs are discovered from topology edges.
"""
command = LaunchFLASH(
simulation_name=simulation_name,
flash_executable_path=flash_executable_path,
parameter_file_path=parameter_file_path,
working_directory=working_directory,
ranks_per_node=ranks_per_node,
min_nodes=min_nodes,
hosts=hosts,
)
await ctx.send_command(command)
return {
"message": "FLASH launch command received",
"command_id": str(command.command_id),
"simulation_name": simulation_name,
}
async def handle_stop_flash(
ctx: PluginContext,
instance_id: str,
) -> dict[str, str]:
"""Stop a running FLASH simulation."""
from exo.shared.types.worker.instances import InstanceId
inst_id = InstanceId(instance_id)
if inst_id not in ctx.state.instances:
raise HTTPException(status_code=404, detail="Instance not found")
instance = ctx.state.instances[inst_id]
if not isinstance(instance, FLASHInstance):
raise HTTPException(
status_code=400, detail="Instance is not a FLASH simulation"
)
command = StopFLASH(instance_id=inst_id)
await ctx.send_command(command)
return {
"message": "Stop command received",
"command_id": str(command.command_id),
"instance_id": str(instance_id),
}
async def handle_list_flash_instances(ctx: PluginContext) -> list[dict[str, Any]]:
"""List all FLASH simulation instances."""
flash_instances: list[dict[str, Any]] = []
for instance_id, instance in ctx.state.instances.items():
if isinstance(instance, FLASHInstance):
# Get runner statuses for this instance
runner_statuses: dict[str, str | None] = {}
for (
node_id,
runner_id,
) in instance.shard_assignments.node_to_runner.items():
runner_status = ctx.state.runners.get(runner_id)
runner_statuses[str(node_id)] = (
str(runner_status) if runner_status else None
)
flash_instances.append(
{
"instance_id": str(instance_id),
"simulation_name": instance.simulation_name,
"total_ranks": instance.total_ranks,
"working_directory": instance.working_directory,
"runner_statuses": runner_statuses,
}
)
return flash_instances

View File

@@ -0,0 +1,152 @@
"""FLASH plugin placement logic."""
from collections.abc import Mapping
from copy import deepcopy
from loguru import logger
from exo.plugins.implementations.flash.types import FLASHInstance, LaunchFLASH
from exo.shared.topology import Topology
from exo.shared.types.common import Host, NodeId
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.topology import SocketConnection
from exo.shared.types.worker.instances import Instance, InstanceId
from exo.shared.types.worker.runners import (
RunnerId,
ShardAssignments,
)
from exo.shared.types.worker.shards import PipelineShardMetadata
def place_flash_instance(
command: LaunchFLASH,
topology: Topology,
current_instances: Mapping[InstanceId, Instance],
) -> dict[InstanceId, Instance]:
"""Place a FLASH simulation instance across available nodes.
Unlike MLX instances which use ring/JACCL topology for tensor parallelism,
FLASH instances use MPI for communication. We just need to provide the
node IPs so the runner can generate an MPI hostfile.
"""
instance_id = InstanceId()
target_instances: dict[InstanceId, Instance] = dict(deepcopy(current_instances))
all_nodes = list(topology.list_nodes())
if len(all_nodes) < command.min_nodes:
raise ValueError(
f"Not enough nodes: need {command.min_nodes}, have {len(all_nodes)}"
)
# Select nodes (take the first min_nodes)
selected_nodes = all_nodes[: command.min_nodes]
logger.info(
f"Placing FLASH instance '{command.simulation_name}' on {len(selected_nodes)} nodes"
)
# Build shard assignments (one runner per node for FLASH)
runner_to_shard: dict[RunnerId, PipelineShardMetadata] = {}
node_to_runner: dict[NodeId, RunnerId] = {}
# Create a dummy ModelMetadata for FLASH (required by ShardMetadata interface)
flash_model_meta = ModelMetadata(
model_id=ModelId(command.simulation_name),
pretty_name=f"FLASH: {command.simulation_name}",
storage_size=Memory(in_bytes=0),
n_layers=1,
hidden_size=1,
supports_tensor=False,
)
for i, node_id in enumerate(selected_nodes):
runner_id = RunnerId()
node_to_runner[node_id] = runner_id
runner_to_shard[runner_id] = PipelineShardMetadata(
device_rank=i,
world_size=len(selected_nodes),
model_meta=flash_model_meta,
start_layer=0,
end_layer=1,
n_layers=1,
)
shard_assignments = ShardAssignments(
model_id=ModelId(command.simulation_name),
runner_to_shard=runner_to_shard,
node_to_runner=node_to_runner,
)
# Build hosts_by_node - get hostnames/IPs for MPI hostfile generation
hosts_by_node: dict[NodeId, list[Host]] = {}
# If explicit hosts are provided, use them directly
if command.hosts:
explicit_hosts = [h.strip() for h in command.hosts.split(",") if h.strip()]
logger.info(f"FLASH placement: explicit hosts provided: {explicit_hosts}")
for i, node_id in enumerate(selected_nodes):
if i < len(explicit_hosts):
hosts_by_node[node_id] = [Host(ip=explicit_hosts[i], port=0)]
logger.info(
f"FLASH placement: node {node_id} (rank {i}) -> IP {explicit_hosts[i]}"
)
else:
logger.warning(
f"Not enough hosts provided for node {i}, using localhost"
)
hosts_by_node[node_id] = [Host(ip="127.0.0.1", port=0)]
logger.info(
f"FLASH placement: coordinator will be rank 0 at IP {explicit_hosts[0]}"
)
else:
# Try to get IPs from topology edges
for node_id in selected_nodes:
node_hosts: list[Host] = []
# Get IP from outgoing edges (connections to other nodes via mDNS discovery)
for conn in topology.out_edges(node_id):
if isinstance(conn.edge, SocketConnection):
# Extract IP from multiaddr
ip = conn.edge.sink_multiaddr.ip_address
# Skip link-local and localhost addresses
if not ip.startswith("169.254.") and not ip.startswith("127."):
node_hosts.append(Host(ip=ip, port=0))
break
# Last resort: use localhost (will only work for single-node)
if not node_hosts:
logger.warning(
f"Could not determine IP for node {node_id}, using localhost"
)
node_hosts.append(Host(ip="127.0.0.1", port=0))
hosts_by_node[node_id] = node_hosts
total_ranks = len(selected_nodes) * command.ranks_per_node
# Determine coordinator IP - first node's first host IP
first_node_id: NodeId = next(iter(hosts_by_node.keys()))
coordinator_ip: str = (
hosts_by_node[first_node_id][0].ip
if hosts_by_node[first_node_id]
else "127.0.0.1"
)
target_instances[instance_id] = FLASHInstance(
instance_id=instance_id,
shard_assignments=shard_assignments,
hosts_by_node=hosts_by_node,
flash_executable_path=command.flash_executable_path,
parameter_file_path=command.parameter_file_path,
working_directory=command.working_directory,
ranks_per_node=command.ranks_per_node,
total_ranks=total_ranks,
simulation_name=command.simulation_name,
coordinator_ip=coordinator_ip,
)
logger.info(f"Created FLASH instance {instance_id} with {total_ranks} total ranks")
return target_instances

View File

@@ -0,0 +1,37 @@
"""FLASH plugin planning logic."""
from collections.abc import Mapping
from exo.plugins.implementations.flash.types import FLASHInstance
from exo.shared.types.tasks import LoadModel, Task
from exo.shared.types.worker.instances import Instance, InstanceId
from exo.shared.types.worker.runners import RunnerId, RunnerIdle
from exo.worker.runner.runner_supervisor import RunnerSupervisor
def plan_flash(
runners: Mapping[RunnerId, RunnerSupervisor],
instances: Mapping[InstanceId, Instance],
) -> Task | None:
"""Plan tasks specifically for FLASH instances.
FLASH instances have a simpler lifecycle:
- CreateRunner (handled by core _create_runner)
- LoadModel (starts the simulation immediately)
- Shutdown (handled by core _kill_runner)
This function handles the LoadModel step for FLASH instances,
skipping the MLX-specific download/init/warmup steps.
"""
for runner in runners.values():
instance = runner.bound_instance.instance
# Only handle FLASH instances
if not isinstance(instance, FLASHInstance):
continue
# If runner is idle, emit LoadModel to start the simulation
if isinstance(runner.status, RunnerIdle):
return LoadModel(instance_id=instance.instance_id)
return None

View File

@@ -0,0 +1,98 @@
"""FLASH Plugin - Main plugin class."""
from collections.abc import Callable, Mapping, Sequence
from typing import Any
from exo.plugins.base import EXOPlugin
from exo.plugins.implementations.flash.api_handlers import (
handle_launch_flash,
handle_list_flash_instances,
handle_stop_flash,
)
from exo.plugins.implementations.flash.placement import place_flash_instance
from exo.plugins.implementations.flash.planning import plan_flash
from exo.plugins.implementations.flash.runner import main as flash_runner_main
from exo.plugins.implementations.flash.types import (
FLASHInstance,
LaunchFLASH,
StopFLASH,
)
from exo.shared.topology import Topology
from exo.shared.types.commands import DeleteInstance
from exo.shared.types.events import Event
from exo.shared.types.tasks import Task
from exo.shared.types.worker.instances import BoundInstance, Instance, InstanceId
from exo.shared.types.worker.runners import RunnerId
from exo.utils.channels import MpReceiver, MpSender
from exo.worker.runner.runner_supervisor import RunnerSupervisor
class FLASHPlugin(EXOPlugin):
"""Plugin for FLASH MPI simulations."""
@property
def name(self) -> str:
return "flash"
@property
def version(self) -> str:
return "1.0.0"
def get_command_types(self) -> Sequence[type]:
return [LaunchFLASH, StopFLASH]
def get_instance_type(self) -> type:
return FLASHInstance
def get_api_routes(
self,
) -> Sequence[tuple[str, str, Callable[..., Any]]]:
return [
("post", "/flash/launch", handle_launch_flash),
("delete", "/flash/{instance_id}", handle_stop_flash),
("get", "/flash/instances", handle_list_flash_instances),
]
def handles_command(self, command: Any) -> bool: # pyright: ignore[reportAny]
return isinstance(command, (LaunchFLASH, StopFLASH))
def process_command(
self,
command: Any, # pyright: ignore[reportAny]
topology: Topology,
current_instances: Mapping[InstanceId, Instance],
) -> Sequence[Event]:
from exo.master.placement import delete_instance, get_transition_events
if isinstance(command, LaunchFLASH):
placement = place_flash_instance(command, topology, current_instances)
return list(get_transition_events(current_instances, placement))
elif isinstance(command, StopFLASH):
placement = delete_instance(
DeleteInstance(instance_id=command.instance_id),
current_instances,
)
return list(get_transition_events(current_instances, placement))
return []
def handles_instance(self, instance: object) -> bool:
return isinstance(instance, FLASHInstance)
def plan_task(
self,
runners: Mapping[RunnerId, RunnerSupervisor],
instances: Mapping[InstanceId, Instance],
) -> Task | None:
return plan_flash(runners, instances)
def should_skip_download(self, instance: object) -> bool:
# FLASH instances don't need model downloads
return True
def create_runner(
self,
bound_instance: BoundInstance,
event_sender: MpSender[Event],
task_receiver: MpReceiver[Task],
) -> None:
flash_runner_main(bound_instance, event_sender, task_receiver)

View File

@@ -0,0 +1,304 @@
"""FLASH MPI Runner - spawns and monitors FLASH simulations.
Exo-native distributed MPI:
- Exo handles node discovery and coordination
- Coordinator generates hostfile from Exo topology
- mpirun uses exo-rsh (no SSH required) to spawn on remote nodes
- exo-rsh connects to each node's Exo API (/execute endpoint) for remote execution
- Workers just report ready and wait
"""
# ruff: noqa: I001 - Import order intentional (plugin types before shared types)
import os
import shutil
import socket
import subprocess
import threading
from loguru import logger
from exo.shared.types.events import (
Event,
RunnerStatusUpdated,
TaskAcknowledged,
TaskStatusUpdated,
)
from exo.shared.types.tasks import (
LoadModel,
Shutdown,
Task,
TaskStatus,
)
from exo.plugins.implementations.flash.types import FLASHInstance
from exo.shared.types.worker.instances import BoundInstance
from exo.shared.types.worker.runners import (
RunnerFailed,
RunnerIdle,
RunnerLoading,
RunnerReady,
RunnerRunning,
RunnerShutdown,
RunnerShuttingDown,
RunnerStatus,
)
from exo.utils.channels import MpReceiver, MpSender
# Find mpirun in PATH, fallback to common locations
MPIRUN_PATH = shutil.which("mpirun") or "/opt/homebrew/bin/mpirun"
# exo-rsh is installed as console script by exo package
_exo_rsh_path = shutil.which("exo-rsh")
if not _exo_rsh_path:
raise RuntimeError("exo-rsh not found in PATH - this should be installed with exo")
EXO_RSH_PATH: str = _exo_rsh_path
def get_my_rank(instance: FLASHInstance, my_node_id: str) -> int:
"""Determine this node's rank based on position in hosts_by_node."""
for i, node_id in enumerate(instance.hosts_by_node.keys()):
if str(node_id) == str(my_node_id):
return i
return -1
def get_coordinator_host(instance: FLASHInstance) -> str:
"""Get the IP of the coordinator node."""
return instance.coordinator_ip
def resolve_host(host: str) -> str:
"""Resolve host string to a usable hostname for MPI hostfile.
Accepts either an IP address or hostname. For IPs, attempts to resolve
to a hostname via DNS/mDNS. Hostnames are returned as-is after validation.
"""
# Check if input is already a hostname (not an IP)
try:
socket.inet_aton(host)
is_ip = True
except socket.error:
is_ip = False
if not is_ip:
# Already a hostname, verify it resolves and return as-is
try:
socket.gethostbyname(host)
return host
except socket.gaierror:
logger.warning(f"Hostname {host} does not resolve, using anyway")
return host
# It's an IP address, try to resolve to hostname
try:
hostname, _, _ = socket.gethostbyaddr(host)
hostname = hostname.split(".")[0]
logger.info(f"Resolved {host} to {hostname}")
return hostname
except socket.herror:
pass
# Fall back to IP
logger.warning(f"Could not resolve {host} to hostname, using IP directly")
return host
def generate_hostfile(instance: FLASHInstance, working_dir: str) -> str:
"""Generate MPI hostfile from instance topology."""
hostfile_path = os.path.join(working_dir, "flash_hosts.txt")
with open(hostfile_path, "w") as f:
for _node_id, hosts in instance.hosts_by_node.items():
if hosts:
host = resolve_host(hosts[0].ip)
f.write(f"{host} slots={instance.ranks_per_node}\n")
logger.info(f"Generated hostfile at {hostfile_path}")
with open(hostfile_path, "r") as f:
logger.info(f"Hostfile contents:\n{f.read()}")
return hostfile_path
def main(
bound_instance: BoundInstance,
event_sender: MpSender[Event],
task_receiver: MpReceiver[Task],
) -> None:
"""Main FLASH runner loop.
Coordinator: generates hostfile and runs mpirun (uses exo-rsh instead of SSH)
Workers: just report ready and wait for mpirun to spawn processes on them
"""
assert isinstance(bound_instance.instance, FLASHInstance)
instance = bound_instance.instance
runner_id = bound_instance.bound_runner_id
my_node_id = str(bound_instance.bound_node_id)
logger.info(f"FLASH runner starting for simulation: {instance.simulation_name}")
my_rank = get_my_rank(instance, my_node_id)
world_size = len(instance.hosts_by_node)
is_coordinator = my_rank == 0
coordinator_ip = get_coordinator_host(instance)
logger.info(
f"FLASH node: rank={my_rank}, world_size={world_size}, coordinator={is_coordinator}"
)
logger.info(f"FLASH coordinator IP: {coordinator_ip}")
process: subprocess.Popen[bytes] | None = None
current_status: RunnerStatus = RunnerIdle()
shutdown_requested = False
event_sender.send(
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
)
def monitor_output(proc: subprocess.Popen[bytes]) -> None:
"""Monitor FLASH stdout for progress updates."""
if proc.stdout is None:
return
for line in iter(proc.stdout.readline, b""):
if shutdown_requested:
break
try:
decoded: str = line.decode("utf-8", errors="replace").strip()
if decoded:
logger.info(f"[FLASH] {decoded}")
except Exception as e:
logger.warning(f"Error parsing FLASH output: {e}")
with task_receiver as tasks:
for task in tasks:
event_sender.send(
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Running)
)
event_sender.send(TaskAcknowledged(task_id=task.task_id))
match task:
case LoadModel() if isinstance(current_status, RunnerIdle):
current_status = RunnerLoading()
logger.info("Starting FLASH simulation")
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=current_status
)
)
try:
if is_coordinator:
# Coordinator: generate hostfile and run mpirun
hostfile = generate_hostfile(
instance, instance.working_directory
)
iface = instance.network_interface
cmd = [
MPIRUN_PATH,
"-np",
str(instance.total_ranks),
"--hostfile",
hostfile,
"--wdir",
instance.working_directory,
"--oversubscribe",
"--mca",
"btl",
"tcp,self",
"--mca",
"btl_tcp_if_include",
iface,
"--mca",
"oob_tcp_if_include",
iface,
"--mca",
"plm_rsh_no_tree_spawn",
"1",
]
# Use exo-rsh for remote execution (no SSH needed)
cmd.extend(["--mca", "plm_rsh_agent", EXO_RSH_PATH])
cmd.append(instance.flash_executable_path)
logger.info(f"FLASH distributed launch: {' '.join(cmd)}")
process = subprocess.Popen(
cmd,
cwd=instance.working_directory,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
)
monitor_thread = threading.Thread(
target=monitor_output, args=(process,), daemon=True
)
monitor_thread.start()
current_status = RunnerRunning()
logger.info(
f"FLASH running on {world_size} nodes with {instance.total_ranks} ranks"
)
else:
# Worker: mpirun on coordinator will use exo-rsh to spawn processes here
logger.info(
f"Worker {my_rank}: Ready for mpirun to spawn processes via exo-rsh"
)
current_status = RunnerRunning()
except Exception as e:
logger.error(f"Failed to start FLASH: {e}")
import traceback
logger.error(traceback.format_exc())
current_status = RunnerFailed(error_message=str(e))
case Shutdown():
shutdown_requested = True
current_status = RunnerShuttingDown()
logger.info("FLASH runner shutting down")
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=current_status
)
)
if process and process.poll() is None:
logger.info("Terminating FLASH simulation")
process.terminate()
try:
process.wait(timeout=10)
except subprocess.TimeoutExpired:
logger.warning("FLASH didn't terminate, killing")
process.kill()
process.wait()
current_status = RunnerShutdown()
case _:
if process and process.poll() is not None:
exit_code = process.returncode
if exit_code == 0:
logger.info("FLASH simulation completed successfully")
current_status = RunnerReady()
else:
logger.error(
f"FLASH simulation failed with code {exit_code}"
)
current_status = RunnerFailed(
error_message=f"Exit code {exit_code}"
)
event_sender.send(
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Complete)
)
event_sender.send(
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
)
if isinstance(current_status, RunnerShutdown):
break
if process and process.poll() is None:
process.terminate()
process.wait(timeout=5)
logger.info("FLASH runner exiting")

View File

@@ -0,0 +1,91 @@
"""FLASH plugin types - commands and instances."""
# ruff: noqa: I001 - Import order intentional for Pydantic model_rebuild
from __future__ import annotations
from typing import TYPE_CHECKING
from pydantic import Field
from exo.shared.types.common import CommandId, Host, NodeId
from exo.shared.types.worker.runners import ShardAssignments
from exo.utils.pydantic_ext import TaggedModel
if TYPE_CHECKING:
from exo.shared.types.worker.instances import InstanceId
from exo.shared.types.worker.runners import RunnerId
from exo.shared.types.worker.shards import (
PipelineShardMetadata,
TensorShardMetadata,
)
# ============================================================================
# Commands
# ============================================================================
class LaunchFLASH(TaggedModel):
"""Command to launch a FLASH MPI simulation."""
command_id: CommandId = Field(default_factory=CommandId)
simulation_name: str
flash_executable_path: str
parameter_file_path: str
working_directory: str
ranks_per_node: int = 1
min_nodes: int = 1
# Optional: explicit hostnames for MPI (e.g., "s14,james21-1")
# Used when topology edges don't contain IP addresses
hosts: str = ""
class StopFLASH(TaggedModel):
"""Command to stop a running FLASH simulation."""
command_id: CommandId = Field(default_factory=CommandId)
instance_id: "InstanceId"
# ============================================================================
# Instances
# ============================================================================
class FLASHInstance(TaggedModel):
"""Instance for FLASH MPI simulation.
Unlike MLX instances which do tensor parallelism, FLASH instances
coordinate MPI processes across nodes. Each node runs one or more
MPI ranks of the FLASH simulation.
"""
instance_id: "InstanceId"
shard_assignments: ShardAssignments
hosts_by_node: dict[NodeId, list[Host]]
flash_executable_path: str
parameter_file_path: str
working_directory: str
ranks_per_node: int = 1
total_ranks: int
simulation_name: str
coordinator_ip: str
network_interface: str = "en0" # Network interface for MPI (e.g., en0, eth0)
def shard(
self, runner_id: "RunnerId"
) -> "PipelineShardMetadata | TensorShardMetadata | None":
return self.shard_assignments.runner_to_shard.get(runner_id, None)
# Import types into module namespace for Pydantic model_rebuild() to resolve forward refs
from exo.shared.types.worker.instances import InstanceId as InstanceId # noqa: E402, I001
from exo.shared.types.worker.runners import RunnerId as RunnerId # noqa: E402, I001
from exo.shared.types.worker.shards import ( # noqa: E402, I001
PipelineShardMetadata as PipelineShardMetadata,
TensorShardMetadata as TensorShardMetadata,
)
# Rebuild models to resolve forward references
StopFLASH.model_rebuild()
FLASHInstance.model_rebuild()

110
src/exo/plugins/registry.py Normal file
View File

@@ -0,0 +1,110 @@
"""Plugin registry for discovering and managing plugins."""
from collections.abc import Callable, Sequence
from typing import Any
from loguru import logger
from exo.plugins.base import EXOPlugin
class PluginRegistry:
"""Central registry for all plugins."""
_instance: "PluginRegistry | None" = None
def __init__(self) -> None:
self._plugins: dict[str, EXOPlugin] = {}
self._command_handlers: dict[type, EXOPlugin] = {}
self._instance_handlers: dict[type, EXOPlugin] = {}
@classmethod
def get(cls) -> "PluginRegistry":
"""Get the singleton registry instance."""
if cls._instance is None:
cls._instance = cls()
return cls._instance
@classmethod
def reset(cls) -> None:
"""Reset the singleton instance (useful for testing)."""
cls._instance = None
def register(self, plugin: EXOPlugin) -> None:
"""Register a plugin and its types."""
if plugin.name in self._plugins:
raise ValueError(f"Plugin '{plugin.name}' already registered")
logger.info(f"Registering plugin: {plugin.name} v{plugin.version}")
self._plugins[plugin.name] = plugin
# Register command handlers
for cmd_type in plugin.get_command_types():
self._command_handlers[cmd_type] = plugin
logger.debug(f" Registered command: {cmd_type.__name__}")
# Register instance handler
instance_type = plugin.get_instance_type()
self._instance_handlers[instance_type] = plugin
logger.debug(f" Registered instance: {instance_type.__name__}")
def get_plugin(self, name: str) -> EXOPlugin | None:
"""Get a plugin by name."""
return self._plugins.get(name)
def get_plugin_for_command(self, command: object) -> EXOPlugin | None:
"""Get the plugin that handles a command."""
for plugin in self._plugins.values():
if plugin.handles_command(command):
return plugin
return None
def get_plugin_for_instance(self, instance: object) -> EXOPlugin | None:
"""Get the plugin that manages an instance."""
for plugin in self._plugins.values():
if plugin.handles_instance(instance):
return plugin
return None
def all_plugins(self) -> Sequence[EXOPlugin]:
"""Get all registered plugins."""
return list(self._plugins.values())
def get_all_api_routes(
self,
) -> Sequence[tuple[str, str, Callable[..., Any], EXOPlugin]]:
"""Get all API routes from all plugins."""
routes: list[tuple[str, str, Callable[..., Any], EXOPlugin]] = []
for plugin in self._plugins.values():
for method, path, handler in plugin.get_api_routes():
routes.append((method, path, handler, plugin))
return routes
def discover_plugins() -> None:
"""Auto-discover and register plugins from the implementations directory.
Plugins should have a register() function that returns an EXOPlugin instance.
"""
import importlib
import pkgutil
registry = PluginRegistry.get()
try:
import exo.plugins.implementations as impl_package
for _, module_name, _ in pkgutil.iter_modules(impl_package.__path__):
try:
module = importlib.import_module(
f"exo.plugins.implementations.{module_name}"
)
if hasattr(module, "register"):
plugin = module.register() # pyright: ignore[reportAny]
if plugin is not None:
registry.register(plugin) # pyright: ignore[reportAny]
except Exception as e:
logger.warning(f"Failed to load plugin {module_name}: {e}")
except ImportError:
logger.debug("No plugin implementations package found")

13
src/exo/rsh/__init__.py Normal file
View File

@@ -0,0 +1,13 @@
"""Exo RSH - Remote Shell for MPI without SSH.
This module provides a remote execution mechanism that allows mpirun to spawn
processes on remote nodes without requiring SSH setup. It works by:
1. Each Exo node runs an API server on port 52415 with an /execute endpoint
2. The exo-rsh script acts as a drop-in replacement for ssh
3. When mpirun calls "exo-rsh hostname command", it HTTP POSTs to the target's /execute
4. The target executes the command and returns output
Usage:
mpirun --mca plm_rsh_agent exo-rsh -np 4 --hostfile hosts.txt ./program
"""

101
src/exo/rsh/client.py Normal file
View File

@@ -0,0 +1,101 @@
#!/usr/bin/env python3
"""exo-rsh - Remote shell client for MPI.
This script is called by mpirun as a replacement for ssh.
Usage: exo-rsh [ssh-options...] hostname command [args...]
It connects to the target node's Exo API (port 52415) and executes the command.
"""
import json
import socket
import sys
from typing import Any, cast
from urllib.error import URLError
from urllib.request import Request, urlopen
# Use the same port as Exo's API server
EXO_API_PORT = 52415
def resolve_hostname(hostname: str) -> str:
"""Resolve hostname to IP address."""
try:
return socket.gethostbyname(hostname)
except socket.gaierror:
# If resolution fails, try using the hostname directly
return hostname
def main():
# Parse arguments - mpirun calls us like: exo-rsh [options] hostname command [args...]
# SSH options we might see: -x (disable X11), -o options, etc.
args = sys.argv[1:]
# Skip SSH-style options
hostname = None
command_start = 0
i = 0
while i < len(args):
arg = args[i]
if arg.startswith("-"):
# Skip option and its value if needed
if arg in ("-o", "-i", "-l", "-p", "-F"):
i += 2 # Skip option and its argument
continue
i += 1
continue
else:
# First non-option is the hostname
hostname = arg
command_start = i + 1
break
i += 1
if hostname is None or command_start >= len(args):
print("Usage: exo-rsh [options] hostname command [args...]", file=sys.stderr)
sys.exit(1)
command = args[command_start:]
# Resolve hostname to IP
ip = resolve_hostname(hostname)
# Make request to Exo API
url = f"http://{ip}:{EXO_API_PORT}/execute"
data = json.dumps({"command": command}).encode("utf-8")
try:
req = Request(url, data=data, headers={"Content-Type": "application/json"})
with urlopen(req, timeout=300) as response: # pyright: ignore[reportAny]
response_body: bytes = cast(bytes, response.read()) # pyright: ignore[reportAny]
result: dict[str, Any] = json.loads(response_body.decode("utf-8")) # pyright: ignore[reportAny]
# Output stdout/stderr
stdout: str = cast(str, result.get("stdout", ""))
stderr: str = cast(str, result.get("stderr", ""))
exit_code: int = cast(int, result.get("exit_code", 0))
if stdout:
sys.stdout.write(stdout)
sys.stdout.flush()
if stderr:
sys.stderr.write(stderr)
sys.stderr.flush()
sys.exit(exit_code)
except URLError as e:
print(
f"exo-rsh: Failed to connect to {hostname}:{EXO_API_PORT}: {e}",
file=sys.stderr,
)
sys.exit(255)
except Exception as e:
print(f"exo-rsh: Error: {e}", file=sys.stderr)
sys.exit(1)
if __name__ == "__main__":
main()

View File

@@ -1,3 +1,4 @@
# ruff: noqa: I001 - Import order intentional to avoid circular imports
from pydantic import Field
from exo.shared.types.api import ChatCompletionTaskParams
@@ -7,6 +8,9 @@ from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.shards import Sharding
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
# Import FLASH commands from plugin (for serialization compatibility)
from exo.plugins.implementations.flash.types import LaunchFLASH, StopFLASH # noqa: E402, I001
class BaseCommand(TaggedModel):
command_id: CommandId = Field(default_factory=CommandId)
@@ -50,6 +54,8 @@ Command = (
| PlaceInstance
| CreateInstance
| DeleteInstance
| LaunchFLASH
| StopFLASH
| TaskFinished
)

View File

@@ -1,3 +1,4 @@
# ruff: noqa: I001 - Import order intentional to avoid circular imports
from enum import Enum
from pydantic import model_validator
@@ -14,6 +15,7 @@ class InstanceId(Id):
class InstanceMeta(str, Enum):
MlxRing = "MlxRing"
MlxJaccl = "MlxJaccl"
FLASH = "FLASH"
class BaseInstance(TaggedModel):
@@ -34,8 +36,12 @@ class MlxJacclInstance(BaseInstance):
jaccl_coordinators: dict[NodeId, str]
# Import FLASHInstance from plugin (for serialization compatibility)
from exo.plugins.implementations.flash.types import FLASHInstance # noqa: E402, I001
# TODO: Single node instance
Instance = MlxRingInstance | MlxJacclInstance
Instance = MlxRingInstance | MlxJacclInstance | FLASHInstance
class BoundInstance(CamelCaseModel):

View File

@@ -212,6 +212,11 @@ def mlx_distributed_init(
os.environ["MLX_JACCL_COORDINATOR"] = jaccl_coordinator
group = mx.distributed.init(backend="jaccl", strict=True)
case _:
raise ValueError(
f"Unsupported instance type for MLX distributed: {type(bound_instance.instance)}"
)
logger.info(f"Rank {rank} mlx distributed initialization complete")
return group

View File

@@ -21,7 +21,11 @@ from exo.shared.types.worker.downloads import (
DownloadOngoing,
DownloadProgress,
)
from exo.shared.types.worker.instances import BoundInstance, Instance, InstanceId
from exo.shared.types.worker.instances import (
BoundInstance,
Instance,
InstanceId,
)
from exo.shared.types.worker.runners import (
RunnerConnected,
RunnerConnecting,
@@ -50,6 +54,16 @@ def plan(
all_runners: Mapping[RunnerId, RunnerStatus], # all global
tasks: Mapping[TaskId, Task],
) -> Task | None:
from exo.plugins.registry import PluginRegistry
registry = PluginRegistry.get()
# Check plugin tasks first
for plugin in registry.all_plugins():
task = plugin.plan_task(runners, instances)
if task is not None:
return task
# Python short circuiting OR logic should evaluate these sequentially.
return (
_kill_runner(runners, all_runners, instances)
@@ -113,7 +127,18 @@ def _model_needs_download(
runners: Mapping[RunnerId, RunnerSupervisor],
download_status: Mapping[ModelId, DownloadProgress],
) -> DownloadModel | None:
from exo.plugins.registry import PluginRegistry
registry = PluginRegistry.get()
for runner in runners.values():
instance = runner.bound_instance.instance
# Check if any plugin wants to skip download for this instance
plugin = registry.get_plugin_for_instance(instance)
if plugin is not None and plugin.should_skip_download(instance):
continue
model_id = runner.bound_instance.bound_shard.model_meta.model_id
if isinstance(runner.status, RunnerIdle) and (
model_id not in download_status

View File

@@ -4,7 +4,10 @@ import loguru
from exo.shared.types.events import Event, RunnerStatusUpdated
from exo.shared.types.tasks import Task
from exo.shared.types.worker.instances import BoundInstance, MlxJacclInstance
from exo.shared.types.worker.instances import (
BoundInstance,
MlxJacclInstance,
)
from exo.shared.types.worker.runners import RunnerFailed
from exo.utils.channels import ClosedResourceError, MpReceiver, MpSender
@@ -17,6 +20,7 @@ def entrypoint(
task_receiver: MpReceiver[Task],
_logger: "loguru.Logger",
) -> None:
# Set FAST_SYNCH based on env var or JACCL device count
fast_synch_override = os.environ.get("EXO_FAST_SYNCH")
if fast_synch_override == "on" or (
fast_synch_override != "off"
@@ -34,11 +38,26 @@ def entrypoint(
logger.info(f"Fast synch flag: {os.environ['MLX_METAL_FAST_SYNCH']}")
# Import main after setting global logger - this lets us just import logger from this module
# Route based on instance type (plugins or default MLX)
try:
from exo.worker.runner.runner import main
from exo.plugins.registry import PluginRegistry, discover_plugins
main(bound_instance, event_sender, task_receiver)
# Discover plugins in subprocess (they aren't inherited from main process)
discover_plugins()
registry = PluginRegistry.get()
instance = bound_instance.instance
# Check if a plugin handles this instance type
plugin = registry.get_plugin_for_instance(instance)
if plugin is not None:
# Delegate to plugin runner
plugin.create_runner(bound_instance, event_sender, task_receiver)
else:
# MLX runner (default)
from exo.worker.runner.runner import main
main(bound_instance, event_sender, task_receiver)
except ClosedResourceError:
logger.warning("Runner communication closed unexpectedly")
except Exception as e: