mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-20 07:46:42 -05:00
Compare commits
19 Commits
meta-insta
...
sami/flash
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4d74574edd | ||
|
|
1ed6ae4fc4 | ||
|
|
b29f5ef1f2 | ||
|
|
2df40ae8ad | ||
|
|
1ea358b808 | ||
|
|
a9db83ba6b | ||
|
|
cd630dea43 | ||
|
|
e55dae5ce8 | ||
|
|
302c43afd5 | ||
|
|
2cf59e2322 | ||
|
|
e506c7d65c | ||
|
|
c1fa2ddeaf | ||
|
|
37c5a2a246 | ||
|
|
4d7f03834a | ||
|
|
bdb9fbc8c0 | ||
|
|
8c7180810c | ||
|
|
318c6e000b | ||
|
|
2d45544da0 | ||
|
|
7cbafa768a |
@@ -34,6 +34,7 @@ dependencies = [
|
|||||||
exo-master = "exo.master.main:main"
|
exo-master = "exo.master.main:main"
|
||||||
exo-worker = "exo.worker.main:main"
|
exo-worker = "exo.worker.main:main"
|
||||||
exo = "exo.main:main"
|
exo = "exo.main:main"
|
||||||
|
exo-rsh = "exo.rsh.client:main"
|
||||||
|
|
||||||
# dependencies only required for development
|
# dependencies only required for development
|
||||||
[dependency-groups]
|
[dependency-groups]
|
||||||
|
|||||||
32
src/exo/cli/__init__.py
Normal file
32
src/exo/cli/__init__.py
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
"""Exo CLI - SLURM-compatible job management commands."""
|
||||||
|
|
||||||
|
|
||||||
|
def run_subcommand(command: str, args: list[str]) -> int:
|
||||||
|
"""Route to the appropriate subcommand handler.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
command: The subcommand name (sbatch, squeue, scancel, salloc)
|
||||||
|
args: Command line arguments for the subcommand
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Exit code from the subcommand
|
||||||
|
"""
|
||||||
|
if command == "sbatch":
|
||||||
|
from exo.cli.sbatch import main
|
||||||
|
|
||||||
|
return main(args)
|
||||||
|
elif command == "squeue":
|
||||||
|
from exo.cli.squeue import main
|
||||||
|
|
||||||
|
return main(args)
|
||||||
|
elif command == "scancel":
|
||||||
|
from exo.cli.scancel import main
|
||||||
|
|
||||||
|
return main(args)
|
||||||
|
elif command == "salloc":
|
||||||
|
from exo.cli.salloc import main
|
||||||
|
|
||||||
|
return main(args)
|
||||||
|
else:
|
||||||
|
print(f"Unknown subcommand: {command}")
|
||||||
|
return 1
|
||||||
118
src/exo/cli/common.py
Normal file
118
src/exo/cli/common.py
Normal file
@@ -0,0 +1,118 @@
|
|||||||
|
"""Common utilities for Exo CLI commands."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import urllib.request
|
||||||
|
from typing import Any
|
||||||
|
from urllib.error import HTTPError, URLError
|
||||||
|
|
||||||
|
# Default API endpoint
|
||||||
|
DEFAULT_API_HOST = "localhost"
|
||||||
|
DEFAULT_API_PORT = 52415
|
||||||
|
|
||||||
|
|
||||||
|
def get_api_base() -> str:
|
||||||
|
"""Get the API base URL from environment or defaults."""
|
||||||
|
host = os.environ.get("EXO_API_HOST", DEFAULT_API_HOST)
|
||||||
|
port = os.environ.get("EXO_API_PORT", str(DEFAULT_API_PORT))
|
||||||
|
return f"http://{host}:{port}"
|
||||||
|
|
||||||
|
|
||||||
|
def api_request(
|
||||||
|
method: str,
|
||||||
|
path: str,
|
||||||
|
data: dict[str, Any] | None = None,
|
||||||
|
) -> dict[str, Any] | list[Any]:
|
||||||
|
"""Make an API request to the Exo server.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
method: HTTP method (GET, POST, DELETE, etc.)
|
||||||
|
path: API path (e.g., "/flash/instances")
|
||||||
|
data: Optional JSON data for POST/PUT requests
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Parsed JSON response
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
SystemExit: On connection or HTTP errors
|
||||||
|
"""
|
||||||
|
url = f"{get_api_base()}{path}"
|
||||||
|
|
||||||
|
request_data = None
|
||||||
|
if data is not None:
|
||||||
|
request_data = json.dumps(data).encode("utf-8")
|
||||||
|
|
||||||
|
req = urllib.request.Request(
|
||||||
|
url,
|
||||||
|
data=request_data,
|
||||||
|
method=method,
|
||||||
|
)
|
||||||
|
req.add_header("Content-Type", "application/json")
|
||||||
|
|
||||||
|
try:
|
||||||
|
with urllib.request.urlopen(req, timeout=30) as response: # pyright: ignore[reportAny]
|
||||||
|
body: str = response.read().decode("utf-8") # pyright: ignore[reportAny]
|
||||||
|
if body:
|
||||||
|
return json.loads(body) # pyright: ignore[reportAny]
|
||||||
|
return {}
|
||||||
|
except HTTPError as e:
|
||||||
|
error_body = e.read().decode("utf-8") if e.fp else ""
|
||||||
|
print(f"API error: {e.code} {e.reason}")
|
||||||
|
if error_body:
|
||||||
|
try:
|
||||||
|
error_json: dict[str, str] = json.loads(error_body) # pyright: ignore[reportAny]
|
||||||
|
if "detail" in error_json:
|
||||||
|
print(f" {error_json['detail']}")
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
print(f" {error_body}")
|
||||||
|
raise SystemExit(1) from None
|
||||||
|
except URLError as e:
|
||||||
|
print(f"Connection error: {e.reason}")
|
||||||
|
print(f"Is Exo running at {get_api_base()}?")
|
||||||
|
raise SystemExit(1) from None
|
||||||
|
|
||||||
|
|
||||||
|
def truncate_id(instance_id: str, length: int = 8) -> str:
|
||||||
|
"""Truncate a UUID for display.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
instance_id: Full UUID string
|
||||||
|
length: Number of characters to keep
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Truncated ID without hyphens
|
||||||
|
"""
|
||||||
|
return instance_id.replace("-", "")[:length]
|
||||||
|
|
||||||
|
|
||||||
|
def format_table(headers: list[str], rows: list[list[str]]) -> str:
|
||||||
|
"""Format data as a simple text table.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
headers: Column headers
|
||||||
|
rows: List of rows, each row is a list of column values
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted table string
|
||||||
|
"""
|
||||||
|
if not rows:
|
||||||
|
return " ".join(f"{h:<10}" for h in headers)
|
||||||
|
|
||||||
|
# Calculate column widths
|
||||||
|
widths = [len(h) for h in headers]
|
||||||
|
for row in rows:
|
||||||
|
for i, cell in enumerate(row):
|
||||||
|
if i < len(widths):
|
||||||
|
widths[i] = max(widths[i], len(cell))
|
||||||
|
|
||||||
|
# Build format string
|
||||||
|
fmt = " ".join(f"{{:<{w}}}" for w in widths)
|
||||||
|
|
||||||
|
# Format output
|
||||||
|
lines = [fmt.format(*headers)]
|
||||||
|
for row in rows:
|
||||||
|
# Pad row if needed
|
||||||
|
padded = row + [""] * (len(headers) - len(row))
|
||||||
|
lines.append(fmt.format(*padded[: len(headers)]))
|
||||||
|
|
||||||
|
return "\n".join(lines)
|
||||||
100
src/exo/cli/salloc.py
Normal file
100
src/exo/cli/salloc.py
Normal file
@@ -0,0 +1,100 @@
|
|||||||
|
"""salloc - Allocate nodes for interactive use.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
exo salloc [options] [-- command [args...]]
|
||||||
|
|
||||||
|
Options:
|
||||||
|
-N, --nodes N Number of nodes to allocate (default: 1)
|
||||||
|
--hosts HOSTS Comma-separated list of hostnames
|
||||||
|
|
||||||
|
If a command is provided after --, it will be executed with
|
||||||
|
SLURM-like environment variables set:
|
||||||
|
SLURM_JOB_NODELIST - Comma-separated list of allocated nodes
|
||||||
|
SLURM_NNODES - Number of allocated nodes
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
exo salloc --nodes=2 --hosts=node1,node2 -- mpirun ./my_program
|
||||||
|
exo salloc --hosts=localhost -- bash
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
|
||||||
|
|
||||||
|
def main(args: list[str]) -> int:
|
||||||
|
"""Main entry point for salloc command."""
|
||||||
|
# Split args at -- if present
|
||||||
|
cmd_args: list[str] = []
|
||||||
|
salloc_args = args
|
||||||
|
|
||||||
|
if "--" in args:
|
||||||
|
idx = args.index("--")
|
||||||
|
salloc_args = args[:idx]
|
||||||
|
cmd_args = args[idx + 1 :]
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
prog="exo salloc",
|
||||||
|
description="Allocate nodes for interactive use",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-N",
|
||||||
|
"--nodes",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Number of nodes to allocate (default: 1)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--hosts",
|
||||||
|
help="Comma-separated list of hostnames (required)",
|
||||||
|
)
|
||||||
|
|
||||||
|
parsed = parser.parse_args(salloc_args)
|
||||||
|
|
||||||
|
nodes: int = parsed.nodes # pyright: ignore[reportAny]
|
||||||
|
hosts: str | None = parsed.hosts # pyright: ignore[reportAny]
|
||||||
|
|
||||||
|
# Require explicit hosts since we can't discover them from topology
|
||||||
|
if not hosts:
|
||||||
|
print("Error: --hosts is required (e.g., --hosts=node1,node2)", file=sys.stderr)
|
||||||
|
print(" The Exo topology doesn't expose hostnames.", file=sys.stderr)
|
||||||
|
return 1
|
||||||
|
|
||||||
|
host_list = [h.strip() for h in hosts.split(",") if h.strip()]
|
||||||
|
|
||||||
|
if len(host_list) < nodes:
|
||||||
|
print(
|
||||||
|
f"Error: Requested {nodes} nodes but only {len(host_list)} hosts provided",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
return 1
|
||||||
|
|
||||||
|
# Use first N hosts
|
||||||
|
allocated_hosts = host_list[:nodes]
|
||||||
|
nodelist = ",".join(allocated_hosts)
|
||||||
|
|
||||||
|
# Set environment variables
|
||||||
|
env = os.environ.copy()
|
||||||
|
env["SLURM_JOB_NODELIST"] = nodelist
|
||||||
|
env["SLURM_NNODES"] = str(nodes)
|
||||||
|
|
||||||
|
print(f"salloc: Granted job allocation on {nodes} node(s)")
|
||||||
|
print(f"salloc: Nodes: {nodelist}")
|
||||||
|
|
||||||
|
if cmd_args:
|
||||||
|
# Run the command
|
||||||
|
print(f"salloc: Running: {' '.join(cmd_args)}")
|
||||||
|
result = subprocess.run(cmd_args, env=env)
|
||||||
|
return result.returncode
|
||||||
|
else:
|
||||||
|
# Start interactive shell
|
||||||
|
shell = os.environ.get("SHELL", "/bin/bash")
|
||||||
|
print(f"salloc: Starting shell {shell}")
|
||||||
|
print("salloc: Use 'exit' to release allocation")
|
||||||
|
result = subprocess.run([shell], env=env)
|
||||||
|
return result.returncode
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
sys.exit(main(sys.argv[1:]))
|
||||||
233
src/exo/cli/sbatch.py
Normal file
233
src/exo/cli/sbatch.py
Normal file
@@ -0,0 +1,233 @@
|
|||||||
|
"""sbatch - Submit a batch job to Exo.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
exo sbatch [options] <script|executable>
|
||||||
|
exo sbatch --job-name=NAME --nodes=N <executable>
|
||||||
|
|
||||||
|
Options:
|
||||||
|
-J, --job-name NAME Job name
|
||||||
|
-N, --nodes N Number of nodes (default: 1)
|
||||||
|
--ntasks-per-node N Tasks per node (default: 1)
|
||||||
|
-D, --chdir DIR Working directory
|
||||||
|
--hosts HOSTS Comma-separated list of hostnames
|
||||||
|
|
||||||
|
Job scripts can contain #SBATCH directives:
|
||||||
|
#!/bin/bash
|
||||||
|
#SBATCH --job-name=Sod2D
|
||||||
|
#SBATCH --nodes=2
|
||||||
|
#SBATCH --chdir=/path/to/workdir
|
||||||
|
|
||||||
|
/path/to/flash4
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from exo.cli.common import api_request, truncate_id
|
||||||
|
|
||||||
|
|
||||||
|
def parse_job_script(script_path: str) -> tuple[dict[str, str], str | None]:
|
||||||
|
"""Parse a job script for #SBATCH directives and executable.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
script_path: Path to the job script
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (directives dict, executable path or None)
|
||||||
|
"""
|
||||||
|
directives: dict[str, str] = {}
|
||||||
|
executable: str | None = None
|
||||||
|
|
||||||
|
with open(script_path, "r") as f:
|
||||||
|
for line in f:
|
||||||
|
line = line.strip()
|
||||||
|
|
||||||
|
# Parse #SBATCH directives
|
||||||
|
if line.startswith("#SBATCH"):
|
||||||
|
# Handle both --option=value and --option value formats
|
||||||
|
match = re.match(r"#SBATCH\s+(-\w|--[\w-]+)(?:=|\s+)(.+)", line)
|
||||||
|
if match:
|
||||||
|
opt, val = match.groups()
|
||||||
|
directives[opt.lstrip("-")] = val.strip()
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Skip comments and empty lines
|
||||||
|
if line.startswith("#") or not line:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# First non-comment, non-directive line is the executable
|
||||||
|
if executable is None:
|
||||||
|
# Handle lines like "/path/to/flash4" or "srun /path/to/flash4"
|
||||||
|
parts = line.split()
|
||||||
|
if parts:
|
||||||
|
# Skip srun/mpirun prefixes if present
|
||||||
|
for part in parts:
|
||||||
|
if not part.startswith("-") and "/" in part:
|
||||||
|
executable = part
|
||||||
|
break
|
||||||
|
if executable is None and parts:
|
||||||
|
executable = parts[-1] # Last token
|
||||||
|
|
||||||
|
return directives, executable
|
||||||
|
|
||||||
|
|
||||||
|
def main(args: list[str]) -> int:
|
||||||
|
"""Main entry point for sbatch command."""
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
prog="exo sbatch",
|
||||||
|
description="Submit a batch job to Exo",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"script",
|
||||||
|
help="Job script or executable path",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-J",
|
||||||
|
"--job-name",
|
||||||
|
dest="job_name",
|
||||||
|
help="Job name",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-N",
|
||||||
|
"--nodes",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Number of nodes (default: 1)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--ntasks-per-node",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Tasks per node (default: 1)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-D",
|
||||||
|
"--chdir",
|
||||||
|
help="Working directory",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--hosts",
|
||||||
|
help="Comma-separated list of hostnames",
|
||||||
|
)
|
||||||
|
|
||||||
|
parsed = parser.parse_args(args)
|
||||||
|
|
||||||
|
# Extract typed values from namespace
|
||||||
|
script_path: str = parsed.script # pyright: ignore[reportAny]
|
||||||
|
arg_job_name: str | None = parsed.job_name # pyright: ignore[reportAny]
|
||||||
|
arg_nodes: int = parsed.nodes # pyright: ignore[reportAny]
|
||||||
|
arg_ntasks: int = parsed.ntasks_per_node # pyright: ignore[reportAny]
|
||||||
|
arg_chdir: str | None = parsed.chdir # pyright: ignore[reportAny]
|
||||||
|
arg_hosts: str | None = parsed.hosts # pyright: ignore[reportAny]
|
||||||
|
|
||||||
|
# Determine if input is a script or direct executable
|
||||||
|
executable: str | None = None
|
||||||
|
directives: dict[str, str] = {}
|
||||||
|
|
||||||
|
if os.path.isfile(script_path):
|
||||||
|
# Check if it's a binary file (executable) or text script
|
||||||
|
is_binary = False
|
||||||
|
try:
|
||||||
|
with open(script_path, "rb") as f:
|
||||||
|
chunk = f.read(512)
|
||||||
|
# Binary files typically contain null bytes
|
||||||
|
is_binary = b"\x00" in chunk
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if is_binary:
|
||||||
|
# It's a binary executable
|
||||||
|
executable = script_path
|
||||||
|
else:
|
||||||
|
# Try to read as text
|
||||||
|
try:
|
||||||
|
with open(script_path, "r") as f:
|
||||||
|
first_line = f.readline()
|
||||||
|
f.seek(0)
|
||||||
|
content = f.read(1024)
|
||||||
|
|
||||||
|
if first_line.startswith("#!") or "#SBATCH" in content:
|
||||||
|
# It's a job script - parse it
|
||||||
|
directives, executable = parse_job_script(script_path)
|
||||||
|
else:
|
||||||
|
# It's an executable (text but no shebang/directives)
|
||||||
|
executable = script_path
|
||||||
|
except UnicodeDecodeError:
|
||||||
|
# Can't read as text - treat as binary executable
|
||||||
|
executable = script_path
|
||||||
|
else:
|
||||||
|
# Not a file - treat as executable path
|
||||||
|
executable = script_path
|
||||||
|
|
||||||
|
if executable is None:
|
||||||
|
print("Error: No executable found in job script", file=sys.stderr)
|
||||||
|
return 1
|
||||||
|
|
||||||
|
# Build job parameters - CLI args override script directives
|
||||||
|
job_name = arg_job_name or directives.get("job-name") or directives.get("J")
|
||||||
|
if not job_name:
|
||||||
|
# Generate name from executable
|
||||||
|
job_name = os.path.basename(executable).replace(".", "_")
|
||||||
|
|
||||||
|
nodes = arg_nodes
|
||||||
|
if "nodes" in directives:
|
||||||
|
nodes = int(directives["nodes"])
|
||||||
|
if "N" in directives:
|
||||||
|
nodes = int(directives["N"])
|
||||||
|
if arg_nodes != 1: # CLI override
|
||||||
|
nodes = arg_nodes
|
||||||
|
|
||||||
|
ntasks = arg_ntasks
|
||||||
|
if "ntasks-per-node" in directives:
|
||||||
|
ntasks = int(directives["ntasks-per-node"])
|
||||||
|
if arg_ntasks != 1: # CLI override
|
||||||
|
ntasks = arg_ntasks
|
||||||
|
|
||||||
|
workdir = arg_chdir or directives.get("chdir") or directives.get("D")
|
||||||
|
if not workdir:
|
||||||
|
workdir = os.getcwd()
|
||||||
|
|
||||||
|
hosts = arg_hosts or directives.get("hosts") or ""
|
||||||
|
|
||||||
|
# Resolve executable to absolute path
|
||||||
|
if not os.path.isabs(executable):
|
||||||
|
executable = os.path.abspath(os.path.join(workdir, executable))
|
||||||
|
|
||||||
|
# Submit job via API using query parameters
|
||||||
|
from urllib.parse import urlencode
|
||||||
|
|
||||||
|
params = {
|
||||||
|
"simulation_name": job_name,
|
||||||
|
"flash_executable_path": executable,
|
||||||
|
"parameter_file_path": "", # FLASH par file - use default
|
||||||
|
"working_directory": workdir,
|
||||||
|
"ranks_per_node": str(ntasks),
|
||||||
|
"min_nodes": str(nodes),
|
||||||
|
"hosts": hosts,
|
||||||
|
}
|
||||||
|
|
||||||
|
query_string = urlencode(params)
|
||||||
|
result = api_request("POST", f"/flash/launch?{query_string}")
|
||||||
|
|
||||||
|
# Print job submission confirmation
|
||||||
|
if isinstance(result, dict):
|
||||||
|
instance_id_val = result.get("instance_id")
|
||||||
|
|
||||||
|
if instance_id_val is not None:
|
||||||
|
job_id = truncate_id(str(instance_id_val)) # pyright: ignore[reportAny]
|
||||||
|
print(f"Submitted batch job {job_id}")
|
||||||
|
else:
|
||||||
|
# Instance created asynchronously - user should check squeue
|
||||||
|
print("Job submitted successfully")
|
||||||
|
print("Use 'exo squeue' to view job ID")
|
||||||
|
else:
|
||||||
|
print("Job submitted successfully")
|
||||||
|
print("Use 'exo squeue' to view job ID")
|
||||||
|
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
sys.exit(main(sys.argv[1:]))
|
||||||
95
src/exo/cli/scancel.py
Normal file
95
src/exo/cli/scancel.py
Normal file
@@ -0,0 +1,95 @@
|
|||||||
|
"""scancel - Cancel jobs in the Exo queue.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
exo scancel <jobid> [<jobid>...]
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
jobid Job ID (or prefix) to cancel. Can specify multiple.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
exo scancel abc123 # Cancel job starting with abc123
|
||||||
|
exo scancel abc123 def456 # Cancel multiple jobs
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import sys
|
||||||
|
from typing import Any, cast
|
||||||
|
|
||||||
|
from exo.cli.common import api_request, truncate_id
|
||||||
|
|
||||||
|
|
||||||
|
def main(args: list[str]) -> int:
|
||||||
|
"""Main entry point for scancel command."""
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
prog="exo scancel",
|
||||||
|
description="Cancel jobs in the Exo queue",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"jobids",
|
||||||
|
nargs="+",
|
||||||
|
help="Job ID(s) to cancel",
|
||||||
|
)
|
||||||
|
|
||||||
|
parsed = parser.parse_args(args)
|
||||||
|
jobids: list[str] = parsed.jobids # pyright: ignore[reportAny]
|
||||||
|
|
||||||
|
# Fetch current jobs to resolve partial IDs
|
||||||
|
result = api_request("GET", "/flash/instances")
|
||||||
|
if isinstance(result, list):
|
||||||
|
instances = cast(list[dict[str, Any]], result)
|
||||||
|
else:
|
||||||
|
instances = cast(list[dict[str, Any]], result.get("instances", []))
|
||||||
|
|
||||||
|
# Build lookup of full IDs
|
||||||
|
id_map: dict[str, str] = {}
|
||||||
|
for inst in instances:
|
||||||
|
iid = inst.get("instance_id", "") # pyright: ignore[reportAny]
|
||||||
|
full_id = str(iid) if iid else "" # pyright: ignore[reportAny]
|
||||||
|
if full_id:
|
||||||
|
# Map both full ID and truncated versions
|
||||||
|
normalized = full_id.replace("-", "").lower()
|
||||||
|
id_map[normalized] = full_id
|
||||||
|
# Also map prefixes
|
||||||
|
for length in range(4, len(normalized) + 1):
|
||||||
|
prefix = normalized[:length]
|
||||||
|
if prefix not in id_map:
|
||||||
|
id_map[prefix] = full_id
|
||||||
|
|
||||||
|
cancelled = 0
|
||||||
|
errors = 0
|
||||||
|
|
||||||
|
for jobid in jobids:
|
||||||
|
search = jobid.lower().replace("-", "")
|
||||||
|
|
||||||
|
# Find matching full ID
|
||||||
|
full_id = id_map.get(search)
|
||||||
|
if not full_id:
|
||||||
|
# Try prefix match
|
||||||
|
matches = [fid for key, fid in id_map.items() if key.startswith(search)]
|
||||||
|
if len(matches) == 1:
|
||||||
|
full_id = matches[0]
|
||||||
|
elif len(matches) > 1:
|
||||||
|
print(f"Ambiguous job ID: {jobid} matches multiple jobs")
|
||||||
|
errors += 1
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
print(f"Job not found: {jobid}")
|
||||||
|
errors += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Cancel the job
|
||||||
|
try:
|
||||||
|
api_request("DELETE", f"/flash/{full_id}")
|
||||||
|
print(f"Job {truncate_id(full_id)} cancelled")
|
||||||
|
cancelled += 1
|
||||||
|
except SystemExit:
|
||||||
|
print(f"Failed to cancel job {truncate_id(full_id)}")
|
||||||
|
errors += 1
|
||||||
|
|
||||||
|
if errors > 0 and cancelled == 0:
|
||||||
|
return 1
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
sys.exit(main(sys.argv[1:]))
|
||||||
165
src/exo/cli/squeue.py
Normal file
165
src/exo/cli/squeue.py
Normal file
@@ -0,0 +1,165 @@
|
|||||||
|
"""squeue - View the Exo job queue.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
exo squeue [options]
|
||||||
|
|
||||||
|
Options:
|
||||||
|
-l, --long Show detailed output
|
||||||
|
-j, --job ID Show only this job
|
||||||
|
|
||||||
|
Output columns:
|
||||||
|
JOBID - Job identifier (truncated UUID)
|
||||||
|
NAME - Job name
|
||||||
|
NODES - Number of nodes
|
||||||
|
STATE - Job state (PENDING, RUNNING, FAILED, etc.)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import sys
|
||||||
|
from typing import Any, cast
|
||||||
|
|
||||||
|
from exo.cli.common import api_request, format_table, truncate_id
|
||||||
|
|
||||||
|
# Map Exo runner statuses to SLURM-like states
|
||||||
|
STATUS_MAP: dict[str, str] = {
|
||||||
|
"RunnerIdle": "PENDING",
|
||||||
|
"RunnerConnecting": "CONFIGURING",
|
||||||
|
"RunnerConnected": "CONFIGURING",
|
||||||
|
"RunnerLoading": "CONFIGURING",
|
||||||
|
"RunnerLoaded": "CONFIGURING",
|
||||||
|
"RunnerWarmingUp": "CONFIGURING",
|
||||||
|
"RunnerReady": "COMPLETING",
|
||||||
|
"RunnerRunning": "RUNNING",
|
||||||
|
"RunnerShuttingDown": "COMPLETING",
|
||||||
|
"RunnerShutdown": "COMPLETED",
|
||||||
|
"RunnerFailed": "FAILED",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_job_state(runner_statuses: dict[str, Any]) -> str:
|
||||||
|
"""Determine overall job state from runner statuses."""
|
||||||
|
if not runner_statuses:
|
||||||
|
return "PENDING"
|
||||||
|
|
||||||
|
states: set[str] = set()
|
||||||
|
for status_val in runner_statuses.values(): # pyright: ignore[reportAny]
|
||||||
|
if isinstance(status_val, dict):
|
||||||
|
# Extract status type from discriminated union
|
||||||
|
type_val = status_val.get("type", "RunnerIdle") # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
|
||||||
|
status_type = str(type_val) if type_val else "RunnerIdle" # pyright: ignore[reportUnknownArgumentType]
|
||||||
|
elif isinstance(status_val, str):
|
||||||
|
status_type = status_val
|
||||||
|
else:
|
||||||
|
status_type = "RunnerIdle"
|
||||||
|
# Strip parentheses from status strings like "RunnerRunning()"
|
||||||
|
if status_type.endswith("()"):
|
||||||
|
status_type = status_type[:-2]
|
||||||
|
states.add(STATUS_MAP.get(status_type, "UNKNOWN"))
|
||||||
|
|
||||||
|
# Priority order for overall state
|
||||||
|
if "FAILED" in states:
|
||||||
|
return "FAILED"
|
||||||
|
if "RUNNING" in states:
|
||||||
|
return "RUNNING"
|
||||||
|
if "CONFIGURING" in states:
|
||||||
|
return "CONFIGURING"
|
||||||
|
if "COMPLETING" in states:
|
||||||
|
return "COMPLETING"
|
||||||
|
if "COMPLETED" in states:
|
||||||
|
return "COMPLETED"
|
||||||
|
if "PENDING" in states:
|
||||||
|
return "PENDING"
|
||||||
|
return "UNKNOWN"
|
||||||
|
|
||||||
|
|
||||||
|
def main(args: list[str]) -> int:
|
||||||
|
"""Main entry point for squeue command."""
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
prog="exo squeue",
|
||||||
|
description="View the Exo job queue",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-l",
|
||||||
|
"--long",
|
||||||
|
action="store_true",
|
||||||
|
help="Show detailed output",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-j",
|
||||||
|
"--job",
|
||||||
|
help="Show only this job ID",
|
||||||
|
)
|
||||||
|
|
||||||
|
parsed = parser.parse_args(args)
|
||||||
|
|
||||||
|
# Extract typed values
|
||||||
|
long_format: bool = parsed.long # pyright: ignore[reportAny]
|
||||||
|
job_filter: str | None = parsed.job # pyright: ignore[reportAny]
|
||||||
|
|
||||||
|
# Fetch jobs from API - returns list directly
|
||||||
|
result = api_request("GET", "/flash/instances")
|
||||||
|
# API returns list directly, not {"instances": [...]}
|
||||||
|
if isinstance(result, list):
|
||||||
|
instances = cast(list[dict[str, Any]], result)
|
||||||
|
else:
|
||||||
|
instances = cast(list[dict[str, Any]], result.get("instances", []))
|
||||||
|
|
||||||
|
if not instances:
|
||||||
|
# No jobs - just print header
|
||||||
|
if long_format:
|
||||||
|
print("JOBID NAME NODES RANKS STATE WORKDIR")
|
||||||
|
else:
|
||||||
|
print("JOBID NAME NODES STATE")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
# Filter by job ID if specified
|
||||||
|
if job_filter:
|
||||||
|
search = job_filter.lower()
|
||||||
|
filtered: list[dict[str, Any]] = []
|
||||||
|
for i in instances:
|
||||||
|
iid = i.get("instance_id", "") # pyright: ignore[reportAny]
|
||||||
|
if search in str(iid).lower().replace("-", ""): # pyright: ignore[reportAny]
|
||||||
|
filtered.append(i)
|
||||||
|
instances = filtered
|
||||||
|
|
||||||
|
# Build table
|
||||||
|
rows: list[list[str]] = []
|
||||||
|
|
||||||
|
if long_format:
|
||||||
|
headers = ["JOBID", "NAME", "NODES", "RANKS", "STATE", "WORKDIR"]
|
||||||
|
for inst in instances:
|
||||||
|
iid_val = inst.get("instance_id", "") # pyright: ignore[reportAny]
|
||||||
|
instance_id = str(iid_val) if iid_val else "" # pyright: ignore[reportAny]
|
||||||
|
job_id = truncate_id(instance_id, 12)
|
||||||
|
name_val = inst.get("simulation_name", "") # pyright: ignore[reportAny]
|
||||||
|
name = (str(name_val) if name_val else "")[:15] # pyright: ignore[reportAny]
|
||||||
|
runner_statuses = cast(dict[str, Any], inst.get("runner_statuses", {}))
|
||||||
|
nodes = str(len(runner_statuses))
|
||||||
|
ranks_val = inst.get("total_ranks", 0) # pyright: ignore[reportAny]
|
||||||
|
ranks = str(ranks_val) if ranks_val else "0" # pyright: ignore[reportAny]
|
||||||
|
state = get_job_state(runner_statuses)
|
||||||
|
workdir_val = inst.get("working_directory", "") # pyright: ignore[reportAny]
|
||||||
|
workdir = str(workdir_val) if workdir_val else "" # pyright: ignore[reportAny]
|
||||||
|
# Truncate workdir for display
|
||||||
|
if len(workdir) > 30:
|
||||||
|
workdir = "..." + workdir[-27:]
|
||||||
|
rows.append([job_id, name, nodes, ranks, state, workdir])
|
||||||
|
else:
|
||||||
|
headers = ["JOBID", "NAME", "NODES", "STATE"]
|
||||||
|
for inst in instances:
|
||||||
|
iid_val = inst.get("instance_id", "") # pyright: ignore[reportAny]
|
||||||
|
instance_id = str(iid_val) if iid_val else "" # pyright: ignore[reportAny]
|
||||||
|
job_id = truncate_id(instance_id, 8)
|
||||||
|
name_val = inst.get("simulation_name", "") # pyright: ignore[reportAny]
|
||||||
|
name = (str(name_val) if name_val else "")[:15] # pyright: ignore[reportAny]
|
||||||
|
runner_statuses = cast(dict[str, Any], inst.get("runner_statuses", {}))
|
||||||
|
nodes = str(len(runner_statuses))
|
||||||
|
state = get_job_state(runner_statuses)
|
||||||
|
rows.append([job_id, name, nodes, state])
|
||||||
|
|
||||||
|
print(format_table(headers, rows))
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
sys.exit(main(sys.argv[1:]))
|
||||||
@@ -195,6 +195,14 @@ class Node:
|
|||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
# Check for SLURM-compatible subcommands first
|
||||||
|
import sys
|
||||||
|
|
||||||
|
if len(sys.argv) > 1 and sys.argv[1] in ("sbatch", "squeue", "scancel", "salloc"):
|
||||||
|
from exo.cli import run_subcommand
|
||||||
|
|
||||||
|
sys.exit(run_subcommand(sys.argv[1], sys.argv[2:]))
|
||||||
|
|
||||||
args = Args.parse()
|
args = Args.parse()
|
||||||
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
|
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
|
||||||
resource.setrlimit(resource.RLIMIT_NOFILE, (max(soft, 65535), hard))
|
resource.setrlimit(resource.RLIMIT_NOFILE, (max(soft, 65535), hard))
|
||||||
@@ -205,6 +213,11 @@ def main():
|
|||||||
logger.info("Starting EXO")
|
logger.info("Starting EXO")
|
||||||
logger.info(f"EXO_LIBP2P_NAMESPACE: {os.getenv('EXO_LIBP2P_NAMESPACE')}")
|
logger.info(f"EXO_LIBP2P_NAMESPACE: {os.getenv('EXO_LIBP2P_NAMESPACE')}")
|
||||||
|
|
||||||
|
# Discover and register plugins
|
||||||
|
from exo.plugins.registry import discover_plugins
|
||||||
|
|
||||||
|
discover_plugins()
|
||||||
|
|
||||||
# Set FAST_SYNCH override env var for runner subprocesses
|
# Set FAST_SYNCH override env var for runner subprocesses
|
||||||
if args.fast_synch is True:
|
if args.fast_synch is True:
|
||||||
os.environ["EXO_FAST_SYNCH"] = "on"
|
os.environ["EXO_FAST_SYNCH"] = "on"
|
||||||
|
|||||||
@@ -1,9 +1,11 @@
|
|||||||
|
import asyncio
|
||||||
import base64
|
import base64
|
||||||
import json
|
import json
|
||||||
|
import os
|
||||||
import time
|
import time
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from typing import Literal, cast
|
from typing import Any, Callable, Literal, Optional, cast
|
||||||
|
|
||||||
import anyio
|
import anyio
|
||||||
from anyio import BrokenResourceError, create_task_group
|
from anyio import BrokenResourceError, create_task_group
|
||||||
@@ -16,6 +18,7 @@ from hypercorn.asyncio import serve # pyright: ignore[reportUnknownVariableType
|
|||||||
from hypercorn.config import Config
|
from hypercorn.config import Config
|
||||||
from hypercorn.typing import ASGIFramework
|
from hypercorn.typing import ASGIFramework
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from exo.master.image_store import ImageStore
|
from exo.master.image_store import ImageStore
|
||||||
from exo.master.placement import place_instance as get_instance_placements
|
from exo.master.placement import place_instance as get_instance_placements
|
||||||
@@ -59,8 +62,8 @@ from exo.shared.types.api import (
|
|||||||
)
|
)
|
||||||
from exo.shared.types.chunks import ImageChunk, InputImageChunk, TokenChunk
|
from exo.shared.types.chunks import ImageChunk, InputImageChunk, TokenChunk
|
||||||
from exo.shared.types.commands import (
|
from exo.shared.types.commands import (
|
||||||
|
BaseCommand,
|
||||||
ChatCompletion,
|
ChatCompletion,
|
||||||
Command,
|
|
||||||
CreateInstance,
|
CreateInstance,
|
||||||
DeleteInstance,
|
DeleteInstance,
|
||||||
ForwarderCommand,
|
ForwarderCommand,
|
||||||
@@ -72,15 +75,20 @@ from exo.shared.types.commands import (
|
|||||||
)
|
)
|
||||||
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
|
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
|
||||||
from exo.shared.types.events import (
|
from exo.shared.types.events import (
|
||||||
|
BaseEvent,
|
||||||
ChunkGenerated,
|
ChunkGenerated,
|
||||||
Event,
|
|
||||||
ForwarderEvent,
|
ForwarderEvent,
|
||||||
IndexedEvent,
|
IndexedEvent,
|
||||||
)
|
)
|
||||||
from exo.shared.types.memory import Memory
|
from exo.shared.types.memory import Memory
|
||||||
from exo.shared.types.state import State
|
from exo.shared.types.state import State
|
||||||
from exo.shared.types.tasks import ChatCompletionTaskParams
|
from exo.shared.types.tasks import ChatCompletionTaskParams
|
||||||
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
|
from exo.shared.types.worker.instances import (
|
||||||
|
BaseInstance,
|
||||||
|
Instance,
|
||||||
|
InstanceId,
|
||||||
|
InstanceMeta,
|
||||||
|
)
|
||||||
from exo.shared.types.worker.shards import Sharding
|
from exo.shared.types.worker.shards import Sharding
|
||||||
from exo.utils.banner import print_startup_banner
|
from exo.utils.banner import print_startup_banner
|
||||||
from exo.utils.channels import Receiver, Sender, channel
|
from exo.utils.channels import Receiver, Sender, channel
|
||||||
@@ -92,6 +100,22 @@ def _format_to_content_type(image_format: Literal["png", "jpeg", "webp"] | None)
|
|||||||
return f"image/{image_format or 'png'}"
|
return f"image/{image_format or 'png'}"
|
||||||
|
|
||||||
|
|
||||||
|
class ExecuteRequest(BaseModel):
|
||||||
|
"""Request to execute a command."""
|
||||||
|
|
||||||
|
command: list[str]
|
||||||
|
cwd: Optional[str] = None
|
||||||
|
env: Optional[dict[str, str]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ExecuteResponse(BaseModel):
|
||||||
|
"""Response from command execution."""
|
||||||
|
|
||||||
|
exit_code: int
|
||||||
|
stdout: str
|
||||||
|
stderr: str
|
||||||
|
|
||||||
|
|
||||||
def chunk_to_response(
|
def chunk_to_response(
|
||||||
chunk: TokenChunk, command_id: CommandId
|
chunk: TokenChunk, command_id: CommandId
|
||||||
) -> ChatCompletionResponse:
|
) -> ChatCompletionResponse:
|
||||||
@@ -135,11 +159,11 @@ class API:
|
|||||||
election_receiver: Receiver[ElectionMessage],
|
election_receiver: Receiver[ElectionMessage],
|
||||||
) -> None:
|
) -> None:
|
||||||
self.state = State()
|
self.state = State()
|
||||||
self._event_log: list[Event] = []
|
self._event_log: list[BaseEvent] = []
|
||||||
self.command_sender = command_sender
|
self.command_sender = command_sender
|
||||||
self.global_event_receiver = global_event_receiver
|
self.global_event_receiver = global_event_receiver
|
||||||
self.election_receiver = election_receiver
|
self.election_receiver = election_receiver
|
||||||
self.event_buffer: OrderedBuffer[Event] = OrderedBuffer[Event]()
|
self.event_buffer: OrderedBuffer[BaseEvent] = OrderedBuffer[BaseEvent]()
|
||||||
self.node_id: NodeId = node_id
|
self.node_id: NodeId = node_id
|
||||||
self.session_id: SessionId = session_id
|
self.session_id: SessionId = session_id
|
||||||
self.last_completed_election: int = 0
|
self.last_completed_election: int = 0
|
||||||
@@ -171,7 +195,7 @@ class API:
|
|||||||
logger.info("Resetting API State")
|
logger.info("Resetting API State")
|
||||||
self.state = State()
|
self.state = State()
|
||||||
self.session_id = new_session_id
|
self.session_id = new_session_id
|
||||||
self.event_buffer = OrderedBuffer[Event]()
|
self.event_buffer = OrderedBuffer[BaseEvent]()
|
||||||
self._chat_completion_queues = {}
|
self._chat_completion_queues = {}
|
||||||
self._image_generation_queues = {}
|
self._image_generation_queues = {}
|
||||||
self.unpause(result_clock)
|
self.unpause(result_clock)
|
||||||
@@ -231,6 +255,129 @@ class API:
|
|||||||
self.app.get("/images/{image_id}")(self.get_image)
|
self.app.get("/images/{image_id}")(self.get_image)
|
||||||
self.app.get("/state")(lambda: self.state)
|
self.app.get("/state")(lambda: self.state)
|
||||||
self.app.get("/events")(lambda: self._event_log)
|
self.app.get("/events")(lambda: self._event_log)
|
||||||
|
self.app.post("/execute")(self.execute)
|
||||||
|
|
||||||
|
# Register plugin routes
|
||||||
|
self._setup_plugin_routes()
|
||||||
|
|
||||||
|
def _setup_plugin_routes(self) -> None:
|
||||||
|
"""Register API routes from all plugins."""
|
||||||
|
from exo.plugins.registry import PluginRegistry
|
||||||
|
|
||||||
|
registry = PluginRegistry.get()
|
||||||
|
|
||||||
|
for plugin in registry.all_plugins():
|
||||||
|
for method, path, handler in plugin.get_api_routes():
|
||||||
|
# Create a wrapper that injects PluginContext
|
||||||
|
# We need to capture handler in closure properly
|
||||||
|
self._register_plugin_route(method, path, handler)
|
||||||
|
|
||||||
|
def _register_plugin_route(
|
||||||
|
self,
|
||||||
|
method: str,
|
||||||
|
path: str,
|
||||||
|
handler: Callable[..., Any],
|
||||||
|
) -> None:
|
||||||
|
"""Register a single plugin route with proper closure."""
|
||||||
|
import functools
|
||||||
|
import inspect
|
||||||
|
|
||||||
|
from exo.plugins.context import PluginContext
|
||||||
|
|
||||||
|
# Get the original handler's signature (excluding ctx)
|
||||||
|
sig = inspect.signature(handler)
|
||||||
|
params = [p for p in sig.parameters.values() if p.name != "ctx"]
|
||||||
|
new_sig = sig.replace(parameters=params)
|
||||||
|
|
||||||
|
@functools.wraps(handler)
|
||||||
|
async def route_wrapper(**kwargs: Any) -> Any: # pyright: ignore[reportAny]
|
||||||
|
ctx = PluginContext(
|
||||||
|
state=self.state,
|
||||||
|
send_command=self._send,
|
||||||
|
node_id=self.node_id,
|
||||||
|
)
|
||||||
|
return await handler(ctx, **kwargs) # pyright: ignore[reportAny]
|
||||||
|
|
||||||
|
# Override the signature for FastAPI
|
||||||
|
route_wrapper.__signature__ = new_sig # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
# Register the route
|
||||||
|
if method == "get":
|
||||||
|
self.app.get(path)(route_wrapper)
|
||||||
|
elif method == "post":
|
||||||
|
self.app.post(path)(route_wrapper)
|
||||||
|
elif method == "delete":
|
||||||
|
self.app.delete(path)(route_wrapper)
|
||||||
|
elif method == "put":
|
||||||
|
self.app.put(path)(route_wrapper)
|
||||||
|
|
||||||
|
logger.info(f"Registered plugin route: {method.upper()} {path}")
|
||||||
|
|
||||||
|
async def execute(self, request: ExecuteRequest) -> ExecuteResponse:
|
||||||
|
"""Execute a command locally. Used by exo-rsh for MPI remote execution."""
|
||||||
|
cmd_str = " ".join(request.command)
|
||||||
|
logger.info(f"Executing: {cmd_str}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Build environment
|
||||||
|
env = os.environ.copy()
|
||||||
|
if request.env:
|
||||||
|
env.update(request.env)
|
||||||
|
|
||||||
|
# Check if command contains shell metacharacters
|
||||||
|
# If so, run through shell. mpirun sends complex commands like:
|
||||||
|
# "VAR=value;export VAR;/path/to/prted --args"
|
||||||
|
needs_shell = any(c in cmd_str for c in ";|&$`")
|
||||||
|
|
||||||
|
# Commands with --daemonize (e.g., prted) fork a child that inherits
|
||||||
|
# stdout/stderr pipe fds. Using PIPE would cause communicate() to hang
|
||||||
|
# because the daemon child never closes them. Use DEVNULL instead.
|
||||||
|
is_daemonize = "--daemonize" in cmd_str
|
||||||
|
out_mode = (
|
||||||
|
asyncio.subprocess.DEVNULL if is_daemonize else asyncio.subprocess.PIPE
|
||||||
|
)
|
||||||
|
|
||||||
|
if needs_shell:
|
||||||
|
process = await asyncio.create_subprocess_shell(
|
||||||
|
cmd_str,
|
||||||
|
stdout=out_mode,
|
||||||
|
stderr=out_mode,
|
||||||
|
cwd=request.cwd,
|
||||||
|
env=env,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
process = await asyncio.create_subprocess_exec(
|
||||||
|
*request.command,
|
||||||
|
stdout=out_mode,
|
||||||
|
stderr=out_mode,
|
||||||
|
cwd=request.cwd,
|
||||||
|
env=env,
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_daemonize:
|
||||||
|
await process.wait()
|
||||||
|
exit_code = process.returncode or 0
|
||||||
|
logger.info(f"Daemonized command completed with exit code {exit_code}")
|
||||||
|
return ExecuteResponse(exit_code=exit_code, stdout="", stderr="")
|
||||||
|
|
||||||
|
stdout, stderr = await process.communicate()
|
||||||
|
exit_code = process.returncode or 0
|
||||||
|
|
||||||
|
logger.info(f"Command completed with exit code {exit_code}")
|
||||||
|
|
||||||
|
return ExecuteResponse(
|
||||||
|
exit_code=exit_code,
|
||||||
|
stdout=stdout.decode("utf-8", errors="replace"),
|
||||||
|
stderr=stderr.decode("utf-8", errors="replace"),
|
||||||
|
)
|
||||||
|
|
||||||
|
except FileNotFoundError:
|
||||||
|
logger.error(f"Command not found: {request.command[0]}")
|
||||||
|
return ExecuteResponse(
|
||||||
|
exit_code=127,
|
||||||
|
stdout="",
|
||||||
|
stderr=f"Command not found: {request.command[0]}",
|
||||||
|
)
|
||||||
|
|
||||||
async def place_instance(self, payload: PlaceInstanceParams):
|
async def place_instance(self, payload: PlaceInstanceParams):
|
||||||
command = PlaceInstance(
|
command = PlaceInstance(
|
||||||
@@ -278,7 +425,7 @@ class API:
|
|||||||
sharding: Sharding = Sharding.Pipeline,
|
sharding: Sharding = Sharding.Pipeline,
|
||||||
instance_meta: InstanceMeta = InstanceMeta.MlxRing,
|
instance_meta: InstanceMeta = InstanceMeta.MlxRing,
|
||||||
min_nodes: int = 1,
|
min_nodes: int = 1,
|
||||||
) -> Instance:
|
) -> BaseInstance:
|
||||||
model_card = await resolve_model_card(model_id)
|
model_card = await resolve_model_card(model_id)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -409,7 +556,7 @@ class API:
|
|||||||
model_id=model_card.model_id,
|
model_id=model_card.model_id,
|
||||||
sharding=sharding,
|
sharding=sharding,
|
||||||
instance_meta=instance_meta,
|
instance_meta=instance_meta,
|
||||||
instance=instance,
|
instance=cast(Instance, instance),
|
||||||
memory_delta_by_node=memory_delta_by_node or None,
|
memory_delta_by_node=memory_delta_by_node or None,
|
||||||
error=None,
|
error=None,
|
||||||
)
|
)
|
||||||
@@ -418,7 +565,7 @@ class API:
|
|||||||
|
|
||||||
return PlacementPreviewResponse(previews=previews)
|
return PlacementPreviewResponse(previews=previews)
|
||||||
|
|
||||||
def get_instance(self, instance_id: InstanceId) -> Instance:
|
def get_instance(self, instance_id: InstanceId) -> BaseInstance:
|
||||||
if instance_id not in self.state.instances:
|
if instance_id not in self.state.instances:
|
||||||
raise HTTPException(status_code=404, detail="Instance not found")
|
raise HTTPException(status_code=404, detail="Instance not found")
|
||||||
return self.state.instances[instance_id]
|
return self.state.instances[instance_id]
|
||||||
@@ -1185,7 +1332,7 @@ class API:
|
|||||||
if removed > 0:
|
if removed > 0:
|
||||||
logger.debug(f"Cleaned up {removed} expired images")
|
logger.debug(f"Cleaned up {removed} expired images")
|
||||||
|
|
||||||
async def _send(self, command: Command):
|
async def _send(self, command: BaseCommand):
|
||||||
while self.paused:
|
while self.paused:
|
||||||
await self.paused_ev.wait()
|
await self.paused_ev.wait()
|
||||||
await self.command_sender.send(
|
await self.command_sender.send(
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ from exo.master.placement import (
|
|||||||
get_transition_events,
|
get_transition_events,
|
||||||
place_instance,
|
place_instance,
|
||||||
)
|
)
|
||||||
|
from exo.plugins.registry import PluginRegistry
|
||||||
from exo.shared.apply import apply
|
from exo.shared.apply import apply
|
||||||
from exo.shared.types.commands import (
|
from exo.shared.types.commands import (
|
||||||
ChatCompletion,
|
ChatCompletion,
|
||||||
@@ -26,6 +27,7 @@ from exo.shared.types.commands import (
|
|||||||
)
|
)
|
||||||
from exo.shared.types.common import CommandId, NodeId, SessionId
|
from exo.shared.types.common import CommandId, NodeId, SessionId
|
||||||
from exo.shared.types.events import (
|
from exo.shared.types.events import (
|
||||||
|
BaseEvent,
|
||||||
Event,
|
Event,
|
||||||
ForwarderEvent,
|
ForwarderEvent,
|
||||||
IndexedEvent,
|
IndexedEvent,
|
||||||
@@ -83,9 +85,9 @@ class Master:
|
|||||||
self._loopback_event_sender: Sender[ForwarderEvent] = (
|
self._loopback_event_sender: Sender[ForwarderEvent] = (
|
||||||
local_event_receiver.clone_sender()
|
local_event_receiver.clone_sender()
|
||||||
)
|
)
|
||||||
self._multi_buffer = MultiSourceBuffer[NodeId, Event]()
|
self._multi_buffer = MultiSourceBuffer[NodeId, BaseEvent]()
|
||||||
# TODO: not have this
|
# TODO: not have this
|
||||||
self._event_log: list[Event] = []
|
self._event_log: list[BaseEvent] = []
|
||||||
|
|
||||||
async def run(self):
|
async def run(self):
|
||||||
logger.info("Starting Master")
|
logger.info("Starting Master")
|
||||||
@@ -296,6 +298,17 @@ class Master:
|
|||||||
await self._send_event(
|
await self._send_event(
|
||||||
IndexedEvent(idx=i, event=self._event_log[i])
|
IndexedEvent(idx=i, event=self._event_log[i])
|
||||||
)
|
)
|
||||||
|
case _:
|
||||||
|
# Check if a plugin handles this command
|
||||||
|
registry = PluginRegistry.get()
|
||||||
|
plugin = registry.get_plugin_for_command(command)
|
||||||
|
if plugin is not None:
|
||||||
|
events = plugin.process_command(
|
||||||
|
command,
|
||||||
|
self.state.topology,
|
||||||
|
self.state.instances,
|
||||||
|
)
|
||||||
|
generated_events.extend(events)
|
||||||
for event in generated_events:
|
for event in generated_events:
|
||||||
await self.event_sender.send(event)
|
await self.event_sender.send(event)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ from exo.shared.types.events import Event, InstanceCreated, InstanceDeleted
|
|||||||
from exo.shared.types.memory import Memory
|
from exo.shared.types.memory import Memory
|
||||||
from exo.shared.types.profiling import MemoryUsage, NodeNetworkInfo
|
from exo.shared.types.profiling import MemoryUsage, NodeNetworkInfo
|
||||||
from exo.shared.types.worker.instances import (
|
from exo.shared.types.worker.instances import (
|
||||||
Instance,
|
BaseInstance,
|
||||||
InstanceId,
|
InstanceId,
|
||||||
InstanceMeta,
|
InstanceMeta,
|
||||||
MlxJacclInstance,
|
MlxJacclInstance,
|
||||||
@@ -41,8 +41,8 @@ def random_ephemeral_port() -> int:
|
|||||||
def add_instance_to_placements(
|
def add_instance_to_placements(
|
||||||
command: CreateInstance,
|
command: CreateInstance,
|
||||||
topology: Topology,
|
topology: Topology,
|
||||||
current_instances: Mapping[InstanceId, Instance],
|
current_instances: Mapping[InstanceId, BaseInstance],
|
||||||
) -> Mapping[InstanceId, Instance]:
|
) -> Mapping[InstanceId, BaseInstance]:
|
||||||
# TODO: validate against topology
|
# TODO: validate against topology
|
||||||
|
|
||||||
return {**current_instances, command.instance.instance_id: command.instance}
|
return {**current_instances, command.instance.instance_id: command.instance}
|
||||||
@@ -51,10 +51,10 @@ def add_instance_to_placements(
|
|||||||
def place_instance(
|
def place_instance(
|
||||||
command: PlaceInstance,
|
command: PlaceInstance,
|
||||||
topology: Topology,
|
topology: Topology,
|
||||||
current_instances: Mapping[InstanceId, Instance],
|
current_instances: Mapping[InstanceId, BaseInstance],
|
||||||
node_memory: Mapping[NodeId, MemoryUsage],
|
node_memory: Mapping[NodeId, MemoryUsage],
|
||||||
node_network: Mapping[NodeId, NodeNetworkInfo],
|
node_network: Mapping[NodeId, NodeNetworkInfo],
|
||||||
) -> dict[InstanceId, Instance]:
|
) -> dict[InstanceId, BaseInstance]:
|
||||||
cycles = topology.get_cycles()
|
cycles = topology.get_cycles()
|
||||||
candidate_cycles = list(filter(lambda it: len(it) >= command.min_nodes, cycles))
|
candidate_cycles = list(filter(lambda it: len(it) >= command.min_nodes, cycles))
|
||||||
cycles_with_sufficient_memory = filter_cycles_by_memory(
|
cycles_with_sufficient_memory = filter_cycles_by_memory(
|
||||||
@@ -159,8 +159,8 @@ def place_instance(
|
|||||||
|
|
||||||
def delete_instance(
|
def delete_instance(
|
||||||
command: DeleteInstance,
|
command: DeleteInstance,
|
||||||
current_instances: Mapping[InstanceId, Instance],
|
current_instances: Mapping[InstanceId, BaseInstance],
|
||||||
) -> dict[InstanceId, Instance]:
|
) -> dict[InstanceId, BaseInstance]:
|
||||||
target_instances = dict(deepcopy(current_instances))
|
target_instances = dict(deepcopy(current_instances))
|
||||||
if command.instance_id in target_instances:
|
if command.instance_id in target_instances:
|
||||||
del target_instances[command.instance_id]
|
del target_instances[command.instance_id]
|
||||||
@@ -169,8 +169,8 @@ def delete_instance(
|
|||||||
|
|
||||||
|
|
||||||
def get_transition_events(
|
def get_transition_events(
|
||||||
current_instances: Mapping[InstanceId, Instance],
|
current_instances: Mapping[InstanceId, BaseInstance],
|
||||||
target_instances: Mapping[InstanceId, Instance],
|
target_instances: Mapping[InstanceId, BaseInstance],
|
||||||
) -> Sequence[Event]:
|
) -> Sequence[Event]:
|
||||||
events: list[Event] = []
|
events: list[Event] = []
|
||||||
|
|
||||||
|
|||||||
24
src/exo/plugins/__init__.py
Normal file
24
src/exo/plugins/__init__.py
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
"""Exo Plugin System.
|
||||||
|
|
||||||
|
This module provides the plugin architecture for extending exo with custom
|
||||||
|
workload types (simulations, ML frameworks, etc.) without modifying core code.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from exo.plugins.base import EXOPlugin, PluginCommand, PluginInstance
|
||||||
|
from exo.plugins.registry import PluginRegistry, discover_plugins
|
||||||
|
from exo.plugins.type_registry import (
|
||||||
|
command_registry,
|
||||||
|
event_registry,
|
||||||
|
instance_registry,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"EXOPlugin",
|
||||||
|
"PluginCommand",
|
||||||
|
"PluginInstance",
|
||||||
|
"PluginRegistry",
|
||||||
|
"discover_plugins",
|
||||||
|
"command_registry",
|
||||||
|
"event_registry",
|
||||||
|
"instance_registry",
|
||||||
|
]
|
||||||
171
src/exo/plugins/base.py
Normal file
171
src/exo/plugins/base.py
Normal file
@@ -0,0 +1,171 @@
|
|||||||
|
"""Base classes and protocols for Exo plugins."""
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from collections.abc import Callable, Mapping, Sequence
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
|
||||||
|
from exo.shared.types.common import CommandId
|
||||||
|
from exo.shared.types.events import Event
|
||||||
|
from exo.shared.types.tasks import Task
|
||||||
|
from exo.shared.types.worker.instances import InstanceId
|
||||||
|
from exo.shared.types.worker.runners import RunnerId
|
||||||
|
from exo.utils.pydantic_ext import TaggedModel
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from exo.shared.topology import Topology
|
||||||
|
from exo.shared.types.worker.instances import BaseInstance, BoundInstance
|
||||||
|
from exo.utils.channels import MpReceiver, MpSender
|
||||||
|
from exo.worker.runner.runner_supervisor import RunnerSupervisor
|
||||||
|
|
||||||
|
|
||||||
|
class PluginCommand(TaggedModel):
|
||||||
|
"""Base class for plugin-defined commands.
|
||||||
|
|
||||||
|
All plugin commands must inherit from this class. Commands are serialized
|
||||||
|
with their class name as a tag for routing.
|
||||||
|
"""
|
||||||
|
|
||||||
|
command_id: CommandId = Field(default_factory=CommandId)
|
||||||
|
|
||||||
|
|
||||||
|
class PluginInstance(TaggedModel):
|
||||||
|
"""Base class for plugin-defined instances.
|
||||||
|
|
||||||
|
All plugin instances must inherit from this class. Plugins are expected
|
||||||
|
to define their own instance type with workload-specific fields.
|
||||||
|
"""
|
||||||
|
|
||||||
|
instance_id: InstanceId
|
||||||
|
|
||||||
|
|
||||||
|
class EXOPlugin(ABC):
|
||||||
|
"""Protocol that all exo plugins must implement.
|
||||||
|
|
||||||
|
A plugin provides:
|
||||||
|
- Custom command types for API -> Master communication
|
||||||
|
- Custom instance types representing running workloads
|
||||||
|
- Placement logic for distributing work across nodes
|
||||||
|
- Planning logic for local task scheduling
|
||||||
|
- Runner implementation for executing work
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def name(self) -> str:
|
||||||
|
"""Unique identifier for this plugin (e.g., 'flash', 'pytorch', 'mpi')."""
|
||||||
|
...
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def version(self) -> str:
|
||||||
|
"""Semantic version string (e.g., '1.0.0')."""
|
||||||
|
...
|
||||||
|
|
||||||
|
# ========== Type Registration ==========
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_command_types(self) -> Sequence[type]:
|
||||||
|
"""Return command types this plugin handles.
|
||||||
|
|
||||||
|
These commands are routed to this plugin's process_command method.
|
||||||
|
Can return core BaseCommand types or PluginCommand types.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_instance_type(self) -> type:
|
||||||
|
"""Return the instance type this plugin creates.
|
||||||
|
|
||||||
|
This instance type is used for routing in planning and runner bootstrap.
|
||||||
|
Can return core Instance types or PluginInstance types.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
# ========== API Routes ==========
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_api_routes(
|
||||||
|
self,
|
||||||
|
) -> Sequence[tuple[str, str, Callable[..., Any]]]:
|
||||||
|
"""Return FastAPI routes to register.
|
||||||
|
|
||||||
|
Each tuple: (method, path, handler)
|
||||||
|
Example: [('post', '/flash/launch', self.launch_handler)]
|
||||||
|
|
||||||
|
Handlers receive a PluginContext with access to:
|
||||||
|
- state: Current State object
|
||||||
|
- send_command: Async function to send commands
|
||||||
|
- node_id: Current node's ID
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
# ========== Master Command Handling ==========
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def handles_command(self, command: Any) -> bool: # pyright: ignore[reportAny]
|
||||||
|
"""Return True if this plugin handles the given command type."""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def process_command(
|
||||||
|
self,
|
||||||
|
command: Any, # pyright: ignore[reportAny]
|
||||||
|
topology: "Topology",
|
||||||
|
current_instances: Mapping[InstanceId, "BaseInstance"],
|
||||||
|
) -> Sequence[Event]:
|
||||||
|
"""Process a command and return events to emit.
|
||||||
|
|
||||||
|
Typically creates placement and returns InstanceCreated/InstanceDeleted events.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
command: The command to process
|
||||||
|
topology: Current cluster topology
|
||||||
|
current_instances: Currently running instances
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Sequence of events to emit (e.g., InstanceCreated, InstanceDeleted)
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
# ========== Worker Planning ==========
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def handles_instance(self, instance: object) -> bool:
|
||||||
|
"""Return True if this plugin manages the given instance type."""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def plan_task(
|
||||||
|
self,
|
||||||
|
runners: Mapping[RunnerId, "RunnerSupervisor"],
|
||||||
|
instances: Mapping[InstanceId, "BaseInstance"],
|
||||||
|
) -> Task | None:
|
||||||
|
"""Plan the next task for plugin instances.
|
||||||
|
|
||||||
|
Called during each planning cycle.
|
||||||
|
Return None if no task is needed.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def should_skip_download(self, instance: object) -> bool:
|
||||||
|
"""Return True if this instance type doesn't need model downloads."""
|
||||||
|
...
|
||||||
|
|
||||||
|
# ========== Runner Bootstrap ==========
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def create_runner(
|
||||||
|
self,
|
||||||
|
bound_instance: "BoundInstance",
|
||||||
|
event_sender: "MpSender[Event]",
|
||||||
|
task_receiver: "MpReceiver[Task]",
|
||||||
|
) -> None:
|
||||||
|
"""Entry point for the runner process.
|
||||||
|
|
||||||
|
Called in a subprocess to execute the actual workload.
|
||||||
|
This function should block until the workload completes.
|
||||||
|
"""
|
||||||
|
...
|
||||||
21
src/exo/plugins/context.py
Normal file
21
src/exo/plugins/context.py
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
"""Context objects passed to plugin handlers."""
|
||||||
|
|
||||||
|
from collections.abc import Awaitable, Callable
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from exo.shared.types.commands import BaseCommand
|
||||||
|
from exo.shared.types.common import NodeId
|
||||||
|
from exo.shared.types.state import State
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PluginContext:
|
||||||
|
"""Context provided to plugin API handlers.
|
||||||
|
|
||||||
|
This gives plugins access to the current state and the ability to send
|
||||||
|
commands without direct access to internal API components.
|
||||||
|
"""
|
||||||
|
|
||||||
|
state: State
|
||||||
|
send_command: Callable[[BaseCommand], Awaitable[None]]
|
||||||
|
node_id: NodeId
|
||||||
5
src/exo/plugins/implementations/__init__.py
Normal file
5
src/exo/plugins/implementations/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
"""Plugin implementations directory.
|
||||||
|
|
||||||
|
Each subdirectory should contain a plugin with a register() function
|
||||||
|
that returns an EXOPlugin instance.
|
||||||
|
"""
|
||||||
15
src/exo/plugins/implementations/flash/__init__.py
Normal file
15
src/exo/plugins/implementations/flash/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
"""FLASH Plugin - MPI-based simulation support for Exo."""
|
||||||
|
|
||||||
|
from exo.plugins.implementations.flash.plugin import FLASHPlugin
|
||||||
|
from exo.plugins.implementations.flash.types import (
|
||||||
|
FLASHInstance,
|
||||||
|
LaunchFLASH,
|
||||||
|
StopFLASH,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = ["FLASHPlugin", "FLASHInstance", "LaunchFLASH", "StopFLASH", "register"]
|
||||||
|
|
||||||
|
|
||||||
|
def register() -> FLASHPlugin:
|
||||||
|
"""Entry point for plugin discovery."""
|
||||||
|
return FLASHPlugin()
|
||||||
109
src/exo/plugins/implementations/flash/api_handlers.py
Normal file
109
src/exo/plugins/implementations/flash/api_handlers.py
Normal file
@@ -0,0 +1,109 @@
|
|||||||
|
"""FLASH plugin API handlers."""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
from exo.plugins.context import PluginContext
|
||||||
|
from exo.plugins.implementations.flash.types import (
|
||||||
|
FLASHInstance,
|
||||||
|
LaunchFLASH,
|
||||||
|
StopFLASH,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_launch_flash(
|
||||||
|
ctx: PluginContext,
|
||||||
|
simulation_name: str,
|
||||||
|
flash_executable_path: str,
|
||||||
|
working_directory: str,
|
||||||
|
parameter_file_path: str = "",
|
||||||
|
ranks_per_node: int = 1,
|
||||||
|
min_nodes: int = 1,
|
||||||
|
hosts: str = "",
|
||||||
|
) -> dict[str, str]:
|
||||||
|
"""Launch a FLASH MPI simulation across the cluster.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ctx: Plugin context with state and send_command
|
||||||
|
simulation_name: Name of the simulation
|
||||||
|
flash_executable_path: Path to the FLASH executable
|
||||||
|
working_directory: Working directory for the simulation
|
||||||
|
parameter_file_path: Path to parameter file (optional)
|
||||||
|
ranks_per_node: Number of MPI ranks per node
|
||||||
|
min_nodes: Minimum number of nodes required
|
||||||
|
hosts: Optional comma-separated hostnames (e.g., "s14,james21-1").
|
||||||
|
If not provided, IPs are discovered from topology edges.
|
||||||
|
"""
|
||||||
|
command = LaunchFLASH(
|
||||||
|
simulation_name=simulation_name,
|
||||||
|
flash_executable_path=flash_executable_path,
|
||||||
|
parameter_file_path=parameter_file_path,
|
||||||
|
working_directory=working_directory,
|
||||||
|
ranks_per_node=ranks_per_node,
|
||||||
|
min_nodes=min_nodes,
|
||||||
|
hosts=hosts,
|
||||||
|
)
|
||||||
|
await ctx.send_command(command)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"message": "FLASH launch command received",
|
||||||
|
"command_id": str(command.command_id),
|
||||||
|
"simulation_name": simulation_name,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_stop_flash(
|
||||||
|
ctx: PluginContext,
|
||||||
|
instance_id: str,
|
||||||
|
) -> dict[str, str]:
|
||||||
|
"""Stop a running FLASH simulation."""
|
||||||
|
from exo.shared.types.worker.instances import InstanceId
|
||||||
|
|
||||||
|
inst_id = InstanceId(instance_id)
|
||||||
|
|
||||||
|
if inst_id not in ctx.state.instances:
|
||||||
|
raise HTTPException(status_code=404, detail="Instance not found")
|
||||||
|
|
||||||
|
instance = ctx.state.instances[inst_id]
|
||||||
|
if not isinstance(instance, FLASHInstance):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400, detail="Instance is not a FLASH simulation"
|
||||||
|
)
|
||||||
|
|
||||||
|
command = StopFLASH(instance_id=inst_id)
|
||||||
|
await ctx.send_command(command)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"message": "Stop command received",
|
||||||
|
"command_id": str(command.command_id),
|
||||||
|
"instance_id": str(instance_id),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_list_flash_instances(ctx: PluginContext) -> list[dict[str, Any]]:
|
||||||
|
"""List all FLASH simulation instances."""
|
||||||
|
flash_instances: list[dict[str, Any]] = []
|
||||||
|
for instance_id, instance in ctx.state.instances.items():
|
||||||
|
if isinstance(instance, FLASHInstance):
|
||||||
|
# Get runner statuses for this instance
|
||||||
|
runner_statuses: dict[str, str | None] = {}
|
||||||
|
for (
|
||||||
|
node_id,
|
||||||
|
runner_id,
|
||||||
|
) in instance.shard_assignments.node_to_runner.items():
|
||||||
|
runner_status = ctx.state.runners.get(runner_id)
|
||||||
|
runner_statuses[str(node_id)] = (
|
||||||
|
str(runner_status) if runner_status else None
|
||||||
|
)
|
||||||
|
|
||||||
|
flash_instances.append(
|
||||||
|
{
|
||||||
|
"instance_id": str(instance_id),
|
||||||
|
"simulation_name": instance.simulation_name,
|
||||||
|
"total_ranks": instance.total_ranks,
|
||||||
|
"working_directory": instance.working_directory,
|
||||||
|
"runner_statuses": runner_statuses,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return flash_instances
|
||||||
167
src/exo/plugins/implementations/flash/placement.py
Normal file
167
src/exo/plugins/implementations/flash/placement.py
Normal file
@@ -0,0 +1,167 @@
|
|||||||
|
"""FLASH plugin placement logic."""
|
||||||
|
|
||||||
|
from collections.abc import Mapping
|
||||||
|
from copy import deepcopy
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from exo.plugins.implementations.flash.types import FLASHInstance, LaunchFLASH
|
||||||
|
from exo.shared.models.model_cards import ModelCard
|
||||||
|
from exo.shared.topology import Topology
|
||||||
|
from exo.shared.types.common import Host, ModelId, NodeId
|
||||||
|
from exo.shared.types.memory import Memory
|
||||||
|
from exo.shared.types.topology import SocketConnection
|
||||||
|
from exo.shared.types.worker.instances import BaseInstance, InstanceId
|
||||||
|
from exo.shared.types.worker.runners import (
|
||||||
|
RunnerId,
|
||||||
|
ShardAssignments,
|
||||||
|
)
|
||||||
|
from exo.shared.types.worker.shards import PipelineShardMetadata
|
||||||
|
|
||||||
|
|
||||||
|
def place_flash_instance(
|
||||||
|
command: LaunchFLASH,
|
||||||
|
topology: Topology,
|
||||||
|
current_instances: Mapping[InstanceId, BaseInstance],
|
||||||
|
) -> dict[InstanceId, BaseInstance]:
|
||||||
|
"""Place a FLASH simulation instance across available nodes.
|
||||||
|
|
||||||
|
Unlike MLX instances which use ring/JACCL topology for tensor parallelism,
|
||||||
|
FLASH instances use MPI for communication. We just need to provide the
|
||||||
|
node IPs so the runner can generate an MPI hostfile.
|
||||||
|
"""
|
||||||
|
instance_id = InstanceId()
|
||||||
|
target_instances: dict[InstanceId, BaseInstance] = dict(deepcopy(current_instances))
|
||||||
|
|
||||||
|
all_nodes = list(topology.list_nodes())
|
||||||
|
|
||||||
|
if len(all_nodes) < command.min_nodes:
|
||||||
|
raise ValueError(
|
||||||
|
f"Not enough nodes: need {command.min_nodes}, have {len(all_nodes)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Select nodes (take the first min_nodes)
|
||||||
|
selected_nodes = all_nodes[: command.min_nodes]
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Placing FLASH instance '{command.simulation_name}' on {len(selected_nodes)} nodes"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build shard assignments (one runner per node for FLASH)
|
||||||
|
runner_to_shard: dict[RunnerId, PipelineShardMetadata] = {}
|
||||||
|
node_to_runner: dict[NodeId, RunnerId] = {}
|
||||||
|
|
||||||
|
# Create a dummy ModelCard for FLASH (required by ShardMetadata interface)
|
||||||
|
flash_model_card = ModelCard(
|
||||||
|
model_id=ModelId(command.simulation_name),
|
||||||
|
storage_size=Memory(in_bytes=0),
|
||||||
|
n_layers=1,
|
||||||
|
hidden_size=1,
|
||||||
|
supports_tensor=False,
|
||||||
|
tasks=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
for i, node_id in enumerate(selected_nodes):
|
||||||
|
runner_id = RunnerId()
|
||||||
|
node_to_runner[node_id] = runner_id
|
||||||
|
runner_to_shard[runner_id] = PipelineShardMetadata(
|
||||||
|
device_rank=i,
|
||||||
|
world_size=len(selected_nodes),
|
||||||
|
model_card=flash_model_card,
|
||||||
|
start_layer=0,
|
||||||
|
end_layer=1,
|
||||||
|
n_layers=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
shard_assignments = ShardAssignments(
|
||||||
|
model_id=ModelId(command.simulation_name),
|
||||||
|
runner_to_shard=runner_to_shard,
|
||||||
|
node_to_runner=node_to_runner,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build hosts_by_node - get hostnames/IPs for MPI hostfile generation
|
||||||
|
hosts_by_node: dict[NodeId, list[Host]] = {}
|
||||||
|
|
||||||
|
# If explicit hosts are provided, use them directly
|
||||||
|
if command.hosts:
|
||||||
|
explicit_hosts = [h.strip() for h in command.hosts.split(",") if h.strip()]
|
||||||
|
logger.info(f"FLASH placement: explicit hosts provided: {explicit_hosts}")
|
||||||
|
for i, node_id in enumerate(selected_nodes):
|
||||||
|
if i < len(explicit_hosts):
|
||||||
|
hosts_by_node[node_id] = [Host(ip=explicit_hosts[i], port=0)]
|
||||||
|
logger.info(
|
||||||
|
f"FLASH placement: node {node_id} (rank {i}) -> IP {explicit_hosts[i]}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
f"Not enough hosts provided for node {i}, using localhost"
|
||||||
|
)
|
||||||
|
hosts_by_node[node_id] = [Host(ip="127.0.0.1", port=0)]
|
||||||
|
logger.info(
|
||||||
|
f"FLASH placement: coordinator will be rank 0 at IP {explicit_hosts[0]}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Get each node's own IP by looking at how OTHER nodes see it.
|
||||||
|
# out_edges(A) gives edges A->B with B's IP in sink_multiaddr,
|
||||||
|
# so to find A's IP we look at edges B->A from any other node.
|
||||||
|
for node_id in selected_nodes:
|
||||||
|
candidate_ips: set[str] = set()
|
||||||
|
|
||||||
|
for other_node in all_nodes:
|
||||||
|
if other_node == node_id:
|
||||||
|
continue
|
||||||
|
for conn in topology.out_edges(other_node):
|
||||||
|
if conn.sink == node_id and isinstance(conn.edge, SocketConnection):
|
||||||
|
ip = conn.edge.sink_multiaddr.ip_address
|
||||||
|
# Skip link-local and localhost addresses
|
||||||
|
if not ip.startswith("169.254.") and not ip.startswith("127."):
|
||||||
|
candidate_ips.add(ip)
|
||||||
|
|
||||||
|
# Prefer private network IPs (10.x, 192.168.x) over Tailscale CGNAT (100.64-127.x)
|
||||||
|
chosen_ip: str | None = None
|
||||||
|
for ip in candidate_ips:
|
||||||
|
if ip.startswith(("10.", "192.168.")):
|
||||||
|
chosen_ip = ip
|
||||||
|
break
|
||||||
|
if chosen_ip is None and candidate_ips:
|
||||||
|
chosen_ip = next(iter(candidate_ips))
|
||||||
|
|
||||||
|
if chosen_ip:
|
||||||
|
hosts_by_node[node_id] = [Host(ip=chosen_ip, port=0)]
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
f"Could not determine IP for node {node_id}, using localhost"
|
||||||
|
)
|
||||||
|
hosts_by_node[node_id] = [Host(ip="127.0.0.1", port=0)]
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"FLASH placement: node {node_id} -> IP {hosts_by_node[node_id][0].ip}"
|
||||||
|
f" (candidates: {candidate_ips})"
|
||||||
|
)
|
||||||
|
|
||||||
|
total_ranks = len(selected_nodes) * command.ranks_per_node
|
||||||
|
|
||||||
|
# Determine coordinator IP - first node's first host IP
|
||||||
|
first_node_id: NodeId = next(iter(hosts_by_node.keys()))
|
||||||
|
coordinator_ip: str = (
|
||||||
|
hosts_by_node[first_node_id][0].ip
|
||||||
|
if hosts_by_node[first_node_id]
|
||||||
|
else "127.0.0.1"
|
||||||
|
)
|
||||||
|
|
||||||
|
target_instances[instance_id] = FLASHInstance(
|
||||||
|
instance_id=instance_id,
|
||||||
|
shard_assignments=shard_assignments,
|
||||||
|
hosts_by_node=hosts_by_node,
|
||||||
|
flash_executable_path=command.flash_executable_path,
|
||||||
|
parameter_file_path=command.parameter_file_path,
|
||||||
|
working_directory=command.working_directory,
|
||||||
|
ranks_per_node=command.ranks_per_node,
|
||||||
|
total_ranks=total_ranks,
|
||||||
|
simulation_name=command.simulation_name,
|
||||||
|
coordinator_ip=coordinator_ip,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Created FLASH instance {instance_id} with {total_ranks} total ranks")
|
||||||
|
|
||||||
|
return target_instances
|
||||||
37
src/exo/plugins/implementations/flash/planning.py
Normal file
37
src/exo/plugins/implementations/flash/planning.py
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
"""FLASH plugin planning logic."""
|
||||||
|
|
||||||
|
from collections.abc import Mapping
|
||||||
|
|
||||||
|
from exo.plugins.implementations.flash.types import FLASHInstance
|
||||||
|
from exo.shared.types.tasks import LoadModel, Task
|
||||||
|
from exo.shared.types.worker.instances import BaseInstance, InstanceId
|
||||||
|
from exo.shared.types.worker.runners import RunnerId, RunnerIdle
|
||||||
|
from exo.worker.runner.runner_supervisor import RunnerSupervisor
|
||||||
|
|
||||||
|
|
||||||
|
def plan_flash(
|
||||||
|
runners: Mapping[RunnerId, RunnerSupervisor],
|
||||||
|
instances: Mapping[InstanceId, BaseInstance],
|
||||||
|
) -> Task | None:
|
||||||
|
"""Plan tasks specifically for FLASH instances.
|
||||||
|
|
||||||
|
FLASH instances have a simpler lifecycle:
|
||||||
|
- CreateRunner (handled by core _create_runner)
|
||||||
|
- LoadModel (starts the simulation immediately)
|
||||||
|
- Shutdown (handled by core _kill_runner)
|
||||||
|
|
||||||
|
This function handles the LoadModel step for FLASH instances,
|
||||||
|
skipping the MLX-specific download/init/warmup steps.
|
||||||
|
"""
|
||||||
|
for runner in runners.values():
|
||||||
|
instance = runner.bound_instance.instance
|
||||||
|
|
||||||
|
# Only handle FLASH instances
|
||||||
|
if not isinstance(instance, FLASHInstance):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# If runner is idle, emit LoadModel to start the simulation
|
||||||
|
if isinstance(runner.status, RunnerIdle):
|
||||||
|
return LoadModel(instance_id=instance.instance_id)
|
||||||
|
|
||||||
|
return None
|
||||||
98
src/exo/plugins/implementations/flash/plugin.py
Normal file
98
src/exo/plugins/implementations/flash/plugin.py
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
"""FLASH Plugin - Main plugin class."""
|
||||||
|
|
||||||
|
from collections.abc import Callable, Mapping, Sequence
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from exo.plugins.base import EXOPlugin
|
||||||
|
from exo.plugins.implementations.flash.api_handlers import (
|
||||||
|
handle_launch_flash,
|
||||||
|
handle_list_flash_instances,
|
||||||
|
handle_stop_flash,
|
||||||
|
)
|
||||||
|
from exo.plugins.implementations.flash.placement import place_flash_instance
|
||||||
|
from exo.plugins.implementations.flash.planning import plan_flash
|
||||||
|
from exo.plugins.implementations.flash.runner import main as flash_runner_main
|
||||||
|
from exo.plugins.implementations.flash.types import (
|
||||||
|
FLASHInstance,
|
||||||
|
LaunchFLASH,
|
||||||
|
StopFLASH,
|
||||||
|
)
|
||||||
|
from exo.shared.topology import Topology
|
||||||
|
from exo.shared.types.commands import DeleteInstance
|
||||||
|
from exo.shared.types.events import Event
|
||||||
|
from exo.shared.types.tasks import Task
|
||||||
|
from exo.shared.types.worker.instances import BaseInstance, BoundInstance, InstanceId
|
||||||
|
from exo.shared.types.worker.runners import RunnerId
|
||||||
|
from exo.utils.channels import MpReceiver, MpSender
|
||||||
|
from exo.worker.runner.runner_supervisor import RunnerSupervisor
|
||||||
|
|
||||||
|
|
||||||
|
class FLASHPlugin(EXOPlugin):
|
||||||
|
"""Plugin for FLASH MPI simulations."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return "flash"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def version(self) -> str:
|
||||||
|
return "1.0.0"
|
||||||
|
|
||||||
|
def get_command_types(self) -> Sequence[type]:
|
||||||
|
return [LaunchFLASH, StopFLASH]
|
||||||
|
|
||||||
|
def get_instance_type(self) -> type:
|
||||||
|
return FLASHInstance
|
||||||
|
|
||||||
|
def get_api_routes(
|
||||||
|
self,
|
||||||
|
) -> Sequence[tuple[str, str, Callable[..., Any]]]:
|
||||||
|
return [
|
||||||
|
("post", "/flash/launch", handle_launch_flash),
|
||||||
|
("delete", "/flash/{instance_id}", handle_stop_flash),
|
||||||
|
("get", "/flash/instances", handle_list_flash_instances),
|
||||||
|
]
|
||||||
|
|
||||||
|
def handles_command(self, command: Any) -> bool: # pyright: ignore[reportAny]
|
||||||
|
return isinstance(command, (LaunchFLASH, StopFLASH))
|
||||||
|
|
||||||
|
def process_command(
|
||||||
|
self,
|
||||||
|
command: Any, # pyright: ignore[reportAny]
|
||||||
|
topology: Topology,
|
||||||
|
current_instances: Mapping[InstanceId, BaseInstance],
|
||||||
|
) -> Sequence[Event]:
|
||||||
|
from exo.master.placement import delete_instance, get_transition_events
|
||||||
|
|
||||||
|
if isinstance(command, LaunchFLASH):
|
||||||
|
placement = place_flash_instance(command, topology, current_instances)
|
||||||
|
return list(get_transition_events(current_instances, placement))
|
||||||
|
elif isinstance(command, StopFLASH):
|
||||||
|
placement = delete_instance(
|
||||||
|
DeleteInstance(instance_id=command.instance_id),
|
||||||
|
current_instances,
|
||||||
|
)
|
||||||
|
return list(get_transition_events(current_instances, placement))
|
||||||
|
return []
|
||||||
|
|
||||||
|
def handles_instance(self, instance: object) -> bool:
|
||||||
|
return isinstance(instance, FLASHInstance)
|
||||||
|
|
||||||
|
def plan_task(
|
||||||
|
self,
|
||||||
|
runners: Mapping[RunnerId, RunnerSupervisor],
|
||||||
|
instances: Mapping[InstanceId, BaseInstance],
|
||||||
|
) -> Task | None:
|
||||||
|
return plan_flash(runners, instances)
|
||||||
|
|
||||||
|
def should_skip_download(self, instance: object) -> bool:
|
||||||
|
# FLASH instances don't need model downloads
|
||||||
|
return True
|
||||||
|
|
||||||
|
def create_runner(
|
||||||
|
self,
|
||||||
|
bound_instance: BoundInstance,
|
||||||
|
event_sender: MpSender[Event],
|
||||||
|
task_receiver: MpReceiver[Task],
|
||||||
|
) -> None:
|
||||||
|
flash_runner_main(bound_instance, event_sender, task_receiver)
|
||||||
305
src/exo/plugins/implementations/flash/runner.py
Normal file
305
src/exo/plugins/implementations/flash/runner.py
Normal file
@@ -0,0 +1,305 @@
|
|||||||
|
"""FLASH MPI Runner - spawns and monitors FLASH simulations.
|
||||||
|
|
||||||
|
Exo-native distributed MPI:
|
||||||
|
- Exo handles node discovery and coordination
|
||||||
|
- Coordinator generates hostfile from Exo topology
|
||||||
|
- mpirun uses exo-rsh (no SSH required) to spawn on remote nodes
|
||||||
|
- exo-rsh connects to each node's Exo API (/execute endpoint) for remote execution
|
||||||
|
- Workers just report ready and wait
|
||||||
|
"""
|
||||||
|
# ruff: noqa: I001 - Import order intentional (plugin types before shared types)
|
||||||
|
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import socket
|
||||||
|
import subprocess
|
||||||
|
import threading
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from exo.shared.types.events import (
|
||||||
|
Event,
|
||||||
|
RunnerStatusUpdated,
|
||||||
|
TaskAcknowledged,
|
||||||
|
TaskStatusUpdated,
|
||||||
|
)
|
||||||
|
from exo.shared.types.tasks import (
|
||||||
|
LoadModel,
|
||||||
|
Shutdown,
|
||||||
|
Task,
|
||||||
|
TaskStatus,
|
||||||
|
)
|
||||||
|
from exo.plugins.implementations.flash.types import FLASHInstance
|
||||||
|
from exo.shared.types.worker.instances import BoundInstance
|
||||||
|
from exo.shared.types.worker.runners import (
|
||||||
|
RunnerFailed,
|
||||||
|
RunnerIdle,
|
||||||
|
RunnerLoading,
|
||||||
|
RunnerReady,
|
||||||
|
RunnerRunning,
|
||||||
|
RunnerShutdown,
|
||||||
|
RunnerShuttingDown,
|
||||||
|
RunnerStatus,
|
||||||
|
)
|
||||||
|
from exo.utils.channels import MpReceiver, MpSender
|
||||||
|
|
||||||
|
# Find mpirun in PATH, fallback to common locations
|
||||||
|
MPIRUN_PATH = shutil.which("mpirun") or "/opt/homebrew/bin/mpirun"
|
||||||
|
|
||||||
|
# exo-rsh is installed as console script by exo package
|
||||||
|
_exo_rsh_path = shutil.which("exo-rsh")
|
||||||
|
if not _exo_rsh_path:
|
||||||
|
raise RuntimeError("exo-rsh not found in PATH - this should be installed with exo")
|
||||||
|
EXO_RSH_PATH: str = _exo_rsh_path
|
||||||
|
|
||||||
|
|
||||||
|
def get_my_rank(instance: FLASHInstance, my_node_id: str) -> int:
|
||||||
|
"""Determine this node's rank based on position in hosts_by_node."""
|
||||||
|
for i, node_id in enumerate(instance.hosts_by_node.keys()):
|
||||||
|
if str(node_id) == str(my_node_id):
|
||||||
|
return i
|
||||||
|
return -1
|
||||||
|
|
||||||
|
|
||||||
|
def get_coordinator_host(instance: FLASHInstance) -> str:
|
||||||
|
"""Get the IP of the coordinator node."""
|
||||||
|
return instance.coordinator_ip
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_host(host: str) -> str:
|
||||||
|
"""Resolve host string to a usable hostname for MPI hostfile.
|
||||||
|
|
||||||
|
Accepts either an IP address or hostname. For IPs, attempts to resolve
|
||||||
|
to a hostname via DNS/mDNS. Hostnames are returned as-is after validation.
|
||||||
|
"""
|
||||||
|
# Check if input is already a hostname (not an IP)
|
||||||
|
try:
|
||||||
|
socket.inet_aton(host)
|
||||||
|
is_ip = True
|
||||||
|
except socket.error:
|
||||||
|
is_ip = False
|
||||||
|
|
||||||
|
if not is_ip:
|
||||||
|
# Already a hostname, verify it resolves and return as-is
|
||||||
|
try:
|
||||||
|
socket.gethostbyname(host)
|
||||||
|
return host
|
||||||
|
except socket.gaierror:
|
||||||
|
logger.warning(f"Hostname {host} does not resolve, using anyway")
|
||||||
|
return host
|
||||||
|
|
||||||
|
# It's an IP address, try to resolve to hostname
|
||||||
|
try:
|
||||||
|
hostname, _, _ = socket.gethostbyaddr(host)
|
||||||
|
hostname = hostname.split(".")[0]
|
||||||
|
logger.info(f"Resolved {host} to {hostname}")
|
||||||
|
return hostname
|
||||||
|
except socket.herror:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Fall back to IP
|
||||||
|
logger.warning(f"Could not resolve {host} to hostname, using IP directly")
|
||||||
|
return host
|
||||||
|
|
||||||
|
|
||||||
|
def generate_hostfile(instance: FLASHInstance, working_dir: str) -> str:
|
||||||
|
"""Generate MPI hostfile from instance topology."""
|
||||||
|
hostfile_path = os.path.join(working_dir, "flash_hosts.txt")
|
||||||
|
with open(hostfile_path, "w") as f:
|
||||||
|
for _node_id, hosts in instance.hosts_by_node.items():
|
||||||
|
if hosts:
|
||||||
|
host = resolve_host(hosts[0].ip)
|
||||||
|
f.write(f"{host} slots={instance.ranks_per_node}\n")
|
||||||
|
logger.info(f"Generated hostfile at {hostfile_path}")
|
||||||
|
with open(hostfile_path, "r") as f:
|
||||||
|
logger.info(f"Hostfile contents:\n{f.read()}")
|
||||||
|
return hostfile_path
|
||||||
|
|
||||||
|
|
||||||
|
def main(
|
||||||
|
bound_instance: BoundInstance,
|
||||||
|
event_sender: MpSender[Event],
|
||||||
|
task_receiver: MpReceiver[Task],
|
||||||
|
) -> None:
|
||||||
|
"""Main FLASH runner loop.
|
||||||
|
|
||||||
|
Coordinator: generates hostfile and runs mpirun (uses exo-rsh instead of SSH)
|
||||||
|
Workers: just report ready and wait for mpirun to spawn processes on them
|
||||||
|
"""
|
||||||
|
assert isinstance(bound_instance.instance, FLASHInstance)
|
||||||
|
instance = bound_instance.instance
|
||||||
|
runner_id = bound_instance.bound_runner_id
|
||||||
|
my_node_id = str(bound_instance.bound_node_id)
|
||||||
|
|
||||||
|
logger.info(f"FLASH runner starting for simulation: {instance.simulation_name}")
|
||||||
|
|
||||||
|
my_rank = get_my_rank(instance, my_node_id)
|
||||||
|
world_size = len(instance.hosts_by_node)
|
||||||
|
is_coordinator = my_rank == 0
|
||||||
|
coordinator_ip = get_coordinator_host(instance)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"FLASH node: rank={my_rank}, world_size={world_size}, coordinator={is_coordinator}"
|
||||||
|
)
|
||||||
|
logger.info(f"FLASH coordinator IP: {coordinator_ip}")
|
||||||
|
|
||||||
|
process: subprocess.Popen[bytes] | None = None
|
||||||
|
current_status: RunnerStatus = RunnerIdle()
|
||||||
|
shutdown_requested = False
|
||||||
|
|
||||||
|
event_sender.send(
|
||||||
|
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
|
||||||
|
)
|
||||||
|
|
||||||
|
def monitor_output(proc: subprocess.Popen[bytes]) -> None:
|
||||||
|
"""Monitor FLASH stdout for progress updates."""
|
||||||
|
if proc.stdout is None:
|
||||||
|
return
|
||||||
|
for line in iter(proc.stdout.readline, b""):
|
||||||
|
if shutdown_requested:
|
||||||
|
break
|
||||||
|
try:
|
||||||
|
decoded: str = line.decode("utf-8", errors="replace").strip()
|
||||||
|
if decoded:
|
||||||
|
logger.info(f"[FLASH] {decoded}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error parsing FLASH output: {e}")
|
||||||
|
|
||||||
|
with task_receiver as tasks:
|
||||||
|
for task in tasks:
|
||||||
|
event_sender.send(
|
||||||
|
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Running)
|
||||||
|
)
|
||||||
|
event_sender.send(TaskAcknowledged(task_id=task.task_id))
|
||||||
|
|
||||||
|
match task:
|
||||||
|
case LoadModel() if isinstance(current_status, RunnerIdle):
|
||||||
|
current_status = RunnerLoading()
|
||||||
|
logger.info("Starting FLASH simulation")
|
||||||
|
event_sender.send(
|
||||||
|
RunnerStatusUpdated(
|
||||||
|
runner_id=runner_id, runner_status=current_status
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
if is_coordinator:
|
||||||
|
# Coordinator: generate hostfile and run mpirun
|
||||||
|
hostfile = generate_hostfile(
|
||||||
|
instance, instance.working_directory
|
||||||
|
)
|
||||||
|
|
||||||
|
iface = instance.network_interface
|
||||||
|
cmd = [
|
||||||
|
MPIRUN_PATH,
|
||||||
|
"-np",
|
||||||
|
str(instance.total_ranks),
|
||||||
|
"--hostfile",
|
||||||
|
hostfile,
|
||||||
|
"--wdir",
|
||||||
|
instance.working_directory,
|
||||||
|
"--oversubscribe",
|
||||||
|
"--allow-run-as-root",
|
||||||
|
"--mca",
|
||||||
|
"btl",
|
||||||
|
"tcp,self",
|
||||||
|
"--mca",
|
||||||
|
"btl_tcp_if_include",
|
||||||
|
iface,
|
||||||
|
"--mca",
|
||||||
|
"oob_tcp_if_include",
|
||||||
|
iface,
|
||||||
|
"--mca",
|
||||||
|
"plm_rsh_no_tree_spawn",
|
||||||
|
"1",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Use exo-rsh for remote execution (no SSH needed)
|
||||||
|
cmd.extend(["--mca", "plm_rsh_agent", EXO_RSH_PATH])
|
||||||
|
|
||||||
|
cmd.append(instance.flash_executable_path)
|
||||||
|
|
||||||
|
logger.info(f"FLASH distributed launch: {' '.join(cmd)}")
|
||||||
|
|
||||||
|
process = subprocess.Popen(
|
||||||
|
cmd,
|
||||||
|
cwd=instance.working_directory,
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
stderr=subprocess.STDOUT,
|
||||||
|
)
|
||||||
|
|
||||||
|
monitor_thread = threading.Thread(
|
||||||
|
target=monitor_output, args=(process,), daemon=True
|
||||||
|
)
|
||||||
|
monitor_thread.start()
|
||||||
|
|
||||||
|
current_status = RunnerRunning()
|
||||||
|
logger.info(
|
||||||
|
f"FLASH running on {world_size} nodes with {instance.total_ranks} ranks"
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Worker: mpirun on coordinator will use exo-rsh to spawn processes here
|
||||||
|
logger.info(
|
||||||
|
f"Worker {my_rank}: Ready for mpirun to spawn processes via exo-rsh"
|
||||||
|
)
|
||||||
|
current_status = RunnerRunning()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to start FLASH: {e}")
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
current_status = RunnerFailed(error_message=str(e))
|
||||||
|
|
||||||
|
case Shutdown():
|
||||||
|
shutdown_requested = True
|
||||||
|
current_status = RunnerShuttingDown()
|
||||||
|
logger.info("FLASH runner shutting down")
|
||||||
|
event_sender.send(
|
||||||
|
RunnerStatusUpdated(
|
||||||
|
runner_id=runner_id, runner_status=current_status
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if process and process.poll() is None:
|
||||||
|
logger.info("Terminating FLASH simulation")
|
||||||
|
process.terminate()
|
||||||
|
try:
|
||||||
|
process.wait(timeout=10)
|
||||||
|
except subprocess.TimeoutExpired:
|
||||||
|
logger.warning("FLASH didn't terminate, killing")
|
||||||
|
process.kill()
|
||||||
|
process.wait()
|
||||||
|
|
||||||
|
current_status = RunnerShutdown()
|
||||||
|
|
||||||
|
case _:
|
||||||
|
if process and process.poll() is not None:
|
||||||
|
exit_code = process.returncode
|
||||||
|
if exit_code == 0:
|
||||||
|
logger.info("FLASH simulation completed successfully")
|
||||||
|
current_status = RunnerReady()
|
||||||
|
else:
|
||||||
|
logger.error(
|
||||||
|
f"FLASH simulation failed with code {exit_code}"
|
||||||
|
)
|
||||||
|
current_status = RunnerFailed(
|
||||||
|
error_message=f"Exit code {exit_code}"
|
||||||
|
)
|
||||||
|
|
||||||
|
event_sender.send(
|
||||||
|
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Complete)
|
||||||
|
)
|
||||||
|
event_sender.send(
|
||||||
|
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(current_status, RunnerShutdown):
|
||||||
|
break
|
||||||
|
|
||||||
|
if process and process.poll() is None:
|
||||||
|
process.terminate()
|
||||||
|
process.wait(timeout=5)
|
||||||
|
|
||||||
|
logger.info("FLASH runner exiting")
|
||||||
62
src/exo/plugins/implementations/flash/types.py
Normal file
62
src/exo/plugins/implementations/flash/types.py
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
"""FLASH plugin types - commands and instances."""
|
||||||
|
|
||||||
|
from exo.plugins.type_registry import command_registry, instance_registry
|
||||||
|
from exo.shared.types.commands import BaseCommand
|
||||||
|
from exo.shared.types.common import Host, NodeId
|
||||||
|
from exo.shared.types.worker.instances import BaseInstance, InstanceId
|
||||||
|
from exo.shared.types.worker.runners import RunnerId
|
||||||
|
from exo.shared.types.worker.shards import ShardMetadata
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Commands
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
@command_registry.register
|
||||||
|
class LaunchFLASH(BaseCommand):
|
||||||
|
"""Command to launch a FLASH MPI simulation."""
|
||||||
|
|
||||||
|
simulation_name: str
|
||||||
|
flash_executable_path: str
|
||||||
|
parameter_file_path: str
|
||||||
|
working_directory: str
|
||||||
|
ranks_per_node: int = 1
|
||||||
|
min_nodes: int = 1
|
||||||
|
# Optional: explicit hostnames for MPI (e.g., "s14,james21-1")
|
||||||
|
# Used when topology edges don't contain IP addresses
|
||||||
|
hosts: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
@command_registry.register
|
||||||
|
class StopFLASH(BaseCommand):
|
||||||
|
"""Command to stop a running FLASH simulation."""
|
||||||
|
|
||||||
|
instance_id: InstanceId
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Instances
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
@instance_registry.register
|
||||||
|
class FLASHInstance(BaseInstance):
|
||||||
|
"""Instance for FLASH MPI simulation.
|
||||||
|
|
||||||
|
Unlike MLX instances which do tensor parallelism, FLASH instances
|
||||||
|
coordinate MPI processes across nodes. Each node runs one or more
|
||||||
|
MPI ranks of the FLASH simulation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
hosts_by_node: dict[NodeId, list[Host]]
|
||||||
|
flash_executable_path: str
|
||||||
|
parameter_file_path: str
|
||||||
|
working_directory: str
|
||||||
|
ranks_per_node: int = 1
|
||||||
|
total_ranks: int
|
||||||
|
simulation_name: str
|
||||||
|
coordinator_ip: str
|
||||||
|
network_interface: str = "en0" # Network interface for MPI (e.g., en0, eth0)
|
||||||
|
|
||||||
|
def shard(self, runner_id: RunnerId) -> ShardMetadata | None:
|
||||||
|
return self.shard_assignments.runner_to_shard.get(runner_id, None)
|
||||||
110
src/exo/plugins/registry.py
Normal file
110
src/exo/plugins/registry.py
Normal file
@@ -0,0 +1,110 @@
|
|||||||
|
"""Plugin registry for discovering and managing plugins."""
|
||||||
|
|
||||||
|
from collections.abc import Callable, Sequence
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from exo.plugins.base import EXOPlugin
|
||||||
|
|
||||||
|
|
||||||
|
class PluginRegistry:
|
||||||
|
"""Central registry for all plugins."""
|
||||||
|
|
||||||
|
_instance: "PluginRegistry | None" = None
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._plugins: dict[str, EXOPlugin] = {}
|
||||||
|
self._command_handlers: dict[type, EXOPlugin] = {}
|
||||||
|
self._instance_handlers: dict[type, EXOPlugin] = {}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get(cls) -> "PluginRegistry":
|
||||||
|
"""Get the singleton registry instance."""
|
||||||
|
if cls._instance is None:
|
||||||
|
cls._instance = cls()
|
||||||
|
return cls._instance
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def reset(cls) -> None:
|
||||||
|
"""Reset the singleton instance (useful for testing)."""
|
||||||
|
cls._instance = None
|
||||||
|
|
||||||
|
def register(self, plugin: EXOPlugin) -> None:
|
||||||
|
"""Register a plugin and its types."""
|
||||||
|
if plugin.name in self._plugins:
|
||||||
|
raise ValueError(f"Plugin '{plugin.name}' already registered")
|
||||||
|
|
||||||
|
logger.info(f"Registering plugin: {plugin.name} v{plugin.version}")
|
||||||
|
|
||||||
|
self._plugins[plugin.name] = plugin
|
||||||
|
|
||||||
|
# Register command handlers
|
||||||
|
for cmd_type in plugin.get_command_types():
|
||||||
|
self._command_handlers[cmd_type] = plugin
|
||||||
|
logger.debug(f" Registered command: {cmd_type.__name__}")
|
||||||
|
|
||||||
|
# Register instance handler
|
||||||
|
instance_type = plugin.get_instance_type()
|
||||||
|
self._instance_handlers[instance_type] = plugin
|
||||||
|
logger.debug(f" Registered instance: {instance_type.__name__}")
|
||||||
|
|
||||||
|
def get_plugin(self, name: str) -> EXOPlugin | None:
|
||||||
|
"""Get a plugin by name."""
|
||||||
|
return self._plugins.get(name)
|
||||||
|
|
||||||
|
def get_plugin_for_command(self, command: object) -> EXOPlugin | None:
|
||||||
|
"""Get the plugin that handles a command."""
|
||||||
|
for plugin in self._plugins.values():
|
||||||
|
if plugin.handles_command(command):
|
||||||
|
return plugin
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_plugin_for_instance(self, instance: object) -> EXOPlugin | None:
|
||||||
|
"""Get the plugin that manages an instance."""
|
||||||
|
for plugin in self._plugins.values():
|
||||||
|
if plugin.handles_instance(instance):
|
||||||
|
return plugin
|
||||||
|
return None
|
||||||
|
|
||||||
|
def all_plugins(self) -> Sequence[EXOPlugin]:
|
||||||
|
"""Get all registered plugins."""
|
||||||
|
return list(self._plugins.values())
|
||||||
|
|
||||||
|
def get_all_api_routes(
|
||||||
|
self,
|
||||||
|
) -> Sequence[tuple[str, str, Callable[..., Any], EXOPlugin]]:
|
||||||
|
"""Get all API routes from all plugins."""
|
||||||
|
routes: list[tuple[str, str, Callable[..., Any], EXOPlugin]] = []
|
||||||
|
for plugin in self._plugins.values():
|
||||||
|
for method, path, handler in plugin.get_api_routes():
|
||||||
|
routes.append((method, path, handler, plugin))
|
||||||
|
return routes
|
||||||
|
|
||||||
|
|
||||||
|
def discover_plugins() -> None:
|
||||||
|
"""Auto-discover and register plugins from the implementations directory.
|
||||||
|
|
||||||
|
Plugins should have a register() function that returns an EXOPlugin instance.
|
||||||
|
"""
|
||||||
|
import importlib
|
||||||
|
import pkgutil
|
||||||
|
|
||||||
|
registry = PluginRegistry.get()
|
||||||
|
|
||||||
|
try:
|
||||||
|
import exo.plugins.implementations as impl_package
|
||||||
|
|
||||||
|
for _, module_name, _ in pkgutil.iter_modules(impl_package.__path__):
|
||||||
|
try:
|
||||||
|
module = importlib.import_module(
|
||||||
|
f"exo.plugins.implementations.{module_name}"
|
||||||
|
)
|
||||||
|
if hasattr(module, "register"):
|
||||||
|
plugin = module.register() # pyright: ignore[reportAny]
|
||||||
|
if plugin is not None:
|
||||||
|
registry.register(plugin) # pyright: ignore[reportAny]
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to load plugin {module_name}: {e}")
|
||||||
|
except ImportError:
|
||||||
|
logger.debug("No plugin implementations package found")
|
||||||
84
src/exo/plugins/type_registry.py
Normal file
84
src/exo/plugins/type_registry.py
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
"""Dynamic type registry for plugin types.
|
||||||
|
|
||||||
|
This module provides a registry system that allows plugins to register their
|
||||||
|
command and instance types dynamically, eliminating the need for static union
|
||||||
|
types and avoiding circular imports.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import TypeVar
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from exo.utils.pydantic_ext import CamelCaseModel
|
||||||
|
|
||||||
|
# TypeVar for preserving exact types through the register decorator
|
||||||
|
_TCls = TypeVar("_TCls", bound=type[CamelCaseModel])
|
||||||
|
|
||||||
|
|
||||||
|
class TypeRegistry[T: CamelCaseModel]:
|
||||||
|
"""Registry for dynamically registered Pydantic types.
|
||||||
|
|
||||||
|
Enables plugins to register their types at import time. Deserialization
|
||||||
|
uses the class name from the tagged JSON format to look up the correct type.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, name: str) -> None:
|
||||||
|
self._name = name
|
||||||
|
self._types: dict[str, type[T]] = {}
|
||||||
|
|
||||||
|
def register(self, cls: _TCls) -> _TCls:
|
||||||
|
"""Decorator to register a type with this registry.
|
||||||
|
|
||||||
|
Preserves the exact type through the decorator for proper type checking.
|
||||||
|
"""
|
||||||
|
self._types[cls.__name__] = cls # type: ignore[assignment]
|
||||||
|
logger.debug(f"{self._name}: registered {cls.__name__}")
|
||||||
|
return cls
|
||||||
|
|
||||||
|
def get(self, name: str) -> type[T] | None:
|
||||||
|
"""Look up a type by class name."""
|
||||||
|
return self._types.get(name)
|
||||||
|
|
||||||
|
def all_types(self) -> dict[str, type[T]]:
|
||||||
|
"""Return all registered types."""
|
||||||
|
return dict(self._types)
|
||||||
|
|
||||||
|
def deserialize(self, data: dict[str, dict[str, object]] | CamelCaseModel) -> T:
|
||||||
|
"""Deserialize dict to the appropriate registered type.
|
||||||
|
|
||||||
|
Supports two formats:
|
||||||
|
1. Tagged format: {"ClassName": {...fields...}} - used for network serialization
|
||||||
|
2. Flat format: {...fields...} - used for API requests, tries each type
|
||||||
|
"""
|
||||||
|
# If already deserialized (e.g., from Pydantic), return as-is
|
||||||
|
if isinstance(data, CamelCaseModel):
|
||||||
|
return data # type: ignore[return-value]
|
||||||
|
|
||||||
|
# Check for tagged format: single key that matches a registered type
|
||||||
|
if len(data) == 1:
|
||||||
|
class_name: str = next(iter(data.keys()))
|
||||||
|
cls = self._types.get(class_name)
|
||||||
|
if cls is not None:
|
||||||
|
return cls.model_validate(data[class_name], strict=False)
|
||||||
|
|
||||||
|
# Flat format: try each registered type, use first that validates
|
||||||
|
errors: list[str] = []
|
||||||
|
for type_name, cls in self._types.items():
|
||||||
|
try:
|
||||||
|
return cls.model_validate(data, strict=False)
|
||||||
|
except Exception as e: # noqa: BLE001
|
||||||
|
errors.append(f"{type_name}: {e}")
|
||||||
|
|
||||||
|
# None matched - provide helpful error
|
||||||
|
available = ", ".join(self._types.keys())
|
||||||
|
raise ValueError(
|
||||||
|
f"{self._name}: could not deserialize data. "
|
||||||
|
f"Available types: {available}. Errors: {'; '.join(errors[:3])}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Global registries for commands, instances, events, and tasks
|
||||||
|
command_registry: TypeRegistry[CamelCaseModel] = TypeRegistry("CommandRegistry")
|
||||||
|
instance_registry: TypeRegistry[CamelCaseModel] = TypeRegistry("InstanceRegistry")
|
||||||
|
event_registry: TypeRegistry[CamelCaseModel] = TypeRegistry("EventRegistry")
|
||||||
|
task_registry: TypeRegistry[CamelCaseModel] = TypeRegistry("TaskRegistry")
|
||||||
@@ -30,7 +30,7 @@ class TypedTopic[T: CamelCaseModel]:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def serialize(t: T) -> bytes:
|
def serialize(t: T) -> bytes:
|
||||||
return t.model_dump_json().encode("utf-8")
|
return t.model_dump_json(by_alias=True, serialize_as_any=True).encode("utf-8")
|
||||||
|
|
||||||
def deserialize(self, b: bytes) -> T:
|
def deserialize(self, b: bytes) -> T:
|
||||||
return self.model_type.model_validate_json(b.decode("utf-8"))
|
return self.model_type.model_validate_json(b.decode("utf-8"))
|
||||||
|
|||||||
13
src/exo/rsh/__init__.py
Normal file
13
src/exo/rsh/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
"""Exo RSH - Remote Shell for MPI without SSH.
|
||||||
|
|
||||||
|
This module provides a remote execution mechanism that allows mpirun to spawn
|
||||||
|
processes on remote nodes without requiring SSH setup. It works by:
|
||||||
|
|
||||||
|
1. Each Exo node runs an API server on port 52415 with an /execute endpoint
|
||||||
|
2. The exo-rsh script acts as a drop-in replacement for ssh
|
||||||
|
3. When mpirun calls "exo-rsh hostname command", it HTTP POSTs to the target's /execute
|
||||||
|
4. The target executes the command and returns output
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
mpirun --mca plm_rsh_agent exo-rsh -np 4 --hostfile hosts.txt ./program
|
||||||
|
"""
|
||||||
101
src/exo/rsh/client.py
Normal file
101
src/exo/rsh/client.py
Normal file
@@ -0,0 +1,101 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""exo-rsh - Remote shell client for MPI.
|
||||||
|
|
||||||
|
This script is called by mpirun as a replacement for ssh.
|
||||||
|
Usage: exo-rsh [ssh-options...] hostname command [args...]
|
||||||
|
|
||||||
|
It connects to the target node's Exo API (port 52415) and executes the command.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import socket
|
||||||
|
import sys
|
||||||
|
from typing import Any, cast
|
||||||
|
from urllib.error import URLError
|
||||||
|
from urllib.request import Request, urlopen
|
||||||
|
|
||||||
|
# Use the same port as Exo's API server
|
||||||
|
EXO_API_PORT = 52415
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_hostname(hostname: str) -> str:
|
||||||
|
"""Resolve hostname to IP address."""
|
||||||
|
try:
|
||||||
|
return socket.gethostbyname(hostname)
|
||||||
|
except socket.gaierror:
|
||||||
|
# If resolution fails, try using the hostname directly
|
||||||
|
return hostname
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# Parse arguments - mpirun calls us like: exo-rsh [options] hostname command [args...]
|
||||||
|
# SSH options we might see: -x (disable X11), -o options, etc.
|
||||||
|
args = sys.argv[1:]
|
||||||
|
|
||||||
|
# Skip SSH-style options
|
||||||
|
hostname = None
|
||||||
|
command_start = 0
|
||||||
|
|
||||||
|
i = 0
|
||||||
|
while i < len(args):
|
||||||
|
arg = args[i]
|
||||||
|
if arg.startswith("-"):
|
||||||
|
# Skip option and its value if needed
|
||||||
|
if arg in ("-o", "-i", "-l", "-p", "-F"):
|
||||||
|
i += 2 # Skip option and its argument
|
||||||
|
continue
|
||||||
|
i += 1
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
# First non-option is the hostname
|
||||||
|
hostname = arg
|
||||||
|
command_start = i + 1
|
||||||
|
break
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
if hostname is None or command_start >= len(args):
|
||||||
|
print("Usage: exo-rsh [options] hostname command [args...]", file=sys.stderr)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
command = args[command_start:]
|
||||||
|
|
||||||
|
# Resolve hostname to IP
|
||||||
|
ip = resolve_hostname(hostname)
|
||||||
|
|
||||||
|
# Make request to Exo API
|
||||||
|
url = f"http://{ip}:{EXO_API_PORT}/execute"
|
||||||
|
data = json.dumps({"command": command}).encode("utf-8")
|
||||||
|
|
||||||
|
try:
|
||||||
|
req = Request(url, data=data, headers={"Content-Type": "application/json"})
|
||||||
|
with urlopen(req, timeout=300) as response: # pyright: ignore[reportAny]
|
||||||
|
response_body: bytes = cast(bytes, response.read()) # pyright: ignore[reportAny]
|
||||||
|
result: dict[str, Any] = json.loads(response_body.decode("utf-8")) # pyright: ignore[reportAny]
|
||||||
|
|
||||||
|
# Output stdout/stderr
|
||||||
|
stdout: str = cast(str, result.get("stdout", ""))
|
||||||
|
stderr: str = cast(str, result.get("stderr", ""))
|
||||||
|
exit_code: int = cast(int, result.get("exit_code", 0))
|
||||||
|
|
||||||
|
if stdout:
|
||||||
|
sys.stdout.write(stdout)
|
||||||
|
sys.stdout.flush()
|
||||||
|
if stderr:
|
||||||
|
sys.stderr.write(stderr)
|
||||||
|
sys.stderr.flush()
|
||||||
|
|
||||||
|
sys.exit(exit_code)
|
||||||
|
|
||||||
|
except URLError as e:
|
||||||
|
print(
|
||||||
|
f"exo-rsh: Failed to connect to {hostname}:{EXO_API_PORT}: {e}",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
sys.exit(255)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"exo-rsh: Error: {e}", file=sys.stderr)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -6,8 +6,8 @@ from loguru import logger
|
|||||||
|
|
||||||
from exo.shared.types.common import NodeId
|
from exo.shared.types.common import NodeId
|
||||||
from exo.shared.types.events import (
|
from exo.shared.types.events import (
|
||||||
|
BaseEvent,
|
||||||
ChunkGenerated,
|
ChunkGenerated,
|
||||||
Event,
|
|
||||||
IndexedEvent,
|
IndexedEvent,
|
||||||
InputChunkReceived,
|
InputChunkReceived,
|
||||||
InstanceCreated,
|
InstanceCreated,
|
||||||
@@ -32,10 +32,10 @@ from exo.shared.types.profiling import (
|
|||||||
NodeThunderboltInfo,
|
NodeThunderboltInfo,
|
||||||
)
|
)
|
||||||
from exo.shared.types.state import State
|
from exo.shared.types.state import State
|
||||||
from exo.shared.types.tasks import Task, TaskId, TaskStatus
|
from exo.shared.types.tasks import BaseTask, TaskId, TaskStatus
|
||||||
from exo.shared.types.topology import Connection, RDMAConnection
|
from exo.shared.types.topology import Connection, RDMAConnection
|
||||||
from exo.shared.types.worker.downloads import DownloadProgress
|
from exo.shared.types.worker.downloads import DownloadProgress
|
||||||
from exo.shared.types.worker.instances import Instance, InstanceId
|
from exo.shared.types.worker.instances import BaseInstance, InstanceId
|
||||||
from exo.shared.types.worker.runners import RunnerId, RunnerStatus
|
from exo.shared.types.worker.runners import RunnerId, RunnerStatus
|
||||||
from exo.utils.info_gatherer.info_gatherer import (
|
from exo.utils.info_gatherer.info_gatherer import (
|
||||||
MacmonMetrics,
|
MacmonMetrics,
|
||||||
@@ -49,7 +49,7 @@ from exo.utils.info_gatherer.info_gatherer import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def event_apply(event: Event, state: State) -> State:
|
def event_apply(event: BaseEvent, state: State) -> State:
|
||||||
"""Apply an event to state."""
|
"""Apply an event to state."""
|
||||||
match event:
|
match event:
|
||||||
case (
|
case (
|
||||||
@@ -82,6 +82,10 @@ def event_apply(event: Event, state: State) -> State:
|
|||||||
return apply_topology_edge_created(event, state)
|
return apply_topology_edge_created(event, state)
|
||||||
case TopologyEdgeDeleted():
|
case TopologyEdgeDeleted():
|
||||||
return apply_topology_edge_deleted(event, state)
|
return apply_topology_edge_deleted(event, state)
|
||||||
|
case _:
|
||||||
|
# Unknown event types from plugins are ignored
|
||||||
|
logger.debug(f"Ignoring unknown event type: {type(event).__name__}")
|
||||||
|
return state
|
||||||
|
|
||||||
|
|
||||||
def apply(state: State, event: IndexedEvent) -> State:
|
def apply(state: State, event: IndexedEvent) -> State:
|
||||||
@@ -122,12 +126,12 @@ def apply_node_download_progress(event: NodeDownloadProgress, state: State) -> S
|
|||||||
|
|
||||||
|
|
||||||
def apply_task_created(event: TaskCreated, state: State) -> State:
|
def apply_task_created(event: TaskCreated, state: State) -> State:
|
||||||
new_tasks: Mapping[TaskId, Task] = {**state.tasks, event.task_id: event.task}
|
new_tasks: Mapping[TaskId, BaseTask] = {**state.tasks, event.task_id: event.task}
|
||||||
return state.model_copy(update={"tasks": new_tasks})
|
return state.model_copy(update={"tasks": new_tasks})
|
||||||
|
|
||||||
|
|
||||||
def apply_task_deleted(event: TaskDeleted, state: State) -> State:
|
def apply_task_deleted(event: TaskDeleted, state: State) -> State:
|
||||||
new_tasks: Mapping[TaskId, Task] = {
|
new_tasks: Mapping[TaskId, BaseTask] = {
|
||||||
tid: task for tid, task in state.tasks.items() if tid != event.task_id
|
tid: task for tid, task in state.tasks.items() if tid != event.task_id
|
||||||
}
|
}
|
||||||
return state.model_copy(update={"tasks": new_tasks})
|
return state.model_copy(update={"tasks": new_tasks})
|
||||||
@@ -146,7 +150,7 @@ def apply_task_status_updated(event: TaskStatusUpdated, state: State) -> State:
|
|||||||
update["error_message"] = None
|
update["error_message"] = None
|
||||||
|
|
||||||
updated_task = state.tasks[event.task_id].model_copy(update=update)
|
updated_task = state.tasks[event.task_id].model_copy(update=update)
|
||||||
new_tasks: Mapping[TaskId, Task] = {**state.tasks, event.task_id: updated_task}
|
new_tasks: Mapping[TaskId, BaseTask] = {**state.tasks, event.task_id: updated_task}
|
||||||
return state.model_copy(update={"tasks": new_tasks})
|
return state.model_copy(update={"tasks": new_tasks})
|
||||||
|
|
||||||
|
|
||||||
@@ -158,13 +162,13 @@ def apply_task_failed(event: TaskFailed, state: State) -> State:
|
|||||||
updated_task = state.tasks[event.task_id].model_copy(
|
updated_task = state.tasks[event.task_id].model_copy(
|
||||||
update={"error_type": event.error_type, "error_message": event.error_message}
|
update={"error_type": event.error_type, "error_message": event.error_message}
|
||||||
)
|
)
|
||||||
new_tasks: Mapping[TaskId, Task] = {**state.tasks, event.task_id: updated_task}
|
new_tasks: Mapping[TaskId, BaseTask] = {**state.tasks, event.task_id: updated_task}
|
||||||
return state.model_copy(update={"tasks": new_tasks})
|
return state.model_copy(update={"tasks": new_tasks})
|
||||||
|
|
||||||
|
|
||||||
def apply_instance_created(event: InstanceCreated, state: State) -> State:
|
def apply_instance_created(event: InstanceCreated, state: State) -> State:
|
||||||
instance = event.instance
|
instance = event.instance
|
||||||
new_instances: Mapping[InstanceId, Instance] = {
|
new_instances: Mapping[InstanceId, BaseInstance] = {
|
||||||
**state.instances,
|
**state.instances,
|
||||||
instance.instance_id: instance,
|
instance.instance_id: instance,
|
||||||
}
|
}
|
||||||
@@ -172,7 +176,7 @@ def apply_instance_created(event: InstanceCreated, state: State) -> State:
|
|||||||
|
|
||||||
|
|
||||||
def apply_instance_deleted(event: InstanceDeleted, state: State) -> State:
|
def apply_instance_deleted(event: InstanceDeleted, state: State) -> State:
|
||||||
new_instances: Mapping[InstanceId, Instance] = {
|
new_instances: Mapping[InstanceId, BaseInstance] = {
|
||||||
iid: inst for iid, inst in state.instances.items() if iid != event.instance_id
|
iid: inst for iid, inst in state.instances.items() if iid != event.instance_id
|
||||||
}
|
}
|
||||||
return state.model_copy(update={"instances": new_instances})
|
return state.model_copy(update={"instances": new_instances})
|
||||||
|
|||||||
@@ -1,15 +1,21 @@
|
|||||||
import time
|
import time
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
from typing import Annotated, Any, Literal
|
from typing import Annotated, Any, Literal, cast
|
||||||
|
|
||||||
from fastapi import UploadFile
|
from fastapi import UploadFile
|
||||||
from pydantic import BaseModel, Field, field_validator
|
from pydantic import BaseModel, Field, field_validator
|
||||||
from pydantic_core import PydanticUseDefault
|
from pydantic_core import PydanticUseDefault
|
||||||
|
|
||||||
|
from exo.plugins.type_registry import instance_registry
|
||||||
from exo.shared.models.model_cards import ModelCard, ModelId
|
from exo.shared.models.model_cards import ModelCard, ModelId
|
||||||
from exo.shared.types.common import CommandId
|
from exo.shared.types.common import CommandId
|
||||||
from exo.shared.types.memory import Memory
|
from exo.shared.types.memory import Memory
|
||||||
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
|
from exo.shared.types.worker.instances import (
|
||||||
|
BaseInstance,
|
||||||
|
Instance,
|
||||||
|
InstanceId,
|
||||||
|
InstanceMeta,
|
||||||
|
)
|
||||||
from exo.shared.types.worker.shards import Sharding
|
from exo.shared.types.worker.shards import Sharding
|
||||||
|
|
||||||
FinishReason = Literal[
|
FinishReason = Literal[
|
||||||
@@ -200,6 +206,12 @@ class PlaceInstanceParams(BaseModel):
|
|||||||
class CreateInstanceParams(BaseModel):
|
class CreateInstanceParams(BaseModel):
|
||||||
instance: Instance
|
instance: Instance
|
||||||
|
|
||||||
|
@field_validator("instance", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def validate_instance(cls, v: Any) -> BaseInstance: # noqa: ANN401 # pyright: ignore[reportAny]
|
||||||
|
"""Validate instance using registry to handle both tagged and flat formats."""
|
||||||
|
return cast(BaseInstance, instance_registry.deserialize(v)) # pyright: ignore[reportAny]
|
||||||
|
|
||||||
|
|
||||||
class PlacementPreview(BaseModel):
|
class PlacementPreview(BaseModel):
|
||||||
model_id: ModelId
|
model_id: ModelId
|
||||||
|
|||||||
@@ -1,5 +1,14 @@
|
|||||||
from pydantic import Field
|
"""Command types for exo.
|
||||||
|
|
||||||
|
Commands are registered dynamically via the command_registry, allowing plugins
|
||||||
|
to add their own command types without modifying this file.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Any, cast
|
||||||
|
|
||||||
|
from pydantic import Field, field_validator
|
||||||
|
|
||||||
|
from exo.plugins.type_registry import command_registry
|
||||||
from exo.shared.models.model_cards import ModelCard
|
from exo.shared.models.model_cards import ModelCard
|
||||||
from exo.shared.types.api import (
|
from exo.shared.types.api import (
|
||||||
ChatCompletionTaskParams,
|
ChatCompletionTaskParams,
|
||||||
@@ -14,25 +23,32 @@ from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
|
|||||||
|
|
||||||
|
|
||||||
class BaseCommand(TaggedModel):
|
class BaseCommand(TaggedModel):
|
||||||
|
"""Base class for all commands."""
|
||||||
|
|
||||||
command_id: CommandId = Field(default_factory=CommandId)
|
command_id: CommandId = Field(default_factory=CommandId)
|
||||||
|
|
||||||
|
|
||||||
|
@command_registry.register
|
||||||
class TestCommand(BaseCommand):
|
class TestCommand(BaseCommand):
|
||||||
__test__ = False
|
__test__ = False
|
||||||
|
|
||||||
|
|
||||||
|
@command_registry.register
|
||||||
class ChatCompletion(BaseCommand):
|
class ChatCompletion(BaseCommand):
|
||||||
request_params: ChatCompletionTaskParams
|
request_params: ChatCompletionTaskParams
|
||||||
|
|
||||||
|
|
||||||
|
@command_registry.register
|
||||||
class ImageGeneration(BaseCommand):
|
class ImageGeneration(BaseCommand):
|
||||||
request_params: ImageGenerationTaskParams
|
request_params: ImageGenerationTaskParams
|
||||||
|
|
||||||
|
|
||||||
|
@command_registry.register
|
||||||
class ImageEdits(BaseCommand):
|
class ImageEdits(BaseCommand):
|
||||||
request_params: ImageEditsInternalParams
|
request_params: ImageEditsInternalParams
|
||||||
|
|
||||||
|
|
||||||
|
@command_registry.register
|
||||||
class PlaceInstance(BaseCommand):
|
class PlaceInstance(BaseCommand):
|
||||||
model_card: ModelCard
|
model_card: ModelCard
|
||||||
sharding: Sharding
|
sharding: Sharding
|
||||||
@@ -40,28 +56,34 @@ class PlaceInstance(BaseCommand):
|
|||||||
min_nodes: int
|
min_nodes: int
|
||||||
|
|
||||||
|
|
||||||
|
@command_registry.register
|
||||||
class CreateInstance(BaseCommand):
|
class CreateInstance(BaseCommand):
|
||||||
instance: Instance
|
instance: Instance
|
||||||
|
|
||||||
|
|
||||||
|
@command_registry.register
|
||||||
class DeleteInstance(BaseCommand):
|
class DeleteInstance(BaseCommand):
|
||||||
instance_id: InstanceId
|
instance_id: InstanceId
|
||||||
|
|
||||||
|
|
||||||
|
@command_registry.register
|
||||||
class TaskFinished(BaseCommand):
|
class TaskFinished(BaseCommand):
|
||||||
finished_command_id: CommandId
|
finished_command_id: CommandId
|
||||||
|
|
||||||
|
|
||||||
|
@command_registry.register
|
||||||
class SendInputChunk(BaseCommand):
|
class SendInputChunk(BaseCommand):
|
||||||
"""Command to send an input image chunk (converted to event by master)."""
|
"""Command to send an input image chunk (converted to event by master)."""
|
||||||
|
|
||||||
chunk: InputImageChunk
|
chunk: InputImageChunk
|
||||||
|
|
||||||
|
|
||||||
|
@command_registry.register
|
||||||
class RequestEventLog(BaseCommand):
|
class RequestEventLog(BaseCommand):
|
||||||
since_idx: int
|
since_idx: int
|
||||||
|
|
||||||
|
|
||||||
|
# Union type for core commands - used by ForwarderCommand for network deserialization
|
||||||
Command = (
|
Command = (
|
||||||
TestCommand
|
TestCommand
|
||||||
| RequestEventLog
|
| RequestEventLog
|
||||||
@@ -77,5 +99,14 @@ Command = (
|
|||||||
|
|
||||||
|
|
||||||
class ForwarderCommand(CamelCaseModel):
|
class ForwarderCommand(CamelCaseModel):
|
||||||
|
"""Wrapper for commands that includes origin node."""
|
||||||
|
|
||||||
origin: NodeId
|
origin: NodeId
|
||||||
command: Command
|
command: BaseCommand
|
||||||
|
|
||||||
|
@field_validator("command", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def validate_command(cls, v: Any) -> BaseCommand: # noqa: ANN401 # pyright: ignore[reportAny]
|
||||||
|
"""Validate command, using registry for plugin commands not in Command union."""
|
||||||
|
# First try the registry (handles both core and plugin commands)
|
||||||
|
return cast(BaseCommand, command_registry.deserialize(v)) # pyright: ignore[reportAny]
|
||||||
|
|||||||
@@ -1,13 +1,15 @@
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from typing import Any, cast
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field, field_validator
|
||||||
|
|
||||||
|
from exo.plugins.type_registry import event_registry, instance_registry, task_registry
|
||||||
from exo.shared.topology import Connection
|
from exo.shared.topology import Connection
|
||||||
from exo.shared.types.chunks import GenerationChunk, InputImageChunk
|
from exo.shared.types.chunks import GenerationChunk, InputImageChunk
|
||||||
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
|
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
|
||||||
from exo.shared.types.tasks import Task, TaskId, TaskStatus
|
from exo.shared.types.tasks import BaseTask, TaskId, TaskStatus
|
||||||
from exo.shared.types.worker.downloads import DownloadProgress
|
from exo.shared.types.worker.downloads import DownloadProgress
|
||||||
from exo.shared.types.worker.instances import Instance, InstanceId
|
from exo.shared.types.worker.instances import BaseInstance, InstanceId
|
||||||
from exo.shared.types.worker.runners import RunnerId, RunnerStatus
|
from exo.shared.types.worker.runners import RunnerId, RunnerStatus
|
||||||
from exo.utils.info_gatherer.info_gatherer import GatheredInfo
|
from exo.utils.info_gatherer.info_gatherer import GatheredInfo
|
||||||
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
|
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
|
||||||
@@ -25,36 +27,53 @@ class BaseEvent(TaggedModel):
|
|||||||
_master_time_stamp: None | datetime = None
|
_master_time_stamp: None | datetime = None
|
||||||
|
|
||||||
|
|
||||||
|
@event_registry.register
|
||||||
class TestEvent(BaseEvent):
|
class TestEvent(BaseEvent):
|
||||||
__test__ = False
|
__test__ = False
|
||||||
|
|
||||||
|
|
||||||
|
@event_registry.register
|
||||||
class TaskCreated(BaseEvent):
|
class TaskCreated(BaseEvent):
|
||||||
task_id: TaskId
|
task_id: TaskId
|
||||||
task: Task
|
task: BaseTask
|
||||||
|
|
||||||
|
@field_validator("task", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def validate_task(cls, v: Any) -> BaseTask: # noqa: ANN401 # pyright: ignore[reportAny]
|
||||||
|
return cast(BaseTask, task_registry.deserialize(v)) # pyright: ignore[reportAny]
|
||||||
|
|
||||||
|
|
||||||
|
@event_registry.register
|
||||||
class TaskAcknowledged(BaseEvent):
|
class TaskAcknowledged(BaseEvent):
|
||||||
task_id: TaskId
|
task_id: TaskId
|
||||||
|
|
||||||
|
|
||||||
|
@event_registry.register
|
||||||
class TaskDeleted(BaseEvent):
|
class TaskDeleted(BaseEvent):
|
||||||
task_id: TaskId
|
task_id: TaskId
|
||||||
|
|
||||||
|
|
||||||
|
@event_registry.register
|
||||||
class TaskStatusUpdated(BaseEvent):
|
class TaskStatusUpdated(BaseEvent):
|
||||||
task_id: TaskId
|
task_id: TaskId
|
||||||
task_status: TaskStatus
|
task_status: TaskStatus
|
||||||
|
|
||||||
|
|
||||||
|
@event_registry.register
|
||||||
class TaskFailed(BaseEvent):
|
class TaskFailed(BaseEvent):
|
||||||
task_id: TaskId
|
task_id: TaskId
|
||||||
error_type: str
|
error_type: str
|
||||||
error_message: str
|
error_message: str
|
||||||
|
|
||||||
|
|
||||||
|
@event_registry.register
|
||||||
class InstanceCreated(BaseEvent):
|
class InstanceCreated(BaseEvent):
|
||||||
instance: Instance
|
instance: BaseInstance
|
||||||
|
|
||||||
|
@field_validator("instance", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def validate_instance(cls, v: Any) -> BaseInstance: # noqa: ANN401 # pyright: ignore[reportAny]
|
||||||
|
return cast(BaseInstance, instance_registry.deserialize(v)) # pyright: ignore[reportAny]
|
||||||
|
|
||||||
def __eq__(self, other: object) -> bool:
|
def __eq__(self, other: object) -> bool:
|
||||||
if isinstance(other, InstanceCreated):
|
if isinstance(other, InstanceCreated):
|
||||||
@@ -63,52 +82,63 @@ class InstanceCreated(BaseEvent):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
@event_registry.register
|
||||||
class InstanceDeleted(BaseEvent):
|
class InstanceDeleted(BaseEvent):
|
||||||
instance_id: InstanceId
|
instance_id: InstanceId
|
||||||
|
|
||||||
|
|
||||||
|
@event_registry.register
|
||||||
class RunnerStatusUpdated(BaseEvent):
|
class RunnerStatusUpdated(BaseEvent):
|
||||||
runner_id: RunnerId
|
runner_id: RunnerId
|
||||||
runner_status: RunnerStatus
|
runner_status: RunnerStatus
|
||||||
|
|
||||||
|
|
||||||
|
@event_registry.register
|
||||||
class RunnerDeleted(BaseEvent):
|
class RunnerDeleted(BaseEvent):
|
||||||
runner_id: RunnerId
|
runner_id: RunnerId
|
||||||
|
|
||||||
|
|
||||||
|
@event_registry.register
|
||||||
class NodeTimedOut(BaseEvent):
|
class NodeTimedOut(BaseEvent):
|
||||||
node_id: NodeId
|
node_id: NodeId
|
||||||
|
|
||||||
|
|
||||||
# TODO: bikeshed this name
|
# TODO: bikeshed this name
|
||||||
|
@event_registry.register
|
||||||
class NodeGatheredInfo(BaseEvent):
|
class NodeGatheredInfo(BaseEvent):
|
||||||
node_id: NodeId
|
node_id: NodeId
|
||||||
when: str # this is a manually cast datetime overrode by the master when the event is indexed, rather than the local time on the device
|
when: str # this is a manually cast datetime overrode by the master when the event is indexed, rather than the local time on the device
|
||||||
info: GatheredInfo
|
info: GatheredInfo
|
||||||
|
|
||||||
|
|
||||||
|
@event_registry.register
|
||||||
class NodeDownloadProgress(BaseEvent):
|
class NodeDownloadProgress(BaseEvent):
|
||||||
download_progress: DownloadProgress
|
download_progress: DownloadProgress
|
||||||
|
|
||||||
|
|
||||||
|
@event_registry.register
|
||||||
class ChunkGenerated(BaseEvent):
|
class ChunkGenerated(BaseEvent):
|
||||||
command_id: CommandId
|
command_id: CommandId
|
||||||
chunk: GenerationChunk
|
chunk: GenerationChunk
|
||||||
|
|
||||||
|
|
||||||
|
@event_registry.register
|
||||||
class InputChunkReceived(BaseEvent):
|
class InputChunkReceived(BaseEvent):
|
||||||
command_id: CommandId
|
command_id: CommandId
|
||||||
chunk: InputImageChunk
|
chunk: InputImageChunk
|
||||||
|
|
||||||
|
|
||||||
|
@event_registry.register
|
||||||
class TopologyEdgeCreated(BaseEvent):
|
class TopologyEdgeCreated(BaseEvent):
|
||||||
conn: Connection
|
conn: Connection
|
||||||
|
|
||||||
|
|
||||||
|
@event_registry.register
|
||||||
class TopologyEdgeDeleted(BaseEvent):
|
class TopologyEdgeDeleted(BaseEvent):
|
||||||
conn: Connection
|
conn: Connection
|
||||||
|
|
||||||
|
|
||||||
|
# Union type for Pydantic validation - tries each type in order
|
||||||
Event = (
|
Event = (
|
||||||
TestEvent
|
TestEvent
|
||||||
| TaskCreated
|
| TaskCreated
|
||||||
@@ -134,7 +164,12 @@ class IndexedEvent(CamelCaseModel):
|
|||||||
"""An event indexed by the master, with a globally unique index"""
|
"""An event indexed by the master, with a globally unique index"""
|
||||||
|
|
||||||
idx: int = Field(ge=0)
|
idx: int = Field(ge=0)
|
||||||
event: Event
|
event: BaseEvent
|
||||||
|
|
||||||
|
@field_validator("event", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def validate_event(cls, v: Any) -> BaseEvent: # noqa: ANN401 # pyright: ignore[reportAny]
|
||||||
|
return cast(BaseEvent, event_registry.deserialize(v)) # pyright: ignore[reportAny]
|
||||||
|
|
||||||
|
|
||||||
class ForwarderEvent(CamelCaseModel):
|
class ForwarderEvent(CamelCaseModel):
|
||||||
@@ -143,4 +178,9 @@ class ForwarderEvent(CamelCaseModel):
|
|||||||
origin_idx: int = Field(ge=0)
|
origin_idx: int = Field(ge=0)
|
||||||
origin: NodeId
|
origin: NodeId
|
||||||
session: SessionId
|
session: SessionId
|
||||||
event: Event
|
event: BaseEvent
|
||||||
|
|
||||||
|
@field_validator("event", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def validate_event(cls, v: Any) -> BaseEvent: # noqa: ANN401 # pyright: ignore[reportAny]
|
||||||
|
return cast(BaseEvent, event_registry.deserialize(v)) # pyright: ignore[reportAny]
|
||||||
|
|||||||
@@ -14,9 +14,9 @@ from exo.shared.types.profiling import (
|
|||||||
NodeThunderboltInfo,
|
NodeThunderboltInfo,
|
||||||
SystemPerformanceProfile,
|
SystemPerformanceProfile,
|
||||||
)
|
)
|
||||||
from exo.shared.types.tasks import Task, TaskId
|
from exo.shared.types.tasks import BaseTask, TaskId
|
||||||
from exo.shared.types.worker.downloads import DownloadProgress
|
from exo.shared.types.worker.downloads import DownloadProgress
|
||||||
from exo.shared.types.worker.instances import Instance, InstanceId
|
from exo.shared.types.worker.instances import BaseInstance, InstanceId
|
||||||
from exo.shared.types.worker.runners import RunnerId, RunnerStatus
|
from exo.shared.types.worker.runners import RunnerId, RunnerStatus
|
||||||
from exo.utils.pydantic_ext import CamelCaseModel
|
from exo.utils.pydantic_ext import CamelCaseModel
|
||||||
|
|
||||||
@@ -37,10 +37,10 @@ class State(CamelCaseModel):
|
|||||||
strict=True,
|
strict=True,
|
||||||
arbitrary_types_allowed=True,
|
arbitrary_types_allowed=True,
|
||||||
)
|
)
|
||||||
instances: Mapping[InstanceId, Instance] = {}
|
instances: Mapping[InstanceId, BaseInstance] = {}
|
||||||
runners: Mapping[RunnerId, RunnerStatus] = {}
|
runners: Mapping[RunnerId, RunnerStatus] = {}
|
||||||
downloads: Mapping[NodeId, Sequence[DownloadProgress]] = {}
|
downloads: Mapping[NodeId, Sequence[DownloadProgress]] = {}
|
||||||
tasks: Mapping[TaskId, Task] = {}
|
tasks: Mapping[TaskId, BaseTask] = {}
|
||||||
last_seen: Mapping[NodeId, datetime] = {}
|
last_seen: Mapping[NodeId, datetime] = {}
|
||||||
topology: Topology = Field(default_factory=Topology)
|
topology: Topology = Field(default_factory=Topology)
|
||||||
last_event_applied_idx: int = Field(default=-1, ge=-1)
|
last_event_applied_idx: int = Field(default=-1, ge=-1)
|
||||||
@@ -52,6 +52,16 @@ class State(CamelCaseModel):
|
|||||||
node_network: Mapping[NodeId, NodeNetworkInfo] = {}
|
node_network: Mapping[NodeId, NodeNetworkInfo] = {}
|
||||||
node_thunderbolt: Mapping[NodeId, NodeThunderboltInfo] = {}
|
node_thunderbolt: Mapping[NodeId, NodeThunderboltInfo] = {}
|
||||||
|
|
||||||
|
@field_serializer("instances", mode="plain")
|
||||||
|
def _encode_instances(
|
||||||
|
self, value: Mapping[InstanceId, BaseInstance]
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Serialize instances with full subclass fields."""
|
||||||
|
return {
|
||||||
|
str(k): v.model_dump(by_alias=True, serialize_as_any=True)
|
||||||
|
for k, v in value.items()
|
||||||
|
}
|
||||||
|
|
||||||
@field_serializer("topology", mode="plain")
|
@field_serializer("topology", mode="plain")
|
||||||
def _encode_topology(self, value: Topology) -> TopologySnapshot:
|
def _encode_topology(self, value: Topology) -> TopologySnapshot:
|
||||||
return value.to_snapshot()
|
return value.to_snapshot()
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ from enum import Enum
|
|||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
|
from exo.plugins.type_registry import task_registry
|
||||||
from exo.shared.types.api import (
|
from exo.shared.types.api import (
|
||||||
ChatCompletionTaskParams,
|
ChatCompletionTaskParams,
|
||||||
ImageEditsInternalParams,
|
ImageEditsInternalParams,
|
||||||
@@ -32,26 +33,32 @@ class BaseTask(TaggedModel):
|
|||||||
instance_id: InstanceId
|
instance_id: InstanceId
|
||||||
|
|
||||||
|
|
||||||
|
@task_registry.register
|
||||||
class CreateRunner(BaseTask): # emitted by Worker
|
class CreateRunner(BaseTask): # emitted by Worker
|
||||||
bound_instance: BoundInstance
|
bound_instance: BoundInstance
|
||||||
|
|
||||||
|
|
||||||
|
@task_registry.register
|
||||||
class DownloadModel(BaseTask): # emitted by Worker
|
class DownloadModel(BaseTask): # emitted by Worker
|
||||||
shard_metadata: ShardMetadata
|
shard_metadata: ShardMetadata
|
||||||
|
|
||||||
|
|
||||||
|
@task_registry.register
|
||||||
class LoadModel(BaseTask): # emitted by Worker
|
class LoadModel(BaseTask): # emitted by Worker
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@task_registry.register
|
||||||
class ConnectToGroup(BaseTask): # emitted by Worker
|
class ConnectToGroup(BaseTask): # emitted by Worker
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@task_registry.register
|
||||||
class StartWarmup(BaseTask): # emitted by Worker
|
class StartWarmup(BaseTask): # emitted by Worker
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@task_registry.register
|
||||||
class ChatCompletion(BaseTask): # emitted by Master
|
class ChatCompletion(BaseTask): # emitted by Master
|
||||||
command_id: CommandId
|
command_id: CommandId
|
||||||
task_params: ChatCompletionTaskParams
|
task_params: ChatCompletionTaskParams
|
||||||
@@ -60,6 +67,7 @@ class ChatCompletion(BaseTask): # emitted by Master
|
|||||||
error_message: str | None = Field(default=None)
|
error_message: str | None = Field(default=None)
|
||||||
|
|
||||||
|
|
||||||
|
@task_registry.register
|
||||||
class ImageGeneration(BaseTask): # emitted by Master
|
class ImageGeneration(BaseTask): # emitted by Master
|
||||||
command_id: CommandId
|
command_id: CommandId
|
||||||
task_params: ImageGenerationTaskParams
|
task_params: ImageGenerationTaskParams
|
||||||
@@ -68,6 +76,7 @@ class ImageGeneration(BaseTask): # emitted by Master
|
|||||||
error_message: str | None = Field(default=None)
|
error_message: str | None = Field(default=None)
|
||||||
|
|
||||||
|
|
||||||
|
@task_registry.register
|
||||||
class ImageEdits(BaseTask): # emitted by Master
|
class ImageEdits(BaseTask): # emitted by Master
|
||||||
command_id: CommandId
|
command_id: CommandId
|
||||||
task_params: ImageEditsInternalParams
|
task_params: ImageEditsInternalParams
|
||||||
@@ -76,10 +85,12 @@ class ImageEdits(BaseTask): # emitted by Master
|
|||||||
error_message: str | None = Field(default=None)
|
error_message: str | None = Field(default=None)
|
||||||
|
|
||||||
|
|
||||||
|
@task_registry.register
|
||||||
class Shutdown(BaseTask): # emitted by Worker
|
class Shutdown(BaseTask): # emitted by Worker
|
||||||
runner_id: RunnerId
|
runner_id: RunnerId
|
||||||
|
|
||||||
|
|
||||||
|
# Union type for Pydantic validation - tries each type in order
|
||||||
Task = (
|
Task = (
|
||||||
CreateRunner
|
CreateRunner
|
||||||
| DownloadModel
|
| DownloadModel
|
||||||
|
|||||||
@@ -1,7 +1,15 @@
|
|||||||
|
"""Instance types for exo.
|
||||||
|
|
||||||
|
Instances are registered dynamically via the instance_registry, allowing plugins
|
||||||
|
to add their own instance types without modifying this file.
|
||||||
|
"""
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
from typing import Any, cast
|
||||||
|
|
||||||
from pydantic import model_validator
|
from pydantic import field_validator, model_validator
|
||||||
|
|
||||||
|
from exo.plugins.type_registry import instance_registry
|
||||||
from exo.shared.types.common import Host, Id, NodeId
|
from exo.shared.types.common import Host, Id, NodeId
|
||||||
from exo.shared.types.worker.runners import RunnerId, ShardAssignments, ShardMetadata
|
from exo.shared.types.worker.runners import RunnerId, ShardAssignments, ShardMetadata
|
||||||
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
|
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
|
||||||
@@ -17,6 +25,8 @@ class InstanceMeta(str, Enum):
|
|||||||
|
|
||||||
|
|
||||||
class BaseInstance(TaggedModel):
|
class BaseInstance(TaggedModel):
|
||||||
|
"""Base class for all instance types."""
|
||||||
|
|
||||||
instance_id: InstanceId
|
instance_id: InstanceId
|
||||||
shard_assignments: ShardAssignments
|
shard_assignments: ShardAssignments
|
||||||
|
|
||||||
@@ -24,25 +34,36 @@ class BaseInstance(TaggedModel):
|
|||||||
return self.shard_assignments.runner_to_shard.get(runner_id, None)
|
return self.shard_assignments.runner_to_shard.get(runner_id, None)
|
||||||
|
|
||||||
|
|
||||||
|
@instance_registry.register
|
||||||
class MlxRingInstance(BaseInstance):
|
class MlxRingInstance(BaseInstance):
|
||||||
hosts_by_node: dict[NodeId, list[Host]]
|
hosts_by_node: dict[NodeId, list[Host]]
|
||||||
ephemeral_port: int
|
ephemeral_port: int
|
||||||
|
|
||||||
|
|
||||||
|
@instance_registry.register
|
||||||
class MlxJacclInstance(BaseInstance):
|
class MlxJacclInstance(BaseInstance):
|
||||||
jaccl_devices: list[list[str | None]]
|
jaccl_devices: list[list[str | None]]
|
||||||
jaccl_coordinators: dict[NodeId, str]
|
jaccl_coordinators: dict[NodeId, str]
|
||||||
|
|
||||||
|
|
||||||
# TODO: Single node instance
|
# Union type for Pydantic validation - tries each type in order
|
||||||
|
# This is used by API endpoints (dashboard) which send flat format
|
||||||
Instance = MlxRingInstance | MlxJacclInstance
|
Instance = MlxRingInstance | MlxJacclInstance
|
||||||
|
|
||||||
|
|
||||||
class BoundInstance(CamelCaseModel):
|
class BoundInstance(CamelCaseModel):
|
||||||
instance: Instance
|
"""An instance bound to a specific runner on a specific node."""
|
||||||
|
|
||||||
|
instance: BaseInstance
|
||||||
bound_runner_id: RunnerId
|
bound_runner_id: RunnerId
|
||||||
bound_node_id: NodeId
|
bound_node_id: NodeId
|
||||||
|
|
||||||
|
@field_validator("instance", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def validate_instance(cls, v: Any) -> BaseInstance: # noqa: ANN401 # pyright: ignore[reportAny]
|
||||||
|
"""Validate instance using registry to handle both tagged and flat formats."""
|
||||||
|
return cast(BaseInstance, instance_registry.deserialize(v)) # pyright: ignore[reportAny]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def bound_shard(self) -> ShardMetadata:
|
def bound_shard(self) -> ShardMetadata:
|
||||||
shard = self.instance.shard(self.bound_runner_id)
|
shard = self.instance.shard(self.bound_runner_id)
|
||||||
|
|||||||
@@ -178,6 +178,11 @@ def mlx_distributed_init(
|
|||||||
os.environ["MLX_JACCL_COORDINATOR"] = jaccl_coordinator
|
os.environ["MLX_JACCL_COORDINATOR"] = jaccl_coordinator
|
||||||
group = mx.distributed.init(backend="jaccl", strict=True)
|
group = mx.distributed.init(backend="jaccl", strict=True)
|
||||||
|
|
||||||
|
case _:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported instance type for MLX distributed: {type(bound_instance.instance)}"
|
||||||
|
)
|
||||||
|
|
||||||
logger.info(f"Rank {rank} mlx distributed initialization complete")
|
logger.info(f"Rank {rank} mlx distributed initialization complete")
|
||||||
|
|
||||||
return group
|
return group
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ from exo.shared.types.api import ImageEditsInternalParams
|
|||||||
from exo.shared.types.commands import ForwarderCommand, RequestEventLog
|
from exo.shared.types.commands import ForwarderCommand, RequestEventLog
|
||||||
from exo.shared.types.common import CommandId, NodeId, SessionId
|
from exo.shared.types.common import CommandId, NodeId, SessionId
|
||||||
from exo.shared.types.events import (
|
from exo.shared.types.events import (
|
||||||
|
BaseEvent,
|
||||||
Event,
|
Event,
|
||||||
EventId,
|
EventId,
|
||||||
ForwarderEvent,
|
ForwarderEvent,
|
||||||
@@ -28,11 +29,11 @@ from exo.shared.types.events import (
|
|||||||
from exo.shared.types.multiaddr import Multiaddr
|
from exo.shared.types.multiaddr import Multiaddr
|
||||||
from exo.shared.types.state import State
|
from exo.shared.types.state import State
|
||||||
from exo.shared.types.tasks import (
|
from exo.shared.types.tasks import (
|
||||||
|
BaseTask,
|
||||||
CreateRunner,
|
CreateRunner,
|
||||||
DownloadModel,
|
DownloadModel,
|
||||||
ImageEdits,
|
ImageEdits,
|
||||||
Shutdown,
|
Shutdown,
|
||||||
Task,
|
|
||||||
TaskStatus,
|
TaskStatus,
|
||||||
)
|
)
|
||||||
from exo.shared.types.topology import Connection, SocketConnection
|
from exo.shared.types.topology import Connection, SocketConnection
|
||||||
@@ -81,7 +82,7 @@ class Worker:
|
|||||||
self.local_event_index = 0
|
self.local_event_index = 0
|
||||||
self.command_sender = command_sender
|
self.command_sender = command_sender
|
||||||
self.connection_message_receiver = connection_message_receiver
|
self.connection_message_receiver = connection_message_receiver
|
||||||
self.event_buffer = OrderedBuffer[Event]()
|
self.event_buffer = OrderedBuffer[BaseEvent]()
|
||||||
self.out_for_delivery: dict[EventId, ForwarderEvent] = {}
|
self.out_for_delivery: dict[EventId, ForwarderEvent] = {}
|
||||||
|
|
||||||
self.state: State = State()
|
self.state: State = State()
|
||||||
@@ -179,7 +180,7 @@ class Worker:
|
|||||||
while True:
|
while True:
|
||||||
await anyio.sleep(0.1)
|
await anyio.sleep(0.1)
|
||||||
# 3. based on the updated state, we plan & execute an operation.
|
# 3. based on the updated state, we plan & execute an operation.
|
||||||
task: Task | None = plan(
|
task: BaseTask | None = plan(
|
||||||
self.node_id,
|
self.node_id,
|
||||||
self.runners,
|
self.runners,
|
||||||
self.download_status,
|
self.download_status,
|
||||||
@@ -298,7 +299,7 @@ class Worker:
|
|||||||
def shutdown(self):
|
def shutdown(self):
|
||||||
self._tg.cancel_scope.cancel()
|
self._tg.cancel_scope.cancel()
|
||||||
|
|
||||||
def _task_to_runner_id(self, task: Task):
|
def _task_to_runner_id(self, task: BaseTask):
|
||||||
instance = self.state.instances[task.instance_id]
|
instance = self.state.instances[task.instance_id]
|
||||||
return instance.shard_assignments.node_to_runner[self.node_id]
|
return instance.shard_assignments.node_to_runner[self.node_id]
|
||||||
|
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ from collections.abc import Mapping, Sequence
|
|||||||
from exo.shared.models.model_cards import ModelId
|
from exo.shared.models.model_cards import ModelId
|
||||||
from exo.shared.types.common import CommandId, NodeId
|
from exo.shared.types.common import CommandId, NodeId
|
||||||
from exo.shared.types.tasks import (
|
from exo.shared.types.tasks import (
|
||||||
|
BaseTask,
|
||||||
ChatCompletion,
|
ChatCompletion,
|
||||||
ConnectToGroup,
|
ConnectToGroup,
|
||||||
CreateRunner,
|
CreateRunner,
|
||||||
@@ -14,7 +15,6 @@ from exo.shared.types.tasks import (
|
|||||||
LoadModel,
|
LoadModel,
|
||||||
Shutdown,
|
Shutdown,
|
||||||
StartWarmup,
|
StartWarmup,
|
||||||
Task,
|
|
||||||
TaskId,
|
TaskId,
|
||||||
TaskStatus,
|
TaskStatus,
|
||||||
)
|
)
|
||||||
@@ -23,7 +23,11 @@ from exo.shared.types.worker.downloads import (
|
|||||||
DownloadOngoing,
|
DownloadOngoing,
|
||||||
DownloadProgress,
|
DownloadProgress,
|
||||||
)
|
)
|
||||||
from exo.shared.types.worker.instances import BoundInstance, Instance, InstanceId
|
from exo.shared.types.worker.instances import (
|
||||||
|
BaseInstance,
|
||||||
|
BoundInstance,
|
||||||
|
InstanceId,
|
||||||
|
)
|
||||||
from exo.shared.types.worker.runners import (
|
from exo.shared.types.worker.runners import (
|
||||||
RunnerConnected,
|
RunnerConnected,
|
||||||
RunnerConnecting,
|
RunnerConnecting,
|
||||||
@@ -48,12 +52,22 @@ def plan(
|
|||||||
download_status: Mapping[ModelId, DownloadProgress],
|
download_status: Mapping[ModelId, DownloadProgress],
|
||||||
# gdls is not expected to be fresh
|
# gdls is not expected to be fresh
|
||||||
global_download_status: Mapping[NodeId, Sequence[DownloadProgress]],
|
global_download_status: Mapping[NodeId, Sequence[DownloadProgress]],
|
||||||
instances: Mapping[InstanceId, Instance],
|
instances: Mapping[InstanceId, BaseInstance],
|
||||||
all_runners: Mapping[RunnerId, RunnerStatus], # all global
|
all_runners: Mapping[RunnerId, RunnerStatus], # all global
|
||||||
tasks: Mapping[TaskId, Task],
|
tasks: Mapping[TaskId, BaseTask],
|
||||||
input_chunk_buffer: Mapping[CommandId, dict[int, str]] | None = None,
|
input_chunk_buffer: Mapping[CommandId, dict[int, str]] | None = None,
|
||||||
input_chunk_counts: Mapping[CommandId, int] | None = None,
|
input_chunk_counts: Mapping[CommandId, int] | None = None,
|
||||||
) -> Task | None:
|
) -> BaseTask | None:
|
||||||
|
from exo.plugins.registry import PluginRegistry
|
||||||
|
|
||||||
|
registry = PluginRegistry.get()
|
||||||
|
|
||||||
|
# Check plugin tasks first
|
||||||
|
for plugin in registry.all_plugins():
|
||||||
|
task = plugin.plan_task(runners, instances)
|
||||||
|
if task is not None:
|
||||||
|
return task
|
||||||
|
|
||||||
# Python short circuiting OR logic should evaluate these sequentially.
|
# Python short circuiting OR logic should evaluate these sequentially.
|
||||||
return (
|
return (
|
||||||
_kill_runner(runners, all_runners, instances)
|
_kill_runner(runners, all_runners, instances)
|
||||||
@@ -69,7 +83,7 @@ def plan(
|
|||||||
def _kill_runner(
|
def _kill_runner(
|
||||||
runners: Mapping[RunnerId, RunnerSupervisor],
|
runners: Mapping[RunnerId, RunnerSupervisor],
|
||||||
all_runners: Mapping[RunnerId, RunnerStatus],
|
all_runners: Mapping[RunnerId, RunnerStatus],
|
||||||
instances: Mapping[InstanceId, Instance],
|
instances: Mapping[InstanceId, BaseInstance],
|
||||||
) -> Shutdown | None:
|
) -> Shutdown | None:
|
||||||
for runner in runners.values():
|
for runner in runners.values():
|
||||||
runner_id = runner.bound_instance.bound_runner_id
|
runner_id = runner.bound_instance.bound_runner_id
|
||||||
@@ -92,7 +106,7 @@ def _kill_runner(
|
|||||||
def _create_runner(
|
def _create_runner(
|
||||||
node_id: NodeId,
|
node_id: NodeId,
|
||||||
runners: Mapping[RunnerId, RunnerSupervisor],
|
runners: Mapping[RunnerId, RunnerSupervisor],
|
||||||
instances: Mapping[InstanceId, Instance],
|
instances: Mapping[InstanceId, BaseInstance],
|
||||||
) -> CreateRunner | None:
|
) -> CreateRunner | None:
|
||||||
for instance in instances.values():
|
for instance in instances.values():
|
||||||
runner_id = instance.shard_assignments.node_to_runner.get(node_id, None)
|
runner_id = instance.shard_assignments.node_to_runner.get(node_id, None)
|
||||||
@@ -117,7 +131,18 @@ def _model_needs_download(
|
|||||||
runners: Mapping[RunnerId, RunnerSupervisor],
|
runners: Mapping[RunnerId, RunnerSupervisor],
|
||||||
download_status: Mapping[ModelId, DownloadProgress],
|
download_status: Mapping[ModelId, DownloadProgress],
|
||||||
) -> DownloadModel | None:
|
) -> DownloadModel | None:
|
||||||
|
from exo.plugins.registry import PluginRegistry
|
||||||
|
|
||||||
|
registry = PluginRegistry.get()
|
||||||
|
|
||||||
for runner in runners.values():
|
for runner in runners.values():
|
||||||
|
instance = runner.bound_instance.instance
|
||||||
|
|
||||||
|
# Check if any plugin wants to skip download for this instance
|
||||||
|
plugin = registry.get_plugin_for_instance(instance)
|
||||||
|
if plugin is not None and plugin.should_skip_download(instance):
|
||||||
|
continue
|
||||||
|
|
||||||
model_id = runner.bound_instance.bound_shard.model_card.model_id
|
model_id = runner.bound_instance.bound_shard.model_card.model_id
|
||||||
if isinstance(runner.status, RunnerIdle) and (
|
if isinstance(runner.status, RunnerIdle) and (
|
||||||
model_id not in download_status
|
model_id not in download_status
|
||||||
@@ -264,10 +289,10 @@ def _ready_to_warmup(
|
|||||||
|
|
||||||
def _pending_tasks(
|
def _pending_tasks(
|
||||||
runners: Mapping[RunnerId, RunnerSupervisor],
|
runners: Mapping[RunnerId, RunnerSupervisor],
|
||||||
tasks: Mapping[TaskId, Task],
|
tasks: Mapping[TaskId, BaseTask],
|
||||||
all_runners: Mapping[RunnerId, RunnerStatus],
|
all_runners: Mapping[RunnerId, RunnerStatus],
|
||||||
input_chunk_buffer: Mapping[CommandId, dict[int, str]] | None = None,
|
input_chunk_buffer: Mapping[CommandId, dict[int, str]] | None = None,
|
||||||
) -> Task | None:
|
) -> BaseTask | None:
|
||||||
for task in tasks.values():
|
for task in tasks.values():
|
||||||
# for now, just forward chat completions
|
# for now, just forward chat completions
|
||||||
# TODO(ciaran): do this better!
|
# TODO(ciaran): do this better!
|
||||||
|
|||||||
@@ -4,7 +4,10 @@ import loguru
|
|||||||
|
|
||||||
from exo.shared.types.events import Event, RunnerStatusUpdated
|
from exo.shared.types.events import Event, RunnerStatusUpdated
|
||||||
from exo.shared.types.tasks import Task
|
from exo.shared.types.tasks import Task
|
||||||
from exo.shared.types.worker.instances import BoundInstance, MlxJacclInstance
|
from exo.shared.types.worker.instances import (
|
||||||
|
BoundInstance,
|
||||||
|
MlxJacclInstance,
|
||||||
|
)
|
||||||
from exo.shared.types.worker.runners import RunnerFailed
|
from exo.shared.types.worker.runners import RunnerFailed
|
||||||
from exo.utils.channels import ClosedResourceError, MpReceiver, MpSender
|
from exo.utils.channels import ClosedResourceError, MpReceiver, MpSender
|
||||||
|
|
||||||
@@ -17,6 +20,7 @@ def entrypoint(
|
|||||||
task_receiver: MpReceiver[Task],
|
task_receiver: MpReceiver[Task],
|
||||||
_logger: "loguru.Logger",
|
_logger: "loguru.Logger",
|
||||||
) -> None:
|
) -> None:
|
||||||
|
# Set FAST_SYNCH based on env var or JACCL device count
|
||||||
fast_synch_override = os.environ.get("EXO_FAST_SYNCH")
|
fast_synch_override = os.environ.get("EXO_FAST_SYNCH")
|
||||||
if fast_synch_override == "on" or (
|
if fast_synch_override == "on" or (
|
||||||
fast_synch_override != "off"
|
fast_synch_override != "off"
|
||||||
@@ -34,11 +38,26 @@ def entrypoint(
|
|||||||
|
|
||||||
logger.info(f"Fast synch flag: {os.environ['MLX_METAL_FAST_SYNCH']}")
|
logger.info(f"Fast synch flag: {os.environ['MLX_METAL_FAST_SYNCH']}")
|
||||||
|
|
||||||
# Import main after setting global logger - this lets us just import logger from this module
|
# Route based on instance type (plugins or default MLX)
|
||||||
try:
|
try:
|
||||||
from exo.worker.runner.runner import main
|
from exo.plugins.registry import PluginRegistry, discover_plugins
|
||||||
|
|
||||||
main(bound_instance, event_sender, task_receiver)
|
# Discover plugins in subprocess (they aren't inherited from main process)
|
||||||
|
discover_plugins()
|
||||||
|
|
||||||
|
registry = PluginRegistry.get()
|
||||||
|
instance = bound_instance.instance
|
||||||
|
|
||||||
|
# Check if a plugin handles this instance type
|
||||||
|
plugin = registry.get_plugin_for_instance(instance)
|
||||||
|
if plugin is not None:
|
||||||
|
# Delegate to plugin runner
|
||||||
|
plugin.create_runner(bound_instance, event_sender, task_receiver)
|
||||||
|
else:
|
||||||
|
# MLX runner (default)
|
||||||
|
from exo.worker.runner.runner import main
|
||||||
|
|
||||||
|
main(bound_instance, event_sender, task_receiver)
|
||||||
except ClosedResourceError:
|
except ClosedResourceError:
|
||||||
logger.warning("Runner communication closed unexpectedly")
|
logger.warning("Runner communication closed unexpectedly")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ from exo.shared.types.events import (
|
|||||||
TaskAcknowledged,
|
TaskAcknowledged,
|
||||||
TaskStatusUpdated,
|
TaskStatusUpdated,
|
||||||
)
|
)
|
||||||
from exo.shared.types.tasks import Task, TaskId, TaskStatus
|
from exo.shared.types.tasks import BaseTask, TaskId, TaskStatus
|
||||||
from exo.shared.types.worker.instances import BoundInstance
|
from exo.shared.types.worker.instances import BoundInstance
|
||||||
from exo.shared.types.worker.runners import (
|
from exo.shared.types.worker.runners import (
|
||||||
RunnerConnecting,
|
RunnerConnecting,
|
||||||
@@ -47,7 +47,7 @@ class RunnerSupervisor:
|
|||||||
runner_process: Process
|
runner_process: Process
|
||||||
initialize_timeout: float
|
initialize_timeout: float
|
||||||
_ev_recv: MpReceiver[Event]
|
_ev_recv: MpReceiver[Event]
|
||||||
_task_sender: MpSender[Task]
|
_task_sender: MpSender[BaseTask]
|
||||||
_event_sender: Sender[Event]
|
_event_sender: Sender[Event]
|
||||||
_tg: TaskGroup | None = field(default=None, init=False)
|
_tg: TaskGroup | None = field(default=None, init=False)
|
||||||
status: RunnerStatus = field(default_factory=RunnerIdle, init=False)
|
status: RunnerStatus = field(default_factory=RunnerIdle, init=False)
|
||||||
@@ -64,7 +64,7 @@ class RunnerSupervisor:
|
|||||||
) -> Self:
|
) -> Self:
|
||||||
ev_send, ev_recv = mp_channel[Event]()
|
ev_send, ev_recv = mp_channel[Event]()
|
||||||
# A task is kind of a runner command
|
# A task is kind of a runner command
|
||||||
task_sender, task_recv = mp_channel[Task]()
|
task_sender, task_recv = mp_channel[BaseTask]()
|
||||||
|
|
||||||
runner_process = Process(
|
runner_process = Process(
|
||||||
target=entrypoint,
|
target=entrypoint,
|
||||||
@@ -126,7 +126,7 @@ class RunnerSupervisor:
|
|||||||
assert self._tg
|
assert self._tg
|
||||||
self._tg.cancel_scope.cancel()
|
self._tg.cancel_scope.cancel()
|
||||||
|
|
||||||
async def start_task(self, task: Task):
|
async def start_task(self, task: BaseTask):
|
||||||
if task.task_id in self.completed:
|
if task.task_id in self.completed:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Skipping invalid task {task} as it has already been completed"
|
f"Skipping invalid task {task} as it has already been completed"
|
||||||
|
|||||||
Reference in New Issue
Block a user