Compare commits

...

1 Commits

Author SHA1 Message Date
Alex Cheema
ce45af58e3 fix: --no-downloads no longer blocks model loading (#1510)
When --no-downloads is passed, skip download checks in plan() so
pre-staged models can be loaded without DownloadCompleted events.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-17 14:40:12 -08:00
4 changed files with 61 additions and 12 deletions

View File

@@ -94,6 +94,7 @@ class Node:
command_sender=router.sender(topics.COMMANDS), command_sender=router.sender(topics.COMMANDS),
download_command_sender=router.sender(topics.DOWNLOAD_COMMANDS), download_command_sender=router.sender(topics.DOWNLOAD_COMMANDS),
event_index_counter=event_index_counter, event_index_counter=event_index_counter,
no_downloads=args.no_downloads,
) )
else: else:
worker = None worker = None
@@ -225,6 +226,7 @@ class Node:
) )
self._tg.start_soon(self.download_coordinator.run) self._tg.start_soon(self.download_coordinator.run)
if self.worker: if self.worker:
no_downloads = self.worker.no_downloads
self.worker.shutdown() self.worker.shutdown()
# TODO: add profiling etc to resource monitor # TODO: add profiling etc to resource monitor
self.worker = Worker( self.worker = Worker(
@@ -239,6 +241,7 @@ class Node:
topics.DOWNLOAD_COMMANDS topics.DOWNLOAD_COMMANDS
), ),
event_index_counter=self.event_index_counter, event_index_counter=self.event_index_counter,
no_downloads=no_downloads,
) )
self._tg.start_soon(self.worker.run) self._tg.start_soon(self.worker.run)
if self.api: if self.api:

View File

@@ -65,6 +65,7 @@ class Worker:
command_sender: Sender[ForwarderCommand], command_sender: Sender[ForwarderCommand],
download_command_sender: Sender[ForwarderDownloadCommand], download_command_sender: Sender[ForwarderDownloadCommand],
event_index_counter: Iterator[int], event_index_counter: Iterator[int],
no_downloads: bool = False,
): ):
self.node_id: NodeId = node_id self.node_id: NodeId = node_id
self.session_id: SessionId = session_id self.session_id: SessionId = session_id
@@ -74,6 +75,7 @@ class Worker:
self.event_index_counter = event_index_counter self.event_index_counter = event_index_counter
self.command_sender = command_sender self.command_sender = command_sender
self.download_command_sender = download_command_sender self.download_command_sender = download_command_sender
self.no_downloads = no_downloads
self.event_buffer = OrderedBuffer[Event]() self.event_buffer = OrderedBuffer[Event]()
self.out_for_delivery: dict[EventId, ForwarderEvent] = {} self.out_for_delivery: dict[EventId, ForwarderEvent] = {}
@@ -182,6 +184,7 @@ class Worker:
self.state.tasks, self.state.tasks,
self.input_chunk_buffer, self.input_chunk_buffer,
self.input_chunk_counts, self.input_chunk_counts,
no_downloads=self.no_downloads,
) )
if task is None: if task is None:
continue continue

View File

