Compare commits

..

1 Commits

Author SHA1 Message Date
Alex Cheema
dc9000290b fix instance type mismatch in /instance/previews endpoint
Derive instance_meta from the actual instance type returned by
place_instance() instead of using the requested instance_meta from
the loop variable. place_instance() overrides single-node placements
to MlxRing, but the preview response was still reporting the original
requested type (e.g., MlxJaccl), causing a mismatch.

Closes #1426

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-23 18:04:38 +00:00
10 changed files with 396 additions and 328 deletions

View File

@@ -469,11 +469,11 @@
<td class="px-4 py-3 text-center align-middle">
{#if cell.kind === "completed"}
<div
class="flex flex-col items-center gap-1"
class="flex flex-col items-center gap-0.5"
title="Completed ({formatBytes(cell.totalBytes)})"
>
<svg
class="w-7 h-7 text-green-400"
class="w-5 h-5 text-green-400"
viewBox="0 0 20 20"
fill="currentColor"
>
@@ -483,18 +483,18 @@
clip-rule="evenodd"
></path>
</svg>
<span class="text-xs text-exo-light-gray/70"
<span class="text-[10px] text-exo-light-gray/70"
>{formatBytes(cell.totalBytes)}</span
>
<button
type="button"
class="text-exo-light-gray/40 hover:text-red-400 transition-colors mt-0.5 cursor-pointer"
class="text-exo-light-gray/40 hover:text-red-400 transition-colors mt-0.5"
onclick={() =>
deleteDownload(col.nodeId, row.modelId)}
title="Delete from this node"
>
<svg
class="w-5 h-5"
class="w-3.5 h-3.5"
viewBox="0 0 20 20"
fill="none"
stroke="currentColor"
@@ -517,11 +517,11 @@
cell.speed,
)} - ETA {formatEta(cell.etaMs)}"
>
<span class="text-exo-yellow text-sm font-medium"
<span class="text-exo-yellow text-xs font-medium"
>{clampPercent(cell.percentage).toFixed(1)}%</span
>
<div
class="w-16 h-2 bg-exo-black/60 rounded-sm overflow-hidden"
class="w-14 h-1.5 bg-exo-black/60 rounded-sm overflow-hidden"
>
<div
class="h-full bg-gradient-to-r from-exo-yellow to-exo-yellow/70 transition-all duration-300"
@@ -530,25 +530,25 @@
).toFixed(1)}%"
></div>
</div>
<span class="text-[10px] text-exo-light-gray/60"
<span class="text-[9px] text-exo-light-gray/60"
>{formatSpeed(cell.speed)}</span
>
</div>
{:else if cell.kind === "pending"}
<div
class="flex flex-col items-center gap-1"
class="flex flex-col items-center gap-0.5"
title={cell.downloaded > 0
? `${formatBytes(cell.downloaded)} / ${formatBytes(cell.total)} downloaded (paused)`
? `${formatBytes(cell.downloaded)} / ${formatBytes(cell.total)} downloaded`
: "Download pending"}
>
{#if cell.downloaded > 0 && cell.total > 0}
<span class="text-exo-light-gray/70 text-xs"
<span class="text-exo-light-gray/70 text-[10px]"
>{formatBytes(cell.downloaded)} / {formatBytes(
cell.total,
)}</span
>
<div
class="w-full h-1.5 bg-white/10 rounded-full overflow-hidden"
class="w-full h-1 bg-white/10 rounded-full overflow-hidden"
>
<div
class="h-full bg-exo-light-gray/40 rounded-full"
@@ -558,55 +558,9 @@
).toFixed(1)}%"
></div>
</div>
{#if row.shardMetadata}
<button
type="button"
class="text-exo-light-gray/50 hover:text-exo-yellow transition-colors cursor-pointer"
onclick={() =>
startDownload(col.nodeId, row.shardMetadata!)}
title="Resume download on this node"
>
<svg
class="w-5 h-5"
viewBox="0 0 20 20"
fill="none"
stroke="currentColor"
stroke-width="2"
>
<path
d="M10 3v10m0 0l-3-3m3 3l3-3M3 17h14"
stroke-linecap="round"
stroke-linejoin="round"
></path>
</svg>
</button>
{:else}
<span class="text-exo-light-gray/40 text-[10px]"
>paused</span
>
{/if}
{:else if row.shardMetadata}
<button
type="button"
class="text-exo-light-gray/30 hover:text-exo-yellow transition-colors cursor-pointer"
onclick={() =>
startDownload(col.nodeId, row.shardMetadata!)}
title="Start download on this node"
<span class="text-exo-light-gray/40 text-[9px]"
>paused</span
>
<svg
class="w-6 h-6"
viewBox="0 0 20 20"
fill="none"
stroke="currentColor"
stroke-width="2"
>
<path
d="M10 3v10m0 0l-3-3m3 3l3-3M3 17h14"
stroke-linecap="round"
stroke-linejoin="round"
></path>
</svg>
</button>
{:else}
<span class="text-exo-light-gray/50 text-sm">...</span
>
@@ -614,11 +568,11 @@
</div>
{:else if cell.kind === "failed"}
<div
class="flex flex-col items-center gap-1"
class="flex flex-col items-center gap-0.5"
title="Download failed"
>
<svg
class="w-7 h-7 text-red-400"
class="w-5 h-5 text-red-400"
viewBox="0 0 20 20"
fill="currentColor"
>
@@ -631,13 +585,13 @@
{#if row.shardMetadata}
<button
type="button"
class="text-exo-light-gray/40 hover:text-exo-yellow transition-colors cursor-pointer"
class="text-exo-light-gray/40 hover:text-exo-yellow transition-colors"
onclick={() =>
startDownload(col.nodeId, row.shardMetadata!)}
title="Retry download on this node"
>
<svg
class="w-5 h-5"
class="w-3.5 h-3.5"
viewBox="0 0 20 20"
fill="none"
stroke="currentColor"
@@ -663,13 +617,13 @@
{#if row.shardMetadata}
<button
type="button"
class="text-exo-light-gray/30 hover:text-exo-yellow transition-colors mt-0.5 opacity-0 group-hover:opacity-100 cursor-pointer"
class="text-exo-light-gray/30 hover:text-exo-yellow transition-colors mt-0.5 opacity-0 group-hover:opacity-100"
onclick={() =>
startDownload(col.nodeId, row.shardMetadata!)}
title="Download to this node"
>
<svg
class="w-5 h-5"
class="w-3.5 h-3.5"
viewBox="0 0 20 20"
fill="none"
stroke="currentColor"

View File

@@ -1,6 +1,7 @@
import asyncio
import socket
from dataclasses import dataclass, field
from random import random
import anyio
from anyio import current_time
@@ -21,9 +22,13 @@ from exo.shared.types.commands import (
ForwarderDownloadCommand,
StartDownload,
)
from exo.shared.types.common import NodeId
from exo.shared.types.common import NodeId, SessionId, SystemId
from exo.shared.types.events import (
Event,
EventId,
# TODO(evan): just for acks, should delete this ASAP
GlobalForwarderEvent,
LocalForwarderEvent,
NodeDownloadProgress,
)
from exo.shared.types.worker.downloads import (
@@ -34,28 +39,40 @@ from exo.shared.types.worker.downloads import (
DownloadProgress,
)
from exo.shared.types.worker.shards import PipelineShardMetadata, ShardMetadata
from exo.utils.channels import Receiver, Sender
from exo.utils.channels import Receiver, Sender, channel
from exo.utils.task_group import TaskGroup
@dataclass
class DownloadCoordinator:
node_id: NodeId
session_id: SessionId
shard_downloader: ShardDownloader
download_command_receiver: Receiver[ForwarderDownloadCommand]
event_sender: Sender[Event]
local_event_sender: Sender[LocalForwarderEvent]
# ack stuff
_global_event_receiver: Receiver[GlobalForwarderEvent]
_out_for_delivery: dict[EventId, LocalForwarderEvent] = field(default_factory=dict)
offline: bool = False
_system_id: SystemId = field(default_factory=SystemId)
# Local state
download_status: dict[ModelId, DownloadProgress] = field(default_factory=dict)
active_downloads: dict[ModelId, asyncio.Task[None]] = field(default_factory=dict)
# Internal event channel for forwarding (initialized in __post_init__)
event_sender: Sender[Event] = field(init=False)
event_receiver: Receiver[Event] = field(init=False)
_tg: TaskGroup = field(init=False, default_factory=TaskGroup)
# Per-model throttle for download progress events
_last_progress_time: dict[ModelId, float] = field(default_factory=dict)
def __post_init__(self) -> None:
self.event_sender, self.event_receiver = channel[Event]()
if self.offline:
self.shard_downloader.set_internet_connection(False)
self.shard_downloader.on_progress(self._download_progress_callback)
@@ -111,7 +128,10 @@ class DownloadCoordinator:
try:
async with self._tg as tg:
tg.start_soon(self._command_processor)
tg.start_soon(self._forward_events)
tg.start_soon(self._emit_existing_download_progress)
tg.start_soon(self._resend_out_for_delivery)
tg.start_soon(self._clear_ofd)
if not self.offline:
tg.start_soon(self._check_internet_connection)
finally:
@@ -149,6 +169,20 @@ class DownloadCoordinator:
def shutdown(self) -> None:
self._tg.cancel_tasks()
# directly copied from worker
async def _resend_out_for_delivery(self) -> None:
# This can also be massively tightened, we should check events are at least a certain age before resending.
# Exponential backoff would also certainly help here.
while True:
await anyio.sleep(1 + random())
for event in self._out_for_delivery.copy().values():
await self.local_event_sender.send(event)
async def _clear_ofd(self) -> None:
with self._global_event_receiver as events:
async for event in events:
self._out_for_delivery.pop(event.event.event_id, None)
async def _command_processor(self) -> None:
with self.download_command_receiver as commands:
async for cmd in commands:
@@ -321,6 +355,23 @@ class DownloadCoordinator:
)
del self.download_status[model_id]
async def _forward_events(self) -> None:
idx = 0
with self.event_receiver as events:
async for event in events:
fe = LocalForwarderEvent(
origin_idx=idx,
origin=self._system_id,
session=self.session_id,
event=event,
)
idx += 1
logger.debug(
f"DownloadCoordinator published event {idx}: {str(event)[:100]}"
)
await self.local_event_sender.send(fe)
self._out_for_delivery[event.event_id] = fe
async def _emit_existing_download_progress(self) -> None:
try:
while True:

View File

@@ -823,7 +823,6 @@ async def download_shard(
for file in filtered_file_list:
downloaded_bytes = await get_downloaded_size(target_dir / file.path)
final_file_exists = await aios.path.exists(target_dir / file.path)
file_progress[file.path] = RepoFileDownloadProgress(
repo_id=shard.model_card.model_id,
repo_revision=revision,
@@ -833,9 +832,7 @@ async def download_shard(
total=Memory.from_bytes(file.size or 0),
speed=0,
eta=timedelta(0),
status="complete"
if final_file_exists and downloaded_bytes == file.size
else "not_started",
status="complete" if downloaded_bytes == file.size else "not_started",
start_time=time.time(),
)

View File

@@ -0,0 +1,98 @@
from typing import Any
import anyio
import pytest
from exo.download.coordinator import DownloadCoordinator
from exo.download.shard_downloader import NoopShardDownloader
from exo.shared.models.model_cards import ModelCard, ModelTask
from exo.shared.types.common import ModelId, NodeId, SessionId
from exo.shared.types.events import (
GlobalForwarderEvent,
LocalForwarderEvent,
NodeDownloadProgress,
)
from exo.shared.types.memory import Memory
from exo.shared.types.worker.downloads import (
DownloadPending,
)
from exo.shared.types.worker.shards import PipelineShardMetadata
from exo.utils.channels import channel
# Use the builtin NoopShardDownloader directly it already implements the required abstract interface.
# No additional subclass is needed for this test.
@pytest.mark.anyio
async def test_ack_behaviour():
# Create channels (type Any for simplicity)
_, command_receiver = channel[Any]()
local_sender, _ = channel[Any]()
global_sender, global_receiver = channel[Any]()
# Minimal identifiers
node_id = NodeId()
session_id = SessionId(master_node_id=node_id, election_clock=0)
# Create a dummy model card and shard metadata
model_id = ModelId("test/model")
model_card = ModelCard(
model_id=model_id,
storage_size=Memory.from_bytes(0),
n_layers=1,
hidden_size=1,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
)
shard = PipelineShardMetadata(
model_card=model_card,
device_rank=0,
world_size=1,
start_layer=0,
end_layer=1,
n_layers=1,
)
# Instantiate the coordinator with the dummy downloader
coord = DownloadCoordinator(
node_id=node_id,
session_id=session_id,
shard_downloader=NoopShardDownloader(),
download_command_receiver=command_receiver,
local_event_sender=local_sender,
_global_event_receiver=global_receiver,
)
async with anyio.create_task_group() as tg:
# Start the forwarding and ackclearing loops
tg.start_soon(coord._forward_events) # pyright: ignore[reportPrivateUsage]
tg.start_soon(coord._clear_ofd) # pyright: ignore[reportPrivateUsage]
# Send a pending download progress event via the internal event sender
pending = DownloadPending(
node_id=node_id,
shard_metadata=shard,
model_directory="/tmp/model",
)
await coord.event_sender.send(NodeDownloadProgress(download_progress=pending))
# Allow the forwarder to process the event
await anyio.sleep(0.1)
# There should be exactly one entry awaiting ACK
assert len(coord._out_for_delivery) == 1 # pyright: ignore[reportPrivateUsage]
# Retrieve the stored LocalForwarderEvent
stored_fe: LocalForwarderEvent = next(iter(coord._out_for_delivery.values())) # pyright: ignore[reportPrivateUsage]
# Simulate receiving a global ack for this event
ack = GlobalForwarderEvent(
origin_idx=0,
origin=node_id,
session=session_id,
event=stored_fe.event,
)
await global_sender.send(ack)
# Give the clearofd task a moment to process the ack
await anyio.sleep(0.1)
# The outfordelivery map should now be empty
assert len(coord._out_for_delivery) == 0 # pyright: ignore[reportPrivateUsage]
# Cancel background tasks
tg.cancel_scope.cancel()

View File

@@ -15,7 +15,6 @@ from exo.download.coordinator import DownloadCoordinator
from exo.download.impl_shard_downloader import exo_shard_downloader
from exo.master.api import API # TODO: should API be in master?
from exo.master.main import Master
from exo.routing.event_router import EventRouter
from exo.routing.router import Router, get_node_id_keypair
from exo.shared.constants import EXO_LOG
from exo.shared.election import Election, ElectionResult
@@ -30,7 +29,6 @@ from exo.worker.main import Worker
@dataclass
class Node:
router: Router
event_router: EventRouter
download_coordinator: DownloadCoordinator | None
worker: Worker | None
election: Election # Every node participates in election, as we do want a node to become master even if it isn't a master candidate if no master candidates are present.
@@ -54,12 +52,6 @@ class Node:
await router.register_topic(topics.ELECTION_MESSAGES)
await router.register_topic(topics.CONNECTION_MESSAGES)
await router.register_topic(topics.DOWNLOAD_COMMANDS)
event_router = EventRouter(
session_id,
command_sender=router.sender(topics.COMMANDS),
external_outbound=router.sender(topics.LOCAL_EVENTS),
external_inbound=router.receiver(topics.GLOBAL_EVENTS),
)
logger.info(f"Starting node {node_id}")
@@ -67,10 +59,13 @@ class Node:
if not args.no_downloads:
download_coordinator = DownloadCoordinator(
node_id,
session_id,
exo_shard_downloader(),
event_sender=event_router.sender(),
download_command_receiver=router.receiver(topics.DOWNLOAD_COMMANDS),
local_event_sender=router.sender(topics.LOCAL_EVENTS),
offline=args.offline,
# TODO(evan): remove
_global_event_receiver=router.receiver(topics.GLOBAL_EVENTS),
)
else:
download_coordinator = None
@@ -78,8 +73,9 @@ class Node:
if args.spawn_api:
api = API(
node_id,
session_id,
port=args.api_port,
event_receiver=event_router.receiver(),
global_event_receiver=router.receiver(topics.GLOBAL_EVENTS),
command_sender=router.sender(topics.COMMANDS),
download_command_sender=router.sender(topics.DOWNLOAD_COMMANDS),
election_receiver=router.receiver(topics.ELECTION_MESSAGES),
@@ -90,8 +86,9 @@ class Node:
if not args.no_worker:
worker = Worker(
node_id,
event_receiver=event_router.receiver(),
event_sender=event_router.sender(),
session_id,
global_event_receiver=router.receiver(topics.GLOBAL_EVENTS),
local_event_sender=router.sender(topics.LOCAL_EVENTS),
command_sender=router.sender(topics.COMMANDS),
download_command_sender=router.sender(topics.DOWNLOAD_COMMANDS),
)
@@ -102,7 +99,6 @@ class Node:
master = Master(
node_id,
session_id,
event_sender=event_router.sender(),
global_event_sender=router.sender(topics.GLOBAL_EVENTS),
local_event_receiver=router.receiver(topics.LOCAL_EVENTS),
command_receiver=router.receiver(topics.COMMANDS),
@@ -125,7 +121,6 @@ class Node:
return cls(
router,
event_router,
download_coordinator,
worker,
election,
@@ -141,7 +136,6 @@ class Node:
signal.signal(signal.SIGINT, lambda _, __: self.shutdown())
signal.signal(signal.SIGTERM, lambda _, __: self.shutdown())
tg.start_soon(self.router.run)
tg.start_soon(self.event_router.run)
tg.start_soon(self.election.run)
if self.download_coordinator:
tg.start_soon(self.download_coordinator.run)
@@ -189,7 +183,6 @@ class Node:
self.master = Master(
self.node_id,
result.session_id,
event_sender=self.event_router.sender(),
global_event_sender=self.router.sender(topics.GLOBAL_EVENTS),
local_event_receiver=self.router.receiver(topics.LOCAL_EVENTS),
command_receiver=self.router.receiver(topics.COMMANDS),
@@ -213,24 +206,21 @@ class Node:
)
if result.is_new_master:
await anyio.sleep(0)
self.event_router.shutdown()
self.event_router = EventRouter(
result.session_id,
self.router.sender(topics.COMMANDS),
self.router.receiver(topics.GLOBAL_EVENTS),
self.router.sender(topics.LOCAL_EVENTS),
)
self._tg.start_soon(self.event_router.run)
if self.download_coordinator:
self.download_coordinator.shutdown()
self.download_coordinator = DownloadCoordinator(
self.node_id,
result.session_id,
exo_shard_downloader(),
event_sender=self.event_router.sender(),
download_command_receiver=self.router.receiver(
topics.DOWNLOAD_COMMANDS
),
local_event_sender=self.router.sender(topics.LOCAL_EVENTS),
offline=self.offline,
# TODO(evan): remove
_global_event_receiver=self.router.receiver(
topics.GLOBAL_EVENTS
),
)
self._tg.start_soon(self.download_coordinator.run)
if self.worker:
@@ -238,8 +228,11 @@ class Node:
# TODO: add profiling etc to resource monitor
self.worker = Worker(
self.node_id,
event_receiver=self.event_router.receiver(),
event_sender=self.event_router.sender(),
result.session_id,
global_event_receiver=self.router.receiver(
topics.GLOBAL_EVENTS
),
local_event_sender=self.router.sender(topics.LOCAL_EVENTS),
command_sender=self.router.sender(topics.COMMANDS),
download_command_sender=self.router.sender(
topics.DOWNLOAD_COMMANDS
@@ -247,7 +240,7 @@ class Node:
)
self._tg.start_soon(self.worker.run)
if self.api:
self.api.reset(result.won_clock, self.event_router.receiver())
self.api.reset(result.session_id, result.won_clock)
else:
if self.api:
self.api.unpause(result.won_clock)

View File

@@ -140,10 +140,11 @@ from exo.shared.types.commands import (
TaskFinished,
TextGeneration,
)
from exo.shared.types.common import CommandId, Id, NodeId, SystemId
from exo.shared.types.common import CommandId, Id, NodeId, SessionId, SystemId
from exo.shared.types.events import (
ChunkGenerated,
Event,
GlobalForwarderEvent,
IndexedEvent,
TracesMerged,
)
@@ -167,10 +168,16 @@ from exo.shared.types.openai_responses import (
)
from exo.shared.types.state import State
from exo.shared.types.worker.downloads import DownloadCompleted
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.instances import (
Instance,
InstanceId,
InstanceMeta,
MlxJacclInstance,
)
from exo.shared.types.worker.shards import Sharding
from exo.utils.banner import print_startup_banner
from exo.utils.channels import Receiver, Sender, channel
from exo.utils.event_buffer import OrderedBuffer
from exo.utils.task_group import TaskGroup
_API_EVENT_LOG_DIR = EXO_EVENT_LOG_DIR / "api"
@@ -194,9 +201,10 @@ class API:
def __init__(
self,
node_id: NodeId,
session_id: SessionId,
*,
port: int,
event_receiver: Receiver[IndexedEvent],
global_event_receiver: Receiver[GlobalForwarderEvent],
command_sender: Sender[ForwarderCommand],
download_command_sender: Sender[ForwarderDownloadCommand],
# This lets us pause the API if an election is running
@@ -207,9 +215,11 @@ class API:
self._system_id = SystemId()
self.command_sender = command_sender
self.download_command_sender = download_command_sender
self.event_receiver = event_receiver
self.global_event_receiver = global_event_receiver
self.election_receiver = election_receiver
self.event_buffer: OrderedBuffer[Event] = OrderedBuffer[Event]()
self.node_id: NodeId = node_id
self.session_id: SessionId = session_id
self.last_completed_election: int = 0
self.port = port
@@ -249,18 +259,17 @@ class API:
self._image_store = ImageStore(EXO_IMAGE_CACHE_DIR)
self._tg: TaskGroup = TaskGroup()
def reset(self, result_clock: int, event_receiver: Receiver[IndexedEvent]):
def reset(self, new_session_id: SessionId, result_clock: int):
logger.info("Resetting API State")
self._event_log.close()
self._event_log = DiskEventLog(_API_EVENT_LOG_DIR)
self.state = State()
self._system_id = SystemId()
self.session_id = new_session_id
self.event_buffer = OrderedBuffer[Event]()
self._text_generation_queues = {}
self._image_generation_queues = {}
self.unpause(result_clock)
self.event_receiver.close()
self.event_receiver = event_receiver
self._tg.start_soon(self._apply_state)
def unpause(self, result_clock: int):
logger.info("Unpausing API")
@@ -509,6 +518,14 @@ class API:
shard_assignments = instance.shard_assignments
placement_node_ids = list(shard_assignments.node_to_runner.keys())
# Derive instance_meta from the actual instance type, since
# place_instance() may override it (e.g., single-node → MlxRing)
actual_instance_meta = (
InstanceMeta.MlxJaccl
if isinstance(instance, MlxJacclInstance)
else InstanceMeta.MlxRing
)
memory_delta_by_node: dict[str, int] = {}
if placement_node_ids:
total_bytes = model_card.storage_size.in_bytes
@@ -521,14 +538,14 @@ class API:
if (
model_card.model_id,
sharding,
instance_meta,
actual_instance_meta,
len(placement_node_ids),
) not in seen:
previews.append(
PlacementPreview(
model_id=model_card.model_id,
sharding=sharding,
instance_meta=instance_meta,
instance_meta=actual_instance_meta,
instance=instance,
memory_delta_by_node=memory_delta_by_node or None,
error=None,
@@ -538,7 +555,7 @@ class API:
(
model_card.model_id,
sharding,
instance_meta,
actual_instance_meta,
len(placement_node_ids),
)
)
@@ -1602,7 +1619,7 @@ class API:
finally:
self._event_log.close()
self.command_sender.close()
self.event_receiver.close()
self.global_event_receiver.close()
async def run_api(self, ev: anyio.Event):
cfg = Config()
@@ -1619,33 +1636,38 @@ class API:
)
async def _apply_state(self):
idx = 0
with self.event_receiver as events:
async for event in events:
self._event_log.append(event.event)
self.state = apply(self.state, event)
idx += 1
event = event.event
with self.global_event_receiver as events:
async for f_event in events:
if f_event.session != self.session_id:
continue
if f_event.origin != self.session_id.master_node_id:
continue
self.event_buffer.ingest(f_event.origin_idx, f_event.event)
for idx, event in self.event_buffer.drain_indexed():
self._event_log.append(event)
self.state = apply(self.state, IndexedEvent(event=event, idx=idx))
if isinstance(event, ChunkGenerated):
if queue := self._image_generation_queues.get(
event.command_id, None
):
assert isinstance(event.chunk, ImageChunk)
try:
await queue.send(event.chunk)
except BrokenResourceError:
self._image_generation_queues.pop(event.command_id, None)
if queue := self._text_generation_queues.get(
event.command_id, None
):
assert not isinstance(event.chunk, ImageChunk)
try:
await queue.send(event.chunk)
except BrokenResourceError:
self._text_generation_queues.pop(event.command_id, None)
if isinstance(event, TracesMerged):
self._save_merged_trace(event)
if isinstance(event, ChunkGenerated):
if queue := self._image_generation_queues.get(
event.command_id, None
):
assert isinstance(event.chunk, ImageChunk)
try:
await queue.send(event.chunk)
except BrokenResourceError:
self._image_generation_queues.pop(
event.command_id, None
)
if queue := self._text_generation_queues.get(
event.command_id, None
):
assert not isinstance(event.chunk, ImageChunk)
try:
await queue.send(event.chunk)
except BrokenResourceError:
self._text_generation_queues.pop(event.command_id, None)
if isinstance(event, TracesMerged):
self._save_merged_trace(event)
def _save_merged_trace(self, event: TracesMerged) -> None:
traces = [

View File

@@ -60,7 +60,7 @@ from exo.shared.types.tasks import (
TextGeneration as TextGenerationTask,
)
from exo.shared.types.worker.instances import InstanceId
from exo.utils.channels import Receiver, Sender
from exo.utils.channels import Receiver, Sender, channel
from exo.utils.event_buffer import MultiSourceBuffer
from exo.utils.task_group import TaskGroup
@@ -72,21 +72,25 @@ class Master:
session_id: SessionId,
*,
command_receiver: Receiver[ForwarderCommand],
event_sender: Sender[Event],
local_event_receiver: Receiver[LocalForwarderEvent],
global_event_sender: Sender[GlobalForwarderEvent],
download_command_sender: Sender[ForwarderDownloadCommand],
):
self.node_id = node_id
self.session_id = session_id
self.state = State()
self._tg: TaskGroup = TaskGroup()
self.node_id = node_id
self.session_id = session_id
self.command_task_mapping: dict[CommandId, TaskId] = {}
self.command_receiver = command_receiver
self.local_event_receiver = local_event_receiver
self.global_event_sender = global_event_sender
self.download_command_sender = download_command_sender
self.event_sender = event_sender
send, recv = channel[Event]()
self.event_sender: Sender[Event] = send
self._loopback_event_receiver: Receiver[Event] = recv
self._loopback_event_sender: Sender[LocalForwarderEvent] = (
local_event_receiver.clone_sender()
)
self._system_id = SystemId()
self._multi_buffer = MultiSourceBuffer[SystemId, Event]()
self._event_log = DiskEventLog(EXO_EVENT_LOG_DIR / "master")
@@ -100,12 +104,15 @@ class Master:
async with self._tg as tg:
tg.start_soon(self._event_processor)
tg.start_soon(self._command_processor)
tg.start_soon(self._loopback_processor)
tg.start_soon(self._plan)
finally:
self._event_log.close()
self.global_event_sender.close()
self.local_event_receiver.close()
self.command_receiver.close()
self._loopback_event_sender.close()
self._loopback_event_receiver.close()
async def shutdown(self):
logger.info("Stopping Master")
@@ -402,6 +409,22 @@ class Master:
self._event_log.append(event)
await self._send_event(indexed)
async def _loopback_processor(self) -> None:
# this would ideally not be necessary.
# this is WAY less hacky than how I was working around this before
local_index = 0
with self._loopback_event_receiver as events:
async for event in events:
await self._loopback_event_sender.send(
LocalForwarderEvent(
origin=self._system_id,
origin_idx=local_index,
session=self.session_id,
event=event,
)
)
local_index += 1
# This function is re-entrant, take care!
async def _send_event(self, event: IndexedEvent):
# Convenience method since this line is ugly

View File

@@ -17,7 +17,6 @@ from exo.shared.types.commands import (
)
from exo.shared.types.common import ModelId, NodeId, SessionId, SystemId
from exo.shared.types.events import (
Event,
GlobalForwarderEvent,
IndexedEvent,
InstanceCreated,
@@ -51,7 +50,6 @@ async def test_master():
command_sender, co_receiver = channel[ForwarderCommand]()
local_event_sender, le_receiver = channel[LocalForwarderEvent]()
fcds, _fcdr = channel[ForwarderDownloadCommand]()
ev_send, _ev_recv = channel[Event]()
all_events: list[IndexedEvent] = []
@@ -69,7 +67,6 @@ async def test_master():
master = Master(
node_id,
session_id,
event_sender=ev_send,
global_event_sender=ge_sender,
local_event_receiver=le_receiver,
command_receiver=co_receiver,

View File

@@ -1,161 +0,0 @@
from dataclasses import dataclass, field
from random import random
import anyio
from anyio import BrokenResourceError, ClosedResourceError
from anyio.abc import CancelScope
from loguru import logger
from exo.shared.types.commands import ForwarderCommand, RequestEventLog
from exo.shared.types.common import SessionId, SystemId
from exo.shared.types.events import (
Event,
EventId,
GlobalForwarderEvent,
IndexedEvent,
LocalForwarderEvent,
)
from exo.utils.channels import Receiver, Sender, channel
from exo.utils.event_buffer import OrderedBuffer
from exo.utils.task_group import TaskGroup
@dataclass
class EventRouter:
session_id: SessionId
command_sender: Sender[ForwarderCommand]
external_inbound: Receiver[GlobalForwarderEvent]
external_outbound: Sender[LocalForwarderEvent]
_system_id: SystemId = field(init=False, default_factory=SystemId)
internal_outbound: list[Sender[IndexedEvent]] = field(
init=False, default_factory=list
)
event_buffer: OrderedBuffer[Event] = field(
init=False, default_factory=OrderedBuffer
)
out_for_delivery: dict[EventId, tuple[float, LocalForwarderEvent]] = field(
init=False, default_factory=dict
)
_tg: TaskGroup = field(init=False, default_factory=TaskGroup)
_nack_cancel_scope: CancelScope | None = field(init=False, default=None)
_nack_attempts: int = field(init=False, default=0)
_nack_base_seconds: float = field(init=False, default=0.5)
_nack_cap_seconds: float = field(init=False, default=10.0)
async def run(self):
try:
async with self._tg as tg:
tg.start_soon(self._run_ext_in)
tg.start_soon(self._simple_retry)
finally:
self.external_outbound.close()
for send in self.internal_outbound:
send.close()
# can make this better in future
async def _simple_retry(self):
while True:
await anyio.sleep(1 + random())
# list here is a shallow clone for shared mutation
for e_id, (time, event) in list(self.out_for_delivery.items()):
if anyio.current_time() > time + 5:
self.out_for_delivery[e_id] = (anyio.current_time(), event)
await self.external_outbound.send(event)
def sender(self) -> Sender[Event]:
send, recv = channel[Event]()
if self._tg.is_running():
self._tg.start_soon(self._ingest, SystemId(), recv)
else:
self._tg.queue(self._ingest, SystemId(), recv)
return send
def receiver(self) -> Receiver[IndexedEvent]:
send, recv = channel[IndexedEvent]()
self.internal_outbound.append(send)
return recv
def shutdown(self) -> None:
self._tg.cancel_tasks()
async def _ingest(self, system_id: SystemId, recv: Receiver[Event]):
idx = 0
with recv as events:
async for event in events:
f_ev = LocalForwarderEvent(
origin_idx=idx,
origin=system_id,
session=self.session_id,
event=event,
)
idx += 1
await self.external_outbound.send(f_ev)
self.out_for_delivery[event.event_id] = (anyio.current_time(), f_ev)
async def _run_ext_in(self):
buf = OrderedBuffer[Event]()
with self.external_inbound as events:
async for event in events:
if event.session != self.session_id:
continue
if event.origin != self.session_id.master_node_id:
continue
buf.ingest(event.origin_idx, event.event)
event_id = event.event.event_id
if event_id in self.out_for_delivery:
self.out_for_delivery.pop(event_id)
drained = buf.drain_indexed()
if drained:
self._nack_attempts = 0
if self._nack_cancel_scope:
self._nack_cancel_scope.cancel()
if not drained and (
self._nack_cancel_scope is None
or self._nack_cancel_scope.cancel_called
):
# Request the next index.
self._tg.start_soon(self._nack_request, buf.next_idx_to_release)
continue
for idx, event in drained:
to_clear = set[int]()
for i, sender in enumerate(self.internal_outbound):
try:
await sender.send(IndexedEvent(idx=idx, event=event))
except (ClosedResourceError, BrokenResourceError):
to_clear.add(i)
for i in sorted(to_clear, reverse=True):
self.internal_outbound.pop(i)
async def _nack_request(self, since_idx: int) -> None:
# We request all events after (and including) the missing index.
# This function is started whenever we receive an event that is out of sequence.
# It is cancelled as soon as we receiver an event that is in sequence.
if since_idx < 0:
logger.warning(f"Negative value encountered for nack request {since_idx=}")
since_idx = 0
with CancelScope() as scope:
self._nack_cancel_scope = scope
delay: float = self._nack_base_seconds * (2.0**self._nack_attempts)
delay = min(self._nack_cap_seconds, delay)
self._nack_attempts += 1
try:
await anyio.sleep(delay)
logger.info(
f"Nack attempt {self._nack_attempts}: Requesting Event Log from {since_idx}"
)
await self.command_sender.send(
ForwarderCommand(
origin=self._system_id,
command=RequestEventLog(since_idx=since_idx),
)
)
finally:
if self._nack_cancel_scope is scope:
self._nack_cancel_scope = None

View File

@@ -1,8 +1,9 @@
from collections import defaultdict
from datetime import datetime, timezone
from random import random
import anyio
from anyio import fail_after
from anyio import CancelScope, fail_after
from loguru import logger
from exo.download.download_utils import resolve_model_in_path
@@ -12,12 +13,14 @@ from exo.shared.types.api import ImageEditsTaskParams
from exo.shared.types.commands import (
ForwarderCommand,
ForwarderDownloadCommand,
RequestEventLog,
StartDownload,
)
from exo.shared.types.common import CommandId, NodeId, SystemId
from exo.shared.types.common import CommandId, NodeId, SessionId, SystemId
from exo.shared.types.events import (
Event,
EventId,
GlobalForwarderEvent,
IndexedEvent,
InputChunkReceived,
LocalForwarderEvent,
@@ -43,6 +46,7 @@ from exo.shared.types.topology import Connection, SocketConnection
from exo.shared.types.worker.downloads import DownloadCompleted
from exo.shared.types.worker.runners import RunnerId
from exo.utils.channels import Receiver, Sender, channel
from exo.utils.event_buffer import OrderedBuffer
from exo.utils.info_gatherer.info_gatherer import GatheredInfo, InfoGatherer
from exo.utils.info_gatherer.net_profile import check_reachable
from exo.utils.keyed_backoff import KeyedBackoff
@@ -55,27 +59,38 @@ class Worker:
def __init__(
self,
node_id: NodeId,
session_id: SessionId,
*,
event_receiver: Receiver[IndexedEvent],
event_sender: Sender[Event],
global_event_receiver: Receiver[GlobalForwarderEvent],
local_event_sender: Sender[LocalForwarderEvent],
# This is for requesting updates. It doesn't need to be a general command sender right now,
# but I think it's the correct way to be thinking about commands
command_sender: Sender[ForwarderCommand],
download_command_sender: Sender[ForwarderDownloadCommand],
):
self.node_id: NodeId = node_id
self.event_receiver = event_receiver
self.event_sender = event_sender
self.session_id: SessionId = session_id
self.global_event_receiver = global_event_receiver
self.local_event_sender = local_event_sender
self.command_sender = command_sender
self.download_command_sender = download_command_sender
self.event_buffer = OrderedBuffer[Event]()
self.out_for_delivery: dict[EventId, LocalForwarderEvent] = {}
self.state: State = State()
self.runners: dict[RunnerId, RunnerSupervisor] = {}
self._tg: TaskGroup = TaskGroup()
self._nack_cancel_scope: CancelScope | None = None
self._nack_attempts: int = 0
self._nack_base_seconds: float = 0.5
self._nack_cap_seconds: float = 10.0
self._system_id = SystemId()
self.event_sender, self.event_receiver = channel[Event]()
# Buffer for input image chunks (for image editing)
self.input_chunk_buffer: dict[CommandId, dict[int, str]] = {}
self.input_chunk_counts: dict[CommandId, int] = {}
@@ -93,12 +108,14 @@ class Worker:
tg.start_soon(info_gatherer.run)
tg.start_soon(self._forward_info, info_recv)
tg.start_soon(self.plan_step)
tg.start_soon(self._resend_out_for_delivery)
tg.start_soon(self._event_applier)
tg.start_soon(self._forward_events)
tg.start_soon(self._poll_connection_updates)
finally:
# Actual shutdown code - waits for all tasks to complete before executing.
logger.info("Stopping Worker")
self.event_sender.close()
self.local_event_sender.close()
self.command_sender.close()
self.download_command_sender.close()
for runner in self.runners.values():
@@ -116,22 +133,47 @@ class Worker:
)
async def _event_applier(self):
with self.event_receiver as events:
async for event in events:
with self.global_event_receiver as events:
async for f_event in events:
if f_event.session != self.session_id:
continue
if f_event.origin != self.session_id.master_node_id:
continue
self.event_buffer.ingest(f_event.origin_idx, f_event.event)
event_id = f_event.event.event_id
if event_id in self.out_for_delivery:
del self.out_for_delivery[event_id]
# 2. for each event, apply it to the state
self.state = apply(self.state, event=event)
event = event.event
indexed_events = self.event_buffer.drain_indexed()
if indexed_events:
self._nack_attempts = 0
# Buffer input image chunks for image editing
if isinstance(event, InputChunkReceived):
cmd_id = event.command_id
if cmd_id not in self.input_chunk_buffer:
self.input_chunk_buffer[cmd_id] = {}
self.input_chunk_counts[cmd_id] = event.chunk.total_chunks
self.input_chunk_buffer[cmd_id][event.chunk.chunk_index] = (
event.chunk.data
if not indexed_events and (
self._nack_cancel_scope is None
or self._nack_cancel_scope.cancel_called
):
# Request the next index.
self._tg.start_soon(
self._nack_request, self.state.last_event_applied_idx + 1
)
continue
elif indexed_events and self._nack_cancel_scope:
self._nack_cancel_scope.cancel()
for idx, event in indexed_events:
self.state = apply(self.state, IndexedEvent(idx=idx, event=event))
# Buffer input image chunks for image editing
if isinstance(event, InputChunkReceived):
cmd_id = event.command_id
if cmd_id not in self.input_chunk_buffer:
self.input_chunk_buffer[cmd_id] = {}
self.input_chunk_counts[cmd_id] = event.chunk.total_chunks
self.input_chunk_buffer[cmd_id][event.chunk.chunk_index] = (
event.chunk.data
)
async def plan_step(self):
while True:
@@ -283,6 +325,43 @@ class Worker:
instance.shard_assignments.node_to_runner[self.node_id]
].start_task(task)
async def _nack_request(self, since_idx: int) -> None:
# We request all events after (and including) the missing index.
# This function is started whenever we receive an event that is out of sequence.
# It is cancelled as soon as we receiver an event that is in sequence.
if since_idx < 0:
logger.warning(f"Negative value encountered for nack request {since_idx=}")
since_idx = 0
with CancelScope() as scope:
self._nack_cancel_scope = scope
delay: float = self._nack_base_seconds * (2.0**self._nack_attempts)
delay = min(self._nack_cap_seconds, delay)
self._nack_attempts += 1
try:
await anyio.sleep(delay)
logger.info(
f"Nack attempt {self._nack_attempts}: Requesting Event Log from {since_idx}"
)
await self.command_sender.send(
ForwarderCommand(
origin=self._system_id,
command=RequestEventLog(since_idx=since_idx),
)
)
finally:
if self._nack_cancel_scope is scope:
self._nack_cancel_scope = None
async def _resend_out_for_delivery(self) -> None:
# This can also be massively tightened, we should check events are at least a certain age before resending.
# Exponential backoff would also certainly help here.
while True:
await anyio.sleep(1 + random())
for event in self.out_for_delivery.copy().values():
await self.local_event_sender.send(event)
def _create_supervisor(self, task: CreateRunner) -> RunnerSupervisor:
"""Creates and stores a new AssignedRunner with initial downloading status."""
runner = RunnerSupervisor.create(
@@ -293,6 +372,21 @@ class Worker:
self._tg.start_soon(runner.run)
return runner
async def _forward_events(self) -> None:
idx = 0
with self.event_receiver as events:
async for event in events:
fe = LocalForwarderEvent(
origin_idx=idx,
origin=self._system_id,
session=self.session_id,
event=event,
)
idx += 1
logger.debug(f"Worker published event {idx}: {str(event)[:100]}")
await self.local_event_sender.send(fe)
self.out_for_delivery[event.event_id] = fe
async def _poll_connection_updates(self):
while True:
edges = set(