mirror of
https://github.com/exo-explore/exo.git
synced 2026-06-02 19:27:55 -04:00
## Motivation No automated integration tests exist for exo. Manual testing against real hardware clusters is slow and error-prone. We need a pytest framework that deploys clusters via `eco`, runs inference scenarios, and tears down cleanly. ## Changes - **`tools/src/exo_tools/`** — New workspace member shared by bench, eval, and tests: - `client.py` — `ExoClient` HTTP client (extracted from `bench/harness.py`) - `harness.py` — instance lifecycle helpers (placement, wait-for-ready, etc.) - `cluster.py` — `EcoSession` for eco cluster lifecycle (deploy/stop/start/release/logs/exec) with unique `USER=<prefix>-<uuid>` per session and atexit/signal cleanup - **`tests/integration/`** — 17 pytest tests across 5 files: - `test_1node.py` — place, chat, multi-turn, delete, state/models endpoints, cluster snapshot, download-from-scratch - `test_2node.py` — parametrized tensor/jaccl + pipeline/ring inference and multi-turn - `test_4node.py` — parametrized 4-node pipeline/ring inference, cluster state - `test_resilience.py` — full disconnect/reconnect cycle (2-node → disconnect → 1-node → reconnect → 2-node) - `test_dashboard.py` — Playwright: dashboard loads, shows node info, chat flow - `helpers.py` — placement/inference helpers, re-exports from `exo_tools` - `conftest.py` — session-scoped cluster fixtures with constraint-based eco reservations; `--hosts` override; `EXO_REF` env var for CI deployments from a GitHub branch - **`bench/`** — Updated imports from `exo_tools.client` / `exo_tools.harness` - **`pyproject.toml`** — Added `tools` workspace member, `playwright` dev dep, `--ignore=tests/integration` ## Why It Works Tests use `eco` for cluster lifecycle and `ExoClient` for API interactions — same tools humans use. Session-scoped fixtures deploy once per file. Unique eco users prevent test runs from interfering with each other or manual usage. ## Test Plan ### Automated Testing - `uv run pytest tests/integration/ -v -s` — full suite (~4-5 min, 17/17 passing) - `uv run pytest tests/integration/ -v -s --hosts s4,s9,s10,s22` — pin specific hosts - `EXO_REF=main uv run pytest tests/integration/ -v` — deploy from a GitHub branch (CI) - `uv run pytest` — confirms integration tests are excluded from default runs
265 lines
8.0 KiB
Python
265 lines
8.0 KiB
Python
import socket
|
|
from typing import Literal
|
|
|
|
import anyio
|
|
from fastapi import FastAPI
|
|
from fastapi.responses import Response, StreamingResponse
|
|
from hypercorn import Config
|
|
from hypercorn.asyncio import serve # pyright: ignore[reportUnknownVariableType]
|
|
from loguru import logger
|
|
from pydantic import BaseModel
|
|
|
|
from exo.shared.constants import EXO_DEFAULT_MODELS_DIR
|
|
from exo.shared.models.model_cards import ModelCard, ModelId
|
|
from exo.shared.types.chunks import TokenChunk
|
|
from exo.shared.types.commands import CommandId
|
|
from exo.shared.types.common import Host, NodeId
|
|
from exo.shared.types.events import ChunkGenerated, Event, RunnerStatusUpdated
|
|
from exo.shared.types.tasks import (
|
|
ConnectToGroup,
|
|
LoadModel,
|
|
Shutdown,
|
|
StartWarmup,
|
|
Task,
|
|
TextGeneration,
|
|
)
|
|
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
|
|
from exo.shared.types.worker.instances import (
|
|
BoundInstance,
|
|
Instance,
|
|
InstanceId,
|
|
MlxJacclInstance,
|
|
MlxRingInstance,
|
|
)
|
|
from exo.shared.types.worker.runners import (
|
|
RunnerFailed,
|
|
RunnerId,
|
|
RunnerShutdown,
|
|
ShardAssignments,
|
|
)
|
|
from exo.shared.types.worker.shards import PipelineShardMetadata, TensorShardMetadata
|
|
from exo.utils.channels import channel, mp_channel
|
|
from exo.utils.info_gatherer.info_gatherer import GatheredInfo, InfoGatherer
|
|
from exo.worker.runner.bootstrap import entrypoint
|
|
|
|
|
|
class Tests(BaseModel):
|
|
# list[hostname, ip addr]
|
|
devs: list[list[str]]
|
|
ibv_devs: list[list[str | None]] | None
|
|
model_id: ModelId
|
|
kind: Literal["ring", "jaccl", "both"]
|
|
|
|
|
|
iid = InstanceId("im testing here")
|
|
|
|
|
|
async def main():
|
|
logger.info("starting cool server majig")
|
|
cfg = Config()
|
|
cfg.bind = "0.0.0.0:52414"
|
|
# nb: shared.logging needs updating if any of this changes
|
|
cfg.accesslog = "-"
|
|
cfg.errorlog = "-"
|
|
ev = anyio.Event()
|
|
app = FastAPI()
|
|
app.post("/run_test")(run_test)
|
|
app.post("/kill")(lambda: kill(ev))
|
|
app.get("/tb_detection")(tb_detection)
|
|
app.get("/models")(list_models)
|
|
await serve(
|
|
app, # type: ignore
|
|
cfg,
|
|
shutdown_trigger=lambda: ev.wait(),
|
|
)
|
|
|
|
|
|
def kill(ev: anyio.Event):
|
|
ev.set()
|
|
return Response(status_code=204)
|
|
|
|
|
|
async def tb_detection():
|
|
send, recv = channel[GatheredInfo]()
|
|
ig = InfoGatherer(send)
|
|
with anyio.move_on_after(1):
|
|
await ig._monitor_system_profiler_thunderbolt_data() # pyright: ignore[reportPrivateUsage]
|
|
with recv:
|
|
return recv.collect()
|
|
|
|
|
|
def list_models():
|
|
sent = set[str]()
|
|
for path in EXO_DEFAULT_MODELS_DIR.rglob("model-*.safetensors"):
|
|
if "--" not in path.parent.name:
|
|
continue
|
|
name = path.parent.name.replace("--", "/")
|
|
if name in sent:
|
|
continue
|
|
sent.add(name)
|
|
yield ModelId(path.parent.name.replace("--", "/"))
|
|
|
|
|
|
async def run_test(test: Tests):
|
|
weird_hn = socket.gethostname()
|
|
for dev in test.devs:
|
|
if weird_hn.startswith(dev[0]) or dev[0].startswith(weird_hn):
|
|
hn = dev[0]
|
|
break
|
|
else:
|
|
raise ValueError(f"{weird_hn} not in {test.devs}")
|
|
|
|
async def run():
|
|
logger.info(f"testing {test.model_id}")
|
|
|
|
instances: list[Instance] = []
|
|
if test.kind in ["ring", "both"]:
|
|
i = await ring_instance(test, hn)
|
|
if i is None:
|
|
yield "no model found"
|
|
return
|
|
instances.append(i)
|
|
if test.kind in ["jaccl", "both"]:
|
|
i = await jaccl_instance(test)
|
|
if i is None:
|
|
yield "no model found"
|
|
return
|
|
instances.append(i)
|
|
|
|
for instance in instances:
|
|
recv = await execute_test(test, instance, hn)
|
|
|
|
str_out = ""
|
|
|
|
for item in recv:
|
|
if isinstance(item, ChunkGenerated):
|
|
assert isinstance(item.chunk, TokenChunk)
|
|
str_out += item.chunk.text
|
|
|
|
if isinstance(item, RunnerStatusUpdated) and isinstance(
|
|
item.runner_status, (RunnerFailed, RunnerShutdown)
|
|
):
|
|
yield str_out + "\n"
|
|
yield item.model_dump_json() + "\n"
|
|
|
|
return StreamingResponse(run())
|
|
|
|
|
|
async def ring_instance(test: Tests, hn: str) -> Instance | None:
|
|
hbn = [Host(ip="198.51.100.0", port=52417) for _ in test.devs]
|
|
world_size = len(test.devs)
|
|
for i in range(world_size):
|
|
if test.devs[i][0] == hn:
|
|
hn = test.devs[i][0]
|
|
hbn[(i - 1) % world_size] = Host(ip=test.devs[i - 1][1], port=52417)
|
|
hbn[(i + 1) % world_size] = Host(ip=test.devs[i + 1][1], port=52417)
|
|
hbn[i] = Host(ip="0.0.0.0", port=52417)
|
|
break
|
|
else:
|
|
raise ValueError(f"{hn} not in {test.devs}")
|
|
|
|
card = await ModelCard.load(test.model_id)
|
|
instance = MlxRingInstance(
|
|
instance_id=iid,
|
|
ephemeral_port=52417,
|
|
hosts_by_node={NodeId(hn): hbn},
|
|
shard_assignments=ShardAssignments(
|
|
model_id=test.model_id,
|
|
node_to_runner={NodeId(host[0]): RunnerId(host[0]) for host in test.devs},
|
|
runner_to_shard={
|
|
RunnerId(test.devs[i][0]): PipelineShardMetadata(
|
|
model_card=card,
|
|
device_rank=i,
|
|
world_size=world_size,
|
|
start_layer=(card.n_layers // world_size) * i,
|
|
end_layer=min(
|
|
card.n_layers, (card.n_layers // world_size) * (i + 1)
|
|
),
|
|
n_layers=min(card.n_layers, (card.n_layers // world_size) * (i + 1))
|
|
- (card.n_layers // world_size) * i,
|
|
)
|
|
for i in range(world_size)
|
|
},
|
|
),
|
|
)
|
|
|
|
return instance
|
|
|
|
|
|
async def execute_test(test: Tests, instance: Instance, hn: str) -> list[Event]:
|
|
world_size = len(test.devs)
|
|
commands: list[Task] = [
|
|
(LoadModel(instance_id=iid)),
|
|
(StartWarmup(instance_id=iid)),
|
|
(
|
|
TextGeneration(
|
|
task_params=TextGenerationTaskParams(
|
|
model=test.model_id,
|
|
instructions="You are a helpful assistant",
|
|
input=[
|
|
InputMessage(
|
|
role="user", content="What is the capital of France?"
|
|
)
|
|
],
|
|
),
|
|
command_id=CommandId("yo"),
|
|
instance_id=iid,
|
|
)
|
|
),
|
|
(Shutdown(runner_id=RunnerId(hn), instance_id=iid)),
|
|
]
|
|
if world_size > 1:
|
|
commands.insert(0, ConnectToGroup(instance_id=iid))
|
|
bound_instance = BoundInstance(
|
|
instance=instance, bound_runner_id=RunnerId(hn), bound_node_id=NodeId(hn)
|
|
)
|
|
ev_send, _ev_recv = mp_channel[Event]()
|
|
task_send, task_recv = mp_channel[Task]()
|
|
|
|
for command in commands:
|
|
task_send.send(command)
|
|
|
|
entrypoint(
|
|
bound_instance,
|
|
ev_send,
|
|
task_recv,
|
|
logger,
|
|
)
|
|
|
|
# TODO(evan): return ev_recv.collect()
|
|
return []
|
|
|
|
|
|
async def jaccl_instance(test: Tests) -> MlxJacclInstance | None:
|
|
card = await ModelCard.load(test.model_id)
|
|
world_size = len(test.devs)
|
|
assert test.ibv_devs
|
|
|
|
return MlxJacclInstance(
|
|
instance_id=iid,
|
|
jaccl_devices=test.ibv_devs,
|
|
# rank 0 is always coordinator
|
|
jaccl_coordinators={
|
|
NodeId(host[0]): test.devs[0][1] + ":52417" for host in test.devs
|
|
},
|
|
shard_assignments=ShardAssignments(
|
|
model_id=test.model_id,
|
|
node_to_runner={NodeId(host[0]): RunnerId(host[0]) for host in test.devs},
|
|
runner_to_shard={
|
|
RunnerId(host[0]): TensorShardMetadata(
|
|
model_card=card,
|
|
device_rank=i,
|
|
world_size=world_size,
|
|
start_layer=0,
|
|
end_layer=card.n_layers,
|
|
n_layers=card.n_layers,
|
|
)
|
|
for i, host in enumerate(test.devs)
|
|
},
|
|
),
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
anyio.run(main)
|