diff --git a/README.md b/README.md index 06b8893eb..216ad95cf 100644 --- a/README.md +++ b/README.md @@ -26,7 +26,7 @@ The rust-sdk consists of multiple crates that can be picked at your convenience: ## Minimum Supported Rust Version (MSRV) -These crates are built with the Rust language version 2021 and require a minimum compiler version of `1.65`. +These crates are built with the Rust language version 2021 and require a minimum compiler version of `1.70`. ## Status diff --git a/bindings/matrix-sdk-ffi/src/lib.rs b/bindings/matrix-sdk-ffi/src/lib.rs index c448528d1..b262de6f4 100644 --- a/bindings/matrix-sdk-ffi/src/lib.rs +++ b/bindings/matrix-sdk-ffi/src/lib.rs @@ -41,12 +41,12 @@ pub use matrix_sdk::ruma::{api::client::account::register, UserId}; pub use matrix_sdk_ui::timeline::PaginationOutcome; pub use platform::*; +// Re-exports for more convenient use inside other submodules +use self::error::ClientError; pub use self::{ authentication_service::*, client::*, event::*, notification::*, room::*, room_member::*, session_verification::*, sliding_sync::*, timeline::*, tracing::*, }; -// Re-exports for more convenient use inside other submodules -use self::{client::Client, error::ClientError}; uniffi::include_scaffolding!("api"); diff --git a/bindings/matrix-sdk-ffi/src/room.rs b/bindings/matrix-sdk-ffi/src/room.rs index aae310e9f..6100d7172 100644 --- a/bindings/matrix-sdk-ffi/src/room.rs +++ b/bindings/matrix-sdk-ffi/src/room.rs @@ -28,7 +28,7 @@ use matrix_sdk::{ }; use matrix_sdk_ui::timeline::{RoomExt, Timeline}; use mime::Mime; -use tracing::error; +use tracing::{error, info}; use super::RUNTIME; use crate::{ @@ -780,6 +780,22 @@ impl Room { } }); } + + pub fn cancel_send(&self, txn_id: String) { + let timeline = match &*self.timeline.read().unwrap() { + Some(t) => Arc::clone(t), + None => { + error!("Timeline not set up, can't retry sending message"); + return; + } + }; + + RUNTIME.spawn(async move { + if !timeline.cancel_send(txn_id.as_str().into()).await { + info!(txn_id, "Failed to discard local echo: Not found"); + } + }); + } } impl Room { diff --git a/crates/matrix-sdk-ui/src/timeline/inner.rs b/crates/matrix-sdk-ui/src/timeline/inner.rs index c3676744f..324134659 100644 --- a/crates/matrix-sdk-ui/src/timeline/inner.rs +++ b/crates/matrix-sdk-ui/src/timeline/inner.rs @@ -361,6 +361,18 @@ impl TimelineInner

