Compare commits

..

13 Commits

Author SHA1 Message Date
Sami Khan
cd630dea43 formatting 2026-01-20 07:54:52 +05:00
Sami Khan
e55dae5ce8 code quality 2026-01-20 07:47:14 +05:00
Sami Khan
302c43afd5 Merge main into sami/flash 2026-01-20 07:31:22 +05:00
Sami Khan
2cf59e2322 nix flake check 2026-01-20 07:21:41 +05:00
Sami Khan
e506c7d65c exo plugins 2026-01-20 06:53:43 +05:00
Sami Khan
c1fa2ddeaf SLURM compatible commands 2026-01-20 06:53:43 +05:00
Sami Khan
37c5a2a246 Merge branch 'main' into sami/flash 2026-01-15 08:57:36 +05:00
Sami Khan
4d7f03834a deleted separate server 2026-01-15 08:50:45 +05:00
Sami Khan
bdb9fbc8c0 Merge branch 'main' into sami/flash 2026-01-14 08:10:51 +05:00
Sami Khan
8c7180810c type checking 2026-01-14 07:15:45 +05:00
Sami Khan
318c6e000b code cleanup 2026-01-14 04:56:59 +05:00
Sami Khan
2d45544da0 use rsh server instead of ssh 2026-01-13 02:46:25 +05:00
Sami Khan
7cbafa768a flash+exo 2026-01-12 10:26:16 +05:00
59 changed files with 3780 additions and 1404 deletions

View File

