Compare commits

...

1 Commits

Author SHA1 Message Date
Jake Hillion
3e9c29618e WIP: move downloads from Worker to Node 2026-01-12 18:30:21 +00:00
4 changed files with 210 additions and 161 deletions

View File

@@ -0,0 +1,3 @@
from exo.download.manager import DownloadManager
__all__ = ["DownloadManager"]

186
src/exo/download/manager.py Normal file
View File

@@ -0,0 +1,186 @@
from dataclasses import dataclass, field
from typing import TYPE_CHECKING
import anyio
from anyio import current_time
from anyio.abc import TaskGroup
from loguru import logger
from exo.shared.types.common import NodeId
from exo.shared.types.events import Event, NodeDownloadProgress, TaskStatusUpdated
from exo.shared.types.models import ModelId
from exo.shared.types.tasks import DownloadModel, TaskStatus
from exo.shared.types.worker.downloads import (
DownloadCompleted,
DownloadOngoing,
DownloadPending,
DownloadProgress,
)
from exo.shared.types.worker.shards import ShardMetadata
from exo.utils.channels import Sender
from exo.worker.download.download_utils import (
map_repo_download_progress_to_download_progress_data,
)
from exo.worker.download.shard_downloader import RepoDownloadProgress, ShardDownloader
if TYPE_CHECKING:
pass
@dataclass
class DownloadManager:
node_id: NodeId
shard_downloader: ShardDownloader
download_status: dict[ModelId, DownloadProgress] = field(default_factory=dict)
_tg: TaskGroup | None = field(default=None, init=False)
_event_sender: Sender[Event] | None = field(default=None, init=False)
async def run(self, event_sender: Sender[Event]) -> None:
"""Run download manager background tasks."""
self._event_sender = event_sender
async with anyio.create_task_group() as tg:
self._tg = tg
tg.start_soon(self._emit_existing_download_progress)
def shutdown(self) -> None:
if self._tg:
self._tg.cancel_scope.cancel()
async def start_download(
self,
task: DownloadModel,
event_sender: Sender[Event],
) -> None:
"""Initiate a download for the given task.
This is called by Worker when it encounters a DownloadModel task.
"""
shard = task.shard_metadata
model_id = shard.model_meta.model_id
if model_id not in self.download_status:
progress = DownloadPending(shard_metadata=shard, node_id=self.node_id)
self.download_status[model_id] = progress
event_sender.send_nowait(NodeDownloadProgress(download_progress=progress))
initial_progress = await self.shard_downloader.get_shard_download_status_for_shard(
shard
)
if initial_progress.status == "complete":
progress = DownloadCompleted(
shard_metadata=shard,
node_id=self.node_id,
total_bytes=initial_progress.total_bytes,
)
self.download_status[model_id] = progress
event_sender.send_nowait(NodeDownloadProgress(download_progress=progress))
event_sender.send_nowait(
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Complete)
)
else:
event_sender.send_nowait(
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Running)
)
self._handle_shard_download_process(task, initial_progress, event_sender)
def _handle_shard_download_process(
self,
task: DownloadModel,
initial_progress: RepoDownloadProgress,
event_sender: Sender[Event],
) -> None:
"""Manages the shard download process with progress tracking."""
status = DownloadOngoing(
node_id=self.node_id,
shard_metadata=task.shard_metadata,
download_progress=map_repo_download_progress_to_download_progress_data(
initial_progress
),
)
self.download_status[task.shard_metadata.model_meta.model_id] = status
event_sender.send_nowait(NodeDownloadProgress(download_progress=status))
last_progress_time = 0.0
throttle_interval_secs = 1.0
def download_progress_callback(
shard: ShardMetadata, progress: RepoDownloadProgress
) -> None:
nonlocal last_progress_time
if progress.status == "complete":
status = DownloadCompleted(
shard_metadata=shard,
node_id=self.node_id,
total_bytes=progress.total_bytes,
)
self.download_status[shard.model_meta.model_id] = status
event_sender.send_nowait(
NodeDownloadProgress(download_progress=status)
)
event_sender.send_nowait(
TaskStatusUpdated(
task_id=task.task_id, task_status=TaskStatus.Complete
)
)
elif (
progress.status == "in_progress"
and current_time() - last_progress_time > throttle_interval_secs
):
status = DownloadOngoing(
node_id=self.node_id,
shard_metadata=shard,
download_progress=map_repo_download_progress_to_download_progress_data(
progress
),
)
self.download_status[shard.model_meta.model_id] = status
event_sender.send_nowait(
NodeDownloadProgress(download_progress=status)
)
last_progress_time = current_time()
self.shard_downloader.on_progress(download_progress_callback)
assert self._tg
self._tg.start_soon(self.shard_downloader.ensure_shard, task.shard_metadata)
async def _emit_existing_download_progress(self) -> None:
"""Periodically emit existing download progress."""
try:
while True:
logger.info("Fetching and emitting existing download progress...")
async for (
_,
progress,
) in self.shard_downloader.get_shard_download_status():
if progress.status == "complete":
status: DownloadProgress = DownloadCompleted(
node_id=self.node_id,
shard_metadata=progress.shard,
total_bytes=progress.total_bytes,
)
elif progress.status in ["in_progress", "not_started"]:
if progress.downloaded_bytes_this_session.in_bytes == 0:
status = DownloadPending(
node_id=self.node_id, shard_metadata=progress.shard
)
else:
status = DownloadOngoing(
node_id=self.node_id,
shard_metadata=progress.shard,
download_progress=map_repo_download_progress_to_download_progress_data(
progress
),
)
else:
continue
self.download_status[progress.shard.model_meta.model_id] = status
if self._event_sender:
await self._event_sender.send(
NodeDownloadProgress(download_progress=status)
)
logger.info("Done emitting existing download progress.")
await anyio.sleep(5 * 60) # 5 minutes
except Exception as e:
logger.error(f"Error emitting existing download progress: {e}")