{ Some(content) } + pub(super) async fn discard_local_echo(&self, txn_id: &TransactionId) -> bool { + let mut state = self.state.lock().await; + if let Some((idx, _)) = + rfind_event_item(&state.items, |it| it.transaction_id() == Some(txn_id)) + { + state.items.remove(idx); + true + } else { + false + } + } + /// Handle a back-paginated event. /// /// Returns the number of timeline updates that were made. diff --git a/crates/matrix-sdk-ui/src/timeline/mod.rs b/crates/matrix-sdk-ui/src/timeline/mod.rs index 62643577d..0fead1dca 100644 --- a/crates/matrix-sdk-ui/src/timeline/mod.rs +++ b/crates/matrix-sdk-ui/src/timeline/mod.rs @@ -418,6 +418,21 @@ impl Timeline { Ok(()) } + /// Discard a local echo for a message that failed to send. + /// + /// Returns whether the local echo with the given transaction ID was found. + /// + /// # Argument + /// + /// * `txn_id` - The transaction ID of a local echo timeline item that has a + /// `send_state()` of `SendState::FailedToSend { .. }`. *Note:* A send + /// state of `SendState::NotYetSent` might be supported in the future as + /// well, but there can be no guarantee for that actually stopping the + /// event from reaching the server. + pub async fn cancel_send(&self, txn_id: &TransactionId) -> bool { + self.inner.discard_local_echo(txn_id).await + } + /// Fetch unavailable details about the event with the given ID. /// /// This method only works for IDs of remote [`EventTimelineItem`]s, diff --git a/crates/matrix-sdk-ui/tests/integration/timeline/echo.rs b/crates/matrix-sdk-ui/tests/integration/timeline/echo.rs index 1f8f4f808..a6fbd4f69 100644 --- a/crates/matrix-sdk-ui/tests/integration/timeline/echo.rs +++ b/crates/matrix-sdk-ui/tests/integration/timeline/echo.rs @@ -263,3 +263,42 @@ async fn dedup_by_event_id_late() { assert_next_matches!(timeline_stream, VectorDiff::Remove { index: 1 }); assert_next_matches!(timeline_stream, VectorDiff::Remove { index: 0 }); } + +#[async_test] +async fn cancel_failed() { + let room_id = room_id!("!a98sd12bjh:example.org"); + let (client, server) = logged_in_client().await; + let sync_settings = SyncSettings::new().timeout(Duration::from_millis(3000)); + + let mut ev_builder = EventBuilder::new(); + ev_builder.add_joined_room(JoinedRoomBuilder::new(room_id)); + + mock_sync(&server, ev_builder.build_json_sync_response(), None).await; + let _response = client.sync_once(sync_settings.clone()).await.unwrap(); + server.reset().await; + + let room = client.get_room(room_id).unwrap(); + let timeline = Arc::new(room.timeline().await); + let (_, mut timeline_stream) = + timeline.subscribe_filter_map(|item| item.as_event().cloned()).await; + + let txn_id: &TransactionId = "my-txn-id".into(); + + timeline.send(RoomMessageEventContent::text_plain("Hello, World!").into(), Some(txn_id)).await; + + // Local echo is added + assert_next_matches!(timeline_stream, VectorDiff::PushBack { value } => { + assert_matches!(value.send_state(), Some(EventSendState::NotSentYet)); + }); + + // Sending fails, the mock server has no matching route + assert_next_matches!(timeline_stream, VectorDiff::Set { index: 0, value } => { + assert_matches!(value.send_state(), Some(EventSendState::SendingFailed { .. })); + }); + + // Discard, assert the local echo is found + assert!(timeline.cancel_send(txn_id).await); + + // Observable local echo being removed + assert_next_matches!(timeline_stream, VectorDiff::Remove { index: 0 }); +} diff --git a/crates/matrix-sdk/src/sliding_sync/builder.rs b/crates/matrix-sdk/src/sliding_sync/builder.rs index 7072c7e33..3d9b07076 100644 --- a/crates/matrix-sdk/src/sliding_sync/builder.rs +++ b/crates/matrix-sdk/src/sliding_sync/builder.rs @@ -12,10 +12,11 @@ use url::Url; use super::{ cache::{format_storage_key_prefix, restore_sliding_sync_state}, + sticky_parameters::SlidingSyncStickyManager, Error, SlidingSync, SlidingSyncInner, SlidingSyncListBuilder, SlidingSyncPositionMarkers, SlidingSyncRoom, }; -use crate::{Client, Result}; +use crate::{sliding_sync::SlidingSyncStickyParameters, Client, Result}; /// Configuration for a Sliding Sync instance. /// @@ -261,6 +262,12 @@ impl SlidingSyncBuilder { let rooms = AsyncRwLock::new(self.rooms); let lists = AsyncRwLock::new(lists); + // Always enable to-device events and the e2ee-extension on the initial request, + // no matter what the caller wants. + let mut extensions = self.extensions.unwrap_or_default(); + extensions.to_device.enabled = Some(true); + extensions.e2ee.enabled = Some(true); + Ok(SlidingSync::new(SlidingSyncInner { _id: Some(self.id), sliding_sync_proxy: self.sliding_sync_proxy, @@ -270,7 +277,6 @@ impl SlidingSyncBuilder { lists, rooms, - extensions: self.extensions.unwrap_or_default(), reset_counter: Default::default(), position: StdRwLock::new(SlidingSyncPositionMarkers { @@ -279,7 +285,9 @@ impl SlidingSyncBuilder { to_device_token, }), - room_subscriptions: StdRwLock::new(self.subscriptions), + sticky: StdRwLock::new(SlidingSyncStickyManager::new( + SlidingSyncStickyParameters::new(self.subscriptions, extensions), + )), room_unsubscriptions: Default::default(), internal_channel: internal_channel_sender, diff --git a/crates/matrix-sdk/src/sliding_sync/list/builder.rs b/crates/matrix-sdk/src/sliding_sync/list/builder.rs index 488c8c34d..365f8de01 100644 --- a/crates/matrix-sdk/src/sliding_sync/list/builder.rs +++ b/crates/matrix-sdk/src/sliding_sync/list/builder.rs @@ -19,10 +19,14 @@ use tokio::sync::broadcast::Sender; use super::{ super::SlidingSyncInternalMessage, Bound, SlidingSyncList, SlidingSyncListCachePolicy, - SlidingSyncListInner, SlidingSyncListRequestGenerator, SlidingSyncMode, SlidingSyncState, + SlidingSyncListInner, SlidingSyncListRequestGenerator, SlidingSyncListStickyParameters, + SlidingSyncMode, SlidingSyncState, }; use crate::{ - sliding_sync::{cache::restore_sliding_sync_list, FrozenSlidingSyncRoom}, + sliding_sync::{ + cache::restore_sliding_sync_list, sticky_parameters::SlidingSyncStickyManager, + FrozenSlidingSyncRoom, + }, Client, RoomListEntry, }; @@ -198,13 +202,17 @@ impl SlidingSyncListBuilder { sync_mode: StdRwLock::new(self.sync_mode.clone()), // From the builder - sort: self.sort, - required_state: self.required_state, - filters: self.filters, - timeline_limit: StdRwLock::new(self.timeline_limit), + sticky: StdRwLock::new(SlidingSyncStickyManager::new( + SlidingSyncListStickyParameters::new( + self.sort, + self.required_state, + self.filters, + self.timeline_limit, + self.bump_event_types, + ), + )), name: self.name, cache_policy: self.cache_policy, - bump_event_types: self.bump_event_types, // Computed from the builder. request_generator: StdRwLock::new(SlidingSyncListRequestGenerator::new( diff --git a/crates/matrix-sdk/src/sliding_sync/list/mod.rs b/crates/matrix-sdk/src/sliding_sync/list/mod.rs index 919554a59..fac0f7780 100644 --- a/crates/matrix-sdk/src/sliding_sync/list/mod.rs +++ b/crates/matrix-sdk/src/sliding_sync/list/mod.rs @@ -2,6 +2,7 @@ mod builder; mod frozen; mod request_generator; mod room_list_entry; +mod sticky; use std::{ collections::HashSet, @@ -20,17 +21,13 @@ use futures_core::Stream; use imbl::Vector; pub(super) use request_generator::*; pub use room_list_entry::RoomListEntry; -use ruma::{ - api::client::sync::sync_events::v4, - assign, - events::{StateEventType, TimelineEventType}, - OwnedRoomId, -}; +use ruma::{api::client::sync::sync_events::v4, assign, OwnedRoomId, TransactionId}; use serde::{Deserialize, Serialize}; use tokio::sync::broadcast::Sender; use tracing::{instrument, warn}; -use super::{Error, SlidingSyncInternalMessage}; +use self::sticky::SlidingSyncListStickyParameters; +use super::{sticky_parameters::SlidingSyncStickyManager, Error, SlidingSyncInternalMessage}; use crate::Result; /// Should this [`SlidingSyncList`] be stored in the cache, and automatically @@ -116,12 +113,12 @@ impl SlidingSyncList { /// Get the timeline limit. pub fn timeline_limit(&self) -> Option { - *self.inner.timeline_limit.read().unwrap() + self.inner.sticky.read().unwrap().data().timeline_limit() } /// Set timeline limit. pub fn set_timeline_limit(&self, timeline: Option) { - *self.inner.timeline_limit.write().unwrap() = timeline; + self.inner.sticky.write().unwrap().data_mut().set_timeline_limit(timeline); } /// Get the current room list. @@ -198,8 +195,11 @@ impl SlidingSyncList { /// /// The next request is entirely calculated based on the request generator /// ([`SlidingSyncListRequestGenerator`]). - pub(super) fn next_request(&mut self) -> Result { - self.inner.next_request() + pub(super) fn next_request( + &self, + txn_id: &TransactionId, + ) -> Result { + self.inner.next_request(txn_id) } /// Returns the current cache policy for this list. @@ -239,6 +239,17 @@ impl SlidingSyncList { Ok(new_changes) } + + /// Commit the set of sticky parameters for this list. + pub fn maybe_commit_sticky(&mut self, txn_id: &TransactionId) { + self.inner.sticky.write().unwrap().maybe_commit(txn_id); + } + + /// Manually invalidate the sticky data, so the sticky parameters are + /// re-sent next time. + pub fn invalidate_sticky_data(&self) { + let _ = self.inner.sticky.write().unwrap().data_mut(); + } } #[cfg(any(test, feature = "testing"))] @@ -266,17 +277,10 @@ pub(super) struct SlidingSyncListInner { /// The state this list is in. state: StdRwLock>, - /// Sort the room list by this. - sort: Vec, - - /// Required states to return per room. - required_state: Vec<(StateEventType, String)>, - - /// Any filters to apply to the query. - filters: Option, - - /// The maximum number of timeline events to query for. - timeline_limit: StdRwLock>, + /// Parameters that are sticky, and can be sent only once per session (until + /// the connection is dropped or the server invalidates what the client + /// knows). + sticky: StdRwLock>, /// The total number of rooms that is possible to interact with for the /// given list. @@ -301,10 +305,6 @@ pub(super) struct SlidingSyncListInner { /// [`SlidingSyncInner::internal_channel`] to learn more. sliding_sync_internal_channel_sender: Sender, - /// The `bump_event_types` field. See - /// [`SlidingSyncListBuilder::bump_event_types`] to learn more. - bump_event_types: Vec, - #[cfg(any(test, feature = "testing"))] sync_mode: StdRwLock, } @@ -343,7 +343,7 @@ impl SlidingSyncListInner { } /// Update the state to the next request, and return it. - fn next_request(&self) -> Result { + fn next_request(&self, txn_id: &TransactionId) -> Result { let ranges = { // Use a dedicated scope to ensure the lock is released before continuing. let mut request_generator = self.request_generator.write().unwrap(); @@ -352,31 +352,24 @@ impl SlidingSyncListInner { }; // Here we go. - Ok(self.request(ranges)) + Ok(self.request(ranges, txn_id)) } /// Build a [`SyncRequestList`][v4::SyncRequestList] based on the current /// state of the request generator. #[instrument(skip(self), fields(name = self.name))] - fn request(&self, ranges: Ranges) -> v4::SyncRequestList { + fn request(&self, ranges: Ranges, txn_id: &TransactionId) -> v4::SyncRequestList { use ruma::UInt; let ranges = ranges.into_iter().map(|r| (UInt::from(*r.start()), UInt::from(*r.end()))).collect(); - let sort = self.sort.clone(); - let required_state = self.required_state.clone(); - let timeline_limit = self.timeline_limit.read().unwrap().map(UInt::from); - let filters = self.filters.clone(); - assign!(v4::SyncRequestList::default(), { - ranges, - room_details: assign!(v4::RoomDetailsConfig::default(), { - required_state, - timeline_limit, - }), - sort, - filters, - bump_event_types: self.bump_event_types.clone(), - }) + let mut request = assign!(v4::SyncRequestList::default(), { ranges }); + { + let mut sticky = self.sticky.write().unwrap(); + sticky.maybe_apply(&mut request, txn_id); + } + + request } /// Update the [`Self::room_list`]. It also updates @@ -975,13 +968,13 @@ mod tests { .timeline_limit(7) .build(sender); - assert_eq!(*list.inner.timeline_limit.read().unwrap(), Some(7)); + assert_eq!(list.inner.sticky.read().unwrap().data().timeline_limit(), Some(7)); list.set_timeline_limit(Some(42)); - assert_eq!(*list.inner.timeline_limit.read().unwrap(), Some(42)); + assert_eq!(list.inner.sticky.read().unwrap().data().timeline_limit(), Some(42)); list.set_timeline_limit(None); - assert_eq!(*list.inner.timeline_limit.read().unwrap(), None); + assert_eq!(list.inner.sticky.read().unwrap().data().timeline_limit(), None); } #[test] @@ -996,7 +989,7 @@ mod tests { let room1 = room_id!("!room1:bar.org"); // Simulate a request. - let _ = list.next_request(); + let _ = list.next_request("tid".into()); // A new response. let sync0: v4::SyncOp = serde_json::from_value(json!({ @@ -1032,7 +1025,7 @@ mod tests { $( { // Generate a new request. - let request = $list.next_request().unwrap(); + let request = $list.next_request("tid".into()).unwrap(); assert_eq!( request.ranges, @@ -1489,7 +1482,7 @@ mod tests { // Initial range. for _ in 0..=1 { // Simulate a request. - let _ = list.next_request(); + let _ = list.next_request("tid".into()); // A new response. let sync: v4::SyncOp = serde_json::from_value(json!({ @@ -1532,7 +1525,7 @@ mod tests { }); // Simulate a request. - let _ = list.next_request(); + let _ = list.next_request("tid".into()); // A new response. let sync: v4::SyncOp = serde_json::from_value(json!({ diff --git a/crates/matrix-sdk/src/sliding_sync/list/request_generator.rs b/crates/matrix-sdk/src/sliding_sync/list/request_generator.rs index 234a499a8..7204a8c72 100644 --- a/crates/matrix-sdk/src/sliding_sync/list/request_generator.rs +++ b/crates/matrix-sdk/src/sliding_sync/list/request_generator.rs @@ -134,7 +134,7 @@ impl SlidingSyncListRequestGenerator { SlidingSyncListRequestGeneratorKind::Paging { fully_loaded: true, .. } | SlidingSyncListRequestGeneratorKind::Growing { fully_loaded: true, .. } | SlidingSyncListRequestGeneratorKind::Selective => { - // Nothing to do: we already have the full ranges, return the existing ranges. + // Nothing to do: we already have the full ranges, return the existing ranges. // For the growing and paging modes, keep the current value of `requested_end`, // which is still valid. Ok(self.ranges.clone()) diff --git a/crates/matrix-sdk/src/sliding_sync/list/sticky.rs b/crates/matrix-sdk/src/sliding_sync/list/sticky.rs new file mode 100644 index 000000000..cc8566a55 --- /dev/null +++ b/crates/matrix-sdk/src/sliding_sync/list/sticky.rs @@ -0,0 +1,64 @@ +use ruma::{ + api::client::sync::sync_events::v4, + events::{StateEventType, TimelineEventType}, +}; + +use super::Bound; +use crate::sliding_sync::sticky_parameters::StickyData; + +/// The set of `SlidingSyncList` request parameters that are *sticky*, as +/// defined by the [Sliding Sync MSC](https://github.com/matrix-org/matrix-spec-proposals/blob/kegan/sync-v3/proposals/3575-sync.md). +#[derive(Debug)] +pub(super) struct SlidingSyncListStickyParameters { + /// Sort the room list by this. + sort: Vec, + + /// Required states to return per room. + required_state: Vec<(StateEventType, String)>, + + /// Any filters to apply to the query. + filters: Option, + + /// The maximum number of timeline events to query for. + timeline_limit: Option, + + /// The `bump_event_types` field. See + /// [`SlidingSyncListBuilder::bump_event_types`] to learn more. + bump_event_types: Vec, +} + +impl SlidingSyncListStickyParameters { + pub fn new( + sort: Vec, + required_state: Vec<(StateEventType, String)>, + filters: Option, + timeline_limit: Option, + bump_event_types: Vec, + ) -> Self { + // Consider that each list will have at least one parameter set, so invalidate + // it by default. + Self { sort, required_state, filters, timeline_limit, bump_event_types } + } +} + +impl SlidingSyncListStickyParameters { + pub(super) fn timeline_limit(&self) -> Option { + self.timeline_limit + } + + pub(super) fn set_timeline_limit(&mut self, timeline: Option) { + self.timeline_limit = timeline; + } +} + +impl StickyData for SlidingSyncListStickyParameters { + type Request = v4::SyncRequestList; + + fn apply(&self, request: &mut v4::SyncRequestList) { + request.sort = self.sort.to_vec(); + request.room_details.required_state = self.required_state.to_vec(); + request.room_details.timeline_limit = self.timeline_limit.map(Into::into); + request.filters = self.filters.clone(); + request.bump_event_types = self.bump_event_types.clone(); + } +} diff --git a/crates/matrix-sdk/src/sliding_sync/mod.rs b/crates/matrix-sdk/src/sliding_sync/mod.rs index 14d16f27d..1576c9a06 100644 --- a/crates/matrix-sdk/src/sliding_sync/mod.rs +++ b/crates/matrix-sdk/src/sliding_sync/mod.rs @@ -21,6 +21,7 @@ mod client; mod error; mod list; mod room; +mod sticky_parameters; use std::{ collections::{BTreeMap, BTreeSet}, @@ -45,7 +46,7 @@ use ruma::{ error::ErrorKind, sync::sync_events::v4::{self, ExtensionsConfig}, }, - assign, OwnedRoomId, RoomId, + assign, OwnedRoomId, OwnedTransactionId, RoomId, TransactionId, }; use serde::{Deserialize, Serialize}; use tokio::{ @@ -55,6 +56,7 @@ use tokio::{ use tracing::{debug, error, instrument, warn, Instrument, Span}; use url::Url; +use self::sticky_parameters::{SlidingSyncStickyManager, StickyData}; use crate::{config::RequestConfig, Client, Result}; /// Number of times a Sliding Sync session can expire before raising an error. @@ -103,9 +105,8 @@ pub(super) struct SlidingSyncInner { /// The rooms details rooms: AsyncRwLock>, - /// Room subscriptions, i.e. rooms that may be out-of-scope of all lists but - /// one wants to receive updates. - room_subscriptions: StdRwLock>, + /// Request parameters that are sticky. + sticky: StdRwLock>, /// Rooms to unsubscribe, see [`Self::room_subscriptions`]. room_unsubscriptions: StdRwLock>, @@ -113,10 +114,6 @@ pub(super) struct SlidingSyncInner { /// Number of times a Sliding Sync session has been reset. reset_counter: AtomicU8, - /// Static configuration for extensions, passed in the slidinc sync - /// requests. - extensions: ExtensionsConfig, - /// Internal channel used to pass messages between Sliding Sync and other /// types. internal_channel: Sender, @@ -139,9 +136,11 @@ impl SlidingSync { /// Subscribe to a given room. pub fn subscribe_to_room(&self, room_id: OwnedRoomId, settings: Option) { self.inner - .room_subscriptions + .sticky .write() .unwrap() + .data_mut() + .room_subscriptions .insert(room_id, settings.unwrap_or_default()); self.inner.internal_channel_send_if_possible( @@ -151,8 +150,15 @@ impl SlidingSync { /// Unsubscribe from a given room. pub fn unsubscribe_from_room(&self, room_id: OwnedRoomId) { - // If removing the subscription was successful… - if self.inner.room_subscriptions.write().unwrap().remove(&room_id).is_some() { + // Note: we don't use `BTreeMap::remove` here, because that would require + // mutable access thus calling `data_mut()`, which in turn would + // invalidate the sticky parameters even if the `room_id` wasn't in the + // mapping. + + // If there's a subscription… + if self.inner.sticky.read().unwrap().data().room_subscriptions.contains_key(&room_id) { + // Remove it… + self.inner.sticky.write().unwrap().data_mut().room_subscriptions.remove(&room_id); // … then keep the unsubscription for the next request. self.inner.room_unsubscriptions.write().unwrap().insert(room_id); @@ -172,11 +178,6 @@ impl SlidingSync { self.inner.rooms.blocking_read().len() } - #[instrument(skip(self))] - fn update_to_device_since(&self, since: String) { - self.inner.position.write().unwrap().to_device_token = Some(since); - } - /// Find a list by its name, and do something on it if it exists. pub async fn on_list( &self, @@ -260,25 +261,6 @@ impl SlidingSync { self.inner.rooms.read().await.values().cloned().collect() } - fn prepare_extension_config(&self, pos: Option<&str>) -> ExtensionsConfig { - let mut extensions = self.inner.extensions.clone(); - - if pos.is_none() { - // The pos is `None`, it's either our initial sync or the proxy forgot about us - // and sent us an `UnknownPos` error. We need to send out the config for our - // extensions. - extensions.e2ee.enabled = Some(true); - extensions.to_device.enabled = Some(true); - } - - // Try to chime in a to-device token that may be unset or restored from the - // cache. - let to_device_since = self.inner.position.read().unwrap().to_device_token.clone(); - extensions.to_device.since = to_device_since; - - extensions - } - /// Handle the HTTP response. #[instrument(skip_all)] async fn handle_response( @@ -306,6 +288,17 @@ impl SlidingSync { let mut position_lock = self.inner.position.write().unwrap(); position_lock.pos = Some(sliding_sync_response.pos); position_lock.delta_token = sliding_sync_response.delta_token; + if let Some(to_device) = sliding_sync_response.extensions.to_device { + position_lock.to_device_token = Some(to_device.next_batch); + } + } + + // Commit sticky parameters, if needed. + if let Some(ref txn_id) = sliding_sync_response.txn_id { + let txn_id = txn_id.as_str().into(); + self.inner.sticky.write().unwrap().maybe_commit(txn_id); + let mut lists = self.inner.lists.write().await; + lists.values_mut().for_each(|list| list.maybe_commit_sticky(txn_id)); } let update_summary = { @@ -372,11 +365,6 @@ impl SlidingSync { } } - // Update the `to-device` next-batch if any. - if let Some(to_device) = sliding_sync_response.extensions.to_device { - self.update_to_device_since(to_device.next_batch); - } - updated_lists }; @@ -386,53 +374,69 @@ impl SlidingSync { Ok(update_summary) } + async fn generate_sync_request( + &self, + txn_id: OwnedTransactionId, + ) -> Result<(v4::Request, RequestConfig, BTreeSet)> { + // Collect requests for lists. + let mut requests_lists = BTreeMap::new(); + + { + let lists = self.inner.lists.read().await; + + for (name, list) in lists.iter() { + requests_lists.insert(name.clone(), list.next_request(&txn_id)?); + } + } + + // Collect the `pos` and `delta_token`. + let (pos, delta_token) = { + let position_lock = self.inner.position.read().unwrap(); + + (position_lock.pos.clone(), position_lock.delta_token.clone()) + }; + + Span::current().record("pos", &pos); + + // Collect other data. + let room_unsubscriptions = self.inner.room_unsubscriptions.read().unwrap().clone(); + let timeout = Duration::from_secs(30); + + let mut request = assign!(v4::Request::new(), { + txn_id: Some(txn_id.to_string()), + pos, + delta_token, + timeout: Some(timeout), + lists: requests_lists, + unsubscribe_rooms: room_unsubscriptions.iter().cloned().collect(), + }); + + { + let mut sticky_params = self.inner.sticky.write().unwrap(); + + sticky_params.maybe_apply(&mut request, &txn_id); + + // Set the to_device token if the extension is enabled. + if sticky_params.data().extensions.to_device.enabled == Some(true) { + request.extensions.to_device.since = + self.inner.position.read().unwrap().to_device_token.clone(); + } + } + + Ok(( + // The request itself. + request, + // Configure long-polling. We need 30 seconds for the long-poll itself, in + // addition to 30 more extra seconds for the network delays. + RequestConfig::default().timeout(timeout + Duration::from_secs(30)), + room_unsubscriptions, + )) + } + #[instrument(skip_all, fields(pos))] async fn sync_once(&self) -> Result { - let (request, request_config, requested_room_unsubscriptions) = { - // Collect requests for lists. - let mut requests_lists = BTreeMap::new(); - - { - let mut lists = self.inner.lists.write().await; - - for (name, list) in lists.iter_mut() { - requests_lists.insert(name.clone(), list.next_request()?); - } - } - - // Collect the `pos` and `delta_token`. - let (pos, delta_token) = { - let position_lock = self.inner.position.read().unwrap(); - - (position_lock.pos.clone(), position_lock.delta_token.clone()) - }; - - Span::current().record("pos", &pos); - - // Collect other data. - let room_subscriptions = self.inner.room_subscriptions.read().unwrap().clone(); - let room_unsubscriptions = self.inner.room_unsubscriptions.read().unwrap().clone(); - let timeout = Duration::from_secs(30); - let extensions = self.prepare_extension_config(pos.as_deref()); - - ( - // Build the request itself. - assign!(v4::Request::new(), { - // conn_id: self.inner.id.clone(), - pos, - delta_token, - timeout: Some(timeout), - lists: requests_lists, - room_subscriptions, - unsubscribe_rooms: room_unsubscriptions.iter().cloned().collect(), - extensions, - }), - // Configure long-polling. We need 30 seconds for the long-poll itself, in - // addition to 30 more extra seconds for the network delays. - RequestConfig::default().timeout(timeout + Duration::from_secs(30)), - room_unsubscriptions, - ) - }; + let (request, request_config, requested_room_unsubscriptions) = + self.generate_sync_request(TransactionId::new()).await?; debug!("Sending the sliding sync request"); @@ -532,7 +536,7 @@ impl SlidingSync { #[allow(unknown_lints, clippy::let_with_type_underscore)] // triggered by instrument macro #[instrument(name = "sync_stream", skip_all)] pub fn sync(&self) -> impl Stream> + '_ { - debug!(?self.inner.extensions, ?self.inner.position, "About to run the sync-loop"); + debug!(?self.inner.position, "About to run the sync-loop"); let sync_span = Span::current(); let mut internal_channel_receiver = self.inner.internal_channel.subscribe(); @@ -540,7 +544,7 @@ impl SlidingSync { stream! { loop { sync_span.in_scope(|| { - debug!(?self.inner.extensions, ?self.inner.position,"Sync-loop is running"); + debug!(?self.inner.position, "Sync-loop is running"); }); select! { @@ -592,17 +596,22 @@ impl SlidingSync { } // Let's reset the Sliding Sync session. - sync_span.in_scope(|| { + sync_span.in_scope(|| async { warn!("Session expired. Restarting Sliding Sync."); - // To “restart” a Sliding Sync session, we set `pos` to its initial value. + // To “restart” a Sliding Sync session, we set `pos` to its initial value, and uncommit the sticky parameters, so they're sent next time. { let mut position_lock = self.inner.position.write().unwrap(); position_lock.pos = None; } - debug!(?self.inner.extensions, ?self.inner.position, "Sliding Sync has been reset"); - }); + // Force invalidation of all the sticky parameters. + let _ = self.inner.sticky.write().unwrap().data_mut(); + + self.inner.lists.read().await.values().for_each(|list| list.invalidate_sticky_data()); + + debug!(?self.inner.position, "Sliding Sync has been reset"); + }).await; continue; } @@ -676,8 +685,21 @@ impl SlidingSync { #[derive(Debug)] pub(super) struct SlidingSyncPositionMarkers { + /// An ephemeral position in the current stream, as received from the + /// previous `/sync` response, or `None` for the first request. + /// + /// Should not be persisted. pos: Option, + + /// Server-provided opaque token that remembers what the last timeline and + /// state events stored by the client were. + /// + /// If `None`, the server will send the full information for all the lists + /// present in the request. delta_token: Option, + + /// Server-provided opaque token that remembers the current position in the + /// to-device extension's stream. to_device_token: Option, } @@ -710,58 +732,52 @@ pub struct UpdateSummary { pub rooms: Vec, } +/// The set of sticky parameters owned by the `SlidingSyncInner` instance, and +/// sent in the request. +#[derive(Debug)] +pub(super) struct SlidingSyncStickyParameters { + /// Room subscriptions, i.e. rooms that may be out-of-scope of all lists + /// but one wants to receive updates. + room_subscriptions: BTreeMap, + + /// The intended state of the extensions being supplied to sliding /sync + /// calls. + extensions: ExtensionsConfig, +} + +impl SlidingSyncStickyParameters { + /// Create a new set of sticky parameters. + pub fn new( + room_subscriptions: BTreeMap, + extensions: ExtensionsConfig, + ) -> Self { + Self { room_subscriptions, extensions } + } +} + +impl StickyData for SlidingSyncStickyParameters { + type Request = v4::Request; + + fn apply(&self, request: &mut Self::Request) { + assign!(request, { + room_subscriptions: self.room_subscriptions.clone(), + extensions: self.extensions.clone(), + }); + } +} + #[cfg(test)] mod tests { use assert_matches::assert_matches; use futures_util::{pin_mut, StreamExt}; - use ruma::{ - api::client::sync::sync_events::v4::{E2EEConfig, ToDeviceConfig}, - room_id, - }; - use wiremock::MockServer; + use ruma::{api::client::sync::sync_events::v4::ToDeviceConfig, room_id}; + use serde_json::json; + use wiremock::{Match, MockServer}; use super::*; - use crate::test_utils::logged_in_client; - - #[tokio::test] - async fn to_device_is_enabled_when_pos_is_none() -> Result<()> { - let server = MockServer::start().await; - let client = logged_in_client(Some(server.uri())).await; - - let sync = client.sliding_sync("test-slidingsync")?.build().await?; - let extensions = sync.prepare_extension_config(None); - - // If the user doesn't provide any extension config, we enable to-device and - // e2ee anyways. - assert_matches!( - extensions.to_device, - ToDeviceConfig { enabled: Some(true), since: None, .. } - ); - assert_matches!(extensions.e2ee, E2EEConfig { enabled: Some(true), .. }); - - let some_since = "some_since".to_owned(); - sync.update_to_device_since(some_since.to_owned()); - let extensions = sync.prepare_extension_config(Some("foo")); - - // If there's a `pos` and to-device `since` token, we make sure we put the token - // into the extension config. The rest doesn't need to be re-enabled due to - // stickyness. - assert_matches!( - extensions.to_device, - ToDeviceConfig { enabled: None, since: Some(since), .. } if since == some_since - ); - assert_matches!(extensions.e2ee, E2EEConfig { enabled: None, .. }); - - let extensions = sync.prepare_extension_config(None); - // Even if there isn't a `pos`, if we have a to-device `since` token, we put it - // into the request. - assert_matches!( - extensions.to_device, - ToDeviceConfig { enabled: Some(true), since: Some(since), .. } if since == some_since - ); - - Ok(()) - } + use crate::{ + sliding_sync::sticky_parameters::SlidingSyncStickyManager, test_utils::logged_in_client, + }; async fn new_sliding_sync( lists: Vec, @@ -797,7 +813,8 @@ mod tests { sliding_sync.subscribe_to_room(room1.clone(), None); { - let room_subscriptions = sliding_sync.inner.room_subscriptions.read().unwrap(); + let sticky = sliding_sync.inner.sticky.read().unwrap(); + let room_subscriptions = &sticky.data().room_subscriptions; assert!(room_subscriptions.contains_key(&room0)); assert!(room_subscriptions.contains_key(&room1)); @@ -808,7 +825,8 @@ mod tests { sliding_sync.unsubscribe_from_room(room2.clone()); { - let room_subscriptions = sliding_sync.inner.room_subscriptions.read().unwrap(); + let sticky = sliding_sync.inner.sticky.read().unwrap(); + let room_subscriptions = &sticky.data().room_subscriptions; assert!(!room_subscriptions.contains_key(&room0)); assert!(room_subscriptions.contains_key(&room1)); @@ -833,14 +851,6 @@ mod tests { .sync_mode(SlidingSyncMode::new_selective().add_range(0..=10))]) .await?; - // When no to-device token is present, `prepare_extensions_config` doesn't fill - // the request with it. - let config = sliding_sync.prepare_extension_config(Some("pos")); - assert!(config.to_device.since.is_none()); - - let config = sliding_sync.prepare_extension_config(None); - assert!(config.to_device.since.is_none()); - // When no to-device token is present, it's still not there after caching // either. let frozen = FrozenSlidingSync::from(&sliding_sync); @@ -849,13 +859,7 @@ mod tests { // When a to-device token is present, `prepare_extensions_config` fills the // request with it. let since = String::from("my-to-device-since-token"); - sliding_sync.update_to_device_since(since.clone()); - - let config = sliding_sync.prepare_extension_config(Some("pos")); - assert_eq!(config.to_device.since.as_ref(), Some(&since)); - - let config = sliding_sync.prepare_extension_config(None); - assert_eq!(config.to_device.since.as_ref(), Some(&since)); + sliding_sync.inner.position.write().unwrap().to_device_token = Some(since.clone()); let frozen = FrozenSlidingSync::from(&sliding_sync); assert_eq!(frozen.to_device_since, Some(since)); @@ -889,6 +893,268 @@ mod tests { Ok(()) } + #[test] + fn test_sticky_parameters_api_invalidated_flow() { + let r0 = room_id!("!room:example.org"); + + let mut room_subscriptions = BTreeMap::new(); + room_subscriptions.insert(r0.to_owned(), Default::default()); + + // At first it's invalidated. + let mut sticky = SlidingSyncStickyManager::new(SlidingSyncStickyParameters::new( + room_subscriptions, + Default::default(), + )); + assert!(sticky.is_invalidated()); + + // Then when we create a request, the sticky parameters are applied. + let txn_id: &TransactionId = "tid123".into(); + + let mut request = v4::Request::default(); + request.txn_id = Some(txn_id.to_string()); + + sticky.maybe_apply(&mut request, txn_id); + + assert!(request.txn_id.is_some()); + assert_eq!(request.room_subscriptions.len(), 1); + assert!(request.room_subscriptions.get(r0).is_some()); + + let tid = request.txn_id.unwrap(); + + sticky.maybe_commit(tid.as_str().into()); + assert!(!sticky.is_invalidated()); + + // Applying new parameters will invalidate again. + sticky + .data_mut() + .room_subscriptions + .insert(room_id!("!r1:bar.org").to_owned(), Default::default()); + assert!(sticky.is_invalidated()); + + // Committing with the wrong transaction id will keep it invalidated. + sticky.maybe_commit("wrong tid today, my love has gone away 🎵".into()); + assert!(sticky.is_invalidated()); + + // Restarting a request will only remember the last generated transaction id. + let txn_id1: &TransactionId = "tid456".into(); + let mut request1 = v4::Request::default(); + request1.txn_id = Some(txn_id1.to_string()); + sticky.maybe_apply(&mut request1, txn_id1); + + assert!(sticky.is_invalidated()); + assert_eq!(request1.room_subscriptions.len(), 2); + + let txn_id2: &TransactionId = "tid789".into(); + let mut request2 = v4::Request::default(); + request2.txn_id = Some(txn_id2.to_string()); + + sticky.maybe_apply(&mut request2, txn_id2); + assert!(sticky.is_invalidated()); + assert_eq!(request2.room_subscriptions.len(), 2); + + // Here we commit with the not most-recent TID, so it keeps the invalidated + // status. + sticky.maybe_commit(txn_id1); + assert!(sticky.is_invalidated()); + + // But here we use the latest TID, so the commit is effective. + sticky.maybe_commit(txn_id2); + assert!(!sticky.is_invalidated()); + } + + #[test] + fn test_extensions_are_sticky() { + let mut extensions = ExtensionsConfig::default(); + extensions.account_data.enabled = Some(true); + + // At first it's invalidated. + let mut sticky = SlidingSyncStickyManager::new(SlidingSyncStickyParameters::new( + Default::default(), + extensions, + )); + + assert!(sticky.is_invalidated(), "invalidated because of non default parameters"); + + // `StickyParameters::new` follows its caller's intent when it comes to e2ee and + // to-device. + let extensions = &sticky.data().extensions; + assert_eq!(extensions.e2ee.enabled, None); + assert_eq!(extensions.to_device.enabled, None,); + assert_eq!(extensions.to_device.since, None,); + + // What the user explicitly enabled is... enabled. + assert_eq!(extensions.account_data.enabled, Some(true),); + + let txn_id: &TransactionId = "tid123".into(); + let mut request = v4::Request::default(); + request.txn_id = Some(txn_id.to_string()); + sticky.maybe_apply(&mut request, txn_id); + assert!(sticky.is_invalidated()); + assert_eq!(request.extensions.to_device.enabled, None); + assert_eq!(request.extensions.to_device.since, None); + assert_eq!(request.extensions.e2ee.enabled, None); + assert_eq!(request.extensions.account_data.enabled, Some(true)); + } + + #[tokio::test] + async fn test_sticky_extensions_plus_since() -> Result<()> { + let server = MockServer::start().await; + let client = logged_in_client(Some(server.uri())).await; + + let mut ss_builder = client.sliding_sync("test-slidingsync")?; + ss_builder = ss_builder.add_list(SlidingSyncList::builder("new_list")); + + let sync = ss_builder.build().await?; + + // We get to-device and e2ee even without requesting it. + assert_eq!( + sync.inner.sticky.read().unwrap().data().extensions.to_device.enabled, + Some(true) + ); + assert_eq!(sync.inner.sticky.read().unwrap().data().extensions.e2ee.enabled, Some(true)); + // But what we didn't enable... isn't enabled. + assert_eq!(sync.inner.sticky.read().unwrap().data().extensions.account_data.enabled, None); + + // Even without a since token, the first request will contain the extensions + // configuration, at least. + let txn_id = TransactionId::new(); + let (request, _, _) = sync.generate_sync_request(txn_id.clone()).await?; + + assert_eq!(request.extensions.e2ee.enabled, Some(true)); + assert_eq!(request.extensions.to_device.enabled, Some(true)); + assert!(request.extensions.to_device.since.is_none()); + + { + // Committing with another transaction id doesn't validate anything. + let mut sticky = sync.inner.sticky.write().unwrap(); + assert!(sticky.is_invalidated()); + sticky.maybe_commit( + "hopefully the rng won't generate this very specific transaction id".into(), + ); + assert!(sticky.is_invalidated()); + } + + // Regenerating a request will yield the same one. + let txn_id2 = TransactionId::new(); + let (request, _, _) = sync.generate_sync_request(txn_id2.clone()).await?; + + assert_eq!(request.extensions.e2ee.enabled, Some(true)); + assert_eq!(request.extensions.to_device.enabled, Some(true)); + assert!(request.extensions.to_device.since.is_none()); + + assert!(txn_id != txn_id2, "the two requests must not share the same transaction id"); + + { + // Committing with the expected transaction id will validate it. + let mut sticky = sync.inner.sticky.write().unwrap(); + assert!(sticky.is_invalidated()); + sticky.maybe_commit(txn_id2.as_str().into()); + assert!(!sticky.is_invalidated()); + } + + // The next request should contain no sticky parameters. + let txn_id = TransactionId::new(); + let (request, _, _) = sync.generate_sync_request(txn_id).await?; + assert!(request.extensions.e2ee.enabled.is_none()); + assert!(request.extensions.to_device.enabled.is_none()); + assert!(request.extensions.to_device.since.is_none()); + + // If there's a to-device `since` token, we make sure we put the token + // into the extension config. The rest doesn't need to be re-enabled due to + // stickyness. + let since_token = "since"; + sync.inner.position.write().unwrap().to_device_token = Some(since_token.to_owned()); + + let txn_id = TransactionId::new(); + let (request, _, _) = sync.generate_sync_request(txn_id).await?; + + assert!(request.extensions.e2ee.enabled.is_none()); + assert!(request.extensions.to_device.enabled.is_none()); + assert_eq!(request.extensions.to_device.since.as_deref(), Some(since_token)); + + Ok(()) + } + + #[tokio::test] + async fn test_sticky_parameters_invalidated_by_reset() -> Result<()> { + let server = MockServer::start().await; + let client = logged_in_client(Some(server.uri())).await; + + let sliding_sync = client + .sliding_sync("test-slidingsync")? + .with_to_device_extension(assign!(ToDeviceConfig::default(), { enabled: Some(true) })) + .build() + .await?; + + // First request asks to enable the extension. + let (request, _, _) = sliding_sync.generate_sync_request(TransactionId::new()).await?; + assert!(request.extensions.to_device.enabled.is_some()); + + let sync = sliding_sync.sync(); + pin_mut!(sync); + + #[derive(Clone)] + struct SlidingSyncMatcher; + + impl Match for SlidingSyncMatcher { + fn matches(&self, request: &wiremock::Request) -> bool { + request.url.path() == "/_matrix/client/unstable/org.matrix.msc3575/sync" + && request.method == wiremock::http::Method::Post + } + } + + #[derive(Deserialize)] + struct PartialRequest { + txn_id: Option, + } + + let _mock_guard = wiremock::Mock::given(SlidingSyncMatcher) + .respond_with(|request: &wiremock::Request| { + // Repeat with the txn_id in the response, if set. + let request: PartialRequest = request.body_json().unwrap(); + wiremock::ResponseTemplate::new(200).set_body_json(json!({ + "txn_id": request.txn_id, + "pos": "0" + })) + }) + .mount_as_scoped(&server) + .await; + + let next = sync.next().await; + assert_matches!(next, Some(Ok(_update_summary))); + + // Next request doesn't ask to enable the extension. + let (request, _, _) = sliding_sync.generate_sync_request(TransactionId::new()).await?; + assert!(request.extensions.to_device.enabled.is_none()); + + let next = sync.next().await; + assert_matches!(next, Some(Ok(_update_summary))); + + // Stop responding with successful requests! + drop(_mock_guard); + + // When responding with M_UNKNOWN_POS, that regenerates the sticky parameters, + // so they're reset. + let _mock_guard = wiremock::Mock::given(SlidingSyncMatcher) + .respond_with(wiremock::ResponseTemplate::new(400).set_body_json(json!({ + "error": "foo", + "errcode": "M_UNKNOWN_POS", + }))) + .mount_as_scoped(&server) + .await; + + let next = sync.next().await; + + // The request will retry a few times, then end in an error eventually. + assert_matches!(next, Some(Err(err)) if err.client_api_error_kind() == Some(&ErrorKind::UnknownPos)); + + // Next request asks to enable the extension again. + let (request, _, _) = sliding_sync.generate_sync_request(TransactionId::new()).await?; + assert!(request.extensions.to_device.enabled.is_some()); + + Ok(()) + } + #[tokio::test] async fn test_stop_sync_loop() -> Result<()> { let (_server, sliding_sync) = new_sliding_sync(vec![SlidingSyncList::builder("foo") diff --git a/crates/matrix-sdk/src/sliding_sync/sticky_parameters.rs b/crates/matrix-sdk/src/sliding_sync/sticky_parameters.rs new file mode 100644 index 000000000..c93fc8da9 --- /dev/null +++ b/crates/matrix-sdk/src/sliding_sync/sticky_parameters.rs @@ -0,0 +1,138 @@ +//! Sticky parameters are a way to spare bandwidth on the network, by sending +//! request parameters once and have the server remember them. +//! +//! The set of sticky parameters have to be agreed upon by the server and the +//! client; this is defined in the +//! [MSC](https://github.com/matrix-org/matrix-spec-proposals/blob/kegan/sync-v3/proposals/3575-sync.md). + +use ruma::{OwnedTransactionId, TransactionId}; + +/// A trait to implement for data that can be sticky, given a context. +pub trait StickyData { + /// Request type that will be applied to, if the sticky parameters have been + /// invalidated before. + type Request; + + /// Apply the current data onto the request. + fn apply(&self, request: &mut Self::Request); +} + +/// Helper data structure to manage sticky parameters, for any kind of data. +/// +/// Initially, the provided data is considered to be invalidated, so it's +/// applied onto the request the first time it's sent. Any changes to the +/// wrapped data happen via `[Self::data_mut]`, which invalidates the sticky +/// parameters; they will be applied automatically to the next request. +/// +/// When applying sticky parameters, we will also remember the transaction id +/// that was generated for us, stash it, so we can match the response against +/// the transaction id later, and only consider the data isn't invalidated +/// anymore (we say it's "committed" in that case) if the response's transaction +/// id match what we expect. +#[derive(Debug)] +pub struct SlidingSyncStickyManager { + /// The data managed by this sticky manager. + data: D, + + /// Was any of the parameters invalidated? If yes, reinitialize them. + invalidated: bool, + + /// If the sticky parameters were applied to a given request, this is + /// the transaction id generated for that request, that must be matched + /// upon in the next call to `commit()`. + txn_id: Option, +} + +impl SlidingSyncStickyManager { + /// Create a new `StickyManager` for the given data. + /// + /// Always assume the initial data invalidates the request, at first. + pub fn new(data: D) -> Self { + Self { data, txn_id: None, invalidated: true } + } + + /// Get a mutable reference to the managed data. + /// + /// Will invalidate the sticky set by default. If you don't need to modify + /// the data, use `Self::data()`; if you're not sure you're going to modify + /// the data, it's best to first use `Self::data()` then `Self::data_mut()` + /// when you're sure. + pub fn data_mut(&mut self) -> &mut D { + self.invalidated = true; + &mut self.data + } + + /// Returns a non-invalidating reference to the managed data. + pub fn data(&self) -> &D { + &self.data + } + + /// May apply some the managed sticky parameters to the given request. + /// + /// After receiving the response from this sliding sync, the caller MUST + /// also call [`Self::maybe_commit`] with the transaction id from the + /// server's response. + pub fn maybe_apply(&mut self, req: &mut D::Request, txn_id: &TransactionId) { + if self.invalidated { + self.txn_id = Some(txn_id.to_owned()); + self.data.apply(req); + } + } + + /// May mark the managed data as not invalidated anymore, if the transaction + /// id received from the response matches the one received from the request. + pub fn maybe_commit(&mut self, txn_id: &TransactionId) { + if self.invalidated && self.txn_id.as_deref() == Some(txn_id) { + self.invalidated = false; + } + } + + #[cfg(test)] + pub fn is_invalidated(&self) -> bool { + self.invalidated + } +} + +#[cfg(test)] +mod tests { + use super::*; + + struct EmptyStickyData; + + impl StickyData for EmptyStickyData { + type Request = bool; + + fn apply(&self, req: &mut Self::Request) { + // Mark that applied has had an effect. + *req = true; + } + } + + #[test] + fn test_sticky_parameters_api_non_invalidated_no_effect() { + let mut sticky = SlidingSyncStickyManager::new(EmptyStickyData); + + // At first, it's always invalidated. + assert!(sticky.is_invalidated()); + + let mut applied = false; + sticky.maybe_apply(&mut applied, "tid123".into()); + assert!(applied); + assert!(sticky.is_invalidated()); + + // Committing with the wrong transaction id won't commit. + sticky.maybe_commit("tid456".into()); + assert!(sticky.is_invalidated()); + + // Providing the correct transaction id will commit. + sticky.maybe_commit("tid123".into()); + assert!(!sticky.is_invalidated()); + + // Applying without being invalidated won't do anything. + let mut applied = false; + sticky.maybe_apply(&mut applied, "tid123".into()); + + assert!(!applied); + assert!(!sticky.is_invalidated()); + } +}