Compare commits

...

2 Commits

Author SHA1 Message Date
Jake Hillion
3fba7b0b7d downloads: add download and delete buttons to downloads UI
The downloads page showed model download progress but provided no way
for users to trigger downloads or remove completed models from disk.

Added API endpoints (POST /download/start, DELETE /download/{node_id}/{model_id})
that send StartDownload and DeleteDownload commands via the download_command_sender.
Updated the dashboard downloads page with per-model buttons: a download button
for incomplete downloads and a delete button for completed ones.

This allows users to manage downloads directly from the UI without needing
to trigger downloads through other means.

Test plan:
- Ran basedpyright with 0 errors
- Manual verification: start exo, navigate to downloads page, verify buttons appear
2026-01-20 11:59:02 +00:00
Jake Hillion
93759f00d4 downloads: refactor to run at node level
The Worker previously owned the ShardDownloader directly via dependency
injection, which prevented --no-worker nodes from downloading and made
it impossible for multiple Workers to share a single downloader instance.

Moved download functionality to a new DownloadCoordinator component at
the Node level that communicates via the DOWNLOAD_COMMANDS pub/sub topic.
Workers now send StartDownload commands instead of calling the downloader
directly, and receive progress updates through the event-sourced state.

This decouples downloads from the Worker lifecycle and enables future
features like UI-triggered downloads to specific nodes and multi-worker
download sharing.

Test plan:
- basedpyright passes with 0 errors
- ruff check passes
- Manual verification of imports and module structure
2026-01-20 11:59:02 +00:00
18 changed files with 573 additions and 182 deletions

View File

@@ -1600,6 +1600,47 @@ class AppStore {
this.conversations.find((c) => c.id === this.activeConversationId) || null
);
}
/**
* Start a download on a specific node
*/
async startDownload(nodeId: string, shardMetadata: object): Promise<void> {
try {
const response = await fetch("/download/start", {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({
targetNodeId: nodeId,
shardMetadata: shardMetadata,
}),
});
if (!response.ok) {
const errorText = await response.text();
throw new Error(`Failed to start download: ${response.status} - ${errorText}`);
}
} catch (error) {
console.error("Error starting download:", error);
throw error;
}
}
/**
* Delete a downloaded model from a specific node
*/
async deleteDownload(nodeId: string, modelId: string): Promise<void> {
try {
const response = await fetch(`/download/${encodeURIComponent(nodeId)}/${encodeURIComponent(modelId)}`, {
method: "DELETE",
});
if (!response.ok) {
const errorText = await response.text();
throw new Error(`Failed to delete download: ${response.status} - ${errorText}`);
}
} catch (error) {
console.error("Error deleting download:", error);
throw error;
}
}
}
export const appStore = new AppStore();
@@ -1678,3 +1719,9 @@ export const toggleChatSidebarVisible = () =>
export const setChatSidebarVisible = (visible: boolean) =>
appStore.setChatSidebarVisible(visible);
export const refreshState = () => appStore.fetchState();
// Download actions
export const startDownload = (nodeId: string, shardMetadata: object) =>
appStore.startDownload(nodeId, shardMetadata);
export const deleteDownload = (nodeId: string, modelId: string) =>
appStore.deleteDownload(nodeId, modelId);

View File

