diff --git a/bindings/matrix-sdk-ffi/src/encryption_sync.rs b/bindings/matrix-sdk-ffi/src/encryption_sync.rs index 2dbbfc531..2954d42c2 100644 --- a/bindings/matrix-sdk-ffi/src/encryption_sync.rs +++ b/bindings/matrix-sdk-ffi/src/encryption_sync.rs @@ -73,12 +73,6 @@ impl EncryptionSync { error!("Error when stopping the encryption sync: {err}"); } } - - pub fn reload_caches(&self) { - if let Err(err) = RUNTIME.block_on(self.sync.reload_caches()) { - error!("Error when reloading caches: {err}"); - } - } } impl Client { diff --git a/crates/matrix-sdk-crypto/src/machine.rs b/crates/matrix-sdk-crypto/src/machine.rs index 8a06ca88e..24394b681 100644 --- a/crates/matrix-sdk-crypto/src/machine.rs +++ b/crates/matrix-sdk-crypto/src/machine.rs @@ -67,8 +67,8 @@ use crate::{ requests::{IncomingResponse, OutgoingRequest, UploadSigningKeysRequest}, session_manager::{GroupSessionManager, SessionManager}, store::{ - Changes, DeviceChanges, DynCryptoStore, IdentityChanges, IntoCryptoStore, MemoryStore, - Result as StoreResult, SecretImportError, Store, + locks::LockStoreError, Changes, DeviceChanges, DynCryptoStore, IdentityChanges, + IntoCryptoStore, MemoryStore, Result as StoreResult, SecretImportError, Store, }, types::{ events::{ @@ -129,6 +129,18 @@ pub struct OlmMachineInner { /// A state machine that handles creating room key backups. #[cfg(feature = "backups_v1")] backup_machine: BackupMachine, + /// Latest "generation" of data known by the crypto store. + /// + /// This is a counter that only increments, set in the database (and can + /// wrap). It's incremented whenever some process acquires a lock for the + /// first time. *This assumes the crypto store lock is being held, to + /// avoid data races on writing to this value in the store*. + /// + /// The current process will maintain this value in local memory and in the + /// DB over time. Observing a different value than the one read in + /// memory, when reading from the store indicates that somebody else has + /// written into the database under our feet. + pub(crate) crypto_store_generation: Arc>>, } #[cfg(not(tarpaulin_include))] @@ -142,6 +154,8 @@ impl std::fmt::Debug for OlmMachine { } impl OlmMachine { + const CURRENT_GENERATION_STORE_KEY: &str = "generation-counter"; + /// Create a new memory based OlmMachine. /// /// The created machine will keep the encryption keys only in memory and @@ -212,6 +226,7 @@ impl OlmMachine { identity_manager, #[cfg(feature = "backups_v1")] backup_machine, + crypto_store_generation: Arc::new(Mutex::new(None)), }); Self { inner } @@ -1728,6 +1743,106 @@ impl OlmMachine { pub fn backup_machine(&self) -> &BackupMachine { &self.inner.backup_machine } + + /// Syncs the database and in-memory generation counter. + /// + /// This requires that the crypto store lock has been acquired already. + pub async fn initialize_crypto_store_generation(&self) -> StoreResult<()> { + // Avoid reentrant initialization by taking the lock for the entire's function + // scope. + let mut gen_guard = self.inner.crypto_store_generation.lock().await; + + let prev_generation = + self.inner.store.get_custom_value(Self::CURRENT_GENERATION_STORE_KEY).await?; + + let gen = match prev_generation { + Some(val) => { + // There was a value in the store. We need to signal that we're a different + // process, so we don't just reuse the value but increment it. + u64::from_le_bytes( + val.try_into().map_err(|_| LockStoreError::InvalidGenerationFormat)?, + ) + .wrapping_add(1) + } + None => 0, + }; + + self.inner + .store + .set_custom_value(Self::CURRENT_GENERATION_STORE_KEY, gen.to_le_bytes().to_vec()) + .await?; + + *gen_guard = Some(gen); + + Ok(()) + } + + /// If needs be, update the local and on-disk crypto store generation. + /// + /// Returns true whether another user has modified the internal generation + /// counter, and as such we've incremented and updated it in the + /// database. + /// + /// ## Requirements + /// + /// - This assumes that `initialize_crypto_store_generation` has been called + /// beforehand. + /// - This requires that the crypto store lock has been acquired. + pub async fn maintain_crypto_store_generation(&self) -> StoreResult { + let mut gen_guard = self.inner.crypto_store_generation.lock().await; + + // The database value must be there: + // - either we could initialize beforehand, thus write into the database, + // - or we couldn't, and then another process was holding onto the database's + // lock, thus + // has written a generation counter in there. + let actual_gen = self + .inner + .store + .get_custom_value(Self::CURRENT_GENERATION_STORE_KEY) + .await? + .ok_or(LockStoreError::MissingGeneration)?; + + let actual_gen = u64::from_le_bytes( + actual_gen.try_into().map_err(|_| LockStoreError::InvalidGenerationFormat)?, + ); + + let expected_gen = match gen_guard.as_ref() { + Some(expected_gen) => { + if actual_gen == *expected_gen { + return Ok(false); + } + // Increment the biggest, and store it everywhere. + actual_gen.max(*expected_gen).wrapping_add(1) + } + None => { + // Some other process hold onto the lock when initializing, so we must reload. + // Increment database value, and store it everywhere. + actual_gen.wrapping_add(1) + } + }; + + tracing::debug!( + "Crypto store generation mismatch: previously known was {:?}, actual is {:?}, next is {}", + *gen_guard, + actual_gen, + expected_gen + ); + + // Update known value. + *gen_guard = Some(expected_gen); + + // Update value in database. + self.inner + .store + .set_custom_value( + Self::CURRENT_GENERATION_STORE_KEY, + expected_gen.to_le_bytes().to_vec(), + ) + .await?; + + Ok(true) + } } #[cfg(any(feature = "testing", test))] diff --git a/crates/matrix-sdk-crypto/src/store/locks.rs b/crates/matrix-sdk-crypto/src/store/locks.rs index d6171fe46..b7fabaa01 100644 --- a/crates/matrix-sdk-crypto/src/store/locks.rs +++ b/crates/matrix-sdk-crypto/src/store/locks.rs @@ -314,6 +314,15 @@ pub enum LockStoreError { /// Spent too long waiting for a database lock. #[error("a lock timed out")] LockTimeout, + + /// The generation counter is missing, and should always be present. + #[error("missing generation counter in the store")] + MissingGeneration, + + /// Unexpected format for the generation counter. Is someone tampering the + /// database? + #[error("invalid format of the generation counter")] + InvalidGenerationFormat, } #[cfg(test)] diff --git a/crates/matrix-sdk-ui/src/encryption_sync/mod.rs b/crates/matrix-sdk-ui/src/encryption_sync/mod.rs index fce89a3ee..a50d9c484 100644 --- a/crates/matrix-sdk-ui/src/encryption_sync/mod.rs +++ b/crates/matrix-sdk-ui/src/encryption_sync/mod.rs @@ -213,19 +213,6 @@ impl EncryptionSync { Ok(()) } - - /// Request a reload of the internal caches used by this sync. - /// - /// This must be called every time the process running this loop was - /// suspended and got back into the foreground, and another process may have - /// written to the same underlying store (e.g. notification process vs - /// main process). - pub async fn reload_caches(&self) -> Result<(), Error> { - // Regenerate the crypto store caches first. - self.client.encryption().reload_caches().await.map_err(Error::ClientError)?; - - Ok(()) - } } /// Errors for the [`EncryptionSync`]. diff --git a/crates/matrix-sdk/src/encryption/mod.rs b/crates/matrix-sdk/src/encryption/mod.rs index e26cfa685..9e71785af 100644 --- a/crates/matrix-sdk/src/encryption/mod.rs +++ b/crates/matrix-sdk/src/encryption/mod.rs @@ -864,6 +864,17 @@ impl Encryption { let lock = olm_machine.store().create_store_lock("cross_process_lock".to_owned(), lock_value); + // Gently try to initialize the crypto store generation counter. + // + // If we don't get the lock immediately, then it is already acquired by another + // process, and we'll get to reload next time we acquire the lock. + { + let guard = lock.try_lock_once().await?; + if guard.is_some() { + olm_machine.initialize_crypto_store_generation().await?; + } + } + self.client .inner .cross_process_crypto_store_lock @@ -873,6 +884,22 @@ impl Encryption { Ok(()) } + /// Maybe reload the `OlmMachine` after acquiring the lock for the first + /// time. + async fn on_lock_newly_acquired(&self) -> Result<(), Error> { + let olm_machine_guard = self.client.olm_machine().await; + if let Some(olm_machine) = olm_machine_guard.as_ref() { + // If the crypto store generation has changed, + if olm_machine.maintain_crypto_store_generation().await? { + // (get rid of the reference to the current crypto store first) + drop(olm_machine_guard); + // Recreate the OlmMachine. + self.client.base_client().regenerate_olm().await?; + } + } + Ok(()) + } + /// If a lock was created with [`Self::enable_cross_process_store_lock`], /// spin-waits until the lock is available. /// @@ -883,20 +910,11 @@ impl Encryption { max_backoff: Option, ) -> Result, Error> { if let Some(lock) = self.client.inner.cross_process_crypto_store_lock.get() { - match lock.try_lock_once().await? { - Some(guard) => Ok(Some(guard)), - None => { - // We didn't get the lock on the first attempt, so that means that another - // process is using it. Wait for it to release it. - let guard = lock.spin_lock(max_backoff).await?; + let guard = lock.spin_lock(max_backoff).await?; - // As we didn't get the lock on the first attempt, force-reload all the crypto - // state caches at once, by recreating the OlmMachine from scratch. - self.client.base_client().regenerate_olm().await?; + self.on_lock_newly_acquired().await?; - Ok(Some(guard)) - } - } + Ok(Some(guard)) } else { Ok(None) } @@ -908,19 +926,17 @@ impl Encryption { /// Returns a guard to the lock, if it was obtained. pub async fn try_lock_store_once(&self) -> Result, Error> { if let Some(lock) = self.client.inner.cross_process_crypto_store_lock.get() { - Ok(lock.try_lock_once().await?) + let maybe_guard = lock.try_lock_once().await?; + + if maybe_guard.is_some() { + self.on_lock_newly_acquired().await?; + } + + Ok(maybe_guard) } else { Ok(None) } } - - /// Manually request that the internal crypto caches be reloaded. - pub async fn reload_caches(&self) -> Result<(), Error> { - // At this time, rleoading the `OlmMachine` ought to be sufficient. - self.client.base_client().regenerate_olm().await?; - - Ok(()) - } } #[cfg(all(test, not(target_arch = "wasm32")))]