Compare commits

..

16 Commits

Author SHA1 Message Date
Alex Cheema
8bf4d1f585 fix: enable MLX CPU inference on x86_64 Linux in Docker
Two issues prevented MLX CPU from working on x86_64 in Docker:

1. Missing BLAS/LAPACK libraries: MLX CPU backend requires libblas-dev,
   liblapack-dev, and liblapacke-dev on Linux. Added to apt-get install.

2. g++ wrapper ordering: The -fpermissive wrapper for GCC 14 was installed
   AFTER uv sync, but MLX may compile extensions during install. Moved
   the wrapper BEFORE uv sync so both build-time and runtime JIT
   compilation benefit from the fix.

MLX publishes manylinux_2_35_x86_64 wheels, so this uses the native
CPU backend — no alternative inference framework needed.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-13 12:25:11 -08:00
Alex Cheema
5e27e4e719 Add multi-model snapshot tests for model diversity
Add e2e snapshot test that exercises 3 different model architectures
to catch model-specific regressions:
- SmolLM2-135M-Instruct (tiny llama, bf16, ~269MB)
- Llama-3.2-1B-Instruct-4bit (small llama, 4bit, ~730MB)
- gemma-2-2b-it-4bit (gemma2 architecture, 4bit, ~1.5GB)

Each model gets its own snapshot file. All use the same prompt
("What is the capital of France?"), seed=42, max_tokens=32.

Also adds model cards for SmolLM2-135M-Instruct and gemma-2-2b-it-4bit
(Llama-3.2-1B-Instruct-4bit already had one).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-13 12:12:40 -08:00
Alex Cheema
b249757116 feat: add Docker layer caching to e2e CI with buildx + GHA cache
Pre-build the Docker image using docker/build-push-action with GitHub
Actions cache (type=gha). On cache hit, the image loads from cache
instead of rebuilding (~12min → seconds).

Changes:
- CI: set up buildx, build image with --cache-from/--cache-to type=gha
- docker-compose.yml: add image tag (exo-e2e:latest) so compose uses
  the pre-built image instead of rebuilding
- conftest.py: Cluster.build() skips if exo-e2e:latest already exists
  (pre-built in CI), falls back to docker compose build for local dev

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-13 11:46:56 -08:00
Alex Cheema
5c0b769bf8 feat: make snapshot tests run on x86 Ubuntu CI without GPU
MLX already supports x86 CPU via mlx[cpu] and the Dockerfile has the
GCC workaround for CPU JIT. The only barriers were the 'slow' markers
causing tests to be skipped in CI.

Changes:
- Remove 'slow' marker from all snapshot tests so they run by default
- Make snapshots architecture-aware (snapshots/{arch}/{name}.json) since
  floating-point results differ between x86_64 and arm64
- Store architecture in snapshot metadata
- Increase CI timeout from 30 to 45 minutes for model download + CPU inference
- Update docstrings to remove Apple Silicon requirement

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-13 11:39:56 -08:00
Alex Cheema
702886d147 feat: add snapshot test cases for code gen, reasoning, long output, and edge cases
Expand e2e snapshot coverage beyond the single 'What is 2+2?' test:
- test_snapshot_code_gen.py: code generation prompt (max_tokens=64)
- test_snapshot_reasoning.py: step-by-step math reasoning (max_tokens=64)
- test_snapshot_long_output.py: longer response with max_tokens=128
- test_snapshot_edge.py: single word, special chars, and unicode prompts

All use seed=42 and the shared assert_snapshot() infrastructure.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-13 10:50:28 -08:00
Alex Cheema
2526b7d166 feat: add reusable snapshot regression testing to e2e framework
Add e2e/snapshot.py with assert_snapshot() for deterministic regression
testing. On first run, saves inference output as the expected snapshot.
On subsequent runs, compares against it with unified diff on mismatch.
Set UPDATE_SNAPSHOTS=1 or pass --update-snapshots to regenerate.

Refactor test_inference_snapshot.py to use the shared infrastructure
and drop temperature=0 in favor of seed-only determinism.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-13 10:40:54 -08:00
Alex Cheema
ffb79d88ca fix: add root conftest.py to exclude start_distributed_test from pytest collection
The tests/start_distributed_test.py script calls sys.exit() at module
level, which crashes pytest collection. Exclude it via collect_ignore.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-13 10:27:21 -08:00
Alex Cheema
4f32b9f180 Merge remote-tracking branch 'origin/main' into e2e-tests 2026-02-13 10:26:21 -08:00
Alex Cheema
e8203596ab fix: ruff lint and formatting for e2e test files
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-12 16:03:52 -08:00
Alex Cheema
b88749a6c5 Merge remote-tracking branch 'origin/main' into e2e-tests 2026-02-12 15:58:04 -08:00
Alex Cheema
4a446b2779 fix: skip slow inference test in CI, run with --slow
MLX CPU inference on x86_64 is too slow for CI runners (~10min+ for
a single request). Mark the inference snapshot test as slow so it's
skipped by default. Run with --slow or E2E_SLOW=1 on Apple Silicon.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-12 11:33:20 -08:00
Alex Cheema
a82feed8e3 feat: add deterministic inference snapshot test
Launch mlx-community/Qwen3-0.6B-4bit on the cluster, send a chat
completion with seed=42 and temperature=0, and verify the output
matches a committed snapshot. Tests inference determinism end-to-end.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-12 10:58:54 -08:00
Alex Cheema
da6e626f6f fix: make no_internet test actually block internet with iptables
Use iptables to block all outbound traffic except private subnets and
multicast (for mDNS discovery). Verify internet is blocked by curling
huggingface.co from inside each container and checking exo logs for
"Internet connectivity: False".

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-12 10:19:47 -08:00
Alex Cheema
cf23916b8b fix: reduce Docker image size and free more CI disk space
Clean up Rust target/ and cargo registry after uv sync in the same RUN
command so build artifacts aren't committed to the layer (~1-2 GB saved).
Also remove more unused toolchains from the CI runner.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-12 09:52:48 -08:00
Alex Cheema
80b29ba0d9 fix: free disk space in CI before Docker build
The runner was running out of disk space during the Docker image build
(Rust compilation + Python deps). Remove unused toolchains first.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-12 09:32:44 -08:00
Alex Cheema
b6214c297f feat: add Docker-based E2E test framework
Add a Python/asyncio E2E test framework that spins up 2-node exo clusters
in Docker Compose and verifies cluster formation, discovery, election, and
API health. Includes a no-internet chaos test using DNS blocking.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-12 09:16:57 -08:00
33 changed files with 1033 additions and 1083 deletions

15
.dockerignore Normal file
View File

@@ -0,0 +1,15 @@
.venv/
.direnv/
target/
.git/
.idea/
.pytest_cache/
.ruff_cache/
dashboard/node_modules/
dashboard/.svelte-kit/
dashboard/build/
dist/
*.pdb
**/__pycache__
**/.DS_Store
.mlx_typings/

42
.github/workflows/e2e.yml vendored Normal file
View File

@@ -0,0 +1,42 @@
name: e2e-tests
on:
push:
pull_request:
branches:
- staging
- main
jobs:
e2e:
runs-on: ubuntu-latest
timeout-minutes: 45
steps:
- name: Free up disk space
run: |
sudo rm -rf /usr/share/dotnet /usr/local/lib/android /opt/ghc \
/opt/hostedtoolcache /usr/local/share/boost /usr/share/swift \
/opt/microsoft /opt/az
docker system prune -af
df -h /
- name: Checkout repository
uses: actions/checkout@v4
with:
lfs: false
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Build E2E image with cache
uses: docker/build-push-action@v6
with:
context: .
file: e2e/Dockerfile
tags: exo-e2e:latest
load: true
cache-from: type=gha
cache-to: type=gha,mode=max
- name: Run E2E tests
run: python3 e2e/run_all.py

View File

@@ -194,40 +194,3 @@ GitHub's API doesn't support direct image upload for PR comments. Workaround:
git push origin <branch>
```
The images still render in the PR comment because they reference the permanent commit SHA.
## Running exo Remotely via SSH (macOS mDNS)
**CRITICAL: On macOS, mDNS multicast (used for peer discovery) only works when the process runs in a proper macOS user session.** Background processes started via `nohup ... &`, `screen`, or plain SSH commands will NOT send mDNS packets and nodes will never discover each other.
### The Problem
When you SSH into a Mac and run `nohup uv run exo &`, the process runs in a detached session without access to macOS multicast networking. The exo node will start but will never discover peers, even if they're on the same network.
### The Solution: Use `open` with a `.command` wrapper
Create a `.command` script that `open` will execute in the proper macOS GUI session context:
```bash
# 1. Create wrapper script on the remote machine
ssh user@remote-mac "cat > /tmp/run_exo.command << 'SCRIPT'
#!/bin/bash
export PATH=/opt/homebrew/bin:\$HOME/.local/bin:\$PATH
export EXO_LIBP2P_NAMESPACE=your-namespace # must match across all nodes
cd ~/path/to/exo
exec uv run exo -vv 2>&1 | tee /tmp/exo.log
SCRIPT
chmod +x /tmp/run_exo.command"
# 2. Launch it via `open` (runs in macOS GUI session with proper mDNS)
ssh user@remote-mac "open /tmp/run_exo.command"
# 3. Check logs
ssh user@remote-mac "tail -f /tmp/exo.log"
```
### Key Details
- **`EXO_LIBP2P_NAMESPACE`**: All nodes in a cluster MUST use the same namespace value. The EXO.app uses a build-specific namespace (check with `ps eww <pid> | grep NAMESPACE`). If mixing dev builds with EXO.app, set the dev build's namespace to match.
- **`open *.command`**: This is the macOS equivalent of double-clicking the script in Finder. It runs in the user's GUI session with full network access.
- **Do NOT use**: `nohup ... &`, `screen -dm`, `tmux new-session -d`, or `sshpass`. These all create detached sessions where mDNS won't work.
- **Killing**: `ssh user@remote-mac "pkill -f 'python.*exo'"` works fine for stopping.
- **Dashboard**: Must be built before running: `cd dashboard && npm install && npm run build && cd ..`. Node.js is at `/opt/homebrew/bin/node` on Apple Silicon Macs.
- **Verifying cluster**: `curl -s http://localhost:52415/state | python3 -c "import json,sys; s=json.load(sys.stdin); print(len(s['topology']['nodes']), 'nodes')"` — should show 2+ nodes.

1
conftest.py Normal file
View File

@@ -0,0 +1 @@
collect_ignore = ["tests/start_distributed_test.py"]

58
e2e/Dockerfile Normal file
View File

