mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-25 10:48:26 -05:00
Compare commits
1 Commits
remove-pyt
...
JakeHillio
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3e9c29618e |
3
src/exo/download/__init__.py
Normal file
3
src/exo/download/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from exo.download.manager import DownloadManager
|
||||
|
||||
__all__ = ["DownloadManager"]
|
||||
186
src/exo/download/manager.py
Normal file
186
src/exo/download/manager.py
Normal file
@@ -0,0 +1,186 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import anyio
|
||||
from anyio import current_time
|
||||
from anyio.abc import TaskGroup
|
||||
from loguru import logger
|
||||
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.events import Event, NodeDownloadProgress, TaskStatusUpdated
|
||||
from exo.shared.types.models import ModelId
|
||||
from exo.shared.types.tasks import DownloadModel, TaskStatus
|
||||
from exo.shared.types.worker.downloads import (
|
||||
DownloadCompleted,
|
||||
DownloadOngoing,
|
||||
DownloadPending,
|
||||
DownloadProgress,
|
||||
)
|
||||
from exo.shared.types.worker.shards import ShardMetadata
|
||||
from exo.utils.channels import Sender
|
||||
from exo.worker.download.download_utils import (
|
||||
map_repo_download_progress_to_download_progress_data,
|
||||
)
|
||||
from exo.worker.download.shard_downloader import RepoDownloadProgress, ShardDownloader
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class DownloadManager:
|
||||
node_id: NodeId
|
||||
shard_downloader: ShardDownloader
|
||||
download_status: dict[ModelId, DownloadProgress] = field(default_factory=dict)
|
||||
_tg: TaskGroup | None = field(default=None, init=False)
|
||||
_event_sender: Sender[Event] | None = field(default=None, init=False)
|
||||
|
||||
async def run(self, event_sender: Sender[Event]) -> None:
|
||||
"""Run download manager background tasks."""
|
||||
self._event_sender = event_sender
|
||||
async with anyio.create_task_group() as tg:
|
||||
self._tg = tg
|
||||
tg.start_soon(self._emit_existing_download_progress)
|
||||
|
||||
def shutdown(self) -> None:
|
||||
if self._tg:
|
||||
self._tg.cancel_scope.cancel()
|
||||
|
||||
async def start_download(
|
||||
self,
|
||||
task: DownloadModel,
|
||||
event_sender: Sender[Event],
|
||||
) -> None:
|
||||
"""Initiate a download for the given task.
|
||||
|
||||
This is called by Worker when it encounters a DownloadModel task.
|
||||
"""
|
||||
shard = task.shard_metadata
|
||||
model_id = shard.model_meta.model_id
|
||||
|
||||
if model_id not in self.download_status:
|
||||
progress = DownloadPending(shard_metadata=shard, node_id=self.node_id)
|
||||
self.download_status[model_id] = progress
|
||||
event_sender.send_nowait(NodeDownloadProgress(download_progress=progress))
|
||||
|
||||
initial_progress = await self.shard_downloader.get_shard_download_status_for_shard(
|
||||
shard
|
||||
)
|
||||
|
||||
if initial_progress.status == "complete":
|
||||
progress = DownloadCompleted(
|
||||
shard_metadata=shard,
|
||||
node_id=self.node_id,
|
||||
total_bytes=initial_progress.total_bytes,
|
||||
)
|
||||
self.download_status[model_id] = progress
|
||||
event_sender.send_nowait(NodeDownloadProgress(download_progress=progress))
|
||||
event_sender.send_nowait(
|
||||
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Complete)
|
||||
)
|
||||
else:
|
||||
event_sender.send_nowait(
|
||||
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Running)
|
||||
)
|
||||
self._handle_shard_download_process(task, initial_progress, event_sender)
|
||||
|
||||
def _handle_shard_download_process(
|
||||
self,
|
||||
task: DownloadModel,
|
||||
initial_progress: RepoDownloadProgress,
|
||||
event_sender: Sender[Event],
|
||||
) -> None:
|
||||
"""Manages the shard download process with progress tracking."""
|
||||
status = DownloadOngoing(
|
||||
node_id=self.node_id,
|
||||
shard_metadata=task.shard_metadata,
|
||||
download_progress=map_repo_download_progress_to_download_progress_data(
|
||||
initial_progress
|
||||
),
|
||||
)
|
||||
self.download_status[task.shard_metadata.model_meta.model_id] = status
|
||||
event_sender.send_nowait(NodeDownloadProgress(download_progress=status))
|
||||
|
||||
last_progress_time = 0.0
|
||||
throttle_interval_secs = 1.0
|
||||
|
||||
def download_progress_callback(
|
||||
shard: ShardMetadata, progress: RepoDownloadProgress
|
||||
) -> None:
|
||||
nonlocal last_progress_time
|
||||
if progress.status == "complete":
|
||||
status = DownloadCompleted(
|
||||
shard_metadata=shard,
|
||||
node_id=self.node_id,
|
||||
total_bytes=progress.total_bytes,
|
||||
)
|
||||
self.download_status[shard.model_meta.model_id] = status
|
||||
event_sender.send_nowait(
|
||||
NodeDownloadProgress(download_progress=status)
|
||||
)
|
||||
event_sender.send_nowait(
|
||||
TaskStatusUpdated(
|
||||
task_id=task.task_id, task_status=TaskStatus.Complete
|
||||
)
|
||||
)
|
||||
elif (
|
||||
progress.status == "in_progress"
|
||||
and current_time() - last_progress_time > throttle_interval_secs
|
||||
):
|
||||
status = DownloadOngoing(
|
||||
node_id=self.node_id,
|
||||
shard_metadata=shard,
|
||||
download_progress=map_repo_download_progress_to_download_progress_data(
|
||||
progress
|
||||
),
|
||||
)
|
||||
self.download_status[shard.model_meta.model_id] = status
|
||||
event_sender.send_nowait(
|
||||
NodeDownloadProgress(download_progress=status)
|
||||
)
|
||||
last_progress_time = current_time()
|
||||
|
||||
self.shard_downloader.on_progress(download_progress_callback)
|
||||
assert self._tg
|
||||
self._tg.start_soon(self.shard_downloader.ensure_shard, task.shard_metadata)
|
||||
|
||||
async def _emit_existing_download_progress(self) -> None:
|
||||
"""Periodically emit existing download progress."""
|
||||
try:
|
||||
while True:
|
||||
logger.info("Fetching and emitting existing download progress...")
|
||||
async for (
|
||||
_,
|
||||
progress,
|
||||
) in self.shard_downloader.get_shard_download_status():
|
||||
if progress.status == "complete":
|
||||
status: DownloadProgress = DownloadCompleted(
|
||||
node_id=self.node_id,
|
||||
shard_metadata=progress.shard,
|
||||
total_bytes=progress.total_bytes,
|
||||
)
|
||||
elif progress.status in ["in_progress", "not_started"]:
|
||||
if progress.downloaded_bytes_this_session.in_bytes == 0:
|
||||
status = DownloadPending(
|
||||
node_id=self.node_id, shard_metadata=progress.shard
|
||||
)
|
||||
else:
|
||||
status = DownloadOngoing(
|
||||
node_id=self.node_id,
|
||||
shard_metadata=progress.shard,
|
||||
download_progress=map_repo_download_progress_to_download_progress_data(
|
||||
progress
|
||||
),
|
||||
)
|
||||
else:
|
||||
continue
|
||||
|
||||
self.download_status[progress.shard.model_meta.model_id] = status
|
||||
if self._event_sender:
|
||||
await self._event_sender.send(
|
||||
NodeDownloadProgress(download_progress=status)
|
||||
)
|
||||
logger.info("Done emitting existing download progress.")
|
||||
await anyio.sleep(5 * 60) # 5 minutes
|
||||
except Exception as e:
|
||||
logger.error(f"Error emitting existing download progress: {e}")
|
||||
@@ -20,6 +20,7 @@ from exo.shared.logging import logger_cleanup, logger_setup
|
||||
from exo.shared.types.common import NodeId, SessionId
|
||||
from exo.utils.channels import Receiver, channel
|
||||
from exo.utils.pydantic_ext import CamelCaseModel
|
||||
from exo.download import DownloadManager
|
||||
from exo.worker.download.impl_shard_downloader import exo_shard_downloader
|
||||
from exo.worker.main import Worker
|
||||
|
||||
@@ -33,6 +34,7 @@ class Node:
|
||||
election_result_receiver: Receiver[ElectionResult]
|
||||
master: Master | None
|
||||
api: API | None
|
||||
download_manager: DownloadManager
|
||||
|
||||
node_id: NodeId
|
||||
_tg: TaskGroup = field(init=False, default_factory=anyio.create_task_group)
|
||||
@@ -62,11 +64,16 @@ class Node:
|
||||
else:
|
||||
api = None
|
||||
|
||||
download_manager = DownloadManager(
|
||||
node_id=node_id,
|
||||
shard_downloader=exo_shard_downloader(),
|
||||
)
|
||||
|
||||
if not args.no_worker:
|
||||
worker = Worker(
|
||||
node_id,
|
||||
session_id,
|
||||
exo_shard_downloader(),
|
||||
download_manager,
|
||||
connection_message_receiver=router.receiver(topics.CONNECTION_MESSAGES),
|
||||
global_event_receiver=router.receiver(topics.GLOBAL_EVENTS),
|
||||
local_event_sender=router.sender(topics.LOCAL_EVENTS),
|
||||
@@ -98,13 +105,17 @@ class Node:
|
||||
election_result_sender=er_send,
|
||||
)
|
||||
|
||||
return cls(router, worker, election, er_recv, master, api, node_id)
|
||||
return cls(router, worker, election, er_recv, master, api, download_manager, node_id)
|
||||
|
||||
async def run(self):
|
||||
async with self._tg as tg:
|
||||
signal.signal(signal.SIGINT, lambda _, __: self.shutdown())
|
||||
tg.start_soon(self.router.run)
|
||||
tg.start_soon(self.election.run)
|
||||
tg.start_soon(
|
||||
self.download_manager.run,
|
||||
self.router.sender(topics.LOCAL_EVENTS),
|
||||
)
|
||||
if self.worker:
|
||||
tg.start_soon(self.worker.run)
|
||||
if self.master:
|
||||
@@ -175,7 +186,7 @@ class Node:
|
||||
self.worker = Worker(
|
||||
self.node_id,
|
||||
result.session_id,
|
||||
exo_shard_downloader(),
|
||||
self.download_manager,
|
||||
connection_message_receiver=self.router.receiver(
|
||||
topics.CONNECTION_MESSAGES
|
||||
),
|
||||
|
||||
@@ -2,7 +2,7 @@ from datetime import datetime, timezone
|
||||
from random import random
|
||||
|
||||
import anyio
|
||||
from anyio import CancelScope, create_task_group, current_time, fail_after
|
||||
from anyio import CancelScope, create_task_group, fail_after
|
||||
from anyio.abc import TaskGroup
|
||||
from loguru import logger
|
||||
|
||||
@@ -15,7 +15,6 @@ from exo.shared.types.events import (
|
||||
EventId,
|
||||
ForwarderEvent,
|
||||
IndexedEvent,
|
||||
NodeDownloadProgress,
|
||||
NodeMemoryMeasured,
|
||||
NodePerformanceMeasured,
|
||||
TaskCreated,
|
||||
@@ -23,7 +22,6 @@ from exo.shared.types.events import (
|
||||
TopologyEdgeCreated,
|
||||
TopologyEdgeDeleted,
|
||||
)
|
||||
from exo.shared.types.models import ModelId
|
||||
from exo.shared.types.multiaddr import Multiaddr
|
||||
from exo.shared.types.profiling import MemoryPerformanceProfile, NodePerformanceProfile
|
||||
from exo.shared.types.state import State
|
||||
@@ -35,20 +33,10 @@ from exo.shared.types.tasks import (
|
||||
TaskStatus,
|
||||
)
|
||||
from exo.shared.types.topology import Connection
|
||||
from exo.shared.types.worker.downloads import (
|
||||
DownloadCompleted,
|
||||
DownloadOngoing,
|
||||
DownloadPending,
|
||||
DownloadProgress,
|
||||
)
|
||||
from exo.shared.types.worker.runners import RunnerId
|
||||
from exo.shared.types.worker.shards import ShardMetadata
|
||||
from exo.utils.channels import Receiver, Sender, channel
|
||||
from exo.utils.event_buffer import OrderedBuffer
|
||||
from exo.worker.download.download_utils import (
|
||||
map_repo_download_progress_to_download_progress_data,
|
||||
)
|
||||
from exo.worker.download.shard_downloader import RepoDownloadProgress, ShardDownloader
|
||||
from exo.download import DownloadManager
|
||||
from exo.worker.plan import plan
|
||||
from exo.worker.runner.runner_supervisor import RunnerSupervisor
|
||||
from exo.worker.utils import start_polling_memory_metrics, start_polling_node_metrics
|
||||
@@ -60,7 +48,7 @@ class Worker:
|
||||
self,
|
||||
node_id: NodeId,
|
||||
session_id: SessionId,
|
||||
shard_downloader: ShardDownloader,
|
||||
download_manager: DownloadManager,
|
||||
*,
|
||||
connection_message_receiver: Receiver[ConnectionMessage],
|
||||
global_event_receiver: Receiver[ForwarderEvent],
|
||||
@@ -72,8 +60,7 @@ class Worker:
|
||||
self.node_id: NodeId = node_id
|
||||
self.session_id: SessionId = session_id
|
||||
|
||||
self.shard_downloader: ShardDownloader = shard_downloader
|
||||
self._pending_downloads: dict[RunnerId, ShardMetadata] = {}
|
||||
self.download_manager: DownloadManager = download_manager
|
||||
|
||||
self.global_event_receiver = global_event_receiver
|
||||
self.local_event_sender = local_event_sender
|
||||
@@ -84,7 +71,6 @@ class Worker:
|
||||
self.out_for_delivery: dict[EventId, ForwarderEvent] = {}
|
||||
|
||||
self.state: State = State()
|
||||
self.download_status: dict[ModelId, DownloadProgress] = {}
|
||||
self.runners: dict[RunnerId, RunnerSupervisor] = {}
|
||||
self._tg: TaskGroup | None = None
|
||||
|
||||
@@ -129,7 +115,6 @@ class Worker:
|
||||
tg.start_soon(start_polling_node_metrics, resource_monitor_callback)
|
||||
|
||||
tg.start_soon(start_polling_memory_metrics, memory_monitor_callback)
|
||||
tg.start_soon(self._emit_existing_download_progress)
|
||||
tg.start_soon(self._connection_message_event_writer)
|
||||
tg.start_soon(self._resend_out_for_delivery)
|
||||
tg.start_soon(self._event_applier)
|
||||
@@ -180,7 +165,7 @@ class Worker:
|
||||
task: Task | None = plan(
|
||||
self.node_id,
|
||||
self.runners,
|
||||
self.download_status,
|
||||
self.download_manager.download_status,
|
||||
self.state.downloads,
|
||||
self.state.instances,
|
||||
self.state.runners,
|
||||
@@ -201,43 +186,8 @@ class Worker:
|
||||
task_id=task.task_id, task_status=TaskStatus.Complete
|
||||
)
|
||||
)
|
||||
case DownloadModel(shard_metadata=shard):
|
||||
if shard.model_meta.model_id not in self.download_status:
|
||||
progress = DownloadPending(
|
||||
shard_metadata=shard, node_id=self.node_id
|
||||
)
|
||||
self.download_status[shard.model_meta.model_id] = progress
|
||||
await self.event_sender.send(
|
||||
NodeDownloadProgress(download_progress=progress)
|
||||
)
|
||||
initial_progress = (
|
||||
await self.shard_downloader.get_shard_download_status_for_shard(
|
||||
shard
|
||||
)
|
||||
)
|
||||
if initial_progress.status == "complete":
|
||||
progress = DownloadCompleted(
|
||||
shard_metadata=shard,
|
||||
node_id=self.node_id,
|
||||
total_bytes=initial_progress.total_bytes,
|
||||
)
|
||||
self.download_status[shard.model_meta.model_id] = progress
|
||||
await self.event_sender.send(
|
||||
NodeDownloadProgress(download_progress=progress)
|
||||
)
|
||||
await self.event_sender.send(
|
||||
TaskStatusUpdated(
|
||||
task_id=task.task_id,
|
||||
task_status=TaskStatus.Complete,
|
||||
)
|
||||
)
|
||||
else:
|
||||
await self.event_sender.send(
|
||||
TaskStatusUpdated(
|
||||
task_id=task.task_id, task_status=TaskStatus.Running
|
||||
)
|
||||
)
|
||||
self._handle_shard_download_process(task, initial_progress)
|
||||
case DownloadModel():
|
||||
await self.download_manager.start_download(task, self.event_sender)
|
||||
case Shutdown(runner_id=runner_id):
|
||||
try:
|
||||
with fail_after(3):
|
||||
@@ -340,68 +290,6 @@ class Worker:
|
||||
self._tg.start_soon(runner.run)
|
||||
return runner
|
||||
|
||||
def _handle_shard_download_process(
|
||||
self,
|
||||
task: DownloadModel,
|
||||
initial_progress: RepoDownloadProgress,
|
||||
):
|
||||
"""Manages the shard download process with progress tracking."""
|
||||
status = DownloadOngoing(
|
||||
node_id=self.node_id,
|
||||
shard_metadata=task.shard_metadata,
|
||||
download_progress=map_repo_download_progress_to_download_progress_data(
|
||||
initial_progress
|
||||
),
|
||||
)
|
||||
self.download_status[task.shard_metadata.model_meta.model_id] = status
|
||||
self.event_sender.send_nowait(NodeDownloadProgress(download_progress=status))
|
||||
|
||||
last_progress_time = 0.0
|
||||
throttle_interval_secs = 1.0
|
||||
|
||||
# TODO: i hate callbacks
|
||||
def download_progress_callback(
|
||||
shard: ShardMetadata, progress: RepoDownloadProgress
|
||||
) -> None:
|
||||
nonlocal self
|
||||
nonlocal last_progress_time
|
||||
if progress.status == "complete":
|
||||
status = DownloadCompleted(
|
||||
shard_metadata=shard,
|
||||
node_id=self.node_id,
|
||||
total_bytes=progress.total_bytes,
|
||||
)
|
||||
self.download_status[shard.model_meta.model_id] = status
|
||||
# Footgun!
|
||||
self.event_sender.send_nowait(
|
||||
NodeDownloadProgress(download_progress=status)
|
||||
)
|
||||
self.event_sender.send_nowait(
|
||||
TaskStatusUpdated(
|
||||
task_id=task.task_id, task_status=TaskStatus.Complete
|
||||
)
|
||||
)
|
||||
elif (
|
||||
progress.status == "in_progress"
|
||||
and current_time() - last_progress_time > throttle_interval_secs
|
||||
):
|
||||
status = DownloadOngoing(
|
||||
node_id=self.node_id,
|
||||
shard_metadata=shard,
|
||||
download_progress=map_repo_download_progress_to_download_progress_data(
|
||||
progress
|
||||
),
|
||||
)
|
||||
self.download_status[shard.model_meta.model_id] = status
|
||||
self.event_sender.send_nowait(
|
||||
NodeDownloadProgress(download_progress=status)
|
||||
)
|
||||
last_progress_time = current_time()
|
||||
|
||||
self.shard_downloader.on_progress(download_progress_callback)
|
||||
assert self._tg
|
||||
self._tg.start_soon(self.shard_downloader.ensure_shard, task.shard_metadata)
|
||||
|
||||
async def _forward_events(self) -> None:
|
||||
with self.event_receiver as events:
|
||||
async for event in events:
|
||||
@@ -452,42 +340,3 @@ class Worker:
|
||||
await self.event_sender.send(TopologyEdgeDeleted(edge=conn))
|
||||
|
||||
await anyio.sleep(10)
|
||||
|
||||
async def _emit_existing_download_progress(self) -> None:
|
||||
try:
|
||||
while True:
|
||||
logger.info("Fetching and emitting existing download progress...")
|
||||
async for (
|
||||
_,
|
||||
progress,
|
||||
) in self.shard_downloader.get_shard_download_status():
|
||||
if progress.status == "complete":
|
||||
status = DownloadCompleted(
|
||||
node_id=self.node_id,
|
||||
shard_metadata=progress.shard,
|
||||
total_bytes=progress.total_bytes,
|
||||
)
|
||||
elif progress.status in ["in_progress", "not_started"]:
|
||||
if progress.downloaded_bytes_this_session.in_bytes == 0:
|
||||
status = DownloadPending(
|
||||
node_id=self.node_id, shard_metadata=progress.shard
|
||||
)
|
||||
else:
|
||||
status = DownloadOngoing(
|
||||
node_id=self.node_id,
|
||||
shard_metadata=progress.shard,
|
||||
download_progress=map_repo_download_progress_to_download_progress_data(
|
||||
progress
|
||||
),
|
||||
)
|
||||
else:
|
||||
continue
|
||||
|
||||
self.download_status[progress.shard.model_meta.model_id] = status
|
||||
await self.event_sender.send(
|
||||
NodeDownloadProgress(download_progress=status)
|
||||
)
|
||||
logger.info("Done emitting existing download progress.")
|
||||
await anyio.sleep(5 * 60) # 5 minutes
|
||||
except Exception as e:
|
||||
logger.error(f"Error emitting existing download progress: {e}")
|
||||
|
||||
Reference in New Issue
Block a user