Compare commits

..

1 Commits

Author SHA1 Message Date
Alex Cheema
44845be117 fix instance type mismatch in /instance/previews endpoint
Derive instance_meta from the actual instance type returned by
place_instance() instead of using the requested instance_meta from
the loop variable. place_instance() overrides single-node placements
to MlxRing, but the preview response was still reporting the original
requested type (e.g., MlxJaccl), causing a mismatch.

Closes #1426

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-17 14:51:33 -08:00
5 changed files with 29 additions and 65 deletions

View File

@@ -94,7 +94,6 @@ class Node:
command_sender=router.sender(topics.COMMANDS),
download_command_sender=router.sender(topics.DOWNLOAD_COMMANDS),
event_index_counter=event_index_counter,
no_downloads=args.no_downloads,
)
else:
worker = None
@@ -226,7 +225,6 @@ class Node:
)
self._tg.start_soon(self.download_coordinator.run)
if self.worker:
no_downloads = self.worker.no_downloads
self.worker.shutdown()
# TODO: add profiling etc to resource monitor
self.worker = Worker(
@@ -241,7 +239,6 @@ class Node:
topics.DOWNLOAD_COMMANDS
),
event_index_counter=self.event_index_counter,
no_downloads=no_downloads,
)
self._tg.start_soon(self.worker.run)
if self.api:

View File

@@ -143,7 +143,12 @@ from exo.shared.types.openai_responses import (
ResponsesResponse,
)
from exo.shared.types.state import State
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.instances import (
Instance,
InstanceId,
InstanceMeta,
MlxJacclInstance,
)
from exo.shared.types.worker.shards import Sharding
from exo.utils.banner import print_startup_banner
from exo.utils.channels import Receiver, Sender, channel
@@ -467,6 +472,14 @@ class API:
shard_assignments = instance.shard_assignments
placement_node_ids = list(shard_assignments.node_to_runner.keys())
# Derive instance_meta from the actual instance type, since
# place_instance() may override it (e.g., single-node → MlxRing)
actual_instance_meta = (
InstanceMeta.MlxJaccl
if isinstance(instance, MlxJacclInstance)
else InstanceMeta.MlxRing
)
memory_delta_by_node: dict[str, int] = {}
if placement_node_ids:
total_bytes = model_card.storage_size.in_bytes
@@ -479,14 +492,14 @@ class API:
if (
model_card.model_id,
sharding,
instance_meta,
actual_instance_meta,
len(placement_node_ids),
) not in seen:
previews.append(
PlacementPreview(
model_id=model_card.model_id,
sharding=sharding,
instance_meta=instance_meta,
instance_meta=actual_instance_meta,
instance=instance,
memory_delta_by_node=memory_delta_by_node or None,
error=None,
@@ -496,7 +509,7 @@ class API:
(
model_card.model_id,
sharding,
instance_meta,
actual_instance_meta,
len(placement_node_ids),
)
)

View File

@@ -65,7 +65,6 @@ class Worker:
command_sender: Sender[ForwarderCommand],
download_command_sender: Sender[ForwarderDownloadCommand],
event_index_counter: Iterator[int],
no_downloads: bool = False,
):
self.node_id: NodeId = node_id
self.session_id: SessionId = session_id
@@ -75,7 +74,6 @@ class Worker:
self.event_index_counter = event_index_counter
self.command_sender = command_sender
self.download_command_sender = download_command_sender
self.no_downloads = no_downloads
self.event_buffer = OrderedBuffer[Event]()
self.out_for_delivery: dict[EventId, ForwarderEvent] = {}
@@ -184,7 +182,6 @@ class Worker:
self.state.tasks,
self.input_chunk_buffer,
self.input_chunk_counts,
no_downloads=self.no_downloads,
)
if task is None:
continue

View File

@@ -51,20 +51,15 @@ def plan(
tasks: Mapping[TaskId, Task],
input_chunk_buffer: Mapping[CommandId, dict[int, str]] | None = None,
input_chunk_counts: Mapping[CommandId, int] | None = None,
no_downloads: bool = False,
) -> Task | None:
# Python short circuiting OR logic should evaluate these sequentially.
return (
_cancel_tasks(runners, tasks)
or _kill_runner(runners, all_runners, instances)
or _create_runner(node_id, runners, instances)
or (
None
if no_downloads
else _model_needs_download(node_id, runners, global_download_status)
)
or _model_needs_download(node_id, runners, global_download_status)
or _init_distributed_backend(runners, all_runners)
or _load_model(runners, all_runners, global_download_status, no_downloads)
or _load_model(runners, all_runners, global_download_status)
or _ready_to_warmup(runners, all_runners)
or _pending_tasks(runners, tasks, all_runners, input_chunk_buffer or {})
)
@@ -197,25 +192,22 @@ def _load_model(
runners: Mapping[RunnerId, RunnerSupervisor],
all_runners: Mapping[RunnerId, RunnerStatus],
global_download_status: Mapping[NodeId, Sequence[DownloadProgress]],
no_downloads: bool = False,
) -> LoadModel | None:
for runner in runners.values():
instance = runner.bound_instance.instance
shard_assignments = instance.shard_assignments
if not no_downloads:
all_local_downloads_complete = all(
nid in global_download_status
and any(
isinstance(dp, DownloadCompleted)
and dp.shard_metadata.model_card.model_id
== shard_assignments.model_id
for dp in global_download_status[nid]
)
for nid in shard_assignments.node_to_runner
all_local_downloads_complete = all(
nid in global_download_status
and any(
isinstance(dp, DownloadCompleted)
and dp.shard_metadata.model_card.model_id == shard_assignments.model_id
for dp in global_download_status[nid]
)
if not all_local_downloads_complete:
continue
for nid in shard_assignments.node_to_runner
)
if not all_local_downloads_complete:
continue
is_single_node_instance = len(instance.shard_assignments.runner_to_shard) == 1
if is_single_node_instance and isinstance(runner.status, RunnerIdle):

View File

@@ -157,41 +157,6 @@ def test_plan_does_not_request_download_when_shard_already_downloaded():
assert not isinstance(result, plan_mod.DownloadModel)
def test_plan_loads_model_with_no_downloads_flag():
"""
When no_downloads=True, plan() should skip download checks and proceed
to load the model even with empty global_download_status.
"""
shard = get_pipeline_shard_metadata(model_id=MODEL_A_ID, device_rank=0)
instance = get_mlx_ring_instance(
instance_id=INSTANCE_1_ID,
model_id=MODEL_A_ID,
node_to_runner={NODE_A: RUNNER_1_ID},
runner_to_shard={RUNNER_1_ID: shard},
)
bound_instance = BoundInstance(
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
)
runner = FakeRunnerSupervisor(bound_instance=bound_instance, status=RunnerIdle())
runners = {RUNNER_1_ID: runner}
instances = {INSTANCE_1_ID: instance}
all_runners = {RUNNER_1_ID: RunnerIdle()}
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
global_download_status={},
instances=instances,
all_runners=all_runners,
tasks={},
no_downloads=True,
)
assert isinstance(result, LoadModel)
assert result.instance_id == INSTANCE_1_ID
def test_plan_does_not_load_model_until_all_shards_downloaded_globally():
"""
LoadModel should not be emitted while some shards are still missing from