View File

@@ -20,6 +20,7 @@ from exo.shared.logging import logger_cleanup, logger_setup
from exo.shared.types.common import NodeId, SessionId
from exo.utils.channels import Receiver, channel
from exo.utils.pydantic_ext import CamelCaseModel
from exo.download import DownloadManager
from exo.worker.download.impl_shard_downloader import exo_shard_downloader
from exo.worker.main import Worker
@@ -33,6 +34,7 @@ class Node:
election_result_receiver: Receiver[ElectionResult]
master: Master | None
api: API | None
download_manager: DownloadManager
node_id: NodeId
_tg: TaskGroup = field(init=False, default_factory=anyio.create_task_group)
@@ -62,11 +64,16 @@ class Node:
else:
api = None
download_manager = DownloadManager(
node_id=node_id,
shard_downloader=exo_shard_downloader(),
)
if not args.no_worker:
worker = Worker(
node_id,
session_id,
exo_shard_downloader(),
download_manager,
connection_message_receiver=router.receiver(topics.CONNECTION_MESSAGES),
global_event_receiver=router.receiver(topics.GLOBAL_EVENTS),
local_event_sender=router.sender(topics.LOCAL_EVENTS),
@@ -98,13 +105,17 @@ class Node:
election_result_sender=er_send,
)
return cls(router, worker, election, er_recv, master, api, node_id)
return cls(router, worker, election, er_recv, master, api, download_manager, node_id)
async def run(self):
async with self._tg as tg:
signal.signal(signal.SIGINT, lambda _, __: self.shutdown())
tg.start_soon(self.router.run)
tg.start_soon(self.election.run)
tg.start_soon(
self.download_manager.run,
self.router.sender(topics.LOCAL_EVENTS),
)
if self.worker:
tg.start_soon(self.worker.run)
if self.master:
@@ -175,7 +186,7 @@ class Node:
self.worker = Worker(
self.node_id,
result.session_id,
exo_shard_downloader(),
self.download_manager,
connection_message_receiver=self.router.receiver(
topics.CONNECTION_MESSAGES
),

View File

@@ -2,7 +2,7 @@ from datetime import datetime, timezone
from random import random
import anyio
from anyio import CancelScope, create_task_group, current_time, fail_after
from anyio import CancelScope, create_task_group, fail_after
from anyio.abc import TaskGroup
from loguru import logger
@@ -15,7 +15,6 @@ from exo.shared.types.events import (
EventId,
ForwarderEvent,
IndexedEvent,
NodeDownloadProgress,
NodeMemoryMeasured,
NodePerformanceMeasured,
TaskCreated,
@@ -23,7 +22,6 @@ from exo.shared.types.events import (
TopologyEdgeCreated,
TopologyEdgeDeleted,
)
from exo.shared.types.models import ModelId
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.profiling import MemoryPerformanceProfile, NodePerformanceProfile
from exo.shared.types.state import State
@@ -35,20 +33,10 @@ from exo.shared.types.tasks import (
TaskStatus,
)
from exo.shared.types.topology import Connection
from exo.shared.types.worker.downloads import (
DownloadCompleted,
DownloadOngoing,
DownloadPending,
DownloadProgress,
)
from exo.shared.types.worker.runners import RunnerId
from exo.shared.types.worker.shards import ShardMetadata
from exo.utils.channels import Receiver, Sender, channel
from exo.utils.event_buffer import OrderedBuffer
from exo.worker.download.download_utils import (
map_repo_download_progress_to_download_progress_data,
)
from exo.worker.download.shard_downloader import RepoDownloadProgress, ShardDownloader
from exo.download import DownloadManager
from exo.worker.plan import plan
from exo.worker.runner.runner_supervisor import RunnerSupervisor
from exo.worker.utils import start_polling_memory_metrics, start_polling_node_metrics
@@ -60,7 +48,7 @@ class Worker:
self,
node_id: NodeId,
session_id: SessionId,
shard_downloader: ShardDownloader,
download_manager: DownloadManager,
*,
connection_message_receiver: Receiver[ConnectionMessage],
global_event_receiver: Receiver[ForwarderEvent],
@@ -72,8 +60,7 @@ class Worker:
self.node_id: NodeId = node_id
self.session_id: SessionId = session_id
self.shard_downloader: ShardDownloader = shard_downloader
self._pending_downloads: dict[RunnerId, ShardMetadata] = {}
self.download_manager: DownloadManager = download_manager
self.global_event_receiver = global_event_receiver
self.local_event_sender = local_event_sender
@@ -84,7 +71,6 @@ class Worker:
self.out_for_delivery: dict[EventId, ForwarderEvent] = {}
self.state: State = State()
self.download_status: dict[ModelId, DownloadProgress] = {}
self.runners: dict[RunnerId, RunnerSupervisor] = {}
self._tg: TaskGroup | None = None
@@ -129,7 +115,6 @@ class Worker:
tg.start_soon(start_polling_node_metrics, resource_monitor_callback)
tg.start_soon(start_polling_memory_metrics, memory_monitor_callback)
tg.start_soon(self._emit_existing_download_progress)
tg.start_soon(self._connection_message_event_writer)
tg.start_soon(self._resend_out_for_delivery)
tg.start_soon(self._event_applier)
@@ -180,7 +165,7 @@ class Worker:
task: Task | None = plan(
self.node_id,
self.runners,
self.download_status,
self.download_manager.download_status,
self.state.downloads,
self.state.instances,
self.state.runners,
@@ -201,43 +186,8 @@ class Worker:
task_id=task.task_id, task_status=TaskStatus.Complete
)
)
case DownloadModel(shard_metadata=shard):
if shard.model_meta.model_id not in self.download_status:
progress = DownloadPending(
shard_metadata=shard, node_id=self.node_id
)
self.download_status[shard.model_meta.model_id] = progress
await self.event_sender.send(
NodeDownloadProgress(download_progress=progress)
)
initial_progress = (
await self.shard_downloader.get_shard_download_status_for_shard(
shard
)
)
if initial_progress.status == "complete":
progress = DownloadCompleted(
shard_metadata=shard,
node_id=self.node_id,
total_bytes=initial_progress.total_bytes,
)
self.download_status[shard.model_meta.model_id] = progress
await self.event_sender.send(
NodeDownloadProgress(download_progress=progress)
)
await self.event_sender.send(
TaskStatusUpdated(
task_id=task.task_id,
task_status=TaskStatus.Complete,
)
)
else:
await self.event_sender.send(
TaskStatusUpdated(
task_id=task.task_id, task_status=TaskStatus.Running
)
)
self._handle_shard_download_process(task, initial_progress)
case DownloadModel():
await self.download_manager.start_download(task, self.event_sender)
case Shutdown(runner_id=runner_id):
try:
with fail_after(3):
@@ -340,68 +290,6 @@ class Worker:
self._tg.start_soon(runner.run)
return runner
def _handle_shard_download_process(
self,
task: DownloadModel,
initial_progress: RepoDownloadProgress,
):
"""Manages the shard download process with progress tracking."""
status = DownloadOngoing(
node_id=self.node_id,
shard_metadata=task.shard_metadata,
download_progress=map_repo_download_progress_to_download_progress_data(
initial_progress
),
)
self.download_status[task.shard_metadata.model_meta.model_id] = status
self.event_sender.send_nowait(NodeDownloadProgress(download_progress=status))
last_progress_time = 0.0
throttle_interval_secs = 1.0
# TODO: i hate callbacks
def download_progress_callback(
shard: ShardMetadata, progress: RepoDownloadProgress
) -> None:
nonlocal self
nonlocal last_progress_time
if progress.status == "complete":
status = DownloadCompleted(
shard_metadata=shard,
node_id=self.node_id,
total_bytes=progress.total_bytes,
)
self.download_status[shard.model_meta.model_id] = status
# Footgun!
self.event_sender.send_nowait(
NodeDownloadProgress(download_progress=status)
)
self.event_sender.send_nowait(
TaskStatusUpdated(
task_id=task.task_id, task_status=TaskStatus.Complete
)
)
elif (
progress.status == "in_progress"
and current_time() - last_progress_time > throttle_interval_secs
):
status = DownloadOngoing(
node_id=self.node_id,
shard_metadata=shard,
download_progress=map_repo_download_progress_to_download_progress_data(
progress
),
)
self.download_status[shard.model_meta.model_id] = status
self.event_sender.send_nowait(
NodeDownloadProgress(download_progress=status)
)
last_progress_time = current_time()
self.shard_downloader.on_progress(download_progress_callback)
assert self._tg
self._tg.start_soon(self.shard_downloader.ensure_shard, task.shard_metadata)
async def _forward_events(self) -> None:
with self.event_receiver as events:
async for event in events:
@@ -452,42 +340,3 @@ class Worker:
await self.event_sender.send(TopologyEdgeDeleted(edge=conn))
await anyio.sleep(10)
async def _emit_existing_download_progress(self) -> None:
try:
while True:
logger.info("Fetching and emitting existing download progress...")
async for (
_,
progress,
) in self.shard_downloader.get_shard_download_status():
if progress.status == "complete":
status = DownloadCompleted(
node_id=self.node_id,
shard_metadata=progress.shard,
total_bytes=progress.total_bytes,
)
elif progress.status in ["in_progress", "not_started"]:
if progress.downloaded_bytes_this_session.in_bytes == 0:
status = DownloadPending(
node_id=self.node_id, shard_metadata=progress.shard
)
else:
status = DownloadOngoing(
node_id=self.node_id,
shard_metadata=progress.shard,
download_progress=map_repo_download_progress_to_download_progress_data(
progress
),
)
else:
continue
self.download_status[progress.shard.model_meta.model_id] = status
await self.event_sender.send(
NodeDownloadProgress(download_progress=status)
)
logger.info("Done emitting existing download progress.")
await anyio.sleep(5 * 60) # 5 minutes
except Exception as e:
logger.error(f"Error emitting existing download progress: {e}")