Compare commits

..

1 Commits

Author SHA1 Message Date
Evan
2ee0bce898 state compaction
introduces a new topic ("state_catchup") over which a full state can be
sent. currently the master sends the worker + api this new state, and
they update only if they have no other events applied - otherwise usual
NACK systems function

## testing

manually tested on two nodes
2026-01-23 14:15:06 +00:00
27 changed files with 346 additions and 739 deletions

View File

@@ -3,45 +3,6 @@
perSystem =
{ pkgs, lib, ... }:
let
# Stub source with lockfiles and minimal files for build to succeed
# This allows prettier-svelte to avoid rebuilding when dashboard source changes
dashboardStubSrc = pkgs.runCommand "dashboard-stub-src" { } ''
mkdir -p $out
cp ${inputs.self}/dashboard/package.json $out/
cp ${inputs.self}/dashboard/package-lock.json $out/
# Minimal files so vite build succeeds (produces empty output)
echo '<!DOCTYPE html><html><head></head><body></body></html>' > $out/index.html
mkdir -p $out/src
touch $out/src/app.html
'';
# Deps-only build using stub source (for prettier-svelte)
# Only rebuilds when package.json or package-lock.json change
dashboardDeps = inputs.dream2nix.lib.evalModules {
packageSets.nixpkgs = pkgs;
modules = [
./dashboard.nix
{
paths.projectRoot = inputs.self;
paths.projectRootFile = "flake.nix";
paths.package = inputs.self + "/dashboard";
}
{
deps.dashboardSrc = lib.mkForce dashboardStubSrc;
}
# Override build phases to skip the actual build - just need node_modules
{
mkDerivation = {
buildPhase = lib.mkForce "true";
installPhase = lib.mkForce ''
runHook preInstall
runHook postInstall
'';
};
}
];
};
# Filter source to only include dashboard directory
dashboardSrc = lib.cleanSourceWith {
src = inputs.self;
@@ -81,12 +42,11 @@
'';
# Prettier with svelte plugin for treefmt
# Uses dashboardDeps instead of dashboardFull to avoid rebuilding on source changes
packages.prettier-svelte = pkgs.writeShellScriptBin "prettier-svelte" ''
export NODE_PATH="${dashboardDeps}/lib/node_modules/exo-dashboard/node_modules"
export NODE_PATH="${dashboardFull}/lib/node_modules/exo-dashboard/node_modules"
exec ${pkgs.nodejs}/bin/node \
${dashboardDeps}/lib/node_modules/exo-dashboard/node_modules/prettier/bin/prettier.cjs \
--plugin "${dashboardDeps}/lib/node_modules/exo-dashboard/node_modules/prettier-plugin-svelte/plugin.js" \
${dashboardFull}/lib/node_modules/exo-dashboard/node_modules/prettier/bin/prettier.cjs \
--plugin "${dashboardFull}/lib/node_modules/exo-dashboard/node_modules/prettier-plugin-svelte/plugin.js" \
"$@"
'';
};

View File

@@ -216,8 +216,6 @@ export interface Message {
attachments?: MessageAttachment[];
ttftMs?: number; // Time to first token in ms (for assistant messages)
tps?: number; // Tokens per second (for assistant messages)
requestType?: "chat" | "image-generation" | "image-editing";
sourceImageDataUrl?: string; // For image editing regeneration
}
export interface Conversation {
@@ -1272,46 +1270,10 @@ class AppStore {
if (lastUserIndex === -1) return;
const lastUserMessage = this.messages[lastUserIndex];
const requestType = lastUserMessage.requestType || "chat";
const prompt = lastUserMessage.content;
// Remove any messages after the user message
this.messages = this.messages.slice(0, lastUserIndex + 1);
// Remove messages after user message (including the user message for image requests
// since generateImage/editImage will re-add it)
this.messages = this.messages.slice(0, lastUserIndex);
switch (requestType) {
case "image-generation":
await this.generateImage(prompt);
break;
case "image-editing":
if (lastUserMessage.sourceImageDataUrl) {
await this.editImage(prompt, lastUserMessage.sourceImageDataUrl);
} else {
// Can't regenerate edit without source image - restore user message and show error
this.messages.push(lastUserMessage);
const errorMessage = this.addMessage("assistant", "");
const idx = this.messages.findIndex((m) => m.id === errorMessage.id);
if (idx !== -1) {
this.messages[idx].content =
"Error: Cannot regenerate image edit - source image not found";
}
this.updateActiveConversation();
}
break;
case "chat":
default:
// Restore the user message for chat regeneration
this.messages.push(lastUserMessage);
await this.regenerateChatCompletion();
break;
}
}
/**
* Helper method to regenerate a chat completion response
*/
private async regenerateChatCompletion(): Promise<void> {
// Resend the message to get a new response
this.isLoading = true;
this.currentResponse = "";
@@ -1826,7 +1788,6 @@ class AppStore {
role: "user",
content: prompt,
timestamp: Date.now(),
requestType: "image-generation",
};
this.messages.push(userMessage);
@@ -2037,8 +1998,6 @@ class AppStore {
role: "user",
content: prompt,
timestamp: Date.now(),
requestType: "image-editing",
sourceImageDataUrl: imageDataUrl,
};
this.messages.push(userMessage);
@@ -2228,54 +2187,6 @@ class AppStore {
this.conversations.find((c) => c.id === this.activeConversationId) || null
);
}
/**
* Start a download on a specific node
*/
async startDownload(nodeId: string, shardMetadata: object): Promise<void> {
try {
const response = await fetch("/download/start", {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({
targetNodeId: nodeId,
shardMetadata: shardMetadata,
}),
});
if (!response.ok) {
const errorText = await response.text();
throw new Error(
`Failed to start download: ${response.status} - ${errorText}`,
);
}
} catch (error) {
console.error("Error starting download:", error);
throw error;
}
}
/**
* Delete a downloaded model from a specific node
*/
async deleteDownload(nodeId: string, modelId: string): Promise<void> {
try {
const response = await fetch(
`/download/${encodeURIComponent(nodeId)}/${encodeURIComponent(modelId)}`,
{
method: "DELETE",
},
);
if (!response.ok) {
const errorText = await response.text();
throw new Error(
`Failed to delete download: ${response.status} - ${errorText}`,
);
}
} catch (error) {
console.error("Error deleting download:", error);
throw error;
}
}
}
export const appStore = new AppStore();
@@ -2381,9 +2292,3 @@ export const setImageGenerationParams = (
) => appStore.setImageGenerationParams(params);
export const resetImageGenerationParams = () =>
appStore.resetImageGenerationParams();
// Download actions
export const startDownload = (nodeId: string, shardMetadata: object) =>
appStore.startDownload(nodeId, shardMetadata);
export const deleteDownload = (nodeId: string, modelId: string) =>
appStore.deleteDownload(nodeId, modelId);

View File

