diff --git a/src/exo/master/placement.py b/src/exo/master/placement.py index 98742924..c0862c10 100644 --- a/src/exo/master/placement.py +++ b/src/exo/master/placement.py @@ -8,7 +8,7 @@ from loguru import logger from exo.master.placement_utils import ( filter_cycles_by_memory, get_hosts_from_subgraph, - get_mlx_ibv_coordinator, + get_mlx_ibv_coordinators, get_mlx_ibv_devices_matrix, get_shard_assignments, get_smallest_cycles, @@ -110,15 +110,16 @@ def get_instance_placements_after_create( selected_cycle, cycle_digraph, ) - mlx_ibv_coordinator = get_mlx_ibv_coordinator( + mlx_ibv_coordinators = get_mlx_ibv_coordinators( selected_cycle, coordinator_port=random_ephemeral_port(), + cycle_digraph=cycle_digraph, ) target_instances[instance_id] = MlxJacclInstance( instance_id=instance_id, shard_assignments=shard_assignments, ibv_devices=mlx_ibv_devices, - ibv_coordinator=mlx_ibv_coordinator, + ibv_coordinators=mlx_ibv_coordinators, ) case InstanceMeta.MlxRing: hosts: list[Host] = get_hosts_from_subgraph(cycle_digraph) diff --git a/src/exo/master/placement_utils.py b/src/exo/master/placement_utils.py index 8cb81adb..24461b42 100644 --- a/src/exo/master/placement_utils.py +++ b/src/exo/master/placement_utils.py @@ -269,20 +269,31 @@ def _find_interface_name_for_ip( return None -def get_mlx_ibv_coordinator( +def get_mlx_ibv_coordinators( selected_cycle: list[NodeInfo], coordinator_port: int, -) -> str: - """Get the coordinator address for MLX IBV (rank 0 device). + cycle_digraph: Topology, +) -> dict[NodeId, str]: + """Get the coordinator addresses for MLX IBV (rank 0 device). - Selects a non-thunderbolt IP address from rank 0 node as a heuristic for - ethernet accessibility. Returns address in format "X.X.X.X:PORT". + Select an IP address that each node can reach for the rank 0 node. Returns + address in format "X.X.X.X:PORT" per node. """ rank_0_node = selected_cycle[0] logger.info(f"Selecting coordinator from rank 0 node: {rank_0_node.node_id}") - assert rank_0_node.node_profile is not None - for iface in rank_0_node.node_profile.network_interfaces: - if iface.name == "en0" and "." in iface.ip_address: - return f"{iface.ip_address}:{coordinator_port}" - raise ValueError("No en0 iface found for device") + def get_ip_for_node(n: NodeInfo) -> str: + if n.node_id == rank_0_node.node_id: + return "0.0.0.0" + + for ip in _find_connection_ip(n, rank_0_node, cycle_digraph): + return ip + + logger.warning( + f"Failed to find directly connected ip between {n.node_id} and {rank_0_node.node_id}" + ) + raise ValueError("Current ibv backend requires all-to-all rdma connections") + + return { + n.node_id: f"{get_ip_for_node(n)}:{coordinator_port}" for n in selected_cycle + } diff --git a/src/exo/master/tests/test_master.py b/src/exo/master/tests/test_master.py index 90c55c5b..a87abc34 100644 --- a/src/exo/master/tests/test_master.py +++ b/src/exo/master/tests/test_master.py @@ -166,28 +166,28 @@ async def test_master(): events[1].event.instance.shard_assignments.runner_to_shard.keys() )[0] assert events[1].event.instance == MlxRingInstance( - instance_id=events[1].event.instance.instance_id, - shard_assignments=ShardAssignments( - model_id=ModelId("llama-3.2-1b"), - runner_to_shard={ - (runner_id): PipelineShardMetadata( - start_layer=0, - end_layer=16, + instance_id=events[1].event.instance.instance_id, + shard_assignments=ShardAssignments( + model_id=ModelId("llama-3.2-1b"), + runner_to_shard={ + (runner_id): PipelineShardMetadata( + start_layer=0, + end_layer=16, + n_layers=16, + model_meta=ModelMetadata( + model_id=ModelId("llama-3.2-1b"), + pretty_name="Llama 3.2 1B", n_layers=16, - model_meta=ModelMetadata( - model_id=ModelId("llama-3.2-1b"), - pretty_name="Llama 3.2 1B", - n_layers=16, - storage_size=Memory.from_bytes(678948), - ), - device_rank=0, - world_size=1, - ) - }, - node_to_runner={node_id: runner_id}, - ), - hosts=[], - ) + storage_size=Memory.from_bytes(678948), + ), + device_rank=0, + world_size=1, + ) + }, + node_to_runner={node_id: runner_id}, + ), + hosts=[], + ) assert isinstance(events[2].event, TaskCreated) assert events[2].event.task.task_status == TaskStatus.Pending assert isinstance(events[2].event.task, ChatCompletionTask) diff --git a/src/exo/master/tests/test_placement.py b/src/exo/master/tests/test_placement.py index 95cb33bc..1bfdf4e2 100644 --- a/src/exo/master/tests/test_placement.py +++ b/src/exo/master/tests/test_placement.py @@ -437,7 +437,7 @@ def test_tensor_rdma_backend_connectivity_matrix( assert isinstance(instance, MlxJacclInstance) assert instance.ibv_devices is not None - assert instance.ibv_coordinator is not None + assert instance.ibv_coordinators is not None matrix = instance.ibv_devices assert len(matrix) == 3 @@ -458,5 +458,17 @@ def test_tensor_rdma_backend_connectivity_matrix( assert matrix[idx_b][idx_c] == "rdma_en3" assert matrix[idx_c][idx_a] == "rdma_en3" - assert ":" in instance.ibv_coordinator - assert not instance.ibv_coordinator.startswith("169.254") + # Verify coordinators are set for all nodes + assert len(instance.ibv_coordinators) == 3 + for node_id in assigned_nodes: + assert node_id in instance.ibv_coordinators + coordinator = instance.ibv_coordinators[node_id] + assert ":" in coordinator + # Rank 0 node should use 0.0.0.0, others should use connection-specific IPs + if node_id == assigned_nodes[0]: + assert coordinator.startswith("0.0.0.0:") + else: + # Non-rank-0 nodes should have valid IP addresses (can be link-local) + ip_part = coordinator.split(":")[0] + # Just verify it's a valid IP format + assert len(ip_part.split(".")) == 4 diff --git a/src/exo/master/tests/test_placement_utils.py b/src/exo/master/tests/test_placement_utils.py index eb1d4e10..ff6de72c 100644 --- a/src/exo/master/tests/test_placement_utils.py +++ b/src/exo/master/tests/test_placement_utils.py @@ -5,6 +5,7 @@ import pytest from exo.master.placement_utils import ( filter_cycles_by_memory, get_hosts_from_subgraph, + get_mlx_ibv_coordinators, get_shard_assignments, get_smallest_cycles, ) @@ -12,6 +13,7 @@ from exo.shared.topology import Topology from exo.shared.types.common import Host, NodeId from exo.shared.types.memory import Memory from exo.shared.types.models import ModelId, ModelMetadata +from exo.shared.types.profiling import NetworkInterfaceInfo, NodePerformanceProfile from exo.shared.types.topology import Connection, NodeInfo from exo.shared.types.worker.shards import Sharding @@ -261,3 +263,135 @@ def test_get_hosts_from_subgraph( ] for expected_host in expected_hosts: assert expected_host in hosts + + +def test_get_mlx_ibv_coordinators( + topology: Topology, + create_node: Callable[[int, NodeId | None], NodeInfo], + create_connection: Callable[[NodeId, NodeId, int | None], Connection], +): + # arrange + node_a_id = NodeId() + node_b_id = NodeId() + node_c_id = NodeId() + + node_a = create_node(500 * 1024, node_a_id) + node_b = create_node(500 * 1024, node_b_id) + node_c = create_node(1000 * 1024, node_c_id) + + conn_a_b = create_connection(node_a_id, node_b_id, 5001) + conn_b_a = create_connection(node_b_id, node_a_id, 5002) + conn_b_c = create_connection(node_b_id, node_c_id, 5003) + conn_c_b = create_connection(node_c_id, node_b_id, 5004) + conn_c_a = create_connection(node_c_id, node_a_id, 5005) + conn_a_c = create_connection(node_a_id, node_c_id, 5006) + + # Update node profiles with network interfaces before adding to topology + assert node_a.node_profile is not None + assert node_b.node_profile is not None + assert node_c.node_profile is not None + + node_a.node_profile = NodePerformanceProfile( + model_id="test", + chip_id="test", + friendly_name="test", + memory=node_a.node_profile.memory, + network_interfaces=[ + NetworkInterfaceInfo( + name="en3", + ip_address=conn_a_b.send_back_multiaddr.ip_address, + ), + NetworkInterfaceInfo( + name="en4", + ip_address=conn_a_c.send_back_multiaddr.ip_address, + ), + ], + system=node_a.node_profile.system, + ) + node_b.node_profile = NodePerformanceProfile( + model_id="test", + chip_id="test", + friendly_name="test", + memory=node_b.node_profile.memory, + network_interfaces=[ + NetworkInterfaceInfo( + name="en3", + ip_address=conn_b_a.send_back_multiaddr.ip_address, + ), + NetworkInterfaceInfo( + name="en4", + ip_address=conn_b_c.send_back_multiaddr.ip_address, + ), + ], + system=node_b.node_profile.system, + ) + node_c.node_profile = NodePerformanceProfile( + model_id="test", + chip_id="test", + friendly_name="test", + memory=node_c.node_profile.memory, + network_interfaces=[ + NetworkInterfaceInfo( + name="en3", + ip_address=conn_c_b.send_back_multiaddr.ip_address, + ), + NetworkInterfaceInfo( + name="en4", + ip_address=conn_c_a.send_back_multiaddr.ip_address, + ), + ], + system=node_c.node_profile.system, + ) + + topology.add_node(node_a) + topology.add_node(node_b) + topology.add_node(node_c) + + topology.add_connection(conn_a_b) + topology.add_connection(conn_b_a) + topology.add_connection(conn_b_c) + topology.add_connection(conn_c_b) + topology.add_connection(conn_c_a) + topology.add_connection(conn_a_c) + + cycle = [node_a, node_b, node_c] + + # act + coordinators = get_mlx_ibv_coordinators( + cycle, coordinator_port=5000, cycle_digraph=topology + ) + + # assert + assert len(coordinators) == 3 + assert node_a_id in coordinators + assert node_b_id in coordinators + assert node_c_id in coordinators + + # All coordinators should have IP:PORT format + for node_id, coordinator in coordinators.items(): + assert ":" in coordinator, ( + f"Coordinator for {node_id} should have ':' separator" + ) + + # Verify port is correct + for node_id, coordinator in coordinators.items(): + assert coordinator.endswith(":5000"), ( + f"Coordinator for {node_id} should use port 5000" + ) + + # Rank 0 (node_a) treats this as the listen socket so should listen on all + # IPs + assert coordinators[node_a_id].startswith("0.0.0.0:"), ( + "Rank 0 node should use localhost as coordinator" + ) + + # Non-rank-0 nodes should use the specific IP from their connection to rank 0 + # node_b uses the IP from conn_b_a (node_b -> node_a) + assert coordinators[node_b_id] == ( + f"{conn_b_a.send_back_multiaddr.ip_address}:5000" + ), "node_b should use the IP from conn_b_a" + + # node_c uses the IP from conn_c_a (node_c -> node_a) + assert coordinators[node_c_id] == ( + f"{conn_c_a.send_back_multiaddr.ip_address}:5000" + ), "node_c should use the IP from conn_c_a" diff --git a/src/exo/routing/tests/test_event_buffer.py b/src/exo/routing/tests/test_event_buffer.py index 0e3e458c..215f53e2 100644 --- a/src/exo/routing/tests/test_event_buffer.py +++ b/src/exo/routing/tests/test_event_buffer.py @@ -95,7 +95,7 @@ async def test_ingest_drops_duplicate_indices(buffer: OrderedBuffer[Event]): buffer.ingest(*make_indexed_event(0)) buffer.ingest(*event2_first) - + with pytest.raises(AssertionError): buffer.ingest(*event2_second) # This duplicate should be ignored diff --git a/src/exo/shared/election.py b/src/exo/shared/election.py index ccbbee52..b4dc36b6 100644 --- a/src/exo/shared/election.py +++ b/src/exo/shared/election.py @@ -18,6 +18,7 @@ from exo.utils.pydantic_ext import CamelCaseModel DEFAULT_ELECTION_TIMEOUT = 3.0 + class ElectionMessage(CamelCaseModel): clock: int seniority: int @@ -152,7 +153,9 @@ class Election: self._candidates = candidates logger.debug(f"New candidates: {self._candidates}") logger.debug("Starting new campaign") - self._tg.start_soon(self._campaign, candidates, DEFAULT_ELECTION_TIMEOUT) + self._tg.start_soon( + self._campaign, candidates, DEFAULT_ELECTION_TIMEOUT + ) logger.debug("Campaign started") continue # Dismiss old messages @@ -181,7 +184,9 @@ class Election: candidates: list[ElectionMessage] = [] self._candidates = candidates logger.debug("Starting new campaign") - self._tg.start_soon(self._campaign, candidates, DEFAULT_ELECTION_TIMEOUT) + self._tg.start_soon( + self._campaign, candidates, DEFAULT_ELECTION_TIMEOUT + ) logger.debug("Campaign started") self._connection_messages.append(first) self._connection_messages.extend(rest) diff --git a/src/exo/shared/tests/test_election.py b/src/exo/shared/tests/test_election.py index 525b35a2..77686a0c 100644 --- a/src/exo/shared/tests/test_election.py +++ b/src/exo/shared/tests/test_election.py @@ -40,6 +40,7 @@ def em( # TESTS # # ======================================= # + @pytest.fixture(autouse=True) def fast_election_timeout(monkeypatch: pytest.MonkeyPatch): monkeypatch.setattr("exo.shared.election.DEFAULT_ELECTION_TIMEOUT", 0.1) diff --git a/src/exo/shared/types/worker/instances.py b/src/exo/shared/types/worker/instances.py index e36c4fb0..ea8e7887 100644 --- a/src/exo/shared/types/worker/instances.py +++ b/src/exo/shared/types/worker/instances.py @@ -2,7 +2,7 @@ from enum import Enum from pydantic import model_validator -from exo.shared.types.common import Host, Id +from exo.shared.types.common import Host, Id, NodeId from exo.shared.types.worker.runners import RunnerId, ShardAssignments, ShardMetadata from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel @@ -30,7 +30,7 @@ class MlxRingInstance(BaseInstance): class MlxJacclInstance(BaseInstance): ibv_devices: list[list[str | None]] - ibv_coordinator: str + ibv_coordinators: dict[NodeId, str] # TODO: Single node instance @@ -40,6 +40,7 @@ Instance = MlxRingInstance | MlxJacclInstance class BoundInstance(CamelCaseModel): instance: Instance bound_runner_id: RunnerId + bound_node_id: NodeId @property def bound_shard(self) -> ShardMetadata: diff --git a/src/exo/worker/engines/mlx/utils_mlx.py b/src/exo/worker/engines/mlx/utils_mlx.py index dc6d1e45..3606b90b 100644 --- a/src/exo/worker/engines/mlx/utils_mlx.py +++ b/src/exo/worker/engines/mlx/utils_mlx.py @@ -128,7 +128,9 @@ def mlx_distributed_init( os.environ["MLX_RING_VERBOSE"] = "1" group = mx.distributed.init(backend="ring", strict=True) - case MlxJacclInstance(ibv_devices=ibv_devices, ibv_coordinator=ibv_coordinator): + case MlxJacclInstance( + ibv_devices=ibv_devices, ibv_coordinators=ibv_coordinators + ): # Use RDMA connectivity matrix devices_file = f"./hosts_{rank}.json" ibv_devices_json = json.dumps(ibv_devices) @@ -136,6 +138,8 @@ def mlx_distributed_init( with open(devices_file, "w") as f: _ = f.write(ibv_devices_json) + ibv_coordinator = ibv_coordinators[bound_instance.bound_node_id] + logger.info(f"rank {rank} MLX_IBV_DEVICES: {ibv_devices_json}") logger.info(f"rank {rank} MLX_IBV_COORDINATOR: {ibv_coordinator}") os.environ["MLX_IBV_DEVICES"] = devices_file diff --git a/src/exo/worker/plan.py b/src/exo/worker/plan.py index 9d1806ad..01106d24 100644 --- a/src/exo/worker/plan.py +++ b/src/exo/worker/plan.py @@ -95,7 +95,9 @@ def _create_runner( return CreateRunner( instance_id=instance.instance_id, - bound_instance=BoundInstance(instance=instance, bound_runner_id=runner_id), + bound_instance=BoundInstance( + instance=instance, bound_runner_id=runner_id, bound_node_id=node_id + ), ) diff --git a/src/exo/worker/tests/unittests/test_plan/test_download_and_loading.py b/src/exo/worker/tests/unittests/test_plan/test_download_and_loading.py index d64df456..5d6e4e2c 100644 --- a/src/exo/worker/tests/unittests/test_plan/test_download_and_loading.py +++ b/src/exo/worker/tests/unittests/test_plan/test_download_and_loading.py @@ -35,7 +35,9 @@ def test_plan_requests_download_when_waiting_and_shard_not_downloaded(): node_to_runner={NODE_A: RUNNER_1_ID}, runner_to_shard={RUNNER_1_ID: shard}, ) - bound_instance = BoundInstance(instance=instance, bound_runner_id=RUNNER_1_ID) + bound_instance = BoundInstance( + instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A + ) runner = FakeRunnerSupervisor( bound_instance=bound_instance, status=RunnerWaitingForModel() ) @@ -76,7 +78,9 @@ def test_plan_loads_model_when_all_shards_downloaded_and_waiting(): node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID}, runner_to_shard={RUNNER_1_ID: shard1, RUNNER_2_ID: shard2}, ) - bound_instance = BoundInstance(instance=instance, bound_runner_id=RUNNER_1_ID) + bound_instance = BoundInstance( + instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A + ) local_runner = FakeRunnerSupervisor( bound_instance=bound_instance, status=RunnerWaitingForModel() ) @@ -126,7 +130,9 @@ def test_plan_does_not_request_download_when_shard_already_downloaded(): node_to_runner={NODE_A: RUNNER_1_ID}, runner_to_shard={RUNNER_1_ID: shard}, ) - bound_instance = BoundInstance(instance=instance, bound_runner_id=RUNNER_1_ID) + bound_instance = BoundInstance( + instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A + ) runner = FakeRunnerSupervisor( bound_instance=bound_instance, status=RunnerWaitingForModel() ) @@ -173,7 +179,9 @@ def test_plan_does_not_load_model_until_all_shards_downloaded_globally(): runner_to_shard={RUNNER_1_ID: shard1, RUNNER_2_ID: shard2}, ) - bound_instance = BoundInstance(instance=instance, bound_runner_id=RUNNER_1_ID) + bound_instance = BoundInstance( + instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A + ) local_runner = FakeRunnerSupervisor( bound_instance=bound_instance, status=RunnerWaitingForModel() ) diff --git a/src/exo/worker/tests/unittests/test_plan/test_runner_lifecycle.py b/src/exo/worker/tests/unittests/test_plan/test_runner_lifecycle.py index 056de505..944cb6db 100644 --- a/src/exo/worker/tests/unittests/test_plan/test_runner_lifecycle.py +++ b/src/exo/worker/tests/unittests/test_plan/test_runner_lifecycle.py @@ -36,7 +36,9 @@ def test_plan_kills_runner_when_instance_missing(): node_to_runner={NODE_A: RUNNER_1_ID}, runner_to_shard={RUNNER_1_ID: shard}, ) - bound_instance = BoundInstance(instance=instance, bound_runner_id=RUNNER_1_ID) + bound_instance = BoundInstance( + instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A + ) runner = FakeRunnerSupervisor(bound_instance=bound_instance, status=RunnerReady()) runners = {RUNNER_1_ID: runner} @@ -71,7 +73,9 @@ def test_plan_kills_runner_when_sibling_failed(): node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID}, runner_to_shard={RUNNER_1_ID: shard1, RUNNER_2_ID: shard2}, ) - bound_instance = BoundInstance(instance=instance, bound_runner_id=RUNNER_1_ID) + bound_instance = BoundInstance( + instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A + ) runner = FakeRunnerSupervisor(bound_instance=bound_instance, status=RunnerReady()) runners = {RUNNER_1_ID: runner} @@ -143,7 +147,9 @@ def test_plan_does_not_create_runner_when_supervisor_already_present(): node_to_runner={NODE_A: RUNNER_1_ID}, runner_to_shard={RUNNER_1_ID: shard}, ) - bound_instance = BoundInstance(instance=instance, bound_runner_id=RUNNER_1_ID) + bound_instance = BoundInstance( + instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A + ) runner = FakeRunnerSupervisor(bound_instance=bound_instance, status=RunnerReady()) runners = {RUNNER_1_ID: runner} diff --git a/src/exo/worker/tests/unittests/test_plan/test_task_forwarding.py b/src/exo/worker/tests/unittests/test_plan/test_task_forwarding.py index b1500e74..1bf985ac 100644 --- a/src/exo/worker/tests/unittests/test_plan/test_task_forwarding.py +++ b/src/exo/worker/tests/unittests/test_plan/test_task_forwarding.py @@ -40,7 +40,9 @@ def test_plan_forwards_pending_chat_completion_when_runner_ready(): node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID}, runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1}, ) - bound_instance = BoundInstance(instance=instance, bound_runner_id=RUNNER_1_ID) + bound_instance = BoundInstance( + instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A + ) local_runner = FakeRunnerSupervisor( bound_instance=bound_instance, status=RunnerReady() ) @@ -86,7 +88,9 @@ def test_plan_does_not_forward_chat_completion_if_any_runner_not_ready(): node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID}, runner_to_shard={RUNNER_1_ID: shard1, RUNNER_2_ID: shard2}, ) - bound_instance = BoundInstance(instance=instance, bound_runner_id=RUNNER_1_ID) + bound_instance = BoundInstance( + instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A + ) local_runner = FakeRunnerSupervisor( bound_instance=bound_instance, status=RunnerReady() ) @@ -131,7 +135,9 @@ def test_plan_does_not_forward_tasks_for_other_instances(): node_to_runner={NODE_A: RUNNER_1_ID}, runner_to_shard={RUNNER_1_ID: shard}, ) - bound_instance = BoundInstance(instance=local_instance, bound_runner_id=RUNNER_1_ID) + bound_instance = BoundInstance( + instance=local_instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A + ) local_runner = FakeRunnerSupervisor( bound_instance=bound_instance, status=RunnerReady() ) @@ -175,7 +181,9 @@ def test_plan_ignores_non_pending_or_non_chat_tasks(): node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID}, runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1}, ) - bound_instance = BoundInstance(instance=instance, bound_runner_id=RUNNER_1_ID) + bound_instance = BoundInstance( + instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A + ) local_runner = FakeRunnerSupervisor( bound_instance=bound_instance, status=RunnerReady() @@ -236,7 +244,9 @@ def test_plan_returns_none_when_nothing_to_do(): node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID}, runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1}, ) - bound_instance = BoundInstance(instance=instance, bound_runner_id=RUNNER_1_ID) + bound_instance = BoundInstance( + instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A + ) local_runner = FakeRunnerSupervisor( bound_instance=bound_instance, status=RunnerRunning() ) diff --git a/src/exo/worker/tests/unittests/test_plan/test_warmup.py b/src/exo/worker/tests/unittests/test_plan/test_warmup.py index ed0f0d2b..f47d24c9 100644 --- a/src/exo/worker/tests/unittests/test_plan/test_warmup.py +++ b/src/exo/worker/tests/unittests/test_plan/test_warmup.py @@ -35,7 +35,9 @@ def test_plan_starts_warmup_for_non_zero_rank_when_all_loaded_or_warming(): runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1}, ) - bound_instance = BoundInstance(instance=instance, bound_runner_id=RUNNER_2_ID) + bound_instance = BoundInstance( + instance=instance, bound_runner_id=RUNNER_2_ID, bound_node_id=NODE_B + ) local_runner = FakeRunnerSupervisor( bound_instance=bound_instance, status=RunnerLoaded() ) @@ -75,7 +77,9 @@ def test_plan_starts_warmup_for_rank_zero_after_others_warming(): runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1}, ) - bound_instance = BoundInstance(instance=instance, bound_runner_id=RUNNER_1_ID) + bound_instance = BoundInstance( + instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A + ) local_runner = FakeRunnerSupervisor( bound_instance=bound_instance, status=RunnerLoaded() ) @@ -114,7 +118,9 @@ def test_plan_does_not_start_warmup_for_non_zero_rank_until_all_loaded_or_warmin runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1}, ) - bound_instance = BoundInstance(instance=instance, bound_runner_id=RUNNER_2_ID) + bound_instance = BoundInstance( + instance=instance, bound_runner_id=RUNNER_2_ID, bound_node_id=NODE_B + ) local_runner = FakeRunnerSupervisor( bound_instance=bound_instance, status=RunnerLoaded() ) @@ -153,7 +159,9 @@ def test_plan_does_not_start_warmup_for_rank_zero_until_others_warming(): runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1}, ) - bound_instance = BoundInstance(instance=instance, bound_runner_id=RUNNER_1_ID) + bound_instance = BoundInstance( + instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A + ) local_runner = FakeRunnerSupervisor( bound_instance=bound_instance, status=RunnerLoaded() )