Compare commits

...

3 Commits

Author SHA1 Message Date
ciaranbor
584fe5b270 Route DownloadCoordinator events through worker's event channel 2026-02-19 12:12:26 +00:00
ciaranbor
2947e7ca62 Add test for multi source buffer resets 2026-02-18 20:06:10 +00:00
ciaranbor
33547125ee Use n-strike ping tolerance 2026-02-18 20:06:10 +00:00
4 changed files with 90 additions and 62 deletions

View File

@@ -24,6 +24,8 @@ use util::wakerdeque::WakerDeque;
const RETRY_CONNECT_INTERVAL: Duration = Duration::from_secs(5);
const MAX_PING_FAILURES: u32 = 3;
mod managed {
use libp2p::swarm::NetworkBehaviour;
use libp2p::{identity, mdns, ping};
@@ -32,8 +34,8 @@ mod managed {
const MDNS_RECORD_TTL: Duration = Duration::from_secs(2_500);
const MDNS_QUERY_INTERVAL: Duration = Duration::from_secs(1_500);
const PING_TIMEOUT: Duration = Duration::from_millis(2_500);
const PING_INTERVAL: Duration = Duration::from_millis(2_500);
const PING_TIMEOUT: Duration = Duration::from_secs(10);
const PING_INTERVAL: Duration = Duration::from_secs(5);
#[derive(NetworkBehaviour)]
pub struct Behaviour {
@@ -110,6 +112,9 @@ pub struct Behaviour {
// pending events to emmit => waker-backed Deque to control polling
pending_events: WakerDeque<ToSwarm<Event, Infallible>>,
// track consecutive ping failures per connection for N-strike tolerance
ping_failures: HashMap<ConnectionId, u32>,
}
impl Behaviour {
@@ -119,6 +124,7 @@ impl Behaviour {
mdns_discovered: HashMap::new(),
retry_delay: Delay::new(RETRY_CONNECT_INTERVAL),
pending_events: WakerDeque::new(),
ping_failures: HashMap::new(),
})
}
@@ -309,6 +315,7 @@ impl NetworkBehaviour for Behaviour {
};
if let Some((ip, port)) = remote_address.try_to_tcp_addr() {
self.ping_failures.remove(&connection_id);
// handle connection closed event which is filtered correctly
self.on_connection_closed(peer_id, connection_id, ip, port)
}
@@ -338,10 +345,37 @@ impl NetworkBehaviour for Behaviour {
}
},
// handle ping events => if error then disconnect
// handle ping events => disconnect after N consecutive failures
managed::BehaviourEvent::Ping(e) => {
if let Err(_) = e.result {
self.close_connection(e.peer, e.connection.clone())
match &e.result {
Err(err) => {
let count = self.ping_failures.entry(e.connection).or_insert(0);
*count += 1;
log::warn!(
"Ping failed for peer {:?} (connection {:?}): {:?} — failure {}/{}",
e.peer,
e.connection,
err,
count,
MAX_PING_FAILURES
);
if *count >= MAX_PING_FAILURES {
log::warn!(
"Closing connection to peer {:?} after {} consecutive ping failures",
e.peer,
MAX_PING_FAILURES
);
self.ping_failures.remove(&e.connection);
self.close_connection(e.peer, e.connection);
}
}
Ok(rtt) => {
// Reset failure counter on successful ping
if self.ping_failures.remove(&e.connection).is_some() {
log::debug!("Ping recovered for peer {:?} (rtt={:?}), reset failure counter", e.peer, rtt);
}
log::trace!("Ping OK for peer {:?}: rtt={:?}", e.peer, rtt);
}
}
}
}

View File

@@ -1,7 +1,6 @@
import asyncio
import socket
from dataclasses import dataclass, field
from typing import Iterator
import anyio
from anyio import current_time
@@ -21,10 +20,9 @@ from exo.shared.types.commands import (
ForwarderDownloadCommand,
StartDownload,
)
from exo.shared.types.common import NodeId, SessionId
from exo.shared.types.common import NodeId
from exo.shared.types.events import (
Event,
ForwarderEvent,
NodeDownloadProgress,
)
from exo.shared.types.worker.downloads import (
@@ -35,36 +33,27 @@ from exo.shared.types.worker.downloads import (
DownloadProgress,
)
from exo.shared.types.worker.shards import ShardMetadata
from exo.utils.channels import Receiver, Sender, channel
from exo.utils.channels import Receiver, Sender
@dataclass
class DownloadCoordinator:
node_id: NodeId
session_id: SessionId
shard_downloader: ShardDownloader
download_command_receiver: Receiver[ForwarderDownloadCommand]
local_event_sender: Sender[ForwarderEvent]
event_index_counter: Iterator[int]
event_sender: Sender[Event]
# Local state
download_status: dict[ModelId, DownloadProgress] = field(default_factory=dict)
active_downloads: dict[ModelId, asyncio.Task[None]] = field(default_factory=dict)
# Internal event channel for forwarding (initialized in __post_init__)
event_sender: Sender[Event] = field(init=False)
event_receiver: Receiver[Event] = field(init=False)
_tg: TaskGroup = field(init=False, default_factory=anyio.create_task_group)
def __post_init__(self) -> None:
self.event_sender, self.event_receiver = channel[Event]()
async def run(self) -> None:
logger.info("Starting DownloadCoordinator")
self._test_internet_connection()
async with self._tg as tg:
tg.start_soon(self._command_processor)
tg.start_soon(self._forward_events)
tg.start_soon(self._emit_existing_download_progress)
tg.start_soon(self._check_internet_connection)
@@ -258,21 +247,6 @@ class DownloadCoordinator:
)
del self.download_status[model_id]
async def _forward_events(self) -> None:
with self.event_receiver as events:
async for event in events:
idx = next(self.event_index_counter)
fe = ForwarderEvent(
origin_idx=idx,
origin=self.node_id,
session=self.session_id,
event=event,
)
logger.debug(
f"DownloadCoordinator published event {idx}: {str(event)[:100]}"
)
await self.local_event_sender.send(fe)
async def _emit_existing_download_progress(self) -> None:
try:
while True:

View File

@@ -56,22 +56,8 @@ class Node:
logger.info(f"Starting node {node_id}")
# Create shared event index counter for Worker and DownloadCoordinator
event_index_counter = itertools.count()
# Create DownloadCoordinator (unless --no-downloads)
if not args.no_downloads:
download_coordinator = DownloadCoordinator(
node_id,
session_id,
exo_shard_downloader(),
download_command_receiver=router.receiver(topics.DOWNLOAD_COMMANDS),
local_event_sender=router.sender(topics.LOCAL_EVENTS),
event_index_counter=event_index_counter,
)
else:
download_coordinator = None
if args.spawn_api:
api = API(
node_id,
@@ -98,6 +84,19 @@ class Node:
else:
worker = None
# DownloadCoordinator sends events through the Worker's event channel
# so they get the same index sequence and retry mechanism
if not args.no_downloads:
assert worker is not None, "DownloadCoordinator requires a Worker"
download_coordinator = DownloadCoordinator(
node_id,
exo_shard_downloader(),
download_command_receiver=router.receiver(topics.DOWNLOAD_COMMANDS),
event_sender=worker.event_sender.clone(),
)
else:
download_coordinator = None
# We start every node with a master
master = Master(
node_id,
@@ -211,19 +210,6 @@ class Node:
await anyio.sleep(0)
# Fresh counter for new session (buffer expects indices from 0)
self.event_index_counter = itertools.count()
if self.download_coordinator:
self.download_coordinator.shutdown()
self.download_coordinator = DownloadCoordinator(
self.node_id,
result.session_id,
exo_shard_downloader(),
download_command_receiver=self.router.receiver(
topics.DOWNLOAD_COMMANDS
),
local_event_sender=self.router.sender(topics.LOCAL_EVENTS),
event_index_counter=self.event_index_counter,
)
self._tg.start_soon(self.download_coordinator.run)
if self.worker:
self.worker.shutdown()
# TODO: add profiling etc to resource monitor
@@ -241,6 +227,18 @@ class Node:
event_index_counter=self.event_index_counter,
)
self._tg.start_soon(self.worker.run)
if self.download_coordinator:
self.download_coordinator.shutdown()
assert self.worker is not None
self.download_coordinator = DownloadCoordinator(
self.node_id,
exo_shard_downloader(),
download_command_receiver=self.router.receiver(
topics.DOWNLOAD_COMMANDS
),
event_sender=self.worker.event_sender.clone(),
)
self._tg.start_soon(self.download_coordinator.run)
if self.api:
self.api.reset(result.session_id, result.won_clock)
else:

View File

@@ -1,7 +1,7 @@
import pytest
from exo.shared.types.events import Event, TestEvent
from exo.utils.event_buffer import OrderedBuffer
from exo.utils.event_buffer import MultiSourceBuffer, OrderedBuffer
def make_indexed_event(idx: int) -> tuple[int, Event]:
@@ -124,6 +124,28 @@ async def test_ingest_drops_stale_events(buffer: OrderedBuffer[Event]):
assert buffer.drain() == []
@pytest.mark.asyncio
async def test_multi_source_buffer_resets_on_index_zero():
"""Tests that MultiSourceBuffer resets a source's buffer when it sends idx=0 after progressing."""
msb: MultiSourceBuffer[str, Event] = MultiSourceBuffer()
msb.ingest(0, make_indexed_event(0)[1], "source_a")
msb.ingest(1, make_indexed_event(1)[1], "source_a")
msb.drain()
assert msb.stores["source_a"].next_idx_to_release == 2
new_event0 = make_indexed_event(0)[1]
new_event1 = make_indexed_event(1)[1]
msb.ingest(0, new_event0, "source_a")
msb.ingest(1, new_event1, "source_a")
assert msb.stores["source_a"].next_idx_to_release == 0
result = msb.drain()
assert len(result) == 2
assert msb.stores["source_a"].next_idx_to_release == 2
@pytest.mark.asyncio
async def test_drain_and_ingest_with_new_sequence(buffer: OrderedBuffer[Event]):
"""Tests reusing the buffer after it has been fully drained."""