Compare commits

...

1 Commits

Author SHA1 Message Date
Evan
4c9bc26c1a add system ids 2026-02-18 21:33:48 +00:00
10 changed files with 46 additions and 51 deletions

View File

@@ -1,7 +1,6 @@
import asyncio
import socket
from dataclasses import dataclass, field
from typing import Iterator
import anyio
from anyio import current_time
@@ -22,7 +21,7 @@ from exo.shared.types.commands import (
ForwarderDownloadCommand,
StartDownload,
)
from exo.shared.types.common import NodeId, SessionId
from exo.shared.types.common import NodeId, SessionId, SystemId
from exo.shared.types.events import (
Event,
ForwarderEvent,
@@ -46,8 +45,8 @@ class DownloadCoordinator:
shard_downloader: ShardDownloader
download_command_receiver: Receiver[ForwarderDownloadCommand]
local_event_sender: Sender[ForwarderEvent]
event_index_counter: Iterator[int]
offline: bool = False
_system_id: SystemId = field(default_factory=SystemId)
# Local state
download_status: dict[ModelId, DownloadProgress] = field(default_factory=dict)
@@ -295,15 +294,16 @@ class DownloadCoordinator:
del self.download_status[model_id]
async def _forward_events(self) -> None:
idx = 0
with self.event_receiver as events:
async for event in events:
idx = next(self.event_index_counter)
fe = ForwarderEvent(
origin_idx=idx,
origin=self.node_id,
origin=self._system_id,
session=self.session_id,
event=event,
)
idx += 1
logger.debug(
f"DownloadCoordinator published event {idx}: {str(event)[:100]}"
)

View File

@@ -1,11 +1,10 @@
import argparse
import itertools
import multiprocessing as mp
import os
import resource
import signal
from dataclasses import dataclass, field
from typing import Iterator, Self
from typing import Self
import anyio
from anyio.abc import TaskGroup
@@ -38,12 +37,11 @@ class Node:
api: API | None
node_id: NodeId
event_index_counter: Iterator[int]
offline: bool
_tg: TaskGroup = field(init=False, default_factory=anyio.create_task_group)
@classmethod
async def create(cls, args: "Args") -> "Self":
async def create(cls, args: "Args") -> Self:
keypair = get_node_id_keypair()
node_id = NodeId(keypair.to_peer_id().to_base58())
session_id = SessionId(master_node_id=node_id, election_clock=0)
@@ -57,9 +55,6 @@ class Node:
logger.info(f"Starting node {node_id}")
# Create shared event index counter for Worker and DownloadCoordinator
event_index_counter = itertools.count()
# Create DownloadCoordinator (unless --no-downloads)
if not args.no_downloads:
download_coordinator = DownloadCoordinator(
@@ -68,7 +63,6 @@ class Node:
exo_shard_downloader(),
download_command_receiver=router.receiver(topics.DOWNLOAD_COMMANDS),
local_event_sender=router.sender(topics.LOCAL_EVENTS),
event_index_counter=event_index_counter,
offline=args.offline,
)
else:
@@ -95,7 +89,6 @@ class Node:
local_event_sender=router.sender(topics.LOCAL_EVENTS),
command_sender=router.sender(topics.COMMANDS),
download_command_sender=router.sender(topics.DOWNLOAD_COMMANDS),
event_index_counter=event_index_counter,
)
else:
worker = None
@@ -133,7 +126,6 @@ class Node:
master,
api,
node_id,
event_index_counter,
args.offline,
)
@@ -213,7 +205,6 @@ class Node:
if result.is_new_master:
await anyio.sleep(0)
# Fresh counter for new session (buffer expects indices from 0)
self.event_index_counter = itertools.count()
if self.download_coordinator:
self.download_coordinator.shutdown()
self.download_coordinator = DownloadCoordinator(
@@ -224,7 +215,6 @@ class Node:
topics.DOWNLOAD_COMMANDS
),
local_event_sender=self.router.sender(topics.LOCAL_EVENTS),
event_index_counter=self.event_index_counter,
offline=self.offline,
)
self._tg.start_soon(self.download_coordinator.run)
@@ -242,7 +232,6 @@ class Node:
download_command_sender=self.router.sender(
topics.DOWNLOAD_COMMANDS
),
event_index_counter=self.event_index_counter,
)
self._tg.start_soon(self.worker.run)
if self.api:

View File

