mirror of
https://github.com/exo-explore/exo.git
synced 2025-12-23 22:27:50 -05:00
actually got some comms throguh
This commit is contained in:
@@ -41,7 +41,7 @@ class RustConnectionMessage:
|
||||
@property
|
||||
def endpoint_id(self) -> EndpointId: ...
|
||||
@property
|
||||
def current_transport_addrs(self) -> builtins.set[IpAddress]: ...
|
||||
def current_transport_addrs(self) -> typing.Optional[builtins.set[IpAddress]]: ...
|
||||
|
||||
@typing.final
|
||||
class RustConnectionReceiver:
|
||||
|
||||
@@ -97,7 +97,7 @@ pub struct PyConnectionMessage {
|
||||
#[pyo3(get)]
|
||||
pub endpoint_id: PyEndpointId,
|
||||
#[pyo3(get)]
|
||||
pub current_transport_addrs: BTreeSet<PyIpAddress>,
|
||||
pub current_transport_addrs: Option<BTreeSet<PyIpAddress>>,
|
||||
}
|
||||
|
||||
#[gen_stub_pyclass]
|
||||
@@ -169,16 +169,17 @@ impl PyConnectionReceiver {
|
||||
}) => {
|
||||
return Ok(PyConnectionMessage {
|
||||
endpoint_id: endpoint_id.into(),
|
||||
current_transport_addrs: data
|
||||
.ip_addrs()
|
||||
.map(|it| PyIpAddress { inner: it.clone() })
|
||||
.collect(),
|
||||
current_transport_addrs: Some(
|
||||
data.ip_addrs()
|
||||
.map(|it| PyIpAddress { inner: it.clone() })
|
||||
.collect(),
|
||||
),
|
||||
});
|
||||
}
|
||||
Some(DiscoveryEvent::Expired { endpoint_id }) => {
|
||||
return Ok(PyConnectionMessage {
|
||||
endpoint_id: endpoint_id.into(),
|
||||
current_transport_addrs: BTreeSet::new(),
|
||||
current_transport_addrs: None,
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
79
rust/iroh_networking/examples/chatroom.rs
Normal file
79
rust/iroh_networking/examples/chatroom.rs
Normal file
@@ -0,0 +1,79 @@
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use iroh::SecretKey;
|
||||
use iroh_gossip::api::{Event, Message};
|
||||
use iroh_networking::ExoNet;
|
||||
use n0_future::StreamExt;
|
||||
use tokio::time::sleep;
|
||||
use tokio::{io, io::AsyncBufReadExt as _};
|
||||
use tracing_subscriber::EnvFilter;
|
||||
use tracing_subscriber::filter::LevelFilter;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
let _ = tracing_subscriber::fmt()
|
||||
.with_env_filter(EnvFilter::from_default_env().add_directive(LevelFilter::INFO.into()))
|
||||
.try_init();
|
||||
|
||||
// Configure swarm
|
||||
let net = Arc::new(
|
||||
ExoNet::init_iroh(SecretKey::generate(&mut rand::rng()), "chatroom")
|
||||
.await
|
||||
.expect("iroh init shouldn't fail"),
|
||||
);
|
||||
let innet = Arc::clone(&net);
|
||||
let _jh = tokio::spawn(async move { innet.start_auto_dialer().await });
|
||||
|
||||
while net.known_peers.lock().await.is_empty() {
|
||||
sleep(Duration::from_secs(1)).await
|
||||
}
|
||||
|
||||
// Create a Gossipsub topic & subscribe
|
||||
let (send, mut recv) = net
|
||||
.subscribe(&"chatting")
|
||||
.await
|
||||
.expect("topic shouldn't fail");
|
||||
|
||||
// Read full lines from stdin
|
||||
let mut stdin = io::BufReader::new(io::stdin()).lines();
|
||||
println!("Enter messages via STDIN and they will be sent to connected peers using Gossipsub");
|
||||
|
||||
let jh1 = tokio::spawn(async move {
|
||||
while let Ok(Some(line)) = stdin.next_line().await {
|
||||
if let Err(e) = send.broadcast(line.into()).await {
|
||||
println!("Publish error: {e:?}");
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let jh2 = tokio::spawn(async move {
|
||||
while let Some(Ok(event)) = recv.next().await {
|
||||
match event {
|
||||
// on gossipsub incoming
|
||||
Event::Received(Message {
|
||||
content,
|
||||
delivered_from,
|
||||
..
|
||||
}) => println!(
|
||||
"\n\nGot message: '{}' with from peer: {delivered_from}\n\n",
|
||||
String::from_utf8_lossy(&content),
|
||||
),
|
||||
|
||||
// on discovery
|
||||
Event::NeighborUp(peer_id) => {
|
||||
println!("\n\nConnected to: {peer_id}\n\n");
|
||||
}
|
||||
Event::NeighborDown(peer_id) => {
|
||||
eprintln!("\n\nDisconnected from: {peer_id}\n\n");
|
||||
}
|
||||
Event::Lagged => {
|
||||
eprintln!("\n\nLagged\n\n");
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
jh1.await.unwrap();
|
||||
jh2.await.unwrap();
|
||||
_jh.await.unwrap();
|
||||
}
|
||||
@@ -1,7 +1,7 @@
|
||||
use std::collections::BTreeSet;
|
||||
|
||||
use iroh::{
|
||||
Endpoint, SecretKey,
|
||||
Endpoint, EndpointId, SecretKey,
|
||||
discovery::{
|
||||
IntoDiscoveryError,
|
||||
mdns::{DiscoveryEvent, MdnsDiscovery},
|
||||
@@ -15,8 +15,9 @@ use iroh_gossip::{
|
||||
api::{ApiError, GossipReceiver, GossipSender},
|
||||
};
|
||||
|
||||
use n0_error::stack_error;
|
||||
use n0_error::{e, stack_error};
|
||||
use n0_future::{Stream, StreamExt};
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
#[stack_error(derive, add_meta, from_sources)]
|
||||
pub enum Error {
|
||||
@@ -27,14 +28,17 @@ pub enum Error {
|
||||
FailedCommunication { source: ApiError },
|
||||
#[error("No IP Protocol supported on device")]
|
||||
IPNotSupported { source: IntoDiscoveryError },
|
||||
#[error("No peers found before subscribing")]
|
||||
NoPeers,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ExoNet {
|
||||
alpn: String,
|
||||
router: Router,
|
||||
gossip: Gossip,
|
||||
mdns: MdnsDiscovery,
|
||||
pub alpn: String,
|
||||
pub router: Router,
|
||||
pub gossip: Gossip,
|
||||
pub mdns: MdnsDiscovery,
|
||||
pub known_peers: Mutex<BTreeSet<EndpointId>>,
|
||||
}
|
||||
|
||||
impl ExoNet {
|
||||
@@ -55,31 +59,47 @@ impl ExoNet {
|
||||
router,
|
||||
gossip,
|
||||
mdns,
|
||||
known_peers: Mutex::new(BTreeSet::new()),
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn start_auto_dialer(&self) {
|
||||
let mut dialed = BTreeSet::new();
|
||||
let mut recv = self.connection_info().await;
|
||||
|
||||
log::info!(
|
||||
"Starting auto dialer for id {}",
|
||||
self.router.endpoint().id().to_z32()
|
||||
);
|
||||
while let Some(item) = recv.next().await {
|
||||
match item {
|
||||
DiscoveryEvent::Discovered { endpoint_info, .. } => {
|
||||
if !dialed.contains(&endpoint_info.endpoint_id) {
|
||||
log::info!("Dialing new peer {}", endpoint_info.endpoint_id.to_z32());
|
||||
let _ = self
|
||||
let id = endpoint_info.endpoint_id;
|
||||
if !self
|
||||
.known_peers
|
||||
.lock()
|
||||
.await
|
||||
.contains(&endpoint_info.endpoint_id)
|
||||
&& let Ok(conn) = self
|
||||
.router
|
||||
.endpoint()
|
||||
.connect(endpoint_info, self.alpn.as_bytes())
|
||||
.await;
|
||||
} else {
|
||||
dialed.insert(endpoint_info.endpoint_id);
|
||||
.await
|
||||
&& conn.alpn() == self.alpn.as_bytes()
|
||||
{
|
||||
self.known_peers.lock().await.insert(id);
|
||||
match self.gossip.handle_connection(conn).await {
|
||||
Ok(()) => log::info!("Successfully dialled"),
|
||||
Err(_) => log::info!("Failed to dial peer"),
|
||||
}
|
||||
}
|
||||
}
|
||||
DiscoveryEvent::Expired { endpoint_id } => {
|
||||
dialed.remove(&endpoint_id);
|
||||
log::info!("Peer expired {}", endpoint_id.to_z32());
|
||||
self.known_peers.lock().await.remove(&endpoint_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
log::info!("Auto dialer stopping");
|
||||
}
|
||||
|
||||
pub async fn connection_info(&self) -> impl Stream<Item = DiscoveryEvent> + Unpin + use<> {
|
||||
@@ -87,9 +107,15 @@ impl ExoNet {
|
||||
}
|
||||
|
||||
pub async fn subscribe(&self, topic: &str) -> Result<(GossipSender, GossipReceiver), Error> {
|
||||
if self.known_peers.lock().await.is_empty() {
|
||||
return Err(e!(Error::NoPeers));
|
||||
}
|
||||
Ok(self
|
||||
.gossip
|
||||
.subscribe(str_to_topic_id(topic), vec![])
|
||||
.subscribe_and_join(
|
||||
str_to_topic_id(topic),
|
||||
self.known_peers.lock().await.clone().into_iter().collect(),
|
||||
)
|
||||
.await?
|
||||
.split())
|
||||
}
|
||||
|
||||
@@ -1,15 +1,16 @@
|
||||
from typing import Callable
|
||||
from ipaddress import ip_address
|
||||
|
||||
import pytest
|
||||
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.multiaddr import Multiaddr
|
||||
from exo.shared.types.profiling import (
|
||||
MemoryPerformanceProfile,
|
||||
NodePerformanceProfile,
|
||||
SystemPerformanceProfile,
|
||||
)
|
||||
from exo.shared.types.topology import Connection, ConnectionProfile, NodeInfo
|
||||
from exo.routing.connection_message import SocketAddress
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -54,10 +55,12 @@ def create_connection() -> Callable[[NodeId, NodeId, int | None], Connection]:
|
||||
send_back_port = port_counter
|
||||
port_counter += 1
|
||||
return Connection(
|
||||
local_node_id=source_node_id,
|
||||
send_back_node_id=sink_node_id,
|
||||
send_back_multiaddr=Multiaddr(
|
||||
address=f"/ip4/169.254.0.{ip_counter}/tcp/{send_back_port}"
|
||||
source_id=source_node_id,
|
||||
sink_id=sink_node_id,
|
||||
sink_addr=SocketAddress(
|
||||
ip=ip_address("169.254.0.{ip_counter}"),
|
||||
port=send_back_port,
|
||||
zone_id=None,
|
||||
),
|
||||
connection_profile=ConnectionProfile(
|
||||
throughput=1000, latency=1000, jitter=1000
|
||||
|
||||
@@ -44,7 +44,7 @@ from exo.utils.channels import channel
|
||||
@pytest.mark.asyncio
|
||||
async def test_master():
|
||||
keypair = get_node_id_keypair()
|
||||
node_id = NodeId(keypair.to_peer_id().to_base58())
|
||||
node_id = NodeId(str(keypair.endpoint_id()))
|
||||
session_id = SessionId(master_node_id=node_id, election_clock=0)
|
||||
|
||||
ge_sender, global_event_receiver = channel[ForwarderEvent]()
|
||||
@@ -75,7 +75,7 @@ async def test_master():
|
||||
async with anyio.create_task_group() as tg:
|
||||
tg.start_soon(master.run)
|
||||
|
||||
sender_node_id = NodeId(f"{keypair.to_peer_id().to_base58()}_sender")
|
||||
sender_node_id = NodeId(f"{keypair.to_postcard_encoding()}_sender")
|
||||
# inject a NodePerformanceProfile event
|
||||
logger.info("inject a NodePerformanceProfile event")
|
||||
await local_event_sender.send(
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Callable
|
||||
from ipaddress import ip_address
|
||||
|
||||
import pytest
|
||||
from loguru import logger
|
||||
@@ -14,7 +14,6 @@ from exo.shared.types.events import InstanceCreated, InstanceDeleted
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.models import ModelId, ModelMetadata
|
||||
from exo.shared.types.profiling import NetworkInterfaceInfo, NodePerformanceProfile
|
||||
from exo.shared.types.topology import Connection, NodeInfo
|
||||
from exo.shared.types.worker.instances import (
|
||||
Instance,
|
||||
InstanceId,
|
||||
@@ -25,6 +24,7 @@ from exo.shared.types.worker.instances import (
|
||||
from exo.shared.types.worker.runners import ShardAssignments
|
||||
from exo.shared.types.worker.shards import Sharding
|
||||
|
||||
from exo.master.tests.conftest import create_node, create_connection
|
||||
|
||||
@pytest.fixture
|
||||
def topology() -> Topology:
|
||||
@@ -76,8 +76,6 @@ def test_get_instance_placements_create_instance(
|
||||
expected_layers: tuple[int, int, int],
|
||||
topology: Topology,
|
||||
model_meta: ModelMetadata,
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
create_connection: Callable[[NodeId, NodeId], Connection],
|
||||
):
|
||||
# arrange
|
||||
model_meta.n_layers = total_layers
|
||||
@@ -124,7 +122,6 @@ def test_get_instance_placements_create_instance(
|
||||
|
||||
|
||||
def test_get_instance_placements_one_node_exact_fit(
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
) -> None:
|
||||
topology = Topology()
|
||||
node_id = NodeId()
|
||||
@@ -149,7 +146,6 @@ def test_get_instance_placements_one_node_exact_fit(
|
||||
|
||||
|
||||
def test_get_instance_placements_one_node_fits_with_extra_memory(
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
) -> None:
|
||||
topology = Topology()
|
||||
node_id = NodeId()
|
||||
@@ -174,7 +170,6 @@ def test_get_instance_placements_one_node_fits_with_extra_memory(
|
||||
|
||||
|
||||
def test_get_instance_placements_one_node_not_fit(
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
) -> None:
|
||||
topology = Topology()
|
||||
node_id = NodeId()
|
||||
@@ -237,8 +232,6 @@ def test_get_transition_events_delete_instance(instance: Instance):
|
||||
def test_placement_prioritizes_leaf_cycle_with_less_memory(
|
||||
topology: Topology,
|
||||
model_meta: ModelMetadata,
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
create_connection: Callable[[NodeId, NodeId], Connection],
|
||||
):
|
||||
# Arrange two 3-node cycles. The A-B-C cycle has a leaf node (only one outgoing
|
||||
# neighbor per node). The D-E-F cycle has extra outgoing edges making its nodes
|
||||
@@ -316,8 +309,6 @@ def test_placement_prioritizes_leaf_cycle_with_less_memory(
|
||||
def test_tensor_rdma_backend_connectivity_matrix(
|
||||
topology: Topology,
|
||||
model_meta: ModelMetadata,
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
create_connection: Callable[[NodeId, NodeId], Connection],
|
||||
):
|
||||
model_meta.n_layers = 12
|
||||
model_meta.storage_size.in_bytes = 1500
|
||||
@@ -332,7 +323,7 @@ def test_tensor_rdma_backend_connectivity_matrix(
|
||||
|
||||
ethernet_interface = NetworkInterfaceInfo(
|
||||
name="en0",
|
||||
ip_address="192.168.1.100",
|
||||
ip_address=ip_address("192.168.1.100"),
|
||||
)
|
||||
|
||||
assert node_a.node_profile is not None
|
||||
@@ -347,13 +338,13 @@ def test_tensor_rdma_backend_connectivity_matrix(
|
||||
conn_c_b = create_connection(node_id_c, node_id_b)
|
||||
conn_a_c = create_connection(node_id_a, node_id_c)
|
||||
|
||||
assert conn_a_b.send_back_multiaddr is not None
|
||||
assert conn_b_c.send_back_multiaddr is not None
|
||||
assert conn_c_a.send_back_multiaddr is not None
|
||||
assert conn_a_b.sink_addr is not None
|
||||
assert conn_b_c.sink_addr is not None
|
||||
assert conn_c_a.sink_addr is not None
|
||||
|
||||
assert conn_b_a.send_back_multiaddr is not None
|
||||
assert conn_c_b.send_back_multiaddr is not None
|
||||
assert conn_a_c.send_back_multiaddr is not None
|
||||
assert conn_b_a.sink_addr is not None
|
||||
assert conn_c_b.sink_addr is not None
|
||||
assert conn_a_c.sink_addr is not None
|
||||
|
||||
node_a.node_profile = NodePerformanceProfile(
|
||||
model_id="test",
|
||||
@@ -363,11 +354,11 @@ def test_tensor_rdma_backend_connectivity_matrix(
|
||||
network_interfaces=[
|
||||
NetworkInterfaceInfo(
|
||||
name="en3",
|
||||
ip_address=conn_c_a.send_back_multiaddr.ip_address,
|
||||
ip_address=conn_c_a.sink_addr.ip,
|
||||
),
|
||||
NetworkInterfaceInfo(
|
||||
name="en4",
|
||||
ip_address=conn_b_a.send_back_multiaddr.ip_address,
|
||||
ip_address=conn_b_a.sink_addr.ip,
|
||||
),
|
||||
ethernet_interface,
|
||||
],
|
||||
@@ -381,11 +372,11 @@ def test_tensor_rdma_backend_connectivity_matrix(
|
||||
network_interfaces=[
|
||||
NetworkInterfaceInfo(
|
||||
name="en3",
|
||||
ip_address=conn_c_b.send_back_multiaddr.ip_address,
|
||||
ip_address=conn_c_b.sink_addr.ip,
|
||||
),
|
||||
NetworkInterfaceInfo(
|
||||
name="en4",
|
||||
ip_address=conn_a_b.send_back_multiaddr.ip_address,
|
||||
ip_address=conn_a_b.sink_addr.ip,
|
||||
),
|
||||
ethernet_interface,
|
||||
],
|
||||
@@ -399,11 +390,11 @@ def test_tensor_rdma_backend_connectivity_matrix(
|
||||
network_interfaces=[
|
||||
NetworkInterfaceInfo(
|
||||
name="en3",
|
||||
ip_address=conn_a_c.send_back_multiaddr.ip_address,
|
||||
ip_address=conn_a_c.sink_addr.ip,
|
||||
),
|
||||
NetworkInterfaceInfo(
|
||||
name="en4",
|
||||
ip_address=conn_b_c.send_back_multiaddr.ip_address,
|
||||
ip_address=conn_b_c.sink_addr.ip,
|
||||
),
|
||||
ethernet_interface,
|
||||
],
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
import pytest
|
||||
from ipaddress import ip_address
|
||||
|
||||
from exo.shared.topology import Topology
|
||||
from exo.shared.types.multiaddr import Multiaddr
|
||||
from exo.shared.types.profiling import (
|
||||
MemoryPerformanceProfile,
|
||||
NodePerformanceProfile,
|
||||
SystemPerformanceProfile,
|
||||
)
|
||||
from exo.shared.types.topology import Connection, ConnectionProfile, NodeId, NodeInfo
|
||||
from exo.routing.connection_message import SocketAddress
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -18,9 +19,9 @@ def topology() -> Topology:
|
||||
@pytest.fixture
|
||||
def connection() -> Connection:
|
||||
return Connection(
|
||||
local_node_id=NodeId(),
|
||||
send_back_node_id=NodeId(),
|
||||
send_back_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/1235"),
|
||||
source_id=NodeId(),
|
||||
sink_id=NodeId(),
|
||||
sink_addr=SocketAddress(ip=ip_address("127.0.0.1"), port=1235, zone_id=None),
|
||||
connection_profile=ConnectionProfile(
|
||||
throughput=1000, latency=1000, jitter=1000
|
||||
),
|
||||
@@ -64,12 +65,8 @@ def test_add_connection(
|
||||
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
|
||||
):
|
||||
# arrange
|
||||
topology.add_node(
|
||||
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
|
||||
)
|
||||
topology.add_node(
|
||||
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
|
||||
)
|
||||
topology.add_node(NodeInfo(node_id=connection.source_id, node_profile=node_profile))
|
||||
topology.add_node(NodeInfo(node_id=connection.sink_id, node_profile=node_profile))
|
||||
topology.add_connection(connection)
|
||||
|
||||
# act
|
||||
@@ -83,12 +80,8 @@ def test_update_node_profile(
|
||||
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
|
||||
):
|
||||
# arrange
|
||||
topology.add_node(
|
||||
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
|
||||
)
|
||||
topology.add_node(
|
||||
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
|
||||
)
|
||||
topology.add_node(NodeInfo(node_id=connection.source_id, node_profile=node_profile))
|
||||
topology.add_node(NodeInfo(node_id=connection.sink_id, node_profile=node_profile))
|
||||
topology.add_connection(connection)
|
||||
|
||||
new_node_profile = NodePerformanceProfile(
|
||||
@@ -103,12 +96,10 @@ def test_update_node_profile(
|
||||
)
|
||||
|
||||
# act
|
||||
topology.update_node_profile(
|
||||
connection.local_node_id, node_profile=new_node_profile
|
||||
)
|
||||
topology.update_node_profile(connection.source_id, node_profile=new_node_profile)
|
||||
|
||||
# assert
|
||||
data = topology.get_node_profile(connection.local_node_id)
|
||||
data = topology.get_node_profile(connection.source_id)
|
||||
assert data == new_node_profile
|
||||
|
||||
|
||||
@@ -116,21 +107,17 @@ def test_update_connection_profile(
|
||||
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
|
||||
):
|
||||
# arrange
|
||||
topology.add_node(
|
||||
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
|
||||
)
|
||||
topology.add_node(
|
||||
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
|
||||
)
|
||||
topology.add_node(NodeInfo(node_id=connection.source_id, node_profile=node_profile))
|
||||
topology.add_node(NodeInfo(node_id=connection.sink_id, node_profile=node_profile))
|
||||
topology.add_connection(connection)
|
||||
|
||||
new_connection_profile = ConnectionProfile(
|
||||
throughput=2000, latency=2000, jitter=2000
|
||||
)
|
||||
connection = Connection(
|
||||
local_node_id=connection.local_node_id,
|
||||
send_back_node_id=connection.send_back_node_id,
|
||||
send_back_multiaddr=connection.send_back_multiaddr,
|
||||
source_id=connection.source_id,
|
||||
sink_id=connection.sink_id,
|
||||
sink_addr=connection.sink_addr,
|
||||
connection_profile=new_connection_profile,
|
||||
)
|
||||
|
||||
@@ -146,12 +133,8 @@ def test_remove_connection_still_connected(
|
||||
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
|
||||
):
|
||||
# arrange
|
||||
topology.add_node(
|
||||
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
|
||||
)
|
||||
topology.add_node(
|
||||
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
|
||||
)
|
||||
topology.add_node(NodeInfo(node_id=connection.source_id, node_profile=node_profile))
|
||||
topology.add_node(NodeInfo(node_id=connection.sink_id, node_profile=node_profile))
|
||||
topology.add_connection(connection)
|
||||
|
||||
# act
|
||||
@@ -165,31 +148,23 @@ def test_remove_node_still_connected(
|
||||
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
|
||||
):
|
||||
# arrange
|
||||
topology.add_node(
|
||||
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
|
||||
)
|
||||
topology.add_node(
|
||||
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
|
||||
)
|
||||
topology.add_node(NodeInfo(node_id=connection.source_id, node_profile=node_profile))
|
||||
topology.add_node(NodeInfo(node_id=connection.sink_id, node_profile=node_profile))
|
||||
topology.add_connection(connection)
|
||||
|
||||
# act
|
||||
topology.remove_node(connection.local_node_id)
|
||||
topology.remove_node(connection.source_id)
|
||||
|
||||
# assert
|
||||
assert topology.get_node_profile(connection.local_node_id) is None
|
||||
assert topology.get_node_profile(connection.source_id) is None
|
||||
|
||||
|
||||
def test_list_nodes(
|
||||
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
|
||||
):
|
||||
# arrange
|
||||
topology.add_node(
|
||||
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
|
||||
)
|
||||
topology.add_node(
|
||||
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
|
||||
)
|
||||
topology.add_node(NodeInfo(node_id=connection.source_id, node_profile=node_profile))
|
||||
topology.add_node(NodeInfo(node_id=connection.sink_id, node_profile=node_profile))
|
||||
topology.add_connection(connection)
|
||||
|
||||
# act
|
||||
@@ -199,6 +174,6 @@ def test_list_nodes(
|
||||
assert len(nodes) == 2
|
||||
assert all(isinstance(node, NodeInfo) for node in nodes)
|
||||
assert {node.node_id for node in nodes} == {
|
||||
connection.local_node_id,
|
||||
connection.send_back_node_id,
|
||||
connection.source_id,
|
||||
connection.sink_id,
|
||||
}
|
||||
|
||||
@@ -27,13 +27,13 @@ class SocketAddress(CamelCaseModel):
|
||||
|
||||
class ConnectionMessage(CamelCaseModel):
|
||||
node_id: NodeId
|
||||
ips: set[SocketAddress]
|
||||
ips: set[SocketAddress] | None
|
||||
|
||||
@classmethod
|
||||
def from_rust(cls, message: RustConnectionMessage) -> "ConnectionMessage":
|
||||
return cls(
|
||||
node_id=NodeId(str(message.endpoint_id)),
|
||||
ips=set(
|
||||
ips=None if message.current_transport_addrs is None else set(
|
||||
# TODO: better handle fallible conversion
|
||||
SocketAddress(
|
||||
ip=ip_address(addr.ip_addr()),
|
||||
|
||||
@@ -164,28 +164,36 @@ class Election:
|
||||
self._candidates.append(message)
|
||||
|
||||
async def _connection_receiver(self) -> None:
|
||||
current_peers: set[NodeId] = set()
|
||||
with self._cm_receiver as connection_messages:
|
||||
async for first in connection_messages:
|
||||
# Delay after connection message for time to symmetrically setup
|
||||
await anyio.sleep(0.2)
|
||||
rest = connection_messages.collect()
|
||||
if first.node_id not in current_peers or first.ips is None:
|
||||
if first.node_id not in current_peers:
|
||||
current_peers.add(first.node_id)
|
||||
if first.ips is None:
|
||||
current_peers.remove(first.node_id)
|
||||
# Delay after connection message for time to symmetrically setup
|
||||
await anyio.sleep(0.2)
|
||||
rest = connection_messages.collect()
|
||||
for msg in rest:
|
||||
if msg.node_id not in current_peers:
|
||||
current_peers.add(first.node_id)
|
||||
if msg.ips is None:
|
||||
current_peers.remove(first.node_id)
|
||||
|
||||
logger.debug(
|
||||
f"Connection messages received: {first} followed by {rest}"
|
||||
)
|
||||
logger.debug(f"Current clock: {self.clock}")
|
||||
# These messages are strictly peer to peer
|
||||
self.clock += 1
|
||||
logger.debug(f"New clock: {self.clock}")
|
||||
assert self._tg is not None
|
||||
candidates: list[ElectionMessage] = []
|
||||
self._candidates = candidates
|
||||
logger.debug("Starting new campaign")
|
||||
self._tg.start_soon(
|
||||
self._campaign, candidates, DEFAULT_ELECTION_TIMEOUT
|
||||
)
|
||||
logger.debug("Campaign started")
|
||||
logger.debug("Connection message added")
|
||||
logger.info(
|
||||
f"Connection messages received: {first} followed by {rest}"
|
||||
)
|
||||
logger.info(f"Current clock: {self.clock}")
|
||||
# These messages are strictly peer to peer
|
||||
self.clock += 1
|
||||
logger.info(f"New clock: {self.clock}")
|
||||
candidates: list[ElectionMessage] = []
|
||||
self._candidates = candidates
|
||||
logger.info("Starting new campaign")
|
||||
assert self._tg is not None
|
||||
self._tg.start_soon(self._campaign, candidates)
|
||||
logger.info("Campaign started")
|
||||
|
||||
async def _command_counter(self) -> None:
|
||||
with self._co_receiver as commands:
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import pytest
|
||||
from anyio import create_task_group, fail_after, move_on_after
|
||||
|
||||
from exo.routing.connection_message import ConnectionMessage, ConnectionMessageType
|
||||
from exo.routing.connection_message import ConnectionMessage
|
||||
from exo.shared.election import Election, ElectionMessage, ElectionResult
|
||||
from exo.shared.types.commands import ForwarderCommand, TestCommand
|
||||
from exo.shared.types.common import NodeId, SessionId
|
||||
@@ -330,9 +330,7 @@ async def test_connection_message_triggers_new_round_broadcast() -> None:
|
||||
await cm_tx.send(
|
||||
ConnectionMessage(
|
||||
node_id=NodeId(),
|
||||
connection_type=ConnectionMessageType.Connected,
|
||||
remote_ipv4="",
|
||||
remote_tcp_port=0,
|
||||
ips=set(),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -23,7 +23,7 @@ def _get_keypair_concurrent_subprocess_task(
|
||||
sem.release()
|
||||
# wait to be told to begin simultaneous read
|
||||
ev.wait()
|
||||
queue.put(get_node_id_keypair().to_protobuf_encoding())
|
||||
queue.put(get_node_id_keypair().to_postcard_encoding())
|
||||
|
||||
|
||||
def _get_keypair_concurrent(num_procs: int) -> bytes:
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from ipaddress import ip_address
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.multiaddr import Multiaddr
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.types.topology import Connection
|
||||
from exo.routing.connection_message import SocketAddress
|
||||
|
||||
|
||||
def test_state_serialization_roundtrip() -> None:
|
||||
@@ -12,9 +13,9 @@ def test_state_serialization_roundtrip() -> None:
|
||||
node_b = NodeId("node-b")
|
||||
|
||||
connection = Connection(
|
||||
local_node_id=node_a,
|
||||
send_back_node_id=node_b,
|
||||
send_back_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/10001"),
|
||||
sink_id=node_a,
|
||||
source_id=node_b,
|
||||
sink_addr=SocketAddress(ip=ip_address("127.0.0.1"), port=5354, zone_id=None),
|
||||
)
|
||||
|
||||
state = State()
|
||||
|
||||
@@ -17,6 +17,8 @@ def check_connections(
|
||||
|
||||
conns = list(state.topology.list_connections())
|
||||
for iface in state.node_profiles[remote_id].network_interfaces:
|
||||
if sockets is None:
|
||||
continue
|
||||
for sock in sockets:
|
||||
if iface.ip_address == sock.ip:
|
||||
conn = Connection(source_id=local_id, sink_id=remote_id, sink_addr=sock)
|
||||
|
||||
Reference in New Issue
Block a user