mirror of
https://github.com/exo-explore/exo.git
synced 2026-01-20 11:58:57 -05:00
Compare commits
13 Commits
foo
...
sami/flash
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
cd630dea43 | ||
|
|
e55dae5ce8 | ||
|
|
302c43afd5 | ||
|
|
2cf59e2322 | ||
|
|
e506c7d65c | ||
|
|
c1fa2ddeaf | ||
|
|
37c5a2a246 | ||
|
|
4d7f03834a | ||
|
|
bdb9fbc8c0 | ||
|
|
8c7180810c | ||
|
|
318c6e000b | ||
|
|
2d45544da0 | ||
|
|
7cbafa768a |
9
dashboard/package-lock.json
generated
9
dashboard/package-lock.json
generated
@@ -863,6 +863,7 @@
|
||||
"integrity": "sha512-oH8tXw7EZnie8FdOWYrF7Yn4IKrqTFHhXvl8YxXxbKwTMcD/5NNCryUSEXRk2ZR4ojnub0P8rNrsVGHXWqIDtA==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@standard-schema/spec": "^1.0.0",
|
||||
"@sveltejs/acorn-typescript": "^1.0.5",
|
||||
@@ -902,6 +903,7 @@
|
||||
"integrity": "sha512-Y1Cs7hhTc+a5E9Va/xwKlAJoariQyHY+5zBgCZg4PFWNYQ1nMN9sjK1zhw1gK69DuqVP++sht/1GZg1aRwmAXQ==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@sveltejs/vite-plugin-svelte-inspector": "^4.0.1",
|
||||
"debug": "^4.4.1",
|
||||
@@ -1518,6 +1520,7 @@
|
||||
"integrity": "sha512-LCCV0HdSZZZb34qifBsyWlUmok6W7ouER+oQIGBScS8EsZsQbrtFTUrDX4hOl+CS6p7cnNC4td+qrSVGSCTUfQ==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"undici-types": "~6.21.0"
|
||||
}
|
||||
@@ -1527,6 +1530,7 @@
|
||||
"resolved": "https://registry.npmjs.org/acorn/-/acorn-8.15.0.tgz",
|
||||
"integrity": "sha512-NZyJarBfL7nWwIq+FDL6Zp/yHEhePMNnnJ0y3qfieCrmNvYct8uvtiV41UvlSe6apAfk0fY1FbWx+NwfmpvtTg==",
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"bin": {
|
||||
"acorn": "bin/acorn"
|
||||
},
|
||||
@@ -1939,6 +1943,7 @@
|
||||
"integrity": "sha512-fmTRWbNMmsmWq6xJV8D19U/gw/bwrHfNXxrIN+HfZgnzqTHp9jOmKMhsTUjXOJnZOdZY9Q28y4yebKzqDKlxlQ==",
|
||||
"dev": true,
|
||||
"license": "ISC",
|
||||
"peer": true,
|
||||
"engines": {
|
||||
"node": ">=12"
|
||||
}
|
||||
@@ -2646,6 +2651,7 @@
|
||||
"integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"engines": {
|
||||
"node": ">=12"
|
||||
},
|
||||
@@ -2833,6 +2839,7 @@
|
||||
"resolved": "https://registry.npmjs.org/svelte/-/svelte-5.45.3.tgz",
|
||||
"integrity": "sha512-ngKXNhNvwPzF43QqEhDOue7TQTrG09em1sd4HBxVF0Wr2gopAmdEWan+rgbdgK4fhBtSOTJO8bYU4chUG7VXZQ==",
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@jridgewell/remapping": "^2.3.4",
|
||||
"@jridgewell/sourcemap-codec": "^1.5.0",
|
||||
@@ -2977,6 +2984,7 @@
|
||||
"integrity": "sha512-jl1vZzPDinLr9eUt3J/t7V6FgNEw9QjvBPdysz9KfQDD41fQrC2Y4vKQdiaUpFT4bXlb1RHhLpp8wtm6M5TgSw==",
|
||||
"dev": true,
|
||||
"license": "Apache-2.0",
|
||||
"peer": true,
|
||||
"bin": {
|
||||
"tsc": "bin/tsc",
|
||||
"tsserver": "bin/tsserver"
|
||||
@@ -2998,6 +3006,7 @@
|
||||
"integrity": "sha512-+Oxm7q9hDoLMyJOYfUYBuHQo+dkAloi33apOPP56pzj+vsdJDzr+j1NISE5pyaAuKL4A3UD34qd0lx5+kfKp2g==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"esbuild": "^0.25.0",
|
||||
"fdir": "^6.4.4",
|
||||
|
||||
@@ -434,8 +434,8 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
const shardData = shardObj[shardKeys[0]] as Record<string, unknown>;
|
||||
if (!shardData) return null;
|
||||
|
||||
// Model meta is nested: shard.model_card.model_id
|
||||
const modelMeta = shardData.model_card ?? shardData.modelCard;
|
||||
// Model meta is nested: shard.model_meta.model_id
|
||||
const modelMeta = shardData.model_meta ?? shardData.modelMeta;
|
||||
if (!modelMeta || typeof modelMeta !== 'object') return null;
|
||||
|
||||
const meta = modelMeta as Record<string, unknown>;
|
||||
|
||||
@@ -98,7 +98,7 @@
|
||||
const shardData = shardObj[shardKeys[0]] as Record<string, unknown>;
|
||||
if (!shardData) return null;
|
||||
|
||||
const modelMeta = shardData.model_card ?? shardData.modelCard;
|
||||
const modelMeta = shardData.model_meta ?? shardData.modelMeta;
|
||||
if (!modelMeta || typeof modelMeta !== 'object') return null;
|
||||
|
||||
const meta = modelMeta as Record<string, unknown>;
|
||||
@@ -190,7 +190,7 @@
|
||||
const shardKeys = Object.keys(shardObj);
|
||||
if (shardKeys.length !== 1) return null;
|
||||
const shardData = shardObj[shardKeys[0]] as Record<string, unknown>;
|
||||
const modelMeta = shardData?.model_card ?? shardData?.modelCard;
|
||||
const modelMeta = shardData?.model_meta ?? shardData?.modelMeta;
|
||||
if (!modelMeta || typeof modelMeta !== 'object') return null;
|
||||
const meta = modelMeta as Record<string, unknown>;
|
||||
return (meta.prettyName as string) ?? null;
|
||||
|
||||
@@ -17,8 +17,8 @@ dependencies = [
|
||||
"loguru>=0.7.3",
|
||||
"exo_pyo3_bindings", # rust bindings
|
||||
"anyio==4.11.0",
|
||||
"mlx==0.30.3; sys_platform == 'darwin'",
|
||||
"mlx[cpu]==0.30.3; sys_platform == 'linux'",
|
||||
"mlx==0.30.1; sys_platform == 'darwin'",
|
||||
"mlx[cpu]==0.30.1; sys_platform == 'linux'",
|
||||
"mlx-lm @ git+https://github.com/AlexCheema/mlx-lm.git@fix-transformers-5.0.0rc2",
|
||||
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
|
||||
"hypercorn>=0.18.0",
|
||||
@@ -30,6 +30,7 @@ dependencies = [
|
||||
exo-master = "exo.master.main:main"
|
||||
exo-worker = "exo.worker.main:main"
|
||||
exo = "exo.main:main"
|
||||
exo-rsh = "exo.rsh.client:main"
|
||||
|
||||
# dependencies only required for development
|
||||
[dependency-groups]
|
||||
|
||||
32
src/exo/cli/__init__.py
Normal file
32
src/exo/cli/__init__.py
Normal file
@@ -0,0 +1,32 @@
|
||||
"""Exo CLI - SLURM-compatible job management commands."""
|
||||
|
||||
|
||||
def run_subcommand(command: str, args: list[str]) -> int:
|
||||
"""Route to the appropriate subcommand handler.
|
||||
|
||||
Args:
|
||||
command: The subcommand name (sbatch, squeue, scancel, salloc)
|
||||
args: Command line arguments for the subcommand
|
||||
|
||||
Returns:
|
||||
Exit code from the subcommand
|
||||
"""
|
||||
if command == "sbatch":
|
||||
from exo.cli.sbatch import main
|
||||
|
||||
return main(args)
|
||||
elif command == "squeue":
|
||||
from exo.cli.squeue import main
|
||||
|
||||
return main(args)
|
||||
elif command == "scancel":
|
||||
from exo.cli.scancel import main
|
||||
|
||||
return main(args)
|
||||
elif command == "salloc":
|
||||
from exo.cli.salloc import main
|
||||
|
||||
return main(args)
|
||||
else:
|
||||
print(f"Unknown subcommand: {command}")
|
||||
return 1
|
||||
118
src/exo/cli/common.py
Normal file
118
src/exo/cli/common.py
Normal file
@@ -0,0 +1,118 @@
|
||||
"""Common utilities for Exo CLI commands."""
|
||||
|
||||
import json
|
||||
import os
|
||||
import urllib.request
|
||||
from typing import Any
|
||||
from urllib.error import HTTPError, URLError
|
||||
|
||||
# Default API endpoint
|
||||
DEFAULT_API_HOST = "localhost"
|
||||
DEFAULT_API_PORT = 52415
|
||||
|
||||
|
||||
def get_api_base() -> str:
|
||||
"""Get the API base URL from environment or defaults."""
|
||||
host = os.environ.get("EXO_API_HOST", DEFAULT_API_HOST)
|
||||
port = os.environ.get("EXO_API_PORT", str(DEFAULT_API_PORT))
|
||||
return f"http://{host}:{port}"
|
||||
|
||||
|
||||
def api_request(
|
||||
method: str,
|
||||
path: str,
|
||||
data: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any] | list[Any]:
|
||||
"""Make an API request to the Exo server.
|
||||
|
||||
Args:
|
||||
method: HTTP method (GET, POST, DELETE, etc.)
|
||||
path: API path (e.g., "/flash/instances")
|
||||
data: Optional JSON data for POST/PUT requests
|
||||
|
||||
Returns:
|
||||
Parsed JSON response
|
||||
|
||||
Raises:
|
||||
SystemExit: On connection or HTTP errors
|
||||
"""
|
||||
url = f"{get_api_base()}{path}"
|
||||
|
||||
request_data = None
|
||||
if data is not None:
|
||||
request_data = json.dumps(data).encode("utf-8")
|
||||
|
||||
req = urllib.request.Request(
|
||||
url,
|
||||
data=request_data,
|
||||
method=method,
|
||||
)
|
||||
req.add_header("Content-Type", "application/json")
|
||||
|
||||
try:
|
||||
with urllib.request.urlopen(req, timeout=30) as response: # pyright: ignore[reportAny]
|
||||
body: str = response.read().decode("utf-8") # pyright: ignore[reportAny]
|
||||
if body:
|
||||
return json.loads(body) # pyright: ignore[reportAny]
|
||||
return {}
|
||||
except HTTPError as e:
|
||||
error_body = e.read().decode("utf-8") if e.fp else ""
|
||||
print(f"API error: {e.code} {e.reason}")
|
||||
if error_body:
|
||||
try:
|
||||
error_json: dict[str, str] = json.loads(error_body) # pyright: ignore[reportAny]
|
||||
if "detail" in error_json:
|
||||
print(f" {error_json['detail']}")
|
||||
except json.JSONDecodeError:
|
||||
print(f" {error_body}")
|
||||
raise SystemExit(1) from None
|
||||
except URLError as e:
|
||||
print(f"Connection error: {e.reason}")
|
||||
print(f"Is Exo running at {get_api_base()}?")
|
||||
raise SystemExit(1) from None
|
||||
|
||||
|
||||
def truncate_id(instance_id: str, length: int = 8) -> str:
|
||||
"""Truncate a UUID for display.
|
||||
|
||||
Args:
|
||||
instance_id: Full UUID string
|
||||
length: Number of characters to keep
|
||||
|
||||
Returns:
|
||||
Truncated ID without hyphens
|
||||
"""
|
||||
return instance_id.replace("-", "")[:length]
|
||||
|
||||
|
||||
def format_table(headers: list[str], rows: list[list[str]]) -> str:
|
||||
"""Format data as a simple text table.
|
||||
|
||||
Args:
|
||||
headers: Column headers
|
||||
rows: List of rows, each row is a list of column values
|
||||
|
||||
Returns:
|
||||
Formatted table string
|
||||
"""
|
||||
if not rows:
|
||||
return " ".join(f"{h:<10}" for h in headers)
|
||||
|
||||
# Calculate column widths
|
||||
widths = [len(h) for h in headers]
|
||||
for row in rows:
|
||||
for i, cell in enumerate(row):
|
||||
if i < len(widths):
|
||||
widths[i] = max(widths[i], len(cell))
|
||||
|
||||
# Build format string
|
||||
fmt = " ".join(f"{{:<{w}}}" for w in widths)
|
||||
|
||||
# Format output
|
||||
lines = [fmt.format(*headers)]
|
||||
for row in rows:
|
||||
# Pad row if needed
|
||||
padded = row + [""] * (len(headers) - len(row))
|
||||
lines.append(fmt.format(*padded[: len(headers)]))
|
||||
|
||||
return "\n".join(lines)
|
||||
100
src/exo/cli/salloc.py
Normal file
100
src/exo/cli/salloc.py
Normal file
@@ -0,0 +1,100 @@
|
||||
"""salloc - Allocate nodes for interactive use.
|
||||
|
||||
Usage:
|
||||
exo salloc [options] [-- command [args...]]
|
||||
|
||||
Options:
|
||||
-N, --nodes N Number of nodes to allocate (default: 1)
|
||||
--hosts HOSTS Comma-separated list of hostnames
|
||||
|
||||
If a command is provided after --, it will be executed with
|
||||
SLURM-like environment variables set:
|
||||
SLURM_JOB_NODELIST - Comma-separated list of allocated nodes
|
||||
SLURM_NNODES - Number of allocated nodes
|
||||
|
||||
Examples:
|
||||
exo salloc --nodes=2 --hosts=node1,node2 -- mpirun ./my_program
|
||||
exo salloc --hosts=localhost -- bash
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
|
||||
def main(args: list[str]) -> int:
|
||||
"""Main entry point for salloc command."""
|
||||
# Split args at -- if present
|
||||
cmd_args: list[str] = []
|
||||
salloc_args = args
|
||||
|
||||
if "--" in args:
|
||||
idx = args.index("--")
|
||||
salloc_args = args[:idx]
|
||||
cmd_args = args[idx + 1 :]
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="exo salloc",
|
||||
description="Allocate nodes for interactive use",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-N",
|
||||
"--nodes",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of nodes to allocate (default: 1)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hosts",
|
||||
help="Comma-separated list of hostnames (required)",
|
||||
)
|
||||
|
||||
parsed = parser.parse_args(salloc_args)
|
||||
|
||||
nodes: int = parsed.nodes # pyright: ignore[reportAny]
|
||||
hosts: str | None = parsed.hosts # pyright: ignore[reportAny]
|
||||
|
||||
# Require explicit hosts since we can't discover them from topology
|
||||
if not hosts:
|
||||
print("Error: --hosts is required (e.g., --hosts=node1,node2)", file=sys.stderr)
|
||||
print(" The Exo topology doesn't expose hostnames.", file=sys.stderr)
|
||||
return 1
|
||||
|
||||
host_list = [h.strip() for h in hosts.split(",") if h.strip()]
|
||||
|
||||
if len(host_list) < nodes:
|
||||
print(
|
||||
f"Error: Requested {nodes} nodes but only {len(host_list)} hosts provided",
|
||||
file=sys.stderr,
|
||||
)
|
||||
return 1
|
||||
|
||||
# Use first N hosts
|
||||
allocated_hosts = host_list[:nodes]
|
||||
nodelist = ",".join(allocated_hosts)
|
||||
|
||||
# Set environment variables
|
||||
env = os.environ.copy()
|
||||
env["SLURM_JOB_NODELIST"] = nodelist
|
||||
env["SLURM_NNODES"] = str(nodes)
|
||||
|
||||
print(f"salloc: Granted job allocation on {nodes} node(s)")
|
||||
print(f"salloc: Nodes: {nodelist}")
|
||||
|
||||
if cmd_args:
|
||||
# Run the command
|
||||
print(f"salloc: Running: {' '.join(cmd_args)}")
|
||||
result = subprocess.run(cmd_args, env=env)
|
||||
return result.returncode
|
||||
else:
|
||||
# Start interactive shell
|
||||
shell = os.environ.get("SHELL", "/bin/bash")
|
||||
print(f"salloc: Starting shell {shell}")
|
||||
print("salloc: Use 'exit' to release allocation")
|
||||
result = subprocess.run([shell], env=env)
|
||||
return result.returncode
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main(sys.argv[1:]))
|
||||
233
src/exo/cli/sbatch.py
Normal file
233
src/exo/cli/sbatch.py
Normal file
@@ -0,0 +1,233 @@
|
||||
"""sbatch - Submit a batch job to Exo.
|
||||
|
||||
Usage:
|
||||
exo sbatch [options] <script|executable>
|
||||
exo sbatch --job-name=NAME --nodes=N <executable>
|
||||
|
||||
Options:
|
||||
-J, --job-name NAME Job name
|
||||
-N, --nodes N Number of nodes (default: 1)
|
||||
--ntasks-per-node N Tasks per node (default: 1)
|
||||
-D, --chdir DIR Working directory
|
||||
--hosts HOSTS Comma-separated list of hostnames
|
||||
|
||||
Job scripts can contain #SBATCH directives:
|
||||
#!/bin/bash
|
||||
#SBATCH --job-name=Sod2D
|
||||
#SBATCH --nodes=2
|
||||
#SBATCH --chdir=/path/to/workdir
|
||||
|
||||
/path/to/flash4
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
|
||||
from exo.cli.common import api_request, truncate_id
|
||||
|
||||
|
||||
def parse_job_script(script_path: str) -> tuple[dict[str, str], str | None]:
|
||||
"""Parse a job script for #SBATCH directives and executable.
|
||||
|
||||
Args:
|
||||
script_path: Path to the job script
|
||||
|
||||
Returns:
|
||||
Tuple of (directives dict, executable path or None)
|
||||
"""
|
||||
directives: dict[str, str] = {}
|
||||
executable: str | None = None
|
||||
|
||||
with open(script_path, "r") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
|
||||
# Parse #SBATCH directives
|
||||
if line.startswith("#SBATCH"):
|
||||
# Handle both --option=value and --option value formats
|
||||
match = re.match(r"#SBATCH\s+(-\w|--[\w-]+)(?:=|\s+)(.+)", line)
|
||||
if match:
|
||||
opt, val = match.groups()
|
||||
directives[opt.lstrip("-")] = val.strip()
|
||||
continue
|
||||
|
||||
# Skip comments and empty lines
|
||||
if line.startswith("#") or not line:
|
||||
continue
|
||||
|
||||
# First non-comment, non-directive line is the executable
|
||||
if executable is None:
|
||||
# Handle lines like "/path/to/flash4" or "srun /path/to/flash4"
|
||||
parts = line.split()
|
||||
if parts:
|
||||
# Skip srun/mpirun prefixes if present
|
||||
for part in parts:
|
||||
if not part.startswith("-") and "/" in part:
|
||||
executable = part
|
||||
break
|
||||
if executable is None and parts:
|
||||
executable = parts[-1] # Last token
|
||||
|
||||
return directives, executable
|
||||
|
||||
|
||||
def main(args: list[str]) -> int:
|
||||
"""Main entry point for sbatch command."""
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="exo sbatch",
|
||||
description="Submit a batch job to Exo",
|
||||
)
|
||||
parser.add_argument(
|
||||
"script",
|
||||
help="Job script or executable path",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-J",
|
||||
"--job-name",
|
||||
dest="job_name",
|
||||
help="Job name",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-N",
|
||||
"--nodes",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of nodes (default: 1)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ntasks-per-node",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Tasks per node (default: 1)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-D",
|
||||
"--chdir",
|
||||
help="Working directory",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hosts",
|
||||
help="Comma-separated list of hostnames",
|
||||
)
|
||||
|
||||
parsed = parser.parse_args(args)
|
||||
|
||||
# Extract typed values from namespace
|
||||
script_path: str = parsed.script # pyright: ignore[reportAny]
|
||||
arg_job_name: str | None = parsed.job_name # pyright: ignore[reportAny]
|
||||
arg_nodes: int = parsed.nodes # pyright: ignore[reportAny]
|
||||
arg_ntasks: int = parsed.ntasks_per_node # pyright: ignore[reportAny]
|
||||
arg_chdir: str | None = parsed.chdir # pyright: ignore[reportAny]
|
||||
arg_hosts: str | None = parsed.hosts # pyright: ignore[reportAny]
|
||||
|
||||
# Determine if input is a script or direct executable
|
||||
executable: str | None = None
|
||||
directives: dict[str, str] = {}
|
||||
|
||||
if os.path.isfile(script_path):
|
||||
# Check if it's a binary file (executable) or text script
|
||||
is_binary = False
|
||||
try:
|
||||
with open(script_path, "rb") as f:
|
||||
chunk = f.read(512)
|
||||
# Binary files typically contain null bytes
|
||||
is_binary = b"\x00" in chunk
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
if is_binary:
|
||||
# It's a binary executable
|
||||
executable = script_path
|
||||
else:
|
||||
# Try to read as text
|
||||
try:
|
||||
with open(script_path, "r") as f:
|
||||
first_line = f.readline()
|
||||
f.seek(0)
|
||||
content = f.read(1024)
|
||||
|
||||
if first_line.startswith("#!") or "#SBATCH" in content:
|
||||
# It's a job script - parse it
|
||||
directives, executable = parse_job_script(script_path)
|
||||
else:
|
||||
# It's an executable (text but no shebang/directives)
|
||||
executable = script_path
|
||||
except UnicodeDecodeError:
|
||||
# Can't read as text - treat as binary executable
|
||||
executable = script_path
|
||||
else:
|
||||
# Not a file - treat as executable path
|
||||
executable = script_path
|
||||
|
||||
if executable is None:
|
||||
print("Error: No executable found in job script", file=sys.stderr)
|
||||
return 1
|
||||
|
||||
# Build job parameters - CLI args override script directives
|
||||
job_name = arg_job_name or directives.get("job-name") or directives.get("J")
|
||||
if not job_name:
|
||||
# Generate name from executable
|
||||
job_name = os.path.basename(executable).replace(".", "_")
|
||||
|
||||
nodes = arg_nodes
|
||||
if "nodes" in directives:
|
||||
nodes = int(directives["nodes"])
|
||||
if "N" in directives:
|
||||
nodes = int(directives["N"])
|
||||
if arg_nodes != 1: # CLI override
|
||||
nodes = arg_nodes
|
||||
|
||||
ntasks = arg_ntasks
|
||||
if "ntasks-per-node" in directives:
|
||||
ntasks = int(directives["ntasks-per-node"])
|
||||
if arg_ntasks != 1: # CLI override
|
||||
ntasks = arg_ntasks
|
||||
|
||||
workdir = arg_chdir or directives.get("chdir") or directives.get("D")
|
||||
if not workdir:
|
||||
workdir = os.getcwd()
|
||||
|
||||
hosts = arg_hosts or directives.get("hosts") or ""
|
||||
|
||||
# Resolve executable to absolute path
|
||||
if not os.path.isabs(executable):
|
||||
executable = os.path.abspath(os.path.join(workdir, executable))
|
||||
|
||||
# Submit job via API using query parameters
|
||||
from urllib.parse import urlencode
|
||||
|
||||
params = {
|
||||
"simulation_name": job_name,
|
||||
"flash_executable_path": executable,
|
||||
"parameter_file_path": "", # FLASH par file - use default
|
||||
"working_directory": workdir,
|
||||
"ranks_per_node": str(ntasks),
|
||||
"min_nodes": str(nodes),
|
||||
"hosts": hosts,
|
||||
}
|
||||
|
||||
query_string = urlencode(params)
|
||||
result = api_request("POST", f"/flash/launch?{query_string}")
|
||||
|
||||
# Print job submission confirmation
|
||||
if isinstance(result, dict):
|
||||
instance_id_val = result.get("instance_id")
|
||||
|
||||
if instance_id_val is not None:
|
||||
job_id = truncate_id(str(instance_id_val)) # pyright: ignore[reportAny]
|
||||
print(f"Submitted batch job {job_id}")
|
||||
else:
|
||||
# Instance created asynchronously - user should check squeue
|
||||
print("Job submitted successfully")
|
||||
print("Use 'exo squeue' to view job ID")
|
||||
else:
|
||||
print("Job submitted successfully")
|
||||
print("Use 'exo squeue' to view job ID")
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main(sys.argv[1:]))
|
||||
95
src/exo/cli/scancel.py
Normal file
95
src/exo/cli/scancel.py
Normal file
@@ -0,0 +1,95 @@
|
||||
"""scancel - Cancel jobs in the Exo queue.
|
||||
|
||||
Usage:
|
||||
exo scancel <jobid> [<jobid>...]
|
||||
|
||||
Arguments:
|
||||
jobid Job ID (or prefix) to cancel. Can specify multiple.
|
||||
|
||||
Examples:
|
||||
exo scancel abc123 # Cancel job starting with abc123
|
||||
exo scancel abc123 def456 # Cancel multiple jobs
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
from typing import Any, cast
|
||||
|
||||
from exo.cli.common import api_request, truncate_id
|
||||
|
||||
|
||||
def main(args: list[str]) -> int:
|
||||
"""Main entry point for scancel command."""
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="exo scancel",
|
||||
description="Cancel jobs in the Exo queue",
|
||||
)
|
||||
parser.add_argument(
|
||||
"jobids",
|
||||
nargs="+",
|
||||
help="Job ID(s) to cancel",
|
||||
)
|
||||
|
||||
parsed = parser.parse_args(args)
|
||||
jobids: list[str] = parsed.jobids # pyright: ignore[reportAny]
|
||||
|
||||
# Fetch current jobs to resolve partial IDs
|
||||
result = api_request("GET", "/flash/instances")
|
||||
if isinstance(result, list):
|
||||
instances = cast(list[dict[str, Any]], result)
|
||||
else:
|
||||
instances = cast(list[dict[str, Any]], result.get("instances", []))
|
||||
|
||||
# Build lookup of full IDs
|
||||
id_map: dict[str, str] = {}
|
||||
for inst in instances:
|
||||
iid = inst.get("instance_id", "") # pyright: ignore[reportAny]
|
||||
full_id = str(iid) if iid else "" # pyright: ignore[reportAny]
|
||||
if full_id:
|
||||
# Map both full ID and truncated versions
|
||||
normalized = full_id.replace("-", "").lower()
|
||||
id_map[normalized] = full_id
|
||||
# Also map prefixes
|
||||
for length in range(4, len(normalized) + 1):
|
||||
prefix = normalized[:length]
|
||||
if prefix not in id_map:
|
||||
id_map[prefix] = full_id
|
||||
|
||||
cancelled = 0
|
||||
errors = 0
|
||||
|
||||
for jobid in jobids:
|
||||
search = jobid.lower().replace("-", "")
|
||||
|
||||
# Find matching full ID
|
||||
full_id = id_map.get(search)
|
||||
if not full_id:
|
||||
# Try prefix match
|
||||
matches = [fid for key, fid in id_map.items() if key.startswith(search)]
|
||||
if len(matches) == 1:
|
||||
full_id = matches[0]
|
||||
elif len(matches) > 1:
|
||||
print(f"Ambiguous job ID: {jobid} matches multiple jobs")
|
||||
errors += 1
|
||||
continue
|
||||
else:
|
||||
print(f"Job not found: {jobid}")
|
||||
errors += 1
|
||||
continue
|
||||
|
||||
# Cancel the job
|
||||
try:
|
||||
api_request("DELETE", f"/flash/{full_id}")
|
||||
print(f"Job {truncate_id(full_id)} cancelled")
|
||||
cancelled += 1
|
||||
except SystemExit:
|
||||
print(f"Failed to cancel job {truncate_id(full_id)}")
|
||||
errors += 1
|
||||
|
||||
if errors > 0 and cancelled == 0:
|
||||
return 1
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main(sys.argv[1:]))
|
||||
165
src/exo/cli/squeue.py
Normal file
165
src/exo/cli/squeue.py
Normal file
@@ -0,0 +1,165 @@
|
||||
"""squeue - View the Exo job queue.
|
||||
|
||||
Usage:
|
||||
exo squeue [options]
|
||||
|
||||
Options:
|
||||
-l, --long Show detailed output
|
||||
-j, --job ID Show only this job
|
||||
|
||||
Output columns:
|
||||
JOBID - Job identifier (truncated UUID)
|
||||
NAME - Job name
|
||||
NODES - Number of nodes
|
||||
STATE - Job state (PENDING, RUNNING, FAILED, etc.)
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
from typing import Any, cast
|
||||
|
||||
from exo.cli.common import api_request, format_table, truncate_id
|
||||
|
||||
# Map Exo runner statuses to SLURM-like states
|
||||
STATUS_MAP: dict[str, str] = {
|
||||
"RunnerIdle": "PENDING",
|
||||
"RunnerConnecting": "CONFIGURING",
|
||||
"RunnerConnected": "CONFIGURING",
|
||||
"RunnerLoading": "CONFIGURING",
|
||||
"RunnerLoaded": "CONFIGURING",
|
||||
"RunnerWarmingUp": "CONFIGURING",
|
||||
"RunnerReady": "COMPLETING",
|
||||
"RunnerRunning": "RUNNING",
|
||||
"RunnerShuttingDown": "COMPLETING",
|
||||
"RunnerShutdown": "COMPLETED",
|
||||
"RunnerFailed": "FAILED",
|
||||
}
|
||||
|
||||
|
||||
def get_job_state(runner_statuses: dict[str, Any]) -> str:
|
||||
"""Determine overall job state from runner statuses."""
|
||||
if not runner_statuses:
|
||||
return "PENDING"
|
||||
|
||||
states: set[str] = set()
|
||||
for status_val in runner_statuses.values(): # pyright: ignore[reportAny]
|
||||
if isinstance(status_val, dict):
|
||||
# Extract status type from discriminated union
|
||||
type_val = status_val.get("type", "RunnerIdle") # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
|
||||
status_type = str(type_val) if type_val else "RunnerIdle" # pyright: ignore[reportUnknownArgumentType]
|
||||
elif isinstance(status_val, str):
|
||||
status_type = status_val
|
||||
else:
|
||||
status_type = "RunnerIdle"
|
||||
# Strip parentheses from status strings like "RunnerRunning()"
|
||||
if status_type.endswith("()"):
|
||||
status_type = status_type[:-2]
|
||||
states.add(STATUS_MAP.get(status_type, "UNKNOWN"))
|
||||
|
||||
# Priority order for overall state
|
||||
if "FAILED" in states:
|
||||
return "FAILED"
|
||||
if "RUNNING" in states:
|
||||
return "RUNNING"
|
||||
if "CONFIGURING" in states:
|
||||
return "CONFIGURING"
|
||||
if "COMPLETING" in states:
|
||||
return "COMPLETING"
|
||||
if "COMPLETED" in states:
|
||||
return "COMPLETED"
|
||||
if "PENDING" in states:
|
||||
return "PENDING"
|
||||
return "UNKNOWN"
|
||||
|
||||
|
||||
def main(args: list[str]) -> int:
|
||||
"""Main entry point for squeue command."""
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="exo squeue",
|
||||
description="View the Exo job queue",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-l",
|
||||
"--long",
|
||||
action="store_true",
|
||||
help="Show detailed output",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-j",
|
||||
"--job",
|
||||
help="Show only this job ID",
|
||||
)
|
||||
|
||||
parsed = parser.parse_args(args)
|
||||
|
||||
# Extract typed values
|
||||
long_format: bool = parsed.long # pyright: ignore[reportAny]
|
||||
job_filter: str | None = parsed.job # pyright: ignore[reportAny]
|
||||
|
||||
# Fetch jobs from API - returns list directly
|
||||
result = api_request("GET", "/flash/instances")
|
||||
# API returns list directly, not {"instances": [...]}
|
||||
if isinstance(result, list):
|
||||
instances = cast(list[dict[str, Any]], result)
|
||||
else:
|
||||
instances = cast(list[dict[str, Any]], result.get("instances", []))
|
||||
|
||||
if not instances:
|
||||
# No jobs - just print header
|
||||
if long_format:
|
||||
print("JOBID NAME NODES RANKS STATE WORKDIR")
|
||||
else:
|
||||
print("JOBID NAME NODES STATE")
|
||||
return 0
|
||||
|
||||
# Filter by job ID if specified
|
||||
if job_filter:
|
||||
search = job_filter.lower()
|
||||
filtered: list[dict[str, Any]] = []
|
||||
for i in instances:
|
||||
iid = i.get("instance_id", "") # pyright: ignore[reportAny]
|
||||
if search in str(iid).lower().replace("-", ""): # pyright: ignore[reportAny]
|
||||
filtered.append(i)
|
||||
instances = filtered
|
||||
|
||||
# Build table
|
||||
rows: list[list[str]] = []
|
||||
|
||||
if long_format:
|
||||
headers = ["JOBID", "NAME", "NODES", "RANKS", "STATE", "WORKDIR"]
|
||||
for inst in instances:
|
||||
iid_val = inst.get("instance_id", "") # pyright: ignore[reportAny]
|
||||
instance_id = str(iid_val) if iid_val else "" # pyright: ignore[reportAny]
|
||||
job_id = truncate_id(instance_id, 12)
|
||||
name_val = inst.get("simulation_name", "") # pyright: ignore[reportAny]
|
||||
name = (str(name_val) if name_val else "")[:15] # pyright: ignore[reportAny]
|
||||
runner_statuses = cast(dict[str, Any], inst.get("runner_statuses", {}))
|
||||
nodes = str(len(runner_statuses))
|
||||
ranks_val = inst.get("total_ranks", 0) # pyright: ignore[reportAny]
|
||||
ranks = str(ranks_val) if ranks_val else "0" # pyright: ignore[reportAny]
|
||||
state = get_job_state(runner_statuses)
|
||||
workdir_val = inst.get("working_directory", "") # pyright: ignore[reportAny]
|
||||
workdir = str(workdir_val) if workdir_val else "" # pyright: ignore[reportAny]
|
||||
# Truncate workdir for display
|
||||
if len(workdir) > 30:
|
||||
workdir = "..." + workdir[-27:]
|
||||
rows.append([job_id, name, nodes, ranks, state, workdir])
|
||||
else:
|
||||
headers = ["JOBID", "NAME", "NODES", "STATE"]
|
||||
for inst in instances:
|
||||
iid_val = inst.get("instance_id", "") # pyright: ignore[reportAny]
|
||||
instance_id = str(iid_val) if iid_val else "" # pyright: ignore[reportAny]
|
||||
job_id = truncate_id(instance_id, 8)
|
||||
name_val = inst.get("simulation_name", "") # pyright: ignore[reportAny]
|
||||
name = (str(name_val) if name_val else "")[:15] # pyright: ignore[reportAny]
|
||||
runner_statuses = cast(dict[str, Any], inst.get("runner_statuses", {}))
|
||||
nodes = str(len(runner_statuses))
|
||||
state = get_job_state(runner_statuses)
|
||||
rows.append([job_id, name, nodes, state])
|
||||
|
||||
print(format_table(headers, rows))
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main(sys.argv[1:]))
|
||||
@@ -195,6 +195,14 @@ class Node:
|
||||
|
||||
|
||||
def main():
|
||||
# Check for SLURM-compatible subcommands first
|
||||
import sys
|
||||
|
||||
if len(sys.argv) > 1 and sys.argv[1] in ("sbatch", "squeue", "scancel", "salloc"):
|
||||
from exo.cli import run_subcommand
|
||||
|
||||
sys.exit(run_subcommand(sys.argv[1], sys.argv[2:]))
|
||||
|
||||
args = Args.parse()
|
||||
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
|
||||
resource.setrlimit(resource.RLIMIT_NOFILE, (max(soft, 65535), hard))
|
||||
@@ -205,6 +213,11 @@ def main():
|
||||
logger.info("Starting EXO")
|
||||
logger.info(f"EXO_LIBP2P_NAMESPACE: {os.getenv('EXO_LIBP2P_NAMESPACE')}")
|
||||
|
||||
# Discover and register plugins
|
||||
from exo.plugins.registry import discover_plugins
|
||||
|
||||
discover_plugins()
|
||||
|
||||
# Set FAST_SYNCH override env var for runner subprocesses
|
||||
if args.fast_synch is True:
|
||||
os.environ["EXO_FAST_SYNCH"] = "on"
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
import asyncio
|
||||
import os
|
||||
import time
|
||||
from collections.abc import AsyncGenerator
|
||||
from http import HTTPStatus
|
||||
from typing import cast
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
import anyio
|
||||
from anyio import BrokenResourceError, create_task_group
|
||||
@@ -14,13 +16,14 @@ from hypercorn.asyncio import serve # pyright: ignore[reportUnknownVariableType
|
||||
from hypercorn.config import Config
|
||||
from hypercorn.typing import ASGIFramework
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
from exo.master.placement import place_instance as get_instance_placements
|
||||
from exo.shared.apply import apply
|
||||
from exo.shared.election import ElectionMessage
|
||||
from exo.shared.logging import InterceptLogger
|
||||
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard, ModelId
|
||||
from exo.shared.models.model_meta import get_model_card
|
||||
from exo.shared.models.model_cards import MODEL_CARDS
|
||||
from exo.shared.models.model_meta import get_model_meta
|
||||
from exo.shared.types.api import (
|
||||
BenchChatCompletionResponse,
|
||||
BenchChatCompletionTaskParams,
|
||||
@@ -59,9 +62,14 @@ from exo.shared.types.events import (
|
||||
IndexedEvent,
|
||||
)
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.models import ModelId, ModelMetadata
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.types.tasks import ChatCompletionTaskParams
|
||||
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
|
||||
from exo.shared.types.worker.instances import (
|
||||
Instance,
|
||||
InstanceId,
|
||||
InstanceMeta,
|
||||
)
|
||||
from exo.shared.types.worker.shards import Sharding
|
||||
from exo.utils.banner import print_startup_banner
|
||||
from exo.utils.channels import Receiver, Sender, channel
|
||||
@@ -69,6 +77,22 @@ from exo.utils.dashboard_path import find_dashboard
|
||||
from exo.utils.event_buffer import OrderedBuffer
|
||||
|
||||
|
||||
class ExecuteRequest(BaseModel):
|
||||
"""Request to execute a command."""
|
||||
|
||||
command: list[str]
|
||||
cwd: Optional[str] = None
|
||||
env: Optional[dict[str, str]] = None
|
||||
|
||||
|
||||
class ExecuteResponse(BaseModel):
|
||||
"""Response from command execution."""
|
||||
|
||||
exit_code: int
|
||||
stdout: str
|
||||
stderr: str
|
||||
|
||||
|
||||
def chunk_to_response(
|
||||
chunk: TokenChunk, command_id: CommandId
|
||||
) -> ChatCompletionResponse:
|
||||
@@ -86,12 +110,12 @@ def chunk_to_response(
|
||||
)
|
||||
|
||||
|
||||
async def resolve_model_card(model_id: str) -> ModelCard:
|
||||
async def resolve_model_meta(model_id: str) -> ModelMetadata:
|
||||
if model_id in MODEL_CARDS:
|
||||
model_card = MODEL_CARDS[model_id]
|
||||
return model_card
|
||||
return model_card.metadata
|
||||
else:
|
||||
return await get_model_card(model_id)
|
||||
return await get_model_meta(model_id)
|
||||
|
||||
|
||||
class API:
|
||||
@@ -125,6 +149,7 @@ class API:
|
||||
self._setup_exception_handlers()
|
||||
self._setup_cors()
|
||||
self._setup_routes()
|
||||
self._register_plugin_routes()
|
||||
|
||||
self.app.mount(
|
||||
"/",
|
||||
@@ -193,10 +218,62 @@ class API:
|
||||
self.app.post("/bench/chat/completions")(self.bench_chat_completions)
|
||||
self.app.get("/state")(lambda: self.state)
|
||||
self.app.get("/events")(lambda: self._event_log)
|
||||
# Remote execution endpoint (used by exo-rsh for MPI)
|
||||
self.app.post("/execute")(self.execute)
|
||||
|
||||
def _register_plugin_routes(self) -> None:
|
||||
"""Register API routes from all loaded plugins."""
|
||||
import functools
|
||||
import inspect
|
||||
|
||||
from exo.plugins.context import PluginContext
|
||||
from exo.plugins.registry import PluginRegistry
|
||||
|
||||
registry = PluginRegistry.get()
|
||||
|
||||
for method, path, handler, plugin in registry.get_all_api_routes():
|
||||
# Create a wrapper that injects the plugin context while preserving
|
||||
# the original function signature for FastAPI parameter extraction
|
||||
@functools.wraps(handler)
|
||||
async def wrapped_handler( # pyright: ignore[reportAny]
|
||||
*args: Any, # pyright: ignore[reportAny]
|
||||
_handler: Any = handler, # pyright: ignore[reportAny]
|
||||
**kwargs: Any, # pyright: ignore[reportAny]
|
||||
) -> Any:
|
||||
context = PluginContext(
|
||||
state=self.state,
|
||||
send_command=self._send,
|
||||
node_id=self.node_id,
|
||||
)
|
||||
# Pass context as first argument, then forward all other args
|
||||
return await _handler(context, *args, **kwargs) # pyright: ignore[reportAny]
|
||||
|
||||
# Modify the wrapper signature to match the original handler
|
||||
# but without the 'ctx' parameter (we inject it)
|
||||
orig_sig = inspect.signature(handler)
|
||||
params = list(orig_sig.parameters.values())
|
||||
# Remove the first parameter (ctx: PluginContext)
|
||||
if params and params[0].name in ("ctx", "context"):
|
||||
params = params[1:]
|
||||
wrapped_handler.__signature__ = orig_sig.replace(parameters=params) # pyright: ignore[reportAttributeAccessIssue]
|
||||
|
||||
# Register the route based on HTTP method
|
||||
if method == "get":
|
||||
self.app.get(path)(wrapped_handler)
|
||||
elif method == "post":
|
||||
self.app.post(path)(wrapped_handler)
|
||||
elif method == "delete":
|
||||
self.app.delete(path)(wrapped_handler)
|
||||
elif method == "put":
|
||||
self.app.put(path)(wrapped_handler)
|
||||
|
||||
logger.debug(
|
||||
f"Registered plugin route: {method.upper()} {path} ({plugin.name})"
|
||||
)
|
||||
|
||||
async def place_instance(self, payload: PlaceInstanceParams):
|
||||
command = PlaceInstance(
|
||||
model_card=await resolve_model_card(payload.model_id),
|
||||
model_meta=await resolve_model_meta(payload.model_id),
|
||||
sharding=payload.sharding,
|
||||
instance_meta=payload.instance_meta,
|
||||
min_nodes=payload.min_nodes,
|
||||
@@ -206,15 +283,15 @@ class API:
|
||||
return CreateInstanceResponse(
|
||||
message="Command received.",
|
||||
command_id=command.command_id,
|
||||
model_card=command.model_card,
|
||||
model_meta=command.model_meta,
|
||||
)
|
||||
|
||||
async def create_instance(
|
||||
self, payload: CreateInstanceParams
|
||||
) -> CreateInstanceResponse:
|
||||
instance = payload.instance
|
||||
model_card = await resolve_model_card(instance.shard_assignments.model_id)
|
||||
required_memory = model_card.storage_size
|
||||
model_meta = await resolve_model_meta(instance.shard_assignments.model_id)
|
||||
required_memory = model_meta.storage_size
|
||||
available_memory = self._calculate_total_available_memory()
|
||||
|
||||
if required_memory > available_memory:
|
||||
@@ -231,7 +308,7 @@ class API:
|
||||
return CreateInstanceResponse(
|
||||
message="Command received.",
|
||||
command_id=command.command_id,
|
||||
model_card=model_card,
|
||||
model_meta=model_meta,
|
||||
)
|
||||
|
||||
async def get_placement(
|
||||
@@ -241,12 +318,12 @@ class API:
|
||||
instance_meta: InstanceMeta = InstanceMeta.MlxRing,
|
||||
min_nodes: int = 1,
|
||||
) -> Instance:
|
||||
model_card = await resolve_model_card(model_id)
|
||||
model_meta = await resolve_model_meta(model_id)
|
||||
|
||||
try:
|
||||
placements = get_instance_placements(
|
||||
PlaceInstance(
|
||||
model_card=model_card,
|
||||
model_meta=model_meta,
|
||||
sharding=sharding,
|
||||
instance_meta=instance_meta,
|
||||
min_nodes=min_nodes,
|
||||
@@ -279,7 +356,7 @@ class API:
|
||||
if len(list(self.state.topology.list_nodes())) == 0:
|
||||
return PlacementPreviewResponse(previews=[])
|
||||
|
||||
cards = [card for card in MODEL_CARDS.values() if card.model_id == model_id]
|
||||
cards = [card for card in MODEL_CARDS.values() if card.short_id == model_id]
|
||||
if not cards:
|
||||
raise HTTPException(status_code=404, detail=f"Model {model_id} not found")
|
||||
|
||||
@@ -297,12 +374,13 @@ class API:
|
||||
# TODO: PDD
|
||||
# instance_combinations.append((Sharding.PrefillDecodeDisaggregation, InstanceMeta.MlxRing, 1))
|
||||
|
||||
for model_card in cards:
|
||||
for card in cards:
|
||||
model_meta = card.metadata
|
||||
for sharding, instance_meta, min_nodes in instance_combinations:
|
||||
try:
|
||||
placements = get_instance_placements(
|
||||
PlaceInstance(
|
||||
model_card=model_card,
|
||||
model_meta=model_meta,
|
||||
sharding=sharding,
|
||||
instance_meta=instance_meta,
|
||||
min_nodes=min_nodes,
|
||||
@@ -313,17 +391,17 @@ class API:
|
||||
current_instances=self.state.instances,
|
||||
)
|
||||
except ValueError as exc:
|
||||
if (model_card.model_id, sharding, instance_meta, 0) not in seen:
|
||||
if (card.model_id, sharding, instance_meta, 0) not in seen:
|
||||
previews.append(
|
||||
PlacementPreview(
|
||||
model_id=model_card.model_id,
|
||||
model_id=card.model_id,
|
||||
sharding=sharding,
|
||||
instance_meta=instance_meta,
|
||||
instance=None,
|
||||
error=str(exc),
|
||||
)
|
||||
)
|
||||
seen.add((model_card.model_id, sharding, instance_meta, 0))
|
||||
seen.add((card.model_id, sharding, instance_meta, 0))
|
||||
continue
|
||||
|
||||
current_ids = set(self.state.instances.keys())
|
||||
@@ -334,17 +412,17 @@ class API:
|
||||
]
|
||||
|
||||
if len(new_instances) != 1:
|
||||
if (model_card.model_id, sharding, instance_meta, 0) not in seen:
|
||||
if (card.model_id, sharding, instance_meta, 0) not in seen:
|
||||
previews.append(
|
||||
PlacementPreview(
|
||||
model_id=model_card.model_id,
|
||||
model_id=card.model_id,
|
||||
sharding=sharding,
|
||||
instance_meta=instance_meta,
|
||||
instance=None,
|
||||
error="Expected exactly one new instance from placement",
|
||||
)
|
||||
)
|
||||
seen.add((model_card.model_id, sharding, instance_meta, 0))
|
||||
seen.add((card.model_id, sharding, instance_meta, 0))
|
||||
continue
|
||||
|
||||
instance = new_instances[0]
|
||||
@@ -353,7 +431,7 @@ class API:
|
||||
|
||||
memory_delta_by_node: dict[str, int] = {}
|
||||
if node_ids:
|
||||
total_bytes = model_card.storage_size.in_bytes
|
||||
total_bytes = model_meta.storage_size.in_bytes
|
||||
per_node = total_bytes // len(node_ids)
|
||||
remainder = total_bytes % len(node_ids)
|
||||
for index, node_id in enumerate(sorted(node_ids, key=str)):
|
||||
@@ -361,14 +439,14 @@ class API:
|
||||
memory_delta_by_node[str(node_id)] = per_node + extra
|
||||
|
||||
if (
|
||||
model_card.model_id,
|
||||
card.model_id,
|
||||
sharding,
|
||||
instance_meta,
|
||||
len(node_ids),
|
||||
) not in seen:
|
||||
previews.append(
|
||||
PlacementPreview(
|
||||
model_id=model_card.model_id,
|
||||
model_id=card.model_id,
|
||||
sharding=sharding,
|
||||
instance_meta=instance_meta,
|
||||
instance=instance,
|
||||
@@ -376,7 +454,7 @@ class API:
|
||||
error=None,
|
||||
)
|
||||
)
|
||||
seen.add((model_card.model_id, sharding, instance_meta, len(node_ids)))
|
||||
seen.add((card.model_id, sharding, instance_meta, len(node_ids)))
|
||||
|
||||
return PlacementPreviewResponse(previews=previews)
|
||||
|
||||
@@ -551,8 +629,8 @@ class API:
|
||||
self, payload: ChatCompletionTaskParams
|
||||
) -> ChatCompletionResponse | StreamingResponse:
|
||||
"""Handle chat completions, supporting both streaming and non-streaming responses."""
|
||||
model_card = await resolve_model_card(payload.model)
|
||||
payload.model = model_card.model_id
|
||||
model_meta = await resolve_model_meta(payload.model)
|
||||
payload.model = model_meta.model_id
|
||||
|
||||
if not any(
|
||||
instance.shard_assignments.model_id == payload.model
|
||||
@@ -578,8 +656,8 @@ class API:
|
||||
async def bench_chat_completions(
|
||||
self, payload: BenchChatCompletionTaskParams
|
||||
) -> BenchChatCompletionResponse:
|
||||
model_card = await resolve_model_card(payload.model)
|
||||
payload.model = model_card.model_id
|
||||
model_meta = await resolve_model_meta(payload.model)
|
||||
payload.model = model_meta.model_id
|
||||
|
||||
if not any(
|
||||
instance.shard_assignments.model_id == payload.model
|
||||
@@ -612,18 +690,77 @@ class API:
|
||||
return ModelList(
|
||||
data=[
|
||||
ModelListModel(
|
||||
id=card.model_id,
|
||||
id=card.short_id,
|
||||
hugging_face_id=card.model_id,
|
||||
name=card.model_id.short(),
|
||||
description="",
|
||||
tags=[],
|
||||
storage_size_megabytes=int(card.storage_size.in_mb),
|
||||
supports_tensor=card.supports_tensor,
|
||||
name=card.name,
|
||||
description=card.description,
|
||||
tags=card.tags,
|
||||
storage_size_megabytes=int(card.metadata.storage_size.in_mb),
|
||||
supports_tensor=card.metadata.supports_tensor,
|
||||
)
|
||||
for card in MODEL_CARDS.values()
|
||||
]
|
||||
)
|
||||
|
||||
async def execute(self, request: ExecuteRequest) -> ExecuteResponse:
|
||||
"""Execute a command locally. Used by exo-rsh for MPI remote execution."""
|
||||
cmd_str = " ".join(request.command)
|
||||
logger.info(f"Executing: {cmd_str}")
|
||||
|
||||
try:
|
||||
# Build environment
|
||||
env = os.environ.copy()
|
||||
if request.env:
|
||||
env.update(request.env)
|
||||
|
||||
# Check if command contains shell metacharacters
|
||||
# If so, run through shell. mpirun sends complex commands like:
|
||||
# "VAR=value;export VAR;/path/to/prted --args"
|
||||
needs_shell = any(c in cmd_str for c in ";|&$`")
|
||||
|
||||
if needs_shell:
|
||||
process = await asyncio.create_subprocess_shell(
|
||||
cmd_str,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
cwd=request.cwd,
|
||||
env=env,
|
||||
)
|
||||
else:
|
||||
process = await asyncio.create_subprocess_exec(
|
||||
*request.command,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
cwd=request.cwd,
|
||||
env=env,
|
||||
)
|
||||
|
||||
stdout, stderr = await process.communicate()
|
||||
exit_code = process.returncode or 0
|
||||
|
||||
logger.info(f"Command completed with exit code {exit_code}")
|
||||
|
||||
return ExecuteResponse(
|
||||
exit_code=exit_code,
|
||||
stdout=stdout.decode("utf-8", errors="replace"),
|
||||
stderr=stderr.decode("utf-8", errors="replace"),
|
||||
)
|
||||
|
||||
except FileNotFoundError:
|
||||
logger.error(f"Command not found: {request.command[0]}")
|
||||
return ExecuteResponse(
|
||||
exit_code=127,
|
||||
stdout="",
|
||||
stderr=f"Command not found: {request.command[0]}",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Execution error: {e}")
|
||||
return ExecuteResponse(
|
||||
exit_code=1,
|
||||
stdout="",
|
||||
stderr=str(e),
|
||||
)
|
||||
|
||||
async def run(self):
|
||||
cfg = Config()
|
||||
cfg.bind = f"0.0.0.0:{self.port}"
|
||||
|
||||
@@ -96,104 +96,128 @@ class Master:
|
||||
self._tg.cancel_scope.cancel()
|
||||
|
||||
async def _command_processor(self) -> None:
|
||||
from exo.plugins.registry import PluginRegistry
|
||||
|
||||
registry = PluginRegistry.get()
|
||||
|
||||
with self.command_receiver as commands:
|
||||
async for forwarder_command in commands:
|
||||
try:
|
||||
logger.info(f"Executing command: {forwarder_command.command}")
|
||||
generated_events: list[Event] = []
|
||||
command = forwarder_command.command
|
||||
match command:
|
||||
case TestCommand():
|
||||
pass
|
||||
case ChatCompletion():
|
||||
instance_task_counts: dict[InstanceId, int] = {}
|
||||
for instance in self.state.instances.values():
|
||||
if (
|
||||
instance.shard_assignments.model_id
|
||||
== command.request_params.model
|
||||
):
|
||||
task_count = sum(
|
||||
1
|
||||
for task in self.state.tasks.values()
|
||||
if task.instance_id == instance.instance_id
|
||||
)
|
||||
instance_task_counts[instance.instance_id] = (
|
||||
task_count
|
||||
|
||||
# Check if a plugin handles this command
|
||||
plugin = registry.get_plugin_for_command(command)
|
||||
if plugin is not None:
|
||||
events = plugin.process_command(
|
||||
command,
|
||||
self.state.topology,
|
||||
self.state.instances,
|
||||
)
|
||||
generated_events.extend(events)
|
||||
else:
|
||||
# Core command handling
|
||||
match command:
|
||||
case TestCommand():
|
||||
pass
|
||||
case ChatCompletion():
|
||||
instance_task_counts: dict[InstanceId, int] = {}
|
||||
for instance in self.state.instances.values():
|
||||
if (
|
||||
instance.shard_assignments.model_id
|
||||
== command.request_params.model
|
||||
):
|
||||
task_count = sum(
|
||||
1
|
||||
for task in self.state.tasks.values()
|
||||
if task.instance_id == instance.instance_id
|
||||
)
|
||||
instance_task_counts[instance.instance_id] = (
|
||||
task_count
|
||||
)
|
||||
|
||||
if not instance_task_counts:
|
||||
raise ValueError(
|
||||
f"No instance found for model {command.request_params.model}"
|
||||
)
|
||||
|
||||
if not instance_task_counts:
|
||||
raise ValueError(
|
||||
f"No instance found for model {command.request_params.model}"
|
||||
available_instance_ids = sorted(
|
||||
instance_task_counts.keys(),
|
||||
key=lambda instance_id: instance_task_counts[
|
||||
instance_id
|
||||
],
|
||||
)
|
||||
|
||||
available_instance_ids = sorted(
|
||||
instance_task_counts.keys(),
|
||||
key=lambda instance_id: instance_task_counts[
|
||||
instance_id
|
||||
],
|
||||
)
|
||||
|
||||
task_id = TaskId()
|
||||
generated_events.append(
|
||||
TaskCreated(
|
||||
task_id=task_id,
|
||||
task=ChatCompletionTask(
|
||||
task_id = TaskId()
|
||||
generated_events.append(
|
||||
TaskCreated(
|
||||
task_id=task_id,
|
||||
command_id=command.command_id,
|
||||
instance_id=available_instance_ids[0],
|
||||
task_status=TaskStatus.Pending,
|
||||
task_params=command.request_params,
|
||||
),
|
||||
task=ChatCompletionTask(
|
||||
task_id=task_id,
|
||||
command_id=command.command_id,
|
||||
instance_id=available_instance_ids[0],
|
||||
task_status=TaskStatus.Pending,
|
||||
task_params=command.request_params,
|
||||
),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
self.command_task_mapping[command.command_id] = task_id
|
||||
case DeleteInstance():
|
||||
placement = delete_instance(command, self.state.instances)
|
||||
transition_events = get_transition_events(
|
||||
self.state.instances, placement
|
||||
)
|
||||
generated_events.extend(transition_events)
|
||||
case PlaceInstance():
|
||||
placement = place_instance(
|
||||
command,
|
||||
self.state.topology,
|
||||
self.state.instances,
|
||||
self.state.node_memory,
|
||||
self.state.node_network,
|
||||
)
|
||||
transition_events = get_transition_events(
|
||||
self.state.instances, placement
|
||||
)
|
||||
generated_events.extend(transition_events)
|
||||
case CreateInstance():
|
||||
placement = add_instance_to_placements(
|
||||
command,
|
||||
self.state.topology,
|
||||
self.state.instances,
|
||||
)
|
||||
transition_events = get_transition_events(
|
||||
self.state.instances, placement
|
||||
)
|
||||
generated_events.extend(transition_events)
|
||||
case TaskFinished():
|
||||
generated_events.append(
|
||||
TaskDeleted(
|
||||
task_id=self.command_task_mapping[
|
||||
self.command_task_mapping[command.command_id] = task_id
|
||||
case DeleteInstance():
|
||||
placement = delete_instance(
|
||||
command, self.state.instances
|
||||
)
|
||||
transition_events = get_transition_events(
|
||||
self.state.instances, placement
|
||||
)
|
||||
generated_events.extend(transition_events)
|
||||
case PlaceInstance():
|
||||
placement = place_instance(
|
||||
command,
|
||||
self.state.topology,
|
||||
self.state.instances,
|
||||
self.state.node_memory,
|
||||
self.state.node_network,
|
||||
)
|
||||
transition_events = get_transition_events(
|
||||
self.state.instances, placement
|
||||
)
|
||||
generated_events.extend(transition_events)
|
||||
case CreateInstance():
|
||||
placement = add_instance_to_placements(
|
||||
command,
|
||||
self.state.topology,
|
||||
self.state.instances,
|
||||
)
|
||||
transition_events = get_transition_events(
|
||||
self.state.instances, placement
|
||||
)
|
||||
generated_events.extend(transition_events)
|
||||
case TaskFinished():
|
||||
generated_events.append(
|
||||
TaskDeleted(
|
||||
task_id=self.command_task_mapping[
|
||||
command.finished_command_id
|
||||
]
|
||||
)
|
||||
)
|
||||
if (
|
||||
command.finished_command_id
|
||||
in self.command_task_mapping
|
||||
):
|
||||
del self.command_task_mapping[
|
||||
command.finished_command_id
|
||||
]
|
||||
)
|
||||
)
|
||||
if command.finished_command_id in self.command_task_mapping:
|
||||
del self.command_task_mapping[
|
||||
command.finished_command_id
|
||||
]
|
||||
case RequestEventLog():
|
||||
# We should just be able to send everything, since other buffers will ignore old messages
|
||||
for i in range(command.since_idx, len(self._event_log)):
|
||||
await self._send_event(
|
||||
IndexedEvent(idx=i, event=self._event_log[i])
|
||||
)
|
||||
case RequestEventLog():
|
||||
# We should just be able to send everything, since other buffers will ignore old messages
|
||||
for i in range(command.since_idx, len(self._event_log)):
|
||||
await self._send_event(
|
||||
IndexedEvent(idx=i, event=self._event_log[i])
|
||||
)
|
||||
case _:
|
||||
# Plugin-managed commands are handled above
|
||||
pass
|
||||
for event in generated_events:
|
||||
await self.event_sender.send(event)
|
||||
except ValueError as e:
|
||||
|
||||
@@ -14,7 +14,6 @@ from exo.master.placement_utils import (
|
||||
get_shard_assignments,
|
||||
get_smallest_cycles,
|
||||
)
|
||||
from exo.shared.models.model_cards import ModelId
|
||||
from exo.shared.topology import Topology
|
||||
from exo.shared.types.commands import (
|
||||
CreateInstance,
|
||||
@@ -24,6 +23,7 @@ from exo.shared.types.commands import (
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.events import Event, InstanceCreated, InstanceDeleted
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.models import ModelId
|
||||
from exo.shared.types.profiling import MemoryUsage, NodeNetworkInfo
|
||||
from exo.shared.types.worker.instances import (
|
||||
Instance,
|
||||
@@ -60,27 +60,27 @@ def place_instance(
|
||||
cycles = topology.get_cycles()
|
||||
candidate_cycles = list(filter(lambda it: len(it) >= command.min_nodes, cycles))
|
||||
cycles_with_sufficient_memory = filter_cycles_by_memory(
|
||||
candidate_cycles, node_memory, command.model_card.storage_size
|
||||
candidate_cycles, node_memory, command.model_meta.storage_size
|
||||
)
|
||||
if len(cycles_with_sufficient_memory) == 0:
|
||||
raise ValueError("No cycles found with sufficient memory")
|
||||
|
||||
if command.sharding == Sharding.Tensor:
|
||||
if not command.model_card.supports_tensor:
|
||||
if not command.model_meta.supports_tensor:
|
||||
raise ValueError(
|
||||
f"Requested Tensor sharding but this model does not support tensor parallelism: {command.model_card.model_id}"
|
||||
f"Requested Tensor sharding but this model does not support tensor parallelism: {command.model_meta.model_id}"
|
||||
)
|
||||
# TODO: the condition here for tensor parallel is not correct, but it works good enough for now.
|
||||
cycles_with_sufficient_memory = [
|
||||
cycle
|
||||
for cycle in cycles_with_sufficient_memory
|
||||
if command.model_card.hidden_size % len(cycle) == 0
|
||||
if command.model_meta.hidden_size % len(cycle) == 0
|
||||
]
|
||||
if not cycles_with_sufficient_memory:
|
||||
raise ValueError(
|
||||
f"No tensor sharding found for model with hidden_size {command.model_card.hidden_size} candidate cycles"
|
||||
f"No tensor sharding found for model with hidden_size {command.model_meta.hidden_size} candidate cycles"
|
||||
)
|
||||
if command.sharding == Sharding.Pipeline and command.model_card.model_id == ModelId(
|
||||
if command.sharding == Sharding.Pipeline and command.model_meta.model_id == ModelId(
|
||||
"mlx-community/DeepSeek-V3.1-8bit"
|
||||
):
|
||||
raise ValueError(
|
||||
@@ -111,7 +111,7 @@ def place_instance(
|
||||
)
|
||||
|
||||
shard_assignments = get_shard_assignments(
|
||||
command.model_card, selected_cycle, command.sharding, node_memory
|
||||
command.model_meta, selected_cycle, command.sharding, node_memory
|
||||
)
|
||||
|
||||
cycle_digraph: Topology = topology.get_subgraph_from_nodes(selected_cycle.node_ids)
|
||||
@@ -159,6 +159,11 @@ def place_instance(
|
||||
hosts_by_node=hosts_by_node,
|
||||
ephemeral_port=ephemeral_port,
|
||||
)
|
||||
case _:
|
||||
# Plugin-managed instance types have their own placement functions
|
||||
raise ValueError(
|
||||
f"Instance type {command.instance_meta} must use plugin placement"
|
||||
)
|
||||
|
||||
return target_instances
|
||||
|
||||
|
||||
@@ -2,10 +2,10 @@ from collections.abc import Generator, Mapping
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from exo.shared.models.model_cards import ModelCard
|
||||
from exo.shared.topology import Topology
|
||||
from exo.shared.types.common import Host, NodeId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.models import ModelMetadata
|
||||
from exo.shared.types.profiling import MemoryUsage, NodeNetworkInfo
|
||||
from exo.shared.types.topology import Cycle, RDMAConnection, SocketConnection
|
||||
from exo.shared.types.worker.runners import RunnerId, ShardAssignments
|
||||
@@ -75,7 +75,7 @@ def allocate_layers_proportionally(
|
||||
|
||||
|
||||
def get_shard_assignments_for_pipeline_parallel(
|
||||
model_card: ModelCard,
|
||||
model_meta: ModelMetadata,
|
||||
cycle: Cycle,
|
||||
node_memory: Mapping[NodeId, MemoryUsage],
|
||||
):
|
||||
@@ -86,10 +86,11 @@ def get_shard_assignments_for_pipeline_parallel(
|
||||
(node_memory[node_id].ram_available for node_id in cycle.node_ids),
|
||||
start=Memory(),
|
||||
)
|
||||
|
||||
if cycle_memory.in_bytes == 0:
|
||||
raise ValueError("Cannot create shard assignments: total available memory is 0")
|
||||
|
||||
total_layers = model_card.n_layers
|
||||
total_layers = model_meta.n_layers
|
||||
world_size = len(cycle)
|
||||
runner_to_shard: dict[RunnerId, ShardMetadata] = {}
|
||||
node_to_runner: dict[NodeId, RunnerId] = {}
|
||||
@@ -103,7 +104,7 @@ def get_shard_assignments_for_pipeline_parallel(
|
||||
)
|
||||
|
||||
# Validate each node has sufficient memory for its assigned layers
|
||||
memory_per_layer = model_card.storage_size.in_bytes / total_layers
|
||||
memory_per_layer = model_meta.storage_size.in_bytes / total_layers
|
||||
for i, (node_id, node_layers) in enumerate(
|
||||
zip(cycle.node_ids, layer_allocations, strict=True)
|
||||
):
|
||||
@@ -123,7 +124,7 @@ def get_shard_assignments_for_pipeline_parallel(
|
||||
runner_id = RunnerId()
|
||||
|
||||
shard = PipelineShardMetadata(
|
||||
model_card=model_card,
|
||||
model_meta=model_meta,
|
||||
device_rank=i,
|
||||
world_size=world_size,
|
||||
start_layer=layers_assigned,
|
||||
@@ -136,7 +137,7 @@ def get_shard_assignments_for_pipeline_parallel(
|
||||
layers_assigned += node_layers
|
||||
|
||||
shard_assignments = ShardAssignments(
|
||||
model_id=model_card.model_id,
|
||||
model_id=model_meta.model_id,
|
||||
runner_to_shard=runner_to_shard,
|
||||
node_to_runner=node_to_runner,
|
||||
)
|
||||
@@ -145,17 +146,17 @@ def get_shard_assignments_for_pipeline_parallel(
|
||||
|
||||
|
||||
def get_shard_assignments_for_tensor_parallel(
|
||||
model_card: ModelCard,
|
||||
model_meta: ModelMetadata,
|
||||
cycle: Cycle,
|
||||
):
|
||||
total_layers = model_card.n_layers
|
||||
total_layers = model_meta.n_layers
|
||||
world_size = len(cycle)
|
||||
runner_to_shard: dict[RunnerId, ShardMetadata] = {}
|
||||
node_to_runner: dict[NodeId, RunnerId] = {}
|
||||
|
||||
for i, node_id in enumerate(cycle):
|
||||
shard = TensorShardMetadata(
|
||||
model_card=model_card,
|
||||
model_meta=model_meta,
|
||||
device_rank=i,
|
||||
world_size=world_size,
|
||||
start_layer=0,
|
||||
@@ -169,7 +170,7 @@ def get_shard_assignments_for_tensor_parallel(
|
||||
node_to_runner[node_id] = runner_id
|
||||
|
||||
shard_assignments = ShardAssignments(
|
||||
model_id=model_card.model_id,
|
||||
model_id=model_meta.model_id,
|
||||
runner_to_shard=runner_to_shard,
|
||||
node_to_runner=node_to_runner,
|
||||
)
|
||||
@@ -178,7 +179,7 @@ def get_shard_assignments_for_tensor_parallel(
|
||||
|
||||
|
||||
def get_shard_assignments(
|
||||
model_card: ModelCard,
|
||||
model_meta: ModelMetadata,
|
||||
cycle: Cycle,
|
||||
sharding: Sharding,
|
||||
node_memory: Mapping[NodeId, MemoryUsage],
|
||||
@@ -186,13 +187,13 @@ def get_shard_assignments(
|
||||
match sharding:
|
||||
case Sharding.Pipeline:
|
||||
return get_shard_assignments_for_pipeline_parallel(
|
||||
model_card=model_card,
|
||||
model_meta=model_meta,
|
||||
cycle=cycle,
|
||||
node_memory=node_memory,
|
||||
)
|
||||
case Sharding.Tensor:
|
||||
return get_shard_assignments_for_tensor_parallel(
|
||||
model_card=model_card,
|
||||
model_meta=model_meta,
|
||||
cycle=cycle,
|
||||
)
|
||||
|
||||
|
||||
@@ -7,7 +7,6 @@ from loguru import logger
|
||||
|
||||
from exo.master.main import Master
|
||||
from exo.routing.router import get_node_id_keypair
|
||||
from exo.shared.models.model_cards import ModelCard, ModelId
|
||||
from exo.shared.types.api import ChatCompletionMessage, ChatCompletionTaskParams
|
||||
from exo.shared.types.commands import (
|
||||
ChatCompletion,
|
||||
@@ -24,6 +23,7 @@ from exo.shared.types.events import (
|
||||
TaskCreated,
|
||||
)
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.models import ModelId, ModelMetadata
|
||||
from exo.shared.types.profiling import (
|
||||
MemoryUsage,
|
||||
)
|
||||
@@ -109,8 +109,9 @@ async def test_master():
|
||||
command=(
|
||||
PlaceInstance(
|
||||
command_id=CommandId(),
|
||||
model_card=ModelCard(
|
||||
model_meta=ModelMetadata(
|
||||
model_id=ModelId("llama-3.2-1b"),
|
||||
pretty_name="Llama 3.2 1B",
|
||||
n_layers=16,
|
||||
storage_size=Memory.from_bytes(678948),
|
||||
hidden_size=7168,
|
||||
@@ -166,8 +167,9 @@ async def test_master():
|
||||
start_layer=0,
|
||||
end_layer=16,
|
||||
n_layers=16,
|
||||
model_card=ModelCard(
|
||||
model_meta=ModelMetadata(
|
||||
model_id=ModelId("llama-3.2-1b"),
|
||||
pretty_name="Llama 3.2 1B",
|
||||
n_layers=16,
|
||||
storage_size=Memory.from_bytes(678948),
|
||||
hidden_size=7168,
|
||||
|
||||
@@ -10,12 +10,12 @@ from exo.master.tests.conftest import (
|
||||
create_rdma_connection,
|
||||
create_socket_connection,
|
||||
)
|
||||
from exo.shared.models.model_cards import ModelCard, ModelId
|
||||
from exo.shared.topology import Topology
|
||||
from exo.shared.types.commands import PlaceInstance
|
||||
from exo.shared.types.common import CommandId, NodeId
|
||||
from exo.shared.types.events import InstanceCreated, InstanceDeleted
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.models import ModelId, ModelMetadata
|
||||
from exo.shared.types.multiaddr import Multiaddr
|
||||
from exo.shared.types.profiling import NetworkInterfaceInfo, NodeNetworkInfo
|
||||
from exo.shared.types.topology import Connection, SocketConnection
|
||||
@@ -43,20 +43,21 @@ def instance() -> Instance:
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model_card() -> ModelCard:
|
||||
return ModelCard(
|
||||
def model_meta() -> ModelMetadata:
|
||||
return ModelMetadata(
|
||||
model_id=ModelId("test-model"),
|
||||
storage_size=Memory.from_kb(1000),
|
||||
pretty_name="Test Model",
|
||||
n_layers=10,
|
||||
hidden_size=30,
|
||||
supports_tensor=True,
|
||||
)
|
||||
|
||||
|
||||
def place_instance_command(model_card: ModelCard) -> PlaceInstance:
|
||||
def place_instance_command(model_meta: ModelMetadata) -> PlaceInstance:
|
||||
return PlaceInstance(
|
||||
command_id=CommandId(),
|
||||
model_card=model_card,
|
||||
model_meta=model_meta,
|
||||
sharding=Sharding.Pipeline,
|
||||
instance_meta=InstanceMeta.MlxRing,
|
||||
min_nodes=1,
|
||||
@@ -75,16 +76,16 @@ def test_get_instance_placements_create_instance(
|
||||
available_memory: tuple[int, int, int],
|
||||
total_layers: int,
|
||||
expected_layers: tuple[int, int, int],
|
||||
model_card: ModelCard,
|
||||
model_meta: ModelMetadata,
|
||||
):
|
||||
# arrange
|
||||
model_card.n_layers = total_layers
|
||||
model_card.storage_size.in_bytes = sum(
|
||||
model_meta.n_layers = total_layers
|
||||
model_meta.storage_size.in_bytes = sum(
|
||||
available_memory
|
||||
) # make it exactly fit across all nodes
|
||||
topology = Topology()
|
||||
|
||||
cic = place_instance_command(model_card)
|
||||
cic = place_instance_command(model_meta)
|
||||
node_id_a = NodeId()
|
||||
node_id_b = NodeId()
|
||||
node_id_c = NodeId()
|
||||
@@ -136,7 +137,7 @@ def test_get_instance_placements_create_instance(
|
||||
assert len(placements) == 1
|
||||
instance_id = list(placements.keys())[0]
|
||||
instance = placements[instance_id]
|
||||
assert instance.shard_assignments.model_id == model_card.model_id
|
||||
assert instance.shard_assignments.model_id == model_meta.model_id
|
||||
|
||||
runner_id_a = instance.shard_assignments.node_to_runner[node_id_a]
|
||||
runner_id_b = instance.shard_assignments.node_to_runner[node_id_b]
|
||||
@@ -163,9 +164,10 @@ def test_get_instance_placements_one_node_exact_fit() -> None:
|
||||
node_memory = {node_id: create_node_memory(1000 * 1024)}
|
||||
node_network = {node_id: create_node_network()}
|
||||
cic = place_instance_command(
|
||||
ModelCard(
|
||||
ModelMetadata(
|
||||
model_id=ModelId("test-model"),
|
||||
storage_size=Memory.from_kb(1000),
|
||||
pretty_name="Test Model",
|
||||
n_layers=10,
|
||||
hidden_size=1000,
|
||||
supports_tensor=True,
|
||||
@@ -189,9 +191,10 @@ def test_get_instance_placements_one_node_fits_with_extra_memory() -> None:
|
||||
node_memory = {node_id: create_node_memory(1001 * 1024)}
|
||||
node_network = {node_id: create_node_network()}
|
||||
cic = place_instance_command(
|
||||
ModelCard(
|
||||
ModelMetadata(
|
||||
model_id=ModelId("test-model"),
|
||||
storage_size=Memory.from_kb(1000),
|
||||
pretty_name="Test Model",
|
||||
n_layers=10,
|
||||
hidden_size=1000,
|
||||
supports_tensor=True,
|
||||
@@ -215,9 +218,10 @@ def test_get_instance_placements_one_node_not_fit() -> None:
|
||||
node_memory = {node_id: create_node_memory(1000 * 1024)}
|
||||
node_network = {node_id: create_node_network()}
|
||||
cic = place_instance_command(
|
||||
model_card=ModelCard(
|
||||
model_meta=ModelMetadata(
|
||||
model_id=ModelId("test-model"),
|
||||
storage_size=Memory.from_kb(1001),
|
||||
pretty_name="Test Model",
|
||||
n_layers=10,
|
||||
hidden_size=1000,
|
||||
supports_tensor=True,
|
||||
@@ -271,12 +275,12 @@ def test_get_transition_events_delete_instance(instance: Instance):
|
||||
|
||||
|
||||
def test_placement_selects_leaf_nodes(
|
||||
model_card: ModelCard,
|
||||
model_meta: ModelMetadata,
|
||||
):
|
||||
# arrange
|
||||
topology = Topology()
|
||||
|
||||
model_card.storage_size = Memory.from_bytes(1000)
|
||||
model_meta.storage_size = Memory.from_bytes(1000)
|
||||
|
||||
node_id_a = NodeId()
|
||||
node_id_b = NodeId()
|
||||
@@ -321,7 +325,7 @@ def test_placement_selects_leaf_nodes(
|
||||
Connection(source=node_id_d, sink=node_id_c, edge=create_socket_connection(1))
|
||||
)
|
||||
|
||||
cic = place_instance_command(model_card=model_card)
|
||||
cic = place_instance_command(model_meta=model_meta)
|
||||
|
||||
# act
|
||||
placements = place_instance(cic, topology, {}, node_memory, node_network)
|
||||
@@ -340,12 +344,12 @@ def test_placement_selects_leaf_nodes(
|
||||
|
||||
|
||||
def test_tensor_rdma_backend_connectivity_matrix(
|
||||
model_card: ModelCard,
|
||||
model_meta: ModelMetadata,
|
||||
):
|
||||
# arrange
|
||||
topology = Topology()
|
||||
model_card.n_layers = 12
|
||||
model_card.storage_size.in_bytes = 1500
|
||||
model_meta.n_layers = 12
|
||||
model_meta.storage_size.in_bytes = 1500
|
||||
|
||||
node_a = NodeId()
|
||||
node_b = NodeId()
|
||||
@@ -407,7 +411,7 @@ def test_tensor_rdma_backend_connectivity_matrix(
|
||||
sharding=Sharding.Tensor,
|
||||
instance_meta=InstanceMeta.MlxJaccl,
|
||||
command_id=CommandId(),
|
||||
model_card=model_card,
|
||||
model_meta=model_meta,
|
||||
min_nodes=1,
|
||||
)
|
||||
|
||||
|
||||
@@ -12,10 +12,10 @@ from exo.master.tests.conftest import (
|
||||
create_node_memory,
|
||||
create_socket_connection,
|
||||
)
|
||||
from exo.shared.models.model_cards import ModelCard, ModelId
|
||||
from exo.shared.topology import Topology
|
||||
from exo.shared.types.common import Host, NodeId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.models import ModelId, ModelMetadata
|
||||
from exo.shared.types.profiling import (
|
||||
NetworkInterfaceInfo,
|
||||
NodeNetworkInfo,
|
||||
@@ -232,8 +232,9 @@ def test_get_shard_assignments(
|
||||
node_c_id: node_c_mem,
|
||||
}
|
||||
|
||||
model_card = ModelCard(
|
||||
model_meta = ModelMetadata(
|
||||
model_id=ModelId("test-model"),
|
||||
pretty_name="Test Model",
|
||||
n_layers=total_layers,
|
||||
storage_size=Memory.from_kb(1000),
|
||||
hidden_size=1000,
|
||||
@@ -247,7 +248,7 @@ def test_get_shard_assignments(
|
||||
|
||||
# act
|
||||
shard_assignments = get_shard_assignments(
|
||||
model_card, selected_cycle, Sharding.Pipeline, node_memory=node_memory
|
||||
model_meta, selected_cycle, Sharding.Pipeline, node_memory=node_memory
|
||||
)
|
||||
|
||||
# assert
|
||||
@@ -511,8 +512,9 @@ def test_get_shard_assignments_insufficient_memory_raises():
|
||||
node_c_id: node_c_mem,
|
||||
}
|
||||
|
||||
model_card = ModelCard(
|
||||
model_meta = ModelMetadata(
|
||||
model_id=ModelId("test-model"),
|
||||
pretty_name="Test Model",
|
||||
n_layers=20,
|
||||
storage_size=Memory.from_kb(1000),
|
||||
hidden_size=1000,
|
||||
@@ -523,5 +525,5 @@ def test_get_shard_assignments_insufficient_memory_raises():
|
||||
|
||||
with pytest.raises(ValueError, match="insufficient memory"):
|
||||
get_shard_assignments(
|
||||
model_card, selected_cycle, Sharding.Pipeline, node_memory
|
||||
model_meta, selected_cycle, Sharding.Pipeline, node_memory
|
||||
)
|
||||
|
||||
16
src/exo/plugins/__init__.py
Normal file
16
src/exo/plugins/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
"""Exo Plugin System.
|
||||
|
||||
This module provides the plugin architecture for extending exo with custom
|
||||
workload types (simulations, ML frameworks, etc.) without modifying core code.
|
||||
"""
|
||||
|
||||
from exo.plugins.base import ExoPlugin, PluginCommand, PluginInstance
|
||||
from exo.plugins.registry import PluginRegistry, discover_plugins
|
||||
|
||||
__all__ = [
|
||||
"ExoPlugin",
|
||||
"PluginCommand",
|
||||
"PluginInstance",
|
||||
"PluginRegistry",
|
||||
"discover_plugins",
|
||||
]
|
||||
171
src/exo/plugins/base.py
Normal file
171
src/exo/plugins/base.py
Normal file
@@ -0,0 +1,171 @@
|
||||
"""Base classes and protocols for Exo plugins."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from exo.shared.types.common import CommandId
|
||||
from exo.shared.types.events import Event
|
||||
from exo.shared.types.tasks import Task
|
||||
from exo.shared.types.worker.instances import InstanceId
|
||||
from exo.shared.types.worker.runners import RunnerId
|
||||
from exo.utils.pydantic_ext import TaggedModel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from exo.shared.topology import Topology
|
||||
from exo.shared.types.worker.instances import BoundInstance, Instance
|
||||
from exo.utils.channels import MpReceiver, MpSender
|
||||
from exo.worker.runner.runner_supervisor import RunnerSupervisor
|
||||
|
||||
|
||||
class PluginCommand(TaggedModel):
|
||||
"""Base class for plugin-defined commands.
|
||||
|
||||
All plugin commands must inherit from this class. Commands are serialized
|
||||
with their class name as a tag for routing.
|
||||
"""
|
||||
|
||||
command_id: CommandId = Field(default_factory=CommandId)
|
||||
|
||||
|
||||
class PluginInstance(TaggedModel):
|
||||
"""Base class for plugin-defined instances.
|
||||
|
||||
All plugin instances must inherit from this class. Plugins are expected
|
||||
to define their own instance type with workload-specific fields.
|
||||
"""
|
||||
|
||||
instance_id: InstanceId
|
||||
|
||||
|
||||
class ExoPlugin(ABC):
|
||||
"""Protocol that all exo plugins must implement.
|
||||
|
||||
A plugin provides:
|
||||
- Custom command types for API -> Master communication
|
||||
- Custom instance types representing running workloads
|
||||
- Placement logic for distributing work across nodes
|
||||
- Planning logic for local task scheduling
|
||||
- Runner implementation for executing work
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def name(self) -> str:
|
||||
"""Unique identifier for this plugin (e.g., 'flash', 'pytorch', 'mpi')."""
|
||||
...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def version(self) -> str:
|
||||
"""Semantic version string (e.g., '1.0.0')."""
|
||||
...
|
||||
|
||||
# ========== Type Registration ==========
|
||||
|
||||
@abstractmethod
|
||||
def get_command_types(self) -> Sequence[type]:
|
||||
"""Return command types this plugin handles.
|
||||
|
||||
These commands are routed to this plugin's process_command method.
|
||||
Can return core BaseCommand types or PluginCommand types.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_instance_type(self) -> type:
|
||||
"""Return the instance type this plugin creates.
|
||||
|
||||
This instance type is used for routing in planning and runner bootstrap.
|
||||
Can return core Instance types or PluginInstance types.
|
||||
"""
|
||||
...
|
||||
|
||||
# ========== API Routes ==========
|
||||
|
||||
@abstractmethod
|
||||
def get_api_routes(
|
||||
self,
|
||||
) -> Sequence[tuple[str, str, Callable[..., Any]]]:
|
||||
"""Return FastAPI routes to register.
|
||||
|
||||
Each tuple: (method, path, handler)
|
||||
Example: [('post', '/flash/launch', self.launch_handler)]
|
||||
|
||||
Handlers receive a PluginContext with access to:
|
||||
- state: Current State object
|
||||
- send_command: Async function to send commands
|
||||
- node_id: Current node's ID
|
||||
"""
|
||||
...
|
||||
|
||||
# ========== Master Command Handling ==========
|
||||
|
||||
@abstractmethod
|
||||
def handles_command(self, command: Any) -> bool: # pyright: ignore[reportAny]
|
||||
"""Return True if this plugin handles the given command type."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def process_command(
|
||||
self,
|
||||
command: Any, # pyright: ignore[reportAny]
|
||||
topology: "Topology",
|
||||
current_instances: Mapping[InstanceId, "Instance"],
|
||||
) -> Sequence[Event]:
|
||||
"""Process a command and return events to emit.
|
||||
|
||||
Typically creates placement and returns InstanceCreated/InstanceDeleted events.
|
||||
|
||||
Args:
|
||||
command: The command to process
|
||||
topology: Current cluster topology
|
||||
current_instances: Currently running instances
|
||||
|
||||
Returns:
|
||||
Sequence of events to emit (e.g., InstanceCreated, InstanceDeleted)
|
||||
"""
|
||||
...
|
||||
|
||||
# ========== Worker Planning ==========
|
||||
|
||||
@abstractmethod
|
||||
def handles_instance(self, instance: object) -> bool:
|
||||
"""Return True if this plugin manages the given instance type."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def plan_task(
|
||||
self,
|
||||
runners: Mapping[RunnerId, "RunnerSupervisor"],
|
||||
instances: Mapping[InstanceId, "Instance"],
|
||||
) -> Task | None:
|
||||
"""Plan the next task for plugin instances.
|
||||
|
||||
Called during each planning cycle.
|
||||
Return None if no task is needed.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def should_skip_download(self, instance: object) -> bool:
|
||||
"""Return True if this instance type doesn't need model downloads."""
|
||||
...
|
||||
|
||||
# ========== Runner Bootstrap ==========
|
||||
|
||||
@abstractmethod
|
||||
def create_runner(
|
||||
self,
|
||||
bound_instance: "BoundInstance",
|
||||
event_sender: "MpSender[Event]",
|
||||
task_receiver: "MpReceiver[Task]",
|
||||
) -> None:
|
||||
"""Entry point for the runner process.
|
||||
|
||||
Called in a subprocess to execute the actual workload.
|
||||
This function should block until the workload completes.
|
||||
"""
|
||||
...
|
||||
21
src/exo/plugins/context.py
Normal file
21
src/exo/plugins/context.py
Normal file
@@ -0,0 +1,21 @@
|
||||
"""Context objects passed to plugin handlers."""
|
||||
|
||||
from collections.abc import Awaitable, Callable
|
||||
from dataclasses import dataclass
|
||||
|
||||
from exo.shared.types.commands import Command
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.state import State
|
||||
|
||||
|
||||
@dataclass
|
||||
class PluginContext:
|
||||
"""Context provided to plugin API handlers.
|
||||
|
||||
This gives plugins access to the current state and the ability to send
|
||||
commands without direct access to internal API components.
|
||||
"""
|
||||
|
||||
state: State
|
||||
send_command: Callable[[Command], Awaitable[None]]
|
||||
node_id: NodeId
|
||||
5
src/exo/plugins/implementations/__init__.py
Normal file
5
src/exo/plugins/implementations/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Plugin implementations directory.
|
||||
|
||||
Each subdirectory should contain a plugin with a register() function
|
||||
that returns an ExoPlugin instance.
|
||||
"""
|
||||
8
src/exo/plugins/implementations/flash/__init__.py
Normal file
8
src/exo/plugins/implementations/flash/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
"""FLASH Plugin - MPI-based simulation support for Exo."""
|
||||
|
||||
from exo.plugins.implementations.flash.plugin import FLASHPlugin
|
||||
|
||||
|
||||
def register() -> FLASHPlugin:
|
||||
"""Entry point for plugin discovery."""
|
||||
return FLASHPlugin()
|
||||
108
src/exo/plugins/implementations/flash/api_handlers.py
Normal file
108
src/exo/plugins/implementations/flash/api_handlers.py
Normal file
@@ -0,0 +1,108 @@
|
||||
"""FLASH plugin API handlers."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from exo.plugins.context import PluginContext
|
||||
|
||||
# Use core types for serialization compatibility
|
||||
from exo.shared.types.commands import LaunchFLASH, StopFLASH
|
||||
from exo.shared.types.worker.instances import FLASHInstance
|
||||
|
||||
|
||||
async def handle_launch_flash(
|
||||
ctx: PluginContext,
|
||||
simulation_name: str,
|
||||
flash_executable_path: str,
|
||||
working_directory: str,
|
||||
parameter_file_path: str = "",
|
||||
ranks_per_node: int = 1,
|
||||
min_nodes: int = 1,
|
||||
hosts: str = "",
|
||||
) -> dict[str, str]:
|
||||
"""Launch a FLASH MPI simulation across the cluster.
|
||||
|
||||
Args:
|
||||
ctx: Plugin context with state and send_command
|
||||
simulation_name: Name of the simulation
|
||||
flash_executable_path: Path to the FLASH executable
|
||||
working_directory: Working directory for the simulation
|
||||
parameter_file_path: Path to parameter file (optional)
|
||||
ranks_per_node: Number of MPI ranks per node
|
||||
min_nodes: Minimum number of nodes required
|
||||
hosts: Optional comma-separated hostnames (e.g., "s14,james21-1").
|
||||
If not provided, IPs are discovered from topology edges.
|
||||
"""
|
||||
command = LaunchFLASH(
|
||||
simulation_name=simulation_name,
|
||||
flash_executable_path=flash_executable_path,
|
||||
parameter_file_path=parameter_file_path,
|
||||
working_directory=working_directory,
|
||||
ranks_per_node=ranks_per_node,
|
||||
min_nodes=min_nodes,
|
||||
hosts=hosts,
|
||||
)
|
||||
await ctx.send_command(command)
|
||||
|
||||
return {
|
||||
"message": "FLASH launch command received",
|
||||
"command_id": str(command.command_id),
|
||||
"simulation_name": simulation_name,
|
||||
}
|
||||
|
||||
|
||||
async def handle_stop_flash(
|
||||
ctx: PluginContext,
|
||||
instance_id: str,
|
||||
) -> dict[str, str]:
|
||||
"""Stop a running FLASH simulation."""
|
||||
from exo.shared.types.worker.instances import InstanceId
|
||||
|
||||
inst_id = InstanceId(instance_id)
|
||||
|
||||
if inst_id not in ctx.state.instances:
|
||||
raise HTTPException(status_code=404, detail="Instance not found")
|
||||
|
||||
instance = ctx.state.instances[inst_id]
|
||||
if not isinstance(instance, FLASHInstance):
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Instance is not a FLASH simulation"
|
||||
)
|
||||
|
||||
command = StopFLASH(instance_id=inst_id)
|
||||
await ctx.send_command(command)
|
||||
|
||||
return {
|
||||
"message": "Stop command received",
|
||||
"command_id": str(command.command_id),
|
||||
"instance_id": str(instance_id),
|
||||
}
|
||||
|
||||
|
||||
async def handle_list_flash_instances(ctx: PluginContext) -> list[dict[str, Any]]:
|
||||
"""List all FLASH simulation instances."""
|
||||
flash_instances: list[dict[str, Any]] = []
|
||||
for instance_id, instance in ctx.state.instances.items():
|
||||
if isinstance(instance, FLASHInstance):
|
||||
# Get runner statuses for this instance
|
||||
runner_statuses: dict[str, str | None] = {}
|
||||
for (
|
||||
node_id,
|
||||
runner_id,
|
||||
) in instance.shard_assignments.node_to_runner.items():
|
||||
runner_status = ctx.state.runners.get(runner_id)
|
||||
runner_statuses[str(node_id)] = (
|
||||
str(runner_status) if runner_status else None
|
||||
)
|
||||
|
||||
flash_instances.append(
|
||||
{
|
||||
"instance_id": str(instance_id),
|
||||
"simulation_name": instance.simulation_name,
|
||||
"total_ranks": instance.total_ranks,
|
||||
"working_directory": instance.working_directory,
|
||||
"runner_statuses": runner_statuses,
|
||||
}
|
||||
)
|
||||
return flash_instances
|
||||
152
src/exo/plugins/implementations/flash/placement.py
Normal file
152
src/exo/plugins/implementations/flash/placement.py
Normal file
@@ -0,0 +1,152 @@
|
||||
"""FLASH plugin placement logic."""
|
||||
|
||||
from collections.abc import Mapping
|
||||
from copy import deepcopy
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from exo.shared.topology import Topology
|
||||
from exo.shared.types.commands import LaunchFLASH
|
||||
from exo.shared.types.common import Host, NodeId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.models import ModelId, ModelMetadata
|
||||
from exo.shared.types.topology import SocketConnection
|
||||
from exo.shared.types.worker.instances import FLASHInstance, Instance, InstanceId
|
||||
from exo.shared.types.worker.runners import (
|
||||
RunnerId,
|
||||
ShardAssignments,
|
||||
)
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata
|
||||
|
||||
|
||||
def place_flash_instance(
|
||||
command: LaunchFLASH,
|
||||
topology: Topology,
|
||||
current_instances: Mapping[InstanceId, Instance],
|
||||
) -> dict[InstanceId, Instance]:
|
||||
"""Place a FLASH simulation instance across available nodes.
|
||||
|
||||
Unlike MLX instances which use ring/JACCL topology for tensor parallelism,
|
||||
FLASH instances use MPI for communication. We just need to provide the
|
||||
node IPs so the runner can generate an MPI hostfile.
|
||||
"""
|
||||
instance_id = InstanceId()
|
||||
target_instances: dict[InstanceId, Instance] = dict(deepcopy(current_instances))
|
||||
|
||||
all_nodes = list(topology.list_nodes())
|
||||
|
||||
if len(all_nodes) < command.min_nodes:
|
||||
raise ValueError(
|
||||
f"Not enough nodes: need {command.min_nodes}, have {len(all_nodes)}"
|
||||
)
|
||||
|
||||
# Select nodes (take the first min_nodes)
|
||||
selected_nodes = all_nodes[: command.min_nodes]
|
||||
|
||||
logger.info(
|
||||
f"Placing FLASH instance '{command.simulation_name}' on {len(selected_nodes)} nodes"
|
||||
)
|
||||
|
||||
# Build shard assignments (one runner per node for FLASH)
|
||||
runner_to_shard: dict[RunnerId, PipelineShardMetadata] = {}
|
||||
node_to_runner: dict[NodeId, RunnerId] = {}
|
||||
|
||||
# Create a dummy ModelMetadata for FLASH (required by ShardMetadata interface)
|
||||
flash_model_meta = ModelMetadata(
|
||||
model_id=ModelId(command.simulation_name),
|
||||
pretty_name=f"FLASH: {command.simulation_name}",
|
||||
storage_size=Memory(in_bytes=0),
|
||||
n_layers=1,
|
||||
hidden_size=1,
|
||||
supports_tensor=False,
|
||||
)
|
||||
|
||||
for i, node_id in enumerate(selected_nodes):
|
||||
runner_id = RunnerId()
|
||||
node_to_runner[node_id] = runner_id
|
||||
runner_to_shard[runner_id] = PipelineShardMetadata(
|
||||
device_rank=i,
|
||||
world_size=len(selected_nodes),
|
||||
model_meta=flash_model_meta,
|
||||
start_layer=0,
|
||||
end_layer=1,
|
||||
n_layers=1,
|
||||
)
|
||||
|
||||
shard_assignments = ShardAssignments(
|
||||
model_id=ModelId(command.simulation_name),
|
||||
runner_to_shard=runner_to_shard,
|
||||
node_to_runner=node_to_runner,
|
||||
)
|
||||
|
||||
# Build hosts_by_node - get hostnames/IPs for MPI hostfile generation
|
||||
hosts_by_node: dict[NodeId, list[Host]] = {}
|
||||
|
||||
# If explicit hosts are provided, use them directly
|
||||
if command.hosts:
|
||||
explicit_hosts = [h.strip() for h in command.hosts.split(",") if h.strip()]
|
||||
logger.info(f"FLASH placement: explicit hosts provided: {explicit_hosts}")
|
||||
for i, node_id in enumerate(selected_nodes):
|
||||
if i < len(explicit_hosts):
|
||||
hosts_by_node[node_id] = [Host(ip=explicit_hosts[i], port=0)]
|
||||
logger.info(
|
||||
f"FLASH placement: node {node_id} (rank {i}) -> IP {explicit_hosts[i]}"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Not enough hosts provided for node {i}, using localhost"
|
||||
)
|
||||
hosts_by_node[node_id] = [Host(ip="127.0.0.1", port=0)]
|
||||
logger.info(
|
||||
f"FLASH placement: coordinator will be rank 0 at IP {explicit_hosts[0]}"
|
||||
)
|
||||
else:
|
||||
# Try to get IPs from topology edges
|
||||
for node_id in selected_nodes:
|
||||
node_hosts: list[Host] = []
|
||||
|
||||
# Get IP from outgoing edges (connections to other nodes via mDNS discovery)
|
||||
for conn in topology.out_edges(node_id):
|
||||
if isinstance(conn.edge, SocketConnection):
|
||||
# Extract IP from multiaddr
|
||||
ip = conn.edge.sink_multiaddr.ip_address
|
||||
# Skip link-local and localhost addresses
|
||||
if not ip.startswith("169.254.") and not ip.startswith("127."):
|
||||
node_hosts.append(Host(ip=ip, port=0))
|
||||
break
|
||||
|
||||
# Last resort: use localhost (will only work for single-node)
|
||||
if not node_hosts:
|
||||
logger.warning(
|
||||
f"Could not determine IP for node {node_id}, using localhost"
|
||||
)
|
||||
node_hosts.append(Host(ip="127.0.0.1", port=0))
|
||||
|
||||
hosts_by_node[node_id] = node_hosts
|
||||
|
||||
total_ranks = len(selected_nodes) * command.ranks_per_node
|
||||
|
||||
# Determine coordinator IP - first node's first host IP
|
||||
first_node_id: NodeId = next(iter(hosts_by_node.keys()))
|
||||
coordinator_ip: str = (
|
||||
hosts_by_node[first_node_id][0].ip
|
||||
if hosts_by_node[first_node_id]
|
||||
else "127.0.0.1"
|
||||
)
|
||||
|
||||
target_instances[instance_id] = FLASHInstance(
|
||||
instance_id=instance_id,
|
||||
shard_assignments=shard_assignments,
|
||||
hosts_by_node=hosts_by_node,
|
||||
flash_executable_path=command.flash_executable_path,
|
||||
parameter_file_path=command.parameter_file_path,
|
||||
working_directory=command.working_directory,
|
||||
ranks_per_node=command.ranks_per_node,
|
||||
total_ranks=total_ranks,
|
||||
simulation_name=command.simulation_name,
|
||||
coordinator_ip=coordinator_ip,
|
||||
)
|
||||
|
||||
logger.info(f"Created FLASH instance {instance_id} with {total_ranks} total ranks")
|
||||
|
||||
return target_instances
|
||||
36
src/exo/plugins/implementations/flash/planning.py
Normal file
36
src/exo/plugins/implementations/flash/planning.py
Normal file
@@ -0,0 +1,36 @@
|
||||
"""FLASH plugin planning logic."""
|
||||
|
||||
from collections.abc import Mapping
|
||||
|
||||
from exo.shared.types.tasks import LoadModel, Task
|
||||
from exo.shared.types.worker.instances import FLASHInstance, Instance, InstanceId
|
||||
from exo.shared.types.worker.runners import RunnerId, RunnerIdle
|
||||
from exo.worker.runner.runner_supervisor import RunnerSupervisor
|
||||
|
||||
|
||||
def plan_flash(
|
||||
runners: Mapping[RunnerId, RunnerSupervisor],
|
||||
instances: Mapping[InstanceId, Instance],
|
||||
) -> Task | None:
|
||||
"""Plan tasks specifically for FLASH instances.
|
||||
|
||||
FLASH instances have a simpler lifecycle:
|
||||
- CreateRunner (handled by core _create_runner)
|
||||
- LoadModel (starts the simulation immediately)
|
||||
- Shutdown (handled by core _kill_runner)
|
||||
|
||||
This function handles the LoadModel step for FLASH instances,
|
||||
skipping the MLX-specific download/init/warmup steps.
|
||||
"""
|
||||
for runner in runners.values():
|
||||
instance = runner.bound_instance.instance
|
||||
|
||||
# Only handle FLASH instances
|
||||
if not isinstance(instance, FLASHInstance):
|
||||
continue
|
||||
|
||||
# If runner is idle, emit LoadModel to start the simulation
|
||||
if isinstance(runner.status, RunnerIdle):
|
||||
return LoadModel(instance_id=instance.instance_id)
|
||||
|
||||
return None
|
||||
98
src/exo/plugins/implementations/flash/plugin.py
Normal file
98
src/exo/plugins/implementations/flash/plugin.py
Normal file
@@ -0,0 +1,98 @@
|
||||
"""FLASH Plugin - Main plugin class."""
|
||||
|
||||
from collections.abc import Callable, Mapping, Sequence
|
||||
from typing import Any
|
||||
|
||||
from exo.plugins.base import ExoPlugin
|
||||
from exo.plugins.implementations.flash.api_handlers import (
|
||||
handle_launch_flash,
|
||||
handle_list_flash_instances,
|
||||
handle_stop_flash,
|
||||
)
|
||||
from exo.plugins.implementations.flash.placement import place_flash_instance
|
||||
from exo.plugins.implementations.flash.planning import plan_flash
|
||||
from exo.plugins.implementations.flash.runner import main as flash_runner_main
|
||||
from exo.shared.topology import Topology
|
||||
from exo.shared.types.commands import DeleteInstance, LaunchFLASH, StopFLASH
|
||||
from exo.shared.types.events import Event
|
||||
from exo.shared.types.tasks import Task
|
||||
from exo.shared.types.worker.instances import (
|
||||
BoundInstance,
|
||||
FLASHInstance,
|
||||
Instance,
|
||||
InstanceId,
|
||||
)
|
||||
from exo.shared.types.worker.runners import RunnerId
|
||||
from exo.utils.channels import MpReceiver, MpSender
|
||||
from exo.worker.runner.runner_supervisor import RunnerSupervisor
|
||||
|
||||
|
||||
class FLASHPlugin(ExoPlugin):
|
||||
"""Plugin for FLASH MPI simulations."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "flash"
|
||||
|
||||
@property
|
||||
def version(self) -> str:
|
||||
return "1.0.0"
|
||||
|
||||
def get_command_types(self) -> Sequence[type]:
|
||||
return [LaunchFLASH, StopFLASH]
|
||||
|
||||
def get_instance_type(self) -> type:
|
||||
return FLASHInstance
|
||||
|
||||
def get_api_routes(
|
||||
self,
|
||||
) -> Sequence[tuple[str, str, Callable[..., Any]]]:
|
||||
return [
|
||||
("post", "/flash/launch", handle_launch_flash),
|
||||
("delete", "/flash/{instance_id}", handle_stop_flash),
|
||||
("get", "/flash/instances", handle_list_flash_instances),
|
||||
]
|
||||
|
||||
def handles_command(self, command: Any) -> bool: # pyright: ignore[reportAny]
|
||||
return isinstance(command, (LaunchFLASH, StopFLASH))
|
||||
|
||||
def process_command(
|
||||
self,
|
||||
command: Any, # pyright: ignore[reportAny]
|
||||
topology: Topology,
|
||||
current_instances: Mapping[InstanceId, Instance],
|
||||
) -> Sequence[Event]:
|
||||
from exo.master.placement import delete_instance, get_transition_events
|
||||
|
||||
if isinstance(command, LaunchFLASH):
|
||||
placement = place_flash_instance(command, topology, current_instances)
|
||||
return list(get_transition_events(current_instances, placement))
|
||||
elif isinstance(command, StopFLASH):
|
||||
placement = delete_instance(
|
||||
DeleteInstance(instance_id=command.instance_id),
|
||||
current_instances,
|
||||
)
|
||||
return list(get_transition_events(current_instances, placement))
|
||||
return []
|
||||
|
||||
def handles_instance(self, instance: object) -> bool:
|
||||
return isinstance(instance, FLASHInstance)
|
||||
|
||||
def plan_task(
|
||||
self,
|
||||
runners: Mapping[RunnerId, RunnerSupervisor],
|
||||
instances: Mapping[InstanceId, Instance],
|
||||
) -> Task | None:
|
||||
return plan_flash(runners, instances)
|
||||
|
||||
def should_skip_download(self, instance: object) -> bool:
|
||||
# FLASH instances don't need model downloads
|
||||
return True
|
||||
|
||||
def create_runner(
|
||||
self,
|
||||
bound_instance: BoundInstance,
|
||||
event_sender: MpSender[Event],
|
||||
task_receiver: MpReceiver[Task],
|
||||
) -> None:
|
||||
flash_runner_main(bound_instance, event_sender, task_receiver)
|
||||
302
src/exo/plugins/implementations/flash/runner.py
Normal file
302
src/exo/plugins/implementations/flash/runner.py
Normal file
@@ -0,0 +1,302 @@
|
||||
"""FLASH MPI Runner - spawns and monitors FLASH simulations.
|
||||
|
||||
Exo-native distributed MPI:
|
||||
- Exo handles node discovery and coordination
|
||||
- Coordinator generates hostfile from Exo topology
|
||||
- mpirun uses exo-rsh (no SSH required) to spawn on remote nodes
|
||||
- exo-rsh connects to each node's Exo API (/execute endpoint) for remote execution
|
||||
- Workers just report ready and wait
|
||||
"""
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import socket
|
||||
import subprocess
|
||||
import threading
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from exo.shared.types.events import (
|
||||
Event,
|
||||
RunnerStatusUpdated,
|
||||
TaskAcknowledged,
|
||||
TaskStatusUpdated,
|
||||
)
|
||||
from exo.shared.types.tasks import (
|
||||
LoadModel,
|
||||
Shutdown,
|
||||
Task,
|
||||
TaskStatus,
|
||||
)
|
||||
from exo.shared.types.worker.instances import BoundInstance, FLASHInstance
|
||||
from exo.shared.types.worker.runners import (
|
||||
RunnerFailed,
|
||||
RunnerIdle,
|
||||
RunnerLoading,
|
||||
RunnerReady,
|
||||
RunnerRunning,
|
||||
RunnerShutdown,
|
||||
RunnerShuttingDown,
|
||||
RunnerStatus,
|
||||
)
|
||||
from exo.utils.channels import MpReceiver, MpSender
|
||||
|
||||
# Find mpirun in PATH, fallback to common locations
|
||||
MPIRUN_PATH = shutil.which("mpirun") or "/opt/homebrew/bin/mpirun"
|
||||
|
||||
# exo-rsh is installed as console script by exo package
|
||||
_exo_rsh_path = shutil.which("exo-rsh")
|
||||
if not _exo_rsh_path:
|
||||
raise RuntimeError("exo-rsh not found in PATH - this should be installed with exo")
|
||||
EXO_RSH_PATH: str = _exo_rsh_path
|
||||
|
||||
|
||||
def get_my_rank(instance: FLASHInstance, my_node_id: str) -> int:
|
||||
"""Determine this node's rank based on position in hosts_by_node."""
|
||||
for i, node_id in enumerate(instance.hosts_by_node.keys()):
|
||||
if str(node_id) == str(my_node_id):
|
||||
return i
|
||||
return -1
|
||||
|
||||
|
||||
def get_coordinator_host(instance: FLASHInstance) -> str:
|
||||
"""Get the IP of the coordinator node."""
|
||||
return instance.coordinator_ip
|
||||
|
||||
|
||||
def resolve_host(host: str) -> str:
|
||||
"""Resolve host string to a usable hostname for MPI hostfile.
|
||||
|
||||
Accepts either an IP address or hostname. For IPs, attempts to resolve
|
||||
to a hostname via DNS/mDNS. Hostnames are returned as-is after validation.
|
||||
"""
|
||||
# Check if input is already a hostname (not an IP)
|
||||
try:
|
||||
socket.inet_aton(host)
|
||||
is_ip = True
|
||||
except socket.error:
|
||||
is_ip = False
|
||||
|
||||
if not is_ip:
|
||||
# Already a hostname, verify it resolves and return as-is
|
||||
try:
|
||||
socket.gethostbyname(host)
|
||||
return host
|
||||
except socket.gaierror:
|
||||
logger.warning(f"Hostname {host} does not resolve, using anyway")
|
||||
return host
|
||||
|
||||
# It's an IP address, try to resolve to hostname
|
||||
try:
|
||||
hostname, _, _ = socket.gethostbyaddr(host)
|
||||
hostname = hostname.split(".")[0]
|
||||
logger.info(f"Resolved {host} to {hostname}")
|
||||
return hostname
|
||||
except socket.herror:
|
||||
pass
|
||||
|
||||
# Fall back to IP
|
||||
logger.warning(f"Could not resolve {host} to hostname, using IP directly")
|
||||
return host
|
||||
|
||||
|
||||
def generate_hostfile(instance: FLASHInstance, working_dir: str) -> str:
|
||||
"""Generate MPI hostfile from instance topology."""
|
||||
hostfile_path = os.path.join(working_dir, "flash_hosts.txt")
|
||||
with open(hostfile_path, "w") as f:
|
||||
for _node_id, hosts in instance.hosts_by_node.items():
|
||||
if hosts:
|
||||
host = resolve_host(hosts[0].ip)
|
||||
f.write(f"{host} slots={instance.ranks_per_node}\n")
|
||||
logger.info(f"Generated hostfile at {hostfile_path}")
|
||||
with open(hostfile_path, "r") as f:
|
||||
logger.info(f"Hostfile contents:\n{f.read()}")
|
||||
return hostfile_path
|
||||
|
||||
|
||||
def main(
|
||||
bound_instance: BoundInstance,
|
||||
event_sender: MpSender[Event],
|
||||
task_receiver: MpReceiver[Task],
|
||||
) -> None:
|
||||
"""Main FLASH runner loop.
|
||||
|
||||
Coordinator: generates hostfile and runs mpirun (uses exo-rsh instead of SSH)
|
||||
Workers: just report ready and wait for mpirun to spawn processes on them
|
||||
"""
|
||||
assert isinstance(bound_instance.instance, FLASHInstance)
|
||||
instance = bound_instance.instance
|
||||
runner_id = bound_instance.bound_runner_id
|
||||
my_node_id = str(bound_instance.bound_node_id)
|
||||
|
||||
logger.info(f"FLASH runner starting for simulation: {instance.simulation_name}")
|
||||
|
||||
my_rank = get_my_rank(instance, my_node_id)
|
||||
world_size = len(instance.hosts_by_node)
|
||||
is_coordinator = my_rank == 0
|
||||
coordinator_ip = get_coordinator_host(instance)
|
||||
|
||||
logger.info(
|
||||
f"FLASH node: rank={my_rank}, world_size={world_size}, coordinator={is_coordinator}"
|
||||
)
|
||||
logger.info(f"FLASH coordinator IP: {coordinator_ip}")
|
||||
|
||||
process: subprocess.Popen[bytes] | None = None
|
||||
current_status: RunnerStatus = RunnerIdle()
|
||||
shutdown_requested = False
|
||||
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
|
||||
)
|
||||
|
||||
def monitor_output(proc: subprocess.Popen[bytes]) -> None:
|
||||
"""Monitor FLASH stdout for progress updates."""
|
||||
if proc.stdout is None:
|
||||
return
|
||||
for line in iter(proc.stdout.readline, b""):
|
||||
if shutdown_requested:
|
||||
break
|
||||
try:
|
||||
decoded: str = line.decode("utf-8", errors="replace").strip()
|
||||
if decoded:
|
||||
logger.info(f"[FLASH] {decoded}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error parsing FLASH output: {e}")
|
||||
|
||||
with task_receiver as tasks:
|
||||
for task in tasks:
|
||||
event_sender.send(
|
||||
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Running)
|
||||
)
|
||||
event_sender.send(TaskAcknowledged(task_id=task.task_id))
|
||||
|
||||
match task:
|
||||
case LoadModel() if isinstance(current_status, RunnerIdle):
|
||||
current_status = RunnerLoading()
|
||||
logger.info("Starting FLASH simulation")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
if is_coordinator:
|
||||
# Coordinator: generate hostfile and run mpirun
|
||||
hostfile = generate_hostfile(
|
||||
instance, instance.working_directory
|
||||
)
|
||||
|
||||
iface = instance.network_interface
|
||||
cmd = [
|
||||
MPIRUN_PATH,
|
||||
"-np",
|
||||
str(instance.total_ranks),
|
||||
"--hostfile",
|
||||
hostfile,
|
||||
"--wdir",
|
||||
instance.working_directory,
|
||||
"--oversubscribe",
|
||||
"--mca",
|
||||
"btl",
|
||||
"tcp,self",
|
||||
"--mca",
|
||||
"btl_tcp_if_include",
|
||||
iface,
|
||||
"--mca",
|
||||
"oob_tcp_if_include",
|
||||
iface,
|
||||
"--mca",
|
||||
"plm_rsh_no_tree_spawn",
|
||||
"1",
|
||||
]
|
||||
|
||||
# Use exo-rsh for remote execution (no SSH needed)
|
||||
cmd.extend(["--mca", "plm_rsh_agent", EXO_RSH_PATH])
|
||||
|
||||
cmd.append(instance.flash_executable_path)
|
||||
|
||||
logger.info(f"FLASH distributed launch: {' '.join(cmd)}")
|
||||
|
||||
process = subprocess.Popen(
|
||||
cmd,
|
||||
cwd=instance.working_directory,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
)
|
||||
|
||||
monitor_thread = threading.Thread(
|
||||
target=monitor_output, args=(process,), daemon=True
|
||||
)
|
||||
monitor_thread.start()
|
||||
|
||||
current_status = RunnerRunning()
|
||||
logger.info(
|
||||
f"FLASH running on {world_size} nodes with {instance.total_ranks} ranks"
|
||||
)
|
||||
|
||||
else:
|
||||
# Worker: mpirun on coordinator will use exo-rsh to spawn processes here
|
||||
logger.info(
|
||||
f"Worker {my_rank}: Ready for mpirun to spawn processes via exo-rsh"
|
||||
)
|
||||
current_status = RunnerRunning()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start FLASH: {e}")
|
||||
import traceback
|
||||
|
||||
logger.error(traceback.format_exc())
|
||||
current_status = RunnerFailed(error_message=str(e))
|
||||
|
||||
case Shutdown():
|
||||
shutdown_requested = True
|
||||
current_status = RunnerShuttingDown()
|
||||
logger.info("FLASH runner shutting down")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
|
||||
if process and process.poll() is None:
|
||||
logger.info("Terminating FLASH simulation")
|
||||
process.terminate()
|
||||
try:
|
||||
process.wait(timeout=10)
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.warning("FLASH didn't terminate, killing")
|
||||
process.kill()
|
||||
process.wait()
|
||||
|
||||
current_status = RunnerShutdown()
|
||||
|
||||
case _:
|
||||
if process and process.poll() is not None:
|
||||
exit_code = process.returncode
|
||||
if exit_code == 0:
|
||||
logger.info("FLASH simulation completed successfully")
|
||||
current_status = RunnerReady()
|
||||
else:
|
||||
logger.error(
|
||||
f"FLASH simulation failed with code {exit_code}"
|
||||
)
|
||||
current_status = RunnerFailed(
|
||||
error_message=f"Exit code {exit_code}"
|
||||
)
|
||||
|
||||
event_sender.send(
|
||||
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Complete)
|
||||
)
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
|
||||
)
|
||||
|
||||
if isinstance(current_status, RunnerShutdown):
|
||||
break
|
||||
|
||||
if process and process.poll() is None:
|
||||
process.terminate()
|
||||
process.wait(timeout=5)
|
||||
|
||||
logger.info("FLASH runner exiting")
|
||||
110
src/exo/plugins/registry.py
Normal file
110
src/exo/plugins/registry.py
Normal file
@@ -0,0 +1,110 @@
|
||||
"""Plugin registry for discovering and managing plugins."""
|
||||
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from exo.plugins.base import ExoPlugin
|
||||
|
||||
|
||||
class PluginRegistry:
|
||||
"""Central registry for all plugins."""
|
||||
|
||||
_instance: "PluginRegistry | None" = None
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._plugins: dict[str, ExoPlugin] = {}
|
||||
self._command_handlers: dict[type, ExoPlugin] = {}
|
||||
self._instance_handlers: dict[type, ExoPlugin] = {}
|
||||
|
||||
@classmethod
|
||||
def get(cls) -> "PluginRegistry":
|
||||
"""Get the singleton registry instance."""
|
||||
if cls._instance is None:
|
||||
cls._instance = cls()
|
||||
return cls._instance
|
||||
|
||||
@classmethod
|
||||
def reset(cls) -> None:
|
||||
"""Reset the singleton instance (useful for testing)."""
|
||||
cls._instance = None
|
||||
|
||||
def register(self, plugin: ExoPlugin) -> None:
|
||||
"""Register a plugin and its types."""
|
||||
if plugin.name in self._plugins:
|
||||
raise ValueError(f"Plugin '{plugin.name}' already registered")
|
||||
|
||||
logger.info(f"Registering plugin: {plugin.name} v{plugin.version}")
|
||||
|
||||
self._plugins[plugin.name] = plugin
|
||||
|
||||
# Register command handlers
|
||||
for cmd_type in plugin.get_command_types():
|
||||
self._command_handlers[cmd_type] = plugin
|
||||
logger.debug(f" Registered command: {cmd_type.__name__}")
|
||||
|
||||
# Register instance handler
|
||||
instance_type = plugin.get_instance_type()
|
||||
self._instance_handlers[instance_type] = plugin
|
||||
logger.debug(f" Registered instance: {instance_type.__name__}")
|
||||
|
||||
def get_plugin(self, name: str) -> ExoPlugin | None:
|
||||
"""Get a plugin by name."""
|
||||
return self._plugins.get(name)
|
||||
|
||||
def get_plugin_for_command(self, command: object) -> ExoPlugin | None:
|
||||
"""Get the plugin that handles a command."""
|
||||
for plugin in self._plugins.values():
|
||||
if plugin.handles_command(command):
|
||||
return plugin
|
||||
return None
|
||||
|
||||
def get_plugin_for_instance(self, instance: object) -> ExoPlugin | None:
|
||||
"""Get the plugin that manages an instance."""
|
||||
for plugin in self._plugins.values():
|
||||
if plugin.handles_instance(instance):
|
||||
return plugin
|
||||
return None
|
||||
|
||||
def all_plugins(self) -> Sequence[ExoPlugin]:
|
||||
"""Get all registered plugins."""
|
||||
return list(self._plugins.values())
|
||||
|
||||
def get_all_api_routes(
|
||||
self,
|
||||
) -> Sequence[tuple[str, str, Callable[..., Any], ExoPlugin]]:
|
||||
"""Get all API routes from all plugins."""
|
||||
routes: list[tuple[str, str, Callable[..., Any], ExoPlugin]] = []
|
||||
for plugin in self._plugins.values():
|
||||
for method, path, handler in plugin.get_api_routes():
|
||||
routes.append((method, path, handler, plugin))
|
||||
return routes
|
||||
|
||||
|
||||
def discover_plugins() -> None:
|
||||
"""Auto-discover and register plugins from the implementations directory.
|
||||
|
||||
Plugins should have a register() function that returns an ExoPlugin instance.
|
||||
"""
|
||||
import importlib
|
||||
import pkgutil
|
||||
|
||||
registry = PluginRegistry.get()
|
||||
|
||||
try:
|
||||
import exo.plugins.implementations as impl_package
|
||||
|
||||
for _, module_name, _ in pkgutil.iter_modules(impl_package.__path__):
|
||||
try:
|
||||
module = importlib.import_module(
|
||||
f"exo.plugins.implementations.{module_name}"
|
||||
)
|
||||
if hasattr(module, "register"):
|
||||
plugin = module.register() # pyright: ignore[reportAny]
|
||||
if plugin is not None:
|
||||
registry.register(plugin) # pyright: ignore[reportAny]
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load plugin {module_name}: {e}")
|
||||
except ImportError:
|
||||
logger.debug("No plugin implementations package found")
|
||||
13
src/exo/rsh/__init__.py
Normal file
13
src/exo/rsh/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
"""Exo RSH - Remote Shell for MPI without SSH.
|
||||
|
||||
This module provides a remote execution mechanism that allows mpirun to spawn
|
||||
processes on remote nodes without requiring SSH setup. It works by:
|
||||
|
||||
1. Each Exo node runs an API server on port 52415 with an /execute endpoint
|
||||
2. The exo-rsh script acts as a drop-in replacement for ssh
|
||||
3. When mpirun calls "exo-rsh hostname command", it HTTP POSTs to the target's /execute
|
||||
4. The target executes the command and returns output
|
||||
|
||||
Usage:
|
||||
mpirun --mca plm_rsh_agent exo-rsh -np 4 --hostfile hosts.txt ./program
|
||||
"""
|
||||
101
src/exo/rsh/client.py
Normal file
101
src/exo/rsh/client.py
Normal file
@@ -0,0 +1,101 @@
|
||||
#!/usr/bin/env python3
|
||||
"""exo-rsh - Remote shell client for MPI.
|
||||
|
||||
This script is called by mpirun as a replacement for ssh.
|
||||
Usage: exo-rsh [ssh-options...] hostname command [args...]
|
||||
|
||||
It connects to the target node's Exo API (port 52415) and executes the command.
|
||||
"""
|
||||
|
||||
import json
|
||||
import socket
|
||||
import sys
|
||||
from typing import Any, cast
|
||||
from urllib.error import URLError
|
||||
from urllib.request import Request, urlopen
|
||||
|
||||
# Use the same port as Exo's API server
|
||||
EXO_API_PORT = 52415
|
||||
|
||||
|
||||
def resolve_hostname(hostname: str) -> str:
|
||||
"""Resolve hostname to IP address."""
|
||||
try:
|
||||
return socket.gethostbyname(hostname)
|
||||
except socket.gaierror:
|
||||
# If resolution fails, try using the hostname directly
|
||||
return hostname
|
||||
|
||||
|
||||
def main():
|
||||
# Parse arguments - mpirun calls us like: exo-rsh [options] hostname command [args...]
|
||||
# SSH options we might see: -x (disable X11), -o options, etc.
|
||||
args = sys.argv[1:]
|
||||
|
||||
# Skip SSH-style options
|
||||
hostname = None
|
||||
command_start = 0
|
||||
|
||||
i = 0
|
||||
while i < len(args):
|
||||
arg = args[i]
|
||||
if arg.startswith("-"):
|
||||
# Skip option and its value if needed
|
||||
if arg in ("-o", "-i", "-l", "-p", "-F"):
|
||||
i += 2 # Skip option and its argument
|
||||
continue
|
||||
i += 1
|
||||
continue
|
||||
else:
|
||||
# First non-option is the hostname
|
||||
hostname = arg
|
||||
command_start = i + 1
|
||||
break
|
||||
i += 1
|
||||
|
||||
if hostname is None or command_start >= len(args):
|
||||
print("Usage: exo-rsh [options] hostname command [args...]", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
command = args[command_start:]
|
||||
|
||||
# Resolve hostname to IP
|
||||
ip = resolve_hostname(hostname)
|
||||
|
||||
# Make request to Exo API
|
||||
url = f"http://{ip}:{EXO_API_PORT}/execute"
|
||||
data = json.dumps({"command": command}).encode("utf-8")
|
||||
|
||||
try:
|
||||
req = Request(url, data=data, headers={"Content-Type": "application/json"})
|
||||
with urlopen(req, timeout=300) as response: # pyright: ignore[reportAny]
|
||||
response_body: bytes = cast(bytes, response.read()) # pyright: ignore[reportAny]
|
||||
result: dict[str, Any] = json.loads(response_body.decode("utf-8")) # pyright: ignore[reportAny]
|
||||
|
||||
# Output stdout/stderr
|
||||
stdout: str = cast(str, result.get("stdout", ""))
|
||||
stderr: str = cast(str, result.get("stderr", ""))
|
||||
exit_code: int = cast(int, result.get("exit_code", 0))
|
||||
|
||||
if stdout:
|
||||
sys.stdout.write(stdout)
|
||||
sys.stdout.flush()
|
||||
if stderr:
|
||||
sys.stderr.write(stderr)
|
||||
sys.stderr.flush()
|
||||
|
||||
sys.exit(exit_code)
|
||||
|
||||
except URLError as e:
|
||||
print(
|
||||
f"exo-rsh: Failed to connect to {hostname}:{EXO_API_PORT}: {e}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(255)
|
||||
except Exception as e:
|
||||
print(f"exo-rsh: Error: {e}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,310 +1,552 @@
|
||||
from pydantic import PositiveInt
|
||||
|
||||
from exo.shared.types.common import Id
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.models import ModelId, ModelMetadata
|
||||
from exo.utils.pydantic_ext import CamelCaseModel
|
||||
|
||||
|
||||
class ModelId(Id):
|
||||
def normalize(self) -> str:
|
||||
return self.replace("/", "--")
|
||||
|
||||
def short(self) -> str:
|
||||
return self.split("/")[-1]
|
||||
|
||||
|
||||
class ModelCard(CamelCaseModel):
|
||||
short_id: str
|
||||
model_id: ModelId
|
||||
storage_size: Memory
|
||||
n_layers: PositiveInt
|
||||
hidden_size: PositiveInt
|
||||
supports_tensor: bool
|
||||
name: str
|
||||
description: str
|
||||
tags: list[str]
|
||||
metadata: ModelMetadata
|
||||
|
||||
|
||||
MODEL_CARDS: dict[str, ModelCard] = {
|
||||
# deepseek v3
|
||||
"deepseek-v3.1-4bit": ModelCard(
|
||||
short_id="deepseek-v3.1-4bit",
|
||||
model_id=ModelId("mlx-community/DeepSeek-V3.1-4bit"),
|
||||
storage_size=Memory.from_gb(378),
|
||||
n_layers=61,
|
||||
hidden_size=7168,
|
||||
supports_tensor=True,
|
||||
name="DeepSeek V3.1 (4-bit)",
|
||||
description="""DeepSeek V3.1 is a large language model trained on the DeepSeek V3.1 dataset.""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/DeepSeek-V3.1-4bit"),
|
||||
pretty_name="DeepSeek V3.1 (4-bit)",
|
||||
storage_size=Memory.from_gb(378),
|
||||
n_layers=61,
|
||||
hidden_size=7168,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"deepseek-v3.1-8bit": ModelCard(
|
||||
short_id="deepseek-v3.1-8bit",
|
||||
model_id=ModelId("mlx-community/DeepSeek-V3.1-8bit"),
|
||||
storage_size=Memory.from_gb(713),
|
||||
n_layers=61,
|
||||
hidden_size=7168,
|
||||
supports_tensor=True,
|
||||
name="DeepSeek V3.1 (8-bit)",
|
||||
description="""DeepSeek V3.1 is a large language model trained on the DeepSeek V3.1 dataset.""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/DeepSeek-V3.1-8bit"),
|
||||
pretty_name="DeepSeek V3.1 (8-bit)",
|
||||
storage_size=Memory.from_gb(713),
|
||||
n_layers=61,
|
||||
hidden_size=7168,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
# kimi k2
|
||||
"kimi-k2-instruct-4bit": ModelCard(
|
||||
short_id="kimi-k2-instruct-4bit",
|
||||
model_id=ModelId("mlx-community/Kimi-K2-Instruct-4bit"),
|
||||
storage_size=Memory.from_gb(578),
|
||||
n_layers=61,
|
||||
hidden_size=7168,
|
||||
supports_tensor=True,
|
||||
name="Kimi K2 Instruct (4-bit)",
|
||||
description="""Kimi K2 is a large language model trained on the Kimi K2 dataset.""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Kimi-K2-Instruct-4bit"),
|
||||
pretty_name="Kimi K2 Instruct (4-bit)",
|
||||
storage_size=Memory.from_gb(578),
|
||||
n_layers=61,
|
||||
hidden_size=7168,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"kimi-k2-thinking": ModelCard(
|
||||
short_id="kimi-k2-thinking",
|
||||
model_id=ModelId("mlx-community/Kimi-K2-Thinking"),
|
||||
storage_size=Memory.from_gb(658),
|
||||
n_layers=61,
|
||||
hidden_size=7168,
|
||||
supports_tensor=True,
|
||||
name="Kimi K2 Thinking (4-bit)",
|
||||
description="""Kimi K2 Thinking is the latest, most capable version of open-source thinking model.""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Kimi-K2-Thinking"),
|
||||
pretty_name="Kimi K2 Thinking (4-bit)",
|
||||
storage_size=Memory.from_gb(658),
|
||||
n_layers=61,
|
||||
hidden_size=7168,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
# llama-3.1
|
||||
"llama-3.1-8b": ModelCard(
|
||||
short_id="llama-3.1-8b",
|
||||
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-4bit"),
|
||||
storage_size=Memory.from_mb(4423),
|
||||
n_layers=32,
|
||||
hidden_size=4096,
|
||||
supports_tensor=True,
|
||||
name="Llama 3.1 8B (4-bit)",
|
||||
description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-4bit"),
|
||||
pretty_name="Llama 3.1 8B (4-bit)",
|
||||
storage_size=Memory.from_mb(4423),
|
||||
n_layers=32,
|
||||
hidden_size=4096,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"llama-3.1-8b-8bit": ModelCard(
|
||||
short_id="llama-3.1-8b-8bit",
|
||||
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-8bit"),
|
||||
storage_size=Memory.from_mb(8540),
|
||||
n_layers=32,
|
||||
hidden_size=4096,
|
||||
supports_tensor=True,
|
||||
name="Llama 3.1 8B (8-bit)",
|
||||
description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-8bit"),
|
||||
pretty_name="Llama 3.1 8B (8-bit)",
|
||||
storage_size=Memory.from_mb(8540),
|
||||
n_layers=32,
|
||||
hidden_size=4096,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"llama-3.1-8b-bf16": ModelCard(
|
||||
short_id="llama-3.1-8b-bf16",
|
||||
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-bf16"),
|
||||
storage_size=Memory.from_mb(16100),
|
||||
n_layers=32,
|
||||
hidden_size=4096,
|
||||
supports_tensor=True,
|
||||
name="Llama 3.1 8B (BF16)",
|
||||
description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-bf16"),
|
||||
pretty_name="Llama 3.1 8B (BF16)",
|
||||
storage_size=Memory.from_mb(16100),
|
||||
n_layers=32,
|
||||
hidden_size=4096,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"llama-3.1-70b": ModelCard(
|
||||
short_id="llama-3.1-70b",
|
||||
model_id=ModelId("mlx-community/Meta-Llama-3.1-70B-Instruct-4bit"),
|
||||
storage_size=Memory.from_mb(38769),
|
||||
n_layers=80,
|
||||
hidden_size=8192,
|
||||
supports_tensor=True,
|
||||
name="Llama 3.1 70B (4-bit)",
|
||||
description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Meta-Llama-3.1-70B-Instruct-4bit"),
|
||||
pretty_name="Llama 3.1 70B (4-bit)",
|
||||
storage_size=Memory.from_mb(38769),
|
||||
n_layers=80,
|
||||
hidden_size=8192,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
# llama-3.2
|
||||
"llama-3.2-1b": ModelCard(
|
||||
short_id="llama-3.2-1b",
|
||||
model_id=ModelId("mlx-community/Llama-3.2-1B-Instruct-4bit"),
|
||||
storage_size=Memory.from_mb(696),
|
||||
n_layers=16,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
name="Llama 3.2 1B (4-bit)",
|
||||
description="""Llama 3.2 is a large language model trained on the Llama 3.2 dataset.""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Llama-3.2-1B-Instruct-4bit"),
|
||||
pretty_name="Llama 3.2 1B (4-bit)",
|
||||
storage_size=Memory.from_mb(696),
|
||||
n_layers=16,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"llama-3.2-3b": ModelCard(
|
||||
short_id="llama-3.2-3b",
|
||||
model_id=ModelId("mlx-community/Llama-3.2-3B-Instruct-4bit"),
|
||||
storage_size=Memory.from_mb(1777),
|
||||
n_layers=28,
|
||||
hidden_size=3072,
|
||||
supports_tensor=True,
|
||||
name="Llama 3.2 3B (4-bit)",
|
||||
description="""Llama 3.2 is a large language model trained on the Llama 3.2 dataset.""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Llama-3.2-3B-Instruct-4bit"),
|
||||
pretty_name="Llama 3.2 3B (4-bit)",
|
||||
storage_size=Memory.from_mb(1777),
|
||||
n_layers=28,
|
||||
hidden_size=3072,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"llama-3.2-3b-8bit": ModelCard(
|
||||
short_id="llama-3.2-3b-8bit",
|
||||
model_id=ModelId("mlx-community/Llama-3.2-3B-Instruct-8bit"),
|
||||
storage_size=Memory.from_mb(3339),
|
||||
n_layers=28,
|
||||
hidden_size=3072,
|
||||
supports_tensor=True,
|
||||
name="Llama 3.2 3B (8-bit)",
|
||||
description="""Llama 3.2 is a large language model trained on the Llama 3.2 dataset.""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Llama-3.2-3B-Instruct-8bit"),
|
||||
pretty_name="Llama 3.2 3B (8-bit)",
|
||||
storage_size=Memory.from_mb(3339),
|
||||
n_layers=28,
|
||||
hidden_size=3072,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
# llama-3.3
|
||||
"llama-3.3-70b": ModelCard(
|
||||
short_id="llama-3.3-70b",
|
||||
model_id=ModelId("mlx-community/Llama-3.3-70B-Instruct-4bit"),
|
||||
storage_size=Memory.from_mb(38769),
|
||||
n_layers=80,
|
||||
hidden_size=8192,
|
||||
supports_tensor=True,
|
||||
name="Llama 3.3 70B (4-bit)",
|
||||
description="""The Meta Llama 3.3 multilingual large language model (LLM) is an instruction tuned generative model in 70B (text in/text out)""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Llama-3.3-70B-Instruct-4bit"),
|
||||
pretty_name="Llama 3.3 70B",
|
||||
storage_size=Memory.from_mb(38769),
|
||||
n_layers=80,
|
||||
hidden_size=8192,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"llama-3.3-70b-8bit": ModelCard(
|
||||
short_id="llama-3.3-70b-8bit",
|
||||
model_id=ModelId("mlx-community/Llama-3.3-70B-Instruct-8bit"),
|
||||
storage_size=Memory.from_mb(73242),
|
||||
n_layers=80,
|
||||
hidden_size=8192,
|
||||
supports_tensor=True,
|
||||
name="Llama 3.3 70B (8-bit)",
|
||||
description="""The Meta Llama 3.3 multilingual large language model (LLM) is an instruction tuned generative model in 70B (text in/text out)""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Llama-3.3-70B-Instruct-8bit"),
|
||||
pretty_name="Llama 3.3 70B (8-bit)",
|
||||
storage_size=Memory.from_mb(73242),
|
||||
n_layers=80,
|
||||
hidden_size=8192,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"llama-3.3-70b-fp16": ModelCard(
|
||||
short_id="llama-3.3-70b-fp16",
|
||||
model_id=ModelId("mlx-community/llama-3.3-70b-instruct-fp16"),
|
||||
storage_size=Memory.from_mb(137695),
|
||||
n_layers=80,
|
||||
hidden_size=8192,
|
||||
supports_tensor=True,
|
||||
name="Llama 3.3 70B (FP16)",
|
||||
description="""The Meta Llama 3.3 multilingual large language model (LLM) is an instruction tuned generative model in 70B (text in/text out)""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/llama-3.3-70b-instruct-fp16"),
|
||||
pretty_name="Llama 3.3 70B (FP16)",
|
||||
storage_size=Memory.from_mb(137695),
|
||||
n_layers=80,
|
||||
hidden_size=8192,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
# qwen3
|
||||
"qwen3-0.6b": ModelCard(
|
||||
short_id="qwen3-0.6b",
|
||||
model_id=ModelId("mlx-community/Qwen3-0.6B-4bit"),
|
||||
storage_size=Memory.from_mb(327),
|
||||
n_layers=28,
|
||||
hidden_size=1024,
|
||||
supports_tensor=False,
|
||||
name="Qwen3 0.6B (4-bit)",
|
||||
description="""Qwen3 0.6B is a large language model trained on the Qwen3 0.6B dataset.""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Qwen3-0.6B-4bit"),
|
||||
pretty_name="Qwen3 0.6B (4-bit)",
|
||||
storage_size=Memory.from_mb(327),
|
||||
n_layers=28,
|
||||
hidden_size=1024,
|
||||
supports_tensor=False,
|
||||
),
|
||||
),
|
||||
"qwen3-0.6b-8bit": ModelCard(
|
||||
short_id="qwen3-0.6b-8bit",
|
||||
model_id=ModelId("mlx-community/Qwen3-0.6B-8bit"),
|
||||
storage_size=Memory.from_mb(666),
|
||||
n_layers=28,
|
||||
hidden_size=1024,
|
||||
supports_tensor=False,
|
||||
name="Qwen3 0.6B (8-bit)",
|
||||
description="""Qwen3 0.6B is a large language model trained on the Qwen3 0.6B dataset.""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Qwen3-0.6B-8bit"),
|
||||
pretty_name="Qwen3 0.6B (8-bit)",
|
||||
storage_size=Memory.from_mb(666),
|
||||
n_layers=28,
|
||||
hidden_size=1024,
|
||||
supports_tensor=False,
|
||||
),
|
||||
),
|
||||
"qwen3-30b": ModelCard(
|
||||
short_id="qwen3-30b",
|
||||
model_id=ModelId("mlx-community/Qwen3-30B-A3B-4bit"),
|
||||
storage_size=Memory.from_mb(16797),
|
||||
n_layers=48,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
name="Qwen3 30B A3B (4-bit)",
|
||||
description="""Qwen3 30B is a large language model trained on the Qwen3 30B dataset.""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Qwen3-30B-A3B-4bit"),
|
||||
pretty_name="Qwen3 30B A3B (4-bit)",
|
||||
storage_size=Memory.from_mb(16797),
|
||||
n_layers=48,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"qwen3-30b-8bit": ModelCard(
|
||||
short_id="qwen3-30b-8bit",
|
||||
model_id=ModelId("mlx-community/Qwen3-30B-A3B-8bit"),
|
||||
storage_size=Memory.from_mb(31738),
|
||||
n_layers=48,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
name="Qwen3 30B A3B (8-bit)",
|
||||
description="""Qwen3 30B is a large language model trained on the Qwen3 30B dataset.""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Qwen3-30B-A3B-8bit"),
|
||||
pretty_name="Qwen3 30B A3B (8-bit)",
|
||||
storage_size=Memory.from_mb(31738),
|
||||
n_layers=48,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"qwen3-80b-a3B-4bit": ModelCard(
|
||||
short_id="qwen3-80b-a3B-4bit",
|
||||
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-4bit"),
|
||||
storage_size=Memory.from_mb(44800),
|
||||
n_layers=48,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
name="Qwen3 80B A3B (4-bit)",
|
||||
description="""Qwen3 80B""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-4bit"),
|
||||
pretty_name="Qwen3 80B A3B (4-bit)",
|
||||
storage_size=Memory.from_mb(44800),
|
||||
n_layers=48,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"qwen3-80b-a3B-8bit": ModelCard(
|
||||
short_id="qwen3-80b-a3B-8bit",
|
||||
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit"),
|
||||
storage_size=Memory.from_mb(84700),
|
||||
n_layers=48,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
name="Qwen3 80B A3B (8-bit)",
|
||||
description="""Qwen3 80B""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit"),
|
||||
pretty_name="Qwen3 80B A3B (8-bit)",
|
||||
storage_size=Memory.from_mb(84700),
|
||||
n_layers=48,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"qwen3-80b-a3B-thinking-4bit": ModelCard(
|
||||
short_id="qwen3-80b-a3B-thinking-4bit",
|
||||
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit"),
|
||||
storage_size=Memory.from_mb(84700),
|
||||
n_layers=48,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
name="Qwen3 80B A3B Thinking (4-bit)",
|
||||
description="""Qwen3 80B Reasoning model""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit"),
|
||||
pretty_name="Qwen3 80B A3B (4-bit)",
|
||||
storage_size=Memory.from_mb(84700),
|
||||
n_layers=48,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"qwen3-80b-a3B-thinking-8bit": ModelCard(
|
||||
short_id="qwen3-80b-a3B-thinking-8bit",
|
||||
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit"),
|
||||
storage_size=Memory.from_mb(84700),
|
||||
n_layers=48,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
name="Qwen3 80B A3B Thinking (8-bit)",
|
||||
description="""Qwen3 80B Reasoning model""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit"),
|
||||
pretty_name="Qwen3 80B A3B (8-bit)",
|
||||
storage_size=Memory.from_mb(84700),
|
||||
n_layers=48,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"qwen3-235b-a22b-4bit": ModelCard(
|
||||
short_id="qwen3-235b-a22b-4bit",
|
||||
model_id=ModelId("mlx-community/Qwen3-235B-A22B-Instruct-2507-4bit"),
|
||||
storage_size=Memory.from_gb(132),
|
||||
n_layers=94,
|
||||
hidden_size=4096,
|
||||
supports_tensor=True,
|
||||
name="Qwen3 235B A22B (4-bit)",
|
||||
description="""Qwen3 235B (Active 22B) is a large language model trained on the Qwen3 235B dataset.""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Qwen3-235B-A22B-Instruct-2507-4bit"),
|
||||
pretty_name="Qwen3 235B A22B (4-bit)",
|
||||
storage_size=Memory.from_gb(132),
|
||||
n_layers=94,
|
||||
hidden_size=4096,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"qwen3-235b-a22b-8bit": ModelCard(
|
||||
short_id="qwen3-235b-a22b-8bit",
|
||||
model_id=ModelId("mlx-community/Qwen3-235B-A22B-Instruct-2507-8bit"),
|
||||
storage_size=Memory.from_gb(250),
|
||||
n_layers=94,
|
||||
hidden_size=4096,
|
||||
supports_tensor=True,
|
||||
name="Qwen3 235B A22B (8-bit)",
|
||||
description="""Qwen3 235B (Active 22B) is a large language model trained on the Qwen3 235B dataset.""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Qwen3-235B-A22B-Instruct-2507-8bit"),
|
||||
pretty_name="Qwen3 235B A22B (8-bit)",
|
||||
storage_size=Memory.from_gb(250),
|
||||
n_layers=94,
|
||||
hidden_size=4096,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"qwen3-coder-480b-a35b-4bit": ModelCard(
|
||||
short_id="qwen3-coder-480b-a35b-4bit",
|
||||
model_id=ModelId("mlx-community/Qwen3-Coder-480B-A35B-Instruct-4bit"),
|
||||
storage_size=Memory.from_gb(270),
|
||||
n_layers=62,
|
||||
hidden_size=6144,
|
||||
supports_tensor=True,
|
||||
name="Qwen3 Coder 480B A35B (4-bit)",
|
||||
description="""Qwen3 Coder 480B (Active 35B) is a large language model trained on the Qwen3 Coder 480B dataset.""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Qwen3-Coder-480B-A35B-Instruct-4bit"),
|
||||
pretty_name="Qwen3 Coder 480B A35B (4-bit)",
|
||||
storage_size=Memory.from_gb(270),
|
||||
n_layers=62,
|
||||
hidden_size=6144,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"qwen3-coder-480b-a35b-8bit": ModelCard(
|
||||
short_id="qwen3-coder-480b-a35b-8bit",
|
||||
model_id=ModelId("mlx-community/Qwen3-Coder-480B-A35B-Instruct-8bit"),
|
||||
storage_size=Memory.from_gb(540),
|
||||
n_layers=62,
|
||||
hidden_size=6144,
|
||||
supports_tensor=True,
|
||||
name="Qwen3 Coder 480B A35B (8-bit)",
|
||||
description="""Qwen3 Coder 480B (Active 35B) is a large language model trained on the Qwen3 Coder 480B dataset.""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Qwen3-Coder-480B-A35B-Instruct-8bit"),
|
||||
pretty_name="Qwen3 Coder 480B A35B (8-bit)",
|
||||
storage_size=Memory.from_gb(540),
|
||||
n_layers=62,
|
||||
hidden_size=6144,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
# gpt-oss
|
||||
"gpt-oss-120b-MXFP4-Q8": ModelCard(
|
||||
short_id="gpt-oss-120b-MXFP4-Q8",
|
||||
model_id=ModelId("mlx-community/gpt-oss-120b-MXFP4-Q8"),
|
||||
storage_size=Memory.from_kb(68_996_301),
|
||||
n_layers=36,
|
||||
hidden_size=2880,
|
||||
supports_tensor=True,
|
||||
name="GPT-OSS 120B (MXFP4-Q8, MLX)",
|
||||
description="""OpenAI's GPT-OSS 120B is a 117B-parameter Mixture-of-Experts model designed for high-reasoning and general-purpose use; this variant is a 4-bit MLX conversion for Apple Silicon.""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/gpt-oss-120b-MXFP4-Q8"),
|
||||
pretty_name="GPT-OSS 120B (MXFP4-Q8, MLX)",
|
||||
storage_size=Memory.from_kb(68_996_301),
|
||||
n_layers=36,
|
||||
hidden_size=2880,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"gpt-oss-20b-MXFP4-Q8": ModelCard(
|
||||
short_id="gpt-oss-20b-MXFP4-Q8",
|
||||
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q8"),
|
||||
storage_size=Memory.from_kb(11_744_051),
|
||||
n_layers=24,
|
||||
hidden_size=2880,
|
||||
supports_tensor=True,
|
||||
name="GPT-OSS 20B (MXFP4-Q8, MLX)",
|
||||
description="""OpenAI's GPT-OSS 20B is a medium-sized MoE model for lower-latency and local or specialized use cases; this variant is a 4-bit MLX conversion for Apple Silicon.""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q8"),
|
||||
pretty_name="GPT-OSS 20B (MXFP4-Q8, MLX)",
|
||||
storage_size=Memory.from_kb(11_744_051),
|
||||
n_layers=24,
|
||||
hidden_size=2880,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
# glm 4.5
|
||||
"glm-4.5-air-8bit": ModelCard(
|
||||
# Needs to be quantized g32 or g16 to work with tensor parallel
|
||||
short_id="glm-4.5-air-8bit",
|
||||
model_id=ModelId("mlx-community/GLM-4.5-Air-8bit"),
|
||||
storage_size=Memory.from_gb(114),
|
||||
n_layers=46,
|
||||
hidden_size=4096,
|
||||
supports_tensor=False,
|
||||
name="GLM 4.5 Air 8bit",
|
||||
description="""GLM 4.5 Air 8bit""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/GLM-4.5-Air-8bit"),
|
||||
pretty_name="GLM 4.5 Air 8bit",
|
||||
storage_size=Memory.from_gb(114),
|
||||
n_layers=46,
|
||||
hidden_size=4096,
|
||||
supports_tensor=False,
|
||||
),
|
||||
),
|
||||
"glm-4.5-air-bf16": ModelCard(
|
||||
short_id="glm-4.5-air-bf16",
|
||||
model_id=ModelId("mlx-community/GLM-4.5-Air-bf16"),
|
||||
storage_size=Memory.from_gb(214),
|
||||
n_layers=46,
|
||||
hidden_size=4096,
|
||||
supports_tensor=True,
|
||||
name="GLM 4.5 Air bf16",
|
||||
description="""GLM 4.5 Air bf16""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/GLM-4.5-Air-bf16"),
|
||||
pretty_name="GLM 4.5 Air bf16",
|
||||
storage_size=Memory.from_gb(214),
|
||||
n_layers=46,
|
||||
hidden_size=4096,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
# glm 4.7
|
||||
"glm-4.7-4bit": ModelCard(
|
||||
short_id="glm-4.7-4bit",
|
||||
model_id=ModelId("mlx-community/GLM-4.7-4bit"),
|
||||
storage_size=Memory.from_bytes(198556925568),
|
||||
n_layers=91,
|
||||
hidden_size=5120,
|
||||
supports_tensor=True,
|
||||
name="GLM 4.7 4bit",
|
||||
description="GLM 4.7 4bit",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/GLM-4.7-4bit"),
|
||||
pretty_name="GLM 4.7 4bit",
|
||||
storage_size=Memory.from_bytes(198556925568),
|
||||
n_layers=91,
|
||||
hidden_size=5120,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"glm-4.7-6bit": ModelCard(
|
||||
short_id="glm-4.7-6bit",
|
||||
model_id=ModelId("mlx-community/GLM-4.7-6bit"),
|
||||
storage_size=Memory.from_bytes(286737579648),
|
||||
n_layers=91,
|
||||
hidden_size=5120,
|
||||
supports_tensor=True,
|
||||
name="GLM 4.7 6bit",
|
||||
description="GLM 4.7 6bit",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/GLM-4.7-6bit"),
|
||||
pretty_name="GLM 4.7 6bit",
|
||||
storage_size=Memory.from_bytes(286737579648),
|
||||
n_layers=91,
|
||||
hidden_size=5120,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"glm-4.7-8bit-gs32": ModelCard(
|
||||
short_id="glm-4.7-8bit-gs32",
|
||||
model_id=ModelId("mlx-community/GLM-4.7-8bit-gs32"),
|
||||
storage_size=Memory.from_bytes(396963397248),
|
||||
n_layers=91,
|
||||
hidden_size=5120,
|
||||
supports_tensor=True,
|
||||
),
|
||||
# glm 4.7 flash
|
||||
"glm-4.7-flash-4bit": ModelCard(
|
||||
model_id=ModelId("mlx-community/GLM-4.7-Flash-4bit"),
|
||||
storage_size=Memory.from_gb(18),
|
||||
n_layers=47,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
),
|
||||
"glm-4.7-flash-5bit": ModelCard(
|
||||
model_id=ModelId("mlx-community/GLM-4.7-Flash-5bit"),
|
||||
storage_size=Memory.from_gb(21),
|
||||
n_layers=47,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
),
|
||||
"glm-4.7-flash-6bit": ModelCard(
|
||||
model_id=ModelId("mlx-community/GLM-4.7-Flash-6bit"),
|
||||
storage_size=Memory.from_gb(25),
|
||||
n_layers=47,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
),
|
||||
"glm-4.7-flash-8bit": ModelCard(
|
||||
model_id=ModelId("mlx-community/GLM-4.7-Flash-8bit"),
|
||||
storage_size=Memory.from_gb(32),
|
||||
n_layers=47,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
name="GLM 4.7 8bit (gs32)",
|
||||
description="GLM 4.7 8bit (gs32)",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/GLM-4.7-8bit-gs32"),
|
||||
pretty_name="GLM 4.7 8bit (gs32)",
|
||||
storage_size=Memory.from_bytes(396963397248),
|
||||
n_layers=91,
|
||||
hidden_size=5120,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
# minimax-m2
|
||||
"minimax-m2.1-8bit": ModelCard(
|
||||
short_id="minimax-m2.1-8bit",
|
||||
model_id=ModelId("mlx-community/MiniMax-M2.1-8bit"),
|
||||
storage_size=Memory.from_bytes(242986745856),
|
||||
n_layers=61,
|
||||
hidden_size=3072,
|
||||
supports_tensor=True,
|
||||
name="MiniMax M2.1 8bit",
|
||||
description="MiniMax M2.1 8bit",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/MiniMax-M2.1-8bit"),
|
||||
pretty_name="MiniMax M2.1 8bit",
|
||||
storage_size=Memory.from_bytes(242986745856),
|
||||
n_layers=61,
|
||||
hidden_size=3072,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"minimax-m2.1-3bit": ModelCard(
|
||||
short_id="minimax-m2.1-3bit",
|
||||
model_id=ModelId("mlx-community/MiniMax-M2.1-3bit"),
|
||||
storage_size=Memory.from_bytes(100086644736),
|
||||
n_layers=61,
|
||||
hidden_size=3072,
|
||||
supports_tensor=True,
|
||||
name="MiniMax M2.1 3bit",
|
||||
description="MiniMax M2.1 3bit",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/MiniMax-M2.1-3bit"),
|
||||
pretty_name="MiniMax M2.1 3bit",
|
||||
storage_size=Memory.from_bytes(100086644736),
|
||||
n_layers=61,
|
||||
hidden_size=3072,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
}
|
||||
|
||||
@@ -6,8 +6,9 @@ from huggingface_hub import model_info
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard, ModelId
|
||||
from exo.shared.models.model_cards import MODEL_CARDS
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.models import ModelId, ModelMetadata
|
||||
from exo.worker.download.download_utils import (
|
||||
ModelSafetensorsIndex,
|
||||
download_file_with_retry,
|
||||
@@ -91,18 +92,18 @@ async def get_safetensors_size(model_id: str) -> Memory:
|
||||
return Memory.from_bytes(info.safetensors.total)
|
||||
|
||||
|
||||
_model_card_cache: dict[str, ModelCard] = {}
|
||||
_model_meta_cache: dict[str, ModelMetadata] = {}
|
||||
|
||||
|
||||
async def get_model_card(model_id: str) -> ModelCard:
|
||||
if model_id in _model_card_cache:
|
||||
return _model_card_cache[model_id]
|
||||
model_card = await _get_model_card(model_id)
|
||||
_model_card_cache[model_id] = model_card
|
||||
return model_card
|
||||
async def get_model_meta(model_id: str) -> ModelMetadata:
|
||||
if model_id in _model_meta_cache:
|
||||
return _model_meta_cache[model_id]
|
||||
model_meta = await _get_model_meta(model_id)
|
||||
_model_meta_cache[model_id] = model_meta
|
||||
return model_meta
|
||||
|
||||
|
||||
async def _get_model_card(model_id: str) -> ModelCard:
|
||||
async def _get_model_meta(model_id: str) -> ModelMetadata:
|
||||
"""Fetches storage size and number of layers for a Hugging Face model, returns Pydantic ModelMeta."""
|
||||
config_data = await get_config_data(model_id)
|
||||
num_layers = config_data.layer_count
|
||||
@@ -112,11 +113,14 @@ async def _get_model_card(model_id: str) -> ModelCard:
|
||||
None,
|
||||
)
|
||||
|
||||
return ModelCard(
|
||||
return ModelMetadata(
|
||||
model_id=ModelId(model_id),
|
||||
pretty_name=model_card.name if model_card is not None else model_id,
|
||||
storage_size=mem_size_bytes,
|
||||
n_layers=num_layers,
|
||||
hidden_size=config_data.hidden_size or 0,
|
||||
# TODO: all custom models currently do not support tensor. We could add a dynamic test for this?
|
||||
supports_tensor=model_card.supports_tensor if model_card is not None else False,
|
||||
supports_tensor=model_card.metadata.supports_tensor
|
||||
if model_card is not None
|
||||
else False,
|
||||
)
|
||||
|
||||
@@ -7,8 +7,8 @@ import pytest
|
||||
from _pytest.logging import LogCaptureFixture
|
||||
from loguru import logger
|
||||
|
||||
from exo.shared.models.model_cards import ModelCard, ModelId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.models import ModelId, ModelMetadata
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata, ShardMetadata
|
||||
|
||||
|
||||
@@ -31,8 +31,9 @@ def get_pipeline_shard_metadata(
|
||||
model_id: ModelId, device_rank: int, world_size: int = 1
|
||||
) -> ShardMetadata:
|
||||
return PipelineShardMetadata(
|
||||
model_card=ModelCard(
|
||||
model_meta=ModelMetadata(
|
||||
model_id=model_id,
|
||||
pretty_name=str(model_id),
|
||||
storage_size=Memory.from_mb(100000),
|
||||
n_layers=32,
|
||||
hidden_size=1000,
|
||||
|
||||
@@ -4,9 +4,9 @@ from typing import Any, Literal
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic_core import PydanticUseDefault
|
||||
|
||||
from exo.shared.models.model_cards import ModelCard, ModelId
|
||||
from exo.shared.types.common import CommandId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.models import ModelId, ModelMetadata
|
||||
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
|
||||
from exo.shared.types.worker.shards import Sharding
|
||||
|
||||
@@ -206,7 +206,7 @@ class DeleteInstanceTaskParams(BaseModel):
|
||||
class CreateInstanceResponse(BaseModel):
|
||||
message: str
|
||||
command_id: CommandId
|
||||
model_card: ModelCard
|
||||
model_meta: ModelMetadata
|
||||
|
||||
|
||||
class DeleteInstanceResponse(BaseModel):
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
from enum import Enum
|
||||
|
||||
from exo.shared.models.model_cards import ModelId
|
||||
from exo.shared.types.api import GenerationStats
|
||||
from exo.utils.pydantic_ext import TaggedModel
|
||||
|
||||
from .api import FinishReason
|
||||
from .models import ModelId
|
||||
|
||||
|
||||
class ChunkType(str, Enum):
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from pydantic import Field
|
||||
|
||||
from exo.shared.models.model_cards import ModelCard
|
||||
from exo.shared.types.api import ChatCompletionTaskParams
|
||||
from exo.shared.types.common import CommandId, NodeId
|
||||
from exo.shared.types.models import ModelMetadata
|
||||
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
|
||||
from exo.shared.types.worker.shards import Sharding
|
||||
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
|
||||
@@ -21,7 +21,7 @@ class ChatCompletion(BaseCommand):
|
||||
|
||||
|
||||
class PlaceInstance(BaseCommand):
|
||||
model_card: ModelCard
|
||||
model_meta: ModelMetadata
|
||||
sharding: Sharding
|
||||
instance_meta: InstanceMeta
|
||||
min_nodes: int
|
||||
@@ -35,6 +35,26 @@ class DeleteInstance(BaseCommand):
|
||||
instance_id: InstanceId
|
||||
|
||||
|
||||
class LaunchFLASH(BaseCommand):
|
||||
"""Command to launch a FLASH MPI simulation."""
|
||||
|
||||
simulation_name: str
|
||||
flash_executable_path: str
|
||||
parameter_file_path: str
|
||||
working_directory: str
|
||||
ranks_per_node: int = 1
|
||||
min_nodes: int = 1
|
||||
# Optional: explicit hostnames for MPI (e.g., "s14,james21-1")
|
||||
# Used when topology edges don't contain IP addresses
|
||||
hosts: str = ""
|
||||
|
||||
|
||||
class StopFLASH(BaseCommand):
|
||||
"""Command to stop a running FLASH simulation."""
|
||||
|
||||
instance_id: InstanceId
|
||||
|
||||
|
||||
class TaskFinished(BaseCommand):
|
||||
finished_command_id: CommandId
|
||||
|
||||
@@ -50,6 +70,8 @@ Command = (
|
||||
| PlaceInstance
|
||||
| CreateInstance
|
||||
| DeleteInstance
|
||||
| LaunchFLASH
|
||||
| StopFLASH
|
||||
| TaskFinished
|
||||
)
|
||||
|
||||
|
||||
@@ -16,9 +16,7 @@ class Id(str):
|
||||
cls, _source: type, handler: GetCoreSchemaHandler
|
||||
) -> core_schema.CoreSchema:
|
||||
# Just use a plain string schema
|
||||
return core_schema.no_info_after_validator_function(
|
||||
cls, core_schema.str_schema()
|
||||
)
|
||||
return core_schema.str_schema()
|
||||
|
||||
|
||||
class NodeId(Id):
|
||||
|
||||
18
src/exo/shared/types/models.py
Normal file
18
src/exo/shared/types/models.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from pydantic import PositiveInt
|
||||
|
||||
from exo.shared.types.common import Id
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.utils.pydantic_ext import CamelCaseModel
|
||||
|
||||
|
||||
class ModelId(Id):
|
||||
pass
|
||||
|
||||
|
||||
class ModelMetadata(CamelCaseModel):
|
||||
model_id: ModelId
|
||||
pretty_name: str
|
||||
storage_size: Memory
|
||||
n_layers: PositiveInt
|
||||
hidden_size: PositiveInt
|
||||
supports_tensor: bool
|
||||
@@ -14,6 +14,7 @@ class InstanceId(Id):
|
||||
class InstanceMeta(str, Enum):
|
||||
MlxRing = "MlxRing"
|
||||
MlxJaccl = "MlxJaccl"
|
||||
FLASH = "FLASH"
|
||||
|
||||
|
||||
class BaseInstance(TaggedModel):
|
||||
@@ -34,8 +35,27 @@ class MlxJacclInstance(BaseInstance):
|
||||
jaccl_coordinators: dict[NodeId, str]
|
||||
|
||||
|
||||
class FLASHInstance(BaseInstance):
|
||||
"""Instance for FLASH MPI simulation.
|
||||
|
||||
Unlike MLX instances which do tensor parallelism, FLASH instances
|
||||
coordinate MPI processes across nodes. Each node runs one or more
|
||||
MPI ranks of the FLASH simulation.
|
||||
"""
|
||||
|
||||
hosts_by_node: dict[NodeId, list[Host]]
|
||||
flash_executable_path: str
|
||||
parameter_file_path: str
|
||||
working_directory: str
|
||||
ranks_per_node: int = 1
|
||||
total_ranks: int
|
||||
simulation_name: str
|
||||
coordinator_ip: str
|
||||
network_interface: str = "en0" # Network interface for MPI (e.g., en0, eth0)
|
||||
|
||||
|
||||
# TODO: Single node instance
|
||||
Instance = MlxRingInstance | MlxJacclInstance
|
||||
Instance = MlxRingInstance | MlxJacclInstance | FLASHInstance
|
||||
|
||||
|
||||
class BoundInstance(CamelCaseModel):
|
||||
|
||||
@@ -2,8 +2,8 @@ from collections.abc import Mapping
|
||||
|
||||
from pydantic import model_validator
|
||||
|
||||
from exo.shared.models.model_cards import ModelId
|
||||
from exo.shared.types.common import Id, NodeId
|
||||
from exo.shared.types.models import ModelId
|
||||
from exo.shared.types.worker.shards import ShardMetadata
|
||||
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ from enum import Enum
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from exo.shared.models.model_cards import ModelCard
|
||||
from exo.shared.types.models import ModelMetadata
|
||||
from exo.utils.pydantic_ext import TaggedModel
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ class BaseShardMetadata(TaggedModel):
|
||||
Replaces previous `Shard` object.
|
||||
"""
|
||||
|
||||
model_card: ModelCard
|
||||
model_meta: ModelMetadata
|
||||
device_rank: int
|
||||
world_size: int
|
||||
|
||||
@@ -41,7 +41,7 @@ class BaseShardMetadata(TaggedModel):
|
||||
def __hash__(self) -> int:
|
||||
return hash(
|
||||
(
|
||||
self.model_card.model_id,
|
||||
self.model_meta.model_id,
|
||||
self.start_layer,
|
||||
self.end_layer,
|
||||
self.n_layers,
|
||||
|
||||
@@ -460,10 +460,10 @@ async def resolve_allow_patterns(shard: ShardMetadata) -> list[str]:
|
||||
# (iii) Tensor parallel requires all files.
|
||||
return ["*"]
|
||||
try:
|
||||
weight_map = await get_weight_map(str(shard.model_card.model_id))
|
||||
weight_map = await get_weight_map(str(shard.model_meta.model_id))
|
||||
return get_allow_patterns(weight_map, shard)
|
||||
except Exception:
|
||||
logger.error(f"Error getting weight map for {shard.model_card.model_id=}")
|
||||
logger.error(f"Error getting weight map for {shard.model_meta.model_id=}")
|
||||
logger.error(traceback.format_exc())
|
||||
return ["*"]
|
||||
|
||||
@@ -532,18 +532,18 @@ async def download_shard(
|
||||
allow_patterns: list[str] | None = None,
|
||||
) -> tuple[Path, RepoDownloadProgress]:
|
||||
if not skip_download:
|
||||
logger.info(f"Downloading {shard.model_card.model_id=}")
|
||||
logger.info(f"Downloading {shard.model_meta.model_id=}")
|
||||
|
||||
# Handle local paths
|
||||
if await aios.path.exists(str(shard.model_card.model_id)):
|
||||
logger.info(f"Using local model path {shard.model_card.model_id}")
|
||||
local_path = Path(str(shard.model_card.model_id))
|
||||
if await aios.path.exists(str(shard.model_meta.model_id)):
|
||||
logger.info(f"Using local model path {shard.model_meta.model_id}")
|
||||
local_path = Path(str(shard.model_meta.model_id))
|
||||
return local_path, await download_progress_for_local_path(
|
||||
str(shard.model_card.model_id), shard, local_path
|
||||
str(shard.model_meta.model_id), shard, local_path
|
||||
)
|
||||
|
||||
revision = "main"
|
||||
target_dir = await ensure_models_dir() / str(shard.model_card.model_id).replace(
|
||||
target_dir = await ensure_models_dir() / str(shard.model_meta.model_id).replace(
|
||||
"/", "--"
|
||||
)
|
||||
if not skip_download:
|
||||
@@ -552,13 +552,13 @@ async def download_shard(
|
||||
if not allow_patterns:
|
||||
allow_patterns = await resolve_allow_patterns(shard)
|
||||
|
||||
logger.info(f"Downloading {shard.model_card.model_id=} with {allow_patterns=}")
|
||||
logger.info(f"Downloading {shard.model_meta.model_id=} with {allow_patterns=}")
|
||||
|
||||
all_start_time = time.time()
|
||||
# TODO: currently not recursive. Some models might require subdirectories - thus this will need to be changed.
|
||||
# Update: <- This does not seem to be the case. Yay?
|
||||
file_list = await fetch_file_list_with_cache(
|
||||
str(shard.model_card.model_id), revision, recursive=True
|
||||
str(shard.model_meta.model_id), revision, recursive=True
|
||||
)
|
||||
filtered_file_list = list(
|
||||
filter_repo_objects(
|
||||
@@ -592,7 +592,7 @@ async def download_shard(
|
||||
else timedelta(seconds=0)
|
||||
)
|
||||
file_progress[file.path] = RepoFileDownloadProgress(
|
||||
repo_id=str(shard.model_card.model_id),
|
||||
repo_id=str(shard.model_meta.model_id),
|
||||
repo_revision=revision,
|
||||
file_path=file.path,
|
||||
downloaded=Memory.from_bytes(curr_bytes),
|
||||
@@ -609,7 +609,7 @@ async def download_shard(
|
||||
shard,
|
||||
calculate_repo_progress(
|
||||
shard,
|
||||
str(shard.model_card.model_id),
|
||||
str(shard.model_meta.model_id),
|
||||
revision,
|
||||
file_progress,
|
||||
all_start_time,
|
||||
@@ -619,7 +619,7 @@ async def download_shard(
|
||||
for file in filtered_file_list:
|
||||
downloaded_bytes = await get_downloaded_size(target_dir / file.path)
|
||||
file_progress[file.path] = RepoFileDownloadProgress(
|
||||
repo_id=str(shard.model_card.model_id),
|
||||
repo_id=str(shard.model_meta.model_id),
|
||||
repo_revision=revision,
|
||||
file_path=file.path,
|
||||
downloaded=Memory.from_bytes(downloaded_bytes),
|
||||
@@ -643,7 +643,7 @@ async def download_shard(
|
||||
async def download_with_semaphore(file: FileListEntry) -> None:
|
||||
async with semaphore:
|
||||
await download_file_with_retry(
|
||||
str(shard.model_card.model_id),
|
||||
str(shard.model_meta.model_id),
|
||||
revision,
|
||||
file.path,
|
||||
target_dir,
|
||||
@@ -657,7 +657,7 @@ async def download_shard(
|
||||
*[download_with_semaphore(file) for file in filtered_file_list]
|
||||
)
|
||||
final_repo_progress = calculate_repo_progress(
|
||||
shard, str(shard.model_card.model_id), revision, file_progress, all_start_time
|
||||
shard, str(shard.model_meta.model_id), revision, file_progress, all_start_time
|
||||
)
|
||||
await on_progress(shard, final_repo_progress)
|
||||
if gguf := next((f for f in filtered_file_list if f.path.endswith(".gguf")), None):
|
||||
|
||||
@@ -4,7 +4,7 @@ from pathlib import Path
|
||||
from typing import AsyncIterator, Callable
|
||||
|
||||
from exo.shared.models.model_cards import MODEL_CARDS
|
||||
from exo.shared.models.model_meta import get_model_card
|
||||
from exo.shared.models.model_meta import get_model_meta
|
||||
from exo.shared.types.worker.shards import (
|
||||
PipelineShardMetadata,
|
||||
ShardMetadata,
|
||||
@@ -20,21 +20,21 @@ def exo_shard_downloader(max_parallel_downloads: int = 8) -> ShardDownloader:
|
||||
|
||||
|
||||
async def build_base_shard(model_id: str) -> ShardMetadata:
|
||||
model_card = await get_model_card(model_id)
|
||||
model_meta = await get_model_meta(model_id)
|
||||
return PipelineShardMetadata(
|
||||
model_card=model_card,
|
||||
model_meta=model_meta,
|
||||
device_rank=0,
|
||||
world_size=1,
|
||||
start_layer=0,
|
||||
end_layer=model_card.n_layers,
|
||||
n_layers=model_card.n_layers,
|
||||
end_layer=model_meta.n_layers,
|
||||
n_layers=model_meta.n_layers,
|
||||
)
|
||||
|
||||
|
||||
async def build_full_shard(model_id: str) -> PipelineShardMetadata:
|
||||
base_shard = await build_base_shard(model_id)
|
||||
return PipelineShardMetadata(
|
||||
model_card=base_shard.model_card,
|
||||
model_meta=base_shard.model_meta,
|
||||
device_rank=base_shard.device_rank,
|
||||
world_size=base_shard.world_size,
|
||||
start_layer=base_shard.start_layer,
|
||||
@@ -93,11 +93,11 @@ class CachedShardDownloader(ShardDownloader):
|
||||
async def ensure_shard(
|
||||
self, shard: ShardMetadata, config_only: bool = False
|
||||
) -> Path:
|
||||
if (shard.model_card.model_id, shard) in self.cache:
|
||||
return self.cache[(shard.model_card.model_id, shard)]
|
||||
if (shard.model_meta.model_id, shard) in self.cache:
|
||||
return self.cache[(shard.model_meta.model_id, shard)]
|
||||
|
||||
target_dir = await self.shard_downloader.ensure_shard(shard, config_only)
|
||||
self.cache[(shard.model_card.model_id, shard)] = target_dir
|
||||
self.cache[(shard.model_meta.model_id, shard)] = target_dir
|
||||
return target_dir
|
||||
|
||||
async def get_shard_download_status(
|
||||
|
||||
@@ -5,8 +5,8 @@ from datetime import timedelta
|
||||
from pathlib import Path
|
||||
from typing import AsyncIterator, Callable
|
||||
|
||||
from exo.shared.models.model_cards import ModelCard, ModelId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.models import ModelId, ModelMetadata
|
||||
from exo.shared.types.worker.shards import (
|
||||
PipelineShardMetadata,
|
||||
ShardMetadata,
|
||||
@@ -86,8 +86,9 @@ NOOP_DOWNLOAD_PROGRESS = RepoDownloadProgress(
|
||||
repo_id="noop",
|
||||
repo_revision="noop",
|
||||
shard=PipelineShardMetadata(
|
||||
model_card=ModelCard(
|
||||
model_meta=ModelMetadata(
|
||||
model_id=ModelId("noop"),
|
||||
pretty_name="noope",
|
||||
storage_size=Memory.from_bytes(0),
|
||||
n_layers=1,
|
||||
hidden_size=1,
|
||||
|
||||
@@ -1,10 +1,7 @@
|
||||
import os
|
||||
import threading
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from functools import partial
|
||||
from inspect import signature
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
from typing import TYPE_CHECKING, Callable, Protocol, cast
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@@ -32,50 +29,28 @@ from mlx_lm.models.qwen3_next import Qwen3NextSparseMoeBlock
|
||||
from exo.shared.logging import logger
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata
|
||||
|
||||
TimeoutCallback = Callable[[], None]
|
||||
|
||||
class _LayerCallable(Protocol):
|
||||
"""Structural type that any compatible layer must satisfy.
|
||||
|
||||
def eval_with_timeout(
|
||||
mlx_item: Any, # pyright: ignore[reportAny]
|
||||
timeout_seconds: float = 60.0,
|
||||
on_timeout: TimeoutCallback | None = None,
|
||||
) -> None:
|
||||
"""Evaluate MLX item with a hard timeout.
|
||||
|
||||
If on_timeout callback is provided, it will be called before terminating
|
||||
the process. This allows the runner to send a failure event before exit.
|
||||
We require a single positional input of type ``mx.array`` and an
|
||||
``mx.array`` output, while permitting arbitrary *args / **kwargs so this
|
||||
protocol matches the vast majority of `mlx.nn.Module` subclasses.
|
||||
"""
|
||||
completed = threading.Event()
|
||||
|
||||
def watchdog() -> None:
|
||||
if not completed.wait(timeout=timeout_seconds):
|
||||
logger.error(
|
||||
f"mlx_item evaluation timed out after {timeout_seconds:.0f}s. "
|
||||
"This may indicate an issue with FAST_SYNCH and tensor parallel sharding. "
|
||||
"Terminating process."
|
||||
)
|
||||
if on_timeout is not None:
|
||||
on_timeout()
|
||||
os._exit(1)
|
||||
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array: ...
|
||||
|
||||
watchdog_thread = threading.Thread(target=watchdog, daemon=True)
|
||||
watchdog_thread.start()
|
||||
|
||||
try:
|
||||
mx.eval(mlx_item) # pyright: ignore[reportAny]
|
||||
finally:
|
||||
completed.set()
|
||||
|
||||
class CustomMlxLayer(nn.Module):
|
||||
"""Base class for replacing an MLX layer with a custom implementation."""
|
||||
|
||||
def __init__(self, original_layer: nn.Module):
|
||||
def __init__(self, original_layer: _LayerCallable):
|
||||
super().__init__()
|
||||
object.__setattr__(self, "_original_layer", original_layer)
|
||||
|
||||
@property
|
||||
def original_layer(self) -> nn.Module:
|
||||
return cast(nn.Module, object.__getattribute__(self, "_original_layer"))
|
||||
def original_layer(self) -> _LayerCallable:
|
||||
return cast(_LayerCallable, object.__getattribute__(self, "_original_layer"))
|
||||
|
||||
# Calls __getattr__ for any attributes not found on nn.Module (e.g. use_sliding)
|
||||
if not TYPE_CHECKING:
|
||||
@@ -88,49 +63,53 @@ class CustomMlxLayer(nn.Module):
|
||||
return getattr(original_layer, name)
|
||||
|
||||
|
||||
class PipelineFirstLayer(CustomMlxLayer):
|
||||
def __init__(
|
||||
self,
|
||||
original_layer: _LayerCallable,
|
||||
r: int,
|
||||
group: mx.distributed.Group,
|
||||
):
|
||||
super().__init__(original_layer)
|
||||
self.r: int = r
|
||||
self.group = group
|
||||
|
||||
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array:
|
||||
if self.r != 0:
|
||||
x = mx.distributed.recv_like(x, (self.r - 1), group=self.group)
|
||||
return self.original_layer(x, *args, **kwargs)
|
||||
|
||||
|
||||
def patch_pipeline_first_layer(pipeline_layer: nn.Module, group: mx.distributed.Group) -> nn.Module:
|
||||
orig_call = cast(Callable[..., mx.array], type(pipeline_layer).__call__)
|
||||
class PipelineLastLayer(CustomMlxLayer):
|
||||
def __init__(
|
||||
self,
|
||||
original_layer: _LayerCallable,
|
||||
r: int,
|
||||
s: int,
|
||||
group: mx.distributed.Group,
|
||||
):
|
||||
super().__init__(original_layer)
|
||||
self.r: int = r
|
||||
self.s: int = s
|
||||
self.group = group
|
||||
self.original_layer_signature = signature(self.original_layer.__call__)
|
||||
|
||||
rank = group.rank()
|
||||
class PatchedFirstLayer(nn.Module):
|
||||
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array:
|
||||
if rank != 0:
|
||||
x = mx.distributed.recv_like(x, (rank - 1), group=group)
|
||||
return orig_call(x, *args, **kwargs)
|
||||
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array:
|
||||
cache = self.original_layer_signature.bind_partial(
|
||||
x, *args, **kwargs
|
||||
).arguments.get("cache", None)
|
||||
|
||||
pipeline_layer.__class__ = PatchedFirstLayer
|
||||
|
||||
return pipeline_layer
|
||||
output: mx.array = self.original_layer(x, *args, **kwargs)
|
||||
|
||||
def patch_pipeline_last_layer(pipeline_layer: nn.Module, group: mx.distributed.Group) -> nn.Module:
|
||||
orig_call = cast(Callable[..., mx.array], type(pipeline_layer).__call__)
|
||||
orig_call_sig = signature(orig_call)
|
||||
|
||||
if self.r != self.s - 1:
|
||||
output = mx.distributed.send(
|
||||
output, (self.r + 1) % self.s, group=self.group
|
||||
)
|
||||
if cache is not None:
|
||||
cache.keys = mx.depends(cache.keys, output) # type: ignore[reportUnknownMemberType]
|
||||
|
||||
rank = group.rank()
|
||||
size = group.size()
|
||||
class PatchedLastLayer(nn.Module):
|
||||
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array:
|
||||
cache = orig_call_sig.bind_partial(
|
||||
x, *args, **kwargs
|
||||
).arguments.get("cache", None)
|
||||
return output
|
||||
|
||||
output: mx.array = orig_call(x, *args, **kwargs)
|
||||
|
||||
if rank != size - 1:
|
||||
output = mx.distributed.send(
|
||||
output, (rank + 1) % size, group=group
|
||||
)
|
||||
if cache is not None:
|
||||
cache.keys = mx.depends(cache.keys, output) # type: ignore[reportUnknownMemberType]
|
||||
|
||||
return output
|
||||
|
||||
pipeline_layer.__class__ = PatchedLastLayer
|
||||
|
||||
return pipeline_layer
|
||||
|
||||
def _inner_model(model: nn.Module) -> nn.Module:
|
||||
inner = getattr(model, "model", None)
|
||||
@@ -144,13 +123,13 @@ def _inner_model(model: nn.Module) -> nn.Module:
|
||||
raise ValueError("Model must either have a 'model' or 'transformer' attribute")
|
||||
|
||||
|
||||
def _get_layers(inner_model_instance: nn.Module) -> list[nn.Module]:
|
||||
def _get_layers(inner_model_instance: nn.Module) -> list[_LayerCallable]:
|
||||
# Handle both model.layers and model.h cases
|
||||
layers: list[nn.Module]
|
||||
layers: list[_LayerCallable]
|
||||
if hasattr(inner_model_instance, "layers"):
|
||||
layers = cast(list[nn.Module], inner_model_instance.layers)
|
||||
layers = cast(list[_LayerCallable], inner_model_instance.layers)
|
||||
elif hasattr(inner_model_instance, "h"):
|
||||
layers = cast(list[nn.Module], inner_model_instance.h)
|
||||
layers = cast(list[_LayerCallable], inner_model_instance.h)
|
||||
else:
|
||||
raise ValueError("Model must have either a 'layers' or 'h' attribute")
|
||||
|
||||
@@ -175,12 +154,15 @@ def pipeline_auto_parallel(
|
||||
layers = _get_layers(inner_model_instance)
|
||||
|
||||
start_layer, end_layer = model_shard_meta.start_layer, model_shard_meta.end_layer
|
||||
device_rank, world_size = model_shard_meta.device_rank, model_shard_meta.world_size
|
||||
|
||||
layers = layers[start_layer:end_layer]
|
||||
layers[0] = patch_pipeline_first_layer(layers[0], group)
|
||||
layers[-1] = patch_pipeline_last_layer(
|
||||
layers[0] = PipelineFirstLayer(layers[0], device_rank, group=group)
|
||||
layers[-1] = PipelineLastLayer(
|
||||
layers[-1],
|
||||
group,
|
||||
device_rank,
|
||||
world_size,
|
||||
group=group,
|
||||
)
|
||||
|
||||
if isinstance(inner_model_instance, GptOssMoeModel):
|
||||
@@ -243,37 +225,9 @@ def patch_pipeline_model[T](model: T, group: mx.distributed.Group) -> T:
|
||||
return model
|
||||
|
||||
|
||||
def patch_tensor_model[T](model: T) -> T:
|
||||
"""Patch model's __call__ to ensure distributed ops sync during inference."""
|
||||
cls = model.__class__
|
||||
original_call = cls.__call__
|
||||
call_signature = signature(original_call)
|
||||
|
||||
def patched_call(
|
||||
self: T,
|
||||
*args: object,
|
||||
**kwargs: object,
|
||||
) -> mx.array:
|
||||
logits: mx.array = original_call(self, *args, **kwargs) # pyright: ignore[reportAny]
|
||||
cache = call_signature.bind_partial(self, *args, **kwargs).arguments.get(
|
||||
"cache", None
|
||||
)
|
||||
|
||||
# Add dependency to last cache entry to ensure distributed ops are evaluated
|
||||
if cache is not None and len(cache) > 0: # pyright: ignore[reportAny]
|
||||
cache[-1].state = mx.depends(cache[-1].state, logits) # pyright: ignore[reportAny,reportUnknownMemberType]
|
||||
|
||||
return logits
|
||||
|
||||
cls.__call__ = patched_call
|
||||
return model
|
||||
|
||||
|
||||
def tensor_auto_parallel(
|
||||
model: nn.Module,
|
||||
group: mx.distributed.Group,
|
||||
timeout_seconds: float = 60.0,
|
||||
on_timeout: TimeoutCallback | None = None,
|
||||
) -> nn.Module:
|
||||
all_to_sharded_linear = partial(
|
||||
shard_linear,
|
||||
@@ -318,7 +272,7 @@ def tensor_auto_parallel(
|
||||
if hasattr(model, "shard"):
|
||||
try:
|
||||
model.shard(group) # type: ignore
|
||||
return patch_tensor_model(model)
|
||||
return model
|
||||
except (AttributeError, TypeError, NameError):
|
||||
pass
|
||||
|
||||
@@ -368,10 +322,7 @@ def tensor_auto_parallel(
|
||||
else:
|
||||
raise ValueError(f"Unsupported model type: {type(model)}")
|
||||
|
||||
model = tensor_parallel_sharding_strategy.shard_model(
|
||||
model, timeout_seconds, on_timeout
|
||||
)
|
||||
return patch_tensor_model(model)
|
||||
return tensor_parallel_sharding_strategy.shard_model(model)
|
||||
|
||||
|
||||
class TensorParallelShardingStrategy(ABC):
|
||||
@@ -391,27 +342,13 @@ class TensorParallelShardingStrategy(ABC):
|
||||
self.N = group.size()
|
||||
|
||||
@abstractmethod
|
||||
def shard_model(
|
||||
self,
|
||||
model: nn.Module,
|
||||
timeout_seconds: float,
|
||||
on_timeout: TimeoutCallback | None,
|
||||
) -> nn.Module: ...
|
||||
def shard_model(self, model: nn.Module) -> nn.Module: ...
|
||||
|
||||
|
||||
class LlamaShardingStrategy(TensorParallelShardingStrategy):
|
||||
def shard_model(
|
||||
self,
|
||||
model: nn.Module,
|
||||
timeout_seconds: float,
|
||||
on_timeout: TimeoutCallback | None,
|
||||
) -> nn.Module:
|
||||
def shard_model(self, model: nn.Module) -> nn.Module:
|
||||
model = cast(LlamaModel, model)
|
||||
for layer in model.layers:
|
||||
# Force load weights before sharding to avoid FAST_SYNCH deadlock
|
||||
eval_with_timeout(
|
||||
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
|
||||
)
|
||||
layer.self_attn.q_proj = self.all_to_sharded_linear(layer.self_attn.q_proj)
|
||||
layer.self_attn.k_proj = self.all_to_sharded_linear(layer.self_attn.k_proj)
|
||||
layer.self_attn.v_proj = self.all_to_sharded_linear(layer.self_attn.v_proj)
|
||||
@@ -427,7 +364,7 @@ class LlamaShardingStrategy(TensorParallelShardingStrategy):
|
||||
return model
|
||||
|
||||
|
||||
def _set_layers(model: nn.Module, layers: list[nn.Module]) -> None:
|
||||
def _set_layers(model: nn.Module, layers: list[_LayerCallable]) -> None:
|
||||
inner_model_instance = _inner_model(model)
|
||||
if hasattr(inner_model_instance, "layers"):
|
||||
inner_model_instance.layers = layers
|
||||
@@ -454,17 +391,9 @@ def _set_layers(model: nn.Module, layers: list[nn.Module]) -> None:
|
||||
|
||||
|
||||
class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
|
||||
def shard_model(
|
||||
self,
|
||||
model: nn.Module,
|
||||
timeout_seconds: float,
|
||||
on_timeout: TimeoutCallback | None,
|
||||
) -> nn.Module:
|
||||
def shard_model(self, model: nn.Module) -> nn.Module:
|
||||
model = cast(DeepseekV3Model, model)
|
||||
for layer in model.layers:
|
||||
eval_with_timeout(
|
||||
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
|
||||
)
|
||||
# Shard the self attention
|
||||
if layer.self_attn.q_lora_rank is None:
|
||||
layer.self_attn.q_proj = self.all_to_sharded_linear(
|
||||
@@ -502,31 +431,23 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
|
||||
|
||||
|
||||
class ShardedDeepseekV3MoE(CustomMlxLayer):
|
||||
def __init__(self, layer: nn.Module):
|
||||
def __init__(self, layer: _LayerCallable):
|
||||
super().__init__(layer)
|
||||
self.sharding_group: mx.distributed.Group | None = None
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
if self.sharding_group is not None:
|
||||
x = sum_gradients(self.sharding_group)(x)
|
||||
y = self.original_layer.__call__(x) # type: ignore
|
||||
y = self.original_layer.__call__(x)
|
||||
if self.sharding_group is not None:
|
||||
y = mx.distributed.all_sum(y, group=self.sharding_group) # type: ignore
|
||||
return y # type: ignore
|
||||
y = mx.distributed.all_sum(y, group=self.sharding_group)
|
||||
return y
|
||||
|
||||
|
||||
class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
|
||||
def shard_model(
|
||||
self,
|
||||
model: nn.Module,
|
||||
timeout_seconds: float,
|
||||
on_timeout: TimeoutCallback | None,
|
||||
) -> nn.Module:
|
||||
def shard_model(self, model: nn.Module) -> nn.Module:
|
||||
model = cast(MiniMaxModel, model)
|
||||
for layer in model.layers:
|
||||
eval_with_timeout(
|
||||
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
|
||||
)
|
||||
# Shard the self attention
|
||||
layer.self_attn.q_proj = self.all_to_sharded_linear(layer.self_attn.q_proj)
|
||||
layer.self_attn.k_proj = self.all_to_sharded_linear(layer.self_attn.k_proj)
|
||||
@@ -546,24 +467,16 @@ class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
|
||||
self.all_to_sharded_linear_in_place(
|
||||
layer.block_sparse_moe.switch_mlp.up_proj
|
||||
)
|
||||
layer.block_sparse_moe = ShardedQwenMoE(layer.block_sparse_moe) # pyright: ignore[reportAttributeAccessIssue]
|
||||
layer.block_sparse_moe.sharding_group = self.group # pyright: ignore[reportAttributeAccessIssue]
|
||||
layer.block_sparse_moe = ShardedQwenMoE(layer.block_sparse_moe) # pyright: ignore[reportAttributeAccessIssue, reportArgumentType]
|
||||
layer.block_sparse_moe.sharding_group = self.group
|
||||
|
||||
return model
|
||||
|
||||
|
||||
class QwenShardingStrategy(TensorParallelShardingStrategy):
|
||||
def shard_model(
|
||||
self,
|
||||
model: nn.Module,
|
||||
timeout_seconds: float,
|
||||
on_timeout: TimeoutCallback | None,
|
||||
) -> nn.Module:
|
||||
def shard_model(self, model: nn.Module) -> nn.Module:
|
||||
model = cast(Qwen3MoeModel, model)
|
||||
for layer in model.layers:
|
||||
eval_with_timeout(
|
||||
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
|
||||
)
|
||||
# Shard the self attention
|
||||
layer.self_attn.q_proj = self.all_to_sharded_linear(layer.self_attn.q_proj)
|
||||
layer.self_attn.k_proj = self.all_to_sharded_linear(layer.self_attn.k_proj)
|
||||
@@ -580,7 +493,7 @@ class QwenShardingStrategy(TensorParallelShardingStrategy):
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.gate_proj)
|
||||
self.sharded_to_all_linear_in_place(layer.mlp.switch_mlp.down_proj)
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.up_proj)
|
||||
layer.mlp = ShardedQwenMoE(layer.mlp) # pyright: ignore[reportAttributeAccessIssue]
|
||||
layer.mlp = ShardedQwenMoE(layer.mlp) # pyright: ignore[reportAttributeAccessIssue, reportArgumentType]
|
||||
layer.mlp.sharding_group = self.group
|
||||
|
||||
# Shard the MLP
|
||||
@@ -593,32 +506,24 @@ class QwenShardingStrategy(TensorParallelShardingStrategy):
|
||||
|
||||
|
||||
class ShardedQwenMoE(CustomMlxLayer):
|
||||
def __init__(self, layer: nn.Module):
|
||||
def __init__(self, layer: _LayerCallable):
|
||||
super().__init__(layer)
|
||||
self.sharding_group: mx.distributed.Group | None = None
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
if self.sharding_group is not None:
|
||||
x = sum_gradients(self.sharding_group)(x)
|
||||
y = self.original_layer.__call__(x) # type: ignore
|
||||
y = self.original_layer.__call__(x)
|
||||
if self.sharding_group is not None:
|
||||
y = mx.distributed.all_sum(y, group=self.sharding_group) # type: ignore
|
||||
return y # type: ignore
|
||||
y = mx.distributed.all_sum(y, group=self.sharding_group)
|
||||
return y
|
||||
|
||||
|
||||
class GptOssShardingStrategy(TensorParallelShardingStrategy):
|
||||
def shard_model(
|
||||
self,
|
||||
model: nn.Module,
|
||||
timeout_seconds: float,
|
||||
on_timeout: TimeoutCallback | None,
|
||||
) -> nn.Module:
|
||||
def shard_model(self, model: nn.Module) -> nn.Module:
|
||||
model = cast(GptOssMoeModel, model)
|
||||
|
||||
for layer in model.layers:
|
||||
eval_with_timeout(
|
||||
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
|
||||
)
|
||||
layer.self_attn.q_proj = self.all_to_sharded_linear(layer.self_attn.q_proj)
|
||||
layer.self_attn.k_proj = self.all_to_sharded_linear(layer.self_attn.k_proj)
|
||||
layer.self_attn.v_proj = self.all_to_sharded_linear(layer.self_attn.v_proj)
|
||||
@@ -642,7 +547,7 @@ class GptOssShardingStrategy(TensorParallelShardingStrategy):
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.experts.up_proj)
|
||||
|
||||
layer.mlp = ShardedGptOssMoE(layer.mlp) # type: ignore
|
||||
layer.mlp.sharding_group = self.group # pyright: ignore[reportAttributeAccessIssue]
|
||||
layer.mlp.sharding_group = self.group
|
||||
|
||||
return model
|
||||
|
||||
@@ -655,7 +560,7 @@ class ShardedGptOssMoE(CustomMlxLayer):
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
if self.sharding_group is not None:
|
||||
x = sum_gradients(self.sharding_group)(x)
|
||||
y = self.original_layer(x) # type: ignore
|
||||
y = self.original_layer(x)
|
||||
if self.sharding_group is not None:
|
||||
y = mx.distributed.all_sum(y, group=self.sharding_group) # type: ignore
|
||||
return y # type: ignore
|
||||
y = mx.distributed.all_sum(y, group=self.sharding_group)
|
||||
return y
|
||||
|
||||
@@ -2,7 +2,9 @@ import json
|
||||
import os
|
||||
import resource
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
from typing import Any, cast
|
||||
|
||||
@@ -57,8 +59,6 @@ from exo.shared.types.worker.shards import (
|
||||
from exo.worker.download.download_utils import build_model_path
|
||||
from exo.worker.engines.mlx import Model
|
||||
from exo.worker.engines.mlx.auto_parallel import (
|
||||
TimeoutCallback,
|
||||
eval_with_timeout,
|
||||
pipeline_auto_parallel,
|
||||
tensor_auto_parallel,
|
||||
)
|
||||
@@ -75,7 +75,7 @@ def get_weights_size(model_shard_meta: ShardMetadata) -> Memory:
|
||||
return Memory.from_float_kb(
|
||||
(model_shard_meta.end_layer - model_shard_meta.start_layer)
|
||||
/ model_shard_meta.n_layers
|
||||
* model_shard_meta.model_card.storage_size.in_kb
|
||||
* model_shard_meta.model_meta.storage_size.in_kb
|
||||
/ (
|
||||
1
|
||||
if isinstance(model_shard_meta, PipelineShardMetadata)
|
||||
@@ -88,6 +88,41 @@ class ModelLoadingTimeoutError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
TimeoutCallback = Callable[[], None]
|
||||
|
||||
|
||||
def eval_with_timeout(
|
||||
mlx_item: Any, # pyright: ignore[reportAny]
|
||||
timeout_seconds: float = 60.0,
|
||||
on_timeout: TimeoutCallback | None = None,
|
||||
) -> None:
|
||||
"""Evaluate MLX item with a hard timeout.
|
||||
|
||||
If on_timeout callback is provided, it will be called before terminating
|
||||
the process. This allows the runner to send a failure event before exit.
|
||||
"""
|
||||
completed = threading.Event()
|
||||
|
||||
def watchdog() -> None:
|
||||
if not completed.wait(timeout=timeout_seconds):
|
||||
logger.error(
|
||||
f"mlx_item evaluation timed out after {timeout_seconds:.0f}s. "
|
||||
"This may indicate an issue with FAST_SYNCH and tensor parallel sharding. "
|
||||
"Terminating process."
|
||||
)
|
||||
if on_timeout is not None:
|
||||
on_timeout()
|
||||
os._exit(1)
|
||||
|
||||
watchdog_thread = threading.Thread(target=watchdog, daemon=True)
|
||||
watchdog_thread.start()
|
||||
|
||||
try:
|
||||
mx.eval(mlx_item) # pyright: ignore[reportAny]
|
||||
finally:
|
||||
completed.set()
|
||||
|
||||
|
||||
def mx_barrier(group: Group | None = None):
|
||||
mx.eval(
|
||||
mx.distributed.all_sum(
|
||||
@@ -169,14 +204,19 @@ def mlx_distributed_init(
|
||||
|
||||
# TODO: update once upstream fixes
|
||||
logger.info(
|
||||
f"rank {rank} MLX_JACCL_DEVICES: {coordination_file} with devices: {jaccl_devices_json}"
|
||||
f"rank {rank} MLX_IBV_DEVICES: {coordination_file} with devices: {jaccl_devices_json}"
|
||||
)
|
||||
logger.info(f"rank {rank} MLX_JACCL_COORDINATOR: {jaccl_coordinator}")
|
||||
os.environ["MLX_JACCL_DEVICES"] = coordination_file
|
||||
os.environ["MLX_IBV_DEVICES"] = coordination_file
|
||||
os.environ["MLX_RANK"] = str(rank)
|
||||
os.environ["MLX_JACCL_COORDINATOR"] = jaccl_coordinator
|
||||
group = mx.distributed.init(backend="jaccl", strict=True)
|
||||
|
||||
case _:
|
||||
raise ValueError(
|
||||
f"Unsupported instance type for MLX distributed: {type(bound_instance.instance)}"
|
||||
)
|
||||
|
||||
logger.info(f"Rank {rank} mlx distributed initialization complete")
|
||||
|
||||
return group
|
||||
@@ -206,7 +246,7 @@ def load_mlx_items(
|
||||
) -> tuple[Model, TokenizerWrapper]:
|
||||
if group is None:
|
||||
logger.info(f"Single device used for {bound_instance.instance}")
|
||||
model_path = build_model_path(bound_instance.bound_shard.model_card.model_id)
|
||||
model_path = build_model_path(bound_instance.bound_shard.model_meta.model_id)
|
||||
start_time = time.perf_counter()
|
||||
model, _ = load_model(model_path, strict=True)
|
||||
end_time = time.perf_counter()
|
||||
@@ -234,7 +274,7 @@ def shard_and_load(
|
||||
group: Group,
|
||||
on_timeout: TimeoutCallback | None = None,
|
||||
) -> tuple[nn.Module, TokenizerWrapper]:
|
||||
model_path = build_model_path(shard_metadata.model_card.model_id)
|
||||
model_path = build_model_path(shard_metadata.model_meta.model_id)
|
||||
|
||||
model, _ = load_model(model_path, lazy=True, strict=False)
|
||||
logger.debug(model)
|
||||
@@ -261,6 +301,14 @@ def shard_and_load(
|
||||
|
||||
logger.info(f"Group size: {group.size()}, group rank: {group.rank()}")
|
||||
|
||||
match shard_metadata:
|
||||
case TensorShardMetadata():
|
||||
logger.info(f"loading model from {model_path} with tensor parallelism")
|
||||
model = tensor_auto_parallel(model, group)
|
||||
case PipelineShardMetadata():
|
||||
logger.info(f"loading model from {model_path} with pipeline parallelism")
|
||||
model = pipeline_auto_parallel(model, group, shard_metadata)
|
||||
|
||||
# Estimate timeout based on model size
|
||||
base_timeout = float(os.environ.get("EXO_MODEL_LOAD_TIMEOUT", "60"))
|
||||
model_size_gb = get_weights_size(shard_metadata).in_bytes / (1024**3)
|
||||
@@ -269,15 +317,7 @@ def shard_and_load(
|
||||
f"Evaluating model parameters with timeout of {timeout_seconds:.0f}s "
|
||||
f"(model size: {model_size_gb:.1f}GB)"
|
||||
)
|
||||
|
||||
match shard_metadata:
|
||||
case TensorShardMetadata():
|
||||
logger.info(f"loading model from {model_path} with tensor parallelism")
|
||||
model = tensor_auto_parallel(model, group, timeout_seconds, on_timeout)
|
||||
case PipelineShardMetadata():
|
||||
logger.info(f"loading model from {model_path} with pipeline parallelism")
|
||||
model = pipeline_auto_parallel(model, group, shard_metadata)
|
||||
eval_with_timeout(model.parameters(), timeout_seconds, on_timeout)
|
||||
eval_with_timeout(model.parameters(), timeout_seconds, on_timeout)
|
||||
|
||||
# TODO: Do we need this?
|
||||
mx.eval(model)
|
||||
@@ -293,7 +333,7 @@ def shard_and_load(
|
||||
|
||||
def get_tokenizer(model_path: Path, shard_metadata: ShardMetadata) -> TokenizerWrapper:
|
||||
"""Load tokenizer for a model shard. Delegates to load_tokenizer_for_model_id."""
|
||||
return load_tokenizer_for_model_id(shard_metadata.model_card.model_id, model_path)
|
||||
return load_tokenizer_for_model_id(shard_metadata.model_meta.model_id, model_path)
|
||||
|
||||
|
||||
def get_eos_token_ids_for_model(model_id: str) -> list[int] | None:
|
||||
@@ -312,9 +352,6 @@ def get_eos_token_ids_for_model(model_id: str) -> list[int] | None:
|
||||
model_id_lower = model_id.lower()
|
||||
if "kimi-k2" in model_id_lower:
|
||||
return [163586]
|
||||
elif "glm-4.7-flash" in model_id_lower:
|
||||
# 154820: <|endoftext|>, 154827: <|user|>, 154829: <|observation|>
|
||||
return [154820, 154827, 154829]
|
||||
elif "glm" in model_id_lower:
|
||||
return [151336, 151329, 151338]
|
||||
return None
|
||||
|
||||
@@ -8,7 +8,6 @@ from loguru import logger
|
||||
|
||||
from exo.routing.connection_message import ConnectionMessage, ConnectionMessageType
|
||||
from exo.shared.apply import apply
|
||||
from exo.shared.models.model_cards import ModelId
|
||||
from exo.shared.types.commands import ForwarderCommand, RequestEventLog
|
||||
from exo.shared.types.common import NodeId, SessionId
|
||||
from exo.shared.types.events import (
|
||||
@@ -23,6 +22,7 @@ from exo.shared.types.events import (
|
||||
TopologyEdgeCreated,
|
||||
TopologyEdgeDeleted,
|
||||
)
|
||||
from exo.shared.types.models import ModelId
|
||||
from exo.shared.types.multiaddr import Multiaddr
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.types.tasks import (
|
||||
@@ -186,11 +186,11 @@ class Worker:
|
||||
)
|
||||
)
|
||||
case DownloadModel(shard_metadata=shard):
|
||||
if shard.model_card.model_id not in self.download_status:
|
||||
if shard.model_meta.model_id not in self.download_status:
|
||||
progress = DownloadPending(
|
||||
shard_metadata=shard, node_id=self.node_id
|
||||
)
|
||||
self.download_status[shard.model_card.model_id] = progress
|
||||
self.download_status[shard.model_meta.model_id] = progress
|
||||
await self.event_sender.send(
|
||||
NodeDownloadProgress(download_progress=progress)
|
||||
)
|
||||
@@ -205,7 +205,7 @@ class Worker:
|
||||
node_id=self.node_id,
|
||||
total_bytes=initial_progress.total_bytes,
|
||||
)
|
||||
self.download_status[shard.model_card.model_id] = progress
|
||||
self.download_status[shard.model_meta.model_id] = progress
|
||||
await self.event_sender.send(
|
||||
NodeDownloadProgress(download_progress=progress)
|
||||
)
|
||||
@@ -339,7 +339,7 @@ class Worker:
|
||||
initial_progress
|
||||
),
|
||||
)
|
||||
self.download_status[task.shard_metadata.model_card.model_id] = status
|
||||
self.download_status[task.shard_metadata.model_meta.model_id] = status
|
||||
self.event_sender.send_nowait(NodeDownloadProgress(download_progress=status))
|
||||
|
||||
last_progress_time = 0.0
|
||||
@@ -356,7 +356,7 @@ class Worker:
|
||||
node_id=self.node_id,
|
||||
total_bytes=progress.total_bytes,
|
||||
)
|
||||
self.download_status[shard.model_card.model_id] = status
|
||||
self.download_status[shard.model_meta.model_id] = status
|
||||
await self.event_sender.send(
|
||||
NodeDownloadProgress(download_progress=status)
|
||||
)
|
||||
@@ -376,7 +376,7 @@ class Worker:
|
||||
progress
|
||||
),
|
||||
)
|
||||
self.download_status[shard.model_card.model_id] = status
|
||||
self.download_status[shard.model_meta.model_id] = status
|
||||
await self.event_sender.send(
|
||||
NodeDownloadProgress(download_progress=status)
|
||||
)
|
||||
@@ -413,6 +413,11 @@ class Worker:
|
||||
)
|
||||
for nid in conns:
|
||||
for ip in conns[nid]:
|
||||
if "127.0.0.1" in ip or "localhost" in ip:
|
||||
logger.warning(
|
||||
f"Loopback connection should not happen: {ip=} for {nid=}"
|
||||
)
|
||||
|
||||
edge = SocketConnection(
|
||||
# nonsense multiaddr
|
||||
sink_multiaddr=Multiaddr(address=f"/ip4/{ip}/tcp/52415")
|
||||
@@ -433,9 +438,6 @@ class Worker:
|
||||
for conn in self.state.topology.out_edges(self.node_id):
|
||||
if not isinstance(conn.edge, SocketConnection):
|
||||
continue
|
||||
# ignore mDNS discovered connections
|
||||
if conn.edge.sink_multiaddr.port != 52415:
|
||||
continue
|
||||
if (
|
||||
conn.sink not in conns
|
||||
or conn.edge.sink_multiaddr.ip_address
|
||||
@@ -476,7 +478,7 @@ class Worker:
|
||||
else:
|
||||
continue
|
||||
|
||||
self.download_status[progress.shard.model_card.model_id] = status
|
||||
self.download_status[progress.shard.model_meta.model_id] = status
|
||||
await self.event_sender.send(
|
||||
NodeDownloadProgress(download_progress=status)
|
||||
)
|
||||
|
||||
@@ -2,8 +2,8 @@
|
||||
|
||||
from collections.abc import Mapping, Sequence
|
||||
|
||||
from exo.shared.models.model_cards import ModelId
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.models import ModelId
|
||||
from exo.shared.types.tasks import (
|
||||
ChatCompletion,
|
||||
ConnectToGroup,
|
||||
@@ -21,7 +21,11 @@ from exo.shared.types.worker.downloads import (
|
||||
DownloadOngoing,
|
||||
DownloadProgress,
|
||||
)
|
||||
from exo.shared.types.worker.instances import BoundInstance, Instance, InstanceId
|
||||
from exo.shared.types.worker.instances import (
|
||||
BoundInstance,
|
||||
Instance,
|
||||
InstanceId,
|
||||
)
|
||||
from exo.shared.types.worker.runners import (
|
||||
RunnerConnected,
|
||||
RunnerConnecting,
|
||||
@@ -50,6 +54,16 @@ def plan(
|
||||
all_runners: Mapping[RunnerId, RunnerStatus], # all global
|
||||
tasks: Mapping[TaskId, Task],
|
||||
) -> Task | None:
|
||||
from exo.plugins.registry import PluginRegistry
|
||||
|
||||
registry = PluginRegistry.get()
|
||||
|
||||
# Check plugin tasks first
|
||||
for plugin in registry.all_plugins():
|
||||
task = plugin.plan_task(runners, instances)
|
||||
if task is not None:
|
||||
return task
|
||||
|
||||
# Python short circuiting OR logic should evaluate these sequentially.
|
||||
return (
|
||||
_kill_runner(runners, all_runners, instances)
|
||||
@@ -113,8 +127,19 @@ def _model_needs_download(
|
||||
runners: Mapping[RunnerId, RunnerSupervisor],
|
||||
download_status: Mapping[ModelId, DownloadProgress],
|
||||
) -> DownloadModel | None:
|
||||
from exo.plugins.registry import PluginRegistry
|
||||
|
||||
registry = PluginRegistry.get()
|
||||
|
||||
for runner in runners.values():
|
||||
model_id = runner.bound_instance.bound_shard.model_card.model_id
|
||||
instance = runner.bound_instance.instance
|
||||
|
||||
# Check if any plugin wants to skip download for this instance
|
||||
plugin = registry.get_plugin_for_instance(instance)
|
||||
if plugin is not None and plugin.should_skip_download(instance):
|
||||
continue
|
||||
|
||||
model_id = runner.bound_instance.bound_shard.model_meta.model_id
|
||||
if isinstance(runner.status, RunnerIdle) and (
|
||||
model_id not in download_status
|
||||
or not isinstance(
|
||||
@@ -191,7 +216,7 @@ def _load_model(
|
||||
nid in global_download_status
|
||||
and any(
|
||||
isinstance(dp, DownloadCompleted)
|
||||
and dp.shard_metadata.model_card.model_id == shard_assignments.model_id
|
||||
and dp.shard_metadata.model_meta.model_id == shard_assignments.model_id
|
||||
for dp in global_download_status[nid]
|
||||
)
|
||||
for nid in shard_assignments.node_to_runner
|
||||
|
||||
@@ -4,7 +4,10 @@ import loguru
|
||||
|
||||
from exo.shared.types.events import Event, RunnerStatusUpdated
|
||||
from exo.shared.types.tasks import Task
|
||||
from exo.shared.types.worker.instances import BoundInstance, MlxJacclInstance
|
||||
from exo.shared.types.worker.instances import (
|
||||
BoundInstance,
|
||||
MlxJacclInstance,
|
||||
)
|
||||
from exo.shared.types.worker.runners import RunnerFailed
|
||||
from exo.utils.channels import ClosedResourceError, MpReceiver, MpSender
|
||||
|
||||
@@ -17,6 +20,7 @@ def entrypoint(
|
||||
task_receiver: MpReceiver[Task],
|
||||
_logger: "loguru.Logger",
|
||||
) -> None:
|
||||
# Set FAST_SYNCH based on env var or JACCL device count
|
||||
fast_synch_override = os.environ.get("EXO_FAST_SYNCH")
|
||||
if fast_synch_override == "on" or (
|
||||
fast_synch_override != "off"
|
||||
@@ -34,11 +38,26 @@ def entrypoint(
|
||||
|
||||
logger.info(f"Fast synch flag: {os.environ['MLX_METAL_FAST_SYNCH']}")
|
||||
|
||||
# Import main after setting global logger - this lets us just import logger from this module
|
||||
# Route based on instance type (plugins or default MLX)
|
||||
try:
|
||||
from exo.worker.runner.runner import main
|
||||
from exo.plugins.registry import PluginRegistry, discover_plugins
|
||||
|
||||
main(bound_instance, event_sender, task_receiver)
|
||||
# Discover plugins in subprocess (they aren't inherited from main process)
|
||||
discover_plugins()
|
||||
|
||||
registry = PluginRegistry.get()
|
||||
instance = bound_instance.instance
|
||||
|
||||
# Check if a plugin handles this instance type
|
||||
plugin = registry.get_plugin_for_instance(instance)
|
||||
if plugin is not None:
|
||||
# Delegate to plugin runner
|
||||
plugin.create_runner(bound_instance, event_sender, task_receiver)
|
||||
else:
|
||||
# MLX runner (default)
|
||||
from exo.worker.runner.runner import main
|
||||
|
||||
main(bound_instance, event_sender, task_receiver)
|
||||
except ClosedResourceError:
|
||||
logger.warning("Runner communication closed unexpectedly")
|
||||
except Exception as e:
|
||||
|
||||
@@ -213,7 +213,7 @@ def main(
|
||||
command_id=command_id,
|
||||
chunk=TokenChunk(
|
||||
idx=response.token,
|
||||
model=shard_metadata.model_card.model_id,
|
||||
model=shard_metadata.model_meta.model_id,
|
||||
text=response.text,
|
||||
token_id=response.token,
|
||||
finish_reason=response.finish_reason,
|
||||
@@ -230,7 +230,7 @@ def main(
|
||||
command_id=command_id,
|
||||
chunk=TokenChunk(
|
||||
idx=0,
|
||||
model=shard_metadata.model_card.model_id,
|
||||
model=shard_metadata.model_meta.model_id,
|
||||
text="",
|
||||
token_id=0,
|
||||
finish_reason="error",
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import Final
|
||||
|
||||
from exo.shared.models.model_cards import ModelId
|
||||
from exo.shared.types.common import CommandId, NodeId
|
||||
from exo.shared.types.models import ModelId
|
||||
from exo.shared.types.tasks import TaskId
|
||||
from exo.shared.types.worker.instances import InstanceId, RunnerId
|
||||
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from exo.shared.models.model_cards import ModelCard, ModelId
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.models import ModelId, ModelMetadata
|
||||
from exo.shared.types.tasks import BaseTask, TaskId
|
||||
from exo.shared.types.worker.instances import (
|
||||
BoundInstance,
|
||||
@@ -32,8 +32,9 @@ def get_pipeline_shard_metadata(
|
||||
model_id: ModelId, device_rank: int, world_size: int = 1
|
||||
) -> ShardMetadata:
|
||||
return PipelineShardMetadata(
|
||||
model_card=ModelCard(
|
||||
model_meta=ModelMetadata(
|
||||
model_id=model_id,
|
||||
pretty_name=str(model_id),
|
||||
storage_size=Memory.from_mb(100000),
|
||||
n_layers=32,
|
||||
hidden_size=2048,
|
||||
|
||||
@@ -11,9 +11,9 @@ import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from exo.shared.constants import EXO_MODELS_DIR
|
||||
from exo.shared.models.model_cards import ModelCard, ModelId
|
||||
from exo.shared.types.api import ChatCompletionMessage
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.models import ModelId, ModelMetadata
|
||||
from exo.shared.types.tasks import ChatCompletionTaskParams
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata, TensorShardMetadata
|
||||
from exo.worker.engines.mlx import Model
|
||||
@@ -81,8 +81,9 @@ def run_gpt_oss_pipeline_device(
|
||||
start_layer, end_layer = layer_splits[rank]
|
||||
|
||||
shard_meta = PipelineShardMetadata(
|
||||
model_card=ModelCard(
|
||||
model_meta=ModelMetadata(
|
||||
model_id=ModelId(DEFAULT_GPT_OSS_MODEL_ID),
|
||||
pretty_name="GPT-OSS 20B",
|
||||
storage_size=Memory.from_gb(12),
|
||||
n_layers=24,
|
||||
hidden_size=2880,
|
||||
@@ -150,8 +151,9 @@ def run_gpt_oss_tensor_parallel_device(
|
||||
|
||||
# For tensor parallelism, all devices run all layers
|
||||
shard_meta = TensorShardMetadata(
|
||||
model_card=ModelCard(
|
||||
model_meta=ModelMetadata(
|
||||
model_id=ModelId(DEFAULT_GPT_OSS_MODEL_ID),
|
||||
pretty_name="GPT-OSS 20B",
|
||||
storage_size=Memory.from_gb(12),
|
||||
n_layers=24,
|
||||
hidden_size=2880,
|
||||
|
||||
@@ -18,7 +18,6 @@ def _check_model_exists() -> bool:
|
||||
|
||||
|
||||
pytestmark = [
|
||||
pytest.mark.slow,
|
||||
pytest.mark.skipif(
|
||||
not _check_model_exists(),
|
||||
reason=f"GPT-OSS model not found at {DEFAULT_GPT_OSS_CONFIG.model_path}",
|
||||
|
||||
@@ -76,21 +76,19 @@ def get_test_models() -> list[tuple[str, ModelCard]]:
|
||||
"""Get a representative sample of models to test."""
|
||||
# Pick one model from each family to test
|
||||
families: dict[str, tuple[str, ModelCard]] = {}
|
||||
for _, card in MODEL_CARDS.items():
|
||||
for short_id, card in MODEL_CARDS.items():
|
||||
# Extract family name (e.g., "llama-3.1" from "llama-3.1-8b")
|
||||
parts = card.model_id.short().split("-")
|
||||
parts = short_id.split("-")
|
||||
family = "-".join(parts[:2]) if len(parts) >= 2 else parts[0]
|
||||
|
||||
if family not in families:
|
||||
families[family] = (card.model_id.short(), card)
|
||||
families[family] = (short_id, card)
|
||||
|
||||
return list(families.values())
|
||||
|
||||
|
||||
TEST_MODELS: list[tuple[str, ModelCard]] = get_test_models()
|
||||
|
||||
pytestmark = pytest.mark.slow
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def event_loop():
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import exo.worker.plan as plan_mod
|
||||
from exo.shared.models.model_cards import ModelId
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.models import ModelId
|
||||
from exo.shared.types.tasks import LoadModel
|
||||
from exo.shared.types.worker.downloads import DownloadCompleted, DownloadProgress
|
||||
from exo.shared.types.worker.instances import BoundInstance
|
||||
|
||||
@@ -82,7 +82,7 @@ async def tb_detection():
|
||||
send, recv = channel[GatheredInfo]()
|
||||
ig = InfoGatherer(send)
|
||||
with anyio.move_on_after(1):
|
||||
await ig._monitor_system_profiler_thunderbolt_data() # pyright: ignore[reportPrivateUsage]
|
||||
await ig._monitor_system_profiler() # pyright: ignore[reportPrivateUsage]
|
||||
with recv:
|
||||
return recv.collect()
|
||||
|
||||
@@ -135,7 +135,7 @@ def ring_instance(test: Tests, iid: InstanceId, hn: str) -> Instance:
|
||||
else:
|
||||
raise ValueError(f"{hn} not in {test.devs}")
|
||||
|
||||
card = MODEL_CARDS[test.model_id]
|
||||
meta = MODEL_CARDS[test.model_id].metadata
|
||||
instance = MlxRingInstance(
|
||||
instance_id=iid,
|
||||
ephemeral_port=52416,
|
||||
@@ -145,15 +145,15 @@ def ring_instance(test: Tests, iid: InstanceId, hn: str) -> Instance:
|
||||
node_to_runner={NodeId(host[0]): RunnerId(host[0]) for host in test.devs},
|
||||
runner_to_shard={
|
||||
RunnerId(test.devs[i][0]): PipelineShardMetadata(
|
||||
model_card=card,
|
||||
model_meta=meta,
|
||||
device_rank=i,
|
||||
world_size=world_size,
|
||||
start_layer=(card.n_layers // world_size) * i,
|
||||
start_layer=(meta.n_layers // world_size) * i,
|
||||
end_layer=min(
|
||||
card.n_layers, (card.n_layers // world_size) * (i + 1)
|
||||
meta.n_layers, (meta.n_layers // world_size) * (i + 1)
|
||||
),
|
||||
n_layers=min(card.n_layers, (card.n_layers // world_size) * (i + 1))
|
||||
- (card.n_layers // world_size) * i,
|
||||
n_layers=min(meta.n_layers, (meta.n_layers // world_size) * (i + 1))
|
||||
- (meta.n_layers // world_size) * i,
|
||||
)
|
||||
for i in range(world_size)
|
||||
},
|
||||
@@ -224,7 +224,7 @@ async def jaccl_backend(test: Tests):
|
||||
|
||||
|
||||
def jaccl_instance(test: Tests, iid: InstanceId):
|
||||
card = MODEL_CARDS[test.model_id]
|
||||
meta = MODEL_CARDS[test.model_id].metadata
|
||||
world_size = len(test.devs)
|
||||
|
||||
return MlxJacclInstance(
|
||||
@@ -239,12 +239,12 @@ def jaccl_instance(test: Tests, iid: InstanceId):
|
||||
node_to_runner={NodeId(host[0]): RunnerId(host[0]) for host in test.devs},
|
||||
runner_to_shard={
|
||||
RunnerId(test.devs[i][0]): TensorShardMetadata(
|
||||
model_card=card,
|
||||
model_meta=meta,
|
||||
device_rank=i,
|
||||
world_size=world_size,
|
||||
start_layer=card.n_layers,
|
||||
end_layer=card.n_layers,
|
||||
n_layers=card.n_layers,
|
||||
start_layer=meta.n_layers,
|
||||
end_layer=meta.n_layers,
|
||||
n_layers=meta.n_layers,
|
||||
)
|
||||
for i in range(world_size)
|
||||
},
|
||||
|
||||
Reference in New Issue
Block a user