mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-20 07:46:42 -05:00
Compare commits
23 Commits
session-id
...
e2e-tests
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a288401a7f | ||
|
|
b36721e6d9 | ||
|
|
671e5de248 | ||
|
|
5ec7b35841 | ||
|
|
92b20128a7 | ||
|
|
1fa4d3b087 | ||
|
|
d9c884b9df | ||
|
|
ba934cf237 | ||
|
|
54a036223b | ||
|
|
199fb1c9fa | ||
|
|
cc7850180d | ||
|
|
37af43e52f | ||
|
|
effafc1d48 | ||
|
|
3fb663ec25 | ||
|
|
c2f9034914 | ||
|
|
2cce4e8f04 | ||
|
|
ce82486a79 | ||
|
|
6e7dc0b042 | ||
|
|
acdf2751a3 | ||
|
|
0035ac3cf2 | ||
|
|
26621503c8 | ||
|
|
b9dcea2b4e | ||
|
|
9bcc0f968b |
15
.dockerignore
Normal file
15
.dockerignore
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
.venv/
|
||||||
|
.direnv/
|
||||||
|
target/
|
||||||
|
.git/
|
||||||
|
.idea/
|
||||||
|
.pytest_cache/
|
||||||
|
.ruff_cache/
|
||||||
|
dashboard/node_modules/
|
||||||
|
dashboard/.svelte-kit/
|
||||||
|
dashboard/build/
|
||||||
|
dist/
|
||||||
|
*.pdb
|
||||||
|
**/__pycache__
|
||||||
|
**/.DS_Store
|
||||||
|
.mlx_typings/
|
||||||
44
.github/workflows/e2e.yml
vendored
Normal file
44
.github/workflows/e2e.yml
vendored
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
name: e2e-tests
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- e2e-tests
|
||||||
|
pull_request:
|
||||||
|
branches:
|
||||||
|
- staging
|
||||||
|
- main
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
e2e:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
timeout-minutes: 45
|
||||||
|
steps:
|
||||||
|
- name: Free up disk space
|
||||||
|
run: |
|
||||||
|
sudo rm -rf /usr/share/dotnet /usr/local/lib/android /opt/ghc \
|
||||||
|
/opt/hostedtoolcache /usr/local/share/boost /usr/share/swift \
|
||||||
|
/opt/microsoft /opt/az
|
||||||
|
docker system prune -af
|
||||||
|
df -h /
|
||||||
|
|
||||||
|
- name: Checkout repository
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
lfs: false
|
||||||
|
|
||||||
|
- name: Set up Docker Buildx
|
||||||
|
uses: docker/setup-buildx-action@v3
|
||||||
|
|
||||||
|
- name: Build E2E image with cache
|
||||||
|
uses: docker/build-push-action@v6
|
||||||
|
with:
|
||||||
|
context: .
|
||||||
|
file: e2e/Dockerfile
|
||||||
|
tags: exo-e2e:latest
|
||||||
|
load: true
|
||||||
|
cache-from: type=gha
|
||||||
|
cache-to: type=gha,mode=max
|
||||||
|
|
||||||
|
- name: Run E2E tests
|
||||||
|
run: python3 e2e/run_all.py
|
||||||
1
conftest.py
Normal file
1
conftest.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
collect_ignore = ["tests/start_distributed_test.py"]
|
||||||
58
e2e/Dockerfile
Normal file
58
e2e/Dockerfile
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
# Stage 1: Build the dashboard
|
||||||
|
FROM node:22-slim AS dashboard
|
||||||
|
WORKDIR /app/dashboard
|
||||||
|
COPY dashboard/package.json dashboard/package-lock.json ./
|
||||||
|
RUN npm ci
|
||||||
|
COPY dashboard/ .
|
||||||
|
RUN npm run build
|
||||||
|
|
||||||
|
# Stage 2: Build and run exo
|
||||||
|
FROM python:3.13-slim
|
||||||
|
|
||||||
|
# Install system dependencies
|
||||||
|
# libblas-dev/liblapack-dev/liblapacke-dev are required by MLX CPU backend on Linux
|
||||||
|
RUN apt-get update && apt-get install -y \
|
||||||
|
build-essential \
|
||||||
|
pkg-config \
|
||||||
|
libssl-dev \
|
||||||
|
libblas-dev \
|
||||||
|
liblapack-dev \
|
||||||
|
liblapacke-dev \
|
||||||
|
curl \
|
||||||
|
protobuf-compiler \
|
||||||
|
iptables \
|
||||||
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
# Install Rust nightly
|
||||||
|
RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y --default-toolchain nightly
|
||||||
|
ENV PATH="/root/.cargo/bin:${PATH}"
|
||||||
|
|
||||||
|
# Wrap g++ with -fpermissive to fix MLX CPU JIT compilation with GCC 14
|
||||||
|
# (GCC 14 treats _Float128/_Float32/_Float64 as built-in types, conflicting with MLX-generated code)
|
||||||
|
# Must be done BEFORE uv sync so any source builds also get the fix
|
||||||
|
RUN mv /usr/bin/g++ /usr/bin/g++.real && \
|
||||||
|
printf '#!/bin/sh\nexec /usr/bin/g++.real -fpermissive "$@"\n' > /usr/bin/g++ && \
|
||||||
|
chmod +x /usr/bin/g++
|
||||||
|
|
||||||
|
# Install uv
|
||||||
|
COPY --from=ghcr.io/astral-sh/uv:latest /uv /usr/local/bin/uv
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
# Copy dependency files first for better layer caching
|
||||||
|
COPY pyproject.toml Cargo.toml uv.lock README.md ./
|
||||||
|
COPY rust/ ./rust/
|
||||||
|
COPY bench/pyproject.toml ./bench/pyproject.toml
|
||||||
|
|
||||||
|
# Copy source and resources
|
||||||
|
COPY src/ ./src/
|
||||||
|
COPY resources/ ./resources/
|
||||||
|
|
||||||
|
# Copy built dashboard from stage 1
|
||||||
|
COPY --from=dashboard /app/dashboard/build ./dashboard/build/
|
||||||
|
|
||||||
|
# Install Python deps and build Rust bindings, then clean up build artifacts
|
||||||
|
# to keep the layer small (Rust target/ and cargo registry can be 1-2 GB)
|
||||||
|
RUN uv sync && rm -rf /app/rust/target /root/.cargo/registry /root/.cargo/git
|
||||||
|
|
||||||
|
CMD [".venv/bin/exo", "-v"]
|
||||||
195
e2e/conftest.py
Normal file
195
e2e/conftest.py
Normal file
@@ -0,0 +1,195 @@
|
|||||||
|
"""Shared E2E test infrastructure for exo cluster tests."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from urllib.error import URLError
|
||||||
|
from urllib.request import Request, urlopen
|
||||||
|
|
||||||
|
E2E_DIR = Path(__file__).parent.resolve()
|
||||||
|
TIMEOUT = int(os.environ.get("E2E_TIMEOUT", "120"))
|
||||||
|
|
||||||
|
|
||||||
|
class Cluster:
|
||||||
|
"""Async wrapper around a docker compose exo cluster."""
|
||||||
|
|
||||||
|
def __init__(self, name: str, overrides: list[str] | None = None):
|
||||||
|
self.name = name
|
||||||
|
self.project = f"e2e-{name}"
|
||||||
|
compose_files = [str(E2E_DIR / "docker-compose.yml")]
|
||||||
|
for path in overrides or []:
|
||||||
|
compose_files.append(str(E2E_DIR / path))
|
||||||
|
self._compose_base = [
|
||||||
|
"docker",
|
||||||
|
"compose",
|
||||||
|
"-p",
|
||||||
|
self.project,
|
||||||
|
*[arg for f in compose_files for arg in ("-f", f)],
|
||||||
|
]
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, *exc):
|
||||||
|
await self.stop()
|
||||||
|
|
||||||
|
async def _run(self, *args: str, check: bool = True) -> str:
|
||||||
|
proc = await asyncio.create_subprocess_exec(
|
||||||
|
*self._compose_base,
|
||||||
|
*args,
|
||||||
|
stdout=asyncio.subprocess.PIPE,
|
||||||
|
stderr=asyncio.subprocess.STDOUT,
|
||||||
|
)
|
||||||
|
stdout, _ = await proc.communicate()
|
||||||
|
output = stdout.decode()
|
||||||
|
if check and proc.returncode != 0:
|
||||||
|
print(output, file=sys.stderr)
|
||||||
|
raise RuntimeError(
|
||||||
|
f"docker compose {' '.join(args)} failed (rc={proc.returncode})"
|
||||||
|
)
|
||||||
|
return output
|
||||||
|
|
||||||
|
async def build(self):
|
||||||
|
# Skip build if the image was pre-built (e.g. in CI with buildx cache)
|
||||||
|
proc = await asyncio.create_subprocess_exec(
|
||||||
|
"docker",
|
||||||
|
"image",
|
||||||
|
"inspect",
|
||||||
|
"exo-e2e:latest",
|
||||||
|
stdout=asyncio.subprocess.DEVNULL,
|
||||||
|
stderr=asyncio.subprocess.DEVNULL,
|
||||||
|
)
|
||||||
|
await proc.wait()
|
||||||
|
if proc.returncode == 0:
|
||||||
|
print(" Using pre-built image (exo-e2e:latest)")
|
||||||
|
return
|
||||||
|
print(" Building images...")
|
||||||
|
await self._run("build", "--quiet")
|
||||||
|
|
||||||
|
async def start(self):
|
||||||
|
print(" Starting cluster...")
|
||||||
|
await self._run("up", "-d")
|
||||||
|
|
||||||
|
async def stop(self):
|
||||||
|
print(" Cleaning up...")
|
||||||
|
await self._run("down", "--timeout", "5", check=False)
|
||||||
|
|
||||||
|
async def logs(self) -> str:
|
||||||
|
return await self._run("logs", check=False)
|
||||||
|
|
||||||
|
async def exec(
|
||||||
|
self, service: str, *cmd: str, check: bool = True
|
||||||
|
) -> tuple[int, str]:
|
||||||
|
"""Run a command inside a running container. Returns (returncode, output)."""
|
||||||
|
proc = await asyncio.create_subprocess_exec(
|
||||||
|
*self._compose_base,
|
||||||
|
"exec",
|
||||||
|
"-T",
|
||||||
|
service,
|
||||||
|
*cmd,
|
||||||
|
stdout=asyncio.subprocess.PIPE,
|
||||||
|
stderr=asyncio.subprocess.STDOUT,
|
||||||
|
)
|
||||||
|
stdout, _ = await proc.communicate()
|
||||||
|
output = stdout.decode()
|
||||||
|
if check and proc.returncode != 0:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"exec {' '.join(cmd)} in {service} failed (rc={proc.returncode})"
|
||||||
|
)
|
||||||
|
return proc.returncode, output
|
||||||
|
|
||||||
|
async def wait_for(self, description: str, check_fn, timeout: int = TIMEOUT):
|
||||||
|
"""Poll check_fn every 2s until it returns True or timeout expires."""
|
||||||
|
print(f" Waiting for {description}...")
|
||||||
|
deadline = asyncio.get_event_loop().time() + timeout
|
||||||
|
while asyncio.get_event_loop().time() < deadline:
|
||||||
|
if await check_fn():
|
||||||
|
print(f" {description}")
|
||||||
|
return
|
||||||
|
await asyncio.sleep(2)
|
||||||
|
output = await self.logs()
|
||||||
|
print(f"--- cluster logs ---\n{output}\n---", file=sys.stderr)
|
||||||
|
raise TimeoutError(f"Timed out waiting for {description}")
|
||||||
|
|
||||||
|
async def assert_healthy(self):
|
||||||
|
"""Verify the cluster formed correctly: nodes started, discovered each other, elected a master, API responds."""
|
||||||
|
|
||||||
|
async def both_nodes_started():
|
||||||
|
log = await self.logs()
|
||||||
|
return log.count("Starting node") >= 2
|
||||||
|
|
||||||
|
async def nodes_discovered():
|
||||||
|
log = await self.logs()
|
||||||
|
return log.count("ConnectionMessageType.Connected") >= 2
|
||||||
|
|
||||||
|
async def master_elected():
|
||||||
|
log = await self.logs()
|
||||||
|
return "demoting self" in log
|
||||||
|
|
||||||
|
async def api_responding():
|
||||||
|
try:
|
||||||
|
with urlopen("http://localhost:52415/v1/models", timeout=3) as resp:
|
||||||
|
return resp.status == 200
|
||||||
|
except (URLError, OSError):
|
||||||
|
return False
|
||||||
|
|
||||||
|
await self.wait_for("Both nodes started", both_nodes_started)
|
||||||
|
await self.wait_for("Nodes discovered each other", nodes_discovered)
|
||||||
|
await self.wait_for("Master election resolved", master_elected)
|
||||||
|
await self.wait_for("API responding", api_responding)
|
||||||
|
|
||||||
|
async def _api(
|
||||||
|
self, method: str, path: str, body: dict | None = None, timeout: int = 30
|
||||||
|
) -> dict:
|
||||||
|
"""Make an API request to the cluster. Returns parsed JSON."""
|
||||||
|
url = f"http://localhost:52415{path}"
|
||||||
|
data = json.dumps(body).encode() if body else None
|
||||||
|
req = Request(
|
||||||
|
url, data=data, headers={"Content-Type": "application/json"}, method=method
|
||||||
|
)
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
resp_bytes = await loop.run_in_executor(
|
||||||
|
None, lambda: urlopen(req, timeout=timeout).read()
|
||||||
|
)
|
||||||
|
return json.loads(resp_bytes)
|
||||||
|
|
||||||
|
async def place_model(self, model: str, timeout: int = 600):
|
||||||
|
"""Place a model instance on the cluster (triggers download) and wait until it's ready."""
|
||||||
|
await self._api("POST", "/place_instance", {"model_id": model})
|
||||||
|
|
||||||
|
async def model_ready():
|
||||||
|
try:
|
||||||
|
resp = await self._api("GET", "/v1/models")
|
||||||
|
return any(m.get("id") == model for m in resp.get("data", []))
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
await self.wait_for(f"Model {model} ready", model_ready, timeout=timeout)
|
||||||
|
|
||||||
|
async def chat(
|
||||||
|
self, model: str, messages: list[dict], timeout: int = 600, **kwargs
|
||||||
|
) -> dict:
|
||||||
|
"""Send a chat completion request. Retries until model is downloaded and inference completes."""
|
||||||
|
body = json.dumps({"model": model, "messages": messages, **kwargs}).encode()
|
||||||
|
deadline = asyncio.get_event_loop().time() + timeout
|
||||||
|
last_error = None
|
||||||
|
|
||||||
|
while asyncio.get_event_loop().time() < deadline:
|
||||||
|
try:
|
||||||
|
req = Request(
|
||||||
|
"http://localhost:52415/v1/chat/completions",
|
||||||
|
data=body,
|
||||||
|
headers={"Content-Type": "application/json"},
|
||||||
|
)
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
resp_bytes = await loop.run_in_executor(
|
||||||
|
None, lambda r=req: urlopen(r, timeout=300).read()
|
||||||
|
)
|
||||||
|
return json.loads(resp_bytes)
|
||||||
|
except Exception as e:
|
||||||
|
last_error = e
|
||||||
|
await asyncio.sleep(5)
|
||||||
|
|
||||||
|
raise TimeoutError(f"Chat request failed after {timeout}s: {last_error}")
|
||||||
20
e2e/docker-compose.yml
Normal file
20
e2e/docker-compose.yml
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
services:
|
||||||
|
exo-node-1:
|
||||||
|
image: exo-e2e:latest
|
||||||
|
build:
|
||||||
|
context: ..
|
||||||
|
dockerfile: e2e/Dockerfile
|
||||||
|
environment:
|
||||||
|
- EXO_LIBP2P_NAMESPACE=docker-e2e
|
||||||
|
command: [".venv/bin/exo", "-v"]
|
||||||
|
ports:
|
||||||
|
- "52415:52415"
|
||||||
|
|
||||||
|
exo-node-2:
|
||||||
|
image: exo-e2e:latest
|
||||||
|
build:
|
||||||
|
context: ..
|
||||||
|
dockerfile: e2e/Dockerfile
|
||||||
|
environment:
|
||||||
|
- EXO_LIBP2P_NAMESPACE=docker-e2e
|
||||||
|
command: [".venv/bin/exo", "-v"]
|
||||||
83
e2e/run_all.py
Normal file
83
e2e/run_all.py
Normal file
@@ -0,0 +1,83 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Discovers and runs all E2E tests in e2e/test_*.py.
|
||||||
|
|
||||||
|
Tests with '# slow' on the first line of their docstring are skipped
|
||||||
|
unless --slow is passed or E2E_SLOW=1 is set.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
E2E_DIR = Path(__file__).parent.resolve()
|
||||||
|
|
||||||
|
|
||||||
|
def is_slow(test_file: Path) -> bool:
|
||||||
|
"""Check if the test file is marked as slow (has '# slow' in first 3 lines)."""
|
||||||
|
with open(test_file) as f:
|
||||||
|
for line in f:
|
||||||
|
if line.strip().startswith("#"):
|
||||||
|
continue
|
||||||
|
if line.strip().startswith('"""') or line.strip().startswith("'''"):
|
||||||
|
# Read into the docstring
|
||||||
|
for doc_line in f:
|
||||||
|
if "slow" in doc_line.lower() and doc_line.strip().startswith(
|
||||||
|
"slow"
|
||||||
|
):
|
||||||
|
return True
|
||||||
|
if '"""' in doc_line or "'''" in doc_line:
|
||||||
|
break
|
||||||
|
break
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
run_slow = "--slow" in sys.argv or os.environ.get("E2E_SLOW") == "1"
|
||||||
|
if "--update-snapshots" in sys.argv:
|
||||||
|
os.environ["UPDATE_SNAPSHOTS"] = "1"
|
||||||
|
test_files = sorted(E2E_DIR.glob("test_*.py"))
|
||||||
|
if not test_files:
|
||||||
|
print("No test files found")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
passed = 0
|
||||||
|
failed = 0
|
||||||
|
skipped = 0
|
||||||
|
failures = []
|
||||||
|
|
||||||
|
for test_file in test_files:
|
||||||
|
name = test_file.stem
|
||||||
|
if is_slow(test_file) and not run_slow:
|
||||||
|
print(f"=== {name} === SKIPPED (slow, use --slow to run)")
|
||||||
|
skipped += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
print(f"=== {name} ===")
|
||||||
|
result = subprocess.run([sys.executable, str(test_file)])
|
||||||
|
if result.returncode == 0:
|
||||||
|
passed += 1
|
||||||
|
else:
|
||||||
|
# Retry once — Docker networking (mDNS) can be slow on first boot
|
||||||
|
print(f"\n=== {name} === RETRYING (attempt 2/2)")
|
||||||
|
result = subprocess.run([sys.executable, str(test_file)])
|
||||||
|
if result.returncode == 0:
|
||||||
|
passed += 1
|
||||||
|
else:
|
||||||
|
failed += 1
|
||||||
|
failures.append(name)
|
||||||
|
print()
|
||||||
|
|
||||||
|
total = passed + failed + skipped
|
||||||
|
print("================================")
|
||||||
|
print(
|
||||||
|
f"{passed}/{total} tests passed" + (f", {skipped} skipped" if skipped else "")
|
||||||
|
)
|
||||||
|
|
||||||
|
if failed:
|
||||||
|
print(f"Failed: {' '.join(failures)}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
78
e2e/snapshot.py
Normal file
78
e2e/snapshot.py
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
"""Snapshot testing infrastructure for E2E tests.
|
||||||
|
|
||||||
|
Provides deterministic regression testing by comparing inference output
|
||||||
|
against committed baseline snapshots. Tests FAIL if no baseline exists —
|
||||||
|
baselines must be explicitly generated and committed.
|
||||||
|
|
||||||
|
Generate baselines: UPDATE_SNAPSHOTS=1 python3 e2e/run_all.py --slow
|
||||||
|
Update after intentional changes: UPDATE_SNAPSHOTS=1 python3 e2e/run_all.py --slow
|
||||||
|
|
||||||
|
Snapshots are stored per-architecture (e.g. snapshots/x86_64/, snapshots/arm64/)
|
||||||
|
since floating-point results differ between CPU architectures.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import difflib
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import platform
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
ARCH = platform.machine()
|
||||||
|
SNAPSHOTS_DIR = Path(__file__).parent / "snapshots" / ARCH
|
||||||
|
|
||||||
|
|
||||||
|
def assert_snapshot(
|
||||||
|
name: str,
|
||||||
|
content: str,
|
||||||
|
metadata: dict,
|
||||||
|
) -> None:
|
||||||
|
"""Compare content against a saved snapshot, or create one if missing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Snapshot identifier (used as filename: snapshots/{arch}/{name}.json).
|
||||||
|
content: The actual inference output to compare.
|
||||||
|
metadata: Additional context stored alongside content (model, seed, etc.).
|
||||||
|
Not used for comparison -- purely documentary.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AssertionError: If content doesn't match the saved snapshot.
|
||||||
|
|
||||||
|
Environment:
|
||||||
|
UPDATE_SNAPSHOTS=1: Overwrite existing snapshot with actual content.
|
||||||
|
"""
|
||||||
|
snapshot_file = SNAPSHOTS_DIR / f"{name}.json"
|
||||||
|
update = os.environ.get("UPDATE_SNAPSHOTS") == "1"
|
||||||
|
|
||||||
|
if update:
|
||||||
|
# Explicitly regenerate snapshot
|
||||||
|
SNAPSHOTS_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
snapshot_data = {**metadata, "arch": ARCH, "content": content}
|
||||||
|
snapshot_file.write_text(json.dumps(snapshot_data, indent=2) + "\n")
|
||||||
|
print(f" Updated snapshot: {ARCH}/{snapshot_file.name}")
|
||||||
|
elif not snapshot_file.exists():
|
||||||
|
raise AssertionError(
|
||||||
|
f"No baseline snapshot for '{name}' on {ARCH}.\n"
|
||||||
|
f"Expected file: {snapshot_file}\n\n"
|
||||||
|
f"Generate baselines with: UPDATE_SNAPSHOTS=1 python3 e2e/run_all.py --slow"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
snapshot = json.loads(snapshot_file.read_text())
|
||||||
|
expected = snapshot["content"]
|
||||||
|
if content != expected:
|
||||||
|
diff = "\n".join(
|
||||||
|
difflib.unified_diff(
|
||||||
|
expected.splitlines(),
|
||||||
|
content.splitlines(),
|
||||||
|
fromfile=f"expected ({snapshot_file.relative_to(SNAPSHOTS_DIR.parent.parent)})",
|
||||||
|
tofile="actual",
|
||||||
|
lineterm="",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
raise AssertionError(
|
||||||
|
f"Snapshot mismatch for '{name}' on {ARCH}!\n\n"
|
||||||
|
f"{diff}\n\n"
|
||||||
|
f"Expected: {expected!r}\n"
|
||||||
|
f"Actual: {content!r}\n\n"
|
||||||
|
f"To update: UPDATE_SNAPSHOTS=1 python3 e2e/run_all.py --slow"
|
||||||
|
)
|
||||||
|
print(f" Output matches snapshot ({ARCH}/{snapshot_file.name})")
|
||||||
22
e2e/test_cluster_formation.py
Normal file
22
e2e/test_cluster_formation.py
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
"""Test: Basic cluster formation.
|
||||||
|
|
||||||
|
Verifies two nodes discover each other, elect a master, and the API responds.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import sys
|
||||||
|
|
||||||
|
sys.path.insert(0, str(__import__("pathlib").Path(__file__).parent))
|
||||||
|
from conftest import Cluster
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
async with Cluster("cluster_formation") as cluster:
|
||||||
|
await cluster.build()
|
||||||
|
await cluster.start()
|
||||||
|
await cluster.assert_healthy()
|
||||||
|
print("PASSED: cluster_formation")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
61
e2e/test_inference_snapshot.py
Normal file
61
e2e/test_inference_snapshot.py
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
"""Test: Deterministic inference output (snapshot test).
|
||||||
|
|
||||||
|
Sends a chat completion request with a fixed seed,
|
||||||
|
then verifies the output matches a known-good snapshot. This ensures
|
||||||
|
inference produces consistent results across runs.
|
||||||
|
|
||||||
|
Uses MLX CPU backend in Docker on x86 Linux.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent))
|
||||||
|
from snapshot import assert_snapshot
|
||||||
|
|
||||||
|
from conftest import Cluster
|
||||||
|
|
||||||
|
MODEL = "mlx-community/Qwen3-0.6B-4bit"
|
||||||
|
SEED = 42
|
||||||
|
PROMPT = "What is 2+2? Reply with just the number."
|
||||||
|
MAX_TOKENS = 32
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
async with Cluster("inference_snapshot") as cluster:
|
||||||
|
await cluster.build()
|
||||||
|
await cluster.start()
|
||||||
|
await cluster.assert_healthy()
|
||||||
|
|
||||||
|
print(f" Launching model {MODEL}...")
|
||||||
|
await cluster.place_model(MODEL)
|
||||||
|
|
||||||
|
print(f" Sending chat completion (seed={SEED})...")
|
||||||
|
resp = await cluster.chat(
|
||||||
|
model=MODEL,
|
||||||
|
messages=[{"role": "user", "content": PROMPT}],
|
||||||
|
seed=SEED,
|
||||||
|
temperature=0,
|
||||||
|
max_tokens=MAX_TOKENS,
|
||||||
|
)
|
||||||
|
|
||||||
|
content = resp["choices"][0]["message"]["content"]
|
||||||
|
print(f" Response: {content!r}")
|
||||||
|
|
||||||
|
assert_snapshot(
|
||||||
|
name="inference_snapshot",
|
||||||
|
content=content,
|
||||||
|
metadata={
|
||||||
|
"model": MODEL,
|
||||||
|
"seed": SEED,
|
||||||
|
"prompt": PROMPT,
|
||||||
|
"max_tokens": MAX_TOKENS,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
print("PASSED: inference_snapshot")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
47
e2e/test_no_internet.py
Normal file
47
e2e/test_no_internet.py
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
"""Test: Cluster works without internet access.
|
||||||
|
|
||||||
|
Verifies exo functions correctly when containers can talk to each other
|
||||||
|
but cannot reach the internet. Uses iptables to block all outbound traffic
|
||||||
|
except private subnets and multicast (for mDNS discovery).
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import sys
|
||||||
|
|
||||||
|
sys.path.insert(0, str(__import__("pathlib").Path(__file__).parent))
|
||||||
|
from conftest import Cluster
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
async with Cluster(
|
||||||
|
"no_internet",
|
||||||
|
overrides=["tests/no_internet/docker-compose.override.yml"],
|
||||||
|
) as cluster:
|
||||||
|
await cluster.build()
|
||||||
|
await cluster.start()
|
||||||
|
await cluster.assert_healthy()
|
||||||
|
|
||||||
|
# Verify internet is actually blocked from inside the containers
|
||||||
|
for node in ["exo-node-1", "exo-node-2"]:
|
||||||
|
rc, _ = await cluster.exec(
|
||||||
|
node,
|
||||||
|
"curl",
|
||||||
|
"-sf",
|
||||||
|
"--max-time",
|
||||||
|
"3",
|
||||||
|
"https://huggingface.co",
|
||||||
|
check=False,
|
||||||
|
)
|
||||||
|
assert rc != 0, f"{node} should not be able to reach the internet"
|
||||||
|
print(f" {node}: internet correctly blocked")
|
||||||
|
|
||||||
|
# Verify exo detected no internet connectivity
|
||||||
|
log = await cluster.logs()
|
||||||
|
assert "Internet connectivity: False" in log, "exo should detect no internet"
|
||||||
|
print(" exo correctly detected no internet connectivity")
|
||||||
|
|
||||||
|
print("PASSED: no_internet")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
65
e2e/test_runner_chaos.py
Normal file
65
e2e/test_runner_chaos.py
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
"""Test: Runner chaos — abrupt runner death detection.
|
||||||
|
slow
|
||||||
|
|
||||||
|
Sends a chat completion with the EXO_RUNNER_MUST_DIE trigger, which causes
|
||||||
|
the runner process to call os._exit(1) (simulating an OOM kill). Verifies that
|
||||||
|
the RunnerSupervisor health check detects the death and the system doesn't hang.
|
||||||
|
|
||||||
|
Requires a machine that can run MLX inference at reasonable speed (Apple Silicon).
|
||||||
|
Run with: python3 e2e/run_all.py --slow or E2E_SLOW=1 python3 e2e/run_all.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import contextlib
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent))
|
||||||
|
from conftest import Cluster
|
||||||
|
|
||||||
|
MODEL = "mlx-community/Qwen3-0.6B-4bit"
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
async with Cluster("runner_chaos") as cluster:
|
||||||
|
await cluster.build()
|
||||||
|
await cluster.start()
|
||||||
|
await cluster.assert_healthy()
|
||||||
|
|
||||||
|
# Place the model so a runner is loaded and ready
|
||||||
|
print(f" Launching model {MODEL}...")
|
||||||
|
await cluster.place_model(MODEL)
|
||||||
|
|
||||||
|
# Send a chat request with the die trigger.
|
||||||
|
# The runner will call os._exit(1) mid-inference, simulating OOM kill.
|
||||||
|
# The chat request itself will fail — that's expected.
|
||||||
|
print(" Sending EXO_RUNNER_MUST_DIE trigger...")
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
await cluster.chat(
|
||||||
|
model=MODEL,
|
||||||
|
messages=[{"role": "user", "content": "EXO RUNNER MUST DIE"}],
|
||||||
|
timeout=60,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Wait for the health check to detect the death and emit RunnerFailed
|
||||||
|
async def health_check_detected():
|
||||||
|
log = await cluster.logs()
|
||||||
|
return "runner process died unexpectedly" in log
|
||||||
|
|
||||||
|
await cluster.wait_for(
|
||||||
|
"Health check detected runner death",
|
||||||
|
health_check_detected,
|
||||||
|
timeout=30,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify RunnerFailed was emitted (visible in logs)
|
||||||
|
log = await cluster.logs()
|
||||||
|
assert "runner process died unexpectedly" in log, (
|
||||||
|
f"Expected health check to detect runner death but it didn't.\nLogs:\n{log}"
|
||||||
|
)
|
||||||
|
|
||||||
|
print("PASSED: runner_chaos")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
60
e2e/test_snapshot_code_gen.py
Normal file
60
e2e/test_snapshot_code_gen.py
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
"""Test: Code generation snapshot.
|
||||||
|
slow
|
||||||
|
|
||||||
|
Verifies deterministic output for a code generation prompt.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent))
|
||||||
|
from snapshot import assert_snapshot
|
||||||
|
|
||||||
|
from conftest import Cluster
|
||||||
|
|
||||||
|
MODEL = "mlx-community/Qwen3-0.6B-4bit"
|
||||||
|
SEED = 42
|
||||||
|
PROMPT = (
|
||||||
|
"Write a Python function to reverse a string. Only output the code, no explanation."
|
||||||
|
)
|
||||||
|
MAX_TOKENS = 64
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
async with Cluster("snapshot_code_gen") as cluster:
|
||||||
|
await cluster.build()
|
||||||
|
await cluster.start()
|
||||||
|
await cluster.assert_healthy()
|
||||||
|
|
||||||
|
print(f" Launching model {MODEL}...")
|
||||||
|
await cluster.place_model(MODEL)
|
||||||
|
|
||||||
|
print(f" Sending chat completion (seed={SEED})...")
|
||||||
|
resp = await cluster.chat(
|
||||||
|
model=MODEL,
|
||||||
|
messages=[{"role": "user", "content": PROMPT}],
|
||||||
|
seed=SEED,
|
||||||
|
temperature=0,
|
||||||
|
max_tokens=MAX_TOKENS,
|
||||||
|
)
|
||||||
|
|
||||||
|
content = resp["choices"][0]["message"]["content"]
|
||||||
|
print(f" Response: {content!r}")
|
||||||
|
|
||||||
|
assert_snapshot(
|
||||||
|
name="snapshot_code_gen",
|
||||||
|
content=content,
|
||||||
|
metadata={
|
||||||
|
"model": MODEL,
|
||||||
|
"seed": SEED,
|
||||||
|
"prompt": PROMPT,
|
||||||
|
"max_tokens": MAX_TOKENS,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
print("PASSED: snapshot_code_gen")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
65
e2e/test_snapshot_edge.py
Normal file
65
e2e/test_snapshot_edge.py
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
"""Test: Edge case snapshots.
|
||||||
|
slow
|
||||||
|
|
||||||
|
Verifies deterministic output for edge-case prompts: single word input,
|
||||||
|
special characters, and unicode.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent))
|
||||||
|
from snapshot import assert_snapshot
|
||||||
|
|
||||||
|
from conftest import Cluster
|
||||||
|
|
||||||
|
MODEL = "mlx-community/Qwen3-0.6B-4bit"
|
||||||
|
SEED = 42
|
||||||
|
MAX_TOKENS = 32
|
||||||
|
|
||||||
|
CASES = [
|
||||||
|
("edge_single_word", "Hi"),
|
||||||
|
("edge_special_chars", "What does 2 * (3 + 4) / 7 - 1 equal? Use <math> tags."),
|
||||||
|
("edge_unicode", "Translate 'hello' to Japanese, Chinese, and Korean."),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
async with Cluster("snapshot_edge") as cluster:
|
||||||
|
await cluster.build()
|
||||||
|
await cluster.start()
|
||||||
|
await cluster.assert_healthy()
|
||||||
|
|
||||||
|
print(f" Launching model {MODEL}...")
|
||||||
|
await cluster.place_model(MODEL)
|
||||||
|
|
||||||
|
for snapshot_name, prompt in CASES:
|
||||||
|
print(f" [{snapshot_name}] Sending: {prompt!r}")
|
||||||
|
resp = await cluster.chat(
|
||||||
|
model=MODEL,
|
||||||
|
messages=[{"role": "user", "content": prompt}],
|
||||||
|
seed=SEED,
|
||||||
|
temperature=0,
|
||||||
|
max_tokens=MAX_TOKENS,
|
||||||
|
)
|
||||||
|
|
||||||
|
content = resp["choices"][0]["message"]["content"]
|
||||||
|
print(f" [{snapshot_name}] Response: {content!r}")
|
||||||
|
|
||||||
|
assert_snapshot(
|
||||||
|
name=snapshot_name,
|
||||||
|
content=content,
|
||||||
|
metadata={
|
||||||
|
"model": MODEL,
|
||||||
|
"seed": SEED,
|
||||||
|
"prompt": prompt,
|
||||||
|
"max_tokens": MAX_TOKENS,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
print("PASSED: snapshot_edge")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
58
e2e/test_snapshot_long_output.py
Normal file
58
e2e/test_snapshot_long_output.py
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
"""Test: Longer output snapshot.
|
||||||
|
slow
|
||||||
|
|
||||||
|
Verifies deterministic output with a higher max_tokens (128).
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent))
|
||||||
|
from snapshot import assert_snapshot
|
||||||
|
|
||||||
|
from conftest import Cluster
|
||||||
|
|
||||||
|
MODEL = "mlx-community/Qwen3-0.6B-4bit"
|
||||||
|
SEED = 42
|
||||||
|
PROMPT = "Explain how a binary search algorithm works."
|
||||||
|
MAX_TOKENS = 128
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
async with Cluster("snapshot_long_output") as cluster:
|
||||||
|
await cluster.build()
|
||||||
|
await cluster.start()
|
||||||
|
await cluster.assert_healthy()
|
||||||
|
|
||||||
|
print(f" Launching model {MODEL}...")
|
||||||
|
await cluster.place_model(MODEL)
|
||||||
|
|
||||||
|
print(f" Sending chat completion (seed={SEED}, max_tokens={MAX_TOKENS})...")
|
||||||
|
resp = await cluster.chat(
|
||||||
|
model=MODEL,
|
||||||
|
messages=[{"role": "user", "content": PROMPT}],
|
||||||
|
seed=SEED,
|
||||||
|
temperature=0,
|
||||||
|
max_tokens=MAX_TOKENS,
|
||||||
|
)
|
||||||
|
|
||||||
|
content = resp["choices"][0]["message"]["content"]
|
||||||
|
print(f" Response: {content!r}")
|
||||||
|
|
||||||
|
assert_snapshot(
|
||||||
|
name="snapshot_long_output",
|
||||||
|
content=content,
|
||||||
|
metadata={
|
||||||
|
"model": MODEL,
|
||||||
|
"seed": SEED,
|
||||||
|
"prompt": PROMPT,
|
||||||
|
"max_tokens": MAX_TOKENS,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
print("PASSED: snapshot_long_output")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
73
e2e/test_snapshot_multi_model.py
Normal file
73
e2e/test_snapshot_multi_model.py
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
"""Test: Multi-model snapshot tests.
|
||||||
|
slow
|
||||||
|
|
||||||
|
Verifies deterministic output across different model architectures to catch
|
||||||
|
model-specific regressions. Each model uses its own snapshot file.
|
||||||
|
Run with: python3 e2e/run_all.py --slow or E2E_SLOW=1 python3 e2e/run_all.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent))
|
||||||
|
from snapshot import assert_snapshot
|
||||||
|
|
||||||
|
from conftest import Cluster
|
||||||
|
|
||||||
|
SEED = 42
|
||||||
|
PROMPT = "What is the capital of France?"
|
||||||
|
MAX_TOKENS = 32
|
||||||
|
|
||||||
|
MODELS = [
|
||||||
|
"mlx-community/SmolLM2-135M-Instruct",
|
||||||
|
"mlx-community/Llama-3.2-1B-Instruct-4bit",
|
||||||
|
"mlx-community/gemma-2-2b-it-4bit",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
async with Cluster("snapshot_multi_model") as cluster:
|
||||||
|
await cluster.build()
|
||||||
|
await cluster.start()
|
||||||
|
await cluster.assert_healthy()
|
||||||
|
|
||||||
|
for model in MODELS:
|
||||||
|
short_name = (
|
||||||
|
model.split("/")[-1].lower().replace("-", "_").replace(".", "_")
|
||||||
|
)
|
||||||
|
snapshot_name = f"snapshot_multi_{short_name}"
|
||||||
|
|
||||||
|
print(f" Launching model {model}...")
|
||||||
|
await cluster.place_model(model)
|
||||||
|
|
||||||
|
print(f" Sending chat completion (seed={SEED})...")
|
||||||
|
resp = await cluster.chat(
|
||||||
|
model=model,
|
||||||
|
messages=[{"role": "user", "content": PROMPT}],
|
||||||
|
seed=SEED,
|
||||||
|
temperature=0,
|
||||||
|
max_tokens=MAX_TOKENS,
|
||||||
|
)
|
||||||
|
|
||||||
|
content = resp["choices"][0]["message"]["content"]
|
||||||
|
print(f" [{short_name}] Response: {content!r}")
|
||||||
|
|
||||||
|
assert_snapshot(
|
||||||
|
name=snapshot_name,
|
||||||
|
content=content,
|
||||||
|
metadata={
|
||||||
|
"model": model,
|
||||||
|
"seed": SEED,
|
||||||
|
"prompt": PROMPT,
|
||||||
|
"max_tokens": MAX_TOKENS,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f" [{short_name}] PASSED")
|
||||||
|
|
||||||
|
print("PASSED: snapshot_multi_model")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
58
e2e/test_snapshot_reasoning.py
Normal file
58
e2e/test_snapshot_reasoning.py
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
"""Test: Reasoning/math snapshot.
|
||||||
|
slow
|
||||||
|
|
||||||
|
Verifies deterministic output for a simple reasoning prompt.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent))
|
||||||
|
from snapshot import assert_snapshot
|
||||||
|
|
||||||
|
from conftest import Cluster
|
||||||
|
|
||||||
|
MODEL = "mlx-community/Qwen3-0.6B-4bit"
|
||||||
|
SEED = 42
|
||||||
|
PROMPT = "If I have 3 apples and give away 1, how many do I have? Think step by step."
|
||||||
|
MAX_TOKENS = 64
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
async with Cluster("snapshot_reasoning") as cluster:
|
||||||
|
await cluster.build()
|
||||||
|
await cluster.start()
|
||||||
|
await cluster.assert_healthy()
|
||||||
|
|
||||||
|
print(f" Launching model {MODEL}...")
|
||||||
|
await cluster.place_model(MODEL)
|
||||||
|
|
||||||
|
print(f" Sending chat completion (seed={SEED})...")
|
||||||
|
resp = await cluster.chat(
|
||||||
|
model=MODEL,
|
||||||
|
messages=[{"role": "user", "content": PROMPT}],
|
||||||
|
seed=SEED,
|
||||||
|
temperature=0,
|
||||||
|
max_tokens=MAX_TOKENS,
|
||||||
|
)
|
||||||
|
|
||||||
|
content = resp["choices"][0]["message"]["content"]
|
||||||
|
print(f" Response: {content!r}")
|
||||||
|
|
||||||
|
assert_snapshot(
|
||||||
|
name="snapshot_reasoning",
|
||||||
|
content=content,
|
||||||
|
metadata={
|
||||||
|
"model": MODEL,
|
||||||
|
"seed": SEED,
|
||||||
|
"prompt": PROMPT,
|
||||||
|
"max_tokens": MAX_TOKENS,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
print("PASSED: snapshot_reasoning")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
32
e2e/tests/no_internet/docker-compose.override.yml
Normal file
32
e2e/tests/no_internet/docker-compose.override.yml
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
# Block all outbound internet traffic using iptables while preserving:
|
||||||
|
# - Multicast (224.0.0.0/4) for mDNS peer discovery
|
||||||
|
# - Private subnets (10/8, 172.16/12, 192.168/16) for inter-container communication
|
||||||
|
# - Loopback (127/8)
|
||||||
|
# Requires NET_ADMIN capability for iptables.
|
||||||
|
services:
|
||||||
|
exo-node-1:
|
||||||
|
cap_add:
|
||||||
|
- NET_ADMIN
|
||||||
|
entrypoint: ["/bin/sh", "-c"]
|
||||||
|
command:
|
||||||
|
- |
|
||||||
|
iptables -A OUTPUT -d 127.0.0.0/8 -j ACCEPT
|
||||||
|
iptables -A OUTPUT -d 10.0.0.0/8 -j ACCEPT
|
||||||
|
iptables -A OUTPUT -d 172.16.0.0/12 -j ACCEPT
|
||||||
|
iptables -A OUTPUT -d 192.168.0.0/16 -j ACCEPT
|
||||||
|
iptables -A OUTPUT -d 224.0.0.0/4 -j ACCEPT
|
||||||
|
iptables -A OUTPUT -j REJECT
|
||||||
|
exec .venv/bin/exo -v
|
||||||
|
exo-node-2:
|
||||||
|
cap_add:
|
||||||
|
- NET_ADMIN
|
||||||
|
entrypoint: ["/bin/sh", "-c"]
|
||||||
|
command:
|
||||||
|
- |
|
||||||
|
iptables -A OUTPUT -d 127.0.0.0/8 -j ACCEPT
|
||||||
|
iptables -A OUTPUT -d 10.0.0.0/8 -j ACCEPT
|
||||||
|
iptables -A OUTPUT -d 172.16.0.0/12 -j ACCEPT
|
||||||
|
iptables -A OUTPUT -d 192.168.0.0/16 -j ACCEPT
|
||||||
|
iptables -A OUTPUT -d 224.0.0.0/4 -j ACCEPT
|
||||||
|
iptables -A OUTPUT -j REJECT
|
||||||
|
exec .venv/bin/exo -v
|
||||||
@@ -0,0 +1,12 @@
|
|||||||
|
model_id = "mlx-community/SmolLM2-135M-Instruct"
|
||||||
|
n_layers = 30
|
||||||
|
hidden_size = 576
|
||||||
|
supports_tensor = true
|
||||||
|
tasks = ["TextGeneration"]
|
||||||
|
family = "llama"
|
||||||
|
quantization = "bf16"
|
||||||
|
base_model = "SmolLM2 135M"
|
||||||
|
capabilities = ["text"]
|
||||||
|
|
||||||
|
[storage_size]
|
||||||
|
in_bytes = 269060381
|
||||||
@@ -0,0 +1,12 @@
|
|||||||
|
model_id = "mlx-community/gemma-2-2b-it-4bit"
|
||||||
|
n_layers = 26
|
||||||
|
hidden_size = 2304
|
||||||
|
supports_tensor = false
|
||||||
|
tasks = ["TextGeneration"]
|
||||||
|
family = "gemma2"
|
||||||
|
quantization = "4bit"
|
||||||
|
base_model = "Gemma 2 2B"
|
||||||
|
capabilities = ["text"]
|
||||||
|
|
||||||
|
[storage_size]
|
||||||
|
in_bytes = 1492755242
|
||||||
@@ -1,4 +1,8 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import threading
|
||||||
|
from multiprocessing.sharedctypes import Synchronized
|
||||||
|
|
||||||
import loguru
|
import loguru
|
||||||
|
|
||||||
@@ -10,6 +14,15 @@ from exo.utils.channels import ClosedResourceError, MpReceiver, MpSender
|
|||||||
|
|
||||||
logger: "loguru.Logger" = loguru.logger
|
logger: "loguru.Logger" = loguru.logger
|
||||||
|
|
||||||
|
HEARTBEAT_INTERVAL_SECONDS = 0.5
|
||||||
|
|
||||||
|
|
||||||
|
def _heartbeat_loop(heartbeat: Synchronized[int], stop: threading.Event) -> None:
|
||||||
|
"""Daemon thread that periodically increments the heartbeat counter."""
|
||||||
|
while not stop.is_set():
|
||||||
|
heartbeat.value += 1
|
||||||
|
stop.wait(HEARTBEAT_INTERVAL_SECONDS)
|
||||||
|
|
||||||
|
|
||||||
def entrypoint(
|
def entrypoint(
|
||||||
bound_instance: BoundInstance,
|
bound_instance: BoundInstance,
|
||||||
@@ -17,6 +30,7 @@ def entrypoint(
|
|||||||
task_receiver: MpReceiver[Task],
|
task_receiver: MpReceiver[Task],
|
||||||
cancel_receiver: MpReceiver[TaskId],
|
cancel_receiver: MpReceiver[TaskId],
|
||||||
_logger: "loguru.Logger",
|
_logger: "loguru.Logger",
|
||||||
|
heartbeat: Synchronized[int] | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
fast_synch_override = os.environ.get("EXO_FAST_SYNCH")
|
fast_synch_override = os.environ.get("EXO_FAST_SYNCH")
|
||||||
if fast_synch_override == "on" or (
|
if fast_synch_override == "on" or (
|
||||||
@@ -35,6 +49,17 @@ def entrypoint(
|
|||||||
|
|
||||||
logger.info(f"Fast synch flag: {os.environ['MLX_METAL_FAST_SYNCH']}")
|
logger.info(f"Fast synch flag: {os.environ['MLX_METAL_FAST_SYNCH']}")
|
||||||
|
|
||||||
|
# Start heartbeat thread so the supervisor can detect if we freeze.
|
||||||
|
stop_heartbeat = threading.Event()
|
||||||
|
heartbeat_thread: threading.Thread | None = None
|
||||||
|
if heartbeat is not None:
|
||||||
|
heartbeat_thread = threading.Thread(
|
||||||
|
target=_heartbeat_loop,
|
||||||
|
args=(heartbeat, stop_heartbeat),
|
||||||
|
daemon=True,
|
||||||
|
)
|
||||||
|
heartbeat_thread.start()
|
||||||
|
|
||||||
# Import main after setting global logger - this lets us just import logger from this module
|
# Import main after setting global logger - this lets us just import logger from this module
|
||||||
try:
|
try:
|
||||||
from exo.worker.runner.runner import main
|
from exo.worker.runner.runner import main
|
||||||
@@ -53,6 +78,9 @@ def entrypoint(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
finally:
|
finally:
|
||||||
|
stop_heartbeat.set()
|
||||||
|
if heartbeat_thread is not None:
|
||||||
|
heartbeat_thread.join(timeout=1)
|
||||||
try:
|
try:
|
||||||
event_sender.close()
|
event_sender.close()
|
||||||
task_receiver.close()
|
task_receiver.close()
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
import base64
|
import base64
|
||||||
import json
|
import json
|
||||||
import math
|
import math
|
||||||
|
import os
|
||||||
import resource
|
import resource
|
||||||
import time
|
import time
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
@@ -999,6 +1000,7 @@ def _validate_single_tool(obj: dict[str, Any]) -> ToolCallItem:
|
|||||||
EXO_RUNNER_MUST_FAIL = "EXO RUNNER MUST FAIL"
|
EXO_RUNNER_MUST_FAIL = "EXO RUNNER MUST FAIL"
|
||||||
EXO_RUNNER_MUST_OOM = "EXO RUNNER MUST OOM"
|
EXO_RUNNER_MUST_OOM = "EXO RUNNER MUST OOM"
|
||||||
EXO_RUNNER_MUST_TIMEOUT = "EXO RUNNER MUST TIMEOUT"
|
EXO_RUNNER_MUST_TIMEOUT = "EXO RUNNER MUST TIMEOUT"
|
||||||
|
EXO_RUNNER_MUST_DIE = "EXO RUNNER MUST DIE"
|
||||||
|
|
||||||
|
|
||||||
def _check_for_debug_prompts(task_params: TextGenerationTaskParams) -> None:
|
def _check_for_debug_prompts(task_params: TextGenerationTaskParams) -> None:
|
||||||
@@ -1014,6 +1016,9 @@ def _check_for_debug_prompts(task_params: TextGenerationTaskParams) -> None:
|
|||||||
if not prompt:
|
if not prompt:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if EXO_RUNNER_MUST_DIE in prompt:
|
||||||
|
logger.info("Abrupt process death triggered (simulates OOM kill)")
|
||||||
|
os._exit(1)
|
||||||
if EXO_RUNNER_MUST_FAIL in prompt:
|
if EXO_RUNNER_MUST_FAIL in prompt:
|
||||||
logger.info("raising exception")
|
logger.info("raising exception")
|
||||||
raise Exception("Artificial runner exception - for testing purposes only.")
|
raise Exception("Artificial runner exception - for testing purposes only.")
|
||||||
|
|||||||
@@ -1,12 +1,17 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import contextlib
|
import contextlib
|
||||||
|
import multiprocessing
|
||||||
import signal
|
import signal
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from multiprocessing import Process
|
from multiprocessing import Process
|
||||||
|
from multiprocessing.sharedctypes import Synchronized
|
||||||
from typing import Self
|
from typing import Self
|
||||||
|
|
||||||
import anyio
|
import anyio
|
||||||
from anyio import (
|
from anyio import (
|
||||||
BrokenResourceError,
|
BrokenResourceError,
|
||||||
|
CancelScope,
|
||||||
ClosedResourceError,
|
ClosedResourceError,
|
||||||
to_thread,
|
to_thread,
|
||||||
)
|
)
|
||||||
@@ -26,6 +31,7 @@ from exo.shared.types.worker.runners import (
|
|||||||
RunnerIdle,
|
RunnerIdle,
|
||||||
RunnerLoading,
|
RunnerLoading,
|
||||||
RunnerRunning,
|
RunnerRunning,
|
||||||
|
RunnerShutdown,
|
||||||
RunnerShuttingDown,
|
RunnerShuttingDown,
|
||||||
RunnerStatus,
|
RunnerStatus,
|
||||||
RunnerWarmingUp,
|
RunnerWarmingUp,
|
||||||
@@ -36,6 +42,8 @@ from exo.worker.runner.bootstrap import entrypoint
|
|||||||
|
|
||||||
PREFILL_TIMEOUT_SECONDS = 60
|
PREFILL_TIMEOUT_SECONDS = 60
|
||||||
DECODE_TIMEOUT_SECONDS = 5
|
DECODE_TIMEOUT_SECONDS = 5
|
||||||
|
HEALTH_CHECK_INTERVAL_SECONDS = 1
|
||||||
|
HEARTBEAT_STALE_CHECKS = 10
|
||||||
|
|
||||||
|
|
||||||
@dataclass(eq=False)
|
@dataclass(eq=False)
|
||||||
@@ -48,10 +56,14 @@ class RunnerSupervisor:
|
|||||||
_task_sender: MpSender[Task]
|
_task_sender: MpSender[Task]
|
||||||
_event_sender: Sender[Event]
|
_event_sender: Sender[Event]
|
||||||
_cancel_sender: MpSender[TaskId]
|
_cancel_sender: MpSender[TaskId]
|
||||||
|
_heartbeat: Synchronized[int]
|
||||||
status: RunnerStatus = field(default_factory=RunnerIdle, init=False)
|
status: RunnerStatus = field(default_factory=RunnerIdle, init=False)
|
||||||
pending: dict[TaskId, anyio.Event] = field(default_factory=dict, init=False)
|
pending: dict[TaskId, anyio.Event] = field(default_factory=dict, init=False)
|
||||||
completed: set[TaskId] = field(default_factory=set, init=False)
|
completed: set[TaskId] = field(default_factory=set, init=False)
|
||||||
cancelled: set[TaskId] = field(default_factory=set, init=False)
|
cancelled: set[TaskId] = field(default_factory=set, init=False)
|
||||||
|
_death_handled: bool = field(default=False, init=False)
|
||||||
|
_last_heartbeat_value: int = field(default=0, init=False)
|
||||||
|
_heartbeat_stale_count: int = field(default=0, init=False)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create(
|
def create(
|
||||||
@@ -65,6 +77,8 @@ class RunnerSupervisor:
|
|||||||
task_sender, task_recv = mp_channel[Task]()
|
task_sender, task_recv = mp_channel[Task]()
|
||||||
cancel_sender, cancel_recv = mp_channel[TaskId]()
|
cancel_sender, cancel_recv = mp_channel[TaskId]()
|
||||||
|
|
||||||
|
heartbeat: Synchronized[int] = multiprocessing.Value("Q", 0)
|
||||||
|
|
||||||
runner_process = Process(
|
runner_process = Process(
|
||||||
target=entrypoint,
|
target=entrypoint,
|
||||||
args=(
|
args=(
|
||||||
@@ -73,6 +87,7 @@ class RunnerSupervisor:
|
|||||||
task_recv,
|
task_recv,
|
||||||
cancel_recv,
|
cancel_recv,
|
||||||
logger,
|
logger,
|
||||||
|
heartbeat,
|
||||||
),
|
),
|
||||||
daemon=True,
|
daemon=True,
|
||||||
)
|
)
|
||||||
@@ -88,13 +103,16 @@ class RunnerSupervisor:
|
|||||||
_task_sender=task_sender,
|
_task_sender=task_sender,
|
||||||
_cancel_sender=cancel_sender,
|
_cancel_sender=cancel_sender,
|
||||||
_event_sender=event_sender,
|
_event_sender=event_sender,
|
||||||
|
_heartbeat=heartbeat,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
async def run(self):
|
async def run(self):
|
||||||
self.runner_process.start()
|
self.runner_process.start()
|
||||||
await self._forward_events()
|
async with anyio.create_task_group() as tg:
|
||||||
|
tg.start_soon(self._forward_events)
|
||||||
|
tg.start_soon(self._health_check, tg.cancel_scope)
|
||||||
|
|
||||||
def shutdown(self):
|
def shutdown(self):
|
||||||
logger.info("Runner supervisor shutting down")
|
logger.info("Runner supervisor shutting down")
|
||||||
@@ -177,9 +195,99 @@ class RunnerSupervisor:
|
|||||||
self.completed.add(event.task_id)
|
self.completed.add(event.task_id)
|
||||||
await self._event_sender.send(event)
|
await self._event_sender.send(event)
|
||||||
except (ClosedResourceError, BrokenResourceError) as e:
|
except (ClosedResourceError, BrokenResourceError) as e:
|
||||||
await self._check_runner(e)
|
if not self._death_handled:
|
||||||
for tid in self.pending:
|
self._death_handled = True
|
||||||
self.pending[tid].set()
|
await self._check_runner(e)
|
||||||
|
for tid in self.pending:
|
||||||
|
self.pending[tid].set()
|
||||||
|
|
||||||
|
async def _health_check(self, cancel_scope: CancelScope) -> None:
|
||||||
|
"""Periodically check if the runner process is alive and responsive.
|
||||||
|
|
||||||
|
Detects two failure modes:
|
||||||
|
1. Process death (e.g. OOM kill) without cleanly closing the event
|
||||||
|
channel, which would leave _forward_events blocked on queue.get().
|
||||||
|
2. Unresponsive process (e.g. frozen by OS memory pressure, deadlock)
|
||||||
|
detected via a stale heartbeat counter.
|
||||||
|
"""
|
||||||
|
while True:
|
||||||
|
await anyio.sleep(HEALTH_CHECK_INTERVAL_SECONDS)
|
||||||
|
|
||||||
|
if not self.runner_process.is_alive():
|
||||||
|
self._handle_process_exit(cancel_scope)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Check heartbeat counter — if it hasn't changed between
|
||||||
|
# consecutive checks, the subprocess may be frozen.
|
||||||
|
current = self._heartbeat.value
|
||||||
|
if current > 0:
|
||||||
|
if current == self._last_heartbeat_value:
|
||||||
|
self._heartbeat_stale_count += 1
|
||||||
|
if self._heartbeat_stale_count >= HEARTBEAT_STALE_CHECKS:
|
||||||
|
logger.error(
|
||||||
|
f"Health check: runner process unresponsive "
|
||||||
|
f"(heartbeat stale for {self._heartbeat_stale_count} checks), killing"
|
||||||
|
)
|
||||||
|
self._handle_unresponsive(cancel_scope)
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
self._heartbeat_stale_count = 0
|
||||||
|
self._last_heartbeat_value = current
|
||||||
|
|
||||||
|
def _handle_process_exit(self, cancel_scope: CancelScope) -> None:
|
||||||
|
"""Handle runner process that has exited."""
|
||||||
|
if not self._death_handled:
|
||||||
|
self._death_handled = True
|
||||||
|
if isinstance(
|
||||||
|
self.status, (RunnerShutdown, RunnerShuttingDown, RunnerFailed)
|
||||||
|
):
|
||||||
|
logger.info("Health check: runner process exited (expected)")
|
||||||
|
else:
|
||||||
|
rc = self.runner_process.exitcode
|
||||||
|
if isinstance(rc, int) and rc < 0:
|
||||||
|
sig = -rc
|
||||||
|
try:
|
||||||
|
cause = f"signal={sig} ({signal.strsignal(sig)})"
|
||||||
|
except Exception:
|
||||||
|
cause = f"signal={sig}"
|
||||||
|
else:
|
||||||
|
cause = f"exitcode={rc}"
|
||||||
|
|
||||||
|
logger.error(
|
||||||
|
f"Health check: runner process died unexpectedly ({cause})"
|
||||||
|
)
|
||||||
|
self._event_sender.send_nowait(
|
||||||
|
RunnerStatusUpdated(
|
||||||
|
runner_id=self.bound_instance.bound_runner_id,
|
||||||
|
runner_status=RunnerFailed(
|
||||||
|
error_message=f"Terminated ({cause})"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.shutdown()
|
||||||
|
|
||||||
|
for tid in self.pending:
|
||||||
|
self.pending[tid].set()
|
||||||
|
|
||||||
|
cancel_scope.cancel()
|
||||||
|
|
||||||
|
def _handle_unresponsive(self, cancel_scope: CancelScope) -> None:
|
||||||
|
"""Handle runner process that is alive but unresponsive."""
|
||||||
|
if not self._death_handled:
|
||||||
|
self._death_handled = True
|
||||||
|
self._event_sender.send_nowait(
|
||||||
|
RunnerStatusUpdated(
|
||||||
|
runner_id=self.bound_instance.bound_runner_id,
|
||||||
|
runner_status=RunnerFailed(
|
||||||
|
error_message="Runner process unresponsive (heartbeat timeout)"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
for tid in self.pending:
|
||||||
|
self.pending[tid].set()
|
||||||
|
self.shutdown()
|
||||||
|
|
||||||
|
cancel_scope.cancel()
|
||||||
|
|
||||||
def __del__(self) -> None:
|
def __del__(self) -> None:
|
||||||
if self.runner_process.is_alive():
|
if self.runner_process.is_alive():
|
||||||
|
|||||||
@@ -1 +1,204 @@
|
|||||||
# TODO:
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import multiprocessing
|
||||||
|
import os
|
||||||
|
import signal as signal_module
|
||||||
|
from collections.abc import Callable
|
||||||
|
from multiprocessing.sharedctypes import Synchronized
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import anyio
|
||||||
|
|
||||||
|
from exo.shared.types.events import Event, RunnerStatusUpdated
|
||||||
|
from exo.shared.types.tasks import Task, TaskId
|
||||||
|
from exo.shared.types.worker.runners import RunnerFailed, RunnerIdle, RunnerShutdown
|
||||||
|
from exo.utils.channels import Receiver, Sender, channel, mp_channel
|
||||||
|
from exo.worker.runner.runner_supervisor import (
|
||||||
|
HEALTH_CHECK_INTERVAL_SECONDS,
|
||||||
|
HEARTBEAT_STALE_CHECKS,
|
||||||
|
RunnerSupervisor,
|
||||||
|
)
|
||||||
|
|
||||||
|
from ...constants import (
|
||||||
|
INSTANCE_1_ID,
|
||||||
|
MODEL_A_ID,
|
||||||
|
NODE_A,
|
||||||
|
RUNNER_1_ID,
|
||||||
|
)
|
||||||
|
from ..conftest import get_bound_mlx_ring_instance
|
||||||
|
|
||||||
|
|
||||||
|
def _die_immediately() -> None:
|
||||||
|
"""Subprocess target that exits with a non-zero code."""
|
||||||
|
os._exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
def _die_with_signal() -> None:
|
||||||
|
"""Subprocess target that kills itself with SIGKILL (simulates OOM)."""
|
||||||
|
os.kill(os.getpid(), signal_module.SIGKILL)
|
||||||
|
|
||||||
|
|
||||||
|
def _exit_cleanly() -> None:
|
||||||
|
"""Subprocess target that exits with code 0."""
|
||||||
|
os._exit(0)
|
||||||
|
|
||||||
|
|
||||||
|
def _hang_forever() -> None:
|
||||||
|
"""Subprocess target that hangs without updating heartbeat (simulates freeze)."""
|
||||||
|
import time
|
||||||
|
|
||||||
|
# Write one heartbeat so the supervisor starts tracking, then stop.
|
||||||
|
time.sleep(100000)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_supervisor(
|
||||||
|
event_sender: Sender[Event],
|
||||||
|
target: Callable[..., Any],
|
||||||
|
) -> RunnerSupervisor:
|
||||||
|
"""Build a RunnerSupervisor with a custom subprocess target.
|
||||||
|
|
||||||
|
Uses a clone of event_sender (matching real Worker behavior) so that
|
||||||
|
closing the supervisor's copy doesn't close the test's receiver.
|
||||||
|
"""
|
||||||
|
bound_instance = get_bound_mlx_ring_instance(
|
||||||
|
instance_id=INSTANCE_1_ID,
|
||||||
|
model_id=MODEL_A_ID,
|
||||||
|
runner_id=RUNNER_1_ID,
|
||||||
|
node_id=NODE_A,
|
||||||
|
)
|
||||||
|
|
||||||
|
_ev_send, ev_recv = mp_channel[Event]()
|
||||||
|
task_sender, _task_recv = mp_channel[Task]()
|
||||||
|
cancel_sender, _cancel_recv = mp_channel[TaskId]()
|
||||||
|
runner_process = multiprocessing.Process(target=target, daemon=True)
|
||||||
|
heartbeat: Synchronized[int] = multiprocessing.Value("Q", 0)
|
||||||
|
|
||||||
|
return RunnerSupervisor(
|
||||||
|
bound_instance=bound_instance,
|
||||||
|
shard_metadata=bound_instance.bound_shard,
|
||||||
|
runner_process=runner_process,
|
||||||
|
initialize_timeout=10,
|
||||||
|
_ev_recv=ev_recv,
|
||||||
|
_task_sender=task_sender,
|
||||||
|
_cancel_sender=cancel_sender,
|
||||||
|
_event_sender=event_sender.clone(),
|
||||||
|
_heartbeat=heartbeat,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _collect_failed_events(
|
||||||
|
event_receiver: Receiver[Event],
|
||||||
|
) -> list[RunnerFailed]:
|
||||||
|
"""Drain the receiver and return all RunnerFailed statuses."""
|
||||||
|
out: list[RunnerFailed] = []
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
event = event_receiver.receive_nowait()
|
||||||
|
except Exception:
|
||||||
|
break
|
||||||
|
if isinstance(event, RunnerStatusUpdated) and isinstance(
|
||||||
|
event.runner_status, RunnerFailed
|
||||||
|
):
|
||||||
|
out.append(event.runner_status)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
async def test_health_check_detects_dead_process():
|
||||||
|
"""When the runner process dies with a non-zero exit code, the health check
|
||||||
|
should emit a RunnerFailed event and run() should return."""
|
||||||
|
event_sender, event_receiver = channel[Event]()
|
||||||
|
supervisor = _build_supervisor(event_sender, _die_immediately)
|
||||||
|
|
||||||
|
with anyio.fail_after(HEALTH_CHECK_INTERVAL_SECONDS + 5):
|
||||||
|
await supervisor.run()
|
||||||
|
|
||||||
|
failures = _collect_failed_events(event_receiver)
|
||||||
|
assert len(failures) == 1
|
||||||
|
assert failures[0].error_message is not None
|
||||||
|
assert "exitcode=1" in failures[0].error_message
|
||||||
|
|
||||||
|
|
||||||
|
async def test_health_check_detects_signal_death():
|
||||||
|
"""When the runner process is killed by a signal (e.g. OOM -> SIGKILL),
|
||||||
|
the health check should report the signal in the failure message."""
|
||||||
|
event_sender, event_receiver = channel[Event]()
|
||||||
|
supervisor = _build_supervisor(event_sender, _die_with_signal)
|
||||||
|
|
||||||
|
with anyio.fail_after(HEALTH_CHECK_INTERVAL_SECONDS + 5):
|
||||||
|
await supervisor.run()
|
||||||
|
|
||||||
|
failures = _collect_failed_events(event_receiver)
|
||||||
|
assert len(failures) == 1
|
||||||
|
assert failures[0].error_message is not None
|
||||||
|
assert "signal=9" in failures[0].error_message
|
||||||
|
|
||||||
|
|
||||||
|
async def test_health_check_releases_pending_tasks():
|
||||||
|
"""When the runner dies, any pending start_task() waiters should be unblocked."""
|
||||||
|
event_sender, _event_receiver = channel[Event]()
|
||||||
|
supervisor = _build_supervisor(event_sender, _die_immediately)
|
||||||
|
|
||||||
|
# Register a pending waiter as if start_task() was waiting for acknowledgement
|
||||||
|
task_event = anyio.Event()
|
||||||
|
tid = TaskId("pending-task")
|
||||||
|
supervisor.pending[tid] = task_event
|
||||||
|
|
||||||
|
with anyio.fail_after(HEALTH_CHECK_INTERVAL_SECONDS + 5):
|
||||||
|
await supervisor.run()
|
||||||
|
|
||||||
|
assert task_event.is_set()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_clean_exit_no_failure_when_shutdown_status():
|
||||||
|
"""When the runner was in RunnerShutdown status and exits with code 0,
|
||||||
|
no RunnerFailed event should be emitted."""
|
||||||
|
event_sender, event_receiver = channel[Event]()
|
||||||
|
supervisor = _build_supervisor(event_sender, _exit_cleanly)
|
||||||
|
|
||||||
|
# Simulate that the runner had already reported shutdown via events
|
||||||
|
supervisor.status = RunnerShutdown()
|
||||||
|
|
||||||
|
with anyio.fail_after(HEALTH_CHECK_INTERVAL_SECONDS + 5):
|
||||||
|
await supervisor.run()
|
||||||
|
|
||||||
|
failures = _collect_failed_events(event_receiver)
|
||||||
|
assert len(failures) == 0
|
||||||
|
|
||||||
|
|
||||||
|
async def test_unexpected_exit_code_zero_emits_failure():
|
||||||
|
"""When the runner exits with code 0 but was NOT in a shutdown state,
|
||||||
|
this is unexpected and should still emit RunnerFailed."""
|
||||||
|
event_sender, event_receiver = channel[Event]()
|
||||||
|
supervisor = _build_supervisor(event_sender, _exit_cleanly)
|
||||||
|
|
||||||
|
assert isinstance(supervisor.status, RunnerIdle)
|
||||||
|
|
||||||
|
with anyio.fail_after(HEALTH_CHECK_INTERVAL_SECONDS + 5):
|
||||||
|
await supervisor.run()
|
||||||
|
|
||||||
|
failures = _collect_failed_events(event_receiver)
|
||||||
|
assert len(failures) == 1
|
||||||
|
assert failures[0].error_message is not None
|
||||||
|
assert "exitcode=0" in failures[0].error_message
|
||||||
|
|
||||||
|
|
||||||
|
async def test_heartbeat_timeout_detects_unresponsive_process():
|
||||||
|
"""When the runner process is alive but its heartbeat goes stale,
|
||||||
|
the health check should kill it and emit RunnerFailed."""
|
||||||
|
event_sender, event_receiver = channel[Event]()
|
||||||
|
supervisor = _build_supervisor(event_sender, _hang_forever)
|
||||||
|
|
||||||
|
# Pre-seed the heartbeat counter with a non-zero value and set the
|
||||||
|
# supervisor's last-seen value to match so it appears stale immediately.
|
||||||
|
# Set stale count to HEARTBEAT_STALE_CHECKS - 1 so a single check triggers.
|
||||||
|
supervisor._heartbeat.value = 42 # pyright: ignore[reportPrivateUsage]
|
||||||
|
supervisor._last_heartbeat_value = 42 # pyright: ignore[reportPrivateUsage]
|
||||||
|
supervisor._heartbeat_stale_count = HEARTBEAT_STALE_CHECKS - 1 # pyright: ignore[reportPrivateUsage]
|
||||||
|
|
||||||
|
with anyio.fail_after(HEALTH_CHECK_INTERVAL_SECONDS + 5):
|
||||||
|
await supervisor.run()
|
||||||
|
|
||||||
|
failures = _collect_failed_events(event_receiver)
|
||||||
|
assert len(failures) == 1
|
||||||
|
assert failures[0].error_message is not None
|
||||||
|
assert "unresponsive" in failures[0].error_message.lower()
|
||||||
|
|||||||
Reference in New Issue
Block a user