From 64ae77ec70f30c8a1e26db2c6d336536520ccf80 Mon Sep 17 00:00:00 2001 From: Ivan Enderlin Date: Wed, 8 Mar 2023 15:05:04 +0100 Subject: [PATCH] feat(sdk): Run SS response handling in a single non-cancellable block. --- crates/matrix-sdk/src/sliding_sync/builder.rs | 30 +++--- crates/matrix-sdk/src/sliding_sync/mod.rs | 98 ++++++++++++------- 2 files changed, 78 insertions(+), 50 deletions(-) diff --git a/crates/matrix-sdk/src/sliding_sync/builder.rs b/crates/matrix-sdk/src/sliding_sync/builder.rs index 6feaa3909..6f680399a 100644 --- a/crates/matrix-sdk/src/sliding_sync/builder.rs +++ b/crates/matrix-sdk/src/sliding_sync/builder.rs @@ -1,7 +1,7 @@ use std::{ collections::BTreeMap, fmt::Debug, - sync::{Arc, Mutex, RwLock as StdRwLock}, + sync::{Mutex, RwLock as StdRwLock}, }; use eyeball::Observable; @@ -289,23 +289,21 @@ impl SlidingSyncBuilder { let rooms = StdRwLock::new(rooms_found); let lists = StdRwLock::new(self.lists); - Ok(SlidingSync { - inner: Arc::new(SlidingSyncInner { - homeserver: self.homeserver, - client, - storage_key: self.storage_key, + Ok(SlidingSync::new(SlidingSyncInner { + homeserver: self.homeserver, + client, + storage_key: self.storage_key, - lists, - rooms, + lists, + rooms, - extensions: Mutex::new(self.extensions).into(), - reset_counter: Default::default(), + extensions: Mutex::new(self.extensions), + reset_counter: Default::default(), - pos: StdRwLock::new(Observable::new(None)), - delta_token: StdRwLock::new(Observable::new(delta_token_inner)), - subscriptions: StdRwLock::new(self.subscriptions), - unsubscribe: Default::default(), - }), - }) + pos: StdRwLock::new(Observable::new(None)), + delta_token: StdRwLock::new(Observable::new(delta_token_inner)), + subscriptions: StdRwLock::new(self.subscriptions), + unsubscribe: Default::default(), + })) } } diff --git a/crates/matrix-sdk/src/sliding_sync/mod.rs b/crates/matrix-sdk/src/sliding_sync/mod.rs index 725a0aec0..567c52ac4 100644 --- a/crates/matrix-sdk/src/sliding_sync/mod.rs +++ b/crates/matrix-sdk/src/sliding_sync/mod.rs @@ -597,6 +597,7 @@ mod list; mod room; use std::{ + borrow::BorrowMut, collections::BTreeMap, fmt::Debug, mem, @@ -614,6 +615,7 @@ use eyeball::Observable; use futures_core::stream::Stream; pub use list::*; use matrix_sdk_base::sync::SyncResponse; +use matrix_sdk_common::locks::Mutex as AsyncMutex; pub use room::*; use ruma::{ api::client::{ @@ -625,6 +627,7 @@ use ruma::{ assign, OwnedRoomId, RoomId, }; use serde::{Deserialize, Serialize}; +use tokio::spawn; use tracing::{debug, error, info_span, instrument, trace, warn, Instrument, Span}; use url::Url; use uuid::Uuid; @@ -641,9 +644,15 @@ use crate::{config::RequestConfig, Client, Result}; const MAXIMUM_SLIDING_SYNC_SESSION_EXPIRATION: u8 = 3; /// The Sliding Sync instance. +/// +/// It is OK to clone this type as much as you need: cloning it is cheap. #[derive(Clone, Debug)] pub struct SlidingSync { + /// The Sliding Sync data. inner: Arc, + + /// A lock to ensure that responses are handled one at a time. + response_handling_lock: Arc>, } #[derive(Debug)] @@ -680,6 +689,10 @@ pub(super) struct SlidingSyncInner { } impl SlidingSync { + pub(super) fn new(inner: SlidingSyncInner) -> Self { + Self { inner: Arc::new(inner), response_handling_lock: Arc::new(AsyncMutex::new(())) } + } + async fn cache_to_storage(&self) -> Result<(), crate::Error> { let Some(storage_key) = self.inner.storage_key.as_ref() else { return Ok(()) }; trace!(storage_key, "Saving to storage for later use"); @@ -979,11 +992,13 @@ impl SlidingSync { async fn sync_once( &self, stream_id: &str, - list_generators: &mut BTreeMap, + list_generators: Arc>>, ) -> Result> { let mut lists = BTreeMap::new(); { + let mut list_generators_lock = list_generators.lock().unwrap(); + let list_generators = list_generators_lock.borrow_mut(); let mut lists_to_remove = Vec::new(); for (name, generator) in list_generators.iter_mut() { @@ -997,10 +1012,10 @@ impl SlidingSync { for list_name in lists_to_remove { list_generators.remove(&list_name); } - } - if list_generators.is_empty() { - return Ok(None); + if list_generators.is_empty() { + return Ok(None); + } } let pos = self.inner.pos.read().unwrap().clone(); @@ -1067,43 +1082,57 @@ impl SlidingSync { // corrupted/incomplete states for Sliding Sync and other parts of // the code. // - // That's why we are running the handling of the response in a blocking - // mode since it cannot be cancelled abruptly. - debug!("Sliding sync response received"); + // That's why we are running the handling of the response in a spawned + // future that cannot be cancelled by anything. + let this = self.clone(); + let stream_id = stream_id.to_owned(); - match &response.txn_id { - None => { - error!(stream_id, "Sliding Sync has received an unexpected response: `txn_id` must match `stream_id`; it's missing"); + // Spawn a new future to ensure that the code inside this future cannot be + // cancelled if this method is cancelled. + spawn(async move { + debug!("Sliding sync response received"); + + // In case the task running this future is detached, we must be + // ensure responses are handled one at a time, hence we lock the + // `response_handling_lock`. + let global_lock = this.response_handling_lock.lock().await; + + match &response.txn_id { + None => { + error!(stream_id, "Sliding Sync has received an unexpected response: `txn_id` must match `stream_id`; it's missing"); + } + + Some(txn_id) if txn_id != &stream_id => { + error!( + stream_id, + txn_id, + "Sliding Sync has received an unexpected response: `txn_id` must match `stream_id`; they differ" + ); + } + + _ => {} } - Some(txn_id) if txn_id != stream_id => { - error!( - stream_id, - txn_id, - "Sliding Sync has received an unexpected response: `txn_id` must match `stream_id`; they differ" - ); - } + // Handle and transform a Sliding Sync Response to a `SyncResponse`. + // + // We may not need the `sync_response` in the future (once `SyncResponse` will + // move to Sliding Sync, i.e. to `v4::Response`), but processing the + // `sliding_sync_response` is vital, so it must be done somewhere; for now it + // happens here. + let sync_response = this.inner.client.process_sliding_sync(&response).await?; - _ => {} - } + debug!("Sliding sync response has been processed"); - // Handle and transform a Sliding Sync Response to a `SyncResponse`. - // - // We may not need the `sync_response` in the future (once `SyncResponse` will - // move to Sliding Sync, i.e. to `v4::Response`), but processing the - // `sliding_sync_response` is vital, so it must be done somewhere; for now it - // happens here. - let sync_response = self.inner.client.process_sliding_sync(&response).await?; + let updates = this.handle_response(response, sync_response, list_generators.lock().unwrap().borrow_mut())?; - debug!("Sliding sync response has been processed"); + this.cache_to_storage().await?; - let updates = self.handle_response(response, sync_response, list_generators)?; + drop(global_lock); - self.cache_to_storage().await?; + debug!("Sliding sync response has been handled"); - debug!("Sliding sync response has been handled"); - - Ok(Some(updates)) + Ok(Some(updates)) + }).await.unwrap() } /// Create a _new_ Sliding Sync stream. @@ -1113,7 +1142,7 @@ impl SlidingSync { #[instrument(name = "sync_stream", skip_all, parent = &self.inner.client.root_span)] pub fn stream(&self) -> impl Stream> + '_ { // Collect all the lists that need to be updated. - let mut list_generators = { + let list_generators = { let mut list_generators = BTreeMap::new(); let lock = self.inner.lists.read().unwrap(); @@ -1129,6 +1158,7 @@ impl SlidingSync { debug!(?self.inner.extensions, stream_id, "About to run the sync stream"); let instrument_span = Span::current(); + let list_generators = Arc::new(Mutex::new(list_generators)); async_stream::stream! { loop { @@ -1138,7 +1168,7 @@ impl SlidingSync { debug!(?self.inner.extensions, "Sync stream loop is running"); }); - match self.sync_once(&stream_id, &mut list_generators).instrument(sync_span.clone()).await { + match self.sync_once(&stream_id, list_generators.clone()).instrument(sync_span.clone()).await { Ok(Some(updates)) => { self.inner.reset_counter.store(0, Ordering::SeqCst);