mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-19 07:17:30 -05:00
Compare commits
3 Commits
sami/iOS-a
...
ciaran/han
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
584fe5b270 | ||
|
|
2947e7ca62 | ||
|
|
33547125ee |
@@ -24,6 +24,8 @@ use util::wakerdeque::WakerDeque;
|
||||
|
||||
const RETRY_CONNECT_INTERVAL: Duration = Duration::from_secs(5);
|
||||
|
||||
const MAX_PING_FAILURES: u32 = 3;
|
||||
|
||||
mod managed {
|
||||
use libp2p::swarm::NetworkBehaviour;
|
||||
use libp2p::{identity, mdns, ping};
|
||||
@@ -32,8 +34,8 @@ mod managed {
|
||||
|
||||
const MDNS_RECORD_TTL: Duration = Duration::from_secs(2_500);
|
||||
const MDNS_QUERY_INTERVAL: Duration = Duration::from_secs(1_500);
|
||||
const PING_TIMEOUT: Duration = Duration::from_millis(2_500);
|
||||
const PING_INTERVAL: Duration = Duration::from_millis(2_500);
|
||||
const PING_TIMEOUT: Duration = Duration::from_secs(10);
|
||||
const PING_INTERVAL: Duration = Duration::from_secs(5);
|
||||
|
||||
#[derive(NetworkBehaviour)]
|
||||
pub struct Behaviour {
|
||||
@@ -110,6 +112,9 @@ pub struct Behaviour {
|
||||
|
||||
// pending events to emmit => waker-backed Deque to control polling
|
||||
pending_events: WakerDeque<ToSwarm<Event, Infallible>>,
|
||||
|
||||
// track consecutive ping failures per connection for N-strike tolerance
|
||||
ping_failures: HashMap<ConnectionId, u32>,
|
||||
}
|
||||
|
||||
impl Behaviour {
|
||||
@@ -119,6 +124,7 @@ impl Behaviour {
|
||||
mdns_discovered: HashMap::new(),
|
||||
retry_delay: Delay::new(RETRY_CONNECT_INTERVAL),
|
||||
pending_events: WakerDeque::new(),
|
||||
ping_failures: HashMap::new(),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -309,6 +315,7 @@ impl NetworkBehaviour for Behaviour {
|
||||
};
|
||||
|
||||
if let Some((ip, port)) = remote_address.try_to_tcp_addr() {
|
||||
self.ping_failures.remove(&connection_id);
|
||||
// handle connection closed event which is filtered correctly
|
||||
self.on_connection_closed(peer_id, connection_id, ip, port)
|
||||
}
|
||||
@@ -338,10 +345,37 @@ impl NetworkBehaviour for Behaviour {
|
||||
}
|
||||
},
|
||||
|
||||
// handle ping events => if error then disconnect
|
||||
// handle ping events => disconnect after N consecutive failures
|
||||
managed::BehaviourEvent::Ping(e) => {
|
||||
if let Err(_) = e.result {
|
||||
self.close_connection(e.peer, e.connection.clone())
|
||||
match &e.result {
|
||||
Err(err) => {
|
||||
let count = self.ping_failures.entry(e.connection).or_insert(0);
|
||||
*count += 1;
|
||||
log::warn!(
|
||||
"Ping failed for peer {:?} (connection {:?}): {:?} — failure {}/{}",
|
||||
e.peer,
|
||||
e.connection,
|
||||
err,
|
||||
count,
|
||||
MAX_PING_FAILURES
|
||||
);
|
||||
if *count >= MAX_PING_FAILURES {
|
||||
log::warn!(
|
||||
"Closing connection to peer {:?} after {} consecutive ping failures",
|
||||
e.peer,
|
||||
MAX_PING_FAILURES
|
||||
);
|
||||
self.ping_failures.remove(&e.connection);
|
||||
self.close_connection(e.peer, e.connection);
|
||||
}
|
||||
}
|
||||
Ok(rtt) => {
|
||||
// Reset failure counter on successful ping
|
||||
if self.ping_failures.remove(&e.connection).is_some() {
|
||||
log::debug!("Ping recovered for peer {:?} (rtt={:?}), reset failure counter", e.peer, rtt);
|
||||
}
|
||||
log::trace!("Ping OK for peer {:?}: rtt={:?}", e.peer, rtt);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import asyncio
|
||||
import socket
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Iterator
|
||||
|
||||
import anyio
|
||||
from anyio import current_time
|
||||
@@ -21,10 +20,9 @@ from exo.shared.types.commands import (
|
||||
ForwarderDownloadCommand,
|
||||
StartDownload,
|
||||
)
|
||||
from exo.shared.types.common import NodeId, SessionId
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.events import (
|
||||
Event,
|
||||
ForwarderEvent,
|
||||
NodeDownloadProgress,
|
||||
)
|
||||
from exo.shared.types.worker.downloads import (
|
||||
@@ -35,36 +33,27 @@ from exo.shared.types.worker.downloads import (
|
||||
DownloadProgress,
|
||||
)
|
||||
from exo.shared.types.worker.shards import ShardMetadata
|
||||
from exo.utils.channels import Receiver, Sender, channel
|
||||
from exo.utils.channels import Receiver, Sender
|
||||
|
||||
|
||||
@dataclass
|
||||
class DownloadCoordinator:
|
||||
node_id: NodeId
|
||||
session_id: SessionId
|
||||
shard_downloader: ShardDownloader
|
||||
download_command_receiver: Receiver[ForwarderDownloadCommand]
|
||||
local_event_sender: Sender[ForwarderEvent]
|
||||
event_index_counter: Iterator[int]
|
||||
event_sender: Sender[Event]
|
||||
|
||||
# Local state
|
||||
download_status: dict[ModelId, DownloadProgress] = field(default_factory=dict)
|
||||
active_downloads: dict[ModelId, asyncio.Task[None]] = field(default_factory=dict)
|
||||
|
||||
# Internal event channel for forwarding (initialized in __post_init__)
|
||||
event_sender: Sender[Event] = field(init=False)
|
||||
event_receiver: Receiver[Event] = field(init=False)
|
||||
_tg: TaskGroup = field(init=False, default_factory=anyio.create_task_group)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.event_sender, self.event_receiver = channel[Event]()
|
||||
|
||||
async def run(self) -> None:
|
||||
logger.info("Starting DownloadCoordinator")
|
||||
self._test_internet_connection()
|
||||
async with self._tg as tg:
|
||||
tg.start_soon(self._command_processor)
|
||||
tg.start_soon(self._forward_events)
|
||||
tg.start_soon(self._emit_existing_download_progress)
|
||||
tg.start_soon(self._check_internet_connection)
|
||||
|
||||
@@ -258,21 +247,6 @@ class DownloadCoordinator:
|
||||
)
|
||||
del self.download_status[model_id]
|
||||
|
||||
async def _forward_events(self) -> None:
|
||||
with self.event_receiver as events:
|
||||
async for event in events:
|
||||
idx = next(self.event_index_counter)
|
||||
fe = ForwarderEvent(
|
||||
origin_idx=idx,
|
||||
origin=self.node_id,
|
||||
session=self.session_id,
|
||||
event=event,
|
||||
)
|
||||
logger.debug(
|
||||
f"DownloadCoordinator published event {idx}: {str(event)[:100]}"
|
||||
)
|
||||
await self.local_event_sender.send(fe)
|
||||
|
||||
async def _emit_existing_download_progress(self) -> None:
|
||||
try:
|
||||
while True:
|
||||
|
||||
@@ -56,22 +56,8 @@ class Node:
|
||||
|
||||
logger.info(f"Starting node {node_id}")
|
||||
|
||||
# Create shared event index counter for Worker and DownloadCoordinator
|
||||
event_index_counter = itertools.count()
|
||||
|
||||
# Create DownloadCoordinator (unless --no-downloads)
|
||||
if not args.no_downloads:
|
||||
download_coordinator = DownloadCoordinator(
|
||||
node_id,
|
||||
session_id,
|
||||
exo_shard_downloader(),
|
||||
download_command_receiver=router.receiver(topics.DOWNLOAD_COMMANDS),
|
||||
local_event_sender=router.sender(topics.LOCAL_EVENTS),
|
||||
event_index_counter=event_index_counter,
|
||||
)
|
||||
else:
|
||||
download_coordinator = None
|
||||
|
||||
if args.spawn_api:
|
||||
api = API(
|
||||
node_id,
|
||||
@@ -98,6 +84,19 @@ class Node:
|
||||
else:
|
||||
worker = None
|
||||
|
||||
# DownloadCoordinator sends events through the Worker's event channel
|
||||
# so they get the same index sequence and retry mechanism
|
||||
if not args.no_downloads:
|
||||
assert worker is not None, "DownloadCoordinator requires a Worker"
|
||||
download_coordinator = DownloadCoordinator(
|
||||
node_id,
|
||||
exo_shard_downloader(),
|
||||
download_command_receiver=router.receiver(topics.DOWNLOAD_COMMANDS),
|
||||
event_sender=worker.event_sender.clone(),
|
||||
)
|
||||
else:
|
||||
download_coordinator = None
|
||||
|
||||
# We start every node with a master
|
||||
master = Master(
|
||||
node_id,
|
||||
@@ -211,19 +210,6 @@ class Node:
|
||||
await anyio.sleep(0)
|
||||
# Fresh counter for new session (buffer expects indices from 0)
|
||||
self.event_index_counter = itertools.count()
|
||||
if self.download_coordinator:
|
||||
self.download_coordinator.shutdown()
|
||||
self.download_coordinator = DownloadCoordinator(
|
||||
self.node_id,
|
||||
result.session_id,
|
||||
exo_shard_downloader(),
|
||||
download_command_receiver=self.router.receiver(
|
||||
topics.DOWNLOAD_COMMANDS
|
||||
),
|
||||
local_event_sender=self.router.sender(topics.LOCAL_EVENTS),
|
||||
event_index_counter=self.event_index_counter,
|
||||
)
|
||||
self._tg.start_soon(self.download_coordinator.run)
|
||||
if self.worker:
|
||||
self.worker.shutdown()
|
||||
# TODO: add profiling etc to resource monitor
|
||||
@@ -241,6 +227,18 @@ class Node:
|
||||
event_index_counter=self.event_index_counter,
|
||||
)
|
||||
self._tg.start_soon(self.worker.run)
|
||||
if self.download_coordinator:
|
||||
self.download_coordinator.shutdown()
|
||||
assert self.worker is not None
|
||||
self.download_coordinator = DownloadCoordinator(
|
||||
self.node_id,
|
||||
exo_shard_downloader(),
|
||||
download_command_receiver=self.router.receiver(
|
||||
topics.DOWNLOAD_COMMANDS
|
||||
),
|
||||
event_sender=self.worker.event_sender.clone(),
|
||||
)
|
||||
self._tg.start_soon(self.download_coordinator.run)
|
||||
if self.api:
|
||||
self.api.reset(result.session_id, result.won_clock)
|
||||
else:
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import pytest
|
||||
|
||||
from exo.shared.types.events import Event, TestEvent
|
||||
from exo.utils.event_buffer import OrderedBuffer
|
||||
from exo.utils.event_buffer import MultiSourceBuffer, OrderedBuffer
|
||||
|
||||
|
||||
def make_indexed_event(idx: int) -> tuple[int, Event]:
|
||||
@@ -124,6 +124,28 @@ async def test_ingest_drops_stale_events(buffer: OrderedBuffer[Event]):
|
||||
assert buffer.drain() == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multi_source_buffer_resets_on_index_zero():
|
||||
"""Tests that MultiSourceBuffer resets a source's buffer when it sends idx=0 after progressing."""
|
||||
msb: MultiSourceBuffer[str, Event] = MultiSourceBuffer()
|
||||
|
||||
msb.ingest(0, make_indexed_event(0)[1], "source_a")
|
||||
msb.ingest(1, make_indexed_event(1)[1], "source_a")
|
||||
msb.drain()
|
||||
|
||||
assert msb.stores["source_a"].next_idx_to_release == 2
|
||||
|
||||
new_event0 = make_indexed_event(0)[1]
|
||||
new_event1 = make_indexed_event(1)[1]
|
||||
msb.ingest(0, new_event0, "source_a")
|
||||
msb.ingest(1, new_event1, "source_a")
|
||||
|
||||
assert msb.stores["source_a"].next_idx_to_release == 0
|
||||
result = msb.drain()
|
||||
assert len(result) == 2
|
||||
assert msb.stores["source_a"].next_idx_to_release == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_drain_and_ingest_with_new_sequence(buffer: OrderedBuffer[Event]):
|
||||
"""Tests reusing the buffer after it has been fully drained."""
|
||||
|
||||
Reference in New Issue
Block a user