mirror of
https://github.com/exo-explore/exo.git
synced 2025-12-23 22:27:50 -05:00
placement: pass different ibv_coordinator per node
This commit is contained in:
@@ -8,7 +8,7 @@ from loguru import logger
|
|||||||
from exo.master.placement_utils import (
|
from exo.master.placement_utils import (
|
||||||
filter_cycles_by_memory,
|
filter_cycles_by_memory,
|
||||||
get_hosts_from_subgraph,
|
get_hosts_from_subgraph,
|
||||||
get_mlx_ibv_coordinator,
|
get_mlx_ibv_coordinators,
|
||||||
get_mlx_ibv_devices_matrix,
|
get_mlx_ibv_devices_matrix,
|
||||||
get_shard_assignments,
|
get_shard_assignments,
|
||||||
get_smallest_cycles,
|
get_smallest_cycles,
|
||||||
@@ -110,15 +110,16 @@ def get_instance_placements_after_create(
|
|||||||
selected_cycle,
|
selected_cycle,
|
||||||
cycle_digraph,
|
cycle_digraph,
|
||||||
)
|
)
|
||||||
mlx_ibv_coordinator = get_mlx_ibv_coordinator(
|
mlx_ibv_coordinators = get_mlx_ibv_coordinators(
|
||||||
selected_cycle,
|
selected_cycle,
|
||||||
coordinator_port=random_ephemeral_port(),
|
coordinator_port=random_ephemeral_port(),
|
||||||
|
cycle_digraph=cycle_digraph,
|
||||||
)
|
)
|
||||||
target_instances[instance_id] = MlxJacclInstance(
|
target_instances[instance_id] = MlxJacclInstance(
|
||||||
instance_id=instance_id,
|
instance_id=instance_id,
|
||||||
shard_assignments=shard_assignments,
|
shard_assignments=shard_assignments,
|
||||||
ibv_devices=mlx_ibv_devices,
|
ibv_devices=mlx_ibv_devices,
|
||||||
ibv_coordinator=mlx_ibv_coordinator,
|
ibv_coordinators=mlx_ibv_coordinators,
|
||||||
)
|
)
|
||||||
case InstanceMeta.MlxRing:
|
case InstanceMeta.MlxRing:
|
||||||
hosts: list[Host] = get_hosts_from_subgraph(cycle_digraph)
|
hosts: list[Host] = get_hosts_from_subgraph(cycle_digraph)
|
||||||
|
|||||||
@@ -269,20 +269,31 @@ def _find_interface_name_for_ip(
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def get_mlx_ibv_coordinator(
|
def get_mlx_ibv_coordinators(
|
||||||
selected_cycle: list[NodeInfo],
|
selected_cycle: list[NodeInfo],
|
||||||
coordinator_port: int,
|
coordinator_port: int,
|
||||||
) -> str:
|
cycle_digraph: Topology,
|
||||||
"""Get the coordinator address for MLX IBV (rank 0 device).
|
) -> dict[NodeId, str]:
|
||||||
|
"""Get the coordinator addresses for MLX IBV (rank 0 device).
|
||||||
|
|
||||||
Selects a non-thunderbolt IP address from rank 0 node as a heuristic for
|
Select an IP address that each node can reach for the rank 0 node. Returns
|
||||||
ethernet accessibility. Returns address in format "X.X.X.X:PORT".
|
address in format "X.X.X.X:PORT" per node.
|
||||||
"""
|
"""
|
||||||
rank_0_node = selected_cycle[0]
|
rank_0_node = selected_cycle[0]
|
||||||
logger.info(f"Selecting coordinator from rank 0 node: {rank_0_node.node_id}")
|
logger.info(f"Selecting coordinator from rank 0 node: {rank_0_node.node_id}")
|
||||||
assert rank_0_node.node_profile is not None
|
|
||||||
for iface in rank_0_node.node_profile.network_interfaces:
|
|
||||||
if iface.name == "en0" and "." in iface.ip_address:
|
|
||||||
return f"{iface.ip_address}:{coordinator_port}"
|
|
||||||
|
|
||||||
raise ValueError("No en0 iface found for device")
|
def get_ip_for_node(n: NodeInfo) -> str:
|
||||||
|
if n.node_id == rank_0_node.node_id:
|
||||||
|
return "0.0.0.0"
|
||||||
|
|
||||||
|
for ip in _find_connection_ip(n, rank_0_node, cycle_digraph):
|
||||||
|
return ip
|
||||||
|
|
||||||
|
logger.warning(
|
||||||
|
f"Failed to find directly connected ip between {n.node_id} and {rank_0_node.node_id}"
|
||||||
|
)
|
||||||
|
raise ValueError("Current ibv backend requires all-to-all rdma connections")
|
||||||
|
|
||||||
|
return {
|
||||||
|
n.node_id: f"{get_ip_for_node(n)}:{coordinator_port}" for n in selected_cycle
|
||||||
|
}
|
||||||
|
|||||||
@@ -166,28 +166,28 @@ async def test_master():
|
|||||||
events[1].event.instance.shard_assignments.runner_to_shard.keys()
|
events[1].event.instance.shard_assignments.runner_to_shard.keys()
|
||||||
)[0]
|
)[0]
|
||||||
assert events[1].event.instance == MlxRingInstance(
|
assert events[1].event.instance == MlxRingInstance(
|
||||||
instance_id=events[1].event.instance.instance_id,
|
instance_id=events[1].event.instance.instance_id,
|
||||||
shard_assignments=ShardAssignments(
|
shard_assignments=ShardAssignments(
|
||||||
model_id=ModelId("llama-3.2-1b"),
|
model_id=ModelId("llama-3.2-1b"),
|
||||||
runner_to_shard={
|
runner_to_shard={
|
||||||
(runner_id): PipelineShardMetadata(
|
(runner_id): PipelineShardMetadata(
|
||||||
start_layer=0,
|
start_layer=0,
|
||||||
end_layer=16,
|
end_layer=16,
|
||||||
|
n_layers=16,
|
||||||
|
model_meta=ModelMetadata(
|
||||||
|
model_id=ModelId("llama-3.2-1b"),
|
||||||
|
pretty_name="Llama 3.2 1B",
|
||||||
n_layers=16,
|
n_layers=16,
|
||||||
model_meta=ModelMetadata(
|
storage_size=Memory.from_bytes(678948),
|
||||||
model_id=ModelId("llama-3.2-1b"),
|
),
|
||||||
pretty_name="Llama 3.2 1B",
|
device_rank=0,
|
||||||
n_layers=16,
|
world_size=1,
|
||||||
storage_size=Memory.from_bytes(678948),
|
)
|
||||||
),
|
},
|
||||||
device_rank=0,
|
node_to_runner={node_id: runner_id},
|
||||||
world_size=1,
|
),
|
||||||
)
|
hosts=[],
|
||||||
},
|
)
|
||||||
node_to_runner={node_id: runner_id},
|
|
||||||
),
|
|
||||||
hosts=[],
|
|
||||||
)
|
|
||||||
assert isinstance(events[2].event, TaskCreated)
|
assert isinstance(events[2].event, TaskCreated)
|
||||||
assert events[2].event.task.task_status == TaskStatus.Pending
|
assert events[2].event.task.task_status == TaskStatus.Pending
|
||||||
assert isinstance(events[2].event.task, ChatCompletionTask)
|
assert isinstance(events[2].event.task, ChatCompletionTask)
|
||||||
|
|||||||
@@ -437,7 +437,7 @@ def test_tensor_rdma_backend_connectivity_matrix(
|
|||||||
assert isinstance(instance, MlxJacclInstance)
|
assert isinstance(instance, MlxJacclInstance)
|
||||||
|
|
||||||
assert instance.ibv_devices is not None
|
assert instance.ibv_devices is not None
|
||||||
assert instance.ibv_coordinator is not None
|
assert instance.ibv_coordinators is not None
|
||||||
|
|
||||||
matrix = instance.ibv_devices
|
matrix = instance.ibv_devices
|
||||||
assert len(matrix) == 3
|
assert len(matrix) == 3
|
||||||
@@ -458,5 +458,17 @@ def test_tensor_rdma_backend_connectivity_matrix(
|
|||||||
assert matrix[idx_b][idx_c] == "rdma_en3"
|
assert matrix[idx_b][idx_c] == "rdma_en3"
|
||||||
assert matrix[idx_c][idx_a] == "rdma_en3"
|
assert matrix[idx_c][idx_a] == "rdma_en3"
|
||||||
|
|
||||||
assert ":" in instance.ibv_coordinator
|
# Verify coordinators are set for all nodes
|
||||||
assert not instance.ibv_coordinator.startswith("169.254")
|
assert len(instance.ibv_coordinators) == 3
|
||||||
|
for node_id in assigned_nodes:
|
||||||
|
assert node_id in instance.ibv_coordinators
|
||||||
|
coordinator = instance.ibv_coordinators[node_id]
|
||||||
|
assert ":" in coordinator
|
||||||
|
# Rank 0 node should use 0.0.0.0, others should use connection-specific IPs
|
||||||
|
if node_id == assigned_nodes[0]:
|
||||||
|
assert coordinator.startswith("0.0.0.0:")
|
||||||
|
else:
|
||||||
|
# Non-rank-0 nodes should have valid IP addresses (can be link-local)
|
||||||
|
ip_part = coordinator.split(":")[0]
|
||||||
|
# Just verify it's a valid IP format
|
||||||
|
assert len(ip_part.split(".")) == 4
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import pytest
|
|||||||
from exo.master.placement_utils import (
|
from exo.master.placement_utils import (
|
||||||
filter_cycles_by_memory,
|
filter_cycles_by_memory,
|
||||||
get_hosts_from_subgraph,
|
get_hosts_from_subgraph,
|
||||||
|
get_mlx_ibv_coordinators,
|
||||||
get_shard_assignments,
|
get_shard_assignments,
|
||||||
get_smallest_cycles,
|
get_smallest_cycles,
|
||||||
)
|
)
|
||||||
@@ -12,6 +13,7 @@ from exo.shared.topology import Topology
|
|||||||
from exo.shared.types.common import Host, NodeId
|
from exo.shared.types.common import Host, NodeId
|
||||||
from exo.shared.types.memory import Memory
|
from exo.shared.types.memory import Memory
|
||||||
from exo.shared.types.models import ModelId, ModelMetadata
|
from exo.shared.types.models import ModelId, ModelMetadata
|
||||||
|
from exo.shared.types.profiling import NetworkInterfaceInfo, NodePerformanceProfile
|
||||||
from exo.shared.types.topology import Connection, NodeInfo
|
from exo.shared.types.topology import Connection, NodeInfo
|
||||||
from exo.shared.types.worker.shards import Sharding
|
from exo.shared.types.worker.shards import Sharding
|
||||||
|
|
||||||
@@ -261,3 +263,135 @@ def test_get_hosts_from_subgraph(
|
|||||||
]
|
]
|
||||||
for expected_host in expected_hosts:
|
for expected_host in expected_hosts:
|
||||||
assert expected_host in hosts
|
assert expected_host in hosts
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_mlx_ibv_coordinators(
|
||||||
|
topology: Topology,
|
||||||
|
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||||
|
create_connection: Callable[[NodeId, NodeId, int | None], Connection],
|
||||||
|
):
|
||||||
|
# arrange
|
||||||
|
node_a_id = NodeId()
|
||||||
|
node_b_id = NodeId()
|
||||||
|
node_c_id = NodeId()
|
||||||
|
|
||||||
|
node_a = create_node(500 * 1024, node_a_id)
|
||||||
|
node_b = create_node(500 * 1024, node_b_id)
|
||||||
|
node_c = create_node(1000 * 1024, node_c_id)
|
||||||
|
|
||||||
|
conn_a_b = create_connection(node_a_id, node_b_id, 5001)
|
||||||
|
conn_b_a = create_connection(node_b_id, node_a_id, 5002)
|
||||||
|
conn_b_c = create_connection(node_b_id, node_c_id, 5003)
|
||||||
|
conn_c_b = create_connection(node_c_id, node_b_id, 5004)
|
||||||
|
conn_c_a = create_connection(node_c_id, node_a_id, 5005)
|
||||||
|
conn_a_c = create_connection(node_a_id, node_c_id, 5006)
|
||||||
|
|
||||||
|
# Update node profiles with network interfaces before adding to topology
|
||||||
|
assert node_a.node_profile is not None
|
||||||
|
assert node_b.node_profile is not None
|
||||||
|
assert node_c.node_profile is not None
|
||||||
|
|
||||||
|
node_a.node_profile = NodePerformanceProfile(
|
||||||
|
model_id="test",
|
||||||
|
chip_id="test",
|
||||||
|
friendly_name="test",
|
||||||
|
memory=node_a.node_profile.memory,
|
||||||
|
network_interfaces=[
|
||||||
|
NetworkInterfaceInfo(
|
||||||
|
name="en3",
|
||||||
|
ip_address=conn_a_b.send_back_multiaddr.ip_address,
|
||||||
|
),
|
||||||
|
NetworkInterfaceInfo(
|
||||||
|
name="en4",
|
||||||
|
ip_address=conn_a_c.send_back_multiaddr.ip_address,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
system=node_a.node_profile.system,
|
||||||
|
)
|
||||||
|
node_b.node_profile = NodePerformanceProfile(
|
||||||
|
model_id="test",
|
||||||
|
chip_id="test",
|
||||||
|
friendly_name="test",
|
||||||
|
memory=node_b.node_profile.memory,
|
||||||
|
network_interfaces=[
|
||||||
|
NetworkInterfaceInfo(
|
||||||
|
name="en3",
|
||||||
|
ip_address=conn_b_a.send_back_multiaddr.ip_address,
|
||||||
|
),
|
||||||
|
NetworkInterfaceInfo(
|
||||||
|
name="en4",
|
||||||
|
ip_address=conn_b_c.send_back_multiaddr.ip_address,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
system=node_b.node_profile.system,
|
||||||
|
)
|
||||||
|
node_c.node_profile = NodePerformanceProfile(
|
||||||
|
model_id="test",
|
||||||
|
chip_id="test",
|
||||||
|
friendly_name="test",
|
||||||
|
memory=node_c.node_profile.memory,
|
||||||
|
network_interfaces=[
|
||||||
|
NetworkInterfaceInfo(
|
||||||
|
name="en3",
|
||||||
|
ip_address=conn_c_b.send_back_multiaddr.ip_address,
|
||||||
|
),
|
||||||
|
NetworkInterfaceInfo(
|
||||||
|
name="en4",
|
||||||
|
ip_address=conn_c_a.send_back_multiaddr.ip_address,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
system=node_c.node_profile.system,
|
||||||
|
)
|
||||||
|
|
||||||
|
topology.add_node(node_a)
|
||||||
|
topology.add_node(node_b)
|
||||||
|
topology.add_node(node_c)
|
||||||
|
|
||||||
|
topology.add_connection(conn_a_b)
|
||||||
|
topology.add_connection(conn_b_a)
|
||||||
|
topology.add_connection(conn_b_c)
|
||||||
|
topology.add_connection(conn_c_b)
|
||||||
|
topology.add_connection(conn_c_a)
|
||||||
|
topology.add_connection(conn_a_c)
|
||||||
|
|
||||||
|
cycle = [node_a, node_b, node_c]
|
||||||
|
|
||||||
|
# act
|
||||||
|
coordinators = get_mlx_ibv_coordinators(
|
||||||
|
cycle, coordinator_port=5000, cycle_digraph=topology
|
||||||
|
)
|
||||||
|
|
||||||
|
# assert
|
||||||
|
assert len(coordinators) == 3
|
||||||
|
assert node_a_id in coordinators
|
||||||
|
assert node_b_id in coordinators
|
||||||
|
assert node_c_id in coordinators
|
||||||
|
|
||||||
|
# All coordinators should have IP:PORT format
|
||||||
|
for node_id, coordinator in coordinators.items():
|
||||||
|
assert ":" in coordinator, (
|
||||||
|
f"Coordinator for {node_id} should have ':' separator"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify port is correct
|
||||||
|
for node_id, coordinator in coordinators.items():
|
||||||
|
assert coordinator.endswith(":5000"), (
|
||||||
|
f"Coordinator for {node_id} should use port 5000"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Rank 0 (node_a) treats this as the listen socket so should listen on all
|
||||||
|
# IPs
|
||||||
|
assert coordinators[node_a_id].startswith("0.0.0.0:"), (
|
||||||
|
"Rank 0 node should use localhost as coordinator"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Non-rank-0 nodes should use the specific IP from their connection to rank 0
|
||||||
|
# node_b uses the IP from conn_b_a (node_b -> node_a)
|
||||||
|
assert coordinators[node_b_id] == (
|
||||||
|
f"{conn_b_a.send_back_multiaddr.ip_address}:5000"
|
||||||
|
), "node_b should use the IP from conn_b_a"
|
||||||
|
|
||||||
|
# node_c uses the IP from conn_c_a (node_c -> node_a)
|
||||||
|
assert coordinators[node_c_id] == (
|
||||||
|
f"{conn_c_a.send_back_multiaddr.ip_address}:5000"
|
||||||
|
), "node_c should use the IP from conn_c_a"
|
||||||
|
|||||||
@@ -95,7 +95,7 @@ async def test_ingest_drops_duplicate_indices(buffer: OrderedBuffer[Event]):
|
|||||||
|
|
||||||
buffer.ingest(*make_indexed_event(0))
|
buffer.ingest(*make_indexed_event(0))
|
||||||
buffer.ingest(*event2_first)
|
buffer.ingest(*event2_first)
|
||||||
|
|
||||||
with pytest.raises(AssertionError):
|
with pytest.raises(AssertionError):
|
||||||
buffer.ingest(*event2_second) # This duplicate should be ignored
|
buffer.ingest(*event2_second) # This duplicate should be ignored
|
||||||
|
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ from exo.utils.pydantic_ext import CamelCaseModel
|
|||||||
|
|
||||||
DEFAULT_ELECTION_TIMEOUT = 3.0
|
DEFAULT_ELECTION_TIMEOUT = 3.0
|
||||||
|
|
||||||
|
|
||||||
class ElectionMessage(CamelCaseModel):
|
class ElectionMessage(CamelCaseModel):
|
||||||
clock: int
|
clock: int
|
||||||
seniority: int
|
seniority: int
|
||||||
@@ -152,7 +153,9 @@ class Election:
|
|||||||
self._candidates = candidates
|
self._candidates = candidates
|
||||||
logger.debug(f"New candidates: {self._candidates}")
|
logger.debug(f"New candidates: {self._candidates}")
|
||||||
logger.debug("Starting new campaign")
|
logger.debug("Starting new campaign")
|
||||||
self._tg.start_soon(self._campaign, candidates, DEFAULT_ELECTION_TIMEOUT)
|
self._tg.start_soon(
|
||||||
|
self._campaign, candidates, DEFAULT_ELECTION_TIMEOUT
|
||||||
|
)
|
||||||
logger.debug("Campaign started")
|
logger.debug("Campaign started")
|
||||||
continue
|
continue
|
||||||
# Dismiss old messages
|
# Dismiss old messages
|
||||||
@@ -181,7 +184,9 @@ class Election:
|
|||||||
candidates: list[ElectionMessage] = []
|
candidates: list[ElectionMessage] = []
|
||||||
self._candidates = candidates
|
self._candidates = candidates
|
||||||
logger.debug("Starting new campaign")
|
logger.debug("Starting new campaign")
|
||||||
self._tg.start_soon(self._campaign, candidates, DEFAULT_ELECTION_TIMEOUT)
|
self._tg.start_soon(
|
||||||
|
self._campaign, candidates, DEFAULT_ELECTION_TIMEOUT
|
||||||
|
)
|
||||||
logger.debug("Campaign started")
|
logger.debug("Campaign started")
|
||||||
self._connection_messages.append(first)
|
self._connection_messages.append(first)
|
||||||
self._connection_messages.extend(rest)
|
self._connection_messages.extend(rest)
|
||||||
|
|||||||
@@ -40,6 +40,7 @@ def em(
|
|||||||
# TESTS #
|
# TESTS #
|
||||||
# ======================================= #
|
# ======================================= #
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def fast_election_timeout(monkeypatch: pytest.MonkeyPatch):
|
def fast_election_timeout(monkeypatch: pytest.MonkeyPatch):
|
||||||
monkeypatch.setattr("exo.shared.election.DEFAULT_ELECTION_TIMEOUT", 0.1)
|
monkeypatch.setattr("exo.shared.election.DEFAULT_ELECTION_TIMEOUT", 0.1)
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ from enum import Enum
|
|||||||
|
|
||||||
from pydantic import model_validator
|
from pydantic import model_validator
|
||||||
|
|
||||||
from exo.shared.types.common import Host, Id
|
from exo.shared.types.common import Host, Id, NodeId
|
||||||
from exo.shared.types.worker.runners import RunnerId, ShardAssignments, ShardMetadata
|
from exo.shared.types.worker.runners import RunnerId, ShardAssignments, ShardMetadata
|
||||||
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
|
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
|
||||||
|
|
||||||
@@ -30,7 +30,7 @@ class MlxRingInstance(BaseInstance):
|
|||||||
|
|
||||||
class MlxJacclInstance(BaseInstance):
|
class MlxJacclInstance(BaseInstance):
|
||||||
ibv_devices: list[list[str | None]]
|
ibv_devices: list[list[str | None]]
|
||||||
ibv_coordinator: str
|
ibv_coordinators: dict[NodeId, str]
|
||||||
|
|
||||||
|
|
||||||
# TODO: Single node instance
|
# TODO: Single node instance
|
||||||
@@ -40,6 +40,7 @@ Instance = MlxRingInstance | MlxJacclInstance
|
|||||||
class BoundInstance(CamelCaseModel):
|
class BoundInstance(CamelCaseModel):
|
||||||
instance: Instance
|
instance: Instance
|
||||||
bound_runner_id: RunnerId
|
bound_runner_id: RunnerId
|
||||||
|
bound_node_id: NodeId
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def bound_shard(self) -> ShardMetadata:
|
def bound_shard(self) -> ShardMetadata:
|
||||||
|
|||||||
@@ -128,7 +128,9 @@ def mlx_distributed_init(
|
|||||||
os.environ["MLX_RING_VERBOSE"] = "1"
|
os.environ["MLX_RING_VERBOSE"] = "1"
|
||||||
group = mx.distributed.init(backend="ring", strict=True)
|
group = mx.distributed.init(backend="ring", strict=True)
|
||||||
|
|
||||||
case MlxJacclInstance(ibv_devices=ibv_devices, ibv_coordinator=ibv_coordinator):
|
case MlxJacclInstance(
|
||||||
|
ibv_devices=ibv_devices, ibv_coordinators=ibv_coordinators
|
||||||
|
):
|
||||||
# Use RDMA connectivity matrix
|
# Use RDMA connectivity matrix
|
||||||
devices_file = f"./hosts_{rank}.json"
|
devices_file = f"./hosts_{rank}.json"
|
||||||
ibv_devices_json = json.dumps(ibv_devices)
|
ibv_devices_json = json.dumps(ibv_devices)
|
||||||
@@ -136,6 +138,8 @@ def mlx_distributed_init(
|
|||||||
with open(devices_file, "w") as f:
|
with open(devices_file, "w") as f:
|
||||||
_ = f.write(ibv_devices_json)
|
_ = f.write(ibv_devices_json)
|
||||||
|
|
||||||
|
ibv_coordinator = ibv_coordinators[bound_instance.bound_node_id]
|
||||||
|
|
||||||
logger.info(f"rank {rank} MLX_IBV_DEVICES: {ibv_devices_json}")
|
logger.info(f"rank {rank} MLX_IBV_DEVICES: {ibv_devices_json}")
|
||||||
logger.info(f"rank {rank} MLX_IBV_COORDINATOR: {ibv_coordinator}")
|
logger.info(f"rank {rank} MLX_IBV_COORDINATOR: {ibv_coordinator}")
|
||||||
os.environ["MLX_IBV_DEVICES"] = devices_file
|
os.environ["MLX_IBV_DEVICES"] = devices_file
|
||||||
|
|||||||
@@ -95,7 +95,9 @@ def _create_runner(
|
|||||||
|
|
||||||
return CreateRunner(
|
return CreateRunner(
|
||||||
instance_id=instance.instance_id,
|
instance_id=instance.instance_id,
|
||||||
bound_instance=BoundInstance(instance=instance, bound_runner_id=runner_id),
|
bound_instance=BoundInstance(
|
||||||
|
instance=instance, bound_runner_id=runner_id, bound_node_id=node_id
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -35,7 +35,9 @@ def test_plan_requests_download_when_waiting_and_shard_not_downloaded():
|
|||||||
node_to_runner={NODE_A: RUNNER_1_ID},
|
node_to_runner={NODE_A: RUNNER_1_ID},
|
||||||
runner_to_shard={RUNNER_1_ID: shard},
|
runner_to_shard={RUNNER_1_ID: shard},
|
||||||
)
|
)
|
||||||
bound_instance = BoundInstance(instance=instance, bound_runner_id=RUNNER_1_ID)
|
bound_instance = BoundInstance(
|
||||||
|
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
|
||||||
|
)
|
||||||
runner = FakeRunnerSupervisor(
|
runner = FakeRunnerSupervisor(
|
||||||
bound_instance=bound_instance, status=RunnerWaitingForModel()
|
bound_instance=bound_instance, status=RunnerWaitingForModel()
|
||||||
)
|
)
|
||||||
@@ -76,7 +78,9 @@ def test_plan_loads_model_when_all_shards_downloaded_and_waiting():
|
|||||||
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},
|
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},
|
||||||
runner_to_shard={RUNNER_1_ID: shard1, RUNNER_2_ID: shard2},
|
runner_to_shard={RUNNER_1_ID: shard1, RUNNER_2_ID: shard2},
|
||||||
)
|
)
|
||||||
bound_instance = BoundInstance(instance=instance, bound_runner_id=RUNNER_1_ID)
|
bound_instance = BoundInstance(
|
||||||
|
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
|
||||||
|
)
|
||||||
local_runner = FakeRunnerSupervisor(
|
local_runner = FakeRunnerSupervisor(
|
||||||
bound_instance=bound_instance, status=RunnerWaitingForModel()
|
bound_instance=bound_instance, status=RunnerWaitingForModel()
|
||||||
)
|
)
|
||||||
@@ -126,7 +130,9 @@ def test_plan_does_not_request_download_when_shard_already_downloaded():
|
|||||||
node_to_runner={NODE_A: RUNNER_1_ID},
|
node_to_runner={NODE_A: RUNNER_1_ID},
|
||||||
runner_to_shard={RUNNER_1_ID: shard},
|
runner_to_shard={RUNNER_1_ID: shard},
|
||||||
)
|
)
|
||||||
bound_instance = BoundInstance(instance=instance, bound_runner_id=RUNNER_1_ID)
|
bound_instance = BoundInstance(
|
||||||
|
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
|
||||||
|
)
|
||||||
runner = FakeRunnerSupervisor(
|
runner = FakeRunnerSupervisor(
|
||||||
bound_instance=bound_instance, status=RunnerWaitingForModel()
|
bound_instance=bound_instance, status=RunnerWaitingForModel()
|
||||||
)
|
)
|
||||||
@@ -173,7 +179,9 @@ def test_plan_does_not_load_model_until_all_shards_downloaded_globally():
|
|||||||
runner_to_shard={RUNNER_1_ID: shard1, RUNNER_2_ID: shard2},
|
runner_to_shard={RUNNER_1_ID: shard1, RUNNER_2_ID: shard2},
|
||||||
)
|
)
|
||||||
|
|
||||||
bound_instance = BoundInstance(instance=instance, bound_runner_id=RUNNER_1_ID)
|
bound_instance = BoundInstance(
|
||||||
|
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
|
||||||
|
)
|
||||||
local_runner = FakeRunnerSupervisor(
|
local_runner = FakeRunnerSupervisor(
|
||||||
bound_instance=bound_instance, status=RunnerWaitingForModel()
|
bound_instance=bound_instance, status=RunnerWaitingForModel()
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -36,7 +36,9 @@ def test_plan_kills_runner_when_instance_missing():
|
|||||||
node_to_runner={NODE_A: RUNNER_1_ID},
|
node_to_runner={NODE_A: RUNNER_1_ID},
|
||||||
runner_to_shard={RUNNER_1_ID: shard},
|
runner_to_shard={RUNNER_1_ID: shard},
|
||||||
)
|
)
|
||||||
bound_instance = BoundInstance(instance=instance, bound_runner_id=RUNNER_1_ID)
|
bound_instance = BoundInstance(
|
||||||
|
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
|
||||||
|
)
|
||||||
runner = FakeRunnerSupervisor(bound_instance=bound_instance, status=RunnerReady())
|
runner = FakeRunnerSupervisor(bound_instance=bound_instance, status=RunnerReady())
|
||||||
|
|
||||||
runners = {RUNNER_1_ID: runner}
|
runners = {RUNNER_1_ID: runner}
|
||||||
@@ -71,7 +73,9 @@ def test_plan_kills_runner_when_sibling_failed():
|
|||||||
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},
|
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},
|
||||||
runner_to_shard={RUNNER_1_ID: shard1, RUNNER_2_ID: shard2},
|
runner_to_shard={RUNNER_1_ID: shard1, RUNNER_2_ID: shard2},
|
||||||
)
|
)
|
||||||
bound_instance = BoundInstance(instance=instance, bound_runner_id=RUNNER_1_ID)
|
bound_instance = BoundInstance(
|
||||||
|
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
|
||||||
|
)
|
||||||
runner = FakeRunnerSupervisor(bound_instance=bound_instance, status=RunnerReady())
|
runner = FakeRunnerSupervisor(bound_instance=bound_instance, status=RunnerReady())
|
||||||
|
|
||||||
runners = {RUNNER_1_ID: runner}
|
runners = {RUNNER_1_ID: runner}
|
||||||
@@ -143,7 +147,9 @@ def test_plan_does_not_create_runner_when_supervisor_already_present():
|
|||||||
node_to_runner={NODE_A: RUNNER_1_ID},
|
node_to_runner={NODE_A: RUNNER_1_ID},
|
||||||
runner_to_shard={RUNNER_1_ID: shard},
|
runner_to_shard={RUNNER_1_ID: shard},
|
||||||
)
|
)
|
||||||
bound_instance = BoundInstance(instance=instance, bound_runner_id=RUNNER_1_ID)
|
bound_instance = BoundInstance(
|
||||||
|
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
|
||||||
|
)
|
||||||
runner = FakeRunnerSupervisor(bound_instance=bound_instance, status=RunnerReady())
|
runner = FakeRunnerSupervisor(bound_instance=bound_instance, status=RunnerReady())
|
||||||
|
|
||||||
runners = {RUNNER_1_ID: runner}
|
runners = {RUNNER_1_ID: runner}
|
||||||
|
|||||||
@@ -40,7 +40,9 @@ def test_plan_forwards_pending_chat_completion_when_runner_ready():
|
|||||||
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},
|
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},
|
||||||
runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1},
|
runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1},
|
||||||
)
|
)
|
||||||
bound_instance = BoundInstance(instance=instance, bound_runner_id=RUNNER_1_ID)
|
bound_instance = BoundInstance(
|
||||||
|
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
|
||||||
|
)
|
||||||
local_runner = FakeRunnerSupervisor(
|
local_runner = FakeRunnerSupervisor(
|
||||||
bound_instance=bound_instance, status=RunnerReady()
|
bound_instance=bound_instance, status=RunnerReady()
|
||||||
)
|
)
|
||||||
@@ -86,7 +88,9 @@ def test_plan_does_not_forward_chat_completion_if_any_runner_not_ready():
|
|||||||
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},
|
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},
|
||||||
runner_to_shard={RUNNER_1_ID: shard1, RUNNER_2_ID: shard2},
|
runner_to_shard={RUNNER_1_ID: shard1, RUNNER_2_ID: shard2},
|
||||||
)
|
)
|
||||||
bound_instance = BoundInstance(instance=instance, bound_runner_id=RUNNER_1_ID)
|
bound_instance = BoundInstance(
|
||||||
|
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
|
||||||
|
)
|
||||||
local_runner = FakeRunnerSupervisor(
|
local_runner = FakeRunnerSupervisor(
|
||||||
bound_instance=bound_instance, status=RunnerReady()
|
bound_instance=bound_instance, status=RunnerReady()
|
||||||
)
|
)
|
||||||
@@ -131,7 +135,9 @@ def test_plan_does_not_forward_tasks_for_other_instances():
|
|||||||
node_to_runner={NODE_A: RUNNER_1_ID},
|
node_to_runner={NODE_A: RUNNER_1_ID},
|
||||||
runner_to_shard={RUNNER_1_ID: shard},
|
runner_to_shard={RUNNER_1_ID: shard},
|
||||||
)
|
)
|
||||||
bound_instance = BoundInstance(instance=local_instance, bound_runner_id=RUNNER_1_ID)
|
bound_instance = BoundInstance(
|
||||||
|
instance=local_instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
|
||||||
|
)
|
||||||
local_runner = FakeRunnerSupervisor(
|
local_runner = FakeRunnerSupervisor(
|
||||||
bound_instance=bound_instance, status=RunnerReady()
|
bound_instance=bound_instance, status=RunnerReady()
|
||||||
)
|
)
|
||||||
@@ -175,7 +181,9 @@ def test_plan_ignores_non_pending_or_non_chat_tasks():
|
|||||||
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},
|
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},
|
||||||
runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1},
|
runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1},
|
||||||
)
|
)
|
||||||
bound_instance = BoundInstance(instance=instance, bound_runner_id=RUNNER_1_ID)
|
bound_instance = BoundInstance(
|
||||||
|
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
|
||||||
|
)
|
||||||
|
|
||||||
local_runner = FakeRunnerSupervisor(
|
local_runner = FakeRunnerSupervisor(
|
||||||
bound_instance=bound_instance, status=RunnerReady()
|
bound_instance=bound_instance, status=RunnerReady()
|
||||||
@@ -236,7 +244,9 @@ def test_plan_returns_none_when_nothing_to_do():
|
|||||||
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},
|
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},
|
||||||
runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1},
|
runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1},
|
||||||
)
|
)
|
||||||
bound_instance = BoundInstance(instance=instance, bound_runner_id=RUNNER_1_ID)
|
bound_instance = BoundInstance(
|
||||||
|
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
|
||||||
|
)
|
||||||
local_runner = FakeRunnerSupervisor(
|
local_runner = FakeRunnerSupervisor(
|
||||||
bound_instance=bound_instance, status=RunnerRunning()
|
bound_instance=bound_instance, status=RunnerRunning()
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -35,7 +35,9 @@ def test_plan_starts_warmup_for_non_zero_rank_when_all_loaded_or_warming():
|
|||||||
runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1},
|
runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1},
|
||||||
)
|
)
|
||||||
|
|
||||||
bound_instance = BoundInstance(instance=instance, bound_runner_id=RUNNER_2_ID)
|
bound_instance = BoundInstance(
|
||||||
|
instance=instance, bound_runner_id=RUNNER_2_ID, bound_node_id=NODE_B
|
||||||
|
)
|
||||||
local_runner = FakeRunnerSupervisor(
|
local_runner = FakeRunnerSupervisor(
|
||||||
bound_instance=bound_instance, status=RunnerLoaded()
|
bound_instance=bound_instance, status=RunnerLoaded()
|
||||||
)
|
)
|
||||||
@@ -75,7 +77,9 @@ def test_plan_starts_warmup_for_rank_zero_after_others_warming():
|
|||||||
runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1},
|
runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1},
|
||||||
)
|
)
|
||||||
|
|
||||||
bound_instance = BoundInstance(instance=instance, bound_runner_id=RUNNER_1_ID)
|
bound_instance = BoundInstance(
|
||||||
|
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
|
||||||
|
)
|
||||||
local_runner = FakeRunnerSupervisor(
|
local_runner = FakeRunnerSupervisor(
|
||||||
bound_instance=bound_instance, status=RunnerLoaded()
|
bound_instance=bound_instance, status=RunnerLoaded()
|
||||||
)
|
)
|
||||||
@@ -114,7 +118,9 @@ def test_plan_does_not_start_warmup_for_non_zero_rank_until_all_loaded_or_warmin
|
|||||||
runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1},
|
runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1},
|
||||||
)
|
)
|
||||||
|
|
||||||
bound_instance = BoundInstance(instance=instance, bound_runner_id=RUNNER_2_ID)
|
bound_instance = BoundInstance(
|
||||||
|
instance=instance, bound_runner_id=RUNNER_2_ID, bound_node_id=NODE_B
|
||||||
|
)
|
||||||
local_runner = FakeRunnerSupervisor(
|
local_runner = FakeRunnerSupervisor(
|
||||||
bound_instance=bound_instance, status=RunnerLoaded()
|
bound_instance=bound_instance, status=RunnerLoaded()
|
||||||
)
|
)
|
||||||
@@ -153,7 +159,9 @@ def test_plan_does_not_start_warmup_for_rank_zero_until_others_warming():
|
|||||||
runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1},
|
runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1},
|
||||||
)
|
)
|
||||||
|
|
||||||
bound_instance = BoundInstance(instance=instance, bound_runner_id=RUNNER_1_ID)
|
bound_instance = BoundInstance(
|
||||||
|
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
|
||||||
|
)
|
||||||
local_runner = FakeRunnerSupervisor(
|
local_runner = FakeRunnerSupervisor(
|
||||||
bound_instance=bound_instance, status=RunnerLoaded()
|
bound_instance=bound_instance, status=RunnerLoaded()
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user