mirror of
https://github.com/exo-explore/exo.git
synced 2025-12-23 22:27:50 -05:00
placement: pass different ibv_coordinator per node
This commit is contained in:
@@ -8,7 +8,7 @@ from loguru import logger
|
||||
from exo.master.placement_utils import (
|
||||
filter_cycles_by_memory,
|
||||
get_hosts_from_subgraph,
|
||||
get_mlx_ibv_coordinator,
|
||||
get_mlx_ibv_coordinators,
|
||||
get_mlx_ibv_devices_matrix,
|
||||
get_shard_assignments,
|
||||
get_smallest_cycles,
|
||||
@@ -110,15 +110,16 @@ def get_instance_placements_after_create(
|
||||
selected_cycle,
|
||||
cycle_digraph,
|
||||
)
|
||||
mlx_ibv_coordinator = get_mlx_ibv_coordinator(
|
||||
mlx_ibv_coordinators = get_mlx_ibv_coordinators(
|
||||
selected_cycle,
|
||||
coordinator_port=random_ephemeral_port(),
|
||||
cycle_digraph=cycle_digraph,
|
||||
)
|
||||
target_instances[instance_id] = MlxJacclInstance(
|
||||
instance_id=instance_id,
|
||||
shard_assignments=shard_assignments,
|
||||
ibv_devices=mlx_ibv_devices,
|
||||
ibv_coordinator=mlx_ibv_coordinator,
|
||||
ibv_coordinators=mlx_ibv_coordinators,
|
||||
)
|
||||
case InstanceMeta.MlxRing:
|
||||
hosts: list[Host] = get_hosts_from_subgraph(cycle_digraph)
|
||||
|
||||
@@ -269,20 +269,31 @@ def _find_interface_name_for_ip(
|
||||
return None
|
||||
|
||||
|
||||
def get_mlx_ibv_coordinator(
|
||||
def get_mlx_ibv_coordinators(
|
||||
selected_cycle: list[NodeInfo],
|
||||
coordinator_port: int,
|
||||
) -> str:
|
||||
"""Get the coordinator address for MLX IBV (rank 0 device).
|
||||
cycle_digraph: Topology,
|
||||
) -> dict[NodeId, str]:
|
||||
"""Get the coordinator addresses for MLX IBV (rank 0 device).
|
||||
|
||||
Selects a non-thunderbolt IP address from rank 0 node as a heuristic for
|
||||
ethernet accessibility. Returns address in format "X.X.X.X:PORT".
|
||||
Select an IP address that each node can reach for the rank 0 node. Returns
|
||||
address in format "X.X.X.X:PORT" per node.
|
||||
"""
|
||||
rank_0_node = selected_cycle[0]
|
||||
logger.info(f"Selecting coordinator from rank 0 node: {rank_0_node.node_id}")
|
||||
assert rank_0_node.node_profile is not None
|
||||
for iface in rank_0_node.node_profile.network_interfaces:
|
||||
if iface.name == "en0" and "." in iface.ip_address:
|
||||
return f"{iface.ip_address}:{coordinator_port}"
|
||||
|
||||
raise ValueError("No en0 iface found for device")
|
||||
def get_ip_for_node(n: NodeInfo) -> str:
|
||||
if n.node_id == rank_0_node.node_id:
|
||||
return "0.0.0.0"
|
||||
|
||||
for ip in _find_connection_ip(n, rank_0_node, cycle_digraph):
|
||||
return ip
|
||||
|
||||
logger.warning(
|
||||
f"Failed to find directly connected ip between {n.node_id} and {rank_0_node.node_id}"
|
||||
)
|
||||
raise ValueError("Current ibv backend requires all-to-all rdma connections")
|
||||
|
||||
return {
|
||||
n.node_id: f"{get_ip_for_node(n)}:{coordinator_port}" for n in selected_cycle
|
||||
}
|
||||
|
||||
@@ -166,28 +166,28 @@ async def test_master():
|
||||
events[1].event.instance.shard_assignments.runner_to_shard.keys()
|
||||
)[0]
|
||||
assert events[1].event.instance == MlxRingInstance(
|
||||
instance_id=events[1].event.instance.instance_id,
|
||||
shard_assignments=ShardAssignments(
|
||||
model_id=ModelId("llama-3.2-1b"),
|
||||
runner_to_shard={
|
||||
(runner_id): PipelineShardMetadata(
|
||||
start_layer=0,
|
||||
end_layer=16,
|
||||
instance_id=events[1].event.instance.instance_id,
|
||||
shard_assignments=ShardAssignments(
|
||||
model_id=ModelId("llama-3.2-1b"),
|
||||
runner_to_shard={
|
||||
(runner_id): PipelineShardMetadata(
|
||||
start_layer=0,
|
||||
end_layer=16,
|
||||
n_layers=16,
|
||||
model_meta=ModelMetadata(
|
||||
model_id=ModelId("llama-3.2-1b"),
|
||||
pretty_name="Llama 3.2 1B",
|
||||
n_layers=16,
|
||||
model_meta=ModelMetadata(
|
||||
model_id=ModelId("llama-3.2-1b"),
|
||||
pretty_name="Llama 3.2 1B",
|
||||
n_layers=16,
|
||||
storage_size=Memory.from_bytes(678948),
|
||||
),
|
||||
device_rank=0,
|
||||
world_size=1,
|
||||
)
|
||||
},
|
||||
node_to_runner={node_id: runner_id},
|
||||
),
|
||||
hosts=[],
|
||||
)
|
||||
storage_size=Memory.from_bytes(678948),
|
||||
),
|
||||
device_rank=0,
|
||||
world_size=1,
|
||||
)
|
||||
},
|
||||
node_to_runner={node_id: runner_id},
|
||||
),
|
||||
hosts=[],
|
||||
)
|
||||
assert isinstance(events[2].event, TaskCreated)
|
||||
assert events[2].event.task.task_status == TaskStatus.Pending
|
||||
assert isinstance(events[2].event.task, ChatCompletionTask)
|
||||
|
||||
@@ -437,7 +437,7 @@ def test_tensor_rdma_backend_connectivity_matrix(
|
||||
assert isinstance(instance, MlxJacclInstance)
|
||||
|
||||
assert instance.ibv_devices is not None
|
||||
assert instance.ibv_coordinator is not None
|
||||
assert instance.ibv_coordinators is not None
|
||||
|
||||
matrix = instance.ibv_devices
|
||||
assert len(matrix) == 3
|
||||
@@ -458,5 +458,17 @@ def test_tensor_rdma_backend_connectivity_matrix(
|
||||
assert matrix[idx_b][idx_c] == "rdma_en3"
|
||||
assert matrix[idx_c][idx_a] == "rdma_en3"
|
||||
|
||||
assert ":" in instance.ibv_coordinator
|
||||
assert not instance.ibv_coordinator.startswith("169.254")
|
||||
# Verify coordinators are set for all nodes
|
||||
assert len(instance.ibv_coordinators) == 3
|
||||
for node_id in assigned_nodes:
|
||||
assert node_id in instance.ibv_coordinators
|
||||
coordinator = instance.ibv_coordinators[node_id]
|
||||
assert ":" in coordinator
|
||||
# Rank 0 node should use 0.0.0.0, others should use connection-specific IPs
|
||||
if node_id == assigned_nodes[0]:
|
||||
assert coordinator.startswith("0.0.0.0:")
|
||||
else:
|
||||
# Non-rank-0 nodes should have valid IP addresses (can be link-local)
|
||||
ip_part = coordinator.split(":")[0]
|
||||
# Just verify it's a valid IP format
|
||||
assert len(ip_part.split(".")) == 4
|
||||
|
||||
@@ -5,6 +5,7 @@ import pytest
|
||||
from exo.master.placement_utils import (
|
||||
filter_cycles_by_memory,
|
||||
get_hosts_from_subgraph,
|
||||
get_mlx_ibv_coordinators,
|
||||
get_shard_assignments,
|
||||
get_smallest_cycles,
|
||||
)
|
||||
@@ -12,6 +13,7 @@ from exo.shared.topology import Topology
|
||||
from exo.shared.types.common import Host, NodeId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.models import ModelId, ModelMetadata
|
||||
from exo.shared.types.profiling import NetworkInterfaceInfo, NodePerformanceProfile
|
||||
from exo.shared.types.topology import Connection, NodeInfo
|
||||
from exo.shared.types.worker.shards import Sharding
|
||||
|
||||
@@ -261,3 +263,135 @@ def test_get_hosts_from_subgraph(
|
||||
]
|
||||
for expected_host in expected_hosts:
|
||||
assert expected_host in hosts
|
||||
|
||||
|
||||
def test_get_mlx_ibv_coordinators(
|
||||
topology: Topology,
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
create_connection: Callable[[NodeId, NodeId, int | None], Connection],
|
||||
):
|
||||
# arrange
|
||||
node_a_id = NodeId()
|
||||
node_b_id = NodeId()
|
||||
node_c_id = NodeId()
|
||||
|
||||
node_a = create_node(500 * 1024, node_a_id)
|
||||
node_b = create_node(500 * 1024, node_b_id)
|
||||
node_c = create_node(1000 * 1024, node_c_id)
|
||||
|
||||
conn_a_b = create_connection(node_a_id, node_b_id, 5001)
|
||||
conn_b_a = create_connection(node_b_id, node_a_id, 5002)
|
||||
conn_b_c = create_connection(node_b_id, node_c_id, 5003)
|
||||
conn_c_b = create_connection(node_c_id, node_b_id, 5004)
|
||||
conn_c_a = create_connection(node_c_id, node_a_id, 5005)
|
||||
conn_a_c = create_connection(node_a_id, node_c_id, 5006)
|
||||
|
||||
# Update node profiles with network interfaces before adding to topology
|
||||
assert node_a.node_profile is not None
|
||||
assert node_b.node_profile is not None
|
||||
assert node_c.node_profile is not None
|
||||
|
||||
node_a.node_profile = NodePerformanceProfile(
|
||||
model_id="test",
|
||||
chip_id="test",
|
||||
friendly_name="test",
|
||||
memory=node_a.node_profile.memory,
|
||||
network_interfaces=[
|
||||
NetworkInterfaceInfo(
|
||||
name="en3",
|
||||
ip_address=conn_a_b.send_back_multiaddr.ip_address,
|
||||
),
|
||||
NetworkInterfaceInfo(
|
||||
name="en4",
|
||||
ip_address=conn_a_c.send_back_multiaddr.ip_address,
|
||||
),
|
||||
],
|
||||
system=node_a.node_profile.system,
|
||||
)
|
||||
node_b.node_profile = NodePerformanceProfile(
|
||||
model_id="test",
|
||||
chip_id="test",
|
||||
friendly_name="test",
|
||||
memory=node_b.node_profile.memory,
|
||||
network_interfaces=[
|
||||
NetworkInterfaceInfo(
|
||||
name="en3",
|
||||
ip_address=conn_b_a.send_back_multiaddr.ip_address,
|
||||
),
|
||||
NetworkInterfaceInfo(
|
||||
name="en4",
|
||||
ip_address=conn_b_c.send_back_multiaddr.ip_address,
|
||||
),
|
||||
],
|
||||
system=node_b.node_profile.system,
|
||||
)
|
||||
node_c.node_profile = NodePerformanceProfile(
|
||||
model_id="test",
|
||||
chip_id="test",
|
||||
friendly_name="test",
|
||||
memory=node_c.node_profile.memory,
|
||||
network_interfaces=[
|
||||
NetworkInterfaceInfo(
|
||||
name="en3",
|
||||
ip_address=conn_c_b.send_back_multiaddr.ip_address,
|
||||
),
|
||||
NetworkInterfaceInfo(
|
||||
name="en4",
|
||||
ip_address=conn_c_a.send_back_multiaddr.ip_address,
|
||||
),
|
||||
],
|
||||
system=node_c.node_profile.system,
|
||||
)
|
||||
|
||||
topology.add_node(node_a)
|
||||
topology.add_node(node_b)
|
||||
topology.add_node(node_c)
|
||||
|
||||
topology.add_connection(conn_a_b)
|
||||
topology.add_connection(conn_b_a)
|
||||
topology.add_connection(conn_b_c)
|
||||
topology.add_connection(conn_c_b)
|
||||
topology.add_connection(conn_c_a)
|
||||
topology.add_connection(conn_a_c)
|
||||
|
||||
cycle = [node_a, node_b, node_c]
|
||||
|
||||
# act
|
||||
coordinators = get_mlx_ibv_coordinators(
|
||||
cycle, coordinator_port=5000, cycle_digraph=topology
|
||||
)
|
||||
|
||||
# assert
|
||||
assert len(coordinators) == 3
|
||||
assert node_a_id in coordinators
|
||||
assert node_b_id in coordinators
|
||||
assert node_c_id in coordinators
|
||||
|
||||
# All coordinators should have IP:PORT format
|
||||
for node_id, coordinator in coordinators.items():
|
||||
assert ":" in coordinator, (
|
||||
f"Coordinator for {node_id} should have ':' separator"
|
||||
)
|
||||
|
||||
# Verify port is correct
|
||||
for node_id, coordinator in coordinators.items():
|
||||
assert coordinator.endswith(":5000"), (
|
||||
f"Coordinator for {node_id} should use port 5000"
|
||||
)
|
||||
|
||||
# Rank 0 (node_a) treats this as the listen socket so should listen on all
|
||||
# IPs
|
||||
assert coordinators[node_a_id].startswith("0.0.0.0:"), (
|
||||
"Rank 0 node should use localhost as coordinator"
|
||||
)
|
||||
|
||||
# Non-rank-0 nodes should use the specific IP from their connection to rank 0
|
||||
# node_b uses the IP from conn_b_a (node_b -> node_a)
|
||||
assert coordinators[node_b_id] == (
|
||||
f"{conn_b_a.send_back_multiaddr.ip_address}:5000"
|
||||
), "node_b should use the IP from conn_b_a"
|
||||
|
||||
# node_c uses the IP from conn_c_a (node_c -> node_a)
|
||||
assert coordinators[node_c_id] == (
|
||||
f"{conn_c_a.send_back_multiaddr.ip_address}:5000"
|
||||
), "node_c should use the IP from conn_c_a"
|
||||
|
||||
@@ -95,7 +95,7 @@ async def test_ingest_drops_duplicate_indices(buffer: OrderedBuffer[Event]):
|
||||
|
||||
buffer.ingest(*make_indexed_event(0))
|
||||
buffer.ingest(*event2_first)
|
||||
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
buffer.ingest(*event2_second) # This duplicate should be ignored
|
||||
|
||||
|
||||
@@ -18,6 +18,7 @@ from exo.utils.pydantic_ext import CamelCaseModel
|
||||
|
||||
DEFAULT_ELECTION_TIMEOUT = 3.0
|
||||
|
||||
|
||||
class ElectionMessage(CamelCaseModel):
|
||||
clock: int
|
||||
seniority: int
|
||||
@@ -152,7 +153,9 @@ class Election:
|
||||
self._candidates = candidates
|
||||
logger.debug(f"New candidates: {self._candidates}")
|
||||
logger.debug("Starting new campaign")
|
||||
self._tg.start_soon(self._campaign, candidates, DEFAULT_ELECTION_TIMEOUT)
|
||||
self._tg.start_soon(
|
||||
self._campaign, candidates, DEFAULT_ELECTION_TIMEOUT
|
||||
)
|
||||
logger.debug("Campaign started")
|
||||
continue
|
||||
# Dismiss old messages
|
||||
@@ -181,7 +184,9 @@ class Election:
|
||||
candidates: list[ElectionMessage] = []
|
||||
self._candidates = candidates
|
||||
logger.debug("Starting new campaign")
|
||||
self._tg.start_soon(self._campaign, candidates, DEFAULT_ELECTION_TIMEOUT)
|
||||
self._tg.start_soon(
|
||||
self._campaign, candidates, DEFAULT_ELECTION_TIMEOUT
|
||||
)
|
||||
logger.debug("Campaign started")
|
||||
self._connection_messages.append(first)
|
||||
self._connection_messages.extend(rest)
|
||||
|
||||
@@ -40,6 +40,7 @@ def em(
|
||||
# TESTS #
|
||||
# ======================================= #
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def fast_election_timeout(monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setattr("exo.shared.election.DEFAULT_ELECTION_TIMEOUT", 0.1)
|
||||
|
||||
@@ -2,7 +2,7 @@ from enum import Enum
|
||||
|
||||
from pydantic import model_validator
|
||||
|
||||
from exo.shared.types.common import Host, Id
|
||||
from exo.shared.types.common import Host, Id, NodeId
|
||||
from exo.shared.types.worker.runners import RunnerId, ShardAssignments, ShardMetadata
|
||||
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
|
||||
|
||||
@@ -30,7 +30,7 @@ class MlxRingInstance(BaseInstance):
|
||||
|
||||
class MlxJacclInstance(BaseInstance):
|
||||
ibv_devices: list[list[str | None]]
|
||||
ibv_coordinator: str
|
||||
ibv_coordinators: dict[NodeId, str]
|
||||
|
||||
|
||||
# TODO: Single node instance
|
||||
@@ -40,6 +40,7 @@ Instance = MlxRingInstance | MlxJacclInstance
|
||||
class BoundInstance(CamelCaseModel):
|
||||
instance: Instance
|
||||
bound_runner_id: RunnerId
|
||||
bound_node_id: NodeId
|
||||
|
||||
@property
|
||||
def bound_shard(self) -> ShardMetadata:
|
||||
|
||||
@@ -128,7 +128,9 @@ def mlx_distributed_init(
|
||||
os.environ["MLX_RING_VERBOSE"] = "1"
|
||||
group = mx.distributed.init(backend="ring", strict=True)
|
||||
|
||||
case MlxJacclInstance(ibv_devices=ibv_devices, ibv_coordinator=ibv_coordinator):
|
||||
case MlxJacclInstance(
|
||||
ibv_devices=ibv_devices, ibv_coordinators=ibv_coordinators
|
||||
):
|
||||
# Use RDMA connectivity matrix
|
||||
devices_file = f"./hosts_{rank}.json"
|
||||
ibv_devices_json = json.dumps(ibv_devices)
|
||||
@@ -136,6 +138,8 @@ def mlx_distributed_init(
|
||||
with open(devices_file, "w") as f:
|
||||
_ = f.write(ibv_devices_json)
|
||||
|
||||
ibv_coordinator = ibv_coordinators[bound_instance.bound_node_id]
|
||||
|
||||
logger.info(f"rank {rank} MLX_IBV_DEVICES: {ibv_devices_json}")
|
||||
logger.info(f"rank {rank} MLX_IBV_COORDINATOR: {ibv_coordinator}")
|
||||
os.environ["MLX_IBV_DEVICES"] = devices_file
|
||||
|
||||
@@ -95,7 +95,9 @@ def _create_runner(
|
||||
|
||||
return CreateRunner(
|
||||
instance_id=instance.instance_id,
|
||||
bound_instance=BoundInstance(instance=instance, bound_runner_id=runner_id),
|
||||
bound_instance=BoundInstance(
|
||||
instance=instance, bound_runner_id=runner_id, bound_node_id=node_id
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -35,7 +35,9 @@ def test_plan_requests_download_when_waiting_and_shard_not_downloaded():
|
||||
node_to_runner={NODE_A: RUNNER_1_ID},
|
||||
runner_to_shard={RUNNER_1_ID: shard},
|
||||
)
|
||||
bound_instance = BoundInstance(instance=instance, bound_runner_id=RUNNER_1_ID)
|
||||
bound_instance = BoundInstance(
|
||||
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
|
||||
)
|
||||
runner = FakeRunnerSupervisor(
|
||||
bound_instance=bound_instance, status=RunnerWaitingForModel()
|
||||
)
|
||||
@@ -76,7 +78,9 @@ def test_plan_loads_model_when_all_shards_downloaded_and_waiting():
|
||||
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},
|
||||
runner_to_shard={RUNNER_1_ID: shard1, RUNNER_2_ID: shard2},
|
||||
)
|
||||
bound_instance = BoundInstance(instance=instance, bound_runner_id=RUNNER_1_ID)
|
||||
bound_instance = BoundInstance(
|
||||
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
|
||||
)
|
||||
local_runner = FakeRunnerSupervisor(
|
||||
bound_instance=bound_instance, status=RunnerWaitingForModel()
|
||||
)
|
||||
@@ -126,7 +130,9 @@ def test_plan_does_not_request_download_when_shard_already_downloaded():
|
||||
node_to_runner={NODE_A: RUNNER_1_ID},
|
||||
runner_to_shard={RUNNER_1_ID: shard},
|
||||
)
|
||||
bound_instance = BoundInstance(instance=instance, bound_runner_id=RUNNER_1_ID)
|
||||
bound_instance = BoundInstance(
|
||||
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
|
||||
)
|
||||
runner = FakeRunnerSupervisor(
|
||||
bound_instance=bound_instance, status=RunnerWaitingForModel()
|
||||
)
|
||||
@@ -173,7 +179,9 @@ def test_plan_does_not_load_model_until_all_shards_downloaded_globally():
|
||||
runner_to_shard={RUNNER_1_ID: shard1, RUNNER_2_ID: shard2},
|
||||
)
|
||||
|
||||
bound_instance = BoundInstance(instance=instance, bound_runner_id=RUNNER_1_ID)
|
||||
bound_instance = BoundInstance(
|
||||
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
|
||||
)
|
||||
local_runner = FakeRunnerSupervisor(
|
||||
bound_instance=bound_instance, status=RunnerWaitingForModel()
|
||||
)
|
||||
|
||||
@@ -36,7 +36,9 @@ def test_plan_kills_runner_when_instance_missing():
|
||||
node_to_runner={NODE_A: RUNNER_1_ID},
|
||||
runner_to_shard={RUNNER_1_ID: shard},
|
||||
)
|
||||
bound_instance = BoundInstance(instance=instance, bound_runner_id=RUNNER_1_ID)
|
||||
bound_instance = BoundInstance(
|
||||
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
|
||||
)
|
||||
runner = FakeRunnerSupervisor(bound_instance=bound_instance, status=RunnerReady())
|
||||
|
||||
runners = {RUNNER_1_ID: runner}
|
||||
@@ -71,7 +73,9 @@ def test_plan_kills_runner_when_sibling_failed():
|
||||
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},
|
||||
runner_to_shard={RUNNER_1_ID: shard1, RUNNER_2_ID: shard2},
|
||||
)
|
||||
bound_instance = BoundInstance(instance=instance, bound_runner_id=RUNNER_1_ID)
|
||||
bound_instance = BoundInstance(
|
||||
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
|
||||
)
|
||||
runner = FakeRunnerSupervisor(bound_instance=bound_instance, status=RunnerReady())
|
||||
|
||||
runners = {RUNNER_1_ID: runner}
|
||||
@@ -143,7 +147,9 @@ def test_plan_does_not_create_runner_when_supervisor_already_present():
|
||||
node_to_runner={NODE_A: RUNNER_1_ID},
|
||||
runner_to_shard={RUNNER_1_ID: shard},
|
||||
)
|
||||
bound_instance = BoundInstance(instance=instance, bound_runner_id=RUNNER_1_ID)
|
||||
bound_instance = BoundInstance(
|
||||
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
|
||||
)
|
||||
runner = FakeRunnerSupervisor(bound_instance=bound_instance, status=RunnerReady())
|
||||
|
||||
runners = {RUNNER_1_ID: runner}
|
||||
|
||||
@@ -40,7 +40,9 @@ def test_plan_forwards_pending_chat_completion_when_runner_ready():
|
||||
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},
|
||||
runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1},
|
||||
)
|
||||
bound_instance = BoundInstance(instance=instance, bound_runner_id=RUNNER_1_ID)
|
||||
bound_instance = BoundInstance(
|
||||
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
|
||||
)
|
||||
local_runner = FakeRunnerSupervisor(
|
||||
bound_instance=bound_instance, status=RunnerReady()
|
||||
)
|
||||
@@ -86,7 +88,9 @@ def test_plan_does_not_forward_chat_completion_if_any_runner_not_ready():
|
||||
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},
|
||||
runner_to_shard={RUNNER_1_ID: shard1, RUNNER_2_ID: shard2},
|
||||
)
|
||||
bound_instance = BoundInstance(instance=instance, bound_runner_id=RUNNER_1_ID)
|
||||
bound_instance = BoundInstance(
|
||||
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
|
||||
)
|
||||
local_runner = FakeRunnerSupervisor(
|
||||
bound_instance=bound_instance, status=RunnerReady()
|
||||
)
|
||||
@@ -131,7 +135,9 @@ def test_plan_does_not_forward_tasks_for_other_instances():
|
||||
node_to_runner={NODE_A: RUNNER_1_ID},
|
||||
runner_to_shard={RUNNER_1_ID: shard},
|
||||
)
|
||||
bound_instance = BoundInstance(instance=local_instance, bound_runner_id=RUNNER_1_ID)
|
||||
bound_instance = BoundInstance(
|
||||
instance=local_instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
|
||||
)
|
||||
local_runner = FakeRunnerSupervisor(
|
||||
bound_instance=bound_instance, status=RunnerReady()
|
||||
)
|
||||
@@ -175,7 +181,9 @@ def test_plan_ignores_non_pending_or_non_chat_tasks():
|
||||
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},
|
||||
runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1},
|
||||
)
|
||||
bound_instance = BoundInstance(instance=instance, bound_runner_id=RUNNER_1_ID)
|
||||
bound_instance = BoundInstance(
|
||||
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
|
||||
)
|
||||
|
||||
local_runner = FakeRunnerSupervisor(
|
||||
bound_instance=bound_instance, status=RunnerReady()
|
||||
@@ -236,7 +244,9 @@ def test_plan_returns_none_when_nothing_to_do():
|
||||
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},
|
||||
runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1},
|
||||
)
|
||||
bound_instance = BoundInstance(instance=instance, bound_runner_id=RUNNER_1_ID)
|
||||
bound_instance = BoundInstance(
|
||||
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
|
||||
)
|
||||
local_runner = FakeRunnerSupervisor(
|
||||
bound_instance=bound_instance, status=RunnerRunning()
|
||||
)
|
||||
|
||||
@@ -35,7 +35,9 @@ def test_plan_starts_warmup_for_non_zero_rank_when_all_loaded_or_warming():
|
||||
runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1},
|
||||
)
|
||||
|
||||
bound_instance = BoundInstance(instance=instance, bound_runner_id=RUNNER_2_ID)
|
||||
bound_instance = BoundInstance(
|
||||
instance=instance, bound_runner_id=RUNNER_2_ID, bound_node_id=NODE_B
|
||||
)
|
||||
local_runner = FakeRunnerSupervisor(
|
||||
bound_instance=bound_instance, status=RunnerLoaded()
|
||||
)
|
||||
@@ -75,7 +77,9 @@ def test_plan_starts_warmup_for_rank_zero_after_others_warming():
|
||||
runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1},
|
||||
)
|
||||
|
||||
bound_instance = BoundInstance(instance=instance, bound_runner_id=RUNNER_1_ID)
|
||||
bound_instance = BoundInstance(
|
||||
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
|
||||
)
|
||||
local_runner = FakeRunnerSupervisor(
|
||||
bound_instance=bound_instance, status=RunnerLoaded()
|
||||
)
|
||||
@@ -114,7 +118,9 @@ def test_plan_does_not_start_warmup_for_non_zero_rank_until_all_loaded_or_warmin
|
||||
runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1},
|
||||
)
|
||||
|
||||
bound_instance = BoundInstance(instance=instance, bound_runner_id=RUNNER_2_ID)
|
||||
bound_instance = BoundInstance(
|
||||
instance=instance, bound_runner_id=RUNNER_2_ID, bound_node_id=NODE_B
|
||||
)
|
||||
local_runner = FakeRunnerSupervisor(
|
||||
bound_instance=bound_instance, status=RunnerLoaded()
|
||||
)
|
||||
@@ -153,7 +159,9 @@ def test_plan_does_not_start_warmup_for_rank_zero_until_others_warming():
|
||||
runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1},
|
||||
)
|
||||
|
||||
bound_instance = BoundInstance(instance=instance, bound_runner_id=RUNNER_1_ID)
|
||||
bound_instance = BoundInstance(
|
||||
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
|
||||
)
|
||||
local_runner = FakeRunnerSupervisor(
|
||||
bound_instance=bound_instance, status=RunnerLoaded()
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user