mirror of
https://github.com/exo-explore/exo.git
synced 2026-01-17 02:18:47 -05:00
Compare commits
8 Commits
v1.0.63
...
feat/bandw
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4d414556d5 | ||
|
|
d1f80c9e86 | ||
|
|
ae3086167f | ||
|
|
a480df40bf | ||
|
|
a8a0fa1bd8 | ||
|
|
9c6f9a6080 | ||
|
|
ab31491786 | ||
|
|
9e8d5b759c |
@@ -49,20 +49,22 @@ def get_smallest_cycles(cycles: list[list[NodeInfo]]) -> list[list[NodeInfo]]:
|
||||
return [cycle for cycle in cycles if len(cycle) == min_nodes]
|
||||
|
||||
|
||||
def get_shard_assignments_for_pipeline_parallel(
|
||||
def _assign_layers_by_ram(
|
||||
model_meta: ModelMetadata,
|
||||
selected_cycle: list[NodeWithProfile],
|
||||
):
|
||||
cycle_memory = sum(
|
||||
(node.node_profile.memory.ram_available for node in selected_cycle),
|
||||
start=Memory(),
|
||||
)
|
||||
) -> ShardAssignments:
|
||||
"""Assign layers proportionally based on available RAM."""
|
||||
total_layers = model_meta.n_layers
|
||||
world_size = len(selected_cycle)
|
||||
runner_to_shard: dict[RunnerId, ShardMetadata] = {}
|
||||
node_to_runner: dict[NodeId, RunnerId] = {}
|
||||
|
||||
cycle_memory = sum(
|
||||
(node.node_profile.memory.ram_available for node in selected_cycle),
|
||||
start=Memory(),
|
||||
)
|
||||
layers_assigned = 0
|
||||
|
||||
for i, node in enumerate(selected_cycle):
|
||||
if i == len(selected_cycle) - 1:
|
||||
node_layers = total_layers - layers_assigned
|
||||
@@ -77,7 +79,6 @@ def get_shard_assignments_for_pipeline_parallel(
|
||||
node_layers = max(1, node_layers)
|
||||
|
||||
runner_id = RunnerId()
|
||||
|
||||
shard = PipelineShardMetadata(
|
||||
model_meta=model_meta,
|
||||
device_rank=i,
|
||||
@@ -86,18 +87,143 @@ def get_shard_assignments_for_pipeline_parallel(
|
||||
end_layer=layers_assigned + node_layers,
|
||||
n_layers=total_layers,
|
||||
)
|
||||
|
||||
runner_to_shard[runner_id] = shard
|
||||
node_to_runner[node.node_id] = runner_id
|
||||
layers_assigned += node_layers
|
||||
|
||||
shard_assignments = ShardAssignments(
|
||||
return ShardAssignments(
|
||||
model_id=model_meta.model_id,
|
||||
runner_to_shard=runner_to_shard,
|
||||
node_to_runner=node_to_runner,
|
||||
)
|
||||
|
||||
return shard_assignments
|
||||
|
||||
def _reserve_base_layers(world_size: int, total_layers: int) -> dict[int, int]:
|
||||
"""Reserve 1 layer per node to ensure connectivity."""
|
||||
assignments = {i: 0 for i in range(world_size)}
|
||||
remaining_layers = total_layers
|
||||
|
||||
for i in range(world_size):
|
||||
assignments[i] = 1
|
||||
remaining_layers -= 1
|
||||
|
||||
if remaining_layers < 0:
|
||||
logger.warning(
|
||||
"Fewer layers than nodes! Reducing to 1 layer per node where possible."
|
||||
)
|
||||
assignments = {i: 1 if i < total_layers else 0 for i in range(world_size)}
|
||||
remaining_layers = 0
|
||||
|
||||
return assignments
|
||||
|
||||
|
||||
def _distribute_layers_by_bandwidth(
|
||||
selected_cycle: list[NodeWithProfile],
|
||||
assignments: dict[int, int],
|
||||
remaining_layers: int,
|
||||
model_meta: ModelMetadata,
|
||||
) -> None:
|
||||
"""Distribute remaining layers based on bandwidth and RAM capacity."""
|
||||
indexed_nodes = list(enumerate(selected_cycle))
|
||||
sorted_nodes = sorted(
|
||||
indexed_nodes,
|
||||
key=lambda x: x[1].node_profile.memory_bandwidth or 0,
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
for original_idx, node in sorted_nodes:
|
||||
if remaining_layers <= 0:
|
||||
break
|
||||
|
||||
layer_size_bytes = model_meta.storage_size.in_bytes / model_meta.n_layers
|
||||
max_layers_by_ram = int(
|
||||
node.node_profile.memory.ram_available.in_bytes // layer_size_bytes
|
||||
)
|
||||
can_take = max(0, max_layers_by_ram - assignments[original_idx])
|
||||
take = min(can_take, remaining_layers)
|
||||
assignments[original_idx] += take
|
||||
remaining_layers -= take
|
||||
|
||||
if remaining_layers > 0:
|
||||
logger.warning(
|
||||
"All nodes maxed out on RAM estimation, dumping remaining layers on fastest nodes."
|
||||
)
|
||||
for original_idx, _ in sorted_nodes:
|
||||
assignments[original_idx] += 1
|
||||
remaining_layers -= 1
|
||||
if remaining_layers == 0:
|
||||
break
|
||||
|
||||
|
||||
def _create_shard_assignments(
|
||||
model_meta: ModelMetadata,
|
||||
selected_cycle: list[NodeWithProfile],
|
||||
assignments: dict[int, int],
|
||||
) -> ShardAssignments:
|
||||
"""Create shard assignments from layer assignments."""
|
||||
world_size = len(selected_cycle)
|
||||
runner_to_shard: dict[RunnerId, ShardMetadata] = {}
|
||||
node_to_runner: dict[NodeId, RunnerId] = {}
|
||||
|
||||
current_start = 0
|
||||
for i, node in enumerate(selected_cycle):
|
||||
count = assignments[i]
|
||||
runner_id = RunnerId()
|
||||
shard = PipelineShardMetadata(
|
||||
model_meta=model_meta,
|
||||
device_rank=i,
|
||||
world_size=world_size,
|
||||
start_layer=current_start,
|
||||
end_layer=current_start + count,
|
||||
n_layers=model_meta.n_layers,
|
||||
)
|
||||
runner_to_shard[runner_id] = shard
|
||||
node_to_runner[node.node_id] = runner_id
|
||||
current_start += count
|
||||
|
||||
return ShardAssignments(
|
||||
model_id=model_meta.model_id,
|
||||
runner_to_shard=runner_to_shard,
|
||||
node_to_runner=node_to_runner,
|
||||
)
|
||||
|
||||
|
||||
def _assign_layers_by_bandwidth(
|
||||
model_meta: ModelMetadata,
|
||||
selected_cycle: list[NodeWithProfile],
|
||||
) -> ShardAssignments:
|
||||
"""Assign layers based on memory bandwidth."""
|
||||
logger.info("Using bandwidth-aware shard assignment")
|
||||
|
||||
total_layers = model_meta.n_layers
|
||||
world_size = len(selected_cycle)
|
||||
|
||||
assignments = _reserve_base_layers(world_size, total_layers)
|
||||
remaining_layers = total_layers - sum(assignments.values())
|
||||
|
||||
if remaining_layers > 0:
|
||||
_distribute_layers_by_bandwidth(
|
||||
selected_cycle, assignments, remaining_layers, model_meta
|
||||
)
|
||||
|
||||
return _create_shard_assignments(model_meta, selected_cycle, assignments)
|
||||
|
||||
|
||||
def get_shard_assignments_for_pipeline_parallel(
|
||||
model_meta: ModelMetadata,
|
||||
selected_cycle: list[NodeWithProfile],
|
||||
):
|
||||
has_bandwidth = all(
|
||||
node.node_profile.memory_bandwidth is not None for node in selected_cycle
|
||||
)
|
||||
|
||||
if not has_bandwidth:
|
||||
logger.info(
|
||||
"Bandwidth data missing for some nodes, falling back to RAM-proportional assignment"
|
||||
)
|
||||
return _assign_layers_by_ram(model_meta, selected_cycle)
|
||||
|
||||
return _assign_layers_by_bandwidth(model_meta, selected_cycle)
|
||||
|
||||
|
||||
def get_shard_assignments_for_tensor_parallel(
|
||||
|
||||
@@ -397,3 +397,106 @@ def test_get_mlx_jaccl_coordinators(
|
||||
assert coordinators[node_c_id] == (
|
||||
f"{conn_c_a.send_back_multiaddr.ip_address}:5000"
|
||||
), "node_c should use the IP from conn_c_a"
|
||||
|
||||
|
||||
def test_get_shard_assignments_bandwidth_aware(
|
||||
topology: Topology,
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
create_connection: Callable[[NodeId, NodeId], Connection],
|
||||
):
|
||||
# arrange
|
||||
node_a_id = NodeId()
|
||||
node_b_id = NodeId()
|
||||
node_c_id = NodeId()
|
||||
|
||||
# Create nodes with identical RAM (plenty of it)
|
||||
# Using 1GB to ensure no RAM constraints (model is small)
|
||||
node_a = create_node(1024 * 1024 * 1024, node_a_id)
|
||||
node_b = create_node(1024 * 1024 * 1024, node_b_id)
|
||||
node_c = create_node(1024 * 1024 * 1024, node_c_id)
|
||||
|
||||
# Set Bandwidths: A=400 (Fastest), B=200, C=100 (Slowest)
|
||||
assert node_a.node_profile is not None
|
||||
assert node_b.node_profile is not None
|
||||
assert node_c.node_profile is not None
|
||||
|
||||
node_a.node_profile.memory_bandwidth = 400_000_000_000
|
||||
node_b.node_profile.memory_bandwidth = 200_000_000_000
|
||||
node_c.node_profile.memory_bandwidth = 100_000_000_000
|
||||
|
||||
topology.add_node(node_a)
|
||||
topology.add_node(node_b)
|
||||
topology.add_node(node_c)
|
||||
|
||||
topology.add_connection(create_connection(node_a_id, node_b_id))
|
||||
topology.add_connection(create_connection(node_b_id, node_c_id))
|
||||
topology.add_connection(create_connection(node_c_id, node_a_id))
|
||||
|
||||
# Needs full cycle edges for get_cycles/get_shard_assignments if strict?
|
||||
# Actually get_cycles just looks for cycles.
|
||||
# But let's follow the pattern of other tests if they add bidirectional.
|
||||
# checking test_filter_cycles_by_memory, it adds both directions.
|
||||
topology.add_connection(create_connection(node_b_id, node_a_id))
|
||||
topology.add_connection(create_connection(node_c_id, node_b_id))
|
||||
topology.add_connection(create_connection(node_a_id, node_c_id))
|
||||
|
||||
model_meta = ModelMetadata(
|
||||
model_id=ModelId("test-model"),
|
||||
pretty_name="Test Model",
|
||||
n_layers=30, # 30 layers
|
||||
storage_size=Memory.from_kb(
|
||||
300
|
||||
), # 10KB per layer. Nodes have 100MB RAM (100*1024 in create_node usually means KB? other tests use 1000*1024).
|
||||
# create_node arg is likely KB or Bytes.
|
||||
# test_filter_cycles_by_memory: create_node(1000 * 1024, ...) -> Memory.from_bytes(1) passes.
|
||||
# Let's assume create_node takes Bytes or KB consistently.
|
||||
# If I give 100*1024*1024 bytes = 100MB.
|
||||
# Model storage = 300KB.
|
||||
# So capacity is definitely not an issue.
|
||||
hidden_size=1000,
|
||||
supports_tensor=True,
|
||||
)
|
||||
|
||||
cycles = topology.get_cycles()
|
||||
# Depending on how get_cycles works and order of addition, we might get multiple cycles.
|
||||
# filtering by memory usually done in master.
|
||||
# Here we just pick one.
|
||||
selected_cycle = cycles[0]
|
||||
|
||||
# act
|
||||
shard_assignments = get_shard_assignments(
|
||||
model_meta, selected_cycle, Sharding.Pipeline
|
||||
)
|
||||
|
||||
# assert
|
||||
runner_id_a = shard_assignments.node_to_runner[node_a_id]
|
||||
runner_id_b = shard_assignments.node_to_runner[node_b_id]
|
||||
runner_id_c = shard_assignments.node_to_runner[node_c_id]
|
||||
|
||||
# Get layer counts
|
||||
layers_a = (
|
||||
shard_assignments.runner_to_shard[runner_id_a].end_layer
|
||||
- shard_assignments.runner_to_shard[runner_id_a].start_layer
|
||||
)
|
||||
layers_b = (
|
||||
shard_assignments.runner_to_shard[runner_id_b].end_layer
|
||||
- shard_assignments.runner_to_shard[runner_id_b].start_layer
|
||||
)
|
||||
layers_c = (
|
||||
shard_assignments.runner_to_shard[runner_id_c].end_layer
|
||||
- shard_assignments.runner_to_shard[runner_id_c].start_layer
|
||||
)
|
||||
|
||||
# Check total
|
||||
assert layers_a + layers_b + layers_c == 30
|
||||
|
||||
# Check that the fastest node (A with 400GB/s) gets saturated first.
|
||||
# With strict greedy assignment and plenty of RAM:
|
||||
# 1. Reserve: A=1, B=1, C=1. Remaining=27.
|
||||
# 2. Sort: [A, B, C]
|
||||
# 3. A takes min(remaining=27, capacity=huge) = 27.
|
||||
# 4. A=28, B=1, C=1.
|
||||
|
||||
assert layers_a == 28
|
||||
assert layers_b == 1
|
||||
assert layers_c == 1
|
||||
|
||||
@@ -57,6 +57,7 @@ class NodePerformanceProfile(CamelCaseModel):
|
||||
chip_id: str
|
||||
friendly_name: str
|
||||
memory: MemoryPerformanceProfile
|
||||
memory_bandwidth: int | None = None
|
||||
network_interfaces: list[NetworkInterfaceInfo] = []
|
||||
system: SystemPerformanceProfile
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ import platform
|
||||
from typing import Any, Callable, Coroutine
|
||||
|
||||
import anyio
|
||||
from anyio import to_thread
|
||||
from loguru import logger
|
||||
|
||||
from exo.shared.types.memory import Memory
|
||||
@@ -24,8 +25,61 @@ from .system_info import (
|
||||
get_friendly_name,
|
||||
get_model_and_chip,
|
||||
get_network_interfaces,
|
||||
profile_memory_bandwidth,
|
||||
)
|
||||
|
||||
# Module-level cache for memory bandwidth (doesn't change at runtime)
|
||||
_cached_bandwidth: int | None = None
|
||||
_bandwidth_profiled: bool = False
|
||||
_bandwidth_profiling_task: asyncio.Task[int | None] | None = None
|
||||
|
||||
|
||||
async def profile_bandwidth_once() -> int | None:
|
||||
"""Profile bandwidth once in a background thread and cache the result.
|
||||
|
||||
This function is non-blocking - it runs the profiling in a thread pool.
|
||||
Subsequent calls return the cached result immediately.
|
||||
"""
|
||||
global _cached_bandwidth, _bandwidth_profiled, _bandwidth_profiling_task
|
||||
|
||||
# Already profiled, return cached value
|
||||
if _bandwidth_profiled:
|
||||
return _cached_bandwidth
|
||||
|
||||
# Profiling already in progress, wait for it
|
||||
if _bandwidth_profiling_task is not None:
|
||||
return await _bandwidth_profiling_task
|
||||
|
||||
# Start profiling in background thread
|
||||
async def _do_profile() -> int | None:
|
||||
global _cached_bandwidth, _bandwidth_profiled
|
||||
try:
|
||||
logger.info("Starting memory bandwidth profiling in background thread...")
|
||||
bandwidth = await to_thread.run_sync(profile_memory_bandwidth, cancellable=True)
|
||||
_cached_bandwidth = bandwidth
|
||||
_bandwidth_profiled = True
|
||||
if bandwidth:
|
||||
logger.info(f"Memory bandwidth profiled: {bandwidth / 1e9:.1f} GB/s")
|
||||
else:
|
||||
logger.warning("Memory bandwidth profiling returned None")
|
||||
return bandwidth
|
||||
except Exception as e:
|
||||
logger.opt(exception=e).error("Memory bandwidth profiling failed")
|
||||
_bandwidth_profiled = True # Mark as done to avoid retrying
|
||||
return None
|
||||
|
||||
_bandwidth_profiling_task = asyncio.create_task(_do_profile())
|
||||
return await _bandwidth_profiling_task
|
||||
|
||||
|
||||
def get_memory_bandwidth_cached() -> int | None:
|
||||
"""Return cached bandwidth or None if not yet profiled.
|
||||
|
||||
This is a non-blocking synchronous function that returns immediately.
|
||||
Call profile_bandwidth_once() first to trigger profiling.
|
||||
"""
|
||||
return _cached_bandwidth if _bandwidth_profiled else None
|
||||
|
||||
|
||||
async def get_metrics_async() -> Metrics | None:
|
||||
"""Return detailed Metrics on macOS or a minimal fallback elsewhere."""
|
||||
@@ -71,6 +125,8 @@ async def start_polling_node_metrics(
|
||||
callback: Callable[[NodePerformanceProfile], Coroutine[Any, Any, None]],
|
||||
):
|
||||
poll_interval_s = 1.0
|
||||
bandwidth_profile_started = False
|
||||
|
||||
while True:
|
||||
try:
|
||||
metrics = await get_metrics_async()
|
||||
@@ -85,6 +141,15 @@ async def start_polling_node_metrics(
|
||||
# do the memory profile last to get a fresh reading to not conflict with the other memory profiling loop
|
||||
memory_profile = get_memory_profile()
|
||||
|
||||
# Start bandwidth profiling in background on first poll (non-blocking)
|
||||
if not bandwidth_profile_started:
|
||||
bandwidth_profile_started = True
|
||||
# Fire and forget - don't await, let it run in background
|
||||
asyncio.create_task(profile_bandwidth_once())
|
||||
|
||||
# Use cached bandwidth (None until profiling completes)
|
||||
memory_bandwidth = get_memory_bandwidth_cached()
|
||||
|
||||
await callback(
|
||||
NodePerformanceProfile(
|
||||
model_id=model_id,
|
||||
@@ -92,6 +157,7 @@ async def start_polling_node_metrics(
|
||||
friendly_name=friendly_name,
|
||||
network_interfaces=network_interfaces,
|
||||
memory=memory_profile,
|
||||
memory_bandwidth=memory_bandwidth,
|
||||
system=SystemPerformanceProfile(
|
||||
gpu_usage=metrics.gpu_usage[1],
|
||||
temp=metrics.temp.gpu_temp_avg,
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import socket
|
||||
import sys
|
||||
import time
|
||||
from subprocess import CalledProcessError
|
||||
|
||||
import psutil
|
||||
@@ -81,3 +82,68 @@ async def get_model_and_chip() -> tuple[str, str]:
|
||||
chip = chip_line.split(": ")[1] if chip_line else "Unknown Chip"
|
||||
|
||||
return (model, chip)
|
||||
|
||||
|
||||
def profile_memory_bandwidth() -> int | None:
|
||||
"""
|
||||
Profile device memory bandwidth using MLX GPU operations.
|
||||
|
||||
Uses a large array copy on the GPU to measure unified memory bandwidth.
|
||||
Returns measured bandwidth in bytes/second, or None if MLX is unavailable.
|
||||
"""
|
||||
try:
|
||||
import mlx.core as mx
|
||||
|
||||
if not mx.metal.is_available():
|
||||
return None
|
||||
|
||||
# Use 2GB buffer to better saturate memory bandwidth
|
||||
# Use 2D shape to avoid potential issues with very large 1D arrays
|
||||
size_bytes = 2 * 1024 * 1024 * 1024
|
||||
side = int((size_bytes // 4) ** 0.5) # Square 2D array of float32
|
||||
shape = (side, side)
|
||||
actual_bytes = side * side * 4
|
||||
bytes_transferred = actual_bytes * 2 # read + write
|
||||
|
||||
# Warm-up: run the full benchmark operation multiple times to stabilize GPU
|
||||
for _ in range(3):
|
||||
src = mx.random.uniform(shape=shape, dtype=mx.float32)
|
||||
mx.eval(src)
|
||||
dst = src + 0.0
|
||||
mx.eval(dst)
|
||||
mx.synchronize()
|
||||
del src, dst
|
||||
|
||||
# Benchmark: measure time to copy array
|
||||
best_bandwidth = 0.0
|
||||
num_runs = 4
|
||||
|
||||
for _ in range(num_runs):
|
||||
src = mx.random.uniform(shape=shape, dtype=mx.float32)
|
||||
mx.eval(src)
|
||||
mx.synchronize()
|
||||
|
||||
# Time the copy operation (src + 0.0 forces read of src, write of dst)
|
||||
start = time.perf_counter()
|
||||
dst = src + 0.0
|
||||
mx.eval(dst)
|
||||
mx.synchronize()
|
||||
end = time.perf_counter()
|
||||
|
||||
bandwidth = bytes_transferred / (end - start)
|
||||
best_bandwidth = max(best_bandwidth, bandwidth)
|
||||
|
||||
del src, dst
|
||||
|
||||
return int(best_bandwidth)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def get_memory_bandwidth(_chip_id: str) -> int | None:
|
||||
"""
|
||||
Returns measured memory bandwidth in bytes/second.
|
||||
|
||||
Uses MLX GPU operations for accurate unified memory bandwidth measurement.
|
||||
"""
|
||||
return profile_memory_bandwidth()
|
||||
|
||||
Reference in New Issue
Block a user