mirror of
https://github.com/exo-explore/exo.git
synced 2025-12-23 22:27:50 -05:00
Squash merge merging_clusters into tensor_parallel94
This commit is contained in:
@@ -25,7 +25,7 @@ workspace = true
|
||||
networking = { workspace = true }
|
||||
|
||||
# interop
|
||||
pyo3 = { version = "0.25.1", features = [# TODO: migrate to v0.26 soon!!
|
||||
pyo3 = { version = "0.27.1", features = [
|
||||
# "abi3-py311", # tells pyo3 (and maturin) to build using the stable ABI with minimum Python version 3.11
|
||||
"nightly", # enables better-supported GIL integration
|
||||
"experimental-async", # async support in #[pyfunction] & #[pymethods]
|
||||
@@ -38,8 +38,9 @@ pyo3 = { version = "0.25.1", features = [# TODO: migrate to v0.26 soon!!
|
||||
"ordered-float", "rust_decimal", "smallvec",
|
||||
# "anyhow", "chrono", "chrono-local", "chrono-tz", "eyre", "jiff-02", "lock_api", "parking-lot", "time", "serde",
|
||||
] }
|
||||
pyo3-stub-gen = { version = "0.13.1" }
|
||||
pyo3-async-runtimes = { version = "0.25", features = ["attributes", "tokio-runtime", "testing"] }
|
||||
pyo3-stub-gen = { version = "0.16.1" }
|
||||
pyo3-async-runtimes = { version = "0.27.0", features = ["attributes", "tokio-runtime", "testing"] }
|
||||
pyo3-log = "0.13.2"
|
||||
|
||||
# macro dependencies
|
||||
extend = { workspace = true }
|
||||
@@ -70,7 +71,6 @@ thiserror = { workspace = true }
|
||||
#tracing-log = "0.2.0"
|
||||
log = { workspace = true }
|
||||
env_logger = "0.11"
|
||||
pyo3-log = "0.12"
|
||||
|
||||
|
||||
# Networking
|
||||
|
||||
@@ -166,6 +166,8 @@ async fn networking_task(
|
||||
IdentTopic::new(topic), data);
|
||||
let pyresult: PyResult<MessageId> = if let Err(PublishError::NoPeersSubscribedToTopic) = result {
|
||||
Err(exception::PyNoPeersSubscribedToTopicError::new_err())
|
||||
} else if let Err(PublishError::AllQueuesFull(_)) = result {
|
||||
Err(exception::PyNoPeersSubscribedToTopicError::new_err())
|
||||
} else {
|
||||
result.pyerr()
|
||||
};
|
||||
|
||||
@@ -95,6 +95,7 @@ mod transport {
|
||||
|
||||
mod behaviour {
|
||||
use crate::{alias, discovery};
|
||||
use std::time::Duration;
|
||||
use libp2p::swarm::NetworkBehaviour;
|
||||
use libp2p::{gossipsub, identity};
|
||||
|
||||
@@ -124,6 +125,7 @@ mod behaviour {
|
||||
gossipsub::Behaviour::new(
|
||||
MessageAuthenticity::Signed(keypair.clone()),
|
||||
ConfigBuilder::default()
|
||||
.publish_queue_duration(Duration::from_secs(15))
|
||||
.validation_mode(ValidationMode::Strict)
|
||||
.build()
|
||||
.expect("the configuration should always be valid"),
|
||||
|
||||
@@ -14,7 +14,7 @@ from exo.routing.router import Router, get_node_id_keypair
|
||||
from exo.shared.constants import EXO_LOG
|
||||
from exo.shared.election import Election, ElectionResult
|
||||
from exo.shared.logging import logger_cleanup, logger_setup
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.common import NodeId, SessionId
|
||||
from exo.utils.channels import Receiver, channel
|
||||
from exo.utils.pydantic_ext import CamelCaseModel
|
||||
from exo.worker.download.impl_shard_downloader import exo_shard_downloader
|
||||
@@ -40,6 +40,7 @@ class Node:
|
||||
async def create(cls, args: "Args") -> "Self":
|
||||
keypair = get_node_id_keypair()
|
||||
node_id = NodeId(keypair.to_peer_id().to_base58())
|
||||
session_id = SessionId(master_node_id=node_id, election_clock=0)
|
||||
router = Router.create(keypair)
|
||||
await router.register_topic(topics.GLOBAL_EVENTS)
|
||||
await router.register_topic(topics.LOCAL_EVENTS)
|
||||
@@ -50,16 +51,19 @@ class Node:
|
||||
logger.info(f"Starting node {node_id}")
|
||||
if args.spawn_api:
|
||||
api = API(
|
||||
node_id=node_id,
|
||||
node_id,
|
||||
session_id,
|
||||
port=args.api_port,
|
||||
global_event_receiver=router.receiver(topics.GLOBAL_EVENTS),
|
||||
command_sender=router.sender(topics.COMMANDS),
|
||||
election_receiver=router.receiver(topics.ELECTION_MESSAGES),
|
||||
)
|
||||
else:
|
||||
api = None
|
||||
|
||||
worker = Worker(
|
||||
node_id,
|
||||
session_id,
|
||||
exo_shard_downloader(),
|
||||
initial_connection_messages=[],
|
||||
connection_message_receiver=router.receiver(topics.CONNECTION_MESSAGES),
|
||||
@@ -70,22 +74,24 @@ class Node:
|
||||
# We start every node with a master
|
||||
master = Master(
|
||||
node_id,
|
||||
session_id,
|
||||
global_event_sender=router.sender(topics.GLOBAL_EVENTS),
|
||||
local_event_receiver=router.receiver(topics.LOCAL_EVENTS),
|
||||
command_receiver=router.receiver(topics.COMMANDS),
|
||||
tb_only=args.tb_only,
|
||||
)
|
||||
|
||||
# If someone manages to assemble 1 MILLION devices into an exo cluster then. well done. good job champ.
|
||||
er_send, er_recv = channel[ElectionResult]()
|
||||
election = Election(
|
||||
node_id,
|
||||
# If someone manages to assemble 1 MILLION devices into an exo cluster then. well done. good job champ.
|
||||
seniority=1_000_000 if args.force_master else 0,
|
||||
# nb: this DOES feedback right now. i have thoughts on how to address this,
|
||||
# but ultimately it seems not worth the complexity
|
||||
election_message_sender=router.sender(topics.ELECTION_MESSAGES),
|
||||
election_message_receiver=router.receiver(topics.ELECTION_MESSAGES),
|
||||
connection_message_receiver=router.receiver(topics.CONNECTION_MESSAGES),
|
||||
command_receiver=router.receiver(topics.COMMANDS),
|
||||
election_result_sender=er_send,
|
||||
)
|
||||
|
||||
@@ -107,6 +113,9 @@ class Node:
|
||||
assert self._tg
|
||||
with self.election_result_receiver as results:
|
||||
async for result in results:
|
||||
# This function continues to have a lot of very specific entangled logic
|
||||
# At least it's somewhat contained
|
||||
|
||||
# I don't like this duplication, but it's manageable for now.
|
||||
# TODO: This function needs refactoring generally
|
||||
|
||||
@@ -116,23 +125,35 @@ class Node:
|
||||
# - Shutdown and re-create the worker
|
||||
# - Shut down and re-create the API
|
||||
|
||||
if result.node_id == self.node_id and self.master is not None:
|
||||
if (
|
||||
result.session_id.master_node_id == self.node_id
|
||||
and self.master is not None
|
||||
):
|
||||
logger.info("Node elected Master")
|
||||
elif result.node_id == self.node_id and self.master is None:
|
||||
elif (
|
||||
result.session_id.master_node_id == self.node_id
|
||||
and self.master is None
|
||||
):
|
||||
logger.info("Node elected Master - promoting self")
|
||||
self.master = Master(
|
||||
self.node_id,
|
||||
result.session_id,
|
||||
global_event_sender=self.router.sender(topics.GLOBAL_EVENTS),
|
||||
local_event_receiver=self.router.receiver(topics.LOCAL_EVENTS),
|
||||
command_receiver=self.router.receiver(topics.COMMANDS),
|
||||
)
|
||||
self._tg.start_soon(self.master.run)
|
||||
elif result.node_id != self.node_id and self.master is not None:
|
||||
logger.info(f"Node {result.node_id} elected master - demoting self")
|
||||
elif (
|
||||
result.session_id.master_node_id != self.node_id
|
||||
and self.master is not None
|
||||
):
|
||||
logger.info(
|
||||
f"Node {result.session_id.master_node_id} elected master - demoting self"
|
||||
)
|
||||
await self.master.shutdown()
|
||||
self.master = None
|
||||
else:
|
||||
logger.info(f"Node {result.node_id} elected master")
|
||||
logger.info(f"Node {result.session_id.master_node_id} elected master")
|
||||
if result.is_new_master:
|
||||
await anyio.sleep(0)
|
||||
if self.worker:
|
||||
@@ -140,6 +161,7 @@ class Node:
|
||||
# TODO: add profiling etc to resource monitor
|
||||
self.worker = Worker(
|
||||
self.node_id,
|
||||
result.session_id,
|
||||
exo_shard_downloader(),
|
||||
initial_connection_messages=result.historic_messages,
|
||||
connection_message_receiver=self.router.receiver(
|
||||
@@ -153,7 +175,10 @@ class Node:
|
||||
)
|
||||
self._tg.start_soon(self.worker.run)
|
||||
if self.api:
|
||||
self.api.reset()
|
||||
self.api.reset(result.session_id)
|
||||
else:
|
||||
if self.api:
|
||||
self.api.unpause()
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
@@ -5,6 +5,7 @@ from collections.abc import AsyncGenerator
|
||||
from typing import final
|
||||
|
||||
import uvicorn
|
||||
from anyio import Event as AsyncTaskEvent
|
||||
from anyio import create_task_group
|
||||
from anyio.abc import TaskGroup
|
||||
from fastapi import FastAPI, HTTPException
|
||||
@@ -14,6 +15,7 @@ from fastapi.staticfiles import StaticFiles
|
||||
from loguru import logger
|
||||
|
||||
from exo.shared.apply import apply
|
||||
from exo.shared.election import ElectionMessage
|
||||
from exo.shared.models.model_cards import MODEL_CARDS
|
||||
from exo.shared.models.model_meta import get_model_meta
|
||||
from exo.shared.types.api import (
|
||||
@@ -36,7 +38,7 @@ from exo.shared.types.commands import (
|
||||
# TODO: SpinUpInstance
|
||||
TaskFinished,
|
||||
)
|
||||
from exo.shared.types.common import CommandId, NodeId
|
||||
from exo.shared.types.common import CommandId, NodeId, SessionId
|
||||
from exo.shared.types.events import ChunkGenerated, Event, ForwarderEvent, IndexedEvent
|
||||
from exo.shared.types.models import ModelMetadata
|
||||
from exo.shared.types.state import State
|
||||
@@ -74,20 +76,28 @@ async def resolve_model_meta(model_id: str) -> ModelMetadata:
|
||||
class API:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
node_id: NodeId,
|
||||
session_id: SessionId,
|
||||
*,
|
||||
port: int = 8000,
|
||||
# Ideally this would be a MasterForwarderEvent but type system says no :(
|
||||
global_event_receiver: Receiver[ForwarderEvent],
|
||||
command_sender: Sender[ForwarderCommand],
|
||||
# This lets us pause the API if an election is running
|
||||
election_receiver: Receiver[ElectionMessage],
|
||||
) -> None:
|
||||
self.state = State()
|
||||
self.command_sender = command_sender
|
||||
self.global_event_receiver = global_event_receiver
|
||||
self.election_receiver = election_receiver
|
||||
self.event_buffer: OrderedBuffer[Event] = OrderedBuffer[Event]()
|
||||
self.node_id: NodeId = node_id
|
||||
self.session_id: SessionId = session_id
|
||||
self.port = port
|
||||
|
||||
self.paused: bool = False
|
||||
self.paused_ev: AsyncTaskEvent = AsyncTaskEvent()
|
||||
|
||||
self.app = FastAPI()
|
||||
self._setup_cors()
|
||||
self._setup_routes()
|
||||
@@ -111,10 +121,17 @@ class API:
|
||||
] = {}
|
||||
self._tg: TaskGroup | None = None
|
||||
|
||||
def reset(self):
|
||||
def reset(self, new_session_id: SessionId):
|
||||
self.state = State()
|
||||
self.session_id = new_session_id
|
||||
self.event_buffer = OrderedBuffer[Event]()
|
||||
self._chat_completion_queues = {}
|
||||
self.unpause()
|
||||
|
||||
def unpause(self):
|
||||
self.paused = False
|
||||
self.paused_ev.set()
|
||||
self.paused_ev = AsyncTaskEvent()
|
||||
|
||||
def _setup_cors(self) -> None:
|
||||
self.app.add_middleware(
|
||||
@@ -160,10 +177,9 @@ class API:
|
||||
)
|
||||
|
||||
def get_instance(self, instance_id: InstanceId) -> Instance:
|
||||
state = self.state
|
||||
if instance_id not in state.instances:
|
||||
if instance_id not in self.state.instances:
|
||||
raise HTTPException(status_code=404, detail="Instance not found")
|
||||
return state.instances[instance_id]
|
||||
return self.state.instances[instance_id]
|
||||
|
||||
async def delete_instance(self, instance_id: InstanceId) -> DeleteInstanceResponse:
|
||||
if instance_id not in self.state.instances:
|
||||
@@ -299,6 +315,7 @@ class API:
|
||||
logger.info("Starting API")
|
||||
tg.start_soon(uvicorn_server.serve)
|
||||
tg.start_soon(self._apply_state)
|
||||
tg.start_soon(self._pause_on_new_election)
|
||||
self.command_sender.close()
|
||||
self.global_event_receiver.close()
|
||||
|
||||
@@ -314,7 +331,15 @@ class API:
|
||||
):
|
||||
self._chat_completion_queues[event.command_id].put_nowait(event)
|
||||
|
||||
async def _pause_on_new_election(self):
|
||||
with self.election_receiver as ems:
|
||||
async for message in ems:
|
||||
if message.clock > self.session_id.election_clock:
|
||||
self.paused = True
|
||||
|
||||
async def _send(self, command: Command):
|
||||
while self.paused:
|
||||
await self.paused_ev.wait()
|
||||
await self.command_sender.send(
|
||||
ForwarderCommand(origin=self.node_id, command=command)
|
||||
)
|
||||
|
||||
@@ -16,8 +16,9 @@ from exo.shared.types.commands import (
|
||||
RequestEventLog,
|
||||
SpinUpInstance,
|
||||
TaskFinished,
|
||||
TestCommand,
|
||||
)
|
||||
from exo.shared.types.common import CommandId, NodeId
|
||||
from exo.shared.types.common import CommandId, NodeId, SessionId
|
||||
from exo.shared.types.events import (
|
||||
Event,
|
||||
ForwarderEvent,
|
||||
@@ -38,6 +39,7 @@ class Master:
|
||||
def __init__(
|
||||
self,
|
||||
node_id: NodeId,
|
||||
session_id: SessionId,
|
||||
*,
|
||||
command_receiver: Receiver[ForwarderCommand],
|
||||
# Receiving indexed events from the forwarder to be applied to state
|
||||
@@ -51,6 +53,7 @@ class Master:
|
||||
self.state = State()
|
||||
self._tg: TaskGroup | None = None
|
||||
self.node_id = node_id
|
||||
self.session_id = session_id
|
||||
self.command_task_mapping: dict[CommandId, TaskId] = {}
|
||||
self.command_receiver = command_receiver
|
||||
self.local_event_receiver = local_event_receiver
|
||||
@@ -93,6 +96,8 @@ class Master:
|
||||
generated_events: list[Event] = []
|
||||
command = forwarder_command.command
|
||||
match command:
|
||||
case TestCommand():
|
||||
pass
|
||||
case ChatCompletion():
|
||||
instance_task_counts: dict[InstanceId, int] = {}
|
||||
for instance in self.state.instances.values():
|
||||
@@ -184,6 +189,9 @@ class Master:
|
||||
async def _event_processor(self) -> None:
|
||||
with self.local_event_receiver as local_events:
|
||||
async for local_event in local_events:
|
||||
# Discard all events not from our session
|
||||
if local_event.session != self.session_id:
|
||||
continue
|
||||
self._multi_buffer.ingest(
|
||||
local_event.origin_idx,
|
||||
local_event.event,
|
||||
@@ -221,6 +229,7 @@ class Master:
|
||||
ForwarderEvent(
|
||||
origin=NodeId(f"master_{self.node_id}"),
|
||||
origin_idx=local_index,
|
||||
session=self.session_id,
|
||||
event=event,
|
||||
)
|
||||
)
|
||||
@@ -233,6 +242,7 @@ class Master:
|
||||
ForwarderEvent(
|
||||
origin=self.node_id,
|
||||
origin_idx=event.idx,
|
||||
session=self.session_id,
|
||||
event=event.event,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -13,7 +13,7 @@ from exo.shared.types.commands import (
|
||||
CreateInstance,
|
||||
ForwarderCommand,
|
||||
)
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.common import NodeId, SessionId
|
||||
from exo.shared.types.events import (
|
||||
ForwarderEvent,
|
||||
IndexedEvent,
|
||||
@@ -38,6 +38,7 @@ from exo.utils.channels import channel
|
||||
async def test_master():
|
||||
keypair = get_node_id_keypair()
|
||||
node_id = NodeId(keypair.to_peer_id().to_base58())
|
||||
session_id = SessionId(master_node_id=node_id, election_clock=0)
|
||||
|
||||
ge_sender, global_event_receiver = channel[ForwarderEvent]()
|
||||
command_sender, co_receiver = channel[ForwarderCommand]()
|
||||
@@ -58,6 +59,7 @@ async def test_master():
|
||||
|
||||
master = Master(
|
||||
node_id,
|
||||
session_id,
|
||||
global_event_sender=ge_sender,
|
||||
local_event_receiver=le_receiver,
|
||||
command_receiver=co_receiver,
|
||||
@@ -74,6 +76,7 @@ async def test_master():
|
||||
ForwarderEvent(
|
||||
origin_idx=0,
|
||||
origin=sender_node_id,
|
||||
session=session_id,
|
||||
event=(
|
||||
NodePerformanceMeasured(
|
||||
node_id=node_id,
|
||||
|
||||
@@ -200,15 +200,15 @@ class Router:
|
||||
await router.publish(message)
|
||||
|
||||
async def _networking_publish(self):
|
||||
# This with/for pattern ensures this method doesn't return until after the receiver closes
|
||||
# This is good for safety, but is mostly a redundant check.
|
||||
with self.networking_receiver as networked_items:
|
||||
async for topic, data in networked_items:
|
||||
try:
|
||||
logger.trace(f"Sending message on {topic} with payload {data}")
|
||||
await self._net.gossipsub_publish(topic, data)
|
||||
# As a hack, this also catches AllQueuesFull
|
||||
# Need to fix that ASAP.
|
||||
except NoPeersSubscribedToTopicError:
|
||||
logger.trace(f"Failed to send over {topic} - No peers found.")
|
||||
pass
|
||||
|
||||
|
||||
def get_node_id_keypair(
|
||||
|
||||
@@ -11,7 +11,8 @@ from anyio.abc import TaskGroup
|
||||
from loguru import logger
|
||||
|
||||
from exo.routing.connection_message import ConnectionMessage
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.commands import ForwarderCommand
|
||||
from exo.shared.types.common import NodeId, SessionId
|
||||
from exo.utils.channels import Receiver, Sender
|
||||
from exo.utils.pydantic_ext import CamelCaseModel
|
||||
|
||||
@@ -21,18 +22,24 @@ ELECTION_TIMEOUT = 3.0
|
||||
class ElectionMessage(CamelCaseModel):
|
||||
clock: int
|
||||
seniority: int
|
||||
node_id: NodeId
|
||||
proposed_session: SessionId
|
||||
commands_seen: int
|
||||
|
||||
# Could eventually include a list of neighbour nodes for centrality
|
||||
def __lt__(self, other: Self):
|
||||
def __lt__(self, other: Self) -> bool:
|
||||
if self.seniority != other.seniority:
|
||||
return self.seniority < other.seniority
|
||||
elif self.commands_seen != other.commands_seen:
|
||||
return self.commands_seen < other.commands_seen
|
||||
else:
|
||||
return self.node_id < other.node_id
|
||||
return (
|
||||
self.proposed_session.master_node_id
|
||||
< other.proposed_session.master_node_id
|
||||
)
|
||||
|
||||
|
||||
class ElectionResult(CamelCaseModel):
|
||||
node_id: NodeId
|
||||
session_id: SessionId
|
||||
is_new_master: bool
|
||||
historic_messages: list[ConnectionMessage]
|
||||
|
||||
@@ -41,11 +48,12 @@ class Election:
|
||||
def __init__(
|
||||
self,
|
||||
node_id: NodeId,
|
||||
*,
|
||||
election_message_receiver: Receiver[ElectionMessage],
|
||||
election_message_sender: Sender[ElectionMessage],
|
||||
election_result_sender: Sender[ElectionResult],
|
||||
connection_message_receiver: Receiver[ConnectionMessage],
|
||||
*,
|
||||
command_receiver: Receiver[ForwarderCommand],
|
||||
is_candidate: bool = True,
|
||||
seniority: int = 0,
|
||||
):
|
||||
@@ -55,13 +63,18 @@ class Election:
|
||||
self.seniority = seniority if is_candidate else -1
|
||||
self.clock = 0
|
||||
self.node_id = node_id
|
||||
self.commands_seen = 0
|
||||
# Every node spawns as master
|
||||
self.master_node_id: NodeId = node_id
|
||||
self.current_session: SessionId = SessionId(
|
||||
master_node_id=node_id, election_clock=0
|
||||
)
|
||||
|
||||
# Senders/Receivers
|
||||
self._em_sender = election_message_sender
|
||||
self._em_receiver = election_message_receiver
|
||||
self._er_sender = election_result_sender
|
||||
self._cm_receiver = connection_message_receiver
|
||||
self._co_receiver = command_receiver
|
||||
|
||||
# Campaign state
|
||||
self._candidates: list[ElectionMessage] = []
|
||||
@@ -76,6 +89,7 @@ class Election:
|
||||
self._tg = tg
|
||||
tg.start_soon(self._election_receiver)
|
||||
tg.start_soon(self._connection_receiver)
|
||||
tg.start_soon(self._command_counter)
|
||||
await self._campaign(None)
|
||||
|
||||
if self._campaign_cancel_scope is not None:
|
||||
@@ -84,12 +98,12 @@ class Election:
|
||||
if self._campaign_done is not None:
|
||||
await self._campaign_done.wait()
|
||||
|
||||
async def elect(self, node_id: NodeId) -> None:
|
||||
is_new_master = node_id != self.master_node_id
|
||||
self.master_node_id = node_id
|
||||
async def elect(self, em: ElectionMessage) -> None:
|
||||
is_new_master = em.proposed_session != self.current_session
|
||||
self.current_session = em.proposed_session
|
||||
await self._er_sender.send(
|
||||
ElectionResult(
|
||||
node_id=node_id,
|
||||
session_id=em.proposed_session,
|
||||
is_new_master=is_new_master,
|
||||
historic_messages=self._connection_messages,
|
||||
)
|
||||
@@ -106,7 +120,7 @@ class Election:
|
||||
async def _election_receiver(self) -> None:
|
||||
with self._em_receiver as election_messages:
|
||||
async for message in election_messages:
|
||||
if message.node_id == self.node_id:
|
||||
if message.proposed_session.master_node_id == self.node_id:
|
||||
# Drop messages from us (See exo.routing.router)
|
||||
continue
|
||||
# If a new round is starting, we participate
|
||||
@@ -129,6 +143,11 @@ class Election:
|
||||
await self._campaign(None)
|
||||
self._connection_messages.append(msg)
|
||||
|
||||
async def _command_counter(self) -> None:
|
||||
with self._co_receiver as commands:
|
||||
async for _command in commands:
|
||||
self.commands_seen += 1
|
||||
|
||||
async def _campaign(self, initial_message: ElectionMessage | None) -> None:
|
||||
# Kill the old campaign
|
||||
if self._campaign_cancel_scope:
|
||||
@@ -167,10 +186,15 @@ class Election:
|
||||
candidates = sorted(candidates)
|
||||
logger.debug(f"Election queue {candidates}")
|
||||
elected = candidates[-1]
|
||||
logger.info("Election finished")
|
||||
if self.node_id == elected.node_id and self.seniority >= 0:
|
||||
if (
|
||||
self.node_id == elected.proposed_session.master_node_id
|
||||
and self.seniority >= 0
|
||||
):
|
||||
self.seniority = max(self.seniority, len(candidates))
|
||||
await self.elect(elected.node_id)
|
||||
logger.info(
|
||||
f"Election finished, new SessionId({elected.proposed_session})"
|
||||
)
|
||||
await self.elect(elected)
|
||||
except get_cancelled_exc_class():
|
||||
logger.info("Election cancelled")
|
||||
finally:
|
||||
@@ -180,4 +204,13 @@ class Election:
|
||||
|
||||
def _election_status(self, clock: int | None = None) -> ElectionMessage:
|
||||
c = self.clock if clock is None else clock
|
||||
return ElectionMessage(clock=c, seniority=self.seniority, node_id=self.node_id)
|
||||
return ElectionMessage(
|
||||
proposed_session=(
|
||||
self.current_session
|
||||
if self.current_session.master_node_id == self.node_id
|
||||
else SessionId(master_node_id=self.node_id, election_clock=c)
|
||||
),
|
||||
clock=c,
|
||||
seniority=self.seniority,
|
||||
commands_seen=self.commands_seen,
|
||||
)
|
||||
|
||||
@@ -3,7 +3,8 @@ from anyio import create_task_group, fail_after, move_on_after
|
||||
|
||||
from exo.routing.connection_message import ConnectionMessage, ConnectionMessageType
|
||||
from exo.shared.election import Election, ElectionMessage, ElectionResult
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.commands import ForwarderCommand, TestCommand
|
||||
from exo.shared.types.common import NodeId, SessionId
|
||||
from exo.utils.channels import channel
|
||||
|
||||
# ======= #
|
||||
@@ -11,8 +12,28 @@ from exo.utils.channels import channel
|
||||
# ======= #
|
||||
|
||||
|
||||
def em(clock: int, seniority: int, node_id: str) -> ElectionMessage:
|
||||
return ElectionMessage(clock=clock, seniority=seniority, node_id=NodeId(node_id))
|
||||
def em(
|
||||
clock: int,
|
||||
seniority: int,
|
||||
node_id: str,
|
||||
commands_seen: int = 0,
|
||||
election_clock: int | None = None,
|
||||
) -> ElectionMessage:
|
||||
"""
|
||||
Helper to build ElectionMessages for a given proposer node.
|
||||
|
||||
The new API carries a proposed SessionId (master_node_id + election_clock).
|
||||
By default we use the same value for election_clock as the 'clock' of the round.
|
||||
"""
|
||||
return ElectionMessage(
|
||||
clock=clock,
|
||||
seniority=seniority,
|
||||
proposed_session=SessionId(
|
||||
master_node_id=NodeId(node_id),
|
||||
election_clock=clock if election_clock is None else election_clock,
|
||||
),
|
||||
commands_seen=commands_seen,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -43,8 +64,10 @@ async def test_single_round_broadcasts_and_updates_seniority_on_self_win(
|
||||
em_in_tx, em_in_rx = channel[ElectionMessage]()
|
||||
# Election results produced by the Election (we'll observe these)
|
||||
er_tx, er_rx = channel[ElectionResult]()
|
||||
# Connection messages (unused in this test but required by ctor)
|
||||
# Connection messages
|
||||
cm_tx, cm_rx = channel[ConnectionMessage]()
|
||||
# Commands
|
||||
co_tx, co_rx = channel[ForwarderCommand]()
|
||||
|
||||
election = Election(
|
||||
node_id=NodeId("B"),
|
||||
@@ -52,6 +75,7 @@ async def test_single_round_broadcasts_and_updates_seniority_on_self_win(
|
||||
election_message_sender=em_out_tx,
|
||||
election_result_sender=er_tx,
|
||||
connection_message_receiver=cm_rx,
|
||||
command_receiver=co_rx,
|
||||
is_candidate=True,
|
||||
)
|
||||
|
||||
@@ -64,18 +88,21 @@ async def test_single_round_broadcasts_and_updates_seniority_on_self_win(
|
||||
# Expect our broadcast back to the peer side for this round only
|
||||
while True:
|
||||
got = await em_out_rx.receive()
|
||||
if got.clock == 1 and got.node_id == NodeId("B"):
|
||||
if got.clock == 1 and got.proposed_session.master_node_id == NodeId(
|
||||
"B"
|
||||
):
|
||||
break
|
||||
|
||||
# Wait for the round to finish and produce an ElectionResult
|
||||
result = await er_rx.receive()
|
||||
assert result.node_id == NodeId("B")
|
||||
assert result.session_id.master_node_id == NodeId("B")
|
||||
# We spawned as master; electing ourselves again is not "new master".
|
||||
assert result.is_new_master is False
|
||||
|
||||
# Close inbound streams to end the receivers (and run())
|
||||
await em_in_tx.aclose()
|
||||
await cm_tx.aclose()
|
||||
em_in_tx.close()
|
||||
cm_tx.close()
|
||||
co_tx.close()
|
||||
|
||||
# We should have updated seniority to 2 (A + B).
|
||||
assert election.seniority == 2
|
||||
@@ -93,6 +120,7 @@ async def test_peer_with_higher_seniority_wins_and_we_switch_master(
|
||||
em_in_tx, em_in_rx = channel[ElectionMessage]()
|
||||
er_tx, er_rx = channel[ElectionResult]()
|
||||
cm_tx, cm_rx = channel[ConnectionMessage]()
|
||||
co_tx, co_rx = channel[ForwarderCommand]()
|
||||
|
||||
election = Election(
|
||||
node_id=NodeId("ME"),
|
||||
@@ -100,6 +128,7 @@ async def test_peer_with_higher_seniority_wins_and_we_switch_master(
|
||||
election_message_sender=em_out_tx,
|
||||
election_result_sender=er_tx,
|
||||
connection_message_receiver=cm_rx,
|
||||
command_receiver=co_rx,
|
||||
is_candidate=True,
|
||||
)
|
||||
|
||||
@@ -117,13 +146,19 @@ async def test_peer_with_higher_seniority_wins_and_we_switch_master(
|
||||
assert got.seniority == 0
|
||||
break
|
||||
|
||||
# After the timeout, election result should report the peer as master
|
||||
result = await er_rx.receive()
|
||||
assert result.node_id == NodeId("PEER")
|
||||
# After the timeout, election result for clock=1 should report the peer as master
|
||||
# (Skip any earlier result from the boot campaign at clock=0 by filtering on election_clock)
|
||||
while True:
|
||||
result = await er_rx.receive()
|
||||
if result.session_id.election_clock == 1:
|
||||
break
|
||||
|
||||
assert result.session_id.master_node_id == NodeId("PEER")
|
||||
assert result.is_new_master is True
|
||||
|
||||
await em_in_tx.aclose()
|
||||
await cm_tx.aclose()
|
||||
em_in_tx.close()
|
||||
cm_tx.close()
|
||||
co_tx.close()
|
||||
|
||||
# We lost → seniority unchanged
|
||||
assert election.seniority == 0
|
||||
@@ -139,6 +174,7 @@ async def test_ignores_older_messages(fast_timeout: None) -> None:
|
||||
em_in_tx, em_in_rx = channel[ElectionMessage]()
|
||||
er_tx, _er_rx = channel[ElectionResult]()
|
||||
cm_tx, cm_rx = channel[ConnectionMessage]()
|
||||
co_tx, co_rx = channel[ForwarderCommand]()
|
||||
|
||||
election = Election(
|
||||
node_id=NodeId("ME"),
|
||||
@@ -146,6 +182,7 @@ async def test_ignores_older_messages(fast_timeout: None) -> None:
|
||||
election_message_sender=em_out_tx,
|
||||
election_result_sender=er_tx,
|
||||
connection_message_receiver=cm_rx,
|
||||
command_receiver=co_rx,
|
||||
is_candidate=True,
|
||||
)
|
||||
|
||||
@@ -169,8 +206,9 @@ async def test_ignores_older_messages(fast_timeout: None) -> None:
|
||||
got_second = True
|
||||
assert not got_second, "Should not receive a broadcast for an older round"
|
||||
|
||||
await em_in_tx.aclose()
|
||||
await cm_tx.aclose()
|
||||
em_in_tx.close()
|
||||
cm_tx.close()
|
||||
co_tx.close()
|
||||
|
||||
# Not asserting on the result; focus is on ignore behavior.
|
||||
|
||||
@@ -186,6 +224,7 @@ async def test_two_rounds_emit_two_broadcasts_and_increment_clock(
|
||||
em_in_tx, em_in_rx = channel[ElectionMessage]()
|
||||
er_tx, _er_rx = channel[ElectionResult]()
|
||||
cm_tx, cm_rx = channel[ConnectionMessage]()
|
||||
co_tx, co_rx = channel[ForwarderCommand]()
|
||||
|
||||
election = Election(
|
||||
node_id=NodeId("ME"),
|
||||
@@ -193,6 +232,7 @@ async def test_two_rounds_emit_two_broadcasts_and_increment_clock(
|
||||
election_message_sender=em_out_tx,
|
||||
election_result_sender=er_tx,
|
||||
connection_message_receiver=cm_rx,
|
||||
command_receiver=co_rx,
|
||||
is_candidate=True,
|
||||
)
|
||||
|
||||
@@ -214,8 +254,9 @@ async def test_two_rounds_emit_two_broadcasts_and_increment_clock(
|
||||
if m2.clock == 2:
|
||||
break
|
||||
|
||||
await em_in_tx.aclose()
|
||||
await cm_tx.aclose()
|
||||
em_in_tx.close()
|
||||
cm_tx.close()
|
||||
co_tx.close()
|
||||
|
||||
# Not asserting on who won; just that both rounds were broadcast.
|
||||
|
||||
@@ -230,6 +271,7 @@ async def test_promotion_new_seniority_counts_participants(fast_timeout: None) -
|
||||
em_in_tx, em_in_rx = channel[ElectionMessage]()
|
||||
er_tx, er_rx = channel[ElectionResult]()
|
||||
cm_tx, cm_rx = channel[ConnectionMessage]()
|
||||
co_tx, co_rx = channel[ForwarderCommand]()
|
||||
|
||||
election = Election(
|
||||
node_id=NodeId("ME"),
|
||||
@@ -237,6 +279,7 @@ async def test_promotion_new_seniority_counts_participants(fast_timeout: None) -
|
||||
election_message_sender=em_out_tx,
|
||||
election_result_sender=er_tx,
|
||||
connection_message_receiver=cm_rx,
|
||||
command_receiver=co_rx,
|
||||
is_candidate=True,
|
||||
)
|
||||
|
||||
@@ -251,14 +294,17 @@ async def test_promotion_new_seniority_counts_participants(fast_timeout: None) -
|
||||
# We should see exactly one broadcast from us for this round
|
||||
while True:
|
||||
got = await em_out_rx.receive()
|
||||
if got.clock == 7 and got.node_id == NodeId("ME"):
|
||||
if got.clock == 7 and got.proposed_session.master_node_id == NodeId(
|
||||
"ME"
|
||||
):
|
||||
break
|
||||
|
||||
# Wait for the election to finish so seniority updates
|
||||
_ = await er_rx.receive()
|
||||
|
||||
await em_in_tx.aclose()
|
||||
await cm_tx.aclose()
|
||||
em_in_tx.close()
|
||||
cm_tx.close()
|
||||
co_tx.close()
|
||||
|
||||
# We + A + B = 3 → new seniority expected to be 3
|
||||
assert election.seniority == 3
|
||||
@@ -276,6 +322,7 @@ async def test_connection_message_triggers_new_round_broadcast(
|
||||
em_in_tx, em_in_rx = channel[ElectionMessage]()
|
||||
er_tx, _er_rx = channel[ElectionResult]()
|
||||
cm_tx, cm_rx = channel[ConnectionMessage]()
|
||||
co_tx, co_rx = channel[ForwarderCommand]()
|
||||
|
||||
election = Election(
|
||||
node_id=NodeId("ME"),
|
||||
@@ -283,6 +330,7 @@ async def test_connection_message_triggers_new_round_broadcast(
|
||||
election_message_sender=em_out_tx,
|
||||
election_result_sender=er_tx,
|
||||
connection_message_receiver=cm_rx,
|
||||
command_receiver=co_rx,
|
||||
is_candidate=True,
|
||||
)
|
||||
|
||||
@@ -303,11 +351,75 @@ async def test_connection_message_triggers_new_round_broadcast(
|
||||
# Expect a broadcast for the new round at clock=1
|
||||
while True:
|
||||
got = await em_out_rx.receive()
|
||||
if got.clock == 1 and got.node_id == NodeId("ME"):
|
||||
if got.clock == 1 and got.proposed_session.master_node_id == NodeId(
|
||||
"ME"
|
||||
):
|
||||
break
|
||||
|
||||
# Close promptly to avoid waiting for campaign completion
|
||||
await em_in_tx.aclose()
|
||||
await cm_tx.aclose()
|
||||
em_in_tx.close()
|
||||
cm_tx.close()
|
||||
co_tx.close()
|
||||
|
||||
# After cancellation (before election finishes), no seniority changes asserted here.
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_tie_breaker_prefers_node_with_more_commands_seen(
|
||||
fast_timeout: None,
|
||||
) -> None:
|
||||
"""
|
||||
With equal seniority, the node that has seen more commands should win the election.
|
||||
We increase our local 'commands_seen' by sending TestCommand()s before triggering the round.
|
||||
"""
|
||||
em_out_tx, em_out_rx = channel[ElectionMessage]()
|
||||
em_in_tx, em_in_rx = channel[ElectionMessage]()
|
||||
er_tx, er_rx = channel[ElectionResult]()
|
||||
cm_tx, cm_rx = channel[ConnectionMessage]()
|
||||
co_tx, co_rx = channel[ForwarderCommand]()
|
||||
|
||||
me = NodeId("ME")
|
||||
|
||||
election = Election(
|
||||
node_id=me,
|
||||
election_message_receiver=em_in_rx,
|
||||
election_message_sender=em_out_tx,
|
||||
election_result_sender=er_tx,
|
||||
connection_message_receiver=cm_rx,
|
||||
command_receiver=co_rx,
|
||||
is_candidate=True,
|
||||
seniority=0,
|
||||
)
|
||||
|
||||
async with create_task_group() as tg:
|
||||
with fail_after(2):
|
||||
tg.start_soon(election.run)
|
||||
|
||||
# Pump local commands so our commands_seen is high before the round starts
|
||||
for _ in range(50):
|
||||
await co_tx.send(
|
||||
ForwarderCommand(origin=NodeId("SOMEONE"), command=TestCommand())
|
||||
)
|
||||
|
||||
# Trigger a round at clock=1 with a peer of equal seniority but fewer commands
|
||||
await em_in_tx.send(
|
||||
em(clock=1, seniority=0, node_id="PEER", commands_seen=5)
|
||||
)
|
||||
|
||||
# Observe our broadcast for this round (to ensure we've joined the round)
|
||||
while True:
|
||||
got = await em_out_rx.receive()
|
||||
if got.clock == 1 and got.proposed_session.master_node_id == me:
|
||||
# We don't assert exact count, just that we've participated this round.
|
||||
break
|
||||
|
||||
# The elected result for clock=1 should be us due to higher commands_seen
|
||||
while True:
|
||||
result = await er_rx.receive()
|
||||
if result.session_id.master_node_id == me:
|
||||
assert result.session_id.election_clock in (0, 1)
|
||||
break
|
||||
|
||||
em_in_tx.close()
|
||||
cm_tx.close()
|
||||
co_tx.close()
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from exo.shared.types.api import ChatCompletionTaskParams
|
||||
@@ -10,19 +8,14 @@ from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
|
||||
|
||||
|
||||
# TODO: We need to have a distinction between create instance and spin up instance.
|
||||
class CommandType(str, Enum):
|
||||
ChatCompletion = "ChatCompletion"
|
||||
CreateInstance = "CreateInstance"
|
||||
SpinUpInstance = "SpinUpInstance"
|
||||
DeleteInstance = "DeleteInstance"
|
||||
TaskFinished = "TaskFinished"
|
||||
RequestEventLog = "RequestEventLog"
|
||||
|
||||
|
||||
class BaseCommand(TaggedModel):
|
||||
command_id: CommandId = Field(default_factory=CommandId)
|
||||
|
||||
|
||||
class TestCommand(BaseCommand):
|
||||
pass
|
||||
|
||||
|
||||
class ChatCompletion(BaseCommand):
|
||||
request_params: ChatCompletionTaskParams
|
||||
|
||||
@@ -48,7 +41,8 @@ class RequestEventLog(BaseCommand):
|
||||
|
||||
|
||||
Command = (
|
||||
RequestEventLog
|
||||
TestCommand
|
||||
| RequestEventLog
|
||||
| ChatCompletion
|
||||
| CreateInstance
|
||||
| SpinUpInstance
|
||||
|
||||
@@ -23,6 +23,11 @@ class NodeId(Id):
|
||||
pass
|
||||
|
||||
|
||||
class SessionId(CamelCaseModel):
|
||||
master_node_id: NodeId
|
||||
election_clock: int
|
||||
|
||||
|
||||
class CommandId(Id):
|
||||
pass
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ from pydantic import Field
|
||||
|
||||
from exo.shared.topology import Connection, NodePerformanceProfile
|
||||
from exo.shared.types.chunks import CommandId, GenerationChunk
|
||||
from exo.shared.types.common import Id, NodeId
|
||||
from exo.shared.types.common import Id, NodeId, SessionId
|
||||
from exo.shared.types.profiling import MemoryPerformanceProfile
|
||||
from exo.shared.types.tasks import Task, TaskId, TaskStatus
|
||||
from exo.shared.types.worker.common import InstanceId, WorkerStatus
|
||||
@@ -177,4 +177,5 @@ class ForwarderEvent(CamelCaseModel):
|
||||
|
||||
origin_idx: int = Field(ge=0)
|
||||
origin: NodeId
|
||||
session: SessionId
|
||||
event: Event
|
||||
|
||||
@@ -13,7 +13,7 @@ from loguru import logger
|
||||
from exo.routing.connection_message import ConnectionMessage, ConnectionMessageType
|
||||
from exo.shared.apply import apply
|
||||
from exo.shared.types.commands import ForwarderCommand, RequestEventLog
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.common import NodeId, SessionId
|
||||
from exo.shared.types.events import (
|
||||
ChunkGenerated,
|
||||
Event,
|
||||
@@ -75,6 +75,7 @@ class Worker:
|
||||
def __init__(
|
||||
self,
|
||||
node_id: NodeId,
|
||||
session_id: SessionId,
|
||||
shard_downloader: ShardDownloader,
|
||||
*,
|
||||
initial_connection_messages: list[ConnectionMessage],
|
||||
@@ -91,6 +92,7 @@ class Worker:
|
||||
command_sender: Sender[ForwarderCommand],
|
||||
):
|
||||
self.node_id: NodeId = node_id
|
||||
self.session_id: SessionId = session_id
|
||||
self.shard_downloader: ShardDownloader = shard_downloader
|
||||
self.global_event_receiver = global_event_receiver
|
||||
self.local_event_sender = local_event_sender
|
||||
@@ -634,6 +636,7 @@ class Worker:
|
||||
fe = ForwarderEvent(
|
||||
origin_idx=self.local_event_index,
|
||||
origin=self.node_id,
|
||||
session=self.session_id,
|
||||
event=event,
|
||||
)
|
||||
logger.debug(
|
||||
|
||||
@@ -5,12 +5,15 @@ from anyio import fail_after
|
||||
|
||||
from exo.routing.topics import ConnectionMessage, ForwarderCommand, ForwarderEvent
|
||||
from exo.shared.types.chunks import TokenChunk
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.common import NodeId, SessionId
|
||||
from exo.shared.types.events import ChunkGenerated, Event, TaskStateUpdated
|
||||
from exo.shared.types.tasks import TaskId, TaskStatus
|
||||
from exo.utils.channels import Receiver, Sender, channel
|
||||
from exo.worker.download.shard_downloader import NoopShardDownloader, ShardDownloader
|
||||
from exo.worker.main import Worker
|
||||
from exo.worker.tests.constants import MASTER_NODE_ID
|
||||
|
||||
session = SessionId(master_node_id=MASTER_NODE_ID, election_clock=0)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -19,11 +22,17 @@ class WorkerMailbox:
|
||||
receiver: Receiver[ForwarderEvent]
|
||||
counter: int = 0
|
||||
|
||||
async def append_events(self, events: list[Event], *, origin: NodeId):
|
||||
async def append_events(
|
||||
self,
|
||||
events: list[Event],
|
||||
*,
|
||||
origin: NodeId,
|
||||
):
|
||||
for event in events:
|
||||
await self.sender.send(
|
||||
ForwarderEvent(
|
||||
origin=origin,
|
||||
session=session,
|
||||
event=event,
|
||||
origin_idx=self.counter,
|
||||
)
|
||||
@@ -45,6 +54,7 @@ def create_worker_void_mailbox(
|
||||
shard_downloader = NoopShardDownloader()
|
||||
return Worker(
|
||||
node_id,
|
||||
session_id=session,
|
||||
shard_downloader=shard_downloader,
|
||||
initial_connection_messages=[],
|
||||
connection_message_receiver=channel[ConnectionMessage]()[1],
|
||||
@@ -64,6 +74,7 @@ def create_worker_and_mailbox(
|
||||
sender, grecv = channel[ForwarderEvent]()
|
||||
worker = Worker(
|
||||
node_id,
|
||||
session_id=session,
|
||||
shard_downloader=shard_downloader,
|
||||
initial_connection_messages=[],
|
||||
connection_message_receiver=channel[ConnectionMessage]()[1],
|
||||
@@ -84,6 +95,7 @@ def create_worker_with_old_mailbox(
|
||||
# This function is subtly complex, come talk to Evan if you want to know what it's actually doing.
|
||||
worker = Worker(
|
||||
node_id,
|
||||
session_id=session,
|
||||
shard_downloader=shard_downloader,
|
||||
initial_connection_messages=[],
|
||||
connection_message_receiver=channel[ConnectionMessage]()[1],
|
||||
|
||||
Reference in New Issue
Block a user