Compare commits

...

9 Commits

Author SHA1 Message Date
Alex Cheema
e5a8f39db6 fix: add missing cancel_sender param to test supervisor builder
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-16 11:11:30 -08:00
Alex Cheema
4fa6d05651 fix: resolve lint/format issues after merging main and fix pytest collection
Add root conftest.py to exclude tests/start_distributed_test.py from
pytest collection (it calls sys.exit at module level). Fix ruff lint
issues (import sorting, f-string without placeholders, lambda loop
variable capture) and apply nix fmt formatting to e2e files.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-16 10:05:13 -08:00
Alex Cheema
587f06cb3e fix: add health check and heartbeat to RunnerSupervisor
Add proactive monitoring to detect runner process death and unresponsiveness:

- Health check loop polls is_alive() every 1s, detects unexpected exits
- Counter-based heartbeat detects frozen/unresponsive processes
- Emits RunnerFailed event and releases pending task waiters on failure
- Add EXO_RUNNER_MUST_DIE debug trigger for testing abrupt process death
- Add chaos E2E test that kills runner mid-inference

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-16 10:05:13 -08:00
Alex Cheema
386c78dc99 fix: skip slow inference test in CI, run with --slow
MLX CPU inference on x86_64 is too slow for CI runners (~10min+ for
a single request). Mark the inference snapshot test as slow so it's
skipped by default. Run with --slow or E2E_SLOW=1 on Apple Silicon.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-16 10:04:07 -08:00
Alex Cheema
9eed3bc467 feat: add deterministic inference snapshot test
Launch mlx-community/Qwen3-0.6B-4bit on the cluster, send a chat
completion with seed=42 and temperature=0, and verify the output
matches a committed snapshot. Tests inference determinism end-to-end.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-16 10:04:07 -08:00
Alex Cheema
edee9c9306 fix: make no_internet test actually block internet with iptables
Use iptables to block all outbound traffic except private subnets and
multicast (for mDNS discovery). Verify internet is blocked by curling
huggingface.co from inside each container and checking exo logs for
"Internet connectivity: False".

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-16 10:04:07 -08:00
Alex Cheema
a454cd3da3 fix: reduce Docker image size and free more CI disk space
Clean up Rust target/ and cargo registry after uv sync in the same RUN
command so build artifacts aren't committed to the layer (~1-2 GB saved).
Also remove more unused toolchains from the CI runner.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-16 10:04:07 -08:00
Alex Cheema
d477dc9f67 fix: free disk space in CI before Docker build
The runner was running out of disk space during the Docker image build
(Rust compilation + Python deps). Remove unused toolchains first.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-16 10:04:07 -08:00
Alex Cheema
79862ccfd2 feat: add Docker-based E2E test framework
Add a Python/asyncio E2E test framework that spins up 2-node exo clusters
in Docker Compose and verifies cluster formation, discovery, election, and
API health. Includes a no-internet chaos test using DNS blocking.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-16 10:04:07 -08:00
17 changed files with 978 additions and 5 deletions

15
.dockerignore Normal file
View File

@@ -0,0 +1,15 @@
.venv/
.direnv/
target/
.git/
.idea/
.pytest_cache/
.ruff_cache/
dashboard/node_modules/
dashboard/.svelte-kit/
dashboard/build/
dist/
*.pdb
**/__pycache__
**/.DS_Store
.mlx_typings/

29
.github/workflows/e2e.yml vendored Normal file
View File

@@ -0,0 +1,29 @@
name: e2e-tests
on:
push:
pull_request:
branches:
- staging
- main
jobs:
e2e:
runs-on: ubuntu-latest
timeout-minutes: 30
steps:
- name: Free up disk space
run: |
sudo rm -rf /usr/share/dotnet /usr/local/lib/android /opt/ghc \
/opt/hostedtoolcache /usr/local/share/boost /usr/share/swift \
/opt/microsoft /opt/az
docker system prune -af
df -h /
- name: Checkout repository
uses: actions/checkout@v4
with:
lfs: false
- name: Run E2E tests
run: python3 e2e/run_all.py

1
conftest.py Normal file
View File

@@ -0,0 +1 @@
collect_ignore = ["tests/start_distributed_test.py"]

53
e2e/Dockerfile Normal file
View File