@@ -6,8 +6,6 @@
type DownloadProgress,
refreshState,
lastUpdate as lastUpdateStore,
startDownload,
deleteDownload,
} from "$lib/stores/app.svelte";
import HeaderNav from "$lib/components/HeaderNav.svelte";
@@ -30,7 +28,6 @@
etaMs: number;
status: "completed" | "downloading";
files: FileProgress[];
shardMetadata?: Record<string, unknown>;
};
type NodeEntry = {
@@ -272,12 +269,6 @@
}
}
// Extract shard_metadata for use with download actions
const shardMetadata = (downloadPayload.shard_metadata ??
downloadPayload.shardMetadata) as
| Record<string, unknown>
| undefined;
const entry: ModelEntry = {
modelId,
prettyName,
@@ -294,7 +285,6 @@
? "completed"
: "downloading",
files,
shardMetadata,
};
const existing = modelMap.get(modelId);
@@ -479,52 +469,6 @@
>
{pct.toFixed(1)}%
</span>
{#if model.status !== "completed" && model.shardMetadata}
<button
type="button"
class="text-exo-light-gray hover:text-exo-yellow transition-colors"
onclick={() =>
startDownload(node.nodeId, model.shardMetadata!)}
title="Start download"
>
<svg
class="w-4 h-4"
viewBox="0 0 20 20"
fill="none"
stroke="currentColor"
stroke-width="2"
>
<path
d="M10 3v10m0 0l-3-3m3 3l3-3M3 17h14"
stroke-linecap="round"
stroke-linejoin="round"
></path>
</svg>
</button>
{/if}
{#if model.status === "completed"}
<button
type="button"
class="text-exo-light-gray hover:text-red-400 transition-colors"
onclick={() =>
deleteDownload(node.nodeId, model.modelId)}
title="Delete download"
>
<svg
class="w-4 h-4"
viewBox="0 0 20 20"
fill="none"
stroke="currentColor"
stroke-width="2"
>
<path
d="M4 6h12M8 6V4h4v2m1 0v10a1 1 0 01-1 1H8a1 1 0 01-1-1V6h6"
stroke-linecap="round"
stroke-linejoin="round"
></path>
</svg>
</button>
{/if}
<button
type="button"
class="text-exo-light-gray hover:text-exo-yellow transition-colors"

View File

@@ -1,284 +0,0 @@
import asyncio
from dataclasses import dataclass, field
from typing import Iterator
import anyio
from anyio import current_time
from anyio.abc import TaskGroup
from loguru import logger
from exo.download.download_utils import (
RepoDownloadProgress,
delete_model,
map_repo_download_progress_to_download_progress_data,
)
from exo.download.shard_downloader import ShardDownloader
from exo.shared.models.model_cards import ModelId
from exo.shared.types.commands import (
DeleteDownload,
ForwarderDownloadCommand,
StartDownload,
)
from exo.shared.types.common import NodeId, SessionId
from exo.shared.types.events import (
Event,
ForwarderEvent,
NodeDownloadProgress,
)
from exo.shared.types.worker.downloads import (
DownloadCompleted,
DownloadFailed,
DownloadOngoing,
DownloadPending,
DownloadProgress,
)
from exo.shared.types.worker.shards import ShardMetadata
from exo.utils.channels import Receiver, Sender, channel
@dataclass
class DownloadCoordinator:
node_id: NodeId
session_id: SessionId
shard_downloader: ShardDownloader
download_command_receiver: Receiver[ForwarderDownloadCommand]
local_event_sender: Sender[ForwarderEvent]
event_index_counter: Iterator[int]
# Local state
download_status: dict[ModelId, DownloadProgress] = field(default_factory=dict)
active_downloads: dict[ModelId, asyncio.Task[None]] = field(default_factory=dict)
# Internal event channel for forwarding (initialized in __post_init__)
event_sender: Sender[Event] = field(init=False)
event_receiver: Receiver[Event] = field(init=False)
_tg: TaskGroup = field(init=False)
def __post_init__(self) -> None:
self.event_sender, self.event_receiver = channel[Event]()
self._tg = anyio.create_task_group()
async def run(self) -> None:
logger.info("Starting DownloadCoordinator")
async with self._tg as tg:
tg.start_soon(self._command_processor)
tg.start_soon(self._forward_events)
tg.start_soon(self._emit_existing_download_progress)
def shutdown(self) -> None:
self._tg.cancel_scope.cancel()
async def _command_processor(self) -> None:
with self.download_command_receiver as commands:
async for cmd in commands:
# Only process commands targeting this node
if cmd.command.target_node_id != self.node_id:
continue
match cmd.command:
case StartDownload(shard_metadata=shard):
await self._start_download(shard)
case DeleteDownload(model_id=model_id):
await self._delete_download(model_id)
async def _start_download(self, shard: ShardMetadata) -> None:
model_id = shard.model_card.model_id
# Check if already downloading or complete
if model_id in self.download_status:
status = self.download_status[model_id]
if isinstance(status, (DownloadOngoing, DownloadCompleted)):
logger.debug(
f"Download for {model_id} already in progress or complete, skipping"
)
return
# Emit pending status
progress = DownloadPending(shard_metadata=shard, node_id=self.node_id)
self.download_status[model_id] = progress
await self.event_sender.send(NodeDownloadProgress(download_progress=progress))
# Check initial status from downloader
initial_progress = (
await self.shard_downloader.get_shard_download_status_for_shard(shard)
)
if initial_progress.status == "complete":
completed = DownloadCompleted(
shard_metadata=shard,
node_id=self.node_id,
total_bytes=initial_progress.total_bytes,
)
self.download_status[model_id] = completed
await self.event_sender.send(
NodeDownloadProgress(download_progress=completed)
)
return
# Start actual download
self._start_download_task(shard, initial_progress)
def _start_download_task(
self, shard: ShardMetadata, initial_progress: RepoDownloadProgress
) -> None:
model_id = shard.model_card.model_id
# Emit ongoing status
status = DownloadOngoing(
node_id=self.node_id,
shard_metadata=shard,
download_progress=map_repo_download_progress_to_download_progress_data(
initial_progress
),
)
self.download_status[model_id] = status
self.event_sender.send_nowait(NodeDownloadProgress(download_progress=status))
last_progress_time = 0.0
throttle_interval_secs = 1.0
async def download_progress_callback(
callback_shard: ShardMetadata, progress: RepoDownloadProgress
) -> None:
nonlocal last_progress_time
if progress.status == "complete":
completed = DownloadCompleted(
shard_metadata=callback_shard,
node_id=self.node_id,
total_bytes=progress.total_bytes,
)
self.download_status[callback_shard.model_card.model_id] = completed
await self.event_sender.send(
NodeDownloadProgress(download_progress=completed)
)
# Clean up active download tracking
if callback_shard.model_card.model_id in self.active_downloads:
del self.active_downloads[callback_shard.model_card.model_id]
elif (
progress.status == "in_progress"
and current_time() - last_progress_time > throttle_interval_secs
):
ongoing = DownloadOngoing(
node_id=self.node_id,
shard_metadata=callback_shard,
download_progress=map_repo_download_progress_to_download_progress_data(
progress
),
)
self.download_status[callback_shard.model_card.model_id] = ongoing
await self.event_sender.send(
NodeDownloadProgress(download_progress=ongoing)
)
last_progress_time = current_time()
self.shard_downloader.on_progress(download_progress_callback)
async def download_wrapper() -> None:
try:
await self.shard_downloader.ensure_shard(shard)
except Exception as e:
logger.error(f"Download failed for {model_id}: {e}")
failed = DownloadFailed(
shard_metadata=shard,
node_id=self.node_id,
error_message=str(e),
)
self.download_status[model_id] = failed
await self.event_sender.send(
NodeDownloadProgress(download_progress=failed)
)
finally:
if model_id in self.active_downloads:
del self.active_downloads[model_id]
task = asyncio.create_task(download_wrapper())
self.active_downloads[model_id] = task
async def _delete_download(self, model_id: ModelId) -> None:
# Cancel if active
if model_id in self.active_downloads:
logger.info(f"Cancelling active download for {model_id} before deletion")
self.active_downloads[model_id].cancel()
del self.active_downloads[model_id]
# Delete from disk
logger.info(f"Deleting model files for {model_id}")
deleted = await delete_model(model_id)
if deleted:
logger.info(f"Successfully deleted model {model_id}")
else:
logger.warning(f"Model {model_id} was not found on disk")
# Emit pending status to reset UI state, then remove from local tracking
if model_id in self.download_status:
current_status = self.download_status[model_id]
pending = DownloadPending(
shard_metadata=current_status.shard_metadata,
node_id=self.node_id,
)
await self.event_sender.send(
NodeDownloadProgress(download_progress=pending)
)
del self.download_status[model_id]
async def _forward_events(self) -> None:
with self.event_receiver as events:
async for event in events:
idx = next(self.event_index_counter)
fe = ForwarderEvent(
origin_idx=idx,
origin=self.node_id,
session=self.session_id,
event=event,
)
logger.debug(
f"DownloadCoordinator published event {idx}: {str(event)[:100]}"
)
await self.local_event_sender.send(fe)
async def _emit_existing_download_progress(self) -> None:
try:
while True:
logger.info(
"DownloadCoordinator: Fetching and emitting existing download progress..."
)
async for (
_,
progress,
) in self.shard_downloader.get_shard_download_status():
if progress.status == "complete":
status: DownloadProgress = DownloadCompleted(
node_id=self.node_id,
shard_metadata=progress.shard,
total_bytes=progress.total_bytes,
)
elif progress.status in ["in_progress", "not_started"]:
if progress.downloaded_bytes_this_session.in_bytes == 0:
status = DownloadPending(
node_id=self.node_id, shard_metadata=progress.shard
)
else:
status = DownloadOngoing(
node_id=self.node_id,
shard_metadata=progress.shard,
download_progress=map_repo_download_progress_to_download_progress_data(
progress
),
)
else:
continue
self.download_status[progress.shard.model_card.model_id] = status
await self.event_sender.send(
NodeDownloadProgress(download_progress=status)
)
logger.info(
"DownloadCoordinator: Done emitting existing download progress."
)
await anyio.sleep(5 * 60) # 5 minutes
except Exception as e:
logger.error(
f"DownloadCoordinator: Error emitting existing download progress: {e}"
)

View File

@@ -1,11 +1,10 @@
import argparse
import itertools
import multiprocessing as mp
import os
import resource
import signal
from dataclasses import dataclass, field
from typing import Iterator, Self
from typing import Self
import anyio
from anyio.abc import TaskGroup
@@ -13,8 +12,6 @@ from loguru import logger
from pydantic import PositiveInt
import exo.routing.topics as topics
from exo.download.coordinator import DownloadCoordinator
from exo.download.impl_shard_downloader import exo_shard_downloader
from exo.master.api import API # TODO: should API be in master?
from exo.master.main import Master
from exo.routing.router import Router, get_node_id_keypair
@@ -24,6 +21,7 @@ from exo.shared.logging import logger_cleanup, logger_setup
from exo.shared.types.common import NodeId, SessionId
from exo.utils.channels import Receiver, channel
from exo.utils.pydantic_ext import CamelCaseModel
from exo.worker.download.impl_shard_downloader import exo_shard_downloader
from exo.worker.main import Worker
@@ -31,7 +29,6 @@ from exo.worker.main import Worker
@dataclass
class Node:
router: Router
download_coordinator: DownloadCoordinator | None
worker: Worker | None
election: Election # Every node participates in election, as we do want a node to become master even if it isn't a master candidate if no master candidates are present.
election_result_receiver: Receiver[ElectionResult]
@@ -39,7 +36,6 @@ class Node:
api: API | None
node_id: NodeId
event_index_counter: Iterator[int]
_tg: TaskGroup = field(init=False, default_factory=anyio.create_task_group)
@classmethod
@@ -53,26 +49,9 @@ class Node:
await router.register_topic(topics.COMMANDS)
await router.register_topic(topics.ELECTION_MESSAGES)
await router.register_topic(topics.CONNECTION_MESSAGES)
await router.register_topic(topics.DOWNLOAD_COMMANDS)
await router.register_topic(topics.STATE_CATCHUP)
logger.info(f"Starting node {node_id}")
# Create shared event index counter for Worker and DownloadCoordinator
event_index_counter = itertools.count()
# Create DownloadCoordinator (unless --no-downloads)
if not args.no_downloads:
download_coordinator = DownloadCoordinator(
node_id,
session_id,
exo_shard_downloader(),
download_command_receiver=router.receiver(topics.DOWNLOAD_COMMANDS),
local_event_sender=router.sender(topics.LOCAL_EVENTS),
event_index_counter=event_index_counter,
)
else:
download_coordinator = None
if args.spawn_api:
api = API(
node_id,
@@ -80,8 +59,8 @@ class Node:
port=args.api_port,
global_event_receiver=router.receiver(topics.GLOBAL_EVENTS),
command_sender=router.sender(topics.COMMANDS),
download_command_sender=router.sender(topics.DOWNLOAD_COMMANDS),
election_receiver=router.receiver(topics.ELECTION_MESSAGES),
state_catchup_receiver=router.receiver(topics.STATE_CATCHUP),
)
else:
api = None
@@ -90,12 +69,12 @@ class Node:
worker = Worker(
node_id,
session_id,
exo_shard_downloader(),
connection_message_receiver=router.receiver(topics.CONNECTION_MESSAGES),
global_event_receiver=router.receiver(topics.GLOBAL_EVENTS),
local_event_sender=router.sender(topics.LOCAL_EVENTS),
command_sender=router.sender(topics.COMMANDS),
download_command_sender=router.sender(topics.DOWNLOAD_COMMANDS),
event_index_counter=event_index_counter,
state_catchup_receiver=router.receiver(topics.STATE_CATCHUP),
)
else:
worker = None
@@ -107,6 +86,7 @@ class Node:
global_event_sender=router.sender(topics.GLOBAL_EVENTS),
local_event_receiver=router.receiver(topics.LOCAL_EVENTS),
command_receiver=router.receiver(topics.COMMANDS),
state_catchup_sender=router.sender(topics.STATE_CATCHUP),
)
er_send, er_recv = channel[ElectionResult]()
@@ -123,25 +103,13 @@ class Node:
election_result_sender=er_send,
)
return cls(
router,
download_coordinator,
worker,
election,
er_recv,
master,
api,
node_id,
event_index_counter,
)
return cls(router, worker, election, er_recv, master, api, node_id)
async def run(self):
async with self._tg as tg:
signal.signal(signal.SIGINT, lambda _, __: self.shutdown())
tg.start_soon(self.router.run)
tg.start_soon(self.election.run)
if self.download_coordinator:
tg.start_soon(self.download_coordinator.run)
if self.worker:
tg.start_soon(self.worker.run)
if self.master:
@@ -189,6 +157,7 @@ class Node:
global_event_sender=self.router.sender(topics.GLOBAL_EVENTS),
local_event_receiver=self.router.receiver(topics.LOCAL_EVENTS),
command_receiver=self.router.receiver(topics.COMMANDS),
state_catchup_sender=self.router.sender(topics.STATE_CATCHUP),
)
self._tg.start_soon(self.master.run)
elif (
@@ -206,27 +175,13 @@ class Node:
)
if result.is_new_master:
await anyio.sleep(0)
# Fresh counter for new session (buffer expects indices from 0)
self.event_index_counter = itertools.count()
if self.download_coordinator:
self.download_coordinator.shutdown()
self.download_coordinator = DownloadCoordinator(
self.node_id,
result.session_id,
exo_shard_downloader(),
download_command_receiver=self.router.receiver(
topics.DOWNLOAD_COMMANDS
),
local_event_sender=self.router.sender(topics.LOCAL_EVENTS),
event_index_counter=self.event_index_counter,
)
self._tg.start_soon(self.download_coordinator.run)
if self.worker:
self.worker.shutdown()
# TODO: add profiling etc to resource monitor
self.worker = Worker(
self.node_id,
result.session_id,
exo_shard_downloader(),
connection_message_receiver=self.router.receiver(
topics.CONNECTION_MESSAGES
),
@@ -235,10 +190,9 @@ class Node:
),
local_event_sender=self.router.sender(topics.LOCAL_EVENTS),
command_sender=self.router.sender(topics.COMMANDS),
download_command_sender=self.router.sender(
topics.DOWNLOAD_COMMANDS
state_catchup_receiver=self.router.receiver(
topics.STATE_CATCHUP
),
event_index_counter=self.event_index_counter,
)
self._tg.start_soon(self.worker.run)
if self.api:
@@ -280,7 +234,6 @@ class Args(CamelCaseModel):
api_port: PositiveInt = 52415
tb_only: bool = False
no_worker: bool = False
no_downloads: bool = False
fast_synch: bool | None = None # None = auto, True = force on, False = force off
@classmethod
@@ -323,11 +276,6 @@ class Args(CamelCaseModel):
"--no-worker",
action="store_true",
)
parser.add_argument(
"--no-downloads",
action="store_true",
help="Disable the download coordinator (node won't download models)",
)
fast_synch_group = parser.add_mutually_exclusive_group()
fast_synch_group.add_argument(
"--fast-synch",

View File

@@ -44,7 +44,6 @@ from exo.shared.types.api import (
ChatCompletionResponse,
CreateInstanceParams,
CreateInstanceResponse,
DeleteDownloadResponse,
DeleteInstanceResponse,
ErrorInfo,
ErrorResponse,
@@ -62,8 +61,6 @@ from exo.shared.types.api import (
PlaceInstanceParams,
PlacementPreview,
PlacementPreviewResponse,
StartDownloadParams,
StartDownloadResponse,
StreamingChoiceResponse,
ToolCall,
)
@@ -78,16 +75,12 @@ from exo.shared.types.commands import (
ChatCompletion,
Command,
CreateInstance,
DeleteDownload,
DeleteInstance,
DownloadCommand,
ForwarderCommand,
ForwarderDownloadCommand,
ImageEdits,
ImageGeneration,
PlaceInstance,
SendInputChunk,
StartDownload,
TaskFinished,
)
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
@@ -163,16 +156,16 @@ class API:
# Ideally this would be a MasterForwarderEvent but type system says no :(
global_event_receiver: Receiver[ForwarderEvent],
command_sender: Sender[ForwarderCommand],
download_command_sender: Sender[ForwarderDownloadCommand],
# This lets us pause the API if an election is running
election_receiver: Receiver[ElectionMessage],
state_catchup_receiver: Receiver[State],
) -> None:
self.state = State()
self._event_log: list[Event] = []
self.command_sender = command_sender
self.download_command_sender = download_command_sender
self.global_event_receiver = global_event_receiver
self.election_receiver = election_receiver
self.state_catchup_receiver = state_catchup_receiver
self.event_buffer: OrderedBuffer[Event] = OrderedBuffer[Event]()
self.node_id: NodeId = node_id
self.session_id: SessionId = session_id
@@ -269,8 +262,6 @@ class API:
self.app.get("/images/{image_id}")(self.get_image)
self.app.get("/state")(lambda: self.state)
self.app.get("/events")(lambda: self._event_log)
self.app.post("/download/start")(self.start_download)
self.app.delete("/download/{node_id}/{model_id:path}")(self.delete_download)
async def place_instance(self, payload: PlaceInstanceParams):
command = PlaceInstance(
@@ -1242,6 +1233,7 @@ class API:
tg.start_soon(self._apply_state)
tg.start_soon(self._pause_on_new_election)
tg.start_soon(self._cleanup_expired_images)
tg.start_soon(self._state_catchup)
print_startup_banner(self.port)
await serve(
cast(ASGIFramework, self.app),
@@ -1252,6 +1244,22 @@ class API:
self.command_sender.close()
self.global_event_receiver.close()
async def _state_catchup(self):
with self.state_catchup_receiver as states:
async for state in states:
if (
self.state.last_event_applied_idx == -1
and state.last_event_applied_idx > self.state.last_event_applied_idx
):
logger.info(
f"API catching up state to idx {state.last_event_applied_idx}"
)
self.event_buffer.store = {}
self.event_buffer.next_idx_to_release = (
state.last_event_applied_idx + 1
)
self.state = state
async def _apply_state(self):
with self.global_event_receiver as events:
async for f_event in events:
@@ -1303,28 +1311,3 @@ class API:
await self.command_sender.send(
ForwarderCommand(origin=self.node_id, command=command)
)
async def _send_download(self, command: DownloadCommand):
await self.download_command_sender.send(
ForwarderDownloadCommand(origin=self.node_id, command=command)
)
async def start_download(
self, payload: StartDownloadParams
) -> StartDownloadResponse:
command = StartDownload(
target_node_id=payload.target_node_id,
shard_metadata=payload.shard_metadata,
)
await self._send_download(command)
return StartDownloadResponse(command_id=command.command_id)
async def delete_download(
self, node_id: NodeId, model_id: ModelId
) -> DeleteDownloadResponse:
command = DeleteDownload(
target_node_id=node_id,
model_id=ModelId(model_id),
)
await self._send_download(command)
return DeleteDownloadResponse(command_id=command.command_id)

View File

@@ -68,6 +68,8 @@ class Master:
# Send events to the forwarder to be indexed (usually from command processing)
# Ideally these would be MasterForwarderEvents but type system says no :(
global_event_sender: Sender[ForwarderEvent],
# not a fan but - send the entire state to a node so it can catchup without the whole event log.
state_catchup_sender: Sender[State],
):
self.state = State()
self._tg: TaskGroup = anyio.create_task_group()
@@ -77,6 +79,7 @@ class Master:
self.command_receiver = command_receiver
self.local_event_receiver = local_event_receiver
self.global_event_sender = global_event_sender
self.state_catchup_sender = state_catchup_sender
send, recv = channel[Event]()
self.event_sender: Sender[Event] = send
self._loopback_event_receiver: Receiver[Event] = recv
@@ -84,7 +87,6 @@ class Master:
local_event_receiver.clone_sender()
)
self._multi_buffer = MultiSourceBuffer[NodeId, Event]()
# TODO: not have this
self._event_log: list[Event] = []
async def run(self):
@@ -291,11 +293,17 @@ class Master:
command.finished_command_id
]
case RequestEventLog():
# We should just be able to send everything, since other buffers will ignore old messages
for i in range(command.since_idx, len(self._event_log)):
await self._send_event(
IndexedEvent(idx=i, event=self._event_log[i])
if command.since_idx == 0:
# This is an optimization, and should not be relied upon in theory.
logger.info(
f"Master sending catchup state for index {self.state.last_event_applied_idx}"
)
await self.state_catchup_sender.send(self.state)
else:
for i in range(command.since_idx, len(self._event_log)):
await self._send_event(
IndexedEvent(idx=i, event=self._event_log[i])
)
for event in generated_events:
await self.event_sender.send(event)
except ValueError as e:

View File

@@ -27,6 +27,7 @@ from exo.shared.types.memory import Memory
from exo.shared.types.profiling import (
MemoryUsage,
)
from exo.shared.types.state import State
from exo.shared.types.tasks import ChatCompletion as ChatCompletionTask
from exo.shared.types.tasks import TaskStatus
from exo.shared.types.worker.instances import (
@@ -47,6 +48,7 @@ async def test_master():
ge_sender, global_event_receiver = channel[ForwarderEvent]()
command_sender, co_receiver = channel[ForwarderCommand]()
local_event_sender, le_receiver = channel[ForwarderEvent]()
st_s, _st_r = channel[State]()
all_events: list[IndexedEvent] = []
@@ -67,6 +69,7 @@ async def test_master():
global_event_sender=ge_sender,
local_event_receiver=le_receiver,
command_receiver=co_receiver,
state_catchup_sender=st_s,
)
logger.info("run the master")
async with anyio.create_task_group() as tg:

View File

@@ -3,10 +3,11 @@ from enum import Enum
from exo.routing.connection_message import ConnectionMessage
from exo.shared.election import ElectionMessage
from exo.shared.types.commands import ForwarderCommand, ForwarderDownloadCommand
from exo.shared.types.commands import ForwarderCommand
from exo.shared.types.events import (
ForwarderEvent,
)
from exo.shared.types.state import State
from exo.utils.pydantic_ext import CamelCaseModel
@@ -45,6 +46,4 @@ ELECTION_MESSAGES = TypedTopic(
CONNECTION_MESSAGES = TypedTopic(
"connection_messages", PublishPolicy.Never, ConnectionMessage
)
DOWNLOAD_COMMANDS = TypedTopic(
"download_commands", PublishPolicy.Always, ForwarderDownloadCommand
)
STATE_CATCHUP = TypedTopic("state_catchup", PublishPolicy.Always, State)

View File

@@ -621,7 +621,7 @@ class ConfigData(BaseModel):
async def get_config_data(model_id: ModelId) -> ConfigData:
"""Downloads and parses config.json for a model."""
from exo.download.download_utils import (
from exo.worker.download.download_utils import (
download_file_with_retry,
ensure_models_dir,
)
@@ -643,11 +643,11 @@ async def get_config_data(model_id: ModelId) -> ConfigData:
async def get_safetensors_size(model_id: ModelId) -> Memory:
"""Gets model size from safetensors index or falls back to HF API."""
from exo.download.download_utils import (
from exo.shared.types.worker.downloads import ModelSafetensorsIndex
from exo.worker.download.download_utils import (
download_file_with_retry,
ensure_models_dir,
)
from exo.shared.types.worker.downloads import ModelSafetensorsIndex
target_dir = (await ensure_models_dir()) / model_id.normalize()
await aios.makedirs(target_dir, exist_ok=True)

View File

@@ -7,11 +7,10 @@ from pydantic import BaseModel, Field, field_validator
from pydantic_core import PydanticUseDefault
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.common import CommandId
from exo.shared.types.memory import Memory
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.shards import Sharding, ShardMetadata
from exo.utils.pydantic_ext import CamelCaseModel
from exo.shared.types.worker.shards import Sharding
FinishReason = Literal[
"stop", "length", "tool_calls", "content_filter", "function_call", "error"
@@ -353,16 +352,3 @@ class ImageListItem(BaseModel, frozen=True):
class ImageListResponse(BaseModel, frozen=True):
data: list[ImageListItem]
class StartDownloadParams(CamelCaseModel):
target_node_id: NodeId
shard_metadata: ShardMetadata
class StartDownloadResponse(CamelCaseModel):
command_id: CommandId
class DeleteDownloadResponse(CamelCaseModel):
command_id: CommandId

View File

@@ -1,6 +1,6 @@
from pydantic import Field
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.models.model_cards import ModelCard
from exo.shared.types.api import (
ChatCompletionTaskParams,
ImageEditsInternalParams,
@@ -9,7 +9,7 @@ from exo.shared.types.api import (
from exo.shared.types.chunks import InputImageChunk
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.shards import Sharding, ShardMetadata
from exo.shared.types.worker.shards import Sharding
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
@@ -62,19 +62,6 @@ class RequestEventLog(BaseCommand):
since_idx: int
class StartDownload(BaseCommand):
target_node_id: NodeId
shard_metadata: ShardMetadata
class DeleteDownload(BaseCommand):
target_node_id: NodeId
model_id: ModelId
DownloadCommand = StartDownload | DeleteDownload
Command = (
TestCommand
| RequestEventLog
@@ -92,8 +79,3 @@ Command = (
class ForwarderCommand(CamelCaseModel):
origin: NodeId
command: Command
class ForwarderDownloadCommand(CamelCaseModel):
origin: NodeId
command: DownloadCommand

View File

@@ -1,32 +0,0 @@
import time
from typing import Generic, TypeVar
K = TypeVar("K")
class KeyedBackoff(Generic[K]):
"""Tracks exponential backoff state per key."""
def __init__(self, base: float = 0.5, cap: float = 10.0):
self._base = base
self._cap = cap
self._attempts: dict[K, int] = {}
self._last_time: dict[K, float] = {}
def should_proceed(self, key: K) -> bool:
"""Returns True if enough time has elapsed since last attempt."""
now = time.monotonic()
last = self._last_time.get(key, 0.0)
attempts = self._attempts.get(key, 0)
delay = min(self._cap, self._base * (2.0**attempts))
return now - last >= delay
def record_attempt(self, key: K) -> None:
"""Record that an attempt was made for this key."""
self._last_time[key] = time.monotonic()
self._attempts[key] = self._attempts.get(key, 0) + 1
def reset(self, key: K) -> None:
"""Reset backoff state for a key (e.g., on success)."""
self._attempts.pop(key, None)
self._last_time.pop(key, None)

View File

@@ -24,13 +24,6 @@ from pydantic import (
TypeAdapter,
)
from exo.download.huggingface_utils import (
filter_repo_objects,
get_allow_patterns,
get_auth_headers,
get_hf_endpoint,
get_hf_token,
)
from exo.shared.constants import EXO_MODELS_DIR
from exo.shared.types.common import ModelId
from exo.shared.types.memory import Memory
@@ -42,6 +35,13 @@ from exo.shared.types.worker.downloads import (
RepoFileDownloadProgress,
)
from exo.shared.types.worker.shards import ShardMetadata
from exo.worker.download.huggingface_utils import (
filter_repo_objects,
get_allow_patterns,
get_auth_headers,
get_hf_endpoint,
get_hf_token,
)
class HuggingFaceAuthenticationError(Exception):

View File

@@ -5,13 +5,13 @@ from typing import AsyncIterator, Callable
from loguru import logger
from exo.download.download_utils import RepoDownloadProgress, download_shard
from exo.download.shard_downloader import ShardDownloader
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard, ModelId
from exo.shared.types.worker.shards import (
PipelineShardMetadata,
ShardMetadata,
)
from exo.worker.download.download_utils import RepoDownloadProgress, download_shard
from exo.worker.download.shard_downloader import ShardDownloader
def exo_shard_downloader(max_parallel_downloads: int = 8) -> ShardDownloader:

View File

@@ -5,13 +5,13 @@ from datetime import timedelta
from pathlib import Path
from typing import AsyncIterator, Callable
from exo.download.download_utils import RepoDownloadProgress
from exo.shared.models.model_cards import ModelCard, ModelId, ModelTask
from exo.shared.types.memory import Memory
from exo.shared.types.worker.shards import (
PipelineShardMetadata,
ShardMetadata,
)
from exo.worker.download.download_utils import RepoDownloadProgress
# TODO: the PipelineShardMetadata getting reinstantiated is a bit messy. Should this be a classmethod?

View File

@@ -6,10 +6,10 @@ import mlx.core as mx
from mflux.models.common.config.config import Config
from PIL import Image
from exo.download.download_utils import build_model_path
from exo.shared.types.api import AdvancedImageParams
from exo.shared.types.worker.instances import BoundInstance
from exo.shared.types.worker.shards import PipelineShardMetadata
from exo.worker.download.download_utils import build_model_path
from exo.worker.engines.image.config import ImageModelConfig
from exo.worker.engines.image.models import (
create_adapter_for_model,

View File

@@ -41,7 +41,6 @@ import mlx.nn as nn
from mlx_lm.utils import load_model
from pydantic import RootModel
from exo.download.download_utils import build_model_path
from exo.shared.types.api import ChatCompletionMessageText
from exo.shared.types.common import Host
from exo.shared.types.memory import Memory
@@ -56,6 +55,7 @@ from exo.shared.types.worker.shards import (
ShardMetadata,
TensorShardMetadata,
)
from exo.worker.download.download_utils import build_model_path
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.auto_parallel import (
TimeoutCallback,

View File

@@ -1,9 +1,8 @@
from datetime import datetime, timezone
from random import random
from typing import Iterator
import anyio
from anyio import CancelScope, create_task_group, fail_after
from anyio import CancelScope, create_task_group, current_time, fail_after
from anyio.abc import TaskGroup
from loguru import logger
@@ -11,12 +10,7 @@ from exo.routing.connection_message import ConnectionMessage, ConnectionMessageT
from exo.shared.apply import apply
from exo.shared.models.model_cards import ModelId
from exo.shared.types.api import ImageEditsInternalParams
from exo.shared.types.commands import (
ForwarderCommand,
ForwarderDownloadCommand,
RequestEventLog,
StartDownload,
)
from exo.shared.types.commands import ForwarderCommand, RequestEventLog
from exo.shared.types.common import CommandId, NodeId, SessionId
from exo.shared.types.events import (
Event,
@@ -24,6 +18,7 @@ from exo.shared.types.events import (
ForwarderEvent,
IndexedEvent,
InputChunkReceived,
NodeDownloadProgress,
NodeGatheredInfo,
TaskCreated,
TaskStatusUpdated,
@@ -41,12 +36,23 @@ from exo.shared.types.tasks import (
TaskStatus,
)
from exo.shared.types.topology import Connection, SocketConnection
from exo.shared.types.worker.downloads import (
DownloadCompleted,
DownloadFailed,
DownloadOngoing,
DownloadPending,
DownloadProgress,
)
from exo.shared.types.worker.runners import RunnerId
from exo.shared.types.worker.shards import ShardMetadata
from exo.utils.channels import Receiver, Sender, channel
from exo.utils.event_buffer import OrderedBuffer
from exo.utils.info_gatherer.info_gatherer import GatheredInfo, InfoGatherer
from exo.utils.info_gatherer.net_profile import check_reachable
from exo.utils.keyed_backoff import KeyedBackoff
from exo.worker.download.download_utils import (
map_repo_download_progress_to_download_progress_data,
)
from exo.worker.download.shard_downloader import RepoDownloadProgress, ShardDownloader
from exo.worker.plan import plan
from exo.worker.runner.runner_supervisor import RunnerSupervisor
@@ -56,29 +62,31 @@ class Worker:
self,
node_id: NodeId,
session_id: SessionId,
shard_downloader: ShardDownloader,
*,
connection_message_receiver: Receiver[ConnectionMessage],
global_event_receiver: Receiver[ForwarderEvent],
local_event_sender: Sender[ForwarderEvent],
# This is for requesting updates. It doesn't need to be a general command sender right now,
# but I think it's the correct way to be thinking about commands
command_sender: Sender[ForwarderCommand],
download_command_sender: Sender[ForwarderDownloadCommand],
event_index_counter: Iterator[int],
state_catchup_receiver: Receiver[State],
):
self.node_id: NodeId = node_id
self.session_id: SessionId = session_id
self.shard_downloader: ShardDownloader = shard_downloader
self._pending_downloads: dict[RunnerId, ShardMetadata] = {}
self.global_event_receiver = global_event_receiver
self.local_event_sender = local_event_sender
self.event_index_counter = event_index_counter
self.state_catchup_receiver = state_catchup_receiver
self.local_event_index = 0
self.command_sender = command_sender
self.download_command_sender = download_command_sender
self.connection_message_receiver = connection_message_receiver
self.event_buffer = OrderedBuffer[Event]()
self.out_for_delivery: dict[EventId, ForwarderEvent] = {}
self.state: State = State()
self.download_status: dict[ModelId, DownloadProgress] = {}
self.runners: dict[RunnerId, RunnerSupervisor] = {}
self._tg: TaskGroup = create_task_group()
@@ -93,8 +101,6 @@ class Worker:
self.input_chunk_buffer: dict[CommandId, dict[int, str]] = {}
self.input_chunk_counts: dict[CommandId, int] = {}
self._download_backoff: KeyedBackoff[ModelId] = KeyedBackoff(base=0.5, cap=10.0)
async def run(self):
logger.info("Starting Worker")
@@ -105,16 +111,17 @@ class Worker:
tg.start_soon(info_gatherer.run)
tg.start_soon(self._forward_info, info_recv)
tg.start_soon(self.plan_step)
tg.start_soon(self._emit_existing_download_progress)
tg.start_soon(self._connection_message_event_writer)
tg.start_soon(self._resend_out_for_delivery)
tg.start_soon(self._event_applier)
tg.start_soon(self._forward_events)
tg.start_soon(self._poll_connection_updates)
tg.start_soon(self._check_catchup_state)
# Actual shutdown code - waits for all tasks to complete before executing.
self.local_event_sender.close()
self.command_sender.close()
self.download_command_sender.close()
for runner in self.runners.values():
runner.shutdown()
@@ -129,6 +136,22 @@ class Worker:
)
)
async def _check_catchup_state(self):
with self.state_catchup_receiver as states:
async for state in states:
if (
self.state.last_event_applied_idx == -1
and state.last_event_applied_idx > self.state.last_event_applied_idx
):
logger.info(
f"Worker catching up state to idx {state.last_event_applied_idx}"
)
self.event_buffer.store = {}
self.event_buffer.next_idx_to_release = (
state.last_event_applied_idx + 1
)
self.state = state
async def _event_applier(self):
with self.global_event_receiver as events:
async for f_event in events:
@@ -173,9 +196,11 @@ class Worker:
async def plan_step(self):
while True:
await anyio.sleep(0.1)
# 3. based on the updated state, we plan & execute an operation.
task: Task | None = plan(
self.node_id,
self.runners,
self.download_status,
self.state.downloads,
self.state.instances,
self.state.runners,
@@ -199,26 +224,42 @@ class Worker:
)
)
case DownloadModel(shard_metadata=shard):
model_id = shard.model_card.model_id
if not self._download_backoff.should_proceed(model_id):
continue
self._download_backoff.record_attempt(model_id)
await self.download_command_sender.send(
ForwarderDownloadCommand(
origin=self.node_id,
command=StartDownload(
target_node_id=self.node_id,
shard_metadata=shard,
),
if shard.model_card.model_id not in self.download_status:
progress = DownloadPending(
shard_metadata=shard, node_id=self.node_id
)
self.download_status[shard.model_card.model_id] = progress
await self.event_sender.send(
NodeDownloadProgress(download_progress=progress)
)
initial_progress = (
await self.shard_downloader.get_shard_download_status_for_shard(
shard
)
)
await self.event_sender.send(
TaskStatusUpdated(
task_id=task.task_id, task_status=TaskStatus.Running
if initial_progress.status == "complete":
progress = DownloadCompleted(
shard_metadata=shard,
node_id=self.node_id,
total_bytes=initial_progress.total_bytes,
)
)
self.download_status[shard.model_card.model_id] = progress
await self.event_sender.send(
NodeDownloadProgress(download_progress=progress)
)
await self.event_sender.send(
TaskStatusUpdated(
task_id=task.task_id,
task_status=TaskStatus.Complete,
)
)
else:
await self.event_sender.send(
TaskStatusUpdated(
task_id=task.task_id, task_status=TaskStatus.Running
)
)
self._handle_shard_download_process(task, initial_progress)
case Shutdown(runner_id=runner_id):
try:
with fail_after(3):
@@ -318,10 +359,7 @@ class Worker:
# We request all events after (and including) the missing index.
# This function is started whenever we receive an event that is out of sequence.
# It is cancelled as soon as we receiver an event that is in sequence.
if since_idx < 0:
logger.warning(f"Negative value encountered for nack request {since_idx=}")
since_idx = 0
assert since_idx >= 0
with CancelScope() as scope:
self._nack_cancel_scope = scope
@@ -363,17 +401,104 @@ class Worker:
self._tg.start_soon(runner.run)
return runner
def _handle_shard_download_process(
self,
task: DownloadModel,
initial_progress: RepoDownloadProgress,
):
"""Manages the shard download process with progress tracking."""
status = DownloadOngoing(
node_id=self.node_id,
shard_metadata=task.shard_metadata,
download_progress=map_repo_download_progress_to_download_progress_data(
initial_progress
),
)
self.download_status[task.shard_metadata.model_card.model_id] = status
self.event_sender.send_nowait(NodeDownloadProgress(download_progress=status))
last_progress_time = 0.0
throttle_interval_secs = 1.0
async def download_progress_callback(
shard: ShardMetadata, progress: RepoDownloadProgress
) -> None:
nonlocal self
nonlocal last_progress_time
if progress.status == "complete":
status = DownloadCompleted(
shard_metadata=shard,
node_id=self.node_id,
total_bytes=progress.total_bytes,
)
self.download_status[shard.model_card.model_id] = status
await self.event_sender.send(
NodeDownloadProgress(download_progress=status)
)
await self.event_sender.send(
TaskStatusUpdated(
task_id=task.task_id, task_status=TaskStatus.Complete
)
)
elif (
progress.status == "in_progress"
and current_time() - last_progress_time > throttle_interval_secs
):
status = DownloadOngoing(
node_id=self.node_id,
shard_metadata=shard,
download_progress=map_repo_download_progress_to_download_progress_data(
progress
),
)
self.download_status[shard.model_card.model_id] = status
await self.event_sender.send(
NodeDownloadProgress(download_progress=status)
)
last_progress_time = current_time()
self.shard_downloader.on_progress(download_progress_callback)
async def download_with_error_handling() -> None:
try:
await self.shard_downloader.ensure_shard(task.shard_metadata)
except Exception as e:
error_message = str(e)
logger.error(
f"Download failed for {task.shard_metadata.model_card.model_id}: {error_message}"
)
failed_status = DownloadFailed(
node_id=self.node_id,
shard_metadata=task.shard_metadata,
error_message=error_message,
)
self.download_status[task.shard_metadata.model_card.model_id] = (
failed_status
)
await self.event_sender.send(
NodeDownloadProgress(download_progress=failed_status)
)
await self.event_sender.send(
TaskStatusUpdated(
task_id=task.task_id, task_status=TaskStatus.Failed
)
)
self._tg.start_soon(download_with_error_handling)
async def _forward_events(self) -> None:
with self.event_receiver as events:
async for event in events:
idx = next(self.event_index_counter)
fe = ForwarderEvent(
origin_idx=idx,
origin_idx=self.local_event_index,
origin=self.node_id,
session=self.session_id,
event=event,
)
logger.debug(f"Worker published event {idx}: {str(event)[:100]}")
logger.debug(
f"Worker published event {self.local_event_index}: {str(event)[:100]}"
)
self.local_event_index += 1
await self.local_event_sender.send(fe)
self.out_for_delivery[event.event_id] = fe
@@ -421,3 +546,42 @@ class Worker:
await self.event_sender.send(TopologyEdgeDeleted(conn=conn))
await anyio.sleep(10)
async def _emit_existing_download_progress(self) -> None:
try:
while True:
logger.debug("Fetching and emitting existing download progress...")
async for (
_,
progress,
) in self.shard_downloader.get_shard_download_status():
if progress.status == "complete":
status = DownloadCompleted(
node_id=self.node_id,
shard_metadata=progress.shard,
total_bytes=progress.total_bytes,
)
elif progress.status in ["in_progress", "not_started"]:
if progress.downloaded_bytes_this_session.in_bytes == 0:
status = DownloadPending(
node_id=self.node_id, shard_metadata=progress.shard
)
else:
status = DownloadOngoing(
node_id=self.node_id,
shard_metadata=progress.shard,
download_progress=map_repo_download_progress_to_download_progress_data(
progress
),
)
else:
continue
self.download_status[progress.shard.model_card.model_id] = status
await self.event_sender.send(
NodeDownloadProgress(download_progress=status)
)
logger.debug("Done emitting existing download progress.")
await anyio.sleep(5 * 60) # 5 minutes
except Exception as e:
logger.error(f"Error emitting existing download progress: {e}")

View File

@@ -2,6 +2,7 @@
from collections.abc import Mapping, Sequence
from exo.shared.models.model_cards import ModelId
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.tasks import (
ChatCompletion,
@@ -44,6 +45,9 @@ def plan(
node_id: NodeId,
# Runners is expected to be FRESH and so should not come from state
runners: Mapping[RunnerId, RunnerSupervisor],
# DL_status is expected to be FRESH and so should not come from state
download_status: Mapping[ModelId, DownloadProgress],
# gdls is not expected to be fresh
global_download_status: Mapping[NodeId, Sequence[DownloadProgress]],
instances: Mapping[InstanceId, Instance],
all_runners: Mapping[RunnerId, RunnerStatus], # all global
@@ -55,7 +59,7 @@ def plan(
return (
_kill_runner(runners, all_runners, instances)
or _create_runner(node_id, runners, instances)
or _model_needs_download(node_id, runners, global_download_status)
or _model_needs_download(runners, download_status)
or _init_distributed_backend(runners, all_runners)
or _load_model(runners, all_runners, global_download_status)
or _ready_to_warmup(runners, all_runners)
@@ -111,15 +115,9 @@ def _create_runner(
def _model_needs_download(
node_id: NodeId,
runners: Mapping[RunnerId, RunnerSupervisor],
global_download_status: Mapping[NodeId, Sequence[DownloadProgress]],
download_status: Mapping[ModelId, DownloadProgress],
) -> DownloadModel | None:
local_downloads = global_download_status.get(node_id, [])
download_status = {
dp.shard_metadata.model_card.model_id: dp for dp in local_downloads
}
for runner in runners.values():
model_id = runner.bound_instance.bound_shard.model_card.model_id
if isinstance(runner.status, RunnerIdle) and (

View File

@@ -11,12 +11,12 @@ from pathlib import Path
import pytest
from exo.download.download_utils import (
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard, ModelId
from exo.worker.download.download_utils import (
download_file_with_retry,
ensure_models_dir,
fetch_file_list_with_cache,
)
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard, ModelId
from exo.worker.engines.mlx.utils_mlx import (
get_eos_token_ids_for_model,
load_tokenizer_for_model_id,

View File

@@ -1,5 +1,5 @@
import exo.worker.plan as plan_mod
from exo.shared.types.common import NodeId
from exo.shared.types.common import ModelId, NodeId
from exo.shared.types.memory import Memory
from exo.shared.types.tasks import LoadModel
from exo.shared.types.worker.downloads import DownloadCompleted, DownloadProgress
@@ -45,9 +45,13 @@ def test_plan_requests_download_when_waiting_and_shard_not_downloaded():
instances = {INSTANCE_1_ID: instance}
all_runners = {RUNNER_1_ID: RunnerIdle()}
# No entry for this shard -> should trigger DownloadModel
download_status: dict[ModelId, DownloadProgress] = {}
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status=download_status,
global_download_status={NODE_A: []},
instances=instances,
all_runners=all_runners,
@@ -88,6 +92,14 @@ def test_plan_loads_model_when_all_shards_downloaded_and_waiting():
RUNNER_2_ID: RunnerConnected(),
}
# Local node has already marked its shard as downloaded (not actually used by _load_model)
local_download_status = {
MODEL_A_ID: DownloadCompleted(
shard_metadata=shard1, node_id=NODE_A, total_bytes=Memory()
)
}
# Global view has completed downloads for both nodes
global_download_status = {
NODE_A: [
DownloadCompleted(
@@ -104,6 +116,7 @@ def test_plan_loads_model_when_all_shards_downloaded_and_waiting():
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status=local_download_status,
global_download_status=global_download_status,
instances=instances,
all_runners=all_runners,
@@ -135,19 +148,23 @@ def test_plan_does_not_request_download_when_shard_already_downloaded():
instances = {INSTANCE_1_ID: instance}
all_runners = {RUNNER_1_ID: RunnerIdle()}
# Global state shows shard is downloaded for NODE_A
# Local status claims the shard is downloaded already
local_download_status = {
MODEL_A_ID: DownloadCompleted(
shard_metadata=shard, node_id=NODE_A, total_bytes=Memory()
)
}
# Global view hasn't caught up yet (no completed shards recorded for NODE_A)
global_download_status: dict[NodeId, list[DownloadProgress]] = {
NODE_A: [
DownloadCompleted(
shard_metadata=shard, node_id=NODE_A, total_bytes=Memory()
)
],
NODE_A: [],
NODE_B: [],
}
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status=local_download_status,
global_download_status=global_download_status,
instances=instances,
all_runners=all_runners,
@@ -185,6 +202,12 @@ def test_plan_does_not_load_model_until_all_shards_downloaded_globally():
RUNNER_2_ID: RunnerConnected(),
}
# Only NODE_A's shard is recorded as downloaded globally
local_download_status = {
MODEL_A_ID: DownloadCompleted(
shard_metadata=shard1, node_id=NODE_A, total_bytes=Memory()
)
}
global_download_status = {
NODE_A: [
DownloadCompleted(
@@ -197,6 +220,7 @@ def test_plan_does_not_load_model_until_all_shards_downloaded_globally():
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status=local_download_status,
global_download_status=global_download_status,
instances=instances,
all_runners=all_runners,
@@ -221,6 +245,7 @@ def test_plan_does_not_load_model_until_all_shards_downloaded_globally():
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status=local_download_status,
global_download_status=global_download_status,
instances=instances,
all_runners=all_runners,

View File

@@ -47,7 +47,8 @@ def test_plan_kills_runner_when_instance_missing():
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore[arg-type]
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_A: []},
instances=instances,
all_runners=all_runners,
@@ -86,7 +87,8 @@ def test_plan_kills_runner_when_sibling_failed():
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore[arg-type]
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_A: []},
instances=instances,
all_runners=all_runners,
@@ -118,6 +120,7 @@ def test_plan_creates_runner_when_missing_for_node():
result = plan_mod.plan(
node_id=NODE_A,
runners=runners,
download_status={},
global_download_status={NODE_A: []},
instances=instances,
all_runners=all_runners,
@@ -155,7 +158,8 @@ def test_plan_does_not_create_runner_when_supervisor_already_present():
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore[arg-type]
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_A: []},
instances=instances,
all_runners=all_runners,
@@ -185,6 +189,7 @@ def test_plan_does_not_create_runner_for_unassigned_node():
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_A: []},
instances=instances,
all_runners=all_runners,

View File

@@ -65,6 +65,7 @@ def test_plan_forwards_pending_chat_completion_when_runner_ready():
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_A: []},
instances=instances,
all_runners=all_runners,
@@ -112,6 +113,7 @@ def test_plan_does_not_forward_chat_completion_if_any_runner_not_ready():
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_A: [], NODE_B: []},
instances=instances,
all_runners=all_runners,
@@ -156,6 +158,7 @@ def test_plan_does_not_forward_tasks_for_other_instances():
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_A: []},
instances=instances,
all_runners=all_runners,
@@ -218,6 +221,7 @@ def test_plan_ignores_non_pending_or_non_chat_tasks():
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_A: [], NODE_B: []},
instances=instances,
all_runners=all_runners,
@@ -257,6 +261,7 @@ def test_plan_returns_none_when_nothing_to_do():
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_A: [], NODE_B: []},
instances=instances,
all_runners=all_runners,

View File

@@ -57,6 +57,7 @@ def test_plan_starts_warmup_for_accepting_rank_when_all_loaded_or_warming():
result = plan_mod.plan(
node_id=NODE_B,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_A: []},
instances=instances,
all_runners=all_runners,
@@ -98,6 +99,7 @@ def test_plan_starts_warmup_for_rank_zero_after_others_warming():
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_A: []},
instances=instances,
all_runners=all_runners,
@@ -138,6 +140,7 @@ def test_plan_does_not_start_warmup_for_non_zero_rank_until_all_loaded_or_warmin
result = plan_mod.plan(
node_id=NODE_B,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_A: [], NODE_B: []},
instances=instances,
all_runners=all_runners,
@@ -182,6 +185,7 @@ def test_plan_does_not_start_warmup_for_rank_zero_until_others_warming():
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_A: []},
instances=instances,
all_runners=all_runners,
@@ -198,6 +202,7 @@ def test_plan_does_not_start_warmup_for_rank_zero_until_others_warming():
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_A: []},
instances=instances,
all_runners=all_runners,
@@ -241,6 +246,7 @@ def test_plan_starts_warmup_for_connecting_rank_after_others_warming():
result = plan_mod.plan(
node_id=NODE_B,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_B: []},
instances=instances,
all_runners=all_runners,
@@ -283,6 +289,7 @@ def test_plan_does_not_start_warmup_for_accepting_rank_until_all_loaded_or_warmi
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_A: [], NODE_B: []},
instances=instances,
all_runners=all_runners,
@@ -324,6 +331,7 @@ def test_plan_does_not_start_warmup_for_connecting_rank_until_others_warming():
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_A: [], NODE_B: []},
instances=instances,
all_runners=all_runners,

View File

@@ -11,10 +11,6 @@ from hypercorn.asyncio import serve # pyright: ignore[reportUnknownVariableType
from loguru import logger
from pydantic import BaseModel
from exo.download.impl_shard_downloader import (
build_full_shard,
exo_shard_downloader,
)
from exo.shared.logging import InterceptLogger, logger_setup
from exo.shared.models.model_cards import MODEL_CARDS, ModelId
from exo.shared.types.api import ChatCompletionMessage, ChatCompletionTaskParams
@@ -40,6 +36,10 @@ from exo.shared.types.worker.runners import RunnerId, ShardAssignments
from exo.shared.types.worker.shards import PipelineShardMetadata, TensorShardMetadata
from exo.utils.channels import MpReceiver, MpSender, channel, mp_channel
from exo.utils.info_gatherer.info_gatherer import GatheredInfo, InfoGatherer
from exo.worker.download.impl_shard_downloader import (
build_full_shard,
exo_shard_downloader,
)
from exo.worker.runner.bootstrap import entrypoint