@@ -0,0 +1,58 @@
# Stage 1: Build the dashboard
FROM node:22-slim AS dashboard
WORKDIR /app/dashboard
COPY dashboard/package.json dashboard/package-lock.json ./
RUN npm ci
COPY dashboard/ .
RUN npm run build
# Stage 2: Build and run exo
FROM python:3.13-slim
# Install system dependencies
# libblas-dev/liblapack-dev/liblapacke-dev are required by MLX CPU backend on Linux
RUN apt-get update && apt-get install -y \
build-essential \
pkg-config \
libssl-dev \
libblas-dev \
liblapack-dev \
liblapacke-dev \
curl \
protobuf-compiler \
iptables \
&& rm -rf /var/lib/apt/lists/*
# Install Rust nightly
RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y --default-toolchain nightly
ENV PATH="/root/.cargo/bin:${PATH}"
# Wrap g++ with -fpermissive to fix MLX CPU JIT compilation with GCC 14
# (GCC 14 treats _Float128/_Float32/_Float64 as built-in types, conflicting with MLX-generated code)
# Must be done BEFORE uv sync so any source builds also get the fix
RUN mv /usr/bin/g++ /usr/bin/g++.real && \
printf '#!/bin/sh\nexec /usr/bin/g++.real -fpermissive "$@"\n' > /usr/bin/g++ && \
chmod +x /usr/bin/g++
# Install uv
COPY --from=ghcr.io/astral-sh/uv:latest /uv /usr/local/bin/uv
WORKDIR /app
# Copy dependency files first for better layer caching
COPY pyproject.toml Cargo.toml uv.lock README.md ./
COPY rust/ ./rust/
COPY bench/pyproject.toml ./bench/pyproject.toml
# Copy source and resources
COPY src/ ./src/
COPY resources/ ./resources/
# Copy built dashboard from stage 1
COPY --from=dashboard /app/dashboard/build ./dashboard/build/
# Install Python deps and build Rust bindings, then clean up build artifacts
# to keep the layer small (Rust target/ and cargo registry can be 1-2 GB)
RUN uv sync && rm -rf /app/rust/target /root/.cargo/registry /root/.cargo/git
CMD [".venv/bin/exo", "-v"]

195
e2e/conftest.py Normal file
View File

@@ -0,0 +1,195 @@
"""Shared E2E test infrastructure for exo cluster tests."""
import asyncio
import json
import os
import sys
from pathlib import Path
from urllib.error import URLError
from urllib.request import Request, urlopen
E2E_DIR = Path(__file__).parent.resolve()
TIMEOUT = int(os.environ.get("E2E_TIMEOUT", "120"))
class Cluster:
"""Async wrapper around a docker compose exo cluster."""
def __init__(self, name: str, overrides: list[str] | None = None):
self.name = name
self.project = f"e2e-{name}"
compose_files = [str(E2E_DIR / "docker-compose.yml")]
for path in overrides or []:
compose_files.append(str(E2E_DIR / path))
self._compose_base = [
"docker",
"compose",
"-p",
self.project,
*[arg for f in compose_files for arg in ("-f", f)],
]
async def __aenter__(self):
return self
async def __aexit__(self, *exc):
await self.stop()
async def _run(self, *args: str, check: bool = True) -> str:
proc = await asyncio.create_subprocess_exec(
*self._compose_base,
*args,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.STDOUT,
)
stdout, _ = await proc.communicate()
output = stdout.decode()
if check and proc.returncode != 0:
print(output, file=sys.stderr)
raise RuntimeError(
f"docker compose {' '.join(args)} failed (rc={proc.returncode})"
)
return output
async def build(self):
# Skip build if the image was pre-built (e.g. in CI with buildx cache)
proc = await asyncio.create_subprocess_exec(
"docker",
"image",
"inspect",
"exo-e2e:latest",
stdout=asyncio.subprocess.DEVNULL,
stderr=asyncio.subprocess.DEVNULL,
)
await proc.wait()
if proc.returncode == 0:
print(" Using pre-built image (exo-e2e:latest)")
return
print(" Building images...")
await self._run("build", "--quiet")
async def start(self):
print(" Starting cluster...")
await self._run("up", "-d")
async def stop(self):
print(" Cleaning up...")
await self._run("down", "--timeout", "5", check=False)
async def logs(self) -> str:
return await self._run("logs", check=False)
async def exec(
self, service: str, *cmd: str, check: bool = True
) -> tuple[int, str]:
"""Run a command inside a running container. Returns (returncode, output)."""
proc = await asyncio.create_subprocess_exec(
*self._compose_base,
"exec",
"-T",
service,
*cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.STDOUT,
)
stdout, _ = await proc.communicate()
output = stdout.decode()
if check and proc.returncode != 0:
raise RuntimeError(
f"exec {' '.join(cmd)} in {service} failed (rc={proc.returncode})"
)
return proc.returncode, output
async def wait_for(self, description: str, check_fn, timeout: int = TIMEOUT):
"""Poll check_fn every 2s until it returns True or timeout expires."""
print(f" Waiting for {description}...")
deadline = asyncio.get_event_loop().time() + timeout
while asyncio.get_event_loop().time() < deadline:
if await check_fn():
print(f" {description}")
return
await asyncio.sleep(2)
output = await self.logs()
print(f"--- cluster logs ---\n{output}\n---", file=sys.stderr)
raise TimeoutError(f"Timed out waiting for {description}")
async def assert_healthy(self):
"""Verify the cluster formed correctly: nodes started, discovered each other, elected a master, API responds."""
async def both_nodes_started():
log = await self.logs()
return log.count("Starting node") >= 2
async def nodes_discovered():
log = await self.logs()
return log.count("ConnectionMessageType.Connected") >= 2
async def master_elected():
log = await self.logs()
return "demoting self" in log
async def api_responding():
try:
with urlopen("http://localhost:52415/v1/models", timeout=3) as resp:
return resp.status == 200
except (URLError, OSError):
return False
await self.wait_for("Both nodes started", both_nodes_started)
await self.wait_for("Nodes discovered each other", nodes_discovered)
await self.wait_for("Master election resolved", master_elected)
await self.wait_for("API responding", api_responding)
async def _api(
self, method: str, path: str, body: dict | None = None, timeout: int = 30
) -> dict:
"""Make an API request to the cluster. Returns parsed JSON."""
url = f"http://localhost:52415{path}"
data = json.dumps(body).encode() if body else None
req = Request(
url, data=data, headers={"Content-Type": "application/json"}, method=method
)
loop = asyncio.get_event_loop()
resp_bytes = await loop.run_in_executor(
None, lambda: urlopen(req, timeout=timeout).read()
)
return json.loads(resp_bytes)
async def place_model(self, model: str, timeout: int = 600):
"""Place a model instance on the cluster (triggers download) and wait until it's ready."""
await self._api("POST", "/place_instance", {"model_id": model})
async def model_ready():
try:
resp = await self._api("GET", "/v1/models")
return any(m.get("id") == model for m in resp.get("data", []))
except Exception:
return False
await self.wait_for(f"Model {model} ready", model_ready, timeout=timeout)
async def chat(
self, model: str, messages: list[dict], timeout: int = 600, **kwargs
) -> dict:
"""Send a chat completion request. Retries until model is downloaded and inference completes."""
body = json.dumps({"model": model, "messages": messages, **kwargs}).encode()
deadline = asyncio.get_event_loop().time() + timeout
last_error = None
while asyncio.get_event_loop().time() < deadline:
try:
req = Request(
"http://localhost:52415/v1/chat/completions",
data=body,
headers={"Content-Type": "application/json"},
)
loop = asyncio.get_event_loop()
resp_bytes = await loop.run_in_executor(
None, lambda r=req: urlopen(r, timeout=300).read()
)
return json.loads(resp_bytes)
except Exception as e:
last_error = e
await asyncio.sleep(5)
raise TimeoutError(f"Chat request failed after {timeout}s: {last_error}")

20
e2e/docker-compose.yml Normal file
View File

@@ -0,0 +1,20 @@
services:
exo-node-1:
image: exo-e2e:latest
build:
context: ..
dockerfile: e2e/Dockerfile
environment:
- EXO_LIBP2P_NAMESPACE=docker-e2e
command: [".venv/bin/exo", "-v"]
ports:
- "52415:52415"
exo-node-2:
image: exo-e2e:latest
build:
context: ..
dockerfile: e2e/Dockerfile
environment:
- EXO_LIBP2P_NAMESPACE=docker-e2e
command: [".venv/bin/exo", "-v"]

77
e2e/run_all.py Normal file
View File

@@ -0,0 +1,77 @@
#!/usr/bin/env python3
"""Discovers and runs all E2E tests in e2e/test_*.py.
Tests with '# slow' on the first line of their docstring are skipped
unless --slow is passed or E2E_SLOW=1 is set.
"""
import os
import subprocess
import sys
from pathlib import Path
E2E_DIR = Path(__file__).parent.resolve()
def is_slow(test_file: Path) -> bool:
"""Check if the test file is marked as slow (has '# slow' in first 3 lines)."""
with open(test_file) as f:
for line in f:
if line.strip().startswith("#"):
continue
if line.strip().startswith('"""') or line.strip().startswith("'''"):
# Read into the docstring
for doc_line in f:
if "slow" in doc_line.lower() and doc_line.strip().startswith(
"slow"
):
return True
if '"""' in doc_line or "'''" in doc_line:
break
break
return False
def main():
run_slow = "--slow" in sys.argv or os.environ.get("E2E_SLOW") == "1"
if "--update-snapshots" in sys.argv:
os.environ["UPDATE_SNAPSHOTS"] = "1"
test_files = sorted(E2E_DIR.glob("test_*.py"))
if not test_files:
print("No test files found")
sys.exit(1)
passed = 0
failed = 0
skipped = 0
failures = []
for test_file in test_files:
name = test_file.stem
if is_slow(test_file) and not run_slow:
print(f"=== {name} === SKIPPED (slow, use --slow to run)")
skipped += 1
continue
print(f"=== {name} ===")
result = subprocess.run([sys.executable, str(test_file)])
if result.returncode == 0:
passed += 1
else:
failed += 1
failures.append(name)
print()
total = passed + failed + skipped
print("================================")
print(
f"{passed}/{total} tests passed" + (f", {skipped} skipped" if skipped else "")
)
if failed:
print(f"Failed: {' '.join(failures)}")
sys.exit(1)
if __name__ == "__main__":
main()

69
e2e/snapshot.py Normal file
View File

@@ -0,0 +1,69 @@
"""Snapshot testing infrastructure for E2E tests.
Provides deterministic regression testing by comparing inference output
against saved snapshots. On first run, snapshots are created automatically.
Set UPDATE_SNAPSHOTS=1 to regenerate snapshots when output intentionally changes.
Snapshots are stored per-architecture (e.g. snapshots/x86_64/, snapshots/arm64/)
since floating-point results differ between CPU architectures.
"""
import difflib
import json
import os
import platform
from pathlib import Path
ARCH = platform.machine()
SNAPSHOTS_DIR = Path(__file__).parent / "snapshots" / ARCH
def assert_snapshot(
name: str,
content: str,
metadata: dict,
) -> None:
"""Compare content against a saved snapshot, or create one if missing.
Args:
name: Snapshot identifier (used as filename: snapshots/{arch}/{name}.json).
content: The actual inference output to compare.
metadata: Additional context stored alongside content (model, seed, etc.).
Not used for comparison -- purely documentary.
Raises:
AssertionError: If content doesn't match the saved snapshot.
Environment:
UPDATE_SNAPSHOTS=1: Overwrite existing snapshot with actual content.
"""
snapshot_file = SNAPSHOTS_DIR / f"{name}.json"
update = os.environ.get("UPDATE_SNAPSHOTS") == "1"
if snapshot_file.exists() and not update:
snapshot = json.loads(snapshot_file.read_text())
expected = snapshot["content"]
if content != expected:
diff = "\n".join(
difflib.unified_diff(
expected.splitlines(),
content.splitlines(),
fromfile=f"expected ({snapshot_file.relative_to(SNAPSHOTS_DIR.parent.parent)})",
tofile="actual",
lineterm="",
)
)
raise AssertionError(
f"Snapshot mismatch for '{name}' on {ARCH}!\n\n"
f"{diff}\n\n"
f"Expected: {expected!r}\n"
f"Actual: {content!r}\n\n"
f"To update: UPDATE_SNAPSHOTS=1 python3 e2e/run_all.py"
)
print(f" Output matches snapshot ({ARCH}/{snapshot_file.name})")
else:
SNAPSHOTS_DIR.mkdir(parents=True, exist_ok=True)
snapshot_data = {**metadata, "arch": ARCH, "content": content}
snapshot_file.write_text(json.dumps(snapshot_data, indent=2) + "\n")
action = "Updated" if update else "Created"
print(f" {action} snapshot: {ARCH}/{snapshot_file.name}")

View File

@@ -0,0 +1,22 @@
"""Test: Basic cluster formation.
Verifies two nodes discover each other, elect a master, and the API responds.
"""
import asyncio
import sys
sys.path.insert(0, str(__import__("pathlib").Path(__file__).parent))
from conftest import Cluster
async def main():
async with Cluster("cluster_formation") as cluster:
await cluster.build()
await cluster.start()
await cluster.assert_healthy()
print("PASSED: cluster_formation")
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -0,0 +1,60 @@
"""Test: Deterministic inference output (snapshot test).
Sends a chat completion request with a fixed seed,
then verifies the output matches a known-good snapshot. This ensures
inference produces consistent results across runs.
Uses MLX CPU backend in Docker on x86 Linux.
"""
import asyncio
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent))
from snapshot import assert_snapshot
from conftest import Cluster
MODEL = "mlx-community/Qwen3-0.6B-4bit"
SEED = 42
PROMPT = "What is 2+2? Reply with just the number."
MAX_TOKENS = 32
async def main():
async with Cluster("inference_snapshot") as cluster:
await cluster.build()
await cluster.start()
await cluster.assert_healthy()
print(f" Launching model {MODEL}...")
await cluster.place_model(MODEL)
print(f" Sending chat completion (seed={SEED})...")
resp = await cluster.chat(
model=MODEL,
messages=[{"role": "user", "content": PROMPT}],
seed=SEED,
max_tokens=MAX_TOKENS,
)
content = resp["choices"][0]["message"]["content"]
print(f" Response: {content!r}")
assert_snapshot(
name="inference_snapshot",
content=content,
metadata={
"model": MODEL,
"seed": SEED,
"prompt": PROMPT,
"max_tokens": MAX_TOKENS,
},
)
print("PASSED: inference_snapshot")
if __name__ == "__main__":
asyncio.run(main())

47
e2e/test_no_internet.py Normal file
View File

@@ -0,0 +1,47 @@
"""Test: Cluster works without internet access.
Verifies exo functions correctly when containers can talk to each other
but cannot reach the internet. Uses iptables to block all outbound traffic
except private subnets and multicast (for mDNS discovery).
"""
import asyncio
import sys
sys.path.insert(0, str(__import__("pathlib").Path(__file__).parent))
from conftest import Cluster
async def main():
async with Cluster(
"no_internet",
overrides=["tests/no_internet/docker-compose.override.yml"],
) as cluster:
await cluster.build()
await cluster.start()
await cluster.assert_healthy()
# Verify internet is actually blocked from inside the containers
for node in ["exo-node-1", "exo-node-2"]:
rc, _ = await cluster.exec(
node,
"curl",
"-sf",
"--max-time",
"3",
"https://huggingface.co",
check=False,
)
assert rc != 0, f"{node} should not be able to reach the internet"
print(f" {node}: internet correctly blocked")
# Verify exo detected no internet connectivity
log = await cluster.logs()
assert "Internet connectivity: False" in log, "exo should detect no internet"
print(" exo correctly detected no internet connectivity")
print("PASSED: no_internet")
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -0,0 +1,58 @@
"""Test: Code generation snapshot.
Verifies deterministic output for a code generation prompt.
"""
import asyncio
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent))
from snapshot import assert_snapshot
from conftest import Cluster
MODEL = "mlx-community/Qwen3-0.6B-4bit"
SEED = 42
PROMPT = (
"Write a Python function to reverse a string. Only output the code, no explanation."
)
MAX_TOKENS = 64
async def main():
async with Cluster("snapshot_code_gen") as cluster:
await cluster.build()
await cluster.start()
await cluster.assert_healthy()
print(f" Launching model {MODEL}...")
await cluster.place_model(MODEL)
print(f" Sending chat completion (seed={SEED})...")
resp = await cluster.chat(
model=MODEL,
messages=[{"role": "user", "content": PROMPT}],
seed=SEED,
max_tokens=MAX_TOKENS,
)
content = resp["choices"][0]["message"]["content"]
print(f" Response: {content!r}")
assert_snapshot(
name="snapshot_code_gen",
content=content,
metadata={
"model": MODEL,
"seed": SEED,
"prompt": PROMPT,
"max_tokens": MAX_TOKENS,
},
)
print("PASSED: snapshot_code_gen")
if __name__ == "__main__":
asyncio.run(main())

63
e2e/test_snapshot_edge.py Normal file
View File

@@ -0,0 +1,63 @@
"""Test: Edge case snapshots.
Verifies deterministic output for edge-case prompts: single word input,
special characters, and unicode.
"""
import asyncio
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent))
from snapshot import assert_snapshot
from conftest import Cluster
MODEL = "mlx-community/Qwen3-0.6B-4bit"
SEED = 42
MAX_TOKENS = 32
CASES = [
("edge_single_word", "Hi"),
("edge_special_chars", "What does 2 * (3 + 4) / 7 - 1 equal? Use <math> tags."),
("edge_unicode", "Translate 'hello' to Japanese, Chinese, and Korean."),
]
async def main():
async with Cluster("snapshot_edge") as cluster:
await cluster.build()
await cluster.start()
await cluster.assert_healthy()
print(f" Launching model {MODEL}...")
await cluster.place_model(MODEL)
for snapshot_name, prompt in CASES:
print(f" [{snapshot_name}] Sending: {prompt!r}")
resp = await cluster.chat(
model=MODEL,
messages=[{"role": "user", "content": prompt}],
seed=SEED,
max_tokens=MAX_TOKENS,
)
content = resp["choices"][0]["message"]["content"]
print(f" [{snapshot_name}] Response: {content!r}")
assert_snapshot(
name=snapshot_name,
content=content,
metadata={
"model": MODEL,
"seed": SEED,
"prompt": prompt,
"max_tokens": MAX_TOKENS,
},
)
print("PASSED: snapshot_edge")
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -0,0 +1,56 @@
"""Test: Longer output snapshot.
Verifies deterministic output with a higher max_tokens (128).
"""
import asyncio
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent))
from snapshot import assert_snapshot
from conftest import Cluster
MODEL = "mlx-community/Qwen3-0.6B-4bit"
SEED = 42
PROMPT = "Explain how a binary search algorithm works."
MAX_TOKENS = 128
async def main():
async with Cluster("snapshot_long_output") as cluster:
await cluster.build()
await cluster.start()
await cluster.assert_healthy()
print(f" Launching model {MODEL}...")
await cluster.place_model(MODEL)
print(f" Sending chat completion (seed={SEED}, max_tokens={MAX_TOKENS})...")
resp = await cluster.chat(
model=MODEL,
messages=[{"role": "user", "content": PROMPT}],
seed=SEED,
max_tokens=MAX_TOKENS,
)
content = resp["choices"][0]["message"]["content"]
print(f" Response: {content!r}")
assert_snapshot(
name="snapshot_long_output",
content=content,
metadata={
"model": MODEL,
"seed": SEED,
"prompt": PROMPT,
"max_tokens": MAX_TOKENS,
},
)
print("PASSED: snapshot_long_output")
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -0,0 +1,72 @@
"""Test: Multi-model snapshot tests.
slow
Verifies deterministic output across different model architectures to catch
model-specific regressions. Each model uses its own snapshot file.
Run with: python3 e2e/run_all.py --slow or E2E_SLOW=1 python3 e2e/run_all.py
"""
import asyncio
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent))
from snapshot import assert_snapshot
from conftest import Cluster
SEED = 42
PROMPT = "What is the capital of France?"
MAX_TOKENS = 32
MODELS = [
"mlx-community/SmolLM2-135M-Instruct",
"mlx-community/Llama-3.2-1B-Instruct-4bit",
"mlx-community/gemma-2-2b-it-4bit",
]
async def main():
async with Cluster("snapshot_multi_model") as cluster:
await cluster.build()
await cluster.start()
await cluster.assert_healthy()
for model in MODELS:
short_name = (
model.split("/")[-1].lower().replace("-", "_").replace(".", "_")
)
snapshot_name = f"snapshot_multi_{short_name}"
print(f" Launching model {model}...")
await cluster.place_model(model)
print(f" Sending chat completion (seed={SEED})...")
resp = await cluster.chat(
model=model,
messages=[{"role": "user", "content": PROMPT}],
seed=SEED,
max_tokens=MAX_TOKENS,
)
content = resp["choices"][0]["message"]["content"]
print(f" [{short_name}] Response: {content!r}")
assert_snapshot(
name=snapshot_name,
content=content,
metadata={
"model": model,
"seed": SEED,
"prompt": PROMPT,
"max_tokens": MAX_TOKENS,
},
)
print(f" [{short_name}] PASSED")
print("PASSED: snapshot_multi_model")
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -0,0 +1,56 @@
"""Test: Reasoning/math snapshot.
Verifies deterministic output for a simple reasoning prompt.
"""
import asyncio
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent))
from snapshot import assert_snapshot
from conftest import Cluster
MODEL = "mlx-community/Qwen3-0.6B-4bit"
SEED = 42
PROMPT = "If I have 3 apples and give away 1, how many do I have? Think step by step."
MAX_TOKENS = 64
async def main():
async with Cluster("snapshot_reasoning") as cluster:
await cluster.build()
await cluster.start()
await cluster.assert_healthy()
print(f" Launching model {MODEL}...")
await cluster.place_model(MODEL)
print(f" Sending chat completion (seed={SEED})...")
resp = await cluster.chat(
model=MODEL,
messages=[{"role": "user", "content": PROMPT}],
seed=SEED,
max_tokens=MAX_TOKENS,
)
content = resp["choices"][0]["message"]["content"]
print(f" Response: {content!r}")
assert_snapshot(
name="snapshot_reasoning",
content=content,
metadata={
"model": MODEL,
"seed": SEED,
"prompt": PROMPT,
"max_tokens": MAX_TOKENS,
},
)
print("PASSED: snapshot_reasoning")
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -0,0 +1,32 @@
# Block all outbound internet traffic using iptables while preserving:
# - Multicast (224.0.0.0/4) for mDNS peer discovery
# - Private subnets (10/8, 172.16/12, 192.168/16) for inter-container communication
# - Loopback (127/8)
# Requires NET_ADMIN capability for iptables.
services:
exo-node-1:
cap_add:
- NET_ADMIN
entrypoint: ["/bin/sh", "-c"]
command:
- |
iptables -A OUTPUT -d 127.0.0.0/8 -j ACCEPT
iptables -A OUTPUT -d 10.0.0.0/8 -j ACCEPT
iptables -A OUTPUT -d 172.16.0.0/12 -j ACCEPT
iptables -A OUTPUT -d 192.168.0.0/16 -j ACCEPT
iptables -A OUTPUT -d 224.0.0.0/4 -j ACCEPT
iptables -A OUTPUT -j REJECT
exec .venv/bin/exo -v
exo-node-2:
cap_add:
- NET_ADMIN
entrypoint: ["/bin/sh", "-c"]
command:
- |
iptables -A OUTPUT -d 127.0.0.0/8 -j ACCEPT
iptables -A OUTPUT -d 10.0.0.0/8 -j ACCEPT
iptables -A OUTPUT -d 172.16.0.0/12 -j ACCEPT
iptables -A OUTPUT -d 192.168.0.0/16 -j ACCEPT
iptables -A OUTPUT -d 224.0.0.0/4 -j ACCEPT
iptables -A OUTPUT -j REJECT
exec .venv/bin/exo -v

View File

@@ -132,7 +132,7 @@ markers = [
env = [
"EXO_TESTS=1"
]
addopts = "-m 'not slow' --ignore=tests/start_distributed_test.py"
addopts = "-m 'not slow'"
filterwarnings = [
"ignore:builtin type Swig:DeprecationWarning",
]

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/SmolLM2-135M-Instruct"
n_layers = 30
hidden_size = 576
supports_tensor = true
tasks = ["TextGeneration"]
family = "llama"
quantization = "bf16"
base_model = "SmolLM2 135M"
capabilities = ["text"]
[storage_size]
in_bytes = 269060381

View File

@@ -0,0 +1,12 @@
model_id = "mlx-community/gemma-2-2b-it-4bit"
n_layers = 26
hidden_size = 2304
supports_tensor = false
tasks = ["TextGeneration"]
family = "gemma2"
quantization = "4bit"
base_model = "Gemma 2 2B"
capabilities = ["text"]
[storage_size]
in_bytes = 1492755242

View File

@@ -73,8 +73,6 @@ from exo.shared.types.api import (
CreateInstanceResponse,
DeleteDownloadResponse,
DeleteInstanceResponse,
DistributeModelParams,
DistributeModelResponse,
ErrorInfo,
ErrorResponse,
FinishReason,
@@ -119,7 +117,6 @@ from exo.shared.types.commands import (
CreateInstance,
DeleteDownload,
DeleteInstance,
DistributeModel,
DownloadCommand,
ForwarderCommand,
ForwarderDownloadCommand,
@@ -145,7 +142,6 @@ from exo.shared.types.openai_responses import (
ResponsesResponse,
)
from exo.shared.types.state import State
from exo.shared.types.worker.downloads import DownloadCompleted
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.shards import Sharding
from exo.utils.banner import print_startup_banner
@@ -302,7 +298,6 @@ class API:
self.app.get("/events")(self.stream_events)
self.app.post("/download/start")(self.start_download)
self.app.delete("/download/{node_id}/{model_id:path}")(self.delete_download)
self.app.post("/v1/models/{model_id:path}/distribute")(self.distribute_model)
self.app.get("/v1/traces")(self.list_traces)
self.app.get("/v1/traces/{task_id}")(self.get_trace)
self.app.get("/v1/traces/{task_id}/stats")(self.get_trace_stats)
@@ -1482,57 +1477,6 @@ class API:
await self._send_download(command)
return DeleteDownloadResponse(command_id=command.command_id)
async def distribute_model(
self, model_id: ModelId, payload: DistributeModelParams
) -> DistributeModelResponse:
"""Distribute model files from one node to others via MLX distributed."""
# Find a source node that has the model downloaded
source_node_id: NodeId | None = None
for nid, downloads in self.state.downloads.items():
for dp in downloads:
if (
isinstance(dp, DownloadCompleted)
and dp.shard_metadata.model_card.model_id == model_id
):
source_node_id = nid
break
if source_node_id is not None:
break
if source_node_id is None:
raise HTTPException(
status_code=404,
detail=f"No node has model {model_id} downloaded",
)
# Determine target nodes
if payload.target_node_ids is not None:
target_node_ids = [
nid for nid in payload.target_node_ids if nid != source_node_id
]
else:
target_node_ids = [
nid for nid in self.state.topology.list_nodes() if nid != source_node_id
]
if not target_node_ids:
raise HTTPException(
status_code=400,
detail="No target nodes to distribute to",
)
command = DistributeModel(
model_id=model_id,
source_node_id=source_node_id,
target_node_ids=target_node_ids,
)
await self._send(command)
return DistributeModelResponse(
command_id=command.command_id,
message=f"Distributing {model_id} from {source_node_id} to {len(target_node_ids)} node(s)",
)
def _get_trace_path(self, task_id: str) -> Path:
return EXO_TRACING_CACHE_DIR / f"trace_{task_id}.json"

View File

@@ -17,7 +17,6 @@ from exo.shared.constants import EXO_EVENT_LOG_DIR, EXO_TRACING_ENABLED
from exo.shared.types.commands import (
CreateInstance,
DeleteInstance,
DistributeModel,
ForwarderCommand,
ForwarderDownloadCommand,
ImageEdits,
@@ -313,37 +312,6 @@ class Master:
self.state.instances, placement
)
generated_events.extend(transition_events)
case DistributeModel():
from exo.shared.models.model_cards import ModelCard
from exo.shared.types.worker.instances import InstanceMeta
from exo.shared.types.worker.shards import Sharding
model_card = await ModelCard.load(command.model_id)
all_node_ids = set(
[command.source_node_id] + list(command.target_node_ids)
)
place_command = PlaceInstance(
model_card=model_card,
sharding=Sharding.Pipeline,
instance_meta=InstanceMeta.MlxRing,
min_nodes=len(all_node_ids),
)
placement = place_instance(
place_command,
self.state.topology,
self.state.instances,
self.state.node_memory,
self.state.node_network,
required_nodes=all_node_ids,
)
# Mark new instances as transfer-only
for instance_id, instance in placement.items():
if instance_id not in self.state.instances:
instance.shard_assignments.transfer_only = True
transition_events = get_transition_events(
self.state.instances, placement
)
generated_events.extend(transition_events)
case SendInputChunk(chunk=chunk):
generated_events.append(
InputChunkReceived(

View File

@@ -374,15 +374,6 @@ class DeleteDownloadResponse(CamelCaseModel):
command_id: CommandId
class DistributeModelParams(CamelCaseModel):
target_node_ids: list[NodeId] | None = None # None = all connected nodes
class DistributeModelResponse(CamelCaseModel):
command_id: CommandId
message: str
class TraceEventResponse(CamelCaseModel):
name: str
start_us: int

View File

@@ -77,14 +77,6 @@ class CancelDownload(BaseCommand):
model_id: ModelId
class DistributeModel(BaseCommand):
"""Distribute model files from one node to others via MLX distributed."""
model_id: ModelId
source_node_id: NodeId
target_node_ids: list[NodeId]
DownloadCommand = StartDownload | DeleteDownload | CancelDownload
@@ -99,7 +91,6 @@ Command = (
| DeleteInstance
| TaskFinished
| SendInputChunk
| DistributeModel
)

View File

@@ -41,7 +41,7 @@ class DownloadModel(BaseTask): # emitted by Worker
class LoadModel(BaseTask): # emitted by Worker
has_local_model: bool = Field(default=True)
pass
class ConnectToGroup(BaseTask): # emitted by Worker
@@ -76,13 +76,6 @@ class ImageEdits(BaseTask): # emitted by Master
error_message: str | None = Field(default=None)
class TransferModelToDisk(BaseTask): # emitted by Worker
"""Transfer all model files from source to receivers' disk via MLX distributed."""
shard_metadata: ShardMetadata
has_local_model: bool = Field(default=True)
class Shutdown(BaseTask): # emitted by Worker
runner_id: RunnerId
@@ -92,7 +85,6 @@ Task = (
| DownloadModel
| ConnectToGroup
| LoadModel
| TransferModelToDisk
| StartWarmup
| TextGeneration
| ImageGeneration

View File

@@ -84,7 +84,6 @@ class ShardAssignments(CamelCaseModel):
model_id: ModelId
runner_to_shard: Mapping[RunnerId, ShardMetadata]
node_to_runner: Mapping[NodeId, RunnerId]
transfer_only: bool = False
@model_validator(mode="after")
def validate_runners_exist(self) -> "ShardAssignments":

View File

@@ -47,7 +47,6 @@ if TYPE_CHECKING:
from mlx_lm.models.cache import Cache
TimeoutCallback = Callable[[], None]
WeightLoader = Callable[[nn.Module, int], None] | None
def eval_with_timeout(
@@ -347,7 +346,6 @@ def tensor_auto_parallel(
group: mx.distributed.Group,
timeout_seconds: float = 60.0,
on_timeout: TimeoutCallback | None = None,
weight_loader: WeightLoader = None,
) -> nn.Module:
all_to_sharded_linear = partial(
shard_linear,
@@ -457,7 +455,7 @@ def tensor_auto_parallel(
raise ValueError(f"Unsupported model type: {type(model)}")
model = tensor_parallel_sharding_strategy.shard_model(
model, timeout_seconds, on_timeout, weight_loader
model, timeout_seconds, on_timeout
)
return patch_tensor_model(model)
@@ -484,7 +482,6 @@ class TensorParallelShardingStrategy(ABC):
model: nn.Module,
timeout_seconds: float,
on_timeout: TimeoutCallback | None,
weight_loader: WeightLoader = None,
) -> nn.Module: ...
@@ -494,12 +491,9 @@ class LlamaShardingStrategy(TensorParallelShardingStrategy):
model: nn.Module,
timeout_seconds: float,
on_timeout: TimeoutCallback | None,
weight_loader: WeightLoader = None,
) -> nn.Module:
model = cast(LlamaModel, model)
for i, layer in enumerate(model.layers):
if weight_loader is not None:
weight_loader(model, i)
for layer in model.layers:
# Force load weights before sharding to avoid FAST_SYNCH deadlock
eval_with_timeout(
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
@@ -551,12 +545,9 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
model: nn.Module,
timeout_seconds: float,
on_timeout: TimeoutCallback | None,
weight_loader: WeightLoader = None,
) -> nn.Module:
model = cast(DeepseekV3Model, model)
for i, layer in enumerate(model.layers):
if weight_loader is not None:
weight_loader(model, i)
for layer in model.layers:
eval_with_timeout(
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
)
@@ -629,12 +620,9 @@ class GLM4MoeLiteShardingStrategy(TensorParallelShardingStrategy):
model: nn.Module,
timeout_seconds: float,
on_timeout: TimeoutCallback | None,
weight_loader: WeightLoader = None,
) -> nn.Module:
model = cast(GLM4MoeLiteModel, model)
for i, layer in enumerate(model.layers): # type: ignore
if weight_loader is not None:
weight_loader(model, i)
for layer in model.layers: # type: ignore
layer = cast(Glm4MoeLiteDecoderLayer, layer)
eval_with_timeout(
layer.parameters(),
@@ -774,12 +762,9 @@ class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
model: nn.Module,
timeout_seconds: float,
on_timeout: TimeoutCallback | None,
weight_loader: WeightLoader = None,
) -> nn.Module:
model = cast(MiniMaxModel, model)
for i, layer in enumerate(model.layers):
if weight_loader is not None:
weight_loader(model, i)
for layer in model.layers:
eval_with_timeout(
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
)
@@ -817,12 +802,9 @@ class QwenShardingStrategy(TensorParallelShardingStrategy):
model: nn.Module,
timeout_seconds: float,
on_timeout: TimeoutCallback | None,
weight_loader: WeightLoader = None,
) -> nn.Module:
model = cast(Qwen3MoeModel | Qwen3NextModel, model)
for i, layer in enumerate(model.layers):
if weight_loader is not None:
weight_loader(model, i)
for layer in model.layers:
eval_with_timeout(
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
)
@@ -944,12 +926,9 @@ class Glm4MoeShardingStrategy(TensorParallelShardingStrategy):
model: nn.Module,
timeout_seconds: float,
on_timeout: TimeoutCallback | None,
weight_loader: WeightLoader = None,
) -> nn.Module:
model = cast(Glm4MoeModel, model)
for i, layer in enumerate(model.layers):
if weight_loader is not None:
weight_loader(model, i)
for layer in model.layers:
eval_with_timeout(
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
)
@@ -993,13 +972,10 @@ class GptOssShardingStrategy(TensorParallelShardingStrategy):
model: nn.Module,
timeout_seconds: float,
on_timeout: TimeoutCallback | None,
weight_loader: WeightLoader = None,
) -> nn.Module:
model = cast(GptOssMoeModel, model)
for i, layer in enumerate(model.layers):
if weight_loader is not None:
weight_loader(model, i)
for layer in model.layers:
eval_with_timeout(
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
)
@@ -1037,7 +1013,6 @@ class Step35ShardingStrategy(TensorParallelShardingStrategy):
model: nn.Module,
timeout_seconds: float,
on_timeout: TimeoutCallback | None,
weight_loader: WeightLoader = None,
) -> nn.Module:
model = cast(Step35Model, model)

View File

@@ -1,507 +0,0 @@
"""
Model transfer via MLX distributed all_sum.
Three transfer modes:
1. Metadata file transfer: broadcast small files (config.json, tokenizer, etc.) to disk
2. Weight tensor broadcast: stream weight tensors directly into memory via all_sum
3. Full file transfer: broadcast all files (including safetensors) to disk
All functions are collective operations — every rank in the group must call them.
Protocol relies on all_sum: source has real data, receivers have zeros.
all_sum(source + zeros) = source data on all ranks.
"""
from __future__ import annotations
import json
import os
import re
import shutil
import tempfile
from functools import partial
from pathlib import Path
from typing import Any, Final, cast
import mlx.core as mx
from exo.shared.constants import EXO_MODELS_DIR
from exo.shared.models.model_cards import ModelId
from exo.worker.runner.bootstrap import logger
Group = mx.distributed.Group
CHUNK_SIZE: Final[int] = 100 * 1024 * 1024 # 100 MB
_LAYER_RE: Final[re.Pattern[str]] = re.compile(r"(?:^|\.)(layers|h)\.(\d+)\.")
def _all_sum_cpu(x: mx.array, group: Group) -> mx.array:
"""all_sum on CPU stream to avoid GPU memory pressure."""
return mx.distributed.all_sum(
x, stream=mx.default_stream(mx.Device(mx.cpu)), group=group
)
def _is_metadata_file(filename: str) -> bool:
"""A metadata file is anything that isn't a weight file or weight index.
Weight indices (.safetensors.index.json) reference safetensors shard paths.
Transferring them to a receiver that has no safetensors files is harmless
today (load_model's glob doesn't match them), but excluding them avoids
stale references and keeps the transfer minimal.
"""
if filename.endswith(".safetensors"):
return False
return not filename.endswith(".safetensors.index.json")
def model_path_for_id(model_id: ModelId) -> Path:
"""Get model path without requiring directory to exist (unlike build_model_path)."""
return EXO_MODELS_DIR / model_id.normalize()
def coordinate_transfer(group: Group, has_local_model: bool) -> tuple[bool, int]:
"""
Determine if a transfer is needed and which rank is the source.
All ranks must call this function (uses collective all_sum).
Returns:
(needs_transfer, source_rank) — source_rank is the lowest rank
that has the model. needs_transfer is True if any rank is missing it.
"""
all_sum = partial(_all_sum_cpu, group=group)
world_size = group.size()
# Each rank broadcasts a one-hot vector at its position if it has the model
bitmask = mx.zeros(world_size, dtype=mx.int32)
if has_local_model:
bitmask = bitmask.at[group.rank()].add(1)
summed = all_sum(bitmask)
mx.eval(summed)
has_model_flags: list[int] = summed.tolist() # type: ignore[assignment]
total_have = sum(has_model_flags)
if total_have == 0:
raise RuntimeError(
"No rank has the model files — cannot transfer. "
"At least one node must have downloaded the model."
)
if total_have == world_size:
logger.info("All ranks have model files, no transfer needed")
return False, 0
source_rank = next(i for i, flag in enumerate(has_model_flags) if flag > 0)
logger.info(
f"Transfer needed: source_rank={source_rank}, "
f"{total_have}/{world_size} ranks have model"
)
return True, source_rank
def _broadcast_json(obj: object, group: Group, is_source: bool) -> object:
"""Broadcast a JSON-serializable object from source to all ranks."""
all_sum = partial(_all_sum_cpu, group=group)
data = json.dumps(obj, separators=(",", ":")).encode("utf-8") if is_source else b""
# Broadcast length
len_arr = mx.array([len(data) if is_source else 0], dtype=mx.int64)
len_result = all_sum(len_arr)
mx.eval(len_result)
length = int(len_result.item())
if length == 0:
return None
# Broadcast payload
if is_source:
arr = mx.array(list(data), dtype=mx.uint8)
else:
arr = mx.zeros(length, dtype=mx.uint8)
result = all_sum(arr)
mx.eval(result)
return json.loads(bytes(cast(list[int], result.tolist()))) # pyright: ignore[reportAny]
def _build_manifest(
model_path: Path, metadata_only: bool = False
) -> list[dict[str, str | int]]:
"""Build a list of files in the model directory with their relative paths and sizes."""
manifest: list[dict[str, str | int]] = []
for root, _dirs, files in os.walk(model_path):
for fname in sorted(files):
if metadata_only and not _is_metadata_file(fname):
continue
full_path = Path(root) / fname
rel_path = str(full_path.relative_to(model_path))
manifest.append(
{
"path": rel_path,
"size": full_path.stat().st_size,
}
)
return manifest
def _transfer_file_to_disk(
source_path: Path,
rel_path: str,
file_size: int,
group: Group,
is_source: bool,
dest_path: Path,
) -> None:
"""Transfer a single file chunk-by-chunk via all_sum. Source reads from disk, receivers write to dest_path."""
all_sum = partial(_all_sum_cpu, group=group)
if is_source:
src_file = source_path / rel_path
with open(src_file, "rb") as f:
offset = 0
while offset < file_size:
chunk_bytes = min(CHUNK_SIZE, file_size - offset)
data = f.read(chunk_bytes)
if not data:
break
size_arr = mx.array([len(data)], dtype=mx.int64)
mx.eval(all_sum(size_arr))
chunk_arr = mx.array(list(data), dtype=mx.uint8)
result = all_sum(chunk_arr)
mx.eval(result)
offset += len(data)
# Signal end of file
mx.eval(all_sum(mx.array([0], dtype=mx.int64)))
else:
dst_file = dest_path / rel_path
os.makedirs(dst_file.parent, exist_ok=True)
with open(dst_file, "wb") as f:
while True:
size_arr = all_sum(mx.zeros(1, dtype=mx.int64))
mx.eval(size_arr)
chunk_size = int(size_arr.item())
if chunk_size == 0:
break
chunk_data = all_sum(mx.zeros(chunk_size, dtype=mx.uint8))
mx.eval(chunk_data)
f.write(bytes(cast(list[int], chunk_data.tolist())))
def _transfer_files_to_disk(
model_path: Path,
group: Group,
is_source: bool,
metadata_only: bool = False,
) -> None:
"""
Transfer files from source to all receivers' disk.
Source broadcasts a manifest then each file. Receivers write to a temp dir
then atomically move files to model_path.
"""
if is_source:
source_manifest = _build_manifest(model_path, metadata_only=metadata_only)
else:
source_manifest = []
manifest = cast(
list[dict[str, str | int]],
_broadcast_json(source_manifest if is_source else None, group, is_source),
)
if not manifest:
logger.info("No files to transfer")
return
logger.info(
f"Transferring {len(manifest)} files ({'metadata only' if metadata_only else 'all'})"
)
temp_dir: Path | None = None
if not is_source:
os.makedirs(model_path.parent, exist_ok=True)
temp_dir = Path(
tempfile.mkdtemp(
dir=model_path.parent,
prefix=f".transfer_{model_path.name}_",
)
)
try:
for entry in manifest:
rel_path = str(entry["path"])
file_size = int(entry["size"])
logger.info(f" {rel_path} ({file_size} bytes)")
_transfer_file_to_disk(
source_path=model_path,
rel_path=rel_path,
file_size=file_size,
group=group,
is_source=is_source,
dest_path=temp_dir if temp_dir is not None else model_path,
)
if temp_dir is not None:
os.makedirs(model_path, exist_ok=True)
for entry in manifest:
rel_path = str(entry["path"])
src = temp_dir / rel_path
dst = model_path / rel_path
os.makedirs(dst.parent, exist_ok=True)
os.replace(src, dst)
logger.info(
f"Transfer complete: {len(manifest)} files moved to {model_path}"
)
finally:
if temp_dir is not None and temp_dir.exists():
shutil.rmtree(temp_dir, ignore_errors=True)
def transfer_metadata_files(model_path: Path, group: Group, is_source: bool) -> None:
"""
Transfer metadata files (config.json, tokenizer files, etc.) to receivers' disk.
All ranks must call this function (collective operation).
Only the designated source (is_source=True) should send; all others receive.
"""
_transfer_files_to_disk(model_path, group, is_source=is_source, metadata_only=True)
def transfer_all_files(model_path: Path, group: Group, is_source: bool) -> None:
"""
Transfer ALL model files (including safetensors) to receivers' disk.
All ranks must call this function (collective operation).
Only the designated source (is_source=True) should send; all others receive.
"""
_transfer_files_to_disk(model_path, group, is_source=is_source, metadata_only=False)
def _parse_mx_dtype(dtype_str: str) -> mx.Dtype:
"""Convert a dtype string like 'float16' or 'mlx.core.float16' to mx.Dtype."""
name = dtype_str.split(".")[-1]
dtype = getattr(mx, name, None)
if dtype is None:
raise ValueError(f"Unknown MLX dtype: {dtype_str}")
return dtype # type: ignore[return-value]
def _extract_layer_index(name: str) -> int | None:
"""Extract layer index from a weight name, or None for non-layer weights.
Matches patterns like ``model.layers.5.self_attn.q_proj.weight``
or ``transformer.h.12.mlp.gate_proj.scales``.
"""
m = _LAYER_RE.search(name)
return int(m.group(2)) if m else None
class WeightBroadcastState:
"""Holds state for layer-by-layer weight broadcasting.
Created by :func:`prepare_weight_broadcast`. Callers stream weights
incrementally via :meth:`broadcast_non_layer_weights` and
:meth:`broadcast_layer` so that at most one layer's worth of un-sharded
weight data is resident at a time.
"""
def __init__(
self,
meta: dict[str, dict[str, Any]],
source_weights: dict[str, mx.array] | None,
group: Group,
is_source: bool,
) -> None:
self.meta = meta
self.source_weights = source_weights
self.group = group
self.is_source = is_source
# Partition weight names into layer vs. non-layer
self.layer_names: dict[int, list[str]] = {}
self.non_layer_names: list[str] = []
for name in sorted(meta.keys()):
layer_idx = _extract_layer_index(name)
if layer_idx is not None:
self.layer_names.setdefault(layer_idx, []).append(name)
else:
self.non_layer_names.append(name)
logger.info(
f"WeightBroadcastState: {len(self.non_layer_names)} non-layer weights, "
f"{len(self.layer_names)} layers"
)
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
def _broadcast_names(self, names: list[str]) -> dict[str, mx.array]:
"""Broadcast a specific set of weight tensors by name."""
all_sum = partial(_all_sum_cpu, group=self.group)
result: dict[str, mx.array] = {}
for name in names:
info = self.meta[name]
shape = cast(list[int], info["s"])
dtype = _parse_mx_dtype(cast(str, info["d"]))
if self.is_source:
assert self.source_weights is not None
tensor = self.source_weights.pop(name)
mx.eval(tensor) # loads from disk (lazy)
else:
tensor = mx.zeros(shape, dtype=dtype)
broadcasted = all_sum(tensor)
mx.eval(broadcasted)
result[name] = broadcasted
return result
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
def broadcast_non_layer_weights(self) -> dict[str, mx.array]:
"""Broadcast non-layer weights (embeddings, norms, lm_head)."""
if not self.non_layer_names:
return {}
logger.info(
f"Broadcasting {len(self.non_layer_names)} non-layer weight tensors"
)
return self._broadcast_names(self.non_layer_names)
def broadcast_layer(self, layer_idx: int) -> dict[str, mx.array]:
"""Broadcast weights for a single transformer layer."""
names = self.layer_names.get(layer_idx, [])
if not names:
return {}
return self._broadcast_names(names)
def prepare_weight_broadcast(
model_path: Path,
group: Group,
is_source: bool,
) -> WeightBroadcastState:
"""Prepare for layer-by-layer weight broadcasting.
Source loads safetensors lazily and broadcasts weight metadata (names,
shapes, dtypes) as JSON. Returns a :class:`WeightBroadcastState` that
can then stream weights incrementally via ``broadcast_layer()``.
All ranks must call this function (collective operation).
"""
source_weights: dict[str, mx.array] | None = None
if is_source:
source_weights = {}
weight_files = sorted(model_path.glob("*.safetensors"))
if not weight_files:
weight_files = sorted(model_path.glob("**/*.safetensors"))
for wf in weight_files:
try:
loaded = cast(
dict[str, mx.array],
mx.load(str(wf), lazy=True), # pyright: ignore[reportCallIssue]
)
except TypeError:
loaded = cast(dict[str, mx.array], mx.load(str(wf)))
source_weights.update(loaded)
logger.info(
f"Source loaded {len(source_weights)} weight tensors (lazy) "
f"from {len(weight_files)} files"
)
# Broadcast metadata
if is_source and source_weights is not None:
source_meta: dict[str, dict[str, Any]] = {
name: {"s": list(tensor.shape), "d": str(tensor.dtype)}
for name, tensor in source_weights.items()
}
else:
source_meta = {}
meta = cast(
dict[str, dict[str, Any]],
_broadcast_json(source_meta if is_source else None, group, is_source),
)
logger.info(f"Weight broadcast prepared: {len(meta)} tensors")
return WeightBroadcastState(meta, source_weights, group, is_source)
def broadcast_model_weights(
model_path: Path,
group: Group,
is_source: bool,
) -> dict[str, mx.array]:
"""
Broadcast model weight tensors from source rank to all receivers' memory.
Source loads weights from .safetensors files on disk and broadcasts each
tensor via all_sum. Receivers receive tensors directly as mx.arrays in
memory — no disk write for weight data.
All ranks must call this function (collective operation).
Only the designated source (is_source=True) should send; all others receive.
Returns:
dict mapping weight names to mx.arrays (on all ranks).
"""
all_sum = partial(_all_sum_cpu, group=group)
# Source loads weights (lazy if supported, so only one tensor in memory at a time)
weights: dict[str, mx.array] = {}
if is_source:
weight_files = sorted(model_path.glob("*.safetensors"))
if not weight_files:
weight_files = sorted(model_path.glob("**/*.safetensors"))
for wf in weight_files:
try:
loaded = cast(dict[str, mx.array], mx.load(str(wf), lazy=True)) # pyright: ignore[reportCallIssue]
except TypeError:
loaded = cast(dict[str, mx.array], mx.load(str(wf)))
weights.update(loaded)
logger.info(
f"Source loaded {len(weights)} weight tensors from {len(weight_files)} files"
)
# Broadcast weight metadata: {name: {shape, dtype}}
if is_source:
source_meta: dict[str, dict[str, Any]] = {
name: {"s": list(tensor.shape), "d": str(tensor.dtype)}
for name, tensor in weights.items()
}
else:
source_meta = {}
meta = cast(
dict[str, dict[str, Any]],
_broadcast_json(source_meta if is_source else None, group, is_source),
)
logger.info(f"Broadcasting {len(meta)} weight tensors")
# Broadcast each tensor in sorted order (deterministic across ranks).
# Source loads one tensor at a time from disk (lazy), broadcasts it,
# then drops the reference so only one tensor is in flight at a time.
result: dict[str, mx.array] = {}
for i, name in enumerate(sorted(meta.keys())):
info = meta[name]
shape = cast(list[int], info["s"])
dtype_str = cast(str, info["d"])
dtype = _parse_mx_dtype(dtype_str)
if is_source:
tensor = weights.pop(name) # pop to free lazy ref after broadcast
mx.eval(tensor) # loads from disk
else:
tensor = mx.zeros(shape, dtype=dtype)
broadcasted = all_sum(tensor)
mx.eval(broadcasted)
result[name] = broadcasted
if (i + 1) % 100 == 0:
logger.info(f" Broadcast {i + 1}/{len(meta)} tensors")
logger.info(f"Weight broadcast complete: {len(result)} tensors")
return result

