Compare commits

...

1 Commits

Author SHA1 Message Date
Evan
2ee0bce898 state compaction
introduces a new topic ("state_catchup") over which a full state can be
sent. currently the master sends the worker + api this new state, and
they update only if they have no other events applied - otherwise usual
NACK systems function

## testing

manually tested on two nodes
2026-01-23 14:15:06 +00:00
6 changed files with 65 additions and 11 deletions

View File

@@ -49,6 +49,7 @@ class Node:
await router.register_topic(topics.COMMANDS)
await router.register_topic(topics.ELECTION_MESSAGES)
await router.register_topic(topics.CONNECTION_MESSAGES)
await router.register_topic(topics.STATE_CATCHUP)
logger.info(f"Starting node {node_id}")
if args.spawn_api:
@@ -59,6 +60,7 @@ class Node:
global_event_receiver=router.receiver(topics.GLOBAL_EVENTS),
command_sender=router.sender(topics.COMMANDS),
election_receiver=router.receiver(topics.ELECTION_MESSAGES),
state_catchup_receiver=router.receiver(topics.STATE_CATCHUP),
)
else:
api = None
@@ -72,6 +74,7 @@ class Node:
global_event_receiver=router.receiver(topics.GLOBAL_EVENTS),
local_event_sender=router.sender(topics.LOCAL_EVENTS),
command_sender=router.sender(topics.COMMANDS),
state_catchup_receiver=router.receiver(topics.STATE_CATCHUP),
)
else:
worker = None
@@ -83,6 +86,7 @@ class Node:
global_event_sender=router.sender(topics.GLOBAL_EVENTS),
local_event_receiver=router.receiver(topics.LOCAL_EVENTS),
command_receiver=router.receiver(topics.COMMANDS),
state_catchup_sender=router.sender(topics.STATE_CATCHUP),
)
er_send, er_recv = channel[ElectionResult]()
@@ -153,6 +157,7 @@ class Node:
global_event_sender=self.router.sender(topics.GLOBAL_EVENTS),
local_event_receiver=self.router.receiver(topics.LOCAL_EVENTS),
command_receiver=self.router.receiver(topics.COMMANDS),
state_catchup_sender=self.router.sender(topics.STATE_CATCHUP),
)
self._tg.start_soon(self.master.run)
elif (
@@ -185,6 +190,9 @@ class Node:
),
local_event_sender=self.router.sender(topics.LOCAL_EVENTS),
command_sender=self.router.sender(topics.COMMANDS),
state_catchup_receiver=self.router.receiver(
topics.STATE_CATCHUP
),
)
self._tg.start_soon(self.worker.run)
if self.api:

View File

