diff --git a/src/exo/download/coordinator.py b/src/exo/download/coordinator.py index 0b6f9fcd3..0579499fe 100644 --- a/src/exo/download/coordinator.py +++ b/src/exo/download/coordinator.py @@ -1,7 +1,6 @@ import asyncio import socket from dataclasses import dataclass, field -from random import random import anyio from anyio import current_time @@ -22,13 +21,9 @@ from exo.shared.types.commands import ( ForwarderDownloadCommand, StartDownload, ) -from exo.shared.types.common import NodeId, SessionId, SystemId +from exo.shared.types.common import NodeId, SystemId from exo.shared.types.events import ( Event, - EventId, - # TODO(evan): just for acks, should delete this ASAP - GlobalForwarderEvent, - LocalForwarderEvent, NodeDownloadProgress, ) from exo.shared.types.worker.downloads import ( @@ -46,15 +41,9 @@ from exo.utils.lazy_task_group import LazyTaskGroup @dataclass class DownloadCoordinator: node_id: NodeId - session_id: SessionId shard_downloader: ShardDownloader download_command_receiver: Receiver[ForwarderDownloadCommand] - local_event_sender: Sender[LocalForwarderEvent] - - # ack stuff - _global_event_receiver: Receiver[GlobalForwarderEvent] - _out_for_delivery: dict[EventId, LocalForwarderEvent] = field(default_factory=dict) - + event_sender: Sender[Event] offline: bool = False _system_id: SystemId = field(default_factory=SystemId) @@ -63,16 +52,12 @@ class DownloadCoordinator: download_status: dict[ModelId, DownloadProgress] = field(default_factory=dict) active_downloads: dict[ModelId, asyncio.Task[None]] = field(default_factory=dict) - # Internal event channel for forwarding (initialized in __post_init__) - event_sender: Sender[Event] = field(init=False) - event_receiver: Receiver[Event] = field(init=False) _tg: LazyTaskGroup = field(init=False, default_factory=LazyTaskGroup) # Per-model throttle for download progress events _last_progress_time: dict[ModelId, float] = field(default_factory=dict) def __post_init__(self) -> None: - self.event_sender, self.event_receiver = channel[Event]() if self.offline: self.shard_downloader.set_internet_connection(False) self.shard_downloader.on_progress(self._download_progress_callback) @@ -128,10 +113,7 @@ class DownloadCoordinator: try: async with self._tg as tg: tg.start_soon(self._command_processor) - tg.start_soon(self._forward_events) tg.start_soon(self._emit_existing_download_progress) - tg.start_soon(self._resend_out_for_delivery) - tg.start_soon(self._clear_ofd) if not self.offline: tg.start_soon(self._check_internet_connection) finally: @@ -169,20 +151,6 @@ class DownloadCoordinator: def shutdown(self) -> None: self._tg.cancel_scope.cancel() - # directly copied from worker - async def _resend_out_for_delivery(self) -> None: - # This can also be massively tightened, we should check events are at least a certain age before resending. - # Exponential backoff would also certainly help here. - while True: - await anyio.sleep(1 + random()) - for event in self._out_for_delivery.copy().values(): - await self.local_event_sender.send(event) - - async def _clear_ofd(self) -> None: - with self._global_event_receiver as events: - async for event in events: - self._out_for_delivery.pop(event.event.event_id, None) - async def _command_processor(self) -> None: with self.download_command_receiver as commands: async for cmd in commands: @@ -355,23 +323,6 @@ class DownloadCoordinator: ) del self.download_status[model_id] - async def _forward_events(self) -> None: - idx = 0 - with self.event_receiver as events: - async for event in events: - fe = LocalForwarderEvent( - origin_idx=idx, - origin=self._system_id, - session=self.session_id, - event=event, - ) - idx += 1 - logger.debug( - f"DownloadCoordinator published event {idx}: {str(event)[:100]}" - ) - await self.local_event_sender.send(fe) - self._out_for_delivery[event.event_id] = fe - async def _emit_existing_download_progress(self) -> None: try: while True: diff --git a/src/exo/download/tests/test_coordinator_ack.py b/src/exo/download/tests/test_coordinator_ack.py deleted file mode 100644 index e99ec9126..000000000 --- a/src/exo/download/tests/test_coordinator_ack.py +++ /dev/null @@ -1,98 +0,0 @@ -from typing import Any - -import anyio -import pytest - -from exo.download.coordinator import DownloadCoordinator -from exo.download.shard_downloader import NoopShardDownloader -from exo.shared.models.model_cards import ModelCard, ModelTask -from exo.shared.types.common import ModelId, NodeId, SessionId -from exo.shared.types.events import ( - GlobalForwarderEvent, - LocalForwarderEvent, - NodeDownloadProgress, -) -from exo.shared.types.memory import Memory -from exo.shared.types.worker.downloads import ( - DownloadPending, -) -from exo.shared.types.worker.shards import PipelineShardMetadata -from exo.utils.channels import channel - -# Use the built‑in NoopShardDownloader directly – it already implements the required abstract interface. -# No additional subclass is needed for this test. - - -@pytest.mark.anyio -async def test_ack_behaviour(): - # Create channels (type Any for simplicity) - _, command_receiver = channel[Any]() - local_sender, _ = channel[Any]() - global_sender, global_receiver = channel[Any]() - - # Minimal identifiers - node_id = NodeId() - session_id = SessionId(master_node_id=node_id, election_clock=0) - - # Create a dummy model card and shard metadata - model_id = ModelId("test/model") - model_card = ModelCard( - model_id=model_id, - storage_size=Memory.from_bytes(0), - n_layers=1, - hidden_size=1, - supports_tensor=True, - tasks=[ModelTask.TextGeneration], - ) - shard = PipelineShardMetadata( - model_card=model_card, - device_rank=0, - world_size=1, - start_layer=0, - end_layer=1, - n_layers=1, - ) - - # Instantiate the coordinator with the dummy downloader - coord = DownloadCoordinator( - node_id=node_id, - session_id=session_id, - shard_downloader=NoopShardDownloader(), - download_command_receiver=command_receiver, - local_event_sender=local_sender, - _global_event_receiver=global_receiver, - ) - - async with anyio.create_task_group() as tg: - # Start the forwarding and ack‑clearing loops - tg.start_soon(coord._forward_events) # pyright: ignore[reportPrivateUsage] - tg.start_soon(coord._clear_ofd) # pyright: ignore[reportPrivateUsage] - - # Send a pending download progress event via the internal event sender - pending = DownloadPending( - node_id=node_id, - shard_metadata=shard, - model_directory="/tmp/model", - ) - await coord.event_sender.send(NodeDownloadProgress(download_progress=pending)) - # Allow the forwarder to process the event - await anyio.sleep(0.1) - - # There should be exactly one entry awaiting ACK - assert len(coord._out_for_delivery) == 1 # pyright: ignore[reportPrivateUsage] - # Retrieve the stored LocalForwarderEvent - stored_fe: LocalForwarderEvent = next(iter(coord._out_for_delivery.values())) # pyright: ignore[reportPrivateUsage] - # Simulate receiving a global ack for this event - ack = GlobalForwarderEvent( - origin_idx=0, - origin=node_id, - session=session_id, - event=stored_fe.event, - ) - await global_sender.send(ack) - # Give the clear‑ofd task a moment to process the ack - await anyio.sleep(0.1) - # The out‑for‑delivery map should now be empty - assert len(coord._out_for_delivery) == 0 # pyright: ignore[reportPrivateUsage] - # Cancel background tasks - tg.cancel_scope.cancel() diff --git a/src/exo/main.py b/src/exo/main.py index b98f92272..f1c439fb8 100644 --- a/src/exo/main.py +++ b/src/exo/main.py @@ -15,6 +15,7 @@ from exo.download.coordinator import DownloadCoordinator from exo.download.impl_shard_downloader import exo_shard_downloader from exo.master.api import API # TODO: should API be in master? from exo.master.main import Master +from exo.routing.event_router import EventRouter from exo.routing.router import Router, get_node_id_keypair from exo.shared.constants import EXO_LOG from exo.shared.election import Election, ElectionResult @@ -29,6 +30,7 @@ from exo.worker.main import Worker @dataclass class Node: router: Router + event_router: EventRouter download_coordinator: DownloadCoordinator | None worker: Worker | None election: Election # Every node participates in election, as we do want a node to become master even if it isn't a master candidate if no master candidates are present. @@ -52,6 +54,12 @@ class Node: await router.register_topic(topics.ELECTION_MESSAGES) await router.register_topic(topics.CONNECTION_MESSAGES) await router.register_topic(topics.DOWNLOAD_COMMANDS) + event_router = EventRouter( + session_id, + command_sender=router.sender(topics.COMMANDS), + external_outbound=router.sender(topics.LOCAL_EVENTS), + external_inbound=router.receiver(topics.GLOBAL_EVENTS), + ) logger.info(f"Starting node {node_id}") @@ -59,13 +67,10 @@ class Node: if not args.no_downloads: download_coordinator = DownloadCoordinator( node_id, - session_id, exo_shard_downloader(), + event_sender=event_router.sender(), download_command_receiver=router.receiver(topics.DOWNLOAD_COMMANDS), - local_event_sender=router.sender(topics.LOCAL_EVENTS), offline=args.offline, - # TODO(evan): remove - _global_event_receiver=router.receiver(topics.GLOBAL_EVENTS), ) else: download_coordinator = None @@ -73,9 +78,8 @@ class Node: if args.spawn_api: api = API( node_id, - session_id, port=args.api_port, - global_event_receiver=router.receiver(topics.GLOBAL_EVENTS), + event_receiver=event_router.receiver(), command_sender=router.sender(topics.COMMANDS), download_command_sender=router.sender(topics.DOWNLOAD_COMMANDS), election_receiver=router.receiver(topics.ELECTION_MESSAGES), @@ -86,9 +90,8 @@ class Node: if not args.no_worker: worker = Worker( node_id, - session_id, - global_event_receiver=router.receiver(topics.GLOBAL_EVENTS), - local_event_sender=router.sender(topics.LOCAL_EVENTS), + event_receiver=event_router.receiver(), + event_sender=event_router.sender(), command_sender=router.sender(topics.COMMANDS), download_command_sender=router.sender(topics.DOWNLOAD_COMMANDS), ) @@ -99,6 +102,7 @@ class Node: master = Master( node_id, session_id, + event_sender=event_router.sender(), global_event_sender=router.sender(topics.GLOBAL_EVENTS), local_event_receiver=router.receiver(topics.LOCAL_EVENTS), command_receiver=router.receiver(topics.COMMANDS), @@ -121,6 +125,7 @@ class Node: return cls( router, + event_router, download_coordinator, worker, election, @@ -136,6 +141,7 @@ class Node: signal.signal(signal.SIGINT, lambda _, __: self.shutdown()) signal.signal(signal.SIGTERM, lambda _, __: self.shutdown()) tg.start_soon(self.router.run) + tg.start_soon(self.event_router.run) tg.start_soon(self.election.run) if self.download_coordinator: tg.start_soon(self.download_coordinator.run) @@ -183,6 +189,7 @@ class Node: self.master = Master( self.node_id, result.session_id, + event_sender=self.event_router.sender(), global_event_sender=self.router.sender(topics.GLOBAL_EVENTS), local_event_receiver=self.router.receiver(topics.LOCAL_EVENTS), command_receiver=self.router.receiver(topics.COMMANDS), @@ -206,21 +213,24 @@ class Node: ) if result.is_new_master: await anyio.sleep(0) + self.event_router.shutdown() + self.event_router = EventRouter( + result.session_id, + self.router.sender(topics.COMMANDS), + self.router.receiver(topics.GLOBAL_EVENTS), + self.router.sender(topics.LOCAL_EVENTS), + ) + self._tg.start_soon(self.event_router.run) if self.download_coordinator: self.download_coordinator.shutdown() self.download_coordinator = DownloadCoordinator( self.node_id, - result.session_id, exo_shard_downloader(), + event_sender=self.event_router.sender(), download_command_receiver=self.router.receiver( topics.DOWNLOAD_COMMANDS ), - local_event_sender=self.router.sender(topics.LOCAL_EVENTS), offline=self.offline, - # TODO(evan): remove - _global_event_receiver=self.router.receiver( - topics.GLOBAL_EVENTS - ), ) self._tg.start_soon(self.download_coordinator.run) if self.worker: @@ -228,11 +238,8 @@ class Node: # TODO: add profiling etc to resource monitor self.worker = Worker( self.node_id, - result.session_id, - global_event_receiver=self.router.receiver( - topics.GLOBAL_EVENTS - ), - local_event_sender=self.router.sender(topics.LOCAL_EVENTS), + event_receiver=self.event_router.receiver(), + event_sender=self.event_router.sender(), command_sender=self.router.sender(topics.COMMANDS), download_command_sender=self.router.sender( topics.DOWNLOAD_COMMANDS @@ -240,7 +247,7 @@ class Node: ) self._tg.start_soon(self.worker.run) if self.api: - self.api.reset(result.session_id, result.won_clock) + self.api.reset(result.won_clock, self.event_router.receiver()) else: if self.api: self.api.unpause(result.won_clock) diff --git a/src/exo/master/api.py b/src/exo/master/api.py index a1c3f3613..c944b8596 100644 --- a/src/exo/master/api.py +++ b/src/exo/master/api.py @@ -140,11 +140,10 @@ from exo.shared.types.commands import ( TaskFinished, TextGeneration, ) -from exo.shared.types.common import CommandId, Id, NodeId, SessionId, SystemId +from exo.shared.types.common import CommandId, Id, NodeId, SystemId from exo.shared.types.events import ( ChunkGenerated, Event, - GlobalForwarderEvent, IndexedEvent, TracesMerged, ) @@ -172,7 +171,6 @@ from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta from exo.shared.types.worker.shards import Sharding from exo.utils.banner import print_startup_banner from exo.utils.channels import Receiver, Sender, channel -from exo.utils.event_buffer import OrderedBuffer from exo.utils.lazy_task_group import LazyTaskGroup _API_EVENT_LOG_DIR = EXO_EVENT_LOG_DIR / "api" @@ -196,10 +194,9 @@ class API: def __init__( self, node_id: NodeId, - session_id: SessionId, *, port: int, - global_event_receiver: Receiver[GlobalForwarderEvent], + event_receiver: Receiver[IndexedEvent], command_sender: Sender[ForwarderCommand], download_command_sender: Sender[ForwarderDownloadCommand], # This lets us pause the API if an election is running @@ -210,11 +207,9 @@ class API: self._system_id = SystemId() self.command_sender = command_sender self.download_command_sender = download_command_sender - self.global_event_receiver = global_event_receiver + self.event_receiver = event_receiver self.election_receiver = election_receiver - self.event_buffer: OrderedBuffer[Event] = OrderedBuffer[Event]() self.node_id: NodeId = node_id - self.session_id: SessionId = session_id self.last_completed_election: int = 0 self.port = port @@ -254,17 +249,18 @@ class API: self._image_store = ImageStore(EXO_IMAGE_CACHE_DIR) self._tg: LazyTaskGroup = LazyTaskGroup() - def reset(self, new_session_id: SessionId, result_clock: int): + def reset(self, result_clock: int, event_receiver: Receiver[IndexedEvent]): logger.info("Resetting API State") self._event_log.close() self._event_log = DiskEventLog(_API_EVENT_LOG_DIR) self.state = State() self._system_id = SystemId() - self.session_id = new_session_id - self.event_buffer = OrderedBuffer[Event]() self._text_generation_queues = {} self._image_generation_queues = {} self.unpause(result_clock) + self.event_receiver.close() + self.event_receiver = event_receiver + self._tg.start_soon(self._apply_state) def unpause(self, result_clock: int): logger.info("Unpausing API") @@ -1606,7 +1602,7 @@ class API: finally: self._event_log.close() self.command_sender.close() - self.global_event_receiver.close() + self.event_receiver.close() async def run_api(self, ev: anyio.Event): cfg = Config() @@ -1623,38 +1619,33 @@ class API: ) async def _apply_state(self): - with self.global_event_receiver as events: - async for f_event in events: - if f_event.session != self.session_id: - continue - if f_event.origin != self.session_id.master_node_id: - continue - self.event_buffer.ingest(f_event.origin_idx, f_event.event) - for idx, event in self.event_buffer.drain_indexed(): - self._event_log.append(event) - self.state = apply(self.state, IndexedEvent(event=event, idx=idx)) + idx = 0 + with self.event_receiver as events: + async for event in events: + self._event_log.append(event.event) + self.state = apply(self.state, event) + idx += 1 + event = event.event - if isinstance(event, ChunkGenerated): - if queue := self._image_generation_queues.get( - event.command_id, None - ): - assert isinstance(event.chunk, ImageChunk) - try: - await queue.send(event.chunk) - except BrokenResourceError: - self._image_generation_queues.pop( - event.command_id, None - ) - if queue := self._text_generation_queues.get( - event.command_id, None - ): - assert not isinstance(event.chunk, ImageChunk) - try: - await queue.send(event.chunk) - except BrokenResourceError: - self._text_generation_queues.pop(event.command_id, None) - if isinstance(event, TracesMerged): - self._save_merged_trace(event) + if isinstance(event, ChunkGenerated): + if queue := self._image_generation_queues.get( + event.command_id, None + ): + assert isinstance(event.chunk, ImageChunk) + try: + await queue.send(event.chunk) + except BrokenResourceError: + self._image_generation_queues.pop(event.command_id, None) + if queue := self._text_generation_queues.get( + event.command_id, None + ): + assert not isinstance(event.chunk, ImageChunk) + try: + await queue.send(event.chunk) + except BrokenResourceError: + self._text_generation_queues.pop(event.command_id, None) + if isinstance(event, TracesMerged): + self._save_merged_trace(event) def _save_merged_trace(self, event: TracesMerged) -> None: traces = [ diff --git a/src/exo/master/main.py b/src/exo/master/main.py index 74230f7ba..06b991840 100644 --- a/src/exo/master/main.py +++ b/src/exo/master/main.py @@ -60,7 +60,7 @@ from exo.shared.types.tasks import ( TextGeneration as TextGenerationTask, ) from exo.shared.types.worker.instances import InstanceId -from exo.utils.channels import Receiver, Sender, channel +from exo.utils.channels import Receiver, Sender from exo.utils.event_buffer import MultiSourceBuffer from exo.utils.lazy_task_group import LazyTaskGroup @@ -72,25 +72,21 @@ class Master: session_id: SessionId, *, command_receiver: Receiver[ForwarderCommand], + event_sender: Sender[Event], local_event_receiver: Receiver[LocalForwarderEvent], global_event_sender: Sender[GlobalForwarderEvent], download_command_sender: Sender[ForwarderDownloadCommand], ): - self.state = State() - self._tg: LazyTaskGroup = LazyTaskGroup() self.node_id = node_id self.session_id = session_id + self.state = State() + self._tg: LazyTaskGroup = LazyTaskGroup() self.command_task_mapping: dict[CommandId, TaskId] = {} self.command_receiver = command_receiver self.local_event_receiver = local_event_receiver self.global_event_sender = global_event_sender self.download_command_sender = download_command_sender - send, recv = channel[Event]() - self.event_sender: Sender[Event] = send - self._loopback_event_receiver: Receiver[Event] = recv - self._loopback_event_sender: Sender[LocalForwarderEvent] = ( - local_event_receiver.clone_sender() - ) + self.event_sender = event_sender self._system_id = SystemId() self._multi_buffer = MultiSourceBuffer[SystemId, Event]() self._event_log = DiskEventLog(EXO_EVENT_LOG_DIR / "master") @@ -104,15 +100,12 @@ class Master: async with self._tg as tg: tg.start_soon(self._event_processor) tg.start_soon(self._command_processor) - tg.start_soon(self._loopback_processor) tg.start_soon(self._plan) finally: self._event_log.close() self.global_event_sender.close() self.local_event_receiver.close() self.command_receiver.close() - self._loopback_event_sender.close() - self._loopback_event_receiver.close() async def shutdown(self): logger.info("Stopping Master") @@ -409,22 +402,6 @@ class Master: self._event_log.append(event) await self._send_event(indexed) - async def _loopback_processor(self) -> None: - # this would ideally not be necessary. - # this is WAY less hacky than how I was working around this before - local_index = 0 - with self._loopback_event_receiver as events: - async for event in events: - await self._loopback_event_sender.send( - LocalForwarderEvent( - origin=self._system_id, - origin_idx=local_index, - session=self.session_id, - event=event, - ) - ) - local_index += 1 - # This function is re-entrant, take care! async def _send_event(self, event: IndexedEvent): # Convenience method since this line is ugly diff --git a/src/exo/master/tests/test_master.py b/src/exo/master/tests/test_master.py index 21fabb215..f6f26dc94 100644 --- a/src/exo/master/tests/test_master.py +++ b/src/exo/master/tests/test_master.py @@ -17,6 +17,7 @@ from exo.shared.types.commands import ( ) from exo.shared.types.common import ModelId, NodeId, SessionId, SystemId from exo.shared.types.events import ( + Event, GlobalForwarderEvent, IndexedEvent, InstanceCreated, @@ -50,6 +51,7 @@ async def test_master(): command_sender, co_receiver = channel[ForwarderCommand]() local_event_sender, le_receiver = channel[LocalForwarderEvent]() fcds, _fcdr = channel[ForwarderDownloadCommand]() + ev_send, _ev_recv = channel[Event]() all_events: list[IndexedEvent] = [] @@ -67,6 +69,7 @@ async def test_master(): master = Master( node_id, session_id, + event_sender=ev_send, global_event_sender=ge_sender, local_event_receiver=le_receiver, command_receiver=co_receiver, diff --git a/src/exo/routing/event_router.py b/src/exo/routing/event_router.py index b27c6b1c2..4eac6f328 100644 --- a/src/exo/routing/event_router.py +++ b/src/exo/routing/event_router.py @@ -1,23 +1,47 @@ from dataclasses import dataclass, field from random import random + import anyio -from anyio import ClosedResourceError, BrokenResourceError -from exo.utils.channels import Sender, Receiver -from exo.utils.lazy_task_group import LazyTaskGroup -from exo.utils.event_buffer import OrderedBuffer +from anyio import BrokenResourceError, ClosedResourceError +from anyio.abc import CancelScope +from loguru import logger + +from exo.shared.types.commands import ForwarderCommand, RequestEventLog from exo.shared.types.common import SessionId, SystemId -from exo.shared.types.events import LocalForwarderEvent, GlobalForwarderEvent, Event, EventId +from exo.shared.types.events import ( + Event, + EventId, + GlobalForwarderEvent, + IndexedEvent, + LocalForwarderEvent, +) +from exo.utils.channels import Receiver, Sender, channel +from exo.utils.event_buffer import OrderedBuffer +from exo.utils.lazy_task_group import LazyTaskGroup + @dataclass class EventRouter: - _tg: LazyTaskGroup = field(init=False, default_factory=LazyTaskGroup) - internal_outbound: dict[SystemId, Sender[Event]] = field(init=False,default_factory=dict) + session_id: SessionId + command_sender: Sender[ForwarderCommand] external_inbound: Receiver[GlobalForwarderEvent] external_outbound: Sender[LocalForwarderEvent] - event_buffer: OrderedBuffer[Event] - session_id: SessionId - out_for_delivery: dict[EventId, tuple[float, LocalForwarderEvent]] = field(init=False, default_factory=dict) + _system_id: SystemId = field(init=False, default_factory=SystemId) + internal_outbound: list[Sender[IndexedEvent]] = field( + init=False, default_factory=list + ) + event_buffer: OrderedBuffer[Event] = field( + init=False, default_factory=OrderedBuffer + ) + out_for_delivery: dict[EventId, tuple[float, LocalForwarderEvent]] = field( + init=False, default_factory=dict + ) + _tg: LazyTaskGroup = field(init=False, default_factory=LazyTaskGroup) + _nack_cancel_scope: CancelScope | None = field(init=False, default=None) + _nack_attempts: int = field(init=False, default=0) + _nack_base_seconds: float = field(init=False, default=0.5) + _nack_cap_seconds: float = field(init=False, default=10.0) async def run(self): try: @@ -26,7 +50,7 @@ class EventRouter: tg.start_soon(self._simple_retry) finally: self.external_outbound.close() - for send in self.internal_outbound.values(): + for send in self.internal_outbound: send.close() # can make this better in future @@ -34,27 +58,40 @@ class EventRouter: while True: await anyio.sleep(1 + random()) # list here is a shallow clone for shared mutation - for id, (time, event) in list(self.out_for_delivery.items()): + for e_id, (time, event) in list(self.out_for_delivery.items()): if anyio.current_time() > time + 5: - self.out_for_delivery[id] = (anyio.current_time(), event) + self.out_for_delivery[e_id] = (anyio.current_time(), event) await self.external_outbound.send(event) + def sender(self) -> Sender[Event]: + send, recv = channel[Event]() + self._tg.start_soon(self._ingest, SystemId(), recv) + return send - def ingest(self, system_id: SystemId, recv: Receiver[Event]): - self._tg.start_soon(self._ingest, system_id, recv) + def receiver(self) -> Receiver[IndexedEvent]: + send, recv = channel[IndexedEvent]() + self.internal_outbound.append(send) + return recv + + def shutdown(self) -> None: + self._tg.cancel_scope.cancel() async def _ingest(self, system_id: SystemId, recv: Receiver[Event]): idx = 0 with recv as events: async for event in events: - f_ev = LocalForwarderEvent(origin_idx = idx, origin=system_id, session=self.session_id, event=event) + f_ev = LocalForwarderEvent( + origin_idx=idx, + origin=system_id, + session=self.session_id, + event=event, + ) idx += 1 await self.external_outbound.send(f_ev) self.out_for_delivery[event.event_id] = (anyio.current_time(), f_ev) - - async def _run_ext_in(self): + buf = OrderedBuffer[Event]() with self.external_inbound as events: async for event in events: if event.session != self.session_id: @@ -62,19 +99,60 @@ class EventRouter: if event.origin != self.session_id.master_node_id: continue - self.event_buffer.ingest(event.origin_idx, event.event) + buf.ingest(event.origin_idx, event.event) event_id = event.event.event_id if event_id in self.out_for_delivery: self.out_for_delivery.pop(event_id) - for event in self.event_buffer.drain(): - to_clear = set[SystemId]() - for s_id, sender in self.internal_outbound.items(): + drained = buf.drain_indexed() + if drained: + self._nack_attempts = 0 + if self._nack_cancel_scope: + self._nack_cancel_scope.cancel() + + if not drained and ( + self._nack_cancel_scope is None + or self._nack_cancel_scope.cancel_called + ): + # Request the next index. + self._tg.start_soon(self._nack_request, buf.next_idx_to_release) + continue + + for idx, event in drained: + to_clear = set[int]() + for i, sender in enumerate(self.internal_outbound): try: - await sender.send(event) + await sender.send(IndexedEvent(idx=idx, event=event)) except (ClosedResourceError, BrokenResourceError): - to_clear.add(s_id) - for s_id in to_clear: - self.internal_outbound.pop(s_id) + to_clear.add(i) + for i in sorted(to_clear, reverse=True): + self.internal_outbound.pop(i) + async def _nack_request(self, since_idx: int) -> None: + # We request all events after (and including) the missing index. + # This function is started whenever we receive an event that is out of sequence. + # It is cancelled as soon as we receiver an event that is in sequence. + if since_idx < 0: + logger.warning(f"Negative value encountered for nack request {since_idx=}") + since_idx = 0 + + with CancelScope() as scope: + self._nack_cancel_scope = scope + delay: float = self._nack_base_seconds * (2.0**self._nack_attempts) + delay = min(self._nack_cap_seconds, delay) + self._nack_attempts += 1 + try: + await anyio.sleep(delay) + logger.info( + f"Nack attempt {self._nack_attempts}: Requesting Event Log from {since_idx}" + ) + await self.command_sender.send( + ForwarderCommand( + origin=self._system_id, + command=RequestEventLog(since_idx=since_idx), + ) + ) + finally: + if self._nack_cancel_scope is scope: + self._nack_cancel_scope = None diff --git a/src/exo/worker/main.py b/src/exo/worker/main.py index df0060a12..ce5643531 100644 --- a/src/exo/worker/main.py +++ b/src/exo/worker/main.py @@ -1,9 +1,8 @@ from collections import defaultdict from datetime import datetime, timezone -from random import random import anyio -from anyio import CancelScope, fail_after +from anyio import fail_after from loguru import logger from exo.download.download_utils import resolve_model_in_path @@ -13,14 +12,12 @@ from exo.shared.types.api import ImageEditsTaskParams from exo.shared.types.commands import ( ForwarderCommand, ForwarderDownloadCommand, - RequestEventLog, StartDownload, ) -from exo.shared.types.common import CommandId, NodeId, SessionId, SystemId +from exo.shared.types.common import CommandId, NodeId, SystemId from exo.shared.types.events import ( Event, EventId, - GlobalForwarderEvent, IndexedEvent, InputChunkReceived, LocalForwarderEvent, @@ -46,7 +43,6 @@ from exo.shared.types.topology import Connection, SocketConnection from exo.shared.types.worker.downloads import DownloadCompleted from exo.shared.types.worker.runners import RunnerId from exo.utils.channels import Receiver, Sender, channel -from exo.utils.event_buffer import OrderedBuffer from exo.utils.info_gatherer.info_gatherer import GatheredInfo, InfoGatherer from exo.utils.info_gatherer.net_profile import check_reachable from exo.utils.keyed_backoff import KeyedBackoff @@ -59,38 +55,27 @@ class Worker: def __init__( self, node_id: NodeId, - session_id: SessionId, *, - global_event_receiver: Receiver[GlobalForwarderEvent], - local_event_sender: Sender[LocalForwarderEvent], + event_receiver: Receiver[IndexedEvent], + event_sender: Sender[Event], # This is for requesting updates. It doesn't need to be a general command sender right now, # but I think it's the correct way to be thinking about commands command_sender: Sender[ForwarderCommand], download_command_sender: Sender[ForwarderDownloadCommand], ): self.node_id: NodeId = node_id - self.session_id: SessionId = session_id - - self.global_event_receiver = global_event_receiver - self.local_event_sender = local_event_sender + self.event_receiver = event_receiver + self.event_sender = event_sender self.command_sender = command_sender self.download_command_sender = download_command_sender - self.event_buffer = OrderedBuffer[Event]() self.out_for_delivery: dict[EventId, LocalForwarderEvent] = {} self.state: State = State() self.runners: dict[RunnerId, RunnerSupervisor] = {} self._tg: LazyTaskGroup = LazyTaskGroup() - self._nack_cancel_scope: CancelScope | None = None - self._nack_attempts: int = 0 - self._nack_base_seconds: float = 0.5 - self._nack_cap_seconds: float = 10.0 - self._system_id = SystemId() - self.event_sender, self.event_receiver = channel[Event]() - # Buffer for input image chunks (for image editing) self.input_chunk_buffer: dict[CommandId, dict[int, str]] = {} self.input_chunk_counts: dict[CommandId, int] = {} @@ -108,14 +93,12 @@ class Worker: tg.start_soon(info_gatherer.run) tg.start_soon(self._forward_info, info_recv) tg.start_soon(self.plan_step) - tg.start_soon(self._resend_out_for_delivery) tg.start_soon(self._event_applier) - tg.start_soon(self._forward_events) tg.start_soon(self._poll_connection_updates) finally: # Actual shutdown code - waits for all tasks to complete before executing. logger.info("Stopping Worker") - self.local_event_sender.close() + self.event_sender.close() self.command_sender.close() self.download_command_sender.close() for runner in self.runners.values(): @@ -133,47 +116,22 @@ class Worker: ) async def _event_applier(self): - with self.global_event_receiver as events: - async for f_event in events: - if f_event.session != self.session_id: - continue - if f_event.origin != self.session_id.master_node_id: - continue - self.event_buffer.ingest(f_event.origin_idx, f_event.event) - event_id = f_event.event.event_id - if event_id in self.out_for_delivery: - del self.out_for_delivery[event_id] - + with self.event_receiver as events: + async for event in events: # 2. for each event, apply it to the state - indexed_events = self.event_buffer.drain_indexed() - if indexed_events: - self._nack_attempts = 0 + self.state = apply(self.state, event=event) + event = event.event - if not indexed_events and ( - self._nack_cancel_scope is None - or self._nack_cancel_scope.cancel_called - ): - # Request the next index. - self._tg.start_soon( - self._nack_request, self.state.last_event_applied_idx + 1 + # Buffer input image chunks for image editing + if isinstance(event, InputChunkReceived): + cmd_id = event.command_id + if cmd_id not in self.input_chunk_buffer: + self.input_chunk_buffer[cmd_id] = {} + self.input_chunk_counts[cmd_id] = event.chunk.total_chunks + + self.input_chunk_buffer[cmd_id][event.chunk.chunk_index] = ( + event.chunk.data ) - continue - elif indexed_events and self._nack_cancel_scope: - self._nack_cancel_scope.cancel() - - for idx, event in indexed_events: - self.state = apply(self.state, IndexedEvent(idx=idx, event=event)) - - # Buffer input image chunks for image editing - if isinstance(event, InputChunkReceived): - cmd_id = event.command_id - if cmd_id not in self.input_chunk_buffer: - self.input_chunk_buffer[cmd_id] = {} - self.input_chunk_counts[cmd_id] = event.chunk.total_chunks - - self.input_chunk_buffer[cmd_id][event.chunk.chunk_index] = ( - event.chunk.data - ) async def plan_step(self): while True: @@ -325,43 +283,6 @@ class Worker: instance.shard_assignments.node_to_runner[self.node_id] ].start_task(task) - async def _nack_request(self, since_idx: int) -> None: - # We request all events after (and including) the missing index. - # This function is started whenever we receive an event that is out of sequence. - # It is cancelled as soon as we receiver an event that is in sequence. - - if since_idx < 0: - logger.warning(f"Negative value encountered for nack request {since_idx=}") - since_idx = 0 - - with CancelScope() as scope: - self._nack_cancel_scope = scope - delay: float = self._nack_base_seconds * (2.0**self._nack_attempts) - delay = min(self._nack_cap_seconds, delay) - self._nack_attempts += 1 - try: - await anyio.sleep(delay) - logger.info( - f"Nack attempt {self._nack_attempts}: Requesting Event Log from {since_idx}" - ) - await self.command_sender.send( - ForwarderCommand( - origin=self._system_id, - command=RequestEventLog(since_idx=since_idx), - ) - ) - finally: - if self._nack_cancel_scope is scope: - self._nack_cancel_scope = None - - async def _resend_out_for_delivery(self) -> None: - # This can also be massively tightened, we should check events are at least a certain age before resending. - # Exponential backoff would also certainly help here. - while True: - await anyio.sleep(1 + random()) - for event in self.out_for_delivery.copy().values(): - await self.local_event_sender.send(event) - def _create_supervisor(self, task: CreateRunner) -> RunnerSupervisor: """Creates and stores a new AssignedRunner with initial downloading status.""" runner = RunnerSupervisor.create( @@ -372,21 +293,6 @@ class Worker: self._tg.start_soon(runner.run) return runner - async def _forward_events(self) -> None: - idx = 0 - with self.event_receiver as events: - async for event in events: - fe = LocalForwarderEvent( - origin_idx=idx, - origin=self._system_id, - session=self.session_id, - event=event, - ) - idx += 1 - logger.debug(f"Worker published event {idx}: {str(event)[:100]}") - await self.local_event_sender.send(fe) - self.out_for_delivery[event.event_id] = fe - async def _poll_connection_updates(self): while True: edges = set(