@@ -5,7 +5,9 @@
downloads,
type DownloadProgress,
refreshState,
lastUpdate as lastUpdateStore
lastUpdate as lastUpdateStore,
startDownload,
deleteDownload
} from '$lib/stores/app.svelte';
import HeaderNav from '$lib/components/HeaderNav.svelte';
@@ -28,6 +30,7 @@
etaMs: number;
status: 'completed' | 'downloading';
files: FileProgress[];
shardMetadata?: Record<string, unknown>;
};
type NodeEntry = {
@@ -230,6 +233,9 @@
}
}
// Extract shard_metadata for use with download actions
const shardMetadata = (downloadPayload.shard_metadata ?? downloadPayload.shardMetadata) as Record<string, unknown> | undefined;
const entry: ModelEntry = {
modelId,
prettyName,
@@ -239,7 +245,8 @@
speed,
etaMs,
status: downloadKind === 'DownloadCompleted' ? 'completed' : 'downloading',
files
files,
shardMetadata
};
const existing = modelMap.get(modelId);
@@ -374,6 +381,30 @@
<span class="text-xs font-mono {pct >= 100 ? 'text-green-400' : pct <= 0 ? 'text-red-400' : 'text-exo-yellow'}">
{pct.toFixed(1)}%
</span>
{#if model.status !== 'completed' && model.shardMetadata}
<button
type="button"
class="text-exo-light-gray hover:text-exo-yellow transition-colors"
onclick={() => startDownload(node.nodeId, model.shardMetadata!)}
title="Start download"
>
<svg class="w-4 h-4" viewBox="0 0 20 20" fill="none" stroke="currentColor" stroke-width="2">
<path d="M10 3v10m0 0l-3-3m3 3l3-3M3 17h14" stroke-linecap="round" stroke-linejoin="round"></path>
</svg>
</button>
{/if}
{#if model.status === 'completed'}
<button
type="button"
class="text-exo-light-gray hover:text-red-400 transition-colors"
onclick={() => deleteDownload(node.nodeId, model.modelId)}
title="Delete download"
>
<svg class="w-4 h-4" viewBox="0 0 20 20" fill="none" stroke="currentColor" stroke-width="2">
<path d="M6 6l8 8M6 14l8-8" stroke-linecap="round" stroke-linejoin="round"></path>
</svg>
</button>
{/if}
<button
type="button"
class="text-exo-light-gray hover:text-exo-yellow transition-colors"

View File

@@ -0,0 +1 @@
# Download package - centralized download management for exo

View File

@@ -0,0 +1,304 @@
import asyncio
import anyio
from anyio import current_time
from anyio.abc import TaskGroup
from loguru import logger
from exo.download.download_utils import (
RepoDownloadProgress,
delete_model,
map_repo_download_progress_to_download_progress_data,
)
from exo.download.shard_downloader import ShardDownloader
from exo.shared.types.commands import (
CancelDownload,
DeleteDownload,
ForwarderDownloadCommand,
StartDownload,
)
from exo.shared.types.common import NodeId, SessionId
from exo.shared.types.events import (
Event,
ForwarderEvent,
NodeDownloadProgress,
)
from exo.shared.models.model_cards import ModelId
from exo.shared.types.worker.downloads import (
DownloadCompleted,
DownloadFailed,
DownloadOngoing,
DownloadPending,
DownloadProgress,
)
from exo.shared.types.worker.shards import ShardMetadata
from exo.utils.channels import Receiver, Sender, channel
class DownloadCoordinator:
def __init__(
self,
node_id: NodeId,
session_id: SessionId,
shard_downloader: ShardDownloader,
*,
download_command_receiver: Receiver[ForwarderDownloadCommand],
local_event_sender: Sender[ForwarderEvent],
):
self.node_id = node_id
self.session_id = session_id
self.shard_downloader = shard_downloader
self.download_command_receiver = download_command_receiver
self.local_event_sender = local_event_sender
# Local state
self.download_status: dict[ModelId, DownloadProgress] = {}
self.active_downloads: dict[ModelId, asyncio.Task[None]] = {}
# Internal event channel for forwarding
self.event_sender, self.event_receiver = channel[Event]()
self.local_event_index = 0
self._tg: TaskGroup = anyio.create_task_group()
async def run(self) -> None:
logger.info("Starting DownloadCoordinator")
async with self._tg as tg:
tg.start_soon(self._command_processor)
tg.start_soon(self._forward_events)
tg.start_soon(self._emit_existing_download_progress)
def shutdown(self) -> None:
self._tg.cancel_scope.cancel()
async def _command_processor(self) -> None:
with self.download_command_receiver as commands:
async for cmd in commands:
# Only process commands targeting this node
if cmd.command.target_node_id != self.node_id:
continue
match cmd.command:
case StartDownload(shard_metadata=shard):
await self._start_download(shard)
case CancelDownload(model_id=model_id):
await self._cancel_download(model_id)
case DeleteDownload(model_id=model_id):
await self._delete_download(model_id)
async def _start_download(self, shard: ShardMetadata) -> None:
model_id = shard.model_card.model_id
# Check if already downloading or complete
if model_id in self.download_status:
status = self.download_status[model_id]
if isinstance(status, (DownloadOngoing, DownloadCompleted)):
logger.debug(
f"Download for {model_id} already in progress or complete, skipping"
)
return
# Emit pending status
progress = DownloadPending(shard_metadata=shard, node_id=self.node_id)
self.download_status[model_id] = progress
await self.event_sender.send(NodeDownloadProgress(download_progress=progress))
# Check initial status from downloader
initial_progress = (
await self.shard_downloader.get_shard_download_status_for_shard(shard)
)
if initial_progress.status == "complete":
completed = DownloadCompleted(
shard_metadata=shard,
node_id=self.node_id,
total_bytes=initial_progress.total_bytes,
)
self.download_status[model_id] = completed
await self.event_sender.send(
NodeDownloadProgress(download_progress=completed)
)
return
# Start actual download
self._start_download_task(shard, initial_progress)
def _start_download_task(
self, shard: ShardMetadata, initial_progress: RepoDownloadProgress
) -> None:
model_id = shard.model_card.model_id
# Emit ongoing status
status = DownloadOngoing(
node_id=self.node_id,
shard_metadata=shard,
download_progress=map_repo_download_progress_to_download_progress_data(
initial_progress
),
)
self.download_status[model_id] = status
self.event_sender.send_nowait(NodeDownloadProgress(download_progress=status))
last_progress_time = 0.0
throttle_interval_secs = 1.0
async def download_progress_callback(
callback_shard: ShardMetadata, progress: RepoDownloadProgress
) -> None:
nonlocal last_progress_time
if progress.status == "complete":
completed = DownloadCompleted(
shard_metadata=callback_shard,
node_id=self.node_id,
total_bytes=progress.total_bytes,
)
self.download_status[callback_shard.model_card.model_id] = completed
await self.event_sender.send(
NodeDownloadProgress(download_progress=completed)
)
# Clean up active download tracking
if callback_shard.model_card.model_id in self.active_downloads:
del self.active_downloads[callback_shard.model_card.model_id]
elif (
progress.status == "in_progress"
and current_time() - last_progress_time > throttle_interval_secs
):
ongoing = DownloadOngoing(
node_id=self.node_id,
shard_metadata=callback_shard,
download_progress=map_repo_download_progress_to_download_progress_data(
progress
),
)
self.download_status[callback_shard.model_card.model_id] = ongoing
await self.event_sender.send(
NodeDownloadProgress(download_progress=ongoing)
)
last_progress_time = current_time()
self.shard_downloader.on_progress(download_progress_callback)
async def download_wrapper() -> None:
try:
await self.shard_downloader.ensure_shard(shard)
except Exception as e:
logger.error(f"Download failed for {model_id}: {e}")
failed = DownloadFailed(
shard_metadata=shard,
node_id=self.node_id,
error_message=str(e),
)
self.download_status[model_id] = failed
await self.event_sender.send(
NodeDownloadProgress(download_progress=failed)
)
finally:
if model_id in self.active_downloads:
del self.active_downloads[model_id]
# Track and start the download task
# asyncio.create_task() immediately starts the coroutine, no need for start_soon
task = asyncio.create_task(download_wrapper())
self.active_downloads[model_id] = task
async def _cancel_download(self, model_id: ModelId) -> None:
if model_id in self.active_downloads:
logger.info(f"Cancelling download for {model_id}")
self.active_downloads[model_id].cancel()
del self.active_downloads[model_id]
# Update status if we have shard metadata
if model_id in self.download_status:
current_status = self.download_status[model_id]
if hasattr(current_status, "shard_metadata"):
failed = DownloadFailed(
shard_metadata=current_status.shard_metadata,
node_id=self.node_id,
error_message="Download cancelled",
)
self.download_status[model_id] = failed
await self.event_sender.send(
NodeDownloadProgress(download_progress=failed)
)
async def _delete_download(self, model_id: ModelId) -> None:
# Cancel if active
if model_id in self.active_downloads:
logger.info(f"Cancelling active download for {model_id} before deletion")
self.active_downloads[model_id].cancel()
del self.active_downloads[model_id]
# Delete from disk
logger.info(f"Deleting model files for {model_id}")
deleted = await delete_model(str(model_id))
if deleted:
logger.info(f"Successfully deleted model {model_id}")
else:
logger.warning(f"Model {model_id} was not found on disk")
# Remove from status tracking
if model_id in self.download_status:
del self.download_status[model_id]
async def _forward_events(self) -> None:
with self.event_receiver as events:
async for event in events:
fe = ForwarderEvent(
origin_idx=self.local_event_index,
origin=self.node_id,
session=self.session_id,
event=event,
)
logger.debug(
f"DownloadCoordinator published event {self.local_event_index}: {str(event)[:100]}"
)
self.local_event_index += 1
await self.local_event_sender.send(fe)
async def _emit_existing_download_progress(self) -> None:
try:
while True:
logger.info(
"DownloadCoordinator: Fetching and emitting existing download progress..."
)
async for (
_,
progress,
) in self.shard_downloader.get_shard_download_status():
if progress.status == "complete":
status: DownloadProgress = DownloadCompleted(
node_id=self.node_id,
shard_metadata=progress.shard,
total_bytes=progress.total_bytes,
)
elif progress.status in ["in_progress", "not_started"]:
if progress.downloaded_bytes_this_session.in_bytes == 0:
status = DownloadPending(
node_id=self.node_id, shard_metadata=progress.shard
)
else:
status = DownloadOngoing(
node_id=self.node_id,
shard_metadata=progress.shard,
download_progress=map_repo_download_progress_to_download_progress_data(
progress
),
)
else:
continue
self.download_status[progress.shard.model_card.model_id] = status
await self.event_sender.send(
NodeDownloadProgress(download_progress=status)
)
logger.info(
"DownloadCoordinator: Done emitting existing download progress."
)
await anyio.sleep(5 * 60) # 5 minutes
except Exception as e:
logger.error(
f"DownloadCoordinator: Error emitting existing download progress: {e}"
)

View File

@@ -25,16 +25,16 @@ from pydantic import (
TypeAdapter,
)
from exo.shared.constants import EXO_MODELS_DIR
from exo.shared.types.memory import Memory
from exo.shared.types.worker.downloads import DownloadProgressData
from exo.shared.types.worker.shards import ShardMetadata
from exo.worker.download.huggingface_utils import (
from exo.download.huggingface_utils import (
filter_repo_objects,
get_allow_patterns,
get_auth_headers,
get_hf_endpoint,
)
from exo.shared.constants import EXO_MODELS_DIR
from exo.shared.types.memory import Memory
from exo.shared.types.worker.downloads import DownloadProgressData
from exo.shared.types.worker.shards import ShardMetadata
class ModelSafetensorsIndexMetadata(BaseModel):

View File

@@ -3,14 +3,14 @@ from collections.abc import Awaitable
from pathlib import Path
from typing import AsyncIterator, Callable
from exo.download.download_utils import RepoDownloadProgress, download_shard
from exo.download.shard_downloader import ShardDownloader
from exo.shared.models.model_cards import MODEL_CARDS
from exo.shared.models.model_meta import get_model_card
from exo.shared.types.worker.shards import (
PipelineShardMetadata,
ShardMetadata,
)
from exo.worker.download.download_utils import RepoDownloadProgress, download_shard
from exo.worker.download.shard_downloader import ShardDownloader
def exo_shard_downloader(max_parallel_downloads: int = 8) -> ShardDownloader:

View File

@@ -5,13 +5,13 @@ from datetime import timedelta
from pathlib import Path
from typing import AsyncIterator, Callable
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.download.download_utils import RepoDownloadProgress
from exo.shared.types.memory import Memory
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.types.worker.shards import (
PipelineShardMetadata,
ShardMetadata,
)
from exo.worker.download.download_utils import RepoDownloadProgress
# TODO: the PipelineShardMetadata getting reinstantiated is a bit messy. Should this be a classmethod?

View File

@@ -12,6 +12,8 @@ from loguru import logger
from pydantic import PositiveInt
import exo.routing.topics as topics
from exo.download.coordinator import DownloadCoordinator
from exo.download.impl_shard_downloader import exo_shard_downloader
from exo.master.api import API # TODO: should API be in master?
from exo.master.main import Master
from exo.routing.router import Router, get_node_id_keypair
@@ -21,7 +23,6 @@ from exo.shared.logging import logger_cleanup, logger_setup
from exo.shared.types.common import NodeId, SessionId
from exo.utils.channels import Receiver, channel
from exo.utils.pydantic_ext import CamelCaseModel
from exo.worker.download.impl_shard_downloader import exo_shard_downloader
from exo.worker.main import Worker
@@ -29,6 +30,7 @@ from exo.worker.main import Worker
@dataclass
class Node:
router: Router
download_coordinator: DownloadCoordinator | None
worker: Worker | None
election: Election # Every node participates in election, as we do want a node to become master even if it isn't a master candidate if no master candidates are present.
election_result_receiver: Receiver[ElectionResult]
@@ -49,8 +51,22 @@ class Node:
await router.register_topic(topics.COMMANDS)
await router.register_topic(topics.ELECTION_MESSAGES)
await router.register_topic(topics.CONNECTION_MESSAGES)
await router.register_topic(topics.DOWNLOAD_COMMANDS)
logger.info(f"Starting node {node_id}")
# Create DownloadCoordinator (unless --no-downloads)
if not args.no_downloads:
download_coordinator = DownloadCoordinator(
node_id,
session_id,
exo_shard_downloader(),
download_command_receiver=router.receiver(topics.DOWNLOAD_COMMANDS),
local_event_sender=router.sender(topics.LOCAL_EVENTS),
)
else:
download_coordinator = None
if args.spawn_api:
api = API(
node_id,
@@ -58,6 +74,7 @@ class Node:
port=args.api_port,
global_event_receiver=router.receiver(topics.GLOBAL_EVENTS),
command_sender=router.sender(topics.COMMANDS),
download_command_sender=router.sender(topics.DOWNLOAD_COMMANDS),
election_receiver=router.receiver(topics.ELECTION_MESSAGES),
)
else:
@@ -67,11 +84,11 @@ class Node:
worker = Worker(
node_id,
session_id,
exo_shard_downloader(),
connection_message_receiver=router.receiver(topics.CONNECTION_MESSAGES),
global_event_receiver=router.receiver(topics.GLOBAL_EVENTS),
local_event_sender=router.sender(topics.LOCAL_EVENTS),
command_sender=router.sender(topics.COMMANDS),
download_command_sender=router.sender(topics.DOWNLOAD_COMMANDS),
)
else:
worker = None
@@ -99,13 +116,24 @@ class Node:
election_result_sender=er_send,
)
return cls(router, worker, election, er_recv, master, api, node_id)
return cls(
router,
download_coordinator,
worker,
election,
er_recv,
master,
api,
node_id,
)
async def run(self):
async with self._tg as tg:
signal.signal(signal.SIGINT, lambda _, __: self.shutdown())
tg.start_soon(self.router.run)
tg.start_soon(self.election.run)
if self.download_coordinator:
tg.start_soon(self.download_coordinator.run)
if self.worker:
tg.start_soon(self.worker.run)
if self.master:
@@ -170,13 +198,24 @@ class Node:
)
if result.is_new_master:
await anyio.sleep(0)
if self.download_coordinator:
self.download_coordinator.shutdown()
self.download_coordinator = DownloadCoordinator(
self.node_id,
result.session_id,
exo_shard_downloader(),
download_command_receiver=self.router.receiver(
topics.DOWNLOAD_COMMANDS
),
local_event_sender=self.router.sender(topics.LOCAL_EVENTS),
)
self._tg.start_soon(self.download_coordinator.run)
if self.worker:
self.worker.shutdown()
# TODO: add profiling etc to resource monitor
self.worker = Worker(
self.node_id,
result.session_id,
exo_shard_downloader(),
connection_message_receiver=self.router.receiver(
topics.CONNECTION_MESSAGES
),
@@ -185,6 +224,9 @@ class Node:
),
local_event_sender=self.router.sender(topics.LOCAL_EVENTS),
command_sender=self.router.sender(topics.COMMANDS),
download_command_sender=self.router.sender(
topics.DOWNLOAD_COMMANDS
),
)
self._tg.start_soon(self.worker.run)
if self.api:
@@ -226,6 +268,7 @@ class Args(CamelCaseModel):
api_port: PositiveInt = 52415
tb_only: bool = False
no_worker: bool = False
no_downloads: bool = False
fast_synch: bool | None = None # None = auto, True = force on, False = force off
@classmethod
@@ -268,6 +311,11 @@ class Args(CamelCaseModel):
"--no-worker",
action="store_true",
)
parser.add_argument(
"--no-downloads",
action="store_true",
help="Disable the download coordinator (node won't download models)",
)
fast_synch_group = parser.add_mutually_exclusive_group()
fast_synch_group.add_argument(
"--fast-synch",

View File

@@ -29,6 +29,7 @@ from exo.shared.types.api import (
ChatCompletionResponse,
CreateInstanceParams,
CreateInstanceResponse,
DeleteDownloadResponse,
DeleteInstanceResponse,
ErrorInfo,
ErrorResponse,
@@ -39,6 +40,8 @@ from exo.shared.types.api import (
PlaceInstanceParams,
PlacementPreview,
PlacementPreviewResponse,
StartDownloadParams,
StartDownloadResponse,
StreamingChoiceResponse,
)
from exo.shared.types.chunks import TokenChunk
@@ -46,9 +49,13 @@ from exo.shared.types.commands import (
ChatCompletion,
Command,
CreateInstance,
DeleteDownload,
DeleteInstance,
DownloadCommand,
ForwarderCommand,
ForwarderDownloadCommand,
PlaceInstance,
StartDownload,
TaskFinished,
)
from exo.shared.types.common import CommandId, NodeId, SessionId
@@ -104,12 +111,14 @@ class API:
# Ideally this would be a MasterForwarderEvent but type system says no :(
global_event_receiver: Receiver[ForwarderEvent],
command_sender: Sender[ForwarderCommand],
download_command_sender: Sender[ForwarderDownloadCommand],
# This lets us pause the API if an election is running
election_receiver: Receiver[ElectionMessage],
) -> None:
self.state = State()
self._event_log: list[Event] = []
self.command_sender = command_sender
self.download_command_sender = download_command_sender
self.global_event_receiver = global_event_receiver
self.election_receiver = election_receiver
self.event_buffer: OrderedBuffer[Event] = OrderedBuffer[Event]()
@@ -193,6 +202,8 @@ class API:
self.app.post("/bench/chat/completions")(self.bench_chat_completions)
self.app.get("/state")(lambda: self.state)
self.app.get("/events")(lambda: self._event_log)
self.app.post("/download/start")(self.start_download)
self.app.delete("/download/{node_id}/{model_id:path}")(self.delete_download)
async def place_instance(self, payload: PlaceInstanceParams):
command = PlaceInstance(
@@ -677,3 +688,30 @@ class API:
await self.command_sender.send(
ForwarderCommand(origin=self.node_id, command=command)
)
async def _send_download(self, command: DownloadCommand):
while self.paused:
await self.paused_ev.wait()
await self.download_command_sender.send(
ForwarderDownloadCommand(origin=self.node_id, command=command)
)
async def start_download(
self, payload: StartDownloadParams
) -> StartDownloadResponse:
command = StartDownload(
target_node_id=payload.target_node_id,
shard_metadata=payload.shard_metadata,
)
await self._send_download(command)
return StartDownloadResponse(command_id=command.command_id)
async def delete_download(
self, node_id: NodeId, model_id: str
) -> DeleteDownloadResponse:
command = DeleteDownload(
target_node_id=node_id,
model_id=ModelId(model_id),
)
await self._send_download(command)
return DeleteDownloadResponse(command_id=command.command_id)

View File

@@ -3,7 +3,7 @@ from enum import Enum
from exo.routing.connection_message import ConnectionMessage
from exo.shared.election import ElectionMessage
from exo.shared.types.commands import ForwarderCommand
from exo.shared.types.commands import ForwarderCommand, ForwarderDownloadCommand
from exo.shared.types.events import (
ForwarderEvent,
)
@@ -45,3 +45,6 @@ ELECTION_MESSAGES = TypedTopic(
CONNECTION_MESSAGES = TypedTopic(
"connection_messages", PublishPolicy.Never, ConnectionMessage
)
DOWNLOAD_COMMANDS = TypedTopic(
"download_commands", PublishPolicy.Always, ForwarderDownloadCommand
)

View File

@@ -6,13 +6,13 @@ from huggingface_hub import model_info
from loguru import logger
from pydantic import BaseModel, Field
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard, ModelId
from exo.shared.types.memory import Memory
from exo.worker.download.download_utils import (
from exo.download.download_utils import (
ModelSafetensorsIndex,
download_file_with_retry,
ensure_models_dir,
)
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard, ModelId
from exo.shared.types.memory import Memory
class ConfigData(BaseModel):

View File

@@ -5,10 +5,11 @@ from pydantic import BaseModel, Field, field_validator
from pydantic_core import PydanticUseDefault
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.types.common import CommandId
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.memory import Memory
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.shards import Sharding
from exo.shared.types.worker.shards import ShardMetadata, Sharding
from exo.utils.pydantic_ext import CamelCaseModel
FinishReason = Literal[
"stop", "length", "tool_calls", "content_filter", "function_call", "error"
@@ -213,3 +214,16 @@ class DeleteInstanceResponse(BaseModel):
message: str
command_id: CommandId
instance_id: InstanceId
class StartDownloadParams(CamelCaseModel):
target_node_id: NodeId
shard_metadata: ShardMetadata
class StartDownloadResponse(CamelCaseModel):
command_id: CommandId
class DeleteDownloadResponse(CamelCaseModel):
command_id: CommandId

View File

@@ -1,10 +1,10 @@
from pydantic import Field
from exo.shared.models.model_cards import ModelCard
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.types.api import ChatCompletionTaskParams
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.shards import Sharding
from exo.shared.types.worker.shards import Sharding, ShardMetadata
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
@@ -43,6 +43,24 @@ class RequestEventLog(BaseCommand):
since_idx: int
class StartDownload(BaseCommand):
target_node_id: NodeId
shard_metadata: ShardMetadata
class CancelDownload(BaseCommand):
target_node_id: NodeId
model_id: ModelId
class DeleteDownload(BaseCommand):
target_node_id: NodeId
model_id: ModelId
DownloadCommand = StartDownload | CancelDownload | DeleteDownload
Command = (
TestCommand
| RequestEventLog
@@ -57,3 +75,8 @@ Command = (
class ForwarderCommand(CamelCaseModel):
origin: NodeId
command: Command
class ForwarderDownloadCommand(CamelCaseModel):
origin: NodeId
command: DownloadCommand

View File

@@ -40,6 +40,7 @@ import mlx.nn as nn
from mlx_lm.utils import load_model
from pydantic import RootModel
from exo.download.download_utils import build_model_path
from exo.shared.types.api import ChatCompletionMessageText
from exo.shared.types.common import Host
from exo.shared.types.memory import Memory
@@ -54,7 +55,6 @@ from exo.shared.types.worker.shards import (
ShardMetadata,
TensorShardMetadata,
)
from exo.worker.download.download_utils import build_model_path
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.auto_parallel import (
TimeoutCallback,

View File

@@ -2,21 +2,25 @@ from datetime import datetime, timezone
from random import random
import anyio
from anyio import CancelScope, create_task_group, current_time, fail_after
from anyio import CancelScope, create_task_group, fail_after
from anyio.abc import TaskGroup
from loguru import logger
from exo.routing.connection_message import ConnectionMessage, ConnectionMessageType
from exo.shared.apply import apply
from exo.shared.models.model_cards import ModelId
from exo.shared.types.commands import ForwarderCommand, RequestEventLog
from exo.shared.types.commands import (
ForwarderCommand,
ForwarderDownloadCommand,
RequestEventLog,
StartDownload,
)
from exo.shared.types.common import NodeId, SessionId
from exo.shared.types.events import (
Event,
EventId,
ForwarderEvent,
IndexedEvent,
NodeDownloadProgress,
NodeGatheredInfo,
TaskCreated,
TaskStatusUpdated,
@@ -33,22 +37,12 @@ from exo.shared.types.tasks import (
TaskStatus,
)
from exo.shared.types.topology import Connection, SocketConnection
from exo.shared.types.worker.downloads import (
DownloadCompleted,
DownloadOngoing,
DownloadPending,
DownloadProgress,
)
from exo.shared.types.worker.downloads import DownloadProgress
from exo.shared.types.worker.runners import RunnerId
from exo.shared.types.worker.shards import ShardMetadata
from exo.utils.channels import Receiver, Sender, channel
from exo.utils.event_buffer import OrderedBuffer
from exo.utils.info_gatherer.info_gatherer import GatheredInfo, InfoGatherer
from exo.utils.info_gatherer.net_profile import check_reachable
from exo.worker.download.download_utils import (
map_repo_download_progress_to_download_progress_data,
)
from exo.worker.download.shard_downloader import RepoDownloadProgress, ShardDownloader
from exo.worker.plan import plan
from exo.worker.runner.runner_supervisor import RunnerSupervisor
@@ -58,7 +52,6 @@ class Worker:
self,
node_id: NodeId,
session_id: SessionId,
shard_downloader: ShardDownloader,
*,
connection_message_receiver: Receiver[ConnectionMessage],
global_event_receiver: Receiver[ForwarderEvent],
@@ -66,23 +59,21 @@ class Worker:
# This is for requesting updates. It doesn't need to be a general command sender right now,
# but I think it's the correct way to be thinking about commands
command_sender: Sender[ForwarderCommand],
download_command_sender: Sender[ForwarderDownloadCommand],
):
self.node_id: NodeId = node_id
self.session_id: SessionId = session_id
self.shard_downloader: ShardDownloader = shard_downloader
self._pending_downloads: dict[RunnerId, ShardMetadata] = {}
self.global_event_receiver = global_event_receiver
self.local_event_sender = local_event_sender
self.local_event_index = 0
self.command_sender = command_sender
self.download_command_sender = download_command_sender
self.connection_message_receiver = connection_message_receiver
self.event_buffer = OrderedBuffer[Event]()
self.out_for_delivery: dict[EventId, ForwarderEvent] = {}
self.state: State = State()
self.download_status: dict[ModelId, DownloadProgress] = {}
self.runners: dict[RunnerId, RunnerSupervisor] = {}
self._tg: TaskGroup = create_task_group()
@@ -103,7 +94,6 @@ class Worker:
tg.start_soon(info_gatherer.run)
tg.start_soon(self._forward_info, info_recv)
tg.start_soon(self.plan_step)
tg.start_soon(self._emit_existing_download_progress)
tg.start_soon(self._connection_message_event_writer)
tg.start_soon(self._resend_out_for_delivery)
tg.start_soon(self._event_applier)
@@ -113,6 +103,7 @@ class Worker:
# Actual shutdown code - waits for all tasks to complete before executing.
self.local_event_sender.close()
self.command_sender.close()
self.download_command_sender.close()
for runner in self.runners.values():
runner.shutdown()
@@ -157,14 +148,22 @@ class Worker:
for idx, event in indexed_events:
self.state = apply(self.state, IndexedEvent(idx=idx, event=event))
def _get_local_download_status(self) -> dict[ModelId, DownloadProgress]:
"""Extract this node's download status from global state."""
downloads = self.state.downloads.get(self.node_id, [])
return {dp.shard_metadata.model_card.model_id: dp for dp in downloads}
async def plan_step(self):
while True:
await anyio.sleep(0.1)
# Get download status from state (event-sourced)
local_download_status = self._get_local_download_status()
# 3. based on the updated state, we plan & execute an operation.
task: Task | None = plan(
self.node_id,
self.runners,
self.download_status,
local_download_status,
self.state.downloads,
self.state.instances,
self.state.runners,
@@ -186,42 +185,23 @@ class Worker:
)
)
case DownloadModel(shard_metadata=shard):
if shard.model_card.model_id not in self.download_status:
progress = DownloadPending(
shard_metadata=shard, node_id=self.node_id
)
self.download_status[shard.model_card.model_id] = progress
await self.event_sender.send(
NodeDownloadProgress(download_progress=progress)
)
initial_progress = (
await self.shard_downloader.get_shard_download_status_for_shard(
shard
# Send StartDownload command to DownloadCoordinator
await self.download_command_sender.send(
ForwarderDownloadCommand(
origin=self.node_id,
command=StartDownload(
target_node_id=self.node_id,
shard_metadata=shard,
),
)
)
if initial_progress.status == "complete":
progress = DownloadCompleted(
shard_metadata=shard,
node_id=self.node_id,
total_bytes=initial_progress.total_bytes,
# Mark task as running - completion will be detected
# when download status in state changes to DownloadCompleted
await self.event_sender.send(
TaskStatusUpdated(
task_id=task.task_id, task_status=TaskStatus.Running
)
self.download_status[shard.model_card.model_id] = progress
await self.event_sender.send(
NodeDownloadProgress(download_progress=progress)
)
await self.event_sender.send(
TaskStatusUpdated(
task_id=task.task_id,
task_status=TaskStatus.Complete,
)
)
else:
await self.event_sender.send(
TaskStatusUpdated(
task_id=task.task_id, task_status=TaskStatus.Running
)
)
self._handle_shard_download_process(task, initial_progress)
)
case Shutdown(runner_id=runner_id):
try:
with fail_after(3):
@@ -326,65 +306,6 @@ class Worker:
self._tg.start_soon(runner.run)
return runner
def _handle_shard_download_process(
self,
task: DownloadModel,
initial_progress: RepoDownloadProgress,
):
"""Manages the shard download process with progress tracking."""
status = DownloadOngoing(
node_id=self.node_id,
shard_metadata=task.shard_metadata,
download_progress=map_repo_download_progress_to_download_progress_data(
initial_progress
),
)
self.download_status[task.shard_metadata.model_card.model_id] = status
self.event_sender.send_nowait(NodeDownloadProgress(download_progress=status))
last_progress_time = 0.0
throttle_interval_secs = 1.0
async def download_progress_callback(
shard: ShardMetadata, progress: RepoDownloadProgress
) -> None:
nonlocal self
nonlocal last_progress_time
if progress.status == "complete":
status = DownloadCompleted(
shard_metadata=shard,
node_id=self.node_id,
total_bytes=progress.total_bytes,
)
self.download_status[shard.model_card.model_id] = status
await self.event_sender.send(
NodeDownloadProgress(download_progress=status)
)
await self.event_sender.send(
TaskStatusUpdated(
task_id=task.task_id, task_status=TaskStatus.Complete
)
)
elif (
progress.status == "in_progress"
and current_time() - last_progress_time > throttle_interval_secs
):
status = DownloadOngoing(
node_id=self.node_id,
shard_metadata=shard,
download_progress=map_repo_download_progress_to_download_progress_data(
progress
),
)
self.download_status[shard.model_card.model_id] = status
await self.event_sender.send(
NodeDownloadProgress(download_progress=status)
)
last_progress_time = current_time()
self.shard_downloader.on_progress(download_progress_callback)
self._tg.start_soon(self.shard_downloader.ensure_shard, task.shard_metadata)
async def _forward_events(self) -> None:
with self.event_receiver as events:
async for event in events:
@@ -447,42 +368,3 @@ class Worker:
await self.event_sender.send(TopologyEdgeDeleted(conn=conn))
await anyio.sleep(10)
async def _emit_existing_download_progress(self) -> None:
try:
while True:
logger.info("Fetching and emitting existing download progress...")
async for (
_,
progress,
) in self.shard_downloader.get_shard_download_status():
if progress.status == "complete":
status = DownloadCompleted(
node_id=self.node_id,
shard_metadata=progress.shard,
total_bytes=progress.total_bytes,
)
elif progress.status in ["in_progress", "not_started"]:
if progress.downloaded_bytes_this_session.in_bytes == 0:
status = DownloadPending(
node_id=self.node_id, shard_metadata=progress.shard
)
else:
status = DownloadOngoing(
node_id=self.node_id,
shard_metadata=progress.shard,
download_progress=map_repo_download_progress_to_download_progress_data(
progress
),
)
else:
continue
self.download_status[progress.shard.model_card.model_id] = status
await self.event_sender.send(
NodeDownloadProgress(download_progress=status)
)
logger.info("Done emitting existing download progress.")
await anyio.sleep(5 * 60) # 5 minutes
except Exception as e:
logger.error(f"Error emitting existing download progress: {e}")

View File

@@ -11,12 +11,12 @@ from pathlib import Path
import pytest
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard
from exo.worker.download.download_utils import (
from exo.download.download_utils import (
download_file_with_retry,
ensure_models_dir,
fetch_file_list_with_cache,
)
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard
from exo.worker.engines.mlx.utils_mlx import (
get_eos_token_ids_for_model,
load_tokenizer_for_model_id,

View File

@@ -11,6 +11,10 @@ from hypercorn.asyncio import serve # pyright: ignore[reportUnknownVariableType
from loguru import logger
from pydantic import BaseModel
from exo.download.impl_shard_downloader import (
build_full_shard,
exo_shard_downloader,
)
from exo.shared.logging import InterceptLogger, logger_setup
from exo.shared.models.model_cards import MODEL_CARDS, ModelId
from exo.shared.types.api import ChatCompletionMessage, ChatCompletionTaskParams
@@ -36,10 +40,6 @@ from exo.shared.types.worker.runners import RunnerId, ShardAssignments
from exo.shared.types.worker.shards import PipelineShardMetadata, TensorShardMetadata
from exo.utils.channels import MpReceiver, MpSender, channel, mp_channel
from exo.utils.info_gatherer.info_gatherer import GatheredInfo, InfoGatherer
from exo.worker.download.impl_shard_downloader import (
build_full_shard,
exo_shard_downloader,
)
from exo.worker.runner.bootstrap import entrypoint