From d41af396cc01fcc7a50f2c50a542ea71d6077b0a Mon Sep 17 00:00:00 2001 From: Hubert Chathi Date: Thu, 27 Jun 2024 06:18:52 -0400 Subject: [PATCH] Embed device keys in Olm-encrypted messages (#3517) This patch implements MSC4147[1]. Signed-off-by: Hubert Chathi [1]: https://github.com/matrix-org/matrix-spec-proposals/pull/4147 --- bindings/matrix-sdk-crypto-ffi/src/lib.rs | 42 +++++- crates/matrix-sdk-crypto/Cargo.toml | 1 + crates/matrix-sdk-crypto/src/error.rs | 13 ++ .../src/identities/device.rs | 3 +- crates/matrix-sdk-crypto/src/machine.rs | 7 + crates/matrix-sdk-crypto/src/olm/account.rs | 54 ++++--- crates/matrix-sdk-crypto/src/olm/mod.rs | 6 +- crates/matrix-sdk-crypto/src/olm/session.rs | 132 +++++++++++++----- .../src/session_manager/sessions.rs | 13 +- .../src/store/integration_tests.rs | 11 ++ .../src/store/memorystore.rs | 10 ++ crates/matrix-sdk-crypto/src/store/traits.rs | 10 ++ .../src/crypto_store/mod.rs | 100 +++++++++++-- crates/matrix-sdk-sqlite/src/crypto_store.rs | 75 ++++++++-- 14 files changed, 393 insertions(+), 84 deletions(-) diff --git a/bindings/matrix-sdk-crypto-ffi/src/lib.rs b/bindings/matrix-sdk-crypto-ffi/src/lib.rs index 0ed02b8a5..63c5936a1 100644 --- a/bindings/matrix-sdk-crypto-ffi/src/lib.rs +++ b/bindings/matrix-sdk-crypto-ffi/src/lib.rs @@ -16,7 +16,11 @@ mod responses; mod users; mod verification; -use std::{collections::HashMap, sync::Arc, time::Duration}; +use std::{ + collections::{BTreeMap, HashMap}, + sync::Arc, + time::Duration, +}; use anyhow::Context as _; pub use backup_recovery_key::{ @@ -33,7 +37,9 @@ use matrix_sdk_common::deserialized_responses::ShieldState as RustShieldState; use matrix_sdk_crypto::{ olm::{IdentityKeys, InboundGroupSession, Session}, store::{Changes, CryptoStore, PendingChanges, RoomSettings as RustRoomSettings}, - types::{EventEncryptionAlgorithm as RustEventEncryptionAlgorithm, SigningKey}, + types::{ + DeviceKey, DeviceKeys, EventEncryptionAlgorithm as RustEventEncryptionAlgorithm, SigningKey, + }, CollectStrategy, EncryptionSettings as RustEncryptionSettings, }; use matrix_sdk_sqlite::SqliteCryptoStore; @@ -43,8 +49,8 @@ pub use responses::{ }; use ruma::{ events::room::history_visibility::HistoryVisibility as RustHistoryVisibility, - DeviceKeyAlgorithm, MilliSecondsSinceUnixEpoch, OwnedDeviceId, OwnedUserId, RoomId, - SecondsSinceUnixEpoch, UserId, + DeviceKeyAlgorithm, DeviceKeyId, MilliSecondsSinceUnixEpoch, OwnedDeviceId, OwnedUserId, + RoomId, SecondsSinceUnixEpoch, UserId, }; use serde::{Deserialize, Serialize}; use tokio::runtime::Runtime; @@ -332,6 +338,10 @@ async fn save_changes( processed_steps += 1; listener(processed_steps, total_steps); + // The Sessions were created with incorrect device keys, so clear the cache + // so that they'll get recreated with correct ones. + store.clear_caches().await; + Ok(()) } @@ -419,6 +429,27 @@ fn collect_sessions( ) -> anyhow::Result<(Vec, Vec)> { let mut sessions = Vec::new(); + // Create a DeviceKeys struct with enough information to get a working + // Session, but we will won't actually use the Sessions (and we'll clear + // the session cache after migration) so we don't need to worry about + // signatures. + let device_keys = DeviceKeys::new( + user_id.clone(), + device_id.clone(), + Default::default(), + BTreeMap::from([ + ( + DeviceKeyId::from_parts(DeviceKeyAlgorithm::Ed25519, &device_id), + DeviceKey::Ed25519(identity_keys.ed25519), + ), + ( + DeviceKeyId::from_parts(DeviceKeyAlgorithm::Curve25519, &device_id), + DeviceKey::Curve25519(identity_keys.curve25519), + ), + ]), + Default::default(), + ); + for session_pickle in session_pickles { let pickle = vodozemac::olm::Session::from_libolm_pickle(&session_pickle.pickle, pickle_key)? @@ -439,8 +470,7 @@ fn collect_sessions( last_use_time, }; - let session = - Session::from_pickle(user_id.clone(), device_id.clone(), identity_keys.clone(), pickle); + let session = Session::from_pickle(device_keys.clone(), pickle)?; sessions.push(session); processed_steps += 1; diff --git a/crates/matrix-sdk-crypto/Cargo.toml b/crates/matrix-sdk-crypto/Cargo.toml index e67f338cc..104930172 100644 --- a/crates/matrix-sdk-crypto/Cargo.toml +++ b/crates/matrix-sdk-crypto/Cargo.toml @@ -30,6 +30,7 @@ testing = ["dep:http"] [dependencies] aes = "0.8.1" as_variant = { workspace = true } +assert_matches2 = { workspace = true } async-trait = { workspace = true } bs58 = { version = "0.5.0" } byteorder = { workspace = true } diff --git a/crates/matrix-sdk-crypto/src/error.rs b/crates/matrix-sdk-crypto/src/error.rs index 3e4e0bcf8..cedb9712d 100644 --- a/crates/matrix-sdk-crypto/src/error.rs +++ b/crates/matrix-sdk-crypto/src/error.rs @@ -165,6 +165,19 @@ pub enum EventError { MismatchedRoom(OwnedRoomId, Option), } +/// Error type describing different errors that can happen when we create an +/// Olm session from a pickle. +#[derive(Error, Debug)] +pub enum SessionUnpickleError { + /// The device keys are missing the signing key + #[error("the device keys are missing the signing key")] + MissingSigningKey, + + /// The device keys are missing the identity key + #[error("the device keys are missing the identity key")] + MissingIdentityKey, +} + /// Error type describing different errors that happen when we check or create /// signatures for a Matrix JSON object. #[derive(Error, Debug)] diff --git a/crates/matrix-sdk-crypto/src/identities/device.rs b/crates/matrix-sdk-crypto/src/identities/device.rs index 79cd50eea..225c8c8bf 100644 --- a/crates/matrix-sdk-crypto/src/identities/device.rs +++ b/crates/matrix-sdk-crypto/src/identities/device.rs @@ -871,7 +871,8 @@ impl ReadOnlyDevice { } } - pub(crate) fn as_device_keys(&self) -> &DeviceKeys { + /// Return the device keys + pub fn as_device_keys(&self) -> &DeviceKeys { &self.inner } diff --git a/crates/matrix-sdk-crypto/src/machine.rs b/crates/matrix-sdk-crypto/src/machine.rs index 80795982c..dfa1ea3f1 100644 --- a/crates/matrix-sdk-crypto/src/machine.rs +++ b/crates/matrix-sdk-crypto/src/machine.rs @@ -173,7 +173,14 @@ impl OlmMachine { let static_account = account.static_data().clone(); let store = Arc::new(CryptoStoreWrapper::new(self.user_id(), MemoryStore::new())); + let device = ReadOnlyDevice::from_account(&account); store.save_pending_changes(PendingChanges { account: Some(account) }).await?; + store + .save_changes(Changes { + devices: DeviceChanges { new: vec![device], ..Default::default() }, + ..Default::default() + }) + .await?; Ok(Self::new_helper( device_id, diff --git a/crates/matrix-sdk-crypto/src/olm/account.rs b/crates/matrix-sdk-crypto/src/olm/account.rs index edeb52bc6..611253b4e 100644 --- a/crates/matrix-sdk-crypto/src/olm/account.rs +++ b/crates/matrix-sdk-crypto/src/olm/account.rs @@ -905,12 +905,16 @@ impl Account { /// and shared with us. /// /// * `fallback_used` - Was the one-time key a fallback key. + /// + /// * `our_device_keys` - Our own `DeviceKeys`, including cross-signing + /// signatures if applicable, for embedding in encrypted messages. pub fn create_outbound_session_helper( &self, config: SessionConfig, identity_key: Curve25519PublicKey, one_time_key: Curve25519PublicKey, fallback_used: bool, + our_device_keys: DeviceKeys, ) -> Session { let session = self.inner.create_outbound_session(config, identity_key, one_time_key); @@ -918,12 +922,10 @@ impl Account { let session_id = session.session_id(); Session { - user_id: self.static_data.user_id.clone(), - device_id: self.static_data.device_id.clone(), - our_identity_keys: self.static_data.identity_keys.clone(), inner: Arc::new(Mutex::new(session)), session_id: session_id.into(), sender_key: identity_key, + our_device_keys, created_using_fallback_key: fallback_used, creation_time: now, last_use_time: now, @@ -978,11 +980,15 @@ impl Account { /// /// * `key_map` - A map from the algorithm and device ID to the one-time key /// that the other account created and shared with us. + /// + /// * `our_device_keys` - Our own `DeviceKeys`, including cross-signing + /// signatures if applicable, for embedding in encrypted messages. #[allow(clippy::result_large_err)] pub fn create_outbound_session( &self, device: &ReadOnlyDevice, key_map: &BTreeMap>, + our_device_keys: DeviceKeys, ) -> Result { let pre_key_bundle = Self::find_pre_key_bundle(device, key_map)?; @@ -1012,6 +1018,7 @@ impl Account { identity_key, one_time_key, is_fallback, + our_device_keys, )) } } @@ -1026,11 +1033,15 @@ impl Account { /// /// * `their_identity_key` - The other account's identity/curve25519 key. /// + /// * `our_device_keys` - Our own `DeviceKeys`, including cross-signing + /// signatures if applicable, for embedding in encrypted messages. + /// /// * `message` - A pre-key Olm message that was sent to us by the other /// account. pub fn create_inbound_session( &mut self, their_identity_key: Curve25519PublicKey, + our_device_keys: DeviceKeys, message: &PreKeyMessage, ) -> Result { Span::current().record("session_id", debug(message.session_id())); @@ -1043,12 +1054,10 @@ impl Account { debug!(session=?result.session, "Decrypted an Olm message from a new Olm session"); let session = Session { - user_id: self.static_data.user_id.clone(), - device_id: self.static_data.device_id.clone(), - our_identity_keys: self.static_data.identity_keys.clone(), inner: Arc::new(Mutex::new(result.session)), session_id: session_id.into(), sender_key: their_identity_key, + our_device_keys, created_using_fallback_key: false, creation_time: now, last_use_time: now, @@ -1072,7 +1081,8 @@ impl Account { let one_time_map = other.signed_one_time_keys(); let device = ReadOnlyDevice::from_account(other); - let mut our_session = self.create_outbound_session(&device, &one_time_map).unwrap(); + let mut our_session = + self.create_outbound_session(&device, &one_time_map, self.device_keys()).unwrap(); other.mark_keys_as_published(); @@ -1104,8 +1114,13 @@ impl Account { }; let our_device = ReadOnlyDevice::from_account(self); - let other_session = - other.create_inbound_session(our_device.curve25519_key().unwrap(), &prekey).unwrap(); + let other_session = other + .create_inbound_session( + our_device.curve25519_key().unwrap(), + other.device_keys(), + &prekey, + ) + .unwrap(); (our_session, other_session.session) } @@ -1290,20 +1305,23 @@ impl Account { ); return Err(OlmError::SessionWedged( - session.user_id.to_owned(), + session.our_device_keys.user_id.to_owned(), session.sender_key(), )); } } - // We didn't find a matching session; try to create a new session. - let result = match self.create_inbound_session(sender_key, prekey_message) { - Ok(r) => r, - Err(e) => { - warn!("Failed to create a new Olm session from a pre-key message: {e:?}"); - return Err(OlmError::SessionWedged(sender.to_owned(), sender_key)); - } - }; + let device_keys = store.get_own_device().await?.as_device_keys().clone(); + let result = + match self.create_inbound_session(sender_key, device_keys, prekey_message) { + Ok(r) => r, + Err(e) => { + warn!( + "Failed to create a new Olm session from a pre-key message: {e:?}" + ); + return Err(OlmError::SessionWedged(sender.to_owned(), sender_key)); + } + }; // We need to add the new session to the session cache, otherwise // we might try to create the same session again. diff --git a/crates/matrix-sdk-crypto/src/olm/mod.rs b/crates/matrix-sdk-crypto/src/olm/mod.rs index 747f14bc0..8bf6d6e75 100644 --- a/crates/matrix-sdk-crypto/src/olm/mod.rs +++ b/crates/matrix-sdk-crypto/src/olm/mod.rs @@ -93,6 +93,7 @@ pub(crate) mod tests { sender_key, one_time_key, false, + alice.device_keys(), ); (alice, session) @@ -144,6 +145,7 @@ pub(crate) mod tests { alice_keys.curve25519, one_time_key, false, + bob.device_keys(), ); let plaintext = "Hello world"; @@ -156,7 +158,9 @@ pub(crate) mod tests { }; let bob_keys = bob.identity_keys(); - let result = alice.create_inbound_session(bob_keys.curve25519, &prekey_message).unwrap(); + let result = alice + .create_inbound_session(bob_keys.curve25519, alice.device_keys(), &prekey_message) + .unwrap(); assert_eq!(bob_session.session_id(), result.session.session_id()); diff --git a/crates/matrix-sdk-crypto/src/olm/session.rs b/crates/matrix-sdk-crypto/src/olm/session.rs index 2836ed57d..7fe1d93ff 100644 --- a/crates/matrix-sdk-crypto/src/olm/session.rs +++ b/crates/matrix-sdk-crypto/src/olm/session.rs @@ -14,7 +14,7 @@ use std::{fmt, sync::Arc}; -use ruma::{serde::Raw, OwnedDeviceId, OwnedUserId, SecondsSinceUnixEpoch}; +use ruma::{serde::Raw, SecondsSinceUnixEpoch}; use serde::{Deserialize, Serialize}; use serde_json::json; use tokio::sync::Mutex; @@ -24,14 +24,13 @@ use vodozemac::{ Curve25519PublicKey, }; -use super::IdentityKeys; #[cfg(feature = "experimental-algorithms")] use crate::types::events::room::encrypted::OlmV2Curve25519AesSha2Content; use crate::{ - error::{EventError, OlmResult}, + error::{EventError, OlmResult, SessionUnpickleError}, types::{ events::room::encrypted::{OlmV1Curve25519AesSha2Content, ToDeviceEncryptedEventContent}, - EventEncryptionAlgorithm, + DeviceKeys, EventEncryptionAlgorithm, }, ReadOnlyDevice, }; @@ -40,18 +39,14 @@ use crate::{ /// `Account`s #[derive(Clone)] pub struct Session { - /// The `UserId` associated with this session - pub user_id: OwnedUserId, - /// The specific `DeviceId` associated with this session - pub device_id: OwnedDeviceId, - /// The `IdentityKeys` associated with this session - pub our_identity_keys: Arc, /// The OlmSession pub inner: Arc>, /// Our sessionId pub session_id: Arc, /// The Key of the sender pub sender_key: Curve25519PublicKey, + /// Our own signed device keys + pub our_device_keys: DeviceKeys, /// Has this been created using the fallback key pub created_using_fallback_key: bool, /// When the session was created @@ -156,11 +151,12 @@ impl Session { recipient_device.ed25519_key().ok_or(EventError::MissingSigningKey)?; let payload = json!({ - "sender": &self.user_id, - "sender_device": &self.device_id, + "sender": &self.our_device_keys.user_id, + "sender_device": &self.our_device_keys.device_id, "keys": { - "ed25519": self.our_identity_keys.ed25519.to_base64(), + "ed25519": self.our_device_keys.ed25519_key().expect("Device doesn't have ed25519 key").to_base64(), }, + "device_keys": self.our_device_keys, "recipient": recipient_device.user_id(), "recipient_keys": { "ed25519": recipient_signing_key.to_base64(), @@ -178,14 +174,20 @@ impl Session { EventEncryptionAlgorithm::OlmV1Curve25519AesSha2 => OlmV1Curve25519AesSha2Content { ciphertext, recipient_key: self.sender_key, - sender_key: self.our_identity_keys.curve25519, + sender_key: self + .our_device_keys + .curve25519_key() + .expect("Device doesn't have curve25519 key"), message_id, } .into(), #[cfg(feature = "experimental-algorithms")] EventEncryptionAlgorithm::OlmV2Curve25519AesSha2 => OlmV2Curve25519AesSha2Content { ciphertext, - sender_key: self.our_identity_keys.curve25519, + sender_key: self + .our_device_keys + .curve25519_key() + .expect("Device doesn't have curve25519 key"), message_id, } .into(), @@ -227,36 +229,32 @@ impl Session { /// /// # Arguments /// - /// * `user_id` - Our own user id that the session belongs to. - /// - /// * `device_id` - Our own device ID that the session belongs to. - /// - /// * `our_identity_keys` - An clone of the Arc to our own identity keys. + /// * `our_device_keys` - Our own signed device keys. /// /// * `pickle` - The pickled version of the `Session`. - /// - /// * `pickle_mode` - The mode that was used to pickle the session, either - /// an unencrypted mode or an encrypted using passphrase. pub fn from_pickle( - user_id: OwnedUserId, - device_id: OwnedDeviceId, - our_identity_keys: Arc, + our_device_keys: DeviceKeys, pickle: PickledSession, - ) -> Self { + ) -> Result { + if our_device_keys.curve25519_key().is_none() { + return Err(SessionUnpickleError::MissingIdentityKey); + } + if our_device_keys.ed25519_key().is_none() { + return Err(SessionUnpickleError::MissingSigningKey); + } + let session: vodozemac::olm::Session = pickle.pickle.into(); let session_id = session.session_id(); - Session { - user_id, - device_id, - our_identity_keys, + Ok(Session { inner: Arc::new(Mutex::new(session)), session_id: session_id.into(), created_using_fallback_key: pickle.created_using_fallback_key, sender_key: pickle.sender_key, + our_device_keys, creation_time: pickle.creation_time, last_use_time: pickle.last_use_time, - } + }) } } @@ -285,3 +283,73 @@ pub struct PickledSession { /// The Unix timestamp when the session was last used. pub last_use_time: SecondsSinceUnixEpoch, } + +#[cfg(test)] +mod tests { + use assert_matches2::assert_let; + use matrix_sdk_test::async_test; + use ruma::{device_id, user_id}; + use serde_json::{self, Value}; + use vodozemac::olm::{OlmMessage, SessionConfig}; + + use crate::{ + identities::ReadOnlyDevice, olm::Account, + types::events::room::encrypted::ToDeviceEncryptedEventContent, + }; + + #[async_test] + async fn test_encryption_and_decryption() { + use ruma::events::dummy::ToDeviceDummyEventContent; + + // Given users Alice and Bob + let alice = + Account::with_device_id(user_id!("@alice:localhost"), device_id!("ALICEDEVICE")); + let mut bob = Account::with_device_id(user_id!("@bob:localhost"), device_id!("BOBDEVICE")); + + // When Alice creates an Olm session with Bob + bob.generate_one_time_keys(1); + let one_time_key = *bob.one_time_keys().values().next().unwrap(); + let sender_key = bob.identity_keys().curve25519; + let mut alice_session = alice.create_outbound_session_helper( + SessionConfig::default(), + sender_key, + one_time_key, + false, + alice.device_keys(), + ); + + let alice_device = ReadOnlyDevice::from_account(&alice); + + // and encrypts a message + let message = alice_session + .encrypt(&alice_device, "m.dummy", ToDeviceDummyEventContent::new(), None) + .await + .unwrap() + .deserialize() + .unwrap(); + + #[cfg(feature = "experimental-algorithms")] + assert_let!(ToDeviceEncryptedEventContent::OlmV2Curve25519AesSha2(content) = message); + #[cfg(not(feature = "experimental-algorithms"))] + assert_let!(ToDeviceEncryptedEventContent::OlmV1Curve25519AesSha2(content) = message); + + let prekey = if let OlmMessage::PreKey(m) = content.ciphertext { + m + } else { + panic!("Wrong Olm message type"); + }; + + // Then Bob should be able to create a session from the message and decrypt it. + let bob_session_result = bob + .create_inbound_session( + alice_device.curve25519_key().unwrap(), + bob.device_keys(), + &prekey, + ) + .unwrap(); + + // Also ensure that the encrypted payload has the device keys. + let plaintext: Value = serde_json::from_str(&bob_session_result.plaintext).unwrap(); + assert_eq!(plaintext["device_keys"]["user_id"].as_str(), Some("@alice:localhost")); + } +} diff --git a/crates/matrix-sdk-crypto/src/session_manager/sessions.rs b/crates/matrix-sdk-crypto/src/session_manager/sessions.rs index e81f32ded..c367954d3 100644 --- a/crates/matrix-sdk-crypto/src/session_manager/sessions.rs +++ b/crates/matrix-sdk-crypto/src/session_manager/sessions.rs @@ -537,7 +537,8 @@ impl SessionManager { }; let account = store_transaction.account().await?; - let session = match account.create_outbound_session(&device, key_map) { + let device_keys = self.store.get_own_device().await?.as_device_keys().clone(); + let session = match account.create_outbound_session(&device, key_map, device_keys) { Ok(s) => s, Err(e) => { warn!( @@ -631,7 +632,7 @@ mod tests { identities::{IdentityManager, ReadOnlyDevice}, olm::{Account, PrivateCrossSigningIdentity}, session_manager::GroupSessionCache, - store::{CryptoStoreWrapper, MemoryStore, PendingChanges, Store}, + store::{Changes, CryptoStoreWrapper, DeviceChanges, MemoryStore, PendingChanges, Store}, verification::VerificationMachine, }; @@ -688,7 +689,15 @@ mod tests { ); let store = Store::new(account.static_data().clone(), identity, store, verification); + let device = ReadOnlyDevice::from_account(&account); store.save_pending_changes(PendingChanges { account: Some(account) }).await.unwrap(); + store + .save_changes(Changes { + devices: DeviceChanges { new: vec![device], ..Default::default() }, + ..Default::default() + }) + .await + .unwrap(); let session_cache = GroupSessionCache::new(store.clone()); let identity_manager = IdentityManager::new(store.clone()); diff --git a/crates/matrix-sdk-crypto/src/store/integration_tests.rs b/crates/matrix-sdk-crypto/src/store/integration_tests.rs index 0f457b480..d058e8f3c 100644 --- a/crates/matrix-sdk-crypto/src/store/integration_tests.rs +++ b/crates/matrix-sdk-crypto/src/store/integration_tests.rs @@ -85,6 +85,7 @@ macro_rules! cryptostore_integration_tests { sender_key, one_time_key, false, + alice.device_keys(), ); (alice, session) @@ -216,6 +217,16 @@ macro_rules! cryptostore_integration_tests { }) .await .expect("Can't save account"); + store + .save_changes(Changes { + devices: DeviceChanges { + new: vec![ReadOnlyDevice::from_account(&account)], + ..Default::default() + }, + ..Default::default() + }) + .await + .unwrap(); let changes = Changes { sessions: vec![session.clone()], ..Default::default() }; store.save_changes(changes).await.unwrap(); diff --git a/crates/matrix-sdk-crypto/src/store/memorystore.rs b/crates/matrix-sdk-crypto/src/store/memorystore.rs index cace29e12..0b465dcba 100644 --- a/crates/matrix-sdk-crypto/src/store/memorystore.rs +++ b/crates/matrix-sdk-crypto/src/store/memorystore.rs @@ -474,6 +474,12 @@ impl CryptoStore for MemoryStore { Ok(self.devices.user_devices(user_id)) } + async fn get_own_device(&self) -> Result { + let account = self.load_account().await?.unwrap(); + + Ok(self.devices.get(&account.user_id, &account.device_id).unwrap()) + } + async fn get_user_identity(&self, user_id: &UserId) -> Result> { Ok(self.identities.read().unwrap().get(user_id).cloned()) } @@ -1265,6 +1271,10 @@ mod integration_tests { self.0.get_user_devices(user_id).await } + async fn get_own_device(&self) -> Result { + self.0.get_own_device().await + } + async fn get_user_identity( &self, user_id: &UserId, diff --git a/crates/matrix-sdk-crypto/src/store/traits.rs b/crates/matrix-sdk-crypto/src/store/traits.rs index 442b30c25..2ce396c4b 100644 --- a/crates/matrix-sdk-crypto/src/store/traits.rs +++ b/crates/matrix-sdk-crypto/src/store/traits.rs @@ -203,6 +203,12 @@ pub trait CryptoStore: AsyncTraitDeps { user_id: &UserId, ) -> Result, Self::Error>; + /// Get the device for the current client. + /// + /// Since our own device is set when the store is created, this will always + /// return a device (unless there is an error). + async fn get_own_device(&self) -> Result; + /// Get the user identity that is attached to the given user id. /// /// # Arguments @@ -450,6 +456,10 @@ impl CryptoStore for EraseCryptoStoreError { self.0.get_user_devices(user_id).await.map_err(Into::into) } + async fn get_own_device(&self) -> Result { + self.0.get_own_device().await.map_err(Into::into) + } + async fn get_user_identity(&self, user_id: &UserId) -> Result> { self.0.get_user_identity(user_id).await.map_err(Into::into) } diff --git a/crates/matrix-sdk-indexeddb/src/crypto_store/mod.rs b/crates/matrix-sdk-indexeddb/src/crypto_store/mod.rs index 8c2f37195..f14c01930 100644 --- a/crates/matrix-sdk-indexeddb/src/crypto_store/mod.rs +++ b/crates/matrix-sdk-indexeddb/src/crypto_store/mod.rs @@ -447,7 +447,14 @@ impl IndexeddbCryptoStore { /// Process all the changes and do all encryption/serialization before the /// actual transaction. - async fn prepare_for_transaction(&self, changes: &Changes) -> Result { + /// + /// Returns a tuple where the first item is a `PendingIndexeddbChanges` + /// struct, and the second item is a boolean indicating whether the session + /// cache should be cleared. + async fn prepare_for_transaction( + &self, + changes: &Changes, + ) -> Result<(PendingIndexeddbChanges, bool)> { let mut indexeddb_changes = PendingIndexeddbChanges::new(); let private_identity_pickle = @@ -534,7 +541,17 @@ impl IndexeddbCryptoStore { let mut device_store = indexeddb_changes.get(keys::DEVICES); + let account_info = self.get_static_account(); + let mut clear_caches = false; for device in device_changes.new.iter().chain(&device_changes.changed) { + // If our own device key changes, we need to clear the session + // cache because the sessions contain a copy of our device key, and + // we want the sessions to use the new version. + if account_info.as_ref().is_some_and(|info| { + info.user_id == device.user_id() && info.device_id == device.device_id() + }) { + clear_caches = true; + } let key = self.serializer.encode_key(keys::DEVICES, (device.user_id(), device.device_id())); let device = self.serializer.serialize_value(&device)?; @@ -617,7 +634,7 @@ impl IndexeddbCryptoStore { } } - Ok(indexeddb_changes) + Ok((indexeddb_changes, clear_caches)) } } @@ -697,7 +714,7 @@ impl_crypto_store! { // TODO: #2000 should make this lock go away, or change its shape. let _guard = self.save_changes_lock.lock().await; - let indexeddb_changes = self.prepare_for_transaction(&changes).await?; + let (indexeddb_changes, clear_caches) = self.prepare_for_transaction(&changes).await?; let stores = indexeddb_changes.touched_stores(); @@ -713,9 +730,15 @@ impl_crypto_store! { tx.await.into_result()?; - // all good, let's update our caches:indexeddb - for session in changes.sessions { - self.session_cache.add(session).await; + if clear_caches { + self.clear_caches().await; + } else { + // All good, let's update our caches:indexeddb. + // We only do this if clear_caches is false, because the sessions may + // have been created using old device_keys. + for session in changes.sessions { + self.session_cache.add(session).await; + } } Ok(()) @@ -860,9 +883,12 @@ impl_crypto_store! { } async fn get_sessions(&self, sender_key: &str) -> Result>>>> { - let account_info = self.get_static_account().ok_or(CryptoStoreError::AccountUnset)?; - if self.session_cache.get(sender_key).is_none() { + let device_keys = self.get_own_device() + .await? + .as_device_keys() + .clone(); + let range = self.serializer.encode_to_range(keys::SESSION, sender_key)?; let sessions: Vec = self .inner @@ -873,13 +899,12 @@ impl_crypto_store! { .iter() .filter_map(|f| self.serializer.deserialize_value(f).ok().map(|p| { Session::from_pickle( - account_info.user_id.clone(), - account_info.device_id.clone(), - account_info.identity_keys.clone(), + device_keys.clone(), p, ) + .map_err(|_| IndexeddbCryptoStoreError::CryptoStoreError(CryptoStoreError::AccountUnset)) })) - .collect::>(); + .collect::>>()?; self.session_cache.set_for_sender(sender_key, sessions); } @@ -1094,6 +1119,13 @@ impl_crypto_store! { .collect::>()) } + async fn get_own_device(&self) -> Result { + let account_info = self.get_static_account().ok_or(CryptoStoreError::AccountUnset)?; + Ok(self.get_device(&account_info.user_id, &account_info.device_id) + .await? + .unwrap()) + } + async fn get_user_identity(&self, user_id: &UserId) -> Result> { Ok(self .inner @@ -1692,7 +1724,12 @@ mod wasm_unit_tests { #[cfg(all(test, target_arch = "wasm32"))] mod tests { - use matrix_sdk_crypto::cryptostore_integration_tests; + use matrix_sdk_crypto::{ + cryptostore_integration_tests, + store::{Changes, CryptoStore as _, DeviceChanges, PendingChanges}, + ReadOnlyDevice, + }; + use matrix_sdk_test::async_test; use super::IndexeddbCryptoStore; @@ -1708,6 +1745,43 @@ mod tests { .expect("Can't create store without passphrase"), } } + + #[async_test] + async fn cache_cleared_after_device_update() { + let store = get_store("cache_cleared_after_device_update", None).await; + // Given we created a session and saved it in the store + let (account, session) = cryptostore_integration_tests::get_account_and_session().await; + let sender_key = session.sender_key.to_base64(); + + store + .save_pending_changes(PendingChanges { account: Some(account.deep_clone()) }) + .await + .expect("Can't save account"); + + let changes = Changes { sessions: vec![session.clone()], ..Default::default() }; + store.save_changes(changes).await.unwrap(); + + store.session_cache.get(&sender_key).expect("We should have a session"); + + // When we save a new version of our device keys + store + .save_changes(Changes { + devices: DeviceChanges { + new: vec![ReadOnlyDevice::from_account(&account)], + ..Default::default() + }, + ..Default::default() + }) + .await + .unwrap(); + + // Then the session is no longer in the cache + assert!( + store.session_cache.get(&sender_key).is_none(), + "Session should not be in the cache!" + ); + } + cryptostore_integration_tests!(); } diff --git a/crates/matrix-sdk-sqlite/src/crypto_store.rs b/crates/matrix-sdk-sqlite/src/crypto_store.rs index 29842dbc5..426446d5c 100644 --- a/crates/matrix-sdk-sqlite/src/crypto_store.rs +++ b/crates/matrix-sdk-sqlite/src/crypto_store.rs @@ -780,9 +780,11 @@ impl CryptoStore for SqliteCryptoStore { } let this = self.clone(); - self.acquire() + let clear_caches = self + .acquire() .await? .with_transaction(move |txn| { + let mut clear_caches = false; if let Some(pickled_private_identity) = &pickled_private_identity { let serialized_private_identity = this.serialize_value(pickled_private_identity)?; @@ -804,7 +806,16 @@ impl CryptoStore for SqliteCryptoStore { txn.set_kv("backup_version_v1", &serialized_backup_version)?; } + let account_info = this.get_static_account(); for device in changes.devices.new.iter().chain(&changes.devices.changed) { + // If our own device key changes, we need to clear the + // session cache because the sessions contain a copy of our + // device key. + if account_info.clone().is_some_and(|info| { + info.user_id == device.user_id() && info.device_id == device.device_id() + }) { + clear_caches = true; + } let user_id = this.encode_key("device", device.user_id().as_bytes()); let device_id = this.encode_key("device", device.device_id().as_bytes()); let data = this.serialize_value(&device)?; @@ -875,10 +886,14 @@ impl CryptoStore for SqliteCryptoStore { txn.set_secret(&secret_name, &value)?; } - Ok::<_, Error>(()) + Ok::<_, Error>(clear_caches) }) .await?; + if clear_caches { + self.clear_caches().await; + } + Ok(()) } @@ -905,9 +920,9 @@ impl CryptoStore for SqliteCryptoStore { } async fn get_sessions(&self, sender_key: &str) -> Result>>>> { - let account_info = self.get_static_account().ok_or(Error::AccountUnset)?; - if self.session_cache.get(sender_key).is_none() { + let device_keys = self.get_own_device().await?.as_device_keys().clone(); + let sessions = self .acquire() .await? @@ -916,12 +931,8 @@ impl CryptoStore for SqliteCryptoStore { .into_iter() .map(|bytes| { let pickle = self.deserialize_value(&bytes)?; - Ok(Session::from_pickle( - account_info.user_id.clone(), - account_info.device_id.clone(), - account_info.identity_keys.clone(), - pickle, - )) + Session::from_pickle(device_keys.clone(), pickle) + .map_err(|_| Error::AccountUnset) }) .collect::>()?; @@ -1110,6 +1121,11 @@ impl CryptoStore for SqliteCryptoStore { .collect() } + async fn get_own_device(&self) -> Result { + let account_info = self.get_static_account().ok_or(Error::AccountUnset)?; + Ok(self.get_device(&account_info.user_id, &account_info.device_id).await?.unwrap()) + } + async fn get_user_identity(&self, user_id: &UserId) -> Result> { let user_id = self.encode_key("identity", user_id.as_bytes()); Ok(self @@ -1331,7 +1347,8 @@ mod tests { mod encrypted_tests { use matrix_sdk_crypto::{ cryptostore_integration_tests, cryptostore_integration_tests_time, - store::{Changes, CryptoStore as _, PendingChanges}, + store::{Changes, CryptoStore as _, DeviceChanges, PendingChanges}, + ReadOnlyDevice, }; use matrix_sdk_test::async_test; use once_cell::sync::Lazy; @@ -1377,6 +1394,42 @@ mod encrypted_tests { ); } + #[async_test] + async fn cache_cleared_after_device_update() { + let store = get_store("cache_cleared_after_device_update", None).await; + // Given we created a session and saved it in the store + let (account, session) = cryptostore_integration_tests::get_account_and_session().await; + let sender_key = session.sender_key.to_base64(); + + store + .save_pending_changes(PendingChanges { account: Some(account.deep_clone()) }) + .await + .expect("Can't save account"); + + let changes = Changes { sessions: vec![session.clone()], ..Default::default() }; + store.save_changes(changes).await.unwrap(); + + store.session_cache.get(&sender_key).expect("We should have a session"); + + // When we save a new version of our device keys + store + .save_changes(Changes { + devices: DeviceChanges { + new: vec![ReadOnlyDevice::from_account(&account)], + ..Default::default() + }, + ..Default::default() + }) + .await + .unwrap(); + + // Then the session is no longer in the cache + assert!( + store.session_cache.get(&sender_key).is_none(), + "Session should not be in the cache!" + ); + } + cryptostore_integration_tests!(); cryptostore_integration_tests_time!(); }