Route DownloadCoordinator events through worker's event channel

This commit is contained in:
ciaranbor
2026-02-19 12:12:26 +00:00
parent 94848bd5bd
commit 20ccf097bb
3 changed files with 35 additions and 57 deletions

View File

@@ -371,7 +371,11 @@ impl NetworkBehaviour for Behaviour {
Ok(rtt) => {
// Reset failure counter on successful ping
if self.ping_failures.remove(&e.connection).is_some() {
log::debug!("Ping recovered for peer {:?} (rtt={:?}), reset failure counter", e.peer, rtt);
log::debug!(
"Ping recovered for peer {:?} (rtt={:?}), reset failure counter",
e.peer,
rtt
);
}
log::trace!("Ping OK for peer {:?}: rtt={:?}", e.peer, rtt);
}

View File

@@ -1,7 +1,6 @@
import asyncio
import socket
from dataclasses import dataclass, field
from typing import Iterator
import anyio
from anyio import current_time
@@ -22,10 +21,9 @@ from exo.shared.types.commands import (
ForwarderDownloadCommand,
StartDownload,
)
from exo.shared.types.common import NodeId, SessionId
from exo.shared.types.common import NodeId
from exo.shared.types.events import (
Event,
ForwarderEvent,
NodeDownloadProgress,
)
from exo.shared.types.worker.downloads import (
@@ -36,33 +34,27 @@ from exo.shared.types.worker.downloads import (
DownloadProgress,
)
from exo.shared.types.worker.shards import ShardMetadata
from exo.utils.channels import Receiver, Sender, channel
from exo.utils.channels import Receiver, Sender
@dataclass
class DownloadCoordinator:
node_id: NodeId
session_id: SessionId
shard_downloader: ShardDownloader
download_command_receiver: Receiver[ForwarderDownloadCommand]
local_event_sender: Sender[ForwarderEvent]
event_index_counter: Iterator[int]
event_sender: Sender[Event]
offline: bool = False
# Local state
download_status: dict[ModelId, DownloadProgress] = field(default_factory=dict)
active_downloads: dict[ModelId, asyncio.Task[None]] = field(default_factory=dict)
# Internal event channel for forwarding (initialized in __post_init__)
event_sender: Sender[Event] = field(init=False)
event_receiver: Receiver[Event] = field(init=False)
_tg: TaskGroup = field(init=False, default_factory=anyio.create_task_group)
# Per-model throttle for download progress events
_last_progress_time: dict[ModelId, float] = field(default_factory=dict)
def __post_init__(self) -> None:
self.event_sender, self.event_receiver = channel[Event]()
if self.offline:
self.shard_downloader.set_internet_connection(False)
self.shard_downloader.on_progress(self._download_progress_callback)
@@ -117,7 +109,6 @@ class DownloadCoordinator:
self._test_internet_connection()
async with self._tg as tg:
tg.start_soon(self._command_processor)
tg.start_soon(self._forward_events)
tg.start_soon(self._emit_existing_download_progress)
if not self.offline:
tg.start_soon(self._check_internet_connection)
@@ -297,21 +288,6 @@ class DownloadCoordinator:
)
del self.download_status[model_id]
async def _forward_events(self) -> None:
with self.event_receiver as events:
async for event in events:
idx = next(self.event_index_counter)
fe = ForwarderEvent(
origin_idx=idx,
origin=self.node_id,
session=self.session_id,
event=event,
)
logger.debug(
f"DownloadCoordinator published event {idx}: {str(event)[:100]}"
)
await self.local_event_sender.send(fe)
async def _emit_existing_download_progress(self) -> None:
try:
while True:

View File

@@ -57,23 +57,8 @@ class Node:
logger.info(f"Starting node {node_id}")
# Create shared event index counter for Worker and DownloadCoordinator
event_index_counter = itertools.count()
# Create DownloadCoordinator (unless --no-downloads)
if not args.no_downloads:
download_coordinator = DownloadCoordinator(
node_id,
session_id,
exo_shard_downloader(),
download_command_receiver=router.receiver(topics.DOWNLOAD_COMMANDS),
local_event_sender=router.sender(topics.LOCAL_EVENTS),
event_index_counter=event_index_counter,
offline=args.offline,
)
else:
download_coordinator = None
if args.spawn_api:
api = API(
node_id,
@@ -100,6 +85,20 @@ class Node:
else:
worker = None
# DownloadCoordinator sends events through the Worker's event channel
# so they get the same index sequence and retry mechanism
if not args.no_downloads:
assert worker is not None, "DownloadCoordinator requires a Worker"
download_coordinator = DownloadCoordinator(
node_id,
exo_shard_downloader(),
download_command_receiver=router.receiver(topics.DOWNLOAD_COMMANDS),
event_sender=worker.event_sender.clone(),
offline=args.offline,
)
else:
download_coordinator = None
# We start every node with a master
master = Master(
node_id,
@@ -214,20 +213,6 @@ class Node:
await anyio.sleep(0)
# Fresh counter for new session (buffer expects indices from 0)
self.event_index_counter = itertools.count()
if self.download_coordinator:
self.download_coordinator.shutdown()
self.download_coordinator = DownloadCoordinator(
self.node_id,
result.session_id,
exo_shard_downloader(),
download_command_receiver=self.router.receiver(
topics.DOWNLOAD_COMMANDS
),
local_event_sender=self.router.sender(topics.LOCAL_EVENTS),
event_index_counter=self.event_index_counter,
offline=self.offline,
)
self._tg.start_soon(self.download_coordinator.run)
if self.worker:
self.worker.shutdown()
# TODO: add profiling etc to resource monitor
@@ -245,6 +230,19 @@ class Node:
event_index_counter=self.event_index_counter,
)
self._tg.start_soon(self.worker.run)
if self.download_coordinator:
self.download_coordinator.shutdown()
assert self.worker is not None
self.download_coordinator = DownloadCoordinator(
self.node_id,
exo_shard_downloader(),
download_command_receiver=self.router.receiver(
topics.DOWNLOAD_COMMANDS
),
event_sender=self.worker.event_sender.clone(),
offline=self.offline,
)
self._tg.start_soon(self.download_coordinator.run)
if self.api:
self.api.reset(result.session_id, result.won_clock)
else: