diff --git a/crates/matrix-sdk-sqlite/src/crypto_store.rs b/crates/matrix-sdk-sqlite/src/crypto_store.rs index e3dfb50be..baa8bf988 100644 --- a/crates/matrix-sdk-sqlite/src/crypto_store.rs +++ b/crates/matrix-sdk-sqlite/src/crypto_store.rs @@ -41,7 +41,10 @@ use ruma::{ events::secret::request::SecretName, }; use rusqlite::{OptionalExtension, named_params, params_from_iter}; -use tokio::{fs, sync::Mutex}; +use tokio::{ + fs, + sync::{Mutex, OwnedMutexGuard}, +}; use tracing::{debug, instrument, warn}; use vodozemac::Curve25519PublicKey; @@ -62,8 +65,16 @@ const DATABASE_NAME: &str = "matrix-sdk-crypto.sqlite3"; #[derive(Clone)] pub struct SqliteCryptoStore { store_cipher: Option>, + + /// The pool of connections. pool: SqlitePool, + /// We make the difference between connections for read operations, and for + /// write operations. We keep a single connection apart from write + /// operations. All other connections are used for read operations. The + /// lock is used to ensure there is one owner at a time. + write_connection: Arc>, + // DB values cached in memory static_account: Arc>>, save_changes_lock: Arc>, @@ -135,6 +146,8 @@ impl SqliteCryptoStore { Ok(SqliteCryptoStore { store_cipher, pool, + // Use `conn` as our selected write connections. + write_connection: Arc::new(Mutex::new(conn)), static_account: Arc::new(RwLock::new(None)), save_changes_lock: Default::default(), }) @@ -168,9 +181,17 @@ impl SqliteCryptoStore { self.static_account.read().unwrap().clone() } - async fn acquire(&self) -> Result { + /// Acquire a connection for executing read operations. + #[instrument(skip_all)] + async fn read(&self) -> Result { Ok(self.pool.get().await?) } + + /// Acquire a connection for executing write operations. + #[instrument(skip_all)] + async fn write(&self) -> OwnedMutexGuard { + self.write_connection.clone().lock_owned().await + } } const DATABASE_VERSION: u8 = 14; @@ -855,7 +876,7 @@ impl CryptoStore for SqliteCryptoStore { type Error = Error; async fn load_account(&self) -> Result> { - let conn = self.acquire().await?; + let conn = self.read().await?; if let Some(pickle) = conn.get_kv("account").await? { let pickle = self.deserialize_value(&pickle)?; @@ -870,7 +891,7 @@ impl CryptoStore for SqliteCryptoStore { } async fn load_identity(&self) -> Result> { - let conn = self.acquire().await?; + let conn = self.read().await?; if let Some(i) = conn.get_kv("identity").await? { let pickle = self.deserialize_value(&i)?; Ok(Some(PrivateCrossSigningIdentity::from_pickle(pickle).map_err(|_| Error::Unpickle)?)) @@ -894,8 +915,8 @@ impl CryptoStore for SqliteCryptoStore { }; let this = self.clone(); - self.acquire() - .await? + self.write() + .await .with_transaction(move |txn| { if let Some(pickled_account) = pickled_account { let serialized_account = this.serialize_value(&pickled_account)?; @@ -946,8 +967,8 @@ impl CryptoStore for SqliteCryptoStore { } let this = self.clone(); - self.acquire() - .await? + self.write() + .await .with_transaction(move |txn| { if let Some(pickled_private_identity) = &pickled_private_identity { let serialized_private_identity = @@ -1094,7 +1115,7 @@ impl CryptoStore for SqliteCryptoStore { let device_keys = self.get_own_device().await?.as_device_keys().clone(); let sessions: Vec<_> = self - .acquire() + .read() .await? .get_sessions_for_sender_key(self.encode_key("session", sender_key.as_bytes())) .await? @@ -1116,7 +1137,7 @@ impl CryptoStore for SqliteCryptoStore { ) -> Result> { let session_id = self.encode_key("inbound_group_session", session_id); let Some((room_id_from_db, value, backed_up)) = - self.acquire().await?.get_inbound_group_session(session_id).await? + self.read().await?.get_inbound_group_session(session_id).await? else { return Ok(None); }; @@ -1131,7 +1152,7 @@ impl CryptoStore for SqliteCryptoStore { } async fn get_inbound_group_sessions(&self) -> Result> { - self.acquire() + self.read() .await? .get_inbound_group_sessions() .await? @@ -1147,7 +1168,7 @@ impl CryptoStore for SqliteCryptoStore { room_id: &RoomId, ) -> Result> { let room_id = self.encode_key("inbound_group_session", room_id.as_bytes()); - self.acquire() + self.read() .await? .get_inbound_group_sessions_by_room_id(room_id) .await? @@ -1169,7 +1190,7 @@ impl CryptoStore for SqliteCryptoStore { after_session_id.map(|session_id| self.encode_key("inbound_group_session", session_id)); let sender_key = self.encode_key("inbound_group_session", sender_key.to_base64()); - self.acquire() + self.read() .await? .get_inbound_group_sessions_for_device_batch( sender_key, @@ -1189,7 +1210,7 @@ impl CryptoStore for SqliteCryptoStore { &self, backup_version: Option<&str>, ) -> Result { - Ok(self.acquire().await?.get_inbound_group_session_counts(backup_version).await?) + Ok(self.read().await?.get_inbound_group_session_counts(backup_version).await?) } async fn inbound_group_sessions_for_backup( @@ -1197,7 +1218,7 @@ impl CryptoStore for SqliteCryptoStore { _backup_version: &str, limit: usize, ) -> Result> { - self.acquire() + self.read() .await? .get_inbound_group_sessions_for_backup(limit) .await? @@ -1212,8 +1233,8 @@ impl CryptoStore for SqliteCryptoStore { session_ids: &[(&RoomId, &str)], ) -> Result<()> { Ok(self - .acquire() - .await? + .write() + .await .mark_inbound_group_sessions_as_backed_up( session_ids .iter() @@ -1224,11 +1245,11 @@ impl CryptoStore for SqliteCryptoStore { } async fn reset_backup_state(&self) -> Result<()> { - Ok(self.acquire().await?.reset_inbound_group_session_backup_state().await?) + Ok(self.write().await.reset_inbound_group_session_backup_state().await?) } async fn load_backup_keys(&self) -> Result { - let conn = self.acquire().await?; + let conn = self.read().await?; let backup_version = conn .get_kv("backup_version_v1") @@ -1246,7 +1267,7 @@ impl CryptoStore for SqliteCryptoStore { } async fn load_dehydrated_device_pickle_key(&self) -> Result> { - let conn = self.acquire().await?; + let conn = self.read().await?; conn.get_kv(DEHYDRATED_DEVICE_PICKLE_KEY) .await? @@ -1255,17 +1276,14 @@ impl CryptoStore for SqliteCryptoStore { } async fn delete_dehydrated_device_pickle_key(&self) -> Result<(), Self::Error> { - let conn = self.acquire().await?; - conn.clear_kv(DEHYDRATED_DEVICE_PICKLE_KEY).await?; - - Ok(()) + Ok(self.write().await.clear_kv(DEHYDRATED_DEVICE_PICKLE_KEY).await?) } async fn get_outbound_group_session( &self, room_id: &RoomId, ) -> Result> { let room_id = self.encode_key("outbound_group_session", room_id.as_bytes()); - let Some(value) = self.acquire().await?.get_outbound_group_session(room_id).await? else { + let Some(value) = self.read().await?.get_outbound_group_session(room_id).await? else { return Ok(None); }; @@ -1283,7 +1301,7 @@ impl CryptoStore for SqliteCryptoStore { } async fn load_tracked_users(&self) -> Result> { - self.acquire() + self.read() .await? .get_tracked_users() .await? @@ -1303,7 +1321,7 @@ impl CryptoStore for SqliteCryptoStore { }) .collect::>()?; - Ok(self.acquire().await?.add_tracked_users(users).await?) + Ok(self.write().await.add_tracked_users(users).await?) } async fn get_device( @@ -1314,7 +1332,7 @@ impl CryptoStore for SqliteCryptoStore { let user_id = self.encode_key("device", user_id.as_bytes()); let device_id = self.encode_key("device", device_id.as_bytes()); Ok(self - .acquire() + .read() .await? .get_device(user_id, device_id) .await? @@ -1327,7 +1345,7 @@ impl CryptoStore for SqliteCryptoStore { user_id: &UserId, ) -> Result> { let user_id = self.encode_key("device", user_id.as_bytes()); - self.acquire() + self.read() .await? .get_user_devices(user_id) .await? @@ -1351,7 +1369,7 @@ impl CryptoStore for SqliteCryptoStore { async fn get_user_identity(&self, user_id: &UserId) -> Result> { let user_id = self.encode_key("identity", user_id.as_bytes()); Ok(self - .acquire() + .read() .await? .get_user_identity(user_id) .await? @@ -1364,7 +1382,7 @@ impl CryptoStore for SqliteCryptoStore { message_hash: &matrix_sdk_crypto::olm::OlmMessageHash, ) -> Result { let value = rmp_serde::to_vec(message_hash)?; - Ok(self.acquire().await?.has_olm_hash(value).await?) + Ok(self.read().await?.has_olm_hash(value).await?) } async fn get_outgoing_secret_requests( @@ -1373,7 +1391,7 @@ impl CryptoStore for SqliteCryptoStore { ) -> Result> { let request_id = self.encode_key("key_requests", request_id.as_bytes()); Ok(self - .acquire() + .read() .await? .get_outgoing_secret_request(request_id) .await? @@ -1385,7 +1403,7 @@ impl CryptoStore for SqliteCryptoStore { &self, key_info: &SecretInfo, ) -> Result> { - let requests = self.acquire().await?.get_outgoing_secret_requests().await?; + let requests = self.read().await?.get_outgoing_secret_requests().await?; for (request, sent_out) in requests { let request = self.deserialize_key_request(&request, sent_out)?; if request.info == *key_info { @@ -1396,7 +1414,7 @@ impl CryptoStore for SqliteCryptoStore { } async fn get_unsent_secret_requests(&self) -> Result> { - self.acquire() + self.read() .await? .get_unsent_secret_requests() .await? @@ -1410,7 +1428,7 @@ impl CryptoStore for SqliteCryptoStore { async fn delete_outgoing_secret_requests(&self, request_id: &TransactionId) -> Result<()> { let request_id = self.encode_key("key_requests", request_id.as_bytes()); - Ok(self.acquire().await?.delete_key_request(request_id).await?) + Ok(self.write().await.delete_key_request(request_id).await?) } async fn get_secrets_from_inbox( @@ -1419,7 +1437,7 @@ impl CryptoStore for SqliteCryptoStore { ) -> Result> { let secret_name = self.encode_key("secrets", secret_name.to_string()); - self.acquire() + self.read() .await? .get_secrets_from_inbox(secret_name) .await? @@ -1430,7 +1448,7 @@ impl CryptoStore for SqliteCryptoStore { async fn delete_secrets_from_inbox(&self, secret_name: &SecretName) -> Result<()> { let secret_name = self.encode_key("secrets", secret_name.to_string()); - self.acquire().await?.delete_secrets_from_inbox(secret_name).await + self.write().await.delete_secrets_from_inbox(secret_name).await } async fn get_withheld_info( @@ -1441,7 +1459,7 @@ impl CryptoStore for SqliteCryptoStore { let room_id = self.encode_key("direct_withheld_info", room_id); let session_id = self.encode_key("direct_withheld_info", session_id); - self.acquire() + self.read() .await? .get_direct_withheld_info(session_id, room_id) .await? @@ -1458,7 +1476,7 @@ impl CryptoStore for SqliteCryptoStore { ) -> matrix_sdk_crypto::store::Result, Self::Error> { let room_id = self.encode_key("direct_withheld_info", room_id); - self.acquire() + self.read() .await? .get_withheld_sessions_by_room_id(room_id) .await? @@ -1469,7 +1487,7 @@ impl CryptoStore for SqliteCryptoStore { async fn get_room_settings(&self, room_id: &RoomId) -> Result> { let room_id = self.encode_key("room_settings", room_id.as_bytes()); - let Some(value) = self.acquire().await?.get_room_settings(room_id).await? else { + let Some(value) = self.read().await?.get_room_settings(room_id).await? else { return Ok(None); }; @@ -1485,7 +1503,7 @@ impl CryptoStore for SqliteCryptoStore { ) -> Result> { let room_id = self.encode_key("received_room_key_bundle", room_id); let user_id = self.encode_key("received_room_key_bundle", user_id); - self.acquire() + self.read() .await? .get_received_room_key_bundle(room_id, user_id) .await? @@ -1495,11 +1513,11 @@ impl CryptoStore for SqliteCryptoStore { async fn has_downloaded_all_room_keys(&self, room_id: &RoomId) -> Result { let room_id = self.encode_key("room_key_backups_fully_downloaded", room_id); - self.acquire().await?.has_downloaded_all_room_keys(room_id).await + self.read().await?.has_downloaded_all_room_keys(room_id).await } async fn get_custom_value(&self, key: &str) -> Result>> { - let Some(serialized) = self.acquire().await?.get_kv(key).await? else { + let Some(serialized) = self.read().await?.get_kv(key).await? else { return Ok(None); }; let value = if let Some(cipher) = &self.store_cipher { @@ -1520,14 +1538,14 @@ impl CryptoStore for SqliteCryptoStore { value }; - self.acquire().await?.set_kv(key, serialized).await?; + self.write().await.set_kv(key, serialized).await?; Ok(()) } async fn remove_custom_value(&self, key: &str) -> Result<()> { let key = key.to_owned(); - self.acquire() - .await? + self.write() + .await .interact(move |conn| conn.execute("DELETE FROM kv WHERE key = ?1", (&key,))) .await .unwrap()?; @@ -1549,8 +1567,8 @@ impl CryptoStore for SqliteCryptoStore { // Learn about the `excluded` keyword in https://sqlite.org/lang_upsert.html. let generation = self - .acquire() - .await? + .write() + .await .with_transaction(move |txn| { txn.query_row( "INSERT INTO lease_locks (key, holder, expiration) @@ -1581,7 +1599,7 @@ impl CryptoStore for SqliteCryptoStore { } async fn next_batch_token(&self) -> Result, Self::Error> { - let conn = self.acquire().await?; + let conn = self.read().await?; if let Some(token) = conn.get_kv("next_batch_token").await? { let maybe_token: Option = self.deserialize_value(&token)?; Ok(maybe_token)