diff --git a/crates/matrix-sdk-crypto/src/olm/group_sessions/inbound.rs b/crates/matrix-sdk-crypto/src/olm/group_sessions/inbound.rs index 755426efe..2e672a704 100644 --- a/crates/matrix-sdk-crypto/src/olm/group_sessions/inbound.rs +++ b/crates/matrix-sdk-crypto/src/olm/group_sessions/inbound.rs @@ -488,7 +488,34 @@ impl InboundGroupSession { /// Check if the [`InboundGroupSession`] is better than the given other /// [`InboundGroupSession`] + #[deprecated( + note = "Sessions cannot be compared on a linear scale. Consider calling `compare_ratchet`, as well as comparing the `sender_data`." + )] pub async fn compare(&self, other: &InboundGroupSession) -> SessionOrdering { + match self.compare_ratchet(other).await { + SessionOrdering::Equal => { + match self.sender_data.compare_trust_level(&other.sender_data) { + Ordering::Less => SessionOrdering::Worse, + Ordering::Equal => SessionOrdering::Equal, + Ordering::Greater => SessionOrdering::Better, + } + } + result => result, + } + } + + /// Check if the [`InboundGroupSession`]'s ratchet index is better than that + /// of the given other [`InboundGroupSession`]. + /// + /// If the two sessions are not connected (i.e., they are from different + /// senders, or if advancing the ratchets to the same index does not + /// give the same ratchet value), returns [`SessionOrdering::Unconnected`]. + /// + /// Otherwise, returns [`SessionOrdering::Equal`], + /// [`SessionOrdering::Better`], or [`SessionOrdering::Worse`] respectively + /// depending on whether this session's first known index is equal to, + /// lower than, or higher than, that of `other`. + pub async fn compare_ratchet(&self, other: &InboundGroupSession) -> SessionOrdering { // If this is the same object the ordering is the same, we can't compare because // we would deadlock while trying to acquire the same lock twice. if Arc::ptr_eq(&self.inner, &other.inner) { @@ -501,17 +528,7 @@ impl InboundGroupSession { SessionOrdering::Unconnected } else { let mut other_inner = other.inner.lock().await; - - match self.inner.lock().await.compare(&mut other_inner) { - SessionOrdering::Equal => { - match self.sender_data.compare_trust_level(&other.sender_data) { - Ordering::Less => SessionOrdering::Worse, - Ordering::Equal => SessionOrdering::Equal, - Ordering::Greater => SessionOrdering::Better, - } - } - result => result, - } + self.inner.lock().await.compare(&mut other_inner) } } @@ -1057,6 +1074,7 @@ mod tests { } #[async_test] + #[allow(deprecated)] async fn test_session_comparison() { let alice = Account::with_device_id(alice_id(), alice_device_id()); let room_id = room_id!("!test:localhost"); @@ -1067,18 +1085,24 @@ mod tests { let mut copy = InboundGroupSession::from_pickle(inbound.pickle().await).unwrap(); assert_eq!(inbound.compare(&worse).await, SessionOrdering::Better); + assert_eq!(inbound.compare_ratchet(&worse).await, SessionOrdering::Better); assert_eq!(worse.compare(&inbound).await, SessionOrdering::Worse); + assert_eq!(worse.compare_ratchet(&inbound).await, SessionOrdering::Worse); assert_eq!(inbound.compare(&inbound).await, SessionOrdering::Equal); + assert_eq!(inbound.compare_ratchet(&inbound).await, SessionOrdering::Equal); assert_eq!(inbound.compare(©).await, SessionOrdering::Equal); + assert_eq!(inbound.compare_ratchet(©).await, SessionOrdering::Equal); copy.creator_info.curve25519_key = Curve25519PublicKey::from_base64("XbmrPa1kMwmdtNYng1B2gsfoo8UtF+NklzsTZiaVKyY") .unwrap(); assert_eq!(inbound.compare(©).await, SessionOrdering::Unconnected); + assert_eq!(inbound.compare_ratchet(©).await, SessionOrdering::Unconnected); } #[async_test] + #[allow(deprecated)] async fn test_session_comparison_sender_data() { let alice = Account::with_device_id(alice_id(), alice_device_id()); let room_id = room_id!("!test:localhost");