View File

@@ -2,7 +2,6 @@ import json
import os
import sys
import time
from collections.abc import Callable
from pathlib import Path
from typing import Any, cast
@@ -60,13 +59,6 @@ from exo.worker.engines.mlx.auto_parallel import (
pipeline_auto_parallel,
tensor_auto_parallel,
)
from exo.worker.engines.mlx.model_transfer import (
WeightBroadcastState,
coordinate_transfer,
model_path_for_id,
prepare_weight_broadcast,
transfer_metadata_files,
)
from exo.worker.runner.bootstrap import logger
Group = mx.distributed.Group
@@ -205,7 +197,6 @@ def load_mlx_items(
bound_instance: BoundInstance,
group: Group | None,
on_timeout: TimeoutCallback | None = None,
has_local_model: bool = True,
) -> tuple[Model, TokenizerWrapper]:
if group is None:
logger.info(f"Single device used for {bound_instance.instance}")
@@ -220,10 +211,7 @@ def load_mlx_items(
logger.info("Starting distributed init")
start_time = time.perf_counter()
model, tokenizer = shard_and_load(
bound_instance.bound_shard,
group=group,
on_timeout=on_timeout,
has_local_model=has_local_model,
bound_instance.bound_shard, group=group, on_timeout=on_timeout
)
end_time = time.perf_counter()
logger.info(
@@ -239,89 +227,30 @@ def shard_and_load(
shard_metadata: ShardMetadata,
group: Group,
on_timeout: TimeoutCallback | None = None,
has_local_model: bool = True,
) -> tuple[nn.Module, TokenizerWrapper]:
model_id = shard_metadata.model_card.model_id
model_path = model_path_for_id(model_id)
model_path = build_model_path(shard_metadata.model_card.model_id)
# Coordinate: does any rank need a transfer?
needs_transfer, source_rank = coordinate_transfer(group, has_local_model)
is_source = group.rank() == source_rank
# Step 1: Always ensure all nodes have metadata files (config, tokenizer, etc.).
# This is cheap (~20MB, ~1s) and guarantees config.json is present for load_model().
transfer_metadata_files(model_path, group, is_source)
# Step 2: Only broadcast weights if some rank is missing the model
broadcast_state: WeightBroadcastState | None = None
if needs_transfer:
logger.info(
f"Model transfer needed (source_rank={source_rank}, "
f"is_source={is_source}, local_weights={has_local_model})"
)
broadcast_state = prepare_weight_broadcast(model_path, group, is_source)
# Create model architecture (all ranks have config.json on disk now).
# Always use lazy=True when we have broadcast state: load_model's internal
# nn.quantize skips quantization when weights dict is empty (no safetensors),
# leaving the model un-quantized. lazy=False would then mx.eval() the full
# fp16 model (~72GB for a 36B-param model), causing OOM on the receiver.
# We handle quantization ourselves below before loading broadcast weights.
use_lazy = has_local_model or broadcast_state is not None
model, _ = load_model(model_path, lazy=use_lazy, strict=False)
model, _ = load_model(model_path, lazy=True, strict=False)
logger.debug(model)
if hasattr(model, "model") and isinstance(model.model, DeepseekV3Model): # type: ignore
pass
# TODO: See if we should quantize the model.
# def is_attention_layer(path: str) -> bool:
# path = path.lower()
# return "self_attn" in path and "layernorm" not in path
# def quant_predicate(path: str, module: nn.Module):
# if not isinstance(module, nn.Linear):
# return False
# return is_attention_layer(path)
# model, config = quantize_model(
# model, config, group_size=KV_GROUP_SIZE, bits=ATTENTION_KV_BITS, quant_predicate=quant_predicate, mode=QUANTIZE_MODEL_MODE
# )
assert isinstance(model, nn.Module)
if broadcast_state is not None:
# When receiver has no weight files, load_model skips quantization
# (its class_predicate checks `f"{p}.scales" in weights`, which is
# always False when weights is empty). Apply quantization explicitly
# using the broadcast metadata to determine which layers are quantized,
# matching load_model's selective quantization logic exactly.
if not has_local_model:
config_path = model_path / "config.json"
with open(config_path) as f:
config = json.load(f) # pyright: ignore[reportAny]
quant_config: dict[str, Any] | None = config.get( # pyright: ignore[reportAny]
"quantization", None
)
if quant_config is not None:
logger.info(f"Applying quantization to receiver model: {quant_config}")
broadcast_weight_names = set(broadcast_state.meta.keys())
def _class_predicate(p: str, m: nn.Module) -> bool | dict[str, Any]:
# Per-layer overrides from config (e.g. "lm_head": false)
assert quant_config is not None
if p in quant_config:
return quant_config[p] # pyright: ignore[reportAny]
if not hasattr(m, "to_quantized"):
return False
# Only quantize layers whose .scales exist in broadcast weights
return f"{p}.scales" in broadcast_weight_names
group_size = int(quant_config.get("group_size", 64)) # pyright: ignore[reportAny]
bits = int(quant_config.get("bits", 4)) # pyright: ignore[reportAny]
mode: str = quant_config.get("mode", "affine") # pyright: ignore[reportAny]
nn.quantize( # pyright: ignore[reportUnknownMemberType]
model,
group_size=group_size,
bits=bits,
mode=mode,
class_predicate=_class_predicate,
)
# Broadcast and load non-layer weights (embeddings, norms, lm_head) upfront.
# These are small (~600MB) and needed before the sharding loop.
non_layer_weights = broadcast_state.broadcast_non_layer_weights()
if non_layer_weights:
model.load_weights(list(non_layer_weights.items()), strict=False)
logger.info(f"Loaded {len(non_layer_weights)} non-layer weight tensors")
del non_layer_weights
tokenizer = get_tokenizer(model_path, shard_metadata)
logger.info(f"Group size: {group.size()}, group rank: {group.rank()}")
@@ -335,43 +264,12 @@ def shard_and_load(
f"(model size: {model_size_gb:.1f}GB)"
)
# Build per-layer weight loader for streaming broadcast during sharding.
# Each layer's weights are broadcast via all_sum just before that layer is
# sharded, so at most one un-sharded layer is in memory at a time.
weight_loader_fn: Callable[[nn.Module, int], None] | None = None
if broadcast_state is not None:
_state = broadcast_state # capture for closure
def _load_layer_weights(mdl: nn.Module, layer_idx: int) -> None:
layer_weights = _state.broadcast_layer(layer_idx)
if layer_weights:
mdl.load_weights(list(layer_weights.items()), strict=False)
weight_loader_fn = _load_layer_weights
match shard_metadata:
case TensorShardMetadata():
logger.info(f"loading model from {model_path} with tensor parallelism")
model = tensor_auto_parallel(
model, group, timeout_seconds, on_timeout, weight_loader_fn
)
model = tensor_auto_parallel(model, group, timeout_seconds, on_timeout)
case PipelineShardMetadata():
logger.info(f"loading model from {model_path} with pipeline parallelism")
# Broadcast all layers (all_sum is collective — all ranks must
# participate) but only load weights for layers this node will
# keep after pipeline slicing. Out-of-range results are discarded,
# keeping peak memory proportional to this node's layer count.
if broadcast_state is not None:
for layer_idx in sorted(broadcast_state.layer_names.keys()):
layer_weights = broadcast_state.broadcast_layer(layer_idx)
if (
shard_metadata.start_layer
<= layer_idx
< shard_metadata.end_layer
and layer_weights
):
model.load_weights(list(layer_weights.items()), strict=False)
del layer_weights
model = pipeline_auto_parallel(model, group, shard_metadata)
eval_with_timeout(model.parameters(), timeout_seconds, on_timeout)
case CfgShardMetadata():
@@ -380,8 +278,6 @@ def shard_and_load(
"this metadata type is only for image generation models"
)
del broadcast_state
# TODO: Do we need this?
mx.eval(model)

View File

@@ -2,7 +2,6 @@
from collections.abc import Mapping, Sequence
from exo.shared.models.model_cards import ModelId
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.tasks import (
ConnectToGroup,
@@ -17,7 +16,6 @@ from exo.shared.types.tasks import (
TaskId,
TaskStatus,
TextGeneration,
TransferModelToDisk,
)
from exo.shared.types.worker.downloads import (
DownloadCompleted,
@@ -36,11 +34,8 @@ from exo.shared.types.worker.runners import (
RunnerLoading,
RunnerReady,
RunnerRunning,
RunnerShutdown,
RunnerShuttingDown,
RunnerStatus,
RunnerWarmingUp,
ShardAssignments,
)
from exo.worker.runner.runner_supervisor import RunnerSupervisor
@@ -62,7 +57,6 @@ def plan(
or _create_runner(node_id, runners, instances)
or _model_needs_download(node_id, runners, global_download_status)
or _init_distributed_backend(runners, all_runners)
or _transfer_model_to_disk(runners, all_runners, global_download_status)
or _load_model(runners, all_runners, global_download_status)
or _ready_to_warmup(runners, all_runners)
or _pending_tasks(runners, tasks, all_runners, input_chunk_buffer)
@@ -127,10 +121,6 @@ def _model_needs_download(
}
for runner in runners.values():
# Transfer-only instances don't need downloads
if runner.bound_instance.instance.shard_assignments.transfer_only:
continue
model_id = runner.bound_instance.bound_shard.model_card.model_id
if isinstance(runner.status, RunnerIdle) and (
model_id not in download_status
@@ -139,15 +129,6 @@ def _model_needs_download(
(DownloadOngoing, DownloadCompleted, DownloadFailed),
)
):
# For multi-node instances, skip download if a peer already has the model.
# The model will be transferred via MLX distributed during LoadModel.
instance = runner.bound_instance.instance
is_multi_node = len(instance.shard_assignments.node_to_runner) > 1
if is_multi_node and _any_peer_has_model(
node_id, model_id, instance, global_download_status
):
continue
# We don't invalidate download_status randomly in case a file gets deleted on disk
return DownloadModel(
instance_id=runner.bound_instance.instance.instance_id,
@@ -205,43 +186,6 @@ def _init_distributed_backend(
return None
def _transfer_model_to_disk(
runners: Mapping[RunnerId, RunnerSupervisor],
all_runners: Mapping[RunnerId, RunnerStatus],
global_download_status: Mapping[NodeId, Sequence[DownloadProgress]],
) -> TransferModelToDisk | None:
"""For transfer-only instances: after all ranks are connected, emit TransferModelToDisk."""
for runner in runners.values():
instance = runner.bound_instance.instance
shard_assignments = instance.shard_assignments
if not shard_assignments.transfer_only:
continue
is_runner_connected = isinstance(runner.status, RunnerConnected)
all_connected_or_further = all(
isinstance(
all_runners.get(global_runner_id, None),
(RunnerConnected, RunnerLoading, RunnerShuttingDown, RunnerShutdown),
)
for global_runner_id in shard_assignments.runner_to_shard
)
if is_runner_connected and all_connected_or_further:
has_local = _node_has_download(
runner.bound_instance.bound_node_id,
shard_assignments.model_id,
global_download_status,
)
return TransferModelToDisk(
instance_id=instance.instance_id,
shard_metadata=runner.bound_instance.bound_shard,
has_local_model=has_local,
)
return None
def _load_model(
runners: Mapping[RunnerId, RunnerSupervisor],
all_runners: Mapping[RunnerId, RunnerStatus],
@@ -251,97 +195,38 @@ def _load_model(
instance = runner.bound_instance.instance
shard_assignments = instance.shard_assignments
# Transfer-only instances don't load models for inference
if shard_assignments.transfer_only:
all_local_downloads_complete = all(
nid in global_download_status
and any(
isinstance(dp, DownloadCompleted)
and dp.shard_metadata.model_card.model_id == shard_assignments.model_id
for dp in global_download_status[nid]
)
for nid in shard_assignments.node_to_runner
)
if not all_local_downloads_complete:
continue
is_single_node_instance = len(shard_assignments.runner_to_shard) == 1
is_single_node_instance = len(instance.shard_assignments.runner_to_shard) == 1
if is_single_node_instance and isinstance(runner.status, RunnerIdle):
return LoadModel(instance_id=instance.instance_id)
if is_single_node_instance:
# Single-node: require local download complete
if not _all_downloads_complete(shard_assignments, global_download_status):
continue
if isinstance(runner.status, RunnerIdle):
return LoadModel(instance_id=instance.instance_id, has_local_model=True)
else:
# Multi-node: require at least one node to have the model downloaded.
# Nodes without the model will receive it via MLX distributed transfer
# during model loading.
if not _any_download_complete(shard_assignments, global_download_status):
continue
is_runner_waiting = isinstance(runner.status, RunnerConnected)
is_runner_waiting = isinstance(runner.status, RunnerConnected)
all_ready_for_model = all(
isinstance(
all_runners.get(global_runner_id, None),
(RunnerConnected, RunnerLoading, RunnerLoaded),
)
for global_runner_id in shard_assignments.runner_to_shard
all_ready_for_model = all(
isinstance(
all_runners.get(global_runner_id, None),
(RunnerConnected, RunnerLoading, RunnerLoaded),
)
for global_runner_id in shard_assignments.runner_to_shard
)
if is_runner_waiting and all_ready_for_model:
has_local = _node_has_download(
runner.bound_instance.bound_node_id,
shard_assignments.model_id,
global_download_status,
)
return LoadModel(
instance_id=instance.instance_id,
has_local_model=has_local,
)
if is_runner_waiting and all_ready_for_model:
return LoadModel(instance_id=instance.instance_id)
return None
def _node_has_download(
nid: NodeId,
model_id: ModelId,
global_download_status: Mapping[NodeId, Sequence[DownloadProgress]],
) -> bool:
"""Check if a specific node has completed downloading the given model."""
return any(
isinstance(dp, DownloadCompleted)
and dp.shard_metadata.model_card.model_id == model_id
for dp in global_download_status.get(nid, [])
)
def _any_peer_has_model(
node_id: NodeId,
model_id: ModelId,
instance: Instance,
global_download_status: Mapping[NodeId, Sequence[DownloadProgress]],
) -> bool:
"""Check if any other node in the instance already has the model downloaded."""
return any(
_node_has_download(nid, model_id, global_download_status)
for nid in instance.shard_assignments.node_to_runner
if nid != node_id
)
def _all_downloads_complete(
shard_assignments: ShardAssignments,
global_download_status: Mapping[NodeId, Sequence[DownloadProgress]],
) -> bool:
"""Check if ALL nodes in the instance have completed downloading the model."""
return all(
_node_has_download(nid, shard_assignments.model_id, global_download_status)
for nid in shard_assignments.node_to_runner
)
def _any_download_complete(
shard_assignments: ShardAssignments,
global_download_status: Mapping[NodeId, Sequence[DownloadProgress]],
) -> bool:
"""Check if at least one node in the instance has completed downloading the model."""
return any(
_node_has_download(nid, shard_assignments.model_id, global_download_status)
for nid in shard_assignments.node_to_runner
)
def _ready_to_warmup(
runners: Mapping[RunnerId, RunnerSupervisor],
all_runners: Mapping[RunnerId, RunnerStatus],
@@ -349,11 +234,6 @@ def _ready_to_warmup(
for runner in runners.values():
instance = runner.bound_instance.instance
shard_assignments = instance.shard_assignments
# Transfer-only instances don't go through warmup
if shard_assignments.transfer_only:
continue
shard = runner.bound_instance.bound_shard
device_rank = shard.device_rank
runner_id = runner.bound_instance.bound_runner_id

View File

@@ -43,7 +43,6 @@ from exo.shared.types.tasks import (
TaskId,
TaskStatus,
TextGeneration,
TransferModelToDisk,
)
from exo.shared.types.text_generation import TextGenerationTaskParams
from exo.shared.types.worker.instances import BoundInstance
@@ -83,11 +82,6 @@ from exo.worker.engines.image import (
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.cache import KVPrefixCache
from exo.worker.engines.mlx.generator.generate import mlx_generate, warmup_inference
from exo.worker.engines.mlx.model_transfer import (
coordinate_transfer,
model_path_for_id,
transfer_all_files,
)
from exo.worker.engines.mlx.utils_mlx import (
apply_chat_template,
detect_thinking_prompt_suffix,
@@ -198,10 +192,7 @@ def main(
if ModelTask.TextGeneration in shard_metadata.model_card.tasks:
model, tokenizer = load_mlx_items(
bound_instance,
group,
on_timeout=on_model_load_timeout,
has_local_model=task.has_local_model,
bound_instance, group, on_timeout=on_model_load_timeout
)
logger.info(
f"model has_tool_calling={tokenizer.has_tool_calling}"
@@ -517,27 +508,6 @@ def main(
current_status = RunnerReady()
logger.info("runner ready")
case TransferModelToDisk() if (
isinstance(current_status, RunnerConnected) and group is not None
):
logger.info("starting disk-to-disk model transfer")
event_sender.send(TaskAcknowledged(task_id=task.task_id))
model_path = model_path_for_id(
task.shard_metadata.model_card.model_id
)
_, source_rank = coordinate_transfer(group, task.has_local_model)
is_source = group.rank() == source_rank
transfer_all_files(model_path, group, is_source)
logger.info("disk-to-disk model transfer complete")
current_status = RunnerShuttingDown()
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=current_status
)
)
current_status = RunnerShutdown()
case Shutdown():
current_status = RunnerShuttingDown()
logger.info("runner shutting down")

View File

@@ -112,7 +112,6 @@ def test_plan_loads_model_when_all_shards_downloaded_and_waiting():
assert isinstance(result, LoadModel)
assert result.instance_id == INSTANCE_1_ID
assert result.has_local_model is True
def test_plan_does_not_request_download_when_shard_already_downloaded():
@@ -158,11 +157,10 @@ def test_plan_does_not_request_download_when_shard_already_downloaded():
assert not isinstance(result, plan_mod.DownloadModel)
def test_plan_loads_model_when_any_node_has_download_for_multi_node():
def test_plan_does_not_load_model_until_all_shards_downloaded_globally():
"""
For multi-node instances, LoadModel should be emitted when at least one
node has the model downloaded. Nodes without the model will receive it
via MLX distributed transfer during model loading.
LoadModel should not be emitted while some shards are still missing from
the global_download_status.
"""
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
shard2 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)
@@ -187,7 +185,6 @@ def test_plan_loads_model_when_any_node_has_download_for_multi_node():
RUNNER_2_ID: RunnerConnected(),
}
# Only NODE_A has the model — LoadModel should still fire
global_download_status = {
NODE_A: [
DownloadCompleted(
@@ -206,42 +203,19 @@ def test_plan_loads_model_when_any_node_has_download_for_multi_node():
tasks={},
)
assert isinstance(result, LoadModel)
assert result.instance_id == INSTANCE_1_ID
assert result.has_local_model is True
assert result is None
def test_plan_does_not_load_model_when_no_node_has_download():
"""
LoadModel should not be emitted when no node has the model downloaded.
"""
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
shard2 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)
instance = get_mlx_ring_instance(
instance_id=INSTANCE_1_ID,
model_id=MODEL_A_ID,
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},
runner_to_shard={RUNNER_1_ID: shard1, RUNNER_2_ID: shard2},
)
bound_instance = BoundInstance(
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
)
local_runner = FakeRunnerSupervisor(
bound_instance=bound_instance, status=RunnerConnected()
)
runners = {RUNNER_1_ID: local_runner}
instances = {INSTANCE_1_ID: instance}
all_runners = {
RUNNER_1_ID: RunnerConnected(),
RUNNER_2_ID: RunnerConnected(),
}
# No node has the model
global_download_status: dict[NodeId, list[DownloadProgress]] = {
NODE_A: [],
NODE_B: [],
global_download_status = {
NODE_A: [
DownloadCompleted(
shard_metadata=shard1, node_id=NODE_A, total_bytes=Memory()
)
],
NODE_B: [
DownloadCompleted(
shard_metadata=shard2, node_id=NODE_B, total_bytes=Memory()
)
], # NODE_B has no downloads completed yet
}
result = plan_mod.plan(
@@ -253,57 +227,4 @@ def test_plan_does_not_load_model_when_no_node_has_download():
tasks={},
)
assert result is None
def test_plan_load_model_has_local_model_false_when_node_missing_download():
"""
For multi-node instances, when the local node does NOT have the model
but a peer does, LoadModel should be emitted with has_local_model=False.
"""
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
shard2 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)
instance = get_mlx_ring_instance(
instance_id=INSTANCE_1_ID,
model_id=MODEL_A_ID,
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},
runner_to_shard={RUNNER_1_ID: shard1, RUNNER_2_ID: shard2},
)
# NODE_B is the local node (bound_node_id=NODE_B), it does NOT have the model
bound_instance = BoundInstance(
instance=instance, bound_runner_id=RUNNER_2_ID, bound_node_id=NODE_B
)
local_runner = FakeRunnerSupervisor(
bound_instance=bound_instance, status=RunnerConnected()
)
runners = {RUNNER_2_ID: local_runner}
instances = {INSTANCE_1_ID: instance}
all_runners = {
RUNNER_1_ID: RunnerConnected(),
RUNNER_2_ID: RunnerConnected(),
}
# Only NODE_A has the model, NODE_B does not
global_download_status: dict[NodeId, list[DownloadProgress]] = {
NODE_A: [
DownloadCompleted(
shard_metadata=shard1, node_id=NODE_A, total_bytes=Memory()
)
],
NODE_B: [],
}
result = plan_mod.plan(
node_id=NODE_B,
runners=runners, # type: ignore
global_download_status=global_download_status,
instances=instances,
all_runners=all_runners,
tasks={},
)
assert isinstance(result, LoadModel)
assert result.instance_id == INSTANCE_1_ID
assert result.has_local_model is False
assert result is not None