From 0154cccd4b0c05500ff6111ffb8e992b467e22e7 Mon Sep 17 00:00:00 2001 From: valere Date: Thu, 16 Feb 2023 17:25:08 +0100 Subject: [PATCH] feat(withheld) Remember no_olm and implement in memory store --- .../src/session_manager/group_sessions.rs | 94 +++++++++++++++++-- .../src/store/memorystore.rs | 61 +++++++++++- crates/matrix-sdk-crypto/src/store/mod.rs | 4 +- crates/matrix-sdk-crypto/src/store/traits.rs | 7 +- .../src/types/events/room_key_withheld.rs | 1 - crates/matrix-sdk-sled/src/crypto_store.rs | 13 ++- crates/matrix-sdk-sqlite/src/crypto_store.rs | 9 ++ 7 files changed, 175 insertions(+), 14 deletions(-) diff --git a/crates/matrix-sdk-crypto/src/session_manager/group_sessions.rs b/crates/matrix-sdk-crypto/src/session_manager/group_sessions.rs index 2da800f4d..ed25bfd24 100644 --- a/crates/matrix-sdk-crypto/src/session_manager/group_sessions.rs +++ b/crates/matrix-sdk-crypto/src/session_manager/group_sessions.rs @@ -635,19 +635,26 @@ impl GroupSessionManager { }) .collect(); - no_olm.iter().for_each(|(u, list)| { - list.into_iter().for_each(|d| { - flat_with_code.push((u.to_owned(), d.to_owned(), WithheldCode::NoOlm)) - }); - }); + for (u, list_devices) in no_olm { + for dev in list_devices { + let no_olm_sent = + self.store.is_no_olm_sent(u.to_owned(), dev.to_owned()).await.unwrap_or(false); + if !no_olm_sent { + flat_with_code.push((u.to_owned(), dev.to_owned(), WithheldCode::NoOlm)); + changes + .no_olm_sent + .entry(u.to_owned()) + .or_insert_with(Vec::new) + .push(dev.to_owned()); + } + } + } let withheld_to_device_request: Vec> = flat_with_code .chunks(Self::MAX_TO_DEVICE_MESSAGES) .map(|chunk| { let mut messages = BTreeMap::new(); - // TODO do not resend code if already sent for that session (or target if - // no_olm) chunk.into_iter().for_each(|(user_id, device_id, code)| { let content = RoomKeyWithheldContent::create( algorithm.to_owned(), @@ -947,6 +954,77 @@ mod tests { .map(|r| r.message_count()) .sum(); assert_eq!(withheld_count, 2); + + requests.iter().filter(|r| r.event_type == "m.room_key.withheld".into()).for_each(|r| { + println!("{:?}", r.messages); + }) + } + + #[async_test] + async fn test_no_olm_sent_once() { + let machine = machine().await; + let keys_claim = keys_claim_response(); + + let users = keys_claim.one_time_keys.keys().map(Deref::deref); + + let first_room_id = room_id!("!test:localhost"); + let requests = machine + .share_room_key(first_room_id, users, EncryptionSettings::default()) + .await + .unwrap(); + + // there will be two no_olm + let withheld_count: usize = requests + .iter() + .filter(|r| r.event_type == "m.room_key.withheld".into()) + .map(|r| { + let mut count = 0; + // count targets + for (_, message) in &r.messages { + message.iter().for_each(|(_, content)| { + let withheld: MegolmV1AesSha2WithheldContent = + content.deserialize_as::().unwrap(); + if withheld.code == WithheldCode::NoOlm { + count += 1; + } + }) + } + count + }) + .sum(); + assert_eq!(withheld_count, 2); + + // Use another session in an other room + let second_room_id = room_id!("!other:localhost"); + let users = keys_claim.one_time_keys.keys().map(Deref::deref); + let requests = machine + .share_room_key(second_room_id, users, EncryptionSettings::default()) + .await + .unwrap(); + + // there will be two no_olm + let withheld_count: usize = requests + .iter() + .filter(|r| r.event_type == "m.room_key.withheld".into()) + .map(|r| { + let mut count = 0; + // count targets + for (_, message) in &r.messages { + message.iter().for_each(|(_, content)| { + let withheld: MegolmV1AesSha2WithheldContent = + content.deserialize_as::().unwrap(); + if withheld.code == WithheldCode::NoOlm { + count += 1; + } + }) + } + count + }) + .sum(); + assert_eq!(withheld_count, 0); + + // Help how do I simulate the creation of a new session for the device + // with no session now? } #[async_test] @@ -1168,6 +1246,6 @@ mod tests { withheld.code == WithheldCode::Blacklisted }); - assert!(has_blacklist) + assert!(has_blacklist); } } diff --git a/crates/matrix-sdk-crypto/src/store/memorystore.rs b/crates/matrix-sdk-crypto/src/store/memorystore.rs index 87b3e39f8..fd342651c 100644 --- a/crates/matrix-sdk-crypto/src/store/memorystore.rs +++ b/crates/matrix-sdk-crypto/src/store/memorystore.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::{collections::HashMap, convert::Infallible, sync::Arc}; +use std::{collections::HashMap, convert::Infallible, ops::Deref, sync::Arc}; use async_trait::async_trait; use dashmap::{DashMap, DashSet}; @@ -51,6 +51,7 @@ pub struct MemoryStore { identities: Arc>, outgoing_key_requests: Arc>, key_requests_by_info: Arc>, + no_olm_sent: Arc>>, } impl Default for MemoryStore { @@ -63,6 +64,7 @@ impl Default for MemoryStore { identities: Default::default(), outgoing_key_requests: Default::default(), key_requests_by_info: Default::default(), + no_olm_sent: Default::default(), } } } @@ -88,6 +90,11 @@ impl MemoryStore { async fn save_sessions(&self, sessions: Vec) { for session in sessions { let _ = self.sessions.add(session.clone()).await; + // we have a new session for that device so we can forgot no_olm + self.no_olm_sent + .entry(session.user_id.deref().to_owned()) + .or_default() + .remove(session.device_id.deref()); } } @@ -144,6 +151,13 @@ impl CryptoStore for MemoryStore { self.key_requests_by_info.insert(info_string, id); } + for (user_id, device_id_list) in changes.no_olm_sent { + self.no_olm_sent + .entry(user_id) + .or_insert_with(DashSet::new) + .extend(device_id_list.into_iter()); + } + Ok(()) } @@ -270,12 +284,27 @@ impl CryptoStore for MemoryStore { async fn load_backup_keys(&self) -> Result { Ok(BackupKeys::default()) } + + async fn is_no_olm_sent(&self, user_id: OwnedUserId, device_id: OwnedDeviceId) -> Result { + Ok(self.no_olm_sent.entry(user_id).or_default().contains(&device_id)) + } + + /* async fn forget_no_olm_sent( + &self, + user_id: OwnedUserId, + device_id: OwnedDeviceId, + ) -> Result<()> { + self.no_olm_sent.entry(user_id).or_default().remove(&device_id); + Ok(()) + } + + */ } #[cfg(test)] mod tests { use matrix_sdk_test::async_test; - use ruma::room_id; + use ruma::{device_id, room_id, user_id}; use vodozemac::{Curve25519PublicKey, Ed25519PublicKey}; use crate::{ @@ -366,4 +395,32 @@ mod tests { store.save_changes(changes).await.unwrap(); assert!(store.is_message_known(&hash).await.unwrap()); } + + #[async_test] + async fn test_no_olm_withheld_store() { + let store = MemoryStore::new(); + + let user_id = user_id!("@alice:example.com"); + let device_id_1 = device_id!("DEV0001"); + let device_id_2 = device_id!("DEV0002"); + + let mut changes = Changes::default(); + changes + .no_olm_sent + .entry(user_id.to_owned()) + .or_insert_with(Vec::new) + .push(device_id_1.to_owned()); + + store.save_changes(changes).await.unwrap(); + + let is_withheld_for = + store.is_no_olm_sent(user_id.to_owned(), device_id_1.to_owned()).await.unwrap(); + + assert_eq!(true, is_withheld_for); + + let is_withheld_for = + store.is_no_olm_sent(user_id.to_owned(), device_id_2.to_owned()).await.unwrap(); + + assert_eq!(false, is_withheld_for); + } } diff --git a/crates/matrix-sdk-crypto/src/store/mod.rs b/crates/matrix-sdk-crypto/src/store/mod.rs index 7a35c1284..5451d28b1 100644 --- a/crates/matrix-sdk-crypto/src/store/mod.rs +++ b/crates/matrix-sdk-crypto/src/store/mod.rs @@ -39,7 +39,7 @@ //! [`CryptoStore`]: trait.Cryptostore.html use std::{ - collections::{HashMap, HashSet}, + collections::{BTreeMap, HashMap, HashSet}, fmt::Debug, ops::Deref, sync::{atomic::AtomicBool, Arc}, @@ -117,6 +117,7 @@ pub struct Changes { pub key_requests: Vec, pub identities: IdentityChanges, pub devices: DeviceChanges, + pub no_olm_sent: BTreeMap>, } /// A user for which we are tracking the list of devices. @@ -143,6 +144,7 @@ impl Changes { && self.key_requests.is_empty() && self.identities.is_empty() && self.devices.is_empty() + && self.no_olm_sent.is_empty() } } diff --git a/crates/matrix-sdk-crypto/src/store/traits.rs b/crates/matrix-sdk-crypto/src/store/traits.rs index 690abe217..6c38af86d 100644 --- a/crates/matrix-sdk-crypto/src/store/traits.rs +++ b/crates/matrix-sdk-crypto/src/store/traits.rs @@ -16,7 +16,7 @@ use std::{collections::HashMap, fmt, sync::Arc}; use async_trait::async_trait; use matrix_sdk_common::{locks::Mutex, AsyncTraitDeps}; -use ruma::{DeviceId, OwnedDeviceId, RoomId, TransactionId, UserId}; +use ruma::{DeviceId, OwnedDeviceId, OwnedUserId, RoomId, TransactionId, UserId}; use super::{BackupKeys, Changes, CryptoStoreError, Result, RoomKeyCounts}; use crate::{ @@ -186,6 +186,11 @@ pub trait CryptoStore: AsyncTraitDeps { &self, request_id: &TransactionId, ) -> Result<(), Self::Error>; + /// Remembers when a no_olm withheld message was already sent to avoid doing + /// it again. + /// Help: Using user/device as key here because it's sent in clear, but + /// should we use the sender_key? + async fn is_no_olm_sent(&self, user_id: OwnedUserId, device_id: OwnedDeviceId) -> Result; } #[repr(transparent)] diff --git a/crates/matrix-sdk-crypto/src/types/events/room_key_withheld.rs b/crates/matrix-sdk-crypto/src/types/events/room_key_withheld.rs index d1fab90a4..9d97388a5 100644 --- a/crates/matrix-sdk-crypto/src/types/events/room_key_withheld.rs +++ b/crates/matrix-sdk-crypto/src/types/events/room_key_withheld.rs @@ -388,7 +388,6 @@ pub(super) mod test { #[test] fn no_olm_should_not_have_room_and_session() { let room_id = room_id!("!DwLygpkclUAfQNnfva:localhost:8481"); - let user_id = user_id!("@alice:example.org"); let device_id = device_id!("DEV001"); let sender_key = Curve25519PublicKey::from_base64("9n7mdWKOjr9c4NTlG6zV8dbFtNK79q9vZADoh7nMUwA") diff --git a/crates/matrix-sdk-sled/src/crypto_store.rs b/crates/matrix-sdk-sled/src/crypto_store.rs index 4b654824e..500ab698c 100644 --- a/crates/matrix-sdk-sled/src/crypto_store.rs +++ b/crates/matrix-sdk-sled/src/crypto_store.rs @@ -35,7 +35,7 @@ use matrix_sdk_crypto::{ TrackedUser, }; use matrix_sdk_store_encryption::StoreCipher; -use ruma::{DeviceId, OwnedDeviceId, RoomId, TransactionId, UserId}; +use ruma::{DeviceId, OwnedDeviceId, OwnedUserId, RoomId, TransactionId, UserId}; use serde::{de::DeserializeOwned, Serialize}; use sled::{ transaction::{ConflictableTransactionError, TransactionError}, @@ -58,6 +58,7 @@ const INBOUND_GROUP_TABLE_NAME: &str = "crypto-store-inbound-group-sessions"; const OUTBOUND_GROUP_TABLE_NAME: &str = "crypto-store-outbound-group-sessions"; const SECRET_REQUEST_BY_INFO_TABLE: &str = "crypto-store-secret-request-by-info"; const TRACKED_USERS_TABLE: &str = "crypto-store-secret-tracked-users"; +const NO_OLM_SENT_TABLE: &str = "crypto-store-no-olm-sent"; impl EncodeKey for InboundGroupSession { fn encode(&self) -> Vec { @@ -185,6 +186,8 @@ pub struct SledCryptoStore { identities: Tree, tracked_users: Tree, + + no_olm_sent: Tree, } impl std::fmt::Debug for SledCryptoStore { @@ -388,6 +391,8 @@ impl SledCryptoStore { let session_cache = SessionStore::new(); + let no_olm_sent = db.open_tree("no_olm_sent")?; + let database = Self { account_info: RwLock::new(None).into(), path, @@ -406,6 +411,7 @@ impl SledCryptoStore { tracked_users, olm_hashes, identities, + no_olm_sent, }; database.upgrade().await?; @@ -1010,6 +1016,11 @@ impl CryptoStore for SledCryptoStore { Ok(key) } + + async fn is_no_olm_sent(&self, user_id: OwnedUserId, device_id: OwnedDeviceId) -> Result { + // TODO, help not sure how to to that? + Ok(false) + } } #[cfg(test)] diff --git a/crates/matrix-sdk-sqlite/src/crypto_store.rs b/crates/matrix-sdk-sqlite/src/crypto_store.rs index 2b5523c18..f4dac0e10 100644 --- a/crates/matrix-sdk-sqlite/src/crypto_store.rs +++ b/crates/matrix-sdk-sqlite/src/crypto_store.rs @@ -946,6 +946,15 @@ impl CryptoStore for SqliteCryptoStore { let request_id = self.encode_key("key_requests", request_id.as_bytes()); Ok(self.acquire().await?.delete_key_request(request_id).await?) } + + async fn is_no_olm_sent( + &self, + user_id: OwnedUserId, + device_id: OwnedDeviceId, + ) -> StoreResult { + // TODO help I need some guidance here? + Ok(false) + } } #[cfg(test)]