diff --git a/crates/matrix-sdk-crypto/src/identities/device.rs b/crates/matrix-sdk-crypto/src/identities/device.rs index 2c14e3e3c..48973f236 100644 --- a/crates/matrix-sdk-crypto/src/identities/device.rs +++ b/crates/matrix-sdk-crypto/src/identities/device.rs @@ -335,6 +335,11 @@ impl Device { self.verification_machine.store.get_sessions(&k.to_base64()).await } + #[cfg(test)] + pub(crate) async fn get_most_recent_session(&self) -> OlmResult> { + self.inner.get_most_recent_session(self.verification_machine.store.inner()).await + } + /// Is this device considered to be verified. /// /// This method returns true if either [`is_locally_trusted()`] returns true @@ -612,6 +617,34 @@ impl ReadOnlyDevice { } } + /// Find and return the most recently created Olm [`Session`] we are sharing + /// with this device. + pub(crate) async fn get_most_recent_session( + &self, + store: &DynCryptoStore, + ) -> OlmResult> { + if let Some(sender_key) = self.curve25519_key() { + if let Some(s) = store.get_sessions(&sender_key.to_base64()).await? { + let mut sessions = s.lock().await; + + sessions.sort_by_key(|s| s.creation_time); + + Ok(sessions.last().cloned()) + } else { + Ok(None) + } + } else { + warn!( + user_id = ?self.user_id(), + device_id = ?self.device_id(), + "Trying to find a Olm session of a device, but the device doesn't have a \ + Curve25519 key", + ); + + Err(EventError::MissingSenderKey.into()) + } + } + /// Does this device support the olm.v2.curve25519-aes-sha2 encryption /// algorithm. #[cfg(feature = "experimental-algorithms")] @@ -680,44 +713,29 @@ impl ReadOnlyDevice { event_type: &str, content: Value, ) -> OlmResult<(Session, Raw)> { - let Some(sender_key) = self.curve25519_key() else { - warn!( + let session = self.get_most_recent_session(store).await?; + + if let Some(mut session) = session { + let message = session.encrypt(self, event_type, content).await?; + + trace!( user_id = ?self.user_id(), device_id = ?self.device_id(), - "Trying to encrypt a Megolm session, but the device doesn't \ - have a curve25519 key", + session_id = session.session_id(), + "Successfully encrypted a Megolm session", ); - return Err(EventError::MissingSenderKey.into()); - }; - let session = if let Some(s) = store.get_sessions(&sender_key.to_base64()).await? { - let mut sessions = s.lock().await; - sessions.sort_by_key(|s| s.last_use_time); - sessions.get(0).cloned() + Ok((session, message)) } else { - None - }; - - let Some(mut session) = session else { warn!( "Trying to encrypt a Megolm session for user {} on device {}, \ but no Olm session is found", self.user_id(), self.device_id() ); - return Err(OlmError::MissingSession); - }; - let message = session.encrypt(self, event_type, content).await?; - - trace!( - user_id = ?self.user_id(), - device_id = ?self.device_id(), - session_id = session.session_id(), - "Successfully encrypted a Megolm session", - ); - - Ok((session, message)) + Err(OlmError::MissingSession) + } } /// Update a device with a new device keys struct. diff --git a/crates/matrix-sdk-crypto/src/machine.rs b/crates/matrix-sdk-crypto/src/machine.rs index c58159b73..a0399c651 100644 --- a/crates/matrix-sdk-crypto/src/machine.rs +++ b/crates/matrix-sdk-crypto/src/machine.rs @@ -1663,7 +1663,12 @@ pub(crate) mod testing { #[cfg(test)] pub(crate) mod tests { - use std::{collections::BTreeMap, iter, sync::Arc}; + use std::{ + collections::BTreeMap, + iter, + sync::Arc, + time::{Duration, SystemTime}, + }; use assert_matches::assert_matches; use futures::select; @@ -1695,7 +1700,7 @@ pub(crate) mod tests { room_id, serde::Raw, uint, user_id, DeviceId, DeviceKeyAlgorithm, DeviceKeyId, MilliSecondsSinceUnixEpoch, - OwnedDeviceKeyId, TransactionId, UserId, + OwnedDeviceKeyId, SecondsSinceUnixEpoch, TransactionId, UserId, }; use serde_json::json; use vodozemac::{ @@ -2027,23 +2032,34 @@ pub(crate) mod tests { assert!(user_sessions.contains_key(alice_device)); } - #[async_test] - async fn test_session_creation() { - let (alice_machine, bob_machine, one_time_keys) = get_machine_pair().await; - - let mut bob_keys = BTreeMap::new(); - - let (device_key_id, one_time_key) = one_time_keys.iter().next().unwrap(); - let mut keys = BTreeMap::new(); - keys.insert(device_key_id.clone(), one_time_key.clone()); - bob_keys.insert(bob_machine.device_id().into(), keys); - - let mut one_time_keys = BTreeMap::new(); - one_time_keys.insert(bob_machine.user_id().to_owned(), bob_keys); - + async fn create_session( + machine: &OlmMachine, + user_id: &UserId, + device_id: &DeviceId, + key_id: OwnedDeviceKeyId, + one_time_key: Raw, + ) { + let keys = BTreeMap::from([(key_id, one_time_key)]); + let keys = BTreeMap::from([(device_id.to_owned(), keys)]); + let one_time_keys = BTreeMap::from([(user_id.to_owned(), keys)]); let response = claim_keys::v3::Response::new(one_time_keys); - alice_machine.receive_keys_claim_response(&response).await.unwrap(); + machine.receive_keys_claim_response(&response).await.unwrap(); + } + + #[async_test] + async fn test_session_creation() { + let (alice_machine, bob_machine, mut one_time_keys) = get_machine_pair().await; + let (device_key_id, one_time_key) = one_time_keys.pop_first().unwrap(); + + create_session( + &alice_machine, + bob_machine.user_id(), + bob_machine.device_id(), + device_key_id, + one_time_key, + ) + .await; let session = alice_machine .store @@ -2055,6 +2071,80 @@ pub(crate) mod tests { assert!(!session.lock().await.is_empty()) } + #[async_test] + async fn getting_most_recent_session() { + let (alice_machine, bob_machine, mut one_time_keys) = get_machine_pair().await; + let (device_key_id, one_time_key) = one_time_keys.pop_first().unwrap(); + + let device = alice_machine + .get_device(bob_machine.user_id(), bob_machine.device_id(), None) + .await + .unwrap() + .unwrap(); + + assert!(device.get_most_recent_session().await.unwrap().is_none()); + + create_session( + &alice_machine, + bob_machine.user_id(), + bob_machine.device_id(), + device_key_id, + one_time_key.to_owned(), + ) + .await; + + for _ in 0..10 { + let (device_key_id, one_time_key) = one_time_keys.pop_first().unwrap(); + + create_session( + &alice_machine, + bob_machine.user_id(), + bob_machine.device_id(), + device_key_id, + one_time_key.to_owned(), + ) + .await; + } + + // Since the sessions are created quickly in succession and our timestamps have + // a resolution in seconds, it's very likely that we're going to end up + // with the same timestamps, so we manually masage them to be 10s apart. + let session_id = { + let sessions = alice_machine + .store + .get_sessions(&bob_machine.identity_keys().curve25519.to_base64()) + .await + .unwrap() + .unwrap(); + + let mut use_time = SystemTime::now(); + let mut sessions = sessions.lock().await; + + let mut session_id = None; + + // Iterate through the sessions skipping the first and last element so we know + // that the correct session isn't the first nor the last one. + let (_, sessions_slice) = sessions.as_mut_slice().split_last_mut().unwrap(); + + for session in sessions_slice.iter_mut().skip(1) { + session.creation_time = SecondsSinceUnixEpoch::from_system_time(use_time).unwrap(); + use_time += Duration::from_secs(10); + + session_id = Some(session.session_id().to_owned()); + } + + session_id.unwrap() + }; + + let newest_session = device.get_most_recent_session().await.unwrap().unwrap(); + + assert_eq!( + newest_session.session_id(), + session_id, + "The session we found is the one that was most recently created" + ); + } + #[async_test] async fn test_olm_encryption() { let (alice, bob) = get_machine_pair_with_session().await;