From ee218d1f4c8ffa9ba6dab4edda38ab3b5d5dd8a5 Mon Sep 17 00:00:00 2001 From: Evan Date: Tue, 16 Dec 2025 18:10:44 +0000 Subject: [PATCH] format --- src/exo/worker/download/download_utils.py | 28 ++++++++++----------- src/exo/worker/download/shard_downloader.py | 18 ++++++------- 2 files changed, 20 insertions(+), 26 deletions(-) diff --git a/src/exo/worker/download/download_utils.py b/src/exo/worker/download/download_utils.py index 234bfd67..cff1c173 100644 --- a/src/exo/worker/download/download_utils.py +++ b/src/exo/worker/download/download_utils.py @@ -20,11 +20,11 @@ from pydantic import ( TypeAdapter, ) -from exo.utils.channels import Sender, Receiver, channel from exo.shared.constants import EXO_HOME, EXO_MODELS_DIR from exo.shared.types.memory import Memory from exo.shared.types.worker.downloads import DownloadProgressData from exo.shared.types.worker.shards import ShardMetadata +from exo.utils.channels import Receiver, Sender, channel from exo.worker.download.huggingface_utils import ( filter_repo_objects, get_allow_patterns, @@ -321,7 +321,7 @@ async def download_file_with_retry( revision: str, path: str, target_dir: Path, - progress_sender: Sender[tuple[int, int, bool]] | None + progress_sender: Sender[tuple[int, int, bool]] | None, ) -> Path: n_attempts = 30 for attempt in range(n_attempts): @@ -568,9 +568,7 @@ async def download_shard( ) file_progress: dict[str, RepoFileDownloadProgress] = {} - async def huh( - file: FileListEntry, recv: Receiver[tuple[int, int, bool]] - ): + async def huh(file: FileListEntry, recv: Receiver[tuple[int, int, bool]]): async with recv: async for curr, total, done in recv: await progress_sender.send(on_progress_wrapper(file, curr, total, done)) @@ -613,14 +611,12 @@ async def download_shard( else "in_progress", start_time=start_time, ) - return ( - calculate_repo_progress( - shard, - str(shard.model_meta.model_id), - revision, - file_progress, - all_start_time, - ) + return calculate_repo_progress( + shard, + str(shard.model_meta.model_id), + revision, + file_progress, + all_start_time, ) for file in filtered_file_list: @@ -640,7 +636,9 @@ async def download_shard( semaphore = anyio.Semaphore(max_parallel_downloads) - async def download_with_semaphore(file: FileListEntry, sender: Sender[tuple[int, int, bool]]): + async def download_with_semaphore( + file: FileListEntry, sender: Sender[tuple[int, int, bool]] + ): async with semaphore: await download_file_with_retry( str(shard.model_meta.model_id), @@ -653,7 +651,7 @@ async def download_shard( if not skip_download: async with anyio.create_task_group() as tg: for file in filtered_file_list: - send, recv = channel[tuple[int,int,bool]](1) + send, recv = channel[tuple[int, int, bool]](1) tg.start_soon(download_with_semaphore, file, send) tg.start_soon(huh, file, recv) diff --git a/src/exo/worker/download/shard_downloader.py b/src/exo/worker/download/shard_downloader.py index fdc27436..cb1abd5b 100644 --- a/src/exo/worker/download/shard_downloader.py +++ b/src/exo/worker/download/shard_downloader.py @@ -5,7 +5,7 @@ from typing import AsyncIterator, Callable, Self import anyio from anyio import Path, create_task_group -from anyio.abc import TaskGroup, CancelScope +from anyio.abc import CancelScope, TaskGroup from exo.shared.types.memory import Memory from exo.shared.types.models import ModelId, ModelMetadata @@ -13,17 +13,19 @@ from exo.shared.types.worker.shards import ( PipelineShardMetadata, ShardMetadata, ) -from exo.utils.channels import Sender, Receiver, channel +from exo.utils.channels import Receiver, Sender, channel from exo.worker.download.download_utils import RepoDownloadProgress, download_shard @dataclass class ShardDownloader2: progress_sender: Sender[RepoDownloadProgress] - max_parallel_downloads=8 + max_parallel_downloads = 8 # The last item on the shard stack is currently being downloaded - shard_stack: list[tuple[ShardMetadata, bool]] = field(init=False, default_factory = list) + shard_stack: list[tuple[ShardMetadata, bool]] = field( + init=False, default_factory=list + ) _top_scope: CancelScope | None = field(init=False, default=None) _tg: TaskGroup = field(init=False, default_factory=create_task_group) @@ -35,13 +37,10 @@ class ShardDownloader2: # Create a new scope self._top_scope = CancelScope() - - async def run(self): - async with self._tg as tg: + async with self._tg: await anyio.sleep_forever() - def shutdown(self): self.progress_sender.close() self._tg.cancel_scope.cancel() @@ -58,9 +57,6 @@ class ShardDownloader2: ) return target_dir - - - @classmethod def default(cls) -> tuple[Self, Receiver[RepoDownloadProgress]]: send, recv = channel[RepoDownloadProgress](10)