Compare commits

...

3 Commits

Author SHA1 Message Date
Alex Cheema
ad60cf09fa Merge remote-tracking branch 'origin/main' into alexcheema/event-index-liveness 2026-02-13 10:09:53 -08:00
Alex Cheema
ebe02c4722 Merge remote-tracking branch 'origin/main' into alexcheema/event-index-liveness 2026-02-11 16:25:38 -08:00
Alex Cheema
bf034b61e4 refactor: replace timestamp-based node liveness with event-index staleness
Remove all datetime timestamps from the event-sourced state. Instead of
tracking `last_seen: Mapping[NodeId, datetime]` and checking wall-clock
deltas, track `last_event_index_by_node: Mapping[NodeId, int]` — the
global event index at which each node was last heard from.

The master's planning loop now compares snapshots of each node's last
event index across consecutive cycles. If a node produces no new events
for 3 consecutive plan cycles (~30s), it is disconnected.

This eliminates false-positive node removals caused by slow info
gathering tasks, since the timeout is now purely based on whether
events are flowing — not on wall-clock timing of individual tasks.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 13:35:01 -08:00
6 changed files with 41 additions and 37 deletions

View File

@@ -1,5 +1,3 @@
from datetime import datetime, timedelta, timezone
import anyio
from anyio.abc import TaskGroup
from loguru import logger
@@ -35,8 +33,7 @@ from exo.shared.types.events import (
IndexedEvent,
InputChunkReceived,
InstanceDeleted,
NodeGatheredInfo,
NodeTimedOut,
NodeDisconnected,
TaskCreated,
TaskDeleted,
TraceEventData,
@@ -92,6 +89,8 @@ class Master:
self._event_log = DiskEventLog(EXO_EVENT_LOG_DIR / "master")
self._pending_traces: dict[TaskId, dict[int, list[TraceEventData]]] = {}
self._expected_ranks: dict[TaskId, set[int]] = {}
self._last_checked_indices: dict[NodeId, int] = {}
self._stale_cycles: dict[NodeId, int] = {}
async def run(self):
logger.info("Starting Master")
@@ -358,12 +357,25 @@ class Master:
)
break
# time out dead nodes
for node_id, time in self.state.last_seen.items():
now = datetime.now(tz=timezone.utc)
if now - time > timedelta(seconds=30):
logger.info(f"Manually removing node {node_id} due to inactivity")
await self.event_sender.send(NodeTimedOut(node_id=node_id))
# Check which nodes have gone stale
current_indices = dict(self.state.last_event_index_by_node)
for node_id in list(current_indices):
last_checked = self._last_checked_indices.get(node_id, -1)
if current_indices[node_id] == last_checked:
self._stale_cycles[node_id] = self._stale_cycles.get(node_id, 0) + 1
else:
self._stale_cycles.pop(node_id, None)
self._last_checked_indices = current_indices
# Disconnect nodes stale for >= 3 consecutive cycles (~30s)
for node_id, cycles in list(self._stale_cycles.items()):
if cycles >= 3:
logger.info(
f"Removing node {node_id}: no events for {cycles} plan cycles"
)
del self._stale_cycles[node_id]
await self.event_sender.send(NodeDisconnected(node_id=node_id))
await anyio.sleep(10)
@@ -387,10 +399,6 @@ class Master:
indexed = IndexedEvent(event=event, idx=len(self._event_log))
self.state = apply(self.state, indexed)
event._master_time_stamp = datetime.now(tz=timezone.utc) # pyright: ignore[reportPrivateUsage]
if isinstance(event, NodeGatheredInfo):
event.when = str(datetime.now(tz=timezone.utc))
self._event_log.append(event)
await self._send_event(indexed)

View File

