From 8bfcd208019a39e6522ec7ef9fca8bba0d178609 Mon Sep 17 00:00:00 2001 From: Benjamin Bouvier Date: Tue, 10 Oct 2023 14:02:33 +0200 Subject: [PATCH] crypto: make `Account` not clonable and the `Account` in the cache optional --- .../src/dehydrated_devices.rs | 11 ++-- .../src/dehydrated_devices.rs | 14 +++-- .../src/gossiping/machine.rs | 33 +++++----- .../src/identities/device.rs | 2 +- .../src/identities/manager.rs | 45 ++++++++------ .../matrix-sdk-crypto/src/identities/user.rs | 3 +- crates/matrix-sdk-crypto/src/machine.rs | 61 +++++++++++-------- crates/matrix-sdk-crypto/src/olm/account.rs | 13 +++- .../src/session_manager/sessions.rs | 14 ++--- .../src/store/integration_tests.rs | 16 ++--- .../src/store/memorystore.rs | 11 +++- crates/matrix-sdk-crypto/src/store/mod.rs | 34 ++++++++--- 12 files changed, 156 insertions(+), 101 deletions(-) diff --git a/bindings/matrix-sdk-crypto-ffi/src/dehydrated_devices.rs b/bindings/matrix-sdk-crypto-ffi/src/dehydrated_devices.rs index c6550c2b2..548fb83be 100644 --- a/bindings/matrix-sdk-crypto-ffi/src/dehydrated_devices.rs +++ b/bindings/matrix-sdk-crypto-ffi/src/dehydrated_devices.rs @@ -54,12 +54,13 @@ impl Drop for DehydratedDevices { #[uniffi::export] impl DehydratedDevices { - pub fn create(&self) -> Arc { - DehydratedDevice { - inner: ManuallyDrop::new(self.inner.create()), + pub fn create(&self) -> Result, DehydrationError> { + let inner = self.runtime.block_on(self.inner.create())?; + + Ok(Arc::new(DehydratedDevice { + inner: ManuallyDrop::new(inner), runtime: self.runtime.to_owned(), - } - .into() + })) } pub fn rehydrate( diff --git a/crates/matrix-sdk-crypto/src/dehydrated_devices.rs b/crates/matrix-sdk-crypto/src/dehydrated_devices.rs index f4ba22fbf..f925f20ab 100644 --- a/crates/matrix-sdk-crypto/src/dehydrated_devices.rs +++ b/crates/matrix-sdk-crypto/src/dehydrated_devices.rs @@ -91,7 +91,7 @@ pub struct DehydratedDevices { impl DehydratedDevices { /// Create a new [`DehydratedDevice`] which can be uploaded to the server. - pub fn create(&self) -> DehydratedDevice { + pub async fn create(&self) -> Result { let user_id = self.inner.user_id(); let user_identity = self.inner.store().private_identity(); @@ -104,9 +104,11 @@ impl DehydratedDevices { store.clone(), ); - let store = Store::new(account, user_identity, store, verification_machine); + let store = + Store::new(account.static_data().clone(), user_identity, store, verification_machine); + store.save_pending_changes(crate::store::PendingChanges { account: Some(account) }).await?; - DehydratedDevice { store } + Ok(DehydratedDevice { store }) } /// Rehydrate the dehydrated device. @@ -456,10 +458,10 @@ mod tests { } #[async_test] - async fn dehydrated_device_creation() { + async fn test_dehydrated_device_creation() { let olm_machine = get_olm_machine().await; - let dehydrated_device = olm_machine.dehydrated_devices().create(); + let dehydrated_device = olm_machine.dehydrated_devices().create().await.unwrap(); let request = dehydrated_device .keys_for_upload("Foo".to_owned(), PICKLE_KEY) @@ -482,7 +484,7 @@ mod tests { let room_id = room_id!("!test:example.org"); let alice = get_olm_machine().await; - let dehydrated_device = alice.dehydrated_devices().create(); + let dehydrated_device = alice.dehydrated_devices().create().await.unwrap(); let mut request = dehydrated_device .keys_for_upload("Foo".to_owned(), PICKLE_KEY) diff --git a/crates/matrix-sdk-crypto/src/gossiping/machine.rs b/crates/matrix-sdk-crypto/src/gossiping/machine.rs index 5c6d9db31..00e213c99 100644 --- a/crates/matrix-sdk-crypto/src/gossiping/machine.rs +++ b/crates/matrix-sdk-crypto/src/gossiping/machine.rs @@ -1069,7 +1069,7 @@ mod tests { identities::{LocalTrust, ReadOnlyDevice}, olm::{Account, PrivateCrossSigningIdentity}, session_manager::GroupSessionCache, - store::{CryptoStoreWrapper, MemoryStore, Store}, + store::{CryptoStoreWrapper, MemoryStore, PendingChanges, Store}, types::events::room::encrypted::{EncryptedEvent, RoomEncryptedEventContent}, verification::VerificationMachine, }; @@ -1113,7 +1113,7 @@ mod tests { } #[cfg(feature = "automatic-room-key-forwarding")] - fn test_gossip_machine(user_id: &UserId) -> GossipMachine { + async fn gossip_machine_test_helper(user_id: &UserId) -> GossipMachine { let user_id = user_id.to_owned(); let device_id = DeviceId::new(); @@ -1122,13 +1122,15 @@ mod tests { let identity = Arc::new(Mutex::new(PrivateCrossSigningIdentity::empty(alice_id()))); let verification = VerificationMachine::new(account.static_data.clone(), identity.clone(), store.clone()); - let store = Store::new(account, identity, store, verification); + let store = Store::new(account.static_data().clone(), identity, store, verification); + store.save_pending_changes(PendingChanges { account: Some(account) }).await.unwrap(); + let session_cache = GroupSessionCache::new(store.clone()); GossipMachine::new(store, session_cache, Default::default()) } - async fn get_machine() -> GossipMachine { + async fn get_machine_test_helper() -> GossipMachine { let user_id = alice_id().to_owned(); let account = Account::with_device_id(&user_id, alice_device_id()); let device = ReadOnlyDevice::from_account(&account).await; @@ -1141,8 +1143,9 @@ mod tests { let verification = VerificationMachine::new(account.static_data.clone(), identity.clone(), store.clone()); - let store = Store::new(account, identity, store, verification); + let store = Store::new(account.static_data().clone(), identity, store, verification); store.save_devices(&[device, another_device]).await.unwrap(); + store.save_pending_changes(PendingChanges { account: Some(account) }).await.unwrap(); let session_cache = GroupSessionCache::new(store.clone()); GossipMachine::new(store, session_cache, Default::default()) @@ -1154,16 +1157,16 @@ mod tests { create_sessions: bool, algorithm: EventEncryptionAlgorithm, ) -> (GossipMachine, OutboundGroupSession, GossipMachine) { - let alice_machine = get_machine().await; + let alice_machine = get_machine_test_helper().await; let alice_device = ReadOnlyDevice::from_account( - alice_machine.inner.store.cache().await.unwrap().account().await.unwrap(), + &alice_machine.inner.store.cache().await.unwrap().account().await.unwrap(), ) .await; - let bob_machine = test_gossip_machine(other_machine_owner); + let bob_machine = gossip_machine_test_helper(other_machine_owner).await; let bob_device = ReadOnlyDevice::from_account( - bob_machine.inner.store.cache().await.unwrap().account().await.unwrap(), + &*bob_machine.inner.store.cache().await.unwrap().account().await.unwrap(), ) .await; @@ -1295,14 +1298,14 @@ mod tests { #[async_test] async fn create_machine() { - let machine = get_machine().await; + let machine = get_machine_test_helper().await; assert!(machine.outgoing_to_device_requests().await.unwrap().is_empty()); } #[async_test] async fn re_request_keys() { - let machine = get_machine().await; + let machine = get_machine_test_helper().await; let account = account(); let (outbound, session) = account.create_group_session_pair_with_defaults(room_id()).await; @@ -1325,7 +1328,7 @@ mod tests { #[async_test] #[cfg(feature = "automatic-room-key-forwarding")] async fn create_key_request() { - let machine = get_machine().await; + let machine = get_machine_test_helper().await; let account = account(); let second_account = alice_2_account(); let alice_device = ReadOnlyDevice::from_account(&second_account).await; @@ -1357,7 +1360,7 @@ mod tests { #[async_test] #[cfg(feature = "automatic-room-key-forwarding")] async fn receive_forwarded_key() { - let machine = get_machine().await; + let machine = get_machine_test_helper().await; let account = account(); let second_account = alice_2_account(); @@ -1466,7 +1469,7 @@ mod tests { #[async_test] #[cfg(feature = "automatic-room-key-forwarding")] async fn should_share_key_test() { - let machine = get_machine().await; + let machine = get_machine_test_helper().await; let account = account(); let own_device = @@ -1733,7 +1736,7 @@ mod tests { #[async_test] async fn test_secret_share_cycle() { - let alice_machine = get_machine().await; + let alice_machine = get_machine_test_helper().await; let second_account = alice_2_account(); let alice_device = ReadOnlyDevice::from_account(&second_account).await; diff --git a/crates/matrix-sdk-crypto/src/identities/device.rs b/crates/matrix-sdk-crypto/src/identities/device.rs index 9eb312eb2..7d4a22521 100644 --- a/crates/matrix-sdk-crypto/src/identities/device.rs +++ b/crates/matrix-sdk-crypto/src/identities/device.rs @@ -869,7 +869,7 @@ impl ReadOnlyDevice { pub async fn from_machine_test_helper( machine: &OlmMachine, ) -> Result { - Ok(ReadOnlyDevice::from_account(machine.store().cache().await?.account().await?).await) + Ok(ReadOnlyDevice::from_account(&*machine.store().cache().await?.account().await?).await) } /// Create a `ReadOnlyDevice` from an `Account` diff --git a/crates/matrix-sdk-crypto/src/identities/manager.rs b/crates/matrix-sdk-crypto/src/identities/manager.rs index a9b5df2a1..06a6dbd2b 100644 --- a/crates/matrix-sdk-crypto/src/identities/manager.rs +++ b/crates/matrix-sdk-crypto/src/identities/manager.rs @@ -884,7 +884,7 @@ pub(crate) mod testing { identities::IdentityManager, machine::testing::response_from_file, olm::{Account, PrivateCrossSigningIdentity}, - store::{CryptoStoreWrapper, MemoryStore, Store}, + store::{CryptoStoreWrapper, MemoryStore, PendingChanges, Store}, types::DeviceKeys, verification::VerificationMachine, UploadSigningKeysRequest, @@ -902,15 +902,20 @@ pub(crate) mod testing { device_id!("WSKKLTJZCL") } - pub(crate) async fn manager(user_id: &UserId, device_id: &DeviceId) -> IdentityManager { + pub(crate) async fn manager_test_helper( + user_id: &UserId, + device_id: &DeviceId, + ) -> IdentityManager { let identity = PrivateCrossSigningIdentity::new(user_id.into()).await; let identity = Arc::new(Mutex::new(identity)); let user_id = user_id.to_owned(); let account = Account::with_device_id(&user_id, device_id); + let static_account = account.static_data().clone(); let store = Arc::new(CryptoStoreWrapper::new(&user_id, MemoryStore::new())); let verification = - VerificationMachine::new(account.static_data.clone(), identity.clone(), store.clone()); - let store = Store::new(account, identity, store, verification); + VerificationMachine::new(static_account.clone(), identity.clone(), store.clone()); + let store = Store::new(static_account, identity, store, verification); + store.save_pending_changes(PendingChanges { account: Some(account) }).await.unwrap(); IdentityManager::new(store) } @@ -1186,7 +1191,9 @@ pub(crate) mod tests { use serde_json::json; use stream_assert::{assert_closed, assert_pending, assert_ready}; - use super::testing::{device_id, key_query, manager, other_key_query, other_user_id, user_id}; + use super::testing::{ + device_id, key_query, manager_test_helper, other_key_query, other_user_id, user_id, + }; use crate::{ identities::manager::testing::{other_key_query_cross_signed, own_key_query}, olm::PrivateCrossSigningIdentity, @@ -1211,7 +1218,7 @@ pub(crate) mod tests { #[async_test] async fn test_tracked_users() { - let manager = manager(user_id(), device_id()).await; + let manager = manager_test_helper(user_id(), device_id()).await; let alice = user_id!("@alice:example.org"); assert!( @@ -1231,13 +1238,13 @@ pub(crate) mod tests { #[async_test] async fn test_manager_creation() { - let manager = manager(user_id(), device_id()).await; + let manager = manager_test_helper(user_id(), device_id()).await; assert!(manager.store.tracked_users().await.unwrap().is_empty()) } #[async_test] async fn test_manager_key_query_response() { - let manager = manager(user_id(), device_id()).await; + let manager = manager_test_helper(user_id(), device_id()).await; let other_user = other_user_id(); let devices = manager.store.get_user_devices(other_user).await.unwrap(); assert_eq!(devices.devices().count(), 0); @@ -1264,7 +1271,7 @@ pub(crate) mod tests { #[async_test] async fn test_manager_own_key_query_response() -> anyhow::Result<()> { - let manager = manager(user_id(), device_id()).await; + let manager = manager_test_helper(user_id(), device_id()).await; let our_user = user_id(); let devices = manager.store.get_user_devices(our_user).await?; assert_eq!(devices.devices().count(), 0); @@ -1299,7 +1306,7 @@ pub(crate) mod tests { #[async_test] async fn private_identity_invalidation_after_public_keys_change() { let user_id = user_id!("@example1:localhost"); - let manager = manager(user_id, "DEVICEID".into()).await; + let manager = manager_test_helper(user_id, "DEVICEID".into()).await; let identity_request = { let private_identity = manager.store.private_identity(); @@ -1387,7 +1394,7 @@ pub(crate) mod tests { #[async_test] async fn no_tracked_users_key_query_request() { - let manager = manager(user_id(), device_id()).await; + let manager = manager_test_helper(user_id(), device_id()).await; assert!( manager.store.tracked_users().await.unwrap().is_empty(), @@ -1407,8 +1414,8 @@ pub(crate) mod tests { /// user is not removed from the list of outdated users when the /// response is received #[async_test] - async fn invalidation_race_handling() { - let manager = manager(user_id(), device_id()).await; + async fn test_invalidation_race_handling() { + let manager = manager_test_helper(user_id(), device_id()).await; let alice = other_user_id(); manager.update_tracked_users([alice]).await.unwrap(); @@ -1436,7 +1443,7 @@ pub(crate) mod tests { #[async_test] async fn failure_handling() { - let manager = manager(user_id(), device_id()).await; + let manager = manager_test_helper(user_id(), device_id()).await; let alice = user_id!("@alice:example.org"); assert!( @@ -1476,7 +1483,7 @@ pub(crate) mod tests { #[async_test] async fn test_out_of_band_key_query() { // build the request - let manager = manager(user_id(), device_id()).await; + let manager = manager_test_helper(user_id(), device_id()).await; let (reqid, req) = manager.build_key_query_for_users(vec![user_id()]); assert!(req.device_keys.contains_key(user_id())); @@ -1495,7 +1502,7 @@ pub(crate) mod tests { #[async_test] async fn devices_stream() { - let manager = manager(user_id(), device_id()).await; + let manager = manager_test_helper(user_id(), device_id()).await; let (request_id, _) = manager.build_key_query_for_users(vec![user_id()]); let stream = manager.store.devices_stream(); @@ -1509,7 +1516,7 @@ pub(crate) mod tests { #[async_test] async fn identities_stream() { - let manager = manager(user_id(), device_id()).await; + let manager = manager_test_helper(user_id(), device_id()).await; let (request_id, _) = manager.build_key_query_for_users(vec![user_id()]); let stream = manager.store.user_identities_stream(); @@ -1523,7 +1530,7 @@ pub(crate) mod tests { #[async_test] async fn identities_stream_raw() { - let mut manager = Some(manager(user_id(), device_id()).await); + let mut manager = Some(manager_test_helper(user_id(), device_id()).await); let (request_id, _) = manager.as_ref().unwrap().build_key_query_for_users(vec![user_id()]); let stream = manager.as_ref().unwrap().store.identities_stream_raw(); @@ -1568,7 +1575,7 @@ pub(crate) mod tests { #[async_test] async fn identities_stream_raw_signature_update() { - let mut manager = Some(manager(user_id(), device_id()).await); + let mut manager = Some(manager_test_helper(user_id(), device_id()).await); let (request_id, _) = manager.as_ref().unwrap().build_key_query_for_users(vec![other_user_id()]); diff --git a/crates/matrix-sdk-crypto/src/identities/user.rs b/crates/matrix-sdk-crypto/src/identities/user.rs index b0f188f02..d0dda24b6 100644 --- a/crates/matrix-sdk-crypto/src/identities/user.rs +++ b/crates/matrix-sdk-crypto/src/identities/user.rs @@ -147,7 +147,8 @@ impl OwnUserIdentity { } let cache = self.store.cache().await?; - cache.account().await?.sign_master_key(self.master_key.clone()).await + let account = cache.account().await?; + account.sign_master_key(self.master_key.clone()).await } /// Send a verification request to our other devices. diff --git a/crates/matrix-sdk-crypto/src/machine.rs b/crates/matrix-sdk-crypto/src/machine.rs index d42db6808..f75774fd9 100644 --- a/crates/matrix-sdk-crypto/src/machine.rs +++ b/crates/matrix-sdk-crypto/src/machine.rs @@ -64,6 +64,7 @@ use crate::{ olm::{ Account, CrossSigningStatus, EncryptionSettings, ExportedRoomKey, IdentityKeys, InboundGroupSession, OlmDecryptionInfo, PrivateCrossSigningIdentity, SessionType, + StaticAccountData, }, requests::{IncomingResponse, OutgoingRequest, UploadSigningKeysRequest}, session_manager::{GroupSessionManager, SessionManager}, @@ -168,25 +169,23 @@ impl OlmMachine { ) -> Result { let account = Account::rehydrate(pickle_key, self.user_id(), device_id, device_data).await?; + let static_account = account.static_data().clone(); let store = Arc::new(CryptoStoreWrapper::new(self.user_id(), MemoryStore::new())); + store.save_pending_changes(PendingChanges { account: Some(account) }).await?; - Ok(Self::new_helper(device_id, store, account, self.store().private_identity())) + Ok(Self::new_helper(device_id, store, static_account, self.store().private_identity())) } fn new_helper( device_id: &DeviceId, store: Arc, - account: Account, + account: StaticAccountData, user_identity: Arc>, ) -> Self { - let verification_machine = VerificationMachine::new( - account.static_data().clone(), - user_identity.clone(), - store.clone(), - ); - let store = - Store::new(account.clone(), user_identity.clone(), store, verification_machine.clone()); + let verification_machine = + VerificationMachine::new(account.clone(), user_identity.clone(), store.clone()); + let store = Store::new(account, user_identity.clone(), store, verification_machine.clone()); let group_session_manager = GroupSessionManager::new(store.clone()); @@ -248,7 +247,8 @@ impl OlmMachine { store: impl IntoCryptoStore, ) -> StoreResult { let store = store.into_crypto_store(); - let account = match store.load_account().await? { + + let static_account = match store.load_account().await? { Some(account) => { if user_id != account.user_id() || device_id != account.device_id() { return Err(CryptoStoreError::MismatchedAccount { @@ -262,10 +262,12 @@ impl OlmMachine { .record("curve25519_key", display(account.identity_keys().curve25519)); debug!("Restored an Olm account"); - account + account.static_data().clone() } + None => { let account = Account::with_device_id(user_id, device_id); + let static_account = account.static_data().clone(); Span::current() .record("ed25519_key", display(account.identity_keys().ed25519)) @@ -283,13 +285,11 @@ impl OlmMachine { ..Default::default() }; store.save_changes(changes).await?; - store - .save_pending_changes(PendingChanges { account: Some(account.clone()) }) - .await?; + store.save_pending_changes(PendingChanges { account: Some(account) }).await?; debug!("Created a new Olm account"); - account + static_account } }; @@ -310,7 +310,7 @@ impl OlmMachine { let identity = Arc::new(Mutex::new(identity)); let store = Arc::new(CryptoStoreWrapper::new(user_id, store)); - Ok(OlmMachine::new_helper(device_id, store, account, identity)) + Ok(OlmMachine::new_helper(device_id, store, static_account, identity)) } /// Get the crypto store associated with this `OlmMachine` instance. @@ -2028,7 +2028,8 @@ impl OlmMachine { #[cfg(any(feature = "testing", test))] pub async fn uploaded_key_count(&self) -> Result { let cache = self.inner.store.cache().await?; - Ok(cache.account().await?.uploaded_key_count()) + let account = cache.account().await?; + Ok(account.uploaded_key_count()) } } @@ -2436,6 +2437,7 @@ pub(crate) mod tests { let mut request = machine.keys_for_upload(&account).await.expect("Can't prepare initial key upload"); + drop(account); // Release the account lock. let one_time_key: SignedKey = request .one_time_keys @@ -2461,16 +2463,27 @@ pub(crate) mod tests { ); ret.unwrap(); - let mut response = keys_upload_response(); - response.one_time_key_counts.insert( - DeviceKeyAlgorithm::SignedCurve25519, - (account.max_one_time_keys().await).try_into().unwrap(), - ); + let response = { + let cache = machine.store().cache().await.unwrap(); + let account = cache.account().await.unwrap(); + + let mut response = keys_upload_response(); + response.one_time_key_counts.insert( + DeviceKeyAlgorithm::SignedCurve25519, + (account.max_one_time_keys().await).try_into().unwrap(), + ); + + response + }; machine.receive_keys_upload_response(&response).await.unwrap(); - let ret = machine.keys_for_upload(&account).await; - assert!(ret.is_none()); + { + let cache = machine.store().cache().await.unwrap(); + let account = cache.account().await.unwrap(); + let ret = machine.keys_for_upload(&account).await; + assert!(ret.is_none()); + } } #[async_test] diff --git a/crates/matrix-sdk-crypto/src/olm/account.rs b/crates/matrix-sdk-crypto/src/olm/account.rs index 1c1c9e007..98ceef7db 100644 --- a/crates/matrix-sdk-crypto/src/olm/account.rs +++ b/crates/matrix-sdk-crypto/src/olm/account.rs @@ -314,7 +314,6 @@ impl StaticAccountData { /// /// An account is the central identity for encrypted communication between two /// devices. -#[derive(Clone)] pub struct Account { pub(crate) static_data: StaticAccountData, /// `vodozemac` account. @@ -1356,6 +1355,18 @@ impl Account { }) } } + + /// Test-only. + #[cfg(any(test, feature = "testing"))] + #[doc(hidden)] + pub fn clone_for_testing(&self) -> Self { + Self { + static_data: self.static_data.clone(), + inner: self.inner.clone(), + shared: self.shared.clone(), + uploaded_signed_key_count: self.uploaded_signed_key_count.clone(), + } + } } impl PartialEq for Account { diff --git a/crates/matrix-sdk-crypto/src/session_manager/sessions.rs b/crates/matrix-sdk-crypto/src/session_manager/sessions.rs index 07a1309b4..1efcc5f5b 100644 --- a/crates/matrix-sdk-crypto/src/session_manager/sessions.rs +++ b/crates/matrix-sdk-crypto/src/session_manager/sessions.rs @@ -484,7 +484,7 @@ mod tests { identities::{IdentityManager, ReadOnlyDevice}, olm::{Account, PrivateCrossSigningIdentity}, session_manager::GroupSessionCache, - store::{CryptoStoreWrapper, MemoryStore, Store}, + store::{CryptoStoreWrapper, MemoryStore, PendingChanges, Store}, verification::VerificationMachine, }; @@ -531,7 +531,6 @@ mod tests { let user_id = user_id(); let device_id = device_id(); - let users_for_key_claim = Arc::new(StdRwLock::new(BTreeMap::new())); let account = Account::with_device_id(user_id, device_id); let store = Arc::new(CryptoStoreWrapper::new(user_id, MemoryStore::new())); let identity = Arc::new(Mutex::new(PrivateCrossSigningIdentity::empty(user_id))); @@ -541,15 +540,12 @@ mod tests { store.clone(), ); - let store = Store::new(account.clone(), identity, store, verification); - { - // Perform a dummy transaction to sync in-memory cache with the db. - let tr = store.transaction().await.unwrap(); - tr.commit().await.unwrap(); - } + let store = Store::new(account.static_data().clone(), identity, store, verification); + store.save_pending_changes(PendingChanges { account: Some(account) }).await.unwrap(); let session_cache = GroupSessionCache::new(store.clone()); + let users_for_key_claim = Arc::new(StdRwLock::new(BTreeMap::new())); let key_request = GossipMachine::new(store.clone(), session_cache, users_for_key_claim.clone()); @@ -656,7 +652,7 @@ mod tests { let (_, mut session) = { let cache = manager.store.cache().await.unwrap(); - let manager_account = cache.account().await.unwrap(); + let manager_account = &*cache.account().await.unwrap(); bob.create_session_for(manager_account).await }; diff --git a/crates/matrix-sdk-crypto/src/store/integration_tests.rs b/crates/matrix-sdk-crypto/src/store/integration_tests.rs index 50b4087df..5ee7a2dba 100644 --- a/crates/matrix-sdk-crypto/src/store/integration_tests.rs +++ b/crates/matrix-sdk-crypto/src/store/integration_tests.rs @@ -66,7 +66,7 @@ macro_rules! cryptostore_integration_tests { let store = get_store(name, None).await; let account = get_account(); - store.save_pending_changes(PendingChanges { account: Some(account.clone()), }).await.expect("Can't save account"); + store.save_pending_changes(PendingChanges { account: Some(account.clone_for_testing()), }).await.expect("Can't save account"); (account, store) } @@ -124,7 +124,7 @@ macro_rules! cryptostore_integration_tests { let store = get_store("load_account", None).await; let account = get_account(); - store.save_pending_changes(PendingChanges { account: Some(account.clone()), }).await.expect("Can't save account"); + store.save_pending_changes(PendingChanges { account: Some(account.clone_for_testing()), }).await.expect("Can't save account"); let loaded_account = store.load_account().await.expect("Can't load account"); let loaded_account = loaded_account.unwrap(); @@ -138,7 +138,7 @@ macro_rules! cryptostore_integration_tests { get_store("load_account_with_passphrase", Some("secret_passphrase")).await; let account = get_account(); - store.save_pending_changes(PendingChanges { account: Some(account.clone()), }).await.expect("Can't save account"); + store.save_pending_changes(PendingChanges { account: Some(account.clone_for_testing()), }).await.expect("Can't save account"); let loaded_account = store.load_account().await.expect("Can't load account"); let loaded_account = loaded_account.unwrap(); @@ -151,12 +151,12 @@ macro_rules! cryptostore_integration_tests { let store = get_store("save_and_share_account", None).await; let account = get_account(); - store.save_pending_changes(PendingChanges { account: Some(account.clone()), }).await.expect("Can't save account"); + store.save_pending_changes(PendingChanges { account: Some(account.clone_for_testing()), }).await.expect("Can't save account"); account.mark_as_shared(); account.update_uploaded_key_count(50); - store.save_pending_changes(PendingChanges { account: Some(account.clone()), }).await.expect("Can't save account"); + store.save_pending_changes(PendingChanges { account: Some(account.clone_for_testing()), }).await.expect("Can't save account"); let loaded_account = store.load_account().await.expect("Can't load account"); let loaded_account = loaded_account.unwrap(); @@ -169,7 +169,7 @@ macro_rules! cryptostore_integration_tests { async fn load_sessions() { let store = get_store("load_sessions", None).await; let (account, session) = get_account_and_session().await; - store.save_pending_changes(PendingChanges { account: Some(account.clone()), }).await.expect("Can't save account"); + store.save_pending_changes(PendingChanges { account: Some(account.clone_for_testing()), }).await.expect("Can't save account"); let changes = Changes { sessions: vec![session.clone()], ..Default::default() }; @@ -193,7 +193,7 @@ macro_rules! cryptostore_integration_tests { let sender_key = session.sender_key.to_base64(); let session_id = session.session_id().to_owned(); - store.save_pending_changes(PendingChanges { account: Some(account.clone()), }).await.expect("Can't save account"); + store.save_pending_changes(PendingChanges { account: Some(account.clone_for_testing()), }).await.expect("Can't save account"); let changes = Changes { sessions: vec![session.clone()], ..Default::default() }; store.save_changes(changes).await.unwrap(); @@ -491,7 +491,7 @@ macro_rules! cryptostore_integration_tests { let account = Account::with_device_id(&user_id, device_id); - store.save_pending_changes(PendingChanges { account: Some(account.clone()), }) + store.save_pending_changes(PendingChanges { account: Some(account), }) .await .expect("Can't save account"); diff --git a/crates/matrix-sdk-crypto/src/store/memorystore.rs b/crates/matrix-sdk-crypto/src/store/memorystore.rs index 65ecf8eae..b0102e6bf 100644 --- a/crates/matrix-sdk-crypto/src/store/memorystore.rs +++ b/crates/matrix-sdk-crypto/src/store/memorystore.rs @@ -52,6 +52,7 @@ fn encode_key_info(info: &SecretInfo) -> String { /// An in-memory only store that will forget all the E2EE key once it's dropped. #[derive(Debug)] pub struct MemoryStore { + account: StdRwLock>, sessions: SessionStore, inbound_group_sessions: GroupSessionStore, olm_hashes: StdRwLock>>, @@ -70,6 +71,7 @@ pub struct MemoryStore { impl Default for MemoryStore { fn default() -> Self { MemoryStore { + account: Default::default(), sessions: SessionStore::new(), inbound_group_sessions: GroupSessionStore::new(), olm_hashes: Default::default(), @@ -126,7 +128,7 @@ impl CryptoStore for MemoryStore { type Error = Infallible; async fn load_account(&self) -> Result> { - Ok(None) + Ok(self.account.read().unwrap().as_ref().map(|acc| acc.clone_for_testing())) } async fn load_identity(&self) -> Result> { @@ -137,8 +139,11 @@ impl CryptoStore for MemoryStore { Ok(self.next_batch_token.read().await.clone()) } - async fn save_pending_changes(&self, _changes: PendingChanges) -> Result<()> { - // TODO(bnjbvr) why didn't save_changes save the account? + async fn save_pending_changes(&self, changes: PendingChanges) -> Result<()> { + if let Some(account) = changes.account { + *self.account.write().unwrap() = Some(account); + } + Ok(()) } diff --git a/crates/matrix-sdk-crypto/src/store/mod.rs b/crates/matrix-sdk-crypto/src/store/mod.rs index 4e15cfe93..92a722e1f 100644 --- a/crates/matrix-sdk-crypto/src/store/mod.rs +++ b/crates/matrix-sdk-crypto/src/store/mod.rs @@ -55,7 +55,7 @@ use ruma::{ }; use serde::{de::DeserializeOwned, Deserialize, Serialize}; use thiserror::Error; -use tokio::sync::{Mutex, RwLock}; +use tokio::sync::{Mutex, MutexGuard, RwLock}; use tracing::{info, warn}; use vodozemac::{base64_encode, megolm::SessionOrdering, Curve25519PublicKey}; use zeroize::Zeroize; @@ -107,9 +107,11 @@ pub struct Store { #[derive(Debug)] pub(crate) struct StoreCache { + store: Arc, + tracked_users: StdRwLock>, tracked_user_loading_lock: RwLock, - account: Account, + account: Mutex>, } impl StoreCache { @@ -119,8 +121,19 @@ impl StoreCache { /// /// Note there should always be an account stored at least in the store, so this doesn't return /// an `Option`. - pub async fn account(&self) -> Result<&Account> { - Ok(&self.account) + pub async fn account(&self) -> Result + '_> { + let mut guard = self.account.lock().await; + if guard.is_some() { + Ok(MutexGuard::map(guard, |acc| acc.as_mut().unwrap())) + } else { + match self.store.load_account().await? { + Some(account) => { + *guard = Some(account); + Ok(MutexGuard::map(guard, |acc| acc.as_mut().unwrap())) + } + None => Err(CryptoStoreError::AccountUnset), + } + } } } @@ -163,7 +176,9 @@ impl StoreTransaction { /// Gets a `Account` for update. pub async fn account(&mut self) -> Result<&mut Account> { if self.changes.account.is_none() { - self.changes.account = Some(self.cache.account.clone()); + // Make sure the cache loaded the account. + let _ = self.cache.account().await?; + self.changes.account = self.cache.account.lock().await.take(); } Ok(self.changes.account.as_mut().unwrap()) } @@ -604,23 +619,24 @@ impl From<&InboundGroupSession> for RoomKeyInfo { impl Store { /// Create a new Store. pub(crate) fn new( - account: Account, + account: StaticAccountData, identity: Arc>, store: Arc, verification_machine: VerificationMachine, ) -> Self { Self { inner: Arc::new(StoreInner { - static_account: account.static_data().clone(), + static_account: account, identity, - store, + store: store.clone(), verification_machine, users_for_key_query: AsyncStdMutex::new(UsersForKeyQuery::new()), users_for_key_query_condvar: Condvar::new(), cache: Arc::new(StoreCache { + store, tracked_users: Default::default(), tracked_user_loading_lock: Default::default(), - account, + account: Default::default(), }), }), }