mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-25 18:58:39 -05:00
Compare commits
1 Commits
remove-pyt
...
fix/no-dow
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ce45af58e3 |
@@ -94,6 +94,7 @@ class Node:
|
|||||||
command_sender=router.sender(topics.COMMANDS),
|
command_sender=router.sender(topics.COMMANDS),
|
||||||
download_command_sender=router.sender(topics.DOWNLOAD_COMMANDS),
|
download_command_sender=router.sender(topics.DOWNLOAD_COMMANDS),
|
||||||
event_index_counter=event_index_counter,
|
event_index_counter=event_index_counter,
|
||||||
|
no_downloads=args.no_downloads,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
worker = None
|
worker = None
|
||||||
@@ -225,6 +226,7 @@ class Node:
|
|||||||
)
|
)
|
||||||
self._tg.start_soon(self.download_coordinator.run)
|
self._tg.start_soon(self.download_coordinator.run)
|
||||||
if self.worker:
|
if self.worker:
|
||||||
|
no_downloads = self.worker.no_downloads
|
||||||
self.worker.shutdown()
|
self.worker.shutdown()
|
||||||
# TODO: add profiling etc to resource monitor
|
# TODO: add profiling etc to resource monitor
|
||||||
self.worker = Worker(
|
self.worker = Worker(
|
||||||
@@ -239,6 +241,7 @@ class Node:
|
|||||||
topics.DOWNLOAD_COMMANDS
|
topics.DOWNLOAD_COMMANDS
|
||||||
),
|
),
|
||||||
event_index_counter=self.event_index_counter,
|
event_index_counter=self.event_index_counter,
|
||||||
|
no_downloads=no_downloads,
|
||||||
)
|
)
|
||||||
self._tg.start_soon(self.worker.run)
|
self._tg.start_soon(self.worker.run)
|
||||||
if self.api:
|
if self.api:
|
||||||
|
|||||||
@@ -65,6 +65,7 @@ class Worker:
|
|||||||
command_sender: Sender[ForwarderCommand],
|
command_sender: Sender[ForwarderCommand],
|
||||||
download_command_sender: Sender[ForwarderDownloadCommand],
|
download_command_sender: Sender[ForwarderDownloadCommand],
|
||||||
event_index_counter: Iterator[int],
|
event_index_counter: Iterator[int],
|
||||||
|
no_downloads: bool = False,
|
||||||
):
|
):
|
||||||
self.node_id: NodeId = node_id
|
self.node_id: NodeId = node_id
|
||||||
self.session_id: SessionId = session_id
|
self.session_id: SessionId = session_id
|
||||||
@@ -74,6 +75,7 @@ class Worker:
|
|||||||
self.event_index_counter = event_index_counter
|
self.event_index_counter = event_index_counter
|
||||||
self.command_sender = command_sender
|
self.command_sender = command_sender
|
||||||
self.download_command_sender = download_command_sender
|
self.download_command_sender = download_command_sender
|
||||||
|
self.no_downloads = no_downloads
|
||||||
self.event_buffer = OrderedBuffer[Event]()
|
self.event_buffer = OrderedBuffer[Event]()
|
||||||
self.out_for_delivery: dict[EventId, ForwarderEvent] = {}
|
self.out_for_delivery: dict[EventId, ForwarderEvent] = {}
|
||||||
|
|
||||||
@@ -182,6 +184,7 @@ class Worker:
|
|||||||
self.state.tasks,
|
self.state.tasks,
|
||||||
self.input_chunk_buffer,
|
self.input_chunk_buffer,
|
||||||
self.input_chunk_counts,
|
self.input_chunk_counts,
|
||||||
|
no_downloads=self.no_downloads,
|
||||||
)
|
)
|
||||||
if task is None:
|
if task is None:
|
||||||
continue
|
continue
|
||||||
|
|||||||
@@ -51,15 +51,20 @@ def plan(
|
|||||||
tasks: Mapping[TaskId, Task],
|
tasks: Mapping[TaskId, Task],
|
||||||
input_chunk_buffer: Mapping[CommandId, dict[int, str]] | None = None,
|
input_chunk_buffer: Mapping[CommandId, dict[int, str]] | None = None,
|
||||||
input_chunk_counts: Mapping[CommandId, int] | None = None,
|
input_chunk_counts: Mapping[CommandId, int] | None = None,
|
||||||
|
no_downloads: bool = False,
|
||||||
) -> Task | None:
|
) -> Task | None:
|
||||||
# Python short circuiting OR logic should evaluate these sequentially.
|
# Python short circuiting OR logic should evaluate these sequentially.
|
||||||
return (
|
return (
|
||||||
_cancel_tasks(runners, tasks)
|
_cancel_tasks(runners, tasks)
|
||||||
or _kill_runner(runners, all_runners, instances)
|
or _kill_runner(runners, all_runners, instances)
|
||||||
or _create_runner(node_id, runners, instances)
|
or _create_runner(node_id, runners, instances)
|
||||||
or _model_needs_download(node_id, runners, global_download_status)
|
or (
|
||||||
|
None
|
||||||
|
if no_downloads
|
||||||
|
else _model_needs_download(node_id, runners, global_download_status)
|
||||||
|
)
|
||||||
or _init_distributed_backend(runners, all_runners)
|
or _init_distributed_backend(runners, all_runners)
|
||||||
or _load_model(runners, all_runners, global_download_status)
|
or _load_model(runners, all_runners, global_download_status, no_downloads)
|
||||||
or _ready_to_warmup(runners, all_runners)
|
or _ready_to_warmup(runners, all_runners)
|
||||||
or _pending_tasks(runners, tasks, all_runners, input_chunk_buffer or {})
|
or _pending_tasks(runners, tasks, all_runners, input_chunk_buffer or {})
|
||||||
)
|
)
|
||||||
@@ -192,22 +197,25 @@ def _load_model(
|
|||||||
runners: Mapping[RunnerId, RunnerSupervisor],
|
runners: Mapping[RunnerId, RunnerSupervisor],
|
||||||
all_runners: Mapping[RunnerId, RunnerStatus],
|
all_runners: Mapping[RunnerId, RunnerStatus],
|
||||||
global_download_status: Mapping[NodeId, Sequence[DownloadProgress]],
|
global_download_status: Mapping[NodeId, Sequence[DownloadProgress]],
|
||||||
|
no_downloads: bool = False,
|
||||||
) -> LoadModel | None:
|
) -> LoadModel | None:
|
||||||
for runner in runners.values():
|
for runner in runners.values():
|
||||||
instance = runner.bound_instance.instance
|
instance = runner.bound_instance.instance
|
||||||
shard_assignments = instance.shard_assignments
|
shard_assignments = instance.shard_assignments
|
||||||
|
|
||||||
all_local_downloads_complete = all(
|
if not no_downloads:
|
||||||
nid in global_download_status
|
all_local_downloads_complete = all(
|
||||||
and any(
|
nid in global_download_status
|
||||||
isinstance(dp, DownloadCompleted)
|
and any(
|
||||||
and dp.shard_metadata.model_card.model_id == shard_assignments.model_id
|
isinstance(dp, DownloadCompleted)
|
||||||
for dp in global_download_status[nid]
|
and dp.shard_metadata.model_card.model_id
|
||||||
|
== shard_assignments.model_id
|
||||||
|
for dp in global_download_status[nid]
|
||||||
|
)
|
||||||
|
for nid in shard_assignments.node_to_runner
|
||||||
)
|
)
|
||||||
for nid in shard_assignments.node_to_runner
|
if not all_local_downloads_complete:
|
||||||
)
|
continue
|
||||||
if not all_local_downloads_complete:
|
|
||||||
continue
|
|
||||||
|
|
||||||
is_single_node_instance = len(instance.shard_assignments.runner_to_shard) == 1
|
is_single_node_instance = len(instance.shard_assignments.runner_to_shard) == 1
|
||||||
if is_single_node_instance and isinstance(runner.status, RunnerIdle):
|
if is_single_node_instance and isinstance(runner.status, RunnerIdle):
|
||||||
|
|||||||
@@ -157,6 +157,41 @@ def test_plan_does_not_request_download_when_shard_already_downloaded():
|
|||||||
assert not isinstance(result, plan_mod.DownloadModel)
|
assert not isinstance(result, plan_mod.DownloadModel)
|
||||||
|
|
||||||
|
|
||||||
|
def test_plan_loads_model_with_no_downloads_flag():
|
||||||
|
"""
|
||||||
|
When no_downloads=True, plan() should skip download checks and proceed
|
||||||
|
to load the model even with empty global_download_status.
|
||||||
|
"""
|
||||||
|
shard = get_pipeline_shard_metadata(model_id=MODEL_A_ID, device_rank=0)
|
||||||
|
instance = get_mlx_ring_instance(
|
||||||
|
instance_id=INSTANCE_1_ID,
|
||||||
|
model_id=MODEL_A_ID,
|
||||||
|
node_to_runner={NODE_A: RUNNER_1_ID},
|
||||||
|
runner_to_shard={RUNNER_1_ID: shard},
|
||||||
|
)
|
||||||
|
bound_instance = BoundInstance(
|
||||||
|
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
|
||||||
|
)
|
||||||
|
runner = FakeRunnerSupervisor(bound_instance=bound_instance, status=RunnerIdle())
|
||||||
|
|
||||||
|
runners = {RUNNER_1_ID: runner}
|
||||||
|
instances = {INSTANCE_1_ID: instance}
|
||||||
|
all_runners = {RUNNER_1_ID: RunnerIdle()}
|
||||||
|
|
||||||
|
result = plan_mod.plan(
|
||||||
|
node_id=NODE_A,
|
||||||
|
runners=runners, # type: ignore
|
||||||
|
global_download_status={},
|
||||||
|
instances=instances,
|
||||||
|
all_runners=all_runners,
|
||||||
|
tasks={},
|
||||||
|
no_downloads=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(result, LoadModel)
|
||||||
|
assert result.instance_id == INSTANCE_1_ID
|
||||||
|
|
||||||
|
|
||||||
def test_plan_does_not_load_model_until_all_shards_downloaded_globally():
|
def test_plan_does_not_load_model_until_all_shards_downloaded_globally():
|
||||||
"""
|
"""
|
||||||
LoadModel should not be emitted while some shards are still missing from
|
LoadModel should not be emitted while some shards are still missing from
|
||||||
|
|||||||
Reference in New Issue
Block a user