@@ -1,4 +1,3 @@
from datetime import datetime, timezone
from typing import Sequence
import anyio
@@ -85,7 +84,6 @@ async def test_master():
session=session_id,
event=(
NodeGatheredInfo(
when=str(datetime.now(tz=timezone.utc)),
node_id=node_id,
info=MemoryUsage(
ram_total=Memory.from_bytes(678948 * 1024),

View File

@@ -1,6 +1,5 @@
import copy
from collections.abc import Mapping, Sequence
from datetime import datetime
from loguru import logger
@@ -12,9 +11,9 @@ from exo.shared.types.events import (
InputChunkReceived,
InstanceCreated,
InstanceDeleted,
NodeDisconnected,
NodeDownloadProgress,
NodeGatheredInfo,
NodeTimedOut,
RunnerDeleted,
RunnerStatusUpdated,
TaskAcknowledged,
@@ -72,8 +71,8 @@ def event_apply(event: Event, state: State) -> State:
return apply_instance_created(event, state)
case InstanceDeleted():
return apply_instance_deleted(event, state)
case NodeTimedOut():
return apply_node_timed_out(event, state)
case NodeDisconnected():
return apply_node_disconnected(event, state)
case NodeDownloadProgress():
return apply_node_download_progress(event, state)
case NodeGatheredInfo():
@@ -104,7 +103,13 @@ def apply(state: State, event: IndexedEvent) -> State:
)
assert state.last_event_applied_idx == event.idx - 1
new_state: State = event_apply(event.event, state)
return new_state.model_copy(update={"last_event_applied_idx": event.idx})
update: dict[str, object] = {"last_event_applied_idx": event.idx}
if isinstance(event.event, NodeGatheredInfo):
update["last_event_index_by_node"] = {
**new_state.last_event_index_by_node,
event.event.node_id: event.idx,
}
return new_state.model_copy(update=update)
def apply_node_download_progress(event: NodeDownloadProgress, state: State) -> State:
@@ -208,11 +213,13 @@ def apply_runner_deleted(event: RunnerDeleted, state: State) -> State:
return state.model_copy(update={"runners": new_runners})
def apply_node_timed_out(event: NodeTimedOut, state: State) -> State:
def apply_node_disconnected(event: NodeDisconnected, state: State) -> State:
topology = copy.deepcopy(state.topology)
topology.remove_node(event.node_id)
last_seen = {
key: value for key, value in state.last_seen.items() if key != event.node_id
last_event_index_by_node = {
key: value
for key, value in state.last_event_index_by_node.items()
if key != event.node_id
}
downloads = {
key: value for key, value in state.downloads.items() if key != event.node_id
@@ -262,7 +269,7 @@ def apply_node_timed_out(event: NodeTimedOut, state: State) -> State:
update={
"downloads": downloads,
"topology": topology,
"last_seen": last_seen,
"last_event_index_by_node": last_event_index_by_node,
"node_identities": node_identities,
"node_memory": node_memory,
"node_disk": node_disk,
@@ -283,10 +290,6 @@ def apply_node_gathered_info(event: NodeGatheredInfo, state: State) -> State:
# Build update dict with only the mappings that change
update: dict[str, object] = {
"last_seen": {
**state.last_seen,
event.node_id: datetime.fromisoformat(event.when),
},
"topology": topology,
}

View File

@@ -77,14 +77,12 @@ class RunnerDeleted(BaseEvent):
runner_id: RunnerId
class NodeTimedOut(BaseEvent):
class NodeDisconnected(BaseEvent):
node_id: NodeId
# TODO: bikeshed this name
class NodeGatheredInfo(BaseEvent):
node_id: NodeId
when: str # this is a manually cast datetime overrode by the master when the event is indexed, rather than the local time on the device
info: GatheredInfo
@@ -143,7 +141,7 @@ Event = (
| InstanceDeleted
| RunnerStatusUpdated
| RunnerDeleted
| NodeTimedOut
| NodeDisconnected
| NodeGatheredInfo
| NodeDownloadProgress
| ChunkGenerated

View File

@@ -1,5 +1,4 @@
from collections.abc import Mapping, Sequence
from datetime import datetime
from typing import Any, cast
from pydantic import ConfigDict, Field, field_serializer, field_validator
@@ -44,7 +43,7 @@ class State(CamelCaseModel):
runners: Mapping[RunnerId, RunnerStatus] = {}
downloads: Mapping[NodeId, Sequence[DownloadProgress]] = {}
tasks: Mapping[TaskId, Task] = {}
last_seen: Mapping[NodeId, datetime] = {}
last_event_index_by_node: Mapping[NodeId, int] = {}
topology: Topology = Field(default_factory=Topology)
last_event_applied_idx: int = Field(default=-1, ge=-1)

View File

@@ -1,5 +1,4 @@
from collections import defaultdict
from datetime import datetime, timezone
from random import random
from typing import Iterator
@@ -123,7 +122,6 @@ class Worker:
await self.event_sender.send(
NodeGatheredInfo(
node_id=self.node_id,
when=str(datetime.now(tz=timezone.utc)),
info=info,
)
)