diff --git a/Cargo.lock b/Cargo.lock index e32f77ebc..5e26ad28b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2142,6 +2142,18 @@ dependencies = [ "subtle", ] +[[package]] +name = "growable-bloom-filter" +version = "2.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c669fa03050eb3445343f215d62fc1ab831e8098bc9a55f26e9724faff11075c" +dependencies = [ + "serde", + "serde_bytes", + "serde_derive", + "xxhash-rust", +] + [[package]] name = "h2" version = "0.3.26" @@ -3241,6 +3253,7 @@ dependencies = [ "eyeball-im", "futures-executor", "futures-util", + "growable-bloom-filter", "http 1.1.0", "matrix-sdk-common", "matrix-sdk-crypto", @@ -3411,6 +3424,7 @@ dependencies = [ "base64 0.22.1", "getrandom", "gloo-utils", + "growable-bloom-filter", "hkdf", "indexed_db_futures", "js-sys", @@ -3569,6 +3583,7 @@ dependencies = [ "futures-core", "futures-util", "fuzzy-matcher", + "growable-bloom-filter", "imbl", "indexmap 2.2.6", "itertools 0.12.1", @@ -4565,7 +4580,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "19de2de2a00075bf566bee3bd4db014b11587e84184d3f7a791bc17f1a8e9e48" dependencies = [ "anyhow", - "itertools 0.10.5", + "itertools 0.12.1", "proc-macro2", "quote", "syn 2.0.63", @@ -7156,6 +7171,12 @@ dependencies = [ "xshell", ] +[[package]] +name = "xxhash-rust" +version = "0.8.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "927da81e25be1e1a2901d59b81b37dd2efd1fc9c9345a55007f09bf5a2d3ee03" + [[package]] name = "zerocopy" version = "0.7.34" diff --git a/Cargo.toml b/Cargo.toml index 7c36d7dbf..2ac19dfd9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -36,6 +36,7 @@ futures-executor = "0.3.21" futures-util = { version = "0.3.26", default-features = false, features = [ "alloc", ] } +growable-bloom-filter = "2.1.0" http = "1.1.0" imbl = "2.0.0" itertools = "0.12.0" diff --git a/bindings/matrix-sdk-ffi/src/sync_service.rs b/bindings/matrix-sdk-ffi/src/sync_service.rs index 7a1e2800b..329784c3e 100644 --- a/bindings/matrix-sdk-ffi/src/sync_service.rs +++ b/bindings/matrix-sdk-ffi/src/sync_service.rs @@ -25,6 +25,7 @@ use matrix_sdk_ui::{ UnableToDecryptHook, UnableToDecryptInfo as SdkUnableToDecryptInfo, UtdHookManager, }, }; +use tracing::error; use crate::{ error::ClientError, helpers::unwrap_or_clone_arc, room_list::RoomListService, TaskHandle, @@ -93,6 +94,7 @@ impl SyncService { #[derive(Clone, uniffi::Object)] pub struct SyncServiceBuilder { + client: Client, builder: MatrixSyncServiceBuilder, utd_hook: Option>, @@ -100,7 +102,11 @@ pub struct SyncServiceBuilder { impl SyncServiceBuilder { pub(crate) fn new(client: Client) -> Arc { - Arc::new(Self { builder: MatrixSyncService::builder(client), utd_hook: None }) + Arc::new(Self { + client: client.clone(), + builder: MatrixSyncService::builder(client), + utd_hook: None, + }) } } @@ -109,20 +115,33 @@ impl SyncServiceBuilder { pub fn with_cross_process_lock(self: Arc, app_identifier: Option) -> Arc { let this = unwrap_or_clone_arc(self); let builder = this.builder.with_cross_process_lock(app_identifier); - Arc::new(Self { builder, utd_hook: this.utd_hook }) + Arc::new(Self { client: this.client, builder, utd_hook: this.utd_hook }) } - pub fn with_utd_hook(self: Arc, delegate: Box) -> Arc { + pub async fn with_utd_hook( + self: Arc, + delegate: Box, + ) -> Arc { // UTDs detected before this duration may be reclassified as "late decryption" // events (or discarded, if they get decrypted fast enough). const UTD_HOOK_GRACE_PERIOD: Duration = Duration::from_secs(60); let this = unwrap_or_clone_arc(self); - let utd_hook = Some(Arc::new( - UtdHookManager::new(Arc::new(UtdHook { delegate })) - .with_max_delay(UTD_HOOK_GRACE_PERIOD), - )); - Arc::new(Self { builder: this.builder, utd_hook }) + + let mut utd_hook = UtdHookManager::new(Arc::new(UtdHook { delegate }), this.client.clone()) + .with_max_delay(UTD_HOOK_GRACE_PERIOD); + + if let Err(e) = utd_hook.reload_from_store().await { + error!("Unable to reload UTD hook data from data store: {}", e); + // Carry on with the setup anyway; we shouldn't fail setup just + // because the UTD hook failed to load its data. + } + + Arc::new(Self { + client: this.client, + builder: this.builder, + utd_hook: Some(Arc::new(utd_hook)), + }) } pub async fn finish(self: Arc) -> Result, ClientError> { diff --git a/crates/matrix-sdk-base/Cargo.toml b/crates/matrix-sdk-base/Cargo.toml index a662e964a..fcf46c0ff 100644 --- a/crates/matrix-sdk-base/Cargo.toml +++ b/crates/matrix-sdk-base/Cargo.toml @@ -43,6 +43,7 @@ bitflags = { version = "2.4.0", features = ["serde"] } eyeball = { workspace = true } eyeball-im = { workspace = true } futures-util = { workspace = true } +growable-bloom-filter = { workspace = true } http = { workspace = true, optional = true } matrix-sdk-common = { workspace = true } matrix-sdk-crypto = { workspace = true, optional = true } diff --git a/crates/matrix-sdk-base/src/store/integration_tests.rs b/crates/matrix-sdk-base/src/store/integration_tests.rs index 4256860df..e47adfa12 100644 --- a/crates/matrix-sdk-base/src/store/integration_tests.rs +++ b/crates/matrix-sdk-base/src/store/integration_tests.rs @@ -5,6 +5,7 @@ use std::collections::{BTreeMap, BTreeSet}; use assert_matches::assert_matches; use assert_matches2::assert_let; use async_trait::async_trait; +use growable_bloom_filter::GrowableBloomBuilder; use matrix_sdk_test::test_json; use ruma::{ api::client::media::get_content_thumbnail::v3::Method, @@ -62,6 +63,8 @@ pub trait StateStoreIntegrationTests { async fn test_user_avatar_url_saving(&self); /// Test sync token saving. async fn test_sync_token_saving(&self); + /// Test UtdHookManagerData saving. + async fn test_utd_hook_manager_data_saving(&self); /// Test stripped room member saving. async fn test_stripped_member_saving(&self); /// Test room power levels saving. @@ -594,6 +597,37 @@ impl StateStoreIntegrationTests for DynStateStore { assert_matches!(self.get_kv_data(StateStoreDataKey::SyncToken).await, Ok(None)); } + async fn test_utd_hook_manager_data_saving(&self) { + // Before any data is written, the getter should return None. + assert!( + self.get_kv_data(StateStoreDataKey::UtdHookManagerData) + .await + .expect("Could not read data") + .is_none(), + "Store was not empty at start" + ); + + // Put some data in the store... + let data = GrowableBloomBuilder::new().build(); + self.set_kv_data( + StateStoreDataKey::UtdHookManagerData, + StateStoreDataValue::UtdHookManagerData(data.clone()), + ) + .await + .expect("Could not save data"); + + // ... and check it comes back. + let read_data = self + .get_kv_data(StateStoreDataKey::UtdHookManagerData) + .await + .expect("Could not read data") + .expect("no data found") + .into_utd_hook_manager_data() + .expect("not UtdHookManagerData"); + + assert_eq!(read_data, data); + } + async fn test_stripped_member_saving(&self) { let room_id = room_id!("!test_stripped_member_saving:localhost"); let user_id = user_id(); @@ -1351,6 +1385,12 @@ macro_rules! statestore_integration_tests { store.test_sync_token_saving().await } + #[async_test] + async fn test_utd_hook_manager_data_saving() { + let store = get_store().await.expect("creating store failed").into_state_store(); + store.test_utd_hook_manager_data_saving().await; + } + #[async_test] async fn test_stripped_member_saving() { let store = get_store().await.unwrap().into_state_store(); diff --git a/crates/matrix-sdk-base/src/store/memory_store.rs b/crates/matrix-sdk-base/src/store/memory_store.rs index a68b203ce..31320d742 100644 --- a/crates/matrix-sdk-base/src/store/memory_store.rs +++ b/crates/matrix-sdk-base/src/store/memory_store.rs @@ -19,6 +19,7 @@ use std::{ }; use async_trait::async_trait; +use growable_bloom_filter::GrowableBloom; use matrix_sdk_common::{instant::Instant, ring_buffer::RingBuffer}; use ruma::{ canonical_json::{redact, RedactedBecause}, @@ -52,6 +53,7 @@ pub struct MemoryStore { user_avatar_url: StdRwLock>, sync_token: StdRwLock>, filters: StdRwLock>, + utd_hook_manager_data: StdRwLock>, account_data: StdRwLock>>, profiles: StdRwLock>>, display_names: StdRwLock>>>, @@ -94,6 +96,7 @@ impl Default for MemoryStore { user_avatar_url: Default::default(), sync_token: Default::default(), filters: Default::default(), + utd_hook_manager_data: Default::default(), account_data: Default::default(), profiles: Default::default(), display_names: Default::default(), @@ -186,6 +189,12 @@ impl StateStore for MemoryStore { .get(user_id.as_str()) .cloned() .map(StateStoreDataValue::RecentlyVisitedRooms), + StateStoreDataKey::UtdHookManagerData => self + .utd_hook_manager_data + .read() + .unwrap() + .clone() + .map(StateStoreDataValue::UtdHookManagerData), }) } @@ -219,6 +228,13 @@ impl StateStore for MemoryStore { .expect("Session data not a list of recently visited rooms"), ); } + StateStoreDataKey::UtdHookManagerData => { + *self.utd_hook_manager_data.write().unwrap() = Some( + value + .into_utd_hook_manager_data() + .expect("Session data not the hook manager data"), + ); + } } Ok(()) @@ -236,6 +252,9 @@ impl StateStore for MemoryStore { StateStoreDataKey::RecentlyVisitedRooms(user_id) => { self.recently_visited_rooms.write().unwrap().remove(user_id.as_str()); } + StateStoreDataKey::UtdHookManagerData => { + *self.utd_hook_manager_data.write().unwrap() = None + } } Ok(()) } diff --git a/crates/matrix-sdk-base/src/store/traits.rs b/crates/matrix-sdk-base/src/store/traits.rs index 09ebbc0f4..963d644cc 100644 --- a/crates/matrix-sdk-base/src/store/traits.rs +++ b/crates/matrix-sdk-base/src/store/traits.rs @@ -21,6 +21,7 @@ use std::{ use as_variant::as_variant; use async_trait::async_trait; +use growable_bloom_filter::GrowableBloom; use matrix_sdk_common::AsyncTraitDeps; use ruma::{ events::{ @@ -808,6 +809,10 @@ pub enum StateStoreDataValue { /// A list of recently visited room identifiers for the current user RecentlyVisitedRooms(Vec), + + /// Persistent data for + /// `matrix_sdk_ui::unable_to_decrypt_hook::UtdHookManager`. + UtdHookManagerData(GrowableBloom), } impl StateStoreDataValue { @@ -830,6 +835,11 @@ impl StateStoreDataValue { pub fn into_recently_visited_rooms(self) -> Option> { as_variant!(self, Self::RecentlyVisitedRooms) } + + /// Get this value if it is the data for the `UtdHookManager`. + pub fn into_utd_hook_manager_data(self) -> Option { + as_variant!(self, Self::UtdHookManagerData) + } } /// A key for key-value data. @@ -846,6 +856,10 @@ pub enum StateStoreDataKey<'a> { /// Recently visited room identifiers RecentlyVisitedRooms(&'a UserId), + + /// Persistent data for + /// `matrix_sdk_ui::unable_to_decrypt_hook::UtdHookManager`. + UtdHookManagerData, } impl StateStoreDataKey<'_> { @@ -860,4 +874,8 @@ impl StateStoreDataKey<'_> { /// Key prefix to use for the /// [`RecentlyVisitedRooms`][Self::RecentlyVisitedRooms] variant. pub const RECENTLY_VISITED_ROOMS: &'static str = "recently_visited_rooms"; + + /// Key to use for the [`UtdHookManagerData`][Self::UtdHookManagerData] + /// variant. + pub const UTD_HOOK_MANAGER_DATA: &'static str = "utd_hook_manager_data"; } diff --git a/crates/matrix-sdk-indexeddb/Cargo.toml b/crates/matrix-sdk-indexeddb/Cargo.toml index 8063b55ac..5c6fbc05b 100644 --- a/crates/matrix-sdk-indexeddb/Cargo.toml +++ b/crates/matrix-sdk-indexeddb/Cargo.toml @@ -15,7 +15,7 @@ rustdoc-args = ["--cfg", "docsrs"] [features] default = ["e2e-encryption", "state-store"] -state-store = ["dep:matrix-sdk-base"] +state-store = ["dep:matrix-sdk-base", "growable-bloom-filter"] e2e-encryption = ["dep:matrix-sdk-crypto"] testing = ["matrix-sdk-crypto?/testing"] @@ -24,6 +24,7 @@ anyhow = { workspace = true } async-trait = { workspace = true } base64 = { workspace = true } gloo-utils = { version = "0.2.0", features = ["serde"] } +growable-bloom-filter = { workspace = true, optional = true } indexed_db_futures = "0.4.1" js-sys = { version = "0.3.58" } matrix-sdk-base = { workspace = true, features = ["js"], optional = true } diff --git a/crates/matrix-sdk-indexeddb/src/state_store/mod.rs b/crates/matrix-sdk-indexeddb/src/state_store/mod.rs index c38a48b12..624861e24 100644 --- a/crates/matrix-sdk-indexeddb/src/state_store/mod.rs +++ b/crates/matrix-sdk-indexeddb/src/state_store/mod.rs @@ -20,6 +20,7 @@ use std::{ use anyhow::anyhow; use async_trait::async_trait; use gloo_utils::format::JsValueSerdeExt; +use growable_bloom_filter::GrowableBloom; use indexed_db_futures::prelude::*; use matrix_sdk_base::{ deserialized_responses::RawAnySyncOrStrippedState, @@ -388,6 +389,9 @@ impl IndexeddbStateStore { StateStoreDataKey::RecentlyVisitedRooms(user_id) => { self.encode_key(keys::KV, (StateStoreDataKey::RECENTLY_VISITED_ROOMS, user_id)) } + StateStoreDataKey::UtdHookManagerData => { + self.encode_key(keys::KV, StateStoreDataKey::UTD_HOOK_MANAGER_DATA) + } } } } @@ -449,6 +453,10 @@ impl_state_store!({ .map(|f| self.deserialize_event::>(&f)) .transpose()? .map(StateStoreDataValue::RecentlyVisitedRooms), + StateStoreDataKey::UtdHookManagerData => value + .map(|f| self.deserialize_event::(&f)) + .transpose()? + .map(StateStoreDataValue::UtdHookManagerData), }; Ok(value) @@ -475,6 +483,9 @@ impl_state_store!({ .into_recently_visited_rooms() .expect("Session data not a recently visited room list"), ), + StateStoreDataKey::UtdHookManagerData => self.serialize_event( + &value.into_utd_hook_manager_data().expect("Session data not UtdHookManagerData"), + ), }; let tx = diff --git a/crates/matrix-sdk-sqlite/src/state_store.rs b/crates/matrix-sdk-sqlite/src/state_store.rs index 260e166d1..e673c0c16 100644 --- a/crates/matrix-sdk-sqlite/src/state_store.rs +++ b/crates/matrix-sdk-sqlite/src/state_store.rs @@ -278,6 +278,9 @@ impl SqliteStateStore { StateStoreDataKey::RecentlyVisitedRooms(b) => { Cow::Owned(format!("{}:{b}", StateStoreDataKey::RECENTLY_VISITED_ROOMS)) } + StateStoreDataKey::UtdHookManagerData => { + Cow::Borrowed(StateStoreDataKey::UTD_HOOK_MANAGER_DATA) + } }; self.encode_key(keys::KV_BLOB, &*key_s) @@ -896,6 +899,9 @@ impl StateStore for SqliteStateStore { StateStoreDataKey::RecentlyVisitedRooms(_) => { StateStoreDataValue::RecentlyVisitedRooms(self.deserialize_value(&data)?) } + StateStoreDataKey::UtdHookManagerData => { + StateStoreDataValue::UtdHookManagerData(self.deserialize_value(&data)?) + } }) }) .transpose() @@ -919,6 +925,9 @@ impl StateStore for SqliteStateStore { StateStoreDataKey::RecentlyVisitedRooms(_) => self.serialize_value( &value.into_recently_visited_rooms().expect("Session data not breadcrumbs"), )?, + StateStoreDataKey::UtdHookManagerData => self.serialize_value( + &value.into_utd_hook_manager_data().expect("Session data not UtdHookManagerData"), + )?, }; self.acquire() diff --git a/crates/matrix-sdk-ui/CHANGELOG.md b/crates/matrix-sdk-ui/CHANGELOG.md index 56ef35a51..ba88d8b13 100644 --- a/crates/matrix-sdk-ui/CHANGELOG.md +++ b/crates/matrix-sdk-ui/CHANGELOG.md @@ -9,7 +9,14 @@ Breaking changes: Bug fixes: - `UtdHookManager` no longer re-reports UTD events as late decryptions. - ([#3840](https://github.com/matrix-org/matrix-rust-sdk/pull/3840)) + ([#3480](https://github.com/matrix-org/matrix-rust-sdk/pull/3480)) + +Other changes: + +- `UtdHookManager` no longer reports UTD events that were already reported in a + previous session. + ([#3519](https://github.com/matrix-org/matrix-rust-sdk/pull/3519)) + # 0.7.0 diff --git a/crates/matrix-sdk-ui/Cargo.toml b/crates/matrix-sdk-ui/Cargo.toml index ba5cd71ba..a73328061 100644 --- a/crates/matrix-sdk-ui/Cargo.toml +++ b/crates/matrix-sdk-ui/Cargo.toml @@ -34,6 +34,7 @@ eyeball-im-util = { workspace = true } futures-core = { workspace = true } futures-util = { workspace = true } fuzzy-matcher = "0.3.7" +growable-bloom-filter = { workspace = true } imbl = { workspace = true, features = ["serde"] } indexmap = "2.0.0" itertools = { workspace = true } diff --git a/crates/matrix-sdk-ui/src/timeline/event_handler.rs b/crates/matrix-sdk-ui/src/timeline/event_handler.rs index f0f6a96b3..b305ebb0d 100644 --- a/crates/matrix-sdk-ui/src/timeline/event_handler.rs +++ b/crates/matrix-sdk-ui/src/timeline/event_handler.rs @@ -286,7 +286,7 @@ impl<'a, 'o> TimelineEventHandler<'a, 'o> { /// `raw_event` is only needed to determine the cause of any UTDs, /// so if we know this is not a UTD it can be None. #[instrument(skip_all, fields(txn_id, event_id, position))] - pub(super) fn handle_event( + pub(super) async fn handle_event( mut self, day_divider_adjuster: &mut DayDividerAdjuster, event_kind: TimelineEventKind, @@ -367,7 +367,7 @@ impl<'a, 'o> TimelineEventHandler<'a, 'o> { // timeline. if let Some(hook) = self.meta.unable_to_decrypt_hook.as_ref() { if let Flow::Remote { event_id, .. } = &self.ctx.flow { - hook.on_utd(event_id, cause); + hook.on_utd(event_id, cause).await; } } } diff --git a/crates/matrix-sdk-ui/src/timeline/inner/mod.rs b/crates/matrix-sdk-ui/src/timeline/inner/mod.rs index 9cfa7b00f..0b370ed2c 100644 --- a/crates/matrix-sdk-ui/src/timeline/inner/mod.rs +++ b/crates/matrix-sdk-ui/src/timeline/inner/mod.rs @@ -466,16 +466,18 @@ impl TimelineInner

{ let event_content = AnyMessageLikeEventContent::Reaction( ReactionEventContent::from(annotation.clone()), ); - state.handle_local_event( - sender, - sender_profile, - txn_id.clone(), - None, - TimelineEventKind::Message { - content: event_content.clone(), - relations: Default::default(), - }, - ); + state + .handle_local_event( + sender, + sender_profile, + txn_id.clone(), + None, + TimelineEventKind::Message { + content: event_content.clone(), + relations: Default::default(), + }, + ) + .await; ReactionState::Sending(txn_id) } @@ -490,13 +492,9 @@ impl TimelineInner

{ unreachable!("the None/None case has been handled above") }; - state.handle_local_event( - sender, - sender_profile, - TransactionId::new(), - None, - content, - ); + state + .handle_local_event(sender, sender_profile, TransactionId::new(), None, content) + .await; // Remember the remote echo to redact on the homeserver. ReactionState::Redacting(remote_echo_event_id.cloned()) @@ -654,7 +652,7 @@ impl TimelineInner

{ let profile = self.room_data_provider.profile_from_user_id(&sender).await; let mut state = self.state.write().await; - state.handle_local_event(sender, profile, txn_id, abort_handle, content); + state.handle_local_event(sender, profile, txn_id, abort_handle, content).await; } /// Update the send state of a local event represented by a transaction ID. @@ -922,7 +920,7 @@ impl TimelineInner

{ // Notify observers that we managed to eventually decrypt an event. if let Some(hook) = unable_to_decrypt_hook { - hook.on_late_decrypt(&remote_event.event_id, cause); + hook.on_late_decrypt(&remote_event.event_id, cause).await; } Some(event) diff --git a/crates/matrix-sdk-ui/src/timeline/inner/state.rs b/crates/matrix-sdk-ui/src/timeline/inner/state.rs index 6b728acac..736d766cd 100644 --- a/crates/matrix-sdk-ui/src/timeline/inner/state.rs +++ b/crates/matrix-sdk-ui/src/timeline/inner/state.rs @@ -158,7 +158,7 @@ impl TimelineInnerState { /// Adds a local echo (for an event) to the timeline. #[instrument(skip_all)] - pub(super) fn handle_local_event( + pub(super) async fn handle_local_event( &mut self, own_user_id: OwnedUserId, own_profile: Option, @@ -183,13 +183,15 @@ impl TimelineInnerState { let mut day_divider_adjuster = DayDividerAdjuster::default(); - TimelineEventHandler::new(&mut txn, ctx).handle_event( - &mut day_divider_adjuster, - content, - // Local events are never UTD, so no need to pass in a raw_event - this is only used to - // determine the type of UTD if there is one. - None, - ); + TimelineEventHandler::new(&mut txn, ctx) + .handle_event( + &mut day_divider_adjuster, + content, + // Local events are never UTD, so no need to pass in a raw_event - this is only + // used to determine the type of UTD if there is one. + None, + ) + .await; txn.adjust_day_dividers(day_divider_adjuster); @@ -569,11 +571,9 @@ impl TimelineInnerStateTransaction<'_> { }, }; - TimelineEventHandler::new(self, ctx).handle_event( - day_divider_adjuster, - event_kind, - Some(&raw), - ) + TimelineEventHandler::new(self, ctx) + .handle_event(day_divider_adjuster, event_kind, Some(&raw)) + .await } fn clear(&mut self) { diff --git a/crates/matrix-sdk-ui/src/timeline/tests/encryption.rs b/crates/matrix-sdk-ui/src/timeline/tests/encryption.rs index 0ddf4dcea..ccbad4aa2 100644 --- a/crates/matrix-sdk-ui/src/timeline/tests/encryption.rs +++ b/crates/matrix-sdk-ui/src/timeline/tests/encryption.rs @@ -23,7 +23,10 @@ use std::{ use assert_matches::assert_matches; use assert_matches2::assert_let; use eyeball_im::VectorDiff; -use matrix_sdk::crypto::{decrypt_room_key_export, types::events::UtdCause, OlmMachine}; +use matrix_sdk::{ + crypto::{decrypt_room_key_export, types::events::UtdCause, OlmMachine}, + test_utils::test_client_builder, +}; use matrix_sdk_test::{async_test, BOB}; use ruma::{ assign, @@ -76,7 +79,8 @@ async fn test_retry_message_decryption() { } let hook = Arc::new(DummyUtdHook::default()); - let utd_hook = Arc::new(UtdHookManager::new(hook.clone())); + let client = test_client_builder(None).build().await.unwrap(); + let utd_hook = Arc::new(UtdHookManager::new(hook.clone(), client)); let timeline = TestTimeline::with_unable_to_decrypt_hook(utd_hook.clone()); let mut stream = timeline.subscribe().await; diff --git a/crates/matrix-sdk-ui/src/unable_to_decrypt_hook.rs b/crates/matrix-sdk-ui/src/unable_to_decrypt_hook.rs index d6e36a92e..14f72aee8 100644 --- a/crates/matrix-sdk-ui/src/unable_to_decrypt_hook.rs +++ b/crates/matrix-sdk-ui/src/unable_to_decrypt_hook.rs @@ -19,14 +19,22 @@ //! utilities to simplify usage of this trait. use std::{ - collections::{HashMap, HashSet}, + collections::HashMap, sync::{Arc, Mutex}, time::{Duration, Instant}, }; -use matrix_sdk::crypto::types::events::UtdCause; +use growable_bloom_filter::{GrowableBloom, GrowableBloomBuilder}; +use matrix_sdk::{crypto::types::events::UtdCause, Client}; +use matrix_sdk_base::{StateStoreDataKey, StateStoreDataValue, StoreError}; use ruma::{EventId, OwnedEventId}; -use tokio::{spawn, task::JoinHandle, time::sleep}; +use tokio::{ + spawn, + sync::{Mutex as AsyncMutex, MutexGuard}, + task::JoinHandle, + time::sleep, +}; +use tracing::error; /// A generic interface which methods get called whenever we observe a /// unable-to-decrypt (UTD) event. @@ -77,21 +85,13 @@ struct PendingUtdReport { /// that delay. #[derive(Debug)] pub struct UtdHookManager { + /// A Client associated with the UTD hook. This is used to access the store + /// which we persist our data to. + client: Client, + /// The parent hook we'll call, when we have found a unique UTD. parent: Arc, - /// The set of events we've marked as UTDs. - /// - /// Events are added to this set when they are first flagged as UTDs. If - /// they are subsequently successfully decrypted, they are removed from - /// this set. (In other words, this is a superset of the events in - /// [`Self::pending_delayed`]. - /// - /// Note: this is unbounded, because we have absolutely no idea how long it - /// will take for a UTD to resolve, or if it will even resolve at any - /// point. - known_utds: Arc>>, - /// An optional delay before marking the event as UTD ("grace period"). max_delay: Option, @@ -103,16 +103,67 @@ pub struct UtdHookManager { /// Note: this is theoretically unbounded in size, although this set of /// tasks will degrow over time, as tasks expire after the max delay. pending_delayed: Arc>>, + + /// Bloom filter containing the event IDs of events which have been reported + /// as UTDs + reported_utds: Arc>, } impl UtdHookManager { /// Create a new [`UtdHookManager`] for the given hook. - pub fn new(parent: Arc) -> Self { + /// + /// A [`Client`] must also be provided; this provides a link to the + /// [`matrix_sdk_base::StateStore`] which is used to load and store the + /// persistent data. + pub fn new(parent: Arc, client: Client) -> Self { + let bloom_filter = + // Some slightly arbitrarily-chosen parameters here. We specify that, after 1000 + // UTDs, we want to have a false-positive rate of 1%. + // + // The GrowableBloomFilter is based on a series of (partitioned) Bloom filters; + // once the first starts getting full (the expected false-positive + // rate gets too high), it adds another Bloom filter. Each new entry + // is recorded in the most recent Bloom filter; when querying, if + // *any* of the component filters show a match, that shows + // an overall match. + // + // The first component filter is created based on the parameters we give. For + // reasons derived in the paper [1], a partitioned Bloom filter with + // target false-positive rate `P` after `n` insertions requires a + // number of slices `k` given by: + // + // k = log2(1/P) = -ln(P) / ln(2) + // + // ... where each slice has a number of bits `m` given by + // + // m = n / ln(2) + // + // We have to have a whole number of slices and bits, so the total number of + // bits M is: + // + // M = ceil(k) * ceil(m) + // = ceil(-ln(P) / ln(2)) * ceil(n / ln(2)) + // + // In other words, our FP rate of 1% after 1000 insertions requires: + // + // M = ceil(-ln(0.01) / ln(2)) * ceil(1000 / ln(2)) + // = 7 * 1443 = 10101 bits + // + // So our filter starts off with 1263 bytes of data (plus a little overhead). + // Once we hit 1000 UTDs, we add a second component filter with a capacity + // double that of the original and target error rate 85% of the + // original (another 2526 bytes), which then lasts us until a total + // of 3000 UTDs. + // + // [1]: https://gsd.di.uminho.pt/members/cbm/ps/dbloom.pdf + GrowableBloomBuilder::new().estimated_insertions(1000).desired_error_ratio(0.01).build(); + Self { + client, parent, - known_utds: Default::default(), max_delay: None, pending_delayed: Default::default(), + reported_utds: Arc::new(AsyncMutex::new(bloom_filter)), } } @@ -125,44 +176,57 @@ impl UtdHookManager { self } + /// Load the persistent data for the UTD hook from the store. + /// + /// If the client previously used a UtdHookManager, and UTDs were + /// encountered, the data on the reported UTDs is loaded from the store. + /// Otherwise, there is no effect. + pub async fn reload_from_store(&mut self) -> Result<(), StoreError> { + let existing_data = + self.client.store().get_kv_data(StateStoreDataKey::UtdHookManagerData).await?; + + if let Some(existing_data) = existing_data { + let bloom_filter = existing_data + .into_utd_hook_manager_data() + .expect("StateStore::get_kv_data should return data of the right type"); + self.reported_utds = Arc::new(AsyncMutex::new(bloom_filter)); + } + Ok(()) + } + /// The function to call whenever a UTD is seen for the first time. /// /// Pipe in any information that needs to be included in the final report. - pub(crate) fn on_utd(&self, event_id: &EventId, cause: UtdCause) { - // First of all, check if we already have a task to handle this UTD. If so, our - // work is done - let mut pending_delayed_lock = self.pending_delayed.lock().unwrap(); - if pending_delayed_lock.contains_key(event_id) { + pub(crate) async fn on_utd(&self, event_id: &EventId, cause: UtdCause) { + // Hold the lock on `reported_utds` throughout, to avoid races with other + // threads. + let mut reported_utds_lock = self.reported_utds.lock().await; + + // Check if this, or a previous instance of UtdHookManager, has already reported + // this UTD, and bail out if not. + if reported_utds_lock.contains(event_id) { return; } - // Keep track of UTDs we have already seen. - { - let mut known_utds = self.known_utds.lock().unwrap(); - if !known_utds.insert(event_id.to_owned()) { - return; - } + // Otherwise, check if we already have a task to handle this UTD. + if self.pending_delayed.lock().unwrap().contains_key(event_id) { + return; } - // Construct a closure which will report the UTD to the parent. - let report_utd = { - let info = - UnableToDecryptInfo { event_id: event_id.to_owned(), time_to_decrypt: None, cause }; - let parent = self.parent.clone(); - move || { - parent.on_utd(info); - } - }; + let info = + UnableToDecryptInfo { event_id: event_id.to_owned(), time_to_decrypt: None, cause }; let Some(max_delay) = self.max_delay else { // No delay: immediately report the event to the parent hook. - report_utd(); + Self::report_utd(info, &self.parent, &self.client, &mut reported_utds_lock).await; return; }; // Clone data shared with the task below. let pending_delayed = self.pending_delayed.clone(); - let target_event_id = event_id.to_owned(); + let reported_utds = self.reported_utds.clone(); + let parent = self.parent.clone(); + let client = self.client.clone(); // Spawn a task that will wait for the given delay, and maybe call the parent // hook then. @@ -170,15 +234,22 @@ impl UtdHookManager { // Wait for the given delay. sleep(max_delay).await; + // Make sure we take out the lock on `reported_utds` before removing the entry + // from `pending_delayed`, to ensure we don't race against another call to + // `on_utd` (which could otherwise see that the entry has been + // removed from `pending_delayed` but not yet added to + // `reported_utds`). + let mut reported_utds_lock = reported_utds.lock().await; + // Remove the task from the outstanding set. But if it's already been removed, // it's been decrypted since the task was added! - if pending_delayed.lock().unwrap().remove(&target_event_id).is_some() { - report_utd(); + if pending_delayed.lock().unwrap().remove(&info.event_id).is_some() { + Self::report_utd(info, &parent, &client, &mut reported_utds_lock).await; } }); // Add the task to the set of pending tasks. - pending_delayed_lock.insert( + self.pending_delayed.lock().unwrap().insert( event_id.to_owned(), PendingUtdReport { marked_utd_at: Instant::now(), report_task: handle }, ); @@ -189,14 +260,15 @@ impl UtdHookManager { /// /// Note: if this is called for an event that was never marked as a UTD /// before, it has no effect. - pub(crate) fn on_late_decrypt(&self, event_id: &EventId, cause: UtdCause) { - let mut pending_delayed_lock = self.pending_delayed.lock().unwrap(); - self.known_utds.lock().unwrap().remove(event_id); + pub(crate) async fn on_late_decrypt(&self, event_id: &EventId, cause: UtdCause) { + // Hold the lock on `reported_utds` throughout, to avoid races with other + // threads. + let mut reported_utds_lock = self.reported_utds.lock().await; // Only let the parent hook know about the late decryption if the event is // a pending UTD. If so, remove the event from the pending list — // doing so will cause the reporting task to no-op if it runs. - let Some(pending_utd_report) = pending_delayed_lock.remove(event_id) else { + let Some(pending_utd_report) = self.pending_delayed.lock().unwrap().remove(event_id) else { return; }; @@ -209,13 +281,49 @@ impl UtdHookManager { time_to_decrypt: Some(pending_utd_report.marked_utd_at.elapsed()), cause, }; - self.parent.on_utd(info); + Self::report_utd(info, &self.parent, &self.client, &mut reported_utds_lock).await; + } + + /// Helper for [`UtdHookManager::on_utd`] and + /// [`UtdHookManager.on_late_decrypt`]: reports the UTD to the parent, + /// and records the event as reported. + /// + /// Must be called with the lock held on [`UtdHookManager::reported_utds`], + /// and takes a `MutexGuard` to enforce that. + async fn report_utd<'a>( + info: UnableToDecryptInfo, + parent_hook: &Arc, + client: &Client, + reported_utds_lock: &mut MutexGuard<'a, GrowableBloom>, + ) { + let event_id = info.event_id.clone(); + parent_hook.on_utd(info); + reported_utds_lock.insert(event_id); + if let Err(e) = client + .store() + .set_kv_data( + StateStoreDataKey::UtdHookManagerData, + StateStoreDataValue::UtdHookManagerData(reported_utds_lock.clone()), + ) + .await + { + error!("Unable to persist UTD report data: {}", e); + } } } impl Drop for UtdHookManager { fn drop(&mut self) { // Cancel all the outstanding delayed tasks to report UTDs. + // + // Here, we don't take the lock on `reported_utd`s (indeed, we can't, since + // `reported_utds` has an async mutex, and `drop` has to be sync), but + // that's ok. We can't race against `on_utd` or `on_late_decrypt`, since + // they both have `&self` references which mean `drop` can't be called. + // We *could* race against one of the actual tasks to report + // UTDs, but that's ok too: either the report task will bail out when it sees + // the entry has been removed from `pending_delayed` (which is fine), or the + // report task will successfully report the UTD (which is fine). let mut pending_delayed = self.pending_delayed.lock().unwrap(); for (_, pending_utd_report) in pending_delayed.drain() { pending_utd_report.report_task.abort(); @@ -225,6 +333,7 @@ impl Drop for UtdHookManager { #[cfg(test)] mod tests { + use matrix_sdk::test_utils::no_retry_test_client; use matrix_sdk_test::async_test; use ruma::event_id; @@ -241,21 +350,21 @@ mod tests { } } - #[test] - fn test_deduplicates_utds() { + #[async_test] + async fn test_deduplicates_utds() { // If I create a dummy hook, let hook = Arc::new(Dummy::default()); // And I wrap with the UtdHookManager, - let wrapper = UtdHookManager::new(hook.clone()); + let wrapper = UtdHookManager::new(hook.clone(), no_retry_test_client(None).await); // And I call the `on_utd` method multiple times, sometimes on the same event, - wrapper.on_utd(event_id!("$1"), UtdCause::Unknown); - wrapper.on_utd(event_id!("$1"), UtdCause::Unknown); - wrapper.on_utd(event_id!("$2"), UtdCause::Unknown); - wrapper.on_utd(event_id!("$1"), UtdCause::Unknown); - wrapper.on_utd(event_id!("$2"), UtdCause::Unknown); - wrapper.on_utd(event_id!("$3"), UtdCause::Unknown); + wrapper.on_utd(event_id!("$1"), UtdCause::Unknown).await; + wrapper.on_utd(event_id!("$1"), UtdCause::Unknown).await; + wrapper.on_utd(event_id!("$2"), UtdCause::Unknown).await; + wrapper.on_utd(event_id!("$1"), UtdCause::Unknown).await; + wrapper.on_utd(event_id!("$2"), UtdCause::Unknown).await; + wrapper.on_utd(event_id!("$3"), UtdCause::Unknown).await; // Then the event ids have been deduplicated, { @@ -272,32 +381,117 @@ mod tests { } } - #[test] - fn test_on_late_decrypted_no_effect() { + #[async_test] + async fn test_deduplicates_utds_from_previous_session() { + // Use a single client for both hooks, so that both hooks are backed by the same + // memorystore. + let client = no_retry_test_client(None).await; + + // Dummy hook 1, with the first UtdHookManager + { + let hook = Arc::new(Dummy::default()); + let wrapper = UtdHookManager::new(hook.clone(), client.clone()); + + // I call it a couple of times with different events + wrapper.on_utd(event_id!("$1"), UtdCause::Unknown).await; + wrapper.on_utd(event_id!("$2"), UtdCause::Unknown).await; + + // Sanity-check the reported event IDs + { + let utds = hook.utds.lock().unwrap(); + assert_eq!(utds.len(), 2); + assert_eq!(utds[0].event_id, event_id!("$1")); + assert!(utds[0].time_to_decrypt.is_none()); + assert_eq!(utds[1].event_id, event_id!("$2")); + assert!(utds[1].time_to_decrypt.is_none()); + } + } + + // Now, create a *new* hook, with a *new* UtdHookManager + { + let hook = Arc::new(Dummy::default()); + let mut wrapper = UtdHookManager::new(hook.clone(), client.clone()); + wrapper.reload_from_store().await.unwrap(); + + // Call it with more events, some of which match the previous instance + wrapper.on_utd(event_id!("$1"), UtdCause::Unknown).await; + wrapper.on_utd(event_id!("$3"), UtdCause::Unknown).await; + + // Only the *new* ones should be reported + let utds = hook.utds.lock().unwrap(); + assert_eq!(utds.len(), 1); + assert_eq!(utds[0].event_id, event_id!("$3")); + } + } + + /// Test that UTD events which had not yet been reported in a previous + /// session, are reported in the next session. + #[async_test] + async fn test_does_not_deduplicate_late_utds_from_previous_session() { + // Use a single client for both hooks, so that both hooks are backed by the same + // memorystore. + let client = no_retry_test_client(None).await; + + // Dummy hook 1, with the first UtdHookManager + { + let hook = Arc::new(Dummy::default()); + let wrapper = UtdHookManager::new(hook.clone(), client.clone()) + .with_max_delay(Duration::from_secs(2)); + + // a UTD event + wrapper.on_utd(event_id!("$1"), UtdCause::Unknown).await; + + // The event ID should not yet have been reported. + { + let utds = hook.utds.lock().unwrap(); + assert_eq!(utds.len(), 0); + } + } + + // Now, create a *new* hook, with a *new* UtdHookManager + { + let hook = Arc::new(Dummy::default()); + let mut wrapper = UtdHookManager::new(hook.clone(), client.clone()); + wrapper.reload_from_store().await.unwrap(); + + // Call the new hook with the same event + wrapper.on_utd(event_id!("$1"), UtdCause::Unknown).await; + + // And it should be reported. + sleep(Duration::from_millis(2500)).await; + + let utds = hook.utds.lock().unwrap(); + assert_eq!(utds.len(), 1); + assert_eq!(utds[0].event_id, event_id!("$1")); + } + } + + #[async_test] + async fn test_on_late_decrypted_no_effect() { // If I create a dummy hook, let hook = Arc::new(Dummy::default()); // And I wrap with the UtdHookManager, - let wrapper = UtdHookManager::new(hook.clone()); + let wrapper = UtdHookManager::new(hook.clone(), no_retry_test_client(None).await); // And I call the `on_late_decrypt` method before the event had been marked as // utd, - wrapper.on_late_decrypt(event_id!("$1"), UtdCause::Unknown); + wrapper.on_late_decrypt(event_id!("$1"), UtdCause::Unknown).await; // Then nothing is registered in the parent hook. assert!(hook.utds.lock().unwrap().is_empty()); } - #[test] - fn test_on_late_decrypted_after_utd_no_grace_period() { + #[async_test] + async fn test_on_late_decrypted_after_utd_no_grace_period() { // If I create a dummy hook, let hook = Arc::new(Dummy::default()); // And I wrap with the UtdHookManager, - let wrapper = UtdHookManager::new(hook.clone()); + let wrapper = UtdHookManager::new(hook.clone(), no_retry_test_client(None).await); // And I call the `on_utd` method for an event, - wrapper.on_utd(event_id!("$1"), UtdCause::Unknown); + wrapper.on_utd(event_id!("$1"), UtdCause::Unknown).await; // Then the UTD has been notified, but not as late-decrypted event. { @@ -308,7 +502,7 @@ mod tests { } // And when I call the `on_late_decrypt` method, - wrapper.on_late_decrypt(event_id!("$1"), UtdCause::Unknown); + wrapper.on_late_decrypt(event_id!("$1"), UtdCause::Unknown).await; // Then the event is not reported again as a late-decryption. { @@ -329,10 +523,11 @@ mod tests { // And I wrap with the UtdHookManager, configured to delay reporting after 2 // seconds. - let wrapper = UtdHookManager::new(hook.clone()).with_max_delay(Duration::from_secs(2)); + let wrapper = UtdHookManager::new(hook.clone(), no_retry_test_client(None).await) + .with_max_delay(Duration::from_secs(2)); // And I call the `on_utd` method for an event, - wrapper.on_utd(event_id!("$1"), UtdCause::Unknown); + wrapper.on_utd(event_id!("$1"), UtdCause::Unknown).await; // Then the UTD is not being reported immediately. assert!(hook.utds.lock().unwrap().is_empty()); @@ -365,10 +560,11 @@ mod tests { // And I wrap with the UtdHookManager, configured to delay reporting after 2 // seconds. - let wrapper = UtdHookManager::new(hook.clone()).with_max_delay(Duration::from_secs(2)); + let wrapper = UtdHookManager::new(hook.clone(), no_retry_test_client(None).await) + .with_max_delay(Duration::from_secs(2)); // And I call the `on_utd` method for an event, - wrapper.on_utd(event_id!("$1"), UtdCause::Unknown); + wrapper.on_utd(event_id!("$1"), UtdCause::Unknown).await; // Then the UTD has not been notified quite yet. assert!(hook.utds.lock().unwrap().is_empty()); @@ -377,7 +573,7 @@ mod tests { // If I wait for 1 second, and mark the event as late-decrypted, sleep(Duration::from_secs(1)).await; - wrapper.on_late_decrypt(event_id!("$1"), UtdCause::Unknown); + wrapper.on_late_decrypt(event_id!("$1"), UtdCause::Unknown).await; // Then it's being immediately reported as a late-decryption UTD. {