mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-14 16:15:43 -05:00
Compare commits
16 Commits
alexcheema
...
e2e-tests
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8bf4d1f585 | ||
|
|
5e27e4e719 | ||
|
|
b249757116 | ||
|
|
5c0b769bf8 | ||
|
|
702886d147 | ||
|
|
2526b7d166 | ||
|
|
ffb79d88ca | ||
|
|
4f32b9f180 | ||
|
|
e8203596ab | ||
|
|
b88749a6c5 | ||
|
|
4a446b2779 | ||
|
|
a82feed8e3 | ||
|
|
da6e626f6f | ||
|
|
cf23916b8b | ||
|
|
80b29ba0d9 | ||
|
|
b6214c297f |
15
.dockerignore
Normal file
15
.dockerignore
Normal file
@@ -0,0 +1,15 @@
|
||||
.venv/
|
||||
.direnv/
|
||||
target/
|
||||
.git/
|
||||
.idea/
|
||||
.pytest_cache/
|
||||
.ruff_cache/
|
||||
dashboard/node_modules/
|
||||
dashboard/.svelte-kit/
|
||||
dashboard/build/
|
||||
dist/
|
||||
*.pdb
|
||||
**/__pycache__
|
||||
**/.DS_Store
|
||||
.mlx_typings/
|
||||
42
.github/workflows/e2e.yml
vendored
Normal file
42
.github/workflows/e2e.yml
vendored
Normal file
@@ -0,0 +1,42 @@
|
||||
name: e2e-tests
|
||||
|
||||
on:
|
||||
push:
|
||||
pull_request:
|
||||
branches:
|
||||
- staging
|
||||
- main
|
||||
|
||||
jobs:
|
||||
e2e:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- name: Free up disk space
|
||||
run: |
|
||||
sudo rm -rf /usr/share/dotnet /usr/local/lib/android /opt/ghc \
|
||||
/opt/hostedtoolcache /usr/local/share/boost /usr/share/swift \
|
||||
/opt/microsoft /opt/az
|
||||
docker system prune -af
|
||||
df -h /
|
||||
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
lfs: false
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Build E2E image with cache
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: .
|
||||
file: e2e/Dockerfile
|
||||
tags: exo-e2e:latest
|
||||
load: true
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
|
||||
- name: Run E2E tests
|
||||
run: python3 e2e/run_all.py
|
||||
37
AGENTS.md
37
AGENTS.md
@@ -194,40 +194,3 @@ GitHub's API doesn't support direct image upload for PR comments. Workaround:
|
||||
git push origin <branch>
|
||||
```
|
||||
The images still render in the PR comment because they reference the permanent commit SHA.
|
||||
|
||||
## Running exo Remotely via SSH (macOS mDNS)
|
||||
|
||||
**CRITICAL: On macOS, mDNS multicast (used for peer discovery) only works when the process runs in a proper macOS user session.** Background processes started via `nohup ... &`, `screen`, or plain SSH commands will NOT send mDNS packets and nodes will never discover each other.
|
||||
|
||||
### The Problem
|
||||
When you SSH into a Mac and run `nohup uv run exo &`, the process runs in a detached session without access to macOS multicast networking. The exo node will start but will never discover peers, even if they're on the same network.
|
||||
|
||||
### The Solution: Use `open` with a `.command` wrapper
|
||||
|
||||
Create a `.command` script that `open` will execute in the proper macOS GUI session context:
|
||||
|
||||
```bash
|
||||
# 1. Create wrapper script on the remote machine
|
||||
ssh user@remote-mac "cat > /tmp/run_exo.command << 'SCRIPT'
|
||||
#!/bin/bash
|
||||
export PATH=/opt/homebrew/bin:\$HOME/.local/bin:\$PATH
|
||||
export EXO_LIBP2P_NAMESPACE=your-namespace # must match across all nodes
|
||||
cd ~/path/to/exo
|
||||
exec uv run exo -vv 2>&1 | tee /tmp/exo.log
|
||||
SCRIPT
|
||||
chmod +x /tmp/run_exo.command"
|
||||
|
||||
# 2. Launch it via `open` (runs in macOS GUI session with proper mDNS)
|
||||
ssh user@remote-mac "open /tmp/run_exo.command"
|
||||
|
||||
# 3. Check logs
|
||||
ssh user@remote-mac "tail -f /tmp/exo.log"
|
||||
```
|
||||
|
||||
### Key Details
|
||||
- **`EXO_LIBP2P_NAMESPACE`**: All nodes in a cluster MUST use the same namespace value. The EXO.app uses a build-specific namespace (check with `ps eww <pid> | grep NAMESPACE`). If mixing dev builds with EXO.app, set the dev build's namespace to match.
|
||||
- **`open *.command`**: This is the macOS equivalent of double-clicking the script in Finder. It runs in the user's GUI session with full network access.
|
||||
- **Do NOT use**: `nohup ... &`, `screen -dm`, `tmux new-session -d`, or `sshpass`. These all create detached sessions where mDNS won't work.
|
||||
- **Killing**: `ssh user@remote-mac "pkill -f 'python.*exo'"` works fine for stopping.
|
||||
- **Dashboard**: Must be built before running: `cd dashboard && npm install && npm run build && cd ..`. Node.js is at `/opt/homebrew/bin/node` on Apple Silicon Macs.
|
||||
- **Verifying cluster**: `curl -s http://localhost:52415/state | python3 -c "import json,sys; s=json.load(sys.stdin); print(len(s['topology']['nodes']), 'nodes')"` — should show 2+ nodes.
|
||||
|
||||
1
conftest.py
Normal file
1
conftest.py
Normal file
@@ -0,0 +1 @@
|
||||
collect_ignore = ["tests/start_distributed_test.py"]
|
||||
58
e2e/Dockerfile
Normal file
58
e2e/Dockerfile
Normal file
@@ -0,0 +1,58 @@
|
||||
# Stage 1: Build the dashboard
|
||||
FROM node:22-slim AS dashboard
|
||||
WORKDIR /app/dashboard
|
||||
COPY dashboard/package.json dashboard/package-lock.json ./
|
||||
RUN npm ci
|
||||
COPY dashboard/ .
|
||||
RUN npm run build
|
||||
|
||||
# Stage 2: Build and run exo
|
||||
FROM python:3.13-slim
|
||||
|
||||
# Install system dependencies
|
||||
# libblas-dev/liblapack-dev/liblapacke-dev are required by MLX CPU backend on Linux
|
||||
RUN apt-get update && apt-get install -y \
|
||||
build-essential \
|
||||
pkg-config \
|
||||
libssl-dev \
|
||||
libblas-dev \
|
||||
liblapack-dev \
|
||||
liblapacke-dev \
|
||||
curl \
|
||||
protobuf-compiler \
|
||||
iptables \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install Rust nightly
|
||||
RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y --default-toolchain nightly
|
||||
ENV PATH="/root/.cargo/bin:${PATH}"
|
||||
|
||||
# Wrap g++ with -fpermissive to fix MLX CPU JIT compilation with GCC 14
|
||||
# (GCC 14 treats _Float128/_Float32/_Float64 as built-in types, conflicting with MLX-generated code)
|
||||
# Must be done BEFORE uv sync so any source builds also get the fix
|
||||
RUN mv /usr/bin/g++ /usr/bin/g++.real && \
|
||||
printf '#!/bin/sh\nexec /usr/bin/g++.real -fpermissive "$@"\n' > /usr/bin/g++ && \
|
||||
chmod +x /usr/bin/g++
|
||||
|
||||
# Install uv
|
||||
COPY --from=ghcr.io/astral-sh/uv:latest /uv /usr/local/bin/uv
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Copy dependency files first for better layer caching
|
||||
COPY pyproject.toml Cargo.toml uv.lock README.md ./
|
||||
COPY rust/ ./rust/
|
||||
COPY bench/pyproject.toml ./bench/pyproject.toml
|
||||
|
||||
# Copy source and resources
|
||||
COPY src/ ./src/
|
||||
COPY resources/ ./resources/
|
||||
|
||||
# Copy built dashboard from stage 1
|
||||
COPY --from=dashboard /app/dashboard/build ./dashboard/build/
|
||||
|
||||
# Install Python deps and build Rust bindings, then clean up build artifacts
|
||||
# to keep the layer small (Rust target/ and cargo registry can be 1-2 GB)
|
||||
RUN uv sync && rm -rf /app/rust/target /root/.cargo/registry /root/.cargo/git
|
||||
|
||||
CMD [".venv/bin/exo", "-v"]
|
||||
195
e2e/conftest.py
Normal file
195
e2e/conftest.py
Normal file
@@ -0,0 +1,195 @@
|
||||
"""Shared E2E test infrastructure for exo cluster tests."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from urllib.error import URLError
|
||||
from urllib.request import Request, urlopen
|
||||
|
||||
E2E_DIR = Path(__file__).parent.resolve()
|
||||
TIMEOUT = int(os.environ.get("E2E_TIMEOUT", "120"))
|
||||
|
||||
|
||||
class Cluster:
|
||||
"""Async wrapper around a docker compose exo cluster."""
|
||||
|
||||
def __init__(self, name: str, overrides: list[str] | None = None):
|
||||
self.name = name
|
||||
self.project = f"e2e-{name}"
|
||||
compose_files = [str(E2E_DIR / "docker-compose.yml")]
|
||||
for path in overrides or []:
|
||||
compose_files.append(str(E2E_DIR / path))
|
||||
self._compose_base = [
|
||||
"docker",
|
||||
"compose",
|
||||
"-p",
|
||||
self.project,
|
||||
*[arg for f in compose_files for arg in ("-f", f)],
|
||||
]
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *exc):
|
||||
await self.stop()
|
||||
|
||||
async def _run(self, *args: str, check: bool = True) -> str:
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
*self._compose_base,
|
||||
*args,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.STDOUT,
|
||||
)
|
||||
stdout, _ = await proc.communicate()
|
||||
output = stdout.decode()
|
||||
if check and proc.returncode != 0:
|
||||
print(output, file=sys.stderr)
|
||||
raise RuntimeError(
|
||||
f"docker compose {' '.join(args)} failed (rc={proc.returncode})"
|
||||
)
|
||||
return output
|
||||
|
||||
async def build(self):
|
||||
# Skip build if the image was pre-built (e.g. in CI with buildx cache)
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
"docker",
|
||||
"image",
|
||||
"inspect",
|
||||
"exo-e2e:latest",
|
||||
stdout=asyncio.subprocess.DEVNULL,
|
||||
stderr=asyncio.subprocess.DEVNULL,
|
||||
)
|
||||
await proc.wait()
|
||||
if proc.returncode == 0:
|
||||
print(" Using pre-built image (exo-e2e:latest)")
|
||||
return
|
||||
print(" Building images...")
|
||||
await self._run("build", "--quiet")
|
||||
|
||||
async def start(self):
|
||||
print(" Starting cluster...")
|
||||
await self._run("up", "-d")
|
||||
|
||||
async def stop(self):
|
||||
print(" Cleaning up...")
|
||||
await self._run("down", "--timeout", "5", check=False)
|
||||
|
||||
async def logs(self) -> str:
|
||||
return await self._run("logs", check=False)
|
||||
|
||||
async def exec(
|
||||
self, service: str, *cmd: str, check: bool = True
|
||||
) -> tuple[int, str]:
|
||||
"""Run a command inside a running container. Returns (returncode, output)."""
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
*self._compose_base,
|
||||
"exec",
|
||||
"-T",
|
||||
service,
|
||||
*cmd,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.STDOUT,
|
||||
)
|
||||
stdout, _ = await proc.communicate()
|
||||
output = stdout.decode()
|
||||
if check and proc.returncode != 0:
|
||||
raise RuntimeError(
|
||||
f"exec {' '.join(cmd)} in {service} failed (rc={proc.returncode})"
|
||||
)
|
||||
return proc.returncode, output
|
||||
|
||||
async def wait_for(self, description: str, check_fn, timeout: int = TIMEOUT):
|
||||
"""Poll check_fn every 2s until it returns True or timeout expires."""
|
||||
print(f" Waiting for {description}...")
|
||||
deadline = asyncio.get_event_loop().time() + timeout
|
||||
while asyncio.get_event_loop().time() < deadline:
|
||||
if await check_fn():
|
||||
print(f" {description}")
|
||||
return
|
||||
await asyncio.sleep(2)
|
||||
output = await self.logs()
|
||||
print(f"--- cluster logs ---\n{output}\n---", file=sys.stderr)
|
||||
raise TimeoutError(f"Timed out waiting for {description}")
|
||||
|
||||
async def assert_healthy(self):
|
||||
"""Verify the cluster formed correctly: nodes started, discovered each other, elected a master, API responds."""
|
||||
|
||||
async def both_nodes_started():
|
||||
log = await self.logs()
|
||||
return log.count("Starting node") >= 2
|
||||
|
||||
async def nodes_discovered():
|
||||
log = await self.logs()
|
||||
return log.count("ConnectionMessageType.Connected") >= 2
|
||||
|
||||
async def master_elected():
|
||||
log = await self.logs()
|
||||
return "demoting self" in log
|
||||
|
||||
async def api_responding():
|
||||
try:
|
||||
with urlopen("http://localhost:52415/v1/models", timeout=3) as resp:
|
||||
return resp.status == 200
|
||||
except (URLError, OSError):
|
||||
return False
|
||||
|
||||
await self.wait_for("Both nodes started", both_nodes_started)
|
||||
await self.wait_for("Nodes discovered each other", nodes_discovered)
|
||||
await self.wait_for("Master election resolved", master_elected)
|
||||
await self.wait_for("API responding", api_responding)
|
||||
|
||||
async def _api(
|
||||
self, method: str, path: str, body: dict | None = None, timeout: int = 30
|
||||
) -> dict:
|
||||
"""Make an API request to the cluster. Returns parsed JSON."""
|
||||
url = f"http://localhost:52415{path}"
|
||||
data = json.dumps(body).encode() if body else None
|
||||
req = Request(
|
||||
url, data=data, headers={"Content-Type": "application/json"}, method=method
|
||||
)
|
||||
loop = asyncio.get_event_loop()
|
||||
resp_bytes = await loop.run_in_executor(
|
||||
None, lambda: urlopen(req, timeout=timeout).read()
|
||||
)
|
||||
return json.loads(resp_bytes)
|
||||
|
||||
async def place_model(self, model: str, timeout: int = 600):
|
||||
"""Place a model instance on the cluster (triggers download) and wait until it's ready."""
|
||||
await self._api("POST", "/place_instance", {"model_id": model})
|
||||
|
||||
async def model_ready():
|
||||
try:
|
||||
resp = await self._api("GET", "/v1/models")
|
||||
return any(m.get("id") == model for m in resp.get("data", []))
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
await self.wait_for(f"Model {model} ready", model_ready, timeout=timeout)
|
||||
|
||||
async def chat(
|
||||
self, model: str, messages: list[dict], timeout: int = 600, **kwargs
|
||||
) -> dict:
|
||||
"""Send a chat completion request. Retries until model is downloaded and inference completes."""
|
||||
body = json.dumps({"model": model, "messages": messages, **kwargs}).encode()
|
||||
deadline = asyncio.get_event_loop().time() + timeout
|
||||
last_error = None
|
||||
|
||||
while asyncio.get_event_loop().time() < deadline:
|
||||
try:
|
||||
req = Request(
|
||||
"http://localhost:52415/v1/chat/completions",
|
||||
data=body,
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
loop = asyncio.get_event_loop()
|
||||
resp_bytes = await loop.run_in_executor(
|
||||
None, lambda r=req: urlopen(r, timeout=300).read()
|
||||
)
|
||||
return json.loads(resp_bytes)
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
await asyncio.sleep(5)
|
||||
|
||||
raise TimeoutError(f"Chat request failed after {timeout}s: {last_error}")
|
||||
20
e2e/docker-compose.yml
Normal file
20
e2e/docker-compose.yml
Normal file
@@ -0,0 +1,20 @@
|
||||
services:
|
||||
exo-node-1:
|
||||
image: exo-e2e:latest
|
||||
build:
|
||||
context: ..
|
||||
dockerfile: e2e/Dockerfile
|
||||
environment:
|
||||
- EXO_LIBP2P_NAMESPACE=docker-e2e
|
||||
command: [".venv/bin/exo", "-v"]
|
||||
ports:
|
||||
- "52415:52415"
|
||||
|
||||
exo-node-2:
|
||||
image: exo-e2e:latest
|
||||
build:
|
||||
context: ..
|
||||
dockerfile: e2e/Dockerfile
|
||||
environment:
|
||||
- EXO_LIBP2P_NAMESPACE=docker-e2e
|
||||
command: [".venv/bin/exo", "-v"]
|
||||
77
e2e/run_all.py
Normal file
77
e2e/run_all.py
Normal file
@@ -0,0 +1,77 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Discovers and runs all E2E tests in e2e/test_*.py.
|
||||
|
||||
Tests with '# slow' on the first line of their docstring are skipped
|
||||
unless --slow is passed or E2E_SLOW=1 is set.
|
||||
"""
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
E2E_DIR = Path(__file__).parent.resolve()
|
||||
|
||||
|
||||
def is_slow(test_file: Path) -> bool:
|
||||
"""Check if the test file is marked as slow (has '# slow' in first 3 lines)."""
|
||||
with open(test_file) as f:
|
||||
for line in f:
|
||||
if line.strip().startswith("#"):
|
||||
continue
|
||||
if line.strip().startswith('"""') or line.strip().startswith("'''"):
|
||||
# Read into the docstring
|
||||
for doc_line in f:
|
||||
if "slow" in doc_line.lower() and doc_line.strip().startswith(
|
||||
"slow"
|
||||
):
|
||||
return True
|
||||
if '"""' in doc_line or "'''" in doc_line:
|
||||
break
|
||||
break
|
||||
return False
|
||||
|
||||
|
||||
def main():
|
||||
run_slow = "--slow" in sys.argv or os.environ.get("E2E_SLOW") == "1"
|
||||
if "--update-snapshots" in sys.argv:
|
||||
os.environ["UPDATE_SNAPSHOTS"] = "1"
|
||||
test_files = sorted(E2E_DIR.glob("test_*.py"))
|
||||
if not test_files:
|
||||
print("No test files found")
|
||||
sys.exit(1)
|
||||
|
||||
passed = 0
|
||||
failed = 0
|
||||
skipped = 0
|
||||
failures = []
|
||||
|
||||
for test_file in test_files:
|
||||
name = test_file.stem
|
||||
if is_slow(test_file) and not run_slow:
|
||||
print(f"=== {name} === SKIPPED (slow, use --slow to run)")
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
print(f"=== {name} ===")
|
||||
result = subprocess.run([sys.executable, str(test_file)])
|
||||
if result.returncode == 0:
|
||||
passed += 1
|
||||
else:
|
||||
failed += 1
|
||||
failures.append(name)
|
||||
print()
|
||||
|
||||
total = passed + failed + skipped
|
||||
print("================================")
|
||||
print(
|
||||
f"{passed}/{total} tests passed" + (f", {skipped} skipped" if skipped else "")
|
||||
)
|
||||
|
||||
if failed:
|
||||
print(f"Failed: {' '.join(failures)}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
69
e2e/snapshot.py
Normal file
69
e2e/snapshot.py
Normal file
@@ -0,0 +1,69 @@
|
||||
"""Snapshot testing infrastructure for E2E tests.
|
||||
|
||||
Provides deterministic regression testing by comparing inference output
|
||||
against saved snapshots. On first run, snapshots are created automatically.
|
||||
Set UPDATE_SNAPSHOTS=1 to regenerate snapshots when output intentionally changes.
|
||||
|
||||
Snapshots are stored per-architecture (e.g. snapshots/x86_64/, snapshots/arm64/)
|
||||
since floating-point results differ between CPU architectures.
|
||||
"""
|
||||
|
||||
import difflib
|
||||
import json
|
||||
import os
|
||||
import platform
|
||||
from pathlib import Path
|
||||
|
||||
ARCH = platform.machine()
|
||||
SNAPSHOTS_DIR = Path(__file__).parent / "snapshots" / ARCH
|
||||
|
||||
|
||||
def assert_snapshot(
|
||||
name: str,
|
||||
content: str,
|
||||
metadata: dict,
|
||||
) -> None:
|
||||
"""Compare content against a saved snapshot, or create one if missing.
|
||||
|
||||
Args:
|
||||
name: Snapshot identifier (used as filename: snapshots/{arch}/{name}.json).
|
||||
content: The actual inference output to compare.
|
||||
metadata: Additional context stored alongside content (model, seed, etc.).
|
||||
Not used for comparison -- purely documentary.
|
||||
|
||||
Raises:
|
||||
AssertionError: If content doesn't match the saved snapshot.
|
||||
|
||||
Environment:
|
||||
UPDATE_SNAPSHOTS=1: Overwrite existing snapshot with actual content.
|
||||
"""
|
||||
snapshot_file = SNAPSHOTS_DIR / f"{name}.json"
|
||||
update = os.environ.get("UPDATE_SNAPSHOTS") == "1"
|
||||
|
||||
if snapshot_file.exists() and not update:
|
||||
snapshot = json.loads(snapshot_file.read_text())
|
||||
expected = snapshot["content"]
|
||||
if content != expected:
|
||||
diff = "\n".join(
|
||||
difflib.unified_diff(
|
||||
expected.splitlines(),
|
||||
content.splitlines(),
|
||||
fromfile=f"expected ({snapshot_file.relative_to(SNAPSHOTS_DIR.parent.parent)})",
|
||||
tofile="actual",
|
||||
lineterm="",
|
||||
)
|
||||
)
|
||||
raise AssertionError(
|
||||
f"Snapshot mismatch for '{name}' on {ARCH}!\n\n"
|
||||
f"{diff}\n\n"
|
||||
f"Expected: {expected!r}\n"
|
||||
f"Actual: {content!r}\n\n"
|
||||
f"To update: UPDATE_SNAPSHOTS=1 python3 e2e/run_all.py"
|
||||
)
|
||||
print(f" Output matches snapshot ({ARCH}/{snapshot_file.name})")
|
||||
else:
|
||||
SNAPSHOTS_DIR.mkdir(parents=True, exist_ok=True)
|
||||
snapshot_data = {**metadata, "arch": ARCH, "content": content}
|
||||
snapshot_file.write_text(json.dumps(snapshot_data, indent=2) + "\n")
|
||||
action = "Updated" if update else "Created"
|
||||
print(f" {action} snapshot: {ARCH}/{snapshot_file.name}")
|
||||
22
e2e/test_cluster_formation.py
Normal file
22
e2e/test_cluster_formation.py
Normal file
@@ -0,0 +1,22 @@
|
||||
"""Test: Basic cluster formation.
|
||||
|
||||
Verifies two nodes discover each other, elect a master, and the API responds.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
|
||||
sys.path.insert(0, str(__import__("pathlib").Path(__file__).parent))
|
||||
from conftest import Cluster
|
||||
|
||||
|
||||
async def main():
|
||||
async with Cluster("cluster_formation") as cluster:
|
||||
await cluster.build()
|
||||
await cluster.start()
|
||||
await cluster.assert_healthy()
|
||||
print("PASSED: cluster_formation")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
60
e2e/test_inference_snapshot.py
Normal file
60
e2e/test_inference_snapshot.py
Normal file
@@ -0,0 +1,60 @@
|
||||
"""Test: Deterministic inference output (snapshot test).
|
||||
|
||||
Sends a chat completion request with a fixed seed,
|
||||
then verifies the output matches a known-good snapshot. This ensures
|
||||
inference produces consistent results across runs.
|
||||
|
||||
Uses MLX CPU backend in Docker on x86 Linux.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
from snapshot import assert_snapshot
|
||||
|
||||
from conftest import Cluster
|
||||
|
||||
MODEL = "mlx-community/Qwen3-0.6B-4bit"
|
||||
SEED = 42
|
||||
PROMPT = "What is 2+2? Reply with just the number."
|
||||
MAX_TOKENS = 32
|
||||
|
||||
|
||||
async def main():
|
||||
async with Cluster("inference_snapshot") as cluster:
|
||||
await cluster.build()
|
||||
await cluster.start()
|
||||
await cluster.assert_healthy()
|
||||
|
||||
print(f" Launching model {MODEL}...")
|
||||
await cluster.place_model(MODEL)
|
||||
|
||||
print(f" Sending chat completion (seed={SEED})...")
|
||||
resp = await cluster.chat(
|
||||
model=MODEL,
|
||||
messages=[{"role": "user", "content": PROMPT}],
|
||||
seed=SEED,
|
||||
max_tokens=MAX_TOKENS,
|
||||
)
|
||||
|
||||
content = resp["choices"][0]["message"]["content"]
|
||||
print(f" Response: {content!r}")
|
||||
|
||||
assert_snapshot(
|
||||
name="inference_snapshot",
|
||||
content=content,
|
||||
metadata={
|
||||
"model": MODEL,
|
||||
"seed": SEED,
|
||||
"prompt": PROMPT,
|
||||
"max_tokens": MAX_TOKENS,
|
||||
},
|
||||
)
|
||||
|
||||
print("PASSED: inference_snapshot")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
47
e2e/test_no_internet.py
Normal file
47
e2e/test_no_internet.py
Normal file
@@ -0,0 +1,47 @@
|
||||
"""Test: Cluster works without internet access.
|
||||
|
||||
Verifies exo functions correctly when containers can talk to each other
|
||||
but cannot reach the internet. Uses iptables to block all outbound traffic
|
||||
except private subnets and multicast (for mDNS discovery).
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
|
||||
sys.path.insert(0, str(__import__("pathlib").Path(__file__).parent))
|
||||
from conftest import Cluster
|
||||
|
||||
|
||||
async def main():
|
||||
async with Cluster(
|
||||
"no_internet",
|
||||
overrides=["tests/no_internet/docker-compose.override.yml"],
|
||||
) as cluster:
|
||||
await cluster.build()
|
||||
await cluster.start()
|
||||
await cluster.assert_healthy()
|
||||
|
||||
# Verify internet is actually blocked from inside the containers
|
||||
for node in ["exo-node-1", "exo-node-2"]:
|
||||
rc, _ = await cluster.exec(
|
||||
node,
|
||||
"curl",
|
||||
"-sf",
|
||||
"--max-time",
|
||||
"3",
|
||||
"https://huggingface.co",
|
||||
check=False,
|
||||
)
|
||||
assert rc != 0, f"{node} should not be able to reach the internet"
|
||||
print(f" {node}: internet correctly blocked")
|
||||
|
||||
# Verify exo detected no internet connectivity
|
||||
log = await cluster.logs()
|
||||
assert "Internet connectivity: False" in log, "exo should detect no internet"
|
||||
print(" exo correctly detected no internet connectivity")
|
||||
|
||||
print("PASSED: no_internet")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
58
e2e/test_snapshot_code_gen.py
Normal file
58
e2e/test_snapshot_code_gen.py
Normal file
@@ -0,0 +1,58 @@
|
||||
"""Test: Code generation snapshot.
|
||||
|
||||
Verifies deterministic output for a code generation prompt.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
from snapshot import assert_snapshot
|
||||
|
||||
from conftest import Cluster
|
||||
|
||||
MODEL = "mlx-community/Qwen3-0.6B-4bit"
|
||||
SEED = 42
|
||||
PROMPT = (
|
||||
"Write a Python function to reverse a string. Only output the code, no explanation."
|
||||
)
|
||||
MAX_TOKENS = 64
|
||||
|
||||
|
||||
async def main():
|
||||
async with Cluster("snapshot_code_gen") as cluster:
|
||||
await cluster.build()
|
||||
await cluster.start()
|
||||
await cluster.assert_healthy()
|
||||
|
||||
print(f" Launching model {MODEL}...")
|
||||
await cluster.place_model(MODEL)
|
||||
|
||||
print(f" Sending chat completion (seed={SEED})...")
|
||||
resp = await cluster.chat(
|
||||
model=MODEL,
|
||||
messages=[{"role": "user", "content": PROMPT}],
|
||||
seed=SEED,
|
||||
max_tokens=MAX_TOKENS,
|
||||
)
|
||||
|
||||
content = resp["choices"][0]["message"]["content"]
|
||||
print(f" Response: {content!r}")
|
||||
|
||||
assert_snapshot(
|
||||
name="snapshot_code_gen",
|
||||
content=content,
|
||||
metadata={
|
||||
"model": MODEL,
|
||||
"seed": SEED,
|
||||
"prompt": PROMPT,
|
||||
"max_tokens": MAX_TOKENS,
|
||||
},
|
||||
)
|
||||
|
||||
print("PASSED: snapshot_code_gen")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
63
e2e/test_snapshot_edge.py
Normal file
63
e2e/test_snapshot_edge.py
Normal file
@@ -0,0 +1,63 @@
|
||||
"""Test: Edge case snapshots.
|
||||
|
||||
Verifies deterministic output for edge-case prompts: single word input,
|
||||
special characters, and unicode.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
from snapshot import assert_snapshot
|
||||
|
||||
from conftest import Cluster
|
||||
|
||||
MODEL = "mlx-community/Qwen3-0.6B-4bit"
|
||||
SEED = 42
|
||||
MAX_TOKENS = 32
|
||||
|
||||
CASES = [
|
||||
("edge_single_word", "Hi"),
|
||||
("edge_special_chars", "What does 2 * (3 + 4) / 7 - 1 equal? Use <math> tags."),
|
||||
("edge_unicode", "Translate 'hello' to Japanese, Chinese, and Korean."),
|
||||
]
|
||||
|
||||
|
||||
async def main():
|
||||
async with Cluster("snapshot_edge") as cluster:
|
||||
await cluster.build()
|
||||
await cluster.start()
|
||||
await cluster.assert_healthy()
|
||||
|
||||
print(f" Launching model {MODEL}...")
|
||||
await cluster.place_model(MODEL)
|
||||
|
||||
for snapshot_name, prompt in CASES:
|
||||
print(f" [{snapshot_name}] Sending: {prompt!r}")
|
||||
resp = await cluster.chat(
|
||||
model=MODEL,
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
seed=SEED,
|
||||
max_tokens=MAX_TOKENS,
|
||||
)
|
||||
|
||||
content = resp["choices"][0]["message"]["content"]
|
||||
print(f" [{snapshot_name}] Response: {content!r}")
|
||||
|
||||
assert_snapshot(
|
||||
name=snapshot_name,
|
||||
content=content,
|
||||
metadata={
|
||||
"model": MODEL,
|
||||
"seed": SEED,
|
||||
"prompt": prompt,
|
||||
"max_tokens": MAX_TOKENS,
|
||||
},
|
||||
)
|
||||
|
||||
print("PASSED: snapshot_edge")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
56
e2e/test_snapshot_long_output.py
Normal file
56
e2e/test_snapshot_long_output.py
Normal file
@@ -0,0 +1,56 @@
|
||||
"""Test: Longer output snapshot.
|
||||
|
||||
Verifies deterministic output with a higher max_tokens (128).
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
from snapshot import assert_snapshot
|
||||
|
||||
from conftest import Cluster
|
||||
|
||||
MODEL = "mlx-community/Qwen3-0.6B-4bit"
|
||||
SEED = 42
|
||||
PROMPT = "Explain how a binary search algorithm works."
|
||||
MAX_TOKENS = 128
|
||||
|
||||
|
||||
async def main():
|
||||
async with Cluster("snapshot_long_output") as cluster:
|
||||
await cluster.build()
|
||||
await cluster.start()
|
||||
await cluster.assert_healthy()
|
||||
|
||||
print(f" Launching model {MODEL}...")
|
||||
await cluster.place_model(MODEL)
|
||||
|
||||
print(f" Sending chat completion (seed={SEED}, max_tokens={MAX_TOKENS})...")
|
||||
resp = await cluster.chat(
|
||||
model=MODEL,
|
||||
messages=[{"role": "user", "content": PROMPT}],
|
||||
seed=SEED,
|
||||
max_tokens=MAX_TOKENS,
|
||||
)
|
||||
|
||||
content = resp["choices"][0]["message"]["content"]
|
||||
print(f" Response: {content!r}")
|
||||
|
||||
assert_snapshot(
|
||||
name="snapshot_long_output",
|
||||
content=content,
|
||||
metadata={
|
||||
"model": MODEL,
|
||||
"seed": SEED,
|
||||
"prompt": PROMPT,
|
||||
"max_tokens": MAX_TOKENS,
|
||||
},
|
||||
)
|
||||
|
||||
print("PASSED: snapshot_long_output")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
72
e2e/test_snapshot_multi_model.py
Normal file
72
e2e/test_snapshot_multi_model.py
Normal file
@@ -0,0 +1,72 @@
|
||||
"""Test: Multi-model snapshot tests.
|
||||
slow
|
||||
|
||||
Verifies deterministic output across different model architectures to catch
|
||||
model-specific regressions. Each model uses its own snapshot file.
|
||||
Run with: python3 e2e/run_all.py --slow or E2E_SLOW=1 python3 e2e/run_all.py
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
from snapshot import assert_snapshot
|
||||
|
||||
from conftest import Cluster
|
||||
|
||||
SEED = 42
|
||||
PROMPT = "What is the capital of France?"
|
||||
MAX_TOKENS = 32
|
||||
|
||||
MODELS = [
|
||||
"mlx-community/SmolLM2-135M-Instruct",
|
||||
"mlx-community/Llama-3.2-1B-Instruct-4bit",
|
||||
"mlx-community/gemma-2-2b-it-4bit",
|
||||
]
|
||||
|
||||
|
||||
async def main():
|
||||
async with Cluster("snapshot_multi_model") as cluster:
|
||||
await cluster.build()
|
||||
await cluster.start()
|
||||
await cluster.assert_healthy()
|
||||
|
||||
for model in MODELS:
|
||||
short_name = (
|
||||
model.split("/")[-1].lower().replace("-", "_").replace(".", "_")
|
||||
)
|
||||
snapshot_name = f"snapshot_multi_{short_name}"
|
||||
|
||||
print(f" Launching model {model}...")
|
||||
await cluster.place_model(model)
|
||||
|
||||
print(f" Sending chat completion (seed={SEED})...")
|
||||
resp = await cluster.chat(
|
||||
model=model,
|
||||
messages=[{"role": "user", "content": PROMPT}],
|
||||
seed=SEED,
|
||||
max_tokens=MAX_TOKENS,
|
||||
)
|
||||
|
||||
content = resp["choices"][0]["message"]["content"]
|
||||
print(f" [{short_name}] Response: {content!r}")
|
||||
|
||||
assert_snapshot(
|
||||
name=snapshot_name,
|
||||
content=content,
|
||||
metadata={
|
||||
"model": model,
|
||||
"seed": SEED,
|
||||
"prompt": PROMPT,
|
||||
"max_tokens": MAX_TOKENS,
|
||||
},
|
||||
)
|
||||
|
||||
print(f" [{short_name}] PASSED")
|
||||
|
||||
print("PASSED: snapshot_multi_model")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
56
e2e/test_snapshot_reasoning.py
Normal file
56
e2e/test_snapshot_reasoning.py
Normal file
@@ -0,0 +1,56 @@
|
||||
"""Test: Reasoning/math snapshot.
|
||||
|
||||
Verifies deterministic output for a simple reasoning prompt.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
from snapshot import assert_snapshot
|
||||
|
||||
from conftest import Cluster
|
||||
|
||||
MODEL = "mlx-community/Qwen3-0.6B-4bit"
|
||||
SEED = 42
|
||||
PROMPT = "If I have 3 apples and give away 1, how many do I have? Think step by step."
|
||||
MAX_TOKENS = 64
|
||||
|
||||
|
||||
async def main():
|
||||
async with Cluster("snapshot_reasoning") as cluster:
|
||||
await cluster.build()
|
||||
await cluster.start()
|
||||
await cluster.assert_healthy()
|
||||
|
||||
print(f" Launching model {MODEL}...")
|
||||
await cluster.place_model(MODEL)
|
||||
|
||||
print(f" Sending chat completion (seed={SEED})...")
|
||||
resp = await cluster.chat(
|
||||
model=MODEL,
|
||||
messages=[{"role": "user", "content": PROMPT}],
|
||||
seed=SEED,
|
||||
max_tokens=MAX_TOKENS,
|
||||
)
|
||||
|
||||
content = resp["choices"][0]["message"]["content"]
|
||||
print(f" Response: {content!r}")
|
||||
|
||||
assert_snapshot(
|
||||
name="snapshot_reasoning",
|
||||
content=content,
|
||||
metadata={
|
||||
"model": MODEL,
|
||||
"seed": SEED,
|
||||
"prompt": PROMPT,
|
||||
"max_tokens": MAX_TOKENS,
|
||||
},
|
||||
)
|
||||
|
||||
print("PASSED: snapshot_reasoning")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
32
e2e/tests/no_internet/docker-compose.override.yml
Normal file
32
e2e/tests/no_internet/docker-compose.override.yml
Normal file
@@ -0,0 +1,32 @@
|
||||
# Block all outbound internet traffic using iptables while preserving:
|
||||
# - Multicast (224.0.0.0/4) for mDNS peer discovery
|
||||
# - Private subnets (10/8, 172.16/12, 192.168/16) for inter-container communication
|
||||
# - Loopback (127/8)
|
||||
# Requires NET_ADMIN capability for iptables.
|
||||
services:
|
||||
exo-node-1:
|
||||
cap_add:
|
||||
- NET_ADMIN
|
||||
entrypoint: ["/bin/sh", "-c"]
|
||||
command:
|
||||
- |
|
||||
iptables -A OUTPUT -d 127.0.0.0/8 -j ACCEPT
|
||||
iptables -A OUTPUT -d 10.0.0.0/8 -j ACCEPT
|
||||
iptables -A OUTPUT -d 172.16.0.0/12 -j ACCEPT
|
||||
iptables -A OUTPUT -d 192.168.0.0/16 -j ACCEPT
|
||||
iptables -A OUTPUT -d 224.0.0.0/4 -j ACCEPT
|
||||
iptables -A OUTPUT -j REJECT
|
||||
exec .venv/bin/exo -v
|
||||
exo-node-2:
|
||||
cap_add:
|
||||
- NET_ADMIN
|
||||
entrypoint: ["/bin/sh", "-c"]
|
||||
command:
|
||||
- |
|
||||
iptables -A OUTPUT -d 127.0.0.0/8 -j ACCEPT
|
||||
iptables -A OUTPUT -d 10.0.0.0/8 -j ACCEPT
|
||||
iptables -A OUTPUT -d 172.16.0.0/12 -j ACCEPT
|
||||
iptables -A OUTPUT -d 192.168.0.0/16 -j ACCEPT
|
||||
iptables -A OUTPUT -d 224.0.0.0/4 -j ACCEPT
|
||||
iptables -A OUTPUT -j REJECT
|
||||
exec .venv/bin/exo -v
|
||||
@@ -132,7 +132,7 @@ markers = [
|
||||
env = [
|
||||
"EXO_TESTS=1"
|
||||
]
|
||||
addopts = "-m 'not slow' --ignore=tests/start_distributed_test.py"
|
||||
addopts = "-m 'not slow'"
|
||||
filterwarnings = [
|
||||
"ignore:builtin type Swig:DeprecationWarning",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,12 @@
|
||||
model_id = "mlx-community/SmolLM2-135M-Instruct"
|
||||
n_layers = 30
|
||||
hidden_size = 576
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
family = "llama"
|
||||
quantization = "bf16"
|
||||
base_model = "SmolLM2 135M"
|
||||
capabilities = ["text"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 269060381
|
||||
@@ -0,0 +1,12 @@
|
||||
model_id = "mlx-community/gemma-2-2b-it-4bit"
|
||||
n_layers = 26
|
||||
hidden_size = 2304
|
||||
supports_tensor = false
|
||||
tasks = ["TextGeneration"]
|
||||
family = "gemma2"
|
||||
quantization = "4bit"
|
||||
base_model = "Gemma 2 2B"
|
||||
capabilities = ["text"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 1492755242
|
||||
@@ -73,8 +73,6 @@ from exo.shared.types.api import (
|
||||
CreateInstanceResponse,
|
||||
DeleteDownloadResponse,
|
||||
DeleteInstanceResponse,
|
||||
DistributeModelParams,
|
||||
DistributeModelResponse,
|
||||
ErrorInfo,
|
||||
ErrorResponse,
|
||||
FinishReason,
|
||||
@@ -119,7 +117,6 @@ from exo.shared.types.commands import (
|
||||
CreateInstance,
|
||||
DeleteDownload,
|
||||
DeleteInstance,
|
||||
DistributeModel,
|
||||
DownloadCommand,
|
||||
ForwarderCommand,
|
||||
ForwarderDownloadCommand,
|
||||
@@ -145,7 +142,6 @@ from exo.shared.types.openai_responses import (
|
||||
ResponsesResponse,
|
||||
)
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.types.worker.downloads import DownloadCompleted
|
||||
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
|
||||
from exo.shared.types.worker.shards import Sharding
|
||||
from exo.utils.banner import print_startup_banner
|
||||
@@ -302,7 +298,6 @@ class API:
|
||||
self.app.get("/events")(self.stream_events)
|
||||
self.app.post("/download/start")(self.start_download)
|
||||
self.app.delete("/download/{node_id}/{model_id:path}")(self.delete_download)
|
||||
self.app.post("/v1/models/{model_id:path}/distribute")(self.distribute_model)
|
||||
self.app.get("/v1/traces")(self.list_traces)
|
||||
self.app.get("/v1/traces/{task_id}")(self.get_trace)
|
||||
self.app.get("/v1/traces/{task_id}/stats")(self.get_trace_stats)
|
||||
@@ -1482,57 +1477,6 @@ class API:
|
||||
await self._send_download(command)
|
||||
return DeleteDownloadResponse(command_id=command.command_id)
|
||||
|
||||
async def distribute_model(
|
||||
self, model_id: ModelId, payload: DistributeModelParams
|
||||
) -> DistributeModelResponse:
|
||||
"""Distribute model files from one node to others via MLX distributed."""
|
||||
# Find a source node that has the model downloaded
|
||||
source_node_id: NodeId | None = None
|
||||
for nid, downloads in self.state.downloads.items():
|
||||
for dp in downloads:
|
||||
if (
|
||||
isinstance(dp, DownloadCompleted)
|
||||
and dp.shard_metadata.model_card.model_id == model_id
|
||||
):
|
||||
source_node_id = nid
|
||||
break
|
||||
if source_node_id is not None:
|
||||
break
|
||||
|
||||
if source_node_id is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"No node has model {model_id} downloaded",
|
||||
)
|
||||
|
||||
# Determine target nodes
|
||||
if payload.target_node_ids is not None:
|
||||
target_node_ids = [
|
||||
nid for nid in payload.target_node_ids if nid != source_node_id
|
||||
]
|
||||
else:
|
||||
target_node_ids = [
|
||||
nid for nid in self.state.topology.list_nodes() if nid != source_node_id
|
||||
]
|
||||
|
||||
if not target_node_ids:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="No target nodes to distribute to",
|
||||
)
|
||||
|
||||
command = DistributeModel(
|
||||
model_id=model_id,
|
||||
source_node_id=source_node_id,
|
||||
target_node_ids=target_node_ids,
|
||||
)
|
||||
await self._send(command)
|
||||
|
||||
return DistributeModelResponse(
|
||||
command_id=command.command_id,
|
||||
message=f"Distributing {model_id} from {source_node_id} to {len(target_node_ids)} node(s)",
|
||||
)
|
||||
|
||||
def _get_trace_path(self, task_id: str) -> Path:
|
||||
return EXO_TRACING_CACHE_DIR / f"trace_{task_id}.json"
|
||||
|
||||
|
||||
@@ -17,7 +17,6 @@ from exo.shared.constants import EXO_EVENT_LOG_DIR, EXO_TRACING_ENABLED
|
||||
from exo.shared.types.commands import (
|
||||
CreateInstance,
|
||||
DeleteInstance,
|
||||
DistributeModel,
|
||||
ForwarderCommand,
|
||||
ForwarderDownloadCommand,
|
||||
ImageEdits,
|
||||
@@ -313,37 +312,6 @@ class Master:
|
||||
self.state.instances, placement
|
||||
)
|
||||
generated_events.extend(transition_events)
|
||||
case DistributeModel():
|
||||
from exo.shared.models.model_cards import ModelCard
|
||||
from exo.shared.types.worker.instances import InstanceMeta
|
||||
from exo.shared.types.worker.shards import Sharding
|
||||
|
||||
model_card = await ModelCard.load(command.model_id)
|
||||
all_node_ids = set(
|
||||
[command.source_node_id] + list(command.target_node_ids)
|
||||
)
|
||||
place_command = PlaceInstance(
|
||||
model_card=model_card,
|
||||
sharding=Sharding.Pipeline,
|
||||
instance_meta=InstanceMeta.MlxRing,
|
||||
min_nodes=len(all_node_ids),
|
||||
)
|
||||
placement = place_instance(
|
||||
place_command,
|
||||
self.state.topology,
|
||||
self.state.instances,
|
||||
self.state.node_memory,
|
||||
self.state.node_network,
|
||||
required_nodes=all_node_ids,
|
||||
)
|
||||
# Mark new instances as transfer-only
|
||||
for instance_id, instance in placement.items():
|
||||
if instance_id not in self.state.instances:
|
||||
instance.shard_assignments.transfer_only = True
|
||||
transition_events = get_transition_events(
|
||||
self.state.instances, placement
|
||||
)
|
||||
generated_events.extend(transition_events)
|
||||
case SendInputChunk(chunk=chunk):
|
||||
generated_events.append(
|
||||
InputChunkReceived(
|
||||
|
||||
@@ -374,15 +374,6 @@ class DeleteDownloadResponse(CamelCaseModel):
|
||||
command_id: CommandId
|
||||
|
||||
|
||||
class DistributeModelParams(CamelCaseModel):
|
||||
target_node_ids: list[NodeId] | None = None # None = all connected nodes
|
||||
|
||||
|
||||
class DistributeModelResponse(CamelCaseModel):
|
||||
command_id: CommandId
|
||||
message: str
|
||||
|
||||
|
||||
class TraceEventResponse(CamelCaseModel):
|
||||
name: str
|
||||
start_us: int
|
||||
|
||||
@@ -77,14 +77,6 @@ class CancelDownload(BaseCommand):
|
||||
model_id: ModelId
|
||||
|
||||
|
||||
class DistributeModel(BaseCommand):
|
||||
"""Distribute model files from one node to others via MLX distributed."""
|
||||
|
||||
model_id: ModelId
|
||||
source_node_id: NodeId
|
||||
target_node_ids: list[NodeId]
|
||||
|
||||
|
||||
DownloadCommand = StartDownload | DeleteDownload | CancelDownload
|
||||
|
||||
|
||||
@@ -99,7 +91,6 @@ Command = (
|
||||
| DeleteInstance
|
||||
| TaskFinished
|
||||
| SendInputChunk
|
||||
| DistributeModel
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -41,7 +41,7 @@ class DownloadModel(BaseTask): # emitted by Worker
|
||||
|
||||
|
||||
class LoadModel(BaseTask): # emitted by Worker
|
||||
has_local_model: bool = Field(default=True)
|
||||
pass
|
||||
|
||||
|
||||
class ConnectToGroup(BaseTask): # emitted by Worker
|
||||
@@ -76,13 +76,6 @@ class ImageEdits(BaseTask): # emitted by Master
|
||||
error_message: str | None = Field(default=None)
|
||||
|
||||
|
||||
class TransferModelToDisk(BaseTask): # emitted by Worker
|
||||
"""Transfer all model files from source to receivers' disk via MLX distributed."""
|
||||
|
||||
shard_metadata: ShardMetadata
|
||||
has_local_model: bool = Field(default=True)
|
||||
|
||||
|
||||
class Shutdown(BaseTask): # emitted by Worker
|
||||
runner_id: RunnerId
|
||||
|
||||
@@ -92,7 +85,6 @@ Task = (
|
||||
| DownloadModel
|
||||
| ConnectToGroup
|
||||
| LoadModel
|
||||
| TransferModelToDisk
|
||||
| StartWarmup
|
||||
| TextGeneration
|
||||
| ImageGeneration
|
||||
|
||||
@@ -84,7 +84,6 @@ class ShardAssignments(CamelCaseModel):
|
||||
model_id: ModelId
|
||||
runner_to_shard: Mapping[RunnerId, ShardMetadata]
|
||||
node_to_runner: Mapping[NodeId, RunnerId]
|
||||
transfer_only: bool = False
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_runners_exist(self) -> "ShardAssignments":
|
||||
|
||||
@@ -47,7 +47,6 @@ if TYPE_CHECKING:
|
||||
from mlx_lm.models.cache import Cache
|
||||
|
||||
TimeoutCallback = Callable[[], None]
|
||||
WeightLoader = Callable[[nn.Module, int], None] | None
|
||||
|
||||
|
||||
def eval_with_timeout(
|
||||
@@ -347,7 +346,6 @@ def tensor_auto_parallel(
|
||||
group: mx.distributed.Group,
|
||||
timeout_seconds: float = 60.0,
|
||||
on_timeout: TimeoutCallback | None = None,
|
||||
weight_loader: WeightLoader = None,
|
||||
) -> nn.Module:
|
||||
all_to_sharded_linear = partial(
|
||||
shard_linear,
|
||||
@@ -457,7 +455,7 @@ def tensor_auto_parallel(
|
||||
raise ValueError(f"Unsupported model type: {type(model)}")
|
||||
|
||||
model = tensor_parallel_sharding_strategy.shard_model(
|
||||
model, timeout_seconds, on_timeout, weight_loader
|
||||
model, timeout_seconds, on_timeout
|
||||
)
|
||||
return patch_tensor_model(model)
|
||||
|
||||
@@ -484,7 +482,6 @@ class TensorParallelShardingStrategy(ABC):
|
||||
model: nn.Module,
|
||||
timeout_seconds: float,
|
||||
on_timeout: TimeoutCallback | None,
|
||||
weight_loader: WeightLoader = None,
|
||||
) -> nn.Module: ...
|
||||
|
||||
|
||||
@@ -494,12 +491,9 @@ class LlamaShardingStrategy(TensorParallelShardingStrategy):
|
||||
model: nn.Module,
|
||||
timeout_seconds: float,
|
||||
on_timeout: TimeoutCallback | None,
|
||||
weight_loader: WeightLoader = None,
|
||||
) -> nn.Module:
|
||||
model = cast(LlamaModel, model)
|
||||
for i, layer in enumerate(model.layers):
|
||||
if weight_loader is not None:
|
||||
weight_loader(model, i)
|
||||
for layer in model.layers:
|
||||
# Force load weights before sharding to avoid FAST_SYNCH deadlock
|
||||
eval_with_timeout(
|
||||
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
|
||||
@@ -551,12 +545,9 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
|
||||
model: nn.Module,
|
||||
timeout_seconds: float,
|
||||
on_timeout: TimeoutCallback | None,
|
||||
weight_loader: WeightLoader = None,
|
||||
) -> nn.Module:
|
||||
model = cast(DeepseekV3Model, model)
|
||||
for i, layer in enumerate(model.layers):
|
||||
if weight_loader is not None:
|
||||
weight_loader(model, i)
|
||||
for layer in model.layers:
|
||||
eval_with_timeout(
|
||||
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
|
||||
)
|
||||
@@ -629,12 +620,9 @@ class GLM4MoeLiteShardingStrategy(TensorParallelShardingStrategy):
|
||||
model: nn.Module,
|
||||
timeout_seconds: float,
|
||||
on_timeout: TimeoutCallback | None,
|
||||
weight_loader: WeightLoader = None,
|
||||
) -> nn.Module:
|
||||
model = cast(GLM4MoeLiteModel, model)
|
||||
for i, layer in enumerate(model.layers): # type: ignore
|
||||
if weight_loader is not None:
|
||||
weight_loader(model, i)
|
||||
for layer in model.layers: # type: ignore
|
||||
layer = cast(Glm4MoeLiteDecoderLayer, layer)
|
||||
eval_with_timeout(
|
||||
layer.parameters(),
|
||||
@@ -774,12 +762,9 @@ class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
|
||||
model: nn.Module,
|
||||
timeout_seconds: float,
|
||||
on_timeout: TimeoutCallback | None,
|
||||
weight_loader: WeightLoader = None,
|
||||
) -> nn.Module:
|
||||
model = cast(MiniMaxModel, model)
|
||||
for i, layer in enumerate(model.layers):
|
||||
if weight_loader is not None:
|
||||
weight_loader(model, i)
|
||||
for layer in model.layers:
|
||||
eval_with_timeout(
|
||||
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
|
||||
)
|
||||
@@ -817,12 +802,9 @@ class QwenShardingStrategy(TensorParallelShardingStrategy):
|
||||
model: nn.Module,
|
||||
timeout_seconds: float,
|
||||
on_timeout: TimeoutCallback | None,
|
||||
weight_loader: WeightLoader = None,
|
||||
) -> nn.Module:
|
||||
model = cast(Qwen3MoeModel | Qwen3NextModel, model)
|
||||
for i, layer in enumerate(model.layers):
|
||||
if weight_loader is not None:
|
||||
weight_loader(model, i)
|
||||
for layer in model.layers:
|
||||
eval_with_timeout(
|
||||
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
|
||||
)
|
||||
@@ -944,12 +926,9 @@ class Glm4MoeShardingStrategy(TensorParallelShardingStrategy):
|
||||
model: nn.Module,
|
||||
timeout_seconds: float,
|
||||
on_timeout: TimeoutCallback | None,
|
||||
weight_loader: WeightLoader = None,
|
||||
) -> nn.Module:
|
||||
model = cast(Glm4MoeModel, model)
|
||||
for i, layer in enumerate(model.layers):
|
||||
if weight_loader is not None:
|
||||
weight_loader(model, i)
|
||||
for layer in model.layers:
|
||||
eval_with_timeout(
|
||||
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
|
||||
)
|
||||
@@ -993,13 +972,10 @@ class GptOssShardingStrategy(TensorParallelShardingStrategy):
|
||||
model: nn.Module,
|
||||
timeout_seconds: float,
|
||||
on_timeout: TimeoutCallback | None,
|
||||
weight_loader: WeightLoader = None,
|
||||
) -> nn.Module:
|
||||
model = cast(GptOssMoeModel, model)
|
||||
|
||||
for i, layer in enumerate(model.layers):
|
||||
if weight_loader is not None:
|
||||
weight_loader(model, i)
|
||||
for layer in model.layers:
|
||||
eval_with_timeout(
|
||||
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
|
||||
)
|
||||
@@ -1037,7 +1013,6 @@ class Step35ShardingStrategy(TensorParallelShardingStrategy):
|
||||
model: nn.Module,
|
||||
timeout_seconds: float,
|
||||
on_timeout: TimeoutCallback | None,
|
||||
weight_loader: WeightLoader = None,
|
||||
) -> nn.Module:
|
||||
model = cast(Step35Model, model)
|
||||
|
||||
|
||||
@@ -1,507 +0,0 @@
|
||||
"""
|
||||
Model transfer via MLX distributed all_sum.
|
||||
|
||||
Three transfer modes:
|
||||
1. Metadata file transfer: broadcast small files (config.json, tokenizer, etc.) to disk
|
||||
2. Weight tensor broadcast: stream weight tensors directly into memory via all_sum
|
||||
3. Full file transfer: broadcast all files (including safetensors) to disk
|
||||
|
||||
All functions are collective operations — every rank in the group must call them.
|
||||
|
||||
Protocol relies on all_sum: source has real data, receivers have zeros.
|
||||
all_sum(source + zeros) = source data on all ranks.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import tempfile
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Any, Final, cast
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
from exo.shared.constants import EXO_MODELS_DIR
|
||||
from exo.shared.models.model_cards import ModelId
|
||||
from exo.worker.runner.bootstrap import logger
|
||||
|
||||
Group = mx.distributed.Group
|
||||
|
||||
CHUNK_SIZE: Final[int] = 100 * 1024 * 1024 # 100 MB
|
||||
_LAYER_RE: Final[re.Pattern[str]] = re.compile(r"(?:^|\.)(layers|h)\.(\d+)\.")
|
||||
|
||||
|
||||
def _all_sum_cpu(x: mx.array, group: Group) -> mx.array:
|
||||
"""all_sum on CPU stream to avoid GPU memory pressure."""
|
||||
return mx.distributed.all_sum(
|
||||
x, stream=mx.default_stream(mx.Device(mx.cpu)), group=group
|
||||
)
|
||||
|
||||
|
||||
def _is_metadata_file(filename: str) -> bool:
|
||||
"""A metadata file is anything that isn't a weight file or weight index.
|
||||
|
||||
Weight indices (.safetensors.index.json) reference safetensors shard paths.
|
||||
Transferring them to a receiver that has no safetensors files is harmless
|
||||
today (load_model's glob doesn't match them), but excluding them avoids
|
||||
stale references and keeps the transfer minimal.
|
||||
"""
|
||||
if filename.endswith(".safetensors"):
|
||||
return False
|
||||
return not filename.endswith(".safetensors.index.json")
|
||||
|
||||
|
||||
def model_path_for_id(model_id: ModelId) -> Path:
|
||||
"""Get model path without requiring directory to exist (unlike build_model_path)."""
|
||||
return EXO_MODELS_DIR / model_id.normalize()
|
||||
|
||||
|
||||
def coordinate_transfer(group: Group, has_local_model: bool) -> tuple[bool, int]:
|
||||
"""
|
||||
Determine if a transfer is needed and which rank is the source.
|
||||
|
||||
All ranks must call this function (uses collective all_sum).
|
||||
|
||||
Returns:
|
||||
(needs_transfer, source_rank) — source_rank is the lowest rank
|
||||
that has the model. needs_transfer is True if any rank is missing it.
|
||||
"""
|
||||
all_sum = partial(_all_sum_cpu, group=group)
|
||||
world_size = group.size()
|
||||
|
||||
# Each rank broadcasts a one-hot vector at its position if it has the model
|
||||
bitmask = mx.zeros(world_size, dtype=mx.int32)
|
||||
if has_local_model:
|
||||
bitmask = bitmask.at[group.rank()].add(1)
|
||||
summed = all_sum(bitmask)
|
||||
mx.eval(summed)
|
||||
|
||||
has_model_flags: list[int] = summed.tolist() # type: ignore[assignment]
|
||||
total_have = sum(has_model_flags)
|
||||
|
||||
if total_have == 0:
|
||||
raise RuntimeError(
|
||||
"No rank has the model files — cannot transfer. "
|
||||
"At least one node must have downloaded the model."
|
||||
)
|
||||
|
||||
if total_have == world_size:
|
||||
logger.info("All ranks have model files, no transfer needed")
|
||||
return False, 0
|
||||
|
||||
source_rank = next(i for i, flag in enumerate(has_model_flags) if flag > 0)
|
||||
logger.info(
|
||||
f"Transfer needed: source_rank={source_rank}, "
|
||||
f"{total_have}/{world_size} ranks have model"
|
||||
)
|
||||
return True, source_rank
|
||||
|
||||
|
||||
def _broadcast_json(obj: object, group: Group, is_source: bool) -> object:
|
||||
"""Broadcast a JSON-serializable object from source to all ranks."""
|
||||
all_sum = partial(_all_sum_cpu, group=group)
|
||||
|
||||
data = json.dumps(obj, separators=(",", ":")).encode("utf-8") if is_source else b""
|
||||
|
||||
# Broadcast length
|
||||
len_arr = mx.array([len(data) if is_source else 0], dtype=mx.int64)
|
||||
len_result = all_sum(len_arr)
|
||||
mx.eval(len_result)
|
||||
length = int(len_result.item())
|
||||
if length == 0:
|
||||
return None
|
||||
|
||||
# Broadcast payload
|
||||
if is_source:
|
||||
arr = mx.array(list(data), dtype=mx.uint8)
|
||||
else:
|
||||
arr = mx.zeros(length, dtype=mx.uint8)
|
||||
result = all_sum(arr)
|
||||
mx.eval(result)
|
||||
return json.loads(bytes(cast(list[int], result.tolist()))) # pyright: ignore[reportAny]
|
||||
|
||||
|
||||
def _build_manifest(
|
||||
model_path: Path, metadata_only: bool = False
|
||||
) -> list[dict[str, str | int]]:
|
||||
"""Build a list of files in the model directory with their relative paths and sizes."""
|
||||
manifest: list[dict[str, str | int]] = []
|
||||
for root, _dirs, files in os.walk(model_path):
|
||||
for fname in sorted(files):
|
||||
if metadata_only and not _is_metadata_file(fname):
|
||||
continue
|
||||
full_path = Path(root) / fname
|
||||
rel_path = str(full_path.relative_to(model_path))
|
||||
manifest.append(
|
||||
{
|
||||
"path": rel_path,
|
||||
"size": full_path.stat().st_size,
|
||||
}
|
||||
)
|
||||
return manifest
|
||||
|
||||
|
||||
def _transfer_file_to_disk(
|
||||
source_path: Path,
|
||||
rel_path: str,
|
||||
file_size: int,
|
||||
group: Group,
|
||||
is_source: bool,
|
||||
dest_path: Path,
|
||||
) -> None:
|
||||
"""Transfer a single file chunk-by-chunk via all_sum. Source reads from disk, receivers write to dest_path."""
|
||||
all_sum = partial(_all_sum_cpu, group=group)
|
||||
|
||||
if is_source:
|
||||
src_file = source_path / rel_path
|
||||
with open(src_file, "rb") as f:
|
||||
offset = 0
|
||||
while offset < file_size:
|
||||
chunk_bytes = min(CHUNK_SIZE, file_size - offset)
|
||||
data = f.read(chunk_bytes)
|
||||
if not data:
|
||||
break
|
||||
size_arr = mx.array([len(data)], dtype=mx.int64)
|
||||
mx.eval(all_sum(size_arr))
|
||||
chunk_arr = mx.array(list(data), dtype=mx.uint8)
|
||||
result = all_sum(chunk_arr)
|
||||
mx.eval(result)
|
||||
offset += len(data)
|
||||
# Signal end of file
|
||||
mx.eval(all_sum(mx.array([0], dtype=mx.int64)))
|
||||
else:
|
||||
dst_file = dest_path / rel_path
|
||||
os.makedirs(dst_file.parent, exist_ok=True)
|
||||
with open(dst_file, "wb") as f:
|
||||
while True:
|
||||
size_arr = all_sum(mx.zeros(1, dtype=mx.int64))
|
||||
mx.eval(size_arr)
|
||||
chunk_size = int(size_arr.item())
|
||||
if chunk_size == 0:
|
||||
break
|
||||
chunk_data = all_sum(mx.zeros(chunk_size, dtype=mx.uint8))
|
||||
mx.eval(chunk_data)
|
||||
f.write(bytes(cast(list[int], chunk_data.tolist())))
|
||||
|
||||
|
||||
def _transfer_files_to_disk(
|
||||
model_path: Path,
|
||||
group: Group,
|
||||
is_source: bool,
|
||||
metadata_only: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Transfer files from source to all receivers' disk.
|
||||
|
||||
Source broadcasts a manifest then each file. Receivers write to a temp dir
|
||||
then atomically move files to model_path.
|
||||
"""
|
||||
if is_source:
|
||||
source_manifest = _build_manifest(model_path, metadata_only=metadata_only)
|
||||
else:
|
||||
source_manifest = []
|
||||
manifest = cast(
|
||||
list[dict[str, str | int]],
|
||||
_broadcast_json(source_manifest if is_source else None, group, is_source),
|
||||
)
|
||||
|
||||
if not manifest:
|
||||
logger.info("No files to transfer")
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"Transferring {len(manifest)} files ({'metadata only' if metadata_only else 'all'})"
|
||||
)
|
||||
|
||||
temp_dir: Path | None = None
|
||||
if not is_source:
|
||||
os.makedirs(model_path.parent, exist_ok=True)
|
||||
temp_dir = Path(
|
||||
tempfile.mkdtemp(
|
||||
dir=model_path.parent,
|
||||
prefix=f".transfer_{model_path.name}_",
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
for entry in manifest:
|
||||
rel_path = str(entry["path"])
|
||||
file_size = int(entry["size"])
|
||||
logger.info(f" {rel_path} ({file_size} bytes)")
|
||||
_transfer_file_to_disk(
|
||||
source_path=model_path,
|
||||
rel_path=rel_path,
|
||||
file_size=file_size,
|
||||
group=group,
|
||||
is_source=is_source,
|
||||
dest_path=temp_dir if temp_dir is not None else model_path,
|
||||
)
|
||||
|
||||
if temp_dir is not None:
|
||||
os.makedirs(model_path, exist_ok=True)
|
||||
for entry in manifest:
|
||||
rel_path = str(entry["path"])
|
||||
src = temp_dir / rel_path
|
||||
dst = model_path / rel_path
|
||||
os.makedirs(dst.parent, exist_ok=True)
|
||||
os.replace(src, dst)
|
||||
logger.info(
|
||||
f"Transfer complete: {len(manifest)} files moved to {model_path}"
|
||||
)
|
||||
finally:
|
||||
if temp_dir is not None and temp_dir.exists():
|
||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||
|
||||
|
||||
def transfer_metadata_files(model_path: Path, group: Group, is_source: bool) -> None:
|
||||
"""
|
||||
Transfer metadata files (config.json, tokenizer files, etc.) to receivers' disk.
|
||||
|
||||
All ranks must call this function (collective operation).
|
||||
Only the designated source (is_source=True) should send; all others receive.
|
||||
"""
|
||||
_transfer_files_to_disk(model_path, group, is_source=is_source, metadata_only=True)
|
||||
|
||||
|
||||
def transfer_all_files(model_path: Path, group: Group, is_source: bool) -> None:
|
||||
"""
|
||||
Transfer ALL model files (including safetensors) to receivers' disk.
|
||||
|
||||
All ranks must call this function (collective operation).
|
||||
Only the designated source (is_source=True) should send; all others receive.
|
||||
"""
|
||||
_transfer_files_to_disk(model_path, group, is_source=is_source, metadata_only=False)
|
||||
|
||||
|
||||
def _parse_mx_dtype(dtype_str: str) -> mx.Dtype:
|
||||
"""Convert a dtype string like 'float16' or 'mlx.core.float16' to mx.Dtype."""
|
||||
name = dtype_str.split(".")[-1]
|
||||
dtype = getattr(mx, name, None)
|
||||
if dtype is None:
|
||||
raise ValueError(f"Unknown MLX dtype: {dtype_str}")
|
||||
return dtype # type: ignore[return-value]
|
||||
|
||||
|
||||
def _extract_layer_index(name: str) -> int | None:
|
||||
"""Extract layer index from a weight name, or None for non-layer weights.
|
||||
|
||||
Matches patterns like ``model.layers.5.self_attn.q_proj.weight``
|
||||
or ``transformer.h.12.mlp.gate_proj.scales``.
|
||||
"""
|
||||
m = _LAYER_RE.search(name)
|
||||
return int(m.group(2)) if m else None
|
||||
|
||||
|
||||
class WeightBroadcastState:
|
||||
"""Holds state for layer-by-layer weight broadcasting.
|
||||
|
||||
Created by :func:`prepare_weight_broadcast`. Callers stream weights
|
||||
incrementally via :meth:`broadcast_non_layer_weights` and
|
||||
:meth:`broadcast_layer` so that at most one layer's worth of un-sharded
|
||||
weight data is resident at a time.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
meta: dict[str, dict[str, Any]],
|
||||
source_weights: dict[str, mx.array] | None,
|
||||
group: Group,
|
||||
is_source: bool,
|
||||
) -> None:
|
||||
self.meta = meta
|
||||
self.source_weights = source_weights
|
||||
self.group = group
|
||||
self.is_source = is_source
|
||||
|
||||
# Partition weight names into layer vs. non-layer
|
||||
self.layer_names: dict[int, list[str]] = {}
|
||||
self.non_layer_names: list[str] = []
|
||||
for name in sorted(meta.keys()):
|
||||
layer_idx = _extract_layer_index(name)
|
||||
if layer_idx is not None:
|
||||
self.layer_names.setdefault(layer_idx, []).append(name)
|
||||
else:
|
||||
self.non_layer_names.append(name)
|
||||
|
||||
logger.info(
|
||||
f"WeightBroadcastState: {len(self.non_layer_names)} non-layer weights, "
|
||||
f"{len(self.layer_names)} layers"
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _broadcast_names(self, names: list[str]) -> dict[str, mx.array]:
|
||||
"""Broadcast a specific set of weight tensors by name."""
|
||||
all_sum = partial(_all_sum_cpu, group=self.group)
|
||||
result: dict[str, mx.array] = {}
|
||||
for name in names:
|
||||
info = self.meta[name]
|
||||
shape = cast(list[int], info["s"])
|
||||
dtype = _parse_mx_dtype(cast(str, info["d"]))
|
||||
|
||||
if self.is_source:
|
||||
assert self.source_weights is not None
|
||||
tensor = self.source_weights.pop(name)
|
||||
mx.eval(tensor) # loads from disk (lazy)
|
||||
else:
|
||||
tensor = mx.zeros(shape, dtype=dtype)
|
||||
|
||||
broadcasted = all_sum(tensor)
|
||||
mx.eval(broadcasted)
|
||||
result[name] = broadcasted
|
||||
return result
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public API
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def broadcast_non_layer_weights(self) -> dict[str, mx.array]:
|
||||
"""Broadcast non-layer weights (embeddings, norms, lm_head)."""
|
||||
if not self.non_layer_names:
|
||||
return {}
|
||||
logger.info(
|
||||
f"Broadcasting {len(self.non_layer_names)} non-layer weight tensors"
|
||||
)
|
||||
return self._broadcast_names(self.non_layer_names)
|
||||
|
||||
def broadcast_layer(self, layer_idx: int) -> dict[str, mx.array]:
|
||||
"""Broadcast weights for a single transformer layer."""
|
||||
names = self.layer_names.get(layer_idx, [])
|
||||
if not names:
|
||||
return {}
|
||||
return self._broadcast_names(names)
|
||||
|
||||
|
||||
def prepare_weight_broadcast(
|
||||
model_path: Path,
|
||||
group: Group,
|
||||
is_source: bool,
|
||||
) -> WeightBroadcastState:
|
||||
"""Prepare for layer-by-layer weight broadcasting.
|
||||
|
||||
Source loads safetensors lazily and broadcasts weight metadata (names,
|
||||
shapes, dtypes) as JSON. Returns a :class:`WeightBroadcastState` that
|
||||
can then stream weights incrementally via ``broadcast_layer()``.
|
||||
|
||||
All ranks must call this function (collective operation).
|
||||
"""
|
||||
source_weights: dict[str, mx.array] | None = None
|
||||
if is_source:
|
||||
source_weights = {}
|
||||
weight_files = sorted(model_path.glob("*.safetensors"))
|
||||
if not weight_files:
|
||||
weight_files = sorted(model_path.glob("**/*.safetensors"))
|
||||
for wf in weight_files:
|
||||
try:
|
||||
loaded = cast(
|
||||
dict[str, mx.array],
|
||||
mx.load(str(wf), lazy=True), # pyright: ignore[reportCallIssue]
|
||||
)
|
||||
except TypeError:
|
||||
loaded = cast(dict[str, mx.array], mx.load(str(wf)))
|
||||
source_weights.update(loaded)
|
||||
logger.info(
|
||||
f"Source loaded {len(source_weights)} weight tensors (lazy) "
|
||||
f"from {len(weight_files)} files"
|
||||
)
|
||||
|
||||
# Broadcast metadata
|
||||
if is_source and source_weights is not None:
|
||||
source_meta: dict[str, dict[str, Any]] = {
|
||||
name: {"s": list(tensor.shape), "d": str(tensor.dtype)}
|
||||
for name, tensor in source_weights.items()
|
||||
}
|
||||
else:
|
||||
source_meta = {}
|
||||
|
||||
meta = cast(
|
||||
dict[str, dict[str, Any]],
|
||||
_broadcast_json(source_meta if is_source else None, group, is_source),
|
||||
)
|
||||
|
||||
logger.info(f"Weight broadcast prepared: {len(meta)} tensors")
|
||||
return WeightBroadcastState(meta, source_weights, group, is_source)
|
||||
|
||||
|
||||
def broadcast_model_weights(
|
||||
model_path: Path,
|
||||
group: Group,
|
||||
is_source: bool,
|
||||
) -> dict[str, mx.array]:
|
||||
"""
|
||||
Broadcast model weight tensors from source rank to all receivers' memory.
|
||||
|
||||
Source loads weights from .safetensors files on disk and broadcasts each
|
||||
tensor via all_sum. Receivers receive tensors directly as mx.arrays in
|
||||
memory — no disk write for weight data.
|
||||
|
||||
All ranks must call this function (collective operation).
|
||||
Only the designated source (is_source=True) should send; all others receive.
|
||||
|
||||
Returns:
|
||||
dict mapping weight names to mx.arrays (on all ranks).
|
||||
"""
|
||||
all_sum = partial(_all_sum_cpu, group=group)
|
||||
|
||||
# Source loads weights (lazy if supported, so only one tensor in memory at a time)
|
||||
weights: dict[str, mx.array] = {}
|
||||
if is_source:
|
||||
weight_files = sorted(model_path.glob("*.safetensors"))
|
||||
if not weight_files:
|
||||
weight_files = sorted(model_path.glob("**/*.safetensors"))
|
||||
for wf in weight_files:
|
||||
try:
|
||||
loaded = cast(dict[str, mx.array], mx.load(str(wf), lazy=True)) # pyright: ignore[reportCallIssue]
|
||||
except TypeError:
|
||||
loaded = cast(dict[str, mx.array], mx.load(str(wf)))
|
||||
weights.update(loaded)
|
||||
logger.info(
|
||||
f"Source loaded {len(weights)} weight tensors from {len(weight_files)} files"
|
||||
)
|
||||
|
||||
# Broadcast weight metadata: {name: {shape, dtype}}
|
||||
if is_source:
|
||||
source_meta: dict[str, dict[str, Any]] = {
|
||||
name: {"s": list(tensor.shape), "d": str(tensor.dtype)}
|
||||
for name, tensor in weights.items()
|
||||
}
|
||||
else:
|
||||
source_meta = {}
|
||||
meta = cast(
|
||||
dict[str, dict[str, Any]],
|
||||
_broadcast_json(source_meta if is_source else None, group, is_source),
|
||||
)
|
||||
|
||||
logger.info(f"Broadcasting {len(meta)} weight tensors")
|
||||
|
||||
# Broadcast each tensor in sorted order (deterministic across ranks).
|
||||
# Source loads one tensor at a time from disk (lazy), broadcasts it,
|
||||
# then drops the reference so only one tensor is in flight at a time.
|
||||
result: dict[str, mx.array] = {}
|
||||
for i, name in enumerate(sorted(meta.keys())):
|
||||
info = meta[name]
|
||||
shape = cast(list[int], info["s"])
|
||||
dtype_str = cast(str, info["d"])
|
||||
dtype = _parse_mx_dtype(dtype_str)
|
||||
|
||||
if is_source:
|
||||
tensor = weights.pop(name) # pop to free lazy ref after broadcast
|
||||
mx.eval(tensor) # loads from disk
|
||||
else:
|
||||
tensor = mx.zeros(shape, dtype=dtype)
|
||||
|
||||
broadcasted = all_sum(tensor)
|
||||
mx.eval(broadcasted)
|
||||
result[name] = broadcasted
|
||||
|
||||
if (i + 1) % 100 == 0:
|
||||
logger.info(f" Broadcast {i + 1}/{len(meta)} tensors")
|
||||
|
||||
logger.info(f"Weight broadcast complete: {len(result)} tensors")
|
||||
return result
|
||||
@@ -2,7 +2,6 @@ import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
from typing import Any, cast
|
||||
|
||||
@@ -60,13 +59,6 @@ from exo.worker.engines.mlx.auto_parallel import (
|
||||
pipeline_auto_parallel,
|
||||
tensor_auto_parallel,
|
||||
)
|
||||
from exo.worker.engines.mlx.model_transfer import (
|
||||
WeightBroadcastState,
|
||||
coordinate_transfer,
|
||||
model_path_for_id,
|
||||
prepare_weight_broadcast,
|
||||
transfer_metadata_files,
|
||||
)
|
||||
from exo.worker.runner.bootstrap import logger
|
||||
|
||||
Group = mx.distributed.Group
|
||||
@@ -205,7 +197,6 @@ def load_mlx_items(
|
||||
bound_instance: BoundInstance,
|
||||
group: Group | None,
|
||||
on_timeout: TimeoutCallback | None = None,
|
||||
has_local_model: bool = True,
|
||||
) -> tuple[Model, TokenizerWrapper]:
|
||||
if group is None:
|
||||
logger.info(f"Single device used for {bound_instance.instance}")
|
||||
@@ -220,10 +211,7 @@ def load_mlx_items(
|
||||
logger.info("Starting distributed init")
|
||||
start_time = time.perf_counter()
|
||||
model, tokenizer = shard_and_load(
|
||||
bound_instance.bound_shard,
|
||||
group=group,
|
||||
on_timeout=on_timeout,
|
||||
has_local_model=has_local_model,
|
||||
bound_instance.bound_shard, group=group, on_timeout=on_timeout
|
||||
)
|
||||
end_time = time.perf_counter()
|
||||
logger.info(
|
||||
@@ -239,89 +227,30 @@ def shard_and_load(
|
||||
shard_metadata: ShardMetadata,
|
||||
group: Group,
|
||||
on_timeout: TimeoutCallback | None = None,
|
||||
has_local_model: bool = True,
|
||||
) -> tuple[nn.Module, TokenizerWrapper]:
|
||||
model_id = shard_metadata.model_card.model_id
|
||||
model_path = model_path_for_id(model_id)
|
||||
model_path = build_model_path(shard_metadata.model_card.model_id)
|
||||
|
||||
# Coordinate: does any rank need a transfer?
|
||||
needs_transfer, source_rank = coordinate_transfer(group, has_local_model)
|
||||
is_source = group.rank() == source_rank
|
||||
|
||||
# Step 1: Always ensure all nodes have metadata files (config, tokenizer, etc.).
|
||||
# This is cheap (~20MB, ~1s) and guarantees config.json is present for load_model().
|
||||
transfer_metadata_files(model_path, group, is_source)
|
||||
|
||||
# Step 2: Only broadcast weights if some rank is missing the model
|
||||
broadcast_state: WeightBroadcastState | None = None
|
||||
if needs_transfer:
|
||||
logger.info(
|
||||
f"Model transfer needed (source_rank={source_rank}, "
|
||||
f"is_source={is_source}, local_weights={has_local_model})"
|
||||
)
|
||||
broadcast_state = prepare_weight_broadcast(model_path, group, is_source)
|
||||
|
||||
# Create model architecture (all ranks have config.json on disk now).
|
||||
# Always use lazy=True when we have broadcast state: load_model's internal
|
||||
# nn.quantize skips quantization when weights dict is empty (no safetensors),
|
||||
# leaving the model un-quantized. lazy=False would then mx.eval() the full
|
||||
# fp16 model (~72GB for a 36B-param model), causing OOM on the receiver.
|
||||
# We handle quantization ourselves below before loading broadcast weights.
|
||||
use_lazy = has_local_model or broadcast_state is not None
|
||||
model, _ = load_model(model_path, lazy=use_lazy, strict=False)
|
||||
model, _ = load_model(model_path, lazy=True, strict=False)
|
||||
logger.debug(model)
|
||||
if hasattr(model, "model") and isinstance(model.model, DeepseekV3Model): # type: ignore
|
||||
pass
|
||||
# TODO: See if we should quantize the model.
|
||||
# def is_attention_layer(path: str) -> bool:
|
||||
# path = path.lower()
|
||||
|
||||
# return "self_attn" in path and "layernorm" not in path
|
||||
|
||||
# def quant_predicate(path: str, module: nn.Module):
|
||||
# if not isinstance(module, nn.Linear):
|
||||
# return False
|
||||
|
||||
# return is_attention_layer(path)
|
||||
# model, config = quantize_model(
|
||||
# model, config, group_size=KV_GROUP_SIZE, bits=ATTENTION_KV_BITS, quant_predicate=quant_predicate, mode=QUANTIZE_MODEL_MODE
|
||||
# )
|
||||
|
||||
assert isinstance(model, nn.Module)
|
||||
|
||||
if broadcast_state is not None:
|
||||
# When receiver has no weight files, load_model skips quantization
|
||||
# (its class_predicate checks `f"{p}.scales" in weights`, which is
|
||||
# always False when weights is empty). Apply quantization explicitly
|
||||
# using the broadcast metadata to determine which layers are quantized,
|
||||
# matching load_model's selective quantization logic exactly.
|
||||
if not has_local_model:
|
||||
config_path = model_path / "config.json"
|
||||
with open(config_path) as f:
|
||||
config = json.load(f) # pyright: ignore[reportAny]
|
||||
quant_config: dict[str, Any] | None = config.get( # pyright: ignore[reportAny]
|
||||
"quantization", None
|
||||
)
|
||||
if quant_config is not None:
|
||||
logger.info(f"Applying quantization to receiver model: {quant_config}")
|
||||
broadcast_weight_names = set(broadcast_state.meta.keys())
|
||||
|
||||
def _class_predicate(p: str, m: nn.Module) -> bool | dict[str, Any]:
|
||||
# Per-layer overrides from config (e.g. "lm_head": false)
|
||||
assert quant_config is not None
|
||||
if p in quant_config:
|
||||
return quant_config[p] # pyright: ignore[reportAny]
|
||||
if not hasattr(m, "to_quantized"):
|
||||
return False
|
||||
# Only quantize layers whose .scales exist in broadcast weights
|
||||
return f"{p}.scales" in broadcast_weight_names
|
||||
|
||||
group_size = int(quant_config.get("group_size", 64)) # pyright: ignore[reportAny]
|
||||
bits = int(quant_config.get("bits", 4)) # pyright: ignore[reportAny]
|
||||
mode: str = quant_config.get("mode", "affine") # pyright: ignore[reportAny]
|
||||
nn.quantize( # pyright: ignore[reportUnknownMemberType]
|
||||
model,
|
||||
group_size=group_size,
|
||||
bits=bits,
|
||||
mode=mode,
|
||||
class_predicate=_class_predicate,
|
||||
)
|
||||
|
||||
# Broadcast and load non-layer weights (embeddings, norms, lm_head) upfront.
|
||||
# These are small (~600MB) and needed before the sharding loop.
|
||||
non_layer_weights = broadcast_state.broadcast_non_layer_weights()
|
||||
if non_layer_weights:
|
||||
model.load_weights(list(non_layer_weights.items()), strict=False)
|
||||
logger.info(f"Loaded {len(non_layer_weights)} non-layer weight tensors")
|
||||
del non_layer_weights
|
||||
|
||||
tokenizer = get_tokenizer(model_path, shard_metadata)
|
||||
|
||||
logger.info(f"Group size: {group.size()}, group rank: {group.rank()}")
|
||||
@@ -335,43 +264,12 @@ def shard_and_load(
|
||||
f"(model size: {model_size_gb:.1f}GB)"
|
||||
)
|
||||
|
||||
# Build per-layer weight loader for streaming broadcast during sharding.
|
||||
# Each layer's weights are broadcast via all_sum just before that layer is
|
||||
# sharded, so at most one un-sharded layer is in memory at a time.
|
||||
weight_loader_fn: Callable[[nn.Module, int], None] | None = None
|
||||
if broadcast_state is not None:
|
||||
_state = broadcast_state # capture for closure
|
||||
|
||||
def _load_layer_weights(mdl: nn.Module, layer_idx: int) -> None:
|
||||
layer_weights = _state.broadcast_layer(layer_idx)
|
||||
if layer_weights:
|
||||
mdl.load_weights(list(layer_weights.items()), strict=False)
|
||||
|
||||
weight_loader_fn = _load_layer_weights
|
||||
|
||||
match shard_metadata:
|
||||
case TensorShardMetadata():
|
||||
logger.info(f"loading model from {model_path} with tensor parallelism")
|
||||
model = tensor_auto_parallel(
|
||||
model, group, timeout_seconds, on_timeout, weight_loader_fn
|
||||
)
|
||||
model = tensor_auto_parallel(model, group, timeout_seconds, on_timeout)
|
||||
case PipelineShardMetadata():
|
||||
logger.info(f"loading model from {model_path} with pipeline parallelism")
|
||||
# Broadcast all layers (all_sum is collective — all ranks must
|
||||
# participate) but only load weights for layers this node will
|
||||
# keep after pipeline slicing. Out-of-range results are discarded,
|
||||
# keeping peak memory proportional to this node's layer count.
|
||||
if broadcast_state is not None:
|
||||
for layer_idx in sorted(broadcast_state.layer_names.keys()):
|
||||
layer_weights = broadcast_state.broadcast_layer(layer_idx)
|
||||
if (
|
||||
shard_metadata.start_layer
|
||||
<= layer_idx
|
||||
< shard_metadata.end_layer
|
||||
and layer_weights
|
||||
):
|
||||
model.load_weights(list(layer_weights.items()), strict=False)
|
||||
del layer_weights
|
||||
model = pipeline_auto_parallel(model, group, shard_metadata)
|
||||
eval_with_timeout(model.parameters(), timeout_seconds, on_timeout)
|
||||
case CfgShardMetadata():
|
||||
@@ -380,8 +278,6 @@ def shard_and_load(
|
||||
"this metadata type is only for image generation models"
|
||||
)
|
||||
|
||||
del broadcast_state
|
||||
|
||||
# TODO: Do we need this?
|
||||
mx.eval(model)
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
from collections.abc import Mapping, Sequence
|
||||
|
||||
from exo.shared.models.model_cards import ModelId
|
||||
from exo.shared.types.common import CommandId, NodeId
|
||||
from exo.shared.types.tasks import (
|
||||
ConnectToGroup,
|
||||
@@ -17,7 +16,6 @@ from exo.shared.types.tasks import (
|
||||
TaskId,
|
||||
TaskStatus,
|
||||
TextGeneration,
|
||||
TransferModelToDisk,
|
||||
)
|
||||
from exo.shared.types.worker.downloads import (
|
||||
DownloadCompleted,
|
||||
@@ -36,11 +34,8 @@ from exo.shared.types.worker.runners import (
|
||||
RunnerLoading,
|
||||
RunnerReady,
|
||||
RunnerRunning,
|
||||
RunnerShutdown,
|
||||
RunnerShuttingDown,
|
||||
RunnerStatus,
|
||||
RunnerWarmingUp,
|
||||
ShardAssignments,
|
||||
)
|
||||
from exo.worker.runner.runner_supervisor import RunnerSupervisor
|
||||
|
||||
@@ -62,7 +57,6 @@ def plan(
|
||||
or _create_runner(node_id, runners, instances)
|
||||
or _model_needs_download(node_id, runners, global_download_status)
|
||||
or _init_distributed_backend(runners, all_runners)
|
||||
or _transfer_model_to_disk(runners, all_runners, global_download_status)
|
||||
or _load_model(runners, all_runners, global_download_status)
|
||||
or _ready_to_warmup(runners, all_runners)
|
||||
or _pending_tasks(runners, tasks, all_runners, input_chunk_buffer)
|
||||
@@ -127,10 +121,6 @@ def _model_needs_download(
|
||||
}
|
||||
|
||||
for runner in runners.values():
|
||||
# Transfer-only instances don't need downloads
|
||||
if runner.bound_instance.instance.shard_assignments.transfer_only:
|
||||
continue
|
||||
|
||||
model_id = runner.bound_instance.bound_shard.model_card.model_id
|
||||
if isinstance(runner.status, RunnerIdle) and (
|
||||
model_id not in download_status
|
||||
@@ -139,15 +129,6 @@ def _model_needs_download(
|
||||
(DownloadOngoing, DownloadCompleted, DownloadFailed),
|
||||
)
|
||||
):
|
||||
# For multi-node instances, skip download if a peer already has the model.
|
||||
# The model will be transferred via MLX distributed during LoadModel.
|
||||
instance = runner.bound_instance.instance
|
||||
is_multi_node = len(instance.shard_assignments.node_to_runner) > 1
|
||||
if is_multi_node and _any_peer_has_model(
|
||||
node_id, model_id, instance, global_download_status
|
||||
):
|
||||
continue
|
||||
|
||||
# We don't invalidate download_status randomly in case a file gets deleted on disk
|
||||
return DownloadModel(
|
||||
instance_id=runner.bound_instance.instance.instance_id,
|
||||
@@ -205,43 +186,6 @@ def _init_distributed_backend(
|
||||
return None
|
||||
|
||||
|
||||
def _transfer_model_to_disk(
|
||||
runners: Mapping[RunnerId, RunnerSupervisor],
|
||||
all_runners: Mapping[RunnerId, RunnerStatus],
|
||||
global_download_status: Mapping[NodeId, Sequence[DownloadProgress]],
|
||||
) -> TransferModelToDisk | None:
|
||||
"""For transfer-only instances: after all ranks are connected, emit TransferModelToDisk."""
|
||||
for runner in runners.values():
|
||||
instance = runner.bound_instance.instance
|
||||
shard_assignments = instance.shard_assignments
|
||||
|
||||
if not shard_assignments.transfer_only:
|
||||
continue
|
||||
|
||||
is_runner_connected = isinstance(runner.status, RunnerConnected)
|
||||
all_connected_or_further = all(
|
||||
isinstance(
|
||||
all_runners.get(global_runner_id, None),
|
||||
(RunnerConnected, RunnerLoading, RunnerShuttingDown, RunnerShutdown),
|
||||
)
|
||||
for global_runner_id in shard_assignments.runner_to_shard
|
||||
)
|
||||
|
||||
if is_runner_connected and all_connected_or_further:
|
||||
has_local = _node_has_download(
|
||||
runner.bound_instance.bound_node_id,
|
||||
shard_assignments.model_id,
|
||||
global_download_status,
|
||||
)
|
||||
return TransferModelToDisk(
|
||||
instance_id=instance.instance_id,
|
||||
shard_metadata=runner.bound_instance.bound_shard,
|
||||
has_local_model=has_local,
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _load_model(
|
||||
runners: Mapping[RunnerId, RunnerSupervisor],
|
||||
all_runners: Mapping[RunnerId, RunnerStatus],
|
||||
@@ -251,97 +195,38 @@ def _load_model(
|
||||
instance = runner.bound_instance.instance
|
||||
shard_assignments = instance.shard_assignments
|
||||
|
||||
# Transfer-only instances don't load models for inference
|
||||
if shard_assignments.transfer_only:
|
||||
all_local_downloads_complete = all(
|
||||
nid in global_download_status
|
||||
and any(
|
||||
isinstance(dp, DownloadCompleted)
|
||||
and dp.shard_metadata.model_card.model_id == shard_assignments.model_id
|
||||
for dp in global_download_status[nid]
|
||||
)
|
||||
for nid in shard_assignments.node_to_runner
|
||||
)
|
||||
if not all_local_downloads_complete:
|
||||
continue
|
||||
|
||||
is_single_node_instance = len(shard_assignments.runner_to_shard) == 1
|
||||
is_single_node_instance = len(instance.shard_assignments.runner_to_shard) == 1
|
||||
if is_single_node_instance and isinstance(runner.status, RunnerIdle):
|
||||
return LoadModel(instance_id=instance.instance_id)
|
||||
|
||||
if is_single_node_instance:
|
||||
# Single-node: require local download complete
|
||||
if not _all_downloads_complete(shard_assignments, global_download_status):
|
||||
continue
|
||||
if isinstance(runner.status, RunnerIdle):
|
||||
return LoadModel(instance_id=instance.instance_id, has_local_model=True)
|
||||
else:
|
||||
# Multi-node: require at least one node to have the model downloaded.
|
||||
# Nodes without the model will receive it via MLX distributed transfer
|
||||
# during model loading.
|
||||
if not _any_download_complete(shard_assignments, global_download_status):
|
||||
continue
|
||||
is_runner_waiting = isinstance(runner.status, RunnerConnected)
|
||||
|
||||
is_runner_waiting = isinstance(runner.status, RunnerConnected)
|
||||
all_ready_for_model = all(
|
||||
isinstance(
|
||||
all_runners.get(global_runner_id, None),
|
||||
(RunnerConnected, RunnerLoading, RunnerLoaded),
|
||||
)
|
||||
for global_runner_id in shard_assignments.runner_to_shard
|
||||
all_ready_for_model = all(
|
||||
isinstance(
|
||||
all_runners.get(global_runner_id, None),
|
||||
(RunnerConnected, RunnerLoading, RunnerLoaded),
|
||||
)
|
||||
for global_runner_id in shard_assignments.runner_to_shard
|
||||
)
|
||||
|
||||
if is_runner_waiting and all_ready_for_model:
|
||||
has_local = _node_has_download(
|
||||
runner.bound_instance.bound_node_id,
|
||||
shard_assignments.model_id,
|
||||
global_download_status,
|
||||
)
|
||||
return LoadModel(
|
||||
instance_id=instance.instance_id,
|
||||
has_local_model=has_local,
|
||||
)
|
||||
if is_runner_waiting and all_ready_for_model:
|
||||
return LoadModel(instance_id=instance.instance_id)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _node_has_download(
|
||||
nid: NodeId,
|
||||
model_id: ModelId,
|
||||
global_download_status: Mapping[NodeId, Sequence[DownloadProgress]],
|
||||
) -> bool:
|
||||
"""Check if a specific node has completed downloading the given model."""
|
||||
return any(
|
||||
isinstance(dp, DownloadCompleted)
|
||||
and dp.shard_metadata.model_card.model_id == model_id
|
||||
for dp in global_download_status.get(nid, [])
|
||||
)
|
||||
|
||||
|
||||
def _any_peer_has_model(
|
||||
node_id: NodeId,
|
||||
model_id: ModelId,
|
||||
instance: Instance,
|
||||
global_download_status: Mapping[NodeId, Sequence[DownloadProgress]],
|
||||
) -> bool:
|
||||
"""Check if any other node in the instance already has the model downloaded."""
|
||||
return any(
|
||||
_node_has_download(nid, model_id, global_download_status)
|
||||
for nid in instance.shard_assignments.node_to_runner
|
||||
if nid != node_id
|
||||
)
|
||||
|
||||
|
||||
def _all_downloads_complete(
|
||||
shard_assignments: ShardAssignments,
|
||||
global_download_status: Mapping[NodeId, Sequence[DownloadProgress]],
|
||||
) -> bool:
|
||||
"""Check if ALL nodes in the instance have completed downloading the model."""
|
||||
return all(
|
||||
_node_has_download(nid, shard_assignments.model_id, global_download_status)
|
||||
for nid in shard_assignments.node_to_runner
|
||||
)
|
||||
|
||||
|
||||
def _any_download_complete(
|
||||
shard_assignments: ShardAssignments,
|
||||
global_download_status: Mapping[NodeId, Sequence[DownloadProgress]],
|
||||
) -> bool:
|
||||
"""Check if at least one node in the instance has completed downloading the model."""
|
||||
return any(
|
||||
_node_has_download(nid, shard_assignments.model_id, global_download_status)
|
||||
for nid in shard_assignments.node_to_runner
|
||||
)
|
||||
|
||||
|
||||
def _ready_to_warmup(
|
||||
runners: Mapping[RunnerId, RunnerSupervisor],
|
||||
all_runners: Mapping[RunnerId, RunnerStatus],
|
||||
@@ -349,11 +234,6 @@ def _ready_to_warmup(
|
||||
for runner in runners.values():
|
||||
instance = runner.bound_instance.instance
|
||||
shard_assignments = instance.shard_assignments
|
||||
|
||||
# Transfer-only instances don't go through warmup
|
||||
if shard_assignments.transfer_only:
|
||||
continue
|
||||
|
||||
shard = runner.bound_instance.bound_shard
|
||||
device_rank = shard.device_rank
|
||||
runner_id = runner.bound_instance.bound_runner_id
|
||||
|
||||
@@ -43,7 +43,6 @@ from exo.shared.types.tasks import (
|
||||
TaskId,
|
||||
TaskStatus,
|
||||
TextGeneration,
|
||||
TransferModelToDisk,
|
||||
)
|
||||
from exo.shared.types.text_generation import TextGenerationTaskParams
|
||||
from exo.shared.types.worker.instances import BoundInstance
|
||||
@@ -83,11 +82,6 @@ from exo.worker.engines.image import (
|
||||
from exo.worker.engines.mlx import Model
|
||||
from exo.worker.engines.mlx.cache import KVPrefixCache
|
||||
from exo.worker.engines.mlx.generator.generate import mlx_generate, warmup_inference
|
||||
from exo.worker.engines.mlx.model_transfer import (
|
||||
coordinate_transfer,
|
||||
model_path_for_id,
|
||||
transfer_all_files,
|
||||
)
|
||||
from exo.worker.engines.mlx.utils_mlx import (
|
||||
apply_chat_template,
|
||||
detect_thinking_prompt_suffix,
|
||||
@@ -198,10 +192,7 @@ def main(
|
||||
|
||||
if ModelTask.TextGeneration in shard_metadata.model_card.tasks:
|
||||
model, tokenizer = load_mlx_items(
|
||||
bound_instance,
|
||||
group,
|
||||
on_timeout=on_model_load_timeout,
|
||||
has_local_model=task.has_local_model,
|
||||
bound_instance, group, on_timeout=on_model_load_timeout
|
||||
)
|
||||
logger.info(
|
||||
f"model has_tool_calling={tokenizer.has_tool_calling}"
|
||||
@@ -517,27 +508,6 @@ def main(
|
||||
|
||||
current_status = RunnerReady()
|
||||
logger.info("runner ready")
|
||||
case TransferModelToDisk() if (
|
||||
isinstance(current_status, RunnerConnected) and group is not None
|
||||
):
|
||||
logger.info("starting disk-to-disk model transfer")
|
||||
event_sender.send(TaskAcknowledged(task_id=task.task_id))
|
||||
|
||||
model_path = model_path_for_id(
|
||||
task.shard_metadata.model_card.model_id
|
||||
)
|
||||
_, source_rank = coordinate_transfer(group, task.has_local_model)
|
||||
is_source = group.rank() == source_rank
|
||||
transfer_all_files(model_path, group, is_source)
|
||||
|
||||
logger.info("disk-to-disk model transfer complete")
|
||||
current_status = RunnerShuttingDown()
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
current_status = RunnerShutdown()
|
||||
case Shutdown():
|
||||
current_status = RunnerShuttingDown()
|
||||
logger.info("runner shutting down")
|
||||
|
||||
@@ -112,7 +112,6 @@ def test_plan_loads_model_when_all_shards_downloaded_and_waiting():
|
||||
|
||||
assert isinstance(result, LoadModel)
|
||||
assert result.instance_id == INSTANCE_1_ID
|
||||
assert result.has_local_model is True
|
||||
|
||||
|
||||
def test_plan_does_not_request_download_when_shard_already_downloaded():
|
||||
@@ -158,11 +157,10 @@ def test_plan_does_not_request_download_when_shard_already_downloaded():
|
||||
assert not isinstance(result, plan_mod.DownloadModel)
|
||||
|
||||
|
||||
def test_plan_loads_model_when_any_node_has_download_for_multi_node():
|
||||
def test_plan_does_not_load_model_until_all_shards_downloaded_globally():
|
||||
"""
|
||||
For multi-node instances, LoadModel should be emitted when at least one
|
||||
node has the model downloaded. Nodes without the model will receive it
|
||||
via MLX distributed transfer during model loading.
|
||||
LoadModel should not be emitted while some shards are still missing from
|
||||
the global_download_status.
|
||||
"""
|
||||
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
|
||||
shard2 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)
|
||||
@@ -187,7 +185,6 @@ def test_plan_loads_model_when_any_node_has_download_for_multi_node():
|
||||
RUNNER_2_ID: RunnerConnected(),
|
||||
}
|
||||
|
||||
# Only NODE_A has the model — LoadModel should still fire
|
||||
global_download_status = {
|
||||
NODE_A: [
|
||||
DownloadCompleted(
|
||||
@@ -206,42 +203,19 @@ def test_plan_loads_model_when_any_node_has_download_for_multi_node():
|
||||
tasks={},
|
||||
)
|
||||
|
||||
assert isinstance(result, LoadModel)
|
||||
assert result.instance_id == INSTANCE_1_ID
|
||||
assert result.has_local_model is True
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_plan_does_not_load_model_when_no_node_has_download():
|
||||
"""
|
||||
LoadModel should not be emitted when no node has the model downloaded.
|
||||
"""
|
||||
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
|
||||
shard2 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)
|
||||
instance = get_mlx_ring_instance(
|
||||
instance_id=INSTANCE_1_ID,
|
||||
model_id=MODEL_A_ID,
|
||||
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},
|
||||
runner_to_shard={RUNNER_1_ID: shard1, RUNNER_2_ID: shard2},
|
||||
)
|
||||
|
||||
bound_instance = BoundInstance(
|
||||
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
|
||||
)
|
||||
local_runner = FakeRunnerSupervisor(
|
||||
bound_instance=bound_instance, status=RunnerConnected()
|
||||
)
|
||||
|
||||
runners = {RUNNER_1_ID: local_runner}
|
||||
instances = {INSTANCE_1_ID: instance}
|
||||
all_runners = {
|
||||
RUNNER_1_ID: RunnerConnected(),
|
||||
RUNNER_2_ID: RunnerConnected(),
|
||||
}
|
||||
|
||||
# No node has the model
|
||||
global_download_status: dict[NodeId, list[DownloadProgress]] = {
|
||||
NODE_A: [],
|
||||
NODE_B: [],
|
||||
global_download_status = {
|
||||
NODE_A: [
|
||||
DownloadCompleted(
|
||||
shard_metadata=shard1, node_id=NODE_A, total_bytes=Memory()
|
||||
)
|
||||
],
|
||||
NODE_B: [
|
||||
DownloadCompleted(
|
||||
shard_metadata=shard2, node_id=NODE_B, total_bytes=Memory()
|
||||
)
|
||||
], # NODE_B has no downloads completed yet
|
||||
}
|
||||
|
||||
result = plan_mod.plan(
|
||||
@@ -253,57 +227,4 @@ def test_plan_does_not_load_model_when_no_node_has_download():
|
||||
tasks={},
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_plan_load_model_has_local_model_false_when_node_missing_download():
|
||||
"""
|
||||
For multi-node instances, when the local node does NOT have the model
|
||||
but a peer does, LoadModel should be emitted with has_local_model=False.
|
||||
"""
|
||||
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
|
||||
shard2 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)
|
||||
instance = get_mlx_ring_instance(
|
||||
instance_id=INSTANCE_1_ID,
|
||||
model_id=MODEL_A_ID,
|
||||
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},
|
||||
runner_to_shard={RUNNER_1_ID: shard1, RUNNER_2_ID: shard2},
|
||||
)
|
||||
|
||||
# NODE_B is the local node (bound_node_id=NODE_B), it does NOT have the model
|
||||
bound_instance = BoundInstance(
|
||||
instance=instance, bound_runner_id=RUNNER_2_ID, bound_node_id=NODE_B
|
||||
)
|
||||
local_runner = FakeRunnerSupervisor(
|
||||
bound_instance=bound_instance, status=RunnerConnected()
|
||||
)
|
||||
|
||||
runners = {RUNNER_2_ID: local_runner}
|
||||
instances = {INSTANCE_1_ID: instance}
|
||||
all_runners = {
|
||||
RUNNER_1_ID: RunnerConnected(),
|
||||
RUNNER_2_ID: RunnerConnected(),
|
||||
}
|
||||
|
||||
# Only NODE_A has the model, NODE_B does not
|
||||
global_download_status: dict[NodeId, list[DownloadProgress]] = {
|
||||
NODE_A: [
|
||||
DownloadCompleted(
|
||||
shard_metadata=shard1, node_id=NODE_A, total_bytes=Memory()
|
||||
)
|
||||
],
|
||||
NODE_B: [],
|
||||
}
|
||||
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_B,
|
||||
runners=runners, # type: ignore
|
||||
global_download_status=global_download_status,
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
tasks={},
|
||||
)
|
||||
|
||||
assert isinstance(result, LoadModel)
|
||||
assert result.instance_id == INSTANCE_1_ID
|
||||
assert result.has_local_model is False
|
||||
assert result is not None
|
||||
|
||||
Reference in New Issue
Block a user