mirror of
https://github.com/exo-explore/exo.git
synced 2026-01-17 10:29:26 -05:00
Compare commits
4 Commits
aiohttp
...
leo/handle
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d1d414c57a | ||
|
|
82ad11d971 | ||
|
|
049bc84e8b | ||
|
|
83c5285a80 |
@@ -189,7 +189,8 @@ def _load_model(
|
||||
|
||||
all_local_downloads_complete = all(
|
||||
nid in global_download_status
|
||||
and any(
|
||||
and global_download_status[nid] != [] # non-empty
|
||||
and all(
|
||||
isinstance(dp, DownloadCompleted)
|
||||
and dp.shard_metadata.model_meta.model_id == shard_assignments.model_id
|
||||
for dp in global_download_status[nid]
|
||||
|
||||
@@ -3,7 +3,12 @@ from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.models import ModelId
|
||||
from exo.shared.types.tasks import LoadModel
|
||||
from exo.shared.types.worker.downloads import DownloadCompleted, DownloadProgress
|
||||
from exo.shared.types.worker.downloads import (
|
||||
DownloadCompleted,
|
||||
DownloadOngoing,
|
||||
DownloadProgress,
|
||||
DownloadProgressData,
|
||||
)
|
||||
from exo.shared.types.worker.instances import BoundInstance
|
||||
from exo.shared.types.worker.runners import (
|
||||
RunnerConnected,
|
||||
@@ -254,3 +259,80 @@ def test_plan_does_not_load_model_until_all_shards_downloaded_globally():
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
|
||||
|
||||
def test_plan_does_not_load_model_when_download_still_ongoing():
|
||||
"""
|
||||
LoadModel should not be emitted when a node has a DownloadOngoing entry,
|
||||
even if the same node also has a DownloadCompleted entry.
|
||||
All downloads for the model must be DownloadCompleted.
|
||||
"""
|
||||
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
|
||||
shard2 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)
|
||||
instance = get_mlx_ring_instance(
|
||||
instance_id=INSTANCE_1_ID,
|
||||
model_id=MODEL_A_ID,
|
||||
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},
|
||||
runner_to_shard={RUNNER_1_ID: shard1, RUNNER_2_ID: shard2},
|
||||
)
|
||||
|
||||
bound_instance = BoundInstance(
|
||||
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
|
||||
)
|
||||
local_runner = FakeRunnerSupervisor(
|
||||
bound_instance=bound_instance, status=RunnerConnected()
|
||||
)
|
||||
|
||||
runners = {RUNNER_1_ID: local_runner}
|
||||
instances = {INSTANCE_1_ID: instance}
|
||||
all_runners = {
|
||||
RUNNER_1_ID: RunnerConnected(),
|
||||
RUNNER_2_ID: RunnerConnected(),
|
||||
}
|
||||
|
||||
local_download_status = {
|
||||
MODEL_A_ID: DownloadCompleted(
|
||||
shard_metadata=shard1, node_id=NODE_A, total_bytes=Memory()
|
||||
)
|
||||
}
|
||||
|
||||
# NODE_A completed, NODE_B has BOTH a completed and an ongoing download
|
||||
# This tests that ALL downloads must be completed, not just ANY
|
||||
global_download_status = {
|
||||
NODE_A: [
|
||||
DownloadCompleted(
|
||||
shard_metadata=shard1, node_id=NODE_A, total_bytes=Memory()
|
||||
)
|
||||
],
|
||||
NODE_B: [
|
||||
DownloadCompleted(
|
||||
shard_metadata=shard2, node_id=NODE_B, total_bytes=Memory()
|
||||
),
|
||||
DownloadOngoing(
|
||||
shard_metadata=shard2,
|
||||
node_id=NODE_B,
|
||||
download_progress=DownloadProgressData(
|
||||
total_bytes=Memory(in_bytes=1000),
|
||||
downloaded_bytes=Memory(in_bytes=500),
|
||||
downloaded_bytes_this_session=Memory(in_bytes=500),
|
||||
completed_files=0,
|
||||
total_files=1,
|
||||
speed=100.0,
|
||||
eta_ms=5000,
|
||||
files={},
|
||||
),
|
||||
),
|
||||
],
|
||||
}
|
||||
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners, # type: ignore
|
||||
download_status=local_download_status,
|
||||
global_download_status=global_download_status,
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
tasks={},
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
@@ -22,7 +22,6 @@ async def check_reachability(
|
||||
url = f"http://{target_ip}:52415/node_id"
|
||||
|
||||
remote_node_id = None
|
||||
|
||||
last_error = None
|
||||
|
||||
for _ in range(REACHABILITY_ATTEMPTS):
|
||||
@@ -40,24 +39,24 @@ async def check_reachability(
|
||||
remote_node_id = NodeId(body)
|
||||
break
|
||||
|
||||
# expected failure cases
|
||||
except (
|
||||
httpx.ConnectError,
|
||||
httpx.ConnectTimeout,
|
||||
httpx.ReadTimeout,
|
||||
httpx.RemoteProtocolError,
|
||||
) as e:
|
||||
httpx.TimeoutException,
|
||||
httpx.NetworkError,
|
||||
):
|
||||
await anyio.sleep(1)
|
||||
|
||||
# other failures should be logged on last attempt
|
||||
except httpx.HTTPError as e:
|
||||
last_error = e
|
||||
await anyio.sleep(1)
|
||||
|
||||
else:
|
||||
if last_error is not None:
|
||||
logger.warning(
|
||||
f"connect error {type(last_error).__name__} from {target_ip} after {REACHABILITY_ATTEMPTS} attempts; treating as down"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"malformed response from {target_ip} after {REACHABILITY_ATTEMPTS} attempts; treating as down"
|
||||
)
|
||||
if last_error is not None:
|
||||
logger.warning(
|
||||
f"connect error {type(last_error).__name__} from {target_ip} after {REACHABILITY_ATTEMPTS} attempts; treating as down"
|
||||
)
|
||||
|
||||
if remote_node_id is None:
|
||||
return
|
||||
|
||||
if remote_node_id != expected_node_id:
|
||||
|
||||
Reference in New Issue
Block a user