diff --git a/crates/matrix-sdk-sqlite/src/event_cache_store.rs b/crates/matrix-sdk-sqlite/src/event_cache_store.rs index 0bcc04104..79b46c808 100644 --- a/crates/matrix-sdk-sqlite/src/event_cache_store.rs +++ b/crates/matrix-sdk-sqlite/src/event_cache_store.rs @@ -43,7 +43,10 @@ use ruma::{ OwnedEventId, RoomId, }; use rusqlite::{params_from_iter, OptionalExtension, ToSql, Transaction, TransactionBehavior}; -use tokio::fs; +use tokio::{ + fs, + sync::{Mutex, OwnedMutexGuard}, +}; use tracing::{debug, error, trace}; use crate::{ @@ -86,7 +89,16 @@ const CHUNK_TYPE_GAP_TYPE_STRING: &str = "G"; #[derive(Clone)] pub struct SqliteEventCacheStore { store_cipher: Option>, + + /// The pool of connections. pool: SqlitePool, + + /// We make the difference between connections for read operations, and for + /// write operations. We keep a single connection apart from write + /// operations. All other connections are used for read operations. The + /// lock is used to ensure there is one owner at a time. + write_connection: Arc>, + media_service: MediaService, } @@ -125,7 +137,7 @@ impl SqliteEventCacheStore { let pool = config.create_pool(Runtime::Tokio1)?; let this = Self::open_with_pool(pool, passphrase.as_deref()).await?; - this.pool.get().await?.apply_runtime_config(runtime_config).await?; + this.write().await?.apply_runtime_config(runtime_config).await?; Ok(this) } @@ -151,10 +163,17 @@ impl SqliteEventCacheStore { let last_media_cleanup_time = conn.get_serialized_kv(keys::LAST_MEDIA_CLEANUP_TIME).await?; media_service.restore(media_retention_policy, last_media_cleanup_time); - Ok(Self { store_cipher, pool, media_service }) + Ok(Self { + store_cipher, + pool, + // Use `conn` as our selected write connections. + write_connection: Arc::new(Mutex::new(conn)), + media_service, + }) } - async fn acquire(&self) -> Result { + // Acquire a connection for executing read operations. + async fn read(&self) -> Result { let connection = self.pool.get().await?; // Per https://www.sqlite.org/foreignkeys.html#fk_enable, foreign key @@ -166,6 +185,19 @@ impl SqliteEventCacheStore { Ok(connection) } + // Acquire a connection for executing write operations. + async fn write(&self) -> Result> { + let connection = self.write_connection.clone().lock_owned().await; + + // Per https://www.sqlite.org/foreignkeys.html#fk_enable, foreign key + // support must be enabled on a per-connection basis. Execute it every + // time we try to get a connection, since we can't guarantee a previous + // connection did enable it before. + connection.execute_batch("PRAGMA foreign_keys = ON;").await?; + + Ok(connection) + } + fn map_row_to_chunk( row: &rusqlite::Row<'_>, ) -> Result<(u64, Option, Option, String), rusqlite::Error> { @@ -425,7 +457,7 @@ impl EventCacheStore for SqliteEventCacheStore { let expiration = now + lease_duration_ms as u64; let num_touched = self - .acquire() + .write() .await? .with_transaction(move |txn| { txn.execute( @@ -457,7 +489,7 @@ impl EventCacheStore for SqliteEventCacheStore { let linked_chunk_id = linked_chunk_id.to_owned(); let this = self.clone(); - with_immediate_transaction(self.acquire().await?, move |txn| { + with_immediate_transaction(self, move |txn| { for up in updates { match up { Update::NewItemsChunk { previous, new, next } => { @@ -783,7 +815,7 @@ impl EventCacheStore for SqliteEventCacheStore { let this = self.clone(); let result = self - .acquire() + .read() .await? .with_transaction(move |txn| -> Result<_> { let mut items = Vec::new(); @@ -821,7 +853,7 @@ impl EventCacheStore for SqliteEventCacheStore { let hashed_linked_chunk_id = self.encode_key(keys::LINKED_CHUNKS, linked_chunk_id.storage_key()); - self.acquire() + self.read() .await? .with_transaction(move |txn| -> Result<_> { // I'm not a DB analyst, so for my own future sanity: this query joins the @@ -884,7 +916,7 @@ impl EventCacheStore for SqliteEventCacheStore { let this = self.clone(); self - .acquire() + .read() .await? .with_transaction(move |txn| -> Result<_> { // Find the latest chunk identifier to generate a `ChunkIdentifierGenerator`, and count the number of chunks. @@ -977,7 +1009,7 @@ impl EventCacheStore for SqliteEventCacheStore { let this = self.clone(); self - .acquire() + .read() .await? .with_transaction(move |txn| -> Result<_> { // Find the chunk before the chunk identified by `before_chunk_identifier`. @@ -1018,7 +1050,7 @@ impl EventCacheStore for SqliteEventCacheStore { } async fn clear_all_linked_chunks(&self) -> Result<(), Self::Error> { - self.acquire() + self.write() .await? .with_transaction(move |txn| { // Remove all the chunks, and let cascading do its job. @@ -1047,7 +1079,7 @@ impl EventCacheStore for SqliteEventCacheStore { self.encode_key(keys::LINKED_CHUNKS, linked_chunk_id.storage_key()); let linked_chunk_id = linked_chunk_id.to_owned(); - self.acquire() + self.read() .await? .with_transaction(move |txn| -> Result<_> { txn.chunk_large_query_over(events, None, move |txn, events| { @@ -1119,7 +1151,7 @@ impl EventCacheStore for SqliteEventCacheStore { let hashed_room_id = self.encode_key(keys::LINKED_CHUNKS, room_id); - self.acquire() + self.read() .await? .with_transaction(move |txn| -> Result<_> { let Some(event) = txn @@ -1153,7 +1185,7 @@ impl EventCacheStore for SqliteEventCacheStore { let filters = filters.map(ToOwned::to_owned); let store = self.clone(); - self.acquire() + self.read() .await? .with_transaction(move |txn| -> Result<_> { find_event_relations_transaction( @@ -1178,7 +1210,7 @@ impl EventCacheStore for SqliteEventCacheStore { let event_id = event_id.to_string(); let encoded_event = self.encode_event(&event)?; - self.acquire() + self.write() .await? .with_transaction(move |txn| -> Result<_> { txn.execute( @@ -1210,7 +1242,7 @@ impl EventCacheStore for SqliteEventCacheStore { let new_uri = self.encode_key(keys::MEDIA, to.source.unique_key()); let new_format = self.encode_key(keys::MEDIA, to.format.unique_key()); - let conn = self.acquire().await?; + let conn = self.write().await?; conn.execute( r#"UPDATE media SET uri = ?, format = ? WHERE uri = ? AND format = ?"#, (new_uri, new_format, prev_uri, prev_format), @@ -1228,7 +1260,7 @@ impl EventCacheStore for SqliteEventCacheStore { let uri = self.encode_key(keys::MEDIA, request.source.unique_key()); let format = self.encode_key(keys::MEDIA, request.format.unique_key()); - let conn = self.acquire().await?; + let conn = self.write().await?; conn.execute("DELETE FROM media WHERE uri = ? AND format = ?", (uri, format)).await?; Ok(()) @@ -1244,7 +1276,7 @@ impl EventCacheStore for SqliteEventCacheStore { async fn remove_media_content_for_uri(&self, uri: &MxcUri) -> Result<()> { let uri = self.encode_key(keys::MEDIA, uri); - let conn = self.acquire().await?; + let conn = self.write().await?; conn.execute("DELETE FROM media WHERE uri = ?", (uri,)).await?; Ok(()) @@ -1282,7 +1314,7 @@ impl EventCacheStoreMedia for SqliteEventCacheStore { async fn media_retention_policy_inner( &self, ) -> Result, Self::Error> { - let conn = self.acquire().await?; + let conn = self.read().await?; conn.get_serialized_kv(keys::MEDIA_RETENTION_POLICY).await } @@ -1290,7 +1322,7 @@ impl EventCacheStoreMedia for SqliteEventCacheStore { &self, policy: MediaRetentionPolicy, ) -> Result<(), Self::Error> { - let conn = self.acquire().await?; + let conn = self.write().await?; conn.set_serialized_kv(keys::MEDIA_RETENTION_POLICY, policy).await?; Ok(()) } @@ -1314,7 +1346,7 @@ impl EventCacheStoreMedia for SqliteEventCacheStore { let format = self.encode_key(keys::MEDIA, request.format.unique_key()); let timestamp = time_to_timestamp(last_access); - let conn = self.acquire().await?; + let conn = self.write().await?; conn.execute( "INSERT OR REPLACE INTO media (uri, format, data, last_access, ignore_policy) VALUES (?, ?, ?, ?, ?)", (uri, format, data, timestamp, ignore_policy), @@ -1333,7 +1365,7 @@ impl EventCacheStoreMedia for SqliteEventCacheStore { let format = self.encode_key(keys::MEDIA, request.format.unique_key()); let ignore_policy = ignore_policy.is_yes(); - let conn = self.acquire().await?; + let conn = self.write().await?; conn.execute( r#"UPDATE media SET ignore_policy = ? WHERE uri = ? AND format = ?"#, (ignore_policy, uri, format), @@ -1352,7 +1384,7 @@ impl EventCacheStoreMedia for SqliteEventCacheStore { let format = self.encode_key(keys::MEDIA, request.format.unique_key()); let timestamp = time_to_timestamp(current_time); - let conn = self.acquire().await?; + let conn = self.write().await?; let data = conn .with_transaction::<_, rusqlite::Error, _>(move |txn| { // Update the last access. @@ -1383,7 +1415,7 @@ impl EventCacheStoreMedia for SqliteEventCacheStore { let uri = self.encode_key(keys::MEDIA, uri); let timestamp = time_to_timestamp(current_time); - let conn = self.acquire().await?; + let conn = self.write().await?; let data = conn .with_transaction::<_, rusqlite::Error, _>(move |txn| { // Update the last access. @@ -1413,7 +1445,7 @@ impl EventCacheStoreMedia for SqliteEventCacheStore { return Ok(()); } - let conn = self.acquire().await?; + let conn = self.write().await?; let removed = conn .with_transaction::<_, Error, _>(move |txn| { let mut removed = false; @@ -1532,7 +1564,7 @@ impl EventCacheStoreMedia for SqliteEventCacheStore { } async fn last_media_cleanup_time_inner(&self) -> Result, Self::Error> { - let conn = self.acquire().await?; + let conn = self.read().await?; conn.get_serialized_kv(keys::LAST_MEDIA_CLEANUP_TIME).await } } @@ -1630,33 +1662,35 @@ async fn with_immediate_transaction< T: Send + 'static, F: FnOnce(&Transaction<'_>) -> Result + Send + 'static, >( - conn: SqliteAsyncConn, + this: &SqliteEventCacheStore, f: F, ) -> Result { - conn.interact(move |conn| -> Result { - // Start the transaction in IMMEDIATE mode since all updates may cause writes, - // to avoid read transactions upgrading to write mode and causing - // SQLITE_BUSY errors. See also: https://www.sqlite.org/lang_transaction.html#deferred_immediate_and_exclusive_transactions - conn.set_transaction_behavior(TransactionBehavior::Immediate); + this.write() + .await? + .interact(move |conn| -> Result { + // Start the transaction in IMMEDIATE mode since all updates may cause writes, + // to avoid read transactions upgrading to write mode and causing + // SQLITE_BUSY errors. See also: https://www.sqlite.org/lang_transaction.html#deferred_immediate_and_exclusive_transactions + conn.set_transaction_behavior(TransactionBehavior::Immediate); - let code = || -> Result { - let txn = conn.transaction()?; - let res = f(&txn)?; - txn.commit()?; - Ok(res) - }; + let code = || -> Result { + let txn = conn.transaction()?; + let res = f(&txn)?; + txn.commit()?; + Ok(res) + }; - let res = code(); + let res = code(); - // Reset the transaction behavior to use Deferred, after this transaction has - // been run, whether it was successful or not. - conn.set_transaction_behavior(TransactionBehavior::Deferred); + // Reset the transaction behavior to use Deferred, after this transaction has + // been run, whether it was successful or not. + conn.set_transaction_behavior(TransactionBehavior::Deferred); - res - }) - .await - // SAFETY: same logic as in [`deadpool::managed::Object::with_transaction`].` - .unwrap() + res + }) + .await + // SAFETY: same logic as in [`deadpool::managed::Object::with_transaction`].` + .unwrap() } fn insert_chunk( @@ -1763,7 +1797,7 @@ mod tests { async fn get_event_cache_store_content_sorted_by_last_access( event_cache_store: &SqliteEventCacheStore, ) -> Vec> { - let sqlite_db = event_cache_store.acquire().await.expect("accessing sqlite db failed"); + let sqlite_db = event_cache_store.read().await.expect("accessing sqlite db failed"); sqlite_db .prepare("SELECT data FROM media ORDER BY last_access DESC", |mut stmt| { stmt.query(())?.mapped(|row| row.get(0)).collect() @@ -2053,7 +2087,7 @@ mod tests { // Check that cascading worked. Yes, SQLite, I doubt you. let gaps = store - .acquire() + .read() .await .unwrap() .with_transaction(|txn| -> rusqlite::Result<_> { @@ -2175,7 +2209,7 @@ mod tests { // Make sure the position have been updated for the remaining events. let num_rows: u64 = store - .acquire() + .read() .await .unwrap() .with_transaction(move |txn| { @@ -2324,7 +2358,7 @@ mod tests { // Check that cascading worked. Yes, SQLite, I doubt you. store - .acquire() + .read() .await .unwrap() .with_transaction(|txn| -> rusqlite::Result<_> {