mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-25 18:58:39 -05:00
Compare commits
1 Commits
remove-pyt
...
alexcheema
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
73599fbaec |
@@ -199,9 +199,22 @@ class DownloadCoordinator:
|
||||
await self._cancel_download(model_id)
|
||||
|
||||
async def _cancel_download(self, model_id: ModelId) -> None:
|
||||
if model_id in self.active_downloads and model_id in self.download_status:
|
||||
if model_id in self.active_downloads:
|
||||
logger.info(f"Cancelling download for {model_id}")
|
||||
self.active_downloads.pop(model_id).cancel()
|
||||
if model_id in self.download_status:
|
||||
current_status = self.download_status[model_id]
|
||||
if isinstance(current_status, DownloadOngoing):
|
||||
pending = DownloadPending(
|
||||
shard_metadata=current_status.shard_metadata,
|
||||
node_id=self.node_id,
|
||||
model_directory=self._model_dir(model_id),
|
||||
)
|
||||
await self.event_sender.send(
|
||||
NodeDownloadProgress(download_progress=pending)
|
||||
)
|
||||
del self.download_status[model_id]
|
||||
self._last_progress_time.pop(model_id, None)
|
||||
|
||||
async def _start_download(self, shard: ShardMetadata) -> None:
|
||||
model_id = shard.model_card.model_id
|
||||
@@ -298,6 +311,11 @@ class DownloadCoordinator:
|
||||
async def download_wrapper() -> None:
|
||||
try:
|
||||
await self.shard_downloader.ensure_shard(shard)
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"Download cancelled for {model_id}")
|
||||
if model_id in self.download_status:
|
||||
del self.download_status[model_id]
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Download failed for {model_id}: {e}")
|
||||
failed = DownloadFailed(
|
||||
|
||||
@@ -118,7 +118,7 @@ def apply_node_download_progress(event: NodeDownloadProgress, state: State) -> S
|
||||
|
||||
replaced = False
|
||||
for i, existing_dp in enumerate(current):
|
||||
if existing_dp.shard_metadata == dp.shard_metadata:
|
||||
if existing_dp.shard_metadata.model_card.model_id == dp.shard_metadata.model_card.model_id:
|
||||
current[i] = dp
|
||||
replaced = True
|
||||
break
|
||||
|
||||
@@ -9,7 +9,11 @@ from loguru import logger
|
||||
|
||||
from exo.shared.models.model_cards import ModelCard, ModelId, ModelTask
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata, ShardMetadata
|
||||
from exo.shared.types.worker.shards import (
|
||||
PipelineShardMetadata,
|
||||
ShardMetadata,
|
||||
TensorShardMetadata,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
@@ -47,6 +51,26 @@ def get_pipeline_shard_metadata(
|
||||
)
|
||||
|
||||
|
||||
def get_tensor_shard_metadata(
|
||||
model_id: ModelId, device_rank: int, world_size: int = 1
|
||||
) -> ShardMetadata:
|
||||
return TensorShardMetadata(
|
||||
model_card=ModelCard(
|
||||
model_id=model_id,
|
||||
storage_size=Memory.from_mb(100000),
|
||||
n_layers=32,
|
||||
hidden_size=1000,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
device_rank=device_rank,
|
||||
world_size=world_size,
|
||||
start_layer=0,
|
||||
end_layer=32,
|
||||
n_layers=32,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def caplog(caplog: LogCaptureFixture):
|
||||
handler_id = logger.add(
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
from exo.shared.apply import apply_node_download_progress
|
||||
from exo.shared.tests.conftest import get_pipeline_shard_metadata
|
||||
from exo.shared.tests.conftest import (
|
||||
get_pipeline_shard_metadata,
|
||||
get_tensor_shard_metadata,
|
||||
)
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.events import NodeDownloadProgress
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.types.worker.downloads import DownloadCompleted
|
||||
from exo.shared.types.worker.downloads import DownloadCompleted, DownloadPending
|
||||
from exo.worker.tests.constants import MODEL_A_ID, MODEL_B_ID
|
||||
|
||||
|
||||
@@ -44,3 +47,29 @@ def test_apply_two_node_download_progress():
|
||||
)
|
||||
|
||||
assert new_state.downloads == {NodeId("node-1"): [event1, event2]}
|
||||
|
||||
|
||||
def test_apply_download_progress_replaces_when_shard_metadata_type_changes():
|
||||
"""When the same model is re-created with a different shard type, the download
|
||||
entry should be replaced rather than appended."""
|
||||
pipeline_shard = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=1)
|
||||
tensor_shard = get_tensor_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)
|
||||
|
||||
old_entry = DownloadCompleted(
|
||||
node_id=NodeId("node-1"),
|
||||
shard_metadata=pipeline_shard,
|
||||
total=Memory(),
|
||||
)
|
||||
state = State(downloads={NodeId("node-1"): [old_entry]})
|
||||
|
||||
new_entry = DownloadPending(
|
||||
node_id=NodeId("node-1"),
|
||||
shard_metadata=tensor_shard,
|
||||
)
|
||||
new_state = apply_node_download_progress(
|
||||
NodeDownloadProgress(download_progress=new_entry), state
|
||||
)
|
||||
|
||||
# Should replace, not append - one entry per model per node
|
||||
assert len(new_state.downloads[NodeId("node-1")]) == 1
|
||||
assert new_state.downloads[NodeId("node-1")][0] == new_entry
|
||||
|
||||
Reference in New Issue
Block a user