diff --git a/crates/matrix-sdk-crypto/src/store/memorystore.rs b/crates/matrix-sdk-crypto/src/store/memorystore.rs index a45d6e99e..54339756c 100644 --- a/crates/matrix-sdk-crypto/src/store/memorystore.rs +++ b/crates/matrix-sdk-crypto/src/store/memorystore.rs @@ -56,7 +56,7 @@ pub struct MemoryStore { inbound_group_sessions: GroupSessionStore, outbound_group_sessions: StdRwLock>, private_identity: StdRwLock>, - tracked_users: StdRwLock>, + tracked_users: StdRwLock>, olm_hashes: StdRwLock>>, devices: DeviceStore, identities: StdRwLock>, @@ -324,12 +324,13 @@ impl CryptoStore for MemoryStore { } async fn load_tracked_users(&self) -> Result> { - Ok(self.tracked_users.read().unwrap().clone()) + Ok(self.tracked_users.read().unwrap().values().cloned().collect()) } async fn save_tracked_users(&self, tracked_users: &[(&UserId, bool)]) -> Result<()> { self.tracked_users.write().unwrap().extend(tracked_users.iter().map(|(user_id, dirty)| { - TrackedUser { user_id: user_id.to_owned().into(), dirty: *dirty } + let user_id: OwnedUserId = user_id.to_owned().into(); + (user_id.clone(), TrackedUser { user_id, dirty: *dirty }) })); Ok(()) } @@ -559,23 +560,29 @@ mod tests { } #[async_test] - async fn test_tracked_users_store() { - // Given some tracked users - let tracked_users = - &[(user_id!("@dirty_user:s"), true), (user_id!("@clean_user:t"), false)]; - - // When we save them to the store + async fn test_tracked_users_are_stored_once_per_user_id() { + // Given a store containing 2 tracked users, both dirty + let user1 = user_id!("@user1:s"); + let user2 = user_id!("@user2:s"); + let user3 = user_id!("@user3:s"); let store = MemoryStore::new(); - store.save_tracked_users(tracked_users).await.unwrap(); + store.save_tracked_users(&[(user1, true), (user2, true)]).await.unwrap(); - // Then we can get them out again + // When we mark one as clean and add another + store.save_tracked_users(&[(user2, false), (user3, false)]).await.unwrap(); + + // Then we can get them out again and their dirty flags are correct let loaded_tracked_users = store.load_tracked_users().await.expect("failed to load tracked users"); - assert_eq!(loaded_tracked_users[0].user_id, user_id!("@dirty_user:s")); - assert!(loaded_tracked_users[0].dirty); - assert_eq!(loaded_tracked_users[1].user_id, user_id!("@clean_user:t")); - assert!(!loaded_tracked_users[1].dirty); - assert_eq!(loaded_tracked_users.len(), 2); + + let tracked_contains = |user_id, dirty| { + loaded_tracked_users.iter().any(|u| u.user_id == user_id && u.dirty == dirty) + }; + + assert!(tracked_contains(user1, true)); + assert!(tracked_contains(user2, false)); + assert!(tracked_contains(user3, false)); + assert_eq!(loaded_tracked_users.len(), 3); } #[async_test] diff --git a/crates/matrix-sdk-crypto/src/store/traits.rs b/crates/matrix-sdk-crypto/src/store/traits.rs index bebcdd03b..3347f373d 100644 --- a/crates/matrix-sdk-crypto/src/store/traits.rs +++ b/crates/matrix-sdk-crypto/src/store/traits.rs @@ -133,11 +133,14 @@ pub trait CryptoStore: AsyncTraitDeps { room_id: &RoomId, ) -> Result, Self::Error>; - /// Load the list of users whose devices we are keeping track of. + /// Provide the list of users whose devices we are keeping track of, and + /// whether they are considered dirty/outdated. async fn load_tracked_users(&self) -> Result, Self::Error>; - /// Save a list of users and their respective dirty/outdated flags to the - /// store. + /// Update the list of users whose devices we are keeping track of, and + /// whether they are considered dirty/outdated. + /// + /// Replaces any existing entry with a matching user ID. async fn save_tracked_users(&self, users: &[(&UserId, bool)]) -> Result<(), Self::Error>; /// Get the device for the given user with the given device ID.