feat: implement a generation counter for the CryptoStore lock

This commit is contained in:
Benjamin Bouvier
2023-06-26 16:39:27 +02:00
parent d72fd34325
commit d642c22db7
5 changed files with 163 additions and 42 deletions

View File

@@ -73,12 +73,6 @@ impl EncryptionSync {
error!("Error when stopping the encryption sync: {err}");
}
}
pub fn reload_caches(&self) {
if let Err(err) = RUNTIME.block_on(self.sync.reload_caches()) {
error!("Error when reloading caches: {err}");
}
}
}
impl Client {

View File

@@ -67,8 +67,8 @@ use crate::{
requests::{IncomingResponse, OutgoingRequest, UploadSigningKeysRequest},
session_manager::{GroupSessionManager, SessionManager},
store::{
Changes, DeviceChanges, DynCryptoStore, IdentityChanges, IntoCryptoStore, MemoryStore,
Result as StoreResult, SecretImportError, Store,
locks::LockStoreError, Changes, DeviceChanges, DynCryptoStore, IdentityChanges,
IntoCryptoStore, MemoryStore, Result as StoreResult, SecretImportError, Store,
},
types::{
events::{
@@ -129,6 +129,18 @@ pub struct OlmMachineInner {
/// A state machine that handles creating room key backups.
#[cfg(feature = "backups_v1")]
backup_machine: BackupMachine,
/// Latest "generation" of data known by the crypto store.
///
/// This is a counter that only increments, set in the database (and can
/// wrap). It's incremented whenever some process acquires a lock for the
/// first time. *This assumes the crypto store lock is being held, to
/// avoid data races on writing to this value in the store*.
///
/// The current process will maintain this value in local memory and in the
/// DB over time. Observing a different value than the one read in
/// memory, when reading from the store indicates that somebody else has
/// written into the database under our feet.
pub(crate) crypto_store_generation: Arc<Mutex<Option<u64>>>,
}
#[cfg(not(tarpaulin_include))]
@@ -142,6 +154,8 @@ impl std::fmt::Debug for OlmMachine {
}
impl OlmMachine {
const CURRENT_GENERATION_STORE_KEY: &str = "generation-counter";
/// Create a new memory based OlmMachine.
///
/// The created machine will keep the encryption keys only in memory and
@@ -212,6 +226,7 @@ impl OlmMachine {
identity_manager,
#[cfg(feature = "backups_v1")]
backup_machine,
crypto_store_generation: Arc::new(Mutex::new(None)),
});
Self { inner }
@@ -1728,6 +1743,106 @@ impl OlmMachine {
pub fn backup_machine(&self) -> &BackupMachine {
&self.inner.backup_machine
}
/// Syncs the database and in-memory generation counter.
///
/// This requires that the crypto store lock has been acquired already.
pub async fn initialize_crypto_store_generation(&self) -> StoreResult<()> {
// Avoid reentrant initialization by taking the lock for the entire's function
// scope.
let mut gen_guard = self.inner.crypto_store_generation.lock().await;
let prev_generation =
self.inner.store.get_custom_value(Self::CURRENT_GENERATION_STORE_KEY).await?;
let gen = match prev_generation {
Some(val) => {
// There was a value in the store. We need to signal that we're a different
// process, so we don't just reuse the value but increment it.
u64::from_le_bytes(
val.try_into().map_err(|_| LockStoreError::InvalidGenerationFormat)?,
)
.wrapping_add(1)
}
None => 0,
};
self.inner
.store
.set_custom_value(Self::CURRENT_GENERATION_STORE_KEY, gen.to_le_bytes().to_vec())
.await?;
*gen_guard = Some(gen);
Ok(())
}
/// If needs be, update the local and on-disk crypto store generation.
///
/// Returns true whether another user has modified the internal generation
/// counter, and as such we've incremented and updated it in the
/// database.
///
/// ## Requirements
///
/// - This assumes that `initialize_crypto_store_generation` has been called
/// beforehand.
/// - This requires that the crypto store lock has been acquired.
pub async fn maintain_crypto_store_generation(&self) -> StoreResult<bool> {
let mut gen_guard = self.inner.crypto_store_generation.lock().await;
// The database value must be there:
// - either we could initialize beforehand, thus write into the database,
// - or we couldn't, and then another process was holding onto the database's
// lock, thus
// has written a generation counter in there.
let actual_gen = self
.inner
.store
.get_custom_value(Self::CURRENT_GENERATION_STORE_KEY)
.await?
.ok_or(LockStoreError::MissingGeneration)?;
let actual_gen = u64::from_le_bytes(
actual_gen.try_into().map_err(|_| LockStoreError::InvalidGenerationFormat)?,
);
let expected_gen = match gen_guard.as_ref() {
Some(expected_gen) => {
if actual_gen == *expected_gen {
return Ok(false);
}
// Increment the biggest, and store it everywhere.
actual_gen.max(*expected_gen).wrapping_add(1)
}
None => {
// Some other process hold onto the lock when initializing, so we must reload.
// Increment database value, and store it everywhere.
actual_gen.wrapping_add(1)
}
};
tracing::debug!(
"Crypto store generation mismatch: previously known was {:?}, actual is {:?}, next is {}",
*gen_guard,
actual_gen,
expected_gen
);
// Update known value.
*gen_guard = Some(expected_gen);
// Update value in database.
self.inner
.store
.set_custom_value(
Self::CURRENT_GENERATION_STORE_KEY,
expected_gen.to_le_bytes().to_vec(),
)
.await?;
Ok(true)
}
}
#[cfg(any(feature = "testing", test))]

View File

@@ -314,6 +314,15 @@ pub enum LockStoreError {
/// Spent too long waiting for a database lock.
#[error("a lock timed out")]
LockTimeout,
/// The generation counter is missing, and should always be present.
#[error("missing generation counter in the store")]
MissingGeneration,
/// Unexpected format for the generation counter. Is someone tampering the
/// database?
#[error("invalid format of the generation counter")]
InvalidGenerationFormat,
}
#[cfg(test)]

View File

@@ -213,19 +213,6 @@ impl EncryptionSync {
Ok(())
}
/// Request a reload of the internal caches used by this sync.
///
/// This must be called every time the process running this loop was
/// suspended and got back into the foreground, and another process may have
/// written to the same underlying store (e.g. notification process vs
/// main process).
pub async fn reload_caches(&self) -> Result<(), Error> {
// Regenerate the crypto store caches first.
self.client.encryption().reload_caches().await.map_err(Error::ClientError)?;
Ok(())
}
}
/// Errors for the [`EncryptionSync`].

View File

@@ -864,6 +864,17 @@ impl Encryption {
let lock =
olm_machine.store().create_store_lock("cross_process_lock".to_owned(), lock_value);
// Gently try to initialize the crypto store generation counter.
//
// If we don't get the lock immediately, then it is already acquired by another
// process, and we'll get to reload next time we acquire the lock.
{
let guard = lock.try_lock_once().await?;
if guard.is_some() {
olm_machine.initialize_crypto_store_generation().await?;
}
}
self.client
.inner
.cross_process_crypto_store_lock
@@ -873,6 +884,22 @@ impl Encryption {
Ok(())
}
/// Maybe reload the `OlmMachine` after acquiring the lock for the first
/// time.
async fn on_lock_newly_acquired(&self) -> Result<(), Error> {
let olm_machine_guard = self.client.olm_machine().await;
if let Some(olm_machine) = olm_machine_guard.as_ref() {
// If the crypto store generation has changed,
if olm_machine.maintain_crypto_store_generation().await? {
// (get rid of the reference to the current crypto store first)
drop(olm_machine_guard);
// Recreate the OlmMachine.
self.client.base_client().regenerate_olm().await?;
}
}
Ok(())
}
/// If a lock was created with [`Self::enable_cross_process_store_lock`],
/// spin-waits until the lock is available.
///
@@ -883,20 +910,11 @@ impl Encryption {
max_backoff: Option<u32>,
) -> Result<Option<CryptoStoreLockGuard>, Error> {
if let Some(lock) = self.client.inner.cross_process_crypto_store_lock.get() {
match lock.try_lock_once().await? {
Some(guard) => Ok(Some(guard)),
None => {
// We didn't get the lock on the first attempt, so that means that another
// process is using it. Wait for it to release it.
let guard = lock.spin_lock(max_backoff).await?;
let guard = lock.spin_lock(max_backoff).await?;
// As we didn't get the lock on the first attempt, force-reload all the crypto
// state caches at once, by recreating the OlmMachine from scratch.
self.client.base_client().regenerate_olm().await?;
self.on_lock_newly_acquired().await?;
Ok(Some(guard))
}
}
Ok(Some(guard))
} else {
Ok(None)
}
@@ -908,19 +926,17 @@ impl Encryption {
/// Returns a guard to the lock, if it was obtained.
pub async fn try_lock_store_once(&self) -> Result<Option<CryptoStoreLockGuard>, Error> {
if let Some(lock) = self.client.inner.cross_process_crypto_store_lock.get() {
Ok(lock.try_lock_once().await?)
let maybe_guard = lock.try_lock_once().await?;
if maybe_guard.is_some() {
self.on_lock_newly_acquired().await?;
}
Ok(maybe_guard)
} else {
Ok(None)
}
}
/// Manually request that the internal crypto caches be reloaded.
pub async fn reload_caches(&self) -> Result<(), Error> {
// At this time, rleoading the `OlmMachine` ought to be sufficient.
self.client.base_client().regenerate_olm().await?;
Ok(())
}
}
#[cfg(all(test, not(target_arch = "wasm32")))]