diff --git a/rust/exo_pyo3_bindings/Cargo.toml b/rust/exo_pyo3_bindings/Cargo.toml index 4895ecf4..cab3b731 100644 --- a/rust/exo_pyo3_bindings/Cargo.toml +++ b/rust/exo_pyo3_bindings/Cargo.toml @@ -25,7 +25,7 @@ workspace = true networking = { workspace = true } # interop -pyo3 = { version = "0.25.1", features = [# TODO: migrate to v0.26 soon!! +pyo3 = { version = "0.27.1", features = [ # "abi3-py311", # tells pyo3 (and maturin) to build using the stable ABI with minimum Python version 3.11 "nightly", # enables better-supported GIL integration "experimental-async", # async support in #[pyfunction] & #[pymethods] @@ -38,8 +38,9 @@ pyo3 = { version = "0.25.1", features = [# TODO: migrate to v0.26 soon!! "ordered-float", "rust_decimal", "smallvec", # "anyhow", "chrono", "chrono-local", "chrono-tz", "eyre", "jiff-02", "lock_api", "parking-lot", "time", "serde", ] } -pyo3-stub-gen = { version = "0.13.1" } -pyo3-async-runtimes = { version = "0.25", features = ["attributes", "tokio-runtime", "testing"] } +pyo3-stub-gen = { version = "0.16.1" } +pyo3-async-runtimes = { version = "0.27.0", features = ["attributes", "tokio-runtime", "testing"] } +pyo3-log = "0.13.2" # macro dependencies extend = { workspace = true } @@ -70,7 +71,6 @@ thiserror = { workspace = true } #tracing-log = "0.2.0" log = { workspace = true } env_logger = "0.11" -pyo3-log = "0.12" # Networking diff --git a/rust/exo_pyo3_bindings/src/networking.rs b/rust/exo_pyo3_bindings/src/networking.rs index 021fc90e..3c480e08 100644 --- a/rust/exo_pyo3_bindings/src/networking.rs +++ b/rust/exo_pyo3_bindings/src/networking.rs @@ -166,6 +166,8 @@ async fn networking_task( IdentTopic::new(topic), data); let pyresult: PyResult = if let Err(PublishError::NoPeersSubscribedToTopic) = result { Err(exception::PyNoPeersSubscribedToTopicError::new_err()) + } else if let Err(PublishError::AllQueuesFull(_)) = result { + Err(exception::PyNoPeersSubscribedToTopicError::new_err()) } else { result.pyerr() }; diff --git a/rust/networking/src/swarm.rs b/rust/networking/src/swarm.rs index 24750558..eaeae467 100644 --- a/rust/networking/src/swarm.rs +++ b/rust/networking/src/swarm.rs @@ -95,6 +95,7 @@ mod transport { mod behaviour { use crate::{alias, discovery}; + use std::time::Duration; use libp2p::swarm::NetworkBehaviour; use libp2p::{gossipsub, identity}; @@ -124,6 +125,7 @@ mod behaviour { gossipsub::Behaviour::new( MessageAuthenticity::Signed(keypair.clone()), ConfigBuilder::default() + .publish_queue_duration(Duration::from_secs(15)) .validation_mode(ValidationMode::Strict) .build() .expect("the configuration should always be valid"), diff --git a/src/exo/main.py b/src/exo/main.py index 988a861b..280d5eaa 100644 --- a/src/exo/main.py +++ b/src/exo/main.py @@ -14,7 +14,7 @@ from exo.routing.router import Router, get_node_id_keypair from exo.shared.constants import EXO_LOG from exo.shared.election import Election, ElectionResult from exo.shared.logging import logger_cleanup, logger_setup -from exo.shared.types.common import NodeId +from exo.shared.types.common import NodeId, SessionId from exo.utils.channels import Receiver, channel from exo.utils.pydantic_ext import CamelCaseModel from exo.worker.download.impl_shard_downloader import exo_shard_downloader @@ -40,6 +40,7 @@ class Node: async def create(cls, args: "Args") -> "Self": keypair = get_node_id_keypair() node_id = NodeId(keypair.to_peer_id().to_base58()) + session_id = SessionId(master_node_id=node_id, election_clock=0) router = Router.create(keypair) await router.register_topic(topics.GLOBAL_EVENTS) await router.register_topic(topics.LOCAL_EVENTS) @@ -50,16 +51,19 @@ class Node: logger.info(f"Starting node {node_id}") if args.spawn_api: api = API( - node_id=node_id, + node_id, + session_id, port=args.api_port, global_event_receiver=router.receiver(topics.GLOBAL_EVENTS), command_sender=router.sender(topics.COMMANDS), + election_receiver=router.receiver(topics.ELECTION_MESSAGES), ) else: api = None worker = Worker( node_id, + session_id, exo_shard_downloader(), initial_connection_messages=[], connection_message_receiver=router.receiver(topics.CONNECTION_MESSAGES), @@ -70,22 +74,24 @@ class Node: # We start every node with a master master = Master( node_id, + session_id, global_event_sender=router.sender(topics.GLOBAL_EVENTS), local_event_receiver=router.receiver(topics.LOCAL_EVENTS), command_receiver=router.receiver(topics.COMMANDS), tb_only=args.tb_only, ) - # If someone manages to assemble 1 MILLION devices into an exo cluster then. well done. good job champ. er_send, er_recv = channel[ElectionResult]() election = Election( node_id, + # If someone manages to assemble 1 MILLION devices into an exo cluster then. well done. good job champ. seniority=1_000_000 if args.force_master else 0, # nb: this DOES feedback right now. i have thoughts on how to address this, # but ultimately it seems not worth the complexity election_message_sender=router.sender(topics.ELECTION_MESSAGES), election_message_receiver=router.receiver(topics.ELECTION_MESSAGES), connection_message_receiver=router.receiver(topics.CONNECTION_MESSAGES), + command_receiver=router.receiver(topics.COMMANDS), election_result_sender=er_send, ) @@ -107,6 +113,9 @@ class Node: assert self._tg with self.election_result_receiver as results: async for result in results: + # This function continues to have a lot of very specific entangled logic + # At least it's somewhat contained + # I don't like this duplication, but it's manageable for now. # TODO: This function needs refactoring generally @@ -116,23 +125,35 @@ class Node: # - Shutdown and re-create the worker # - Shut down and re-create the API - if result.node_id == self.node_id and self.master is not None: + if ( + result.session_id.master_node_id == self.node_id + and self.master is not None + ): logger.info("Node elected Master") - elif result.node_id == self.node_id and self.master is None: + elif ( + result.session_id.master_node_id == self.node_id + and self.master is None + ): logger.info("Node elected Master - promoting self") self.master = Master( self.node_id, + result.session_id, global_event_sender=self.router.sender(topics.GLOBAL_EVENTS), local_event_receiver=self.router.receiver(topics.LOCAL_EVENTS), command_receiver=self.router.receiver(topics.COMMANDS), ) self._tg.start_soon(self.master.run) - elif result.node_id != self.node_id and self.master is not None: - logger.info(f"Node {result.node_id} elected master - demoting self") + elif ( + result.session_id.master_node_id != self.node_id + and self.master is not None + ): + logger.info( + f"Node {result.session_id.master_node_id} elected master - demoting self" + ) await self.master.shutdown() self.master = None else: - logger.info(f"Node {result.node_id} elected master") + logger.info(f"Node {result.session_id.master_node_id} elected master") if result.is_new_master: await anyio.sleep(0) if self.worker: @@ -140,6 +161,7 @@ class Node: # TODO: add profiling etc to resource monitor self.worker = Worker( self.node_id, + result.session_id, exo_shard_downloader(), initial_connection_messages=result.historic_messages, connection_message_receiver=self.router.receiver( @@ -153,7 +175,10 @@ class Node: ) self._tg.start_soon(self.worker.run) if self.api: - self.api.reset() + self.api.reset(result.session_id) + else: + if self.api: + self.api.unpause() def main(): diff --git a/src/exo/master/api.py b/src/exo/master/api.py index a4ad65cd..df3782bc 100644 --- a/src/exo/master/api.py +++ b/src/exo/master/api.py @@ -5,6 +5,7 @@ from collections.abc import AsyncGenerator from typing import final import uvicorn +from anyio import Event as AsyncTaskEvent from anyio import create_task_group from anyio.abc import TaskGroup from fastapi import FastAPI, HTTPException @@ -14,6 +15,7 @@ from fastapi.staticfiles import StaticFiles from loguru import logger from exo.shared.apply import apply +from exo.shared.election import ElectionMessage from exo.shared.models.model_cards import MODEL_CARDS from exo.shared.models.model_meta import get_model_meta from exo.shared.types.api import ( @@ -36,7 +38,7 @@ from exo.shared.types.commands import ( # TODO: SpinUpInstance TaskFinished, ) -from exo.shared.types.common import CommandId, NodeId +from exo.shared.types.common import CommandId, NodeId, SessionId from exo.shared.types.events import ChunkGenerated, Event, ForwarderEvent, IndexedEvent from exo.shared.types.models import ModelMetadata from exo.shared.types.state import State @@ -74,20 +76,28 @@ async def resolve_model_meta(model_id: str) -> ModelMetadata: class API: def __init__( self, - *, node_id: NodeId, + session_id: SessionId, + *, port: int = 8000, # Ideally this would be a MasterForwarderEvent but type system says no :( global_event_receiver: Receiver[ForwarderEvent], command_sender: Sender[ForwarderCommand], + # This lets us pause the API if an election is running + election_receiver: Receiver[ElectionMessage], ) -> None: self.state = State() self.command_sender = command_sender self.global_event_receiver = global_event_receiver + self.election_receiver = election_receiver self.event_buffer: OrderedBuffer[Event] = OrderedBuffer[Event]() self.node_id: NodeId = node_id + self.session_id: SessionId = session_id self.port = port + self.paused: bool = False + self.paused_ev: AsyncTaskEvent = AsyncTaskEvent() + self.app = FastAPI() self._setup_cors() self._setup_routes() @@ -111,10 +121,17 @@ class API: ] = {} self._tg: TaskGroup | None = None - def reset(self): + def reset(self, new_session_id: SessionId): self.state = State() + self.session_id = new_session_id self.event_buffer = OrderedBuffer[Event]() self._chat_completion_queues = {} + self.unpause() + + def unpause(self): + self.paused = False + self.paused_ev.set() + self.paused_ev = AsyncTaskEvent() def _setup_cors(self) -> None: self.app.add_middleware( @@ -160,10 +177,9 @@ class API: ) def get_instance(self, instance_id: InstanceId) -> Instance: - state = self.state - if instance_id not in state.instances: + if instance_id not in self.state.instances: raise HTTPException(status_code=404, detail="Instance not found") - return state.instances[instance_id] + return self.state.instances[instance_id] async def delete_instance(self, instance_id: InstanceId) -> DeleteInstanceResponse: if instance_id not in self.state.instances: @@ -299,6 +315,7 @@ class API: logger.info("Starting API") tg.start_soon(uvicorn_server.serve) tg.start_soon(self._apply_state) + tg.start_soon(self._pause_on_new_election) self.command_sender.close() self.global_event_receiver.close() @@ -314,7 +331,15 @@ class API: ): self._chat_completion_queues[event.command_id].put_nowait(event) + async def _pause_on_new_election(self): + with self.election_receiver as ems: + async for message in ems: + if message.clock > self.session_id.election_clock: + self.paused = True + async def _send(self, command: Command): + while self.paused: + await self.paused_ev.wait() await self.command_sender.send( ForwarderCommand(origin=self.node_id, command=command) ) diff --git a/src/exo/master/main.py b/src/exo/master/main.py index b60b263a..15cd79e9 100644 --- a/src/exo/master/main.py +++ b/src/exo/master/main.py @@ -16,8 +16,9 @@ from exo.shared.types.commands import ( RequestEventLog, SpinUpInstance, TaskFinished, + TestCommand, ) -from exo.shared.types.common import CommandId, NodeId +from exo.shared.types.common import CommandId, NodeId, SessionId from exo.shared.types.events import ( Event, ForwarderEvent, @@ -38,6 +39,7 @@ class Master: def __init__( self, node_id: NodeId, + session_id: SessionId, *, command_receiver: Receiver[ForwarderCommand], # Receiving indexed events from the forwarder to be applied to state @@ -51,6 +53,7 @@ class Master: self.state = State() self._tg: TaskGroup | None = None self.node_id = node_id + self.session_id = session_id self.command_task_mapping: dict[CommandId, TaskId] = {} self.command_receiver = command_receiver self.local_event_receiver = local_event_receiver @@ -93,6 +96,8 @@ class Master: generated_events: list[Event] = [] command = forwarder_command.command match command: + case TestCommand(): + pass case ChatCompletion(): instance_task_counts: dict[InstanceId, int] = {} for instance in self.state.instances.values(): @@ -184,6 +189,9 @@ class Master: async def _event_processor(self) -> None: with self.local_event_receiver as local_events: async for local_event in local_events: + # Discard all events not from our session + if local_event.session != self.session_id: + continue self._multi_buffer.ingest( local_event.origin_idx, local_event.event, @@ -221,6 +229,7 @@ class Master: ForwarderEvent( origin=NodeId(f"master_{self.node_id}"), origin_idx=local_index, + session=self.session_id, event=event, ) ) @@ -233,6 +242,7 @@ class Master: ForwarderEvent( origin=self.node_id, origin_idx=event.idx, + session=self.session_id, event=event.event, ) ) diff --git a/src/exo/master/tests/test_master.py b/src/exo/master/tests/test_master.py index a1b6c0b6..1e2750b5 100644 --- a/src/exo/master/tests/test_master.py +++ b/src/exo/master/tests/test_master.py @@ -13,7 +13,7 @@ from exo.shared.types.commands import ( CreateInstance, ForwarderCommand, ) -from exo.shared.types.common import NodeId +from exo.shared.types.common import NodeId, SessionId from exo.shared.types.events import ( ForwarderEvent, IndexedEvent, @@ -38,6 +38,7 @@ from exo.utils.channels import channel async def test_master(): keypair = get_node_id_keypair() node_id = NodeId(keypair.to_peer_id().to_base58()) + session_id = SessionId(master_node_id=node_id, election_clock=0) ge_sender, global_event_receiver = channel[ForwarderEvent]() command_sender, co_receiver = channel[ForwarderCommand]() @@ -58,6 +59,7 @@ async def test_master(): master = Master( node_id, + session_id, global_event_sender=ge_sender, local_event_receiver=le_receiver, command_receiver=co_receiver, @@ -74,6 +76,7 @@ async def test_master(): ForwarderEvent( origin_idx=0, origin=sender_node_id, + session=session_id, event=( NodePerformanceMeasured( node_id=node_id, diff --git a/src/exo/routing/router.py b/src/exo/routing/router.py index cf89e75f..335d7200 100644 --- a/src/exo/routing/router.py +++ b/src/exo/routing/router.py @@ -200,15 +200,15 @@ class Router: await router.publish(message) async def _networking_publish(self): - # This with/for pattern ensures this method doesn't return until after the receiver closes - # This is good for safety, but is mostly a redundant check. with self.networking_receiver as networked_items: async for topic, data in networked_items: try: logger.trace(f"Sending message on {topic} with payload {data}") await self._net.gossipsub_publish(topic, data) + # As a hack, this also catches AllQueuesFull + # Need to fix that ASAP. except NoPeersSubscribedToTopicError: - logger.trace(f"Failed to send over {topic} - No peers found.") + pass def get_node_id_keypair( diff --git a/src/exo/shared/election.py b/src/exo/shared/election.py index a5f94c66..70e5efc3 100644 --- a/src/exo/shared/election.py +++ b/src/exo/shared/election.py @@ -11,7 +11,8 @@ from anyio.abc import TaskGroup from loguru import logger from exo.routing.connection_message import ConnectionMessage -from exo.shared.types.common import NodeId +from exo.shared.types.commands import ForwarderCommand +from exo.shared.types.common import NodeId, SessionId from exo.utils.channels import Receiver, Sender from exo.utils.pydantic_ext import CamelCaseModel @@ -21,18 +22,24 @@ ELECTION_TIMEOUT = 3.0 class ElectionMessage(CamelCaseModel): clock: int seniority: int - node_id: NodeId + proposed_session: SessionId + commands_seen: int # Could eventually include a list of neighbour nodes for centrality - def __lt__(self, other: Self): + def __lt__(self, other: Self) -> bool: if self.seniority != other.seniority: return self.seniority < other.seniority + elif self.commands_seen != other.commands_seen: + return self.commands_seen < other.commands_seen else: - return self.node_id < other.node_id + return ( + self.proposed_session.master_node_id + < other.proposed_session.master_node_id + ) class ElectionResult(CamelCaseModel): - node_id: NodeId + session_id: SessionId is_new_master: bool historic_messages: list[ConnectionMessage] @@ -41,11 +48,12 @@ class Election: def __init__( self, node_id: NodeId, + *, election_message_receiver: Receiver[ElectionMessage], election_message_sender: Sender[ElectionMessage], election_result_sender: Sender[ElectionResult], connection_message_receiver: Receiver[ConnectionMessage], - *, + command_receiver: Receiver[ForwarderCommand], is_candidate: bool = True, seniority: int = 0, ): @@ -55,13 +63,18 @@ class Election: self.seniority = seniority if is_candidate else -1 self.clock = 0 self.node_id = node_id + self.commands_seen = 0 # Every node spawns as master - self.master_node_id: NodeId = node_id + self.current_session: SessionId = SessionId( + master_node_id=node_id, election_clock=0 + ) + # Senders/Receivers self._em_sender = election_message_sender self._em_receiver = election_message_receiver self._er_sender = election_result_sender self._cm_receiver = connection_message_receiver + self._co_receiver = command_receiver # Campaign state self._candidates: list[ElectionMessage] = [] @@ -76,6 +89,7 @@ class Election: self._tg = tg tg.start_soon(self._election_receiver) tg.start_soon(self._connection_receiver) + tg.start_soon(self._command_counter) await self._campaign(None) if self._campaign_cancel_scope is not None: @@ -84,12 +98,12 @@ class Election: if self._campaign_done is not None: await self._campaign_done.wait() - async def elect(self, node_id: NodeId) -> None: - is_new_master = node_id != self.master_node_id - self.master_node_id = node_id + async def elect(self, em: ElectionMessage) -> None: + is_new_master = em.proposed_session != self.current_session + self.current_session = em.proposed_session await self._er_sender.send( ElectionResult( - node_id=node_id, + session_id=em.proposed_session, is_new_master=is_new_master, historic_messages=self._connection_messages, ) @@ -106,7 +120,7 @@ class Election: async def _election_receiver(self) -> None: with self._em_receiver as election_messages: async for message in election_messages: - if message.node_id == self.node_id: + if message.proposed_session.master_node_id == self.node_id: # Drop messages from us (See exo.routing.router) continue # If a new round is starting, we participate @@ -129,6 +143,11 @@ class Election: await self._campaign(None) self._connection_messages.append(msg) + async def _command_counter(self) -> None: + with self._co_receiver as commands: + async for _command in commands: + self.commands_seen += 1 + async def _campaign(self, initial_message: ElectionMessage | None) -> None: # Kill the old campaign if self._campaign_cancel_scope: @@ -167,10 +186,15 @@ class Election: candidates = sorted(candidates) logger.debug(f"Election queue {candidates}") elected = candidates[-1] - logger.info("Election finished") - if self.node_id == elected.node_id and self.seniority >= 0: + if ( + self.node_id == elected.proposed_session.master_node_id + and self.seniority >= 0 + ): self.seniority = max(self.seniority, len(candidates)) - await self.elect(elected.node_id) + logger.info( + f"Election finished, new SessionId({elected.proposed_session})" + ) + await self.elect(elected) except get_cancelled_exc_class(): logger.info("Election cancelled") finally: @@ -180,4 +204,13 @@ class Election: def _election_status(self, clock: int | None = None) -> ElectionMessage: c = self.clock if clock is None else clock - return ElectionMessage(clock=c, seniority=self.seniority, node_id=self.node_id) + return ElectionMessage( + proposed_session=( + self.current_session + if self.current_session.master_node_id == self.node_id + else SessionId(master_node_id=self.node_id, election_clock=c) + ), + clock=c, + seniority=self.seniority, + commands_seen=self.commands_seen, + ) diff --git a/src/exo/shared/tests/test_election.py b/src/exo/shared/tests/test_election.py index 1c04e5c1..ae8c833f 100644 --- a/src/exo/shared/tests/test_election.py +++ b/src/exo/shared/tests/test_election.py @@ -3,7 +3,8 @@ from anyio import create_task_group, fail_after, move_on_after from exo.routing.connection_message import ConnectionMessage, ConnectionMessageType from exo.shared.election import Election, ElectionMessage, ElectionResult -from exo.shared.types.common import NodeId +from exo.shared.types.commands import ForwarderCommand, TestCommand +from exo.shared.types.common import NodeId, SessionId from exo.utils.channels import channel # ======= # @@ -11,8 +12,28 @@ from exo.utils.channels import channel # ======= # -def em(clock: int, seniority: int, node_id: str) -> ElectionMessage: - return ElectionMessage(clock=clock, seniority=seniority, node_id=NodeId(node_id)) +def em( + clock: int, + seniority: int, + node_id: str, + commands_seen: int = 0, + election_clock: int | None = None, +) -> ElectionMessage: + """ + Helper to build ElectionMessages for a given proposer node. + + The new API carries a proposed SessionId (master_node_id + election_clock). + By default we use the same value for election_clock as the 'clock' of the round. + """ + return ElectionMessage( + clock=clock, + seniority=seniority, + proposed_session=SessionId( + master_node_id=NodeId(node_id), + election_clock=clock if election_clock is None else election_clock, + ), + commands_seen=commands_seen, + ) @pytest.fixture @@ -43,8 +64,10 @@ async def test_single_round_broadcasts_and_updates_seniority_on_self_win( em_in_tx, em_in_rx = channel[ElectionMessage]() # Election results produced by the Election (we'll observe these) er_tx, er_rx = channel[ElectionResult]() - # Connection messages (unused in this test but required by ctor) + # Connection messages cm_tx, cm_rx = channel[ConnectionMessage]() + # Commands + co_tx, co_rx = channel[ForwarderCommand]() election = Election( node_id=NodeId("B"), @@ -52,6 +75,7 @@ async def test_single_round_broadcasts_and_updates_seniority_on_self_win( election_message_sender=em_out_tx, election_result_sender=er_tx, connection_message_receiver=cm_rx, + command_receiver=co_rx, is_candidate=True, ) @@ -64,18 +88,21 @@ async def test_single_round_broadcasts_and_updates_seniority_on_self_win( # Expect our broadcast back to the peer side for this round only while True: got = await em_out_rx.receive() - if got.clock == 1 and got.node_id == NodeId("B"): + if got.clock == 1 and got.proposed_session.master_node_id == NodeId( + "B" + ): break # Wait for the round to finish and produce an ElectionResult result = await er_rx.receive() - assert result.node_id == NodeId("B") + assert result.session_id.master_node_id == NodeId("B") # We spawned as master; electing ourselves again is not "new master". assert result.is_new_master is False # Close inbound streams to end the receivers (and run()) - await em_in_tx.aclose() - await cm_tx.aclose() + em_in_tx.close() + cm_tx.close() + co_tx.close() # We should have updated seniority to 2 (A + B). assert election.seniority == 2 @@ -93,6 +120,7 @@ async def test_peer_with_higher_seniority_wins_and_we_switch_master( em_in_tx, em_in_rx = channel[ElectionMessage]() er_tx, er_rx = channel[ElectionResult]() cm_tx, cm_rx = channel[ConnectionMessage]() + co_tx, co_rx = channel[ForwarderCommand]() election = Election( node_id=NodeId("ME"), @@ -100,6 +128,7 @@ async def test_peer_with_higher_seniority_wins_and_we_switch_master( election_message_sender=em_out_tx, election_result_sender=er_tx, connection_message_receiver=cm_rx, + command_receiver=co_rx, is_candidate=True, ) @@ -117,13 +146,19 @@ async def test_peer_with_higher_seniority_wins_and_we_switch_master( assert got.seniority == 0 break - # After the timeout, election result should report the peer as master - result = await er_rx.receive() - assert result.node_id == NodeId("PEER") + # After the timeout, election result for clock=1 should report the peer as master + # (Skip any earlier result from the boot campaign at clock=0 by filtering on election_clock) + while True: + result = await er_rx.receive() + if result.session_id.election_clock == 1: + break + + assert result.session_id.master_node_id == NodeId("PEER") assert result.is_new_master is True - await em_in_tx.aclose() - await cm_tx.aclose() + em_in_tx.close() + cm_tx.close() + co_tx.close() # We lost → seniority unchanged assert election.seniority == 0 @@ -139,6 +174,7 @@ async def test_ignores_older_messages(fast_timeout: None) -> None: em_in_tx, em_in_rx = channel[ElectionMessage]() er_tx, _er_rx = channel[ElectionResult]() cm_tx, cm_rx = channel[ConnectionMessage]() + co_tx, co_rx = channel[ForwarderCommand]() election = Election( node_id=NodeId("ME"), @@ -146,6 +182,7 @@ async def test_ignores_older_messages(fast_timeout: None) -> None: election_message_sender=em_out_tx, election_result_sender=er_tx, connection_message_receiver=cm_rx, + command_receiver=co_rx, is_candidate=True, ) @@ -169,8 +206,9 @@ async def test_ignores_older_messages(fast_timeout: None) -> None: got_second = True assert not got_second, "Should not receive a broadcast for an older round" - await em_in_tx.aclose() - await cm_tx.aclose() + em_in_tx.close() + cm_tx.close() + co_tx.close() # Not asserting on the result; focus is on ignore behavior. @@ -186,6 +224,7 @@ async def test_two_rounds_emit_two_broadcasts_and_increment_clock( em_in_tx, em_in_rx = channel[ElectionMessage]() er_tx, _er_rx = channel[ElectionResult]() cm_tx, cm_rx = channel[ConnectionMessage]() + co_tx, co_rx = channel[ForwarderCommand]() election = Election( node_id=NodeId("ME"), @@ -193,6 +232,7 @@ async def test_two_rounds_emit_two_broadcasts_and_increment_clock( election_message_sender=em_out_tx, election_result_sender=er_tx, connection_message_receiver=cm_rx, + command_receiver=co_rx, is_candidate=True, ) @@ -214,8 +254,9 @@ async def test_two_rounds_emit_two_broadcasts_and_increment_clock( if m2.clock == 2: break - await em_in_tx.aclose() - await cm_tx.aclose() + em_in_tx.close() + cm_tx.close() + co_tx.close() # Not asserting on who won; just that both rounds were broadcast. @@ -230,6 +271,7 @@ async def test_promotion_new_seniority_counts_participants(fast_timeout: None) - em_in_tx, em_in_rx = channel[ElectionMessage]() er_tx, er_rx = channel[ElectionResult]() cm_tx, cm_rx = channel[ConnectionMessage]() + co_tx, co_rx = channel[ForwarderCommand]() election = Election( node_id=NodeId("ME"), @@ -237,6 +279,7 @@ async def test_promotion_new_seniority_counts_participants(fast_timeout: None) - election_message_sender=em_out_tx, election_result_sender=er_tx, connection_message_receiver=cm_rx, + command_receiver=co_rx, is_candidate=True, ) @@ -251,14 +294,17 @@ async def test_promotion_new_seniority_counts_participants(fast_timeout: None) - # We should see exactly one broadcast from us for this round while True: got = await em_out_rx.receive() - if got.clock == 7 and got.node_id == NodeId("ME"): + if got.clock == 7 and got.proposed_session.master_node_id == NodeId( + "ME" + ): break # Wait for the election to finish so seniority updates _ = await er_rx.receive() - await em_in_tx.aclose() - await cm_tx.aclose() + em_in_tx.close() + cm_tx.close() + co_tx.close() # We + A + B = 3 → new seniority expected to be 3 assert election.seniority == 3 @@ -276,6 +322,7 @@ async def test_connection_message_triggers_new_round_broadcast( em_in_tx, em_in_rx = channel[ElectionMessage]() er_tx, _er_rx = channel[ElectionResult]() cm_tx, cm_rx = channel[ConnectionMessage]() + co_tx, co_rx = channel[ForwarderCommand]() election = Election( node_id=NodeId("ME"), @@ -283,6 +330,7 @@ async def test_connection_message_triggers_new_round_broadcast( election_message_sender=em_out_tx, election_result_sender=er_tx, connection_message_receiver=cm_rx, + command_receiver=co_rx, is_candidate=True, ) @@ -303,11 +351,75 @@ async def test_connection_message_triggers_new_round_broadcast( # Expect a broadcast for the new round at clock=1 while True: got = await em_out_rx.receive() - if got.clock == 1 and got.node_id == NodeId("ME"): + if got.clock == 1 and got.proposed_session.master_node_id == NodeId( + "ME" + ): break # Close promptly to avoid waiting for campaign completion - await em_in_tx.aclose() - await cm_tx.aclose() + em_in_tx.close() + cm_tx.close() + co_tx.close() # After cancellation (before election finishes), no seniority changes asserted here. + + +@pytest.mark.anyio +async def test_tie_breaker_prefers_node_with_more_commands_seen( + fast_timeout: None, +) -> None: + """ + With equal seniority, the node that has seen more commands should win the election. + We increase our local 'commands_seen' by sending TestCommand()s before triggering the round. + """ + em_out_tx, em_out_rx = channel[ElectionMessage]() + em_in_tx, em_in_rx = channel[ElectionMessage]() + er_tx, er_rx = channel[ElectionResult]() + cm_tx, cm_rx = channel[ConnectionMessage]() + co_tx, co_rx = channel[ForwarderCommand]() + + me = NodeId("ME") + + election = Election( + node_id=me, + election_message_receiver=em_in_rx, + election_message_sender=em_out_tx, + election_result_sender=er_tx, + connection_message_receiver=cm_rx, + command_receiver=co_rx, + is_candidate=True, + seniority=0, + ) + + async with create_task_group() as tg: + with fail_after(2): + tg.start_soon(election.run) + + # Pump local commands so our commands_seen is high before the round starts + for _ in range(50): + await co_tx.send( + ForwarderCommand(origin=NodeId("SOMEONE"), command=TestCommand()) + ) + + # Trigger a round at clock=1 with a peer of equal seniority but fewer commands + await em_in_tx.send( + em(clock=1, seniority=0, node_id="PEER", commands_seen=5) + ) + + # Observe our broadcast for this round (to ensure we've joined the round) + while True: + got = await em_out_rx.receive() + if got.clock == 1 and got.proposed_session.master_node_id == me: + # We don't assert exact count, just that we've participated this round. + break + + # The elected result for clock=1 should be us due to higher commands_seen + while True: + result = await er_rx.receive() + if result.session_id.master_node_id == me: + assert result.session_id.election_clock in (0, 1) + break + + em_in_tx.close() + cm_tx.close() + co_tx.close() diff --git a/src/exo/shared/types/commands.py b/src/exo/shared/types/commands.py index d7f5da87..b2f7a97b 100644 --- a/src/exo/shared/types/commands.py +++ b/src/exo/shared/types/commands.py @@ -1,5 +1,3 @@ -from enum import Enum - from pydantic import Field from exo.shared.types.api import ChatCompletionTaskParams @@ -10,19 +8,14 @@ from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel # TODO: We need to have a distinction between create instance and spin up instance. -class CommandType(str, Enum): - ChatCompletion = "ChatCompletion" - CreateInstance = "CreateInstance" - SpinUpInstance = "SpinUpInstance" - DeleteInstance = "DeleteInstance" - TaskFinished = "TaskFinished" - RequestEventLog = "RequestEventLog" - - class BaseCommand(TaggedModel): command_id: CommandId = Field(default_factory=CommandId) +class TestCommand(BaseCommand): + pass + + class ChatCompletion(BaseCommand): request_params: ChatCompletionTaskParams @@ -48,7 +41,8 @@ class RequestEventLog(BaseCommand): Command = ( - RequestEventLog + TestCommand + | RequestEventLog | ChatCompletion | CreateInstance | SpinUpInstance diff --git a/src/exo/shared/types/common.py b/src/exo/shared/types/common.py index e34fc7ef..42b682dc 100644 --- a/src/exo/shared/types/common.py +++ b/src/exo/shared/types/common.py @@ -23,6 +23,11 @@ class NodeId(Id): pass +class SessionId(CamelCaseModel): + master_node_id: NodeId + election_clock: int + + class CommandId(Id): pass diff --git a/src/exo/shared/types/events.py b/src/exo/shared/types/events.py index a910ea93..0de5612d 100644 --- a/src/exo/shared/types/events.py +++ b/src/exo/shared/types/events.py @@ -4,7 +4,7 @@ from pydantic import Field from exo.shared.topology import Connection, NodePerformanceProfile from exo.shared.types.chunks import CommandId, GenerationChunk -from exo.shared.types.common import Id, NodeId +from exo.shared.types.common import Id, NodeId, SessionId from exo.shared.types.profiling import MemoryPerformanceProfile from exo.shared.types.tasks import Task, TaskId, TaskStatus from exo.shared.types.worker.common import InstanceId, WorkerStatus @@ -177,4 +177,5 @@ class ForwarderEvent(CamelCaseModel): origin_idx: int = Field(ge=0) origin: NodeId + session: SessionId event: Event diff --git a/src/exo/worker/main.py b/src/exo/worker/main.py index e4374dd5..f19db835 100644 --- a/src/exo/worker/main.py +++ b/src/exo/worker/main.py @@ -13,7 +13,7 @@ from loguru import logger from exo.routing.connection_message import ConnectionMessage, ConnectionMessageType from exo.shared.apply import apply from exo.shared.types.commands import ForwarderCommand, RequestEventLog -from exo.shared.types.common import NodeId +from exo.shared.types.common import NodeId, SessionId from exo.shared.types.events import ( ChunkGenerated, Event, @@ -75,6 +75,7 @@ class Worker: def __init__( self, node_id: NodeId, + session_id: SessionId, shard_downloader: ShardDownloader, *, initial_connection_messages: list[ConnectionMessage], @@ -91,6 +92,7 @@ class Worker: command_sender: Sender[ForwarderCommand], ): self.node_id: NodeId = node_id + self.session_id: SessionId = session_id self.shard_downloader: ShardDownloader = shard_downloader self.global_event_receiver = global_event_receiver self.local_event_sender = local_event_sender @@ -634,6 +636,7 @@ class Worker: fe = ForwarderEvent( origin_idx=self.local_event_index, origin=self.node_id, + session=self.session_id, event=event, ) logger.debug( diff --git a/src/exo/worker/tests/worker_management.py b/src/exo/worker/tests/worker_management.py index ad7e346d..220665e6 100644 --- a/src/exo/worker/tests/worker_management.py +++ b/src/exo/worker/tests/worker_management.py @@ -5,12 +5,15 @@ from anyio import fail_after from exo.routing.topics import ConnectionMessage, ForwarderCommand, ForwarderEvent from exo.shared.types.chunks import TokenChunk -from exo.shared.types.common import NodeId +from exo.shared.types.common import NodeId, SessionId from exo.shared.types.events import ChunkGenerated, Event, TaskStateUpdated from exo.shared.types.tasks import TaskId, TaskStatus from exo.utils.channels import Receiver, Sender, channel from exo.worker.download.shard_downloader import NoopShardDownloader, ShardDownloader from exo.worker.main import Worker +from exo.worker.tests.constants import MASTER_NODE_ID + +session = SessionId(master_node_id=MASTER_NODE_ID, election_clock=0) @dataclass @@ -19,11 +22,17 @@ class WorkerMailbox: receiver: Receiver[ForwarderEvent] counter: int = 0 - async def append_events(self, events: list[Event], *, origin: NodeId): + async def append_events( + self, + events: list[Event], + *, + origin: NodeId, + ): for event in events: await self.sender.send( ForwarderEvent( origin=origin, + session=session, event=event, origin_idx=self.counter, ) @@ -45,6 +54,7 @@ def create_worker_void_mailbox( shard_downloader = NoopShardDownloader() return Worker( node_id, + session_id=session, shard_downloader=shard_downloader, initial_connection_messages=[], connection_message_receiver=channel[ConnectionMessage]()[1], @@ -64,6 +74,7 @@ def create_worker_and_mailbox( sender, grecv = channel[ForwarderEvent]() worker = Worker( node_id, + session_id=session, shard_downloader=shard_downloader, initial_connection_messages=[], connection_message_receiver=channel[ConnectionMessage]()[1], @@ -84,6 +95,7 @@ def create_worker_with_old_mailbox( # This function is subtly complex, come talk to Evan if you want to know what it's actually doing. worker = Worker( node_id, + session_id=session, shard_downloader=shard_downloader, initial_connection_messages=[], connection_message_receiver=channel[ConnectionMessage]()[1],