mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-19 15:27:02 -05:00
Compare commits
1 Commits
main
...
feat/e2e-c
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
891166ac36 |
0
src/exo/tests/__init__.py
Normal file
0
src/exo/tests/__init__.py
Normal file
0
src/exo/tests/e2e_chaos/__init__.py
Normal file
0
src/exo/tests/e2e_chaos/__init__.py
Normal file
307
src/exo/tests/e2e_chaos/conftest.py
Normal file
307
src/exo/tests/e2e_chaos/conftest.py
Normal file
@@ -0,0 +1,307 @@
|
||||
"""Shared fixtures and helpers for E2E chaos/networking tests.
|
||||
|
||||
Provides a ``MiniCluster`` that wires Master + Worker(s) + Election together
|
||||
using in-process channels. No Docker, no network, no GPU -- pure async
|
||||
integration testing of the coordination layer.
|
||||
"""
|
||||
|
||||
from collections.abc import Iterator
|
||||
from datetime import datetime, timezone
|
||||
from typing import Final
|
||||
|
||||
import pytest
|
||||
from _pytest.logging import LogCaptureFixture
|
||||
from loguru import logger
|
||||
|
||||
from exo.master.main import Master
|
||||
from exo.shared.models.model_cards import ModelCard, ModelTask
|
||||
from exo.shared.types.commands import (
|
||||
CommandId,
|
||||
ForwarderCommand,
|
||||
ForwarderDownloadCommand,
|
||||
PlaceInstance,
|
||||
TextGeneration,
|
||||
)
|
||||
from exo.shared.types.common import ModelId, NodeId, SessionId
|
||||
from exo.shared.types.events import (
|
||||
ForwarderEvent,
|
||||
IndexedEvent,
|
||||
NodeGatheredInfo,
|
||||
TopologyEdgeCreated,
|
||||
)
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.multiaddr import Multiaddr
|
||||
from exo.shared.types.profiling import MemoryUsage
|
||||
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
|
||||
from exo.shared.types.topology import Connection, SocketConnection
|
||||
from exo.shared.types.worker.instances import InstanceMeta
|
||||
from exo.shared.types.worker.shards import Sharding
|
||||
from exo.utils.channels import Receiver, Sender, channel
|
||||
from exo.worker.main import Worker
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Constants
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
TEST_MODEL_ID: Final[ModelId] = ModelId("test-model/chaos-test-1b")
|
||||
TEST_MODEL_CARD: Final[ModelCard] = ModelCard(
|
||||
model_id=TEST_MODEL_ID,
|
||||
n_layers=16,
|
||||
storage_size=Memory.from_bytes(678_948),
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
)
|
||||
|
||||
FAST_ELECTION_TIMEOUT: Final[float] = 0.1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def make_node_id(label: str) -> NodeId:
|
||||
return NodeId(f"node-{label}")
|
||||
|
||||
|
||||
def make_session_id(master_node: NodeId) -> SessionId:
|
||||
return SessionId(master_node_id=master_node, election_clock=0)
|
||||
|
||||
|
||||
def make_memory_info() -> MemoryUsage:
|
||||
return MemoryUsage(
|
||||
ram_total=Memory.from_bytes(16 * 1024 * 1024 * 1024),
|
||||
ram_available=Memory.from_bytes(8 * 1024 * 1024 * 1024),
|
||||
swap_total=Memory.from_bytes(0),
|
||||
swap_available=Memory.from_bytes(0),
|
||||
)
|
||||
|
||||
|
||||
def make_gathered_info_event(
|
||||
node_id: NodeId, sender_id: NodeId, session_id: SessionId, origin_idx: int
|
||||
) -> ForwarderEvent:
|
||||
return ForwarderEvent(
|
||||
origin_idx=origin_idx,
|
||||
origin=sender_id,
|
||||
session=session_id,
|
||||
event=NodeGatheredInfo(
|
||||
when=str(datetime.now(tz=timezone.utc)),
|
||||
node_id=node_id,
|
||||
info=make_memory_info(),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def make_topology_edge_event(
|
||||
source: NodeId,
|
||||
sink: NodeId,
|
||||
sender_id: NodeId,
|
||||
session_id: SessionId,
|
||||
origin_idx: int,
|
||||
ip_suffix: int = 1,
|
||||
) -> ForwarderEvent:
|
||||
"""Create a ForwarderEvent wrapping a TopologyEdgeCreated event."""
|
||||
return ForwarderEvent(
|
||||
origin_idx=origin_idx,
|
||||
origin=sender_id,
|
||||
session=session_id,
|
||||
event=TopologyEdgeCreated(
|
||||
conn=Connection(
|
||||
source=source,
|
||||
sink=sink,
|
||||
edge=SocketConnection(
|
||||
sink_multiaddr=Multiaddr(
|
||||
address=f"/ip4/10.0.0.{ip_suffix}/tcp/52415"
|
||||
)
|
||||
),
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class EventCollector:
|
||||
"""Collects ForwarderEvents from a global event receiver."""
|
||||
|
||||
def __init__(self, receiver: Receiver[ForwarderEvent]) -> None:
|
||||
self._receiver = receiver
|
||||
self.indexed_events: list[IndexedEvent] = []
|
||||
|
||||
def collect(self) -> list[IndexedEvent]:
|
||||
raw = self._receiver.collect()
|
||||
for fe in raw:
|
||||
self.indexed_events.append(
|
||||
IndexedEvent(event=fe.event, idx=len(self.indexed_events))
|
||||
)
|
||||
return self.indexed_events
|
||||
|
||||
async def wait_for_event_count(
|
||||
self, count: int, *, timeout: float = 5.0, poll_interval: float = 0.01
|
||||
) -> list[IndexedEvent]:
|
||||
import anyio
|
||||
|
||||
with anyio.fail_after(timeout):
|
||||
while len(self.collect()) < count:
|
||||
await anyio.sleep(poll_interval)
|
||||
return self.indexed_events
|
||||
|
||||
|
||||
class MiniCluster:
|
||||
"""An in-process cluster with one Master and N Workers wired via channels.
|
||||
|
||||
No networking, no real model loading -- exercises the coordination logic
|
||||
(event sourcing, command routing, election) in a deterministic, fast
|
||||
test harness.
|
||||
"""
|
||||
|
||||
def __init__(self, node_count: int = 2) -> None:
|
||||
self.node_count = node_count
|
||||
self.master_node_id = make_node_id("master")
|
||||
self.session_id = make_session_id(self.master_node_id)
|
||||
|
||||
# -- shared bus channels --
|
||||
self.global_event_sender: Sender[ForwarderEvent]
|
||||
self.global_event_internal_receiver: Receiver[ForwarderEvent]
|
||||
self.global_event_sender, self.global_event_internal_receiver = channel[
|
||||
ForwarderEvent
|
||||
]()
|
||||
|
||||
self.command_sender: Sender[ForwarderCommand]
|
||||
self.command_receiver: Receiver[ForwarderCommand]
|
||||
self.command_sender, self.command_receiver = channel[ForwarderCommand]()
|
||||
|
||||
self.local_event_sender: Sender[ForwarderEvent]
|
||||
self.local_event_receiver: Receiver[ForwarderEvent]
|
||||
self.local_event_sender, self.local_event_receiver = channel[ForwarderEvent]()
|
||||
|
||||
self.download_cmd_sender: Sender[ForwarderDownloadCommand]
|
||||
self._download_cmd_receiver: Receiver[ForwarderDownloadCommand]
|
||||
self.download_cmd_sender, self._download_cmd_receiver = channel[
|
||||
ForwarderDownloadCommand
|
||||
]()
|
||||
|
||||
# -- event collector (taps global events) --
|
||||
self.event_collector = EventCollector(
|
||||
self.global_event_internal_receiver.clone()
|
||||
)
|
||||
|
||||
# -- master --
|
||||
self.master = Master(
|
||||
self.master_node_id,
|
||||
self.session_id,
|
||||
global_event_sender=self.global_event_sender.clone(),
|
||||
local_event_receiver=self.local_event_receiver.clone(),
|
||||
command_receiver=self.command_receiver.clone(),
|
||||
download_command_sender=self.download_cmd_sender.clone(),
|
||||
)
|
||||
|
||||
# -- workers --
|
||||
self.worker_node_ids: list[NodeId] = []
|
||||
self.workers: list[Worker] = []
|
||||
for i in range(node_count):
|
||||
wid = make_node_id(f"worker-{i}")
|
||||
self.worker_node_ids.append(wid)
|
||||
|
||||
counter: Iterator[int] = iter(range(1_000_000))
|
||||
worker = Worker(
|
||||
wid,
|
||||
self.session_id,
|
||||
global_event_receiver=self.global_event_internal_receiver.clone(),
|
||||
local_event_sender=self.local_event_sender.clone(),
|
||||
command_sender=self.command_sender.clone(),
|
||||
download_command_sender=self.download_cmd_sender.clone(),
|
||||
event_index_counter=counter,
|
||||
)
|
||||
self.workers.append(worker)
|
||||
|
||||
async def inject_node_info(self, node_id: NodeId, sender_suffix: str = "") -> None:
|
||||
"""Inject a NodeGatheredInfo event for a node into the local event bus."""
|
||||
sender_id = NodeId(f"{node_id}_sender{sender_suffix}")
|
||||
await self.local_event_sender.send(
|
||||
make_gathered_info_event(node_id, sender_id, self.session_id, 0)
|
||||
)
|
||||
|
||||
async def wait_for_topology_nodes(
|
||||
self, count: int, *, timeout: float = 5.0
|
||||
) -> None:
|
||||
import anyio
|
||||
|
||||
with anyio.fail_after(timeout):
|
||||
while len(list(self.master.state.topology.list_nodes())) < count:
|
||||
await anyio.sleep(0.01)
|
||||
|
||||
async def place_model(
|
||||
self,
|
||||
model_card: ModelCard | None = None,
|
||||
min_nodes: int = 1,
|
||||
) -> None:
|
||||
card = model_card or TEST_MODEL_CARD
|
||||
await self.command_sender.send(
|
||||
ForwarderCommand(
|
||||
origin=self.master_node_id,
|
||||
command=PlaceInstance(
|
||||
command_id=CommandId(),
|
||||
model_card=card,
|
||||
sharding=Sharding.Pipeline,
|
||||
instance_meta=InstanceMeta.MlxRing,
|
||||
min_nodes=min_nodes,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
async def wait_for_instances(self, count: int, *, timeout: float = 5.0) -> None:
|
||||
import anyio
|
||||
|
||||
with anyio.fail_after(timeout):
|
||||
while len(self.master.state.instances) < count:
|
||||
await anyio.sleep(0.01)
|
||||
|
||||
async def send_chat(
|
||||
self,
|
||||
message: str,
|
||||
model: ModelId | None = None,
|
||||
) -> CommandId:
|
||||
cmd_id = CommandId()
|
||||
await self.command_sender.send(
|
||||
ForwarderCommand(
|
||||
origin=self.master_node_id,
|
||||
command=TextGeneration(
|
||||
command_id=cmd_id,
|
||||
task_params=TextGenerationTaskParams(
|
||||
model=model or TEST_MODEL_ID,
|
||||
input=[InputMessage(role="user", content=message)],
|
||||
),
|
||||
),
|
||||
)
|
||||
)
|
||||
return cmd_id
|
||||
|
||||
async def shutdown_master(self) -> None:
|
||||
await self.master.shutdown()
|
||||
|
||||
def shutdown_workers(self) -> None:
|
||||
for w in self.workers:
|
||||
w.shutdown()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def fast_election_timeout(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr("exo.shared.election.DEFAULT_ELECTION_TIMEOUT", 0.1)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def caplog(caplog: LogCaptureFixture) -> Iterator[LogCaptureFixture]:
|
||||
handler_id = logger.add(
|
||||
caplog.handler,
|
||||
format="{message}",
|
||||
level=0,
|
||||
filter=lambda record: record["level"].no >= caplog.handler.level,
|
||||
enqueue=True,
|
||||
)
|
||||
yield caplog
|
||||
logger.remove(handler_id)
|
||||
255
src/exo/tests/e2e_chaos/test_client_disconnect.py
Normal file
255
src/exo/tests/e2e_chaos/test_client_disconnect.py
Normal file
@@ -0,0 +1,255 @@
|
||||
"""E2E Chaos Test: Client disconnect.
|
||||
|
||||
Scenarios:
|
||||
1. Task cancellation after client disconnect -- a TextGeneration command is
|
||||
sent, then immediately cancelled (simulating browser tab close).
|
||||
Verify the master correctly transitions the task to Cancelled status.
|
||||
2. Multiple rapid cancellations -- several chat commands are sent and
|
||||
cancelled in quick succession; no tasks should remain in a stuck state.
|
||||
"""
|
||||
|
||||
import anyio
|
||||
import pytest
|
||||
|
||||
from exo.master.main import Master
|
||||
from exo.shared.types.commands import (
|
||||
CommandId,
|
||||
ForwarderCommand,
|
||||
ForwarderDownloadCommand,
|
||||
PlaceInstance,
|
||||
TaskCancelled,
|
||||
TextGeneration,
|
||||
)
|
||||
from exo.shared.types.common import NodeId, SessionId
|
||||
from exo.shared.types.events import (
|
||||
ForwarderEvent,
|
||||
)
|
||||
from exo.shared.types.tasks import TaskStatus
|
||||
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
|
||||
from exo.shared.types.worker.instances import InstanceMeta
|
||||
from exo.shared.types.worker.shards import Sharding
|
||||
from exo.utils.channels import channel
|
||||
|
||||
from .conftest import (
|
||||
TEST_MODEL_CARD,
|
||||
TEST_MODEL_ID,
|
||||
EventCollector,
|
||||
make_gathered_info_event,
|
||||
make_node_id,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.asyncio
|
||||
async def test_task_cancelled_after_client_disconnect() -> None:
|
||||
"""Simulate a browser tab close by sending a TextGeneration command
|
||||
followed immediately by a TaskCancelled command. Verify the task
|
||||
transitions to Cancelled status.
|
||||
"""
|
||||
master_nid = make_node_id("master-cancel")
|
||||
session_id = SessionId(master_node_id=master_nid, election_clock=0)
|
||||
|
||||
ge_sender, ge_receiver = channel[ForwarderEvent]()
|
||||
cmd_sender, cmd_receiver = channel[ForwarderCommand]()
|
||||
le_sender, le_receiver = channel[ForwarderEvent]()
|
||||
dl_sender, _dl_receiver = channel[ForwarderDownloadCommand]()
|
||||
|
||||
master = Master(
|
||||
master_nid,
|
||||
session_id,
|
||||
global_event_sender=ge_sender,
|
||||
local_event_receiver=le_receiver,
|
||||
command_receiver=cmd_receiver,
|
||||
download_command_sender=dl_sender,
|
||||
)
|
||||
|
||||
_collector = EventCollector(ge_receiver.clone())
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
tg.start_soon(master.run)
|
||||
|
||||
# Register node
|
||||
sender_id = NodeId(f"{master_nid}_sender")
|
||||
await le_sender.send(
|
||||
make_gathered_info_event(master_nid, sender_id, session_id, 0)
|
||||
)
|
||||
|
||||
with anyio.fail_after(3):
|
||||
while len(list(master.state.topology.list_nodes())) == 0:
|
||||
await anyio.sleep(0.01)
|
||||
|
||||
# Place instance
|
||||
await cmd_sender.send(
|
||||
ForwarderCommand(
|
||||
origin=master_nid,
|
||||
command=PlaceInstance(
|
||||
command_id=CommandId(),
|
||||
model_card=TEST_MODEL_CARD,
|
||||
sharding=Sharding.Pipeline,
|
||||
instance_meta=InstanceMeta.MlxRing,
|
||||
min_nodes=1,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
with anyio.fail_after(3):
|
||||
while len(master.state.instances) == 0:
|
||||
await anyio.sleep(0.01)
|
||||
|
||||
# Send a chat command
|
||||
chat_cmd_id = CommandId()
|
||||
await cmd_sender.send(
|
||||
ForwarderCommand(
|
||||
origin=master_nid,
|
||||
command=TextGeneration(
|
||||
command_id=chat_cmd_id,
|
||||
task_params=TextGenerationTaskParams(
|
||||
model=TEST_MODEL_ID,
|
||||
input=[InputMessage(role="user", content="Hello world")],
|
||||
),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# Wait for the task to be created
|
||||
with anyio.fail_after(3):
|
||||
while len(master.state.tasks) == 0:
|
||||
await anyio.sleep(0.01)
|
||||
|
||||
# Immediately cancel -- simulating browser tab close
|
||||
await cmd_sender.send(
|
||||
ForwarderCommand(
|
||||
origin=master_nid,
|
||||
command=TaskCancelled(
|
||||
command_id=CommandId(),
|
||||
cancelled_command_id=chat_cmd_id,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# Wait for the task status to be updated to Cancelled
|
||||
with anyio.fail_after(3):
|
||||
while True:
|
||||
tasks_cancelled = [
|
||||
t
|
||||
for t in master.state.tasks.values()
|
||||
if t.task_status == TaskStatus.Cancelled
|
||||
]
|
||||
if tasks_cancelled:
|
||||
break
|
||||
await anyio.sleep(0.01)
|
||||
|
||||
assert len(tasks_cancelled) == 1
|
||||
|
||||
await master.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.asyncio
|
||||
async def test_rapid_cancel_does_not_leave_stuck_tasks() -> None:
|
||||
"""Send multiple chat commands and cancel them all rapidly.
|
||||
Verify no tasks remain in Pending or Running state.
|
||||
"""
|
||||
master_nid = make_node_id("master-rapid-cancel")
|
||||
session_id = SessionId(master_node_id=master_nid, election_clock=0)
|
||||
|
||||
ge_sender, _ge_receiver = channel[ForwarderEvent]()
|
||||
cmd_sender, cmd_receiver = channel[ForwarderCommand]()
|
||||
le_sender, le_receiver = channel[ForwarderEvent]()
|
||||
dl_sender, _dl_receiver = channel[ForwarderDownloadCommand]()
|
||||
|
||||
master = Master(
|
||||
master_nid,
|
||||
session_id,
|
||||
global_event_sender=ge_sender,
|
||||
local_event_receiver=le_receiver,
|
||||
command_receiver=cmd_receiver,
|
||||
download_command_sender=dl_sender,
|
||||
)
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
tg.start_soon(master.run)
|
||||
|
||||
# Register node and place instance
|
||||
sender_id = NodeId(f"{master_nid}_sender")
|
||||
await le_sender.send(
|
||||
make_gathered_info_event(master_nid, sender_id, session_id, 0)
|
||||
)
|
||||
|
||||
with anyio.fail_after(3):
|
||||
while len(list(master.state.topology.list_nodes())) == 0:
|
||||
await anyio.sleep(0.01)
|
||||
|
||||
await cmd_sender.send(
|
||||
ForwarderCommand(
|
||||
origin=master_nid,
|
||||
command=PlaceInstance(
|
||||
command_id=CommandId(),
|
||||
model_card=TEST_MODEL_CARD,
|
||||
sharding=Sharding.Pipeline,
|
||||
instance_meta=InstanceMeta.MlxRing,
|
||||
min_nodes=1,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
with anyio.fail_after(3):
|
||||
while len(master.state.instances) == 0:
|
||||
await anyio.sleep(0.01)
|
||||
|
||||
# Send 5 chat commands and immediately cancel each
|
||||
chat_cmd_ids: list[CommandId] = []
|
||||
for i in range(5):
|
||||
cmd_id = CommandId()
|
||||
chat_cmd_ids.append(cmd_id)
|
||||
await cmd_sender.send(
|
||||
ForwarderCommand(
|
||||
origin=master_nid,
|
||||
command=TextGeneration(
|
||||
command_id=cmd_id,
|
||||
task_params=TextGenerationTaskParams(
|
||||
model=TEST_MODEL_ID,
|
||||
input=[InputMessage(role="user", content=f"Message {i}")],
|
||||
),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# Wait for all tasks to be created
|
||||
with anyio.fail_after(3):
|
||||
while len(master.state.tasks) < 5:
|
||||
await anyio.sleep(0.01)
|
||||
|
||||
# Cancel all of them
|
||||
for cmd_id in chat_cmd_ids:
|
||||
await cmd_sender.send(
|
||||
ForwarderCommand(
|
||||
origin=master_nid,
|
||||
command=TaskCancelled(
|
||||
command_id=CommandId(),
|
||||
cancelled_command_id=cmd_id,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# Wait for all cancellations to be processed
|
||||
with anyio.fail_after(3):
|
||||
while True:
|
||||
cancelled_count = sum(
|
||||
1
|
||||
for t in master.state.tasks.values()
|
||||
if t.task_status == TaskStatus.Cancelled
|
||||
)
|
||||
if cancelled_count == 5:
|
||||
break
|
||||
await anyio.sleep(0.01)
|
||||
|
||||
# No tasks should be Pending or Running
|
||||
stuck = [
|
||||
t
|
||||
for t in master.state.tasks.values()
|
||||
if t.task_status in (TaskStatus.Pending, TaskStatus.Running)
|
||||
]
|
||||
assert len(stuck) == 0
|
||||
|
||||
await master.shutdown()
|
||||
395
src/exo/tests/e2e_chaos/test_concurrent_requests.py
Normal file
395
src/exo/tests/e2e_chaos/test_concurrent_requests.py
Normal file
@@ -0,0 +1,395 @@
|
||||
"""E2E Chaos Test: Concurrent requests.
|
||||
|
||||
Scenarios:
|
||||
1. Multiple simultaneous inference requests -- verify they are all created
|
||||
as tasks with no data corruption (unique task IDs, correct model IDs).
|
||||
2. Concurrent requests across multiple model instances -- verify tasks are
|
||||
routed to the correct instances.
|
||||
3. Concurrent requests with load balancing -- when multiple instances of
|
||||
the same model exist, verify tasks are distributed.
|
||||
"""
|
||||
|
||||
import anyio
|
||||
import pytest
|
||||
|
||||
from exo.master.main import Master
|
||||
from exo.shared.models.model_cards import ModelCard, ModelTask
|
||||
from exo.shared.types.commands import (
|
||||
CommandId,
|
||||
ForwarderCommand,
|
||||
ForwarderDownloadCommand,
|
||||
PlaceInstance,
|
||||
TextGeneration,
|
||||
)
|
||||
from exo.shared.types.common import ModelId, NodeId, SessionId
|
||||
from exo.shared.types.events import (
|
||||
ForwarderEvent,
|
||||
)
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.tasks import TaskStatus
|
||||
from exo.shared.types.tasks import TextGeneration as TextGenerationTask
|
||||
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
|
||||
from exo.shared.types.worker.instances import InstanceMeta
|
||||
from exo.shared.types.worker.shards import Sharding
|
||||
from exo.utils.channels import channel
|
||||
|
||||
from .conftest import (
|
||||
TEST_MODEL_CARD,
|
||||
TEST_MODEL_ID,
|
||||
EventCollector,
|
||||
make_gathered_info_event,
|
||||
make_node_id,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_chat_requests_no_corruption() -> None:
|
||||
"""Send multiple TextGeneration commands concurrently and verify each
|
||||
results in a unique task with the correct model and content mapping.
|
||||
"""
|
||||
master_nid = make_node_id("master-concurrent")
|
||||
session_id = SessionId(master_node_id=master_nid, election_clock=0)
|
||||
|
||||
ge_sender, ge_receiver = channel[ForwarderEvent]()
|
||||
cmd_sender, cmd_receiver = channel[ForwarderCommand]()
|
||||
le_sender, le_receiver = channel[ForwarderEvent]()
|
||||
dl_sender, _dl_receiver = channel[ForwarderDownloadCommand]()
|
||||
|
||||
master = Master(
|
||||
master_nid,
|
||||
session_id,
|
||||
global_event_sender=ge_sender,
|
||||
local_event_receiver=le_receiver,
|
||||
command_receiver=cmd_receiver,
|
||||
download_command_sender=dl_sender,
|
||||
)
|
||||
|
||||
_collector = EventCollector(ge_receiver.clone())
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
tg.start_soon(master.run)
|
||||
|
||||
# Set up node and instance
|
||||
sender_id = NodeId(f"{master_nid}_sender")
|
||||
await le_sender.send(
|
||||
make_gathered_info_event(master_nid, sender_id, session_id, 0)
|
||||
)
|
||||
|
||||
with anyio.fail_after(3):
|
||||
while len(list(master.state.topology.list_nodes())) == 0:
|
||||
await anyio.sleep(0.01)
|
||||
|
||||
await cmd_sender.send(
|
||||
ForwarderCommand(
|
||||
origin=master_nid,
|
||||
command=PlaceInstance(
|
||||
command_id=CommandId(),
|
||||
model_card=TEST_MODEL_CARD,
|
||||
sharding=Sharding.Pipeline,
|
||||
instance_meta=InstanceMeta.MlxRing,
|
||||
min_nodes=1,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
with anyio.fail_after(3):
|
||||
while len(master.state.instances) == 0:
|
||||
await anyio.sleep(0.01)
|
||||
|
||||
# Send 10 concurrent chat requests
|
||||
num_requests = 10
|
||||
cmd_ids: list[CommandId] = []
|
||||
|
||||
async def send_chat(index: int) -> None:
|
||||
cmd_id = CommandId()
|
||||
cmd_ids.append(cmd_id)
|
||||
await cmd_sender.send(
|
||||
ForwarderCommand(
|
||||
origin=master_nid,
|
||||
command=TextGeneration(
|
||||
command_id=cmd_id,
|
||||
task_params=TextGenerationTaskParams(
|
||||
model=TEST_MODEL_ID,
|
||||
input=[
|
||||
InputMessage(
|
||||
role="user",
|
||||
content=f"Concurrent request #{index}",
|
||||
)
|
||||
],
|
||||
),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
async with anyio.create_task_group() as send_tg:
|
||||
for i in range(num_requests):
|
||||
send_tg.start_soon(send_chat, i)
|
||||
|
||||
# Wait for all tasks to be created
|
||||
with anyio.fail_after(5):
|
||||
while len(master.state.tasks) < num_requests:
|
||||
await anyio.sleep(0.01)
|
||||
|
||||
# Verify no corruption
|
||||
assert len(master.state.tasks) == num_requests
|
||||
|
||||
# All task IDs should be unique
|
||||
task_ids = list(master.state.tasks.keys())
|
||||
assert len(set(task_ids)) == num_requests
|
||||
|
||||
# All tasks should target the correct model
|
||||
for task in master.state.tasks.values():
|
||||
assert isinstance(task, TextGenerationTask)
|
||||
assert task.task_params.model == TEST_MODEL_ID
|
||||
assert task.task_status == TaskStatus.Pending
|
||||
|
||||
# All tasks should reference the same instance
|
||||
instance_ids = {task.instance_id for task in master.state.tasks.values()}
|
||||
assert len(instance_ids) == 1
|
||||
|
||||
await master.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_requests_across_multiple_models() -> None:
|
||||
"""Place two different models, then send concurrent requests for each.
|
||||
Verify tasks are routed to the correct model instances.
|
||||
"""
|
||||
master_nid = make_node_id("master-multi-model")
|
||||
session_id = SessionId(master_node_id=master_nid, election_clock=0)
|
||||
|
||||
ge_sender, _ge_receiver = channel[ForwarderEvent]()
|
||||
cmd_sender, cmd_receiver = channel[ForwarderCommand]()
|
||||
le_sender, le_receiver = channel[ForwarderEvent]()
|
||||
dl_sender, _dl_receiver = channel[ForwarderDownloadCommand]()
|
||||
|
||||
master = Master(
|
||||
master_nid,
|
||||
session_id,
|
||||
global_event_sender=ge_sender,
|
||||
local_event_receiver=le_receiver,
|
||||
command_receiver=cmd_receiver,
|
||||
download_command_sender=dl_sender,
|
||||
)
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
tg.start_soon(master.run)
|
||||
|
||||
# Register node
|
||||
sender_id = NodeId(f"{master_nid}_sender")
|
||||
await le_sender.send(
|
||||
make_gathered_info_event(master_nid, sender_id, session_id, 0)
|
||||
)
|
||||
|
||||
with anyio.fail_after(3):
|
||||
while len(list(master.state.topology.list_nodes())) == 0:
|
||||
await anyio.sleep(0.01)
|
||||
|
||||
# Place two different models
|
||||
model_a_id = ModelId("test-model/model-a")
|
||||
model_a_card = ModelCard(
|
||||
model_id=model_a_id,
|
||||
n_layers=16,
|
||||
storage_size=Memory.from_bytes(500_000),
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
)
|
||||
|
||||
model_b_id = ModelId("test-model/model-b")
|
||||
model_b_card = ModelCard(
|
||||
model_id=model_b_id,
|
||||
n_layers=32,
|
||||
storage_size=Memory.from_bytes(500_000),
|
||||
hidden_size=4096,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
)
|
||||
|
||||
for card in [model_a_card, model_b_card]:
|
||||
await cmd_sender.send(
|
||||
ForwarderCommand(
|
||||
origin=master_nid,
|
||||
command=PlaceInstance(
|
||||
command_id=CommandId(),
|
||||
model_card=card,
|
||||
sharding=Sharding.Pipeline,
|
||||
instance_meta=InstanceMeta.MlxRing,
|
||||
min_nodes=1,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
with anyio.fail_after(5):
|
||||
while len(master.state.instances) < 2:
|
||||
await anyio.sleep(0.01)
|
||||
|
||||
# Map instance IDs to models
|
||||
instance_to_model: dict[str, ModelId] = {}
|
||||
for iid, inst in master.state.instances.items():
|
||||
instance_to_model[iid] = inst.shard_assignments.model_id
|
||||
|
||||
# Send concurrent requests for both models
|
||||
async def send_for_model(model_id: ModelId, count: int) -> None:
|
||||
for i in range(count):
|
||||
await cmd_sender.send(
|
||||
ForwarderCommand(
|
||||
origin=master_nid,
|
||||
command=TextGeneration(
|
||||
command_id=CommandId(),
|
||||
task_params=TextGenerationTaskParams(
|
||||
model=model_id,
|
||||
input=[
|
||||
InputMessage(
|
||||
role="user",
|
||||
content=f"Request for {model_id} #{i}",
|
||||
)
|
||||
],
|
||||
),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
async with anyio.create_task_group() as send_tg:
|
||||
send_tg.start_soon(send_for_model, model_a_id, 3)
|
||||
send_tg.start_soon(send_for_model, model_b_id, 3)
|
||||
|
||||
# Wait for all 6 tasks
|
||||
with anyio.fail_after(5):
|
||||
while len(master.state.tasks) < 6:
|
||||
await anyio.sleep(0.01)
|
||||
|
||||
# Verify task routing
|
||||
model_a_tasks = [
|
||||
t
|
||||
for t in master.state.tasks.values()
|
||||
if isinstance(t, TextGenerationTask) and t.task_params.model == model_a_id
|
||||
]
|
||||
model_b_tasks = [
|
||||
t
|
||||
for t in master.state.tasks.values()
|
||||
if isinstance(t, TextGenerationTask) and t.task_params.model == model_b_id
|
||||
]
|
||||
|
||||
assert len(model_a_tasks) == 3
|
||||
assert len(model_b_tasks) == 3
|
||||
|
||||
# All model_a tasks should reference the model_a instance
|
||||
model_a_instance_ids = {
|
||||
iid for iid, mid in instance_to_model.items() if mid == model_a_id
|
||||
}
|
||||
for task in model_a_tasks:
|
||||
assert task.instance_id in model_a_instance_ids
|
||||
|
||||
# All model_b tasks should reference the model_b instance
|
||||
model_b_instance_ids = {
|
||||
iid for iid, mid in instance_to_model.items() if mid == model_b_id
|
||||
}
|
||||
for task in model_b_tasks:
|
||||
assert task.instance_id in model_b_instance_ids
|
||||
|
||||
await master.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_index_monotonically_increases_under_load() -> None:
|
||||
"""Under heavy concurrent command load, verify the master's event log
|
||||
index increases monotonically with no gaps or duplicates.
|
||||
"""
|
||||
master_nid = make_node_id("master-monotonic")
|
||||
session_id = SessionId(master_node_id=master_nid, election_clock=0)
|
||||
|
||||
ge_sender, ge_receiver = channel[ForwarderEvent]()
|
||||
cmd_sender, cmd_receiver = channel[ForwarderCommand]()
|
||||
le_sender, le_receiver = channel[ForwarderEvent]()
|
||||
dl_sender, _dl_receiver = channel[ForwarderDownloadCommand]()
|
||||
|
||||
master = Master(
|
||||
master_nid,
|
||||
session_id,
|
||||
global_event_sender=ge_sender,
|
||||
local_event_receiver=le_receiver,
|
||||
command_receiver=cmd_receiver,
|
||||
download_command_sender=dl_sender,
|
||||
)
|
||||
|
||||
collector = EventCollector(ge_receiver.clone())
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
tg.start_soon(master.run)
|
||||
|
||||
# Register node and place instance
|
||||
sender_id = NodeId(f"{master_nid}_sender")
|
||||
await le_sender.send(
|
||||
make_gathered_info_event(master_nid, sender_id, session_id, 0)
|
||||
)
|
||||
|
||||
with anyio.fail_after(3):
|
||||
while len(list(master.state.topology.list_nodes())) == 0:
|
||||
await anyio.sleep(0.01)
|
||||
|
||||
await cmd_sender.send(
|
||||
ForwarderCommand(
|
||||
origin=master_nid,
|
||||
command=PlaceInstance(
|
||||
command_id=CommandId(),
|
||||
model_card=TEST_MODEL_CARD,
|
||||
sharding=Sharding.Pipeline,
|
||||
instance_meta=InstanceMeta.MlxRing,
|
||||
min_nodes=1,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
with anyio.fail_after(3):
|
||||
while len(master.state.instances) == 0:
|
||||
await anyio.sleep(0.01)
|
||||
|
||||
# Blast 20 concurrent commands
|
||||
async def blast_commands(start: int, count: int) -> None:
|
||||
for i in range(count):
|
||||
await cmd_sender.send(
|
||||
ForwarderCommand(
|
||||
origin=master_nid,
|
||||
command=TextGeneration(
|
||||
command_id=CommandId(),
|
||||
task_params=TextGenerationTaskParams(
|
||||
model=TEST_MODEL_ID,
|
||||
input=[
|
||||
InputMessage(
|
||||
role="user",
|
||||
content=f"Blast {start + i}",
|
||||
)
|
||||
],
|
||||
),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
async with anyio.create_task_group() as blast_tg:
|
||||
blast_tg.start_soon(blast_commands, 0, 10)
|
||||
blast_tg.start_soon(blast_commands, 10, 10)
|
||||
|
||||
# Wait for all tasks
|
||||
with anyio.fail_after(5):
|
||||
while len(master.state.tasks) < 20:
|
||||
await anyio.sleep(0.01)
|
||||
|
||||
# Collect all events and verify monotonic indexing
|
||||
# NodeGatheredInfo(0) + InstanceCreated(1) + 20 TaskCreated = 22 events
|
||||
await collector.wait_for_event_count(22, timeout=5.0)
|
||||
|
||||
events = collector.indexed_events
|
||||
indices = [e.idx for e in events]
|
||||
|
||||
# Should be 0, 1, 2, ..., N-1 with no gaps
|
||||
expected = list(range(len(indices)))
|
||||
assert indices == expected
|
||||
|
||||
# last_event_applied_idx should match
|
||||
assert master.state.last_event_applied_idx == len(events) - 1
|
||||
|
||||
await master.shutdown()
|
||||
356
src/exo/tests/e2e_chaos/test_distributed_model_loading.py
Normal file
356
src/exo/tests/e2e_chaos/test_distributed_model_loading.py
Normal file
@@ -0,0 +1,356 @@
|
||||
"""E2E Chaos Test: Large model distributed loading.
|
||||
|
||||
Scenarios:
|
||||
1. Multi-node sharding -- place a model with min_nodes > 1, verify sharding
|
||||
is distributed across multiple nodes with correct shard assignments.
|
||||
2. Single-node gets all layers -- place on 1 node, verify full assignment.
|
||||
3. Three-node sharding -- verify 3-way distribution.
|
||||
"""
|
||||
|
||||
import anyio
|
||||
import pytest
|
||||
|
||||
from exo.master.main import Master
|
||||
from exo.shared.models.model_cards import ModelCard, ModelTask
|
||||
from exo.shared.types.commands import (
|
||||
CommandId,
|
||||
ForwarderCommand,
|
||||
ForwarderDownloadCommand,
|
||||
PlaceInstance,
|
||||
)
|
||||
from exo.shared.types.common import ModelId, NodeId, SessionId
|
||||
from exo.shared.types.events import (
|
||||
ForwarderEvent,
|
||||
)
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.worker.instances import InstanceMeta, MlxRingInstance
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata, Sharding
|
||||
from exo.utils.channels import Sender, channel
|
||||
|
||||
from .conftest import (
|
||||
TEST_MODEL_CARD,
|
||||
make_gathered_info_event,
|
||||
make_node_id,
|
||||
make_topology_edge_event,
|
||||
)
|
||||
|
||||
# A model large enough to need sharding but small enough to fit in test node memory
|
||||
# Each test node has 8GB available, so 2 nodes = 16GB, 3 nodes = 24GB.
|
||||
# storage_size < total cluster memory to pass the memory filter.
|
||||
LARGE_MODEL_CARD = ModelCard(
|
||||
model_id=ModelId("test-model/large-70b-4bit"),
|
||||
n_layers=80,
|
||||
storage_size=Memory.from_bytes(4 * 1024 * 1024 * 1024),
|
||||
hidden_size=8192,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
)
|
||||
|
||||
|
||||
async def _register_node(
|
||||
le_sender: Sender[ForwarderEvent],
|
||||
node_id: NodeId,
|
||||
session_id: SessionId,
|
||||
) -> None:
|
||||
"""Register a node by injecting NodeGatheredInfo."""
|
||||
sender_id = NodeId(f"{node_id}_sender")
|
||||
await le_sender.send(make_gathered_info_event(node_id, sender_id, session_id, 0))
|
||||
|
||||
|
||||
async def _add_bidirectional_edge(
|
||||
le_sender: Sender[ForwarderEvent],
|
||||
node_a: NodeId,
|
||||
node_b: NodeId,
|
||||
session_id: SessionId,
|
||||
sender_id: NodeId,
|
||||
origin_idx_start: int,
|
||||
ip_a: int,
|
||||
ip_b: int,
|
||||
) -> None:
|
||||
"""Add bidirectional topology edges between two nodes."""
|
||||
await le_sender.send(
|
||||
make_topology_edge_event(
|
||||
node_a, node_b, sender_id, session_id, origin_idx_start, ip_suffix=ip_b
|
||||
)
|
||||
)
|
||||
await le_sender.send(
|
||||
make_topology_edge_event(
|
||||
node_b, node_a, sender_id, session_id, origin_idx_start + 1, ip_suffix=ip_a
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.asyncio
|
||||
async def test_multi_node_sharding_distributes_layers() -> None:
|
||||
"""Place a model with min_nodes=2 on a cluster with 2 connected nodes.
|
||||
Verify the resulting instance has shard assignments spanning both nodes.
|
||||
"""
|
||||
master_nid = make_node_id("master-shard")
|
||||
session_id = SessionId(master_node_id=master_nid, election_clock=0)
|
||||
|
||||
ge_sender, _ge_receiver = channel[ForwarderEvent]()
|
||||
cmd_sender, cmd_receiver = channel[ForwarderCommand]()
|
||||
le_sender, le_receiver = channel[ForwarderEvent]()
|
||||
dl_sender, _dl_receiver = channel[ForwarderDownloadCommand]()
|
||||
|
||||
master = Master(
|
||||
master_nid,
|
||||
session_id,
|
||||
global_event_sender=ge_sender,
|
||||
local_event_receiver=le_receiver,
|
||||
command_receiver=cmd_receiver,
|
||||
download_command_sender=dl_sender,
|
||||
)
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
tg.start_soon(master.run)
|
||||
|
||||
worker_a = make_node_id("shard-worker-a")
|
||||
worker_b = make_node_id("shard-worker-b")
|
||||
|
||||
# Register both worker nodes (each sender uses origin_idx=0)
|
||||
for nid in [worker_a, worker_b]:
|
||||
await _register_node(le_sender, nid, session_id)
|
||||
|
||||
with anyio.fail_after(3):
|
||||
while len(list(master.state.topology.list_nodes())) < 2:
|
||||
await anyio.sleep(0.01)
|
||||
|
||||
# Add bidirectional edges to form a 2-node cycle (A <-> B)
|
||||
edge_sender = NodeId("edge_sender")
|
||||
await _add_bidirectional_edge(
|
||||
le_sender, worker_a, worker_b, session_id, edge_sender, 0, 1, 2
|
||||
)
|
||||
|
||||
# Wait for edges to be processed
|
||||
with anyio.fail_after(3):
|
||||
while len(list(master.state.topology.list_connections())) < 2:
|
||||
await anyio.sleep(0.01)
|
||||
|
||||
# Place a large model requiring 2 nodes
|
||||
await cmd_sender.send(
|
||||
ForwarderCommand(
|
||||
origin=master_nid,
|
||||
command=PlaceInstance(
|
||||
command_id=CommandId(),
|
||||
model_card=LARGE_MODEL_CARD,
|
||||
sharding=Sharding.Pipeline,
|
||||
instance_meta=InstanceMeta.MlxRing,
|
||||
min_nodes=2,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
with anyio.fail_after(5):
|
||||
while len(master.state.instances) == 0:
|
||||
await anyio.sleep(0.01)
|
||||
|
||||
instance_id = next(iter(master.state.instances))
|
||||
instance = master.state.instances[instance_id]
|
||||
assert isinstance(instance, MlxRingInstance)
|
||||
|
||||
shard_assignments = instance.shard_assignments
|
||||
runner_shards = shard_assignments.runner_to_shard
|
||||
|
||||
assert len(runner_shards) == 2
|
||||
|
||||
assigned_nodes = set(shard_assignments.node_to_runner.keys())
|
||||
assert worker_a in assigned_nodes
|
||||
assert worker_b in assigned_nodes
|
||||
|
||||
shards = list(runner_shards.values())
|
||||
assert all(isinstance(s, PipelineShardMetadata) for s in shards)
|
||||
pipeline_shards = [s for s in shards if isinstance(s, PipelineShardMetadata)]
|
||||
|
||||
assert all(s.world_size == 2 for s in pipeline_shards)
|
||||
ranks = {s.device_rank for s in pipeline_shards}
|
||||
assert ranks == {0, 1}
|
||||
|
||||
sorted_shards = sorted(pipeline_shards, key=lambda s: s.device_rank)
|
||||
assert sorted_shards[0].start_layer == 0
|
||||
assert sorted_shards[-1].end_layer == LARGE_MODEL_CARD.n_layers
|
||||
|
||||
total_layers = sum(s.end_layer - s.start_layer for s in sorted_shards)
|
||||
assert total_layers == LARGE_MODEL_CARD.n_layers
|
||||
|
||||
await master.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_node_gets_all_layers() -> None:
|
||||
"""Place a model with min_nodes=1 on a single node. Verify the
|
||||
instance has one runner assigned all layers (world_size=1).
|
||||
"""
|
||||
master_nid = make_node_id("master-single")
|
||||
session_id = SessionId(master_node_id=master_nid, election_clock=0)
|
||||
|
||||
ge_sender, _ge_receiver = channel[ForwarderEvent]()
|
||||
cmd_sender, cmd_receiver = channel[ForwarderCommand]()
|
||||
le_sender, le_receiver = channel[ForwarderEvent]()
|
||||
dl_sender, _dl_receiver = channel[ForwarderDownloadCommand]()
|
||||
|
||||
master = Master(
|
||||
master_nid,
|
||||
session_id,
|
||||
global_event_sender=ge_sender,
|
||||
local_event_receiver=le_receiver,
|
||||
command_receiver=cmd_receiver,
|
||||
download_command_sender=dl_sender,
|
||||
)
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
tg.start_soon(master.run)
|
||||
|
||||
worker_nid = make_node_id("single-worker")
|
||||
await _register_node(le_sender, worker_nid, session_id)
|
||||
|
||||
with anyio.fail_after(3):
|
||||
while len(list(master.state.topology.list_nodes())) < 1:
|
||||
await anyio.sleep(0.01)
|
||||
|
||||
await cmd_sender.send(
|
||||
ForwarderCommand(
|
||||
origin=master_nid,
|
||||
command=PlaceInstance(
|
||||
command_id=CommandId(),
|
||||
model_card=TEST_MODEL_CARD,
|
||||
sharding=Sharding.Pipeline,
|
||||
instance_meta=InstanceMeta.MlxRing,
|
||||
min_nodes=1,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
with anyio.fail_after(3):
|
||||
while len(master.state.instances) == 0:
|
||||
await anyio.sleep(0.01)
|
||||
|
||||
instance_id = next(iter(master.state.instances))
|
||||
instance = master.state.instances[instance_id]
|
||||
assert isinstance(instance, MlxRingInstance)
|
||||
|
||||
shards = list(instance.shard_assignments.runner_to_shard.values())
|
||||
assert len(shards) == 1
|
||||
|
||||
shard = shards[0]
|
||||
assert isinstance(shard, PipelineShardMetadata)
|
||||
assert shard.world_size == 1
|
||||
assert shard.device_rank == 0
|
||||
assert shard.start_layer == 0
|
||||
assert shard.end_layer == TEST_MODEL_CARD.n_layers
|
||||
|
||||
await master.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.asyncio
|
||||
async def test_three_node_sharding_distributes_evenly() -> None:
|
||||
"""Place a model across 3 connected nodes. Verify all 3 get shard assignments."""
|
||||
master_nid = make_node_id("master-3way")
|
||||
session_id = SessionId(master_node_id=master_nid, election_clock=0)
|
||||
|
||||
ge_sender, _ge_receiver = channel[ForwarderEvent]()
|
||||
cmd_sender, cmd_receiver = channel[ForwarderCommand]()
|
||||
le_sender, le_receiver = channel[ForwarderEvent]()
|
||||
dl_sender, _dl_receiver = channel[ForwarderDownloadCommand]()
|
||||
|
||||
master = Master(
|
||||
master_nid,
|
||||
session_id,
|
||||
global_event_sender=ge_sender,
|
||||
local_event_receiver=le_receiver,
|
||||
command_receiver=cmd_receiver,
|
||||
download_command_sender=dl_sender,
|
||||
)
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
tg.start_soon(master.run)
|
||||
|
||||
workers: list[NodeId] = []
|
||||
for i in range(3):
|
||||
nid = make_node_id(f"three-worker-{i}")
|
||||
workers.append(nid)
|
||||
await _register_node(le_sender, nid, session_id)
|
||||
|
||||
with anyio.fail_after(3):
|
||||
while len(list(master.state.topology.list_nodes())) < 3:
|
||||
await anyio.sleep(0.01)
|
||||
|
||||
# Add bidirectional edges to form a fully connected 3-node cycle:
|
||||
# A <-> B, B <-> C, C <-> A
|
||||
edge_sender = NodeId("edge_sender_3way")
|
||||
idx = 0
|
||||
ip_counter = 10
|
||||
for i in range(3):
|
||||
source = workers[i]
|
||||
sink = workers[(i + 1) % 3]
|
||||
# Forward edge
|
||||
await le_sender.send(
|
||||
make_topology_edge_event(
|
||||
source,
|
||||
sink,
|
||||
edge_sender,
|
||||
session_id,
|
||||
idx,
|
||||
ip_suffix=ip_counter,
|
||||
)
|
||||
)
|
||||
idx += 1
|
||||
ip_counter += 1
|
||||
# Reverse edge
|
||||
await le_sender.send(
|
||||
make_topology_edge_event(
|
||||
sink,
|
||||
source,
|
||||
edge_sender,
|
||||
session_id,
|
||||
idx,
|
||||
ip_suffix=ip_counter,
|
||||
)
|
||||
)
|
||||
idx += 1
|
||||
ip_counter += 1
|
||||
|
||||
# Wait for all 6 edges (3 pairs x 2 directions)
|
||||
with anyio.fail_after(3):
|
||||
while len(list(master.state.topology.list_connections())) < 6:
|
||||
await anyio.sleep(0.01)
|
||||
|
||||
await cmd_sender.send(
|
||||
ForwarderCommand(
|
||||
origin=master_nid,
|
||||
command=PlaceInstance(
|
||||
command_id=CommandId(),
|
||||
model_card=LARGE_MODEL_CARD,
|
||||
sharding=Sharding.Pipeline,
|
||||
instance_meta=InstanceMeta.MlxRing,
|
||||
min_nodes=3,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
with anyio.fail_after(5):
|
||||
while len(master.state.instances) == 0:
|
||||
await anyio.sleep(0.01)
|
||||
|
||||
instance = next(iter(master.state.instances.values()))
|
||||
assert isinstance(instance, MlxRingInstance)
|
||||
|
||||
assignments = instance.shard_assignments
|
||||
assert len(assignments.runner_to_shard) == 3
|
||||
assert len(assignments.node_to_runner) == 3
|
||||
|
||||
for w in workers:
|
||||
assert w in assignments.node_to_runner
|
||||
|
||||
shards = list(assignments.runner_to_shard.values())
|
||||
ranks = {s.device_rank for s in shards if isinstance(s, PipelineShardMetadata)}
|
||||
assert ranks == {0, 1, 2}
|
||||
|
||||
pipeline_shards = [s for s in shards if isinstance(s, PipelineShardMetadata)]
|
||||
total_layers = sum(s.end_layer - s.start_layer for s in pipeline_shards)
|
||||
assert total_layers == LARGE_MODEL_CARD.n_layers
|
||||
|
||||
await master.shutdown()
|
||||
272
src/exo/tests/e2e_chaos/test_failure_recovery.py
Normal file
272
src/exo/tests/e2e_chaos/test_failure_recovery.py
Normal file
@@ -0,0 +1,272 @@
|
||||
"""E2E Chaos Test: Failure recovery.
|
||||
|
||||
Scenarios:
|
||||
1. Master crash and re-election -- master shuts down, a new election round
|
||||
produces a new master, workers re-converge.
|
||||
2. Worker crash during task execution -- runner death is detected, instance
|
||||
is cleaned up, and cluster recovers.
|
||||
"""
|
||||
|
||||
import anyio
|
||||
import pytest
|
||||
|
||||
from exo.master.main import Master
|
||||
from exo.shared.types.commands import (
|
||||
CommandId,
|
||||
ForwarderCommand,
|
||||
ForwarderDownloadCommand,
|
||||
PlaceInstance,
|
||||
)
|
||||
from exo.shared.types.common import NodeId, SessionId
|
||||
from exo.shared.types.events import (
|
||||
ForwarderEvent,
|
||||
RunnerStatusUpdated,
|
||||
)
|
||||
from exo.shared.types.worker.instances import InstanceMeta
|
||||
from exo.shared.types.worker.runners import RunnerFailed
|
||||
from exo.shared.types.worker.shards import Sharding
|
||||
from exo.utils.channels import channel
|
||||
|
||||
from .conftest import (
|
||||
TEST_MODEL_CARD,
|
||||
EventCollector,
|
||||
MiniCluster,
|
||||
make_gathered_info_event,
|
||||
make_node_id,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.asyncio
|
||||
async def test_master_crash_and_reelection() -> None:
|
||||
"""Simulate master crash by shutting it down, then verify a new master
|
||||
can be started with fresh state and begin accepting commands.
|
||||
|
||||
This tests the scenario where the elected master dies and a new election
|
||||
must take place. We simulate the election result directly (since
|
||||
Election is tested separately) and verify the new master works.
|
||||
"""
|
||||
cluster = MiniCluster(node_count=1)
|
||||
old_instance_id: str = ""
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
tg.start_soon(cluster.master.run)
|
||||
|
||||
# Set up initial state
|
||||
await cluster.inject_node_info(cluster.master_node_id)
|
||||
await cluster.wait_for_topology_nodes(1)
|
||||
await cluster.place_model()
|
||||
await cluster.wait_for_instances(1)
|
||||
|
||||
# Verify initial state
|
||||
assert len(cluster.master.state.instances) == 1
|
||||
old_instance_id = next(iter(cluster.master.state.instances))
|
||||
|
||||
# --- Crash the master ---
|
||||
await cluster.shutdown_master()
|
||||
|
||||
# --- Start a new master (simulating re-election) ---
|
||||
new_master_nid = make_node_id("new-master")
|
||||
new_session_id = SessionId(master_node_id=new_master_nid, election_clock=1)
|
||||
|
||||
ge_sender, ge_receiver = channel[ForwarderEvent]()
|
||||
cmd_sender, cmd_receiver = channel[ForwarderCommand]()
|
||||
le_sender, le_receiver = channel[ForwarderEvent]()
|
||||
dl_sender, _dl_receiver = channel[ForwarderDownloadCommand]()
|
||||
|
||||
new_master = Master(
|
||||
new_master_nid,
|
||||
new_session_id,
|
||||
global_event_sender=ge_sender,
|
||||
local_event_receiver=le_receiver,
|
||||
command_receiver=cmd_receiver,
|
||||
download_command_sender=dl_sender,
|
||||
)
|
||||
|
||||
_new_collector = EventCollector(ge_receiver.clone())
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
tg.start_soon(new_master.run)
|
||||
|
||||
# New master starts with clean state
|
||||
assert len(new_master.state.instances) == 0
|
||||
assert new_master.state.last_event_applied_idx == -1
|
||||
|
||||
# Re-register node with the new master
|
||||
sender_id = NodeId(f"{new_master_nid}_sender_new")
|
||||
await le_sender.send(
|
||||
make_gathered_info_event(new_master_nid, sender_id, new_session_id, 0)
|
||||
)
|
||||
|
||||
# Wait for topology to be rebuilt
|
||||
with anyio.fail_after(3):
|
||||
while len(list(new_master.state.topology.list_nodes())) == 0:
|
||||
await anyio.sleep(0.01)
|
||||
|
||||
# Place a new model instance on the new master
|
||||
await cmd_sender.send(
|
||||
ForwarderCommand(
|
||||
origin=new_master_nid,
|
||||
command=PlaceInstance(
|
||||
command_id=CommandId(),
|
||||
model_card=TEST_MODEL_CARD,
|
||||
sharding=Sharding.Pipeline,
|
||||
instance_meta=InstanceMeta.MlxRing,
|
||||
min_nodes=1,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
with anyio.fail_after(3):
|
||||
while len(new_master.state.instances) == 0:
|
||||
await anyio.sleep(0.01)
|
||||
|
||||
# Verify new master is functional
|
||||
assert len(new_master.state.instances) == 1
|
||||
new_instance_id = next(iter(new_master.state.instances))
|
||||
# New instance should be different from old one
|
||||
assert new_instance_id != old_instance_id
|
||||
|
||||
await new_master.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_failure_triggers_instance_cleanup() -> None:
|
||||
"""Simulate a runner failure by injecting a RunnerStatusUpdated(RunnerFailed)
|
||||
event. Verify that the master's plan loop eventually detects the broken
|
||||
instance (no connected node for the runner) and cleans it up.
|
||||
"""
|
||||
master_nid = make_node_id("master-runner-fail")
|
||||
session_id = SessionId(master_node_id=master_nid, election_clock=0)
|
||||
|
||||
ge_sender, ge_receiver = channel[ForwarderEvent]()
|
||||
cmd_sender, cmd_receiver = channel[ForwarderCommand]()
|
||||
le_sender, le_receiver = channel[ForwarderEvent]()
|
||||
dl_sender, _dl_receiver = channel[ForwarderDownloadCommand]()
|
||||
|
||||
master = Master(
|
||||
master_nid,
|
||||
session_id,
|
||||
global_event_sender=ge_sender,
|
||||
local_event_receiver=le_receiver,
|
||||
command_receiver=cmd_receiver,
|
||||
download_command_sender=dl_sender,
|
||||
)
|
||||
|
||||
_collector = EventCollector(ge_receiver.clone())
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
tg.start_soon(master.run)
|
||||
|
||||
# Register a worker node
|
||||
worker_nid = make_node_id("worker-failing")
|
||||
sender_id = NodeId(f"{worker_nid}_sender")
|
||||
await le_sender.send(
|
||||
make_gathered_info_event(worker_nid, sender_id, session_id, 0)
|
||||
)
|
||||
|
||||
with anyio.fail_after(3):
|
||||
while len(list(master.state.topology.list_nodes())) == 0:
|
||||
await anyio.sleep(0.01)
|
||||
|
||||
# Place a model instance
|
||||
await cmd_sender.send(
|
||||
ForwarderCommand(
|
||||
origin=master_nid,
|
||||
command=PlaceInstance(
|
||||
command_id=CommandId(),
|
||||
model_card=TEST_MODEL_CARD,
|
||||
sharding=Sharding.Pipeline,
|
||||
instance_meta=InstanceMeta.MlxRing,
|
||||
min_nodes=1,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
with anyio.fail_after(3):
|
||||
while len(master.state.instances) == 0:
|
||||
await anyio.sleep(0.01)
|
||||
|
||||
instance_id = next(iter(master.state.instances))
|
||||
instance = master.state.instances[instance_id]
|
||||
runner_id = next(iter(instance.shard_assignments.runner_to_shard))
|
||||
|
||||
# Inject a RunnerFailed event from the worker
|
||||
await le_sender.send(
|
||||
ForwarderEvent(
|
||||
origin_idx=1,
|
||||
origin=sender_id,
|
||||
session=session_id,
|
||||
event=RunnerStatusUpdated(
|
||||
runner_id=runner_id,
|
||||
runner_status=RunnerFailed(
|
||||
error_message="Simulated OOM kill (exitcode=137)"
|
||||
),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# Wait for the runner failure to be processed
|
||||
with anyio.fail_after(3):
|
||||
while runner_id not in master.state.runners:
|
||||
await anyio.sleep(0.01)
|
||||
|
||||
# The runner status should be RunnerFailed
|
||||
assert isinstance(master.state.runners[runner_id], RunnerFailed)
|
||||
|
||||
await master.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.asyncio
|
||||
async def test_election_recovers_after_multiple_node_joins() -> None:
|
||||
"""Verify that the election protocol correctly handles rapid node
|
||||
join/leave events by running multiple election rounds.
|
||||
"""
|
||||
from exo.routing.connection_message import ConnectionMessage, ConnectionMessageType
|
||||
from exo.shared.election import Election, ElectionMessage, ElectionResult
|
||||
|
||||
em_out_tx, em_out_rx = channel[ElectionMessage]()
|
||||
em_in_tx, em_in_rx = channel[ElectionMessage]()
|
||||
er_tx, er_rx = channel[ElectionResult]()
|
||||
cm_tx, cm_rx = channel[ConnectionMessage]()
|
||||
co_tx, co_rx = channel[ForwarderCommand]()
|
||||
|
||||
election = Election(
|
||||
node_id=NodeId("SURVIVOR"),
|
||||
election_message_receiver=em_in_rx,
|
||||
election_message_sender=em_out_tx,
|
||||
election_result_sender=er_tx,
|
||||
connection_message_receiver=cm_rx,
|
||||
command_receiver=co_rx,
|
||||
is_candidate=True,
|
||||
)
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
with anyio.fail_after(5):
|
||||
tg.start_soon(election.run)
|
||||
|
||||
# Simulate rapid node joins via connection messages
|
||||
for i in range(3):
|
||||
await cm_tx.send(
|
||||
ConnectionMessage(
|
||||
node_id=NodeId(f"joiner-{i}"),
|
||||
connection_type=ConnectionMessageType.Connected,
|
||||
remote_ipv4=f"10.0.0.{i + 1}",
|
||||
remote_tcp_port=52415,
|
||||
)
|
||||
)
|
||||
# Each connection triggers a new election round
|
||||
while True:
|
||||
got = await em_out_rx.receive()
|
||||
if got.proposed_session.master_node_id == NodeId("SURVIVOR"):
|
||||
break
|
||||
|
||||
# After all joins, an election result should eventually be produced
|
||||
result = await er_rx.receive()
|
||||
assert result.session_id.master_node_id == NodeId("SURVIVOR")
|
||||
|
||||
em_in_tx.close()
|
||||
cm_tx.close()
|
||||
co_tx.close()
|
||||
227
src/exo/tests/e2e_chaos/test_networking_resilience.py
Normal file
227
src/exo/tests/e2e_chaos/test_networking_resilience.py
Normal file
@@ -0,0 +1,227 @@
|
||||
"""E2E Chaos Test: Networking resilience.
|
||||
|
||||
Scenarios:
|
||||
1. Node disconnect mid-inference -- a worker stops receiving global events, then
|
||||
reconnects and catches up via the event buffer / nack mechanism.
|
||||
2. Master detects stale node and times it out, then the node re-announces.
|
||||
"""
|
||||
|
||||
import anyio
|
||||
import pytest
|
||||
|
||||
from exo.master.main import Master
|
||||
from exo.shared.types.commands import (
|
||||
ForwarderCommand,
|
||||
ForwarderDownloadCommand,
|
||||
)
|
||||
from exo.shared.types.common import NodeId, SessionId
|
||||
from exo.shared.types.events import (
|
||||
ForwarderEvent,
|
||||
InstanceCreated,
|
||||
NodeGatheredInfo,
|
||||
TaskCreated,
|
||||
)
|
||||
from exo.utils.channels import channel
|
||||
|
||||
from .conftest import (
|
||||
EventCollector,
|
||||
MiniCluster,
|
||||
make_gathered_info_event,
|
||||
make_node_id,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.asyncio
|
||||
async def test_node_disconnect_and_reconnect_event_replay() -> None:
|
||||
"""Simulate a node disconnecting by closing its global event receiver,
|
||||
then reconnecting with a fresh receiver.
|
||||
|
||||
After reconnection, events that were broadcast while the node was
|
||||
disconnected should be replayed to the new receiver via the shared
|
||||
channel state. The master's state should remain consistent.
|
||||
"""
|
||||
cluster = MiniCluster(node_count=1)
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
tg.start_soon(cluster.master.run)
|
||||
|
||||
# Register the master node so topology is populated
|
||||
await cluster.inject_node_info(cluster.master_node_id)
|
||||
await cluster.wait_for_topology_nodes(1)
|
||||
|
||||
# Place a model instance
|
||||
await cluster.place_model()
|
||||
await cluster.wait_for_instances(1)
|
||||
|
||||
# Verify instance was created
|
||||
assert len(cluster.master.state.instances) == 1
|
||||
|
||||
# --- Simulate disconnection ---
|
||||
# The worker's global event receiver is independent; we just verify
|
||||
# that the master continues to accept commands while a worker is gone.
|
||||
_first_instance_id = next(iter(cluster.master.state.instances))
|
||||
|
||||
# Send a chat command while "disconnected" worker can't process
|
||||
_cmd_id = await cluster.send_chat("Hello during disconnect")
|
||||
|
||||
# Give master time to process the command
|
||||
await cluster.event_collector.wait_for_event_count(3, timeout=3.0)
|
||||
|
||||
events = cluster.event_collector.indexed_events
|
||||
# Should have: NodeGatheredInfo, InstanceCreated, TaskCreated
|
||||
assert any(isinstance(e.event, NodeGatheredInfo) for e in events)
|
||||
assert any(isinstance(e.event, InstanceCreated) for e in events)
|
||||
assert any(isinstance(e.event, TaskCreated) for e in events)
|
||||
|
||||
# --- Simulate reconnection ---
|
||||
# A reconnecting node gets a fresh receiver clone and catches up
|
||||
reconnect_receiver = cluster.global_event_internal_receiver.clone()
|
||||
_reconnect_collector = EventCollector(reconnect_receiver)
|
||||
|
||||
# The new receiver should see future events; existing events are in
|
||||
# the master's event log (which would be replayed via RequestEventLog
|
||||
# in production). Here we verify the channel infrastructure works.
|
||||
await cluster.send_chat("Hello after reconnect")
|
||||
await anyio.sleep(0.1)
|
||||
|
||||
# Master state should now have 2 tasks
|
||||
assert len(cluster.master.state.tasks) == 2
|
||||
|
||||
# The master's state is consistent throughout
|
||||
assert len(cluster.master.state.instances) == 1
|
||||
assert cluster.master.state.last_event_applied_idx >= 3
|
||||
|
||||
await cluster.shutdown_master()
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.asyncio
|
||||
async def test_master_detects_timed_out_node_and_cleans_state() -> None:
|
||||
"""Verify that the master's plan loop detects a node that hasn't sent
|
||||
a heartbeat (NodeGatheredInfo) recently and emits NodeTimedOut, cleaning
|
||||
up topology and related state.
|
||||
"""
|
||||
master_nid = make_node_id("master-timeout")
|
||||
session_id = SessionId(master_node_id=master_nid, election_clock=0)
|
||||
|
||||
ge_sender, ge_receiver = channel[ForwarderEvent]()
|
||||
_cmd_sender, cmd_receiver = channel[ForwarderCommand]()
|
||||
le_sender, le_receiver = channel[ForwarderEvent]()
|
||||
dl_sender, _dl_receiver = channel[ForwarderDownloadCommand]()
|
||||
|
||||
master = Master(
|
||||
master_nid,
|
||||
session_id,
|
||||
global_event_sender=ge_sender,
|
||||
local_event_receiver=le_receiver,
|
||||
command_receiver=cmd_receiver,
|
||||
download_command_sender=dl_sender,
|
||||
)
|
||||
|
||||
_collector = EventCollector(ge_receiver.clone())
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
tg.start_soon(master.run)
|
||||
|
||||
# Register two nodes
|
||||
stale_node = make_node_id("stale")
|
||||
alive_node = make_node_id("alive")
|
||||
|
||||
for node_id, suffix in [(stale_node, "_s0"), (alive_node, "_a0")]:
|
||||
sender_id = NodeId(f"{node_id}_sender{suffix}")
|
||||
await le_sender.send(
|
||||
make_gathered_info_event(node_id, sender_id, session_id, 0)
|
||||
)
|
||||
|
||||
# Wait for both nodes in topology
|
||||
with anyio.fail_after(3):
|
||||
while len(list(master.state.topology.list_nodes())) < 2:
|
||||
await anyio.sleep(0.01)
|
||||
|
||||
assert stale_node in master.state.last_seen
|
||||
assert alive_node in master.state.last_seen
|
||||
|
||||
# Manually expire the stale node's last_seen time by patching the state
|
||||
# (in production, the _plan loop checks every 10s with a 30s threshold)
|
||||
from datetime import timedelta
|
||||
|
||||
old_time = master.state.last_seen[stale_node] - timedelta(seconds=60)
|
||||
patched_last_seen = {**master.state.last_seen, stale_node: old_time}
|
||||
master.state = master.state.model_copy(update={"last_seen": patched_last_seen})
|
||||
|
||||
# Trigger the plan loop manually to speed up the test
|
||||
# The plan loop checks for stale nodes
|
||||
# We wait for the NodeTimedOut event to be emitted
|
||||
with anyio.fail_after(15):
|
||||
while stale_node in master.state.last_seen:
|
||||
await anyio.sleep(0.1)
|
||||
|
||||
# Stale node should be removed from topology
|
||||
assert stale_node not in set(master.state.topology.list_nodes())
|
||||
|
||||
# Alive node should still be present
|
||||
assert alive_node in set(master.state.topology.list_nodes())
|
||||
assert alive_node in master.state.last_seen
|
||||
|
||||
await master.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_ordering_preserved_under_concurrent_writers() -> None:
|
||||
"""Multiple sources writing local events concurrently. Verify that the
|
||||
master's MultiSourceBuffer correctly sequences events from each source
|
||||
and the final state is consistent.
|
||||
"""
|
||||
master_nid = make_node_id("master-ordering")
|
||||
session_id = SessionId(master_node_id=master_nid, election_clock=0)
|
||||
|
||||
ge_sender, ge_receiver = channel[ForwarderEvent]()
|
||||
_cmd_sender, cmd_receiver = channel[ForwarderCommand]()
|
||||
le_sender, le_receiver = channel[ForwarderEvent]()
|
||||
dl_sender, _dl_receiver = channel[ForwarderDownloadCommand]()
|
||||
|
||||
master = Master(
|
||||
master_nid,
|
||||
session_id,
|
||||
global_event_sender=ge_sender,
|
||||
local_event_receiver=le_receiver,
|
||||
command_receiver=cmd_receiver,
|
||||
download_command_sender=dl_sender,
|
||||
)
|
||||
|
||||
_collector = EventCollector(ge_receiver.clone())
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
tg.start_soon(master.run)
|
||||
|
||||
# Inject events from 3 different "worker" sources concurrently
|
||||
node_ids = [make_node_id(f"concurrent-{i}") for i in range(3)]
|
||||
|
||||
async def inject_events(node_id: NodeId, count: int) -> None:
|
||||
for idx in range(count):
|
||||
sender_id = NodeId(f"{node_id}_sender")
|
||||
await le_sender.send(
|
||||
make_gathered_info_event(node_id, sender_id, session_id, idx)
|
||||
)
|
||||
await anyio.sleep(0.001) # slight jitter
|
||||
|
||||
async with anyio.create_task_group() as inject_tg:
|
||||
for nid in node_ids:
|
||||
inject_tg.start_soon(inject_events, nid, 5)
|
||||
|
||||
# Wait for master to process all events (3 nodes * 5 events each = 15)
|
||||
with anyio.fail_after(5):
|
||||
while master.state.last_event_applied_idx < 14:
|
||||
await anyio.sleep(0.01)
|
||||
|
||||
# All 3 nodes should be visible in topology
|
||||
topo_nodes = set(master.state.topology.list_nodes())
|
||||
for nid in node_ids:
|
||||
assert nid in topo_nodes
|
||||
|
||||
# Event indices should be sequential with no gaps
|
||||
assert master.state.last_event_applied_idx == 14
|
||||
|
||||
await master.shutdown()
|
||||
267
src/exo/tests/e2e_chaos/test_node_join_leave.py
Normal file
267
src/exo/tests/e2e_chaos/test_node_join_leave.py
Normal file
@@ -0,0 +1,267 @@
|
||||
"""E2E Chaos Test: Node join/leave during operation.
|
||||
|
||||
Scenarios:
|
||||
1. Add nodes dynamically -- register new nodes with the master while
|
||||
a model is already placed, verify topology grows.
|
||||
2. Remove nodes -- simulate node timeout, verify instances on that node
|
||||
are cleaned up and remaining nodes are unaffected.
|
||||
3. Rapid join/leave churn -- nodes join and leave quickly, verify state
|
||||
converges to a consistent snapshot.
|
||||
"""
|
||||
|
||||
from datetime import timedelta
|
||||
|
||||
import anyio
|
||||
import pytest
|
||||
|
||||
from exo.master.main import Master
|
||||
from exo.shared.types.commands import (
|
||||
CommandId,
|
||||
ForwarderCommand,
|
||||
ForwarderDownloadCommand,
|
||||
PlaceInstance,
|
||||
)
|
||||
from exo.shared.types.common import NodeId, SessionId
|
||||
from exo.shared.types.events import (
|
||||
ForwarderEvent,
|
||||
)
|
||||
from exo.shared.types.worker.instances import InstanceMeta
|
||||
from exo.shared.types.worker.shards import Sharding
|
||||
from exo.utils.channels import channel
|
||||
|
||||
from .conftest import (
|
||||
TEST_MODEL_CARD,
|
||||
make_gathered_info_event,
|
||||
make_node_id,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.asyncio
|
||||
async def test_dynamic_node_registration_expands_topology() -> None:
|
||||
"""Start with one node, then add more dynamically. Verify the topology
|
||||
grows and all nodes are visible in state.
|
||||
"""
|
||||
master_nid = make_node_id("master-join")
|
||||
session_id = SessionId(master_node_id=master_nid, election_clock=0)
|
||||
|
||||
ge_sender, _ge_receiver = channel[ForwarderEvent]()
|
||||
cmd_sender, cmd_receiver = channel[ForwarderCommand]()
|
||||
le_sender, le_receiver = channel[ForwarderEvent]()
|
||||
dl_sender, _dl_receiver = channel[ForwarderDownloadCommand]()
|
||||
|
||||
master = Master(
|
||||
master_nid,
|
||||
session_id,
|
||||
global_event_sender=ge_sender,
|
||||
local_event_receiver=le_receiver,
|
||||
command_receiver=cmd_receiver,
|
||||
download_command_sender=dl_sender,
|
||||
)
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
tg.start_soon(master.run)
|
||||
|
||||
# Register initial node
|
||||
initial_node = make_node_id("initial")
|
||||
sender_id = NodeId(f"{initial_node}_sender")
|
||||
await le_sender.send(
|
||||
make_gathered_info_event(initial_node, sender_id, session_id, 0)
|
||||
)
|
||||
|
||||
with anyio.fail_after(3):
|
||||
while len(list(master.state.topology.list_nodes())) < 1:
|
||||
await anyio.sleep(0.01)
|
||||
|
||||
# Place a model instance
|
||||
await cmd_sender.send(
|
||||
ForwarderCommand(
|
||||
origin=master_nid,
|
||||
command=PlaceInstance(
|
||||
command_id=CommandId(),
|
||||
model_card=TEST_MODEL_CARD,
|
||||
sharding=Sharding.Pipeline,
|
||||
instance_meta=InstanceMeta.MlxRing,
|
||||
min_nodes=1,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
with anyio.fail_after(3):
|
||||
while len(master.state.instances) == 0:
|
||||
await anyio.sleep(0.01)
|
||||
|
||||
# Dynamically add 3 more nodes
|
||||
new_nodes: list[NodeId] = []
|
||||
for i in range(3):
|
||||
new_nid = make_node_id(f"dynamic-{i}")
|
||||
new_nodes.append(new_nid)
|
||||
new_sender = NodeId(f"{new_nid}_sender")
|
||||
await le_sender.send(
|
||||
make_gathered_info_event(new_nid, new_sender, session_id, 0)
|
||||
)
|
||||
|
||||
with anyio.fail_after(3):
|
||||
while len(list(master.state.topology.list_nodes())) < 4:
|
||||
await anyio.sleep(0.01)
|
||||
|
||||
# All 4 nodes should be in topology
|
||||
topo_nodes = set(master.state.topology.list_nodes())
|
||||
assert initial_node in topo_nodes
|
||||
for nid in new_nodes:
|
||||
assert nid in topo_nodes
|
||||
|
||||
# Original instance should still exist
|
||||
assert len(master.state.instances) >= 1
|
||||
|
||||
await master.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.asyncio
|
||||
async def test_node_removal_cleans_up_instances() -> None:
|
||||
"""Place a model on a specific node, then time it out. Verify the
|
||||
instance assigned to that node is deleted by the master's plan loop.
|
||||
"""
|
||||
master_nid = make_node_id("master-leave")
|
||||
session_id = SessionId(master_node_id=master_nid, election_clock=0)
|
||||
|
||||
ge_sender, _ge_receiver = channel[ForwarderEvent]()
|
||||
cmd_sender, cmd_receiver = channel[ForwarderCommand]()
|
||||
le_sender, le_receiver = channel[ForwarderEvent]()
|
||||
dl_sender, _dl_receiver = channel[ForwarderDownloadCommand]()
|
||||
|
||||
master = Master(
|
||||
master_nid,
|
||||
session_id,
|
||||
global_event_sender=ge_sender,
|
||||
local_event_receiver=le_receiver,
|
||||
command_receiver=cmd_receiver,
|
||||
download_command_sender=dl_sender,
|
||||
)
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
tg.start_soon(master.run)
|
||||
|
||||
# Register a worker node
|
||||
worker_nid = make_node_id("worker-leaving")
|
||||
sender_id = NodeId(f"{worker_nid}_sender")
|
||||
await le_sender.send(
|
||||
make_gathered_info_event(worker_nid, sender_id, session_id, 0)
|
||||
)
|
||||
|
||||
with anyio.fail_after(3):
|
||||
while len(list(master.state.topology.list_nodes())) < 1:
|
||||
await anyio.sleep(0.01)
|
||||
|
||||
# Place instance on the worker node
|
||||
await cmd_sender.send(
|
||||
ForwarderCommand(
|
||||
origin=master_nid,
|
||||
command=PlaceInstance(
|
||||
command_id=CommandId(),
|
||||
model_card=TEST_MODEL_CARD,
|
||||
sharding=Sharding.Pipeline,
|
||||
instance_meta=InstanceMeta.MlxRing,
|
||||
min_nodes=1,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
with anyio.fail_after(3):
|
||||
while len(master.state.instances) == 0:
|
||||
await anyio.sleep(0.01)
|
||||
|
||||
assert len(master.state.instances) == 1
|
||||
|
||||
# Simulate node leaving by expiring its last_seen
|
||||
old_time = master.state.last_seen[worker_nid] - timedelta(seconds=60)
|
||||
patched_last_seen = {**master.state.last_seen, worker_nid: old_time}
|
||||
master.state = master.state.model_copy(update={"last_seen": patched_last_seen})
|
||||
|
||||
# The plan loop should detect the stale node and delete the instance
|
||||
# because the node assigned to the instance is no longer in the topology
|
||||
with anyio.fail_after(15):
|
||||
while worker_nid in master.state.last_seen:
|
||||
await anyio.sleep(0.1)
|
||||
|
||||
# After timeout, the node should be removed from topology
|
||||
assert worker_nid not in set(master.state.topology.list_nodes())
|
||||
|
||||
# The instance should eventually be deleted since the assigned node
|
||||
# is no longer connected (the _plan loop kills broken instances)
|
||||
with anyio.fail_after(15):
|
||||
while len(master.state.instances) > 0:
|
||||
await anyio.sleep(0.1)
|
||||
|
||||
assert len(master.state.instances) == 0
|
||||
|
||||
await master.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.asyncio
|
||||
async def test_rapid_join_leave_churn_converges() -> None:
|
||||
"""Rapidly join and leave nodes. After the churn settles, verify the
|
||||
master's state reflects only the surviving nodes.
|
||||
"""
|
||||
master_nid = make_node_id("master-churn")
|
||||
session_id = SessionId(master_node_id=master_nid, election_clock=0)
|
||||
|
||||
ge_sender, _ge_receiver = channel[ForwarderEvent]()
|
||||
_cmd_sender, cmd_receiver = channel[ForwarderCommand]()
|
||||
le_sender, le_receiver = channel[ForwarderEvent]()
|
||||
dl_sender, _dl_receiver = channel[ForwarderDownloadCommand]()
|
||||
|
||||
master = Master(
|
||||
master_nid,
|
||||
session_id,
|
||||
global_event_sender=ge_sender,
|
||||
local_event_receiver=le_receiver,
|
||||
command_receiver=cmd_receiver,
|
||||
download_command_sender=dl_sender,
|
||||
)
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
tg.start_soon(master.run)
|
||||
|
||||
# Register 5 nodes rapidly
|
||||
all_nodes: list[NodeId] = []
|
||||
for i in range(5):
|
||||
nid = make_node_id(f"churn-{i}")
|
||||
all_nodes.append(nid)
|
||||
sender_id = NodeId(f"{nid}_sender")
|
||||
await le_sender.send(
|
||||
make_gathered_info_event(nid, sender_id, session_id, 0)
|
||||
)
|
||||
|
||||
with anyio.fail_after(5):
|
||||
while len(list(master.state.topology.list_nodes())) < 5:
|
||||
await anyio.sleep(0.01)
|
||||
|
||||
assert len(list(master.state.topology.list_nodes())) == 5
|
||||
|
||||
# Expire the first 3 nodes (simulate leaving)
|
||||
leaving_nodes = all_nodes[:3]
|
||||
surviving_nodes = all_nodes[3:]
|
||||
|
||||
patched_last_seen = dict(master.state.last_seen)
|
||||
for nid in leaving_nodes:
|
||||
patched_last_seen[nid] = patched_last_seen[nid] - timedelta(seconds=60)
|
||||
master.state = master.state.model_copy(update={"last_seen": patched_last_seen})
|
||||
|
||||
# Wait for master's plan loop to time out the expired nodes
|
||||
with anyio.fail_after(15):
|
||||
while any(nid in master.state.last_seen for nid in leaving_nodes):
|
||||
await anyio.sleep(0.1)
|
||||
|
||||
# Verify only surviving nodes remain
|
||||
topo_nodes = set(master.state.topology.list_nodes())
|
||||
for nid in leaving_nodes:
|
||||
assert nid not in topo_nodes
|
||||
for nid in surviving_nodes:
|
||||
assert nid in topo_nodes
|
||||
|
||||
assert len(list(master.state.topology.list_nodes())) == 2
|
||||
|
||||
await master.shutdown()
|
||||
Reference in New Issue
Block a user