From 10c905c8ddb9a206e57cb8efd3b7bc3480484a61 Mon Sep 17 00:00:00 2001 From: Evan Quiney Date: Tue, 2 Dec 2025 11:35:02 +0000 Subject: [PATCH] worker no longer gets stuck after shutdown --- justfile | 2 +- src/exo/master/tests/test_master.py | 1 + src/exo/master/tests/test_placement.py | 1 + src/exo/shared/election.py | 77 +++++++++++----------- src/exo/shared/tests/test_election.py | 33 ++-------- src/exo/shared/types/tasks.py | 1 + src/exo/worker/main.py | 26 +++++--- src/exo/worker/runner/bootstrap.py | 5 +- src/exo/worker/runner/runner_supervisor.py | 2 + 9 files changed, 71 insertions(+), 77 deletions(-) diff --git a/justfile b/justfile index 2ef99049..676e66fc 100644 --- a/justfile +++ b/justfile @@ -22,5 +22,5 @@ rust-rebuild: clean: rm -rf **/__pycache__ - rm -rf rust/target + sudo rm -rf rust/target rm -rf .venv diff --git a/src/exo/master/tests/test_master.py b/src/exo/master/tests/test_master.py index 5aa26d48..c5d3ae47 100644 --- a/src/exo/master/tests/test_master.py +++ b/src/exo/master/tests/test_master.py @@ -125,6 +125,7 @@ async def test_master(): ), sharding=Sharding.Pipeline, instance_meta=InstanceMeta.MlxRing, + min_nodes=1, ) ), ) diff --git a/src/exo/master/tests/test_placement.py b/src/exo/master/tests/test_placement.py index 41cd8360..3c4fe0ee 100644 --- a/src/exo/master/tests/test_placement.py +++ b/src/exo/master/tests/test_placement.py @@ -57,6 +57,7 @@ def create_instance_command(model_meta: ModelMetadata) -> CreateInstance: model_meta=model_meta, sharding=Sharding.Pipeline, instance_meta=InstanceMeta.MlxRing, + min_nodes=1, ) diff --git a/src/exo/shared/election.py b/src/exo/shared/election.py index 071914fa..206dcf59 100644 --- a/src/exo/shared/election.py +++ b/src/exo/shared/election.py @@ -94,27 +94,26 @@ class Election: # And start an election immediately, that instantly resolves candidates: list[ElectionMessage] = [] - logger.info("Starting initial campaign") + logger.debug("Starting initial campaign") self._candidates = candidates - logger.info("Campaign started") await self._campaign(candidates, campaign_timeout=0.0) - logger.info("Initial campaign finished") + logger.debug("Initial campaign finished") # Cancel and wait for the last election to end if self._campaign_cancel_scope is not None: - logger.info("Cancelling campaign") + logger.debug("Cancelling campaign") self._campaign_cancel_scope.cancel() if self._campaign_done is not None: - logger.info("Waiting for campaign to finish") + logger.debug("Waiting for campaign to finish") await self._campaign_done.wait() - logger.info("Campaign cancelled and finished") + logger.debug("Campaign cancelled and finished") logger.info("Election finished") async def elect(self, em: ElectionMessage) -> None: - logger.info(f"Electing: {em}") + logger.debug(f"Electing: {em}") is_new_master = em.proposed_session != self.current_session self.current_session = em.proposed_session - logger.info(f"Current session: {self.current_session}") + logger.debug(f"Current session: {self.current_session}") await self._er_sender.send( ElectionResult( won_clock=em.clock, @@ -135,29 +134,29 @@ class Election: async def _election_receiver(self) -> None: with self._em_receiver as election_messages: async for message in election_messages: - logger.info(f"Election message received: {message}") + logger.debug(f"Election message received: {message}") if message.proposed_session.master_node_id == self.node_id: - logger.info("Dropping message from ourselves") + logger.debug("Dropping message from ourselves") # Drop messages from us (See exo.routing.router) continue # If a new round is starting, we participate if message.clock > self.clock: self.clock = message.clock - logger.info(f"New clock: {self.clock}") + logger.debug(f"New clock: {self.clock}") assert self._tg is not None - logger.info("Starting new campaign") + logger.debug("Starting new campaign") candidates: list[ElectionMessage] = [message] - logger.info(f"Candidates: {candidates}") - logger.info(f"Current candidates: {self._candidates}") + logger.debug(f"Candidates: {candidates}") + logger.debug(f"Current candidates: {self._candidates}") self._candidates = candidates - logger.info(f"New candidates: {self._candidates}") - logger.info("Starting new campaign") + logger.debug(f"New candidates: {self._candidates}") + logger.debug("Starting new campaign") self._tg.start_soon(self._campaign, candidates) - logger.info("Campaign started") + logger.debug("Campaign started") continue # Dismiss old messages if message.clock < self.clock: - logger.info(f"Dropping old message: {message}") + logger.debug(f"Dropping old message: {message}") continue logger.debug(f"Election added candidate {message}") # Now we are processing this rounds messages - including the message that triggered this round. @@ -170,20 +169,20 @@ class Election: await anyio.sleep(0.2) rest = connection_messages.collect() - logger.info(f"Connection messages received: {first} followed by {rest}") - logger.info(f"Current clock: {self.clock}") + logger.debug(f"Connection messages received: {first} followed by {rest}") + logger.debug(f"Current clock: {self.clock}") # These messages are strictly peer to peer self.clock += 1 - logger.info(f"New clock: {self.clock}") + logger.debug(f"New clock: {self.clock}") assert self._tg is not None candidates: list[ElectionMessage] = [] self._candidates = candidates - logger.info("Starting new campaign") + logger.debug("Starting new campaign") self._tg.start_soon(self._campaign, candidates) - logger.info("Campaign started") + logger.debug("Campaign started") self._connection_messages.append(first) self._connection_messages.extend(rest) - logger.info("Connection message added") + logger.debug("Connection message added") async def _command_counter(self) -> None: with self._co_receiver as commands: @@ -210,52 +209,52 @@ class Election: try: with scope: - logger.info(f"Election {clock} started") + logger.debug(f"Election {clock} started") status = self._election_status(clock) candidates.append(status) await self._em_sender.send(status) - logger.info(f"Sleeping for {campaign_timeout} seconds") + logger.debug(f"Sleeping for {campaign_timeout} seconds") await anyio.sleep(campaign_timeout) # minor hack - rebroadcast status in case anyone has missed it. await self._em_sender.send(status) - logger.info("Woke up from sleep") + logger.debug("Woke up from sleep") # add an anyio checkpoint - anyio.lowlevel.chekpoint() or checkpoint_if_cancelled() is preferred, but wasn't typechecking last I checked await anyio.sleep(0) # Election finished! elected = max(candidates) - logger.info(f"Election queue {candidates}") - logger.info(f"Elected: {elected}") + logger.debug(f"Election queue {candidates}") + logger.debug(f"Elected: {elected}") if ( self.node_id == elected.proposed_session.master_node_id and self.seniority >= 0 ): - logger.info( + logger.debug( f"Node is a candidate and seniority is {self.seniority}" ) self.seniority = max(self.seniority, len(candidates)) - logger.info(f"New seniority: {self.seniority}") + logger.debug(f"New seniority: {self.seniority}") else: - logger.info( + logger.debug( f"Node is not a candidate or seniority is not {self.seniority}" ) - logger.info( + logger.debug( f"Election finished, new SessionId({elected.proposed_session}) with queue {candidates}" ) - logger.info("Sending election result") + logger.debug("Sending election result") await self.elect(elected) - logger.info("Election result sent") + logger.debug("Election result sent") except get_cancelled_exc_class(): - logger.info(f"Election {clock} cancelled") + logger.debug(f"Election {clock} cancelled") finally: - logger.info(f"Election {clock} finally") + logger.debug(f"Election {clock} finally") if self._campaign_cancel_scope is scope: self._campaign_cancel_scope = None - logger.info("Setting done event") + logger.debug("Setting done event") done.set() - logger.info("Done event set") + logger.debug("Done event set") def _election_status(self, clock: int | None = None) -> ElectionMessage: c = self.clock if clock is None else clock diff --git a/src/exo/shared/tests/test_election.py b/src/exo/shared/tests/test_election.py index ae8c833f..894c55ce 100644 --- a/src/exo/shared/tests/test_election.py +++ b/src/exo/shared/tests/test_election.py @@ -36,24 +36,13 @@ def em( ) -@pytest.fixture -def fast_timeout(monkeypatch: pytest.MonkeyPatch): - # Keep campaigns fast; user explicitly allows tests to shorten the timeout. - import exo.shared.election as election_mod - - monkeypatch.setattr(election_mod, "ELECTION_TIMEOUT", 0.05, raising=True) - yield - - # ======================================= # # TESTS # # ======================================= # @pytest.mark.anyio -async def test_single_round_broadcasts_and_updates_seniority_on_self_win( - fast_timeout: None, -) -> None: +async def test_single_round_broadcasts_and_updates_seniority_on_self_win() -> None: """ Start a round by injecting an ElectionMessage with higher clock. With only our node effectively 'winning', we should broadcast once and update seniority. @@ -109,9 +98,7 @@ async def test_single_round_broadcasts_and_updates_seniority_on_self_win( @pytest.mark.anyio -async def test_peer_with_higher_seniority_wins_and_we_switch_master( - fast_timeout: None, -) -> None: +async def test_peer_with_higher_seniority_wins_and_we_switch_master() -> None: """ If a peer with clearly higher seniority participates in the round, they should win. We should broadcast our status exactly once for this round, then switch master. @@ -165,7 +152,7 @@ async def test_peer_with_higher_seniority_wins_and_we_switch_master( @pytest.mark.anyio -async def test_ignores_older_messages(fast_timeout: None) -> None: +async def test_ignores_older_messages() -> None: """ Messages with a lower clock than the current round are ignored by the receiver. Expect exactly one broadcast for the higher clock round. @@ -214,9 +201,7 @@ async def test_ignores_older_messages(fast_timeout: None) -> None: @pytest.mark.anyio -async def test_two_rounds_emit_two_broadcasts_and_increment_clock( - fast_timeout: None, -) -> None: +async def test_two_rounds_emit_two_broadcasts_and_increment_clock() -> None: """ Two successive rounds → two broadcasts. Second round triggered by a higher-clock message. """ @@ -262,7 +247,7 @@ async def test_two_rounds_emit_two_broadcasts_and_increment_clock( @pytest.mark.anyio -async def test_promotion_new_seniority_counts_participants(fast_timeout: None) -> None: +async def test_promotion_new_seniority_counts_participants() -> None: """ When we win against two peers in the same round, our seniority becomes max(existing, number_of_candidates). With existing=0: expect 3 (us + A + B). @@ -311,9 +296,7 @@ async def test_promotion_new_seniority_counts_participants(fast_timeout: None) - @pytest.mark.anyio -async def test_connection_message_triggers_new_round_broadcast( - fast_timeout: None, -) -> None: +async def test_connection_message_triggers_new_round_broadcast() -> None: """ A connection message increments the clock and starts a new campaign. We should observe a broadcast at the incremented clock. @@ -365,9 +348,7 @@ async def test_connection_message_triggers_new_round_broadcast( @pytest.mark.anyio -async def test_tie_breaker_prefers_node_with_more_commands_seen( - fast_timeout: None, -) -> None: +async def test_tie_breaker_prefers_node_with_more_commands_seen() -> None: """ With equal seniority, the node that has seen more commands should win the election. We increase our local 'commands_seen' by sending TestCommand()s before triggering the round. diff --git a/src/exo/shared/types/tasks.py b/src/exo/shared/types/tasks.py index 40fb1611..4951bc4a 100644 --- a/src/exo/shared/types/tasks.py +++ b/src/exo/shared/types/tasks.py @@ -18,6 +18,7 @@ class TaskStatus(str, Enum): Pending = "Pending" Running = "Running" Complete = "Complete" + TimedOut = "TimedOut" Failed = "Failed" diff --git a/src/exo/worker/main.py b/src/exo/worker/main.py index 073b1dbb..22df4d66 100644 --- a/src/exo/worker/main.py +++ b/src/exo/worker/main.py @@ -1,7 +1,7 @@ from random import random import anyio -from anyio import CancelScope, create_task_group, current_time +from anyio import CancelScope, create_task_group, current_time, fail_after from anyio.abc import TaskGroup from loguru import logger @@ -184,6 +184,7 @@ class Worker: assert task.task_status await self.event_sender.send(TaskCreated(task_id=task.task_id, task=task)) + # lets not kill the worker if a runner is unresponsive match task: case CreateRunner(): self._create_supervisor(task) @@ -201,11 +202,8 @@ class Worker: await self.event_sender.send( NodeDownloadProgress(download_progress=progress) ) - - initial_progress = ( - await self.shard_downloader.get_shard_download_status_for_shard( - shard - ) + initial_progress = await self.shard_downloader.get_shard_download_status_for_shard( + shard ) if initial_progress.status == "complete": progress = DownloadCompleted( @@ -217,7 +215,8 @@ class Worker: ) await self.event_sender.send( TaskStatusUpdated( - task_id=task.task_id, task_status=TaskStatus.Complete + task_id=task.task_id, + task_status=TaskStatus.Complete, ) ) else: @@ -228,9 +227,18 @@ class Worker: ) self._handle_shard_download_process(task, initial_progress) case Shutdown(runner_id=runner_id): - await self.runners.pop(runner_id).start_task(task) + try: + with fail_after(3): + await self.runners.pop(runner_id).start_task(task) + except TimeoutError: + await self.event_sender.send( + TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.TimedOut) + ) case task: - await self.runners[self._task_to_runner_id(task)].start_task(task) + await self.runners[self._task_to_runner_id(task)].start_task( + task + ) + def shutdown(self): if self._tg: diff --git a/src/exo/worker/runner/bootstrap.py b/src/exo/worker/runner/bootstrap.py index 22eab98a..3f703588 100644 --- a/src/exo/worker/runner/bootstrap.py +++ b/src/exo/worker/runner/bootstrap.py @@ -4,7 +4,7 @@ import loguru from exo.shared.types.events import Event from exo.shared.types.tasks import Task -from exo.shared.types.worker.instances import BoundInstance +from exo.shared.types.worker.instances import BoundInstance, MlxIbvInstance from exo.utils.channels import MpReceiver, MpSender logger: "loguru.Logger" @@ -20,7 +20,8 @@ def entrypoint( task_receiver: MpReceiver[Task], _logger: "loguru.Logger", ) -> None: - os.environ["MLX_METAL_FAST_SYNCH"] = "1" + if isinstance(bound_instance.instance, MlxIbvInstance) and len(bound_instance.instance.ibv_devices) >= 2: + os.environ["MLX_METAL_FAST_SYNCH"] = "1" global logger logger = _logger diff --git a/src/exo/worker/runner/runner_supervisor.py b/src/exo/worker/runner/runner_supervisor.py index cda356ae..90f2d9b7 100644 --- a/src/exo/worker/runner/runner_supervisor.py +++ b/src/exo/worker/runner/runner_supervisor.py @@ -139,6 +139,8 @@ class RunnerSupervisor: await self._event_sender.send(event) except (ClosedResourceError, BrokenResourceError) as e: await self._check_runner(e) + for tid in self.pending: + self.pending[tid].set() def __del__(self) -> None: if self.runner_process.is_alive():