@@ -51,15 +51,20 @@ def plan(
tasks: Mapping[TaskId, Task], tasks: Mapping[TaskId, Task],
input_chunk_buffer: Mapping[CommandId, dict[int, str]] | None = None, input_chunk_buffer: Mapping[CommandId, dict[int, str]] | None = None,
input_chunk_counts: Mapping[CommandId, int] | None = None, input_chunk_counts: Mapping[CommandId, int] | None = None,
no_downloads: bool = False,
) -> Task | None: ) -> Task | None:
# Python short circuiting OR logic should evaluate these sequentially. # Python short circuiting OR logic should evaluate these sequentially.
return ( return (
_cancel_tasks(runners, tasks) _cancel_tasks(runners, tasks)
or _kill_runner(runners, all_runners, instances) or _kill_runner(runners, all_runners, instances)
or _create_runner(node_id, runners, instances) or _create_runner(node_id, runners, instances)
or _model_needs_download(node_id, runners, global_download_status) or (
None
if no_downloads
else _model_needs_download(node_id, runners, global_download_status)
)
or _init_distributed_backend(runners, all_runners) or _init_distributed_backend(runners, all_runners)
or _load_model(runners, all_runners, global_download_status) or _load_model(runners, all_runners, global_download_status, no_downloads)
or _ready_to_warmup(runners, all_runners) or _ready_to_warmup(runners, all_runners)
or _pending_tasks(runners, tasks, all_runners, input_chunk_buffer or {}) or _pending_tasks(runners, tasks, all_runners, input_chunk_buffer or {})
) )
@@ -192,22 +197,25 @@ def _load_model(
runners: Mapping[RunnerId, RunnerSupervisor], runners: Mapping[RunnerId, RunnerSupervisor],
all_runners: Mapping[RunnerId, RunnerStatus], all_runners: Mapping[RunnerId, RunnerStatus],
global_download_status: Mapping[NodeId, Sequence[DownloadProgress]], global_download_status: Mapping[NodeId, Sequence[DownloadProgress]],
no_downloads: bool = False,
) -> LoadModel | None: ) -> LoadModel | None:
for runner in runners.values(): for runner in runners.values():
instance = runner.bound_instance.instance instance = runner.bound_instance.instance
shard_assignments = instance.shard_assignments shard_assignments = instance.shard_assignments
all_local_downloads_complete = all( if not no_downloads:
nid in global_download_status all_local_downloads_complete = all(
and any( nid in global_download_status
isinstance(dp, DownloadCompleted) and any(
and dp.shard_metadata.model_card.model_id == shard_assignments.model_id isinstance(dp, DownloadCompleted)
for dp in global_download_status[nid] and dp.shard_metadata.model_card.model_id
== shard_assignments.model_id
for dp in global_download_status[nid]
)
for nid in shard_assignments.node_to_runner
) )
for nid in shard_assignments.node_to_runner if not all_local_downloads_complete:
) continue
if not all_local_downloads_complete:
continue
is_single_node_instance = len(instance.shard_assignments.runner_to_shard) == 1 is_single_node_instance = len(instance.shard_assignments.runner_to_shard) == 1
if is_single_node_instance and isinstance(runner.status, RunnerIdle): if is_single_node_instance and isinstance(runner.status, RunnerIdle):

View File

@@ -157,6 +157,41 @@ def test_plan_does_not_request_download_when_shard_already_downloaded():
assert not isinstance(result, plan_mod.DownloadModel) assert not isinstance(result, plan_mod.DownloadModel)
def test_plan_loads_model_with_no_downloads_flag():
"""
When no_downloads=True, plan() should skip download checks and proceed
to load the model even with empty global_download_status.
"""
shard = get_pipeline_shard_metadata(model_id=MODEL_A_ID, device_rank=0)
instance = get_mlx_ring_instance(
instance_id=INSTANCE_1_ID,
model_id=MODEL_A_ID,
node_to_runner={NODE_A: RUNNER_1_ID},
runner_to_shard={RUNNER_1_ID: shard},
)
bound_instance = BoundInstance(
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
)
runner = FakeRunnerSupervisor(bound_instance=bound_instance, status=RunnerIdle())
runners = {RUNNER_1_ID: runner}
instances = {INSTANCE_1_ID: instance}
all_runners = {RUNNER_1_ID: RunnerIdle()}
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
global_download_status={},
instances=instances,
all_runners=all_runners,
tasks={},
no_downloads=True,
)
assert isinstance(result, LoadModel)
assert result.instance_id == INSTANCE_1_ID
def test_plan_does_not_load_model_until_all_shards_downloaded_globally(): def test_plan_does_not_load_model_until_all_shards_downloaded_globally():
""" """
LoadModel should not be emitted while some shards are still missing from LoadModel should not be emitted while some shards are still missing from