diff --git a/crates/matrix-sdk-base/src/rooms/normal.rs b/crates/matrix-sdk-base/src/rooms/normal.rs index 6b7aa75c8..bd6da1983 100644 --- a/crates/matrix-sdk-base/src/rooms/normal.rs +++ b/crates/matrix-sdk-base/src/rooms/normal.rs @@ -15,7 +15,7 @@ #[cfg(all(feature = "e2e-encryption", feature = "experimental-sliding-sync"))] use std::sync::RwLock as SyncRwLock; use std::{ - collections::{BTreeMap, HashSet}, + collections::{BTreeMap, BTreeSet, HashSet}, mem, sync::{atomic::AtomicBool, Arc}, }; @@ -1222,20 +1222,60 @@ impl Room { } let current_seen_events_guard = self.get_write_guarded_current_knock_request_ids().await?; - // We're not calling `get_seen_join_request_ids` here because we need to keep - // the Mutex's guard until we've updated the data let mut current_seen_events = current_seen_events_guard.clone().unwrap_or_default(); current_seen_events.extend(event_to_user_ids); self.update_seen_knock_request_ids(current_seen_events_guard, current_seen_events).await?; - self.store - .set_kv_data( - StateStoreDataKey::SeenKnockRequests(self.room_id()), - StateStoreDataValue::SeenKnockRequests(current_seen_events), - ) - .await?; + Ok(()) + } + + /// Removes the seen knock request ids that are no longer valid given the + /// current room members. + pub async fn remove_outdated_seen_knock_requests_ids(&self) -> StoreResult<()> { + let current_seen_events_guard = self.get_write_guarded_current_knock_request_ids().await?; + let mut current_seen_events = current_seen_events_guard.clone().unwrap_or_default(); + + // Get and deserialize the member events for the seen knock requests + let keys: Vec = current_seen_events.values().map(|id| id.to_owned()).collect(); + let raw_member_events: Vec = + self.store.get_state_events_for_keys_static(self.room_id(), &keys).await?; + let member_events = raw_member_events + .into_iter() + .map(|raw| raw.deserialize()) + .collect::, _>>()?; + + let mut ids_to_remove = Vec::new(); + + for (event_id, user_id) in current_seen_events.iter() { + // Check the seen knock request ids against the current room member events for + // the room members associated to them + let matching_member = member_events.iter().find(|event| event.user_id() == user_id); + + if let Some(member) = matching_member { + let member_event_id = member.event_id(); + // If the member event is not a knock or it's different knock, it's outdated + if *member.membership() != MembershipState::Knock + || member_event_id.is_some_and(|id| id != event_id) + { + ids_to_remove.push(event_id.to_owned()); + } + } else { + ids_to_remove.push(event_id.to_owned()); + } + } + + // If there are no ids to remove, do nothing + if ids_to_remove.is_empty() { + return Ok(()); + } + + for event_id in ids_to_remove { + current_seen_events.remove(&event_id); + } + + self.update_seen_knock_request_ids(current_seen_events_guard, current_seen_events).await?; Ok(()) } @@ -2101,7 +2141,6 @@ mod tests { use std::{ collections::BTreeSet, ops::{Not, Sub}, - pin::pin, str::FromStr, sync::Arc, time::Duration, @@ -2147,7 +2186,13 @@ mod tests { use super::{compute_display_name_from_heroes, Room, RoomHero, RoomInfo, RoomState, SyncInfo}; #[cfg(any(feature = "experimental-sliding-sync", feature = "e2e-encryption"))] use crate::latest_event::LatestEvent; - use crate::{rooms::RoomNotableTags, store::{IntoStateStore, MemoryStore, StateChanges, StateStore, StoreConfig}, test_utils::logged_in_base_client, BaseClient, MinimalStateEvent, OriginalMinimalStateEvent, RoomDisplayName, RoomInfoNotableUpdateReasons, RoomMembersUpdate, RoomStateFilter, SessionMeta}; + use crate::{ + rooms::RoomNotableTags, + store::{IntoStateStore, MemoryStore, StateChanges, StateStore, StoreConfig}, + test_utils::logged_in_base_client, + BaseClient, MinimalStateEvent, OriginalMinimalStateEvent, RoomDisplayName, + RoomInfoNotableUpdateReasons, RoomStateFilter, SessionMeta, + }; #[test] #[cfg(feature = "experimental-sliding-sync")] @@ -3739,27 +3784,4 @@ mod tests { ] ); } - - #[async_test] - async fn test_room_member_updates_sender_and_receiver() { - use assert_matches::assert_matches; - - let client = logged_in_base_client(None).await; - let room = client.get_or_create_room(room_id!("!a:b.c"), RoomState::Joined); - - let mut receiver = room.room_member_updates_sender.subscribe(); - - assert!(receiver.is_empty()); - - room.room_member_updates_sender - .send(RoomMembersUpdate::FullReload) - .expect("broadcasting a room members update failed"); - - let recv = pin!(receiver.recv()); - let next = matrix_sdk_common::timeout::timeout(recv, Duration::from_secs(1)) - .await - .expect("receiving a room members update timed out") - .expect("failed receiving a room members update"); - assert_matches!(next, RoomMembersUpdate::FullReload); - } } diff --git a/crates/matrix-sdk/tests/integration/room/joined.rs b/crates/matrix-sdk/tests/integration/room/joined.rs index 4e6ca00de..76f9b38a0 100644 --- a/crates/matrix-sdk/tests/integration/room/joined.rs +++ b/crates/matrix-sdk/tests/integration/room/joined.rs @@ -947,4 +947,109 @@ async fn test_subscribe_to_requests_to_join_reloads_members_on_limited_sync() { // There should be no other knock requests assert_pending!(stream) + assert_pending!(stream); + + handle.abort(); +} + +#[async_test] +async fn test_remove_outdated_seen_knock_requests_ids_when_membership_changed() { + let server = MatrixMockServer::new().await; + let client = server.client_builder().build().await; + + server.mock_room_state_encryption().plain().mount().await; + + let room_id = room_id!("!a:b.c"); + let f = EventFactory::new().room(room_id); + + let user_id = user_id!("@alice:b.c"); + let knock_event_id = event_id!("$alice-knock:b.c"); + let knock_event = f + .event(RoomMemberEventContent::new(MembershipState::Knock)) + .event_id(knock_event_id) + .sender(user_id) + .state_key(user_id) + .into_raw_timeline() + .cast(); + + // When syncing the room, we'll have a knock request coming from alice + let room = server.sync_room(&client, JoinedRoomBuilder::new(room_id).add_state_bulk(vec![knock_event])).await; + + // We then mark the knock request as seen + room.mark_knock_requests_as_seen(&[user_id.to_owned()]).await.unwrap(); + + // Now it's received again as seen + let seen = room.get_seen_knock_request_ids().await.unwrap(); + assert_eq!(seen.len(), 1); + + // If we then load the members again and the previously knocking member is in + // another state now + let joined_event = f + .event(RoomMemberEventContent::new(MembershipState::Join)) + .sender(user_id) + .state_key(user_id) + .into_raw_timeline() + .cast(); + + server.mock_get_members().ok(vec![joined_event]).mock_once().mount().await; + + room.mark_members_missing(); + room.sync_members().await.expect("could not reload room members"); + + // Calling remove outdated seen knock request ids will remove the seen id + room.remove_outdated_seen_knock_requests_ids().await.expect("could not remove outdated seen knock request ids"); + + let seen = room.get_seen_knock_request_ids().await.unwrap(); + assert!(seen.is_empty()); +} + +#[async_test] +async fn test_remove_outdated_seen_knock_requests_ids_when_we_have_an_outdated_knock() { + let server = MatrixMockServer::new().await; + let client = server.client_builder().build().await; + + server.mock_room_state_encryption().plain().mount().await; + + let room_id = room_id!("!a:b.c"); + let f = EventFactory::new().room(room_id); + + let user_id = user_id!("@alice:b.c"); + let knock_event_id = event_id!("$alice-knock:b.c"); + let knock_event = f + .event(RoomMemberEventContent::new(MembershipState::Knock)) + .event_id(knock_event_id) + .sender(user_id) + .state_key(user_id) + .into_raw_timeline() + .cast(); + + // When syncing the room, we'll have a knock request coming from alice + let room = server.sync_room(&client, JoinedRoomBuilder::new(room_id).add_state_bulk(vec![knock_event])).await; + + // We then mark the knock request as seen + room.mark_knock_requests_as_seen(&[user_id.to_owned()]).await.unwrap(); + + // Now it's received again as seen + let seen = room.get_seen_knock_request_ids().await.unwrap(); + assert_eq!(seen.len(), 1); + + // If we then load the members again and the previously knocking member has a different event id + let knock_event = f + .event(RoomMemberEventContent::new(MembershipState::Knock)) + .event_id(event_id!("$knock-2:b.c")) + .sender(user_id) + .state_key(user_id) + .into_raw_timeline() + .cast(); + + server.mock_get_members().ok(vec![knock_event]).mock_once().mount().await; + + room.mark_members_missing(); + room.sync_members().await.expect("could not reload room members"); + + // Calling remove outdated seen knock request ids will remove the seen id + room.remove_outdated_seen_knock_requests_ids().await.expect("could not remove outdated seen knock request ids"); + + let seen = room.get_seen_knock_request_ids().await.unwrap(); + assert!(seen.is_empty()); }