diff --git a/crates/matrix-sdk-sled/src/crypto_store.rs b/crates/matrix-sdk-sled/src/crypto_store.rs index 6d3175bf9..d3ee6183f 100644 --- a/crates/matrix-sdk-sled/src/crypto_store.rs +++ b/crates/matrix-sdk-sled/src/crypto_store.rs @@ -16,7 +16,10 @@ use std::{ borrow::Cow, collections::{HashMap, HashSet}, path::{Path, PathBuf}, - sync::{Arc, RwLock}, + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, RwLock, + }, }; use async_trait::async_trait; @@ -169,6 +172,9 @@ struct TrackedUser { #[derive(Clone)] pub struct SledCryptoStore { account_info: Arc>>, + tracked_user_loading_lock: Arc>, + tracked_users_loaded: Arc, + store_cipher: Option>, path: Option, inner: Db, @@ -398,6 +404,8 @@ impl SledCryptoStore { let database = Self { account_info: RwLock::new(None).into(), + tracked_user_loading_lock: Mutex::new(()).into(), + tracked_users_loaded: AtomicBool::new(false).into(), path, inner: db, store_cipher, @@ -424,14 +432,26 @@ impl SledCryptoStore { } async fn load_tracked_users(&self) -> Result<()> { - for value in &self.tracked_users { - let (_, user) = value.map_err(CryptoStoreError::backend)?; - let user: TrackedUser = self.deserialize_value(&user)?; + // If the users are loaded do nothing, otherwise acquire a lock. + if !self.tracked_users_loaded.load(Ordering::SeqCst) { + let _lock = self.tracked_user_loading_lock.lock().await; - self.tracked_users_cache.insert(user.user_id.to_owned()); + // Check again if the users have been loaded, in case another call to this + // method loaded the tracked users between the time we tried to + // acquire the lock and the time we actually acquired the lock. + if !self.tracked_users_loaded.load(Ordering::SeqCst) { + for value in &self.tracked_users { + let (_, user) = value.map_err(CryptoStoreError::backend)?; + let user: TrackedUser = self.deserialize_value(&user)?; - if user.dirty { - self.users_for_key_query_cache.insert(user.user_id); + self.tracked_users_cache.insert(user.user_id.to_owned()); + + if user.dirty { + self.users_for_key_query_cache.insert(user.user_id); + } + } + + self.tracked_users_loaded.store(true, Ordering::SeqCst); } } @@ -715,7 +735,6 @@ impl CryptoStore for SledCryptoStore { { let pickle = self.deserialize_value(&pickle)?; - self.load_tracked_users().await?; let account = ReadOnlyAccount::from_pickle(pickle)?; let account_info = AccountInfo { @@ -864,22 +883,32 @@ impl CryptoStore for SledCryptoStore { } async fn is_user_tracked(&self, user_id: &UserId) -> Result { + self.load_tracked_users().await?; + Ok(self.tracked_users_cache.contains(user_id)) } async fn has_users_for_key_query(&self) -> Result { + self.load_tracked_users().await?; + Ok(!self.users_for_key_query_cache.is_empty()) } async fn users_for_key_query(&self) -> Result> { + self.load_tracked_users().await?; + Ok(self.users_for_key_query_cache.iter().map(|u| u.clone()).collect()) } async fn tracked_users(&self) -> Result> { + self.load_tracked_users().await?; + Ok(self.tracked_users_cache.to_owned().iter().map(|u| u.clone()).collect()) } async fn update_tracked_user(&self, user: &UserId, dirty: bool) -> Result { + self.load_tracked_users().await?; + let already_added = self.tracked_users_cache.insert(user.to_owned()); if dirty {