Compare commits

...

1 Commits

Author SHA1 Message Date
Alex Cheema
73599fbaec fix: download gets permanently stuck after cancellation
_cancel_download() cancelled the asyncio task but never cleared the
coordinator's local download_status or emitted an event to update global
state. Both remained stuck on DownloadOngoing, causing any new
StartDownload for the same model to be silently skipped.

Also fixes apply_node_download_progress to match by model_id instead of
full shard_metadata equality, preventing duplicate download entries when
instance type changes (e.g. MlxJaccl -> MlxRing).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-24 11:13:55 -08:00
4 changed files with 76 additions and 5 deletions

View File

@@ -199,9 +199,22 @@ class DownloadCoordinator:
await self._cancel_download(model_id)
async def _cancel_download(self, model_id: ModelId) -> None:
if model_id in self.active_downloads and model_id in self.download_status:
if model_id in self.active_downloads:
logger.info(f"Cancelling download for {model_id}")
self.active_downloads.pop(model_id).cancel()
if model_id in self.download_status:
current_status = self.download_status[model_id]
if isinstance(current_status, DownloadOngoing):
pending = DownloadPending(
shard_metadata=current_status.shard_metadata,
node_id=self.node_id,
model_directory=self._model_dir(model_id),
)
await self.event_sender.send(
NodeDownloadProgress(download_progress=pending)
)
del self.download_status[model_id]
self._last_progress_time.pop(model_id, None)
async def _start_download(self, shard: ShardMetadata) -> None:
model_id = shard.model_card.model_id
@@ -298,6 +311,11 @@ class DownloadCoordinator:
async def download_wrapper() -> None:
try:
await self.shard_downloader.ensure_shard(shard)
except asyncio.CancelledError:
logger.info(f"Download cancelled for {model_id}")
if model_id in self.download_status:
del self.download_status[model_id]
raise
except Exception as e:
logger.error(f"Download failed for {model_id}: {e}")
failed = DownloadFailed(

View File

@@ -118,7 +118,7 @@ def apply_node_download_progress(event: NodeDownloadProgress, state: State) -> S
replaced = False
for i, existing_dp in enumerate(current):
if existing_dp.shard_metadata == dp.shard_metadata:
if existing_dp.shard_metadata.model_card.model_id == dp.shard_metadata.model_card.model_id:
current[i] = dp
replaced = True
break

View File

@@ -9,7 +9,11 @@ from loguru import logger
from exo.shared.models.model_cards import ModelCard, ModelId, ModelTask
from exo.shared.types.memory import Memory
from exo.shared.types.worker.shards import PipelineShardMetadata, ShardMetadata
from exo.shared.types.worker.shards import (
PipelineShardMetadata,
ShardMetadata,
TensorShardMetadata,
)
@pytest.fixture(scope="session")
@@ -47,6 +51,26 @@ def get_pipeline_shard_metadata(
)
def get_tensor_shard_metadata(
model_id: ModelId, device_rank: int, world_size: int = 1
) -> ShardMetadata:
return TensorShardMetadata(
model_card=ModelCard(
model_id=model_id,
storage_size=Memory.from_mb(100000),
n_layers=32,
hidden_size=1000,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
device_rank=device_rank,
world_size=world_size,
start_layer=0,
end_layer=32,
n_layers=32,
)
@pytest.fixture
def caplog(caplog: LogCaptureFixture):
handler_id = logger.add(

View File

@@ -1,10 +1,13 @@
from exo.shared.apply import apply_node_download_progress
from exo.shared.tests.conftest import get_pipeline_shard_metadata
from exo.shared.tests.conftest import (
get_pipeline_shard_metadata,
get_tensor_shard_metadata,
)
from exo.shared.types.common import NodeId
from exo.shared.types.events import NodeDownloadProgress
from exo.shared.types.memory import Memory
from exo.shared.types.state import State
from exo.shared.types.worker.downloads import DownloadCompleted
from exo.shared.types.worker.downloads import DownloadCompleted, DownloadPending
from exo.worker.tests.constants import MODEL_A_ID, MODEL_B_ID
@@ -44,3 +47,29 @@ def test_apply_two_node_download_progress():
)
assert new_state.downloads == {NodeId("node-1"): [event1, event2]}
def test_apply_download_progress_replaces_when_shard_metadata_type_changes():
"""When the same model is re-created with a different shard type, the download
entry should be replaced rather than appended."""
pipeline_shard = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=1)
tensor_shard = get_tensor_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)
old_entry = DownloadCompleted(
node_id=NodeId("node-1"),
shard_metadata=pipeline_shard,
total=Memory(),
)
state = State(downloads={NodeId("node-1"): [old_entry]})
new_entry = DownloadPending(
node_id=NodeId("node-1"),
shard_metadata=tensor_shard,
)
new_state = apply_node_download_progress(
NodeDownloadProgress(download_progress=new_entry), state
)
# Should replace, not append - one entry per model per node
assert len(new_state.downloads[NodeId("node-1")]) == 1
assert new_state.downloads[NodeId("node-1")][0] == new_entry