Compare commits

..

2 Commits

Author SHA1 Message Date
Alex Cheema
922e8075d3 debug: add logging for NodeGatheredInfo event flow
Track when NodeGatheredInfo events are sent and applied to help
diagnose why joining nodes stay as "unknown".

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-23 16:03:22 -08:00
Alex Cheema
6ee745246d fix: preserve early-arriving events during state catchup
When a node joins a cluster and catches up state, events arriving before
catchup completes were being lost because the buffer was cleared entirely.
This could cause nodes to remain "unknown" if their NodeGatheredInfo event
arrived during this window.

Now we preserve events with idx >= new_idx instead of clearing all events.
Added debug logging to help diagnose any remaining issues.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-23 15:56:33 -08:00
6 changed files with 123 additions and 17 deletions

View File

@@ -53,6 +53,7 @@ class Node:
await router.register_topic(topics.COMMANDS)
await router.register_topic(topics.ELECTION_MESSAGES)
await router.register_topic(topics.CONNECTION_MESSAGES)
await router.register_topic(topics.STATE_CATCHUP)
await router.register_topic(topics.DOWNLOAD_COMMANDS)
logger.info(f"Starting node {node_id}")
@@ -82,6 +83,7 @@ class Node:
command_sender=router.sender(topics.COMMANDS),
download_command_sender=router.sender(topics.DOWNLOAD_COMMANDS),
election_receiver=router.receiver(topics.ELECTION_MESSAGES),
state_catchup_receiver=router.receiver(topics.STATE_CATCHUP),
)
else:
api = None
@@ -94,6 +96,7 @@ class Node:
global_event_receiver=router.receiver(topics.GLOBAL_EVENTS),
local_event_sender=router.sender(topics.LOCAL_EVENTS),
command_sender=router.sender(topics.COMMANDS),
state_catchup_receiver=router.receiver(topics.STATE_CATCHUP),
download_command_sender=router.sender(topics.DOWNLOAD_COMMANDS),
event_index_counter=event_index_counter,
)
@@ -107,6 +110,7 @@ class Node:
global_event_sender=router.sender(topics.GLOBAL_EVENTS),
local_event_receiver=router.receiver(topics.LOCAL_EVENTS),
command_receiver=router.receiver(topics.COMMANDS),
state_catchup_sender=router.sender(topics.STATE_CATCHUP),
)
er_send, er_recv = channel[ElectionResult]()
@@ -189,6 +193,7 @@ class Node:
global_event_sender=self.router.sender(topics.GLOBAL_EVENTS),
local_event_receiver=self.router.receiver(topics.LOCAL_EVENTS),
command_receiver=self.router.receiver(topics.COMMANDS),
state_catchup_sender=self.router.sender(topics.STATE_CATCHUP),
)
self._tg.start_soon(self.master.run)
elif (
@@ -235,6 +240,9 @@ class Node:
),
local_event_sender=self.router.sender(topics.LOCAL_EVENTS),
command_sender=self.router.sender(topics.COMMANDS),
state_catchup_receiver=self.router.receiver(
topics.STATE_CATCHUP
),
download_command_sender=self.router.sender(
topics.DOWNLOAD_COMMANDS
),

View File

