From 486b6d6e2b2d8d6aeb7d02c56a43ebc38c0e86f6 Mon Sep 17 00:00:00 2001 From: Andy Balaam Date: Fri, 15 Mar 2024 11:43:37 +0000 Subject: [PATCH] crypto: Save outbound sessions in MemoryStore --- .../src/store/memorystore.rs | 41 +++++++++++++++++-- 1 file changed, 38 insertions(+), 3 deletions(-) diff --git a/crates/matrix-sdk-crypto/src/store/memorystore.rs b/crates/matrix-sdk-crypto/src/store/memorystore.rs index f75b2155f..1e64a5253 100644 --- a/crates/matrix-sdk-crypto/src/store/memorystore.rs +++ b/crates/matrix-sdk-crypto/src/store/memorystore.rs @@ -54,6 +54,7 @@ pub struct MemoryStore { account: StdRwLock>, sessions: SessionStore, inbound_group_sessions: GroupSessionStore, + outbound_group_sessions: StdRwLock>, olm_hashes: StdRwLock>>, devices: DeviceStore, identities: StdRwLock>, @@ -74,6 +75,7 @@ impl Default for MemoryStore { account: Default::default(), sessions: SessionStore::new(), inbound_group_sessions: GroupSessionStore::new(), + outbound_group_sessions: Default::default(), olm_hashes: Default::default(), devices: DeviceStore::new(), identities: Default::default(), @@ -119,6 +121,10 @@ impl MemoryStore { self.inbound_group_sessions.add(session); } } + + fn save_outbound_group_sessions(&self, mut sessions: Vec) { + self.outbound_group_sessions.write().unwrap().append(&mut sessions); + } } type Result = std::result::Result; @@ -151,6 +157,7 @@ impl CryptoStore for MemoryStore { async fn save_changes(&self, changes: Changes) -> Result<()> { self.save_sessions(changes.sessions).await; self.save_inbound_group_sessions(changes.inbound_group_sessions); + self.save_outbound_group_sessions(changes.outbound_group_sessions); self.save_devices(changes.devices.new); self.save_devices(changes.devices.changed); @@ -297,8 +304,17 @@ impl CryptoStore for MemoryStore { Ok(self.backup_keys.read().await.to_owned()) } - async fn get_outbound_group_session(&self, _: &RoomId) -> Result> { - Ok(None) + async fn get_outbound_group_session( + &self, + room_id: &RoomId, + ) -> Result> { + Ok(self + .outbound_group_sessions + .read() + .unwrap() + .iter() + .find(|session| session.room_id() == room_id) + .cloned()) } async fn load_tracked_users(&self) -> Result> { @@ -487,7 +503,7 @@ mod tests { } #[async_test] - async fn test_group_session_store() { + async fn test_inbound_group_session_store() { let (account, _) = get_account_and_session_test_helper(); let room_id = room_id!("!test:localhost"); let curve_key = "Nn0L2hkcCMFKqynTjyGsJbth7QrVmX3lbrksMkrGOAw"; @@ -511,6 +527,25 @@ mod tests { assert_eq!(inbound, loaded_session); } + #[async_test] + async fn test_outbound_group_session_store() { + // Given an outbound sessions + let (account, _) = get_account_and_session_test_helper(); + let room_id = room_id!("!test:localhost"); + let (outbound, _) = account.create_group_session_pair_with_defaults(room_id).await; + + // When we save it to the store + let store = MemoryStore::new(); + store.save_outbound_group_sessions(vec![outbound.clone()]); + + // Then we can get it out again + let loaded_session = store.get_outbound_group_session(room_id).await.unwrap().unwrap(); + assert_eq!( + serde_json::to_string(&outbound.pickle().await).unwrap(), + serde_json::to_string(&loaded_session.pickle().await).unwrap() + ); + } + #[async_test] async fn test_device_store() { let device = get_device();