From e7e15ca280f97a4be0c18a0efb36f2b5d9de836b Mon Sep 17 00:00:00 2001 From: Ivan Enderlin Date: Mon, 7 Apr 2025 15:35:01 +0200 Subject: [PATCH] refactor(base): Create the `tracked_users` response processor. This patch creates the new `tracked_users::update_if_necessary` response processor, and uses it in two places where the code was duplicated. --- crates/matrix-sdk-base/src/client.rs | 44 ++++++------ .../src/response_processors/e2ee/mod.rs | 1 + .../response_processors/e2ee/tracked_users.rs | 69 +++++++++++++++++++ .../src/response_processors/mod.rs | 4 +- crates/matrix-sdk-base/src/sliding_sync.rs | 47 +++++-------- 5 files changed, 112 insertions(+), 53 deletions(-) create mode 100644 crates/matrix-sdk-base/src/response_processors/e2ee/tracked_users.rs diff --git a/crates/matrix-sdk-base/src/client.rs b/crates/matrix-sdk-base/src/client.rs index 61be1b76c..72dd5ecc4 100644 --- a/crates/matrix-sdk-base/src/client.rs +++ b/crates/matrix-sdk-base/src/client.rs @@ -1023,7 +1023,7 @@ impl BaseClient { let (raw_state_events, state_events): (Vec<_>, Vec<_>) = state_events.into_iter().unzip(); - let mut user_ids = self + let mut new_user_ids = self .handle_state( &mut context, &raw_state_events, @@ -1033,7 +1033,7 @@ impl BaseClient { ) .await?; - updated_members_in_room.insert(room_id.to_owned(), user_ids.clone()); + updated_members_in_room.insert(room_id.to_owned(), new_user_ids.clone()); for raw in &new_info.ephemeral.events { match raw.deserialize() { @@ -1065,7 +1065,7 @@ impl BaseClient { false, new_info.timeline.prev_batch, &push_rules, - &mut user_ids, + &mut new_user_ids, &mut room_info, &mut notifications, &mut ambiguity_cache, @@ -1086,22 +1086,16 @@ impl BaseClient { let mut room_info = context.state_changes.room_infos.get(&room_id).unwrap().clone(); #[cfg(feature = "e2e-encryption")] - if room_info.encryption_state().is_encrypted() { - if let Some(o) = self.olm_machine().await.as_ref() { - if !room.encryption_state().is_encrypted() { - // The room turned on encryption in this sync, we need - // to also get all the existing users and mark them for - // tracking. - let user_ids = self - .state_store - .get_user_ids(&room_id, RoomMemberships::ACTIVE) - .await?; - o.update_tracked_users(user_ids.iter().map(Deref::deref)).await? - } - - o.update_tracked_users(user_ids.iter().map(Deref::deref)).await?; - } - } + processors::e2ee::tracked_users::update_or_set_if_room_is_newly_encrypted( + &mut context, + olm_machine.as_ref(), + &new_user_ids, + room_info.encryption_state(), + room.encryption_state(), + &room_id, + &self.state_store, + ) + .await?; let notification_count = new_info.unread_notifications.into(); room_info.update_notification_count(notification_count); @@ -1439,11 +1433,13 @@ impl BaseClient { } #[cfg(feature = "e2e-encryption")] - if room.encryption_state().is_encrypted() { - if let Some(o) = self.olm_machine().await.as_ref() { - o.update_tracked_users(user_ids.iter().map(Deref::deref)).await? - } - } + processors::e2ee::tracked_users::update( + &mut context, + self.olm_machine().await.as_ref(), + room.encryption_state(), + &user_ids, + ) + .await?; context.state_changes.ambiguity_maps.insert(room_id.to_owned(), ambiguity_map); diff --git a/crates/matrix-sdk-base/src/response_processors/e2ee/mod.rs b/crates/matrix-sdk-base/src/response_processors/e2ee/mod.rs index e69e6ac05..5512fdc39 100644 --- a/crates/matrix-sdk-base/src/response_processors/e2ee/mod.rs +++ b/crates/matrix-sdk-base/src/response_processors/e2ee/mod.rs @@ -13,3 +13,4 @@ // limitations under the License. pub mod to_device; +pub mod tracked_users; diff --git a/crates/matrix-sdk-base/src/response_processors/e2ee/tracked_users.rs b/crates/matrix-sdk-base/src/response_processors/e2ee/tracked_users.rs new file mode 100644 index 000000000..cf221dbc2 --- /dev/null +++ b/crates/matrix-sdk-base/src/response_processors/e2ee/tracked_users.rs @@ -0,0 +1,69 @@ +// Copyright 2025 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::collections::BTreeSet; + +use matrix_sdk_crypto::OlmMachine; +use ruma::{OwnedUserId, RoomId}; + +use super::super::Context; +use crate::{store::BaseStateStore, EncryptionState, Result, RoomMemberships}; + +/// Update tracked users, if the room is encrypted. +pub async fn update( + _context: &mut Context, + olm_machine: Option<&OlmMachine>, + room_encryption_state: EncryptionState, + user_ids_to_track: &BTreeSet, +) -> Result<()> { + if room_encryption_state.is_encrypted() { + if let Some(olm) = olm_machine { + if !user_ids_to_track.is_empty() { + olm.update_tracked_users(user_ids_to_track.iter().map(AsRef::as_ref)).await? + } + } + } + + Ok(()) +} + +/// Update tracked users, if the room is encrypted, or if the room has become +/// encrypted. +pub async fn update_or_set_if_room_is_newly_encrypted( + _context: &mut Context, + olm_machine: Option<&OlmMachine>, + user_ids_to_track: &BTreeSet, + new_room_encryption_state: EncryptionState, + previous_room_encryption_state: EncryptionState, + room_id: &RoomId, + state_store: &BaseStateStore, +) -> Result<()> { + if new_room_encryption_state.is_encrypted() { + if let Some(olm) = olm_machine { + if !previous_room_encryption_state.is_encrypted() { + // The room turned on encryption in this sync, we need + // to also get all the existing users and mark them for + // tracking. + let user_ids = state_store.get_user_ids(room_id, RoomMemberships::ACTIVE).await?; + olm.update_tracked_users(user_ids.iter().map(AsRef::as_ref)).await? + } + + if !user_ids_to_track.is_empty() { + olm.update_tracked_users(user_ids_to_track.iter().map(AsRef::as_ref)).await?; + } + } + } + + Ok(()) +} diff --git a/crates/matrix-sdk-base/src/response_processors/mod.rs b/crates/matrix-sdk-base/src/response_processors/mod.rs index 82d587f13..0bd09922d 100644 --- a/crates/matrix-sdk-base/src/response_processors/mod.rs +++ b/crates/matrix-sdk-base/src/response_processors/mod.rs @@ -26,7 +26,9 @@ use std::collections::BTreeMap; #[cfg(feature = "e2e-encryption")] mod with_e2ee { pub use super::{ - e2ee::to_device, latest_event::decrypt_latest_events, verification::verification, + e2ee::{to_device, tracked_users}, + latest_event::decrypt_latest_events, + verification::verification, }; } use ruma::OwnedRoomId; diff --git a/crates/matrix-sdk-base/src/sliding_sync.rs b/crates/matrix-sdk-base/src/sliding_sync.rs index 1939f0546..57103db71 100644 --- a/crates/matrix-sdk-base/src/sliding_sync.rs +++ b/crates/matrix-sdk-base/src/sliding_sync.rs @@ -15,8 +15,6 @@ //! Extend `BaseClient` with capabilities to handle MSC4186. use std::collections::BTreeMap; -#[cfg(feature = "e2e-encryption")] -use std::ops::Deref; #[cfg(feature = "e2e-encryption")] use matrix_sdk_common::deserialized_responses::TimelineEvent; @@ -37,9 +35,12 @@ use ruma::{ use tracing::{instrument, trace, warn}; use super::BaseClient; +#[cfg(feature = "e2e-encryption")] +use crate::latest_event::{is_suitable_for_latest_event, LatestEvent, PossibleLatestEvent}; use crate::{ error::Result, read_receipts::{compute_unread_counts, PreviousEventsProvider}, + response_processors as processors, response_processors::account_data::AccountDataProcessor, rooms::{ normal::{RoomHero, RoomInfoNotableUpdateReasons}, @@ -50,11 +51,6 @@ use crate::{ sync::{JoinedRoomUpdate, LeftRoomUpdate, Notification, RoomUpdates, SyncResponse}, RequestedRequiredStates, Room, RoomInfo, }; -#[cfg(feature = "e2e-encryption")] -use crate::{ - latest_event::{is_suitable_for_latest_event, LatestEvent, PossibleLatestEvent}, - response_processors as processors, RoomMemberships, -}; impl BaseClient { #[cfg(feature = "e2e-encryption")] @@ -357,7 +353,7 @@ impl BaseClient { requested_required_states: &[(StateEventType, String)], room_data: &http::response::Room, rooms_account_data: &mut BTreeMap>>, - store: &BaseStateStore, + state_store: &BaseStateStore, user_id: &UserId, account_data_processor: &AccountDataProcessor, notifications: &mut BTreeMap>, @@ -381,7 +377,7 @@ impl BaseClient { }; // Find or create the room in the store - let is_new_room = !store.room_exists(room_id); + let is_new_room = !state_store.room_exists(room_id); let stripped_state: Option, AnyStrippedStateEvent)>> = room_data @@ -395,7 +391,7 @@ impl BaseClient { context, &state_events, stripped_state.as_ref(), - store, + state_store, user_id, room_id, ); @@ -403,7 +399,7 @@ impl BaseClient { room_info.mark_state_partially_synced(); room_info.handle_encryption_state(requested_required_states); - let mut user_ids = if !state_events.is_empty() { + let mut new_user_ids = if !state_events.is_empty() { self.handle_state( context, &raw_state_events, @@ -442,7 +438,7 @@ impl BaseClient { true, room_data.prev_batch.clone(), &push_rules, - &mut user_ids, + &mut new_user_ids, &mut room_info, notifications, ambiguity_cache, @@ -457,26 +453,21 @@ impl BaseClient { &mut room_info, &timeline.events, Some(&context.state_changes), - Some(store), + Some(state_store), ) .await; #[cfg(feature = "e2e-encryption")] - if room_info.encryption_state().is_encrypted() { - if let Some(o) = self.olm_machine().await.as_ref() { - if !room.encryption_state().is_encrypted() { - // The room turned on encryption in this sync, we need - // to also get all the existing users and mark them for - // tracking. - let user_ids = store.get_user_ids(room_id, RoomMemberships::ACTIVE).await?; - o.update_tracked_users(user_ids.iter().map(Deref::deref)).await? - } - - if !user_ids.is_empty() { - o.update_tracked_users(user_ids.iter().map(Deref::deref)).await?; - } - } - } + processors::e2ee::tracked_users::update_or_set_if_room_is_newly_encrypted( + context, + self.olm_machine().await.as_ref(), + &new_user_ids, + room_info.encryption_state(), + room.encryption_state(), + room_id, + state_store, + ) + .await?; let notification_count = room_data.unread_notifications.clone().into(); room_info.update_notification_count(notification_count);