Files
exo/tests/framework.py
ciaranbor fa57131374 Integration tests infra (#1995)
## Motivation

No automated integration tests exist for exo. Manual testing against
real hardware clusters is slow and error-prone. We need a pytest
framework that deploys clusters via `eco`, runs inference scenarios, and
tears down cleanly.

## Changes

- **`tools/src/exo_tools/`** — New workspace member shared by bench,
eval, and tests:
- `client.py` — `ExoClient` HTTP client (extracted from
`bench/harness.py`)
- `harness.py` — instance lifecycle helpers (placement, wait-for-ready,
etc.)
- `cluster.py` — `EcoSession` for eco cluster lifecycle
(deploy/stop/start/release/logs/exec) with unique `USER=<prefix>-<uuid>`
per session and atexit/signal cleanup
- **`tests/integration/`** — 17 pytest tests across 5 files:
- `test_1node.py` — place, chat, multi-turn, delete, state/models
endpoints, cluster snapshot, download-from-scratch
- `test_2node.py` — parametrized tensor/jaccl + pipeline/ring inference
and multi-turn
- `test_4node.py` — parametrized 4-node pipeline/ring inference, cluster
state
- `test_resilience.py` — full disconnect/reconnect cycle (2-node →
disconnect → 1-node → reconnect → 2-node)
- `test_dashboard.py` — Playwright: dashboard loads, shows node info,
chat flow
- `helpers.py` — placement/inference helpers, re-exports from
`exo_tools`
- `conftest.py` — session-scoped cluster fixtures with constraint-based
eco reservations; `--hosts` override; `EXO_REF` env var for CI
deployments from a GitHub branch
- **`bench/`** — Updated imports from `exo_tools.client` /
`exo_tools.harness`
- **`pyproject.toml`** — Added `tools` workspace member, `playwright`
dev dep, `--ignore=tests/integration`

## Why It Works

Tests use `eco` for cluster lifecycle and `ExoClient` for API
interactions — same tools humans use. Session-scoped fixtures deploy
once per file. Unique eco users prevent test runs from interfering with
each other or manual usage.

## Test Plan

### Automated Testing

- `uv run pytest tests/integration/ -v -s` — full suite (~4-5 min, 17/17
passing)
- `uv run pytest tests/integration/ -v -s --hosts s4,s9,s10,s22` — pin
specific hosts
- `EXO_REF=main uv run pytest tests/integration/ -v` — deploy from a
GitHub branch (CI)
- `uv run pytest` — confirms integration tests are excluded from default
runs
2026-05-08 17:15:08 +01:00

200 lines
6.5 KiB
Python

"""Marker-driven test framework for exo integration tests.
Test authors declare requirements via markers:
@pytest.mark.cluster(count=2, thunderbolt='a2a')
@pytest.mark.instance('mlx-community/Llama-3.2-1B-Instruct-4bit',
sharding='tensor', comm='jaccl')
def test_jaccl_inference(session):
resp = session.chat('What is 2+2?')
assert '4' in resp
The `session` fixture reads the markers, deploys the cluster, places the
instance, and provides a `Session` object. All cluster/instance orchestration
lives in `exo_tools.harness`; this module is purely the pytest-facing layer.
"""
from __future__ import annotations
import time
from dataclasses import dataclass, field
from typing import Any
from exo_tools.client import ExoClient
from exo_tools.cluster import (
Chip,
ClusterInfo,
EcoSession,
Thunderbolt,
make_client_from_url,
)
from exo_tools.harness import Comm, Sharding
from exo.api.types.api import (
ChatCompletionChoice,
ChatCompletionRequest,
ChatCompletionResponse,
)
DEFAULT_MODEL = "mlx-community/Llama-3.2-1B-Instruct-4bit"
def _extract_content(resp: ChatCompletionResponse) -> str:
"""Extract plain-text content from a non-streaming chat completion."""
choice = resp.choices[0]
if not isinstance(choice, ChatCompletionChoice):
raise RuntimeError(
f"Expected non-streaming choice, got {type(choice).__name__}"
)
content = choice.message.content
if not isinstance(content, str):
raise RuntimeError(f"Expected string content, got {type(content).__name__}")
return content
@dataclass(frozen=True)
class ClusterSpec:
count: int = 1
thunderbolt: Thunderbolt | None = None
min_memory_gb: float | None = None
chip: Chip | None = None
@dataclass(frozen=True)
class InstanceSpec:
model_id: str
sharding: Sharding = Sharding.PIPELINE
comm: Comm = Comm.RING
min_nodes: int = 1
def parse_cluster_marker(marker) -> ClusterSpec:
if marker is None:
return ClusterSpec()
return ClusterSpec(
count=marker.kwargs.get("count", 1),
thunderbolt=marker.kwargs.get("thunderbolt"),
min_memory_gb=marker.kwargs.get("min_memory"),
chip=marker.kwargs.get("chip"),
)
def parse_instance_marker(marker) -> InstanceSpec | None:
if marker is None:
return None
if not marker.args:
raise ValueError(
"@pytest.mark.instance requires a positional model_id argument"
)
return InstanceSpec(
model_id=marker.args[0],
sharding=marker.kwargs.get("sharding", Sharding.PIPELINE),
comm=marker.kwargs.get("comm", Comm.RING),
min_nodes=marker.kwargs.get("min_nodes", 1),
)
@dataclass
class Session:
cluster: ClusterInfo
eco: EcoSession
instance_spec: InstanceSpec | None = None
instance_id: str | None = None
_stopped_hosts: set[str] = field(default_factory=set)
@property
def client(self) -> ExoClient:
for host in self.cluster.hosts:
if host not in self._stopped_hosts:
return make_client_from_url(self.cluster.api_endpoints[host])
return self.cluster.make_client()
@property
def state(self) -> dict[str, Any]:
return self.client.request_json("GET", "/state") or {}
@property
def instances(self) -> dict[str, Any]:
return self.state.get("instances", {})
# ---- Inference ----
def chat(self, prompt: str, max_tokens: int = 100) -> str:
resp = self.chat_raw(prompt, max_tokens=max_tokens)
return _extract_content(resp)
def chat_raw(self, prompt: str, **kwargs: Any) -> ChatCompletionResponse:
if not self.instance_spec:
raise RuntimeError(
"No instance placed; add @pytest.mark.instance to the test"
)
max_tokens = kwargs.pop("max_tokens", 100)
request = ChatCompletionRequest.model_validate(
{
"model": self.instance_spec.model_id,
"messages": [{"role": "user", "content": prompt}],
"max_tokens": max_tokens,
**kwargs,
}
)
return self._post_chat(request)
def multi_turn(self, messages: list[dict[str, str]], max_tokens: int = 100) -> str:
if not self.instance_spec:
raise RuntimeError(
"No instance placed; add @pytest.mark.instance to the test"
)
request = ChatCompletionRequest.model_validate(
{
"model": self.instance_spec.model_id,
"messages": messages,
"max_tokens": max_tokens,
}
)
return _extract_content(self._post_chat(request))
def _post_chat(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
raw = self.client.request_json(
"POST",
"/v1/chat/completions",
body=request.model_dump(exclude_none=True),
)
return ChatCompletionResponse.model_validate(raw)
def disconnect_node(self, index: int) -> None:
"""Stop exo on a node and wait for the cluster to observe the disconnect."""
host = self.cluster.hosts[index]
self.eco.stop([host], keep=True)
self._stopped_hosts.add(host)
def reconnect_node(self, index: int) -> None:
"""Restart a previously disconnected node into the existing namespace."""
host = self.cluster.hosts[index]
self.eco.start_hosts([host], namespace=self.cluster.namespace)
self._stopped_hosts.discard(host)
def wait_ready(
self, expected_nodes: int | None = None, timeout: float = 60
) -> None:
"""Wait until the cluster has exactly `expected_nodes` visible and reporting memory.
Defaults to the count of non-stopped hosts. Use this after
`disconnect_node` / `reconnect_node` to wait for the cluster to settle.
"""
if expected_nodes is None:
expected_nodes = len(self.cluster.hosts) - len(self._stopped_hosts)
start = time.time()
while time.time() - start < timeout:
try:
state = self.state
identities = len(state.get("nodeIdentities", {}))
memory = len(state.get("nodeMemory", {}))
if identities == expected_nodes and memory == expected_nodes:
return
except Exception:
pass
time.sleep(2.0)
raise TimeoutError(
f"Cluster did not reach exactly {expected_nodes} ready nodes within {timeout}s"
)