From 20ccf097bb2aecfed27b467afe41bafda626ce92 Mon Sep 17 00:00:00 2001 From: ciaranbor Date: Thu, 19 Feb 2026 12:12:26 +0000 Subject: [PATCH] Route DownloadCoordinator events through worker's event channel --- rust/networking/src/discovery.rs | 6 +++- src/exo/download/coordinator.py | 30 ++--------------- src/exo/main.py | 56 +++++++++++++++----------------- 3 files changed, 35 insertions(+), 57 deletions(-) diff --git a/rust/networking/src/discovery.rs b/rust/networking/src/discovery.rs index 8db848236..01230c6ab 100644 --- a/rust/networking/src/discovery.rs +++ b/rust/networking/src/discovery.rs @@ -371,7 +371,11 @@ impl NetworkBehaviour for Behaviour { Ok(rtt) => { // Reset failure counter on successful ping if self.ping_failures.remove(&e.connection).is_some() { - log::debug!("Ping recovered for peer {:?} (rtt={:?}), reset failure counter", e.peer, rtt); + log::debug!( + "Ping recovered for peer {:?} (rtt={:?}), reset failure counter", + e.peer, + rtt + ); } log::trace!("Ping OK for peer {:?}: rtt={:?}", e.peer, rtt); } diff --git a/src/exo/download/coordinator.py b/src/exo/download/coordinator.py index 30e45a082..b1dc647d5 100644 --- a/src/exo/download/coordinator.py +++ b/src/exo/download/coordinator.py @@ -1,7 +1,6 @@ import asyncio import socket from dataclasses import dataclass, field -from typing import Iterator import anyio from anyio import current_time @@ -22,10 +21,9 @@ from exo.shared.types.commands import ( ForwarderDownloadCommand, StartDownload, ) -from exo.shared.types.common import NodeId, SessionId +from exo.shared.types.common import NodeId from exo.shared.types.events import ( Event, - ForwarderEvent, NodeDownloadProgress, ) from exo.shared.types.worker.downloads import ( @@ -36,33 +34,27 @@ from exo.shared.types.worker.downloads import ( DownloadProgress, ) from exo.shared.types.worker.shards import ShardMetadata -from exo.utils.channels import Receiver, Sender, channel +from exo.utils.channels import Receiver, Sender @dataclass class DownloadCoordinator: node_id: NodeId - session_id: SessionId shard_downloader: ShardDownloader download_command_receiver: Receiver[ForwarderDownloadCommand] - local_event_sender: Sender[ForwarderEvent] - event_index_counter: Iterator[int] + event_sender: Sender[Event] offline: bool = False # Local state download_status: dict[ModelId, DownloadProgress] = field(default_factory=dict) active_downloads: dict[ModelId, asyncio.Task[None]] = field(default_factory=dict) - # Internal event channel for forwarding (initialized in __post_init__) - event_sender: Sender[Event] = field(init=False) - event_receiver: Receiver[Event] = field(init=False) _tg: TaskGroup = field(init=False, default_factory=anyio.create_task_group) # Per-model throttle for download progress events _last_progress_time: dict[ModelId, float] = field(default_factory=dict) def __post_init__(self) -> None: - self.event_sender, self.event_receiver = channel[Event]() if self.offline: self.shard_downloader.set_internet_connection(False) self.shard_downloader.on_progress(self._download_progress_callback) @@ -117,7 +109,6 @@ class DownloadCoordinator: self._test_internet_connection() async with self._tg as tg: tg.start_soon(self._command_processor) - tg.start_soon(self._forward_events) tg.start_soon(self._emit_existing_download_progress) if not self.offline: tg.start_soon(self._check_internet_connection) @@ -297,21 +288,6 @@ class DownloadCoordinator: ) del self.download_status[model_id] - async def _forward_events(self) -> None: - with self.event_receiver as events: - async for event in events: - idx = next(self.event_index_counter) - fe = ForwarderEvent( - origin_idx=idx, - origin=self.node_id, - session=self.session_id, - event=event, - ) - logger.debug( - f"DownloadCoordinator published event {idx}: {str(event)[:100]}" - ) - await self.local_event_sender.send(fe) - async def _emit_existing_download_progress(self) -> None: try: while True: diff --git a/src/exo/main.py b/src/exo/main.py index 27c78165a..624fce212 100644 --- a/src/exo/main.py +++ b/src/exo/main.py @@ -57,23 +57,8 @@ class Node: logger.info(f"Starting node {node_id}") - # Create shared event index counter for Worker and DownloadCoordinator event_index_counter = itertools.count() - # Create DownloadCoordinator (unless --no-downloads) - if not args.no_downloads: - download_coordinator = DownloadCoordinator( - node_id, - session_id, - exo_shard_downloader(), - download_command_receiver=router.receiver(topics.DOWNLOAD_COMMANDS), - local_event_sender=router.sender(topics.LOCAL_EVENTS), - event_index_counter=event_index_counter, - offline=args.offline, - ) - else: - download_coordinator = None - if args.spawn_api: api = API( node_id, @@ -100,6 +85,20 @@ class Node: else: worker = None + # DownloadCoordinator sends events through the Worker's event channel + # so they get the same index sequence and retry mechanism + if not args.no_downloads: + assert worker is not None, "DownloadCoordinator requires a Worker" + download_coordinator = DownloadCoordinator( + node_id, + exo_shard_downloader(), + download_command_receiver=router.receiver(topics.DOWNLOAD_COMMANDS), + event_sender=worker.event_sender.clone(), + offline=args.offline, + ) + else: + download_coordinator = None + # We start every node with a master master = Master( node_id, @@ -214,20 +213,6 @@ class Node: await anyio.sleep(0) # Fresh counter for new session (buffer expects indices from 0) self.event_index_counter = itertools.count() - if self.download_coordinator: - self.download_coordinator.shutdown() - self.download_coordinator = DownloadCoordinator( - self.node_id, - result.session_id, - exo_shard_downloader(), - download_command_receiver=self.router.receiver( - topics.DOWNLOAD_COMMANDS - ), - local_event_sender=self.router.sender(topics.LOCAL_EVENTS), - event_index_counter=self.event_index_counter, - offline=self.offline, - ) - self._tg.start_soon(self.download_coordinator.run) if self.worker: self.worker.shutdown() # TODO: add profiling etc to resource monitor @@ -245,6 +230,19 @@ class Node: event_index_counter=self.event_index_counter, ) self._tg.start_soon(self.worker.run) + if self.download_coordinator: + self.download_coordinator.shutdown() + assert self.worker is not None + self.download_coordinator = DownloadCoordinator( + self.node_id, + exo_shard_downloader(), + download_command_receiver=self.router.receiver( + topics.DOWNLOAD_COMMANDS + ), + event_sender=self.worker.event_sender.clone(), + offline=self.offline, + ) + self._tg.start_soon(self.download_coordinator.run) if self.api: self.api.reset(result.session_id, result.won_clock) else: