mirror of
https://github.com/exo-explore/exo.git
synced 2025-12-23 14:17:58 -05:00
format
This commit is contained in:
@@ -20,11 +20,11 @@ from pydantic import (
|
||||
TypeAdapter,
|
||||
)
|
||||
|
||||
from exo.utils.channels import Sender, Receiver, channel
|
||||
from exo.shared.constants import EXO_HOME, EXO_MODELS_DIR
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.worker.downloads import DownloadProgressData
|
||||
from exo.shared.types.worker.shards import ShardMetadata
|
||||
from exo.utils.channels import Receiver, Sender, channel
|
||||
from exo.worker.download.huggingface_utils import (
|
||||
filter_repo_objects,
|
||||
get_allow_patterns,
|
||||
@@ -321,7 +321,7 @@ async def download_file_with_retry(
|
||||
revision: str,
|
||||
path: str,
|
||||
target_dir: Path,
|
||||
progress_sender: Sender[tuple[int, int, bool]] | None
|
||||
progress_sender: Sender[tuple[int, int, bool]] | None,
|
||||
) -> Path:
|
||||
n_attempts = 30
|
||||
for attempt in range(n_attempts):
|
||||
@@ -568,9 +568,7 @@ async def download_shard(
|
||||
)
|
||||
file_progress: dict[str, RepoFileDownloadProgress] = {}
|
||||
|
||||
async def huh(
|
||||
file: FileListEntry, recv: Receiver[tuple[int, int, bool]]
|
||||
):
|
||||
async def huh(file: FileListEntry, recv: Receiver[tuple[int, int, bool]]):
|
||||
async with recv:
|
||||
async for curr, total, done in recv:
|
||||
await progress_sender.send(on_progress_wrapper(file, curr, total, done))
|
||||
@@ -613,14 +611,12 @@ async def download_shard(
|
||||
else "in_progress",
|
||||
start_time=start_time,
|
||||
)
|
||||
return (
|
||||
calculate_repo_progress(
|
||||
shard,
|
||||
str(shard.model_meta.model_id),
|
||||
revision,
|
||||
file_progress,
|
||||
all_start_time,
|
||||
)
|
||||
return calculate_repo_progress(
|
||||
shard,
|
||||
str(shard.model_meta.model_id),
|
||||
revision,
|
||||
file_progress,
|
||||
all_start_time,
|
||||
)
|
||||
|
||||
for file in filtered_file_list:
|
||||
@@ -640,7 +636,9 @@ async def download_shard(
|
||||
|
||||
semaphore = anyio.Semaphore(max_parallel_downloads)
|
||||
|
||||
async def download_with_semaphore(file: FileListEntry, sender: Sender[tuple[int, int, bool]]):
|
||||
async def download_with_semaphore(
|
||||
file: FileListEntry, sender: Sender[tuple[int, int, bool]]
|
||||
):
|
||||
async with semaphore:
|
||||
await download_file_with_retry(
|
||||
str(shard.model_meta.model_id),
|
||||
@@ -653,7 +651,7 @@ async def download_shard(
|
||||
if not skip_download:
|
||||
async with anyio.create_task_group() as tg:
|
||||
for file in filtered_file_list:
|
||||
send, recv = channel[tuple[int,int,bool]](1)
|
||||
send, recv = channel[tuple[int, int, bool]](1)
|
||||
tg.start_soon(download_with_semaphore, file, send)
|
||||
tg.start_soon(huh, file, recv)
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import AsyncIterator, Callable, Self
|
||||
|
||||
import anyio
|
||||
from anyio import Path, create_task_group
|
||||
from anyio.abc import TaskGroup, CancelScope
|
||||
from anyio.abc import CancelScope, TaskGroup
|
||||
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.models import ModelId, ModelMetadata
|
||||
@@ -13,17 +13,19 @@ from exo.shared.types.worker.shards import (
|
||||
PipelineShardMetadata,
|
||||
ShardMetadata,
|
||||
)
|
||||
from exo.utils.channels import Sender, Receiver, channel
|
||||
from exo.utils.channels import Receiver, Sender, channel
|
||||
from exo.worker.download.download_utils import RepoDownloadProgress, download_shard
|
||||
|
||||
|
||||
@dataclass
|
||||
class ShardDownloader2:
|
||||
progress_sender: Sender[RepoDownloadProgress]
|
||||
max_parallel_downloads=8
|
||||
max_parallel_downloads = 8
|
||||
|
||||
# The last item on the shard stack is currently being downloaded
|
||||
shard_stack: list[tuple[ShardMetadata, bool]] = field(init=False, default_factory = list)
|
||||
shard_stack: list[tuple[ShardMetadata, bool]] = field(
|
||||
init=False, default_factory=list
|
||||
)
|
||||
_top_scope: CancelScope | None = field(init=False, default=None)
|
||||
_tg: TaskGroup = field(init=False, default_factory=create_task_group)
|
||||
|
||||
@@ -35,13 +37,10 @@ class ShardDownloader2:
|
||||
# Create a new scope
|
||||
self._top_scope = CancelScope()
|
||||
|
||||
|
||||
|
||||
async def run(self):
|
||||
async with self._tg as tg:
|
||||
async with self._tg:
|
||||
await anyio.sleep_forever()
|
||||
|
||||
|
||||
def shutdown(self):
|
||||
self.progress_sender.close()
|
||||
self._tg.cancel_scope.cancel()
|
||||
@@ -58,9 +57,6 @@ class ShardDownloader2:
|
||||
)
|
||||
return target_dir
|
||||
|
||||
|
||||
|
||||
|
||||
@classmethod
|
||||
def default(cls) -> tuple[Self, Receiver[RepoDownloadProgress]]:
|
||||
send, recv = channel[RepoDownloadProgress](10)
|
||||
|
||||
Reference in New Issue
Block a user