Compare commits

..

13 Commits

Author SHA1 Message Date
Ryuichi Leo Takashige
541339aae6 Dont warn on single node jaccl placement 2026-01-20 18:28:51 +00:00
rltakashige
758464703d Fix GPT OSS tensor sharding with upstream MLX LM (#1223)
## Motivation
MLX LM has given GPT OSS a shard method, but MLX does not have an update
to match.

## Changes

<!-- Describe what you changed in detail -->

## Why It Works

<!-- Explain why your approach solves the problem -->

## Test Plan

### Manual Testing
<!-- Hardware: (e.g., MacBook Pro M1 Max 32GB, Mac Mini M2 16GB,
connected via Thunderbolt 4) -->
<!-- What you did: -->
<!-- - -->

### Automated Testing
<!-- Describe changes to automated tests, or how existing tests cover
this change -->
<!-- - -->
2026-01-20 18:24:54 +00:00
rltakashige
9e2179c848 Register original layer in CustomMlxLayer (#1229)
## Motivation
Kimi K2 Thinking Pipeline RDMA was broken before.

## Why It Works
No clue tbh

## Test Plan

### Manual Testing
Kimi K2 Thinking and GPT OSS work at the same time on Pipeline RDMA.
Needs exo bench to check more thoroughly

### Automated Testing
Layer composition tests still pass.
2026-01-20 18:20:01 +00:00
Evan Quiney
22b5d836ef swap all instances of model_id: str for model_id: ModelId (#1221)
This change uses the stronger typed ModelId, and introduces some
convenience methods. It also cleans up some code left over from #1204.

## Changes

`model_id: str -> model_id: ModelId`
`repo_id: str -> model_id: ModelId`

Introduces methods on ModelId, in particular ModelId.normalize() to
replace `/` with `--`.

This PR did introduce some circular imports, so has moved some code
around to try and limit them.

## Test Plan

Tests still pass, types still check. As this is about metadata, I
haven't tested inference.
2026-01-20 17:38:06 +00:00
Alex Cheema
ea9c6d6bdf Remove dead local paths code from download_shard (#1227)
## Motivation

The `download_progress_for_local_path` function and the "Handle local
paths" code block in `download_shard` are dead code that cannot be
reached in normal usage. The code checks if `model_id` (e.g.,
"mlx-community/Llama-3.2-3B-Instruct-4bit") exists as a filesystem path,
but model IDs are constrained to HuggingFace repo format and there's no
API pathway to pass local paths.

## Changes

- Removed `download_progress_for_local_path()` function (45 lines)
- Removed the "Handle local paths" block in `download_shard()` (7 lines)

## Why It Works

This code was added in PR #669 as part of a "feature-local-models"
branch, but the feature was never fully integrated. The check
`aios.path.exists(str(shard.model_card.model_id))` would only return
true if a directory literally named
"mlx-community/Llama-3.2-3B-Instruct-4bit" existed in the cwd, which
doesn't happen in practice. Offline caching is already handled by
`fetch_file_list_with_cache`.

## Test Plan

### Manual Testing
- Run exo normally and verify downloads still work

### Automated Testing
- Existing tests pass (this code had no test coverage)

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-20 17:07:27 +00:00
Alex Cheema
4ea66d427b Reduce download log spam (#1225)
## Motivation

When `skip_download=True`, exo was logging a lot of unnecessary messages during periodic download status checks. This resulted in spammy logs that made it hard to see important messages.

## Changes

- Only log "Downloading ... with allow_patterns=..." when actually downloading (not when skip_download is true)
- Changed periodic download progress check logs from INFO to DEBUG level

## Why It Works

The `skip_download=True` parameter is used when checking download status without actually downloading. By guarding the log behind `if not skip_download:`, we avoid logging on every status check. Changing the periodic emitting logs to DEBUG level reduces noise while still keeping them available for debugging.

## Test Plan

### Manual Testing
- Run exo and observe that logs are less spammy during normal operation
- Use -v or -vv flags to see DEBUG logs when needed

### Automated Testing
- Existing tests cover this code path
2026-01-20 16:57:05 +00:00
rltakashige
8b709e68b2 Mark slow tests as slow (#1220)
## Motivation

<!-- Why is this change needed? What problem does it solve? -->
<!-- If it fixes an open issue, please link to the issue here -->

## Changes

<!-- Describe what you changed in detail -->

## Why It Works

<!-- Explain why your approach solves the problem -->

## Test Plan

### Manual Testing
<!-- Hardware: (e.g., MacBook Pro M1 Max 32GB, Mac Mini M2 16GB,
connected via Thunderbolt 4) -->
<!-- What you did: -->
<!-- - -->

### Automated Testing
<!-- Describe changes to automated tests, or how existing tests cover
this change -->
<!-- - -->
2026-01-20 15:03:46 +00:00
Evan Quiney
4da6eeb11f fix a test broken by #1204 (#1219)
bad merge broke a test - fix it
2026-01-20 14:56:20 +00:00
Evan
3d2eee4884 quiet localhost log
this log is just noise - remove it
2026-01-20 14:51:26 +00:00
Evan
116558839e don't clear mdns discovered connections
pingers currently removes mdns discovered connections - these systems
should be independent
2026-01-20 14:46:20 +00:00
Evan Quiney
d4f551c602 Simplify model cards (#1204)
## Motivation

We have a lot of unneeded data in the model card - lets just keep the
necessary stuff and add back more data when we need it

## Test Plan

EXO still runs! (pipeline on 2)

Co-authored-by: rltakashige <rl.takashige@gmail.com>
2026-01-20 11:01:19 +00:00
Alex Cheema
176ab5ba40 Add GLM-4.7-Flash model cards (4bit, 5bit, 6bit, 8bit) (#1214)
## Motivation

Add support for GLM-4.7-Flash, a lighter variant of GLM-4.7 with the
`glm4_moe_lite` architecture. These models are smaller and faster while
maintaining good performance.

## Changes

1. **Added 4 new model cards** for GLM-4.7-Flash variants:
   - `glm-4.7-flash-4bit` (~18 GB)
   - `glm-4.7-flash-5bit` (~21 GB)
   - `glm-4.7-flash-6bit` (~25 GB)
   - `glm-4.7-flash-8bit` (~32 GB)

   All variants have:
   - `n_layers`: 47 (vs 91 in GLM-4.7)
   - `hidden_size`: 2048 (vs 5120 in GLM-4.7)
   - `supports_tensor`: True (native `shard()` method)

2. **Bumped mlx from 0.30.1 to 0.30.3** - required by mlx-lm 0.30.4

3. **Updated mlx-lm from 0.30.2 to 0.30.4** - adds `glm4_moe_lite`
architecture support

4. **Added type ignores** in `auto_parallel.py` for stricter type
annotations in new mlx-lm

5. **Fixed EOS token IDs** for GLM-4.7-Flash - uses different tokenizer
with IDs `[154820, 154827, 154829]` vs other GLM models' `[151336,
151329, 151338]`

6. **Renamed `MLX_IBV_DEVICES` to `MLX_JACCL_DEVICES`** - env var name
changed in new mlx

## Why It Works

The model cards follow the same pattern as existing GLM-4.7 models.
Tensor parallel support is enabled because GLM-4.7-Flash implements the
native `shard()` method in mlx-lm 0.30.4, which is automatically
detected in `auto_parallel.py`.

GLM-4.7-Flash uses a new tokenizer with different special token IDs.
Without the correct EOS tokens, generation wouldn't stop properly.

## Test Plan

### Manual Testing
Tested generation with GLM-4.7-Flash-4bit - now correctly stops at EOS
tokens.

### Automated Testing
- `basedpyright`: 0 errors
- `ruff check`: All checks passed
- `pytest`: 162/162 tests pass (excluding pre-existing
`test_distributed_fix.py` timeout failures)

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-20 03:58:09 +00:00
rltakashige
f5e6aa82d2 Load layers individually (#1211)
## Motivation

Certain models hang at model loading in tensor parallel. 

Hopefully closes #1205 

## Changes

- Load layer by layer for tensor parallel sharding
- Move eval_with_timeout to auto_parallel.py to resolve circular import.

## Why It Works

The naive way to fix this is to use load model with lazy = False and
then shard in tensor parallel. However, this requires the entire model
to be loaded into memory.

Instead, we can load layer by layer and shard after loading. There is a
very small memory footprint to this, but it is negligible.

I tried loading layer by layer after the sharding, and this allowed
model loading but got stuck at warming up.

## Test Plan

### Manual Testing
GPT OSS loads with TP and FAST SYNCH. Kimi does too.

### Automated Testing
We need to run a suite of exo_bench before merging this!
2026-01-20 03:26:51 +00:00
60 changed files with 1632 additions and 4006 deletions

View File

@@ -863,7 +863,6 @@
"integrity": "sha512-oH8tXw7EZnie8FdOWYrF7Yn4IKrqTFHhXvl8YxXxbKwTMcD/5NNCryUSEXRk2ZR4ojnub0P8rNrsVGHXWqIDtA==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"@standard-schema/spec": "^1.0.0",
"@sveltejs/acorn-typescript": "^1.0.5",
@@ -903,7 +902,6 @@
"integrity": "sha512-Y1Cs7hhTc+a5E9Va/xwKlAJoariQyHY+5zBgCZg4PFWNYQ1nMN9sjK1zhw1gK69DuqVP++sht/1GZg1aRwmAXQ==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"@sveltejs/vite-plugin-svelte-inspector": "^4.0.1",
"debug": "^4.4.1",
@@ -1520,7 +1518,6 @@
"integrity": "sha512-LCCV0HdSZZZb34qifBsyWlUmok6W7ouER+oQIGBScS8EsZsQbrtFTUrDX4hOl+CS6p7cnNC4td+qrSVGSCTUfQ==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"undici-types": "~6.21.0"
}
@@ -1530,7 +1527,6 @@
"resolved": "https://registry.npmjs.org/acorn/-/acorn-8.15.0.tgz",
"integrity": "sha512-NZyJarBfL7nWwIq+FDL6Zp/yHEhePMNnnJ0y3qfieCrmNvYct8uvtiV41UvlSe6apAfk0fY1FbWx+NwfmpvtTg==",
"license": "MIT",
"peer": true,
"bin": {
"acorn": "bin/acorn"
},
@@ -1943,7 +1939,6 @@
"integrity": "sha512-fmTRWbNMmsmWq6xJV8D19U/gw/bwrHfNXxrIN+HfZgnzqTHp9jOmKMhsTUjXOJnZOdZY9Q28y4yebKzqDKlxlQ==",
"dev": true,
"license": "ISC",
"peer": true,
"engines": {
"node": ">=12"
}
@@ -2651,7 +2646,6 @@
"integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==",
"dev": true,
"license": "MIT",
"peer": true,
"engines": {
"node": ">=12"
},
@@ -2839,7 +2833,6 @@
"resolved": "https://registry.npmjs.org/svelte/-/svelte-5.45.3.tgz",
"integrity": "sha512-ngKXNhNvwPzF43QqEhDOue7TQTrG09em1sd4HBxVF0Wr2gopAmdEWan+rgbdgK4fhBtSOTJO8bYU4chUG7VXZQ==",
"license": "MIT",
"peer": true,
"dependencies": {
"@jridgewell/remapping": "^2.3.4",
"@jridgewell/sourcemap-codec": "^1.5.0",
@@ -2984,7 +2977,6 @@
"integrity": "sha512-jl1vZzPDinLr9eUt3J/t7V6FgNEw9QjvBPdysz9KfQDD41fQrC2Y4vKQdiaUpFT4bXlb1RHhLpp8wtm6M5TgSw==",
"dev": true,
"license": "Apache-2.0",
"peer": true,
"bin": {
"tsc": "bin/tsc",
"tsserver": "bin/tsserver"
@@ -3006,7 +2998,6 @@
"integrity": "sha512-+Oxm7q9hDoLMyJOYfUYBuHQo+dkAloi33apOPP56pzj+vsdJDzr+j1NISE5pyaAuKL4A3UD34qd0lx5+kfKp2g==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"esbuild": "^0.25.0",
"fdir": "^6.4.4",

View File

@@ -434,8 +434,8 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
const shardData = shardObj[shardKeys[0]] as Record<string, unknown>;
if (!shardData) return null;
// Model meta is nested: shard.model_meta.model_id
const modelMeta = shardData.model_meta ?? shardData.modelMeta;
// Model meta is nested: shard.model_card.model_id
const modelMeta = shardData.model_card ?? shardData.modelCard;
if (!modelMeta || typeof modelMeta !== 'object') return null;
const meta = modelMeta as Record<string, unknown>;

View File

@@ -98,7 +98,7 @@
const shardData = shardObj[shardKeys[0]] as Record<string, unknown>;
if (!shardData) return null;
const modelMeta = shardData.model_meta ?? shardData.modelMeta;
const modelMeta = shardData.model_card ?? shardData.modelCard;
if (!modelMeta || typeof modelMeta !== 'object') return null;
const meta = modelMeta as Record<string, unknown>;
@@ -190,7 +190,7 @@
const shardKeys = Object.keys(shardObj);
if (shardKeys.length !== 1) return null;
const shardData = shardObj[shardKeys[0]] as Record<string, unknown>;
const modelMeta = shardData?.model_meta ?? shardData?.modelMeta;
const modelMeta = shardData?.model_card ?? shardData?.modelCard;
if (!modelMeta || typeof modelMeta !== 'object') return null;
const meta = modelMeta as Record<string, unknown>;
return (meta.prettyName as string) ?? null;

View File

@@ -17,20 +17,20 @@ dependencies = [
"loguru>=0.7.3",
"exo_pyo3_bindings", # rust bindings
"anyio==4.11.0",
"mlx==0.30.1; sys_platform == 'darwin'",
"mlx[cpu]==0.30.1; sys_platform == 'linux'",
"mlx==0.30.3; sys_platform == 'darwin'",
"mlx[cpu]==0.30.3; sys_platform == 'linux'",
"mlx-lm @ git+https://github.com/AlexCheema/mlx-lm.git@fix-transformers-5.0.0rc2",
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
"hypercorn>=0.18.0",
"openai-harmony>=0.0.8",
"httpx>=0.28.1",
"tomlkit>=0.14.0",
]
[project.scripts]
exo-master = "exo.master.main:main"
exo-worker = "exo.worker.main:main"
exo = "exo.main:main"
exo-rsh = "exo.rsh.client:main"
# dependencies only required for development
[dependency-groups]

View File

@@ -1,32 +0,0 @@
"""Exo CLI - SLURM-compatible job management commands."""
def run_subcommand(command: str, args: list[str]) -> int:
"""Route to the appropriate subcommand handler.
Args:
command: The subcommand name (sbatch, squeue, scancel, salloc)
args: Command line arguments for the subcommand
Returns:
Exit code from the subcommand
"""
if command == "sbatch":
from exo.cli.sbatch import main
return main(args)
elif command == "squeue":
from exo.cli.squeue import main
return main(args)
elif command == "scancel":
from exo.cli.scancel import main
return main(args)
elif command == "salloc":
from exo.cli.salloc import main
return main(args)
else:
print(f"Unknown subcommand: {command}")
return 1

View File

@@ -1,118 +0,0 @@
"""Common utilities for Exo CLI commands."""
import json
import os
import urllib.request
from typing import Any
from urllib.error import HTTPError, URLError
# Default API endpoint
DEFAULT_API_HOST = "localhost"
DEFAULT_API_PORT = 52415
def get_api_base() -> str:
"""Get the API base URL from environment or defaults."""
host = os.environ.get("EXO_API_HOST", DEFAULT_API_HOST)
port = os.environ.get("EXO_API_PORT", str(DEFAULT_API_PORT))
return f"http://{host}:{port}"
def api_request(
method: str,
path: str,
data: dict[str, Any] | None = None,
) -> dict[str, Any] | list[Any]:
"""Make an API request to the Exo server.
Args:
method: HTTP method (GET, POST, DELETE, etc.)
path: API path (e.g., "/flash/instances")
data: Optional JSON data for POST/PUT requests
Returns:
Parsed JSON response
Raises:
SystemExit: On connection or HTTP errors
"""
url = f"{get_api_base()}{path}"
request_data = None
if data is not None:
request_data = json.dumps(data).encode("utf-8")
req = urllib.request.Request(
url,
data=request_data,
method=method,
)
req.add_header("Content-Type", "application/json")
try:
with urllib.request.urlopen(req, timeout=30) as response: # pyright: ignore[reportAny]
body: str = response.read().decode("utf-8") # pyright: ignore[reportAny]
if body:
return json.loads(body) # pyright: ignore[reportAny]
return {}
except HTTPError as e:
error_body = e.read().decode("utf-8") if e.fp else ""
print(f"API error: {e.code} {e.reason}")
if error_body:
try:
error_json: dict[str, str] = json.loads(error_body) # pyright: ignore[reportAny]
if "detail" in error_json:
print(f" {error_json['detail']}")
except json.JSONDecodeError:
print(f" {error_body}")
raise SystemExit(1) from None
except URLError as e:
print(f"Connection error: {e.reason}")
print(f"Is Exo running at {get_api_base()}?")
raise SystemExit(1) from None
def truncate_id(instance_id: str, length: int = 8) -> str:
"""Truncate a UUID for display.
Args:
instance_id: Full UUID string
length: Number of characters to keep
Returns:
Truncated ID without hyphens
"""
return instance_id.replace("-", "")[:length]
def format_table(headers: list[str], rows: list[list[str]]) -> str:
"""Format data as a simple text table.
Args:
headers: Column headers
rows: List of rows, each row is a list of column values
Returns:
Formatted table string
"""
if not rows:
return " ".join(f"{h:<10}" for h in headers)
# Calculate column widths
widths = [len(h) for h in headers]
for row in rows:
for i, cell in enumerate(row):
if i < len(widths):
widths[i] = max(widths[i], len(cell))
# Build format string
fmt = " ".join(f"{{:<{w}}}" for w in widths)
# Format output
lines = [fmt.format(*headers)]
for row in rows:
# Pad row if needed
padded = row + [""] * (len(headers) - len(row))
lines.append(fmt.format(*padded[: len(headers)]))
return "\n".join(lines)

View File

@@ -1,100 +0,0 @@
"""salloc - Allocate nodes for interactive use.
Usage:
exo salloc [options] [-- command [args...]]
Options:
-N, --nodes N Number of nodes to allocate (default: 1)
--hosts HOSTS Comma-separated list of hostnames
If a command is provided after --, it will be executed with
SLURM-like environment variables set:
SLURM_JOB_NODELIST - Comma-separated list of allocated nodes
SLURM_NNODES - Number of allocated nodes
Examples:
exo salloc --nodes=2 --hosts=node1,node2 -- mpirun ./my_program
exo salloc --hosts=localhost -- bash
"""
import argparse
import os
import subprocess
import sys
def main(args: list[str]) -> int:
"""Main entry point for salloc command."""
# Split args at -- if present
cmd_args: list[str] = []
salloc_args = args
if "--" in args:
idx = args.index("--")
salloc_args = args[:idx]
cmd_args = args[idx + 1 :]
parser = argparse.ArgumentParser(
prog="exo salloc",
description="Allocate nodes for interactive use",
)
parser.add_argument(
"-N",
"--nodes",
type=int,
default=1,
help="Number of nodes to allocate (default: 1)",
)
parser.add_argument(
"--hosts",
help="Comma-separated list of hostnames (required)",
)
parsed = parser.parse_args(salloc_args)
nodes: int = parsed.nodes # pyright: ignore[reportAny]
hosts: str | None = parsed.hosts # pyright: ignore[reportAny]
# Require explicit hosts since we can't discover them from topology
if not hosts:
print("Error: --hosts is required (e.g., --hosts=node1,node2)", file=sys.stderr)
print(" The Exo topology doesn't expose hostnames.", file=sys.stderr)
return 1
host_list = [h.strip() for h in hosts.split(",") if h.strip()]
if len(host_list) < nodes:
print(
f"Error: Requested {nodes} nodes but only {len(host_list)} hosts provided",
file=sys.stderr,
)
return 1
# Use first N hosts
allocated_hosts = host_list[:nodes]
nodelist = ",".join(allocated_hosts)
# Set environment variables
env = os.environ.copy()
env["SLURM_JOB_NODELIST"] = nodelist
env["SLURM_NNODES"] = str(nodes)
print(f"salloc: Granted job allocation on {nodes} node(s)")
print(f"salloc: Nodes: {nodelist}")
if cmd_args:
# Run the command
print(f"salloc: Running: {' '.join(cmd_args)}")
result = subprocess.run(cmd_args, env=env)
return result.returncode
else:
# Start interactive shell
shell = os.environ.get("SHELL", "/bin/bash")
print(f"salloc: Starting shell {shell}")
print("salloc: Use 'exit' to release allocation")
result = subprocess.run([shell], env=env)
return result.returncode
if __name__ == "__main__":
sys.exit(main(sys.argv[1:]))

View File

@@ -1,233 +0,0 @@
"""sbatch - Submit a batch job to Exo.
Usage:
exo sbatch [options] <script|executable>
exo sbatch --job-name=NAME --nodes=N <executable>
Options:
-J, --job-name NAME Job name
-N, --nodes N Number of nodes (default: 1)
--ntasks-per-node N Tasks per node (default: 1)
-D, --chdir DIR Working directory
--hosts HOSTS Comma-separated list of hostnames
Job scripts can contain #SBATCH directives:
#!/bin/bash
#SBATCH --job-name=Sod2D
#SBATCH --nodes=2
#SBATCH --chdir=/path/to/workdir
/path/to/flash4
"""
import argparse
import os
import re
import sys
from exo.cli.common import api_request, truncate_id
def parse_job_script(script_path: str) -> tuple[dict[str, str], str | None]:
"""Parse a job script for #SBATCH directives and executable.
Args:
script_path: Path to the job script
Returns:
Tuple of (directives dict, executable path or None)
"""
directives: dict[str, str] = {}
executable: str | None = None
with open(script_path, "r") as f:
for line in f:
line = line.strip()
# Parse #SBATCH directives
if line.startswith("#SBATCH"):
# Handle both --option=value and --option value formats
match = re.match(r"#SBATCH\s+(-\w|--[\w-]+)(?:=|\s+)(.+)", line)
if match:
opt, val = match.groups()
directives[opt.lstrip("-")] = val.strip()
continue
# Skip comments and empty lines
if line.startswith("#") or not line:
continue
# First non-comment, non-directive line is the executable
if executable is None:
# Handle lines like "/path/to/flash4" or "srun /path/to/flash4"
parts = line.split()
if parts:
# Skip srun/mpirun prefixes if present
for part in parts:
if not part.startswith("-") and "/" in part:
executable = part
break
if executable is None and parts:
executable = parts[-1] # Last token
return directives, executable
def main(args: list[str]) -> int:
"""Main entry point for sbatch command."""
parser = argparse.ArgumentParser(
prog="exo sbatch",
description="Submit a batch job to Exo",
)
parser.add_argument(
"script",
help="Job script or executable path",
)
parser.add_argument(
"-J",
"--job-name",
dest="job_name",
help="Job name",
)
parser.add_argument(
"-N",
"--nodes",
type=int,
default=1,
help="Number of nodes (default: 1)",
)
parser.add_argument(
"--ntasks-per-node",
type=int,
default=1,
help="Tasks per node (default: 1)",
)
parser.add_argument(
"-D",
"--chdir",
help="Working directory",
)
parser.add_argument(
"--hosts",
help="Comma-separated list of hostnames",
)
parsed = parser.parse_args(args)
# Extract typed values from namespace
script_path: str = parsed.script # pyright: ignore[reportAny]
arg_job_name: str | None = parsed.job_name # pyright: ignore[reportAny]
arg_nodes: int = parsed.nodes # pyright: ignore[reportAny]
arg_ntasks: int = parsed.ntasks_per_node # pyright: ignore[reportAny]
arg_chdir: str | None = parsed.chdir # pyright: ignore[reportAny]
arg_hosts: str | None = parsed.hosts # pyright: ignore[reportAny]
# Determine if input is a script or direct executable
executable: str | None = None
directives: dict[str, str] = {}
if os.path.isfile(script_path):
# Check if it's a binary file (executable) or text script
is_binary = False
try:
with open(script_path, "rb") as f:
chunk = f.read(512)
# Binary files typically contain null bytes
is_binary = b"\x00" in chunk
except OSError:
pass
if is_binary:
# It's a binary executable
executable = script_path
else:
# Try to read as text
try:
with open(script_path, "r") as f:
first_line = f.readline()
f.seek(0)
content = f.read(1024)
if first_line.startswith("#!") or "#SBATCH" in content:
# It's a job script - parse it
directives, executable = parse_job_script(script_path)
else:
# It's an executable (text but no shebang/directives)
executable = script_path
except UnicodeDecodeError:
# Can't read as text - treat as binary executable
executable = script_path
else:
# Not a file - treat as executable path
executable = script_path
if executable is None:
print("Error: No executable found in job script", file=sys.stderr)
return 1
# Build job parameters - CLI args override script directives
job_name = arg_job_name or directives.get("job-name") or directives.get("J")
if not job_name:
# Generate name from executable
job_name = os.path.basename(executable).replace(".", "_")
nodes = arg_nodes
if "nodes" in directives:
nodes = int(directives["nodes"])
if "N" in directives:
nodes = int(directives["N"])
if arg_nodes != 1: # CLI override
nodes = arg_nodes
ntasks = arg_ntasks
if "ntasks-per-node" in directives:
ntasks = int(directives["ntasks-per-node"])
if arg_ntasks != 1: # CLI override
ntasks = arg_ntasks
workdir = arg_chdir or directives.get("chdir") or directives.get("D")
if not workdir:
workdir = os.getcwd()
hosts = arg_hosts or directives.get("hosts") or ""
# Resolve executable to absolute path
if not os.path.isabs(executable):
executable = os.path.abspath(os.path.join(workdir, executable))
# Submit job via API using query parameters
from urllib.parse import urlencode
params = {
"simulation_name": job_name,
"flash_executable_path": executable,
"parameter_file_path": "", # FLASH par file - use default
"working_directory": workdir,
"ranks_per_node": str(ntasks),
"min_nodes": str(nodes),
"hosts": hosts,
}
query_string = urlencode(params)
result = api_request("POST", f"/flash/launch?{query_string}")
# Print job submission confirmation
if isinstance(result, dict):
instance_id_val = result.get("instance_id")
if instance_id_val is not None:
job_id = truncate_id(str(instance_id_val)) # pyright: ignore[reportAny]
print(f"Submitted batch job {job_id}")
else:
# Instance created asynchronously - user should check squeue
print("Job submitted successfully")
print("Use 'exo squeue' to view job ID")
else:
print("Job submitted successfully")
print("Use 'exo squeue' to view job ID")
return 0
if __name__ == "__main__":
sys.exit(main(sys.argv[1:]))

View File

@@ -1,95 +0,0 @@
"""scancel - Cancel jobs in the Exo queue.
Usage:
exo scancel <jobid> [<jobid>...]
Arguments:
jobid Job ID (or prefix) to cancel. Can specify multiple.
Examples:
exo scancel abc123 # Cancel job starting with abc123
exo scancel abc123 def456 # Cancel multiple jobs
"""
import argparse
import sys
from typing import Any, cast
from exo.cli.common import api_request, truncate_id
def main(args: list[str]) -> int:
"""Main entry point for scancel command."""
parser = argparse.ArgumentParser(
prog="exo scancel",
description="Cancel jobs in the Exo queue",
)
parser.add_argument(
"jobids",
nargs="+",
help="Job ID(s) to cancel",
)
parsed = parser.parse_args(args)
jobids: list[str] = parsed.jobids # pyright: ignore[reportAny]
# Fetch current jobs to resolve partial IDs
result = api_request("GET", "/flash/instances")
if isinstance(result, list):
instances = cast(list[dict[str, Any]], result)
else:
instances = cast(list[dict[str, Any]], result.get("instances", []))
# Build lookup of full IDs
id_map: dict[str, str] = {}
for inst in instances:
iid = inst.get("instance_id", "") # pyright: ignore[reportAny]
full_id = str(iid) if iid else "" # pyright: ignore[reportAny]
if full_id:
# Map both full ID and truncated versions
normalized = full_id.replace("-", "").lower()
id_map[normalized] = full_id
# Also map prefixes
for length in range(4, len(normalized) + 1):
prefix = normalized[:length]
if prefix not in id_map:
id_map[prefix] = full_id
cancelled = 0
errors = 0
for jobid in jobids:
search = jobid.lower().replace("-", "")
# Find matching full ID
full_id = id_map.get(search)
if not full_id:
# Try prefix match
matches = [fid for key, fid in id_map.items() if key.startswith(search)]
if len(matches) == 1:
full_id = matches[0]
elif len(matches) > 1:
print(f"Ambiguous job ID: {jobid} matches multiple jobs")
errors += 1
continue
else:
print(f"Job not found: {jobid}")
errors += 1
continue
# Cancel the job
try:
api_request("DELETE", f"/flash/{full_id}")
print(f"Job {truncate_id(full_id)} cancelled")
cancelled += 1
except SystemExit:
print(f"Failed to cancel job {truncate_id(full_id)}")
errors += 1
if errors > 0 and cancelled == 0:
return 1
return 0
if __name__ == "__main__":
sys.exit(main(sys.argv[1:]))

View File

@@ -1,165 +0,0 @@
"""squeue - View the Exo job queue.
Usage:
exo squeue [options]
Options:
-l, --long Show detailed output
-j, --job ID Show only this job
Output columns:
JOBID - Job identifier (truncated UUID)
NAME - Job name
NODES - Number of nodes
STATE - Job state (PENDING, RUNNING, FAILED, etc.)
"""
import argparse
import sys
from typing import Any, cast
from exo.cli.common import api_request, format_table, truncate_id
# Map Exo runner statuses to SLURM-like states
STATUS_MAP: dict[str, str] = {
"RunnerIdle": "PENDING",
"RunnerConnecting": "CONFIGURING",
"RunnerConnected": "CONFIGURING",
"RunnerLoading": "CONFIGURING",
"RunnerLoaded": "CONFIGURING",
"RunnerWarmingUp": "CONFIGURING",
"RunnerReady": "COMPLETING",
"RunnerRunning": "RUNNING",
"RunnerShuttingDown": "COMPLETING",
"RunnerShutdown": "COMPLETED",
"RunnerFailed": "FAILED",
}
def get_job_state(runner_statuses: dict[str, Any]) -> str:
"""Determine overall job state from runner statuses."""
if not runner_statuses:
return "PENDING"
states: set[str] = set()
for status_val in runner_statuses.values(): # pyright: ignore[reportAny]
if isinstance(status_val, dict):
# Extract status type from discriminated union
type_val = status_val.get("type", "RunnerIdle") # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
status_type = str(type_val) if type_val else "RunnerIdle" # pyright: ignore[reportUnknownArgumentType]
elif isinstance(status_val, str):
status_type = status_val
else:
status_type = "RunnerIdle"
# Strip parentheses from status strings like "RunnerRunning()"
if status_type.endswith("()"):
status_type = status_type[:-2]
states.add(STATUS_MAP.get(status_type, "UNKNOWN"))
# Priority order for overall state
if "FAILED" in states:
return "FAILED"
if "RUNNING" in states:
return "RUNNING"
if "CONFIGURING" in states:
return "CONFIGURING"
if "COMPLETING" in states:
return "COMPLETING"
if "COMPLETED" in states:
return "COMPLETED"
if "PENDING" in states:
return "PENDING"
return "UNKNOWN"
def main(args: list[str]) -> int:
"""Main entry point for squeue command."""
parser = argparse.ArgumentParser(
prog="exo squeue",
description="View the Exo job queue",
)
parser.add_argument(
"-l",
"--long",
action="store_true",
help="Show detailed output",
)
parser.add_argument(
"-j",
"--job",
help="Show only this job ID",
)
parsed = parser.parse_args(args)
# Extract typed values
long_format: bool = parsed.long # pyright: ignore[reportAny]
job_filter: str | None = parsed.job # pyright: ignore[reportAny]
# Fetch jobs from API - returns list directly
result = api_request("GET", "/flash/instances")
# API returns list directly, not {"instances": [...]}
if isinstance(result, list):
instances = cast(list[dict[str, Any]], result)
else:
instances = cast(list[dict[str, Any]], result.get("instances", []))
if not instances:
# No jobs - just print header
if long_format:
print("JOBID NAME NODES RANKS STATE WORKDIR")
else:
print("JOBID NAME NODES STATE")
return 0
# Filter by job ID if specified
if job_filter:
search = job_filter.lower()
filtered: list[dict[str, Any]] = []
for i in instances:
iid = i.get("instance_id", "") # pyright: ignore[reportAny]
if search in str(iid).lower().replace("-", ""): # pyright: ignore[reportAny]
filtered.append(i)
instances = filtered
# Build table
rows: list[list[str]] = []
if long_format:
headers = ["JOBID", "NAME", "NODES", "RANKS", "STATE", "WORKDIR"]
for inst in instances:
iid_val = inst.get("instance_id", "") # pyright: ignore[reportAny]
instance_id = str(iid_val) if iid_val else "" # pyright: ignore[reportAny]
job_id = truncate_id(instance_id, 12)
name_val = inst.get("simulation_name", "") # pyright: ignore[reportAny]
name = (str(name_val) if name_val else "")[:15] # pyright: ignore[reportAny]
runner_statuses = cast(dict[str, Any], inst.get("runner_statuses", {}))
nodes = str(len(runner_statuses))
ranks_val = inst.get("total_ranks", 0) # pyright: ignore[reportAny]
ranks = str(ranks_val) if ranks_val else "0" # pyright: ignore[reportAny]
state = get_job_state(runner_statuses)
workdir_val = inst.get("working_directory", "") # pyright: ignore[reportAny]
workdir = str(workdir_val) if workdir_val else "" # pyright: ignore[reportAny]
# Truncate workdir for display
if len(workdir) > 30:
workdir = "..." + workdir[-27:]
rows.append([job_id, name, nodes, ranks, state, workdir])
else:
headers = ["JOBID", "NAME", "NODES", "STATE"]
for inst in instances:
iid_val = inst.get("instance_id", "") # pyright: ignore[reportAny]
instance_id = str(iid_val) if iid_val else "" # pyright: ignore[reportAny]
job_id = truncate_id(instance_id, 8)
name_val = inst.get("simulation_name", "") # pyright: ignore[reportAny]
name = (str(name_val) if name_val else "")[:15] # pyright: ignore[reportAny]
runner_statuses = cast(dict[str, Any], inst.get("runner_statuses", {}))
nodes = str(len(runner_statuses))
state = get_job_state(runner_statuses)
rows.append([job_id, name, nodes, state])
print(format_table(headers, rows))
return 0
if __name__ == "__main__":
sys.exit(main(sys.argv[1:]))

View File

@@ -195,14 +195,6 @@ class Node:
def main():
# Check for SLURM-compatible subcommands first
import sys
if len(sys.argv) > 1 and sys.argv[1] in ("sbatch", "squeue", "scancel", "salloc"):
from exo.cli import run_subcommand
sys.exit(run_subcommand(sys.argv[1], sys.argv[2:]))
args = Args.parse()
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
resource.setrlimit(resource.RLIMIT_NOFILE, (max(soft, 65535), hard))
@@ -213,11 +205,6 @@ def main():
logger.info("Starting EXO")
logger.info(f"EXO_LIBP2P_NAMESPACE: {os.getenv('EXO_LIBP2P_NAMESPACE')}")
# Discover and register plugins
from exo.plugins.registry import discover_plugins
discover_plugins()
# Set FAST_SYNCH override env var for runner subprocesses
if args.fast_synch is True:
os.environ["EXO_FAST_SYNCH"] = "on"

View File

@@ -1,9 +1,7 @@
import asyncio
import os
import time
from collections.abc import AsyncGenerator
from http import HTTPStatus
from typing import Any, Optional, cast
from typing import cast
import anyio
from anyio import BrokenResourceError, create_task_group
@@ -16,14 +14,16 @@ from hypercorn.asyncio import serve # pyright: ignore[reportUnknownVariableType
from hypercorn.config import Config
from hypercorn.typing import ASGIFramework
from loguru import logger
from pydantic import BaseModel
from exo.master.placement import place_instance as get_instance_placements
from exo.shared.apply import apply
from exo.shared.election import ElectionMessage
from exo.shared.logging import InterceptLogger
from exo.shared.models.model_cards import MODEL_CARDS
from exo.shared.models.model_meta import get_model_meta
from exo.shared.models.model_cards import (
MODEL_CARDS,
ModelCard,
ModelId,
)
from exo.shared.types.api import (
BenchChatCompletionResponse,
BenchChatCompletionTaskParams,
@@ -62,14 +62,9 @@ from exo.shared.types.events import (
IndexedEvent,
)
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.state import State
from exo.shared.types.tasks import ChatCompletionTaskParams
from exo.shared.types.worker.instances import (
Instance,
InstanceId,
InstanceMeta,
)
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.shards import Sharding
from exo.utils.banner import print_startup_banner
from exo.utils.channels import Receiver, Sender, channel
@@ -77,22 +72,6 @@ from exo.utils.dashboard_path import find_dashboard
from exo.utils.event_buffer import OrderedBuffer
class ExecuteRequest(BaseModel):
"""Request to execute a command."""
command: list[str]
cwd: Optional[str] = None
env: Optional[dict[str, str]] = None
class ExecuteResponse(BaseModel):
"""Response from command execution."""
exit_code: int
stdout: str
stderr: str
def chunk_to_response(
chunk: TokenChunk, command_id: CommandId
) -> ChatCompletionResponse:
@@ -110,12 +89,12 @@ def chunk_to_response(
)
async def resolve_model_meta(model_id: str) -> ModelMetadata:
async def resolve_model_card(model_id: ModelId) -> ModelCard:
if model_id in MODEL_CARDS:
model_card = MODEL_CARDS[model_id]
return model_card.metadata
return model_card
else:
return await get_model_meta(model_id)
return await ModelCard.from_hf(model_id)
class API:
@@ -149,7 +128,6 @@ class API:
self._setup_exception_handlers()
self._setup_cors()
self._setup_routes()
self._register_plugin_routes()
self.app.mount(
"/",
@@ -218,62 +196,10 @@ class API:
self.app.post("/bench/chat/completions")(self.bench_chat_completions)
self.app.get("/state")(lambda: self.state)
self.app.get("/events")(lambda: self._event_log)
# Remote execution endpoint (used by exo-rsh for MPI)
self.app.post("/execute")(self.execute)
def _register_plugin_routes(self) -> None:
"""Register API routes from all loaded plugins."""
import functools
import inspect
from exo.plugins.context import PluginContext
from exo.plugins.registry import PluginRegistry
registry = PluginRegistry.get()
for method, path, handler, plugin in registry.get_all_api_routes():
# Create a wrapper that injects the plugin context while preserving
# the original function signature for FastAPI parameter extraction
@functools.wraps(handler)
async def wrapped_handler( # pyright: ignore[reportAny]
*args: Any, # pyright: ignore[reportAny]
_handler: Any = handler, # pyright: ignore[reportAny]
**kwargs: Any, # pyright: ignore[reportAny]
) -> Any:
context = PluginContext(
state=self.state,
send_command=self._send,
node_id=self.node_id,
)
# Pass context as first argument, then forward all other args
return await _handler(context, *args, **kwargs) # pyright: ignore[reportAny]
# Modify the wrapper signature to match the original handler
# but without the 'ctx' parameter (we inject it)
orig_sig = inspect.signature(handler)
params = list(orig_sig.parameters.values())
# Remove the first parameter (ctx: PluginContext)
if params and params[0].name in ("ctx", "context"):
params = params[1:]
wrapped_handler.__signature__ = orig_sig.replace(parameters=params) # pyright: ignore[reportAttributeAccessIssue]
# Register the route based on HTTP method
if method == "get":
self.app.get(path)(wrapped_handler)
elif method == "post":
self.app.post(path)(wrapped_handler)
elif method == "delete":
self.app.delete(path)(wrapped_handler)
elif method == "put":
self.app.put(path)(wrapped_handler)
logger.debug(
f"Registered plugin route: {method.upper()} {path} ({plugin.name})"
)
async def place_instance(self, payload: PlaceInstanceParams):
command = PlaceInstance(
model_meta=await resolve_model_meta(payload.model_id),
model_card=await resolve_model_card(payload.model_id),
sharding=payload.sharding,
instance_meta=payload.instance_meta,
min_nodes=payload.min_nodes,
@@ -283,15 +209,15 @@ class API:
return CreateInstanceResponse(
message="Command received.",
command_id=command.command_id,
model_meta=command.model_meta,
model_card=command.model_card,
)
async def create_instance(
self, payload: CreateInstanceParams
) -> CreateInstanceResponse:
instance = payload.instance
model_meta = await resolve_model_meta(instance.shard_assignments.model_id)
required_memory = model_meta.storage_size
model_card = await resolve_model_card(instance.shard_assignments.model_id)
required_memory = model_card.storage_size
available_memory = self._calculate_total_available_memory()
if required_memory > available_memory:
@@ -308,22 +234,22 @@ class API:
return CreateInstanceResponse(
message="Command received.",
command_id=command.command_id,
model_meta=model_meta,
model_card=model_card,
)
async def get_placement(
self,
model_id: str,
model_id: ModelId,
sharding: Sharding = Sharding.Pipeline,
instance_meta: InstanceMeta = InstanceMeta.MlxRing,
min_nodes: int = 1,
) -> Instance:
model_meta = await resolve_model_meta(model_id)
model_card = await resolve_model_card(model_id)
try:
placements = get_instance_placements(
PlaceInstance(
model_meta=model_meta,
model_card=model_card,
sharding=sharding,
instance_meta=instance_meta,
min_nodes=min_nodes,
@@ -356,7 +282,7 @@ class API:
if len(list(self.state.topology.list_nodes())) == 0:
return PlacementPreviewResponse(previews=[])
cards = [card for card in MODEL_CARDS.values() if card.short_id == model_id]
cards = [card for card in MODEL_CARDS.values() if card.model_id == model_id]
if not cards:
raise HTTPException(status_code=404, detail=f"Model {model_id} not found")
@@ -374,13 +300,12 @@ class API:
# TODO: PDD
# instance_combinations.append((Sharding.PrefillDecodeDisaggregation, InstanceMeta.MlxRing, 1))
for card in cards:
model_meta = card.metadata
for model_card in cards:
for sharding, instance_meta, min_nodes in instance_combinations:
try:
placements = get_instance_placements(
PlaceInstance(
model_meta=model_meta,
model_card=model_card,
sharding=sharding,
instance_meta=instance_meta,
min_nodes=min_nodes,
@@ -391,17 +316,17 @@ class API:
current_instances=self.state.instances,
)
except ValueError as exc:
if (card.model_id, sharding, instance_meta, 0) not in seen:
if (model_card.model_id, sharding, instance_meta, 0) not in seen:
previews.append(
PlacementPreview(
model_id=card.model_id,
model_id=model_card.model_id,
sharding=sharding,
instance_meta=instance_meta,
instance=None,
error=str(exc),
)
)
seen.add((card.model_id, sharding, instance_meta, 0))
seen.add((model_card.model_id, sharding, instance_meta, 0))
continue
current_ids = set(self.state.instances.keys())
@@ -412,17 +337,17 @@ class API:
]
if len(new_instances) != 1:
if (card.model_id, sharding, instance_meta, 0) not in seen:
if (model_card.model_id, sharding, instance_meta, 0) not in seen:
previews.append(
PlacementPreview(
model_id=card.model_id,
model_id=model_card.model_id,
sharding=sharding,
instance_meta=instance_meta,
instance=None,
error="Expected exactly one new instance from placement",
)
)
seen.add((card.model_id, sharding, instance_meta, 0))
seen.add((model_card.model_id, sharding, instance_meta, 0))
continue
instance = new_instances[0]
@@ -431,7 +356,7 @@ class API:
memory_delta_by_node: dict[str, int] = {}
if node_ids:
total_bytes = model_meta.storage_size.in_bytes
total_bytes = model_card.storage_size.in_bytes
per_node = total_bytes // len(node_ids)
remainder = total_bytes % len(node_ids)
for index, node_id in enumerate(sorted(node_ids, key=str)):
@@ -439,14 +364,14 @@ class API:
memory_delta_by_node[str(node_id)] = per_node + extra
if (
card.model_id,
model_card.model_id,
sharding,
instance_meta,
len(node_ids),
) not in seen:
previews.append(
PlacementPreview(
model_id=card.model_id,
model_id=model_card.model_id,
sharding=sharding,
instance_meta=instance_meta,
instance=instance,
@@ -454,7 +379,7 @@ class API:
error=None,
)
)
seen.add((card.model_id, sharding, instance_meta, len(node_ids)))
seen.add((model_card.model_id, sharding, instance_meta, len(node_ids)))
return PlacementPreviewResponse(previews=previews)
@@ -629,8 +554,8 @@ class API:
self, payload: ChatCompletionTaskParams
) -> ChatCompletionResponse | StreamingResponse:
"""Handle chat completions, supporting both streaming and non-streaming responses."""
model_meta = await resolve_model_meta(payload.model)
payload.model = model_meta.model_id
model_card = await resolve_model_card(ModelId(payload.model))
payload.model = model_card.model_id
if not any(
instance.shard_assignments.model_id == payload.model
@@ -656,8 +581,8 @@ class API:
async def bench_chat_completions(
self, payload: BenchChatCompletionTaskParams
) -> BenchChatCompletionResponse:
model_meta = await resolve_model_meta(payload.model)
payload.model = model_meta.model_id
model_card = await resolve_model_card(ModelId(payload.model))
payload.model = model_card.model_id
if not any(
instance.shard_assignments.model_id == payload.model
@@ -690,77 +615,18 @@ class API:
return ModelList(
data=[
ModelListModel(
id=card.short_id,
id=card.model_id,
hugging_face_id=card.model_id,
name=card.name,
description=card.description,
tags=card.tags,
storage_size_megabytes=int(card.metadata.storage_size.in_mb),
supports_tensor=card.metadata.supports_tensor,
name=card.model_id.short(),
description="",
tags=[],
storage_size_megabytes=int(card.storage_size.in_mb),
supports_tensor=card.supports_tensor,
)
for card in MODEL_CARDS.values()
]
)
async def execute(self, request: ExecuteRequest) -> ExecuteResponse:
"""Execute a command locally. Used by exo-rsh for MPI remote execution."""
cmd_str = " ".join(request.command)
logger.info(f"Executing: {cmd_str}")
try:
# Build environment
env = os.environ.copy()
if request.env:
env.update(request.env)
# Check if command contains shell metacharacters
# If so, run through shell. mpirun sends complex commands like:
# "VAR=value;export VAR;/path/to/prted --args"
needs_shell = any(c in cmd_str for c in ";|&$`")
if needs_shell:
process = await asyncio.create_subprocess_shell(
cmd_str,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
cwd=request.cwd,
env=env,
)
else:
process = await asyncio.create_subprocess_exec(
*request.command,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
cwd=request.cwd,
env=env,
)
stdout, stderr = await process.communicate()
exit_code = process.returncode or 0
logger.info(f"Command completed with exit code {exit_code}")
return ExecuteResponse(
exit_code=exit_code,
stdout=stdout.decode("utf-8", errors="replace"),
stderr=stderr.decode("utf-8", errors="replace"),
)
except FileNotFoundError:
logger.error(f"Command not found: {request.command[0]}")
return ExecuteResponse(
exit_code=127,
stdout="",
stderr=f"Command not found: {request.command[0]}",
)
except Exception as e:
logger.error(f"Execution error: {e}")
return ExecuteResponse(
exit_code=1,
stdout="",
stderr=str(e),
)
async def run(self):
cfg = Config()
cfg.bind = f"0.0.0.0:{self.port}"

View File

@@ -96,128 +96,104 @@ class Master:
self._tg.cancel_scope.cancel()
async def _command_processor(self) -> None:
from exo.plugins.registry import PluginRegistry
registry = PluginRegistry.get()
with self.command_receiver as commands:
async for forwarder_command in commands:
try:
logger.info(f"Executing command: {forwarder_command.command}")
generated_events: list[Event] = []
command = forwarder_command.command
# Check if a plugin handles this command
plugin = registry.get_plugin_for_command(command)
if plugin is not None:
events = plugin.process_command(
command,
self.state.topology,
self.state.instances,
)
generated_events.extend(events)
else:
# Core command handling
match command:
case TestCommand():
pass
case ChatCompletion():
instance_task_counts: dict[InstanceId, int] = {}
for instance in self.state.instances.values():
if (
instance.shard_assignments.model_id
== command.request_params.model
):
task_count = sum(
1
for task in self.state.tasks.values()
if task.instance_id == instance.instance_id
)
instance_task_counts[instance.instance_id] = (
task_count
)
if not instance_task_counts:
raise ValueError(
f"No instance found for model {command.request_params.model}"
)
available_instance_ids = sorted(
instance_task_counts.keys(),
key=lambda instance_id: instance_task_counts[
instance_id
],
)
task_id = TaskId()
generated_events.append(
TaskCreated(
task_id=task_id,
task=ChatCompletionTask(
task_id=task_id,
command_id=command.command_id,
instance_id=available_instance_ids[0],
task_status=TaskStatus.Pending,
task_params=command.request_params,
),
)
)
self.command_task_mapping[command.command_id] = task_id
case DeleteInstance():
placement = delete_instance(
command, self.state.instances
)
transition_events = get_transition_events(
self.state.instances, placement
)
generated_events.extend(transition_events)
case PlaceInstance():
placement = place_instance(
command,
self.state.topology,
self.state.instances,
self.state.node_memory,
self.state.node_network,
)
transition_events = get_transition_events(
self.state.instances, placement
)
generated_events.extend(transition_events)
case CreateInstance():
placement = add_instance_to_placements(
command,
self.state.topology,
self.state.instances,
)
transition_events = get_transition_events(
self.state.instances, placement
)
generated_events.extend(transition_events)
case TaskFinished():
generated_events.append(
TaskDeleted(
task_id=self.command_task_mapping[
command.finished_command_id
]
)
)
match command:
case TestCommand():
pass
case ChatCompletion():
instance_task_counts: dict[InstanceId, int] = {}
for instance in self.state.instances.values():
if (
command.finished_command_id
in self.command_task_mapping
instance.shard_assignments.model_id
== command.request_params.model
):
del self.command_task_mapping[
task_count = sum(
1
for task in self.state.tasks.values()
if task.instance_id == instance.instance_id
)
instance_task_counts[instance.instance_id] = (
task_count
)
if not instance_task_counts:
raise ValueError(
f"No instance found for model {command.request_params.model}"
)
available_instance_ids = sorted(
instance_task_counts.keys(),
key=lambda instance_id: instance_task_counts[
instance_id
],
)
task_id = TaskId()
generated_events.append(
TaskCreated(
task_id=task_id,
task=ChatCompletionTask(
task_id=task_id,
command_id=command.command_id,
instance_id=available_instance_ids[0],
task_status=TaskStatus.Pending,
task_params=command.request_params,
),
)
)
self.command_task_mapping[command.command_id] = task_id
case DeleteInstance():
placement = delete_instance(command, self.state.instances)
transition_events = get_transition_events(
self.state.instances, placement
)
generated_events.extend(transition_events)
case PlaceInstance():
placement = place_instance(
command,
self.state.topology,
self.state.instances,
self.state.node_memory,
self.state.node_network,
)
transition_events = get_transition_events(
self.state.instances, placement
)
generated_events.extend(transition_events)
case CreateInstance():
placement = add_instance_to_placements(
command,
self.state.topology,
self.state.instances,
)
transition_events = get_transition_events(
self.state.instances, placement
)
generated_events.extend(transition_events)
case TaskFinished():
generated_events.append(
TaskDeleted(
task_id=self.command_task_mapping[
command.finished_command_id
]
case RequestEventLog():
# We should just be able to send everything, since other buffers will ignore old messages
for i in range(command.since_idx, len(self._event_log)):
await self._send_event(
IndexedEvent(idx=i, event=self._event_log[i])
)
case _:
# Plugin-managed commands are handled above
pass
)
)
if command.finished_command_id in self.command_task_mapping:
del self.command_task_mapping[
command.finished_command_id
]
case RequestEventLog():
# We should just be able to send everything, since other buffers will ignore old messages
for i in range(command.since_idx, len(self._event_log)):
await self._send_event(
IndexedEvent(idx=i, event=self._event_log[i])
)
for event in generated_events:
await self.event_sender.send(event)
except ValueError as e:

View File

@@ -14,6 +14,7 @@ from exo.master.placement_utils import (
get_shard_assignments,
get_smallest_cycles,
)
from exo.shared.models.model_cards import ModelId
from exo.shared.topology import Topology
from exo.shared.types.commands import (
CreateInstance,
@@ -23,7 +24,6 @@ from exo.shared.types.commands import (
from exo.shared.types.common import NodeId
from exo.shared.types.events import Event, InstanceCreated, InstanceDeleted
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId
from exo.shared.types.profiling import MemoryUsage, NodeNetworkInfo
from exo.shared.types.worker.instances import (
Instance,
@@ -60,27 +60,27 @@ def place_instance(
cycles = topology.get_cycles()
candidate_cycles = list(filter(lambda it: len(it) >= command.min_nodes, cycles))
cycles_with_sufficient_memory = filter_cycles_by_memory(
candidate_cycles, node_memory, command.model_meta.storage_size
candidate_cycles, node_memory, command.model_card.storage_size
)
if len(cycles_with_sufficient_memory) == 0:
raise ValueError("No cycles found with sufficient memory")
if command.sharding == Sharding.Tensor:
if not command.model_meta.supports_tensor:
if not command.model_card.supports_tensor:
raise ValueError(
f"Requested Tensor sharding but this model does not support tensor parallelism: {command.model_meta.model_id}"
f"Requested Tensor sharding but this model does not support tensor parallelism: {command.model_card.model_id}"
)
# TODO: the condition here for tensor parallel is not correct, but it works good enough for now.
cycles_with_sufficient_memory = [
cycle
for cycle in cycles_with_sufficient_memory
if command.model_meta.hidden_size % len(cycle) == 0
if command.model_card.hidden_size % len(cycle) == 0
]
if not cycles_with_sufficient_memory:
raise ValueError(
f"No tensor sharding found for model with hidden_size {command.model_meta.hidden_size} candidate cycles"
f"No tensor sharding found for model with hidden_size {command.model_card.hidden_size} candidate cycles"
)
if command.sharding == Sharding.Pipeline and command.model_meta.model_id == ModelId(
if command.sharding == Sharding.Pipeline and command.model_card.model_id == ModelId(
"mlx-community/DeepSeek-V3.1-8bit"
):
raise ValueError(
@@ -111,7 +111,7 @@ def place_instance(
)
shard_assignments = get_shard_assignments(
command.model_meta, selected_cycle, command.sharding, node_memory
command.model_card, selected_cycle, command.sharding, node_memory
)
cycle_digraph: Topology = topology.get_subgraph_from_nodes(selected_cycle.node_ids)
@@ -120,7 +120,7 @@ def place_instance(
target_instances = dict(deepcopy(current_instances))
if len(selected_cycle) == 1:
logger.warning(
logger.debug(
"You have likely selected jaccl for a single node instance; falling back to MlxRing"
)
@@ -159,11 +159,6 @@ def place_instance(
hosts_by_node=hosts_by_node,
ephemeral_port=ephemeral_port,
)
case _:
# Plugin-managed instance types have their own placement functions
raise ValueError(
f"Instance type {command.instance_meta} must use plugin placement"
)
return target_instances

View File

@@ -2,10 +2,10 @@ from collections.abc import Generator, Mapping
from loguru import logger
from exo.shared.models.model_cards import ModelCard
from exo.shared.topology import Topology
from exo.shared.types.common import Host, NodeId
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelMetadata
from exo.shared.types.profiling import MemoryUsage, NodeNetworkInfo
from exo.shared.types.topology import Cycle, RDMAConnection, SocketConnection
from exo.shared.types.worker.runners import RunnerId, ShardAssignments
@@ -75,7 +75,7 @@ def allocate_layers_proportionally(
def get_shard_assignments_for_pipeline_parallel(
model_meta: ModelMetadata,
model_card: ModelCard,
cycle: Cycle,
node_memory: Mapping[NodeId, MemoryUsage],
):
@@ -86,11 +86,10 @@ def get_shard_assignments_for_pipeline_parallel(
(node_memory[node_id].ram_available for node_id in cycle.node_ids),
start=Memory(),
)
if cycle_memory.in_bytes == 0:
raise ValueError("Cannot create shard assignments: total available memory is 0")
total_layers = model_meta.n_layers
total_layers = model_card.n_layers
world_size = len(cycle)
runner_to_shard: dict[RunnerId, ShardMetadata] = {}
node_to_runner: dict[NodeId, RunnerId] = {}
@@ -104,7 +103,7 @@ def get_shard_assignments_for_pipeline_parallel(
)
# Validate each node has sufficient memory for its assigned layers
memory_per_layer = model_meta.storage_size.in_bytes / total_layers
memory_per_layer = model_card.storage_size.in_bytes / total_layers
for i, (node_id, node_layers) in enumerate(
zip(cycle.node_ids, layer_allocations, strict=True)
):
@@ -124,7 +123,7 @@ def get_shard_assignments_for_pipeline_parallel(
runner_id = RunnerId()
shard = PipelineShardMetadata(
model_meta=model_meta,
model_card=model_card,
device_rank=i,
world_size=world_size,
start_layer=layers_assigned,
@@ -137,7 +136,7 @@ def get_shard_assignments_for_pipeline_parallel(
layers_assigned += node_layers
shard_assignments = ShardAssignments(
model_id=model_meta.model_id,
model_id=model_card.model_id,
runner_to_shard=runner_to_shard,
node_to_runner=node_to_runner,
)
@@ -146,17 +145,17 @@ def get_shard_assignments_for_pipeline_parallel(
def get_shard_assignments_for_tensor_parallel(
model_meta: ModelMetadata,
model_card: ModelCard,
cycle: Cycle,
):
total_layers = model_meta.n_layers
total_layers = model_card.n_layers
world_size = len(cycle)
runner_to_shard: dict[RunnerId, ShardMetadata] = {}
node_to_runner: dict[NodeId, RunnerId] = {}
for i, node_id in enumerate(cycle):
shard = TensorShardMetadata(
model_meta=model_meta,
model_card=model_card,
device_rank=i,
world_size=world_size,
start_layer=0,
@@ -170,7 +169,7 @@ def get_shard_assignments_for_tensor_parallel(
node_to_runner[node_id] = runner_id
shard_assignments = ShardAssignments(
model_id=model_meta.model_id,
model_id=model_card.model_id,
runner_to_shard=runner_to_shard,
node_to_runner=node_to_runner,
)
@@ -179,7 +178,7 @@ def get_shard_assignments_for_tensor_parallel(
def get_shard_assignments(
model_meta: ModelMetadata,
model_card: ModelCard,
cycle: Cycle,
sharding: Sharding,
node_memory: Mapping[NodeId, MemoryUsage],
@@ -187,13 +186,13 @@ def get_shard_assignments(
match sharding:
case Sharding.Pipeline:
return get_shard_assignments_for_pipeline_parallel(
model_meta=model_meta,
model_card=model_card,
cycle=cycle,
node_memory=node_memory,
)
case Sharding.Tensor:
return get_shard_assignments_for_tensor_parallel(
model_meta=model_meta,
model_card=model_card,
cycle=cycle,
)

View File

@@ -7,6 +7,7 @@ from loguru import logger
from exo.master.main import Master
from exo.routing.router import get_node_id_keypair
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.types.api import ChatCompletionMessage, ChatCompletionTaskParams
from exo.shared.types.commands import (
ChatCompletion,
@@ -23,7 +24,6 @@ from exo.shared.types.events import (
TaskCreated,
)
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.profiling import (
MemoryUsage,
)
@@ -109,9 +109,8 @@ async def test_master():
command=(
PlaceInstance(
command_id=CommandId(),
model_meta=ModelMetadata(
model_card=ModelCard(
model_id=ModelId("llama-3.2-1b"),
pretty_name="Llama 3.2 1B",
n_layers=16,
storage_size=Memory.from_bytes(678948),
hidden_size=7168,
@@ -167,9 +166,8 @@ async def test_master():
start_layer=0,
end_layer=16,
n_layers=16,
model_meta=ModelMetadata(
model_card=ModelCard(
model_id=ModelId("llama-3.2-1b"),
pretty_name="Llama 3.2 1B",
n_layers=16,
storage_size=Memory.from_bytes(678948),
hidden_size=7168,

View File

@@ -10,12 +10,12 @@ from exo.master.tests.conftest import (
create_rdma_connection,
create_socket_connection,
)
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.topology import Topology
from exo.shared.types.commands import PlaceInstance
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.events import InstanceCreated, InstanceDeleted
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.profiling import NetworkInterfaceInfo, NodeNetworkInfo
from exo.shared.types.topology import Connection, SocketConnection
@@ -43,21 +43,20 @@ def instance() -> Instance:
@pytest.fixture
def model_meta() -> ModelMetadata:
return ModelMetadata(
def model_card() -> ModelCard:
return ModelCard(
model_id=ModelId("test-model"),
storage_size=Memory.from_kb(1000),
pretty_name="Test Model",
n_layers=10,
hidden_size=30,
supports_tensor=True,
)
def place_instance_command(model_meta: ModelMetadata) -> PlaceInstance:
def place_instance_command(model_card: ModelCard) -> PlaceInstance:
return PlaceInstance(
command_id=CommandId(),
model_meta=model_meta,
model_card=model_card,
sharding=Sharding.Pipeline,
instance_meta=InstanceMeta.MlxRing,
min_nodes=1,
@@ -76,16 +75,16 @@ def test_get_instance_placements_create_instance(
available_memory: tuple[int, int, int],
total_layers: int,
expected_layers: tuple[int, int, int],
model_meta: ModelMetadata,
model_card: ModelCard,
):
# arrange
model_meta.n_layers = total_layers
model_meta.storage_size.in_bytes = sum(
model_card.n_layers = total_layers
model_card.storage_size.in_bytes = sum(
available_memory
) # make it exactly fit across all nodes
topology = Topology()
cic = place_instance_command(model_meta)
cic = place_instance_command(model_card)
node_id_a = NodeId()
node_id_b = NodeId()
node_id_c = NodeId()
@@ -137,7 +136,7 @@ def test_get_instance_placements_create_instance(
assert len(placements) == 1
instance_id = list(placements.keys())[0]
instance = placements[instance_id]
assert instance.shard_assignments.model_id == model_meta.model_id
assert instance.shard_assignments.model_id == model_card.model_id
runner_id_a = instance.shard_assignments.node_to_runner[node_id_a]
runner_id_b = instance.shard_assignments.node_to_runner[node_id_b]
@@ -164,10 +163,9 @@ def test_get_instance_placements_one_node_exact_fit() -> None:
node_memory = {node_id: create_node_memory(1000 * 1024)}
node_network = {node_id: create_node_network()}
cic = place_instance_command(
ModelMetadata(
ModelCard(
model_id=ModelId("test-model"),
storage_size=Memory.from_kb(1000),
pretty_name="Test Model",
n_layers=10,
hidden_size=1000,
supports_tensor=True,
@@ -191,10 +189,9 @@ def test_get_instance_placements_one_node_fits_with_extra_memory() -> None:
node_memory = {node_id: create_node_memory(1001 * 1024)}
node_network = {node_id: create_node_network()}
cic = place_instance_command(
ModelMetadata(
ModelCard(
model_id=ModelId("test-model"),
storage_size=Memory.from_kb(1000),
pretty_name="Test Model",
n_layers=10,
hidden_size=1000,
supports_tensor=True,
@@ -218,10 +215,9 @@ def test_get_instance_placements_one_node_not_fit() -> None:
node_memory = {node_id: create_node_memory(1000 * 1024)}
node_network = {node_id: create_node_network()}
cic = place_instance_command(
model_meta=ModelMetadata(
model_card=ModelCard(
model_id=ModelId("test-model"),
storage_size=Memory.from_kb(1001),
pretty_name="Test Model",
n_layers=10,
hidden_size=1000,
supports_tensor=True,
@@ -275,12 +271,12 @@ def test_get_transition_events_delete_instance(instance: Instance):
def test_placement_selects_leaf_nodes(
model_meta: ModelMetadata,
model_card: ModelCard,
):
# arrange
topology = Topology()
model_meta.storage_size = Memory.from_bytes(1000)
model_card.storage_size = Memory.from_bytes(1000)
node_id_a = NodeId()
node_id_b = NodeId()
@@ -325,7 +321,7 @@ def test_placement_selects_leaf_nodes(
Connection(source=node_id_d, sink=node_id_c, edge=create_socket_connection(1))
)
cic = place_instance_command(model_meta=model_meta)
cic = place_instance_command(model_card=model_card)
# act
placements = place_instance(cic, topology, {}, node_memory, node_network)
@@ -344,12 +340,12 @@ def test_placement_selects_leaf_nodes(
def test_tensor_rdma_backend_connectivity_matrix(
model_meta: ModelMetadata,
model_card: ModelCard,
):
# arrange
topology = Topology()
model_meta.n_layers = 12
model_meta.storage_size.in_bytes = 1500
model_card.n_layers = 12
model_card.storage_size.in_bytes = 1500
node_a = NodeId()
node_b = NodeId()
@@ -411,7 +407,7 @@ def test_tensor_rdma_backend_connectivity_matrix(
sharding=Sharding.Tensor,
instance_meta=InstanceMeta.MlxJaccl,
command_id=CommandId(),
model_meta=model_meta,
model_card=model_card,
min_nodes=1,
)

View File

@@ -12,10 +12,10 @@ from exo.master.tests.conftest import (
create_node_memory,
create_socket_connection,
)
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.topology import Topology
from exo.shared.types.common import Host, NodeId
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.profiling import (
NetworkInterfaceInfo,
NodeNetworkInfo,
@@ -232,9 +232,8 @@ def test_get_shard_assignments(
node_c_id: node_c_mem,
}
model_meta = ModelMetadata(
model_card = ModelCard(
model_id=ModelId("test-model"),
pretty_name="Test Model",
n_layers=total_layers,
storage_size=Memory.from_kb(1000),
hidden_size=1000,
@@ -248,7 +247,7 @@ def test_get_shard_assignments(
# act
shard_assignments = get_shard_assignments(
model_meta, selected_cycle, Sharding.Pipeline, node_memory=node_memory
model_card, selected_cycle, Sharding.Pipeline, node_memory=node_memory
)
# assert
@@ -512,9 +511,8 @@ def test_get_shard_assignments_insufficient_memory_raises():
node_c_id: node_c_mem,
}
model_meta = ModelMetadata(
model_card = ModelCard(
model_id=ModelId("test-model"),
pretty_name="Test Model",
n_layers=20,
storage_size=Memory.from_kb(1000),
hidden_size=1000,
@@ -525,5 +523,5 @@ def test_get_shard_assignments_insufficient_memory_raises():
with pytest.raises(ValueError, match="insufficient memory"):
get_shard_assignments(
model_meta, selected_cycle, Sharding.Pipeline, node_memory
model_card, selected_cycle, Sharding.Pipeline, node_memory
)

View File

@@ -1,16 +0,0 @@
"""Exo Plugin System.
This module provides the plugin architecture for extending exo with custom
workload types (simulations, ML frameworks, etc.) without modifying core code.
"""
from exo.plugins.base import ExoPlugin, PluginCommand, PluginInstance
from exo.plugins.registry import PluginRegistry, discover_plugins
__all__ = [
"ExoPlugin",
"PluginCommand",
"PluginInstance",
"PluginRegistry",
"discover_plugins",
]

View File

@@ -1,171 +0,0 @@
"""Base classes and protocols for Exo plugins."""
from abc import ABC, abstractmethod
from collections.abc import Callable, Mapping, Sequence
from typing import TYPE_CHECKING, Any
from pydantic import Field
from exo.shared.types.common import CommandId
from exo.shared.types.events import Event
from exo.shared.types.tasks import Task
from exo.shared.types.worker.instances import InstanceId
from exo.shared.types.worker.runners import RunnerId
from exo.utils.pydantic_ext import TaggedModel
if TYPE_CHECKING:
from exo.shared.topology import Topology
from exo.shared.types.worker.instances import BoundInstance, Instance
from exo.utils.channels import MpReceiver, MpSender
from exo.worker.runner.runner_supervisor import RunnerSupervisor
class PluginCommand(TaggedModel):
"""Base class for plugin-defined commands.
All plugin commands must inherit from this class. Commands are serialized
with their class name as a tag for routing.
"""
command_id: CommandId = Field(default_factory=CommandId)
class PluginInstance(TaggedModel):
"""Base class for plugin-defined instances.
All plugin instances must inherit from this class. Plugins are expected
to define their own instance type with workload-specific fields.
"""
instance_id: InstanceId
class ExoPlugin(ABC):
"""Protocol that all exo plugins must implement.
A plugin provides:
- Custom command types for API -> Master communication
- Custom instance types representing running workloads
- Placement logic for distributing work across nodes
- Planning logic for local task scheduling
- Runner implementation for executing work
"""
@property
@abstractmethod
def name(self) -> str:
"""Unique identifier for this plugin (e.g., 'flash', 'pytorch', 'mpi')."""
...
@property
@abstractmethod
def version(self) -> str:
"""Semantic version string (e.g., '1.0.0')."""
...
# ========== Type Registration ==========
@abstractmethod
def get_command_types(self) -> Sequence[type]:
"""Return command types this plugin handles.
These commands are routed to this plugin's process_command method.
Can return core BaseCommand types or PluginCommand types.
"""
...
@abstractmethod
def get_instance_type(self) -> type:
"""Return the instance type this plugin creates.
This instance type is used for routing in planning and runner bootstrap.
Can return core Instance types or PluginInstance types.
"""
...
# ========== API Routes ==========
@abstractmethod
def get_api_routes(
self,
) -> Sequence[tuple[str, str, Callable[..., Any]]]:
"""Return FastAPI routes to register.
Each tuple: (method, path, handler)
Example: [('post', '/flash/launch', self.launch_handler)]
Handlers receive a PluginContext with access to:
- state: Current State object
- send_command: Async function to send commands
- node_id: Current node's ID
"""
...
# ========== Master Command Handling ==========
@abstractmethod
def handles_command(self, command: Any) -> bool: # pyright: ignore[reportAny]
"""Return True if this plugin handles the given command type."""
...
@abstractmethod
def process_command(
self,
command: Any, # pyright: ignore[reportAny]
topology: "Topology",
current_instances: Mapping[InstanceId, "Instance"],
) -> Sequence[Event]:
"""Process a command and return events to emit.
Typically creates placement and returns InstanceCreated/InstanceDeleted events.
Args:
command: The command to process
topology: Current cluster topology
current_instances: Currently running instances
Returns:
Sequence of events to emit (e.g., InstanceCreated, InstanceDeleted)
"""
...
# ========== Worker Planning ==========
@abstractmethod
def handles_instance(self, instance: object) -> bool:
"""Return True if this plugin manages the given instance type."""
...
@abstractmethod
def plan_task(
self,
runners: Mapping[RunnerId, "RunnerSupervisor"],
instances: Mapping[InstanceId, "Instance"],
) -> Task | None:
"""Plan the next task for plugin instances.
Called during each planning cycle.
Return None if no task is needed.
"""
...
@abstractmethod
def should_skip_download(self, instance: object) -> bool:
"""Return True if this instance type doesn't need model downloads."""
...
# ========== Runner Bootstrap ==========
@abstractmethod
def create_runner(
self,
bound_instance: "BoundInstance",
event_sender: "MpSender[Event]",
task_receiver: "MpReceiver[Task]",
) -> None:
"""Entry point for the runner process.
Called in a subprocess to execute the actual workload.
This function should block until the workload completes.
"""
...

View File

@@ -1,21 +0,0 @@
"""Context objects passed to plugin handlers."""
from collections.abc import Awaitable, Callable
from dataclasses import dataclass
from exo.shared.types.commands import Command
from exo.shared.types.common import NodeId
from exo.shared.types.state import State
@dataclass
class PluginContext:
"""Context provided to plugin API handlers.
This gives plugins access to the current state and the ability to send
commands without direct access to internal API components.
"""
state: State
send_command: Callable[[Command], Awaitable[None]]
node_id: NodeId

View File

@@ -1,5 +0,0 @@
"""Plugin implementations directory.
Each subdirectory should contain a plugin with a register() function
that returns an ExoPlugin instance.
"""

View File

@@ -1,8 +0,0 @@
"""FLASH Plugin - MPI-based simulation support for Exo."""
from exo.plugins.implementations.flash.plugin import FLASHPlugin
def register() -> FLASHPlugin:
"""Entry point for plugin discovery."""
return FLASHPlugin()

View File

@@ -1,108 +0,0 @@
"""FLASH plugin API handlers."""
from typing import Any
from fastapi import HTTPException
from exo.plugins.context import PluginContext
# Use core types for serialization compatibility
from exo.shared.types.commands import LaunchFLASH, StopFLASH
from exo.shared.types.worker.instances import FLASHInstance
async def handle_launch_flash(
ctx: PluginContext,
simulation_name: str,
flash_executable_path: str,
working_directory: str,
parameter_file_path: str = "",
ranks_per_node: int = 1,
min_nodes: int = 1,
hosts: str = "",
) -> dict[str, str]:
"""Launch a FLASH MPI simulation across the cluster.
Args:
ctx: Plugin context with state and send_command
simulation_name: Name of the simulation
flash_executable_path: Path to the FLASH executable
working_directory: Working directory for the simulation
parameter_file_path: Path to parameter file (optional)
ranks_per_node: Number of MPI ranks per node
min_nodes: Minimum number of nodes required
hosts: Optional comma-separated hostnames (e.g., "s14,james21-1").
If not provided, IPs are discovered from topology edges.
"""
command = LaunchFLASH(
simulation_name=simulation_name,
flash_executable_path=flash_executable_path,
parameter_file_path=parameter_file_path,
working_directory=working_directory,
ranks_per_node=ranks_per_node,
min_nodes=min_nodes,
hosts=hosts,
)
await ctx.send_command(command)
return {
"message": "FLASH launch command received",
"command_id": str(command.command_id),
"simulation_name": simulation_name,
}
async def handle_stop_flash(
ctx: PluginContext,
instance_id: str,
) -> dict[str, str]:
"""Stop a running FLASH simulation."""
from exo.shared.types.worker.instances import InstanceId
inst_id = InstanceId(instance_id)
if inst_id not in ctx.state.instances:
raise HTTPException(status_code=404, detail="Instance not found")
instance = ctx.state.instances[inst_id]
if not isinstance(instance, FLASHInstance):
raise HTTPException(
status_code=400, detail="Instance is not a FLASH simulation"
)
command = StopFLASH(instance_id=inst_id)
await ctx.send_command(command)
return {
"message": "Stop command received",
"command_id": str(command.command_id),
"instance_id": str(instance_id),
}
async def handle_list_flash_instances(ctx: PluginContext) -> list[dict[str, Any]]:
"""List all FLASH simulation instances."""
flash_instances: list[dict[str, Any]] = []
for instance_id, instance in ctx.state.instances.items():
if isinstance(instance, FLASHInstance):
# Get runner statuses for this instance
runner_statuses: dict[str, str | None] = {}
for (
node_id,
runner_id,
) in instance.shard_assignments.node_to_runner.items():
runner_status = ctx.state.runners.get(runner_id)
runner_statuses[str(node_id)] = (
str(runner_status) if runner_status else None
)
flash_instances.append(
{
"instance_id": str(instance_id),
"simulation_name": instance.simulation_name,
"total_ranks": instance.total_ranks,
"working_directory": instance.working_directory,
"runner_statuses": runner_statuses,
}
)
return flash_instances

View File

@@ -1,152 +0,0 @@
"""FLASH plugin placement logic."""
from collections.abc import Mapping
from copy import deepcopy
from loguru import logger
from exo.shared.topology import Topology
from exo.shared.types.commands import LaunchFLASH
from exo.shared.types.common import Host, NodeId
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.topology import SocketConnection
from exo.shared.types.worker.instances import FLASHInstance, Instance, InstanceId
from exo.shared.types.worker.runners import (
RunnerId,
ShardAssignments,
)
from exo.shared.types.worker.shards import PipelineShardMetadata
def place_flash_instance(
command: LaunchFLASH,
topology: Topology,
current_instances: Mapping[InstanceId, Instance],
) -> dict[InstanceId, Instance]:
"""Place a FLASH simulation instance across available nodes.
Unlike MLX instances which use ring/JACCL topology for tensor parallelism,
FLASH instances use MPI for communication. We just need to provide the
node IPs so the runner can generate an MPI hostfile.
"""
instance_id = InstanceId()
target_instances: dict[InstanceId, Instance] = dict(deepcopy(current_instances))
all_nodes = list(topology.list_nodes())
if len(all_nodes) < command.min_nodes:
raise ValueError(
f"Not enough nodes: need {command.min_nodes}, have {len(all_nodes)}"
)
# Select nodes (take the first min_nodes)
selected_nodes = all_nodes[: command.min_nodes]
logger.info(
f"Placing FLASH instance '{command.simulation_name}' on {len(selected_nodes)} nodes"
)
# Build shard assignments (one runner per node for FLASH)
runner_to_shard: dict[RunnerId, PipelineShardMetadata] = {}
node_to_runner: dict[NodeId, RunnerId] = {}
# Create a dummy ModelMetadata for FLASH (required by ShardMetadata interface)
flash_model_meta = ModelMetadata(
model_id=ModelId(command.simulation_name),
pretty_name=f"FLASH: {command.simulation_name}",
storage_size=Memory(in_bytes=0),
n_layers=1,
hidden_size=1,
supports_tensor=False,
)
for i, node_id in enumerate(selected_nodes):
runner_id = RunnerId()
node_to_runner[node_id] = runner_id
runner_to_shard[runner_id] = PipelineShardMetadata(
device_rank=i,
world_size=len(selected_nodes),
model_meta=flash_model_meta,
start_layer=0,
end_layer=1,
n_layers=1,
)
shard_assignments = ShardAssignments(
model_id=ModelId(command.simulation_name),
runner_to_shard=runner_to_shard,
node_to_runner=node_to_runner,
)
# Build hosts_by_node - get hostnames/IPs for MPI hostfile generation
hosts_by_node: dict[NodeId, list[Host]] = {}
# If explicit hosts are provided, use them directly
if command.hosts:
explicit_hosts = [h.strip() for h in command.hosts.split(",") if h.strip()]
logger.info(f"FLASH placement: explicit hosts provided: {explicit_hosts}")
for i, node_id in enumerate(selected_nodes):
if i < len(explicit_hosts):
hosts_by_node[node_id] = [Host(ip=explicit_hosts[i], port=0)]
logger.info(
f"FLASH placement: node {node_id} (rank {i}) -> IP {explicit_hosts[i]}"
)
else:
logger.warning(
f"Not enough hosts provided for node {i}, using localhost"
)
hosts_by_node[node_id] = [Host(ip="127.0.0.1", port=0)]
logger.info(
f"FLASH placement: coordinator will be rank 0 at IP {explicit_hosts[0]}"
)
else:
# Try to get IPs from topology edges
for node_id in selected_nodes:
node_hosts: list[Host] = []
# Get IP from outgoing edges (connections to other nodes via mDNS discovery)
for conn in topology.out_edges(node_id):
if isinstance(conn.edge, SocketConnection):
# Extract IP from multiaddr
ip = conn.edge.sink_multiaddr.ip_address
# Skip link-local and localhost addresses
if not ip.startswith("169.254.") and not ip.startswith("127."):
node_hosts.append(Host(ip=ip, port=0))
break
# Last resort: use localhost (will only work for single-node)
if not node_hosts:
logger.warning(
f"Could not determine IP for node {node_id}, using localhost"
)
node_hosts.append(Host(ip="127.0.0.1", port=0))
hosts_by_node[node_id] = node_hosts
total_ranks = len(selected_nodes) * command.ranks_per_node
# Determine coordinator IP - first node's first host IP
first_node_id: NodeId = next(iter(hosts_by_node.keys()))
coordinator_ip: str = (
hosts_by_node[first_node_id][0].ip
if hosts_by_node[first_node_id]
else "127.0.0.1"
)
target_instances[instance_id] = FLASHInstance(
instance_id=instance_id,
shard_assignments=shard_assignments,
hosts_by_node=hosts_by_node,
flash_executable_path=command.flash_executable_path,
parameter_file_path=command.parameter_file_path,
working_directory=command.working_directory,
ranks_per_node=command.ranks_per_node,
total_ranks=total_ranks,
simulation_name=command.simulation_name,
coordinator_ip=coordinator_ip,
)
logger.info(f"Created FLASH instance {instance_id} with {total_ranks} total ranks")
return target_instances

View File

@@ -1,36 +0,0 @@
"""FLASH plugin planning logic."""
from collections.abc import Mapping
from exo.shared.types.tasks import LoadModel, Task
from exo.shared.types.worker.instances import FLASHInstance, Instance, InstanceId
from exo.shared.types.worker.runners import RunnerId, RunnerIdle
from exo.worker.runner.runner_supervisor import RunnerSupervisor
def plan_flash(
runners: Mapping[RunnerId, RunnerSupervisor],
instances: Mapping[InstanceId, Instance],
) -> Task | None:
"""Plan tasks specifically for FLASH instances.
FLASH instances have a simpler lifecycle:
- CreateRunner (handled by core _create_runner)
- LoadModel (starts the simulation immediately)
- Shutdown (handled by core _kill_runner)
This function handles the LoadModel step for FLASH instances,
skipping the MLX-specific download/init/warmup steps.
"""
for runner in runners.values():
instance = runner.bound_instance.instance
# Only handle FLASH instances
if not isinstance(instance, FLASHInstance):
continue
# If runner is idle, emit LoadModel to start the simulation
if isinstance(runner.status, RunnerIdle):
return LoadModel(instance_id=instance.instance_id)
return None

View File

@@ -1,98 +0,0 @@
"""FLASH Plugin - Main plugin class."""
from collections.abc import Callable, Mapping, Sequence
from typing import Any
from exo.plugins.base import ExoPlugin
from exo.plugins.implementations.flash.api_handlers import (
handle_launch_flash,
handle_list_flash_instances,
handle_stop_flash,
)
from exo.plugins.implementations.flash.placement import place_flash_instance
from exo.plugins.implementations.flash.planning import plan_flash
from exo.plugins.implementations.flash.runner import main as flash_runner_main
from exo.shared.topology import Topology
from exo.shared.types.commands import DeleteInstance, LaunchFLASH, StopFLASH
from exo.shared.types.events import Event
from exo.shared.types.tasks import Task
from exo.shared.types.worker.instances import (
BoundInstance,
FLASHInstance,
Instance,
InstanceId,
)
from exo.shared.types.worker.runners import RunnerId
from exo.utils.channels import MpReceiver, MpSender
from exo.worker.runner.runner_supervisor import RunnerSupervisor
class FLASHPlugin(ExoPlugin):
"""Plugin for FLASH MPI simulations."""
@property
def name(self) -> str:
return "flash"
@property
def version(self) -> str:
return "1.0.0"
def get_command_types(self) -> Sequence[type]:
return [LaunchFLASH, StopFLASH]
def get_instance_type(self) -> type:
return FLASHInstance
def get_api_routes(
self,
) -> Sequence[tuple[str, str, Callable[..., Any]]]:
return [
("post", "/flash/launch", handle_launch_flash),
("delete", "/flash/{instance_id}", handle_stop_flash),
("get", "/flash/instances", handle_list_flash_instances),
]
def handles_command(self, command: Any) -> bool: # pyright: ignore[reportAny]
return isinstance(command, (LaunchFLASH, StopFLASH))
def process_command(
self,
command: Any, # pyright: ignore[reportAny]
topology: Topology,
current_instances: Mapping[InstanceId, Instance],
) -> Sequence[Event]:
from exo.master.placement import delete_instance, get_transition_events
if isinstance(command, LaunchFLASH):
placement = place_flash_instance(command, topology, current_instances)
return list(get_transition_events(current_instances, placement))
elif isinstance(command, StopFLASH):
placement = delete_instance(
DeleteInstance(instance_id=command.instance_id),
current_instances,
)
return list(get_transition_events(current_instances, placement))
return []
def handles_instance(self, instance: object) -> bool:
return isinstance(instance, FLASHInstance)
def plan_task(
self,
runners: Mapping[RunnerId, RunnerSupervisor],
instances: Mapping[InstanceId, Instance],
) -> Task | None:
return plan_flash(runners, instances)
def should_skip_download(self, instance: object) -> bool:
# FLASH instances don't need model downloads
return True
def create_runner(
self,
bound_instance: BoundInstance,
event_sender: MpSender[Event],
task_receiver: MpReceiver[Task],
) -> None:
flash_runner_main(bound_instance, event_sender, task_receiver)

View File

@@ -1,302 +0,0 @@
"""FLASH MPI Runner - spawns and monitors FLASH simulations.
Exo-native distributed MPI:
- Exo handles node discovery and coordination
- Coordinator generates hostfile from Exo topology
- mpirun uses exo-rsh (no SSH required) to spawn on remote nodes
- exo-rsh connects to each node's Exo API (/execute endpoint) for remote execution
- Workers just report ready and wait
"""
import os
import shutil
import socket
import subprocess
import threading
from loguru import logger
from exo.shared.types.events import (
Event,
RunnerStatusUpdated,
TaskAcknowledged,
TaskStatusUpdated,
)
from exo.shared.types.tasks import (
LoadModel,
Shutdown,
Task,
TaskStatus,
)
from exo.shared.types.worker.instances import BoundInstance, FLASHInstance
from exo.shared.types.worker.runners import (
RunnerFailed,
RunnerIdle,
RunnerLoading,
RunnerReady,
RunnerRunning,
RunnerShutdown,
RunnerShuttingDown,
RunnerStatus,
)
from exo.utils.channels import MpReceiver, MpSender
# Find mpirun in PATH, fallback to common locations
MPIRUN_PATH = shutil.which("mpirun") or "/opt/homebrew/bin/mpirun"
# exo-rsh is installed as console script by exo package
_exo_rsh_path = shutil.which("exo-rsh")
if not _exo_rsh_path:
raise RuntimeError("exo-rsh not found in PATH - this should be installed with exo")
EXO_RSH_PATH: str = _exo_rsh_path
def get_my_rank(instance: FLASHInstance, my_node_id: str) -> int:
"""Determine this node's rank based on position in hosts_by_node."""
for i, node_id in enumerate(instance.hosts_by_node.keys()):
if str(node_id) == str(my_node_id):
return i
return -1
def get_coordinator_host(instance: FLASHInstance) -> str:
"""Get the IP of the coordinator node."""
return instance.coordinator_ip
def resolve_host(host: str) -> str:
"""Resolve host string to a usable hostname for MPI hostfile.
Accepts either an IP address or hostname. For IPs, attempts to resolve
to a hostname via DNS/mDNS. Hostnames are returned as-is after validation.
"""
# Check if input is already a hostname (not an IP)
try:
socket.inet_aton(host)
is_ip = True
except socket.error:
is_ip = False
if not is_ip:
# Already a hostname, verify it resolves and return as-is
try:
socket.gethostbyname(host)
return host
except socket.gaierror:
logger.warning(f"Hostname {host} does not resolve, using anyway")
return host
# It's an IP address, try to resolve to hostname
try:
hostname, _, _ = socket.gethostbyaddr(host)
hostname = hostname.split(".")[0]
logger.info(f"Resolved {host} to {hostname}")
return hostname
except socket.herror:
pass
# Fall back to IP
logger.warning(f"Could not resolve {host} to hostname, using IP directly")
return host
def generate_hostfile(instance: FLASHInstance, working_dir: str) -> str:
"""Generate MPI hostfile from instance topology."""
hostfile_path = os.path.join(working_dir, "flash_hosts.txt")
with open(hostfile_path, "w") as f:
for _node_id, hosts in instance.hosts_by_node.items():
if hosts:
host = resolve_host(hosts[0].ip)
f.write(f"{host} slots={instance.ranks_per_node}\n")
logger.info(f"Generated hostfile at {hostfile_path}")
with open(hostfile_path, "r") as f:
logger.info(f"Hostfile contents:\n{f.read()}")
return hostfile_path
def main(
bound_instance: BoundInstance,
event_sender: MpSender[Event],
task_receiver: MpReceiver[Task],
) -> None:
"""Main FLASH runner loop.
Coordinator: generates hostfile and runs mpirun (uses exo-rsh instead of SSH)
Workers: just report ready and wait for mpirun to spawn processes on them
"""
assert isinstance(bound_instance.instance, FLASHInstance)
instance = bound_instance.instance
runner_id = bound_instance.bound_runner_id
my_node_id = str(bound_instance.bound_node_id)
logger.info(f"FLASH runner starting for simulation: {instance.simulation_name}")
my_rank = get_my_rank(instance, my_node_id)
world_size = len(instance.hosts_by_node)
is_coordinator = my_rank == 0
coordinator_ip = get_coordinator_host(instance)
logger.info(
f"FLASH node: rank={my_rank}, world_size={world_size}, coordinator={is_coordinator}"
)
logger.info(f"FLASH coordinator IP: {coordinator_ip}")
process: subprocess.Popen[bytes] | None = None
current_status: RunnerStatus = RunnerIdle()
shutdown_requested = False
event_sender.send(
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
)
def monitor_output(proc: subprocess.Popen[bytes]) -> None:
"""Monitor FLASH stdout for progress updates."""
if proc.stdout is None:
return
for line in iter(proc.stdout.readline, b""):
if shutdown_requested:
break
try:
decoded: str = line.decode("utf-8", errors="replace").strip()
if decoded:
logger.info(f"[FLASH] {decoded}")
except Exception as e:
logger.warning(f"Error parsing FLASH output: {e}")
with task_receiver as tasks:
for task in tasks:
event_sender.send(
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Running)
)
event_sender.send(TaskAcknowledged(task_id=task.task_id))
match task:
case LoadModel() if isinstance(current_status, RunnerIdle):
current_status = RunnerLoading()
logger.info("Starting FLASH simulation")
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=current_status
)
)
try:
if is_coordinator:
# Coordinator: generate hostfile and run mpirun
hostfile = generate_hostfile(
instance, instance.working_directory
)
iface = instance.network_interface
cmd = [
MPIRUN_PATH,
"-np",
str(instance.total_ranks),
"--hostfile",
hostfile,
"--wdir",
instance.working_directory,
"--oversubscribe",
"--mca",
"btl",
"tcp,self",
"--mca",
"btl_tcp_if_include",
iface,
"--mca",
"oob_tcp_if_include",
iface,
"--mca",
"plm_rsh_no_tree_spawn",
"1",
]
# Use exo-rsh for remote execution (no SSH needed)
cmd.extend(["--mca", "plm_rsh_agent", EXO_RSH_PATH])
cmd.append(instance.flash_executable_path)
logger.info(f"FLASH distributed launch: {' '.join(cmd)}")
process = subprocess.Popen(
cmd,
cwd=instance.working_directory,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
)
monitor_thread = threading.Thread(
target=monitor_output, args=(process,), daemon=True
)
monitor_thread.start()
current_status = RunnerRunning()
logger.info(
f"FLASH running on {world_size} nodes with {instance.total_ranks} ranks"
)
else:
# Worker: mpirun on coordinator will use exo-rsh to spawn processes here
logger.info(
f"Worker {my_rank}: Ready for mpirun to spawn processes via exo-rsh"
)
current_status = RunnerRunning()
except Exception as e:
logger.error(f"Failed to start FLASH: {e}")
import traceback
logger.error(traceback.format_exc())
current_status = RunnerFailed(error_message=str(e))
case Shutdown():
shutdown_requested = True
current_status = RunnerShuttingDown()
logger.info("FLASH runner shutting down")
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=current_status
)
)
if process and process.poll() is None:
logger.info("Terminating FLASH simulation")
process.terminate()
try:
process.wait(timeout=10)
except subprocess.TimeoutExpired:
logger.warning("FLASH didn't terminate, killing")
process.kill()
process.wait()
current_status = RunnerShutdown()
case _:
if process and process.poll() is not None:
exit_code = process.returncode
if exit_code == 0:
logger.info("FLASH simulation completed successfully")
current_status = RunnerReady()
else:
logger.error(
f"FLASH simulation failed with code {exit_code}"
)
current_status = RunnerFailed(
error_message=f"Exit code {exit_code}"
)
event_sender.send(
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Complete)
)
event_sender.send(
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
)
if isinstance(current_status, RunnerShutdown):
break
if process and process.poll() is None:
process.terminate()
process.wait(timeout=5)
logger.info("FLASH runner exiting")

View File

@@ -1,110 +0,0 @@
"""Plugin registry for discovering and managing plugins."""
from collections.abc import Callable, Sequence
from typing import Any
from loguru import logger
from exo.plugins.base import ExoPlugin
class PluginRegistry:
"""Central registry for all plugins."""
_instance: "PluginRegistry | None" = None
def __init__(self) -> None:
self._plugins: dict[str, ExoPlugin] = {}
self._command_handlers: dict[type, ExoPlugin] = {}
self._instance_handlers: dict[type, ExoPlugin] = {}
@classmethod
def get(cls) -> "PluginRegistry":
"""Get the singleton registry instance."""
if cls._instance is None:
cls._instance = cls()
return cls._instance
@classmethod
def reset(cls) -> None:
"""Reset the singleton instance (useful for testing)."""
cls._instance = None
def register(self, plugin: ExoPlugin) -> None:
"""Register a plugin and its types."""
if plugin.name in self._plugins:
raise ValueError(f"Plugin '{plugin.name}' already registered")
logger.info(f"Registering plugin: {plugin.name} v{plugin.version}")
self._plugins[plugin.name] = plugin
# Register command handlers
for cmd_type in plugin.get_command_types():
self._command_handlers[cmd_type] = plugin
logger.debug(f" Registered command: {cmd_type.__name__}")
# Register instance handler
instance_type = plugin.get_instance_type()
self._instance_handlers[instance_type] = plugin
logger.debug(f" Registered instance: {instance_type.__name__}")
def get_plugin(self, name: str) -> ExoPlugin | None:
"""Get a plugin by name."""
return self._plugins.get(name)
def get_plugin_for_command(self, command: object) -> ExoPlugin | None:
"""Get the plugin that handles a command."""
for plugin in self._plugins.values():
if plugin.handles_command(command):
return plugin
return None
def get_plugin_for_instance(self, instance: object) -> ExoPlugin | None:
"""Get the plugin that manages an instance."""
for plugin in self._plugins.values():
if plugin.handles_instance(instance):
return plugin
return None
def all_plugins(self) -> Sequence[ExoPlugin]:
"""Get all registered plugins."""
return list(self._plugins.values())
def get_all_api_routes(
self,
) -> Sequence[tuple[str, str, Callable[..., Any], ExoPlugin]]:
"""Get all API routes from all plugins."""
routes: list[tuple[str, str, Callable[..., Any], ExoPlugin]] = []
for plugin in self._plugins.values():
for method, path, handler in plugin.get_api_routes():
routes.append((method, path, handler, plugin))
return routes
def discover_plugins() -> None:
"""Auto-discover and register plugins from the implementations directory.
Plugins should have a register() function that returns an ExoPlugin instance.
"""
import importlib
import pkgutil
registry = PluginRegistry.get()
try:
import exo.plugins.implementations as impl_package
for _, module_name, _ in pkgutil.iter_modules(impl_package.__path__):
try:
module = importlib.import_module(
f"exo.plugins.implementations.{module_name}"
)
if hasattr(module, "register"):
plugin = module.register() # pyright: ignore[reportAny]
if plugin is not None:
registry.register(plugin) # pyright: ignore[reportAny]
except Exception as e:
logger.warning(f"Failed to load plugin {module_name}: {e}")
except ImportError:
logger.debug("No plugin implementations package found")

View File

@@ -1,13 +0,0 @@
"""Exo RSH - Remote Shell for MPI without SSH.
This module provides a remote execution mechanism that allows mpirun to spawn
processes on remote nodes without requiring SSH setup. It works by:
1. Each Exo node runs an API server on port 52415 with an /execute endpoint
2. The exo-rsh script acts as a drop-in replacement for ssh
3. When mpirun calls "exo-rsh hostname command", it HTTP POSTs to the target's /execute
4. The target executes the command and returns output
Usage:
mpirun --mca plm_rsh_agent exo-rsh -np 4 --hostfile hosts.txt ./program
"""

View File

@@ -1,101 +0,0 @@
#!/usr/bin/env python3
"""exo-rsh - Remote shell client for MPI.
This script is called by mpirun as a replacement for ssh.
Usage: exo-rsh [ssh-options...] hostname command [args...]
It connects to the target node's Exo API (port 52415) and executes the command.
"""
import json
import socket
import sys
from typing import Any, cast
from urllib.error import URLError
from urllib.request import Request, urlopen
# Use the same port as Exo's API server
EXO_API_PORT = 52415
def resolve_hostname(hostname: str) -> str:
"""Resolve hostname to IP address."""
try:
return socket.gethostbyname(hostname)
except socket.gaierror:
# If resolution fails, try using the hostname directly
return hostname
def main():
# Parse arguments - mpirun calls us like: exo-rsh [options] hostname command [args...]
# SSH options we might see: -x (disable X11), -o options, etc.
args = sys.argv[1:]
# Skip SSH-style options
hostname = None
command_start = 0
i = 0
while i < len(args):
arg = args[i]
if arg.startswith("-"):
# Skip option and its value if needed
if arg in ("-o", "-i", "-l", "-p", "-F"):
i += 2 # Skip option and its argument
continue
i += 1
continue
else:
# First non-option is the hostname
hostname = arg
command_start = i + 1
break
i += 1
if hostname is None or command_start >= len(args):
print("Usage: exo-rsh [options] hostname command [args...]", file=sys.stderr)
sys.exit(1)
command = args[command_start:]
# Resolve hostname to IP
ip = resolve_hostname(hostname)
# Make request to Exo API
url = f"http://{ip}:{EXO_API_PORT}/execute"
data = json.dumps({"command": command}).encode("utf-8")
try:
req = Request(url, data=data, headers={"Content-Type": "application/json"})
with urlopen(req, timeout=300) as response: # pyright: ignore[reportAny]
response_body: bytes = cast(bytes, response.read()) # pyright: ignore[reportAny]
result: dict[str, Any] = json.loads(response_body.decode("utf-8")) # pyright: ignore[reportAny]
# Output stdout/stderr
stdout: str = cast(str, result.get("stdout", ""))
stderr: str = cast(str, result.get("stderr", ""))
exit_code: int = cast(int, result.get("exit_code", 0))
if stdout:
sys.stdout.write(stdout)
sys.stdout.flush()
if stderr:
sys.stderr.write(stderr)
sys.stderr.flush()
sys.exit(exit_code)
except URLError as e:
print(
f"exo-rsh: Failed to connect to {hostname}:{EXO_API_PORT}: {e}",
file=sys.stderr,
)
sys.exit(255)
except Exception as e:
print(f"exo-rsh: Error: {e}", file=sys.stderr)
sys.exit(1)
if __name__ == "__main__":
main()

View File

@@ -1,552 +1,445 @@
from typing import Annotated
import aiofiles
import aiofiles.os as aios
import tomlkit
from anyio import Path, open_file
from huggingface_hub import model_info
from loguru import logger
from pydantic import BaseModel, Field, PositiveInt
from exo.shared.types.common import ModelId
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.utils.pydantic_ext import CamelCaseModel
_card_cache: dict[str, "ModelCard"] = {}
class ModelCard(CamelCaseModel):
short_id: str
model_id: ModelId
name: str
description: str
tags: list[str]
metadata: ModelMetadata
storage_size: Memory
n_layers: PositiveInt
hidden_size: PositiveInt
supports_tensor: bool
async def save(self, path: Path) -> None:
async with await open_file(path, "w") as f:
py = self.model_dump()
data = tomlkit.dumps(py) # pyright: ignore[reportUnknownMemberType]
await f.write(data)
@staticmethod
async def load_from_path(path: Path) -> "ModelCard":
async with await open_file(path, "r") as f:
py = tomlkit.loads(await f.read())
return ModelCard.model_validate(py)
@staticmethod
async def load(model_id: ModelId) -> "ModelCard":
if model_id in MODEL_CARDS:
return MODEL_CARDS[model_id]
return await ModelCard.from_hf(model_id)
@staticmethod
async def from_hf(model_id: ModelId) -> "ModelCard":
"""Fetches storage size and number of layers for a Hugging Face model, returns Pydantic ModelMeta."""
if (mc := _card_cache.get(model_id)) is not None:
return mc
config_data = await get_config_data(model_id)
num_layers = config_data.layer_count
mem_size_bytes = await get_safetensors_size(model_id)
mc = ModelCard(
model_id=ModelId(model_id),
storage_size=mem_size_bytes,
n_layers=num_layers,
hidden_size=config_data.hidden_size or 0,
supports_tensor=config_data.supports_tensor,
)
_card_cache[model_id] = mc
return mc
MODEL_CARDS: dict[str, ModelCard] = {
# deepseek v3
"deepseek-v3.1-4bit": ModelCard(
short_id="deepseek-v3.1-4bit",
model_id=ModelId("mlx-community/DeepSeek-V3.1-4bit"),
name="DeepSeek V3.1 (4-bit)",
description="""DeepSeek V3.1 is a large language model trained on the DeepSeek V3.1 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/DeepSeek-V3.1-4bit"),
pretty_name="DeepSeek V3.1 (4-bit)",
storage_size=Memory.from_gb(378),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
),
storage_size=Memory.from_gb(378),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
),
"deepseek-v3.1-8bit": ModelCard(
short_id="deepseek-v3.1-8bit",
model_id=ModelId("mlx-community/DeepSeek-V3.1-8bit"),
name="DeepSeek V3.1 (8-bit)",
description="""DeepSeek V3.1 is a large language model trained on the DeepSeek V3.1 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/DeepSeek-V3.1-8bit"),
pretty_name="DeepSeek V3.1 (8-bit)",
storage_size=Memory.from_gb(713),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
),
storage_size=Memory.from_gb(713),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
),
# kimi k2
"kimi-k2-instruct-4bit": ModelCard(
short_id="kimi-k2-instruct-4bit",
model_id=ModelId("mlx-community/Kimi-K2-Instruct-4bit"),
name="Kimi K2 Instruct (4-bit)",
description="""Kimi K2 is a large language model trained on the Kimi K2 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Kimi-K2-Instruct-4bit"),
pretty_name="Kimi K2 Instruct (4-bit)",
storage_size=Memory.from_gb(578),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
),
storage_size=Memory.from_gb(578),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
),
"kimi-k2-thinking": ModelCard(
short_id="kimi-k2-thinking",
model_id=ModelId("mlx-community/Kimi-K2-Thinking"),
name="Kimi K2 Thinking (4-bit)",
description="""Kimi K2 Thinking is the latest, most capable version of open-source thinking model.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Kimi-K2-Thinking"),
pretty_name="Kimi K2 Thinking (4-bit)",
storage_size=Memory.from_gb(658),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
),
storage_size=Memory.from_gb(658),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
),
# llama-3.1
"llama-3.1-8b": ModelCard(
short_id="llama-3.1-8b",
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-4bit"),
name="Llama 3.1 8B (4-bit)",
description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-4bit"),
pretty_name="Llama 3.1 8B (4-bit)",
storage_size=Memory.from_mb(4423),
n_layers=32,
hidden_size=4096,
supports_tensor=True,
),
storage_size=Memory.from_mb(4423),
n_layers=32,
hidden_size=4096,
supports_tensor=True,
),
"llama-3.1-8b-8bit": ModelCard(
short_id="llama-3.1-8b-8bit",
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-8bit"),
name="Llama 3.1 8B (8-bit)",
description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-8bit"),
pretty_name="Llama 3.1 8B (8-bit)",
storage_size=Memory.from_mb(8540),
n_layers=32,
hidden_size=4096,
supports_tensor=True,
),
storage_size=Memory.from_mb(8540),
n_layers=32,
hidden_size=4096,
supports_tensor=True,
),
"llama-3.1-8b-bf16": ModelCard(
short_id="llama-3.1-8b-bf16",
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-bf16"),
name="Llama 3.1 8B (BF16)",
description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-bf16"),
pretty_name="Llama 3.1 8B (BF16)",
storage_size=Memory.from_mb(16100),
n_layers=32,
hidden_size=4096,
supports_tensor=True,
),
storage_size=Memory.from_mb(16100),
n_layers=32,
hidden_size=4096,
supports_tensor=True,
),
"llama-3.1-70b": ModelCard(
short_id="llama-3.1-70b",
model_id=ModelId("mlx-community/Meta-Llama-3.1-70B-Instruct-4bit"),
name="Llama 3.1 70B (4-bit)",
description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Meta-Llama-3.1-70B-Instruct-4bit"),
pretty_name="Llama 3.1 70B (4-bit)",
storage_size=Memory.from_mb(38769),
n_layers=80,
hidden_size=8192,
supports_tensor=True,
),
storage_size=Memory.from_mb(38769),
n_layers=80,
hidden_size=8192,
supports_tensor=True,
),
# llama-3.2
"llama-3.2-1b": ModelCard(
short_id="llama-3.2-1b",
model_id=ModelId("mlx-community/Llama-3.2-1B-Instruct-4bit"),
name="Llama 3.2 1B (4-bit)",
description="""Llama 3.2 is a large language model trained on the Llama 3.2 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Llama-3.2-1B-Instruct-4bit"),
pretty_name="Llama 3.2 1B (4-bit)",
storage_size=Memory.from_mb(696),
n_layers=16,
hidden_size=2048,
supports_tensor=True,
),
storage_size=Memory.from_mb(696),
n_layers=16,
hidden_size=2048,
supports_tensor=True,
),
"llama-3.2-3b": ModelCard(
short_id="llama-3.2-3b",
model_id=ModelId("mlx-community/Llama-3.2-3B-Instruct-4bit"),
name="Llama 3.2 3B (4-bit)",
description="""Llama 3.2 is a large language model trained on the Llama 3.2 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Llama-3.2-3B-Instruct-4bit"),
pretty_name="Llama 3.2 3B (4-bit)",
storage_size=Memory.from_mb(1777),
n_layers=28,
hidden_size=3072,
supports_tensor=True,
),
storage_size=Memory.from_mb(1777),
n_layers=28,
hidden_size=3072,
supports_tensor=True,
),
"llama-3.2-3b-8bit": ModelCard(
short_id="llama-3.2-3b-8bit",
model_id=ModelId("mlx-community/Llama-3.2-3B-Instruct-8bit"),
name="Llama 3.2 3B (8-bit)",
description="""Llama 3.2 is a large language model trained on the Llama 3.2 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Llama-3.2-3B-Instruct-8bit"),
pretty_name="Llama 3.2 3B (8-bit)",
storage_size=Memory.from_mb(3339),
n_layers=28,
hidden_size=3072,
supports_tensor=True,
),
storage_size=Memory.from_mb(3339),
n_layers=28,
hidden_size=3072,
supports_tensor=True,
),
# llama-3.3
"llama-3.3-70b": ModelCard(
short_id="llama-3.3-70b",
model_id=ModelId("mlx-community/Llama-3.3-70B-Instruct-4bit"),
name="Llama 3.3 70B (4-bit)",
description="""The Meta Llama 3.3 multilingual large language model (LLM) is an instruction tuned generative model in 70B (text in/text out)""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Llama-3.3-70B-Instruct-4bit"),
pretty_name="Llama 3.3 70B",
storage_size=Memory.from_mb(38769),
n_layers=80,
hidden_size=8192,
supports_tensor=True,
),
storage_size=Memory.from_mb(38769),
n_layers=80,
hidden_size=8192,
supports_tensor=True,
),
"llama-3.3-70b-8bit": ModelCard(
short_id="llama-3.3-70b-8bit",
model_id=ModelId("mlx-community/Llama-3.3-70B-Instruct-8bit"),
name="Llama 3.3 70B (8-bit)",
description="""The Meta Llama 3.3 multilingual large language model (LLM) is an instruction tuned generative model in 70B (text in/text out)""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Llama-3.3-70B-Instruct-8bit"),
pretty_name="Llama 3.3 70B (8-bit)",
storage_size=Memory.from_mb(73242),
n_layers=80,
hidden_size=8192,
supports_tensor=True,
),
storage_size=Memory.from_mb(73242),
n_layers=80,
hidden_size=8192,
supports_tensor=True,
),
"llama-3.3-70b-fp16": ModelCard(
short_id="llama-3.3-70b-fp16",
model_id=ModelId("mlx-community/llama-3.3-70b-instruct-fp16"),
name="Llama 3.3 70B (FP16)",
description="""The Meta Llama 3.3 multilingual large language model (LLM) is an instruction tuned generative model in 70B (text in/text out)""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/llama-3.3-70b-instruct-fp16"),
pretty_name="Llama 3.3 70B (FP16)",
storage_size=Memory.from_mb(137695),
n_layers=80,
hidden_size=8192,
supports_tensor=True,
),
storage_size=Memory.from_mb(137695),
n_layers=80,
hidden_size=8192,
supports_tensor=True,
),
# qwen3
"qwen3-0.6b": ModelCard(
short_id="qwen3-0.6b",
model_id=ModelId("mlx-community/Qwen3-0.6B-4bit"),
name="Qwen3 0.6B (4-bit)",
description="""Qwen3 0.6B is a large language model trained on the Qwen3 0.6B dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-0.6B-4bit"),
pretty_name="Qwen3 0.6B (4-bit)",
storage_size=Memory.from_mb(327),
n_layers=28,
hidden_size=1024,
supports_tensor=False,
),
storage_size=Memory.from_mb(327),
n_layers=28,
hidden_size=1024,
supports_tensor=False,
),
"qwen3-0.6b-8bit": ModelCard(
short_id="qwen3-0.6b-8bit",
model_id=ModelId("mlx-community/Qwen3-0.6B-8bit"),
name="Qwen3 0.6B (8-bit)",
description="""Qwen3 0.6B is a large language model trained on the Qwen3 0.6B dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-0.6B-8bit"),
pretty_name="Qwen3 0.6B (8-bit)",
storage_size=Memory.from_mb(666),
n_layers=28,
hidden_size=1024,
supports_tensor=False,
),
storage_size=Memory.from_mb(666),
n_layers=28,
hidden_size=1024,
supports_tensor=False,
),
"qwen3-30b": ModelCard(
short_id="qwen3-30b",
model_id=ModelId("mlx-community/Qwen3-30B-A3B-4bit"),
name="Qwen3 30B A3B (4-bit)",
description="""Qwen3 30B is a large language model trained on the Qwen3 30B dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-30B-A3B-4bit"),
pretty_name="Qwen3 30B A3B (4-bit)",
storage_size=Memory.from_mb(16797),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
storage_size=Memory.from_mb(16797),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
"qwen3-30b-8bit": ModelCard(
short_id="qwen3-30b-8bit",
model_id=ModelId("mlx-community/Qwen3-30B-A3B-8bit"),
name="Qwen3 30B A3B (8-bit)",
description="""Qwen3 30B is a large language model trained on the Qwen3 30B dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-30B-A3B-8bit"),
pretty_name="Qwen3 30B A3B (8-bit)",
storage_size=Memory.from_mb(31738),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
storage_size=Memory.from_mb(31738),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
"qwen3-80b-a3B-4bit": ModelCard(
short_id="qwen3-80b-a3B-4bit",
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-4bit"),
name="Qwen3 80B A3B (4-bit)",
description="""Qwen3 80B""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-4bit"),
pretty_name="Qwen3 80B A3B (4-bit)",
storage_size=Memory.from_mb(44800),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
storage_size=Memory.from_mb(44800),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
"qwen3-80b-a3B-8bit": ModelCard(
short_id="qwen3-80b-a3B-8bit",
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit"),
name="Qwen3 80B A3B (8-bit)",
description="""Qwen3 80B""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit"),
pretty_name="Qwen3 80B A3B (8-bit)",
storage_size=Memory.from_mb(84700),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
storage_size=Memory.from_mb(84700),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
"qwen3-80b-a3B-thinking-4bit": ModelCard(
short_id="qwen3-80b-a3B-thinking-4bit",
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit"),
name="Qwen3 80B A3B Thinking (4-bit)",
description="""Qwen3 80B Reasoning model""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit"),
pretty_name="Qwen3 80B A3B (4-bit)",
storage_size=Memory.from_mb(84700),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
storage_size=Memory.from_mb(84700),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
"qwen3-80b-a3B-thinking-8bit": ModelCard(
short_id="qwen3-80b-a3B-thinking-8bit",
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit"),
name="Qwen3 80B A3B Thinking (8-bit)",
description="""Qwen3 80B Reasoning model""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit"),
pretty_name="Qwen3 80B A3B (8-bit)",
storage_size=Memory.from_mb(84700),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
storage_size=Memory.from_mb(84700),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
"qwen3-235b-a22b-4bit": ModelCard(
short_id="qwen3-235b-a22b-4bit",
model_id=ModelId("mlx-community/Qwen3-235B-A22B-Instruct-2507-4bit"),
name="Qwen3 235B A22B (4-bit)",
description="""Qwen3 235B (Active 22B) is a large language model trained on the Qwen3 235B dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-235B-A22B-Instruct-2507-4bit"),
pretty_name="Qwen3 235B A22B (4-bit)",
storage_size=Memory.from_gb(132),
n_layers=94,
hidden_size=4096,
supports_tensor=True,
),
storage_size=Memory.from_gb(132),
n_layers=94,
hidden_size=4096,
supports_tensor=True,
),
"qwen3-235b-a22b-8bit": ModelCard(
short_id="qwen3-235b-a22b-8bit",
model_id=ModelId("mlx-community/Qwen3-235B-A22B-Instruct-2507-8bit"),
name="Qwen3 235B A22B (8-bit)",
description="""Qwen3 235B (Active 22B) is a large language model trained on the Qwen3 235B dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-235B-A22B-Instruct-2507-8bit"),
pretty_name="Qwen3 235B A22B (8-bit)",
storage_size=Memory.from_gb(250),
n_layers=94,
hidden_size=4096,
supports_tensor=True,
),
storage_size=Memory.from_gb(250),
n_layers=94,
hidden_size=4096,
supports_tensor=True,
),
"qwen3-coder-480b-a35b-4bit": ModelCard(
short_id="qwen3-coder-480b-a35b-4bit",
model_id=ModelId("mlx-community/Qwen3-Coder-480B-A35B-Instruct-4bit"),
name="Qwen3 Coder 480B A35B (4-bit)",
description="""Qwen3 Coder 480B (Active 35B) is a large language model trained on the Qwen3 Coder 480B dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-Coder-480B-A35B-Instruct-4bit"),
pretty_name="Qwen3 Coder 480B A35B (4-bit)",
storage_size=Memory.from_gb(270),
n_layers=62,
hidden_size=6144,
supports_tensor=True,
),
storage_size=Memory.from_gb(270),
n_layers=62,
hidden_size=6144,
supports_tensor=True,
),
"qwen3-coder-480b-a35b-8bit": ModelCard(
short_id="qwen3-coder-480b-a35b-8bit",
model_id=ModelId("mlx-community/Qwen3-Coder-480B-A35B-Instruct-8bit"),
name="Qwen3 Coder 480B A35B (8-bit)",
description="""Qwen3 Coder 480B (Active 35B) is a large language model trained on the Qwen3 Coder 480B dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-Coder-480B-A35B-Instruct-8bit"),
pretty_name="Qwen3 Coder 480B A35B (8-bit)",
storage_size=Memory.from_gb(540),
n_layers=62,
hidden_size=6144,
supports_tensor=True,
),
storage_size=Memory.from_gb(540),
n_layers=62,
hidden_size=6144,
supports_tensor=True,
),
# gpt-oss
"gpt-oss-120b-MXFP4-Q8": ModelCard(
short_id="gpt-oss-120b-MXFP4-Q8",
model_id=ModelId("mlx-community/gpt-oss-120b-MXFP4-Q8"),
name="GPT-OSS 120B (MXFP4-Q8, MLX)",
description="""OpenAI's GPT-OSS 120B is a 117B-parameter Mixture-of-Experts model designed for high-reasoning and general-purpose use; this variant is a 4-bit MLX conversion for Apple Silicon.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/gpt-oss-120b-MXFP4-Q8"),
pretty_name="GPT-OSS 120B (MXFP4-Q8, MLX)",
storage_size=Memory.from_kb(68_996_301),
n_layers=36,
hidden_size=2880,
supports_tensor=True,
),
storage_size=Memory.from_kb(68_996_301),
n_layers=36,
hidden_size=2880,
supports_tensor=True,
),
"gpt-oss-20b-MXFP4-Q8": ModelCard(
short_id="gpt-oss-20b-MXFP4-Q8",
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q8"),
name="GPT-OSS 20B (MXFP4-Q8, MLX)",
description="""OpenAI's GPT-OSS 20B is a medium-sized MoE model for lower-latency and local or specialized use cases; this variant is a 4-bit MLX conversion for Apple Silicon.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q8"),
pretty_name="GPT-OSS 20B (MXFP4-Q8, MLX)",
storage_size=Memory.from_kb(11_744_051),
n_layers=24,
hidden_size=2880,
supports_tensor=True,
),
storage_size=Memory.from_kb(11_744_051),
n_layers=24,
hidden_size=2880,
supports_tensor=True,
),
# glm 4.5
"glm-4.5-air-8bit": ModelCard(
# Needs to be quantized g32 or g16 to work with tensor parallel
short_id="glm-4.5-air-8bit",
model_id=ModelId("mlx-community/GLM-4.5-Air-8bit"),
name="GLM 4.5 Air 8bit",
description="""GLM 4.5 Air 8bit""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/GLM-4.5-Air-8bit"),
pretty_name="GLM 4.5 Air 8bit",
storage_size=Memory.from_gb(114),
n_layers=46,
hidden_size=4096,
supports_tensor=False,
),
storage_size=Memory.from_gb(114),
n_layers=46,
hidden_size=4096,
supports_tensor=False,
),
"glm-4.5-air-bf16": ModelCard(
short_id="glm-4.5-air-bf16",
model_id=ModelId("mlx-community/GLM-4.5-Air-bf16"),
name="GLM 4.5 Air bf16",
description="""GLM 4.5 Air bf16""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/GLM-4.5-Air-bf16"),
pretty_name="GLM 4.5 Air bf16",
storage_size=Memory.from_gb(214),
n_layers=46,
hidden_size=4096,
supports_tensor=True,
),
storage_size=Memory.from_gb(214),
n_layers=46,
hidden_size=4096,
supports_tensor=True,
),
# glm 4.7
"glm-4.7-4bit": ModelCard(
short_id="glm-4.7-4bit",
model_id=ModelId("mlx-community/GLM-4.7-4bit"),
name="GLM 4.7 4bit",
description="GLM 4.7 4bit",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/GLM-4.7-4bit"),
pretty_name="GLM 4.7 4bit",
storage_size=Memory.from_bytes(198556925568),
n_layers=91,
hidden_size=5120,
supports_tensor=True,
),
storage_size=Memory.from_bytes(198556925568),
n_layers=91,
hidden_size=5120,
supports_tensor=True,
),
"glm-4.7-6bit": ModelCard(
short_id="glm-4.7-6bit",
model_id=ModelId("mlx-community/GLM-4.7-6bit"),
name="GLM 4.7 6bit",
description="GLM 4.7 6bit",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/GLM-4.7-6bit"),
pretty_name="GLM 4.7 6bit",
storage_size=Memory.from_bytes(286737579648),
n_layers=91,
hidden_size=5120,
supports_tensor=True,
),
storage_size=Memory.from_bytes(286737579648),
n_layers=91,
hidden_size=5120,
supports_tensor=True,
),
"glm-4.7-8bit-gs32": ModelCard(
short_id="glm-4.7-8bit-gs32",
model_id=ModelId("mlx-community/GLM-4.7-8bit-gs32"),
name="GLM 4.7 8bit (gs32)",
description="GLM 4.7 8bit (gs32)",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/GLM-4.7-8bit-gs32"),
pretty_name="GLM 4.7 8bit (gs32)",
storage_size=Memory.from_bytes(396963397248),
n_layers=91,
hidden_size=5120,
supports_tensor=True,
),
storage_size=Memory.from_bytes(396963397248),
n_layers=91,
hidden_size=5120,
supports_tensor=True,
),
# glm 4.7 flash
"glm-4.7-flash-4bit": ModelCard(
model_id=ModelId("mlx-community/GLM-4.7-Flash-4bit"),
storage_size=Memory.from_gb(18),
n_layers=47,
hidden_size=2048,
supports_tensor=True,
),
"glm-4.7-flash-5bit": ModelCard(
model_id=ModelId("mlx-community/GLM-4.7-Flash-5bit"),
storage_size=Memory.from_gb(21),
n_layers=47,
hidden_size=2048,
supports_tensor=True,
),
"glm-4.7-flash-6bit": ModelCard(
model_id=ModelId("mlx-community/GLM-4.7-Flash-6bit"),
storage_size=Memory.from_gb(25),
n_layers=47,
hidden_size=2048,
supports_tensor=True,
),
"glm-4.7-flash-8bit": ModelCard(
model_id=ModelId("mlx-community/GLM-4.7-Flash-8bit"),
storage_size=Memory.from_gb(32),
n_layers=47,
hidden_size=2048,
supports_tensor=True,
),
# minimax-m2
"minimax-m2.1-8bit": ModelCard(
short_id="minimax-m2.1-8bit",
model_id=ModelId("mlx-community/MiniMax-M2.1-8bit"),
name="MiniMax M2.1 8bit",
description="MiniMax M2.1 8bit",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/MiniMax-M2.1-8bit"),
pretty_name="MiniMax M2.1 8bit",
storage_size=Memory.from_bytes(242986745856),
n_layers=61,
hidden_size=3072,
supports_tensor=True,
),
storage_size=Memory.from_bytes(242986745856),
n_layers=61,
hidden_size=3072,
supports_tensor=True,
),
"minimax-m2.1-3bit": ModelCard(
short_id="minimax-m2.1-3bit",
model_id=ModelId("mlx-community/MiniMax-M2.1-3bit"),
name="MiniMax M2.1 3bit",
description="MiniMax M2.1 3bit",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/MiniMax-M2.1-3bit"),
pretty_name="MiniMax M2.1 3bit",
storage_size=Memory.from_bytes(100086644736),
n_layers=61,
hidden_size=3072,
supports_tensor=True,
),
storage_size=Memory.from_bytes(100086644736),
n_layers=61,
hidden_size=3072,
supports_tensor=True,
),
}
from exo.worker.download.download_utils import ( # noqa: E402
ModelSafetensorsIndex,
download_file_with_retry,
ensure_models_dir,
)
class ConfigData(BaseModel):
model_config = {"extra": "ignore"} # Allow unknown fields
# Common field names for number of layers across different architectures
num_hidden_layers: Annotated[int, Field(ge=0)] | None = None
num_layers: Annotated[int, Field(ge=0)] | None = None
n_layer: Annotated[int, Field(ge=0)] | None = None
n_layers: Annotated[int, Field(ge=0)] | None = None # Sometimes used
num_decoder_layers: Annotated[int, Field(ge=0)] | None = None # Transformer models
decoder_layers: Annotated[int, Field(ge=0)] | None = None # Some architectures
hidden_size: Annotated[int, Field(ge=0)] | None = None
architectures: list[str] | None = None
@property
def supports_tensor(self) -> bool:
return self.architectures in [
["Glm4MoeLiteForCausalLM"],
["DeepseekV32ForCausalLM"],
["DeepseekV3ForCausalLM"],
["Qwen3NextForCausalLM"],
["Qwen3MoeForCausalLM"],
["MiniMaxM2ForCausalLM"],
["LlamaForCausalLM"],
["GptOssForCausalLM"],
]
@property
def layer_count(self) -> int:
# Check common field names for layer count
layer_fields = [
self.num_hidden_layers,
self.num_layers,
self.n_layer,
self.n_layers,
self.num_decoder_layers,
self.decoder_layers,
]
for layer_count in layer_fields:
if layer_count is not None:
return layer_count
raise ValueError(
f"No layer count found in config.json: {self.model_dump_json()}"
)
async def get_config_data(model_id: ModelId) -> ConfigData:
"""Downloads and parses config.json for a model."""
target_dir = (await ensure_models_dir()) / model_id.normalize()
await aios.makedirs(target_dir, exist_ok=True)
config_path = await download_file_with_retry(
model_id,
"main",
"config.json",
target_dir,
lambda curr_bytes, total_bytes, is_renamed: logger.info(
f"Downloading config.json for {model_id}: {curr_bytes}/{total_bytes} ({is_renamed=})"
),
)
async with aiofiles.open(config_path, "r") as f:
return ConfigData.model_validate_json(await f.read())
async def get_safetensors_size(model_id: ModelId) -> Memory:
"""Gets model size from safetensors index or falls back to HF API."""
target_dir = (await ensure_models_dir()) / model_id.normalize()
await aios.makedirs(target_dir, exist_ok=True)
index_path = await download_file_with_retry(
model_id,
"main",
"model.safetensors.index.json",
target_dir,
lambda curr_bytes, total_bytes, is_renamed: logger.info(
f"Downloading model.safetensors.index.json for {model_id}: {curr_bytes}/{total_bytes} ({is_renamed=})"
),
)
async with aiofiles.open(index_path, "r") as f:
index_data = ModelSafetensorsIndex.model_validate_json(await f.read())
metadata = index_data.metadata
if metadata is not None:
return Memory.from_bytes(metadata.total_size)
info = model_info(model_id)
if info.safetensors is None:
raise ValueError(f"No safetensors info found for {model_id}")
return Memory.from_bytes(info.safetensors.total)

View File

@@ -1,126 +0,0 @@
from typing import Annotated
import aiofiles
import aiofiles.os as aios
from huggingface_hub import model_info
from loguru import logger
from pydantic import BaseModel, Field
from exo.shared.models.model_cards import MODEL_CARDS
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.worker.download.download_utils import (
ModelSafetensorsIndex,
download_file_with_retry,
ensure_models_dir,
)
class ConfigData(BaseModel):
model_config = {"extra": "ignore"} # Allow unknown fields
# Common field names for number of layers across different architectures
num_hidden_layers: Annotated[int, Field(ge=0)] | None = None
num_layers: Annotated[int, Field(ge=0)] | None = None
n_layer: Annotated[int, Field(ge=0)] | None = None
n_layers: Annotated[int, Field(ge=0)] | None = None # Sometimes used
num_decoder_layers: Annotated[int, Field(ge=0)] | None = None # Transformer models
decoder_layers: Annotated[int, Field(ge=0)] | None = None # Some architectures
hidden_size: Annotated[int, Field(ge=0)] | None = None
@property
def layer_count(self) -> int:
# Check common field names for layer count
layer_fields = [
self.num_hidden_layers,
self.num_layers,
self.n_layer,
self.n_layers,
self.num_decoder_layers,
self.decoder_layers,
]
for layer_count in layer_fields:
if layer_count is not None:
return layer_count
raise ValueError(
f"No layer count found in config.json: {self.model_dump_json()}"
)
async def get_config_data(model_id: str) -> ConfigData:
"""Downloads and parses config.json for a model."""
target_dir = (await ensure_models_dir()) / str(model_id).replace("/", "--")
await aios.makedirs(target_dir, exist_ok=True)
config_path = await download_file_with_retry(
model_id,
"main",
"config.json",
target_dir,
lambda curr_bytes, total_bytes, is_renamed: logger.info(
f"Downloading config.json for {model_id}: {curr_bytes}/{total_bytes} ({is_renamed=})"
),
)
async with aiofiles.open(config_path, "r") as f:
return ConfigData.model_validate_json(await f.read())
async def get_safetensors_size(model_id: str) -> Memory:
"""Gets model size from safetensors index or falls back to HF API."""
target_dir = (await ensure_models_dir()) / str(model_id).replace("/", "--")
await aios.makedirs(target_dir, exist_ok=True)
index_path = await download_file_with_retry(
model_id,
"main",
"model.safetensors.index.json",
target_dir,
lambda curr_bytes, total_bytes, is_renamed: logger.info(
f"Downloading model.safetensors.index.json for {model_id}: {curr_bytes}/{total_bytes} ({is_renamed=})"
),
)
async with aiofiles.open(index_path, "r") as f:
index_data = ModelSafetensorsIndex.model_validate_json(await f.read())
metadata = index_data.metadata
if metadata is not None:
return Memory.from_bytes(metadata.total_size)
info = model_info(model_id)
if info.safetensors is None:
raise ValueError(f"No safetensors info found for {model_id}")
return Memory.from_bytes(info.safetensors.total)
_model_meta_cache: dict[str, ModelMetadata] = {}
async def get_model_meta(model_id: str) -> ModelMetadata:
if model_id in _model_meta_cache:
return _model_meta_cache[model_id]
model_meta = await _get_model_meta(model_id)
_model_meta_cache[model_id] = model_meta
return model_meta
async def _get_model_meta(model_id: str) -> ModelMetadata:
"""Fetches storage size and number of layers for a Hugging Face model, returns Pydantic ModelMeta."""
config_data = await get_config_data(model_id)
num_layers = config_data.layer_count
mem_size_bytes = await get_safetensors_size(model_id)
model_card = next(
(card for card in MODEL_CARDS.values() if card.model_id == ModelId(model_id)),
None,
)
return ModelMetadata(
model_id=ModelId(model_id),
pretty_name=model_card.name if model_card is not None else model_id,
storage_size=mem_size_bytes,
n_layers=num_layers,
hidden_size=config_data.hidden_size or 0,
# TODO: all custom models currently do not support tensor. We could add a dynamic test for this?
supports_tensor=model_card.metadata.supports_tensor
if model_card is not None
else False,
)

View File

@@ -7,8 +7,8 @@ import pytest
from _pytest.logging import LogCaptureFixture
from loguru import logger
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.worker.shards import PipelineShardMetadata, ShardMetadata
@@ -31,9 +31,8 @@ def get_pipeline_shard_metadata(
model_id: ModelId, device_rank: int, world_size: int = 1
) -> ShardMetadata:
return PipelineShardMetadata(
model_meta=ModelMetadata(
model_card=ModelCard(
model_id=model_id,
pretty_name=str(model_id),
storage_size=Memory.from_mb(100000),
n_layers=32,
hidden_size=1000,

View File

@@ -4,9 +4,9 @@ from typing import Any, Literal
from pydantic import BaseModel, Field, field_validator
from pydantic_core import PydanticUseDefault
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.types.common import CommandId
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.shards import Sharding
@@ -168,7 +168,7 @@ class BenchChatCompletionTaskParams(ChatCompletionTaskParams):
class PlaceInstanceParams(BaseModel):
model_id: str
model_id: ModelId
sharding: Sharding = Sharding.Pipeline
instance_meta: InstanceMeta = InstanceMeta.MlxRing
min_nodes: int = 1
@@ -206,7 +206,7 @@ class DeleteInstanceTaskParams(BaseModel):
class CreateInstanceResponse(BaseModel):
message: str
command_id: CommandId
model_meta: ModelMetadata
model_card: ModelCard
class DeleteInstanceResponse(BaseModel):

View File

@@ -1,10 +1,10 @@
from enum import Enum
from exo.shared.models.model_cards import ModelId
from exo.shared.types.api import GenerationStats
from exo.utils.pydantic_ext import TaggedModel
from .api import FinishReason
from .models import ModelId
class ChunkType(str, Enum):

View File

@@ -1,8 +1,8 @@
from pydantic import Field
from exo.shared.models.model_cards import ModelCard
from exo.shared.types.api import ChatCompletionTaskParams
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.models import ModelMetadata
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.shards import Sharding
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
@@ -21,7 +21,7 @@ class ChatCompletion(BaseCommand):
class PlaceInstance(BaseCommand):
model_meta: ModelMetadata
model_card: ModelCard
sharding: Sharding
instance_meta: InstanceMeta
min_nodes: int
@@ -35,26 +35,6 @@ class DeleteInstance(BaseCommand):
instance_id: InstanceId
class LaunchFLASH(BaseCommand):
"""Command to launch a FLASH MPI simulation."""
simulation_name: str
flash_executable_path: str
parameter_file_path: str
working_directory: str
ranks_per_node: int = 1
min_nodes: int = 1
# Optional: explicit hostnames for MPI (e.g., "s14,james21-1")
# Used when topology edges don't contain IP addresses
hosts: str = ""
class StopFLASH(BaseCommand):
"""Command to stop a running FLASH simulation."""
instance_id: InstanceId
class TaskFinished(BaseCommand):
finished_command_id: CommandId
@@ -70,8 +50,6 @@ Command = (
| PlaceInstance
| CreateInstance
| DeleteInstance
| LaunchFLASH
| StopFLASH
| TaskFinished
)

View File

@@ -16,13 +16,23 @@ class Id(str):
cls, _source: type, handler: GetCoreSchemaHandler
) -> core_schema.CoreSchema:
# Just use a plain string schema
return core_schema.str_schema()
return core_schema.no_info_after_validator_function(
cls, core_schema.str_schema()
)
class NodeId(Id):
pass
class ModelId(Id):
def normalize(self) -> str:
return self.replace("/", "--")
def short(self) -> str:
return self.split("/")[-1]
class SessionId(CamelCaseModel):
master_node_id: NodeId
election_clock: int

View File

@@ -1,18 +0,0 @@
from pydantic import PositiveInt
from exo.shared.types.common import Id
from exo.shared.types.memory import Memory
from exo.utils.pydantic_ext import CamelCaseModel
class ModelId(Id):
pass
class ModelMetadata(CamelCaseModel):
model_id: ModelId
pretty_name: str
storage_size: Memory
n_layers: PositiveInt
hidden_size: PositiveInt
supports_tensor: bool

View File

@@ -1,3 +1,8 @@
from datetime import timedelta
from typing import Literal
from pydantic import BaseModel, ConfigDict, Field, PositiveInt
from exo.shared.types.common import NodeId
from exo.shared.types.memory import Memory
from exo.shared.types.worker.shards import ShardMetadata
@@ -42,3 +47,50 @@ class DownloadOngoing(BaseDownloadProgress):
DownloadProgress = (
DownloadPending | DownloadCompleted | DownloadFailed | DownloadOngoing
)
class ModelSafetensorsIndexMetadata(BaseModel):
total_size: PositiveInt
class ModelSafetensorsIndex(BaseModel):
metadata: ModelSafetensorsIndexMetadata | None
weight_map: dict[str, str]
class FileListEntry(BaseModel):
type: Literal["file", "directory"]
path: str
size: int | None = None
class RepoFileDownloadProgress(BaseModel):
repo_id: str
repo_revision: str
file_path: str
downloaded: Memory
downloaded_this_session: Memory
total: Memory
speed: float
eta: timedelta
status: Literal["not_started", "in_progress", "complete"]
start_time: float
model_config = ConfigDict(frozen=True)
class RepoDownloadProgress(BaseModel):
repo_id: str
repo_revision: str
shard: ShardMetadata
completed_files: int
total_files: int
downloaded_bytes: Memory
downloaded_bytes_this_session: Memory
total_bytes: Memory
overall_speed: float
overall_eta: timedelta
status: Literal["not_started", "in_progress", "complete"]
file_progress: dict[str, RepoFileDownloadProgress] = Field(default_factory=dict)
model_config = ConfigDict(frozen=True)

View File

@@ -14,7 +14,6 @@ class InstanceId(Id):
class InstanceMeta(str, Enum):
MlxRing = "MlxRing"
MlxJaccl = "MlxJaccl"
FLASH = "FLASH"
class BaseInstance(TaggedModel):
@@ -35,27 +34,8 @@ class MlxJacclInstance(BaseInstance):
jaccl_coordinators: dict[NodeId, str]
class FLASHInstance(BaseInstance):
"""Instance for FLASH MPI simulation.
Unlike MLX instances which do tensor parallelism, FLASH instances
coordinate MPI processes across nodes. Each node runs one or more
MPI ranks of the FLASH simulation.
"""
hosts_by_node: dict[NodeId, list[Host]]
flash_executable_path: str
parameter_file_path: str
working_directory: str
ranks_per_node: int = 1
total_ranks: int
simulation_name: str
coordinator_ip: str
network_interface: str = "en0" # Network interface for MPI (e.g., en0, eth0)
# TODO: Single node instance
Instance = MlxRingInstance | MlxJacclInstance | FLASHInstance
Instance = MlxRingInstance | MlxJacclInstance
class BoundInstance(CamelCaseModel):

View File

@@ -2,8 +2,8 @@ from collections.abc import Mapping
from pydantic import model_validator
from exo.shared.models.model_cards import ModelId
from exo.shared.types.common import Id, NodeId
from exo.shared.types.models import ModelId
from exo.shared.types.worker.shards import ShardMetadata
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel

View File

@@ -2,7 +2,7 @@ from enum import Enum
from pydantic import Field
from exo.shared.types.models import ModelMetadata
from exo.shared.models.model_cards import ModelCard
from exo.utils.pydantic_ext import TaggedModel
@@ -17,7 +17,7 @@ class BaseShardMetadata(TaggedModel):
Replaces previous `Shard` object.
"""
model_meta: ModelMetadata
model_card: ModelCard
device_rank: int
world_size: int
@@ -41,7 +41,7 @@ class BaseShardMetadata(TaggedModel):
def __hash__(self) -> int:
return hash(
(
self.model_meta.model_id,
self.model_card.model_id,
self.start_layer,
self.end_layer,
self.n_layers,

View File

@@ -17,17 +17,20 @@ import aiohttp
import certifi
from loguru import logger
from pydantic import (
BaseModel,
ConfigDict,
DirectoryPath,
Field,
PositiveInt,
TypeAdapter,
)
from exo.shared.constants import EXO_MODELS_DIR
from exo.shared.types.common import ModelId
from exo.shared.types.memory import Memory
from exo.shared.types.worker.downloads import DownloadProgressData
from exo.shared.types.worker.downloads import (
DownloadProgressData,
FileListEntry,
ModelSafetensorsIndex,
RepoDownloadProgress,
RepoFileDownloadProgress,
)
from exo.shared.types.worker.shards import ShardMetadata
from exo.worker.download.huggingface_utils import (
filter_repo_objects,
@@ -37,53 +40,6 @@ from exo.worker.download.huggingface_utils import (
)
class ModelSafetensorsIndexMetadata(BaseModel):
total_size: PositiveInt
class ModelSafetensorsIndex(BaseModel):
metadata: ModelSafetensorsIndexMetadata | None
weight_map: dict[str, str]
class FileListEntry(BaseModel):
type: Literal["file", "directory"]
path: str
size: int | None = None
class RepoFileDownloadProgress(BaseModel):
repo_id: str
repo_revision: str
file_path: str
downloaded: Memory
downloaded_this_session: Memory
total: Memory
speed: float
eta: timedelta
status: Literal["not_started", "in_progress", "complete"]
start_time: float
model_config = ConfigDict(frozen=True)
class RepoDownloadProgress(BaseModel):
repo_id: str
repo_revision: str
shard: ShardMetadata
completed_files: int
total_files: int
downloaded_bytes: Memory
downloaded_bytes_this_session: Memory
total_bytes: Memory
overall_speed: float
overall_eta: timedelta
status: Literal["not_started", "in_progress", "complete"]
file_progress: dict[str, RepoFileDownloadProgress] = Field(default_factory=dict)
model_config = ConfigDict(frozen=True)
def trim_etag(etag: str) -> str:
if (etag[0] == '"' and etag[-1] == '"') or (etag[0] == "'" and etag[-1] == "'"):
return etag[1:-1]
@@ -125,12 +81,12 @@ def map_repo_download_progress_to_download_progress_data(
)
def build_model_path(model_id: str) -> DirectoryPath:
return EXO_MODELS_DIR / model_id.replace("/", "--")
def build_model_path(model_id: ModelId) -> DirectoryPath:
return EXO_MODELS_DIR / model_id.normalize()
async def resolve_model_path_for_repo(repo_id: str) -> Path:
return (await ensure_models_dir()) / repo_id.replace("/", "--")
async def resolve_model_path_for_repo(model_id: ModelId) -> Path:
return (await ensure_models_dir()) / model_id.normalize()
async def ensure_models_dir() -> Path:
@@ -138,8 +94,8 @@ async def ensure_models_dir() -> Path:
return EXO_MODELS_DIR
async def delete_model(repo_id: str) -> bool:
model_dir = await ensure_models_dir() / repo_id.replace("/", "--")
async def delete_model(model_id: ModelId) -> bool:
model_dir = await ensure_models_dir() / model_id.normalize()
if not await aios.path.exists(model_dir):
return False
await asyncio.to_thread(shutil.rmtree, model_dir, ignore_errors=False)
@@ -164,19 +120,17 @@ async def seed_models(seed_dir: str | Path):
async def fetch_file_list_with_cache(
repo_id: str, revision: str = "main", recursive: bool = False
model_id: ModelId, revision: str = "main", recursive: bool = False
) -> list[FileListEntry]:
target_dir = (
(await ensure_models_dir()) / "caches" / str(repo_id).replace("/", "--")
)
target_dir = (await ensure_models_dir()) / "caches" / model_id.normalize()
await aios.makedirs(target_dir, exist_ok=True)
cache_file = (
target_dir / f"{repo_id.replace('/', '--')}--{revision}--file_list.json"
)
cache_file = target_dir / f"{model_id.normalize()}--{revision}--file_list.json"
if await aios.path.exists(cache_file):
async with aiofiles.open(cache_file, "r") as f:
return TypeAdapter(list[FileListEntry]).validate_json(await f.read())
file_list = await fetch_file_list_with_retry(repo_id, revision, recursive=recursive)
file_list = await fetch_file_list_with_retry(
model_id, revision, recursive=recursive
)
await aios.makedirs(cache_file.parent, exist_ok=True)
async with aiofiles.open(cache_file, "w") as f:
await f.write(TypeAdapter(list[FileListEntry]).dump_json(file_list).decode())
@@ -184,25 +138,25 @@ async def fetch_file_list_with_cache(
async def fetch_file_list_with_retry(
repo_id: str, revision: str = "main", path: str = "", recursive: bool = False
model_id: ModelId, revision: str = "main", path: str = "", recursive: bool = False
) -> list[FileListEntry]:
n_attempts = 30
for attempt in range(n_attempts):
try:
return await _fetch_file_list(repo_id, revision, path, recursive)
return await _fetch_file_list(model_id, revision, path, recursive)
except Exception as e:
if attempt == n_attempts - 1:
raise e
await asyncio.sleep(min(8, 0.1 * float(2.0 ** int(attempt))))
raise Exception(
f"Failed to fetch file list for {repo_id=} {revision=} {path=} {recursive=}"
f"Failed to fetch file list for {model_id=} {revision=} {path=} {recursive=}"
)
async def _fetch_file_list(
repo_id: str, revision: str = "main", path: str = "", recursive: bool = False
model_id: ModelId, revision: str = "main", path: str = "", recursive: bool = False
) -> list[FileListEntry]:
api_url = f"{get_hf_endpoint()}/api/models/{repo_id}/tree/{revision}"
api_url = f"{get_hf_endpoint()}/api/models/{model_id}/tree/{revision}"
url = f"{api_url}/{path}" if path else api_url
headers = await get_download_headers()
@@ -219,7 +173,7 @@ async def _fetch_file_list(
files.append(FileListEntry.model_validate(item))
elif item.type == "directory" and recursive:
subfiles = await _fetch_file_list(
repo_id, revision, item.path, recursive
model_id, revision, item.path, recursive
)
files.extend(subfiles)
return files
@@ -276,10 +230,10 @@ async def calc_hash(path: Path, hash_type: Literal["sha1", "sha256"] = "sha1") -
async def file_meta(
repo_id: str, revision: str, path: str, redirected_location: str | None = None
model_id: ModelId, revision: str, path: str, redirected_location: str | None = None
) -> tuple[int, str]:
url = (
urljoin(f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/", path)
urljoin(f"{get_hf_endpoint()}/{model_id}/resolve/{revision}/", path)
if redirected_location is None
else f"{get_hf_endpoint()}{redirected_location}"
)
@@ -298,7 +252,7 @@ async def file_meta(
return content_length, etag
# Otherwise, follow the redirect to get authoritative size/hash
redirected_location = r.headers.get("location")
return await file_meta(repo_id, revision, path, redirected_location)
return await file_meta(model_id, revision, path, redirected_location)
content_length = int(
r.headers.get("x-linked-size") or r.headers.get("content-length") or 0
)
@@ -310,7 +264,7 @@ async def file_meta(
async def download_file_with_retry(
repo_id: str,
model_id: ModelId,
revision: str,
path: str,
target_dir: Path,
@@ -320,23 +274,23 @@ async def download_file_with_retry(
for attempt in range(n_attempts):
try:
return await _download_file(
repo_id, revision, path, target_dir, on_progress
model_id, revision, path, target_dir, on_progress
)
except Exception as e:
if isinstance(e, FileNotFoundError) or attempt == n_attempts - 1:
raise e
logger.error(
f"Download error on attempt {attempt}/{n_attempts} for {repo_id=} {revision=} {path=} {target_dir=}"
f"Download error on attempt {attempt}/{n_attempts} for {model_id=} {revision=} {path=} {target_dir=}"
)
logger.error(traceback.format_exc())
await asyncio.sleep(min(8, 0.1 * (2.0**attempt)))
raise Exception(
f"Failed to download file {repo_id=} {revision=} {path=} {target_dir=}"
f"Failed to download file {model_id=} {revision=} {path=} {target_dir=}"
)
async def _download_file(
repo_id: str,
model_id: ModelId,
revision: str,
path: str,
target_dir: Path,
@@ -345,7 +299,7 @@ async def _download_file(
if await aios.path.exists(target_dir / path):
return target_dir / path
await aios.makedirs((target_dir / path).parent, exist_ok=True)
length, etag = await file_meta(repo_id, revision, path)
length, etag = await file_meta(model_id, revision, path)
remote_hash = etag[:-5] if etag.endswith("-gzip") else etag
partial_path = target_dir / f"{path}.partial"
resume_byte_pos = (
@@ -354,7 +308,7 @@ async def _download_file(
else None
)
if resume_byte_pos != length:
url = urljoin(f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/", path)
url = urljoin(f"{get_hf_endpoint()}/{model_id}/resolve/{revision}/", path)
headers = await get_download_headers()
if resume_byte_pos:
headers["Range"] = f"bytes={resume_byte_pos}-"
@@ -394,7 +348,7 @@ async def _download_file(
def calculate_repo_progress(
shard: ShardMetadata,
repo_id: str,
model_id: ModelId,
revision: str,
file_progress: dict[str, RepoFileDownloadProgress],
all_start_time: float,
@@ -423,7 +377,7 @@ def calculate_repo_progress(
else "not_started"
)
return RepoDownloadProgress(
repo_id=repo_id,
repo_id=model_id,
repo_revision=revision,
shard=shard,
completed_files=len(
@@ -442,11 +396,11 @@ def calculate_repo_progress(
)
async def get_weight_map(repo_id: str, revision: str = "main") -> dict[str, str]:
target_dir = (await ensure_models_dir()) / str(repo_id).replace("/", "--")
async def get_weight_map(model_id: ModelId, revision: str = "main") -> dict[str, str]:
target_dir = (await ensure_models_dir()) / model_id.normalize()
await aios.makedirs(target_dir, exist_ok=True)
index_file = await download_file_with_retry(
repo_id, revision, "model.safetensors.index.json", target_dir
model_id, revision, "model.safetensors.index.json", target_dir
)
async with aiofiles.open(index_file, "r") as f:
index_data = ModelSafetensorsIndex.model_validate_json(await f.read())
@@ -460,10 +414,10 @@ async def resolve_allow_patterns(shard: ShardMetadata) -> list[str]:
# (iii) Tensor parallel requires all files.
return ["*"]
try:
weight_map = await get_weight_map(str(shard.model_meta.model_id))
weight_map = await get_weight_map(str(shard.model_card.model_id))
return get_allow_patterns(weight_map, shard)
except Exception:
logger.error(f"Error getting weight map for {shard.model_meta.model_id=}")
logger.error(f"Error getting weight map for {shard.model_card.model_id=}")
logger.error(traceback.format_exc())
return ["*"]
@@ -477,53 +431,6 @@ async def get_downloaded_size(path: Path) -> int:
return 0
async def download_progress_for_local_path(
repo_id: str, shard: ShardMetadata, local_path: Path
) -> RepoDownloadProgress:
file_progress: dict[str, RepoFileDownloadProgress] = {}
total_files = 0
total_bytes = 0
if await aios.path.isdir(local_path):
for root, _, files in os.walk(local_path):
for f in files:
if f.endswith((".safetensors", ".bin", ".pt", ".gguf", ".json")):
file_path = Path(root) / f
size = (await aios.stat(file_path)).st_size
rel_path = str(file_path.relative_to(local_path))
file_progress[rel_path] = RepoFileDownloadProgress(
repo_id=repo_id,
repo_revision="local",
file_path=rel_path,
downloaded=Memory.from_bytes(size),
downloaded_this_session=Memory.from_bytes(0),
total=Memory.from_bytes(size),
speed=0,
eta=timedelta(0),
status="complete",
start_time=time.time(),
)
total_files += 1
total_bytes += size
else:
raise ValueError(f"Local path {local_path} is not a directory")
return RepoDownloadProgress(
repo_id=repo_id,
repo_revision="local",
shard=shard,
completed_files=total_files,
total_files=total_files,
downloaded_bytes=Memory.from_bytes(total_bytes),
downloaded_bytes_this_session=Memory.from_bytes(0),
total_bytes=Memory.from_bytes(total_bytes),
overall_speed=0,
overall_eta=timedelta(0),
status="complete",
file_progress=file_progress,
)
async def download_shard(
shard: ShardMetadata,
on_progress: Callable[[ShardMetadata, RepoDownloadProgress], Awaitable[None]],
@@ -532,18 +439,10 @@ async def download_shard(
allow_patterns: list[str] | None = None,
) -> tuple[Path, RepoDownloadProgress]:
if not skip_download:
logger.info(f"Downloading {shard.model_meta.model_id=}")
# Handle local paths
if await aios.path.exists(str(shard.model_meta.model_id)):
logger.info(f"Using local model path {shard.model_meta.model_id}")
local_path = Path(str(shard.model_meta.model_id))
return local_path, await download_progress_for_local_path(
str(shard.model_meta.model_id), shard, local_path
)
logger.info(f"Downloading {shard.model_card.model_id=}")
revision = "main"
target_dir = await ensure_models_dir() / str(shard.model_meta.model_id).replace(
target_dir = await ensure_models_dir() / str(shard.model_card.model_id).replace(
"/", "--"
)
if not skip_download:
@@ -552,13 +451,14 @@ async def download_shard(
if not allow_patterns:
allow_patterns = await resolve_allow_patterns(shard)
logger.info(f"Downloading {shard.model_meta.model_id=} with {allow_patterns=}")
if not skip_download:
logger.info(f"Downloading {shard.model_card.model_id=} with {allow_patterns=}")
all_start_time = time.time()
# TODO: currently not recursive. Some models might require subdirectories - thus this will need to be changed.
# Update: <- This does not seem to be the case. Yay?
file_list = await fetch_file_list_with_cache(
str(shard.model_meta.model_id), revision, recursive=True
shard.model_card.model_id, revision, recursive=True
)
filtered_file_list = list(
filter_repo_objects(
@@ -592,7 +492,7 @@ async def download_shard(
else timedelta(seconds=0)
)
file_progress[file.path] = RepoFileDownloadProgress(
repo_id=str(shard.model_meta.model_id),
repo_id=shard.model_card.model_id,
repo_revision=revision,
file_path=file.path,
downloaded=Memory.from_bytes(curr_bytes),
@@ -609,7 +509,7 @@ async def download_shard(
shard,
calculate_repo_progress(
shard,
str(shard.model_meta.model_id),
shard.model_card.model_id,
revision,
file_progress,
all_start_time,
@@ -619,7 +519,7 @@ async def download_shard(
for file in filtered_file_list:
downloaded_bytes = await get_downloaded_size(target_dir / file.path)
file_progress[file.path] = RepoFileDownloadProgress(
repo_id=str(shard.model_meta.model_id),
repo_id=shard.model_card.model_id,
repo_revision=revision,
file_path=file.path,
downloaded=Memory.from_bytes(downloaded_bytes),
@@ -643,7 +543,7 @@ async def download_shard(
async def download_with_semaphore(file: FileListEntry) -> None:
async with semaphore:
await download_file_with_retry(
str(shard.model_meta.model_id),
shard.model_card.model_id,
revision,
file.path,
target_dir,
@@ -657,7 +557,7 @@ async def download_shard(
*[download_with_semaphore(file) for file in filtered_file_list]
)
final_repo_progress = calculate_repo_progress(
shard, str(shard.model_meta.model_id), revision, file_progress, all_start_time
shard, shard.model_card.model_id, revision, file_progress, all_start_time
)
await on_progress(shard, final_repo_progress)
if gguf := next((f for f in filtered_file_list if f.path.endswith(".gguf")), None):

View File

@@ -3,8 +3,7 @@ from collections.abc import Awaitable
from pathlib import Path
from typing import AsyncIterator, Callable
from exo.shared.models.model_cards import MODEL_CARDS
from exo.shared.models.model_meta import get_model_meta
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard, ModelId
from exo.shared.types.worker.shards import (
PipelineShardMetadata,
ShardMetadata,
@@ -19,22 +18,22 @@ def exo_shard_downloader(max_parallel_downloads: int = 8) -> ShardDownloader:
)
async def build_base_shard(model_id: str) -> ShardMetadata:
model_meta = await get_model_meta(model_id)
async def build_base_shard(model_id: ModelId) -> ShardMetadata:
model_card = await ModelCard.from_hf(model_id)
return PipelineShardMetadata(
model_meta=model_meta,
model_card=model_card,
device_rank=0,
world_size=1,
start_layer=0,
end_layer=model_meta.n_layers,
n_layers=model_meta.n_layers,
end_layer=model_card.n_layers,
n_layers=model_card.n_layers,
)
async def build_full_shard(model_id: str) -> PipelineShardMetadata:
async def build_full_shard(model_id: ModelId) -> PipelineShardMetadata:
base_shard = await build_base_shard(model_id)
return PipelineShardMetadata(
model_meta=base_shard.model_meta,
model_card=base_shard.model_card,
device_rank=base_shard.device_rank,
world_size=base_shard.world_size,
start_layer=base_shard.start_layer,
@@ -93,11 +92,11 @@ class CachedShardDownloader(ShardDownloader):
async def ensure_shard(
self, shard: ShardMetadata, config_only: bool = False
) -> Path:
if (shard.model_meta.model_id, shard) in self.cache:
return self.cache[(shard.model_meta.model_id, shard)]
if (shard.model_card.model_id, shard) in self.cache:
return self.cache[(shard.model_card.model_id, shard)]
target_dir = await self.shard_downloader.ensure_shard(shard, config_only)
self.cache[(shard.model_meta.model_id, shard)] = target_dir
self.cache[(shard.model_card.model_id, shard)] = target_dir
return target_dir
async def get_shard_download_status(
@@ -148,7 +147,7 @@ class ResumableShardDownloader(ShardDownloader):
self,
) -> AsyncIterator[tuple[Path, RepoDownloadProgress]]:
async def _status_for_model(
model_id: str,
model_id: ModelId,
) -> tuple[Path, RepoDownloadProgress]:
"""Helper coroutine that builds the shard for a model and gets its download status."""
shard = await build_full_shard(model_id)

View File

@@ -5,8 +5,8 @@ from datetime import timedelta
from pathlib import Path
from typing import AsyncIterator, Callable
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.worker.shards import (
PipelineShardMetadata,
ShardMetadata,
@@ -86,9 +86,8 @@ NOOP_DOWNLOAD_PROGRESS = RepoDownloadProgress(
repo_id="noop",
repo_revision="noop",
shard=PipelineShardMetadata(
model_meta=ModelMetadata(
model_card=ModelCard(
model_id=ModelId("noop"),
pretty_name="noope",
storage_size=Memory.from_bytes(0),
n_layers=1,
hidden_size=1,

View File

@@ -1,7 +1,10 @@
import os
import threading
from abc import ABC, abstractmethod
from collections.abc import Callable
from functools import partial
from inspect import signature
from typing import TYPE_CHECKING, Callable, Protocol, cast
from typing import TYPE_CHECKING, Any, Protocol, cast
import mlx.core as mx
import mlx.nn as nn
@@ -29,6 +32,40 @@ from mlx_lm.models.qwen3_next import Qwen3NextSparseMoeBlock
from exo.shared.logging import logger
from exo.shared.types.worker.shards import PipelineShardMetadata
TimeoutCallback = Callable[[], None]
def eval_with_timeout(
mlx_item: Any, # pyright: ignore[reportAny]
timeout_seconds: float = 60.0,
on_timeout: TimeoutCallback | None = None,
) -> None:
"""Evaluate MLX item with a hard timeout.
If on_timeout callback is provided, it will be called before terminating
the process. This allows the runner to send a failure event before exit.
"""
completed = threading.Event()
def watchdog() -> None:
if not completed.wait(timeout=timeout_seconds):
logger.error(
f"mlx_item evaluation timed out after {timeout_seconds:.0f}s. "
"This may indicate an issue with FAST_SYNCH and tensor parallel sharding. "
"Terminating process."
)
if on_timeout is not None:
on_timeout()
os._exit(1)
watchdog_thread = threading.Thread(target=watchdog, daemon=True)
watchdog_thread.start()
try:
mx.eval(mlx_item) # pyright: ignore[reportAny]
finally:
completed.set()
class _LayerCallable(Protocol):
"""Structural type that any compatible layer must satisfy.
@@ -46,11 +83,11 @@ class CustomMlxLayer(nn.Module):
def __init__(self, original_layer: _LayerCallable):
super().__init__()
object.__setattr__(self, "_original_layer", original_layer)
dict.__setitem__(self, "_original_layer", original_layer) # pyright: ignore[reportUnknownMemberType]
@property
def original_layer(self) -> _LayerCallable:
return cast(_LayerCallable, object.__getattribute__(self, "_original_layer"))
return cast(_LayerCallable, self["_original_layer"])
# Calls __getattr__ for any attributes not found on nn.Module (e.g. use_sliding)
if not TYPE_CHECKING:
@@ -59,7 +96,7 @@ class CustomMlxLayer(nn.Module):
try:
return super().__getattr__(name)
except AttributeError:
original_layer = object.__getattribute__(self, "_original_layer")
original_layer = cast(_LayerCallable, self["_original_layer"])
return getattr(original_layer, name)
@@ -225,9 +262,37 @@ def patch_pipeline_model[T](model: T, group: mx.distributed.Group) -> T:
return model
def patch_tensor_model[T](model: T) -> T:
"""Patch model's __call__ to ensure distributed ops sync during inference."""
cls = model.__class__
original_call = cls.__call__
call_signature = signature(original_call)
def patched_call(
self: T,
*args: object,
**kwargs: object,
) -> mx.array:
logits: mx.array = original_call(self, *args, **kwargs) # pyright: ignore[reportAny]
cache = call_signature.bind_partial(self, *args, **kwargs).arguments.get(
"cache", None
)
# Add dependency to last cache entry to ensure distributed ops are evaluated
if cache is not None and len(cache) > 0: # pyright: ignore[reportAny]
cache[-1].state = mx.depends(cache[-1].state, logits) # pyright: ignore[reportAny,reportUnknownMemberType]
return logits
cls.__call__ = patched_call
return model
def tensor_auto_parallel(
model: nn.Module,
group: mx.distributed.Group,
timeout_seconds: float = 60.0,
on_timeout: TimeoutCallback | None = None,
) -> nn.Module:
all_to_sharded_linear = partial(
shard_linear,
@@ -269,10 +334,10 @@ def tensor_auto_parallel(
group=group,
)
if hasattr(model, "shard"):
if hasattr(model, "shard") and not isinstance(model, GptOssModel):
try:
model.shard(group) # type: ignore
return model
return patch_tensor_model(model)
except (AttributeError, TypeError, NameError):
pass
@@ -318,11 +383,13 @@ def tensor_auto_parallel(
all_to_sharded_linear_in_place,
sharded_to_all_linear_in_place,
)
else:
raise ValueError(f"Unsupported model type: {type(model)}")
return tensor_parallel_sharding_strategy.shard_model(model)
model = tensor_parallel_sharding_strategy.shard_model(
model, timeout_seconds, on_timeout
)
return patch_tensor_model(model)
class TensorParallelShardingStrategy(ABC):
@@ -342,13 +409,27 @@ class TensorParallelShardingStrategy(ABC):
self.N = group.size()
@abstractmethod
def shard_model(self, model: nn.Module) -> nn.Module: ...
def shard_model(
self,
model: nn.Module,
timeout_seconds: float,
on_timeout: TimeoutCallback | None,
) -> nn.Module: ...
class LlamaShardingStrategy(TensorParallelShardingStrategy):
def shard_model(self, model: nn.Module) -> nn.Module:
def shard_model(
self,
model: nn.Module,
timeout_seconds: float,
on_timeout: TimeoutCallback | None,
) -> nn.Module:
model = cast(LlamaModel, model)
for layer in model.layers:
# Force load weights before sharding to avoid FAST_SYNCH deadlock
eval_with_timeout(
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
)
layer.self_attn.q_proj = self.all_to_sharded_linear(layer.self_attn.q_proj)
layer.self_attn.k_proj = self.all_to_sharded_linear(layer.self_attn.k_proj)
layer.self_attn.v_proj = self.all_to_sharded_linear(layer.self_attn.v_proj)
@@ -391,9 +472,17 @@ def _set_layers(model: nn.Module, layers: list[_LayerCallable]) -> None:
class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
def shard_model(self, model: nn.Module) -> nn.Module:
def shard_model(
self,
model: nn.Module,
timeout_seconds: float,
on_timeout: TimeoutCallback | None,
) -> nn.Module:
model = cast(DeepseekV3Model, model)
for layer in model.layers:
eval_with_timeout(
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
)
# Shard the self attention
if layer.self_attn.q_lora_rank is None:
layer.self_attn.q_proj = self.all_to_sharded_linear(
@@ -445,9 +534,17 @@ class ShardedDeepseekV3MoE(CustomMlxLayer):
class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
def shard_model(self, model: nn.Module) -> nn.Module:
def shard_model(
self,
model: nn.Module,
timeout_seconds: float,
on_timeout: TimeoutCallback | None,
) -> nn.Module:
model = cast(MiniMaxModel, model)
for layer in model.layers:
eval_with_timeout(
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
)
# Shard the self attention
layer.self_attn.q_proj = self.all_to_sharded_linear(layer.self_attn.q_proj)
layer.self_attn.k_proj = self.all_to_sharded_linear(layer.self_attn.k_proj)
@@ -468,15 +565,23 @@ class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
layer.block_sparse_moe.switch_mlp.up_proj
)
layer.block_sparse_moe = ShardedQwenMoE(layer.block_sparse_moe) # pyright: ignore[reportAttributeAccessIssue, reportArgumentType]
layer.block_sparse_moe.sharding_group = self.group
layer.block_sparse_moe.sharding_group = self.group # pyright: ignore[reportAttributeAccessIssue]
return model
class QwenShardingStrategy(TensorParallelShardingStrategy):
def shard_model(self, model: nn.Module) -> nn.Module:
def shard_model(
self,
model: nn.Module,
timeout_seconds: float,
on_timeout: TimeoutCallback | None,
) -> nn.Module:
model = cast(Qwen3MoeModel, model)
for layer in model.layers:
eval_with_timeout(
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
)
# Shard the self attention
layer.self_attn.q_proj = self.all_to_sharded_linear(layer.self_attn.q_proj)
layer.self_attn.k_proj = self.all_to_sharded_linear(layer.self_attn.k_proj)
@@ -520,10 +625,18 @@ class ShardedQwenMoE(CustomMlxLayer):
class GptOssShardingStrategy(TensorParallelShardingStrategy):
def shard_model(self, model: nn.Module) -> nn.Module:
def shard_model(
self,
model: nn.Module,
timeout_seconds: float,
on_timeout: TimeoutCallback | None,
) -> nn.Module:
model = cast(GptOssMoeModel, model)
for layer in model.layers:
eval_with_timeout(
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
)
layer.self_attn.q_proj = self.all_to_sharded_linear(layer.self_attn.q_proj)
layer.self_attn.k_proj = self.all_to_sharded_linear(layer.self_attn.k_proj)
layer.self_attn.v_proj = self.all_to_sharded_linear(layer.self_attn.v_proj)
@@ -547,7 +660,7 @@ class GptOssShardingStrategy(TensorParallelShardingStrategy):
self.all_to_sharded_linear_in_place(layer.mlp.experts.up_proj)
layer.mlp = ShardedGptOssMoE(layer.mlp) # type: ignore
layer.mlp.sharding_group = self.group
layer.mlp.sharding_group = self.group # pyright: ignore[reportAttributeAccessIssue]
return model

View File

@@ -2,9 +2,7 @@ import json
import os
import resource
import sys
import threading
import time
from collections.abc import Callable
from pathlib import Path
from typing import Any, cast
@@ -25,6 +23,7 @@ from mlx_lm.models.deepseek_v3 import DeepseekV3Model
from mlx_lm.models.gpt_oss import Model as GptOssModel
from mlx_lm.tokenizer_utils import TokenizerWrapper
from exo.shared.models.model_cards import ModelId
from exo.worker.engines.mlx.constants import (
CACHE_GROUP_SIZE,
KV_CACHE_BITS,
@@ -59,6 +58,8 @@ from exo.shared.types.worker.shards import (
from exo.worker.download.download_utils import build_model_path
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.auto_parallel import (
TimeoutCallback,
eval_with_timeout,
pipeline_auto_parallel,
tensor_auto_parallel,
)
@@ -75,7 +76,7 @@ def get_weights_size(model_shard_meta: ShardMetadata) -> Memory:
return Memory.from_float_kb(
(model_shard_meta.end_layer - model_shard_meta.start_layer)
/ model_shard_meta.n_layers
* model_shard_meta.model_meta.storage_size.in_kb
* model_shard_meta.model_card.storage_size.in_kb
/ (
1
if isinstance(model_shard_meta, PipelineShardMetadata)
@@ -88,41 +89,6 @@ class ModelLoadingTimeoutError(Exception):
pass
TimeoutCallback = Callable[[], None]
def eval_with_timeout(
mlx_item: Any, # pyright: ignore[reportAny]
timeout_seconds: float = 60.0,
on_timeout: TimeoutCallback | None = None,
) -> None:
"""Evaluate MLX item with a hard timeout.
If on_timeout callback is provided, it will be called before terminating
the process. This allows the runner to send a failure event before exit.
"""
completed = threading.Event()
def watchdog() -> None:
if not completed.wait(timeout=timeout_seconds):
logger.error(
f"mlx_item evaluation timed out after {timeout_seconds:.0f}s. "
"This may indicate an issue with FAST_SYNCH and tensor parallel sharding. "
"Terminating process."
)
if on_timeout is not None:
on_timeout()
os._exit(1)
watchdog_thread = threading.Thread(target=watchdog, daemon=True)
watchdog_thread.start()
try:
mx.eval(mlx_item) # pyright: ignore[reportAny]
finally:
completed.set()
def mx_barrier(group: Group | None = None):
mx.eval(
mx.distributed.all_sum(
@@ -204,19 +170,14 @@ def mlx_distributed_init(
# TODO: update once upstream fixes
logger.info(
f"rank {rank} MLX_IBV_DEVICES: {coordination_file} with devices: {jaccl_devices_json}"
f"rank {rank} MLX_JACCL_DEVICES: {coordination_file} with devices: {jaccl_devices_json}"
)
logger.info(f"rank {rank} MLX_JACCL_COORDINATOR: {jaccl_coordinator}")
os.environ["MLX_IBV_DEVICES"] = coordination_file
os.environ["MLX_JACCL_DEVICES"] = coordination_file
os.environ["MLX_RANK"] = str(rank)
os.environ["MLX_JACCL_COORDINATOR"] = jaccl_coordinator
group = mx.distributed.init(backend="jaccl", strict=True)
case _:
raise ValueError(
f"Unsupported instance type for MLX distributed: {type(bound_instance.instance)}"
)
logger.info(f"Rank {rank} mlx distributed initialization complete")
return group
@@ -246,7 +207,7 @@ def load_mlx_items(
) -> tuple[Model, TokenizerWrapper]:
if group is None:
logger.info(f"Single device used for {bound_instance.instance}")
model_path = build_model_path(bound_instance.bound_shard.model_meta.model_id)
model_path = build_model_path(bound_instance.bound_shard.model_card.model_id)
start_time = time.perf_counter()
model, _ = load_model(model_path, strict=True)
end_time = time.perf_counter()
@@ -274,7 +235,7 @@ def shard_and_load(
group: Group,
on_timeout: TimeoutCallback | None = None,
) -> tuple[nn.Module, TokenizerWrapper]:
model_path = build_model_path(shard_metadata.model_meta.model_id)
model_path = build_model_path(shard_metadata.model_card.model_id)
model, _ = load_model(model_path, lazy=True, strict=False)
logger.debug(model)
@@ -301,14 +262,6 @@ def shard_and_load(
logger.info(f"Group size: {group.size()}, group rank: {group.rank()}")
match shard_metadata:
case TensorShardMetadata():
logger.info(f"loading model from {model_path} with tensor parallelism")
model = tensor_auto_parallel(model, group)
case PipelineShardMetadata():
logger.info(f"loading model from {model_path} with pipeline parallelism")
model = pipeline_auto_parallel(model, group, shard_metadata)
# Estimate timeout based on model size
base_timeout = float(os.environ.get("EXO_MODEL_LOAD_TIMEOUT", "60"))
model_size_gb = get_weights_size(shard_metadata).in_bytes / (1024**3)
@@ -317,7 +270,15 @@ def shard_and_load(
f"Evaluating model parameters with timeout of {timeout_seconds:.0f}s "
f"(model size: {model_size_gb:.1f}GB)"
)
eval_with_timeout(model.parameters(), timeout_seconds, on_timeout)
match shard_metadata:
case TensorShardMetadata():
logger.info(f"loading model from {model_path} with tensor parallelism")
model = tensor_auto_parallel(model, group, timeout_seconds, on_timeout)
case PipelineShardMetadata():
logger.info(f"loading model from {model_path} with pipeline parallelism")
model = pipeline_auto_parallel(model, group, shard_metadata)
eval_with_timeout(model.parameters(), timeout_seconds, on_timeout)
# TODO: Do we need this?
mx.eval(model)
@@ -333,10 +294,10 @@ def shard_and_load(
def get_tokenizer(model_path: Path, shard_metadata: ShardMetadata) -> TokenizerWrapper:
"""Load tokenizer for a model shard. Delegates to load_tokenizer_for_model_id."""
return load_tokenizer_for_model_id(shard_metadata.model_meta.model_id, model_path)
return load_tokenizer_for_model_id(shard_metadata.model_card.model_id, model_path)
def get_eos_token_ids_for_model(model_id: str) -> list[int] | None:
def get_eos_token_ids_for_model(model_id: ModelId) -> list[int] | None:
"""
Get the EOS token IDs for a model based on its ID.
@@ -352,12 +313,17 @@ def get_eos_token_ids_for_model(model_id: str) -> list[int] | None:
model_id_lower = model_id.lower()
if "kimi-k2" in model_id_lower:
return [163586]
elif "glm-4.7-flash" in model_id_lower:
# 154820: <|endoftext|>, 154827: <|user|>, 154829: <|observation|>
return [154820, 154827, 154829]
elif "glm" in model_id_lower:
return [151336, 151329, 151338]
return None
def load_tokenizer_for_model_id(model_id: str, model_path: Path) -> TokenizerWrapper:
def load_tokenizer_for_model_id(
model_id: ModelId, model_path: Path
) -> TokenizerWrapper:
"""
Load tokenizer for a model given its ID and local path.

View File

@@ -8,6 +8,7 @@ from loguru import logger
from exo.routing.connection_message import ConnectionMessage, ConnectionMessageType
from exo.shared.apply import apply
from exo.shared.models.model_cards import ModelId
from exo.shared.types.commands import ForwarderCommand, RequestEventLog
from exo.shared.types.common import NodeId, SessionId
from exo.shared.types.events import (
@@ -22,7 +23,6 @@ from exo.shared.types.events import (
TopologyEdgeCreated,
TopologyEdgeDeleted,
)
from exo.shared.types.models import ModelId
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.state import State
from exo.shared.types.tasks import (
@@ -186,11 +186,11 @@ class Worker:
)
)
case DownloadModel(shard_metadata=shard):
if shard.model_meta.model_id not in self.download_status:
if shard.model_card.model_id not in self.download_status:
progress = DownloadPending(
shard_metadata=shard, node_id=self.node_id
)
self.download_status[shard.model_meta.model_id] = progress
self.download_status[shard.model_card.model_id] = progress
await self.event_sender.send(
NodeDownloadProgress(download_progress=progress)
)
@@ -205,7 +205,7 @@ class Worker:
node_id=self.node_id,
total_bytes=initial_progress.total_bytes,
)
self.download_status[shard.model_meta.model_id] = progress
self.download_status[shard.model_card.model_id] = progress
await self.event_sender.send(
NodeDownloadProgress(download_progress=progress)
)
@@ -339,7 +339,7 @@ class Worker:
initial_progress
),
)
self.download_status[task.shard_metadata.model_meta.model_id] = status
self.download_status[task.shard_metadata.model_card.model_id] = status
self.event_sender.send_nowait(NodeDownloadProgress(download_progress=status))
last_progress_time = 0.0
@@ -356,7 +356,7 @@ class Worker:
node_id=self.node_id,
total_bytes=progress.total_bytes,
)
self.download_status[shard.model_meta.model_id] = status
self.download_status[shard.model_card.model_id] = status
await self.event_sender.send(
NodeDownloadProgress(download_progress=status)
)
@@ -376,7 +376,7 @@ class Worker:
progress
),
)
self.download_status[shard.model_meta.model_id] = status
self.download_status[shard.model_card.model_id] = status
await self.event_sender.send(
NodeDownloadProgress(download_progress=status)
)
@@ -413,11 +413,6 @@ class Worker:
)
for nid in conns:
for ip in conns[nid]:
if "127.0.0.1" in ip or "localhost" in ip:
logger.warning(
f"Loopback connection should not happen: {ip=} for {nid=}"
)
edge = SocketConnection(
# nonsense multiaddr
sink_multiaddr=Multiaddr(address=f"/ip4/{ip}/tcp/52415")
@@ -438,6 +433,9 @@ class Worker:
for conn in self.state.topology.out_edges(self.node_id):
if not isinstance(conn.edge, SocketConnection):
continue
# ignore mDNS discovered connections
if conn.edge.sink_multiaddr.port != 52415:
continue
if (
conn.sink not in conns
or conn.edge.sink_multiaddr.ip_address
@@ -451,7 +449,7 @@ class Worker:
async def _emit_existing_download_progress(self) -> None:
try:
while True:
logger.info("Fetching and emitting existing download progress...")
logger.debug("Fetching and emitting existing download progress...")
async for (
_,
progress,
@@ -478,11 +476,11 @@ class Worker:
else:
continue
self.download_status[progress.shard.model_meta.model_id] = status
self.download_status[progress.shard.model_card.model_id] = status
await self.event_sender.send(
NodeDownloadProgress(download_progress=status)
)
logger.info("Done emitting existing download progress.")
logger.debug("Done emitting existing download progress.")
await anyio.sleep(5 * 60) # 5 minutes
except Exception as e:
logger.error(f"Error emitting existing download progress: {e}")

View File

@@ -2,8 +2,8 @@
from collections.abc import Mapping, Sequence
from exo.shared.models.model_cards import ModelId
from exo.shared.types.common import NodeId
from exo.shared.types.models import ModelId
from exo.shared.types.tasks import (
ChatCompletion,
ConnectToGroup,
@@ -21,11 +21,7 @@ from exo.shared.types.worker.downloads import (
DownloadOngoing,
DownloadProgress,
)
from exo.shared.types.worker.instances import (
BoundInstance,
Instance,
InstanceId,
)
from exo.shared.types.worker.instances import BoundInstance, Instance, InstanceId
from exo.shared.types.worker.runners import (
RunnerConnected,
RunnerConnecting,
@@ -54,16 +50,6 @@ def plan(
all_runners: Mapping[RunnerId, RunnerStatus], # all global
tasks: Mapping[TaskId, Task],
) -> Task | None:
from exo.plugins.registry import PluginRegistry
registry = PluginRegistry.get()
# Check plugin tasks first
for plugin in registry.all_plugins():
task = plugin.plan_task(runners, instances)
if task is not None:
return task
# Python short circuiting OR logic should evaluate these sequentially.
return (
_kill_runner(runners, all_runners, instances)
@@ -127,19 +113,8 @@ def _model_needs_download(
runners: Mapping[RunnerId, RunnerSupervisor],
download_status: Mapping[ModelId, DownloadProgress],
) -> DownloadModel | None:
from exo.plugins.registry import PluginRegistry
registry = PluginRegistry.get()
for runner in runners.values():
instance = runner.bound_instance.instance
# Check if any plugin wants to skip download for this instance
plugin = registry.get_plugin_for_instance(instance)
if plugin is not None and plugin.should_skip_download(instance):
continue
model_id = runner.bound_instance.bound_shard.model_meta.model_id
model_id = runner.bound_instance.bound_shard.model_card.model_id
if isinstance(runner.status, RunnerIdle) and (
model_id not in download_status
or not isinstance(
@@ -216,7 +191,7 @@ def _load_model(
nid in global_download_status
and any(
isinstance(dp, DownloadCompleted)
and dp.shard_metadata.model_meta.model_id == shard_assignments.model_id
and dp.shard_metadata.model_card.model_id == shard_assignments.model_id
for dp in global_download_status[nid]
)
for nid in shard_assignments.node_to_runner

View File

@@ -4,10 +4,7 @@ import loguru
from exo.shared.types.events import Event, RunnerStatusUpdated
from exo.shared.types.tasks import Task
from exo.shared.types.worker.instances import (
BoundInstance,
MlxJacclInstance,
)
from exo.shared.types.worker.instances import BoundInstance, MlxJacclInstance
from exo.shared.types.worker.runners import RunnerFailed
from exo.utils.channels import ClosedResourceError, MpReceiver, MpSender
@@ -20,7 +17,6 @@ def entrypoint(
task_receiver: MpReceiver[Task],
_logger: "loguru.Logger",
) -> None:
# Set FAST_SYNCH based on env var or JACCL device count
fast_synch_override = os.environ.get("EXO_FAST_SYNCH")
if fast_synch_override == "on" or (
fast_synch_override != "off"
@@ -38,26 +34,11 @@ def entrypoint(
logger.info(f"Fast synch flag: {os.environ['MLX_METAL_FAST_SYNCH']}")
# Route based on instance type (plugins or default MLX)
# Import main after setting global logger - this lets us just import logger from this module
try:
from exo.plugins.registry import PluginRegistry, discover_plugins
from exo.worker.runner.runner import main
# Discover plugins in subprocess (they aren't inherited from main process)
discover_plugins()
registry = PluginRegistry.get()
instance = bound_instance.instance
# Check if a plugin handles this instance type
plugin = registry.get_plugin_for_instance(instance)
if plugin is not None:
# Delegate to plugin runner
plugin.create_runner(bound_instance, event_sender, task_receiver)
else:
# MLX runner (default)
from exo.worker.runner.runner import main
main(bound_instance, event_sender, task_receiver)
main(bound_instance, event_sender, task_receiver)
except ClosedResourceError:
logger.warning("Runner communication closed unexpectedly")
except Exception as e:

View File

@@ -213,7 +213,7 @@ def main(
command_id=command_id,
chunk=TokenChunk(
idx=response.token,
model=shard_metadata.model_meta.model_id,
model=shard_metadata.model_card.model_id,
text=response.text,
token_id=response.token,
finish_reason=response.finish_reason,
@@ -230,7 +230,7 @@ def main(
command_id=command_id,
chunk=TokenChunk(
idx=0,
model=shard_metadata.model_meta.model_id,
model=shard_metadata.model_card.model_id,
text="",
token_id=0,
finish_reason="error",

View File

@@ -1,7 +1,7 @@
from typing import Final
from exo.shared.models.model_cards import ModelId
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.models import ModelId
from exo.shared.types.tasks import TaskId
from exo.shared.types.worker.instances import InstanceId, RunnerId

View File

@@ -1,8 +1,8 @@
from dataclasses import dataclass, field
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.types.common import NodeId
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.tasks import BaseTask, TaskId
from exo.shared.types.worker.instances import (
BoundInstance,
@@ -32,9 +32,8 @@ def get_pipeline_shard_metadata(
model_id: ModelId, device_rank: int, world_size: int = 1
) -> ShardMetadata:
return PipelineShardMetadata(
model_meta=ModelMetadata(
model_card=ModelCard(
model_id=model_id,
pretty_name=str(model_id),
storage_size=Memory.from_mb(100000),
n_layers=32,
hidden_size=2048,

View File

@@ -11,9 +11,10 @@ import mlx.core as mx
import mlx.nn as nn
from exo.shared.constants import EXO_MODELS_DIR
from exo.shared.models.model_cards import ModelCard
from exo.shared.types.api import ChatCompletionMessage
from exo.shared.types.common import ModelId
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.tasks import ChatCompletionTaskParams
from exo.shared.types.worker.shards import PipelineShardMetadata, TensorShardMetadata
from exo.worker.engines.mlx import Model
@@ -81,9 +82,8 @@ def run_gpt_oss_pipeline_device(
start_layer, end_layer = layer_splits[rank]
shard_meta = PipelineShardMetadata(
model_meta=ModelMetadata(
model_card=ModelCard(
model_id=ModelId(DEFAULT_GPT_OSS_MODEL_ID),
pretty_name="GPT-OSS 20B",
storage_size=Memory.from_gb(12),
n_layers=24,
hidden_size=2880,
@@ -151,9 +151,8 @@ def run_gpt_oss_tensor_parallel_device(
# For tensor parallelism, all devices run all layers
shard_meta = TensorShardMetadata(
model_meta=ModelMetadata(
model_card=ModelCard(
model_id=ModelId(DEFAULT_GPT_OSS_MODEL_ID),
pretty_name="GPT-OSS 20B",
storage_size=Memory.from_gb(12),
n_layers=24,
hidden_size=2880,

View File

@@ -18,6 +18,7 @@ def _check_model_exists() -> bool:
pytestmark = [
pytest.mark.slow,
pytest.mark.skipif(
not _check_model_exists(),
reason=f"GPT-OSS model not found at {DEFAULT_GPT_OSS_CONFIG.model_path}",

View File

@@ -11,7 +11,7 @@ from pathlib import Path
import pytest
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard, ModelId
from exo.worker.download.download_utils import (
download_file_with_retry,
ensure_models_dir,
@@ -50,9 +50,9 @@ def is_tokenizer_file(filename: str) -> bool:
return False
async def download_tokenizer_files(model_id: str) -> Path:
async def download_tokenizer_files(model_id: ModelId) -> Path:
"""Download only the tokenizer-related files for a model."""
target_dir = await ensure_models_dir() / model_id.replace("/", "--")
target_dir = await ensure_models_dir() / model_id.normalize()
target_dir.mkdir(parents=True, exist_ok=True)
file_list = await fetch_file_list_with_cache(model_id, "main", recursive=True)
@@ -72,22 +72,24 @@ async def download_tokenizer_files(model_id: str) -> Path:
# Get a sample of models to test (one per family to keep tests fast)
def get_test_models() -> list[tuple[str, ModelCard]]:
def get_test_models() -> list[ModelCard]:
"""Get a representative sample of models to test."""
# Pick one model from each family to test
families: dict[str, tuple[str, ModelCard]] = {}
for short_id, card in MODEL_CARDS.items():
families: dict[str, ModelCard] = {}
for card in MODEL_CARDS.values():
# Extract family name (e.g., "llama-3.1" from "llama-3.1-8b")
parts = short_id.split("-")
parts = card.model_id.short().split("-")
family = "-".join(parts[:2]) if len(parts) >= 2 else parts[0]
if family not in families:
families[family] = (short_id, card)
families[family] = card
return list(families.values())
TEST_MODELS: list[tuple[str, ModelCard]] = get_test_models()
TEST_MODELS: list[ModelCard] = get_test_models()
pytestmark = pytest.mark.slow
@pytest.fixture(scope="module")
@@ -99,14 +101,13 @@ def event_loop():
@pytest.mark.parametrize(
"short_id,model_card",
"model_card",
TEST_MODELS,
ids=[m[0] for m in TEST_MODELS],
)
@pytest.mark.asyncio
async def test_tokenizer_encode_decode(short_id: str, model_card: ModelCard) -> None:
"""Test that tokenizer can encode and decode text correctly."""
model_id = str(model_card.model_id)
model_id = model_card.model_id
# Download tokenizer files
model_path = await download_tokenizer_files(model_id)
@@ -165,16 +166,15 @@ async def test_tokenizer_encode_decode(short_id: str, model_card: ModelCard) ->
@pytest.mark.parametrize(
"short_id,model_card",
"model_card",
TEST_MODELS,
ids=[m[0] for m in TEST_MODELS],
)
@pytest.mark.asyncio
async def test_tokenizer_has_required_attributes(
short_id: str, model_card: ModelCard
) -> None:
"""Test that tokenizer has required attributes for inference."""
model_id = str(model_card.model_id)
model_id = model_card.model_id
model_path = await download_tokenizer_files(model_id)
@@ -207,19 +207,18 @@ async def test_tokenizer_has_required_attributes(
@pytest.mark.parametrize(
"short_id,model_card",
"model_card",
TEST_MODELS,
ids=[m[0] for m in TEST_MODELS],
)
@pytest.mark.asyncio
async def test_tokenizer_special_tokens(short_id: str, model_card: ModelCard) -> None:
async def test_tokenizer_special_tokens(model_card: ModelCard) -> None:
"""Test that tokenizer can encode text containing special tokens.
This is critical because the actual inference path uses prompts with
special tokens from chat templates. If special tokens aren't handled
correctly, encoding will fail.
"""
model_id = str(model_card.model_id)
model_id = model_card.model_id
model_path = await download_tokenizer_files(model_id)
@@ -299,16 +298,14 @@ async def test_tokenizer_special_tokens(short_id: str, model_card: ModelCard) ->
async def test_kimi_tokenizer_specifically():
"""Test Kimi tokenizer with its specific patches and quirks."""
kimi_models = [
(short_id, card)
for short_id, card in MODEL_CARDS.items()
if "kimi" in short_id.lower()
card for card in MODEL_CARDS.values() if "kimi" in card.model_id.lower()
]
if not kimi_models:
pytest.skip("No Kimi models found in MODEL_CARDS")
_, model_card = kimi_models[0]
model_id = str(model_card.model_id)
model_card = kimi_models[0]
model_id = model_card.model_id
model_path = await download_tokenizer_files(model_id)
@@ -347,17 +344,15 @@ async def test_kimi_tokenizer_specifically():
@pytest.mark.asyncio
async def test_glm_tokenizer_specifically():
"""Test GLM tokenizer with its specific EOS tokens."""
glm_models = [
(short_id, card)
for short_id, card in MODEL_CARDS.items()
if "glm" in short_id.lower()
glm_model_cards = [
card for card in MODEL_CARDS.values() if "glm" in card.model_id.lower()
]
if not glm_models:
if not glm_model_cards:
pytest.skip("No GLM models found in MODEL_CARDS")
_, model_card = glm_models[0]
model_id = str(model_card.model_id)
model_card = glm_model_cards[0]
model_id = model_card.model_id
model_path = await download_tokenizer_files(model_id)

View File

@@ -1,7 +1,6 @@
import exo.worker.plan as plan_mod
from exo.shared.types.common import NodeId
from exo.shared.types.common import ModelId, NodeId
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId
from exo.shared.types.tasks import LoadModel
from exo.shared.types.worker.downloads import DownloadCompleted, DownloadProgress
from exo.shared.types.worker.instances import BoundInstance

View File

@@ -82,7 +82,7 @@ async def tb_detection():
send, recv = channel[GatheredInfo]()
ig = InfoGatherer(send)
with anyio.move_on_after(1):
await ig._monitor_system_profiler() # pyright: ignore[reportPrivateUsage]
await ig._monitor_system_profiler_thunderbolt_data() # pyright: ignore[reportPrivateUsage]
with recv:
return recv.collect()
@@ -135,7 +135,7 @@ def ring_instance(test: Tests, iid: InstanceId, hn: str) -> Instance:
else:
raise ValueError(f"{hn} not in {test.devs}")
meta = MODEL_CARDS[test.model_id].metadata
card = MODEL_CARDS[test.model_id]
instance = MlxRingInstance(
instance_id=iid,
ephemeral_port=52416,
@@ -145,15 +145,15 @@ def ring_instance(test: Tests, iid: InstanceId, hn: str) -> Instance:
node_to_runner={NodeId(host[0]): RunnerId(host[0]) for host in test.devs},
runner_to_shard={
RunnerId(test.devs[i][0]): PipelineShardMetadata(
model_meta=meta,
model_card=card,
device_rank=i,
world_size=world_size,
start_layer=(meta.n_layers // world_size) * i,
start_layer=(card.n_layers // world_size) * i,
end_layer=min(
meta.n_layers, (meta.n_layers // world_size) * (i + 1)
card.n_layers, (card.n_layers // world_size) * (i + 1)
),
n_layers=min(meta.n_layers, (meta.n_layers // world_size) * (i + 1))
- (meta.n_layers // world_size) * i,
n_layers=min(card.n_layers, (card.n_layers // world_size) * (i + 1))
- (card.n_layers // world_size) * i,
)
for i in range(world_size)
},
@@ -224,7 +224,7 @@ async def jaccl_backend(test: Tests):
def jaccl_instance(test: Tests, iid: InstanceId):
meta = MODEL_CARDS[test.model_id].metadata
card = MODEL_CARDS[test.model_id]
world_size = len(test.devs)
return MlxJacclInstance(
@@ -239,12 +239,12 @@ def jaccl_instance(test: Tests, iid: InstanceId):
node_to_runner={NodeId(host[0]): RunnerId(host[0]) for host in test.devs},
runner_to_shard={
RunnerId(test.devs[i][0]): TensorShardMetadata(
model_meta=meta,
model_card=card,
device_rank=i,
world_size=world_size,
start_layer=meta.n_layers,
end_layer=meta.n_layers,
n_layers=meta.n_layers,
start_layer=card.n_layers,
end_layer=card.n_layers,
n_layers=card.n_layers,
)
for i in range(world_size)
},

1508
uv.lock generated
View File

File diff suppressed because it is too large Load Diff