Compare commits

..

3 Commits

Author SHA1 Message Date
ciaranbor
12ace705fc Test event dropping 2026-02-19 18:55:52 +00:00
ciaranbor
20ccf097bb Route DownloadCoordinator events through worker's event channel 2026-02-19 18:55:52 +00:00
ciaranbor
94848bd5bd Use n-strike ping tolerance 2026-02-19 18:55:52 +00:00
4 changed files with 323 additions and 61 deletions

View File

@@ -23,6 +23,8 @@ use util::wakerdeque::WakerDeque;
const RETRY_CONNECT_INTERVAL: Duration = Duration::from_secs(5);
const MAX_PING_FAILURES: u32 = 3;
mod managed {
use libp2p::swarm::NetworkBehaviour;
use libp2p::{identity, mdns, ping};
@@ -31,8 +33,8 @@ mod managed {
const MDNS_RECORD_TTL: Duration = Duration::from_secs(2_500);
const MDNS_QUERY_INTERVAL: Duration = Duration::from_secs(1_500);
const PING_TIMEOUT: Duration = Duration::from_millis(2_500);
const PING_INTERVAL: Duration = Duration::from_millis(2_500);
const PING_TIMEOUT: Duration = Duration::from_secs(10);
const PING_INTERVAL: Duration = Duration::from_secs(5);
#[derive(NetworkBehaviour)]
pub struct Behaviour {
@@ -109,6 +111,9 @@ pub struct Behaviour {
// pending events to emmit => waker-backed Deque to control polling
pending_events: WakerDeque<ToSwarm<Event, Infallible>>,
// track consecutive ping failures per connection for N-strike tolerance
ping_failures: HashMap<ConnectionId, u32>,
}
impl Behaviour {
@@ -118,6 +123,7 @@ impl Behaviour {
mdns_discovered: HashMap::new(),
retry_delay: Delay::new(RETRY_CONNECT_INTERVAL),
pending_events: WakerDeque::new(),
ping_failures: HashMap::new(),
})
}
@@ -308,6 +314,7 @@ impl NetworkBehaviour for Behaviour {
};
if let Some((ip, port)) = remote_address.try_to_tcp_addr() {
self.ping_failures.remove(&connection_id);
// handle connection closed event which is filtered correctly
self.on_connection_closed(peer_id, connection_id, ip, port)
}
@@ -337,10 +344,41 @@ impl NetworkBehaviour for Behaviour {
}
},
// handle ping events => if error then disconnect
// handle ping events => disconnect after N consecutive failures
managed::BehaviourEvent::Ping(e) => {
if let Err(_) = e.result {
self.close_connection(e.peer, e.connection.clone())
match &e.result {
Err(err) => {
let count = self.ping_failures.entry(e.connection).or_insert(0);
*count += 1;
log::warn!(
"Ping failed for peer {:?} (connection {:?}): {:?} — failure {}/{}",
e.peer,
e.connection,
err,
count,
MAX_PING_FAILURES
);
if *count >= MAX_PING_FAILURES {
log::warn!(
"Closing connection to peer {:?} after {} consecutive ping failures",
e.peer,
MAX_PING_FAILURES
);
self.ping_failures.remove(&e.connection);
self.close_connection(e.peer, e.connection);
}
}
Ok(rtt) => {
// Reset failure counter on successful ping
if self.ping_failures.remove(&e.connection).is_some() {
log::debug!(
"Ping recovered for peer {:?} (rtt={:?}), reset failure counter",
e.peer,
rtt
);
}
log::trace!("Ping OK for peer {:?}: rtt={:?}", e.peer, rtt);
}
}
}
}

View File

@@ -1,7 +1,6 @@
import asyncio
import socket
from dataclasses import dataclass, field
from typing import Iterator
import anyio
from anyio import current_time
@@ -22,10 +21,9 @@ from exo.shared.types.commands import (
ForwarderDownloadCommand,
StartDownload,
)
from exo.shared.types.common import NodeId, SessionId
from exo.shared.types.common import NodeId
from exo.shared.types.events import (
Event,
ForwarderEvent,
NodeDownloadProgress,
)
from exo.shared.types.worker.downloads import (
@@ -36,33 +34,27 @@ from exo.shared.types.worker.downloads import (
DownloadProgress,
)
from exo.shared.types.worker.shards import ShardMetadata
from exo.utils.channels import Receiver, Sender, channel
from exo.utils.channels import Receiver, Sender
@dataclass
class DownloadCoordinator:
node_id: NodeId
session_id: SessionId
shard_downloader: ShardDownloader
download_command_receiver: Receiver[ForwarderDownloadCommand]
local_event_sender: Sender[ForwarderEvent]
event_index_counter: Iterator[int]
event_sender: Sender[Event]
offline: bool = False
# Local state
download_status: dict[ModelId, DownloadProgress] = field(default_factory=dict)
active_downloads: dict[ModelId, asyncio.Task[None]] = field(default_factory=dict)
# Internal event channel for forwarding (initialized in __post_init__)
event_sender: Sender[Event] = field(init=False)
event_receiver: Receiver[Event] = field(init=False)
_tg: TaskGroup = field(init=False, default_factory=anyio.create_task_group)
# Per-model throttle for download progress events
_last_progress_time: dict[ModelId, float] = field(default_factory=dict)
def __post_init__(self) -> None:
self.event_sender, self.event_receiver = channel[Event]()
if self.offline:
self.shard_downloader.set_internet_connection(False)
self.shard_downloader.on_progress(self._download_progress_callback)
@@ -117,7 +109,6 @@ class DownloadCoordinator:
self._test_internet_connection()
async with self._tg as tg:
tg.start_soon(self._command_processor)
tg.start_soon(self._forward_events)
tg.start_soon(self._emit_existing_download_progress)
if not self.offline:
tg.start_soon(self._check_internet_connection)
@@ -297,21 +288,6 @@ class DownloadCoordinator:
)
del self.download_status[model_id]
async def _forward_events(self) -> None:
with self.event_receiver as events:
async for event in events:
idx = next(self.event_index_counter)
fe = ForwarderEvent(
origin_idx=idx,
origin=self.node_id,
session=self.session_id,
event=event,
)
logger.debug(
f"DownloadCoordinator published event {idx}: {str(event)[:100]}"
)
await self.local_event_sender.send(fe)
async def _emit_existing_download_progress(self) -> None:
try:
while True:

View File

@@ -57,23 +57,8 @@ class Node:
logger.info(f"Starting node {node_id}")
# Create shared event index counter for Worker and DownloadCoordinator
event_index_counter = itertools.count()
# Create DownloadCoordinator (unless --no-downloads)
if not args.no_downloads:
download_coordinator = DownloadCoordinator(
node_id,
session_id,
exo_shard_downloader(),
download_command_receiver=router.receiver(topics.DOWNLOAD_COMMANDS),
local_event_sender=router.sender(topics.LOCAL_EVENTS),
event_index_counter=event_index_counter,
offline=args.offline,
)
else:
download_coordinator = None
if args.spawn_api:
api = API(
node_id,
@@ -100,6 +85,20 @@ class Node:
else:
worker = None
# DownloadCoordinator sends events through the Worker's event channel
# so they get the same index sequence and retry mechanism
if not args.no_downloads:
assert worker is not None, "DownloadCoordinator requires a Worker"
download_coordinator = DownloadCoordinator(
node_id,
exo_shard_downloader(),
download_command_receiver=router.receiver(topics.DOWNLOAD_COMMANDS),
event_sender=worker.event_sender.clone(),
offline=args.offline,
)
else:
download_coordinator = None
# We start every node with a master
master = Master(
node_id,
@@ -214,20 +213,6 @@ class Node:
await anyio.sleep(0)
# Fresh counter for new session (buffer expects indices from 0)
self.event_index_counter = itertools.count()
if self.download_coordinator:
self.download_coordinator.shutdown()
self.download_coordinator = DownloadCoordinator(
self.node_id,
result.session_id,
exo_shard_downloader(),
download_command_receiver=self.router.receiver(
topics.DOWNLOAD_COMMANDS
),
local_event_sender=self.router.sender(topics.LOCAL_EVENTS),
event_index_counter=self.event_index_counter,
offline=self.offline,
)
self._tg.start_soon(self.download_coordinator.run)
if self.worker:
self.worker.shutdown()
# TODO: add profiling etc to resource monitor
@@ -245,6 +230,19 @@ class Node:
event_index_counter=self.event_index_counter,
)
self._tg.start_soon(self.worker.run)
if self.download_coordinator:
self.download_coordinator.shutdown()
assert self.worker is not None
self.download_coordinator = DownloadCoordinator(
self.node_id,
exo_shard_downloader(),
download_command_receiver=self.router.receiver(
topics.DOWNLOAD_COMMANDS
),
event_sender=self.worker.event_sender.clone(),
offline=self.offline,
)
self._tg.start_soon(self.download_coordinator.run)
if self.api:
self.api.reset(result.session_id, result.won_clock)
else:

View File

@@ -0,0 +1,250 @@
from collections.abc import Awaitable, Callable
from datetime import datetime, timedelta, timezone
from itertools import count
from pathlib import Path
from typing import AsyncIterator
import anyio
import pytest
from exo.download.coordinator import DownloadCoordinator
from exo.download.shard_downloader import RepoDownloadProgress, ShardDownloader
from exo.master.main import Master
from exo.master.tests.conftest import create_node_memory
from exo.shared.models.model_cards import ModelCard, ModelTask
from exo.shared.types.commands import (
ForwarderCommand,
ForwarderDownloadCommand,
StartDownload,
)
from exo.shared.types.common import ModelId, NodeId, SessionId
from exo.shared.types.events import (
ForwarderEvent,
NodeDownloadProgress,
NodeGatheredInfo,
)
from exo.shared.types.memory import Memory
from exo.shared.types.worker.shards import PipelineShardMetadata, ShardMetadata
from exo.utils.channels import Receiver, Sender, channel
from exo.worker.main import Worker
def _complete_progress(shard: ShardMetadata) -> RepoDownloadProgress:
return RepoDownloadProgress(
repo_id=str(shard.model_card.model_id),
repo_revision="test",
shard=shard,
completed_files=0,
total_files=0,
downloaded_bytes=Memory.from_bytes(0),
downloaded_bytes_this_session=Memory.from_bytes(0),
total_bytes=Memory.from_bytes(0),
overall_speed=0,
overall_eta=timedelta(seconds=0),
status="complete",
)
class _TestShardDownloader(ShardDownloader):
"""Shard downloader that reports every shard as already complete."""
async def ensure_shard(
self, shard: ShardMetadata, config_only: bool = False
) -> Path:
return Path("/tmp/test_shard")
def on_progress(
self,
callback: Callable[[ShardMetadata, RepoDownloadProgress], Awaitable[None]],
) -> None:
pass
async def get_shard_download_status(
self,
) -> AsyncIterator[tuple[Path, RepoDownloadProgress]]:
# Yield nothing — no pre-existing downloads
return
yield # make this an async generator
async def get_shard_download_status_for_shard(
self, shard: ShardMetadata
) -> RepoDownloadProgress:
return _complete_progress(shard)
def _make_heartbeat(node_id: NodeId) -> NodeGatheredInfo:
return NodeGatheredInfo(
node_id=node_id,
when=str(datetime.now(tz=timezone.utc)),
info=create_node_memory(500),
)
class _PartitionSwitch:
"""Mutable boolean flag shared with the partition proxy coroutine."""
def __init__(self) -> None:
self.connected = True
async def _partition_proxy(
source: Receiver[ForwarderEvent],
dest: Sender[ForwarderEvent],
switch: _PartitionSwitch,
) -> None:
"""Forward events when ``switch.connected`` is True; drop otherwise."""
with source as events:
async for event in events:
if switch.connected:
await dest.send(event)
async def _wait_until(
predicate: Callable[[], object], *, timeout: float = 5.0, poll: float = 0.02
) -> None:
"""Poll *predicate* until truthy, raising on timeout."""
with anyio.fail_after(timeout):
while not predicate():
await anyio.sleep(poll)
# ---------------------------------------------------------------------------
# Test 1 same master: Worker + DC retry recovers lost events
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_partition_recovery_same_master() -> None:
"""Worker's out_for_delivery retry fills the Master's buffer gap after a
partition heals, even when DownloadCoordinator events are interleaved."""
master_node = NodeId("master-node")
worker_node = NodeId("worker-node")
session = SessionId(master_node_id=master_node, election_clock=1)
switch = _PartitionSwitch()
# --- channels --------------------------------------------------------
# Worker → proxy → Master (local events)
worker_local_send, proxy_local_recv = channel[ForwarderEvent]()
proxy_local_send, master_local_recv = channel[ForwarderEvent]()
# Master → proxy → Worker (global events)
master_global_send, proxy_global_recv = channel[ForwarderEvent]()
proxy_global_send, worker_global_recv = channel[ForwarderEvent]()
# Commands (required by constructors)
cmd_send, cmd_recv = channel[ForwarderCommand]()
dl_cmd_send, dl_cmd_recv = channel[ForwarderDownloadCommand]()
# --- components ------------------------------------------------------
worker = Worker(
worker_node,
session,
global_event_receiver=worker_global_recv,
local_event_sender=worker_local_send,
command_sender=cmd_send.clone(),
download_command_sender=dl_cmd_send.clone(),
event_index_counter=count(),
)
dc = DownloadCoordinator(
node_id=worker_node,
shard_downloader=_TestShardDownloader(),
download_command_receiver=dl_cmd_recv,
event_sender=worker.event_sender.clone(),
offline=True,
)
master = Master(
master_node,
session,
command_receiver=cmd_recv,
local_event_receiver=master_local_recv,
global_event_sender=master_global_send,
download_command_sender=dl_cmd_send.clone(),
)
async with anyio.create_task_group() as tg:
tg.start_soon(_partition_proxy, proxy_local_recv, proxy_local_send, switch)
tg.start_soon(_partition_proxy, proxy_global_recv, proxy_global_send, switch)
tg.start_soon(master.run)
tg.start_soon(dc.run)
tg.start_soon(worker.run)
# 1. Pre-partition: heartbeat reaches master
await worker.event_sender.send(_make_heartbeat(worker_node))
await _wait_until(lambda: worker_node in master.state.last_seen)
initial_last_seen = master.state.last_seen[worker_node]
# 2. Partition — proxy drops everything
switch.connected = False
# Worker heartbeat during partition — lost at proxy, kept in
# out_for_delivery.
await worker.event_sender.send(_make_heartbeat(worker_node))
# Trigger a download via DC's command channel. NoopShardDownloader
# returns status="complete" for any shard, so _start_download emits
# NodeDownloadProgress(DownloadPending) then
# NodeDownloadProgress(DownloadCompleted) through worker.event_sender.
# These go through _forward_events → proxy (dropped) → out_for_delivery.
# Use a unique model ID so the DC doesn't skip it as already-completed
# (it pre-emits progress for the default "noop" model at startup).
test_shard = PipelineShardMetadata(
model_card=ModelCard(
model_id=ModelId("test-partition-model"),
n_layers=1,
storage_size=Memory.from_bytes(0),
hidden_size=1,
supports_tensor=False,
tasks=[ModelTask.TextGeneration],
),
device_rank=0,
world_size=1,
start_layer=0,
end_layer=1,
n_layers=1,
)
await dl_cmd_send.send(
ForwarderDownloadCommand(
origin=worker_node,
command=StartDownload(
target_node_id=worker_node,
shard_metadata=test_shard,
),
)
)
# Wait for DC events to flow through worker's _forward_events
# (poll instead of sleeping a fixed duration to avoid flakiness on slow CI)
await _wait_until(lambda: len(worker.out_for_delivery) >= 3)
# Verify at least one is a download progress event
has_download_event = any(
isinstance(fe.event, NodeDownloadProgress)
for fe in worker.out_for_delivery.values()
)
assert has_download_event, (
"out_for_delivery should contain DC-originated download events"
)
# 3. Heal partition
switch.connected = True
# Worker's _resend_out_for_delivery runs every ~1-2s.
await _wait_until(
lambda: master.state.last_seen.get(worker_node, initial_last_seen)
> initial_last_seen,
timeout=8.0,
)
# 4. All events recovered — both worker heartbeats and DC download
# progress events were retried and accepted by master.
await _wait_until(lambda: len(worker.out_for_delivery) == 0, timeout=8.0)
# Master state reflects the download
assert worker_node in master.state.downloads
await master.shutdown()
worker.shutdown()
dc.shutdown()