mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-19 15:27:02 -05:00
Compare commits
3 Commits
main
...
ciaran/han
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
12ace705fc | ||
|
|
20ccf097bb | ||
|
|
94848bd5bd |
@@ -23,6 +23,8 @@ use util::wakerdeque::WakerDeque;
|
||||
|
||||
const RETRY_CONNECT_INTERVAL: Duration = Duration::from_secs(5);
|
||||
|
||||
const MAX_PING_FAILURES: u32 = 3;
|
||||
|
||||
mod managed {
|
||||
use libp2p::swarm::NetworkBehaviour;
|
||||
use libp2p::{identity, mdns, ping};
|
||||
@@ -31,8 +33,8 @@ mod managed {
|
||||
|
||||
const MDNS_RECORD_TTL: Duration = Duration::from_secs(2_500);
|
||||
const MDNS_QUERY_INTERVAL: Duration = Duration::from_secs(1_500);
|
||||
const PING_TIMEOUT: Duration = Duration::from_millis(2_500);
|
||||
const PING_INTERVAL: Duration = Duration::from_millis(2_500);
|
||||
const PING_TIMEOUT: Duration = Duration::from_secs(10);
|
||||
const PING_INTERVAL: Duration = Duration::from_secs(5);
|
||||
|
||||
#[derive(NetworkBehaviour)]
|
||||
pub struct Behaviour {
|
||||
@@ -109,6 +111,9 @@ pub struct Behaviour {
|
||||
|
||||
// pending events to emmit => waker-backed Deque to control polling
|
||||
pending_events: WakerDeque<ToSwarm<Event, Infallible>>,
|
||||
|
||||
// track consecutive ping failures per connection for N-strike tolerance
|
||||
ping_failures: HashMap<ConnectionId, u32>,
|
||||
}
|
||||
|
||||
impl Behaviour {
|
||||
@@ -118,6 +123,7 @@ impl Behaviour {
|
||||
mdns_discovered: HashMap::new(),
|
||||
retry_delay: Delay::new(RETRY_CONNECT_INTERVAL),
|
||||
pending_events: WakerDeque::new(),
|
||||
ping_failures: HashMap::new(),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -308,6 +314,7 @@ impl NetworkBehaviour for Behaviour {
|
||||
};
|
||||
|
||||
if let Some((ip, port)) = remote_address.try_to_tcp_addr() {
|
||||
self.ping_failures.remove(&connection_id);
|
||||
// handle connection closed event which is filtered correctly
|
||||
self.on_connection_closed(peer_id, connection_id, ip, port)
|
||||
}
|
||||
@@ -337,10 +344,41 @@ impl NetworkBehaviour for Behaviour {
|
||||
}
|
||||
},
|
||||
|
||||
// handle ping events => if error then disconnect
|
||||
// handle ping events => disconnect after N consecutive failures
|
||||
managed::BehaviourEvent::Ping(e) => {
|
||||
if let Err(_) = e.result {
|
||||
self.close_connection(e.peer, e.connection.clone())
|
||||
match &e.result {
|
||||
Err(err) => {
|
||||
let count = self.ping_failures.entry(e.connection).or_insert(0);
|
||||
*count += 1;
|
||||
log::warn!(
|
||||
"Ping failed for peer {:?} (connection {:?}): {:?} — failure {}/{}",
|
||||
e.peer,
|
||||
e.connection,
|
||||
err,
|
||||
count,
|
||||
MAX_PING_FAILURES
|
||||
);
|
||||
if *count >= MAX_PING_FAILURES {
|
||||
log::warn!(
|
||||
"Closing connection to peer {:?} after {} consecutive ping failures",
|
||||
e.peer,
|
||||
MAX_PING_FAILURES
|
||||
);
|
||||
self.ping_failures.remove(&e.connection);
|
||||
self.close_connection(e.peer, e.connection);
|
||||
}
|
||||
}
|
||||
Ok(rtt) => {
|
||||
// Reset failure counter on successful ping
|
||||
if self.ping_failures.remove(&e.connection).is_some() {
|
||||
log::debug!(
|
||||
"Ping recovered for peer {:?} (rtt={:?}), reset failure counter",
|
||||
e.peer,
|
||||
rtt
|
||||
);
|
||||
}
|
||||
log::trace!("Ping OK for peer {:?}: rtt={:?}", e.peer, rtt);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import asyncio
|
||||
import socket
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Iterator
|
||||
|
||||
import anyio
|
||||
from anyio import current_time
|
||||
@@ -22,10 +21,9 @@ from exo.shared.types.commands import (
|
||||
ForwarderDownloadCommand,
|
||||
StartDownload,
|
||||
)
|
||||
from exo.shared.types.common import NodeId, SessionId
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.events import (
|
||||
Event,
|
||||
ForwarderEvent,
|
||||
NodeDownloadProgress,
|
||||
)
|
||||
from exo.shared.types.worker.downloads import (
|
||||
@@ -36,33 +34,27 @@ from exo.shared.types.worker.downloads import (
|
||||
DownloadProgress,
|
||||
)
|
||||
from exo.shared.types.worker.shards import ShardMetadata
|
||||
from exo.utils.channels import Receiver, Sender, channel
|
||||
from exo.utils.channels import Receiver, Sender
|
||||
|
||||
|
||||
@dataclass
|
||||
class DownloadCoordinator:
|
||||
node_id: NodeId
|
||||
session_id: SessionId
|
||||
shard_downloader: ShardDownloader
|
||||
download_command_receiver: Receiver[ForwarderDownloadCommand]
|
||||
local_event_sender: Sender[ForwarderEvent]
|
||||
event_index_counter: Iterator[int]
|
||||
event_sender: Sender[Event]
|
||||
offline: bool = False
|
||||
|
||||
# Local state
|
||||
download_status: dict[ModelId, DownloadProgress] = field(default_factory=dict)
|
||||
active_downloads: dict[ModelId, asyncio.Task[None]] = field(default_factory=dict)
|
||||
|
||||
# Internal event channel for forwarding (initialized in __post_init__)
|
||||
event_sender: Sender[Event] = field(init=False)
|
||||
event_receiver: Receiver[Event] = field(init=False)
|
||||
_tg: TaskGroup = field(init=False, default_factory=anyio.create_task_group)
|
||||
|
||||
# Per-model throttle for download progress events
|
||||
_last_progress_time: dict[ModelId, float] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.event_sender, self.event_receiver = channel[Event]()
|
||||
if self.offline:
|
||||
self.shard_downloader.set_internet_connection(False)
|
||||
self.shard_downloader.on_progress(self._download_progress_callback)
|
||||
@@ -117,7 +109,6 @@ class DownloadCoordinator:
|
||||
self._test_internet_connection()
|
||||
async with self._tg as tg:
|
||||
tg.start_soon(self._command_processor)
|
||||
tg.start_soon(self._forward_events)
|
||||
tg.start_soon(self._emit_existing_download_progress)
|
||||
if not self.offline:
|
||||
tg.start_soon(self._check_internet_connection)
|
||||
@@ -297,21 +288,6 @@ class DownloadCoordinator:
|
||||
)
|
||||
del self.download_status[model_id]
|
||||
|
||||
async def _forward_events(self) -> None:
|
||||
with self.event_receiver as events:
|
||||
async for event in events:
|
||||
idx = next(self.event_index_counter)
|
||||
fe = ForwarderEvent(
|
||||
origin_idx=idx,
|
||||
origin=self.node_id,
|
||||
session=self.session_id,
|
||||
event=event,
|
||||
)
|
||||
logger.debug(
|
||||
f"DownloadCoordinator published event {idx}: {str(event)[:100]}"
|
||||
)
|
||||
await self.local_event_sender.send(fe)
|
||||
|
||||
async def _emit_existing_download_progress(self) -> None:
|
||||
try:
|
||||
while True:
|
||||
|
||||
@@ -57,23 +57,8 @@ class Node:
|
||||
|
||||
logger.info(f"Starting node {node_id}")
|
||||
|
||||
# Create shared event index counter for Worker and DownloadCoordinator
|
||||
event_index_counter = itertools.count()
|
||||
|
||||
# Create DownloadCoordinator (unless --no-downloads)
|
||||
if not args.no_downloads:
|
||||
download_coordinator = DownloadCoordinator(
|
||||
node_id,
|
||||
session_id,
|
||||
exo_shard_downloader(),
|
||||
download_command_receiver=router.receiver(topics.DOWNLOAD_COMMANDS),
|
||||
local_event_sender=router.sender(topics.LOCAL_EVENTS),
|
||||
event_index_counter=event_index_counter,
|
||||
offline=args.offline,
|
||||
)
|
||||
else:
|
||||
download_coordinator = None
|
||||
|
||||
if args.spawn_api:
|
||||
api = API(
|
||||
node_id,
|
||||
@@ -100,6 +85,20 @@ class Node:
|
||||
else:
|
||||
worker = None
|
||||
|
||||
# DownloadCoordinator sends events through the Worker's event channel
|
||||
# so they get the same index sequence and retry mechanism
|
||||
if not args.no_downloads:
|
||||
assert worker is not None, "DownloadCoordinator requires a Worker"
|
||||
download_coordinator = DownloadCoordinator(
|
||||
node_id,
|
||||
exo_shard_downloader(),
|
||||
download_command_receiver=router.receiver(topics.DOWNLOAD_COMMANDS),
|
||||
event_sender=worker.event_sender.clone(),
|
||||
offline=args.offline,
|
||||
)
|
||||
else:
|
||||
download_coordinator = None
|
||||
|
||||
# We start every node with a master
|
||||
master = Master(
|
||||
node_id,
|
||||
@@ -214,20 +213,6 @@ class Node:
|
||||
await anyio.sleep(0)
|
||||
# Fresh counter for new session (buffer expects indices from 0)
|
||||
self.event_index_counter = itertools.count()
|
||||
if self.download_coordinator:
|
||||
self.download_coordinator.shutdown()
|
||||
self.download_coordinator = DownloadCoordinator(
|
||||
self.node_id,
|
||||
result.session_id,
|
||||
exo_shard_downloader(),
|
||||
download_command_receiver=self.router.receiver(
|
||||
topics.DOWNLOAD_COMMANDS
|
||||
),
|
||||
local_event_sender=self.router.sender(topics.LOCAL_EVENTS),
|
||||
event_index_counter=self.event_index_counter,
|
||||
offline=self.offline,
|
||||
)
|
||||
self._tg.start_soon(self.download_coordinator.run)
|
||||
if self.worker:
|
||||
self.worker.shutdown()
|
||||
# TODO: add profiling etc to resource monitor
|
||||
@@ -245,6 +230,19 @@ class Node:
|
||||
event_index_counter=self.event_index_counter,
|
||||
)
|
||||
self._tg.start_soon(self.worker.run)
|
||||
if self.download_coordinator:
|
||||
self.download_coordinator.shutdown()
|
||||
assert self.worker is not None
|
||||
self.download_coordinator = DownloadCoordinator(
|
||||
self.node_id,
|
||||
exo_shard_downloader(),
|
||||
download_command_receiver=self.router.receiver(
|
||||
topics.DOWNLOAD_COMMANDS
|
||||
),
|
||||
event_sender=self.worker.event_sender.clone(),
|
||||
offline=self.offline,
|
||||
)
|
||||
self._tg.start_soon(self.download_coordinator.run)
|
||||
if self.api:
|
||||
self.api.reset(result.session_id, result.won_clock)
|
||||
else:
|
||||
|
||||
250
src/exo/master/tests/test_partition_recovery.py
Normal file
250
src/exo/master/tests/test_partition_recovery.py
Normal file
@@ -0,0 +1,250 @@
|
||||
from collections.abc import Awaitable, Callable
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from itertools import count
|
||||
from pathlib import Path
|
||||
from typing import AsyncIterator
|
||||
|
||||
import anyio
|
||||
import pytest
|
||||
|
||||
from exo.download.coordinator import DownloadCoordinator
|
||||
from exo.download.shard_downloader import RepoDownloadProgress, ShardDownloader
|
||||
from exo.master.main import Master
|
||||
from exo.master.tests.conftest import create_node_memory
|
||||
from exo.shared.models.model_cards import ModelCard, ModelTask
|
||||
from exo.shared.types.commands import (
|
||||
ForwarderCommand,
|
||||
ForwarderDownloadCommand,
|
||||
StartDownload,
|
||||
)
|
||||
from exo.shared.types.common import ModelId, NodeId, SessionId
|
||||
from exo.shared.types.events import (
|
||||
ForwarderEvent,
|
||||
NodeDownloadProgress,
|
||||
NodeGatheredInfo,
|
||||
)
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata, ShardMetadata
|
||||
from exo.utils.channels import Receiver, Sender, channel
|
||||
from exo.worker.main import Worker
|
||||
|
||||
|
||||
def _complete_progress(shard: ShardMetadata) -> RepoDownloadProgress:
|
||||
return RepoDownloadProgress(
|
||||
repo_id=str(shard.model_card.model_id),
|
||||
repo_revision="test",
|
||||
shard=shard,
|
||||
completed_files=0,
|
||||
total_files=0,
|
||||
downloaded_bytes=Memory.from_bytes(0),
|
||||
downloaded_bytes_this_session=Memory.from_bytes(0),
|
||||
total_bytes=Memory.from_bytes(0),
|
||||
overall_speed=0,
|
||||
overall_eta=timedelta(seconds=0),
|
||||
status="complete",
|
||||
)
|
||||
|
||||
|
||||
class _TestShardDownloader(ShardDownloader):
|
||||
"""Shard downloader that reports every shard as already complete."""
|
||||
|
||||
async def ensure_shard(
|
||||
self, shard: ShardMetadata, config_only: bool = False
|
||||
) -> Path:
|
||||
return Path("/tmp/test_shard")
|
||||
|
||||
def on_progress(
|
||||
self,
|
||||
callback: Callable[[ShardMetadata, RepoDownloadProgress], Awaitable[None]],
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
async def get_shard_download_status(
|
||||
self,
|
||||
) -> AsyncIterator[tuple[Path, RepoDownloadProgress]]:
|
||||
# Yield nothing — no pre-existing downloads
|
||||
return
|
||||
yield # make this an async generator
|
||||
|
||||
async def get_shard_download_status_for_shard(
|
||||
self, shard: ShardMetadata
|
||||
) -> RepoDownloadProgress:
|
||||
return _complete_progress(shard)
|
||||
|
||||
|
||||
def _make_heartbeat(node_id: NodeId) -> NodeGatheredInfo:
|
||||
return NodeGatheredInfo(
|
||||
node_id=node_id,
|
||||
when=str(datetime.now(tz=timezone.utc)),
|
||||
info=create_node_memory(500),
|
||||
)
|
||||
|
||||
|
||||
class _PartitionSwitch:
|
||||
"""Mutable boolean flag shared with the partition proxy coroutine."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.connected = True
|
||||
|
||||
|
||||
async def _partition_proxy(
|
||||
source: Receiver[ForwarderEvent],
|
||||
dest: Sender[ForwarderEvent],
|
||||
switch: _PartitionSwitch,
|
||||
) -> None:
|
||||
"""Forward events when ``switch.connected`` is True; drop otherwise."""
|
||||
with source as events:
|
||||
async for event in events:
|
||||
if switch.connected:
|
||||
await dest.send(event)
|
||||
|
||||
|
||||
async def _wait_until(
|
||||
predicate: Callable[[], object], *, timeout: float = 5.0, poll: float = 0.02
|
||||
) -> None:
|
||||
"""Poll *predicate* until truthy, raising on timeout."""
|
||||
with anyio.fail_after(timeout):
|
||||
while not predicate():
|
||||
await anyio.sleep(poll)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test 1 – same master: Worker + DC retry recovers lost events
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_partition_recovery_same_master() -> None:
|
||||
"""Worker's out_for_delivery retry fills the Master's buffer gap after a
|
||||
partition heals, even when DownloadCoordinator events are interleaved."""
|
||||
|
||||
master_node = NodeId("master-node")
|
||||
worker_node = NodeId("worker-node")
|
||||
session = SessionId(master_node_id=master_node, election_clock=1)
|
||||
switch = _PartitionSwitch()
|
||||
|
||||
# --- channels --------------------------------------------------------
|
||||
# Worker → proxy → Master (local events)
|
||||
worker_local_send, proxy_local_recv = channel[ForwarderEvent]()
|
||||
proxy_local_send, master_local_recv = channel[ForwarderEvent]()
|
||||
|
||||
# Master → proxy → Worker (global events)
|
||||
master_global_send, proxy_global_recv = channel[ForwarderEvent]()
|
||||
proxy_global_send, worker_global_recv = channel[ForwarderEvent]()
|
||||
|
||||
# Commands (required by constructors)
|
||||
cmd_send, cmd_recv = channel[ForwarderCommand]()
|
||||
dl_cmd_send, dl_cmd_recv = channel[ForwarderDownloadCommand]()
|
||||
|
||||
# --- components ------------------------------------------------------
|
||||
worker = Worker(
|
||||
worker_node,
|
||||
session,
|
||||
global_event_receiver=worker_global_recv,
|
||||
local_event_sender=worker_local_send,
|
||||
command_sender=cmd_send.clone(),
|
||||
download_command_sender=dl_cmd_send.clone(),
|
||||
event_index_counter=count(),
|
||||
)
|
||||
|
||||
dc = DownloadCoordinator(
|
||||
node_id=worker_node,
|
||||
shard_downloader=_TestShardDownloader(),
|
||||
download_command_receiver=dl_cmd_recv,
|
||||
event_sender=worker.event_sender.clone(),
|
||||
offline=True,
|
||||
)
|
||||
|
||||
master = Master(
|
||||
master_node,
|
||||
session,
|
||||
command_receiver=cmd_recv,
|
||||
local_event_receiver=master_local_recv,
|
||||
global_event_sender=master_global_send,
|
||||
download_command_sender=dl_cmd_send.clone(),
|
||||
)
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
tg.start_soon(_partition_proxy, proxy_local_recv, proxy_local_send, switch)
|
||||
tg.start_soon(_partition_proxy, proxy_global_recv, proxy_global_send, switch)
|
||||
tg.start_soon(master.run)
|
||||
tg.start_soon(dc.run)
|
||||
tg.start_soon(worker.run)
|
||||
|
||||
# 1. Pre-partition: heartbeat reaches master
|
||||
await worker.event_sender.send(_make_heartbeat(worker_node))
|
||||
await _wait_until(lambda: worker_node in master.state.last_seen)
|
||||
initial_last_seen = master.state.last_seen[worker_node]
|
||||
|
||||
# 2. Partition — proxy drops everything
|
||||
switch.connected = False
|
||||
|
||||
# Worker heartbeat during partition — lost at proxy, kept in
|
||||
# out_for_delivery.
|
||||
await worker.event_sender.send(_make_heartbeat(worker_node))
|
||||
|
||||
# Trigger a download via DC's command channel. NoopShardDownloader
|
||||
# returns status="complete" for any shard, so _start_download emits
|
||||
# NodeDownloadProgress(DownloadPending) then
|
||||
# NodeDownloadProgress(DownloadCompleted) through worker.event_sender.
|
||||
# These go through _forward_events → proxy (dropped) → out_for_delivery.
|
||||
# Use a unique model ID so the DC doesn't skip it as already-completed
|
||||
# (it pre-emits progress for the default "noop" model at startup).
|
||||
test_shard = PipelineShardMetadata(
|
||||
model_card=ModelCard(
|
||||
model_id=ModelId("test-partition-model"),
|
||||
n_layers=1,
|
||||
storage_size=Memory.from_bytes(0),
|
||||
hidden_size=1,
|
||||
supports_tensor=False,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
device_rank=0,
|
||||
world_size=1,
|
||||
start_layer=0,
|
||||
end_layer=1,
|
||||
n_layers=1,
|
||||
)
|
||||
await dl_cmd_send.send(
|
||||
ForwarderDownloadCommand(
|
||||
origin=worker_node,
|
||||
command=StartDownload(
|
||||
target_node_id=worker_node,
|
||||
shard_metadata=test_shard,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# Wait for DC events to flow through worker's _forward_events
|
||||
# (poll instead of sleeping a fixed duration to avoid flakiness on slow CI)
|
||||
await _wait_until(lambda: len(worker.out_for_delivery) >= 3)
|
||||
|
||||
# Verify at least one is a download progress event
|
||||
has_download_event = any(
|
||||
isinstance(fe.event, NodeDownloadProgress)
|
||||
for fe in worker.out_for_delivery.values()
|
||||
)
|
||||
assert has_download_event, (
|
||||
"out_for_delivery should contain DC-originated download events"
|
||||
)
|
||||
|
||||
# 3. Heal partition
|
||||
switch.connected = True
|
||||
|
||||
# Worker's _resend_out_for_delivery runs every ~1-2s.
|
||||
await _wait_until(
|
||||
lambda: master.state.last_seen.get(worker_node, initial_last_seen)
|
||||
> initial_last_seen,
|
||||
timeout=8.0,
|
||||
)
|
||||
|
||||
# 4. All events recovered — both worker heartbeats and DC download
|
||||
# progress events were retried and accepted by master.
|
||||
await _wait_until(lambda: len(worker.out_for_delivery) == 0, timeout=8.0)
|
||||
|
||||
# Master state reflects the download
|
||||
assert worker_node in master.state.downloads
|
||||
|
||||
await master.shutdown()
|
||||
worker.shutdown()
|
||||
dc.shutdown()
|
||||
Reference in New Issue
Block a user