Compare commits

...

4 Commits

Author SHA1 Message Date
Ryuichi Leo Takashige
d1d414c57a More explicit condition 2026-01-16 20:27:12 +00:00
Ryuichi Leo Takashige
82ad11d971 Wait for ALL download progresses to be downloaded to begin LoadModel 2026-01-16 20:21:56 +00:00
Ryuichi Leo Takashige
049bc84e8b Failing test 2026-01-16 20:21:29 +00:00
Evan
83c5285a80 reduce logs
previous commits logs were too verbose, this tones them down a bit
2026-01-16 14:05:47 +00:00
3 changed files with 99 additions and 17 deletions

View File

@@ -189,7 +189,8 @@ def _load_model(
all_local_downloads_complete = all(
nid in global_download_status
and any(
and global_download_status[nid] != [] # non-empty
and all(
isinstance(dp, DownloadCompleted)
and dp.shard_metadata.model_meta.model_id == shard_assignments.model_id
for dp in global_download_status[nid]

View File

@@ -3,7 +3,12 @@ from exo.shared.types.common import NodeId
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId
from exo.shared.types.tasks import LoadModel
from exo.shared.types.worker.downloads import DownloadCompleted, DownloadProgress
from exo.shared.types.worker.downloads import (
DownloadCompleted,
DownloadOngoing,
DownloadProgress,
DownloadProgressData,
)
from exo.shared.types.worker.instances import BoundInstance
from exo.shared.types.worker.runners import (
RunnerConnected,
@@ -254,3 +259,80 @@ def test_plan_does_not_load_model_until_all_shards_downloaded_globally():
)
assert result is not None
def test_plan_does_not_load_model_when_download_still_ongoing():
"""
LoadModel should not be emitted when a node has a DownloadOngoing entry,
even if the same node also has a DownloadCompleted entry.
All downloads for the model must be DownloadCompleted.
"""
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
shard2 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)
instance = get_mlx_ring_instance(
instance_id=INSTANCE_1_ID,
model_id=MODEL_A_ID,
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},
runner_to_shard={RUNNER_1_ID: shard1, RUNNER_2_ID: shard2},
)
bound_instance = BoundInstance(
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
)
local_runner = FakeRunnerSupervisor(
bound_instance=bound_instance, status=RunnerConnected()
)
runners = {RUNNER_1_ID: local_runner}
instances = {INSTANCE_1_ID: instance}
all_runners = {
RUNNER_1_ID: RunnerConnected(),
RUNNER_2_ID: RunnerConnected(),
}
local_download_status = {
MODEL_A_ID: DownloadCompleted(
shard_metadata=shard1, node_id=NODE_A, total_bytes=Memory()
)
}
# NODE_A completed, NODE_B has BOTH a completed and an ongoing download
# This tests that ALL downloads must be completed, not just ANY
global_download_status = {
NODE_A: [
DownloadCompleted(
shard_metadata=shard1, node_id=NODE_A, total_bytes=Memory()
)
],
NODE_B: [
DownloadCompleted(
shard_metadata=shard2, node_id=NODE_B, total_bytes=Memory()
),
DownloadOngoing(
shard_metadata=shard2,
node_id=NODE_B,
download_progress=DownloadProgressData(
total_bytes=Memory(in_bytes=1000),
downloaded_bytes=Memory(in_bytes=500),
downloaded_bytes_this_session=Memory(in_bytes=500),
completed_files=0,
total_files=1,
speed=100.0,
eta_ms=5000,
files={},
),
),
],
}
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status=local_download_status,
global_download_status=global_download_status,
instances=instances,
all_runners=all_runners,
tasks={},
)
assert result is None

View File

@@ -22,7 +22,6 @@ async def check_reachability(
url = f"http://{target_ip}:52415/node_id"
remote_node_id = None
last_error = None
for _ in range(REACHABILITY_ATTEMPTS):
@@ -40,24 +39,24 @@ async def check_reachability(
remote_node_id = NodeId(body)
break
# expected failure cases
except (
httpx.ConnectError,
httpx.ConnectTimeout,
httpx.ReadTimeout,
httpx.RemoteProtocolError,
) as e:
httpx.TimeoutException,
httpx.NetworkError,
):
await anyio.sleep(1)
# other failures should be logged on last attempt
except httpx.HTTPError as e:
last_error = e
await anyio.sleep(1)
else:
if last_error is not None:
logger.warning(
f"connect error {type(last_error).__name__} from {target_ip} after {REACHABILITY_ATTEMPTS} attempts; treating as down"
)
else:
logger.warning(
f"malformed response from {target_ip} after {REACHABILITY_ATTEMPTS} attempts; treating as down"
)
if last_error is not None:
logger.warning(
f"connect error {type(last_error).__name__} from {target_ip} after {REACHABILITY_ATTEMPTS} attempts; treating as down"
)
if remote_node_id is None:
return
if remote_node_id != expected_node_id: