mirror of
https://github.com/exo-explore/exo.git
synced 2026-01-24 05:48:44 -05:00
Compare commits
1 Commits
alexcheema
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ff4a2022f7 |
@@ -53,7 +53,6 @@ class Node:
|
||||
await router.register_topic(topics.COMMANDS)
|
||||
await router.register_topic(topics.ELECTION_MESSAGES)
|
||||
await router.register_topic(topics.CONNECTION_MESSAGES)
|
||||
await router.register_topic(topics.STATE_CATCHUP)
|
||||
await router.register_topic(topics.DOWNLOAD_COMMANDS)
|
||||
|
||||
logger.info(f"Starting node {node_id}")
|
||||
@@ -83,7 +82,6 @@ class Node:
|
||||
command_sender=router.sender(topics.COMMANDS),
|
||||
download_command_sender=router.sender(topics.DOWNLOAD_COMMANDS),
|
||||
election_receiver=router.receiver(topics.ELECTION_MESSAGES),
|
||||
state_catchup_receiver=router.receiver(topics.STATE_CATCHUP),
|
||||
)
|
||||
else:
|
||||
api = None
|
||||
@@ -96,7 +94,6 @@ class Node:
|
||||
global_event_receiver=router.receiver(topics.GLOBAL_EVENTS),
|
||||
local_event_sender=router.sender(topics.LOCAL_EVENTS),
|
||||
command_sender=router.sender(topics.COMMANDS),
|
||||
state_catchup_receiver=router.receiver(topics.STATE_CATCHUP),
|
||||
download_command_sender=router.sender(topics.DOWNLOAD_COMMANDS),
|
||||
event_index_counter=event_index_counter,
|
||||
)
|
||||
@@ -110,7 +107,6 @@ class Node:
|
||||
global_event_sender=router.sender(topics.GLOBAL_EVENTS),
|
||||
local_event_receiver=router.receiver(topics.LOCAL_EVENTS),
|
||||
command_receiver=router.receiver(topics.COMMANDS),
|
||||
state_catchup_sender=router.sender(topics.STATE_CATCHUP),
|
||||
)
|
||||
|
||||
er_send, er_recv = channel[ElectionResult]()
|
||||
@@ -193,7 +189,6 @@ class Node:
|
||||
global_event_sender=self.router.sender(topics.GLOBAL_EVENTS),
|
||||
local_event_receiver=self.router.receiver(topics.LOCAL_EVENTS),
|
||||
command_receiver=self.router.receiver(topics.COMMANDS),
|
||||
state_catchup_sender=self.router.sender(topics.STATE_CATCHUP),
|
||||
)
|
||||
self._tg.start_soon(self.master.run)
|
||||
elif (
|
||||
@@ -240,9 +235,6 @@ class Node:
|
||||
),
|
||||
local_event_sender=self.router.sender(topics.LOCAL_EVENTS),
|
||||
command_sender=self.router.sender(topics.COMMANDS),
|
||||
state_catchup_receiver=self.router.receiver(
|
||||
topics.STATE_CATCHUP
|
||||
),
|
||||
download_command_sender=self.router.sender(
|
||||
topics.DOWNLOAD_COMMANDS
|
||||
),
|
||||
|
||||
@@ -166,7 +166,6 @@ class API:
|
||||
download_command_sender: Sender[ForwarderDownloadCommand],
|
||||
# This lets us pause the API if an election is running
|
||||
election_receiver: Receiver[ElectionMessage],
|
||||
state_catchup_receiver: Receiver[State],
|
||||
) -> None:
|
||||
self.state = State()
|
||||
self._event_log: list[Event] = []
|
||||
@@ -174,7 +173,6 @@ class API:
|
||||
self.download_command_sender = download_command_sender
|
||||
self.global_event_receiver = global_event_receiver
|
||||
self.election_receiver = election_receiver
|
||||
self.state_catchup_receiver = state_catchup_receiver
|
||||
self.event_buffer: OrderedBuffer[Event] = OrderedBuffer[Event]()
|
||||
self.node_id: NodeId = node_id
|
||||
self.session_id: SessionId = session_id
|
||||
@@ -1251,7 +1249,6 @@ class API:
|
||||
tg.start_soon(self._apply_state)
|
||||
tg.start_soon(self._pause_on_new_election)
|
||||
tg.start_soon(self._cleanup_expired_images)
|
||||
tg.start_soon(self._state_catchup)
|
||||
print_startup_banner(self.port)
|
||||
await serve(
|
||||
cast(ASGIFramework, self.app),
|
||||
@@ -1262,22 +1259,6 @@ class API:
|
||||
self.command_sender.close()
|
||||
self.global_event_receiver.close()
|
||||
|
||||
async def _state_catchup(self):
|
||||
with self.state_catchup_receiver as states:
|
||||
async for state in states:
|
||||
if (
|
||||
self.state.last_event_applied_idx == -1
|
||||
and state.last_event_applied_idx > self.state.last_event_applied_idx
|
||||
):
|
||||
logger.info(
|
||||
f"API catching up state to idx {state.last_event_applied_idx}"
|
||||
)
|
||||
self.event_buffer.store = {}
|
||||
self.event_buffer.next_idx_to_release = (
|
||||
state.last_event_applied_idx + 1
|
||||
)
|
||||
self.state = state
|
||||
|
||||
async def _apply_state(self):
|
||||
with self.global_event_receiver as events:
|
||||
async for f_event in events:
|
||||
|
||||
@@ -68,8 +68,6 @@ class Master:
|
||||
# Send events to the forwarder to be indexed (usually from command processing)
|
||||
# Ideally these would be MasterForwarderEvents but type system says no :(
|
||||
global_event_sender: Sender[ForwarderEvent],
|
||||
# not a fan but - send the entire state to a node so it can catchup without the whole event log.
|
||||
state_catchup_sender: Sender[State],
|
||||
):
|
||||
self.state = State()
|
||||
self._tg: TaskGroup = anyio.create_task_group()
|
||||
@@ -79,7 +77,6 @@ class Master:
|
||||
self.command_receiver = command_receiver
|
||||
self.local_event_receiver = local_event_receiver
|
||||
self.global_event_sender = global_event_sender
|
||||
self.state_catchup_sender = state_catchup_sender
|
||||
send, recv = channel[Event]()
|
||||
self.event_sender: Sender[Event] = send
|
||||
self._loopback_event_receiver: Receiver[Event] = recv
|
||||
@@ -87,6 +84,7 @@ class Master:
|
||||
local_event_receiver.clone_sender()
|
||||
)
|
||||
self._multi_buffer = MultiSourceBuffer[NodeId, Event]()
|
||||
# TODO: not have this
|
||||
self._event_log: list[Event] = []
|
||||
|
||||
async def run(self):
|
||||
@@ -293,17 +291,11 @@ class Master:
|
||||
command.finished_command_id
|
||||
]
|
||||
case RequestEventLog():
|
||||
if command.since_idx == 0:
|
||||
# This is an optimization, and should not be relied upon in theory.
|
||||
logger.info(
|
||||
f"Master sending catchup state for index {self.state.last_event_applied_idx}"
|
||||
# We should just be able to send everything, since other buffers will ignore old messages
|
||||
for i in range(command.since_idx, len(self._event_log)):
|
||||
await self._send_event(
|
||||
IndexedEvent(idx=i, event=self._event_log[i])
|
||||
)
|
||||
await self.state_catchup_sender.send(self.state)
|
||||
else:
|
||||
for i in range(command.since_idx, len(self._event_log)):
|
||||
await self._send_event(
|
||||
IndexedEvent(idx=i, event=self._event_log[i])
|
||||
)
|
||||
for event in generated_events:
|
||||
await self.event_sender.send(event)
|
||||
except ValueError as e:
|
||||
|
||||
@@ -27,7 +27,6 @@ from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.profiling import (
|
||||
MemoryUsage,
|
||||
)
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.types.tasks import ChatCompletion as ChatCompletionTask
|
||||
from exo.shared.types.tasks import TaskStatus
|
||||
from exo.shared.types.worker.instances import (
|
||||
@@ -48,7 +47,6 @@ async def test_master():
|
||||
ge_sender, global_event_receiver = channel[ForwarderEvent]()
|
||||
command_sender, co_receiver = channel[ForwarderCommand]()
|
||||
local_event_sender, le_receiver = channel[ForwarderEvent]()
|
||||
st_s, _st_r = channel[State]()
|
||||
|
||||
all_events: list[IndexedEvent] = []
|
||||
|
||||
@@ -69,7 +67,6 @@ async def test_master():
|
||||
global_event_sender=ge_sender,
|
||||
local_event_receiver=le_receiver,
|
||||
command_receiver=co_receiver,
|
||||
state_catchup_sender=st_s,
|
||||
)
|
||||
logger.info("run the master")
|
||||
async with anyio.create_task_group() as tg:
|
||||
|
||||
@@ -7,7 +7,6 @@ from exo.shared.types.commands import ForwarderCommand, ForwarderDownloadCommand
|
||||
from exo.shared.types.events import (
|
||||
ForwarderEvent,
|
||||
)
|
||||
from exo.shared.types.state import State
|
||||
from exo.utils.pydantic_ext import CamelCaseModel
|
||||
|
||||
|
||||
@@ -46,7 +45,6 @@ ELECTION_MESSAGES = TypedTopic(
|
||||
CONNECTION_MESSAGES = TypedTopic(
|
||||
"connection_messages", PublishPolicy.Never, ConnectionMessage
|
||||
)
|
||||
STATE_CATCHUP = TypedTopic("state_catchup", PublishPolicy.Always, State)
|
||||
DOWNLOAD_COMMANDS = TypedTopic(
|
||||
"download_commands", PublishPolicy.Always, ForwarderDownloadCommand
|
||||
)
|
||||
|
||||
@@ -60,8 +60,9 @@ class Worker:
|
||||
connection_message_receiver: Receiver[ConnectionMessage],
|
||||
global_event_receiver: Receiver[ForwarderEvent],
|
||||
local_event_sender: Sender[ForwarderEvent],
|
||||
# This is for requesting updates. It doesn't need to be a general command sender right now,
|
||||
# but I think it's the correct way to be thinking about commands
|
||||
command_sender: Sender[ForwarderCommand],
|
||||
state_catchup_receiver: Receiver[State],
|
||||
download_command_sender: Sender[ForwarderDownloadCommand],
|
||||
event_index_counter: Iterator[int],
|
||||
):
|
||||
@@ -70,8 +71,6 @@ class Worker:
|
||||
|
||||
self.global_event_receiver = global_event_receiver
|
||||
self.local_event_sender = local_event_sender
|
||||
self.state_catchup_receiver = state_catchup_receiver
|
||||
self.local_event_index = 0
|
||||
self.event_index_counter = event_index_counter
|
||||
self.command_sender = command_sender
|
||||
self.download_command_sender = download_command_sender
|
||||
@@ -111,7 +110,6 @@ class Worker:
|
||||
tg.start_soon(self._event_applier)
|
||||
tg.start_soon(self._forward_events)
|
||||
tg.start_soon(self._poll_connection_updates)
|
||||
tg.start_soon(self._check_catchup_state)
|
||||
|
||||
# Actual shutdown code - waits for all tasks to complete before executing.
|
||||
self.local_event_sender.close()
|
||||
@@ -131,22 +129,6 @@ class Worker:
|
||||
)
|
||||
)
|
||||
|
||||
async def _check_catchup_state(self):
|
||||
with self.state_catchup_receiver as states:
|
||||
async for state in states:
|
||||
if (
|
||||
self.state.last_event_applied_idx == -1
|
||||
and state.last_event_applied_idx > self.state.last_event_applied_idx
|
||||
):
|
||||
logger.info(
|
||||
f"Worker catching up state to idx {state.last_event_applied_idx}"
|
||||
)
|
||||
self.event_buffer.store = {}
|
||||
self.event_buffer.next_idx_to_release = (
|
||||
state.last_event_applied_idx + 1
|
||||
)
|
||||
self.state = state
|
||||
|
||||
async def _event_applier(self):
|
||||
with self.global_event_receiver as events:
|
||||
async for f_event in events:
|
||||
@@ -336,7 +318,10 @@ class Worker:
|
||||
# We request all events after (and including) the missing index.
|
||||
# This function is started whenever we receive an event that is out of sequence.
|
||||
# It is cancelled as soon as we receiver an event that is in sequence.
|
||||
assert since_idx >= 0
|
||||
|
||||
if since_idx < 0:
|
||||
logger.warning(f"Negative value encountered for nack request {since_idx=}")
|
||||
since_idx = 0
|
||||
|
||||
with CancelScope() as scope:
|
||||
self._nack_cancel_scope = scope
|
||||
|
||||
Reference in New Issue
Block a user