@@ -166,6 +166,7 @@ class API:
download_command_sender: Sender[ForwarderDownloadCommand],
# This lets us pause the API if an election is running
election_receiver: Receiver[ElectionMessage],
state_catchup_receiver: Receiver[State],
) -> None:
self.state = State()
self._event_log: list[Event] = []
@@ -173,6 +174,7 @@ class API:
self.download_command_sender = download_command_sender
self.global_event_receiver = global_event_receiver
self.election_receiver = election_receiver
self.state_catchup_receiver = state_catchup_receiver
self.event_buffer: OrderedBuffer[Event] = OrderedBuffer[Event]()
self.node_id: NodeId = node_id
self.session_id: SessionId = session_id
@@ -1249,6 +1251,7 @@ class API:
tg.start_soon(self._apply_state)
tg.start_soon(self._pause_on_new_election)
tg.start_soon(self._cleanup_expired_images)
tg.start_soon(self._state_catchup)
print_startup_banner(self.port)
await serve(
cast(ASGIFramework, self.app),
@@ -1259,6 +1262,37 @@ class API:
self.command_sender.close()
self.global_event_receiver.close()
async def _state_catchup(self):
with self.state_catchup_receiver as states:
async for state in states:
if (
self.state.last_event_applied_idx == -1
and state.last_event_applied_idx > self.state.last_event_applied_idx
):
# DEBUG: Log buffer state BEFORE clearing
logger.warning(
f"STATE_CATCHUP: About to catch up. "
f"Current buffer indices: {sorted(self.event_buffer.store.keys())}, "
f"next_idx_to_release: {self.event_buffer.next_idx_to_release}, "
f"catching up to idx: {state.last_event_applied_idx}"
)
new_idx = state.last_event_applied_idx + 1
self.event_buffer.next_idx_to_release = new_idx
# Preserve events that arrived early but are still valid (idx >= new_idx)
# Remove stale events (idx < new_idx) to prevent memory growth
self.event_buffer.store = {
k: v for k, v in self.event_buffer.store.items() if k >= new_idx
}
self.state = state
# DEBUG: Log buffer state AFTER clearing
logger.warning(
f"STATE_CATCHUP: Catchup complete. "
f"Buffer preserved indices: {sorted(self.event_buffer.store.keys())}, "
f"new next_idx_to_release: {self.event_buffer.next_idx_to_release}"
)
async def _apply_state(self):
with self.global_event_receiver as events:
async for f_event in events:

View File

@@ -68,6 +68,8 @@ class Master:
# Send events to the forwarder to be indexed (usually from command processing)
# Ideally these would be MasterForwarderEvents but type system says no :(
global_event_sender: Sender[ForwarderEvent],
# not a fan but - send the entire state to a node so it can catchup without the whole event log.
state_catchup_sender: Sender[State],
):
self.state = State()
self._tg: TaskGroup = anyio.create_task_group()
@@ -77,6 +79,7 @@ class Master:
self.command_receiver = command_receiver
self.local_event_receiver = local_event_receiver
self.global_event_sender = global_event_sender
self.state_catchup_sender = state_catchup_sender
send, recv = channel[Event]()
self.event_sender: Sender[Event] = send
self._loopback_event_receiver: Receiver[Event] = recv
@@ -84,7 +87,6 @@ class Master:
local_event_receiver.clone_sender()
)
self._multi_buffer = MultiSourceBuffer[NodeId, Event]()
# TODO: not have this
self._event_log: list[Event] = []
async def run(self):
@@ -291,11 +293,17 @@ class Master:
command.finished_command_id
]
case RequestEventLog():
# We should just be able to send everything, since other buffers will ignore old messages
for i in range(command.since_idx, len(self._event_log)):
await self._send_event(
IndexedEvent(idx=i, event=self._event_log[i])
if command.since_idx == 0:
# This is an optimization, and should not be relied upon in theory.
logger.info(
f"Master sending catchup state for index {self.state.last_event_applied_idx}"
)
await self.state_catchup_sender.send(self.state)
else:
for i in range(command.since_idx, len(self._event_log)):
await self._send_event(
IndexedEvent(idx=i, event=self._event_log[i])
)
for event in generated_events:
await self.event_sender.send(event)
except ValueError as e:

View File

@@ -27,6 +27,7 @@ from exo.shared.types.memory import Memory
from exo.shared.types.profiling import (
MemoryUsage,
)
from exo.shared.types.state import State
from exo.shared.types.tasks import ChatCompletion as ChatCompletionTask
from exo.shared.types.tasks import TaskStatus
from exo.shared.types.worker.instances import (
@@ -47,6 +48,7 @@ async def test_master():
ge_sender, global_event_receiver = channel[ForwarderEvent]()
command_sender, co_receiver = channel[ForwarderCommand]()
local_event_sender, le_receiver = channel[ForwarderEvent]()
st_s, _st_r = channel[State]()
all_events: list[IndexedEvent] = []
@@ -67,6 +69,7 @@ async def test_master():
global_event_sender=ge_sender,
local_event_receiver=le_receiver,
command_receiver=co_receiver,
state_catchup_sender=st_s,
)
logger.info("run the master")
async with anyio.create_task_group() as tg:

View File

@@ -7,6 +7,7 @@ from exo.shared.types.commands import ForwarderCommand, ForwarderDownloadCommand
from exo.shared.types.events import (
ForwarderEvent,
)
from exo.shared.types.state import State
from exo.utils.pydantic_ext import CamelCaseModel
@@ -45,6 +46,7 @@ ELECTION_MESSAGES = TypedTopic(
CONNECTION_MESSAGES = TypedTopic(
"connection_messages", PublishPolicy.Never, ConnectionMessage
)
STATE_CATCHUP = TypedTopic("state_catchup", PublishPolicy.Always, State)
DOWNLOAD_COMMANDS = TypedTopic(
"download_commands", PublishPolicy.Always, ForwarderDownloadCommand
)

View File

@@ -60,9 +60,8 @@ class Worker:
connection_message_receiver: Receiver[ConnectionMessage],
global_event_receiver: Receiver[ForwarderEvent],
local_event_sender: Sender[ForwarderEvent],
# This is for requesting updates. It doesn't need to be a general command sender right now,
# but I think it's the correct way to be thinking about commands
command_sender: Sender[ForwarderCommand],
state_catchup_receiver: Receiver[State],
download_command_sender: Sender[ForwarderDownloadCommand],
event_index_counter: Iterator[int],
):
@@ -71,6 +70,8 @@ class Worker:
self.global_event_receiver = global_event_receiver
self.local_event_sender = local_event_sender
self.state_catchup_receiver = state_catchup_receiver
self.local_event_index = 0
self.event_index_counter = event_index_counter
self.command_sender = command_sender
self.download_command_sender = download_command_sender
@@ -110,6 +111,7 @@ class Worker:
tg.start_soon(self._event_applier)
tg.start_soon(self._forward_events)
tg.start_soon(self._poll_connection_updates)
tg.start_soon(self._check_catchup_state)
# Actual shutdown code - waits for all tasks to complete before executing.
self.local_event_sender.close()
@@ -121,13 +123,47 @@ class Worker:
async def _forward_info(self, recv: Receiver[GatheredInfo]):
with recv as info_stream:
async for info in info_stream:
await self.event_sender.send(
NodeGatheredInfo(
node_id=self.node_id,
when=str(datetime.now(tz=timezone.utc)),
info=info,
)
event = NodeGatheredInfo(
node_id=self.node_id,
when=str(datetime.now(tz=timezone.utc)),
info=info,
)
logger.warning(
f"NODE_GATHERED_INFO: Sending event for node {self.node_id}, "
f"event_id={event.event_id}"
)
await self.event_sender.send(event)
async def _check_catchup_state(self):
with self.state_catchup_receiver as states:
async for state in states:
if (
self.state.last_event_applied_idx == -1
and state.last_event_applied_idx > self.state.last_event_applied_idx
):
# DEBUG: Log buffer state BEFORE clearing
logger.warning(
f"STATE_CATCHUP: About to catch up. "
f"Current buffer indices: {sorted(self.event_buffer.store.keys())}, "
f"next_idx_to_release: {self.event_buffer.next_idx_to_release}, "
f"catching up to idx: {state.last_event_applied_idx}"
)
new_idx = state.last_event_applied_idx + 1
self.event_buffer.next_idx_to_release = new_idx
# Preserve events that arrived early but are still valid (idx >= new_idx)
# Remove stale events (idx < new_idx) to prevent memory growth
self.event_buffer.store = {
k: v for k, v in self.event_buffer.store.items() if k >= new_idx
}
self.state = state
# DEBUG: Log buffer state AFTER clearing
logger.warning(
f"STATE_CATCHUP: Catchup complete. "
f"Buffer preserved indices: {sorted(self.event_buffer.store.keys())}, "
f"new next_idx_to_release: {self.event_buffer.next_idx_to_release}"
)
async def _event_applier(self):
with self.global_event_receiver as events:
@@ -139,8 +175,20 @@ class Worker:
if event_id in self.out_for_delivery:
del self.out_for_delivery[event_id]
# DEBUG: Log what was ingested
logger.warning(
f"EVENT_APPLIER: Ingested event idx={f_event.origin_idx}, "
f"buffer keys now: {sorted(self.event_buffer.store.keys())}"
)
# 2. for each event, apply it to the state
indexed_events = self.event_buffer.drain_indexed()
# DEBUG: Log drain results
logger.warning(
f"EVENT_APPLIER: Drained {len(indexed_events)} events, "
f"next_idx_to_release now: {self.event_buffer.next_idx_to_release}"
)
if indexed_events:
self._nack_attempts = 0
@@ -157,6 +205,12 @@ class Worker:
self._nack_cancel_scope.cancel()
for idx, event in indexed_events:
# DEBUG: Log NodeGatheredInfo events
if isinstance(event, NodeGatheredInfo):
logger.warning(
f"NODE_GATHERED_INFO: Applying event idx={idx} for node {event.node_id}, "
f"event_id={event.event_id}"
)
self.state = apply(self.state, IndexedEvent(idx=idx, event=event))
# Buffer input image chunks for image editing
@@ -318,10 +372,7 @@ class Worker:
# We request all events after (and including) the missing index.
# This function is started whenever we receive an event that is out of sequence.
# It is cancelled as soon as we receiver an event that is in sequence.
if since_idx < 0:
logger.warning(f"Negative value encountered for nack request {since_idx=}")
since_idx = 0
assert since_idx >= 0
with CancelScope() as scope:
self._nack_cancel_scope = scope