@@ -158,12 +158,14 @@ class API:
command_sender: Sender[ForwarderCommand],
# This lets us pause the API if an election is running
election_receiver: Receiver[ElectionMessage],
state_catchup_receiver: Receiver[State],
) -> None:
self.state = State()
self._event_log: list[Event] = []
self.command_sender = command_sender
self.global_event_receiver = global_event_receiver
self.election_receiver = election_receiver
self.state_catchup_receiver = state_catchup_receiver
self.event_buffer: OrderedBuffer[Event] = OrderedBuffer[Event]()
self.node_id: NodeId = node_id
self.session_id: SessionId = session_id
@@ -1231,6 +1233,7 @@ class API:
tg.start_soon(self._apply_state)
tg.start_soon(self._pause_on_new_election)
tg.start_soon(self._cleanup_expired_images)
tg.start_soon(self._state_catchup)
print_startup_banner(self.port)
await serve(
cast(ASGIFramework, self.app),
@@ -1241,6 +1244,22 @@ class API:
self.command_sender.close()
self.global_event_receiver.close()
async def _state_catchup(self):
with self.state_catchup_receiver as states:
async for state in states:
if (
self.state.last_event_applied_idx == -1
and state.last_event_applied_idx > self.state.last_event_applied_idx
):
logger.info(
f"API catching up state to idx {state.last_event_applied_idx}"
)
self.event_buffer.store = {}
self.event_buffer.next_idx_to_release = (
state.last_event_applied_idx + 1
)
self.state = state
async def _apply_state(self):
with self.global_event_receiver as events:
async for f_event in events:

View File

@@ -68,6 +68,8 @@ class Master:
# Send events to the forwarder to be indexed (usually from command processing)
# Ideally these would be MasterForwarderEvents but type system says no :(
global_event_sender: Sender[ForwarderEvent],
# not a fan but - send the entire state to a node so it can catchup without the whole event log.
state_catchup_sender: Sender[State],
):
self.state = State()
self._tg: TaskGroup = anyio.create_task_group()
@@ -77,6 +79,7 @@ class Master:
self.command_receiver = command_receiver
self.local_event_receiver = local_event_receiver
self.global_event_sender = global_event_sender
self.state_catchup_sender = state_catchup_sender
send, recv = channel[Event]()
self.event_sender: Sender[Event] = send
self._loopback_event_receiver: Receiver[Event] = recv
@@ -84,7 +87,6 @@ class Master:
local_event_receiver.clone_sender()
)
self._multi_buffer = MultiSourceBuffer[NodeId, Event]()
# TODO: not have this
self._event_log: list[Event] = []
async def run(self):
@@ -291,11 +293,17 @@ class Master:
command.finished_command_id
]
case RequestEventLog():
# We should just be able to send everything, since other buffers will ignore old messages
for i in range(command.since_idx, len(self._event_log)):
await self._send_event(
IndexedEvent(idx=i, event=self._event_log[i])
if command.since_idx == 0:
# This is an optimization, and should not be relied upon in theory.
logger.info(
f"Master sending catchup state for index {self.state.last_event_applied_idx}"
)
await self.state_catchup_sender.send(self.state)
else:
for i in range(command.since_idx, len(self._event_log)):
await self._send_event(
IndexedEvent(idx=i, event=self._event_log[i])
)
for event in generated_events:
await self.event_sender.send(event)
except ValueError as e:

View File

@@ -27,6 +27,7 @@ from exo.shared.types.memory import Memory
from exo.shared.types.profiling import (
MemoryUsage,
)
from exo.shared.types.state import State
from exo.shared.types.tasks import ChatCompletion as ChatCompletionTask
from exo.shared.types.tasks import TaskStatus
from exo.shared.types.worker.instances import (
@@ -47,6 +48,7 @@ async def test_master():
ge_sender, global_event_receiver = channel[ForwarderEvent]()
command_sender, co_receiver = channel[ForwarderCommand]()
local_event_sender, le_receiver = channel[ForwarderEvent]()
st_s, _st_r = channel[State]()
all_events: list[IndexedEvent] = []
@@ -67,6 +69,7 @@ async def test_master():
global_event_sender=ge_sender,
local_event_receiver=le_receiver,
command_receiver=co_receiver,
state_catchup_sender=st_s,
)
logger.info("run the master")
async with anyio.create_task_group() as tg:

View File

@@ -7,6 +7,7 @@ from exo.shared.types.commands import ForwarderCommand
from exo.shared.types.events import (
ForwarderEvent,
)
from exo.shared.types.state import State
from exo.utils.pydantic_ext import CamelCaseModel
@@ -45,3 +46,4 @@ ELECTION_MESSAGES = TypedTopic(
CONNECTION_MESSAGES = TypedTopic(
"connection_messages", PublishPolicy.Never, ConnectionMessage
)
STATE_CATCHUP = TypedTopic("state_catchup", PublishPolicy.Always, State)

View File

@@ -67,9 +67,8 @@ class Worker:
connection_message_receiver: Receiver[ConnectionMessage],
global_event_receiver: Receiver[ForwarderEvent],
local_event_sender: Sender[ForwarderEvent],
# This is for requesting updates. It doesn't need to be a general command sender right now,
# but I think it's the correct way to be thinking about commands
command_sender: Sender[ForwarderCommand],
state_catchup_receiver: Receiver[State],
):
self.node_id: NodeId = node_id
self.session_id: SessionId = session_id
@@ -79,6 +78,7 @@ class Worker:
self.global_event_receiver = global_event_receiver
self.local_event_sender = local_event_sender
self.state_catchup_receiver = state_catchup_receiver
self.local_event_index = 0
self.command_sender = command_sender
self.connection_message_receiver = connection_message_receiver
@@ -117,6 +117,7 @@ class Worker:
tg.start_soon(self._event_applier)
tg.start_soon(self._forward_events)
tg.start_soon(self._poll_connection_updates)
tg.start_soon(self._check_catchup_state)
# Actual shutdown code - waits for all tasks to complete before executing.
self.local_event_sender.close()
@@ -135,6 +136,22 @@ class Worker:
)
)
async def _check_catchup_state(self):
with self.state_catchup_receiver as states:
async for state in states:
if (
self.state.last_event_applied_idx == -1
and state.last_event_applied_idx > self.state.last_event_applied_idx
):
logger.info(
f"Worker catching up state to idx {state.last_event_applied_idx}"
)
self.event_buffer.store = {}
self.event_buffer.next_idx_to_release = (
state.last_event_applied_idx + 1
)
self.state = state
async def _event_applier(self):
with self.global_event_receiver as events:
async for f_event in events:
@@ -342,10 +359,7 @@ class Worker:
# We request all events after (and including) the missing index.
# This function is started whenever we receive an event that is out of sequence.
# It is cancelled as soon as we receiver an event that is in sequence.
if since_idx < 0:
logger.warning(f"Negative value encountered for nack request {since_idx=}")
since_idx = 0
assert since_idx >= 0
with CancelScope() as scope:
self._nack_cancel_scope = scope