@@ -0,0 +1,53 @@
# Stage 1: Build the dashboard
FROM node:22-slim AS dashboard
WORKDIR /app/dashboard
COPY dashboard/package.json dashboard/package-lock.json ./
RUN npm ci
COPY dashboard/ .
RUN npm run build
# Stage 2: Build and run exo
FROM python:3.13-slim
# Install system dependencies
RUN apt-get update && apt-get install -y \
build-essential \
pkg-config \
libssl-dev \
curl \
protobuf-compiler \
iptables \
&& rm -rf /var/lib/apt/lists/*
# Install Rust nightly
RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y --default-toolchain nightly
ENV PATH="/root/.cargo/bin:${PATH}"
# Install uv
COPY --from=ghcr.io/astral-sh/uv:latest /uv /usr/local/bin/uv
WORKDIR /app
# Copy dependency files first for better layer caching
COPY pyproject.toml Cargo.toml uv.lock README.md ./
COPY rust/ ./rust/
COPY bench/pyproject.toml ./bench/pyproject.toml
# Copy source and resources
COPY src/ ./src/
COPY resources/ ./resources/
# Copy built dashboard from stage 1
COPY --from=dashboard /app/dashboard/build ./dashboard/build/
# Install Python deps and build Rust bindings, then clean up build artifacts
# to keep the layer small (Rust target/ and cargo registry can be 1-2 GB)
RUN uv sync && rm -rf /app/rust/target /root/.cargo/registry /root/.cargo/git
# Wrap g++ with -fpermissive to fix MLX CPU JIT compilation with GCC 14
# (GCC 14 treats _Float128/_Float32/_Float64 as built-in types, conflicting with MLX-generated code)
RUN mv /usr/bin/g++ /usr/bin/g++.real && \
printf '#!/bin/sh\nexec /usr/bin/g++.real -fpermissive "$@"\n' > /usr/bin/g++ && \
chmod +x /usr/bin/g++
CMD [".venv/bin/exo", "-v"]

182
e2e/conftest.py Normal file
View File

@@ -0,0 +1,182 @@
"""Shared E2E test infrastructure for exo cluster tests."""
import asyncio
import json
import os
import sys
from pathlib import Path
from urllib.error import URLError
from urllib.request import Request, urlopen
E2E_DIR = Path(__file__).parent.resolve()
TIMEOUT = int(os.environ.get("E2E_TIMEOUT", "120"))
class Cluster:
"""Async wrapper around a docker compose exo cluster."""
def __init__(self, name: str, overrides: list[str] | None = None):
self.name = name
self.project = f"e2e-{name}"
compose_files = [str(E2E_DIR / "docker-compose.yml")]
for path in overrides or []:
compose_files.append(str(E2E_DIR / path))
self._compose_base = [
"docker",
"compose",
"-p",
self.project,
*[arg for f in compose_files for arg in ("-f", f)],
]
async def __aenter__(self):
return self
async def __aexit__(self, *exc):
await self.stop()
async def _run(self, *args: str, check: bool = True) -> str:
proc = await asyncio.create_subprocess_exec(
*self._compose_base,
*args,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.STDOUT,
)
stdout, _ = await proc.communicate()
output = stdout.decode()
if check and proc.returncode != 0:
print(output, file=sys.stderr)
raise RuntimeError(
f"docker compose {' '.join(args)} failed (rc={proc.returncode})"
)
return output
async def build(self):
print(" Building images...")
await self._run("build", "--quiet")
async def start(self):
print(" Starting cluster...")
await self._run("up", "-d")
async def stop(self):
print(" Cleaning up...")
await self._run("down", "--timeout", "5", check=False)
async def logs(self) -> str:
return await self._run("logs", check=False)
async def exec(
self, service: str, *cmd: str, check: bool = True
) -> tuple[int, str]:
"""Run a command inside a running container. Returns (returncode, output)."""
proc = await asyncio.create_subprocess_exec(
*self._compose_base,
"exec",
"-T",
service,
*cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.STDOUT,
)
stdout, _ = await proc.communicate()
output = stdout.decode()
if check and proc.returncode != 0:
raise RuntimeError(
f"exec {' '.join(cmd)} in {service} failed (rc={proc.returncode})"
)
return proc.returncode, output
async def wait_for(self, description: str, check_fn, timeout: int = TIMEOUT):
"""Poll check_fn every 2s until it returns True or timeout expires."""
print(f" Waiting for {description}...")
deadline = asyncio.get_event_loop().time() + timeout
while asyncio.get_event_loop().time() < deadline:
if await check_fn():
print(f" {description}")
return
await asyncio.sleep(2)
output = await self.logs()
print(f"--- cluster logs ---\n{output}\n---", file=sys.stderr)
raise TimeoutError(f"Timed out waiting for {description}")
async def assert_healthy(self):
"""Verify the cluster formed correctly: nodes started, discovered each other, elected a master, API responds."""
async def both_nodes_started():
log = await self.logs()
return log.count("Starting node") >= 2
async def nodes_discovered():
log = await self.logs()
return log.count("ConnectionMessageType.Connected") >= 2
async def master_elected():
log = await self.logs()
return "demoting self" in log
async def api_responding():
try:
with urlopen("http://localhost:52415/v1/models", timeout=3) as resp:
return resp.status == 200
except (URLError, OSError):
return False
await self.wait_for("Both nodes started", both_nodes_started)
await self.wait_for("Nodes discovered each other", nodes_discovered)
await self.wait_for("Master election resolved", master_elected)
await self.wait_for("API responding", api_responding)
async def _api(
self, method: str, path: str, body: dict | None = None, timeout: int = 30
) -> dict:
"""Make an API request to the cluster. Returns parsed JSON."""
url = f"http://localhost:52415{path}"
data = json.dumps(body).encode() if body else None
req = Request(
url, data=data, headers={"Content-Type": "application/json"}, method=method
)
loop = asyncio.get_event_loop()
resp_bytes = await loop.run_in_executor(
None, lambda: urlopen(req, timeout=timeout).read()
)
return json.loads(resp_bytes)
async def place_model(self, model: str, timeout: int = 600):
"""Place a model instance on the cluster (triggers download) and wait until it's ready."""
await self._api("POST", "/place_instance", {"model_id": model})
async def model_ready():
try:
resp = await self._api("GET", "/v1/models")
return any(m.get("id") == model for m in resp.get("data", []))
except Exception:
return False
await self.wait_for(f"Model {model} ready", model_ready, timeout=timeout)
async def chat(
self, model: str, messages: list[dict], timeout: int = 600, **kwargs
) -> dict:
"""Send a chat completion request. Retries until model is downloaded and inference completes."""
body = json.dumps({"model": model, "messages": messages, **kwargs}).encode()
deadline = asyncio.get_event_loop().time() + timeout
last_error = None
while asyncio.get_event_loop().time() < deadline:
try:
req = Request(
"http://localhost:52415/v1/chat/completions",
data=body,
headers={"Content-Type": "application/json"},
)
loop = asyncio.get_event_loop()
resp_bytes = await loop.run_in_executor(
None, lambda r=req: urlopen(r, timeout=300).read()
)
return json.loads(resp_bytes)
except Exception as e:
last_error = e
await asyncio.sleep(5)
raise TimeoutError(f"Chat request failed after {timeout}s: {last_error}")

18
e2e/docker-compose.yml Normal file
View File

@@ -0,0 +1,18 @@
services:
exo-node-1:
build:
context: ..
dockerfile: e2e/Dockerfile
environment:
- EXO_LIBP2P_NAMESPACE=docker-e2e
command: [".venv/bin/exo", "-v"]
ports:
- "52415:52415"
exo-node-2:
build:
context: ..
dockerfile: e2e/Dockerfile
environment:
- EXO_LIBP2P_NAMESPACE=docker-e2e
command: [".venv/bin/exo", "-v"]

75
e2e/run_all.py Normal file
View File

@@ -0,0 +1,75 @@
#!/usr/bin/env python3
"""Discovers and runs all E2E tests in e2e/test_*.py.
Tests with '# slow' on the first line of their docstring are skipped
unless --slow is passed or E2E_SLOW=1 is set.
"""
import os
import subprocess
import sys
from pathlib import Path
E2E_DIR = Path(__file__).parent.resolve()
def is_slow(test_file: Path) -> bool:
"""Check if the test file is marked as slow (has '# slow' in first 3 lines)."""
with open(test_file) as f:
for line in f:
if line.strip().startswith("#"):
continue
if line.strip().startswith('"""') or line.strip().startswith("'''"):
# Read into the docstring
for doc_line in f:
if "slow" in doc_line.lower() and doc_line.strip().startswith(
"slow"
):
return True
if '"""' in doc_line or "'''" in doc_line:
break
break
return False
def main():
run_slow = "--slow" in sys.argv or os.environ.get("E2E_SLOW") == "1"
test_files = sorted(E2E_DIR.glob("test_*.py"))
if not test_files:
print("No test files found")
sys.exit(1)
passed = 0
failed = 0
skipped = 0
failures = []
for test_file in test_files:
name = test_file.stem
if is_slow(test_file) and not run_slow:
print(f"=== {name} === SKIPPED (slow, use --slow to run)")
skipped += 1
continue
print(f"=== {name} ===")
result = subprocess.run([sys.executable, str(test_file)])
if result.returncode == 0:
passed += 1
else:
failed += 1
failures.append(name)
print()
total = passed + failed + skipped
print("================================")
print(
f"{passed}/{total} tests passed" + (f", {skipped} skipped" if skipped else "")
)
if failed:
print(f"Failed: {' '.join(failures)}")
sys.exit(1)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,8 @@
{
"model": "mlx-community/Qwen3-0.6B-4bit",
"seed": 42,
"temperature": 0,
"prompt": "What is 2+2? Reply with just the number.",
"max_tokens": 32,
"content": "<think>\nOkay, so I need to figure out what 2+2 is. Let me think. Well, if you add 2 and 2 together"
}

View File

@@ -0,0 +1,22 @@
"""Test: Basic cluster formation.
Verifies two nodes discover each other, elect a master, and the API responds.
"""
import asyncio
import sys
sys.path.insert(0, str(__import__("pathlib").Path(__file__).parent))
from conftest import Cluster
async def main():
async with Cluster("cluster_formation") as cluster:
await cluster.build()
await cluster.start()
await cluster.assert_healthy()
print("PASSED: cluster_formation")
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -0,0 +1,82 @@
"""Test: Deterministic inference output (snapshot test).
slow
Sends a chat completion request with a fixed seed and temperature=0,
then verifies the output matches a known-good snapshot. This ensures
inference produces consistent results across runs.
Requires a machine that can run MLX inference at reasonable speed (Apple Silicon).
Run with: python3 e2e/run_all.py --slow or E2E_SLOW=1 python3 e2e/run_all.py
"""
import asyncio
import json
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent))
from conftest import Cluster
MODEL = "mlx-community/Qwen3-0.6B-4bit"
SEED = 42
PROMPT = "What is 2+2? Reply with just the number."
MAX_TOKENS = 32
SNAPSHOT_FILE = Path(__file__).parent / "snapshots" / "inference.json"
async def main():
async with Cluster("inference_snapshot") as cluster:
await cluster.build()
await cluster.start()
await cluster.assert_healthy()
# Launch the model instance (triggers download + placement)
print(f" Launching model {MODEL}...")
await cluster.place_model(MODEL)
print(f" Sending chat completion (seed={SEED}, temperature=0)...")
resp = await cluster.chat(
model=MODEL,
messages=[{"role": "user", "content": PROMPT}],
seed=SEED,
temperature=0,
max_tokens=MAX_TOKENS,
)
content = resp["choices"][0]["message"]["content"]
print(f" Response: {content!r}")
# Load or create snapshot
if SNAPSHOT_FILE.exists():
snapshot = json.loads(SNAPSHOT_FILE.read_text())
expected = snapshot["content"]
assert content == expected, (
f"Snapshot mismatch!\n"
f" Expected: {expected!r}\n"
f" Got: {content!r}\n"
f" Delete {SNAPSHOT_FILE} to regenerate."
)
print(" Output matches snapshot")
else:
SNAPSHOT_FILE.parent.mkdir(parents=True, exist_ok=True)
SNAPSHOT_FILE.write_text(
json.dumps(
{
"model": MODEL,
"seed": SEED,
"temperature": 0,
"prompt": PROMPT,
"max_tokens": MAX_TOKENS,
"content": content,
},
indent=2,
)
+ "\n"
)
print(f" Snapshot created: {SNAPSHOT_FILE}")
print("PASSED: inference_snapshot")
if __name__ == "__main__":
asyncio.run(main())

47
e2e/test_no_internet.py Normal file
View File

@@ -0,0 +1,47 @@
"""Test: Cluster works without internet access.
Verifies exo functions correctly when containers can talk to each other
but cannot reach the internet. Uses iptables to block all outbound traffic
except private subnets and multicast (for mDNS discovery).
"""
import asyncio
import sys
sys.path.insert(0, str(__import__("pathlib").Path(__file__).parent))
from conftest import Cluster
async def main():
async with Cluster(
"no_internet",
overrides=["tests/no_internet/docker-compose.override.yml"],
) as cluster:
await cluster.build()
await cluster.start()
await cluster.assert_healthy()
# Verify internet is actually blocked from inside the containers
for node in ["exo-node-1", "exo-node-2"]:
rc, _ = await cluster.exec(
node,
"curl",
"-sf",
"--max-time",
"3",
"https://huggingface.co",
check=False,
)
assert rc != 0, f"{node} should not be able to reach the internet"
print(f" {node}: internet correctly blocked")
# Verify exo detected no internet connectivity
log = await cluster.logs()
assert "Internet connectivity: False" in log, "exo should detect no internet"
print(" exo correctly detected no internet connectivity")
print("PASSED: no_internet")
if __name__ == "__main__":
asyncio.run(main())

65
e2e/test_runner_chaos.py Normal file
View File

@@ -0,0 +1,65 @@
"""Test: Runner chaos — abrupt runner death detection.
slow
Sends a chat completion with the EXO_RUNNER_MUST_DIE trigger, which causes
the runner process to call os._exit(1) (simulating an OOM kill). Verifies that
the RunnerSupervisor health check detects the death and the system doesn't hang.
Requires a machine that can run MLX inference at reasonable speed (Apple Silicon).
Run with: python3 e2e/run_all.py --slow or E2E_SLOW=1 python3 e2e/run_all.py
"""
import asyncio
import contextlib
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent))
from conftest import Cluster
MODEL = "mlx-community/Qwen3-0.6B-4bit"
async def main():
async with Cluster("runner_chaos") as cluster:
await cluster.build()
await cluster.start()
await cluster.assert_healthy()
# Place the model so a runner is loaded and ready
print(f" Launching model {MODEL}...")
await cluster.place_model(MODEL)
# Send a chat request with the die trigger.
# The runner will call os._exit(1) mid-inference, simulating OOM kill.
# The chat request itself will fail — that's expected.
print(" Sending EXO_RUNNER_MUST_DIE trigger...")
with contextlib.suppress(Exception):
await cluster.chat(
model=MODEL,
messages=[{"role": "user", "content": "EXO RUNNER MUST DIE"}],
timeout=60,
)
# Wait for the health check to detect the death and emit RunnerFailed
async def health_check_detected():
log = await cluster.logs()
return "runner process died unexpectedly" in log
await cluster.wait_for(
"Health check detected runner death",
health_check_detected,
timeout=30,
)
# Verify RunnerFailed was emitted (visible in logs)
log = await cluster.logs()
assert "runner process died unexpectedly" in log, (
f"Expected health check to detect runner death but it didn't.\nLogs:\n{log}"
)
print("PASSED: runner_chaos")
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -0,0 +1,32 @@
# Block all outbound internet traffic using iptables while preserving:
# - Multicast (224.0.0.0/4) for mDNS peer discovery
# - Private subnets (10/8, 172.16/12, 192.168/16) for inter-container communication
# - Loopback (127/8)
# Requires NET_ADMIN capability for iptables.
services:
exo-node-1:
cap_add:
- NET_ADMIN
entrypoint: ["/bin/sh", "-c"]
command:
- |
iptables -A OUTPUT -d 127.0.0.0/8 -j ACCEPT
iptables -A OUTPUT -d 10.0.0.0/8 -j ACCEPT
iptables -A OUTPUT -d 172.16.0.0/12 -j ACCEPT
iptables -A OUTPUT -d 192.168.0.0/16 -j ACCEPT
iptables -A OUTPUT -d 224.0.0.0/4 -j ACCEPT
iptables -A OUTPUT -j REJECT
exec .venv/bin/exo -v
exo-node-2:
cap_add:
- NET_ADMIN
entrypoint: ["/bin/sh", "-c"]
command:
- |
iptables -A OUTPUT -d 127.0.0.0/8 -j ACCEPT
iptables -A OUTPUT -d 10.0.0.0/8 -j ACCEPT
iptables -A OUTPUT -d 172.16.0.0/12 -j ACCEPT
iptables -A OUTPUT -d 192.168.0.0/16 -j ACCEPT
iptables -A OUTPUT -d 224.0.0.0/4 -j ACCEPT
iptables -A OUTPUT -j REJECT
exec .venv/bin/exo -v

View File

@@ -1,4 +1,8 @@
from __future__ import annotations
import os
import threading
from multiprocessing.sharedctypes import Synchronized
import loguru
@@ -10,6 +14,15 @@ from exo.utils.channels import ClosedResourceError, MpReceiver, MpSender
logger: "loguru.Logger" = loguru.logger
HEARTBEAT_INTERVAL_SECONDS = 0.5
def _heartbeat_loop(heartbeat: Synchronized[int], stop: threading.Event) -> None:
"""Daemon thread that periodically increments the heartbeat counter."""
while not stop.is_set():
heartbeat.value += 1
stop.wait(HEARTBEAT_INTERVAL_SECONDS)
def entrypoint(
bound_instance: BoundInstance,
@@ -17,6 +30,7 @@ def entrypoint(
task_receiver: MpReceiver[Task],
cancel_receiver: MpReceiver[TaskId],
_logger: "loguru.Logger",
heartbeat: Synchronized[int] | None = None,
) -> None:
fast_synch_override = os.environ.get("EXO_FAST_SYNCH")
if fast_synch_override == "on" or (
@@ -35,6 +49,17 @@ def entrypoint(
logger.info(f"Fast synch flag: {os.environ['MLX_METAL_FAST_SYNCH']}")
# Start heartbeat thread so the supervisor can detect if we freeze.
stop_heartbeat = threading.Event()
heartbeat_thread: threading.Thread | None = None
if heartbeat is not None:
heartbeat_thread = threading.Thread(
target=_heartbeat_loop,
args=(heartbeat, stop_heartbeat),
daemon=True,
)
heartbeat_thread.start()
# Import main after setting global logger - this lets us just import logger from this module
try:
from exo.worker.runner.runner import main
@@ -53,6 +78,9 @@ def entrypoint(
)
)
finally:
stop_heartbeat.set()
if heartbeat_thread is not None:
heartbeat_thread.join(timeout=1)
try:
event_sender.close()
task_receiver.close()

View File

@@ -1,6 +1,7 @@
import base64
import json
import math
import os
import resource
import time
from collections.abc import Generator
@@ -999,6 +1000,7 @@ def _validate_single_tool(obj: dict[str, Any]) -> ToolCallItem:
EXO_RUNNER_MUST_FAIL = "EXO RUNNER MUST FAIL"
EXO_RUNNER_MUST_OOM = "EXO RUNNER MUST OOM"
EXO_RUNNER_MUST_TIMEOUT = "EXO RUNNER MUST TIMEOUT"
EXO_RUNNER_MUST_DIE = "EXO RUNNER MUST DIE"
def _check_for_debug_prompts(task_params: TextGenerationTaskParams) -> None:
@@ -1014,6 +1016,9 @@ def _check_for_debug_prompts(task_params: TextGenerationTaskParams) -> None:
if not prompt:
return
if EXO_RUNNER_MUST_DIE in prompt:
logger.info("Abrupt process death triggered (simulates OOM kill)")
os._exit(1)
if EXO_RUNNER_MUST_FAIL in prompt:
logger.info("raising exception")
raise Exception("Artificial runner exception - for testing purposes only.")

View File

@@ -1,12 +1,17 @@
from __future__ import annotations
import contextlib
import multiprocessing
import signal
from dataclasses import dataclass, field
from multiprocessing import Process
from multiprocessing.sharedctypes import Synchronized
from typing import Self
import anyio
from anyio import (
BrokenResourceError,
CancelScope,
ClosedResourceError,
to_thread,
)
@@ -26,6 +31,7 @@ from exo.shared.types.worker.runners import (
RunnerIdle,
RunnerLoading,
RunnerRunning,
RunnerShutdown,
RunnerShuttingDown,
RunnerStatus,
RunnerWarmingUp,
@@ -36,6 +42,8 @@ from exo.worker.runner.bootstrap import entrypoint
PREFILL_TIMEOUT_SECONDS = 60
DECODE_TIMEOUT_SECONDS = 5
HEALTH_CHECK_INTERVAL_SECONDS = 1
HEARTBEAT_STALE_CHECKS = 10
@dataclass(eq=False)
@@ -48,10 +56,14 @@ class RunnerSupervisor:
_task_sender: MpSender[Task]
_event_sender: Sender[Event]
_cancel_sender: MpSender[TaskId]
_heartbeat: Synchronized[int]
status: RunnerStatus = field(default_factory=RunnerIdle, init=False)
pending: dict[TaskId, anyio.Event] = field(default_factory=dict, init=False)
completed: set[TaskId] = field(default_factory=set, init=False)
cancelled: set[TaskId] = field(default_factory=set, init=False)
_death_handled: bool = field(default=False, init=False)
_last_heartbeat_value: int = field(default=0, init=False)
_heartbeat_stale_count: int = field(default=0, init=False)
@classmethod
def create(
@@ -65,6 +77,8 @@ class RunnerSupervisor:
task_sender, task_recv = mp_channel[Task]()
cancel_sender, cancel_recv = mp_channel[TaskId]()
heartbeat: Synchronized[int] = multiprocessing.Value("Q", 0)
runner_process = Process(
target=entrypoint,
args=(
@@ -73,6 +87,7 @@ class RunnerSupervisor:
task_recv,
cancel_recv,
logger,
heartbeat,
),
daemon=True,
)
@@ -88,13 +103,16 @@ class RunnerSupervisor:
_task_sender=task_sender,
_cancel_sender=cancel_sender,
_event_sender=event_sender,
_heartbeat=heartbeat,
)
return self
async def run(self):
self.runner_process.start()
await self._forward_events()
async with anyio.create_task_group() as tg:
tg.start_soon(self._forward_events)
tg.start_soon(self._health_check, tg.cancel_scope)
def shutdown(self):
logger.info("Runner supervisor shutting down")
@@ -177,9 +195,99 @@ class RunnerSupervisor:
self.completed.add(event.task_id)
await self._event_sender.send(event)
except (ClosedResourceError, BrokenResourceError) as e:
await self._check_runner(e)
for tid in self.pending:
self.pending[tid].set()
if not self._death_handled:
self._death_handled = True
await self._check_runner(e)
for tid in self.pending:
self.pending[tid].set()
async def _health_check(self, cancel_scope: CancelScope) -> None:
"""Periodically check if the runner process is alive and responsive.
Detects two failure modes:
1. Process death (e.g. OOM kill) without cleanly closing the event
channel, which would leave _forward_events blocked on queue.get().
2. Unresponsive process (e.g. frozen by OS memory pressure, deadlock)
detected via a stale heartbeat counter.
"""
while True:
await anyio.sleep(HEALTH_CHECK_INTERVAL_SECONDS)
if not self.runner_process.is_alive():
self._handle_process_exit(cancel_scope)
return
# Check heartbeat counter — if it hasn't changed between
# consecutive checks, the subprocess may be frozen.
current = self._heartbeat.value
if current > 0:
if current == self._last_heartbeat_value:
self._heartbeat_stale_count += 1
if self._heartbeat_stale_count >= HEARTBEAT_STALE_CHECKS:
logger.error(
f"Health check: runner process unresponsive "
f"(heartbeat stale for {self._heartbeat_stale_count} checks), killing"
)
self._handle_unresponsive(cancel_scope)
return
else:
self._heartbeat_stale_count = 0
self._last_heartbeat_value = current
def _handle_process_exit(self, cancel_scope: CancelScope) -> None:
"""Handle runner process that has exited."""
if not self._death_handled:
self._death_handled = True
if isinstance(
self.status, (RunnerShutdown, RunnerShuttingDown, RunnerFailed)
):
logger.info("Health check: runner process exited (expected)")
else:
rc = self.runner_process.exitcode
if isinstance(rc, int) and rc < 0:
sig = -rc
try:
cause = f"signal={sig} ({signal.strsignal(sig)})"
except Exception:
cause = f"signal={sig}"
else:
cause = f"exitcode={rc}"
logger.error(
f"Health check: runner process died unexpectedly ({cause})"
)
self._event_sender.send_nowait(
RunnerStatusUpdated(
runner_id=self.bound_instance.bound_runner_id,
runner_status=RunnerFailed(
error_message=f"Terminated ({cause})"
),
)
)
self.shutdown()
for tid in self.pending:
self.pending[tid].set()
cancel_scope.cancel()
def _handle_unresponsive(self, cancel_scope: CancelScope) -> None:
"""Handle runner process that is alive but unresponsive."""
if not self._death_handled:
self._death_handled = True
self._event_sender.send_nowait(
RunnerStatusUpdated(
runner_id=self.bound_instance.bound_runner_id,
runner_status=RunnerFailed(
error_message="Runner process unresponsive (heartbeat timeout)"
),
)
)
for tid in self.pending:
self.pending[tid].set()
self.shutdown()
cancel_scope.cancel()
def __del__(self) -> None:
if self.runner_process.is_alive():

View File

@@ -1 +1,204 @@
# TODO:
from __future__ import annotations
import multiprocessing
import os
import signal as signal_module
from collections.abc import Callable
from multiprocessing.sharedctypes import Synchronized
from typing import Any
import anyio
from exo.shared.types.events import Event, RunnerStatusUpdated
from exo.shared.types.tasks import Task, TaskId
from exo.shared.types.worker.runners import RunnerFailed, RunnerIdle, RunnerShutdown
from exo.utils.channels import Receiver, Sender, channel, mp_channel
from exo.worker.runner.runner_supervisor import (
HEALTH_CHECK_INTERVAL_SECONDS,
HEARTBEAT_STALE_CHECKS,
RunnerSupervisor,
)
from ...constants import (
INSTANCE_1_ID,
MODEL_A_ID,
NODE_A,
RUNNER_1_ID,
)
from ..conftest import get_bound_mlx_ring_instance
def _die_immediately() -> None:
"""Subprocess target that exits with a non-zero code."""
os._exit(1)
def _die_with_signal() -> None:
"""Subprocess target that kills itself with SIGKILL (simulates OOM)."""
os.kill(os.getpid(), signal_module.SIGKILL)
def _exit_cleanly() -> None:
"""Subprocess target that exits with code 0."""
os._exit(0)
def _hang_forever() -> None:
"""Subprocess target that hangs without updating heartbeat (simulates freeze)."""
import time
# Write one heartbeat so the supervisor starts tracking, then stop.
time.sleep(100000)
def _build_supervisor(
event_sender: Sender[Event],
target: Callable[..., Any],
) -> RunnerSupervisor:
"""Build a RunnerSupervisor with a custom subprocess target.
Uses a clone of event_sender (matching real Worker behavior) so that
closing the supervisor's copy doesn't close the test's receiver.
"""
bound_instance = get_bound_mlx_ring_instance(
instance_id=INSTANCE_1_ID,
model_id=MODEL_A_ID,
runner_id=RUNNER_1_ID,
node_id=NODE_A,
)
_ev_send, ev_recv = mp_channel[Event]()
task_sender, _task_recv = mp_channel[Task]()
cancel_sender, _cancel_recv = mp_channel[TaskId]()
runner_process = multiprocessing.Process(target=target, daemon=True)
heartbeat: Synchronized[int] = multiprocessing.Value("Q", 0)
return RunnerSupervisor(
bound_instance=bound_instance,
shard_metadata=bound_instance.bound_shard,
runner_process=runner_process,
initialize_timeout=10,
_ev_recv=ev_recv,
_task_sender=task_sender,
_event_sender=event_sender.clone(),
_cancel_sender=cancel_sender,
_heartbeat=heartbeat,
)
def _collect_failed_events(
event_receiver: Receiver[Event],
) -> list[RunnerFailed]:
"""Drain the receiver and return all RunnerFailed statuses."""
out: list[RunnerFailed] = []
while True:
try:
event = event_receiver.receive_nowait()
except Exception:
break
if isinstance(event, RunnerStatusUpdated) and isinstance(
event.runner_status, RunnerFailed
):
out.append(event.runner_status)
return out
async def test_health_check_detects_dead_process():
"""When the runner process dies with a non-zero exit code, the health check
should emit a RunnerFailed event and run() should return."""
event_sender, event_receiver = channel[Event]()
supervisor = _build_supervisor(event_sender, _die_immediately)
with anyio.fail_after(HEALTH_CHECK_INTERVAL_SECONDS + 5):
await supervisor.run()
failures = _collect_failed_events(event_receiver)
assert len(failures) == 1
assert failures[0].error_message is not None
assert "exitcode=1" in failures[0].error_message
async def test_health_check_detects_signal_death():
"""When the runner process is killed by a signal (e.g. OOM -> SIGKILL),
the health check should report the signal in the failure message."""
event_sender, event_receiver = channel[Event]()
supervisor = _build_supervisor(event_sender, _die_with_signal)
with anyio.fail_after(HEALTH_CHECK_INTERVAL_SECONDS + 5):
await supervisor.run()
failures = _collect_failed_events(event_receiver)
assert len(failures) == 1
assert failures[0].error_message is not None
assert "signal=9" in failures[0].error_message
async def test_health_check_releases_pending_tasks():
"""When the runner dies, any pending start_task() waiters should be unblocked."""
event_sender, _event_receiver = channel[Event]()
supervisor = _build_supervisor(event_sender, _die_immediately)
# Register a pending waiter as if start_task() was waiting for acknowledgement
task_event = anyio.Event()
tid = TaskId("pending-task")
supervisor.pending[tid] = task_event
with anyio.fail_after(HEALTH_CHECK_INTERVAL_SECONDS + 5):
await supervisor.run()
assert task_event.is_set()
async def test_clean_exit_no_failure_when_shutdown_status():
"""When the runner was in RunnerShutdown status and exits with code 0,
no RunnerFailed event should be emitted."""
event_sender, event_receiver = channel[Event]()
supervisor = _build_supervisor(event_sender, _exit_cleanly)
# Simulate that the runner had already reported shutdown via events
supervisor.status = RunnerShutdown()
with anyio.fail_after(HEALTH_CHECK_INTERVAL_SECONDS + 5):
await supervisor.run()
failures = _collect_failed_events(event_receiver)
assert len(failures) == 0
async def test_unexpected_exit_code_zero_emits_failure():
"""When the runner exits with code 0 but was NOT in a shutdown state,
this is unexpected and should still emit RunnerFailed."""
event_sender, event_receiver = channel[Event]()
supervisor = _build_supervisor(event_sender, _exit_cleanly)
assert isinstance(supervisor.status, RunnerIdle)
with anyio.fail_after(HEALTH_CHECK_INTERVAL_SECONDS + 5):
await supervisor.run()
failures = _collect_failed_events(event_receiver)
assert len(failures) == 1
assert failures[0].error_message is not None
assert "exitcode=0" in failures[0].error_message
async def test_heartbeat_timeout_detects_unresponsive_process():
"""When the runner process is alive but its heartbeat goes stale,
the health check should kill it and emit RunnerFailed."""
event_sender, event_receiver = channel[Event]()
supervisor = _build_supervisor(event_sender, _hang_forever)
# Pre-seed the heartbeat counter with a non-zero value and set the
# supervisor's last-seen value to match so it appears stale immediately.
# Set stale count to HEARTBEAT_STALE_CHECKS - 1 so a single check triggers.
supervisor._heartbeat.value = 42 # pyright: ignore[reportPrivateUsage]
supervisor._last_heartbeat_value = 42 # pyright: ignore[reportPrivateUsage]
supervisor._heartbeat_stale_count = HEARTBEAT_STALE_CHECKS - 1 # pyright: ignore[reportPrivateUsage]
with anyio.fail_after(HEALTH_CHECK_INTERVAL_SECONDS + 5):
await supervisor.run()
failures = _collect_failed_events(event_receiver)
assert len(failures) == 1
assert failures[0].error_message is not None
assert "unresponsive" in failures[0].error_message.lower()