diff --git a/core/src/p2p/p2p_manager.rs b/core/src/p2p/p2p_manager.rs index 8d7a9ed28..63a80c39a 100644 --- a/core/src/p2p/p2p_manager.rs +++ b/core/src/p2p/p2p_manager.rs @@ -227,7 +227,7 @@ impl P2PManager { let mut len_buf = len.to_le_bytes(); debug_assert_eq!(len_buf.len(), 4); - head_buf.extend_from_slice(&mut len_buf); + head_buf.extend_from_slice(&len_buf); head_buf.append(&mut buf); self.manager.broadcast(head_buf).await; diff --git a/crates/p2p/src/manager.rs b/crates/p2p/src/manager.rs index 430889af8..674fcd656 100644 --- a/crates/p2p/src/manager.rs +++ b/crates/p2p/src/manager.rs @@ -1,25 +1,21 @@ -use std::{ - collections::{HashMap, HashSet}, - net::SocketAddr, - sync::{atomic::AtomicBool, Arc}, -}; +use std::{collections::HashSet, net::SocketAddr, sync::Arc}; use libp2p::{core::muxing::StreamMuxerBox, quic, Swarm, Transport}; use thiserror::Error; -use tokio::sync::{mpsc, oneshot, RwLock}; +use tokio::sync::{mpsc, oneshot}; use tracing::{debug, error, warn}; use crate::{ spacetime::{SpaceTime, UnicastStream}, - AsyncFn, DiscoveredPeer, Keypair, ManagerStream, ManagerStreamAction, Mdns, Metadata, PeerId, + AsyncFn, DiscoveredPeer, Keypair, ManagerStream, ManagerStreamAction, Mdns, MdnsState, + Metadata, PeerId, }; /// Is the core component of the P2P system that holds the state and delegates actions to the other components #[derive(Debug)] pub struct Manager { + pub(crate) mdns_state: Arc>, pub(crate) peer_id: PeerId, - pub(crate) listen_addrs: RwLock>, - pub(crate) discovered: RwLock>>, pub(crate) application_name: &'static [u8], event_stream_tx: mpsc::Sender>, } @@ -40,17 +36,19 @@ impl Manager { .then_some(()) .ok_or(ManagerError::InvalidAppName)?; + let peer_id = PeerId(keypair.public().to_peer_id()); let (event_stream_tx, event_stream_rx) = mpsc::channel(1024); + + let (mdns, mdns_state) = Mdns::new(application_name, peer_id, fn_get_metadata).unwrap(); let this = Arc::new(Self { + mdns_state, // Look this is bad but it's hard to avoid. Technically a memory leak but it's a small amount of memory and is should done on startup on the P2P system. application_name: Box::leak(Box::new( format!("/{}/spacetime/1.0.0", application_name) .as_bytes() .to_vec(), )), - peer_id: PeerId(keypair.public().to_peer_id()), - listen_addrs: RwLock::new(Default::default()), - discovered: RwLock::new(Default::default()), + peer_id, event_stream_tx, }); @@ -77,11 +75,11 @@ impl Manager { Ok(( this.clone(), ManagerStream { - manager: this.clone(), + manager: this, event_stream_rx, swarm, - mdns: Mdns::new(this, application_name, fn_get_metadata).unwrap(), - is_advertisement_queued: AtomicBool::new(false), + mdns, + queued_events: Default::default(), }, )) } @@ -98,11 +96,17 @@ impl Manager { } pub async fn listen_addrs(&self) -> HashSet { - self.listen_addrs.read().await.clone() + self.mdns_state.listen_addrs.read().await.clone() } pub async fn get_discovered_peers(&self) -> Vec> { - self.discovered.read().await.values().cloned().collect() + self.mdns_state + .discovered + .read() + .await + .values() + .cloned() + .collect() } pub async fn get_connected_peers(&self) -> Result, ()> { diff --git a/crates/p2p/src/manager_stream.rs b/crates/p2p/src/manager_stream.rs index eadeaae2d..484db9037 100644 --- a/crates/p2p/src/manager_stream.rs +++ b/crates/p2p/src/manager_stream.rs @@ -1,11 +1,4 @@ -use std::{ - fmt, - net::SocketAddr, - sync::{ - atomic::{AtomicBool, Ordering}, - Arc, - }, -}; +use std::{collections::VecDeque, fmt, net::SocketAddr, sync::Arc}; use libp2p::{ futures::StreamExt, @@ -13,7 +6,7 @@ use libp2p::{ dial_opts::{DialOpts, PeerCondition}, NetworkBehaviourAction, NotifyHandler, SwarmEvent, }, - Multiaddr, Swarm, + Swarm, }; use tokio::sync::{mpsc, oneshot}; use tracing::{debug, error, warn}; @@ -63,7 +56,7 @@ where pub(crate) event_stream_rx: mpsc::Receiver>, pub(crate) swarm: Swarm>, pub(crate) mdns: Mdns, - pub(crate) is_advertisement_queued: AtomicBool, + pub(crate) queued_events: VecDeque>, } impl ManagerStream @@ -75,8 +68,12 @@ where pub async fn next(&mut self) -> Option> { // We loop polling internal services until an event comes in that needs to be sent to the parent application. loop { + if let Some(event) = self.queued_events.pop_front() { + return Some(event); + } + tokio::select! { - event = self.mdns.poll() => { + event = self.mdns.poll(&self.manager) => { if let Some(event) = event { return Some(event); } @@ -101,16 +98,11 @@ where SwarmEvent::IncomingConnectionError { local_addr, error, .. } => warn!("handshake error with incoming connection from '{}': {}", local_addr, error), SwarmEvent::OutgoingConnectionError { peer_id, error } => warn!("error establishing connection with '{:?}': {}", peer_id, error), SwarmEvent::BannedPeer { peer_id, .. } => warn!("banned peer '{}' attempted to connection and was rejected", peer_id), - SwarmEvent::NewListenAddr{ address, .. } => { + SwarmEvent::NewListenAddr { address, .. } => { match quic_multiaddr_to_socketaddr(address) { Ok(addr) => { debug!("listen address added: {}", addr); - self.manager.listen_addrs.write().await.insert(addr); - if !self.is_advertisement_queued.load(Ordering::Relaxed) { - self.is_advertisement_queued.store(true, Ordering::Relaxed); - self.mdns.advertise(); - } - self.mdns.advertise(); + self.mdns.register_addr(addr).await; return Some(Event::AddListenAddr(addr)); }, Err(err) => { @@ -120,8 +112,12 @@ where } }, SwarmEvent::ExpiredListenAddr { address, .. } => { - match Self::unregister_addr(&self.manager, &self.mdns, &self.is_advertisement_queued, address).await { - Ok(_) => {}, + match quic_multiaddr_to_socketaddr(address) { + Ok(addr) => { + debug!("listen address added: {}", addr); + self.mdns.unregister_addr(&addr).await; + return Some(Event::RemoveListenAddr(addr)); + }, Err(err) => { warn!("error passing listen address: {}", err); continue; @@ -131,14 +127,21 @@ where SwarmEvent::ListenerClosed { listener_id, addresses, reason } => { debug!("listener '{:?}' was closed due to: {:?}", listener_id, reason); for address in addresses { - match Self::unregister_addr(&self.manager, &self.mdns, &self.is_advertisement_queued, address).await { - Ok(_) => {}, + match quic_multiaddr_to_socketaddr(address) { + Ok(addr) => { + debug!("listen address added: {}", addr); + self.mdns.unregister_addr(&addr).await; + + self.queued_events.push_back(Event::RemoveListenAddr(addr)); + }, Err(err) => { warn!("error passing listen address: {}", err); continue; } } } + + // The `loop` will restart and begin returning the events from `queued_events`. } SwarmEvent::ListenerError { listener_id, error } => warn!("listener '{:?}' reported a non-fatal error: {}", listener_id, error), SwarmEvent::Dialing(_peer_id) => {}, @@ -191,13 +194,13 @@ where ); } ManagerStreamAction::BroadcastData(data) => { - let connected_peers = self.swarm.connected_peers().map(|v| *v).collect::>(); + let connected_peers = self.swarm.connected_peers().copied().collect::>(); let behaviour = self.swarm.behaviour_mut(); for peer_id in connected_peers { behaviour .pending_events .push_back(NetworkBehaviourAction::NotifyHandler { - peer_id: peer_id, + peer_id, handler: NotifyHandler::Any, event: OutboundRequest::Broadcast(data.clone()), }); @@ -205,28 +208,6 @@ where } } - return None; - } - - // TODO: Move into mdns - async fn unregister_addr( - manager: &Arc>, - mdns: &Mdns, - is_advertisement_queued: &AtomicBool, - address: Multiaddr, - ) -> Result, String> { - match quic_multiaddr_to_socketaddr(address) { - Ok(addr) => { - debug!("listen address removed: {}", addr); - manager.listen_addrs.write().await.remove(&addr); - let _ = mdns.unregister_mdns(); - if !is_advertisement_queued.load(Ordering::Relaxed) { - is_advertisement_queued.store(true, Ordering::Relaxed); - mdns.advertise(); - } - Ok(Event::RemoveListenAddr(addr)) - } - Err(err) => Err(err), - } + None } } diff --git a/crates/p2p/src/mdns.rs b/crates/p2p/src/mdns.rs index 4bec9c5b5..2363a279f 100644 --- a/crates/p2p/src/mdns.rs +++ b/crates/p2p/src/mdns.rs @@ -1,5 +1,5 @@ use std::{ - collections::HashMap, + collections::{HashMap, HashSet}, net::{IpAddr, SocketAddr}, pin::Pin, str::FromStr, @@ -8,7 +8,10 @@ use std::{ }; use mdns_sd::{ServiceDaemon, ServiceEvent, ServiceInfo}; -use tokio::time::{sleep_until, Instant, Sleep}; +use tokio::{ + sync::RwLock, + time::{sleep_until, Instant, Sleep}, +}; use tracing::{debug, error, warn}; use crate::{AsyncFn, DiscoveredPeer, Event, Manager, Metadata, PeerId}; @@ -16,18 +19,27 @@ use crate::{AsyncFn, DiscoveredPeer, Event, Manager, Metadata, PeerId}; /// TODO const MDNS_READVERTISEMENT_INTERVAL: Duration = Duration::from_secs(60); // Every minute re-advertise +/// TODO +#[derive(Debug)] +pub struct MdnsState { + pub discovered: RwLock>>, + pub listen_addrs: RwLock>, +} + /// TODO pub struct Mdns where TMetadata: Metadata, TMetadataFn: AsyncFn, { - manager: Arc>, + // used to ignore events from our own mdns advertisement + peer_id: PeerId, fn_get_metadata: TMetadataFn, mdns_daemon: ServiceDaemon, mdns_service_receiver: flume::Receiver, service_name: String, next_mdns_advertisement: Pin>, + state: Arc>, } impl Mdns @@ -36,10 +48,10 @@ where TMetadataFn: AsyncFn, { pub fn new( - manager: Arc>, application_name: &'static str, + peer_id: PeerId, fn_get_metadata: TMetadataFn, - ) -> Result + ) -> Result<(Self, Arc>), mdns_sd::Error> where TMetadataFn: AsyncFn, { @@ -47,88 +59,81 @@ where let service_name = format!("_{}._udp.local.", application_name); let mdns_service_receiver = mdns_daemon.browse(&service_name)?; - let this = Self { - manager, - fn_get_metadata, - mdns_daemon, - mdns_service_receiver, - service_name, - next_mdns_advertisement: Box::pin(sleep_until( - Instant::now() + MDNS_READVERTISEMENT_INTERVAL, - )), - }; - this.advertise(); - Ok(this) + let state = Arc::new(MdnsState { + discovered: RwLock::new(Default::default()), + listen_addrs: RwLock::new(Default::default()), + }); + Ok(( + Self { + peer_id, + fn_get_metadata, + mdns_daemon, + mdns_service_receiver, + service_name, + next_mdns_advertisement: Box::pin(sleep_until(Instant::now())), // Trigger an advertisement immediately + state: state.clone(), + }, + state, + )) } pub fn unregister_mdns(&self) -> mdns_sd::Result> { self.mdns_daemon - .unregister(&format!("{}.{}", self.manager.peer_id, self.service_name)) + .unregister(&format!("{}.{}", self.peer_id, self.service_name)) } - /// Do an mdns advertisement to the network - pub fn advertise(&self) { - // TODO: Instead of spawning maybe do this as part of the polling loop to avoid needing persitent reference to manager. - let manager = self.manager.clone(); - let service_name = self.service_name.clone(); - // let fn_get_metadata = self.fn_get_metadata.clone(); - let mdns_daemon = self.mdns_daemon.clone(); + /// Do an mdns advertisement to the network. + async fn advertise(&mut self) { + let metadata = (self.fn_get_metadata)().await.to_hashmap(); - let metadata_fut = (self.fn_get_metadata)(); + // This is in simple terms converts from `Vec<(ip, port)>` to `Vec<(Vec, port)>` + let mut services = HashMap::::new(); + for addr in self.state.listen_addrs.read().await.iter() { + let addr = match addr { + SocketAddr::V4(addr) => addr, + // TODO: Our mdns library doesn't support Ipv6. This code has the infra to support it so once this issue is fixed upstream we can just flip it on. + // Refer to issue: https://github.com/keepsimple1/mdns-sd/issues/61 + SocketAddr::V6(_) => continue, + }; - tokio::spawn(async move { - let metadata = metadata_fut.await.to_hashmap(); - let peer_id = manager.peer_id.0.to_base58(); - - // This is in simple terms converts from `Vec<(ip, port)>` to `Vec<(Vec, port)>` - let mut services = HashMap::::new(); - for addr in manager.listen_addrs.read().await.iter() { - let addr = match addr { - SocketAddr::V4(addr) => addr, - // TODO: Our mdns library doesn't support Ipv6. This code has the infra to support it so once this issue is fixed upstream we can just flip it on. - // Refer to issue: https://github.com/keepsimple1/mdns-sd/issues/61 - SocketAddr::V6(_) => continue, + if let Some(mut service) = services.remove(&addr.port()) { + service.insert_ipv4addr(*addr.ip()); + services.insert(addr.port(), service); + } else { + let service = match ServiceInfo::new( + &self.service_name, + &self.peer_id.to_string(), + &format!("{}.", self.peer_id), + *addr.ip(), + addr.port(), + Some(metadata.clone()), // TODO: Prevent the user defining a value that overflows a DNS record + ) { + Ok(service) => service, + Err(err) => { + warn!("error creating mdns service info: {}", err); + continue; + } }; - - if let Some(mut service) = services.remove(&addr.port()) { - service.insert_ipv4addr(*addr.ip()); - services.insert(addr.port(), service); - } else { - let service = match ServiceInfo::new( - &service_name, - &peer_id, - &format!("{}.", peer_id), - *addr.ip(), - addr.port(), - Some(metadata.clone()), // TODO: Prevent the user defining a value that overflows a DNS record - ) { - Ok(service) => service, - Err(err) => { - warn!("error creating mdns service info: {}", err); - continue; - } - }; - services.insert(addr.port(), service); - } + services.insert(addr.port(), service); } + } - for (_, service) in services.into_iter() { - debug!("advertising mdns service: {:?}", service); - match mdns_daemon.register(service) { - Ok(_) => {} - Err(err) => warn!("error registering mdns service: {}", err), - } + for (_, service) in services.into_iter() { + debug!("advertising mdns service: {:?}", service); + match self.mdns_daemon.register(service) { + Ok(_) => {} + Err(err) => warn!("error registering mdns service: {}", err), } - }); + } + + self.next_mdns_advertisement = + Box::pin(sleep_until(Instant::now() + MDNS_READVERTISEMENT_INTERVAL)); } // TODO: if the channel's sender is dropped will this cause the `tokio::select` in the `manager.rs` to infinitely loop? - pub async fn poll(&mut self) -> Option> { + pub async fn poll(&mut self, manager: &Arc>) -> Option> { tokio::select! { - _ = &mut self.next_mdns_advertisement => { - self.advertise(); - self.next_mdns_advertisement = Box::pin(sleep_until(Instant::now() + MDNS_READVERTISEMENT_INTERVAL)); - } + _ = &mut self.next_mdns_advertisement => self.advertise().await, event = self.mdns_service_receiver.recv_async() => { let event = event.unwrap(); // TODO: Error handling match event { @@ -142,7 +147,7 @@ where match PeerId::from_str(&raw_peer_id) { Ok(peer_id) => { // Prevent discovery of the current peer. - if peer_id == self.manager.peer_id { + if peer_id == self.peer_id { return None; } @@ -156,14 +161,14 @@ where Ok(metadata) => { let peer = { let mut discovered_peers = - self.manager.discovered.write().await; + self.state.discovered.write().await; let peer = if let Some(peer) = discovered_peers.remove(&peer_id) { peer } else { DiscoveredPeer { - manager: self.manager.clone(), + manager: manager.clone(), peer_id, metadata, addresses: info @@ -201,13 +206,13 @@ where match PeerId::from_str(&raw_peer_id) { Ok(peer_id) => { // Prevent discovery of the current peer. - if peer_id == self.manager.peer_id { + if peer_id == self.peer_id { return None; } { let mut discovered_peers = - self.manager.discovered.write().await; + self.state.discovered.write().await; let peer = discovered_peers.remove(&peer_id); return Some(Event::PeerExpired { @@ -229,4 +234,26 @@ where None } + + pub async fn register_addr(&mut self, addr: SocketAddr) { + self.state.listen_addrs.write().await.insert(addr); + + // If the next mdns advertisement is more than 250ms away, then we should queue one closer to now. + // This acts as a debounce for advertisements when many addresses are discovered close to each other (Eg. at startup) + if self.next_mdns_advertisement.deadline() > (Instant::now() + Duration::from_millis(250)) { + self.next_mdns_advertisement = + Box::pin(sleep_until(Instant::now() + Duration::from_millis(200))); + } + } + + pub async fn unregister_addr(&mut self, addr: &SocketAddr) { + self.state.listen_addrs.write().await.remove(addr); + + // If the next mdns advertisement is more than 250ms away, then we should queue one closer to now. + // This acts as a debounce for advertisements when many addresses are discovered close to each other (Eg. at startup) + if self.next_mdns_advertisement.deadline() > (Instant::now() + Duration::from_millis(250)) { + self.next_mdns_advertisement = + Box::pin(sleep_until(Instant::now() + Duration::from_millis(200))); + } + } } diff --git a/crates/p2p/src/spacetime/stream.rs b/crates/p2p/src/spacetime/stream.rs index f7837e1a4..5e415f82b 100644 --- a/crates/p2p/src/spacetime/stream.rs +++ b/crates/p2p/src/spacetime/stream.rs @@ -36,8 +36,9 @@ impl SpaceTimeStream { Self::Broadcast(mut stream) => { if let Some(stream) = stream.0.take() { BroadcastStream::close_inner(stream).await + } else if cfg!(debug_assertions) { + panic!("'BroadcastStream' should never be 'None' here!"); } else { - debug_assert!(true, "'BroadcastStream' should never be 'None' here!"); error!("'BroadcastStream' should never be 'None' here!"); Ok(()) }