From dcf00697539321cf4eac5cd4929d45347b947da7 Mon Sep 17 00:00:00 2001 From: Andy Balaam Date: Tue, 20 Feb 2024 14:26:33 +0000 Subject: [PATCH] export: Provide a streamed way to export keys Signed-off-by: Andy Balaam --- crates/matrix-sdk-crypto/CHANGELOG.md | 3 + crates/matrix-sdk-crypto/src/machine.rs | 2 +- crates/matrix-sdk-crypto/src/store/mod.rs | 135 ++++++++++++++++++++++ 3 files changed, 139 insertions(+), 1 deletion(-) diff --git a/crates/matrix-sdk-crypto/CHANGELOG.md b/crates/matrix-sdk-crypto/CHANGELOG.md index 655d7629f..42d99e867 100644 --- a/crates/matrix-sdk-crypto/CHANGELOG.md +++ b/crates/matrix-sdk-crypto/CHANGELOG.md @@ -26,6 +26,9 @@ Additions: to encrypt an event to a specific device. ([#3091](https://github.com/matrix-org/matrix-rust-sdk/pull/3091)) +- Add new API `store::Store::export_room_keys_stream` that provides room + keys on demand. + # 0.7.0 - Add method to mark a list of inbound group sessions as backed up: diff --git a/crates/matrix-sdk-crypto/src/machine.rs b/crates/matrix-sdk-crypto/src/machine.rs index 80ebba2b6..8274bbb26 100644 --- a/crates/matrix-sdk-crypto/src/machine.rs +++ b/crates/matrix-sdk-crypto/src/machine.rs @@ -2317,7 +2317,7 @@ pub(crate) mod tests { (machine, otk) } - async fn get_machine_pair( + pub async fn get_machine_pair( alice: &UserId, bob: &UserId, use_fallback_key: bool, diff --git a/crates/matrix-sdk-crypto/src/store/mod.rs b/crates/matrix-sdk-crypto/src/store/mod.rs index a6a2ceddc..67969a0e2 100644 --- a/crates/matrix-sdk-crypto/src/store/mod.rs +++ b/crates/matrix-sdk-crypto/src/store/mod.rs @@ -1633,6 +1633,48 @@ impl Store { Ok(exported) } + + /// Export room keys matching a predicate, providing them as an async + /// `Stream`. + /// + /// # Arguments + /// + /// * `predicate` - A closure that will be called for every known + /// `InboundGroupSession`, which represents a room key. If the closure + /// returns `true` the `InboundGroupSession` will be included in the export, + /// if the closure returns `false` it will not be included. + /// + /// # Examples + /// + /// ```no_run + /// use std::pin::pin; + /// + /// use matrix_sdk_crypto::{olm::ExportedRoomKey, OlmMachine}; + /// use ruma::{device_id, room_id, user_id}; + /// use tokio_stream::StreamExt; + /// # async { + /// let alice = user_id!("@alice:example.org"); + /// let machine = OlmMachine::new(&alice, device_id!("DEVICEID")).await; + /// let room_id = room_id!("!test:localhost"); + /// let mut keys = pin!(machine + /// .store() + /// .export_room_keys_stream(|s| s.room_id() == room_id) + /// .await + /// .unwrap()); + /// while let Some(key) = keys.next().await { + /// println!("{}", key.room_id); + /// } + /// # }; + /// ``` + pub async fn export_room_keys_stream( + &self, + predicate: impl FnMut(&InboundGroupSession) -> bool, + ) -> Result> { + // TODO: if/when there is a get_inbound_group_sessions_stream, use that here. + let sessions = self.get_inbound_group_sessions().await?; + Ok(futures_util::stream::iter(sessions.into_iter().filter(predicate)) + .then(|session| async move { session.export().await })) + } } impl Deref for Store { @@ -1661,3 +1703,96 @@ impl matrix_sdk_common::store_locks::BackingStore for LockableCryptoStore { self.0.try_take_leased_lock(lease_duration_ms, key, holder).await } } + +#[cfg(test)] +mod tests { + use std::pin::pin; + + use futures_util::StreamExt; + use matrix_sdk_test::async_test; + use ruma::{room_id, user_id}; + + use crate::{machine::tests::get_machine_pair, types::EventEncryptionAlgorithm}; + + #[async_test] + async fn export_room_keys_provides_selected_keys() { + // Given an OlmMachine with room keys in it + let (alice, _, _) = get_machine_pair(user_id!("@a:s.co"), user_id!("@b:s.co"), false).await; + let room1_id = room_id!("!room1:localhost"); + let room2_id = room_id!("!room2:localhost"); + let room3_id = room_id!("!room3:localhost"); + alice.create_outbound_group_session_with_defaults_test_helper(room1_id).await.unwrap(); + alice.create_outbound_group_session_with_defaults_test_helper(room2_id).await.unwrap(); + alice.create_outbound_group_session_with_defaults_test_helper(room3_id).await.unwrap(); + + // When I export some of the keys + let keys = alice + .store() + .export_room_keys(|s| s.room_id() == room2_id || s.room_id() == room3_id) + .await + .unwrap(); + + // Then the requested keys were provided + assert_eq!(keys.len(), 2); + assert_eq!(keys[0].algorithm, EventEncryptionAlgorithm::MegolmV1AesSha2); + assert_eq!(keys[1].algorithm, EventEncryptionAlgorithm::MegolmV1AesSha2); + assert_eq!(keys[0].room_id, "!room2:localhost"); + assert_eq!(keys[1].room_id, "!room3:localhost"); + assert_eq!(keys[0].session_key.to_base64().len(), 220); + assert_eq!(keys[1].session_key.to_base64().len(), 220); + } + + #[async_test] + async fn export_room_keys_stream_can_provide_all_keys() { + // Given an OlmMachine with room keys in it + let (alice, _, _) = get_machine_pair(user_id!("@a:s.co"), user_id!("@b:s.co"), false).await; + let room1_id = room_id!("!room1:localhost"); + let room2_id = room_id!("!room2:localhost"); + alice.create_outbound_group_session_with_defaults_test_helper(room1_id).await.unwrap(); + alice.create_outbound_group_session_with_defaults_test_helper(room2_id).await.unwrap(); + + // When I export the keys as a stream + let mut keys = pin!(alice.store().export_room_keys_stream(|_| true).await.unwrap()); + + // And collect them + let mut collected = vec![]; + while let Some(key) = keys.next().await { + collected.push(key); + } + + // Then all the keys were provided + assert_eq!(collected.len(), 2); + assert_eq!(collected[0].algorithm, EventEncryptionAlgorithm::MegolmV1AesSha2); + assert_eq!(collected[1].algorithm, EventEncryptionAlgorithm::MegolmV1AesSha2); + assert_eq!(collected[0].room_id, "!room1:localhost"); + assert_eq!(collected[1].room_id, "!room2:localhost"); + assert_eq!(collected[0].session_key.to_base64().len(), 220); + assert_eq!(collected[1].session_key.to_base64().len(), 220); + } + + #[async_test] + async fn export_room_keys_stream_can_provide_a_subset_of_keys() { + // Given an OlmMachine with room keys in it + let (alice, _, _) = get_machine_pair(user_id!("@a:s.co"), user_id!("@b:s.co"), false).await; + let room1_id = room_id!("!room1:localhost"); + let room2_id = room_id!("!room2:localhost"); + alice.create_outbound_group_session_with_defaults_test_helper(room1_id).await.unwrap(); + alice.create_outbound_group_session_with_defaults_test_helper(room2_id).await.unwrap(); + + // When I export the keys as a stream + let mut keys = + pin!(alice.store().export_room_keys_stream(|s| s.room_id() == room1_id).await.unwrap()); + + // And collect them + let mut collected = vec![]; + while let Some(key) = keys.next().await { + collected.push(key); + } + + // Then all the keys matching our predicate were provided, and no others + assert_eq!(collected.len(), 1); + assert_eq!(collected[0].algorithm, EventEncryptionAlgorithm::MegolmV1AesSha2); + assert_eq!(collected[0].room_id, "!room1:localhost"); + assert_eq!(collected[0].session_key.to_base64().len(), 220); + } +}