mirror of
https://github.com/exo-explore/exo.git
synced 2026-05-24 06:35:32 -04:00
## Motivation No automated integration tests exist for exo. Manual testing against real hardware clusters is slow and error-prone. We need a pytest framework that deploys clusters via `eco`, runs inference scenarios, and tears down cleanly. ## Changes - **`tools/src/exo_tools/`** — New workspace member shared by bench, eval, and tests: - `client.py` — `ExoClient` HTTP client (extracted from `bench/harness.py`) - `harness.py` — instance lifecycle helpers (placement, wait-for-ready, etc.) - `cluster.py` — `EcoSession` for eco cluster lifecycle (deploy/stop/start/release/logs/exec) with unique `USER=<prefix>-<uuid>` per session and atexit/signal cleanup - **`tests/integration/`** — 17 pytest tests across 5 files: - `test_1node.py` — place, chat, multi-turn, delete, state/models endpoints, cluster snapshot, download-from-scratch - `test_2node.py` — parametrized tensor/jaccl + pipeline/ring inference and multi-turn - `test_4node.py` — parametrized 4-node pipeline/ring inference, cluster state - `test_resilience.py` — full disconnect/reconnect cycle (2-node → disconnect → 1-node → reconnect → 2-node) - `test_dashboard.py` — Playwright: dashboard loads, shows node info, chat flow - `helpers.py` — placement/inference helpers, re-exports from `exo_tools` - `conftest.py` — session-scoped cluster fixtures with constraint-based eco reservations; `--hosts` override; `EXO_REF` env var for CI deployments from a GitHub branch - **`bench/`** — Updated imports from `exo_tools.client` / `exo_tools.harness` - **`pyproject.toml`** — Added `tools` workspace member, `playwright` dev dep, `--ignore=tests/integration` ## Why It Works Tests use `eco` for cluster lifecycle and `ExoClient` for API interactions — same tools humans use. Session-scoped fixtures deploy once per file. Unique eco users prevent test runs from interfering with each other or manual usage. ## Test Plan ### Automated Testing - `uv run pytest tests/integration/ -v -s` — full suite (~4-5 min, 17/17 passing) - `uv run pytest tests/integration/ -v -s --hosts s4,s9,s10,s22` — pin specific hosts - `EXO_REF=main uv run pytest tests/integration/ -v` — deploy from a GitHub branch (CI) - `uv run pytest` — confirms integration tests are excluded from default runs
200 lines
6.5 KiB
Python
200 lines
6.5 KiB
Python
"""Marker-driven test framework for exo integration tests.
|
|
|
|
Test authors declare requirements via markers:
|
|
|
|
@pytest.mark.cluster(count=2, thunderbolt='a2a')
|
|
@pytest.mark.instance('mlx-community/Llama-3.2-1B-Instruct-4bit',
|
|
sharding='tensor', comm='jaccl')
|
|
def test_jaccl_inference(session):
|
|
resp = session.chat('What is 2+2?')
|
|
assert '4' in resp
|
|
|
|
The `session` fixture reads the markers, deploys the cluster, places the
|
|
instance, and provides a `Session` object. All cluster/instance orchestration
|
|
lives in `exo_tools.harness`; this module is purely the pytest-facing layer.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import time
|
|
from dataclasses import dataclass, field
|
|
from typing import Any
|
|
|
|
from exo_tools.client import ExoClient
|
|
from exo_tools.cluster import (
|
|
Chip,
|
|
ClusterInfo,
|
|
EcoSession,
|
|
Thunderbolt,
|
|
make_client_from_url,
|
|
)
|
|
from exo_tools.harness import Comm, Sharding
|
|
|
|
from exo.api.types.api import (
|
|
ChatCompletionChoice,
|
|
ChatCompletionRequest,
|
|
ChatCompletionResponse,
|
|
)
|
|
|
|
DEFAULT_MODEL = "mlx-community/Llama-3.2-1B-Instruct-4bit"
|
|
|
|
|
|
def _extract_content(resp: ChatCompletionResponse) -> str:
|
|
"""Extract plain-text content from a non-streaming chat completion."""
|
|
choice = resp.choices[0]
|
|
if not isinstance(choice, ChatCompletionChoice):
|
|
raise RuntimeError(
|
|
f"Expected non-streaming choice, got {type(choice).__name__}"
|
|
)
|
|
content = choice.message.content
|
|
if not isinstance(content, str):
|
|
raise RuntimeError(f"Expected string content, got {type(content).__name__}")
|
|
return content
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class ClusterSpec:
|
|
count: int = 1
|
|
thunderbolt: Thunderbolt | None = None
|
|
min_memory_gb: float | None = None
|
|
chip: Chip | None = None
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class InstanceSpec:
|
|
model_id: str
|
|
sharding: Sharding = Sharding.PIPELINE
|
|
comm: Comm = Comm.RING
|
|
min_nodes: int = 1
|
|
|
|
|
|
def parse_cluster_marker(marker) -> ClusterSpec:
|
|
if marker is None:
|
|
return ClusterSpec()
|
|
return ClusterSpec(
|
|
count=marker.kwargs.get("count", 1),
|
|
thunderbolt=marker.kwargs.get("thunderbolt"),
|
|
min_memory_gb=marker.kwargs.get("min_memory"),
|
|
chip=marker.kwargs.get("chip"),
|
|
)
|
|
|
|
|
|
def parse_instance_marker(marker) -> InstanceSpec | None:
|
|
if marker is None:
|
|
return None
|
|
if not marker.args:
|
|
raise ValueError(
|
|
"@pytest.mark.instance requires a positional model_id argument"
|
|
)
|
|
return InstanceSpec(
|
|
model_id=marker.args[0],
|
|
sharding=marker.kwargs.get("sharding", Sharding.PIPELINE),
|
|
comm=marker.kwargs.get("comm", Comm.RING),
|
|
min_nodes=marker.kwargs.get("min_nodes", 1),
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class Session:
|
|
cluster: ClusterInfo
|
|
eco: EcoSession
|
|
instance_spec: InstanceSpec | None = None
|
|
instance_id: str | None = None
|
|
_stopped_hosts: set[str] = field(default_factory=set)
|
|
|
|
@property
|
|
def client(self) -> ExoClient:
|
|
for host in self.cluster.hosts:
|
|
if host not in self._stopped_hosts:
|
|
return make_client_from_url(self.cluster.api_endpoints[host])
|
|
return self.cluster.make_client()
|
|
|
|
@property
|
|
def state(self) -> dict[str, Any]:
|
|
return self.client.request_json("GET", "/state") or {}
|
|
|
|
@property
|
|
def instances(self) -> dict[str, Any]:
|
|
return self.state.get("instances", {})
|
|
|
|
# ---- Inference ----
|
|
|
|
def chat(self, prompt: str, max_tokens: int = 100) -> str:
|
|
resp = self.chat_raw(prompt, max_tokens=max_tokens)
|
|
return _extract_content(resp)
|
|
|
|
def chat_raw(self, prompt: str, **kwargs: Any) -> ChatCompletionResponse:
|
|
if not self.instance_spec:
|
|
raise RuntimeError(
|
|
"No instance placed; add @pytest.mark.instance to the test"
|
|
)
|
|
max_tokens = kwargs.pop("max_tokens", 100)
|
|
request = ChatCompletionRequest.model_validate(
|
|
{
|
|
"model": self.instance_spec.model_id,
|
|
"messages": [{"role": "user", "content": prompt}],
|
|
"max_tokens": max_tokens,
|
|
**kwargs,
|
|
}
|
|
)
|
|
return self._post_chat(request)
|
|
|
|
def multi_turn(self, messages: list[dict[str, str]], max_tokens: int = 100) -> str:
|
|
if not self.instance_spec:
|
|
raise RuntimeError(
|
|
"No instance placed; add @pytest.mark.instance to the test"
|
|
)
|
|
request = ChatCompletionRequest.model_validate(
|
|
{
|
|
"model": self.instance_spec.model_id,
|
|
"messages": messages,
|
|
"max_tokens": max_tokens,
|
|
}
|
|
)
|
|
return _extract_content(self._post_chat(request))
|
|
|
|
def _post_chat(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
|
|
raw = self.client.request_json(
|
|
"POST",
|
|
"/v1/chat/completions",
|
|
body=request.model_dump(exclude_none=True),
|
|
)
|
|
return ChatCompletionResponse.model_validate(raw)
|
|
|
|
def disconnect_node(self, index: int) -> None:
|
|
"""Stop exo on a node and wait for the cluster to observe the disconnect."""
|
|
host = self.cluster.hosts[index]
|
|
self.eco.stop([host], keep=True)
|
|
self._stopped_hosts.add(host)
|
|
|
|
def reconnect_node(self, index: int) -> None:
|
|
"""Restart a previously disconnected node into the existing namespace."""
|
|
host = self.cluster.hosts[index]
|
|
self.eco.start_hosts([host], namespace=self.cluster.namespace)
|
|
self._stopped_hosts.discard(host)
|
|
|
|
def wait_ready(
|
|
self, expected_nodes: int | None = None, timeout: float = 60
|
|
) -> None:
|
|
"""Wait until the cluster has exactly `expected_nodes` visible and reporting memory.
|
|
|
|
Defaults to the count of non-stopped hosts. Use this after
|
|
`disconnect_node` / `reconnect_node` to wait for the cluster to settle.
|
|
"""
|
|
if expected_nodes is None:
|
|
expected_nodes = len(self.cluster.hosts) - len(self._stopped_hosts)
|
|
start = time.time()
|
|
while time.time() - start < timeout:
|
|
try:
|
|
state = self.state
|
|
identities = len(state.get("nodeIdentities", {}))
|
|
memory = len(state.get("nodeMemory", {}))
|
|
if identities == expected_nodes and memory == expected_nodes:
|
|
return
|
|
except Exception:
|
|
pass
|
|
time.sleep(2.0)
|
|
raise TimeoutError(
|
|
f"Cluster did not reach exactly {expected_nodes} ready nodes within {timeout}s"
|
|
)
|