Squash merge merging_clusters into tensor_parallel94

This commit is contained in:
Evan Quiney
2025-10-31 17:41:57 +00:00
committed by GitHub
parent d46c7e6a76
commit 3b409647ba
15 changed files with 306 additions and 79 deletions

View File

@@ -25,7 +25,7 @@ workspace = true
networking = { workspace = true }
# interop
pyo3 = { version = "0.25.1", features = [# TODO: migrate to v0.26 soon!!
pyo3 = { version = "0.27.1", features = [
# "abi3-py311", # tells pyo3 (and maturin) to build using the stable ABI with minimum Python version 3.11
"nightly", # enables better-supported GIL integration
"experimental-async", # async support in #[pyfunction] & #[pymethods]
@@ -38,8 +38,9 @@ pyo3 = { version = "0.25.1", features = [# TODO: migrate to v0.26 soon!!
"ordered-float", "rust_decimal", "smallvec",
# "anyhow", "chrono", "chrono-local", "chrono-tz", "eyre", "jiff-02", "lock_api", "parking-lot", "time", "serde",
] }
pyo3-stub-gen = { version = "0.13.1" }
pyo3-async-runtimes = { version = "0.25", features = ["attributes", "tokio-runtime", "testing"] }
pyo3-stub-gen = { version = "0.16.1" }
pyo3-async-runtimes = { version = "0.27.0", features = ["attributes", "tokio-runtime", "testing"] }
pyo3-log = "0.13.2"
# macro dependencies
extend = { workspace = true }
@@ -70,7 +71,6 @@ thiserror = { workspace = true }
#tracing-log = "0.2.0"
log = { workspace = true }
env_logger = "0.11"
pyo3-log = "0.12"
# Networking

View File

@@ -166,6 +166,8 @@ async fn networking_task(
IdentTopic::new(topic), data);
let pyresult: PyResult<MessageId> = if let Err(PublishError::NoPeersSubscribedToTopic) = result {
Err(exception::PyNoPeersSubscribedToTopicError::new_err())
} else if let Err(PublishError::AllQueuesFull(_)) = result {
Err(exception::PyNoPeersSubscribedToTopicError::new_err())
} else {
result.pyerr()
};

View File

@@ -95,6 +95,7 @@ mod transport {
mod behaviour {
use crate::{alias, discovery};
use std::time::Duration;
use libp2p::swarm::NetworkBehaviour;
use libp2p::{gossipsub, identity};
@@ -124,6 +125,7 @@ mod behaviour {
gossipsub::Behaviour::new(
MessageAuthenticity::Signed(keypair.clone()),
ConfigBuilder::default()
.publish_queue_duration(Duration::from_secs(15))
.validation_mode(ValidationMode::Strict)
.build()
.expect("the configuration should always be valid"),

View File

@@ -14,7 +14,7 @@ from exo.routing.router import Router, get_node_id_keypair
from exo.shared.constants import EXO_LOG
from exo.shared.election import Election, ElectionResult
from exo.shared.logging import logger_cleanup, logger_setup
from exo.shared.types.common import NodeId
from exo.shared.types.common import NodeId, SessionId
from exo.utils.channels import Receiver, channel
from exo.utils.pydantic_ext import CamelCaseModel
from exo.worker.download.impl_shard_downloader import exo_shard_downloader
@@ -40,6 +40,7 @@ class Node:
async def create(cls, args: "Args") -> "Self":
keypair = get_node_id_keypair()
node_id = NodeId(keypair.to_peer_id().to_base58())
session_id = SessionId(master_node_id=node_id, election_clock=0)
router = Router.create(keypair)
await router.register_topic(topics.GLOBAL_EVENTS)
await router.register_topic(topics.LOCAL_EVENTS)
@@ -50,16 +51,19 @@ class Node:
logger.info(f"Starting node {node_id}")
if args.spawn_api:
api = API(
node_id=node_id,
node_id,
session_id,
port=args.api_port,
global_event_receiver=router.receiver(topics.GLOBAL_EVENTS),
command_sender=router.sender(topics.COMMANDS),
election_receiver=router.receiver(topics.ELECTION_MESSAGES),
)
else:
api = None
worker = Worker(
node_id,
session_id,
exo_shard_downloader(),
initial_connection_messages=[],
connection_message_receiver=router.receiver(topics.CONNECTION_MESSAGES),
@@ -70,22 +74,24 @@ class Node:
# We start every node with a master
master = Master(
node_id,
session_id,
global_event_sender=router.sender(topics.GLOBAL_EVENTS),
local_event_receiver=router.receiver(topics.LOCAL_EVENTS),
command_receiver=router.receiver(topics.COMMANDS),
tb_only=args.tb_only,
)
# If someone manages to assemble 1 MILLION devices into an exo cluster then. well done. good job champ.
er_send, er_recv = channel[ElectionResult]()
election = Election(
node_id,
# If someone manages to assemble 1 MILLION devices into an exo cluster then. well done. good job champ.
seniority=1_000_000 if args.force_master else 0,
# nb: this DOES feedback right now. i have thoughts on how to address this,
# but ultimately it seems not worth the complexity
election_message_sender=router.sender(topics.ELECTION_MESSAGES),
election_message_receiver=router.receiver(topics.ELECTION_MESSAGES),
connection_message_receiver=router.receiver(topics.CONNECTION_MESSAGES),
command_receiver=router.receiver(topics.COMMANDS),
election_result_sender=er_send,
)
@@ -107,6 +113,9 @@ class Node:
assert self._tg
with self.election_result_receiver as results:
async for result in results:
# This function continues to have a lot of very specific entangled logic
# At least it's somewhat contained
# I don't like this duplication, but it's manageable for now.
# TODO: This function needs refactoring generally
@@ -116,23 +125,35 @@ class Node:
# - Shutdown and re-create the worker
# - Shut down and re-create the API
if result.node_id == self.node_id and self.master is not None:
if (
result.session_id.master_node_id == self.node_id
and self.master is not None
):
logger.info("Node elected Master")
elif result.node_id == self.node_id and self.master is None:
elif (
result.session_id.master_node_id == self.node_id
and self.master is None
):
logger.info("Node elected Master - promoting self")
self.master = Master(
self.node_id,
result.session_id,
global_event_sender=self.router.sender(topics.GLOBAL_EVENTS),
local_event_receiver=self.router.receiver(topics.LOCAL_EVENTS),
command_receiver=self.router.receiver(topics.COMMANDS),
)
self._tg.start_soon(self.master.run)
elif result.node_id != self.node_id and self.master is not None:
logger.info(f"Node {result.node_id} elected master - demoting self")
elif (
result.session_id.master_node_id != self.node_id
and self.master is not None
):
logger.info(
f"Node {result.session_id.master_node_id} elected master - demoting self"
)
await self.master.shutdown()
self.master = None
else:
logger.info(f"Node {result.node_id} elected master")
logger.info(f"Node {result.session_id.master_node_id} elected master")
if result.is_new_master:
await anyio.sleep(0)
if self.worker:
@@ -140,6 +161,7 @@ class Node:
# TODO: add profiling etc to resource monitor
self.worker = Worker(
self.node_id,
result.session_id,
exo_shard_downloader(),
initial_connection_messages=result.historic_messages,
connection_message_receiver=self.router.receiver(
@@ -153,7 +175,10 @@ class Node:
)
self._tg.start_soon(self.worker.run)
if self.api:
self.api.reset()
self.api.reset(result.session_id)
else:
if self.api:
self.api.unpause()
def main():

View File

@@ -5,6 +5,7 @@ from collections.abc import AsyncGenerator
from typing import final
import uvicorn
from anyio import Event as AsyncTaskEvent
from anyio import create_task_group
from anyio.abc import TaskGroup
from fastapi import FastAPI, HTTPException
@@ -14,6 +15,7 @@ from fastapi.staticfiles import StaticFiles
from loguru import logger
from exo.shared.apply import apply
from exo.shared.election import ElectionMessage
from exo.shared.models.model_cards import MODEL_CARDS
from exo.shared.models.model_meta import get_model_meta
from exo.shared.types.api import (
@@ -36,7 +38,7 @@ from exo.shared.types.commands import (
# TODO: SpinUpInstance
TaskFinished,
)
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.common import CommandId, NodeId, SessionId
from exo.shared.types.events import ChunkGenerated, Event, ForwarderEvent, IndexedEvent
from exo.shared.types.models import ModelMetadata
from exo.shared.types.state import State
@@ -74,20 +76,28 @@ async def resolve_model_meta(model_id: str) -> ModelMetadata:
class API:
def __init__(
self,
*,
node_id: NodeId,
session_id: SessionId,
*,
port: int = 8000,
# Ideally this would be a MasterForwarderEvent but type system says no :(
global_event_receiver: Receiver[ForwarderEvent],
command_sender: Sender[ForwarderCommand],
# This lets us pause the API if an election is running
election_receiver: Receiver[ElectionMessage],
) -> None:
self.state = State()
self.command_sender = command_sender
self.global_event_receiver = global_event_receiver
self.election_receiver = election_receiver
self.event_buffer: OrderedBuffer[Event] = OrderedBuffer[Event]()
self.node_id: NodeId = node_id
self.session_id: SessionId = session_id
self.port = port
self.paused: bool = False
self.paused_ev: AsyncTaskEvent = AsyncTaskEvent()
self.app = FastAPI()
self._setup_cors()
self._setup_routes()
@@ -111,10 +121,17 @@ class API:
] = {}
self._tg: TaskGroup | None = None
def reset(self):
def reset(self, new_session_id: SessionId):
self.state = State()
self.session_id = new_session_id
self.event_buffer = OrderedBuffer[Event]()
self._chat_completion_queues = {}
self.unpause()
def unpause(self):
self.paused = False
self.paused_ev.set()
self.paused_ev = AsyncTaskEvent()
def _setup_cors(self) -> None:
self.app.add_middleware(
@@ -160,10 +177,9 @@ class API:
)
def get_instance(self, instance_id: InstanceId) -> Instance:
state = self.state
if instance_id not in state.instances:
if instance_id not in self.state.instances:
raise HTTPException(status_code=404, detail="Instance not found")
return state.instances[instance_id]
return self.state.instances[instance_id]
async def delete_instance(self, instance_id: InstanceId) -> DeleteInstanceResponse:
if instance_id not in self.state.instances:
@@ -299,6 +315,7 @@ class API:
logger.info("Starting API")
tg.start_soon(uvicorn_server.serve)
tg.start_soon(self._apply_state)
tg.start_soon(self._pause_on_new_election)
self.command_sender.close()
self.global_event_receiver.close()
@@ -314,7 +331,15 @@ class API:
):
self._chat_completion_queues[event.command_id].put_nowait(event)
async def _pause_on_new_election(self):
with self.election_receiver as ems:
async for message in ems:
if message.clock > self.session_id.election_clock:
self.paused = True
async def _send(self, command: Command):
while self.paused:
await self.paused_ev.wait()
await self.command_sender.send(
ForwarderCommand(origin=self.node_id, command=command)
)

View File

@@ -16,8 +16,9 @@ from exo.shared.types.commands import (
RequestEventLog,
SpinUpInstance,
TaskFinished,
TestCommand,
)
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.common import CommandId, NodeId, SessionId
from exo.shared.types.events import (
Event,
ForwarderEvent,
@@ -38,6 +39,7 @@ class Master:
def __init__(
self,
node_id: NodeId,
session_id: SessionId,
*,
command_receiver: Receiver[ForwarderCommand],
# Receiving indexed events from the forwarder to be applied to state
@@ -51,6 +53,7 @@ class Master:
self.state = State()
self._tg: TaskGroup | None = None
self.node_id = node_id
self.session_id = session_id
self.command_task_mapping: dict[CommandId, TaskId] = {}
self.command_receiver = command_receiver
self.local_event_receiver = local_event_receiver
@@ -93,6 +96,8 @@ class Master:
generated_events: list[Event] = []
command = forwarder_command.command
match command:
case TestCommand():
pass
case ChatCompletion():
instance_task_counts: dict[InstanceId, int] = {}
for instance in self.state.instances.values():
@@ -184,6 +189,9 @@ class Master:
async def _event_processor(self) -> None:
with self.local_event_receiver as local_events:
async for local_event in local_events:
# Discard all events not from our session
if local_event.session != self.session_id:
continue
self._multi_buffer.ingest(
local_event.origin_idx,
local_event.event,
@@ -221,6 +229,7 @@ class Master:
ForwarderEvent(
origin=NodeId(f"master_{self.node_id}"),
origin_idx=local_index,
session=self.session_id,
event=event,
)
)
@@ -233,6 +242,7 @@ class Master:
ForwarderEvent(
origin=self.node_id,
origin_idx=event.idx,
session=self.session_id,
event=event.event,
)
)

View File

@@ -13,7 +13,7 @@ from exo.shared.types.commands import (
CreateInstance,
ForwarderCommand,
)
from exo.shared.types.common import NodeId
from exo.shared.types.common import NodeId, SessionId
from exo.shared.types.events import (
ForwarderEvent,
IndexedEvent,
@@ -38,6 +38,7 @@ from exo.utils.channels import channel
async def test_master():
keypair = get_node_id_keypair()
node_id = NodeId(keypair.to_peer_id().to_base58())
session_id = SessionId(master_node_id=node_id, election_clock=0)
ge_sender, global_event_receiver = channel[ForwarderEvent]()
command_sender, co_receiver = channel[ForwarderCommand]()
@@ -58,6 +59,7 @@ async def test_master():
master = Master(
node_id,
session_id,
global_event_sender=ge_sender,
local_event_receiver=le_receiver,
command_receiver=co_receiver,
@@ -74,6 +76,7 @@ async def test_master():
ForwarderEvent(
origin_idx=0,
origin=sender_node_id,
session=session_id,
event=(
NodePerformanceMeasured(
node_id=node_id,

View File

@@ -200,15 +200,15 @@ class Router:
await router.publish(message)
async def _networking_publish(self):
# This with/for pattern ensures this method doesn't return until after the receiver closes
# This is good for safety, but is mostly a redundant check.
with self.networking_receiver as networked_items:
async for topic, data in networked_items:
try:
logger.trace(f"Sending message on {topic} with payload {data}")
await self._net.gossipsub_publish(topic, data)
# As a hack, this also catches AllQueuesFull
# Need to fix that ASAP.
except NoPeersSubscribedToTopicError:
logger.trace(f"Failed to send over {topic} - No peers found.")
pass
def get_node_id_keypair(

View File

@@ -11,7 +11,8 @@ from anyio.abc import TaskGroup
from loguru import logger
from exo.routing.connection_message import ConnectionMessage
from exo.shared.types.common import NodeId
from exo.shared.types.commands import ForwarderCommand
from exo.shared.types.common import NodeId, SessionId
from exo.utils.channels import Receiver, Sender
from exo.utils.pydantic_ext import CamelCaseModel
@@ -21,18 +22,24 @@ ELECTION_TIMEOUT = 3.0
class ElectionMessage(CamelCaseModel):
clock: int
seniority: int
node_id: NodeId
proposed_session: SessionId
commands_seen: int
# Could eventually include a list of neighbour nodes for centrality
def __lt__(self, other: Self):
def __lt__(self, other: Self) -> bool:
if self.seniority != other.seniority:
return self.seniority < other.seniority
elif self.commands_seen != other.commands_seen:
return self.commands_seen < other.commands_seen
else:
return self.node_id < other.node_id
return (
self.proposed_session.master_node_id
< other.proposed_session.master_node_id
)
class ElectionResult(CamelCaseModel):
node_id: NodeId
session_id: SessionId
is_new_master: bool
historic_messages: list[ConnectionMessage]
@@ -41,11 +48,12 @@ class Election:
def __init__(
self,
node_id: NodeId,
*,
election_message_receiver: Receiver[ElectionMessage],
election_message_sender: Sender[ElectionMessage],
election_result_sender: Sender[ElectionResult],
connection_message_receiver: Receiver[ConnectionMessage],
*,
command_receiver: Receiver[ForwarderCommand],
is_candidate: bool = True,
seniority: int = 0,
):
@@ -55,13 +63,18 @@ class Election:
self.seniority = seniority if is_candidate else -1
self.clock = 0
self.node_id = node_id
self.commands_seen = 0
# Every node spawns as master
self.master_node_id: NodeId = node_id
self.current_session: SessionId = SessionId(
master_node_id=node_id, election_clock=0
)
# Senders/Receivers
self._em_sender = election_message_sender
self._em_receiver = election_message_receiver
self._er_sender = election_result_sender
self._cm_receiver = connection_message_receiver
self._co_receiver = command_receiver
# Campaign state
self._candidates: list[ElectionMessage] = []
@@ -76,6 +89,7 @@ class Election:
self._tg = tg
tg.start_soon(self._election_receiver)
tg.start_soon(self._connection_receiver)
tg.start_soon(self._command_counter)
await self._campaign(None)
if self._campaign_cancel_scope is not None:
@@ -84,12 +98,12 @@ class Election:
if self._campaign_done is not None:
await self._campaign_done.wait()
async def elect(self, node_id: NodeId) -> None:
is_new_master = node_id != self.master_node_id
self.master_node_id = node_id
async def elect(self, em: ElectionMessage) -> None:
is_new_master = em.proposed_session != self.current_session
self.current_session = em.proposed_session
await self._er_sender.send(
ElectionResult(
node_id=node_id,
session_id=em.proposed_session,
is_new_master=is_new_master,
historic_messages=self._connection_messages,
)
@@ -106,7 +120,7 @@ class Election:
async def _election_receiver(self) -> None:
with self._em_receiver as election_messages:
async for message in election_messages:
if message.node_id == self.node_id:
if message.proposed_session.master_node_id == self.node_id:
# Drop messages from us (See exo.routing.router)
continue
# If a new round is starting, we participate
@@ -129,6 +143,11 @@ class Election:
await self._campaign(None)
self._connection_messages.append(msg)
async def _command_counter(self) -> None:
with self._co_receiver as commands:
async for _command in commands:
self.commands_seen += 1
async def _campaign(self, initial_message: ElectionMessage | None) -> None:
# Kill the old campaign
if self._campaign_cancel_scope:
@@ -167,10 +186,15 @@ class Election:
candidates = sorted(candidates)
logger.debug(f"Election queue {candidates}")
elected = candidates[-1]
logger.info("Election finished")
if self.node_id == elected.node_id and self.seniority >= 0:
if (
self.node_id == elected.proposed_session.master_node_id
and self.seniority >= 0
):
self.seniority = max(self.seniority, len(candidates))
await self.elect(elected.node_id)
logger.info(
f"Election finished, new SessionId({elected.proposed_session})"
)
await self.elect(elected)
except get_cancelled_exc_class():
logger.info("Election cancelled")
finally:
@@ -180,4 +204,13 @@ class Election:
def _election_status(self, clock: int | None = None) -> ElectionMessage:
c = self.clock if clock is None else clock
return ElectionMessage(clock=c, seniority=self.seniority, node_id=self.node_id)
return ElectionMessage(
proposed_session=(
self.current_session
if self.current_session.master_node_id == self.node_id
else SessionId(master_node_id=self.node_id, election_clock=c)
),
clock=c,
seniority=self.seniority,
commands_seen=self.commands_seen,
)

View File

@@ -3,7 +3,8 @@ from anyio import create_task_group, fail_after, move_on_after
from exo.routing.connection_message import ConnectionMessage, ConnectionMessageType
from exo.shared.election import Election, ElectionMessage, ElectionResult
from exo.shared.types.common import NodeId
from exo.shared.types.commands import ForwarderCommand, TestCommand
from exo.shared.types.common import NodeId, SessionId
from exo.utils.channels import channel
# ======= #
@@ -11,8 +12,28 @@ from exo.utils.channels import channel
# ======= #
def em(clock: int, seniority: int, node_id: str) -> ElectionMessage:
return ElectionMessage(clock=clock, seniority=seniority, node_id=NodeId(node_id))
def em(
clock: int,
seniority: int,
node_id: str,
commands_seen: int = 0,
election_clock: int | None = None,
) -> ElectionMessage:
"""
Helper to build ElectionMessages for a given proposer node.
The new API carries a proposed SessionId (master_node_id + election_clock).
By default we use the same value for election_clock as the 'clock' of the round.
"""
return ElectionMessage(
clock=clock,
seniority=seniority,
proposed_session=SessionId(
master_node_id=NodeId(node_id),
election_clock=clock if election_clock is None else election_clock,
),
commands_seen=commands_seen,
)
@pytest.fixture
@@ -43,8 +64,10 @@ async def test_single_round_broadcasts_and_updates_seniority_on_self_win(
em_in_tx, em_in_rx = channel[ElectionMessage]()
# Election results produced by the Election (we'll observe these)
er_tx, er_rx = channel[ElectionResult]()
# Connection messages (unused in this test but required by ctor)
# Connection messages
cm_tx, cm_rx = channel[ConnectionMessage]()
# Commands
co_tx, co_rx = channel[ForwarderCommand]()
election = Election(
node_id=NodeId("B"),
@@ -52,6 +75,7 @@ async def test_single_round_broadcasts_and_updates_seniority_on_self_win(
election_message_sender=em_out_tx,
election_result_sender=er_tx,
connection_message_receiver=cm_rx,
command_receiver=co_rx,
is_candidate=True,
)
@@ -64,18 +88,21 @@ async def test_single_round_broadcasts_and_updates_seniority_on_self_win(
# Expect our broadcast back to the peer side for this round only
while True:
got = await em_out_rx.receive()
if got.clock == 1 and got.node_id == NodeId("B"):
if got.clock == 1 and got.proposed_session.master_node_id == NodeId(
"B"
):
break
# Wait for the round to finish and produce an ElectionResult
result = await er_rx.receive()
assert result.node_id == NodeId("B")
assert result.session_id.master_node_id == NodeId("B")
# We spawned as master; electing ourselves again is not "new master".
assert result.is_new_master is False
# Close inbound streams to end the receivers (and run())
await em_in_tx.aclose()
await cm_tx.aclose()
em_in_tx.close()
cm_tx.close()
co_tx.close()
# We should have updated seniority to 2 (A + B).
assert election.seniority == 2
@@ -93,6 +120,7 @@ async def test_peer_with_higher_seniority_wins_and_we_switch_master(
em_in_tx, em_in_rx = channel[ElectionMessage]()
er_tx, er_rx = channel[ElectionResult]()
cm_tx, cm_rx = channel[ConnectionMessage]()
co_tx, co_rx = channel[ForwarderCommand]()
election = Election(
node_id=NodeId("ME"),
@@ -100,6 +128,7 @@ async def test_peer_with_higher_seniority_wins_and_we_switch_master(
election_message_sender=em_out_tx,
election_result_sender=er_tx,
connection_message_receiver=cm_rx,
command_receiver=co_rx,
is_candidate=True,
)
@@ -117,13 +146,19 @@ async def test_peer_with_higher_seniority_wins_and_we_switch_master(
assert got.seniority == 0
break
# After the timeout, election result should report the peer as master
result = await er_rx.receive()
assert result.node_id == NodeId("PEER")
# After the timeout, election result for clock=1 should report the peer as master
# (Skip any earlier result from the boot campaign at clock=0 by filtering on election_clock)
while True:
result = await er_rx.receive()
if result.session_id.election_clock == 1:
break
assert result.session_id.master_node_id == NodeId("PEER")
assert result.is_new_master is True
await em_in_tx.aclose()
await cm_tx.aclose()
em_in_tx.close()
cm_tx.close()
co_tx.close()
# We lost → seniority unchanged
assert election.seniority == 0
@@ -139,6 +174,7 @@ async def test_ignores_older_messages(fast_timeout: None) -> None:
em_in_tx, em_in_rx = channel[ElectionMessage]()
er_tx, _er_rx = channel[ElectionResult]()
cm_tx, cm_rx = channel[ConnectionMessage]()
co_tx, co_rx = channel[ForwarderCommand]()
election = Election(
node_id=NodeId("ME"),
@@ -146,6 +182,7 @@ async def test_ignores_older_messages(fast_timeout: None) -> None:
election_message_sender=em_out_tx,
election_result_sender=er_tx,
connection_message_receiver=cm_rx,
command_receiver=co_rx,
is_candidate=True,
)
@@ -169,8 +206,9 @@ async def test_ignores_older_messages(fast_timeout: None) -> None:
got_second = True
assert not got_second, "Should not receive a broadcast for an older round"
await em_in_tx.aclose()
await cm_tx.aclose()
em_in_tx.close()
cm_tx.close()
co_tx.close()
# Not asserting on the result; focus is on ignore behavior.
@@ -186,6 +224,7 @@ async def test_two_rounds_emit_two_broadcasts_and_increment_clock(
em_in_tx, em_in_rx = channel[ElectionMessage]()
er_tx, _er_rx = channel[ElectionResult]()
cm_tx, cm_rx = channel[ConnectionMessage]()
co_tx, co_rx = channel[ForwarderCommand]()
election = Election(
node_id=NodeId("ME"),
@@ -193,6 +232,7 @@ async def test_two_rounds_emit_two_broadcasts_and_increment_clock(
election_message_sender=em_out_tx,
election_result_sender=er_tx,
connection_message_receiver=cm_rx,
command_receiver=co_rx,
is_candidate=True,
)
@@ -214,8 +254,9 @@ async def test_two_rounds_emit_two_broadcasts_and_increment_clock(
if m2.clock == 2:
break
await em_in_tx.aclose()
await cm_tx.aclose()
em_in_tx.close()
cm_tx.close()
co_tx.close()
# Not asserting on who won; just that both rounds were broadcast.
@@ -230,6 +271,7 @@ async def test_promotion_new_seniority_counts_participants(fast_timeout: None) -
em_in_tx, em_in_rx = channel[ElectionMessage]()
er_tx, er_rx = channel[ElectionResult]()
cm_tx, cm_rx = channel[ConnectionMessage]()
co_tx, co_rx = channel[ForwarderCommand]()
election = Election(
node_id=NodeId("ME"),
@@ -237,6 +279,7 @@ async def test_promotion_new_seniority_counts_participants(fast_timeout: None) -
election_message_sender=em_out_tx,
election_result_sender=er_tx,
connection_message_receiver=cm_rx,
command_receiver=co_rx,
is_candidate=True,
)
@@ -251,14 +294,17 @@ async def test_promotion_new_seniority_counts_participants(fast_timeout: None) -
# We should see exactly one broadcast from us for this round
while True:
got = await em_out_rx.receive()
if got.clock == 7 and got.node_id == NodeId("ME"):
if got.clock == 7 and got.proposed_session.master_node_id == NodeId(
"ME"
):
break
# Wait for the election to finish so seniority updates
_ = await er_rx.receive()
await em_in_tx.aclose()
await cm_tx.aclose()
em_in_tx.close()
cm_tx.close()
co_tx.close()
# We + A + B = 3 → new seniority expected to be 3
assert election.seniority == 3
@@ -276,6 +322,7 @@ async def test_connection_message_triggers_new_round_broadcast(
em_in_tx, em_in_rx = channel[ElectionMessage]()
er_tx, _er_rx = channel[ElectionResult]()
cm_tx, cm_rx = channel[ConnectionMessage]()
co_tx, co_rx = channel[ForwarderCommand]()
election = Election(
node_id=NodeId("ME"),
@@ -283,6 +330,7 @@ async def test_connection_message_triggers_new_round_broadcast(
election_message_sender=em_out_tx,
election_result_sender=er_tx,
connection_message_receiver=cm_rx,
command_receiver=co_rx,
is_candidate=True,
)
@@ -303,11 +351,75 @@ async def test_connection_message_triggers_new_round_broadcast(
# Expect a broadcast for the new round at clock=1
while True:
got = await em_out_rx.receive()
if got.clock == 1 and got.node_id == NodeId("ME"):
if got.clock == 1 and got.proposed_session.master_node_id == NodeId(
"ME"
):
break
# Close promptly to avoid waiting for campaign completion
await em_in_tx.aclose()
await cm_tx.aclose()
em_in_tx.close()
cm_tx.close()
co_tx.close()
# After cancellation (before election finishes), no seniority changes asserted here.
@pytest.mark.anyio
async def test_tie_breaker_prefers_node_with_more_commands_seen(
fast_timeout: None,
) -> None:
"""
With equal seniority, the node that has seen more commands should win the election.
We increase our local 'commands_seen' by sending TestCommand()s before triggering the round.
"""
em_out_tx, em_out_rx = channel[ElectionMessage]()
em_in_tx, em_in_rx = channel[ElectionMessage]()
er_tx, er_rx = channel[ElectionResult]()
cm_tx, cm_rx = channel[ConnectionMessage]()
co_tx, co_rx = channel[ForwarderCommand]()
me = NodeId("ME")
election = Election(
node_id=me,
election_message_receiver=em_in_rx,
election_message_sender=em_out_tx,
election_result_sender=er_tx,
connection_message_receiver=cm_rx,
command_receiver=co_rx,
is_candidate=True,
seniority=0,
)
async with create_task_group() as tg:
with fail_after(2):
tg.start_soon(election.run)
# Pump local commands so our commands_seen is high before the round starts
for _ in range(50):
await co_tx.send(
ForwarderCommand(origin=NodeId("SOMEONE"), command=TestCommand())
)
# Trigger a round at clock=1 with a peer of equal seniority but fewer commands
await em_in_tx.send(
em(clock=1, seniority=0, node_id="PEER", commands_seen=5)
)
# Observe our broadcast for this round (to ensure we've joined the round)
while True:
got = await em_out_rx.receive()
if got.clock == 1 and got.proposed_session.master_node_id == me:
# We don't assert exact count, just that we've participated this round.
break
# The elected result for clock=1 should be us due to higher commands_seen
while True:
result = await er_rx.receive()
if result.session_id.master_node_id == me:
assert result.session_id.election_clock in (0, 1)
break
em_in_tx.close()
cm_tx.close()
co_tx.close()

View File

@@ -1,5 +1,3 @@
from enum import Enum
from pydantic import Field
from exo.shared.types.api import ChatCompletionTaskParams
@@ -10,19 +8,14 @@ from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
# TODO: We need to have a distinction between create instance and spin up instance.
class CommandType(str, Enum):
ChatCompletion = "ChatCompletion"
CreateInstance = "CreateInstance"
SpinUpInstance = "SpinUpInstance"
DeleteInstance = "DeleteInstance"
TaskFinished = "TaskFinished"
RequestEventLog = "RequestEventLog"
class BaseCommand(TaggedModel):
command_id: CommandId = Field(default_factory=CommandId)
class TestCommand(BaseCommand):
pass
class ChatCompletion(BaseCommand):
request_params: ChatCompletionTaskParams
@@ -48,7 +41,8 @@ class RequestEventLog(BaseCommand):
Command = (
RequestEventLog
TestCommand
| RequestEventLog
| ChatCompletion
| CreateInstance
| SpinUpInstance

View File

@@ -23,6 +23,11 @@ class NodeId(Id):
pass
class SessionId(CamelCaseModel):
master_node_id: NodeId
election_clock: int
class CommandId(Id):
pass

View File

@@ -4,7 +4,7 @@ from pydantic import Field
from exo.shared.topology import Connection, NodePerformanceProfile
from exo.shared.types.chunks import CommandId, GenerationChunk
from exo.shared.types.common import Id, NodeId
from exo.shared.types.common import Id, NodeId, SessionId
from exo.shared.types.profiling import MemoryPerformanceProfile
from exo.shared.types.tasks import Task, TaskId, TaskStatus
from exo.shared.types.worker.common import InstanceId, WorkerStatus
@@ -177,4 +177,5 @@ class ForwarderEvent(CamelCaseModel):
origin_idx: int = Field(ge=0)
origin: NodeId
session: SessionId
event: Event

View File

@@ -13,7 +13,7 @@ from loguru import logger
from exo.routing.connection_message import ConnectionMessage, ConnectionMessageType
from exo.shared.apply import apply
from exo.shared.types.commands import ForwarderCommand, RequestEventLog
from exo.shared.types.common import NodeId
from exo.shared.types.common import NodeId, SessionId
from exo.shared.types.events import (
ChunkGenerated,
Event,
@@ -75,6 +75,7 @@ class Worker:
def __init__(
self,
node_id: NodeId,
session_id: SessionId,
shard_downloader: ShardDownloader,
*,
initial_connection_messages: list[ConnectionMessage],
@@ -91,6 +92,7 @@ class Worker:
command_sender: Sender[ForwarderCommand],
):
self.node_id: NodeId = node_id
self.session_id: SessionId = session_id
self.shard_downloader: ShardDownloader = shard_downloader
self.global_event_receiver = global_event_receiver
self.local_event_sender = local_event_sender
@@ -634,6 +636,7 @@ class Worker:
fe = ForwarderEvent(
origin_idx=self.local_event_index,
origin=self.node_id,
session=self.session_id,
event=event,
)
logger.debug(

View File

@@ -5,12 +5,15 @@ from anyio import fail_after
from exo.routing.topics import ConnectionMessage, ForwarderCommand, ForwarderEvent
from exo.shared.types.chunks import TokenChunk
from exo.shared.types.common import NodeId
from exo.shared.types.common import NodeId, SessionId
from exo.shared.types.events import ChunkGenerated, Event, TaskStateUpdated
from exo.shared.types.tasks import TaskId, TaskStatus
from exo.utils.channels import Receiver, Sender, channel
from exo.worker.download.shard_downloader import NoopShardDownloader, ShardDownloader
from exo.worker.main import Worker
from exo.worker.tests.constants import MASTER_NODE_ID
session = SessionId(master_node_id=MASTER_NODE_ID, election_clock=0)
@dataclass
@@ -19,11 +22,17 @@ class WorkerMailbox:
receiver: Receiver[ForwarderEvent]
counter: int = 0
async def append_events(self, events: list[Event], *, origin: NodeId):
async def append_events(
self,
events: list[Event],
*,
origin: NodeId,
):
for event in events:
await self.sender.send(
ForwarderEvent(
origin=origin,
session=session,
event=event,
origin_idx=self.counter,
)
@@ -45,6 +54,7 @@ def create_worker_void_mailbox(
shard_downloader = NoopShardDownloader()
return Worker(
node_id,
session_id=session,
shard_downloader=shard_downloader,
initial_connection_messages=[],
connection_message_receiver=channel[ConnectionMessage]()[1],
@@ -64,6 +74,7 @@ def create_worker_and_mailbox(
sender, grecv = channel[ForwarderEvent]()
worker = Worker(
node_id,
session_id=session,
shard_downloader=shard_downloader,
initial_connection_messages=[],
connection_message_receiver=channel[ConnectionMessage]()[1],
@@ -84,6 +95,7 @@ def create_worker_with_old_mailbox(
# This function is subtly complex, come talk to Evan if you want to know what it's actually doing.
worker = Worker(
node_id,
session_id=session,
shard_downloader=shard_downloader,
initial_connection_messages=[],
connection_message_receiver=channel[ConnectionMessage]()[1],