This commit is contained in:
Evan
2025-12-16 18:10:44 +00:00
parent 39a4857189
commit ee218d1f4c
2 changed files with 20 additions and 26 deletions

View File

@@ -20,11 +20,11 @@ from pydantic import (
TypeAdapter,
)
from exo.utils.channels import Sender, Receiver, channel
from exo.shared.constants import EXO_HOME, EXO_MODELS_DIR
from exo.shared.types.memory import Memory
from exo.shared.types.worker.downloads import DownloadProgressData
from exo.shared.types.worker.shards import ShardMetadata
from exo.utils.channels import Receiver, Sender, channel
from exo.worker.download.huggingface_utils import (
filter_repo_objects,
get_allow_patterns,
@@ -321,7 +321,7 @@ async def download_file_with_retry(
revision: str,
path: str,
target_dir: Path,
progress_sender: Sender[tuple[int, int, bool]] | None
progress_sender: Sender[tuple[int, int, bool]] | None,
) -> Path:
n_attempts = 30
for attempt in range(n_attempts):
@@ -568,9 +568,7 @@ async def download_shard(
)
file_progress: dict[str, RepoFileDownloadProgress] = {}
async def huh(
file: FileListEntry, recv: Receiver[tuple[int, int, bool]]
):
async def huh(file: FileListEntry, recv: Receiver[tuple[int, int, bool]]):
async with recv:
async for curr, total, done in recv:
await progress_sender.send(on_progress_wrapper(file, curr, total, done))
@@ -613,14 +611,12 @@ async def download_shard(
else "in_progress",
start_time=start_time,
)
return (
calculate_repo_progress(
shard,
str(shard.model_meta.model_id),
revision,
file_progress,
all_start_time,
)
return calculate_repo_progress(
shard,
str(shard.model_meta.model_id),
revision,
file_progress,
all_start_time,
)
for file in filtered_file_list:
@@ -640,7 +636,9 @@ async def download_shard(
semaphore = anyio.Semaphore(max_parallel_downloads)
async def download_with_semaphore(file: FileListEntry, sender: Sender[tuple[int, int, bool]]):
async def download_with_semaphore(
file: FileListEntry, sender: Sender[tuple[int, int, bool]]
):
async with semaphore:
await download_file_with_retry(
str(shard.model_meta.model_id),
@@ -653,7 +651,7 @@ async def download_shard(
if not skip_download:
async with anyio.create_task_group() as tg:
for file in filtered_file_list:
send, recv = channel[tuple[int,int,bool]](1)
send, recv = channel[tuple[int, int, bool]](1)
tg.start_soon(download_with_semaphore, file, send)
tg.start_soon(huh, file, recv)

View File

@@ -5,7 +5,7 @@ from typing import AsyncIterator, Callable, Self
import anyio
from anyio import Path, create_task_group
from anyio.abc import TaskGroup, CancelScope
from anyio.abc import CancelScope, TaskGroup
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
@@ -13,17 +13,19 @@ from exo.shared.types.worker.shards import (
PipelineShardMetadata,
ShardMetadata,
)
from exo.utils.channels import Sender, Receiver, channel
from exo.utils.channels import Receiver, Sender, channel
from exo.worker.download.download_utils import RepoDownloadProgress, download_shard
@dataclass
class ShardDownloader2:
progress_sender: Sender[RepoDownloadProgress]
max_parallel_downloads=8
max_parallel_downloads = 8
# The last item on the shard stack is currently being downloaded
shard_stack: list[tuple[ShardMetadata, bool]] = field(init=False, default_factory = list)
shard_stack: list[tuple[ShardMetadata, bool]] = field(
init=False, default_factory=list
)
_top_scope: CancelScope | None = field(init=False, default=None)
_tg: TaskGroup = field(init=False, default_factory=create_task_group)
@@ -35,13 +37,10 @@ class ShardDownloader2:
# Create a new scope
self._top_scope = CancelScope()
async def run(self):
async with self._tg as tg:
async with self._tg:
await anyio.sleep_forever()
def shutdown(self):
self.progress_sender.close()
self._tg.cancel_scope.cancel()
@@ -58,9 +57,6 @@ class ShardDownloader2:
)
return target_dir
@classmethod
def default(cls) -> tuple[Self, Receiver[RepoDownloadProgress]]:
send, recv = channel[RepoDownloadProgress](10)