mirror of
https://github.com/exo-explore/exo.git
synced 2026-01-30 16:51:02 -05:00
Compare commits
4 Commits
ciaran/pro
...
leo/profil
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
12669cd59f | ||
|
|
dc7c3bb5b4 | ||
|
|
ed59e1d0c5 | ||
|
|
b28c0bbe96 |
@@ -40,6 +40,7 @@ class Node:
|
||||
|
||||
node_id: NodeId
|
||||
event_index_counter: Iterator[int]
|
||||
_profiling_in_progress: set[NodeId] = field(default_factory=set)
|
||||
_tg: TaskGroup = field(init=False, default_factory=anyio.create_task_group)
|
||||
|
||||
@classmethod
|
||||
@@ -86,6 +87,7 @@ class Node:
|
||||
else:
|
||||
api = None
|
||||
|
||||
profiling_in_progress: set[NodeId] = set()
|
||||
if not args.no_worker:
|
||||
worker = Worker(
|
||||
node_id,
|
||||
@@ -96,6 +98,7 @@ class Node:
|
||||
command_sender=router.sender(topics.COMMANDS),
|
||||
download_command_sender=router.sender(topics.DOWNLOAD_COMMANDS),
|
||||
event_index_counter=event_index_counter,
|
||||
profiling_in_progress=profiling_in_progress,
|
||||
)
|
||||
else:
|
||||
worker = None
|
||||
@@ -133,6 +136,7 @@ class Node:
|
||||
api,
|
||||
node_id,
|
||||
event_index_counter,
|
||||
profiling_in_progress,
|
||||
)
|
||||
|
||||
async def run(self):
|
||||
@@ -239,6 +243,7 @@ class Node:
|
||||
topics.DOWNLOAD_COMMANDS
|
||||
),
|
||||
event_index_counter=self.event_index_counter,
|
||||
profiling_in_progress=self._profiling_in_progress,
|
||||
)
|
||||
self._tg.start_soon(self.worker.run)
|
||||
if self.api:
|
||||
|
||||
@@ -8,8 +8,8 @@ from typing import Annotated, Literal, cast
|
||||
from uuid import uuid4
|
||||
|
||||
import anyio
|
||||
from anyio import BrokenResourceError, create_task_group
|
||||
from anyio.abc import TaskGroup
|
||||
from anyio import BrokenResourceError, ClosedResourceError, EndOfStream, create_task_group
|
||||
from anyio.abc import SocketStream, TaskGroup
|
||||
from fastapi import FastAPI, File, Form, HTTPException, Query, Request, UploadFile
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import FileResponse, JSONResponse, StreamingResponse
|
||||
@@ -1250,6 +1250,7 @@ class API:
|
||||
tg.start_soon(self._apply_state)
|
||||
tg.start_soon(self._pause_on_new_election)
|
||||
tg.start_soon(self._cleanup_expired_images)
|
||||
tg.start_soon(self._run_bandwidth_server)
|
||||
print_startup_banner(self.port)
|
||||
await serve(
|
||||
cast(ASGIFramework, self.app),
|
||||
@@ -1257,6 +1258,43 @@ class API:
|
||||
shutdown_trigger=lambda: anyio.sleep_forever(),
|
||||
)
|
||||
|
||||
async def _run_bandwidth_server(self):
|
||||
"""TCP server for iperf-like bandwidth testing."""
|
||||
bandwidth_port = self.port + 1
|
||||
listener = await anyio.create_tcp_listener(local_port=bandwidth_port)
|
||||
logger.info(f"Bandwidth test server listening on port {bandwidth_port}")
|
||||
await listener.serve(self._handle_bandwidth_connection)
|
||||
|
||||
async def _handle_bandwidth_connection(self, stream: SocketStream) -> None:
|
||||
"""Handle a single bandwidth test connection."""
|
||||
try:
|
||||
mode = await stream.receive(1)
|
||||
if mode == b"U":
|
||||
# Upload test: client sends, we receive
|
||||
bytes_received = 0
|
||||
start = time.perf_counter()
|
||||
while True:
|
||||
try:
|
||||
data = await stream.receive(1024 * 1024)
|
||||
if not data or data == b"DONE" or b"DONE" in data:
|
||||
break
|
||||
bytes_received += len(data)
|
||||
except EndOfStream:
|
||||
break
|
||||
elapsed = time.perf_counter() - start
|
||||
logger.debug(f"Bandwidth upload: {bytes_received} bytes in {elapsed:.3f}s")
|
||||
elif mode == b"D":
|
||||
# Download test: we send, client receives
|
||||
chunk = b"X" * (1024 * 1024)
|
||||
start = time.perf_counter()
|
||||
while time.perf_counter() - start < 1.0: # Send for 1s
|
||||
try:
|
||||
await stream.send(chunk)
|
||||
except (BrokenResourceError, ClosedResourceError):
|
||||
break
|
||||
except Exception as e:
|
||||
logger.debug(f"Bandwidth connection error: {e}")
|
||||
|
||||
self.command_sender.close()
|
||||
self.global_event_receiver.close()
|
||||
|
||||
|
||||
@@ -28,6 +28,9 @@ def create_node_network() -> NodeNetworkInfo:
|
||||
def create_socket_connection(ip: int, sink_port: int = 1234) -> SocketConnection:
|
||||
return SocketConnection(
|
||||
sink_multiaddr=Multiaddr(address=f"/ip4/169.254.0.{ip}/tcp/{sink_port}"),
|
||||
latency_ms=1.0,
|
||||
other_to_sink_bandwidth_mbps=1000.0,
|
||||
sink_to_other_bandwidth_mbps=1000.0,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -366,7 +366,10 @@ def test_tensor_rdma_backend_connectivity_matrix(
|
||||
ip_address="10.0.0.1",
|
||||
)
|
||||
ethernet_conn = SocketConnection(
|
||||
sink_multiaddr=Multiaddr(address="/ip4/10.0.0.1/tcp/8000")
|
||||
sink_multiaddr=Multiaddr(address="/ip4/10.0.0.1/tcp/8000"),
|
||||
latency_ms=1.0,
|
||||
other_to_sink_bandwidth_mbps=1000.0,
|
||||
sink_to_other_bandwidth_mbps=1000.0,
|
||||
)
|
||||
|
||||
node_network = {
|
||||
|
||||
@@ -15,6 +15,9 @@ def topology() -> Topology:
|
||||
def socket_connection() -> SocketConnection:
|
||||
return SocketConnection(
|
||||
sink_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/1235"),
|
||||
latency_ms=1.0,
|
||||
sink_to_other_bandwidth_mbps=1000.0,
|
||||
other_to_sink_bandwidth_mbps=1000.0,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -16,6 +16,9 @@ def test_state_serialization_roundtrip() -> None:
|
||||
sink=node_b,
|
||||
edge=SocketConnection(
|
||||
sink_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/10001"),
|
||||
latency_ms=1.5,
|
||||
sink_to_other_bandwidth_mbps=1000.0,
|
||||
other_to_sink_bandwidth_mbps=1100.0,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -24,10 +24,22 @@ class RDMAConnection(FrozenModel):
|
||||
|
||||
class SocketConnection(FrozenModel):
|
||||
sink_multiaddr: Multiaddr
|
||||
latency_ms: float
|
||||
sink_to_other_bandwidth_mbps: float
|
||||
other_to_sink_bandwidth_mbps: float
|
||||
|
||||
@property
|
||||
def bandwidth_mbps(self):
|
||||
return min(self.sink_to_other_bandwidth_mbps, self.other_to_sink_bandwidth_mbps)
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.sink_multiaddr.ip_address)
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, SocketConnection):
|
||||
return NotImplemented
|
||||
return self.sink_multiaddr == other.sink_multiaddr
|
||||
|
||||
|
||||
class Connection(FrozenModel):
|
||||
source: NodeId
|
||||
|
||||
125
src/exo/utils/info_gatherer/connection_profiler.py
Normal file
125
src/exo/utils/info_gatherer/connection_profiler.py
Normal file
@@ -0,0 +1,125 @@
|
||||
import time
|
||||
|
||||
import anyio
|
||||
import httpx
|
||||
from anyio.abc import SocketStream
|
||||
from exo.shared.logging import logger
|
||||
from pydantic.v1 import BaseModel
|
||||
|
||||
LATENCY_PING_COUNT = 5
|
||||
BANDWIDTH_TEST_DURATION_S = 0.5
|
||||
BANDWIDTH_TEST_PORT_OFFSET = 1 # API port + 1
|
||||
|
||||
|
||||
class ConnectionProfile(BaseModel):
|
||||
latency_ms: float
|
||||
upload_mbps: float
|
||||
download_mbps: float
|
||||
|
||||
|
||||
async def measure_latency(target_ip: str, port: int = 52415) -> float:
|
||||
if ":" in target_ip:
|
||||
url = f"http://[{target_ip}]:{port}/node_id"
|
||||
else:
|
||||
url = f"http://{target_ip}:{port}/node_id"
|
||||
|
||||
rtts: list[float] = []
|
||||
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
for _ in range(LATENCY_PING_COUNT):
|
||||
try:
|
||||
start = time.perf_counter()
|
||||
response = await client.get(url)
|
||||
end = time.perf_counter()
|
||||
|
||||
if response.status_code == 200:
|
||||
rtts.append((end - start) * 1000)
|
||||
except (httpx.TimeoutException, httpx.NetworkError, httpx.RemoteProtocolError) as e:
|
||||
logger.debug(f"Latency ping failed: {e}")
|
||||
|
||||
if not rtts:
|
||||
raise ConnectionError(f"Failed to measure latency to {target_ip}:{port}")
|
||||
|
||||
return sum(rtts) / len(rtts)
|
||||
|
||||
|
||||
async def _measure_upload_tcp(stream: SocketStream, duration: float) -> float:
|
||||
"""Send data for duration seconds, return Mbps."""
|
||||
chunk = b"X" * (1024 * 1024) # 1MB
|
||||
bytes_sent = 0
|
||||
start = time.perf_counter()
|
||||
deadline = start + duration
|
||||
|
||||
while time.perf_counter() < deadline:
|
||||
await stream.send(chunk)
|
||||
bytes_sent += len(chunk)
|
||||
|
||||
elapsed = time.perf_counter() - start
|
||||
return (bytes_sent * 8 / elapsed) / 1_000_000 if elapsed > 0 else 0.0
|
||||
|
||||
|
||||
async def _measure_download_tcp(stream: SocketStream, duration: float) -> float:
|
||||
"""Receive data for duration seconds, return Mbps."""
|
||||
bytes_received = 0
|
||||
start = time.perf_counter()
|
||||
|
||||
with anyio.move_on_after(duration):
|
||||
while True:
|
||||
data = await stream.receive(1024 * 1024)
|
||||
if not data:
|
||||
break
|
||||
bytes_received += len(data)
|
||||
|
||||
elapsed = time.perf_counter() - start
|
||||
return (bytes_received * 8 / elapsed) / 1_000_000 if elapsed > 0 else 0.0
|
||||
|
||||
|
||||
async def measure_bandwidth_tcp(target_ip: str, port: int) -> tuple[float, float]:
|
||||
"""Measure bandwidth using raw TCP like iperf."""
|
||||
upload_mbps = 0.0
|
||||
download_mbps = 0.0
|
||||
|
||||
try:
|
||||
async with await anyio.connect_tcp(target_ip, port) as stream:
|
||||
# Protocol: send 'U' for upload test, 'D' for download test
|
||||
# Upload: client sends, server receives
|
||||
await stream.send(b"U")
|
||||
upload_mbps = await _measure_upload_tcp(stream, BANDWIDTH_TEST_DURATION_S)
|
||||
await stream.send(b"DONE")
|
||||
logger.debug(f"Upload: {upload_mbps:.1f} Mbps")
|
||||
except Exception as e:
|
||||
logger.debug(f"Upload TCP test failed: {e}")
|
||||
|
||||
try:
|
||||
async with await anyio.connect_tcp(target_ip, port) as stream:
|
||||
# Download: client receives, server sends
|
||||
await stream.send(b"D")
|
||||
download_mbps = await _measure_download_tcp(stream, BANDWIDTH_TEST_DURATION_S)
|
||||
logger.debug(f"Download: {download_mbps:.1f} Mbps")
|
||||
except Exception as e:
|
||||
logger.debug(f"Download TCP test failed: {e}")
|
||||
|
||||
return upload_mbps, download_mbps
|
||||
|
||||
|
||||
async def profile_connection(target_ip: str, port: int = 52415) -> ConnectionProfile:
|
||||
logger.debug(f"Profiling connection to {target_ip}:{port}")
|
||||
|
||||
latency_ms = await measure_latency(target_ip, port)
|
||||
logger.debug(f"Measured latency to {target_ip}: {latency_ms:.2f}ms")
|
||||
|
||||
bandwidth_port = port + BANDWIDTH_TEST_PORT_OFFSET
|
||||
upload_mbps, download_mbps = await measure_bandwidth_tcp(target_ip, bandwidth_port)
|
||||
logger.debug(
|
||||
f"Measured bandwidth to {target_ip}: "
|
||||
f"upload={upload_mbps:.1f}Mbps, download={download_mbps:.1f}Mbps"
|
||||
)
|
||||
|
||||
if upload_mbps == 0.0 and download_mbps == 0.0:
|
||||
raise ConnectionError(f"Failed to measure bandwidth to {target_ip}:{bandwidth_port}")
|
||||
|
||||
return ConnectionProfile(
|
||||
latency_ms=latency_ms,
|
||||
upload_mbps=upload_mbps,
|
||||
download_mbps=download_mbps,
|
||||
)
|
||||
@@ -44,6 +44,7 @@ from exo.shared.types.topology import Connection, SocketConnection
|
||||
from exo.shared.types.worker.runners import RunnerId
|
||||
from exo.utils.channels import Receiver, Sender, channel
|
||||
from exo.utils.event_buffer import OrderedBuffer
|
||||
from exo.utils.info_gatherer.connection_profiler import profile_connection
|
||||
from exo.utils.info_gatherer.info_gatherer import GatheredInfo, InfoGatherer
|
||||
from exo.utils.info_gatherer.net_profile import check_reachable
|
||||
from exo.utils.keyed_backoff import KeyedBackoff
|
||||
@@ -65,6 +66,7 @@ class Worker:
|
||||
command_sender: Sender[ForwarderCommand],
|
||||
download_command_sender: Sender[ForwarderDownloadCommand],
|
||||
event_index_counter: Iterator[int],
|
||||
profiling_in_progress: set[NodeId] | None = None,
|
||||
):
|
||||
self.node_id: NodeId = node_id
|
||||
self.session_id: SessionId = session_id
|
||||
@@ -81,6 +83,8 @@ class Worker:
|
||||
self.state: State = State()
|
||||
self.runners: dict[RunnerId, RunnerSupervisor] = {}
|
||||
self._tg: TaskGroup = create_task_group()
|
||||
self._profiling_in_progress: set[NodeId] = profiling_in_progress if profiling_in_progress is not None else set()
|
||||
self._profiling_lock = anyio.Lock()
|
||||
|
||||
self._nack_cancel_scope: CancelScope | None = None
|
||||
self._nack_attempts: int = 0
|
||||
@@ -282,37 +286,73 @@ class Worker:
|
||||
async def _connection_message_event_writer(self):
|
||||
with self.connection_message_receiver as connection_messages:
|
||||
async for msg in connection_messages:
|
||||
await self.event_sender.send(
|
||||
self._convert_connection_message_to_event(msg)
|
||||
)
|
||||
event = await self._convert_connection_message_to_event(msg)
|
||||
if event:
|
||||
await self.event_sender.send(event)
|
||||
|
||||
def _convert_connection_message_to_event(self, msg: ConnectionMessage):
|
||||
match msg.connection_type:
|
||||
case ConnectionMessageType.Connected:
|
||||
return TopologyEdgeCreated(
|
||||
conn=Connection(
|
||||
source=self.node_id,
|
||||
sink=msg.node_id,
|
||||
edge=SocketConnection(
|
||||
sink_multiaddr=Multiaddr(
|
||||
address=f"/ip4/{msg.remote_ipv4}/tcp/{msg.remote_tcp_port}"
|
||||
),
|
||||
),
|
||||
),
|
||||
)
|
||||
async def _convert_connection_message_to_event(self, msg: ConnectionMessage):
|
||||
if msg.connection_type == ConnectionMessageType.Connected:
|
||||
async with self._profiling_lock:
|
||||
if msg.node_id in self._profiling_in_progress:
|
||||
return None
|
||||
self._profiling_in_progress.add(msg.node_id)
|
||||
return await self._profile_and_emit_connection(
|
||||
msg.node_id,
|
||||
msg.remote_ipv4,
|
||||
msg.remote_tcp_port,
|
||||
)
|
||||
elif msg.connection_type == ConnectionMessageType.Disconnected:
|
||||
self._profiling_in_progress.discard(msg.node_id)
|
||||
target_ip = msg.remote_ipv4
|
||||
for connection in self.state.topology.list_connections():
|
||||
if (
|
||||
isinstance(connection.edge, SocketConnection)
|
||||
and connection.edge.sink_multiaddr.ip_address == target_ip
|
||||
):
|
||||
return TopologyEdgeDeleted(conn=connection)
|
||||
|
||||
case ConnectionMessageType.Disconnected:
|
||||
return TopologyEdgeDeleted(
|
||||
conn=Connection(
|
||||
source=self.node_id,
|
||||
sink=msg.node_id,
|
||||
edge=SocketConnection(
|
||||
sink_multiaddr=Multiaddr(
|
||||
address=f"/ip4/{msg.remote_ipv4}/tcp/{msg.remote_tcp_port}"
|
||||
),
|
||||
),
|
||||
async def _profile_and_emit_connection(
|
||||
self,
|
||||
sink_node_id: NodeId,
|
||||
remote_ip: str,
|
||||
remote_port: int,
|
||||
):
|
||||
try:
|
||||
profile = await profile_connection(remote_ip)
|
||||
except ConnectionError as e:
|
||||
logger.warning(f"Failed to profile connection to {sink_node_id}: {e}")
|
||||
profile = None
|
||||
|
||||
if profile:
|
||||
latency_ms = profile.latency_ms
|
||||
other_to_sink_mbps = profile.upload_mbps
|
||||
sink_to_other_mbps = profile.download_mbps
|
||||
else:
|
||||
latency_ms = 0.0
|
||||
other_to_sink_mbps = 0.0
|
||||
sink_to_other_mbps = 0.0
|
||||
|
||||
logger.info(
|
||||
f"Connection to {sink_node_id} profiled: "
|
||||
f"latency={latency_ms:.2f}ms, "
|
||||
f"upload={other_to_sink_mbps:.1f}Mbps, "
|
||||
f"download={sink_to_other_mbps:.1f}Mbps"
|
||||
)
|
||||
|
||||
return TopologyEdgeCreated(
|
||||
conn=Connection(
|
||||
source=self.node_id,
|
||||
sink=sink_node_id,
|
||||
edge=SocketConnection(
|
||||
sink_multiaddr=Multiaddr(
|
||||
address=f"/ip4/{remote_ip}/tcp/{remote_port}"
|
||||
),
|
||||
)
|
||||
latency_ms=latency_ms,
|
||||
sink_to_other_bandwidth_mbps=sink_to_other_mbps,
|
||||
other_to_sink_bandwidth_mbps=other_to_sink_mbps,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
async def _nack_request(self, since_idx: int) -> None:
|
||||
# We request all events after (and including) the missing index.
|
||||
@@ -389,21 +429,21 @@ class Worker:
|
||||
)
|
||||
for nid in conns:
|
||||
for ip in conns[nid]:
|
||||
edge = SocketConnection(
|
||||
# nonsense multiaddr
|
||||
temp_edge = SocketConnection(
|
||||
sink_multiaddr=Multiaddr(address=f"/ip4/{ip}/tcp/52415")
|
||||
if "." in ip
|
||||
# nonsense multiaddr
|
||||
else Multiaddr(address=f"/ip6/{ip}/tcp/52415"),
|
||||
latency_ms=0.0,
|
||||
sink_to_other_bandwidth_mbps=0.0,
|
||||
other_to_sink_bandwidth_mbps=0.0,
|
||||
)
|
||||
if edge not in edges:
|
||||
logger.debug(f"ping discovered {edge=}")
|
||||
await self.event_sender.send(
|
||||
TopologyEdgeCreated(
|
||||
conn=Connection(
|
||||
source=self.node_id, sink=nid, edge=edge
|
||||
)
|
||||
)
|
||||
if temp_edge not in edges:
|
||||
logger.debug(f"ping discovered new connection to {nid} at {ip}")
|
||||
self._tg.start_soon(
|
||||
self._profile_and_emit_connection,
|
||||
nid,
|
||||
ip,
|
||||
52415,
|
||||
)
|
||||
|
||||
for conn in self.state.topology.out_edges(self.node_id):
|
||||
|
||||
Reference in New Issue
Block a user