Compare commits

...

3 Commits

Author SHA1 Message Date
Ryuichi Leo Takashige
d1d414c57a More explicit condition 2026-01-16 20:27:12 +00:00
Ryuichi Leo Takashige
82ad11d971 Wait for ALL download progresses to be downloaded to begin LoadModel 2026-01-16 20:21:56 +00:00
Ryuichi Leo Takashige
049bc84e8b Failing test 2026-01-16 20:21:29 +00:00
2 changed files with 85 additions and 2 deletions

View File

@@ -189,7 +189,8 @@ def _load_model(
all_local_downloads_complete = all(
nid in global_download_status
and any(
and global_download_status[nid] != [] # non-empty
and all(
isinstance(dp, DownloadCompleted)
and dp.shard_metadata.model_meta.model_id == shard_assignments.model_id
for dp in global_download_status[nid]

View File

@@ -3,7 +3,12 @@ from exo.shared.types.common import NodeId
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId
from exo.shared.types.tasks import LoadModel
from exo.shared.types.worker.downloads import DownloadCompleted, DownloadProgress
from exo.shared.types.worker.downloads import (
DownloadCompleted,
DownloadOngoing,
DownloadProgress,
DownloadProgressData,
)
from exo.shared.types.worker.instances import BoundInstance
from exo.shared.types.worker.runners import (
RunnerConnected,
@@ -254,3 +259,80 @@ def test_plan_does_not_load_model_until_all_shards_downloaded_globally():
)
assert result is not None
def test_plan_does_not_load_model_when_download_still_ongoing():
"""
LoadModel should not be emitted when a node has a DownloadOngoing entry,
even if the same node also has a DownloadCompleted entry.
All downloads for the model must be DownloadCompleted.
"""
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
shard2 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)
instance = get_mlx_ring_instance(
instance_id=INSTANCE_1_ID,
model_id=MODEL_A_ID,
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},
runner_to_shard={RUNNER_1_ID: shard1, RUNNER_2_ID: shard2},
)
bound_instance = BoundInstance(
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
)
local_runner = FakeRunnerSupervisor(
bound_instance=bound_instance, status=RunnerConnected()
)
runners = {RUNNER_1_ID: local_runner}
instances = {INSTANCE_1_ID: instance}
all_runners = {
RUNNER_1_ID: RunnerConnected(),
RUNNER_2_ID: RunnerConnected(),
}
local_download_status = {
MODEL_A_ID: DownloadCompleted(
shard_metadata=shard1, node_id=NODE_A, total_bytes=Memory()
)
}
# NODE_A completed, NODE_B has BOTH a completed and an ongoing download
# This tests that ALL downloads must be completed, not just ANY
global_download_status = {
NODE_A: [
DownloadCompleted(
shard_metadata=shard1, node_id=NODE_A, total_bytes=Memory()
)
],
NODE_B: [
DownloadCompleted(
shard_metadata=shard2, node_id=NODE_B, total_bytes=Memory()
),
DownloadOngoing(
shard_metadata=shard2,
node_id=NODE_B,
download_progress=DownloadProgressData(
total_bytes=Memory(in_bytes=1000),
downloaded_bytes=Memory(in_bytes=500),
downloaded_bytes_this_session=Memory(in_bytes=500),
completed_files=0,
total_files=1,
speed=100.0,
eta_ms=5000,
files={},
),
),
],
}
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status=local_download_status,
global_download_status=global_download_status,
instances=instances,
all_runners=all_runners,
tasks={},
)
assert result is None