@@ -131,7 +131,7 @@ from exo.shared.types.commands import (
TaskFinished,
TextGeneration,
)
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
from exo.shared.types.common import CommandId, Id, NodeId, SessionId, SystemId
from exo.shared.types.events import (
ChunkGenerated,
Event,
@@ -183,6 +183,7 @@ class API:
) -> None:
self.state = State()
self._event_log = DiskEventLog(_API_EVENT_LOG_DIR)
self._system_id = SystemId()
self.command_sender = command_sender
self.download_command_sender = download_command_sender
self.global_event_receiver = global_event_receiver
@@ -233,6 +234,7 @@ class API:
self._event_log.close()
self._event_log = DiskEventLog(_API_EVENT_LOG_DIR)
self.state = State()
self._system_id = SystemId()
self.session_id = new_session_id
self.event_buffer = OrderedBuffer[Event]()
self._text_generation_queues = {}
@@ -546,7 +548,7 @@ class API:
command = TaskCancelled(cancelled_command_id=command_id)
with anyio.CancelScope(shield=True):
await self.command_sender.send(
ForwarderCommand(origin=self.node_id, command=command)
ForwarderCommand(origin=self._system_id, command=command)
)
raise
finally:
@@ -891,7 +893,7 @@ class API:
command = TaskCancelled(cancelled_command_id=command_id)
with anyio.CancelScope(shield=True):
await self.command_sender.send(
ForwarderCommand(origin=self.node_id, command=command)
ForwarderCommand(origin=self._system_id, command=command)
)
raise
finally:
@@ -977,7 +979,7 @@ class API:
command = TaskCancelled(cancelled_command_id=command_id)
with anyio.CancelScope(shield=True):
await self.command_sender.send(
ForwarderCommand(origin=self.node_id, command=command)
ForwarderCommand(origin=self._system_id, command=command)
)
raise
finally:
@@ -1408,7 +1410,7 @@ class API:
async def _apply_state(self):
with self.global_event_receiver as events:
async for f_event in events:
if f_event.origin != self.session_id.master_node_id:
if f_event.session != self.session_id:
continue
self.event_buffer.ingest(f_event.origin_idx, f_event.event)
for idx, event in self.event_buffer.drain_indexed():
@@ -1472,12 +1474,12 @@ class API:
while self.paused:
await self.paused_ev.wait()
await self.command_sender.send(
ForwarderCommand(origin=self.node_id, command=command)
ForwarderCommand(origin=self._system_id, command=command)
)
async def _send_download(self, command: DownloadCommand):
await self.download_command_sender.send(
ForwarderDownloadCommand(origin=self.node_id, command=command)
ForwarderDownloadCommand(origin=self._system_id, command=command)
)
async def start_download(

View File

@@ -29,7 +29,7 @@ from exo.shared.types.commands import (
TestCommand,
TextGeneration,
)
from exo.shared.types.common import CommandId, NodeId, SessionId
from exo.shared.types.common import CommandId, NodeId, SessionId, SystemId
from exo.shared.types.events import (
Event,
ForwarderEvent,
@@ -90,7 +90,8 @@ class Master:
self._loopback_event_sender: Sender[ForwarderEvent] = (
local_event_receiver.clone_sender()
)
self._multi_buffer = MultiSourceBuffer[NodeId, Event]()
self._system_id = SystemId()
self._multi_buffer = MultiSourceBuffer[SystemId, Event]()
self._event_log = DiskEventLog(EXO_EVENT_LOG_DIR / "master")
self._pending_traces: dict[TaskId, dict[int, list[TraceEventData]]] = {}
self._expected_ranks: dict[TaskId, set[int]] = {}
@@ -288,7 +289,7 @@ class Master:
):
await self.download_command_sender.send(
ForwarderDownloadCommand(
origin=self.node_id, command=cmd
origin=self._system_id, command=cmd
)
)
generated_events.extend(transition_events)
@@ -415,7 +416,7 @@ class Master:
async for event in events:
await self._loopback_event_sender.send(
ForwarderEvent(
origin=NodeId(f"master_{self.node_id}"),
origin=self._system_id,
origin_idx=local_index,
session=self.session_id,
event=event,
@@ -428,7 +429,7 @@ class Master:
# Convenience method since this line is ugly
await self.global_event_sender.send(
ForwarderEvent(
origin=self.node_id,
origin=self._system_id,
origin_idx=event.idx,
session=self.session_id,
event=event.event,

View File

@@ -15,7 +15,7 @@ from exo.shared.types.commands import (
PlaceInstance,
TextGeneration,
)
from exo.shared.types.common import ModelId, NodeId, SessionId
from exo.shared.types.common import ModelId, NodeId, SessionId, SystemId
from exo.shared.types.events import (
ForwarderEvent,
IndexedEvent,
@@ -75,13 +75,12 @@ async def test_master():
async with anyio.create_task_group() as tg:
tg.start_soon(master.run)
sender_node_id = NodeId(f"{keypair.to_peer_id().to_base58()}_sender")
# inject a NodeGatheredInfo event
logger.info("inject a NodeGatheredInfo event")
await local_event_sender.send(
ForwarderEvent(
origin_idx=0,
origin=sender_node_id,
origin=SystemId("Worker"),
session=session_id,
event=(
NodeGatheredInfo(
@@ -108,7 +107,7 @@ async def test_master():
logger.info("inject a CreateInstance Command")
await command_sender.send(
ForwarderCommand(
origin=node_id,
origin=SystemId("API"),
command=(
PlaceInstance(
command_id=CommandId(),
@@ -133,7 +132,7 @@ async def test_master():
logger.info("inject a TextGeneration Command")
await command_sender.send(
ForwarderCommand(
origin=node_id,
origin=SystemId("API"),
command=(
TextGeneration(
command_id=CommandId(),

View File

@@ -4,7 +4,7 @@ from anyio import create_task_group, fail_after, move_on_after
from exo.routing.connection_message import ConnectionMessage, ConnectionMessageType
from exo.shared.election import Election, ElectionMessage, ElectionResult
from exo.shared.types.commands import ForwarderCommand, TestCommand
from exo.shared.types.common import NodeId, SessionId
from exo.shared.types.common import NodeId, SessionId, SystemId
from exo.utils.channels import channel
# ======= #
@@ -384,7 +384,7 @@ async def test_tie_breaker_prefers_node_with_more_commands_seen() -> None:
# Pump local commands so our commands_seen is high before the round starts
for _ in range(50):
await co_tx.send(
ForwarderCommand(origin=NodeId("SOMEONE"), command=TestCommand())
ForwarderCommand(origin=SystemId("SOMEONE"), command=TestCommand())
)
# Trigger a round at clock=1 with a peer of equal seniority but fewer commands

View File

@@ -6,7 +6,7 @@ from exo.shared.types.api import (
ImageGenerationTaskParams,
)
from exo.shared.types.chunks import InputImageChunk
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.common import CommandId, NodeId, SystemId
from exo.shared.types.text_generation import TextGenerationTaskParams
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.shards import Sharding, ShardMetadata
@@ -100,10 +100,10 @@ Command = (
class ForwarderCommand(CamelCaseModel):
origin: NodeId
origin: SystemId
command: Command
class ForwarderDownloadCommand(CamelCaseModel):
origin: NodeId
origin: SystemId
command: DownloadCommand

View File

@@ -25,6 +25,10 @@ class NodeId(Id):
pass
class SystemId(Id):
pass
class ModelId(Id):
def normalize(self) -> str:
return self.replace("/", "--")

View File

@@ -5,7 +5,7 @@ from pydantic import Field
from exo.shared.topology import Connection
from exo.shared.types.chunks import GenerationChunk, InputImageChunk
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
from exo.shared.types.common import CommandId, Id, NodeId, SessionId, SystemId
from exo.shared.types.tasks import Task, TaskId, TaskStatus
from exo.shared.types.worker.downloads import DownloadProgress
from exo.shared.types.worker.instances import Instance, InstanceId
@@ -166,6 +166,6 @@ class ForwarderEvent(CamelCaseModel):
"""An event the forwarder will serialize and send over the network"""
origin_idx: int = Field(ge=0)
origin: NodeId
origin: SystemId
session: SessionId
event: Event

View File

@@ -1,7 +1,6 @@
from collections import defaultdict
from datetime import datetime, timezone
from random import random
from typing import Iterator
import anyio
from anyio import CancelScope, create_task_group, fail_after
@@ -17,7 +16,7 @@ from exo.shared.types.commands import (
RequestEventLog,
StartDownload,
)
from exo.shared.types.common import CommandId, NodeId, SessionId
from exo.shared.types.common import CommandId, NodeId, SessionId, SystemId
from exo.shared.types.events import (
Event,
EventId,
@@ -64,14 +63,12 @@ class Worker:
# but I think it's the correct way to be thinking about commands
command_sender: Sender[ForwarderCommand],
download_command_sender: Sender[ForwarderDownloadCommand],
event_index_counter: Iterator[int],
):
self.node_id: NodeId = node_id
self.session_id: SessionId = session_id
self.global_event_receiver = global_event_receiver
self.local_event_sender = local_event_sender
self.event_index_counter = event_index_counter
self.command_sender = command_sender
self.download_command_sender = download_command_sender
self.event_buffer = OrderedBuffer[Event]()
@@ -86,6 +83,8 @@ class Worker:
self._nack_base_seconds: float = 0.5
self._nack_cap_seconds: float = 10.0
self._system_id = SystemId()
self.event_sender, self.event_receiver = channel[Event]()
# Buffer for input image chunks (for image editing)
@@ -132,7 +131,7 @@ class Worker:
async def _event_applier(self):
with self.global_event_receiver as events:
async for f_event in events:
if f_event.origin != self.session_id.master_node_id:
if f_event.session != self.session_id:
continue
self.event_buffer.ingest(f_event.origin_idx, f_event.event)
event_id = f_event.event.event_id
@@ -212,7 +211,7 @@ class Worker:
await self.download_command_sender.send(
ForwarderDownloadCommand(
origin=self.node_id,
origin=self._system_id,
command=StartDownload(
target_node_id=self.node_id,
shard_metadata=shard,
@@ -312,7 +311,7 @@ class Worker:
)
await self.command_sender.send(
ForwarderCommand(
origin=self.node_id,
origin=self._system_id,
command=RequestEventLog(since_idx=since_idx),
)
)
@@ -339,15 +338,16 @@ class Worker:
return runner
async def _forward_events(self) -> None:
idx = 0
with self.event_receiver as events:
async for event in events:
idx = next(self.event_index_counter)
fe = ForwarderEvent(
origin_idx=idx,
origin=self.node_id,
origin=self._system_id,
session=self.session_id,
event=event,
)
idx += 1
logger.debug(f"Worker published event {idx}: {str(event)[:100]}")
await self.local_event_sender.send(fe)
self.out_for_delivery[event.event_id] = fe