@@ -863,6 +863,7 @@
"integrity": "sha512-oH8tXw7EZnie8FdOWYrF7Yn4IKrqTFHhXvl8YxXxbKwTMcD/5NNCryUSEXRk2ZR4ojnub0P8rNrsVGHXWqIDtA==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"@standard-schema/spec": "^1.0.0",
"@sveltejs/acorn-typescript": "^1.0.5",
@@ -902,6 +903,7 @@
"integrity": "sha512-Y1Cs7hhTc+a5E9Va/xwKlAJoariQyHY+5zBgCZg4PFWNYQ1nMN9sjK1zhw1gK69DuqVP++sht/1GZg1aRwmAXQ==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"@sveltejs/vite-plugin-svelte-inspector": "^4.0.1",
"debug": "^4.4.1",
@@ -1518,6 +1520,7 @@
"integrity": "sha512-LCCV0HdSZZZb34qifBsyWlUmok6W7ouER+oQIGBScS8EsZsQbrtFTUrDX4hOl+CS6p7cnNC4td+qrSVGSCTUfQ==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"undici-types": "~6.21.0"
}
@@ -1527,6 +1530,7 @@
"resolved": "https://registry.npmjs.org/acorn/-/acorn-8.15.0.tgz",
"integrity": "sha512-NZyJarBfL7nWwIq+FDL6Zp/yHEhePMNnnJ0y3qfieCrmNvYct8uvtiV41UvlSe6apAfk0fY1FbWx+NwfmpvtTg==",
"license": "MIT",
"peer": true,
"bin": {
"acorn": "bin/acorn"
},
@@ -1939,6 +1943,7 @@
"integrity": "sha512-fmTRWbNMmsmWq6xJV8D19U/gw/bwrHfNXxrIN+HfZgnzqTHp9jOmKMhsTUjXOJnZOdZY9Q28y4yebKzqDKlxlQ==",
"dev": true,
"license": "ISC",
"peer": true,
"engines": {
"node": ">=12"
}
@@ -2646,6 +2651,7 @@
"integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==",
"dev": true,
"license": "MIT",
"peer": true,
"engines": {
"node": ">=12"
},
@@ -2833,6 +2839,7 @@
"resolved": "https://registry.npmjs.org/svelte/-/svelte-5.45.3.tgz",
"integrity": "sha512-ngKXNhNvwPzF43QqEhDOue7TQTrG09em1sd4HBxVF0Wr2gopAmdEWan+rgbdgK4fhBtSOTJO8bYU4chUG7VXZQ==",
"license": "MIT",
"peer": true,
"dependencies": {
"@jridgewell/remapping": "^2.3.4",
"@jridgewell/sourcemap-codec": "^1.5.0",
@@ -2977,6 +2984,7 @@
"integrity": "sha512-jl1vZzPDinLr9eUt3J/t7V6FgNEw9QjvBPdysz9KfQDD41fQrC2Y4vKQdiaUpFT4bXlb1RHhLpp8wtm6M5TgSw==",
"dev": true,
"license": "Apache-2.0",
"peer": true,
"bin": {
"tsc": "bin/tsc",
"tsserver": "bin/tsserver"
@@ -2998,6 +3006,7 @@
"integrity": "sha512-+Oxm7q9hDoLMyJOYfUYBuHQo+dkAloi33apOPP56pzj+vsdJDzr+j1NISE5pyaAuKL4A3UD34qd0lx5+kfKp2g==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"esbuild": "^0.25.0",
"fdir": "^6.4.4",

View File

@@ -434,8 +434,8 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
const shardData = shardObj[shardKeys[0]] as Record<string, unknown>;
if (!shardData) return null;
// Model meta is nested: shard.model_card.model_id
const modelMeta = shardData.model_card ?? shardData.modelCard;
// Model meta is nested: shard.model_meta.model_id
const modelMeta = shardData.model_meta ?? shardData.modelMeta;
if (!modelMeta || typeof modelMeta !== 'object') return null;
const meta = modelMeta as Record<string, unknown>;

View File

@@ -98,7 +98,7 @@
const shardData = shardObj[shardKeys[0]] as Record<string, unknown>;
if (!shardData) return null;
const modelMeta = shardData.model_card ?? shardData.modelCard;
const modelMeta = shardData.model_meta ?? shardData.modelMeta;
if (!modelMeta || typeof modelMeta !== 'object') return null;
const meta = modelMeta as Record<string, unknown>;
@@ -190,7 +190,7 @@
const shardKeys = Object.keys(shardObj);
if (shardKeys.length !== 1) return null;
const shardData = shardObj[shardKeys[0]] as Record<string, unknown>;
const modelMeta = shardData?.model_card ?? shardData?.modelCard;
const modelMeta = shardData?.model_meta ?? shardData?.modelMeta;
if (!modelMeta || typeof modelMeta !== 'object') return null;
const meta = modelMeta as Record<string, unknown>;
return (meta.prettyName as string) ?? null;

View File

@@ -17,8 +17,8 @@ dependencies = [
"loguru>=0.7.3",
"exo_pyo3_bindings", # rust bindings
"anyio==4.11.0",
"mlx==0.30.3; sys_platform == 'darwin'",
"mlx[cpu]==0.30.3; sys_platform == 'linux'",
"mlx==0.30.1; sys_platform == 'darwin'",
"mlx[cpu]==0.30.1; sys_platform == 'linux'",
"mlx-lm @ git+https://github.com/AlexCheema/mlx-lm.git@fix-transformers-5.0.0rc2",
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
"hypercorn>=0.18.0",
@@ -30,6 +30,7 @@ dependencies = [
exo-master = "exo.master.main:main"
exo-worker = "exo.worker.main:main"
exo = "exo.main:main"
exo-rsh = "exo.rsh.client:main"
# dependencies only required for development
[dependency-groups]

32
src/exo/cli/__init__.py Normal file
View File

@@ -0,0 +1,32 @@
"""Exo CLI - SLURM-compatible job management commands."""
def run_subcommand(command: str, args: list[str]) -> int:
"""Route to the appropriate subcommand handler.
Args:
command: The subcommand name (sbatch, squeue, scancel, salloc)
args: Command line arguments for the subcommand
Returns:
Exit code from the subcommand
"""
if command == "sbatch":
from exo.cli.sbatch import main
return main(args)
elif command == "squeue":
from exo.cli.squeue import main
return main(args)
elif command == "scancel":
from exo.cli.scancel import main
return main(args)
elif command == "salloc":
from exo.cli.salloc import main
return main(args)
else:
print(f"Unknown subcommand: {command}")
return 1

118
src/exo/cli/common.py Normal file
View File

@@ -0,0 +1,118 @@
"""Common utilities for Exo CLI commands."""
import json
import os
import urllib.request
from typing import Any
from urllib.error import HTTPError, URLError
# Default API endpoint
DEFAULT_API_HOST = "localhost"
DEFAULT_API_PORT = 52415
def get_api_base() -> str:
"""Get the API base URL from environment or defaults."""
host = os.environ.get("EXO_API_HOST", DEFAULT_API_HOST)
port = os.environ.get("EXO_API_PORT", str(DEFAULT_API_PORT))
return f"http://{host}:{port}"
def api_request(
method: str,
path: str,
data: dict[str, Any] | None = None,
) -> dict[str, Any] | list[Any]:
"""Make an API request to the Exo server.
Args:
method: HTTP method (GET, POST, DELETE, etc.)
path: API path (e.g., "/flash/instances")
data: Optional JSON data for POST/PUT requests
Returns:
Parsed JSON response
Raises:
SystemExit: On connection or HTTP errors
"""
url = f"{get_api_base()}{path}"
request_data = None
if data is not None:
request_data = json.dumps(data).encode("utf-8")
req = urllib.request.Request(
url,
data=request_data,
method=method,
)
req.add_header("Content-Type", "application/json")
try:
with urllib.request.urlopen(req, timeout=30) as response: # pyright: ignore[reportAny]
body: str = response.read().decode("utf-8") # pyright: ignore[reportAny]
if body:
return json.loads(body) # pyright: ignore[reportAny]
return {}
except HTTPError as e:
error_body = e.read().decode("utf-8") if e.fp else ""
print(f"API error: {e.code} {e.reason}")
if error_body:
try:
error_json: dict[str, str] = json.loads(error_body) # pyright: ignore[reportAny]
if "detail" in error_json:
print(f" {error_json['detail']}")
except json.JSONDecodeError:
print(f" {error_body}")
raise SystemExit(1) from None
except URLError as e:
print(f"Connection error: {e.reason}")
print(f"Is Exo running at {get_api_base()}?")
raise SystemExit(1) from None
def truncate_id(instance_id: str, length: int = 8) -> str:
"""Truncate a UUID for display.
Args:
instance_id: Full UUID string
length: Number of characters to keep
Returns:
Truncated ID without hyphens
"""
return instance_id.replace("-", "")[:length]
def format_table(headers: list[str], rows: list[list[str]]) -> str:
"""Format data as a simple text table.
Args:
headers: Column headers
rows: List of rows, each row is a list of column values
Returns:
Formatted table string
"""
if not rows:
return " ".join(f"{h:<10}" for h in headers)
# Calculate column widths
widths = [len(h) for h in headers]
for row in rows:
for i, cell in enumerate(row):
if i < len(widths):
widths[i] = max(widths[i], len(cell))
# Build format string
fmt = " ".join(f"{{:<{w}}}" for w in widths)
# Format output
lines = [fmt.format(*headers)]
for row in rows:
# Pad row if needed
padded = row + [""] * (len(headers) - len(row))
lines.append(fmt.format(*padded[: len(headers)]))
return "\n".join(lines)

100
src/exo/cli/salloc.py Normal file
View File

@@ -0,0 +1,100 @@
"""salloc - Allocate nodes for interactive use.
Usage:
exo salloc [options] [-- command [args...]]
Options:
-N, --nodes N Number of nodes to allocate (default: 1)
--hosts HOSTS Comma-separated list of hostnames
If a command is provided after --, it will be executed with
SLURM-like environment variables set:
SLURM_JOB_NODELIST - Comma-separated list of allocated nodes
SLURM_NNODES - Number of allocated nodes
Examples:
exo salloc --nodes=2 --hosts=node1,node2 -- mpirun ./my_program
exo salloc --hosts=localhost -- bash
"""
import argparse
import os
import subprocess
import sys
def main(args: list[str]) -> int:
"""Main entry point for salloc command."""
# Split args at -- if present
cmd_args: list[str] = []
salloc_args = args
if "--" in args:
idx = args.index("--")
salloc_args = args[:idx]
cmd_args = args[idx + 1 :]
parser = argparse.ArgumentParser(
prog="exo salloc",
description="Allocate nodes for interactive use",
)
parser.add_argument(
"-N",
"--nodes",
type=int,
default=1,
help="Number of nodes to allocate (default: 1)",
)
parser.add_argument(
"--hosts",
help="Comma-separated list of hostnames (required)",
)
parsed = parser.parse_args(salloc_args)
nodes: int = parsed.nodes # pyright: ignore[reportAny]
hosts: str | None = parsed.hosts # pyright: ignore[reportAny]
# Require explicit hosts since we can't discover them from topology
if not hosts:
print("Error: --hosts is required (e.g., --hosts=node1,node2)", file=sys.stderr)
print(" The Exo topology doesn't expose hostnames.", file=sys.stderr)
return 1
host_list = [h.strip() for h in hosts.split(",") if h.strip()]
if len(host_list) < nodes:
print(
f"Error: Requested {nodes} nodes but only {len(host_list)} hosts provided",
file=sys.stderr,
)
return 1
# Use first N hosts
allocated_hosts = host_list[:nodes]
nodelist = ",".join(allocated_hosts)
# Set environment variables
env = os.environ.copy()
env["SLURM_JOB_NODELIST"] = nodelist
env["SLURM_NNODES"] = str(nodes)
print(f"salloc: Granted job allocation on {nodes} node(s)")
print(f"salloc: Nodes: {nodelist}")
if cmd_args:
# Run the command
print(f"salloc: Running: {' '.join(cmd_args)}")
result = subprocess.run(cmd_args, env=env)
return result.returncode
else:
# Start interactive shell
shell = os.environ.get("SHELL", "/bin/bash")
print(f"salloc: Starting shell {shell}")
print("salloc: Use 'exit' to release allocation")
result = subprocess.run([shell], env=env)
return result.returncode
if __name__ == "__main__":
sys.exit(main(sys.argv[1:]))

233
src/exo/cli/sbatch.py Normal file
View File

@@ -0,0 +1,233 @@
"""sbatch - Submit a batch job to Exo.
Usage:
exo sbatch [options] <script|executable>
exo sbatch --job-name=NAME --nodes=N <executable>
Options:
-J, --job-name NAME Job name
-N, --nodes N Number of nodes (default: 1)
--ntasks-per-node N Tasks per node (default: 1)
-D, --chdir DIR Working directory
--hosts HOSTS Comma-separated list of hostnames
Job scripts can contain #SBATCH directives:
#!/bin/bash
#SBATCH --job-name=Sod2D
#SBATCH --nodes=2
#SBATCH --chdir=/path/to/workdir
/path/to/flash4
"""
import argparse
import os
import re
import sys
from exo.cli.common import api_request, truncate_id
def parse_job_script(script_path: str) -> tuple[dict[str, str], str | None]:
"""Parse a job script for #SBATCH directives and executable.
Args:
script_path: Path to the job script
Returns:
Tuple of (directives dict, executable path or None)
"""
directives: dict[str, str] = {}
executable: str | None = None
with open(script_path, "r") as f:
for line in f:
line = line.strip()
# Parse #SBATCH directives
if line.startswith("#SBATCH"):
# Handle both --option=value and --option value formats
match = re.match(r"#SBATCH\s+(-\w|--[\w-]+)(?:=|\s+)(.+)", line)
if match:
opt, val = match.groups()
directives[opt.lstrip("-")] = val.strip()
continue
# Skip comments and empty lines
if line.startswith("#") or not line:
continue
# First non-comment, non-directive line is the executable
if executable is None:
# Handle lines like "/path/to/flash4" or "srun /path/to/flash4"
parts = line.split()
if parts:
# Skip srun/mpirun prefixes if present
for part in parts:
if not part.startswith("-") and "/" in part:
executable = part
break
if executable is None and parts:
executable = parts[-1] # Last token
return directives, executable
def main(args: list[str]) -> int:
"""Main entry point for sbatch command."""
parser = argparse.ArgumentParser(
prog="exo sbatch",
description="Submit a batch job to Exo",
)
parser.add_argument(
"script",
help="Job script or executable path",
)
parser.add_argument(
"-J",
"--job-name",
dest="job_name",
help="Job name",
)
parser.add_argument(
"-N",
"--nodes",
type=int,
default=1,
help="Number of nodes (default: 1)",
)
parser.add_argument(
"--ntasks-per-node",
type=int,
default=1,
help="Tasks per node (default: 1)",
)
parser.add_argument(
"-D",
"--chdir",
help="Working directory",
)
parser.add_argument(
"--hosts",
help="Comma-separated list of hostnames",
)
parsed = parser.parse_args(args)
# Extract typed values from namespace
script_path: str = parsed.script # pyright: ignore[reportAny]
arg_job_name: str | None = parsed.job_name # pyright: ignore[reportAny]
arg_nodes: int = parsed.nodes # pyright: ignore[reportAny]
arg_ntasks: int = parsed.ntasks_per_node # pyright: ignore[reportAny]
arg_chdir: str | None = parsed.chdir # pyright: ignore[reportAny]
arg_hosts: str | None = parsed.hosts # pyright: ignore[reportAny]
# Determine if input is a script or direct executable
executable: str | None = None
directives: dict[str, str] = {}
if os.path.isfile(script_path):
# Check if it's a binary file (executable) or text script
is_binary = False
try:
with open(script_path, "rb") as f:
chunk = f.read(512)
# Binary files typically contain null bytes
is_binary = b"\x00" in chunk
except OSError:
pass
if is_binary:
# It's a binary executable
executable = script_path
else:
# Try to read as text
try:
with open(script_path, "r") as f:
first_line = f.readline()
f.seek(0)
content = f.read(1024)
if first_line.startswith("#!") or "#SBATCH" in content:
# It's a job script - parse it
directives, executable = parse_job_script(script_path)
else:
# It's an executable (text but no shebang/directives)
executable = script_path
except UnicodeDecodeError:
# Can't read as text - treat as binary executable
executable = script_path
else:
# Not a file - treat as executable path
executable = script_path
if executable is None:
print("Error: No executable found in job script", file=sys.stderr)
return 1
# Build job parameters - CLI args override script directives
job_name = arg_job_name or directives.get("job-name") or directives.get("J")
if not job_name:
# Generate name from executable
job_name = os.path.basename(executable).replace(".", "_")
nodes = arg_nodes
if "nodes" in directives:
nodes = int(directives["nodes"])
if "N" in directives:
nodes = int(directives["N"])
if arg_nodes != 1: # CLI override
nodes = arg_nodes
ntasks = arg_ntasks
if "ntasks-per-node" in directives:
ntasks = int(directives["ntasks-per-node"])
if arg_ntasks != 1: # CLI override
ntasks = arg_ntasks
workdir = arg_chdir or directives.get("chdir") or directives.get("D")
if not workdir:
workdir = os.getcwd()
hosts = arg_hosts or directives.get("hosts") or ""
# Resolve executable to absolute path
if not os.path.isabs(executable):
executable = os.path.abspath(os.path.join(workdir, executable))
# Submit job via API using query parameters
from urllib.parse import urlencode
params = {
"simulation_name": job_name,
"flash_executable_path": executable,
"parameter_file_path": "", # FLASH par file - use default
"working_directory": workdir,
"ranks_per_node": str(ntasks),
"min_nodes": str(nodes),
"hosts": hosts,
}
query_string = urlencode(params)
result = api_request("POST", f"/flash/launch?{query_string}")
# Print job submission confirmation
if isinstance(result, dict):
instance_id_val = result.get("instance_id")
if instance_id_val is not None:
job_id = truncate_id(str(instance_id_val)) # pyright: ignore[reportAny]
print(f"Submitted batch job {job_id}")
else:
# Instance created asynchronously - user should check squeue
print("Job submitted successfully")
print("Use 'exo squeue' to view job ID")
else:
print("Job submitted successfully")
print("Use 'exo squeue' to view job ID")
return 0
if __name__ == "__main__":
sys.exit(main(sys.argv[1:]))

95
src/exo/cli/scancel.py Normal file
View File

@@ -0,0 +1,95 @@
"""scancel - Cancel jobs in the Exo queue.
Usage:
exo scancel <jobid> [<jobid>...]
Arguments:
jobid Job ID (or prefix) to cancel. Can specify multiple.
Examples:
exo scancel abc123 # Cancel job starting with abc123
exo scancel abc123 def456 # Cancel multiple jobs
"""
import argparse
import sys
from typing import Any, cast
from exo.cli.common import api_request, truncate_id
def main(args: list[str]) -> int:
"""Main entry point for scancel command."""
parser = argparse.ArgumentParser(
prog="exo scancel",
description="Cancel jobs in the Exo queue",
)
parser.add_argument(
"jobids",
nargs="+",
help="Job ID(s) to cancel",
)
parsed = parser.parse_args(args)
jobids: list[str] = parsed.jobids # pyright: ignore[reportAny]
# Fetch current jobs to resolve partial IDs
result = api_request("GET", "/flash/instances")
if isinstance(result, list):
instances = cast(list[dict[str, Any]], result)
else:
instances = cast(list[dict[str, Any]], result.get("instances", []))
# Build lookup of full IDs
id_map: dict[str, str] = {}
for inst in instances:
iid = inst.get("instance_id", "") # pyright: ignore[reportAny]
full_id = str(iid) if iid else "" # pyright: ignore[reportAny]
if full_id:
# Map both full ID and truncated versions
normalized = full_id.replace("-", "").lower()
id_map[normalized] = full_id
# Also map prefixes
for length in range(4, len(normalized) + 1):
prefix = normalized[:length]
if prefix not in id_map:
id_map[prefix] = full_id
cancelled = 0
errors = 0
for jobid in jobids:
search = jobid.lower().replace("-", "")
# Find matching full ID
full_id = id_map.get(search)
if not full_id:
# Try prefix match
matches = [fid for key, fid in id_map.items() if key.startswith(search)]
if len(matches) == 1:
full_id = matches[0]
elif len(matches) > 1:
print(f"Ambiguous job ID: {jobid} matches multiple jobs")
errors += 1
continue
else:
print(f"Job not found: {jobid}")
errors += 1
continue
# Cancel the job
try:
api_request("DELETE", f"/flash/{full_id}")
print(f"Job {truncate_id(full_id)} cancelled")
cancelled += 1
except SystemExit:
print(f"Failed to cancel job {truncate_id(full_id)}")
errors += 1
if errors > 0 and cancelled == 0:
return 1
return 0
if __name__ == "__main__":
sys.exit(main(sys.argv[1:]))

165
src/exo/cli/squeue.py Normal file
View File

@@ -0,0 +1,165 @@
"""squeue - View the Exo job queue.
Usage:
exo squeue [options]
Options:
-l, --long Show detailed output
-j, --job ID Show only this job
Output columns:
JOBID - Job identifier (truncated UUID)
NAME - Job name
NODES - Number of nodes
STATE - Job state (PENDING, RUNNING, FAILED, etc.)
"""
import argparse
import sys
from typing import Any, cast
from exo.cli.common import api_request, format_table, truncate_id
# Map Exo runner statuses to SLURM-like states
STATUS_MAP: dict[str, str] = {
"RunnerIdle": "PENDING",
"RunnerConnecting": "CONFIGURING",
"RunnerConnected": "CONFIGURING",
"RunnerLoading": "CONFIGURING",
"RunnerLoaded": "CONFIGURING",
"RunnerWarmingUp": "CONFIGURING",
"RunnerReady": "COMPLETING",
"RunnerRunning": "RUNNING",
"RunnerShuttingDown": "COMPLETING",
"RunnerShutdown": "COMPLETED",
"RunnerFailed": "FAILED",
}
def get_job_state(runner_statuses: dict[str, Any]) -> str:
"""Determine overall job state from runner statuses."""
if not runner_statuses:
return "PENDING"
states: set[str] = set()
for status_val in runner_statuses.values(): # pyright: ignore[reportAny]
if isinstance(status_val, dict):
# Extract status type from discriminated union
type_val = status_val.get("type", "RunnerIdle") # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
status_type = str(type_val) if type_val else "RunnerIdle" # pyright: ignore[reportUnknownArgumentType]
elif isinstance(status_val, str):
status_type = status_val
else:
status_type = "RunnerIdle"
# Strip parentheses from status strings like "RunnerRunning()"
if status_type.endswith("()"):
status_type = status_type[:-2]
states.add(STATUS_MAP.get(status_type, "UNKNOWN"))
# Priority order for overall state
if "FAILED" in states:
return "FAILED"
if "RUNNING" in states:
return "RUNNING"
if "CONFIGURING" in states:
return "CONFIGURING"
if "COMPLETING" in states:
return "COMPLETING"
if "COMPLETED" in states:
return "COMPLETED"
if "PENDING" in states:
return "PENDING"
return "UNKNOWN"
def main(args: list[str]) -> int:
"""Main entry point for squeue command."""
parser = argparse.ArgumentParser(
prog="exo squeue",
description="View the Exo job queue",
)
parser.add_argument(
"-l",
"--long",
action="store_true",
help="Show detailed output",
)
parser.add_argument(
"-j",
"--job",
help="Show only this job ID",
)
parsed = parser.parse_args(args)
# Extract typed values
long_format: bool = parsed.long # pyright: ignore[reportAny]
job_filter: str | None = parsed.job # pyright: ignore[reportAny]
# Fetch jobs from API - returns list directly
result = api_request("GET", "/flash/instances")
# API returns list directly, not {"instances": [...]}
if isinstance(result, list):
instances = cast(list[dict[str, Any]], result)
else:
instances = cast(list[dict[str, Any]], result.get("instances", []))
if not instances:
# No jobs - just print header
if long_format:
print("JOBID NAME NODES RANKS STATE WORKDIR")
else:
print("JOBID NAME NODES STATE")
return 0
# Filter by job ID if specified
if job_filter:
search = job_filter.lower()
filtered: list[dict[str, Any]] = []
for i in instances:
iid = i.get("instance_id", "") # pyright: ignore[reportAny]
if search in str(iid).lower().replace("-", ""): # pyright: ignore[reportAny]
filtered.append(i)
instances = filtered
# Build table
rows: list[list[str]] = []
if long_format:
headers = ["JOBID", "NAME", "NODES", "RANKS", "STATE", "WORKDIR"]
for inst in instances:
iid_val = inst.get("instance_id", "") # pyright: ignore[reportAny]
instance_id = str(iid_val) if iid_val else "" # pyright: ignore[reportAny]
job_id = truncate_id(instance_id, 12)
name_val = inst.get("simulation_name", "") # pyright: ignore[reportAny]
name = (str(name_val) if name_val else "")[:15] # pyright: ignore[reportAny]
runner_statuses = cast(dict[str, Any], inst.get("runner_statuses", {}))
nodes = str(len(runner_statuses))
ranks_val = inst.get("total_ranks", 0) # pyright: ignore[reportAny]
ranks = str(ranks_val) if ranks_val else "0" # pyright: ignore[reportAny]
state = get_job_state(runner_statuses)
workdir_val = inst.get("working_directory", "") # pyright: ignore[reportAny]
workdir = str(workdir_val) if workdir_val else "" # pyright: ignore[reportAny]
# Truncate workdir for display
if len(workdir) > 30:
workdir = "..." + workdir[-27:]
rows.append([job_id, name, nodes, ranks, state, workdir])
else:
headers = ["JOBID", "NAME", "NODES", "STATE"]
for inst in instances:
iid_val = inst.get("instance_id", "") # pyright: ignore[reportAny]
instance_id = str(iid_val) if iid_val else "" # pyright: ignore[reportAny]
job_id = truncate_id(instance_id, 8)
name_val = inst.get("simulation_name", "") # pyright: ignore[reportAny]
name = (str(name_val) if name_val else "")[:15] # pyright: ignore[reportAny]
runner_statuses = cast(dict[str, Any], inst.get("runner_statuses", {}))
nodes = str(len(runner_statuses))
state = get_job_state(runner_statuses)
rows.append([job_id, name, nodes, state])
print(format_table(headers, rows))
return 0
if __name__ == "__main__":
sys.exit(main(sys.argv[1:]))

View File

@@ -195,6 +195,14 @@ class Node:
def main():
# Check for SLURM-compatible subcommands first
import sys
if len(sys.argv) > 1 and sys.argv[1] in ("sbatch", "squeue", "scancel", "salloc"):
from exo.cli import run_subcommand
sys.exit(run_subcommand(sys.argv[1], sys.argv[2:]))
args = Args.parse()
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
resource.setrlimit(resource.RLIMIT_NOFILE, (max(soft, 65535), hard))
@@ -205,6 +213,11 @@ def main():
logger.info("Starting EXO")
logger.info(f"EXO_LIBP2P_NAMESPACE: {os.getenv('EXO_LIBP2P_NAMESPACE')}")
# Discover and register plugins
from exo.plugins.registry import discover_plugins
discover_plugins()
# Set FAST_SYNCH override env var for runner subprocesses
if args.fast_synch is True:
os.environ["EXO_FAST_SYNCH"] = "on"

View File

@@ -1,7 +1,9 @@
import asyncio
import os
import time
from collections.abc import AsyncGenerator
from http import HTTPStatus
from typing import cast
from typing import Any, Optional, cast
import anyio
from anyio import BrokenResourceError, create_task_group
@@ -14,13 +16,14 @@ from hypercorn.asyncio import serve # pyright: ignore[reportUnknownVariableType
from hypercorn.config import Config
from hypercorn.typing import ASGIFramework
from loguru import logger
from pydantic import BaseModel
from exo.master.placement import place_instance as get_instance_placements
from exo.shared.apply import apply
from exo.shared.election import ElectionMessage
from exo.shared.logging import InterceptLogger
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard, ModelId
from exo.shared.models.model_meta import get_model_card
from exo.shared.models.model_cards import MODEL_CARDS
from exo.shared.models.model_meta import get_model_meta
from exo.shared.types.api import (
BenchChatCompletionResponse,
BenchChatCompletionTaskParams,
@@ -59,9 +62,14 @@ from exo.shared.types.events import (
IndexedEvent,
)
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.state import State
from exo.shared.types.tasks import ChatCompletionTaskParams
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.instances import (
Instance,
InstanceId,
InstanceMeta,
)
from exo.shared.types.worker.shards import Sharding
from exo.utils.banner import print_startup_banner
from exo.utils.channels import Receiver, Sender, channel
@@ -69,6 +77,22 @@ from exo.utils.dashboard_path import find_dashboard
from exo.utils.event_buffer import OrderedBuffer
class ExecuteRequest(BaseModel):
"""Request to execute a command."""
command: list[str]
cwd: Optional[str] = None
env: Optional[dict[str, str]] = None
class ExecuteResponse(BaseModel):
"""Response from command execution."""
exit_code: int
stdout: str
stderr: str
def chunk_to_response(
chunk: TokenChunk, command_id: CommandId
) -> ChatCompletionResponse:
@@ -86,12 +110,12 @@ def chunk_to_response(
)
async def resolve_model_card(model_id: str) -> ModelCard:
async def resolve_model_meta(model_id: str) -> ModelMetadata:
if model_id in MODEL_CARDS:
model_card = MODEL_CARDS[model_id]
return model_card
return model_card.metadata
else:
return await get_model_card(model_id)
return await get_model_meta(model_id)
class API:
@@ -125,6 +149,7 @@ class API:
self._setup_exception_handlers()
self._setup_cors()
self._setup_routes()
self._register_plugin_routes()
self.app.mount(
"/",
@@ -193,10 +218,62 @@ class API:
self.app.post("/bench/chat/completions")(self.bench_chat_completions)
self.app.get("/state")(lambda: self.state)
self.app.get("/events")(lambda: self._event_log)
# Remote execution endpoint (used by exo-rsh for MPI)
self.app.post("/execute")(self.execute)
def _register_plugin_routes(self) -> None:
"""Register API routes from all loaded plugins."""
import functools
import inspect
from exo.plugins.context import PluginContext
from exo.plugins.registry import PluginRegistry
registry = PluginRegistry.get()
for method, path, handler, plugin in registry.get_all_api_routes():
# Create a wrapper that injects the plugin context while preserving
# the original function signature for FastAPI parameter extraction
@functools.wraps(handler)
async def wrapped_handler( # pyright: ignore[reportAny]
*args: Any, # pyright: ignore[reportAny]
_handler: Any = handler, # pyright: ignore[reportAny]
**kwargs: Any, # pyright: ignore[reportAny]
) -> Any:
context = PluginContext(
state=self.state,
send_command=self._send,
node_id=self.node_id,
)
# Pass context as first argument, then forward all other args
return await _handler(context, *args, **kwargs) # pyright: ignore[reportAny]
# Modify the wrapper signature to match the original handler
# but without the 'ctx' parameter (we inject it)
orig_sig = inspect.signature(handler)
params = list(orig_sig.parameters.values())
# Remove the first parameter (ctx: PluginContext)
if params and params[0].name in ("ctx", "context"):
params = params[1:]
wrapped_handler.__signature__ = orig_sig.replace(parameters=params) # pyright: ignore[reportAttributeAccessIssue]
# Register the route based on HTTP method
if method == "get":
self.app.get(path)(wrapped_handler)
elif method == "post":
self.app.post(path)(wrapped_handler)
elif method == "delete":
self.app.delete(path)(wrapped_handler)
elif method == "put":
self.app.put(path)(wrapped_handler)
logger.debug(
f"Registered plugin route: {method.upper()} {path} ({plugin.name})"
)
async def place_instance(self, payload: PlaceInstanceParams):
command = PlaceInstance(
model_card=await resolve_model_card(payload.model_id),
model_meta=await resolve_model_meta(payload.model_id),
sharding=payload.sharding,
instance_meta=payload.instance_meta,
min_nodes=payload.min_nodes,
@@ -206,15 +283,15 @@ class API:
return CreateInstanceResponse(
message="Command received.",
command_id=command.command_id,
model_card=command.model_card,
model_meta=command.model_meta,
)
async def create_instance(
self, payload: CreateInstanceParams
) -> CreateInstanceResponse:
instance = payload.instance
model_card = await resolve_model_card(instance.shard_assignments.model_id)
required_memory = model_card.storage_size
model_meta = await resolve_model_meta(instance.shard_assignments.model_id)
required_memory = model_meta.storage_size
available_memory = self._calculate_total_available_memory()
if required_memory > available_memory:
@@ -231,7 +308,7 @@ class API:
return CreateInstanceResponse(
message="Command received.",
command_id=command.command_id,
model_card=model_card,
model_meta=model_meta,
)
async def get_placement(
@@ -241,12 +318,12 @@ class API:
instance_meta: InstanceMeta = InstanceMeta.MlxRing,
min_nodes: int = 1,
) -> Instance:
model_card = await resolve_model_card(model_id)
model_meta = await resolve_model_meta(model_id)
try:
placements = get_instance_placements(
PlaceInstance(
model_card=model_card,
model_meta=model_meta,
sharding=sharding,
instance_meta=instance_meta,
min_nodes=min_nodes,
@@ -279,7 +356,7 @@ class API:
if len(list(self.state.topology.list_nodes())) == 0:
return PlacementPreviewResponse(previews=[])
cards = [card for card in MODEL_CARDS.values() if card.model_id == model_id]
cards = [card for card in MODEL_CARDS.values() if card.short_id == model_id]
if not cards:
raise HTTPException(status_code=404, detail=f"Model {model_id} not found")
@@ -297,12 +374,13 @@ class API:
# TODO: PDD
# instance_combinations.append((Sharding.PrefillDecodeDisaggregation, InstanceMeta.MlxRing, 1))
for model_card in cards:
for card in cards:
model_meta = card.metadata
for sharding, instance_meta, min_nodes in instance_combinations:
try:
placements = get_instance_placements(
PlaceInstance(
model_card=model_card,
model_meta=model_meta,
sharding=sharding,
instance_meta=instance_meta,
min_nodes=min_nodes,
@@ -313,17 +391,17 @@ class API:
current_instances=self.state.instances,
)
except ValueError as exc:
if (model_card.model_id, sharding, instance_meta, 0) not in seen:
if (card.model_id, sharding, instance_meta, 0) not in seen:
previews.append(
PlacementPreview(
model_id=model_card.model_id,
model_id=card.model_id,
sharding=sharding,
instance_meta=instance_meta,
instance=None,
error=str(exc),
)
)
seen.add((model_card.model_id, sharding, instance_meta, 0))
seen.add((card.model_id, sharding, instance_meta, 0))
continue
current_ids = set(self.state.instances.keys())
@@ -334,17 +412,17 @@ class API:
]
if len(new_instances) != 1:
if (model_card.model_id, sharding, instance_meta, 0) not in seen:
if (card.model_id, sharding, instance_meta, 0) not in seen:
previews.append(
PlacementPreview(
model_id=model_card.model_id,
model_id=card.model_id,
sharding=sharding,
instance_meta=instance_meta,
instance=None,
error="Expected exactly one new instance from placement",
)
)
seen.add((model_card.model_id, sharding, instance_meta, 0))
seen.add((card.model_id, sharding, instance_meta, 0))
continue
instance = new_instances[0]
@@ -353,7 +431,7 @@ class API:
memory_delta_by_node: dict[str, int] = {}
if node_ids:
total_bytes = model_card.storage_size.in_bytes
total_bytes = model_meta.storage_size.in_bytes
per_node = total_bytes // len(node_ids)
remainder = total_bytes % len(node_ids)
for index, node_id in enumerate(sorted(node_ids, key=str)):
@@ -361,14 +439,14 @@ class API:
memory_delta_by_node[str(node_id)] = per_node + extra
if (
model_card.model_id,
card.model_id,
sharding,
instance_meta,
len(node_ids),
) not in seen:
previews.append(
PlacementPreview(
model_id=model_card.model_id,
model_id=card.model_id,
sharding=sharding,
instance_meta=instance_meta,
instance=instance,
@@ -376,7 +454,7 @@ class API:
error=None,
)
)
seen.add((model_card.model_id, sharding, instance_meta, len(node_ids)))
seen.add((card.model_id, sharding, instance_meta, len(node_ids)))
return PlacementPreviewResponse(previews=previews)
@@ -551,8 +629,8 @@ class API:
self, payload: ChatCompletionTaskParams
) -> ChatCompletionResponse | StreamingResponse:
"""Handle chat completions, supporting both streaming and non-streaming responses."""
model_card = await resolve_model_card(payload.model)
payload.model = model_card.model_id
model_meta = await resolve_model_meta(payload.model)
payload.model = model_meta.model_id
if not any(
instance.shard_assignments.model_id == payload.model
@@ -578,8 +656,8 @@ class API:
async def bench_chat_completions(
self, payload: BenchChatCompletionTaskParams
) -> BenchChatCompletionResponse:
model_card = await resolve_model_card(payload.model)
payload.model = model_card.model_id
model_meta = await resolve_model_meta(payload.model)
payload.model = model_meta.model_id
if not any(
instance.shard_assignments.model_id == payload.model
@@ -612,18 +690,77 @@ class API:
return ModelList(
data=[
ModelListModel(
id=card.model_id,
id=card.short_id,
hugging_face_id=card.model_id,
name=card.model_id.short(),
description="",
tags=[],
storage_size_megabytes=int(card.storage_size.in_mb),
supports_tensor=card.supports_tensor,
name=card.name,
description=card.description,
tags=card.tags,
storage_size_megabytes=int(card.metadata.storage_size.in_mb),
supports_tensor=card.metadata.supports_tensor,
)
for card in MODEL_CARDS.values()
]
)
async def execute(self, request: ExecuteRequest) -> ExecuteResponse:
"""Execute a command locally. Used by exo-rsh for MPI remote execution."""
cmd_str = " ".join(request.command)
logger.info(f"Executing: {cmd_str}")
try:
# Build environment
env = os.environ.copy()
if request.env:
env.update(request.env)
# Check if command contains shell metacharacters
# If so, run through shell. mpirun sends complex commands like:
# "VAR=value;export VAR;/path/to/prted --args"
needs_shell = any(c in cmd_str for c in ";|&$`")
if needs_shell:
process = await asyncio.create_subprocess_shell(
cmd_str,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
cwd=request.cwd,
env=env,
)
else:
process = await asyncio.create_subprocess_exec(
*request.command,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
cwd=request.cwd,
env=env,
)
stdout, stderr = await process.communicate()
exit_code = process.returncode or 0
logger.info(f"Command completed with exit code {exit_code}")
return ExecuteResponse(
exit_code=exit_code,
stdout=stdout.decode("utf-8", errors="replace"),
stderr=stderr.decode("utf-8", errors="replace"),
)
except FileNotFoundError:
logger.error(f"Command not found: {request.command[0]}")
return ExecuteResponse(
exit_code=127,
stdout="",
stderr=f"Command not found: {request.command[0]}",
)
except Exception as e:
logger.error(f"Execution error: {e}")
return ExecuteResponse(
exit_code=1,
stdout="",
stderr=str(e),
)
async def run(self):
cfg = Config()
cfg.bind = f"0.0.0.0:{self.port}"

View File

@@ -96,104 +96,128 @@ class Master:
self._tg.cancel_scope.cancel()
async def _command_processor(self) -> None:
from exo.plugins.registry import PluginRegistry
registry = PluginRegistry.get()
with self.command_receiver as commands:
async for forwarder_command in commands:
try:
logger.info(f"Executing command: {forwarder_command.command}")
generated_events: list[Event] = []
command = forwarder_command.command
match command:
case TestCommand():
pass
case ChatCompletion():
instance_task_counts: dict[InstanceId, int] = {}
for instance in self.state.instances.values():
if (
instance.shard_assignments.model_id
== command.request_params.model
):
task_count = sum(
1
for task in self.state.tasks.values()
if task.instance_id == instance.instance_id
)
instance_task_counts[instance.instance_id] = (
task_count
# Check if a plugin handles this command
plugin = registry.get_plugin_for_command(command)
if plugin is not None:
events = plugin.process_command(
command,
self.state.topology,
self.state.instances,
)
generated_events.extend(events)
else:
# Core command handling
match command:
case TestCommand():
pass
case ChatCompletion():
instance_task_counts: dict[InstanceId, int] = {}
for instance in self.state.instances.values():
if (
instance.shard_assignments.model_id
== command.request_params.model
):
task_count = sum(
1
for task in self.state.tasks.values()
if task.instance_id == instance.instance_id
)
instance_task_counts[instance.instance_id] = (
task_count
)
if not instance_task_counts:
raise ValueError(
f"No instance found for model {command.request_params.model}"
)
if not instance_task_counts:
raise ValueError(
f"No instance found for model {command.request_params.model}"
available_instance_ids = sorted(
instance_task_counts.keys(),
key=lambda instance_id: instance_task_counts[
instance_id
],
)
available_instance_ids = sorted(
instance_task_counts.keys(),
key=lambda instance_id: instance_task_counts[
instance_id
],
)
task_id = TaskId()
generated_events.append(
TaskCreated(
task_id=task_id,
task=ChatCompletionTask(
task_id = TaskId()
generated_events.append(
TaskCreated(
task_id=task_id,
command_id=command.command_id,
instance_id=available_instance_ids[0],
task_status=TaskStatus.Pending,
task_params=command.request_params,
),
task=ChatCompletionTask(
task_id=task_id,
command_id=command.command_id,
instance_id=available_instance_ids[0],
task_status=TaskStatus.Pending,
task_params=command.request_params,
),
)
)
)
self.command_task_mapping[command.command_id] = task_id
case DeleteInstance():
placement = delete_instance(command, self.state.instances)
transition_events = get_transition_events(
self.state.instances, placement
)
generated_events.extend(transition_events)
case PlaceInstance():
placement = place_instance(
command,
self.state.topology,
self.state.instances,
self.state.node_memory,
self.state.node_network,
)
transition_events = get_transition_events(
self.state.instances, placement
)
generated_events.extend(transition_events)
case CreateInstance():
placement = add_instance_to_placements(
command,
self.state.topology,
self.state.instances,
)
transition_events = get_transition_events(
self.state.instances, placement
)
generated_events.extend(transition_events)
case TaskFinished():
generated_events.append(
TaskDeleted(
task_id=self.command_task_mapping[
self.command_task_mapping[command.command_id] = task_id
case DeleteInstance():
placement = delete_instance(
command, self.state.instances
)
transition_events = get_transition_events(
self.state.instances, placement
)
generated_events.extend(transition_events)
case PlaceInstance():
placement = place_instance(
command,
self.state.topology,
self.state.instances,
self.state.node_memory,
self.state.node_network,
)
transition_events = get_transition_events(
self.state.instances, placement
)
generated_events.extend(transition_events)
case CreateInstance():
placement = add_instance_to_placements(
command,
self.state.topology,
self.state.instances,
)
transition_events = get_transition_events(
self.state.instances, placement
)
generated_events.extend(transition_events)
case TaskFinished():
generated_events.append(
TaskDeleted(
task_id=self.command_task_mapping[
command.finished_command_id
]
)
)
if (
command.finished_command_id
in self.command_task_mapping
):
del self.command_task_mapping[
command.finished_command_id
]
)
)
if command.finished_command_id in self.command_task_mapping:
del self.command_task_mapping[
command.finished_command_id
]
case RequestEventLog():
# We should just be able to send everything, since other buffers will ignore old messages
for i in range(command.since_idx, len(self._event_log)):
await self._send_event(
IndexedEvent(idx=i, event=self._event_log[i])
)
case RequestEventLog():
# We should just be able to send everything, since other buffers will ignore old messages
for i in range(command.since_idx, len(self._event_log)):
await self._send_event(
IndexedEvent(idx=i, event=self._event_log[i])
)
case _:
# Plugin-managed commands are handled above
pass
for event in generated_events:
await self.event_sender.send(event)
except ValueError as e:

View File

@@ -14,7 +14,6 @@ from exo.master.placement_utils import (
get_shard_assignments,
get_smallest_cycles,
)
from exo.shared.models.model_cards import ModelId
from exo.shared.topology import Topology
from exo.shared.types.commands import (
CreateInstance,
@@ -24,6 +23,7 @@ from exo.shared.types.commands import (
from exo.shared.types.common import NodeId
from exo.shared.types.events import Event, InstanceCreated, InstanceDeleted
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId
from exo.shared.types.profiling import MemoryUsage, NodeNetworkInfo
from exo.shared.types.worker.instances import (
Instance,
@@ -60,27 +60,27 @@ def place_instance(
cycles = topology.get_cycles()
candidate_cycles = list(filter(lambda it: len(it) >= command.min_nodes, cycles))
cycles_with_sufficient_memory = filter_cycles_by_memory(
candidate_cycles, node_memory, command.model_card.storage_size
candidate_cycles, node_memory, command.model_meta.storage_size
)
if len(cycles_with_sufficient_memory) == 0:
raise ValueError("No cycles found with sufficient memory")
if command.sharding == Sharding.Tensor:
if not command.model_card.supports_tensor:
if not command.model_meta.supports_tensor:
raise ValueError(
f"Requested Tensor sharding but this model does not support tensor parallelism: {command.model_card.model_id}"
f"Requested Tensor sharding but this model does not support tensor parallelism: {command.model_meta.model_id}"
)
# TODO: the condition here for tensor parallel is not correct, but it works good enough for now.
cycles_with_sufficient_memory = [
cycle
for cycle in cycles_with_sufficient_memory
if command.model_card.hidden_size % len(cycle) == 0
if command.model_meta.hidden_size % len(cycle) == 0
]
if not cycles_with_sufficient_memory:
raise ValueError(
f"No tensor sharding found for model with hidden_size {command.model_card.hidden_size} candidate cycles"
f"No tensor sharding found for model with hidden_size {command.model_meta.hidden_size} candidate cycles"
)
if command.sharding == Sharding.Pipeline and command.model_card.model_id == ModelId(
if command.sharding == Sharding.Pipeline and command.model_meta.model_id == ModelId(
"mlx-community/DeepSeek-V3.1-8bit"
):
raise ValueError(
@@ -111,7 +111,7 @@ def place_instance(
)
shard_assignments = get_shard_assignments(
command.model_card, selected_cycle, command.sharding, node_memory
command.model_meta, selected_cycle, command.sharding, node_memory
)
cycle_digraph: Topology = topology.get_subgraph_from_nodes(selected_cycle.node_ids)
@@ -159,6 +159,11 @@ def place_instance(
hosts_by_node=hosts_by_node,
ephemeral_port=ephemeral_port,
)
case _:
# Plugin-managed instance types have their own placement functions
raise ValueError(
f"Instance type {command.instance_meta} must use plugin placement"
)
return target_instances

View File

@@ -2,10 +2,10 @@ from collections.abc import Generator, Mapping
from loguru import logger
from exo.shared.models.model_cards import ModelCard
from exo.shared.topology import Topology
from exo.shared.types.common import Host, NodeId
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelMetadata
from exo.shared.types.profiling import MemoryUsage, NodeNetworkInfo
from exo.shared.types.topology import Cycle, RDMAConnection, SocketConnection
from exo.shared.types.worker.runners import RunnerId, ShardAssignments
@@ -75,7 +75,7 @@ def allocate_layers_proportionally(
def get_shard_assignments_for_pipeline_parallel(
model_card: ModelCard,
model_meta: ModelMetadata,
cycle: Cycle,
node_memory: Mapping[NodeId, MemoryUsage],
):
@@ -86,10 +86,11 @@ def get_shard_assignments_for_pipeline_parallel(
(node_memory[node_id].ram_available for node_id in cycle.node_ids),
start=Memory(),
)
if cycle_memory.in_bytes == 0:
raise ValueError("Cannot create shard assignments: total available memory is 0")
total_layers = model_card.n_layers
total_layers = model_meta.n_layers
world_size = len(cycle)
runner_to_shard: dict[RunnerId, ShardMetadata] = {}
node_to_runner: dict[NodeId, RunnerId] = {}
@@ -103,7 +104,7 @@ def get_shard_assignments_for_pipeline_parallel(
)
# Validate each node has sufficient memory for its assigned layers
memory_per_layer = model_card.storage_size.in_bytes / total_layers
memory_per_layer = model_meta.storage_size.in_bytes / total_layers
for i, (node_id, node_layers) in enumerate(
zip(cycle.node_ids, layer_allocations, strict=True)
):
@@ -123,7 +124,7 @@ def get_shard_assignments_for_pipeline_parallel(
runner_id = RunnerId()
shard = PipelineShardMetadata(
model_card=model_card,
model_meta=model_meta,
device_rank=i,
world_size=world_size,
start_layer=layers_assigned,
@@ -136,7 +137,7 @@ def get_shard_assignments_for_pipeline_parallel(
layers_assigned += node_layers
shard_assignments = ShardAssignments(
model_id=model_card.model_id,
model_id=model_meta.model_id,
runner_to_shard=runner_to_shard,
node_to_runner=node_to_runner,
)
@@ -145,17 +146,17 @@ def get_shard_assignments_for_pipeline_parallel(
def get_shard_assignments_for_tensor_parallel(
model_card: ModelCard,
model_meta: ModelMetadata,
cycle: Cycle,
):
total_layers = model_card.n_layers
total_layers = model_meta.n_layers
world_size = len(cycle)
runner_to_shard: dict[RunnerId, ShardMetadata] = {}
node_to_runner: dict[NodeId, RunnerId] = {}
for i, node_id in enumerate(cycle):
shard = TensorShardMetadata(
model_card=model_card,
model_meta=model_meta,
device_rank=i,
world_size=world_size,
start_layer=0,
@@ -169,7 +170,7 @@ def get_shard_assignments_for_tensor_parallel(
node_to_runner[node_id] = runner_id
shard_assignments = ShardAssignments(
model_id=model_card.model_id,
model_id=model_meta.model_id,
runner_to_shard=runner_to_shard,
node_to_runner=node_to_runner,
)
@@ -178,7 +179,7 @@ def get_shard_assignments_for_tensor_parallel(
def get_shard_assignments(
model_card: ModelCard,
model_meta: ModelMetadata,
cycle: Cycle,
sharding: Sharding,
node_memory: Mapping[NodeId, MemoryUsage],
@@ -186,13 +187,13 @@ def get_shard_assignments(
match sharding:
case Sharding.Pipeline:
return get_shard_assignments_for_pipeline_parallel(
model_card=model_card,
model_meta=model_meta,
cycle=cycle,
node_memory=node_memory,
)
case Sharding.Tensor:
return get_shard_assignments_for_tensor_parallel(
model_card=model_card,
model_meta=model_meta,
cycle=cycle,
)

View File

@@ -7,7 +7,6 @@ from loguru import logger
from exo.master.main import Master
from exo.routing.router import get_node_id_keypair
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.types.api import ChatCompletionMessage, ChatCompletionTaskParams
from exo.shared.types.commands import (
ChatCompletion,
@@ -24,6 +23,7 @@ from exo.shared.types.events import (
TaskCreated,
)
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.profiling import (
MemoryUsage,
)
@@ -109,8 +109,9 @@ async def test_master():
command=(
PlaceInstance(
command_id=CommandId(),
model_card=ModelCard(
model_meta=ModelMetadata(
model_id=ModelId("llama-3.2-1b"),
pretty_name="Llama 3.2 1B",
n_layers=16,
storage_size=Memory.from_bytes(678948),
hidden_size=7168,
@@ -166,8 +167,9 @@ async def test_master():
start_layer=0,
end_layer=16,
n_layers=16,
model_card=ModelCard(
model_meta=ModelMetadata(
model_id=ModelId("llama-3.2-1b"),
pretty_name="Llama 3.2 1B",
n_layers=16,
storage_size=Memory.from_bytes(678948),
hidden_size=7168,

View File

@@ -10,12 +10,12 @@ from exo.master.tests.conftest import (
create_rdma_connection,
create_socket_connection,
)
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.topology import Topology
from exo.shared.types.commands import PlaceInstance
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.events import InstanceCreated, InstanceDeleted
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.profiling import NetworkInterfaceInfo, NodeNetworkInfo
from exo.shared.types.topology import Connection, SocketConnection
@@ -43,20 +43,21 @@ def instance() -> Instance:
@pytest.fixture
def model_card() -> ModelCard:
return ModelCard(
def model_meta() -> ModelMetadata:
return ModelMetadata(
model_id=ModelId("test-model"),
storage_size=Memory.from_kb(1000),
pretty_name="Test Model",
n_layers=10,
hidden_size=30,
supports_tensor=True,
)
def place_instance_command(model_card: ModelCard) -> PlaceInstance:
def place_instance_command(model_meta: ModelMetadata) -> PlaceInstance:
return PlaceInstance(
command_id=CommandId(),
model_card=model_card,
model_meta=model_meta,
sharding=Sharding.Pipeline,
instance_meta=InstanceMeta.MlxRing,
min_nodes=1,
@@ -75,16 +76,16 @@ def test_get_instance_placements_create_instance(
available_memory: tuple[int, int, int],
total_layers: int,
expected_layers: tuple[int, int, int],
model_card: ModelCard,
model_meta: ModelMetadata,
):
# arrange
model_card.n_layers = total_layers
model_card.storage_size.in_bytes = sum(
model_meta.n_layers = total_layers
model_meta.storage_size.in_bytes = sum(
available_memory
) # make it exactly fit across all nodes
topology = Topology()
cic = place_instance_command(model_card)
cic = place_instance_command(model_meta)
node_id_a = NodeId()
node_id_b = NodeId()
node_id_c = NodeId()
@@ -136,7 +137,7 @@ def test_get_instance_placements_create_instance(
assert len(placements) == 1
instance_id = list(placements.keys())[0]
instance = placements[instance_id]
assert instance.shard_assignments.model_id == model_card.model_id
assert instance.shard_assignments.model_id == model_meta.model_id
runner_id_a = instance.shard_assignments.node_to_runner[node_id_a]
runner_id_b = instance.shard_assignments.node_to_runner[node_id_b]
@@ -163,9 +164,10 @@ def test_get_instance_placements_one_node_exact_fit() -> None:
node_memory = {node_id: create_node_memory(1000 * 1024)}
node_network = {node_id: create_node_network()}
cic = place_instance_command(
ModelCard(
ModelMetadata(
model_id=ModelId("test-model"),
storage_size=Memory.from_kb(1000),
pretty_name="Test Model",
n_layers=10,
hidden_size=1000,
supports_tensor=True,
@@ -189,9 +191,10 @@ def test_get_instance_placements_one_node_fits_with_extra_memory() -> None:
node_memory = {node_id: create_node_memory(1001 * 1024)}
node_network = {node_id: create_node_network()}
cic = place_instance_command(
ModelCard(
ModelMetadata(
model_id=ModelId("test-model"),
storage_size=Memory.from_kb(1000),
pretty_name="Test Model",
n_layers=10,
hidden_size=1000,
supports_tensor=True,
@@ -215,9 +218,10 @@ def test_get_instance_placements_one_node_not_fit() -> None:
node_memory = {node_id: create_node_memory(1000 * 1024)}
node_network = {node_id: create_node_network()}
cic = place_instance_command(
model_card=ModelCard(
model_meta=ModelMetadata(
model_id=ModelId("test-model"),
storage_size=Memory.from_kb(1001),
pretty_name="Test Model",
n_layers=10,
hidden_size=1000,
supports_tensor=True,
@@ -271,12 +275,12 @@ def test_get_transition_events_delete_instance(instance: Instance):
def test_placement_selects_leaf_nodes(
model_card: ModelCard,
model_meta: ModelMetadata,
):
# arrange
topology = Topology()
model_card.storage_size = Memory.from_bytes(1000)
model_meta.storage_size = Memory.from_bytes(1000)
node_id_a = NodeId()
node_id_b = NodeId()
@@ -321,7 +325,7 @@ def test_placement_selects_leaf_nodes(
Connection(source=node_id_d, sink=node_id_c, edge=create_socket_connection(1))
)
cic = place_instance_command(model_card=model_card)
cic = place_instance_command(model_meta=model_meta)
# act
placements = place_instance(cic, topology, {}, node_memory, node_network)
@@ -340,12 +344,12 @@ def test_placement_selects_leaf_nodes(
def test_tensor_rdma_backend_connectivity_matrix(
model_card: ModelCard,
model_meta: ModelMetadata,
):
# arrange
topology = Topology()
model_card.n_layers = 12
model_card.storage_size.in_bytes = 1500
model_meta.n_layers = 12
model_meta.storage_size.in_bytes = 1500
node_a = NodeId()
node_b = NodeId()
@@ -407,7 +411,7 @@ def test_tensor_rdma_backend_connectivity_matrix(
sharding=Sharding.Tensor,
instance_meta=InstanceMeta.MlxJaccl,
command_id=CommandId(),
model_card=model_card,
model_meta=model_meta,
min_nodes=1,
)

View File

@@ -12,10 +12,10 @@ from exo.master.tests.conftest import (
create_node_memory,
create_socket_connection,
)
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.topology import Topology
from exo.shared.types.common import Host, NodeId
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.profiling import (
NetworkInterfaceInfo,
NodeNetworkInfo,
@@ -232,8 +232,9 @@ def test_get_shard_assignments(
node_c_id: node_c_mem,
}
model_card = ModelCard(
model_meta = ModelMetadata(
model_id=ModelId("test-model"),
pretty_name="Test Model",
n_layers=total_layers,
storage_size=Memory.from_kb(1000),
hidden_size=1000,
@@ -247,7 +248,7 @@ def test_get_shard_assignments(
# act
shard_assignments = get_shard_assignments(
model_card, selected_cycle, Sharding.Pipeline, node_memory=node_memory
model_meta, selected_cycle, Sharding.Pipeline, node_memory=node_memory
)
# assert
@@ -511,8 +512,9 @@ def test_get_shard_assignments_insufficient_memory_raises():
node_c_id: node_c_mem,
}
model_card = ModelCard(
model_meta = ModelMetadata(
model_id=ModelId("test-model"),
pretty_name="Test Model",
n_layers=20,
storage_size=Memory.from_kb(1000),
hidden_size=1000,
@@ -523,5 +525,5 @@ def test_get_shard_assignments_insufficient_memory_raises():
with pytest.raises(ValueError, match="insufficient memory"):
get_shard_assignments(
model_card, selected_cycle, Sharding.Pipeline, node_memory
model_meta, selected_cycle, Sharding.Pipeline, node_memory
)

View File

@@ -0,0 +1,16 @@
"""Exo Plugin System.
This module provides the plugin architecture for extending exo with custom
workload types (simulations, ML frameworks, etc.) without modifying core code.
"""
from exo.plugins.base import ExoPlugin, PluginCommand, PluginInstance
from exo.plugins.registry import PluginRegistry, discover_plugins
__all__ = [
"ExoPlugin",
"PluginCommand",
"PluginInstance",
"PluginRegistry",
"discover_plugins",
]

171
src/exo/plugins/base.py Normal file
View File

@@ -0,0 +1,171 @@
"""Base classes and protocols for Exo plugins."""
from abc import ABC, abstractmethod
from collections.abc import Callable, Mapping, Sequence
from typing import TYPE_CHECKING, Any
from pydantic import Field
from exo.shared.types.common import CommandId
from exo.shared.types.events import Event
from exo.shared.types.tasks import Task
from exo.shared.types.worker.instances import InstanceId
from exo.shared.types.worker.runners import RunnerId
from exo.utils.pydantic_ext import TaggedModel
if TYPE_CHECKING:
from exo.shared.topology import Topology
from exo.shared.types.worker.instances import BoundInstance, Instance
from exo.utils.channels import MpReceiver, MpSender
from exo.worker.runner.runner_supervisor import RunnerSupervisor
class PluginCommand(TaggedModel):
"""Base class for plugin-defined commands.
All plugin commands must inherit from this class. Commands are serialized
with their class name as a tag for routing.
"""
command_id: CommandId = Field(default_factory=CommandId)
class PluginInstance(TaggedModel):
"""Base class for plugin-defined instances.
All plugin instances must inherit from this class. Plugins are expected
to define their own instance type with workload-specific fields.
"""
instance_id: InstanceId
class ExoPlugin(ABC):
"""Protocol that all exo plugins must implement.
A plugin provides:
- Custom command types for API -> Master communication
- Custom instance types representing running workloads
- Placement logic for distributing work across nodes
- Planning logic for local task scheduling
- Runner implementation for executing work
"""
@property
@abstractmethod
def name(self) -> str:
"""Unique identifier for this plugin (e.g., 'flash', 'pytorch', 'mpi')."""
...
@property
@abstractmethod
def version(self) -> str:
"""Semantic version string (e.g., '1.0.0')."""
...
# ========== Type Registration ==========
@abstractmethod
def get_command_types(self) -> Sequence[type]:
"""Return command types this plugin handles.
These commands are routed to this plugin's process_command method.
Can return core BaseCommand types or PluginCommand types.
"""
...
@abstractmethod
def get_instance_type(self) -> type:
"""Return the instance type this plugin creates.
This instance type is used for routing in planning and runner bootstrap.
Can return core Instance types or PluginInstance types.
"""
...
# ========== API Routes ==========
@abstractmethod
def get_api_routes(
self,
) -> Sequence[tuple[str, str, Callable[..., Any]]]:
"""Return FastAPI routes to register.
Each tuple: (method, path, handler)
Example: [('post', '/flash/launch', self.launch_handler)]
Handlers receive a PluginContext with access to:
- state: Current State object
- send_command: Async function to send commands
- node_id: Current node's ID
"""
...
# ========== Master Command Handling ==========
@abstractmethod
def handles_command(self, command: Any) -> bool: # pyright: ignore[reportAny]
"""Return True if this plugin handles the given command type."""
...
@abstractmethod
def process_command(
self,
command: Any, # pyright: ignore[reportAny]
topology: "Topology",
current_instances: Mapping[InstanceId, "Instance"],
) -> Sequence[Event]:
"""Process a command and return events to emit.
Typically creates placement and returns InstanceCreated/InstanceDeleted events.
Args:
command: The command to process
topology: Current cluster topology
current_instances: Currently running instances
Returns:
Sequence of events to emit (e.g., InstanceCreated, InstanceDeleted)
"""
...
# ========== Worker Planning ==========
@abstractmethod
def handles_instance(self, instance: object) -> bool:
"""Return True if this plugin manages the given instance type."""
...
@abstractmethod
def plan_task(
self,
runners: Mapping[RunnerId, "RunnerSupervisor"],
instances: Mapping[InstanceId, "Instance"],
) -> Task | None:
"""Plan the next task for plugin instances.
Called during each planning cycle.
Return None if no task is needed.
"""
...
@abstractmethod
def should_skip_download(self, instance: object) -> bool:
"""Return True if this instance type doesn't need model downloads."""
...
# ========== Runner Bootstrap ==========
@abstractmethod
def create_runner(
self,
bound_instance: "BoundInstance",
event_sender: "MpSender[Event]",
task_receiver: "MpReceiver[Task]",
) -> None:
"""Entry point for the runner process.
Called in a subprocess to execute the actual workload.
This function should block until the workload completes.
"""
...

View File

@@ -0,0 +1,21 @@
"""Context objects passed to plugin handlers."""
from collections.abc import Awaitable, Callable
from dataclasses import dataclass
from exo.shared.types.commands import Command
from exo.shared.types.common import NodeId
from exo.shared.types.state import State
@dataclass
class PluginContext:
"""Context provided to plugin API handlers.
This gives plugins access to the current state and the ability to send
commands without direct access to internal API components.
"""
state: State
send_command: Callable[[Command], Awaitable[None]]
node_id: NodeId

View File

@@ -0,0 +1,5 @@
"""Plugin implementations directory.
Each subdirectory should contain a plugin with a register() function
that returns an ExoPlugin instance.
"""

View File

@@ -0,0 +1,8 @@
"""FLASH Plugin - MPI-based simulation support for Exo."""
from exo.plugins.implementations.flash.plugin import FLASHPlugin
def register() -> FLASHPlugin:
"""Entry point for plugin discovery."""
return FLASHPlugin()

View File

@@ -0,0 +1,108 @@
"""FLASH plugin API handlers."""
from typing import Any
from fastapi import HTTPException
from exo.plugins.context import PluginContext
# Use core types for serialization compatibility
from exo.shared.types.commands import LaunchFLASH, StopFLASH
from exo.shared.types.worker.instances import FLASHInstance
async def handle_launch_flash(
ctx: PluginContext,
simulation_name: str,
flash_executable_path: str,
working_directory: str,
parameter_file_path: str = "",
ranks_per_node: int = 1,
min_nodes: int = 1,
hosts: str = "",
) -> dict[str, str]:
"""Launch a FLASH MPI simulation across the cluster.
Args:
ctx: Plugin context with state and send_command
simulation_name: Name of the simulation
flash_executable_path: Path to the FLASH executable
working_directory: Working directory for the simulation
parameter_file_path: Path to parameter file (optional)
ranks_per_node: Number of MPI ranks per node
min_nodes: Minimum number of nodes required
hosts: Optional comma-separated hostnames (e.g., "s14,james21-1").
If not provided, IPs are discovered from topology edges.
"""
command = LaunchFLASH(
simulation_name=simulation_name,
flash_executable_path=flash_executable_path,
parameter_file_path=parameter_file_path,
working_directory=working_directory,
ranks_per_node=ranks_per_node,
min_nodes=min_nodes,
hosts=hosts,
)
await ctx.send_command(command)
return {
"message": "FLASH launch command received",
"command_id": str(command.command_id),
"simulation_name": simulation_name,
}
async def handle_stop_flash(
ctx: PluginContext,
instance_id: str,
) -> dict[str, str]:
"""Stop a running FLASH simulation."""
from exo.shared.types.worker.instances import InstanceId
inst_id = InstanceId(instance_id)
if inst_id not in ctx.state.instances:
raise HTTPException(status_code=404, detail="Instance not found")
instance = ctx.state.instances[inst_id]
if not isinstance(instance, FLASHInstance):
raise HTTPException(
status_code=400, detail="Instance is not a FLASH simulation"
)
command = StopFLASH(instance_id=inst_id)
await ctx.send_command(command)
return {
"message": "Stop command received",
"command_id": str(command.command_id),
"instance_id": str(instance_id),
}
async def handle_list_flash_instances(ctx: PluginContext) -> list[dict[str, Any]]:
"""List all FLASH simulation instances."""
flash_instances: list[dict[str, Any]] = []
for instance_id, instance in ctx.state.instances.items():
if isinstance(instance, FLASHInstance):
# Get runner statuses for this instance
runner_statuses: dict[str, str | None] = {}
for (
node_id,
runner_id,
) in instance.shard_assignments.node_to_runner.items():
runner_status = ctx.state.runners.get(runner_id)
runner_statuses[str(node_id)] = (
str(runner_status) if runner_status else None
)
flash_instances.append(
{
"instance_id": str(instance_id),
"simulation_name": instance.simulation_name,
"total_ranks": instance.total_ranks,
"working_directory": instance.working_directory,
"runner_statuses": runner_statuses,
}
)
return flash_instances

View File

@@ -0,0 +1,152 @@
"""FLASH plugin placement logic."""
from collections.abc import Mapping
from copy import deepcopy
from loguru import logger
from exo.shared.topology import Topology
from exo.shared.types.commands import LaunchFLASH
from exo.shared.types.common import Host, NodeId
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.topology import SocketConnection
from exo.shared.types.worker.instances import FLASHInstance, Instance, InstanceId
from exo.shared.types.worker.runners import (
RunnerId,
ShardAssignments,
)
from exo.shared.types.worker.shards import PipelineShardMetadata
def place_flash_instance(
command: LaunchFLASH,
topology: Topology,
current_instances: Mapping[InstanceId, Instance],
) -> dict[InstanceId, Instance]:
"""Place a FLASH simulation instance across available nodes.
Unlike MLX instances which use ring/JACCL topology for tensor parallelism,
FLASH instances use MPI for communication. We just need to provide the
node IPs so the runner can generate an MPI hostfile.
"""
instance_id = InstanceId()
target_instances: dict[InstanceId, Instance] = dict(deepcopy(current_instances))
all_nodes = list(topology.list_nodes())
if len(all_nodes) < command.min_nodes:
raise ValueError(
f"Not enough nodes: need {command.min_nodes}, have {len(all_nodes)}"
)
# Select nodes (take the first min_nodes)
selected_nodes = all_nodes[: command.min_nodes]
logger.info(
f"Placing FLASH instance '{command.simulation_name}' on {len(selected_nodes)} nodes"
)
# Build shard assignments (one runner per node for FLASH)
runner_to_shard: dict[RunnerId, PipelineShardMetadata] = {}
node_to_runner: dict[NodeId, RunnerId] = {}
# Create a dummy ModelMetadata for FLASH (required by ShardMetadata interface)
flash_model_meta = ModelMetadata(
model_id=ModelId(command.simulation_name),
pretty_name=f"FLASH: {command.simulation_name}",
storage_size=Memory(in_bytes=0),
n_layers=1,
hidden_size=1,
supports_tensor=False,
)
for i, node_id in enumerate(selected_nodes):
runner_id = RunnerId()
node_to_runner[node_id] = runner_id
runner_to_shard[runner_id] = PipelineShardMetadata(
device_rank=i,
world_size=len(selected_nodes),
model_meta=flash_model_meta,
start_layer=0,
end_layer=1,
n_layers=1,
)
shard_assignments = ShardAssignments(
model_id=ModelId(command.simulation_name),
runner_to_shard=runner_to_shard,
node_to_runner=node_to_runner,
)
# Build hosts_by_node - get hostnames/IPs for MPI hostfile generation
hosts_by_node: dict[NodeId, list[Host]] = {}
# If explicit hosts are provided, use them directly
if command.hosts:
explicit_hosts = [h.strip() for h in command.hosts.split(",") if h.strip()]
logger.info(f"FLASH placement: explicit hosts provided: {explicit_hosts}")
for i, node_id in enumerate(selected_nodes):
if i < len(explicit_hosts):
hosts_by_node[node_id] = [Host(ip=explicit_hosts[i], port=0)]
logger.info(
f"FLASH placement: node {node_id} (rank {i}) -> IP {explicit_hosts[i]}"
)
else:
logger.warning(
f"Not enough hosts provided for node {i}, using localhost"
)
hosts_by_node[node_id] = [Host(ip="127.0.0.1", port=0)]
logger.info(
f"FLASH placement: coordinator will be rank 0 at IP {explicit_hosts[0]}"
)
else:
# Try to get IPs from topology edges
for node_id in selected_nodes:
node_hosts: list[Host] = []
# Get IP from outgoing edges (connections to other nodes via mDNS discovery)
for conn in topology.out_edges(node_id):
if isinstance(conn.edge, SocketConnection):
# Extract IP from multiaddr
ip = conn.edge.sink_multiaddr.ip_address
# Skip link-local and localhost addresses
if not ip.startswith("169.254.") and not ip.startswith("127."):
node_hosts.append(Host(ip=ip, port=0))
break
# Last resort: use localhost (will only work for single-node)
if not node_hosts:
logger.warning(
f"Could not determine IP for node {node_id}, using localhost"
)
node_hosts.append(Host(ip="127.0.0.1", port=0))
hosts_by_node[node_id] = node_hosts
total_ranks = len(selected_nodes) * command.ranks_per_node
# Determine coordinator IP - first node's first host IP
first_node_id: NodeId = next(iter(hosts_by_node.keys()))
coordinator_ip: str = (
hosts_by_node[first_node_id][0].ip
if hosts_by_node[first_node_id]
else "127.0.0.1"
)
target_instances[instance_id] = FLASHInstance(
instance_id=instance_id,
shard_assignments=shard_assignments,
hosts_by_node=hosts_by_node,
flash_executable_path=command.flash_executable_path,
parameter_file_path=command.parameter_file_path,
working_directory=command.working_directory,
ranks_per_node=command.ranks_per_node,
total_ranks=total_ranks,
simulation_name=command.simulation_name,
coordinator_ip=coordinator_ip,
)
logger.info(f"Created FLASH instance {instance_id} with {total_ranks} total ranks")
return target_instances

View File

@@ -0,0 +1,36 @@
"""FLASH plugin planning logic."""
from collections.abc import Mapping
from exo.shared.types.tasks import LoadModel, Task
from exo.shared.types.worker.instances import FLASHInstance, Instance, InstanceId
from exo.shared.types.worker.runners import RunnerId, RunnerIdle
from exo.worker.runner.runner_supervisor import RunnerSupervisor
def plan_flash(
runners: Mapping[RunnerId, RunnerSupervisor],
instances: Mapping[InstanceId, Instance],
) -> Task | None:
"""Plan tasks specifically for FLASH instances.
FLASH instances have a simpler lifecycle:
- CreateRunner (handled by core _create_runner)
- LoadModel (starts the simulation immediately)
- Shutdown (handled by core _kill_runner)
This function handles the LoadModel step for FLASH instances,
skipping the MLX-specific download/init/warmup steps.
"""
for runner in runners.values():
instance = runner.bound_instance.instance
# Only handle FLASH instances
if not isinstance(instance, FLASHInstance):
continue
# If runner is idle, emit LoadModel to start the simulation
if isinstance(runner.status, RunnerIdle):
return LoadModel(instance_id=instance.instance_id)
return None

View File

@@ -0,0 +1,98 @@
"""FLASH Plugin - Main plugin class."""
from collections.abc import Callable, Mapping, Sequence
from typing import Any
from exo.plugins.base import ExoPlugin
from exo.plugins.implementations.flash.api_handlers import (
handle_launch_flash,
handle_list_flash_instances,
handle_stop_flash,
)
from exo.plugins.implementations.flash.placement import place_flash_instance
from exo.plugins.implementations.flash.planning import plan_flash
from exo.plugins.implementations.flash.runner import main as flash_runner_main
from exo.shared.topology import Topology
from exo.shared.types.commands import DeleteInstance, LaunchFLASH, StopFLASH
from exo.shared.types.events import Event
from exo.shared.types.tasks import Task
from exo.shared.types.worker.instances import (
BoundInstance,
FLASHInstance,
Instance,
InstanceId,
)
from exo.shared.types.worker.runners import RunnerId
from exo.utils.channels import MpReceiver, MpSender
from exo.worker.runner.runner_supervisor import RunnerSupervisor
class FLASHPlugin(ExoPlugin):
"""Plugin for FLASH MPI simulations."""
@property
def name(self) -> str:
return "flash"
@property
def version(self) -> str:
return "1.0.0"
def get_command_types(self) -> Sequence[type]:
return [LaunchFLASH, StopFLASH]
def get_instance_type(self) -> type:
return FLASHInstance
def get_api_routes(
self,
) -> Sequence[tuple[str, str, Callable[..., Any]]]:
return [
("post", "/flash/launch", handle_launch_flash),
("delete", "/flash/{instance_id}", handle_stop_flash),
("get", "/flash/instances", handle_list_flash_instances),
]
def handles_command(self, command: Any) -> bool: # pyright: ignore[reportAny]
return isinstance(command, (LaunchFLASH, StopFLASH))
def process_command(
self,
command: Any, # pyright: ignore[reportAny]
topology: Topology,
current_instances: Mapping[InstanceId, Instance],
) -> Sequence[Event]:
from exo.master.placement import delete_instance, get_transition_events
if isinstance(command, LaunchFLASH):
placement = place_flash_instance(command, topology, current_instances)
return list(get_transition_events(current_instances, placement))
elif isinstance(command, StopFLASH):
placement = delete_instance(
DeleteInstance(instance_id=command.instance_id),
current_instances,
)
return list(get_transition_events(current_instances, placement))
return []
def handles_instance(self, instance: object) -> bool:
return isinstance(instance, FLASHInstance)
def plan_task(
self,
runners: Mapping[RunnerId, RunnerSupervisor],
instances: Mapping[InstanceId, Instance],
) -> Task | None:
return plan_flash(runners, instances)
def should_skip_download(self, instance: object) -> bool:
# FLASH instances don't need model downloads
return True
def create_runner(
self,
bound_instance: BoundInstance,
event_sender: MpSender[Event],
task_receiver: MpReceiver[Task],
) -> None:
flash_runner_main(bound_instance, event_sender, task_receiver)

View File

@@ -0,0 +1,302 @@
"""FLASH MPI Runner - spawns and monitors FLASH simulations.
Exo-native distributed MPI:
- Exo handles node discovery and coordination
- Coordinator generates hostfile from Exo topology
- mpirun uses exo-rsh (no SSH required) to spawn on remote nodes
- exo-rsh connects to each node's Exo API (/execute endpoint) for remote execution
- Workers just report ready and wait
"""
import os
import shutil
import socket
import subprocess
import threading
from loguru import logger
from exo.shared.types.events import (
Event,
RunnerStatusUpdated,
TaskAcknowledged,
TaskStatusUpdated,
)
from exo.shared.types.tasks import (
LoadModel,
Shutdown,
Task,
TaskStatus,
)
from exo.shared.types.worker.instances import BoundInstance, FLASHInstance
from exo.shared.types.worker.runners import (
RunnerFailed,
RunnerIdle,
RunnerLoading,
RunnerReady,
RunnerRunning,
RunnerShutdown,
RunnerShuttingDown,
RunnerStatus,
)
from exo.utils.channels import MpReceiver, MpSender
# Find mpirun in PATH, fallback to common locations
MPIRUN_PATH = shutil.which("mpirun") or "/opt/homebrew/bin/mpirun"
# exo-rsh is installed as console script by exo package
_exo_rsh_path = shutil.which("exo-rsh")
if not _exo_rsh_path:
raise RuntimeError("exo-rsh not found in PATH - this should be installed with exo")
EXO_RSH_PATH: str = _exo_rsh_path
def get_my_rank(instance: FLASHInstance, my_node_id: str) -> int:
"""Determine this node's rank based on position in hosts_by_node."""
for i, node_id in enumerate(instance.hosts_by_node.keys()):
if str(node_id) == str(my_node_id):
return i
return -1
def get_coordinator_host(instance: FLASHInstance) -> str:
"""Get the IP of the coordinator node."""
return instance.coordinator_ip
def resolve_host(host: str) -> str:
"""Resolve host string to a usable hostname for MPI hostfile.
Accepts either an IP address or hostname. For IPs, attempts to resolve
to a hostname via DNS/mDNS. Hostnames are returned as-is after validation.
"""
# Check if input is already a hostname (not an IP)
try:
socket.inet_aton(host)
is_ip = True
except socket.error:
is_ip = False
if not is_ip:
# Already a hostname, verify it resolves and return as-is
try:
socket.gethostbyname(host)
return host
except socket.gaierror:
logger.warning(f"Hostname {host} does not resolve, using anyway")
return host
# It's an IP address, try to resolve to hostname
try:
hostname, _, _ = socket.gethostbyaddr(host)
hostname = hostname.split(".")[0]
logger.info(f"Resolved {host} to {hostname}")
return hostname
except socket.herror:
pass
# Fall back to IP
logger.warning(f"Could not resolve {host} to hostname, using IP directly")
return host
def generate_hostfile(instance: FLASHInstance, working_dir: str) -> str:
"""Generate MPI hostfile from instance topology."""
hostfile_path = os.path.join(working_dir, "flash_hosts.txt")
with open(hostfile_path, "w") as f:
for _node_id, hosts in instance.hosts_by_node.items():
if hosts:
host = resolve_host(hosts[0].ip)
f.write(f"{host} slots={instance.ranks_per_node}\n")
logger.info(f"Generated hostfile at {hostfile_path}")
with open(hostfile_path, "r") as f:
logger.info(f"Hostfile contents:\n{f.read()}")
return hostfile_path
def main(
bound_instance: BoundInstance,
event_sender: MpSender[Event],
task_receiver: MpReceiver[Task],
) -> None:
"""Main FLASH runner loop.
Coordinator: generates hostfile and runs mpirun (uses exo-rsh instead of SSH)
Workers: just report ready and wait for mpirun to spawn processes on them
"""
assert isinstance(bound_instance.instance, FLASHInstance)
instance = bound_instance.instance
runner_id = bound_instance.bound_runner_id
my_node_id = str(bound_instance.bound_node_id)
logger.info(f"FLASH runner starting for simulation: {instance.simulation_name}")
my_rank = get_my_rank(instance, my_node_id)
world_size = len(instance.hosts_by_node)
is_coordinator = my_rank == 0
coordinator_ip = get_coordinator_host(instance)
logger.info(
f"FLASH node: rank={my_rank}, world_size={world_size}, coordinator={is_coordinator}"
)
logger.info(f"FLASH coordinator IP: {coordinator_ip}")
process: subprocess.Popen[bytes] | None = None
current_status: RunnerStatus = RunnerIdle()
shutdown_requested = False
event_sender.send(
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
)
def monitor_output(proc: subprocess.Popen[bytes]) -> None:
"""Monitor FLASH stdout for progress updates."""
if proc.stdout is None:
return
for line in iter(proc.stdout.readline, b""):
if shutdown_requested:
break
try:
decoded: str = line.decode("utf-8", errors="replace").strip()
if decoded:
logger.info(f"[FLASH] {decoded}")
except Exception as e:
logger.warning(f"Error parsing FLASH output: {e}")
with task_receiver as tasks:
for task in tasks:
event_sender.send(
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Running)
)
event_sender.send(TaskAcknowledged(task_id=task.task_id))
match task:
case LoadModel() if isinstance(current_status, RunnerIdle):
current_status = RunnerLoading()
logger.info("Starting FLASH simulation")
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=current_status
)
)
try:
if is_coordinator:
# Coordinator: generate hostfile and run mpirun
hostfile = generate_hostfile(
instance, instance.working_directory
)
iface = instance.network_interface
cmd = [
MPIRUN_PATH,
"-np",
str(instance.total_ranks),
"--hostfile",
hostfile,
"--wdir",
instance.working_directory,
"--oversubscribe",
"--mca",
"btl",
"tcp,self",
"--mca",
"btl_tcp_if_include",
iface,
"--mca",
"oob_tcp_if_include",
iface,
"--mca",
"plm_rsh_no_tree_spawn",
"1",
]
# Use exo-rsh for remote execution (no SSH needed)
cmd.extend(["--mca", "plm_rsh_agent", EXO_RSH_PATH])
cmd.append(instance.flash_executable_path)
logger.info(f"FLASH distributed launch: {' '.join(cmd)}")
process = subprocess.Popen(
cmd,
cwd=instance.working_directory,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
)
monitor_thread = threading.Thread(
target=monitor_output, args=(process,), daemon=True
)
monitor_thread.start()
current_status = RunnerRunning()
logger.info(
f"FLASH running on {world_size} nodes with {instance.total_ranks} ranks"
)
else:
# Worker: mpirun on coordinator will use exo-rsh to spawn processes here
logger.info(
f"Worker {my_rank}: Ready for mpirun to spawn processes via exo-rsh"
)
current_status = RunnerRunning()
except Exception as e:
logger.error(f"Failed to start FLASH: {e}")
import traceback
logger.error(traceback.format_exc())
current_status = RunnerFailed(error_message=str(e))
case Shutdown():
shutdown_requested = True
current_status = RunnerShuttingDown()
logger.info("FLASH runner shutting down")
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=current_status
)
)
if process and process.poll() is None:
logger.info("Terminating FLASH simulation")
process.terminate()
try:
process.wait(timeout=10)
except subprocess.TimeoutExpired:
logger.warning("FLASH didn't terminate, killing")
process.kill()
process.wait()
current_status = RunnerShutdown()
case _:
if process and process.poll() is not None:
exit_code = process.returncode
if exit_code == 0:
logger.info("FLASH simulation completed successfully")
current_status = RunnerReady()
else:
logger.error(
f"FLASH simulation failed with code {exit_code}"
)
current_status = RunnerFailed(
error_message=f"Exit code {exit_code}"
)
event_sender.send(
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Complete)
)
event_sender.send(
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
)
if isinstance(current_status, RunnerShutdown):
break
if process and process.poll() is None:
process.terminate()
process.wait(timeout=5)
logger.info("FLASH runner exiting")

110
src/exo/plugins/registry.py Normal file
View File

@@ -0,0 +1,110 @@
"""Plugin registry for discovering and managing plugins."""
from collections.abc import Callable, Sequence
from typing import Any
from loguru import logger
from exo.plugins.base import ExoPlugin
class PluginRegistry:
"""Central registry for all plugins."""
_instance: "PluginRegistry | None" = None
def __init__(self) -> None:
self._plugins: dict[str, ExoPlugin] = {}
self._command_handlers: dict[type, ExoPlugin] = {}
self._instance_handlers: dict[type, ExoPlugin] = {}
@classmethod
def get(cls) -> "PluginRegistry":
"""Get the singleton registry instance."""
if cls._instance is None:
cls._instance = cls()
return cls._instance
@classmethod
def reset(cls) -> None:
"""Reset the singleton instance (useful for testing)."""
cls._instance = None
def register(self, plugin: ExoPlugin) -> None:
"""Register a plugin and its types."""
if plugin.name in self._plugins:
raise ValueError(f"Plugin '{plugin.name}' already registered")
logger.info(f"Registering plugin: {plugin.name} v{plugin.version}")
self._plugins[plugin.name] = plugin
# Register command handlers
for cmd_type in plugin.get_command_types():
self._command_handlers[cmd_type] = plugin
logger.debug(f" Registered command: {cmd_type.__name__}")
# Register instance handler
instance_type = plugin.get_instance_type()
self._instance_handlers[instance_type] = plugin
logger.debug(f" Registered instance: {instance_type.__name__}")
def get_plugin(self, name: str) -> ExoPlugin | None:
"""Get a plugin by name."""
return self._plugins.get(name)
def get_plugin_for_command(self, command: object) -> ExoPlugin | None:
"""Get the plugin that handles a command."""
for plugin in self._plugins.values():
if plugin.handles_command(command):
return plugin
return None
def get_plugin_for_instance(self, instance: object) -> ExoPlugin | None:
"""Get the plugin that manages an instance."""
for plugin in self._plugins.values():
if plugin.handles_instance(instance):
return plugin
return None
def all_plugins(self) -> Sequence[ExoPlugin]:
"""Get all registered plugins."""
return list(self._plugins.values())
def get_all_api_routes(
self,
) -> Sequence[tuple[str, str, Callable[..., Any], ExoPlugin]]:
"""Get all API routes from all plugins."""
routes: list[tuple[str, str, Callable[..., Any], ExoPlugin]] = []
for plugin in self._plugins.values():
for method, path, handler in plugin.get_api_routes():
routes.append((method, path, handler, plugin))
return routes
def discover_plugins() -> None:
"""Auto-discover and register plugins from the implementations directory.
Plugins should have a register() function that returns an ExoPlugin instance.
"""
import importlib
import pkgutil
registry = PluginRegistry.get()
try:
import exo.plugins.implementations as impl_package
for _, module_name, _ in pkgutil.iter_modules(impl_package.__path__):
try:
module = importlib.import_module(
f"exo.plugins.implementations.{module_name}"
)
if hasattr(module, "register"):
plugin = module.register() # pyright: ignore[reportAny]
if plugin is not None:
registry.register(plugin) # pyright: ignore[reportAny]
except Exception as e:
logger.warning(f"Failed to load plugin {module_name}: {e}")
except ImportError:
logger.debug("No plugin implementations package found")

13
src/exo/rsh/__init__.py Normal file
View File

@@ -0,0 +1,13 @@
"""Exo RSH - Remote Shell for MPI without SSH.
This module provides a remote execution mechanism that allows mpirun to spawn
processes on remote nodes without requiring SSH setup. It works by:
1. Each Exo node runs an API server on port 52415 with an /execute endpoint
2. The exo-rsh script acts as a drop-in replacement for ssh
3. When mpirun calls "exo-rsh hostname command", it HTTP POSTs to the target's /execute
4. The target executes the command and returns output
Usage:
mpirun --mca plm_rsh_agent exo-rsh -np 4 --hostfile hosts.txt ./program
"""

101
src/exo/rsh/client.py Normal file
View File

@@ -0,0 +1,101 @@
#!/usr/bin/env python3
"""exo-rsh - Remote shell client for MPI.
This script is called by mpirun as a replacement for ssh.
Usage: exo-rsh [ssh-options...] hostname command [args...]
It connects to the target node's Exo API (port 52415) and executes the command.
"""
import json
import socket
import sys
from typing import Any, cast
from urllib.error import URLError
from urllib.request import Request, urlopen
# Use the same port as Exo's API server
EXO_API_PORT = 52415
def resolve_hostname(hostname: str) -> str:
"""Resolve hostname to IP address."""
try:
return socket.gethostbyname(hostname)
except socket.gaierror:
# If resolution fails, try using the hostname directly
return hostname
def main():
# Parse arguments - mpirun calls us like: exo-rsh [options] hostname command [args...]
# SSH options we might see: -x (disable X11), -o options, etc.
args = sys.argv[1:]
# Skip SSH-style options
hostname = None
command_start = 0
i = 0
while i < len(args):
arg = args[i]
if arg.startswith("-"):
# Skip option and its value if needed
if arg in ("-o", "-i", "-l", "-p", "-F"):
i += 2 # Skip option and its argument
continue
i += 1
continue
else:
# First non-option is the hostname
hostname = arg
command_start = i + 1
break
i += 1
if hostname is None or command_start >= len(args):
print("Usage: exo-rsh [options] hostname command [args...]", file=sys.stderr)
sys.exit(1)
command = args[command_start:]
# Resolve hostname to IP
ip = resolve_hostname(hostname)
# Make request to Exo API
url = f"http://{ip}:{EXO_API_PORT}/execute"
data = json.dumps({"command": command}).encode("utf-8")
try:
req = Request(url, data=data, headers={"Content-Type": "application/json"})
with urlopen(req, timeout=300) as response: # pyright: ignore[reportAny]
response_body: bytes = cast(bytes, response.read()) # pyright: ignore[reportAny]
result: dict[str, Any] = json.loads(response_body.decode("utf-8")) # pyright: ignore[reportAny]
# Output stdout/stderr
stdout: str = cast(str, result.get("stdout", ""))
stderr: str = cast(str, result.get("stderr", ""))
exit_code: int = cast(int, result.get("exit_code", 0))
if stdout:
sys.stdout.write(stdout)
sys.stdout.flush()
if stderr:
sys.stderr.write(stderr)
sys.stderr.flush()
sys.exit(exit_code)
except URLError as e:
print(
f"exo-rsh: Failed to connect to {hostname}:{EXO_API_PORT}: {e}",
file=sys.stderr,
)
sys.exit(255)
except Exception as e:
print(f"exo-rsh: Error: {e}", file=sys.stderr)
sys.exit(1)
if __name__ == "__main__":
main()

View File

@@ -1,310 +1,552 @@
from pydantic import PositiveInt
from exo.shared.types.common import Id
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.utils.pydantic_ext import CamelCaseModel
class ModelId(Id):
def normalize(self) -> str:
return self.replace("/", "--")
def short(self) -> str:
return self.split("/")[-1]
class ModelCard(CamelCaseModel):
short_id: str
model_id: ModelId
storage_size: Memory
n_layers: PositiveInt
hidden_size: PositiveInt
supports_tensor: bool
name: str
description: str
tags: list[str]
metadata: ModelMetadata
MODEL_CARDS: dict[str, ModelCard] = {
# deepseek v3
"deepseek-v3.1-4bit": ModelCard(
short_id="deepseek-v3.1-4bit",
model_id=ModelId("mlx-community/DeepSeek-V3.1-4bit"),
storage_size=Memory.from_gb(378),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
name="DeepSeek V3.1 (4-bit)",
description="""DeepSeek V3.1 is a large language model trained on the DeepSeek V3.1 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/DeepSeek-V3.1-4bit"),
pretty_name="DeepSeek V3.1 (4-bit)",
storage_size=Memory.from_gb(378),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
),
),
"deepseek-v3.1-8bit": ModelCard(
short_id="deepseek-v3.1-8bit",
model_id=ModelId("mlx-community/DeepSeek-V3.1-8bit"),
storage_size=Memory.from_gb(713),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
name="DeepSeek V3.1 (8-bit)",
description="""DeepSeek V3.1 is a large language model trained on the DeepSeek V3.1 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/DeepSeek-V3.1-8bit"),
pretty_name="DeepSeek V3.1 (8-bit)",
storage_size=Memory.from_gb(713),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
),
),
# kimi k2
"kimi-k2-instruct-4bit": ModelCard(
short_id="kimi-k2-instruct-4bit",
model_id=ModelId("mlx-community/Kimi-K2-Instruct-4bit"),
storage_size=Memory.from_gb(578),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
name="Kimi K2 Instruct (4-bit)",
description="""Kimi K2 is a large language model trained on the Kimi K2 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Kimi-K2-Instruct-4bit"),
pretty_name="Kimi K2 Instruct (4-bit)",
storage_size=Memory.from_gb(578),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
),
),
"kimi-k2-thinking": ModelCard(
short_id="kimi-k2-thinking",
model_id=ModelId("mlx-community/Kimi-K2-Thinking"),
storage_size=Memory.from_gb(658),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
name="Kimi K2 Thinking (4-bit)",
description="""Kimi K2 Thinking is the latest, most capable version of open-source thinking model.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Kimi-K2-Thinking"),
pretty_name="Kimi K2 Thinking (4-bit)",
storage_size=Memory.from_gb(658),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
),
),
# llama-3.1
"llama-3.1-8b": ModelCard(
short_id="llama-3.1-8b",
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-4bit"),
storage_size=Memory.from_mb(4423),
n_layers=32,
hidden_size=4096,
supports_tensor=True,
name="Llama 3.1 8B (4-bit)",
description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-4bit"),
pretty_name="Llama 3.1 8B (4-bit)",
storage_size=Memory.from_mb(4423),
n_layers=32,
hidden_size=4096,
supports_tensor=True,
),
),
"llama-3.1-8b-8bit": ModelCard(
short_id="llama-3.1-8b-8bit",
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-8bit"),
storage_size=Memory.from_mb(8540),
n_layers=32,
hidden_size=4096,
supports_tensor=True,
name="Llama 3.1 8B (8-bit)",
description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-8bit"),
pretty_name="Llama 3.1 8B (8-bit)",
storage_size=Memory.from_mb(8540),
n_layers=32,
hidden_size=4096,
supports_tensor=True,
),
),
"llama-3.1-8b-bf16": ModelCard(
short_id="llama-3.1-8b-bf16",
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-bf16"),
storage_size=Memory.from_mb(16100),
n_layers=32,
hidden_size=4096,
supports_tensor=True,
name="Llama 3.1 8B (BF16)",
description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-bf16"),
pretty_name="Llama 3.1 8B (BF16)",
storage_size=Memory.from_mb(16100),
n_layers=32,
hidden_size=4096,
supports_tensor=True,
),
),
"llama-3.1-70b": ModelCard(
short_id="llama-3.1-70b",
model_id=ModelId("mlx-community/Meta-Llama-3.1-70B-Instruct-4bit"),
storage_size=Memory.from_mb(38769),
n_layers=80,
hidden_size=8192,
supports_tensor=True,
name="Llama 3.1 70B (4-bit)",
description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Meta-Llama-3.1-70B-Instruct-4bit"),
pretty_name="Llama 3.1 70B (4-bit)",
storage_size=Memory.from_mb(38769),
n_layers=80,
hidden_size=8192,
supports_tensor=True,
),
),
# llama-3.2
"llama-3.2-1b": ModelCard(
short_id="llama-3.2-1b",
model_id=ModelId("mlx-community/Llama-3.2-1B-Instruct-4bit"),
storage_size=Memory.from_mb(696),
n_layers=16,
hidden_size=2048,
supports_tensor=True,
name="Llama 3.2 1B (4-bit)",
description="""Llama 3.2 is a large language model trained on the Llama 3.2 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Llama-3.2-1B-Instruct-4bit"),
pretty_name="Llama 3.2 1B (4-bit)",
storage_size=Memory.from_mb(696),
n_layers=16,
hidden_size=2048,
supports_tensor=True,
),
),
"llama-3.2-3b": ModelCard(
short_id="llama-3.2-3b",
model_id=ModelId("mlx-community/Llama-3.2-3B-Instruct-4bit"),
storage_size=Memory.from_mb(1777),
n_layers=28,
hidden_size=3072,
supports_tensor=True,
name="Llama 3.2 3B (4-bit)",
description="""Llama 3.2 is a large language model trained on the Llama 3.2 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Llama-3.2-3B-Instruct-4bit"),
pretty_name="Llama 3.2 3B (4-bit)",
storage_size=Memory.from_mb(1777),
n_layers=28,
hidden_size=3072,
supports_tensor=True,
),
),
"llama-3.2-3b-8bit": ModelCard(
short_id="llama-3.2-3b-8bit",
model_id=ModelId("mlx-community/Llama-3.2-3B-Instruct-8bit"),
storage_size=Memory.from_mb(3339),
n_layers=28,
hidden_size=3072,
supports_tensor=True,
name="Llama 3.2 3B (8-bit)",
description="""Llama 3.2 is a large language model trained on the Llama 3.2 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Llama-3.2-3B-Instruct-8bit"),
pretty_name="Llama 3.2 3B (8-bit)",
storage_size=Memory.from_mb(3339),
n_layers=28,
hidden_size=3072,
supports_tensor=True,
),
),
# llama-3.3
"llama-3.3-70b": ModelCard(
short_id="llama-3.3-70b",
model_id=ModelId("mlx-community/Llama-3.3-70B-Instruct-4bit"),
storage_size=Memory.from_mb(38769),
n_layers=80,
hidden_size=8192,
supports_tensor=True,
name="Llama 3.3 70B (4-bit)",
description="""The Meta Llama 3.3 multilingual large language model (LLM) is an instruction tuned generative model in 70B (text in/text out)""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Llama-3.3-70B-Instruct-4bit"),
pretty_name="Llama 3.3 70B",
storage_size=Memory.from_mb(38769),
n_layers=80,
hidden_size=8192,
supports_tensor=True,
),
),
"llama-3.3-70b-8bit": ModelCard(
short_id="llama-3.3-70b-8bit",
model_id=ModelId("mlx-community/Llama-3.3-70B-Instruct-8bit"),
storage_size=Memory.from_mb(73242),
n_layers=80,
hidden_size=8192,
supports_tensor=True,
name="Llama 3.3 70B (8-bit)",
description="""The Meta Llama 3.3 multilingual large language model (LLM) is an instruction tuned generative model in 70B (text in/text out)""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Llama-3.3-70B-Instruct-8bit"),
pretty_name="Llama 3.3 70B (8-bit)",
storage_size=Memory.from_mb(73242),
n_layers=80,
hidden_size=8192,
supports_tensor=True,
),
),
"llama-3.3-70b-fp16": ModelCard(
short_id="llama-3.3-70b-fp16",
model_id=ModelId("mlx-community/llama-3.3-70b-instruct-fp16"),
storage_size=Memory.from_mb(137695),
n_layers=80,
hidden_size=8192,
supports_tensor=True,
name="Llama 3.3 70B (FP16)",
description="""The Meta Llama 3.3 multilingual large language model (LLM) is an instruction tuned generative model in 70B (text in/text out)""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/llama-3.3-70b-instruct-fp16"),
pretty_name="Llama 3.3 70B (FP16)",
storage_size=Memory.from_mb(137695),
n_layers=80,
hidden_size=8192,
supports_tensor=True,
),
),
# qwen3
"qwen3-0.6b": ModelCard(
short_id="qwen3-0.6b",
model_id=ModelId("mlx-community/Qwen3-0.6B-4bit"),
storage_size=Memory.from_mb(327),
n_layers=28,
hidden_size=1024,
supports_tensor=False,
name="Qwen3 0.6B (4-bit)",
description="""Qwen3 0.6B is a large language model trained on the Qwen3 0.6B dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-0.6B-4bit"),
pretty_name="Qwen3 0.6B (4-bit)",
storage_size=Memory.from_mb(327),
n_layers=28,
hidden_size=1024,
supports_tensor=False,
),
),
"qwen3-0.6b-8bit": ModelCard(
short_id="qwen3-0.6b-8bit",
model_id=ModelId("mlx-community/Qwen3-0.6B-8bit"),
storage_size=Memory.from_mb(666),
n_layers=28,
hidden_size=1024,
supports_tensor=False,
name="Qwen3 0.6B (8-bit)",
description="""Qwen3 0.6B is a large language model trained on the Qwen3 0.6B dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-0.6B-8bit"),
pretty_name="Qwen3 0.6B (8-bit)",
storage_size=Memory.from_mb(666),
n_layers=28,
hidden_size=1024,
supports_tensor=False,
),
),
"qwen3-30b": ModelCard(
short_id="qwen3-30b",
model_id=ModelId("mlx-community/Qwen3-30B-A3B-4bit"),
storage_size=Memory.from_mb(16797),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
name="Qwen3 30B A3B (4-bit)",
description="""Qwen3 30B is a large language model trained on the Qwen3 30B dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-30B-A3B-4bit"),
pretty_name="Qwen3 30B A3B (4-bit)",
storage_size=Memory.from_mb(16797),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
),
"qwen3-30b-8bit": ModelCard(
short_id="qwen3-30b-8bit",
model_id=ModelId("mlx-community/Qwen3-30B-A3B-8bit"),
storage_size=Memory.from_mb(31738),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
name="Qwen3 30B A3B (8-bit)",
description="""Qwen3 30B is a large language model trained on the Qwen3 30B dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-30B-A3B-8bit"),
pretty_name="Qwen3 30B A3B (8-bit)",
storage_size=Memory.from_mb(31738),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
),
"qwen3-80b-a3B-4bit": ModelCard(
short_id="qwen3-80b-a3B-4bit",
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-4bit"),
storage_size=Memory.from_mb(44800),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
name="Qwen3 80B A3B (4-bit)",
description="""Qwen3 80B""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-4bit"),
pretty_name="Qwen3 80B A3B (4-bit)",
storage_size=Memory.from_mb(44800),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
),
"qwen3-80b-a3B-8bit": ModelCard(
short_id="qwen3-80b-a3B-8bit",
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit"),
storage_size=Memory.from_mb(84700),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
name="Qwen3 80B A3B (8-bit)",
description="""Qwen3 80B""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit"),
pretty_name="Qwen3 80B A3B (8-bit)",
storage_size=Memory.from_mb(84700),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
),
"qwen3-80b-a3B-thinking-4bit": ModelCard(
short_id="qwen3-80b-a3B-thinking-4bit",
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit"),
storage_size=Memory.from_mb(84700),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
name="Qwen3 80B A3B Thinking (4-bit)",
description="""Qwen3 80B Reasoning model""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit"),
pretty_name="Qwen3 80B A3B (4-bit)",
storage_size=Memory.from_mb(84700),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
),
"qwen3-80b-a3B-thinking-8bit": ModelCard(
short_id="qwen3-80b-a3B-thinking-8bit",
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit"),
storage_size=Memory.from_mb(84700),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
name="Qwen3 80B A3B Thinking (8-bit)",
description="""Qwen3 80B Reasoning model""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit"),
pretty_name="Qwen3 80B A3B (8-bit)",
storage_size=Memory.from_mb(84700),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
),
"qwen3-235b-a22b-4bit": ModelCard(
short_id="qwen3-235b-a22b-4bit",
model_id=ModelId("mlx-community/Qwen3-235B-A22B-Instruct-2507-4bit"),
storage_size=Memory.from_gb(132),
n_layers=94,
hidden_size=4096,
supports_tensor=True,
name="Qwen3 235B A22B (4-bit)",
description="""Qwen3 235B (Active 22B) is a large language model trained on the Qwen3 235B dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-235B-A22B-Instruct-2507-4bit"),
pretty_name="Qwen3 235B A22B (4-bit)",
storage_size=Memory.from_gb(132),
n_layers=94,
hidden_size=4096,
supports_tensor=True,
),
),
"qwen3-235b-a22b-8bit": ModelCard(
short_id="qwen3-235b-a22b-8bit",
model_id=ModelId("mlx-community/Qwen3-235B-A22B-Instruct-2507-8bit"),
storage_size=Memory.from_gb(250),
n_layers=94,
hidden_size=4096,
supports_tensor=True,
name="Qwen3 235B A22B (8-bit)",
description="""Qwen3 235B (Active 22B) is a large language model trained on the Qwen3 235B dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-235B-A22B-Instruct-2507-8bit"),
pretty_name="Qwen3 235B A22B (8-bit)",
storage_size=Memory.from_gb(250),
n_layers=94,
hidden_size=4096,
supports_tensor=True,
),
),
"qwen3-coder-480b-a35b-4bit": ModelCard(
short_id="qwen3-coder-480b-a35b-4bit",
model_id=ModelId("mlx-community/Qwen3-Coder-480B-A35B-Instruct-4bit"),
storage_size=Memory.from_gb(270),
n_layers=62,
hidden_size=6144,
supports_tensor=True,
name="Qwen3 Coder 480B A35B (4-bit)",
description="""Qwen3 Coder 480B (Active 35B) is a large language model trained on the Qwen3 Coder 480B dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-Coder-480B-A35B-Instruct-4bit"),
pretty_name="Qwen3 Coder 480B A35B (4-bit)",
storage_size=Memory.from_gb(270),
n_layers=62,
hidden_size=6144,
supports_tensor=True,
),
),
"qwen3-coder-480b-a35b-8bit": ModelCard(
short_id="qwen3-coder-480b-a35b-8bit",
model_id=ModelId("mlx-community/Qwen3-Coder-480B-A35B-Instruct-8bit"),
storage_size=Memory.from_gb(540),
n_layers=62,
hidden_size=6144,
supports_tensor=True,
name="Qwen3 Coder 480B A35B (8-bit)",
description="""Qwen3 Coder 480B (Active 35B) is a large language model trained on the Qwen3 Coder 480B dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-Coder-480B-A35B-Instruct-8bit"),
pretty_name="Qwen3 Coder 480B A35B (8-bit)",
storage_size=Memory.from_gb(540),
n_layers=62,
hidden_size=6144,
supports_tensor=True,
),
),
# gpt-oss
"gpt-oss-120b-MXFP4-Q8": ModelCard(
short_id="gpt-oss-120b-MXFP4-Q8",
model_id=ModelId("mlx-community/gpt-oss-120b-MXFP4-Q8"),
storage_size=Memory.from_kb(68_996_301),
n_layers=36,
hidden_size=2880,
supports_tensor=True,
name="GPT-OSS 120B (MXFP4-Q8, MLX)",
description="""OpenAI's GPT-OSS 120B is a 117B-parameter Mixture-of-Experts model designed for high-reasoning and general-purpose use; this variant is a 4-bit MLX conversion for Apple Silicon.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/gpt-oss-120b-MXFP4-Q8"),
pretty_name="GPT-OSS 120B (MXFP4-Q8, MLX)",
storage_size=Memory.from_kb(68_996_301),
n_layers=36,
hidden_size=2880,
supports_tensor=True,
),
),
"gpt-oss-20b-MXFP4-Q8": ModelCard(
short_id="gpt-oss-20b-MXFP4-Q8",
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q8"),
storage_size=Memory.from_kb(11_744_051),
n_layers=24,
hidden_size=2880,
supports_tensor=True,
name="GPT-OSS 20B (MXFP4-Q8, MLX)",
description="""OpenAI's GPT-OSS 20B is a medium-sized MoE model for lower-latency and local or specialized use cases; this variant is a 4-bit MLX conversion for Apple Silicon.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q8"),
pretty_name="GPT-OSS 20B (MXFP4-Q8, MLX)",
storage_size=Memory.from_kb(11_744_051),
n_layers=24,
hidden_size=2880,
supports_tensor=True,
),
),
# glm 4.5
"glm-4.5-air-8bit": ModelCard(
# Needs to be quantized g32 or g16 to work with tensor parallel
short_id="glm-4.5-air-8bit",
model_id=ModelId("mlx-community/GLM-4.5-Air-8bit"),
storage_size=Memory.from_gb(114),
n_layers=46,
hidden_size=4096,
supports_tensor=False,
name="GLM 4.5 Air 8bit",
description="""GLM 4.5 Air 8bit""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/GLM-4.5-Air-8bit"),
pretty_name="GLM 4.5 Air 8bit",
storage_size=Memory.from_gb(114),
n_layers=46,
hidden_size=4096,
supports_tensor=False,
),
),
"glm-4.5-air-bf16": ModelCard(
short_id="glm-4.5-air-bf16",
model_id=ModelId("mlx-community/GLM-4.5-Air-bf16"),
storage_size=Memory.from_gb(214),
n_layers=46,
hidden_size=4096,
supports_tensor=True,
name="GLM 4.5 Air bf16",
description="""GLM 4.5 Air bf16""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/GLM-4.5-Air-bf16"),
pretty_name="GLM 4.5 Air bf16",
storage_size=Memory.from_gb(214),
n_layers=46,
hidden_size=4096,
supports_tensor=True,
),
),
# glm 4.7
"glm-4.7-4bit": ModelCard(
short_id="glm-4.7-4bit",
model_id=ModelId("mlx-community/GLM-4.7-4bit"),
storage_size=Memory.from_bytes(198556925568),
n_layers=91,
hidden_size=5120,
supports_tensor=True,
name="GLM 4.7 4bit",
description="GLM 4.7 4bit",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/GLM-4.7-4bit"),
pretty_name="GLM 4.7 4bit",
storage_size=Memory.from_bytes(198556925568),
n_layers=91,
hidden_size=5120,
supports_tensor=True,
),
),
"glm-4.7-6bit": ModelCard(
short_id="glm-4.7-6bit",
model_id=ModelId("mlx-community/GLM-4.7-6bit"),
storage_size=Memory.from_bytes(286737579648),
n_layers=91,
hidden_size=5120,
supports_tensor=True,
name="GLM 4.7 6bit",
description="GLM 4.7 6bit",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/GLM-4.7-6bit"),
pretty_name="GLM 4.7 6bit",
storage_size=Memory.from_bytes(286737579648),
n_layers=91,
hidden_size=5120,
supports_tensor=True,
),
),
"glm-4.7-8bit-gs32": ModelCard(
short_id="glm-4.7-8bit-gs32",
model_id=ModelId("mlx-community/GLM-4.7-8bit-gs32"),
storage_size=Memory.from_bytes(396963397248),
n_layers=91,
hidden_size=5120,
supports_tensor=True,
),
# glm 4.7 flash
"glm-4.7-flash-4bit": ModelCard(
model_id=ModelId("mlx-community/GLM-4.7-Flash-4bit"),
storage_size=Memory.from_gb(18),
n_layers=47,
hidden_size=2048,
supports_tensor=True,
),
"glm-4.7-flash-5bit": ModelCard(
model_id=ModelId("mlx-community/GLM-4.7-Flash-5bit"),
storage_size=Memory.from_gb(21),
n_layers=47,
hidden_size=2048,
supports_tensor=True,
),
"glm-4.7-flash-6bit": ModelCard(
model_id=ModelId("mlx-community/GLM-4.7-Flash-6bit"),
storage_size=Memory.from_gb(25),
n_layers=47,
hidden_size=2048,
supports_tensor=True,
),
"glm-4.7-flash-8bit": ModelCard(
model_id=ModelId("mlx-community/GLM-4.7-Flash-8bit"),
storage_size=Memory.from_gb(32),
n_layers=47,
hidden_size=2048,
supports_tensor=True,
name="GLM 4.7 8bit (gs32)",
description="GLM 4.7 8bit (gs32)",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/GLM-4.7-8bit-gs32"),
pretty_name="GLM 4.7 8bit (gs32)",
storage_size=Memory.from_bytes(396963397248),
n_layers=91,
hidden_size=5120,
supports_tensor=True,
),
),
# minimax-m2
"minimax-m2.1-8bit": ModelCard(
short_id="minimax-m2.1-8bit",
model_id=ModelId("mlx-community/MiniMax-M2.1-8bit"),
storage_size=Memory.from_bytes(242986745856),
n_layers=61,
hidden_size=3072,
supports_tensor=True,
name="MiniMax M2.1 8bit",
description="MiniMax M2.1 8bit",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/MiniMax-M2.1-8bit"),
pretty_name="MiniMax M2.1 8bit",
storage_size=Memory.from_bytes(242986745856),
n_layers=61,
hidden_size=3072,
supports_tensor=True,
),
),
"minimax-m2.1-3bit": ModelCard(
short_id="minimax-m2.1-3bit",
model_id=ModelId("mlx-community/MiniMax-M2.1-3bit"),
storage_size=Memory.from_bytes(100086644736),
n_layers=61,
hidden_size=3072,
supports_tensor=True,
name="MiniMax M2.1 3bit",
description="MiniMax M2.1 3bit",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/MiniMax-M2.1-3bit"),
pretty_name="MiniMax M2.1 3bit",
storage_size=Memory.from_bytes(100086644736),
n_layers=61,
hidden_size=3072,
supports_tensor=True,
),
),
}

View File

@@ -6,8 +6,9 @@ from huggingface_hub import model_info
from loguru import logger
from pydantic import BaseModel, Field
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard, ModelId
from exo.shared.models.model_cards import MODEL_CARDS
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.worker.download.download_utils import (
ModelSafetensorsIndex,
download_file_with_retry,
@@ -91,18 +92,18 @@ async def get_safetensors_size(model_id: str) -> Memory:
return Memory.from_bytes(info.safetensors.total)
_model_card_cache: dict[str, ModelCard] = {}
_model_meta_cache: dict[str, ModelMetadata] = {}
async def get_model_card(model_id: str) -> ModelCard:
if model_id in _model_card_cache:
return _model_card_cache[model_id]
model_card = await _get_model_card(model_id)
_model_card_cache[model_id] = model_card
return model_card
async def get_model_meta(model_id: str) -> ModelMetadata:
if model_id in _model_meta_cache:
return _model_meta_cache[model_id]
model_meta = await _get_model_meta(model_id)
_model_meta_cache[model_id] = model_meta
return model_meta
async def _get_model_card(model_id: str) -> ModelCard:
async def _get_model_meta(model_id: str) -> ModelMetadata:
"""Fetches storage size and number of layers for a Hugging Face model, returns Pydantic ModelMeta."""
config_data = await get_config_data(model_id)
num_layers = config_data.layer_count
@@ -112,11 +113,14 @@ async def _get_model_card(model_id: str) -> ModelCard:
None,
)
return ModelCard(
return ModelMetadata(
model_id=ModelId(model_id),
pretty_name=model_card.name if model_card is not None else model_id,
storage_size=mem_size_bytes,
n_layers=num_layers,
hidden_size=config_data.hidden_size or 0,
# TODO: all custom models currently do not support tensor. We could add a dynamic test for this?
supports_tensor=model_card.supports_tensor if model_card is not None else False,
supports_tensor=model_card.metadata.supports_tensor
if model_card is not None
else False,
)

View File

@@ -7,8 +7,8 @@ import pytest
from _pytest.logging import LogCaptureFixture
from loguru import logger
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.worker.shards import PipelineShardMetadata, ShardMetadata
@@ -31,8 +31,9 @@ def get_pipeline_shard_metadata(
model_id: ModelId, device_rank: int, world_size: int = 1
) -> ShardMetadata:
return PipelineShardMetadata(
model_card=ModelCard(
model_meta=ModelMetadata(
model_id=model_id,
pretty_name=str(model_id),
storage_size=Memory.from_mb(100000),
n_layers=32,
hidden_size=1000,

View File

@@ -4,9 +4,9 @@ from typing import Any, Literal
from pydantic import BaseModel, Field, field_validator
from pydantic_core import PydanticUseDefault
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.types.common import CommandId
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.shards import Sharding
@@ -206,7 +206,7 @@ class DeleteInstanceTaskParams(BaseModel):
class CreateInstanceResponse(BaseModel):
message: str
command_id: CommandId
model_card: ModelCard
model_meta: ModelMetadata
class DeleteInstanceResponse(BaseModel):

View File

@@ -1,10 +1,10 @@
from enum import Enum
from exo.shared.models.model_cards import ModelId
from exo.shared.types.api import GenerationStats
from exo.utils.pydantic_ext import TaggedModel
from .api import FinishReason
from .models import ModelId
class ChunkType(str, Enum):

View File

@@ -1,8 +1,8 @@
from pydantic import Field
from exo.shared.models.model_cards import ModelCard
from exo.shared.types.api import ChatCompletionTaskParams
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.models import ModelMetadata
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.shards import Sharding
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
@@ -21,7 +21,7 @@ class ChatCompletion(BaseCommand):
class PlaceInstance(BaseCommand):
model_card: ModelCard
model_meta: ModelMetadata
sharding: Sharding
instance_meta: InstanceMeta
min_nodes: int
@@ -35,6 +35,26 @@ class DeleteInstance(BaseCommand):
instance_id: InstanceId
class LaunchFLASH(BaseCommand):
"""Command to launch a FLASH MPI simulation."""
simulation_name: str
flash_executable_path: str
parameter_file_path: str
working_directory: str
ranks_per_node: int = 1
min_nodes: int = 1
# Optional: explicit hostnames for MPI (e.g., "s14,james21-1")
# Used when topology edges don't contain IP addresses
hosts: str = ""
class StopFLASH(BaseCommand):
"""Command to stop a running FLASH simulation."""
instance_id: InstanceId
class TaskFinished(BaseCommand):
finished_command_id: CommandId
@@ -50,6 +70,8 @@ Command = (
| PlaceInstance
| CreateInstance
| DeleteInstance
| LaunchFLASH
| StopFLASH
| TaskFinished
)

View File

@@ -16,9 +16,7 @@ class Id(str):
cls, _source: type, handler: GetCoreSchemaHandler
) -> core_schema.CoreSchema:
# Just use a plain string schema
return core_schema.no_info_after_validator_function(
cls, core_schema.str_schema()
)
return core_schema.str_schema()
class NodeId(Id):

View File

@@ -0,0 +1,18 @@
from pydantic import PositiveInt
from exo.shared.types.common import Id
from exo.shared.types.memory import Memory
from exo.utils.pydantic_ext import CamelCaseModel
class ModelId(Id):
pass
class ModelMetadata(CamelCaseModel):
model_id: ModelId
pretty_name: str
storage_size: Memory
n_layers: PositiveInt
hidden_size: PositiveInt
supports_tensor: bool

View File

@@ -14,6 +14,7 @@ class InstanceId(Id):
class InstanceMeta(str, Enum):
MlxRing = "MlxRing"
MlxJaccl = "MlxJaccl"
FLASH = "FLASH"
class BaseInstance(TaggedModel):
@@ -34,8 +35,27 @@ class MlxJacclInstance(BaseInstance):
jaccl_coordinators: dict[NodeId, str]
class FLASHInstance(BaseInstance):
"""Instance for FLASH MPI simulation.
Unlike MLX instances which do tensor parallelism, FLASH instances
coordinate MPI processes across nodes. Each node runs one or more
MPI ranks of the FLASH simulation.
"""
hosts_by_node: dict[NodeId, list[Host]]
flash_executable_path: str
parameter_file_path: str
working_directory: str
ranks_per_node: int = 1
total_ranks: int
simulation_name: str
coordinator_ip: str
network_interface: str = "en0" # Network interface for MPI (e.g., en0, eth0)
# TODO: Single node instance
Instance = MlxRingInstance | MlxJacclInstance
Instance = MlxRingInstance | MlxJacclInstance | FLASHInstance
class BoundInstance(CamelCaseModel):

View File

@@ -2,8 +2,8 @@ from collections.abc import Mapping
from pydantic import model_validator
from exo.shared.models.model_cards import ModelId
from exo.shared.types.common import Id, NodeId
from exo.shared.types.models import ModelId
from exo.shared.types.worker.shards import ShardMetadata
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel

View File

@@ -2,7 +2,7 @@ from enum import Enum
from pydantic import Field
from exo.shared.models.model_cards import ModelCard
from exo.shared.types.models import ModelMetadata
from exo.utils.pydantic_ext import TaggedModel
@@ -17,7 +17,7 @@ class BaseShardMetadata(TaggedModel):
Replaces previous `Shard` object.
"""
model_card: ModelCard
model_meta: ModelMetadata
device_rank: int
world_size: int
@@ -41,7 +41,7 @@ class BaseShardMetadata(TaggedModel):
def __hash__(self) -> int:
return hash(
(
self.model_card.model_id,
self.model_meta.model_id,
self.start_layer,
self.end_layer,
self.n_layers,

View File

@@ -460,10 +460,10 @@ async def resolve_allow_patterns(shard: ShardMetadata) -> list[str]:
# (iii) Tensor parallel requires all files.
return ["*"]
try:
weight_map = await get_weight_map(str(shard.model_card.model_id))
weight_map = await get_weight_map(str(shard.model_meta.model_id))
return get_allow_patterns(weight_map, shard)
except Exception:
logger.error(f"Error getting weight map for {shard.model_card.model_id=}")
logger.error(f"Error getting weight map for {shard.model_meta.model_id=}")
logger.error(traceback.format_exc())
return ["*"]
@@ -532,18 +532,18 @@ async def download_shard(
allow_patterns: list[str] | None = None,
) -> tuple[Path, RepoDownloadProgress]:
if not skip_download:
logger.info(f"Downloading {shard.model_card.model_id=}")
logger.info(f"Downloading {shard.model_meta.model_id=}")
# Handle local paths
if await aios.path.exists(str(shard.model_card.model_id)):
logger.info(f"Using local model path {shard.model_card.model_id}")
local_path = Path(str(shard.model_card.model_id))
if await aios.path.exists(str(shard.model_meta.model_id)):
logger.info(f"Using local model path {shard.model_meta.model_id}")
local_path = Path(str(shard.model_meta.model_id))
return local_path, await download_progress_for_local_path(
str(shard.model_card.model_id), shard, local_path
str(shard.model_meta.model_id), shard, local_path
)
revision = "main"
target_dir = await ensure_models_dir() / str(shard.model_card.model_id).replace(
target_dir = await ensure_models_dir() / str(shard.model_meta.model_id).replace(
"/", "--"
)
if not skip_download:
@@ -552,13 +552,13 @@ async def download_shard(
if not allow_patterns:
allow_patterns = await resolve_allow_patterns(shard)
logger.info(f"Downloading {shard.model_card.model_id=} with {allow_patterns=}")
logger.info(f"Downloading {shard.model_meta.model_id=} with {allow_patterns=}")
all_start_time = time.time()
# TODO: currently not recursive. Some models might require subdirectories - thus this will need to be changed.
# Update: <- This does not seem to be the case. Yay?
file_list = await fetch_file_list_with_cache(
str(shard.model_card.model_id), revision, recursive=True
str(shard.model_meta.model_id), revision, recursive=True
)
filtered_file_list = list(
filter_repo_objects(
@@ -592,7 +592,7 @@ async def download_shard(
else timedelta(seconds=0)
)
file_progress[file.path] = RepoFileDownloadProgress(
repo_id=str(shard.model_card.model_id),
repo_id=str(shard.model_meta.model_id),
repo_revision=revision,
file_path=file.path,
downloaded=Memory.from_bytes(curr_bytes),
@@ -609,7 +609,7 @@ async def download_shard(
shard,
calculate_repo_progress(
shard,
str(shard.model_card.model_id),
str(shard.model_meta.model_id),
revision,
file_progress,
all_start_time,
@@ -619,7 +619,7 @@ async def download_shard(
for file in filtered_file_list:
downloaded_bytes = await get_downloaded_size(target_dir / file.path)
file_progress[file.path] = RepoFileDownloadProgress(
repo_id=str(shard.model_card.model_id),
repo_id=str(shard.model_meta.model_id),
repo_revision=revision,
file_path=file.path,
downloaded=Memory.from_bytes(downloaded_bytes),
@@ -643,7 +643,7 @@ async def download_shard(
async def download_with_semaphore(file: FileListEntry) -> None:
async with semaphore:
await download_file_with_retry(
str(shard.model_card.model_id),
str(shard.model_meta.model_id),
revision,
file.path,
target_dir,
@@ -657,7 +657,7 @@ async def download_shard(
*[download_with_semaphore(file) for file in filtered_file_list]
)
final_repo_progress = calculate_repo_progress(
shard, str(shard.model_card.model_id), revision, file_progress, all_start_time
shard, str(shard.model_meta.model_id), revision, file_progress, all_start_time
)
await on_progress(shard, final_repo_progress)
if gguf := next((f for f in filtered_file_list if f.path.endswith(".gguf")), None):

View File

@@ -4,7 +4,7 @@ from pathlib import Path
from typing import AsyncIterator, Callable
from exo.shared.models.model_cards import MODEL_CARDS
from exo.shared.models.model_meta import get_model_card
from exo.shared.models.model_meta import get_model_meta
from exo.shared.types.worker.shards import (
PipelineShardMetadata,
ShardMetadata,
@@ -20,21 +20,21 @@ def exo_shard_downloader(max_parallel_downloads: int = 8) -> ShardDownloader:
async def build_base_shard(model_id: str) -> ShardMetadata:
model_card = await get_model_card(model_id)
model_meta = await get_model_meta(model_id)
return PipelineShardMetadata(
model_card=model_card,
model_meta=model_meta,
device_rank=0,
world_size=1,
start_layer=0,
end_layer=model_card.n_layers,
n_layers=model_card.n_layers,
end_layer=model_meta.n_layers,
n_layers=model_meta.n_layers,
)
async def build_full_shard(model_id: str) -> PipelineShardMetadata:
base_shard = await build_base_shard(model_id)
return PipelineShardMetadata(
model_card=base_shard.model_card,
model_meta=base_shard.model_meta,
device_rank=base_shard.device_rank,
world_size=base_shard.world_size,
start_layer=base_shard.start_layer,
@@ -93,11 +93,11 @@ class CachedShardDownloader(ShardDownloader):
async def ensure_shard(
self, shard: ShardMetadata, config_only: bool = False
) -> Path:
if (shard.model_card.model_id, shard) in self.cache:
return self.cache[(shard.model_card.model_id, shard)]
if (shard.model_meta.model_id, shard) in self.cache:
return self.cache[(shard.model_meta.model_id, shard)]
target_dir = await self.shard_downloader.ensure_shard(shard, config_only)
self.cache[(shard.model_card.model_id, shard)] = target_dir
self.cache[(shard.model_meta.model_id, shard)] = target_dir
return target_dir
async def get_shard_download_status(

View File

@@ -5,8 +5,8 @@ from datetime import timedelta
from pathlib import Path
from typing import AsyncIterator, Callable
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.worker.shards import (
PipelineShardMetadata,
ShardMetadata,
@@ -86,8 +86,9 @@ NOOP_DOWNLOAD_PROGRESS = RepoDownloadProgress(
repo_id="noop",
repo_revision="noop",
shard=PipelineShardMetadata(
model_card=ModelCard(
model_meta=ModelMetadata(
model_id=ModelId("noop"),
pretty_name="noope",
storage_size=Memory.from_bytes(0),
n_layers=1,
hidden_size=1,

View File

@@ -1,10 +1,7 @@
import os
import threading
from abc import ABC, abstractmethod
from collections.abc import Callable
from functools import partial
from inspect import signature
from typing import TYPE_CHECKING, Any, cast
from typing import TYPE_CHECKING, Callable, Protocol, cast
import mlx.core as mx
import mlx.nn as nn
@@ -32,50 +29,28 @@ from mlx_lm.models.qwen3_next import Qwen3NextSparseMoeBlock
from exo.shared.logging import logger
from exo.shared.types.worker.shards import PipelineShardMetadata
TimeoutCallback = Callable[[], None]
class _LayerCallable(Protocol):
"""Structural type that any compatible layer must satisfy.
def eval_with_timeout(
mlx_item: Any, # pyright: ignore[reportAny]
timeout_seconds: float = 60.0,
on_timeout: TimeoutCallback | None = None,
) -> None:
"""Evaluate MLX item with a hard timeout.
If on_timeout callback is provided, it will be called before terminating
the process. This allows the runner to send a failure event before exit.
We require a single positional input of type ``mx.array`` and an
``mx.array`` output, while permitting arbitrary *args / **kwargs so this
protocol matches the vast majority of `mlx.nn.Module` subclasses.
"""
completed = threading.Event()
def watchdog() -> None:
if not completed.wait(timeout=timeout_seconds):
logger.error(
f"mlx_item evaluation timed out after {timeout_seconds:.0f}s. "
"This may indicate an issue with FAST_SYNCH and tensor parallel sharding. "
"Terminating process."
)
if on_timeout is not None:
on_timeout()
os._exit(1)
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array: ...
watchdog_thread = threading.Thread(target=watchdog, daemon=True)
watchdog_thread.start()
try:
mx.eval(mlx_item) # pyright: ignore[reportAny]
finally:
completed.set()
class CustomMlxLayer(nn.Module):
"""Base class for replacing an MLX layer with a custom implementation."""
def __init__(self, original_layer: nn.Module):
def __init__(self, original_layer: _LayerCallable):
super().__init__()
object.__setattr__(self, "_original_layer", original_layer)
@property
def original_layer(self) -> nn.Module:
return cast(nn.Module, object.__getattribute__(self, "_original_layer"))
def original_layer(self) -> _LayerCallable:
return cast(_LayerCallable, object.__getattribute__(self, "_original_layer"))
# Calls __getattr__ for any attributes not found on nn.Module (e.g. use_sliding)
if not TYPE_CHECKING:
@@ -88,49 +63,53 @@ class CustomMlxLayer(nn.Module):
return getattr(original_layer, name)
class PipelineFirstLayer(CustomMlxLayer):
def __init__(
self,
original_layer: _LayerCallable,
r: int,
group: mx.distributed.Group,
):
super().__init__(original_layer)
self.r: int = r
self.group = group
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array:
if self.r != 0:
x = mx.distributed.recv_like(x, (self.r - 1), group=self.group)
return self.original_layer(x, *args, **kwargs)
def patch_pipeline_first_layer(pipeline_layer: nn.Module, group: mx.distributed.Group) -> nn.Module:
orig_call = cast(Callable[..., mx.array], type(pipeline_layer).__call__)
class PipelineLastLayer(CustomMlxLayer):
def __init__(
self,
original_layer: _LayerCallable,
r: int,
s: int,
group: mx.distributed.Group,
):
super().__init__(original_layer)
self.r: int = r
self.s: int = s
self.group = group
self.original_layer_signature = signature(self.original_layer.__call__)
rank = group.rank()
class PatchedFirstLayer(nn.Module):
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array:
if rank != 0:
x = mx.distributed.recv_like(x, (rank - 1), group=group)
return orig_call(x, *args, **kwargs)
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array:
cache = self.original_layer_signature.bind_partial(
x, *args, **kwargs
).arguments.get("cache", None)
pipeline_layer.__class__ = PatchedFirstLayer
return pipeline_layer
output: mx.array = self.original_layer(x, *args, **kwargs)
def patch_pipeline_last_layer(pipeline_layer: nn.Module, group: mx.distributed.Group) -> nn.Module:
orig_call = cast(Callable[..., mx.array], type(pipeline_layer).__call__)
orig_call_sig = signature(orig_call)
if self.r != self.s - 1:
output = mx.distributed.send(
output, (self.r + 1) % self.s, group=self.group
)
if cache is not None:
cache.keys = mx.depends(cache.keys, output) # type: ignore[reportUnknownMemberType]
rank = group.rank()
size = group.size()
class PatchedLastLayer(nn.Module):
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array:
cache = orig_call_sig.bind_partial(
x, *args, **kwargs
).arguments.get("cache", None)
return output
output: mx.array = orig_call(x, *args, **kwargs)
if rank != size - 1:
output = mx.distributed.send(
output, (rank + 1) % size, group=group
)
if cache is not None:
cache.keys = mx.depends(cache.keys, output) # type: ignore[reportUnknownMemberType]
return output
pipeline_layer.__class__ = PatchedLastLayer
return pipeline_layer
def _inner_model(model: nn.Module) -> nn.Module:
inner = getattr(model, "model", None)
@@ -144,13 +123,13 @@ def _inner_model(model: nn.Module) -> nn.Module:
raise ValueError("Model must either have a 'model' or 'transformer' attribute")
def _get_layers(inner_model_instance: nn.Module) -> list[nn.Module]:
def _get_layers(inner_model_instance: nn.Module) -> list[_LayerCallable]:
# Handle both model.layers and model.h cases
layers: list[nn.Module]
layers: list[_LayerCallable]
if hasattr(inner_model_instance, "layers"):
layers = cast(list[nn.Module], inner_model_instance.layers)
layers = cast(list[_LayerCallable], inner_model_instance.layers)
elif hasattr(inner_model_instance, "h"):
layers = cast(list[nn.Module], inner_model_instance.h)
layers = cast(list[_LayerCallable], inner_model_instance.h)
else:
raise ValueError("Model must have either a 'layers' or 'h' attribute")
@@ -175,12 +154,15 @@ def pipeline_auto_parallel(
layers = _get_layers(inner_model_instance)
start_layer, end_layer = model_shard_meta.start_layer, model_shard_meta.end_layer
device_rank, world_size = model_shard_meta.device_rank, model_shard_meta.world_size
layers = layers[start_layer:end_layer]
layers[0] = patch_pipeline_first_layer(layers[0], group)
layers[-1] = patch_pipeline_last_layer(
layers[0] = PipelineFirstLayer(layers[0], device_rank, group=group)
layers[-1] = PipelineLastLayer(
layers[-1],
group,
device_rank,
world_size,
group=group,
)
if isinstance(inner_model_instance, GptOssMoeModel):
@@ -243,37 +225,9 @@ def patch_pipeline_model[T](model: T, group: mx.distributed.Group) -> T:
return model
def patch_tensor_model[T](model: T) -> T:
"""Patch model's __call__ to ensure distributed ops sync during inference."""
cls = model.__class__
original_call = cls.__call__
call_signature = signature(original_call)
def patched_call(
self: T,
*args: object,
**kwargs: object,
) -> mx.array:
logits: mx.array = original_call(self, *args, **kwargs) # pyright: ignore[reportAny]
cache = call_signature.bind_partial(self, *args, **kwargs).arguments.get(
"cache", None
)
# Add dependency to last cache entry to ensure distributed ops are evaluated
if cache is not None and len(cache) > 0: # pyright: ignore[reportAny]
cache[-1].state = mx.depends(cache[-1].state, logits) # pyright: ignore[reportAny,reportUnknownMemberType]
return logits
cls.__call__ = patched_call
return model
def tensor_auto_parallel(
model: nn.Module,
group: mx.distributed.Group,
timeout_seconds: float = 60.0,
on_timeout: TimeoutCallback | None = None,
) -> nn.Module:
all_to_sharded_linear = partial(
shard_linear,
@@ -318,7 +272,7 @@ def tensor_auto_parallel(
if hasattr(model, "shard"):
try:
model.shard(group) # type: ignore
return patch_tensor_model(model)
return model
except (AttributeError, TypeError, NameError):
pass
@@ -368,10 +322,7 @@ def tensor_auto_parallel(
else:
raise ValueError(f"Unsupported model type: {type(model)}")
model = tensor_parallel_sharding_strategy.shard_model(
model, timeout_seconds, on_timeout
)
return patch_tensor_model(model)
return tensor_parallel_sharding_strategy.shard_model(model)
class TensorParallelShardingStrategy(ABC):
@@ -391,27 +342,13 @@ class TensorParallelShardingStrategy(ABC):
self.N = group.size()
@abstractmethod
def shard_model(
self,
model: nn.Module,
timeout_seconds: float,
on_timeout: TimeoutCallback | None,
) -> nn.Module: ...
def shard_model(self, model: nn.Module) -> nn.Module: ...
class LlamaShardingStrategy(TensorParallelShardingStrategy):
def shard_model(
self,
model: nn.Module,
timeout_seconds: float,
on_timeout: TimeoutCallback | None,
) -> nn.Module:
def shard_model(self, model: nn.Module) -> nn.Module:
model = cast(LlamaModel, model)
for layer in model.layers:
# Force load weights before sharding to avoid FAST_SYNCH deadlock
eval_with_timeout(
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
)
layer.self_attn.q_proj = self.all_to_sharded_linear(layer.self_attn.q_proj)
layer.self_attn.k_proj = self.all_to_sharded_linear(layer.self_attn.k_proj)
layer.self_attn.v_proj = self.all_to_sharded_linear(layer.self_attn.v_proj)
@@ -427,7 +364,7 @@ class LlamaShardingStrategy(TensorParallelShardingStrategy):
return model
def _set_layers(model: nn.Module, layers: list[nn.Module]) -> None:
def _set_layers(model: nn.Module, layers: list[_LayerCallable]) -> None:
inner_model_instance = _inner_model(model)
if hasattr(inner_model_instance, "layers"):
inner_model_instance.layers = layers
@@ -454,17 +391,9 @@ def _set_layers(model: nn.Module, layers: list[nn.Module]) -> None:
class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
def shard_model(
self,
model: nn.Module,
timeout_seconds: float,
on_timeout: TimeoutCallback | None,
) -> nn.Module:
def shard_model(self, model: nn.Module) -> nn.Module:
model = cast(DeepseekV3Model, model)
for layer in model.layers:
eval_with_timeout(
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
)
# Shard the self attention
if layer.self_attn.q_lora_rank is None:
layer.self_attn.q_proj = self.all_to_sharded_linear(
@@ -502,31 +431,23 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
class ShardedDeepseekV3MoE(CustomMlxLayer):
def __init__(self, layer: nn.Module):
def __init__(self, layer: _LayerCallable):
super().__init__(layer)
self.sharding_group: mx.distributed.Group | None = None
def __call__(self, x: mx.array) -> mx.array:
if self.sharding_group is not None:
x = sum_gradients(self.sharding_group)(x)
y = self.original_layer.__call__(x) # type: ignore
y = self.original_layer.__call__(x)
if self.sharding_group is not None:
y = mx.distributed.all_sum(y, group=self.sharding_group) # type: ignore
return y # type: ignore
y = mx.distributed.all_sum(y, group=self.sharding_group)
return y
class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
def shard_model(
self,
model: nn.Module,
timeout_seconds: float,
on_timeout: TimeoutCallback | None,
) -> nn.Module:
def shard_model(self, model: nn.Module) -> nn.Module:
model = cast(MiniMaxModel, model)
for layer in model.layers:
eval_with_timeout(
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
)
# Shard the self attention
layer.self_attn.q_proj = self.all_to_sharded_linear(layer.self_attn.q_proj)
layer.self_attn.k_proj = self.all_to_sharded_linear(layer.self_attn.k_proj)
@@ -546,24 +467,16 @@ class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
self.all_to_sharded_linear_in_place(
layer.block_sparse_moe.switch_mlp.up_proj
)
layer.block_sparse_moe = ShardedQwenMoE(layer.block_sparse_moe) # pyright: ignore[reportAttributeAccessIssue]
layer.block_sparse_moe.sharding_group = self.group # pyright: ignore[reportAttributeAccessIssue]
layer.block_sparse_moe = ShardedQwenMoE(layer.block_sparse_moe) # pyright: ignore[reportAttributeAccessIssue, reportArgumentType]
layer.block_sparse_moe.sharding_group = self.group
return model
class QwenShardingStrategy(TensorParallelShardingStrategy):
def shard_model(
self,
model: nn.Module,
timeout_seconds: float,
on_timeout: TimeoutCallback | None,
) -> nn.Module:
def shard_model(self, model: nn.Module) -> nn.Module:
model = cast(Qwen3MoeModel, model)
for layer in model.layers:
eval_with_timeout(
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
)
# Shard the self attention
layer.self_attn.q_proj = self.all_to_sharded_linear(layer.self_attn.q_proj)
layer.self_attn.k_proj = self.all_to_sharded_linear(layer.self_attn.k_proj)
@@ -580,7 +493,7 @@ class QwenShardingStrategy(TensorParallelShardingStrategy):
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.gate_proj)
self.sharded_to_all_linear_in_place(layer.mlp.switch_mlp.down_proj)
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.up_proj)
layer.mlp = ShardedQwenMoE(layer.mlp) # pyright: ignore[reportAttributeAccessIssue]
layer.mlp = ShardedQwenMoE(layer.mlp) # pyright: ignore[reportAttributeAccessIssue, reportArgumentType]
layer.mlp.sharding_group = self.group
# Shard the MLP
@@ -593,32 +506,24 @@ class QwenShardingStrategy(TensorParallelShardingStrategy):
class ShardedQwenMoE(CustomMlxLayer):
def __init__(self, layer: nn.Module):
def __init__(self, layer: _LayerCallable):
super().__init__(layer)
self.sharding_group: mx.distributed.Group | None = None
def __call__(self, x: mx.array) -> mx.array:
if self.sharding_group is not None:
x = sum_gradients(self.sharding_group)(x)
y = self.original_layer.__call__(x) # type: ignore
y = self.original_layer.__call__(x)
if self.sharding_group is not None:
y = mx.distributed.all_sum(y, group=self.sharding_group) # type: ignore
return y # type: ignore
y = mx.distributed.all_sum(y, group=self.sharding_group)
return y
class GptOssShardingStrategy(TensorParallelShardingStrategy):
def shard_model(
self,
model: nn.Module,
timeout_seconds: float,
on_timeout: TimeoutCallback | None,
) -> nn.Module:
def shard_model(self, model: nn.Module) -> nn.Module:
model = cast(GptOssMoeModel, model)
for layer in model.layers:
eval_with_timeout(
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
)
layer.self_attn.q_proj = self.all_to_sharded_linear(layer.self_attn.q_proj)
layer.self_attn.k_proj = self.all_to_sharded_linear(layer.self_attn.k_proj)
layer.self_attn.v_proj = self.all_to_sharded_linear(layer.self_attn.v_proj)
@@ -642,7 +547,7 @@ class GptOssShardingStrategy(TensorParallelShardingStrategy):
self.all_to_sharded_linear_in_place(layer.mlp.experts.up_proj)
layer.mlp = ShardedGptOssMoE(layer.mlp) # type: ignore
layer.mlp.sharding_group = self.group # pyright: ignore[reportAttributeAccessIssue]
layer.mlp.sharding_group = self.group
return model
@@ -655,7 +560,7 @@ class ShardedGptOssMoE(CustomMlxLayer):
def __call__(self, x: mx.array) -> mx.array:
if self.sharding_group is not None:
x = sum_gradients(self.sharding_group)(x)
y = self.original_layer(x) # type: ignore
y = self.original_layer(x)
if self.sharding_group is not None:
y = mx.distributed.all_sum(y, group=self.sharding_group) # type: ignore
return y # type: ignore
y = mx.distributed.all_sum(y, group=self.sharding_group)
return y

View File

@@ -2,7 +2,9 @@ import json
import os
import resource
import sys
import threading
import time
from collections.abc import Callable
from pathlib import Path
from typing import Any, cast
@@ -57,8 +59,6 @@ from exo.shared.types.worker.shards import (
from exo.worker.download.download_utils import build_model_path
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.auto_parallel import (
TimeoutCallback,
eval_with_timeout,
pipeline_auto_parallel,
tensor_auto_parallel,
)
@@ -75,7 +75,7 @@ def get_weights_size(model_shard_meta: ShardMetadata) -> Memory:
return Memory.from_float_kb(
(model_shard_meta.end_layer - model_shard_meta.start_layer)
/ model_shard_meta.n_layers
* model_shard_meta.model_card.storage_size.in_kb
* model_shard_meta.model_meta.storage_size.in_kb
/ (
1
if isinstance(model_shard_meta, PipelineShardMetadata)
@@ -88,6 +88,41 @@ class ModelLoadingTimeoutError(Exception):
pass
TimeoutCallback = Callable[[], None]
def eval_with_timeout(
mlx_item: Any, # pyright: ignore[reportAny]
timeout_seconds: float = 60.0,
on_timeout: TimeoutCallback | None = None,
) -> None:
"""Evaluate MLX item with a hard timeout.
If on_timeout callback is provided, it will be called before terminating
the process. This allows the runner to send a failure event before exit.
"""
completed = threading.Event()
def watchdog() -> None:
if not completed.wait(timeout=timeout_seconds):
logger.error(
f"mlx_item evaluation timed out after {timeout_seconds:.0f}s. "
"This may indicate an issue with FAST_SYNCH and tensor parallel sharding. "
"Terminating process."
)
if on_timeout is not None:
on_timeout()
os._exit(1)
watchdog_thread = threading.Thread(target=watchdog, daemon=True)
watchdog_thread.start()
try:
mx.eval(mlx_item) # pyright: ignore[reportAny]
finally:
completed.set()
def mx_barrier(group: Group | None = None):
mx.eval(
mx.distributed.all_sum(
@@ -169,14 +204,19 @@ def mlx_distributed_init(
# TODO: update once upstream fixes
logger.info(
f"rank {rank} MLX_JACCL_DEVICES: {coordination_file} with devices: {jaccl_devices_json}"
f"rank {rank} MLX_IBV_DEVICES: {coordination_file} with devices: {jaccl_devices_json}"
)
logger.info(f"rank {rank} MLX_JACCL_COORDINATOR: {jaccl_coordinator}")
os.environ["MLX_JACCL_DEVICES"] = coordination_file
os.environ["MLX_IBV_DEVICES"] = coordination_file
os.environ["MLX_RANK"] = str(rank)
os.environ["MLX_JACCL_COORDINATOR"] = jaccl_coordinator
group = mx.distributed.init(backend="jaccl", strict=True)
case _:
raise ValueError(
f"Unsupported instance type for MLX distributed: {type(bound_instance.instance)}"
)
logger.info(f"Rank {rank} mlx distributed initialization complete")
return group
@@ -206,7 +246,7 @@ def load_mlx_items(
) -> tuple[Model, TokenizerWrapper]:
if group is None:
logger.info(f"Single device used for {bound_instance.instance}")
model_path = build_model_path(bound_instance.bound_shard.model_card.model_id)
model_path = build_model_path(bound_instance.bound_shard.model_meta.model_id)
start_time = time.perf_counter()
model, _ = load_model(model_path, strict=True)
end_time = time.perf_counter()
@@ -234,7 +274,7 @@ def shard_and_load(
group: Group,
on_timeout: TimeoutCallback | None = None,
) -> tuple[nn.Module, TokenizerWrapper]:
model_path = build_model_path(shard_metadata.model_card.model_id)
model_path = build_model_path(shard_metadata.model_meta.model_id)
model, _ = load_model(model_path, lazy=True, strict=False)
logger.debug(model)
@@ -261,6 +301,14 @@ def shard_and_load(
logger.info(f"Group size: {group.size()}, group rank: {group.rank()}")
match shard_metadata:
case TensorShardMetadata():
logger.info(f"loading model from {model_path} with tensor parallelism")
model = tensor_auto_parallel(model, group)
case PipelineShardMetadata():
logger.info(f"loading model from {model_path} with pipeline parallelism")
model = pipeline_auto_parallel(model, group, shard_metadata)
# Estimate timeout based on model size
base_timeout = float(os.environ.get("EXO_MODEL_LOAD_TIMEOUT", "60"))
model_size_gb = get_weights_size(shard_metadata).in_bytes / (1024**3)
@@ -269,15 +317,7 @@ def shard_and_load(
f"Evaluating model parameters with timeout of {timeout_seconds:.0f}s "
f"(model size: {model_size_gb:.1f}GB)"
)
match shard_metadata:
case TensorShardMetadata():
logger.info(f"loading model from {model_path} with tensor parallelism")
model = tensor_auto_parallel(model, group, timeout_seconds, on_timeout)
case PipelineShardMetadata():
logger.info(f"loading model from {model_path} with pipeline parallelism")
model = pipeline_auto_parallel(model, group, shard_metadata)
eval_with_timeout(model.parameters(), timeout_seconds, on_timeout)
eval_with_timeout(model.parameters(), timeout_seconds, on_timeout)
# TODO: Do we need this?
mx.eval(model)
@@ -293,7 +333,7 @@ def shard_and_load(
def get_tokenizer(model_path: Path, shard_metadata: ShardMetadata) -> TokenizerWrapper:
"""Load tokenizer for a model shard. Delegates to load_tokenizer_for_model_id."""
return load_tokenizer_for_model_id(shard_metadata.model_card.model_id, model_path)
return load_tokenizer_for_model_id(shard_metadata.model_meta.model_id, model_path)
def get_eos_token_ids_for_model(model_id: str) -> list[int] | None:
@@ -312,9 +352,6 @@ def get_eos_token_ids_for_model(model_id: str) -> list[int] | None:
model_id_lower = model_id.lower()
if "kimi-k2" in model_id_lower:
return [163586]
elif "glm-4.7-flash" in model_id_lower:
# 154820: <|endoftext|>, 154827: <|user|>, 154829: <|observation|>
return [154820, 154827, 154829]
elif "glm" in model_id_lower:
return [151336, 151329, 151338]
return None

View File

@@ -8,7 +8,6 @@ from loguru import logger
from exo.routing.connection_message import ConnectionMessage, ConnectionMessageType
from exo.shared.apply import apply
from exo.shared.models.model_cards import ModelId
from exo.shared.types.commands import ForwarderCommand, RequestEventLog
from exo.shared.types.common import NodeId, SessionId
from exo.shared.types.events import (
@@ -23,6 +22,7 @@ from exo.shared.types.events import (
TopologyEdgeCreated,
TopologyEdgeDeleted,
)
from exo.shared.types.models import ModelId
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.state import State
from exo.shared.types.tasks import (
@@ -186,11 +186,11 @@ class Worker:
)
)
case DownloadModel(shard_metadata=shard):
if shard.model_card.model_id not in self.download_status:
if shard.model_meta.model_id not in self.download_status:
progress = DownloadPending(
shard_metadata=shard, node_id=self.node_id
)
self.download_status[shard.model_card.model_id] = progress
self.download_status[shard.model_meta.model_id] = progress
await self.event_sender.send(
NodeDownloadProgress(download_progress=progress)
)
@@ -205,7 +205,7 @@ class Worker:
node_id=self.node_id,
total_bytes=initial_progress.total_bytes,
)
self.download_status[shard.model_card.model_id] = progress
self.download_status[shard.model_meta.model_id] = progress
await self.event_sender.send(
NodeDownloadProgress(download_progress=progress)
)
@@ -339,7 +339,7 @@ class Worker:
initial_progress
),
)
self.download_status[task.shard_metadata.model_card.model_id] = status
self.download_status[task.shard_metadata.model_meta.model_id] = status
self.event_sender.send_nowait(NodeDownloadProgress(download_progress=status))
last_progress_time = 0.0
@@ -356,7 +356,7 @@ class Worker:
node_id=self.node_id,
total_bytes=progress.total_bytes,
)
self.download_status[shard.model_card.model_id] = status
self.download_status[shard.model_meta.model_id] = status
await self.event_sender.send(
NodeDownloadProgress(download_progress=status)
)
@@ -376,7 +376,7 @@ class Worker:
progress
),
)
self.download_status[shard.model_card.model_id] = status
self.download_status[shard.model_meta.model_id] = status
await self.event_sender.send(
NodeDownloadProgress(download_progress=status)
)
@@ -413,6 +413,11 @@ class Worker:
)
for nid in conns:
for ip in conns[nid]:
if "127.0.0.1" in ip or "localhost" in ip:
logger.warning(
f"Loopback connection should not happen: {ip=} for {nid=}"
)
edge = SocketConnection(
# nonsense multiaddr
sink_multiaddr=Multiaddr(address=f"/ip4/{ip}/tcp/52415")
@@ -433,9 +438,6 @@ class Worker:
for conn in self.state.topology.out_edges(self.node_id):
if not isinstance(conn.edge, SocketConnection):
continue
# ignore mDNS discovered connections
if conn.edge.sink_multiaddr.port != 52415:
continue
if (
conn.sink not in conns
or conn.edge.sink_multiaddr.ip_address
@@ -476,7 +478,7 @@ class Worker:
else:
continue
self.download_status[progress.shard.model_card.model_id] = status
self.download_status[progress.shard.model_meta.model_id] = status
await self.event_sender.send(
NodeDownloadProgress(download_progress=status)
)

View File

@@ -2,8 +2,8 @@
from collections.abc import Mapping, Sequence
from exo.shared.models.model_cards import ModelId
from exo.shared.types.common import NodeId
from exo.shared.types.models import ModelId
from exo.shared.types.tasks import (
ChatCompletion,
ConnectToGroup,
@@ -21,7 +21,11 @@ from exo.shared.types.worker.downloads import (
DownloadOngoing,
DownloadProgress,
)
from exo.shared.types.worker.instances import BoundInstance, Instance, InstanceId
from exo.shared.types.worker.instances import (
BoundInstance,
Instance,
InstanceId,
)
from exo.shared.types.worker.runners import (
RunnerConnected,
RunnerConnecting,
@@ -50,6 +54,16 @@ def plan(
all_runners: Mapping[RunnerId, RunnerStatus], # all global
tasks: Mapping[TaskId, Task],
) -> Task | None:
from exo.plugins.registry import PluginRegistry
registry = PluginRegistry.get()
# Check plugin tasks first
for plugin in registry.all_plugins():
task = plugin.plan_task(runners, instances)
if task is not None:
return task
# Python short circuiting OR logic should evaluate these sequentially.
return (
_kill_runner(runners, all_runners, instances)
@@ -113,8 +127,19 @@ def _model_needs_download(
runners: Mapping[RunnerId, RunnerSupervisor],
download_status: Mapping[ModelId, DownloadProgress],
) -> DownloadModel | None:
from exo.plugins.registry import PluginRegistry
registry = PluginRegistry.get()
for runner in runners.values():
model_id = runner.bound_instance.bound_shard.model_card.model_id
instance = runner.bound_instance.instance
# Check if any plugin wants to skip download for this instance
plugin = registry.get_plugin_for_instance(instance)
if plugin is not None and plugin.should_skip_download(instance):
continue
model_id = runner.bound_instance.bound_shard.model_meta.model_id
if isinstance(runner.status, RunnerIdle) and (
model_id not in download_status
or not isinstance(
@@ -191,7 +216,7 @@ def _load_model(
nid in global_download_status
and any(
isinstance(dp, DownloadCompleted)
and dp.shard_metadata.model_card.model_id == shard_assignments.model_id
and dp.shard_metadata.model_meta.model_id == shard_assignments.model_id
for dp in global_download_status[nid]
)
for nid in shard_assignments.node_to_runner

View File

@@ -4,7 +4,10 @@ import loguru
from exo.shared.types.events import Event, RunnerStatusUpdated
from exo.shared.types.tasks import Task
from exo.shared.types.worker.instances import BoundInstance, MlxJacclInstance
from exo.shared.types.worker.instances import (
BoundInstance,
MlxJacclInstance,
)
from exo.shared.types.worker.runners import RunnerFailed
from exo.utils.channels import ClosedResourceError, MpReceiver, MpSender
@@ -17,6 +20,7 @@ def entrypoint(
task_receiver: MpReceiver[Task],
_logger: "loguru.Logger",
) -> None:
# Set FAST_SYNCH based on env var or JACCL device count
fast_synch_override = os.environ.get("EXO_FAST_SYNCH")
if fast_synch_override == "on" or (
fast_synch_override != "off"
@@ -34,11 +38,26 @@ def entrypoint(
logger.info(f"Fast synch flag: {os.environ['MLX_METAL_FAST_SYNCH']}")
# Import main after setting global logger - this lets us just import logger from this module
# Route based on instance type (plugins or default MLX)
try:
from exo.worker.runner.runner import main
from exo.plugins.registry import PluginRegistry, discover_plugins
main(bound_instance, event_sender, task_receiver)
# Discover plugins in subprocess (they aren't inherited from main process)
discover_plugins()
registry = PluginRegistry.get()
instance = bound_instance.instance
# Check if a plugin handles this instance type
plugin = registry.get_plugin_for_instance(instance)
if plugin is not None:
# Delegate to plugin runner
plugin.create_runner(bound_instance, event_sender, task_receiver)
else:
# MLX runner (default)
from exo.worker.runner.runner import main
main(bound_instance, event_sender, task_receiver)
except ClosedResourceError:
logger.warning("Runner communication closed unexpectedly")
except Exception as e:

View File

@@ -213,7 +213,7 @@ def main(
command_id=command_id,
chunk=TokenChunk(
idx=response.token,
model=shard_metadata.model_card.model_id,
model=shard_metadata.model_meta.model_id,
text=response.text,
token_id=response.token,
finish_reason=response.finish_reason,
@@ -230,7 +230,7 @@ def main(
command_id=command_id,
chunk=TokenChunk(
idx=0,
model=shard_metadata.model_card.model_id,
model=shard_metadata.model_meta.model_id,
text="",
token_id=0,
finish_reason="error",

View File

@@ -1,7 +1,7 @@
from typing import Final
from exo.shared.models.model_cards import ModelId
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.models import ModelId
from exo.shared.types.tasks import TaskId
from exo.shared.types.worker.instances import InstanceId, RunnerId

View File

@@ -1,8 +1,8 @@
from dataclasses import dataclass, field
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.types.common import NodeId
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.tasks import BaseTask, TaskId
from exo.shared.types.worker.instances import (
BoundInstance,
@@ -32,8 +32,9 @@ def get_pipeline_shard_metadata(
model_id: ModelId, device_rank: int, world_size: int = 1
) -> ShardMetadata:
return PipelineShardMetadata(
model_card=ModelCard(
model_meta=ModelMetadata(
model_id=model_id,
pretty_name=str(model_id),
storage_size=Memory.from_mb(100000),
n_layers=32,
hidden_size=2048,

View File

@@ -11,9 +11,9 @@ import mlx.core as mx
import mlx.nn as nn
from exo.shared.constants import EXO_MODELS_DIR
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.types.api import ChatCompletionMessage
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.tasks import ChatCompletionTaskParams
from exo.shared.types.worker.shards import PipelineShardMetadata, TensorShardMetadata
from exo.worker.engines.mlx import Model
@@ -81,8 +81,9 @@ def run_gpt_oss_pipeline_device(
start_layer, end_layer = layer_splits[rank]
shard_meta = PipelineShardMetadata(
model_card=ModelCard(
model_meta=ModelMetadata(
model_id=ModelId(DEFAULT_GPT_OSS_MODEL_ID),
pretty_name="GPT-OSS 20B",
storage_size=Memory.from_gb(12),
n_layers=24,
hidden_size=2880,
@@ -150,8 +151,9 @@ def run_gpt_oss_tensor_parallel_device(
# For tensor parallelism, all devices run all layers
shard_meta = TensorShardMetadata(
model_card=ModelCard(
model_meta=ModelMetadata(
model_id=ModelId(DEFAULT_GPT_OSS_MODEL_ID),
pretty_name="GPT-OSS 20B",
storage_size=Memory.from_gb(12),
n_layers=24,
hidden_size=2880,

View File

@@ -18,7 +18,6 @@ def _check_model_exists() -> bool:
pytestmark = [
pytest.mark.slow,
pytest.mark.skipif(
not _check_model_exists(),
reason=f"GPT-OSS model not found at {DEFAULT_GPT_OSS_CONFIG.model_path}",

View File

@@ -76,21 +76,19 @@ def get_test_models() -> list[tuple[str, ModelCard]]:
"""Get a representative sample of models to test."""
# Pick one model from each family to test
families: dict[str, tuple[str, ModelCard]] = {}
for _, card in MODEL_CARDS.items():
for short_id, card in MODEL_CARDS.items():
# Extract family name (e.g., "llama-3.1" from "llama-3.1-8b")
parts = card.model_id.short().split("-")
parts = short_id.split("-")
family = "-".join(parts[:2]) if len(parts) >= 2 else parts[0]
if family not in families:
families[family] = (card.model_id.short(), card)
families[family] = (short_id, card)
return list(families.values())
TEST_MODELS: list[tuple[str, ModelCard]] = get_test_models()
pytestmark = pytest.mark.slow
@pytest.fixture(scope="module")
def event_loop():

View File

@@ -1,7 +1,7 @@
import exo.worker.plan as plan_mod
from exo.shared.models.model_cards import ModelId
from exo.shared.types.common import NodeId
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId
from exo.shared.types.tasks import LoadModel
from exo.shared.types.worker.downloads import DownloadCompleted, DownloadProgress
from exo.shared.types.worker.instances import BoundInstance

View File

@@ -82,7 +82,7 @@ async def tb_detection():
send, recv = channel[GatheredInfo]()
ig = InfoGatherer(send)
with anyio.move_on_after(1):
await ig._monitor_system_profiler_thunderbolt_data() # pyright: ignore[reportPrivateUsage]
await ig._monitor_system_profiler() # pyright: ignore[reportPrivateUsage]
with recv:
return recv.collect()
@@ -135,7 +135,7 @@ def ring_instance(test: Tests, iid: InstanceId, hn: str) -> Instance:
else:
raise ValueError(f"{hn} not in {test.devs}")
card = MODEL_CARDS[test.model_id]
meta = MODEL_CARDS[test.model_id].metadata
instance = MlxRingInstance(
instance_id=iid,
ephemeral_port=52416,
@@ -145,15 +145,15 @@ def ring_instance(test: Tests, iid: InstanceId, hn: str) -> Instance:
node_to_runner={NodeId(host[0]): RunnerId(host[0]) for host in test.devs},
runner_to_shard={
RunnerId(test.devs[i][0]): PipelineShardMetadata(
model_card=card,
model_meta=meta,
device_rank=i,
world_size=world_size,
start_layer=(card.n_layers // world_size) * i,
start_layer=(meta.n_layers // world_size) * i,
end_layer=min(
card.n_layers, (card.n_layers // world_size) * (i + 1)
meta.n_layers, (meta.n_layers // world_size) * (i + 1)
),
n_layers=min(card.n_layers, (card.n_layers // world_size) * (i + 1))
- (card.n_layers // world_size) * i,
n_layers=min(meta.n_layers, (meta.n_layers // world_size) * (i + 1))
- (meta.n_layers // world_size) * i,
)
for i in range(world_size)
},
@@ -224,7 +224,7 @@ async def jaccl_backend(test: Tests):
def jaccl_instance(test: Tests, iid: InstanceId):
card = MODEL_CARDS[test.model_id]
meta = MODEL_CARDS[test.model_id].metadata
world_size = len(test.devs)
return MlxJacclInstance(
@@ -239,12 +239,12 @@ def jaccl_instance(test: Tests, iid: InstanceId):
node_to_runner={NodeId(host[0]): RunnerId(host[0]) for host in test.devs},
runner_to_shard={
RunnerId(test.devs[i][0]): TensorShardMetadata(
model_card=card,
model_meta=meta,
device_rank=i,
world_size=world_size,
start_layer=card.n_layers,
end_layer=card.n_layers,
n_layers=card.n_layers,
start_layer=meta.n_layers,
end_layer=meta.n_layers,
n_layers=meta.n_layers,
)
for i in range(world_size)
},

1496
uv.lock generated
View File

File diff suppressed because it is too large Load Diff